risingwave_frontend/binder/
bind_param.rs
1use bytes::Bytes;
16use pgwire::types::{Format, FormatIterator};
17use risingwave_common::bail;
18use risingwave_common::error::BoxedError;
19use risingwave_common::types::{Datum, ScalarImpl};
20
21use super::BoundStatement;
22use super::statement::RewriteExprsRecursive;
23use crate::error::{ErrorCode, Result};
24use crate::expr::{Expr, ExprImpl, ExprRewriter, Literal, default_rewrite_expr};
25
26pub(crate) struct ParamRewriter {
28 pub(crate) params: Vec<Option<Bytes>>,
29 pub(crate) parsed_params: Vec<Datum>,
30 pub(crate) param_formats: Vec<Format>,
31 pub(crate) error: Option<BoxedError>,
32}
33
34impl ParamRewriter {
35 pub(crate) fn new(param_formats: Vec<Format>, params: Vec<Option<Bytes>>) -> Self {
36 Self {
37 parsed_params: vec![None; params.len()],
38 params,
39 param_formats,
40 error: None,
41 }
42 }
43}
44
45impl ExprRewriter for ParamRewriter {
46 fn rewrite_expr(&mut self, expr: ExprImpl) -> ExprImpl {
47 if self.error.is_some() {
48 return expr;
49 }
50 default_rewrite_expr(self, expr)
51 }
52
53 fn rewrite_subquery(&mut self, mut subquery: crate::expr::Subquery) -> ExprImpl {
54 subquery.query.rewrite_exprs_recursive(self);
55 subquery.into()
56 }
57
58 fn rewrite_parameter(&mut self, parameter: crate::expr::Parameter) -> ExprImpl {
59 let data_type = parameter.return_type();
60
61 let parameter_index = (parameter.index - 1) as usize;
64
65 fn cstr_to_str(b: &[u8]) -> std::result::Result<&str, BoxedError> {
66 let without_null = if b.last() == Some(&0) {
67 &b[..b.len() - 1]
68 } else {
69 b
70 };
71 Ok(std::str::from_utf8(without_null)?)
72 }
73
74 let datum: Datum = if let Some(val_bytes) = self.params[parameter_index].clone() {
75 let res = match self.param_formats[parameter_index] {
76 Format::Text => {
77 cstr_to_str(&val_bytes).and_then(|str| ScalarImpl::from_text(str, &data_type))
78 }
79 Format::Binary => ScalarImpl::from_binary(&val_bytes, &data_type),
80 };
81 match res {
82 Ok(datum) => Some(datum),
83 Err(e) => {
84 self.error = Some(e);
85 return parameter.into();
86 }
87 }
88 } else {
89 None
90 };
91
92 self.parsed_params[parameter_index].clone_from(&datum);
93 Literal::new(datum, data_type).into()
94 }
95}
96
97impl BoundStatement {
98 pub fn bind_parameter(
99 mut self,
100 params: Vec<Option<Bytes>>,
101 param_formats: Vec<Format>,
102 ) -> Result<(BoundStatement, Vec<Datum>)> {
103 let mut rewriter = ParamRewriter::new(
104 FormatIterator::new(¶m_formats, params.len())
105 .map_err(ErrorCode::BindError)?
106 .collect(),
107 params,
108 );
109
110 self.rewrite_exprs_recursive(&mut rewriter);
111
112 if let Some(err) = rewriter.error {
113 bail!(err);
114 }
115
116 Ok((self, rewriter.parsed_params))
117 }
118}
119
120#[cfg(test)]
121mod test {
122 use bytes::Bytes;
123 use pgwire::types::Format;
124 use risingwave_common::types::DataType;
125 use risingwave_sqlparser::test_utils::parse_sql_statements;
126
127 use crate::binder::BoundStatement;
128 use crate::binder::test_utils::{mock_binder, mock_binder_with_param_types};
129
130 fn create_expect_bound(sql: &str) -> BoundStatement {
131 let mut binder = mock_binder();
132 let stmt = parse_sql_statements(sql).unwrap().remove(0);
133 binder.bind(stmt).unwrap()
134 }
135
136 fn create_actual_bound(
137 sql: &str,
138 param_types: Vec<Option<DataType>>,
139 params: Vec<Option<Bytes>>,
140 param_formats: Vec<Format>,
141 ) -> BoundStatement {
142 let mut binder = mock_binder_with_param_types(param_types);
143 let stmt = parse_sql_statements(sql).unwrap().remove(0);
144 let bound = binder.bind(stmt).unwrap();
145 bound.bind_parameter(params, param_formats).unwrap().0
146 }
147
148 fn expect_actual_eq(expect: BoundStatement, actual: BoundStatement) {
149 assert_eq!(format!("{:?}", expect), format!("{:?}", actual));
151 }
152
153 #[tokio::test]
154 async fn basic_select() {
155 expect_actual_eq(
156 create_expect_bound("select 1::int4"),
157 create_actual_bound(
158 "select $1::int4",
159 vec![],
160 vec![Some("1".into())],
161 vec![Format::Text],
162 ),
163 );
164 }
165
166 #[tokio::test]
167 async fn basic_value() {
168 expect_actual_eq(
169 create_expect_bound("values(1::int4)"),
170 create_actual_bound(
171 "values($1::int4)",
172 vec![],
173 vec![Some("1".into())],
174 vec![Format::Text],
175 ),
176 );
177 }
178
179 #[tokio::test]
180 async fn default_type() {
181 expect_actual_eq(
182 create_expect_bound("select '1'"),
183 create_actual_bound(
184 "select $1",
185 vec![],
186 vec![Some("1".into())],
187 vec![Format::Text],
188 ),
189 );
190 }
191
192 #[tokio::test]
193 async fn cast_after_specific() {
194 expect_actual_eq(
195 create_expect_bound("select 1::varchar"),
196 create_actual_bound(
197 "select $1::varchar",
198 vec![Some(DataType::Int32)],
199 vec![Some("1".into())],
200 vec![Format::Text],
201 ),
202 );
203 }
204
205 #[tokio::test]
206 async fn infer_case() {
207 expect_actual_eq(
208 create_expect_bound("select 1,1::INT4"),
209 create_actual_bound(
210 "select $1,$1::INT4",
211 vec![],
212 vec![Some("1".into())],
213 vec![Format::Text],
214 ),
215 );
216 }
217
218 #[tokio::test]
219 async fn subquery() {
220 expect_actual_eq(
221 create_expect_bound("select (select '1')"),
222 create_actual_bound(
223 "select (select $1)",
224 vec![],
225 vec![Some("1".into())],
226 vec![Format::Text],
227 ),
228 );
229 }
230}