1use std::sync::Arc;
16use std::sync::atomic::AtomicU64;
17
18use itertools::Itertools;
19use risingwave_common::array::{I64Array, Op};
20use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, TableId};
21use risingwave_common::hash::Key128;
22use risingwave_common::util::sort_util::OrderType;
23use risingwave_pb::plan_common::JoinType;
24use risingwave_storage::memory::MemoryStateStore;
25use strum_macros::Display;
26
27use super::*;
28use crate::common::table::test_utils::gen_pbtable;
29use crate::executor::monitor::StreamingMetrics;
30use crate::executor::prelude::StateTable;
31use crate::executor::test_utils::{MessageSender, MockSource};
32use crate::executor::{
33 ActorContext, HashJoinExecutor, JoinParams, JoinType as ConstJoinType, MemoryEncoding,
34};
35
36#[derive(Clone, Copy, Debug, Display)]
37pub enum HashJoinWorkload {
38 InCache,
39 NotInCache,
40}
41
42pub async fn create_in_memory_state_table(
43 mem_state: MemoryStateStore,
44 data_types: &[DataType],
45 order_types: &[OrderType],
46 pk_indices: &[usize],
47 table_id: u32,
48) -> (StateTable<MemoryStateStore>, StateTable<MemoryStateStore>) {
49 let column_descs = data_types
50 .iter()
51 .enumerate()
52 .map(|(id, data_type)| ColumnDesc::unnamed(ColumnId::new(id as i32), data_type.clone()))
53 .collect_vec();
54 let state_table = StateTable::from_table_catalog(
55 &gen_pbtable(
56 TableId::new(table_id),
57 column_descs,
58 order_types.to_vec(),
59 pk_indices.to_vec(),
60 0,
61 ),
62 mem_state.clone(),
63 None,
64 )
65 .await;
66
67 let mut degree_table_column_descs = vec![];
69 pk_indices.iter().enumerate().for_each(|(pk_id, idx)| {
70 degree_table_column_descs.push(ColumnDesc::unnamed(
71 ColumnId::new(pk_id as i32),
72 data_types[*idx].clone(),
73 ))
74 });
75 degree_table_column_descs.push(ColumnDesc::unnamed(
76 ColumnId::new(pk_indices.len() as i32),
77 DataType::Int64,
78 ));
79 let degree_state_table = StateTable::from_table_catalog(
80 &gen_pbtable(
81 TableId::new(table_id + 1),
82 degree_table_column_descs,
83 order_types.to_vec(),
84 pk_indices.to_vec(),
85 0,
86 ),
87 mem_state,
88 None,
89 )
90 .await;
91 (state_table, degree_state_table)
92}
93
94pub async fn setup_bench_stream_hash_join(
99 amp: usize,
100 workload: HashJoinWorkload,
101 join_type: JoinType,
102) -> (MessageSender, MessageSender, BoxedMessageStream) {
103 let fields = vec![DataType::Int64, DataType::Int64, DataType::Int64];
104 let orders = vec![OrderType::ascending(), OrderType::ascending()];
105 let state_store = MemoryStateStore::new();
106
107 let (lhs_state_table, lhs_degree_state_table) =
109 create_in_memory_state_table(state_store.clone(), &fields, &orders, &[0, 1], 0).await;
110
111 let (mut rhs_state_table, rhs_degree_state_table) =
113 create_in_memory_state_table(state_store.clone(), &fields, &orders, &[0, 1], 2).await;
114
115 if matches!(workload, HashJoinWorkload::NotInCache) {
117 let stream_chunk = build_chunk(amp, 200_000);
118 rhs_state_table.write_chunk(stream_chunk);
120 }
121
122 let schema = Schema::new(fields.iter().cloned().map(Field::unnamed).collect());
123
124 let (tx_l, source_l) = MockSource::channel();
125 let source_l = source_l.into_executor(schema.clone(), vec![1]);
126 let (tx_r, source_r) = MockSource::channel();
127 let source_r = source_r.into_executor(schema, vec![1]);
128
129 let schema: Vec<_> = [source_l.schema().fields(), source_r.schema().fields()]
133 .concat()
134 .into_iter()
135 .collect();
136 let schema_len = schema.len();
137 let info = ExecutorInfo::for_test(
138 Schema { fields: schema },
139 vec![0, 1, 3, 4],
140 "HashJoinExecutor".to_owned(),
141 0,
142 );
143
144 let params_l = JoinParams::new(vec![0], vec![1]);
146 let params_r = JoinParams::new(vec![0], vec![1]);
147
148 let cache_size = match workload {
149 HashJoinWorkload::InCache => Some(1_000_000),
150 HashJoinWorkload::NotInCache => None,
151 };
152
153 match join_type {
154 JoinType::Inner => {
155 let executor = HashJoinExecutor::<
156 Key128,
157 MemoryStateStore,
158 { ConstJoinType::Inner },
159 MemoryEncoding,
160 >::new_with_cache_size(
161 ActorContext::for_test(123),
162 info,
163 source_l,
164 source_r,
165 params_l,
166 params_r,
167 vec![false], (0..schema_len).collect_vec(),
169 None, vec![], lhs_state_table,
172 lhs_degree_state_table,
173 rhs_state_table,
174 rhs_degree_state_table,
175 Arc::new(AtomicU64::new(0)), false, Arc::new(StreamingMetrics::unused()),
178 1024, 2048, cache_size,
181 vec![], );
183 (tx_l, tx_r, executor.boxed().execute())
184 }
185 JoinType::LeftOuter => {
186 let executor = HashJoinExecutor::<
187 Key128,
188 MemoryStateStore,
189 { ConstJoinType::LeftOuter },
190 MemoryEncoding,
191 >::new_with_cache_size(
192 ActorContext::for_test(123),
193 info,
194 source_l,
195 source_r,
196 params_l,
197 params_r,
198 vec![false], (0..schema_len).collect_vec(),
200 None, vec![], lhs_state_table,
203 lhs_degree_state_table,
204 rhs_state_table,
205 rhs_degree_state_table,
206 Arc::new(AtomicU64::new(0)), false, Arc::new(StreamingMetrics::unused()),
209 1024, 2048, cache_size,
212 vec![], );
214 (tx_l, tx_r, executor.boxed().execute())
215 }
216 _ => panic!("Unsupported join type"),
217 }
218}
219
220fn build_chunk(size: usize, join_key_value: i64) -> StreamChunk {
221 let mut int64_jk_builder = DataType::Int64.create_array_builder(size);
223 int64_jk_builder.append_array(&I64Array::from_iter(vec![Some(join_key_value); size]).into());
224 let jk = int64_jk_builder.finish();
225
226 let mut int64_pk_data_chunk_builder = DataType::Int64.create_array_builder(size);
228 let seq = I64Array::from_iter((0..size as i64).map(Some));
229 int64_pk_data_chunk_builder.append_array(&I64Array::from(seq).into());
230 let pk = int64_pk_data_chunk_builder.finish();
231
232 let values = pk.clone();
234
235 let columns = vec![jk.into(), pk.into(), values.into()];
237 let ops = vec![Op::Insert; size];
238 StreamChunk::new(ops, columns)
239}
240
241pub async fn handle_streams(
242 hash_join_workload: HashJoinWorkload,
243 join_type: JoinType,
244 amp: usize,
245 mut tx_l: MessageSender,
246 mut tx_r: MessageSender,
247 mut stream: BoxedMessageStream,
248) {
249 tx_l.push_barrier(test_epoch(1), false);
251 tx_r.push_barrier(test_epoch(1), false);
252
253 if matches!(hash_join_workload, HashJoinWorkload::InCache) {
254 let chunk = build_chunk(amp, 200_000);
256 tx_r.push_chunk(chunk);
257
258 tx_l.push_barrier(test_epoch(2), false);
262 tx_r.push_barrier(test_epoch(2), false);
263 }
264
265 let chunk_size = match hash_join_workload {
267 HashJoinWorkload::InCache => 64,
268 HashJoinWorkload::NotInCache => 1,
269 };
270 let chunk = match join_type {
271 JoinType::Inner => build_chunk(chunk_size, 200_000),
273 JoinType::LeftOuter => build_chunk(chunk_size, 300_000),
275 _ => panic!("Unsupported join type"),
276 };
277 tx_l.push_chunk(chunk);
278
279 match stream.next().await {
280 Some(Ok(Message::Barrier(b))) => {
281 assert_eq!(b.epoch.curr, test_epoch(1));
282 }
283 other => {
284 panic!("Expected a barrier, got {:?}", other);
285 }
286 }
287
288 if matches!(hash_join_workload, HashJoinWorkload::InCache) {
289 match stream.next().await {
290 Some(Ok(Message::Barrier(b))) => {
291 assert_eq!(b.epoch.curr, test_epoch(2));
292 }
293 other => {
294 panic!("Expected a barrier, got {:?}", other);
295 }
296 }
297 }
298
299 let expected_count = match join_type {
300 JoinType::LeftOuter => chunk_size,
301 JoinType::Inner => amp * chunk_size,
302 _ => panic!("Unsupported join type"),
303 };
304 let mut current_count = 0;
305 while current_count < expected_count {
306 match stream.next().await {
307 Some(Ok(Message::Chunk(c))) => {
308 current_count += c.cardinality();
309 }
310 other => {
311 panic!("Expected a barrier, got {:?}", other);
312 }
313 }
314 }
315}