1use bytes::{BufMut, Bytes};
16use prost::Message;
17use prost_reflect::{
18 DynamicMessage, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage, Value,
19};
20use risingwave_common::array::VECTOR_ITEM_TYPE;
21use risingwave_common::catalog::Schema;
22use risingwave_common::row::Row;
23use risingwave_common::types::{DataType, DatumRef, MapType, ScalarRefImpl, StructType};
24use risingwave_common::util::iter_util::ZipEqDebug;
25
26use super::{FieldEncodeError, Result as SinkResult, RowEncoder, SerTo};
27
28type Result<T> = std::result::Result<T, FieldEncodeError>;
29
30pub struct ProtoEncoder {
31 schema: Schema,
32 col_indices: Option<Vec<usize>>,
33 descriptor: MessageDescriptor,
34 header: ProtoHeader,
35}
36
37#[derive(Debug, Clone, Copy)]
38pub enum ProtoHeader {
39 None,
40 ConfluentSchemaRegistry(i32),
45}
46
47impl ProtoEncoder {
48 pub fn new(
49 schema: Schema,
50 col_indices: Option<Vec<usize>>,
51 descriptor: MessageDescriptor,
52 header: ProtoHeader,
53 ) -> SinkResult<Self> {
54 match &col_indices {
55 Some(col_indices) => validate_fields(
56 col_indices.iter().map(|idx| {
57 let f = &schema[*idx];
58 (f.name.as_str(), &f.data_type)
59 }),
60 &descriptor,
61 )?,
62 None => validate_fields(
63 schema
64 .fields
65 .iter()
66 .map(|f| (f.name.as_str(), &f.data_type)),
67 &descriptor,
68 )?,
69 };
70
71 Ok(Self {
72 schema,
73 col_indices,
74 descriptor,
75 header,
76 })
77 }
78}
79
80pub struct ProtoEncoded {
81 pub message: DynamicMessage,
82 header: ProtoHeader,
83}
84
85impl RowEncoder for ProtoEncoder {
86 type Output = ProtoEncoded;
87
88 fn schema(&self) -> &Schema {
89 &self.schema
90 }
91
92 fn col_indices(&self) -> Option<&[usize]> {
93 self.col_indices.as_deref()
94 }
95
96 fn encode_cols(
97 &self,
98 row: impl Row,
99 col_indices: impl Iterator<Item = usize>,
100 ) -> SinkResult<Self::Output> {
101 encode_fields(
102 col_indices.map(|idx| {
103 let f = &self.schema[idx];
104 ((f.name.as_str(), &f.data_type), row.datum_at(idx))
105 }),
106 &self.descriptor,
107 )
108 .map_err(Into::into)
109 .map(|m| ProtoEncoded {
110 message: m,
111 header: self.header,
112 })
113 }
114}
115
116impl SerTo<Vec<u8>> for ProtoEncoded {
117 fn ser_to(self) -> SinkResult<Vec<u8>> {
118 let mut buf = Vec::new();
119 match self.header {
120 ProtoHeader::None => { }
121 ProtoHeader::ConfluentSchemaRegistry(schema_id) => {
122 buf.reserve(1 + 4);
123 buf.put_u8(0);
124 buf.put_i32(schema_id);
125 MessageIndexes::from(self.message.descriptor()).encode(&mut buf);
126 }
127 }
128 self.message.encode(&mut buf).unwrap();
129 Ok(buf)
130 }
131}
132
133struct MessageIndexes(Vec<i32>);
134
135impl MessageIndexes {
136 fn from(desc: MessageDescriptor) -> Self {
137 const TAG_FILE_MESSAGE: i32 = 4;
142 const TAG_MESSAGE_NESTED: i32 = 3;
144
145 let mut indexes = vec![];
146 let mut path = desc.path().iter().copied().array_chunks();
147 let [tag, idx] = path.next().unwrap();
148 assert_eq!(tag, TAG_FILE_MESSAGE);
149 indexes.push(idx);
150 for [tag, idx] in path {
151 assert_eq!(tag, TAG_MESSAGE_NESTED);
152 indexes.push(idx);
153 }
154 Self(indexes)
155 }
156
157 fn zig_i32(value: i32, buf: &mut impl BufMut) {
158 let unsigned = ((value << 1) ^ (value >> 31)) as u32 as u64;
159 prost::encoding::encode_varint(unsigned, buf);
160 }
161
162 fn encode(&self, buf: &mut impl BufMut) {
163 if self.0 == [0] {
164 buf.put_u8(0);
165 return;
166 }
167 Self::zig_i32(self.0.len().try_into().unwrap(), buf);
168 for &idx in &self.0 {
169 Self::zig_i32(idx, buf);
170 }
171 }
172}
173
174trait MaybeData: std::fmt::Debug {
180 type Out;
181
182 fn on_base(self, f: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out>;
183
184 fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out>;
185
186 fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out>;
187
188 fn on_map(self, m: &MapType, pb: &MessageDescriptor) -> Result<Self::Out>;
189}
190
191impl MaybeData for () {
192 type Out = ();
193
194 fn on_base(self, _: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out> {
195 Ok(self)
196 }
197
198 fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out> {
199 validate_fields(st.iter(), pb)
200 }
201
202 fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out> {
203 on_field(elem, (), pb, true)
204 }
205
206 fn on_map(self, elem: &MapType, pb: &MessageDescriptor) -> Result<Self::Out> {
207 debug_assert!(pb.is_map_entry());
208 on_field(elem.key(), (), &pb.map_entry_key_field(), false)?;
209 on_field(elem.value(), (), &pb.map_entry_value_field(), false)?;
210 Ok(())
211 }
212}
213
214impl MaybeData for ScalarRefImpl<'_> {
221 type Out = Value;
222
223 fn on_base(self, f: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out> {
224 f(self)
225 }
226
227 fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out> {
228 let d = self.into_struct();
229 let message = encode_fields(st.iter().zip_eq_debug(d.iter_fields_ref()), pb)?;
230 Ok(Value::Message(message))
231 }
232
233 fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out> {
234 let d = self.into_list();
235 let vs = d
236 .iter()
237 .map(|d| {
238 on_field(
239 elem,
240 d.ok_or_else(|| {
241 FieldEncodeError::new("array containing null not allowed as repeated field")
242 })?,
243 pb,
244 true,
245 )
246 })
247 .try_collect()?;
248 Ok(Value::List(vs))
249 }
250
251 fn on_map(self, m: &MapType, pb: &MessageDescriptor) -> Result<Self::Out> {
252 debug_assert!(pb.is_map_entry());
253 let vs = self
254 .into_map()
255 .iter()
256 .map(|(k, v)| {
257 let v =
258 v.ok_or_else(|| FieldEncodeError::new("map containing null not allowed"))?;
259 let k = on_field(m.key(), k, &pb.map_entry_key_field(), false)?;
260 let v = on_field(m.value(), v, &pb.map_entry_value_field(), false)?;
261 Ok((
262 k.into_map_key().ok_or_else(|| {
263 FieldEncodeError::new("failed to convert map key to proto")
264 })?,
265 v,
266 ))
267 })
268 .try_collect()?;
269 Ok(Value::Map(vs))
270 }
271}
272
273fn validate_fields<'a>(
274 fields: impl Iterator<Item = (&'a str, &'a DataType)>,
275 descriptor: &MessageDescriptor,
276) -> Result<()> {
277 for (name, t) in fields {
278 let Some(proto_field) = descriptor.get_field_by_name(name) else {
279 return Err(FieldEncodeError::new("field not in proto").with_name(name));
280 };
281 if proto_field.cardinality() == prost_reflect::Cardinality::Required {
282 return Err(FieldEncodeError::new("`required` not supported").with_name(name));
283 }
284 on_field(t, (), &proto_field, false).map_err(|e| e.with_name(name))?;
285 }
286 Ok(())
287}
288
289fn encode_fields<'a>(
290 fields_with_datums: impl Iterator<Item = ((&'a str, &'a DataType), DatumRef<'a>)>,
291 descriptor: &MessageDescriptor,
292) -> Result<DynamicMessage> {
293 let mut message = DynamicMessage::new(descriptor.clone());
294 for ((name, t), d) in fields_with_datums {
295 let proto_field = descriptor.get_field_by_name(name).unwrap();
296 if let Some(scalar) = d {
298 let value = on_field(t, scalar, &proto_field, false).map_err(|e| e.with_name(name))?;
299 message
300 .try_set_field(&proto_field, value)
301 .map_err(|e| FieldEncodeError::new(e).with_name(name))?;
302 }
303 }
304 Ok(message)
305}
306
307const WKT_TIMESTAMP: &str = "google.protobuf.Timestamp";
309#[expect(dead_code)]
310const WKT_BOOL_VALUE: &str = "google.protobuf.BoolValue";
311
312fn on_field<D: MaybeData>(
315 data_type: &DataType,
316 maybe: D,
317 proto_field: &FieldDescriptor,
318 in_repeated: bool,
319) -> Result<D::Out> {
320 assert!(proto_field.is_list() || !in_repeated);
327 let expect_list = proto_field.is_list() && !in_repeated;
328 if proto_field.is_group() {
329 return Err(FieldEncodeError::new("proto group not supported yet"));
330 }
331
332 let no_match_err = || {
333 Err(FieldEncodeError::new(format!(
334 "cannot encode {} column as {}{:?} field",
335 data_type,
336 if expect_list { "repeated " } else { "" },
337 proto_field.kind()
338 )))
339 };
340
341 if expect_list && !matches!(data_type, DataType::List(_)) {
342 return no_match_err();
343 }
344
345 let value = match &data_type {
346 DataType::Boolean => match proto_field.kind() {
348 Kind::Bool => maybe.on_base(|s| Ok(Value::Bool(s.into_bool())))?,
349 _ => return no_match_err(),
350 },
351 DataType::Varchar => match proto_field.kind() {
352 Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_utf8().into())))?,
353 Kind::Enum(enum_desc) => maybe.on_base(|s| {
354 let name = s.into_utf8();
355 let enum_value_desc = enum_desc.get_value_by_name(name).ok_or_else(|| {
356 FieldEncodeError::new(format!("'{name}' not in enum {}", enum_desc.name()))
357 })?;
358 Ok(Value::EnumNumber(enum_value_desc.number()))
359 })?,
360 _ => return no_match_err(),
361 },
362 DataType::Bytea => match proto_field.kind() {
363 Kind::Bytes => {
364 maybe.on_base(|s| Ok(Value::Bytes(Bytes::copy_from_slice(s.into_bytea()))))?
365 }
366 _ => return no_match_err(),
367 },
368 DataType::Float32 => match proto_field.kind() {
369 Kind::Float => maybe.on_base(|s| Ok(Value::F32(s.into_float32().into())))?,
370 _ => return no_match_err(),
371 },
372 DataType::Float64 => match proto_field.kind() {
373 Kind::Double => maybe.on_base(|s| Ok(Value::F64(s.into_float64().into())))?,
374 _ => return no_match_err(),
375 },
376 DataType::Int32 => match proto_field.kind() {
377 Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => {
378 maybe.on_base(|s| Ok(Value::I32(s.into_int32())))?
379 }
380 _ => return no_match_err(),
381 },
382 DataType::Int64 => match proto_field.kind() {
383 Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => {
384 maybe.on_base(|s| Ok(Value::I64(s.into_int64())))?
385 }
386 _ => return no_match_err(),
387 },
388 DataType::Struct(st) => match proto_field.kind() {
389 Kind::Message(pb) => maybe.on_struct(st, &pb)?,
390 _ => return no_match_err(),
391 },
392 DataType::List(lt) => match expect_list {
393 true => maybe.on_list(lt.elem(), proto_field)?,
394 false => return no_match_err(),
395 },
396 DataType::Timestamptz => match proto_field.kind() {
398 Kind::Message(pb) if pb.full_name() == WKT_TIMESTAMP => maybe.on_base(|s| {
399 let d = s.into_timestamptz();
400 let message = prost_types::Timestamp {
401 seconds: d.timestamp(),
402 nanos: d.timestamp_subsec_nanos().try_into().unwrap(),
403 };
404 Ok(Value::Message(message.transcode_to_dynamic()))
405 })?,
406 Kind::String => {
407 maybe.on_base(|s| Ok(Value::String(s.into_timestamptz().to_string())))?
408 }
409 _ => return no_match_err(),
410 },
411 DataType::Jsonb => match proto_field.kind() {
412 Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_jsonb().to_string())))?,
413 _ => return no_match_err(), },
416 DataType::Int16 => match proto_field.kind() {
417 Kind::Int64 => maybe.on_base(|s| Ok(Value::I64(s.into_int16() as i64)))?,
418 _ => return no_match_err(),
419 },
420 DataType::Date => match proto_field.kind() {
421 Kind::Int32 => {
422 maybe.on_base(|s| Ok(Value::I32(s.into_date().get_nums_days_unix_epoch())))?
423 }
424 _ => return no_match_err(), },
426 DataType::Time => match proto_field.kind() {
427 Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_time().to_string())))?,
428 _ => return no_match_err(), },
430 DataType::Timestamp => match proto_field.kind() {
431 Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_timestamp().to_string())))?,
432 _ => return no_match_err(), },
434 DataType::Decimal => match proto_field.kind() {
435 Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_decimal().to_string())))?,
436 _ => return no_match_err(), },
438 DataType::Interval => match proto_field.kind() {
439 Kind::String => {
440 maybe.on_base(|s| Ok(Value::String(s.into_interval().as_iso_8601())))?
441 }
442 _ => return no_match_err(), },
444 DataType::Serial => match proto_field.kind() {
445 Kind::Int64 => maybe.on_base(|s| Ok(Value::I64(s.into_serial().as_row_id())))?,
446 _ => return no_match_err(), },
448 DataType::Int256 => {
449 return no_match_err();
450 }
451 DataType::Map(map_type) => {
452 if proto_field.is_map() {
453 let msg = match proto_field.kind() {
454 Kind::Message(m) => m,
455 _ => return no_match_err(), };
457 return maybe.on_map(map_type, &msg);
458 } else {
459 return no_match_err();
460 }
461 }
462 DataType::Vector(_) => match expect_list {
463 true => maybe.on_list(&VECTOR_ITEM_TYPE, proto_field)?,
464 false => return no_match_err(),
465 },
466 };
467
468 Ok(value)
469}
470
471#[cfg(test)]
472mod tests {
473 use itertools::Itertools;
474 use risingwave_common::array::{ArrayBuilder, StructArrayBuilder};
475 use risingwave_common::catalog::Field;
476 use risingwave_common::row::OwnedRow;
477 use risingwave_common::types::{
478 ListValue, MapType, MapValue, Scalar, ScalarImpl, StructValue, Timestamptz,
479 };
480
481 use super::*;
482
483 #[test]
484 fn test_encode_proto_ok() {
485 let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
486 .join("codec/tests/test_data/all-types.pb");
487 let pool_bytes = std::fs::read(pool_path).unwrap();
488 let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
489 let descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
490 let schema = Schema::new(vec![
491 Field::with_name(DataType::Boolean, "bool_field"),
492 Field::with_name(DataType::Varchar, "string_field"),
493 Field::with_name(DataType::Bytea, "bytes_field"),
494 Field::with_name(DataType::Float32, "float_field"),
495 Field::with_name(DataType::Float64, "double_field"),
496 Field::with_name(DataType::Int32, "int32_field"),
497 Field::with_name(DataType::Int64, "int64_field"),
498 Field::with_name(DataType::Int32, "sint32_field"),
499 Field::with_name(DataType::Int64, "sint64_field"),
500 Field::with_name(DataType::Int32, "sfixed32_field"),
501 Field::with_name(DataType::Int64, "sfixed64_field"),
502 Field::with_name(
503 DataType::Struct(StructType::new(vec![
504 ("id", DataType::Int32),
505 ("name", DataType::Varchar),
506 ])),
507 "nested_message_field",
508 ),
509 Field::with_name(DataType::Int32.list(), "repeated_int_field"),
510 Field::with_name(DataType::Timestamptz, "timestamp_field"),
511 Field::with_name(
512 DataType::Map(MapType::from_kv(DataType::Varchar, DataType::Int32)),
513 "map_field",
514 ),
515 Field::with_name(
516 DataType::Map(MapType::from_kv(
517 DataType::Varchar,
518 DataType::Struct(StructType::new(vec![
519 ("id", DataType::Int32),
520 ("name", DataType::Varchar),
521 ])),
522 )),
523 "map_struct_field",
524 ),
525 ]);
526 let row = OwnedRow::new(vec![
527 Some(ScalarImpl::Bool(true)),
528 Some(ScalarImpl::Utf8("RisingWave".into())),
529 Some(ScalarImpl::Bytea([0xbe, 0xef].into())),
530 Some(ScalarImpl::Float32(3.5f32.into())),
531 Some(ScalarImpl::Float64(4.25f64.into())),
532 Some(ScalarImpl::Int32(22)),
533 Some(ScalarImpl::Int64(23)),
534 Some(ScalarImpl::Int32(24)),
535 None,
536 Some(ScalarImpl::Int32(26)),
537 Some(ScalarImpl::Int64(27)),
538 Some(ScalarImpl::Struct(StructValue::new(vec![
539 Some(ScalarImpl::Int32(1)),
540 Some(ScalarImpl::Utf8("".into())),
541 ]))),
542 Some(ScalarImpl::List(ListValue::from_iter([4, 0, 4]))),
543 Some(ScalarImpl::Timestamptz(Timestamptz::from_micros(3))),
544 Some(ScalarImpl::Map(
545 MapValue::try_from_kv(
546 ListValue::from_iter(["a", "b"]),
547 ListValue::from_iter([1, 2]),
548 )
549 .unwrap(),
550 )),
551 {
552 let mut struct_array_builder = StructArrayBuilder::with_type(
553 2,
554 DataType::Struct(StructType::new(vec![
555 ("id", DataType::Int32),
556 ("name", DataType::Varchar),
557 ])),
558 );
559 struct_array_builder.append(Some(
560 StructValue::new(vec![
561 Some(ScalarImpl::Int32(1)),
562 Some(ScalarImpl::Utf8("x".into())),
563 ])
564 .as_scalar_ref(),
565 ));
566 struct_array_builder.append(Some(
567 StructValue::new(vec![
568 Some(ScalarImpl::Int32(2)),
569 Some(ScalarImpl::Utf8("y".into())),
570 ])
571 .as_scalar_ref(),
572 ));
573 Some(ScalarImpl::Map(
574 MapValue::try_from_kv(
575 ListValue::from_iter(["a", "b"]),
576 ListValue::new(struct_array_builder.finish().into()),
577 )
578 .unwrap(),
579 ))
580 },
581 ]);
582
583 let encoder = ProtoEncoder::new(schema, None, descriptor, ProtoHeader::None).unwrap();
584 let m = encoder.encode(row).unwrap();
585 expect_test::expect![[r#"
586 field: FieldDescriptor {
587 name: "double_field",
588 full_name: "all_types.AllTypes.double_field",
589 json_name: "doubleField",
590 number: 1,
591 kind: double,
592 cardinality: Optional,
593 containing_oneof: None,
594 default_value: None,
595 is_group: false,
596 is_list: false,
597 is_map: false,
598 is_packed: false,
599 supports_presence: false,
600 }
601
602 value: F64(4.25)
603
604 ==============================
605 field: FieldDescriptor {
606 name: "float_field",
607 full_name: "all_types.AllTypes.float_field",
608 json_name: "floatField",
609 number: 2,
610 kind: float,
611 cardinality: Optional,
612 containing_oneof: None,
613 default_value: None,
614 is_group: false,
615 is_list: false,
616 is_map: false,
617 is_packed: false,
618 supports_presence: false,
619 }
620
621 value: F32(3.5)
622
623 ==============================
624 field: FieldDescriptor {
625 name: "int32_field",
626 full_name: "all_types.AllTypes.int32_field",
627 json_name: "int32Field",
628 number: 3,
629 kind: int32,
630 cardinality: Optional,
631 containing_oneof: None,
632 default_value: None,
633 is_group: false,
634 is_list: false,
635 is_map: false,
636 is_packed: false,
637 supports_presence: false,
638 }
639
640 value: I32(22)
641
642 ==============================
643 field: FieldDescriptor {
644 name: "int64_field",
645 full_name: "all_types.AllTypes.int64_field",
646 json_name: "int64Field",
647 number: 4,
648 kind: int64,
649 cardinality: Optional,
650 containing_oneof: None,
651 default_value: None,
652 is_group: false,
653 is_list: false,
654 is_map: false,
655 is_packed: false,
656 supports_presence: false,
657 }
658
659 value: I64(23)
660
661 ==============================
662 field: FieldDescriptor {
663 name: "sint32_field",
664 full_name: "all_types.AllTypes.sint32_field",
665 json_name: "sint32Field",
666 number: 7,
667 kind: sint32,
668 cardinality: Optional,
669 containing_oneof: None,
670 default_value: None,
671 is_group: false,
672 is_list: false,
673 is_map: false,
674 is_packed: false,
675 supports_presence: false,
676 }
677
678 value: I32(24)
679
680 ==============================
681 field: FieldDescriptor {
682 name: "sfixed32_field",
683 full_name: "all_types.AllTypes.sfixed32_field",
684 json_name: "sfixed32Field",
685 number: 11,
686 kind: sfixed32,
687 cardinality: Optional,
688 containing_oneof: None,
689 default_value: None,
690 is_group: false,
691 is_list: false,
692 is_map: false,
693 is_packed: false,
694 supports_presence: false,
695 }
696
697 value: I32(26)
698
699 ==============================
700 field: FieldDescriptor {
701 name: "sfixed64_field",
702 full_name: "all_types.AllTypes.sfixed64_field",
703 json_name: "sfixed64Field",
704 number: 12,
705 kind: sfixed64,
706 cardinality: Optional,
707 containing_oneof: None,
708 default_value: None,
709 is_group: false,
710 is_list: false,
711 is_map: false,
712 is_packed: false,
713 supports_presence: false,
714 }
715
716 value: I64(27)
717
718 ==============================
719 field: FieldDescriptor {
720 name: "bool_field",
721 full_name: "all_types.AllTypes.bool_field",
722 json_name: "boolField",
723 number: 13,
724 kind: bool,
725 cardinality: Optional,
726 containing_oneof: None,
727 default_value: None,
728 is_group: false,
729 is_list: false,
730 is_map: false,
731 is_packed: false,
732 supports_presence: false,
733 }
734
735 value: Bool(true)
736
737 ==============================
738 field: FieldDescriptor {
739 name: "string_field",
740 full_name: "all_types.AllTypes.string_field",
741 json_name: "stringField",
742 number: 14,
743 kind: string,
744 cardinality: Optional,
745 containing_oneof: None,
746 default_value: None,
747 is_group: false,
748 is_list: false,
749 is_map: false,
750 is_packed: false,
751 supports_presence: false,
752 }
753
754 value: String("RisingWave")
755
756 ==============================
757 field: FieldDescriptor {
758 name: "bytes_field",
759 full_name: "all_types.AllTypes.bytes_field",
760 json_name: "bytesField",
761 number: 15,
762 kind: bytes,
763 cardinality: Optional,
764 containing_oneof: None,
765 default_value: None,
766 is_group: false,
767 is_list: false,
768 is_map: false,
769 is_packed: false,
770 supports_presence: false,
771 }
772
773 value: Bytes(b"\xbe\xef")
774
775 ==============================
776 field: FieldDescriptor {
777 name: "nested_message_field",
778 full_name: "all_types.AllTypes.nested_message_field",
779 json_name: "nestedMessageField",
780 number: 17,
781 kind: all_types.AllTypes.NestedMessage,
782 cardinality: Optional,
783 containing_oneof: None,
784 default_value: None,
785 is_group: false,
786 is_list: false,
787 is_map: false,
788 is_packed: false,
789 supports_presence: true,
790 }
791
792 value: Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(1)), 2: Value(String(""))} } })
793
794 ==============================
795 field: FieldDescriptor {
796 name: "repeated_int_field",
797 full_name: "all_types.AllTypes.repeated_int_field",
798 json_name: "repeatedIntField",
799 number: 18,
800 kind: int32,
801 cardinality: Repeated,
802 containing_oneof: None,
803 default_value: None,
804 is_group: false,
805 is_list: true,
806 is_map: false,
807 is_packed: true,
808 supports_presence: false,
809 }
810
811 value: List([I32(4), I32(0), I32(4)])
812
813 ==============================
814 field: FieldDescriptor {
815 name: "map_field",
816 full_name: "all_types.AllTypes.map_field",
817 json_name: "mapField",
818 number: 22,
819 kind: all_types.AllTypes.MapFieldEntry,
820 cardinality: Repeated,
821 containing_oneof: None,
822 default_value: None,
823 is_group: false,
824 is_list: false,
825 is_map: true,
826 is_packed: false,
827 supports_presence: false,
828 }
829
830 value: Map({
831 String("a"): I32(1),
832 String("b"): I32(2),
833 })
834
835 ==============================
836 field: FieldDescriptor {
837 name: "timestamp_field",
838 full_name: "all_types.AllTypes.timestamp_field",
839 json_name: "timestampField",
840 number: 23,
841 kind: google.protobuf.Timestamp,
842 cardinality: Optional,
843 containing_oneof: None,
844 default_value: None,
845 is_group: false,
846 is_list: false,
847 is_map: false,
848 is_packed: false,
849 supports_presence: true,
850 }
851
852 value: Message(DynamicMessage { desc: MessageDescriptor { name: "Timestamp", full_name: "google.protobuf.Timestamp", is_map_entry: false, fields: [FieldDescriptor { name: "seconds", full_name: "google.protobuf.Timestamp.seconds", json_name: "seconds", number: 1, kind: int64, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "nanos", full_name: "google.protobuf.Timestamp.nanos", json_name: "nanos", number: 2, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {2: Value(I32(3000))} } })
853
854 ==============================
855 field: FieldDescriptor {
856 name: "map_struct_field",
857 full_name: "all_types.AllTypes.map_struct_field",
858 json_name: "mapStructField",
859 number: 29,
860 kind: all_types.AllTypes.MapStructFieldEntry,
861 cardinality: Repeated,
862 containing_oneof: None,
863 default_value: None,
864 is_group: false,
865 is_list: false,
866 is_map: true,
867 is_packed: false,
868 supports_presence: false,
869 }
870
871 value: Map({
872 String("a"): Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(1)), 2: Value(String("x"))} } }),
873 String("b"): Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(2)), 2: Value(String("y"))} } }),
874 })"#]].assert_eq(&format!("{}",
875 m.message.fields().format_with("\n\n==============================\n", |(field,value),f| {
876 f(&format!("field: {:#?}\n\nvalue: {}", field, print_proto(value)))
877 })));
878 }
879
880 fn print_proto(value: &Value) -> String {
881 match value {
882 Value::Map(m) => {
883 let mut res = String::new();
884 res.push_str("Map({\n");
885 for (k, v) in m.iter().sorted_by_key(|(k, _v)| *k) {
886 res.push_str(&format!(
887 " {}: {},\n",
888 print_proto(&k.clone().into()),
889 print_proto(v)
890 ));
891 }
892 res.push_str("})");
893 res
894 }
895 _ => format!("{:?}", value),
896 }
897 }
898
899 #[test]
900 fn test_encode_proto_repeated() {
901 let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
902 .join("codec/tests/test_data/all-types.pb");
903 let pool_bytes = fs_err::read(pool_path).unwrap();
904 let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
905 let message_descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
906
907 let schema = Schema::new(vec![Field::with_name(
908 DataType::Int32.list().list(),
909 "repeated_int_field",
910 )]);
911
912 let err = validate_fields(
913 schema
914 .fields
915 .iter()
916 .map(|f| (f.name.as_str(), &f.data_type)),
917 &message_descriptor,
918 )
919 .unwrap_err();
920 assert_eq!(
921 err.to_string(),
922 "encode 'repeated_int_field' error: cannot encode integer[] column as int32 field"
923 );
924
925 let schema = Schema::new(vec![Field::with_name(
926 DataType::Int32.list(),
927 "repeated_int_field",
928 )]);
929 let row = OwnedRow::new(vec![Some(ScalarImpl::List(ListValue::from_iter([
930 Some(0),
931 None,
932 Some(2),
933 Some(3),
934 ])))]);
935
936 let err = encode_fields(
937 schema
938 .fields
939 .iter()
940 .map(|f| (f.name.as_str(), &f.data_type))
941 .zip_eq_debug(row.iter()),
942 &message_descriptor,
943 )
944 .unwrap_err();
945 assert_eq!(
946 err.to_string(),
947 "encode 'repeated_int_field' error: array containing null not allowed as repeated field"
948 );
949 }
950
951 #[test]
952 fn test_encode_proto_err() {
953 let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
954 .join("codec/tests/test_data/all-types.pb");
955 let pool_bytes = std::fs::read(pool_path).unwrap();
956 let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
957 let message_descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
958
959 let err = validate_fields(
960 std::iter::once(("not_exists", &DataType::Int16)),
961 &message_descriptor,
962 )
963 .unwrap_err();
964 assert_eq!(
965 err.to_string(),
966 "encode 'not_exists' error: field not in proto"
967 );
968
969 let err = validate_fields(
970 std::iter::once(("map_field", &DataType::Jsonb)),
971 &message_descriptor,
972 )
973 .unwrap_err();
974 assert_eq!(
975 err.to_string(),
976 "encode 'map_field' error: cannot encode jsonb column as all_types.AllTypes.MapFieldEntry field"
977 );
978 }
979}