risingwave_connector/schema/
protobuf.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16
17use anyhow::Context as _;
18use prost_reflect::{DescriptorPool, FileDescriptor, MessageDescriptor};
19use prost_types::FileDescriptorSet;
20use risingwave_connector_codec::common::protobuf::compile_pb;
21
22use super::loader::{LoadedSchema, SchemaLoader};
23use super::schema_registry::Subject;
24use super::{
25    InvalidOptionError, MESSAGE_NAME_KEY, SCHEMA_LOCATION_KEY, SCHEMA_REGISTRY_KEY,
26    SchemaFetchError, invalid_option_error,
27};
28use crate::connector_common::AwsAuthProps;
29use crate::parser::{EncodingProperties, ProtobufParserConfig, ProtobufProperties};
30
31/// `aws_auth_props` is only required when reading `s3://` URL.
32pub async fn fetch_descriptor(
33    format_options: &BTreeMap<String, String>,
34    topic: &str,
35    aws_auth_props: Option<&AwsAuthProps>,
36) -> Result<(MessageDescriptor, Option<i32>), SchemaFetchError> {
37    let message_name = format_options
38        .get(MESSAGE_NAME_KEY)
39        .ok_or_else(|| invalid_option_error!("{MESSAGE_NAME_KEY} required"))?
40        .clone();
41    let schema_location = format_options.get(SCHEMA_LOCATION_KEY);
42    let schema_registry = format_options.get(SCHEMA_REGISTRY_KEY);
43    let row_schema_location = match (schema_location, schema_registry) {
44        (Some(_), Some(_)) => {
45            return Err(invalid_option_error!(
46                "cannot use {SCHEMA_LOCATION_KEY} and {SCHEMA_REGISTRY_KEY} together"
47            )
48            .into());
49        }
50        (None, None) => {
51            return Err(invalid_option_error!(
52                "requires one of {SCHEMA_LOCATION_KEY} or {SCHEMA_REGISTRY_KEY}"
53            )
54            .into());
55        }
56        (None, Some(_)) => {
57            let (md, sid) = fetch_from_registry(&message_name, format_options, topic).await?;
58            return Ok((md, Some(sid)));
59        }
60        (Some(url), None) => url.clone(),
61    };
62
63    if row_schema_location.starts_with("s3") && aws_auth_props.is_none() {
64        return Err(invalid_option_error!("s3 URL not supported yet").into());
65    }
66
67    let enc = EncodingProperties::Protobuf(ProtobufProperties {
68        schema_location: crate::parser::SchemaLocation::File {
69            url: row_schema_location,
70            aws_auth_props: aws_auth_props.cloned(),
71        },
72        message_name,
73        // name_strategy, topic, key_message_name, enable_upsert, client_config
74        ..Default::default()
75    });
76    // Ideally, we should extract the schema loading logic from source parser to this place,
77    // and call this in both source and sink.
78    // But right now this function calls into source parser for its schema loading functionality.
79    // This reversed dependency will be fixed when we support schema registry.
80    let conf = ProtobufParserConfig::new(enc)
81        .await
82        .map_err(SchemaFetchError::YetToMigrate)?;
83    Ok((conf.message_descriptor, None))
84}
85
86pub async fn fetch_from_registry(
87    message_name: &str,
88    format_options: &BTreeMap<String, String>,
89    topic: &str,
90) -> Result<(MessageDescriptor, i32), SchemaFetchError> {
91    let loader = SchemaLoader::from_format_options(topic, format_options).await?;
92
93    let (vid, vpb) = loader.load_val_schema::<FileDescriptor>().await?;
94    let vid = match vid {
95        super::SchemaVersion::Confluent(vid) => vid,
96        super::SchemaVersion::Glue(_) => {
97            return Err(
98                invalid_option_error!("Protobuf with Glue Schema Registry unsupported").into(),
99            );
100        }
101    };
102    let message_descriptor = vpb
103        .parent_pool()
104        .get_message_by_name(message_name)
105        .ok_or_else(|| invalid_option_error!("message {message_name} not defined in proto"))?;
106
107    Ok((message_descriptor, vid))
108}
109
110impl LoadedSchema for FileDescriptor {
111    fn compile(primary: Subject, references: Vec<Subject>) -> Result<Self, SchemaFetchError> {
112        let primary_name = primary.name.clone();
113
114        match compile_pb_subject(primary, references)
115            .context("failed to compile protobuf schema into fd set")
116        {
117            Err(e) => Err(SchemaFetchError::SchemaCompile(e.into())),
118            Ok(fd_set) => DescriptorPool::from_file_descriptor_set(fd_set)
119                .context("failed to convert fd set to descriptor pool")
120                .and_then(|pool| {
121                    pool.get_file_by_name(&primary_name)
122                        .context("file lost after compilation")
123                })
124                .map_err(|e| SchemaFetchError::SchemaCompile(e.into())),
125        }
126    }
127}
128
129fn compile_pb_subject(
130    primary_subject: Subject,
131    dependency_subjects: Vec<Subject>,
132) -> Result<FileDescriptorSet, SchemaFetchError> {
133    compile_pb(
134        (
135            primary_subject.name.clone(),
136            primary_subject.schema.content.clone(),
137        ),
138        dependency_subjects
139            .into_iter()
140            .map(|s| (s.name.clone(), s.schema.content.clone())),
141    )
142    .map_err(|e| SchemaFetchError::SchemaCompile(e.into()))
143}