risingwave_connector_codec/decoder/protobuf/
parser.rs1use std::borrow::Cow;
16use std::collections::HashSet;
17
18use anyhow::Context;
19use itertools::Itertools;
20use prost_reflect::{
21 Cardinality, DynamicMessage, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage, Value,
22};
23use risingwave_common::array::{ListValue, StructValue};
24use risingwave_common::catalog::Field;
25use risingwave_common::types::{
26 DataType, DatumCow, Decimal, F32, F64, JsonbVal, MapType, MapValue, ScalarImpl, StructType,
27 ToOwnedDatum,
28};
29use thiserror::Error;
30use thiserror_ext::Macro;
31
32use crate::decoder::{AccessError, AccessResult, uncategorized};
33
34pub const PROTOBUF_MESSAGES_AS_JSONB: &str = "messages_as_jsonb";
35
36pub fn pb_schema_to_fields(
37 message_descriptor: &MessageDescriptor,
38 messages_as_jsonb: &HashSet<String>,
39) -> anyhow::Result<Vec<Field>> {
40 let mut parse_trace: Vec<String> = vec![];
41 message_descriptor
42 .fields()
43 .map(|field| {
44 let field_type = protobuf_type_mapping(&field, &mut parse_trace, messages_as_jsonb)
45 .context("failed to map protobuf type")?;
46 let column = Field::new(field.name(), field_type);
47 Ok(column)
48 })
49 .collect()
50}
51
52#[derive(Error, Debug, Macro)]
53#[error("{0}")]
54struct ProtobufTypeError(#[message] String);
55
56fn detect_loop_and_push(
57 trace: &mut Vec<String>,
58 fd: &FieldDescriptor,
59) -> std::result::Result<(), ProtobufTypeError> {
60 let identifier = format!("{}({})", fd.name(), fd.full_name());
61 if trace.iter().any(|s| s == identifier.as_str()) {
62 bail_protobuf_type_error!(
63 "circular reference detected: {}, conflict with {}, kind {:?}. Adding {:?} to {:?} may help.",
64 trace.iter().format("->"),
65 identifier,
66 fd.kind(),
67 fd.kind(),
68 PROTOBUF_MESSAGES_AS_JSONB,
69 );
70 }
71 trace.push(identifier);
72 Ok(())
73}
74
75pub fn from_protobuf_message_field<'a>(
80 field_desc: &FieldDescriptor,
81 message: &'a DynamicMessage,
82 type_expected: &DataType,
83 messages_as_jsonb: &'a HashSet<String>,
84) -> AccessResult<DatumCow<'a>> {
85 let value = if field_desc.supports_presence() && !message.has_field(field_desc) {
86 return Ok(DatumCow::NULL);
89 } else {
90 message.get_field(field_desc)
92 };
93
94 match value {
95 Cow::Borrowed(value) => {
96 from_protobuf_value(field_desc, value, type_expected, messages_as_jsonb)
97 }
98 Cow::Owned(value) => {
99 from_protobuf_value(field_desc, &value, type_expected, messages_as_jsonb)
100 .map(|d| d.to_owned_datum().into())
101 }
102 }
103}
104
105fn from_protobuf_value<'a>(
107 field_desc: &FieldDescriptor,
108 value: &'a Value,
109 type_expected: &DataType,
110 messages_as_jsonb: &'a HashSet<String>,
111) -> AccessResult<DatumCow<'a>> {
112 let kind = field_desc.kind();
113
114 macro_rules! borrowed {
115 ($v:expr) => {
116 return Ok(DatumCow::Borrowed(Some($v.into())))
117 };
118 }
119
120 let v: ScalarImpl = match value {
121 Value::Bool(v) => ScalarImpl::Bool(*v),
122 Value::I32(i) => ScalarImpl::Int32(*i),
123 Value::U32(i) => ScalarImpl::Int64(*i as i64),
124 Value::I64(i) => ScalarImpl::Int64(*i),
125 Value::U64(i) => ScalarImpl::Decimal(Decimal::from(*i)),
126 Value::F32(f) => ScalarImpl::Float32(F32::from(*f)),
127 Value::F64(f) => ScalarImpl::Float64(F64::from(*f)),
128 Value::String(s) => borrowed!(s.as_str()),
129 Value::EnumNumber(idx) => {
130 let enum_desc = kind.as_enum().ok_or_else(|| AccessError::TypeError {
131 expected: "enum".to_owned(),
132 got: format!("{kind:?}"),
133 value: value.to_string(),
134 })?;
135 let enum_symbol = enum_desc.get_value(*idx).ok_or_else(|| {
136 uncategorized!("unknown enum index {} of enum {:?}", idx, enum_desc)
137 })?;
138 ScalarImpl::Utf8(enum_symbol.name().into())
139 }
140 Value::Message(dyn_msg) => {
141 if messages_as_jsonb.contains(dyn_msg.descriptor().full_name()) {
142 ScalarImpl::Jsonb(JsonbVal::from(
143 serde_json::to_value(dyn_msg).map_err(AccessError::ProtobufAnyToJson)?,
144 ))
145 } else {
146 let desc = dyn_msg.descriptor();
147 let DataType::Struct(st) = type_expected else {
148 return Err(AccessError::TypeError {
149 expected: type_expected.to_string(),
150 got: desc.full_name().to_owned(),
151 value: value.to_string(), });
153 };
154
155 let mut datums = Vec::with_capacity(st.len());
156 for (name, expected_field_type) in st.iter() {
157 let Some(field_desc) = desc.get_field_by_name(name) else {
158 datums.push(None);
160 continue;
161 };
162 let datum = from_protobuf_message_field(
163 &field_desc,
164 dyn_msg,
165 expected_field_type,
166 messages_as_jsonb,
167 )?;
168 datums.push(datum.to_owned_datum());
169 }
170 ScalarImpl::Struct(StructValue::new(datums))
171 }
172 }
173 Value::List(values) => {
174 let DataType::List(list_type) = type_expected else {
175 return Err(AccessError::TypeError {
176 expected: type_expected.to_string(),
177 got: format!("repeated {:?}", kind),
178 value: value.to_string(), });
180 };
181 let elem_type = list_type.elem();
182 let mut builder = elem_type.create_array_builder(values.len());
183 for value in values {
184 builder.append(from_protobuf_value(
185 field_desc,
186 value,
187 elem_type,
188 messages_as_jsonb,
189 )?);
190 }
191 ScalarImpl::List(ListValue::new(builder.finish()))
192 }
193 Value::Bytes(value) => borrowed!(&**value),
194 Value::Map(map) => {
195 let err = || {
196 AccessError::TypeError {
197 expected: type_expected.to_string(),
198 got: format!("{:?}", kind),
199 value: value.to_string(), }
201 };
202
203 let DataType::Map(map_type) = type_expected else {
204 return Err(err());
205 };
206 if !field_desc.is_map() {
207 return Err(err());
208 }
209 let map_desc = kind.as_message().ok_or_else(err)?;
210
211 let mut key_builder = map_type.key().create_array_builder(map.len());
212 let mut value_builder = map_type.value().create_array_builder(map.len());
213 for (key, value) in map.iter().sorted_by_key(|(k, _v)| *k) {
218 key_builder.append(from_protobuf_value(
219 &map_desc.map_entry_key_field(),
220 &key.clone().into(),
221 map_type.key(),
222 messages_as_jsonb,
223 )?);
224 value_builder.append(from_protobuf_value(
225 &map_desc.map_entry_value_field(),
226 value,
227 map_type.value(),
228 messages_as_jsonb,
229 )?);
230 }
231 let keys = key_builder.finish();
232 let values = value_builder.finish();
233 ScalarImpl::Map(
234 MapValue::try_from_kv(ListValue::new(keys), ListValue::new(values))
235 .map_err(|e| uncategorized!("failed to convert protobuf map: {e}"))?,
236 )
237 }
238 };
239 Ok(Some(v).into())
240}
241
242fn protobuf_type_mapping(
244 field_descriptor: &FieldDescriptor,
245 parse_trace: &mut Vec<String>,
246 messages_as_jsonb: &HashSet<String>,
247) -> std::result::Result<DataType, ProtobufTypeError> {
248 detect_loop_and_push(parse_trace, field_descriptor)?;
249 let mut t = match field_descriptor.kind() {
250 Kind::Bool => DataType::Boolean,
251 Kind::Double => DataType::Float64,
252 Kind::Float => DataType::Float32,
253 Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => DataType::Int32,
254 Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 | Kind::Uint32 | Kind::Fixed32 => {
256 DataType::Int64
257 }
258 Kind::Uint64 | Kind::Fixed64 => DataType::Decimal,
259 Kind::String => DataType::Varchar,
260 Kind::Message(m) => {
261 if messages_as_jsonb.contains(m.full_name()) {
262 DataType::Jsonb
264 } else if m.is_map_entry() {
265 debug_assert!(field_descriptor.is_map());
267 let key = protobuf_type_mapping(
268 &m.map_entry_key_field(),
269 parse_trace,
270 messages_as_jsonb,
271 )?;
272 let value = protobuf_type_mapping(
273 &m.map_entry_value_field(),
274 parse_trace,
275 messages_as_jsonb,
276 )?;
277 _ = parse_trace.pop();
278 return Ok(DataType::Map(MapType::from_kv(key, value)));
279 } else {
280 let fields = m
281 .fields()
282 .map(|f| {
283 Ok((
284 f.name().to_owned(),
285 protobuf_type_mapping(&f, parse_trace, messages_as_jsonb)?,
286 ))
287 })
288 .try_collect::<_, Vec<_>, _>()?;
289 StructType::new(fields).into()
290 }
291 }
292 Kind::Enum(_) => DataType::Varchar,
293 Kind::Bytes => DataType::Bytea,
294 };
295 if field_descriptor.cardinality() == Cardinality::Repeated {
296 debug_assert!(!field_descriptor.is_map());
297 t = DataType::list(t)
298 }
299 _ = parse_trace.pop();
300 Ok(t)
301}