risingwave_expr_impl/scalar/
vector.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::array::Finite32;
16use risingwave_common::types::{DataType, F32, F64, ListRef, ScalarRefImpl, VectorRef, VectorVal};
17use risingwave_common::util::iter_util::ZipEqFast;
18use risingwave_common::vector::MeasureDistanceBuilder;
19use risingwave_common::vector::distance::{L1Distance, L2SqrDistance, inner_product_faiss};
20use risingwave_expr::expr::Context;
21use risingwave_expr::{ExprError, Result, function};
22
23fn check_dims(name: &'static str, lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<()> {
24    if lhs.dimension() != rhs.dimension() {
25        return Err(ExprError::InvalidParam {
26            name,
27            reason: format!(
28                "different vector dimensions {} and {}",
29                lhs.dimension(),
30                rhs.dimension()
31            )
32            .into(),
33        });
34    }
35    Ok(())
36}
37
38/// ```slt
39/// query R
40/// SELECT l2_distance('[0,0]'::vector(2), '[3,4]');
41/// ----
42/// 5
43///
44/// query R
45/// SELECT l2_distance('[0,0]'::vector(2), '[0,1]');
46/// ----
47/// 1
48///
49/// query error dimensions
50/// SELECT l2_distance('[1,2]'::vector(2), '[3]');
51///
52/// query R
53/// SELECT l2_distance('[3e38]'::vector(1), '[-3e38]');
54/// ----
55/// Infinity
56///
57/// query R
58/// SELECT l2_distance('[1,1,1,1,1,1,1,1,1]'::vector(9), '[1,1,1,1,1,1,1,4,5]');
59/// ----
60/// 5
61///
62/// query R
63/// SELECT '[0,0]'::vector(2) <-> '[3,4]';
64/// ----
65/// 5
66/// ```
67#[function("l2_distance(vector, vector) -> float8"/*, type_infer = "unreachable"*/)]
68fn l2_distance(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<F64> {
69    check_dims("l2_distance", lhs, rhs)?;
70    Ok(L2SqrDistance::distance(lhs, rhs).sqrt().into())
71}
72
73/// ```slt
74/// query R
75/// SELECT abs(cosine_distance('[1,2]'::vector(2), '[2,4]')) < 1e-5;
76/// ----
77/// t
78///
79/// query R
80/// SELECT cosine_distance('[1,2]'::vector(2), '[0,0]');
81/// ----
82/// NaN
83///
84/// query R
85/// SELECT abs(cosine_distance('[1,1]'::vector(2), '[1,1]')) < 1e-5;
86/// ----
87/// t
88///
89/// query R
90/// SELECT abs(cosine_distance('[1,0]'::vector(2), '[0,2]') - 1.0) < 1e-5;
91/// ----
92/// t
93///
94/// query R
95/// SELECT abs(cosine_distance('[1,1]'::vector(2), '[-1,-1]') - 2) < 1e-5;
96/// ----
97/// t
98///
99/// query error dimensions
100/// SELECT cosine_distance('[1,2]'::vector(2), '[3]');
101///
102/// query R
103/// SELECT cosine_distance('[1,1]'::vector(2), '[1.1,1.1]');
104/// ----
105/// 0
106///
107/// query R
108/// SELECT cosine_distance('[1,1]'::vector(2), '[-1.1,-1.1]');
109/// ----
110/// 2
111///
112/// query R
113/// SELECT cosine_distance('[3e38]'::vector(1), '[3e38]');
114/// ----
115/// NaN
116///
117/// query R
118/// SELECT cosine_distance('[1,2,3,4,5,6,7,8,9]'::vector(9), '[1,2,3,4,5,6,7,8,9]');
119/// ----
120/// 0
121///
122/// query R
123/// SELECT cosine_distance('[1,2,3,4,5,6,7,8,9]'::vector(9), '[-1,-2,-3,-4,-5,-6,-7,-8,-9]');
124/// ----
125/// 2
126///
127/// query R
128/// SELECT '[1,2]'::vector(2) <=> '[2,4]';
129/// ----
130/// 0
131/// ```
132#[function("cosine_distance(vector, vector) -> float8")]
133fn cosine_distance(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<F64> {
134    check_dims("cosine_distance", lhs, rhs)?;
135    Ok(risingwave_common::vector::distance::cosine_distance(lhs, rhs).into())
136}
137
138/// ```slt
139/// query R
140/// SELECT l1_distance('[0,0]'::vector(2), '[3,4]');
141/// ----
142/// 7
143///
144/// query R
145/// SELECT l1_distance('[0,0]'::vector(2), '[0,1]');
146/// ----
147/// 1
148///
149/// query error dimensions
150/// SELECT l1_distance('[1,2]'::vector(2), '[3]');
151///
152/// query R
153/// SELECT l1_distance('[3e38]'::vector(1), '[-3e38]');
154/// ----
155/// Infinity
156///
157/// query R
158/// SELECT l1_distance('[1,2,3,4,5,6,7,8,9]'::vector(9), '[1,2,3,4,5,6,7,8,9]');
159/// ----
160/// 0
161///
162/// query R
163/// SELECT l1_distance('[1,2,3,4,5,6,7,8,9]'::vector(9), '[0,3,2,5,4,7,6,9,8]');
164/// ----
165/// 9
166///
167/// query R
168/// SELECT '[0,0]'::vector(2) <+> '[3,4]';
169/// ----
170/// 7
171/// ```
172#[function("l1_distance(vector, vector) -> float8")]
173fn l1_distance(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<F64> {
174    check_dims("l1_distance", lhs, rhs)?;
175    Ok(L1Distance::distance(lhs, rhs).into())
176}
177
178/// ```slt
179/// query R
180/// SELECT inner_product('[1,2]'::vector(2), '[3,4]');
181/// ----
182/// 11
183///
184/// query error dimensions
185/// SELECT inner_product('[1,2]'::vector(2), '[3]');
186///
187/// query R
188/// SELECT inner_product('[3e38]'::vector(1), '[3e38]');
189/// ----
190/// Infinity
191///
192/// query R
193/// SELECT inner_product('[1,1,1,1,1,1,1,1,1]'::vector(9), '[1,2,3,4,5,6,7,8,9]');
194/// ----
195/// 45
196///
197/// query R
198/// SELECT '[1,2]'::vector(2) <#> '[3,4]';
199/// ----
200/// -11
201/// ```
202#[function("inner_product(vector, vector) -> float8")]
203fn inner_product(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<F64> {
204    check_dims("inner_product", lhs, rhs)?;
205    Ok(inner_product_faiss(lhs, rhs).into())
206}
207
208/// ```slt
209/// query R
210/// SELECT '[1,2,3]'::vector(3) + '[4,5,6]';
211/// ----
212/// [5,7,9]
213///
214/// query error out of range: overflow
215/// SELECT '[3e38]'::vector(1) + '[3e38]';
216///
217/// query error dimensions
218/// SELECT '[1,2]'::vector(2) + '[3]';
219/// ```
220#[function("add(vector, vector) -> vector", type_infer = "unreachable")]
221fn vector_add(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<VectorVal> {
222    check_dims("vector_add", lhs, rhs)?;
223    let lhs = lhs.as_raw_slice();
224    let rhs = rhs.as_raw_slice();
225
226    let result = lhs
227        .iter()
228        .zip_eq_fast(rhs.iter())
229        .map(|(l, r)| Finite32::try_from(l + r))
230        .try_collect()
231        .map_err(|_| ExprError::NumericOverflow)?;
232    Ok(result)
233}
234
235/// ```slt
236/// query R
237/// SELECT '[1,2,3]'::vector(3) - '[4,5,6]';
238/// ----
239/// [-3,-3,-3]
240///
241/// query error out of range: overflow
242/// SELECT '[-3e38]'::vector(1) - '[3e38]';
243///
244/// query error dimensions
245/// SELECT '[1,2]'::vector(2) - '[3]';
246/// ```
247#[function("subtract(vector, vector) -> vector", type_infer = "unreachable")]
248fn vector_subtract(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<VectorVal> {
249    check_dims("vector_subtract", lhs, rhs)?;
250    let lhs = lhs.as_raw_slice();
251    let rhs = rhs.as_raw_slice();
252
253    let result = lhs
254        .iter()
255        .zip_eq_fast(rhs.iter())
256        .map(|(l, r)| Finite32::try_from(l - r))
257        .try_collect()
258        .map_err(|_| ExprError::NumericOverflow)?;
259    Ok(result)
260}
261
262/// ```slt
263/// query R
264/// SELECT '[1,2,3]'::vector(3) * '[4,5,6]';
265/// ----
266/// [4,10,18]
267///
268/// query error out of range: overflow
269/// SELECT '[1e37]'::vector(1) * '[1e37]';
270///
271/// query error out of range: underflow
272/// SELECT '[1e-37]'::vector(1) * '[1e-37]';
273///
274/// query error dimensions
275/// SELECT '[1,2]'::vector(2) * '[3]';
276/// ```
277#[function("multiply(vector, vector) -> vector", type_infer = "unreachable")]
278fn vector_multiply(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<VectorVal> {
279    check_dims("vector_multiply", lhs, rhs)?;
280    let lhs = lhs.as_raw_slice();
281    let rhs = rhs.as_raw_slice();
282
283    let result = lhs
284        .iter()
285        .zip_eq_fast(rhs.iter())
286        .map(|(l, r)| {
287            let v = l * r;
288            match v == 0. && !(*l == 0. || *r == 0.) {
289                true => Err(ExprError::NumericUnderflow),
290                false => Finite32::try_from(v).map_err(|_| ExprError::NumericOverflow),
291            }
292        })
293        .try_collect()?;
294    Ok(result)
295}
296
297/// ```slt
298/// query R
299/// SELECT '[1,2,3]'::vector(3) || '[4,5]';
300/// ----
301/// [1,2,3,4,5]
302///
303/// query error cast
304/// SELECT '[1,2,3]'::vector(3) || null;
305///
306/// query R
307/// SELECT '[1,2,3]'::vector(3) || null::vector(4);
308/// ----
309/// NULL
310///
311/// query error vector cannot have more than 16000 dimensions
312/// SELECT null::vector(16000) || '[1]';
313/// ```
314#[function("vec_concat(vector, vector) -> vector", type_infer = "unreachable")]
315fn vector_concat(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<VectorVal> {
316    let lhs = lhs.as_raw_slice();
317    let rhs = rhs.as_raw_slice();
318
319    let result = lhs
320        .iter()
321        .chain(rhs)
322        .copied()
323        .map(Finite32::try_from)
324        .try_collect()
325        .map_err(|_| ExprError::NumericOverflow)?;
326    Ok(result)
327}
328
329/// ```slt
330/// query T
331/// SELECT '[1,2,3]'::vector(3)::real[];
332/// ----
333/// {1,2,3}
334/// ```
335#[function("cast(vector) -> float4[]")]
336fn vector_to_float4(v: VectorRef<'_>, writer: &mut impl risingwave_common::array::ListWrite) {
337    writer.write_iter(
338        v.as_raw_slice()
339            .iter()
340            .map(|&f| Some(ScalarRefImpl::Float32(F32::from(f)))),
341    );
342}
343
344/// ```slt
345/// query T
346/// SELECT ARRAY[1,2,3]::vector(3);
347/// ----
348/// [1,2,3]
349///
350/// query T
351/// SELECT ARRAY[1.0,2.0,3.0]::vector(3);
352/// ----
353/// [1,2,3]
354///
355/// query T
356/// SELECT ARRAY[1,2,3]::float4[]::vector(3);
357/// ----
358/// [1,2,3]
359///
360/// query T
361/// SELECT ARRAY[1,2,3]::float8[]::vector(3);
362/// ----
363/// [1,2,3]
364///
365/// query T
366/// SELECT ARRAY[1,2,3]::numeric[]::vector(3);
367/// ----
368/// [1,2,3]
369///
370/// query T
371/// SELECT '{1,2,3}'::real[]::vector(3);
372/// ----
373/// [1,2,3]
374///
375/// query error expected 2 dimensions, not 3
376/// SELECT '{1,2,3}'::real[]::vector(2);
377///
378/// query error array must not contain nulls
379/// SELECT '{NULL}'::real[]::vector(1);
380///
381/// query error NaN not allowed in vector
382/// SELECT '{NaN}'::real[]::vector(1);
383///
384/// query error inf not allowed in vector
385/// SELECT '{Infinity}'::real[]::vector(1);
386///
387/// query error -inf not allowed in vector
388/// SELECT '{-Infinity}'::real[]::vector(1);
389///
390/// query error dimension
391/// SELECT '{}'::real[]::vector(1);
392///
393/// query error cannot cast
394/// SELECT '{{1}}'::real[][]::vector(1);
395///
396/// query T
397/// SELECT '{1,2,3}'::double precision[]::vector(3);
398/// ----
399/// [1,2,3]
400///
401/// query error expected 2 dimensions, not 3
402/// SELECT '{1,2,3}'::double precision[]::vector(2);
403///
404/// query error out of range
405/// SELECT '{4e38,-4e38}'::double precision[]::vector(2);
406///
407/// # Caveat: pgvector does not check underflow and returns 0 here.
408/// query error out of range
409/// SELECT '{1e-46,-1e-46}'::double precision[]::vector(2);
410/// ```
411#[function("cast(int4[]) -> vector", type_infer = "unreachable")]
412#[function("cast(decimal[]) -> vector", type_infer = "unreachable")]
413#[function("cast(float4[]) -> vector", type_infer = "unreachable")]
414#[function("cast(float8[]) -> vector", type_infer = "unreachable")]
415fn array_to_vector(array: ListRef<'_>, ctx: &Context) -> Result<VectorVal> {
416    macro_rules! bail_invalid_param {
417        ($($arg:tt)*) => {
418            return Err(ExprError::InvalidParam {
419                name: "array_to_vector",
420                reason: format!($($arg)*).into(),
421            });
422        };
423    }
424
425    let DataType::Vector(size) = ctx.return_type else {
426        unreachable!()
427    };
428    if array.len() != size {
429        bail_invalid_param!("expected {} dimensions, not {}", size, array.len());
430    }
431    let result = array
432        .iter()
433        .map(|scalar| {
434            let Some(scalar) = scalar else {
435                bail_invalid_param!("array must not contain nulls");
436            };
437            let val = match scalar {
438                ScalarRefImpl::Int32(val) => val.into(),
439                ScalarRefImpl::Decimal(val) => {
440                    val.try_into().map_err(|_| ExprError::NumericOverflow)?
441                }
442                ScalarRefImpl::Float32(val) => val,
443                ScalarRefImpl::Float64(val) => {
444                    val.try_into().map_err(|_| ExprError::NumericOverflow)?
445                }
446                _ => unreachable!(),
447            };
448            Finite32::try_from(val.0).map_err(|err| ExprError::InvalidParam {
449                name: "array_to_vector",
450                reason: err.into(),
451            })
452        })
453        .try_collect()?;
454    Ok(result)
455}
456
457/// ```slt
458/// query R
459/// SELECT round(vector_norm('[1,1]'::vector(2))::numeric, 5);
460/// ----
461/// 1.41421
462///
463/// query R
464/// SELECT vector_norm('[3,4]'::vector(2));
465/// ----
466/// 5
467///
468/// query R
469/// SELECT vector_norm('[0,1]'::vector(2));
470/// ----
471/// 1
472///
473/// query R
474/// SELECT vector_norm('[3e18,4e18]'::vector(2))::real;
475/// ----
476/// 5e+18
477///
478/// query R
479/// SELECT vector_norm('[0,0]'::vector(2));
480/// ----
481/// 0
482///
483/// query R
484/// SELECT vector_norm('[2]'::vector(1));
485/// ----
486/// 2
487/// ```
488#[function("l2_norm(vector) -> float8")]
489fn l2_norm(vector: VectorRef<'_>) -> F64 {
490    (vector.l2_norm() as f64).into()
491}
492
493/// ```slt
494/// query R
495/// SELECT l2_normalize('[3,4]'::vector(2));
496/// ----
497/// [0.6,0.8]
498///
499/// query R
500/// SELECT l2_normalize('[3,0]'::vector(2));
501/// ----
502/// [1,0]
503///
504/// query R
505/// SELECT l2_normalize('[0,0.1]'::vector(2));
506/// ----
507/// [0,1]
508///
509/// query R
510/// SELECT l2_normalize('[0,0]'::vector(2));
511/// ----
512/// [0,0]
513///
514/// query R
515/// SELECT l2_normalize('[3e18]'::vector(1));
516/// ----
517/// [1]
518/// ```
519#[function(
520    "l2_normalize(vector) -> vector",
521    type_infer = "|args| Ok(args[0].clone())"
522)]
523fn l2_normalize(vector: VectorRef<'_>) -> VectorVal {
524    vector.normalized()
525}
526
527#[derive(Debug)]
528pub struct SubvectorContext {
529    pub start: usize,
530    pub end: usize,
531}
532
533impl SubvectorContext {
534    pub fn from_start_count(start: i32, count: i32) -> Result<Self> {
535        Ok(Self {
536            start: (start - 1) as usize,
537            end: (start + count - 1) as usize,
538        })
539    }
540}
541
542/// ```slt
543/// query R
544/// SELECT subvector('[1,2,3,4,5]'::vector(5), 1, 3);
545/// ----
546/// [1,2,3]
547///
548/// query R
549/// SELECT subvector('[1,2,3,4,5]'::vector(5), 3, 2);
550/// ----
551/// [3,4]
552///
553/// query R
554/// SELECT subvector('[1,2,3,4,5]'::vector(5), 1, 5);
555/// ----
556/// [1,2,3,4,5]
557///
558/// query R
559/// SELECT subvector('[1,2,3,4,5]'::vector(5), 5, 1);
560/// ----
561/// [5]
562///
563/// query R
564/// SELECT subvector('[1,2,3,4,5]'::vector(5), 2, 3);
565/// ----
566/// [2,3,4]
567///
568/// query R
569/// select subvector(vec, 1, 3) from (values ('[1,2,3,4,5]'::vector(5)), ('[6,7,8,9,10]'::vector(5))) as t(vec);
570/// ----
571/// [1,2,3]
572/// [6,7,8]
573///
574/// statement error
575/// SELECT subvector('[1,2,3,4,5]'::vector(5), -1, 2);
576///
577/// statement error
578/// SELECT subvector('[6,7,8,9,10]'::vector(5), 1, 6);
579///
580/// statement error
581/// SELECT subvector('[6,7,8,9,10]'::vector(5), 5, 2);
582/// ```
583#[function(
584    "subvector(vector, int4, int4) -> vector",
585    prebuild = "SubvectorContext::from_start_count($1, $2)?",
586    type_infer = "unreachable"
587)]
588fn subvector(v: VectorRef<'_>, ctx: &SubvectorContext) -> Result<VectorVal> {
589    Ok(v.subvector(ctx.start, ctx.end))
590}