risingwave_connector_codec/decoder/protobuf/
parser.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16
17use anyhow::Context;
18use itertools::Itertools;
19use prost_reflect::{Cardinality, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage, Value};
20use risingwave_common::array::{ListValue, StructValue};
21use risingwave_common::catalog::Field;
22use risingwave_common::types::{
23    DataType, DatumCow, Decimal, F32, F64, JsonbVal, MapType, MapValue, ScalarImpl, StructType,
24    ToOwnedDatum,
25};
26use thiserror::Error;
27use thiserror_ext::Macro;
28
29use crate::decoder::{AccessError, AccessResult, uncategorized};
30
31pub const PROTOBUF_MESSAGES_AS_JSONB: &str = "messages_as_jsonb";
32
33pub fn pb_schema_to_fields(
34    message_descriptor: &MessageDescriptor,
35    messages_as_jsonb: &HashSet<String>,
36) -> anyhow::Result<Vec<Field>> {
37    let mut parse_trace: Vec<String> = vec![];
38    message_descriptor
39        .fields()
40        .map(|field| {
41            let field_type = protobuf_type_mapping(&field, &mut parse_trace, messages_as_jsonb)
42                .context("failed to map protobuf type")?;
43            let column = Field::new(field.name(), field_type);
44            Ok(column)
45        })
46        .collect()
47}
48
49#[derive(Error, Debug, Macro)]
50#[error("{0}")]
51struct ProtobufTypeError(#[message] String);
52
53fn detect_loop_and_push(
54    trace: &mut Vec<String>,
55    fd: &FieldDescriptor,
56) -> std::result::Result<(), ProtobufTypeError> {
57    let identifier = format!("{}({})", fd.name(), fd.full_name());
58    if trace.iter().any(|s| s == identifier.as_str()) {
59        bail_protobuf_type_error!(
60            "circular reference detected: {}, conflict with {}, kind {:?}. Adding {:?} to {:?} may help.",
61            trace.iter().format("->"),
62            identifier,
63            fd.kind(),
64            fd.kind(),
65            PROTOBUF_MESSAGES_AS_JSONB,
66        );
67    }
68    trace.push(identifier);
69    Ok(())
70}
71
72pub fn from_protobuf_value<'a>(
73    field_desc: &FieldDescriptor,
74    value: &'a Value,
75    type_expected: &DataType,
76    messages_as_jsonb: &'a HashSet<String>,
77) -> AccessResult<DatumCow<'a>> {
78    let kind = field_desc.kind();
79
80    macro_rules! borrowed {
81        ($v:expr) => {
82            return Ok(DatumCow::Borrowed(Some($v.into())))
83        };
84    }
85
86    let v: ScalarImpl = match value {
87        Value::Bool(v) => ScalarImpl::Bool(*v),
88        Value::I32(i) => ScalarImpl::Int32(*i),
89        Value::U32(i) => ScalarImpl::Int64(*i as i64),
90        Value::I64(i) => ScalarImpl::Int64(*i),
91        Value::U64(i) => ScalarImpl::Decimal(Decimal::from(*i)),
92        Value::F32(f) => ScalarImpl::Float32(F32::from(*f)),
93        Value::F64(f) => ScalarImpl::Float64(F64::from(*f)),
94        Value::String(s) => borrowed!(s.as_str()),
95        Value::EnumNumber(idx) => {
96            let enum_desc = kind.as_enum().ok_or_else(|| AccessError::TypeError {
97                expected: "enum".to_owned(),
98                got: format!("{kind:?}"),
99                value: value.to_string(),
100            })?;
101            let enum_symbol = enum_desc.get_value(*idx).ok_or_else(|| {
102                uncategorized!("unknown enum index {} of enum {:?}", idx, enum_desc)
103            })?;
104            ScalarImpl::Utf8(enum_symbol.name().into())
105        }
106        Value::Message(dyn_msg) => {
107            if messages_as_jsonb.contains(dyn_msg.descriptor().full_name()) {
108                ScalarImpl::Jsonb(JsonbVal::from(
109                    serde_json::to_value(dyn_msg).map_err(AccessError::ProtobufAnyToJson)?,
110                ))
111            } else {
112                let desc = dyn_msg.descriptor();
113                let DataType::Struct(st) = type_expected else {
114                    return Err(AccessError::TypeError {
115                        expected: type_expected.to_string(),
116                        got: desc.full_name().to_owned(),
117                        value: value.to_string(), // Protobuf TEXT
118                    });
119                };
120
121                let mut rw_values = Vec::with_capacity(st.len());
122                for (name, expected_field_type) in st.iter() {
123                    let Some(field_desc) = desc.get_field_by_name(name) else {
124                        // Field deleted in protobuf. Fallback to SQL NULL (of proper RW type).
125                        rw_values.push(None);
126                        continue;
127                    };
128                    let value = dyn_msg.get_field(&field_desc);
129                    rw_values.push(
130                        from_protobuf_value(
131                            &field_desc,
132                            &value,
133                            expected_field_type,
134                            messages_as_jsonb,
135                        )?
136                        .to_owned_datum(),
137                    );
138                }
139                ScalarImpl::Struct(StructValue::new(rw_values))
140            }
141        }
142        Value::List(values) => {
143            let DataType::List(element_type) = type_expected else {
144                return Err(AccessError::TypeError {
145                    expected: type_expected.to_string(),
146                    got: format!("repeated {:?}", kind),
147                    value: value.to_string(), // Protobuf TEXT
148                });
149            };
150            let mut builder = element_type.create_array_builder(values.len());
151            for value in values {
152                builder.append(from_protobuf_value(
153                    field_desc,
154                    value,
155                    element_type,
156                    messages_as_jsonb,
157                )?);
158            }
159            ScalarImpl::List(ListValue::new(builder.finish()))
160        }
161        Value::Bytes(value) => borrowed!(&**value),
162        Value::Map(map) => {
163            let err = || {
164                AccessError::TypeError {
165                    expected: type_expected.to_string(),
166                    got: format!("{:?}", kind),
167                    value: value.to_string(), // Protobuf TEXT
168                }
169            };
170
171            let DataType::Map(map_type) = type_expected else {
172                return Err(err());
173            };
174            if !field_desc.is_map() {
175                return Err(err());
176            }
177            let map_desc = kind.as_message().ok_or_else(err)?;
178
179            let mut key_builder = map_type.key().create_array_builder(map.len());
180            let mut value_builder = map_type.value().create_array_builder(map.len());
181            // NOTE: HashMap's iter order is non-deterministic, but MapValue's
182            // order matters. We sort by key here to have deterministic order
183            // in tests. We might consider removing this, or make all MapValue sorted
184            // in the future.
185            for (key, value) in map.iter().sorted_by_key(|(k, _v)| *k) {
186                key_builder.append(from_protobuf_value(
187                    &map_desc.map_entry_key_field(),
188                    &key.clone().into(),
189                    map_type.key(),
190                    messages_as_jsonb,
191                )?);
192                value_builder.append(from_protobuf_value(
193                    &map_desc.map_entry_value_field(),
194                    value,
195                    map_type.value(),
196                    messages_as_jsonb,
197                )?);
198            }
199            let keys = key_builder.finish();
200            let values = value_builder.finish();
201            ScalarImpl::Map(
202                MapValue::try_from_kv(ListValue::new(keys), ListValue::new(values))
203                    .map_err(|e| uncategorized!("failed to convert protobuf map: {e}"))?,
204            )
205        }
206    };
207    Ok(Some(v).into())
208}
209
210/// Maps protobuf type to RW type.
211fn protobuf_type_mapping(
212    field_descriptor: &FieldDescriptor,
213    parse_trace: &mut Vec<String>,
214    messages_as_jsonb: &HashSet<String>,
215) -> std::result::Result<DataType, ProtobufTypeError> {
216    detect_loop_and_push(parse_trace, field_descriptor)?;
217    let mut t = match field_descriptor.kind() {
218        Kind::Bool => DataType::Boolean,
219        Kind::Double => DataType::Float64,
220        Kind::Float => DataType::Float32,
221        Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => DataType::Int32,
222        // Fixed32 represents [0, 2^32 - 1]. It's equal to u32.
223        Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 | Kind::Uint32 | Kind::Fixed32 => {
224            DataType::Int64
225        }
226        Kind::Uint64 | Kind::Fixed64 => DataType::Decimal,
227        Kind::String => DataType::Varchar,
228        Kind::Message(m) => {
229            if messages_as_jsonb.contains(m.full_name()) {
230                // Well-Known Types are identified by their full name
231                DataType::Jsonb
232            } else if m.is_map_entry() {
233                // Map is equivalent to `repeated MapFieldEntry map_field = N;`
234                debug_assert!(field_descriptor.is_map());
235                let key = protobuf_type_mapping(
236                    &m.map_entry_key_field(),
237                    parse_trace,
238                    messages_as_jsonb,
239                )?;
240                let value = protobuf_type_mapping(
241                    &m.map_entry_value_field(),
242                    parse_trace,
243                    messages_as_jsonb,
244                )?;
245                _ = parse_trace.pop();
246                return Ok(DataType::Map(MapType::from_kv(key, value)));
247            } else {
248                let fields = m
249                    .fields()
250                    .map(|f| {
251                        Ok((
252                            f.name().to_owned(),
253                            protobuf_type_mapping(&f, parse_trace, messages_as_jsonb)?,
254                        ))
255                    })
256                    .try_collect::<_, Vec<_>, _>()?;
257                StructType::new(fields).into()
258            }
259        }
260        Kind::Enum(_) => DataType::Varchar,
261        Kind::Bytes => DataType::Bytea,
262    };
263    if field_descriptor.cardinality() == Cardinality::Repeated {
264        debug_assert!(!field_descriptor.is_map());
265        t = DataType::List(Box::new(t))
266    }
267    _ = parse_trace.pop();
268    Ok(t)
269}