risingwave_connector/parser/
sql_server.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::str::FromStr;
17use std::sync::LazyLock;
18
19use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
20use risingwave_common::catalog::Schema;
21use risingwave_common::log::LogSuppressor;
22use risingwave_common::row::OwnedRow;
23use risingwave_common::types::{DataType, Date, Decimal, ScalarImpl, Time, Timestamp, Timestamptz};
24use rust_decimal::Decimal as RustDecimal;
25use thiserror_ext::AsReport;
26use tiberius::Row;
27use tiberius::xml::XmlData;
28use uuid::Uuid;
29
30use crate::parser::utils::log_error;
31
32static LOG_SUPPRESSOR: LazyLock<LogSuppressor> = LazyLock::new(LogSuppressor::default);
33
34pub fn sql_server_row_to_owned_row(row: &mut Row, schema: &Schema) -> OwnedRow {
35    let mut datums: Vec<Option<ScalarImpl>> = vec![];
36    let mut money_fields: HashSet<&str> = HashSet::new();
37    // Special handling of the money field, as the third-party library Tiberius converts the money type to i64.
38    for (column, _) in row.cells() {
39        if column.column_type() == tiberius::ColumnType::Money {
40            money_fields.insert(column.name());
41        }
42    }
43    for i in 0..schema.fields.len() {
44        let rw_field = &schema.fields[i];
45        let name = rw_field.name.as_str();
46        let datum = match money_fields.contains(name) {
47            true => match row.try_get::<i64, usize>(i) {
48                Ok(Some(value)) => Some(convert_money_i64_to_type(value, &rw_field.data_type)),
49                Ok(None) => None,
50                Err(err) => {
51                    log_error!(name, err, "parse column failed");
52                    None
53                }
54            },
55            false => match row.try_get::<ScalarImplTiberiusWrapper, usize>(i) {
56                Ok(datum) => datum
57                    .map(|d| d.0)
58                    .map(|scalar| coerce_scalar_to_target_type(scalar, &rw_field.data_type)),
59                Err(err) => {
60                    log_error!(name, err, "parse column failed");
61                    None
62                }
63            },
64        };
65
66        datums.push(datum);
67    }
68    OwnedRow::new(datums)
69}
70
71fn coerce_scalar_to_target_type(scalar: ScalarImpl, target_type: &DataType) -> ScalarImpl {
72    match (scalar, target_type) {
73        // SQL Server validator allows integer upcast (e.g. `int` -> `BIGINT`).
74        // Coerce snapshot values to the target RW type to keep validation and execution consistent.
75        (ScalarImpl::Int16(v), DataType::Int32) => ScalarImpl::Int32(v as i32),
76        (ScalarImpl::Int16(v), DataType::Int64) => ScalarImpl::Int64(v as i64),
77        (ScalarImpl::Int32(v), DataType::Int64) => ScalarImpl::Int64(v as i64),
78        // SQL Server `real` may map to `FLOAT` in RW validator.
79        (ScalarImpl::Float32(v), DataType::Float64) => ScalarImpl::Float64((v.0 as f64).into()),
80        (scalar, _) => scalar,
81    }
82}
83
84pub fn convert_money_i64_to_type(value: i64, data_type: &DataType) -> ScalarImpl {
85    match data_type {
86        DataType::Decimal => {
87            ScalarImpl::Decimal(Decimal::from(value) / Decimal::from_str("10000").unwrap())
88        }
89        _ => {
90            panic!(
91                "Conversion of Money type to {:?} is not supported",
92                data_type
93            );
94        }
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use risingwave_common::types::F32;
101
102    use super::*;
103
104    #[test]
105    fn test_integer_upcast_coercion() {
106        let v = coerce_scalar_to_target_type(ScalarImpl::Int32(7), &DataType::Int64);
107        assert_eq!(v, ScalarImpl::Int64(7));
108
109        let v = coerce_scalar_to_target_type(ScalarImpl::Int16(7), &DataType::Int32);
110        assert_eq!(v, ScalarImpl::Int32(7));
111
112        let v = coerce_scalar_to_target_type(ScalarImpl::Int16(7), &DataType::Int64);
113        assert_eq!(v, ScalarImpl::Int64(7));
114    }
115
116    #[test]
117    fn test_float_upcast_coercion() {
118        let v =
119            coerce_scalar_to_target_type(ScalarImpl::Float32(F32::from(1.25)), &DataType::Float64);
120        assert_eq!(v, ScalarImpl::Float64(1.25.into()));
121    }
122
123    #[test]
124    fn test_non_upcast_keeps_original() {
125        let v = coerce_scalar_to_target_type(ScalarImpl::Int32(7), &DataType::Int32);
126        assert_eq!(v, ScalarImpl::Int32(7));
127    }
128}
129macro_rules! impl_tiberius_wrapper {
130    ($wrapper_name:ident, $variant_name:ident) => {
131        pub struct $wrapper_name($variant_name);
132
133        impl From<$variant_name> for $wrapper_name {
134            fn from(value: $variant_name) -> Self {
135                Self(value)
136            }
137        }
138    };
139}
140
141impl_tiberius_wrapper!(ScalarImplTiberiusWrapper, ScalarImpl);
142impl_tiberius_wrapper!(TimeTiberiusWrapper, Time);
143impl_tiberius_wrapper!(DateTiberiusWrapper, Date);
144impl_tiberius_wrapper!(TimestampTiberiusWrapper, Timestamp);
145impl_tiberius_wrapper!(TimestamptzTiberiusWrapper, Timestamptz);
146impl_tiberius_wrapper!(DecimalTiberiusWrapper, Decimal);
147
148macro_rules! impl_chrono_tiberius_wrapper {
149    ($wrapper_name:ident, $variant_name:ident, $chrono:ty) => {
150        impl<'a> tiberius::IntoSql<'a> for $wrapper_name {
151            fn into_sql(self) -> tiberius::ColumnData<'a> {
152                self.0.0.into_sql()
153            }
154        }
155
156        impl<'a> tiberius::FromSql<'a> for $wrapper_name {
157            fn from_sql(
158                value: &'a tiberius::ColumnData<'static>,
159            ) -> tiberius::Result<Option<Self>> {
160                let instant = <$chrono>::from_sql(value)?;
161                let time = instant.map($variant_name::from).map($wrapper_name::from);
162                tiberius::Result::Ok(time)
163            }
164        }
165    };
166}
167
168impl_chrono_tiberius_wrapper!(TimeTiberiusWrapper, Time, NaiveTime);
169impl_chrono_tiberius_wrapper!(DateTiberiusWrapper, Date, NaiveDate);
170impl_chrono_tiberius_wrapper!(TimestampTiberiusWrapper, Timestamp, NaiveDateTime);
171
172impl<'a> tiberius::IntoSql<'a> for DecimalTiberiusWrapper {
173    fn into_sql(self) -> tiberius::ColumnData<'a> {
174        match self.0 {
175            Decimal::Normalized(d) => d.into_sql(),
176            Decimal::NaN => tiberius::ColumnData::Numeric(None),
177            Decimal::PositiveInf => tiberius::ColumnData::Numeric(None),
178            Decimal::NegativeInf => tiberius::ColumnData::Numeric(None),
179        }
180    }
181}
182
183impl<'a> tiberius::FromSql<'a> for DecimalTiberiusWrapper {
184    // TODO(kexiang): will sql server have inf/-inf/nan for decimal?
185    fn from_sql(value: &'a tiberius::ColumnData<'static>) -> tiberius::Result<Option<Self>> {
186        tiberius::Result::Ok(
187            RustDecimal::from_sql(value)?
188                .map(Decimal::Normalized)
189                .map(DecimalTiberiusWrapper::from),
190        )
191    }
192}
193
194impl<'a> tiberius::IntoSql<'a> for TimestamptzTiberiusWrapper {
195    fn into_sql(self) -> tiberius::ColumnData<'a> {
196        self.0.to_datetime_utc().into_sql()
197    }
198}
199
200impl<'a> tiberius::FromSql<'a> for TimestamptzTiberiusWrapper {
201    fn from_sql(value: &'a tiberius::ColumnData<'static>) -> tiberius::Result<Option<Self>> {
202        let instant = DateTime::<Utc>::from_sql(value)?;
203        let time = instant
204            .map(Timestamptz::from)
205            .map(TimestamptzTiberiusWrapper::from);
206        tiberius::Result::Ok(time)
207    }
208}
209
210/// The following table shows the mapping between Rust types and Sql Server types in tiberius.
211/// |Rust Type|Sql Server Type|
212/// |`u8`|`tinyint`|
213/// |`i16`|`smallint`|
214/// |`i32`|`int`|
215/// |`i64`|`bigint`|
216/// |`f32`|`float(24)`|
217/// |`f64`|`float(53)`|
218/// |`bool`|`bit`|
219/// |`String`/`&str`|`nvarchar`/`varchar`/`nchar`/`char`/`ntext`/`text`|
220/// |`Vec<u8>`/`&[u8]`|`binary`/`varbinary`/`image`|
221/// |[`Uuid`]|`uniqueidentifier`|
222/// |[`Numeric`]|`numeric`/`decimal`|
223/// |[`Decimal`] (with feature flag `rust_decimal`)|`numeric`/`decimal`|
224/// |[`XmlData`]|`xml`|
225/// |[`NaiveDateTime`] (with feature flag `chrono`)|`datetime`/`datetime2`/`smalldatetime`|
226/// |[`NaiveDate`] (with feature flag `chrono`)|`date`|
227/// |[`NaiveTime`] (with feature flag `chrono`)|`time`|
228/// |[`DateTime`] (with feature flag `chrono`)|`datetimeoffset`|
229///
230/// See the [`time`] module for more information about the date and time structs.
231///
232/// [`Row#get`]: struct.Row.html#method.get
233/// [`Row#try_get`]: struct.Row.html#method.try_get
234/// [`time`]: time/index.html
235/// [`Uuid`]: struct.Uuid.html
236/// [`Numeric`]: numeric/struct.Numeric.html
237/// [`Decimal`]: numeric/struct.Decimal.html
238/// [`XmlData`]: xml/struct.XmlData.html
239/// [`NaiveDateTime`]: time/chrono/struct.NaiveDateTime.html
240/// [`NaiveDate`]: time/chrono/struct.NaiveDate.html
241/// [`NaiveTime`]: time/chrono/struct.NaiveTime.html
242/// [`DateTime`]: time/chrono/struct.DateTime.html
243impl<'a> tiberius::FromSql<'a> for ScalarImplTiberiusWrapper {
244    fn from_sql(value: &'a tiberius::ColumnData<'static>) -> tiberius::Result<Option<Self>> {
245        Ok(match &value {
246            tiberius::ColumnData::U8(_) => u8::from_sql(value)?
247                .map(|v| ScalarImplTiberiusWrapper::from(ScalarImpl::from(v as i16))),
248            tiberius::ColumnData::I16(_) => i16::from_sql(value)?
249                .map(ScalarImpl::from)
250                .map(ScalarImplTiberiusWrapper::from),
251            tiberius::ColumnData::I32(_) => i32::from_sql(value)?
252                .map(ScalarImpl::from)
253                .map(ScalarImplTiberiusWrapper::from),
254            tiberius::ColumnData::I64(_) => i64::from_sql(value)?
255                .map(ScalarImpl::from)
256                .map(ScalarImplTiberiusWrapper::from),
257            tiberius::ColumnData::F32(_) => f32::from_sql(value)?
258                .map(ScalarImpl::from)
259                .map(ScalarImplTiberiusWrapper::from),
260            tiberius::ColumnData::F64(_) => f64::from_sql(value)?
261                .map(ScalarImpl::from)
262                .map(ScalarImplTiberiusWrapper::from),
263            tiberius::ColumnData::Bit(_) => bool::from_sql(value)?
264                .map(ScalarImpl::from)
265                .map(ScalarImplTiberiusWrapper::from),
266            tiberius::ColumnData::String(_) => <&str>::from_sql(value)?
267                .map(ScalarImpl::from)
268                .map(ScalarImplTiberiusWrapper::from),
269            tiberius::ColumnData::Numeric(_) => DecimalTiberiusWrapper::from_sql(value)?
270                .map(|w| ScalarImpl::from(w.0))
271                .map(ScalarImplTiberiusWrapper::from),
272            tiberius::ColumnData::DateTime(_)
273            | tiberius::ColumnData::DateTime2(_)
274            | tiberius::ColumnData::SmallDateTime(_) => TimestampTiberiusWrapper::from_sql(value)?
275                .map(|w| ScalarImpl::from(w.0))
276                .map(ScalarImplTiberiusWrapper::from),
277            tiberius::ColumnData::Time(_) => TimeTiberiusWrapper::from_sql(value)?
278                .map(|w| ScalarImpl::from(w.0))
279                .map(ScalarImplTiberiusWrapper::from),
280            tiberius::ColumnData::Date(_) => DateTiberiusWrapper::from_sql(value)?
281                .map(|w| ScalarImpl::from(w.0))
282                .map(ScalarImplTiberiusWrapper::from),
283            tiberius::ColumnData::DateTimeOffset(_) => TimestamptzTiberiusWrapper::from_sql(value)?
284                .map(|w| ScalarImpl::from(w.0))
285                .map(ScalarImplTiberiusWrapper::from),
286            tiberius::ColumnData::Binary(_) => <&[u8]>::from_sql(value)?
287                .map(ScalarImpl::from)
288                .map(ScalarImplTiberiusWrapper::from),
289            tiberius::ColumnData::Guid(_) => <Uuid>::from_sql(value)?
290                .map(|uuid| uuid.to_string().to_uppercase())
291                .map(ScalarImpl::from)
292                .map(ScalarImplTiberiusWrapper::from),
293            tiberius::ColumnData::Xml(_) => <&XmlData>::from_sql(value)?
294                .map(|xml| xml.clone().into_string())
295                .map(ScalarImpl::from)
296                .map(ScalarImplTiberiusWrapper::from),
297        })
298    }
299}
300
301/// The following table shows the mapping between Rust types and Sql Server types in tiberius.
302/// |Rust type|Sql Server type|
303/// |--------|--------|
304/// |`u8`|`tinyint`|
305/// |`i16`|`smallint`|
306/// |`i32`|`int`|
307/// |`i64`|`bigint`|
308/// |`f32`|`float(24)`|
309/// |`f64`|`float(53)`|
310/// |`bool`|`bit`|
311/// |`String`/`&str` (< 4000 characters)|`nvarchar(4000)`|
312/// |`String`/`&str`|`nvarchar(max)`|
313/// |`Vec<u8>`/`&[u8]` (< 8000 bytes)|`varbinary(8000)`|
314/// |`Vec<u8>`/`&[u8]`|`varbinary(max)`|
315/// |[`Uuid`]|`uniqueidentifier`|
316/// |[`Numeric`]|`numeric`/`decimal`|
317/// |[`Decimal`] (with feature flag `rust_decimal`)|`numeric`/`decimal`|
318/// |[`BigDecimal`] (with feature flag `bigdecimal`)|`numeric`/`decimal`|
319/// |[`XmlData`]|`xml`|
320/// |[`NaiveDate`] (with `chrono` feature, TDS 7.3 >)|`date`|
321/// |[`NaiveTime`] (with `chrono` feature, TDS 7.3 >)|`time`|
322/// |[`DateTime`] (with `chrono` feature, TDS 7.3 >)|`datetimeoffset`|
323/// |[`NaiveDateTime`] (with `chrono` feature, TDS 7.3 >)|`datetime2`|
324/// |[`NaiveDateTime`] (with `chrono` feature, TDS 7.2)|`datetime`|
325///
326/// It is possible to use some of the types to write into columns that are not
327/// of the same type. For example on systems following the TDS 7.3 standard (SQL
328/// Server 2008 and later), the chrono type `NaiveDateTime` can also be used to
329/// write to `datetime`, `datetime2` and `smalldatetime` columns. All string
330/// types can also be used with `ntext`, `text`, `varchar`, `nchar` and `char`
331/// columns. All binary types can also be used with `binary` and `image`
332/// columns.
333///
334/// See the [`time`] module for more information about the date and time structs.
335///
336/// [`Client#query`]: struct.Client.html#method.query
337/// [`Client#execute`]: struct.Client.html#method.execute
338/// [`time`]: time/index.html
339/// [`Uuid`]: struct.Uuid.html
340/// [`Numeric`]: numeric/struct.Numeric.html
341/// [`Decimal`]: numeric/struct.Decimal.html
342/// [`BigDecimal`]: numeric/struct.BigDecimal.html
343/// [`XmlData`]: xml/struct.XmlData.html
344/// [`NaiveDateTime`]: time/chrono/struct.NaiveDateTime.html
345/// [`NaiveDate`]: time/chrono/struct.NaiveDate.html
346/// [`NaiveTime`]: time/chrono/struct.NaiveTime.html
347/// [`DateTime`]: time/chrono/struct.DateTime.html
348impl<'a> tiberius::IntoSql<'a> for ScalarImplTiberiusWrapper {
349    fn into_sql(self) -> tiberius::ColumnData<'a> {
350        match self.0 {
351            ScalarImpl::Int16(v) => v.into_sql(),
352            ScalarImpl::Int32(v) => v.into_sql(),
353            ScalarImpl::Int64(v) => v.into_sql(),
354            ScalarImpl::Float32(v) => v.0.into_sql(),
355            ScalarImpl::Float64(v) => v.0.into_sql(),
356            ScalarImpl::Bool(v) => v.into_sql(),
357            ScalarImpl::Decimal(v) => DecimalTiberiusWrapper::from(v).into_sql(),
358            ScalarImpl::Date(v) => DateTiberiusWrapper::from(v).into_sql(),
359            ScalarImpl::Timestamp(v) => TimestampTiberiusWrapper::from(v).into_sql(),
360            ScalarImpl::Timestamptz(v) => TimestamptzTiberiusWrapper::from(v).into_sql(),
361            ScalarImpl::Time(v) => TimeTiberiusWrapper::from(v).into_sql(),
362            ScalarImpl::Bytea(v) => {
363                let value: Vec<u8> = (*v).to_vec();
364                value.into_sql()
365            }
366            ScalarImpl::Utf8(v) => String::from(v).into_sql(),
367            value => {
368                // Serial, Interval, Jsonb, Int256, Struct, List are not supported yet
369                unimplemented!("the sql server decoding for {:?} is unsupported", value);
370            }
371        }
372    }
373}