risingwave_connector_codec/decoder/protobuf/
parser.rs1use std::collections::HashSet;
16
17use anyhow::Context;
18use itertools::Itertools;
19use prost_reflect::{Cardinality, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage, Value};
20use risingwave_common::array::{ListValue, StructValue};
21use risingwave_common::catalog::Field;
22use risingwave_common::types::{
23 DataType, DatumCow, Decimal, F32, F64, JsonbVal, MapType, MapValue, ScalarImpl, StructType,
24 ToOwnedDatum,
25};
26use thiserror::Error;
27use thiserror_ext::Macro;
28
29use crate::decoder::{AccessError, AccessResult, uncategorized};
30
31pub const PROTOBUF_MESSAGES_AS_JSONB: &str = "messages_as_jsonb";
32
33pub fn pb_schema_to_fields(
34 message_descriptor: &MessageDescriptor,
35 messages_as_jsonb: &HashSet<String>,
36) -> anyhow::Result<Vec<Field>> {
37 let mut parse_trace: Vec<String> = vec![];
38 message_descriptor
39 .fields()
40 .map(|field| {
41 let field_type = protobuf_type_mapping(&field, &mut parse_trace, messages_as_jsonb)
42 .context("failed to map protobuf type")?;
43 let column = Field::new(field.name(), field_type);
44 Ok(column)
45 })
46 .collect()
47}
48
49#[derive(Error, Debug, Macro)]
50#[error("{0}")]
51struct ProtobufTypeError(#[message] String);
52
53fn detect_loop_and_push(
54 trace: &mut Vec<String>,
55 fd: &FieldDescriptor,
56) -> std::result::Result<(), ProtobufTypeError> {
57 let identifier = format!("{}({})", fd.name(), fd.full_name());
58 if trace.iter().any(|s| s == identifier.as_str()) {
59 bail_protobuf_type_error!(
60 "circular reference detected: {}, conflict with {}, kind {:?}. Adding {:?} to {:?} may help.",
61 trace.iter().format("->"),
62 identifier,
63 fd.kind(),
64 fd.kind(),
65 PROTOBUF_MESSAGES_AS_JSONB,
66 );
67 }
68 trace.push(identifier);
69 Ok(())
70}
71
72pub fn from_protobuf_value<'a>(
73 field_desc: &FieldDescriptor,
74 value: &'a Value,
75 type_expected: &DataType,
76 messages_as_jsonb: &'a HashSet<String>,
77) -> AccessResult<DatumCow<'a>> {
78 let kind = field_desc.kind();
79
80 macro_rules! borrowed {
81 ($v:expr) => {
82 return Ok(DatumCow::Borrowed(Some($v.into())))
83 };
84 }
85
86 let v: ScalarImpl = match value {
87 Value::Bool(v) => ScalarImpl::Bool(*v),
88 Value::I32(i) => ScalarImpl::Int32(*i),
89 Value::U32(i) => ScalarImpl::Int64(*i as i64),
90 Value::I64(i) => ScalarImpl::Int64(*i),
91 Value::U64(i) => ScalarImpl::Decimal(Decimal::from(*i)),
92 Value::F32(f) => ScalarImpl::Float32(F32::from(*f)),
93 Value::F64(f) => ScalarImpl::Float64(F64::from(*f)),
94 Value::String(s) => borrowed!(s.as_str()),
95 Value::EnumNumber(idx) => {
96 let enum_desc = kind.as_enum().ok_or_else(|| AccessError::TypeError {
97 expected: "enum".to_owned(),
98 got: format!("{kind:?}"),
99 value: value.to_string(),
100 })?;
101 let enum_symbol = enum_desc.get_value(*idx).ok_or_else(|| {
102 uncategorized!("unknown enum index {} of enum {:?}", idx, enum_desc)
103 })?;
104 ScalarImpl::Utf8(enum_symbol.name().into())
105 }
106 Value::Message(dyn_msg) => {
107 if messages_as_jsonb.contains(dyn_msg.descriptor().full_name()) {
108 ScalarImpl::Jsonb(JsonbVal::from(
109 serde_json::to_value(dyn_msg).map_err(AccessError::ProtobufAnyToJson)?,
110 ))
111 } else {
112 let desc = dyn_msg.descriptor();
113 let DataType::Struct(st) = type_expected else {
114 return Err(AccessError::TypeError {
115 expected: type_expected.to_string(),
116 got: desc.full_name().to_owned(),
117 value: value.to_string(), });
119 };
120
121 let mut rw_values = Vec::with_capacity(st.len());
122 for (name, expected_field_type) in st.iter() {
123 let Some(field_desc) = desc.get_field_by_name(name) else {
124 rw_values.push(None);
126 continue;
127 };
128 let value = dyn_msg.get_field(&field_desc);
129 rw_values.push(
130 from_protobuf_value(
131 &field_desc,
132 &value,
133 expected_field_type,
134 messages_as_jsonb,
135 )?
136 .to_owned_datum(),
137 );
138 }
139 ScalarImpl::Struct(StructValue::new(rw_values))
140 }
141 }
142 Value::List(values) => {
143 let DataType::List(element_type) = type_expected else {
144 return Err(AccessError::TypeError {
145 expected: type_expected.to_string(),
146 got: format!("repeated {:?}", kind),
147 value: value.to_string(), });
149 };
150 let mut builder = element_type.create_array_builder(values.len());
151 for value in values {
152 builder.append(from_protobuf_value(
153 field_desc,
154 value,
155 element_type,
156 messages_as_jsonb,
157 )?);
158 }
159 ScalarImpl::List(ListValue::new(builder.finish()))
160 }
161 Value::Bytes(value) => borrowed!(&**value),
162 Value::Map(map) => {
163 let err = || {
164 AccessError::TypeError {
165 expected: type_expected.to_string(),
166 got: format!("{:?}", kind),
167 value: value.to_string(), }
169 };
170
171 let DataType::Map(map_type) = type_expected else {
172 return Err(err());
173 };
174 if !field_desc.is_map() {
175 return Err(err());
176 }
177 let map_desc = kind.as_message().ok_or_else(err)?;
178
179 let mut key_builder = map_type.key().create_array_builder(map.len());
180 let mut value_builder = map_type.value().create_array_builder(map.len());
181 for (key, value) in map.iter().sorted_by_key(|(k, _v)| *k) {
186 key_builder.append(from_protobuf_value(
187 &map_desc.map_entry_key_field(),
188 &key.clone().into(),
189 map_type.key(),
190 messages_as_jsonb,
191 )?);
192 value_builder.append(from_protobuf_value(
193 &map_desc.map_entry_value_field(),
194 value,
195 map_type.value(),
196 messages_as_jsonb,
197 )?);
198 }
199 let keys = key_builder.finish();
200 let values = value_builder.finish();
201 ScalarImpl::Map(
202 MapValue::try_from_kv(ListValue::new(keys), ListValue::new(values))
203 .map_err(|e| uncategorized!("failed to convert protobuf map: {e}"))?,
204 )
205 }
206 };
207 Ok(Some(v).into())
208}
209
210fn protobuf_type_mapping(
212 field_descriptor: &FieldDescriptor,
213 parse_trace: &mut Vec<String>,
214 messages_as_jsonb: &HashSet<String>,
215) -> std::result::Result<DataType, ProtobufTypeError> {
216 detect_loop_and_push(parse_trace, field_descriptor)?;
217 let mut t = match field_descriptor.kind() {
218 Kind::Bool => DataType::Boolean,
219 Kind::Double => DataType::Float64,
220 Kind::Float => DataType::Float32,
221 Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => DataType::Int32,
222 Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 | Kind::Uint32 | Kind::Fixed32 => {
224 DataType::Int64
225 }
226 Kind::Uint64 | Kind::Fixed64 => DataType::Decimal,
227 Kind::String => DataType::Varchar,
228 Kind::Message(m) => {
229 if messages_as_jsonb.contains(m.full_name()) {
230 DataType::Jsonb
232 } else if m.is_map_entry() {
233 debug_assert!(field_descriptor.is_map());
235 let key = protobuf_type_mapping(
236 &m.map_entry_key_field(),
237 parse_trace,
238 messages_as_jsonb,
239 )?;
240 let value = protobuf_type_mapping(
241 &m.map_entry_value_field(),
242 parse_trace,
243 messages_as_jsonb,
244 )?;
245 _ = parse_trace.pop();
246 return Ok(DataType::Map(MapType::from_kv(key, value)));
247 } else {
248 let fields = m
249 .fields()
250 .map(|f| {
251 Ok((
252 f.name().to_owned(),
253 protobuf_type_mapping(&f, parse_trace, messages_as_jsonb)?,
254 ))
255 })
256 .try_collect::<_, Vec<_>, _>()?;
257 StructType::new(fields).into()
258 }
259 }
260 Kind::Enum(_) => DataType::Varchar,
261 Kind::Bytes => DataType::Bytea,
262 };
263 if field_descriptor.cardinality() == Cardinality::Repeated {
264 debug_assert!(!field_descriptor.is_map());
265 t = DataType::List(Box::new(t))
266 }
267 _ = parse_trace.pop();
268 Ok(t)
269}