risingwave_expr_impl/scalar/
ai_model.rs1use std::sync::Arc;
16
17use async_openai::Client;
18use async_openai::config::OpenAIConfig;
19use async_openai::types::{CreateEmbeddingRequestArgs, Embedding, EmbeddingInput};
20use risingwave_common::array::{
21 Array, ArrayBuilder, ArrayImpl, ArrayRef, DataChunk, F32Array, ListArrayBuilder, ListValue,
22};
23use risingwave_common::row::OwnedRow;
24use risingwave_common::types::{DataType, Datum, F32, ScalarImpl};
25use risingwave_expr::expr::{BoxedExpression, Expression};
26use risingwave_expr::{ExprError, Result, build_function};
27use serde::Deserialize;
28use serde_json::Value;
29use thiserror_ext::AsReport;
30
31#[derive(Debug)]
33pub struct OpenAiEmbeddingContext {
34 pub client: Client<OpenAIConfig>,
35 pub model: String,
36}
37
38#[derive(Deserialize)]
39struct OpenAiEmbeddingConfig {
40 model: String,
41 api_key: Option<String>,
42 org_id: Option<String>,
43 project_id: Option<String>,
44 api_base: Option<String>,
45}
46
47impl OpenAiEmbeddingContext {
48 pub fn from_config(config: Value) -> Result<Self> {
50 let param: OpenAiEmbeddingConfig = serde_json::from_value(config).map_err(|err| {
51 invalid_param_err(format!("failed to parse config: {}", err.as_report()))
52 })?;
53
54 let mut config = OpenAIConfig::new();
55 if let Some(api_key) = param.api_key {
56 config = config.with_api_key(api_key);
57 }
58 if let Some(org_id) = param.org_id {
59 config = config.with_org_id(org_id);
60 }
61 if let Some(proj_id) = param.project_id {
62 config = config.with_project_id(proj_id);
63 }
64 if let Some(api_base) = param.api_base {
65 config = config.with_api_base(api_base);
66 }
67
68 let client = Client::with_config(config);
69 Ok(Self {
70 client,
71 model: param.model,
72 })
73 }
74}
75
76#[derive(Debug)]
77struct OpenAiEmbedding {
78 text_expr: BoxedExpression,
79 context: OpenAiEmbeddingContext,
80}
81
82impl OpenAiEmbedding {
83 async fn get_embeddings(&self, input: EmbeddingInput) -> Result<Vec<Embedding>> {
84 let request = CreateEmbeddingRequestArgs::default()
85 .model(&self.context.model)
86 .input(input)
87 .build()
88 .map_err(|e| {
89 tracing::error!(error = %e.as_report(), "Failed to build OpenAI embedding request");
90 ExprError::Custom("failed to build OpenAI embedding request".into())
91 })?;
92
93 let response = self
94 .context
95 .client
96 .embeddings()
97 .create(request)
98 .await
99 .map_err(|e| {
100 tracing::error!(error = %e.as_report(), "Failed to get embedding from OpenAI");
101 ExprError::Custom(format!(
102 "failed to get embedding from OpenAI: {}",
103 e.as_report()
104 ))
105 })?;
106
107 if response.data.is_empty() {
108 return Err(ExprError::Custom(
109 "no embedding data returned from OpenAI".into(),
110 ));
111 }
112
113 Ok(response.data)
114 }
115}
116
117#[async_trait::async_trait]
118impl Expression for OpenAiEmbedding {
119 fn return_type(&self) -> DataType {
120 DataType::Float32.list()
121 }
122
123 async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
124 let text_array = self.text_expr.eval(input).await?;
125 let text_array = text_array.as_utf8();
126
127 let mut texts_to_embed = Vec::new();
129
130 for i in 0..input.capacity() {
131 if let Some(text) = text_array.value_at(i)
132 && !text.is_empty()
133 {
134 texts_to_embed.push(text.to_owned());
135 }
136 }
137 let n_texts_to_embed = texts_to_embed.len();
138
139 let embeddings = if texts_to_embed.is_empty() {
141 Vec::new()
142 } else {
143 self.get_embeddings(EmbeddingInput::StringArray(texts_to_embed))
144 .await?
145 };
146 if embeddings.len() != n_texts_to_embed {
147 return Err(ExprError::Custom(
148 "number of embeddings returned from OpenAI does not match the number of texts"
149 .into(),
150 ));
151 }
152
153 let mut builder = ListArrayBuilder::with_type(input.capacity(), DataType::Float32.list());
155 let mut embedding_idx = 0;
156
157 for i in 0..input.capacity() {
158 if let Some(text) = text_array.value_at(i) {
159 if !text.is_empty() {
160 if embedding_idx < embeddings.len() {
162 let embedding = &embeddings[embedding_idx].embedding;
163 let float_array =
164 F32Array::from_iter(embedding.iter().map(|&v| Some(F32::from(v))));
165 let list_value = ListValue::new(float_array.into());
166 builder.append_owned(Some(list_value));
167 embedding_idx += 1;
168 } else {
169 builder.append(None);
170 }
171 } else {
172 builder.append(None);
174 }
175 } else {
176 builder.append(None);
178 }
179 }
180
181 Ok(Arc::new(ArrayImpl::List(builder.finish())))
182 }
183
184 async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
185 let text_datum = self.text_expr.eval_row(input).await?;
186
187 if let Some(ScalarImpl::Utf8(text)) = text_datum.as_ref() {
188 if text.is_empty() {
189 return Ok(None);
190 }
191
192 let embeddings = self
193 .get_embeddings(EmbeddingInput::String(text.to_owned().into_string()))
194 .await?;
195 let embedding = &embeddings[0].embedding;
196 let float_array = F32Array::from_iter(embedding.iter().map(|&v| Some(F32::from(v))));
197 Ok(Some(ListValue::new(float_array.into()).into()))
198 } else {
199 Ok(None)
200 }
201 }
202}
203
204fn invalid_param_err(reason: impl Into<String>) -> ExprError {
205 ExprError::InvalidParam {
206 name: "openai_embedding",
207 reason: reason.into().into(),
208 }
209}
210
211#[build_function("openai_embedding(jsonb, varchar) -> float4[]")]
212fn build_openai_embedding_expr(
213 _: DataType,
214 mut children: Vec<BoxedExpression>,
215) -> Result<BoxedExpression> {
216 if children.len() != 2 {
217 return Err(invalid_param_err("expected 2 arguments"));
218 }
219
220 let config = if let Ok(Some(config_scalar)) = children[0].eval_const() {
222 if let ScalarImpl::Jsonb(config) = config_scalar {
223 config.take()
224 } else {
225 return Err(invalid_param_err(
226 "`embedding_config` must be a jsonb constant",
227 ));
228 }
229 } else {
230 return Err(invalid_param_err("`embedding_config` must be a constant"));
231 };
232
233 let context = OpenAiEmbeddingContext::from_config(config)?;
234
235 Ok(Box::new(OpenAiEmbedding {
236 text_expr: children.pop().unwrap(), context,
238 }))
239}