risingwave_connector/parser/protobuf/
parser.rs1use std::collections::HashSet;
16
17use anyhow::Context;
18use prost_reflect::{DescriptorPool, DynamicMessage, FileDescriptor, MessageDescriptor};
19use risingwave_common::catalog::Field;
20use risingwave_common::{bail, try_match_expand};
21use risingwave_connector_codec::decoder::protobuf::ProtobufAccess;
22pub use risingwave_connector_codec::decoder::protobuf::parser::{PROTOBUF_MESSAGES_AS_JSONB, *};
23
24use crate::error::ConnectorResult;
25use crate::parser::unified::AccessImpl;
26use crate::parser::utils::bytes_from_url;
27use crate::parser::{AccessBuilder, EncodingProperties, SchemaLocation};
28use crate::schema::schema_registry::{Client, WireFormatError, extract_schema_id, handle_sr_list};
29use crate::schema::{
30 ConfluentSchemaLoader, InvalidOptionError, SchemaLoader, bail_invalid_option_error,
31};
32
33#[derive(Debug)]
34pub struct ProtobufAccessBuilder {
35 wire_type: WireType,
36 message_descriptor: MessageDescriptor,
37
38 messages_as_jsonb: HashSet<String>,
41}
42
43impl AccessBuilder for ProtobufAccessBuilder {
44 async fn generate_accessor(
45 &mut self,
46 payload: Vec<u8>,
47 _: &crate::source::SourceMeta,
48 ) -> ConnectorResult<AccessImpl<'_>> {
49 let payload = match self.wire_type {
50 WireType::Confluent => resolve_pb_header(&payload)?,
51 WireType::None => &payload,
52 };
53
54 let message = DynamicMessage::decode(self.message_descriptor.clone(), payload)
55 .context("failed to parse message")?;
56
57 Ok(AccessImpl::Protobuf(ProtobufAccess::new(
58 message,
59 &self.messages_as_jsonb,
60 )))
61 }
62}
63
64impl ProtobufAccessBuilder {
65 pub fn new(config: ProtobufParserConfig) -> ConnectorResult<Self> {
66 let ProtobufParserConfig {
67 wire_type,
68 message_descriptor,
69 messages_as_jsonb,
70 } = config;
71
72 Ok(Self {
73 wire_type,
74 message_descriptor,
75 messages_as_jsonb,
76 })
77 }
78}
79
80#[derive(Debug, Clone)]
81enum WireType {
82 None,
83 Confluent,
84 }
87
88impl TryFrom<&SchemaLocation> for WireType {
89 type Error = InvalidOptionError;
90
91 fn try_from(value: &SchemaLocation) -> Result<Self, Self::Error> {
92 match value {
93 SchemaLocation::File { .. } => Ok(Self::None),
94 SchemaLocation::Confluent { .. } => Ok(Self::Confluent),
95 SchemaLocation::Glue { .. } => bail_invalid_option_error!(
96 "encode protobuf from aws glue schema registry not supported yet"
97 ),
98 }
99 }
100}
101
102#[derive(Debug, Clone)]
103pub struct ProtobufParserConfig {
104 wire_type: WireType,
105 pub(crate) message_descriptor: MessageDescriptor,
106 messages_as_jsonb: HashSet<String>,
107}
108
109impl ProtobufParserConfig {
110 pub async fn new(encoding_properties: EncodingProperties) -> ConnectorResult<Self> {
111 let protobuf_config = try_match_expand!(encoding_properties, EncodingProperties::Protobuf)?;
112 let message_name = &protobuf_config.message_name;
113
114 let wire_type = (&protobuf_config.schema_location).try_into()?;
115 if protobuf_config.key_message_name.is_some() {
116 bail!("protobuf key is not supported");
118 }
119 let pool = match protobuf_config.schema_location {
120 SchemaLocation::Confluent {
121 urls,
122 client_config,
123 name_strategy,
124 topic,
125 } => {
126 let url = handle_sr_list(urls.as_str())?;
127 let client = Client::new(url, &client_config)?;
128 let loader = SchemaLoader::Confluent(ConfluentSchemaLoader {
129 client,
130 name_strategy,
131 topic,
132 key_record_name: None,
133 val_record_name: Some(message_name.clone()),
134 });
135 let (_schema_id, root_file_descriptor) = loader
136 .load_val_schema::<FileDescriptor>()
137 .await
138 .context("load schema failed")?;
139 root_file_descriptor.parent_pool().clone()
140 }
141 SchemaLocation::File {
142 url,
143 aws_auth_props,
144 } => {
145 let url = handle_sr_list(url.as_str())?;
146 let url = url.first().unwrap();
147 let schema_bytes = bytes_from_url(url, aws_auth_props.as_ref()).await?;
148 DescriptorPool::decode(schema_bytes.as_slice())
149 .with_context(|| format!("cannot build descriptor pool from schema `{url}`"))?
150 }
151 SchemaLocation::Glue { .. } => bail_invalid_option_error!(
152 "encode protobuf from aws glue schema registry not supported yet"
153 ),
154 };
155
156 let message_descriptor = pool
157 .get_message_by_name(message_name)
158 .with_context(|| format!("cannot find message `{message_name}` in schema"))?;
159
160 Ok(Self {
161 message_descriptor,
162 wire_type,
163 messages_as_jsonb: protobuf_config.messages_as_jsonb,
164 })
165 }
166
167 pub fn map_to_columns(&self) -> ConnectorResult<Vec<Field>> {
169 pb_schema_to_fields(&self.message_descriptor, &self.messages_as_jsonb).map_err(|e| e.into())
170 }
171}
172
173fn decode_varint_zigzag(buffer: &[u8]) -> ConnectorResult<(i32, usize)> {
176 let mut value = 0u32;
178 let mut shift = 0;
179 let mut len = 0usize;
180
181 for &byte in buffer {
182 len += 1;
183 if len > 5 {
185 break;
186 }
187 let byte_ext = byte as u32;
189 value |= (byte_ext & 0x7F) << shift;
192 if byte_ext & 0x80 == 0 {
193 return Ok((((value >> 1) as i32) ^ -((value & 1) as i32), len));
194 }
195
196 shift += 7;
197 }
198
199 Err(WireFormatError::ParseMessageIndexes.into())
200}
201
202pub(crate) fn resolve_pb_header(payload: &[u8]) -> ConnectorResult<&[u8]> {
207 let (_, remained) = extract_schema_id(payload)?;
210 match remained.first() {
214 Some(0) => Ok(&remained[1..]),
215 Some(_) => {
216 let (index_len, mut offset) = decode_varint_zigzag(remained)?;
217 for _ in 0..index_len {
218 offset += decode_varint_zigzag(&remained[offset..])?.1;
219 }
220 Ok(&remained[offset..])
221 }
222 None => bail!("The proto payload is empty"),
223 }
224}
225
226#[cfg(test)]
227mod test {
228 use super::*;
229
230 #[test]
231 fn test_decode_varint_zigzag() {
232 let buffer = vec![0x02];
234 let (value, len) = decode_varint_zigzag(&buffer).unwrap();
235 assert_eq!(value, 1);
236 assert_eq!(len, 1);
237
238 let buffer = vec![0x01];
240 let (value, len) = decode_varint_zigzag(&buffer).unwrap();
241 assert_eq!(value, -1);
242 assert_eq!(len, 1);
243
244 let buffer = vec![0x9E, 0x03];
246 let (value, len) = decode_varint_zigzag(&buffer).unwrap();
247 assert_eq!(value, 207);
248 assert_eq!(len, 2);
249
250 let buffer = vec![0xBF, 0x07];
252 let (value, len) = decode_varint_zigzag(&buffer).unwrap();
253 assert_eq!(value, -480);
254 assert_eq!(len, 2);
255
256 let buffer = vec![0xFE, 0xFF, 0xFF, 0xFF, 0x0F];
258 let (value, len) = decode_varint_zigzag(&buffer).unwrap();
259 assert_eq!(value, i32::MAX);
260 assert_eq!(len, 5);
261
262 let buffer = vec![0xFF, 0xFF, 0xFF, 0xFF, 0x0F];
264 let (value, len) = decode_varint_zigzag(&buffer).unwrap();
265 assert_eq!(value, i32::MIN);
266 assert_eq!(len, 5);
267
268 let buffer = vec![0xFF, 0xFF, 0xFF, 0xFF, 0x7F];
270 let (value, len) = decode_varint_zigzag(&buffer).unwrap();
271 assert_eq!(value, i32::MIN);
272 assert_eq!(len, 5);
273
274 let buffer = vec![0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF];
276 let result = decode_varint_zigzag(&buffer);
277 assert!(result.is_err());
278 }
279}