1use bytes::{BufMut, Bytes};
16use prost::Message;
17use prost_reflect::{
18 DynamicMessage, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage, Value,
19};
20use risingwave_common::array::VECTOR_ITEM_TYPE;
21use risingwave_common::catalog::Schema;
22use risingwave_common::row::Row;
23use risingwave_common::types::{DataType, DatumRef, MapType, ScalarRefImpl, StructType};
24use risingwave_common::util::iter_util::ZipEqDebug;
25
26use super::{FieldEncodeError, Result as SinkResult, RowEncoder, SerTo};
27
28type Result<T> = std::result::Result<T, FieldEncodeError>;
29
30pub struct ProtoEncoder {
31 schema: Schema,
32 col_indices: Option<Vec<usize>>,
33 descriptor: MessageDescriptor,
34 header: ProtoHeader,
35}
36
37#[derive(Debug, Clone, Copy)]
38pub enum ProtoHeader {
39 None,
40 ConfluentSchemaRegistry(i32),
45}
46
47impl ProtoEncoder {
48 pub fn new(
49 schema: Schema,
50 col_indices: Option<Vec<usize>>,
51 descriptor: MessageDescriptor,
52 header: ProtoHeader,
53 ) -> SinkResult<Self> {
54 match &col_indices {
55 Some(col_indices) => validate_fields(
56 col_indices.iter().map(|idx| {
57 let f = &schema[*idx];
58 (f.name.as_str(), &f.data_type)
59 }),
60 &descriptor,
61 )?,
62 None => validate_fields(
63 schema
64 .fields
65 .iter()
66 .map(|f| (f.name.as_str(), &f.data_type)),
67 &descriptor,
68 )?,
69 };
70
71 Ok(Self {
72 schema,
73 col_indices,
74 descriptor,
75 header,
76 })
77 }
78}
79
80pub struct ProtoEncoded {
81 pub message: DynamicMessage,
82 header: ProtoHeader,
83}
84
85impl RowEncoder for ProtoEncoder {
86 type Output = ProtoEncoded;
87
88 fn schema(&self) -> &Schema {
89 &self.schema
90 }
91
92 fn col_indices(&self) -> Option<&[usize]> {
93 self.col_indices.as_deref()
94 }
95
96 fn encode_cols(
97 &self,
98 row: impl Row,
99 col_indices: impl Iterator<Item = usize>,
100 ) -> SinkResult<Self::Output> {
101 encode_fields(
102 col_indices.map(|idx| {
103 let f = &self.schema[idx];
104 ((f.name.as_str(), &f.data_type), row.datum_at(idx))
105 }),
106 &self.descriptor,
107 )
108 .map_err(Into::into)
109 .map(|m| ProtoEncoded {
110 message: m,
111 header: self.header,
112 })
113 }
114}
115
116impl SerTo<Vec<u8>> for ProtoEncoded {
117 fn ser_to(self) -> SinkResult<Vec<u8>> {
118 let mut buf = Vec::new();
119 match self.header {
120 ProtoHeader::None => { }
121 ProtoHeader::ConfluentSchemaRegistry(schema_id) => {
122 buf.reserve(1 + 4);
123 buf.put_u8(0);
124 buf.put_i32(schema_id);
125 MessageIndexes::from(self.message.descriptor()).encode(&mut buf);
126 }
127 }
128 self.message.encode(&mut buf).unwrap();
129 Ok(buf)
130 }
131}
132
133struct MessageIndexes(Vec<i32>);
134
135impl MessageIndexes {
136 fn from(desc: MessageDescriptor) -> Self {
137 const TAG_FILE_MESSAGE: i32 = 4;
142 const TAG_MESSAGE_NESTED: i32 = 3;
144
145 let mut indexes = vec![];
146 let mut path = desc.path().array_chunks();
147 let &[tag, idx] = path.next().unwrap();
148 assert_eq!(tag, TAG_FILE_MESSAGE);
149 indexes.push(idx);
150 for &[tag, idx] in path {
151 assert_eq!(tag, TAG_MESSAGE_NESTED);
152 indexes.push(idx);
153 }
154 Self(indexes)
155 }
156
157 fn zig_i32(value: i32, buf: &mut impl BufMut) {
158 let unsigned = ((value << 1) ^ (value >> 31)) as u32 as u64;
159 prost::encoding::encode_varint(unsigned, buf);
160 }
161
162 fn encode(&self, buf: &mut impl BufMut) {
163 if self.0 == [0] {
164 buf.put_u8(0);
165 return;
166 }
167 Self::zig_i32(self.0.len().try_into().unwrap(), buf);
168 for &idx in &self.0 {
169 Self::zig_i32(idx, buf);
170 }
171 }
172}
173
174trait MaybeData: std::fmt::Debug {
180 type Out;
181
182 fn on_base(self, f: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out>;
183
184 fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out>;
185
186 fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out>;
187
188 fn on_map(self, m: &MapType, pb: &MessageDescriptor) -> Result<Self::Out>;
189}
190
191impl MaybeData for () {
192 type Out = ();
193
194 fn on_base(self, _: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out> {
195 Ok(self)
196 }
197
198 fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out> {
199 validate_fields(st.iter(), pb)
200 }
201
202 fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out> {
203 on_field(elem, (), pb, true)
204 }
205
206 fn on_map(self, elem: &MapType, pb: &MessageDescriptor) -> Result<Self::Out> {
207 debug_assert!(pb.is_map_entry());
208 on_field(elem.key(), (), &pb.map_entry_key_field(), false)?;
209 on_field(elem.value(), (), &pb.map_entry_value_field(), false)?;
210 Ok(())
211 }
212}
213
214impl MaybeData for ScalarRefImpl<'_> {
221 type Out = Value;
222
223 fn on_base(self, f: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out> {
224 f(self)
225 }
226
227 fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out> {
228 let d = self.into_struct();
229 let message = encode_fields(st.iter().zip_eq_debug(d.iter_fields_ref()), pb)?;
230 Ok(Value::Message(message))
231 }
232
233 fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out> {
234 let d = self.into_list();
235 let vs = d
236 .iter()
237 .map(|d| {
238 on_field(
239 elem,
240 d.ok_or_else(|| {
241 FieldEncodeError::new("array containing null not allowed as repeated field")
242 })?,
243 pb,
244 true,
245 )
246 })
247 .try_collect()?;
248 Ok(Value::List(vs))
249 }
250
251 fn on_map(self, m: &MapType, pb: &MessageDescriptor) -> Result<Self::Out> {
252 debug_assert!(pb.is_map_entry());
253 let vs = self
254 .into_map()
255 .iter()
256 .map(|(k, v)| {
257 let v =
258 v.ok_or_else(|| FieldEncodeError::new("map containing null not allowed"))?;
259 let k = on_field(m.key(), k, &pb.map_entry_key_field(), false)?;
260 let v = on_field(m.value(), v, &pb.map_entry_value_field(), false)?;
261 Ok((
262 k.into_map_key().ok_or_else(|| {
263 FieldEncodeError::new("failed to convert map key to proto")
264 })?,
265 v,
266 ))
267 })
268 .try_collect()?;
269 Ok(Value::Map(vs))
270 }
271}
272
273fn validate_fields<'a>(
274 fields: impl Iterator<Item = (&'a str, &'a DataType)>,
275 descriptor: &MessageDescriptor,
276) -> Result<()> {
277 for (name, t) in fields {
278 let Some(proto_field) = descriptor.get_field_by_name(name) else {
279 return Err(FieldEncodeError::new("field not in proto").with_name(name));
280 };
281 if proto_field.cardinality() == prost_reflect::Cardinality::Required {
282 return Err(FieldEncodeError::new("`required` not supported").with_name(name));
283 }
284 on_field(t, (), &proto_field, false).map_err(|e| e.with_name(name))?;
285 }
286 Ok(())
287}
288
289fn encode_fields<'a>(
290 fields_with_datums: impl Iterator<Item = ((&'a str, &'a DataType), DatumRef<'a>)>,
291 descriptor: &MessageDescriptor,
292) -> Result<DynamicMessage> {
293 let mut message = DynamicMessage::new(descriptor.clone());
294 for ((name, t), d) in fields_with_datums {
295 let proto_field = descriptor.get_field_by_name(name).unwrap();
296 if let Some(scalar) = d {
298 let value = on_field(t, scalar, &proto_field, false).map_err(|e| e.with_name(name))?;
299 message
300 .try_set_field(&proto_field, value)
301 .map_err(|e| FieldEncodeError::new(e).with_name(name))?;
302 }
303 }
304 Ok(message)
305}
306
307const WKT_TIMESTAMP: &str = "google.protobuf.Timestamp";
309#[expect(dead_code)]
310const WKT_BOOL_VALUE: &str = "google.protobuf.BoolValue";
311
312fn on_field<D: MaybeData>(
315 data_type: &DataType,
316 maybe: D,
317 proto_field: &FieldDescriptor,
318 in_repeated: bool,
319) -> Result<D::Out> {
320 assert!(proto_field.is_list() || !in_repeated);
327 let expect_list = proto_field.is_list() && !in_repeated;
328 if proto_field.is_group() {
329 return Err(FieldEncodeError::new("proto group not supported yet"));
330 }
331
332 let no_match_err = || {
333 Err(FieldEncodeError::new(format!(
334 "cannot encode {} column as {}{:?} field",
335 data_type,
336 if expect_list { "repeated " } else { "" },
337 proto_field.kind()
338 )))
339 };
340
341 if expect_list && !matches!(data_type, DataType::List(_)) {
342 return no_match_err();
343 }
344
345 let value = match &data_type {
346 DataType::Boolean => match proto_field.kind() {
348 Kind::Bool => maybe.on_base(|s| Ok(Value::Bool(s.into_bool())))?,
349 _ => return no_match_err(),
350 },
351 DataType::Varchar => match proto_field.kind() {
352 Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_utf8().into())))?,
353 Kind::Enum(enum_desc) => maybe.on_base(|s| {
354 let name = s.into_utf8();
355 let enum_value_desc = enum_desc.get_value_by_name(name).ok_or_else(|| {
356 FieldEncodeError::new(format!("'{name}' not in enum {}", enum_desc.name()))
357 })?;
358 Ok(Value::EnumNumber(enum_value_desc.number()))
359 })?,
360 _ => return no_match_err(),
361 },
362 DataType::Bytea => match proto_field.kind() {
363 Kind::Bytes => {
364 maybe.on_base(|s| Ok(Value::Bytes(Bytes::copy_from_slice(s.into_bytea()))))?
365 }
366 _ => return no_match_err(),
367 },
368 DataType::Float32 => match proto_field.kind() {
369 Kind::Float => maybe.on_base(|s| Ok(Value::F32(s.into_float32().into())))?,
370 _ => return no_match_err(),
371 },
372 DataType::Float64 => match proto_field.kind() {
373 Kind::Double => maybe.on_base(|s| Ok(Value::F64(s.into_float64().into())))?,
374 _ => return no_match_err(),
375 },
376 DataType::Int32 => match proto_field.kind() {
377 Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => {
378 maybe.on_base(|s| Ok(Value::I32(s.into_int32())))?
379 }
380 _ => return no_match_err(),
381 },
382 DataType::Int64 => match proto_field.kind() {
383 Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => {
384 maybe.on_base(|s| Ok(Value::I64(s.into_int64())))?
385 }
386 _ => return no_match_err(),
387 },
388 DataType::Struct(st) => match proto_field.kind() {
389 Kind::Message(pb) => maybe.on_struct(st, &pb)?,
390 _ => return no_match_err(),
391 },
392 DataType::List(elem) => match expect_list {
393 true => maybe.on_list(elem, proto_field)?,
394 false => return no_match_err(),
395 },
396 DataType::Timestamptz => match proto_field.kind() {
398 Kind::Message(pb) if pb.full_name() == WKT_TIMESTAMP => maybe.on_base(|s| {
399 let d = s.into_timestamptz();
400 let message = prost_types::Timestamp {
401 seconds: d.timestamp(),
402 nanos: d.timestamp_subsec_nanos().try_into().unwrap(),
403 };
404 Ok(Value::Message(message.transcode_to_dynamic()))
405 })?,
406 Kind::String => {
407 maybe.on_base(|s| Ok(Value::String(s.into_timestamptz().to_string())))?
408 }
409 _ => return no_match_err(),
410 },
411 DataType::Jsonb => match proto_field.kind() {
412 Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_jsonb().to_string())))?,
413 _ => return no_match_err(), },
416 DataType::Int16 => match proto_field.kind() {
417 Kind::Int64 => maybe.on_base(|s| Ok(Value::I64(s.into_int16() as i64)))?,
418 _ => return no_match_err(),
419 },
420 DataType::Date => match proto_field.kind() {
421 Kind::Int32 => {
422 maybe.on_base(|s| Ok(Value::I32(s.into_date().get_nums_days_unix_epoch())))?
423 }
424 _ => return no_match_err(), },
426 DataType::Time => match proto_field.kind() {
427 Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_time().to_string())))?,
428 _ => return no_match_err(), },
430 DataType::Timestamp => match proto_field.kind() {
431 Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_timestamp().to_string())))?,
432 _ => return no_match_err(), },
434 DataType::Decimal => match proto_field.kind() {
435 Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_decimal().to_string())))?,
436 _ => return no_match_err(), },
438 DataType::Interval => match proto_field.kind() {
439 Kind::String => {
440 maybe.on_base(|s| Ok(Value::String(s.into_interval().as_iso_8601())))?
441 }
442 _ => return no_match_err(), },
444 DataType::Serial => match proto_field.kind() {
445 Kind::Int64 => maybe.on_base(|s| Ok(Value::I64(s.into_serial().as_row_id())))?,
446 _ => return no_match_err(), },
448 DataType::Int256 => {
449 return no_match_err();
450 }
451 DataType::Map(map_type) => {
452 if proto_field.is_map() {
453 let msg = match proto_field.kind() {
454 Kind::Message(m) => m,
455 _ => return no_match_err(), };
457 return maybe.on_map(map_type, &msg);
458 } else {
459 return no_match_err();
460 }
461 }
462 DataType::Vector(_) => match expect_list {
463 true => maybe.on_list(&VECTOR_ITEM_TYPE, proto_field)?,
464 false => return no_match_err(),
465 },
466 };
467
468 Ok(value)
469}
470
471#[cfg(test)]
472mod tests {
473 use itertools::Itertools;
474 use risingwave_common::array::{ArrayBuilder, StructArrayBuilder};
475 use risingwave_common::catalog::Field;
476 use risingwave_common::row::OwnedRow;
477 use risingwave_common::types::{
478 ListValue, MapType, MapValue, Scalar, ScalarImpl, StructValue, Timestamptz,
479 };
480
481 use super::*;
482
483 #[test]
484 fn test_encode_proto_ok() {
485 let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
486 .join("codec/tests/test_data/all-types.pb");
487 let pool_bytes = std::fs::read(pool_path).unwrap();
488 let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
489 let descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
490 let schema = Schema::new(vec![
491 Field::with_name(DataType::Boolean, "bool_field"),
492 Field::with_name(DataType::Varchar, "string_field"),
493 Field::with_name(DataType::Bytea, "bytes_field"),
494 Field::with_name(DataType::Float32, "float_field"),
495 Field::with_name(DataType::Float64, "double_field"),
496 Field::with_name(DataType::Int32, "int32_field"),
497 Field::with_name(DataType::Int64, "int64_field"),
498 Field::with_name(DataType::Int32, "sint32_field"),
499 Field::with_name(DataType::Int64, "sint64_field"),
500 Field::with_name(DataType::Int32, "sfixed32_field"),
501 Field::with_name(DataType::Int64, "sfixed64_field"),
502 Field::with_name(
503 DataType::Struct(StructType::new(vec![
504 ("id", DataType::Int32),
505 ("name", DataType::Varchar),
506 ])),
507 "nested_message_field",
508 ),
509 Field::with_name(DataType::List(DataType::Int32.into()), "repeated_int_field"),
510 Field::with_name(DataType::Timestamptz, "timestamp_field"),
511 Field::with_name(
512 DataType::Map(MapType::from_kv(DataType::Varchar, DataType::Int32)),
513 "map_field",
514 ),
515 Field::with_name(
516 DataType::Map(MapType::from_kv(
517 DataType::Varchar,
518 DataType::Struct(StructType::new(vec![
519 ("id", DataType::Int32),
520 ("name", DataType::Varchar),
521 ])),
522 )),
523 "map_struct_field",
524 ),
525 ]);
526 let row = OwnedRow::new(vec![
527 Some(ScalarImpl::Bool(true)),
528 Some(ScalarImpl::Utf8("RisingWave".into())),
529 Some(ScalarImpl::Bytea([0xbe, 0xef].into())),
530 Some(ScalarImpl::Float32(3.5f32.into())),
531 Some(ScalarImpl::Float64(4.25f64.into())),
532 Some(ScalarImpl::Int32(22)),
533 Some(ScalarImpl::Int64(23)),
534 Some(ScalarImpl::Int32(24)),
535 None,
536 Some(ScalarImpl::Int32(26)),
537 Some(ScalarImpl::Int64(27)),
538 Some(ScalarImpl::Struct(StructValue::new(vec![
539 Some(ScalarImpl::Int32(1)),
540 Some(ScalarImpl::Utf8("".into())),
541 ]))),
542 Some(ScalarImpl::List(ListValue::from_iter([4, 0, 4]))),
543 Some(ScalarImpl::Timestamptz(Timestamptz::from_micros(3))),
544 Some(ScalarImpl::Map(
545 MapValue::try_from_kv(
546 ListValue::from_iter(["a", "b"]),
547 ListValue::from_iter([1, 2]),
548 )
549 .unwrap(),
550 )),
551 {
552 let mut struct_array_builder = StructArrayBuilder::with_type(
553 2,
554 DataType::Struct(StructType::new(vec![
555 ("id", DataType::Int32),
556 ("name", DataType::Varchar),
557 ])),
558 );
559 struct_array_builder.append(Some(
560 StructValue::new(vec![
561 Some(ScalarImpl::Int32(1)),
562 Some(ScalarImpl::Utf8("x".into())),
563 ])
564 .as_scalar_ref(),
565 ));
566 struct_array_builder.append(Some(
567 StructValue::new(vec![
568 Some(ScalarImpl::Int32(2)),
569 Some(ScalarImpl::Utf8("y".into())),
570 ])
571 .as_scalar_ref(),
572 ));
573 Some(ScalarImpl::Map(
574 MapValue::try_from_kv(
575 ListValue::from_iter(["a", "b"]),
576 ListValue::new(struct_array_builder.finish().into()),
577 )
578 .unwrap(),
579 ))
580 },
581 ]);
582
583 let encoder =
584 ProtoEncoder::new(schema, None, descriptor.clone(), ProtoHeader::None).unwrap();
585 let m = encoder.encode(row).unwrap();
586 expect_test::expect![[r#"
587 field: FieldDescriptor {
588 name: "double_field",
589 full_name: "all_types.AllTypes.double_field",
590 json_name: "doubleField",
591 number: 1,
592 kind: double,
593 cardinality: Optional,
594 containing_oneof: None,
595 default_value: None,
596 is_group: false,
597 is_list: false,
598 is_map: false,
599 is_packed: false,
600 supports_presence: false,
601 }
602
603 value: F64(4.25)
604
605 ==============================
606 field: FieldDescriptor {
607 name: "float_field",
608 full_name: "all_types.AllTypes.float_field",
609 json_name: "floatField",
610 number: 2,
611 kind: float,
612 cardinality: Optional,
613 containing_oneof: None,
614 default_value: None,
615 is_group: false,
616 is_list: false,
617 is_map: false,
618 is_packed: false,
619 supports_presence: false,
620 }
621
622 value: F32(3.5)
623
624 ==============================
625 field: FieldDescriptor {
626 name: "int32_field",
627 full_name: "all_types.AllTypes.int32_field",
628 json_name: "int32Field",
629 number: 3,
630 kind: int32,
631 cardinality: Optional,
632 containing_oneof: None,
633 default_value: None,
634 is_group: false,
635 is_list: false,
636 is_map: false,
637 is_packed: false,
638 supports_presence: false,
639 }
640
641 value: I32(22)
642
643 ==============================
644 field: FieldDescriptor {
645 name: "int64_field",
646 full_name: "all_types.AllTypes.int64_field",
647 json_name: "int64Field",
648 number: 4,
649 kind: int64,
650 cardinality: Optional,
651 containing_oneof: None,
652 default_value: None,
653 is_group: false,
654 is_list: false,
655 is_map: false,
656 is_packed: false,
657 supports_presence: false,
658 }
659
660 value: I64(23)
661
662 ==============================
663 field: FieldDescriptor {
664 name: "sint32_field",
665 full_name: "all_types.AllTypes.sint32_field",
666 json_name: "sint32Field",
667 number: 7,
668 kind: sint32,
669 cardinality: Optional,
670 containing_oneof: None,
671 default_value: None,
672 is_group: false,
673 is_list: false,
674 is_map: false,
675 is_packed: false,
676 supports_presence: false,
677 }
678
679 value: I32(24)
680
681 ==============================
682 field: FieldDescriptor {
683 name: "sfixed32_field",
684 full_name: "all_types.AllTypes.sfixed32_field",
685 json_name: "sfixed32Field",
686 number: 11,
687 kind: sfixed32,
688 cardinality: Optional,
689 containing_oneof: None,
690 default_value: None,
691 is_group: false,
692 is_list: false,
693 is_map: false,
694 is_packed: false,
695 supports_presence: false,
696 }
697
698 value: I32(26)
699
700 ==============================
701 field: FieldDescriptor {
702 name: "sfixed64_field",
703 full_name: "all_types.AllTypes.sfixed64_field",
704 json_name: "sfixed64Field",
705 number: 12,
706 kind: sfixed64,
707 cardinality: Optional,
708 containing_oneof: None,
709 default_value: None,
710 is_group: false,
711 is_list: false,
712 is_map: false,
713 is_packed: false,
714 supports_presence: false,
715 }
716
717 value: I64(27)
718
719 ==============================
720 field: FieldDescriptor {
721 name: "bool_field",
722 full_name: "all_types.AllTypes.bool_field",
723 json_name: "boolField",
724 number: 13,
725 kind: bool,
726 cardinality: Optional,
727 containing_oneof: None,
728 default_value: None,
729 is_group: false,
730 is_list: false,
731 is_map: false,
732 is_packed: false,
733 supports_presence: false,
734 }
735
736 value: Bool(true)
737
738 ==============================
739 field: FieldDescriptor {
740 name: "string_field",
741 full_name: "all_types.AllTypes.string_field",
742 json_name: "stringField",
743 number: 14,
744 kind: string,
745 cardinality: Optional,
746 containing_oneof: None,
747 default_value: None,
748 is_group: false,
749 is_list: false,
750 is_map: false,
751 is_packed: false,
752 supports_presence: false,
753 }
754
755 value: String("RisingWave")
756
757 ==============================
758 field: FieldDescriptor {
759 name: "bytes_field",
760 full_name: "all_types.AllTypes.bytes_field",
761 json_name: "bytesField",
762 number: 15,
763 kind: bytes,
764 cardinality: Optional,
765 containing_oneof: None,
766 default_value: None,
767 is_group: false,
768 is_list: false,
769 is_map: false,
770 is_packed: false,
771 supports_presence: false,
772 }
773
774 value: Bytes(b"\xbe\xef")
775
776 ==============================
777 field: FieldDescriptor {
778 name: "nested_message_field",
779 full_name: "all_types.AllTypes.nested_message_field",
780 json_name: "nestedMessageField",
781 number: 17,
782 kind: all_types.AllTypes.NestedMessage,
783 cardinality: Optional,
784 containing_oneof: None,
785 default_value: None,
786 is_group: false,
787 is_list: false,
788 is_map: false,
789 is_packed: false,
790 supports_presence: true,
791 }
792
793 value: Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(1)), 2: Value(String(""))} } })
794
795 ==============================
796 field: FieldDescriptor {
797 name: "repeated_int_field",
798 full_name: "all_types.AllTypes.repeated_int_field",
799 json_name: "repeatedIntField",
800 number: 18,
801 kind: int32,
802 cardinality: Repeated,
803 containing_oneof: None,
804 default_value: None,
805 is_group: false,
806 is_list: true,
807 is_map: false,
808 is_packed: true,
809 supports_presence: false,
810 }
811
812 value: List([I32(4), I32(0), I32(4)])
813
814 ==============================
815 field: FieldDescriptor {
816 name: "map_field",
817 full_name: "all_types.AllTypes.map_field",
818 json_name: "mapField",
819 number: 22,
820 kind: all_types.AllTypes.MapFieldEntry,
821 cardinality: Repeated,
822 containing_oneof: None,
823 default_value: None,
824 is_group: false,
825 is_list: false,
826 is_map: true,
827 is_packed: false,
828 supports_presence: false,
829 }
830
831 value: Map({
832 String("a"): I32(1),
833 String("b"): I32(2),
834 })
835
836 ==============================
837 field: FieldDescriptor {
838 name: "timestamp_field",
839 full_name: "all_types.AllTypes.timestamp_field",
840 json_name: "timestampField",
841 number: 23,
842 kind: google.protobuf.Timestamp,
843 cardinality: Optional,
844 containing_oneof: None,
845 default_value: None,
846 is_group: false,
847 is_list: false,
848 is_map: false,
849 is_packed: false,
850 supports_presence: true,
851 }
852
853 value: Message(DynamicMessage { desc: MessageDescriptor { name: "Timestamp", full_name: "google.protobuf.Timestamp", is_map_entry: false, fields: [FieldDescriptor { name: "seconds", full_name: "google.protobuf.Timestamp.seconds", json_name: "seconds", number: 1, kind: int64, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "nanos", full_name: "google.protobuf.Timestamp.nanos", json_name: "nanos", number: 2, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {2: Value(I32(3000))} } })
854
855 ==============================
856 field: FieldDescriptor {
857 name: "map_struct_field",
858 full_name: "all_types.AllTypes.map_struct_field",
859 json_name: "mapStructField",
860 number: 29,
861 kind: all_types.AllTypes.MapStructFieldEntry,
862 cardinality: Repeated,
863 containing_oneof: None,
864 default_value: None,
865 is_group: false,
866 is_list: false,
867 is_map: true,
868 is_packed: false,
869 supports_presence: false,
870 }
871
872 value: Map({
873 String("a"): Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(1)), 2: Value(String("x"))} } }),
874 String("b"): Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(2)), 2: Value(String("y"))} } }),
875 })"#]].assert_eq(&format!("{}",
876 m.message.fields().format_with("\n\n==============================\n", |(field,value),f| {
877 f(&format!("field: {:#?}\n\nvalue: {}", field, print_proto(value)))
878 })));
879 }
880
881 fn print_proto(value: &Value) -> String {
882 match value {
883 Value::Map(m) => {
884 let mut res = String::new();
885 res.push_str("Map({\n");
886 for (k, v) in m.iter().sorted_by_key(|(k, _v)| *k) {
887 res.push_str(&format!(
888 " {}: {},\n",
889 print_proto(&k.clone().into()),
890 print_proto(v)
891 ));
892 }
893 res.push_str("})");
894 res
895 }
896 _ => format!("{:?}", value),
897 }
898 }
899
900 #[test]
901 fn test_encode_proto_repeated() {
902 let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
903 .join("codec/tests/test_data/all-types.pb");
904 let pool_bytes = fs_err::read(pool_path).unwrap();
905 let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
906 let message_descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
907
908 let schema = Schema::new(vec![Field::with_name(
909 DataType::List(DataType::List(DataType::Int32.into()).into()),
910 "repeated_int_field",
911 )]);
912
913 let err = validate_fields(
914 schema
915 .fields
916 .iter()
917 .map(|f| (f.name.as_str(), &f.data_type)),
918 &message_descriptor,
919 )
920 .unwrap_err();
921 assert_eq!(
922 err.to_string(),
923 "encode 'repeated_int_field' error: cannot encode integer[] column as int32 field"
924 );
925
926 let schema = Schema::new(vec![Field::with_name(
927 DataType::List(DataType::Int32.into()),
928 "repeated_int_field",
929 )]);
930 let row = OwnedRow::new(vec![Some(ScalarImpl::List(ListValue::from_iter([
931 Some(0),
932 None,
933 Some(2),
934 Some(3),
935 ])))]);
936
937 let err = encode_fields(
938 schema
939 .fields
940 .iter()
941 .map(|f| (f.name.as_str(), &f.data_type))
942 .zip_eq_debug(row.iter()),
943 &message_descriptor,
944 )
945 .unwrap_err();
946 assert_eq!(
947 err.to_string(),
948 "encode 'repeated_int_field' error: array containing null not allowed as repeated field"
949 );
950 }
951
952 #[test]
953 fn test_encode_proto_err() {
954 let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
955 .join("codec/tests/test_data/all-types.pb");
956 let pool_bytes = std::fs::read(pool_path).unwrap();
957 let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
958 let message_descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
959
960 let err = validate_fields(
961 std::iter::once(("not_exists", &DataType::Int16)),
962 &message_descriptor,
963 )
964 .unwrap_err();
965 assert_eq!(
966 err.to_string(),
967 "encode 'not_exists' error: field not in proto"
968 );
969
970 let err = validate_fields(
971 std::iter::once(("map_field", &DataType::Jsonb)),
972 &message_descriptor,
973 )
974 .unwrap_err();
975 assert_eq!(
976 err.to_string(),
977 "encode 'map_field' error: cannot encode jsonb column as all_types.AllTypes.MapFieldEntry field"
978 );
979 }
980}