1use std::collections::{BTreeMap, HashMap};
16use std::ops::Deref;
17use std::sync::LazyLock;
18
19use anyhow::anyhow;
20use futures::TryFutureExt;
21use futures::future::{TryJoinAll, try_join_all};
22use futures::prelude::TryFuture;
23use itertools::Itertools;
24use mongodb::bson::{Array, Bson, Document, bson, doc};
25use mongodb::{Client, Namespace};
26use risingwave_common::array::{Op, RowRef, StreamChunk};
27use risingwave_common::catalog::Schema;
28use risingwave_common::log::LogSuppresser;
29use risingwave_common::row::Row;
30use risingwave_common::types::ScalarRefImpl;
31use serde_derive::Deserialize;
32use serde_with::{DisplayFromStr, serde_as};
33use thiserror_ext::AsReport;
34use with_options::WithOptions;
35
36use super::encoder::BsonEncoder;
37use super::log_store::DeliveryFutureManagerAddFuture;
38use super::writer::{
39 AsyncTruncateLogSinkerOf, AsyncTruncateSinkWriter, AsyncTruncateSinkWriterExt,
40};
41use crate::connector_common::MongodbCommon;
42use crate::deserialize_bool_from_string;
43use crate::enforce_secret::EnforceSecret;
44use crate::sink::encoder::RowEncoder;
45use crate::sink::{
46 Result, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT, Sink, SinkError, SinkParam,
47 SinkWriterParam,
48};
49
50mod send_bulk_write_command_future {
51 use core::future::Future;
52
53 use anyhow::anyhow;
54 use mongodb::Database;
55 use mongodb::bson::Document;
56
57 use crate::sink::{Result, SinkError};
58
59 pub(super) type SendBulkWriteCommandFuture = impl Future<Output = Result<()>> + 'static;
60
61 #[define_opaque(SendBulkWriteCommandFuture)]
62 pub(super) fn send_bulk_write_commands(
63 db: Database,
64 upsert: Option<Document>,
65 delete: Option<Document>,
66 ) -> SendBulkWriteCommandFuture {
67 async move {
68 if let Some(upsert) = upsert {
69 send_bulk_write_command(db.clone(), upsert).await?;
70 }
71 if let Some(delete) = delete {
72 send_bulk_write_command(db, delete).await?;
73 }
74 Ok(())
75 }
76 }
77
78 async fn send_bulk_write_command(db: Database, command: Document) -> Result<()> {
79 let result = db.run_command(command).await.map_err(|err| {
80 SinkError::Mongodb(anyhow!(err).context(format!(
81 "sending bulk write command failed, database: {}",
82 db.name()
83 )))
84 })?;
85
86 if let Ok(ok) = result.get_i32("ok")
87 && ok != 1
88 {
89 return Err(SinkError::Mongodb(anyhow!("bulk write write errors")));
90 }
91
92 if let Ok(write_errors) = result.get_array("writeErrors") {
93 return Err(SinkError::Mongodb(anyhow!(
94 "bulk write respond with write errors: {:?}",
95 write_errors,
96 )));
97 }
98
99 if let Ok(write_concern_error) = result.get_array("writeConcernError") {
100 return Err(SinkError::Mongodb(anyhow!(
101 "bulk write respond with write errors: {:?}",
102 write_concern_error,
103 )));
104 }
105
106 Ok(())
107 }
108}
109
110pub const MONGODB_SINK: &str = "mongodb";
111const MONGODB_SEND_FUTURE_BUFFER_MAX_SIZE: usize = 4096;
112
113pub const MONGODB_PK_NAME: &str = "_id";
114
115static LOG_SUPPERSSER: LazyLock<LogSuppresser> = LazyLock::new(LogSuppresser::default);
116
117const fn _default_bulk_write_max_entries() -> usize {
118 1024
119}
120#[serde_as]
121#[derive(Clone, Debug, Deserialize, WithOptions)]
122pub struct MongodbConfig {
123 #[serde(flatten)]
124 pub common: MongodbCommon,
125
126 pub r#type: String, #[serde(rename = "collection.name.field")]
132 pub collection_name_field: Option<String>,
133
134 #[serde(
138 default,
139 deserialize_with = "deserialize_bool_from_string",
140 rename = "collection.name.field.drop"
141 )]
142 pub drop_collection_name_field: bool,
143
144 #[serde(
146 rename = "mongodb.bulk_write.max_entries",
147 default = "_default_bulk_write_max_entries"
148 )]
149 #[serde_as(as = "DisplayFromStr")]
150 #[deprecated]
151 pub bulk_write_max_entries: usize,
152}
153
154impl EnforceSecret for MongodbConfig {
155 fn enforce_one(prop: &str) -> crate::error::ConnectorResult<()> {
156 MongodbCommon::enforce_one(prop)
157 }
158}
159
160impl MongodbConfig {
161 pub fn from_btreemap(properties: BTreeMap<String, String>) -> crate::sink::Result<Self> {
162 let config =
163 serde_json::from_value::<MongodbConfig>(serde_json::to_value(properties).unwrap())
164 .map_err(|e| SinkError::Config(anyhow!(e)))?;
165 if config.r#type != SINK_TYPE_APPEND_ONLY && config.r#type != SINK_TYPE_UPSERT {
166 return Err(SinkError::Config(anyhow!(
167 "`{}` must be {}, or {}",
168 SINK_TYPE_OPTION,
169 SINK_TYPE_APPEND_ONLY,
170 SINK_TYPE_UPSERT
171 )));
172 }
173 Ok(config)
174 }
175}
176
177struct ClientGuard {
184 _tx: tokio::sync::oneshot::Sender<()>,
185 client: Client,
186}
187
188impl ClientGuard {
189 fn new(name: String, client: Client) -> Self {
190 let client_copy = client.clone();
191 let (_tx, rx) = tokio::sync::oneshot::channel::<()>();
192 tokio::spawn(async move {
193 tracing::debug!(%name, "waiting for client to shut down");
194 let _ = rx.await;
195 tracing::debug!(%name, "sender dropped now calling client's shutdown");
196 client_copy.shutdown().await;
201 tracing::debug!(%name, "client shutdown succeeded");
202 });
203 Self { _tx, client }
204 }
205}
206
207impl Deref for ClientGuard {
208 type Target = Client;
209
210 fn deref(&self) -> &Self::Target {
211 &self.client
212 }
213}
214
215#[derive(Debug)]
216pub struct MongodbSink {
217 pub config: MongodbConfig,
218 param: SinkParam,
219 schema: Schema,
220 pk_indices: Vec<usize>,
221 is_append_only: bool,
222}
223
224impl EnforceSecret for MongodbSink {
225 fn enforce_secret<'a>(
226 prop_iter: impl Iterator<Item = &'a str>,
227 ) -> crate::sink::ConnectorResult<()> {
228 for prop in prop_iter {
229 MongodbConfig::enforce_one(prop)?;
230 }
231 Ok(())
232 }
233}
234
235impl MongodbSink {
236 pub fn new(param: SinkParam) -> Result<Self> {
237 let config = MongodbConfig::from_btreemap(param.properties.clone())?;
238 let pk_indices = param.downstream_pk.clone();
239 let is_append_only = param.sink_type.is_append_only();
240 let schema = param.schema();
241 Ok(Self {
242 config,
243 param,
244 schema,
245 pk_indices,
246 is_append_only,
247 })
248 }
249}
250
251impl TryFrom<SinkParam> for MongodbSink {
252 type Error = SinkError;
253
254 fn try_from(param: SinkParam) -> std::result::Result<Self, Self::Error> {
255 MongodbSink::new(param)
256 }
257}
258
259impl Sink for MongodbSink {
260 type LogSinker = AsyncTruncateLogSinkerOf<MongodbSinkWriter>;
261
262 const SINK_NAME: &'static str = MONGODB_SINK;
263
264 async fn validate(&self) -> Result<()> {
265 if !self.is_append_only {
266 if self.pk_indices.is_empty() {
267 return Err(SinkError::Config(anyhow!(
268 "Primary key not defined for upsert mongodb sink (please define in `primary_key` field)"
269 )));
270 }
271
272 if self
274 .schema
275 .fields
276 .iter()
277 .enumerate()
278 .any(|(i, field)| !self.pk_indices.contains(&i) && field.name == MONGODB_PK_NAME)
279 {
280 return Err(SinkError::Config(anyhow!(
281 "_id field must be the sink's primary key, but a non primary key field name is _id",
282 )));
283 }
284
285 if self.pk_indices.len() > 1
294 && self
295 .pk_indices
296 .iter()
297 .map(|&idx| self.schema.fields[idx].name.as_str())
298 .any(|field| field == MONGODB_PK_NAME)
299 {
300 return Err(SinkError::Config(anyhow!(
301 "primary key fields must not contain a field named _id"
302 )));
303 }
304 }
305
306 if let Err(err) = self.config.common.collection_name.parse::<Namespace>() {
307 return Err(SinkError::Config(anyhow!(err).context(format!(
308 "invalid collection.name {}",
309 self.config.common.collection_name
310 ))));
311 }
312
313 let client = self.config.common.build_client().await?;
315 let client = ClientGuard::new(self.param.sink_name.clone(), client);
316 client
317 .database("admin")
318 .run_command(doc! {"hello":1})
319 .await
320 .map_err(|err| {
321 SinkError::Mongodb(anyhow!(err).context("failed to send hello command to mongodb"))
322 })?;
323
324 if self.config.drop_collection_name_field && self.config.collection_name_field.is_none() {
325 return Err(SinkError::Config(anyhow!(
326 "collection.name.field must be specified when collection.name.field.drop is enabled"
327 )));
328 }
329
330 if let Some(coll_field) = &self.config.collection_name_field {
332 let fields = self.schema.fields();
333
334 let coll_field_index = fields
335 .iter()
336 .enumerate()
337 .find_map(|(index, field)| {
338 if &field.name == coll_field {
339 Some(index)
340 } else {
341 None
342 }
343 })
344 .ok_or(SinkError::Config(anyhow!(
345 "collection.name.field {} not found",
346 coll_field
347 )))?;
348
349 if fields[coll_field_index].data_type() != risingwave_common::types::DataType::Varchar {
350 return Err(SinkError::Config(anyhow!(
351 "the type of collection.name.field {} must be varchar",
352 coll_field
353 )));
354 }
355
356 if !self.is_append_only && self.pk_indices.contains(&coll_field_index) {
357 return Err(SinkError::Config(anyhow!(
358 "collection.name.field {} must not be equal to the primary key field",
359 coll_field
360 )));
361 }
362 }
363
364 Ok(())
365 }
366
367 async fn new_log_sinker(&self, writer_param: SinkWriterParam) -> Result<Self::LogSinker> {
368 Ok(MongodbSinkWriter::new(
369 format!("{}-{}", writer_param.executor_id, self.param.sink_name),
370 self.config.clone(),
371 self.schema.clone(),
372 self.pk_indices.clone(),
373 self.is_append_only,
374 )
375 .await?
376 .into_log_sinker(MONGODB_SEND_FUTURE_BUFFER_MAX_SIZE))
377 }
378}
379
380use send_bulk_write_command_future::*;
381
382pub struct MongodbSinkWriter {
383 pub config: MongodbConfig,
384 payload_writer: MongodbPayloadWriter,
385 is_append_only: bool,
386}
387
388impl MongodbSinkWriter {
389 pub async fn new(
390 name: String,
391 config: MongodbConfig,
392 schema: Schema,
393 pk_indices: Vec<usize>,
394 is_append_only: bool,
395 ) -> Result<Self> {
396 let client = config.common.build_client().await?;
397
398 let default_namespace =
399 config
400 .common
401 .collection_name
402 .parse()
403 .map_err(|err: mongodb::error::Error| {
404 SinkError::Mongodb(anyhow!(err).context("parsing default namespace failed"))
405 })?;
406
407 let coll_name_field_index =
408 config
409 .collection_name_field
410 .as_ref()
411 .and_then(|coll_name_field| {
412 schema
413 .names_str()
414 .iter()
415 .position(|&name| coll_name_field == name)
416 });
417
418 let col_indices = if let Some(coll_name_field_index) = coll_name_field_index
419 && config.drop_collection_name_field
420 {
421 (0..schema.fields.len())
422 .filter(|idx| *idx != coll_name_field_index)
423 .collect_vec()
424 } else {
425 (0..schema.fields.len()).collect_vec()
426 };
427
428 let row_encoder = BsonEncoder::new(schema.clone(), Some(col_indices), pk_indices.clone());
429
430 let payload_writer = MongodbPayloadWriter::new(
431 schema,
432 pk_indices,
433 default_namespace,
434 coll_name_field_index,
435 ClientGuard::new(name, client),
436 row_encoder,
437 );
438
439 Ok(Self {
440 config,
441 payload_writer,
442 is_append_only,
443 })
444 }
445
446 fn append(&mut self, chunk: StreamChunk) -> Result<TryJoinAll<SendBulkWriteCommandFuture>> {
447 let mut insert_builder: HashMap<MongodbNamespace, InsertCommandBuilder> = HashMap::new();
448 for (op, row) in chunk.rows() {
449 if op != Op::Insert {
450 if let Ok(suppressed_count) = LOG_SUPPERSSER.check() {
451 tracing::warn!(
452 suppressed_count,
453 ?op,
454 ?row,
455 "non-insert op received in append-only mode"
456 );
457 }
458 continue;
459 }
460 self.payload_writer.append(&mut insert_builder, row)?;
461 }
462 Ok(self.payload_writer.flush_insert(insert_builder))
463 }
464
465 fn upsert(&mut self, chunk: StreamChunk) -> Result<TryJoinAll<SendBulkWriteCommandFuture>> {
466 let mut upsert_builder: HashMap<MongodbNamespace, UpsertCommandBuilder> = HashMap::new();
467 for (op, row) in chunk.rows() {
468 if op == Op::UpdateDelete {
469 continue;
471 }
472 self.payload_writer.upsert(&mut upsert_builder, op, row)?;
473 }
474 Ok(self.payload_writer.flush_upsert(upsert_builder))
475 }
476}
477
478pub type MongodbSinkDeliveryFuture = impl TryFuture<Ok = (), Error = SinkError> + Unpin + 'static;
479
480impl AsyncTruncateSinkWriter for MongodbSinkWriter {
481 type DeliveryFuture = MongodbSinkDeliveryFuture;
482
483 #[define_opaque(MongodbSinkDeliveryFuture)]
484 async fn write_chunk<'a>(
485 &'a mut self,
486 chunk: StreamChunk,
487 mut add_future: DeliveryFutureManagerAddFuture<'a, Self::DeliveryFuture>,
488 ) -> Result<()> {
489 let futures = if self.is_append_only {
490 self.append(chunk)?
491 } else {
492 self.upsert(chunk)?
493 };
494 add_future
495 .add_future_may_await(futures.map_ok(|_: Vec<()>| ()))
496 .await?;
497 Ok(())
498 }
499}
500
501struct InsertCommandBuilder {
502 coll: String,
503 inserts: Array,
504}
505
506impl InsertCommandBuilder {
507 fn new(coll: String) -> Self {
508 Self {
509 coll,
510 inserts: Array::new(),
511 }
512 }
513
514 fn append(&mut self, row: Document) {
515 self.inserts.push(Bson::Document(row));
516 }
517
518 fn build(self) -> Document {
519 doc! {
520 "insert": self.coll,
521 "ordered": true,
522 "documents": self.inserts,
523 }
524 }
525}
526
527struct UpsertCommandBuilder {
528 coll: String,
529 updates: Array,
530 deletes: HashMap<Vec<u8>, Document>,
531}
532
533impl UpsertCommandBuilder {
534 fn new(coll: String) -> Self {
535 Self {
536 coll,
537 updates: Array::new(),
538 deletes: HashMap::new(),
539 }
540 }
541
542 fn add_upsert(&mut self, pk: Document, row: Document) -> Result<()> {
543 let pk_data = mongodb::bson::to_vec(&pk).map_err(|err| {
544 SinkError::Mongodb(anyhow!(err).context("cannot serialize primary key"))
545 })?;
546 self.deletes.remove(&pk_data);
550
551 self.updates.push(bson!( {
552 "q": pk,
553 "u": bson!( {
554 "$set": row,
555 }),
556 "upsert": true,
557 "multi": false,
558 }));
559
560 Ok(())
561 }
562
563 fn add_delete(&mut self, pk: Document) -> Result<()> {
564 let pk_data = mongodb::bson::to_vec(&pk).map_err(|err| {
565 SinkError::Mongodb(anyhow!(err).context("cannot serialize primary key"))
566 })?;
567 self.deletes.insert(pk_data, pk);
568 Ok(())
569 }
570
571 fn build(self) -> (Option<Document>, Option<Document>) {
572 let (mut upsert_document, mut delete_document) = (None, None);
573 if !self.updates.is_empty() {
574 upsert_document = Some(doc! {
575 "update": self.coll.clone(),
576 "ordered": true,
577 "updates": self.updates,
578 });
579 }
580 if !self.deletes.is_empty() {
581 let deletes = self
582 .deletes
583 .into_values()
584 .map(|pk| {
585 bson!({
586 "q": pk,
587 "limit": 1,
588 })
589 })
590 .collect::<Array>();
591
592 delete_document = Some(doc! {
593 "delete": self.coll,
594 "ordered": true,
595 "deletes": deletes,
596 });
597 }
598 (upsert_document, delete_document)
599 }
600}
601
602type MongodbNamespace = (String, String);
603
604struct MongodbPayloadWriter {
607 schema: Schema,
608 pk_indices: Vec<usize>,
609 default_namespace: Namespace,
610 coll_name_field_index: Option<usize>,
611 client: ClientGuard,
612 row_encoder: BsonEncoder,
613}
614
615impl MongodbPayloadWriter {
616 fn new(
617 schema: Schema,
618 pk_indices: Vec<usize>,
619 default_namespace: Namespace,
620 coll_name_field_index: Option<usize>,
621 client: ClientGuard,
622 row_encoder: BsonEncoder,
623 ) -> Self {
624 Self {
625 schema,
626 pk_indices,
627 default_namespace,
628 coll_name_field_index,
629 client,
630 row_encoder,
631 }
632 }
633
634 fn extract_namespace_from_row_ref(&self, row: RowRef<'_>) -> MongodbNamespace {
635 let ns = self.coll_name_field_index.and_then(|coll_name_field_index| {
636 match row.datum_at(coll_name_field_index) {
637 Some(ScalarRefImpl::Utf8(v)) => match v.parse::<Namespace>() {
638 Ok(ns) => Some(ns),
639 Err(err) => {
640 if let Ok(suppressed_count) = LOG_SUPPERSSER.check() {
641 tracing::warn!(
642 suppressed_count,
643 error = %err.as_report(),
644 collection_name = %v,
645 "parsing collection name failed, fallback to use default collection.name"
646 );
647 }
648 None
649 }
650 },
651 _ => {
652 if let Ok(suppressed_count) = LOG_SUPPERSSER.check() {
653 tracing::warn!(
654 suppressed_count,
655 "the value of collection.name.field is null, fallback to use default collection.name"
656 );
657 }
658 None
659 }
660 }
661 });
662 match ns {
663 Some(ns) => (ns.db, ns.coll),
664 None => (
665 self.default_namespace.db.clone(),
666 self.default_namespace.coll.clone(),
667 ),
668 }
669 }
670
671 fn append(
672 &mut self,
673 insert_builder: &mut HashMap<MongodbNamespace, InsertCommandBuilder>,
674 row: RowRef<'_>,
675 ) -> Result<()> {
676 let document = self.row_encoder.encode(row)?;
677 let ns = self.extract_namespace_from_row_ref(row);
678 let coll = ns.1.clone();
679
680 insert_builder
681 .entry(ns)
682 .or_insert_with(|| InsertCommandBuilder::new(coll))
683 .append(document);
684 Ok(())
685 }
686
687 fn upsert(
688 &mut self,
689 upsert_builder: &mut HashMap<MongodbNamespace, UpsertCommandBuilder>,
690 op: Op,
691 row: RowRef<'_>,
692 ) -> Result<()> {
693 let mut document = self.row_encoder.encode(row)?;
694 let ns = self.extract_namespace_from_row_ref(row);
695 let coll = ns.1.clone();
696
697 let pk = self.row_encoder.construct_pk(row);
698
699 if self.pk_indices.len() > 1
701 || self.schema.fields[self.pk_indices[0]].name != MONGODB_PK_NAME
702 {
703 document.insert(MONGODB_PK_NAME, pk.clone());
705 }
706
707 let pk = doc! {MONGODB_PK_NAME: pk};
708 match op {
709 Op::Insert | Op::UpdateInsert => upsert_builder
710 .entry(ns)
711 .or_insert_with(|| UpsertCommandBuilder::new(coll))
712 .add_upsert(pk, document)?,
713 Op::UpdateDelete => (),
714 Op::Delete => upsert_builder
715 .entry(ns)
716 .or_insert_with(|| UpsertCommandBuilder::new(coll))
717 .add_delete(pk)?,
718 }
719 Ok(())
720 }
721
722 fn flush_insert(
723 &self,
724 insert_builder: HashMap<MongodbNamespace, InsertCommandBuilder>,
725 ) -> TryJoinAll<SendBulkWriteCommandFuture> {
726 let futures = insert_builder.into_iter().map(|(ns, builder)| {
730 let db = self.client.database(&ns.0);
731 send_bulk_write_commands(db, Some(builder.build()), None)
732 });
733 try_join_all(futures)
734 }
735
736 fn flush_upsert(
737 &self,
738 upsert_builder: HashMap<MongodbNamespace, UpsertCommandBuilder>,
739 ) -> TryJoinAll<SendBulkWriteCommandFuture> {
740 let futures = upsert_builder.into_iter().map(|(ns, builder)| {
744 let (upsert, delete) = builder.build();
745 let db = self.client.database(&ns.0);
749 send_bulk_write_commands(db, upsert, delete)
750 });
751 try_join_all(futures)
752 }
753}