risingwave_connector_codec/decoder/protobuf/
parser.rsuse anyhow::Context;
use itertools::Itertools;
use prost_reflect::{Cardinality, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage, Value};
use risingwave_common::array::{ListValue, StructValue};
use risingwave_common::types::{
DataType, DatumCow, Decimal, JsonbVal, MapType, MapValue, ScalarImpl, ToOwnedDatum, F32, F64,
};
use risingwave_pb::plan_common::{AdditionalColumn, ColumnDesc, ColumnDescVersion};
use thiserror::Error;
use thiserror_ext::Macro;
use crate::decoder::{uncategorized, AccessError, AccessResult};
pub fn pb_schema_to_column_descs(
message_descriptor: &MessageDescriptor,
) -> anyhow::Result<Vec<ColumnDesc>> {
let mut columns = Vec::with_capacity(message_descriptor.fields().len());
let mut index = 0;
let mut parse_trace: Vec<String> = vec![];
for field in message_descriptor.fields() {
columns.push(pb_field_to_col_desc(&field, &mut index, &mut parse_trace)?);
}
Ok(columns)
}
fn pb_field_to_col_desc(
field_descriptor: &FieldDescriptor,
index: &mut i32,
parse_trace: &mut Vec<String>,
) -> anyhow::Result<ColumnDesc> {
let field_type = protobuf_type_mapping(field_descriptor, parse_trace)
.context("failed to map protobuf type")?;
if let Kind::Message(m) = field_descriptor.kind() {
let field_descs = if let DataType::List { .. } = field_type {
vec![]
} else {
m.fields()
.map(|f| pb_field_to_col_desc(&f, index, parse_trace))
.try_collect()?
};
*index += 1;
Ok(ColumnDesc {
column_id: *index,
name: field_descriptor.name().to_string(),
column_type: Some(field_type.to_protobuf()),
field_descs,
type_name: m.full_name().to_string(),
generated_or_default_column: None,
description: None,
additional_column_type: 0, additional_column: Some(AdditionalColumn { column_type: None }),
version: ColumnDescVersion::Pr13707 as i32,
})
} else {
*index += 1;
Ok(ColumnDesc {
column_id: *index,
name: field_descriptor.name().to_string(),
column_type: Some(field_type.to_protobuf()),
additional_column: Some(AdditionalColumn { column_type: None }),
version: ColumnDescVersion::Pr13707 as i32,
..Default::default()
})
}
}
#[derive(Error, Debug, Macro)]
#[error("{0}")]
struct ProtobufTypeError(#[message] String);
fn detect_loop_and_push(
trace: &mut Vec<String>,
fd: &FieldDescriptor,
) -> std::result::Result<(), ProtobufTypeError> {
let identifier = format!("{}({})", fd.name(), fd.full_name());
if trace.iter().any(|s| s == identifier.as_str()) {
bail_protobuf_type_error!(
"circular reference detected: {}, conflict with {}, kind {:?}",
trace.iter().format("->"),
identifier,
fd.kind(),
);
}
trace.push(identifier);
Ok(())
}
pub fn from_protobuf_value<'a>(
field_desc: &FieldDescriptor,
value: &'a Value,
type_expected: &DataType,
) -> AccessResult<DatumCow<'a>> {
let kind = field_desc.kind();
macro_rules! borrowed {
($v:expr) => {
return Ok(DatumCow::Borrowed(Some($v.into())))
};
}
let v: ScalarImpl = match value {
Value::Bool(v) => ScalarImpl::Bool(*v),
Value::I32(i) => ScalarImpl::Int32(*i),
Value::U32(i) => ScalarImpl::Int64(*i as i64),
Value::I64(i) => ScalarImpl::Int64(*i),
Value::U64(i) => ScalarImpl::Decimal(Decimal::from(*i)),
Value::F32(f) => ScalarImpl::Float32(F32::from(*f)),
Value::F64(f) => ScalarImpl::Float64(F64::from(*f)),
Value::String(s) => borrowed!(s.as_str()),
Value::EnumNumber(idx) => {
let enum_desc = kind.as_enum().ok_or_else(|| AccessError::TypeError {
expected: "enum".to_owned(),
got: format!("{kind:?}"),
value: value.to_string(),
})?;
let enum_symbol = enum_desc.get_value(*idx).ok_or_else(|| {
uncategorized!("unknown enum index {} of enum {:?}", idx, enum_desc)
})?;
ScalarImpl::Utf8(enum_symbol.name().into())
}
Value::Message(dyn_msg) => {
if dyn_msg.descriptor().full_name() == "google.protobuf.Any" {
ScalarImpl::Jsonb(JsonbVal::from(
serde_json::to_value(dyn_msg).map_err(AccessError::ProtobufAnyToJson)?,
))
} else {
let desc = dyn_msg.descriptor();
let DataType::Struct(st) = type_expected else {
return Err(AccessError::TypeError {
expected: type_expected.to_string(),
got: desc.full_name().to_string(),
value: value.to_string(), });
};
let mut rw_values = Vec::with_capacity(st.len());
for (name, expected_field_type) in st.iter() {
let Some(field_desc) = desc.get_field_by_name(name) else {
rw_values.push(None);
continue;
};
let value = dyn_msg.get_field(&field_desc);
rw_values.push(
from_protobuf_value(&field_desc, &value, expected_field_type)?
.to_owned_datum(),
);
}
ScalarImpl::Struct(StructValue::new(rw_values))
}
}
Value::List(values) => {
let DataType::List(element_type) = type_expected else {
return Err(AccessError::TypeError {
expected: type_expected.to_string(),
got: format!("repeated {:?}", kind),
value: value.to_string(), });
};
let mut builder = element_type.create_array_builder(values.len());
for value in values {
builder.append(from_protobuf_value(field_desc, value, element_type)?);
}
ScalarImpl::List(ListValue::new(builder.finish()))
}
Value::Bytes(value) => borrowed!(&**value),
Value::Map(map) => {
let err = || {
AccessError::TypeError {
expected: type_expected.to_string(),
got: format!("{:?}", kind),
value: value.to_string(), }
};
let DataType::Map(map_type) = type_expected else {
return Err(err());
};
if !field_desc.is_map() {
return Err(err());
}
let map_desc = kind.as_message().ok_or_else(err)?;
let mut key_builder = map_type.key().create_array_builder(map.len());
let mut value_builder = map_type.value().create_array_builder(map.len());
for (key, value) in map.iter().sorted_by_key(|(k, _v)| *k) {
key_builder.append(from_protobuf_value(
&map_desc.map_entry_key_field(),
&key.clone().into(),
map_type.key(),
)?);
value_builder.append(from_protobuf_value(
&map_desc.map_entry_value_field(),
value,
map_type.value(),
)?);
}
let keys = key_builder.finish();
let values = value_builder.finish();
ScalarImpl::Map(
MapValue::try_from_kv(ListValue::new(keys), ListValue::new(values))
.map_err(|e| uncategorized!("failed to convert protobuf map: {e}"))?,
)
}
};
Ok(Some(v).into())
}
fn protobuf_type_mapping(
field_descriptor: &FieldDescriptor,
parse_trace: &mut Vec<String>,
) -> std::result::Result<DataType, ProtobufTypeError> {
detect_loop_and_push(parse_trace, field_descriptor)?;
let mut t = match field_descriptor.kind() {
Kind::Bool => DataType::Boolean,
Kind::Double => DataType::Float64,
Kind::Float => DataType::Float32,
Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => DataType::Int32,
Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 | Kind::Uint32 | Kind::Fixed32 => {
DataType::Int64
}
Kind::Uint64 | Kind::Fixed64 => DataType::Decimal,
Kind::String => DataType::Varchar,
Kind::Message(m) => {
if m.full_name() == "google.protobuf.Any" {
DataType::Jsonb
} else if m.is_map_entry() {
debug_assert!(field_descriptor.is_map());
let key = protobuf_type_mapping(&m.map_entry_key_field(), parse_trace)?;
let value = protobuf_type_mapping(&m.map_entry_value_field(), parse_trace)?;
_ = parse_trace.pop();
return Ok(DataType::Map(MapType::from_kv(key, value)));
} else {
let fields = m
.fields()
.map(|f| protobuf_type_mapping(&f, parse_trace))
.try_collect()?;
let field_names = m.fields().map(|f| f.name().to_string()).collect_vec();
DataType::new_struct(fields, field_names)
}
}
Kind::Enum(_) => DataType::Varchar,
Kind::Bytes => DataType::Bytea,
};
if field_descriptor.cardinality() == Cardinality::Repeated {
debug_assert!(!field_descriptor.is_map());
t = DataType::List(Box::new(t))
}
_ = parse_trace.pop();
Ok(t)
}