risingwave_connector/sink/encoder/
proto.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use bytes::{BufMut, Bytes};
16use prost::Message;
17use prost_reflect::{
18    DynamicMessage, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage, Value,
19};
20use risingwave_common::array::VECTOR_ITEM_TYPE;
21use risingwave_common::catalog::Schema;
22use risingwave_common::row::Row;
23use risingwave_common::types::{DataType, DatumRef, MapType, ScalarRefImpl, StructType};
24use risingwave_common::util::iter_util::ZipEqDebug;
25
26use super::{FieldEncodeError, Result as SinkResult, RowEncoder, SerTo};
27
28type Result<T> = std::result::Result<T, FieldEncodeError>;
29
30pub struct ProtoEncoder {
31    schema: Schema,
32    col_indices: Option<Vec<usize>>,
33    descriptor: MessageDescriptor,
34    header: ProtoHeader,
35}
36
37#[derive(Debug, Clone, Copy)]
38pub enum ProtoHeader {
39    None,
40    /// <https://docs.confluent.io/platform/7.5/schema-registry/fundamentals/serdes-develop/index.html#messages-wire-format>
41    ///
42    /// * 00
43    /// * 4-byte big-endian schema ID
44    ConfluentSchemaRegistry(i32),
45}
46
47impl ProtoEncoder {
48    pub fn new(
49        schema: Schema,
50        col_indices: Option<Vec<usize>>,
51        descriptor: MessageDescriptor,
52        header: ProtoHeader,
53    ) -> SinkResult<Self> {
54        match &col_indices {
55            Some(col_indices) => validate_fields(
56                col_indices.iter().map(|idx| {
57                    let f = &schema[*idx];
58                    (f.name.as_str(), &f.data_type)
59                }),
60                &descriptor,
61            )?,
62            None => validate_fields(
63                schema
64                    .fields
65                    .iter()
66                    .map(|f| (f.name.as_str(), &f.data_type)),
67                &descriptor,
68            )?,
69        };
70
71        Ok(Self {
72            schema,
73            col_indices,
74            descriptor,
75            header,
76        })
77    }
78}
79
80pub struct ProtoEncoded {
81    pub message: DynamicMessage,
82    header: ProtoHeader,
83}
84
85impl RowEncoder for ProtoEncoder {
86    type Output = ProtoEncoded;
87
88    fn schema(&self) -> &Schema {
89        &self.schema
90    }
91
92    fn col_indices(&self) -> Option<&[usize]> {
93        self.col_indices.as_deref()
94    }
95
96    fn encode_cols(
97        &self,
98        row: impl Row,
99        col_indices: impl Iterator<Item = usize>,
100    ) -> SinkResult<Self::Output> {
101        encode_fields(
102            col_indices.map(|idx| {
103                let f = &self.schema[idx];
104                ((f.name.as_str(), &f.data_type), row.datum_at(idx))
105            }),
106            &self.descriptor,
107        )
108        .map_err(Into::into)
109        .map(|m| ProtoEncoded {
110            message: m,
111            header: self.header,
112        })
113    }
114}
115
116impl SerTo<Vec<u8>> for ProtoEncoded {
117    fn ser_to(self) -> SinkResult<Vec<u8>> {
118        let mut buf = Vec::new();
119        match self.header {
120            ProtoHeader::None => { /* noop */ }
121            ProtoHeader::ConfluentSchemaRegistry(schema_id) => {
122                buf.reserve(1 + 4);
123                buf.put_u8(0);
124                buf.put_i32(schema_id);
125                MessageIndexes::from(self.message.descriptor()).encode(&mut buf);
126            }
127        }
128        self.message.encode(&mut buf).unwrap();
129        Ok(buf)
130    }
131}
132
133struct MessageIndexes(Vec<i32>);
134
135impl MessageIndexes {
136    fn from(desc: MessageDescriptor) -> Self {
137        // https://github.com/protocolbuffers/protobuf/blob/v25.1/src/google/protobuf/descriptor.proto
138        // https://docs.rs/prost-reflect/0.12.0/src/prost_reflect/descriptor/tag.rs.html
139        // https://docs.rs/prost-reflect/0.12.0/src/prost_reflect/descriptor/build/visit.rs.html#125
140        // `FileDescriptorProto` field #4 is `repeated DescriptorProto message_type`
141        const TAG_FILE_MESSAGE: i32 = 4;
142        // `DescriptorProto` field #3 is `repeated DescriptorProto nested_type`
143        const TAG_MESSAGE_NESTED: i32 = 3;
144
145        let mut indexes = vec![];
146        let mut path = desc.path().array_chunks();
147        let &[tag, idx] = path.next().unwrap();
148        assert_eq!(tag, TAG_FILE_MESSAGE);
149        indexes.push(idx);
150        for &[tag, idx] in path {
151            assert_eq!(tag, TAG_MESSAGE_NESTED);
152            indexes.push(idx);
153        }
154        Self(indexes)
155    }
156
157    fn zig_i32(value: i32, buf: &mut impl BufMut) {
158        let unsigned = ((value << 1) ^ (value >> 31)) as u32 as u64;
159        prost::encoding::encode_varint(unsigned, buf);
160    }
161
162    fn encode(&self, buf: &mut impl BufMut) {
163        if self.0 == [0] {
164            buf.put_u8(0);
165            return;
166        }
167        Self::zig_i32(self.0.len().try_into().unwrap(), buf);
168        for &idx in &self.0 {
169            Self::zig_i32(idx, buf);
170        }
171    }
172}
173
174/// A trait that assists code reuse between `validate` and `encode`.
175/// * For `validate`, the inputs are (RisingWave type, ProtoBuf type).
176/// * For `encode`, the inputs are (RisingWave type, RisingWave data, ProtoBuf type).
177///
178/// Thus we impl [`MaybeData`] for both [`()`] and [`ScalarRefImpl`].
179trait MaybeData: std::fmt::Debug {
180    type Out;
181
182    fn on_base(self, f: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out>;
183
184    fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out>;
185
186    fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out>;
187
188    fn on_map(self, m: &MapType, pb: &MessageDescriptor) -> Result<Self::Out>;
189}
190
191impl MaybeData for () {
192    type Out = ();
193
194    fn on_base(self, _: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out> {
195        Ok(self)
196    }
197
198    fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out> {
199        validate_fields(st.iter(), pb)
200    }
201
202    fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out> {
203        on_field(elem, (), pb, true)
204    }
205
206    fn on_map(self, elem: &MapType, pb: &MessageDescriptor) -> Result<Self::Out> {
207        debug_assert!(pb.is_map_entry());
208        on_field(elem.key(), (), &pb.map_entry_key_field(), false)?;
209        on_field(elem.value(), (), &pb.map_entry_value_field(), false)?;
210        Ok(())
211    }
212}
213
214/// Nullability is not part of type system in proto.
215/// * Top level is always a message.
216/// * All message fields can be omitted in proto3.
217/// * All repeated elements must have a value.
218///
219/// So we handle [`ScalarRefImpl`] rather than [`DatumRef`] here.
220impl MaybeData for ScalarRefImpl<'_> {
221    type Out = Value;
222
223    fn on_base(self, f: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out> {
224        f(self)
225    }
226
227    fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out> {
228        let d = self.into_struct();
229        let message = encode_fields(st.iter().zip_eq_debug(d.iter_fields_ref()), pb)?;
230        Ok(Value::Message(message))
231    }
232
233    fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out> {
234        let d = self.into_list();
235        let vs = d
236            .iter()
237            .map(|d| {
238                on_field(
239                    elem,
240                    d.ok_or_else(|| {
241                        FieldEncodeError::new("array containing null not allowed as repeated field")
242                    })?,
243                    pb,
244                    true,
245                )
246            })
247            .try_collect()?;
248        Ok(Value::List(vs))
249    }
250
251    fn on_map(self, m: &MapType, pb: &MessageDescriptor) -> Result<Self::Out> {
252        debug_assert!(pb.is_map_entry());
253        let vs = self
254            .into_map()
255            .iter()
256            .map(|(k, v)| {
257                let v =
258                    v.ok_or_else(|| FieldEncodeError::new("map containing null not allowed"))?;
259                let k = on_field(m.key(), k, &pb.map_entry_key_field(), false)?;
260                let v = on_field(m.value(), v, &pb.map_entry_value_field(), false)?;
261                Ok((
262                    k.into_map_key().ok_or_else(|| {
263                        FieldEncodeError::new("failed to convert map key to proto")
264                    })?,
265                    v,
266                ))
267            })
268            .try_collect()?;
269        Ok(Value::Map(vs))
270    }
271}
272
273fn validate_fields<'a>(
274    fields: impl Iterator<Item = (&'a str, &'a DataType)>,
275    descriptor: &MessageDescriptor,
276) -> Result<()> {
277    for (name, t) in fields {
278        let Some(proto_field) = descriptor.get_field_by_name(name) else {
279            return Err(FieldEncodeError::new("field not in proto").with_name(name));
280        };
281        if proto_field.cardinality() == prost_reflect::Cardinality::Required {
282            return Err(FieldEncodeError::new("`required` not supported").with_name(name));
283        }
284        on_field(t, (), &proto_field, false).map_err(|e| e.with_name(name))?;
285    }
286    Ok(())
287}
288
289fn encode_fields<'a>(
290    fields_with_datums: impl Iterator<Item = ((&'a str, &'a DataType), DatumRef<'a>)>,
291    descriptor: &MessageDescriptor,
292) -> Result<DynamicMessage> {
293    let mut message = DynamicMessage::new(descriptor.clone());
294    for ((name, t), d) in fields_with_datums {
295        let proto_field = descriptor.get_field_by_name(name).unwrap();
296        // On `null`, simply skip setting the field.
297        if let Some(scalar) = d {
298            let value = on_field(t, scalar, &proto_field, false).map_err(|e| e.with_name(name))?;
299            message
300                .try_set_field(&proto_field, value)
301                .map_err(|e| FieldEncodeError::new(e).with_name(name))?;
302        }
303    }
304    Ok(message)
305}
306
307// Full name of Well-Known Types
308const WKT_TIMESTAMP: &str = "google.protobuf.Timestamp";
309#[expect(dead_code)]
310const WKT_BOOL_VALUE: &str = "google.protobuf.BoolValue";
311
312/// Handles both `validate` (without actual data) and `encode`.
313/// See [`MaybeData`] for more info.
314fn on_field<D: MaybeData>(
315    data_type: &DataType,
316    maybe: D,
317    proto_field: &FieldDescriptor,
318    in_repeated: bool,
319) -> Result<D::Out> {
320    // Regarding (proto_field.is_list, in_repeated):
321    // (F, T) => impossible
322    // (F, F) => encoding to a non-repeated field
323    // (T, F) => encoding to a repeated field
324    // (T, T) => encoding to an element of a repeated field
325    // In the bottom 2 cases, we need to distinguish the same `proto_field` with the help of `in_repeated`.
326    assert!(proto_field.is_list() || !in_repeated);
327    let expect_list = proto_field.is_list() && !in_repeated;
328    if proto_field.is_group() {
329        return Err(FieldEncodeError::new("proto group not supported yet"));
330    }
331
332    let no_match_err = || {
333        Err(FieldEncodeError::new(format!(
334            "cannot encode {} column as {}{:?} field",
335            data_type,
336            if expect_list { "repeated " } else { "" },
337            proto_field.kind()
338        )))
339    };
340
341    if expect_list && !matches!(data_type, DataType::List(_)) {
342        return no_match_err();
343    }
344
345    let value = match &data_type {
346        // Group A: perfect match between RisingWave types and ProtoBuf types
347        DataType::Boolean => match proto_field.kind() {
348            Kind::Bool => maybe.on_base(|s| Ok(Value::Bool(s.into_bool())))?,
349            _ => return no_match_err(),
350        },
351        DataType::Varchar => match proto_field.kind() {
352            Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_utf8().into())))?,
353            Kind::Enum(enum_desc) => maybe.on_base(|s| {
354                let name = s.into_utf8();
355                let enum_value_desc = enum_desc.get_value_by_name(name).ok_or_else(|| {
356                    FieldEncodeError::new(format!("'{name}' not in enum {}", enum_desc.name()))
357                })?;
358                Ok(Value::EnumNumber(enum_value_desc.number()))
359            })?,
360            _ => return no_match_err(),
361        },
362        DataType::Bytea => match proto_field.kind() {
363            Kind::Bytes => {
364                maybe.on_base(|s| Ok(Value::Bytes(Bytes::copy_from_slice(s.into_bytea()))))?
365            }
366            _ => return no_match_err(),
367        },
368        DataType::Float32 => match proto_field.kind() {
369            Kind::Float => maybe.on_base(|s| Ok(Value::F32(s.into_float32().into())))?,
370            _ => return no_match_err(),
371        },
372        DataType::Float64 => match proto_field.kind() {
373            Kind::Double => maybe.on_base(|s| Ok(Value::F64(s.into_float64().into())))?,
374            _ => return no_match_err(),
375        },
376        DataType::Int32 => match proto_field.kind() {
377            Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => {
378                maybe.on_base(|s| Ok(Value::I32(s.into_int32())))?
379            }
380            _ => return no_match_err(),
381        },
382        DataType::Int64 => match proto_field.kind() {
383            Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => {
384                maybe.on_base(|s| Ok(Value::I64(s.into_int64())))?
385            }
386            _ => return no_match_err(),
387        },
388        DataType::Struct(st) => match proto_field.kind() {
389            Kind::Message(pb) => maybe.on_struct(st, &pb)?,
390            _ => return no_match_err(),
391        },
392        DataType::List(elem) => match expect_list {
393            true => maybe.on_list(elem, proto_field)?,
394            false => return no_match_err(),
395        },
396        // Group B: match between RisingWave types and ProtoBuf Well-Known types
397        DataType::Timestamptz => match proto_field.kind() {
398            Kind::Message(pb) if pb.full_name() == WKT_TIMESTAMP => maybe.on_base(|s| {
399                let d = s.into_timestamptz();
400                let message = prost_types::Timestamp {
401                    seconds: d.timestamp(),
402                    nanos: d.timestamp_subsec_nanos().try_into().unwrap(),
403                };
404                Ok(Value::Message(message.transcode_to_dynamic()))
405            })?,
406            Kind::String => {
407                maybe.on_base(|s| Ok(Value::String(s.into_timestamptz().to_string())))?
408            }
409            _ => return no_match_err(),
410        },
411        DataType::Jsonb => match proto_field.kind() {
412            Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_jsonb().to_string())))?,
413            _ => return no_match_err(), /* Value, NullValue, Struct (map), ListValue
414                                         * Group C: experimental */
415        },
416        DataType::Int16 => match proto_field.kind() {
417            Kind::Int64 => maybe.on_base(|s| Ok(Value::I64(s.into_int16() as i64)))?,
418            _ => return no_match_err(),
419        },
420        DataType::Date => match proto_field.kind() {
421            Kind::Int32 => {
422                maybe.on_base(|s| Ok(Value::I32(s.into_date().get_nums_days_unix_epoch())))?
423            }
424            _ => return no_match_err(), // google.type.Date
425        },
426        DataType::Time => match proto_field.kind() {
427            Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_time().to_string())))?,
428            _ => return no_match_err(), // google.type.TimeOfDay
429        },
430        DataType::Timestamp => match proto_field.kind() {
431            Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_timestamp().to_string())))?,
432            _ => return no_match_err(), // google.type.DateTime
433        },
434        DataType::Decimal => match proto_field.kind() {
435            Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_decimal().to_string())))?,
436            _ => return no_match_err(), // google.type.Decimal
437        },
438        DataType::Interval => match proto_field.kind() {
439            Kind::String => {
440                maybe.on_base(|s| Ok(Value::String(s.into_interval().as_iso_8601())))?
441            }
442            _ => return no_match_err(), // Group D: unsupported
443        },
444        DataType::Serial => match proto_field.kind() {
445            Kind::Int64 => maybe.on_base(|s| Ok(Value::I64(s.into_serial().as_row_id())))?,
446            _ => return no_match_err(), // Group D: unsupported
447        },
448        DataType::Int256 => {
449            return no_match_err();
450        }
451        DataType::Map(map_type) => {
452            if proto_field.is_map() {
453                let msg = match proto_field.kind() {
454                    Kind::Message(m) => m,
455                    _ => return no_match_err(), // unreachable actually
456                };
457                return maybe.on_map(map_type, &msg);
458            } else {
459                return no_match_err();
460            }
461        }
462        DataType::Vector(_) => match expect_list {
463            true => maybe.on_list(&VECTOR_ITEM_TYPE, proto_field)?,
464            false => return no_match_err(),
465        },
466    };
467
468    Ok(value)
469}
470
471#[cfg(test)]
472mod tests {
473    use itertools::Itertools;
474    use risingwave_common::array::{ArrayBuilder, StructArrayBuilder};
475    use risingwave_common::catalog::Field;
476    use risingwave_common::row::OwnedRow;
477    use risingwave_common::types::{
478        ListValue, MapType, MapValue, Scalar, ScalarImpl, StructValue, Timestamptz,
479    };
480
481    use super::*;
482
483    #[test]
484    fn test_encode_proto_ok() {
485        let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
486            .join("codec/tests/test_data/all-types.pb");
487        let pool_bytes = std::fs::read(pool_path).unwrap();
488        let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
489        let descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
490        let schema = Schema::new(vec![
491            Field::with_name(DataType::Boolean, "bool_field"),
492            Field::with_name(DataType::Varchar, "string_field"),
493            Field::with_name(DataType::Bytea, "bytes_field"),
494            Field::with_name(DataType::Float32, "float_field"),
495            Field::with_name(DataType::Float64, "double_field"),
496            Field::with_name(DataType::Int32, "int32_field"),
497            Field::with_name(DataType::Int64, "int64_field"),
498            Field::with_name(DataType::Int32, "sint32_field"),
499            Field::with_name(DataType::Int64, "sint64_field"),
500            Field::with_name(DataType::Int32, "sfixed32_field"),
501            Field::with_name(DataType::Int64, "sfixed64_field"),
502            Field::with_name(
503                DataType::Struct(StructType::new(vec![
504                    ("id", DataType::Int32),
505                    ("name", DataType::Varchar),
506                ])),
507                "nested_message_field",
508            ),
509            Field::with_name(DataType::List(DataType::Int32.into()), "repeated_int_field"),
510            Field::with_name(DataType::Timestamptz, "timestamp_field"),
511            Field::with_name(
512                DataType::Map(MapType::from_kv(DataType::Varchar, DataType::Int32)),
513                "map_field",
514            ),
515            Field::with_name(
516                DataType::Map(MapType::from_kv(
517                    DataType::Varchar,
518                    DataType::Struct(StructType::new(vec![
519                        ("id", DataType::Int32),
520                        ("name", DataType::Varchar),
521                    ])),
522                )),
523                "map_struct_field",
524            ),
525        ]);
526        let row = OwnedRow::new(vec![
527            Some(ScalarImpl::Bool(true)),
528            Some(ScalarImpl::Utf8("RisingWave".into())),
529            Some(ScalarImpl::Bytea([0xbe, 0xef].into())),
530            Some(ScalarImpl::Float32(3.5f32.into())),
531            Some(ScalarImpl::Float64(4.25f64.into())),
532            Some(ScalarImpl::Int32(22)),
533            Some(ScalarImpl::Int64(23)),
534            Some(ScalarImpl::Int32(24)),
535            None,
536            Some(ScalarImpl::Int32(26)),
537            Some(ScalarImpl::Int64(27)),
538            Some(ScalarImpl::Struct(StructValue::new(vec![
539                Some(ScalarImpl::Int32(1)),
540                Some(ScalarImpl::Utf8("".into())),
541            ]))),
542            Some(ScalarImpl::List(ListValue::from_iter([4, 0, 4]))),
543            Some(ScalarImpl::Timestamptz(Timestamptz::from_micros(3))),
544            Some(ScalarImpl::Map(
545                MapValue::try_from_kv(
546                    ListValue::from_iter(["a", "b"]),
547                    ListValue::from_iter([1, 2]),
548                )
549                .unwrap(),
550            )),
551            {
552                let mut struct_array_builder = StructArrayBuilder::with_type(
553                    2,
554                    DataType::Struct(StructType::new(vec![
555                        ("id", DataType::Int32),
556                        ("name", DataType::Varchar),
557                    ])),
558                );
559                struct_array_builder.append(Some(
560                    StructValue::new(vec![
561                        Some(ScalarImpl::Int32(1)),
562                        Some(ScalarImpl::Utf8("x".into())),
563                    ])
564                    .as_scalar_ref(),
565                ));
566                struct_array_builder.append(Some(
567                    StructValue::new(vec![
568                        Some(ScalarImpl::Int32(2)),
569                        Some(ScalarImpl::Utf8("y".into())),
570                    ])
571                    .as_scalar_ref(),
572                ));
573                Some(ScalarImpl::Map(
574                    MapValue::try_from_kv(
575                        ListValue::from_iter(["a", "b"]),
576                        ListValue::new(struct_array_builder.finish().into()),
577                    )
578                    .unwrap(),
579                ))
580            },
581        ]);
582
583        let encoder =
584            ProtoEncoder::new(schema, None, descriptor.clone(), ProtoHeader::None).unwrap();
585        let m = encoder.encode(row).unwrap();
586        expect_test::expect![[r#"
587            field: FieldDescriptor {
588                name: "double_field",
589                full_name: "all_types.AllTypes.double_field",
590                json_name: "doubleField",
591                number: 1,
592                kind: double,
593                cardinality: Optional,
594                containing_oneof: None,
595                default_value: None,
596                is_group: false,
597                is_list: false,
598                is_map: false,
599                is_packed: false,
600                supports_presence: false,
601            }
602
603            value: F64(4.25)
604
605            ==============================
606            field: FieldDescriptor {
607                name: "float_field",
608                full_name: "all_types.AllTypes.float_field",
609                json_name: "floatField",
610                number: 2,
611                kind: float,
612                cardinality: Optional,
613                containing_oneof: None,
614                default_value: None,
615                is_group: false,
616                is_list: false,
617                is_map: false,
618                is_packed: false,
619                supports_presence: false,
620            }
621
622            value: F32(3.5)
623
624            ==============================
625            field: FieldDescriptor {
626                name: "int32_field",
627                full_name: "all_types.AllTypes.int32_field",
628                json_name: "int32Field",
629                number: 3,
630                kind: int32,
631                cardinality: Optional,
632                containing_oneof: None,
633                default_value: None,
634                is_group: false,
635                is_list: false,
636                is_map: false,
637                is_packed: false,
638                supports_presence: false,
639            }
640
641            value: I32(22)
642
643            ==============================
644            field: FieldDescriptor {
645                name: "int64_field",
646                full_name: "all_types.AllTypes.int64_field",
647                json_name: "int64Field",
648                number: 4,
649                kind: int64,
650                cardinality: Optional,
651                containing_oneof: None,
652                default_value: None,
653                is_group: false,
654                is_list: false,
655                is_map: false,
656                is_packed: false,
657                supports_presence: false,
658            }
659
660            value: I64(23)
661
662            ==============================
663            field: FieldDescriptor {
664                name: "sint32_field",
665                full_name: "all_types.AllTypes.sint32_field",
666                json_name: "sint32Field",
667                number: 7,
668                kind: sint32,
669                cardinality: Optional,
670                containing_oneof: None,
671                default_value: None,
672                is_group: false,
673                is_list: false,
674                is_map: false,
675                is_packed: false,
676                supports_presence: false,
677            }
678
679            value: I32(24)
680
681            ==============================
682            field: FieldDescriptor {
683                name: "sfixed32_field",
684                full_name: "all_types.AllTypes.sfixed32_field",
685                json_name: "sfixed32Field",
686                number: 11,
687                kind: sfixed32,
688                cardinality: Optional,
689                containing_oneof: None,
690                default_value: None,
691                is_group: false,
692                is_list: false,
693                is_map: false,
694                is_packed: false,
695                supports_presence: false,
696            }
697
698            value: I32(26)
699
700            ==============================
701            field: FieldDescriptor {
702                name: "sfixed64_field",
703                full_name: "all_types.AllTypes.sfixed64_field",
704                json_name: "sfixed64Field",
705                number: 12,
706                kind: sfixed64,
707                cardinality: Optional,
708                containing_oneof: None,
709                default_value: None,
710                is_group: false,
711                is_list: false,
712                is_map: false,
713                is_packed: false,
714                supports_presence: false,
715            }
716
717            value: I64(27)
718
719            ==============================
720            field: FieldDescriptor {
721                name: "bool_field",
722                full_name: "all_types.AllTypes.bool_field",
723                json_name: "boolField",
724                number: 13,
725                kind: bool,
726                cardinality: Optional,
727                containing_oneof: None,
728                default_value: None,
729                is_group: false,
730                is_list: false,
731                is_map: false,
732                is_packed: false,
733                supports_presence: false,
734            }
735
736            value: Bool(true)
737
738            ==============================
739            field: FieldDescriptor {
740                name: "string_field",
741                full_name: "all_types.AllTypes.string_field",
742                json_name: "stringField",
743                number: 14,
744                kind: string,
745                cardinality: Optional,
746                containing_oneof: None,
747                default_value: None,
748                is_group: false,
749                is_list: false,
750                is_map: false,
751                is_packed: false,
752                supports_presence: false,
753            }
754
755            value: String("RisingWave")
756
757            ==============================
758            field: FieldDescriptor {
759                name: "bytes_field",
760                full_name: "all_types.AllTypes.bytes_field",
761                json_name: "bytesField",
762                number: 15,
763                kind: bytes,
764                cardinality: Optional,
765                containing_oneof: None,
766                default_value: None,
767                is_group: false,
768                is_list: false,
769                is_map: false,
770                is_packed: false,
771                supports_presence: false,
772            }
773
774            value: Bytes(b"\xbe\xef")
775
776            ==============================
777            field: FieldDescriptor {
778                name: "nested_message_field",
779                full_name: "all_types.AllTypes.nested_message_field",
780                json_name: "nestedMessageField",
781                number: 17,
782                kind: all_types.AllTypes.NestedMessage,
783                cardinality: Optional,
784                containing_oneof: None,
785                default_value: None,
786                is_group: false,
787                is_list: false,
788                is_map: false,
789                is_packed: false,
790                supports_presence: true,
791            }
792
793            value: Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(1)), 2: Value(String(""))} } })
794
795            ==============================
796            field: FieldDescriptor {
797                name: "repeated_int_field",
798                full_name: "all_types.AllTypes.repeated_int_field",
799                json_name: "repeatedIntField",
800                number: 18,
801                kind: int32,
802                cardinality: Repeated,
803                containing_oneof: None,
804                default_value: None,
805                is_group: false,
806                is_list: true,
807                is_map: false,
808                is_packed: true,
809                supports_presence: false,
810            }
811
812            value: List([I32(4), I32(0), I32(4)])
813
814            ==============================
815            field: FieldDescriptor {
816                name: "map_field",
817                full_name: "all_types.AllTypes.map_field",
818                json_name: "mapField",
819                number: 22,
820                kind: all_types.AllTypes.MapFieldEntry,
821                cardinality: Repeated,
822                containing_oneof: None,
823                default_value: None,
824                is_group: false,
825                is_list: false,
826                is_map: true,
827                is_packed: false,
828                supports_presence: false,
829            }
830
831            value: Map({
832                String("a"): I32(1),
833                String("b"): I32(2),
834            })
835
836            ==============================
837            field: FieldDescriptor {
838                name: "timestamp_field",
839                full_name: "all_types.AllTypes.timestamp_field",
840                json_name: "timestampField",
841                number: 23,
842                kind: google.protobuf.Timestamp,
843                cardinality: Optional,
844                containing_oneof: None,
845                default_value: None,
846                is_group: false,
847                is_list: false,
848                is_map: false,
849                is_packed: false,
850                supports_presence: true,
851            }
852
853            value: Message(DynamicMessage { desc: MessageDescriptor { name: "Timestamp", full_name: "google.protobuf.Timestamp", is_map_entry: false, fields: [FieldDescriptor { name: "seconds", full_name: "google.protobuf.Timestamp.seconds", json_name: "seconds", number: 1, kind: int64, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "nanos", full_name: "google.protobuf.Timestamp.nanos", json_name: "nanos", number: 2, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {2: Value(I32(3000))} } })
854
855            ==============================
856            field: FieldDescriptor {
857                name: "map_struct_field",
858                full_name: "all_types.AllTypes.map_struct_field",
859                json_name: "mapStructField",
860                number: 29,
861                kind: all_types.AllTypes.MapStructFieldEntry,
862                cardinality: Repeated,
863                containing_oneof: None,
864                default_value: None,
865                is_group: false,
866                is_list: false,
867                is_map: true,
868                is_packed: false,
869                supports_presence: false,
870            }
871
872            value: Map({
873                String("a"): Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(1)), 2: Value(String("x"))} } }),
874                String("b"): Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(2)), 2: Value(String("y"))} } }),
875            })"#]].assert_eq(&format!("{}",
876            m.message.fields().format_with("\n\n==============================\n", |(field,value),f| {
877            f(&format!("field: {:#?}\n\nvalue: {}", field, print_proto(value)))
878        })));
879    }
880
881    fn print_proto(value: &Value) -> String {
882        match value {
883            Value::Map(m) => {
884                let mut res = String::new();
885                res.push_str("Map({\n");
886                for (k, v) in m.iter().sorted_by_key(|(k, _v)| *k) {
887                    res.push_str(&format!(
888                        "    {}: {},\n",
889                        print_proto(&k.clone().into()),
890                        print_proto(v)
891                    ));
892                }
893                res.push_str("})");
894                res
895            }
896            _ => format!("{:?}", value),
897        }
898    }
899
900    #[test]
901    fn test_encode_proto_repeated() {
902        let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
903            .join("codec/tests/test_data/all-types.pb");
904        let pool_bytes = fs_err::read(pool_path).unwrap();
905        let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
906        let message_descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
907
908        let schema = Schema::new(vec![Field::with_name(
909            DataType::List(DataType::List(DataType::Int32.into()).into()),
910            "repeated_int_field",
911        )]);
912
913        let err = validate_fields(
914            schema
915                .fields
916                .iter()
917                .map(|f| (f.name.as_str(), &f.data_type)),
918            &message_descriptor,
919        )
920        .unwrap_err();
921        assert_eq!(
922            err.to_string(),
923            "encode 'repeated_int_field' error: cannot encode integer[] column as int32 field"
924        );
925
926        let schema = Schema::new(vec![Field::with_name(
927            DataType::List(DataType::Int32.into()),
928            "repeated_int_field",
929        )]);
930        let row = OwnedRow::new(vec![Some(ScalarImpl::List(ListValue::from_iter([
931            Some(0),
932            None,
933            Some(2),
934            Some(3),
935        ])))]);
936
937        let err = encode_fields(
938            schema
939                .fields
940                .iter()
941                .map(|f| (f.name.as_str(), &f.data_type))
942                .zip_eq_debug(row.iter()),
943            &message_descriptor,
944        )
945        .unwrap_err();
946        assert_eq!(
947            err.to_string(),
948            "encode 'repeated_int_field' error: array containing null not allowed as repeated field"
949        );
950    }
951
952    #[test]
953    fn test_encode_proto_err() {
954        let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
955            .join("codec/tests/test_data/all-types.pb");
956        let pool_bytes = std::fs::read(pool_path).unwrap();
957        let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
958        let message_descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
959
960        let err = validate_fields(
961            std::iter::once(("not_exists", &DataType::Int16)),
962            &message_descriptor,
963        )
964        .unwrap_err();
965        assert_eq!(
966            err.to_string(),
967            "encode 'not_exists' error: field not in proto"
968        );
969
970        let err = validate_fields(
971            std::iter::once(("map_field", &DataType::Jsonb)),
972            &message_descriptor,
973        )
974        .unwrap_err();
975        assert_eq!(
976            err.to_string(),
977            "encode 'map_field' error: cannot encode jsonb column as all_types.AllTypes.MapFieldEntry field"
978        );
979    }
980}