risingwave_expr_macro/
context.rs1use itertools::Itertools;
16use proc_macro2::TokenStream;
17use quote::{ToTokens, quote, quote_spanned};
18use syn::parse::{Parse, ParseStream};
19use syn::{Error, FnArg, Ident, ItemFn, Pat, PatType, Result, ReturnType, Token, Type, Visibility};
20
21use crate::utils::extend_vis_with_super;
22
23#[derive(Debug, Clone)]
25pub(super) struct DefineContextField {
26 vis: Visibility,
27 name: Ident,
28 ty: Type,
29}
30
31#[derive(Debug, Clone)]
33pub(super) struct DefineContextAttr {
34 fields: Vec<DefineContextField>,
35}
36
37impl Parse for DefineContextField {
38 fn parse(input: ParseStream<'_>) -> Result<Self> {
39 let vis: Visibility = input.parse()?;
40 let name: Ident = input.parse()?;
41 input.parse::<Token![:]>()?;
42 let ty: Type = input.parse()?;
43
44 Ok(Self { vis, name, ty })
45 }
46}
47
48impl Parse for DefineContextAttr {
49 fn parse(input: ParseStream<'_>) -> Result<Self> {
50 let fields = input.parse_terminated(DefineContextField::parse, Token![,])?;
51 Ok(Self {
52 fields: fields.into_iter().collect(),
53 })
54 }
55}
56
57impl DefineContextField {
58 pub(super) fn r#gen(self) -> Result<TokenStream> {
59 let Self { vis, name, ty } = self;
60
61 let vis: Visibility = extend_vis_with_super(vis);
63
64 {
65 let name_s = name.to_string();
66 if name_s.to_uppercase() != name_s {
67 return Err(Error::new_spanned(
68 name,
69 "the name of context variable should be uppercase",
70 ));
71 }
72 }
73
74 Ok(quote! {
75 #[allow(non_snake_case)]
76 pub mod #name {
77 use super::*;
78 pub type Type = #ty;
79
80 tokio::task_local! {
81 static LOCAL_KEY: #ty;
82 }
83
84 #vis fn try_with<F, R>(f: F) -> Result<R, risingwave_expr::ExprError>
85 where
86 F: FnOnce(&#ty) -> R
87 {
88 LOCAL_KEY.try_with(f).map_err(|_| risingwave_expr::ContextUnavailable::new(stringify!(#name))).map_err(Into::into)
89 }
90
91 pub fn scope<F>(value: #ty, f: F) -> tokio::task::futures::TaskLocalFuture<#ty, F>
92 where
93 F: std::future::Future
94 {
95 LOCAL_KEY.scope(value, f)
96 }
97
98 pub fn sync_scope<F, R>(value: #ty, f: F) -> R
99 where
100 F: FnOnce() -> R
101 {
102 LOCAL_KEY.sync_scope(value, f)
103 }
104 }
105 })
106 }
107}
108
109impl DefineContextAttr {
110 pub(super) fn r#gen(self) -> Result<TokenStream> {
111 let generated_fields: Vec<TokenStream> = self
112 .fields
113 .into_iter()
114 .map(DefineContextField::r#gen)
115 .try_collect()?;
116 Ok(quote! {
117 #(#generated_fields)*
118 })
119 }
120}
121
122pub struct CaptureContextAttr {
123 captures: Vec<Ident>,
125}
126
127impl Parse for CaptureContextAttr {
128 fn parse(input: ParseStream<'_>) -> Result<Self> {
129 let captures = input.parse_terminated(Ident::parse, Token![,])?;
130 Ok(Self {
131 captures: captures.into_iter().collect(),
132 })
133 }
134}
135
136pub(super) fn generate_captured_function(
137 attr: CaptureContextAttr,
138 mut user_fn: ItemFn,
139) -> Result<TokenStream> {
140 let CaptureContextAttr { captures } = attr;
141 let is_async = user_fn.sig.asyncness.is_some();
142 let mut orig_user_fn = user_fn.clone();
143 if is_async {
144 let output_type = match &orig_user_fn.sig.output {
146 ReturnType::Type(_, ty) => ty.clone(),
147 ReturnType::Default => Box::new(syn::parse_quote!(())),
148 };
149 orig_user_fn.sig.output = ReturnType::Type(
150 syn::token::RArrow::default(),
151 Box::new(
152 syn::parse_quote!(impl std::future::Future<Output = #output_type> + Send + 'static),
153 ),
154 );
155 orig_user_fn.sig.asyncness = None;
156
157 let input_def: Vec<TokenStream> = orig_user_fn
159 .sig
160 .inputs
161 .iter()
162 .map(|arg| {
163 if let FnArg::Typed(PatType { pat, .. }) = arg {
164 if let Pat::Ident(ident) = pat.as_ref() {
165 let ident_name = &ident.ident;
166 return quote! {
167 let #ident_name = #ident_name.clone();
168 };
169 }
170 }
171 quote! {}
172 })
173 .collect();
174
175 let orig_body = &orig_user_fn.block;
177 orig_user_fn.block = Box::new(syn::parse_quote!({
178 #(#input_def)*
179 async move { #orig_body }
180 }));
181 }
182
183 let sig = &mut user_fn.sig;
184
185 let name = sig.ident.clone();
186
187 {
189 let new_name = format!("{}_captured", name);
190 let new_name = Ident::new(&new_name, sig.ident.span());
191 sig.ident = new_name;
192 }
193
194 if is_async {
195 sig.asyncness = Some(syn::token::Async::default());
197 }
198
199 let inputs = &mut sig.inputs;
201 if inputs.len() < captures.len() {
202 return Err(syn::Error::new_spanned(
203 inputs,
204 format!("expected at least {} inputs", captures.len()),
205 ));
206 }
207
208 let arg_names: Vec<_> = inputs
209 .iter()
210 .map(|arg| {
211 let FnArg::Typed(arg) = arg else {
212 return Err(syn::Error::new_spanned(
213 arg,
214 "receiver is not allowed in captured function",
215 ));
216 };
217 Ok(arg.pat.to_token_stream())
218 })
219 .try_collect()?;
220
221 let (captured_inputs, remained_inputs) = {
222 let mut inputs = inputs.iter().cloned();
223 let inputs = inputs.by_ref();
224 let captured_inputs = inputs.take(captures.len()).collect_vec();
225 let remained_inputs = inputs.collect_vec();
226 (captured_inputs, remained_inputs)
227 };
228 *inputs = remained_inputs.into_iter().collect();
229
230 let call_old_fn = quote! {
231 #name(#(#arg_names),*)
232 };
233
234 let new_body = {
235 let mut scoped = quote! {
236 #call_old_fn
237 };
238
239 #[allow(clippy::disallowed_methods)]
240 for (context, arg) in captures.into_iter().zip(captured_inputs.into_iter()) {
241 let FnArg::Typed(arg) = arg else {
242 return Err(syn::Error::new_spanned(
243 arg,
244 "receiver is not allowed in captured function",
245 ));
246 };
247 let name = arg.pat.into_token_stream();
248 scoped = if is_async {
250 quote_spanned! { context.span()=>
251 #context::try_with(|#name| { #scoped })
252 }
253 } else {
254 quote_spanned! { context.span()=>
255 #context::try_with(|#name| { #scoped }).flatten()
256 }
257 };
258 }
259 scoped
260 };
261 let new_user_fn = {
262 let vis = user_fn.vis;
263 let sig = user_fn.sig;
264 if is_async {
265 quote! {
266 #vis #sig {
267 {#new_body}?.await
268 }
269 }
270 } else {
271 quote! {
272 #vis #sig {
273 {#new_body}.map_err(Into::into)
274 }
275 }
276 }
277 };
278
279 Ok(quote! {
280 #orig_user_fn
281 #new_user_fn
282 })
283}