diff --git a/blosc2/blosc2_ext.pyx b/blosc2/blosc2_ext.pyx index c457a920..fb955f56 100644 --- a/blosc2/blosc2_ext.pyx +++ b/blosc2/blosc2_ext.pyx @@ -646,24 +646,34 @@ def get_blocksize(): cdef _check_cparams(blosc2_cparams *cparams): if cparams.nthreads > 1: - if BLOSC2_USER_REGISTERED_CODECS_START <= cparams.compcode <= BLOSC2_USER_REGISTERED_CODECS_STOP: - raise ValueError("Cannot use multi-threading with user defined codecs") - elif any([BLOSC2_USER_REGISTERED_FILTERS_START <= filter <= BLOSC2_USER_REGISTERED_FILTERS_STOP - for filter in cparams.filters]): - raise ValueError("Cannot use multi-threading with user defined filters") - elif cparams.prefilter != NULL: + if BLOSC2_USER_REGISTERED_CODECS_START <= cparams.compcode <= BLOSC2_USER_REGISTERED_CODECS_STOP\ + and cparams.compcode in blosc2.ucodecs_registry.keys(): + raise ValueError("Cannot use multi-threading with user defined Python codecs") + + ufilters = [BLOSC2_USER_REGISTERED_FILTERS_START <= filter <= BLOSC2_USER_REGISTERED_FILTERS_STOP + for filter in cparams.filters] + for i in range(len(ufilters)): + if ufilters[i] and cparams.filters[i] in blosc2.ufilters_registry.keys(): + raise ValueError("Cannot use multi-threading with user defined Python filters") + + if cparams.prefilter != NULL: raise ValueError("`nthreads` must be 1 when a prefilter is set") cdef _check_dparams(blosc2_dparams* dparams, blosc2_cparams* cparams=NULL): if cparams == NULL: return if dparams.nthreads > 1: - if BLOSC2_USER_REGISTERED_CODECS_START <= cparams.compcode <= BLOSC2_USER_REGISTERED_CODECS_STOP: - raise ValueError("Cannot use multi-threading with user defined codecs") - elif any([BLOSC2_USER_REGISTERED_FILTERS_START <= filter <= BLOSC2_USER_REGISTERED_FILTERS_STOP - for filter in cparams.filters]): - raise ValueError("Cannot use multi-threading with user defined filters") - elif dparams.postfilter != NULL: + if BLOSC2_USER_REGISTERED_CODECS_START <= cparams.compcode <= BLOSC2_USER_REGISTERED_CODECS_STOP\ + and cparams.compcode in blosc2.ucodecs_registry.keys(): + raise ValueError("Cannot use multi-threading with user defined Python codecs") + + ufilters = [BLOSC2_USER_REGISTERED_FILTERS_START <= filter <= BLOSC2_USER_REGISTERED_FILTERS_STOP + for filter in cparams.filters] + for i in range(len(ufilters)): + if ufilters[i] and cparams.filters[i] in blosc2.ufilters_registry.keys(): + raise ValueError("Cannot use multi-threading with user defined Python filters") + + if dparams.postfilter != NULL: raise ValueError("`nthreads` must be 1 when a postfilter is set") diff --git a/tests/test_ucodecs.py b/tests/test_ucodecs.py index 159ba9ce..48df45eb 100644 --- a/tests/test_ucodecs.py +++ b/tests/test_ucodecs.py @@ -83,3 +83,72 @@ def decoder1(input, output, meta, schunk): assert np.array_equal(data, out) blosc2.remove_urlpath(urlpath) + + +@pytest.mark.parametrize( + "cparams, dparams", + [ + ({"codec": 163, "nthreads": 1}, {"nthreads": 4}), + ({"codec": 163, "nthreads": 4}, {"nthreads": 1}), + ], +) +def test_pyucodecs_error(cparams, dparams): + chunk_len = 20 * 1000 + dtype = np.dtype(np.int32) + + def encoder1(input, output, meta, schunk): + nd_input = input.view(dtype) + if np.max(nd_input) == np.min(nd_input): + output[0 : schunk.typesize] = input[0 : schunk.typesize] + n = nd_input.size.to_bytes(4, sys.byteorder) + output[schunk.typesize : schunk.typesize + 4] = [n[i] for i in range(4)] + return schunk.typesize + 4 + else: + # memcpy + return 0 + + def decoder1(input, output, meta, schunk): + nd_input = input.view(np.int32) + nd_output = output.view(dtype) + nd_output[0 : nd_input[1]] = [nd_input[0]] * nd_input[1] + return nd_input[1] * schunk.typesize + + if cparams["codec"] not in blosc2.ucodecs_registry: + blosc2.register_codec("codec3", cparams["codec"], encoder1, decoder1) + + nchunks = 2 + fill_value = 341 + data = np.full(chunk_len * nchunks, fill_value, dtype=dtype) + + with pytest.raises(ValueError): + _ = blosc2.SChunk( + chunksize=chunk_len * dtype.itemsize, + data=data, + cparams=cparams, + dparams=dparams, + ) + + +@pytest.mark.parametrize( + "cparams, dparams", + [ + ({"codec": 254, "nthreads": 1}, {"nthreads": 4}), + ({"codec": 254, "nthreads": 4}, {"nthreads": 1}), + ], +) +def test_dynamic_ucodecs_error(cparams, dparams): + blosc2.register_codec("codec4", cparams["codec"], None, None) + + chunk_len = 100 + dtype = np.dtype(np.int32) + nchunks = 1 + fill_value = 341 + data = np.arange(chunk_len * nchunks, dtype=dtype) + + with pytest.raises(RuntimeError): + schunk = blosc2.SChunk( + chunksize=chunk_len * dtype.itemsize, + data=data, + cparams=cparams, + dparams=dparams, + ) diff --git a/tests/test_ufilters.py b/tests/test_ufilters.py index 514ddb0b..2690e417 100644 --- a/tests/test_ufilters.py +++ b/tests/test_ufilters.py @@ -90,3 +90,64 @@ def backward2(input, output, meta, schunk): assert np.array_equal(data, out) blosc2.remove_urlpath(urlpath) + +@pytest.mark.parametrize( + "cparams, dparams", + [ + ({"nthreads": 4, "filters": [255, blosc2.Filter.SHUFFLE], "filters_meta": [0, 0]}, {"nthreads": 1}), + ({"nthreads": 1, "filters": [255], "filters_meta": [4]}, {"nthreads": 4}) + ], +) +def test_pyufilters_error(cparams, dparams): + dtype = np.dtype(np.int32) + def forward(input, output, meta, schunk): + nd_input = input.view(dtype) + nd_output = output.view(dtype) + + nd_output[:] = nd_input + 1 + + def backward(input, output, meta, schunk): + nd_input = input.view(dtype) + nd_output = output.view(dtype) + + nd_output[:] = nd_input - 1 + if 255 not in blosc2.ufilters_registry: + blosc2.register_filter(255, forward, backward) + + nchunks = 1 + chunk_len = 100 + fill_value = 341 + data = np.full(chunk_len * nchunks, fill_value, dtype=dtype) + + with pytest.raises(ValueError): + _ = blosc2.SChunk( + chunksize=chunk_len * dtype.itemsize, + data=data, + cparams=cparams, + dparams=dparams, + ) + + +@pytest.mark.parametrize( + "cparams, dparams", + [ + ({"nthreads": 4, "filters": [163, blosc2.Filter.SHUFFLE], "filters_meta": [0, 0]}, {"nthreads": 1}), + ({"nthreads": 1, "filters": [163], "filters_meta": [4]}, {"nthreads": 4}) + ], +) +def test_dynamic_ufilters_error(cparams, dparams): + dtype = np.dtype(np.int32) + blosc2.register_filter(163, None, None, "ufilter_test") + + nchunks = 1 + chunk_len = 100 + fill_value = 341 + data = np.full(chunk_len * nchunks, fill_value, dtype=dtype) + + with pytest.raises(RuntimeError): + _ = blosc2.SChunk( + chunksize=chunk_len * dtype.itemsize, + data=data, + cparams=cparams, + dparams=dparams, + )