risingwave_stream/executor/test_utils/
hash_join_executor.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16use std::sync::atomic::AtomicU64;
17
18use itertools::Itertools;
19use risingwave_common::array::{I64Array, Op};
20use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, TableId};
21use risingwave_common::hash::Key128;
22use risingwave_common::util::sort_util::OrderType;
23use risingwave_pb::plan_common::JoinType;
24use risingwave_storage::memory::MemoryStateStore;
25use strum_macros::Display;
26
27use super::*;
28use crate::common::table::test_utils::gen_pbtable;
29use crate::executor::monitor::StreamingMetrics;
30use crate::executor::prelude::StateTable;
31use crate::executor::test_utils::{MessageSender, MockSource};
32use crate::executor::{
33    ActorContext, HashJoinExecutor, JoinParams, JoinType as ConstJoinType, MemoryEncoding,
34};
35
36#[derive(Clone, Copy, Debug, Display)]
37pub enum HashJoinWorkload {
38    InCache,
39    NotInCache,
40}
41
42pub async fn create_in_memory_state_table(
43    mem_state: MemoryStateStore,
44    data_types: &[DataType],
45    order_types: &[OrderType],
46    pk_indices: &[usize],
47    table_id: u32,
48) -> (StateTable<MemoryStateStore>, StateTable<MemoryStateStore>) {
49    let column_descs = data_types
50        .iter()
51        .enumerate()
52        .map(|(id, data_type)| ColumnDesc::unnamed(ColumnId::new(id as i32), data_type.clone()))
53        .collect_vec();
54    let state_table = StateTable::from_table_catalog(
55        &gen_pbtable(
56            TableId::new(table_id),
57            column_descs,
58            order_types.to_vec(),
59            pk_indices.to_vec(),
60            0,
61        ),
62        mem_state.clone(),
63        None,
64    )
65    .await;
66
67    // Create degree table
68    let mut degree_table_column_descs = vec![];
69    pk_indices.iter().enumerate().for_each(|(pk_id, idx)| {
70        degree_table_column_descs.push(ColumnDesc::unnamed(
71            ColumnId::new(pk_id as i32),
72            data_types[*idx].clone(),
73        ))
74    });
75    degree_table_column_descs.push(ColumnDesc::unnamed(
76        ColumnId::new(pk_indices.len() as i32),
77        DataType::Int64,
78    ));
79    let degree_state_table = StateTable::from_table_catalog(
80        &gen_pbtable(
81            TableId::new(table_id + 1),
82            degree_table_column_descs,
83            order_types.to_vec(),
84            pk_indices.to_vec(),
85            0,
86        ),
87        mem_state,
88        None,
89    )
90    .await;
91    (state_table, degree_state_table)
92}
93
94/// 1. Refill state table of build side.
95/// 2. Init executor.
96/// 3. Push data to the probe side.
97/// 4. Check memory utilization.
98pub async fn setup_bench_stream_hash_join(
99    amp: usize,
100    workload: HashJoinWorkload,
101    join_type: JoinType,
102) -> (MessageSender, MessageSender, BoxedMessageStream) {
103    let fields = vec![DataType::Int64, DataType::Int64, DataType::Int64];
104    let orders = vec![OrderType::ascending(), OrderType::ascending()];
105    let state_store = MemoryStateStore::new();
106
107    // Probe side
108    let (lhs_state_table, lhs_degree_state_table) =
109        create_in_memory_state_table(state_store.clone(), &fields, &orders, &[0, 1], 0).await;
110
111    // Build side
112    let (mut rhs_state_table, rhs_degree_state_table) =
113        create_in_memory_state_table(state_store.clone(), &fields, &orders, &[0, 1], 2).await;
114
115    // Insert 100K records into the build side.
116    if matches!(workload, HashJoinWorkload::NotInCache) {
117        let stream_chunk = build_chunk(amp, 200_000);
118        // Write to state table.
119        rhs_state_table.write_chunk(stream_chunk);
120    }
121
122    let schema = Schema::new(fields.iter().cloned().map(Field::unnamed).collect());
123
124    let (tx_l, source_l) = MockSource::channel();
125    let source_l = source_l.into_executor(schema.clone(), vec![1]);
126    let (tx_r, source_r) = MockSource::channel();
127    let source_r = source_r.into_executor(schema, vec![1]);
128
129    // Schema is the concatenation of the two source schemas.
130    // [lhs(jk):0, lhs(pk):1, lhs(value):2, rhs(jk):0, rhs(pk):1, rhs(value):2]
131    // [0,         1,         2,            3,         4,         5           ]
132    let schema: Vec<_> = [source_l.schema().fields(), source_r.schema().fields()]
133        .concat()
134        .into_iter()
135        .collect();
136    let schema_len = schema.len();
137    let info = ExecutorInfo::new(
138        Schema { fields: schema },
139        vec![0, 1, 3, 4],
140        "HashJoinExecutor".to_owned(),
141        0,
142    );
143
144    // join-key is [0], primary-key is [1].
145    let params_l = JoinParams::new(vec![0], vec![1]);
146    let params_r = JoinParams::new(vec![0], vec![1]);
147
148    let cache_size = match workload {
149        HashJoinWorkload::InCache => Some(1_000_000),
150        HashJoinWorkload::NotInCache => None,
151    };
152
153    match join_type {
154        JoinType::Inner => {
155            let executor = HashJoinExecutor::<
156                Key128,
157                MemoryStateStore,
158                { ConstJoinType::Inner },
159                MemoryEncoding,
160            >::new_with_cache_size(
161                ActorContext::for_test(123),
162                info,
163                source_l,
164                source_r,
165                params_l,
166                params_r,
167                vec![false], // null-safe
168                (0..schema_len).collect_vec(),
169                None,   // condition, it is an eq join, we have no condition
170                vec![], // ineq pairs
171                lhs_state_table,
172                lhs_degree_state_table,
173                rhs_state_table,
174                rhs_degree_state_table,
175                Arc::new(AtomicU64::new(0)), // watermark epoch
176                false,                       // is_append_only
177                Arc::new(StreamingMetrics::unused()),
178                1024, // chunk_size
179                2048, // high_join_amplification_threshold
180                cache_size,
181            );
182            (tx_l, tx_r, executor.boxed().execute())
183        }
184        JoinType::LeftOuter => {
185            let executor = HashJoinExecutor::<
186                Key128,
187                MemoryStateStore,
188                { ConstJoinType::LeftOuter },
189                MemoryEncoding,
190            >::new_with_cache_size(
191                ActorContext::for_test(123),
192                info,
193                source_l,
194                source_r,
195                params_l,
196                params_r,
197                vec![false], // null-safe
198                (0..schema_len).collect_vec(),
199                None,   // condition, it is an eq join, we have no condition
200                vec![], // ineq pairs
201                lhs_state_table,
202                lhs_degree_state_table,
203                rhs_state_table,
204                rhs_degree_state_table,
205                Arc::new(AtomicU64::new(0)), // watermark epoch
206                false,                       // is_append_only
207                Arc::new(StreamingMetrics::unused()),
208                1024, // chunk_size
209                2048, // high_join_amplification_threshold
210                cache_size,
211            );
212            (tx_l, tx_r, executor.boxed().execute())
213        }
214        _ => panic!("Unsupported join type"),
215    }
216}
217
218fn build_chunk(size: usize, join_key_value: i64) -> StreamChunk {
219    // Create column [0]: join key. Each record has the same value, to trigger join amplification.
220    let mut int64_jk_builder = DataType::Int64.create_array_builder(size);
221    int64_jk_builder.append_array(&I64Array::from_iter(vec![Some(join_key_value); size]).into());
222    let jk = int64_jk_builder.finish();
223
224    // Create column [1]: pk. The original pk will be here, it will be unique.
225    let mut int64_pk_data_chunk_builder = DataType::Int64.create_array_builder(size);
226    let seq = I64Array::from_iter((0..size as i64).map(Some));
227    int64_pk_data_chunk_builder.append_array(&I64Array::from(seq).into());
228    let pk = int64_pk_data_chunk_builder.finish();
229
230    // Create column [2]: value. This can be an arbitrary value, so just clone the pk column.
231    let values = pk.clone();
232
233    // Build the stream chunk.
234    let columns = vec![jk.into(), pk.into(), values.into()];
235    let ops = vec![Op::Insert; size];
236    StreamChunk::new(ops, columns)
237}
238
239pub async fn handle_streams(
240    hash_join_workload: HashJoinWorkload,
241    join_type: JoinType,
242    amp: usize,
243    mut tx_l: MessageSender,
244    mut tx_r: MessageSender,
245    mut stream: BoxedMessageStream,
246) {
247    // Init executors
248    tx_l.push_barrier(test_epoch(1), false);
249    tx_r.push_barrier(test_epoch(1), false);
250
251    if matches!(hash_join_workload, HashJoinWorkload::InCache) {
252        // Push a single record into tx_r, so 100K records to be matched are cached.
253        let chunk = build_chunk(amp, 200_000);
254        tx_r.push_chunk(chunk);
255
256        // Ensure that the chunk on the rhs is processed, before inserting a chunk
257        // into the lhs. This is to ensure that the rhs chunk is cached,
258        // and we don't get interleaving of chunks between lhs and rhs.
259        tx_l.push_barrier(test_epoch(2), false);
260        tx_r.push_barrier(test_epoch(2), false);
261    }
262
263    // Push a chunk of records into tx_l, matches 100K records in the build side.
264    let chunk_size = match hash_join_workload {
265        HashJoinWorkload::InCache => 64,
266        HashJoinWorkload::NotInCache => 1,
267    };
268    let chunk = match join_type {
269        // Make sure all match
270        JoinType::Inner => build_chunk(chunk_size, 200_000),
271        // Make sure no match is found.
272        JoinType::LeftOuter => build_chunk(chunk_size, 300_000),
273        _ => panic!("Unsupported join type"),
274    };
275    tx_l.push_chunk(chunk);
276
277    match stream.next().await {
278        Some(Ok(Message::Barrier(b))) => {
279            assert_eq!(b.epoch.curr, test_epoch(1));
280        }
281        other => {
282            panic!("Expected a barrier, got {:?}", other);
283        }
284    }
285
286    if matches!(hash_join_workload, HashJoinWorkload::InCache) {
287        match stream.next().await {
288            Some(Ok(Message::Barrier(b))) => {
289                assert_eq!(b.epoch.curr, test_epoch(2));
290            }
291            other => {
292                panic!("Expected a barrier, got {:?}", other);
293            }
294        }
295    }
296
297    let expected_count = match join_type {
298        JoinType::LeftOuter => chunk_size,
299        JoinType::Inner => amp * chunk_size,
300        _ => panic!("Unsupported join type"),
301    };
302    let mut current_count = 0;
303    while current_count < expected_count {
304        match stream.next().await {
305            Some(Ok(Message::Chunk(c))) => {
306                current_count += c.cardinality();
307            }
308            other => {
309                panic!("Expected a barrier, got {:?}", other);
310            }
311        }
312    }
313}