11import base64
22import urllib .parse
33from typing import Any , Dict , Generator , List , Tuple
4+ from itertools import islice
45
56from .abstract_handler import AbstractHandler
67from .. import Response , Request
@@ -25,12 +26,25 @@ def all_casings(input_string: str) -> Generator:
2526 yield first .upper () + sub_casing
2627
2728
29+ def case_mutated_headers (multi_value_headers : Dict [str , List [str ]]) -> Dict [str , str ]:
30+ """Create str/str key/value headers, with duplicate keys case mutated."""
31+ headers = {}
32+ for key , values in multi_value_headers .items ():
33+ if len (values ) > 0 :
34+ casings = list (islice (all_casings (key ), len (values )))
35+ for value , cased_key in zip (values , casings ):
36+ headers [cased_key ] = value
37+ return headers
38+
39+
2840class AwsAlb (AbstractHandler ):
2941 """
3042 Handles AWS Elastic Load Balancer, really Application Load Balancer events
3143 transforming them into ASGI Scope and handling responses
3244
33- See: https://docs.aws.amazon.com/lambda/latest/dg/services-alb.html
45+ See:
46+ 1. https://docs.aws.amazon.com/lambda/latest/dg/services-alb.html
47+ 2. https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html # noqa: E501
3448 """
3549
3650 TYPE = "AWS_ALB"
@@ -71,22 +85,40 @@ def encode_query_string(self) -> bytes:
7185
7286 return urllib .parse .urlencode (query ).encode ()
7387
88+ def transform_headers (self ) -> List [Tuple [bytes , bytes ]]:
89+ """Convert headers to a list of two-tuples per ASGI spec.
90+
91+ Only one of `multiValueHeaders` or `headers` should be defined in the
92+ trigger event. However, we act as though they both might exist and pull
93+ headers out of both.
94+ """
95+ headers = []
96+ if "multiValueHeaders" in self .trigger_event :
97+ for k , v in self .trigger_event ["multiValueHeaders" ].items ():
98+ for inner_v in v :
99+ headers .append ((k .lower ().encode (), inner_v .encode ()))
100+ else :
101+ for k , v in self .trigger_event ["headers" ].items ():
102+ headers .append ((k .lower ().encode (), v .encode ()))
103+ return headers
104+
74105 @property
75106 def request (self ) -> Request :
76107 event = self .trigger_event
77108
78- headers = {}
79- if event .get ("headers" ):
80- headers = {k .lower (): v for k , v in event .get ("headers" , {}).items ()}
109+ headers = self .transform_headers ()
110+ list_headers = [list (x ) for x in headers ]
111+ # Unique headers. If there are duplicates, it will use the last defined.
112+ uq_headers = {k .decode (): v .decode () for k , v in headers }
81113
82- source_ip = headers .get ("x-forwarded-for" , "" )
114+ source_ip = uq_headers .get ("x-forwarded-for" , "" )
83115 path = event ["path" ]
84116 http_method = event ["httpMethod" ]
85117 query_string = self .encode_query_string ()
86118
87- server_name = headers .get ("host" , "mangum" )
119+ server_name = uq_headers .get ("host" , "mangum" )
88120 if ":" not in server_name :
89- server_port = headers .get ("x-forwarded-port" , 80 )
121+ server_port = uq_headers .get ("x-forwarded-port" , 80 )
90122 else :
91123 server_name , server_port = server_name .split (":" ) # pragma: no cover
92124 server = (server_name , int (server_port ))
@@ -97,9 +129,9 @@ def request(self) -> Request:
97129
98130 return Request (
99131 method = http_method ,
100- headers = [[ k . encode (), v . encode ()] for k , v in headers . items ()] ,
132+ headers = list_headers ,
101133 path = urllib .parse .unquote (path ),
102- scheme = headers .get ("x-forwarded-proto" , "https" ),
134+ scheme = uq_headers .get ("x-forwarded-proto" , "https" ),
103135 query_string = query_string ,
104136 server = server ,
105137 client = client ,
@@ -119,36 +151,34 @@ def body(self) -> bytes:
119151
120152 return body
121153
122- def handle_headers (
123- self ,
124- response_headers : List [List [bytes ]],
125- ) -> Tuple [Dict [str , str ], Dict [str , List [str ]]]:
126- headers , multi_value_headers = self ._handle_multi_value_headers (
127- response_headers
128- )
129- if "multiValueHeaders" not in self .trigger_event :
130- # If there are multiple occurrences of headers, create case-mutated
131- # variations: https://github.com/logandk/serverless-wsgi/issues/11
132- for key , values in multi_value_headers .items ():
133- if len (values ) > 1 :
134- for value , cased_key in zip (values , all_casings (key )):
135- headers [cased_key ] = value
136-
137- multi_value_headers = {}
154+ def transform_response (self , response : Response ) -> Dict [str , Any ]:
138155
139- return headers , multi_value_headers
156+ multi_value_headers : Dict [str , List [str ]] = {}
157+ for key , value in response .headers :
158+ lower_key = key .decode ().lower ()
159+ if lower_key not in multi_value_headers :
160+ multi_value_headers [lower_key ] = []
161+ multi_value_headers [lower_key ].append (value .decode ())
140162
141- def transform_response (self , response : Response ) -> Dict [str , Any ]:
142- headers , multi_value_headers = self .handle_headers (response .headers )
163+ headers = case_mutated_headers (multi_value_headers )
143164
144165 body , is_base64_encoded = self ._handle_base64_response_body (
145166 response .body , headers
146167 )
147168
148- return {
169+ out = {
149170 "statusCode" : response .status ,
150- "headers" : headers ,
151- "multiValueHeaders" : multi_value_headers ,
152171 "body" : body ,
153172 "isBase64Encoded" : is_base64_encoded ,
154173 }
174+
175+ # "You must use multiValueHeaders if you have enabled multi-value headers
176+ # and headers otherwise"
177+ # https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html
178+ multi_value_headers_enabled = "multiValueHeaders" in self .trigger_event
179+ if multi_value_headers_enabled :
180+ out ["multiValueHeaders" ] = multi_value_headers
181+ else :
182+ out ["headers" ] = headers
183+
184+ return out
0 commit comments