risingwave_connector/source/cdc/external/
sql_server.rsuse std::cmp::Ordering;
use anyhow::{anyhow, Context};
use futures::stream::BoxStream;
use futures::{pin_mut, StreamExt, TryStreamExt};
use futures_async_stream::try_stream;
use itertools::Itertools;
use risingwave_common::bail;
use risingwave_common::catalog::{ColumnDesc, ColumnId, Schema};
use risingwave_common::row::OwnedRow;
use risingwave_common::types::{DataType, ScalarImpl};
use serde_derive::{Deserialize, Serialize};
use tiberius::{Config, Query, QueryItem};
use crate::error::{ConnectorError, ConnectorResult};
use crate::parser::{sql_server_row_to_owned_row, ScalarImplTiberiusWrapper};
use crate::sink::sqlserver::SqlServerClient;
use crate::source::cdc::external::{
CdcOffset, CdcOffsetParseFunc, DebeziumOffset, ExternalTableConfig, ExternalTableReader,
SchemaTableName,
};
const MAX_COMMIT_LSN: &str = "ffffffff:ffffffff:ffff";
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
pub struct SqlServerOffset {
pub change_lsn: String,
pub commit_lsn: String,
}
impl PartialOrd for SqlServerOffset {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match self.change_lsn.partial_cmp(&other.change_lsn) {
Some(Ordering::Equal) => self.commit_lsn.partial_cmp(&other.commit_lsn),
other => other,
}
}
}
impl SqlServerOffset {
pub fn parse_debezium_offset(offset: &str) -> ConnectorResult<Self> {
let dbz_offset: DebeziumOffset = serde_json::from_str(offset)
.with_context(|| format!("invalid upstream offset: {}", offset))?;
Ok(Self {
change_lsn: dbz_offset
.source_offset
.change_lsn
.context("invalid sql server change_lsn")?,
commit_lsn: dbz_offset
.source_offset
.commit_lsn
.context("invalid sql server commit_lsn")?,
})
}
}
pub struct SqlServerExternalTable {
column_descs: Vec<ColumnDesc>,
pk_names: Vec<String>,
}
impl SqlServerExternalTable {
pub async fn connect(config: ExternalTableConfig) -> ConnectorResult<Self> {
tracing::debug!("connect to sql server");
let mut client_config = Config::new();
client_config.host(&config.host);
client_config.database(&config.database);
client_config.port(config.port.parse::<u16>().unwrap());
client_config.authentication(tiberius::AuthMethod::sql_server(
&config.username,
&config.password,
));
if config.encrypt == "true" {
client_config.encryption(tiberius::EncryptionLevel::Required);
}
client_config.trust_cert();
let mut client = SqlServerClient::new_with_config(client_config).await?;
let mut column_descs = vec![];
let mut pk_names = vec![];
{
let sql = Query::new(format!(
"SELECT
COLUMN_NAME,
DATA_TYPE
FROM
INFORMATION_SCHEMA.COLUMNS
WHERE
TABLE_SCHEMA = '{}'
AND TABLE_NAME = '{}'",
config.schema.clone(),
config.table.clone(),
));
let mut stream = sql.query(&mut client.inner_client).await?;
while let Some(item) = stream.try_next().await? {
match item {
QueryItem::Metadata(_) => {}
QueryItem::Row(row) => {
let col_name: &str = row.try_get(0)?.unwrap();
let col_type: &str = row.try_get(1)?.unwrap();
column_descs.push(ColumnDesc::named(
col_name,
ColumnId::placeholder(),
type_to_rw_type(col_type, col_name)?,
));
}
}
}
}
{
let sql = Query::new(format!(
"SELECT kcu.COLUMN_NAME
FROM
INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc
JOIN
INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS kcu
ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME AND
tc.TABLE_SCHEMA = kcu.TABLE_SCHEMA AND
tc.TABLE_NAME = kcu.TABLE_NAME
WHERE
tc.CONSTRAINT_TYPE = 'PRIMARY KEY' AND
tc.TABLE_SCHEMA = '{}' AND tc.TABLE_NAME = '{}'",
config.schema, config.table,
));
let mut stream = sql.query(&mut client.inner_client).await?;
while let Some(item) = stream.try_next().await? {
match item {
QueryItem::Metadata(_) => {}
QueryItem::Row(row) => {
let pk_name: &str = row.try_get(0)?.unwrap();
pk_names.push(pk_name.to_string());
}
}
}
}
Ok(Self {
column_descs,
pk_names,
})
}
pub fn column_descs(&self) -> &Vec<ColumnDesc> {
&self.column_descs
}
pub fn pk_names(&self) -> &Vec<String> {
&self.pk_names
}
}
fn type_to_rw_type(col_type: &str, col_name: &str) -> ConnectorResult<DataType> {
let dtype = match col_type.to_lowercase().as_str() {
"bit" => DataType::Boolean,
"binary" | "varbinary" => DataType::Bytea,
"tinyint" | "smallint" => DataType::Int16,
"integer" | "int" => DataType::Int32,
"bigint" => DataType::Int64,
"real" => DataType::Float32,
"float" => DataType::Float64,
"decimal" | "numeric" => DataType::Decimal,
"date" => DataType::Date,
"time" => DataType::Time,
"datetime" | "datetime2" | "smalldatetime" => DataType::Timestamp,
"datetimeoffset" => DataType::Timestamptz,
"char" | "nchar" | "varchar" | "nvarchar" | "text" | "ntext" | "xml"
| "uniqueidentifier" => DataType::Varchar,
mssql_type => {
return Err(anyhow!(
"Unsupported Sql Server data type: {:?}, column name: {}",
mssql_type,
col_name
)
.into());
}
};
Ok(dtype)
}
#[derive(Debug)]
pub struct SqlServerExternalTableReader {
rw_schema: Schema,
field_names: String,
client: tokio::sync::Mutex<SqlServerClient>,
}
impl ExternalTableReader for SqlServerExternalTableReader {
async fn current_cdc_offset(&self) -> ConnectorResult<CdcOffset> {
let mut client = self.client.lock().await;
let row = client
.inner_client
.simple_query(String::from("SELECT sys.fn_cdc_get_max_lsn()"))
.await?
.into_row()
.await?
.expect("No result returned by `SELECT sys.fn_cdc_get_max_lsn()`");
let max_lsn = match row.try_get::<&[u8], usize>(0)? {
Some(bytes) => {
let mut hex_string = String::with_capacity(bytes.len() * 2 + 2);
assert_eq!(
bytes.len(),
10,
"sys.fn_cdc_get_max_lsn() should return a 10 bytes array."
);
for byte in &bytes[0..4] {
hex_string.push_str(&format!("{:02x}", byte));
}
hex_string.push(':');
for byte in &bytes[4..8] {
hex_string.push_str(&format!("{:02x}", byte));
}
hex_string.push(':');
for byte in &bytes[8..10] {
hex_string.push_str(&format!("{:02x}", byte));
}
hex_string
}
None => bail!("None is returned by `SELECT sys.fn_cdc_get_max_lsn()`, please ensure Sql Server Agent is running."),
};
tracing::debug!("current max_lsn: {}", max_lsn);
Ok(CdcOffset::SqlServer(SqlServerOffset {
change_lsn: max_lsn,
commit_lsn: MAX_COMMIT_LSN.into(),
}))
}
fn snapshot_read(
&self,
table_name: SchemaTableName,
start_pk: Option<OwnedRow>,
primary_keys: Vec<String>,
limit: u32,
) -> BoxStream<'_, ConnectorResult<OwnedRow>> {
self.snapshot_read_inner(table_name, start_pk, primary_keys, limit)
}
}
impl SqlServerExternalTableReader {
pub async fn new(
config: ExternalTableConfig,
rw_schema: Schema,
pk_indices: Vec<usize>,
) -> ConnectorResult<Self> {
tracing::info!(
?rw_schema,
?pk_indices,
"create sql server external table reader"
);
let mut client_config = Config::new();
client_config.host(&config.host);
client_config.database(&config.database);
client_config.port(config.port.parse::<u16>().unwrap());
client_config.authentication(tiberius::AuthMethod::sql_server(
&config.username,
&config.password,
));
if config.encrypt == "true" {
client_config.encryption(tiberius::EncryptionLevel::Required);
}
client_config.trust_cert();
let client = SqlServerClient::new_with_config(client_config).await?;
let field_names = rw_schema
.fields
.iter()
.map(|f| Self::quote_column(&f.name))
.join(",");
Ok(Self {
rw_schema,
field_names,
client: tokio::sync::Mutex::new(client),
})
}
pub fn get_cdc_offset_parser() -> CdcOffsetParseFunc {
Box::new(move |offset| {
Ok(CdcOffset::SqlServer(
SqlServerOffset::parse_debezium_offset(offset)?,
))
})
}
#[try_stream(boxed, ok = OwnedRow, error = ConnectorError)]
async fn snapshot_read_inner(
&self,
table_name: SchemaTableName,
start_pk_row: Option<OwnedRow>,
primary_keys: Vec<String>,
limit: u32,
) {
let order_key = primary_keys
.iter()
.map(|col| Self::quote_column(col))
.join(",");
let mut sql = Query::new(if start_pk_row.is_none() {
format!(
"SELECT {} FROM {} ORDER BY {} OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY",
self.field_names,
Self::get_normalized_table_name(&table_name),
order_key,
)
} else {
let filter_expr = Self::filter_expression(&primary_keys);
format!(
"SELECT {} FROM {} WHERE {} ORDER BY {} OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY",
self.field_names,
Self::get_normalized_table_name(&table_name),
filter_expr,
order_key,
)
});
let mut client = self.client.lock().await;
if let Some(pk_row) = start_pk_row {
let params: Vec<Option<ScalarImpl>> = pk_row.into_iter().collect();
for param in params {
sql.bind(ScalarImplTiberiusWrapper::from(param.unwrap()));
}
}
let stream = sql.query(&mut client.inner_client).await?.into_row_stream();
let row_stream = stream.map(|res| {
let mut row = res?;
Ok::<_, ConnectorError>(sql_server_row_to_owned_row(&mut row, &self.rw_schema))
});
pin_mut!(row_stream);
#[for_await]
for row in row_stream {
let row = row?;
yield row;
}
}
pub fn get_normalized_table_name(table_name: &SchemaTableName) -> String {
format!(
"\"{}\".\"{}\"",
table_name.schema_name, table_name.table_name
)
}
fn filter_expression(columns: &[String]) -> String {
let mut conditions = vec![];
conditions.push(format!("({} > @P{})", Self::quote_column(&columns[0]), 1));
for i in 2..=columns.len() {
let mut condition = String::new();
for (j, col) in columns.iter().enumerate().take(i - 1) {
if j == 0 {
condition.push_str(&format!("({} = @P{})", Self::quote_column(col), j + 1));
} else {
condition.push_str(&format!(
" AND ({} = @P{})",
Self::quote_column(col),
j + 1
));
}
}
condition.push_str(&format!(
" AND ({} > @P{})",
Self::quote_column(&columns[i - 1]),
i
));
conditions.push(format!("({})", condition));
}
if columns.len() > 1 {
conditions.join(" OR ")
} else {
conditions.join("")
}
}
fn quote_column(column: &str) -> String {
format!("\"{}\"", column)
}
}
#[cfg(test)]
mod tests {
use crate::source::cdc::external::SqlServerExternalTableReader;
#[test]
fn test_sql_server_filter_expr() {
let cols = vec!["id".to_string()];
let expr = SqlServerExternalTableReader::filter_expression(&cols);
assert_eq!(expr, "(\"id\" > @P1)");
let cols = vec!["aa".to_string(), "bb".to_string(), "cc".to_string()];
let expr = SqlServerExternalTableReader::filter_expression(&cols);
assert_eq!(
expr,
"(\"aa\" > @P1) OR ((\"aa\" = @P1) AND (\"bb\" > @P2)) OR ((\"aa\" = @P1) AND (\"bb\" = @P2) AND (\"cc\" > @P3))"
);
}
}