risingwave_connector/sink/
remote.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::VecDeque;
16use std::marker::PhantomData;
17use std::num::NonZero;
18use std::ops::Deref;
19use std::pin::pin;
20use std::time::Instant;
21
22use anyhow::{Context, anyhow};
23use async_trait::async_trait;
24use await_tree::{InstrumentAwait, span};
25use futures::TryStreamExt;
26use futures::future::select;
27use phf::phf_set;
28use prost::Message;
29use risingwave_common::array::StreamChunk;
30use risingwave_common::bail;
31use risingwave_common::catalog::{ColumnDesc, ColumnId, Field};
32use risingwave_common::global_jvm::Jvm;
33use risingwave_common::session_config::sink_decouple::SinkDecouple;
34use risingwave_common::types::DataType;
35use risingwave_jni_core::jvm_runtime::execute_with_jni_env;
36use risingwave_jni_core::{
37    JniReceiverType, JniSenderType, JniSinkWriterStreamRequest, call_static_method, gen_class_name,
38};
39use risingwave_pb::connector_service::sink_coordinator_stream_request::StartCoordinator;
40use risingwave_pb::connector_service::sink_writer_stream_request::{
41    Request as SinkRequest, StartSink,
42};
43use risingwave_pb::connector_service::{
44    PbSinkParam, SinkCoordinatorStreamRequest, SinkCoordinatorStreamResponse, SinkMetadata,
45    SinkWriterStreamRequest, SinkWriterStreamResponse, TableSchema, ValidateSinkRequest,
46    ValidateSinkResponse, sink_coordinator_stream_request, sink_coordinator_stream_response,
47    sink_writer_stream_response,
48};
49use risingwave_rpc_client::error::RpcError;
50use risingwave_rpc_client::{
51    BidiStreamReceiver, BidiStreamSender, DEFAULT_BUFFER_SIZE, SinkCoordinatorStreamHandle,
52    SinkWriterStreamHandle,
53};
54use rw_futures_util::drop_either_future;
55use sea_orm::DatabaseConnection;
56use thiserror_ext::AsReport;
57use tokio::sync::mpsc;
58use tokio::sync::mpsc::{Receiver, UnboundedSender, unbounded_channel};
59use tokio::task::spawn_blocking;
60use tokio_stream::wrappers::ReceiverStream;
61use tracing::warn;
62
63use super::SinkCommittedEpochSubscriber;
64use super::elasticsearch_opensearch::elasticsearch_converter::{
65    StreamChunkConverter, is_remote_es_sink,
66};
67use super::elasticsearch_opensearch::elasticsearch_opensearch_config::ES_OPTION_DELIMITER;
68use crate::connector_common::IcebergSinkCompactionUpdate;
69use crate::enforce_secret::EnforceSecret;
70use crate::error::ConnectorResult;
71use crate::sink::coordinate::CoordinatedLogSinker;
72use crate::sink::log_store::{LogStoreReadItem, LogStoreResult, TruncateOffset};
73use crate::sink::writer::SinkWriter;
74use crate::sink::{
75    LogSinker, Result, Sink, SinkCommitCoordinator, SinkError, SinkLogReader, SinkParam,
76    SinkWriterMetrics, SinkWriterParam,
77};
78
79macro_rules! def_remote_sink {
80    () => {
81        def_remote_sink! {
82            //todo!, delete java impl
83            // { ElasticSearchJava, ElasticSearchJavaSink, "elasticsearch_v1" }
84            // { OpensearchJava, OpenSearchJavaSink, "opensearch_v1"}
85            { Cassandra, CassandraSink, "cassandra", [ "cassandra.url" ] }
86            { Jdbc, JdbcSink, "jdbc", [ "jdbc.url" ] }
87        }
88    };
89    ({ $variant_name:ident, $sink_type_name:ident, $sink_name:expr, [ $($enforce_secret_prop:expr),* ] }) => {
90        #[derive(Debug)]
91        pub struct $variant_name;
92        impl RemoteSinkTrait for $variant_name {
93            const SINK_NAME: &'static str = $sink_name;
94        }
95        impl EnforceSecret for $variant_name {
96            const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf_set! {
97                $($enforce_secret_prop),*
98            };
99        }
100        pub type $sink_type_name = RemoteSink<$variant_name>;
101    };
102    ({ $variant_name:ident, $sink_type_name:ident, $sink_name:expr, [ $($enforce_secret_prop:expr),* ], |$desc:ident| $body:expr }) => {
103        #[derive(Debug)]
104        pub struct $variant_name;
105        impl RemoteSinkTrait for $variant_name {
106            const SINK_NAME: &'static str = $sink_name;
107            fn default_sink_decouple($desc: &SinkDesc) -> bool {
108                $body
109            }
110        }
111        impl EnforceSecret for $variant_name {
112            const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = phf_set! {
113                $($enforce_secret_prop),*
114            };
115        }
116        pub type $sink_type_name = RemoteSink<$variant_name>;
117    };
118    ({ $($first:tt)+ } $({$($rest:tt)+})*) => {
119        def_remote_sink! {
120            {$($first)+}
121        }
122        def_remote_sink! {
123            $({$($rest)+})*
124        }
125    };
126    ($($invalid:tt)*) => {
127        compile_error! {concat! {"invalid `", stringify!{$($invalid)*}, "`"}}
128    }
129}
130
131def_remote_sink!();
132
133pub trait RemoteSinkTrait: EnforceSecret + Send + Sync + 'static {
134    const SINK_NAME: &'static str;
135    fn default_sink_decouple() -> bool {
136        true
137    }
138}
139
140#[derive(Debug)]
141pub struct RemoteSink<R: RemoteSinkTrait> {
142    param: SinkParam,
143    _phantom: PhantomData<R>,
144}
145
146impl<R: RemoteSinkTrait> EnforceSecret for RemoteSink<R> {
147    const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = R::ENFORCE_SECRET_PROPERTIES;
148}
149
150impl<R: RemoteSinkTrait> TryFrom<SinkParam> for RemoteSink<R> {
151    type Error = SinkError;
152
153    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
154        Ok(Self {
155            param,
156            _phantom: PhantomData,
157        })
158    }
159}
160
161impl<R: RemoteSinkTrait> Sink for RemoteSink<R> {
162    type LogSinker = RemoteLogSinker;
163
164    const SINK_NAME: &'static str = R::SINK_NAME;
165
166    fn is_sink_decouple(user_specified: &SinkDecouple) -> Result<bool> {
167        match user_specified {
168            SinkDecouple::Default => Ok(R::default_sink_decouple()),
169            SinkDecouple::Enable => Ok(true),
170            SinkDecouple::Disable => Ok(false),
171        }
172    }
173
174    async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
175        RemoteLogSinker::new(self.param.clone(), writer_param, Self::SINK_NAME).await
176    }
177
178    async fn validate(&self) -> Result<()> {
179        validate_remote_sink(&self.param, Self::SINK_NAME).await?;
180        Ok(())
181    }
182}
183
184async fn validate_remote_sink(param: &SinkParam, sink_name: &str) -> ConnectorResult<()> {
185    // if sink_name == OpenSearchJavaSink::SINK_NAME {
186    //     risingwave_common::license::Feature::OpenSearchSink
187    //         .check_available()
188    //         .map_err(|e| anyhow::anyhow!(e))?;
189    // }
190    if is_remote_es_sink(sink_name)
191        && param.downstream_pk.len() > 1
192        && !param.properties.contains_key(ES_OPTION_DELIMITER)
193    {
194        bail!("Es sink only supports single pk or pk with delimiter option");
195    }
196    // FIXME: support struct and array in stream sink
197    param.columns.iter().try_for_each(|col| {
198        match &col.data_type {
199            DataType::Int16
200                    | DataType::Int32
201                    | DataType::Int64
202                    | DataType::Float32
203                    | DataType::Float64
204                    | DataType::Boolean
205                    | DataType::Decimal
206                    | DataType::Timestamp
207                    | DataType::Timestamptz
208                    | DataType::Varchar
209                    | DataType::Date
210                    | DataType::Time
211                    | DataType::Interval
212                    | DataType::Jsonb
213                    | DataType::Bytea => Ok(()),
214            DataType::List(list) => {
215                if is_remote_es_sink(sink_name) || matches!(list.as_ref(), DataType::Int16 | DataType::Int32 | DataType::Int64 | DataType::Float32 | DataType::Float64 | DataType::Varchar){
216                    Ok(())
217                } else{
218                    Err(SinkError::Remote(anyhow!(
219                        "Remote sink only supports list<int16, int32, int64, float, double, varchar>, got {:?}: {:?}",
220                        col.name,
221                        col.data_type,
222                    )))
223                }
224            },
225            DataType::Struct(_) => {
226                if is_remote_es_sink(sink_name){
227                    Ok(())
228                }else{
229                    Err(SinkError::Remote(anyhow!(
230                        "Only Es sink supports struct, got {:?}: {:?}",
231                        col.name,
232                        col.data_type,
233                    )))
234                }
235            },
236            DataType::Vector(_) |
237            DataType::Serial | DataType::Int256 | DataType::Map(_) => Err(SinkError::Remote(anyhow!(
238                            "remote sink supports Int16, Int32, Int64, Float32, Float64, Boolean, Decimal, Time, Date, Interval, Jsonb, Timestamp, Timestamptz, Bytea, List and Varchar, (Es sink support Struct) got {:?}: {:?}",
239                            col.name,
240                            col.data_type,
241                        )))}})?;
242
243    let jvm = Jvm::get_or_init()?;
244    let sink_param = param.to_proto();
245
246    spawn_blocking(move || -> anyhow::Result<()> {
247        execute_with_jni_env(jvm, |env| {
248            let validate_sink_request = ValidateSinkRequest {
249                sink_param: Some(sink_param),
250            };
251            let validate_sink_request_bytes =
252                env.byte_array_from_slice(&Message::encode_to_vec(&validate_sink_request))?;
253
254            let validate_sink_response_bytes = call_static_method!(
255                env,
256                {com.risingwave.connector.JniSinkValidationHandler},
257                {byte[] validate(byte[] validateSourceRequestBytes)},
258                &validate_sink_request_bytes
259            )?;
260
261            let validate_sink_response: ValidateSinkResponse = Message::decode(
262                risingwave_jni_core::to_guarded_slice(&validate_sink_response_bytes, env)?.deref(),
263            )?;
264
265            validate_sink_response.error.map_or_else(
266                || Ok(()), // If there is no error message, return Ok here.
267                |err| bail!("sink cannot pass validation: {}", err.error_message),
268            )
269        })
270    })
271    .await
272    .context("JoinHandle returns error")??;
273
274    Ok(())
275}
276
277pub struct RemoteLogSinker {
278    request_sender: BidiStreamSender<JniSinkWriterStreamRequest>,
279    response_stream: BidiStreamReceiver<SinkWriterStreamResponse>,
280    stream_chunk_converter: StreamChunkConverter,
281    sink_writer_metrics: SinkWriterMetrics,
282}
283
284impl RemoteLogSinker {
285    async fn new(
286        sink_param: SinkParam,
287        writer_param: SinkWriterParam,
288        sink_name: &str,
289    ) -> Result<Self> {
290        let sink_proto = sink_param.to_proto();
291        let payload_schema = if is_remote_es_sink(sink_name) {
292            let columns = vec![
293                ColumnDesc::unnamed(ColumnId::from(0), DataType::Varchar).to_protobuf(),
294                ColumnDesc::unnamed(ColumnId::from(1), DataType::Varchar).to_protobuf(),
295                ColumnDesc::unnamed(ColumnId::from(2), DataType::Jsonb).to_protobuf(),
296                ColumnDesc::unnamed(ColumnId::from(2), DataType::Varchar).to_protobuf(),
297            ];
298            Some(TableSchema {
299                columns,
300                pk_indices: vec![],
301            })
302        } else {
303            sink_proto.table_schema.clone()
304        };
305
306        let SinkWriterStreamHandle {
307            request_sender,
308            response_stream,
309        } = EmbeddedConnectorClient::new()?
310            .start_sink_writer_stream(payload_schema, sink_proto)
311            .await?;
312
313        Ok(RemoteLogSinker {
314            request_sender,
315            response_stream,
316            sink_writer_metrics: SinkWriterMetrics::new(&writer_param),
317            stream_chunk_converter: StreamChunkConverter::new(
318                sink_name,
319                sink_param.schema(),
320                &sink_param.downstream_pk,
321                &sink_param.properties,
322                sink_param.sink_type.is_append_only(),
323            )?,
324        })
325    }
326}
327
328#[async_trait]
329impl LogSinker for RemoteLogSinker {
330    async fn consume_log_and_sink(self, mut log_reader: impl SinkLogReader) -> Result<!> {
331        log_reader.start_from(None).await?;
332        let mut request_tx = self.request_sender;
333        let mut response_err_stream_rx = self.response_stream;
334        let sink_writer_metrics = self.sink_writer_metrics;
335
336        let (response_tx, mut response_rx) = unbounded_channel();
337
338        let poll_response_stream = async move {
339            loop {
340                let result = response_err_stream_rx
341                    .stream
342                    .try_next()
343                    .instrument_await("log_sinker_wait_next_response")
344                    .await;
345                match result {
346                    Ok(Some(response)) => {
347                        response_tx.send(response).map_err(|err| {
348                            SinkError::Remote(anyhow!("unable to send response: {:?}", err.0))
349                        })?;
350                    }
351                    Ok(None) => return Err(SinkError::Remote(anyhow!("end of response stream"))),
352                    Err(e) => return Err(SinkError::Remote(anyhow!(e))),
353                }
354            }
355        };
356
357        let poll_consume_log_and_sink = async move {
358            fn truncate_matched_offset(
359                queue: &mut VecDeque<(TruncateOffset, Option<Instant>)>,
360                persisted_offset: TruncateOffset,
361                log_reader: &mut impl SinkLogReader,
362                sink_writer_metrics: &SinkWriterMetrics,
363            ) -> Result<()> {
364                while let Some((sent_offset, _)) = queue.front()
365                    && sent_offset < &persisted_offset
366                {
367                    queue.pop_front();
368                }
369
370                let (sent_offset, start_time) = queue.pop_front().ok_or_else(|| {
371                    anyhow!("get unsent offset {:?} in response", persisted_offset)
372                })?;
373                if sent_offset != persisted_offset {
374                    bail!(
375                        "new response offset {:?} does not match the buffer offset {:?}",
376                        persisted_offset,
377                        sent_offset
378                    );
379                }
380
381                if let (TruncateOffset::Barrier { .. }, Some(start_time)) =
382                    (persisted_offset, start_time)
383                {
384                    sink_writer_metrics
385                        .sink_commit_duration
386                        .observe(start_time.elapsed().as_millis() as f64);
387                }
388
389                log_reader.truncate(persisted_offset)?;
390                Ok(())
391            }
392
393            let mut prev_offset: Option<TruncateOffset> = None;
394            // Push from back and pop from front
395            let mut sent_offset_queue: VecDeque<(TruncateOffset, Option<Instant>)> =
396                VecDeque::new();
397
398            loop {
399                let either_result: futures::future::Either<
400                    Option<SinkWriterStreamResponse>,
401                    LogStoreResult<(u64, LogStoreReadItem)>,
402                > = drop_either_future(
403                    select(pin!(response_rx.recv()), pin!(log_reader.next_item())).await,
404                );
405                match either_result {
406                    futures::future::Either::Left(opt) => {
407                        let response = opt.context("end of response stream")?;
408                        match response {
409                            SinkWriterStreamResponse {
410                                response:
411                                    Some(sink_writer_stream_response::Response::Batch(
412                                        sink_writer_stream_response::BatchWrittenResponse {
413                                            epoch,
414                                            batch_id,
415                                        },
416                                    )),
417                            } => {
418                                truncate_matched_offset(
419                                    &mut sent_offset_queue,
420                                    TruncateOffset::Chunk {
421                                        epoch,
422                                        chunk_id: batch_id as _,
423                                    },
424                                    &mut log_reader,
425                                    &sink_writer_metrics,
426                                )?;
427                            }
428                            SinkWriterStreamResponse {
429                                response:
430                                    Some(sink_writer_stream_response::Response::Commit(
431                                        sink_writer_stream_response::CommitResponse {
432                                            epoch,
433                                            metadata,
434                                        },
435                                    )),
436                            } => {
437                                if let Some(metadata) = metadata {
438                                    warn!("get unexpected non-empty metadata: {:?}", metadata);
439                                }
440                                truncate_matched_offset(
441                                    &mut sent_offset_queue,
442                                    TruncateOffset::Barrier { epoch },
443                                    &mut log_reader,
444                                    &sink_writer_metrics,
445                                )?;
446                            }
447                            response => {
448                                return Err(SinkError::Remote(anyhow!(
449                                    "get unexpected response: {:?}",
450                                    response
451                                )));
452                            }
453                        }
454                    }
455                    futures::future::Either::Right(result) => {
456                        let (epoch, item): (u64, LogStoreReadItem) = result?;
457
458                        match item {
459                            LogStoreReadItem::StreamChunk { chunk, chunk_id } => {
460                                let offset = TruncateOffset::Chunk { epoch, chunk_id };
461                                if let Some(prev_offset) = &prev_offset {
462                                    prev_offset.check_next_offset(offset)?;
463                                }
464                                let cardinality = chunk.cardinality();
465                                sink_writer_metrics
466                                    .connector_sink_rows_received
467                                    .inc_by(cardinality as _);
468
469                                let chunk = self.stream_chunk_converter.convert_chunk(chunk)?;
470                                request_tx
471                                    .send_request(JniSinkWriterStreamRequest::Chunk {
472                                        epoch,
473                                        batch_id: chunk_id as u64,
474                                        chunk,
475                                    })
476                                    .instrument_await(span!(
477                                        "log_sinker_send_chunk (chunk {chunk_id})"
478                                    ))
479                                    .await?;
480                                prev_offset = Some(offset);
481                                sent_offset_queue
482                                    .push_back((TruncateOffset::Chunk { epoch, chunk_id }, None));
483                            }
484                            LogStoreReadItem::Barrier { is_checkpoint, .. } => {
485                                let offset = TruncateOffset::Barrier { epoch };
486                                if let Some(prev_offset) = &prev_offset {
487                                    prev_offset.check_next_offset(offset)?;
488                                }
489                                let start_time = if is_checkpoint {
490                                    let start_time = Instant::now();
491                                    request_tx
492                                        .barrier(epoch, true)
493                                        .instrument_await(span!(
494                                            "log_sinker_commit_checkpoint (epoch {epoch})"
495                                        ))
496                                        .await?;
497                                    Some(start_time)
498                                } else {
499                                    request_tx
500                                        .barrier(epoch, false)
501                                        .instrument_await(span!(
502                                            "log_sinker_send_barrier (epoch {epoch})"
503                                        ))
504                                        .await?;
505                                    None
506                                };
507                                prev_offset = Some(offset);
508                                sent_offset_queue
509                                    .push_back((TruncateOffset::Barrier { epoch }, start_time));
510                            }
511                        }
512                    }
513                }
514            }
515        };
516
517        select(pin!(poll_response_stream), pin!(poll_consume_log_and_sink))
518            .await
519            .factor_first()
520            .0
521    }
522}
523
524#[derive(Debug)]
525pub struct CoordinatedRemoteSink<R: RemoteSinkTrait> {
526    param: SinkParam,
527    _phantom: PhantomData<R>,
528}
529
530impl<R: RemoteSinkTrait> EnforceSecret for CoordinatedRemoteSink<R> {
531    const ENFORCE_SECRET_PROPERTIES: phf::Set<&'static str> = R::ENFORCE_SECRET_PROPERTIES;
532}
533
534impl<R: RemoteSinkTrait> TryFrom<SinkParam> for CoordinatedRemoteSink<R> {
535    type Error = SinkError;
536
537    fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
538        Ok(Self {
539            param,
540            _phantom: PhantomData,
541        })
542    }
543}
544
545impl<R: RemoteSinkTrait> Sink for CoordinatedRemoteSink<R> {
546    type Coordinator = RemoteCoordinator;
547    type LogSinker = CoordinatedLogSinker<CoordinatedRemoteSinkWriter>;
548
549    const SINK_NAME: &'static str = R::SINK_NAME;
550
551    async fn validate(&self) -> Result<()> {
552        validate_remote_sink(&self.param, Self::SINK_NAME).await?;
553        Ok(())
554    }
555
556    async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
557        let metrics = SinkWriterMetrics::new(&writer_param);
558        CoordinatedLogSinker::new(
559            &writer_param,
560            self.param.clone(),
561            CoordinatedRemoteSinkWriter::new(self.param.clone(), metrics.clone()).await?,
562            NonZero::new(1).unwrap(),
563        )
564        .await
565    }
566
567    fn is_coordinated_sink(&self) -> bool {
568        true
569    }
570
571    async fn new_coordinator(
572        &self,
573        _db: DatabaseConnection,
574        _iceberg_compact_stat_sender: Option<UnboundedSender<IcebergSinkCompactionUpdate>>,
575    ) -> Result<Self::Coordinator> {
576        RemoteCoordinator::new::<R>(self.param.clone()).await
577    }
578}
579
580pub struct CoordinatedRemoteSinkWriter {
581    epoch: Option<u64>,
582    batch_id: u64,
583    stream_handle: SinkWriterStreamHandle<JniSinkWriterStreamRequest>,
584    metrics: SinkWriterMetrics,
585}
586
587impl CoordinatedRemoteSinkWriter {
588    pub async fn new(param: SinkParam, metrics: SinkWriterMetrics) -> Result<Self> {
589        let sink_proto = param.to_proto();
590        let stream_handle = EmbeddedConnectorClient::new()?
591            .start_sink_writer_stream(sink_proto.table_schema.clone(), sink_proto)
592            .await?;
593
594        Ok(Self {
595            epoch: None,
596            batch_id: 0,
597            stream_handle,
598            metrics,
599        })
600    }
601
602    #[cfg(test)]
603    fn for_test(
604        response_receiver: Receiver<ConnectorResult<SinkWriterStreamResponse>>,
605        request_sender: tokio::sync::mpsc::Sender<JniSinkWriterStreamRequest>,
606    ) -> CoordinatedRemoteSinkWriter {
607        use futures::StreamExt;
608
609        let stream_handle = SinkWriterStreamHandle::for_test(
610            request_sender,
611            ReceiverStream::new(response_receiver)
612                .map_err(RpcError::from)
613                .boxed(),
614        );
615
616        CoordinatedRemoteSinkWriter {
617            epoch: None,
618            batch_id: 0,
619            stream_handle,
620            metrics: SinkWriterMetrics::for_test(),
621        }
622    }
623}
624
625#[async_trait]
626impl SinkWriter for CoordinatedRemoteSinkWriter {
627    type CommitMetadata = Option<SinkMetadata>;
628
629    async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> {
630        let cardinality = chunk.cardinality();
631        self.metrics
632            .connector_sink_rows_received
633            .inc_by(cardinality as _);
634
635        let epoch = self.epoch.ok_or_else(|| {
636            SinkError::Remote(anyhow!("epoch has not been initialize, call `begin_epoch`"))
637        })?;
638        let batch_id = self.batch_id;
639        self.stream_handle
640            .request_sender
641            .send_request(JniSinkWriterStreamRequest::Chunk {
642                chunk,
643                epoch,
644                batch_id,
645            })
646            .await?;
647        self.batch_id += 1;
648        Ok(())
649    }
650
651    async fn begin_epoch(&mut self, epoch: u64) -> Result<()> {
652        self.epoch = Some(epoch);
653        Ok(())
654    }
655
656    async fn barrier(&mut self, is_checkpoint: bool) -> Result<Option<SinkMetadata>> {
657        let epoch = self.epoch.ok_or_else(|| {
658            SinkError::Remote(anyhow!("epoch has not been initialize, call `begin_epoch`"))
659        })?;
660        if is_checkpoint {
661            // TODO: add metrics to measure commit time
662            let rsp = self.stream_handle.commit(epoch).await?;
663            rsp.metadata
664                .ok_or_else(|| {
665                    SinkError::Remote(anyhow!(
666                        "get none metadata in commit response for coordinated sink writer"
667                    ))
668                })
669                .map(Some)
670        } else {
671            self.stream_handle.barrier(epoch).await?;
672            Ok(None)
673        }
674    }
675}
676
677pub struct RemoteCoordinator {
678    stream_handle: SinkCoordinatorStreamHandle,
679}
680
681impl RemoteCoordinator {
682    pub async fn new<R: RemoteSinkTrait>(param: SinkParam) -> Result<Self> {
683        let stream_handle = EmbeddedConnectorClient::new()?
684            .start_sink_coordinator_stream(param.clone())
685            .await?;
686
687        tracing::trace!("{:?} RemoteCoordinator started", R::SINK_NAME,);
688
689        Ok(RemoteCoordinator { stream_handle })
690    }
691}
692
693#[async_trait]
694impl SinkCommitCoordinator for RemoteCoordinator {
695    async fn init(&mut self, _subscriber: SinkCommittedEpochSubscriber) -> Result<Option<u64>> {
696        Ok(None)
697    }
698
699    async fn commit(
700        &mut self,
701        epoch: u64,
702        metadata: Vec<SinkMetadata>,
703        add_columns: Option<Vec<Field>>,
704    ) -> Result<()> {
705        if let Some(add_columns) = add_columns {
706            return Err(anyhow!(
707                "remote coordinator not support add columns, but got: {:?}",
708                add_columns
709            )
710            .into());
711        }
712        Ok(self.stream_handle.commit(epoch, metadata).await?)
713    }
714}
715
716struct EmbeddedConnectorClient {
717    jvm: Jvm,
718}
719
720impl EmbeddedConnectorClient {
721    fn new() -> Result<Self> {
722        let jvm = Jvm::get_or_init().context("failed to create EmbeddedConnectorClient")?;
723        Ok(EmbeddedConnectorClient { jvm })
724    }
725
726    async fn start_sink_writer_stream(
727        &self,
728        payload_schema: Option<TableSchema>,
729        sink_proto: PbSinkParam,
730    ) -> Result<SinkWriterStreamHandle<JniSinkWriterStreamRequest>> {
731        let (handle, first_rsp) = SinkWriterStreamHandle::initialize(
732            SinkWriterStreamRequest {
733                request: Some(SinkRequest::Start(StartSink {
734                    sink_param: Some(sink_proto),
735                    payload_schema,
736                })),
737            },
738            |rx| async move {
739                let rx = self.start_jvm_worker_thread(
740                    gen_class_name!(com.risingwave.connector.JniSinkWriterHandler),
741                    "runJniSinkWriterThread",
742                    rx,
743                );
744                Ok(ReceiverStream::new(rx).map_err(RpcError::from))
745            },
746        )
747        .await?;
748
749        match first_rsp {
750            SinkWriterStreamResponse {
751                response: Some(sink_writer_stream_response::Response::Start(_)),
752            } => Ok(handle),
753            msg => Err(SinkError::Internal(anyhow!(
754                "should get start response but get {:?}",
755                msg
756            ))),
757        }
758    }
759
760    async fn start_sink_coordinator_stream(
761        &self,
762        param: SinkParam,
763    ) -> Result<SinkCoordinatorStreamHandle> {
764        let (handle, first_rsp) = SinkCoordinatorStreamHandle::initialize(
765            SinkCoordinatorStreamRequest {
766                request: Some(sink_coordinator_stream_request::Request::Start(
767                    StartCoordinator {
768                        param: Some(param.to_proto()),
769                    },
770                )),
771            },
772            |rx| async move {
773                let rx = self.start_jvm_worker_thread(
774                    gen_class_name!(com.risingwave.connector.JniSinkCoordinatorHandler),
775                    "runJniSinkCoordinatorThread",
776                    rx,
777                );
778                Ok(ReceiverStream::new(rx).map_err(RpcError::from))
779            },
780        )
781        .await?;
782
783        match first_rsp {
784            SinkCoordinatorStreamResponse {
785                response: Some(sink_coordinator_stream_response::Response::Start(_)),
786            } => Ok(handle),
787            msg => Err(SinkError::Internal(anyhow!(
788                "should get start response but get {:?}",
789                msg
790            ))),
791        }
792    }
793
794    fn start_jvm_worker_thread<REQ: Send + 'static, RSP: Send + 'static>(
795        &self,
796        class_name: &'static str,
797        method_name: &'static str,
798        mut request_rx: JniReceiverType<REQ>,
799    ) -> Receiver<std::result::Result<RSP, anyhow::Error>> {
800        let (mut response_tx, response_rx): (JniSenderType<RSP>, _) =
801            mpsc::channel(DEFAULT_BUFFER_SIZE);
802
803        let jvm = self.jvm;
804        std::thread::spawn(move || {
805            let result = execute_with_jni_env(jvm, |env| {
806                let result = call_static_method!(
807                    env,
808                    class_name,
809                    method_name,
810                    {{void}, {long requestRx, long responseTx}},
811                    &mut request_rx as *mut JniReceiverType<REQ>,
812                    &mut response_tx as *mut JniSenderType<RSP>
813                );
814
815                match result {
816                    Ok(_) => {
817                        tracing::debug!("end of jni call {}::{}", class_name, method_name);
818                    }
819                    Err(e) => {
820                        tracing::error!(error = %e.as_report(), "jni call error");
821                    }
822                };
823
824                Ok(())
825            });
826
827            if let Err(e) = result {
828                let _ = response_tx.blocking_send(Err(e));
829            }
830        });
831        response_rx
832    }
833}
834
835#[cfg(test)]
836mod test {
837    use std::time::Duration;
838
839    use risingwave_common::array::StreamChunk;
840    use risingwave_common::test_prelude::StreamChunkTestExt;
841    use risingwave_jni_core::JniSinkWriterStreamRequest;
842    use risingwave_pb::connector_service::sink_writer_stream_request::{Barrier, Request};
843    use risingwave_pb::connector_service::sink_writer_stream_response::{CommitResponse, Response};
844    use risingwave_pb::connector_service::{SinkWriterStreamRequest, SinkWriterStreamResponse};
845    use tokio::sync::mpsc;
846
847    use crate::sink::SinkWriter;
848    use crate::sink::remote::CoordinatedRemoteSinkWriter;
849
850    #[tokio::test]
851    async fn test_epoch_check() {
852        let (request_sender, mut request_recv) = mpsc::channel(16);
853        let (_, resp_recv) = mpsc::channel(16);
854
855        let mut sink = CoordinatedRemoteSinkWriter::for_test(resp_recv, request_sender);
856        let chunk = StreamChunk::from_pretty(
857            " i T
858            + 1 Ripper
859        ",
860        );
861
862        // test epoch check
863        assert!(
864            tokio::time::timeout(Duration::from_secs(10), sink.barrier(true))
865                .await
866                .expect("test failed: should not commit without epoch")
867                .is_err(),
868            "test failed: no epoch check for commit()"
869        );
870        assert!(
871            request_recv.try_recv().is_err(),
872            "test failed: unchecked epoch before request"
873        );
874
875        assert!(
876            tokio::time::timeout(Duration::from_secs(1), sink.write_batch(chunk))
877                .await
878                .expect("test failed: should not write without epoch")
879                .is_err(),
880            "test failed: no epoch check for write_batch()"
881        );
882        assert!(
883            request_recv.try_recv().is_err(),
884            "test failed: unchecked epoch before request"
885        );
886    }
887
888    #[tokio::test]
889    async fn test_remote_sink() {
890        let (request_sender, mut request_receiver) = mpsc::channel(16);
891        let (response_sender, response_receiver) = mpsc::channel(16);
892        let mut sink = CoordinatedRemoteSinkWriter::for_test(response_receiver, request_sender);
893
894        let chunk_a = StreamChunk::from_pretty(
895            " i T
896            + 1 Alice
897            + 2 Bob
898            + 3 Clare
899        ",
900        );
901        let chunk_b = StreamChunk::from_pretty(
902            " i T
903            + 4 David
904            + 5 Eve
905            + 6 Frank
906        ",
907        );
908
909        // test write batch
910        sink.begin_epoch(2022).await.unwrap();
911        assert_eq!(sink.epoch, Some(2022));
912
913        sink.write_batch(chunk_a.clone()).await.unwrap();
914        assert_eq!(sink.epoch, Some(2022));
915        assert_eq!(sink.batch_id, 1);
916        match request_receiver.recv().await.unwrap() {
917            JniSinkWriterStreamRequest::Chunk {
918                epoch,
919                batch_id,
920                chunk,
921            } => {
922                assert_eq!(epoch, 2022);
923                assert_eq!(batch_id, 0);
924                assert_eq!(chunk, chunk_a);
925            }
926            _ => panic!("test failed: failed to construct write request"),
927        }
928
929        // test commit
930        response_sender
931            .send(Ok(SinkWriterStreamResponse {
932                response: Some(Response::Commit(CommitResponse {
933                    epoch: 2022,
934                    metadata: None,
935                })),
936            }))
937            .await
938            .expect("test failed: failed to sync epoch");
939        sink.barrier(false).await.unwrap();
940        let commit_request = request_receiver.recv().await.unwrap();
941        match commit_request {
942            JniSinkWriterStreamRequest::PbRequest(SinkWriterStreamRequest {
943                request:
944                    Some(Request::Barrier(Barrier {
945                        epoch,
946                        is_checkpoint: false,
947                    })),
948            }) => {
949                assert_eq!(epoch, 2022);
950            }
951            _ => panic!("test failed: failed to construct sync request "),
952        };
953
954        // begin another epoch
955        sink.begin_epoch(2023).await.unwrap();
956        assert_eq!(sink.epoch, Some(2023));
957
958        // test another write
959        sink.write_batch(chunk_b.clone()).await.unwrap();
960        assert_eq!(sink.epoch, Some(2023));
961        assert_eq!(sink.batch_id, 2);
962        match request_receiver.recv().await.unwrap() {
963            JniSinkWriterStreamRequest::Chunk {
964                epoch,
965                batch_id,
966                chunk,
967            } => {
968                assert_eq!(epoch, 2023);
969                assert_eq!(batch_id, 1);
970                assert_eq!(chunk, chunk_b);
971            }
972            _ => panic!("test failed: failed to construct write request"),
973        }
974    }
975}