risingwave_stream/executor/test_utils/
hash_join_executor.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16use std::sync::atomic::AtomicU64;
17
18use itertools::Itertools;
19use risingwave_common::array::{I64Array, Op};
20use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, TableId};
21use risingwave_common::hash::Key128;
22use risingwave_common::util::sort_util::OrderType;
23use risingwave_pb::plan_common::JoinType;
24use risingwave_storage::memory::MemoryStateStore;
25use strum_macros::Display;
26
27use super::*;
28use crate::common::table::test_utils::gen_pbtable;
29use crate::executor::monitor::StreamingMetrics;
30use crate::executor::prelude::StateTable;
31use crate::executor::test_utils::{MessageSender, MockSource};
32use crate::executor::{ActorContext, HashJoinExecutor, JoinParams, JoinType as ConstJoinType};
33
34#[derive(Clone, Copy, Debug, Display)]
35pub enum HashJoinWorkload {
36    InCache,
37    NotInCache,
38}
39
40pub async fn create_in_memory_state_table(
41    mem_state: MemoryStateStore,
42    data_types: &[DataType],
43    order_types: &[OrderType],
44    pk_indices: &[usize],
45    table_id: u32,
46) -> (StateTable<MemoryStateStore>, StateTable<MemoryStateStore>) {
47    let column_descs = data_types
48        .iter()
49        .enumerate()
50        .map(|(id, data_type)| ColumnDesc::unnamed(ColumnId::new(id as i32), data_type.clone()))
51        .collect_vec();
52    let state_table = StateTable::from_table_catalog(
53        &gen_pbtable(
54            TableId::new(table_id),
55            column_descs,
56            order_types.to_vec(),
57            pk_indices.to_vec(),
58            0,
59        ),
60        mem_state.clone(),
61        None,
62    )
63    .await;
64
65    // Create degree table
66    let mut degree_table_column_descs = vec![];
67    pk_indices.iter().enumerate().for_each(|(pk_id, idx)| {
68        degree_table_column_descs.push(ColumnDesc::unnamed(
69            ColumnId::new(pk_id as i32),
70            data_types[*idx].clone(),
71        ))
72    });
73    degree_table_column_descs.push(ColumnDesc::unnamed(
74        ColumnId::new(pk_indices.len() as i32),
75        DataType::Int64,
76    ));
77    let degree_state_table = StateTable::from_table_catalog(
78        &gen_pbtable(
79            TableId::new(table_id + 1),
80            degree_table_column_descs,
81            order_types.to_vec(),
82            pk_indices.to_vec(),
83            0,
84        ),
85        mem_state,
86        None,
87    )
88    .await;
89    (state_table, degree_state_table)
90}
91
92/// 1. Refill state table of build side.
93/// 2. Init executor.
94/// 3. Push data to the probe side.
95/// 4. Check memory utilization.
96pub async fn setup_bench_stream_hash_join(
97    amp: usize,
98    workload: HashJoinWorkload,
99    join_type: JoinType,
100) -> (MessageSender, MessageSender, BoxedMessageStream) {
101    let fields = vec![DataType::Int64, DataType::Int64, DataType::Int64];
102    let orders = vec![OrderType::ascending(), OrderType::ascending()];
103    let state_store = MemoryStateStore::new();
104
105    // Probe side
106    let (lhs_state_table, lhs_degree_state_table) =
107        create_in_memory_state_table(state_store.clone(), &fields, &orders, &[0, 1], 0).await;
108
109    // Build side
110    let (mut rhs_state_table, rhs_degree_state_table) =
111        create_in_memory_state_table(state_store.clone(), &fields, &orders, &[0, 1], 2).await;
112
113    // Insert 100K records into the build side.
114    if matches!(workload, HashJoinWorkload::NotInCache) {
115        let stream_chunk = build_chunk(amp, 200_000);
116        // Write to state table.
117        rhs_state_table.write_chunk(stream_chunk);
118    }
119
120    let schema = Schema::new(fields.iter().cloned().map(Field::unnamed).collect());
121
122    let (tx_l, source_l) = MockSource::channel();
123    let source_l = source_l.into_executor(schema.clone(), vec![1]);
124    let (tx_r, source_r) = MockSource::channel();
125    let source_r = source_r.into_executor(schema, vec![1]);
126
127    // Schema is the concatenation of the two source schemas.
128    // [lhs(jk):0, lhs(pk):1, lhs(value):2, rhs(jk):0, rhs(pk):1, rhs(value):2]
129    // [0,         1,         2,            3,         4,         5           ]
130    let schema: Vec<_> = [source_l.schema().fields(), source_r.schema().fields()]
131        .concat()
132        .into_iter()
133        .collect();
134    let schema_len = schema.len();
135    let info = ExecutorInfo::new(
136        Schema { fields: schema },
137        vec![0, 1, 3, 4],
138        "HashJoinExecutor".to_owned(),
139        0,
140    );
141
142    // join-key is [0], primary-key is [1].
143    let params_l = JoinParams::new(vec![0], vec![1]);
144    let params_r = JoinParams::new(vec![0], vec![1]);
145
146    let cache_size = match workload {
147        HashJoinWorkload::InCache => Some(1_000_000),
148        HashJoinWorkload::NotInCache => None,
149    };
150
151    match join_type {
152        JoinType::Inner => {
153            let executor =
154                    HashJoinExecutor::<Key128, MemoryStateStore, { ConstJoinType::Inner }>::new_with_cache_size(
155                        ActorContext::for_test(123),
156                        info,
157                        source_l,
158                        source_r,
159                        params_l,
160                        params_r,
161                        vec![false], // null-safe
162                        (0..schema_len).collect_vec(),
163                        None,   // condition, it is an eq join, we have no condition
164                        vec![], // ineq pairs
165                        lhs_state_table,
166                        lhs_degree_state_table,
167                        rhs_state_table,
168                        rhs_degree_state_table,
169                        Arc::new(AtomicU64::new(0)), // watermark epoch
170                        false,                       // is_append_only
171                        Arc::new(StreamingMetrics::unused()),
172                        1024, // chunk_size
173                        2048, // high_join_amplification_threshold
174                        cache_size,
175                    );
176            (tx_l, tx_r, executor.boxed().execute())
177        }
178        JoinType::LeftOuter => {
179            let executor = HashJoinExecutor::<
180                    Key128,
181                    MemoryStateStore,
182                    { ConstJoinType::LeftOuter },
183                >::new_with_cache_size(
184                    ActorContext::for_test(123),
185                    info,
186                    source_l,
187                    source_r,
188                    params_l,
189                    params_r,
190                    vec![false], // null-safe
191                    (0..schema_len).collect_vec(),
192                    None,   // condition, it is an eq join, we have no condition
193                    vec![], // ineq pairs
194                    lhs_state_table,
195                    lhs_degree_state_table,
196                    rhs_state_table,
197                    rhs_degree_state_table,
198                    Arc::new(AtomicU64::new(0)), // watermark epoch
199                    false,                       // is_append_only
200                    Arc::new(StreamingMetrics::unused()),
201                    1024, // chunk_size
202                    2048, // high_join_amplification_threshold
203                    cache_size,
204                );
205            (tx_l, tx_r, executor.boxed().execute())
206        }
207        _ => panic!("Unsupported join type"),
208    }
209}
210
211fn build_chunk(size: usize, join_key_value: i64) -> StreamChunk {
212    // Create column [0]: join key. Each record has the same value, to trigger join amplification.
213    let mut int64_jk_builder = DataType::Int64.create_array_builder(size);
214    int64_jk_builder.append_array(&I64Array::from_iter(vec![Some(join_key_value); size]).into());
215    let jk = int64_jk_builder.finish();
216
217    // Create column [1]: pk. The original pk will be here, it will be unique.
218    let mut int64_pk_data_chunk_builder = DataType::Int64.create_array_builder(size);
219    let seq = I64Array::from_iter((0..size as i64).map(Some));
220    int64_pk_data_chunk_builder.append_array(&I64Array::from(seq).into());
221    let pk = int64_pk_data_chunk_builder.finish();
222
223    // Create column [2]: value. This can be an arbitrary value, so just clone the pk column.
224    let values = pk.clone();
225
226    // Build the stream chunk.
227    let columns = vec![jk.into(), pk.into(), values.into()];
228    let ops = vec![Op::Insert; size];
229    StreamChunk::new(ops, columns)
230}
231
232pub async fn handle_streams(
233    hash_join_workload: HashJoinWorkload,
234    join_type: JoinType,
235    amp: usize,
236    mut tx_l: MessageSender,
237    mut tx_r: MessageSender,
238    mut stream: BoxedMessageStream,
239) {
240    // Init executors
241    tx_l.push_barrier(test_epoch(1), false);
242    tx_r.push_barrier(test_epoch(1), false);
243
244    if matches!(hash_join_workload, HashJoinWorkload::InCache) {
245        // Push a single record into tx_r, so 100K records to be matched are cached.
246        let chunk = build_chunk(amp, 200_000);
247        tx_r.push_chunk(chunk);
248
249        // Ensure that the chunk on the rhs is processed, before inserting a chunk
250        // into the lhs. This is to ensure that the rhs chunk is cached,
251        // and we don't get interleaving of chunks between lhs and rhs.
252        tx_l.push_barrier(test_epoch(2), false);
253        tx_r.push_barrier(test_epoch(2), false);
254    }
255
256    // Push a chunk of records into tx_l, matches 100K records in the build side.
257    let chunk_size = match hash_join_workload {
258        HashJoinWorkload::InCache => 64,
259        HashJoinWorkload::NotInCache => 1,
260    };
261    let chunk = match join_type {
262        // Make sure all match
263        JoinType::Inner => build_chunk(chunk_size, 200_000),
264        // Make sure no match is found.
265        JoinType::LeftOuter => build_chunk(chunk_size, 300_000),
266        _ => panic!("Unsupported join type"),
267    };
268    tx_l.push_chunk(chunk);
269
270    match stream.next().await {
271        Some(Ok(Message::Barrier(b))) => {
272            assert_eq!(b.epoch.curr, test_epoch(1));
273        }
274        other => {
275            panic!("Expected a barrier, got {:?}", other);
276        }
277    }
278
279    if matches!(hash_join_workload, HashJoinWorkload::InCache) {
280        match stream.next().await {
281            Some(Ok(Message::Barrier(b))) => {
282                assert_eq!(b.epoch.curr, test_epoch(2));
283            }
284            other => {
285                panic!("Expected a barrier, got {:?}", other);
286            }
287        }
288    }
289
290    let expected_count = match join_type {
291        JoinType::LeftOuter => chunk_size,
292        JoinType::Inner => amp * chunk_size,
293        _ => panic!("Unsupported join type"),
294    };
295    let mut current_count = 0;
296    while current_count < expected_count {
297        match stream.next().await {
298            Some(Ok(Message::Chunk(c))) => {
299                current_count += c.cardinality();
300            }
301            other => {
302                panic!("Expected a barrier, got {:?}", other);
303            }
304        }
305    }
306}