Skip to content

Commit

Permalink
Add batch get for MmapHashmap (#256)
Browse files Browse the repository at this point in the history
  • Loading branch information
weiliw-amz authored Sep 26, 2023
1 parent 7643ebe commit c3bccd5
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 4 deletions.
19 changes: 19 additions & 0 deletions pecos/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1716,6 +1716,15 @@ def link_mmap_hashmap_methods(self):
c_uint64, # key int64
],
}
batch_key_args_dict = {
"str2int": [
c_void_p, # List of pointer of key string
POINTER(c_uint32), # List of length of key string
],
"int2int": [
POINTER(c_uint64), # List of key int64
],
}
self.mmap_map_fn_dict = {}

for map_type in map_type_list:
Expand Down Expand Up @@ -1760,6 +1769,16 @@ def link_mmap_hashmap_methods(self):
local_fn_dict[fn_name], c_uint64, [c_void_p] + key_args_dict[map_type] + [c_uint64]
)

fn_name = "batch_get_w_default"
local_fn_dict[fn_name] = getattr(self.clib_float32, f"{fn_prefix}_{fn_name}_{map_type}")
corelib.fillprototype(
local_fn_dict[fn_name],
None,
[c_void_p, c_uint32]
+ batch_key_args_dict[map_type] # noqa: W503
+ [c_uint64, POINTER(c_uint64), c_uint32], # noqa: W503
)

