Skip to content

Commit

Permalink
openapi: Add support for multiple authentication methods #627
Browse files Browse the repository at this point in the history
  • Loading branch information
sunli829 committed Aug 12, 2023
1 parent dccbd34 commit 8a6f5db
Show file tree
Hide file tree
Showing 6 changed files with 288 additions and 113 deletions.
8 changes: 5 additions & 3 deletions poem-openapi-derive/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,11 @@ fn generate_operation(
let scopes = &operation_param.scopes;
security.push(quote! {
if <#arg_ty as #crate_name::ApiExtractor>::TYPES.contains(&#crate_name::ApiExtractorType::SecurityScheme) {
security.push(<::std::collections::HashMap<&'static str, ::std::vec::Vec<&'static str>> as ::std::convert::From<_>>::from([
(<#arg_ty as #crate_name::ApiExtractor>::security_scheme().unwrap(), ::std::vec![#(#crate_name::OAuthScopes::name(&#scopes)),*])
]));
for security_name in <#arg_ty as #crate_name::ApiExtractor>::security_schemes() {
security.push(<::std::collections::HashMap<&'static str, ::std::vec::Vec<&'static str>> as ::std::convert::From<_>>::from([
(security_name, ::std::vec![#(#crate_name::OAuthScopes::name(&#scopes)),*])
]));
}
}
});
}
Expand Down
235 changes: 156 additions & 79 deletions poem-openapi-derive/src/security_scheme.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use darling::{
ast::{Data, Style},
util::{Ignored, SpannedValue},
FromDeriveInput, FromMeta,
ast::{Data, Fields, Style},
util::SpannedValue,
FromDeriveInput, FromMeta, FromVariant,
};
use http::header::HeaderName;
use proc_macro2::{Ident, Span, TokenStream};
Expand Down Expand Up @@ -185,18 +185,24 @@ pub(crate) enum ApiKeyInType {
Cookie,
}

#[derive(FromVariant)]
struct SecuritySchemeItem {
ident: Ident,
fields: Fields<syn::Type>,
}

#[derive(FromDeriveInput)]
#[darling(attributes(oai), forward_attrs(doc))]
struct SecuritySchemeArgs {
ident: Ident,
data: Data<Ignored, syn::Type>,
data: Data<SecuritySchemeItem, syn::Type>,
attrs: Vec<Attribute>,

#[darling(default)]
internal: bool,
#[darling(default)]
rename: Option<String>,
ty: AuthType,
ty: Option<AuthType>,
#[darling(default)]
key_in: Option<ApiKeyInType>,
#[darling(default)]
Expand All @@ -212,8 +218,15 @@ struct SecuritySchemeArgs {
}

impl SecuritySchemeArgs {
fn validate(&self) -> GeneratorResult<()> {
fn auth_type(&self) -> GeneratorResult<AuthType> {
match self.ty {
Some(ty) => Ok(ty),
None => Err(Error::new_spanned(&self.ident, "Missing an auth type.").into()),
}
}

fn validate(&self) -> GeneratorResult<()> {
match self.auth_type()? {
AuthType::ApiKey => self.validate_api_key(),
AuthType::OAuth2 => self.validate_oauth2(),
AuthType::OpenIdConnect => self.validate_openid_connect(),
Expand Down Expand Up @@ -308,7 +321,7 @@ impl SecuritySchemeArgs {
None => quote!(::std::option::Option::None),
};

let ts = match self.ty {
let ts = match self.auth_type()? {
AuthType::ApiKey => {
quote! {
registry.create_security_scheme(#name, #crate_name::registry::MetaSecurityScheme {
Expand Down Expand Up @@ -384,29 +397,31 @@ impl SecuritySchemeArgs {
Ok(ts)
}

fn generate_from_request(&self, crate_name: &TokenStream) -> TokenStream {
match self.ty {
fn generate_from_request(&self, crate_name: &TokenStream) -> GeneratorResult<TokenStream> {
match self.auth_type()? {
AuthType::ApiKey => {
let key_name = self.key_name.as_ref().unwrap().as_str();
let param_in = match self.key_in.as_ref().unwrap() {
ApiKeyInType::Query => quote!(#crate_name::registry::MetaParamIn::Query),
ApiKeyInType::Header => quote!(#crate_name::registry::MetaParamIn::Header),
ApiKeyInType::Cookie => quote!(#crate_name::registry::MetaParamIn::Cookie),
};
quote!(<#crate_name::auth::ApiKey as #crate_name::auth::ApiKeyAuthorization>::from_request(req, query, #key_name, #param_in))
}
AuthType::Basic => {
quote!(<#crate_name::auth::Basic as #crate_name::auth::BasicAuthorization>::from_request(req))
}
AuthType::Bearer => {
quote!(<#crate_name::auth::Bearer as #crate_name::auth::BearerAuthorization>::from_request(req))
}
AuthType::OAuth2 => {
quote!(<#crate_name::auth::Bearer as #crate_name::auth::BearerAuthorization>::from_request(req))
}
AuthType::OpenIdConnect => {
quote!(<#crate_name::auth::Bearer as #crate_name::auth::BearerAuthorization>::from_request(req))
Ok(
quote!(<#crate_name::auth::ApiKey as #crate_name::auth::ApiKeyAuthorization>::from_request(req, query, #key_name, #param_in)),
)
}
AuthType::Basic => Ok(
quote!(<#crate_name::auth::Basic as #crate_name::auth::BasicAuthorization>::from_request(req)),
),
AuthType::Bearer => Ok(
quote!(<#crate_name::auth::Bearer as #crate_name::auth::BearerAuthorization>::from_request(req)),
),
AuthType::OAuth2 => Ok(
quote!(<#crate_name::auth::Bearer as #crate_name::auth::BearerAuthorization>::from_request(req)),
),
AuthType::OpenIdConnect => Ok(
quote!(<#crate_name::auth::Bearer as #crate_name::auth::BearerAuthorization>::from_request(req)),
),
}
}
}
Expand All @@ -415,69 +430,131 @@ pub(crate) fn generate(args: DeriveInput) -> GeneratorResult<TokenStream> {
let args: SecuritySchemeArgs = SecuritySchemeArgs::from_derive_input(&args)?;
let crate_name = get_crate_name(args.internal);
let ident = &args.ident;
let oai_typename = args.rename.clone().unwrap_or_else(|| ident.to_string());
args.validate()?;

let fields = match &args.data {
Data::Struct(e) => e,
_ => {
return Err(Error::new_spanned(
ident,
"SecurityScheme can only be applied to an struct.",
)
.into())
}
};

if fields.style == Style::Tuple && fields.fields.len() != 1 {
return Err(Error::new_spanned(
ident,
"Only one unnamed field is allowed in the SecurityScheme struct.",
)
.into());
}
match &args.data {
Data::Struct(fields) => {
let oai_typename = args.rename.clone().unwrap_or_else(|| ident.to_string());

let register_security_scheme =
args.generate_register_security_scheme(&crate_name, &oai_typename)?;
let from_request = args.generate_from_request(&crate_name);
let path = args.checker.as_ref();

let output = match path {
Some(_) => quote! {
let output = #crate_name::__private::CheckerReturn::from(#path(&req, #from_request?).await).into_result()?;
},
None => quote! {
let output = #from_request?;
},
};

let expanded = quote! {
#[#crate_name::__private::poem::async_trait]
impl<'a> #crate_name::ApiExtractor<'a> for #ident {
const TYPES: &'static [#crate_name::ApiExtractorType] = &[#crate_name::ApiExtractorType::SecurityScheme];

type ParamType = ();
type ParamRawType = ();

fn register(registry: &mut #crate_name::registry::Registry) {
#register_security_scheme
if fields.style != Style::Tuple || fields.fields.len() != 1 {
return Err(Error::new_spanned(
ident,
"Only one unnamed field is allowed in the SecurityScheme struct.",
)
.into());
}

fn security_scheme() -> ::std::option::Option<&'static str> {
::std::option::Option::Some(#oai_typename)
args.validate()?;

let register_security_scheme =
args.generate_register_security_scheme(&crate_name, &oai_typename)?;
let from_request = args.generate_from_request(&crate_name)?;
let path = args.checker.as_ref();

let output = match path {
Some(_) => quote! {
let output = #crate_name::__private::CheckerReturn::from(#path(&req, #from_request?).await).into_result()?;
},
None => quote! {
let output = #from_request?;
},
};

let expanded = quote! {
#[#crate_name::__private::poem::async_trait]
impl<'a> #crate_name::ApiExtractor<'a> for #ident {
const TYPES: &'static [#crate_name::ApiExtractorType] = &[#crate_name::ApiExtractorType::SecurityScheme];

type ParamType = ();
type ParamRawType = ();

fn register(registry: &mut #crate_name::registry::Registry) {
#register_security_scheme
}

fn security_schemes() -> ::std::vec::Vec<&'static str> {
::std::vec![#oai_typename]
}

async fn from_request(
req: &'a #crate_name::__private::poem::Request,
body: &mut #crate_name::__private::poem::RequestBody,
_param_opts: #crate_name::ExtractParamOptions<Self::ParamType>,
) -> #crate_name::__private::poem::Result<Self> {
let query = req.extensions().get::<#crate_name::__private::UrlQuery>().unwrap();
#output
::std::result::Result::Ok(Self(output))
}
}
};

Ok(expanded)
}
Data::Enum(items) => {
let mut registers = Vec::new();
let mut security_schemes = Vec::new();
let mut from_requests = Vec::new();

if items.is_empty() {
return Err(Error::new_spanned(ident, "At least one member is required.").into());
}

async fn from_request(
req: &'a #crate_name::__private::poem::Request,
body: &mut #crate_name::__private::poem::RequestBody,
_param_opts: #crate_name::ExtractParamOptions<Self::ParamType>,
) -> #crate_name::__private::poem::Result<Self> {
let query = req.extensions().get::<#crate_name::__private::UrlQuery>().unwrap();
#output
::std::result::Result::Ok(Self(output))
for item in items {
if item.fields.style != Style::Tuple || item.fields.fields.len() != 1 {
return Err(Error::new_spanned(
ident,
"Only one unnamed field is allowed in the SecurityScheme enum.",
)
.into());
}

let item_ident = &item.ident;
let item_type = &item.fields.fields[0];

registers.push(
quote! { <#item_type as #crate_name::ApiExtractor>::register(registry); },
);
security_schemes.push(quote! {
security_schemes.extend(<#item_type as #crate_name::ApiExtractor>::security_schemes());
});
from_requests.push(quote! {
match <#item_type as #crate_name::ApiExtractor>::from_request(req, body, param_opts.clone()).await {
::std::result::Result::Ok(item) => return Ok(#ident::#item_ident(item)),
::std::result::Result::Err(err) => last_err = ::std::option::Option::Some(err),
}
})
}
}
};

Ok(expanded)
let expanded = quote! {
#[#crate_name::__private::poem::async_trait]
impl<'a> #crate_name::ApiExtractor<'a> for #ident {
const TYPES: &'static [#crate_name::ApiExtractorType] = &[#crate_name::ApiExtractorType::SecurityScheme];

type ParamType = ();
type ParamRawType = ();

fn register(registry: &mut #crate_name::registry::Registry) {
#(#registers)*
}

fn security_schemes() -> ::std::vec::Vec<&'static str> {
let mut security_schemes = ::std::vec![];
#(#security_schemes)*
security_schemes
}

async fn from_request(
req: &'a #crate_name::__private::poem::Request,
body: &mut #crate_name::__private::poem::RequestBody,
param_opts: #crate_name::ExtractParamOptions<Self::ParamType>,
) -> #crate_name::__private::poem::Result<Self> {
let mut last_err = ::std::option::Option::None;
#(#from_requests)*
::std::result::Result::Err(last_err.unwrap())
}
}
};

Ok(expanded)
}
}
}
8 changes: 8 additions & 0 deletions poem-openapi/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

# [3.0.2] 2023-08-12

- Add support for multiple authentication methods. [#627](https://github.com/poem-web/poem/discussions/627)

## Breaking changes

- change `fn ApiExtractor::security_schemes() -> Option<&str>` to `fn ApiExtractor::security_schemes() -> Vec<&str>`

# [3.0.1] 2023-08-02

- openapi: allows multiple secutity schemes on one operation [#621](https://github.com/poem-web/poem/issues/621)
Expand Down
7 changes: 4 additions & 3 deletions poem-openapi/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ impl UrlQuery {
}

/// Options for the parameter extractor.
#[derive(Clone)]
pub struct ExtractParamOptions<T> {
/// The name of this parameter.
pub name: &'static str,
Expand Down Expand Up @@ -164,9 +165,9 @@ pub trait ApiExtractor<'a>: Sized {
/// Register related types to registry.
fn register(registry: &mut Registry) {}

/// Returns name of security scheme if this extractor is security scheme.
fn security_scheme() -> Option<&'static str> {
None
/// Returns names of security scheme if this extractor is security scheme.
fn security_schemes() -> Vec<&'static str> {
vec![]
}

/// Returns the location of the parameter if this extractor is parameter.
Expand Down

0 comments on commit 8a6f5db

Please sign in to comment.