risingwave_batch/task/
context.rs1use std::sync::Arc;
16
17use prometheus::core::Atomic;
18use risingwave_common::catalog::SysCatalogReaderRef;
19use risingwave_common::config::BatchConfig;
20use risingwave_common::memory::MemoryContext;
21use risingwave_common::metrics::TrAdderAtomic;
22use risingwave_common::metrics_reader::MetricsReader;
23use risingwave_common::util::addr::{HostAddr, is_local_address};
24use risingwave_connector::source::monitor::SourceMetrics;
25use risingwave_dml::dml_manager::DmlManagerRef;
26use risingwave_rpc_client::ComputeClientPoolRef;
27use risingwave_storage::StateStoreImpl;
28
29use crate::error::Result;
30use crate::monitor::{BatchMetrics, BatchMetricsInner, BatchSpillMetrics};
31use crate::task::{BatchEnvironment, TaskOutput, TaskOutputId};
32use crate::worker_manager::worker_node_manager::WorkerNodeManagerRef;
33
34pub trait BatchTaskContext: Send + Sync + 'static {
38 fn get_task_output(&self, task_output_id: TaskOutputId) -> Result<TaskOutput>;
42
43 fn catalog_reader(&self) -> SysCatalogReaderRef;
45
46 fn is_local_addr(&self, peer_addr: &HostAddr) -> bool;
48
49 fn dml_manager(&self) -> DmlManagerRef;
50
51 fn state_store(&self) -> StateStoreImpl;
52
53 fn batch_metrics(&self) -> Option<BatchMetrics>;
56
57 fn spill_metrics(&self) -> Arc<BatchSpillMetrics>;
58
59 fn client_pool(&self) -> ComputeClientPoolRef;
62
63 fn get_config(&self) -> &BatchConfig;
65
66 fn source_metrics(&self) -> Arc<SourceMetrics>;
67
68 fn create_executor_mem_context(&self, executor_id: &str) -> MemoryContext;
69
70 fn worker_node_manager(&self) -> Option<WorkerNodeManagerRef>;
71
72 fn metrics_reader(&self) -> Arc<dyn MetricsReader>;
74}
75
76#[derive(Clone)]
78pub struct ComputeNodeContext {
79 env: BatchEnvironment,
80
81 batch_metrics: BatchMetrics,
82
83 mem_context: MemoryContext,
84}
85
86impl BatchTaskContext for ComputeNodeContext {
87 fn get_task_output(&self, task_output_id: TaskOutputId) -> Result<TaskOutput> {
88 self.env
89 .task_manager()
90 .take_output(&task_output_id.to_prost())
91 }
92
93 fn catalog_reader(&self) -> SysCatalogReaderRef {
94 unimplemented!("not supported in distributed mode")
95 }
96
97 fn is_local_addr(&self, peer_addr: &HostAddr) -> bool {
98 is_local_address(self.env.server_address(), peer_addr)
99 }
100
101 fn dml_manager(&self) -> DmlManagerRef {
102 self.env.dml_manager_ref()
103 }
104
105 fn state_store(&self) -> StateStoreImpl {
106 self.env.state_store()
107 }
108
109 fn batch_metrics(&self) -> Option<BatchMetrics> {
110 Some(self.batch_metrics.clone())
111 }
112
113 fn spill_metrics(&self) -> Arc<BatchSpillMetrics> {
114 self.env.spill_metrics()
115 }
116
117 fn client_pool(&self) -> ComputeClientPoolRef {
118 self.env.client_pool()
119 }
120
121 fn get_config(&self) -> &BatchConfig {
122 self.env.config()
123 }
124
125 fn source_metrics(&self) -> Arc<SourceMetrics> {
126 self.env.source_metrics()
127 }
128
129 fn create_executor_mem_context(&self, _executor_id: &str) -> MemoryContext {
130 let counter = TrAdderAtomic::new(0);
131 MemoryContext::new(Some(self.mem_context.clone()), counter)
132 }
133
134 fn worker_node_manager(&self) -> Option<WorkerNodeManagerRef> {
135 None
136 }
137
138 fn metrics_reader(&self) -> Arc<dyn MetricsReader> {
139 unimplemented!("metrics_reader not supported in compute node context")
140 }
141}
142
143impl ComputeNodeContext {
144 pub fn for_test() -> Arc<dyn BatchTaskContext> {
145 Arc::new(Self {
146 env: BatchEnvironment::for_test(),
147 batch_metrics: BatchMetricsInner::for_test(),
148 mem_context: MemoryContext::none(),
149 })
150 }
151
152 pub fn create(env: BatchEnvironment) -> Arc<dyn BatchTaskContext> {
153 let mem_context = env.task_manager().memory_context_ref();
154 let batch_metrics = Arc::new(BatchMetricsInner::new(
155 env.task_manager().metrics(),
156 env.executor_metrics(),
157 env.iceberg_scan_metrics(),
158 ));
159 Arc::new(Self {
160 env,
161 batch_metrics,
162 mem_context,
163 })
164 }
165}