risingwave_connector/sink/encoder/
proto.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use bytes::{BufMut, Bytes};
16use prost::Message;
17use prost_reflect::{
18    DynamicMessage, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage, Value,
19};
20use risingwave_common::array::VECTOR_ITEM_TYPE;
21use risingwave_common::catalog::Schema;
22use risingwave_common::row::Row;
23use risingwave_common::types::{DataType, DatumRef, MapType, ScalarRefImpl, StructType};
24use risingwave_common::util::iter_util::ZipEqDebug;
25
26use super::{FieldEncodeError, Result as SinkResult, RowEncoder, SerTo};
27
28type Result<T> = std::result::Result<T, FieldEncodeError>;
29
30pub struct ProtoEncoder {
31    schema: Schema,
32    col_indices: Option<Vec<usize>>,
33    descriptor: MessageDescriptor,
34    header: ProtoHeader,
35}
36
37#[derive(Debug, Clone, Copy)]
38pub enum ProtoHeader {
39    None,
40    /// <https://docs.confluent.io/platform/7.5/schema-registry/fundamentals/serdes-develop/index.html#messages-wire-format>
41    ///
42    /// * 00
43    /// * 4-byte big-endian schema ID
44    ConfluentSchemaRegistry(i32),
45}
46
47impl ProtoEncoder {
48    pub fn new(
49        schema: Schema,
50        col_indices: Option<Vec<usize>>,
51        descriptor: MessageDescriptor,
52        header: ProtoHeader,
53    ) -> SinkResult<Self> {
54        match &col_indices {
55            Some(col_indices) => validate_fields(
56                col_indices.iter().map(|idx| {
57                    let f = &schema[*idx];
58                    (f.name.as_str(), &f.data_type)
59                }),
60                &descriptor,
61            )?,
62            None => validate_fields(
63                schema
64                    .fields
65                    .iter()
66                    .map(|f| (f.name.as_str(), &f.data_type)),
67                &descriptor,
68            )?,
69        };
70
71        Ok(Self {
72            schema,
73            col_indices,
74            descriptor,
75            header,
76        })
77    }
78}
79
80pub struct ProtoEncoded {
81    pub message: DynamicMessage,
82    header: ProtoHeader,
83}
84
85impl RowEncoder for ProtoEncoder {
86    type Output = ProtoEncoded;
87
88    fn schema(&self) -> &Schema {
89        &self.schema
90    }
91
92    fn col_indices(&self) -> Option<&[usize]> {
93        self.col_indices.as_deref()
94    }
95
96    fn encode_cols(
97        &self,
98        row: impl Row,
99        col_indices: impl Iterator<Item = usize>,
100    ) -> SinkResult<Self::Output> {
101        encode_fields(
102            col_indices.map(|idx| {
103                let f = &self.schema[idx];
104                ((f.name.as_str(), &f.data_type), row.datum_at(idx))
105            }),
106            &self.descriptor,
107        )
108        .map_err(Into::into)
109        .map(|m| ProtoEncoded {
110            message: m,
111            header: self.header,
112        })
113    }
114}
115
116impl SerTo<Vec<u8>> for ProtoEncoded {
117    fn ser_to(self) -> SinkResult<Vec<u8>> {
118        let mut buf = Vec::new();
119        match self.header {
120            ProtoHeader::None => { /* noop */ }
121            ProtoHeader::ConfluentSchemaRegistry(schema_id) => {
122                buf.reserve(1 + 4);
123                buf.put_u8(0);
124                buf.put_i32(schema_id);
125                MessageIndexes::from(self.message.descriptor()).encode(&mut buf);
126            }
127        }
128        self.message.encode(&mut buf).unwrap();
129        Ok(buf)
130    }
131}
132
133struct MessageIndexes(Vec<i32>);
134
135impl MessageIndexes {
136    fn from(desc: MessageDescriptor) -> Self {
137        // https://github.com/protocolbuffers/protobuf/blob/v25.1/src/google/protobuf/descriptor.proto
138        // https://docs.rs/prost-reflect/0.12.0/src/prost_reflect/descriptor/tag.rs.html
139        // https://docs.rs/prost-reflect/0.12.0/src/prost_reflect/descriptor/build/visit.rs.html#125
140        // `FileDescriptorProto` field #4 is `repeated DescriptorProto message_type`
141        const TAG_FILE_MESSAGE: i32 = 4;
142        // `DescriptorProto` field #3 is `repeated DescriptorProto nested_type`
143        const TAG_MESSAGE_NESTED: i32 = 3;
144
145        let mut indexes = vec![];
146        let mut path = desc.path().iter().copied().array_chunks();
147        let [tag, idx] = path.next().unwrap();
148        assert_eq!(tag, TAG_FILE_MESSAGE);
149        indexes.push(idx);
150        for [tag, idx] in path {
151            assert_eq!(tag, TAG_MESSAGE_NESTED);
152            indexes.push(idx);
153        }
154        Self(indexes)
155    }
156
157    fn zig_i32(value: i32, buf: &mut impl BufMut) {
158        let unsigned = ((value << 1) ^ (value >> 31)) as u32 as u64;
159        prost::encoding::encode_varint(unsigned, buf);
160    }
161
162    fn encode(&self, buf: &mut impl BufMut) {
163        if self.0 == [0] {
164            buf.put_u8(0);
165            return;
166        }
167        Self::zig_i32(self.0.len().try_into().unwrap(), buf);
168        for &idx in &self.0 {
169            Self::zig_i32(idx, buf);
170        }
171    }
172}
173
174/// A trait that assists code reuse between `validate` and `encode`.
175/// * For `validate`, the inputs are (RisingWave type, ProtoBuf type).
176/// * For `encode`, the inputs are (RisingWave type, RisingWave data, ProtoBuf type).
177///
178/// Thus we impl [`MaybeData`] for both `()` and [`ScalarRefImpl`].
179trait MaybeData: std::fmt::Debug {
180    type Out;
181
182    fn on_base(self, f: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out>;
183
184    fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out>;
185
186    fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out>;
187
188    fn on_map(self, m: &MapType, pb: &MessageDescriptor) -> Result<Self::Out>;
189}
190
191impl MaybeData for () {
192    type Out = ();
193
194    fn on_base(self, _: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out> {
195        Ok(self)
196    }
197
198    fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out> {
199        validate_fields(st.iter(), pb)
200    }
201
202    fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out> {
203        on_field(elem, (), pb, true)
204    }
205
206    fn on_map(self, elem: &MapType, pb: &MessageDescriptor) -> Result<Self::Out> {
207        debug_assert!(pb.is_map_entry());
208        on_field(elem.key(), (), &pb.map_entry_key_field(), false)?;
209        on_field(elem.value(), (), &pb.map_entry_value_field(), false)?;
210        Ok(())
211    }
212}
213
214/// Nullability is not part of type system in proto.
215/// * Top level is always a message.
216/// * All message fields can be omitted in proto3.
217/// * All repeated elements must have a value.
218///
219/// So we handle [`ScalarRefImpl`] rather than [`DatumRef`] here.
220impl MaybeData for ScalarRefImpl<'_> {
221    type Out = Value;
222
223    fn on_base(self, f: impl FnOnce(ScalarRefImpl<'_>) -> Result<Value>) -> Result<Self::Out> {
224        f(self)
225    }
226
227    fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result<Self::Out> {
228        let d = self.into_struct();
229        let message = encode_fields(st.iter().zip_eq_debug(d.iter_fields_ref()), pb)?;
230        Ok(Value::Message(message))
231    }
232
233    fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result<Self::Out> {
234        let d = self.into_list();
235        let vs = d
236            .iter()
237            .map(|d| {
238                on_field(
239                    elem,
240                    d.ok_or_else(|| {
241                        FieldEncodeError::new("array containing null not allowed as repeated field")
242                    })?,
243                    pb,
244                    true,
245                )
246            })
247            .try_collect()?;
248        Ok(Value::List(vs))
249    }
250
251    fn on_map(self, m: &MapType, pb: &MessageDescriptor) -> Result<Self::Out> {
252        debug_assert!(pb.is_map_entry());
253        let vs = self
254            .into_map()
255            .iter()
256            .map(|(k, v)| {
257                let v =
258                    v.ok_or_else(|| FieldEncodeError::new("map containing null not allowed"))?;
259                let k = on_field(m.key(), k, &pb.map_entry_key_field(), false)?;
260                let v = on_field(m.value(), v, &pb.map_entry_value_field(), false)?;
261                Ok((
262                    k.into_map_key().ok_or_else(|| {
263                        FieldEncodeError::new("failed to convert map key to proto")
264                    })?,
265                    v,
266                ))
267            })
268            .try_collect()?;
269        Ok(Value::Map(vs))
270    }
271}
272
273fn validate_fields<'a>(
274    fields: impl Iterator<Item = (&'a str, &'a DataType)>,
275    descriptor: &MessageDescriptor,
276) -> Result<()> {
277    for (name, t) in fields {
278        let Some(proto_field) = descriptor.get_field_by_name(name) else {
279            return Err(FieldEncodeError::new("field not in proto").with_name(name));
280        };
281        if proto_field.cardinality() == prost_reflect::Cardinality::Required {
282            return Err(FieldEncodeError::new("`required` not supported").with_name(name));
283        }
284        on_field(t, (), &proto_field, false).map_err(|e| e.with_name(name))?;
285    }
286    Ok(())
287}
288
289fn encode_fields<'a>(
290    fields_with_datums: impl Iterator<Item = ((&'a str, &'a DataType), DatumRef<'a>)>,
291    descriptor: &MessageDescriptor,
292) -> Result<DynamicMessage> {
293    let mut message = DynamicMessage::new(descriptor.clone());
294    for ((name, t), d) in fields_with_datums {
295        let proto_field = descriptor.get_field_by_name(name).unwrap();
296        // On `null`, simply skip setting the field.
297        if let Some(scalar) = d {
298            let value = on_field(t, scalar, &proto_field, false).map_err(|e| e.with_name(name))?;
299            message
300                .try_set_field(&proto_field, value)
301                .map_err(|e| FieldEncodeError::new(e).with_name(name))?;
302        }
303    }
304    Ok(message)
305}
306
307// Full name of Well-Known Types
308const WKT_TIMESTAMP: &str = "google.protobuf.Timestamp";
309#[expect(dead_code)]
310const WKT_BOOL_VALUE: &str = "google.protobuf.BoolValue";
311
312/// Handles both `validate` (without actual data) and `encode`.
313/// See [`MaybeData`] for more info.
314fn on_field<D: MaybeData>(
315    data_type: &DataType,
316    maybe: D,
317    proto_field: &FieldDescriptor,
318    in_repeated: bool,
319) -> Result<D::Out> {
320    // Regarding (proto_field.is_list, in_repeated):
321    // (F, T) => impossible
322    // (F, F) => encoding to a non-repeated field
323    // (T, F) => encoding to a repeated field
324    // (T, T) => encoding to an element of a repeated field
325    // In the bottom 2 cases, we need to distinguish the same `proto_field` with the help of `in_repeated`.
326    assert!(proto_field.is_list() || !in_repeated);
327    let expect_list = proto_field.is_list() && !in_repeated;
328    if proto_field.is_group() {
329        return Err(FieldEncodeError::new("proto group not supported yet"));
330    }
331
332    let no_match_err = || {
333        Err(FieldEncodeError::new(format!(
334            "cannot encode {} column as {}{:?} field",
335            data_type,
336            if expect_list { "repeated " } else { "" },
337            proto_field.kind()
338        )))
339    };
340
341    if expect_list && !matches!(data_type, DataType::List(_)) {
342        return no_match_err();
343    }
344
345    let value = match &data_type {
346        // Group A: perfect match between RisingWave types and ProtoBuf types
347        DataType::Boolean => match proto_field.kind() {
348            Kind::Bool => maybe.on_base(|s| Ok(Value::Bool(s.into_bool())))?,
349            _ => return no_match_err(),
350        },
351        DataType::Varchar => match proto_field.kind() {
352            Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_utf8().into())))?,
353            Kind::Enum(enum_desc) => maybe.on_base(|s| {
354                let name = s.into_utf8();
355                let enum_value_desc = enum_desc.get_value_by_name(name).ok_or_else(|| {
356                    FieldEncodeError::new(format!("'{name}' not in enum {}", enum_desc.name()))
357                })?;
358                Ok(Value::EnumNumber(enum_value_desc.number()))
359            })?,
360            _ => return no_match_err(),
361        },
362        DataType::Bytea => match proto_field.kind() {
363            Kind::Bytes => {
364                maybe.on_base(|s| Ok(Value::Bytes(Bytes::copy_from_slice(s.into_bytea()))))?
365            }
366            _ => return no_match_err(),
367        },
368        DataType::Float32 => match proto_field.kind() {
369            Kind::Float => maybe.on_base(|s| Ok(Value::F32(s.into_float32().into())))?,
370            _ => return no_match_err(),
371        },
372        DataType::Float64 => match proto_field.kind() {
373            Kind::Double => maybe.on_base(|s| Ok(Value::F64(s.into_float64().into())))?,
374            _ => return no_match_err(),
375        },
376        DataType::Int32 => match proto_field.kind() {
377            Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => {
378                maybe.on_base(|s| Ok(Value::I32(s.into_int32())))?
379            }
380            _ => return no_match_err(),
381        },
382        DataType::Int64 => match proto_field.kind() {
383            Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => {
384                maybe.on_base(|s| Ok(Value::I64(s.into_int64())))?
385            }
386            _ => return no_match_err(),
387        },
388        DataType::Struct(st) => match proto_field.kind() {
389            Kind::Message(pb) => maybe.on_struct(st, &pb)?,
390            _ => return no_match_err(),
391        },
392        DataType::List(lt) => match expect_list {
393            true => maybe.on_list(lt.elem(), proto_field)?,
394            false => return no_match_err(),
395        },
396        // Group B: match between RisingWave types and ProtoBuf Well-Known types
397        DataType::Timestamptz => match proto_field.kind() {
398            Kind::Message(pb) if pb.full_name() == WKT_TIMESTAMP => maybe.on_base(|s| {
399                let d = s.into_timestamptz();
400                let message = prost_types::Timestamp {
401                    seconds: d.timestamp(),
402                    nanos: d.timestamp_subsec_nanos().try_into().unwrap(),
403                };
404                Ok(Value::Message(message.transcode_to_dynamic()))
405            })?,
406            Kind::String => {
407                maybe.on_base(|s| Ok(Value::String(s.into_timestamptz().to_string())))?
408            }
409            _ => return no_match_err(),
410        },
411        DataType::Jsonb => match proto_field.kind() {
412            Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_jsonb().to_string())))?,
413            _ => return no_match_err(), /* Value, NullValue, Struct (map), ListValue
414                                         * Group C: experimental */
415        },
416        DataType::Int16 => match proto_field.kind() {
417            Kind::Int64 => maybe.on_base(|s| Ok(Value::I64(s.into_int16() as i64)))?,
418            _ => return no_match_err(),
419        },
420        DataType::Date => match proto_field.kind() {
421            Kind::Int32 => {
422                maybe.on_base(|s| Ok(Value::I32(s.into_date().get_nums_days_unix_epoch())))?
423            }
424            _ => return no_match_err(), // google.type.Date
425        },
426        DataType::Time => match proto_field.kind() {
427            Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_time().to_string())))?,
428            _ => return no_match_err(), // google.type.TimeOfDay
429        },
430        DataType::Timestamp => match proto_field.kind() {
431            Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_timestamp().to_string())))?,
432            _ => return no_match_err(), // google.type.DateTime
433        },
434        DataType::Decimal => match proto_field.kind() {
435            Kind::String => maybe.on_base(|s| Ok(Value::String(s.into_decimal().to_string())))?,
436            _ => return no_match_err(), // google.type.Decimal
437        },
438        DataType::Interval => match proto_field.kind() {
439            Kind::String => {
440                maybe.on_base(|s| Ok(Value::String(s.into_interval().as_iso_8601())))?
441            }
442            _ => return no_match_err(), // Group D: unsupported
443        },
444        DataType::Serial => match proto_field.kind() {
445            Kind::Int64 => maybe.on_base(|s| Ok(Value::I64(s.into_serial().as_row_id())))?,
446            _ => return no_match_err(), // Group D: unsupported
447        },
448        DataType::Int256 => {
449            return no_match_err();
450        }
451        DataType::Map(map_type) => {
452            if proto_field.is_map() {
453                let msg = match proto_field.kind() {
454                    Kind::Message(m) => m,
455                    _ => return no_match_err(), // unreachable actually
456                };
457                return maybe.on_map(map_type, &msg);
458            } else {
459                return no_match_err();
460            }
461        }
462        DataType::Vector(_) => match expect_list {
463            true => maybe.on_list(&VECTOR_ITEM_TYPE, proto_field)?,
464            false => return no_match_err(),
465        },
466    };
467
468    Ok(value)
469}
470
471#[cfg(test)]
472mod tests {
473    use itertools::Itertools;
474    use risingwave_common::array::{ArrayBuilder, StructArrayBuilder};
475    use risingwave_common::catalog::Field;
476    use risingwave_common::row::OwnedRow;
477    use risingwave_common::types::{
478        ListValue, MapType, MapValue, Scalar, ScalarImpl, StructValue, Timestamptz,
479    };
480
481    use super::*;
482
483    #[test]
484    fn test_encode_proto_ok() {
485        let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
486            .join("codec/tests/test_data/all-types.pb");
487        let pool_bytes = std::fs::read(pool_path).unwrap();
488        let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
489        let descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
490        let schema = Schema::new(vec![
491            Field::with_name(DataType::Boolean, "bool_field"),
492            Field::with_name(DataType::Varchar, "string_field"),
493            Field::with_name(DataType::Bytea, "bytes_field"),
494            Field::with_name(DataType::Float32, "float_field"),
495            Field::with_name(DataType::Float64, "double_field"),
496            Field::with_name(DataType::Int32, "int32_field"),
497            Field::with_name(DataType::Int64, "int64_field"),
498            Field::with_name(DataType::Int32, "sint32_field"),
499            Field::with_name(DataType::Int64, "sint64_field"),
500            Field::with_name(DataType::Int32, "sfixed32_field"),
501            Field::with_name(DataType::Int64, "sfixed64_field"),
502            Field::with_name(
503                DataType::Struct(StructType::new(vec![
504                    ("id", DataType::Int32),
505                    ("name", DataType::Varchar),
506                ])),
507                "nested_message_field",
508            ),
509            Field::with_name(DataType::Int32.list(), "repeated_int_field"),
510            Field::with_name(DataType::Timestamptz, "timestamp_field"),
511            Field::with_name(
512                DataType::Map(MapType::from_kv(DataType::Varchar, DataType::Int32)),
513                "map_field",
514            ),
515            Field::with_name(
516                DataType::Map(MapType::from_kv(
517                    DataType::Varchar,
518                    DataType::Struct(StructType::new(vec![
519                        ("id", DataType::Int32),
520                        ("name", DataType::Varchar),
521                    ])),
522                )),
523                "map_struct_field",
524            ),
525        ]);
526        let row = OwnedRow::new(vec![
527            Some(ScalarImpl::Bool(true)),
528            Some(ScalarImpl::Utf8("RisingWave".into())),
529            Some(ScalarImpl::Bytea([0xbe, 0xef].into())),
530            Some(ScalarImpl::Float32(3.5f32.into())),
531            Some(ScalarImpl::Float64(4.25f64.into())),
532            Some(ScalarImpl::Int32(22)),
533            Some(ScalarImpl::Int64(23)),
534            Some(ScalarImpl::Int32(24)),
535            None,
536            Some(ScalarImpl::Int32(26)),
537            Some(ScalarImpl::Int64(27)),
538            Some(ScalarImpl::Struct(StructValue::new(vec![
539                Some(ScalarImpl::Int32(1)),
540                Some(ScalarImpl::Utf8("".into())),
541            ]))),
542            Some(ScalarImpl::List(ListValue::from_iter([4, 0, 4]))),
543            Some(ScalarImpl::Timestamptz(Timestamptz::from_micros(3))),
544            Some(ScalarImpl::Map(
545                MapValue::try_from_kv(
546                    ListValue::from_iter(["a", "b"]),
547                    ListValue::from_iter([1, 2]),
548                )
549                .unwrap(),
550            )),
551            {
552                let mut struct_array_builder = StructArrayBuilder::with_type(
553                    2,
554                    DataType::Struct(StructType::new(vec![
555                        ("id", DataType::Int32),
556                        ("name", DataType::Varchar),
557                    ])),
558                );
559                struct_array_builder.append(Some(
560                    StructValue::new(vec![
561                        Some(ScalarImpl::Int32(1)),
562                        Some(ScalarImpl::Utf8("x".into())),
563                    ])
564                    .as_scalar_ref(),
565                ));
566                struct_array_builder.append(Some(
567                    StructValue::new(vec![
568                        Some(ScalarImpl::Int32(2)),
569                        Some(ScalarImpl::Utf8("y".into())),
570                    ])
571                    .as_scalar_ref(),
572                ));
573                Some(ScalarImpl::Map(
574                    MapValue::try_from_kv(
575                        ListValue::from_iter(["a", "b"]),
576                        ListValue::new(struct_array_builder.finish().into()),
577                    )
578                    .unwrap(),
579                ))
580            },
581        ]);
582
583        let encoder = ProtoEncoder::new(schema, None, descriptor, ProtoHeader::None).unwrap();
584        let m = encoder.encode(row).unwrap();
585        expect_test::expect![[r#"
586            field: FieldDescriptor {
587                name: "double_field",
588                full_name: "all_types.AllTypes.double_field",
589                json_name: "doubleField",
590                number: 1,
591                kind: double,
592                cardinality: Optional,
593                containing_oneof: None,
594                default_value: F64(
595                    0.0,
596                ),
597                is_group: false,
598                is_list: false,
599                is_map: false,
600                is_packed: false,
601                supports_presence: false,
602            }
603
604            value: F64(4.25)
605
606            ==============================
607            field: FieldDescriptor {
608                name: "float_field",
609                full_name: "all_types.AllTypes.float_field",
610                json_name: "floatField",
611                number: 2,
612                kind: float,
613                cardinality: Optional,
614                containing_oneof: None,
615                default_value: F32(
616                    0.0,
617                ),
618                is_group: false,
619                is_list: false,
620                is_map: false,
621                is_packed: false,
622                supports_presence: false,
623            }
624
625            value: F32(3.5)
626
627            ==============================
628            field: FieldDescriptor {
629                name: "int32_field",
630                full_name: "all_types.AllTypes.int32_field",
631                json_name: "int32Field",
632                number: 3,
633                kind: int32,
634                cardinality: Optional,
635                containing_oneof: None,
636                default_value: I32(
637                    0,
638                ),
639                is_group: false,
640                is_list: false,
641                is_map: false,
642                is_packed: false,
643                supports_presence: false,
644            }
645
646            value: I32(22)
647
648            ==============================
649            field: FieldDescriptor {
650                name: "int64_field",
651                full_name: "all_types.AllTypes.int64_field",
652                json_name: "int64Field",
653                number: 4,
654                kind: int64,
655                cardinality: Optional,
656                containing_oneof: None,
657                default_value: I64(
658                    0,
659                ),
660                is_group: false,
661                is_list: false,
662                is_map: false,
663                is_packed: false,
664                supports_presence: false,
665            }
666
667            value: I64(23)
668
669            ==============================
670            field: FieldDescriptor {
671                name: "sint32_field",
672                full_name: "all_types.AllTypes.sint32_field",
673                json_name: "sint32Field",
674                number: 7,
675                kind: sint32,
676                cardinality: Optional,
677                containing_oneof: None,
678                default_value: I32(
679                    0,
680                ),
681                is_group: false,
682                is_list: false,
683                is_map: false,
684                is_packed: false,
685                supports_presence: false,
686            }
687
688            value: I32(24)
689
690            ==============================
691            field: FieldDescriptor {
692                name: "sfixed32_field",
693                full_name: "all_types.AllTypes.sfixed32_field",
694                json_name: "sfixed32Field",
695                number: 11,
696                kind: sfixed32,
697                cardinality: Optional,
698                containing_oneof: None,
699                default_value: I32(
700                    0,
701                ),
702                is_group: false,
703                is_list: false,
704                is_map: false,
705                is_packed: false,
706                supports_presence: false,
707            }
708
709            value: I32(26)
710
711            ==============================
712            field: FieldDescriptor {
713                name: "sfixed64_field",
714                full_name: "all_types.AllTypes.sfixed64_field",
715                json_name: "sfixed64Field",
716                number: 12,
717                kind: sfixed64,
718                cardinality: Optional,
719                containing_oneof: None,
720                default_value: I64(
721                    0,
722                ),
723                is_group: false,
724                is_list: false,
725                is_map: false,
726                is_packed: false,
727                supports_presence: false,
728            }
729
730            value: I64(27)
731
732            ==============================
733            field: FieldDescriptor {
734                name: "bool_field",
735                full_name: "all_types.AllTypes.bool_field",
736                json_name: "boolField",
737                number: 13,
738                kind: bool,
739                cardinality: Optional,
740                containing_oneof: None,
741                default_value: Bool(
742                    false,
743                ),
744                is_group: false,
745                is_list: false,
746                is_map: false,
747                is_packed: false,
748                supports_presence: false,
749            }
750
751            value: Bool(true)
752
753            ==============================
754            field: FieldDescriptor {
755                name: "string_field",
756                full_name: "all_types.AllTypes.string_field",
757                json_name: "stringField",
758                number: 14,
759                kind: string,
760                cardinality: Optional,
761                containing_oneof: None,
762                default_value: String(
763                    "",
764                ),
765                is_group: false,
766                is_list: false,
767                is_map: false,
768                is_packed: false,
769                supports_presence: false,
770            }
771
772            value: String("RisingWave")
773
774            ==============================
775            field: FieldDescriptor {
776                name: "bytes_field",
777                full_name: "all_types.AllTypes.bytes_field",
778                json_name: "bytesField",
779                number: 15,
780                kind: bytes,
781                cardinality: Optional,
782                containing_oneof: None,
783                default_value: Bytes(
784                    b"",
785                ),
786                is_group: false,
787                is_list: false,
788                is_map: false,
789                is_packed: false,
790                supports_presence: false,
791            }
792
793            value: Bytes(b"\xbe\xef")
794
795            ==============================
796            field: FieldDescriptor {
797                name: "nested_message_field",
798                full_name: "all_types.AllTypes.nested_message_field",
799                json_name: "nestedMessageField",
800                number: 17,
801                kind: all_types.AllTypes.NestedMessage,
802                cardinality: Optional,
803                containing_oneof: None,
804                default_value: Message(
805                    DynamicMessage {
806                        desc: MessageDescriptor {
807                            name: "NestedMessage",
808                            full_name: "all_types.AllTypes.NestedMessage",
809                            is_map_entry: false,
810                            fields: [
811                                FieldDescriptor {
812                                    name: "id",
813                                    full_name: "all_types.AllTypes.NestedMessage.id",
814                                    json_name: "id",
815                                    number: 1,
816                                    kind: int32,
817                                    cardinality: Optional,
818                                    containing_oneof: None,
819                                    default_value: I32(
820                                        0,
821                                    ),
822                                    is_group: false,
823                                    is_list: false,
824                                    is_map: false,
825                                    is_packed: false,
826                                    supports_presence: false,
827                                },
828                                FieldDescriptor {
829                                    name: "name",
830                                    full_name: "all_types.AllTypes.NestedMessage.name",
831                                    json_name: "name",
832                                    number: 2,
833                                    kind: string,
834                                    cardinality: Optional,
835                                    containing_oneof: None,
836                                    default_value: String(
837                                        "",
838                                    ),
839                                    is_group: false,
840                                    is_list: false,
841                                    is_map: false,
842                                    is_packed: false,
843                                    supports_presence: false,
844                                },
845                            ],
846                            oneofs: [],
847                        },
848                        fields: DynamicMessageFieldSet {
849                            fields: {},
850                        },
851                    },
852                ),
853                is_group: false,
854                is_list: false,
855                is_map: false,
856                is_packed: false,
857                supports_presence: true,
858            }
859
860            value: Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: I32(0), is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: String(""), is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(1)), 2: Value(String(""))} } })
861
862            ==============================
863            field: FieldDescriptor {
864                name: "repeated_int_field",
865                full_name: "all_types.AllTypes.repeated_int_field",
866                json_name: "repeatedIntField",
867                number: 18,
868                kind: int32,
869                cardinality: Repeated,
870                containing_oneof: None,
871                default_value: List(
872                    [],
873                ),
874                is_group: false,
875                is_list: true,
876                is_map: false,
877                is_packed: true,
878                supports_presence: false,
879            }
880
881            value: List([I32(4), I32(0), I32(4)])
882
883            ==============================
884            field: FieldDescriptor {
885                name: "map_field",
886                full_name: "all_types.AllTypes.map_field",
887                json_name: "mapField",
888                number: 22,
889                kind: all_types.AllTypes.MapFieldEntry,
890                cardinality: Repeated,
891                containing_oneof: None,
892                default_value: Map(
893                    {},
894                ),
895                is_group: false,
896                is_list: false,
897                is_map: true,
898                is_packed: false,
899                supports_presence: false,
900            }
901
902            value: Map({
903                String("a"): I32(1),
904                String("b"): I32(2),
905            })
906
907            ==============================
908            field: FieldDescriptor {
909                name: "timestamp_field",
910                full_name: "all_types.AllTypes.timestamp_field",
911                json_name: "timestampField",
912                number: 23,
913                kind: google.protobuf.Timestamp,
914                cardinality: Optional,
915                containing_oneof: None,
916                default_value: Message(
917                    DynamicMessage {
918                        desc: MessageDescriptor {
919                            name: "Timestamp",
920                            full_name: "google.protobuf.Timestamp",
921                            is_map_entry: false,
922                            fields: [
923                                FieldDescriptor {
924                                    name: "seconds",
925                                    full_name: "google.protobuf.Timestamp.seconds",
926                                    json_name: "seconds",
927                                    number: 1,
928                                    kind: int64,
929                                    cardinality: Optional,
930                                    containing_oneof: None,
931                                    default_value: I64(
932                                        0,
933                                    ),
934                                    is_group: false,
935                                    is_list: false,
936                                    is_map: false,
937                                    is_packed: false,
938                                    supports_presence: false,
939                                },
940                                FieldDescriptor {
941                                    name: "nanos",
942                                    full_name: "google.protobuf.Timestamp.nanos",
943                                    json_name: "nanos",
944                                    number: 2,
945                                    kind: int32,
946                                    cardinality: Optional,
947                                    containing_oneof: None,
948                                    default_value: I32(
949                                        0,
950                                    ),
951                                    is_group: false,
952                                    is_list: false,
953                                    is_map: false,
954                                    is_packed: false,
955                                    supports_presence: false,
956                                },
957                            ],
958                            oneofs: [],
959                        },
960                        fields: DynamicMessageFieldSet {
961                            fields: {},
962                        },
963                    },
964                ),
965                is_group: false,
966                is_list: false,
967                is_map: false,
968                is_packed: false,
969                supports_presence: true,
970            }
971
972            value: Message(DynamicMessage { desc: MessageDescriptor { name: "Timestamp", full_name: "google.protobuf.Timestamp", is_map_entry: false, fields: [FieldDescriptor { name: "seconds", full_name: "google.protobuf.Timestamp.seconds", json_name: "seconds", number: 1, kind: int64, cardinality: Optional, containing_oneof: None, default_value: I64(0), is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "nanos", full_name: "google.protobuf.Timestamp.nanos", json_name: "nanos", number: 2, kind: int32, cardinality: Optional, containing_oneof: None, default_value: I32(0), is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {2: Value(I32(3000))} } })
973
974            ==============================
975            field: FieldDescriptor {
976                name: "map_struct_field",
977                full_name: "all_types.AllTypes.map_struct_field",
978                json_name: "mapStructField",
979                number: 29,
980                kind: all_types.AllTypes.MapStructFieldEntry,
981                cardinality: Repeated,
982                containing_oneof: None,
983                default_value: Map(
984                    {},
985                ),
986                is_group: false,
987                is_list: false,
988                is_map: true,
989                is_packed: false,
990                supports_presence: false,
991            }
992
993            value: Map({
994                String("a"): Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: I32(0), is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: String(""), is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(1)), 2: Value(String("x"))} } }),
995                String("b"): Message(DynamicMessage { desc: MessageDescriptor { name: "NestedMessage", full_name: "all_types.AllTypes.NestedMessage", is_map_entry: false, fields: [FieldDescriptor { name: "id", full_name: "all_types.AllTypes.NestedMessage.id", json_name: "id", number: 1, kind: int32, cardinality: Optional, containing_oneof: None, default_value: I32(0), is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }, FieldDescriptor { name: "name", full_name: "all_types.AllTypes.NestedMessage.name", json_name: "name", number: 2, kind: string, cardinality: Optional, containing_oneof: None, default_value: String(""), is_group: false, is_list: false, is_map: false, is_packed: false, supports_presence: false }], oneofs: [] }, fields: DynamicMessageFieldSet { fields: {1: Value(I32(2)), 2: Value(String("y"))} } }),
996            })"#]].assert_eq(&format!("{}",
997            m.message.fields().format_with("\n\n==============================\n", |(field,value),f| {
998            f(&format!("field: {:#?}\n\nvalue: {}", field, print_proto(value)))
999        })));
1000    }
1001
1002    fn print_proto(value: &Value) -> String {
1003        match value {
1004            Value::Map(m) => {
1005                let mut res = String::new();
1006                res.push_str("Map({\n");
1007                for (k, v) in m.iter().sorted_by_key(|(k, _v)| *k) {
1008                    res.push_str(&format!(
1009                        "    {}: {},\n",
1010                        print_proto(&k.clone().into()),
1011                        print_proto(v)
1012                    ));
1013                }
1014                res.push_str("})");
1015                res
1016            }
1017            _ => format!("{:?}", value),
1018        }
1019    }
1020
1021    #[test]
1022    fn test_encode_proto_repeated() {
1023        let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
1024            .join("codec/tests/test_data/all-types.pb");
1025        let pool_bytes = fs_err::read(pool_path).unwrap();
1026        let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
1027        let message_descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
1028
1029        let schema = Schema::new(vec![Field::with_name(
1030            DataType::Int32.list().list(),
1031            "repeated_int_field",
1032        )]);
1033
1034        let err = validate_fields(
1035            schema
1036                .fields
1037                .iter()
1038                .map(|f| (f.name.as_str(), &f.data_type)),
1039            &message_descriptor,
1040        )
1041        .unwrap_err();
1042        assert_eq!(
1043            err.to_string(),
1044            "encode 'repeated_int_field' error: cannot encode integer[] column as int32 field"
1045        );
1046
1047        let schema = Schema::new(vec![Field::with_name(
1048            DataType::Int32.list(),
1049            "repeated_int_field",
1050        )]);
1051        let row = OwnedRow::new(vec![Some(ScalarImpl::List(ListValue::from_iter([
1052            Some(0),
1053            None,
1054            Some(2),
1055            Some(3),
1056        ])))]);
1057
1058        let err = encode_fields(
1059            schema
1060                .fields
1061                .iter()
1062                .map(|f| (f.name.as_str(), &f.data_type))
1063                .zip_eq_debug(row.iter()),
1064            &message_descriptor,
1065        )
1066        .unwrap_err();
1067        assert_eq!(
1068            err.to_string(),
1069            "encode 'repeated_int_field' error: array containing null not allowed as repeated field"
1070        );
1071    }
1072
1073    #[test]
1074    fn test_encode_proto_err() {
1075        let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
1076            .join("codec/tests/test_data/all-types.pb");
1077        let pool_bytes = std::fs::read(pool_path).unwrap();
1078        let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap();
1079        let message_descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap();
1080
1081        let err = validate_fields(
1082            std::iter::once(("not_exists", &DataType::Int16)),
1083            &message_descriptor,
1084        )
1085        .unwrap_err();
1086        assert_eq!(
1087            err.to_string(),
1088            "encode 'not_exists' error: field not in proto"
1089        );
1090
1091        let err = validate_fields(
1092            std::iter::once(("map_field", &DataType::Jsonb)),
1093            &message_descriptor,
1094        )
1095        .unwrap_err();
1096        assert_eq!(
1097            err.to_string(),
1098            "encode 'map_field' error: cannot encode jsonb column as all_types.AllTypes.MapFieldEntry field"
1099        );
1100    }
1101}