risingwave_common/vector/
mod.rs1pub mod distance;
16
17use std::slice;
18
19use bytes::{Buf, BufMut};
20use risingwave_common_estimate_size::EstimateSize;
21use tracing::warn;
22
23use crate::array::{VectorDistanceType, VectorItemType, VectorRef, VectorVal};
24use crate::types::F32;
25
26#[derive(Clone, Copy, PartialEq, Eq, EstimateSize)]
27pub struct VectorInner<T> {
28 pub(crate) inner: T,
29}
30
31impl<T: Ord> PartialOrd for VectorInner<T> {
32 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
33 Some(self.cmp(other))
34 }
35}
36impl<T: Ord> Ord for VectorInner<T> {
37 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
38 self.inner.cmp(&other.inner)
39 }
40}
41
42pub trait MeasureDistance {
43 fn measure(&self, other: VectorRef<'_>) -> VectorDistanceType;
44}
45
46pub trait MeasureDistanceBuilder {
47 type Measure<'a>: MeasureDistance + 'a;
48 fn new(target: VectorRef<'_>) -> Self::Measure<'_>;
49
50 fn distance(target: VectorRef<'_>, other: VectorRef<'_>) -> VectorDistanceType
51 where
52 Self: Sized,
53 {
54 Self::new(target).measure(other)
55 }
56}
57
58#[cfg_attr(not(test), expect(dead_code))]
59fn l2_norm_trivial(vec: &VectorInner<impl AsRef<[VectorItemType]>>) -> f32 {
60 vec.inner
61 .as_ref()
62 .iter()
63 .map(|item| item.0.powi(2))
64 .sum::<f32>()
65 .sqrt()
66}
67
68fn l2_norm_faiss(vec: &VectorInner<impl AsRef<[VectorItemType]>>) -> f32 {
69 faiss::utils::fvec_norm_l2sqr(F32::inner_slice(vec.inner.as_ref())).sqrt()
70}
71
72impl<T: AsRef<[VectorItemType]>> VectorInner<T> {
73 pub fn dimension(&self) -> usize {
74 self.inner.as_ref().len()
75 }
76
77 pub fn as_slice(&self) -> &[VectorItemType] {
78 self.inner.as_ref()
79 }
80
81 pub fn as_raw_slice(&self) -> &[f32] {
82 F32::inner_slice(self.inner.as_ref())
83 }
84
85 pub fn l2_norm(&self) -> f32 {
86 l2_norm_faiss(self)
87 }
88
89 pub fn normalized(&self) -> VectorVal {
90 let slice = self.inner.as_ref();
91 let len = slice.len();
92 let mut inner = Vec::with_capacity(len);
93 let l2_norm = self.l2_norm();
94 if l2_norm.is_infinite() || l2_norm.is_nan() {
95 warn!(
96 "encounter vector {:?} with invalid l2_norm {}. return zeros vector",
97 slice, l2_norm
98 );
99 return VectorVal {
100 inner: vec![0.0.into(); slice.len()].into_boxed_slice(),
101 };
102 }
103 if l2_norm < f32::MIN_POSITIVE {
104 warn!("normalize 0-norm vector. return original value");
105 return VectorVal {
106 inner: self.inner.as_ref().to_vec().into_boxed_slice(),
107 };
108 }
109 inner.extend((0..len).map(|i| {
111 unsafe { slice.get_unchecked(i) / l2_norm }
113 }));
114 VectorInner {
115 inner: inner.into_boxed_slice(),
116 }
117 }
118}
119
120pub fn encode_vector_payload(payload: &[VectorItemType], mut buf: impl BufMut) {
121 let vector_payload_ptr = payload.as_ptr() as *const u8;
122 let vector_payload_slice =
124 unsafe { slice::from_raw_parts(vector_payload_ptr, size_of_val(payload)) };
125 buf.put_slice(vector_payload_slice);
126}
127
128pub fn decode_vector_payload(vector_item_count: usize, mut buf: impl Buf) -> Vec<VectorItemType> {
129 let mut vector_payload = Vec::with_capacity(vector_item_count);
130
131 let vector_payload_ptr = vector_payload.spare_capacity_mut().as_mut_ptr() as *mut u8;
132 let vector_payload_slice = unsafe {
134 slice::from_raw_parts_mut(
135 vector_payload_ptr,
136 vector_item_count * size_of::<VectorItemType>(),
137 )
138 };
139 buf.copy_to_slice(vector_payload_slice);
140 unsafe {
142 vector_payload.set_len(vector_item_count);
143 }
144
145 vector_payload
146}
147
148#[cfg(test)]
149mod tests {
150 use crate::array::VectorVal;
151 use crate::vector::{l2_norm_faiss, l2_norm_trivial};
152
153 #[test]
154 fn test_vector() {
155 let vec = [0.238474, 0.578234];
156 let [v1_1, v1_2] = vec;
157 let vec = VectorVal {
158 inner: vec.map(Into::into).to_vec().into_boxed_slice(),
159 };
160
161 assert_eq!(vec.l2_norm(), (v1_1.powi(2) + v1_2.powi(2)).sqrt());
162 assert_eq!(l2_norm_faiss(&vec), l2_norm_trivial(&vec));
163
164 let normalized_vec = VectorVal::from_iter(
165 [v1_1 / vec.l2_norm(), v1_2 / vec.l2_norm()].map(|v| v.try_into().unwrap()),
166 );
167 assert_eq!(vec.normalized(), normalized_vec);
168 }
169}