risingwave_expr_impl/scalar/vector.rs
1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use risingwave_common::array::Finite32;
16use risingwave_common::types::{DataType, F32, F64, ListRef, ScalarRefImpl, VectorRef, VectorVal};
17use risingwave_common::util::iter_util::ZipEqFast;
18use risingwave_common::vector::MeasureDistanceBuilder;
19use risingwave_common::vector::distance::{L1Distance, L2SqrDistance, inner_product_faiss};
20use risingwave_expr::expr::Context;
21use risingwave_expr::{ExprError, Result, function};
22
23fn check_dims(name: &'static str, lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<()> {
24 if lhs.dimension() != rhs.dimension() {
25 return Err(ExprError::InvalidParam {
26 name,
27 reason: format!(
28 "different vector dimensions {} and {}",
29 lhs.dimension(),
30 rhs.dimension()
31 )
32 .into(),
33 });
34 }
35 Ok(())
36}
37
38/// ```slt
39/// query R
40/// SELECT l2_distance('[0,0]'::vector(2), '[3,4]');
41/// ----
42/// 5
43///
44/// query R
45/// SELECT l2_distance('[0,0]'::vector(2), '[0,1]');
46/// ----
47/// 1
48///
49/// query error dimensions
50/// SELECT l2_distance('[1,2]'::vector(2), '[3]');
51///
52/// query R
53/// SELECT l2_distance('[3e38]'::vector(1), '[-3e38]');
54/// ----
55/// Infinity
56///
57/// query R
58/// SELECT l2_distance('[1,1,1,1,1,1,1,1,1]'::vector(9), '[1,1,1,1,1,1,1,4,5]');
59/// ----
60/// 5
61///
62/// query R
63/// SELECT '[0,0]'::vector(2) <-> '[3,4]';
64/// ----
65/// 5
66/// ```
67#[function("l2_distance(vector, vector) -> float8"/*, type_infer = "unreachable"*/)]
68fn l2_distance(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<F64> {
69 check_dims("l2_distance", lhs, rhs)?;
70 Ok(L2SqrDistance::distance(lhs, rhs).sqrt().into())
71}
72
73/// ```slt
74/// query R
75/// SELECT abs(cosine_distance('[1,2]'::vector(2), '[2,4]')) < 1e-5;
76/// ----
77/// t
78///
79/// query R
80/// SELECT cosine_distance('[1,2]'::vector(2), '[0,0]');
81/// ----
82/// NaN
83///
84/// query R
85/// SELECT abs(cosine_distance('[1,1]'::vector(2), '[1,1]')) < 1e-5;
86/// ----
87/// t
88///
89/// query R
90/// SELECT abs(cosine_distance('[1,0]'::vector(2), '[0,2]') - 1.0) < 1e-5;
91/// ----
92/// t
93///
94/// query R
95/// SELECT abs(cosine_distance('[1,1]'::vector(2), '[-1,-1]') - 2) < 1e-5;
96/// ----
97/// t
98///
99/// query error dimensions
100/// SELECT cosine_distance('[1,2]'::vector(2), '[3]');
101///
102/// query R
103/// SELECT cosine_distance('[1,1]'::vector(2), '[1.1,1.1]');
104/// ----
105/// 0
106///
107/// query R
108/// SELECT cosine_distance('[1,1]'::vector(2), '[-1.1,-1.1]');
109/// ----
110/// 2
111///
112/// query R
113/// SELECT cosine_distance('[3e38]'::vector(1), '[3e38]');
114/// ----
115/// NaN
116///
117/// query R
118/// SELECT cosine_distance('[1,2,3,4,5,6,7,8,9]'::vector(9), '[1,2,3,4,5,6,7,8,9]');
119/// ----
120/// 0
121///
122/// query R
123/// SELECT cosine_distance('[1,2,3,4,5,6,7,8,9]'::vector(9), '[-1,-2,-3,-4,-5,-6,-7,-8,-9]');
124/// ----
125/// 2
126///
127/// query R
128/// SELECT '[1,2]'::vector(2) <=> '[2,4]';
129/// ----
130/// 0
131/// ```
132#[function("cosine_distance(vector, vector) -> float8")]
133fn cosine_distance(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<F64> {
134 check_dims("cosine_distance", lhs, rhs)?;
135 Ok(risingwave_common::vector::distance::cosine_distance(lhs, rhs).into())
136}
137
138/// ```slt
139/// query R
140/// SELECT l1_distance('[0,0]'::vector(2), '[3,4]');
141/// ----
142/// 7
143///
144/// query R
145/// SELECT l1_distance('[0,0]'::vector(2), '[0,1]');
146/// ----
147/// 1
148///
149/// query error dimensions
150/// SELECT l1_distance('[1,2]'::vector(2), '[3]');
151///
152/// query R
153/// SELECT l1_distance('[3e38]'::vector(1), '[-3e38]');
154/// ----
155/// Infinity
156///
157/// query R
158/// SELECT l1_distance('[1,2,3,4,5,6,7,8,9]'::vector(9), '[1,2,3,4,5,6,7,8,9]');
159/// ----
160/// 0
161///
162/// query R
163/// SELECT l1_distance('[1,2,3,4,5,6,7,8,9]'::vector(9), '[0,3,2,5,4,7,6,9,8]');
164/// ----
165/// 9
166///
167/// query R
168/// SELECT '[0,0]'::vector(2) <+> '[3,4]';
169/// ----
170/// 7
171/// ```
172#[function("l1_distance(vector, vector) -> float8")]
173fn l1_distance(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<F64> {
174 check_dims("l1_distance", lhs, rhs)?;
175 Ok(L1Distance::distance(lhs, rhs).into())
176}
177
178/// ```slt
179/// query R
180/// SELECT inner_product('[1,2]'::vector(2), '[3,4]');
181/// ----
182/// 11
183///
184/// query error dimensions
185/// SELECT inner_product('[1,2]'::vector(2), '[3]');
186///
187/// query R
188/// SELECT inner_product('[3e38]'::vector(1), '[3e38]');
189/// ----
190/// Infinity
191///
192/// query R
193/// SELECT inner_product('[1,1,1,1,1,1,1,1,1]'::vector(9), '[1,2,3,4,5,6,7,8,9]');
194/// ----
195/// 45
196///
197/// query R
198/// SELECT '[1,2]'::vector(2) <#> '[3,4]';
199/// ----
200/// -11
201/// ```
202#[function("inner_product(vector, vector) -> float8")]
203fn inner_product(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<F64> {
204 check_dims("inner_product", lhs, rhs)?;
205 Ok(inner_product_faiss(lhs, rhs).into())
206}
207
208/// ```slt
209/// query R
210/// SELECT '[1,2,3]'::vector(3) + '[4,5,6]';
211/// ----
212/// [5,7,9]
213///
214/// query error out of range: overflow
215/// SELECT '[3e38]'::vector(1) + '[3e38]';
216///
217/// query error dimensions
218/// SELECT '[1,2]'::vector(2) + '[3]';
219/// ```
220#[function("add(vector, vector) -> vector", type_infer = "unreachable")]
221fn vector_add(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<VectorVal> {
222 check_dims("vector_add", lhs, rhs)?;
223 let lhs = lhs.as_raw_slice();
224 let rhs = rhs.as_raw_slice();
225
226 let result = lhs
227 .iter()
228 .zip_eq_fast(rhs.iter())
229 .map(|(l, r)| Finite32::try_from(l + r))
230 .try_collect()
231 .map_err(|_| ExprError::NumericOverflow)?;
232 Ok(result)
233}
234
235/// ```slt
236/// query R
237/// SELECT '[1,2,3]'::vector(3) - '[4,5,6]';
238/// ----
239/// [-3,-3,-3]
240///
241/// query error out of range: overflow
242/// SELECT '[-3e38]'::vector(1) - '[3e38]';
243///
244/// query error dimensions
245/// SELECT '[1,2]'::vector(2) - '[3]';
246/// ```
247#[function("subtract(vector, vector) -> vector", type_infer = "unreachable")]
248fn vector_subtract(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<VectorVal> {
249 check_dims("vector_subtract", lhs, rhs)?;
250 let lhs = lhs.as_raw_slice();
251 let rhs = rhs.as_raw_slice();
252
253 let result = lhs
254 .iter()
255 .zip_eq_fast(rhs.iter())
256 .map(|(l, r)| Finite32::try_from(l - r))
257 .try_collect()
258 .map_err(|_| ExprError::NumericOverflow)?;
259 Ok(result)
260}
261
262/// ```slt
263/// query R
264/// SELECT '[1,2,3]'::vector(3) * '[4,5,6]';
265/// ----
266/// [4,10,18]
267///
268/// query error out of range: overflow
269/// SELECT '[1e37]'::vector(1) * '[1e37]';
270///
271/// query error out of range: underflow
272/// SELECT '[1e-37]'::vector(1) * '[1e-37]';
273///
274/// query error dimensions
275/// SELECT '[1,2]'::vector(2) * '[3]';
276/// ```
277#[function("multiply(vector, vector) -> vector", type_infer = "unreachable")]
278fn vector_multiply(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<VectorVal> {
279 check_dims("vector_multiply", lhs, rhs)?;
280 let lhs = lhs.as_raw_slice();
281 let rhs = rhs.as_raw_slice();
282
283 let result = lhs
284 .iter()
285 .zip_eq_fast(rhs.iter())
286 .map(|(l, r)| {
287 let v = l * r;
288 match v == 0. && !(*l == 0. || *r == 0.) {
289 true => Err(ExprError::NumericUnderflow),
290 false => Finite32::try_from(v).map_err(|_| ExprError::NumericOverflow),
291 }
292 })
293 .try_collect()?;
294 Ok(result)
295}
296
297/// ```slt
298/// query R
299/// SELECT '[1,2,3]'::vector(3) || '[4,5]';
300/// ----
301/// [1,2,3,4,5]
302///
303/// query error cast
304/// SELECT '[1,2,3]'::vector(3) || null;
305///
306/// query R
307/// SELECT '[1,2,3]'::vector(3) || null::vector(4);
308/// ----
309/// NULL
310///
311/// query error vector cannot have more than 16000 dimensions
312/// SELECT null::vector(16000) || '[1]';
313/// ```
314#[function("vec_concat(vector, vector) -> vector", type_infer = "unreachable")]
315fn vector_concat(lhs: VectorRef<'_>, rhs: VectorRef<'_>) -> Result<VectorVal> {
316 let lhs = lhs.as_raw_slice();
317 let rhs = rhs.as_raw_slice();
318
319 let result = lhs
320 .iter()
321 .chain(rhs)
322 .copied()
323 .map(Finite32::try_from)
324 .try_collect()
325 .map_err(|_| ExprError::NumericOverflow)?;
326 Ok(result)
327}
328
329/// ```slt
330/// query T
331/// SELECT '[1,2,3]'::vector(3)::real[];
332/// ----
333/// {1,2,3}
334/// ```
335#[function("cast(vector) -> float4[]")]
336fn vector_to_float4(v: VectorRef<'_>, writer: &mut impl risingwave_common::array::ListWrite) {
337 writer.write_iter(
338 v.as_raw_slice()
339 .iter()
340 .map(|&f| Some(ScalarRefImpl::Float32(F32::from(f)))),
341 );
342}
343
344/// ```slt
345/// query T
346/// SELECT ARRAY[1,2,3]::vector(3);
347/// ----
348/// [1,2,3]
349///
350/// query T
351/// SELECT ARRAY[1.0,2.0,3.0]::vector(3);
352/// ----
353/// [1,2,3]
354///
355/// query T
356/// SELECT ARRAY[1,2,3]::float4[]::vector(3);
357/// ----
358/// [1,2,3]
359///
360/// query T
361/// SELECT ARRAY[1,2,3]::float8[]::vector(3);
362/// ----
363/// [1,2,3]
364///
365/// query T
366/// SELECT ARRAY[1,2,3]::numeric[]::vector(3);
367/// ----
368/// [1,2,3]
369///
370/// query T
371/// SELECT '{1,2,3}'::real[]::vector(3);
372/// ----
373/// [1,2,3]
374///
375/// query error expected 2 dimensions, not 3
376/// SELECT '{1,2,3}'::real[]::vector(2);
377///
378/// query error array must not contain nulls
379/// SELECT '{NULL}'::real[]::vector(1);
380///
381/// query error NaN not allowed in vector
382/// SELECT '{NaN}'::real[]::vector(1);
383///
384/// query error inf not allowed in vector
385/// SELECT '{Infinity}'::real[]::vector(1);
386///
387/// query error -inf not allowed in vector
388/// SELECT '{-Infinity}'::real[]::vector(1);
389///
390/// query error dimension
391/// SELECT '{}'::real[]::vector(1);
392///
393/// query error cannot cast
394/// SELECT '{{1}}'::real[][]::vector(1);
395///
396/// query T
397/// SELECT '{1,2,3}'::double precision[]::vector(3);
398/// ----
399/// [1,2,3]
400///
401/// query error expected 2 dimensions, not 3
402/// SELECT '{1,2,3}'::double precision[]::vector(2);
403///
404/// query error out of range
405/// SELECT '{4e38,-4e38}'::double precision[]::vector(2);
406///
407/// # Caveat: pgvector does not check underflow and returns 0 here.
408/// query error out of range
409/// SELECT '{1e-46,-1e-46}'::double precision[]::vector(2);
410/// ```
411#[function("cast(int4[]) -> vector", type_infer = "unreachable")]
412#[function("cast(decimal[]) -> vector", type_infer = "unreachable")]
413#[function("cast(float4[]) -> vector", type_infer = "unreachable")]
414#[function("cast(float8[]) -> vector", type_infer = "unreachable")]
415fn array_to_vector(array: ListRef<'_>, ctx: &Context) -> Result<VectorVal> {
416 macro_rules! bail_invalid_param {
417 ($($arg:tt)*) => {
418 return Err(ExprError::InvalidParam {
419 name: "array_to_vector",
420 reason: format!($($arg)*).into(),
421 });
422 };
423 }
424
425 let DataType::Vector(size) = ctx.return_type else {
426 unreachable!()
427 };
428 if array.len() != size {
429 bail_invalid_param!("expected {} dimensions, not {}", size, array.len());
430 }
431 let result = array
432 .iter()
433 .map(|scalar| {
434 let Some(scalar) = scalar else {
435 bail_invalid_param!("array must not contain nulls");
436 };
437 let val = match scalar {
438 ScalarRefImpl::Int32(val) => val.into(),
439 ScalarRefImpl::Decimal(val) => {
440 val.try_into().map_err(|_| ExprError::NumericOverflow)?
441 }
442 ScalarRefImpl::Float32(val) => val,
443 ScalarRefImpl::Float64(val) => {
444 val.try_into().map_err(|_| ExprError::NumericOverflow)?
445 }
446 _ => unreachable!(),
447 };
448 Finite32::try_from(val.0).map_err(|err| ExprError::InvalidParam {
449 name: "array_to_vector",
450 reason: err.into(),
451 })
452 })
453 .try_collect()?;
454 Ok(result)
455}
456
457/// ```slt
458/// query R
459/// SELECT round(vector_norm('[1,1]'::vector(2))::numeric, 5);
460/// ----
461/// 1.41421
462///
463/// query R
464/// SELECT vector_norm('[3,4]'::vector(2));
465/// ----
466/// 5
467///
468/// query R
469/// SELECT vector_norm('[0,1]'::vector(2));
470/// ----
471/// 1
472///
473/// query R
474/// SELECT vector_norm('[3e18,4e18]'::vector(2))::real;
475/// ----
476/// 5e+18
477///
478/// query R
479/// SELECT vector_norm('[0,0]'::vector(2));
480/// ----
481/// 0
482///
483/// query R
484/// SELECT vector_norm('[2]'::vector(1));
485/// ----
486/// 2
487/// ```
488#[function("l2_norm(vector) -> float8")]
489fn l2_norm(vector: VectorRef<'_>) -> F64 {
490 (vector.l2_norm() as f64).into()
491}
492
493/// ```slt
494/// query R
495/// SELECT l2_normalize('[3,4]'::vector(2));
496/// ----
497/// [0.6,0.8]
498///
499/// query R
500/// SELECT l2_normalize('[3,0]'::vector(2));
501/// ----
502/// [1,0]
503///
504/// query R
505/// SELECT l2_normalize('[0,0.1]'::vector(2));
506/// ----
507/// [0,1]
508///
509/// query R
510/// SELECT l2_normalize('[0,0]'::vector(2));
511/// ----
512/// [0,0]
513///
514/// query R
515/// SELECT l2_normalize('[3e18]'::vector(1));
516/// ----
517/// [1]
518/// ```
519#[function(
520 "l2_normalize(vector) -> vector",
521 type_infer = "|args| Ok(args[0].clone())"
522)]
523fn l2_normalize(vector: VectorRef<'_>) -> VectorVal {
524 vector.normalized()
525}
526
527#[derive(Debug)]
528pub struct SubvectorContext {
529 pub start: usize,
530 pub end: usize,
531}
532
533impl SubvectorContext {
534 pub fn from_start_count(start: i32, count: i32) -> Result<Self> {
535 Ok(Self {
536 start: (start - 1) as usize,
537 end: (start + count - 1) as usize,
538 })
539 }
540}
541
542/// ```slt
543/// query R
544/// SELECT subvector('[1,2,3,4,5]'::vector(5), 1, 3);
545/// ----
546/// [1,2,3]
547///
548/// query R
549/// SELECT subvector('[1,2,3,4,5]'::vector(5), 3, 2);
550/// ----
551/// [3,4]
552///
553/// query R
554/// SELECT subvector('[1,2,3,4,5]'::vector(5), 1, 5);
555/// ----
556/// [1,2,3,4,5]
557///
558/// query R
559/// SELECT subvector('[1,2,3,4,5]'::vector(5), 5, 1);
560/// ----
561/// [5]
562///
563/// query R
564/// SELECT subvector('[1,2,3,4,5]'::vector(5), 2, 3);
565/// ----
566/// [2,3,4]
567///
568/// query R
569/// select subvector(vec, 1, 3) from (values ('[1,2,3,4,5]'::vector(5)), ('[6,7,8,9,10]'::vector(5))) as t(vec);
570/// ----
571/// [1,2,3]
572/// [6,7,8]
573///
574/// statement error
575/// SELECT subvector('[1,2,3,4,5]'::vector(5), -1, 2);
576///
577/// statement error
578/// SELECT subvector('[6,7,8,9,10]'::vector(5), 1, 6);
579///
580/// statement error
581/// SELECT subvector('[6,7,8,9,10]'::vector(5), 5, 2);
582/// ```
583#[function(
584 "subvector(vector, int4, int4) -> vector",
585 prebuild = "SubvectorContext::from_start_count($1, $2)?",
586 type_infer = "unreachable"
587)]
588fn subvector(v: VectorRef<'_>, ctx: &SubvectorContext) -> Result<VectorVal> {
589 Ok(v.subvector(ctx.start, ctx.end))
590}