risingwave_expr_macro/
context.rs

1// Copyright 2025 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use itertools::Itertools;
16use proc_macro2::TokenStream;
17use quote::{ToTokens, quote, quote_spanned};
18use syn::parse::{Parse, ParseStream};
19use syn::{Error, FnArg, Ident, ItemFn, Pat, PatType, Result, ReturnType, Token, Type, Visibility};
20
21use crate::utils::extend_vis_with_super;
22
23/// See [`super::define_context!`].
24#[derive(Debug, Clone)]
25pub(super) struct DefineContextField {
26    vis: Visibility,
27    name: Ident,
28    ty: Type,
29}
30
31/// See [`super::define_context!`].
32#[derive(Debug, Clone)]
33pub(super) struct DefineContextAttr {
34    fields: Vec<DefineContextField>,
35}
36
37impl Parse for DefineContextField {
38    fn parse(input: ParseStream<'_>) -> Result<Self> {
39        let vis: Visibility = input.parse()?;
40        let name: Ident = input.parse()?;
41        input.parse::<Token![:]>()?;
42        let ty: Type = input.parse()?;
43
44        Ok(Self { vis, name, ty })
45    }
46}
47
48impl Parse for DefineContextAttr {
49    fn parse(input: ParseStream<'_>) -> Result<Self> {
50        let fields = input.parse_terminated(DefineContextField::parse, Token![,])?;
51        Ok(Self {
52            fields: fields.into_iter().collect(),
53        })
54    }
55}
56
57impl DefineContextField {
58    pub(super) fn r#gen(self) -> Result<TokenStream> {
59        let Self { vis, name, ty } = self;
60
61        // We create a sub mod, so we need to extend the vis of getter.
62        let vis: Visibility = extend_vis_with_super(vis);
63
64        {
65            let name_s = name.to_string();
66            if name_s.to_uppercase() != name_s {
67                return Err(Error::new_spanned(
68                    name,
69                    "the name of context variable should be uppercase",
70                ));
71            }
72        }
73
74        Ok(quote! {
75            #[allow(non_snake_case)]
76            pub mod #name {
77                use super::*;
78                pub type Type = #ty;
79
80                tokio::task_local! {
81                    static LOCAL_KEY: #ty;
82                }
83
84                #vis fn try_with<F, R>(f: F) -> Result<R, risingwave_expr::ExprError>
85                where
86                    F: FnOnce(&#ty) -> R
87                {
88                    LOCAL_KEY.try_with(f).map_err(|_| risingwave_expr::ContextUnavailable::new(stringify!(#name))).map_err(Into::into)
89                }
90
91                pub fn scope<F>(value: #ty, f: F) -> tokio::task::futures::TaskLocalFuture<#ty, F>
92                where
93                    F: std::future::Future
94                {
95                    LOCAL_KEY.scope(value, f)
96                }
97
98                pub fn sync_scope<F, R>(value: #ty, f: F) -> R
99                where
100                    F: FnOnce() -> R
101                {
102                    LOCAL_KEY.sync_scope(value, f)
103                }
104            }
105        })
106    }
107}
108
109impl DefineContextAttr {
110    pub(super) fn r#gen(self) -> Result<TokenStream> {
111        let generated_fields: Vec<TokenStream> = self
112            .fields
113            .into_iter()
114            .map(DefineContextField::r#gen)
115            .try_collect()?;
116        Ok(quote! {
117            #(#generated_fields)*
118        })
119    }
120}
121
122pub struct CaptureContextAttr {
123    /// The context variables which are captured.
124    captures: Vec<Ident>,
125}
126
127impl Parse for CaptureContextAttr {
128    fn parse(input: ParseStream<'_>) -> Result<Self> {
129        let captures = input.parse_terminated(Ident::parse, Token![,])?;
130        Ok(Self {
131            captures: captures.into_iter().collect(),
132        })
133    }
134}
135
136pub(super) fn generate_captured_function(
137    attr: CaptureContextAttr,
138    mut user_fn: ItemFn,
139) -> Result<TokenStream> {
140    let CaptureContextAttr { captures } = attr;
141    let is_async = user_fn.sig.asyncness.is_some();
142    let mut orig_user_fn = user_fn.clone();
143    if is_async {
144        // Modify the return type to impl Future<Output = output> + Send + 'static for the original function.
145        let output_type = match &orig_user_fn.sig.output {
146            ReturnType::Type(_, ty) => ty.clone(),
147            ReturnType::Default => Box::new(syn::parse_quote!(())),
148        };
149        orig_user_fn.sig.output = ReturnType::Type(
150            syn::token::RArrow::default(),
151            Box::new(
152                syn::parse_quote!(impl std::future::Future<Output = #output_type> + Send + 'static),
153            ),
154        );
155        orig_user_fn.sig.asyncness = None;
156
157        // Generate clone statements for each input
158        let input_def: Vec<TokenStream> = orig_user_fn
159            .sig
160            .inputs
161            .iter()
162            .map(|arg| {
163                if let FnArg::Typed(PatType { pat, .. }) = arg {
164                    if let Pat::Ident(ident) = pat.as_ref() {
165                        let ident_name = &ident.ident;
166                        return quote! {
167                            let #ident_name = #ident_name.clone();
168                        };
169                    }
170                }
171                quote! {}
172            })
173            .collect();
174
175        // Wrap the original function body in async move { ... }.
176        let orig_body = &orig_user_fn.block;
177        orig_user_fn.block = Box::new(syn::parse_quote!({
178            #(#input_def)*
179            async move { #orig_body }
180        }));
181    }
182
183    let sig = &mut user_fn.sig;
184
185    let name = sig.ident.clone();
186
187    // Modify the name.
188    {
189        let new_name = format!("{}_captured", name);
190        let new_name = Ident::new(&new_name, sig.ident.span());
191        sig.ident = new_name;
192    }
193
194    if is_async {
195        // Ensure the function is async
196        sig.asyncness = Some(syn::token::Async::default());
197    }
198
199    // Modify the inputs of sig.
200    let inputs = &mut sig.inputs;
201    if inputs.len() < captures.len() {
202        return Err(syn::Error::new_spanned(
203            inputs,
204            format!("expected at least {} inputs", captures.len()),
205        ));
206    }
207
208    let arg_names: Vec<_> = inputs
209        .iter()
210        .map(|arg| {
211            let FnArg::Typed(arg) = arg else {
212                return Err(syn::Error::new_spanned(
213                    arg,
214                    "receiver is not allowed in captured function",
215                ));
216            };
217            Ok(arg.pat.to_token_stream())
218        })
219        .try_collect()?;
220
221    let (captured_inputs, remained_inputs) = {
222        let mut inputs = inputs.iter().cloned();
223        let inputs = inputs.by_ref();
224        let captured_inputs = inputs.take(captures.len()).collect_vec();
225        let remained_inputs = inputs.collect_vec();
226        (captured_inputs, remained_inputs)
227    };
228    *inputs = remained_inputs.into_iter().collect();
229
230    let call_old_fn = quote! {
231        #name(#(#arg_names),*)
232    };
233
234    let new_body = {
235        let mut scoped = quote! {
236            #call_old_fn
237        };
238
239        #[allow(clippy::disallowed_methods)]
240        for (context, arg) in captures.into_iter().zip(captured_inputs.into_iter()) {
241            let FnArg::Typed(arg) = arg else {
242                return Err(syn::Error::new_spanned(
243                    arg,
244                    "receiver is not allowed in captured function",
245                ));
246            };
247            let name = arg.pat.into_token_stream();
248            // TODO: Can we add an assertion here that `&<<#context::Type> as Deref>::Target` is same as `#arg.ty`?
249            scoped = if is_async {
250                quote_spanned! { context.span()=>
251                    #context::try_with(|#name| { #scoped })
252                }
253            } else {
254                quote_spanned! { context.span()=>
255                    #context::try_with(|#name| { #scoped }).flatten()
256                }
257            };
258        }
259        scoped
260    };
261    let new_user_fn = {
262        let vis = user_fn.vis;
263        let sig = user_fn.sig;
264        if is_async {
265            quote! {
266                #vis #sig {
267                    {#new_body}?.await
268                }
269            }
270        } else {
271            quote! {
272                #vis #sig {
273                    {#new_body}.map_err(Into::into)
274                }
275            }
276        }
277    };
278
279    Ok(quote! {
280        #orig_user_fn
281        #new_user_fn
282    })
283}