risingwave_connector/parser/protobuf/
parser.rsuse anyhow::Context;
use prost_reflect::{DescriptorPool, DynamicMessage, FileDescriptor, MessageDescriptor};
use risingwave_common::{bail, try_match_expand};
pub use risingwave_connector_codec::decoder::protobuf::parser::*;
use risingwave_connector_codec::decoder::protobuf::ProtobufAccess;
use risingwave_pb::plan_common::ColumnDesc;
use crate::error::ConnectorResult;
use crate::parser::unified::AccessImpl;
use crate::parser::util::bytes_from_url;
use crate::parser::{AccessBuilder, EncodingProperties};
use crate::schema::schema_registry::{extract_schema_id, handle_sr_list, Client, WireFormatError};
use crate::schema::SchemaLoader;
#[derive(Debug)]
pub struct ProtobufAccessBuilder {
confluent_wire_type: bool,
message_descriptor: MessageDescriptor,
}
impl AccessBuilder for ProtobufAccessBuilder {
#[allow(clippy::unused_async)]
async fn generate_accessor(&mut self, payload: Vec<u8>) -> ConnectorResult<AccessImpl<'_>> {
let payload = if self.confluent_wire_type {
resolve_pb_header(&payload)?
} else {
&payload
};
let message = DynamicMessage::decode(self.message_descriptor.clone(), payload)
.context("failed to parse message")?;
Ok(AccessImpl::Protobuf(ProtobufAccess::new(message)))
}
}
impl ProtobufAccessBuilder {
pub fn new(config: ProtobufParserConfig) -> ConnectorResult<Self> {
let ProtobufParserConfig {
confluent_wire_type,
message_descriptor,
} = config;
Ok(Self {
confluent_wire_type,
message_descriptor,
})
}
}
#[derive(Debug, Clone)]
pub struct ProtobufParserConfig {
confluent_wire_type: bool,
pub(crate) message_descriptor: MessageDescriptor,
}
impl ProtobufParserConfig {
pub async fn new(encoding_properties: EncodingProperties) -> ConnectorResult<Self> {
let protobuf_config = try_match_expand!(encoding_properties, EncodingProperties::Protobuf)?;
let location = &protobuf_config.row_schema_location;
let message_name = &protobuf_config.message_name;
let url = handle_sr_list(location.as_str())?;
if protobuf_config.key_message_name.is_some() {
bail!("protobuf key is not supported");
}
let pool = if protobuf_config.use_schema_registry {
let client = Client::new(url, &protobuf_config.client_config)?;
let loader = SchemaLoader {
client,
name_strategy: protobuf_config.name_strategy,
topic: protobuf_config.topic,
key_record_name: None,
val_record_name: Some(message_name.clone()),
};
let (_schema_id, root_file_descriptor) = loader
.load_val_schema::<FileDescriptor>()
.await
.context("load schema failed")?;
root_file_descriptor.parent_pool().clone()
} else {
let url = url.first().unwrap();
let schema_bytes = bytes_from_url(url, protobuf_config.aws_auth_props.as_ref()).await?;
DescriptorPool::decode(schema_bytes.as_slice())
.with_context(|| format!("cannot build descriptor pool from schema `{location}`"))?
};
let message_descriptor = pool.get_message_by_name(message_name).with_context(|| {
format!(
"cannot find message `{}` in schema `{}`",
message_name, location,
)
})?;
Ok(Self {
message_descriptor,
confluent_wire_type: protobuf_config.use_schema_registry,
})
}
pub fn map_to_columns(&self) -> ConnectorResult<Vec<ColumnDesc>> {
pb_schema_to_column_descs(&self.message_descriptor).map_err(|e| e.into())
}
}
fn decode_varint_zigzag(buffer: &[u8]) -> ConnectorResult<(i32, usize)> {
let mut value = 0u32;
let mut shift = 0;
let mut len = 0usize;
for &byte in buffer {
len += 1;
if len > 5 {
break;
}
let byte_ext = byte as u32;
value |= (byte_ext & 0x7F) << shift;
if byte_ext & 0x80 == 0 {
return Ok((((value >> 1) as i32) ^ -((value & 1) as i32), len));
}
shift += 7;
}
Err(WireFormatError::ParseMessageIndexes.into())
}
pub(crate) fn resolve_pb_header(payload: &[u8]) -> ConnectorResult<&[u8]> {
let (_, remained) = extract_schema_id(payload)?;
match remained.first() {
Some(0) => Ok(&remained[1..]),
Some(_) => {
let (index_len, mut offset) = decode_varint_zigzag(remained)?;
for _ in 0..index_len {
offset += decode_varint_zigzag(&remained[offset..])?.1;
}
Ok(&remained[offset..])
}
None => bail!("The proto payload is empty"),
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_decode_varint_zigzag() {
let buffer = vec![0x02];
let (value, len) = decode_varint_zigzag(&buffer).unwrap();
assert_eq!(value, 1);
assert_eq!(len, 1);
let buffer = vec![0x01];
let (value, len) = decode_varint_zigzag(&buffer).unwrap();
assert_eq!(value, -1);
assert_eq!(len, 1);
let buffer = vec![0x9E, 0x03];
let (value, len) = decode_varint_zigzag(&buffer).unwrap();
assert_eq!(value, 207);
assert_eq!(len, 2);
let buffer = vec![0xBF, 0x07];
let (value, len) = decode_varint_zigzag(&buffer).unwrap();
assert_eq!(value, -480);
assert_eq!(len, 2);
let buffer = vec![0xFE, 0xFF, 0xFF, 0xFF, 0x0F];
let (value, len) = decode_varint_zigzag(&buffer).unwrap();
assert_eq!(value, i32::MAX);
assert_eq!(len, 5);
let buffer = vec![0xFF, 0xFF, 0xFF, 0xFF, 0x0F];
let (value, len) = decode_varint_zigzag(&buffer).unwrap();
assert_eq!(value, i32::MIN);
assert_eq!(len, 5);
let buffer = vec![0xFF, 0xFF, 0xFF, 0xFF, 0x7F];
let (value, len) = decode_varint_zigzag(&buffer).unwrap();
assert_eq!(value, i32::MIN);
assert_eq!(len, 5);
let buffer = vec![0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF];
let result = decode_varint_zigzag(&buffer);
assert!(result.is_err());
}
}