risingwave_expr_impl/scalar/
ai_model.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use async_openai::Client;
18use async_openai::config::OpenAIConfig;
19use async_openai::types::{CreateEmbeddingRequestArgs, Embedding, EmbeddingInput};
20use risingwave_common::array::{
21    Array, ArrayBuilder, ArrayImpl, ArrayRef, DataChunk, F32Array, ListArrayBuilder, ListValue,
22};
23use risingwave_common::row::OwnedRow;
24use risingwave_common::types::{DataType, Datum, F32, ScalarImpl};
25use risingwave_expr::expr::{BoxedExpression, Expression};
26use risingwave_expr::{ExprError, Result, build_function};
27use thiserror_ext::AsReport;
28
29/// `OpenAI` embedding context that holds the client and model configuration
30#[derive(Debug)]
31pub struct OpenAiEmbeddingContext {
32    pub client: Client<OpenAIConfig>,
33    pub model: String,
34}
35
36impl OpenAiEmbeddingContext {
37    /// Create a new `OpenAI` embedding context from `api_key` and model
38    pub fn from_config(api_key: &str, model: &str) -> Result<Self> {
39        let config = OpenAIConfig::new().with_api_key(api_key);
40        let client = Client::with_config(config);
41        Ok(Self {
42            client,
43            model: model.to_owned(),
44        })
45    }
46}
47
48#[derive(Debug)]
49struct OpenAiEmbedding {
50    text_expr: BoxedExpression,
51    context: OpenAiEmbeddingContext,
52}
53
54impl OpenAiEmbedding {
55    async fn get_embeddings(&self, input: EmbeddingInput) -> Result<Vec<Embedding>> {
56        let request = CreateEmbeddingRequestArgs::default()
57            .model(&self.context.model)
58            .input(input)
59            .build()
60            .map_err(|e| {
61                tracing::error!(error = %e.as_report(), "Failed to build OpenAI embedding request");
62                ExprError::Custom("failed to build OpenAI embedding request".into())
63            })?;
64
65        let response = self
66            .context
67            .client
68            .embeddings()
69            .create(request)
70            .await
71            .map_err(|e| {
72                tracing::error!(error = %e.as_report(), "Failed to get embedding from OpenAI");
73                ExprError::Custom("failed to get embedding from OpenAI".into())
74            })?;
75
76        if response.data.is_empty() {
77            return Err(ExprError::Custom(
78                "no embedding data returned from OpenAI".into(),
79            ));
80        }
81
82        Ok(response.data)
83    }
84}
85
86#[async_trait::async_trait]
87impl Expression for OpenAiEmbedding {
88    fn return_type(&self) -> DataType {
89        DataType::List(Box::new(DataType::Float32))
90    }
91
92    async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
93        let text_array = self.text_expr.eval(input).await?;
94        let text_array = text_array.as_utf8();
95
96        // Collect non-null and non-empty texts
97        let mut texts_to_embed = Vec::new();
98
99        for i in 0..input.capacity() {
100            if let Some(text) = text_array.value_at(i) {
101                if !text.is_empty() {
102                    texts_to_embed.push(text.to_owned());
103                }
104            }
105        }
106        let n_texts_to_embed = texts_to_embed.len();
107
108        // Get embeddings in batch
109        let embeddings = if texts_to_embed.is_empty() {
110            Vec::new()
111        } else {
112            self.get_embeddings(EmbeddingInput::StringArray(texts_to_embed))
113                .await?
114        };
115        if embeddings.len() != n_texts_to_embed {
116            return Err(ExprError::Custom(
117                "number of embeddings returned from OpenAI does not match the number of texts"
118                    .into(),
119            ));
120        }
121
122        // Map results back to original positions
123        let mut builder = ListArrayBuilder::with_type(
124            input.capacity(),
125            DataType::List(Box::new(DataType::Float32)),
126        );
127        let mut embedding_idx = 0;
128
129        for i in 0..input.capacity() {
130            if let Some(text) = text_array.value_at(i) {
131                if !text.is_empty() {
132                    // Non-empty text, use the embedding result
133                    if embedding_idx < embeddings.len() {
134                        let embedding = &embeddings[embedding_idx].embedding;
135                        let float_array =
136                            F32Array::from_iter(embedding.iter().map(|&v| Some(F32::from(v))));
137                        let list_value = ListValue::new(float_array.into());
138                        builder.append_owned(Some(list_value));
139                        embedding_idx += 1;
140                    } else {
141                        builder.append(None);
142                    }
143                } else {
144                    // Empty text returns NULL
145                    builder.append(None);
146                }
147            } else {
148                // Null text returns NULL
149                builder.append(None);
150            }
151        }
152
153        Ok(Arc::new(ArrayImpl::List(builder.finish())))
154    }
155
156    async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
157        let text_datum = self.text_expr.eval_row(input).await?;
158
159        if let Some(ScalarImpl::Utf8(text)) = text_datum.as_ref() {
160            if text.is_empty() {
161                return Ok(None);
162            }
163
164            let embeddings = self
165                .get_embeddings(EmbeddingInput::String(text.to_owned().into_string()))
166                .await?;
167            let embedding = &embeddings[0].embedding;
168            let float_array = F32Array::from_iter(embedding.iter().map(|&v| Some(F32::from(v))));
169            Ok(Some(ListValue::new(float_array.into()).into()))
170        } else {
171            Ok(None)
172        }
173    }
174}
175
176#[build_function("openai_embedding(varchar, varchar, varchar) -> float4[]")]
177fn build_openai_embedding_expr(
178    _: DataType,
179    mut children: Vec<BoxedExpression>,
180) -> Result<BoxedExpression> {
181    if children.len() != 3 {
182        return Err(ExprError::InvalidParam {
183            name: "openai_embedding",
184            reason: "expected 3 arguments".into(),
185        });
186    }
187
188    // Check if the first two parameters are constants
189    let api_key = if let Ok(Some(api_key_scalar)) = children[0].eval_const() {
190        if let ScalarImpl::Utf8(api_key_str) = api_key_scalar {
191            api_key_str.to_string()
192        } else {
193            return Err(ExprError::InvalidParam {
194                name: "openai_embedding",
195                reason: "`api_key` must be a string constant".into(),
196            });
197        }
198    } else {
199        return Err(ExprError::InvalidParam {
200            name: "openai_embedding",
201            reason: "`api_key` must be a constant".into(),
202        });
203    };
204
205    let model = if let Ok(Some(model_scalar)) = children[1].eval_const() {
206        if let ScalarImpl::Utf8(model_str) = model_scalar {
207            model_str.to_string()
208        } else {
209            return Err(ExprError::InvalidParam {
210                name: "openai_embedding",
211                reason: "`model` must be a string constant".into(),
212            });
213        }
214    } else {
215        return Err(ExprError::InvalidParam {
216            name: "openai_embedding",
217            reason: "`model` must be a constant".into(),
218        });
219    };
220
221    let context = OpenAiEmbeddingContext::from_config(&api_key, &model)?;
222
223    Ok(Box::new(OpenAiEmbedding {
224        text_expr: children.pop().unwrap(), // Take the third expression
225        context,
226    }))
227}