fn_name = "contains"
local_fn_dict[fn_name] = getattr(self.clib_float32, f"{fn_prefix}_{fn_name}_{map_type}")
corelib.fillprototype(
Expand Down
6 changes: 6 additions & 0 deletions pecos/core/libpecos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,12 @@ extern "C" {
MMAP_MAP_GET_W_DEFAULT(str2int, KEY_SINGLE_ARG(const char* key, uint32_t key_len), KEY_SINGLE_ARG(key, key_len))
MMAP_MAP_GET_W_DEFAULT(int2int, uint64_t key, key)

#define MMAP_MAP_BATCH_GET_W_DEFAULT(SUFFIX, KEY, FUNC_CALL_KEY) \
void mmap_hashmap_batch_get_w_default_ ## SUFFIX (void* map_ptr, const uint32_t n_key, KEY, uint64_t def_val, uint64_t* vals, const int threads) { \
static_cast<mmap_hashmap_ ## SUFFIX *>(map_ptr)->batch_get_w_default(n_key, FUNC_CALL_KEY, def_val, vals, threads); }
MMAP_MAP_BATCH_GET_W_DEFAULT(str2int, KEY_SINGLE_ARG(const char* const* keys, const uint32_t* keys_lens), KEY_SINGLE_ARG(keys, keys_lens))
MMAP_MAP_BATCH_GET_W_DEFAULT(int2int, const uint64_t* key, key)

// Contains
#define MMAP_MAP_CONTAINS(SUFFIX, KEY, FUNC_CALL_KEY) \
bool mmap_hashmap_contains_ ## SUFFIX (void* map_ptr, KEY) { \
Expand Down
16 changes: 16 additions & 0 deletions pecos/core/utils/mmap_hashmap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
#ifndef __MMAP_ANKERL_HASHMAP_H__
#define __MMAP_ANKERL_HASHMAP_H__

#include <omp.h>
#include "third_party/ankerl/unordered_dense.h"
#include "mmap_util.hpp"


namespace pecos {
namespace mmap_hashmap {

Expand Down Expand Up @@ -374,6 +376,13 @@ class Str2IntMap {
} catch (...) { return def_val;}
}

void batch_get_w_default(const uint32_t n_key, const char* const* keys, const uint32_t* keys_lens, const uint64_t def_val, uint64_t* vals, const int threads) {
#pragma omp parallel for schedule(static, 1) num_threads(threads)
for (uint32_t i=0; i<n_key; ++i) {
vals[i] = get_w_default(keys[i], keys_lens[i], def_val);
}
}

bool contains(const char* key, uint32_t key_len) {
return map.contains(std::string_view(key, key_len));
}
Expand Down Expand Up @@ -403,6 +412,13 @@ class Int2IntMap {
} catch (...) { return def_val;}
}

void batch_get_w_default(const uint32_t n_key, const uint64_t* keys, const uint64_t def_val, uint64_t* vals, const int threads) {
#pragma omp parallel for schedule(static, 1) num_threads(threads)
for (uint32_t i=0; i<n_key; ++i) {
vals[i] = get_w_default(keys[i], def_val);
}
}

bool contains(uint64_t key) { return map.contains(key); }

size_t size() { return map.size(); }
Expand Down
154 changes: 152 additions & 2 deletions pecos/utils/mmap_hashmap_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
import logging
from abc import abstractmethod
from pecos.core import clib
from typing import Optional

from typing import Optional, Tuple
from ctypes import c_char_p, c_uint32, c_uint64, POINTER
import numpy as np
import os

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -88,6 +90,53 @@ def __del__(self):
self.close()


class MmapHashmapBatchGetter(object):
"""
Batch getter for MmapHashmap opened for readonly.
"""

def __init__(self, mmap_r, max_batch_size: int, threads: int = 1):
if not isinstance(mmap_r, _MmapHashmapReadOnly):
raise ValueError(f"Should get from readonly MmapHashmap, got {type(mmap_r)}")
if max_batch_size <= 0:
raise ValueError(f"Max batch size should >0, got {max_batch_size}")
if threads <= 0 and threads != -1:
raise ValueError(f"Number of threads should >0 or =-1, got {threads}")

self.mmap_r: Optional[_MmapHashmapReadOnly] = mmap_r
self.max_batch_size = max_batch_size
self.key_prealloc = mmap_r.get_keyalloc(max_batch_size)

# `os.cpu_count()` is not equivalent to the number of CPUs the current process can use.
# The number of usable CPUs can be obtained with len(os.sched_getaffinity(0))
n_usable_cpu = len(os.sched_getaffinity(0))
self.threads_c_uint32 = c_uint32(
min(n_usable_cpu, n_usable_cpu if threads == -1 else threads)
)

# Pre-allocated space for returns
self.vals = np.zeros(max_batch_size, dtype=np.uint64)

def get(self, keys, default_val):
"""
Batch get multiple keys' values. For non-exist keys, `default_val` is returned.
NOTE:
1) Make sure keys given is compatible with the `MmapHashmap` `batch_get` type.
i) str2int: List of UTF8 encoded strings
ii) int2int: 1D numpy array of int64
2) The return is a reused buffer, use or copy the data once you get it. It is not guaranteed to last.
"""
self.mmap_r.batch_get(
len(keys),
self.key_prealloc.get_key_prealloc(keys),
default_val,
self.vals,
self.threads_c_uint32,
)
return memoryview(self.vals)[: len(keys)]


class _MmapHashmapBase(object):
"""Base class for methods shared by all modes"""

Expand Down Expand Up @@ -117,6 +166,15 @@ def __getitem__(self, key):
def __contains__(self, key):
pass

@abstractmethod
def batch_get(self, n_keys, keys, default_val, vals, threads_c_uint32):
pass

@classmethod
@abstractmethod
def get_keyalloc(cls, max_batch_size):
pass

@classmethod
def init(cls, map_type, map_dir, lazy_load):
fn_dict = clib.mmap_hashmap_init(map_type)
Expand All @@ -135,6 +193,7 @@ def get(self, key_utf8, default_val):
"""
Args:
key_utf8: UTF8 encoded bytes string key
default_val: Default value for key not found
"""
return self.fn_dict["get_w_default"](
self.map_ptr,
Expand All @@ -149,6 +208,59 @@ def __getitem__(self, key_utf8):
def __contains__(self, key_utf8):
return self.fn_dict["contains"](self.map_ptr, key_utf8, len(key_utf8))

def batch_get(
self, n_keys: int, keys_utf8: Tuple, default_val: int, vals, threads_c_uint32: c_uint32
):
"""
Batch get values for UTF8 encoded bytes string keys.
Return values are stored in vals.
How to make inputs from UTF8 encoded bytes string keys List `keys_utf8`:
> keys_ptr = (c_char_p * n_keys)()
> keys_ptr[:] = keys_utf8
> keys_lens = np.array([len(k) for k in keys_utf8], dtype=np.uint32)
Args:
n_keys: int. Number of keys to get.
keys_utf8: Tuple of (keys_ptr, keys_lens)
keys_ptr: List of UTF8 encoded bytes string keys' pointers
keys_lens: 1D Int32 Numpy array of string keys' lengths
default_val: Default value for key not found
vals: 1D Int64 Numpy array to return results
threads_c_uint32: Number of threads to use.
"""
keys_ptr, keys_lens = keys_utf8
self.fn_dict["batch_get_w_default"](
self.map_ptr,
n_keys,
keys_ptr,
keys_lens.ctypes.data_as(POINTER(c_uint32)),
default_val,
vals.ctypes.data_as(POINTER(c_uint64)),
threads_c_uint32,
)
return vals

@classmethod
def get_keyalloc(cls, max_batch_size):
return _Str2IntBatchGetterKeyPreAlloc(max_batch_size)


class _Str2IntBatchGetterKeyPreAlloc(object):
"""
Key pre-allocate for Str2Int MmapHashmap.
"""

def __init__(self, max_batch_size: int):
self.keys_ptr = (c_char_p * max_batch_size)()
self.keys_lens = np.zeros(max_batch_size, dtype=np.uint32)

def get_key_prealloc(self, keys_utf8):
self.keys_ptr[: len(keys_utf8)] = keys_utf8
self.keys_lens.flat[: len(keys_utf8)] = [len(k) for k in keys_utf8]

return (self.keys_ptr, self.keys_lens)


class _MmapHashmapInt2IntReadOnly(_MmapHashmapReadOnly):
def get(self, key, default_val):
Expand All @@ -160,6 +272,44 @@ def __getitem__(self, key):
def __contains__(self, key):
return self.fn_dict["contains"](self.map_ptr, key)

def batch_get(self, n_keys: int, keys, default_val: int, vals, threads_c_uint32: c_uint32):
"""
Batch get values for Int64 keys.
Return values are stored in vals.
Args:
n_keys: int. Number of keys to get.
keys: 1D Int64 Numpy array
default_val: Default value for key not found
vals: 1D Int64 Numpy array to return results
threads_c_uint32: Number of threads to use.
"""
self.fn_dict["batch_get_w_default"](
self.map_ptr,
n_keys,
keys.ctypes.data_as(POINTER(c_uint64)),
default_val,
vals.ctypes.data_as(POINTER(c_uint64)),
threads_c_uint32,
)
return vals

@classmethod
def get_keyalloc(cls, max_batch_size):
return _Int2IntBatchGetterKeyPreAlloc(max_batch_size)


class _Int2IntBatchGetterKeyPreAlloc(object):
"""
Dummy key pre-allocate for Int2Int MmapHashmap.
"""

def __init__(self, max_batch_size: int):
pass

def get_key_prealloc(self, keys):
return keys


class _MmapHashmapWrite(_MmapHashmapBase):
"""Base class for methods shared by all write modes"""
Expand Down
33 changes: 31 additions & 2 deletions test/pecos/utils/test_mmap_hashmap_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def test_str2int_mmap_hashmap(tmpdir):
from pecos.utils.mmap_hashmap_util import MmapHashmap
from pecos.utils.mmap_hashmap_util import MmapHashmap, MmapHashmapBatchGetter

map_dir = tmpdir.join("str2int_mmap").realpath().strpath
kv_dict = {"aaaa".encode("utf-8"): 2, "bb".encode("utf-8"): 3}
Expand Down Expand Up @@ -45,9 +45,25 @@ def test_str2int_mmap_hashmap(tmpdir):
# Size
assert r_map.map.size() == len(kv_dict)

# Batch get with default
max_batch_size = 5
# max_batch_size > num of key
r_map_batch_getter = MmapHashmapBatchGetter(r_map.map, max_batch_size)
ks = list(kv_dict.keys()) + ["ccccc".encode("utf-8")] # Non-exist key
vs = list(kv_dict.values()) + [10]
assert r_map_batch_getter.get(ks, 10).tolist() == vs
# max_batch_size = num of key
ks = list(kv_dict.keys()) + ["ccccc".encode("utf-8")] * (
max_batch_size - len(kv_dict)
) # Non-exist key
vs = list(kv_dict.values()) + [10] * (max_batch_size - len(kv_dict))
assert r_map_batch_getter.get(ks, 10).tolist() == vs
# Cannot test for max_batch_size < num of key, will result in segmentation fault


def test_int2int_mmap_hashmap(tmpdir):
from pecos.utils.mmap_hashmap_util import MmapHashmap
from pecos.utils.mmap_hashmap_util import MmapHashmap, MmapHashmapBatchGetter
import numpy as np

map_dir = tmpdir.join("int2int_mmap").realpath().strpath
kv_dict = {10: 2, 20: 3}
Expand Down Expand Up @@ -79,3 +95,16 @@ def test_int2int_mmap_hashmap(tmpdir):
assert not (1000 in r_map.map)
# Size
assert r_map.map.size() == len(kv_dict)

# Batch get with default
max_batch_size = 5
# max_batch_size > num of key
r_map_batch_getter = MmapHashmapBatchGetter(r_map.map, max_batch_size)
ks = list(kv_dict.keys()) + [1000] # Non-exist key
vs = list(kv_dict.values()) + [10]
assert r_map_batch_getter.get(np.array(ks, dtype=np.int64), 10).tolist() == vs
# max_batch_size = num of key
ks = list(kv_dict.keys()) + [1000] * (max_batch_size - len(kv_dict)) # Non-exist key
vs = list(kv_dict.values()) + [10] * (max_batch_size - len(kv_dict))
assert r_map_batch_getter.get(np.array(ks, dtype=np.int64), 10).tolist() == vs
# Cannot test for max_batch_size < num of key, will result in segmentation fault

0 comments on commit c3bccd5

Please sign in to comment.