risingwave_batch/task/
context.rs1use std::sync::Arc;
15
16use prometheus::core::Atomic;
17use risingwave_common::catalog::SysCatalogReaderRef;
18use risingwave_common::config::BatchConfig;
19use risingwave_common::memory::MemoryContext;
20use risingwave_common::metrics::TrAdderAtomic;
21use risingwave_common::metrics_reader::MetricsReader;
22use risingwave_common::util::addr::{HostAddr, is_local_address};
23use risingwave_connector::source::monitor::SourceMetrics;
24use risingwave_dml::dml_manager::DmlManagerRef;
25use risingwave_rpc_client::ComputeClientPoolRef;
26use risingwave_storage::StateStoreImpl;
27
28use crate::error::Result;
29use crate::monitor::{BatchMetrics, BatchMetricsInner, BatchSpillMetrics};
30use crate::task::{BatchEnvironment, TaskOutput, TaskOutputId};
31use crate::worker_manager::worker_node_manager::WorkerNodeManagerRef;
32
33pub trait BatchTaskContext: Send + Sync + 'static {
37 fn get_task_output(&self, task_output_id: TaskOutputId) -> Result<TaskOutput>;
41
42 fn catalog_reader(&self) -> SysCatalogReaderRef;
44
45 fn is_local_addr(&self, peer_addr: &HostAddr) -> bool;
47
48 fn dml_manager(&self) -> DmlManagerRef;
49
50 fn state_store(&self) -> StateStoreImpl;
51
52 fn batch_metrics(&self) -> Option<BatchMetrics>;
55
56 fn spill_metrics(&self) -> Arc<BatchSpillMetrics>;
57
58 fn client_pool(&self) -> ComputeClientPoolRef;
61
62 fn get_config(&self) -> &BatchConfig;
64
65 fn source_metrics(&self) -> Arc<SourceMetrics>;
66
67 fn create_executor_mem_context(&self, executor_id: &str) -> MemoryContext;
68
69 fn worker_node_manager(&self) -> Option<WorkerNodeManagerRef>;
70
71 fn metrics_reader(&self) -> Arc<dyn MetricsReader>;
73}
74
75#[derive(Clone)]
77pub struct ComputeNodeContext {
78 env: BatchEnvironment,
79
80 batch_metrics: BatchMetrics,
81
82 mem_context: MemoryContext,
83}
84
85impl BatchTaskContext for ComputeNodeContext {
86 fn get_task_output(&self, task_output_id: TaskOutputId) -> Result<TaskOutput> {
87 self.env
88 .task_manager()
89 .take_output(&task_output_id.to_prost())
90 }
91
92 fn catalog_reader(&self) -> SysCatalogReaderRef {
93 unimplemented!("not supported in distributed mode")
94 }
95
96 fn is_local_addr(&self, peer_addr: &HostAddr) -> bool {
97 is_local_address(self.env.server_address(), peer_addr)
98 }
99
100 fn dml_manager(&self) -> DmlManagerRef {
101 self.env.dml_manager_ref()
102 }
103
104 fn state_store(&self) -> StateStoreImpl {
105 self.env.state_store()
106 }
107
108 fn batch_metrics(&self) -> Option<BatchMetrics> {
109 Some(self.batch_metrics.clone())
110 }
111
112 fn spill_metrics(&self) -> Arc<BatchSpillMetrics> {
113 self.env.spill_metrics()
114 }
115
116 fn client_pool(&self) -> ComputeClientPoolRef {
117 self.env.client_pool()
118 }
119
120 fn get_config(&self) -> &BatchConfig {
121 self.env.config()
122 }
123
124 fn source_metrics(&self) -> Arc<SourceMetrics> {
125 self.env.source_metrics()
126 }
127
128 fn create_executor_mem_context(&self, _executor_id: &str) -> MemoryContext {
129 let counter = TrAdderAtomic::new(0);
130 MemoryContext::new(Some(self.mem_context.clone()), counter)
131 }
132
133 fn worker_node_manager(&self) -> Option<WorkerNodeManagerRef> {
134 None
135 }
136
137 fn metrics_reader(&self) -> Arc<dyn MetricsReader> {
138 unimplemented!("metrics_reader not supported in compute node context")
139 }
140}
141
142impl ComputeNodeContext {
143 pub fn for_test() -> Arc<dyn BatchTaskContext> {
144 Arc::new(Self {
145 env: BatchEnvironment::for_test(),
146 batch_metrics: BatchMetricsInner::for_test(),
147 mem_context: MemoryContext::none(),
148 })
149 }
150
151 pub fn create(env: BatchEnvironment) -> Arc<dyn BatchTaskContext> {
152 let mem_context = env.task_manager().memory_context_ref();
153 let batch_metrics = Arc::new(BatchMetricsInner::new(
154 env.task_manager().metrics(),
155 env.executor_metrics(),
156 env.iceberg_scan_metrics(),
157 ));
158 Arc::new(Self {
159 env,
160 batch_metrics,
161 mem_context,
162 })
163 }
164}