Skip to content

Commit

Permalink
fix: lose decimal precision when using decimal type as tag (#5481)
Browse files Browse the repository at this point in the history
* fix: replicate() of decimal vector lose precision

* test: add sqlness test

* test: drop table
  • Loading branch information
evenyag authored Feb 6, 2025
1 parent c80d2a3 commit 0a16998
Show file tree
Hide file tree
Showing 5 changed files with 822 additions and 8 deletions.
37 changes: 37 additions & 0 deletions src/datatypes/src/vectors/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,43 @@ impl Decimal128VectorBuilder {

vectors::impl_try_from_arrow_array_for_vector!(Decimal128Array, Decimal128Vector);

pub(crate) fn replicate_decimal128(
vector: &Decimal128Vector,
offsets: &[usize],
) -> Decimal128Vector {
assert_eq!(offsets.len(), vector.len());

if offsets.is_empty() {
return vector.get_slice(0, 0);
}

// Safety: safe to unwrap because we the vector ensures precision and scale are valid.
let mut builder = Decimal128VectorBuilder::with_capacity(*offsets.last().unwrap())
.with_precision_and_scale(vector.precision(), vector.scale())
.unwrap();

let mut previous_offset = 0;

for (offset, value) in offsets.iter().zip(vector.array.iter()) {
let repeat_times = *offset - previous_offset;
match value {
Some(data) => {
unsafe {
// Safety: std::iter::Repeat and std::iter::Take implement TrustedLen.
builder
.mutable_array
.append_trusted_len_iter(std::iter::repeat(data).take(repeat_times));
}
}
None => {
builder.mutable_array.append_nulls(repeat_times);
}
}
previous_offset = *offset;
}
builder.finish()
}

#[cfg(test)]
pub mod tests {
use arrow_array::Decimal128Array;
Expand Down
31 changes: 24 additions & 7 deletions src/datatypes/src/vectors/operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,30 @@ macro_rules! impl_scalar_vector_op {
)+};
}

impl_scalar_vector_op!(
BinaryVector,
BooleanVector,
ListVector,
StringVector,
Decimal128Vector
);
impl_scalar_vector_op!(BinaryVector, BooleanVector, ListVector, StringVector);

impl VectorOp for Decimal128Vector {
fn replicate(&self, offsets: &[usize]) -> VectorRef {
std::sync::Arc::new(replicate::replicate_decimal128(self, offsets))
}

fn find_unique(&self, selected: &mut BitVec, prev_vector: Option<&dyn Vector>) {
let prev_vector = prev_vector.and_then(|pv| pv.as_any().downcast_ref::<Decimal128Vector>());
find_unique::find_unique_scalar(self, selected, prev_vector);
}

fn filter(&self, filter: &BooleanVector) -> Result<VectorRef> {
filter::filter_non_constant!(self, Decimal128Vector, filter)
}

fn cast(&self, to_type: &ConcreteDataType) -> Result<VectorRef> {
cast::cast_non_constant!(self, to_type)
}

fn take(&self, indices: &UInt32Vector) -> Result<VectorRef> {
take::take_indices!(self, Decimal128Vector, indices)
}
}

impl<T: LogicalPrimitiveType> VectorOp for PrimitiveVector<T> {
fn replicate(&self, offsets: &[usize]) -> VectorRef {
Expand Down
22 changes: 21 additions & 1 deletion src/datatypes/src/vectors/operations/replicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

use crate::prelude::*;
pub(crate) use crate::vectors::decimal::replicate_decimal128;
pub(crate) use crate::vectors::null::replicate_null;
pub(crate) use crate::vectors::primitive::replicate_primitive;

Expand Down Expand Up @@ -45,7 +46,7 @@ mod tests {

use super::*;
use crate::vectors::constant::ConstantVector;
use crate::vectors::{Int32Vector, NullVector, StringVector, VectorOp};
use crate::vectors::{Decimal128Vector, Int32Vector, NullVector, StringVector, VectorOp};

#[test]
fn test_replicate_primitive() {
Expand Down Expand Up @@ -167,4 +168,23 @@ mod tests {
impl_replicate_timestamp_test!(Microsecond);
impl_replicate_timestamp_test!(Nanosecond);
}

#[test]
fn test_replicate_decimal() {
let data = vec![100];
// create a decimal vector
let v = Decimal128Vector::from_values(data.clone())
.with_precision_and_scale(10, 2)
.unwrap();
let offsets = [5];
let v = v.replicate(&offsets);
assert_eq!(5, v.len());

let expect: VectorRef = Arc::new(
Decimal128Vector::from_values(vec![100; 5])
.with_precision_and_scale(10, 2)
.unwrap(),
);
assert_eq!(expect, v);
}
}
Loading

0 comments on commit 0a16998

Please sign in to comment.