Writing an API in Rust using procedural macros

Procedural macros in Rust is a very powerful code generation tool that allows you to do without writing a ton of boilerplate code, or express some new concepts, as the developers of the crate did, for example async_trait.


Nevertheless, many people are justifiably afraid to use this tool, mainly due to the fact that parsing the syntax tree and macro attributes often turns into a “sunset manually”, since the problem has to be solved at a very low level.


In this article I want to share some, in my opinion, successful approaches to writing procedural macros, and show that today procedural macros can be created relatively simply and conveniently.


Foreword


First of all, let's define the problem that we will solve with the help of macros: we will try to define some abstract RPC API in the form of a trait, which then implements both the server part and the client part; and procedural macros, in turn, will help us do without a bunch of boilerplate code. Despite the fact that we will implement a somewhat abstract API, the task is actually quite vital, and, among other things, is ideal for demonstrating the capabilities of procedural macros.


The API itself will be executed according to a very simple principle: there are 4 types of requests:


  • GET requests with no parameters, for example: /ping.
  • GET requests with parameters, the parameters of which will be transmitted in the form of a URL query, for example: /status?name=foo&count=15.
  • POST requests without parameters.
  • POST requests with parameters that are passed as JSON objects.

In all cases, the server will respond with a valid JSON object.


As the server backend, we will use the crate warp.


Ideally, I want to get something like this:


//  :

/// ,   .      URL query,    JSON.
#[derive(Debug, FromUrlQuery, Deserialize, Serialize)]
struct Query {
    first: String,
    second: u64,
}

///  ,      
///  API  warp'.
#[http_api(warp = "serve_ping_interface")]
trait PingInterface {
    #[http_api_endpoint(method = "get")]
    fn get(&self) -> Result<Query, Error>;
    #[http_api_endpoint(method = "get")]
    fn check(&self, query: Query) -> Result<bool, Error>;
    #[http_api_endpoint(method = "post")]
    fn set_value(&self, param: Query) -> Result<(), Error>;
    #[http_api_endpoint(method = "post")]
    fn increment(&self) -> Result<(), Error>;
}

//    :

///  ,     .
#[derive(Debug, Default)]
struct ServiceInner {
    first: String,
    second: u64,
}

///  ,      .
#[derive(Clone, Default)]
struct ServiceImpl(Arc<RwLock<ServiceInner>>);

impl ServiceImpl {
    fn new() -> Self {
        Self::default()
    }

    fn read(&self) -> RwLockReadGuard<ServiceInner> {
        self.0.read().unwrap()
    }

    fn write(&self) -> RwLockWriteGuard<ServiceInner> {
        self.0.write().unwrap()
    }
}

//    :

impl PingInterface for ServiceImpl {
    fn get(&self) -> Result<Query, Error> {
        let inner = self.read();
        Ok(Query {
            first: inner.first.clone(),
            second: inner.second,
        })
    }

    fn check(&self, query: Query) -> Result<bool, Error> {
        let inner = self.read();
        Ok(inner.first == query.first && inner.second == query.second)
    }

    fn set_value(&self, param: Query) -> Result<(), Error> {
        let mut inner = self.write();
        inner.first = param.first;
        inner.second = param.second;
        Ok(())
    }

    fn increment(&self) -> Result<(), Error> {
        self.write().second += 1;
        Ok(())
    }
}

#[tokio::main]
async fn main() {
    let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
    //          API
    serve_ping_interface(ServiceImpl::new(), addr).await
}

, Rust' , , , .
: derive-, - ( serde), , .


, : http_api, , http_api_derive .


FromUrlQuery


, — , , . , - , .


, . URL query. , . :


pub trait FromUrlQuery: Sized {
    fn from_query_str(query: &str) -> Result<Self, ParseQueryError>;
}

, . derive :


///        `#[derive(FromUrlQuery)]`,     
///    #[from_url_query(rename = "bar", skip, etc)]
#[proc_macro_derive(FromUrlQuery, attributes(from_url_query))]
pub fn from_url_query(input: TokenStream) -> TokenStream {
    from_url_query::impl_from_url_query(input)
}

