1use bytes::{BufMut, Bytes};
16use prost::Message;
17use prost_reflect::{
18 DynamicMessage, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage, Value,
19};
20use risingwave_common::catalog::Schema;
21use risingwave_common::row::Row;
22use risingwave_common::types::{DataType, DatumRef, MapType, ScalarRefImpl, StructType};
23use risingwave_common::util::iter_util::ZipEqDebug;
24
25use super::{FieldEncodeError, Result as SinkResult, RowEncoder, SerTo};
26
27type Result<T> = std::result::Result<T, FieldEncodeError>;
28
29pub struct ProtoEncoder {
30 schema: Schema,
31 col_indices: Option<Vec<usize>>,
32 descriptor: MessageDescriptor,
33 header: ProtoHeader,
34}
35
36#[derive(Debug, Clone, Copy)]
37pub enum ProtoHeader {
38 None,
39 ConfluentSchemaRegistry(i32),
44}
45
46impl ProtoEncoder {
47 pub fn new(
48 schema: Schema,
49 col_indices: Option<Vec<usize>>,
50 descriptor: MessageDescriptor,
51 header: ProtoHeader,
52 ) -> SinkResult<Self> {
53 match &col_indices {
54 Some(col_indices) => validate_fields(
55 col_indices.iter().map(|idx| {
56 let f = &schema[*idx];
57 (f.name.as_str(), &f.data_type)
58 }),
59 &descriptor,
60 )?,
61 None => validate_fields(
62 schema
63 .fields
64 .iter()
65 .map(|f| (f.name.as_str(), &f.data_type)),
66 &descriptor,
67 )?,
68 };
69
70 Ok(Self {
71 schema,
72 col_indices,
73 descriptor,
74 header,
75 })
76 }
77}
78
79pub struct ProtoEncoded {
80 pub message: DynamicMessage,
81 header: ProtoHeader,
82}
83
84impl RowEncoder for ProtoEncoder {
85 type Output = ProtoEncoded;
86
87 fn schema(&self) -> &Schema {
88 &self.schema
89 }
90
91 fn col_indices(&self) -> Option<&[usize]> {
92 self.col_indices.as_deref()
93 }
94
95 fn encode_cols(
96 &self,
97 row: impl Row,
98 col_indices: impl Iterator<Item = usize>,
99 ) -> SinkResult<Self::Output> {
100 encode_fields(
101 col_indices.map(|idx| {
102 let f = &self.schema[idx];
103 ((f.name.as_str(), &f.data_type), row.datum_at(idx))
104 }),
105 &self.descriptor,
106 )
107 .map_err(Into::into)
108 .map(|m| ProtoEncoded {
109 message: m,
110 header: self.header,
111 })
112 }
113}
114
115impl SerTo<Vec<u8>> for ProtoEncoded {
116 fn ser_to(self) -> SinkResult<Vec<u8>> {
117 let mut buf = Vec::new();
118 match self.header {
119 ProtoHeader::None => { }
120 ProtoHeader::ConfluentSchemaRegistry(schema_id) => {
121 buf.reserve(1 + 4);
122 buf.put_u8(0);
123 buf.put_i32(schema_id);
124 MessageIndexes::from(self.message.descriptor()).encode(&mut buf);
125 }
126 }
127 self.message.encode(&mut buf).unwrap();
128 Ok(buf)
129 }
130}
131
132struct MessageIndexes(Vec<i32>);
133
134impl MessageIndexes {
135 fn from(desc: MessageDescriptor) -> Self {
136 const TAG_FILE_MESSAGE: i32 = 4;
141 const TAG_MESSAGE_NESTED: i32 = 3;
143
144 let mut indexes = vec![];
145 let mut path = desc.path().array_chunks();
146 let &[tag, idx] = path.next().unwrap();
147 assert_eq!(tag, TAG_FILE_MESSAGE);
148 indexes.push(idx);
149 for &[tag, idx] in path {
150 assert_eq!(tag, TAG_MESSAGE_NESTED);
151 indexes.push(idx);
152 }
153 Self(indexes)
154 }
155
156 fn zig_i32(value: i32, buf: &mut impl BufMut) {
157 let unsigned = ((value << 1) ^ (value >> 31)) as u32 as u64;
158 prost::encoding::encode_varint(unsigned, buf);
159 }
160
161 fn encode(&self, buf: &mut impl BufMut) {
162 if self.0 == [0] {
163 buf.put_u8(0);
164 return;
165 }
166 Self::zig_i32(self.0.len().try_into().unwrap(), buf);
167 for &idx in &self.0 {
168 Self::zig_i32(idx, buf);
169 }
170 }
171}
172
173trait MaybeData: std::fmt::Debug {
179 type Out;
180
181 fn on_base(self, f: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out>;
182
183 fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out>;
184
185 fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out>;
186
187 fn on_map(self, m: &MapType, pb: &MessageDescriptor) -> Result<Self::Out>;
188}
189
190impl MaybeData for () {
191 type Out = ();
192
193 fn on_base(self, _: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out> {
194 Ok(self)
195 }
196
197 fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out> {
198 validate_fields(st.iter(), pb)
199 }
200
201 fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out> {
202 on_field(elem, (), pb, true)
203 }
204
205 fn on_map(self, elem: &MapType, pb: &MessageDescriptor) -> Result<Self::Out> {
206 debug_assert!(pb.is_map_entry());
207 on_field(elem.key(), (), &pb.map_entry_key_field(), false)?;
208 on_field(elem.value(), (), &pb.map_entry_value_field(), false)?;
209 Ok(())
210 }
211}
212
213impl MaybeData for ScalarRefImpl<'_> {
220 type Out = Value;
221
222 fn on_base(self, f: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out> {
223 f(self)
224 }
225
226 fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out> {
227 let d = self.into_struct();
228 let message = encode_fields(st.iter().zip_eq_debug(d.iter_fields_ref()), pb)?;
229 Ok(Value::Message(message))
230 }
231
232 fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out> {
233 let d = self.into_list();
234 let vs = d
235 .iter()
236 .map(|d| {
237 on_field(
238 elem,
239 d.ok_or_else(|| {
240 FieldEncodeError::new("array containing null not allowed as repeated field")
241 })?,
242 pb,
243 true,
244 )
245 })
246 .try_collect()?;
247 Ok(Value::List(vs))
248 }
249
250 fn on_map(self, m: &MapType, pb: &MessageDescriptor) -> Result<Self::Out> {
251 debug_assert!(pb.is_map_entry());
252 let vs = self
253 .into_map()
254 .iter()
255 .map(|(k, v)| {
256 let v =
257 v.ok_or_else(|| FieldEncodeError::new("map containing null not allowed"))?;
258 let k = on_field(m.key(), k, &pb.map_entry_key_field(), false)?;
259 let v = on_field(m.value(), v, &pb.map_entry_value_field(), false)?;
260 Ok((
261 k.into_map_key().ok_or_else(|| {
262 FieldEncodeError::new("failed to convert map key to proto")
263 })?,
264 v,
265 ))
266 })
267 .try_collect()?;
268 Ok(Value::Map(vs))
269 }
270}
271
272fn validate_fields<'a>(
273 fields: impl Iterator<Item = (&'a str, &'a DataType)>,
274 descriptor: &MessageDescriptor,
275) -> Result<()> {
276 for (name, t) in fields {
277 let Some(proto_field) = descriptor.get_field_by_name(name) else {
278 return Err(FieldEncodeError::new("field not in proto").with_name(name));
279 };
280 if proto_field.cardinality() == prost_reflect::Cardinality::Required {
281 return Err(FieldEncodeError::new("`required` not supported").with_name(name));
282 }
283 on_field(t, (), &proto_field, false).map_err(|e| e.with_name(name))?;
284 }
285 Ok(())
286}
287
288fn encode_fields<'a>(
289 fields_with_datums: impl Iterator<Item = ((&'a str, &'a DataType), DatumRef<'a>)>,
290 descriptor: &MessageDescriptor,
291) -> Result<DynamicMessage> {
292 let mut message = DynamicMessage::new(descriptor.clone());
293 for ((name, t), d) in fields_with_datums {
294 let proto_field = descriptor.get_field_by_name(name).unwrap();
295 if let Some(scalar) = d {
297 let value = on_field(t, scalar, &proto_field, false).map_err(|e| e.with_name(name))?;
298 message
299 .try_set_field(&proto_field, value)
300 .map_err(|e| FieldEncodeError::new(e).with_name(name))?;
301 }
302 }
303 Ok(message)
304}
305
306const WKT_TIMESTAMP: &str = "google.protobuf.Timestamp";
308#[expect(dead_code)]
309const WKT_BOOL_VALUE: &str = "google.protobuf.BoolValue";
310
311fn on_field<D: MaybeData>(
314 data_type: &DataType,
315 maybe: D,
316 proto_field: &FieldDescriptor,
317 in_repeated: bool,
318) -> Result<D::Out> {
319 assert!(proto_field.is_list() || !in_repeated);
326 let expect_list = proto_field.is_list() && !in_repeated;
327 if proto_field.is_group() {
328 return Err(FieldEncodeError::new("proto group not supported yet"));
329 }
330
331 let no_match_err = || {
332 Err(FieldEncodeError::new(format!(
333 "cannot encode {} column as {}{:?} field",
334 data_type,
335 if expect_list { "repeated " } else { "" },
336 proto_field.kind()
337 )))
338 };
339
340 if expect_list && !matches!(data_type, DataType::List(_)) {
341 return no_match_err();
342 }
343
344 let value = match &data_type {
345 DataType::Boolean => match proto_field.kind() {
347 Kind::Bool => maybe.on_base(|s| Ok(Value::Bool(s.into_bool())))?,
348 _ => return no_match_err(),
349 },
350 DataType::Varchar => match proto_field.kind() {
351 Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_utf8().into())))?,
352 Kind::Enum(enum_desc) => maybe.on_base(|s| {
353 let name = s.into_utf8();
354 let enum_value_desc = enum_desc.get_value_by_name(name).ok_or_else(|| {
355 FieldEncodeError::new(format!("'{name}' not in enum {}", enum_desc.name()))
356 })?;
357 Ok(Value::EnumNumber(enum_value_desc.number()))
358 })?,
359 _ => return no_match_err(),
360 },
361 DataType::Bytea => match proto_field.kind() {
362 Kind::Bytes => {
363 maybe.on_base(|s| Ok(Value::Bytes(Bytes::copy_from_slice(s.into_bytea()))))?
364 }
365 _ => return no_match_err(),
366 },
367 DataType::Float32 => match proto_field.kind() {
368 Kind::Float => maybe.on_base(|s| Ok(Value::F32(s.into_float32().into())))?,
369 _ => return no_match_err(),
370 },
371 DataType::Float64 => match proto_field.kind() {
372 Kind::Double => maybe.on_base(|s| Ok(Value::F64(s.into_float64().into())))?,
373 _ => return no_match_err(),
374 },
375 DataType::Int32 => match proto_field.kind() {
376 Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => {
377 maybe.on_base(|s| Ok(Value::I32(s.into_int32())))?
378 }
379 _ => return no_match_err(),
380 },
381 DataType::Int64 => match proto_field.kind() {
382 Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => {
383 maybe.on_base(|s| Ok(Value::I64(s.into_int64())))?
384 }
385 _ => return no_match_err(),
386 },
387 DataType::Struct(st) => match proto_field.kind() {
388 Kind::Message(pb) => maybe.on_struct(st, &pb)?,
389 _ => return no_match_err(),
390 },
391 DataType::List(elem) => match expect_list {
392 true => maybe.on_list(elem, proto_field)?,
393 false => return no_match_err(),
394 },
395 DataType::Timestamptz => match proto_field.kind() {
397 Kind::Message(pb) if pb.full_name() == WKT_TIMESTAMP => maybe.on_base(|s| {
398 let d = s.into_timestamptz();
399 let message = prost_types::Timestamp {
400 seconds: d.timestamp(),
401 nanos: d.timestamp_subsec_nanos().try_into().unwrap(),
402 };
403 Ok(Value::Message(message.transcode_to_dynamic()))
404 })?,
405 Kind::String => {
406 maybe.on_base(|s| Ok(Value::String(s.into_timestamptz().to_string())))?
407 }
408 _ => return no_match_err(),
409 },
410 DataType::Jsonb => match proto_field.kind() {
411 Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_jsonb().to_string())))?,
412 _ => return no_match_err(), },
415 DataType::Int16 => match proto_field.kind() {
416 Kind::Int64 => maybe.on_base(|s| Ok(Value::I64(s.into_int16() as i64)))?,
417 _ => return no_match_err(),
418 },
419 DataType::Date => match proto_field.kind() {
420 Kind::Int32 => {
421 maybe.on_base(|s| Ok(Value::I32(s.into_date().get_nums_days_unix_epoch())))?
422 }
423 _ => return no_match_err(), },
425 DataType::Time => match proto_field.kind() {
426 Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_time().to_string())))?,
427 _ => return no_match_err(), },
429 DataType::Timestamp => match proto_field.kind() {
430 Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_timestamp().to_string())))?,
431 _ => return no_match_err(), },
433 DataType::Decimal => match proto_field.kind() {
434 Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_decimal().to_string())))?,
435 _ => return no_match_err(), },
437 DataType::Interval => match proto_field.kind() {
438 Kind::String => {
439 maybe.on_base(|s| Ok(Value::String(s.into_interval().as_iso_8601())))?
440 }
441 _ => return no_match_err(), },
443 DataType::Serial => match proto_field.kind() {
444 Kind::Int64 => maybe.on_base(|s| Ok(Value::I64(s.into_serial().as_row_id())))?,
445 _ => return no_match_err(), },
447 DataType::Int256 => {
448 return no_match_err();
449 }
450 DataType::Map(map_type) => {
451 if proto_field.is_map() {
452 let msg = match proto_field.kind() {
453 Kind::Message(m) => m,
454 _ => return no_match_err(), };
456 return maybe.on_map(map_type, &msg);
457 } else {
458 return no_match_err();
459 }
460 }
461 };
462
463 Ok(value)
464}
465
466#[cfg(test)]
467mod tests {
468 use itertools::Itertools;
469 use risingwave_common::array::{ArrayBuilder, StructArrayBuilder};
470 use risingwave_common::catalog::Field;
471 use risingwave_common::row::OwnedRow;
472 use risingwave_common::types::{
473 ListValue, MapType, MapValue, Scalar, ScalarImpl, StructValue, Timestamptz,
474 };
475
476 use super::*;
477
478 #[test]
479 fn test_encode_proto_ok() {
480 let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
481 .join("codec/tests/test_data/all-types.pb");
482 let pool_bytes = std::fs::read(pool_path).unwrap();
483 let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
484 let descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
485 let schema = Schema::new(vec![
486 Field::with_name(DataType::Boolean, "bool_field"),
487 Field::with_name(DataType::Varchar, "string_field"),
488 Field::with_name(DataType::Bytea, "bytes_field"),
489 Field::with_name(DataType::Float32, "float_field"),
490 Field::with_name(DataType::Float64, "double_field"),
491 Field::with_name(DataType::Int32, "int32_field"),
492 Field::with_name(DataType::Int64, "int64_field"),
493 Field::with_name(DataType::Int32, "sint32_field"),
494 Field::with_name(DataType::Int64, "sint64_field"),
495 Field::with_name(DataType::Int32, "sfixed32_field"),
496 Field::with_name(DataType::Int64, "sfixed64_field"),
497 Field::with_name(
498 DataType::Struct(StructType::new(vec![
499 ("id", DataType::Int32),
500 ("name", DataType::Varchar),
501 ])),
502 "nested_message_field",
503 ),
504 Field::with_name(DataType::List(DataType::Int32.into()), "repeated_int_field"),
505 Field::with_name(DataType::Timestamptz, "timestamp_field"),
506 Field::with_name(
507 DataType::Map(MapType::from_kv(DataType::Varchar, DataType::Int32)),
508 "map_field",
509 ),
510 Field::with_name(
511 DataType::Map(MapType::from_kv(
512 DataType::Varchar,
513 DataType::Struct(StructType::new(vec![
514 ("id", DataType::Int32),
515 ("name", DataType::Varchar),
516 ])),
517 )),
518 "map_struct_field",
519 ),
520 ]);
521 let row = OwnedRow::new(vec![
522 Some(ScalarImpl::Bool(true)),
523 Some(ScalarImpl::Utf8("RisingWave".into())),
524 Some(ScalarImpl::Bytea([0xbe, 0xef].into())),
525 Some(ScalarImpl::Float32(3.5f32.into())),
526 Some(ScalarImpl::Float64(4.25f64.into())),
527 Some(ScalarImpl::Int32(22)),
528 Some(ScalarImpl::Int64(23)),
529 Some(ScalarImpl::Int32(24)),
530 None,
531 Some(ScalarImpl::Int32(26)),
532 Some(ScalarImpl::Int64(27)),
533 Some(ScalarImpl::Struct(StructValue::new(vec![
534 Some(ScalarImpl::Int32(1)),
535 Some(ScalarImpl::Utf8("".into())),
536 ]))),
537 Some(ScalarImpl::List(ListValue::from_iter([4, 0, 4]))),
538 Some(ScalarImpl::Timestamptz(Timestamptz::from_micros(3))),
539 Some(ScalarImpl::Map(
540 MapValue::try_from_kv(
541 ListValue::from_iter(["a", "b"]),
542 ListValue::from_iter([1, 2]),
543 )
544 .unwrap(),
545 )),
546 {
547 let mut struct_array_builder = StructArrayBuilder::with_type(
548 2,
549 DataType::Struct(StructType::new(vec![
550 ("id", DataType::Int32),
551 ("name", DataType::Varchar),
552 ])),
553 );
554 struct_array_builder.append(Some(
555 StructValue::new(vec![
556 Some(ScalarImpl::Int32(1)),
557 Some(ScalarImpl::Utf8("x".into())),
558 ])
559 .as_scalar_ref(),
560 ));
561 struct_array_builder.append(Some(
562 StructValue::new(vec![
563 Some(ScalarImpl::Int32(2)),
564 Some(ScalarImpl::Utf8("y".into())),
565 ])
566 .as_scalar_ref(),
567 ));
568 Some(ScalarImpl::Map(
569 MapValue::try_from_kv(
570 ListValue::from_iter(["a", "b"]),
571 ListValue::new(struct_array_builder.finish().into()),
572 )
573 .unwrap(),
574 ))
575 },
576 ]);
577
578 let encoder =
579 ProtoEncoder::new(schema, None, descriptor.clone(), ProtoHeader::None).unwrap();
580 let m = encoder.encode(row).unwrap();
581 expect_test::expect![[r#"
582 field: FieldDescriptor {
583 name: "double_field",
584 full_name: "all_types.AllTypes.double_field",
585 json_name: "doubleField",
586 number: 1,
587 kind: double,
588 cardinality: Optional,
589 containing_oneof: None,
590 default_value: None,
591 is_group: false,
592 is_list: false,
593 is_map: false,
594 is_packed: false,
595 supports_presence: false,
596 }
597
598 value: F64(4.25)
599
600 ==============================
601 field: FieldDescriptor {
602 name: "float_field",
603 full_name: "all_types.AllTypes.float_field",
604 json_name: "floatField",
605 number: 2,
606 kind: float,
607 cardinality: Optional,
608 containing_oneof: None,
609 default_value: None,
610 is_group: false,
611 is_list: false,
612 is_map: false,
613 is_packed: false,
614 supports_presence: false,
615 }
616
617 value: F32(3.5)
618
619 ==============================
620 field: FieldDescriptor {
621 name: "int32_field",
622 full_name: "all_types.AllTypes.int32_field",
623 json_name: "int32Field",
624 number: 3,
625 kind: int32,
626 cardinality: Optional,
627 containing_oneof: None,
628 default_value: None,
629 is_group: false,
630 is_list: false,
631 is_map: false,
632 is_packed: false,
633 supports_presence: false,
634 }
635
636 value: I32(22)
637
638 ==============================
639 field: FieldDescriptor {
640 name: "int64_field",
641 full_name: "all_types.AllTypes.int64_field",
642 json_name: "int64Field",
643 number: 4,
644 kind: int64,
645 cardinality: Optional,
646 containing_oneof: None,
647 default_value: None,
648 is_group: false,
649 is_list: false,
650 is_map: false,
651 is_packed: false,
652 supports_presence: false,
653 }
654
655 value: I64(23)
656
657 ==============================
658 field: FieldDescriptor {
659 name: "sint32_field",
660 full_name: "all_types.AllTypes.sint32_field",
661 json_name: "sint32Field",
662 number: 7,
663 kind: sint32,
664 cardinality: Optional,
665 containing_oneof: None,
666 default_value: None,
667 is_group: false,
668 is_list: false,
669 is_map: false,
670 is_packed: false,
671 supports_presence: false,
672 }
673
674 value: I32(24)
675
676 ==============================
677 field: FieldDescriptor {
678 name: "sfixed32_field",
679 full_name: "all_types.AllTypes.sfixed32_field",
680 json_name: "sfixed32Field",
681 number: 11,
682 kind: sfixed32,
683 cardinality: Optional,
684 containing_oneof: None,
685 default_value: None,
686 is_group: false,
687 is_list: false,
688 is_map: false,
689 is_packed: false,
690 supports_presence: false,
691 }
692
693 value: I32(26)
694
695 ==============================
696 field: FieldDescriptor {
697 name: "sfixed64_field",
698 full_name: "all_types.AllTypes.sfixed64_field",
699 json_name: "sfixed64Field",
700 number: 12,
701 kind: sfixed64,
702 cardinality: Optional,
703 containing_oneof: None,
704 default_value: None,
705 is_group: false,
706 is_list: false,
707 is_map: false,
708 is_packed: false,
709 supports_presence: false,
710 }
711
712 value: I64(27)
713
714 ==============================
715 field: FieldDescriptor {
716 name: "bool_field",
717 full_name: "all_types.AllTypes.bool_field",
718 json_name: "boolField",
719 number: 13,
720 kind: bool,
721 cardinality: Optional,
722 containing_oneof: None,
723 default_value: None,
724 is_group: false,
725 is_list: false,
726 is_map: false,
727 is_packed: false,
728 supports_presence: false,
729 }
730
731 value: Bool(true)
732
733 ==============================
734 field: FieldDescriptor {
735 name: "string_field",
736 full_name: "all_types.AllTypes.string_field",
737 json_name: "stringField",
738 number: 14,
739 kind: string,
740 cardinality: Optional,
741 containing_oneof: None,
742 default_value: None,
743 is_group: false,
744 is_list: false,
745 is_map: false,
746 is_packed: false,
747 supports_presence: false,
748 }
749
750 value: String("RisingWave")
751
752 ==============================
753 field: FieldDescriptor {
754 name: "bytes_field",
755 full_name: "all_types.AllTypes.bytes_field",
756 json_name: "bytesField",
757 number: 15,
758 kind: bytes,
759 cardinality: Optional,
760 containing_oneof: None,
761 default_value: None,
762 is_group: false,
763 is_list: false,
764 is_map: false,
765 is_packed: false,
766 supports_presence: false,
767 }
768
769 value: Bytes(b"\xbe\xef")
770
771 ==============================
772 field: FieldDescriptor {
773 name: "nested_message_field",
774 full_name: "all_types.AllTypes.nested_message_field",
775 json_name: "nestedMessageField",
776 number: 17,
777 kind: all_types.AllTypes.NestedMessage,
778 cardinality: Optional,
779 containing_oneof: None,
780 default_value: None,
781 is_group: false,
782 is_list: false,
783 is_map: false,
784 is_packed: false,
785 supports_presence: true,
786 }
787
788 value: Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(1)), 2: Value(String(""))} } })
789
790 ==============================
791 field: FieldDescriptor {
792 name: "repeated_int_field",
793 full_name: "all_types.AllTypes.repeated_int_field",
794 json_name: "repeatedIntField",
795 number: 18,
796 kind: int32,
797 cardinality: Repeated,
798 containing_oneof: None,
799 default_value: None,
800 is_group: false,
801 is_list: true,
802 is_map: false,
803 is_packed: true,
804 supports_presence: false,
805 }
806
807 value: List([I32(4), I32(0), I32(4)])
808
809 ==============================
810 field: FieldDescriptor {
811 name: "map_field",
812 full_name: "all_types.AllTypes.map_field",
813 json_name: "mapField",
814 number: 22,
815 kind: all_types.AllTypes.MapFieldEntry,
816 cardinality: Repeated,
817 containing_oneof: None,
818 default_value: None,
819 is_group: false,
820 is_list: false,
821 is_map: true,
822 is_packed: false,
823 supports_presence: false,
824 }
825
826 value: Map({
827 String("a"): I32(1),
828 String("b"): I32(2),
829 })
830
831 ==============================
832 field: FieldDescriptor {
833 name: "timestamp_field",
834 full_name: "all_types.AllTypes.timestamp_field",
835 json_name: "timestampField",
836 number: 23,
837 kind: google.protobuf.Timestamp,
838 cardinality: Optional,
839 containing_oneof: None,
840 default_value: None,
841 is_group: false,
842 is_list: false,
843 is_map: false,
844 is_packed: false,
845 supports_presence: true,
846 }
847
848 value: Message(DynamicMessage { desc: MessageDescriptor { name: "Timestamp", full_name: "google.protobuf.Timestamp", is_map_entry: false, fields: [FieldDescriptor { name: "seconds", full_name: "google.protobuf.Timestamp.seconds", json_name: "seconds", number: 1, kind: int64, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "nanos", full_name: "google.protobuf.Timestamp.nanos", json_name: "nanos", number: 2, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {2: Value(I32(3000))} } })
849
850 ==============================
851 field: FieldDescriptor {
852 name: "map_struct_field",
853 full_name: "all_types.AllTypes.map_struct_field",
854 json_name: "mapStructField",
855 number: 29,
856 kind: all_types.AllTypes.MapStructFieldEntry,
857 cardinality: Repeated,
858 containing_oneof: None,
859 default_value: None,
860 is_group: false,
861 is_list: false,
862 is_map: true,
863 is_packed: false,
864 supports_presence: false,
865 }
866
867 value: Map({
868 String("a"): Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(1)), 2: Value(String("x"))} } }),
869 String("b"): Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(2)), 2: Value(String("y"))} } }),
870 })"#]].assert_eq(&format!("{}",
871 m.message.fields().format_with("\n\n==============================\n", |(field,value),f| {
872 f(&format!("field: {:#?}\n\nvalue: {}", field, print_proto(value)))
873 })));
874 }
875
876 fn print_proto(value: &Value) -> String {
877 match value {
878 Value::Map(m) => {
879 let mut res = String::new();
880 res.push_str("Map({\n");
881 for (k, v) in m.iter().sorted_by_key(|(k, _v)| *k) {
882 res.push_str(&format!(
883 " {}: {},\n",
884 print_proto(&k.clone().into()),
885 print_proto(v)
886 ));
887 }
888 res.push_str("})");
889 res
890 }
891 _ => format!("{:?}", value),
892 }
893 }
894
895 #[test]
896 fn test_encode_proto_repeated() {
897 let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
898 .join("codec/tests/test_data/all-types.pb");
899 let pool_bytes = fs_err::read(pool_path).unwrap();
900 let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
901 let message_descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
902
903 let schema = Schema::new(vec![Field::with_name(
904 DataType::List(DataType::List(DataType::Int32.into()).into()),
905 "repeated_int_field",
906 )]);
907
908 let err = validate_fields(
909 schema
910 .fields
911 .iter()
912 .map(|f| (f.name.as_str(), &f.data_type)),
913 &message_descriptor,
914 )
915 .unwrap_err();
916 assert_eq!(
917 err.to_string(),
918 "encode 'repeated_int_field' error: cannot encode integer[] column as int32 field"
919 );
920
921 let schema = Schema::new(vec![Field::with_name(
922 DataType::List(DataType::Int32.into()),
923 "repeated_int_field",
924 )]);
925 let row = OwnedRow::new(vec![Some(ScalarImpl::List(ListValue::from_iter([
926 Some(0),
927 None,
928 Some(2),
929 Some(3),
930 ])))]);
931
932 let err = encode_fields(
933 schema
934 .fields
935 .iter()
936 .map(|f| (f.name.as_str(), &f.data_type))
937 .zip_eq_debug(row.iter()),
938 &message_descriptor,
939 )
940 .unwrap_err();
941 assert_eq!(
942 err.to_string(),
943 "encode 'repeated_int_field' error: array containing null not allowed as repeated field"
944 );
945 }
946
947 #[test]
948 fn test_encode_proto_err() {
949 let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
950 .join("codec/tests/test_data/all-types.pb");
951 let pool_bytes = std::fs::read(pool_path).unwrap();
952 let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
953 let message_descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
954
955 let err = validate_fields(
956 std::iter::once(("not_exists", &DataType::Int16)),
957 &message_descriptor,
958 )
959 .unwrap_err();
960 assert_eq!(
961 err.to_string(),
962 "encode 'not_exists' error: field not in proto"
963 );
964
965 let err = validate_fields(
966 std::iter::once(("map_field", &DataType::Jsonb)),
967 &message_descriptor,
968 )
969 .unwrap_err();
970 assert_eq!(
971 err.to_string(),
972 "encode 'map_field' error: cannot encode jsonb column as all_types.AllTypes.MapFieldEntry field"
973 );
974 }
975}