1use bytes::{BufMut, Bytes};
16use prost::Message;
17use prost_reflect::{
18 DynamicMessage, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage, Value,
19};
20use risingwave_common::array::VECTOR_ITEM_TYPE;
21use risingwave_common::catalog::Schema;
22use risingwave_common::row::Row;
23use risingwave_common::types::{DataType, DatumRef, MapType, ScalarRefImpl, StructType};
24use risingwave_common::util::iter_util::ZipEqDebug;
25
26use super::{FieldEncodeError, Result as SinkResult, RowEncoder, SerTo};
27
28type Result<T> = std::result::Result<T, FieldEncodeError>;
29
30pub struct ProtoEncoder {
31 schema: Schema,
32 col_indices: Option<Vec<usize>>,
33 descriptor: MessageDescriptor,
34 header: ProtoHeader,
35}
36
37#[derive(Debug, Clone, Copy)]
38pub enum ProtoHeader {
39 None,
40 ConfluentSchemaRegistry(i32),
45}
46
47impl ProtoEncoder {
48 pub fn new(
49 schema: Schema,
50 col_indices: Option<Vec<usize>>,
51 descriptor: MessageDescriptor,
52 header: ProtoHeader,
53 ) -> SinkResult<Self> {
54 match &col_indices {
55 Some(col_indices) => validate_fields(
56 col_indices.iter().map(|idx| {
57 let f = &schema[*idx];
58 (f.name.as_str(), &f.data_type)
59 }),
60 &descriptor,
61 )?,
62 None => validate_fields(
63 schema
64 .fields
65 .iter()
66 .map(|f| (f.name.as_str(), &f.data_type)),
67 &descriptor,
68 )?,
69 };
70
71 Ok(Self {
72 schema,
73 col_indices,
74 descriptor,
75 header,
76 })
77 }
78}
79
80pub struct ProtoEncoded {
81 pub message: DynamicMessage,
82 header: ProtoHeader,
83}
84
85impl RowEncoder for ProtoEncoder {
86 type Output = ProtoEncoded;
87
88 fn schema(&self) -> &Schema {
89 &self.schema
90 }
91
92 fn col_indices(&self) -> Option<&[usize]> {
93 self.col_indices.as_deref()
94 }
95
96 fn encode_cols(
97 &self,
98 row: impl Row,
99 col_indices: impl Iterator<Item = usize>,
100 ) -> SinkResult<Self::Output> {
101 encode_fields(
102 col_indices.map(|idx| {
103 let f = &self.schema[idx];
104 ((f.name.as_str(), &f.data_type), row.datum_at(idx))
105 }),
106 &self.descriptor,
107 )
108 .map_err(Into::into)
109 .map(|m| ProtoEncoded {
110 message: m,
111 header: self.header,
112 })
113 }
114}
115
116impl SerTo<Vec<u8>> for ProtoEncoded {
117 fn ser_to(self) -> SinkResult<Vec<u8>> {
118 let mut buf = Vec::new();
119 match self.header {
120 ProtoHeader::None => { }
121 ProtoHeader::ConfluentSchemaRegistry(schema_id) => {
122 buf.reserve(1 + 4);
123 buf.put_u8(0);
124 buf.put_i32(schema_id);
125 MessageIndexes::from(self.message.descriptor()).encode(&mut buf);
126 }
127 }
128 self.message.encode(&mut buf).unwrap();
129 Ok(buf)
130 }
131}
132
133struct MessageIndexes(Vec<i32>);
134
135impl MessageIndexes {
136 fn from(desc: MessageDescriptor) -> Self {
137 const TAG_FILE_MESSAGE: i32 = 4;
142 const TAG_MESSAGE_NESTED: i32 = 3;
144
145 let mut indexes = vec![];
146 let mut path = desc.path().iter().copied().array_chunks();
147 let [tag, idx] = path.next().unwrap();
148 assert_eq!(tag, TAG_FILE_MESSAGE);
149 indexes.push(idx);
150 for [tag, idx] in path {
151 assert_eq!(tag, TAG_MESSAGE_NESTED);
152 indexes.push(idx);
153 }
154 Self(indexes)
155 }
156
157 fn zig_i32(value: i32, buf: &mut impl BufMut) {
158 let unsigned = ((value << 1) ^ (value >> 31)) as u32 as u64;
159 prost::encoding::encode_varint(unsigned, buf);
160 }
161
162 fn encode(&self, buf: &mut impl BufMut) {
163 if self.0 == [0] {
164 buf.put_u8(0);
165 return;
166 }
167 Self::zig_i32(self.0.len().try_into().unwrap(), buf);
168 for &idx in &self.0 {
169 Self::zig_i32(idx, buf);
170 }
171 }
172}
173
174trait MaybeData: std::fmt::Debug {
180 type Out;
181
182 fn on_base(self, f: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out>;
183
184 fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out>;
185
186 fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out>;
187
188 fn on_map(self, m: &MapType, pb: &MessageDescriptor) -> Result<Self::Out>;
189}
190
191impl MaybeData for () {
192 type Out = ();
193
194 fn on_base(self, _: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out> {
195 Ok(self)
196 }
197
198 fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out> {
199 validate_fields(st.iter(), pb)
200 }
201
202 fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out> {
203 on_field(elem, (), pb, true)
204 }
205
206 fn on_map(self, elem: &MapType, pb: &MessageDescriptor) -> Result<Self::Out> {
207 debug_assert!(pb.is_map_entry());
208 on_field(elem.key(), (), &pb.map_entry_key_field(), false)?;
209 on_field(elem.value(), (), &pb.map_entry_value_field(), false)?;
210 Ok(())
211 }
212}
213
214impl MaybeData for ScalarRefImpl<'_> {
221 type Out = Value;
222
223 fn on_base(self, f: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out> {
224 f(self)
225 }
226
227 fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out> {
228 let d = self.into_struct();
229 let message = encode_fields(st.iter().zip_eq_debug(d.iter_fields_ref()), pb)?;
230 Ok(Value::Message(message))
231 }
232
233 fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out> {
234 let d = self.into_list();
235 let vs = d
236 .iter()
237 .map(|d| {
238 on_field(
239 elem,
240 d.ok_or_else(|| {
241 FieldEncodeError::new("array containing null not allowed as repeated field")
242 })?,
243 pb,
244 true,
245 )
246 })
247 .try_collect()?;
248 Ok(Value::List(vs))
249 }
250
251 fn on_map(self, m: &MapType, pb: &MessageDescriptor) -> Result<Self::Out> {
252 debug_assert!(pb.is_map_entry());
253 let vs = self
254 .into_map()
255 .iter()
256 .map(|(k, v)| {
257 let v =
258 v.ok_or_else(|| FieldEncodeError::new("map containing null not allowed"))?;
259 let k = on_field(m.key(), k, &pb.map_entry_key_field(), false)?;
260 let v = on_field(m.value(), v, &pb.map_entry_value_field(), false)?;
261 Ok((
262 k.into_map_key().ok_or_else(|| {
263 FieldEncodeError::new("failed to convert map key to proto")
264 })?,
265 v,
266 ))
267 })
268 .try_collect()?;
269 Ok(Value::Map(vs))
270 }
271}
272
273fn validate_fields<'a>(
274 fields: impl Iterator<Item = (&'a str, &'a DataType)>,
275 descriptor: &MessageDescriptor,
276) -> Result<()> {
277 for (name, t) in fields {
278 let Some(proto_field) = descriptor.get_field_by_name(name) else {
279 return Err(FieldEncodeError::new("field not in proto").with_name(name));
280 };
281 if proto_field.cardinality() == prost_reflect::Cardinality::Required {
282 return Err(FieldEncodeError::new("`required` not supported").with_name(name));
283 }
284 on_field(t, (), &proto_field, false).map_err(|e| e.with_name(name))?;
285 }
286 Ok(())
287}
288
289fn encode_fields<'a>(
290 fields_with_datums: impl Iterator<Item = ((&'a str, &'a DataType), DatumRef<'a>)>,
291 descriptor: &MessageDescriptor,
292) -> Result<DynamicMessage> {
293 let mut message = DynamicMessage::new(descriptor.clone());
294 for ((name, t), d) in fields_with_datums {
295 let proto_field = descriptor.get_field_by_name(name).unwrap();
296 if let Some(scalar) = d {
298 let value = on_field(t, scalar, &proto_field, false).map_err(|e| e.with_name(name))?;
299 message
300 .try_set_field(&proto_field, value)
301 .map_err(|e| FieldEncodeError::new(e).with_name(name))?;
302 }
303 }
304 Ok(message)
305}
306
307const WKT_TIMESTAMP: &str = "google.protobuf.Timestamp";
309#[expect(dead_code)]
310const WKT_BOOL_VALUE: &str = "google.protobuf.BoolValue";
311
312fn on_field<D: MaybeData>(
315 data_type: &DataType,
316 maybe: D,
317 proto_field: &FieldDescriptor,
318 in_repeated: bool,
319) -> Result<D::Out> {
320 assert!(proto_field.is_list() || !in_repeated);
327 let expect_list = proto_field.is_list() && !in_repeated;
328 if proto_field.is_group() {
329 return Err(FieldEncodeError::new("proto group not supported yet"));
330 }
331
332 let no_match_err = || {
333 Err(FieldEncodeError::new(format!(
334 "cannot encode {} column as {}{:?} field",
335 data_type,
336 if expect_list { "repeated " } else { "" },
337 proto_field.kind()
338 )))
339 };
340
341 if expect_list && !matches!(data_type, DataType::List(_)) {
342 return no_match_err();
343 }
344
345 let value = match &data_type {
346 DataType::Boolean => match proto_field.kind() {
348 Kind::Bool => maybe.on_base(|s| Ok(Value::Bool(s.into_bool())))?,
349 _ => return no_match_err(),
350 },
351 DataType::Varchar => match proto_field.kind() {
352 Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_utf8().into())))?,
353 Kind::Enum(enum_desc) => maybe.on_base(|s| {
354 let name = s.into_utf8();
355 let enum_value_desc = enum_desc.get_value_by_name(name).ok_or_else(|| {
356 FieldEncodeError::new(format!("'{name}' not in enum {}", enum_desc.name()))
357 })?;
358 Ok(Value::EnumNumber(enum_value_desc.number()))
359 })?,
360 _ => return no_match_err(),
361 },
362 DataType::Bytea => match proto_field.kind() {
363 Kind::Bytes => {
364 maybe.on_base(|s| Ok(Value::Bytes(Bytes::copy_from_slice(s.into_bytea()))))?
365 }
366 _ => return no_match_err(),
367 },
368 DataType::Float32 => match proto_field.kind() {
369 Kind::Float => maybe.on_base(|s| Ok(Value::F32(s.into_float32().into())))?,
370 _ => return no_match_err(),
371 },
372 DataType::Float64 => match proto_field.kind() {
373 Kind::Double => maybe.on_base(|s| Ok(Value::F64(s.into_float64().into())))?,
374 _ => return no_match_err(),
375 },
376 DataType::Int32 => match proto_field.kind() {
377 Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => {
378 maybe.on_base(|s| Ok(Value::I32(s.into_int32())))?
379 }
380 _ => return no_match_err(),
381 },
382 DataType::Int64 => match proto_field.kind() {
383 Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => {
384 maybe.on_base(|s| Ok(Value::I64(s.into_int64())))?
385 }
386 _ => return no_match_err(),
387 },
388 DataType::Struct(st) => match proto_field.kind() {
389 Kind::Message(pb) => maybe.on_struct(st, &pb)?,
390 _ => return no_match_err(),
391 },
392 DataType::List(lt) => match expect_list {
393 true => maybe.on_list(lt.elem(), proto_field)?,
394 false => return no_match_err(),
395 },
396 DataType::Timestamptz => match proto_field.kind() {
398 Kind::Message(pb) if pb.full_name() == WKT_TIMESTAMP => maybe.on_base(|s| {
399 let d = s.into_timestamptz();
400 let message = prost_types::Timestamp {
401 seconds: d.timestamp(),
402 nanos: d.timestamp_subsec_nanos().try_into().unwrap(),
403 };
404 Ok(Value::Message(message.transcode_to_dynamic()))
405 })?,
406 Kind::String => {
407 maybe.on_base(|s| Ok(Value::String(s.into_timestamptz().to_string())))?
408 }
409 _ => return no_match_err(),
410 },
411 DataType::Jsonb => match proto_field.kind() {
412 Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_jsonb().to_string())))?,
413 _ => return no_match_err(), },
416 DataType::Int16 => match proto_field.kind() {
417 Kind::Int64 => maybe.on_base(|s| Ok(Value::I64(s.into_int16() as i64)))?,
418 _ => return no_match_err(),
419 },
420 DataType::Date => match proto_field.kind() {
421 Kind::Int32 => {
422 maybe.on_base(|s| Ok(Value::I32(s.into_date().get_nums_days_unix_epoch())))?
423 }
424 _ => return no_match_err(), },
426 DataType::Time => match proto_field.kind() {
427 Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_time().to_string())))?,
428 _ => return no_match_err(), },
430 DataType::Timestamp => match proto_field.kind() {
431 Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_timestamp().to_string())))?,
432 _ => return no_match_err(), },
434 DataType::Decimal => match proto_field.kind() {
435 Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_decimal().to_string())))?,
436 _ => return no_match_err(), },
438 DataType::Interval => match proto_field.kind() {
439 Kind::String => {
440 maybe.on_base(|s| Ok(Value::String(s.into_interval().as_iso_8601())))?
441 }
442 _ => return no_match_err(), },
444 DataType::Serial => match proto_field.kind() {
445 Kind::Int64 => maybe.on_base(|s| Ok(Value::I64(s.into_serial().as_row_id())))?,
446 _ => return no_match_err(), },
448 DataType::Int256 => {
449 return no_match_err();
450 }
451 DataType::Map(map_type) => {
452 if proto_field.is_map() {
453 let msg = match proto_field.kind() {
454 Kind::Message(m) => m,
455 _ => return no_match_err(), };
457 return maybe.on_map(map_type, &msg);
458 } else {
459 return no_match_err();
460 }
461 }
462 DataType::Vector(_) => match expect_list {
463 true => maybe.on_list(&VECTOR_ITEM_TYPE, proto_field)?,
464 false => return no_match_err(),
465 },
466 };
467
468 Ok(value)
469}
470
471#[cfg(test)]
472mod tests {
473 use itertools::Itertools;
474 use risingwave_common::array::{ArrayBuilder, StructArrayBuilder};
475 use risingwave_common::catalog::Field;
476 use risingwave_common::row::OwnedRow;
477 use risingwave_common::types::{
478 ListValue, MapType, MapValue, Scalar, ScalarImpl, StructValue, Timestamptz,
479 };
480
481 use super::*;
482
483 #[test]
484 fn test_encode_proto_ok() {
485 let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
486 .join("codec/tests/test_data/all-types.pb");
487 let pool_bytes = std::fs::read(pool_path).unwrap();
488 let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
489 let descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
490 let schema = Schema::new(vec![
491 Field::with_name(DataType::Boolean, "bool_field"),
492 Field::with_name(DataType::Varchar, "string_field"),
493 Field::with_name(DataType::Bytea, "bytes_field"),
494 Field::with_name(DataType::Float32, "float_field"),
495 Field::with_name(DataType::Float64, "double_field"),
496 Field::with_name(DataType::Int32, "int32_field"),
497 Field::with_name(DataType::Int64, "int64_field"),
498 Field::with_name(DataType::Int32, "sint32_field"),
499 Field::with_name(DataType::Int64, "sint64_field"),
500 Field::with_name(DataType::Int32, "sfixed32_field"),
501 Field::with_name(DataType::Int64, "sfixed64_field"),
502 Field::with_name(
503 DataType::Struct(StructType::new(vec![
504 ("id", DataType::Int32),
505 ("name", DataType::Varchar),
506 ])),
507 "nested_message_field",
508 ),
509 Field::with_name(DataType::Int32.list(), "repeated_int_field"),
510 Field::with_name(DataType::Timestamptz, "timestamp_field"),
511 Field::with_name(
512 DataType::Map(MapType::from_kv(DataType::Varchar, DataType::Int32)),
513 "map_field",
514 ),
515 Field::with_name(
516 DataType::Map(MapType::from_kv(
517 DataType::Varchar,
518 DataType::Struct(StructType::new(vec![
519 ("id", DataType::Int32),
520 ("name", DataType::Varchar),
521 ])),
522 )),
523 "map_struct_field",
524 ),
525 ]);
526 let row = OwnedRow::new(vec![
527 Some(ScalarImpl::Bool(true)),
528 Some(ScalarImpl::Utf8("RisingWave".into())),
529 Some(ScalarImpl::Bytea([0xbe, 0xef].into())),
530 Some(ScalarImpl::Float32(3.5f32.into())),
531 Some(ScalarImpl::Float64(4.25f64.into())),
532 Some(ScalarImpl::Int32(22)),
533 Some(ScalarImpl::Int64(23)),
534 Some(ScalarImpl::Int32(24)),
535 None,
536 Some(ScalarImpl::Int32(26)),
537 Some(ScalarImpl::Int64(27)),
538 Some(ScalarImpl::Struct(StructValue::new(vec![
539 Some(ScalarImpl::Int32(1)),
540 Some(ScalarImpl::Utf8("".into())),
541 ]))),
542 Some(ScalarImpl::List(ListValue::from_iter([4, 0, 4]))),
543 Some(ScalarImpl::Timestamptz(Timestamptz::from_micros(3))),
544 Some(ScalarImpl::Map(
545 MapValue::try_from_kv(
546 ListValue::from_iter(["a", "b"]),
547 ListValue::from_iter([1, 2]),
548 )
549 .unwrap(),
550 )),
551 {
552 let mut struct_array_builder = StructArrayBuilder::with_type(
553 2,
554 DataType::Struct(StructType::new(vec![
555 ("id", DataType::Int32),
556 ("name", DataType::Varchar),
557 ])),
558 );
559 struct_array_builder.append(Some(
560 StructValue::new(vec![
561 Some(ScalarImpl::Int32(1)),
562 Some(ScalarImpl::Utf8("x".into())),
563 ])
564 .as_scalar_ref(),
565 ));
566 struct_array_builder.append(Some(
567 StructValue::new(vec![
568 Some(ScalarImpl::Int32(2)),
569 Some(ScalarImpl::Utf8("y".into())),
570 ])
571 .as_scalar_ref(),
572 ));
573 Some(ScalarImpl::Map(
574 MapValue::try_from_kv(
575 ListValue::from_iter(["a", "b"]),
576 ListValue::new(struct_array_builder.finish().into()),
577 )
578 .unwrap(),
579 ))
580 },
581 ]);
582
583 let encoder = ProtoEncoder::new(schema, None, descriptor, ProtoHeader::None).unwrap();
584 let m = encoder.encode(row).unwrap();
585 expect_test::expect![[r#"
586 field: FieldDescriptor {
587 name: "double_field",
588 full_name: "all_types.AllTypes.double_field",
589 json_name: "doubleField",
590 number: 1,
591 kind: double,
592 cardinality: Optional,
593 containing_oneof: None,
594 default_value: F64(
595 0.0,
596 ),
597 is_group: false,
598 is_list: false,
599 is_map: false,
600 is_packed: false,
601 supports_presence: false,
602 }
603
604 value: F64(4.25)
605
606 ==============================
607 field: FieldDescriptor {
608 name: "float_field",
609 full_name: "all_types.AllTypes.float_field",
610 json_name: "floatField",
611 number: 2,
612 kind: float,
613 cardinality: Optional,
614 containing_oneof: None,
615 default_value: F32(
616 0.0,
617 ),
618 is_group: false,
619 is_list: false,
620 is_map: false,
621 is_packed: false,
622 supports_presence: false,
623 }
624
625 value: F32(3.5)
626
627 ==============================
628 field: FieldDescriptor {
629 name: "int32_field",
630 full_name: "all_types.AllTypes.int32_field",
631 json_name: "int32Field",
632 number: 3,
633 kind: int32,
634 cardinality: Optional,
635 containing_oneof: None,
636 default_value: I32(
637 0,
638 ),
639 is_group: false,
640 is_list: false,
641 is_map: false,
642 is_packed: false,
643 supports_presence: false,
644 }
645
646 value: I32(22)
647
648 ==============================
649 field: FieldDescriptor {
650 name: "int64_field",
651 full_name: "all_types.AllTypes.int64_field",
652 json_name: "int64Field",
653 number: 4,
654 kind: int64,
655 cardinality: Optional,
656 containing_oneof: None,
657 default_value: I64(
658 0,
659 ),
660 is_group: false,
661 is_list: false,
662 is_map: false,
663 is_packed: false,
664 supports_presence: false,
665 }
666
667 value: I64(23)
668
669 ==============================
670 field: FieldDescriptor {
671 name: "sint32_field",
672 full_name: "all_types.AllTypes.sint32_field",
673 json_name: "sint32Field",
674 number: 7,
675 kind: sint32,
676 cardinality: Optional,
677 containing_oneof: None,
678 default_value: I32(
679 0,
680 ),
681 is_group: false,
682 is_list: false,
683 is_map: false,
684 is_packed: false,
685 supports_presence: false,
686 }
687
688 value: I32(24)
689
690 ==============================
691 field: FieldDescriptor {
692 name: "sfixed32_field",
693 full_name: "all_types.AllTypes.sfixed32_field",
694 json_name: "sfixed32Field",
695 number: 11,
696 kind: sfixed32,
697 cardinality: Optional,
698 containing_oneof: None,
699 default_value: I32(
700 0,
701 ),
702 is_group: false,
703 is_list: false,
704 is_map: false,
705 is_packed: false,
706 supports_presence: false,
707 }
708
709 value: I32(26)
710
711 ==============================
712 field: FieldDescriptor {
713 name: "sfixed64_field",
714 full_name: "all_types.AllTypes.sfixed64_field",
715 json_name: "sfixed64Field",
716 number: 12,
717 kind: sfixed64,
718 cardinality: Optional,
719 containing_oneof: None,
720 default_value: I64(
721 0,
722 ),
723 is_group: false,
724 is_list: false,
725 is_map: false,
726 is_packed: false,
727 supports_presence: false,
728 }
729
730 value: I64(27)
731
732 ==============================
733 field: FieldDescriptor {
734 name: "bool_field",
735 full_name: "all_types.AllTypes.bool_field",
736 json_name: "boolField",
737 number: 13,
738 kind: bool,
739 cardinality: Optional,
740 containing_oneof: None,
741 default_value: Bool(
742 false,
743 ),
744 is_group: false,
745 is_list: false,
746 is_map: false,
747 is_packed: false,
748 supports_presence: false,
749 }
750
751 value: Bool(true)
752
753 ==============================
754 field: FieldDescriptor {
755 name: "string_field",
756 full_name: "all_types.AllTypes.string_field",
757 json_name: "stringField",
758 number: 14,
759 kind: string,
760 cardinality: Optional,
761 containing_oneof: None,
762 default_value: String(
763 "",
764 ),
765 is_group: false,
766 is_list: false,
767 is_map: false,
768 is_packed: false,
769 supports_presence: false,
770 }
771
772 value: String("RisingWave")
773
774 ==============================
775 field: FieldDescriptor {
776 name: "bytes_field",
777 full_name: "all_types.AllTypes.bytes_field",
778 json_name: "bytesField",
779 number: 15,
780 kind: bytes,
781 cardinality: Optional,
782 containing_oneof: None,
783 default_value: Bytes(
784 b"",
785 ),
786 is_group: false,
787 is_list: false,
788 is_map: false,
789 is_packed: false,
790 supports_presence: false,
791 }
792
793 value: Bytes(b"\xbe\xef")
794
795 ==============================
796 field: FieldDescriptor {
797 name: "nested_message_field",
798 full_name: "all_types.AllTypes.nested_message_field",
799 json_name: "nestedMessageField",
800 number: 17,
801 kind: all_types.AllTypes.NestedMessage,
802 cardinality: Optional,
803 containing_oneof: None,
804 default_value: Message(
805 DynamicMessage {
806 desc: MessageDescriptor {
807 name: "NestedMessage",
808 full_name: "all_types.AllTypes.NestedMessage",
809 is_map_entry: false,
810 fields: [
811 FieldDescriptor {
812 name: "id",
813 full_name: "all_types.AllTypes.NestedMessage.id",
814 json_name: "id",
815 number: 1,
816 kind: int32,
817 cardinality: Optional,
818 containing_oneof: None,
819 default_value: I32(
820 0,
821 ),
822 is_group: false,
823 is_list: false,
824 is_map: false,
825 is_packed: false,
826 supports_presence: false,
827 },
828 FieldDescriptor {
829 name: "name",
830 full_name: "all_types.AllTypes.NestedMessage.name",
831 json_name: "name",
832 number: 2,
833 kind: string,
834 cardinality: Optional,
835 containing_oneof: None,
836 default_value: String(
837 "",
838 ),
839 is_group: false,
840 is_list: false,
841 is_map: false,
842 is_packed: false,
843 supports_presence: false,
844 },
845 ],
846 oneofs: [],
847 },
848 fields: DynamicMessageFieldSet {
849 fields: {},
850 },
851 },
852 ),
853 is_group: false,
854 is_list: false,
855 is_map: false,
856 is_packed: false,
857 supports_presence: true,
858 }
859
860 value: Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: I32(0), is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: String(""), is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(1)), 2: Value(String(""))} } })
861
862 ==============================
863 field: FieldDescriptor {
864 name: "repeated_int_field",
865 full_name: "all_types.AllTypes.repeated_int_field",
866 json_name: "repeatedIntField",
867 number: 18,
868 kind: int32,
869 cardinality: Repeated,
870 containing_oneof: None,
871 default_value: List(
872 [],
873 ),
874 is_group: false,
875 is_list: true,
876 is_map: false,
877 is_packed: true,
878 supports_presence: false,
879 }
880
881 value: List([I32(4), I32(0), I32(4)])
882
883 ==============================
884 field: FieldDescriptor {
885 name: "map_field",
886 full_name: "all_types.AllTypes.map_field",
887 json_name: "mapField",
888 number: 22,
889 kind: all_types.AllTypes.MapFieldEntry,
890 cardinality: Repeated,
891 containing_oneof: None,
892 default_value: Map(
893 {},
894 ),
895 is_group: false,
896 is_list: false,
897 is_map: true,
898 is_packed: false,
899 supports_presence: false,
900 }
901
902 value: Map({
903 String("a"): I32(1),
904 String("b"): I32(2),
905 })
906
907 ==============================
908 field: FieldDescriptor {
909 name: "timestamp_field",
910 full_name: "all_types.AllTypes.timestamp_field",
911 json_name: "timestampField",
912 number: 23,
913 kind: google.protobuf.Timestamp,
914 cardinality: Optional,
915 containing_oneof: None,
916 default_value: Message(
917 DynamicMessage {
918 desc: MessageDescriptor {
919 name: "Timestamp",
920 full_name: "google.protobuf.Timestamp",
921 is_map_entry: false,
922 fields: [
923 FieldDescriptor {
924 name: "seconds",
925 full_name: "google.protobuf.Timestamp.seconds",
926 json_name: "seconds",
927 number: 1,
928 kind: int64,
929 cardinality: Optional,
930 containing_oneof: None,
931 default_value: I64(
932 0,
933 ),
934 is_group: false,
935 is_list: false,
936 is_map: false,
937 is_packed: false,
938 supports_presence: false,
939 },
940 FieldDescriptor {
941 name: "nanos",
942 full_name: "google.protobuf.Timestamp.nanos",
943 json_name: "nanos",
944 number: 2,
945 kind: int32,
946 cardinality: Optional,
947 containing_oneof: None,
948 default_value: I32(
949 0,
950 ),
951 is_group: false,
952 is_list: false,
953 is_map: false,
954 is_packed: false,
955 supports_presence: false,
956 },
957 ],
958 oneofs: [],
959 },
960 fields: DynamicMessageFieldSet {
961 fields: {},
962 },
963 },
964 ),
965 is_group: false,
966 is_list: false,
967 is_map: false,
968 is_packed: false,
969 supports_presence: true,
970 }
971
972 value: Message(DynamicMessage { desc: MessageDescriptor { name: "Timestamp", full_name: "google.protobuf.Timestamp", is_map_entry: false, fields: [FieldDescriptor { name: "seconds", full_name: "google.protobuf.Timestamp.seconds", json_name: "seconds", number: 1, kind: int64, cardinality: Optional, containing_oneof: None, default_value: I64(0), is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "nanos", full_name: "google.protobuf.Timestamp.nanos", json_name: "nanos", number: 2, kind: int32, cardinality: Optional, containing_oneof: None, default_value: I32(0), is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {2: Value(I32(3000))} } })
973
974 ==============================
975 field: FieldDescriptor {
976 name: "map_struct_field",
977 full_name: "all_types.AllTypes.map_struct_field",
978 json_name: "mapStructField",
979 number: 29,
980 kind: all_types.AllTypes.MapStructFieldEntry,
981 cardinality: Repeated,
982 containing_oneof: None,
983 default_value: Map(
984 {},
985 ),
986 is_group: false,
987 is_list: false,
988 is_map: true,
989 is_packed: false,
990 supports_presence: false,
991 }
992
993 value: Map({
994 String("a"): Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: I32(0), is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: String(""), is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(1)), 2: Value(String("x"))} } }),
995 String("b"): Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: I32(0), is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: String(""), is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(2)), 2: Value(String("y"))} } }),
996 })"#]].assert_eq(&format!("{}",
997 m.message.fields().format_with("\n\n==============================\n", |(field,value),f| {
998 f(&format!("field: {:#?}\n\nvalue: {}", field, print_proto(value)))
999 })));
1000 }
1001
1002 fn print_proto(value: &Value) -> String {
1003 match value {
1004 Value::Map(m) => {
1005 let mut res = String::new();
1006 res.push_str("Map({\n");
1007 for (k, v) in m.iter().sorted_by_key(|(k, _v)| *k) {
1008 res.push_str(&format!(
1009 " {}: {},\n",
1010 print_proto(&k.clone().into()),
1011 print_proto(v)
1012 ));
1013 }
1014 res.push_str("})");
1015 res
1016 }
1017 _ => format!("{:?}", value),
1018 }
1019 }
1020
1021 #[test]
1022 fn test_encode_proto_repeated() {
1023 let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
1024 .join("codec/tests/test_data/all-types.pb");
1025 let pool_bytes = fs_err::read(pool_path).unwrap();
1026 let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
1027 let message_descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
1028
1029 let schema = Schema::new(vec![Field::with_name(
1030 DataType::Int32.list().list(),
1031 "repeated_int_field",
1032 )]);
1033
1034 let err = validate_fields(
1035 schema
1036 .fields
1037 .iter()
1038 .map(|f| (f.name.as_str(), &f.data_type)),
1039 &message_descriptor,
1040 )
1041 .unwrap_err();
1042 assert_eq!(
1043 err.to_string(),
1044 "encode 'repeated_int_field' error: cannot encode integer[] column as int32 field"
1045 );
1046
1047 let schema = Schema::new(vec![Field::with_name(
1048 DataType::Int32.list(),
1049 "repeated_int_field",
1050 )]);
1051 let row = OwnedRow::new(vec![Some(ScalarImpl::List(ListValue::from_iter([
1052 Some(0),
1053 None,
1054 Some(2),
1055 Some(3),
1056 ])))]);
1057
1058 let err = encode_fields(
1059 schema
1060 .fields
1061 .iter()
1062 .map(|f| (f.name.as_str(), &f.data_type))
1063 .zip_eq_debug(row.iter()),
1064 &message_descriptor,
1065 )
1066 .unwrap_err();
1067 assert_eq!(
1068 err.to_string(),
1069 "encode 'repeated_int_field' error: array containing null not allowed as repeated field"
1070 );
1071 }
1072
1073 #[test]
1074 fn test_encode_proto_err() {
1075 let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
1076 .join("codec/tests/test_data/all-types.pb");
1077 let pool_bytes = std::fs::read(pool_path).unwrap();
1078 let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
1079 let message_descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
1080
1081 let err = validate_fields(
1082 std::iter::once(("not_exists", &DataType::Int16)),
1083 &message_descriptor,
1084 )
1085 .unwrap_err();
1086 assert_eq!(
1087 err.to_string(),
1088 "encode 'not_exists' error: field not in proto"
1089 );
1090
1091 let err = validate_fields(
1092 std::iter::once(("map_field", &DataType::Jsonb)),
1093 &message_descriptor,
1094 )
1095 .unwrap_err();
1096 assert_eq!(
1097 err.to_string(),
1098 "encode 'map_field' error: cannot encode jsonb column as all_types.AllTypes.MapFieldEntry field"
1099 );
1100 }
1101}