risingwave_stream/executor/test_utils/
hash_join_executor.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16use std::sync::atomic::AtomicU64;
17
18use itertools::Itertools;
19use risingwave_common::array::{I64Array, Op};
20use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, TableId};
21use risingwave_common::hash::Key128;
22use risingwave_common::util::sort_util::OrderType;
23use risingwave_pb::plan_common::JoinType;
24use risingwave_storage::memory::MemoryStateStore;
25use strum_macros::Display;
26
27use super::*;
28use crate::common::table::test_utils::gen_pbtable;
29use crate::executor::monitor::StreamingMetrics;
30use crate::executor::prelude::StateTable;
31use crate::executor::test_utils::{MessageSender, MockSource};
32use crate::executor::{
33    ActorContext, HashJoinExecutor, JoinParams, JoinType as ConstJoinType, MemoryEncoding,
34};
35
36#[derive(Clone, Copy, Debug, Display)]
37pub enum HashJoinWorkload {
38    InCache,
39    NotInCache,
40}
41
42pub async fn create_in_memory_state_table(
43    mem_state: MemoryStateStore,
44    data_types: &[DataType],
45    order_types: &[OrderType],
46    pk_indices: &[usize],
47    table_id: u32,
48) -> (StateTable<MemoryStateStore>, StateTable<MemoryStateStore>) {
49    let column_descs = data_types
50        .iter()
51        .enumerate()
52        .map(|(id, data_type)| ColumnDesc::unnamed(ColumnId::new(id as i32), data_type.clone()))
53        .collect_vec();
54    let state_table = StateTable::from_table_catalog(
55        &gen_pbtable(
56            TableId::new(table_id),
57            column_descs,
58            order_types.to_vec(),
59            pk_indices.to_vec(),
60            0,
61        ),
62        mem_state.clone(),
63        None,
64    )
65    .await;
66
67    // Create degree table
68    let mut degree_table_column_descs = vec![];
69    pk_indices.iter().enumerate().for_each(|(pk_id, idx)| {
70        degree_table_column_descs.push(ColumnDesc::unnamed(
71            ColumnId::new(pk_id as i32),
72            data_types[*idx].clone(),
73        ))
74    });
75    degree_table_column_descs.push(ColumnDesc::unnamed(
76        ColumnId::new(pk_indices.len() as i32),
77        DataType::Int64,
78    ));
79    let degree_state_table = StateTable::from_table_catalog(
80        &gen_pbtable(
81            TableId::new(table_id + 1),
82            degree_table_column_descs,
83            order_types.to_vec(),
84            pk_indices.to_vec(),
85            0,
86        ),
87        mem_state,
88        None,
89    )
90    .await;
91    (state_table, degree_state_table)
92}
93
94/// 1. Refill state table of build side.
95/// 2. Init executor.
96/// 3. Push data to the probe side.
97/// 4. Check memory utilization.
98pub async fn setup_bench_stream_hash_join(
99    amp: usize,
100    workload: HashJoinWorkload,
101    join_type: JoinType,
102) -> (MessageSender, MessageSender, BoxedMessageStream) {
103    let fields = vec![DataType::Int64, DataType::Int64, DataType::Int64];
104    let orders = vec![OrderType::ascending(), OrderType::ascending()];
105    let state_store = MemoryStateStore::new();
106
107    // Probe side
108    let (lhs_state_table, lhs_degree_state_table) =
109        create_in_memory_state_table(state_store.clone(), &fields, &orders, &[0, 1], 0).await;
110
111    // Build side
112    let (mut rhs_state_table, rhs_degree_state_table) =
113        create_in_memory_state_table(state_store.clone(), &fields, &orders, &[0, 1], 2).await;
114
115    // Insert 100K records into the build side.
116    if matches!(workload, HashJoinWorkload::NotInCache) {
117        let stream_chunk = build_chunk(amp, 200_000);
118        // Write to state table.
119        rhs_state_table.write_chunk(stream_chunk);
120    }
121
122    let schema = Schema::new(fields.iter().cloned().map(Field::unnamed).collect());
123
124    let (tx_l, source_l) = MockSource::channel();
125    let source_l = source_l.into_executor(schema.clone(), vec![1]);
126    let (tx_r, source_r) = MockSource::channel();
127    let source_r = source_r.into_executor(schema, vec![1]);
128
129    // Schema is the concatenation of the two source schemas.
130    // [lhs(jk):0, lhs(pk):1, lhs(value):2, rhs(jk):0, rhs(pk):1, rhs(value):2]
131    // [0,         1,         2,            3,         4,         5           ]
132    let schema: Vec<_> = [source_l.schema().fields(), source_r.schema().fields()]
133        .concat()
134        .into_iter()
135        .collect();
136    let schema_len = schema.len();
137    let info = ExecutorInfo::for_test(
138        Schema { fields: schema },
139        vec![0, 1, 3, 4],
140        "HashJoinExecutor".to_owned(),
141        0,
142    );
143
144    // join-key is [0], primary-key is [1].
145    let params_l = JoinParams::new(vec![0], vec![1]);
146    let params_r = JoinParams::new(vec![0], vec![1]);
147
148    let cache_size = match workload {
149        HashJoinWorkload::InCache => Some(1_000_000),
150        HashJoinWorkload::NotInCache => None,
151    };
152
153    match join_type {
154        JoinType::Inner => {
155            let executor = HashJoinExecutor::<
156                Key128,
157                MemoryStateStore,
158                { ConstJoinType::Inner },
159                MemoryEncoding,
160            >::new_with_cache_size(
161                ActorContext::for_test(123),
162                info,
163                source_l,
164                source_r,
165                params_l,
166                params_r,
167                vec![false], // null-safe
168                (0..schema_len).collect_vec(),
169                None,   // condition, it is an eq join, we have no condition
170                vec![], // ineq pairs
171                lhs_state_table,
172                lhs_degree_state_table,
173                rhs_state_table,
174                rhs_degree_state_table,
175                Arc::new(AtomicU64::new(0)), // watermark epoch
176                false,                       // is_append_only
177                Arc::new(StreamingMetrics::unused()),
178                1024, // chunk_size
179                2048, // high_join_amplification_threshold
180                cache_size,
181                vec![], // watermark_indices_in_jk
182            );
183            (tx_l, tx_r, executor.boxed().execute())
184        }
185        JoinType::LeftOuter => {
186            let executor = HashJoinExecutor::<
187                Key128,
188                MemoryStateStore,
189                { ConstJoinType::LeftOuter },
190                MemoryEncoding,
191            >::new_with_cache_size(
192                ActorContext::for_test(123),
193                info,
194                source_l,
195                source_r,
196                params_l,
197                params_r,
198                vec![false], // null-safe
199                (0..schema_len).collect_vec(),
200                None,   // condition, it is an eq join, we have no condition
201                vec![], // ineq pairs
202                lhs_state_table,
203                lhs_degree_state_table,
204                rhs_state_table,
205                rhs_degree_state_table,
206                Arc::new(AtomicU64::new(0)), // watermark epoch
207                false,                       // is_append_only
208                Arc::new(StreamingMetrics::unused()),
209                1024, // chunk_size
210                2048, // high_join_amplification_threshold
211                cache_size,
212                vec![], // watermark_indices_in_jk
213            );
214            (tx_l, tx_r, executor.boxed().execute())
215        }
216        _ => panic!("Unsupported join type"),
217    }
218}
219
220fn build_chunk(size: usize, join_key_value: i64) -> StreamChunk {
221    // Create column [0]: join key. Each record has the same value, to trigger join amplification.
222    let mut int64_jk_builder = DataType::Int64.create_array_builder(size);
223    int64_jk_builder.append_array(&I64Array::from_iter(vec![Some(join_key_value); size]).into());
224    let jk = int64_jk_builder.finish();
225
226    // Create column [1]: pk. The original pk will be here, it will be unique.
227    let mut int64_pk_data_chunk_builder = DataType::Int64.create_array_builder(size);
228    let seq = I64Array::from_iter((0..size as i64).map(Some));
229    int64_pk_data_chunk_builder.append_array(&I64Array::from(seq).into());
230    let pk = int64_pk_data_chunk_builder.finish();
231
232    // Create column [2]: value. This can be an arbitrary value, so just clone the pk column.
233    let values = pk.clone();
234
235    // Build the stream chunk.
236    let columns = vec![jk.into(), pk.into(), values.into()];
237    let ops = vec![Op::Insert; size];
238    StreamChunk::new(ops, columns)
239}
240
241pub async fn handle_streams(
242    hash_join_workload: HashJoinWorkload,
243    join_type: JoinType,
244    amp: usize,
245    mut tx_l: MessageSender,
246    mut tx_r: MessageSender,
247    mut stream: BoxedMessageStream,
248) {
249    // Init executors
250    tx_l.push_barrier(test_epoch(1), false);
251    tx_r.push_barrier(test_epoch(1), false);
252
253    if matches!(hash_join_workload, HashJoinWorkload::InCache) {
254        // Push a single record into tx_r, so 100K records to be matched are cached.
255        let chunk = build_chunk(amp, 200_000);
256        tx_r.push_chunk(chunk);
257
258        // Ensure that the chunk on the rhs is processed, before inserting a chunk
259        // into the lhs. This is to ensure that the rhs chunk is cached,
260        // and we don't get interleaving of chunks between lhs and rhs.
261        tx_l.push_barrier(test_epoch(2), false);
262        tx_r.push_barrier(test_epoch(2), false);
263    }
264
265    // Push a chunk of records into tx_l, matches 100K records in the build side.
266    let chunk_size = match hash_join_workload {
267        HashJoinWorkload::InCache => 64,
268        HashJoinWorkload::NotInCache => 1,
269    };
270    let chunk = match join_type {
271        // Make sure all match
272        JoinType::Inner => build_chunk(chunk_size, 200_000),
273        // Make sure no match is found.
274        JoinType::LeftOuter => build_chunk(chunk_size, 300_000),
275        _ => panic!("Unsupported join type"),
276    };
277    tx_l.push_chunk(chunk);
278
279    match stream.next().await {
280        Some(Ok(Message::Barrier(b))) => {
281            assert_eq!(b.epoch.curr, test_epoch(1));
282        }
283        other => {
284            panic!("Expected a barrier, got {:?}", other);
285        }
286    }
287
288    if matches!(hash_join_workload, HashJoinWorkload::InCache) {
289        match stream.next().await {
290            Some(Ok(Message::Barrier(b))) => {
291                assert_eq!(b.epoch.curr, test_epoch(2));
292            }
293            other => {
294                panic!("Expected a barrier, got {:?}", other);
295            }
296        }
297    }
298
299    let expected_count = match join_type {
300        JoinType::LeftOuter => chunk_size,
301        JoinType::Inner => amp * chunk_size,
302        _ => panic!("Unsupported join type"),
303    };
304    let mut current_count = 0;
305    while current_count < expected_count {
306        match stream.next().await {
307            Some(Ok(Message::Chunk(c))) => {
308                current_count += c.cardinality();
309            }
310            other => {
311                panic!("Expected a barrier, got {:?}", other);
312            }
313        }
314    }
315}