risingwave_connector/parser/protobuf/
parser.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16
17use anyhow::Context;
18use prost_reflect::{DescriptorPool, DynamicMessage, FileDescriptor, MessageDescriptor};
19use risingwave_common::catalog::Field;
20use risingwave_common::{bail, try_match_expand};
21use risingwave_connector_codec::decoder::protobuf::ProtobufAccess;
22pub use risingwave_connector_codec::decoder::protobuf::parser::{PROTOBUF_MESSAGES_AS_JSONB, *};
23
24use crate::error::ConnectorResult;
25use crate::parser::unified::AccessImpl;
26use crate::parser::utils::bytes_from_url;
27use crate::parser::{AccessBuilder, EncodingProperties, SchemaLocation};
28use crate::schema::schema_registry::{Client, WireFormatError, extract_schema_id, handle_sr_list};
29use crate::schema::{
30    ConfluentSchemaLoader, InvalidOptionError, SchemaLoader, bail_invalid_option_error,
31};
32
33#[derive(Debug)]
34pub struct ProtobufAccessBuilder {
35    wire_type: WireType,
36    message_descriptor: MessageDescriptor,
37
38    // A HashSet containing protobuf message type full names (e.g. "google.protobuf.Any")
39    // that should be mapped to JSONB type when storing in RisingWave
40    messages_as_jsonb: HashSet<String>,
41}
42
43impl AccessBuilder for ProtobufAccessBuilder {
44    async fn generate_accessor(
45        &mut self,
46        payload: Vec<u8>,
47        _: &crate::source::SourceMeta,
48    ) -> ConnectorResult<AccessImpl<'_>> {
49        let payload = match self.wire_type {
50            WireType::Confluent => resolve_pb_header(&payload)?,
51            WireType::None => &payload,
52        };
53
54        let message = DynamicMessage::decode(self.message_descriptor.clone(), payload)
55            .context("failed to parse message")?;
56
57        Ok(AccessImpl::Protobuf(ProtobufAccess::new(
58            message,
59            &self.messages_as_jsonb,
60        )))
61    }
62}
63
64impl ProtobufAccessBuilder {
65    pub fn new(config: ProtobufParserConfig) -> ConnectorResult<Self> {
66        let ProtobufParserConfig {
67            wire_type,
68            message_descriptor,
69            messages_as_jsonb,
70        } = config;
71
72        Ok(Self {
73            wire_type,
74            message_descriptor,
75            messages_as_jsonb,
76        })
77    }
78}
79
80#[derive(Debug, Clone)]
81enum WireType {
82    None,
83    Confluent,
84    // Glue,
85    // Pulsar,
86}
87
88impl TryFrom<&SchemaLocation> for WireType {
89    type Error = InvalidOptionError;
90
91    fn try_from(value: &SchemaLocation) -> Result<Self, Self::Error> {
92        match value {
93            SchemaLocation::File { .. } => Ok(Self::None),
94            SchemaLocation::Confluent { .. } => Ok(Self::Confluent),
95            SchemaLocation::Glue { .. } => bail_invalid_option_error!(
96                "encode protobuf from aws glue schema registry not supported yet"
97            ),
98        }
99    }
100}
101
102#[derive(Debug, Clone)]
103pub struct ProtobufParserConfig {
104    wire_type: WireType,
105    pub(crate) message_descriptor: MessageDescriptor,
106    messages_as_jsonb: HashSet<String>,
107}
108
109impl ProtobufParserConfig {
110    pub async fn new(encoding_properties: EncodingProperties) -> ConnectorResult<Self> {
111        let protobuf_config = try_match_expand!(encoding_properties, EncodingProperties::Protobuf)?;
112        let message_name = &protobuf_config.message_name;
113
114        let wire_type = (&protobuf_config.schema_location).try_into()?;
115        if protobuf_config.key_message_name.is_some() {
116            // https://docs.confluent.io/platform/7.5/control-center/topics/schema.html#c3-schemas-best-practices-key-value-pairs
117            bail!("protobuf key is not supported");
118        }
119        let pool = match protobuf_config.schema_location {
120            SchemaLocation::Confluent {
121                urls,
122                client_config,
123                name_strategy,
124                topic,
125            } => {
126                let url = handle_sr_list(urls.as_str())?;
127                let client = Client::new(url, &client_config)?;
128                let loader = SchemaLoader::Confluent(ConfluentSchemaLoader {
129                    client,
130                    name_strategy,
131                    topic,
132                    key_record_name: None,
133                    val_record_name: Some(message_name.clone()),
134                });
135                let (_schema_id, root_file_descriptor) = loader
136                    .load_val_schema::<FileDescriptor>()
137                    .await
138                    .context("load schema failed")?;
139                root_file_descriptor.parent_pool().clone()
140            }
141            SchemaLocation::File {
142                url,
143                aws_auth_props,
144            } => {
145                let url = handle_sr_list(url.as_str())?;
146                let url = url.first().unwrap();
147                let schema_bytes = bytes_from_url(url, aws_auth_props.as_ref()).await?;
148                DescriptorPool::decode(schema_bytes.as_slice())
149                    .with_context(|| format!("cannot build descriptor pool from schema `{url}`"))?
150            }
151            SchemaLocation::Glue { .. } => bail_invalid_option_error!(
152                "encode protobuf from aws glue schema registry not supported yet"
153            ),
154        };
155
156        let message_descriptor = pool
157            .get_message_by_name(message_name)
158            .with_context(|| format!("cannot find message `{message_name}` in schema"))?;
159
160        Ok(Self {
161            message_descriptor,
162            wire_type,
163            messages_as_jsonb: protobuf_config.messages_as_jsonb,
164        })
165    }
166
167    /// Maps the protobuf schema to relational schema.
168    pub fn map_to_columns(&self) -> ConnectorResult<Vec<Field>> {
169        pb_schema_to_fields(&self.message_descriptor, &self.messages_as_jsonb).map_err(|e| e.into())
170    }
171}
172
173/// A port from the implementation of confluent's Varint Zig-zag deserialization.
174/// See `ReadVarint` in <https://github.com/apache/kafka/blob/trunk/clients/src/main/java/org/apache/kafka/common/utils/ByteUtils.java>
175fn decode_varint_zigzag(buffer: &[u8]) -> ConnectorResult<(i32, usize)> {
176    // We expect the decoded number to be 4 bytes.
177    let mut value = 0u32;
178    let mut shift = 0;
179    let mut len = 0usize;
180
181    for &byte in buffer {
182        len += 1;
183        // The Varint encoding is limited to 5 bytes.
184        if len > 5 {
185            break;
186        }
187        // The byte is cast to u32 to avoid shifting overflow.
188        let byte_ext = byte as u32;
189        // In Varint encoding, the lowest 7 bits are used to represent number,
190        // while the highest zero bit indicates the end of the number with Varint encoding.
191        value |= (byte_ext & 0x7F) << shift;
192        if byte_ext & 0x80 == 0 {
193            return Ok((((value >> 1) as i32) ^ -((value & 1) as i32), len));
194        }
195
196        shift += 7;
197    }
198
199    Err(WireFormatError::ParseMessageIndexes.into())
200}
201
202/// Reference: <https://docs.confluent.io/platform/current/schema-registry/fundamentals/serdes-develop/index.html#wire-format>
203/// Wire format for Confluent pb header is:
204/// | 0          | 1-4        | 5-x             | x+1-end
205/// | magic-byte | schema-id  | message-indexes | protobuf-payload
206pub(crate) fn resolve_pb_header(payload: &[u8]) -> ConnectorResult<&[u8]> {
207    // there's a message index array at the front of payload
208    // if it is the first message in proto def, the array is just and `0`
209    let (_, remained) = extract_schema_id(payload)?;
210    // The message indexes are encoded as int using variable-length zig-zag encoding,
211    // prefixed by the length of the array.
212    // Note that if the first byte is 0, it is equivalent to (1, 0) as an optimization.
213    match remained.first() {
214        Some(0) => Ok(&remained[1..]),
215        Some(_) => {
216            let (index_len, mut offset) = decode_varint_zigzag(remained)?;
217            for _ in 0..index_len {
218                offset += decode_varint_zigzag(&remained[offset..])?.1;
219            }
220            Ok(&remained[offset..])
221        }
222        None => bail!("The proto payload is empty"),
223    }
224}
225
226#[cfg(test)]
227mod test {
228    use super::*;
229
230    #[test]
231    fn test_decode_varint_zigzag() {
232        // 1. Positive number
233        let buffer = vec![0x02];
234        let (value, len) = decode_varint_zigzag(&buffer).unwrap();
235        assert_eq!(value, 1);
236        assert_eq!(len, 1);
237
238        // 2. Negative number
239        let buffer = vec![0x01];
240        let (value, len) = decode_varint_zigzag(&buffer).unwrap();
241        assert_eq!(value, -1);
242        assert_eq!(len, 1);
243
244        // 3. Larger positive number
245        let buffer = vec![0x9E, 0x03];
246        let (value, len) = decode_varint_zigzag(&buffer).unwrap();
247        assert_eq!(value, 207);
248        assert_eq!(len, 2);
249
250        // 4. Larger negative number
251        let buffer = vec![0xBF, 0x07];
252        let (value, len) = decode_varint_zigzag(&buffer).unwrap();
253        assert_eq!(value, -480);
254        assert_eq!(len, 2);
255
256        // 5. Maximum positive number
257        let buffer = vec![0xFE, 0xFF, 0xFF, 0xFF, 0x0F];
258        let (value, len) = decode_varint_zigzag(&buffer).unwrap();
259        assert_eq!(value, i32::MAX);
260        assert_eq!(len, 5);
261
262        // 6. Maximum negative number
263        let buffer = vec![0xFF, 0xFF, 0xFF, 0xFF, 0x0F];
264        let (value, len) = decode_varint_zigzag(&buffer).unwrap();
265        assert_eq!(value, i32::MIN);
266        assert_eq!(len, 5);
267
268        // 7. More than 32 bits
269        let buffer = vec![0xFF, 0xFF, 0xFF, 0xFF, 0x7F];
270        let (value, len) = decode_varint_zigzag(&buffer).unwrap();
271        assert_eq!(value, i32::MIN);
272        assert_eq!(len, 5);
273
274        // 8. Invalid input (more than 5 bytes)
275        let buffer = vec![0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF];
276        let result = decode_varint_zigzag(&buffer);
277        assert!(result.is_err());
278    }
279}