risingwave_stream/executor/
asof_join.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashSet};
16use std::marker::PhantomData;
17use std::ops::Bound;
18use std::time::Duration;
19
20use either::Either;
21use itertools::Itertools;
22use multimap::MultiMap;
23use risingwave_common::array::Op;
24use risingwave_common::hash::{HashKey, NullBitmap};
25use risingwave_common::util::epoch::EpochPair;
26use risingwave_common::util::iter_util::ZipEqDebug;
27use tokio::time::Instant;
28
29use self::builder::JoinChunkBuilder;
30use super::barrier_align::*;
31use super::join::hash_join::*;
32use super::join::*;
33use super::watermark::*;
34use crate::executor::join::builder::JoinStreamChunkBuilder;
35use crate::executor::join::row::JoinEncoding;
36use crate::executor::prelude::*;
37
38/// Evict the cache every n rows.
39const EVICT_EVERY_N_ROWS: u32 = 16;
40
41fn is_subset(vec1: Vec<usize>, vec2: Vec<usize>) -> bool {
42    HashSet::<usize>::from_iter(vec1).is_subset(&vec2.into_iter().collect())
43}
44
45pub struct JoinParams {
46    /// Indices of the join keys
47    pub join_key_indices: Vec<usize>,
48    /// Indices of the input pk after dedup
49    pub deduped_pk_indices: Vec<usize>,
50}
51
52impl JoinParams {
53    pub fn new(join_key_indices: Vec<usize>, deduped_pk_indices: Vec<usize>) -> Self {
54        Self {
55            join_key_indices,
56            deduped_pk_indices,
57        }
58    }
59}
60
61struct JoinSide<K: HashKey, S: StateStore, E: JoinEncoding> {
62    /// Store all data from a one side stream
63    ht: JoinHashMap<K, S, E>,
64    /// Indices of the join key columns
65    join_key_indices: Vec<usize>,
66    /// The data type of all columns without degree.
67    all_data_types: Vec<DataType>,
68    /// The start position for the side in output new columns
69    start_pos: usize,
70    /// The mapping from input indices of a side to output columns.
71    i2o_mapping: Vec<(usize, usize)>,
72    i2o_mapping_indexed: MultiMap<usize, usize>,
73    /// Whether degree table is needed for this side.
74    need_degree_table: bool,
75    _marker: std::marker::PhantomData<E>,
76}
77
78impl<K: HashKey, S: StateStore, E: JoinEncoding> std::fmt::Debug for JoinSide<K, S, E> {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        f.debug_struct("JoinSide")
81            .field("join_key_indices", &self.join_key_indices)
82            .field("col_types", &self.all_data_types)
83            .field("start_pos", &self.start_pos)
84            .field("i2o_mapping", &self.i2o_mapping)
85            .field("need_degree_table", &self.need_degree_table)
86            .finish()
87    }
88}
89
90impl<K: HashKey, S: StateStore, E: JoinEncoding> JoinSide<K, S, E> {
91    // WARNING: Please do not call this until we implement it.
92    fn is_dirty(&self) -> bool {
93        unimplemented!()
94    }
95
96    #[expect(dead_code)]
97    fn clear_cache(&mut self) {
98        assert!(
99            !self.is_dirty(),
100            "cannot clear cache while states of hash join are dirty"
101        );
102
103        // TODO: not working with rearranged chain
104        // self.ht.clear();
105    }
106
107    pub async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> {
108        self.ht.init(epoch).await
109    }
110}
111
112/// `AsOfJoinExecutor` takes two input streams and runs equal hash join on them.
113/// The output columns are the concatenation of left and right columns.
114pub struct AsOfJoinExecutor<
115    K: HashKey,
116    S: StateStore,
117    const T: AsOfJoinTypePrimitive,
118    E: JoinEncoding,
119> {
120    ctx: ActorContextRef,
121    info: ExecutorInfo,
122
123    /// Left input executor
124    input_l: Option<Executor>,
125    /// Right input executor
126    input_r: Option<Executor>,
127    /// The data types of the formed new columns
128    actual_output_data_types: Vec<DataType>,
129    /// The parameters of the left join executor
130    side_l: JoinSide<K, S, E>,
131    /// The parameters of the right join executor
132    side_r: JoinSide<K, S, E>,
133
134    metrics: Arc<StreamingMetrics>,
135    /// The maximum size of the chunk produced by executor at a time
136    chunk_size: usize,
137    /// Count the messages received, clear to 0 when counted to `EVICT_EVERY_N_MESSAGES`
138    cnt_rows_received: u32,
139
140    /// watermark column index -> `BufferedWatermarks`
141    watermark_buffers: BTreeMap<usize, BufferedWatermarks<SideTypePrimitive>>,
142
143    high_join_amplification_threshold: usize,
144    /// `AsOf` join description
145    asof_desc: AsOfDesc,
146}
147
148impl<K: HashKey, S: StateStore, const T: AsOfJoinTypePrimitive, E: JoinEncoding> std::fmt::Debug
149    for AsOfJoinExecutor<K, S, T, E>
150{
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        f.debug_struct("AsOfJoinExecutor")
153            .field("join_type", &T)
154            .field("input_left", &self.input_l.as_ref().unwrap().identity())
155            .field("input_right", &self.input_r.as_ref().unwrap().identity())
156            .field("side_l", &self.side_l)
157            .field("side_r", &self.side_r)
158            .field("stream_key", &self.info.stream_key)
159            .field("schema", &self.info.schema)
160            .field("actual_output_data_types", &self.actual_output_data_types)
161            .finish()
162    }
163}
164
165impl<K: HashKey, S: StateStore, const T: AsOfJoinTypePrimitive, E: JoinEncoding> Execute
166    for AsOfJoinExecutor<K, S, T, E>
167{
168    fn execute(self: Box<Self>) -> BoxedMessageStream {
169        self.into_stream().boxed()
170    }
171}
172
173struct EqJoinArgs<'a, K: HashKey, S: StateStore, E: JoinEncoding> {
174    ctx: &'a ActorContextRef,
175    side_l: &'a mut JoinSide<K, S, E>,
176    side_r: &'a mut JoinSide<K, S, E>,
177    asof_desc: &'a AsOfDesc,
178    actual_output_data_types: &'a [DataType],
179    // inequality_watermarks: &'a Watermark,
180    chunk: StreamChunk,
181    chunk_size: usize,
182    cnt_rows_received: &'a mut u32,
183    high_join_amplification_threshold: usize,
184}
185
186impl<K: HashKey, S: StateStore, const T: AsOfJoinTypePrimitive, E: JoinEncoding>
187    AsOfJoinExecutor<K, S, T, E>
188{
189    #[allow(clippy::too_many_arguments)]
190    pub fn new(
191        ctx: ActorContextRef,
192        info: ExecutorInfo,
193        input_l: Executor,
194        input_r: Executor,
195        params_l: JoinParams,
196        params_r: JoinParams,
197        null_safe: Vec<bool>,
198        output_indices: Vec<usize>,
199        state_table_l: StateTable<S>,
200        state_table_r: StateTable<S>,
201        watermark_epoch: AtomicU64Ref,
202        metrics: Arc<StreamingMetrics>,
203        chunk_size: usize,
204        high_join_amplification_threshold: usize,
205        asof_desc: AsOfDesc,
206    ) -> Self {
207        let side_l_column_n = input_l.schema().len();
208
209        let schema_fields = [
210            input_l.schema().fields.clone(),
211            input_r.schema().fields.clone(),
212        ]
213        .concat();
214
215        let original_output_data_types = schema_fields
216            .iter()
217            .map(|field| field.data_type())
218            .collect_vec();
219        let actual_output_data_types = output_indices
220            .iter()
221            .map(|&idx| original_output_data_types[idx].clone())
222            .collect_vec();
223
224        // Data types of of hash join state.
225        let state_all_data_types_l = input_l.schema().data_types();
226        let state_all_data_types_r = input_r.schema().data_types();
227
228        let state_pk_indices_l = input_l.stream_key().to_vec();
229        let state_pk_indices_r = input_r.stream_key().to_vec();
230
231        let state_join_key_indices_l = params_l.join_key_indices;
232        let state_join_key_indices_r = params_r.join_key_indices;
233
234        // If pk is contained in join key.
235        let pk_contained_in_jk_l = is_subset(state_pk_indices_l, state_join_key_indices_l.clone());
236        let pk_contained_in_jk_r = is_subset(state_pk_indices_r, state_join_key_indices_r.clone());
237
238        let join_key_data_types_l = state_join_key_indices_l
239            .iter()
240            .map(|idx| state_all_data_types_l[*idx].clone())
241            .collect_vec();
242
243        let join_key_data_types_r = state_join_key_indices_r
244            .iter()
245            .map(|idx| state_all_data_types_r[*idx].clone())
246            .collect_vec();
247
248        assert_eq!(join_key_data_types_l, join_key_data_types_r);
249
250        let null_matched = K::Bitmap::from_bool_vec(null_safe);
251
252        let (left_to_output, right_to_output) = {
253            let (left_len, right_len) = if is_left_semi_or_anti(T) {
254                (state_all_data_types_l.len(), 0usize)
255            } else if is_right_semi_or_anti(T) {
256                (0usize, state_all_data_types_r.len())
257            } else {
258                (state_all_data_types_l.len(), state_all_data_types_r.len())
259            };
260            JoinStreamChunkBuilder::get_i2o_mapping(&output_indices, left_len, right_len)
261        };
262
263        let l2o_indexed = MultiMap::from_iter(left_to_output.iter().copied());
264        let r2o_indexed = MultiMap::from_iter(right_to_output.iter().copied());
265
266        // handle inequality watermarks
267        // https://github.com/risingwavelabs/risingwave/issues/18503
268        // let inequality_watermarks = None;
269        let watermark_buffers = BTreeMap::new();
270
271        let inequal_key_idx_l = Some(asof_desc.left_idx);
272        let inequal_key_idx_r = Some(asof_desc.right_idx);
273
274        Self {
275            ctx: ctx.clone(),
276            info,
277            input_l: Some(input_l),
278            input_r: Some(input_r),
279            actual_output_data_types,
280            side_l: JoinSide {
281                ht: JoinHashMap::new(
282                    watermark_epoch.clone(),
283                    join_key_data_types_l,
284                    state_join_key_indices_l.clone(),
285                    state_all_data_types_l.clone(),
286                    state_table_l,
287                    params_l.deduped_pk_indices,
288                    None,
289                    null_matched.clone(),
290                    pk_contained_in_jk_l,
291                    inequal_key_idx_l,
292                    metrics.clone(),
293                    ctx.id,
294                    ctx.fragment_id,
295                    "left",
296                ),
297                join_key_indices: state_join_key_indices_l,
298                all_data_types: state_all_data_types_l,
299                i2o_mapping: left_to_output,
300                i2o_mapping_indexed: l2o_indexed,
301                start_pos: 0,
302                need_degree_table: false,
303                _marker: PhantomData,
304            },
305            side_r: JoinSide {
306                ht: JoinHashMap::new(
307                    watermark_epoch,
308                    join_key_data_types_r,
309                    state_join_key_indices_r.clone(),
310                    state_all_data_types_r.clone(),
311                    state_table_r,
312                    params_r.deduped_pk_indices,
313                    None,
314                    null_matched,
315                    pk_contained_in_jk_r,
316                    inequal_key_idx_r,
317                    metrics.clone(),
318                    ctx.id,
319                    ctx.fragment_id,
320                    "right",
321                ),
322                join_key_indices: state_join_key_indices_r,
323                all_data_types: state_all_data_types_r,
324                start_pos: side_l_column_n,
325                i2o_mapping: right_to_output,
326                i2o_mapping_indexed: r2o_indexed,
327                need_degree_table: false,
328                _marker: PhantomData,
329            },
330            metrics,
331            chunk_size,
332            cnt_rows_received: 0,
333            watermark_buffers,
334            high_join_amplification_threshold,
335            asof_desc,
336        }
337    }
338
339    #[try_stream(ok = Message, error = StreamExecutorError)]
340    async fn into_stream(mut self) {
341        let input_l = self.input_l.take().unwrap();
342        let input_r = self.input_r.take().unwrap();
343        let aligned_stream = barrier_align(
344            input_l.execute(),
345            input_r.execute(),
346            self.ctx.id,
347            self.ctx.fragment_id,
348            self.metrics.clone(),
349            "Join",
350        );
351        pin_mut!(aligned_stream);
352        let actor_id = self.ctx.id;
353
354        let barrier = expect_first_barrier_from_aligned_stream(&mut aligned_stream).await?;
355        let first_epoch = barrier.epoch;
356        // The first barrier message should be propagated.
357        yield Message::Barrier(barrier);
358        self.side_l.init(first_epoch).await?;
359        self.side_r.init(first_epoch).await?;
360
361        let actor_id_str = self.ctx.id.to_string();
362        let fragment_id_str = self.ctx.fragment_id.to_string();
363
364        // initialized some metrics
365        let join_actor_input_waiting_duration_ns = self
366            .metrics
367            .join_actor_input_waiting_duration_ns
368            .with_guarded_label_values(&[&actor_id_str, &fragment_id_str]);
369        let left_join_match_duration_ns = self
370            .metrics
371            .join_match_duration_ns
372            .with_guarded_label_values(&[actor_id_str.as_str(), fragment_id_str.as_str(), "left"]);
373        let right_join_match_duration_ns = self
374            .metrics
375            .join_match_duration_ns
376            .with_guarded_label_values(&[actor_id_str.as_str(), fragment_id_str.as_str(), "right"]);
377
378        let barrier_join_match_duration_ns = self
379            .metrics
380            .join_match_duration_ns
381            .with_guarded_label_values(&[
382                actor_id_str.as_str(),
383                fragment_id_str.as_str(),
384                "barrier",
385            ]);
386
387        let left_join_cached_entry_count = self
388            .metrics
389            .join_cached_entry_count
390            .with_guarded_label_values(&[actor_id_str.as_str(), fragment_id_str.as_str(), "left"]);
391
392        let right_join_cached_entry_count = self
393            .metrics
394            .join_cached_entry_count
395            .with_guarded_label_values(&[actor_id_str.as_str(), fragment_id_str.as_str(), "right"]);
396
397        let mut start_time = Instant::now();
398
399        while let Some(msg) = aligned_stream
400            .next()
401            .instrument_await("hash_join_barrier_align")
402            .await
403        {
404            join_actor_input_waiting_duration_ns.inc_by(start_time.elapsed().as_nanos() as u64);
405            match msg? {
406                AlignedMessage::WatermarkLeft(watermark) => {
407                    for watermark_to_emit in self.handle_watermark(SideType::Left, watermark)? {
408                        yield Message::Watermark(watermark_to_emit);
409                    }
410                }
411                AlignedMessage::WatermarkRight(watermark) => {
412                    for watermark_to_emit in self.handle_watermark(SideType::Right, watermark)? {
413                        yield Message::Watermark(watermark_to_emit);
414                    }
415                }
416                AlignedMessage::Left(chunk) => {
417                    let mut left_time = Duration::from_nanos(0);
418                    let mut left_start_time = Instant::now();
419                    #[for_await]
420                    for chunk in Self::eq_join_left(EqJoinArgs {
421                        ctx: &self.ctx,
422                        side_l: &mut self.side_l,
423                        side_r: &mut self.side_r,
424                        asof_desc: &self.asof_desc,
425                        actual_output_data_types: &self.actual_output_data_types,
426                        // inequality_watermarks: &self.inequality_watermarks,
427                        chunk,
428                        chunk_size: self.chunk_size,
429                        cnt_rows_received: &mut self.cnt_rows_received,
430                        high_join_amplification_threshold: self.high_join_amplification_threshold,
431                    }) {
432                        left_time += left_start_time.elapsed();
433                        yield Message::Chunk(chunk?);
434                        left_start_time = Instant::now();
435                    }
436                    left_time += left_start_time.elapsed();
437                    left_join_match_duration_ns.inc_by(left_time.as_nanos() as u64);
438                    self.try_flush_data().await?;
439                }
440                AlignedMessage::Right(chunk) => {
441                    let mut right_time = Duration::from_nanos(0);
442                    let mut right_start_time = Instant::now();
443                    #[for_await]
444                    for chunk in Self::eq_join_right(EqJoinArgs {
445                        ctx: &self.ctx,
446                        side_l: &mut self.side_l,
447                        side_r: &mut self.side_r,
448                        asof_desc: &self.asof_desc,
449                        actual_output_data_types: &self.actual_output_data_types,
450                        // inequality_watermarks: &self.inequality_watermarks,
451                        chunk,
452                        chunk_size: self.chunk_size,
453                        cnt_rows_received: &mut self.cnt_rows_received,
454                        high_join_amplification_threshold: self.high_join_amplification_threshold,
455                    }) {
456                        right_time += right_start_time.elapsed();
457                        yield Message::Chunk(chunk?);
458                        right_start_time = Instant::now();
459                    }
460                    right_time += right_start_time.elapsed();
461                    right_join_match_duration_ns.inc_by(right_time.as_nanos() as u64);
462                    self.try_flush_data().await?;
463                }
464                AlignedMessage::Barrier(barrier) => {
465                    let barrier_start_time = Instant::now();
466                    let (left_post_commit, right_post_commit) =
467                        self.flush_data(barrier.epoch).await?;
468
469                    let update_vnode_bitmap = barrier.as_update_vnode_bitmap(actor_id);
470                    yield Message::Barrier(barrier);
471
472                    // Update the vnode bitmap for state tables of both sides if asked.
473                    right_post_commit
474                        .post_yield_barrier(update_vnode_bitmap.clone())
475                        .await?;
476                    if left_post_commit
477                        .post_yield_barrier(update_vnode_bitmap)
478                        .await?
479                        .unwrap_or(false)
480                    {
481                        self.watermark_buffers
482                            .values_mut()
483                            .for_each(|buffers| buffers.clear());
484                    }
485
486                    // Report metrics of cached join rows/entries
487                    for (join_cached_entry_count, ht) in [
488                        (&left_join_cached_entry_count, &self.side_l.ht),
489                        (&right_join_cached_entry_count, &self.side_r.ht),
490                    ] {
491                        join_cached_entry_count.set(ht.entry_count() as i64);
492                    }
493
494                    barrier_join_match_duration_ns
495                        .inc_by(barrier_start_time.elapsed().as_nanos() as u64);
496                }
497            }
498            start_time = Instant::now();
499        }
500    }
501
502    async fn flush_data(
503        &mut self,
504        epoch: EpochPair,
505    ) -> StreamExecutorResult<(
506        JoinHashMapPostCommit<'_, K, S, E>,
507        JoinHashMapPostCommit<'_, K, S, E>,
508    )> {
509        // All changes to the state has been buffered in the mem-table of the state table. Just
510        // `commit` them here.
511        let left = self.side_l.ht.flush(epoch).await?;
512        let right = self.side_r.ht.flush(epoch).await?;
513        Ok((left, right))
514    }
515
516    async fn try_flush_data(&mut self) -> StreamExecutorResult<()> {
517        // All changes to the state has been buffered in the mem-table of the state table. Just
518        // `commit` them here.
519        self.side_l.ht.try_flush().await?;
520        self.side_r.ht.try_flush().await?;
521        Ok(())
522    }
523
524    // We need to manually evict the cache.
525    fn evict_cache(
526        side_update: &mut JoinSide<K, S, E>,
527        side_match: &mut JoinSide<K, S, E>,
528        cnt_rows_received: &mut u32,
529    ) {
530        *cnt_rows_received += 1;
531        if *cnt_rows_received == EVICT_EVERY_N_ROWS {
532            side_update.ht.evict();
533            side_match.ht.evict();
534            *cnt_rows_received = 0;
535        }
536    }
537
538    fn handle_watermark(
539        &mut self,
540        side: SideTypePrimitive,
541        watermark: Watermark,
542    ) -> StreamExecutorResult<Vec<Watermark>> {
543        let (side_update, side_match) = if side == SideType::Left {
544            (&mut self.side_l, &mut self.side_r)
545        } else {
546            (&mut self.side_r, &mut self.side_l)
547        };
548
549        // State cleaning
550        if side_update.join_key_indices[0] == watermark.col_idx {
551            side_match.ht.update_watermark(watermark.val.clone());
552        }
553
554        // Select watermarks to yield.
555        let wm_in_jk = side_update
556            .join_key_indices
557            .iter()
558            .positions(|idx| *idx == watermark.col_idx);
559        let mut watermarks_to_emit = vec![];
560        for idx in wm_in_jk {
561            let buffers = self
562                .watermark_buffers
563                .entry(idx)
564                .or_insert_with(|| BufferedWatermarks::with_ids([SideType::Left, SideType::Right]));
565            if let Some(selected_watermark) = buffers.handle_watermark(side, watermark.clone()) {
566                let empty_indices = vec![];
567                let output_indices = side_update
568                    .i2o_mapping_indexed
569                    .get_vec(&side_update.join_key_indices[idx])
570                    .unwrap_or(&empty_indices)
571                    .iter()
572                    .chain(
573                        side_match
574                            .i2o_mapping_indexed
575                            .get_vec(&side_match.join_key_indices[idx])
576                            .unwrap_or(&empty_indices),
577                    );
578                for output_idx in output_indices {
579                    watermarks_to_emit.push(selected_watermark.clone().with_idx(*output_idx));
580                }
581            };
582        }
583        Ok(watermarks_to_emit)
584    }
585
586    /// the data the hash table and match the coming
587    /// data chunk with the executor state
588    async fn hash_eq_match(
589        key: &K,
590        ht: &mut JoinHashMap<K, S, E>,
591    ) -> StreamExecutorResult<Option<HashValueType<E>>> {
592        if !key.null_bitmap().is_subset(ht.null_matched()) {
593            Ok(None)
594        } else {
595            ht.take_state(key).await.map(Some)
596        }
597    }
598
599    #[try_stream(ok = StreamChunk, error = StreamExecutorError)]
600    async fn eq_join_left(args: EqJoinArgs<'_, K, S, E>) {
601        let EqJoinArgs {
602            ctx: _,
603            side_l,
604            side_r,
605            asof_desc,
606            actual_output_data_types,
607            // inequality_watermarks,
608            chunk,
609            chunk_size,
610            cnt_rows_received,
611            high_join_amplification_threshold: _,
612        } = args;
613
614        let (side_update, side_match) = (side_l, side_r);
615
616        let mut join_chunk_builder =
617            JoinChunkBuilder::<T, { SideType::Left }>::new(JoinStreamChunkBuilder::new(
618                chunk_size,
619                actual_output_data_types.to_vec(),
620                side_update.i2o_mapping.clone(),
621                side_match.i2o_mapping.clone(),
622            ));
623
624        let keys = K::build_many(&side_update.join_key_indices, chunk.data_chunk());
625        for (r, key) in chunk.rows_with_holes().zip_eq_debug(keys.iter()) {
626            let Some((op, row)) = r else {
627                continue;
628            };
629            Self::evict_cache(side_update, side_match, cnt_rows_received);
630
631            let matched_rows = if !side_update.ht.check_inequal_key_null(&row) {
632                Self::hash_eq_match(key, &mut side_match.ht).await?
633            } else {
634                None
635            };
636            let inequal_key = side_update.ht.serialize_inequal_key_from_row(row);
637
638            if let Some(matched_rows) = matched_rows {
639                let matched_row_by_inequality = match asof_desc.inequality_type {
640                    AsOfInequalityType::Lt => matched_rows.lower_bound_by_inequality(
641                        Bound::Excluded(&inequal_key),
642                        &side_match.all_data_types,
643                    ),
644                    AsOfInequalityType::Le => matched_rows.lower_bound_by_inequality(
645                        Bound::Included(&inequal_key),
646                        &side_match.all_data_types,
647                    ),
648                    AsOfInequalityType::Gt => matched_rows.upper_bound_by_inequality(
649                        Bound::Excluded(&inequal_key),
650                        &side_match.all_data_types,
651                    ),
652                    AsOfInequalityType::Ge => matched_rows.upper_bound_by_inequality(
653                        Bound::Included(&inequal_key),
654                        &side_match.all_data_types,
655                    ),
656                };
657                match op {
658                    Op::Insert | Op::UpdateInsert => {
659                        if let Some(matched_row_by_inequality) = matched_row_by_inequality {
660                            let matched_row = matched_row_by_inequality?;
661
662                            if let Some(chunk) =
663                                join_chunk_builder.with_match_on_insert(&row, &matched_row)
664                            {
665                                yield chunk;
666                            }
667                        } else if let Some(chunk) =
668                            join_chunk_builder.forward_if_not_matched(Op::Insert, row)
669                        {
670                            yield chunk;
671                        }
672                        side_update.ht.insert_row(key, row)?;
673                    }
674                    Op::Delete | Op::UpdateDelete => {
675                        if let Some(matched_row_by_inequality) = matched_row_by_inequality {
676                            let matched_row = matched_row_by_inequality?;
677
678                            if let Some(chunk) =
679                                join_chunk_builder.with_match_on_delete(&row, &matched_row)
680                            {
681                                yield chunk;
682                            }
683                        } else if let Some(chunk) =
684                            join_chunk_builder.forward_if_not_matched(Op::Delete, row)
685                        {
686                            yield chunk;
687                        }
688                        side_update.ht.delete_row(key, row)?;
689                    }
690                }
691                // Insert back the state taken from ht.
692                side_match.ht.update_state(key, matched_rows);
693            } else {
694                // Row which violates null-safe bitmap will never be matched so we need not
695                // store.
696                match op {
697                    Op::Insert | Op::UpdateInsert => {
698                        if let Some(chunk) =
699                            join_chunk_builder.forward_if_not_matched(Op::Insert, row)
700                        {
701                            yield chunk;
702                        }
703                    }
704                    Op::Delete | Op::UpdateDelete => {
705                        if let Some(chunk) =
706                            join_chunk_builder.forward_if_not_matched(Op::Delete, row)
707                        {
708                            yield chunk;
709                        }
710                    }
711                }
712            }
713        }
714        if let Some(chunk) = join_chunk_builder.take() {
715            yield chunk;
716        }
717    }
718
719    #[try_stream(ok = StreamChunk, error = StreamExecutorError)]
720    async fn eq_join_right(args: EqJoinArgs<'_, K, S, E>) {
721        let EqJoinArgs {
722            ctx,
723            side_l,
724            side_r,
725            asof_desc,
726            actual_output_data_types,
727            // inequality_watermarks,
728            chunk,
729            chunk_size,
730            cnt_rows_received,
731            high_join_amplification_threshold,
732        } = args;
733
734        let (side_update, side_match) = (side_r, side_l);
735
736        let mut join_chunk_builder = JoinStreamChunkBuilder::new(
737            chunk_size,
738            actual_output_data_types.to_vec(),
739            side_update.i2o_mapping.clone(),
740            side_match.i2o_mapping.clone(),
741        );
742
743        let join_matched_rows_metrics = ctx
744            .streaming_metrics
745            .join_matched_join_keys
746            .with_guarded_label_values(&[
747                &ctx.id.to_string(),
748                &ctx.fragment_id.to_string(),
749                &side_update.ht.table_id().to_string(),
750            ]);
751
752        let keys = K::build_many(&side_update.join_key_indices, chunk.data_chunk());
753        for (r, key) in chunk.rows_with_holes().zip_eq_debug(keys.iter()) {
754            let Some((op, row)) = r else {
755                continue;
756            };
757            let mut join_matched_rows_cnt = 0;
758
759            Self::evict_cache(side_update, side_match, cnt_rows_received);
760
761            let matched_rows = if !side_update.ht.check_inequal_key_null(&row) {
762                Self::hash_eq_match(key, &mut side_match.ht).await?
763            } else {
764                None
765            };
766            let inequal_key = side_update.ht.serialize_inequal_key_from_row(row);
767
768            if let Some(matched_rows) = matched_rows {
769                let update_rows = Self::hash_eq_match(key, &mut side_update.ht).await?.expect("None is not expected because we have checked null in key when getting matched_rows");
770                let right_inequality_index = update_rows.inequality_index();
771                let (row_to_delete_r, row_to_insert_r) =
772                    if let Some(pks) = right_inequality_index.get(&inequal_key) {
773                        assert!(!pks.is_empty());
774                        let row_pk = side_update.ht.serialize_pk_from_row(row);
775                        match op {
776                            Op::Insert | Op::UpdateInsert => {
777                                // If there are multiple rows match the inequality key in the right table, we use one with smallest pk.
778                                let smallest_pk = pks.first_key_sorted().unwrap();
779                                if smallest_pk > &row_pk {
780                                    // smallest_pk is in the cache index, so it must exist in the cache.
781                                    if let Some(to_delete_row) = update_rows
782                                        .get_by_indexed_pk(smallest_pk, &side_update.all_data_types)
783                                    {
784                                        (
785                                            Some(Either::Left(to_delete_row?.row)),
786                                            Some(Either::Right(row)),
787                                        )
788                                    } else {
789                                        // Something wrong happened. Ignore this row in non strict consistency mode.
790                                        (None, None)
791                                    }
792                                } else {
793                                    // No affected row in the right table.
794                                    (None, None)
795                                }
796                            }
797                            Op::Delete | Op::UpdateDelete => {
798                                let smallest_pk = pks.first_key_sorted().unwrap();
799                                if smallest_pk == &row_pk {
800                                    if let Some(second_smallest_pk) = pks.second_key_sorted() {
801                                        if let Some(to_insert_row) = update_rows.get_by_indexed_pk(
802                                            second_smallest_pk,
803                                            &side_update.all_data_types,
804                                        ) {
805                                            (
806                                                Some(Either::Right(row)),
807                                                Some(Either::Left(to_insert_row?.row)),
808                                            )
809                                        } else {
810                                            // Something wrong happened. Ignore this row in non strict consistency mode.
811                                            (None, None)
812                                        }
813                                    } else {
814                                        (Some(Either::Right(row)), None)
815                                    }
816                                } else {
817                                    // No affected row in the right table.
818                                    (None, None)
819                                }
820                            }
821                        }
822                    } else {
823                        match op {
824                            // Decide the row_to_delete later
825                            Op::Insert | Op::UpdateInsert => (None, Some(Either::Right(row))),
826                            // Decide the row_to_insert later
827                            Op::Delete | Op::UpdateDelete => (Some(Either::Right(row)), None),
828                        }
829                    };
830
831                // 4 cases for row_to_delete_r and row_to_insert_r:
832                // 1. Some(_), Some(_): delete row_to_delete_r and insert row_to_insert_r
833                // 2. None, Some(_)   : row_to_delete to be decided by the nearest inequality key
834                // 3. Some(_), None   : row_to_insert to be decided by the nearest inequality key
835                // 4. None, None      : do nothing
836                if row_to_delete_r.is_none() && row_to_insert_r.is_none() {
837                    // no row to delete or insert.
838                } else {
839                    let prev_inequality_key =
840                        right_inequality_index.upper_bound_key(Bound::Excluded(&inequal_key));
841                    let next_inequality_key =
842                        right_inequality_index.lower_bound_key(Bound::Excluded(&inequal_key));
843                    let affected_row_r = match asof_desc.inequality_type {
844                        AsOfInequalityType::Lt | AsOfInequalityType::Le => next_inequality_key
845                            .and_then(|k| {
846                                update_rows.get_first_by_inequality(k, &side_update.all_data_types)
847                            }),
848                        AsOfInequalityType::Gt | AsOfInequalityType::Ge => prev_inequality_key
849                            .and_then(|k| {
850                                update_rows.get_first_by_inequality(k, &side_update.all_data_types)
851                            }),
852                    }
853                    .transpose()?
854                    .map(|r| Either::Left(r.row));
855
856                    let (row_to_delete_r, row_to_insert_r) =
857                        match (&row_to_delete_r, &row_to_insert_r) {
858                            (Some(_), Some(_)) => (row_to_delete_r, row_to_insert_r),
859                            (None, Some(_)) => (affected_row_r, row_to_insert_r),
860                            (Some(_), None) => (row_to_delete_r, affected_row_r),
861                            (None, None) => unreachable!(),
862                        };
863                    let range = match asof_desc.inequality_type {
864                        AsOfInequalityType::Lt => (
865                            prev_inequality_key.map_or_else(|| Bound::Unbounded, Bound::Included),
866                            Bound::Excluded(&inequal_key),
867                        ),
868                        AsOfInequalityType::Le => (
869                            prev_inequality_key.map_or_else(|| Bound::Unbounded, Bound::Excluded),
870                            Bound::Included(&inequal_key),
871                        ),
872                        AsOfInequalityType::Gt => (
873                            Bound::Excluded(&inequal_key),
874                            next_inequality_key.map_or_else(|| Bound::Unbounded, Bound::Included),
875                        ),
876                        AsOfInequalityType::Ge => (
877                            Bound::Included(&inequal_key),
878                            next_inequality_key.map_or_else(|| Bound::Unbounded, Bound::Excluded),
879                        ),
880                    };
881
882                    let rows_l =
883                        matched_rows.range_by_inequality(range, &side_match.all_data_types);
884                    for row_l in rows_l {
885                        join_matched_rows_cnt += 1;
886                        let row_l = row_l?.row;
887                        if let Some(row_to_delete_r) = &row_to_delete_r {
888                            if let Some(chunk) =
889                                join_chunk_builder.append_row(Op::Delete, row_to_delete_r, &row_l)
890                            {
891                                yield chunk;
892                            }
893                        } else if is_as_of_left_outer(T)
894                            && let Some(chunk) =
895                                join_chunk_builder.append_row_matched(Op::Delete, &row_l)
896                        {
897                            yield chunk;
898                        }
899                        if let Some(row_to_insert_r) = &row_to_insert_r {
900                            if let Some(chunk) =
901                                join_chunk_builder.append_row(Op::Insert, row_to_insert_r, &row_l)
902                            {
903                                yield chunk;
904                            }
905                        } else if is_as_of_left_outer(T)
906                            && let Some(chunk) =
907                                join_chunk_builder.append_row_matched(Op::Insert, &row_l)
908                        {
909                            yield chunk;
910                        }
911                    }
912                }
913                // Insert back the state taken from ht.
914                side_match.ht.update_state(key, matched_rows);
915                side_update.ht.update_state(key, update_rows);
916
917                match op {
918                    Op::Insert | Op::UpdateInsert => {
919                        side_update.ht.insert_row(key, row)?;
920                    }
921                    Op::Delete | Op::UpdateDelete => {
922                        side_update.ht.delete_row(key, row)?;
923                    }
924                }
925            } else {
926                // Row which violates null-safe bitmap will never be matched so we need not
927                // store.
928                // Noop here because we only support left outer AsOf join.
929            }
930            join_matched_rows_metrics.observe(join_matched_rows_cnt as _);
931            if join_matched_rows_cnt > high_join_amplification_threshold {
932                let join_key_data_types = side_update.ht.join_key_data_types();
933                let key = key.deserialize(join_key_data_types)?;
934                tracing::warn!(target: "high_join_amplification",
935                    matched_rows_len = join_matched_rows_cnt,
936                    update_table_id = %side_update.ht.table_id(),
937                    match_table_id = %side_match.ht.table_id(),
938                    join_key = ?key,
939                    actor_id = %ctx.id,
940                    fragment_id = %ctx.fragment_id,
941                    "large rows matched for join key when AsOf join updating right side",
942                );
943            }
944        }
945        if let Some(chunk) = join_chunk_builder.take() {
946            yield chunk;
947        }
948    }
949}
950
951#[cfg(test)]
952mod tests {
953    use std::sync::atomic::AtomicU64;
954
955    use risingwave_common::array::*;
956    use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, TableId};
957    use risingwave_common::hash::Key64;
958    use risingwave_common::util::epoch::test_epoch;
959    use risingwave_common::util::sort_util::OrderType;
960    use risingwave_storage::memory::MemoryStateStore;
961
962    use super::*;
963    use crate::common::table::test_utils::gen_pbtable;
964    use crate::executor::MemoryEncoding;
965    use crate::executor::test_utils::{MessageSender, MockSource, StreamExecutorTestExt};
966
967    async fn create_in_memory_state_table(
968        mem_state: MemoryStateStore,
969        data_types: &[DataType],
970        order_types: &[OrderType],
971        pk_indices: &[usize],
972        table_id: u32,
973    ) -> StateTable<MemoryStateStore> {
974        let column_descs = data_types
975            .iter()
976            .enumerate()
977            .map(|(id, data_type)| ColumnDesc::unnamed(ColumnId::new(id as i32), data_type.clone()))
978            .collect_vec();
979        StateTable::from_table_catalog(
980            &gen_pbtable(
981                TableId::new(table_id),
982                column_descs,
983                order_types.to_vec(),
984                pk_indices.to_vec(),
985                0,
986            ),
987            mem_state.clone(),
988            None,
989        )
990        .await
991    }
992
993    async fn create_executor<const T: AsOfJoinTypePrimitive>(
994        asof_desc: AsOfDesc,
995    ) -> (MessageSender, MessageSender, BoxedMessageStream) {
996        let schema = Schema {
997            fields: vec![
998                Field::unnamed(DataType::Int64), // join key
999                Field::unnamed(DataType::Int64),
1000                Field::unnamed(DataType::Int64),
1001            ],
1002        };
1003        let (tx_l, source_l) = MockSource::channel();
1004        let source_l = source_l.into_executor(schema.clone(), vec![1]);
1005        let (tx_r, source_r) = MockSource::channel();
1006        let source_r = source_r.into_executor(schema, vec![1]);
1007        let params_l = JoinParams::new(vec![0], vec![1]);
1008        let params_r = JoinParams::new(vec![0], vec![1]);
1009
1010        let mem_state = MemoryStateStore::new();
1011
1012        let state_l = create_in_memory_state_table(
1013            mem_state.clone(),
1014            &[DataType::Int64, DataType::Int64, DataType::Int64],
1015            &[
1016                OrderType::ascending(),
1017                OrderType::ascending(),
1018                OrderType::ascending(),
1019            ],
1020            &[0, asof_desc.left_idx, 1],
1021            0,
1022        )
1023        .await;
1024
1025        let state_r = create_in_memory_state_table(
1026            mem_state,
1027            &[DataType::Int64, DataType::Int64, DataType::Int64],
1028            &[
1029                OrderType::ascending(),
1030                OrderType::ascending(),
1031                OrderType::ascending(),
1032            ],
1033            &[0, asof_desc.right_idx, 1],
1034            1,
1035        )
1036        .await;
1037
1038        let schema: Schema = [source_l.schema().fields(), source_r.schema().fields()]
1039            .concat()
1040            .into_iter()
1041            .collect();
1042        let schema_len = schema.len();
1043        let info = ExecutorInfo::for_test(schema, vec![1], "HashJoinExecutor".to_owned(), 0);
1044
1045        let executor = AsOfJoinExecutor::<Key64, MemoryStateStore, T, MemoryEncoding>::new(
1046            ActorContext::for_test(123),
1047            info,
1048            source_l,
1049            source_r,
1050            params_l,
1051            params_r,
1052            vec![false],
1053            (0..schema_len).collect_vec(),
1054            state_l,
1055            state_r,
1056            Arc::new(AtomicU64::new(0)),
1057            Arc::new(StreamingMetrics::unused()),
1058            1024,
1059            2048,
1060            asof_desc,
1061        );
1062        (tx_l, tx_r, executor.boxed().execute())
1063    }
1064
1065    #[tokio::test]
1066    async fn test_as_of_inner_join() -> StreamExecutorResult<()> {
1067        let asof_desc = AsOfDesc {
1068            left_idx: 0,
1069            right_idx: 2,
1070            inequality_type: AsOfInequalityType::Lt,
1071        };
1072
1073        let chunk_l1 = StreamChunk::from_pretty(
1074            "  I I I
1075             + 1 4 7
1076             + 2 5 8
1077             + 3 6 9",
1078        );
1079        let chunk_l2 = StreamChunk::from_pretty(
1080            "  I I I
1081             + 3 8 1
1082             - 3 8 1",
1083        );
1084        let chunk_r1 = StreamChunk::from_pretty(
1085            "  I I I
1086             + 2 1 7
1087             + 2 2 1
1088             + 2 3 4
1089             + 2 4 2
1090             + 6 1 9
1091             + 6 2 9",
1092        );
1093        let chunk_r2 = StreamChunk::from_pretty(
1094            "  I I I
1095             - 2 3 4",
1096        );
1097        let chunk_r3 = StreamChunk::from_pretty(
1098            "  I I I
1099             + 2 3 3",
1100        );
1101        let chunk_l3 = StreamChunk::from_pretty(
1102            "  I I I
1103             - 2 5 8",
1104        );
1105        let chunk_l4 = StreamChunk::from_pretty(
1106            "  I I I
1107             + 6 3 1
1108             + 6 4 1",
1109        );
1110        let chunk_r4 = StreamChunk::from_pretty(
1111            "  I I I
1112             - 6 1 9",
1113        );
1114
1115        let (mut tx_l, mut tx_r, mut hash_join) =
1116            create_executor::<{ AsOfJoinType::Inner }>(asof_desc).await;
1117
1118        // push the init barrier for left and right
1119        tx_l.push_barrier(test_epoch(1), false);
1120        tx_r.push_barrier(test_epoch(1), false);
1121        hash_join.next_unwrap_ready_barrier()?;
1122
1123        // push the 1st left chunk
1124        tx_l.push_chunk(chunk_l1);
1125        hash_join.next_unwrap_pending();
1126
1127        // push the init barrier for left and right
1128        tx_l.push_barrier(test_epoch(2), false);
1129        tx_r.push_barrier(test_epoch(2), false);
1130        hash_join.next_unwrap_ready_barrier()?;
1131
1132        // push the 2nd left chunk
1133        tx_l.push_chunk(chunk_l2);
1134        hash_join.next_unwrap_pending();
1135
1136        // push the 1st right chunk
1137        tx_r.push_chunk(chunk_r1);
1138        let chunk = hash_join.next_unwrap_ready_chunk()?;
1139        assert_eq!(
1140            chunk,
1141            StreamChunk::from_pretty(
1142                " I I I I I I
1143                + 2 5 8 2 1 7
1144                - 2 5 8 2 1 7
1145                + 2 5 8 2 3 4"
1146            )
1147        );
1148
1149        // push the 2nd right chunk
1150        tx_r.push_chunk(chunk_r2);
1151        let chunk = hash_join.next_unwrap_ready_chunk()?;
1152        assert_eq!(
1153            chunk,
1154            StreamChunk::from_pretty(
1155                " I I I I I I
1156                - 2 5 8 2 3 4
1157                + 2 5 8 2 1 7"
1158            )
1159        );
1160
1161        // push the 3rd right chunk
1162        tx_r.push_chunk(chunk_r3);
1163        let chunk = hash_join.next_unwrap_ready_chunk()?;
1164        assert_eq!(
1165            chunk,
1166            StreamChunk::from_pretty(
1167                " I I I I I I
1168                - 2 5 8 2 1 7
1169                + 2 5 8 2 3 3"
1170            )
1171        );
1172
1173        // push the 3rd left chunk
1174        tx_l.push_chunk(chunk_l3);
1175        let chunk = hash_join.next_unwrap_ready_chunk()?;
1176        assert_eq!(
1177            chunk,
1178            StreamChunk::from_pretty(
1179                " I I I I I I
1180                - 2 5 8 2 3 3"
1181            )
1182        );
1183
1184        // push the 4th left chunk
1185        tx_l.push_chunk(chunk_l4);
1186        let chunk = hash_join.next_unwrap_ready_chunk()?;
1187        assert_eq!(
1188            chunk,
1189            StreamChunk::from_pretty(
1190                " I I I I I I
1191                + 6 3 1 6 1 9
1192                + 6 4 1 6 1 9"
1193            )
1194        );
1195
1196        // push the 4th right chunk
1197        tx_r.push_chunk(chunk_r4);
1198        let chunk = hash_join.next_unwrap_ready_chunk()?;
1199        assert_eq!(
1200            chunk,
1201            StreamChunk::from_pretty(
1202                " I I I I I I
1203                - 6 3 1 6 1 9
1204                + 6 3 1 6 2 9
1205                - 6 4 1 6 1 9
1206                + 6 4 1 6 2 9"
1207            )
1208        );
1209
1210        Ok(())
1211    }
1212
1213    #[tokio::test]
1214    async fn test_as_of_left_outer_join() -> StreamExecutorResult<()> {
1215        let asof_desc = AsOfDesc {
1216            left_idx: 1,
1217            right_idx: 2,
1218            inequality_type: AsOfInequalityType::Ge,
1219        };
1220
1221        let chunk_l1 = StreamChunk::from_pretty(
1222            "  I I I
1223             + 1 4 7
1224             + 2 5 8
1225             + 3 6 9",
1226        );
1227        let chunk_l2 = StreamChunk::from_pretty(
1228            "  I I I
1229             + 3 8 1
1230             - 3 8 1",
1231        );
1232        let chunk_r1 = StreamChunk::from_pretty(
1233            "  I I I
1234             + 2 3 4
1235             + 2 2 5
1236             + 2 1 5
1237             + 6 1 8
1238             + 6 2 9",
1239        );
1240        let chunk_r2 = StreamChunk::from_pretty(
1241            "  I I I
1242             - 2 3 4
1243             - 2 1 5
1244             - 2 2 5",
1245        );
1246        let chunk_l3 = StreamChunk::from_pretty(
1247            "  I I I
1248             + 6 8 9",
1249        );
1250        let chunk_r3 = StreamChunk::from_pretty(
1251            "  I I I
1252             - 6 1 8",
1253        );
1254
1255        let (mut tx_l, mut tx_r, mut hash_join) =
1256            create_executor::<{ AsOfJoinType::LeftOuter }>(asof_desc).await;
1257
1258        // push the init barrier for left and right
1259        tx_l.push_barrier(test_epoch(1), false);
1260        tx_r.push_barrier(test_epoch(1), false);
1261        hash_join.next_unwrap_ready_barrier()?;
1262
1263        // push the 1st left chunk
1264        tx_l.push_chunk(chunk_l1);
1265        let chunk = hash_join.next_unwrap_ready_chunk()?;
1266        assert_eq!(
1267            chunk,
1268            StreamChunk::from_pretty(
1269                " I I I I I I
1270                + 1 4 7 . . .
1271                + 2 5 8 . . .
1272                + 3 6 9 . . ."
1273            )
1274        );
1275
1276        // push the init barrier for left and right
1277        tx_l.push_barrier(test_epoch(2), false);
1278        tx_r.push_barrier(test_epoch(2), false);
1279        hash_join.next_unwrap_ready_barrier()?;
1280
1281        // push the 2nd left chunk
1282        tx_l.push_chunk(chunk_l2);
1283        let chunk = hash_join.next_unwrap_ready_chunk()?;
1284        assert_eq!(
1285            chunk,
1286            StreamChunk::from_pretty(
1287                " I I I I I I
1288                + 3 8 1 . . . D
1289                - 3 8 1 . . . D"
1290            )
1291        );
1292
1293        // push the 1st right chunk
1294        tx_r.push_chunk(chunk_r1);
1295        let chunk = hash_join.next_unwrap_ready_chunk()?;
1296        assert_eq!(
1297            chunk,
1298            StreamChunk::from_pretty(
1299                " I I I I I I
1300                - 2 5 8 . . .
1301                + 2 5 8 2 3 4
1302                - 2 5 8 2 3 4
1303                + 2 5 8 2 2 5
1304                - 2 5 8 2 2 5
1305                + 2 5 8 2 1 5"
1306            )
1307        );
1308
1309        // push the 2nd right chunk
1310        tx_r.push_chunk(chunk_r2);
1311        let chunk = hash_join.next_unwrap_ready_chunk()?;
1312        assert_eq!(
1313            chunk,
1314            StreamChunk::from_pretty(
1315                " I I I I I I
1316                - 2 5 8 2 1 5
1317                + 2 5 8 2 2 5
1318                - 2 5 8 2 2 5
1319                + 2 5 8 . . ."
1320            )
1321        );
1322
1323        // push the 3rd left chunk
1324        tx_l.push_chunk(chunk_l3);
1325        let chunk = hash_join.next_unwrap_ready_chunk()?;
1326        assert_eq!(
1327            chunk,
1328            StreamChunk::from_pretty(
1329                " I I I I I I
1330                + 6 8 9 6 1 8"
1331            )
1332        );
1333
1334        // push the 3rd right chunk
1335        tx_r.push_chunk(chunk_r3);
1336        let chunk = hash_join.next_unwrap_ready_chunk()?;
1337        assert_eq!(
1338            chunk,
1339            StreamChunk::from_pretty(
1340                " I I I I I I
1341                - 6 8 9 6 1 8
1342                + 6 8 9 . . ."
1343            )
1344        );
1345        Ok(())
1346    }
1347}