Skip to content

Commit

Permalink
feat(RUST): String detection is performed using SIMD techniques (apac…
Browse files Browse the repository at this point in the history
…he#1752)



## What does this PR do?

Using SIMD technology to speed up string detection

## Related issues



## Does this PR introduce any user-facing change?

- [x] Does this PR introduce any public API change?
- [ ] Does this PR introduce any binary protocol compatibility change?


## Benchmark
In fury/rust, run
```bash
cargo bench
```
result:
```bash
SIMD sse short          time:   [992.07 ps 993.91 ps 995.81 ps]                            
                        change: [-1.1907% -0.7021% -0.2276%] (p = 0.00 < 0.05)
                        Change within noise threshold.
Found 5 outliers among 100 measurements (5.00%)
  2 (2.00%) high mild
  3 (3.00%) high severe

SIMD sse long           time:   [368.88 ns 369.70 ns 370.59 ns]                          
                        change: [-1.1407% -0.4415% +0.1537%] (p = 0.20 > 0.05)
                        No change in performance detected.
Found 11 outliers among 100 measurements (11.00%)
  9 (9.00%) high mild
  2 (2.00%) high severe

SIMD avx short          time:   [4.4313 ns 4.4425 ns 4.4566 ns]                            
                        change: [-0.4239% -0.0440% +0.3277%] (p = 0.82 > 0.05)
                        No change in performance detected.
Found 9 outliers among 100 measurements (9.00%)
  5 (5.00%) high mild
  4 (4.00%) high severe

SIMD avx long           time:   [18.215 ns 18.277 ns 18.351 ns]                           
                        change: [+0.6658% +1.0962% +1.6058%] (p = 0.00 < 0.05)
                        Change within noise threshold.
Found 9 outliers among 100 measurements (9.00%)
  5 (5.00%) high mild
  4 (4.00%) high severe

Standard short          time:   [5.1115 ns 5.1670 ns 5.2491 ns]                            
                        change: [+0.5348% +1.5193% +2.5623%] (p = 0.00 < 0.05)
                        Change within noise threshold.
Found 3 outliers among 100 measurements (3.00%)
  2 (2.00%) high mild
  1 (1.00%) high severe

Standard long           time:   [3.6904 µs 3.7205 µs 3.7606 µs]                           
                        change: [+1.5445% +2.5638% +3.9167%] (p = 0.00 < 0.05)
                        Performance has regressed.
Found 13 outliers among 100 measurements (13.00%)
  7 (7.00%) high mild
  6 (6.00%) high severe
  ```

---------

Co-authored-by: hezz <[email protected]>
  • Loading branch information
kitty-eu-org and hezz authored Jul 22, 2024
1 parent cfaac57 commit c09f5b9
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 1 deletion.
10 changes: 10 additions & 0 deletions rust/fury/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,13 @@ lazy_static = { version = "1.4" }
byteorder = { version = "1.4" }
chrono = "0.4"
thiserror = { default-features = false, version = "1.0" }


[[bench]]
name = "simd_bench"
harness = false


[dev-dependencies]
criterion = "0.5.1"
rand = "0.8.5"
112 changes: 112 additions & 0 deletions rust/fury/benches/simd_bench.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use criterion::{black_box, criterion_group, criterion_main, Criterion};
#[cfg(target_feature = "avx2")]
use std::arch::x86_64::*;

#[cfg(target_feature = "sse2")]
use std::arch::x86_64::*;

#[cfg(target_feature = "avx2")]
pub(crate) const MIN_DIM_SIZE_AVX: usize = 32;

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub(crate) const MIN_DIM_SIZE_SIMD: usize = 16;

#[cfg(target_feature = "sse2")]
unsafe fn is_latin_sse(s: &str) -> bool {
let bytes = s.as_bytes();
let len = s.len();
let ascii_mask = _mm_set1_epi8(0x80u8 as i8);
let remaining = len % MIN_DIM_SIZE_SIMD;
let range_end = len - remaining;
for i in (0..range_end).step_by(MIN_DIM_SIZE_SIMD) {
let chunk = _mm_loadu_si128(bytes.as_ptr().add(i) as *const __m128i);
let masked = _mm_and_si128(chunk, ascii_mask);
let cmp = _mm_cmpeq_epi8(masked, _mm_setzero_si128());
if _mm_movemask_epi8(cmp) != 0xFFFF {
return false;
}
}
for item in bytes.iter().take(range_end).skip(range_end) {
if !item.is_ascii() {
return false;
}
}
true
}

#[cfg(target_feature = "avx2")]
unsafe fn is_latin_avx(s: &str) -> bool {
let bytes = s.as_bytes();
let len = s.len();
let ascii_mask = _mm256_set1_epi8(0x80u8 as i8);
let remaining = len % MIN_DIM_SIZE_AVX;
let range_end = len - remaining;
for i in (0..(len - remaining)).step_by(MIN_DIM_SIZE_AVX) {
let chunk = _mm256_loadu_si256(bytes.as_ptr().add(i) as *const __m256i);
let masked = _mm256_and_si256(chunk, ascii_mask);
let cmp = _mm256_cmpeq_epi8(masked, _mm256_setzero_si256());
if _mm256_movemask_epi8(cmp) != 0xFFFF {
return false;
}
}
for item in bytes.iter().take(range_end).skip(range_end) {
if !item.is_ascii() {
return false;
}
}
true
}

fn is_latin_std(s: &str) -> bool {
s.bytes().all(|b| b.is_ascii())
}

fn criterion_benchmark(c: &mut Criterion) {
let test_str_short = "Hello, World!";
let test_str_long = "Hello, World! ".repeat(1000);

#[cfg(target_feature = "sse2")]
c.bench_function("SIMD sse short", |b| {
b.iter(|| unsafe { is_latin_sse(black_box(test_str_short)) })
});
#[cfg(target_feature = "sse2")]
c.bench_function("SIMD sse long", |b| {
b.iter(|| unsafe { is_latin_sse(black_box(&test_str_long)) })
});
#[cfg(target_feature = "avx2")]
c.bench_function("SIMD avx short", |b| {
b.iter(|| unsafe { is_latin_avx(black_box(test_str_short)) })
});
#[cfg(target_feature = "avx2")]
c.bench_function("SIMD avx long", |b| {
b.iter(|| unsafe { is_latin_avx(black_box(&test_str_long)) })
});

c.bench_function("Standard short", |b| {
b.iter(|| is_latin_std(black_box(test_str_short)))
});

c.bench_function("Standard long", |b| {
b.iter(|| is_latin_std(black_box(&test_str_long)))
});
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
4 changes: 3 additions & 1 deletion rust/fury/src/meta/meta_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.

use crate::meta::string_util;

#[derive(Debug, PartialEq)]
pub enum Encoding {
Utf8 = 0x00,
Expand Down Expand Up @@ -102,7 +104,7 @@ impl MetaStringEncoder {
}

fn is_latin(&self, s: &str) -> bool {
s.bytes().all(|b| b.is_ascii())
string_util::is_latin(s)
}

pub fn encode(&self, input: &str) -> Result<MetaString, Error> {
Expand Down
1 change: 1 addition & 0 deletions rust/fury/src/meta/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
// under the License.

mod meta_string;
mod string_util;
pub use meta_string::{Encoding, MetaStringDecoder, MetaStringEncoder};
189 changes: 189 additions & 0 deletions rust/fury/src/meta/string_util.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#[cfg(target_feature = "neon")]
use std::arch::aarch64::*;

#[cfg(target_feature = "avx2")]
use std::arch::x86_64::*;

#[cfg(target_feature = "sse2")]
use std::arch::x86_64::*;

#[cfg(target_arch = "x86_64")]
pub(crate) const MIN_DIM_SIZE_AVX: usize = 32;

#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
all(target_arch = "aarch64", target_feature = "neon")
))]
pub(crate) const MIN_DIM_SIZE_SIMD: usize = 16;

#[cfg(target_arch = "x86_64")]
unsafe fn is_latin_avx(s: &str) -> bool {
let bytes = s.as_bytes();
let len = bytes.len();
let ascii_mask = _mm256_set1_epi8(0x80u8 as i8);
let remaining = len % MIN_DIM_SIZE_AVX;
let range_end = len - remaining;

for i in (0..range_end).step_by(MIN_DIM_SIZE_AVX) {
let chunk = _mm256_loadu_si256(bytes.as_ptr().add(i) as *const __m256i);
let masked = _mm256_and_si256(chunk, ascii_mask);
let cmp = _mm256_cmpeq_epi8(masked, _mm256_setzero_si256());
if _mm256_movemask_epi8(cmp) != -1 {
return false;
}
}
for item in bytes.iter().take(len).skip(range_end) {
if !item.is_ascii() {
return false;
}
}
true
}

#[cfg(target_feature = "sse2")]
unsafe fn is_latin_sse(s: &str) -> bool {
let bytes = s.as_bytes();
let len = bytes.len();
let ascii_mask = _mm_set1_epi8(0x80u8 as i8);
let remaining = len % MIN_DIM_SIZE_SIMD;
let range_end = len - remaining;
for i in (0..range_end).step_by(MIN_DIM_SIZE_SIMD) {
let chunk = _mm_loadu_si128(bytes.as_ptr().add(i) as *const __m128i);
let masked = _mm_and_si128(chunk, ascii_mask);
let cmp = _mm_cmpeq_epi8(masked, _mm_setzero_si128());
if _mm_movemask_epi8(cmp) != 0xFFFF {
return false;
}
}
for item in bytes.iter().take(len).skip(range_end) {
if !item.is_ascii() {
return false;
}
}
true
}

#[cfg(target_feature = "neon")]
unsafe fn is_latin_neon(s: &str) -> bool {
let bytes = s.as_bytes();
let len = bytes.len();
let ascii_mask = vdupq_n_u8(0x80);
let remaining = len % MIN_DIM_SIZE_SIMD;
let range_end = len - remaining;
for i in (0..range_end).step_by(MIN_DIM_SIZE_SIMD) {
let chunk = vld1q_u8(bytes.as_ptr().add(i));
let masked = vandq_u8(chunk, ascii_mask);
let cmp = vceqq_u8(masked, vdupq_n_u8(0));
if vminvq_u8(cmp) == 0 {
return false;
}
}
for item in bytes.iter().take(len).skip(range_end) {
if !item.is_ascii() {
return false;
}
}
true
}

fn is_latin_standard(s: &str) -> bool {
s.bytes().all(|b| b.is_ascii())
}

pub(crate) fn is_latin(s: &str) -> bool {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx")
&& is_x86_feature_detected!("fma")
&& s.len() >= MIN_DIM_SIZE_AVX
{
return unsafe { is_latin_avx(s) };
}
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("sse") && s.len() >= MIN_DIM_SIZE_SIMD {
return unsafe { is_latin_sse(s) };
}
}

#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
{
if std::arch::is_aarch64_feature_detected!("neon") && s.len() >= MIN_DIM_SIZE_SIMD {
return unsafe { is_latin_neon(s) };
}
}
is_latin_standard(s)
}

