risingwave_expr_impl/scalar/
ai_model.rs1use std::sync::Arc;
16
17use async_openai::Client;
18use async_openai::config::OpenAIConfig;
19use async_openai::types::{CreateEmbeddingRequestArgs, Embedding, EmbeddingInput};
20use risingwave_common::array::{
21 Array, ArrayBuilder, ArrayImpl, ArrayRef, DataChunk, F32Array, ListArrayBuilder, ListValue,
22};
23use risingwave_common::row::OwnedRow;
24use risingwave_common::types::{DataType, Datum, F32, ScalarImpl};
25use risingwave_expr::expr::{BoxedExpression, Expression};
26use risingwave_expr::{ExprError, Result, build_function};
27use serde::Deserialize;
28use serde_json::Value;
29use thiserror_ext::AsReport;
30
31#[derive(Debug)]
33pub struct OpenAiEmbeddingContext {
34 pub client: Client<OpenAIConfig>,
35 pub model: String,
36}
37
38#[derive(Deserialize)]
39struct OpenAiEmbeddingConfig {
40 model: String,
41 api_key: Option<String>,
42 org_id: Option<String>,
43 project_id: Option<String>,
44 api_base: Option<String>,
45}
46
47impl OpenAiEmbeddingContext {
48 pub fn from_config(config: Value) -> Result<Self> {
50 let param: OpenAiEmbeddingConfig = serde_json::from_value(config).map_err(|err| {
51 invalid_param_err(format!("failed to parse config: {}", err.as_report()))
52 })?;
53
54 let mut config = OpenAIConfig::new();
55 if let Some(api_key) = param.api_key {
56 config = config.with_api_key(api_key);
57 }
58 if let Some(org_id) = param.org_id {
59 config = config.with_org_id(org_id);
60 }
61 if let Some(proj_id) = param.project_id {
62 config = config.with_project_id(proj_id);
63 }
64 if let Some(api_base) = param.api_base {
65 config = config.with_api_base(api_base);
66 }
67
68 let client = Client::with_config(config);
69 Ok(Self {
70 client,
71 model: param.model,
72 })
73 }
74}
75
76#[derive(Debug)]
77struct OpenAiEmbedding {
78 text_expr: BoxedExpression,
79 context: OpenAiEmbeddingContext,
80}
81
82impl OpenAiEmbedding {
83 async fn get_embeddings(&self, input: EmbeddingInput) -> Result<Vec<Embedding>> {
84 let request = CreateEmbeddingRequestArgs::default()
85 .model(&self.context.model)
86 .input(input)
87 .build()
88 .map_err(|e| {
89 tracing::error!(error = %e.as_report(), "Failed to build OpenAI embedding request");
90 ExprError::Custom("failed to build OpenAI embedding request".into())
91 })?;
92
93 let response = self
94 .context
95 .client
96 .embeddings()
97 .create(request)
98 .await
99 .map_err(|e| {
100 tracing::error!(error = %e.as_report(), "Failed to get embedding from OpenAI");
101 ExprError::Custom(format!(
102 "failed to get embedding from OpenAI: {}",
103 e.as_report()
104 ))
105 })?;
106
107 if response.data.is_empty() {
108 return Err(ExprError::Custom(
109 "no embedding data returned from OpenAI".into(),
110 ));
111 }
112
113 Ok(response.data)
114 }
115}
116
117#[async_trait::async_trait]
118impl Expression for OpenAiEmbedding {
119 fn return_type(&self) -> DataType {
120 DataType::List(Box::new(DataType::Float32))
121 }
122
123 async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
124 let text_array = self.text_expr.eval(input).await?;
125 let text_array = text_array.as_utf8();
126
127 let mut texts_to_embed = Vec::new();
129
130 for i in 0..input.capacity() {
131 if let Some(text) = text_array.value_at(i)
132 && !text.is_empty()
133 {
134 texts_to_embed.push(text.to_owned());
135 }
136 }
137 let n_texts_to_embed = texts_to_embed.len();
138
139 let embeddings = if texts_to_embed.is_empty() {
141 Vec::new()
142 } else {
143 self.get_embeddings(EmbeddingInput::StringArray(texts_to_embed))
144 .await?
145 };
146 if embeddings.len() != n_texts_to_embed {
147 return Err(ExprError::Custom(
148 "number of embeddings returned from OpenAI does not match the number of texts"
149 .into(),
150 ));
151 }
152
153 let mut builder = ListArrayBuilder::with_type(
155 input.capacity(),
156 DataType::List(Box::new(DataType::Float32)),
157 );
158 let mut embedding_idx = 0;
159
160 for i in 0..input.capacity() {
161 if let Some(text) = text_array.value_at(i) {
162 if !text.is_empty() {
163 if embedding_idx < embeddings.len() {
165 let embedding = &embeddings[embedding_idx].embedding;
166 let float_array =
167 F32Array::from_iter(embedding.iter().map(|&v| Some(F32::from(v))));
168 let list_value = ListValue::new(float_array.into());
169 builder.append_owned(Some(list_value));
170 embedding_idx += 1;
171 } else {
172 builder.append(None);
173 }
174 } else {
175 builder.append(None);
177 }
178 } else {
179 builder.append(None);
181 }
182 }
183
184 Ok(Arc::new(ArrayImpl::List(builder.finish())))
185 }
186
187 async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
188 let text_datum = self.text_expr.eval_row(input).await?;
189
190 if let Some(ScalarImpl::Utf8(text)) = text_datum.as_ref() {
191 if text.is_empty() {
192 return Ok(None);
193 }
194
195 let embeddings = self
196 .get_embeddings(EmbeddingInput::String(text.to_owned().into_string()))
197 .await?;
198 let embedding = &embeddings[0].embedding;
199 let float_array = F32Array::from_iter(embedding.iter().map(|&v| Some(F32::from(v))));
200 Ok(Some(ListValue::new(float_array.into()).into()))
201 } else {
202 Ok(None)
203 }
204 }
205}
206
207fn invalid_param_err(reason: impl Into<String>) -> ExprError {
208 ExprError::InvalidParam {
209 name: "openai_embedding",
210 reason: reason.into().into(),
211 }
212}
213
214#[build_function("openai_embedding(jsonb, varchar) -> float4[]")]
215fn build_openai_embedding_expr(
216 _: DataType,
217 mut children: Vec<BoxedExpression>,
218) -> Result<BoxedExpression> {
219 if children.len() != 2 {
220 return Err(invalid_param_err("expected 2 arguments"));
221 }
222
223 let config = if let Ok(Some(config_scalar)) = children[0].eval_const() {
225 if let ScalarImpl::Jsonb(config) = config_scalar {
226 config.take()
227 } else {
228 return Err(invalid_param_err(
229 "`embedding_config` must be a jsonb constant",
230 ));
231 }
232 } else {
233 return Err(invalid_param_err("`embedding_config` must be a constant"));
234 };
235
236 let context = OpenAiEmbeddingContext::from_config(config)?;
237
238 Ok(Box::new(OpenAiEmbedding {
239 text_expr: children.pop().unwrap(), context,
241 }))
242}