risingwave_connector/sink/encoder/
proto.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use bytes::{BufMut, Bytes};
16use prost::Message;
17use prost_reflect::{
18    DynamicMessage, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage, Value,
19};
20use risingwave_common::catalog::Schema;
21use risingwave_common::row::Row;
22use risingwave_common::types::{DataType, DatumRef, MapType, ScalarRefImpl, StructType};
23use risingwave_common::util::iter_util::ZipEqDebug;
24
25use super::{FieldEncodeError, Result as SinkResult, RowEncoder, SerTo};
26
27type Result<T> = std::result::Result<T, FieldEncodeError>;
28
29pub struct ProtoEncoder {
30    schema: Schema,
31    col_indices: Option<Vec<usize>>,
32    descriptor: MessageDescriptor,
33    header: ProtoHeader,
34}
35
36#[derive(Debug, Clone, Copy)]
37pub enum ProtoHeader {
38    None,
39    /// <https://docs.confluent.io/platform/7.5/schema-registry/fundamentals/serdes-develop/index.html#messages-wire-format>
40    ///
41    /// * 00
42    /// * 4-byte big-endian schema ID
43    ConfluentSchemaRegistry(i32),
44}
45
46impl ProtoEncoder {
47    pub fn new(
48        schema: Schema,
49        col_indices: Option<Vec<usize>>,
50        descriptor: MessageDescriptor,
51        header: ProtoHeader,
52    ) -> SinkResult<Self> {
53        match &col_indices {
54            Some(col_indices) => validate_fields(
55                col_indices.iter().map(|idx| {
56                    let f = &schema[*idx];
57                    (f.name.as_str(), &f.data_type)
58                }),
59                &descriptor,
60            )?,
61            None => validate_fields(
62                schema
63                    .fields
64                    .iter()
65                    .map(|f| (f.name.as_str(), &f.data_type)),
66                &descriptor,
67            )?,
68        };
69
70        Ok(Self {
71            schema,
72            col_indices,
73            descriptor,
74            header,
75        })
76    }
77}
78
79pub struct ProtoEncoded {
80    pub message: DynamicMessage,
81    header: ProtoHeader,
82}
83
84impl RowEncoder for ProtoEncoder {
85    type Output = ProtoEncoded;
86
87    fn schema(&self) -> &Schema {
88        &self.schema
89    }
90
91    fn col_indices(&self) -> Option<&[usize]> {
92        self.col_indices.as_deref()
93    }
94
95    fn encode_cols(
96        &self,
97        row: impl Row,
98        col_indices: impl Iterator<Item = usize>,
99    ) -> SinkResult<Self::Output> {
100        encode_fields(
101            col_indices.map(|idx| {
102                let f = &self.schema[idx];
103                ((f.name.as_str(), &f.data_type), row.datum_at(idx))
104            }),
105            &self.descriptor,
106        )
107        .map_err(Into::into)
108        .map(|m| ProtoEncoded {
109            message: m,
110            header: self.header,
111        })
112    }
113}
114
115impl SerTo<Vec<u8>> for ProtoEncoded {
116    fn ser_to(self) -> SinkResult<Vec<u8>> {
117        let mut buf = Vec::new();
118        match self.header {
119            ProtoHeader::None => { /* noop */ }
120            ProtoHeader::ConfluentSchemaRegistry(schema_id) => {
121                buf.reserve(1 + 4);
122                buf.put_u8(0);
123                buf.put_i32(schema_id);
124                MessageIndexes::from(self.message.descriptor()).encode(&mut buf);
125            }
126        }
127        self.message.encode(&mut buf).unwrap();
128        Ok(buf)
129    }
130}
131
132struct MessageIndexes(Vec<i32>);
133
134impl MessageIndexes {
135    fn from(desc: MessageDescriptor) -> Self {
136        // https://github.com/protocolbuffers/protobuf/blob/v25.1/src/google/protobuf/descriptor.proto
137        // https://docs.rs/prost-reflect/0.12.0/src/prost_reflect/descriptor/tag.rs.html
138        // https://docs.rs/prost-reflect/0.12.0/src/prost_reflect/descriptor/build/visit.rs.html#125
139        // `FileDescriptorProto` field #4 is `repeated DescriptorProto message_type`
140        const TAG_FILE_MESSAGE: i32 = 4;
141        // `DescriptorProto` field #3 is `repeated DescriptorProto nested_type`
142        const TAG_MESSAGE_NESTED: i32 = 3;
143
144        let mut indexes = vec![];
145        let mut path = desc.path().array_chunks();
146        let &[tag, idx] = path.next().unwrap();
147        assert_eq!(tag, TAG_FILE_MESSAGE);
148        indexes.push(idx);
149        for &[tag, idx] in path {
150            assert_eq!(tag, TAG_MESSAGE_NESTED);
151            indexes.push(idx);
152        }
153        Self(indexes)
154    }
155
156    fn zig_i32(value: i32, buf: &mut impl BufMut) {
157        let unsigned = ((value << 1) ^ (value >> 31)) as u32 as u64;
158        prost::encoding::encode_varint(unsigned, buf);
159    }
160
161    fn encode(&self, buf: &mut impl BufMut) {
162        if self.0 == [0] {
163            buf.put_u8(0);
164            return;
165        }
166        Self::zig_i32(self.0.len().try_into().unwrap(), buf);
167        for &idx in &self.0 {
168            Self::zig_i32(idx, buf);
169        }
170    }
171}
172
173/// A trait that assists code reuse between `validate` and `encode`.
174/// * For `validate`, the inputs are (RisingWave type, ProtoBuf type).
175/// * For `encode`, the inputs are (RisingWave type, RisingWave data, ProtoBuf type).
176///
177/// Thus we impl [`MaybeData`] for both [`()`] and [`ScalarRefImpl`].
178trait MaybeData: std::fmt::Debug {
179    type Out;
180
181    fn on_base(self, f: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out>;
182
183    fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out>;
184
185    fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out>;
186
187    fn on_map(self, m: &MapType, pb: &MessageDescriptor) -> Result<Self::Out>;
188}
189
190impl MaybeData for () {
191    type Out = ();
192
193    fn on_base(self, _: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out> {
194        Ok(self)
195    }
196
197    fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out> {
198        validate_fields(st.iter(), pb)
199    }
200
201    fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out> {
202        on_field(elem, (), pb, true)
203    }
204
205    fn on_map(self, elem: &MapType, pb: &MessageDescriptor) -> Result<Self::Out> {
206        debug_assert!(pb.is_map_entry());
207        on_field(elem.key(), (), &pb.map_entry_key_field(), false)?;
208        on_field(elem.value(), (), &pb.map_entry_value_field(), false)?;
209        Ok(())
210    }
211}
212
213/// Nullability is not part of type system in proto.
214/// * Top level is always a message.
215/// * All message fields can be omitted in proto3.
216/// * All repeated elements must have a value.
217///
218/// So we handle [`ScalarRefImpl`] rather than [`DatumRef`] here.
219impl MaybeData for ScalarRefImpl<'_> {
220    type Out = Value;
221
222    fn on_base(self, f: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out> {
223        f(self)
224    }
225
226    fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out> {
227        let d = self.into_struct();
228        let message = encode_fields(st.iter().zip_eq_debug(d.iter_fields_ref()), pb)?;
229        Ok(Value::Message(message))
230    }
231
232    fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out> {
233        let d = self.into_list();
234        let vs = d
235            .iter()
236            .map(|d| {
237                on_field(
238                    elem,
239                    d.ok_or_else(|| {
240                        FieldEncodeError::new("array containing null not allowed as repeated field")
241                    })?,
242                    pb,
243                    true,
244                )
245            })
246            .try_collect()?;
247        Ok(Value::List(vs))
248    }
249
250    fn on_map(self, m: &MapType, pb: &MessageDescriptor) -> Result<Self::Out> {
251        debug_assert!(pb.is_map_entry());
252        let vs = self
253            .into_map()
254            .iter()
255            .map(|(k, v)| {
256                let v =
257                    v.ok_or_else(|| FieldEncodeError::new("map containing null not allowed"))?;
258                let k = on_field(m.key(), k, &pb.map_entry_key_field(), false)?;
259                let v = on_field(m.value(), v, &pb.map_entry_value_field(), false)?;
260                Ok((
261                    k.into_map_key().ok_or_else(|| {
262                        FieldEncodeError::new("failed to convert map key to proto")
263                    })?,
264                    v,
265                ))
266            })
267            .try_collect()?;
268        Ok(Value::Map(vs))
269    }
270}
271
272fn validate_fields<'a>(
273    fields: impl Iterator<Item = (&'a str, &'a DataType)>,
274    descriptor: &MessageDescriptor,
275) -> Result<()> {
276    for (name, t) in fields {
277        let Some(proto_field) = descriptor.get_field_by_name(name) else {
278            return Err(FieldEncodeError::new("field not in proto").with_name(name));
279        };
280        if proto_field.cardinality() == prost_reflect::Cardinality::Required {
281            return Err(FieldEncodeError::new("`required` not supported").with_name(name));
282        }
283        on_field(t, (), &proto_field, false).map_err(|e| e.with_name(name))?;
284    }
285    Ok(())
286}
287
288fn encode_fields<'a>(
289    fields_with_datums: impl Iterator<Item = ((&'a str, &'a DataType), DatumRef<'a>)>,
290    descriptor: &MessageDescriptor,
291) -> Result<DynamicMessage> {
292    let mut message = DynamicMessage::new(descriptor.clone());
293    for ((name, t), d) in fields_with_datums {
294        let proto_field = descriptor.get_field_by_name(name).unwrap();
295        // On `null`, simply skip setting the field.
296        if let Some(scalar) = d {
297            let value = on_field(t, scalar, &proto_field, false).map_err(|e| e.with_name(name))?;
298            message
299                .try_set_field(&proto_field, value)
300                .map_err(|e| FieldEncodeError::new(e).with_name(name))?;
301        }
302    }
303    Ok(message)
304}
305
306// Full name of Well-Known Types
307const WKT_TIMESTAMP: &str = "google.protobuf.Timestamp";
308#[expect(dead_code)]
309const WKT_BOOL_VALUE: &str = "google.protobuf.BoolValue";
310
311/// Handles both `validate` (without actual data) and `encode`.
312/// See [`MaybeData`] for more info.
313fn on_field<D: MaybeData>(
314    data_type: &DataType,
315    maybe: D,
316    proto_field: &FieldDescriptor,
317    in_repeated: bool,
318) -> Result<D::Out> {
319    // Regarding (proto_field.is_list, in_repeated):
320    // (F, T) => impossible
321    // (F, F) => encoding to a non-repeated field
322    // (T, F) => encoding to a repeated field
323    // (T, T) => encoding to an element of a repeated field
324    // In the bottom 2 cases, we need to distinguish the same `proto_field` with the help of `in_repeated`.
325    assert!(proto_field.is_list() || !in_repeated);
326    let expect_list = proto_field.is_list() && !in_repeated;
327    if proto_field.is_group() {
328        return Err(FieldEncodeError::new("proto group not supported yet"));
329    }
330
331    let no_match_err = || {
332        Err(FieldEncodeError::new(format!(
333            "cannot encode {} column as {}{:?} field",
334            data_type,
335            if expect_list { "repeated " } else { "" },
336            proto_field.kind()
337        )))
338    };
339
340    if expect_list && !matches!(data_type, DataType::List(_)) {
341        return no_match_err();
342    }
343
344    let value = match &data_type {
345        // Group A: perfect match between RisingWave types and ProtoBuf types
346        DataType::Boolean => match proto_field.kind() {
347            Kind::Bool => maybe.on_base(|s| Ok(Value::Bool(s.into_bool())))?,
348            _ => return no_match_err(),
349        },
350        DataType::Varchar => match proto_field.kind() {
351            Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_utf8().into())))?,
352            Kind::Enum(enum_desc) => maybe.on_base(|s| {
353                let name = s.into_utf8();
354                let enum_value_desc = enum_desc.get_value_by_name(name).ok_or_else(|| {
355                    FieldEncodeError::new(format!("'{name}' not in enum {}", enum_desc.name()))
356                })?;
357                Ok(Value::EnumNumber(enum_value_desc.number()))
358            })?,
359            _ => return no_match_err(),
360        },
361        DataType::Bytea => match proto_field.kind() {
362            Kind::Bytes => {
363                maybe.on_base(|s| Ok(Value::Bytes(Bytes::copy_from_slice(s.into_bytea()))))?
364            }
365            _ => return no_match_err(),
366        },
367        DataType::Float32 => match proto_field.kind() {
368            Kind::Float => maybe.on_base(|s| Ok(Value::F32(s.into_float32().into())))?,
369            _ => return no_match_err(),
370        },
371        DataType::Float64 => match proto_field.kind() {
372            Kind::Double => maybe.on_base(|s| Ok(Value::F64(s.into_float64().into())))?,
373            _ => return no_match_err(),
374        },
375        DataType::Int32 => match proto_field.kind() {
376            Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => {
377                maybe.on_base(|s| Ok(Value::I32(s.into_int32())))?
378            }
379            _ => return no_match_err(),
380        },
381        DataType::Int64 => match proto_field.kind() {
382            Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => {
383                maybe.on_base(|s| Ok(Value::I64(s.into_int64())))?
384            }
385            _ => return no_match_err(),
386        },
387        DataType::Struct(st) => match proto_field.kind() {
388            Kind::Message(pb) => maybe.on_struct(st, &pb)?,
389            _ => return no_match_err(),
390        },
391        DataType::List(elem) => match expect_list {
392            true => maybe.on_list(elem, proto_field)?,
393            false => return no_match_err(),
394        },
395        // Group B: match between RisingWave types and ProtoBuf Well-Known types
396        DataType::Timestamptz => match proto_field.kind() {
397            Kind::Message(pb) if pb.full_name() == WKT_TIMESTAMP => maybe.on_base(|s| {
398                let d = s.into_timestamptz();
399                let message = prost_types::Timestamp {
400                    seconds: d.timestamp(),
401                    nanos: d.timestamp_subsec_nanos().try_into().unwrap(),
402                };
403                Ok(Value::Message(message.transcode_to_dynamic()))
404            })?,
405            Kind::String => {
406                maybe.on_base(|s| Ok(Value::String(s.into_timestamptz().to_string())))?
407            }
408            _ => return no_match_err(),
409        },
410        DataType::Jsonb => match proto_field.kind() {
411            Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_jsonb().to_string())))?,
412            _ => return no_match_err(), /* Value, NullValue, Struct (map), ListValue
413                                         * Group C: experimental */
414        },
415        DataType::Int16 => match proto_field.kind() {
416            Kind::Int64 => maybe.on_base(|s| Ok(Value::I64(s.into_int16() as i64)))?,
417            _ => return no_match_err(),
418        },
419        DataType::Date => match proto_field.kind() {
420            Kind::Int32 => {
421                maybe.on_base(|s| Ok(Value::I32(s.into_date().get_nums_days_unix_epoch())))?
422            }
423            _ => return no_match_err(), // google.type.Date
424        },
425        DataType::Time => match proto_field.kind() {
426            Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_time().to_string())))?,
427            _ => return no_match_err(), // google.type.TimeOfDay
428        },
429        DataType::Timestamp => match proto_field.kind() {
430            Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_timestamp().to_string())))?,
431            _ => return no_match_err(), // google.type.DateTime
432        },
433        DataType::Decimal => match proto_field.kind() {
434            Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_decimal().to_string())))?,
435            _ => return no_match_err(), // google.type.Decimal
436        },
437        DataType::Interval => match proto_field.kind() {
438            Kind::String => {
439                maybe.on_base(|s| Ok(Value::String(s.into_interval().as_iso_8601())))?
440            }
441            _ => return no_match_err(), // Group D: unsupported
442        },
443        DataType::Serial => match proto_field.kind() {
444            Kind::Int64 => maybe.on_base(|s| Ok(Value::I64(s.into_serial().as_row_id())))?,
445            _ => return no_match_err(), // Group D: unsupported
446        },
447        DataType::Int256 => {
448            return no_match_err();
449        }
450        DataType::Map(map_type) => {
451            if proto_field.is_map() {
452                let msg = match proto_field.kind() {
453                    Kind::Message(m) => m,
454                    _ => return no_match_err(), // unreachable actually
455                };
456                return maybe.on_map(map_type, &msg);
457            } else {
458                return no_match_err();
459            }
460        }
461        DataType::Vector(_) => todo!("VECTOR_PLACEHOLDER"),
462    };
463
464    Ok(value)
465}
466
467#[cfg(test)]
468mod tests {
469    use itertools::Itertools;
470    use risingwave_common::array::{ArrayBuilder, StructArrayBuilder};
471    use risingwave_common::catalog::Field;
472    use risingwave_common::row::OwnedRow;
473    use risingwave_common::types::{
474        ListValue, MapType, MapValue, Scalar, ScalarImpl, StructValue, Timestamptz,
475    };
476
477    use super::*;
478
479    #[test]
480    fn test_encode_proto_ok() {
481        let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
482            .join("codec/tests/test_data/all-types.pb");
483        let pool_bytes = std::fs::read(pool_path).unwrap();
484        let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
485        let descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
486        let schema = Schema::new(vec![
487            Field::with_name(DataType::Boolean, "bool_field"),
488            Field::with_name(DataType::Varchar, "string_field"),
489            Field::with_name(DataType::Bytea, "bytes_field"),
490            Field::with_name(DataType::Float32, "float_field"),
491            Field::with_name(DataType::Float64, "double_field"),
492            Field::with_name(DataType::Int32, "int32_field"),
493            Field::with_name(DataType::Int64, "int64_field"),
494            Field::with_name(DataType::Int32, "sint32_field"),
495            Field::with_name(DataType::Int64, "sint64_field"),
496            Field::with_name(DataType::Int32, "sfixed32_field"),
497            Field::with_name(DataType::Int64, "sfixed64_field"),
498            Field::with_name(
499                DataType::Struct(StructType::new(vec![
500                    ("id", DataType::Int32),
501                    ("name", DataType::Varchar),
502                ])),
503                "nested_message_field",
504            ),
505            Field::with_name(DataType::List(DataType::Int32.into()), "repeated_int_field"),
506            Field::with_name(DataType::Timestamptz, "timestamp_field"),
507            Field::with_name(
508                DataType::Map(MapType::from_kv(DataType::Varchar, DataType::Int32)),
509                "map_field",
510            ),
511            Field::with_name(
512                DataType::Map(MapType::from_kv(
513                    DataType::Varchar,
514                    DataType::Struct(StructType::new(vec![
515                        ("id", DataType::Int32),
516                        ("name", DataType::Varchar),
517                    ])),
518                )),
519                "map_struct_field",
520            ),
521        ]);
522        let row = OwnedRow::new(vec![
523            Some(ScalarImpl::Bool(true)),
524            Some(ScalarImpl::Utf8("RisingWave".into())),
525            Some(ScalarImpl::Bytea([0xbe, 0xef].into())),
526            Some(ScalarImpl::Float32(3.5f32.into())),
527            Some(ScalarImpl::Float64(4.25f64.into())),
528            Some(ScalarImpl::Int32(22)),
529            Some(ScalarImpl::Int64(23)),
530            Some(ScalarImpl::Int32(24)),
531            None,
532            Some(ScalarImpl::Int32(26)),
533            Some(ScalarImpl::Int64(27)),
534            Some(ScalarImpl::Struct(StructValue::new(vec![
535                Some(ScalarImpl::Int32(1)),
536                Some(ScalarImpl::Utf8("".into())),
537            ]))),
538            Some(ScalarImpl::List(ListValue::from_iter([4, 0, 4]))),
539            Some(ScalarImpl::Timestamptz(Timestamptz::from_micros(3))),
540            Some(ScalarImpl::Map(
541                MapValue::try_from_kv(
542                    ListValue::from_iter(["a", "b"]),
543                    ListValue::from_iter([1, 2]),
544                )
545                .unwrap(),
546            )),
547            {
548                let mut struct_array_builder = StructArrayBuilder::with_type(
549                    2,
550                    DataType::Struct(StructType::new(vec![
551                        ("id", DataType::Int32),
552                        ("name", DataType::Varchar),
553                    ])),
554                );
555                struct_array_builder.append(Some(
556                    StructValue::new(vec![
557                        Some(ScalarImpl::Int32(1)),
558                        Some(ScalarImpl::Utf8("x".into())),
559                    ])
560                    .as_scalar_ref(),
561                ));
562                struct_array_builder.append(Some(
563                    StructValue::new(vec![
564                        Some(ScalarImpl::Int32(2)),
565                        Some(ScalarImpl::Utf8("y".into())),
566                    ])
567                    .as_scalar_ref(),
568                ));
569                Some(ScalarImpl::Map(
570                    MapValue::try_from_kv(
571                        ListValue::from_iter(["a", "b"]),
572                        ListValue::new(struct_array_builder.finish().into()),
573                    )
574                    .unwrap(),
575                ))
576            },
577        ]);
578
579        let encoder =
580            ProtoEncoder::new(schema, None, descriptor.clone(), ProtoHeader::None).unwrap();
581        let m = encoder.encode(row).unwrap();
582        expect_test::expect![[r#"
583            field: FieldDescriptor {
584                name: "double_field",
585                full_name: "all_types.AllTypes.double_field",
586                json_name: "doubleField",
587                number: 1,
588                kind: double,
589                cardinality: Optional,
590                containing_oneof: None,
591                default_value: None,
592                is_group: false,
593                is_list: false,
594                is_map: false,
595                is_packed: false,
596                supports_presence: false,
597            }
598
599            value: F64(4.25)
600
601            ==============================
602            field: FieldDescriptor {
603                name: "float_field",
604                full_name: "all_types.AllTypes.float_field",
605                json_name: "floatField",
606                number: 2,
607                kind: float,
608                cardinality: Optional,
609                containing_oneof: None,
610                default_value: None,
611                is_group: false,
612                is_list: false,
613                is_map: false,
614                is_packed: false,
615                supports_presence: false,
616            }
617
618            value: F32(3.5)
619
620            ==============================
621            field: FieldDescriptor {
622                name: "int32_field",
623                full_name: "all_types.AllTypes.int32_field",
624                json_name: "int32Field",
625                number: 3,
626                kind: int32,
627                cardinality: Optional,
628                containing_oneof: None,
629                default_value: None,
630                is_group: false,
631                is_list: false,
632                is_map: false,
633                is_packed: false,
634                supports_presence: false,
635            }
636
637            value: I32(22)
638
639            ==============================
640            field: FieldDescriptor {
641                name: "int64_field",
642                full_name: "all_types.AllTypes.int64_field",
643                json_name: "int64Field",
644                number: 4,
645                kind: int64,
646                cardinality: Optional,
647                containing_oneof: None,
648                default_value: None,
649                is_group: false,
650                is_list: false,
651                is_map: false,
652                is_packed: false,
653                supports_presence: false,
654            }
655
656            value: I64(23)
657
658            ==============================
659            field: FieldDescriptor {
660                name: "sint32_field",
661                full_name: "all_types.AllTypes.sint32_field",
662                json_name: "sint32Field",
663                number: 7,
664                kind: sint32,
665                cardinality: Optional,
666                containing_oneof: None,
667                default_value: None,
668                is_group: false,
669                is_list: false,
670                is_map: false,
671                is_packed: false,
672                supports_presence: false,
673            }
674
675            value: I32(24)
676
677            ==============================
678            field: FieldDescriptor {
679                name: "sfixed32_field",
680                full_name: "all_types.AllTypes.sfixed32_field",
681                json_name: "sfixed32Field",
682                number: 11,
683                kind: sfixed32,
684                cardinality: Optional,
685                containing_oneof: None,
686                default_value: None,
687                is_group: false,
688                is_list: false,
689                is_map: false,
690                is_packed: false,
691                supports_presence: false,
692            }
693
694            value: I32(26)
695
696            ==============================
697            field: FieldDescriptor {
698                name: "sfixed64_field",
699                full_name: "all_types.AllTypes.sfixed64_field",
700                json_name: "sfixed64Field",
701                number: 12,
702                kind: sfixed64,
703                cardinality: Optional,
704                containing_oneof: None,
705                default_value: None,
706                is_group: false,
707                is_list: false,
708                is_map: false,
709                is_packed: false,
710                supports_presence: false,
711            }
712
713            value: I64(27)
714
715            ==============================
716            field: FieldDescriptor {
717                name: "bool_field",
718                full_name: "all_types.AllTypes.bool_field",
719                json_name: "boolField",
720                number: 13,
721                kind: bool,
722                cardinality: Optional,
723                containing_oneof: None,
724                default_value: None,
725                is_group: false,
726                is_list: false,
727                is_map: false,
728                is_packed: false,
729                supports_presence: false,
730            }
731
732            value: Bool(true)
733
734            ==============================
735            field: FieldDescriptor {
736                name: "string_field",
737                full_name: "all_types.AllTypes.string_field",
738                json_name: "stringField",
739                number: 14,
740                kind: string,
741                cardinality: Optional,
742                containing_oneof: None,
743                default_value: None,
744                is_group: false,
745                is_list: false,
746                is_map: false,
747                is_packed: false,
748                supports_presence: false,
749            }
750
751            value: String("RisingWave")
752
753            ==============================
754            field: FieldDescriptor {
755                name: "bytes_field",
756                full_name: "all_types.AllTypes.bytes_field",
757                json_name: "bytesField",
758                number: 15,
759                kind: bytes,
760                cardinality: Optional,
761                containing_oneof: None,
762                default_value: None,
763                is_group: false,
764                is_list: false,
765                is_map: false,
766                is_packed: false,
767                supports_presence: false,
768            }
769
770            value: Bytes(b"\xbe\xef")
771
772            ==============================
773            field: FieldDescriptor {
774                name: "nested_message_field",
775                full_name: "all_types.AllTypes.nested_message_field",
776                json_name: "nestedMessageField",
777                number: 17,
778                kind: all_types.AllTypes.NestedMessage,
779                cardinality: Optional,
780                containing_oneof: None,
781                default_value: None,
782                is_group: false,
783                is_list: false,
784                is_map: false,
785                is_packed: false,
786                supports_presence: true,
787            }
788
789            value: Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(1)), 2: Value(String(""))} } })
790
791            ==============================
792            field: FieldDescriptor {
793                name: "repeated_int_field",
794                full_name: "all_types.AllTypes.repeated_int_field",
795                json_name: "repeatedIntField",
796                number: 18,
797                kind: int32,
798                cardinality: Repeated,
799                containing_oneof: None,
800                default_value: None,
801                is_group: false,
802                is_list: true,
803                is_map: false,
804                is_packed: true,
805                supports_presence: false,
806            }
807
808            value: List([I32(4), I32(0), I32(4)])
809
810            ==============================
811            field: FieldDescriptor {
812                name: "map_field",
813                full_name: "all_types.AllTypes.map_field",
814                json_name: "mapField",
815                number: 22,
816                kind: all_types.AllTypes.MapFieldEntry,
817                cardinality: Repeated,
818                containing_oneof: None,
819                default_value: None,
820                is_group: false,
821                is_list: false,
822                is_map: true,
823                is_packed: false,
824                supports_presence: false,
825            }
826
827            value: Map({
828                String("a"): I32(1),
829                String("b"): I32(2),
830            })
831
832            ==============================
833            field: FieldDescriptor {
834                name: "timestamp_field",
835                full_name: "all_types.AllTypes.timestamp_field",
836                json_name: "timestampField",
837                number: 23,
838                kind: google.protobuf.Timestamp,
839                cardinality: Optional,
840                containing_oneof: None,
841                default_value: None,
842                is_group: false,
843                is_list: false,
844                is_map: false,
845                is_packed: false,
846                supports_presence: true,
847            }
848
849            value: Message(DynamicMessage { desc: MessageDescriptor { name: "Timestamp", full_name: "google.protobuf.Timestamp", is_map_entry: false, fields: [FieldDescriptor { name: "seconds", full_name: "google.protobuf.Timestamp.seconds", json_name: "seconds", number: 1, kind: int64, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "nanos", full_name: "google.protobuf.Timestamp.nanos", json_name: "nanos", number: 2, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {2: Value(I32(3000))} } })
850
851            ==============================
852            field: FieldDescriptor {
853                name: "map_struct_field",
854                full_name: "all_types.AllTypes.map_struct_field",
855                json_name: "mapStructField",
856                number: 29,
857                kind: all_types.AllTypes.MapStructFieldEntry,
858                cardinality: Repeated,
859                containing_oneof: None,
860                default_value: None,
861                is_group: false,
862                is_list: false,
863                is_map: true,
864                is_packed: false,
865                supports_presence: false,
866            }
867
868            value: Map({
869                String("a"): Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(1)), 2: Value(String("x"))} } }),
870                String("b"): Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(2)), 2: Value(String("y"))} } }),
871            })"#]].assert_eq(&format!("{}",
872            m.message.fields().format_with("\n\n==============================\n", |(field,value),f| {
873            f(&format!("field: {:#?}\n\nvalue: {}", field, print_proto(value)))
874        })));
875    }
876
877    fn print_proto(value: &Value) -> String {
878        match value {
879            Value::Map(m) => {
880                let mut res = String::new();
881                res.push_str("Map({\n");
882                for (k, v) in m.iter().sorted_by_key(|(k, _v)| *k) {
883                    res.push_str(&format!(
884                        "    {}: {},\n",
885                        print_proto(&k.clone().into()),
886                        print_proto(v)
887                    ));
888                }
889                res.push_str("})");
890                res
891            }
892            _ => format!("{:?}", value),
893        }
894    }
895
896    #[test]
897    fn test_encode_proto_repeated() {
898        let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
899            .join("codec/tests/test_data/all-types.pb");
900        let pool_bytes = fs_err::read(pool_path).unwrap();
901        let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
902        let message_descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
903
904        let schema = Schema::new(vec![Field::with_name(
905            DataType::List(DataType::List(DataType::Int32.into()).into()),
906            "repeated_int_field",
907        )]);
908
909        let err = validate_fields(
910            schema
911                .fields
912                .iter()
913                .map(|f| (f.name.as_str(), &f.data_type)),
914            &message_descriptor,
915        )
916        .unwrap_err();
917        assert_eq!(
918            err.to_string(),
919            "encode 'repeated_int_field' error: cannot encode integer[] column as int32 field"
920        );
921
922        let schema = Schema::new(vec![Field::with_name(
923            DataType::List(DataType::Int32.into()),
924            "repeated_int_field",
925        )]);
926        let row = OwnedRow::new(vec![Some(ScalarImpl::List(ListValue::from_iter([
927            Some(0),
928            None,
929            Some(2),
930            Some(3),
931        ])))]);
932
933        let err = encode_fields(
934            schema
935                .fields
936                .iter()
937                .map(|f| (f.name.as_str(), &f.data_type))
938                .zip_eq_debug(row.iter()),
939            &message_descriptor,
940        )
941        .unwrap_err();
942        assert_eq!(
943            err.to_string(),
944            "encode 'repeated_int_field' error: array containing null not allowed as repeated field"
945        );
946    }
947
948    #[test]
949    fn test_encode_proto_err() {
950        let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
951            .join("codec/tests/test_data/all-types.pb");
952        let pool_bytes = std::fs::read(pool_path).unwrap();
953        let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
954        let message_descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
955
956        let err = validate_fields(
957            std::iter::once(("not_exists", &DataType::Int16)),
958            &message_descriptor,
959        )
960        .unwrap_err();
961        assert_eq!(
962            err.to_string(),
963            "encode 'not_exists' error: field not in proto"
964        );
965
966        let err = validate_fields(
967            std::iter::once(("map_field", &DataType::Jsonb)),
968            &message_descriptor,
969        )
970        .unwrap_err();
971        assert_eq!(
972            err.to_string(),
973            "encode 'map_field' error: cannot encode jsonb column as all_types.AllTypes.MapFieldEntry field"
974        );
975    }
976}