risingwave_common/array/
vector_array.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16use std::hash::Hash;
17use std::slice;
18use std::sync::LazyLock;
19
20use bytes::{Buf, BufMut};
21use itertools::{Itertools, repeat_n};
22use memcomparable::Error;
23use risingwave_common::types::F32;
24use risingwave_common_estimate_size::EstimateSize;
25use risingwave_pb::common::PbBuffer;
26use risingwave_pb::common::buffer::PbCompressionType;
27use risingwave_pb::data::{PbArray, PbArrayType, PbListArrayData};
28use serde::{Deserialize, Serialize};
29
30use super::{Array, ArrayBuilder};
31use crate::bitmap::{Bitmap, BitmapBuilder};
32use crate::types::{DataType, ListType, Scalar, ScalarRef, ToText};
33use crate::vector::{VectorInner, decode_vector_payload, encode_vector_payload};
34
35pub type VectorItemType = F32;
36pub type VectorDistanceType = f64;
37pub const VECTOR_ITEM_TYPE: DataType = DataType::Float32;
38pub const VECTOR_DISTANCE_TYPE: DataType = DataType::Float64;
39
40/// Sometimes we can interpret a vector as a list to reuse some code, pass this type around.
41pub static VECTOR_AS_LIST_TYPE: LazyLock<ListType> =
42    LazyLock::new(|| ListType::new(VECTOR_ITEM_TYPE));
43
44#[derive(Debug, Clone, EstimateSize)]
45pub struct VectorArrayBuilder {
46    bitmap: BitmapBuilder,
47    offsets: Vec<u32>,
48    inner: Vec<VectorItemType>,
49    elem_size: usize,
50}
51
52impl ArrayBuilder for VectorArrayBuilder {
53    type ArrayType = VectorArray;
54
55    #[cfg(not(test))]
56    fn new(_capacity: usize) -> Self {
57        panic!("please use `VectorArrayBuilder::with_type` instead");
58    }
59
60    #[cfg(test)]
61    fn new(capacity: usize) -> Self {
62        Self::with_type(capacity, VectorVal::test_type())
63    }
64
65    fn with_type(capacity: usize, ty: DataType) -> Self {
66        let DataType::Vector(elem_size) = ty else {
67            panic!("VectorArrayBuilder only supports Vector type");
68        };
69        let mut offsets = Vec::with_capacity(capacity + 1);
70        offsets.push(0);
71        Self {
72            bitmap: BitmapBuilder::with_capacity(capacity),
73            offsets,
74            inner: Vec::with_capacity(capacity * elem_size),
75            elem_size,
76        }
77    }
78
79    fn append_n(&mut self, n: usize, value: Option<VectorRef<'_>>) {
80        let last = self
81            .offsets
82            .last()
83            .cloned()
84            .expect("non-empty with an initial 0");
85        if let Some(value) = value {
86            assert_eq!(self.elem_size, value.inner.len());
87            self.inner.reserve(self.elem_size * n);
88            for _ in 0..n {
89                self.inner.extend_from_slice(value.inner);
90            }
91            self.offsets.reserve(n);
92            self.offsets.extend((1..=n).map(|i| {
93                last.checked_add((i * self.elem_size) as _)
94                    .expect("overflow")
95            }));
96            self.bitmap.append_n(n, true);
97        } else {
98            self.offsets.reserve(n);
99            self.offsets.extend(repeat_n(last, n));
100            self.bitmap.append_n(n, false);
101        }
102    }
103
104    fn append_array(&mut self, other: &VectorArray) {
105        assert_eq!(self.elem_size, other.elem_size);
106        self.bitmap.append_bitmap(&other.bitmap);
107        let last = self
108            .offsets
109            .last()
110            .cloned()
111            .expect("non-empty with an initial 0");
112        let other_offsets = &other.offsets[1..];
113        self.offsets.reserve(other_offsets.len());
114        self.offsets.extend(
115            other_offsets
116                .iter()
117                .map(|offset| last.checked_add(*offset).expect("overflow")),
118        );
119        self.inner.reserve(other.inner.len());
120        self.inner.extend_from_slice(&other.inner);
121    }
122
123    fn pop(&mut self) -> Option<()> {
124        if self.bitmap.pop().is_some() {
125            self.offsets
126                .pop()
127                .expect("non-empty when bitmap popped Some");
128            let last = self
129                .offsets
130                .last()
131                .cloned()
132                .expect("non-empty with initial 0");
133            self.inner.truncate(last as _);
134            Some(())
135        } else {
136            None
137        }
138    }
139
140    fn len(&self) -> usize {
141        self.bitmap.len()
142    }
143
144    fn finish(self) -> VectorArray {
145        VectorArray {
146            bitmap: self.bitmap.finish(),
147            offsets: self.offsets,
148            inner: self.inner,
149            elem_size: self.elem_size,
150        }
151    }
152}
153
154#[derive(Debug, Clone)]
155pub struct VectorArray {
156    bitmap: Bitmap,
157    /// Of size as `bitmap.len() + 1`. `(self.offsets[i]..self.offsets[i+1])` is the slice range of the i-th vector
158    /// if it's not null.
159    offsets: Vec<u32>,
160    inner: Vec<VectorItemType>,
161    elem_size: usize,
162}
163
164impl EstimateSize for VectorArray {
165    fn estimated_heap_size(&self) -> usize {
166        self.inner.estimated_heap_size()
167    }
168}
169
170impl Array for VectorArray {
171    type Builder = VectorArrayBuilder;
172    type OwnedItem = VectorVal;
173    type RefItem<'a> = VectorRef<'a>;
174
175    unsafe fn raw_value_at_unchecked(&self, idx: usize) -> Self::RefItem<'_> {
176        VectorRef {
177            inner: unsafe {
178                let start = self.inner.as_ptr().add(self.offsets[idx] as usize);
179                slice::from_raw_parts(start, self.elem_size)
180            },
181        }
182    }
183
184    fn len(&self) -> usize {
185        self.bitmap.len()
186    }
187
188    fn to_protobuf(&self) -> PbArray {
189        let mut payload = Vec::with_capacity(self.inner.len() * size_of::<VectorItemType>());
190        encode_vector_payload(self.inner.as_slice(), &mut payload);
191        PbArray {
192            array_type: PbArrayType::Vector as _,
193            null_bitmap: Some(self.bitmap.to_protobuf()),
194            values: vec![PbBuffer {
195                compression: PbCompressionType::None as _,
196                body: payload,
197            }],
198            struct_array_data: None,
199            list_array_data: Some(
200                PbListArrayData {
201                    offsets: self.offsets.clone(),
202                    value: None,
203                    value_type: Some(DataType::Float32.to_protobuf()),
204                    elem_size: Some(self.elem_size as _),
205                }
206                .into(),
207            ),
208        }
209    }
210
211    fn null_bitmap(&self) -> &Bitmap {
212        &self.bitmap
213    }
214
215    fn into_null_bitmap(self) -> Bitmap {
216        self.bitmap
217    }
218
219    fn set_bitmap(&mut self, bitmap: Bitmap) {
220        self.bitmap = bitmap;
221    }
222
223    fn data_type(&self) -> DataType {
224        DataType::Vector(self.elem_size)
225    }
226}
227
228impl VectorArray {
229    pub fn from_protobuf(
230        array: &risingwave_pb::data::PbArray,
231    ) -> super::ArrayResult<super::ArrayImpl> {
232        // reversing to_protobuf
233        assert_eq!(
234            array.array_type,
235            PbArrayType::Vector as i32,
236            "invalid array type for vector: {}",
237            array.array_type
238        );
239        let bitmap: Bitmap = array.get_null_bitmap()?.into();
240        let encoded_payload = &array.values[0].body;
241        let payload = decode_vector_payload(
242            encoded_payload
243                .len()
244                .checked_exact_div(size_of::<VectorItemType>())
245                .unwrap_or_else(|| {
246                    panic!("invalid payload len {} for vector", encoded_payload.len(),)
247                }),
248            array.values[0].body.as_slice(),
249        );
250        let array_data = array.get_list_array_data()?;
251        let elem_size = array_data.elem_size.expect("should exist for Vector") as usize;
252        let offsets = array_data.offsets.clone();
253        debug_assert_eq!(array_data.value_type, Some(DataType::Float32.to_protobuf()));
254        debug_assert_eq!(array_data.value, None);
255
256        Ok(VectorArray {
257            bitmap,
258            offsets,
259            inner: payload,
260            elem_size,
261        }
262        .into())
263    }
264
265    pub fn as_raw_slice(&self) -> &[f32] {
266        F32::inner_slice(&self.inner)
267    }
268
269    pub fn offsets(&self) -> &[u32] {
270        &self.offsets
271    }
272}
273
274pub type VectorVal = VectorInner<Box<[VectorItemType]>>;
275
276impl Debug for VectorVal {
277    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
278        self.to_ref().fmt(f)
279    }
280}
281
282impl Scalar for VectorVal {
283    type ScalarRefType<'a> = VectorRef<'a>;
284
285    fn as_scalar_ref(&self) -> VectorRef<'_> {
286        self.to_ref()
287    }
288}
289
290impl VectorVal {
291    #[cfg(test)]
292    pub const TEST_VECTOR_DIMENSION: usize = 3;
293
294    pub fn from_text(text: &str, size: usize) -> Result<Self, String> {
295        let text = text.trim();
296        let text = text
297            .strip_prefix('[')
298            .ok_or_else(|| "vector must start with [".to_owned())?
299            .strip_suffix(']')
300            .ok_or_else(|| "vector must end with ]".to_owned())?;
301        let inner = text
302            .split(',')
303            .map(|s| {
304                s.trim()
305                    .parse::<f32>()
306                    .map_err(|_| format!("invalid f32: {s}"))
307                    .and_then(|f| {
308                        if f.is_finite() {
309                            Ok(f.into())
310                        } else {
311                            Err(format!("{f} not allowed in vector"))
312                        }
313                    })
314            })
315            .collect::<Result<Vec<_>, _>>()?;
316        if inner.len() != size {
317            return Err(format!("expected {} dimensions, not {}", size, inner.len()));
318        }
319        Ok(Self {
320            inner: inner.into(),
321        })
322    }
323
324    #[cfg(test)]
325    pub fn test_type() -> DataType {
326        DataType::Vector(Self::TEST_VECTOR_DIMENSION)
327    }
328
329    pub fn to_ref(&self) -> VectorRef<'_> {
330        VectorRef { inner: &self.inner }
331    }
332}
333
334/// A `f32` without nan/inf/-inf. Added as intermediate type to `try_collect` `f32` values into a `VectorVal`.
335#[derive(Clone, Copy, Debug)]
336#[repr(transparent)]
337pub struct Finite32(f32);
338
339impl TryFrom<f32> for Finite32 {
340    type Error = String;
341
342    fn try_from(value: f32) -> Result<Self, Self::Error> {
343        if value.is_finite() {
344            Ok(Self(value))
345        } else {
346            Err(format!("{value} not allowed in vector"))
347        }
348    }
349}
350
351impl From<Vec<Finite32>> for VectorVal {
352    fn from(value: Vec<Finite32>) -> Self {
353        let (ptr, len, cap) = value.into_raw_parts();
354        // Safety: OrderedFloat is #[repr(transparent)] and has no invalid values.
355        Self {
356            inner: unsafe { Vec::from_raw_parts(ptr as *mut F32, len, cap).into_boxed_slice() },
357        }
358    }
359}
360
361impl FromIterator<Finite32> for VectorVal {
362    fn from_iter<I: IntoIterator<Item = Finite32>>(iter: I) -> Self {
363        let inner = iter.into_iter().collect_vec();
364        Self::from(inner)
365    }
366}
367
368pub type VectorRef<'a> = VectorInner<&'a [VectorItemType]>;
369
370impl Debug for VectorRef<'_> {
371    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
372        self.write_with_type(&DataType::Vector(self.dimension()), f)
373    }
374}
375
376impl ToText for VectorRef<'_> {
377    fn write<W: std::fmt::Write>(&self, f: &mut W) -> std::fmt::Result {
378        self.write_with_type(&DataType::Vector(self.dimension()), f)
379    }
380
381    fn write_with_type<W: std::fmt::Write>(&self, _ty: &DataType, f: &mut W) -> std::fmt::Result {
382        write!(f, "[")?;
383        for (i, item) in self.inner.iter().enumerate() {
384            if i > 0 {
385                write!(f, ",")?;
386            }
387            write!(f, "{}", item)?;
388        }
389        write!(f, "]")
390    }
391}
392
393impl<'a> ScalarRef<'a> for VectorRef<'a> {
394    type ScalarType = VectorVal;
395
396    fn to_owned_scalar(&self) -> VectorVal {
397        VectorVal {
398            inner: self.inner.to_vec().into_boxed_slice(),
399        }
400    }
401
402    fn hash_scalar<H: std::hash::Hasher>(&self, state: &mut H) {
403        self.inner.hash(state);
404    }
405}
406
407impl<'a> VectorRef<'a> {
408    /// Create a `VectorRef` from a slice of `VectorItemType` without checking the elements in the slice
409    /// is invalid, such as `inf` and `nan`.
410    pub fn from_slice_unchecked(inner: &'a [VectorItemType]) -> Self {
411        Self { inner }
412    }
413
414    pub fn memcmp_serialize(
415        self,
416        serializer: &mut memcomparable::Serializer<impl BufMut>,
417    ) -> memcomparable::Result<()> {
418        for item in self.inner {
419            item.serialize(&mut *serializer)?;
420        }
421        Ok(())
422    }
423}
424
425impl VectorVal {
426    pub fn memcmp_deserialize(
427        dimension: usize,
428        de: &mut memcomparable::Deserializer<impl Buf>,
429    ) -> memcomparable::Result<Self> {
430        let mut value = Vec::with_capacity(dimension);
431        for _ in 0..dimension {
432            value.push(Finite32::try_from(f32::deserialize(&mut *de)?).map_err(Error::Message)?)
433        }
434        Ok(VectorVal::from(value))
435    }
436}