risingwave_batch/task/
context.rsuse std::sync::Arc;
use prometheus::core::Atomic;
use risingwave_common::catalog::SysCatalogReaderRef;
use risingwave_common::config::BatchConfig;
use risingwave_common::memory::MemoryContext;
use risingwave_common::metrics::TrAdderAtomic;
use risingwave_common::util::addr::{is_local_address, HostAddr};
use risingwave_connector::source::monitor::SourceMetrics;
use risingwave_dml::dml_manager::DmlManagerRef;
use risingwave_rpc_client::ComputeClientPoolRef;
use risingwave_storage::StateStoreImpl;
use crate::error::Result;
use crate::monitor::{BatchMetrics, BatchMetricsInner, BatchSpillMetrics};
use crate::task::{BatchEnvironment, TaskOutput, TaskOutputId};
use crate::worker_manager::worker_node_manager::WorkerNodeManagerRef;
pub trait BatchTaskContext: Clone + Send + Sync + 'static {
fn get_task_output(&self, task_output_id: TaskOutputId) -> Result<TaskOutput>;
fn catalog_reader(&self) -> SysCatalogReaderRef;
fn is_local_addr(&self, peer_addr: &HostAddr) -> bool;
fn dml_manager(&self) -> DmlManagerRef;
fn state_store(&self) -> StateStoreImpl;
fn batch_metrics(&self) -> Option<BatchMetrics>;
fn spill_metrics(&self) -> Arc<BatchSpillMetrics>;
fn client_pool(&self) -> ComputeClientPoolRef;
fn get_config(&self) -> &BatchConfig;
fn source_metrics(&self) -> Arc<SourceMetrics>;
fn create_executor_mem_context(&self, executor_id: &str) -> MemoryContext;
fn worker_node_manager(&self) -> Option<WorkerNodeManagerRef>;
}
#[derive(Clone)]
pub struct ComputeNodeContext {
env: BatchEnvironment,
batch_metrics: BatchMetrics,
mem_context: MemoryContext,
}
impl BatchTaskContext for ComputeNodeContext {
fn get_task_output(&self, task_output_id: TaskOutputId) -> Result<TaskOutput> {
self.env
.task_manager()
.take_output(&task_output_id.to_prost())
}
fn catalog_reader(&self) -> SysCatalogReaderRef {
unimplemented!("not supported in distributed mode")
}
fn is_local_addr(&self, peer_addr: &HostAddr) -> bool {
is_local_address(self.env.server_address(), peer_addr)
}
fn dml_manager(&self) -> DmlManagerRef {
self.env.dml_manager_ref()
}
fn state_store(&self) -> StateStoreImpl {
self.env.state_store()
}
fn batch_metrics(&self) -> Option<BatchMetrics> {
Some(self.batch_metrics.clone())
}
fn spill_metrics(&self) -> Arc<BatchSpillMetrics> {
self.env.spill_metrics()
}
fn client_pool(&self) -> ComputeClientPoolRef {
self.env.client_pool()
}
fn get_config(&self) -> &BatchConfig {
self.env.config()
}
fn source_metrics(&self) -> Arc<SourceMetrics> {
self.env.source_metrics()
}
fn create_executor_mem_context(&self, _executor_id: &str) -> MemoryContext {
let counter = TrAdderAtomic::new(0);
MemoryContext::new(Some(self.mem_context.clone()), counter)
}
fn worker_node_manager(&self) -> Option<WorkerNodeManagerRef> {
None
}
}
impl ComputeNodeContext {
#[cfg(test)]
pub fn for_test() -> Self {
Self {
env: BatchEnvironment::for_test(),
batch_metrics: BatchMetricsInner::for_test(),
mem_context: MemoryContext::none(),
}
}
pub fn new(env: BatchEnvironment) -> Self {
let mem_context = env.task_manager().memory_context_ref();
let batch_metrics = Arc::new(BatchMetricsInner::new(
env.task_manager().metrics(),
env.executor_metrics(),
env.iceberg_scan_metrics(),
));
Self {
env,
batch_metrics,
mem_context,
}
}
}