risingwave_connector/parser/protobuf/
parser.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16
17use anyhow::Context;
18use prost_reflect::{DescriptorPool, DynamicMessage, FileDescriptor, MessageDescriptor};
19use risingwave_common::catalog::Field;
20use risingwave_common::{bail, try_match_expand};
21use risingwave_connector_codec::decoder::protobuf::ProtobufAccess;
22pub use risingwave_connector_codec::decoder::protobuf::parser::{PROTOBUF_MESSAGES_AS_JSONB, *};
23
24use crate::error::ConnectorResult;
25use crate::parser::unified::AccessImpl;
26use crate::parser::utils::bytes_from_url;
27use crate::parser::{AccessBuilder, EncodingProperties, SchemaLocation};
28use crate::schema::schema_registry::{Client, WireFormatError, extract_schema_id, handle_sr_list};
29use crate::schema::{
30    ConfluentSchemaLoader, InvalidOptionError, SchemaLoader, bail_invalid_option_error,
31};
32
33#[derive(Debug)]
34pub struct ProtobufAccessBuilder {
35    wire_type: WireType,
36    message_descriptor: MessageDescriptor,
37
38    // A HashSet containing protobuf message type full names (e.g. "google.protobuf.Any")
39    // that should be mapped to JSONB type when storing in RisingWave
40    messages_as_jsonb: HashSet<String>,
41}
42
43impl AccessBuilder for ProtobufAccessBuilder {
44    #[allow(clippy::unused_async)]
45    async fn generate_accessor(
46        &mut self,
47        payload: Vec<u8>,
48        _: &crate::source::SourceMeta,
49    ) -> ConnectorResult<AccessImpl<'_>> {
50        let payload = match self.wire_type {
51            WireType::Confluent => resolve_pb_header(&payload)?,
52            WireType::None => &payload,
53        };
54
55        let message = DynamicMessage::decode(self.message_descriptor.clone(), payload)
56            .context("failed to parse message")?;
57
58        Ok(AccessImpl::Protobuf(ProtobufAccess::new(
59            message,
60            &self.messages_as_jsonb,
61        )))
62    }
63}
64
65impl ProtobufAccessBuilder {
66    pub fn new(config: ProtobufParserConfig) -> ConnectorResult<Self> {
67        let ProtobufParserConfig {
68            wire_type,
69            message_descriptor,
70            messages_as_jsonb,
71        } = config;
72
73        Ok(Self {
74            wire_type,
75            message_descriptor,
76            messages_as_jsonb,
77        })
78    }
79}
80
81#[derive(Debug, Clone)]
82enum WireType {
83    None,
84    Confluent,
85    // Glue,
86    // Pulsar,
87}
88
89impl TryFrom<&SchemaLocation> for WireType {
90    type Error = InvalidOptionError;
91
92    fn try_from(value: &SchemaLocation) -> Result<Self, Self::Error> {
93        match value {
94            SchemaLocation::File { .. } => Ok(Self::None),
95            SchemaLocation::Confluent { .. } => Ok(Self::Confluent),
96            SchemaLocation::Glue { .. } => bail_invalid_option_error!(
97                "encode protobuf from aws glue schema registry not supported yet"
98            ),
99        }
100    }
101}
102
103#[derive(Debug, Clone)]
104pub struct ProtobufParserConfig {
105    wire_type: WireType,
106    pub(crate) message_descriptor: MessageDescriptor,
107    messages_as_jsonb: HashSet<String>,
108}
109
110impl ProtobufParserConfig {
111    pub async fn new(encoding_properties: EncodingProperties) -> ConnectorResult<Self> {
112        let protobuf_config = try_match_expand!(encoding_properties, EncodingProperties::Protobuf)?;
113        let message_name = &protobuf_config.message_name;
114
115        let wire_type = (&protobuf_config.schema_location).try_into()?;
116        if protobuf_config.key_message_name.is_some() {
117            // https://docs.confluent.io/platform/7.5/control-center/topics/schema.html#c3-schemas-best-practices-key-value-pairs
118            bail!("protobuf key is not supported");
119        }
120        let pool = match protobuf_config.schema_location {
121            SchemaLocation::Confluent {
122                urls,
123                client_config,
124                name_strategy,
125                topic,
126            } => {
127                let url = handle_sr_list(urls.as_str())?;
128                let client = Client::new(url, &client_config)?;
129                let loader = SchemaLoader::Confluent(ConfluentSchemaLoader {
130                    client,
131                    name_strategy,
132                    topic,
133                    key_record_name: None,
134                    val_record_name: Some(message_name.clone()),
135                });
136                let (_schema_id, root_file_descriptor) = loader
137                    .load_val_schema::<FileDescriptor>()
138                    .await
139                    .context("load schema failed")?;
140                root_file_descriptor.parent_pool().clone()
141            }
142            SchemaLocation::File {
143                url,
144                aws_auth_props,
145            } => {
146                let url = handle_sr_list(url.as_str())?;
147                let url = url.first().unwrap();
148                let schema_bytes = bytes_from_url(url, aws_auth_props.as_ref()).await?;
149                DescriptorPool::decode(schema_bytes.as_slice())
150                    .with_context(|| format!("cannot build descriptor pool from schema `{url}`"))?
151            }
152            SchemaLocation::Glue { .. } => bail_invalid_option_error!(
153                "encode protobuf from aws glue schema registry not supported yet"
154            ),
155        };
156
157        let message_descriptor = pool
158            .get_message_by_name(message_name)
159            .with_context(|| format!("cannot find message `{message_name}` in schema"))?;
160
161        Ok(Self {
162            message_descriptor,
163            wire_type,
164            messages_as_jsonb: protobuf_config.messages_as_jsonb,
165        })
166    }
167
168    /// Maps the protobuf schema to relational schema.
169    pub fn map_to_columns(&self) -> ConnectorResult<Vec<Field>> {
170        pb_schema_to_fields(&self.message_descriptor, &self.messages_as_jsonb).map_err(|e| e.into())
171    }
172}
173
174/// A port from the implementation of confluent's Varint Zig-zag deserialization.
175/// See `ReadVarint` in <https://github.com/apache/kafka/blob/trunk/clients/src/main/java/org/apache/kafka/common/utils/ByteUtils.java>
176fn decode_varint_zigzag(buffer: &[u8]) -> ConnectorResult<(i32, usize)> {
177    // We expect the decoded number to be 4 bytes.
178    let mut value = 0u32;
179    let mut shift = 0;
180    let mut len = 0usize;
181
182    for &byte in buffer {
183        len += 1;
184        // The Varint encoding is limited to 5 bytes.
185        if len > 5 {
186            break;
187        }
188        // The byte is cast to u32 to avoid shifting overflow.
189        let byte_ext = byte as u32;
190        // In Varint encoding, the lowest 7 bits are used to represent number,
191        // while the highest zero bit indicates the end of the number with Varint encoding.
192        value |= (byte_ext & 0x7F) << shift;
193        if byte_ext & 0x80 == 0 {
194            return Ok((((value >> 1) as i32) ^ -((value & 1) as i32), len));
195        }
196
197        shift += 7;
198    }
199
200    Err(WireFormatError::ParseMessageIndexes.into())
201}
202
203/// Reference: <https://docs.confluent.io/platform/current/schema-registry/fundamentals/serdes-develop/index.html#wire-format>
204/// Wire format for Confluent pb header is:
205/// | 0          | 1-4        | 5-x             | x+1-end
206/// | magic-byte | schema-id  | message-indexes | protobuf-payload
207pub(crate) fn resolve_pb_header(payload: &[u8]) -> ConnectorResult<&[u8]> {
208    // there's a message index array at the front of payload
209    // if it is the first message in proto def, the array is just and `0`
210    let (_, remained) = extract_schema_id(payload)?;
211    // The message indexes are encoded as int using variable-length zig-zag encoding,
212    // prefixed by the length of the array.
213    // Note that if the first byte is 0, it is equivalent to (1, 0) as an optimization.
214    match remained.first() {
215        Some(0) => Ok(&remained[1..]),
216        Some(_) => {
217            let (index_len, mut offset) = decode_varint_zigzag(remained)?;
218            for _ in 0..index_len {
219                offset += decode_varint_zigzag(&remained[offset..])?.1;
220            }
221            Ok(&remained[offset..])
222        }
223        None => bail!("The proto payload is empty"),
224    }
225}
226
227#[cfg(test)]
228mod test {
229    use super::*;
230
231    #[test]
232    fn test_decode_varint_zigzag() {
233        // 1. Positive number
234        let buffer = vec![0x02];
235        let (value, len) = decode_varint_zigzag(&buffer).unwrap();
236        assert_eq!(value, 1);
237        assert_eq!(len, 1);
238
239        // 2. Negative number
240        let buffer = vec![0x01];
241        let (value, len) = decode_varint_zigzag(&buffer).unwrap();
242        assert_eq!(value, -1);
243        assert_eq!(len, 1);
244
245        // 3. Larger positive number
246        let buffer = vec![0x9E, 0x03];
247        let (value, len) = decode_varint_zigzag(&buffer).unwrap();
248        assert_eq!(value, 207);
249        assert_eq!(len, 2);
250
251        // 4. Larger negative number
252        let buffer = vec![0xBF, 0x07];
253        let (value, len) = decode_varint_zigzag(&buffer).unwrap();
254        assert_eq!(value, -480);
255        assert_eq!(len, 2);
256
257        // 5. Maximum positive number
258        let buffer = vec![0xFE, 0xFF, 0xFF, 0xFF, 0x0F];
259        let (value, len) = decode_varint_zigzag(&buffer).unwrap();
260        assert_eq!(value, i32::MAX);
261        assert_eq!(len, 5);
262
263        // 6. Maximum negative number
264        let buffer = vec![0xFF, 0xFF, 0xFF, 0xFF, 0x0F];
265        let (value, len) = decode_varint_zigzag(&buffer).unwrap();
266        assert_eq!(value, i32::MIN);
267        assert_eq!(len, 5);
268
269        // 7. More than 32 bits
270        let buffer = vec![0xFF, 0xFF, 0xFF, 0xFF, 0x7F];
271        let (value, len) = decode_varint_zigzag(&buffer).unwrap();
272        assert_eq!(value, i32::MIN);
273        assert_eq!(len, 5);
274
275        // 8. Invalid input (more than 5 bytes)
276        let buffer = vec![0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF];
277        let result = decode_varint_zigzag(&buffer);
278        assert!(result.is_err());
279    }
280}