risingwave_frontend/optimizer/plan_node/
batch_lookup_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use pretty_xmlish::{Pretty, XmlNode};
18use risingwave_common::catalog::ColumnId;
19use risingwave_pb::batch_plan::plan_node::NodeBody;
20use risingwave_pb::batch_plan::{DistributedLookupJoinNode, LocalLookupJoinNode};
21use risingwave_pb::plan_common::AsOfJoinDesc;
22use risingwave_sqlparser::ast::AsOf;
23
24use super::batch::prelude::*;
25use super::utils::{Distill, childless_record, to_pb_time_travel_as_of};
26use super::{BatchPlanRef as PlanRef, BatchSeqScan, ExprRewritable, generic};
27use crate::TableCatalog;
28use crate::error::Result;
29use crate::expr::{Expr, ExprRewriter, ExprVisitor};
30use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
31use crate::optimizer::plan_node::utils::IndicesDisplay;
32use crate::optimizer::plan_node::{
33    EqJoinPredicate, EqJoinPredicateDisplay, PlanBase, PlanTreeNodeUnary, ToDistributedBatch,
34    ToLocalBatch, TryToBatchPb,
35};
36use crate::optimizer::property::{Distribution, Order, RequiredDist};
37use crate::scheduler::SchedulerResult;
38use crate::utils::ColIndexMappingRewriteExt;
39
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41pub struct BatchLookupJoin {
42    pub base: PlanBase<Batch>,
43    core: generic::Join<PlanRef>,
44
45    /// The join condition must be equivalent to `logical.on`, but separated into equal and
46    /// non-equal parts to facilitate execution later
47    eq_join_predicate: EqJoinPredicate,
48
49    /// Table description of the right side table
50    right_table: Arc<TableCatalog>,
51
52    /// Output column ids of the right side table
53    right_output_column_ids: Vec<ColumnId>,
54
55    /// The prefix length of the order key of right side table.
56    lookup_prefix_len: usize,
57
58    /// If `distributed_lookup` is true, it will generate `DistributedLookupJoinNode` for
59    /// `ToBatchPb`. Otherwise, it will generate `LookupJoinNode`.
60    distributed_lookup: bool,
61
62    as_of: Option<AsOf>,
63    // `AsOf` join description
64    asof_desc: Option<AsOfJoinDesc>,
65}
66
67impl BatchLookupJoin {
68    pub fn new(
69        core: generic::Join<PlanRef>,
70        eq_join_predicate: EqJoinPredicate,
71        right_table: Arc<TableCatalog>,
72        right_output_column_ids: Vec<ColumnId>,
73        lookup_prefix_len: usize,
74        distributed_lookup: bool,
75        as_of: Option<AsOf>,
76        asof_desc: Option<AsOfJoinDesc>,
77    ) -> Self {
78        // We cannot create a `BatchLookupJoin` without any eq keys. We require eq keys to do the
79        // lookup.
80        assert!(eq_join_predicate.has_eq());
81        assert!(eq_join_predicate.eq_keys_are_type_aligned());
82        let dist = Self::derive_dist(core.left.distribution(), &core);
83        let base = PlanBase::new_batch_with_core(&core, dist, Order::any());
84        Self {
85            base,
86            core,
87            eq_join_predicate,
88            right_table,
89            right_output_column_ids,
90            lookup_prefix_len,
91            distributed_lookup,
92            as_of,
93            asof_desc,
94        }
95    }
96
97    fn derive_dist(left: &Distribution, core: &generic::Join<PlanRef>) -> Distribution {
98        match left {
99            Distribution::Single => Distribution::Single,
100            Distribution::HashShard(_) | Distribution::UpstreamHashShard(_, _) => {
101                let l2o = core.l2i_col_mapping().composite(&core.i2o_col_mapping());
102                l2o.rewrite_provided_distribution(left)
103            }
104            _ => unreachable!(),
105        }
106    }
107
108    fn eq_join_predicate(&self) -> &EqJoinPredicate {
109        &self.eq_join_predicate
110    }
111
112    pub fn right_table(&self) -> &TableCatalog {
113        &self.right_table
114    }
115
116    fn clone_with_distributed_lookup(&self, input: PlanRef, distributed_lookup: bool) -> Self {
117        let mut batch_lookup_join = self.clone_with_input(input);
118        batch_lookup_join.distributed_lookup = distributed_lookup;
119        batch_lookup_join
120    }
121
122    pub fn lookup_prefix_len(&self) -> usize {
123        self.lookup_prefix_len
124    }
125}
126
127impl Distill for BatchLookupJoin {
128    fn distill<'a>(&self) -> XmlNode<'a> {
129        let verbose = self.base.ctx().is_explain_verbose();
130        let mut vec = Vec::with_capacity(if verbose { 3 } else { 2 });
131        vec.push(("type", Pretty::debug(&self.core.join_type)));
132
133        let concat_schema = self.core.concat_schema();
134        vec.push((
135            "predicate",
136            Pretty::debug(&EqJoinPredicateDisplay {
137                eq_join_predicate: self.eq_join_predicate(),
138                input_schema: &concat_schema,
139            }),
140        ));
141
142        if verbose {
143            let data = IndicesDisplay::from_join(&self.core, &concat_schema);
144            vec.push(("output", data));
145        }
146
147        let scan: &BatchSeqScan = self.core.right.as_batch_seq_scan().unwrap();
148
149        vec.push(("lookup table", Pretty::display(&scan.core().table_name())));
150
151        childless_record("BatchLookupJoin", vec)
152    }
153}
154
155impl PlanTreeNodeUnary<Batch> for BatchLookupJoin {
156    fn input(&self) -> PlanRef {
157        self.core.left.clone()
158    }
159
160    // Only change left side
161    fn clone_with_input(&self, input: PlanRef) -> Self {
162        let mut core = self.core.clone();
163        core.left = input;
164        Self::new(
165            core,
166            self.eq_join_predicate.clone(),
167            self.right_table.clone(),
168            self.right_output_column_ids.clone(),
169            self.lookup_prefix_len,
170            self.distributed_lookup,
171            self.as_of.clone(),
172            self.asof_desc,
173        )
174    }
175}
176
177impl_plan_tree_node_for_unary! { Batch, BatchLookupJoin }
178
179impl ToDistributedBatch for BatchLookupJoin {
180    fn to_distributed(&self) -> Result<PlanRef> {
181        // Align left distribution keys with the right table.
182        let mut exchange_dist_keys = vec![];
183        let left_eq_indexes = self.eq_join_predicate.left_eq_indexes();
184        let right_table = &self.right_table;
185        for dist_col_index in &right_table.distribution_key {
186            let dist_col_id = right_table.columns[*dist_col_index].column_desc.column_id;
187            let output_pos = self
188                .right_output_column_ids
189                .iter()
190                .position(|p| *p == dist_col_id)
191                .unwrap();
192            let dist_in_eq_indexes = self
193                .eq_join_predicate
194                .right_eq_indexes()
195                .iter()
196                .position(|col| *col == output_pos)
197                .unwrap();
198            assert!(dist_in_eq_indexes < self.lookup_prefix_len);
199            exchange_dist_keys.push(left_eq_indexes[dist_in_eq_indexes]);
200        }
201
202        assert!(!exchange_dist_keys.is_empty());
203
204        let input = self.input().to_distributed_with_required(
205            &Order::any(),
206            &RequiredDist::PhysicalDist(Distribution::UpstreamHashShard(
207                exchange_dist_keys,
208                self.right_table.id,
209            )),
210        )?;
211
212        Ok(self.clone_with_distributed_lookup(input, true).into())
213    }
214}
215
216impl TryToBatchPb for BatchLookupJoin {
217    fn try_to_batch_prost_body(&self) -> SchedulerResult<NodeBody> {
218        Ok(if self.distributed_lookup {
219            NodeBody::DistributedLookupJoin(DistributedLookupJoinNode {
220                join_type: self.core.join_type as i32,
221                condition: self
222                    .eq_join_predicate
223                    .other_cond()
224                    .as_expr_unless_true()
225                    .map(|x| x.to_expr_proto()),
226                outer_side_key: self
227                    .eq_join_predicate
228                    .left_eq_indexes()
229                    .into_iter()
230                    .map(|a| a as _)
231                    .collect(),
232                inner_side_key: self
233                    .eq_join_predicate
234                    .right_eq_indexes()
235                    .into_iter()
236                    .map(|a| a as _)
237                    .collect(),
238                inner_side_table_desc: Some(self.right_table.table_desc().try_to_protobuf()?),
239                inner_side_column_ids: self
240                    .right_output_column_ids
241                    .iter()
242                    .map(ColumnId::get_id)
243                    .collect(),
244                output_indices: self.core.output_indices.iter().map(|&x| x as u32).collect(),
245                null_safe: self.eq_join_predicate.null_safes(),
246                lookup_prefix_len: self.lookup_prefix_len as u32,
247                as_of: to_pb_time_travel_as_of(&self.as_of)?,
248                asof_desc: self.asof_desc,
249            })
250        } else {
251            NodeBody::LocalLookupJoin(LocalLookupJoinNode {
252                join_type: self.core.join_type as i32,
253                condition: self
254                    .eq_join_predicate
255                    .other_cond()
256                    .as_expr_unless_true()
257                    .map(|x| x.to_expr_proto()),
258                outer_side_key: self
259                    .eq_join_predicate
260                    .left_eq_indexes()
261                    .into_iter()
262                    .map(|a| a as _)
263                    .collect(),
264                inner_side_key: self
265                    .eq_join_predicate
266                    .right_eq_indexes()
267                    .into_iter()
268                    .map(|a| a as _)
269                    .collect(),
270                inner_side_table_desc: Some(self.right_table.table_desc().try_to_protobuf()?),
271                inner_side_vnode_mapping: vec![], // To be filled in at local.rs
272                inner_side_column_ids: self
273                    .right_output_column_ids
274                    .iter()
275                    .map(ColumnId::get_id)
276                    .collect(),
277                output_indices: self.core.output_indices.iter().map(|&x| x as u32).collect(),
278                worker_nodes: vec![], // To be filled in at local.rs
279                null_safe: self.eq_join_predicate.null_safes(),
280                lookup_prefix_len: self.lookup_prefix_len as u32,
281                as_of: to_pb_time_travel_as_of(&self.as_of)?,
282                asof_desc: self.asof_desc,
283            })
284        })
285    }
286}
287
288impl ToLocalBatch for BatchLookupJoin {
289    fn to_local(&self) -> Result<PlanRef> {
290        let input = RequiredDist::single()
291            .batch_enforce_if_not_satisfies(self.input().to_local()?, &Order::any())?;
292
293        Ok(self.clone_with_distributed_lookup(input, false).into())
294    }
295}
296
297impl ExprRewritable<Batch> for BatchLookupJoin {
298    fn has_rewritable_expr(&self) -> bool {
299        true
300    }
301
302    fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
303        let base = self.base.clone_with_new_plan_id();
304        let mut core = self.core.clone();
305        core.rewrite_exprs(r);
306        Self {
307            base,
308            core,
309            eq_join_predicate: self.eq_join_predicate.rewrite_exprs(r),
310            ..Self::clone(self)
311        }
312        .into()
313    }
314}
315
316impl ExprVisitable for BatchLookupJoin {
317    fn visit_exprs(&self, v: &mut dyn ExprVisitor) {
318        self.core.visit_exprs(v);
319    }
320}