Skip to content

Commit

Permalink
more functionality for dict serialization (#3516)
Browse files Browse the repository at this point in the history
* more functionality for dict serialization

* adds test for empty dict
  • Loading branch information
antonydellavecchia authored Mar 19, 2024
1 parent 608d2b6 commit ac3ae07
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 5 deletions.
52 changes: 47 additions & 5 deletions src/Serialization/containers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ end
# Saving and loading dicts
@register_serialization_type Dict uses_params

function save_type_params(s::SerializerState, obj::Dict{S, T}) where {S <: Union{Symbol, String, Int}, T}
function save_type_params(s::SerializerState, obj::Dict{S, Any}) where S <: Union{Symbol, String, Int}
save_data_dict(s) do
save_object(s, encode_type(Dict), :name)
save_data_dict(s, :params) do
Expand All @@ -280,7 +280,42 @@ function save_type_params(s::SerializerState, obj::Dict{S, T}) where {S <: Union
end
end

function save_type_params(s::SerializerState, obj::Dict{S, T}) where {S <: Union{Symbol, String, Int}, T}
save_data_dict(s) do
save_object(s, encode_type(Dict), :name)
save_data_dict(s, :params) do
save_object(s, encode_type(S), :key_type)

if serialize_with_params(T)
if isempty(obj)
save_object(s, encode_type(T), :value_type)
else
v = first(values(obj))
save_object(s, encode_type(T), :value_type)
save_type_params(s, v, :value_params)
end
else
save_object(s, encode_type(T), :value_type)
end
end
end
end

function load_type_params(s::DeserializerState, ::Type{<:Dict})
if haskey(s, :value_type)
key_type = load_node(_ -> decode_type(s), s, :key_type)
value_type = load_node(_ -> decode_type(s), s, :value_type)
d = Dict{Symbol, Any}(:key_type => key_type, :value_type => value_type)

if serialize_with_params(value_type)
d[:value_params] = load_node(s, :value_params) do _
load_params_node(s)
end
end

return d
end

params_dict = Dict{Symbol, Any}()
for (k, _) in s.obj
load_node(s, k) do _
Expand Down Expand Up @@ -309,20 +344,27 @@ end

function load_object(s::DeserializerState, ::Type{<:Dict}, params::Dict{Symbol, Any})
key_type = params[:key_type]

dict = Dict{key_type, Any}()
value_type = haskey(params, :value_type) ? params[:value_type] : Any
dict = Dict{key_type, value_type}()

for (k, _) in s.obj
if k == :key_type
continue
end

if key_type == Int
key = parse(Int, string(k))
else
key = key_type(k)
end

if params[k] isa Type
if value_type != Any
if serialize_with_params(value_type)
dict[key] = load_object(s, value_type, params[:value_params], k)
else
dict[key] = load_object(s, value_type, k)
end
elseif params[k] isa Type
dict[key] = load_object(s, params[k], k)
else
dict[key] = load_object(s, params[k][type_key], params[k][:params], k)
Expand Down
15 changes: 15 additions & 0 deletions test/Serialization/containers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,21 @@
end
end

@testset "(de)serialization Dict{Symbol, T}" begin
Qx, x = QQ[:x]
for (T, values) in ((Int, [1, 2]), (PolyRingElem, [x^2, x - 1]))
original = Dict{Symbol, T}(:a => values[1], :b => values[2])
test_save_load_roundtrip(path, original) do loaded
@test original == loaded
end
end

original = Dict{Symbol, Int}()
test_save_load_roundtrip(path, original) do loaded
@test original == loaded
end
end

@testset "Testing (de)serialization of Set" begin
original = Set([Set([1, 2])])
test_save_load_roundtrip(path, original) do loaded
Expand Down

0 comments on commit ac3ae07

Please sign in to comment.