Skip to content

Commit

Permalink
Merge pull request #45 from ASML-Labs/union-type-unpacking-convert-to…
Browse files Browse the repository at this point in the history
…-matlab

Support for union value unpacking.
  • Loading branch information
jorisbelierasml authored Aug 13, 2024
2 parents dcd7052 + faa2f33 commit c39b2e5
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 2 deletions.
21 changes: 19 additions & 2 deletions src/matfrostjuliacall/converttomatlab.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,20 @@ class AbstractConverter : public Converter {
}
};

/**
* UnionConverter will construct a converter for values from union types.
*/
class UnionConverter : public Converter {
public:

matlab::data::Array convert(jl_value_t* jlval, matlab::engine::MATLABEngine* matlabPtr) override {
jl_value_t* jlvalc = jl_call1(jl_get_function(jl_base_module, "identity"), jlval); // Unpacks the union value.
std::unique_ptr<Converter> conv = converter((jl_datatype_t*) jl_typeof(jlvalc));
return conv->convert(jlvalc, matlabPtr);

}
};


bool unbox_bool(jl_value_t* jlval){
return (bool) jl_unbox_bool(jlval);
Expand All @@ -352,6 +366,8 @@ std::unique_ptr<Converter> converter(jl_datatype_t* jltype){
jl_value_t* jlcomplex = (jl_value_t*) jl_get_function(jl_base_module, "Complex");
if (jl_is_abstracttype(jltype)){
return std::unique_ptr<Converter>(new AbstractConverter());
} else if(jl_is_uniontype(jltype)) {
return std::unique_ptr<Converter>(new UnionConverter());
} else if(jl_is_primitivetype(jltype)){
if (jltype == jl_float32_type){
return std::unique_ptr<Converter>(new PrimitiveConverter<float>(jl_unbox_float32));
Expand Down Expand Up @@ -456,7 +472,7 @@ std::unique_ptr<Converter> converter(jl_datatype_t* jltype){
else if(jlarrayof == jl_string_type){
return std::unique_ptr<Converter>(new ArrayStringConverter());
}
else if (jl_is_array_type(jlarrayof) || jl_is_tuple_type(jlarrayof)) {
else if (jl_is_array_type(jlarrayof) || jl_is_tuple_type(jlarrayof) || jl_is_uniontype(jlarrayof)) {
return std::unique_ptr<Converter>(new ArrayConverter(jltype));
}
else if (jl_is_structtype(jlarrayof) || jl_is_namedtuple_type(jlarrayof)){
Expand All @@ -466,7 +482,8 @@ std::unique_ptr<Converter> converter(jl_datatype_t* jltype){
} else if (jl_is_structtype(jltype) || jl_is_namedtuple_type(jltype)){
return std::unique_ptr<Converter>(new StructConverter(jltype));
}
throw std::invalid_argument("Wrong input MATFRost test - Cannot find matching type.");
throw std::invalid_argument("Wrong input MATFRost test - Cannot find matching type for: " +
std::string(jl_string_ptr(jl_call1(jl_get_function(jl_base_module, "string"), (jl_value_t*) jltype))));
}


Expand Down
15 changes: 15 additions & 0 deletions test/MATFrostTest/src/MATFrostTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,15 @@ function nested_structures_test1(v::Nest3_L1{T}) where {T}
)
end

function interleave_with_number_and_string(vs::Vector{T}, vnum::Float64, vstring::String) where {T}
vo = Vector{Union{T, Float64, String}}(undef, length(vs)*3)

vo[1:3:end] .= vs
vo[2:3:end] .= vnum
vo[3:3:end] .= vstring

vo
end


for (suf, prim) in (
Expand Down Expand Up @@ -311,6 +320,12 @@ for (suf, prim) in (


eval(:($(Symbol(:nested_structures_test1_vector_, suf))(v::Nest3_L1{Vector{$(prim)}}) = nested_structures_test1(v)))

eval(:($(Symbol(:interleave_with_number_and_string_, suf))(vs::Vector{$(prim)}, vnum::Float64, vstring::String) = interleave_with_number_and_string(vs, vnum, vstring)))

eval(:($(Symbol(:ifelse_, suf))(b::Bool, v::$(prim), vstrings::Vector{String}) = ifelse(b, v, vstrings)))


end


Expand Down
75 changes: 75 additions & 0 deletions test/matfrost_union_type_unpacking_test.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
classdef matfrost_union_type_unpacking_test < matfrost_abstract_test


properties (TestParameter)
prim_type = {"bool";
"string";
"simple_population_type";
"named_tuple_simple_population_type";
"i8"; "ui8"; "i16"; "ui16"; "i32"; "ui32"; "i64"; "ui64"; ...
"f32"; "f64"; ...
"ci8"; "cui8"; "ci16"; "cui16"; "ci32"; "cui32"; "ci64"; "cui64"; ...
"cf32"; "cf64"};

v = {...
false;
"test";
...
struct("name", "Test", "population", int64(200));
struct("name", "Test", "population", int64(200));
...
int8(8); uint8(14);
int16(478); uint16(4532);
int32(323442); uint32(53342);
int64(323434542); uint64(535345342);
...
single(34.125); 2342.0625;
...
complex(int8(8), int8(12));
complex(uint8(14), uint8(7));
complex(int16(478), int16(124));
complex(uint16(4532), uint16(544));
complex(int32(323442), int32(74571));
complex(uint32(53342), uint32(56123));...
complex(int64(323434542), int64(84968213));
complex(uint64(535345342), uint64(8492313));...
...
complex(single(34.125), single(4234.5));
complex(2342.0625, 12444.0625)};



end



methods(Test, ParameterCombination="sequential")

function ifelse_scalar_union_value(tc, prim_type, v)

tc.verifyEqual( ...
tc.mjl.MATFrostTest.("ifelse_" + prim_type)(true, v, ["If"; "Else"]), ...
v);

tc.verifyEqual( ...
tc.mjl.MATFrostTest.("ifelse_" + prim_type)(false, v, ["If"; "Else"]), ...
["If"; "Else"]);
end


function interleave_with_number_and_string(tc, prim_type, v)
vs = repmat(v, 5, 1);

vnum = 321.0;

vstring = "interleaved";

vo = repmat({v; vnum; vstring}, 5, 1);

tc.verifyEqual(tc.mjl.MATFrostTest.("interleave_with_number_and_string_" + prim_type)(vs, vnum, vstring), vo);
end


end

end

0 comments on commit c39b2e5

Please sign in to comment.