risingwave_stream/executor/
asof_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use std::collections::{BTreeMap, HashSet};
15use std::marker::PhantomData;
16use std::ops::Bound;
17use std::time::Duration;
18
19use either::Either;
20use itertools::Itertools;
21use multimap::MultiMap;
22use risingwave_common::array::Op;
23use risingwave_common::hash::{HashKey, NullBitmap};
24use risingwave_common::util::epoch::EpochPair;
25use risingwave_common::util::iter_util::ZipEqDebug;
26use tokio::time::Instant;
27
28use self::builder::JoinChunkBuilder;
29use super::barrier_align::*;
30use super::join::hash_join::*;
31use super::join::*;
32use super::watermark::*;
33use crate::executor::join::builder::JoinStreamChunkBuilder;
34use crate::executor::join::row::JoinEncoding;
35use crate::executor::prelude::*;
36
37/// Evict the cache every n rows.
38const EVICT_EVERY_N_ROWS: u32 = 16;
39
40fn is_subset(vec1: Vec<usize>, vec2: Vec<usize>) -> bool {
41    HashSet::<usize>::from_iter(vec1).is_subset(&vec2.into_iter().collect())
42}
43
44pub struct JoinParams {
45    /// Indices of the join keys
46    pub join_key_indices: Vec<usize>,
47    /// Indices of the input pk after dedup
48    pub deduped_pk_indices: Vec<usize>,
49}
50
51impl JoinParams {
52    pub fn new(join_key_indices: Vec<usize>, deduped_pk_indices: Vec<usize>) -> Self {
53        Self {
54            join_key_indices,
55            deduped_pk_indices,
56        }
57    }
58}
59
60struct JoinSide<K: HashKey, S: StateStore, E: JoinEncoding> {
61    /// Store all data from a one side stream
62    ht: JoinHashMap<K, S, E>,
63    /// Indices of the join key columns
64    join_key_indices: Vec<usize>,
65    /// The data type of all columns without degree.
66    all_data_types: Vec<DataType>,
67    /// The start position for the side in output new columns
68    start_pos: usize,
69    /// The mapping from input indices of a side to output columns.
70    i2o_mapping: Vec<(usize, usize)>,
71    i2o_mapping_indexed: MultiMap<usize, usize>,
72    /// Whether degree table is needed for this side.
73    need_degree_table: bool,
74    _marker: std::marker::PhantomData<E>,
75}
76
77impl<K: HashKey, S: StateStore, E: JoinEncoding> std::fmt::Debug for JoinSide<K, S, E> {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        f.debug_struct("JoinSide")
80            .field("join_key_indices", &self.join_key_indices)
81            .field("col_types", &self.all_data_types)
82            .field("start_pos", &self.start_pos)
83            .field("i2o_mapping", &self.i2o_mapping)
84            .field("need_degree_table", &self.need_degree_table)
85            .finish()
86    }
87}
88
89impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinSide<K, S, E> {
90    // WARNING: Please do not call this until we implement it.
91    fn is_dirty(&self) -> bool {
92        unimplemented!()
93    }
94
95    #[expect(dead_code)]
96    fn clear_cache(&mut self) {
97        assert!(
98            !self.is_dirty(),
99            "cannot clear cache while states of hash join are dirty"
100        );
101
102        // TODO: not working with rearranged chain
103        // self.ht.clear();
104    }
105
106    pub async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
107        self.ht.init(epoch).await
108    }
109}
110
111/// `AsOfJoinExecutor` takes two input streams and runs equal hash join on them.
112/// The output columns are the concatenation of left and right columns.
113pub struct AsOfJoinExecutor<
114    K: HashKey,
115    S: StateStore,
116    const T: AsOfJoinTypePrimitive,
117    E: JoinEncoding,
118> {
119    ctx: ActorContextRef,
120    info: ExecutorInfo,
121
122    /// Left input executor
123    input_l: Option<Executor>,
124    /// Right input executor
125    input_r: Option<Executor>,
126    /// The data types of the formed new columns
127    actual_output_data_types: Vec<DataType>,
128    /// The parameters of the left join executor
129    side_l: JoinSide<K, S, E>,
130    /// The parameters of the right join executor
131    side_r: JoinSide<K, S, E>,
132
133    metrics: Arc<StreamingMetrics>,
134    /// The maximum size of the chunk produced by executor at a time
135    chunk_size: usize,
136    /// Count the messages received, clear to 0 when counted to `EVICT_EVERY_N_MESSAGES`
137    cnt_rows_received: u32,
138
139    /// watermark column index -> `BufferedWatermarks`
140    watermark_buffers: BTreeMap<usize, BufferedWatermarks<SideTypePrimitive>>,
141
142    high_join_amplification_threshold: usize,
143    /// `AsOf` join description
144    asof_desc: AsOfDesc,
145}
146
147impl<K: HashKey, S: StateStore, const T: AsOfJoinTypePrimitive, E: JoinEncoding> std::fmt::Debug
148    for AsOfJoinExecutor<K, S, T, E>
149{
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151        f.debug_struct("AsOfJoinExecutor")
152            .field("join_type", &T)
153            .field("input_left", &self.input_l.as_ref().unwrap().identity())
154            .field("input_right", &self.input_r.as_ref().unwrap().identity())
155            .field("side_l", &self.side_l)
156            .field("side_r", &self.side_r)
157            .field("stream_key", &self.info.stream_key)
158            .field("schema", &self.info.schema)
159            .field("actual_output_data_types", &self.actual_output_data_types)
160            .finish()
161    }
162}
163
164impl<K: HashKey, S: StateStore, const T: AsOfJoinTypePrimitive, E: JoinEncoding> Execute
165    for AsOfJoinExecutor<K, S, T, E>
166{
167    fn execute(self: Box<Self>) -> BoxedMessageStream {
168        self.into_stream().boxed()
169    }
170}
171
172struct EqJoinArgs<'a, K: HashKey, S: StateStore, E: JoinEncoding> {
173    ctx: &'a ActorContextRef,
174    side_l: &'a mut JoinSide<K, S, E>,
175    side_r: &'a mut JoinSide<K, S, E>,
176    asof_desc: &'a AsOfDesc,
177    actual_output_data_types: &'a [DataType],
178    // inequality_watermarks: &'a Watermark,
179    chunk: StreamChunk,
180    chunk_size: usize,
181    cnt_rows_received: &'a mut u32,
182    high_join_amplification_threshold: usize,
183}
184
185impl<K: HashKey, S: StateStore, const T: AsOfJoinTypePrimitive, E: JoinEncoding>
186    AsOfJoinExecutor<K, S, T, E>
187{
188    #[allow(clippy::too_many_arguments)]
189    pub fn new(
190        ctx: ActorContextRef,
191        info: ExecutorInfo,
192        input_l: Executor,
193        input_r: Executor,
194        params_l: JoinParams,
195        params_r: JoinParams,
196        null_safe: Vec<bool>,
197        output_indices: Vec<usize>,
198        state_table_l: StateTable<S>,
199        state_table_r: StateTable<S>,
200        watermark_epoch: AtomicU64Ref,
201        metrics: Arc<StreamingMetrics>,
202        chunk_size: usize,
203        high_join_amplification_threshold: usize,
204        asof_desc: AsOfDesc,
205    ) -> Self {
206        let side_l_column_n = input_l.schema().len();
207
208        let schema_fields = [
209            input_l.schema().fields.clone(),
210            input_r.schema().fields.clone(),
211        ]
212        .concat();
213
214        let original_output_data_types = schema_fields
215            .iter()
216            .map(|field| field.data_type())
217            .collect_vec();
218        let actual_output_data_types = output_indices
219            .iter()
220            .map(|&idx| original_output_data_types[idx].clone())
221            .collect_vec();
222
223        // Data types of of hash join state.
224        let state_all_data_types_l = input_l.schema().data_types();
225        let state_all_data_types_r = input_r.schema().data_types();
226
227        let state_pk_indices_l = input_l.stream_key().to_vec();
228        let state_pk_indices_r = input_r.stream_key().to_vec();
229
230        let state_join_key_indices_l = params_l.join_key_indices;
231        let state_join_key_indices_r = params_r.join_key_indices;
232
233        // If pk is contained in join key.
234        let pk_contained_in_jk_l = is_subset(state_pk_indices_l, state_join_key_indices_l.clone());
235        let pk_contained_in_jk_r = is_subset(state_pk_indices_r, state_join_key_indices_r.clone());
236
237        let join_key_data_types_l = state_join_key_indices_l
238            .iter()
239            .map(|idx| state_all_data_types_l[*idx].clone())
240            .collect_vec();
241
242        let join_key_data_types_r = state_join_key_indices_r
243            .iter()
244            .map(|idx| state_all_data_types_r[*idx].clone())
245            .collect_vec();
246
247        assert_eq!(join_key_data_types_l, join_key_data_types_r);
248
249        let null_matched = K::Bitmap::from_bool_vec(null_safe);
250
251        let (left_to_output, right_to_output) = {
252            let (left_len, right_len) = if is_left_semi_or_anti(T) {
253                (state_all_data_types_l.len(), 0usize)
254            } else if is_right_semi_or_anti(T) {
255                (0usize, state_all_data_types_r.len())
256            } else {
257                (state_all_data_types_l.len(), state_all_data_types_r.len())
258            };
259            JoinStreamChunkBuilder::get_i2o_mapping(&output_indices, left_len, right_len)
260        };
261
262        let l2o_indexed = MultiMap::from_iter(left_to_output.iter().copied());
263        let r2o_indexed = MultiMap::from_iter(right_to_output.iter().copied());
264
265        // handle inequality watermarks
266        // https://github.com/risingwavelabs/risingwave/issues/18503
267        // let inequality_watermarks = None;
268        let watermark_buffers = BTreeMap::new();
269
270        let inequal_key_idx_l = Some(asof_desc.left_idx);
271        let inequal_key_idx_r = Some(asof_desc.right_idx);
272
273        Self {
274            ctx: ctx.clone(),
275            info,
276            input_l: Some(input_l),
277            input_r: Some(input_r),
278            actual_output_data_types,
279            side_l: JoinSide {
280                ht: JoinHashMap::new(
281                    watermark_epoch.clone(),
282                    join_key_data_types_l,
283                    state_join_key_indices_l.clone(),
284                    state_all_data_types_l.clone(),
285                    state_table_l,
286                    params_l.deduped_pk_indices,
287                    None,
288                    null_matched.clone(),
289                    pk_contained_in_jk_l,
290                    inequal_key_idx_l,
291                    metrics.clone(),
292                    ctx.id,
293                    ctx.fragment_id,
294                    "left",
295                ),
296                join_key_indices: state_join_key_indices_l,
297                all_data_types: state_all_data_types_l,
298                i2o_mapping: left_to_output,
299                i2o_mapping_indexed: l2o_indexed,
300                start_pos: 0,
301                need_degree_table: false,
302                _marker: PhantomData,
303            },
304            side_r: JoinSide {
305                ht: JoinHashMap::new(
306                    watermark_epoch,
307                    join_key_data_types_r,
308                    state_join_key_indices_r.clone(),
309                    state_all_data_types_r.clone(),
310                    state_table_r,
311                    params_r.deduped_pk_indices,
312                    None,
313                    null_matched,
314                    pk_contained_in_jk_r,
315                    inequal_key_idx_r,
316                    metrics.clone(),
317                    ctx.id,
318                    ctx.fragment_id,
319                    "right",
320                ),
321                join_key_indices: state_join_key_indices_r,
322                all_data_types: state_all_data_types_r,
323                start_pos: side_l_column_n,
324                i2o_mapping: right_to_output,
325                i2o_mapping_indexed: r2o_indexed,
326                need_degree_table: false,
327                _marker: PhantomData,
328            },
329            metrics,
330            chunk_size,
331            cnt_rows_received: 0,
332            watermark_buffers,
333            high_join_amplification_threshold,
334            asof_desc,
335        }
336    }
337
338    #[try_stream(ok = Message, error = StreamExecutorError)]
339    async fn into_stream(mut self) {
340        let input_l = self.input_l.take().unwrap();
341        let input_r = self.input_r.take().unwrap();
342        let aligned_stream = barrier_align(
343            input_l.execute(),
344            input_r.execute(),
345            self.ctx.id,
346            self.ctx.fragment_id,
347            self.metrics.clone(),
348            "Join",
349        );
350        pin_mut!(aligned_stream);
351        let actor_id = self.ctx.id;
352
353        let barrier = expect_first_barrier_from_aligned_stream(&mut aligned_stream).await?;
354        let first_epoch = barrier.epoch;
355        // The first barrier message should be propagated.
356        yield Message::Barrier(barrier);
357        self.side_l.init(first_epoch).await?;
358        self.side_r.init(first_epoch).await?;
359
360        let actor_id_str = self.ctx.id.to_string();
361        let fragment_id_str = self.ctx.fragment_id.to_string();
362
363        // initialized some metrics
364        let join_actor_input_waiting_duration_ns = self
365            .metrics
366            .join_actor_input_waiting_duration_ns
367            .with_guarded_label_values(&[&actor_id_str, &fragment_id_str]);
368        let left_join_match_duration_ns = self
369            .metrics
370            .join_match_duration_ns
371            .with_guarded_label_values(&[actor_id_str.as_str(), fragment_id_str.as_str(), "left"]);
372        let right_join_match_duration_ns = self
373            .metrics
374            .join_match_duration_ns
375            .with_guarded_label_values(&[actor_id_str.as_str(), fragment_id_str.as_str(), "right"]);
376
377        let barrier_join_match_duration_ns = self
378            .metrics
379            .join_match_duration_ns
380            .with_guarded_label_values(&[
381                actor_id_str.as_str(),
382                fragment_id_str.as_str(),
383                "barrier",
384            ]);
385
386        let left_join_cached_entry_count = self
387            .metrics
388            .join_cached_entry_count
389            .with_guarded_label_values(&[actor_id_str.as_str(), fragment_id_str.as_str(), "left"]);
390
391        let right_join_cached_entry_count = self
392            .metrics
393            .join_cached_entry_count
394            .with_guarded_label_values(&[actor_id_str.as_str(), fragment_id_str.as_str(), "right"]);
395
396        let mut start_time = Instant::now();
397
398        while let Some(msg) = aligned_stream
399            .next()
400            .instrument_await("hash_join_barrier_align")
401            .await
402        {
403            join_actor_input_waiting_duration_ns.inc_by(start_time.elapsed().as_nanos() as u64);
404            match msg? {
405                AlignedMessage::WatermarkLeft(watermark) => {
406                    for watermark_to_emit in self.handle_watermark(SideType::Left, watermark)? {
407                        yield Message::Watermark(watermark_to_emit);
408                    }
409                }
410                AlignedMessage::WatermarkRight(watermark) => {
411                    for watermark_to_emit in self.handle_watermark(SideType::Right, watermark)? {
412                        yield Message::Watermark(watermark_to_emit);
413                    }
414                }
415                AlignedMessage::Left(chunk) => {
416                    let mut left_time = Duration::from_nanos(0);
417                    let mut left_start_time = Instant::now();
418                    #[for_await]
419                    for chunk in Self::eq_join_left(EqJoinArgs {
420                        ctx: &self.ctx,
421                        side_l: &mut self.side_l,
422                        side_r: &mut self.side_r,
423                        asof_desc: &self.asof_desc,
424                        actual_output_data_types: &self.actual_output_data_types,
425                        // inequality_watermarks: &self.inequality_watermarks,
426                        chunk,
427                        chunk_size: self.chunk_size,
428                        cnt_rows_received: &mut self.cnt_rows_received,
429                        high_join_amplification_threshold: self.high_join_amplification_threshold,
430                    }) {
431                        left_time += left_start_time.elapsed();
432                        yield Message::Chunk(chunk?);
433                        left_start_time = Instant::now();
434                    }
435                    left_time += left_start_time.elapsed();
436                    left_join_match_duration_ns.inc_by(left_time.as_nanos() as u64);
437                    self.try_flush_data().await?;
438                }
439                AlignedMessage::Right(chunk) => {
440                    let mut right_time = Duration::from_nanos(0);
441                    let mut right_start_time = Instant::now();
442                    #[for_await]
443                    for chunk in Self::eq_join_right(EqJoinArgs {
444                        ctx: &self.ctx,
445                        side_l: &mut self.side_l,
446                        side_r: &mut self.side_r,
447                        asof_desc: &self.asof_desc,
448                        actual_output_data_types: &self.actual_output_data_types,
449                        // inequality_watermarks: &self.inequality_watermarks,
450                        chunk,
451                        chunk_size: self.chunk_size,
452                        cnt_rows_received: &mut self.cnt_rows_received,
453                        high_join_amplification_threshold: self.high_join_amplification_threshold,
454                    }) {
455                        right_time += right_start_time.elapsed();
456                        yield Message::Chunk(chunk?);
457                        right_start_time = Instant::now();
458                    }
459                    right_time += right_start_time.elapsed();
460                    right_join_match_duration_ns.inc_by(right_time.as_nanos() as u64);
461                    self.try_flush_data().await?;
462                }
463                AlignedMessage::Barrier(barrier) => {
464                    let barrier_start_time = Instant::now();
465                    let (left_post_commit, right_post_commit) =
466                        self.flush_data(barrier.epoch).await?;
467
468                    let update_vnode_bitmap = barrier.as_update_vnode_bitmap(actor_id);
469                    yield Message::Barrier(barrier);
470
471                    // Update the vnode bitmap for state tables of both sides if asked.
472                    right_post_commit
473                        .post_yield_barrier(update_vnode_bitmap.clone())
474                        .await?;
475                    if left_post_commit
476                        .post_yield_barrier(update_vnode_bitmap)
477                        .await?
478                        .unwrap_or(false)
479                    {
480                        self.watermark_buffers
481                            .values_mut()
482                            .for_each(|buffers| buffers.clear());
483                    }
484
485                    // Report metrics of cached join rows/entries
486                    for (join_cached_entry_count, ht) in [
487                        (&left_join_cached_entry_count, &self.side_l.ht),
488                        (&right_join_cached_entry_count, &self.side_r.ht),
489                    ] {
490                        join_cached_entry_count.set(ht.entry_count() as i64);
491                    }
492
493                    barrier_join_match_duration_ns
494                        .inc_by(barrier_start_time.elapsed().as_nanos() as u64);
495                }
496            }
497            start_time = Instant::now();
498        }
499    }
500
501    async fn flush_data(
502        &mut self,
503        epoch: EpochPair,
504    ) -> StreamExecutorResult<(
505        JoinHashMapPostCommit<'_, K, S, E>,
506        JoinHashMapPostCommit<'_, K, S, E>,
507    )> {
508        // All changes to the state has been buffered in the mem-table of the state table. Just
509        // `commit` them here.
510        let left = self.side_l.ht.flush(epoch).await?;
511        let right = self.side_r.ht.flush(epoch).await?;
512        Ok((left, right))
513    }
514
515    async fn try_flush_data(&mut self) -> StreamExecutorResult<()> {
516        // All changes to the state has been buffered in the mem-table of the state table. Just
517        // `commit` them here.
518        self.side_l.ht.try_flush().await?;
519        self.side_r.ht.try_flush().await?;
520        Ok(())
521    }
522
523    // We need to manually evict the cache.
524    fn evict_cache(
525        side_update: &mut JoinSide<K, S, E>,
526        side_match: &mut JoinSide<K, S, E>,
527        cnt_rows_received: &mut u32,
528    ) {
529        *cnt_rows_received += 1;
530        if *cnt_rows_received == EVICT_EVERY_N_ROWS {
531            side_update.ht.evict();
532            side_match.ht.evict();
533            *cnt_rows_received = 0;
534        }
535    }
536
537    fn handle_watermark(
538        &mut self,
539        side: SideTypePrimitive,
540        watermark: Watermark,
541    ) -> StreamExecutorResult<Vec<Watermark>> {
542        let (side_update, side_match) = if side == SideType::Left {
543            (&mut self.side_l, &mut self.side_r)
544        } else {
545            (&mut self.side_r, &mut self.side_l)
546        };
547
548        // State cleaning
549        if side_update.join_key_indices[0] == watermark.col_idx {
550            side_match.ht.update_watermark(watermark.val.clone());
551        }
552
553        // Select watermarks to yield.
554        let wm_in_jk = side_update
555            .join_key_indices
556            .iter()
557            .positions(|idx| *idx == watermark.col_idx);
558        let mut watermarks_to_emit = vec![];
559        for idx in wm_in_jk {
560            let buffers = self
561                .watermark_buffers
562                .entry(idx)
563                .or_insert_with(|| BufferedWatermarks::with_ids([SideType::Left, SideType::Right]));
564            if let Some(selected_watermark) = buffers.handle_watermark(side, watermark.clone()) {
565                let empty_indices = vec![];
566                let output_indices = side_update
567                    .i2o_mapping_indexed
568                    .get_vec(&side_update.join_key_indices[idx])
569                    .unwrap_or(&empty_indices)
570                    .iter()
571                    .chain(
572                        side_match
573                            .i2o_mapping_indexed
574                            .get_vec(&side_match.join_key_indices[idx])
575                            .unwrap_or(&empty_indices),
576                    );
577                for output_idx in output_indices {
578                    watermarks_to_emit.push(selected_watermark.clone().with_idx(*output_idx));
579                }
580            };
581        }
582        Ok(watermarks_to_emit)
583    }
584
585    /// the data the hash table and match the coming
586    /// data chunk with the executor state
587    async fn hash_eq_match(
588        key: &K,
589        ht: &mut JoinHashMap<K, S, E>,
590    ) -> StreamExecutorResult<Option<HashValueType<E>>> {
591        if !key.null_bitmap().is_subset(ht.null_matched()) {
592            Ok(None)
593        } else {
594            ht.take_state(key).await.map(Some)
595        }
596    }
597
598    #[try_stream(ok = StreamChunk, error = StreamExecutorError)]
599    async fn eq_join_left(args: EqJoinArgs<'_, K, S, E>) {
600        let EqJoinArgs {
601            ctx: _,
602            side_l,
603            side_r,
604            asof_desc,
605            actual_output_data_types,
606            // inequality_watermarks,
607            chunk,
608            chunk_size,
609            cnt_rows_received,
610            high_join_amplification_threshold: _,
611        } = args;
612
613        let (side_update, side_match) = (side_l, side_r);
614
615        let mut join_chunk_builder =
616            JoinChunkBuilder::<T, { SideType::Left }>::new(JoinStreamChunkBuilder::new(
617                chunk_size,
618                actual_output_data_types.to_vec(),
619                side_update.i2o_mapping.clone(),
620                side_match.i2o_mapping.clone(),
621            ));
622
623        let keys = K::build_many(&side_update.join_key_indices, chunk.data_chunk());
624        for (r, key) in chunk.rows_with_holes().zip_eq_debug(keys.iter()) {
625            let Some((op, row)) = r else {
626                continue;
627            };
628            Self::evict_cache(side_update, side_match, cnt_rows_received);
629
630            let matched_rows = if !side_update.ht.check_inequal_key_null(&row) {
631                Self::hash_eq_match(key, &mut side_match.ht).await?
632            } else {
633                None
634            };
635            let inequal_key = side_update.ht.serialize_inequal_key_from_row(row);
636
637            if let Some(matched_rows) = matched_rows {
638                let matched_row_by_inequality = match asof_desc.inequality_type {
639                    AsOfInequalityType::Lt => matched_rows.lower_bound_by_inequality(
640                        Bound::Excluded(&inequal_key),
641                        &side_match.all_data_types,
642                    ),
643                    AsOfInequalityType::Le => matched_rows.lower_bound_by_inequality(
644                        Bound::Included(&inequal_key),
645                        &side_match.all_data_types,
646                    ),
647                    AsOfInequalityType::Gt => matched_rows.upper_bound_by_inequality(
648                        Bound::Excluded(&inequal_key),
649                        &side_match.all_data_types,
650                    ),
651                    AsOfInequalityType::Ge => matched_rows.upper_bound_by_inequality(
652                        Bound::Included(&inequal_key),
653                        &side_match.all_data_types,
654                    ),
655                };
656                match op {
657                    Op::Insert | Op::UpdateInsert => {
658                        if let Some(matched_row_by_inequality) = matched_row_by_inequality {
659                            let matched_row = matched_row_by_inequality?;
660
661                            if let Some(chunk) =
662                                join_chunk_builder.with_match_on_insert(&row, &matched_row)
663                            {
664                                yield chunk;
665                            }
666                        } else if let Some(chunk) =
667                            join_chunk_builder.forward_if_not_matched(Op::Insert, row)
668                        {
669                            yield chunk;
670                        }
671                        side_update.ht.insert_row(key, row)?;
672                    }
673                    Op::Delete | Op::UpdateDelete => {
674                        if let Some(matched_row_by_inequality) = matched_row_by_inequality {
675                            let matched_row = matched_row_by_inequality?;
676
677                            if let Some(chunk) =
678                                join_chunk_builder.with_match_on_delete(&row, &matched_row)
679                            {
680                                yield chunk;
681                            }
682                        } else if let Some(chunk) =
683                            join_chunk_builder.forward_if_not_matched(Op::Delete, row)
684                        {
685                            yield chunk;
686                        }
687                        side_update.ht.delete_row(key, row)?;
688                    }
689                }
690                // Insert back the state taken from ht.
691                side_match.ht.update_state(key, matched_rows);
692            } else {
693                // Row which violates null-safe bitmap will never be matched so we need not
694                // store.
695                match op {
696                    Op::Insert | Op::UpdateInsert => {
697                        if let Some(chunk) =
698                            join_chunk_builder.forward_if_not_matched(Op::Insert, row)
699                        {
700                            yield chunk;
701                        }
702                    }
703                    Op::Delete | Op::UpdateDelete => {
704                        if let Some(chunk) =
705                            join_chunk_builder.forward_if_not_matched(Op::Delete, row)
706                        {
707                            yield chunk;
708                        }
709                    }
710                }
711            }
712        }
713        if let Some(chunk) = join_chunk_builder.take() {
714            yield chunk;
715        }
716    }
717
718    #[try_stream(ok = StreamChunk, error = StreamExecutorError)]
719    async fn eq_join_right(args: EqJoinArgs<'_, K, S, E>) {
720        let EqJoinArgs {
721            ctx,
722            side_l,
723            side_r,
724            asof_desc,
725            actual_output_data_types,
726            // inequality_watermarks,
727            chunk,
728            chunk_size,
729            cnt_rows_received,
730            high_join_amplification_threshold,
731        } = args;
732
733        let (side_update, side_match) = (side_r, side_l);
734
735        let mut join_chunk_builder = JoinStreamChunkBuilder::new(
736            chunk_size,
737            actual_output_data_types.to_vec(),
738            side_update.i2o_mapping.clone(),
739            side_match.i2o_mapping.clone(),
740        );
741
742        let join_matched_rows_metrics = ctx
743            .streaming_metrics
744            .join_matched_join_keys
745            .with_guarded_label_values(&[
746                &ctx.id.to_string(),
747                &ctx.fragment_id.to_string(),
748                &side_update.ht.table_id().to_string(),
749            ]);
750
751        let keys = K::build_many(&side_update.join_key_indices, chunk.data_chunk());
752        for (r, key) in chunk.rows_with_holes().zip_eq_debug(keys.iter()) {
753            let Some((op, row)) = r else {
754                continue;
755            };
756            let mut join_matched_rows_cnt = 0;
757
758            Self::evict_cache(side_update, side_match, cnt_rows_received);
759
760            let matched_rows = if !side_update.ht.check_inequal_key_null(&row) {
761                Self::hash_eq_match(key, &mut side_match.ht).await?
762            } else {
763                None
764            };
765            let inequal_key = side_update.ht.serialize_inequal_key_from_row(row);
766
767            if let Some(matched_rows) = matched_rows {
768                let update_rows = Self::hash_eq_match(key, &mut side_update.ht).await?.expect("None is not expected because we have checked null in key when getting matched_rows");
769                let right_inequality_index = update_rows.inequality_index();
770                let (row_to_delete_r, row_to_insert_r) =
771                    if let Some(pks) = right_inequality_index.get(&inequal_key) {
772                        assert!(!pks.is_empty());
773                        let row_pk = side_update.ht.serialize_pk_from_row(row);
774                        match op {
775                            Op::Insert | Op::UpdateInsert => {
776                                // If there are multiple rows match the inequality key in the right table, we use one with smallest pk.
777                                let smallest_pk = pks.first_key_sorted().unwrap();
778                                if smallest_pk > &row_pk {
779                                    // smallest_pk is in the cache index, so it must exist in the cache.
780                                    if let Some(to_delete_row) = update_rows
781                                        .get_by_indexed_pk(smallest_pk, &side_update.all_data_types)
782                                    {
783                                        (
784                                            Some(Either::Left(to_delete_row?.row)),
785                                            Some(Either::Right(row)),
786                                        )
787                                    } else {
788                                        // Something wrong happened. Ignore this row in non strict consistency mode.
789                                        (None, None)
790                                    }
791                                } else {
792                                    // No affected row in the right table.
793                                    (None, None)
794                                }
795                            }
796                            Op::Delete | Op::UpdateDelete => {
797                                let smallest_pk = pks.first_key_sorted().unwrap();
798                                if smallest_pk == &row_pk {
799                                    if let Some(second_smallest_pk) = pks.second_key_sorted() {
800                                        if let Some(to_insert_row) = update_rows.get_by_indexed_pk(
801                                            second_smallest_pk,
802                                            &side_update.all_data_types,
803                                        ) {
804                                            (
805                                                Some(Either::Right(row)),
806                                                Some(Either::Left(to_insert_row?.row)),
807                                            )
808                                        } else {
809                                            // Something wrong happened. Ignore this row in non strict consistency mode.
810                                            (None, None)
811                                        }
812                                    } else {
813                                        (Some(Either::Right(row)), None)
814                                    }
815                                } else {
816                                    // No affected row in the right table.
817                                    (None, None)
818                                }
819                            }
820                        }
821                    } else {
822                        match op {
823                            // Decide the row_to_delete later
824                            Op::Insert | Op::UpdateInsert => (None, Some(Either::Right(row))),
825                            // Decide the row_to_insert later
826                            Op::Delete | Op::UpdateDelete => (Some(Either::Right(row)), None),
827                        }
828                    };
829
830                // 4 cases for row_to_delete_r and row_to_insert_r:
831                // 1. Some(_), Some(_): delete row_to_delete_r and insert row_to_insert_r
832                // 2. None, Some(_)   : row_to_delete to be decided by the nearest inequality key
833                // 3. Some(_), None   : row_to_insert to be decided by the nearest inequality key
834                // 4. None, None      : do nothing
835                if row_to_delete_r.is_none() && row_to_insert_r.is_none() {
836                    // no row to delete or insert.
837                } else {
838                    let prev_inequality_key =
839                        right_inequality_index.upper_bound_key(Bound::Excluded(&inequal_key));
840                    let next_inequality_key =
841                        right_inequality_index.lower_bound_key(Bound::Excluded(&inequal_key));
842                    let affected_row_r = match asof_desc.inequality_type {
843                        AsOfInequalityType::Lt | AsOfInequalityType::Le => next_inequality_key
844                            .and_then(|k| {
845                                update_rows.get_first_by_inequality(k, &side_update.all_data_types)
846                            }),
847                        AsOfInequalityType::Gt | AsOfInequalityType::Ge => prev_inequality_key
848                            .and_then(|k| {
849                                update_rows.get_first_by_inequality(k, &side_update.all_data_types)
850                            }),
851                    }
852                    .transpose()?
853                    .map(|r| Either::Left(r.row));
854
855                    let (row_to_delete_r, row_to_insert_r) =
856                        match (&row_to_delete_r, &row_to_insert_r) {
857                            (Some(_), Some(_)) => (row_to_delete_r, row_to_insert_r),
858                            (None, Some(_)) => (affected_row_r, row_to_insert_r),
859                            (Some(_), None) => (row_to_delete_r, affected_row_r),
860                            (None, None) => unreachable!(),
861                        };
862                    let range = match asof_desc.inequality_type {
863                        AsOfInequalityType::Lt => (
864                            prev_inequality_key.map_or_else(|| Bound::Unbounded, Bound::Included),
865                            Bound::Excluded(&inequal_key),
866                        ),
867                        AsOfInequalityType::Le => (
868                            prev_inequality_key.map_or_else(|| Bound::Unbounded, Bound::Excluded),
869                            Bound::Included(&inequal_key),
870                        ),
871                        AsOfInequalityType::Gt => (
872                            Bound::Excluded(&inequal_key),
873                            next_inequality_key.map_or_else(|| Bound::Unbounded, Bound::Included),
874                        ),
875                        AsOfInequalityType::Ge => (
876                            Bound::Included(&inequal_key),
877                            next_inequality_key.map_or_else(|| Bound::Unbounded, Bound::Excluded),
878                        ),
879                    };
880
881                    let rows_l =
882                        matched_rows.range_by_inequality(range, &side_match.all_data_types);
883                    for row_l in rows_l {
884                        join_matched_rows_cnt += 1;
885                        let row_l = row_l?.row;
886                        if let Some(row_to_delete_r) = &row_to_delete_r {
887                            if let Some(chunk) =
888                                join_chunk_builder.append_row(Op::Delete, row_to_delete_r, &row_l)
889                            {
890                                yield chunk;
891                            }
892                        } else if is_as_of_left_outer(T)
893                            && let Some(chunk) =
894                                join_chunk_builder.append_row_matched(Op::Delete, &row_l)
895                        {
896                            yield chunk;
897                        }
898                        if let Some(row_to_insert_r) = &row_to_insert_r {
899                            if let Some(chunk) =
900                                join_chunk_builder.append_row(Op::Insert, row_to_insert_r, &row_l)
901                            {
902                                yield chunk;
903                            }
904                        } else if is_as_of_left_outer(T)
905                            && let Some(chunk) =
906                                join_chunk_builder.append_row_matched(Op::Insert, &row_l)
907                        {
908                            yield chunk;
909                        }
910                    }
911                }
912                // Insert back the state taken from ht.
913                side_match.ht.update_state(key, matched_rows);
914                side_update.ht.update_state(key, update_rows);
915
916                match op {
917                    Op::Insert | Op::UpdateInsert => {
918                        side_update.ht.insert_row(key, row)?;
919                    }
920                    Op::Delete | Op::UpdateDelete => {
921                        side_update.ht.delete_row(key, row)?;
922                    }
923                }
924            } else {
925                // Row which violates null-safe bitmap will never be matched so we need not
926                // store.
927                // Noop here because we only support left outer AsOf join.
928            }
929            join_matched_rows_metrics.observe(join_matched_rows_cnt as _);
930            if join_matched_rows_cnt > high_join_amplification_threshold {
931                let join_key_data_types = side_update.ht.join_key_data_types();
932                let key = key.deserialize(join_key_data_types)?;
933                tracing::warn!(target: "high_join_amplification",
934                    matched_rows_len = join_matched_rows_cnt,
935                    update_table_id = %side_update.ht.table_id(),
936                    match_table_id = %side_match.ht.table_id(),
937                    join_key = ?key,
938                    actor_id = %ctx.id,
939                    fragment_id = %ctx.fragment_id,
940                    "large rows matched for join key when AsOf join updating right side",
941                );
942            }
943        }
944        if let Some(chunk) = join_chunk_builder.take() {
945            yield chunk;
946        }
947    }
948}
949
950#[cfg(test)]
951mod tests {
952    use std::sync::atomic::AtomicU64;
953
954    use risingwave_common::array::*;
955    use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, TableId};
956    use risingwave_common::hash::Key64;
957    use risingwave_common::util::epoch::test_epoch;
958    use risingwave_common::util::sort_util::OrderType;
959    use risingwave_storage::memory::MemoryStateStore;
960
961    use super::*;
962    use crate::common::table::test_utils::gen_pbtable;
963    use crate::executor::MemoryEncoding;
964    use crate::executor::test_utils::{MessageSender, MockSource, StreamExecutorTestExt};
965
966    async fn create_in_memory_state_table(
967        mem_state: MemoryStateStore,
968        data_types: &[DataType],
969        order_types: &[OrderType],
970        pk_indices: &[usize],
971        table_id: u32,
972    ) -> StateTable<MemoryStateStore> {
973        let column_descs = data_types
974            .iter()
975            .enumerate()
976            .map(|(id, data_type)| ColumnDesc::unnamed(ColumnId::new(id as i32), data_type.clone()))
977            .collect_vec();
978        StateTable::from_table_catalog(
979            &gen_pbtable(
980                TableId::new(table_id),
981                column_descs,
982                order_types.to_vec(),
983                pk_indices.to_vec(),
984                0,
985            ),
986            mem_state.clone(),
987            None,
988        )
989        .await
990    }
991
992    async fn create_executor<const T: AsOfJoinTypePrimitive>(
993        asof_desc: AsOfDesc,
994    ) -> (MessageSender, MessageSender, BoxedMessageStream) {
995        let schema = Schema {
996            fields: vec![
997                Field::unnamed(DataType::Int64), // join key
998                Field::unnamed(DataType::Int64),
999                Field::unnamed(DataType::Int64),
1000            ],
1001        };
1002        let (tx_l, source_l) = MockSource::channel();
1003        let source_l = source_l.into_executor(schema.clone(), vec![1]);
1004        let (tx_r, source_r) = MockSource::channel();
1005        let source_r = source_r.into_executor(schema, vec![1]);
1006        let params_l = JoinParams::new(vec![0], vec![1]);
1007        let params_r = JoinParams::new(vec![0], vec![1]);
1008
1009        let mem_state = MemoryStateStore::new();
1010
1011        let state_l = create_in_memory_state_table(
1012            mem_state.clone(),
1013            &[DataType::Int64, DataType::Int64, DataType::Int64],
1014            &[
1015                OrderType::ascending(),
1016                OrderType::ascending(),
1017                OrderType::ascending(),
1018            ],
1019            &[0, asof_desc.left_idx, 1],
1020            0,
1021        )
1022        .await;
1023
1024        let state_r = create_in_memory_state_table(
1025            mem_state,
1026            &[DataType::Int64, DataType::Int64, DataType::Int64],
1027            &[
1028                OrderType::ascending(),
1029                OrderType::ascending(),
1030                OrderType::ascending(),
1031            ],
1032            &[0, asof_desc.right_idx, 1],
1033            1,
1034        )
1035        .await;
1036
1037        let schema: Schema = [source_l.schema().fields(), source_r.schema().fields()]
1038            .concat()
1039            .into_iter()
1040            .collect();
1041        let schema_len = schema.len();
1042        let info = ExecutorInfo::for_test(schema, vec![1], "HashJoinExecutor".to_owned(), 0);
1043
1044        let executor = AsOfJoinExecutor::<Key64, MemoryStateStore, T, MemoryEncoding>::new(
1045            ActorContext::for_test(123),
1046            info,
1047            source_l,
1048            source_r,
1049            params_l,
1050            params_r,
1051            vec![false],
1052            (0..schema_len).collect_vec(),
1053            state_l,
1054            state_r,
1055            Arc::new(AtomicU64::new(0)),
1056            Arc::new(StreamingMetrics::unused()),
1057            1024,
1058            2048,
1059            asof_desc,
1060        );
1061        (tx_l, tx_r, executor.boxed().execute())
1062    }
1063
1064    #[tokio::test]
1065    async fn test_as_of_inner_join() -> StreamExecutorResult<()> {
1066        let asof_desc = AsOfDesc {
1067            left_idx: 0,
1068            right_idx: 2,
1069            inequality_type: AsOfInequalityType::Lt,
1070        };
1071
1072        let chunk_l1 = StreamChunk::from_pretty(
1073            "  I I I
1074             + 1 4 7
1075             + 2 5 8
1076             + 3 6 9",
1077        );
1078        let chunk_l2 = StreamChunk::from_pretty(
1079            "  I I I
1080             + 3 8 1
1081             - 3 8 1",
1082        );
1083        let chunk_r1 = StreamChunk::from_pretty(
1084            "  I I I
1085             + 2 1 7
1086             + 2 2 1
1087             + 2 3 4
1088             + 2 4 2
1089             + 6 1 9
1090             + 6 2 9",
1091        );
1092        let chunk_r2 = StreamChunk::from_pretty(
1093            "  I I I
1094             - 2 3 4",
1095        );
1096        let chunk_r3 = StreamChunk::from_pretty(
1097            "  I I I
1098             + 2 3 3",
1099        );
1100        let chunk_l3 = StreamChunk::from_pretty(
1101            "  I I I
1102             - 2 5 8",
1103        );
1104        let chunk_l4 = StreamChunk::from_pretty(
1105            "  I I I
1106             + 6 3 1
1107             + 6 4 1",
1108        );
1109        let chunk_r4 = StreamChunk::from_pretty(
1110            "  I I I
1111             - 6 1 9",
1112        );
1113
1114        let (mut tx_l, mut tx_r, mut hash_join) =
1115            create_executor::<{ AsOfJoinType::Inner }>(asof_desc).await;
1116
1117        // push the init barrier for left and right
1118        tx_l.push_barrier(test_epoch(1), false);
1119        tx_r.push_barrier(test_epoch(1), false);
1120        hash_join.next_unwrap_ready_barrier()?;
1121
1122        // push the 1st left chunk
1123        tx_l.push_chunk(chunk_l1);
1124        hash_join.next_unwrap_pending();
1125
1126        // push the init barrier for left and right
1127        tx_l.push_barrier(test_epoch(2), false);
1128        tx_r.push_barrier(test_epoch(2), false);
1129        hash_join.next_unwrap_ready_barrier()?;
1130
1131        // push the 2nd left chunk
1132        tx_l.push_chunk(chunk_l2);
1133        hash_join.next_unwrap_pending();
1134
1135        // push the 1st right chunk
1136        tx_r.push_chunk(chunk_r1);
1137        let chunk = hash_join.next_unwrap_ready_chunk()?;
1138        assert_eq!(
1139            chunk,
1140            StreamChunk::from_pretty(
1141                " I I I I I I
1142                + 2 5 8 2 1 7
1143                - 2 5 8 2 1 7
1144                + 2 5 8 2 3 4"
1145            )
1146        );
1147
1148        // push the 2nd right chunk
1149        tx_r.push_chunk(chunk_r2);
1150        let chunk = hash_join.next_unwrap_ready_chunk()?;
1151        assert_eq!(
1152            chunk,
1153            StreamChunk::from_pretty(
1154                " I I I I I I
1155                - 2 5 8 2 3 4
1156                + 2 5 8 2 1 7"
1157            )
1158        );
1159
1160        // push the 3rd right chunk
1161        tx_r.push_chunk(chunk_r3);
1162        let chunk = hash_join.next_unwrap_ready_chunk()?;
1163        assert_eq!(
1164            chunk,
1165            StreamChunk::from_pretty(
1166                " I I I I I I
1167                - 2 5 8 2 1 7
1168                + 2 5 8 2 3 3"
1169            )
1170        );
1171
1172        // push the 3rd left chunk
1173        tx_l.push_chunk(chunk_l3);
1174        let chunk = hash_join.next_unwrap_ready_chunk()?;
1175        assert_eq!(
1176            chunk,
1177            StreamChunk::from_pretty(
1178                " I I I I I I
1179                - 2 5 8 2 3 3"
1180            )
1181        );
1182
1183        // push the 4th left chunk
1184        tx_l.push_chunk(chunk_l4);
1185        let chunk = hash_join.next_unwrap_ready_chunk()?;
1186        assert_eq!(
1187            chunk,
1188            StreamChunk::from_pretty(
1189                " I I I I I I
1190                + 6 3 1 6 1 9
1191                + 6 4 1 6 1 9"
1192            )
1193        );
1194
1195        // push the 4th right chunk
1196        tx_r.push_chunk(chunk_r4);
1197        let chunk = hash_join.next_unwrap_ready_chunk()?;
1198        assert_eq!(
1199            chunk,
1200            StreamChunk::from_pretty(
1201                " I I I I I I
1202                - 6 3 1 6 1 9
1203                + 6 3 1 6 2 9
1204                - 6 4 1 6 1 9
1205                + 6 4 1 6 2 9"
1206            )
1207        );
1208
1209        Ok(())
1210    }
1211
1212    #[tokio::test]
1213    async fn test_as_of_left_outer_join() -> StreamExecutorResult<()> {
1214        let asof_desc = AsOfDesc {
1215            left_idx: 1,
1216            right_idx: 2,
1217            inequality_type: AsOfInequalityType::Ge,
1218        };
1219
1220        let chunk_l1 = StreamChunk::from_pretty(
1221            "  I I I
1222             + 1 4 7
1223             + 2 5 8
1224             + 3 6 9",
1225        );
1226        let chunk_l2 = StreamChunk::from_pretty(
1227            "  I I I
1228             + 3 8 1
1229             - 3 8 1",
1230        );
1231        let chunk_r1 = StreamChunk::from_pretty(
1232            "  I I I
1233             + 2 3 4
1234             + 2 2 5
1235             + 2 1 5
1236             + 6 1 8
1237             + 6 2 9",
1238        );
1239        let chunk_r2 = StreamChunk::from_pretty(
1240            "  I I I
1241             - 2 3 4
1242             - 2 1 5
1243             - 2 2 5",
1244        );
1245        let chunk_l3 = StreamChunk::from_pretty(
1246            "  I I I
1247             + 6 8 9",
1248        );
1249        let chunk_r3 = StreamChunk::from_pretty(
1250            "  I I I
1251             - 6 1 8",
1252        );
1253
1254        let (mut tx_l, mut tx_r, mut hash_join) =
1255            create_executor::<{ AsOfJoinType::LeftOuter }>(asof_desc).await;
1256
1257        // push the init barrier for left and right
1258        tx_l.push_barrier(test_epoch(1), false);
1259        tx_r.push_barrier(test_epoch(1), false);
1260        hash_join.next_unwrap_ready_barrier()?;
1261
1262        // push the 1st left chunk
1263        tx_l.push_chunk(chunk_l1);
1264        let chunk = hash_join.next_unwrap_ready_chunk()?;
1265        assert_eq!(
1266            chunk,
1267            StreamChunk::from_pretty(
1268                " I I I I I I
1269                + 1 4 7 . . .
1270                + 2 5 8 . . .
1271                + 3 6 9 . . ."
1272            )
1273        );
1274
1275        // push the init barrier for left and right
1276        tx_l.push_barrier(test_epoch(2), false);
1277        tx_r.push_barrier(test_epoch(2), false);
1278        hash_join.next_unwrap_ready_barrier()?;
1279
1280        // push the 2nd left chunk
1281        tx_l.push_chunk(chunk_l2);
1282        let chunk = hash_join.next_unwrap_ready_chunk()?;
1283        assert_eq!(
1284            chunk,
1285            StreamChunk::from_pretty(
1286                " I I I I I I
1287                + 3 8 1 . . . D
1288                - 3 8 1 . . . D"
1289            )
1290        );
1291
1292        // push the 1st right chunk
1293        tx_r.push_chunk(chunk_r1);
1294        let chunk = hash_join.next_unwrap_ready_chunk()?;
1295        assert_eq!(
1296            chunk,
1297            StreamChunk::from_pretty(
1298                " I I I I I I
1299                - 2 5 8 . . .
1300                + 2 5 8 2 3 4
1301                - 2 5 8 2 3 4
1302                + 2 5 8 2 2 5
1303                - 2 5 8 2 2 5
1304                + 2 5 8 2 1 5"
1305            )
1306        );
1307
1308        // push the 2nd right chunk
1309        tx_r.push_chunk(chunk_r2);
1310        let chunk = hash_join.next_unwrap_ready_chunk()?;
1311        assert_eq!(
1312            chunk,
1313            StreamChunk::from_pretty(
1314                " I I I I I I
1315                - 2 5 8 2 1 5
1316                + 2 5 8 2 2 5
1317                - 2 5 8 2 2 5
1318                + 2 5 8 . . ."
1319            )
1320        );
1321
1322        // push the 3rd left chunk
1323        tx_l.push_chunk(chunk_l3);
1324        let chunk = hash_join.next_unwrap_ready_chunk()?;
1325        assert_eq!(
1326            chunk,
1327            StreamChunk::from_pretty(
1328                " I I I I I I
1329                + 6 8 9 6 1 8"
1330            )
1331        );
1332
1333        // push the 3rd right chunk
1334        tx_r.push_chunk(chunk_r3);
1335        let chunk = hash_join.next_unwrap_ready_chunk()?;
1336        assert_eq!(
1337            chunk,
1338            StreamChunk::from_pretty(
1339                " I I I I I I
1340                - 6 8 9 6 1 8
1341                + 6 8 9 . . ."
1342            )
1343        );
1344        Ok(())
1345    }
1346}