risingwave_common/array/
vector_array.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16
17use risingwave_common_estimate_size::EstimateSize;
18use risingwave_pb::data::PbArray;
19
20use super::{Array, ArrayBuilder, ListArray, ListArrayBuilder, ListRef, ListValue};
21use crate::array::ArrayError;
22use crate::bitmap::Bitmap;
23use crate::types::{DataType, Scalar, ScalarRef, ScalarRefImpl, ToText};
24
25#[derive(Debug, Clone, EstimateSize)]
26pub struct VectorArrayBuilder {
27    inner: ListArrayBuilder,
28    elem_size: usize,
29}
30
31impl ArrayBuilder for VectorArrayBuilder {
32    type ArrayType = VectorArray;
33
34    #[cfg(not(test))]
35    fn new(_capacity: usize) -> Self {
36        panic!("please use `VectorArrayBuilder::with_type` instead");
37    }
38
39    #[cfg(test)]
40    fn new(capacity: usize) -> Self {
41        Self::with_type(capacity, DataType::Vector(3))
42    }
43
44    fn with_type(capacity: usize, ty: DataType) -> Self {
45        let DataType::Vector(elem_size) = ty else {
46            panic!("VectorArrayBuilder only supports Vector type");
47        };
48        Self {
49            inner: ListArrayBuilder::with_type(capacity, DataType::List(DataType::Float32.into())),
50            elem_size,
51        }
52    }
53
54    fn append_n(&mut self, n: usize, value: Option<VectorRef<'_>>) {
55        if let Some(value) = value {
56            assert_eq!(self.elem_size, value.inner.len());
57        }
58        self.inner.append_n(n, value.map(|v| v.inner))
59    }
60
61    fn append_array(&mut self, other: &VectorArray) {
62        assert_eq!(self.elem_size, other.elem_size);
63        self.inner.append_array(&other.inner)
64    }
65
66    fn pop(&mut self) -> Option<()> {
67        self.inner.pop()
68    }
69
70    fn len(&self) -> usize {
71        self.inner.len()
72    }
73
74    fn finish(self) -> VectorArray {
75        VectorArray {
76            inner: self.inner.finish(),
77            elem_size: self.elem_size,
78        }
79    }
80}
81
82#[derive(Debug, Clone)]
83pub struct VectorArray {
84    inner: ListArray,
85    elem_size: usize,
86}
87
88impl EstimateSize for VectorArray {
89    fn estimated_heap_size(&self) -> usize {
90        self.inner.estimated_heap_size()
91    }
92}
93
94impl Array for VectorArray {
95    type Builder = VectorArrayBuilder;
96    type OwnedItem = VectorVal;
97    type RefItem<'a> = VectorRef<'a>;
98
99    unsafe fn raw_value_at_unchecked(&self, idx: usize) -> Self::RefItem<'_> {
100        VectorRef {
101            inner: unsafe { self.inner.raw_value_at_unchecked(idx) },
102        }
103    }
104
105    fn len(&self) -> usize {
106        self.inner.len()
107    }
108
109    fn to_protobuf(&self) -> PbArray {
110        let mut pb_array = self.inner.to_protobuf();
111        pb_array.set_array_type(risingwave_pb::data::PbArrayType::Vector);
112        pb_array.list_array_data.as_mut().unwrap().elem_size = Some(self.elem_size as _);
113        pb_array
114    }
115
116    fn null_bitmap(&self) -> &Bitmap {
117        self.inner.null_bitmap()
118    }
119
120    fn into_null_bitmap(self) -> Bitmap {
121        self.inner.into_null_bitmap()
122    }
123
124    fn set_bitmap(&mut self, bitmap: Bitmap) {
125        self.inner.set_bitmap(bitmap)
126    }
127
128    fn data_type(&self) -> DataType {
129        DataType::Vector(self.elem_size)
130    }
131}
132
133impl VectorArray {
134    pub fn from_protobuf(
135        pb_array: &risingwave_pb::data::PbArray,
136    ) -> super::ArrayResult<super::ArrayImpl> {
137        let inner = ListArray::from_protobuf(pb_array)?.into_list();
138        let elem_size = pb_array
139            .list_array_data
140            .as_ref()
141            .unwrap()
142            .elem_size
143            .unwrap() as _;
144        Ok(Self { inner, elem_size }.into())
145    }
146}
147
148#[derive(Clone, EstimateSize)]
149pub struct VectorVal {
150    inner: ListValue,
151}
152
153impl Debug for VectorVal {
154    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155        self.as_scalar_ref().fmt(f)
156    }
157}
158
159impl PartialEq for VectorVal {
160    fn eq(&self, _other: &Self) -> bool {
161        todo!("VECTOR_PLACEHOLDER")
162    }
163}
164impl Eq for VectorVal {}
165impl PartialOrd for VectorVal {
166    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
167        Some(self.cmp(other))
168    }
169}
170impl Ord for VectorVal {
171    fn cmp(&self, _other: &Self) -> std::cmp::Ordering {
172        todo!("VECTOR_PLACEHOLDER")
173    }
174}
175
176impl Scalar for VectorVal {
177    type ScalarRefType<'a> = VectorRef<'a>;
178
179    fn as_scalar_ref(&self) -> VectorRef<'_> {
180        VectorRef {
181            inner: self.inner.as_scalar_ref(),
182        }
183    }
184}
185
186impl VectorVal {
187    pub fn from_text(text: &str, size: usize) -> Result<Self, String> {
188        let text = text.trim();
189        let text = text
190            .strip_prefix('[')
191            .ok_or_else(|| "vector must start with [".to_owned())?
192            .strip_suffix(']')
193            .ok_or_else(|| "vector must end with ]".to_owned())?;
194        let inner = text
195            .split(',')
196            .map(|s| {
197                s.trim()
198                    .parse::<f32>()
199                    .map_err(|_| format!("invalid f32: {s}"))
200                    .and_then(|f| {
201                        if f.is_finite() {
202                            Ok(crate::types::F32::from(f))
203                        } else {
204                            Err(format!("{f} not allowed in vector"))
205                        }
206                    })
207            })
208            .collect::<Result<ListValue, _>>()?;
209        if inner.len() != size {
210            return Err(format!("expected {} dimensions, not {}", size, inner.len()));
211        }
212        Ok(Self { inner })
213    }
214
215    /// Create a new vector from inner [`ListValue`].
216    ///
217    /// This is leak of implementation. Prefer [`VectorVal::from_iter`] below.
218    pub fn from_inner(inner: ListValue) -> Result<Self, ArrayError> {
219        for element in inner.iter() {
220            let Some(scalar) = element else {
221                return Err(ArrayError::internal("NULL not allowed in vector"));
222            };
223            let ScalarRefImpl::Float32(val) = scalar else {
224                return Err(ArrayError::internal(format!(
225                    "vector element must be f32 but found {scalar:?}"
226                )));
227            };
228            if !val.0.is_finite() {
229                return Err(ArrayError::internal(format!("{val} not allowed in vector")));
230            }
231        }
232        Ok(Self { inner })
233    }
234}
235
236/// A `f32` without nan/inf/-inf. Added as intermediate type to `try_collect` `f32` values into a `VectorVal`.
237#[derive(Clone, Copy, Debug)]
238pub struct Finite32(f32);
239
240impl TryFrom<f32> for Finite32 {
241    type Error = String;
242
243    fn try_from(value: f32) -> Result<Self, Self::Error> {
244        if value.is_finite() {
245            Ok(Self(value))
246        } else {
247            Err(format!("{value} not allowed in vector"))
248        }
249    }
250}
251
252impl FromIterator<Finite32> for VectorVal {
253    fn from_iter<I: IntoIterator<Item = Finite32>>(iter: I) -> Self {
254        let inner = ListValue::from_iter(iter.into_iter().map(|v| crate::types::F32::from(v.0)));
255        Self { inner }
256    }
257}
258
259#[derive(Clone, Copy)]
260pub struct VectorRef<'a> {
261    inner: ListRef<'a>,
262}
263
264impl Debug for VectorRef<'_> {
265    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266        self.write_with_type(&DataType::Vector(self.into_slice().len()), f)
267    }
268}
269
270impl PartialEq for VectorRef<'_> {
271    fn eq(&self, _other: &Self) -> bool {
272        todo!("VECTOR_PLACEHOLDER")
273    }
274}
275impl Eq for VectorRef<'_> {}
276impl PartialOrd for VectorRef<'_> {
277    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
278        Some(self.cmp(other))
279    }
280}
281impl Ord for VectorRef<'_> {
282    fn cmp(&self, _other: &Self) -> std::cmp::Ordering {
283        todo!("VECTOR_PLACEHOLDER")
284    }
285}
286
287impl ToText for VectorRef<'_> {
288    fn write<W: std::fmt::Write>(&self, f: &mut W) -> std::fmt::Result {
289        self.write_with_type(&DataType::Vector(self.into_slice().len()), f)
290    }
291
292    fn write_with_type<W: std::fmt::Write>(&self, _ty: &DataType, f: &mut W) -> std::fmt::Result {
293        write!(f, "[")?;
294        for (i, item) in self.inner.iter().enumerate() {
295            if i > 0 {
296                write!(f, ",")?;
297            }
298            item.write_with_type(&DataType::Float32, f)?;
299        }
300        write!(f, "]")
301    }
302}
303
304impl<'a> ScalarRef<'a> for VectorRef<'a> {
305    type ScalarType = VectorVal;
306
307    fn to_owned_scalar(&self) -> VectorVal {
308        VectorVal {
309            inner: self.inner.to_owned_scalar(),
310        }
311    }
312
313    fn hash_scalar<H: std::hash::Hasher>(&self, state: &mut H) {
314        self.inner.hash_scalar(state)
315    }
316}
317
318impl<'a> VectorRef<'a> {
319    /// Get the inner [`ListRef`].
320    ///
321    /// This is leak of implementation. Prefer [`Self::into_slice`] below.
322    pub fn into_inner(self) -> ListRef<'a> {
323        self.inner
324    }
325
326    /// Get the slice of floats in this vector.
327    pub fn into_slice(self) -> &'a [f32] {
328        crate::types::F32::inner_slice(self.inner.as_primitive_slice().unwrap())
329    }
330}