1use std::sync::Arc;
16use std::sync::atomic::AtomicU64;
17
18use itertools::Itertools;
19use risingwave_common::array::{I64Array, Op};
20use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, TableId};
21use risingwave_common::hash::Key128;
22use risingwave_common::util::sort_util::OrderType;
23use risingwave_pb::plan_common::JoinType;
24use risingwave_storage::memory::MemoryStateStore;
25use strum_macros::Display;
26
27use super::*;
28use crate::common::table::test_utils::gen_pbtable;
29use crate::executor::monitor::StreamingMetrics;
30use crate::executor::prelude::StateTable;
31use crate::executor::test_utils::{MessageSender, MockSource};
32use crate::executor::{
33 ActorContext, HashJoinExecutor, JoinParams, JoinType as ConstJoinType, MemoryEncoding,
34};
35
36#[derive(Clone, Copy, Debug, Display)]
37pub enum HashJoinWorkload {
38 InCache,
39 NotInCache,
40}
41
42pub async fn create_in_memory_state_table(
43 mem_state: MemoryStateStore,
44 data_types: &[DataType],
45 order_types: &[OrderType],
46 pk_indices: &[usize],
47 table_id: u32,
48) -> (StateTable<MemoryStateStore>, StateTable<MemoryStateStore>) {
49 let column_descs = data_types
50 .iter()
51 .enumerate()
52 .map(|(id, data_type)| ColumnDesc::unnamed(ColumnId::new(id as i32), data_type.clone()))
53 .collect_vec();
54 let state_table = StateTable::from_table_catalog(
55 &gen_pbtable(
56 TableId::new(table_id),
57 column_descs,
58 order_types.to_vec(),
59 pk_indices.to_vec(),
60 0,
61 ),
62 mem_state.clone(),
63 None,
64 )
65 .await;
66
67 let mut degree_table_column_descs = vec![];
69 pk_indices.iter().enumerate().for_each(|(pk_id, idx)| {
70 degree_table_column_descs.push(ColumnDesc::unnamed(
71 ColumnId::new(pk_id as i32),
72 data_types[*idx].clone(),
73 ))
74 });
75 degree_table_column_descs.push(ColumnDesc::unnamed(
76 ColumnId::new(pk_indices.len() as i32),
77 DataType::Int64,
78 ));
79 let degree_state_table = StateTable::from_table_catalog(
80 &gen_pbtable(
81 TableId::new(table_id + 1),
82 degree_table_column_descs,
83 order_types.to_vec(),
84 pk_indices.to_vec(),
85 0,
86 ),
87 mem_state,
88 None,
89 )
90 .await;
91 (state_table, degree_state_table)
92}
93
94pub async fn setup_bench_stream_hash_join(
99 amp: usize,
100 workload: HashJoinWorkload,
101 join_type: JoinType,
102) -> (MessageSender, MessageSender, BoxedMessageStream) {
103 let fields = vec![DataType::Int64, DataType::Int64, DataType::Int64];
104 let orders = vec![OrderType::ascending(), OrderType::ascending()];
105 let state_store = MemoryStateStore::new();
106
107 let (lhs_state_table, lhs_degree_state_table) =
109 create_in_memory_state_table(state_store.clone(), &fields, &orders, &[0, 1], 0).await;
110
111 let (mut rhs_state_table, rhs_degree_state_table) =
113 create_in_memory_state_table(state_store.clone(), &fields, &orders, &[0, 1], 2).await;
114
115 if matches!(workload, HashJoinWorkload::NotInCache) {
117 let stream_chunk = build_chunk(amp, 200_000);
118 rhs_state_table.write_chunk(stream_chunk);
120 }
121
122 let schema = Schema::new(fields.iter().cloned().map(Field::unnamed).collect());
123
124 let (tx_l, source_l) = MockSource::channel();
125 let source_l = source_l.into_executor(schema.clone(), vec![1]);
126 let (tx_r, source_r) = MockSource::channel();
127 let source_r = source_r.into_executor(schema, vec![1]);
128
129 let schema: Vec<_> = [source_l.schema().fields(), source_r.schema().fields()]
133 .concat()
134 .into_iter()
135 .collect();
136 let schema_len = schema.len();
137 let info = ExecutorInfo::new(
138 Schema { fields: schema },
139 vec![0, 1, 3, 4],
140 "HashJoinExecutor".to_owned(),
141 0,
142 );
143
144 let params_l = JoinParams::new(vec![0], vec![1]);
146 let params_r = JoinParams::new(vec![0], vec![1]);
147
148 let cache_size = match workload {
149 HashJoinWorkload::InCache => Some(1_000_000),
150 HashJoinWorkload::NotInCache => None,
151 };
152
153 match join_type {
154 JoinType::Inner => {
155 let executor = HashJoinExecutor::<
156 Key128,
157 MemoryStateStore,
158 { ConstJoinType::Inner },
159 MemoryEncoding,
160 >::new_with_cache_size(
161 ActorContext::for_test(123),
162 info,
163 source_l,
164 source_r,
165 params_l,
166 params_r,
167 vec![false], (0..schema_len).collect_vec(),
169 None, vec![], lhs_state_table,
172 lhs_degree_state_table,
173 rhs_state_table,
174 rhs_degree_state_table,
175 Arc::new(AtomicU64::new(0)), false, Arc::new(StreamingMetrics::unused()),
178 1024, 2048, cache_size,
181 );
182 (tx_l, tx_r, executor.boxed().execute())
183 }
184 JoinType::LeftOuter => {
185 let executor = HashJoinExecutor::<
186 Key128,
187 MemoryStateStore,
188 { ConstJoinType::LeftOuter },
189 MemoryEncoding,
190 >::new_with_cache_size(
191 ActorContext::for_test(123),
192 info,
193 source_l,
194 source_r,
195 params_l,
196 params_r,
197 vec![false], (0..schema_len).collect_vec(),
199 None, vec![], lhs_state_table,
202 lhs_degree_state_table,
203 rhs_state_table,
204 rhs_degree_state_table,
205 Arc::new(AtomicU64::new(0)), false, Arc::new(StreamingMetrics::unused()),
208 1024, 2048, cache_size,
211 );
212 (tx_l, tx_r, executor.boxed().execute())
213 }
214 _ => panic!("Unsupported join type"),
215 }
216}
217
218fn build_chunk(size: usize, join_key_value: i64) -> StreamChunk {
219 let mut int64_jk_builder = DataType::Int64.create_array_builder(size);
221 int64_jk_builder.append_array(&I64Array::from_iter(vec![Some(join_key_value); size]).into());
222 let jk = int64_jk_builder.finish();
223
224 let mut int64_pk_data_chunk_builder = DataType::Int64.create_array_builder(size);
226 let seq = I64Array::from_iter((0..size as i64).map(Some));
227 int64_pk_data_chunk_builder.append_array(&I64Array::from(seq).into());
228 let pk = int64_pk_data_chunk_builder.finish();
229
230 let values = pk.clone();
232
233 let columns = vec![jk.into(), pk.into(), values.into()];
235 let ops = vec![Op::Insert; size];
236 StreamChunk::new(ops, columns)
237}
238
239pub async fn handle_streams(
240 hash_join_workload: HashJoinWorkload,
241 join_type: JoinType,
242 amp: usize,
243 mut tx_l: MessageSender,
244 mut tx_r: MessageSender,
245 mut stream: BoxedMessageStream,
246) {
247 tx_l.push_barrier(test_epoch(1), false);
249 tx_r.push_barrier(test_epoch(1), false);
250
251 if matches!(hash_join_workload, HashJoinWorkload::InCache) {
252 let chunk = build_chunk(amp, 200_000);
254 tx_r.push_chunk(chunk);
255
256 tx_l.push_barrier(test_epoch(2), false);
260 tx_r.push_barrier(test_epoch(2), false);
261 }
262
263 let chunk_size = match hash_join_workload {
265 HashJoinWorkload::InCache => 64,
266 HashJoinWorkload::NotInCache => 1,
267 };
268 let chunk = match join_type {
269 JoinType::Inner => build_chunk(chunk_size, 200_000),
271 JoinType::LeftOuter => build_chunk(chunk_size, 300_000),
273 _ => panic!("Unsupported join type"),
274 };
275 tx_l.push_chunk(chunk);
276
277 match stream.next().await {
278 Some(Ok(Message::Barrier(b))) => {
279 assert_eq!(b.epoch.curr, test_epoch(1));
280 }
281 other => {
282 panic!("Expected a barrier, got {:?}", other);
283 }
284 }
285
286 if matches!(hash_join_workload, HashJoinWorkload::InCache) {
287 match stream.next().await {
288 Some(Ok(Message::Barrier(b))) => {
289 assert_eq!(b.epoch.curr, test_epoch(2));
290 }
291 other => {
292 panic!("Expected a barrier, got {:?}", other);
293 }
294 }
295 }
296
297 let expected_count = match join_type {
298 JoinType::LeftOuter => chunk_size,
299 JoinType::Inner => amp * chunk_size,
300 _ => panic!("Unsupported join type"),
301 };
302 let mut current_count = 0;
303 while current_count < expected_count {
304 match stream.next().await {
305 Some(Ok(Message::Chunk(c))) => {
306 current_count += c.cardinality();
307 }
308 other => {
309 panic!("Expected a barrier, got {:?}", other);
310 }
311 }
312 }
313}