diff --git a/poem-openapi-derive/src/api.rs b/poem-openapi-derive/src/api.rs index b86a83094e..7b950c17d6 100644 --- a/poem-openapi-derive/src/api.rs +++ b/poem-openapi-derive/src/api.rs @@ -367,9 +367,11 @@ fn generate_operation( let scopes = &operation_param.scopes; security.push(quote! { if <#arg_ty as #crate_name::ApiExtractor>::TYPES.contains(&#crate_name::ApiExtractorType::SecurityScheme) { - security.push(<::std::collections::HashMap<&'static str, ::std::vec::Vec<&'static str>> as ::std::convert::From<_>>::from([ - (<#arg_ty as #crate_name::ApiExtractor>::security_scheme().unwrap(), ::std::vec![#(#crate_name::OAuthScopes::name(&#scopes)),*]) - ])); + for security_name in <#arg_ty as #crate_name::ApiExtractor>::security_schemes() { + security.push(<::std::collections::HashMap<&'static str, ::std::vec::Vec<&'static str>> as ::std::convert::From<_>>::from([ + (security_name, ::std::vec![#(#crate_name::OAuthScopes::name(&#scopes)),*]) + ])); + } } }); } diff --git a/poem-openapi-derive/src/security_scheme.rs b/poem-openapi-derive/src/security_scheme.rs index 5619850fb5..207457e5cb 100644 --- a/poem-openapi-derive/src/security_scheme.rs +++ b/poem-openapi-derive/src/security_scheme.rs @@ -1,7 +1,7 @@ use darling::{ - ast::{Data, Style}, - util::{Ignored, SpannedValue}, - FromDeriveInput, FromMeta, + ast::{Data, Fields, Style}, + util::SpannedValue, + FromDeriveInput, FromMeta, FromVariant, }; use http::header::HeaderName; use proc_macro2::{Ident, Span, TokenStream}; @@ -185,18 +185,24 @@ pub(crate) enum ApiKeyInType { Cookie, } +#[derive(FromVariant)] +struct SecuritySchemeItem { + ident: Ident, + fields: Fields, +} + #[derive(FromDeriveInput)] #[darling(attributes(oai), forward_attrs(doc))] struct SecuritySchemeArgs { ident: Ident, - data: Data, + data: Data, attrs: Vec, #[darling(default)] internal: bool, #[darling(default)] rename: Option, - ty: AuthType, + ty: Option, #[darling(default)] key_in: Option, #[darling(default)] @@ -212,8 +218,15 @@ struct SecuritySchemeArgs { } impl SecuritySchemeArgs { - fn validate(&self) -> GeneratorResult<()> { + fn auth_type(&self) -> GeneratorResult { match self.ty { + Some(ty) => Ok(ty), + None => Err(Error::new_spanned(&self.ident, "Missing an auth type.").into()), + } + } + + fn validate(&self) -> GeneratorResult<()> { + match self.auth_type()? { AuthType::ApiKey => self.validate_api_key(), AuthType::OAuth2 => self.validate_oauth2(), AuthType::OpenIdConnect => self.validate_openid_connect(), @@ -308,7 +321,7 @@ impl SecuritySchemeArgs { None => quote!(::std::option::Option::None), }; - let ts = match self.ty { + let ts = match self.auth_type()? { AuthType::ApiKey => { quote! { registry.create_security_scheme(#name, #crate_name::registry::MetaSecurityScheme { @@ -384,8 +397,8 @@ impl SecuritySchemeArgs { Ok(ts) } - fn generate_from_request(&self, crate_name: &TokenStream) -> TokenStream { - match self.ty { + fn generate_from_request(&self, crate_name: &TokenStream) -> GeneratorResult { + match self.auth_type()? { AuthType::ApiKey => { let key_name = self.key_name.as_ref().unwrap().as_str(); let param_in = match self.key_in.as_ref().unwrap() { @@ -393,20 +406,22 @@ impl SecuritySchemeArgs { ApiKeyInType::Header => quote!(#crate_name::registry::MetaParamIn::Header), ApiKeyInType::Cookie => quote!(#crate_name::registry::MetaParamIn::Cookie), }; - quote!(<#crate_name::auth::ApiKey as #crate_name::auth::ApiKeyAuthorization>::from_request(req, query, #key_name, #param_in)) - } - AuthType::Basic => { - quote!(<#crate_name::auth::Basic as #crate_name::auth::BasicAuthorization>::from_request(req)) - } - AuthType::Bearer => { - quote!(<#crate_name::auth::Bearer as #crate_name::auth::BearerAuthorization>::from_request(req)) - } - AuthType::OAuth2 => { - quote!(<#crate_name::auth::Bearer as #crate_name::auth::BearerAuthorization>::from_request(req)) - } - AuthType::OpenIdConnect => { - quote!(<#crate_name::auth::Bearer as #crate_name::auth::BearerAuthorization>::from_request(req)) + Ok( + quote!(<#crate_name::auth::ApiKey as #crate_name::auth::ApiKeyAuthorization>::from_request(req, query, #key_name, #param_in)), + ) } + AuthType::Basic => Ok( + quote!(<#crate_name::auth::Basic as #crate_name::auth::BasicAuthorization>::from_request(req)), + ), + AuthType::Bearer => Ok( + quote!(<#crate_name::auth::Bearer as #crate_name::auth::BearerAuthorization>::from_request(req)), + ), + AuthType::OAuth2 => Ok( + quote!(<#crate_name::auth::Bearer as #crate_name::auth::BearerAuthorization>::from_request(req)), + ), + AuthType::OpenIdConnect => Ok( + quote!(<#crate_name::auth::Bearer as #crate_name::auth::BearerAuthorization>::from_request(req)), + ), } } } @@ -415,69 +430,131 @@ pub(crate) fn generate(args: DeriveInput) -> GeneratorResult { let args: SecuritySchemeArgs = SecuritySchemeArgs::from_derive_input(&args)?; let crate_name = get_crate_name(args.internal); let ident = &args.ident; - let oai_typename = args.rename.clone().unwrap_or_else(|| ident.to_string()); - args.validate()?; - let fields = match &args.data { - Data::Struct(e) => e, - _ => { - return Err(Error::new_spanned( - ident, - "SecurityScheme can only be applied to an struct.", - ) - .into()) - } - }; - - if fields.style == Style::Tuple && fields.fields.len() != 1 { - return Err(Error::new_spanned( - ident, - "Only one unnamed field is allowed in the SecurityScheme struct.", - ) - .into()); - } + match &args.data { + Data::Struct(fields) => { + let oai_typename = args.rename.clone().unwrap_or_else(|| ident.to_string()); - let register_security_scheme = - args.generate_register_security_scheme(&crate_name, &oai_typename)?; - let from_request = args.generate_from_request(&crate_name); - let path = args.checker.as_ref(); - - let output = match path { - Some(_) => quote! { - let output = #crate_name::__private::CheckerReturn::from(#path(&req, #from_request?).await).into_result()?; - }, - None => quote! { - let output = #from_request?; - }, - }; - - let expanded = quote! { - #[#crate_name::__private::poem::async_trait] - impl<'a> #crate_name::ApiExtractor<'a> for #ident { - const TYPES: &'static [#crate_name::ApiExtractorType] = &[#crate_name::ApiExtractorType::SecurityScheme]; - - type ParamType = (); - type ParamRawType = (); - - fn register(registry: &mut #crate_name::registry::Registry) { - #register_security_scheme + if fields.style != Style::Tuple || fields.fields.len() != 1 { + return Err(Error::new_spanned( + ident, + "Only one unnamed field is allowed in the SecurityScheme struct.", + ) + .into()); } - fn security_scheme() -> ::std::option::Option<&'static str> { - ::std::option::Option::Some(#oai_typename) + args.validate()?; + + let register_security_scheme = + args.generate_register_security_scheme(&crate_name, &oai_typename)?; + let from_request = args.generate_from_request(&crate_name)?; + let path = args.checker.as_ref(); + + let output = match path { + Some(_) => quote! { + let output = #crate_name::__private::CheckerReturn::from(#path(&req, #from_request?).await).into_result()?; + }, + None => quote! { + let output = #from_request?; + }, + }; + + let expanded = quote! { + #[#crate_name::__private::poem::async_trait] + impl<'a> #crate_name::ApiExtractor<'a> for #ident { + const TYPES: &'static [#crate_name::ApiExtractorType] = &[#crate_name::ApiExtractorType::SecurityScheme]; + + type ParamType = (); + type ParamRawType = (); + + fn register(registry: &mut #crate_name::registry::Registry) { + #register_security_scheme + } + + fn security_schemes() -> ::std::vec::Vec<&'static str> { + ::std::vec![#oai_typename] + } + + async fn from_request( + req: &'a #crate_name::__private::poem::Request, + body: &mut #crate_name::__private::poem::RequestBody, + _param_opts: #crate_name::ExtractParamOptions, + ) -> #crate_name::__private::poem::Result { + let query = req.extensions().get::<#crate_name::__private::UrlQuery>().unwrap(); + #output + ::std::result::Result::Ok(Self(output)) + } + } + }; + + Ok(expanded) + } + Data::Enum(items) => { + let mut registers = Vec::new(); + let mut security_schemes = Vec::new(); + let mut from_requests = Vec::new(); + + if items.is_empty() { + return Err(Error::new_spanned(ident, "At least one member is required.").into()); } - async fn from_request( - req: &'a #crate_name::__private::poem::Request, - body: &mut #crate_name::__private::poem::RequestBody, - _param_opts: #crate_name::ExtractParamOptions, - ) -> #crate_name::__private::poem::Result { - let query = req.extensions().get::<#crate_name::__private::UrlQuery>().unwrap(); - #output - ::std::result::Result::Ok(Self(output)) + for item in items { + if item.fields.style != Style::Tuple || item.fields.fields.len() != 1 { + return Err(Error::new_spanned( + ident, + "Only one unnamed field is allowed in the SecurityScheme enum.", + ) + .into()); + } + + let item_ident = &item.ident; + let item_type = &item.fields.fields[0]; + + registers.push( + quote! { <#item_type as #crate_name::ApiExtractor>::register(registry); }, + ); + security_schemes.push(quote! { + security_schemes.extend(<#item_type as #crate_name::ApiExtractor>::security_schemes()); + }); + from_requests.push(quote! { + match <#item_type as #crate_name::ApiExtractor>::from_request(req, body, param_opts.clone()).await { + ::std::result::Result::Ok(item) => return Ok(#ident::#item_ident(item)), + ::std::result::Result::Err(err) => last_err = ::std::option::Option::Some(err), + } + }) } - } - }; - Ok(expanded) + let expanded = quote! { + #[#crate_name::__private::poem::async_trait] + impl<'a> #crate_name::ApiExtractor<'a> for #ident { + const TYPES: &'static [#crate_name::ApiExtractorType] = &[#crate_name::ApiExtractorType::SecurityScheme]; + + type ParamType = (); + type ParamRawType = (); + + fn register(registry: &mut #crate_name::registry::Registry) { + #(#registers)* + } + + fn security_schemes() -> ::std::vec::Vec<&'static str> { + let mut security_schemes = ::std::vec![]; + #(#security_schemes)* + security_schemes + } + + async fn from_request( + req: &'a #crate_name::__private::poem::Request, + body: &mut #crate_name::__private::poem::RequestBody, + param_opts: #crate_name::ExtractParamOptions, + ) -> #crate_name::__private::poem::Result { + let mut last_err = ::std::option::Option::None; + #(#from_requests)* + ::std::result::Result::Err(last_err.unwrap()) + } + } + }; + + Ok(expanded) + } + } } diff --git a/poem-openapi/CHANGELOG.md b/poem-openapi/CHANGELOG.md index 9d4279528e..b1c32fa2a9 100644 --- a/poem-openapi/CHANGELOG.md +++ b/poem-openapi/CHANGELOG.md @@ -4,6 +4,14 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +# [3.0.2] 2023-08-12 + +- Add support for multiple authentication methods. [#627](https://github.com/poem-web/poem/discussions/627) + +## Breaking changes + +- change `fn ApiExtractor::security_schemes() -> Option<&str>` to `fn ApiExtractor::security_schemes() -> Vec<&str>` + # [3.0.1] 2023-08-02 - openapi: allows multiple secutity schemes on one operation [#621](https://github.com/poem-web/poem/issues/621) diff --git a/poem-openapi/src/base.rs b/poem-openapi/src/base.rs index 8aaf20bf50..7d7a6765b7 100644 --- a/poem-openapi/src/base.rs +++ b/poem-openapi/src/base.rs @@ -57,6 +57,7 @@ impl UrlQuery { } /// Options for the parameter extractor. +#[derive(Clone)] pub struct ExtractParamOptions { /// The name of this parameter. pub name: &'static str, @@ -164,9 +165,9 @@ pub trait ApiExtractor<'a>: Sized { /// Register related types to registry. fn register(registry: &mut Registry) {} - /// Returns name of security scheme if this extractor is security scheme. - fn security_scheme() -> Option<&'static str> { - None + /// Returns names of security scheme if this extractor is security scheme. + fn security_schemes() -> Vec<&'static str> { + vec![] } /// Returns the location of the parameter if this extractor is parameter. diff --git a/poem-openapi/src/docs/security_scheme.md b/poem-openapi/src/docs/security_scheme.md index 971477ccf4..f9f3a2b384 100644 --- a/poem-openapi/src/docs/security_scheme.md +++ b/poem-openapi/src/docs/security_scheme.md @@ -2,21 +2,21 @@ Define a OpenAPI Security Scheme. # Macro parameters -| Attribute | Description | Type | Optional | -| ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ---------- | -------- | -| rename | Rename the security scheme. | string | Y | -| ty | The type of the security scheme. (api_key, basic, bearer, oauth2, openid_connect) | string | N | -| key_in | `api_key` The location of the API key. Valid values are "query", "header" or "cookie". (query, header, cookie) | string | Y | -| key_name | `api_key` The name of the header, query or cookie parameter to be used.. | string | Y | -| bearer_format | `bearer` A hint to the client to identify how the bearer token is formatted. Bearer tokens are usually generated by an authorization server, so this information is primarily for documentation purposes. | string | Y | -| flows | `oauth2` An object containing configuration information for the flow types supported. | OAuthFlows | Y | -| openid_connect_url | OpenId Connect URL to discover OAuth2 configuration values. | string | Y | +| Attribute | Description | Type | Optional | +|--------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------|----------| +| rename | Rename the security scheme. | string | Y | +| ty | The type of the security scheme. (api_key, basic, bearer, oauth2, openid_connect) | string | N | +| key_in | `api_key` The location of the API key. Valid values are "query", "header" or "cookie". (query, header, cookie) | string | Y | +| key_name | `api_key` The name of the header, query or cookie parameter to be used.. | string | Y | +| bearer_format | `bearer` A hint to the client to identify how the bearer token is formatted. Bearer tokens are usually generated by an authorization server, so this information is primarily for documentation purposes. | string | Y | +| flows | `oauth2` An object containing configuration information for the flow types supported. | OAuthFlows | Y | +| openid_connect_url | OpenId Connect URL to discover OAuth2 configuration values. | string | Y | | checker | Specify a function to check the original authentication information and convert it to the return type of this function. This function must return `Option` or `poem::Result`, with `None` meaning a General Authorization error and anĀ `Err` reflecting the error supplied. | string | Y | # OAuthFlows | Attribute | description | Type | Optional | -| ------------------ | -------------------------------------------------------- | --------- | -------- | +|--------------------|----------------------------------------------------------|-----------|----------| | implicit | Configuration for the OAuth Implicit flow | OAuthFlow | Y | | password | Configuration for the OAuth Resource Owner Password flow | OAuthFlow | Y | | client_credentials | Configuration for the OAuth Client Credentials flow | OAuthFlow | Y | @@ -24,9 +24,50 @@ Define a OpenAPI Security Scheme. # OAuthFlow -| Attribute | description | Type | Optional | -| ----------------- | -------------------------------------------------------------------------------------------------- | ----------- | -------- | -| authorization_url | `implicit` `authorization_code` The authorization URL to be used for this flow. | string | Y | +| Attribute | description | Type | Optional | +|-------------------|----------------------------------------------------------------------------------------------|-------------|----------| +| authorization_url | `implicit` `authorization_code` The authorization URL to be used for this flow. | string | Y | | token_url | `password` `client_credentials` `authorization_code` The token URL to be used for this flow. | string | Y | -| refresh_url | The URL to be used for obtaining refresh tokens. | string | Y | -| scopes | The available scopes for the OAuth2 security scheme. | OAuthScopes | Y | +| refresh_url | The URL to be used for obtaining refresh tokens. | string | Y | +| scopes | The available scopes for the OAuth2 security scheme. | OAuthScopes | Y | + +# Multiple Authentication Methods + +When `SecurityScheme` macro is used with an enumerated type, it is used to define multiple authentication methods. + +```rust +use poem_openapi::{OpenApi, SecurityScheme}; +use poem_openapi::payload::PlainText; +use poem_openapi::auth::{ApiKey, Basic}; + +#[derive(SecurityScheme)] +#[oai(ty = "basic")] +struct MySecurityScheme1(Basic); + +#[derive(SecurityScheme)] +#[oai(ty = "api_key", key_name = "X-API-Key", key_in = "header")] +struct MySecurityScheme2(ApiKey); + +#[derive(SecurityScheme)] +enum MySecurityScheme { + MySecurityScheme1(MySecurityScheme1), + MySecurityScheme2(MySecurityScheme2), +} + +struct MyApi; + +#[OpenApi] +impl MyApi { + #[oai(path = "/test", method = "get")] + async fn test(&self, auth: MySecurityScheme) -> PlainText { + match auth { + MySecurityScheme::MySecurityScheme1(auth) => { + PlainText(format!("basic: {}", auth.0.username)) + } + MySecurityScheme::MySecurityScheme2(auth) => { + PlainText(format!("api-key: {}", auth.0.key)) + } + } + } +} +``` \ No newline at end of file diff --git a/poem-openapi/tests/security_scheme.rs b/poem-openapi/tests/security_scheme.rs index 30f0cf3d7d..b3d2fe4217 100644 --- a/poem-openapi/tests/security_scheme.rs +++ b/poem-openapi/tests/security_scheme.rs @@ -20,7 +20,7 @@ fn rename() { #[oai(rename = "ABC", ty = "basic")] struct MySecurityScheme(Basic); - assert_eq!(MySecurityScheme::security_scheme().unwrap(), "ABC"); + assert_eq!(MySecurityScheme::security_schemes(), &["ABC"]); } #[test] @@ -29,10 +29,7 @@ fn default_rename() { #[oai(ty = "basic")] struct MySecurityScheme(Basic); - assert_eq!( - MySecurityScheme::security_scheme().unwrap(), - "MySecurityScheme" - ); + assert_eq!(MySecurityScheme::security_schemes(), &["MySecurityScheme"]); } #[test] @@ -462,9 +459,6 @@ async fn checker_result() { } } - let mut registry = Registry::new(); - MySecurityScheme::register(&mut registry); - struct MyApi; #[OpenApi] @@ -483,7 +477,7 @@ async fn checker_result() { .send() .await; resp.assert_status_is_ok(); - resp.assert_text("Authed: Enabled".to_string()).await; + resp.assert_text("Authed: Enabled").await; let resp = client .get("/test") @@ -508,9 +502,6 @@ async fn checker_option() { } } - let mut registry = Registry::new(); - MySecurityScheme::register(&mut registry); - struct MyApi; #[OpenApi] @@ -529,7 +520,7 @@ async fn checker_option() { .send() .await; resp.assert_status_is_ok(); - resp.assert_text("Authed: Enabled".to_string()).await; + resp.assert_text("Authed: Enabled").await; let resp = client .get("/test") @@ -538,3 +529,58 @@ async fn checker_option() { .await; resp.assert_status(StatusCode::UNAUTHORIZED); } + +#[tokio::test] +async fn multiple_auth_methods() { + #[derive(SecurityScheme)] + #[oai(ty = "basic")] + struct MySecurityScheme1(Basic); + + #[derive(SecurityScheme)] + #[oai(ty = "api_key", key_name = "X-API-Key", key_in = "header")] + struct MySecurityScheme2(ApiKey); + + #[derive(SecurityScheme)] + enum MySecurityScheme { + MySecurityScheme1(MySecurityScheme1), + MySecurityScheme2(MySecurityScheme2), + } + + struct MyApi; + + #[OpenApi] + impl MyApi { + #[oai(path = "/test", method = "get")] + async fn test(&self, auth: MySecurityScheme) -> PlainText { + match auth { + MySecurityScheme::MySecurityScheme1(auth) => { + PlainText(format!("basic: {}", auth.0.username)) + } + MySecurityScheme::MySecurityScheme2(auth) => { + PlainText(format!("api-key: {}", auth.0.key)) + } + } + } + } + + let service = OpenApiService::new(MyApi, "test", "1.0"); + let client = TestClient::new(service); + let resp = client + .get("/test") + .typed_header(headers::Authorization::basic("sunli", "password")) + .send() + .await; + resp.assert_status_is_ok(); + resp.assert_text("basic: sunli").await; + + let resp = client + .get("/test") + .header("X-API-Key", "abcdef") + .send() + .await; + resp.assert_status_is_ok(); + resp.assert_text("api-key: abcdef").await; + + let resp = client.get("/test").send().await; + resp.assert_status(StatusCode::UNAUTHORIZED); +}