risingwave_connector_codec/decoder/avro/
schema.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::{Arc, LazyLock};
16
17use anyhow::Context;
18use apache_avro::AvroResult;
19use apache_avro::schema::{DecimalSchema, NamesRef, RecordSchema, ResolvedSchema, Schema};
20use itertools::Itertools;
21use risingwave_common::catalog::Field;
22use risingwave_common::error::NotImplemented;
23use risingwave_common::log::LogSuppressor;
24use risingwave_common::types::{DataType, Decimal, MapType, StructType};
25use risingwave_common::{bail, bail_not_implemented};
26
27use super::get_nullable_union_inner;
28
29/// Avro schema with `Ref` inlined. The newtype is used to indicate whether the schema is resolved.
30///
31/// TODO: Actually most of the place should use resolved schema, but currently they just happen to work (Some edge cases are not met yet).
32///
33/// TODO: refactor avro lib to use the feature there.
34#[derive(Debug)]
35pub struct ResolvedAvroSchema {
36    /// Should be used for parsing bytes into Avro value
37    pub original_schema: Arc<Schema>,
38}
39
40impl ResolvedAvroSchema {
41    pub fn create(schema: Arc<Schema>) -> AvroResult<Self> {
42        Ok(Self {
43            original_schema: schema,
44        })
45    }
46}
47
48/// How to convert the map type from the input encoding to RisingWave's datatype.
49///
50/// XXX: Should this be `avro.map.handling.mode`? Can it be shared between Avro and Protobuf?
51#[derive(Debug, Copy, Clone)]
52pub enum MapHandling {
53    Jsonb,
54    Map,
55}
56
57impl MapHandling {
58    pub const OPTION_KEY: &'static str = "map.handling.mode";
59
60    pub fn from_options(
61        options: &std::collections::BTreeMap<String, String>,
62    ) -> anyhow::Result<Option<Self>> {
63        let mode = match options.get(Self::OPTION_KEY).map(std::ops::Deref::deref) {
64            Some("jsonb") => Self::Jsonb,
65            Some("map") => Self::Map,
66            Some(v) => bail!("unrecognized {} value {}", Self::OPTION_KEY, v),
67            None => return Ok(None),
68        };
69        Ok(Some(mode))
70    }
71}
72
73/// This function expects original schema (with `Ref`).
74/// TODO: change `map_handling` to some `Config`, and also unify debezium.
75pub fn avro_schema_to_fields(
76    schema: &Schema,
77    map_handling: Option<MapHandling>,
78) -> anyhow::Result<Vec<Field>> {
79    let resolved = ResolvedSchema::try_from(schema)?;
80    let mut ancestor_records: Vec<String> = vec![];
81    let root_type = avro_type_mapping(
82        schema,
83        &mut ancestor_records,
84        resolved.get_names(),
85        map_handling,
86    )?;
87    let DataType::Struct(root_struct) = root_type else {
88        bail!("schema invalid, record type required at top level of the schema.");
89    };
90    let fields = root_struct
91        .iter()
92        .map(|(name, data_type)| Field::new(name, data_type.clone()))
93        .collect();
94    Ok(fields)
95}
96
97const DBZ_VARIABLE_SCALE_DECIMAL_NAME: &str = "VariableScaleDecimal";
98const DBZ_VARIABLE_SCALE_DECIMAL_NAMESPACE: &str = "io.debezium.data";
99
100/// This function expects original schema (with `Ref`).
101fn avro_type_mapping(
102    schema: &Schema,
103    ancestor_records: &mut Vec<String>,
104    refs: &NamesRef<'_>,
105    map_handling: Option<MapHandling>,
106) -> anyhow::Result<DataType> {
107    let data_type = match schema {
108        Schema::String => DataType::Varchar,
109        Schema::Int => DataType::Int32,
110        Schema::Long => DataType::Int64,
111        Schema::Boolean => DataType::Boolean,
112        Schema::Float => DataType::Float32,
113        Schema::Double => DataType::Float64,
114        Schema::Decimal(DecimalSchema { precision, .. }) => {
115            if *precision > Decimal::MAX_PRECISION.into() {
116                static LOG_SUPPRESSOR: LazyLock<LogSuppressor> =
117                    LazyLock::new(LogSuppressor::default);
118                if let Ok(suppressed_count) = LOG_SUPPRESSOR.check() {
119                    tracing::warn!(
120                        suppressed_count,
121                        "RisingWave supports decimal precision up to {}, but got {}. Will truncate.",
122                        Decimal::MAX_PRECISION,
123                        precision
124                    );
125                }
126            }
127            DataType::Decimal
128        }
129        Schema::Date => DataType::Date,
130        Schema::LocalTimestampMillis => DataType::Timestamp,
131        Schema::LocalTimestampMicros => DataType::Timestamp,
132        Schema::TimestampMillis => DataType::Timestamptz,
133        Schema::TimestampMicros => DataType::Timestamptz,
134        Schema::Duration => DataType::Interval,
135        Schema::Bytes => DataType::Bytea,
136        Schema::Enum { .. } => DataType::Varchar,
137        Schema::TimeMillis => DataType::Time,
138        Schema::TimeMicros => DataType::Time,
139        Schema::Record(RecordSchema { fields, name, .. }) => {
140            if name.name == DBZ_VARIABLE_SCALE_DECIMAL_NAME
141                && name.namespace == Some(DBZ_VARIABLE_SCALE_DECIMAL_NAMESPACE.into())
142            {
143                return Ok(DataType::Decimal);
144            }
145
146            let unique_name = name.fullname(None);
147            if ancestor_records.contains(&unique_name) {
148                bail!(
149                    "circular reference detected in Avro schema: {} -> {}",
150                    ancestor_records.join(" -> "),
151                    unique_name
152                );
153            }
154
155            ancestor_records.push(unique_name);
156            let ty = StructType::new(
157                fields
158                    .iter()
159                    .map(|f| {
160                        Ok((
161                            &f.name,
162                            avro_type_mapping(&f.schema, ancestor_records, refs, map_handling)?,
163                        ))
164                    })
165                    .collect::<anyhow::Result<Vec<_>>>()?,
166            )
167            .into();
168            ancestor_records.pop();
169            ty
170        }
171        Schema::Array(array_schema) => {
172            let item_schema = &array_schema.items;
173            let item_type =
174                avro_type_mapping(item_schema.as_ref(), ancestor_records, refs, map_handling)?;
175            DataType::list(item_type)
176        }
177        Schema::Union(union_schema) => {
178            // Note: Unions may not immediately contain other unions. So a `null` must represent a top-level null.
179            // e.g., ["null", ["null", "string"]] is not allowed
180
181            // Note: Unions may not contain more than one schema with the same type, except for the named types record, fixed and enum.
182            // https://avro.apache.org/docs/1.11.1/specification/_print/#unions
183            debug_assert!(
184                union_schema
185                    .variants()
186                    .iter()
187                    .map(Schema::canonical_form) // Schema doesn't implement Eq, but only PartialEq.
188                    .duplicates()
189                    .next()
190                    .is_none(),
191                "Union contains duplicate types: {union_schema:?}",
192            );
193            match get_nullable_union_inner(union_schema) {
194                Some(inner) => avro_type_mapping(inner, ancestor_records, refs, map_handling)?,
195                None => {
196                    // Convert the union to a struct, each field of the struct represents a variant of the union.
197                    // Refer to https://github.com/risingwavelabs/risingwave/issues/16273#issuecomment-2179761345 to see why it's not perfect.
198                    // Note: Avro union's variant tag is type name, not field name (unlike Rust enum, or Protobuf oneof).
199
200                    // XXX: do we need to introduce union.handling.mode?
201                    let fields = union_schema
202                        .variants()
203                        .iter()
204                        // null will mean the whole struct is null
205                        .filter(|variant| !matches!(variant, &&Schema::Null))
206                        .map(|variant| {
207                            avro_type_mapping(variant, ancestor_records, refs, map_handling)
208                                .and_then(|t| {
209                                    let name = avro_schema_to_struct_field_name(variant)?;
210                                    Ok((name, t))
211                                })
212                        })
213                        .try_collect::<_, Vec<_>, _>()
214                        .context("failed to convert Avro union to struct")?;
215
216                    StructType::new(fields).into()
217                }
218            }
219        }
220        Schema::Ref { name } => {
221            if name.name == DBZ_VARIABLE_SCALE_DECIMAL_NAME
222                && name.namespace == Some(DBZ_VARIABLE_SCALE_DECIMAL_NAMESPACE.into())
223            {
224                DataType::Decimal
225            } else {
226                avro_type_mapping(
227                    refs[name], // `ResolvedSchema::try_from` already handles lookup failure
228                    ancestor_records,
229                    refs,
230                    map_handling,
231                )?
232            }
233        }
234        Schema::Map(map_schema) => {
235            let value_schema = &map_schema.types;
236            // TODO: support native map type
237            match map_handling {
238                Some(MapHandling::Jsonb) => {
239                    if supported_avro_to_json_type(value_schema) {
240                        DataType::Jsonb
241                    } else {
242                        bail_not_implemented!(
243                            issue = 16963,
244                            "Avro map type to jsonb: {:?}",
245                            schema
246                        );
247                    }
248                }
249                Some(MapHandling::Map) | None => {
250                    let value = avro_type_mapping(
251                        value_schema.as_ref(),
252                        ancestor_records,
253                        refs,
254                        map_handling,
255                    )
256                    .context("failed to convert Avro map type")?;
257                    DataType::Map(MapType::from_kv(DataType::Varchar, value))
258                }
259            }
260        }
261        Schema::Uuid => DataType::Varchar,
262        Schema::Null
263        | Schema::BigDecimal
264        | Schema::TimestampNanos
265        | Schema::LocalTimestampNanos
266        | Schema::Fixed(_) => {
267            bail_not_implemented!("Avro type: {:?}", schema);
268        }
269    };
270
271    Ok(data_type)
272}
273
274/// Check for [`super::avro_to_jsonb`]
275fn supported_avro_to_json_type(schema: &Schema) -> bool {
276    match schema {
277        Schema::Null | Schema::Boolean | Schema::Int | Schema::String => true,
278
279        Schema::Map(map_schema) => supported_avro_to_json_type(&map_schema.types),
280        Schema::Array(array_schema) => supported_avro_to_json_type(&array_schema.items),
281        Schema::Record(RecordSchema { fields, .. }) => fields
282            .iter()
283            .all(|f| supported_avro_to_json_type(&f.schema)),
284        Schema::Long
285        | Schema::Float
286        | Schema::Double
287        | Schema::Bytes
288        | Schema::Enum(_)
289        | Schema::Fixed(_)
290        | Schema::Decimal(_)
291        | Schema::BigDecimal
292        | Schema::Uuid
293        | Schema::Date
294        | Schema::TimeMillis
295        | Schema::TimeMicros
296        | Schema::TimestampMillis
297        | Schema::TimestampMicros
298        | Schema::TimestampNanos
299        | Schema::LocalTimestampMillis
300        | Schema::LocalTimestampMicros
301        | Schema::LocalTimestampNanos
302        | Schema::Duration
303        | Schema::Ref { name: _ }
304        | Schema::Union(_) => false,
305    }
306}
307
308/// The field name when converting Avro union type to RisingWave struct type.
309pub(super) fn avro_schema_to_struct_field_name(schema: &Schema) -> Result<String, NotImplemented> {
310    Ok(match schema {
311        Schema::Null => unreachable!(),
312        Schema::Union(_) => unreachable!(),
313        // Primitive types
314        Schema::Boolean => "boolean".to_owned(),
315        Schema::Int => "int".to_owned(),
316        Schema::Long => "long".to_owned(),
317        Schema::Float => "float".to_owned(),
318        Schema::Double => "double".to_owned(),
319        Schema::Bytes => "bytes".to_owned(),
320        Schema::String => "string".to_owned(),
321        // Unnamed Complex types
322        Schema::Array(_) => "array".to_owned(),
323        Schema::Map(_) => "map".to_owned(),
324        // Named Complex types
325        Schema::Enum(_) | Schema::Ref { name: _ } | Schema::Fixed(_) | Schema::Record(_) => {
326            // schema.name().unwrap().fullname(None)
327            // See test_avro_lib_union_record_bug
328            // https://github.com/risingwavelabs/risingwave/issues/17632
329            bail_not_implemented!(issue=17632, "Avro named type used in Union type: {:?}", schema)
330
331        }
332
333        // Logical types are currently banned. See https://github.com/risingwavelabs/risingwave/issues/17616
334
335/*
336        Schema::Uuid => "uuid".to_string(),
337        // Decimal is the most tricky. https://avro.apache.org/docs/1.11.1/specification/_print/#decimal
338        // - A decimal logical type annotates Avro bytes _or_ fixed types.
339        // - It has attributes `precision` and `scale`.
340        //  "For the purposes of schema resolution, two schemas that are decimal logical types match if their scales and precisions match."
341        // - When the physical type is fixed, it's a named type. And a schema containing 2 decimals is possible:
342        //   [
343        //     {"type":"fixed","name":"Decimal128","size":16,"logicalType":"decimal","precision":38,"scale":2},
344        //     {"type":"fixed","name":"Decimal256","size":32,"logicalType":"decimal","precision":50,"scale":2}
345        //   ]
346        //   In this case (a logical type's physical type is a named type), perhaps we should use the physical type's `name`.
347        Schema::Decimal(_) => "decimal".to_string(),
348        Schema::Date => "date".to_string(),
349        // Note: in Avro, the name style is "time-millis", etc.
350        // But in RisingWave (Postgres), it will require users to use quotes, i.e.,
351        // select (struct)."time-millis", (struct).time_millies from t;
352        // The latter might be more user-friendly.
353        Schema::TimeMillis => "time_millis".to_string(),
354        Schema::TimeMicros => "time_micros".to_string(),
355        Schema::TimestampMillis => "timestamp_millis".to_string(),
356        Schema::TimestampMicros => "timestamp_micros".to_string(),
357        Schema::LocalTimestampMillis => "local_timestamp_millis".to_string(),
358        Schema::LocalTimestampMicros => "local_timestamp_micros".to_string(),
359        Schema::Duration => "duration".to_string(),
360*/
361        Schema::Uuid
362        | Schema::Decimal(_)
363        | Schema::BigDecimal
364        | Schema::Date
365        | Schema::TimeMillis
366        | Schema::TimeMicros
367        | Schema::TimestampMillis
368        | Schema::TimestampMicros
369        | Schema::TimestampNanos
370        | Schema::LocalTimestampMillis
371        | Schema::LocalTimestampMicros
372        | Schema::LocalTimestampNanos
373        | Schema::Duration => {
374            bail_not_implemented!(issue=17616, "Avro logicalType used in Union type: {:?}", schema)
375        }
376    })
377}