risingwave_connector/sink/
big_query.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use core::pin::Pin;
16use core::time::Duration;
17use std::collections::{BTreeMap, HashMap, VecDeque};
18
19use anyhow::{Context, anyhow};
20use async_trait::async_trait;
21use base64::Engine;
22use base64::prelude::BASE64_STANDARD;
23use futures::future::pending;
24use futures::prelude::Future;
25use futures::{Stream, StreamExt};
26use futures_async_stream::try_stream;
27use gcp_bigquery_client::Client;
28use gcp_bigquery_client::error::BQError;
29use gcp_bigquery_client::model::query_request::QueryRequest;
30use gcp_bigquery_client::model::table::Table;
31use gcp_bigquery_client::model::table_field_schema::TableFieldSchema;
32use gcp_bigquery_client::model::table_schema::TableSchema;
33use google_cloud_bigquery::grpc::apiv1::conn_pool::ConnectionManager;
34use google_cloud_gax::conn::{ConnectionOptions, Environment};
35use google_cloud_gax::grpc::{Request, Response, Status};
36use google_cloud_googleapis::cloud::bigquery::storage::v1::append_rows_request::{
37    MissingValueInterpretation, ProtoData, Rows as AppendRowsRequestRows,
38};
39use google_cloud_googleapis::cloud::bigquery::storage::v1::{
40    AppendRowsRequest, AppendRowsResponse, ProtoRows, ProtoSchema,
41};
42use google_cloud_pubsub::client::google_cloud_auth;
43use google_cloud_pubsub::client::google_cloud_auth::credentials::CredentialsFile;
44use phf::{Set, phf_set};
45use prost::Message;
46use prost_014::Message as Message014;
47use prost_reflect::{FieldDescriptor, MessageDescriptor};
48use prost_types::{
49    DescriptorProto, FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet,
50    field_descriptor_proto,
51};
52use prost_types_014::DescriptorProto as DescriptorProto014;
53use risingwave_common::array::{Op, StreamChunk};
54use risingwave_common::catalog::{Field, Schema};
55use risingwave_common::types::DataType;
56use serde::Deserialize;
57use serde_with::{DisplayFromStr, serde_as};
58use simd_json::prelude::ArrayTrait;
59use tokio::sync::mpsc;
60use url::Url;
61use uuid::Uuid;
62use with_options::WithOptions;
63use yup_oauth2::ServiceAccountKey;
64
65use super::encoder::{ProtoEncoder, ProtoHeader, RowEncoder, SerTo};
66use super::log_store::{LogStoreReadItem, TruncateOffset};
67use super::{
68    LogSinker, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, SinkError, SinkLogReader,
69};
70use crate::aws_utils::load_file_descriptor_from_s3;
71use crate::connector_common::AwsAuthProps;
72use crate::enforce_secret::EnforceSecret;
73use crate::sink::{Result, Sink, SinkParam, SinkWriterParam};
74
75pub const BIGQUERY_SINK: &str = "bigquery";
76pub const CHANGE_TYPE: &str = "_CHANGE_TYPE";
77const DEFAULT_GRPC_CHANNEL_NUMS: usize = 4;
78const CONNECT_TIMEOUT: Option<Duration> = Some(Duration::from_secs(30));
79const CONNECTION_TIMEOUT: Option<Duration> = None;
80const BIGQUERY_SEND_FUTURE_BUFFER_MAX_SIZE: usize = 65536;
81// < 10MB, we set 8MB
82const MAX_ROW_SIZE: usize = 8 * 1024 * 1024;
83
84#[serde_as]
85#[derive(Deserialize, Debug, Clone, WithOptions)]
86pub struct BigQueryCommon {
87    #[serde(rename = "bigquery.local.path")]
88    pub local_path: Option<String>,
89    #[serde(rename = "bigquery.s3.path")]
90    pub s3_path: Option<String>,
91    #[serde(rename = "bigquery.project")]
92    pub project: String,
93    #[serde(rename = "bigquery.dataset")]
94    pub dataset: String,
95    #[serde(rename = "bigquery.table")]
96    pub table: String,
97    #[serde(default)] // default false
98    #[serde_as(as = "DisplayFromStr")]
99    pub auto_create: bool,
100    #[serde(rename = "bigquery.credentials")]
101    pub credentials: Option<String>,
102}
103
104impl EnforceSecret for BigQueryCommon {
105    const ENFORCE_SECRET_PROPERTIES: Set<&'static str> = phf_set! {
106        "bigquery.credentials",
107    };
108}
109
110struct BigQueryFutureManager {
111    // `offset_queue` holds the Some corresponding to each future.
112    // When TruncateOffset is barrier, the num is 0, we don't need to wait for the return of `resp_stream`.
113    // When TruncateOffset is chunk:
114    // 1. chunk has no rows. we didn't send, the num is 0, we don't need to wait for the return of `resp_stream`.
115    // 2. chunk is less than `MAX_ROW_SIZE`, we only sent once, the num is 1 and we only have to wait once for `resp_stream`.
116    // 3. chunk is less than `MAX_ROW_SIZE`, we only sent n, the num is n and we need to wait n times for r.
117    offset_queue: VecDeque<(TruncateOffset, usize)>,
118    resp_stream: Pin<Box<dyn Stream<Item = Result<()>> + Send>>,
119}
120impl BigQueryFutureManager {
121    pub fn new(
122        max_future_num: usize,
123        resp_stream: impl Stream<Item = Result<()>> + Send + 'static,
124    ) -> Self {
125        let offset_queue = VecDeque::with_capacity(max_future_num);
126        Self {
127            offset_queue,
128            resp_stream: Box::pin(resp_stream),
129        }
130    }
131
132    pub fn add_offset(&mut self, offset: TruncateOffset, resp_num: usize) {
133        self.offset_queue.push_back((offset, resp_num));
134    }
135
136    pub async fn next_offset(&mut self) -> Result<TruncateOffset> {
137        if let Some((_offset, remaining_resp_num)) = self.offset_queue.front_mut() {
138            if *remaining_resp_num == 0 {
139                return Ok(self.offset_queue.pop_front().unwrap().0);
140            }
141            while *remaining_resp_num > 0 {
142                self.resp_stream
143                    .next()
144                    .await
145                    .ok_or_else(|| SinkError::BigQuery(anyhow::anyhow!("end of stream")))??;
146                *remaining_resp_num -= 1;
147            }
148            Ok(self.offset_queue.pop_front().unwrap().0)
149        } else {
150            pending().await
151        }
152    }
153}
154pub struct BigQueryLogSinker {
155    writer: BigQuerySinkWriter,
156    bigquery_future_manager: BigQueryFutureManager,
157    future_num: usize,
158}
159impl BigQueryLogSinker {
160    pub fn new(
161        writer: BigQuerySinkWriter,
162        resp_stream: impl Stream<Item = Result<()>> + Send + 'static,
163        future_num: usize,
164    ) -> Self {
165        Self {
166            writer,
167            bigquery_future_manager: BigQueryFutureManager::new(future_num, resp_stream),
168            future_num,
169        }
170    }
171}
172
173#[async_trait]
174impl LogSinker for BigQueryLogSinker {
175    async fn consume_log_and_sink(mut self, mut log_reader: impl SinkLogReader) -> Result<!> {
176        log_reader.start_from(None).await?;
177        loop {
178            tokio::select!(
179                offset = self.bigquery_future_manager.next_offset() => {
180                        log_reader.truncate(offset?)?;
181                }
182                item_result = log_reader.next_item(), if self.bigquery_future_manager.offset_queue.len() <= self.future_num => {
183                    let (epoch, item) = item_result?;
184                    match item {
185                        LogStoreReadItem::StreamChunk { chunk_id, chunk } => {
186                            let resp_num = self.writer.write_chunk(chunk)?;
187                            self.bigquery_future_manager
188                                .add_offset(TruncateOffset::Chunk { epoch, chunk_id },resp_num);
189                        }
190                        LogStoreReadItem::Barrier { .. } => {
191                            self.bigquery_future_manager
192                                .add_offset(TruncateOffset::Barrier { epoch },0);
193                        }
194                    }
195                }
196            )
197        }
198    }
199}
200
201impl BigQueryCommon {
202    async fn build_client(&self, aws_auth_props: &AwsAuthProps) -> Result<Client> {
203        let auth_json = self.get_auth_json_from_path(aws_auth_props).await?;
204
205        let service_account =
206            if let Ok(auth_json_from_base64) = BASE64_STANDARD.decode(auth_json.clone()) {
207                serde_json::from_slice::<ServiceAccountKey>(&auth_json_from_base64)
208            } else {
209                serde_json::from_str::<ServiceAccountKey>(&auth_json)
210            }
211            .map_err(|e| SinkError::BigQuery(e.into()))?;
212
213        let client: Client = Client::from_service_account_key(service_account, false)
214            .await
215            .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?;
216        Ok(client)
217    }
218
219    async fn build_writer_client(
220        &self,
221        aws_auth_props: &AwsAuthProps,
222    ) -> Result<(StorageWriterClient, impl Stream<Item = Result<()>> + use<>)> {
223        let auth_json = self.get_auth_json_from_path(aws_auth_props).await?;
224
225        let credentials_file =
226            if let Ok(auth_json_from_base64) = BASE64_STANDARD.decode(auth_json.clone()) {
227                serde_json::from_slice::<CredentialsFile>(&auth_json_from_base64)
228            } else {
229                serde_json::from_str::<CredentialsFile>(&auth_json)
230            }
231            .map_err(|e| SinkError::BigQuery(e.into()))?;
232
233        StorageWriterClient::new(credentials_file).await
234    }
235
236    async fn get_auth_json_from_path(&self, aws_auth_props: &AwsAuthProps) -> Result<String> {
237        if let Some(credentials) = &self.credentials {
238            Ok(credentials.clone())
239        } else if let Some(local_path) = &self.local_path {
240            std::fs::read_to_string(local_path)
241                .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))
242        } else if let Some(s3_path) = &self.s3_path {
243            let url =
244                Url::parse(s3_path).map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?;
245            let auth_vec = load_file_descriptor_from_s3(&url, aws_auth_props)
246                .await
247                .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?;
248            Ok(String::from_utf8(auth_vec).map_err(|e| SinkError::BigQuery(e.into()))?)
249        } else {
250            Err(SinkError::BigQuery(anyhow::anyhow!(
251                "`bigquery.local.path` and `bigquery.s3.path` set at least one, configure as needed."
252            )))
253        }
254    }
255}
256
257#[serde_as]
258#[derive(Clone, Debug, Deserialize, WithOptions)]
259pub struct BigQueryConfig {
260    #[serde(flatten)]
261    pub common: BigQueryCommon,
262    #[serde(flatten)]
263    pub aws_auth_props: AwsAuthProps,
264    pub r#type: String, // accept "append-only" or "upsert"
265}
266
267impl EnforceSecret for BigQueryConfig {
268    fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
269        BigQueryCommon::enforce_one(prop)?;
270        AwsAuthProps::enforce_one(prop)?;
271        Ok(())
272    }
273}
274
275impl BigQueryConfig {
276    pub fn from_btreemap(properties: BTreeMap<String, String>) -> Result<Self> {
277        let config =
278            serde_json::from_value::<BigQueryConfig>(serde_json::to_value(properties).unwrap())
279                .map_err(|e| SinkError::Config(anyhow!(e)))?;
280        if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
281            return Err(SinkError::Config(anyhow!(
282                "`{}` must be {}, or {}",
283                SINK_TYPE_OPTION,
284                SINK_TYPE_APPEND_ONLY,
285                SINK_TYPE_UPSERT
286            )));
287        }
288        Ok(config)
289    }
290}
291
292#[derive(Debug)]
293pub struct BigQuerySink {
294    pub config: BigQueryConfig,
295    schema: Schema,
296    pk_indices: Vec<usize>,
297    is_append_only: bool,
298}
299
300impl EnforceSecret for BigQuerySink {
301    fn enforce_secret<'a>(
302        prop_iter: impl Iterator<Item = &'a str>,
303    ) -> crate::error::ConnectorResult<()> {
304        for prop in prop_iter {
305            BigQueryConfig::enforce_one(prop)?;
306        }
307        Ok(())
308    }
309}
310
311impl BigQuerySink {
312    pub fn new(
313        config: BigQueryConfig,
314        schema: Schema,
315        pk_indices: Vec<usize>,
316        is_append_only: bool,
317    ) -> Result<Self> {
318        Ok(Self {
319            config,
320            schema,
321            pk_indices,
322            is_append_only,
323        })
324    }
325}
326
327impl BigQuerySink {
328    fn check_column_name_and_type(
329        &self,
330        big_query_columns_desc: HashMap<String, String>,
331    ) -> Result<()> {
332        let rw_fields_name = self.schema.fields();
333        if big_query_columns_desc.is_empty() {
334            return Err(SinkError::BigQuery(anyhow::anyhow!(
335                "Cannot find table in bigquery"
336            )));
337        }
338        if rw_fields_name.len().ne(&big_query_columns_desc.len()) {
339            return Err(SinkError::BigQuery(anyhow::anyhow!(
340                "The length of the RisingWave column {} must be equal to the length of the bigquery column {}",
341                rw_fields_name.len(),
342                big_query_columns_desc.len()
343            )));
344        }
345
346        for i in rw_fields_name {
347            let value = big_query_columns_desc.get(&i.name).ok_or_else(|| {
348                SinkError::BigQuery(anyhow::anyhow!(
349                    "Column `{:?}` on RisingWave side is not found on BigQuery side.",
350                    i.name
351                ))
352            })?;
353            let data_type_string = Self::get_string_and_check_support_from_datatype(&i.data_type)?;
354            if data_type_string.ne(value) {
355                return Err(SinkError::BigQuery(anyhow::anyhow!(
356                    "Data type mismatch for column `{:?}`. BigQuery side: `{:?}`, RisingWave side: `{:?}`. ",
357                    i.name,
358                    value,
359                    data_type_string
360                )));
361            };
362        }
363        Ok(())
364    }
365
366    fn get_string_and_check_support_from_datatype(rw_data_type: &DataType) -> Result<String> {
367        match rw_data_type {
368            DataType::Boolean => Ok("BOOL".to_owned()),
369            DataType::Int16 => Ok("INT64".to_owned()),
370            DataType::Int32 => Ok("INT64".to_owned()),
371            DataType::Int64 => Ok("INT64".to_owned()),
372            DataType::Float32 => Err(SinkError::BigQuery(anyhow::anyhow!(
373                "REAL is not supported for BigQuery sink. Please convert to FLOAT64 or other supported types."
374            ))),
375            DataType::Float64 => Ok("FLOAT64".to_owned()),
376            DataType::Decimal => Ok("NUMERIC".to_owned()),
377            DataType::Date => Ok("DATE".to_owned()),
378            DataType::Varchar => Ok("STRING".to_owned()),
379            DataType::Time => Ok("TIME".to_owned()),
380            DataType::Timestamp => Ok("DATETIME".to_owned()),
381            DataType::Timestamptz => Ok("TIMESTAMP".to_owned()),
382            DataType::Interval => Ok("INTERVAL".to_owned()),
383            DataType::Struct(structs) => {
384                let mut elements_vec = vec![];
385                for (name, datatype) in structs.iter() {
386                    let element_string =
387                        Self::get_string_and_check_support_from_datatype(datatype)?;
388                    elements_vec.push(format!("{} {}", name, element_string));
389                }
390                Ok(format!("STRUCT<{}>", elements_vec.join(", ")))
391            }
392            DataType::List(l) => {
393                let element_string = Self::get_string_and_check_support_from_datatype(l.elem())?;
394                Ok(format!("ARRAY<{}>", element_string))
395            }
396            DataType::Bytea => Ok("BYTES".to_owned()),
397            DataType::Jsonb => Ok("JSON".to_owned()),
398            DataType::Serial => Ok("INT64".to_owned()),
399            DataType::Int256 => Err(SinkError::BigQuery(anyhow::anyhow!(
400                "INT256 is not supported for BigQuery sink."
401            ))),
402            DataType::Map(_) => Err(SinkError::BigQuery(anyhow::anyhow!(
403                "MAP is not supported for BigQuery sink."
404            ))),
405            DataType::Vector(_) => Err(SinkError::BigQuery(anyhow::anyhow!(
406                "VECTOR is not supported for BigQuery sink."
407            ))),
408        }
409    }
410
411    fn map_field(rw_field: &Field) -> Result<TableFieldSchema> {
412        let tfs = match &rw_field.data_type {
413            DataType::Boolean => TableFieldSchema::bool(&rw_field.name),
414            DataType::Int16 | DataType::Int32 | DataType::Int64 | DataType::Serial => {
415                TableFieldSchema::integer(&rw_field.name)
416            }
417            DataType::Float32 => {
418                return Err(SinkError::BigQuery(anyhow::anyhow!(
419                    "REAL is not supported for BigQuery sink. Please convert to FLOAT64 or other supported types."
420                )));
421            }
422            DataType::Float64 => TableFieldSchema::float(&rw_field.name),
423            DataType::Decimal => TableFieldSchema::numeric(&rw_field.name),
424            DataType::Date => TableFieldSchema::date(&rw_field.name),
425            DataType::Varchar => TableFieldSchema::string(&rw_field.name),
426            DataType::Time => TableFieldSchema::time(&rw_field.name),
427            DataType::Timestamp => TableFieldSchema::date_time(&rw_field.name),
428            DataType::Timestamptz => TableFieldSchema::timestamp(&rw_field.name),
429            DataType::Interval => {
430                return Err(SinkError::BigQuery(anyhow::anyhow!(
431                    "INTERVAL is not supported for BigQuery sink. Please convert to VARCHAR or other supported types."
432                )));
433            }
434            DataType::Struct(st) => {
435                let mut sub_fields = Vec::with_capacity(st.len());
436                for (name, dt) in st.iter() {
437                    let rw_field = Field::with_name(dt.clone(), name);
438                    let field = Self::map_field(&rw_field)?;
439                    sub_fields.push(field);
440                }
441                TableFieldSchema::record(&rw_field.name, sub_fields)
442            }
443            DataType::List(lt) => {
444                let inner_field =
445                    Self::map_field(&Field::with_name(lt.elem().clone(), &rw_field.name))?;
446                TableFieldSchema {
447                    mode: Some("REPEATED".to_owned()),
448                    ..inner_field
449                }
450            }
451
452            DataType::Bytea => TableFieldSchema::bytes(&rw_field.name),
453            DataType::Jsonb => TableFieldSchema::json(&rw_field.name),
454            DataType::Int256 => {
455                return Err(SinkError::BigQuery(anyhow::anyhow!(
456                    "INT256 is not supported for BigQuery sink."
457                )));
458            }
459            DataType::Map(_) => {
460                return Err(SinkError::BigQuery(anyhow::anyhow!(
461                    "MAP is not supported for BigQuery sink."
462                )));
463            }
464            DataType::Vector(_) => {
465                return Err(SinkError::BigQuery(anyhow::anyhow!(
466                    "VECTOR is not supported for BigQuery sink."
467                )));
468            }
469        };
470        Ok(tfs)
471    }
472
473    async fn create_table(
474        &self,
475        client: &Client,
476        project_id: &str,
477        dataset_id: &str,
478        table_id: &str,
479        fields: &Vec<Field>,
480    ) -> Result<Table> {
481        let dataset = client
482            .dataset()
483            .get(project_id, dataset_id)
484            .await
485            .map_err(|e| SinkError::BigQuery(e.into()))?;
486        let fields: Vec<_> = fields.iter().map(Self::map_field).collect::<Result<_>>()?;
487        let table = Table::from_dataset(&dataset, table_id, TableSchema::new(fields));
488
489        client
490            .table()
491            .create(table)
492            .await
493            .map_err(|e| SinkError::BigQuery(e.into()))
494    }
495}
496
497impl Sink for BigQuerySink {
498    type LogSinker = BigQueryLogSinker;
499
500    const SINK_NAME: &'static str = BIGQUERY_SINK;
501
502    async fn new_log_sinker(&self, _writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
503        let (writer, resp_stream) = BigQuerySinkWriter::new(
504            self.config.clone(),
505            self.schema.clone(),
506            self.pk_indices.clone(),
507            self.is_append_only,
508        )
509        .await?;
510        Ok(BigQueryLogSinker::new(
511            writer,
512            resp_stream,
513            BIGQUERY_SEND_FUTURE_BUFFER_MAX_SIZE,
514        ))
515    }
516
517    async fn validate(&self) -> Result<()> {
518        risingwave_common::license::Feature::BigQuerySink
519            .check_available()
520            .map_err(|e| anyhow::anyhow!(e))?;
521        if !self.is_append_only && self.pk_indices.is_empty() {
522            return Err(SinkError::Config(anyhow!(
523                "Primary key not defined for upsert bigquery sink (please define in `primary_key` field)"
524            )));
525        }
526        let client = self
527            .config
528            .common
529            .build_client(&self.config.aws_auth_props)
530            .await?;
531        let BigQueryCommon {
532            project: project_id,
533            dataset: dataset_id,
534            table: table_id,
535            ..
536        } = &self.config.common;
537
538        if self.config.common.auto_create {
539            match client
540                .table()
541                .get(project_id, dataset_id, table_id, None)
542                .await
543            {
544                Err(BQError::RequestError(_)) => {
545                    // early return: no need to query schema to check column and type
546                    return self
547                        .create_table(
548                            &client,
549                            project_id,
550                            dataset_id,
551                            table_id,
552                            &self.schema.fields,
553                        )
554                        .await
555                        .map(|_| ());
556                }
557                Err(e) => return Err(SinkError::BigQuery(e.into())),
558                _ => {}
559            }
560        }
561
562        let mut rs = client
563            .job()
564            .query(
565                &self.config.common.project,
566                QueryRequest::new(format!(
567                    "SELECT column_name, data_type FROM `{}.{}.INFORMATION_SCHEMA.COLUMNS` WHERE table_name = '{}'",
568                    project_id, dataset_id, table_id,
569                )),
570            ).await.map_err(|e| SinkError::BigQuery(e.into()))?;
571
572        let mut big_query_schema = HashMap::default();
573        while rs.next_row() {
574            big_query_schema.insert(
575                rs.get_string_by_name("column_name")
576                    .map_err(|e| SinkError::BigQuery(e.into()))?
577                    .ok_or_else(|| {
578                        SinkError::BigQuery(anyhow::anyhow!("Cannot find column_name"))
579                    })?,
580                rs.get_string_by_name("data_type")
581                    .map_err(|e| SinkError::BigQuery(e.into()))?
582                    .ok_or_else(|| {
583                        SinkError::BigQuery(anyhow::anyhow!("Cannot find column_name"))
584                    })?,
585            );
586        }
587
588        self.check_column_name_and_type(big_query_schema)?;
589        Ok(())
590    }
591}
592
593pub struct BigQuerySinkWriter {
594    pub config: BigQueryConfig,
595    #[expect(dead_code)]
596    schema: Schema,
597    #[expect(dead_code)]
598    pk_indices: Vec<usize>,
599    client: StorageWriterClient,
600    is_append_only: bool,
601    row_encoder: ProtoEncoder,
602    writer_pb_schema: ProtoSchema,
603    #[expect(dead_code)]
604    message_descriptor: MessageDescriptor,
605    write_stream: String,
606    proto_field: Option<FieldDescriptor>,
607}
608
609impl TryFrom<SinkParam> for BigQuerySink {
610    type Error = SinkError;
611
612    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
613        let schema = param.schema();
614        let pk_indices = param.downstream_pk_or_empty();
615        let config = BigQueryConfig::from_btreemap(param.properties)?;
616        BigQuerySink::new(config, schema, pk_indices, param.sink_type.is_append_only())
617    }
618}
619
620impl BigQuerySinkWriter {
621    pub async fn new(
622        config: BigQueryConfig,
623        schema: Schema,
624        pk_indices: Vec<usize>,
625        is_append_only: bool,
626    ) -> Result<(Self, impl Stream<Item = Result<()>>)> {
627        let (client, resp_stream) = config
628            .common
629            .build_writer_client(&config.aws_auth_props)
630            .await?;
631        let mut descriptor_proto = build_protobuf_schema(
632            schema
633                .fields()
634                .iter()
635                .map(|f| (f.name.as_str(), &f.data_type)),
636            config.common.table.clone(),
637        )?;
638
639        if !is_append_only {
640            let field = FieldDescriptorProto {
641                name: Some(CHANGE_TYPE.to_owned()),
642                number: Some((schema.len() + 1) as i32),
643                r#type: Some(field_descriptor_proto::Type::String.into()),
644                ..Default::default()
645            };
646            descriptor_proto.field.push(field);
647        }
648
649        let descriptor_pool = build_protobuf_descriptor_pool(&descriptor_proto)?;
650        let message_descriptor = descriptor_pool
651            .get_message_by_name(&config.common.table)
652            .ok_or_else(|| {
653                SinkError::BigQuery(anyhow::anyhow!(
654                    "Can't find message proto {}",
655                    &config.common.table
656                ))
657            })?;
658        let proto_field = if !is_append_only {
659            let proto_field = message_descriptor
660                .get_field_by_name(CHANGE_TYPE)
661                .ok_or_else(|| {
662                    SinkError::BigQuery(anyhow::anyhow!("Can't find {}", CHANGE_TYPE))
663                })?;
664            Some(proto_field)
665        } else {
666            None
667        };
668        let row_encoder = ProtoEncoder::new(
669            schema.clone(),
670            None,
671            message_descriptor.clone(),
672            ProtoHeader::None,
673        )?;
674        Ok((
675            Self {
676                write_stream: format!(
677                    "projects/{}/datasets/{}/tables/{}/streams/_default",
678                    config.common.project, config.common.dataset, config.common.table
679                ),
680                config,
681                schema,
682                pk_indices,
683                client,
684                is_append_only,
685                row_encoder,
686                message_descriptor,
687                proto_field,
688                writer_pb_schema: ProtoSchema {
689                    proto_descriptor: Some(to_gcloud_descriptor(&descriptor_proto)?),
690                },
691            },
692            resp_stream,
693        ))
694    }
695
696    fn append_only(&mut self, chunk: StreamChunk) -> Result<Vec<Vec<u8>>> {
697        let mut serialized_rows: Vec<Vec<u8>> = Vec::with_capacity(chunk.capacity());
698        for (op, row) in chunk.rows() {
699            if op != Op::Insert {
700                continue;
701            }
702            serialized_rows.push(self.row_encoder.encode(row)?.ser_to()?)
703        }
704        Ok(serialized_rows)
705    }
706
707    fn upsert(&mut self, chunk: StreamChunk) -> Result<Vec<Vec<u8>>> {
708        let mut serialized_rows: Vec<Vec<u8>> = Vec::with_capacity(chunk.capacity());
709        for (op, row) in chunk.rows() {
710            if op == Op::UpdateDelete {
711                continue;
712            }
713            let mut pb_row = self.row_encoder.encode(row)?;
714            match op {
715                Op::Insert => pb_row
716                    .message
717                    .try_set_field(
718                        self.proto_field.as_ref().unwrap(),
719                        prost_reflect::Value::String("UPSERT".to_owned()),
720                    )
721                    .map_err(|e| SinkError::BigQuery(e.into()))?,
722                Op::Delete => pb_row
723                    .message
724                    .try_set_field(
725                        self.proto_field.as_ref().unwrap(),
726                        prost_reflect::Value::String("DELETE".to_owned()),
727                    )
728                    .map_err(|e| SinkError::BigQuery(e.into()))?,
729                Op::UpdateDelete => continue,
730                Op::UpdateInsert => pb_row
731                    .message
732                    .try_set_field(
733                        self.proto_field.as_ref().unwrap(),
734                        prost_reflect::Value::String("UPSERT".to_owned()),
735                    )
736                    .map_err(|e| SinkError::BigQuery(e.into()))?,
737            };
738
739            serialized_rows.push(pb_row.ser_to()?)
740        }
741        Ok(serialized_rows)
742    }
743
744    fn write_chunk(&mut self, chunk: StreamChunk) -> Result<usize> {
745        let serialized_rows = if self.is_append_only {
746            self.append_only(chunk)?
747        } else {
748            self.upsert(chunk)?
749        };
750        if serialized_rows.is_empty() {
751            return Ok(0);
752        }
753        let mut result = Vec::new();
754        let mut result_inner = Vec::new();
755        let mut size_count = 0;
756        for i in serialized_rows {
757            size_count += i.len();
758            if size_count > MAX_ROW_SIZE {
759                result.push(result_inner);
760                result_inner = Vec::new();
761                size_count = i.len();
762            }
763            result_inner.push(i);
764        }
765        if !result_inner.is_empty() {
766            result.push(result_inner);
767        }
768        let len = result.len();
769        for serialized_rows in result {
770            let rows = AppendRowsRequestRows::ProtoRows(ProtoData {
771                writer_schema: Some(self.writer_pb_schema.clone()),
772                rows: Some(ProtoRows { serialized_rows }),
773            });
774            self.client.append_rows(rows, self.write_stream.clone())?;
775        }
776        Ok(len)
777    }
778}
779
780#[try_stream(ok = (), error = SinkError)]
781pub async fn resp_to_stream(
782    resp_stream: impl Future<
783        Output = std::result::Result<
784            Response<google_cloud_gax::grpc::Streaming<AppendRowsResponse>>,
785            Status,
786        >,
787    >
788    + 'static
789    + Send,
790) {
791    let mut resp_stream = resp_stream
792        .await
793        .map_err(|e| SinkError::BigQuery(e.into()))?
794        .into_inner();
795    loop {
796        match resp_stream
797            .message()
798            .await
799            .map_err(|e| SinkError::BigQuery(e.into()))?
800        {
801            Some(append_rows_response) => {
802                if !append_rows_response.row_errors.is_empty() {
803                    return Err(SinkError::BigQuery(anyhow::anyhow!(
804                        "bigquery insert error {:?}",
805                        append_rows_response.row_errors
806                    )));
807                }
808                if let Some(google_cloud_googleapis::cloud::bigquery::storage::v1::append_rows_response::Response::Error(status)) = append_rows_response.response{
809                            return Err(SinkError::BigQuery(anyhow::anyhow!(
810                                "bigquery insert error {:?}",
811                                status
812                            )));
813                        }
814                yield ();
815            }
816            None => {
817                return Err(SinkError::BigQuery(anyhow::anyhow!(
818                    "bigquery insert error: end of resp stream",
819                )));
820            }
821        }
822    }
823}
824
825struct StorageWriterClient {
826    #[expect(dead_code)]
827    environment: Environment,
828    request_sender: mpsc::UnboundedSender<AppendRowsRequest>,
829}
830impl StorageWriterClient {
831    pub async fn new(
832        credentials: CredentialsFile,
833    ) -> Result<(Self, impl Stream<Item = Result<()>>)> {
834        let ts_grpc = google_cloud_auth::token::DefaultTokenSourceProvider::new_with_credentials(
835            Self::bigquery_grpc_auth_config(),
836            Box::new(credentials),
837        )
838        .await
839        .map_err(|e| SinkError::BigQuery(e.into()))?;
840        let conn_options = ConnectionOptions {
841            connect_timeout: CONNECT_TIMEOUT,
842            timeout: CONNECTION_TIMEOUT,
843        };
844        let environment = Environment::GoogleCloud(Box::new(ts_grpc));
845        let conn = ConnectionManager::new(DEFAULT_GRPC_CHANNEL_NUMS, &environment, &conn_options)
846            .await
847            .map_err(|e| SinkError::BigQuery(e.into()))?;
848        let mut client = conn.writer();
849
850        let (tx, rx) = mpsc::unbounded_channel();
851        let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
852
853        let resp = async move { client.append_rows(Request::new(stream)).await };
854        let resp_stream = resp_to_stream(resp);
855
856        Ok((
857            StorageWriterClient {
858                environment,
859                request_sender: tx,
860            },
861            resp_stream,
862        ))
863    }
864
865    pub fn append_rows(&mut self, row: AppendRowsRequestRows, write_stream: String) -> Result<()> {
866        let append_req = AppendRowsRequest {
867            write_stream,
868            offset: None,
869            trace_id: Uuid::new_v4().hyphenated().to_string(),
870            missing_value_interpretations: HashMap::default(),
871            rows: Some(row),
872            default_missing_value_interpretation: MissingValueInterpretation::DefaultValue as i32,
873        };
874        self.request_sender
875            .send(append_req)
876            .map_err(|e| SinkError::BigQuery(e.into()))?;
877        Ok(())
878    }
879
880    fn bigquery_grpc_auth_config() -> google_cloud_auth::project::Config<'static> {
881        let mut auth_config = google_cloud_auth::project::Config::default();
882        auth_config =
883            auth_config.with_audience(google_cloud_bigquery::grpc::apiv1::conn_pool::AUDIENCE);
884        auth_config =
885            auth_config.with_scopes(&google_cloud_bigquery::grpc::apiv1::conn_pool::SCOPES);
886        auth_config
887    }
888}
889
890fn build_protobuf_descriptor_pool(desc: &DescriptorProto) -> Result<prost_reflect::DescriptorPool> {
891    let file_descriptor = FileDescriptorProto {
892        message_type: vec![desc.clone()],
893        name: Some("bigquery".to_owned()),
894        ..Default::default()
895    };
896
897    prost_reflect::DescriptorPool::from_file_descriptor_set(FileDescriptorSet {
898        file: vec![file_descriptor],
899    })
900    .context("failed to build descriptor pool")
901    .map_err(SinkError::BigQuery)
902}
903
904fn to_gcloud_descriptor(desc: &DescriptorProto) -> Result<DescriptorProto014> {
905    let bytes = Message::encode_to_vec(desc);
906    Message014::decode(bytes.as_slice())
907        .context("failed to convert descriptor proto")
908        .map_err(SinkError::BigQuery)
909}
910
911fn build_protobuf_schema<'a>(
912    fields: impl Iterator<Item = (&'a str, &'a DataType)>,
913    name: String,
914) -> Result<DescriptorProto> {
915    let mut proto = DescriptorProto {
916        name: Some(name),
917        ..Default::default()
918    };
919    let mut struct_vec = vec![];
920    let field_vec = fields
921        .enumerate()
922        .map(|(index, (name, data_type))| {
923            let (field, des_proto) =
924                build_protobuf_field(data_type, (index + 1) as i32, name.to_owned())?;
925            if let Some(sv) = des_proto {
926                struct_vec.push(sv);
927            }
928            Ok(field)
929        })
930        .collect::<Result<Vec<_>>>()?;
931    proto.field = field_vec;
932    proto.nested_type = struct_vec;
933    Ok(proto)
934}
935
936fn build_protobuf_field(
937    data_type: &DataType,
938    index: i32,
939    name: String,
940) -> Result<(FieldDescriptorProto, Option<DescriptorProto>)> {
941    let mut field = FieldDescriptorProto {
942        name: Some(name.clone()),
943        number: Some(index),
944        ..Default::default()
945    };
946    match data_type {
947        DataType::Boolean => field.r#type = Some(field_descriptor_proto::Type::Bool.into()),
948        DataType::Int32 => field.r#type = Some(field_descriptor_proto::Type::Int32.into()),
949        DataType::Int16 | DataType::Int64 => {
950            field.r#type = Some(field_descriptor_proto::Type::Int64.into())
951        }
952        DataType::Float64 => field.r#type = Some(field_descriptor_proto::Type::Double.into()),
953        DataType::Decimal => field.r#type = Some(field_descriptor_proto::Type::String.into()),
954        DataType::Date => field.r#type = Some(field_descriptor_proto::Type::Int32.into()),
955        DataType::Varchar => field.r#type = Some(field_descriptor_proto::Type::String.into()),
956        DataType::Time => field.r#type = Some(field_descriptor_proto::Type::String.into()),
957        DataType::Timestamp => field.r#type = Some(field_descriptor_proto::Type::String.into()),
958        DataType::Timestamptz => field.r#type = Some(field_descriptor_proto::Type::String.into()),
959        DataType::Interval => field.r#type = Some(field_descriptor_proto::Type::String.into()),
960        DataType::Struct(s) => {
961            field.r#type = Some(field_descriptor_proto::Type::Message.into());
962            let name = format!("Struct{}", name);
963            let sub_proto = build_protobuf_schema(s.iter(), name.clone())?;
964            field.type_name = Some(name);
965            return Ok((field, Some(sub_proto)));
966        }
967        DataType::List(l) => {
968            let (mut field, proto) = build_protobuf_field(l.elem(), index, name)?;
969            field.label = Some(field_descriptor_proto::Label::Repeated.into());
970            return Ok((field, proto));
971        }
972        DataType::Bytea => field.r#type = Some(field_descriptor_proto::Type::Bytes.into()),
973        DataType::Jsonb => field.r#type = Some(field_descriptor_proto::Type::String.into()),
974        DataType::Serial => field.r#type = Some(field_descriptor_proto::Type::Int64.into()),
975        DataType::Float32 | DataType::Int256 => {
976            return Err(SinkError::BigQuery(anyhow::anyhow!(
977                "Don't support Float32 and Int256"
978            )));
979        }
980        DataType::Map(_) => return Err(SinkError::BigQuery(anyhow::anyhow!("Don't support Map"))),
981        DataType::Vector(_) => {
982            return Err(SinkError::BigQuery(anyhow::anyhow!("Don't support Vector")));
983        }
984    }
985    Ok((field, None))
986}
987
988#[cfg(test)]
989mod test {
990
991    use std::assert_matches::assert_matches;
992
993    use risingwave_common::catalog::{Field, Schema};
994    use risingwave_common::types::{DataType, StructType};
995
996    use crate::sink::big_query::{
997        BigQuerySink, build_protobuf_descriptor_pool, build_protobuf_schema,
998    };
999
1000    #[tokio::test]
1001    async fn test_type_check() {
1002        let big_query_type_string = "ARRAY<STRUCT<v1 ARRAY<INT64>, v2 STRUCT<v1 INT64, v2 INT64>>>";
1003        let rw_datatype = DataType::list(DataType::Struct(StructType::new(vec![
1004            ("v1".to_owned(), DataType::Int64.list()),
1005            (
1006                "v2".to_owned(),
1007                DataType::Struct(StructType::new(vec![
1008                    ("v1".to_owned(), DataType::Int64),
1009                    ("v2".to_owned(), DataType::Int64),
1010                ])),
1011            ),
1012        ])));
1013        assert_eq!(
1014            BigQuerySink::get_string_and_check_support_from_datatype(&rw_datatype).unwrap(),
1015            big_query_type_string
1016        );
1017    }
1018
1019    #[tokio::test]
1020    async fn test_schema_check() {
1021        let schema = Schema {
1022            fields: vec![
1023                Field::with_name(DataType::Int64, "v1"),
1024                Field::with_name(DataType::Float64, "v2"),
1025                Field::with_name(
1026                    DataType::list(DataType::Struct(StructType::new(vec![
1027                        ("v1".to_owned(), DataType::Int64.list()),
1028                        (
1029                            "v3".to_owned(),
1030                            DataType::Struct(StructType::new(vec![
1031                                ("v1".to_owned(), DataType::Int64),
1032                                ("v2".to_owned(), DataType::Int64),
1033                            ])),
1034                        ),
1035                    ]))),
1036                    "v3",
1037                ),
1038            ],
1039        };
1040        let fields = schema
1041            .fields()
1042            .iter()
1043            .map(|f| (f.name.as_str(), &f.data_type));
1044        let desc = build_protobuf_schema(fields, "t1".to_owned()).unwrap();
1045        let pool = build_protobuf_descriptor_pool(&desc).unwrap();
1046        let t1_message = pool.get_message_by_name("t1").unwrap();
1047        assert_matches!(
1048            t1_message.get_field_by_name("v1").unwrap().kind(),
1049            prost_reflect::Kind::Int64
1050        );
1051        assert_matches!(
1052            t1_message.get_field_by_name("v2").unwrap().kind(),
1053            prost_reflect::Kind::Double
1054        );
1055        assert_matches!(
1056            t1_message.get_field_by_name("v3").unwrap().kind(),
1057            prost_reflect::Kind::Message(_)
1058        );
1059
1060        let v3_message = pool.get_message_by_name("t1.Structv3").unwrap();
1061        assert_matches!(
1062            v3_message.get_field_by_name("v1").unwrap().kind(),
1063            prost_reflect::Kind::Int64
1064        );
1065        assert!(v3_message.get_field_by_name("v1").unwrap().is_list());
1066
1067        let v3_v3_message = pool.get_message_by_name("t1.Structv3.Structv3").unwrap();
1068        assert_matches!(
1069            v3_v3_message.get_field_by_name("v1").unwrap().kind(),
1070            prost_reflect::Kind::Int64
1071        );
1072        assert_matches!(
1073            v3_v3_message.get_field_by_name("v2").unwrap().kind(),
1074            prost_reflect::Kind::Int64
1075        );
1076    }
1077}