1use std::collections::{BTreeMap, HashMap};
16use std::ops::Deref;
17use std::sync::LazyLock;
18
19use anyhow::anyhow;
20use futures::TryFutureExt;
21use futures::future::{TryJoinAll, try_join_all};
22use futures::prelude::TryFuture;
23use itertools::Itertools;
24use mongodb::bson::{Array, Bson, Document, bson, doc};
25use mongodb::{Client, Namespace};
26use risingwave_common::array::{Op, RowRef, StreamChunk};
27use risingwave_common::catalog::Schema;
28use risingwave_common::log::LogSuppresser;
29use risingwave_common::row::Row;
30use risingwave_common::types::ScalarRefImpl;
31use serde_derive::Deserialize;
32use serde_with::{DisplayFromStr, serde_as};
33use thiserror_ext::AsReport;
34use with_options::WithOptions;
35
36use super::encoder::BsonEncoder;
37use super::log_store::DeliveryFutureManagerAddFuture;
38use super::writer::{
39 AsyncTruncateLogSinkerOf, AsyncTruncateSinkWriter, AsyncTruncateSinkWriterExt,
40};
41use crate::connector_common::MongodbCommon;
42use crate::deserialize_bool_from_string;
43use crate::sink::encoder::RowEncoder;
44use crate::sink::{
45 DummySinkCommitCoordinator, Result, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT,
46 Sink, SinkError, SinkParam, SinkWriterParam,
47};
48
49mod send_bulk_write_command_future {
50 use core::future::Future;
51
52 use anyhow::anyhow;
53 use mongodb::Database;
54 use mongodb::bson::Document;
55
56 use crate::sink::{Result, SinkError};
57
58 pub(super) type SendBulkWriteCommandFuture = impl Future<Output = Result<()>> + 'static;
59
60 pub(super) fn send_bulk_write_commands(
61 db: Database,
62 upsert: Option<Document>,
63 delete: Option<Document>,
64 ) -> SendBulkWriteCommandFuture {
65 async move {
66 if let Some(upsert) = upsert {
67 send_bulk_write_command(db.clone(), upsert).await?;
68 }
69 if let Some(delete) = delete {
70 send_bulk_write_command(db, delete).await?;
71 }
72 Ok(())
73 }
74 }
75
76 async fn send_bulk_write_command(db: Database, command: Document) -> Result<()> {
77 let result = db.run_command(command).await.map_err(|err| {
78 SinkError::Mongodb(anyhow!(err).context(format!(
79 "sending bulk write command failed, database: {}",
80 db.name()
81 )))
82 })?;
83
84 if let Ok(ok) = result.get_i32("ok")
85 && ok != 1
86 {
87 return Err(SinkError::Mongodb(anyhow!("bulk write write errors")));
88 }
89
90 if let Ok(write_errors) = result.get_array("writeErrors") {
91 return Err(SinkError::Mongodb(anyhow!(
92 "bulk write respond with write errors: {:?}",
93 write_errors,
94 )));
95 }
96
97 if let Ok(write_concern_error) = result.get_array("writeConcernError") {
98 return Err(SinkError::Mongodb(anyhow!(
99 "bulk write respond with write errors: {:?}",
100 write_concern_error,
101 )));
102 }
103
104 Ok(())
105 }
106}
107
108pub const MONGODB_SINK: &str = "mongodb";
109const MONGODB_SEND_FUTURE_BUFFER_MAX_SIZE: usize = 4096;
110
111pub const MONGODB_PK_NAME: &str = "_id";
112
113static LOG_SUPPERSSER: LazyLock<LogSuppresser> = LazyLock::new(LogSuppresser::default);
114
115const fn _default_bulk_write_max_entries() -> usize {
116 1024
117}
118#[serde_as]
119#[derive(Clone, Debug, Deserialize, WithOptions)]
120pub struct MongodbConfig {
121 #[serde(flatten)]
122 pub common: MongodbCommon,
123
124 pub r#type: String, #[serde(rename = "collection.name.field")]
130 pub collection_name_field: Option<String>,
131
132 #[serde(
136 default,
137 deserialize_with = "deserialize_bool_from_string",
138 rename = "collection.name.field.drop"
139 )]
140 pub drop_collection_name_field: bool,
141
142 #[serde(
144 rename = "mongodb.bulk_write.max_entries",
145 default = "_default_bulk_write_max_entries"
146 )]
147 #[serde_as(as = "DisplayFromStr")]
148 #[deprecated]
149 pub bulk_write_max_entries: usize,
150}
151
152impl MongodbConfig {
153 pub fn from_btreemap(properties: BTreeMap<String, String>) -> crate::sink::Result<Self> {
154 let config =
155 serde_json::from_value::<MongodbConfig>(serde_json::to_value(properties).unwrap())
156 .map_err(|e| SinkError::Config(anyhow!(e)))?;
157 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
158 return Err(SinkError::Config(anyhow!(
159 "`{}` must be {}, or {}",
160 SINK_TYPE_OPTION,
161 SINK_TYPE_APPEND_ONLY,
162 SINK_TYPE_UPSERT
163 )));
164 }
165 Ok(config)
166 }
167}
168
169struct ClientGuard {
176 _tx: tokio::sync::oneshot::Sender<()>,
177 client: Client,
178}
179
180impl ClientGuard {
181 fn new(name: String, client: Client) -> Self {
182 let client_copy = client.clone();
183 let (_tx, rx) = tokio::sync::oneshot::channel::<()>();
184 tokio::spawn(async move {
185 tracing::debug!(%name, "waiting for client to shut down");
186 let _ = rx.await;
187 tracing::debug!(%name, "sender dropped now calling client's shutdown");
188 client_copy.shutdown().await;
193 tracing::debug!(%name, "client shutdown succeeded");
194 });
195 Self { _tx, client }
196 }
197}
198
199impl Deref for ClientGuard {
200 type Target = Client;
201
202 fn deref(&self) -> &Self::Target {
203 &self.client
204 }
205}
206
207#[derive(Debug)]
208pub struct MongodbSink {
209 pub config: MongodbConfig,
210 param: SinkParam,
211 schema: Schema,
212 pk_indices: Vec<usize>,
213 is_append_only: bool,
214}
215
216impl MongodbSink {
217 pub fn new(param: SinkParam) -> Result<Self> {
218 let config = MongodbConfig::from_btreemap(param.properties.clone())?;
219 let pk_indices = param.downstream_pk.clone();
220 let is_append_only = param.sink_type.is_append_only();
221 let schema = param.schema();
222 Ok(Self {
223 config,
224 param,
225 schema,
226 pk_indices,
227 is_append_only,
228 })
229 }
230}
231
232impl TryFrom<SinkParam> for MongodbSink {
233 type Error = SinkError;
234
235 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
236 MongodbSink::new(param)
237 }
238}
239
240impl Sink for MongodbSink {
241 type Coordinator = DummySinkCommitCoordinator;
242 type LogSinker = AsyncTruncateLogSinkerOf<MongodbSinkWriter>;
243
244 const SINK_NAME: &'static str = MONGODB_SINK;
245
246 async fn validate(&self) -> Result<()> {
247 if !self.is_append_only {
248 if self.pk_indices.is_empty() {
249 return Err(SinkError::Config(anyhow!(
250 "Primary key not defined for upsert mongodb sink (please define in `primary_key` field)"
251 )));
252 }
253
254 if self
256 .schema
257 .fields
258 .iter()
259 .enumerate()
260 .any(|(i, field)| !self.pk_indices.contains(&i) && field.name == MONGODB_PK_NAME)
261 {
262 return Err(SinkError::Config(anyhow!(
263 "_id field must be the sink's primary key, but a non primary key field name is _id",
264 )));
265 }
266
267 if self.pk_indices.len() > 1
276 && self
277 .pk_indices
278 .iter()
279 .map(|&idx| self.schema.fields[idx].name.as_str())
280 .any(|field| field == MONGODB_PK_NAME)
281 {
282 return Err(SinkError::Config(anyhow!(
283 "primary key fields must not contain a field named _id"
284 )));
285 }
286 }
287
288 if let Err(err) = self.config.common.collection_name.parse::<Namespace>() {
289 return Err(SinkError::Config(anyhow!(err).context(format!(
290 "invalid collection.name {}",
291 self.config.common.collection_name
292 ))));
293 }
294
295 let client = self.config.common.build_client().await?;
297 let client = ClientGuard::new(self.param.sink_name.clone(), client);
298 client
299 .database("admin")
300 .run_command(doc! {"hello":1})
301 .await
302 .map_err(|err| {
303 SinkError::Mongodb(anyhow!(err).context("failed to send hello command to mongodb"))
304 })?;
305
306 if self.config.drop_collection_name_field && self.config.collection_name_field.is_none() {
307 return Err(SinkError::Config(anyhow!(
308 "collection.name.field must be specified when collection.name.field.drop is enabled"
309 )));
310 }
311
312 if let Some(coll_field) = &self.config.collection_name_field {
314 let fields = self.schema.fields();
315
316 let coll_field_index = fields
317 .iter()
318 .enumerate()
319 .find_map(|(index, field)| {
320 if &field.name == coll_field {
321 Some(index)
322 } else {
323 None
324 }
325 })
326 .ok_or(SinkError::Config(anyhow!(
327 "collection.name.field {} not found",
328 coll_field
329 )))?;
330
331 if fields[coll_field_index].data_type() != risingwave_common::types::DataType::Varchar {
332 return Err(SinkError::Config(anyhow!(
333 "the type of collection.name.field {} must be varchar",
334 coll_field
335 )));
336 }
337
338 if !self.is_append_only && self.pk_indices.contains(&coll_field_index) {
339 return Err(SinkError::Config(anyhow!(
340 "collection.name.field {} must not be equal to the primary key field",
341 coll_field
342 )));
343 }
344 }
345
346 Ok(())
347 }
348
349 async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
350 Ok(MongodbSinkWriter::new(
351 format!("{}-{}", writer_param.executor_id, self.param.sink_name),
352 self.config.clone(),
353 self.schema.clone(),
354 self.pk_indices.clone(),
355 self.is_append_only,
356 )
357 .await?
358 .into_log_sinker(MONGODB_SEND_FUTURE_BUFFER_MAX_SIZE))
359 }
360}
361
362use send_bulk_write_command_future::*;
363
364pub struct MongodbSinkWriter {
365 pub config: MongodbConfig,
366 payload_writer: MongodbPayloadWriter,
367 is_append_only: bool,
368}
369
370impl MongodbSinkWriter {
371 pub async fn new(
372 name: String,
373 config: MongodbConfig,
374 schema: Schema,
375 pk_indices: Vec<usize>,
376 is_append_only: bool,
377 ) -> Result<Self> {
378 let client = config.common.build_client().await?;
379
380 let default_namespace =
381 config
382 .common
383 .collection_name
384 .parse()
385 .map_err(|err: mongodb::error::Error| {
386 SinkError::Mongodb(anyhow!(err).context("parsing default namespace failed"))
387 })?;
388
389 let coll_name_field_index =
390 config
391 .collection_name_field
392 .as_ref()
393 .and_then(|coll_name_field| {
394 schema
395 .names_str()
396 .iter()
397 .position(|&name| coll_name_field == name)
398 });
399
400 let col_indices = if let Some(coll_name_field_index) = coll_name_field_index
401 && config.drop_collection_name_field
402 {
403 (0..schema.fields.len())
404 .filter(|idx| *idx != coll_name_field_index)
405 .collect_vec()
406 } else {
407 (0..schema.fields.len()).collect_vec()
408 };
409
410 let row_encoder = BsonEncoder::new(schema.clone(), Some(col_indices), pk_indices.clone());
411
412 let payload_writer = MongodbPayloadWriter::new(
413 schema,
414 pk_indices,
415 default_namespace,
416 coll_name_field_index,
417 ClientGuard::new(name, client),
418 row_encoder,
419 );
420
421 Ok(Self {
422 config,
423 payload_writer,
424 is_append_only,
425 })
426 }
427
428 fn append(&mut self, chunk: StreamChunk) -> Result<TryJoinAll<SendBulkWriteCommandFuture>> {
429 let mut insert_builder: HashMap<MongodbNamespace, InsertCommandBuilder> = HashMap::new();
430 for (op, row) in chunk.rows() {
431 if op != Op::Insert {
432 if let Ok(suppressed_count) = LOG_SUPPERSSER.check() {
433 tracing::warn!(
434 suppressed_count,
435 ?op,
436 ?row,
437 "non-insert op received in append-only mode"
438 );
439 }
440 continue;
441 }
442 self.payload_writer.append(&mut insert_builder, row)?;
443 }
444 Ok(self.payload_writer.flush_insert(insert_builder))
445 }
446
447 fn upsert(&mut self, chunk: StreamChunk) -> Result<TryJoinAll<SendBulkWriteCommandFuture>> {
448 let mut upsert_builder: HashMap<MongodbNamespace, UpsertCommandBuilder> = HashMap::new();
449 for (op, row) in chunk.rows() {
450 if op == Op::UpdateDelete {
451 continue;
453 }
454 self.payload_writer.upsert(&mut upsert_builder, op, row)?;
455 }
456 Ok(self.payload_writer.flush_upsert(upsert_builder))
457 }
458}
459
460pub type MongodbSinkDeliveryFuture = impl TryFuture<Ok = (), Error = SinkError> + Unpin + 'static;
461
462impl AsyncTruncateSinkWriter for MongodbSinkWriter {
463 type DeliveryFuture = MongodbSinkDeliveryFuture;
464
465 async fn write_chunk<'a>(
466 &'a mut self,
467 chunk: StreamChunk,
468 mut add_future: DeliveryFutureManagerAddFuture<'a, Self::DeliveryFuture>,
469 ) -> Result<()> {
470 let futures = if self.is_append_only {
471 self.append(chunk)?
472 } else {
473 self.upsert(chunk)?
474 };
475 add_future
476 .add_future_may_await(futures.map_ok(|_: Vec<()>| ()))
477 .await?;
478 Ok(())
479 }
480}
481
482struct InsertCommandBuilder {
483 coll: String,
484 inserts: Array,
485}
486
487impl InsertCommandBuilder {
488 fn new(coll: String) -> Self {
489 Self {
490 coll,
491 inserts: Array::new(),
492 }
493 }
494
495 fn append(&mut self, row: Document) {
496 self.inserts.push(Bson::Document(row));
497 }
498
499 fn build(self) -> Document {
500 doc! {
501 "insert": self.coll,
502 "ordered": true,
503 "documents": self.inserts,
504 }
505 }
506}
507
508struct UpsertCommandBuilder {
509 coll: String,
510 updates: Array,
511 deletes: HashMap<Vec<u8>, Document>,
512}
513
514impl UpsertCommandBuilder {
515 fn new(coll: String) -> Self {
516 Self {
517 coll,
518 updates: Array::new(),
519 deletes: HashMap::new(),
520 }
521 }
522
523 fn add_upsert(&mut self, pk: Document, row: Document) -> Result<()> {
524 let pk_data = mongodb::bson::to_vec(&pk).map_err(|err| {
525 SinkError::Mongodb(anyhow!(err).context("cannot serialize primary key"))
526 })?;
527 self.deletes.remove(&pk_data);
531
532 self.updates.push(bson!( {
533 "q": pk,
534 "u": bson!( {
535 "$set": row,
536 }),
537 "upsert": true,
538 "multi": false,
539 }));
540
541 Ok(())
542 }
543
544 fn add_delete(&mut self, pk: Document) -> Result<()> {
545 let pk_data = mongodb::bson::to_vec(&pk).map_err(|err| {
546 SinkError::Mongodb(anyhow!(err).context("cannot serialize primary key"))
547 })?;
548 self.deletes.insert(pk_data, pk);
549 Ok(())
550 }
551
552 fn build(self) -> (Option<Document>, Option<Document>) {
553 let (mut upsert_document, mut delete_document) = (None, None);
554 if !self.updates.is_empty() {
555 upsert_document = Some(doc! {
556 "update": self.coll.clone(),
557 "ordered": true,
558 "updates": self.updates,
559 });
560 }
561 if !self.deletes.is_empty() {
562 let deletes = self
563 .deletes
564 .into_values()
565 .map(|pk| {
566 bson!({
567 "q": pk,
568 "limit": 1,
569 })
570 })
571 .collect::<Array>();
572
573 delete_document = Some(doc! {
574 "delete": self.coll,
575 "ordered": true,
576 "deletes": deletes,
577 });
578 }
579 (upsert_document, delete_document)
580 }
581}
582
583type MongodbNamespace = (String, String);
584
585struct MongodbPayloadWriter {
588 schema: Schema,
589 pk_indices: Vec<usize>,
590 default_namespace: Namespace,
591 coll_name_field_index: Option<usize>,
592 client: ClientGuard,
593 row_encoder: BsonEncoder,
594}
595
596impl MongodbPayloadWriter {
597 fn new(
598 schema: Schema,
599 pk_indices: Vec<usize>,
600 default_namespace: Namespace,
601 coll_name_field_index: Option<usize>,
602 client: ClientGuard,
603 row_encoder: BsonEncoder,
604 ) -> Self {
605 Self {
606 schema,
607 pk_indices,
608 default_namespace,
609 coll_name_field_index,
610 client,
611 row_encoder,
612 }
613 }
614
615 fn extract_namespace_from_row_ref(&self, row: RowRef<'_>) -> MongodbNamespace {
616 let ns = self.coll_name_field_index.and_then(|coll_name_field_index| {
617 match row.datum_at(coll_name_field_index) {
618 Some(ScalarRefImpl::Utf8(v)) => match v.parse::<Namespace>() {
619 Ok(ns) => Some(ns),
620 Err(err) => {
621 if let Ok(suppressed_count) = LOG_SUPPERSSER.check() {
622 tracing::warn!(
623 suppressed_count,
624 error = %err.as_report(),
625 collection_name = %v,
626 "parsing collection name failed, fallback to use default collection.name"
627 );
628 }
629 None
630 }
631 },
632 _ => {
633 if let Ok(suppressed_count) = LOG_SUPPERSSER.check() {
634 tracing::warn!(
635 suppressed_count,
636 "the value of collection.name.field is null, fallback to use default collection.name"
637 );
638 }
639 None
640 }
641 }
642 });
643 match ns {
644 Some(ns) => (ns.db, ns.coll),
645 None => (
646 self.default_namespace.db.clone(),
647 self.default_namespace.coll.clone(),
648 ),
649 }
650 }
651
652 fn append(
653 &mut self,
654 insert_builder: &mut HashMap<MongodbNamespace, InsertCommandBuilder>,
655 row: RowRef<'_>,
656 ) -> Result<()> {
657 let document = self.row_encoder.encode(row)?;
658 let ns = self.extract_namespace_from_row_ref(row);
659 let coll = ns.1.clone();
660
661 insert_builder
662 .entry(ns)
663 .or_insert_with(|| InsertCommandBuilder::new(coll))
664 .append(document);
665 Ok(())
666 }
667
668 fn upsert(
669 &mut self,
670 upsert_builder: &mut HashMap<MongodbNamespace, UpsertCommandBuilder>,
671 op: Op,
672 row: RowRef<'_>,
673 ) -> Result<()> {
674 let mut document = self.row_encoder.encode(row)?;
675 let ns = self.extract_namespace_from_row_ref(row);
676 let coll = ns.1.clone();
677
678 let pk = self.row_encoder.construct_pk(row);
679
680 if self.pk_indices.len() > 1
682 || self.schema.fields[self.pk_indices[0]].name != MONGODB_PK_NAME
683 {
684 document.insert(MONGODB_PK_NAME, pk.clone());
686 }
687
688 let pk = doc! {MONGODB_PK_NAME: pk};
689 match op {
690 Op::Insert | Op::UpdateInsert => upsert_builder
691 .entry(ns)
692 .or_insert_with(|| UpsertCommandBuilder::new(coll))
693 .add_upsert(pk, document)?,
694 Op::UpdateDelete => (),
695 Op::Delete => upsert_builder
696 .entry(ns)
697 .or_insert_with(|| UpsertCommandBuilder::new(coll))
698 .add_delete(pk)?,
699 }
700 Ok(())
701 }
702
703 fn flush_insert(
704 &self,
705 insert_builder: HashMap<MongodbNamespace, InsertCommandBuilder>,
706 ) -> TryJoinAll<SendBulkWriteCommandFuture> {
707 let futures = insert_builder.into_iter().map(|(ns, builder)| {
711 let db = self.client.database(&ns.0);
712 send_bulk_write_commands(db, Some(builder.build()), None)
713 });
714 try_join_all(futures)
715 }
716
717 fn flush_upsert(
718 &self,
719 upsert_builder: HashMap<MongodbNamespace, UpsertCommandBuilder>,
720 ) -> TryJoinAll<SendBulkWriteCommandFuture> {
721 let futures = upsert_builder.into_iter().map(|(ns, builder)| {
725 let (upsert, delete) = builder.build();
726 let db = self.client.database(&ns.0);
730 send_bulk_write_commands(db, upsert, delete)
731 });
732 try_join_all(futures)
733 }
734}