risingwave_batch/task/
context.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use std::sync::Arc;
15
16use prometheus::core::Atomic;
17use risingwave_common::catalog::SysCatalogReaderRef;
18use risingwave_common::config::BatchConfig;
19use risingwave_common::memory::MemoryContext;
20use risingwave_common::metrics::TrAdderAtomic;
21use risingwave_common::util::addr::{HostAddr, is_local_address};
22use risingwave_connector::source::monitor::SourceMetrics;
23use risingwave_dml::dml_manager::DmlManagerRef;
24use risingwave_rpc_client::ComputeClientPoolRef;
25use risingwave_storage::StateStoreImpl;
26
27use crate::error::Result;
28use crate::monitor::{BatchMetrics, BatchMetricsInner, BatchSpillMetrics};
29use crate::task::{BatchEnvironment, TaskOutput, TaskOutputId};
30use crate::worker_manager::worker_node_manager::WorkerNodeManagerRef;
31
32/// Context for batch task execution.
33///
34/// This context is specific to one task execution, and should *not* be shared by different tasks.
35pub trait BatchTaskContext: Send + Sync + 'static {
36    /// Get task output identified by `task_output_id`.
37    ///
38    /// Returns error if the task of `task_output_id` doesn't run in same worker as current task.
39    fn get_task_output(&self, task_output_id: TaskOutputId) -> Result<TaskOutput>;
40
41    /// Get system catalog reader, used to read system table.
42    fn catalog_reader(&self) -> SysCatalogReaderRef;
43
44    /// Whether `peer_addr` is in same as current task.
45    fn is_local_addr(&self, peer_addr: &HostAddr) -> bool;
46
47    fn dml_manager(&self) -> DmlManagerRef;
48
49    fn state_store(&self) -> StateStoreImpl;
50
51    /// Get batch metrics.
52    /// None indicates that not collect task metrics.
53    fn batch_metrics(&self) -> Option<BatchMetrics>;
54
55    fn spill_metrics(&self) -> Arc<BatchSpillMetrics>;
56
57    /// Get compute client pool. This is used in grpc exchange to avoid creating new compute client
58    /// for each grpc call.
59    fn client_pool(&self) -> ComputeClientPoolRef;
60
61    /// Get config for batch environment
62    fn get_config(&self) -> &BatchConfig;
63
64    fn source_metrics(&self) -> Arc<SourceMetrics>;
65
66    fn create_executor_mem_context(&self, executor_id: &str) -> MemoryContext;
67
68    fn worker_node_manager(&self) -> Option<WorkerNodeManagerRef>;
69}
70
71/// Batch task context on compute node.
72#[derive(Clone)]
73pub struct ComputeNodeContext {
74    env: BatchEnvironment,
75
76    batch_metrics: BatchMetrics,
77
78    mem_context: MemoryContext,
79}
80
81impl BatchTaskContext for ComputeNodeContext {
82    fn get_task_output(&self, task_output_id: TaskOutputId) -> Result<TaskOutput> {
83        self.env
84            .task_manager()
85            .take_output(&task_output_id.to_prost())
86    }
87
88    fn catalog_reader(&self) -> SysCatalogReaderRef {
89        unimplemented!("not supported in distributed mode")
90    }
91
92    fn is_local_addr(&self, peer_addr: &HostAddr) -> bool {
93        is_local_address(self.env.server_address(), peer_addr)
94    }
95
96    fn dml_manager(&self) -> DmlManagerRef {
97        self.env.dml_manager_ref()
98    }
99
100    fn state_store(&self) -> StateStoreImpl {
101        self.env.state_store()
102    }
103
104    fn batch_metrics(&self) -> Option<BatchMetrics> {
105        Some(self.batch_metrics.clone())
106    }
107
108    fn spill_metrics(&self) -> Arc<BatchSpillMetrics> {
109        self.env.spill_metrics()
110    }
111
112    fn client_pool(&self) -> ComputeClientPoolRef {
113        self.env.client_pool()
114    }
115
116    fn get_config(&self) -> &BatchConfig {
117        self.env.config()
118    }
119
120    fn source_metrics(&self) -> Arc<SourceMetrics> {
121        self.env.source_metrics()
122    }
123
124    fn create_executor_mem_context(&self, _executor_id: &str) -> MemoryContext {
125        let counter = TrAdderAtomic::new(0);
126        MemoryContext::new(Some(self.mem_context.clone()), counter)
127    }
128
129    fn worker_node_manager(&self) -> Option<WorkerNodeManagerRef> {
130        None
131    }
132}
133
134impl ComputeNodeContext {
135    pub fn for_test() -> Arc<dyn BatchTaskContext> {
136        Arc::new(Self {
137            env: BatchEnvironment::for_test(),
138            batch_metrics: BatchMetricsInner::for_test(),
139            mem_context: MemoryContext::none(),
140        })
141    }
142
143    pub fn create(env: BatchEnvironment) -> Arc<dyn BatchTaskContext> {
144        let mem_context = env.task_manager().memory_context_ref();
145        let batch_metrics = Arc::new(BatchMetricsInner::new(
146            env.task_manager().metrics(),
147            env.executor_metrics(),
148            env.iceberg_scan_metrics(),
149        ));
150        Arc::new(Self {
151            env,
152            batch_metrics,
153            mem_context,
154        })
155    }
156}