syn, quote. Rust , , .
quote!, Rust , .
, , , darling. ( , ).


, ** **. :

AST , , .


fn get_field_names(input: &DeriveInput) -> Option<Vec<(Ident, Action)>> {
    let data = match &input.data {
        Data::Struct(x) => Some(x),
        Data::Enum(..) => None,
        _ => panic!("Protobuf convert can be derived for structs and enums only."),
    };
    data.map(|data| {
        data.fields
            .iter()
            .map(|f| {
                let mut action = Action::Convert;
                for attr in &f.attrs {
                    match attr.parse_meta() {
                        Ok(syn::Meta::List(ref meta)) if meta.ident == "protobuf_convert" => {
                            for nested in &meta.nested {
                                match nested {
                                    syn::NestedMeta::Meta(syn::Meta::Word(ident))
                                        if ident == "skip" =>
                                    {
                                        action = Action::Skip;
                                    }
                                    _ => {
                                        panic!("Unknown attribute");
                                    }
                                }
                            }
                        }
                        _ => {
                            // Other attributes are ignored
                        }
                    }
                }
                (f.ident.clone().unwrap(), action)
            })
            .collect()
    })
}

fn get_field_names_enum(input: &DeriveInput) -> Option<Vec<Ident>> {
    let data = match &input.data {
        Data::Struct(..) => None,
        Data::Enum(x) => Some(x),
        _ => panic!("Protobuf convert can be derived for structs and enums only."),
    };
    data.map(|data| data.variants.iter().map(|f| f.ident.clone()).collect())
}

fn implement_protobuf_convert_from_pb(field_names: &[(Ident, Action)]) -> impl quote::ToTokens {
    let mut to_convert = vec![];
    let mut to_skip = vec![];
    for (x, a) in field_names {
        match a {
            Action::Convert => to_convert.push(x),
            Action::Skip => to_skip.push(x),
        }
    }

    let getters = to_convert
        .iter()
        .map(|i| Ident::new(&format!("get_{}", i), Span::call_site()));
    let our_struct_names = to_convert.clone();
    let our_struct_names_skip = to_skip;

    quote! {
        fn from_pb(pb: Self::ProtoStruct) -> std::result::Result<Self, _FailureError> {
          Ok(Self {
           #( #our_struct_names: ProtobufConvert::from_pb(pb.#getters().to_owned())?, )*
           #( #our_struct_names_skip: Default::default(), )*
          })
        }
    }
}

darling, .


, FromUrlQuery, , . , - :


#[derive(FromUrlQuery)]
struct OptionalQuery {
    first: String,
    opt_value: Option<u64>,
}

darling' , .


FromField, :


#[derive(Clone, Debug, FromField)]
struct QueryField {
    ident: Option<syn::Ident>,
    ty: syn::Type,
}

, , , :


#[derive(Clone, Debug, FromField)]
struct QueryField {
    ident: Option<syn::Ident>,
    ty: syn::Type,
    vis: syn::Visibility,
}

FromDeriveInput, :


#[derive(Debug, FromDeriveInput)]
//         
// ,          
//  ,   .
#[darling(supports(struct_named))]
struct FromUrlQuery {
    ident: syn::Ident,
    //           
    //  .
    //  darling::ast::Data   :   
    // ,     .
    //         ,   
    //   ().
    data: darling::ast::Data<(), QueryField>,
}

, .


let input: DeriveInput = syn::parse(input).unwrap();
let from_url_query = match FromUrlQuery::from_derive_input(&input) {
    Ok(parsed) => parsed,
    Err(e) => return e.write_errors().into(),
};

.


, URL query serde. serde , . Deserialize, serde_urlencoded. serde , .


#[doc(hidden)]
pub mod export {
    pub use serde;
    pub use serde_derive;
    pub use serde_urlencoded;
}

, FromUrlQuery:


impl FromUrlQuery {
    //     ,      
    //  "Serde".
    fn serde_wrapper_ident(&self) -> syn::Ident {
        let ident_str = format!("{}Serde", self.ident);
        syn::Ident::new(&ident_str, proc_macro2::Span::call_site())
    }

