From faa2f3399099b5c3d184bd890ebb5563a4c077ba Mon Sep 17 00:00:00 2001 From: Joris Belier Date: Mon, 12 Aug 2024 19:22:55 +0200 Subject: [PATCH] Support for union value unpacking. --- src/matfrostjuliacall/converttomatlab.hpp | 21 ++++++- test/MATFrostTest/src/MATFrostTest.jl | 15 +++++ test/matfrost_union_type_unpacking_test.m | 75 +++++++++++++++++++++++ 3 files changed, 109 insertions(+), 2 deletions(-) create mode 100644 test/matfrost_union_type_unpacking_test.m diff --git a/src/matfrostjuliacall/converttomatlab.hpp b/src/matfrostjuliacall/converttomatlab.hpp index 4c00305..19e2d32 100644 --- a/src/matfrostjuliacall/converttomatlab.hpp +++ b/src/matfrostjuliacall/converttomatlab.hpp @@ -343,6 +343,20 @@ class AbstractConverter : public Converter { } }; +/** + * UnionConverter will construct a converter for values from union types. + */ +class UnionConverter : public Converter { +public: + + matlab::data::Array convert(jl_value_t* jlval, matlab::engine::MATLABEngine* matlabPtr) override { + jl_value_t* jlvalc = jl_call1(jl_get_function(jl_base_module, "identity"), jlval); // Unpacks the union value. + std::unique_ptr conv = converter((jl_datatype_t*) jl_typeof(jlvalc)); + return conv->convert(jlvalc, matlabPtr); + + } +}; + bool unbox_bool(jl_value_t* jlval){ return (bool) jl_unbox_bool(jlval); @@ -352,6 +366,8 @@ std::unique_ptr converter(jl_datatype_t* jltype){ jl_value_t* jlcomplex = (jl_value_t*) jl_get_function(jl_base_module, "Complex"); if (jl_is_abstracttype(jltype)){ return std::unique_ptr(new AbstractConverter()); + } else if(jl_is_uniontype(jltype)) { + return std::unique_ptr(new UnionConverter()); } else if(jl_is_primitivetype(jltype)){ if (jltype == jl_float32_type){ return std::unique_ptr(new PrimitiveConverter(jl_unbox_float32)); @@ -456,7 +472,7 @@ std::unique_ptr converter(jl_datatype_t* jltype){ else if(jlarrayof == jl_string_type){ return std::unique_ptr(new ArrayStringConverter()); } - else if (jl_is_array_type(jlarrayof) || jl_is_tuple_type(jlarrayof)) { + else if (jl_is_array_type(jlarrayof) || jl_is_tuple_type(jlarrayof) || jl_is_uniontype(jlarrayof)) { return std::unique_ptr(new ArrayConverter(jltype)); } else if (jl_is_structtype(jlarrayof) || jl_is_namedtuple_type(jlarrayof)){ @@ -466,7 +482,8 @@ std::unique_ptr converter(jl_datatype_t* jltype){ } else if (jl_is_structtype(jltype) || jl_is_namedtuple_type(jltype)){ return std::unique_ptr(new StructConverter(jltype)); } - throw std::invalid_argument("Wrong input MATFRost test - Cannot find matching type."); + throw std::invalid_argument("Wrong input MATFRost test - Cannot find matching type for: " + + std::string(jl_string_ptr(jl_call1(jl_get_function(jl_base_module, "string"), (jl_value_t*) jltype)))); } diff --git a/test/MATFrostTest/src/MATFrostTest.jl b/test/MATFrostTest/src/MATFrostTest.jl index 90793c8..2933a0a 100644 --- a/test/MATFrostTest/src/MATFrostTest.jl +++ b/test/MATFrostTest/src/MATFrostTest.jl @@ -229,6 +229,15 @@ function nested_structures_test1(v::Nest3_L1{T}) where {T} ) end +function interleave_with_number_and_string(vs::Vector{T}, vnum::Float64, vstring::String) where {T} + vo = Vector{Union{T, Float64, String}}(undef, length(vs)*3) + + vo[1:3:end] .= vs + vo[2:3:end] .= vnum + vo[3:3:end] .= vstring + + vo +end for (suf, prim) in ( @@ -311,6 +320,12 @@ for (suf, prim) in ( eval(:($(Symbol(:nested_structures_test1_vector_, suf))(v::Nest3_L1{Vector{$(prim)}}) = nested_structures_test1(v))) + + eval(:($(Symbol(:interleave_with_number_and_string_, suf))(vs::Vector{$(prim)}, vnum::Float64, vstring::String) = interleave_with_number_and_string(vs, vnum, vstring))) + + eval(:($(Symbol(:ifelse_, suf))(b::Bool, v::$(prim), vstrings::Vector{String}) = ifelse(b, v, vstrings))) + + end diff --git a/test/matfrost_union_type_unpacking_test.m b/test/matfrost_union_type_unpacking_test.m new file mode 100644 index 0000000..c7bf169 --- /dev/null +++ b/test/matfrost_union_type_unpacking_test.m @@ -0,0 +1,75 @@ +classdef matfrost_union_type_unpacking_test < matfrost_abstract_test + + + properties (TestParameter) + prim_type = {"bool"; + "string"; + "simple_population_type"; + "named_tuple_simple_population_type"; + "i8"; "ui8"; "i16"; "ui16"; "i32"; "ui32"; "i64"; "ui64"; ... + "f32"; "f64"; ... + "ci8"; "cui8"; "ci16"; "cui16"; "ci32"; "cui32"; "ci64"; "cui64"; ... + "cf32"; "cf64"}; + + v = {... + false; + "test"; + ... + struct("name", "Test", "population", int64(200)); + struct("name", "Test", "population", int64(200)); + ... + int8(8); uint8(14); + int16(478); uint16(4532); + int32(323442); uint32(53342); + int64(323434542); uint64(535345342); + ... + single(34.125); 2342.0625; + ... + complex(int8(8), int8(12)); + complex(uint8(14), uint8(7)); + complex(int16(478), int16(124)); + complex(uint16(4532), uint16(544)); + complex(int32(323442), int32(74571)); + complex(uint32(53342), uint32(56123));... + complex(int64(323434542), int64(84968213)); + complex(uint64(535345342), uint64(8492313));... + ... + complex(single(34.125), single(4234.5)); + complex(2342.0625, 12444.0625)}; + + + + end + + + + methods(Test, ParameterCombination="sequential") + + function ifelse_scalar_union_value(tc, prim_type, v) + + tc.verifyEqual( ... + tc.mjl.MATFrostTest.("ifelse_" + prim_type)(true, v, ["If"; "Else"]), ... + v); + + tc.verifyEqual( ... + tc.mjl.MATFrostTest.("ifelse_" + prim_type)(false, v, ["If"; "Else"]), ... + ["If"; "Else"]); + end + + + function interleave_with_number_and_string(tc, prim_type, v) + vs = repmat(v, 5, 1); + + vnum = 321.0; + + vstring = "interleaved"; + + vo = repmat({v; vnum; vstring}, 5, 1); + + tc.verifyEqual(tc.mjl.MATFrostTest.("interleave_with_number_and_string_" + prim_type)(vs, vnum, vstring), vo); + end + + + end + +end \ No newline at end of file