Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch get for MmapHashmap #256

Merged
merged 4 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions pecos/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1716,6 +1716,15 @@ def link_mmap_hashmap_methods(self):
c_uint64, # key int64
],
}
batch_key_args_dict = {
"str2int": [
c_void_p, # List of pointer of key string
POINTER(c_uint32), # List of length of key string
],
"int2int": [
POINTER(c_uint64), # List of key int64
],
}
self.mmap_map_fn_dict = {}

for map_type in map_type_list:
Expand Down Expand Up @@ -1760,6 +1769,16 @@ def link_mmap_hashmap_methods(self):
local_fn_dict[fn_name], c_uint64, [c_void_p] + key_args_dict[map_type] + [c_uint64]
)

fn_name = "batch_get_w_default"
local_fn_dict[fn_name] = getattr(self.clib_float32, f"{fn_prefix}_{fn_name}_{map_type}")
corelib.fillprototype(
local_fn_dict[fn_name],
None,
[c_void_p, c_uint32]
+ batch_key_args_dict[map_type] # noqa: W503
+ [c_uint64, POINTER(c_uint64), c_uint32], # noqa: W503
)

fn_name = "contains"
local_fn_dict[fn_name] = getattr(self.clib_float32, f"{fn_prefix}_{fn_name}_{map_type}")
corelib.fillprototype(
Expand Down
6 changes: 6 additions & 0 deletions pecos/core/libpecos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,12 @@ extern "C" {
MMAP_MAP_GET_W_DEFAULT(str2int, KEY_SINGLE_ARG(const char* key, uint32_t key_len), KEY_SINGLE_ARG(key, key_len))
MMAP_MAP_GET_W_DEFAULT(int2int, uint64_t key, key)

#define MMAP_MAP_BATCH_GET_W_DEFAULT(SUFFIX, KEY, FUNC_CALL_KEY) \
void mmap_hashmap_batch_get_w_default_ ## SUFFIX (void* map_ptr, const uint32_t n_key, KEY, uint64_t def_val, uint64_t* vals, const int threads) { \
static_cast<mmap_hashmap_ ## SUFFIX *>(map_ptr)->batch_get_w_default(n_key, FUNC_CALL_KEY, def_val, vals, threads); }
MMAP_MAP_BATCH_GET_W_DEFAULT(str2int, KEY_SINGLE_ARG(const char* const* keys, const uint32_t* keys_lens), KEY_SINGLE_ARG(keys, keys_lens))
MMAP_MAP_BATCH_GET_W_DEFAULT(int2int, const uint64_t* key, key)

// Contains
#define MMAP_MAP_CONTAINS(SUFFIX, KEY, FUNC_CALL_KEY) \
bool mmap_hashmap_contains_ ## SUFFIX (void* map_ptr, KEY) { \
Expand Down
16 changes: 16 additions & 0 deletions pecos/core/utils/mmap_hashmap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
#ifndef __MMAP_ANKERL_HASHMAP_H__
#define __MMAP_ANKERL_HASHMAP_H__

#include <omp.h>
#include "third_party/ankerl/unordered_dense.h"
#include "mmap_util.hpp"


namespace pecos {
namespace mmap_hashmap {

Expand Down Expand Up @@ -374,6 +376,13 @@ class Str2IntMap {
} catch (...) { return def_val;}
}

void batch_get_w_default(const uint32_t n_key, const char* const* keys, const uint32_t* keys_lens, const uint64_t def_val, uint64_t* vals, const int threads) {
#pragma omp parallel for schedule(static, 1) num_threads(threads)
for (uint32_t i=0; i<n_key; ++i) {
vals[i] = get_w_default(keys[i], keys_lens[i], def_val);
}
}

bool contains(const char* key, uint32_t key_len) {
return map.contains(std::string_view(key, key_len));
}
Expand Down Expand Up @@ -403,6 +412,13 @@ class Int2IntMap {
} catch (...) { return def_val;}
}

void batch_get_w_default(const uint32_t n_key, const uint64_t* keys, const uint64_t def_val, uint64_t* vals, const int threads) {
#pragma omp parallel for schedule(static, 1) num_threads(threads)
for (uint32_t i=0; i<n_key; ++i) {
vals[i] = get_w_default(keys[i], def_val);
}
}

bool contains(uint64_t key) { return map.contains(key); }

size_t size() { return map.size(); }
Expand Down
154 changes: 152 additions & 2 deletions pecos/utils/mmap_hashmap_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
import logging
from abc import abstractmethod
from pecos.core import clib
from typing import Optional

from typing import Optional, Tuple
from ctypes import c_char_p, c_uint32, c_uint64, POINTER
import numpy as np
import os

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -88,6 +90,53 @@ def __del__(self):
self.close()


class MmapHashmapBatchGetter(object):
"""
Batch getter for MmapHashmap opened for readonly.
"""

def __init__(self, mmap_r, max_batch_size: int, threads: int = 1):
if not isinstance(mmap_r, _MmapHashmapReadOnly):
raise ValueError(f"Should get from readonly MmapHashmap, got {type(mmap_r)}")
if max_batch_size <= 0:
raise ValueError(f"Max batch size should >0, got {max_batch_size}")
if threads <= 0 and threads != -1:
raise ValueError(f"Number of threads should >0 or =-1, got {threads}")

self.mmap_r: Optional[_MmapHashmapReadOnly] = mmap_r
self.max_batch_size = max_batch_size
self.key_prealloc = mmap_r.get_keyalloc(max_batch_size)

# `os.cpu_count()` is not equivalent to the number of CPUs the current process can use.
# The number of usable CPUs can be obtained with len(os.sched_getaffinity(0))
n_usable_cpu = len(os.sched_getaffinity(0))
self.threads_c_uint32 = c_uint32(
min(n_usable_cpu, n_usable_cpu if threads == -1 else threads)
)

# Pre-allocated space for returns
self.vals = np.zeros(max_batch_size, dtype=np.uint64)

def get(self, keys, default_val):
"""
Batch get multiple keys' values. For non-exist keys, `default_val` is returned.

NOTE:
1) Make sure keys given is compatible with the `MmapHashmap` `batch_get` type.
i) str2int: List of UTF8 encoded strings
ii) int2int: 1D numpy array of int64
2) The return is a reused buffer, use or copy the data once you get it. It is not guaranteed to last.
"""
self.mmap_r.batch_get(
len(keys),
self.key_prealloc.get_key_prealloc(keys),
default_val,
self.vals,
self.threads_c_uint32,
)
return memoryview(self.vals)[: len(keys)]


class _MmapHashmapBase(object):
"""Base class for methods shared by all modes"""

Expand Down Expand Up @@ -117,6 +166,15 @@ def __getitem__(self, key):
def __contains__(self, key):
pass

@abstractmethod
def batch_get(self, n_keys, keys, default_val, vals, threads_c_uint32):
pass

@classmethod
@abstractmethod
def get_keyalloc(cls, max_batch_size):
pass

@classmethod
def init(cls, map_type, map_dir, lazy_load):
fn_dict = clib.mmap_hashmap_init(map_type)
Expand All @@ -135,6 +193,7 @@ def get(self, key_utf8, default_val):
"""
Args:
key_utf8: UTF8 encoded bytes string key
default_val: Default value for key not found
"""
return self.fn_dict["get_w_default"](
self.map_ptr,
Expand All @@ -149,6 +208,59 @@ def __getitem__(self, key_utf8):
def __contains__(self, key_utf8):
return self.fn_dict["contains"](self.map_ptr, key_utf8, len(key_utf8))

def batch_get(
self, n_keys: int, keys_utf8: Tuple, default_val: int, vals, threads_c_uint32: c_uint32
):
"""
Batch get values for UTF8 encoded bytes string keys.
Return values are stored in vals.

How to make inputs from UTF8 encoded bytes string keys List `keys_utf8`:
> keys_ptr = (c_char_p * n_keys)()
> keys_ptr[:] = keys_utf8
> keys_lens = np.array([len(k) for k in keys_utf8], dtype=np.uint32)

Args:
n_keys: int. Number of keys to get.
keys_utf8: Tuple of (keys_ptr, keys_lens)
keys_ptr: List of UTF8 encoded bytes string keys' pointers
keys_lens: 1D Int32 Numpy array of string keys' lengths
default_val: Default value for key not found
vals: 1D Int64 Numpy array to return results
threads_c_uint32: Number of threads to use.
"""
keys_ptr, keys_lens = keys_utf8
self.fn_dict["batch_get_w_default"](
self.map_ptr,
n_keys,
keys_ptr,
keys_lens.ctypes.data_as(POINTER(c_uint32)),
default_val,
vals.ctypes.data_as(POINTER(c_uint64)),
threads_c_uint32,
)
return vals

@classmethod
def get_keyalloc(cls, max_batch_size):
return _Str2IntBatchGetterKeyPreAlloc(max_batch_size)


class _Str2IntBatchGetterKeyPreAlloc(object):
"""
Key pre-allocate for Str2Int MmapHashmap.
"""

def __init__(self, max_batch_size: int):
self.keys_ptr = (c_char_p * max_batch_size)()
self.keys_lens = np.zeros(max_batch_size, dtype=np.uint32)

def get_key_prealloc(self, keys_utf8):
self.keys_ptr[: len(keys_utf8)] = keys_utf8
self.keys_lens.flat[: len(keys_utf8)] = [len(k) for k in keys_utf8]

return (self.keys_ptr, self.keys_lens)


class _MmapHashmapInt2IntReadOnly(_MmapHashmapReadOnly):
def get(self, key, default_val):
Expand All @@ -160,6 +272,44 @@ def __getitem__(self, key):
def __contains__(self, key):
return self.fn_dict["contains"](self.map_ptr, key)

def batch_get(self, n_keys: int, keys, default_val: int, vals, threads_c_uint32: c_uint32):
"""
Batch get values for Int64 keys.
Return values are stored in vals.

Args:
n_keys: int. Number of keys to get.
keys: 1D Int64 Numpy array
default_val: Default value for key not found
vals: 1D Int64 Numpy array to return results
threads_c_uint32: Number of threads to use.
"""
self.fn_dict["batch_get_w_default"](
self.map_ptr,
n_keys,
keys.ctypes.data_as(POINTER(c_uint64)),
default_val,
vals.ctypes.data_as(POINTER(c_uint64)),
threads_c_uint32,
)
return vals

@classmethod
def get_keyalloc(cls, max_batch_size):
return _Int2IntBatchGetterKeyPreAlloc(max_batch_size)


class _Int2IntBatchGetterKeyPreAlloc(object):
"""
Dummy key pre-allocate for Int2Int MmapHashmap.
"""

def __init__(self, max_batch_size: int):
pass

def get_key_prealloc(self, keys):
return keys


class _MmapHashmapWrite(_MmapHashmapBase):
"""Base class for methods shared by all write modes"""
Expand Down
33 changes: 31 additions & 2 deletions test/pecos/utils/test_mmap_hashmap_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def test_str2int_mmap_hashmap(tmpdir):
from pecos.utils.mmap_hashmap_util import MmapHashmap
from pecos.utils.mmap_hashmap_util import MmapHashmap, MmapHashmapBatchGetter

map_dir = tmpdir.join("str2int_mmap").realpath().strpath
kv_dict = {"aaaa".encode("utf-8"): 2, "bb".encode("utf-8"): 3}
Expand Down Expand Up @@ -45,9 +45,25 @@ def test_str2int_mmap_hashmap(tmpdir):
# Size
assert r_map.map.size() == len(kv_dict)

# Batch get with default
max_batch_size = 5
# max_batch_size > num of key
r_map_batch_getter = MmapHashmapBatchGetter(r_map.map, max_batch_size)
ks = list(kv_dict.keys()) + ["ccccc".encode("utf-8")] # Non-exist key
vs = list(kv_dict.values()) + [10]
assert r_map_batch_getter.get(ks, 10).tolist() == vs
# max_batch_size = num of key
ks = list(kv_dict.keys()) + ["ccccc".encode("utf-8")] * (
max_batch_size - len(kv_dict)
) # Non-exist key
vs = list(kv_dict.values()) + [10] * (max_batch_size - len(kv_dict))
assert r_map_batch_getter.get(ks, 10).tolist() == vs
# Cannot test for max_batch_size < num of key, will result in segmentation fault


def test_int2int_mmap_hashmap(tmpdir):
from pecos.utils.mmap_hashmap_util import MmapHashmap
from pecos.utils.mmap_hashmap_util import MmapHashmap, MmapHashmapBatchGetter
import numpy as np

map_dir = tmpdir.join("int2int_mmap").realpath().strpath
kv_dict = {10: 2, 20: 3}
Expand Down Expand Up @@ -79,3 +95,16 @@ def test_int2int_mmap_hashmap(tmpdir):
assert not (1000 in r_map.map)
# Size
assert r_map.map.size() == len(kv_dict)

# Batch get with default
max_batch_size = 5
# max_batch_size > num of key
r_map_batch_getter = MmapHashmapBatchGetter(r_map.map, max_batch_size)
ks = list(kv_dict.keys()) + [1000] # Non-exist key
vs = list(kv_dict.values()) + [10]
assert r_map_batch_getter.get(np.array(ks, dtype=np.int64), 10).tolist() == vs
# max_batch_size = num of key
ks = list(kv_dict.keys()) + [1000] * (max_batch_size - len(kv_dict)) # Non-exist key
vs = list(kv_dict.values()) + [10] * (max_batch_size - len(kv_dict))
assert r_map_batch_getter.get(np.array(ks, dtype=np.int64), 10).tolist() == vs
# Cannot test for max_batch_size < num of key, will result in segmentation fault
Loading