risingwave_batch_executors/executor/
postgres_query.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::Context;
16use futures_async_stream::try_stream;
17use futures_util::stream::StreamExt;
18use risingwave_common::array::DataChunk;
19use risingwave_common::catalog::{Field, Schema};
20use risingwave_common::row::OwnedRow;
21use risingwave_common::types::{DataType, Datum, Decimal, ScalarImpl};
22use risingwave_common::util::chunk_coalesce::DataChunkBuilder;
23use risingwave_connector::connector_common::{SslMode, create_pg_client};
24use risingwave_pb::batch_plan::plan_node::NodeBody;
25use tokio_postgres;
26
27use crate::error::BatchError;
28use crate::executor::{BoxedExecutor, BoxedExecutorBuilder, Executor, ExecutorBuilder};
29
30/// `PostgresQuery` executor. Runs a query against a Postgres database.
31pub struct PostgresQueryExecutor {
32    schema: Schema,
33    params: PostgresConnectionParams,
34    query: String,
35    identity: String,
36    chunk_size: usize,
37}
38
39pub struct PostgresConnectionParams {
40    pub host: String,
41    pub port: String,
42    pub username: String,
43    pub password: String,
44    pub database: String,
45    pub ssl_mode: SslMode,
46    pub ssl_root_cert: Option<String>,
47}
48
49impl Executor for PostgresQueryExecutor {
50    fn schema(&self) -> &risingwave_common::catalog::Schema {
51        &self.schema
52    }
53
54    fn identity(&self) -> &str {
55        &self.identity
56    }
57
58    fn execute(self: Box<Self>) -> super::BoxedDataChunkStream {
59        self.do_execute().boxed()
60    }
61}
62
63pub fn postgres_row_to_owned_row(
64    row: tokio_postgres::Row,
65    schema: &Schema,
66) -> Result<OwnedRow, BatchError> {
67    let mut datums = vec![];
68    for i in 0..schema.fields.len() {
69        let rw_field = &schema.fields[i];
70        let name = rw_field.name.as_str();
71        let datum = postgres_cell_to_scalar_impl(&row, &rw_field.data_type, i, name)?;
72        datums.push(datum);
73    }
74    Ok(OwnedRow::new(datums))
75}
76
77// TODO(kwannoel): Support more types, see postgres connector's ScalarAdapter.
78fn postgres_cell_to_scalar_impl(
79    row: &tokio_postgres::Row,
80    data_type: &DataType,
81    i: usize,
82    name: &str,
83) -> Result<Datum, BatchError> {
84    let datum = match data_type {
85        DataType::Boolean
86        | DataType::Int16
87        | DataType::Int32
88        | DataType::Int64
89        | DataType::Float32
90        | DataType::Float64
91        | DataType::Date
92        | DataType::Time
93        | DataType::Timestamp
94        | DataType::Timestamptz
95        | DataType::Jsonb
96        | DataType::Interval
97        | DataType::Varchar
98        | DataType::Bytea => {
99            // ScalarAdapter is also fine. But ScalarImpl is more efficient
100            row.try_get::<_, Option<ScalarImpl>>(i)?
101        }
102        DataType::Decimal => {
103            // Decimal is more efficient than PgNumeric in ScalarAdapter
104            let val = row.try_get::<_, Option<Decimal>>(i)?;
105            val.map(ScalarImpl::from)
106        }
107        _ => {
108            tracing::warn!(name, ?data_type, "unsupported data type, set to null");
109            None
110        }
111    };
112    Ok(datum)
113}
114
115impl PostgresQueryExecutor {
116    pub fn new(
117        schema: Schema,
118        params: PostgresConnectionParams,
119        query: String,
120        identity: String,
121        chunk_size: usize,
122    ) -> Self {
123        Self {
124            schema,
125            params,
126            query,
127            identity,
128            chunk_size,
129        }
130    }
131
132    #[try_stream(ok = DataChunk, error = BatchError)]
133    async fn do_execute(self: Box<Self>) {
134        tracing::debug!("postgres_query_executor: started");
135
136        let client = create_pg_client(
137            &self.params.username,
138            &self.params.password,
139            &self.params.host,
140            &self.params.port,
141            &self.params.database,
142            &self.params.ssl_mode,
143            &self.params.ssl_root_cert,
144        )
145        .await?;
146
147        let params: &[&str] = &[];
148        let row_stream = client
149            .query_raw(&self.query, params)
150            .await
151            .context("postgres_query received error from remote server")?;
152        let mut builder = DataChunkBuilder::new(self.schema.data_types(), self.chunk_size);
153        tracing::debug!("postgres_query_executor: query executed, start deserializing rows");
154        // deserialize the rows
155        #[for_await]
156        for row in row_stream {
157            let row = row?;
158            let owned_row = postgres_row_to_owned_row(row, &self.schema)?;
159            if let Some(chunk) = builder.append_one_row(owned_row) {
160                yield chunk;
161            }
162        }
163        if let Some(chunk) = builder.consume_all() {
164            yield chunk;
165        }
166        return Ok(());
167    }
168}
169
170pub struct PostgresQueryExecutorBuilder {}
171
172impl BoxedExecutorBuilder for PostgresQueryExecutorBuilder {
173    async fn new_boxed_executor(
174        source: &ExecutorBuilder<'_>,
175        _inputs: Vec<BoxedExecutor>,
176    ) -> crate::error::Result<BoxedExecutor> {
177        let postgres_query_node = try_match_expand!(
178            source.plan_node().get_node_body().unwrap(),
179            NodeBody::PostgresQuery
180        )?;
181
182        Ok(Box::new(PostgresQueryExecutor::new(
183            Schema::from_iter(postgres_query_node.columns.iter().map(Field::from)),
184            PostgresConnectionParams {
185                host: postgres_query_node.hostname.clone(),
186                port: postgres_query_node.port.clone(),
187                username: postgres_query_node.username.clone(),
188                password: postgres_query_node.password.clone(),
189                database: postgres_query_node.database.clone(),
190                ssl_mode: postgres_query_node.ssl_mode.parse().unwrap_or_default(),
191                ssl_root_cert: if postgres_query_node.ssl_root_cert.is_empty() {
192                    None
193                } else {
194                    Some(postgres_query_node.ssl_root_cert.clone())
195                },
196            },
197            postgres_query_node.query.clone(),
198            source.plan_node().get_identity().clone(),
199            source.context().get_config().developer.chunk_size,
200        )))
201    }
202}