risingwave_expr_impl/scalar/
ai_model.rs1use std::sync::Arc;
16
17use async_openai::Client;
18use async_openai::config::OpenAIConfig;
19use async_openai::types::{CreateEmbeddingRequestArgs, Embedding, EmbeddingInput};
20use risingwave_common::array::{
21 Array, ArrayBuilder, ArrayImpl, ArrayRef, DataChunk, F32Array, ListArrayBuilder, ListValue,
22};
23use risingwave_common::row::OwnedRow;
24use risingwave_common::types::{DataType, Datum, F32, ScalarImpl};
25use risingwave_expr::expr::{BoxedExpression, Expression};
26use risingwave_expr::{ExprError, Result, build_function};
27use thiserror_ext::AsReport;
28
29#[derive(Debug)]
31pub struct OpenAiEmbeddingContext {
32 pub client: Client<OpenAIConfig>,
33 pub model: String,
34}
35
36impl OpenAiEmbeddingContext {
37 pub fn from_config(api_key: &str, model: &str) -> Result<Self> {
39 let config = OpenAIConfig::new().with_api_key(api_key);
40 let client = Client::with_config(config);
41 Ok(Self {
42 client,
43 model: model.to_owned(),
44 })
45 }
46}
47
48#[derive(Debug)]
49struct OpenAiEmbedding {
50 text_expr: BoxedExpression,
51 context: OpenAiEmbeddingContext,
52}
53
54impl OpenAiEmbedding {
55 async fn get_embeddings(&self, input: EmbeddingInput) -> Result<Vec<Embedding>> {
56 let request = CreateEmbeddingRequestArgs::default()
57 .model(&self.context.model)
58 .input(input)
59 .build()
60 .map_err(|e| {
61 tracing::error!(error = %e.as_report(), "Failed to build OpenAI embedding request");
62 ExprError::Custom("failed to build OpenAI embedding request".into())
63 })?;
64
65 let response = self
66 .context
67 .client
68 .embeddings()
69 .create(request)
70 .await
71 .map_err(|e| {
72 tracing::error!(error = %e.as_report(), "Failed to get embedding from OpenAI");
73 ExprError::Custom("failed to get embedding from OpenAI".into())
74 })?;
75
76 if response.data.is_empty() {
77 return Err(ExprError::Custom(
78 "no embedding data returned from OpenAI".into(),
79 ));
80 }
81
82 Ok(response.data)
83 }
84}
85
86#[async_trait::async_trait]
87impl Expression for OpenAiEmbedding {
88 fn return_type(&self) -> DataType {
89 DataType::List(Box::new(DataType::Float32))
90 }
91
92 async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
93 let text_array = self.text_expr.eval(input).await?;
94 let text_array = text_array.as_utf8();
95
96 let mut texts_to_embed = Vec::new();
98
99 for i in 0..input.capacity() {
100 if let Some(text) = text_array.value_at(i) {
101 if !text.is_empty() {
102 texts_to_embed.push(text.to_owned());
103 }
104 }
105 }
106 let n_texts_to_embed = texts_to_embed.len();
107
108 let embeddings = if texts_to_embed.is_empty() {
110 Vec::new()
111 } else {
112 self.get_embeddings(EmbeddingInput::StringArray(texts_to_embed))
113 .await?
114 };
115 if embeddings.len() != n_texts_to_embed {
116 return Err(ExprError::Custom(
117 "number of embeddings returned from OpenAI does not match the number of texts"
118 .into(),
119 ));
120 }
121
122 let mut builder = ListArrayBuilder::with_type(
124 input.capacity(),
125 DataType::List(Box::new(DataType::Float32)),
126 );
127 let mut embedding_idx = 0;
128
129 for i in 0..input.capacity() {
130 if let Some(text) = text_array.value_at(i) {
131 if !text.is_empty() {
132 if embedding_idx < embeddings.len() {
134 let embedding = &embeddings[embedding_idx].embedding;
135 let float_array =
136 F32Array::from_iter(embedding.iter().map(|&v| Some(F32::from(v))));
137 let list_value = ListValue::new(float_array.into());
138 builder.append_owned(Some(list_value));
139 embedding_idx += 1;
140 } else {
141 builder.append(None);
142 }
143 } else {
144 builder.append(None);
146 }
147 } else {
148 builder.append(None);
150 }
151 }
152
153 Ok(Arc::new(ArrayImpl::List(builder.finish())))
154 }
155
156 async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
157 let text_datum = self.text_expr.eval_row(input).await?;
158
159 if let Some(ScalarImpl::Utf8(text)) = text_datum.as_ref() {
160 if text.is_empty() {
161 return Ok(None);
162 }
163
164 let embeddings = self
165 .get_embeddings(EmbeddingInput::String(text.to_owned().into_string()))
166 .await?;
167 let embedding = &embeddings[0].embedding;
168 let float_array = F32Array::from_iter(embedding.iter().map(|&v| Some(F32::from(v))));
169 Ok(Some(ListValue::new(float_array.into()).into()))
170 } else {
171 Ok(None)
172 }
173 }
174}
175
176#[build_function("openai_embedding(varchar, varchar, varchar) -> float4[]")]
177fn build_openai_embedding_expr(
178 _: DataType,
179 mut children: Vec<BoxedExpression>,
180) -> Result<BoxedExpression> {
181 if children.len() != 3 {
182 return Err(ExprError::InvalidParam {
183 name: "openai_embedding",
184 reason: "expected 3 arguments".into(),
185 });
186 }
187
188 let api_key = if let Ok(Some(api_key_scalar)) = children[0].eval_const() {
190 if let ScalarImpl::Utf8(api_key_str) = api_key_scalar {
191 api_key_str.to_string()
192 } else {
193 return Err(ExprError::InvalidParam {
194 name: "openai_embedding",
195 reason: "`api_key` must be a string constant".into(),
196 });
197 }
198 } else {
199 return Err(ExprError::InvalidParam {
200 name: "openai_embedding",
201 reason: "`api_key` must be a constant".into(),
202 });
203 };
204
205 let model = if let Ok(Some(model_scalar)) = children[1].eval_const() {
206 if let ScalarImpl::Utf8(model_str) = model_scalar {
207 model_str.to_string()
208 } else {
209 return Err(ExprError::InvalidParam {
210 name: "openai_embedding",
211 reason: "`model` must be a string constant".into(),
212 });
213 }
214 } else {
215 return Err(ExprError::InvalidParam {
216 name: "openai_embedding",
217 reason: "`model` must be a constant".into(),
218 });
219 };
220
221 let context = OpenAiEmbeddingContext::from_config(&api_key, &model)?;
222
223 Ok(Box::new(OpenAiEmbedding {
224 text_expr: children.pop().unwrap(), context,
226 }))
227}