1use std::sync::Arc;
16use std::sync::atomic::AtomicU64;
17
18use itertools::Itertools;
19use risingwave_common::array::{I64Array, Op};
20use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, TableId};
21use risingwave_common::hash::Key128;
22use risingwave_common::util::sort_util::OrderType;
23use risingwave_pb::plan_common::JoinType;
24use risingwave_storage::memory::MemoryStateStore;
25use strum_macros::Display;
26
27use super::*;
28use crate::common::table::test_utils::gen_pbtable;
29use crate::executor::monitor::StreamingMetrics;
30use crate::executor::prelude::StateTable;
31use crate::executor::test_utils::{MessageSender, MockSource};
32use crate::executor::{ActorContext, HashJoinExecutor, JoinParams, JoinType as ConstJoinType};
33
34#[derive(Clone, Copy, Debug, Display)]
35pub enum HashJoinWorkload {
36 InCache,
37 NotInCache,
38}
39
40pub async fn create_in_memory_state_table(
41 mem_state: MemoryStateStore,
42 data_types: &[DataType],
43 order_types: &[OrderType],
44 pk_indices: &[usize],
45 table_id: u32,
46) -> (StateTable<MemoryStateStore>, StateTable<MemoryStateStore>) {
47 let column_descs = data_types
48 .iter()
49 .enumerate()
50 .map(|(id, data_type)| ColumnDesc::unnamed(ColumnId::new(id as i32), data_type.clone()))
51 .collect_vec();
52 let state_table = StateTable::from_table_catalog(
53 &gen_pbtable(
54 TableId::new(table_id),
55 column_descs,
56 order_types.to_vec(),
57 pk_indices.to_vec(),
58 0,
59 ),
60 mem_state.clone(),
61 None,
62 )
63 .await;
64
65 let mut degree_table_column_descs = vec![];
67 pk_indices.iter().enumerate().for_each(|(pk_id, idx)| {
68 degree_table_column_descs.push(ColumnDesc::unnamed(
69 ColumnId::new(pk_id as i32),
70 data_types[*idx].clone(),
71 ))
72 });
73 degree_table_column_descs.push(ColumnDesc::unnamed(
74 ColumnId::new(pk_indices.len() as i32),
75 DataType::Int64,
76 ));
77 let degree_state_table = StateTable::from_table_catalog(
78 &gen_pbtable(
79 TableId::new(table_id + 1),
80 degree_table_column_descs,
81 order_types.to_vec(),
82 pk_indices.to_vec(),
83 0,
84 ),
85 mem_state,
86 None,
87 )
88 .await;
89 (state_table, degree_state_table)
90}
91
92pub async fn setup_bench_stream_hash_join(
97 amp: usize,
98 workload: HashJoinWorkload,
99 join_type: JoinType,
100) -> (MessageSender, MessageSender, BoxedMessageStream) {
101 let fields = vec![DataType::Int64, DataType::Int64, DataType::Int64];
102 let orders = vec![OrderType::ascending(), OrderType::ascending()];
103 let state_store = MemoryStateStore::new();
104
105 let (lhs_state_table, lhs_degree_state_table) =
107 create_in_memory_state_table(state_store.clone(), &fields, &orders, &[0, 1], 0).await;
108
109 let (mut rhs_state_table, rhs_degree_state_table) =
111 create_in_memory_state_table(state_store.clone(), &fields, &orders, &[0, 1], 2).await;
112
113 if matches!(workload, HashJoinWorkload::NotInCache) {
115 let stream_chunk = build_chunk(amp, 200_000);
116 rhs_state_table.write_chunk(stream_chunk);
118 }
119
120 let schema = Schema::new(fields.iter().cloned().map(Field::unnamed).collect());
121
122 let (tx_l, source_l) = MockSource::channel();
123 let source_l = source_l.into_executor(schema.clone(), vec![1]);
124 let (tx_r, source_r) = MockSource::channel();
125 let source_r = source_r.into_executor(schema, vec![1]);
126
127 let schema: Vec<_> = [source_l.schema().fields(), source_r.schema().fields()]
131 .concat()
132 .into_iter()
133 .collect();
134 let schema_len = schema.len();
135 let info = ExecutorInfo::new(
136 Schema { fields: schema },
137 vec![0, 1, 3, 4],
138 "HashJoinExecutor".to_owned(),
139 0,
140 );
141
142 let params_l = JoinParams::new(vec![0], vec![1]);
144 let params_r = JoinParams::new(vec![0], vec![1]);
145
146 let cache_size = match workload {
147 HashJoinWorkload::InCache => Some(1_000_000),
148 HashJoinWorkload::NotInCache => None,
149 };
150
151 match join_type {
152 JoinType::Inner => {
153 let executor =
154 HashJoinExecutor::<Key128, MemoryStateStore, { ConstJoinType::Inner }>::new_with_cache_size(
155 ActorContext::for_test(123),
156 info,
157 source_l,
158 source_r,
159 params_l,
160 params_r,
161 vec![false], (0..schema_len).collect_vec(),
163 None, vec![], lhs_state_table,
166 lhs_degree_state_table,
167 rhs_state_table,
168 rhs_degree_state_table,
169 Arc::new(AtomicU64::new(0)), false, Arc::new(StreamingMetrics::unused()),
172 1024, 2048, cache_size,
175 );
176 (tx_l, tx_r, executor.boxed().execute())
177 }
178 JoinType::LeftOuter => {
179 let executor = HashJoinExecutor::<
180 Key128,
181 MemoryStateStore,
182 { ConstJoinType::LeftOuter },
183 >::new_with_cache_size(
184 ActorContext::for_test(123),
185 info,
186 source_l,
187 source_r,
188 params_l,
189 params_r,
190 vec![false], (0..schema_len).collect_vec(),
192 None, vec![], lhs_state_table,
195 lhs_degree_state_table,
196 rhs_state_table,
197 rhs_degree_state_table,
198 Arc::new(AtomicU64::new(0)), false, Arc::new(StreamingMetrics::unused()),
201 1024, 2048, cache_size,
204 );
205 (tx_l, tx_r, executor.boxed().execute())
206 }
207 _ => panic!("Unsupported join type"),
208 }
209}
210
211fn build_chunk(size: usize, join_key_value: i64) -> StreamChunk {
212 let mut int64_jk_builder = DataType::Int64.create_array_builder(size);
214 int64_jk_builder.append_array(&I64Array::from_iter(vec![Some(join_key_value); size]).into());
215 let jk = int64_jk_builder.finish();
216
217 let mut int64_pk_data_chunk_builder = DataType::Int64.create_array_builder(size);
219 let seq = I64Array::from_iter((0..size as i64).map(Some));
220 int64_pk_data_chunk_builder.append_array(&I64Array::from(seq).into());
221 let pk = int64_pk_data_chunk_builder.finish();
222
223 let values = pk.clone();
225
226 let columns = vec![jk.into(), pk.into(), values.into()];
228 let ops = vec![Op::Insert; size];
229 StreamChunk::new(ops, columns)
230}
231
232pub async fn handle_streams(
233 hash_join_workload: HashJoinWorkload,
234 join_type: JoinType,
235 amp: usize,
236 mut tx_l: MessageSender,
237 mut tx_r: MessageSender,
238 mut stream: BoxedMessageStream,
239) {
240 tx_l.push_barrier(test_epoch(1), false);
242 tx_r.push_barrier(test_epoch(1), false);
243
244 if matches!(hash_join_workload, HashJoinWorkload::InCache) {
245 let chunk = build_chunk(amp, 200_000);
247 tx_r.push_chunk(chunk);
248
249 tx_l.push_barrier(test_epoch(2), false);
253 tx_r.push_barrier(test_epoch(2), false);
254 }
255
256 let chunk_size = match hash_join_workload {
258 HashJoinWorkload::InCache => 64,
259 HashJoinWorkload::NotInCache => 1,
260 };
261 let chunk = match join_type {
262 JoinType::Inner => build_chunk(chunk_size, 200_000),
264 JoinType::LeftOuter => build_chunk(chunk_size, 300_000),
266 _ => panic!("Unsupported join type"),
267 };
268 tx_l.push_chunk(chunk);
269
270 match stream.next().await {
271 Some(Ok(Message::Barrier(b))) => {
272 assert_eq!(b.epoch.curr, test_epoch(1));
273 }
274 other => {
275 panic!("Expected a barrier, got {:?}", other);
276 }
277 }
278
279 if matches!(hash_join_workload, HashJoinWorkload::InCache) {
280 match stream.next().await {
281 Some(Ok(Message::Barrier(b))) => {
282 assert_eq!(b.epoch.curr, test_epoch(2));
283 }
284 other => {
285 panic!("Expected a barrier, got {:?}", other);
286 }
287 }
288 }
289
290 let expected_count = match join_type {
291 JoinType::LeftOuter => chunk_size,
292 JoinType::Inner => amp * chunk_size,
293 _ => panic!("Unsupported join type"),
294 };
295 let mut current_count = 0;
296 while current_count < expected_count {
297 match stream.next().await {
298 Some(Ok(Message::Chunk(c))) => {
299 current_count += c.cardinality();
300 }
301 other => {
302 panic!("Expected a barrier, got {:?}", other);
303 }
304 }
305 }
306}