diff --git a/stringdtype/stringdtype/src/static_string.c b/stringdtype/stringdtype/src/static_string.c index 7288a82b..5d79e914 100644 --- a/stringdtype/stringdtype/src/static_string.c +++ b/stringdtype/stringdtype/src/static_string.c @@ -44,11 +44,14 @@ typedef union _npy_static_string_u { #define NPY_STRING_SHORT 0x40 // 0100 0000 #define NPY_STRING_ARENA_FREED 0x20 // 0010 0000 #define NPY_STRING_ON_HEAP 0x10 // 0001 0000 +#define NPY_STRING_MEDIUM 0x08 // 0000 1000 +#define NPY_STRING_FLAG_MASK 0xF8 // 1111 1000 // short string sizes fit in a 4-bit integer #define NPY_SHORT_STRING_SIZE_MASK 0x0F // 0000 1111 #define NPY_SHORT_STRING_MAX_SIZE \ - (sizeof(npy_static_string) - 1) // 15 or 7 depending on arch + (sizeof(npy_static_string) - 1) // 15 or 7 depending on arch +#define NPY_MEDIUM_STRING_MAX_SIZE 0xFF // 256 // Since this has no flags set, technically this is a heap-allocated string // with size zero. Practically, that doesn't matter because we always do size @@ -86,8 +89,7 @@ struct npy_string_allocator { void set_vstring_size(_npy_static_string_u *str, size_t size) { - unsigned char *flags = &str->direct_buffer.flags_and_size; - unsigned char current_flags = *flags & ~NPY_SHORT_STRING_SIZE_MASK; + unsigned char current_flags = str->direct_buffer.flags_and_size; str->vstring.size = size; str->direct_buffer.flags_and_size = current_flags; } @@ -110,9 +112,13 @@ npy_string_arena_malloc(npy_string_arena *arena, npy_string_realloc_func r, size_t size) { // one extra size_t to store the size of the allocation - size_t string_storage_size = size + sizeof(size_t); - // expand size to nearest multiple of 8 bytes to ensure 64 bit alignment - string_storage_size += (8 - string_storage_size % 8); + size_t string_storage_size; + if (size <= NPY_MEDIUM_STRING_MAX_SIZE) { + string_storage_size = size + sizeof(unsigned char); + } + else { + string_storage_size = size + sizeof(size_t); + } if ((arena->size - arena->cursor) <= string_storage_size) { // realloc the buffer so there is enough room // first guess is to double the size of the buffer @@ -130,7 +136,7 @@ npy_string_arena_malloc(npy_string_arena *arena, npy_string_realloc_func r, // doubling the current size isn't enough newsize = 2 * (arena->cursor + size); } - // realloc passed a NULL pointer acts like malloc + // passing a NULL buffer to realloc is the same as malloc char *newbuf = r(arena->buffer, newsize); if (newbuf == NULL) { return NULL; @@ -139,9 +145,18 @@ npy_string_arena_malloc(npy_string_arena *arena, npy_string_realloc_func r, arena->buffer = newbuf; arena->size = newsize; } - size_t *size_loc = (size_t *)&arena->buffer[arena->cursor]; - *size_loc = size; - char *ret = &arena->buffer[arena->cursor + sizeof(size_t)]; + char *ret; + if (size <= NPY_MEDIUM_STRING_MAX_SIZE) { + unsigned char *size_loc = + (unsigned char *)&arena->buffer[arena->cursor]; + *size_loc = size; + ret = &arena->buffer[arena->cursor + sizeof(char)]; + } + else { + char *size_ptr = (char *)&arena->buffer[arena->cursor]; + memcpy(size_ptr, &size, sizeof(size_t)); + ret = &arena->buffer[arena->cursor + sizeof(size_t)]; + } arena->cursor += string_storage_size; return ret; } @@ -207,6 +222,15 @@ is_short_string(const npy_packed_static_string *s) return has_short_flag && !has_on_heap_flag; } +int +is_medium_string(const _npy_static_string_u *s) +{ + unsigned char high_byte = s->direct_buffer.flags_and_size; + int has_short_flag = (high_byte & NPY_STRING_SHORT); + int has_medium_flag = (high_byte & NPY_STRING_MEDIUM); + return (!has_short_flag && has_medium_flag); +} + int npy_string_isnull(const npy_packed_static_string *s) { @@ -286,10 +310,19 @@ heap_or_arena_allocate(npy_string_allocator *allocator, if (buf == NULL) { return NULL; } - size_t alloc_size = *((size_t *)(buf - 1)); + size_t alloc_size; + if (is_medium_string(to_init_u)) { + // stored in a char so direct access is OK + alloc_size = (size_t) * (buf - 1); + } + else { + // not necessarily memory-aligned, so need to use memcpy + size_t *size_loc = (size_t *)((uintptr_t)buf - sizeof(size_t)); + memcpy(&alloc_size, size_loc, sizeof(size_t)); + } if (size <= alloc_size) { // we have room! - *flags = NPY_STRING_ARENA_FREED; + *flags &= ~NPY_STRING_ARENA_FREED; return buf; } else { @@ -316,8 +349,12 @@ heap_or_arena_allocate(npy_string_allocator *allocator, if (arena == NULL) { return NULL; } - return npy_string_arena_malloc(arena, allocator->realloc, - sizeof(char) * size); + char *ret = npy_string_arena_malloc(arena, allocator->realloc, + sizeof(char) * size); + if (size < NPY_MEDIUM_STRING_MAX_SIZE) { + *flags |= NPY_STRING_MEDIUM; + } + return ret; } int diff --git a/stringdtype/tests/test_char.py b/stringdtype/tests/test_char.py index 4ad99cdf..c153c511 100644 --- a/stringdtype/tests/test_char.py +++ b/stringdtype/tests/test_char.py @@ -4,7 +4,12 @@ from stringdtype import StringDType -TEST_DATA = ["hello", "Ae¢☃€ 😊", "entry\nwith\nnewlines", "entry\twith\ttabs"] +TEST_DATA = [ + "hello" * 10, + "Ae¢☃€ 😊" * 100, + "entry\nwith\nnewlines", + "entry\twith\ttabs", +] @pytest.fixture @@ -94,11 +99,11 @@ def test_binary(string_array, unicode_array, function_name, args): def test_strip(string_array, unicode_array): - rjs = np.char.rjust(string_array, 25) - rju = np.char.rjust(unicode_array, 25) + rjs = np.char.rjust(string_array, 1000) + rju = np.char.rjust(unicode_array, 1000) - ljs = np.char.ljust(string_array, 25) - lju = np.char.ljust(unicode_array, 25) + ljs = np.char.ljust(string_array, 1000) + lju = np.char.ljust(unicode_array, 1000) assert_array_equal( np.char.lstrip(rjs), diff --git a/stringdtype/tests/test_stringdtype.py b/stringdtype/tests/test_stringdtype.py index 627ea96c..eb48fef6 100644 --- a/stringdtype/tests/test_stringdtype.py +++ b/stringdtype/tests/test_stringdtype.py @@ -17,7 +17,7 @@ @pytest.fixture def string_list(): - return ["abc", "def", "ghi" * 10, "A¢☃€ 😊", "Abc", "DEF"] + return ["abc", "def", "ghi" * 10, "A¢☃€ 😊" * 100, "Abc" * 1000, "DEF"] pd_param = pytest.param( @@ -121,7 +121,7 @@ def test_array_creation_utf8(dtype, data): def test_array_creation_scalars(string_list): arr = np.array([StringScalar(s) for s in string_list]) assert ( - str(arr) + str(arr).replace("\n", "") == "[" + " ".join(["'" + str(s) + "'" for s in string_list]) + "]" ) assert arr.dtype == StringDType()