risingwave_batch_executors/executor/join/
distributed_lookup_join.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::marker::PhantomData;
16use std::mem::swap;
17
18use anyhow::anyhow;
19use futures::pin_mut;
20use itertools::Itertools;
21use risingwave_batch::task::ShutdownToken;
22use risingwave_common::bitmap::Bitmap;
23use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema};
24use risingwave_common::hash::{HashKey, HashKeyDispatcher, VnodeCountCompat};
25use risingwave_common::memory::MemoryContext;
26use risingwave_common::row::OwnedRow;
27use risingwave_common::types::{DataType, Datum};
28use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
29use risingwave_common::util::iter_util::ZipEqFast;
30use risingwave_common::util::scan_range::ScanRange;
31use risingwave_expr::expr::{BoxedExpression, build_from_prost};
32use risingwave_pb::batch_plan::plan_node::NodeBody;
33use risingwave_pb::common::BatchQueryEpoch;
34use risingwave_storage::store::PrefetchOptions;
35use risingwave_storage::table::TableIter;
36use risingwave_storage::table::batch_table::BatchTable;
37use risingwave_storage::{StateStore, dispatch_state_store};
38
39use super::AsOfDesc;
40use crate::error::Result;
41use crate::executor::join::JoinType;
42use crate::executor::{
43    BoxedDataChunkStream, BoxedExecutor, BoxedExecutorBuilder, BufferChunkExecutor, Executor,
44    ExecutorBuilder, LookupExecutorBuilder, LookupJoinBase,
45};
46
47/// Distributed Lookup Join Executor.
48/// High level Execution flow:
49/// Repeat 1-3:
50///   1. Read N rows from outer side input and send keys to inner side builder after deduplication.
51///   2. Inner side input lookups inner side table with keys and builds hash map.
52///   3. Outer side rows join each inner side rows by probing the hash map.
53///
54/// Distributed lookup join already scheduled to its inner side corresponding compute node, so that
55/// it can just lookup the compute node locally without sending RPCs to other compute nodes.
56pub struct DistributedLookupJoinExecutor<K, S: StateStore> {
57    base: LookupJoinBase<K, InnerSideExecutorBuilder<S>>,
58}
59
60impl<K: HashKey, S: StateStore> Executor for DistributedLookupJoinExecutor<K, S> {
61    fn schema(&self) -> &Schema {
62        &self.base.schema
63    }
64
65    fn identity(&self) -> &str {
66        &self.base.identity
67    }
68
69    fn execute(self: Box<Self>) -> BoxedDataChunkStream {
70        Box::new(self.base).do_execute()
71    }
72}
73
74impl<K, S: StateStore> DistributedLookupJoinExecutor<K, S> {
75    fn new(base: LookupJoinBase<K, InnerSideExecutorBuilder<S>>) -> Self {
76        Self { base }
77    }
78}
79
80pub struct DistributedLookupJoinExecutorBuilder {}
81
82impl BoxedExecutorBuilder for DistributedLookupJoinExecutorBuilder {
83    async fn new_boxed_executor(
84        source: &ExecutorBuilder<'_>,
85        inputs: Vec<BoxedExecutor>,
86    ) -> Result<BoxedExecutor> {
87        let [outer_side_input]: [_; 1] = inputs.try_into().unwrap();
88
89        let distributed_lookup_join_node = try_match_expand!(
90            source.plan_node().get_node_body().unwrap(),
91            NodeBody::DistributedLookupJoin
92        )?;
93
94        let join_type = JoinType::from_prost(distributed_lookup_join_node.get_join_type()?);
95        let condition = match distributed_lookup_join_node.get_condition() {
96            Ok(cond_prost) => Some(build_from_prost(cond_prost)?),
97            Err(_) => None,
98        };
99
100        let output_indices: Vec<usize> = distributed_lookup_join_node
101            .get_output_indices()
102            .iter()
103            .map(|&x| x as usize)
104            .collect();
105
106        let outer_side_data_types = outer_side_input.schema().data_types();
107
108        let table_desc = distributed_lookup_join_node.get_inner_side_table_desc()?;
109        let inner_side_column_ids = distributed_lookup_join_node
110            .get_inner_side_column_ids()
111            .clone();
112
113        let inner_side_schema = Schema {
114            fields: inner_side_column_ids
115                .iter()
116                .map(|&id| {
117                    let column = table_desc
118                        .columns
119                        .iter()
120                        .find(|c| c.column_id == id)
121                        .unwrap();
122                    Field::from(&ColumnDesc::from(column))
123                })
124                .collect_vec(),
125        };
126
127        let fields = if join_type == JoinType::LeftSemi || join_type == JoinType::LeftAnti {
128            outer_side_input.schema().fields.clone()
129        } else {
130            [
131                outer_side_input.schema().fields.clone(),
132                inner_side_schema.fields.clone(),
133            ]
134            .concat()
135        };
136
137        let original_schema = Schema { fields };
138        let actual_schema = output_indices
139            .iter()
140            .map(|&idx| original_schema[idx].clone())
141            .collect();
142
143        let mut outer_side_key_idxs = vec![];
144        for outer_side_key in distributed_lookup_join_node.get_outer_side_key() {
145            outer_side_key_idxs.push(*outer_side_key as usize)
146        }
147
148        let outer_side_key_types: Vec<DataType> = outer_side_key_idxs
149            .iter()
150            .map(|&i| outer_side_data_types[i].clone())
151            .collect_vec();
152
153        let lookup_prefix_len: usize =
154            distributed_lookup_join_node.get_lookup_prefix_len() as usize;
155
156        let mut inner_side_key_idxs = vec![];
157        for inner_side_key in distributed_lookup_join_node.get_inner_side_key() {
158            inner_side_key_idxs.push(*inner_side_key as usize)
159        }
160
161        let inner_side_key_types = inner_side_key_idxs
162            .iter()
163            .map(|&i| inner_side_schema.fields[i].data_type.clone())
164            .collect_vec();
165
166        let null_safe = distributed_lookup_join_node.get_null_safe().clone();
167
168        let chunk_size = source.context().get_config().developer.chunk_size;
169
170        let asof_desc = distributed_lookup_join_node
171            .asof_desc
172            .map(|desc| AsOfDesc::from_protobuf(&desc))
173            .transpose()?;
174
175        let column_ids = inner_side_column_ids
176            .iter()
177            .copied()
178            .map(ColumnId::from)
179            .collect();
180
181        // Lookup Join always contains distribution key, so we don't need vnode bitmap
182        let vnodes = Some(Bitmap::ones(table_desc.vnode_count()).into());
183
184        dispatch_state_store!(source.context().state_store(), state_store, {
185            let table = BatchTable::new_partial(state_store, column_ids, vnodes, table_desc);
186            let inner_side_builder = InnerSideExecutorBuilder::new(
187                outer_side_key_types,
188                inner_side_key_types.clone(),
189                lookup_prefix_len,
190                distributed_lookup_join_node
191                    .query_epoch
192                    .ok_or_else(|| anyhow!("query_epoch not set in distributed lookup join"))?,
193                vec![],
194                table,
195                chunk_size,
196            );
197
198            let identity = source.plan_node().get_identity().clone();
199
200            Ok(DistributedLookupJoinExecutorArgs {
201                join_type,
202                condition,
203                outer_side_input,
204                outer_side_data_types,
205                outer_side_key_idxs,
206                inner_side_builder,
207                inner_side_key_types,
208                inner_side_key_idxs,
209                null_safe,
210                lookup_prefix_len,
211                chunk_builder: DataChunkBuilder::new(original_schema.data_types(), chunk_size),
212                schema: actual_schema,
213                output_indices,
214                chunk_size,
215                asof_desc,
216                identity: identity.clone(),
217                shutdown_rx: source.shutdown_rx().clone(),
218                mem_ctx: source.context().create_executor_mem_context(&identity),
219            }
220            .dispatch())
221        })
222    }
223}
224
225struct DistributedLookupJoinExecutorArgs<S: StateStore> {
226    join_type: JoinType,
227    condition: Option<BoxedExpression>,
228    outer_side_input: BoxedExecutor,
229    outer_side_data_types: Vec<DataType>,
230    outer_side_key_idxs: Vec<usize>,
231    inner_side_builder: InnerSideExecutorBuilder<S>,
232    inner_side_key_types: Vec<DataType>,
233    inner_side_key_idxs: Vec<usize>,
234    null_safe: Vec<bool>,
235    lookup_prefix_len: usize,
236    chunk_builder: DataChunkBuilder,
237    schema: Schema,
238    output_indices: Vec<usize>,
239    chunk_size: usize,
240    asof_desc: Option<AsOfDesc>,
241    identity: String,
242    shutdown_rx: ShutdownToken,
243    mem_ctx: MemoryContext,
244}
245
246impl<S: StateStore> HashKeyDispatcher for DistributedLookupJoinExecutorArgs<S> {
247    type Output = BoxedExecutor;
248
249    fn dispatch_impl<K: HashKey>(self) -> Self::Output {
250        Box::new(DistributedLookupJoinExecutor::<K, S>::new(LookupJoinBase {
251            join_type: self.join_type,
252            condition: self.condition,
253            outer_side_input: self.outer_side_input,
254            outer_side_data_types: self.outer_side_data_types,
255            outer_side_key_idxs: self.outer_side_key_idxs,
256            inner_side_builder: self.inner_side_builder,
257            inner_side_key_types: self.inner_side_key_types,
258            inner_side_key_idxs: self.inner_side_key_idxs,
259            null_safe: self.null_safe,
260            lookup_prefix_len: self.lookup_prefix_len,
261            chunk_builder: self.chunk_builder,
262            schema: self.schema,
263            output_indices: self.output_indices,
264            chunk_size: self.chunk_size,
265            asof_desc: self.asof_desc,
266            identity: self.identity,
267            shutdown_rx: self.shutdown_rx,
268            mem_ctx: self.mem_ctx,
269            _phantom: PhantomData,
270        }))
271    }
272
273    fn data_types(&self) -> &[DataType] {
274        &self.inner_side_key_types
275    }
276}
277
278/// Inner side executor builder for the `DistributedLookupJoinExecutor`
279struct InnerSideExecutorBuilder<S: StateStore> {
280    outer_side_key_types: Vec<DataType>,
281    inner_side_key_types: Vec<DataType>,
282    lookup_prefix_len: usize,
283    epoch: BatchQueryEpoch,
284    row_list: Vec<OwnedRow>,
285    table: BatchTable<S>,
286    chunk_size: usize,
287}
288
289impl<S: StateStore> InnerSideExecutorBuilder<S> {
290    fn new(
291        outer_side_key_types: Vec<DataType>,
292        inner_side_key_types: Vec<DataType>,
293        lookup_prefix_len: usize,
294        epoch: BatchQueryEpoch,
295        row_list: Vec<OwnedRow>,
296        table: BatchTable<S>,
297        chunk_size: usize,
298    ) -> Self {
299        Self {
300            outer_side_key_types,
301            inner_side_key_types,
302            lookup_prefix_len,
303            epoch,
304            row_list,
305            table,
306            chunk_size,
307        }
308    }
309}
310
311impl<S: StateStore> LookupExecutorBuilder for InnerSideExecutorBuilder<S> {
312    fn reset(&mut self) {
313        // PASS
314    }
315
316    /// Fetch row from inner side table by the scan range added.
317    async fn add_scan_range(&mut self, key_datums: Vec<Datum>) -> Result<()> {
318        let mut scan_range = ScanRange::full_table_scan();
319
320        for ((datum, outer_type), inner_type) in key_datums
321            .into_iter()
322            .zip_eq_fast(
323                self.outer_side_key_types
324                    .iter()
325                    .take(self.lookup_prefix_len),
326            )
327            .zip_eq_fast(
328                self.inner_side_key_types
329                    .iter()
330                    .take(self.lookup_prefix_len),
331            )
332        {
333            let datum = if inner_type == outer_type {
334                datum
335            } else {
336                bail!("Join key types are not aligned: LHS: {outer_type:?}, RHS: {inner_type:?}");
337            };
338
339            scan_range.eq_conds.push(datum);
340        }
341
342        let pk_prefix = OwnedRow::new(scan_range.eq_conds);
343
344        if self.lookup_prefix_len == self.table.pk_indices().len() {
345            let row = self.table.get_row(&pk_prefix, self.epoch.into()).await?;
346
347            if let Some(row) = row {
348                self.row_list.push(row);
349            }
350        } else {
351            let iter = self
352                .table
353                .batch_iter_with_pk_bounds(
354                    self.epoch.into(),
355                    &pk_prefix,
356                    ..,
357                    false,
358                    PrefetchOptions::default(),
359                )
360                .await?;
361
362            pin_mut!(iter);
363            while let Some(row) = iter.next_row().await? {
364                self.row_list.push(row);
365            }
366        }
367
368        Ok(())
369    }
370
371    /// Build a `BufferChunkExecutor` to return all its rows fetched by `add_scan_range` before.
372    async fn build_executor(&mut self) -> Result<BoxedExecutor> {
373        let mut data_chunk_builder =
374            DataChunkBuilder::new(self.table.schema().data_types(), self.chunk_size);
375        let mut chunk_list = Vec::new();
376
377        let mut new_row_list = vec![];
378        swap(&mut new_row_list, &mut self.row_list);
379
380        for row in new_row_list {
381            if let Some(chunk) = data_chunk_builder.append_one_row(row) {
382                chunk_list.push(chunk);
383            }
384        }
385        if let Some(chunk) = data_chunk_builder.consume_all() {
386            chunk_list.push(chunk);
387        }
388
389        Ok(Box::new(BufferChunkExecutor::new(
390            self.table.schema().clone(),
391            chunk_list,
392        )))
393    }
394}