risingwave_stream/executor/over_window/
eowc.rs

1// Copyright 2023 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::marker::PhantomData;
17use std::ops::Bound;
18
19use anyhow::Context;
20use itertools::Itertools;
21use risingwave_common::array::stream_record::Record;
22use risingwave_common::array::{ArrayRef, Op};
23use risingwave_common::row::RowExt;
24use risingwave_common::types::{ToDatumRef, ToOwnedDatum};
25use risingwave_common::util::iter_util::{ZipEqDebug, ZipEqFast};
26use risingwave_common::util::memcmp_encoding::{self, MemcmpEncoded};
27use risingwave_common::util::row_serde::OrderedRowSerde;
28use risingwave_common::util::sort_util::OrderType;
29use risingwave_common::{must_match, row};
30use risingwave_common_estimate_size::EstimateSize;
31use risingwave_common_estimate_size::collections::EstimatedVecDeque;
32use risingwave_expr::window_function::{
33    StateEvictHint, StateKey, WindowFuncCall, WindowStateSnapshot, WindowStates,
34    create_window_state,
35};
36use risingwave_pb::window_function::{
37    StateKey as PbStateKey, WindowStateSnapshot as PbWindowStateSnapshot,
38};
39use risingwave_storage::store::PrefetchOptions;
40use tracing::debug;
41
42use crate::cache::ManagedLruCache;
43use crate::common::metrics::MetricsInfo;
44use crate::executor::prelude::*;
45
46struct Partition {
47    states: WindowStates,
48    curr_row_buffer: EstimatedVecDeque<OwnedRow>,
49    /// Cached intermediate state row for this partition, used for upsert operations.
50    /// `None` means no prior row exists in the intermediate state table for this partition.
51    intermediate_state_row: Option<OwnedRow>,
52}
53
54impl EstimateSize for Partition {
55    fn estimated_heap_size(&self) -> usize {
56        let mut total_size = self.curr_row_buffer.estimated_heap_size();
57        for state in self.states.iter() {
58            total_size += state.estimated_heap_size();
59        }
60        if let Some(row) = &self.intermediate_state_row {
61            total_size += row.estimated_heap_size();
62        }
63        total_size
64    }
65}
66
67type PartitionCache = ManagedLruCache<MemcmpEncoded, Partition>; // TODO(rc): use `K: HashKey` as key like in hash agg?
68
69/// Encode a [`WindowStateSnapshot`] to bytes for persistence using protobuf.
70fn encode_snapshot(snapshot: &WindowStateSnapshot, pk_ser: &OrderedRowSerde) -> Vec<u8> {
71    use prost::Message;
72    let pb = PbWindowStateSnapshot {
73        last_output_key: snapshot.last_output_key.as_ref().map(|key| PbStateKey {
74            order_key: key.order_key.to_vec(),
75            pk: key.pk.as_inner().memcmp_serialize(pk_ser),
76        }),
77        function_state: Some(snapshot.function_state.clone()),
78    };
79    pb.encode_to_vec()
80}
81
82/// Decode a [`WindowStateSnapshot`] from bytes during recovery using protobuf.
83fn decode_snapshot(
84    bytes: &[u8],
85    pk_deser: &OrderedRowSerde,
86) -> StreamExecutorResult<WindowStateSnapshot> {
87    use prost::Message;
88    let pb = PbWindowStateSnapshot::decode(bytes).context("failed to decode snapshot")?;
89    let last_output_key = pb
90        .last_output_key
91        .map(|key| {
92            let pk = pk_deser
93                .deserialize(&key.pk)
94                .context("failed to deserialize pk")?;
95            Ok::<_, anyhow::Error>(StateKey {
96                order_key: key.order_key.into(),
97                pk: pk.into(),
98            })
99        })
100        .transpose()?;
101    let function_state = pb
102        .function_state
103        .context("snapshot missing function_state")?;
104    Ok(WindowStateSnapshot {
105        last_output_key,
106        function_state,
107    })
108}
109
110/// [`EowcOverWindowExecutor`] consumes ordered input (on order key column with watermark in
111/// ascending order) and outputs window function results. One [`EowcOverWindowExecutor`] can handle
112/// one combination of partition key and order key.
113///
114/// The reason not to use [`SortBuffer`] is that the table schemas of [`EowcOverWindowExecutor`] and
115/// [`SortBuffer`] are different, since we don't have something like a _grouped_ sort buffer.
116///
117/// [`SortBuffer`]: crate::executor::eowc::SortBuffer
118///
119/// Basic idea:
120///
121/// ```text
122/// ──────────────┬────────────────────────────────────────────────────── curr evict row
123///               │ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING
124///        (1)    │ ─┬─
125///               │  │RANGE BETWEEN '1hr' PRECEDING AND '1hr' FOLLOWING
126///        ─┬─    │  │ ─┬─
127///   LAG(1)│        │  │
128/// ────────┴──┬─────┼──┼──────────────────────────────────────────────── curr output row
129///     LEAD(1)│     │  │GROUPS 1 PRECEDING AND 1 FOLLOWING
130///                  │
131///                  │ (2)
132/// ─────────────────┴─────────────────────────────────────────────────── curr input row
133/// (1): additional buffered input (unneeded) for some window
134/// (2): additional delay (already able to output) for some window
135/// ```
136///
137/// - State table schema = input schema, state table pk = `partition key | order key | input pk`.
138/// - Output schema = input schema + window function results.
139/// - Rows in range (`curr evict row`, `curr input row`] are in state table.
140/// - `curr evict row` <= min(last evict rows of all `WindowState`s).
141/// - `WindowState` should output agg result for `curr output row`.
142/// - Recover: iterate through state table, push rows to `WindowState`, ignore ready windows.
143pub struct EowcOverWindowExecutor<S: StateStore> {
144    input: Executor,
145    inner: ExecutorInner<S>,
146}
147
148struct ExecutorInner<S: StateStore> {
149    actor_ctx: ActorContextRef,
150
151    schema: Schema,
152    calls: Vec<WindowFuncCall>,
153    input_stream_key: Vec<usize>,
154    partition_key_indices: Vec<usize>,
155    order_key_index: usize, // no `OrderType` here, cuz we expect the input is ascending
156    state_table: StateTable<S>,
157    state_table_schema_len: usize,
158    watermark_sequence: AtomicU64Ref,
159    /// Optional state table for persisting window function intermediate states.
160    /// See `StreamEowcOverWindow::infer_intermediate_state_table` for schema definition.
161    intermediate_state_table: Option<StateTable<S>>,
162    /// Serde for input stream key (pk), used for encoding/decoding `StateKey` in snapshots.
163    /// Only initialized when `intermediate_state_table` is present.
164    pk_serde: Option<OrderedRowSerde>,
165}
166
167struct ExecutionVars<S: StateStore> {
168    partitions: PartitionCache,
169    _phantom: PhantomData<S>,
170}
171
172impl<S: StateStore> Execute for EowcOverWindowExecutor<S> {
173    fn execute(self: Box<Self>) -> BoxedMessageStream {
174        self.executor_inner().boxed()
175    }
176}
177
178pub struct EowcOverWindowExecutorArgs<S: StateStore> {
179    pub actor_ctx: ActorContextRef,
180
181    pub input: Executor,
182
183    pub schema: Schema,
184    pub calls: Vec<WindowFuncCall>,
185    pub partition_key_indices: Vec<usize>,
186    pub order_key_index: usize,
187    pub state_table: StateTable<S>,
188    pub watermark_epoch: AtomicU64Ref,
189    /// Optional state table for persisting window function intermediate states.
190    /// See `StreamEowcOverWindow::infer_intermediate_state_table` for schema definition.
191    pub intermediate_state_table: Option<StateTable<S>>,
192}
193
194impl<S: StateStore> EowcOverWindowExecutor<S> {
195    pub fn new(args: EowcOverWindowExecutorArgs<S>) -> Self {
196        let input_info = args.input.info().clone();
197
198        // Build pk_serde if intermediate_state_table is present
199        let pk_serde = args.intermediate_state_table.as_ref().map(|_| {
200            let pk_data_types: Vec<_> = input_info
201                .stream_key
202                .iter()
203                .map(|&i| args.schema[i].data_type())
204                .collect();
205            let pk_order_types: Vec<_> = input_info
206                .stream_key
207                .iter()
208                .map(|_| OrderType::ascending())
209                .collect();
210            OrderedRowSerde::new(pk_data_types, pk_order_types)
211        });
212
213        Self {
214            input: args.input,
215            inner: ExecutorInner {
216                actor_ctx: args.actor_ctx,
217                schema: args.schema,
218                calls: args.calls,
219                input_stream_key: input_info.stream_key,
220                partition_key_indices: args.partition_key_indices,
221                order_key_index: args.order_key_index,
222                state_table: args.state_table,
223                state_table_schema_len: input_info.schema.len(),
224                watermark_sequence: args.watermark_epoch,
225                intermediate_state_table: args.intermediate_state_table,
226                pk_serde,
227            },
228        }
229    }
230
231    /// Load intermediate state snapshots from the state table and restore into partition states.
232    async fn load_intermediate_state(
233        this: &ExecutorInner<S>,
234        partition: &mut Partition,
235        partition_key: impl Row,
236        encoded_partition_key: &MemcmpEncoded,
237    ) -> StreamExecutorResult<()> {
238        let Some(intermediate_state_table) = &this.intermediate_state_table else {
239            return Ok(());
240        };
241        let pk_serde = this
242            .pk_serde
243            .as_ref()
244            .expect("pk_serde must be set when intermediate_state_table is present");
245
246        for state in partition.states.iter_mut() {
247            state.enable_persistence();
248        }
249
250        let partition_key_owned = partition_key.to_owned_row();
251        if let Some(row) = intermediate_state_table
252            .get_row(&partition_key_owned)
253            .await?
254        {
255            let num_partition_key_cols = this.partition_key_indices.len();
256            let num_calls = this.calls.len();
257
258            for call_index in 0..num_calls {
259                let state_col = num_partition_key_cols + call_index;
260                if state_col < row.len() {
261                    if let Some(state_bytes) = row.datum_at(state_col) {
262                        let snapshot = decode_snapshot(state_bytes.into_bytea(), pk_serde)?;
263                        debug!(
264                            "Restoring intermediate state for partition {:?}, call_index {}, has_last_key: {}",
265                            encoded_partition_key,
266                            call_index,
267                            snapshot.last_output_key.is_some()
268                        );
269                        partition
270                            .states
271                            .get_mut(call_index)
272                            .unwrap()
273                            .restore(snapshot)?;
274                    }
275                } else {
276                    return Err(anyhow::anyhow!(
277                        "intermediate state row has fewer columns ({}) than expected ({}) \
278                        at call_index {}, state may be corrupted",
279                        row.len(),
280                        num_partition_key_cols + num_calls,
281                        call_index
282                    )
283                    .into());
284                }
285            }
286            partition.intermediate_state_row = Some(row);
287        }
288        Ok(())
289    }
290
291    /// Persist intermediate state snapshots to the state table.
292    fn persist_intermediate_state(
293        this: &mut ExecutorInner<S>,
294        partition: &mut Partition,
295        partition_key: impl Row,
296    ) {
297        let Some(intermediate_state_table) = &mut this.intermediate_state_table else {
298            return;
299        };
300        let pk_serde = this
301            .pk_serde
302            .as_ref()
303            .expect("pk_serde must be set when intermediate_state_table is present");
304
305        let num_calls = partition.states.len();
306        let num_partition_key_cols = partition_key.len();
307
308        // Build the new row: partition_key columns + state_0..state_{n-1}
309        let mut new_row_values = Vec::with_capacity(num_partition_key_cols + num_calls);
310        for datum in partition_key.iter() {
311            new_row_values.push(datum.to_owned_datum());
312        }
313
314        // For each call, encode snapshot or preserve previous value
315        for (call_index, state) in partition.states.iter().enumerate() {
316            if let Some(snapshot) = state.snapshot() {
317                let snapshot_bytes = encode_snapshot(&snapshot, pk_serde);
318                new_row_values.push(Some(snapshot_bytes.into_boxed_slice().into()));
319            } else if let Some(ref old_row) = partition.intermediate_state_row {
320                let state_col = num_partition_key_cols + call_index;
321                if state_col < old_row.len() {
322                    new_row_values.push(old_row.datum_at(state_col).to_owned_datum());
323                } else {
324                    new_row_values.push(None);
325                }
326            } else {
327                new_row_values.push(None);
328            }
329        }
330        let new_row = OwnedRow::new(new_row_values);
331
332        // Upsert: update if old row exists, otherwise insert
333        if let Some(old_row) = partition.intermediate_state_row.take() {
334            intermediate_state_table.update(old_row, new_row.clone());
335        } else {
336            intermediate_state_table.insert(new_row.clone());
337        }
338        partition.intermediate_state_row = Some(new_row);
339    }
340
341    async fn ensure_key_in_cache(
342        this: &ExecutorInner<S>,
343        cache: &mut PartitionCache,
344        partition_key: impl Row,
345        encoded_partition_key: &MemcmpEncoded,
346    ) -> StreamExecutorResult<()> {
347        if cache.contains(encoded_partition_key) {
348            return Ok(());
349        }
350
351        let mut partition = Partition {
352            states: WindowStates::new(this.calls.iter().map(create_window_state).try_collect()?),
353            curr_row_buffer: Default::default(),
354            intermediate_state_row: None,
355        };
356
357        // If intermediate state table exists, load and restore intermediate state snapshots
358        Self::load_intermediate_state(this, &mut partition, &partition_key, encoded_partition_key)
359            .await?;
360
361        let sub_range: &(Bound<OwnedRow>, Bound<OwnedRow>) = &(Bound::Unbounded, Bound::Unbounded);
362        // Recover states from state table.
363        let table_iter = this
364            .state_table
365            .iter_with_prefix(&partition_key, sub_range, PrefetchOptions::default())
366            .await?;
367
368        #[for_await]
369        for keyed_row in table_iter {
370            let row = keyed_row?.into_owned_row();
371            let order_key_enc = memcmp_encoding::encode_row(
372                row::once(Some(
373                    row.datum_at(this.order_key_index)
374                        .expect("order key column must be non-NULL")
375                        .into_scalar_impl(),
376                )),
377                &[OrderType::ascending()],
378            )?;
379            let pk = (&row).project(&this.input_stream_key).into_owned_row();
380            let key = StateKey {
381                order_key: order_key_enc,
382                pk: pk.into(),
383            };
384            for (call, state) in this.calls.iter().zip_eq_fast(partition.states.iter_mut()) {
385                state.append(
386                    key.clone(),
387                    (&row)
388                        .project(call.args.val_indices())
389                        .into_owned_row()
390                        .as_inner()
391                        .into(),
392                );
393            }
394            partition.curr_row_buffer.push_back(row);
395        }
396
397        // Ensure states correctness.
398        assert!(partition.states.are_aligned());
399
400        // Ignore ready windows (all ready windows were outputted before).
401        // Use just_slide which calls slide_no_output and respects recovery skip logic.
402        while partition.states.are_ready() {
403            partition.states.just_slide()?;
404            partition.curr_row_buffer.pop_front();
405        }
406
407        cache.put(encoded_partition_key.clone(), partition);
408        Ok(())
409    }
410
411    async fn apply_chunk(
412        this: &mut ExecutorInner<S>,
413        vars: &mut ExecutionVars<S>,
414        chunk: StreamChunk,
415    ) -> StreamExecutorResult<Option<StreamChunk>> {
416        let mut builders = this.schema.create_array_builders(chunk.capacity()); // just an estimate
417        // Track partitions that produced output during this chunk, so we persist
418        // intermediate state only once per partition at the end.
419        let mut dirty_partitions: HashMap<MemcmpEncoded, OwnedRow> = HashMap::new();
420
421        // We assume that the input is sorted by order key.
422        for record in chunk.records() {
423            let input_row = must_match!(record, Record::Insert { new_row } => new_row);
424
425            let partition_key = input_row
426                .project(&this.partition_key_indices)
427                .into_owned_row();
428            let encoded_partition_key = memcmp_encoding::encode_row(
429                &partition_key,
430                &vec![OrderType::ascending(); this.partition_key_indices.len()],
431            )?;
432
433            // Get the partition.
434            Self::ensure_key_in_cache(
435                this,
436                &mut vars.partitions,
437                &partition_key,
438                &encoded_partition_key,
439            )
440            .await?;
441            let partition: &mut Partition =
442                &mut vars.partitions.get_mut(&encoded_partition_key).unwrap();
443
444            // Materialize input to state table.
445            this.state_table.insert(input_row);
446
447            // Feed the row to all window states.
448            let order_key_enc = memcmp_encoding::encode_row(
449                row::once(Some(
450                    input_row
451                        .datum_at(this.order_key_index)
452                        .expect("order key column must be non-NULL")
453                        .into_scalar_impl(),
454                )),
455                &[OrderType::ascending()],
456            )?;
457            let pk = input_row.project(&this.input_stream_key).into_owned_row();
458            let key = StateKey {
459                order_key: order_key_enc,
460                pk: pk.into(),
461            };
462            for (call, state) in this.calls.iter().zip_eq_fast(partition.states.iter_mut()) {
463                state.append(
464                    key.clone(),
465                    input_row
466                        .project(call.args.val_indices())
467                        .into_owned_row()
468                        .as_inner()
469                        .into(),
470                );
471            }
472            partition
473                .curr_row_buffer
474                .push_back(input_row.into_owned_row());
475
476            let mut has_output = false;
477            while partition.states.are_ready() {
478                has_output = true;
479                // The partition is ready to output, so we can produce a row.
480
481                // Get all outputs.
482                let (ret_values, evict_hint) = partition.states.slide()?;
483                let curr_row = partition
484                    .curr_row_buffer
485                    .pop_front()
486                    .expect("ready window must have corresponding current row");
487
488                // Append to output builders.
489                for (builder, datum) in builders.iter_mut().zip_eq_debug(
490                    curr_row
491                        .iter()
492                        .chain(ret_values.iter().map(|v| v.to_datum_ref())),
493                ) {
494                    builder.append(datum);
495                }
496
497                // Evict unneeded rows from state table.
498                if let StateEvictHint::CanEvict(keys_to_evict) = evict_hint {
499                    for key in keys_to_evict {
500                        let order_key = memcmp_encoding::decode_row(
501                            &key.order_key,
502                            &[this.schema[this.order_key_index].data_type()],
503                            &[OrderType::ascending()],
504                        )?;
505                        let state_row_pk = (&partition_key).chain(order_key).chain(key.pk);
506                        let state_row = {
507                            // FIXME(rc): quite hacky here, we may need `state_table.delete_by_pk`
508                            let mut state_row = vec![None; this.state_table_schema_len];
509                            for (i_in_pk, &i) in this.state_table.pk_indices().iter().enumerate() {
510                                state_row[i] = state_row_pk.datum_at(i_in_pk).to_owned_datum();
511                            }
512                            OwnedRow::new(state_row)
513                        };
514                        // NOTE: We don't know the value of the row here, so the table must allow
515                        // inconsistent ops.
516                        this.state_table.delete(state_row);
517                    }
518                }
519            }
520
521            if has_output && this.intermediate_state_table.is_some() {
522                dirty_partitions
523                    .entry(encoded_partition_key)
524                    .or_insert(partition_key);
525            }
526        }
527
528        // Persist intermediate state snapshots once per dirty partition at the end of the chunk.
529        for (encoded_partition_key, partition_key) in &dirty_partitions {
530            let partition = &mut *vars.partitions.get_mut(encoded_partition_key).unwrap();
531            Self::persist_intermediate_state(this, partition, partition_key);
532        }
533
534        let columns: Vec<ArrayRef> = builders.into_iter().map(|b| b.finish().into()).collect();
535        let chunk_size = columns[0].len();
536        Ok(if chunk_size > 0 {
537            Some(StreamChunk::new(vec![Op::Insert; chunk_size], columns))
538        } else {
539            None
540        })
541    }
542
543    #[try_stream(ok = Message, error = StreamExecutorError)]
544    async fn executor_inner(self) {
545        let EowcOverWindowExecutor {
546            input,
547            inner: mut this,
548        } = self;
549
550        let metrics_info = MetricsInfo::new(
551            this.actor_ctx.streaming_metrics.clone(),
552            this.state_table.table_id(),
553            this.actor_ctx.id,
554            "EowcOverWindow",
555        );
556
557        let mut vars = ExecutionVars {
558            partitions: ManagedLruCache::unbounded(this.watermark_sequence.clone(), metrics_info),
559            _phantom: PhantomData::<S>,
560        };
561
562        let mut input = input.execute();
563        let barrier = expect_first_barrier(&mut input).await?;
564        let first_epoch = barrier.epoch;
565        yield Message::Barrier(barrier);
566        this.state_table.init_epoch(first_epoch).await?;
567        if let Some(intermediate_state_table) = &mut this.intermediate_state_table {
568            intermediate_state_table.init_epoch(first_epoch).await?;
569        }
570
571        #[for_await]
572        for msg in input {
573            let msg = msg?;
574            match msg {
575                Message::Watermark(_) => {
576                    continue;
577                }
578                Message::Chunk(chunk) => {
579                    let output_chunk = Self::apply_chunk(&mut this, &mut vars, chunk).await?;
580                    if let Some(chunk) = output_chunk {
581                        yield Message::Chunk(chunk);
582                    }
583                    this.state_table.try_flush().await?;
584                    if let Some(intermediate_state_table) = &mut this.intermediate_state_table {
585                        intermediate_state_table.try_flush().await?;
586                    }
587                }
588                Message::Barrier(barrier) => {
589                    let post_commit = this.state_table.commit(barrier.epoch).await?;
590                    let intermediate_post_commit = if let Some(intermediate_state_table) =
591                        &mut this.intermediate_state_table
592                    {
593                        Some(intermediate_state_table.commit(barrier.epoch).await?)
594                    } else {
595                        None
596                    };
597
598                    vars.partitions.evict();
599
600                    let update_vnode_bitmap = barrier.as_update_vnode_bitmap(this.actor_ctx.id);
601                    yield Message::Barrier(barrier);
602
603                    let mut cache_may_stale = false;
604                    if let Some((_, stale)) = post_commit
605                        .post_yield_barrier(update_vnode_bitmap.clone())
606                        .await?
607                    {
608                        cache_may_stale = cache_may_stale || stale;
609                    }
610                    if let Some(intermediate_post_commit) = intermediate_post_commit
611                        && let Some((_, stale)) = intermediate_post_commit
612                            .post_yield_barrier(update_vnode_bitmap)
613                            .await?
614                    {
615                        cache_may_stale = cache_may_stale || stale;
616                    }
617                    if cache_may_stale {
618                        vars.partitions.clear();
619                    }
620                }
621            }
622        }
623    }
624}