    ///        .
    fn impl_serde_wrapper(&self) -> impl ToTokens {
        //          
        // ,     `unwrap`  .
        let fields = self.data.clone().take_struct().unwrap();
        //         ,   
        //   Query  SerdeQuery,   .
        let wrapped_fields = fields.iter().map(|field| {
            let ident = &field.ident;
            let ty = &field.ty;
            quote! { #ident: #ty }
        });
        let from_fields = fields.iter().map(|field| {
            let ident = &field.ident;
            quote! { #ident: v.#ident }
        });

        let wrapped_ident = self.serde_wrapper_ident();
        let ident = &self.ident;

        //    ,      `quote!`
        //         , 
        //  ,    "#"  "$".
        quote! {
            //  serde    .
            use http_api::export::serde_derive::Deserialize;

            #[derive(Deserialize)]
            //    serde   ,   
            //  .
            #[serde(crate = "http_api::export::serde")]
            struct #wrapped_ident {
                #( #wrapped_fields, )*
            }

            impl From<#wrapped_ident> for #ident {
                fn from(v: #wrapped_ident) -> Self {
                    Self {
                        #( #from_fields, )*
                    }
                }
            }
        }
    }
}

, , , , , . , , ; , , .


http_api


FromDeriveInput, darling' , AST. , , :


:


#[proc_macro_attribute]
pub fn http_api(attr: TokenStream, item: TokenStream) -> TokenStream {
    //     :    ,  
    //     AST  .
    http_api::impl_http_api(attr, item)
}

, : , (, http_api_endpoint), . , TokenStream, "cannot find attribute http_api_endpoint in this scope", . , , , . , , , .


, http_api_endpoint, , .


#[proc_macro_attribute]
#[doc(hidden)]
pub fn http_api_endpoint(_attr: TokenStream, item: TokenStream) -> TokenStream {
    //      ,   `http_api_endpoint`
    //       `http_api` .

    //    `http_api_endpoint`  
    //   , ,  rustc  
    //   .
    item
}

, , .



, , :


//    :
#[http_api_endpoint(method = "#method_type")]
fn #method_name(&self) -> Result<$ResponseType, Error>;
//     :
#[http_api_endpoint(method = "#method_type")]
fn #method_name(&self, query: $QueryType) -> Result<$ResponseType, Error>;

HTTP , :


#[derive(Debug)]
enum SupportedHttpMethod {
    Get,
    Post,
}

impl FromMeta for SupportedHttpMethod {
    fn from_string(value: &str) -> Result<Self, darling::Error> {
        match value {
            "get" => Ok(SupportedHttpMethod::Get),
            "post" => Ok(SupportedHttpMethod::Post),
            other => Err(darling::Error::unknown_value(other)),
        }
    }
}

, :


#[derive(Debug, FromMeta)]
struct EndpointAttrs {
    //    ,      
    //  .
    method: SupportedHttpMethod,
    //     ,     
    //  None,       :
    // #[http_api_endpoint(method = "get", rename = "foo")]
    #[darling(default)]
    rename: Option<String>,
}

, syn::Signature, darling' : , FromMeta.
http_api_endpoint . syn::NestedMeta , (foo = "bar", boo(first, second)).


fn find_meta_attrs(name: &str, args: &[syn::Attribute]) -> Option<syn::NestedMeta> {
    args.as_ref()
        .iter()
        .filter_map(|a| a.parse_meta().ok())
        .find(|m| m.path().is_ident(name))
        .map(syn::NestedMeta::from)
}

. ,
— :


///      .
fn invalid_method(span: &impl syn::spanned::Spanned) -> darling::Error {
    darling::Error::custom(
        "API method should have one of `fn foo(&self) -> Result<Bar, Error>` or \
         `fn foo(&self, arg: Foo) -> Result<Bar, Error>` form",
    )
    .with_span(span)
}

