pin_init_internal/helpers.rs
1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3#[cfg(not(kernel))]
4use proc_macro2 as proc_macro;
5
6use proc_macro::{TokenStream, TokenTree};
7
8/// Parsed generics.
9///
10/// See the field documentation for an explanation what each of the fields represents.
11///
12/// # Examples
13///
14/// ```rust,ignore
15/// # let input = todo!();
16/// let (Generics { decl_generics, impl_generics, ty_generics }, rest) = parse_generics(input);
17/// quote! {
18///     struct Foo<$($decl_generics)*> {
19///         // ...
20///     }
21///
22///     impl<$impl_generics> Foo<$ty_generics> {
23///         fn foo() {
24///             // ...
25///         }
26///     }
27/// }
28/// ```
29pub(crate) struct Generics {
30    /// The generics with bounds and default values (e.g. `T: Clone, const N: usize = 0`).
31    ///
32    /// Use this on type definitions e.g. `struct Foo<$decl_generics> ...` (or `union`/`enum`).
33    pub(crate) decl_generics: Vec<TokenTree>,
34    /// The generics with bounds (e.g. `T: Clone, const N: usize`).
35    ///
36    /// Use this on `impl` blocks e.g. `impl<$impl_generics> Trait for ...`.
37    pub(crate) impl_generics: Vec<TokenTree>,
38    /// The generics without bounds and without default values (e.g. `T, N`).
39    ///
40    /// Use this when you use the type that is declared with these generics e.g.
41    /// `Foo<$ty_generics>`.
42    pub(crate) ty_generics: Vec<TokenTree>,
43}
44
45/// Parses the given `TokenStream` into `Generics` and the rest.
46///
47/// The generics are not present in the rest, but a where clause might remain.
48pub(crate) fn parse_generics(input: TokenStream) -> (Generics, Vec<TokenTree>) {
49    // The generics with bounds and default values.
50    let mut decl_generics = vec![];
51    // `impl_generics`, the declared generics with their bounds.
52    let mut impl_generics = vec![];
53    // Only the names of the generics, without any bounds.
54    let mut ty_generics = vec![];
55    // Tokens not related to the generics e.g. the `where` token and definition.
56    let mut rest = vec![];
57    // The current level of `<`.
58    let mut nesting = 0;
59    let mut toks = input.into_iter();
60    // If we are at the beginning of a generic parameter.
61    let mut at_start = true;
62    let mut skip_until_comma = false;
63    while let Some(tt) = toks.next() {
64        if nesting == 1 && matches!(&tt, TokenTree::Punct(p) if p.as_char() == '>') {
65            // Found the end of the generics.
66            break;
67        } else if nesting >= 1 {
68            decl_generics.push(tt.clone());
69        }
70        match tt.clone() {
71            TokenTree::Punct(p) if p.as_char() == '<' => {
72                if nesting >= 1 && !skip_until_comma {
73                    // This is inside of the generics and part of some bound.
74                    impl_generics.push(tt);
75                }
76                nesting += 1;
77            }
78            TokenTree::Punct(p) if p.as_char() == '>' => {
79                // This is a parsing error, so we just end it here.
80                if nesting == 0 {
81                    break;
82                } else {
83                    nesting -= 1;
84                    if nesting >= 1 && !skip_until_comma {
85                        // We are still inside of the generics and part of some bound.
86                        impl_generics.push(tt);
87                    }
88                }
89            }
90            TokenTree::Punct(p) if skip_until_comma && p.as_char() == ',' => {
91                if nesting == 1 {
92                    impl_generics.push(tt.clone());
93                    impl_generics.push(tt);
94                    skip_until_comma = false;
95                }
96            }
97            _ if !skip_until_comma => {
98                match nesting {
99                    // If we haven't entered the generics yet, we still want to keep these tokens.
100                    0 => rest.push(tt),
101                    1 => {
102                        // Here depending on the token, it might be a generic variable name.
103                        match tt.clone() {
104                            TokenTree::Ident(i) if at_start && i.to_string() == "const" => {
105                                let Some(name) = toks.next() else {
106                                    // Parsing error.
107                                    break;
108                                };
109                                impl_generics.push(tt);
110                                impl_generics.push(name.clone());
111                                ty_generics.push(name.clone());
112                                decl_generics.push(name);
113                                at_start = false;
114                            }
115                            TokenTree::Ident(_) if at_start => {
116                                impl_generics.push(tt.clone());
117                                ty_generics.push(tt);
118                                at_start = false;
119                            }
120                            TokenTree::Punct(p) if p.as_char() == ',' => {
121                                impl_generics.push(tt.clone());
122                                ty_generics.push(tt);
123                                at_start = true;
124                            }
125                            // Lifetimes begin with `'`.
126                            TokenTree::Punct(p) if p.as_char() == '\'' && at_start => {
127                                impl_generics.push(tt.clone());
128                                ty_generics.push(tt);
129                            }
130                            // Generics can have default values, we skip these.
131                            TokenTree::Punct(p) if p.as_char() == '=' => {
132                                skip_until_comma = true;
133                            }
134                            _ => impl_generics.push(tt),
135                        }
136                    }
137                    _ => impl_generics.push(tt),
138                }
139            }
140            _ => {}
141        }
142    }
143    rest.extend(toks);
144    (
145        Generics {
146            impl_generics,
147            decl_generics,
148            ty_generics,
149        },
150        rest,
151    )
152}