risingwave_batch/task/
context.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use prometheus::core::Atomic;
18use risingwave_common::catalog::SysCatalogReaderRef;
19use risingwave_common::config::BatchConfig;
20use risingwave_common::memory::MemoryContext;
21use risingwave_common::metrics::TrAdderAtomic;
22use risingwave_common::metrics_reader::MetricsReader;
23use risingwave_common::util::addr::{HostAddr, is_local_address};
24use risingwave_connector::source::monitor::SourceMetrics;
25use risingwave_dml::dml_manager::DmlManagerRef;
26use risingwave_rpc_client::ComputeClientPoolRef;
27use risingwave_storage::StateStoreImpl;
28
29use crate::error::Result;
30use crate::monitor::{BatchMetrics, BatchMetricsInner, BatchSpillMetrics};
31use crate::task::{BatchEnvironment, TaskOutput, TaskOutputId};
32use crate::worker_manager::worker_node_manager::WorkerNodeManagerRef;
33
34/// Context for batch task execution.
35///
36/// This context is specific to one task execution, and should *not* be shared by different tasks.
37pub trait BatchTaskContext: Send + Sync + 'static {
38    /// Get task output identified by `task_output_id`.
39    ///
40    /// Returns error if the task of `task_output_id` doesn't run in same worker as current task.
41    fn get_task_output(&self, task_output_id: TaskOutputId) -> Result<TaskOutput>;
42
43    /// Get system catalog reader, used to read system table.
44    fn catalog_reader(&self) -> SysCatalogReaderRef;
45
46    /// Whether `peer_addr` is in same as current task.
47    fn is_local_addr(&self, peer_addr: &HostAddr) -> bool;
48
49    fn dml_manager(&self) -> DmlManagerRef;
50
51    fn state_store(&self) -> StateStoreImpl;
52
53    /// Get batch metrics.
54    /// None indicates that not collect task metrics.
55    fn batch_metrics(&self) -> Option<BatchMetrics>;
56
57    fn spill_metrics(&self) -> Arc<BatchSpillMetrics>;
58
59    /// Get compute client pool. This is used in grpc exchange to avoid creating new compute client
60    /// for each grpc call.
61    fn client_pool(&self) -> ComputeClientPoolRef;
62
63    /// Get config for batch environment
64    fn get_config(&self) -> &BatchConfig;
65
66    fn source_metrics(&self) -> Arc<SourceMetrics>;
67
68    fn create_executor_mem_context(&self, executor_id: &str) -> MemoryContext;
69
70    fn worker_node_manager(&self) -> Option<WorkerNodeManagerRef>;
71
72    /// Get metrics reader for reading channel delta stats and other metrics.
73    fn metrics_reader(&self) -> Arc<dyn MetricsReader>;
74}
75
76/// Batch task context on compute node.
77#[derive(Clone)]
78pub struct ComputeNodeContext {
79    env: BatchEnvironment,
80
81    batch_metrics: BatchMetrics,
82
83    mem_context: MemoryContext,
84}
85
86impl BatchTaskContext for ComputeNodeContext {
87    fn get_task_output(&self, task_output_id: TaskOutputId) -> Result<TaskOutput> {
88        self.env
89            .task_manager()
90            .take_output(&task_output_id.to_prost())
91    }
92
93    fn catalog_reader(&self) -> SysCatalogReaderRef {
94        unimplemented!("not supported in distributed mode")
95    }
96
97    fn is_local_addr(&self, peer_addr: &HostAddr) -> bool {
98        is_local_address(self.env.server_address(), peer_addr)
99    }
100
101    fn dml_manager(&self) -> DmlManagerRef {
102        self.env.dml_manager_ref()
103    }
104
105    fn state_store(&self) -> StateStoreImpl {
106        self.env.state_store()
107    }
108
109    fn batch_metrics(&self) -> Option<BatchMetrics> {
110        Some(self.batch_metrics.clone())
111    }
112
113    fn spill_metrics(&self) -> Arc<BatchSpillMetrics> {
114        self.env.spill_metrics()
115    }
116
117    fn client_pool(&self) -> ComputeClientPoolRef {
118        self.env.client_pool()
119    }
120
121    fn get_config(&self) -> &BatchConfig {
122        self.env.config()
123    }
124
125    fn source_metrics(&self) -> Arc<SourceMetrics> {
126        self.env.source_metrics()
127    }
128
129    fn create_executor_mem_context(&self, _executor_id: &str) -> MemoryContext {
130        let counter = TrAdderAtomic::new(0);
131        MemoryContext::new(Some(self.mem_context.clone()), counter)
132    }
133
134    fn worker_node_manager(&self) -> Option<WorkerNodeManagerRef> {
135        None
136    }
137
138    fn metrics_reader(&self) -> Arc<dyn MetricsReader> {
139        unimplemented!("metrics_reader not supported in compute node context")
140    }
141}
142
143impl ComputeNodeContext {
144    pub fn for_test() -> Arc<dyn BatchTaskContext> {
145        Arc::new(Self {
146            env: BatchEnvironment::for_test(),
147            batch_metrics: BatchMetricsInner::for_test(),
148            mem_context: MemoryContext::none(),
149        })
150    }
151
152    pub fn create(env: BatchEnvironment) -> Arc<dyn BatchTaskContext> {
153        let mem_context = env.task_manager().memory_context_ref();
154        let batch_metrics = Arc::new(BatchMetricsInner::new(
155            env.task_manager().metrics(),
156            env.executor_metrics(),
157            env.iceberg_scan_metrics(),
158        ));
159        Arc::new(Self {
160            env,
161            batch_metrics,
162            mem_context,
163        })
164    }
165}