impl ParsedEndpoint {
    fn parse(sig: &syn::Signature, attrs: &[syn::Attribute]) -> Result<Self, darling::Error> {
        ///      .
        let mut args = sig.inputs.iter();

        // ,     -   &self   ,
        //    &mut self,   &self: Arc<Self>  
        //  .
        if let Some(arg) = args.next() {
            match arg {
                // `self`  `syn`   Receiver.
                syn::FnArg::Receiver(syn::Receiver {
                    //  `reference`  ,     
                    // `&self`.
                    reference: Some(_),
                    //  `mutability`   ,   
                    //    `mut`.
                    mutability: None,
                    //      .
                    ..
                }) => {
                    //  ,    .
                }
                _ => {
                    //   -  ,  
                    //   .
                    return Err(invalid_method(&arg));
                }
            }
        } else {
            return Err(invalid_method(&sig));
        }

        //    .
        let arg = args
            .next()
            .map(|arg| match arg {
                // `FnArg`    `Typed`,  `Receiver`,  `Receiver`
                //      ,   
                //   .
                syn::FnArg::Typed(arg) => Ok(arg.ty.clone()),
                //      `self`.
                _ => unreachable!("Only first argument can be receiver."),
            })
            // Transpose   ,  
            // `Option<Result<...>>`  `Result<Option<...>>`,   
            //  .
            .transpose()?;

        //    ,     `Typed`,
        //   `Receiver`.
        let ret = match &sig.output {
            syn::ReturnType::Type(_, ty) => Ok(ty.clone()),
            _ => Err(invalid_method(&sig)),
        }?;

        // ,    ,   , 
        //   .
        //   `FromMeta::from_nested_meta`    .
        let attrs = find_meta_attrs("http_api_endpoint", attrs)
            .map(|meta| EndpointAttrs::from_nested_meta(&meta))
            .unwrap_or_else(|| Err(darling::Error::custom("todo")))?;

        /// ,    ,   .
        Ok(Self {
            ident: sig.ident.clone(),
            arg,
            ret,
            attrs,
        })
    }
}


. , , .
, :



///    ,     
///  .
#[derive(Debug)]
struct ParsedApiDefinition {
    ///  ,    .     
    ///     .
    item_trait: syn::ItemTrait,
    ///   .
    endpoints: Vec<ParsedEndpoint>,
    ///   .
    attrs: ApiAttrs,
}

#[derive(Debug, FromMeta)]
struct ApiAttrs {
    ///     ,   
    ///    warp'.
    warp: syn::Ident,
}

impl ParsedApiDefinition {
    fn parse(
        item_trait: syn::ItemTrait,
        attrs: &[syn::NestedMeta],
    ) -> Result<Self, darling::Error> {
        //       ,    
        //  ,       , 
        //   .
        let endpoints = item_trait
            .items
            .iter()
            .filter_map(|item| {
                if let syn::TraitItem::Method(method) = item {
                    Some(method)
                } else {
                    None
                }
            })
            .map(|method| ParsedEndpoint::parse(&method.sig, method.attrs.as_ref()))
            .collect::<Result<Vec<_>, darling::Error>>()?;

        //   .
        let attrs = ApiAttrs::from_list(attrs)?;

        //      HTTP API.
        Ok(Self {
            item_trait,
            endpoints,
            attrs,
        })
    }
}


, warp.


, , , warp . warp' , Filter. and, map, and_then, , .


, , GET JSON, - :