#[cfg(test)]
mod tests {
// 导入外部模块中的内容
use super::*;
use rand::Rng;

fn generate_random_string(length: usize) -> String {
const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
let mut rng = rand::thread_rng();

let result: String = (0..length)
.map(|_| {
let idx = rng.gen_range(0..CHARSET.len());
CHARSET[idx] as char
})
.collect();

result
}

#[test]
fn test_is_latin() {
let s = generate_random_string(1000);
let not_latin_str = generate_random_string(1000) + "abc\u{1234}";

#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma") {
assert!(unsafe { is_latin_avx(&s) });
assert!(!unsafe { is_latin_avx(&not_latin_str) });
}
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("sse") && s.len() >= MIN_DIM_SIZE_SIMD {
assert!(unsafe { is_latin_sse(&s) });
assert!(!unsafe { is_latin_sse(&not_latin_str) });
}
}

#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
{
if std::arch::is_aarch64_feature_detected!("neon") && s.len() >= MIN_DIM_SIZE_SIMD {
assert!(unsafe { is_latin_neon(&s) });
assert!(!unsafe { is_latin_neon(&not_latin_str) });
}
}
assert!(is_latin_standard(&s));
assert!(!is_latin_standard(&not_latin_str));
}
}

0 comments on commit c09f5b9

Please sign in to comment.