risingwave_stream/executor/wrapper/
schema_check.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use futures_async_stream::try_stream;
18
19use crate::executor::error::StreamExecutorError;
20use crate::executor::{ExecutorInfo, Message, MessageStream};
21
22/// Streams wrapped by `schema_check` will check the passing stream chunk against the expected
23/// schema.
24#[try_stream(ok = Message, error = StreamExecutorError)]
25pub async fn schema_check(info: Arc<ExecutorInfo>, input: impl MessageStream) {
26    #[for_await]
27    for message in input {
28        let message = message?;
29
30        match &message {
31            Message::Chunk(chunk) => risingwave_common::util::schema_check::schema_check(
32                info.schema.fields().iter().map(|f| &f.data_type),
33                chunk.columns(),
34            ),
35            Message::Watermark(watermark) => {
36                let expected = info.schema.fields()[watermark.col_idx].data_type();
37                let found = &watermark.data_type;
38                if &expected != found {
39                    Err(format!(
40                        "watermark type mismatched: expected {expected}, found {found}"
41                    ))
42                } else {
43                    Ok(())
44                }
45            }
46            Message::Barrier(_) => Ok(()),
47        }
48        .unwrap_or_else(|e| panic!("schema check failed on {:?}: {}", info, e));
49
50        yield message;
51    }
52}
53
54#[cfg(test)]
55mod tests {
56    use assert_matches::assert_matches;
57    use futures::{StreamExt, pin_mut};
58    use risingwave_common::array::StreamChunk;
59    use risingwave_common::array::stream_chunk::StreamChunkTestExt;
60    use risingwave_common::catalog::{Field, Schema};
61    use risingwave_common::types::DataType;
62    use risingwave_common::util::epoch::test_epoch;
63
64    use super::*;
65    use crate::executor::test_utils::MockSource;
66
67    #[tokio::test]
68    async fn test_schema_ok() {
69        let schema = Schema {
70            fields: vec![
71                Field::unnamed(DataType::Int64),
72                Field::unnamed(DataType::Float64),
73            ],
74        };
75
76        let (mut tx, source) = MockSource::channel();
77        let source = source.into_executor(schema, vec![1]);
78        tx.push_chunk(StreamChunk::from_pretty(
79            "   I     F
80            + 100 200.0
81            +  10  14.0
82            +   4 300.0",
83        ));
84        tx.push_barrier(test_epoch(1), false);
85
86        let checked = schema_check(source.info().clone().into(), source.execute());
87        pin_mut!(checked);
88
89        assert_matches!(checked.next().await.unwrap().unwrap(), Message::Chunk(_));
90        assert_matches!(checked.next().await.unwrap().unwrap(), Message::Barrier(_));
91    }
92
93    #[should_panic]
94    #[tokio::test]
95    async fn test_schema_bad() {
96        let schema = Schema {
97            fields: vec![
98                Field::unnamed(DataType::Int64),
99                Field::unnamed(DataType::Float64),
100            ],
101        };
102
103        let (mut tx, source) = MockSource::channel();
104        let source = source.into_executor(schema, vec![1]);
105        tx.push_chunk(StreamChunk::from_pretty(
106            "   I   I
107            + 100 200
108            +  10  14
109            +   4 300",
110        ));
111        tx.push_barrier(test_epoch(1), false);
112
113        let checked = schema_check(source.info().clone().into(), source.execute());
114        pin_mut!(checked);
115        checked.next().await.unwrap().unwrap();
116    }
117}