///    ,    
///   warp'
pub fn simple_get<F, R, E>(name: &'static str, handler: F) -> JsonReply
where
    F: Fn() -> Result<R, E> + Clone + Send + Sync + 'static,
    R: ser::Serialize,
    E: Reject,
{
    //    ,    GET ,
    //   .
    warp::get()
        //     ,   
        //     {name}
        .and(warp::path(name))
        //     and_then    
        //      JSON 
        .and_then(move || {
            let handler = handler.clone();
            //      , 
            //        async .
            async move {
                match handler() {
                    Ok(value) => Ok(warp::reply::json(&value)),
                    //  warp'     ,
                    //     ,    
                    //       -
                    //   .
                    Err(e) => Err(warp::reject::custom(e)),
                }
            }
        })
        .boxed()
}

GET , , :


pub fn query_get<F, Q, R, E>(name: &'static str, handler: F) -> JsonReply
where
    F: Fn(Q) -> Result<R, E> + Clone + Send + Sync + 'static,
    Q: FromUrlQuery,
    R: ser::Serialize,
    E: Reject,
{
    warp::get()
        .and(warp::path(name))
        //       URL query,   
        //    query.
        .and(warp::filters::query::raw())
        .and_then(move |raw_query: String| {
            let handler = handler.clone();
            async move {
                //         
                // FromUrlQuery      
                //  .
                let query = Q::from_query_str(&raw_query)
                    .map_err(|_| warp::reject::custom(IncorrectQuery))?;

                match handler(query) {
                    Ok(value) => Ok(warp::reply::json(&value)),
                    Err(e) => Err(warp::reject::custom(e)),
                }
            }
        })
        .boxed()
}

.



and, ,
or, , ,
, .
:


use std::net::SocketAddr;
use warp::Filter;

//    warp     `/:u32` 
// `/:socketaddr`
warp::path::param::<u32>()
    .or(warp::path::param::<SocketAddr>());

serve_ping_interface. warp , service , -.


impl ParsedEndpoint {
    fn impl_endpoint_handler(&self) -> impl ToTokens {
        //      .
        let path = self.endpoint_path();
        let ident = &self.ident;

        //        ,  
        //   warp .
        match (&self.attrs.method, &self.arg) {
            (SupportedHttpMethod::Get, None) => {
                quote! {
                    let #ident = http_api::warp_backend::simple_get(#path, {
                        let out = service.clone();
                        move || out.#ident()
                    });
                }
            }

            (SupportedHttpMethod::Get, Some(_arg)) => {
                quote! {
                    let #ident = http_api::warp_backend::query_get(#path, {
                        let out = service.clone();
                        move |query| out.#ident(query)
                    });
                }
            }

            (SupportedHttpMethod::Post, None) => {
                quote! {
                    let #ident = http_api::warp_backend::simple_post(#path, {
                        let out = service.clone();
                        move || out.#ident()
                    });
                }
            }

            (SupportedHttpMethod::Post, Some(_arg)) => {
                quote! {
                    let #ident = http_api::warp_backend::params_post(#path, {
                        let out = service.clone();
                        move |params| out.#ident(params)
                    });
                }
            }
        }
    }
}

or .


impl ToTokens for ParsedApiDefinition {
    fn to_tokens(&self, out: &mut proc_macro2::TokenStream) {
        let fn_name = &self.attrs.warp;
        let interface = &self.item_trait.ident;

        //     ,    
        //   ,    
        //    .
        let (filters, idents): (Vec<_>, Vec<_>) = self
            .endpoints
            .iter()
            .map(|endpoint| {
                let ident = &endpoint.ident;
                let handler = endpoint.impl_endpoint_handler();

                (handler, ident)
            })
            .unzip();

        let mut tail = idents.into_iter();
        //        
        // `a.or(b).or(c).or(d)`,     
        //   ,   .
        let head = tail.next().unwrap();
        let serve_impl = quote! {
            #head #( .or(#tail) )*
        };

        //  :   .
        let tokens = quote! {
            fn #fn_name<T>(
                service: T,
                addr: impl Into<std::net::SocketAddr>,
            ) -> impl std::future::Future<Output = ()>
            where
                T: #interface + Clone + Send + Sync + 'static,
            {
                use warp::Filter;

                //      .
                #( #filters )*

                //      API  
                // warp .
                warp::serve(#serve_impl).run(addr.into())
            }

        };
        out.extend(tokens)
    }
}


, derive ,
.
, ,
RPC, , Rust'.
, - HTTP
reqwest .


Using macros, no one bothers to go further and display the openapi or swagger
specification for interface types. But it seems to me that in this case it is better to go the other way and write a Rust code generator according to the specification, this will give more room for maneuvers.
If you write this generator in the form of a build dependency, then you can use the libraries
synand quote, thus, writing the generator will be very comfortable and simple. However, this is already far-reaching thoughts :)


Fully working code, examples of which were given in this article, can be found at this
link .


Thank you for the attention!

Source: https://habr.com/ru/post/undefined/


All Articles