1use std::collections::{BTreeMap, HashMap};
16use std::ops::Deref;
17use std::sync::LazyLock;
18
19use anyhow::anyhow;
20use futures::TryFutureExt;
21use futures::future::{TryJoinAll, try_join_all};
22use futures::prelude::TryFuture;
23use itertools::Itertools;
24use mongodb::bson::{Array, Bson, Document, bson, doc};
25use mongodb::{Client, Namespace};
26use risingwave_common::array::{Op, RowRef, StreamChunk};
27use risingwave_common::catalog::Schema;
28use risingwave_common::log::LogSuppresser;
29use risingwave_common::row::Row;
30use risingwave_common::types::ScalarRefImpl;
31use serde_derive::Deserialize;
32use serde_with::{DisplayFromStr, serde_as};
33use thiserror_ext::AsReport;
34use with_options::WithOptions;
35
36use super::encoder::BsonEncoder;
37use super::log_store::DeliveryFutureManagerAddFuture;
38use super::writer::{
39 AsyncTruncateLogSinkerOf, AsyncTruncateSinkWriter, AsyncTruncateSinkWriterExt,
40};
41use crate::connector_common::MongodbCommon;
42use crate::deserialize_bool_from_string;
43use crate::enforce_secret::EnforceSecret;
44use crate::sink::encoder::RowEncoder;
45use crate::sink::{
46 DummySinkCommitCoordinator, Result, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT,
47 Sink, SinkError, SinkParam, SinkWriterParam,
48};
49
50mod send_bulk_write_command_future {
51 use core::future::Future;
52
53 use anyhow::anyhow;
54 use mongodb::Database;
55 use mongodb::bson::Document;
56
57 use crate::sink::{Result, SinkError};
58
59 pub(super) type SendBulkWriteCommandFuture = impl Future<Output = Result<()>> + 'static;
60
61 pub(super) fn send_bulk_write_commands(
62 db: Database,
63 upsert: Option<Document>,
64 delete: Option<Document>,
65 ) -> SendBulkWriteCommandFuture {
66 async move {
67 if let Some(upsert) = upsert {
68 send_bulk_write_command(db.clone(), upsert).await?;
69 }
70 if let Some(delete) = delete {
71 send_bulk_write_command(db, delete).await?;
72 }
73 Ok(())
74 }
75 }
76
77 async fn send_bulk_write_command(db: Database, command: Document) -> Result<()> {
78 let result = db.run_command(command).await.map_err(|err| {
79 SinkError::Mongodb(anyhow!(err).context(format!(
80 "sending bulk write command failed, database: {}",
81 db.name()
82 )))
83 })?;
84
85 if let Ok(ok) = result.get_i32("ok")
86 && ok != 1
87 {
88 return Err(SinkError::Mongodb(anyhow!("bulk write write errors")));
89 }
90
91 if let Ok(write_errors) = result.get_array("writeErrors") {
92 return Err(SinkError::Mongodb(anyhow!(
93 "bulk write respond with write errors: {:?}",
94 write_errors,
95 )));
96 }
97
98 if let Ok(write_concern_error) = result.get_array("writeConcernError") {
99 return Err(SinkError::Mongodb(anyhow!(
100 "bulk write respond with write errors: {:?}",
101 write_concern_error,
102 )));
103 }
104
105 Ok(())
106 }
107}
108
109pub const MONGODB_SINK: &str = "mongodb";
110const MONGODB_SEND_FUTURE_BUFFER_MAX_SIZE: usize = 4096;
111
112pub const MONGODB_PK_NAME: &str = "_id";
113
114static LOG_SUPPERSSER: LazyLock<LogSuppresser> = LazyLock::new(LogSuppresser::default);
115
116const fn _default_bulk_write_max_entries() -> usize {
117 1024
118}
119#[serde_as]
120#[derive(Clone, Debug, Deserialize, WithOptions)]
121pub struct MongodbConfig {
122 #[serde(flatten)]
123 pub common: MongodbCommon,
124
125 pub r#type: String, #[serde(rename = "collection.name.field")]
131 pub collection_name_field: Option<String>,
132
133 #[serde(
137 default,
138 deserialize_with = "deserialize_bool_from_string",
139 rename = "collection.name.field.drop"
140 )]
141 pub drop_collection_name_field: bool,
142
143 #[serde(
145 rename = "mongodb.bulk_write.max_entries",
146 default = "_default_bulk_write_max_entries"
147 )]
148 #[serde_as(as = "DisplayFromStr")]
149 #[deprecated]
150 pub bulk_write_max_entries: usize,
151}
152
153impl EnforceSecret for MongodbConfig {
154 fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
155 MongodbCommon::enforce_one(prop)
156 }
157}
158
159impl MongodbConfig {
160 pub fn from_btreemap(properties: BTreeMap<String, String>) -> crate::sink::Result<Self> {
161 let config =
162 serde_json::from_value::<MongodbConfig>(serde_json::to_value(properties).unwrap())
163 .map_err(|e| SinkError::Config(anyhow!(e)))?;
164 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
165 return Err(SinkError::Config(anyhow!(
166 "`{}` must be {}, or {}",
167 SINK_TYPE_OPTION,
168 SINK_TYPE_APPEND_ONLY,
169 SINK_TYPE_UPSERT
170 )));
171 }
172 Ok(config)
173 }
174}
175
176struct ClientGuard {
183 _tx: tokio::sync::oneshot::Sender<()>,
184 client: Client,
185}
186
187impl ClientGuard {
188 fn new(name: String, client: Client) -> Self {
189 let client_copy = client.clone();
190 let (_tx, rx) = tokio::sync::oneshot::channel::<()>();
191 tokio::spawn(async move {
192 tracing::debug!(%name, "waiting for client to shut down");
193 let _ = rx.await;
194 tracing::debug!(%name, "sender dropped now calling client's shutdown");
195 client_copy.shutdown().await;
200 tracing::debug!(%name, "client shutdown succeeded");
201 });
202 Self { _tx, client }
203 }
204}
205
206impl Deref for ClientGuard {
207 type Target = Client;
208
209 fn deref(&self) -> &Self::Target {
210 &self.client
211 }
212}
213
214#[derive(Debug)]
215pub struct MongodbSink {
216 pub config: MongodbConfig,
217 param: SinkParam,
218 schema: Schema,
219 pk_indices: Vec<usize>,
220 is_append_only: bool,
221}
222
223impl EnforceSecret for MongodbSink {
224 fn enforce_secret<'a>(
225 prop_iter: impl Iterator<Item = &'a str>,
226 ) -> crate::sink::ConnectorResult<()> {
227 for prop in prop_iter {
228 MongodbConfig::enforce_one(prop)?;
229 }
230 Ok(())
231 }
232}
233
234impl MongodbSink {
235 pub fn new(param: SinkParam) -> Result<Self> {
236 let config = MongodbConfig::from_btreemap(param.properties.clone())?;
237 let pk_indices = param.downstream_pk.clone();
238 let is_append_only = param.sink_type.is_append_only();
239 let schema = param.schema();
240 Ok(Self {
241 config,
242 param,
243 schema,
244 pk_indices,
245 is_append_only,
246 })
247 }
248}
249
250impl TryFrom<SinkParam> for MongodbSink {
251 type Error = SinkError;
252
253 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
254 MongodbSink::new(param)
255 }
256}
257
258impl Sink for MongodbSink {
259 type Coordinator = DummySinkCommitCoordinator;
260 type LogSinker = AsyncTruncateLogSinkerOf<MongodbSinkWriter>;
261
262 const SINK_NAME: &'static str = MONGODB_SINK;
263
264 async fn validate(&self) -> Result<()> {
265 if !self.is_append_only {
266 if self.pk_indices.is_empty() {
267 return Err(SinkError::Config(anyhow!(
268 "Primary key not defined for upsert mongodb sink (please define in `primary_key` field)"
269 )));
270 }
271
272 if self
274 .schema
275 .fields
276 .iter()
277 .enumerate()
278 .any(|(i, field)| !self.pk_indices.contains(&i) && field.name == MONGODB_PK_NAME)
279 {
280 return Err(SinkError::Config(anyhow!(
281 "_id field must be the sink's primary key, but a non primary key field name is _id",
282 )));
283 }
284
285 if self.pk_indices.len() > 1
294 && self
295 .pk_indices
296 .iter()
297 .map(|&idx| self.schema.fields[idx].name.as_str())
298 .any(|field| field == MONGODB_PK_NAME)
299 {
300 return Err(SinkError::Config(anyhow!(
301 "primary key fields must not contain a field named _id"
302 )));
303 }
304 }
305
306 if let Err(err) = self.config.common.collection_name.parse::<Namespace>() {
307 return Err(SinkError::Config(anyhow!(err).context(format!(
308 "invalid collection.name {}",
309 self.config.common.collection_name
310 ))));
311 }
312
313 let client = self.config.common.build_client().await?;
315 let client = ClientGuard::new(self.param.sink_name.clone(), client);
316 client
317 .database("admin")
318 .run_command(doc! {"hello":1})
319 .await
320 .map_err(|err| {
321 SinkError::Mongodb(anyhow!(err).context("failed to send hello command to mongodb"))
322 })?;
323
324 if self.config.drop_collection_name_field && self.config.collection_name_field.is_none() {
325 return Err(SinkError::Config(anyhow!(
326 "collection.name.field must be specified when collection.name.field.drop is enabled"
327 )));
328 }
329
330 if let Some(coll_field) = &self.config.collection_name_field {
332 let fields = self.schema.fields();
333
334 let coll_field_index = fields
335 .iter()
336 .enumerate()
337 .find_map(|(index, field)| {
338 if &field.name == coll_field {
339 Some(index)
340 } else {
341 None
342 }
343 })
344 .ok_or(SinkError::Config(anyhow!(
345 "collection.name.field {} not found",
346 coll_field
347 )))?;
348
349 if fields[coll_field_index].data_type() != risingwave_common::types::DataType::Varchar {
350 return Err(SinkError::Config(anyhow!(
351 "the type of collection.name.field {} must be varchar",
352 coll_field
353 )));
354 }
355
356 if !self.is_append_only && self.pk_indices.contains(&coll_field_index) {
357 return Err(SinkError::Config(anyhow!(
358 "collection.name.field {} must not be equal to the primary key field",
359 coll_field
360 )));
361 }
362 }
363
364 Ok(())
365 }
366
367 async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
368 Ok(MongodbSinkWriter::new(
369 format!("{}-{}", writer_param.executor_id, self.param.sink_name),
370 self.config.clone(),
371 self.schema.clone(),
372 self.pk_indices.clone(),
373 self.is_append_only,
374 )
375 .await?
376 .into_log_sinker(MONGODB_SEND_FUTURE_BUFFER_MAX_SIZE))
377 }
378}
379
380use send_bulk_write_command_future::*;
381
382pub struct MongodbSinkWriter {
383 pub config: MongodbConfig,
384 payload_writer: MongodbPayloadWriter,
385 is_append_only: bool,
386}
387
388impl MongodbSinkWriter {
389 pub async fn new(
390 name: String,
391 config: MongodbConfig,
392 schema: Schema,
393 pk_indices: Vec<usize>,
394 is_append_only: bool,
395 ) -> Result<Self> {
396 let client = config.common.build_client().await?;
397
398 let default_namespace =
399 config
400 .common
401 .collection_name
402 .parse()
403 .map_err(|err: mongodb::error::Error| {
404 SinkError::Mongodb(anyhow!(err).context("parsing default namespace failed"))
405 })?;
406
407 let coll_name_field_index =
408 config
409 .collection_name_field
410 .as_ref()
411 .and_then(|coll_name_field| {
412 schema
413 .names_str()
414 .iter()
415 .position(|&name| coll_name_field == name)
416 });
417
418 let col_indices = if let Some(coll_name_field_index) = coll_name_field_index
419 && config.drop_collection_name_field
420 {
421 (0..schema.fields.len())
422 .filter(|idx| *idx != coll_name_field_index)
423 .collect_vec()
424 } else {
425 (0..schema.fields.len()).collect_vec()
426 };
427
428 let row_encoder = BsonEncoder::new(schema.clone(), Some(col_indices), pk_indices.clone());
429
430 let payload_writer = MongodbPayloadWriter::new(
431 schema,
432 pk_indices,
433 default_namespace,
434 coll_name_field_index,
435 ClientGuard::new(name, client),
436 row_encoder,
437 );
438
439 Ok(Self {
440 config,
441 payload_writer,
442 is_append_only,
443 })
444 }
445
446 fn append(&mut self, chunk: StreamChunk) -> Result<TryJoinAll<SendBulkWriteCommandFuture>> {
447 let mut insert_builder: HashMap<MongodbNamespace, InsertCommandBuilder> = HashMap::new();
448 for (op, row) in chunk.rows() {
449 if op != Op::Insert {
450 if let Ok(suppressed_count) = LOG_SUPPERSSER.check() {
451 tracing::warn!(
452 suppressed_count,
453 ?op,
454 ?row,
455 "non-insert op received in append-only mode"
456 );
457 }
458 continue;
459 }
460 self.payload_writer.append(&mut insert_builder, row)?;
461 }
462 Ok(self.payload_writer.flush_insert(insert_builder))
463 }
464
465 fn upsert(&mut self, chunk: StreamChunk) -> Result<TryJoinAll<SendBulkWriteCommandFuture>> {
466 let mut upsert_builder: HashMap<MongodbNamespace, UpsertCommandBuilder> = HashMap::new();
467 for (op, row) in chunk.rows() {
468 if op == Op::UpdateDelete {
469 continue;
471 }
472 self.payload_writer.upsert(&mut upsert_builder, op, row)?;
473 }
474 Ok(self.payload_writer.flush_upsert(upsert_builder))
475 }
476}
477
478pub type MongodbSinkDeliveryFuture = impl TryFuture<Ok = (), Error = SinkError> + Unpin + 'static;
479
480impl AsyncTruncateSinkWriter for MongodbSinkWriter {
481 type DeliveryFuture = MongodbSinkDeliveryFuture;
482
483 async fn write_chunk<'a>(
484 &'a mut self,
485 chunk: StreamChunk,
486 mut add_future: DeliveryFutureManagerAddFuture<'a, Self::DeliveryFuture>,
487 ) -> Result<()> {
488 let futures = if self.is_append_only {
489 self.append(chunk)?
490 } else {
491 self.upsert(chunk)?
492 };
493 add_future
494 .add_future_may_await(futures.map_ok(|_: Vec<()>| ()))
495 .await?;
496 Ok(())
497 }
498}
499
500struct InsertCommandBuilder {
501 coll: String,
502 inserts: Array,
503}
504
505impl InsertCommandBuilder {
506 fn new(coll: String) -> Self {
507 Self {
508 coll,
509 inserts: Array::new(),
510 }
511 }
512
513 fn append(&mut self, row: Document) {
514 self.inserts.push(Bson::Document(row));
515 }
516
517 fn build(self) -> Document {
518 doc! {
519 "insert": self.coll,
520 "ordered": true,
521 "documents": self.inserts,
522 }
523 }
524}
525
526struct UpsertCommandBuilder {
527 coll: String,
528 updates: Array,
529 deletes: HashMap<Vec<u8>, Document>,
530}
531
532impl UpsertCommandBuilder {
533 fn new(coll: String) -> Self {
534 Self {
535 coll,
536 updates: Array::new(),
537 deletes: HashMap::new(),
538 }
539 }
540
541 fn add_upsert(&mut self, pk: Document, row: Document) -> Result<()> {
542 let pk_data = mongodb::bson::to_vec(&pk).map_err(|err| {
543 SinkError::Mongodb(anyhow!(err).context("cannot serialize primary key"))
544 })?;
545 self.deletes.remove(&pk_data);
549
550 self.updates.push(bson!( {
551 "q": pk,
552 "u": bson!( {
553 "$set": row,
554 }),
555 "upsert": true,
556 "multi": false,
557 }));
558
559 Ok(())
560 }
561
562 fn add_delete(&mut self, pk: Document) -> Result<()> {
563 let pk_data = mongodb::bson::to_vec(&pk).map_err(|err| {
564 SinkError::Mongodb(anyhow!(err).context("cannot serialize primary key"))
565 })?;
566 self.deletes.insert(pk_data, pk);
567 Ok(())
568 }
569
570 fn build(self) -> (Option<Document>, Option<Document>) {
571 let (mut upsert_document, mut delete_document) = (None, None);
572 if !self.updates.is_empty() {
573 upsert_document = Some(doc! {
574 "update": self.coll.clone(),
575 "ordered": true,
576 "updates": self.updates,
577 });
578 }
579 if !self.deletes.is_empty() {
580 let deletes = self
581 .deletes
582 .into_values()
583 .map(|pk| {
584 bson!({
585 "q": pk,
586 "limit": 1,
587 })
588 })
589 .collect::<Array>();
590
591 delete_document = Some(doc! {
592 "delete": self.coll,
593 "ordered": true,
594 "deletes": deletes,
595 });
596 }
597 (upsert_document, delete_document)
598 }
599}
600
601type MongodbNamespace = (String, String);
602
603struct MongodbPayloadWriter {
606 schema: Schema,
607 pk_indices: Vec<usize>,
608 default_namespace: Namespace,
609 coll_name_field_index: Option<usize>,
610 client: ClientGuard,
611 row_encoder: BsonEncoder,
612}
613
614impl MongodbPayloadWriter {
615 fn new(
616 schema: Schema,
617 pk_indices: Vec<usize>,
618 default_namespace: Namespace,
619 coll_name_field_index: Option<usize>,
620 client: ClientGuard,
621 row_encoder: BsonEncoder,
622 ) -> Self {
623 Self {
624 schema,
625 pk_indices,
626 default_namespace,
627 coll_name_field_index,
628 client,
629 row_encoder,
630 }
631 }
632
633 fn extract_namespace_from_row_ref(&self, row: RowRef<'_>) -> MongodbNamespace {
634 let ns = self.coll_name_field_index.and_then(|coll_name_field_index| {
635 match row.datum_at(coll_name_field_index) {
636 Some(ScalarRefImpl::Utf8(v)) => match v.parse::<Namespace>() {
637 Ok(ns) => Some(ns),
638 Err(err) => {
639 if let Ok(suppressed_count) = LOG_SUPPERSSER.check() {
640 tracing::warn!(
641 suppressed_count,
642 error = %err.as_report(),
643 collection_name = %v,
644 "parsing collection name failed, fallback to use default collection.name"
645 );
646 }
647 None
648 }
649 },
650 _ => {
651 if let Ok(suppressed_count) = LOG_SUPPERSSER.check() {
652 tracing::warn!(
653 suppressed_count,
654 "the value of collection.name.field is null, fallback to use default collection.name"
655 );
656 }
657 None
658 }
659 }
660 });
661 match ns {
662 Some(ns) => (ns.db, ns.coll),
663 None => (
664 self.default_namespace.db.clone(),
665 self.default_namespace.coll.clone(),
666 ),
667 }
668 }
669
670 fn append(
671 &mut self,
672 insert_builder: &mut HashMap<MongodbNamespace, InsertCommandBuilder>,
673 row: RowRef<'_>,
674 ) -> Result<()> {
675 let document = self.row_encoder.encode(row)?;
676 let ns = self.extract_namespace_from_row_ref(row);
677 let coll = ns.1.clone();
678
679 insert_builder
680 .entry(ns)
681 .or_insert_with(|| InsertCommandBuilder::new(coll))
682 .append(document);
683 Ok(())
684 }
685
686 fn upsert(
687 &mut self,
688 upsert_builder: &mut HashMap<MongodbNamespace, UpsertCommandBuilder>,
689 op: Op,
690 row: RowRef<'_>,
691 ) -> Result<()> {
692 let mut document = self.row_encoder.encode(row)?;
693 let ns = self.extract_namespace_from_row_ref(row);
694 let coll = ns.1.clone();
695
696 let pk = self.row_encoder.construct_pk(row);
697
698 if self.pk_indices.len() > 1
700 || self.schema.fields[self.pk_indices[0]].name != MONGODB_PK_NAME
701 {
702 document.insert(MONGODB_PK_NAME, pk.clone());
704 }
705
706 let pk = doc! {MONGODB_PK_NAME: pk};
707 match op {
708 Op::Insert | Op::UpdateInsert => upsert_builder
709 .entry(ns)
710 .or_insert_with(|| UpsertCommandBuilder::new(coll))
711 .add_upsert(pk, document)?,
712 Op::UpdateDelete => (),
713 Op::Delete => upsert_builder
714 .entry(ns)
715 .or_insert_with(|| UpsertCommandBuilder::new(coll))
716 .add_delete(pk)?,
717 }
718 Ok(())
719 }
720
721 fn flush_insert(
722 &self,
723 insert_builder: HashMap<MongodbNamespace, InsertCommandBuilder>,
724 ) -> TryJoinAll<SendBulkWriteCommandFuture> {
725 let futures = insert_builder.into_iter().map(|(ns, builder)| {
729 let db = self.client.database(&ns.0);
730 send_bulk_write_commands(db, Some(builder.build()), None)
731 });
732 try_join_all(futures)
733 }
734
735 fn flush_upsert(
736 &self,
737 upsert_builder: HashMap<MongodbNamespace, UpsertCommandBuilder>,
738 ) -> TryJoinAll<SendBulkWriteCommandFuture> {
739 let futures = upsert_builder.into_iter().map(|(ns, builder)| {
743 let (upsert, delete) = builder.build();
744 let db = self.client.database(&ns.0);
748 send_bulk_write_commands(db, upsert, delete)
749 });
750 try_join_all(futures)
751 }
752}