risingwave_common/vector/
mod.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod distance;
16
17use std::slice;
18
19use bytes::{Buf, BufMut};
20use risingwave_common_estimate_size::EstimateSize;
21use tracing::warn;
22
23use crate::array::{VectorDistanceType, VectorItemType, VectorRef, VectorVal};
24use crate::types::F32;
25
26#[derive(Clone, Copy, PartialEq, Eq, EstimateSize)]
27pub struct VectorInner<T> {
28    pub(crate) inner: T,
29}
30
31impl<T: Ord> PartialOrd for VectorInner<T> {
32    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
33        Some(self.cmp(other))
34    }
35}
36impl<T: Ord> Ord for VectorInner<T> {
37    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
38        self.inner.cmp(&other.inner)
39    }
40}
41
42pub trait MeasureDistance {
43    fn measure(&self, other: VectorRef<'_>) -> VectorDistanceType;
44}
45
46pub trait MeasureDistanceBuilder {
47    type Measure<'a>: MeasureDistance + 'a;
48    fn new(target: VectorRef<'_>) -> Self::Measure<'_>;
49
50    fn distance(target: VectorRef<'_>, other: VectorRef<'_>) -> VectorDistanceType
51    where
52        Self: Sized,
53    {
54        Self::new(target).measure(other)
55    }
56}
57
58#[cfg_attr(not(test), expect(dead_code))]
59fn l2_norm_trivial(vec: &VectorInner<impl AsRef<[VectorItemType]>>) -> f32 {
60    vec.inner
61        .as_ref()
62        .iter()
63        .map(|item| item.0.powi(2))
64        .sum::<f32>()
65        .sqrt()
66}
67
68fn l2_norm_faiss(vec: &VectorInner<impl AsRef<[VectorItemType]>>) -> f32 {
69    faiss::utils::fvec_norm_l2sqr(F32::inner_slice(vec.inner.as_ref())).sqrt()
70}
71
72impl<T: AsRef<[VectorItemType]>> VectorInner<T> {
73    pub fn dimension(&self) -> usize {
74        self.inner.as_ref().len()
75    }
76
77    pub fn as_slice(&self) -> &[VectorItemType] {
78        self.inner.as_ref()
79    }
80
81    pub fn as_raw_slice(&self) -> &[f32] {
82        F32::inner_slice(self.inner.as_ref())
83    }
84
85    pub fn l2_norm(&self) -> f32 {
86        l2_norm_faiss(self)
87    }
88
89    pub fn normalized(&self) -> VectorVal {
90        let slice = self.inner.as_ref();
91        let len = slice.len();
92        let mut inner = Vec::with_capacity(len);
93        let l2_norm = self.l2_norm();
94        if l2_norm.is_infinite() || l2_norm.is_nan() {
95            warn!(
96                "encounter vector {:?} with invalid l2_norm {}. return zeros vector",
97                slice, l2_norm
98            );
99            return VectorVal {
100                inner: vec![0.0.into(); slice.len()].into_boxed_slice(),
101            };
102        }
103        if l2_norm < f32::MIN_POSITIVE {
104            warn!("normalize 0-norm vector. return original value");
105            return VectorVal {
106                inner: self.inner.as_ref().to_vec().into_boxed_slice(),
107            };
108        }
109        // TODO: vectorize it
110        inner.extend((0..len).map(|i| {
111            // safety: 0 <= i < len
112            unsafe { slice.get_unchecked(i) / l2_norm }
113        }));
114        VectorInner {
115            inner: inner.into_boxed_slice(),
116        }
117    }
118}
119
120pub fn encode_vector_payload(payload: &[VectorItemType], mut buf: impl BufMut) {
121    let vector_payload_ptr = payload.as_ptr() as *const u8;
122    // safety: correctly set the size of vector_payload
123    let vector_payload_slice =
124        unsafe { slice::from_raw_parts(vector_payload_ptr, size_of_val(payload)) };
125    buf.put_slice(vector_payload_slice);
126}
127
128pub fn decode_vector_payload(vector_item_count: usize, mut buf: impl Buf) -> Vec<VectorItemType> {
129    let mut vector_payload = Vec::with_capacity(vector_item_count);
130
131    let vector_payload_ptr = vector_payload.spare_capacity_mut().as_mut_ptr() as *mut u8;
132    // safety: no data append to vector_payload, and correctly set the size of vector_payload
133    let vector_payload_slice = unsafe {
134        slice::from_raw_parts_mut(
135            vector_payload_ptr,
136            vector_item_count * size_of::<VectorItemType>(),
137        )
138    };
139    buf.copy_to_slice(vector_payload_slice);
140    // safety: have written correct amount of data
141    unsafe {
142        vector_payload.set_len(vector_item_count);
143    }
144
145    vector_payload
146}
147
148#[cfg(test)]
149mod tests {
150    use crate::array::VectorVal;
151    use crate::vector::{l2_norm_faiss, l2_norm_trivial};
152
153    #[test]
154    fn test_vector() {
155        let vec = [0.238474, 0.578234];
156        let [v1_1, v1_2] = vec;
157        let vec = VectorVal {
158            inner: vec.map(Into::into).to_vec().into_boxed_slice(),
159        };
160
161        assert_eq!(vec.l2_norm(), (v1_1.powi(2) + v1_2.powi(2)).sqrt());
162        assert_eq!(l2_norm_faiss(&vec), l2_norm_trivial(&vec));
163
164        let normalized_vec = VectorVal::from_iter(
165            [v1_1 / vec.l2_norm(), v1_2 / vec.l2_norm()].map(|v| v.try_into().unwrap()),
166        );
167        assert_eq!(vec.normalized(), normalized_vec);
168    }
169}