risingwave_connector/sink/encoder/
proto.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use bytes::{BufMut, Bytes};
16use prost::Message;
17use prost_reflect::{
18    DynamicMessage, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage, Value,
19};
20use risingwave_common::catalog::Schema;
21use risingwave_common::row::Row;
22use risingwave_common::types::{DataType, DatumRef, MapType, ScalarRefImpl, StructType};
23use risingwave_common::util::iter_util::ZipEqDebug;
24
25use super::{FieldEncodeError, Result as SinkResult, RowEncoder, SerTo};
26
27type Result<T> = std::result::Result<T, FieldEncodeError>;
28
29pub struct ProtoEncoder {
30    schema: Schema,
31    col_indices: Option<Vec<usize>>,
32    descriptor: MessageDescriptor,
33    header: ProtoHeader,
34}
35
36#[derive(Debug, Clone, Copy)]
37pub enum ProtoHeader {
38    None,
39    /// <https://docs.confluent.io/platform/7.5/schema-registry/fundamentals/serdes-develop/index.html#messages-wire-format>
40    ///
41    /// * 00
42    /// * 4-byte big-endian schema ID
43    ConfluentSchemaRegistry(i32),
44}
45
46impl ProtoEncoder {
47    pub fn new(
48        schema: Schema,
49        col_indices: Option<Vec<usize>>,
50        descriptor: MessageDescriptor,
51        header: ProtoHeader,
52    ) -> SinkResult<Self> {
53        match &col_indices {
54            Some(col_indices) => validate_fields(
55                col_indices.iter().map(|idx| {
56                    let f = &schema[*idx];
57                    (f.name.as_str(), &f.data_type)
58                }),
59                &descriptor,
60            )?,
61            None => validate_fields(
62                schema
63                    .fields
64                    .iter()
65                    .map(|f| (f.name.as_str(), &f.data_type)),
66                &descriptor,
67            )?,
68        };
69
70        Ok(Self {
71            schema,
72            col_indices,
73            descriptor,
74            header,
75        })
76    }
77}
78
79pub struct ProtoEncoded {
80    pub message: DynamicMessage,
81    header: ProtoHeader,
82}
83
84impl RowEncoder for ProtoEncoder {
85    type Output = ProtoEncoded;
86
87    fn schema(&self) -> &Schema {
88        &self.schema
89    }
90
91    fn col_indices(&self) -> Option<&[usize]> {
92        self.col_indices.as_deref()
93    }
94
95    fn encode_cols(
96        &self,
97        row: impl Row,
98        col_indices: impl Iterator<Item = usize>,
99    ) -> SinkResult<Self::Output> {
100        encode_fields(
101            col_indices.map(|idx| {
102                let f = &self.schema[idx];
103                ((f.name.as_str(), &f.data_type), row.datum_at(idx))
104            }),
105            &self.descriptor,
106        )
107        .map_err(Into::into)
108        .map(|m| ProtoEncoded {
109            message: m,
110            header: self.header,
111        })
112    }
113}
114
115impl SerTo<Vec<u8>> for ProtoEncoded {
116    fn ser_to(self) -> SinkResult<Vec<u8>> {
117        let mut buf = Vec::new();
118        match self.header {
119            ProtoHeader::None => { /* noop */ }
120            ProtoHeader::ConfluentSchemaRegistry(schema_id) => {
121                buf.reserve(1 + 4);
122                buf.put_u8(0);
123                buf.put_i32(schema_id);
124                MessageIndexes::from(self.message.descriptor()).encode(&mut buf);
125            }
126        }
127        self.message.encode(&mut buf).unwrap();
128        Ok(buf)
129    }
130}
131
132struct MessageIndexes(Vec<i32>);
133
134impl MessageIndexes {
135    fn from(desc: MessageDescriptor) -> Self {
136        // https://github.com/protocolbuffers/protobuf/blob/v25.1/src/google/protobuf/descriptor.proto
137        // https://docs.rs/prost-reflect/0.12.0/src/prost_reflect/descriptor/tag.rs.html
138        // https://docs.rs/prost-reflect/0.12.0/src/prost_reflect/descriptor/build/visit.rs.html#125
139        // `FileDescriptorProto` field #4 is `repeated DescriptorProto message_type`
140        const TAG_FILE_MESSAGE: i32 = 4;
141        // `DescriptorProto` field #3 is `repeated DescriptorProto nested_type`
142        const TAG_MESSAGE_NESTED: i32 = 3;
143
144        let mut indexes = vec![];
145        let mut path = desc.path().array_chunks();
146        let &[tag, idx] = path.next().unwrap();
147        assert_eq!(tag, TAG_FILE_MESSAGE);
148        indexes.push(idx);
149        for &[tag, idx] in path {
150            assert_eq!(tag, TAG_MESSAGE_NESTED);
151            indexes.push(idx);
152        }
153        Self(indexes)
154    }
155
156    fn zig_i32(value: i32, buf: &mut impl BufMut) {
157        let unsigned = ((value << 1) ^ (value >> 31)) as u32 as u64;
158        prost::encoding::encode_varint(unsigned, buf);
159    }
160
161    fn encode(&self, buf: &mut impl BufMut) {
162        if self.0 == [0] {
163            buf.put_u8(0);
164            return;
165        }
166        Self::zig_i32(self.0.len().try_into().unwrap(), buf);
167        for &idx in &self.0 {
168            Self::zig_i32(idx, buf);
169        }
170    }
171}
172
173/// A trait that assists code reuse between `validate` and `encode`.
174/// * For `validate`, the inputs are (RisingWave type, ProtoBuf type).
175/// * For `encode`, the inputs are (RisingWave type, RisingWave data, ProtoBuf type).
176///
177/// Thus we impl [`MaybeData`] for both [`()`] and [`ScalarRefImpl`].
178trait MaybeData: std::fmt::Debug {
179    type Out;
180
181    fn on_base(self, f: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out>;
182
183    fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out>;
184
185    fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out>;
186
187    fn on_map(self, m: &MapType, pb: &MessageDescriptor) -> Result<Self::Out>;
188}
189
190impl MaybeData for () {
191    type Out = ();
192
193    fn on_base(self, _: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out> {
194        Ok(self)
195    }
196
197    fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out> {
198        validate_fields(st.iter(), pb)
199    }
200
201    fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out> {
202        on_field(elem, (), pb, true)
203    }
204
205    fn on_map(self, elem: &MapType, pb: &MessageDescriptor) -> Result<Self::Out> {
206        debug_assert!(pb.is_map_entry());
207        on_field(elem.key(), (), &pb.map_entry_key_field(), false)?;
208        on_field(elem.value(), (), &pb.map_entry_value_field(), false)?;
209        Ok(())
210    }
211}
212
213/// Nullability is not part of type system in proto.
214/// * Top level is always a message.
215/// * All message fields can be omitted in proto3.
216/// * All repeated elements must have a value.
217///
218/// So we handle [`ScalarRefImpl`] rather than [`DatumRef`] here.
219impl MaybeData for ScalarRefImpl<'_> {
220    type Out = Value;
221
222    fn on_base(self, f: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out> {
223        f(self)
224    }
225
226    fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out> {
227        let d = self.into_struct();
228        let message = encode_fields(st.iter().zip_eq_debug(d.iter_fields_ref()), pb)?;
229        Ok(Value::Message(message))
230    }
231
232    fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out> {
233        let d = self.into_list();
234        let vs = d
235            .iter()
236            .map(|d| {
237                on_field(
238                    elem,
239                    d.ok_or_else(|| {
240                        FieldEncodeError::new("array containing null not allowed as repeated field")
241                    })?,
242                    pb,
243                    true,
244                )
245            })
246            .try_collect()?;
247        Ok(Value::List(vs))
248    }
249
250    fn on_map(self, m: &MapType, pb: &MessageDescriptor) -> Result<Self::Out> {
251        debug_assert!(pb.is_map_entry());
252        let vs = self
253            .into_map()
254            .iter()
255            .map(|(k, v)| {
256                let v =
257                    v.ok_or_else(|| FieldEncodeError::new("map containing null not allowed"))?;
258                let k = on_field(m.key(), k, &pb.map_entry_key_field(), false)?;
259                let v = on_field(m.value(), v, &pb.map_entry_value_field(), false)?;
260                Ok((
261                    k.into_map_key().ok_or_else(|| {
262                        FieldEncodeError::new("failed to convert map key to proto")
263                    })?,
264                    v,
265                ))
266            })
267            .try_collect()?;
268        Ok(Value::Map(vs))
269    }
270}
271
272fn validate_fields<'a>(
273    fields: impl Iterator<Item = (&'a str, &'a DataType)>,
274    descriptor: &MessageDescriptor,
275) -> Result<()> {
276    for (name, t) in fields {
277        let Some(proto_field) = descriptor.get_field_by_name(name) else {
278            return Err(FieldEncodeError::new("field not in proto").with_name(name));
279        };
280        if proto_field.cardinality() == prost_reflect::Cardinality::Required {
281            return Err(FieldEncodeError::new("`required` not supported").with_name(name));
282        }
283        on_field(t, (), &proto_field, false).map_err(|e| e.with_name(name))?;
284    }
285    Ok(())
286}
287
288fn encode_fields<'a>(
289    fields_with_datums: impl Iterator<Item = ((&'a str, &'a DataType), DatumRef<'a>)>,
290    descriptor: &MessageDescriptor,
291) -> Result<DynamicMessage> {
292    let mut message = DynamicMessage::new(descriptor.clone());
293    for ((name, t), d) in fields_with_datums {
294        let proto_field = descriptor.get_field_by_name(name).unwrap();
295        // On `null`, simply skip setting the field.
296        if let Some(scalar) = d {
297            let value = on_field(t, scalar, &proto_field, false).map_err(|e| e.with_name(name))?;
298            message
299                .try_set_field(&proto_field, value)
300                .map_err(|e| FieldEncodeError::new(e).with_name(name))?;
301        }
302    }
303    Ok(message)
304}
305
306// Full name of Well-Known Types
307const WKT_TIMESTAMP: &str = "google.protobuf.Timestamp";
308#[expect(dead_code)]
309const WKT_BOOL_VALUE: &str = "google.protobuf.BoolValue";
310
311/// Handles both `validate` (without actual data) and `encode`.
312/// See [`MaybeData`] for more info.
313fn on_field<D: MaybeData>(
314    data_type: &DataType,
315    maybe: D,
316    proto_field: &FieldDescriptor,
317    in_repeated: bool,
318) -> Result<D::Out> {
319    // Regarding (proto_field.is_list, in_repeated):
320    // (F, T) => impossible
321    // (F, F) => encoding to a non-repeated field
322    // (T, F) => encoding to a repeated field
323    // (T, T) => encoding to an element of a repeated field
324    // In the bottom 2 cases, we need to distinguish the same `proto_field` with the help of `in_repeated`.
325    assert!(proto_field.is_list() || !in_repeated);
326    let expect_list = proto_field.is_list() && !in_repeated;
327    if proto_field.is_group() {
328        return Err(FieldEncodeError::new("proto group not supported yet"));
329    }
330
331    let no_match_err = || {
332        Err(FieldEncodeError::new(format!(
333            "cannot encode {} column as {}{:?} field",
334            data_type,
335            if expect_list { "repeated " } else { "" },
336            proto_field.kind()
337        )))
338    };
339
340    if expect_list && !matches!(data_type, DataType::List(_)) {
341        return no_match_err();
342    }
343
344    let value = match &data_type {
345        // Group A: perfect match between RisingWave types and ProtoBuf types
346        DataType::Boolean => match proto_field.kind() {
347            Kind::Bool => maybe.on_base(|s| Ok(Value::Bool(s.into_bool())))?,
348            _ => return no_match_err(),
349        },
350        DataType::Varchar => match proto_field.kind() {
351            Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_utf8().into())))?,
352            Kind::Enum(enum_desc) => maybe.on_base(|s| {
353                let name = s.into_utf8();
354                let enum_value_desc = enum_desc.get_value_by_name(name).ok_or_else(|| {
355                    FieldEncodeError::new(format!("'{name}' not in enum {}", enum_desc.name()))
356                })?;
357                Ok(Value::EnumNumber(enum_value_desc.number()))
358            })?,
359            _ => return no_match_err(),
360        },
361        DataType::Bytea => match proto_field.kind() {
362            Kind::Bytes => {
363                maybe.on_base(|s| Ok(Value::Bytes(Bytes::copy_from_slice(s.into_bytea()))))?
364            }
365            _ => return no_match_err(),
366        },
367        DataType::Float32 => match proto_field.kind() {
368            Kind::Float => maybe.on_base(|s| Ok(Value::F32(s.into_float32().into())))?,
369            _ => return no_match_err(),
370        },
371        DataType::Float64 => match proto_field.kind() {
372            Kind::Double => maybe.on_base(|s| Ok(Value::F64(s.into_float64().into())))?,
373            _ => return no_match_err(),
374        },
375        DataType::Int32 => match proto_field.kind() {
376            Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => {
377                maybe.on_base(|s| Ok(Value::I32(s.into_int32())))?
378            }
379            _ => return no_match_err(),
380        },
381        DataType::Int64 => match proto_field.kind() {
382            Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => {
383                maybe.on_base(|s| Ok(Value::I64(s.into_int64())))?
384            }
385            _ => return no_match_err(),
386        },
387        DataType::Struct(st) => match proto_field.kind() {
388            Kind::Message(pb) => maybe.on_struct(st, &pb)?,
389            _ => return no_match_err(),
390        },
391        DataType::List(elem) => match expect_list {
392            true => maybe.on_list(elem, proto_field)?,
393            false => return no_match_err(),
394        },
395        // Group B: match between RisingWave types and ProtoBuf Well-Known types
396        DataType::Timestamptz => match proto_field.kind() {
397            Kind::Message(pb) if pb.full_name() == WKT_TIMESTAMP => maybe.on_base(|s| {
398                let d = s.into_timestamptz();
399                let message = prost_types::Timestamp {
400                    seconds: d.timestamp(),
401                    nanos: d.timestamp_subsec_nanos().try_into().unwrap(),
402                };
403                Ok(Value::Message(message.transcode_to_dynamic()))
404            })?,
405            Kind::String => {
406                maybe.on_base(|s| Ok(Value::String(s.into_timestamptz().to_string())))?
407            }
408            _ => return no_match_err(),
409        },
410        DataType::Jsonb => match proto_field.kind() {
411            Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_jsonb().to_string())))?,
412            _ => return no_match_err(), /* Value, NullValue, Struct (map), ListValue
413                                         * Group C: experimental */
414        },
415        DataType::Int16 => match proto_field.kind() {
416            Kind::Int64 => maybe.on_base(|s| Ok(Value::I64(s.into_int16() as i64)))?,
417            _ => return no_match_err(),
418        },
419        DataType::Date => match proto_field.kind() {
420            Kind::Int32 => {
421                maybe.on_base(|s| Ok(Value::I32(s.into_date().get_nums_days_unix_epoch())))?
422            }
423            _ => return no_match_err(), // google.type.Date
424        },
425        DataType::Time => match proto_field.kind() {
426            Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_time().to_string())))?,
427            _ => return no_match_err(), // google.type.TimeOfDay
428        },
429        DataType::Timestamp => match proto_field.kind() {
430            Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_timestamp().to_string())))?,
431            _ => return no_match_err(), // google.type.DateTime
432        },
433        DataType::Decimal => match proto_field.kind() {
434            Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_decimal().to_string())))?,
435            _ => return no_match_err(), // google.type.Decimal
436        },
437        DataType::Interval => match proto_field.kind() {
438            Kind::String => {
439                maybe.on_base(|s| Ok(Value::String(s.into_interval().as_iso_8601())))?
440            }
441            _ => return no_match_err(), // Group D: unsupported
442        },
443        DataType::Serial => match proto_field.kind() {
444            Kind::Int64 => maybe.on_base(|s| Ok(Value::I64(s.into_serial().as_row_id())))?,
445            _ => return no_match_err(), // Group D: unsupported
446        },
447        DataType::Int256 => {
448            return no_match_err();
449        }
450        DataType::Map(map_type) => {
451            if proto_field.is_map() {
452                let msg = match proto_field.kind() {
453                    Kind::Message(m) => m,
454                    _ => return no_match_err(), // unreachable actually
455                };
456                return maybe.on_map(map_type, &msg);
457            } else {
458                return no_match_err();
459            }
460        }
461    };
462
463    Ok(value)
464}
465
466#[cfg(test)]
467mod tests {
468    use itertools::Itertools;
469    use risingwave_common::array::{ArrayBuilder, StructArrayBuilder};
470    use risingwave_common::catalog::Field;
471    use risingwave_common::row::OwnedRow;
472    use risingwave_common::types::{
473        ListValue, MapType, MapValue, Scalar, ScalarImpl, StructValue, Timestamptz,
474    };
475
476    use super::*;
477
478    #[test]
479    fn test_encode_proto_ok() {
480        let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
481            .join("codec/tests/test_data/all-types.pb");
482        let pool_bytes = std::fs::read(pool_path).unwrap();
483        let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
484        let descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
485        let schema = Schema::new(vec![
486            Field::with_name(DataType::Boolean, "bool_field"),
487            Field::with_name(DataType::Varchar, "string_field"),
488            Field::with_name(DataType::Bytea, "bytes_field"),
489            Field::with_name(DataType::Float32, "float_field"),
490            Field::with_name(DataType::Float64, "double_field"),
491            Field::with_name(DataType::Int32, "int32_field"),
492            Field::with_name(DataType::Int64, "int64_field"),
493            Field::with_name(DataType::Int32, "sint32_field"),
494            Field::with_name(DataType::Int64, "sint64_field"),
495            Field::with_name(DataType::Int32, "sfixed32_field"),
496            Field::with_name(DataType::Int64, "sfixed64_field"),
497            Field::with_name(
498                DataType::Struct(StructType::new(vec![
499                    ("id", DataType::Int32),
500                    ("name", DataType::Varchar),
501                ])),
502                "nested_message_field",
503            ),
504            Field::with_name(DataType::List(DataType::Int32.into()), "repeated_int_field"),
505            Field::with_name(DataType::Timestamptz, "timestamp_field"),
506            Field::with_name(
507                DataType::Map(MapType::from_kv(DataType::Varchar, DataType::Int32)),
508                "map_field",
509            ),
510            Field::with_name(
511                DataType::Map(MapType::from_kv(
512                    DataType::Varchar,
513                    DataType::Struct(StructType::new(vec![
514                        ("id", DataType::Int32),
515                        ("name", DataType::Varchar),
516                    ])),
517                )),
518                "map_struct_field",
519            ),
520        ]);
521        let row = OwnedRow::new(vec![
522            Some(ScalarImpl::Bool(true)),
523            Some(ScalarImpl::Utf8("RisingWave".into())),
524            Some(ScalarImpl::Bytea([0xbe, 0xef].into())),
525            Some(ScalarImpl::Float32(3.5f32.into())),
526            Some(ScalarImpl::Float64(4.25f64.into())),
527            Some(ScalarImpl::Int32(22)),
528            Some(ScalarImpl::Int64(23)),
529            Some(ScalarImpl::Int32(24)),
530            None,
531            Some(ScalarImpl::Int32(26)),
532            Some(ScalarImpl::Int64(27)),
533            Some(ScalarImpl::Struct(StructValue::new(vec![
534                Some(ScalarImpl::Int32(1)),
535                Some(ScalarImpl::Utf8("".into())),
536            ]))),
537            Some(ScalarImpl::List(ListValue::from_iter([4, 0, 4]))),
538            Some(ScalarImpl::Timestamptz(Timestamptz::from_micros(3))),
539            Some(ScalarImpl::Map(
540                MapValue::try_from_kv(
541                    ListValue::from_iter(["a", "b"]),
542                    ListValue::from_iter([1, 2]),
543                )
544                .unwrap(),
545            )),
546            {
547                let mut struct_array_builder = StructArrayBuilder::with_type(
548                    2,
549                    DataType::Struct(StructType::new(vec![
550                        ("id", DataType::Int32),
551                        ("name", DataType::Varchar),
552                    ])),
553                );
554                struct_array_builder.append(Some(
555                    StructValue::new(vec![
556                        Some(ScalarImpl::Int32(1)),
557                        Some(ScalarImpl::Utf8("x".into())),
558                    ])
559                    .as_scalar_ref(),
560                ));
561                struct_array_builder.append(Some(
562                    StructValue::new(vec![
563                        Some(ScalarImpl::Int32(2)),
564                        Some(ScalarImpl::Utf8("y".into())),
565                    ])
566                    .as_scalar_ref(),
567                ));
568                Some(ScalarImpl::Map(
569                    MapValue::try_from_kv(
570                        ListValue::from_iter(["a", "b"]),
571                        ListValue::new(struct_array_builder.finish().into()),
572                    )
573                    .unwrap(),
574                ))
575            },
576        ]);
577
578        let encoder =
579            ProtoEncoder::new(schema, None, descriptor.clone(), ProtoHeader::None).unwrap();
580        let m = encoder.encode(row).unwrap();
581        expect_test::expect![[r#"
582            field: FieldDescriptor {
583                name: "double_field",
584                full_name: "all_types.AllTypes.double_field",
585                json_name: "doubleField",
586                number: 1,
587                kind: double,
588                cardinality: Optional,
589                containing_oneof: None,
590                default_value: None,
591                is_group: false,
592                is_list: false,
593                is_map: false,
594                is_packed: false,
595                supports_presence: false,
596            }
597
598            value: F64(4.25)
599
600            ==============================
601            field: FieldDescriptor {
602                name: "float_field",
603                full_name: "all_types.AllTypes.float_field",
604                json_name: "floatField",
605                number: 2,
606                kind: float,
607                cardinality: Optional,
608                containing_oneof: None,
609                default_value: None,
610                is_group: false,
611                is_list: false,
612                is_map: false,
613                is_packed: false,
614                supports_presence: false,
615            }
616
617            value: F32(3.5)
618
619            ==============================
620            field: FieldDescriptor {
621                name: "int32_field",
622                full_name: "all_types.AllTypes.int32_field",
623                json_name: "int32Field",
624                number: 3,
625                kind: int32,
626                cardinality: Optional,
627                containing_oneof: None,
628                default_value: None,
629                is_group: false,
630                is_list: false,
631                is_map: false,
632                is_packed: false,
633                supports_presence: false,
634            }
635
636            value: I32(22)
637
638            ==============================
639            field: FieldDescriptor {
640                name: "int64_field",
641                full_name: "all_types.AllTypes.int64_field",
642                json_name: "int64Field",
643                number: 4,
644                kind: int64,
645                cardinality: Optional,
646                containing_oneof: None,
647                default_value: None,
648                is_group: false,
649                is_list: false,
650                is_map: false,
651                is_packed: false,
652                supports_presence: false,
653            }
654
655            value: I64(23)
656
657            ==============================
658            field: FieldDescriptor {
659                name: "sint32_field",
660                full_name: "all_types.AllTypes.sint32_field",
661                json_name: "sint32Field",
662                number: 7,
663                kind: sint32,
664                cardinality: Optional,
665                containing_oneof: None,
666                default_value: None,
667                is_group: false,
668                is_list: false,
669                is_map: false,
670                is_packed: false,
671                supports_presence: false,
672            }
673
674            value: I32(24)
675
676            ==============================
677            field: FieldDescriptor {
678                name: "sfixed32_field",
679                full_name: "all_types.AllTypes.sfixed32_field",
680                json_name: "sfixed32Field",
681                number: 11,
682                kind: sfixed32,
683                cardinality: Optional,
684                containing_oneof: None,
685                default_value: None,
686                is_group: false,
687                is_list: false,
688                is_map: false,
689                is_packed: false,
690                supports_presence: false,
691            }
692
693            value: I32(26)
694
695            ==============================
696            field: FieldDescriptor {
697                name: "sfixed64_field",
698                full_name: "all_types.AllTypes.sfixed64_field",
699                json_name: "sfixed64Field",
700                number: 12,
701                kind: sfixed64,
702                cardinality: Optional,
703                containing_oneof: None,
704                default_value: None,
705                is_group: false,
706                is_list: false,
707                is_map: false,
708                is_packed: false,
709                supports_presence: false,
710            }
711
712            value: I64(27)
713
714            ==============================
715            field: FieldDescriptor {
716                name: "bool_field",
717                full_name: "all_types.AllTypes.bool_field",
718                json_name: "boolField",
719                number: 13,
720                kind: bool,
721                cardinality: Optional,
722                containing_oneof: None,
723                default_value: None,
724                is_group: false,
725                is_list: false,
726                is_map: false,
727                is_packed: false,
728                supports_presence: false,
729            }
730
731            value: Bool(true)
732
733            ==============================
734            field: FieldDescriptor {
735                name: "string_field",
736                full_name: "all_types.AllTypes.string_field",
737                json_name: "stringField",
738                number: 14,
739                kind: string,
740                cardinality: Optional,
741                containing_oneof: None,
742                default_value: None,
743                is_group: false,
744                is_list: false,
745                is_map: false,
746                is_packed: false,
747                supports_presence: false,
748            }
749
750            value: String("RisingWave")
751
752            ==============================
753            field: FieldDescriptor {
754                name: "bytes_field",
755                full_name: "all_types.AllTypes.bytes_field",
756                json_name: "bytesField",
757                number: 15,
758                kind: bytes,
759                cardinality: Optional,
760                containing_oneof: None,
761                default_value: None,
762                is_group: false,
763                is_list: false,
764                is_map: false,
765                is_packed: false,
766                supports_presence: false,
767            }
768
769            value: Bytes(b"\xbe\xef")
770
771            ==============================
772            field: FieldDescriptor {
773                name: "nested_message_field",
774                full_name: "all_types.AllTypes.nested_message_field",
775                json_name: "nestedMessageField",
776                number: 17,
777                kind: all_types.AllTypes.NestedMessage,
778                cardinality: Optional,
779                containing_oneof: None,
780                default_value: None,
781                is_group: false,
782                is_list: false,
783                is_map: false,
784                is_packed: false,
785                supports_presence: true,
786            }
787
788            value: Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(1)), 2: Value(String(""))} } })
789
790            ==============================
791            field: FieldDescriptor {
792                name: "repeated_int_field",
793                full_name: "all_types.AllTypes.repeated_int_field",
794                json_name: "repeatedIntField",
795                number: 18,
796                kind: int32,
797                cardinality: Repeated,
798                containing_oneof: None,
799                default_value: None,
800                is_group: false,
801                is_list: true,
802                is_map: false,
803                is_packed: true,
804                supports_presence: false,
805            }
806
807            value: List([I32(4), I32(0), I32(4)])
808
809            ==============================
810            field: FieldDescriptor {
811                name: "map_field",
812                full_name: "all_types.AllTypes.map_field",
813                json_name: "mapField",
814                number: 22,
815                kind: all_types.AllTypes.MapFieldEntry,
816                cardinality: Repeated,
817                containing_oneof: None,
818                default_value: None,
819                is_group: false,
820                is_list: false,
821                is_map: true,
822                is_packed: false,
823                supports_presence: false,
824            }
825
826            value: Map({
827                String("a"): I32(1),
828                String("b"): I32(2),
829            })
830
831            ==============================
832            field: FieldDescriptor {
833                name: "timestamp_field",
834                full_name: "all_types.AllTypes.timestamp_field",
835                json_name: "timestampField",
836                number: 23,
837                kind: google.protobuf.Timestamp,
838                cardinality: Optional,
839                containing_oneof: None,
840                default_value: None,
841                is_group: false,
842                is_list: false,
843                is_map: false,
844                is_packed: false,
845                supports_presence: true,
846            }
847
848            value: Message(DynamicMessage { desc: MessageDescriptor { name: "Timestamp", full_name: "google.protobuf.Timestamp", is_map_entry: false, fields: [FieldDescriptor { name: "seconds", full_name: "google.protobuf.Timestamp.seconds", json_name: "seconds", number: 1, kind: int64, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "nanos", full_name: "google.protobuf.Timestamp.nanos", json_name: "nanos", number: 2, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {2: Value(I32(3000))} } })
849
850            ==============================
851            field: FieldDescriptor {
852                name: "map_struct_field",
853                full_name: "all_types.AllTypes.map_struct_field",
854                json_name: "mapStructField",
855                number: 29,
856                kind: all_types.AllTypes.MapStructFieldEntry,
857                cardinality: Repeated,
858                containing_oneof: None,
859                default_value: None,
860                is_group: false,
861                is_list: false,
862                is_map: true,
863                is_packed: false,
864                supports_presence: false,
865            }
866
867            value: Map({
868                String("a"): Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(1)), 2: Value(String("x"))} } }),
869                String("b"): Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: None, is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(2)), 2: Value(String("y"))} } }),
870            })"#]].assert_eq(&format!("{}",
871            m.message.fields().format_with("\n\n==============================\n", |(field,value),f| {
872            f(&format!("field: {:#?}\n\nvalue: {}", field, print_proto(value)))
873        })));
874    }
875
876    fn print_proto(value: &Value) -> String {
877        match value {
878            Value::Map(m) => {
879                let mut res = String::new();
880                res.push_str("Map({\n");
881                for (k, v) in m.iter().sorted_by_key(|(k, _v)| *k) {
882                    res.push_str(&format!(
883                        "    {}: {},\n",
884                        print_proto(&k.clone().into()),
885                        print_proto(v)
886                    ));
887                }
888                res.push_str("})");
889                res
890            }
891            _ => format!("{:?}", value),
892        }
893    }
894
895    #[test]
896    fn test_encode_proto_repeated() {
897        let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
898            .join("codec/tests/test_data/all-types.pb");
899        let pool_bytes = fs_err::read(pool_path).unwrap();
900        let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
901        let message_descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
902
903        let schema = Schema::new(vec![Field::with_name(
904            DataType::List(DataType::List(DataType::Int32.into()).into()),
905            "repeated_int_field",
906        )]);
907
908        let err = validate_fields(
909            schema
910                .fields
911                .iter()
912                .map(|f| (f.name.as_str(), &f.data_type)),
913            &message_descriptor,
914        )
915        .unwrap_err();
916        assert_eq!(
917            err.to_string(),
918            "encode 'repeated_int_field' error: cannot encode integer[] column as int32 field"
919        );
920
921        let schema = Schema::new(vec![Field::with_name(
922            DataType::List(DataType::Int32.into()),
923            "repeated_int_field",
924        )]);
925        let row = OwnedRow::new(vec![Some(ScalarImpl::List(ListValue::from_iter([
926            Some(0),
927            None,
928            Some(2),
929            Some(3),
930        ])))]);
931
932        let err = encode_fields(
933            schema
934                .fields
935                .iter()
936                .map(|f| (f.name.as_str(), &f.data_type))
937                .zip_eq_debug(row.iter()),
938            &message_descriptor,
939        )
940        .unwrap_err();
941        assert_eq!(
942            err.to_string(),
943            "encode 'repeated_int_field' error: array containing null not allowed as repeated field"
944        );
945    }
946
947    #[test]
948    fn test_encode_proto_err() {
949        let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
950            .join("codec/tests/test_data/all-types.pb");
951        let pool_bytes = std::fs::read(pool_path).unwrap();
952        let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
953        let message_descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
954
955        let err = validate_fields(
956            std::iter::once(("not_exists", &DataType::Int16)),
957            &message_descriptor,
958        )
959        .unwrap_err();
960        assert_eq!(
961            err.to_string(),
962            "encode 'not_exists' error: field not in proto"
963        );
964
965        let err = validate_fields(
966            std::iter::once(("map_field", &DataType::Jsonb)),
967            &message_descriptor,
968        )
969        .unwrap_err();
970        assert_eq!(
971            err.to_string(),
972            "encode 'map_field' error: cannot encode jsonb column as all_types.AllTypes.MapFieldEntry field"
973        );
974    }
975}