risingwave_connector/sink/
big_query.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::pin::Pin;
16use core::time::Duration;
17use std::collections::{BTreeMap, HashMap, VecDeque};
18
19use anyhow::{Context, anyhow};
20use async_trait::async_trait;
21use base64::Engine;
22use base64::prelude::BASE64_STANDARD;
23use futures::future::pending;
24use futures::prelude::Future;
25use futures::{Stream, StreamExt};
26use futures_async_stream::try_stream;
27use gcp_bigquery_client::Client;
28use gcp_bigquery_client::error::BQError;
29use gcp_bigquery_client::model::query_request::QueryRequest;
30use gcp_bigquery_client::model::query_response::ResultSet;
31use gcp_bigquery_client::model::table::Table;
32use gcp_bigquery_client::model::table_field_schema::TableFieldSchema;
33use gcp_bigquery_client::model::table_schema::TableSchema;
34use google_cloud_bigquery::grpc::apiv1::conn_pool::ConnectionManager;
35use google_cloud_gax::conn::{ConnectionOptions, Environment};
36use google_cloud_gax::grpc::{Request, Response, Status};
37use google_cloud_googleapis::cloud::bigquery::storage::v1::append_rows_request::{
38    MissingValueInterpretation, ProtoData, Rows as AppendRowsRequestRows,
39};
40use google_cloud_googleapis::cloud::bigquery::storage::v1::{
41    AppendRowsRequest, AppendRowsResponse, ProtoRows, ProtoSchema,
42};
43use google_cloud_pubsub::client::google_cloud_auth;
44use google_cloud_pubsub::client::google_cloud_auth::credentials::CredentialsFile;
45use phf::{Set, phf_set};
46use prost_reflect::{FieldDescriptor, MessageDescriptor};
47use prost_types::{
48    DescriptorProto, FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet,
49    field_descriptor_proto,
50};
51use risingwave_common::array::{Op, StreamChunk};
52use risingwave_common::catalog::{Field, Schema};
53use risingwave_common::types::DataType;
54use serde::Deserialize;
55use serde_with::{DisplayFromStr, serde_as};
56use simd_json::prelude::ArrayTrait;
57use tokio::sync::mpsc;
58use url::Url;
59use uuid::Uuid;
60use with_options::WithOptions;
61use yup_oauth2::ServiceAccountKey;
62
63use super::encoder::{ProtoEncoder, ProtoHeader, RowEncoder, SerTo};
64use super::log_store::{LogStoreReadItem, TruncateOffset};
65use super::{
66    LogSinker, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkLogReader,
67};
68use crate::aws_utils::load_file_descriptor_from_s3;
69use crate::connector_common::AwsAuthProps;
70use crate::enforce_secret::EnforceSecret;
71use crate::sink::{Result, Sink, SinkParam, SinkWriterParam};
72
73pub const BIGQUERY_SINK: &str = "bigquery";
74pub const CHANGE_TYPE: &str = "_CHANGE_TYPE";
75const DEFAULT_GRPC_CHANNEL_NUMS: usize = 4;
76const CONNECT_TIMEOUT: Option<Duration> = Some(Duration::from_secs(30));
77const CONNECTION_TIMEOUT: Option<Duration> = None;
78const BIGQUERY_SEND_FUTURE_BUFFER_MAX_SIZE: usize = 65536;
79// < 10MB, we set 8MB
80const MAX_ROW_SIZE: usize = 8 * 1024 * 1024;
81
82#[serde_as]
83#[derive(Deserialize, Debug, Clone, WithOptions)]
84pub struct BigQueryCommon {
85    #[serde(rename = "bigquery.local.path")]
86    pub local_path: Option<String>,
87    #[serde(rename = "bigquery.s3.path")]
88    pub s3_path: Option<String>,
89    #[serde(rename = "bigquery.project")]
90    pub project: String,
91    #[serde(rename = "bigquery.dataset")]
92    pub dataset: String,
93    #[serde(rename = "bigquery.table")]
94    pub table: String,
95    #[serde(default)] // default false
96    #[serde_as(as = "DisplayFromStr")]
97    pub auto_create: bool,
98    #[serde(rename = "bigquery.credentials")]
99    pub credentials: Option<String>,
100}
101
102impl EnforceSecret for BigQueryCommon {
103    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
104        "bigquery.credentials",
105    };
106}
107
108struct BigQueryFutureManager {
109    // `offset_queue` holds the Some corresponding to each future.
110    // When TruncateOffset is barrier, the num is 0, we don't need to wait for the return of `resp_stream`.
111    // When TruncateOffset is chunk:
112    // 1. chunk has no rows. we didn't send, the num is 0, we don't need to wait for the return of `resp_stream`.
113    // 2. chunk is less than `MAX_ROW_SIZE`, we only sent once, the num is 1 and we only have to wait once for `resp_stream`.
114    // 3. chunk is less than `MAX_ROW_SIZE`, we only sent n, the num is n and we need to wait n times for r.
115    offset_queue: VecDeque<(TruncateOffset, usize)>,
116    resp_stream: Pin<Box<dyn Stream<Item = Result<()>> + Send>>,
117}
118impl BigQueryFutureManager {
119    pub fn new(
120        max_future_num: usize,
121        resp_stream: impl Stream<Item = Result<()>> + Send + 'static,
122    ) -> Self {
123        let offset_queue = VecDeque::with_capacity(max_future_num);
124        Self {
125            offset_queue,
126            resp_stream: Box::pin(resp_stream),
127        }
128    }
129
130    pub fn add_offset(&mut self, offset: TruncateOffset, resp_num: usize) {
131        self.offset_queue.push_back((offset, resp_num));
132    }
133
134    pub async fn next_offset(&mut self) -> Result<TruncateOffset> {
135        if let Some((_offset, remaining_resp_num)) = self.offset_queue.front_mut() {
136            if *remaining_resp_num == 0 {
137                return Ok(self.offset_queue.pop_front().unwrap().0);
138            }
139            while *remaining_resp_num > 0 {
140                self.resp_stream
141                    .next()
142                    .await
143                    .ok_or_else(|| SinkError::BigQuery(anyhow::anyhow!("end of stream")))??;
144                *remaining_resp_num -= 1;
145            }
146            Ok(self.offset_queue.pop_front().unwrap().0)
147        } else {
148            pending().await
149        }
150    }
151}
152pub struct BigQueryLogSinker {
153    writer: BigQuerySinkWriter,
154    bigquery_future_manager: BigQueryFutureManager,
155    future_num: usize,
156}
157impl BigQueryLogSinker {
158    pub fn new(
159        writer: BigQuerySinkWriter,
160        resp_stream: impl Stream<Item = Result<()>> + Send + 'static,
161        future_num: usize,
162    ) -> Self {
163        Self {
164            writer,
165            bigquery_future_manager: BigQueryFutureManager::new(future_num, resp_stream),
166            future_num,
167        }
168    }
169}
170
171#[async_trait]
172impl LogSinker for BigQueryLogSinker {
173    async fn consume_log_and_sink(mut self, mut log_reader: impl SinkLogReader) -> Result<!> {
174        log_reader.start_from(None).await?;
175        loop {
176            tokio::select!(
177                offset = self.bigquery_future_manager.next_offset() => {
178                        log_reader.truncate(offset?)?;
179                }
180                item_result = log_reader.next_item(), if self.bigquery_future_manager.offset_queue.len() <= self.future_num => {
181                    let (epoch, item) = item_result?;
182                    match item {
183                        LogStoreReadItem::StreamChunk { chunk_id, chunk } => {
184                            let resp_num = self.writer.write_chunk(chunk)?;
185                            self.bigquery_future_manager
186                                .add_offset(TruncateOffset::Chunk { epoch, chunk_id },resp_num);
187                        }
188                        LogStoreReadItem::Barrier { .. } => {
189                            self.bigquery_future_manager
190                                .add_offset(TruncateOffset::Barrier { epoch },0);
191                        }
192                    }
193                }
194            )
195        }
196    }
197}
198
199impl BigQueryCommon {
200    async fn build_client(&self, aws_auth_props: &AwsAuthProps) -> Result<Client> {
201        let auth_json = self.get_auth_json_from_path(aws_auth_props).await?;
202
203        let service_account =
204            if let Ok(auth_json_from_base64) = BASE64_STANDARD.decode(auth_json.clone()) {
205                serde_json::from_slice::<ServiceAccountKey>(&auth_json_from_base64)
206            } else {
207                serde_json::from_str::<ServiceAccountKey>(&auth_json)
208            }
209            .map_err(|e| SinkError::BigQuery(e.into()))?;
210
211        let client: Client = Client::from_service_account_key(service_account, false)
212            .await
213            .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?;
214        Ok(client)
215    }
216
217    async fn build_writer_client(
218        &self,
219        aws_auth_props: &AwsAuthProps,
220    ) -> Result<(StorageWriterClient, impl Stream<Item = Result<()>> + use<>)> {
221        let auth_json = self.get_auth_json_from_path(aws_auth_props).await?;
222
223        let credentials_file =
224            if let Ok(auth_json_from_base64) = BASE64_STANDARD.decode(auth_json.clone()) {
225                serde_json::from_slice::<CredentialsFile>(&auth_json_from_base64)
226            } else {
227                serde_json::from_str::<CredentialsFile>(&auth_json)
228            }
229            .map_err(|e| SinkError::BigQuery(e.into()))?;
230
231        StorageWriterClient::new(credentials_file).await
232    }
233
234    async fn get_auth_json_from_path(&self, aws_auth_props: &AwsAuthProps) -> Result<String> {
235        if let Some(credentials) = &self.credentials {
236            Ok(credentials.clone())
237        } else if let Some(local_path) = &self.local_path {
238            std::fs::read_to_string(local_path)
239                .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))
240        } else if let Some(s3_path) = &self.s3_path {
241            let url =
242                Url::parse(s3_path).map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?;
243            let auth_vec = load_file_descriptor_from_s3(&url, aws_auth_props)
244                .await
245                .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?;
246            Ok(String::from_utf8(auth_vec).map_err(|e| SinkError::BigQuery(e.into()))?)
247        } else {
248            Err(SinkError::BigQuery(anyhow::anyhow!(
249                "`bigquery.local.path` and `bigquery.s3.path` set at least one, configure as needed."
250            )))
251        }
252    }
253}
254
255#[serde_as]
256#[derive(Clone, Debug, Deserialize, WithOptions)]
257pub struct BigQueryConfig {
258    #[serde(flatten)]
259    pub common: BigQueryCommon,
260    #[serde(flatten)]
261    pub aws_auth_props: AwsAuthProps,
262    pub r#type: String, // accept "append-only" or "upsert"
263}
264
265impl EnforceSecret for BigQueryConfig {
266    fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
267        BigQueryCommon::enforce_one(prop)?;
268        AwsAuthProps::enforce_one(prop)?;
269        Ok(())
270    }
271}
272
273impl BigQueryConfig {
274    pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
275        let config =
276            serde_json::from_value::<BigQueryConfig>(serde_json::to_value(properties).unwrap())
277                .map_err(|e| SinkError::Config(anyhow!(e)))?;
278        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
279            return Err(SinkError::Config(anyhow!(
280                "`{}` must be {}, or {}",
281                SINK_TYPE_OPTION,
282                SINK_TYPE_APPEND_ONLY,
283                SINK_TYPE_UPSERT
284            )));
285        }
286        Ok(config)
287    }
288}
289
290#[derive(Debug)]
291pub struct BigQuerySink {
292    pub config: BigQueryConfig,
293    schema: Schema,
294    pk_indices: Vec<usize>,
295    is_append_only: bool,
296}
297
298impl EnforceSecret for BigQuerySink {
299    fn enforce_secret<'a>(
300        prop_iter: impl Iterator<Item = &'a str>,
301    ) -> crate::error::ConnectorResult<()> {
302        for prop in prop_iter {
303            BigQueryConfig::enforce_one(prop)?;
304        }
305        Ok(())
306    }
307}
308
309impl BigQuerySink {
310    pub fn new(
311        config: BigQueryConfig,
312        schema: Schema,
313        pk_indices: Vec<usize>,
314        is_append_only: bool,
315    ) -> Result<Self> {
316        Ok(Self {
317            config,
318            schema,
319            pk_indices,
320            is_append_only,
321        })
322    }
323}
324
325impl BigQuerySink {
326    fn check_column_name_and_type(
327        &self,
328        big_query_columns_desc: HashMap<String, String>,
329    ) -> Result<()> {
330        let rw_fields_name = self.schema.fields();
331        if big_query_columns_desc.is_empty() {
332            return Err(SinkError::BigQuery(anyhow::anyhow!(
333                "Cannot find table in bigquery"
334            )));
335        }
336        if rw_fields_name.len().ne(&big_query_columns_desc.len()) {
337            return Err(SinkError::BigQuery(anyhow::anyhow!(
338                "The length of the RisingWave column {} must be equal to the length of the bigquery column {}",
339                rw_fields_name.len(),
340                big_query_columns_desc.len()
341            )));
342        }
343
344        for i in rw_fields_name {
345            let value = big_query_columns_desc.get(&i.name).ok_or_else(|| {
346                SinkError::BigQuery(anyhow::anyhow!(
347                    "Column `{:?}` on RisingWave side is not found on BigQuery side.",
348                    i.name
349                ))
350            })?;
351            let data_type_string = Self::get_string_and_check_support_from_datatype(&i.data_type)?;
352            if data_type_string.ne(value) {
353                return Err(SinkError::BigQuery(anyhow::anyhow!(
354                    "Data type mismatch for column `{:?}`. BigQuery side: `{:?}`, RisingWave side: `{:?}`. ",
355                    i.name,
356                    value,
357                    data_type_string
358                )));
359            };
360        }
361        Ok(())
362    }
363
364    fn get_string_and_check_support_from_datatype(rw_data_type: &DataType) -> Result<String> {
365        match rw_data_type {
366            DataType::Boolean => Ok("BOOL".to_owned()),
367            DataType::Int16 => Ok("INT64".to_owned()),
368            DataType::Int32 => Ok("INT64".to_owned()),
369            DataType::Int64 => Ok("INT64".to_owned()),
370            DataType::Float32 => Err(SinkError::BigQuery(anyhow::anyhow!(
371                "REAL is not supported for BigQuery sink. Please convert to FLOAT64 or other supported types."
372            ))),
373            DataType::Float64 => Ok("FLOAT64".to_owned()),
374            DataType::Decimal => Ok("NUMERIC".to_owned()),
375            DataType::Date => Ok("DATE".to_owned()),
376            DataType::Varchar => Ok("STRING".to_owned()),
377            DataType::Time => Ok("TIME".to_owned()),
378            DataType::Timestamp => Ok("DATETIME".to_owned()),
379            DataType::Timestamptz => Ok("TIMESTAMP".to_owned()),
380            DataType::Interval => Ok("INTERVAL".to_owned()),
381            DataType::Struct(structs) => {
382                let mut elements_vec = vec![];
383                for (name, datatype) in structs.iter() {
384                    let element_string =
385                        Self::get_string_and_check_support_from_datatype(datatype)?;
386                    elements_vec.push(format!("{} {}", name, element_string));
387                }
388                Ok(format!("STRUCT<{}>", elements_vec.join(", ")))
389            }
390            DataType::List(l) => {
391                let element_string = Self::get_string_and_check_support_from_datatype(l.elem())?;
392                Ok(format!("ARRAY<{}>", element_string))
393            }
394            DataType::Bytea => Ok("BYTES".to_owned()),
395            DataType::Jsonb => Ok("JSON".to_owned()),
396            DataType::Serial => Ok("INT64".to_owned()),
397            DataType::Int256 => Err(SinkError::BigQuery(anyhow::anyhow!(
398                "INT256 is not supported for BigQuery sink."
399            ))),
400            DataType::Map(_) => Err(SinkError::BigQuery(anyhow::anyhow!(
401                "MAP is not supported for BigQuery sink."
402            ))),
403            DataType::Vector(_) => Err(SinkError::BigQuery(anyhow::anyhow!(
404                "VECTOR is not supported for BigQuery sink."
405            ))),
406        }
407    }
408
409    fn map_field(rw_field: &Field) -> Result<TableFieldSchema> {
410        let tfs = match &rw_field.data_type {
411            DataType::Boolean => TableFieldSchema::bool(&rw_field.name),
412            DataType::Int16 | DataType::Int32 | DataType::Int64 | DataType::Serial => {
413                TableFieldSchema::integer(&rw_field.name)
414            }
415            DataType::Float32 => {
416                return Err(SinkError::BigQuery(anyhow::anyhow!(
417                    "REAL is not supported for BigQuery sink. Please convert to FLOAT64 or other supported types."
418                )));
419            }
420            DataType::Float64 => TableFieldSchema::float(&rw_field.name),
421            DataType::Decimal => TableFieldSchema::numeric(&rw_field.name),
422            DataType::Date => TableFieldSchema::date(&rw_field.name),
423            DataType::Varchar => TableFieldSchema::string(&rw_field.name),
424            DataType::Time => TableFieldSchema::time(&rw_field.name),
425            DataType::Timestamp => TableFieldSchema::date_time(&rw_field.name),
426            DataType::Timestamptz => TableFieldSchema::timestamp(&rw_field.name),
427            DataType::Interval => {
428                return Err(SinkError::BigQuery(anyhow::anyhow!(
429                    "INTERVAL is not supported for BigQuery sink. Please convert to VARCHAR or other supported types."
430                )));
431            }
432            DataType::Struct(st) => {
433                let mut sub_fields = Vec::with_capacity(st.len());
434                for (name, dt) in st.iter() {
435                    let rw_field = Field::with_name(dt.clone(), name);
436                    let field = Self::map_field(&rw_field)?;
437                    sub_fields.push(field);
438                }
439                TableFieldSchema::record(&rw_field.name, sub_fields)
440            }
441            DataType::List(lt) => {
442                let inner_field =
443                    Self::map_field(&Field::with_name(lt.elem().clone(), &rw_field.name))?;
444                TableFieldSchema {
445                    mode: Some("REPEATED".to_owned()),
446                    ..inner_field
447                }
448            }
449
450            DataType::Bytea => TableFieldSchema::bytes(&rw_field.name),
451            DataType::Jsonb => TableFieldSchema::json(&rw_field.name),
452            DataType::Int256 => {
453                return Err(SinkError::BigQuery(anyhow::anyhow!(
454                    "INT256 is not supported for BigQuery sink."
455                )));
456            }
457            DataType::Map(_) => {
458                return Err(SinkError::BigQuery(anyhow::anyhow!(
459                    "MAP is not supported for BigQuery sink."
460                )));
461            }
462            DataType::Vector(_) => {
463                return Err(SinkError::BigQuery(anyhow::anyhow!(
464                    "VECTOR is not supported for BigQuery sink."
465                )));
466            }
467        };
468        Ok(tfs)
469    }
470
471    async fn create_table(
472        &self,
473        client: &Client,
474        project_id: &str,
475        dataset_id: &str,
476        table_id: &str,
477        fields: &Vec<Field>,
478    ) -> Result<Table> {
479        let dataset = client
480            .dataset()
481            .get(project_id, dataset_id)
482            .await
483            .map_err(|e| SinkError::BigQuery(e.into()))?;
484        let fields: Vec<_> = fields.iter().map(Self::map_field).collect::<Result<_>>()?;
485        let table = Table::from_dataset(&dataset, table_id, TableSchema::new(fields));
486
487        client
488            .table()
489            .create(table)
490            .await
491            .map_err(|e| SinkError::BigQuery(e.into()))
492    }
493}
494
495impl Sink for BigQuerySink {
496    type LogSinker = BigQueryLogSinker;
497
498    const SINK_NAME: &'static str = BIGQUERY_SINK;
499
500    async fn new_log_sinker(&self, _writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
501        let (writer, resp_stream) = BigQuerySinkWriter::new(
502            self.config.clone(),
503            self.schema.clone(),
504            self.pk_indices.clone(),
505            self.is_append_only,
506        )
507        .await?;
508        Ok(BigQueryLogSinker::new(
509            writer,
510            resp_stream,
511            BIGQUERY_SEND_FUTURE_BUFFER_MAX_SIZE,
512        ))
513    }
514
515    async fn validate(&self) -> Result<()> {
516        risingwave_common::license::Feature::BigQuerySink
517            .check_available()
518            .map_err(|e| anyhow::anyhow!(e))?;
519        if !self.is_append_only && self.pk_indices.is_empty() {
520            return Err(SinkError::Config(anyhow!(
521                "Primary key not defined for upsert bigquery sink (please define in `primary_key` field)"
522            )));
523        }
524        let client = self
525            .config
526            .common
527            .build_client(&self.config.aws_auth_props)
528            .await?;
529        let BigQueryCommon {
530            project: project_id,
531            dataset: dataset_id,
532            table: table_id,
533            ..
534        } = &self.config.common;
535
536        if self.config.common.auto_create {
537            match client
538                .table()
539                .get(project_id, dataset_id, table_id, None)
540                .await
541            {
542                Err(BQError::RequestError(_)) => {
543                    // early return: no need to query schema to check column and type
544                    return self
545                        .create_table(
546                            &client,
547                            project_id,
548                            dataset_id,
549                            table_id,
550                            &self.schema.fields,
551                        )
552                        .await
553                        .map(|_| ());
554                }
555                Err(e) => return Err(SinkError::BigQuery(e.into())),
556                _ => {}
557            }
558        }
559
560        let rs = client
561            .job()
562            .query(
563                &self.config.common.project,
564                QueryRequest::new(format!(
565                    "SELECT column_name, data_type FROM `{}.{}.INFORMATION_SCHEMA.COLUMNS` WHERE table_name = '{}'",
566                    project_id, dataset_id, table_id,
567                )),
568            ).await.map_err(|e| SinkError::BigQuery(e.into()))?;
569        let mut rs = ResultSet::new_from_query_response(rs);
570
571        let mut big_query_schema = HashMap::default();
572        while rs.next_row() {
573            big_query_schema.insert(
574                rs.get_string_by_name("column_name")
575                    .map_err(|e| SinkError::BigQuery(e.into()))?
576                    .ok_or_else(|| {
577                        SinkError::BigQuery(anyhow::anyhow!("Cannot find column_name"))
578                    })?,
579                rs.get_string_by_name("data_type")
580                    .map_err(|e| SinkError::BigQuery(e.into()))?
581                    .ok_or_else(|| {
582                        SinkError::BigQuery(anyhow::anyhow!("Cannot find column_name"))
583                    })?,
584            );
585        }
586
587        self.check_column_name_and_type(big_query_schema)?;
588        Ok(())
589    }
590}
591
592pub struct BigQuerySinkWriter {
593    pub config: BigQueryConfig,
594    #[expect(dead_code)]
595    schema: Schema,
596    #[expect(dead_code)]
597    pk_indices: Vec<usize>,
598    client: StorageWriterClient,
599    is_append_only: bool,
600    row_encoder: ProtoEncoder,
601    writer_pb_schema: ProtoSchema,
602    #[expect(dead_code)]
603    message_descriptor: MessageDescriptor,
604    write_stream: String,
605    proto_field: Option<FieldDescriptor>,
606}
607
608impl TryFrom<SinkParam> for BigQuerySink {
609    type Error = SinkError;
610
611    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
612        let schema = param.schema();
613        let pk_indices = param.downstream_pk_or_empty();
614        let config = BigQueryConfig::from_btreemap(param.properties)?;
615        BigQuerySink::new(config, schema, pk_indices, param.sink_type.is_append_only())
616    }
617}
618
619impl BigQuerySinkWriter {
620    pub async fn new(
621        config: BigQueryConfig,
622        schema: Schema,
623        pk_indices: Vec<usize>,
624        is_append_only: bool,
625    ) -> Result<(Self, impl Stream<Item = Result<()>>)> {
626        let (client, resp_stream) = config
627            .common
628            .build_writer_client(&config.aws_auth_props)
629            .await?;
630        let mut descriptor_proto = build_protobuf_schema(
631            schema
632                .fields()
633                .iter()
634                .map(|f| (f.name.as_str(), &f.data_type)),
635            config.common.table.clone(),
636        )?;
637
638        if !is_append_only {
639            let field = FieldDescriptorProto {
640                name: Some(CHANGE_TYPE.to_owned()),
641                number: Some((schema.len() + 1) as i32),
642                r#type: Some(field_descriptor_proto::Type::String.into()),
643                ..Default::default()
644            };
645            descriptor_proto.field.push(field);
646        }
647
648        let descriptor_pool = build_protobuf_descriptor_pool(&descriptor_proto)?;
649        let message_descriptor = descriptor_pool
650            .get_message_by_name(&config.common.table)
651            .ok_or_else(|| {
652                SinkError::BigQuery(anyhow::anyhow!(
653                    "Can't find message proto {}",
654                    &config.common.table
655                ))
656            })?;
657        let proto_field = if !is_append_only {
658            let proto_field = message_descriptor
659                .get_field_by_name(CHANGE_TYPE)
660                .ok_or_else(|| {
661                    SinkError::BigQuery(anyhow::anyhow!("Can't find {}", CHANGE_TYPE))
662                })?;
663            Some(proto_field)
664        } else {
665            None
666        };
667        let row_encoder = ProtoEncoder::new(
668            schema.clone(),
669            None,
670            message_descriptor.clone(),
671            ProtoHeader::None,
672        )?;
673        Ok((
674            Self {
675                write_stream: format!(
676                    "projects/{}/datasets/{}/tables/{}/streams/_default",
677                    config.common.project, config.common.dataset, config.common.table
678                ),
679                config,
680                schema,
681                pk_indices,
682                client,
683                is_append_only,
684                row_encoder,
685                message_descriptor,
686                proto_field,
687                writer_pb_schema: ProtoSchema {
688                    proto_descriptor: Some(descriptor_proto.clone()),
689                },
690            },
691            resp_stream,
692        ))
693    }
694
695    fn append_only(&mut self, chunk: StreamChunk) -> Result<Vec<Vec<u8>>> {
696        let mut serialized_rows: Vec<Vec<u8>> = Vec::with_capacity(chunk.capacity());
697        for (op, row) in chunk.rows() {
698            if op != Op::Insert {
699                continue;
700            }
701            serialized_rows.push(self.row_encoder.encode(row)?.ser_to()?)
702        }
703        Ok(serialized_rows)
704    }
705
706    fn upsert(&mut self, chunk: StreamChunk) -> Result<Vec<Vec<u8>>> {
707        let mut serialized_rows: Vec<Vec<u8>> = Vec::with_capacity(chunk.capacity());
708        for (op, row) in chunk.rows() {
709            if op == Op::UpdateDelete {
710                continue;
711            }
712            let mut pb_row = self.row_encoder.encode(row)?;
713            match op {
714                Op::Insert => pb_row
715                    .message
716                    .try_set_field(
717                        self.proto_field.as_ref().unwrap(),
718                        prost_reflect::Value::String("UPSERT".to_owned()),
719                    )
720                    .map_err(|e| SinkError::BigQuery(e.into()))?,
721                Op::Delete => pb_row
722                    .message
723                    .try_set_field(
724                        self.proto_field.as_ref().unwrap(),
725                        prost_reflect::Value::String("DELETE".to_owned()),
726                    )
727                    .map_err(|e| SinkError::BigQuery(e.into()))?,
728                Op::UpdateDelete => continue,
729                Op::UpdateInsert => pb_row
730                    .message
731                    .try_set_field(
732                        self.proto_field.as_ref().unwrap(),
733                        prost_reflect::Value::String("UPSERT".to_owned()),
734                    )
735                    .map_err(|e| SinkError::BigQuery(e.into()))?,
736            };
737
738            serialized_rows.push(pb_row.ser_to()?)
739        }
740        Ok(serialized_rows)
741    }
742
743    fn write_chunk(&mut self, chunk: StreamChunk) -> Result<usize> {
744        let serialized_rows = if self.is_append_only {
745            self.append_only(chunk)?
746        } else {
747            self.upsert(chunk)?
748        };
749        if serialized_rows.is_empty() {
750            return Ok(0);
751        }
752        let mut result = Vec::new();
753        let mut result_inner = Vec::new();
754        let mut size_count = 0;
755        for i in serialized_rows {
756            size_count += i.len();
757            if size_count > MAX_ROW_SIZE {
758                result.push(result_inner);
759                result_inner = Vec::new();
760                size_count = i.len();
761            }
762            result_inner.push(i);
763        }
764        if !result_inner.is_empty() {
765            result.push(result_inner);
766        }
767        let len = result.len();
768        for serialized_rows in result {
769            let rows = AppendRowsRequestRows::ProtoRows(ProtoData {
770                writer_schema: Some(self.writer_pb_schema.clone()),
771                rows: Some(ProtoRows { serialized_rows }),
772            });
773            self.client.append_rows(rows, self.write_stream.clone())?;
774        }
775        Ok(len)
776    }
777}
778
779#[try_stream(ok = (), error = SinkError)]
780pub async fn resp_to_stream(
781    resp_stream: impl Future<
782        Output = std::result::Result<
783            Response<google_cloud_gax::grpc::Streaming<AppendRowsResponse>>,
784            Status,
785        >,
786    >
787    + 'static
788    + Send,
789) {
790    let mut resp_stream = resp_stream
791        .await
792        .map_err(|e| SinkError::BigQuery(e.into()))?
793        .into_inner();
794    loop {
795        match resp_stream
796            .message()
797            .await
798            .map_err(|e| SinkError::BigQuery(e.into()))?
799        {
800            Some(append_rows_response) => {
801                if !append_rows_response.row_errors.is_empty() {
802                    return Err(SinkError::BigQuery(anyhow::anyhow!(
803                        "bigquery insert error {:?}",
804                        append_rows_response.row_errors
805                    )));
806                }
807                if let Some(google_cloud_googleapis::cloud::bigquery::storage::v1::append_rows_response::Response::Error(status)) = append_rows_response.response{
808                            return Err(SinkError::BigQuery(anyhow::anyhow!(
809                                "bigquery insert error {:?}",
810                                status
811                            )));
812                        }
813                yield ();
814            }
815            None => {
816                return Err(SinkError::BigQuery(anyhow::anyhow!(
817                    "bigquery insert error: end of resp stream",
818                )));
819            }
820        }
821    }
822}
823
824struct StorageWriterClient {
825    #[expect(dead_code)]
826    environment: Environment,
827    request_sender: mpsc::UnboundedSender<AppendRowsRequest>,
828}
829impl StorageWriterClient {
830    pub async fn new(
831        credentials: CredentialsFile,
832    ) -> Result<(Self, impl Stream<Item = Result<()>>)> {
833        let ts_grpc = google_cloud_auth::token::DefaultTokenSourceProvider::new_with_credentials(
834            Self::bigquery_grpc_auth_config(),
835            Box::new(credentials),
836        )
837        .await
838        .map_err(|e| SinkError::BigQuery(e.into()))?;
839        let conn_options = ConnectionOptions {
840            connect_timeout: CONNECT_TIMEOUT,
841            timeout: CONNECTION_TIMEOUT,
842        };
843        let environment = Environment::GoogleCloud(Box::new(ts_grpc));
844        let conn = ConnectionManager::new(DEFAULT_GRPC_CHANNEL_NUMS, &environment, &conn_options)
845            .await
846            .map_err(|e| SinkError::BigQuery(e.into()))?;
847        let mut client = conn.writer();
848
849        let (tx, rx) = mpsc::unbounded_channel();
850        let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
851
852        let resp = async move { client.append_rows(Request::new(stream)).await };
853        let resp_stream = resp_to_stream(resp);
854
855        Ok((
856            StorageWriterClient {
857                environment,
858                request_sender: tx,
859            },
860            resp_stream,
861        ))
862    }
863
864    pub fn append_rows(&mut self, row: AppendRowsRequestRows, write_stream: String) -> Result<()> {
865        let append_req = AppendRowsRequest {
866            write_stream,
867            offset: None,
868            trace_id: Uuid::new_v4().hyphenated().to_string(),
869            missing_value_interpretations: HashMap::default(),
870            rows: Some(row),
871            default_missing_value_interpretation: MissingValueInterpretation::DefaultValue as i32,
872        };
873        self.request_sender
874            .send(append_req)
875            .map_err(|e| SinkError::BigQuery(e.into()))?;
876        Ok(())
877    }
878
879    fn bigquery_grpc_auth_config() -> google_cloud_auth::project::Config<'static> {
880        let mut auth_config = google_cloud_auth::project::Config::default();
881        auth_config =
882            auth_config.with_audience(google_cloud_bigquery::grpc::apiv1::conn_pool::AUDIENCE);
883        auth_config =
884            auth_config.with_scopes(&google_cloud_bigquery::grpc::apiv1::conn_pool::SCOPES);
885        auth_config
886    }
887}
888
889fn build_protobuf_descriptor_pool(desc: &DescriptorProto) -> Result<prost_reflect::DescriptorPool> {
890    let file_descriptor = FileDescriptorProto {
891        message_type: vec![desc.clone()],
892        name: Some("bigquery".to_owned()),
893        ..Default::default()
894    };
895
896    prost_reflect::DescriptorPool::from_file_descriptor_set(FileDescriptorSet {
897        file: vec![file_descriptor],
898    })
899    .context("failed to build descriptor pool")
900    .map_err(SinkError::BigQuery)
901}
902
903fn build_protobuf_schema<'a>(
904    fields: impl Iterator<Item = (&'a str, &'a DataType)>,
905    name: String,
906) -> Result<DescriptorProto> {
907    let mut proto = DescriptorProto {
908        name: Some(name),
909        ..Default::default()
910    };
911    let mut struct_vec = vec![];
912    let field_vec = fields
913        .enumerate()
914        .map(|(index, (name, data_type))| {
915            let (field, des_proto) =
916                build_protobuf_field(data_type, (index + 1) as i32, name.to_owned())?;
917            if let Some(sv) = des_proto {
918                struct_vec.push(sv);
919            }
920            Ok(field)
921        })
922        .collect::<Result<Vec<_>>>()?;
923    proto.field = field_vec;
924    proto.nested_type = struct_vec;
925    Ok(proto)
926}
927
928fn build_protobuf_field(
929    data_type: &DataType,
930    index: i32,
931    name: String,
932) -> Result<(FieldDescriptorProto, Option<DescriptorProto>)> {
933    let mut field = FieldDescriptorProto {
934        name: Some(name.clone()),
935        number: Some(index),
936        ..Default::default()
937    };
938    match data_type {
939        DataType::Boolean => field.r#type = Some(field_descriptor_proto::Type::Bool.into()),
940        DataType::Int32 => field.r#type = Some(field_descriptor_proto::Type::Int32.into()),
941        DataType::Int16 | DataType::Int64 => {
942            field.r#type = Some(field_descriptor_proto::Type::Int64.into())
943        }
944        DataType::Float64 => field.r#type = Some(field_descriptor_proto::Type::Double.into()),
945        DataType::Decimal => field.r#type = Some(field_descriptor_proto::Type::String.into()),
946        DataType::Date => field.r#type = Some(field_descriptor_proto::Type::Int32.into()),
947        DataType::Varchar => field.r#type = Some(field_descriptor_proto::Type::String.into()),
948        DataType::Time => field.r#type = Some(field_descriptor_proto::Type::String.into()),
949        DataType::Timestamp => field.r#type = Some(field_descriptor_proto::Type::String.into()),
950        DataType::Timestamptz => field.r#type = Some(field_descriptor_proto::Type::String.into()),
951        DataType::Interval => field.r#type = Some(field_descriptor_proto::Type::String.into()),
952        DataType::Struct(s) => {
953            field.r#type = Some(field_descriptor_proto::Type::Message.into());
954            let name = format!("Struct{}", name);
955            let sub_proto = build_protobuf_schema(s.iter(), name.clone())?;
956            field.type_name = Some(name);
957            return Ok((field, Some(sub_proto)));
958        }
959        DataType::List(l) => {
960            let (mut field, proto) = build_protobuf_field(l.elem(), index, name)?;
961            field.label = Some(field_descriptor_proto::Label::Repeated.into());
962            return Ok((field, proto));
963        }
964        DataType::Bytea => field.r#type = Some(field_descriptor_proto::Type::Bytes.into()),
965        DataType::Jsonb => field.r#type = Some(field_descriptor_proto::Type::String.into()),
966        DataType::Serial => field.r#type = Some(field_descriptor_proto::Type::Int64.into()),
967        DataType::Float32 | DataType::Int256 => {
968            return Err(SinkError::BigQuery(anyhow::anyhow!(
969                "Don't support Float32 and Int256"
970            )));
971        }
972        DataType::Map(_) => return Err(SinkError::BigQuery(anyhow::anyhow!("Don't support Map"))),
973        DataType::Vector(_) => {
974            return Err(SinkError::BigQuery(anyhow::anyhow!("Don't support Vector")));
975        }
976    }
977    Ok((field, None))
978}
979
980#[cfg(test)]
981mod test {
982
983    use std::assert_matches::assert_matches;
984
985    use risingwave_common::catalog::{Field, Schema};
986    use risingwave_common::types::{DataType, StructType};
987
988    use crate::sink::big_query::{
989        BigQuerySink, build_protobuf_descriptor_pool, build_protobuf_schema,
990    };
991
992    #[tokio::test]
993    async fn test_type_check() {
994        let big_query_type_string = "ARRAY<STRUCT<v1 ARRAY<INT64>, v2 STRUCT<v1 INT64, v2 INT64>>>";
995        let rw_datatype = DataType::list(DataType::Struct(StructType::new(vec![
996            ("v1".to_owned(), DataType::Int64.list()),
997            (
998                "v2".to_owned(),
999                DataType::Struct(StructType::new(vec![
1000                    ("v1".to_owned(), DataType::Int64),
1001                    ("v2".to_owned(), DataType::Int64),
1002                ])),
1003            ),
1004        ])));
1005        assert_eq!(
1006            BigQuerySink::get_string_and_check_support_from_datatype(&rw_datatype).unwrap(),
1007            big_query_type_string
1008        );
1009    }
1010
1011    #[tokio::test]
1012    async fn test_schema_check() {
1013        let schema = Schema {
1014            fields: vec![
1015                Field::with_name(DataType::Int64, "v1"),
1016                Field::with_name(DataType::Float64, "v2"),
1017                Field::with_name(
1018                    DataType::list(DataType::Struct(StructType::new(vec![
1019                        ("v1".to_owned(), DataType::Int64.list()),
1020                        (
1021                            "v3".to_owned(),
1022                            DataType::Struct(StructType::new(vec![
1023                                ("v1".to_owned(), DataType::Int64),
1024                                ("v2".to_owned(), DataType::Int64),
1025                            ])),
1026                        ),
1027                    ]))),
1028                    "v3",
1029                ),
1030            ],
1031        };
1032        let fields = schema
1033            .fields()
1034            .iter()
1035            .map(|f| (f.name.as_str(), &f.data_type));
1036        let desc = build_protobuf_schema(fields, "t1".to_owned()).unwrap();
1037        let pool = build_protobuf_descriptor_pool(&desc).unwrap();
1038        let t1_message = pool.get_message_by_name("t1").unwrap();
1039        assert_matches!(
1040            t1_message.get_field_by_name("v1").unwrap().kind(),
1041            prost_reflect::Kind::Int64
1042        );
1043        assert_matches!(
1044            t1_message.get_field_by_name("v2").unwrap().kind(),
1045            prost_reflect::Kind::Double
1046        );
1047        assert_matches!(
1048            t1_message.get_field_by_name("v3").unwrap().kind(),
1049            prost_reflect::Kind::Message(_)
1050        );
1051
1052        let v3_message = pool.get_message_by_name("t1.Structv3").unwrap();
1053        assert_matches!(
1054            v3_message.get_field_by_name("v1").unwrap().kind(),
1055            prost_reflect::Kind::Int64
1056        );
1057        assert!(v3_message.get_field_by_name("v1").unwrap().is_list());
1058
1059        let v3_v3_message = pool.get_message_by_name("t1.Structv3.Structv3").unwrap();
1060        assert_matches!(
1061            v3_v3_message.get_field_by_name("v1").unwrap().kind(),
1062            prost_reflect::Kind::Int64
1063        );
1064        assert_matches!(
1065            v3_v3_message.get_field_by_name("v2").unwrap().kind(),
1066            prost_reflect::Kind::Int64
1067        );
1068    }
1069}