risingwave_connector/parser/avro/
glue_resolver.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::sync::Arc;
17
18use anyhow::Context;
19use apache_avro::Schema;
20use aws_sdk_glue::Client;
21use aws_sdk_glue::types::{SchemaId, SchemaVersionNumber};
22use moka::future::Cache;
23
24use crate::connector_common::AwsAuthProps;
25use crate::error::ConnectorResult;
26
27/// Fetch schemas from AWS Glue schema registry and cache them.
28///
29/// Background: This is mainly used for Avro **writer schema** (during schema evolution): When decoding an Avro message,
30/// we must get the message's schema id, and use the *exactly same schema* to decode the message, and then
31/// convert it with the reader schema. (This is also why Avro has to be used with a schema registry instead of a static schema file.)
32///
33/// TODO: support protobuf (not sure if it's needed)
34pub trait GlueSchemaCache {
35    /// Gets the a specific schema by id, which is used as *writer schema*.
36    async fn get_by_id(&self, schema_version_id: uuid::Uuid) -> ConnectorResult<Arc<Schema>>;
37    /// Gets the latest schema by arn, which is used as *reader schema*.
38    async fn get_by_name(&self, schema_arn: &str) -> ConnectorResult<Arc<Schema>>;
39}
40
41#[derive(Debug)]
42pub enum GlueSchemaCacheImpl {
43    Real(RealGlueSchemaCache),
44    Mock(MockGlueSchemaCache),
45}
46
47impl GlueSchemaCacheImpl {
48    pub async fn new(
49        aws_auth_props: &AwsAuthProps,
50        mock_config: Option<&str>,
51    ) -> ConnectorResult<Self> {
52        if let Some(mock_config) = mock_config {
53            return Ok(Self::Mock(MockGlueSchemaCache::new(mock_config)));
54        }
55        Ok(Self::Real(RealGlueSchemaCache::new(aws_auth_props).await?))
56    }
57}
58
59impl GlueSchemaCache for GlueSchemaCacheImpl {
60    async fn get_by_id(&self, schema_version_id: uuid::Uuid) -> ConnectorResult<Arc<Schema>> {
61        match self {
62            Self::Real(inner) => inner.get_by_id(schema_version_id).await,
63            Self::Mock(inner) => inner.get_by_id(schema_version_id).await,
64        }
65    }
66
67    async fn get_by_name(&self, schema_arn: &str) -> ConnectorResult<Arc<Schema>> {
68        match self {
69            Self::Real(inner) => inner.get_by_name(schema_arn).await,
70            Self::Mock(inner) => inner.get_by_name(schema_arn).await,
71        }
72    }
73}
74
75#[derive(Debug)]
76pub struct RealGlueSchemaCache {
77    writer_schemas: Cache<uuid::Uuid, Arc<Schema>>,
78    glue_client: Client,
79}
80
81impl RealGlueSchemaCache {
82    /// Create a new `GlueSchemaCache`
83    pub async fn new(aws_auth_props: &AwsAuthProps) -> ConnectorResult<Self> {
84        let client = Client::new(&aws_auth_props.build_config().await?);
85        Ok(Self {
86            writer_schemas: Cache::new(u64::MAX),
87            glue_client: client,
88        })
89    }
90
91    async fn parse_and_cache_schema(
92        &self,
93        schema_version_id: uuid::Uuid,
94        content: &str,
95    ) -> ConnectorResult<Arc<Schema>> {
96        let schema = Schema::parse_str(content).context("failed to parse avro schema")?;
97        let schema = Arc::new(schema);
98        self.writer_schemas
99            .insert(schema_version_id, Arc::clone(&schema))
100            .await;
101        Ok(schema)
102    }
103}
104
105impl GlueSchemaCache for RealGlueSchemaCache {
106    /// Gets the a specific schema by id, which is used as *writer schema*.
107    async fn get_by_id(&self, schema_version_id: uuid::Uuid) -> ConnectorResult<Arc<Schema>> {
108        if let Some(schema) = self.writer_schemas.get(&schema_version_id).await {
109            return Ok(schema);
110        }
111        let res = self
112            .glue_client
113            .get_schema_version()
114            .schema_version_id(schema_version_id)
115            .send()
116            .await
117            .context("glue sdk error")?;
118        let definition = res
119            .schema_definition()
120            .context("glue sdk response without definition")?;
121        self.parse_and_cache_schema(schema_version_id, definition)
122            .await
123    }
124
125    /// Gets the latest schema by arn, which is used as *reader schema*.
126    async fn get_by_name(&self, schema_arn: &str) -> ConnectorResult<Arc<Schema>> {
127        let res = self
128            .glue_client
129            .get_schema_version()
130            .schema_id(SchemaId::builder().schema_arn(schema_arn).build())
131            .schema_version_number(SchemaVersionNumber::builder().latest_version(true).build())
132            .send()
133            .await
134            .context("glue sdk error")?;
135        let schema_version_id = res
136            .schema_version_id()
137            .context("glue sdk response without schema version id")?
138            .parse()
139            .context("glue sdk response invalid schema version id")?;
140        let definition = res
141            .schema_definition()
142            .context("glue sdk response without definition")?;
143        self.parse_and_cache_schema(schema_version_id, definition)
144            .await
145    }
146}
147
148#[derive(Debug)]
149pub struct MockGlueSchemaCache {
150    by_id: HashMap<uuid::Uuid, Arc<Schema>>,
151    arn_to_latest_id: HashMap<String, uuid::Uuid>,
152}
153
154impl MockGlueSchemaCache {
155    pub fn new(mock_config: &str) -> Self {
156        // The `mock_config` accepted is a JSON that looks like:
157        // {
158        //   "by_id": {
159        //     "4dc80ccf-2d0c-4846-9325-7e1c9e928121": {
160        //       "type": "record",
161        //       "name": "MyEvent",
162        //       "fields": [...]
163        //     },
164        //     "3df022f4-b16d-4afe-bdf7-cf4baf8d01d3": {
165        //       ...
166        //     }
167        //   },
168        //   "arn_to_latest_id": {
169        //     "arn:aws:glue:ap-southeast-1:123456123456:schema/default-registry/MyEvent": "3df022f4-b16d-4afe-bdf7-cf4baf8d01d3"
170        //   }
171        // }
172        //
173        // The format is not public and we can make breaking changes to it.
174        // Current format only supports avsc.
175        let parsed: serde_json::Value =
176            serde_json::from_str(mock_config).expect("mock config shall be valid json");
177        let by_id = parsed
178            .get("by_id")
179            .unwrap()
180            .as_object()
181            .unwrap()
182            .iter()
183            .map(|(schema_version_id, schema)| {
184                let schema_version_id = schema_version_id.parse().unwrap();
185                let schema = Schema::parse(schema).unwrap();
186                (schema_version_id, Arc::new(schema))
187            })
188            .collect();
189        let arn_to_latest_id = parsed
190            .get("arn_to_latest_id")
191            .unwrap()
192            .as_object()
193            .unwrap()
194            .iter()
195            .map(|(arn, latest_id)| (arn.clone(), latest_id.as_str().unwrap().parse().unwrap()))
196            .collect();
197        Self {
198            by_id,
199            arn_to_latest_id,
200        }
201    }
202}
203
204impl GlueSchemaCache for MockGlueSchemaCache {
205    async fn get_by_id(&self, schema_version_id: uuid::Uuid) -> ConnectorResult<Arc<Schema>> {
206        Ok(self
207            .by_id
208            .get(&schema_version_id)
209            .context("schema version id not found in mock registry")?
210            .clone())
211    }
212
213    async fn get_by_name(&self, schema_arn: &str) -> ConnectorResult<Arc<Schema>> {
214        let schema_version_id = self
215            .arn_to_latest_id
216            .get(schema_arn)
217            .context("schema arn not found in mock registry")?;
218        self.get_by_id(*schema_version_id).await
219    }
220}