diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml new file mode 100644 index 0000000..62245fc --- /dev/null +++ b/.github/workflows/linux.yml @@ -0,0 +1,40 @@ +name: Test ubuntu latest + +on: + push: + branches: [ master, refactoring_rust_api ] + pull_request: + branches: [ master, refactoring_rust_api ] + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + + runs-on: ubuntu-latest + + strategy: + matrix: + test_args: ["", "-V --vg-suppressions ../leakcheck.supp"] + + steps: + - uses: actions/checkout@v3 + - name: Checkout submodules + run: git submodule update --init --recursive + - name: format + run: cargo fmt -- --check + - name: install rltest + run: python3 -m pip install RLTest gevent + - name: install redis + run: git clone https://github.com/redis/redis; cd redis; git checkout 7.0.5; BUILD_TLS=yes make valgrind install + - name: install valgrind + run: sudo apt-get install valgrind + - name: Build + run: | + cd tests/mr_test_module/ + cargo build --verbose + - name: Run tests + run: | + cd tests/mr_test_module/pytests/ + DEBUG=1 ./run_full_tests.sh ${{ matrix.test_args }} diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml new file mode 100644 index 0000000..97f540d --- /dev/null +++ b/.github/workflows/macos.yml @@ -0,0 +1,42 @@ +name: Test Macos + +on: + push: + branches: [ master, refactoring_rust_api ] + pull_request: + branches: [ master, refactoring_rust_api ] + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + + runs-on: macos-latest + + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Checkout submodules + run: git submodule update --init --recursive + - name: format + run: cargo fmt -- --check + - name: install rltest + run: python3 -m pip install RLTest gevent + - name: install redis + run: git clone https://github.com/redis/redis; cd redis; git checkout 7.0.5; BUILD_TLS=yes make install + - name: install automake + run: brew install automake + - name: install openssl + run: brew install openssl@1.1 + - name: Build + run: | + cd tests/mr_test_module/ + export PKG_CONFIG_PATH=/usr/local/opt/openssl@1.1/lib/pkgconfig/ + cargo build --verbose + - name: Run tests + run: | + cd tests/mr_test_module/pytests/ + DEBUG=1 ./run_full_tests.sh diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..4eb5a13 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "lib_mr" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +#redis-module = { version="0.22.0", features = ["experimental-api"]} +redis-module = { git = "https://github.com/RedisLabsModules/redismodule-rs", branch = "api_extentions", features = ["experimental-api"]} +serde_json = "1.0" +serde = "1.0" +serde_derive = "1.0" +libc = "0.2" +linkme = "0.3" + +[build-dependencies] +bindgen = "0.59.2" + +[lib] +crate-type = ["rlib"] +name = "mr" +path = "rust_api/lib.rs" diff --git a/LibMRDerive/Cargo.toml b/LibMRDerive/Cargo.toml new file mode 100644 index 0000000..20804e9 --- /dev/null +++ b/LibMRDerive/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "lib_mr_derive" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +syn = "1.0" +quote = "1.0" + +[lib] +name = "mr_derive" +path = "src/lib.rs" +proc-macro = true \ No newline at end of file diff --git a/LibMRDerive/src/lib.rs b/LibMRDerive/src/lib.rs new file mode 100644 index 0000000..56db2a9 --- /dev/null +++ b/LibMRDerive/src/lib.rs @@ -0,0 +1,28 @@ +extern crate proc_macro; +use proc_macro::TokenStream; +use quote::quote; +use quote::format_ident; +use syn; + +#[proc_macro_derive(BaseObject)] +pub fn base_object_derive(item: TokenStream) -> TokenStream { + let ast: syn::DeriveInput = syn::parse(item).unwrap(); + let name = &ast.ident; + + let func_name = format_ident!("register_{}", name.to_string().to_lowercase()); + + let gen = quote! { + impl mr::libmr::base_object::BaseObject for #name { + fn get_name() -> &'static str { + concat!(stringify!(#name), "\0") + } + } + + #[linkme::distributed_slice(mr::libmr::REGISTER_LIST)] + fn #func_name() { + #name::register(); + } + }; + + gen.into() +} \ No newline at end of file diff --git a/Makefile b/Makefile index 6e90abd..7dabea2 100644 --- a/Makefile +++ b/Makefile @@ -4,9 +4,11 @@ clean: clean_libmr build_deps: make -C deps/ - -libmr: build_deps + +libmr_only: make -C src/ + +libmr: build_deps libmr_only run_tests: make -C ./tests/mr_test_module/ test diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..138e69e --- /dev/null +++ b/build.rs @@ -0,0 +1,79 @@ +extern crate bindgen; + +use std::env; +use std::path::{Path, PathBuf}; +use std::process::Command; + +fn probe_paths<'a>(paths: &'a [&'a str]) -> Option<&'a str> { + paths.iter().find(|path| Path::new(path).exists()).copied() +} + +fn find_macos_openssl_prefix_path() -> &'static str { + const PATHS: [&str; 3] = [ + "/usr/local/opt/openssl", + "/usr/local/opt/openssl@1.1", + "/opt/homebrew/opt/openssl@1.1", + ]; + probe_paths(&PATHS).unwrap_or("") +} + +fn main() { + println!("cargo:rerun-if-changed=src/*.c"); + println!("cargo:rerun-if-changed=src/*.h"); + println!("cargo:rerun-if-changed=src/utils/*.h"); + println!("cargo:rerun-if-changed=src/utils/*.c"); + + if !Command::new("make") + .env( + "MODULE_NAME", + std::env::var("MODULE_NAME").expect("module name was not given"), + ) + .status() + .expect("failed to compile libmr") + .success() + { + panic!("failed to compile libmr"); + } + + let output_dir = env::var("OUT_DIR").expect("Can not find out directory"); + + if !Command::new("cp") + .args(["src/libmr.a", &output_dir]) + .status() + .expect("failed copy libmr.a to output directory") + .success() + { + panic!("failed copy libmr.a to output directory"); + } + + let build = bindgen::Builder::default(); + + let bindings = build + .header("src/mr.h") + .size_t_is_usize(true) + .layout_tests(false) + .generate() + .expect("error generating bindings"); + + let out_path = PathBuf::from(&output_dir); + bindings + .write_to_file(out_path.join("libmr_bindings.rs")) + .expect("failed to write bindings to file"); + + let open_ssl_prefix_path = match std::option_env!("OPENSSL_PREFIX") { + Some(p) => p, + None if std::env::consts::OS == "macos" => find_macos_openssl_prefix_path(), + _ => "", + }; + + let open_ssl_lib_path_link_argument = if open_ssl_prefix_path.is_empty() { + "".to_owned() + } else { + format!("-L{open_ssl_prefix_path}/lib/") + }; + + println!( + "cargo:rustc-flags=-L{} {} -lmr -lssl -lcrypto", + output_dir, open_ssl_lib_path_link_argument + ); +} diff --git a/deps/Makefile b/deps/Makefile index 2e6e605..896c8d7 100644 --- a/deps/Makefile +++ b/deps/Makefile @@ -1,11 +1,23 @@ all: build_hiredis build_libevent build_hiredis: +ifneq ("$(wildcard built_hiredis)","") + echo hiredis already built +else MAKEFLAGS='' USE_SSL=1 make -C ./hiredis/ +endif + touch built_hiredis build_libevent: - cd libevent; autoreconf -v -i -f; CFLAGS=-fPIC ./configure; make - +ifneq ("$(wildcard built_libevent)","") + echo libevent already built +else + cd libevent; autoreconf -v -i -f; CFLAGS=-fPIC ./configure PKG_CONFIG_PATH=$(PKG_CONFIG_PATH); make +endif + touch built_libevent + clean: make -C ./hiredis/ clean make -C ./libevent/ clean + rm built_libevent + rm built_hiredis diff --git a/rust_api/lib.rs b/rust_api/lib.rs new file mode 100644 index 0000000..6c0b641 --- /dev/null +++ b/rust_api/lib.rs @@ -0,0 +1,11 @@ +/* + * Copyright Redis Ltd. 2021 - present + * Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or + * the Server Side Public License v1 (SSPLv1). + */ + +#[macro_use] +extern crate serde_derive; + +pub mod libmr; +pub mod libmr_c_raw; diff --git a/rust_api/libmr/accumulator.rs b/rust_api/libmr/accumulator.rs new file mode 100644 index 0000000..78a58a8 --- /dev/null +++ b/rust_api/libmr/accumulator.rs @@ -0,0 +1,65 @@ +/* + * Copyright Redis Ltd. 2021 - present + * Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or + * the Server Side Public License v1 (SSPLv1). + */ + +use crate::libmr_c_raw::bindings::{ + ExecutionCtx, MR_ExecutionCtxSetError, MR_RegisterAccumulator, Record, +}; + +use crate::libmr::base_object::{register, BaseObject}; +use crate::libmr::record; +use crate::libmr::record::MRBaseRecord; +use crate::libmr::RustMRError; + +use std::os::raw::{c_char, c_void}; + +use std::ptr; + +pub extern "C" fn rust_accumulate( + ectx: *mut ExecutionCtx, + accumulator: *mut Record, + r: *mut Record, + args: *mut c_void, +) -> *mut Record { + let s = unsafe { &*(args as *mut Step) }; + let accumulator = if accumulator.is_null() { + None + } else { + let mut accumulator = + unsafe { *Box::from_raw(accumulator as *mut MRBaseRecord) }; + Some(accumulator.record.take().unwrap()) + }; + let mut r = unsafe { Box::from_raw(r as *mut MRBaseRecord) }; + let res = match s.accumulate(accumulator, r.record.take().unwrap()) { + Ok(res) => res, + Err(e) => { + unsafe { MR_ExecutionCtxSetError(ectx, e.as_ptr() as *mut c_char, e.len()) }; + return ptr::null_mut(); + } + }; + Box::into_raw(Box::new(MRBaseRecord::new(res))) as *mut Record +} + +pub trait AccumulateStep: BaseObject { + type InRecord: record::Record; + type Accumulator: record::Record; + + fn accumulate( + &self, + accumulator: Option, + r: Self::InRecord, + ) -> Result; + + fn register() { + let obj = register::(); + unsafe { + MR_RegisterAccumulator( + Self::get_name().as_ptr() as *mut c_char, + Some(rust_accumulate::), + obj, + ); + } + } +} diff --git a/rust_api/libmr/base_object.rs b/rust_api/libmr/base_object.rs new file mode 100644 index 0000000..eb0b238 --- /dev/null +++ b/rust_api/libmr/base_object.rs @@ -0,0 +1,86 @@ +/* + * Copyright Redis Ltd. 2021 - present + * Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or + * the Server Side Public License v1 (SSPLv1). + */ + +use crate::libmr_c_raw::bindings::{ + MRError, MRObjectType, MR_RegisterObject, MR_SerializationCtxReadeBuffer, + MR_SerializationCtxWriteBuffer, ReaderSerializationCtx, WriteSerializationCtx, +}; + +use std::os::raw::{c_char, c_void}; + +use serde_json::{from_str, to_string}; + +use serde::ser::Serialize; + +use serde::de::Deserialize; + +use std::slice; +use std::str; + +pub extern "C" fn rust_obj_free(ctx: *mut c_void) { + unsafe { Box::from_raw(ctx as *mut T) }; +} + +pub extern "C" fn rust_obj_dup(arg: *mut c_void) -> *mut c_void { + let obj = unsafe { &mut *(arg as *mut T) }; + let mut obj = obj.clone(); + obj.init(); + Box::into_raw(Box::new(obj)) as *mut c_void +} + +pub extern "C" fn rust_obj_serialize( + sctx: *mut WriteSerializationCtx, + arg: *mut c_void, + error: *mut *mut MRError, +) { + let obj = unsafe { &mut *(arg as *mut T) }; + let s = to_string(obj).unwrap(); + unsafe { + MR_SerializationCtxWriteBuffer(sctx, s.as_ptr() as *const c_char, s.len(), error); + } +} + +pub extern "C" fn rust_obj_deserialize( + sctx: *mut ReaderSerializationCtx, + error: *mut *mut MRError, +) -> *mut c_void { + let mut len: usize = 0; + let s = unsafe { MR_SerializationCtxReadeBuffer(sctx, &mut len as *mut usize, error) }; + if !(unsafe { *error }).is_null() { + return 0 as *mut c_void; + } + let s = str::from_utf8(unsafe { slice::from_raw_parts(s as *const u8, len) }).unwrap(); + let mut obj: T = from_str(s).unwrap(); + obj.init(); + Box::into_raw(Box::new(obj)) as *mut c_void +} + +pub extern "C" fn rust_obj_to_string(_arg: *mut c_void) -> *mut c_char { + 0 as *mut c_char +} + +pub trait BaseObject: Clone + Serialize + Deserialize<'static> { + fn get_name() -> &'static str; + fn init(&mut self) {} +} + +pub(crate) fn register() -> *mut MRObjectType { + unsafe { + let obj = Box::into_raw(Box::new(MRObjectType { + type_: T::get_name().as_ptr() as *mut c_char, + id: 0, + free: Some(rust_obj_free::), + dup: Some(rust_obj_dup::), + serialize: Some(rust_obj_serialize::), + deserialize: Some(rust_obj_deserialize::), + tostring: Some(rust_obj_to_string), + })); + + MR_RegisterObject(obj); + + obj + } +} diff --git a/rust_api/libmr/execution_builder.rs b/rust_api/libmr/execution_builder.rs new file mode 100644 index 0000000..824f530 --- /dev/null +++ b/rust_api/libmr/execution_builder.rs @@ -0,0 +1,138 @@ +/* + * Copyright Redis Ltd. 2021 - present + * Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or + * the Server Side Public License v1 (SSPLv1). + */ + +use std::marker::PhantomData; + +use crate::libmr_c_raw::bindings::{ + ExecutionBuilder, MRError, MR_CreateExecution, MR_CreateExecutionBuilder, MR_ErrorGetMessage, + MR_ExecutionBuilderBuilAccumulate, MR_ExecutionBuilderCollect, MR_ExecutionBuilderFilter, + MR_ExecutionBuilderMap, MR_ExecutionBuilderReshuffle, MR_FreeExecutionBuilder, +}; + +use std::os::raw::{c_char, c_void}; + +use crate::libmr::accumulator::AccumulateStep; +use crate::libmr::execution_object::ExecutionObj; +use crate::libmr::filter::FilterStep; +use crate::libmr::mapper::MapStep; +use crate::libmr::reader::Reader; +use crate::libmr::record; +use crate::libmr::RustMRError; + +use std::slice; +use std::str; + +use libc::strlen; + +pub struct Builder { + inner_builder: Option<*mut ExecutionBuilder>, + phantom: PhantomData, +} + +pub fn create_builder(reader: Re) -> Builder { + let reader = Box::into_raw(Box::new(reader)); + let inner_builder = unsafe { + MR_CreateExecutionBuilder( + Re::get_name().as_ptr() as *const c_char, + reader as *mut c_void, + ) + }; + Builder:: { + inner_builder: Some(inner_builder), + phantom: PhantomData, + } +} + +impl Builder { + fn take(&mut self) -> *mut ExecutionBuilder { + self.inner_builder.take().unwrap() + } + + pub fn map>(mut self, step: Step) -> Builder { + let inner_builder = self.take(); + unsafe { + MR_ExecutionBuilderMap( + inner_builder, + Step::get_name().as_ptr() as *const c_char, + Box::into_raw(Box::new(step)) as *const Step as *mut c_void, + ) + } + Builder:: { + inner_builder: Some(inner_builder), + phantom: PhantomData, + } + } + + pub fn filter>(self, step: Step) -> Builder { + unsafe { + MR_ExecutionBuilderFilter( + self.inner_builder.unwrap(), + Step::get_name().as_ptr() as *const c_char, + Box::into_raw(Box::new(step)) as *const Step as *mut c_void, + ) + } + self + } + + pub fn accumulate>( + mut self, + step: Step, + ) -> Builder { + let inner_builder = self.take(); + unsafe { + MR_ExecutionBuilderBuilAccumulate( + inner_builder, + Step::get_name().as_ptr() as *const c_char, + Box::into_raw(Box::new(step)) as *const Step as *mut c_void, + ) + } + Builder:: { + inner_builder: Some(inner_builder), + phantom: PhantomData, + } + } + + pub fn collect(self) -> Self { + unsafe { + MR_ExecutionBuilderCollect(self.inner_builder.unwrap()); + } + self + } + + pub fn reshuffle(self) -> Self { + unsafe { + MR_ExecutionBuilderReshuffle(self.inner_builder.unwrap()); + } + self + } + + pub fn create_execution(&self) -> Result, RustMRError> { + let execution = unsafe { + let mut err: *mut MRError = 0 as *mut MRError; + let res = MR_CreateExecution(self.inner_builder.unwrap(), &mut err); + if !err.is_null() { + let c_msg = MR_ErrorGetMessage(err); + let r_str = + str::from_utf8(slice::from_raw_parts(c_msg.cast::(), strlen(c_msg))) + .unwrap(); + return Err(r_str.to_string()); + } + res + }; + Ok(ExecutionObj { + inner_e: execution, + phantom: PhantomData, + }) + } +} + +impl Drop for Builder { + fn drop(&mut self) { + if let Some(innder_builder) = self.inner_builder { + unsafe { MR_FreeExecutionBuilder(innder_builder) } + } + } +} diff --git a/rust_api/libmr/execution_object.rs b/rust_api/libmr/execution_object.rs new file mode 100644 index 0000000..d2ce99a --- /dev/null +++ b/rust_api/libmr/execution_object.rs @@ -0,0 +1,73 @@ +/* + * Copyright Redis Ltd. 2021 - present + * Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or + * the Server Side Public License v1 (SSPLv1). + */ + +use std::marker::PhantomData; + +use crate::libmr_c_raw::bindings::{ + Execution, ExecutionCtx, MR_ExecutionCtxGetError, MR_ExecutionCtxGetErrorsLen, + MR_ExecutionCtxGetResult, MR_ExecutionCtxGetResultsLen, MR_ExecutionSetMaxIdle, + MR_ExecutionSetOnDoneHandler, MR_FreeExecution, MR_Run, +}; + +use crate::libmr::record; + +use std::os::raw::c_void; + +use std::slice; +use std::str; + +use libc::strlen; + +pub struct ExecutionObj { + pub(crate) inner_e: *mut Execution, + pub(crate) phantom: PhantomData, +} + +pub extern "C" fn rust_on_done, Vec<&str>)>( + ectx: *mut ExecutionCtx, + pd: *mut c_void, +) { + let f = unsafe { Box::from_raw(pd as *mut F) }; + let mut res = Vec::new(); + let res_len = unsafe { MR_ExecutionCtxGetResultsLen(ectx) }; + for i in 0..res_len { + let r = + unsafe { &mut *(MR_ExecutionCtxGetResult(ectx, i) as *mut record::MRBaseRecord) }; + res.push(r.record.as_mut().unwrap()); + } + let mut errs = Vec::new(); + let errs_len = unsafe { MR_ExecutionCtxGetErrorsLen(ectx) }; + for i in 0..errs_len { + let r = unsafe { MR_ExecutionCtxGetError(ectx, i) }; + let s = + str::from_utf8(unsafe { slice::from_raw_parts(r.cast::(), strlen(r)) }).unwrap(); + errs.push(s); + } + f(res, errs); +} + +impl ExecutionObj { + pub fn set_max_idle(&self, max_idle: usize) { + unsafe { MR_ExecutionSetMaxIdle(self.inner_e, max_idle) }; + } + + pub fn set_done_hanlder, Vec<&str>)>(&self, f: F) { + let f = Box::into_raw(Box::new(f)); + unsafe { + MR_ExecutionSetOnDoneHandler(self.inner_e, Some(rust_on_done::), f as *mut c_void) + }; + } + + pub fn run(&self) { + unsafe { MR_Run(self.inner_e) }; + } +} + +impl Drop for ExecutionObj { + fn drop(&mut self) { + unsafe { MR_FreeExecution(self.inner_e) }; + } +} diff --git a/rust_api/libmr/filter.rs b/rust_api/libmr/filter.rs new file mode 100644 index 0000000..58491b3 --- /dev/null +++ b/rust_api/libmr/filter.rs @@ -0,0 +1,49 @@ +/* + * Copyright Redis Ltd. 2021 - present + * Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or + * the Server Side Public License v1 (SSPLv1). + */ + +use crate::libmr_c_raw::bindings::{ + ExecutionCtx, MR_ExecutionCtxSetError, MR_RegisterFilter, Record, +}; + +use crate::libmr::base_object::{register, BaseObject}; +use crate::libmr::record; +use crate::libmr::record::MRBaseRecord; +use crate::libmr::RustMRError; + +use std::os::raw::{c_char, c_int, c_void}; + +pub extern "C" fn rust_filter( + ectx: *mut ExecutionCtx, + r: *mut Record, + args: *mut c_void, +) -> c_int { + let s = unsafe { &*(args as *mut Step) }; + let r = unsafe { &*(r as *mut MRBaseRecord) }; // do not take ownership on the record + match s.filter(&r.record.as_ref().unwrap()) { + Ok(res) => res as c_int, + Err(e) => { + unsafe { MR_ExecutionCtxSetError(ectx, e.as_ptr() as *mut c_char, e.len()) }; + 0 as c_int + } + } +} + +pub trait FilterStep: BaseObject { + type R: record::Record; + + fn filter(&self, r: &Self::R) -> Result; + + fn register() { + let obj = register::(); + unsafe { + MR_RegisterFilter( + Self::get_name().as_ptr() as *mut c_char, + Some(rust_filter::), + obj, + ); + } + } +} diff --git a/rust_api/libmr/mapper.rs b/rust_api/libmr/mapper.rs new file mode 100644 index 0000000..f2422b2 --- /dev/null +++ b/rust_api/libmr/mapper.rs @@ -0,0 +1,53 @@ +/* + * Copyright Redis Ltd. 2021 - present + * Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or + * the Server Side Public License v1 (SSPLv1). + */ + +use crate::libmr_c_raw::bindings::{ + ExecutionCtx, MR_ExecutionCtxSetError, MR_RegisterMapper, Record, +}; + +use crate::libmr::base_object::{register, BaseObject}; +use crate::libmr::record; +use crate::libmr::record::MRBaseRecord; +use crate::libmr::RustMRError; + +use std::os::raw::{c_char, c_void}; + +use std::ptr; + +pub extern "C" fn rust_map( + ectx: *mut ExecutionCtx, + r: *mut Record, + args: *mut c_void, +) -> *mut Record { + let s = unsafe { &*(args as *mut Step) }; + let mut r = unsafe { Box::from_raw(r as *mut MRBaseRecord) }; + let res = match s.map(r.record.take().unwrap()) { + Ok(res) => res, + Err(e) => { + unsafe { MR_ExecutionCtxSetError(ectx, e.as_ptr() as *mut c_char, e.len()) }; + return ptr::null_mut(); + } + }; + Box::into_raw(Box::new(MRBaseRecord::new(res))) as *mut Record +} + +pub trait MapStep: BaseObject { + type InRecord: record::Record; + type OutRecord: record::Record; + + fn map(&self, r: Self::InRecord) -> Result; + + fn register() { + let obj = register::(); + unsafe { + MR_RegisterMapper( + Self::get_name().as_ptr() as *mut c_char, + Some(rust_map::), + obj, + ); + } + } +} diff --git a/rust_api/libmr/mod.rs b/rust_api/libmr/mod.rs new file mode 100644 index 0000000..07edf55 --- /dev/null +++ b/rust_api/libmr/mod.rs @@ -0,0 +1,48 @@ +/* + * Copyright Redis Ltd. 2021 - present + * Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or + * the Server Side Public License v1 (SSPLv1). + */ + +use crate::libmr_c_raw::bindings::{MRRecordType, MR_CalculateSlot, MR_Init, RedisModuleCtx}; +use redis_module::Context; + +use std::os::raw::c_char; + +use linkme::distributed_slice; + +pub mod accumulator; +pub mod base_object; +pub mod execution_builder; +pub mod execution_object; +pub mod filter; +pub mod mapper; +pub mod reader; +pub mod record; +pub mod remote_task; + +#[distributed_slice()] +pub static REGISTER_LIST: [fn()] = [..]; + +impl Default for crate::libmr_c_raw::bindings::Record { + fn default() -> Self { + crate::libmr_c_raw::bindings::Record { + recordType: 0 as *mut MRRecordType, + } + } +} + +pub type RustMRError = String; + +pub fn mr_init(ctx: &Context, num_threads: usize) { + unsafe { MR_Init(ctx.ctx as *mut RedisModuleCtx, num_threads) }; + record::init(); + + for register in REGISTER_LIST { + register(); + } +} + +pub fn calc_slot(s: &[u8]) -> usize { + unsafe { MR_CalculateSlot(s.as_ptr() as *const c_char, s.len()) } +} diff --git a/rust_api/libmr/reader.rs b/rust_api/libmr/reader.rs new file mode 100644 index 0000000..8140736 --- /dev/null +++ b/rust_api/libmr/reader.rs @@ -0,0 +1,51 @@ +/* + * Copyright Redis Ltd. 2021 - present + * Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or + * the Server Side Public License v1 (SSPLv1). + */ + +use crate::libmr_c_raw::bindings::{ + ExecutionCtx, MR_ExecutionCtxSetError, MR_RegisterReader, Record, +}; + +use crate::libmr::base_object::{register, BaseObject}; +use crate::libmr::record; +use crate::libmr::record::MRBaseRecord; +use crate::libmr::RustMRError; + +use std::os::raw::{c_char, c_void}; + +use std::ptr; + +extern "C" fn rust_reader(ectx: *mut ExecutionCtx, args: *mut c_void) -> *mut Record { + let r = unsafe { &mut *(args as *mut Step) }; + let res = match r.read() { + Ok(res) => match res { + Some(res) => res, + None => return ptr::null_mut(), + }, + Err(e) => { + unsafe { MR_ExecutionCtxSetError(ectx, e.as_ptr() as *mut c_char, e.len()) }; + return ptr::null_mut(); + } + }; + + Box::into_raw(Box::new(MRBaseRecord::new(res))) as *mut Record +} + +pub trait Reader: BaseObject { + type R: record::Record; + + fn read(&mut self) -> Result, RustMRError>; + + fn register() { + let obj = register::(); + unsafe { + MR_RegisterReader( + Self::get_name().as_ptr() as *mut c_char, + Some(rust_reader::), + obj, + ); + } + } +} diff --git a/rust_api/libmr/record.rs b/rust_api/libmr/record.rs new file mode 100644 index 0000000..9eb584c --- /dev/null +++ b/rust_api/libmr/record.rs @@ -0,0 +1,145 @@ +/* + * Copyright Redis Ltd. 2021 - present + * Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or + * the Server Side Public License v1 (SSPLv1). + */ + +use crate::libmr_c_raw::bindings::{ + MRError, MRObjectType, MRRecordType, MR_RegisterRecord, MR_SerializationCtxReadeBuffer, + MR_SerializationCtxWriteBuffer, ReaderSerializationCtx, RedisModuleCtx, WriteSerializationCtx, +}; + +use redis_module::RedisValue; + +use serde_json::{from_str, to_string}; + +use std::os::raw::{c_char, c_void}; + +use crate::libmr::base_object::BaseObject; + +use std::collections::HashMap; + +use std::slice; +use std::str; + +#[repr(C)] +#[derive(Clone, Serialize)] +pub(crate) struct MRBaseRecord { + #[serde(skip)] + base: crate::libmr_c_raw::bindings::Record, + pub(crate) record: Option, +} + +impl MRBaseRecord { + pub(crate) fn new(record: T) -> MRBaseRecord { + MRBaseRecord { + base: crate::libmr_c_raw::bindings::Record { + recordType: get_record_type(T::get_name()).unwrap(), + }, + record: Some(record), + } + } +} + +pub extern "C" fn rust_obj_free(ctx: *mut c_void) { + unsafe { Box::from_raw(ctx as *mut MRBaseRecord) }; +} + +pub extern "C" fn rust_obj_dup(arg: *mut c_void) -> *mut c_void { + let obj = unsafe { &mut *(arg as *mut MRBaseRecord) }; + let mut obj = obj.clone(); + obj.record.as_mut().unwrap().init(); + Box::into_raw(Box::new(obj)) as *mut c_void +} + +pub extern "C" fn rust_obj_serialize( + sctx: *mut WriteSerializationCtx, + arg: *mut c_void, + error: *mut *mut MRError, +) { + let obj = unsafe { &mut *(arg as *mut MRBaseRecord) }; + let s = to_string(obj.record.as_ref().unwrap()).unwrap(); + unsafe { + MR_SerializationCtxWriteBuffer(sctx, s.as_ptr() as *const c_char, s.len(), error); + } +} + +pub extern "C" fn rust_obj_deserialize( + sctx: *mut ReaderSerializationCtx, + error: *mut *mut MRError, +) -> *mut c_void { + let mut len: usize = 0; + let s = unsafe { MR_SerializationCtxReadeBuffer(sctx, &mut len as *mut usize, error) }; + if !(unsafe { *error }).is_null() { + return 0 as *mut c_void; + } + let s = str::from_utf8(unsafe { slice::from_raw_parts(s as *const u8, len) }).unwrap(); + let mut obj: T = from_str(s).unwrap(); + obj.init(); + Box::into_raw(Box::new(MRBaseRecord::new(obj))) as *mut c_void +} + +pub extern "C" fn rust_obj_to_string(_arg: *mut c_void) -> *mut c_char { + 0 as *mut c_char +} + +pub extern "C" fn rust_obj_send_reply( + _arg1: *mut RedisModuleCtx, + _record: *mut ::std::os::raw::c_void, +) { +} + +pub extern "C" fn rust_obj_hash_slot(record: *mut ::std::os::raw::c_void) -> usize { + let record = unsafe { &mut *(record as *mut MRBaseRecord) }; + record.record.as_ref().unwrap().hash_slot() +} + +fn register_record() -> *mut MRRecordType { + unsafe { + let obj = Box::into_raw(Box::new(MRRecordType { + type_: MRObjectType { + type_: T::get_name().as_ptr() as *mut c_char, + id: 0, + free: Some(rust_obj_free::), + dup: Some(rust_obj_dup::), + serialize: Some(rust_obj_serialize::), + deserialize: Some(rust_obj_deserialize::), + tostring: Some(rust_obj_to_string), + }, + sendReply: Some(rust_obj_send_reply), + hashTag: Some(rust_obj_hash_slot::), + })); + + MR_RegisterRecord(obj); + + obj + } +} + +static mut RECORD_TYPES: Option> = None; + +fn get_record_types_mut() -> &'static mut HashMap { + unsafe { RECORD_TYPES.as_mut().unwrap() } +} + +fn get_record_type(name: &str) -> Option<*mut MRRecordType> { + match unsafe { RECORD_TYPES.as_ref().unwrap() }.get(name) { + Some(r) => Some(*r), + None => None, + } +} + +pub(crate) fn init() { + unsafe { + RECORD_TYPES = Some(HashMap::new()); + } +} + +pub trait Record: BaseObject { + fn register() { + let record_type = register_record::(); + get_record_types_mut().insert(Self::get_name().to_string(), record_type); + } + fn to_redis_value(&mut self) -> RedisValue; + fn hash_slot(&self) -> usize; +} diff --git a/rust_api/libmr/remote_task.rs b/rust_api/libmr/remote_task.rs new file mode 100644 index 0000000..d6b1f70 --- /dev/null +++ b/rust_api/libmr/remote_task.rs @@ -0,0 +1,200 @@ +/* + * Copyright Redis Ltd. 2021 - present + * Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or + * the Server Side Public License v1 (SSPLv1). + */ + +use crate::libmr_c_raw::bindings::{ + MRError, MR_ErrorCreate, MR_ErrorFree, MR_ErrorGetMessage, MR_RegisterRemoteTask, + MR_RunOnAllShards, MR_RunOnKey, Record, +}; + +use crate::libmr::base_object::{register, BaseObject}; +use crate::libmr::record; +use crate::libmr::record::MRBaseRecord; +use crate::libmr::RustMRError; + +use libc::strlen; +use std::os::raw::{c_char, c_void}; + +struct VoidHolder { + pd: *mut ::std::os::raw::c_void, +} + +impl VoidHolder { + fn get(&self) -> *mut ::std::os::raw::c_void { + self.pd + } +} + +unsafe impl Send for VoidHolder {} + +extern "C" fn rust_remote_task( + r: *mut Record, + args: *mut ::std::os::raw::c_void, + on_done: ::std::option::Option< + unsafe extern "C" fn(PD: *mut ::std::os::raw::c_void, r: *mut Record), + >, + on_error: ::std::option::Option< + unsafe extern "C" fn(PD: *mut ::std::os::raw::c_void, r: *mut MRError), + >, + pd: *mut ::std::os::raw::c_void, +) { + let void_holder = VoidHolder { pd }; + let s = unsafe { Box::from_raw(args as *mut Step) }; + let mut r = unsafe { Box::from_raw(r as *mut MRBaseRecord) }; + s.task( + r.record.take().unwrap(), + Box::new(move |res| { + let pd = void_holder.get(); + match res { + Ok(r) => { + let record = Box::new(MRBaseRecord::new(r)); + unsafe { on_done.unwrap()(pd, Box::into_raw(record) as *mut Record) } + } + Err(e) => { + let error = unsafe { MR_ErrorCreate(e.as_ptr() as *const c_char, e.len()) }; + unsafe { on_error.unwrap()(pd, error) }; + } + } + }), + ); +} + +pub trait RemoteTask: BaseObject { + type InRecord: record::Record; + type OutRecord: record::Record; + + fn task( + self, + r: Self::InRecord, + on_done: Box) + Send>, + ); + + fn register() { + let obj = register::(); + unsafe { + MR_RegisterRemoteTask( + Self::get_name().as_ptr() as *mut c_char, + Some(rust_remote_task::), + obj, + ); + } + } +} + +extern "C" fn on_done< + OutRecord: record::Record, + DoneCallback: FnOnce(Result), +>( + pd: *mut ::std::os::raw::c_void, + result: *mut Record, +) { + let callback = unsafe { Box::::from_raw(pd as *mut DoneCallback) }; + let mut r = unsafe { Box::from_raw(result as *mut MRBaseRecord) }; + callback(Ok(r.record.take().unwrap())); +} + +extern "C" fn on_done_on_all_shards< + OutRecord: record::Record, + DoneCallback: FnOnce(Vec, Vec), +>( + pd: *mut ::std::os::raw::c_void, + results: *mut *mut Record, + n_results: usize, + errs: *mut *mut MRError, + n_errs: usize, +) { + let callback = unsafe { Box::::from_raw(pd as *mut DoneCallback) }; + + let results_slice = unsafe { std::slice::from_raw_parts(results, n_results) }; + let mut results_vec = Vec::new(); + for res in results_slice { + results_vec.push( + unsafe { Box::from_raw(*res as *mut MRBaseRecord) } + .record + .take() + .unwrap(), + ) + } + + let errs_slice = unsafe { std::slice::from_raw_parts(errs, n_errs) }; + let mut errs_vec = Vec::new(); + for err in errs_slice { + let err_msg = unsafe { MR_ErrorGetMessage(*err) }; + let r_str = std::str::from_utf8(unsafe { + std::slice::from_raw_parts(err_msg.cast::(), strlen(err_msg)) + }) + .unwrap(); + errs_vec.push(r_str.to_string()) + } + + callback(results_vec, errs_vec); +} + +extern "C" fn on_error< + OutRecord: record::Record, + DoneCallback: FnOnce(Result), +>( + pd: *mut ::std::os::raw::c_void, + error: *mut MRError, +) { + let callback = unsafe { Box::::from_raw(pd as *mut DoneCallback) }; + let err_msg = unsafe { MR_ErrorGetMessage(error) }; + let r_str = std::str::from_utf8(unsafe { + std::slice::from_raw_parts(err_msg.cast::(), strlen(err_msg)) + }) + .unwrap(); + callback(Err(r_str.to_string())); + unsafe { MR_ErrorFree(error) }; +} + +pub fn run_on_key< + Remote: RemoteTask, + InRecord: record::Record, + OutRecord: record::Record, + DoneCallback: FnOnce(Result), +>( + key_name: &[u8], + remote_task: Remote, + r: InRecord, + done: DoneCallback, + timeout: usize, +) { + unsafe { + MR_RunOnKey( + key_name.as_ptr() as *mut c_char, + key_name.len(), + Remote::get_name().as_ptr() as *mut c_char, + Box::into_raw(Box::new(remote_task)) as *mut c_void, + Box::into_raw(Box::new(MRBaseRecord::new(r))) as *mut Record, + Some(on_done::), + Some(on_error::), + Box::into_raw(Box::new(done)) as *mut c_void, + timeout, + ) + } +} + +pub fn run_on_all_shards< + Remote: RemoteTask, + InRecord: record::Record, + OutRecord: record::Record, + DoneCallback: FnOnce(Vec, Vec), +>( + remote_task: Remote, + r: InRecord, + done: DoneCallback, + timeout: usize, +) { + unsafe { + MR_RunOnAllShards( + Remote::get_name().as_ptr() as *mut c_char, + Box::into_raw(Box::new(remote_task)) as *mut c_void, + Box::into_raw(Box::new(MRBaseRecord::new(r))) as *mut Record, + Some(on_done_on_all_shards::), + Box::into_raw(Box::new(done)) as *mut c_void, + timeout, + ) + } +} diff --git a/tests/mr_test_module/src/libmrraw/mod.rs b/rust_api/libmr_c_raw/mod.rs similarity index 86% rename from tests/mr_test_module/src/libmrraw/mod.rs rename to rust_api/libmr_c_raw/mod.rs index d89a57c..e00cf21 100644 --- a/tests/mr_test_module/src/libmrraw/mod.rs +++ b/rust_api/libmr_c_raw/mod.rs @@ -10,7 +10,7 @@ #![allow(dead_code)] pub mod bindings { - include!(concat!(env!("OUT_DIR"), "/mr.rs")); + include!(concat!(env!("OUT_DIR"), "/libmr_bindings.rs")); } // See: https://users.rust-lang.org/t/bindgen-generate-options-and-some-are-none/14027 diff --git a/src/Makefile b/src/Makefile index eb21e13..60efd8a 100644 --- a/src/Makefile +++ b/src/Makefile @@ -6,10 +6,13 @@ else GCC_FLAGS+=-O3 endif +OPENSSL_PREFIX?=/usr/local/opt/openssl + GCC_FLAGS+=-fvisibility=hidden -fPIC -DREDISMODULE_EXPERIMENTAL_API \ -I../deps/hiredis/ \ -I../deps/hiredis/adapters/ \ --I../deps/libevent/include/ +-I../deps/libevent/include/ \ +-I$(OPENSSL_PREFIX)/include ifeq ($(COVERAGE),1) GCC_FLAGS+=-fprofile-arcs -ftest-coverage @@ -33,12 +36,18 @@ ARTIFACT_NAME=libmr all: $(ARTIFACT_NAME) +uname_S := $(shell sh -c 'uname -s 2>/dev/null || echo not') +ifeq ($(uname_S),Darwin) + OPENSSL_PREFIX?=/usr/local/opt/openssl@1.1 + GCC_FLAGS+=-I$(OPENSSL_PREFIX)/include/ +endif + %.o : %.c gcc -c $(GCC_FLAGS) $< -o $@ -DMODULE_NAME=$(MODULE_NAME) $(ARTIFACT_NAME): $(SOURCES) - gcc $(SOURCES) $(HIREDIS) $(HIREDIS_SSL) $(LIBEVENT) $(LIBEVENT_PTHREADS) -r -o $(ARTIFACT_NAME).o $(LD_FLAGS) + gcc $(SOURCES) $(HIREDIS) $(HIREDIS_SSL) $(LIBEVENT) $(LIBEVENT_PTHREADS) -r -o $(ARTIFACT_NAME).o $(LD_FLAGS) ifeq ($(COMPILER),gcc) objcopy --localize-hidden $(ARTIFACT_NAME).o endif diff --git a/src/mr.c b/src/mr.c index 587ab4c..3a04c00 100644 --- a/src/mr.c +++ b/src/mr.c @@ -32,6 +32,8 @@ functionId PASS_RECORD_FUNCTION_ID = 0; functionId NOTIFY_STEP_DONE_FUNCTION_ID = 0; functionId NOTIFY_DONE_FUNCTION_ID = 0; functionId DROP_EXECUTION_FUNCTION_ID = 0; +functionId REMOTE_TASK_FUNCTION_ID = 0; +functionId REMOTE_TASK_DONE_FUNCTION_ID = 0; typedef struct RemoteFunctionDef { functionId* funcIdPointer; @@ -56,6 +58,8 @@ static void MR_PassRecord(RedisModuleCtx *ctx, const char *sender_id, uint8_t ty static void MR_NotifyDone(RedisModuleCtx *ctx, const char *sender_id, uint8_t type, RedisModuleString* payload); static void MR_NotifyStepDone(RedisModuleCtx *ctx, const char *sender_id, uint8_t type, RedisModuleString* payload); static void MR_DropExecution(RedisModuleCtx *ctx, const char *sender_id, uint8_t type, RedisModuleString* payload); +static void MR_RemoteTask(RedisModuleCtx *ctx, const char *sender_id, uint8_t type, RedisModuleString* payload); +static void MR_RemoteTaskDone(RedisModuleCtx *ctx, const char *sender_id, uint8_t type, RedisModuleString* payload); /* Remote functions array */ RemoteFunctionDef remoteFunctions[] = { @@ -87,6 +91,14 @@ RemoteFunctionDef remoteFunctions[] = { .funcIdPointer = &DROP_EXECUTION_FUNCTION_ID, .functionPointer = MR_DropExecution, }, + { + .funcIdPointer = &REMOTE_TASK_FUNCTION_ID, + .functionPointer = MR_RemoteTask, + }, + { + .funcIdPointer = &REMOTE_TASK_DONE_FUNCTION_ID, + .functionPointer = MR_RemoteTaskDone, + }, }; typedef struct MRStats { @@ -100,12 +112,16 @@ struct MRCtx { /* protected by the event loop */ mr_dict* executionsDict; + /* protected by the event loop */ + mr_dict* remoteDict; + /* should be initialized at start and then read only */ ARR(MRObjectType*) objectTypesDict; /* Steps dictionaries */ mr_dict* readerDict; mr_dict* mappersDict; + mr_dict* remoteTasksDict; mr_dict* filtersDict; mr_dict* accumulatorsDict; @@ -1353,11 +1369,13 @@ int MR_Init(RedisModuleCtx* ctx, size_t numThreads) { mrCtx.lastExecutionId = 0; mrCtx.executionsDict = mr_dictCreate(&dictTypeHeapIds, NULL); + mrCtx.remoteDict = mr_dictCreate(&dictTypeHeapIds, NULL); mrCtx.objectTypesDict = array_new(MRObjectType*, 10); mrCtx.readerDict = mr_dictCreate(&mr_dictTypeHeapStrings, NULL); mrCtx.mappersDict = mr_dictCreate(&mr_dictTypeHeapStrings, NULL); + mrCtx.remoteTasksDict = mr_dictCreate(&mr_dictTypeHeapStrings, NULL); mrCtx.filtersDict = mr_dictCreate(&mr_dictTypeHeapStrings, NULL); mrCtx.accumulatorsDict = mr_dictCreate(&mr_dictTypeHeapStrings, NULL); @@ -1440,6 +1458,17 @@ LIBMR_API void MR_RegisterAccumulator(const char* name, ExecutionAccumulator acc mr_dictAdd(mrCtx.accumulatorsDict, asd->name, asd); } +LIBMR_API void MR_RegisterRemoteTask(const char* name, RemoteTask remote, MRObjectType* argType) { + RedisModule_Assert(!mr_dictFetchValue(mrCtx.remoteTasksDict, name)); + StepDefinition* asd = MR_ALLOC(sizeof(*asd)); + *asd = (StepDefinition) { + .name = MR_STRDUP(name), + .type = argType, + .callback = remote, + }; + mr_dictAdd(mrCtx.remoteTasksDict, asd->name, asd); +} + long long MR_SerializationCtxReadeLongLong(ReaderSerializationCtx* sctx, MRError** err) { int error = 0; long res = mr_BufferReaderReadLongLong(sctx, &error); @@ -1497,3 +1526,490 @@ void MR_ErrorFree(MRError* err) { MR_FREE(err); } } + +typedef struct RunOnKeyReplyMsg { + char *sender; + char *id; +} RunOnKeyReplyMsg; + +typedef enum ReplyType { + ReplyType_OK, ReplyType_ERROR, +}ReplyType; + +typedef enum RemoteTaksMsgType { + RemoteTaksMsgType_OnKey, RemoteTaksMsgType_OnAllShards, +} RemoteTaksMsgType; + +typedef struct RemoteTaksMsg { + char idStr[STR_ID_LEN]; + char id[ID_LEN]; + char *msg; + size_t msgLen; + size_t timeout; + MR_LoopTaskCtx* timeoutTask; + RemoteTaksMsgType remoteTaskType; +} RemoteTaksMsg; + +typedef struct RemoteTaskResult { + union { + Record *res; + MRError *error; + }; + ReplyType replyType; +} RemoteTaskResult; + +typedef struct RunOnKeyMsg { + RemoteTaksMsg remoteTaskBase; + size_t slot; + void (*onDone)(void *pd, Record* result); + void (*onError)(void *pd, MRError* err); + void *pd; + RemoteTaskResult remoteTaskRes; +} RunOnKeyMsg; + +typedef struct RunOnShardsMsg { + RemoteTaksMsg remoteTaskBase; + void (*onDone)(void *pd, Record** result, size_t nResults, MRError** errs, size_t nErrs); + void *pd; + ReplyType replyType; + void* args; + Record* r; + ARR(Record*) results; + ARR(MRError*) errs; + size_t expectedNResults; + size_t nResultsArrived; + StepDefinition* msd; +} RunOnShardsMsg; + +/* Run on thread pool */ +static void MR_RemoteTaskOnShardsDoneInternal(void* pd) { + RunOnShardsMsg *msg = pd; + msg->onDone(msg->pd, msg->results, array_len(msg->results), msg->errs, array_len(msg->errs)); + + array_free(msg->results); + array_free(msg->errs); + MR_FREE(msg); +} + +/* Run on thread pool */ +static void MR_RemoteTaskOnKeyDoneInternal(void* pd) { + RunOnKeyMsg *msg = pd; + + if (msg->remoteTaskRes.replyType == ReplyType_OK) { + msg->onDone(msg->pd, msg->remoteTaskRes.res); + } else { + msg->onError(msg->pd, msg->remoteTaskRes.error); + } + + MR_FREE(msg); +} + +/* Run on the event loop */ +static void MR_RemoteTaskDoneProcessesResult(const char *id, RemoteTaskResult remoteTaskRes) { + RemoteTaksMsg *msgBase = mr_dictFetchValue(mrCtx.remoteDict, id); + if (!msgBase) { + RedisModule_Log(NULL, "warning", "Got a remote task done on none existing ID %.*s", REDISMODULE_NODE_ID_LEN, id); + return; + } + + if (msgBase->remoteTaskType == RemoteTaksMsgType_OnKey) { + RunOnKeyMsg *msg = (RunOnKeyMsg*)msgBase; + msg->remoteTaskRes = remoteTaskRes; + + if (msg->remoteTaskBase.timeoutTask) { + MR_EventLoopDelayTaskCancel(msg->remoteTaskBase.timeoutTask); + msg->remoteTaskBase.timeoutTask = NULL; + } + + /* Remove msg from remoteDict, we will be done with it once we fire the done callback */ + mr_dictDelete(mrCtx.remoteDict, msgBase->id); + + /* Run the callback on the thread pool */ + mr_thpool_add_work(mrCtx.executionsThreadPool, MR_RemoteTaskOnKeyDoneInternal, msg); + } else { + RedisModule_Assert(msgBase->remoteTaskType == RemoteTaksMsgType_OnAllShards); + RunOnShardsMsg *msg = (RunOnShardsMsg*)msgBase; + if (remoteTaskRes.replyType == ReplyType_OK) { + msg->results = array_append(msg->results, remoteTaskRes.res); + } else { + msg->errs = array_append(msg->errs, remoteTaskRes.error); + } + ++msg->nResultsArrived; + if (msg->nResultsArrived == msg->expectedNResults) { + /* Got all the results */ + + if (msg->remoteTaskBase.timeoutTask) { + MR_EventLoopDelayTaskCancel(msg->remoteTaskBase.timeoutTask); + msg->remoteTaskBase.timeoutTask = NULL; + } + + /* Remove msg from remoteDict, we will be done with it once we fire the done callback */ + mr_dictDelete(mrCtx.remoteDict, msgBase->id); + + /* Run the callback on the thread pool */ + mr_thpool_add_work(mrCtx.executionsThreadPool, MR_RemoteTaskOnShardsDoneInternal, msg); + } + } +} + +/* Run on the event loop */ +static void MR_RemoteTaskDone(RedisModuleCtx *ctx, const char *sender_id, uint8_t type, RedisModuleString* payload) { + size_t dataSize; + const char* data = RedisModule_StringPtrLen(payload, &dataSize); + mr_Buffer buff = { + .buff = (char*)data, + .size = dataSize, + .cap = dataSize, + }; + mr_BufferReader buffReader; + mr_BufferReaderInit(&buffReader, &buff); + + size_t idLen; + const char *id = mr_BufferReaderReadBuff(&buffReader, &idLen, NULL); + RedisModule_Assert(idLen == ID_LEN); + + RemoteTaskResult remoteTaskRes; + + if (mr_BufferReaderReadLongLong(&buffReader, NULL)) { + remoteTaskRes.res = MR_RecordDeSerialize(&buffReader); + remoteTaskRes.replyType = ReplyType_OK; + } else { + const char* errMsg = mr_BufferReaderReadString(&buffReader, NULL); + remoteTaskRes.error = MR_ErrorCreate(errMsg, strlen(errMsg)); + remoteTaskRes.replyType = ReplyType_ERROR; + } + + MR_RemoteTaskDoneProcessesResult(id, remoteTaskRes); +} + +static void MR_RemoteTaskErrorOnRemote(void *pd, MRError *error) { + RunOnKeyReplyMsg *replyMsg = pd; + + mr_Buffer buff; + mr_BufferInitialize(&buff); + mr_BufferWriter buffWriter; + mr_BufferWriterInit(&buffWriter, &buff); + + /* write id */ + mr_BufferWriterWriteBuff(&buffWriter, replyMsg->id, ID_LEN); + + mr_BufferWriterWriteLongLong(&buffWriter, 0); /* mean failure */ + + mr_BufferWriterWriteString(&buffWriter, error->msg); + + MR_ClusterSendMsg(replyMsg->sender, REMOTE_TASK_DONE_FUNCTION_ID, buff.buff, buff.size); + + MR_ErrorFree(error); + MR_FREE(replyMsg->id); + MR_FREE(replyMsg->sender); + MR_FREE(replyMsg); +} + +static void MR_RemoteTaskDoneOnRemote(void *pd, Record *res) { + RunOnKeyReplyMsg *replyMsg = pd; + + MRError* error = NULL; + mr_Buffer buff; + mr_BufferInitialize(&buff); + mr_BufferWriter buffWriter; + mr_BufferWriterInit(&buffWriter, &buff); + + /* write id */ + mr_BufferWriterWriteBuff(&buffWriter, replyMsg->id, ID_LEN); + + mr_BufferWriterWriteLongLong(&buffWriter, 1); /* mean success */ + + MR_RecordSerialize(res, &buffWriter); + /* todo: handler serialization failure */ + + MR_ClusterSendMsg(replyMsg->sender, REMOTE_TASK_DONE_FUNCTION_ID, buff.buff, buff.size); + + MR_RecordFree(res); + MR_FREE(replyMsg->id); + MR_FREE(replyMsg->sender); + MR_FREE(replyMsg); +} + +/* Runs on thread pool */ +static void MR_RemoteTaskInternal(void* pd) { + RedisModuleString* payload = pd; + size_t dataSize; + const char* data = RedisModule_StringPtrLen(payload, &dataSize); + mr_Buffer buff = { + .buff = (char*)data, + .size = dataSize, + .cap = dataSize, + }; + mr_BufferReader buffReader; + mr_BufferReaderInit(&buffReader, &buff); + + /* Read sender id */ + const char *sender = mr_BufferReaderReadString(&buffReader, NULL); + + size_t idLen; + const char *id = mr_BufferReaderReadBuff(&buffReader, &idLen, NULL); + RedisModule_Assert(idLen == ID_LEN); + + const char *remoteTaskName = mr_BufferReaderReadString(&buffReader, NULL); + StepDefinition* msd = mr_dictFetchValue(mrCtx.remoteTasksDict, remoteTaskName); + RedisModule_Assert(msd); + + MRError* error = NULL; + void *args = msd->type->deserialize(&buffReader, &error); + /* todo: handler serialization failure */ + + Record *r = MR_RecordDeSerialize(&buffReader); + /* todo: handler serialization failure */ + + RunOnKeyReplyMsg *replyMsg = MR_ALLOC(sizeof(*replyMsg)); + replyMsg->sender = MR_STRDUP(sender); + replyMsg->id = MR_ALLOC(idLen); + memcpy(replyMsg->id, id, idLen); + + ((RemoteTask)msd->callback)(r, args, MR_RemoteTaskDoneOnRemote, MR_RemoteTaskErrorOnRemote, replyMsg); + + /* We must take the Redis GIL to free the payload, + * RedisModuleString refcount are not thread safe. + * We better do it here and stuck on of the threads + * in the thread pool then do it on the event loop. + * Possible optimization would be to batch multiple + * payloads into one GIL locking */ + RedisModule_ThreadSafeContextLock(mr_staticCtx); + RedisModule_FreeString(NULL, payload); + RedisModule_ThreadSafeContextUnlock(mr_staticCtx); +} + +/* Runs on the event loop */ +static void MR_RemoteTask(RedisModuleCtx *ctx, const char *sender_id, uint8_t type, RedisModuleString* payload) { + /* We can directly move the job to the thread pool for deserialization and execution */ + mr_thpool_add_work(mrCtx.executionsThreadPool, MR_RemoteTaskInternal, RedisModule_HoldString(NULL, payload)); +} + +/* Invoked on the event look */ +static void MR_RemoteTaskOnKeyTimeoutOut(void* ctx) { + static const char TIMEOUT_TEXT[] = "Remote task timeout"; + RunOnKeyMsg *msg = ctx; + msg->remoteTaskBase.timeoutTask = NULL; + + msg->remoteTaskRes.error = MR_ErrorCreate(TIMEOUT_TEXT, sizeof(TIMEOUT_TEXT) - 1); + msg->remoteTaskRes.replyType = ReplyType_ERROR; + + /* Remove msg from remoteDict, we will be done with it once we fire the done callback */ + int res = mr_dictDelete(mrCtx.remoteDict, msg->remoteTaskBase.id); + RedisModule_Assert(res == DICT_OK); + + /* Run the callback on the thread pool */ + mr_thpool_add_work(mrCtx.executionsThreadPool, MR_RemoteTaskOnKeyDoneInternal, msg); +} + +/* Invoked on the event look */ +static void MR_RunOnKeyInternal(void* ctx) { + RunOnKeyMsg *msg = ctx; + + /* add the task to the remote mappers dictionary */ + mr_dictAdd(mrCtx.remoteDict, msg->remoteTaskBase.id, msg); + + /* send the message to the shard */ + MR_ClusterSendMsgBySlot(msg->slot, REMOTE_TASK_FUNCTION_ID, msg->remoteTaskBase.msg, msg->remoteTaskBase.msgLen); + + /* ownership on the message was moved to MR_ClusterSendMsgBySlot function */ + msg->remoteTaskBase.msg = NULL; + msg->remoteTaskBase.msgLen = 0; + + if (msg->remoteTaskBase.timeout != SIZE_MAX) { + msg->remoteTaskBase.timeoutTask = MR_EventLoopAddTaskWithDelay(MR_RemoteTaskOnKeyTimeoutOut, msg, msg->remoteTaskBase.timeout); + } +} + +LIBMR_API void MR_RunOnKey(const char* keyName, + size_t keyNameSize, + const char* remoteTaskName, + void* args, + Record* r, + void (*onDone)(void *pd, Record* result), + void (*onError)(void *pd, MRError* err), + void *pd, + size_t timeout) +{ + StepDefinition* msd = mr_dictFetchValue(mrCtx.remoteTasksDict, remoteTaskName); + RedisModule_Assert(msd); + size_t slot = MR_ClusterGetSlotdByKey(keyName, keyNameSize); + if (!MR_ClusterIsInClusterMode() || MR_ClusterIsMySlot(slot)) { + ((RemoteTask)msd->callback)(r, args, onDone, onError, pd); + return; + } + + RunOnKeyMsg *msg = MR_ALLOC(sizeof(*msg)); + msg->slot = slot; + msg->onDone = onDone; + msg->onError = onError; + msg->pd = pd; + msg->remoteTaskBase.timeout = timeout; + msg->remoteTaskBase.timeoutTask = NULL; + msg->remoteTaskBase.remoteTaskType = RemoteTaksMsgType_OnKey; + /* Set id */ + size_t id = __atomic_add_fetch(&mrCtx.lastExecutionId, 1, __ATOMIC_RELAXED); + SetId(msg->remoteTaskBase.id, msg->remoteTaskBase.idStr, id); + + MRError* error = NULL; + mr_Buffer buff; + mr_BufferInitialize(&buff); + mr_BufferWriter buffWriter; + mr_BufferWriterInit(&buffWriter, &buff); + /* write sender */ + mr_BufferWriterWriteString(&buffWriter, MR_ClusterGetMyId()); + /* write id */ + mr_BufferWriterWriteBuff(&buffWriter, msg->remoteTaskBase.id, ID_LEN); + /* mapped name to invoke */ + mr_BufferWriterWriteString(&buffWriter, remoteTaskName); + /* Serialize args */ + msd->type->serialize(&buffWriter, args, &error); + msd->type->free(args); + /* todo: handler serialization failure */ + /* serialize the record */ + MR_RecordSerialize(r, &buffWriter); + MR_RecordFree(r); + + msg->remoteTaskBase.msg = buff.buff; + msg->remoteTaskBase.msgLen = buff.size; + + MR_EventLoopAddTask(MR_RunOnKeyInternal, msg); +} + +typedef struct RemoteTaskLocalRun{ + char id[ID_LEN]; + StepDefinition* msd; + void* args; + Record* r; + RemoteTaskResult result; +} RemoteTaskLocalRun; + + +static void MR_RemoteTaskDoneOnLocalEVLoop(void* ctx) { + RemoteTaskLocalRun* localRun = ctx; + + MR_RemoteTaskDoneProcessesResult(localRun->id, localRun->result); + + MR_FREE(localRun); +} + +static void MR_RemoteTaskErrorOnLocal(void *pd, MRError *error) { + RemoteTaskLocalRun* localRun = pd; + localRun->result.replyType = ReplyType_ERROR; + localRun->result.error = error; + MR_EventLoopAddTask(MR_RemoteTaskDoneOnLocalEVLoop, localRun); +} + +static void MR_RemoteTaskDoneOnLocal(void *pd, Record *res) { + RemoteTaskLocalRun* localRun = pd; + localRun->result.replyType = ReplyType_OK; + localRun->result.res = res; + MR_EventLoopAddTask(MR_RemoteTaskDoneOnLocalEVLoop, localRun); +} + +static void MR_RemoteTaskRunOnLocal(void *pd) { + RemoteTaskLocalRun* localRun = pd; + void* args = localRun->args; + Record* r = localRun->r; + localRun->args = NULL; + localRun->r = NULL; + ((RemoteTask)localRun->msd->callback)(r, args, MR_RemoteTaskDoneOnLocal, MR_RemoteTaskErrorOnLocal, localRun); +} + +/* Invoked on the event look */ +static void MR_RemoteTaskOnAllShardsTimeoutOut(void* ctx) { + RunOnShardsMsg *msg = ctx; + msg->remoteTaskBase.timeoutTask = NULL; + + msg->errs = array_append(msg->errs, MR_ErrorCreate("Timeout", 7)); + + /* Remove msg from remoteDict, we will be done with it once we fire the done callback */ + int res = mr_dictDelete(mrCtx.remoteDict, msg->remoteTaskBase.id); + RedisModule_Assert(res == DICT_OK); + + /* Run the callback on the thread pool */ + mr_thpool_add_work(mrCtx.executionsThreadPool, MR_RemoteTaskOnShardsDoneInternal, msg); +} + +/* Invoked on the event loop */ +static void MR_RunOnAllShardsInternal(void* ctx) { + RunOnShardsMsg *msg = ctx; + + /* add the task to the remote mappers dictionary */ + mr_dictAdd(mrCtx.remoteDict, msg->remoteTaskBase.id, msg); + + if (MR_ClusterIsInClusterMode()) { + /* send the message to the shard */ + MR_ClusterSendMsg(NULL, REMOTE_TASK_FUNCTION_ID, msg->remoteTaskBase.msg, msg->remoteTaskBase.msgLen); + } else { + MR_FREE(msg->remoteTaskBase.msg); + } + + /* ownership on the message was moved to MR_ClusterSendMsgBySlot function */ + msg->remoteTaskBase.msg = NULL; + msg->remoteTaskBase.msgLen = 0; + + /* Create local run */ + RemoteTaskLocalRun* localRun = MR_ALLOC(sizeof(*localRun)); + memcpy(localRun->id, msg->remoteTaskBase.id, ID_LEN); + localRun->args = msg->args; + localRun->r = msg->r; + localRun->msd = msg->msd; + mr_thpool_add_work(mrCtx.executionsThreadPool, MR_RemoteTaskRunOnLocal, localRun); + msg->args = NULL; + msg->r = NULL; + + if (msg->remoteTaskBase.timeout != SIZE_MAX) { + msg->remoteTaskBase.timeoutTask = MR_EventLoopAddTaskWithDelay(MR_RemoteTaskOnAllShardsTimeoutOut, msg, msg->remoteTaskBase.timeout); + } +} + +LIBMR_API void MR_RunOnAllShards(const char* remoteTaskName, + void* args, + Record* r, + void (*onDone)(void *pd, Record** result, size_t nResults, MRError** errs, size_t nErrs), + void *pd, + size_t timeout) { + StepDefinition* msd = mr_dictFetchValue(mrCtx.remoteTasksDict, remoteTaskName); + RedisModule_Assert(msd); + + RunOnShardsMsg *msg = MR_ALLOC(sizeof(*msg)); + msg->onDone = onDone; + msg->pd = pd; + msg->remoteTaskBase.timeout = timeout; + msg->remoteTaskBase.timeoutTask = NULL; + msg->remoteTaskBase.remoteTaskType = RemoteTaksMsgType_OnAllShards; + msg->args = args; + msg->r = r; + msg->expectedNResults = MR_ClusterGetSize(); + msg->nResultsArrived = 0; + msg->results = array_new(Record*, 10); + msg->errs = array_new(MRError*, 10); + msg->msd = msd; + + /* Set id */ + size_t id = __atomic_add_fetch(&mrCtx.lastExecutionId, 1, __ATOMIC_RELAXED); + SetId(msg->remoteTaskBase.id, msg->remoteTaskBase.idStr, id); + + MRError* error = NULL; + mr_Buffer buff; + mr_BufferInitialize(&buff); + mr_BufferWriter buffWriter; + mr_BufferWriterInit(&buffWriter, &buff); + /* write sender */ + mr_BufferWriterWriteString(&buffWriter, MR_ClusterGetMyId()); + /* write id */ + mr_BufferWriterWriteBuff(&buffWriter, msg->remoteTaskBase.id, ID_LEN); + /* mapped name to invoke */ + mr_BufferWriterWriteString(&buffWriter, remoteTaskName); + /* Serialize args */ + msd->type->serialize(&buffWriter, args, &error); + /* todo: handler serialization failure */ + /* serialize the record */ + MR_RecordSerialize(r, &buffWriter); + + msg->remoteTaskBase.msg = buff.buff; + msg->remoteTaskBase.msgLen = buff.size; + + MR_EventLoopAddTask(MR_RunOnAllShardsInternal, msg); +} diff --git a/src/mr.h b/src/mr.h index d98f354..e2bc0c9 100644 --- a/src/mr.h +++ b/src/mr.h @@ -65,6 +65,32 @@ typedef Record* (*ExecutionReader)(ExecutionCtx* ectx, void* args); typedef Record* (*ExecutionMapper)(ExecutionCtx* ectx, Record* r, void* args); typedef int (*ExecutionFilter)(ExecutionCtx* ectx, Record* r, void* args); typedef Record* (*ExecutionAccumulator)(ExecutionCtx* ectx, Record* accumulator, Record* r, void* args); +typedef void (*RemoteTask)(Record* r, void* args, void (*onDone)(void* PD, Record *r), void (*onError)(void* PD, MRError *r), void *pd); + +/* Run a remote task on a shard responsible for a given key. + * There is not guarantee on which thread the task will run, if + * the current shard is responsible for the given key or if its + * a none cluster environment, then the callback will be called + * immediately (an so the onDone/onError) callbacks. + * If the key located on the remote shard, the task will + * be invoke on the thread pool of this remote shard, the onDone/onError + * callback will be invoke on the thread pool of the current shard. */ +LIBMR_API void MR_RunOnKey(const char* keyName, + size_t keyNameSize, + const char* remoteTaskName, + void* args, + Record* r, + void (*onDone)(void *pd, Record* result), + void (*onError)(void *pd, MRError* err), + void *pd, + size_t timeout); + +LIBMR_API void MR_RunOnAllShards(const char* remoteTaskName, + void* args, + Record* r, + void (*onDone)(void *pd, Record** result, size_t nResults, MRError** errs, size_t nErrs), + void *pd, + size_t timeout); /* Creatign a new execution builder */ LIBMR_API ExecutionBuilder* MR_CreateExecutionBuilder(const char* readerName, void* args); @@ -138,6 +164,9 @@ LIBMR_API void MR_RegisterFilter(const char* name, ExecutionFilter filter, MRObj /* Register an accumulate step */ LIBMR_API void MR_RegisterAccumulator(const char* name, ExecutionAccumulator accumulator, MRObjectType* argType); +/* Register a remote task */ +LIBMR_API void MR_RegisterRemoteTask(const char* name, RemoteTask remote, MRObjectType* argType); + /* Serialization Context functions */ LIBMR_API long long MR_SerializationCtxReadeLongLong(ReaderSerializationCtx* sctx, MRError** err); LIBMR_API const char* MR_SerializationCtxReadeBuffer(ReaderSerializationCtx* sctx, size_t* len, MRError** err); diff --git a/tests/mr_test_module/.cargo/config.toml b/tests/mr_test_module/.cargo/config.toml new file mode 100644 index 0000000..94c34ce --- /dev/null +++ b/tests/mr_test_module/.cargo/config.toml @@ -0,0 +1,2 @@ +[env] +MODULE_NAME = "MRTESTS" \ No newline at end of file diff --git a/tests/mr_test_module/Cargo.toml b/tests/mr_test_module/Cargo.toml index 1ddb34b..bde09d0 100644 --- a/tests/mr_test_module/Cargo.toml +++ b/tests/mr_test_module/Cargo.toml @@ -8,12 +8,15 @@ license = "Redis Source Available License 2.0 (RSALv2) or the Server Side Public # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -redis-module = { version="0.22.0", features = ["experimental-api"]} -# redis-module = { path="/home/meir/work/redismodule-rs", features = ["experimental-api"]} +#redis-module = { version="0.22.0", features = ["experimental-api"]} +redis-module = { git = "https://github.com/RedisLabsModules/redismodule-rs", branch = "api_extentions", features = ["experimental-api"]} serde_json = "1.0" serde = "1.0" serde_derive = "1.0" libc = "0.2" +lib_mr = { path = "../../" } +lib_mr_derive = { path = "../../LibMRDerive/" } +linkme = "0.3" [build-dependencies] bindgen = "0.57" diff --git a/tests/mr_test_module/Makefile b/tests/mr_test_module/Makefile index 55f8e11..a3265d4 100644 --- a/tests/mr_test_module/Makefile +++ b/tests/mr_test_module/Makefile @@ -18,14 +18,15 @@ else endif RUSTFLAGS=-L$(ROOT)/src/ +MODULE_NAME=MRTESTS all: build build_libmr: - DEBUG=$(DEBUG_RUN) COVERAGE=$(COVERAGE_RUN) MODULE_NAME=MRTESTS make -C $(ROOT)/src/ + DEBUG=$(DEBUG_RUN) COVERAGE=$(COVERAGE_RUN) MODULE_NAME=$(MODULE_NAME) make -C $(ROOT)/src/ build: build_libmr - RUSTFLAGS="$(RUSTFLAGS)" cargo build $(EXTRA_ARGS) + RUSTFLAGS="$(RUSTFLAGS)" MODULE_NAME==$(MODULE_NAME) cargo build $(EXTRA_ARGS) run: build redis-server --loadmodule $(MODULE_PATH) diff --git a/tests/mr_test_module/build.rs b/tests/mr_test_module/build.rs deleted file mode 100644 index 6e1181d..0000000 --- a/tests/mr_test_module/build.rs +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright Redis Ltd. 2021 - present - * Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or - * the Server Side Public License v1 (SSPLv1). - */ - -extern crate bindgen; -extern crate cc; - -use std::env; -use std::path::PathBuf; - -#[derive(Debug)] -struct RedisModuleCallback; - -fn main() { - let build = bindgen::Builder::default(); - - let bindings = build - .header("src/include/mr.h") - .size_t_is_usize(true) - .generate() - .expect("error generating bindings"); - - let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); - bindings - .write_to_file(out_path.join("mr.rs")) - .expect("failed to write bindings to file"); -} diff --git a/tests/mr_test_module/pytests/run_full_tests.sh b/tests/mr_test_module/pytests/run_full_tests.sh new file mode 100755 index 0000000..fd25c02 --- /dev/null +++ b/tests/mr_test_module/pytests/run_full_tests.sh @@ -0,0 +1,21 @@ +#!/bin/bash +set -x +set -e + +echo oss +DEBUG=$DEBUG ./run_tests.sh --env-reuse "$@" +echo single shard cluster +DEBUG=$DEBUG ./run_tests.sh --env-reuse --env oss-cluster --shards-count 1 "$@" +echo 2 shards cluster +DEBUG=$DEBUG ./run_tests.sh --env-reuse --env oss-cluster --shards-count 2 "$@" +echo 3 shards cluster +DEBUG=$DEBUG ./run_tests.sh --env-reuse --env oss-cluster --shards-count 3 "$@" + +echo ssl certificates +bash ../generate_tests_cert.sh + +echo 2 shards cluster ssl enabled +DEBUG=$DEBUG ./run_tests.sh --env-reuse --env oss-cluster --shards-count 2 --tls --tls-cert-file ../tests/tls/redis.crt --tls-key-file ../tests/tls/redis.key --tls-ca-cert-file ../tests/tls/ca.crt --tls-passphrase foobar "$@" + +echo 3 shards cluster ssl enabled +DEBUG=$DEBUG ./run_tests.sh --env-reuse --env oss-cluster --shards-count 3 --tls --tls-cert-file ../tests/tls/redis.crt --tls-key-file ../tests/tls/redis.key --tls-ca-cert-file ../tests/tls/ca.crt --tls-passphrase foobar "$@" \ No newline at end of file diff --git a/tests/mr_test_module/pytests/run_tests.sh b/tests/mr_test_module/pytests/run_tests.sh index 9bf04dd..95f8555 100755 --- a/tests/mr_test_module/pytests/run_tests.sh +++ b/tests/mr_test_module/pytests/run_tests.sh @@ -5,10 +5,17 @@ HERE="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" ROOT=$(cd $HERE/../../../ && pwd) +OS=$(uname -s 2>/dev/null) +if [[ $OS == Darwin ]]; then + LIB_EXTENTION=dylib +else + LIB_EXTENTION=so +fi + if [[ $DEBUG == 1 ]]; then - MODULE_PATH=$HERE/../target/debug/libmr_test.so + MODULE_PATH=$HERE/../target/debug/libmr_test.$LIB_EXTENTION else - MODULE_PATH=$HERE/../target/release/libmr_test.so + MODULE_PATH=$HERE/../target/release/libmr_test.$LIB_EXTENTION fi diff --git a/tests/mr_test_module/pytests/test_basic.py b/tests/mr_test_module/pytests/test_basic.py index 1d243be..4840d96 100644 --- a/tests/mr_test_module/pytests/test_basic.py +++ b/tests/mr_test_module/pytests/test_basic.py @@ -63,3 +63,18 @@ def testUnevenWork(env, conn): except Exception as e: if str(e) != 'timeout': raise e + +@MRTestDecorator() +def testRemoteTaskOnKey(env, conn): + conn.execute_command('set', 'x', '1') + env.expect('lmrtest.get', 'x').equal('1') + env.expect('lmrtest.get', 'y').error().contains('bad result returned from') + +@MRTestDecorator() +def testRemoteTaskOnAllShards(env, conn): + for i in range(100): + conn.execute_command('set', 'doc%d' % i, '1') + env.expect('lmrtest.dbsize').equal(100) + for i in range(100): + conn.execute_command('del', 'doc%d' % i) + env.expect('lmrtest.dbsize').equal(0) diff --git a/tests/mr_test_module/src/include/mr.h b/tests/mr_test_module/src/include/mr.h deleted file mode 100644 index 7c6c285..0000000 --- a/tests/mr_test_module/src/include/mr.h +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Copyright Redis Ltd. 2021 - present - * Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or - * the Server Side Public License v1 (SSPLv1). - */ - -#ifndef SRC_MR_H_ -#define SRC_MR_H_ - -#include -#include - -#define LIBMR_API __attribute__ ((visibility("default"))) - -typedef struct RedisModuleCtx RedisModuleCtx; - -typedef struct MRError MRError; - -extern RedisModuleCtx* mr_staticCtx; - -/* Opaque struct build an execution */ -typedef struct ExecutionBuilder ExecutionBuilder; - -/* Opaque struct represents an execution */ -typedef struct Execution Execution; - -/* Opaque struct represents a record that pass in the execution pipe */ -typedef struct Record Record; - -/* Opaque struct that allow to serialize and deserialize objects */ -typedef struct mr_BufferReader ReaderSerializationCtx; -typedef struct mr_BufferWriter WriteSerializationCtx; - -/* MRObjectType callbacks definition */ -typedef void (*ObjectFree)(void* arg); -typedef void* (*ObjectDuplicate)(void* arg); -typedef void (*ObjectSerialize)(WriteSerializationCtx* sctx, void* arg, MRError** error); -typedef void* (*ObjectDeserialize)(ReaderSerializationCtx* sctx, MRError** error); -typedef char* (*ObjectToString)(void* arg); - -/* represent map reduce object type */ -typedef struct MRObjectType{ - char* type; - size_t id; - ObjectFree free; - ObjectDuplicate dup; - ObjectSerialize serialize; - ObjectDeserialize deserialize; - ObjectToString tostring; -}MRObjectType; - -/* Opaque struct that is given to execution steps */ -typedef struct ExecutionCtx ExecutionCtx; -LIBMR_API Record* MR_ExecutionCtxGetResult(ExecutionCtx* ectx, size_t i); -LIBMR_API size_t MR_ExecutionCtxGetResultsLen(ExecutionCtx* ectx); -LIBMR_API const char* MR_ExecutionCtxGetError(ExecutionCtx* ectx, size_t i); -LIBMR_API size_t MR_ExecutionCtxGetErrorsLen(ExecutionCtx* ectx); -LIBMR_API void MR_ExecutionCtxSetError(ExecutionCtx* ectx, const char* err, size_t len); - -/* Execution Callback definition */ -typedef void(*ExecutionCallback)(ExecutionCtx* ectx, void* pd); - -/* step functions signiture */ -typedef Record* (*ExecutionReader)(ExecutionCtx* ectx, void* args); -typedef Record* (*ExecutionMapper)(ExecutionCtx* ectx, Record* r, void* args); -typedef int (*ExecutionFilter)(ExecutionCtx* ectx, Record* r, void* args); -typedef Record* (*ExecutionAccumulator)(ExecutionCtx* ectx, Record* accumulator, Record* r, void* args); - -/* Creatign a new execution builder */ -LIBMR_API ExecutionBuilder* MR_CreateExecutionBuilder(const char* readerName, void* args); - -/* Add map step to the given builder. - * The function takes ownership on the given - * args so the user is not allow to use it anymore. */ -LIBMR_API void MR_ExecutionBuilderMap(ExecutionBuilder* builder, const char* name, void* args); - -/* Add filter step to the given builder. - * The function takes ownership on the given - * args so the user is not allow to use it anymore. */ -LIBMR_API void MR_ExecutionBuilderFilter(ExecutionBuilder* builder, const char* name, void* args); - -/* Add accumulate step to the given builder. - * The function takes ownership on the given - * args so the user is not allow to use it anymore. */ -LIBMR_API void MR_ExecutionBuilderBuilAccumulate(ExecutionBuilder* builder, const char* name, void* args); - -/* Add a collect step to the builder. - * Will return all the records to the initiator */ -LIBMR_API void MR_ExecutionBuilderCollect(ExecutionBuilder* builder); - -/* Add a reshuffle step to the builder. */ -LIBMR_API void MR_ExecutionBuilderReshuffle(ExecutionBuilder* builder); - -/* Free the give execution builder */ -LIBMR_API void MR_FreeExecutionBuilder(ExecutionBuilder* builder); - -/* Create execution from the given builder. - * Returns Execution which need to be freed using RM_FreeExecution. - * The user can use the returned Execution to set - * different callbacks, such as on_done callback and hold/resume callbacks. - * - * After callbacks are set the user can run the execution using MR_Run - * - * The function borrow the builder, which means that once returned - * the user can still use the builder, change it, or create more executions - * from it. - * - * Return NULL on error and set the error on err out param */ -LIBMR_API Execution* MR_CreateExecution(ExecutionBuilder* builder, MRError** err); - -/* Set max idle time (in ms) for the given execution */ -LIBMR_API void MR_ExecutionSetMaxIdle(Execution* e, size_t maxIdle); - -/* Set on execution done callbac */ -LIBMR_API void MR_ExecutionSetOnDoneHandler(Execution* e, ExecutionCallback onDone, void* pd); - -/* Run the given execution, should at most once on each execution. */ -LIBMR_API void MR_Run(Execution* e); - -/* Free the given execution */ -LIBMR_API void MR_FreeExecution(Execution* e); - -/* Initialize mr library */ -LIBMR_API int MR_Init(RedisModuleCtx* ctx, size_t numThreads); - -/* Register a new object type */ -LIBMR_API int MR_RegisterObject(MRObjectType* t); - -/* Register a reader */ -LIBMR_API void MR_RegisterReader(const char* name, ExecutionReader reader, MRObjectType* argType); - -/* Register a map step */ -LIBMR_API void MR_RegisterMapper(const char* name, ExecutionMapper mapper, MRObjectType* argType); - -/* Register a filter step */ -LIBMR_API void MR_RegisterFilter(const char* name, ExecutionFilter filter, MRObjectType* argType); - -/* Register an accumulate step */ -LIBMR_API void MR_RegisterAccumulator(const char* name, ExecutionAccumulator accumulator, MRObjectType* argType); - -/* Serialization Context functions */ -LIBMR_API long long MR_SerializationCtxReadeLongLong(ReaderSerializationCtx* sctx, MRError** err); -LIBMR_API const char* MR_SerializationCtxReadeBuffer(ReaderSerializationCtx* sctx, size_t* len, MRError** err); -LIBMR_API double MR_SerializationCtxReadeDouble(ReaderSerializationCtx* sctx, MRError** err); -LIBMR_API void MR_SerializationCtxWriteLongLong(WriteSerializationCtx* sctx, long long val, MRError** err); -LIBMR_API void MR_SerializationCtxWriteBuffer(WriteSerializationCtx* sctx, const char* buff, size_t len, MRError** err); -LIBMR_API void MR_SerializationCtxWriteDouble(WriteSerializationCtx* sctx, double val, MRError** err); - -/* records functions */ -typedef void (*SendAsRedisReply)(RedisModuleCtx*, void* record); -typedef size_t (*HashTag)(void* record); - -/* represent record type */ -typedef struct MRRecordType{ - MRObjectType type; - SendAsRedisReply sendReply; - HashTag hashTag; -}MRRecordType; - -/* Base record struct, each record should have it - * as first value */ -struct Record { - MRRecordType* recordType; -}; - -/* Register a new Record type */ -LIBMR_API int MR_RegisterRecord(MRRecordType* t); - -/* Free the give Record */ -LIBMR_API void MR_RecordFree(Record* r); - -/* Calculate slot on the given buffer */ -LIBMR_API size_t MR_CalculateSlot(const char* buff, size_t len); - -/* Create a new error object */ -LIBMR_API MRError* MR_ErrorCreate(const char* msg, size_t len); - -/* Get error message from the error object */ -LIBMR_API const char* MR_ErrorGetMessage(MRError* err); - -/* Free the error object */ -LIBMR_API void MR_ErrorFree(MRError* err); - -/***************** no public API **********************/ -MRObjectType* MR_GetObjectType(size_t id); - -#endif /* SRC_MR_H_ */ diff --git a/tests/mr_test_module/src/lib.rs b/tests/mr_test_module/src/lib.rs index 7c6c93f..a0aa8db 100644 --- a/tests/mr_test_module/src/lib.rs +++ b/tests/mr_test_module/src/lib.rs @@ -8,78 +8,35 @@ extern crate serde_derive; use redis_module::redisraw::bindings::{ - RedisModule_ScanCursorCreate, - RedisModuleScanCursor, - RedisModule_Scan, - RedisModule_GetDetachedThreadSafeContext, - RedisModuleCtx, - RedisModuleString, - RedisModuleKey, - RedisModule_ThreadSafeContextLock, + RedisModuleCtx, RedisModuleKey, RedisModuleScanCursor, RedisModuleString, + RedisModule_GetDetachedThreadSafeContext, RedisModule_Scan, RedisModule_ScanCursorCreate, + RedisModule_ScanCursorDestroy, RedisModule_ThreadSafeContextLock, RedisModule_ThreadSafeContextUnlock, - RedisModule_ScanCursorDestroy, }; use redis_module::{ - redis_module, - redis_command, - Context, - RedisError, - RedisResult, - RedisString, - RedisValue, - Status, + redis_command, redis_module, Context, RedisError, RedisResult, RedisString, RedisValue, Status, ThreadSafeContext, }; use std::str; -mod libmrraw; -mod libmr; - -use libmr::{ - create_builder, - BaseObject, - Record, - Reader, - MapStep, - RecordType, - RustMRError, - FilterStep, - AccumulateStep, +use mr::libmr::{ + accumulator::AccumulateStep, base_object::BaseObject, calc_slot, + execution_builder::create_builder, filter::FilterStep, mapper::MapStep, mr_init, + reader::Reader, record::Record, remote_task::run_on_key, remote_task::run_on_all_shards, remote_task::RemoteTask, RustMRError, }; -use libmrraw::bindings::{ - MR_Init, - MRRecordType, - MR_CalculateSlot, -}; - -use std::os::raw::{ - c_void, - c_char, -}; +use std::os::raw::c_void; use std::{thread, time}; -#[allow(improper_ctypes)] -#[link(name = "mr", kind = "static")] -extern "C" {} - -#[allow(improper_ctypes)] -#[link(name = "ssl")] -extern "C" {} - -#[allow(improper_ctypes)] -#[link(name = "crypto")] -extern "C" {} +use mr_derive::BaseObject; static mut DETACHED_CTX: *mut RedisModuleCtx = 0 as *mut RedisModuleCtx; fn get_redis_ctx() -> *mut RedisModuleCtx { - unsafe { - DETACHED_CTX - } + unsafe { DETACHED_CTX } } fn get_ctx() -> Context { @@ -89,44 +46,37 @@ fn get_ctx() -> Context { fn ctx_lock() { let inner = get_redis_ctx(); - unsafe{ + unsafe { RedisModule_ThreadSafeContextLock.unwrap()(inner); } } fn ctx_unlock() { let inner = get_redis_ctx(); - unsafe{ + unsafe { RedisModule_ThreadSafeContextUnlock.unwrap()(inner); } } fn strin_record_new(s: String) -> StringRecord { - let mut r = unsafe{ - HASH_RECORD_TYPE.as_ref().unwrap().create() - }; - r.s = Some(s); - r + StringRecord { s: Some(s) } } fn int_record_new(i: i64) -> IntRecord { - let mut r = unsafe{ - INT_RECORD_TYPE.as_ref().unwrap().create() - }; - r.i = i; - r + IntRecord { i: i } } fn lmr_map_error(ctx: &Context, _args: Vec) -> RedisResult { - let execution = create_builder(KeysReader::new(None)). - map(ErrorMapper). - filter(DummyFilter). - reshuffle(). - collect(). - accumulate(CountAccumulator). - create_execution().map_err(|e|RedisError::String(e))?; + let execution = create_builder(KeysReader::new(None)) + .map(ErrorMapper) + .filter(DummyFilter) + .reshuffle() + .collect() + .accumulate(CountAccumulator) + .create_execution() + .map_err(|e| RedisError::String(e))?; let blocked_client = ctx.block_client(); - execution.set_done_hanlder(|res, errs|{ + execution.set_done_hanlder(|res, errs| { let thread_ctx = ThreadSafeContext::with_blocked_client(blocked_client); let mut final_res = Vec::new(); final_res.push(RedisValue::Integer(res.len() as i64)); @@ -140,15 +90,16 @@ fn lmr_map_error(ctx: &Context, _args: Vec) -> RedisResult { } fn lmr_filter_error(ctx: &Context, _args: Vec) -> RedisResult { - let execution = create_builder(KeysReader::new(None)). - filter(ErrorFilter). - map(DummyMapper). - reshuffle(). - collect(). - accumulate(CountAccumulator). - create_execution().map_err(|e|RedisError::String(e))?; + let execution = create_builder(KeysReader::new(None)) + .filter(ErrorFilter) + .map(DummyMapper) + .reshuffle() + .collect() + .accumulate(CountAccumulator) + .create_execution() + .map_err(|e| RedisError::String(e))?; let blocked_client = ctx.block_client(); - execution.set_done_hanlder(|res, errs|{ + execution.set_done_hanlder(|res, errs| { let thread_ctx = ThreadSafeContext::with_blocked_client(blocked_client); let mut final_res = Vec::new(); final_res.push(RedisValue::Integer(res.len() as i64)); @@ -162,16 +113,17 @@ fn lmr_filter_error(ctx: &Context, _args: Vec) -> RedisResult { } fn lmr_accumulate_error(ctx: &Context, _args: Vec) -> RedisResult { - let execution = create_builder(KeysReader::new(None)). - accumulate(ErrorAccumulator). - map(DummyMapper). - filter(DummyFilter). - reshuffle(). - collect(). - accumulate(CountAccumulator). - create_execution().map_err(|e|RedisError::String(e))?; + let execution = create_builder(KeysReader::new(None)) + .accumulate(ErrorAccumulator) + .map(DummyMapper) + .filter(DummyFilter) + .reshuffle() + .collect() + .accumulate(CountAccumulator) + .create_execution() + .map_err(|e| RedisError::String(e))?; let blocked_client = ctx.block_client(); - execution.set_done_hanlder(|res, errs|{ + execution.set_done_hanlder(|res, errs| { let thread_ctx = ThreadSafeContext::with_blocked_client(blocked_client); let mut final_res = Vec::new(); final_res.push(RedisValue::Integer(res.len() as i64)); @@ -185,12 +137,13 @@ fn lmr_accumulate_error(ctx: &Context, _args: Vec) -> RedisResult { } fn lmr_uneven_work(ctx: &Context, _args: Vec) -> RedisResult { - let execution = create_builder(MaxIdleReader::new(1)). - map(UnevenWorkMapper::new()). - create_execution().map_err(|e|RedisError::String(e))?; + let execution = create_builder(MaxIdleReader::new(1)) + .map(UnevenWorkMapper::new()) + .create_execution() + .map_err(|e| RedisError::String(e))?; execution.set_max_idle(2000); let blocked_client = ctx.block_client(); - execution.set_done_hanlder(|mut res, mut errs|{ + execution.set_done_hanlder(|mut res, mut errs| { let thread_ctx = ThreadSafeContext::with_blocked_client(blocked_client); if errs.len() > 0 { let err = errs.pop().unwrap(); @@ -207,15 +160,16 @@ fn lmr_uneven_work(ctx: &Context, _args: Vec) -> RedisResult { } fn lmr_read_error(ctx: &Context, _args: Vec) -> RedisResult { - let execution = create_builder(ErrorReader::new()). - map(DummyMapper). - filter(DummyFilter). - reshuffle(). - collect(). - accumulate(CountAccumulator). - create_execution().map_err(|e|RedisError::String(e))?; + let execution = create_builder(ErrorReader::new()) + .map(DummyMapper) + .filter(DummyFilter) + .reshuffle() + .collect() + .accumulate(CountAccumulator) + .create_execution() + .map_err(|e| RedisError::String(e))?; let blocked_client = ctx.block_client(); - execution.set_done_hanlder(|res, errs|{ + execution.set_done_hanlder(|res, errs| { let thread_ctx = ThreadSafeContext::with_blocked_client(blocked_client); let mut final_res = Vec::new(); final_res.push(RedisValue::Integer(res.len() as i64)); @@ -229,12 +183,13 @@ fn lmr_read_error(ctx: &Context, _args: Vec) -> RedisResult { } fn lmr_count_key(ctx: &Context, _args: Vec) -> RedisResult { - let execution = create_builder(KeysReader::new(None)). - collect(). - accumulate(CountAccumulator). - create_execution().map_err(|e|RedisError::String(e))?; + let execution = create_builder(KeysReader::new(None)) + .collect() + .accumulate(CountAccumulator) + .create_execution() + .map_err(|e| RedisError::String(e))?; let blocked_client = ctx.block_client(); - execution.set_done_hanlder(|mut res, mut errs|{ + execution.set_done_hanlder(|mut res, mut errs| { let thread_ctx = ThreadSafeContext::with_blocked_client(blocked_client); if errs.len() > 0 { let err = errs.pop().unwrap(); @@ -251,12 +206,13 @@ fn lmr_count_key(ctx: &Context, _args: Vec) -> RedisResult { } fn lmr_reach_max_idle(ctx: &Context, _args: Vec) -> RedisResult { - let execution = create_builder(MaxIdleReader::new(50)). - collect(). - create_execution().map_err(|e|RedisError::String(e))?; - execution.set_max_idle(20); + let execution = create_builder(MaxIdleReader::new(200)) + .collect() + .create_execution() + .map_err(|e| RedisError::String(e))?; + execution.set_max_idle(10); let blocked_client = ctx.block_client(); - execution.set_done_hanlder(|mut res, mut errs|{ + execution.set_done_hanlder(|mut res, mut errs| { let thread_ctx = ThreadSafeContext::with_blocked_client(blocked_client); if errs.len() > 0 { let err = errs.pop().unwrap(); @@ -273,12 +229,13 @@ fn lmr_reach_max_idle(ctx: &Context, _args: Vec) -> RedisResult { } fn lmr_read_keys_type(ctx: &Context, _args: Vec) -> RedisResult { - let execution = create_builder(KeysReader::new(None)). - map(TypeMapper). - collect(). - create_execution().map_err(|e|RedisError::String(e))?; + let execution = create_builder(KeysReader::new(None)) + .map(TypeMapper) + .collect() + .create_execution() + .map_err(|e| RedisError::String(e))?; let blocked_client = ctx.block_client(); - execution.set_done_hanlder(|mut res, mut errs|{ + execution.set_done_hanlder(|mut res, mut errs| { let thread_ctx = ThreadSafeContext::with_blocked_client(blocked_client); if errs.len() > 0 { let err = errs.pop().unwrap(); @@ -296,17 +253,21 @@ fn lmr_read_keys_type(ctx: &Context, _args: Vec) -> RedisResult { fn replace_keys_values(ctx: &Context, args: Vec) -> RedisResult { let mut args = args.into_iter().skip(1); - let prefix = args.next().ok_or(RedisError::Str("not prefix was given"))?.try_as_str()?; - let execution = create_builder(KeysReader::new(Some(prefix.to_string()))). - filter(TypeFilter::new("string".to_string())). - map(ReadStringMapper{}). - reshuffle(). - map(WriteDummyString{}). - collect(). - create_execution().map_err(|e|RedisError::String(e))?; - + let prefix = args + .next() + .ok_or(RedisError::Str("not prefix was given"))? + .try_as_str()?; + let execution = create_builder(KeysReader::new(Some(prefix.to_string()))) + .filter(TypeFilter::new("string".to_string())) + .map(ReadStringMapper {}) + .reshuffle() + .map(WriteDummyString {}) + .collect() + .create_execution() + .map_err(|e| RedisError::String(e))?; + let blocked_client = ctx.block_client(); - execution.set_done_hanlder(|mut res, mut errs|{ + execution.set_done_hanlder(|mut res, mut errs| { let thread_ctx = ThreadSafeContext::with_blocked_client(blocked_client); if errs.len() > 0 { let err = errs.pop().unwrap(); @@ -323,12 +284,13 @@ fn replace_keys_values(ctx: &Context, args: Vec) -> RedisResult { } fn lmr_read_string_keys(ctx: &Context, _args: Vec) -> RedisResult { - let execution = create_builder(KeysReader::new(None)). - filter(TypeFilter::new("string".to_string())). - collect(). - create_execution().map_err(|e|RedisError::String(e))?; + let execution = create_builder(KeysReader::new(None)) + .filter(TypeFilter::new("string".to_string())) + .collect() + .create_execution() + .map_err(|e| RedisError::String(e))?; let blocked_client = ctx.block_client(); - execution.set_done_hanlder(|mut res, mut errs|{ + execution.set_done_hanlder(|mut res, mut errs| { let thread_ctx = ThreadSafeContext::with_blocked_client(blocked_client); if errs.len() > 0 { let err = errs.pop().unwrap(); @@ -344,12 +306,61 @@ fn lmr_read_string_keys(ctx: &Context, _args: Vec) -> RedisResult { Ok(RedisValue::NoReply) } +fn lmr_dbsize(ctx: &Context, _args: Vec) -> RedisResult { + let blocked_client = ctx.block_client(); + run_on_all_shards( + RemoteTaskDBSize, + int_record_new(0), + move |results: Vec, mut errs: Vec| { + let thread_ctx = ThreadSafeContext::with_blocked_client(blocked_client); + if errs.len() > 0 { + let err = errs.pop().unwrap(); + thread_ctx.reply(Err(RedisError::String(err))); + } else { + let sum: i64 = results.into_iter().map(|e| e.i).sum(); + thread_ctx.reply(Ok(RedisValue::Integer(sum))); + } + }, + usize::MAX, + ); + Ok(RedisValue::NoReply) +} + +fn lmr_get(ctx: &Context, args: Vec) -> RedisResult { + let mut args = args.into_iter().skip(1); + let ke_redis_string = args.next().ok_or(RedisError::Str("not prefix was given"))?; + let key = ke_redis_string.try_as_str()?; + let blocked_client = ctx.block_client(); + thread::spawn(move || { + let record = strin_record_new(key.to_string()); + run_on_key( + key.as_bytes(), + RemoteTaskGet, + record, + move |res: Result| { + let thread_ctx = ThreadSafeContext::with_blocked_client(blocked_client); + match res { + Ok(mut r) => { + thread_ctx.reply(Ok(r.to_redis_value())); + } + Err(e) => { + thread_ctx.reply(Err(RedisError::String(e))); + } + } + }, + usize::MAX, + ); + }); + Ok(RedisValue::NoReply) +} + fn lmr_read_all_keys(ctx: &Context, _args: Vec) -> RedisResult { - let execution = create_builder(KeysReader::new(None)). - collect(). - create_execution().map_err(|e|RedisError::String(e))?; + let execution = create_builder(KeysReader::new(None)) + .collect() + .create_execution() + .map_err(|e| RedisError::String(e))?; let blocked_client = ctx.block_client(); - execution.set_done_hanlder(|mut res, mut errs|{ + execution.set_done_hanlder(|mut res, mut errs| { let thread_ctx = ThreadSafeContext::with_blocked_client(blocked_client); if errs.len() > 0 { let err = errs.pop().unwrap(); @@ -365,128 +376,137 @@ fn lmr_read_all_keys(ctx: &Context, _args: Vec) -> RedisResult { Ok(RedisValue::NoReply) } -impl Default for crate::libmrraw::bindings::Record { - fn default() -> Self { - crate::libmrraw::bindings::Record { - recordType: 0 as *mut MRRecordType, +#[derive(Clone, Serialize, Deserialize, BaseObject)] +struct RemoteTaskGet; + +impl RemoteTask for RemoteTaskGet { + type InRecord = StringRecord; + type OutRecord = StringRecord; + + fn task( + self, + mut r: Self::InRecord, + on_done: Box) + Send>, + ) { + let ctx = get_ctx(); + ctx_lock(); + let res = ctx.call("get", &[r.s.as_ref().unwrap()]); + ctx_unlock(); + if let Ok(res) = res { + if let RedisValue::StringBuffer(res) = res { + r.s = Some(std::str::from_utf8(&res).unwrap().to_string()); + on_done(Ok(r)); + } else { + on_done(Err("bad result returned from `get` command".to_string())) + } + } else { + on_done(Err("bad result returned from `get` command".to_string())) } } } -#[repr(C)] -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, BaseObject)] +struct RemoteTaskDBSize; + +impl RemoteTask for RemoteTaskDBSize { + type InRecord = IntRecord; + type OutRecord = IntRecord; + + fn task( + self, + mut r: Self::InRecord, + on_done: Box) + Send>, + ) { + let ctx = get_ctx(); + ctx_lock(); + let res = ctx.call("dbsize", &[]); + ctx_unlock(); + if let Ok(res) = res { + if let RedisValue::Integer(res) = res { + r.i = res; + on_done(Ok(r)); + } else { + on_done(Err("bad result returned from `dbsize` command".to_string())) + } + } else { + on_done(Err("bad result returned from `dbsize` command".to_string())) + } + } +} + +#[derive(Clone, Serialize, Deserialize, BaseObject)] struct StringRecord { - #[serde(skip)] - base: crate::libmrraw::bindings::Record, pub s: Option, } impl Record for StringRecord { - fn new(t: *mut MRRecordType) -> Self { - StringRecord { - base: crate::libmrraw::bindings::Record { - recordType: t, - }, - s: None, - } - } - fn to_redis_value(&mut self) -> RedisValue { match self.s.take() { Some(s) => RedisValue::BulkString(s), None => RedisValue::Null, } - } fn hash_slot(&self) -> usize { - unsafe{MR_CalculateSlot(self.s.as_ref().unwrap().as_ptr() as *const c_char, self.s.as_ref().unwrap().len())} + calc_slot(self.s.as_ref().unwrap().as_bytes()) } } -impl BaseObject for StringRecord { - fn get_name() -> &'static str { - "StringRecord\0" - } -} - -#[repr(C)] -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, BaseObject)] struct IntRecord { - #[serde(skip)] - base: crate::libmrraw::bindings::Record, pub i: i64, } impl Record for IntRecord { - fn new(t: *mut MRRecordType) -> Self { - IntRecord { - base: crate::libmrraw::bindings::Record { - recordType: t, - }, - i: 0, - } - } - fn to_redis_value(&mut self) -> RedisValue { RedisValue::Integer(self.i) } fn hash_slot(&self) -> usize { let s = self.i.to_string(); - unsafe{MR_CalculateSlot(s.as_ptr() as *const c_char, s.len())} + calc_slot(s.as_bytes()) } } -impl BaseObject for IntRecord { - fn get_name() -> &'static str { - "IntRecord\0" - } -} - -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, BaseObject)] struct CountAccumulator; impl AccumulateStep for CountAccumulator { type InRecord = StringRecord; type Accumulator = IntRecord; - fn accumulate(&self, accumulator: Option, _r: Self::InRecord) -> Result { + fn accumulate( + &self, + accumulator: Option, + _r: Self::InRecord, + ) -> Result { let mut accumulator = match accumulator { Some(a) => a, - None => int_record_new(0) + None => int_record_new(0), }; - accumulator.i+=1; + accumulator.i += 1; Ok(accumulator) } } -impl BaseObject for CountAccumulator { - fn get_name() -> &'static str { - "CountAccumulator\0" - } -} - -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, BaseObject)] struct ErrorAccumulator; impl AccumulateStep for ErrorAccumulator { type InRecord = StringRecord; type Accumulator = StringRecord; - fn accumulate(&self, _accumulator: Option, _r: Self::InRecord) -> Result { + fn accumulate( + &self, + _accumulator: Option, + _r: Self::InRecord, + ) -> Result { Err("accumulate_error".to_string()) } } -impl BaseObject for ErrorAccumulator { - fn get_name() -> &'static str { - "ErrorAccumulator\0" - } -} - /* filter by key type */ -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, BaseObject)] struct DummyFilter; impl FilterStep for DummyFilter { @@ -497,14 +517,8 @@ impl FilterStep for DummyFilter { } } -impl BaseObject for DummyFilter { - fn get_name() -> &'static str { - "DummyFilter\0" - } -} - /* filter by key type */ -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, BaseObject)] struct ErrorFilter; impl FilterStep for ErrorFilter { @@ -515,23 +529,15 @@ impl FilterStep for ErrorFilter { } } -impl BaseObject for ErrorFilter { - fn get_name() -> &'static str { - "ErrorFilter\0" - } -} - /* filter by key type */ -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, BaseObject)] struct TypeFilter { t: String, } impl TypeFilter { - pub fn new(t: String) -> TypeFilter{ - TypeFilter{ - t: t, - } + pub fn new(t: String) -> TypeFilter { + TypeFilter { t: t } } } @@ -541,11 +547,11 @@ impl FilterStep for TypeFilter { fn filter(&self, r: &Self::R) -> Result { let ctx = get_ctx(); ctx_lock(); - let res = ctx.call("type",&[r.s.as_ref().unwrap()]); + let res = ctx.call("type", &[r.s.as_ref().unwrap()]); ctx_unlock(); if let Ok(res) = res { - if let RedisValue::SimpleString(res) = res { - if res == self.t { + if let RedisValue::StringBuffer(res) = res { + if std::str::from_utf8(&res).unwrap() == self.t { Ok(true) } else { Ok(false) @@ -559,14 +565,8 @@ impl FilterStep for TypeFilter { } } -impl BaseObject for TypeFilter { - fn get_name() -> &'static str { - "TypeFilter\0" - } -} - /* map key name to its type */ -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, BaseObject)] struct TypeMapper; impl MapStep for TypeMapper { @@ -576,11 +576,11 @@ impl MapStep for TypeMapper { fn map(&self, mut r: Self::InRecord) -> Result { let ctx = get_ctx(); ctx_lock(); - let res = ctx.call("type",&[r.s.as_ref().unwrap()]); + let res = ctx.call("type", &[r.s.as_ref().unwrap()]); ctx_unlock(); if let Ok(res) = res { - if let RedisValue::SimpleString(res) = res { - r.s = Some(res); + if let RedisValue::StringBuffer(res) = res { + r.s = Some(std::str::from_utf8(&res).unwrap().to_string()); Ok(r) } else { Err("bad result returned from type command".to_string()) @@ -591,14 +591,8 @@ impl MapStep for TypeMapper { } } -impl BaseObject for TypeMapper { - fn get_name() -> &'static str { - "TypeMapper\0" - } -} - /* map key name to its type */ -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, BaseObject)] struct ErrorMapper; impl MapStep for ErrorMapper { @@ -610,13 +604,8 @@ impl MapStep for ErrorMapper { } } -impl BaseObject for ErrorMapper { - fn get_name() -> &'static str { - "ErrorMapper\0" - } -} /* map key name to its type */ -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, BaseObject)] struct DummyMapper; impl MapStep for DummyMapper { @@ -628,22 +617,16 @@ impl MapStep for DummyMapper { } } -impl BaseObject for DummyMapper { - fn get_name() -> &'static str { - "DummyMapper\0" - } -} - /* map key name to its type */ -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, BaseObject)] struct UnevenWorkMapper { #[serde(skip)] - is_initiator: bool + is_initiator: bool, } impl UnevenWorkMapper { fn new() -> UnevenWorkMapper { - UnevenWorkMapper{ is_initiator: true } + UnevenWorkMapper { is_initiator: true } } } @@ -660,14 +643,7 @@ impl MapStep for UnevenWorkMapper { } } -impl BaseObject for UnevenWorkMapper { - fn get_name() -> &'static str { - "UnevenWorkMapper\0" - } -} - - -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, BaseObject)] struct ReadStringMapper; impl MapStep for ReadStringMapper { @@ -677,11 +653,11 @@ impl MapStep for ReadStringMapper { fn map(&self, mut r: Self::InRecord) -> Result { let ctx = get_ctx(); ctx_lock(); - let res = ctx.call("get",&[r.s.as_ref().unwrap()]); + let res = ctx.call("get", &[r.s.as_ref().unwrap()]); ctx_unlock(); if let Ok(res) = res { - if let RedisValue::SimpleString(res) = res { - r.s = Some(res); + if let RedisValue::StringBuffer(res) = res { + r.s = Some(std::str::from_utf8(&res).unwrap().to_string()); Ok(r) } else { Err("bad result returned from type command".to_string()) @@ -692,13 +668,7 @@ impl MapStep for ReadStringMapper { } } -impl BaseObject for ReadStringMapper { - fn get_name() -> &'static str { - "ReadStringMapper\0" - } -} - -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, BaseObject)] struct WriteDummyString; impl MapStep for WriteDummyString { @@ -708,11 +678,11 @@ impl MapStep for WriteDummyString { fn map(&self, mut r: Self::InRecord) -> Result { let ctx = get_ctx(); ctx_lock(); - let res = ctx.call("set",&[r.s.as_ref().unwrap(), "val"]); + let res = ctx.call("set", &[r.s.as_ref().unwrap(), "val"]); ctx_unlock(); if let Ok(res) = res { - if let RedisValue::SimpleString(res) = res { - r.s = Some(res); + if let RedisValue::StringBuffer(res) = res { + r.s = Some(std::str::from_utf8(&res).unwrap().to_string()); Ok(r) } else { Err("bad result returned from type command".to_string()) @@ -723,13 +693,7 @@ impl MapStep for WriteDummyString { } } -impl BaseObject for WriteDummyString { - fn get_name() -> &'static str { - "WriteDummyString\0" - } -} - -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, BaseObject)] struct MaxIdleReader { #[serde(skip)] is_initiator: bool, @@ -739,57 +703,49 @@ struct MaxIdleReader { impl MaxIdleReader { fn new(sleep_time: usize) -> MaxIdleReader { - MaxIdleReader {is_initiator: true, sleep_time: sleep_time, is_done: false} + MaxIdleReader { + is_initiator: true, + sleep_time: sleep_time, + is_done: false, + } } } impl Reader for MaxIdleReader { type R = StringRecord; - fn read(&mut self) -> Option> { + fn read(&mut self) -> Result, RustMRError> { if self.is_done { - return None; + return Ok(None); } self.is_done = true; if !self.is_initiator { let ten_millis = time::Duration::from_millis(self.sleep_time as u64); thread::sleep(ten_millis); } - Some(Ok(strin_record_new("record".to_string()))) + Ok(Some(strin_record_new("record".to_string()))) } } -impl BaseObject for MaxIdleReader { - fn get_name() -> &'static str { - "MaxIdleReader\0" - } -} - -#[derive(Clone, Serialize, Deserialize)] -struct ErrorReader{ +#[derive(Clone, Serialize, Deserialize, BaseObject)] +struct ErrorReader { is_done: bool, } impl ErrorReader { fn new() -> ErrorReader { - ErrorReader {is_done: false} + ErrorReader { is_done: false } } } impl Reader for ErrorReader { type R = StringRecord; - fn read(&mut self) -> Option> { + fn read(&mut self) -> Result, RustMRError> { if self.is_done { - return None; + return Ok(None); } self.is_done = true; - Some(Err("read_error".to_string())) - } -} - -impl BaseObject for ErrorReader { - fn get_name() -> &'static str { - "ErrorReader\0" + Err("read_error".to_string()) } } @@ -801,13 +757,14 @@ struct KeysReader { pending: Vec, #[serde(skip)] is_done: bool, - prefix: Option + prefix: Option, } impl KeysReader { fn new(prefix: Option) -> KeysReader { - let mut reader = KeysReader {cursor: None, - pending:Vec::new(), + let mut reader = KeysReader { + cursor: None, + pending: Vec::new(), is_done: false, prefix: prefix, }; @@ -816,12 +773,13 @@ impl KeysReader { } } -extern "C" fn cursor_callback(_ctx: *mut RedisModuleCtx, - keyname: *mut RedisModuleString, - _key: *mut RedisModuleKey, - privdata: *mut c_void) { - - let reader = unsafe{&mut *(privdata as *mut KeysReader)}; +extern "C" fn cursor_callback( + _ctx: *mut RedisModuleCtx, + keyname: *mut RedisModuleString, + _key: *mut RedisModuleKey, + privdata: *mut c_void, +) { + let reader = unsafe { &mut *(privdata as *mut KeysReader) }; let key_str = RedisString::from_ptr(keyname).unwrap(); if let Some(pre) = &reader.prefix { @@ -838,18 +796,26 @@ extern "C" fn cursor_callback(_ctx: *mut RedisModuleCtx, impl Reader for KeysReader { type R = StringRecord; - fn read(&mut self) -> Option> { - let cursor = *self.cursor.as_ref()?; + fn read(&mut self) -> Result, RustMRError> { + let cursor = *match self.cursor.as_ref() { + Some(s) => s, + None => return Ok(None), + }; loop { if let Some(element) = self.pending.pop() { - return Some(Ok(element)); + return Ok(Some(element)); } if self.is_done { - return None; + return Ok(None); } ctx_lock(); - let res = unsafe{ - let res = RedisModule_Scan.unwrap()(get_redis_ctx(), cursor, Some(cursor_callback), self as *mut KeysReader as *mut c_void); + let res = unsafe { + let res = RedisModule_Scan.unwrap()( + get_redis_ctx(), + cursor, + Some(cursor_callback), + self as *mut KeysReader as *mut c_void, + ); res }; ctx_unlock(); @@ -866,9 +832,7 @@ impl BaseObject for KeysReader { } fn init(&mut self) { - self.cursor = Some(unsafe{ - RedisModule_ScanCursorCreate.unwrap()() - }); + self.cursor = Some(unsafe { RedisModule_ScanCursorCreate.unwrap()() }); self.is_done = false; } } @@ -876,48 +840,31 @@ impl BaseObject for KeysReader { impl Drop for KeysReader { fn drop(&mut self) { if let Some(c) = self.cursor { - unsafe{RedisModule_ScanCursorDestroy.unwrap()(c)}; + unsafe { RedisModule_ScanCursorDestroy.unwrap()(c) }; } } } -static mut HASH_RECORD_TYPE: Option> = None; -static mut INT_RECORD_TYPE: Option> = None; - fn init_func(ctx: &Context, _args: &Vec) -> Status { - unsafe{ + unsafe { DETACHED_CTX = RedisModule_GetDetachedThreadSafeContext.unwrap()(ctx.ctx); + } + + mr_init(ctx, 3); + + KeysReader::register(); + + Status::Ok +} - MR_Init(ctx.ctx as *mut libmrraw::bindings::RedisModuleCtx, 3); - } - - unsafe{ - HASH_RECORD_TYPE = Some(RecordType::new()); - INT_RECORD_TYPE = Some(RecordType::new()); - }; - KeysReader::register(); - MaxIdleReader::register(); - ErrorReader::register(); - TypeMapper::register(); - ErrorMapper::register(); - DummyMapper::register(); - TypeFilter::register(); - DummyFilter::register(); - ErrorFilter::register(); - WriteDummyString::register(); - ReadStringMapper::register(); - CountAccumulator::register(); - ErrorAccumulator::register(); - UnevenWorkMapper::register(); - Status::Ok -} - -redis_module!{ +redis_module! { name: "lmrtest", version: 99_99_99, data_types: [], init: init_func, commands: [ + ["lmrtest.dbsize", lmr_dbsize, "readonly", 0,0,0], + ["lmrtest.get", lmr_get, "readonly", 0,0,0], ["lmrtest.readallkeys", lmr_read_all_keys, "readonly", 0,0,0], ["lmrtest.readallkeystype", lmr_read_keys_type, "readonly", 0,0,0], ["lmrtest.readallstringkeys", lmr_read_string_keys, "readonly", 0,0,0], diff --git a/tests/mr_test_module/src/libmr/mod.rs b/tests/mr_test_module/src/libmr/mod.rs deleted file mode 100644 index 9ecb56a..0000000 --- a/tests/mr_test_module/src/libmr/mod.rs +++ /dev/null @@ -1,439 +0,0 @@ -/* - * Copyright Redis Ltd. 2021 - present - * Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or - * the Server Side Public License v1 (SSPLv1). - */ - -use std::marker::PhantomData; -use crate::libmrraw::bindings::{ - ExecutionBuilder, - MR_CreateExecutionBuilder, - MR_FreeExecutionBuilder, - MR_ExecutionBuilderCollect, - MR_ExecutionBuilderMap, - MRObjectType, - WriteSerializationCtx, - ReaderSerializationCtx, - MR_SerializationCtxWriteBuffer, - MR_SerializationCtxReadeBuffer, - RedisModuleCtx, - MR_RegisterObject, - MR_RegisterMapper, - ExecutionCtx, - MR_RegisterReader, - MR_CreateExecution, - MR_Run, - Execution, - MR_ExecutionSetOnDoneHandler, - MR_FreeExecution, - MR_ExecutionCtxGetResultsLen, - MR_ExecutionCtxGetResult, - MR_ExecutionCtxGetErrorsLen, - MR_ExecutionCtxGetError, - MRRecordType, - MR_RegisterRecord, - MR_ExecutionCtxSetError, - MR_ExecutionBuilderFilter, - MR_RegisterFilter, - MR_ExecutionBuilderReshuffle, - MR_ExecutionSetMaxIdle, - MR_RegisterAccumulator, - MR_ExecutionBuilderBuilAccumulate, - MRError, - MR_ErrorGetMessage, -}; - -use serde::ser::{ - Serialize, -}; - -use serde::de::{ - Deserialize, -}; - -use serde_json::{ - to_string, - from_str, -}; - -use std::os::raw::{ - c_char, - c_void, - c_int, -}; - -use std::slice; -use std::str; - -use redis_module::{ - RedisValue, -}; - -use libc::{ - strlen, -}; - -pub type RustMRError = String; - -pub extern "C" fn rust_obj_free(ctx: *mut c_void) { - unsafe{Box::from_raw(ctx as *mut T)}; -} - -pub extern "C" fn rust_obj_dup(arg: *mut c_void) -> *mut c_void { - let obj = unsafe{&mut *(arg as *mut T)}; - let mut obj = obj.clone(); - obj.init(); - Box::into_raw(Box::new(obj)) as *mut c_void -} - -pub extern "C" fn rust_obj_serialize(sctx: *mut WriteSerializationCtx, arg: *mut c_void, error: *mut *mut MRError) { - let obj = unsafe{&mut *(arg as *mut T)}; - let s = to_string(obj).unwrap(); - unsafe{ - MR_SerializationCtxWriteBuffer(sctx, s.as_ptr() as *const c_char, s.len(), error); - } -} - -pub extern "C" fn rust_obj_deserialize(sctx: *mut ReaderSerializationCtx, error: *mut *mut MRError) -> *mut c_void { - let mut len: usize = 0; - let s = unsafe { - MR_SerializationCtxReadeBuffer(sctx, &mut len as *mut usize, error) - }; - if !(unsafe{*error}).is_null() { - return 0 as *mut c_void; - } - let s = str::from_utf8(unsafe { slice::from_raw_parts(s as *const u8, len) }).unwrap(); - let mut obj: T = from_str(s).unwrap(); - obj.init(); - Box::into_raw(Box::new(obj)) as *mut c_void -} - -pub extern "C" fn rust_obj_to_string(_arg: *mut c_void) -> *mut c_char { - 0 as *mut c_char -} - -pub extern "C" fn rust_obj_send_reply(_arg1: *mut RedisModuleCtx, _record: *mut ::std::os::raw::c_void) { - -} - -pub extern "C" fn rust_obj_hash_slot(record: *mut ::std::os::raw::c_void) -> usize { - let record = unsafe{&mut *(record as *mut T)}; - record.hash_slot() -} - -pub trait BaseObject: Clone + Serialize + Deserialize<'static> { - fn get_name() -> &'static str; - fn init(&mut self) {} -} - -fn register() -> *mut MRObjectType { - unsafe { - let obj = Box::into_raw(Box::new(MRObjectType { - type_: T::get_name().as_ptr() as *mut c_char, - id: 0, - free: Some(rust_obj_free::), - dup: Some(rust_obj_dup::), - serialize: Some(rust_obj_serialize::), - deserialize: Some(rust_obj_deserialize::), - tostring: Some(rust_obj_to_string), - })); - - MR_RegisterObject(obj); - - obj - } -} - -fn register_record() -> *mut MRRecordType { - unsafe { - let obj = Box::into_raw(Box::new(MRRecordType { - type_: MRObjectType{ - type_: T::get_name().as_ptr() as *mut c_char, - id: 0, - free: Some(rust_obj_free::), - dup: Some(rust_obj_dup::), - serialize: Some(rust_obj_serialize::), - deserialize: Some(rust_obj_deserialize::), - tostring: Some(rust_obj_to_string), - }, - sendReply: Some(rust_obj_send_reply), - hashTag: Some(rust_obj_hash_slot::), - })); - - MR_RegisterRecord(obj); - - obj - } -} - -pub struct RecordType { - t: *mut MRRecordType, - phantom: PhantomData, -} - -impl RecordType { - pub fn new() -> RecordType { - let obj = register_record::(); - RecordType { - t: obj, - phantom: PhantomData, - } - } - - pub fn create(&self) -> R { - R::new(self.t) - } -} - -pub trait Record: BaseObject{ - fn new(t: *mut MRRecordType) -> Self; - fn to_redis_value(&mut self) -> RedisValue; - fn hash_slot(&self) -> usize; -} - -pub extern "C" fn rust_reader(ectx: *mut ExecutionCtx, args: *mut ::std::os::raw::c_void) -> *mut crate::libmrraw::bindings::Record { - let r = unsafe{&mut *(args as *mut Step)}; - match r.read() { - Some(res) => { - match res { - Ok(res) => Box::into_raw(Box::new(res)) as *mut crate::libmrraw::bindings::Record, - Err(e) => { - unsafe{MR_ExecutionCtxSetError(ectx, e.as_ptr() as *mut c_char, e.len())}; - 0 as *mut crate::libmrraw::bindings::Record - }, - } - }, - None => 0 as *mut crate::libmrraw::bindings::Record, - } -} - -pub trait Reader : BaseObject{ - type R: Record; - - fn read(&mut self) -> Option>; - - fn register() { - let obj = register::(); - unsafe{ - MR_RegisterReader(Self::get_name().as_ptr() as *mut c_char, Some(rust_reader::), obj); - } - } -} - -pub extern "C" fn rust_map(ectx: *mut ExecutionCtx, r: *mut crate::libmrraw::bindings::Record, args: *mut c_void) -> *mut crate::libmrraw::bindings::Record { - let s = unsafe{&*(args as *mut Step)}; - let r = unsafe{Box::from_raw(r as *mut Step::InRecord)}; - match s.map(*r) { - Ok(res) => Box::into_raw(Box::new(res)) as *mut crate::libmrraw::bindings::Record, - Err(e) => { - unsafe{MR_ExecutionCtxSetError(ectx, e.as_ptr() as *mut c_char, e.len())}; - 0 as *mut crate::libmrraw::bindings::Record - } - } - -} - -pub trait MapStep: BaseObject{ - type InRecord: Record; - type OutRecord: Record; - - fn map(&self, r: Self::InRecord) -> Result; - - fn register() { - let obj = register::(); - unsafe{ - MR_RegisterMapper(Self::get_name().as_ptr() as *mut c_char, Some(rust_map::), obj); - } - } -} - -pub extern "C" fn rust_filter(ectx: *mut ExecutionCtx, r: *mut crate::libmrraw::bindings::Record, args: *mut c_void) -> c_int { - let s = unsafe{&*(args as *mut Step)}; - let r = unsafe{&*(r as *mut Step::R)}; // do not take ownership on the record - match s.filter(r) { - Ok(res) => res as c_int, - Err(e) => { - unsafe{MR_ExecutionCtxSetError(ectx, e.as_ptr() as *mut c_char, e.len())}; - 0 as c_int - } - } - -} - -pub trait FilterStep: BaseObject{ - type R: Record; - - fn filter(&self, r: &Self::R) -> Result; - - fn register() { - let obj = register::(); - unsafe{ - MR_RegisterFilter(Self::get_name().as_ptr() as *mut c_char, Some(rust_filter::), obj); - } - } - -} - -pub extern "C" fn rust_accumulate(ectx: *mut ExecutionCtx, accumulator: *mut crate::libmrraw::bindings::Record, r: *mut crate::libmrraw::bindings::Record, args: *mut c_void) -> *mut crate::libmrraw::bindings::Record { - let s = unsafe{&*(args as *mut Step)}; - let accumulator = if accumulator.is_null() { - None - } else { - Some(unsafe{*Box::from_raw(accumulator as *mut Step::Accumulator)}) - }; - let r = unsafe{Box::from_raw(r as *mut Step::InRecord)}; - match s.accumulate(accumulator, *r) { - Ok(res) => Box::into_raw(Box::new(res)) as *mut crate::libmrraw::bindings::Record, - Err(e) => { - unsafe{MR_ExecutionCtxSetError(ectx, e.as_ptr() as *mut c_char, e.len())}; - 0 as *mut crate::libmrraw::bindings::Record - } - } - -} - -pub trait AccumulateStep: BaseObject{ - type InRecord: Record; - type Accumulator: Record; - - fn accumulate(&self, accumulator: Option, r: Self::InRecord) -> Result; - - fn register() { - let obj = register::(); - unsafe{ - MR_RegisterAccumulator(Self::get_name().as_ptr() as *mut c_char, Some(rust_accumulate::), obj); - } - } -} - -pub struct Builder { - inner_builder: Option<*mut ExecutionBuilder>, - phantom: PhantomData, -} - -pub fn create_builder(reader: Re) -> Builder { - let reader = Box::into_raw(Box::new(reader)); - let inner_builder = unsafe{ - MR_CreateExecutionBuilder(Re::get_name().as_ptr() as *const c_char, reader as *mut c_void) - }; - Builder:: { - inner_builder: Some(inner_builder), - phantom: PhantomData, - } -} - -impl Builder { - fn take(&mut self) -> *mut ExecutionBuilder{ - self.inner_builder.take().unwrap() - } - - pub fn map>(mut self, step: Step) -> Builder { - let inner_builder = self.take(); - unsafe { - MR_ExecutionBuilderMap(inner_builder, Step::get_name().as_ptr() as *const c_char, Box::into_raw(Box::new(step)) as *const Step as *mut c_void) - } - Builder:: { - inner_builder: Some(inner_builder), - phantom: PhantomData, - } - } - - pub fn filter>(self, step: Step) -> Builder { - unsafe { - MR_ExecutionBuilderFilter(self.inner_builder.unwrap(), Step::get_name().as_ptr() as *const c_char, Box::into_raw(Box::new(step)) as *const Step as *mut c_void) - } - self - } - - pub fn accumulate>(mut self, step: Step) -> Builder { - let inner_builder = self.take(); - unsafe { - MR_ExecutionBuilderBuilAccumulate(inner_builder, Step::get_name().as_ptr() as *const c_char, Box::into_raw(Box::new(step)) as *const Step as *mut c_void) - } - Builder:: { - inner_builder: Some(inner_builder), - phantom: PhantomData, - } - } - - pub fn collect(self) -> Self { - unsafe { - MR_ExecutionBuilderCollect(self.inner_builder.unwrap()); - } - self - } - - pub fn reshuffle(self) -> Self { - unsafe { - MR_ExecutionBuilderReshuffle(self.inner_builder.unwrap()); - } - self - } - - pub fn create_execution(&self) -> Result, RustMRError> { - let execution = unsafe { - let mut err: *mut MRError = 0 as *mut MRError; - let res = MR_CreateExecution(self.inner_builder.unwrap(), &mut err); - if !err.is_null() { - let c_msg = MR_ErrorGetMessage(err); - let r_str = str::from_utf8(slice::from_raw_parts(c_msg.cast::(), strlen(c_msg))).unwrap(); - return Err(r_str.to_string()); - } - res - }; - Ok(ExecutionObj{inner_e: execution, phantom: PhantomData,}) - } -} - -impl Drop for Builder { - fn drop(&mut self) { - if let Some(innder_builder) = self.inner_builder { - unsafe{MR_FreeExecutionBuilder(innder_builder)} - } - } -} - -pub struct ExecutionObj { - inner_e: *mut Execution, - phantom: PhantomData, -} - -pub extern "C" fn rust_on_done, Vec<&str>)>(ectx: *mut ExecutionCtx, pd: *mut c_void) { - let f = unsafe{Box::from_raw(pd as *mut F)}; - let mut res = Vec::new(); - let res_len = unsafe{MR_ExecutionCtxGetResultsLen(ectx)}; - for i in 0..res_len { - let r = unsafe{&mut *(MR_ExecutionCtxGetResult(ectx, i) as *mut R)}; - res.push(r); - } - let mut errs = Vec::new(); - let errs_len = unsafe{MR_ExecutionCtxGetErrorsLen(ectx)}; - for i in 0..errs_len { - let r = unsafe{MR_ExecutionCtxGetError(ectx, i)}; - let s = str::from_utf8(unsafe { slice::from_raw_parts(r.cast::(), strlen(r))}).unwrap(); - errs.push(s); - } - f(res, errs); -} - -impl ExecutionObj { - - pub fn set_max_idle(&self, max_idle: usize) { - unsafe{MR_ExecutionSetMaxIdle(self.inner_e, max_idle)}; - } - - pub fn set_done_hanlder, Vec<&str>)>(&self, f: F) { - let f = Box::into_raw(Box::new(f)); - unsafe{MR_ExecutionSetOnDoneHandler(self.inner_e, Some(rust_on_done::), f as *mut c_void)}; - } - - pub fn run(&self) { - unsafe{MR_Run(self.inner_e)}; - } -} - -impl Drop for ExecutionObj { - fn drop(&mut self) { - unsafe{MR_FreeExecution(self.inner_e)}; - } -} \ No newline at end of file