risingwave_connector_codec/decoder/protobuf/
parser.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::collections::HashSet;
17
18use anyhow::Context;
19use itertools::Itertools;
20use prost_reflect::{
21    Cardinality, DynamicMessage, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage, Value,
22};
23use risingwave_common::array::{ListValue, StructValue};
24use risingwave_common::catalog::Field;
25use risingwave_common::types::{
26    DataType, DatumCow, Decimal, F32, F64, JsonbVal, MapType, MapValue, ScalarImpl, StructType,
27    ToOwnedDatum,
28};
29use thiserror::Error;
30use thiserror_ext::Macro;
31
32use crate::decoder::{AccessError, AccessResult, uncategorized};
33
34pub const PROTOBUF_MESSAGES_AS_JSONB: &str = "messages_as_jsonb";
35
36pub fn pb_schema_to_fields(
37    message_descriptor: &MessageDescriptor,
38    messages_as_jsonb: &HashSet<String>,
39) -> anyhow::Result<Vec<Field>> {
40    let mut parse_trace: Vec<String> = vec![];
41    message_descriptor
42        .fields()
43        .map(|field| {
44            let field_type = protobuf_type_mapping(&field, &mut parse_trace, messages_as_jsonb)
45                .context("failed to map protobuf type")?;
46            let column = Field::new(field.name(), field_type);
47            Ok(column)
48        })
49        .collect()
50}
51
52#[derive(Error, Debug, Macro)]
53#[error("{0}")]
54struct ProtobufTypeError(#[message] String);
55
56fn detect_loop_and_push(
57    trace: &mut Vec<String>,
58    fd: &FieldDescriptor,
59) -> std::result::Result<(), ProtobufTypeError> {
60    let identifier = format!("{}({})", fd.name(), fd.full_name());
61    if trace.iter().any(|s| s == identifier.as_str()) {
62        bail_protobuf_type_error!(
63            "circular reference detected: {}, conflict with {}, kind {:?}. Adding {:?} to {:?} may help.",
64            trace.iter().format("->"),
65            identifier,
66            fd.kind(),
67            fd.kind(),
68            PROTOBUF_MESSAGES_AS_JSONB,
69        );
70    }
71    trace.push(identifier);
72    Ok(())
73}
74
75/// Converts a protobuf message field to a datum.
76///
77/// We will get the protobuf value from the message by checking the field descriptor and correctly
78/// handling presence, then call [`from_protobuf_value`].
79pub fn from_protobuf_message_field<'a>(
80    field_desc: &FieldDescriptor,
81    message: &'a DynamicMessage,
82    type_expected: &DataType,
83    messages_as_jsonb: &'a HashSet<String>,
84) -> AccessResult<DatumCow<'a>> {
85    let value = if field_desc.supports_presence() && !message.has_field(field_desc) {
86        // The field supports presence and it's absent in the message. Treat it as NULL.
87        // This is the case for `optional` fields, message fields, and fields contained in `oneof`.
88        return Ok(DatumCow::NULL);
89    } else {
90        // Otherwise, directly call `get_field`, which will return the default value if absent.
91        message.get_field(field_desc)
92    };
93
94    match value {
95        Cow::Borrowed(value) => {
96            from_protobuf_value(field_desc, value, type_expected, messages_as_jsonb)
97        }
98        Cow::Owned(value) => {
99            from_protobuf_value(field_desc, &value, type_expected, messages_as_jsonb)
100                .map(|d| d.to_owned_datum().into())
101        }
102    }
103}
104
105/// Converts a protobuf value to a datum.
106fn from_protobuf_value<'a>(
107    field_desc: &FieldDescriptor,
108    value: &'a Value,
109    type_expected: &DataType,
110    messages_as_jsonb: &'a HashSet<String>,
111) -> AccessResult<DatumCow<'a>> {
112    let kind = field_desc.kind();
113
114    macro_rules! borrowed {
115        ($v:expr) => {
116            return Ok(DatumCow::Borrowed(Some($v.into())))
117        };
118    }
119
120    let v: ScalarImpl = match value {
121        Value::Bool(v) => ScalarImpl::Bool(*v),
122        Value::I32(i) => ScalarImpl::Int32(*i),
123        Value::U32(i) => ScalarImpl::Int64(*i as i64),
124        Value::I64(i) => ScalarImpl::Int64(*i),
125        Value::U64(i) => ScalarImpl::Decimal(Decimal::from(*i)),
126        Value::F32(f) => ScalarImpl::Float32(F32::from(*f)),
127        Value::F64(f) => ScalarImpl::Float64(F64::from(*f)),
128        Value::String(s) => borrowed!(s.as_str()),
129        Value::EnumNumber(idx) => {
130            let enum_desc = kind.as_enum().ok_or_else(|| AccessError::TypeError {
131                expected: "enum".to_owned(),
132                got: format!("{kind:?}"),
133                value: value.to_string(),
134            })?;
135            let enum_symbol = enum_desc.get_value(*idx).ok_or_else(|| {
136                uncategorized!("unknown enum index {} of enum {:?}", idx, enum_desc)
137            })?;
138            ScalarImpl::Utf8(enum_symbol.name().into())
139        }
140        Value::Message(dyn_msg) => {
141            if messages_as_jsonb.contains(dyn_msg.descriptor().full_name()) {
142                ScalarImpl::Jsonb(JsonbVal::from(
143                    serde_json::to_value(dyn_msg).map_err(AccessError::ProtobufAnyToJson)?,
144                ))
145            } else {
146                let desc = dyn_msg.descriptor();
147                let DataType::Struct(st) = type_expected else {
148                    return Err(AccessError::TypeError {
149                        expected: type_expected.to_string(),
150                        got: desc.full_name().to_owned(),
151                        value: value.to_string(), // Protobuf TEXT
152                    });
153                };
154
155                let mut datums = Vec::with_capacity(st.len());
156                for (name, expected_field_type) in st.iter() {
157                    let Some(field_desc) = desc.get_field_by_name(name) else {
158                        // Field deleted in protobuf. Fallback to SQL NULL (of proper RW type).
159                        datums.push(None);
160                        continue;
161                    };
162                    let datum = from_protobuf_message_field(
163                        &field_desc,
164                        dyn_msg,
165                        expected_field_type,
166                        messages_as_jsonb,
167                    )?;
168                    datums.push(datum.to_owned_datum());
169                }
170                ScalarImpl::Struct(StructValue::new(datums))
171            }
172        }
173        Value::List(values) => {
174            let DataType::List(element_type) = type_expected else {
175                return Err(AccessError::TypeError {
176                    expected: type_expected.to_string(),
177                    got: format!("repeated {:?}", kind),
178                    value: value.to_string(), // Protobuf TEXT
179                });
180            };
181            let mut builder = element_type.create_array_builder(values.len());
182            for value in values {
183                builder.append(from_protobuf_value(
184                    field_desc,
185                    value,
186                    element_type,
187                    messages_as_jsonb,
188                )?);
189            }
190            ScalarImpl::List(ListValue::new(builder.finish()))
191        }
192        Value::Bytes(value) => borrowed!(&**value),
193        Value::Map(map) => {
194            let err = || {
195                AccessError::TypeError {
196                    expected: type_expected.to_string(),
197                    got: format!("{:?}", kind),
198                    value: value.to_string(), // Protobuf TEXT
199                }
200            };
201
202            let DataType::Map(map_type) = type_expected else {
203                return Err(err());
204            };
205            if !field_desc.is_map() {
206                return Err(err());
207            }
208            let map_desc = kind.as_message().ok_or_else(err)?;
209
210            let mut key_builder = map_type.key().create_array_builder(map.len());
211            let mut value_builder = map_type.value().create_array_builder(map.len());
212            // NOTE: HashMap's iter order is non-deterministic, but MapValue's
213            // order matters. We sort by key here to have deterministic order
214            // in tests. We might consider removing this, or make all MapValue sorted
215            // in the future.
216            for (key, value) in map.iter().sorted_by_key(|(k, _v)| *k) {
217                key_builder.append(from_protobuf_value(
218                    &map_desc.map_entry_key_field(),
219                    &key.clone().into(),
220                    map_type.key(),
221                    messages_as_jsonb,
222                )?);
223                value_builder.append(from_protobuf_value(
224                    &map_desc.map_entry_value_field(),
225                    value,
226                    map_type.value(),
227                    messages_as_jsonb,
228                )?);
229            }
230            let keys = key_builder.finish();
231            let values = value_builder.finish();
232            ScalarImpl::Map(
233                MapValue::try_from_kv(ListValue::new(keys), ListValue::new(values))
234                    .map_err(|e| uncategorized!("failed to convert protobuf map: {e}"))?,
235            )
236        }
237    };
238    Ok(Some(v).into())
239}
240
241/// Maps protobuf type to RW type.
242fn protobuf_type_mapping(
243    field_descriptor: &FieldDescriptor,
244    parse_trace: &mut Vec<String>,
245    messages_as_jsonb: &HashSet<String>,
246) -> std::result::Result<DataType, ProtobufTypeError> {
247    detect_loop_and_push(parse_trace, field_descriptor)?;
248    let mut t = match field_descriptor.kind() {
249        Kind::Bool => DataType::Boolean,
250        Kind::Double => DataType::Float64,
251        Kind::Float => DataType::Float32,
252        Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => DataType::Int32,
253        // Fixed32 represents [0, 2^32 - 1]. It's equal to u32.
254        Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 | Kind::Uint32 | Kind::Fixed32 => {
255            DataType::Int64
256        }
257        Kind::Uint64 | Kind::Fixed64 => DataType::Decimal,
258        Kind::String => DataType::Varchar,
259        Kind::Message(m) => {
260            if messages_as_jsonb.contains(m.full_name()) {
261                // Well-Known Types are identified by their full name
262                DataType::Jsonb
263            } else if m.is_map_entry() {
264                // Map is equivalent to `repeated MapFieldEntry map_field = N;`
265                debug_assert!(field_descriptor.is_map());
266                let key = protobuf_type_mapping(
267                    &m.map_entry_key_field(),
268                    parse_trace,
269                    messages_as_jsonb,
270                )?;
271                let value = protobuf_type_mapping(
272                    &m.map_entry_value_field(),
273                    parse_trace,
274                    messages_as_jsonb,
275                )?;
276                _ = parse_trace.pop();
277                return Ok(DataType::Map(MapType::from_kv(key, value)));
278            } else {
279                let fields = m
280                    .fields()
281                    .map(|f| {
282                        Ok((
283                            f.name().to_owned(),
284                            protobuf_type_mapping(&f, parse_trace, messages_as_jsonb)?,
285                        ))
286                    })
287                    .try_collect::<_, Vec<_>, _>()?;
288                StructType::new(fields).into()
289            }
290        }
291        Kind::Enum(_) => DataType::Varchar,
292        Kind::Bytes => DataType::Bytea,
293    };
294    if field_descriptor.cardinality() == Cardinality::Repeated {
295        debug_assert!(!field_descriptor.is_map());
296        t = DataType::List(Box::new(t))
297    }
298    _ = parse_trace.pop();
299    Ok(t)
300}