Skip to content

Commit

Permalink
Allows the middleware to override the rest API requests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Jean-Charles Bertin committed Jul 17, 2022
1 parent e684710 commit 51c0dda
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions chalice/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,7 @@ def __init__(self, event_dict: Dict[str, Any],
if query_params is None else MultiDict(query_params)
self.headers: CaseInsensitiveMapping = \
CaseInsensitiveMapping(event_dict['headers'])
self.uri_params: Optional[Dict[str, str]] \
= event_dict['pathParameters']
self.uri_params: Dict[str, str] = event_dict['pathParameters']
self.method: str = event_dict['requestContext']['httpMethod']
self._is_base64_encoded = event_dict.get('isBase64Encoded', False)
self._body: Any = event_dict['body']
Expand Down Expand Up @@ -1857,12 +1856,17 @@ def wrapped_event(request: Request) -> Response:
return response.to_dict(self.api.binary_types)

def _main_rest_api_handler(self, event: Any, context: Any) -> Response:
resource_path = event.get('requestContext', {}).get('resourcePath')
if resource_path is None:
return error_response(error_code='InternalServerError',
message='Unknown request.',
http_status_code=500)
http_method = event['requestContext']['httpMethod']
current_request: Optional[Request] = self.current_request
if current_request:
resource_path = current_request.path
http_method = current_request.method
else:
resource_path = event.get('requestContext', {}).get('resourcePath')
if resource_path is None:
return error_response(error_code='InternalServerError',
message='Unknown request.',
http_status_code=500)
http_method = event['requestContext']['httpMethod']
if http_method not in self.routes[resource_path]:
allowed_methods = ', '.join(self.routes[resource_path].keys())
return error_response(
Expand All @@ -1872,8 +1876,12 @@ def _main_rest_api_handler(self, event: Any, context: Any) -> Response:
headers={'Allow': allowed_methods})
route_entry = self.routes[resource_path][http_method]
view_function = route_entry.view_function
function_args = {name: event['pathParameters'][name]
for name in route_entry.view_args}
if current_request:
function_args = {name: current_request.uri_params[name]
for name in route_entry.view_args}
else:
function_args = {name: event['pathParameters'][name]
for name in route_entry.view_args}
self.lambda_context = context
# We're getting the CORS headers before validation to be able to
# output desired headers with
Expand All @@ -1883,8 +1891,8 @@ def _main_rest_api_handler(self, event: Any, context: Any) -> Response:
# We're doing the header validation after creating the request
# so can leverage the case insensitive dict that the Request class
# uses for headers.
if self.current_request and route_entry.content_types:
content_type = self.current_request.headers.get(
if current_request and route_entry.content_types:
content_type = current_request.headers.get(
'content-type', 'application/json')
if not _matches_content_type(content_type,
route_entry.content_types):
Expand All @@ -1900,8 +1908,8 @@ def _main_rest_api_handler(self, event: Any, context: Any) -> Response:
self._add_cors_headers(response, cors_headers)

response_headers = CaseInsensitiveMapping(response.headers)
if self.current_request and not self._validate_binary_response(
self.current_request.headers, response_headers):
if current_request and not self._validate_binary_response(
current_request.headers, response_headers):
content_type = response_headers.get('content-type', '')
return error_response(
error_code='BadRequest',
Expand Down

0 comments on commit 51c0dda

Please sign in to comment.