risingwave_expr_impl/scalar/
vector.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::array::Finite32;
16use risingwave_common::types::{
17    DataType, F32, F64, ListRef, ListValue, ScalarRefImpl, VectorRef, VectorVal,
18};
19use risingwave_common::util::iter_util::ZipEqFast;
20use risingwave_common::vector::MeasureDistanceBuilder;
21use risingwave_common::vector::distance::{L1Distance, L2SqrDistance, inner_product_faiss};
22use risingwave_expr::expr::Context;
23use risingwave_expr::{ExprError, Result, function};
24
25fn check_dims(name: &'static str, lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<()> {
26    if lhs.dimension() != rhs.dimension() {
27        return Err(ExprError::InvalidParam {
28            name,
29            reason: format!(
30                "different vector dimensions {} and {}",
31                lhs.dimension(),
32                rhs.dimension()
33            )
34            .into(),
35        });
36    }
37    Ok(())
38}
39
40/// ```slt
41/// query R
42/// SELECT l2_distance('[0,0]'::vector(2), '[3,4]');
43/// ----
44/// 5
45///
46/// query R
47/// SELECT l2_distance('[0,0]'::vector(2), '[0,1]');
48/// ----
49/// 1
50///
51/// query error dimensions
52/// SELECT l2_distance('[1,2]'::vector(2), '[3]');
53///
54/// query R
55/// SELECT l2_distance('[3e38]'::vector(1), '[-3e38]');
56/// ----
57/// Infinity
58///
59/// query R
60/// SELECT l2_distance('[1,1,1,1,1,1,1,1,1]'::vector(9), '[1,1,1,1,1,1,1,4,5]');
61/// ----
62/// 5
63///
64/// query R
65/// SELECT '[0,0]'::vector(2) <-> '[3,4]';
66/// ----
67/// 5
68/// ```
69#[function("l2_distance(vector, vector) -> float8"/*, type_infer = "unreachable"*/)]
70fn l2_distance(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<F64> {
71    check_dims("l2_distance", lhs, rhs)?;
72    Ok(L2SqrDistance::distance(lhs, rhs).sqrt().into())
73}
74
75/// ```slt
76/// query R
77/// SELECT abs(cosine_distance('[1,2]'::vector(2), '[2,4]')) < 1e-5;
78/// ----
79/// t
80///
81/// query R
82/// SELECT cosine_distance('[1,2]'::vector(2), '[0,0]');
83/// ----
84/// NaN
85///
86/// query R
87/// SELECT abs(cosine_distance('[1,1]'::vector(2), '[1,1]')) < 1e-5;
88/// ----
89/// t
90///
91/// query R
92/// SELECT abs(cosine_distance('[1,0]'::vector(2), '[0,2]') - 1.0) < 1e-5;
93/// ----
94/// t
95///
96/// query R
97/// SELECT abs(cosine_distance('[1,1]'::vector(2), '[-1,-1]') - 2) < 1e-5;
98/// ----
99/// t
100///
101/// query error dimensions
102/// SELECT cosine_distance('[1,2]'::vector(2), '[3]');
103///
104/// query R
105/// SELECT cosine_distance('[1,1]'::vector(2), '[1.1,1.1]');
106/// ----
107/// 0
108///
109/// query R
110/// SELECT cosine_distance('[1,1]'::vector(2), '[-1.1,-1.1]');
111/// ----
112/// 2
113///
114/// query R
115/// SELECT cosine_distance('[3e38]'::vector(1), '[3e38]');
116/// ----
117/// NaN
118///
119/// query R
120/// SELECT cosine_distance('[1,2,3,4,5,6,7,8,9]'::vector(9), '[1,2,3,4,5,6,7,8,9]');
121/// ----
122/// 0
123///
124/// query R
125/// SELECT cosine_distance('[1,2,3,4,5,6,7,8,9]'::vector(9), '[-1,-2,-3,-4,-5,-6,-7,-8,-9]');
126/// ----
127/// 2
128///
129/// query R
130/// SELECT '[1,2]'::vector(2) <=> '[2,4]';
131/// ----
132/// 0
133/// ```
134#[function("cosine_distance(vector, vector) -> float8")]
135fn cosine_distance(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<F64> {
136    check_dims("cosine_distance", lhs, rhs)?;
137    Ok(risingwave_common::vector::distance::cosine_distance(lhs, rhs).into())
138}
139
140/// ```slt
141/// query R
142/// SELECT l1_distance('[0,0]'::vector(2), '[3,4]');
143/// ----
144/// 7
145///
146/// query R
147/// SELECT l1_distance('[0,0]'::vector(2), '[0,1]');
148/// ----
149/// 1
150///
151/// query error dimensions
152/// SELECT l1_distance('[1,2]'::vector(2), '[3]');
153///
154/// query R
155/// SELECT l1_distance('[3e38]'::vector(1), '[-3e38]');
156/// ----
157/// Infinity
158///
159/// query R
160/// SELECT l1_distance('[1,2,3,4,5,6,7,8,9]'::vector(9), '[1,2,3,4,5,6,7,8,9]');
161/// ----
162/// 0
163///
164/// query R
165/// SELECT l1_distance('[1,2,3,4,5,6,7,8,9]'::vector(9), '[0,3,2,5,4,7,6,9,8]');
166/// ----
167/// 9
168///
169/// query R
170/// SELECT '[0,0]'::vector(2) <+> '[3,4]';
171/// ----
172/// 7
173/// ```
174#[function("l1_distance(vector, vector) -> float8")]
175fn l1_distance(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<F64> {
176    check_dims("l1_distance", lhs, rhs)?;
177    Ok(L1Distance::distance(lhs, rhs).into())
178}
179
180/// ```slt
181/// query R
182/// SELECT inner_product('[1,2]'::vector(2), '[3,4]');
183/// ----
184/// 11
185///
186/// query error dimensions
187/// SELECT inner_product('[1,2]'::vector(2), '[3]');
188///
189/// query R
190/// SELECT inner_product('[3e38]'::vector(1), '[3e38]');
191/// ----
192/// Infinity
193///
194/// query R
195/// SELECT inner_product('[1,1,1,1,1,1,1,1,1]'::vector(9), '[1,2,3,4,5,6,7,8,9]');
196/// ----
197/// 45
198///
199/// query R
200/// SELECT '[1,2]'::vector(2) <#> '[3,4]';
201/// ----
202/// -11
203/// ```
204#[function("inner_product(vector, vector) -> float8")]
205fn inner_product(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<F64> {
206    check_dims("inner_product", lhs, rhs)?;
207    Ok(inner_product_faiss(lhs, rhs).into())
208}
209
210/// ```slt
211/// query R
212/// SELECT '[1,2,3]'::vector(3) + '[4,5,6]';
213/// ----
214/// [5,7,9]
215///
216/// query error out of range: overflow
217/// SELECT '[3e38]'::vector(1) + '[3e38]';
218///
219/// query error dimensions
220/// SELECT '[1,2]'::vector(2) + '[3]';
221/// ```
222#[function("add(vector, vector) -> vector", type_infer = "unreachable")]
223fn vector_add(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<VectorVal> {
224    check_dims("vector_add", lhs, rhs)?;
225    let lhs = lhs.as_raw_slice();
226    let rhs = rhs.as_raw_slice();
227
228    let result = lhs
229        .iter()
230        .zip_eq_fast(rhs.iter())
231        .map(|(l, r)| Finite32::try_from(l + r))
232        .try_collect()
233        .map_err(|_| ExprError::NumericOverflow)?;
234    Ok(result)
235}
236
237/// ```slt
238/// query R
239/// SELECT '[1,2,3]'::vector(3) - '[4,5,6]';
240/// ----
241/// [-3,-3,-3]
242///
243/// query error out of range: overflow
244/// SELECT '[-3e38]'::vector(1) - '[3e38]';
245///
246/// query error dimensions
247/// SELECT '[1,2]'::vector(2) - '[3]';
248/// ```
249#[function("subtract(vector, vector) -> vector", type_infer = "unreachable")]
250fn vector_subtract(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<VectorVal> {
251    check_dims("vector_subtract", lhs, rhs)?;
252    let lhs = lhs.as_raw_slice();
253    let rhs = rhs.as_raw_slice();
254
255    let result = lhs
256        .iter()
257        .zip_eq_fast(rhs.iter())
258        .map(|(l, r)| Finite32::try_from(l - r))
259        .try_collect()
260        .map_err(|_| ExprError::NumericOverflow)?;
261    Ok(result)
262}
263
264/// ```slt
265/// query R
266/// SELECT '[1,2,3]'::vector(3) * '[4,5,6]';
267/// ----
268/// [4,10,18]
269///
270/// query error out of range: overflow
271/// SELECT '[1e37]'::vector(1) * '[1e37]';
272///
273/// query error out of range: underflow
274/// SELECT '[1e-37]'::vector(1) * '[1e-37]';
275///
276/// query error dimensions
277/// SELECT '[1,2]'::vector(2) * '[3]';
278/// ```
279#[function("multiply(vector, vector) -> vector", type_infer = "unreachable")]
280fn vector_multiply(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<VectorVal> {
281    check_dims("vector_multiply", lhs, rhs)?;
282    let lhs = lhs.as_raw_slice();
283    let rhs = rhs.as_raw_slice();
284
285    let result = lhs
286        .iter()
287        .zip_eq_fast(rhs.iter())
288        .map(|(l, r)| {
289            let v = l * r;
290            match v == 0. && !(*l == 0. || *r == 0.) {
291                true => Err(ExprError::NumericUnderflow),
292                false => Finite32::try_from(v).map_err(|_| ExprError::NumericOverflow),
293            }
294        })
295        .try_collect()?;
296    Ok(result)
297}
298
299/// ```slt
300/// query R
301/// SELECT '[1,2,3]'::vector(3) || '[4,5]';
302/// ----
303/// [1,2,3,4,5]
304///
305/// query error cast
306/// SELECT '[1,2,3]'::vector(3) || null;
307///
308/// query R
309/// SELECT '[1,2,3]'::vector(3) || null::vector(4);
310/// ----
311/// NULL
312///
313/// query error vector cannot have more than 16000 dimensions
314/// SELECT null::vector(16000) || '[1]';
315/// ```
316#[function("vec_concat(vector, vector) -> vector", type_infer = "unreachable")]
317fn vector_concat(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<VectorVal> {
318    let lhs = lhs.as_raw_slice();
319    let rhs = rhs.as_raw_slice();
320
321    let result = lhs
322        .iter()
323        .chain(rhs)
324        .copied()
325        .map(Finite32::try_from)
326        .try_collect()
327        .map_err(|_| ExprError::NumericOverflow)?;
328    Ok(result)
329}
330
331/// ```slt
332/// query T
333/// SELECT '[1,2,3]'::vector(3)::real[];
334/// ----
335/// {1,2,3}
336/// ```
337#[function("cast(vector) -> float4[]")]
338fn vector_to_float4(v: VectorRef<'_>) -> ListValue {
339    v.as_raw_slice().iter().copied().map(F32::from).collect()
340}
341
342/// ```slt
343/// query T
344/// SELECT ARRAY[1,2,3]::vector(3);
345/// ----
346/// [1,2,3]
347///
348/// query T
349/// SELECT ARRAY[1.0,2.0,3.0]::vector(3);
350/// ----
351/// [1,2,3]
352///
353/// query T
354/// SELECT ARRAY[1,2,3]::float4[]::vector(3);
355/// ----
356/// [1,2,3]
357///
358/// query T
359/// SELECT ARRAY[1,2,3]::float8[]::vector(3);
360/// ----
361/// [1,2,3]
362///
363/// query T
364/// SELECT ARRAY[1,2,3]::numeric[]::vector(3);
365/// ----
366/// [1,2,3]
367///
368/// query T
369/// SELECT '{1,2,3}'::real[]::vector(3);
370/// ----
371/// [1,2,3]
372///
373/// query error expected 2 dimensions, not 3
374/// SELECT '{1,2,3}'::real[]::vector(2);
375///
376/// query error array must not contain nulls
377/// SELECT '{NULL}'::real[]::vector(1);
378///
379/// query error NaN not allowed in vector
380/// SELECT '{NaN}'::real[]::vector(1);
381///
382/// query error inf not allowed in vector
383/// SELECT '{Infinity}'::real[]::vector(1);
384///
385/// query error -inf not allowed in vector
386/// SELECT '{-Infinity}'::real[]::vector(1);
387///
388/// query error dimension
389/// SELECT '{}'::real[]::vector(1);
390///
391/// query error cannot cast
392/// SELECT '{{1}}'::real[][]::vector(1);
393///
394/// query T
395/// SELECT '{1,2,3}'::double precision[]::vector(3);
396/// ----
397/// [1,2,3]
398///
399/// query error expected 2 dimensions, not 3
400/// SELECT '{1,2,3}'::double precision[]::vector(2);
401///
402/// query error out of range
403/// SELECT '{4e38,-4e38}'::double precision[]::vector(2);
404///
405/// # Caveat: pgvector does not check underflow and returns 0 here.
406/// query error out of range
407/// SELECT '{1e-46,-1e-46}'::double precision[]::vector(2);
408/// ```
409#[function("cast(int4[]) -> vector", type_infer = "unreachable")]
410#[function("cast(decimal[]) -> vector", type_infer = "unreachable")]
411#[function("cast(float4[]) -> vector", type_infer = "unreachable")]
412#[function("cast(float8[]) -> vector", type_infer = "unreachable")]
413fn array_to_vector(array: ListRef<'_>, ctx: &Context) -> Result<VectorVal> {
414    macro_rules! bail_invalid_param {
415        ($($arg:tt)*) => {
416            return Err(ExprError::InvalidParam {
417                name: "array_to_vector",
418                reason: format!($($arg)*).into(),
419            });
420        };
421    }
422
423    let DataType::Vector(size) = ctx.return_type else {
424        unreachable!()
425    };
426    if array.len() != size {
427        bail_invalid_param!("expected {} dimensions, not {}", size, array.len());
428    }
429    let result = array
430        .iter()
431        .map(|scalar| {
432            let Some(scalar) = scalar else {
433                bail_invalid_param!("array must not contain nulls");
434            };
435            let val = match scalar {
436                ScalarRefImpl::Int32(val) => val.into(),
437                ScalarRefImpl::Decimal(val) => {
438                    val.try_into().map_err(|_| ExprError::NumericOverflow)?
439                }
440                ScalarRefImpl::Float32(val) => val,
441                ScalarRefImpl::Float64(val) => {
442                    val.try_into().map_err(|_| ExprError::NumericOverflow)?
443                }
444                _ => unreachable!(),
445            };
446            Finite32::try_from(val.0).map_err(|err| ExprError::InvalidParam {
447                name: "array_to_vector",
448                reason: err.into(),
449            })
450        })
451        .try_collect()?;
452    Ok(result)
453}
454
455/// ```slt
456/// query R
457/// SELECT round(vector_norm('[1,1]'::vector(2))::numeric, 5);
458/// ----
459/// 1.41421
460///
461/// query R
462/// SELECT vector_norm('[3,4]'::vector(2));
463/// ----
464/// 5
465///
466/// query R
467/// SELECT vector_norm('[0,1]'::vector(2));
468/// ----
469/// 1
470///
471/// query R
472/// SELECT vector_norm('[3e18,4e18]'::vector(2))::real;
473/// ----
474/// 5e+18
475///
476/// query R
477/// SELECT vector_norm('[0,0]'::vector(2));
478/// ----
479/// 0
480///
481/// query R
482/// SELECT vector_norm('[2]'::vector(1));
483/// ----
484/// 2
485/// ```
486#[function("l2_norm(vector) -> float8")]
487fn l2_norm(vector: VectorRef<'_>) -> F64 {
488    (vector.l2_norm() as f64).into()
489}
490
491/// ```slt
492/// query R
493/// SELECT l2_normalize('[3,4]'::vector(2));
494/// ----
495/// [0.6,0.8]
496///
497/// query R
498/// SELECT l2_normalize('[3,0]'::vector(2));
499/// ----
500/// [1,0]
501///
502/// query R
503/// SELECT l2_normalize('[0,0.1]'::vector(2));
504/// ----
505/// [0,1]
506///
507/// query R
508/// SELECT l2_normalize('[0,0]'::vector(2));
509/// ----
510/// [0,0]
511///
512/// query R
513/// SELECT l2_normalize('[3e18]'::vector(1));
514/// ----
515/// [1]
516/// ```
517#[function(
518    "l2_normalize(vector) -> vector",
519    type_infer = "|args| Ok(args[0].clone())"
520)]
521fn l2_normalize(vector: VectorRef<'_>) -> VectorVal {
522    vector.normalized()
523}