risingwave_stream/from_proto/
hash_join.rs1use std::sync::Arc;
16
17use risingwave_common::config::streaming::JoinEncodingType;
18use risingwave_common::hash::{HashKey, HashKeyDispatcher};
19use risingwave_common::types::DataType;
20use risingwave_expr::expr::{NonStrictExpression, build_non_strict_from_prost};
21use risingwave_pb::plan_common::JoinType as JoinTypeProto;
22use risingwave_pb::stream_plan::{HashJoinNode, InequalityType};
23
24use super::*;
25use crate::common::table::state_table::{StateTable, StateTableBuilder};
26use crate::executor::hash_join::*;
27use crate::executor::monitor::StreamingMetrics;
28use crate::executor::{ActorContextRef, CpuEncoding, JoinType, MemoryEncoding};
29use crate::task::AtomicU64Ref;
30
31pub struct HashJoinExecutorBuilder;
32
33impl ExecutorBuilder for HashJoinExecutorBuilder {
34 type Node = HashJoinNode;
35
36 async fn new_boxed_executor(
37 params: ExecutorParams,
38 node: &Self::Node,
39 store: impl StateStore,
40 ) -> StreamResult<Executor> {
41 let is_append_only = node.is_append_only;
42 let vnodes = Arc::new(params.vnode_bitmap.expect("vnodes not set for hash join"));
43
44 let [source_l, source_r]: [_; 2] = params.input.try_into().unwrap();
45
46 let table_l = node.get_left_table()?;
47 let degree_table_l = node.get_left_degree_table()?;
48
49 let table_r = node.get_right_table()?;
50 let degree_table_r = node.get_right_degree_table()?;
51
52 let params_l = JoinParams::new(
53 node.get_left_key()
54 .iter()
55 .map(|key| *key as usize)
56 .collect_vec(),
57 node.get_left_deduped_input_pk_indices()
58 .iter()
59 .map(|key| *key as usize)
60 .collect_vec(),
61 );
62 let params_r = JoinParams::new(
63 node.get_right_key()
64 .iter()
65 .map(|key| *key as usize)
66 .collect_vec(),
67 node.get_right_deduped_input_pk_indices()
68 .iter()
69 .map(|key| *key as usize)
70 .collect_vec(),
71 );
72 let null_safe = node.get_null_safe().clone();
73 let output_indices = node
74 .get_output_indices()
75 .iter()
76 .map(|&x| x as usize)
77 .collect_vec();
78
79 let condition = match node.get_condition() {
80 Ok(cond_prost) => Some(build_non_strict_from_prost(
81 cond_prost,
82 params.eval_error_report.clone(),
83 )?),
84 Err(_) => None,
85 };
86 trace!("Join non-equi condition: {:?}", condition);
87
88 let inequality_pairs: Vec<InequalityPairInfo> =
90 if !node.get_inequality_pairs_v2().is_empty() {
91 node.get_inequality_pairs_v2()
93 .iter()
94 .map(|pair| InequalityPairInfo {
95 left_idx: pair.get_left_idx() as usize,
96 right_idx: pair.get_right_idx() as usize,
97 clean_left_state: pair.get_clean_left_state(),
98 clean_right_state: pair.get_clean_right_state(),
99 op: pair.op(),
100 })
101 .collect()
102 } else {
103 node.get_inequality_pairs()
105 .iter()
106 .map(|pair| {
107 let key_required_larger = pair.get_key_required_larger() as usize;
108 let key_required_smaller = pair.get_key_required_smaller() as usize;
109 let left_input_len = source_l.schema().len();
110
111 let (left_idx, right_idx, clean_left, clean_right, op) =
115 if key_required_larger < left_input_len {
116 (
118 key_required_larger,
119 key_required_smaller - left_input_len,
120 pair.get_clean_state(),
121 false,
122 InequalityType::GreaterThanOrEqual,
123 )
124 } else {
125 (
127 key_required_smaller,
128 key_required_larger - left_input_len,
129 false,
130 pair.get_clean_state(),
131 InequalityType::LessThanOrEqual,
132 )
133 };
134
135 InequalityPairInfo {
136 left_idx,
137 right_idx,
138 clean_left_state: clean_left,
139 clean_right_state: clean_right,
140 op,
141 }
142 })
143 .collect()
144 };
145
146 let join_key_data_types = params_l
147 .join_key_indices
148 .iter()
149 .map(|idx| source_l.schema().fields[*idx].data_type())
150 .collect_vec();
151
152 let state_table_l = StateTableBuilder::new(table_l, store.clone(), Some(vnodes.clone()))
153 .enable_preload_all_rows_by_config(¶ms.config)
154 .build()
155 .await;
156 let degree_state_table_l =
157 StateTableBuilder::new(degree_table_l, store.clone(), Some(vnodes.clone()))
158 .enable_preload_all_rows_by_config(¶ms.config)
159 .build()
160 .await;
161
162 let state_table_r = StateTableBuilder::new(table_r, store.clone(), Some(vnodes.clone()))
163 .enable_preload_all_rows_by_config(¶ms.config)
164 .build()
165 .await;
166 let degree_state_table_r = StateTableBuilder::new(degree_table_r, store, Some(vnodes))
167 .enable_preload_all_rows_by_config(¶ms.config)
168 .build()
169 .await;
170
171 #[allow(deprecated)]
174 let join_encoding_type = node
175 .get_join_encoding_type()
176 .map_or(params.config.developer.join_encoding_type, Into::into);
177
178 let args = HashJoinExecutorDispatcherArgs {
179 ctx: params.actor_context,
180 info: params.info.clone(),
181 source_l,
182 source_r,
183 params_l,
184 params_r,
185 null_safe,
186 output_indices,
187 cond: condition,
188 inequality_pairs,
189 state_table_l,
190 degree_state_table_l,
191 state_table_r,
192 degree_state_table_r,
193 lru_manager: params.watermark_epoch,
194 is_append_only,
195 metrics: params.executor_stats,
196 join_type_proto: node.get_join_type()?,
197 join_key_data_types,
198 chunk_size: params.config.developer.chunk_size,
199 high_join_amplification_threshold: (params.config.developer)
200 .high_join_amplification_threshold,
201 join_encoding_type,
202 };
203
204 let exec = args.dispatch()?;
205 Ok((params.info, exec).into())
206 }
207}
208
209struct HashJoinExecutorDispatcherArgs<S: StateStore> {
210 ctx: ActorContextRef,
211 info: ExecutorInfo,
212 source_l: Executor,
213 source_r: Executor,
214 params_l: JoinParams,
215 params_r: JoinParams,
216 null_safe: Vec<bool>,
217 output_indices: Vec<usize>,
218 cond: Option<NonStrictExpression>,
219 inequality_pairs: Vec<InequalityPairInfo>,
220 state_table_l: StateTable<S>,
221 degree_state_table_l: StateTable<S>,
222 state_table_r: StateTable<S>,
223 degree_state_table_r: StateTable<S>,
224 lru_manager: AtomicU64Ref,
225 is_append_only: bool,
226 metrics: Arc<StreamingMetrics>,
227 join_type_proto: JoinTypeProto,
228 join_key_data_types: Vec<DataType>,
229 chunk_size: usize,
230 high_join_amplification_threshold: usize,
231 join_encoding_type: JoinEncodingType,
232}
233
234impl<S: StateStore> HashKeyDispatcher for HashJoinExecutorDispatcherArgs<S> {
235 type Output = StreamResult<Box<dyn Execute>>;
236
237 fn dispatch_impl<K: HashKey>(self) -> Self::Output {
238 macro_rules! build {
240 ($join_type:ident, $join_encoding:ident) => {
241 Ok(
242 HashJoinExecutor::<K, S, { JoinType::$join_type }, $join_encoding>::new(
243 self.ctx,
244 self.info,
245 self.source_l,
246 self.source_r,
247 self.params_l,
248 self.params_r,
249 self.null_safe,
250 self.output_indices,
251 self.cond,
252 self.inequality_pairs,
253 self.state_table_l,
254 self.degree_state_table_l,
255 self.state_table_r,
256 self.degree_state_table_r,
257 self.lru_manager,
258 self.is_append_only,
259 self.metrics,
260 self.chunk_size,
261 self.high_join_amplification_threshold,
262 )
263 .boxed(),
264 )
265 };
266 }
267
268 macro_rules! build_match {
269 ($($join_type:ident),*) => {
270 match (self.join_type_proto, self.join_encoding_type) {
271 (JoinTypeProto::AsofInner, _)
272 | (JoinTypeProto::AsofLeftOuter, _)
273 | (JoinTypeProto::Unspecified, _) => unreachable!(),
274 $(
275 (JoinTypeProto::$join_type, JoinEncodingType::Memory) => build!($join_type, MemoryEncoding),
276 (JoinTypeProto::$join_type, JoinEncodingType::Cpu) => build!($join_type, CpuEncoding),
277 )*
278 }
279 };
280 }
281 build_match! {
282 Inner,
283 LeftOuter,
284 RightOuter,
285 FullOuter,
286 LeftSemi,
287 LeftAnti,
288 RightSemi,
289 RightAnti
290 }
291 }
292
293 fn data_types(&self) -> &[DataType] {
294 &self.join_key_data_types
295 }
296}