risingwave_stream/executor/
asof_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use std::collections::{BTreeMap, HashSet};
15use std::marker::PhantomData;
16use std::ops::Bound;
17use std::time::Duration;
18
19use either::Either;
20use itertools::Itertools;
21use multimap::MultiMap;
22use risingwave_common::array::Op;
23use risingwave_common::hash::{HashKey, NullBitmap};
24use risingwave_common::util::epoch::EpochPair;
25use risingwave_common::util::iter_util::ZipEqDebug;
26use tokio::time::Instant;
27
28use self::builder::JoinChunkBuilder;
29use super::barrier_align::*;
30use super::join::hash_join::*;
31use super::join::*;
32use super::watermark::*;
33use crate::executor::join::builder::JoinStreamChunkBuilder;
34use crate::executor::join::row::JoinEncoding;
35use crate::executor::prelude::*;
36
37/// Evict the cache every n rows.
38const EVICT_EVERY_N_ROWS: u32 = 16;
39
40fn is_subset(vec1: Vec<usize>, vec2: Vec<usize>) -> bool {
41    HashSet::<usize>::from_iter(vec1).is_subset(&vec2.into_iter().collect())
42}
43
44pub struct JoinParams {
45    /// Indices of the join keys
46    pub join_key_indices: Vec<usize>,
47    /// Indices of the input pk after dedup
48    pub deduped_pk_indices: Vec<usize>,
49}
50
51impl JoinParams {
52    pub fn new(join_key_indices: Vec<usize>, deduped_pk_indices: Vec<usize>) -> Self {
53        Self {
54            join_key_indices,
55            deduped_pk_indices,
56        }
57    }
58}
59
60struct JoinSide<K: HashKey, S: StateStore, E: JoinEncoding> {
61    /// Store all data from a one side stream
62    ht: JoinHashMap<K, S, E>,
63    /// Indices of the join key columns
64    join_key_indices: Vec<usize>,
65    /// The data type of all columns without degree.
66    all_data_types: Vec<DataType>,
67    /// The start position for the side in output new columns
68    start_pos: usize,
69    /// The mapping from input indices of a side to output columes.
70    i2o_mapping: Vec<(usize, usize)>,
71    i2o_mapping_indexed: MultiMap<usize, usize>,
72    /// Whether degree table is needed for this side.
73    need_degree_table: bool,
74    _marker: std::marker::PhantomData<E>,
75}
76
77impl<K: HashKey, S: StateStore, E: JoinEncoding> std::fmt::Debug for JoinSide<K, S, E> {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        f.debug_struct("JoinSide")
80            .field("join_key_indices", &self.join_key_indices)
81            .field("col_types", &self.all_data_types)
82            .field("start_pos", &self.start_pos)
83            .field("i2o_mapping", &self.i2o_mapping)
84            .field("need_degree_table", &self.need_degree_table)
85            .finish()
86    }
87}
88
89impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinSide<K, S, E> {
90    // WARNING: Please do not call this until we implement it.
91    fn is_dirty(&self) -> bool {
92        unimplemented!()
93    }
94
95    #[expect(dead_code)]
96    fn clear_cache(&mut self) {
97        assert!(
98            !self.is_dirty(),
99            "cannot clear cache while states of hash join are dirty"
100        );
101
102        // TODO: not working with rearranged chain
103        // self.ht.clear();
104    }
105
106    pub async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
107        self.ht.init(epoch).await
108    }
109}
110
111/// `AsOfJoinExecutor` takes two input streams and runs equal hash join on them.
112/// The output columns are the concatenation of left and right columns.
113pub struct AsOfJoinExecutor<
114    K: HashKey,
115    S: StateStore,
116    const T: AsOfJoinTypePrimitive,
117    E: JoinEncoding,
118> {
119    ctx: ActorContextRef,
120    info: ExecutorInfo,
121
122    /// Left input executor
123    input_l: Option<Executor>,
124    /// Right input executor
125    input_r: Option<Executor>,
126    /// The data types of the formed new columns
127    actual_output_data_types: Vec<DataType>,
128    /// The parameters of the left join executor
129    side_l: JoinSide<K, S, E>,
130    /// The parameters of the right join executor
131    side_r: JoinSide<K, S, E>,
132
133    metrics: Arc<StreamingMetrics>,
134    /// The maximum size of the chunk produced by executor at a time
135    chunk_size: usize,
136    /// Count the messages received, clear to 0 when counted to `EVICT_EVERY_N_MESSAGES`
137    cnt_rows_received: u32,
138
139    /// watermark column index -> `BufferedWatermarks`
140    watermark_buffers: BTreeMap<usize, BufferedWatermarks<SideTypePrimitive>>,
141
142    high_join_amplification_threshold: usize,
143    /// `AsOf` join description
144    asof_desc: AsOfDesc,
145}
146
147impl<K: HashKey, S: StateStore, const T: AsOfJoinTypePrimitive, E: JoinEncoding> std::fmt::Debug
148    for AsOfJoinExecutor<K, S, T, E>
149{
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151        f.debug_struct("AsOfJoinExecutor")
152            .field("join_type", &T)
153            .field("input_left", &self.input_l.as_ref().unwrap().identity())
154            .field("input_right", &self.input_r.as_ref().unwrap().identity())
155            .field("side_l", &self.side_l)
156            .field("side_r", &self.side_r)
157            .field("pk_indices", &self.info.pk_indices)
158            .field("schema", &self.info.schema)
159            .field("actual_output_data_types", &self.actual_output_data_types)
160            .finish()
161    }
162}
163
164impl<K: HashKey, S: StateStore, const T: AsOfJoinTypePrimitive, E: JoinEncoding> Execute
165    for AsOfJoinExecutor<K, S, T, E>
166{
167    fn execute(self: Box<Self>) -> BoxedMessageStream {
168        self.into_stream().boxed()
169    }
170}
171
172struct EqJoinArgs<'a, K: HashKey, S: StateStore, E: JoinEncoding> {
173    ctx: &'a ActorContextRef,
174    side_l: &'a mut JoinSide<K, S, E>,
175    side_r: &'a mut JoinSide<K, S, E>,
176    asof_desc: &'a AsOfDesc,
177    actual_output_data_types: &'a [DataType],
178    // inequality_watermarks: &'a Watermark,
179    chunk: StreamChunk,
180    chunk_size: usize,
181    cnt_rows_received: &'a mut u32,
182    high_join_amplification_threshold: usize,
183}
184
185impl<K: HashKey, S: StateStore, const T: AsOfJoinTypePrimitive, E: JoinEncoding>
186    AsOfJoinExecutor<K, S, T, E>
187{
188    #[allow(clippy::too_many_arguments)]
189    pub fn new(
190        ctx: ActorContextRef,
191        info: ExecutorInfo,
192        input_l: Executor,
193        input_r: Executor,
194        params_l: JoinParams,
195        params_r: JoinParams,
196        null_safe: Vec<bool>,
197        output_indices: Vec<usize>,
198        state_table_l: StateTable<S>,
199        state_table_r: StateTable<S>,
200        watermark_epoch: AtomicU64Ref,
201        metrics: Arc<StreamingMetrics>,
202        chunk_size: usize,
203        high_join_amplification_threshold: usize,
204        asof_desc: AsOfDesc,
205    ) -> Self {
206        let side_l_column_n = input_l.schema().len();
207
208        let schema_fields = [
209            input_l.schema().fields.clone(),
210            input_r.schema().fields.clone(),
211        ]
212        .concat();
213
214        let original_output_data_types = schema_fields
215            .iter()
216            .map(|field| field.data_type())
217            .collect_vec();
218        let actual_output_data_types = output_indices
219            .iter()
220            .map(|&idx| original_output_data_types[idx].clone())
221            .collect_vec();
222
223        // Data types of of hash join state.
224        let state_all_data_types_l = input_l.schema().data_types();
225        let state_all_data_types_r = input_r.schema().data_types();
226
227        let state_pk_indices_l = input_l.pk_indices().to_vec();
228        let state_pk_indices_r = input_r.pk_indices().to_vec();
229
230        let state_join_key_indices_l = params_l.join_key_indices;
231        let state_join_key_indices_r = params_r.join_key_indices;
232
233        // If pk is contained in join key.
234        let pk_contained_in_jk_l =
235            is_subset(state_pk_indices_l.clone(), state_join_key_indices_l.clone());
236        let pk_contained_in_jk_r =
237            is_subset(state_pk_indices_r.clone(), state_join_key_indices_r.clone());
238
239        let join_key_data_types_l = state_join_key_indices_l
240            .iter()
241            .map(|idx| state_all_data_types_l[*idx].clone())
242            .collect_vec();
243
244        let join_key_data_types_r = state_join_key_indices_r
245            .iter()
246            .map(|idx| state_all_data_types_r[*idx].clone())
247            .collect_vec();
248
249        assert_eq!(join_key_data_types_l, join_key_data_types_r);
250
251        let null_matched = K::Bitmap::from_bool_vec(null_safe);
252
253        let (left_to_output, right_to_output) = {
254            let (left_len, right_len) = if is_left_semi_or_anti(T) {
255                (state_all_data_types_l.len(), 0usize)
256            } else if is_right_semi_or_anti(T) {
257                (0usize, state_all_data_types_r.len())
258            } else {
259                (state_all_data_types_l.len(), state_all_data_types_r.len())
260            };
261            JoinStreamChunkBuilder::get_i2o_mapping(&output_indices, left_len, right_len)
262        };
263
264        let l2o_indexed = MultiMap::from_iter(left_to_output.iter().copied());
265        let r2o_indexed = MultiMap::from_iter(right_to_output.iter().copied());
266
267        // handle inequality watermarks
268        // https://github.com/risingwavelabs/risingwave/issues/18503
269        // let inequality_watermarks = None;
270        let watermark_buffers = BTreeMap::new();
271
272        let inequal_key_idx_l = Some(asof_desc.left_idx);
273        let inequal_key_idx_r = Some(asof_desc.right_idx);
274
275        Self {
276            ctx: ctx.clone(),
277            info,
278            input_l: Some(input_l),
279            input_r: Some(input_r),
280            actual_output_data_types,
281            side_l: JoinSide {
282                ht: JoinHashMap::new(
283                    watermark_epoch.clone(),
284                    join_key_data_types_l,
285                    state_join_key_indices_l.clone(),
286                    state_all_data_types_l.clone(),
287                    state_table_l,
288                    params_l.deduped_pk_indices,
289                    None,
290                    null_matched.clone(),
291                    pk_contained_in_jk_l,
292                    inequal_key_idx_l,
293                    metrics.clone(),
294                    ctx.id,
295                    ctx.fragment_id,
296                    "left",
297                ),
298                join_key_indices: state_join_key_indices_l,
299                all_data_types: state_all_data_types_l,
300                i2o_mapping: left_to_output,
301                i2o_mapping_indexed: l2o_indexed,
302                start_pos: 0,
303                need_degree_table: false,
304                _marker: PhantomData,
305            },
306            side_r: JoinSide {
307                ht: JoinHashMap::new(
308                    watermark_epoch,
309                    join_key_data_types_r,
310                    state_join_key_indices_r.clone(),
311                    state_all_data_types_r.clone(),
312                    state_table_r,
313                    params_r.deduped_pk_indices,
314                    None,
315                    null_matched,
316                    pk_contained_in_jk_r,
317                    inequal_key_idx_r,
318                    metrics.clone(),
319                    ctx.id,
320                    ctx.fragment_id,
321                    "right",
322                ),
323                join_key_indices: state_join_key_indices_r,
324                all_data_types: state_all_data_types_r,
325                start_pos: side_l_column_n,
326                i2o_mapping: right_to_output,
327                i2o_mapping_indexed: r2o_indexed,
328                need_degree_table: false,
329                _marker: PhantomData,
330            },
331            metrics,
332            chunk_size,
333            cnt_rows_received: 0,
334            watermark_buffers,
335            high_join_amplification_threshold,
336            asof_desc,
337        }
338    }
339
340    #[try_stream(ok = Message, error = StreamExecutorError)]
341    async fn into_stream(mut self) {
342        let input_l = self.input_l.take().unwrap();
343        let input_r = self.input_r.take().unwrap();
344        let aligned_stream = barrier_align(
345            input_l.execute(),
346            input_r.execute(),
347            self.ctx.id,
348            self.ctx.fragment_id,
349            self.metrics.clone(),
350            "Join",
351        );
352        pin_mut!(aligned_stream);
353        let actor_id = self.ctx.id;
354
355        let barrier = expect_first_barrier_from_aligned_stream(&mut aligned_stream).await?;
356        let first_epoch = barrier.epoch;
357        // The first barrier message should be propagated.
358        yield Message::Barrier(barrier);
359        self.side_l.init(first_epoch).await?;
360        self.side_r.init(first_epoch).await?;
361
362        let actor_id_str = self.ctx.id.to_string();
363        let fragment_id_str = self.ctx.fragment_id.to_string();
364
365        // initialized some metrics
366        let join_actor_input_waiting_duration_ns = self
367            .metrics
368            .join_actor_input_waiting_duration_ns
369            .with_guarded_label_values(&[&actor_id_str, &fragment_id_str]);
370        let left_join_match_duration_ns = self
371            .metrics
372            .join_match_duration_ns
373            .with_guarded_label_values(&[actor_id_str.as_str(), fragment_id_str.as_str(), "left"]);
374        let right_join_match_duration_ns = self
375            .metrics
376            .join_match_duration_ns
377            .with_guarded_label_values(&[actor_id_str.as_str(), fragment_id_str.as_str(), "right"]);
378
379        let barrier_join_match_duration_ns = self
380            .metrics
381            .join_match_duration_ns
382            .with_guarded_label_values(&[
383                actor_id_str.as_str(),
384                fragment_id_str.as_str(),
385                "barrier",
386            ]);
387
388        let left_join_cached_entry_count = self
389            .metrics
390            .join_cached_entry_count
391            .with_guarded_label_values(&[actor_id_str.as_str(), fragment_id_str.as_str(), "left"]);
392
393        let right_join_cached_entry_count = self
394            .metrics
395            .join_cached_entry_count
396            .with_guarded_label_values(&[actor_id_str.as_str(), fragment_id_str.as_str(), "right"]);
397
398        let mut start_time = Instant::now();
399
400        while let Some(msg) = aligned_stream
401            .next()
402            .instrument_await("hash_join_barrier_align")
403            .await
404        {
405            join_actor_input_waiting_duration_ns.inc_by(start_time.elapsed().as_nanos() as u64);
406            match msg? {
407                AlignedMessage::WatermarkLeft(watermark) => {
408                    for watermark_to_emit in self.handle_watermark(SideType::Left, watermark)? {
409                        yield Message::Watermark(watermark_to_emit);
410                    }
411                }
412                AlignedMessage::WatermarkRight(watermark) => {
413                    for watermark_to_emit in self.handle_watermark(SideType::Right, watermark)? {
414                        yield Message::Watermark(watermark_to_emit);
415                    }
416                }
417                AlignedMessage::Left(chunk) => {
418                    let mut left_time = Duration::from_nanos(0);
419                    let mut left_start_time = Instant::now();
420                    #[for_await]
421                    for chunk in Self::eq_join_left(EqJoinArgs {
422                        ctx: &self.ctx,
423                        side_l: &mut self.side_l,
424                        side_r: &mut self.side_r,
425                        asof_desc: &self.asof_desc,
426                        actual_output_data_types: &self.actual_output_data_types,
427                        // inequality_watermarks: &self.inequality_watermarks,
428                        chunk,
429                        chunk_size: self.chunk_size,
430                        cnt_rows_received: &mut self.cnt_rows_received,
431                        high_join_amplification_threshold: self.high_join_amplification_threshold,
432                    }) {
433                        left_time += left_start_time.elapsed();
434                        yield Message::Chunk(chunk?);
435                        left_start_time = Instant::now();
436                    }
437                    left_time += left_start_time.elapsed();
438                    left_join_match_duration_ns.inc_by(left_time.as_nanos() as u64);
439                    self.try_flush_data().await?;
440                }
441                AlignedMessage::Right(chunk) => {
442                    let mut right_time = Duration::from_nanos(0);
443                    let mut right_start_time = Instant::now();
444                    #[for_await]
445                    for chunk in Self::eq_join_right(EqJoinArgs {
446                        ctx: &self.ctx,
447                        side_l: &mut self.side_l,
448                        side_r: &mut self.side_r,
449                        asof_desc: &self.asof_desc,
450                        actual_output_data_types: &self.actual_output_data_types,
451                        // inequality_watermarks: &self.inequality_watermarks,
452                        chunk,
453                        chunk_size: self.chunk_size,
454                        cnt_rows_received: &mut self.cnt_rows_received,
455                        high_join_amplification_threshold: self.high_join_amplification_threshold,
456                    }) {
457                        right_time += right_start_time.elapsed();
458                        yield Message::Chunk(chunk?);
459                        right_start_time = Instant::now();
460                    }
461                    right_time += right_start_time.elapsed();
462                    right_join_match_duration_ns.inc_by(right_time.as_nanos() as u64);
463                    self.try_flush_data().await?;
464                }
465                AlignedMessage::Barrier(barrier) => {
466                    let barrier_start_time = Instant::now();
467                    let (left_post_commit, right_post_commit) =
468                        self.flush_data(barrier.epoch).await?;
469
470                    let update_vnode_bitmap = barrier.as_update_vnode_bitmap(actor_id);
471                    yield Message::Barrier(barrier);
472
473                    // Update the vnode bitmap for state tables of both sides if asked.
474                    right_post_commit
475                        .post_yield_barrier(update_vnode_bitmap.clone())
476                        .await?;
477                    if left_post_commit
478                        .post_yield_barrier(update_vnode_bitmap)
479                        .await?
480                        .unwrap_or(false)
481                    {
482                        self.watermark_buffers
483                            .values_mut()
484                            .for_each(|buffers| buffers.clear());
485                    }
486
487                    // Report metrics of cached join rows/entries
488                    for (join_cached_entry_count, ht) in [
489                        (&left_join_cached_entry_count, &self.side_l.ht),
490                        (&right_join_cached_entry_count, &self.side_r.ht),
491                    ] {
492                        join_cached_entry_count.set(ht.entry_count() as i64);
493                    }
494
495                    barrier_join_match_duration_ns
496                        .inc_by(barrier_start_time.elapsed().as_nanos() as u64);
497                }
498            }
499            start_time = Instant::now();
500        }
501    }
502
503    async fn flush_data(
504        &mut self,
505        epoch: EpochPair,
506    ) -> StreamExecutorResult<(
507        JoinHashMapPostCommit<'_, K, S, E>,
508        JoinHashMapPostCommit<'_, K, S, E>,
509    )> {
510        // All changes to the state has been buffered in the mem-table of the state table. Just
511        // `commit` them here.
512        let left = self.side_l.ht.flush(epoch).await?;
513        let right = self.side_r.ht.flush(epoch).await?;
514        Ok((left, right))
515    }
516
517    async fn try_flush_data(&mut self) -> StreamExecutorResult<()> {
518        // All changes to the state has been buffered in the mem-table of the state table. Just
519        // `commit` them here.
520        self.side_l.ht.try_flush().await?;
521        self.side_r.ht.try_flush().await?;
522        Ok(())
523    }
524
525    // We need to manually evict the cache.
526    fn evict_cache(
527        side_update: &mut JoinSide<K, S, E>,
528        side_match: &mut JoinSide<K, S, E>,
529        cnt_rows_received: &mut u32,
530    ) {
531        *cnt_rows_received += 1;
532        if *cnt_rows_received == EVICT_EVERY_N_ROWS {
533            side_update.ht.evict();
534            side_match.ht.evict();
535            *cnt_rows_received = 0;
536        }
537    }
538
539    fn handle_watermark(
540        &mut self,
541        side: SideTypePrimitive,
542        watermark: Watermark,
543    ) -> StreamExecutorResult<Vec<Watermark>> {
544        let (side_update, side_match) = if side == SideType::Left {
545            (&mut self.side_l, &mut self.side_r)
546        } else {
547            (&mut self.side_r, &mut self.side_l)
548        };
549
550        // State cleaning
551        if side_update.join_key_indices[0] == watermark.col_idx {
552            side_match.ht.update_watermark(watermark.val.clone());
553        }
554
555        // Select watermarks to yield.
556        let wm_in_jk = side_update
557            .join_key_indices
558            .iter()
559            .positions(|idx| *idx == watermark.col_idx);
560        let mut watermarks_to_emit = vec![];
561        for idx in wm_in_jk {
562            let buffers = self
563                .watermark_buffers
564                .entry(idx)
565                .or_insert_with(|| BufferedWatermarks::with_ids([SideType::Left, SideType::Right]));
566            if let Some(selected_watermark) = buffers.handle_watermark(side, watermark.clone()) {
567                let empty_indices = vec![];
568                let output_indices = side_update
569                    .i2o_mapping_indexed
570                    .get_vec(&side_update.join_key_indices[idx])
571                    .unwrap_or(&empty_indices)
572                    .iter()
573                    .chain(
574                        side_match
575                            .i2o_mapping_indexed
576                            .get_vec(&side_match.join_key_indices[idx])
577                            .unwrap_or(&empty_indices),
578                    );
579                for output_idx in output_indices {
580                    watermarks_to_emit.push(selected_watermark.clone().with_idx(*output_idx));
581                }
582            };
583        }
584        Ok(watermarks_to_emit)
585    }
586
587    /// the data the hash table and match the coming
588    /// data chunk with the executor state
589    async fn hash_eq_match(
590        key: &K,
591        ht: &mut JoinHashMap<K, S, E>,
592    ) -> StreamExecutorResult<Option<HashValueType<E>>> {
593        if !key.null_bitmap().is_subset(ht.null_matched()) {
594            Ok(None)
595        } else {
596            ht.take_state(key).await.map(Some)
597        }
598    }
599
600    #[try_stream(ok = StreamChunk, error = StreamExecutorError)]
601    async fn eq_join_left(args: EqJoinArgs<'_, K, S, E>) {
602        let EqJoinArgs {
603            ctx: _,
604            side_l,
605            side_r,
606            asof_desc,
607            actual_output_data_types,
608            // inequality_watermarks,
609            chunk,
610            chunk_size,
611            cnt_rows_received,
612            high_join_amplification_threshold: _,
613        } = args;
614
615        let (side_update, side_match) = (side_l, side_r);
616
617        let mut join_chunk_builder =
618            JoinChunkBuilder::<T, { SideType::Left }>::new(JoinStreamChunkBuilder::new(
619                chunk_size,
620                actual_output_data_types.to_vec(),
621                side_update.i2o_mapping.clone(),
622                side_match.i2o_mapping.clone(),
623            ));
624
625        let keys = K::build_many(&side_update.join_key_indices, chunk.data_chunk());
626        for (r, key) in chunk.rows_with_holes().zip_eq_debug(keys.iter()) {
627            let Some((op, row)) = r else {
628                continue;
629            };
630            Self::evict_cache(side_update, side_match, cnt_rows_received);
631
632            let matched_rows = if !side_update.ht.check_inequal_key_null(&row) {
633                Self::hash_eq_match(key, &mut side_match.ht).await?
634            } else {
635                None
636            };
637            let inequal_key = side_update.ht.serialize_inequal_key_from_row(row);
638
639            if let Some(matched_rows) = matched_rows {
640                let matched_row_by_inequality = match asof_desc.inequality_type {
641                    AsOfInequalityType::Lt => matched_rows.lower_bound_by_inequality(
642                        Bound::Excluded(&inequal_key),
643                        &side_match.all_data_types,
644                    ),
645                    AsOfInequalityType::Le => matched_rows.lower_bound_by_inequality(
646                        Bound::Included(&inequal_key),
647                        &side_match.all_data_types,
648                    ),
649                    AsOfInequalityType::Gt => matched_rows.upper_bound_by_inequality(
650                        Bound::Excluded(&inequal_key),
651                        &side_match.all_data_types,
652                    ),
653                    AsOfInequalityType::Ge => matched_rows.upper_bound_by_inequality(
654                        Bound::Included(&inequal_key),
655                        &side_match.all_data_types,
656                    ),
657                };
658                match op {
659                    Op::Insert | Op::UpdateInsert => {
660                        if let Some(matched_row_by_inequality) = matched_row_by_inequality {
661                            let matched_row = matched_row_by_inequality?;
662
663                            if let Some(chunk) =
664                                join_chunk_builder.with_match_on_insert(&row, &matched_row)
665                            {
666                                yield chunk;
667                            }
668                        } else if let Some(chunk) =
669                            join_chunk_builder.forward_if_not_matched(Op::Insert, row)
670                        {
671                            yield chunk;
672                        }
673                        side_update.ht.insert_row(key, row)?;
674                    }
675                    Op::Delete | Op::UpdateDelete => {
676                        if let Some(matched_row_by_inequality) = matched_row_by_inequality {
677                            let matched_row = matched_row_by_inequality?;
678
679                            if let Some(chunk) =
680                                join_chunk_builder.with_match_on_delete(&row, &matched_row)
681                            {
682                                yield chunk;
683                            }
684                        } else if let Some(chunk) =
685                            join_chunk_builder.forward_if_not_matched(Op::Delete, row)
686                        {
687                            yield chunk;
688                        }
689                        side_update.ht.delete_row(key, row)?;
690                    }
691                }
692                // Insert back the state taken from ht.
693                side_match.ht.update_state(key, matched_rows);
694            } else {
695                // Row which violates null-safe bitmap will never be matched so we need not
696                // store.
697                match op {
698                    Op::Insert | Op::UpdateInsert => {
699                        if let Some(chunk) =
700                            join_chunk_builder.forward_if_not_matched(Op::Insert, row)
701                        {
702                            yield chunk;
703                        }
704                    }
705                    Op::Delete | Op::UpdateDelete => {
706                        if let Some(chunk) =
707                            join_chunk_builder.forward_if_not_matched(Op::Delete, row)
708                        {
709                            yield chunk;
710                        }
711                    }
712                }
713            }
714        }
715        if let Some(chunk) = join_chunk_builder.take() {
716            yield chunk;
717        }
718    }
719
720    #[try_stream(ok = StreamChunk, error = StreamExecutorError)]
721    async fn eq_join_right(args: EqJoinArgs<'_, K, S, E>) {
722        let EqJoinArgs {
723            ctx,
724            side_l,
725            side_r,
726            asof_desc,
727            actual_output_data_types,
728            // inequality_watermarks,
729            chunk,
730            chunk_size,
731            cnt_rows_received,
732            high_join_amplification_threshold,
733        } = args;
734
735        let (side_update, side_match) = (side_r, side_l);
736
737        let mut join_chunk_builder = JoinStreamChunkBuilder::new(
738            chunk_size,
739            actual_output_data_types.to_vec(),
740            side_update.i2o_mapping.clone(),
741            side_match.i2o_mapping.clone(),
742        );
743
744        let join_matched_rows_metrics = ctx
745            .streaming_metrics
746            .join_matched_join_keys
747            .with_guarded_label_values(&[
748                &ctx.id.to_string(),
749                &ctx.fragment_id.to_string(),
750                &side_update.ht.table_id().to_string(),
751            ]);
752
753        let keys = K::build_many(&side_update.join_key_indices, chunk.data_chunk());
754        for (r, key) in chunk.rows_with_holes().zip_eq_debug(keys.iter()) {
755            let Some((op, row)) = r else {
756                continue;
757            };
758            let mut join_matched_rows_cnt = 0;
759
760            Self::evict_cache(side_update, side_match, cnt_rows_received);
761
762            let matched_rows = if !side_update.ht.check_inequal_key_null(&row) {
763                Self::hash_eq_match(key, &mut side_match.ht).await?
764            } else {
765                None
766            };
767            let inequal_key = side_update.ht.serialize_inequal_key_from_row(row);
768
769            if let Some(matched_rows) = matched_rows {
770                let update_rows = Self::hash_eq_match(key, &mut side_update.ht).await?.expect("None is not expected because we have checked null in key when getting matched_rows");
771                let right_inequality_index = update_rows.inequality_index();
772                let (row_to_delete_r, row_to_insert_r) =
773                    if let Some(pks) = right_inequality_index.get(&inequal_key) {
774                        assert!(!pks.is_empty());
775                        let row_pk = side_update.ht.serialize_pk_from_row(row);
776                        match op {
777                            Op::Insert | Op::UpdateInsert => {
778                                // If there are multiple rows match the inequality key in the right table, we use one with smallest pk.
779                                let smallest_pk = pks.first_key_sorted().unwrap();
780                                if smallest_pk > &row_pk {
781                                    // smallest_pk is in the cache index, so it must exist in the cache.
782                                    if let Some(to_delete_row) = update_rows
783                                        .get_by_indexed_pk(smallest_pk, &side_update.all_data_types)
784                                    {
785                                        (
786                                            Some(Either::Left(to_delete_row?.row)),
787                                            Some(Either::Right(row)),
788                                        )
789                                    } else {
790                                        // Something wrong happened. Ignore this row in non strict consistency mode.
791                                        (None, None)
792                                    }
793                                } else {
794                                    // No affected row in the right table.
795                                    (None, None)
796                                }
797                            }
798                            Op::Delete | Op::UpdateDelete => {
799                                let smallest_pk = pks.first_key_sorted().unwrap();
800                                if smallest_pk == &row_pk {
801                                    if let Some(second_smallest_pk) = pks.second_key_sorted() {
802                                        if let Some(to_insert_row) = update_rows.get_by_indexed_pk(
803                                            second_smallest_pk,
804                                            &side_update.all_data_types,
805                                        ) {
806                                            (
807                                                Some(Either::Right(row)),
808                                                Some(Either::Left(to_insert_row?.row)),
809                                            )
810                                        } else {
811                                            // Something wrong happened. Ignore this row in non strict consistency mode.
812                                            (None, None)
813                                        }
814                                    } else {
815                                        (Some(Either::Right(row)), None)
816                                    }
817                                } else {
818                                    // No affected row in the right table.
819                                    (None, None)
820                                }
821                            }
822                        }
823                    } else {
824                        match op {
825                            // Decide the row_to_delete later
826                            Op::Insert | Op::UpdateInsert => (None, Some(Either::Right(row))),
827                            // Decide the row_to_insert later
828                            Op::Delete | Op::UpdateDelete => (Some(Either::Right(row)), None),
829                        }
830                    };
831
832                // 4 cases for row_to_delete_r and row_to_insert_r:
833                // 1. Some(_), Some(_): delete row_to_delete_r and insert row_to_insert_r
834                // 2. None, Some(_)   : row_to_delete to be decided by the nearest inequality key
835                // 3. Some(_), None   : row_to_insert to be decided by the nearest inequality key
836                // 4. None, None      : do nothing
837                if row_to_delete_r.is_none() && row_to_insert_r.is_none() {
838                    // no row to delete or insert.
839                } else {
840                    let prev_inequality_key =
841                        right_inequality_index.upper_bound_key(Bound::Excluded(&inequal_key));
842                    let next_inequality_key =
843                        right_inequality_index.lower_bound_key(Bound::Excluded(&inequal_key));
844                    let affected_row_r = match asof_desc.inequality_type {
845                        AsOfInequalityType::Lt | AsOfInequalityType::Le => next_inequality_key
846                            .and_then(|k| {
847                                update_rows.get_first_by_inequality(k, &side_update.all_data_types)
848                            }),
849                        AsOfInequalityType::Gt | AsOfInequalityType::Ge => prev_inequality_key
850                            .and_then(|k| {
851                                update_rows.get_first_by_inequality(k, &side_update.all_data_types)
852                            }),
853                    }
854                    .transpose()?
855                    .map(|r| Either::Left(r.row));
856
857                    let (row_to_delete_r, row_to_insert_r) =
858                        match (&row_to_delete_r, &row_to_insert_r) {
859                            (Some(_), Some(_)) => (row_to_delete_r, row_to_insert_r),
860                            (None, Some(_)) => (affected_row_r, row_to_insert_r),
861                            (Some(_), None) => (row_to_delete_r, affected_row_r),
862                            (None, None) => unreachable!(),
863                        };
864                    let range = match asof_desc.inequality_type {
865                        AsOfInequalityType::Lt => (
866                            prev_inequality_key.map_or_else(|| Bound::Unbounded, Bound::Included),
867                            Bound::Excluded(&inequal_key),
868                        ),
869                        AsOfInequalityType::Le => (
870                            prev_inequality_key.map_or_else(|| Bound::Unbounded, Bound::Excluded),
871                            Bound::Included(&inequal_key),
872                        ),
873                        AsOfInequalityType::Gt => (
874                            Bound::Excluded(&inequal_key),
875                            next_inequality_key.map_or_else(|| Bound::Unbounded, Bound::Included),
876                        ),
877                        AsOfInequalityType::Ge => (
878                            Bound::Included(&inequal_key),
879                            next_inequality_key.map_or_else(|| Bound::Unbounded, Bound::Excluded),
880                        ),
881                    };
882
883                    let rows_l =
884                        matched_rows.range_by_inequality(range, &side_match.all_data_types);
885                    for row_l in rows_l {
886                        join_matched_rows_cnt += 1;
887                        let row_l = row_l?.row;
888                        if let Some(row_to_delete_r) = &row_to_delete_r {
889                            if let Some(chunk) =
890                                join_chunk_builder.append_row(Op::Delete, row_to_delete_r, &row_l)
891                            {
892                                yield chunk;
893                            }
894                        } else if is_as_of_left_outer(T)
895                            && let Some(chunk) =
896                                join_chunk_builder.append_row_matched(Op::Delete, &row_l)
897                        {
898                            yield chunk;
899                        }
900                        if let Some(row_to_insert_r) = &row_to_insert_r {
901                            if let Some(chunk) =
902                                join_chunk_builder.append_row(Op::Insert, row_to_insert_r, &row_l)
903                            {
904                                yield chunk;
905                            }
906                        } else if is_as_of_left_outer(T)
907                            && let Some(chunk) =
908                                join_chunk_builder.append_row_matched(Op::Insert, &row_l)
909                        {
910                            yield chunk;
911                        }
912                    }
913                }
914                // Insert back the state taken from ht.
915                side_match.ht.update_state(key, matched_rows);
916                side_update.ht.update_state(key, update_rows);
917
918                match op {
919                    Op::Insert | Op::UpdateInsert => {
920                        side_update.ht.insert_row(key, row)?;
921                    }
922                    Op::Delete | Op::UpdateDelete => {
923                        side_update.ht.delete_row(key, row)?;
924                    }
925                }
926            } else {
927                // Row which violates null-safe bitmap will never be matched so we need not
928                // store.
929                // Noop here because we only support left outer AsOf join.
930            }
931            join_matched_rows_metrics.observe(join_matched_rows_cnt as _);
932            if join_matched_rows_cnt > high_join_amplification_threshold {
933                let join_key_data_types = side_update.ht.join_key_data_types();
934                let key = key.deserialize(join_key_data_types)?;
935                tracing::warn!(target: "high_join_amplification",
936                    matched_rows_len = join_matched_rows_cnt,
937                    update_table_id = side_update.ht.table_id(),
938                    match_table_id = side_match.ht.table_id(),
939                    join_key = ?key,
940                    actor_id = ctx.id,
941                    fragment_id = ctx.fragment_id,
942                    "large rows matched for join key when AsOf join updating right side",
943                );
944            }
945        }
946        if let Some(chunk) = join_chunk_builder.take() {
947            yield chunk;
948        }
949    }
950}
951
952#[cfg(test)]
953mod tests {
954    use std::sync::atomic::AtomicU64;
955
956    use risingwave_common::array::*;
957    use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, TableId};
958    use risingwave_common::hash::Key64;
959    use risingwave_common::util::epoch::test_epoch;
960    use risingwave_common::util::sort_util::OrderType;
961    use risingwave_storage::memory::MemoryStateStore;
962
963    use super::*;
964    use crate::common::table::test_utils::gen_pbtable;
965    use crate::executor::MemoryEncoding;
966    use crate::executor::test_utils::{MessageSender, MockSource, StreamExecutorTestExt};
967
968    async fn create_in_memory_state_table(
969        mem_state: MemoryStateStore,
970        data_types: &[DataType],
971        order_types: &[OrderType],
972        pk_indices: &[usize],
973        table_id: u32,
974    ) -> StateTable<MemoryStateStore> {
975        let column_descs = data_types
976            .iter()
977            .enumerate()
978            .map(|(id, data_type)| ColumnDesc::unnamed(ColumnId::new(id as i32), data_type.clone()))
979            .collect_vec();
980        StateTable::from_table_catalog(
981            &gen_pbtable(
982                TableId::new(table_id),
983                column_descs,
984                order_types.to_vec(),
985                pk_indices.to_vec(),
986                0,
987            ),
988            mem_state.clone(),
989            None,
990        )
991        .await
992    }
993
994    async fn create_executor<const T: AsOfJoinTypePrimitive>(
995        asof_desc: AsOfDesc,
996    ) -> (MessageSender, MessageSender, BoxedMessageStream) {
997        let schema = Schema {
998            fields: vec![
999                Field::unnamed(DataType::Int64), // join key
1000                Field::unnamed(DataType::Int64),
1001                Field::unnamed(DataType::Int64),
1002            ],
1003        };
1004        let (tx_l, source_l) = MockSource::channel();
1005        let source_l = source_l.into_executor(schema.clone(), vec![1]);
1006        let (tx_r, source_r) = MockSource::channel();
1007        let source_r = source_r.into_executor(schema, vec![1]);
1008        let params_l = JoinParams::new(vec![0], vec![1]);
1009        let params_r = JoinParams::new(vec![0], vec![1]);
1010
1011        let mem_state = MemoryStateStore::new();
1012
1013        let state_l = create_in_memory_state_table(
1014            mem_state.clone(),
1015            &[DataType::Int64, DataType::Int64, DataType::Int64],
1016            &[
1017                OrderType::ascending(),
1018                OrderType::ascending(),
1019                OrderType::ascending(),
1020            ],
1021            &[0, asof_desc.left_idx, 1],
1022            0,
1023        )
1024        .await;
1025
1026        let state_r = create_in_memory_state_table(
1027            mem_state,
1028            &[DataType::Int64, DataType::Int64, DataType::Int64],
1029            &[
1030                OrderType::ascending(),
1031                OrderType::ascending(),
1032                OrderType::ascending(),
1033            ],
1034            &[0, asof_desc.right_idx, 1],
1035            1,
1036        )
1037        .await;
1038
1039        let schema: Schema = [source_l.schema().fields(), source_r.schema().fields()]
1040            .concat()
1041            .into_iter()
1042            .collect();
1043        let schema_len = schema.len();
1044        let info = ExecutorInfo::new(schema, vec![1], "HashJoinExecutor".to_owned(), 0);
1045
1046        let executor = AsOfJoinExecutor::<Key64, MemoryStateStore, T, MemoryEncoding>::new(
1047            ActorContext::for_test(123),
1048            info,
1049            source_l,
1050            source_r,
1051            params_l,
1052            params_r,
1053            vec![false],
1054            (0..schema_len).collect_vec(),
1055            state_l,
1056            state_r,
1057            Arc::new(AtomicU64::new(0)),
1058            Arc::new(StreamingMetrics::unused()),
1059            1024,
1060            2048,
1061            asof_desc,
1062        );
1063        (tx_l, tx_r, executor.boxed().execute())
1064    }
1065
1066    #[tokio::test]
1067    async fn test_as_of_inner_join() -> StreamExecutorResult<()> {
1068        let asof_desc = AsOfDesc {
1069            left_idx: 0,
1070            right_idx: 2,
1071            inequality_type: AsOfInequalityType::Lt,
1072        };
1073
1074        let chunk_l1 = StreamChunk::from_pretty(
1075            "  I I I
1076             + 1 4 7
1077             + 2 5 8
1078             + 3 6 9",
1079        );
1080        let chunk_l2 = StreamChunk::from_pretty(
1081            "  I I I
1082             + 3 8 1
1083             - 3 8 1",
1084        );
1085        let chunk_r1 = StreamChunk::from_pretty(
1086            "  I I I
1087             + 2 1 7
1088             + 2 2 1
1089             + 2 3 4
1090             + 2 4 2
1091             + 6 1 9
1092             + 6 2 9",
1093        );
1094        let chunk_r2 = StreamChunk::from_pretty(
1095            "  I I I
1096             - 2 3 4",
1097        );
1098        let chunk_r3 = StreamChunk::from_pretty(
1099            "  I I I
1100             + 2 3 3",
1101        );
1102        let chunk_l3 = StreamChunk::from_pretty(
1103            "  I I I
1104             - 2 5 8",
1105        );
1106        let chunk_l4 = StreamChunk::from_pretty(
1107            "  I I I
1108             + 6 3 1
1109             + 6 4 1",
1110        );
1111        let chunk_r4 = StreamChunk::from_pretty(
1112            "  I I I
1113             - 6 1 9",
1114        );
1115
1116        let (mut tx_l, mut tx_r, mut hash_join) =
1117            create_executor::<{ AsOfJoinType::Inner }>(asof_desc).await;
1118
1119        // push the init barrier for left and right
1120        tx_l.push_barrier(test_epoch(1), false);
1121        tx_r.push_barrier(test_epoch(1), false);
1122        hash_join.next_unwrap_ready_barrier()?;
1123
1124        // push the 1st left chunk
1125        tx_l.push_chunk(chunk_l1);
1126        hash_join.next_unwrap_pending();
1127
1128        // push the init barrier for left and right
1129        tx_l.push_barrier(test_epoch(2), false);
1130        tx_r.push_barrier(test_epoch(2), false);
1131        hash_join.next_unwrap_ready_barrier()?;
1132
1133        // push the 2nd left chunk
1134        tx_l.push_chunk(chunk_l2);
1135        hash_join.next_unwrap_pending();
1136
1137        // push the 1st right chunk
1138        tx_r.push_chunk(chunk_r1);
1139        let chunk = hash_join.next_unwrap_ready_chunk()?;
1140        assert_eq!(
1141            chunk,
1142            StreamChunk::from_pretty(
1143                " I I I I I I
1144                + 2 5 8 2 1 7
1145                - 2 5 8 2 1 7
1146                + 2 5 8 2 3 4"
1147            )
1148        );
1149
1150        // push the 2nd right chunk
1151        tx_r.push_chunk(chunk_r2);
1152        let chunk = hash_join.next_unwrap_ready_chunk()?;
1153        assert_eq!(
1154            chunk,
1155            StreamChunk::from_pretty(
1156                " I I I I I I
1157                - 2 5 8 2 3 4
1158                + 2 5 8 2 1 7"
1159            )
1160        );
1161
1162        // push the 3rd right chunk
1163        tx_r.push_chunk(chunk_r3);
1164        let chunk = hash_join.next_unwrap_ready_chunk()?;
1165        assert_eq!(
1166            chunk,
1167            StreamChunk::from_pretty(
1168                " I I I I I I
1169                - 2 5 8 2 1 7
1170                + 2 5 8 2 3 3"
1171            )
1172        );
1173
1174        // push the 3rd left chunk
1175        tx_l.push_chunk(chunk_l3);
1176        let chunk = hash_join.next_unwrap_ready_chunk()?;
1177        assert_eq!(
1178            chunk,
1179            StreamChunk::from_pretty(
1180                " I I I I I I
1181                - 2 5 8 2 3 3"
1182            )
1183        );
1184
1185        // push the 4th left chunk
1186        tx_l.push_chunk(chunk_l4);
1187        let chunk = hash_join.next_unwrap_ready_chunk()?;
1188        assert_eq!(
1189            chunk,
1190            StreamChunk::from_pretty(
1191                " I I I I I I
1192                + 6 3 1 6 1 9
1193                + 6 4 1 6 1 9"
1194            )
1195        );
1196
1197        // push the 4th right chunk
1198        tx_r.push_chunk(chunk_r4);
1199        let chunk = hash_join.next_unwrap_ready_chunk()?;
1200        assert_eq!(
1201            chunk,
1202            StreamChunk::from_pretty(
1203                " I I I I I I
1204                - 6 3 1 6 1 9
1205                + 6 3 1 6 2 9
1206                - 6 4 1 6 1 9
1207                + 6 4 1 6 2 9"
1208            )
1209        );
1210
1211        Ok(())
1212    }
1213
1214    #[tokio::test]
1215    async fn test_as_of_left_outer_join() -> StreamExecutorResult<()> {
1216        let asof_desc = AsOfDesc {
1217            left_idx: 1,
1218            right_idx: 2,
1219            inequality_type: AsOfInequalityType::Ge,
1220        };
1221
1222        let chunk_l1 = StreamChunk::from_pretty(
1223            "  I I I
1224             + 1 4 7
1225             + 2 5 8
1226             + 3 6 9",
1227        );
1228        let chunk_l2 = StreamChunk::from_pretty(
1229            "  I I I
1230             + 3 8 1
1231             - 3 8 1",
1232        );
1233        let chunk_r1 = StreamChunk::from_pretty(
1234            "  I I I
1235             + 2 3 4
1236             + 2 2 5
1237             + 2 1 5
1238             + 6 1 8
1239             + 6 2 9",
1240        );
1241        let chunk_r2 = StreamChunk::from_pretty(
1242            "  I I I
1243             - 2 3 4
1244             - 2 1 5
1245             - 2 2 5",
1246        );
1247        let chunk_l3 = StreamChunk::from_pretty(
1248            "  I I I
1249             + 6 8 9",
1250        );
1251        let chunk_r3 = StreamChunk::from_pretty(
1252            "  I I I
1253             - 6 1 8",
1254        );
1255
1256        let (mut tx_l, mut tx_r, mut hash_join) =
1257            create_executor::<{ AsOfJoinType::LeftOuter }>(asof_desc).await;
1258
1259        // push the init barrier for left and right
1260        tx_l.push_barrier(test_epoch(1), false);
1261        tx_r.push_barrier(test_epoch(1), false);
1262        hash_join.next_unwrap_ready_barrier()?;
1263
1264        // push the 1st left chunk
1265        tx_l.push_chunk(chunk_l1);
1266        let chunk = hash_join.next_unwrap_ready_chunk()?;
1267        assert_eq!(
1268            chunk,
1269            StreamChunk::from_pretty(
1270                " I I I I I I
1271                + 1 4 7 . . .
1272                + 2 5 8 . . .
1273                + 3 6 9 . . ."
1274            )
1275        );
1276
1277        // push the init barrier for left and right
1278        tx_l.push_barrier(test_epoch(2), false);
1279        tx_r.push_barrier(test_epoch(2), false);
1280        hash_join.next_unwrap_ready_barrier()?;
1281
1282        // push the 2nd left chunk
1283        tx_l.push_chunk(chunk_l2);
1284        let chunk = hash_join.next_unwrap_ready_chunk()?;
1285        assert_eq!(
1286            chunk,
1287            StreamChunk::from_pretty(
1288                " I I I I I I
1289                + 3 8 1 . . .
1290                - 3 8 1 . . ."
1291            )
1292        );
1293
1294        // push the 1st right chunk
1295        tx_r.push_chunk(chunk_r1);
1296        let chunk = hash_join.next_unwrap_ready_chunk()?;
1297        assert_eq!(
1298            chunk,
1299            StreamChunk::from_pretty(
1300                " I I I I I I
1301                - 2 5 8 . . .
1302                + 2 5 8 2 3 4
1303                - 2 5 8 2 3 4
1304                + 2 5 8 2 2 5
1305                - 2 5 8 2 2 5
1306                + 2 5 8 2 1 5"
1307            )
1308        );
1309
1310        // push the 2nd right chunk
1311        tx_r.push_chunk(chunk_r2);
1312        let chunk = hash_join.next_unwrap_ready_chunk()?;
1313        assert_eq!(
1314            chunk,
1315            StreamChunk::from_pretty(
1316                " I I I I I I
1317                - 2 5 8 2 1 5
1318                + 2 5 8 2 2 5
1319                - 2 5 8 2 2 5
1320                + 2 5 8 . . ."
1321            )
1322        );
1323
1324        // push the 3rd left chunk
1325        tx_l.push_chunk(chunk_l3);
1326        let chunk = hash_join.next_unwrap_ready_chunk()?;
1327        assert_eq!(
1328            chunk,
1329            StreamChunk::from_pretty(
1330                " I I I I I I
1331                + 6 8 9 6 1 8"
1332            )
1333        );
1334
1335        // push the 3rd right chunk
1336        tx_r.push_chunk(chunk_r3);
1337        let chunk = hash_join.next_unwrap_ready_chunk()?;
1338        assert_eq!(
1339            chunk,
1340            StreamChunk::from_pretty(
1341                " I I I I I I
1342                - 6 8 9 6 1 8
1343                + 6 8 9 . . ."
1344            )
1345        );
1346        Ok(())
1347    }
1348}