risingwave_expr_impl/scalar/
ai_model.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use async_openai::Client;
18use async_openai::config::OpenAIConfig;
19use async_openai::types::{CreateEmbeddingRequestArgs, Embedding, EmbeddingInput};
20use risingwave_common::array::{
21    Array, ArrayBuilder, ArrayImpl, ArrayRef, DataChunk, F32Array, ListArrayBuilder, ListValue,
22};
23use risingwave_common::row::OwnedRow;
24use risingwave_common::types::{DataType, Datum, F32, ScalarImpl};
25use risingwave_expr::expr::{BoxedExpression, Expression};
26use risingwave_expr::{ExprError, Result, build_function};
27use serde::Deserialize;
28use serde_json::Value;
29use thiserror_ext::AsReport;
30
31/// `OpenAI` embedding context that holds the client and model configuration
32#[derive(Debug)]
33pub struct OpenAiEmbeddingContext {
34    pub client: Client<OpenAIConfig>,
35    pub model: String,
36}
37
38#[derive(Deserialize)]
39struct OpenAiEmbeddingConfig {
40    model: String,
41    api_key: Option<String>,
42    org_id: Option<String>,
43    project_id: Option<String>,
44    api_base: Option<String>,
45}
46
47impl OpenAiEmbeddingContext {
48    /// Create a new `OpenAI` embedding context from `api_key` and model
49    pub fn from_config(config: Value) -> Result<Self> {
50        let param: OpenAiEmbeddingConfig = serde_json::from_value(config).map_err(|err| {
51            invalid_param_err(format!("failed to parse config: {}", err.as_report()))
52        })?;
53
54        let mut config = OpenAIConfig::new();
55        if let Some(api_key) = param.api_key {
56            config = config.with_api_key(api_key);
57        }
58        if let Some(org_id) = param.org_id {
59            config = config.with_org_id(org_id);
60        }
61        if let Some(proj_id) = param.project_id {
62            config = config.with_project_id(proj_id);
63        }
64        if let Some(api_base) = param.api_base {
65            config = config.with_api_base(api_base);
66        }
67
68        let client = Client::with_config(config);
69        Ok(Self {
70            client,
71            model: param.model,
72        })
73    }
74}
75
76#[derive(Debug)]
77struct OpenAiEmbedding {
78    text_expr: BoxedExpression,
79    context: OpenAiEmbeddingContext,
80}
81
82impl OpenAiEmbedding {
83    async fn get_embeddings(&self, input: EmbeddingInput) -> Result<Vec<Embedding>> {
84        let request = CreateEmbeddingRequestArgs::default()
85            .model(&self.context.model)
86            .input(input)
87            .build()
88            .map_err(|e| {
89                tracing::error!(error = %e.as_report(), "Failed to build OpenAI embedding request");
90                ExprError::Custom("failed to build OpenAI embedding request".into())
91            })?;
92
93        let response = self
94            .context
95            .client
96            .embeddings()
97            .create(request)
98            .await
99            .map_err(|e| {
100                tracing::error!(error = %e.as_report(), "Failed to get embedding from OpenAI");
101                ExprError::Custom(format!(
102                    "failed to get embedding from OpenAI: {}",
103                    e.as_report()
104                ))
105            })?;
106
107        if response.data.is_empty() {
108            return Err(ExprError::Custom(
109                "no embedding data returned from OpenAI".into(),
110            ));
111        }
112
113        Ok(response.data)
114    }
115}
116
117#[async_trait::async_trait]
118impl Expression for OpenAiEmbedding {
119    fn return_type(&self) -> DataType {
120        DataType::List(Box::new(DataType::Float32))
121    }
122
123    async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
124        let text_array = self.text_expr.eval(input).await?;
125        let text_array = text_array.as_utf8();
126
127        // Collect non-null and non-empty texts
128        let mut texts_to_embed = Vec::new();
129
130        for i in 0..input.capacity() {
131            if let Some(text) = text_array.value_at(i)
132                && !text.is_empty()
133            {
134                texts_to_embed.push(text.to_owned());
135            }
136        }
137        let n_texts_to_embed = texts_to_embed.len();
138
139        // Get embeddings in batch
140        let embeddings = if texts_to_embed.is_empty() {
141            Vec::new()
142        } else {
143            self.get_embeddings(EmbeddingInput::StringArray(texts_to_embed))
144                .await?
145        };
146        if embeddings.len() != n_texts_to_embed {
147            return Err(ExprError::Custom(
148                "number of embeddings returned from OpenAI does not match the number of texts"
149                    .into(),
150            ));
151        }
152
153        // Map results back to original positions
154        let mut builder = ListArrayBuilder::with_type(
155            input.capacity(),
156            DataType::List(Box::new(DataType::Float32)),
157        );
158        let mut embedding_idx = 0;
159
160        for i in 0..input.capacity() {
161            if let Some(text) = text_array.value_at(i) {
162                if !text.is_empty() {
163                    // Non-empty text, use the embedding result
164                    if embedding_idx < embeddings.len() {
165                        let embedding = &embeddings[embedding_idx].embedding;
166                        let float_array =
167                            F32Array::from_iter(embedding.iter().map(|&v| Some(F32::from(v))));
168                        let list_value = ListValue::new(float_array.into());
169                        builder.append_owned(Some(list_value));
170                        embedding_idx += 1;
171                    } else {
172                        builder.append(None);
173                    }
174                } else {
175                    // Empty text returns NULL
176                    builder.append(None);
177                }
178            } else {
179                // Null text returns NULL
180                builder.append(None);
181            }
182        }
183
184        Ok(Arc::new(ArrayImpl::List(builder.finish())))
185    }
186
187    async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
188        let text_datum = self.text_expr.eval_row(input).await?;
189
190        if let Some(ScalarImpl::Utf8(text)) = text_datum.as_ref() {
191            if text.is_empty() {
192                return Ok(None);
193            }
194
195            let embeddings = self
196                .get_embeddings(EmbeddingInput::String(text.to_owned().into_string()))
197                .await?;
198            let embedding = &embeddings[0].embedding;
199            let float_array = F32Array::from_iter(embedding.iter().map(|&v| Some(F32::from(v))));
200            Ok(Some(ListValue::new(float_array.into()).into()))
201        } else {
202            Ok(None)
203        }
204    }
205}
206
207fn invalid_param_err(reason: impl Into<String>) -> ExprError {
208    ExprError::InvalidParam {
209        name: "openai_embedding",
210        reason: reason.into().into(),
211    }
212}
213
214#[build_function("openai_embedding(jsonb, varchar) -> float4[]")]
215fn build_openai_embedding_expr(
216    _: DataType,
217    mut children: Vec<BoxedExpression>,
218) -> Result<BoxedExpression> {
219    if children.len() != 2 {
220        return Err(invalid_param_err("expected 2 arguments"));
221    }
222
223    // Check if the first two parameters are constants
224    let config = if let Ok(Some(config_scalar)) = children[0].eval_const() {
225        if let ScalarImpl::Jsonb(config) = config_scalar {
226            config.take()
227        } else {
228            return Err(invalid_param_err(
229                "`embedding_config` must be a jsonb constant",
230            ));
231        }
232    } else {
233        return Err(invalid_param_err("`embedding_config` must be a constant"));
234    };
235
236    let context = OpenAiEmbeddingContext::from_config(config)?;
237
238    Ok(Box::new(OpenAiEmbedding {
239        text_expr: children.pop().unwrap(), // Take the second expression
240        context,
241    }))
242}