risingwave_connector_codec/decoder/protobuf/
parser.rs1use std::borrow::Cow;
16use std::collections::HashSet;
17
18use anyhow::Context;
19use itertools::Itertools;
20use prost_reflect::{
21 Cardinality, DynamicMessage, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage, Value,
22};
23use risingwave_common::array::{ListValue, StructValue};
24use risingwave_common::catalog::Field;
25use risingwave_common::types::{
26 DataType, DatumCow, Decimal, F32, F64, JsonbVal, MapType, MapValue, ScalarImpl, StructType,
27 ToOwnedDatum,
28};
29use thiserror::Error;
30use thiserror_ext::Macro;
31
32use crate::decoder::{AccessError, AccessResult, uncategorized};
33
34pub const PROTOBUF_MESSAGES_AS_JSONB: &str = "messages_as_jsonb";
35
36pub fn pb_schema_to_fields(
37 message_descriptor: &MessageDescriptor,
38 messages_as_jsonb: &HashSet<String>,
39) -> anyhow::Result<Vec<Field>> {
40 let mut parse_trace: Vec<String> = vec![];
41 message_descriptor
42 .fields()
43 .map(|field| {
44 let field_type = protobuf_type_mapping(&field, &mut parse_trace, messages_as_jsonb)
45 .context("failed to map protobuf type")?;
46 let column = Field::new(field.name(), field_type);
47 Ok(column)
48 })
49 .collect()
50}
51
52#[derive(Error, Debug, Macro)]
53#[error("{0}")]
54struct ProtobufTypeError(#[message] String);
55
56fn detect_loop_and_push(
57 trace: &mut Vec<String>,
58 fd: &FieldDescriptor,
59) -> std::result::Result<(), ProtobufTypeError> {
60 let identifier = format!("{}({})", fd.name(), fd.full_name());
61 if trace.iter().any(|s| s == identifier.as_str()) {
62 bail_protobuf_type_error!(
63 "circular reference detected: {}, conflict with {}, kind {:?}. Adding {:?} to {:?} may help.",
64 trace.iter().format("->"),
65 identifier,
66 fd.kind(),
67 fd.kind(),
68 PROTOBUF_MESSAGES_AS_JSONB,
69 );
70 }
71 trace.push(identifier);
72 Ok(())
73}
74
75pub fn from_protobuf_message_field<'a>(
80 field_desc: &FieldDescriptor,
81 message: &'a DynamicMessage,
82 type_expected: &DataType,
83 messages_as_jsonb: &'a HashSet<String>,
84) -> AccessResult<DatumCow<'a>> {
85 let value = if field_desc.supports_presence() && !message.has_field(field_desc) {
86 return Ok(DatumCow::NULL);
89 } else {
90 message.get_field(field_desc)
92 };
93
94 match value {
95 Cow::Borrowed(value) => {
96 from_protobuf_value(field_desc, value, type_expected, messages_as_jsonb)
97 }
98 Cow::Owned(value) => {
99 from_protobuf_value(field_desc, &value, type_expected, messages_as_jsonb)
100 .map(|d| d.to_owned_datum().into())
101 }
102 }
103}
104
105fn from_protobuf_value<'a>(
107 field_desc: &FieldDescriptor,
108 value: &'a Value,
109 type_expected: &DataType,
110 messages_as_jsonb: &'a HashSet<String>,
111) -> AccessResult<DatumCow<'a>> {
112 let kind = field_desc.kind();
113
114 macro_rules! borrowed {
115 ($v:expr) => {
116 return Ok(DatumCow::Borrowed(Some($v.into())))
117 };
118 }
119
120 let v: ScalarImpl = match value {
121 Value::Bool(v) => ScalarImpl::Bool(*v),
122 Value::I32(i) => ScalarImpl::Int32(*i),
123 Value::U32(i) => ScalarImpl::Int64(*i as i64),
124 Value::I64(i) => ScalarImpl::Int64(*i),
125 Value::U64(i) => ScalarImpl::Decimal(Decimal::from(*i)),
126 Value::F32(f) => ScalarImpl::Float32(F32::from(*f)),
127 Value::F64(f) => ScalarImpl::Float64(F64::from(*f)),
128 Value::String(s) => borrowed!(s.as_str()),
129 Value::EnumNumber(idx) => {
130 let enum_desc = kind.as_enum().ok_or_else(|| AccessError::TypeError {
131 expected: "enum".to_owned(),
132 got: format!("{kind:?}"),
133 value: value.to_string(),
134 })?;
135 let enum_symbol = enum_desc.get_value(*idx).ok_or_else(|| {
136 uncategorized!("unknown enum index {} of enum {:?}", idx, enum_desc)
137 })?;
138 ScalarImpl::Utf8(enum_symbol.name().into())
139 }
140 Value::Message(dyn_msg) => {
141 if messages_as_jsonb.contains(dyn_msg.descriptor().full_name()) {
142 ScalarImpl::Jsonb(JsonbVal::from(
143 serde_json::to_value(dyn_msg).map_err(AccessError::ProtobufAnyToJson)?,
144 ))
145 } else {
146 let desc = dyn_msg.descriptor();
147 let DataType::Struct(st) = type_expected else {
148 return Err(AccessError::TypeError {
149 expected: type_expected.to_string(),
150 got: desc.full_name().to_owned(),
151 value: value.to_string(), });
153 };
154
155 let mut datums = Vec::with_capacity(st.len());
156 for (name, expected_field_type) in st.iter() {
157 let Some(field_desc) = desc.get_field_by_name(name) else {
158 datums.push(None);
160 continue;
161 };
162 let datum = from_protobuf_message_field(
163 &field_desc,
164 dyn_msg,
165 expected_field_type,
166 messages_as_jsonb,
167 )?;
168 datums.push(datum.to_owned_datum());
169 }
170 ScalarImpl::Struct(StructValue::new(datums))
171 }
172 }
173 Value::List(values) => {
174 let DataType::List(element_type) = type_expected else {
175 return Err(AccessError::TypeError {
176 expected: type_expected.to_string(),
177 got: format!("repeated {:?}", kind),
178 value: value.to_string(), });
180 };
181 let mut builder = element_type.create_array_builder(values.len());
182 for value in values {
183 builder.append(from_protobuf_value(
184 field_desc,
185 value,
186 element_type,
187 messages_as_jsonb,
188 )?);
189 }
190 ScalarImpl::List(ListValue::new(builder.finish()))
191 }
192 Value::Bytes(value) => borrowed!(&**value),
193 Value::Map(map) => {
194 let err = || {
195 AccessError::TypeError {
196 expected: type_expected.to_string(),
197 got: format!("{:?}", kind),
198 value: value.to_string(), }
200 };
201
202 let DataType::Map(map_type) = type_expected else {
203 return Err(err());
204 };
205 if !field_desc.is_map() {
206 return Err(err());
207 }
208 let map_desc = kind.as_message().ok_or_else(err)?;
209
210 let mut key_builder = map_type.key().create_array_builder(map.len());
211 let mut value_builder = map_type.value().create_array_builder(map.len());
212 for (key, value) in map.iter().sorted_by_key(|(k, _v)| *k) {
217 key_builder.append(from_protobuf_value(
218 &map_desc.map_entry_key_field(),
219 &key.clone().into(),
220 map_type.key(),
221 messages_as_jsonb,
222 )?);
223 value_builder.append(from_protobuf_value(
224 &map_desc.map_entry_value_field(),
225 value,
226 map_type.value(),
227 messages_as_jsonb,
228 )?);
229 }
230 let keys = key_builder.finish();
231 let values = value_builder.finish();
232 ScalarImpl::Map(
233 MapValue::try_from_kv(ListValue::new(keys), ListValue::new(values))
234 .map_err(|e| uncategorized!("failed to convert protobuf map: {e}"))?,
235 )
236 }
237 };
238 Ok(Some(v).into())
239}
240
241fn protobuf_type_mapping(
243 field_descriptor: &FieldDescriptor,
244 parse_trace: &mut Vec<String>,
245 messages_as_jsonb: &HashSet<String>,
246) -> std::result::Result<DataType, ProtobufTypeError> {
247 detect_loop_and_push(parse_trace, field_descriptor)?;
248 let mut t = match field_descriptor.kind() {
249 Kind::Bool => DataType::Boolean,
250 Kind::Double => DataType::Float64,
251 Kind::Float => DataType::Float32,
252 Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => DataType::Int32,
253 Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 | Kind::Uint32 | Kind::Fixed32 => {
255 DataType::Int64
256 }
257 Kind::Uint64 | Kind::Fixed64 => DataType::Decimal,
258 Kind::String => DataType::Varchar,
259 Kind::Message(m) => {
260 if messages_as_jsonb.contains(m.full_name()) {
261 DataType::Jsonb
263 } else if m.is_map_entry() {
264 debug_assert!(field_descriptor.is_map());
266 let key = protobuf_type_mapping(
267 &m.map_entry_key_field(),
268 parse_trace,
269 messages_as_jsonb,
270 )?;
271 let value = protobuf_type_mapping(
272 &m.map_entry_value_field(),
273 parse_trace,
274 messages_as_jsonb,
275 )?;
276 _ = parse_trace.pop();
277 return Ok(DataType::Map(MapType::from_kv(key, value)));
278 } else {
279 let fields = m
280 .fields()
281 .map(|f| {
282 Ok((
283 f.name().to_owned(),
284 protobuf_type_mapping(&f, parse_trace, messages_as_jsonb)?,
285 ))
286 })
287 .try_collect::<_, Vec<_>, _>()?;
288 StructType::new(fields).into()
289 }
290 }
291 Kind::Enum(_) => DataType::Varchar,
292 Kind::Bytes => DataType::Bytea,
293 };
294 if field_descriptor.cardinality() == Cardinality::Repeated {
295 debug_assert!(!field_descriptor.is_map());
296 t = DataType::List(Box::new(t))
297 }
298 _ = parse_trace.pop();
299 Ok(t)
300}