risingwave_stream/executor/wrapper/
schema_check.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use futures_async_stream::try_stream;
18
19use crate::executor::error::StreamExecutorError;
20use crate::executor::{ExecutorInfo, Message, MessageStream};
21
22/// Streams wrapped by `schema_check` will check the passing stream chunk against the expected
23/// schema.
24#[try_stream(ok = Message, error = StreamExecutorError)]
25pub async fn schema_check(info: Arc<ExecutorInfo>, input: impl MessageStream) {
26    #[for_await]
27    for message in input {
28        let message = message?;
29
30        match &message {
31            Message::Chunk(chunk) => risingwave_common::util::schema_check::schema_check(
32                info.schema.fields().iter().map(|f| &f.data_type),
33                chunk.columns(),
34            )
35            .map_err(|e| format!("{e}\nchunk:\n{}", chunk.to_pretty())),
36            Message::Watermark(watermark) => {
37                let expected = info.schema.fields()[watermark.col_idx].data_type();
38                let found = &watermark.data_type;
39                if &expected != found {
40                    Err(format!(
41                        "watermark type mismatched: expected {expected}, found {found}"
42                    ))
43                } else {
44                    Ok(())
45                }
46            }
47            Message::Barrier(_) => Ok(()),
48        }
49        .unwrap_or_else(|e| panic!("schema check failed on {:?}: {}", info, e));
50
51        yield message;
52    }
53}
54
55#[cfg(test)]
56mod tests {
57    use assert_matches::assert_matches;
58    use futures::{StreamExt, pin_mut};
59    use risingwave_common::array::StreamChunk;
60    use risingwave_common::array::stream_chunk::StreamChunkTestExt;
61    use risingwave_common::catalog::{Field, Schema};
62    use risingwave_common::types::DataType;
63    use risingwave_common::util::epoch::test_epoch;
64
65    use super::*;
66    use crate::executor::test_utils::MockSource;
67
68    #[tokio::test]
69    async fn test_schema_ok() {
70        let schema = Schema {
71            fields: vec![
72                Field::unnamed(DataType::Int64),
73                Field::unnamed(DataType::Float64),
74            ],
75        };
76
77        let (mut tx, source) = MockSource::channel();
78        let source = source.into_executor(schema, vec![1]);
79        tx.push_chunk(StreamChunk::from_pretty(
80            "   I     F
81            + 100 200.0
82            +  10  14.0
83            +   4 300.0",
84        ));
85        tx.push_barrier(test_epoch(1), false);
86
87        let checked = schema_check(source.info().clone().into(), source.execute());
88        pin_mut!(checked);
89
90        assert_matches!(checked.next().await.unwrap().unwrap(), Message::Chunk(_));
91        assert_matches!(checked.next().await.unwrap().unwrap(), Message::Barrier(_));
92    }
93
94    #[should_panic]
95    #[tokio::test]
96    async fn test_schema_bad() {
97        let schema = Schema {
98            fields: vec![
99                Field::unnamed(DataType::Int64),
100                Field::unnamed(DataType::Float64),
101            ],
102        };
103
104        let (mut tx, source) = MockSource::channel();
105        let source = source.into_executor(schema, vec![1]);
106        tx.push_chunk(StreamChunk::from_pretty(
107            "   I   I
108            + 100 200
109            +  10  14
110            +   4 300",
111        ));
112        tx.push_barrier(test_epoch(1), false);
113
114        let checked = schema_check(source.info().clone().into(), source.execute());
115        pin_mut!(checked);
116        checked.next().await.unwrap().unwrap();
117    }
118}