forked from Azure/azure-sdk-for-python
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcbs.py
More file actions
318 lines (291 loc) · 12.1 KB
/
cbs.py
File metadata and controls
318 lines (291 loc) · 12.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# -------------------------------------------------------------------------
import logging
from datetime import datetime
from typing import Any, Optional, Tuple, Union
from .utils import utc_now, utc_from_timestamp
from .management_link import ManagementLink
from .message import Message, Properties
from .error import (
AuthenticationException,
ErrorCondition,
TokenAuthFailure,
TokenExpired,
)
from .constants import (
CbsState,
CbsAuthState,
CBS_PUT_TOKEN,
CBS_EXPIRATION,
CBS_NAME,
CBS_TYPE,
CBS_OPERATION,
ManagementExecuteOperationResult,
ManagementOpenResult,
)
from .session import Session
from .authentication import JWTTokenAuth, SASTokenAuth
_LOGGER = logging.getLogger(__name__)
def check_expiration_and_refresh_status(expires_on: int, refresh_window: int) -> Tuple[bool, bool]:
seconds_since_epoc = int(utc_now().timestamp())
is_expired = seconds_since_epoc >= expires_on
is_refresh_required = (expires_on - seconds_since_epoc) <= refresh_window
return is_expired, is_refresh_required
def check_put_timeout_status(auth_timeout: float, token_put_time: int) -> bool:
if auth_timeout > 0:
return (int(utc_now().timestamp()) - token_put_time) >= auth_timeout
return False
class CBSAuthenticator: # pylint:disable=too-many-instance-attributes, disable=unused-argument
def __init__(
self,
session: Session,
auth: Union[JWTTokenAuth, SASTokenAuth],
*,
auth_timeout: float,
**kwargs: Any
) -> None:
self._session = session
self._connection = self._session._connection
self._mgmt_link: ManagementLink = self._session.create_request_response_link_pair(
endpoint="$cbs",
on_amqp_management_open_complete=self._on_amqp_management_open_complete,
on_amqp_management_error=self._on_amqp_management_error,
status_code_field=b"status-code",
status_description_field=b"status-description",
)
# FIXME: probably can remove the None check as it should fail callable too
if not auth.get_token or not callable(auth.get_token): # type: ignore
raise ValueError("get_token must be a callable object.")
self._auth = auth
self._encoding = "UTF-8"
self._auth_timeout: float = auth_timeout
self._token_put_time: Optional[int] = None
self._expires_on: Optional[int] = None
self._token: Optional[str] = None
self._refresh_window: Optional[int] = None
self._network_trace_params = {
"amqpConnection": self._session._connection._container_id,
"amqpSession": self._session.name,
"amqpLink": ""
}
self._token_status_code: Optional[int] = None
self._token_status_description: Optional[str] = None
self.state = CbsState.CLOSED
self.auth_state = CbsAuthState.IDLE
def _put_token(self, token: str, token_type: str, audience: str, expires_on: Optional[datetime] = None) -> None:
message = Message( # type: ignore # TODO: missing positional args header, etc.
value=token,
properties=Properties(message_id=self._mgmt_link.next_message_id), # type: ignore
application_properties={
CBS_NAME: audience,
CBS_OPERATION: CBS_PUT_TOKEN,
CBS_TYPE: token_type,
CBS_EXPIRATION: expires_on,
},
)
self._mgmt_link.execute_operation(
message,
self._on_execute_operation_complete,
timeout=self._auth_timeout,
operation=CBS_PUT_TOKEN,
type=token_type,
)
self._mgmt_link.next_message_id += 1
def _on_amqp_management_open_complete(self, management_open_result: ManagementOpenResult) -> None:
if self.state in (CbsState.CLOSED, CbsState.ERROR):
_LOGGER.debug(
"CSB with status: %r encounters unexpected AMQP management open complete.",
self.state,
extra=self._network_trace_params
)
elif self.state == CbsState.OPEN:
self.state = CbsState.ERROR
_LOGGER.info(
"Unexpected AMQP management open complete in OPEN, CBS error occurred.",
extra=self._network_trace_params
)
elif self.state == CbsState.OPENING:
self.state = (
CbsState.OPEN
if management_open_result == ManagementOpenResult.OK
else CbsState.CLOSED
)
_LOGGER.debug(
"CBS completed opening with status: %r",
management_open_result,
extra=self._network_trace_params
)
def _on_amqp_management_error(self) -> None:
if self.state == CbsState.CLOSED:
_LOGGER.info("Unexpected AMQP error in CLOSED state.", extra=self._network_trace_params)
elif self.state == CbsState.OPENING:
self.state = CbsState.ERROR
self._mgmt_link.close()
_LOGGER.info(
"CBS failed to open with status: %r",
ManagementOpenResult.ERROR,
extra=self._network_trace_params
)
elif self.state == CbsState.OPEN:
self.state = CbsState.ERROR
_LOGGER.info("CBS error occurred.", extra=self._network_trace_params)
def _on_execute_operation_complete(
self,
execute_operation_result: ManagementExecuteOperationResult,
status_code: int,
status_description: str,
_,
error_condition: Optional[str] = None,
) -> None:
if error_condition:
_LOGGER.info(
"CBS Put token error: %r",
error_condition,
extra=self._network_trace_params
)
self.auth_state = CbsAuthState.ERROR
return
_LOGGER.debug(
"CBS Put token result (%r), status code: %s, status_description: %s.",
execute_operation_result,
status_code,
status_description,
extra=self._network_trace_params
)
self._token_status_code = status_code
self._token_status_description = status_description
if execute_operation_result == ManagementExecuteOperationResult.OK:
self.auth_state = CbsAuthState.OK
elif execute_operation_result == ManagementExecuteOperationResult.ERROR:
self.auth_state = CbsAuthState.ERROR
# put-token-message sending failure, rejected
self._token_status_code = 0
self._token_status_description = "Auth message has been rejected."
elif (
execute_operation_result
== ManagementExecuteOperationResult.FAILED_BAD_STATUS
):
self.auth_state = CbsAuthState.ERROR
def _update_status(self) -> None:
if (
self.auth_state in (CbsAuthState.OK, CbsAuthState.REFRESH_REQUIRED)
):
is_expired, is_refresh_required = check_expiration_and_refresh_status(
self._expires_on, self._refresh_window # type: ignore
)
_LOGGER.debug(
"CBS status check: state == %r, expired == %r, refresh required == %r",
self.auth_state,
is_expired,
is_refresh_required,
extra=self._network_trace_params
)
if is_expired:
self.auth_state = CbsAuthState.EXPIRED
elif is_refresh_required:
self.auth_state = CbsAuthState.REFRESH_REQUIRED
elif self.auth_state == CbsAuthState.IN_PROGRESS:
_LOGGER.debug(
"CBS update in progress. Token put time: %r",
self._token_put_time,
extra=self._network_trace_params
)
if self._token_put_time is not None:
put_timeout = check_put_timeout_status(
self._auth_timeout, self._token_put_time
)
if put_timeout:
self.auth_state = CbsAuthState.TIMEOUT
def _cbs_link_ready(self) -> Optional[bool]:
if self.state == CbsState.OPEN:
return True
if self.state != CbsState.OPEN:
return False
if self.state in (CbsState.CLOSED, CbsState.ERROR):
raise TokenAuthFailure(
status_code=ErrorCondition.ClientError,
status_description="CBS authentication link is in broken status, please recreate the cbs link.",
)
return None
def open(self) -> None:
self.state = CbsState.OPENING
self._mgmt_link.open()
def close(self) -> None:
self._mgmt_link.close()
self.state = CbsState.CLOSED
def update_token(self) -> None:
self.auth_state = CbsAuthState.IN_PROGRESS
access_token = self._auth.get_token()
if not access_token:
_LOGGER.info(
"Token refresh function received an empty token object.",
extra=self._network_trace_params
)
elif not access_token.token:
_LOGGER.info(
"Token refresh function received an empty token.",
extra=self._network_trace_params
)
self._expires_on = access_token.expires_on
expires_in = self._expires_on - int(utc_now().timestamp())
self._refresh_window = int(float(expires_in) * 0.1)
token_type: Optional[str] = None
if isinstance(access_token.token, bytes):
self._token = access_token.token.decode()
elif isinstance(access_token.token, str):
self._token = access_token.token
else:
raise ValueError("Token must be a string or bytes.")
if isinstance(self._auth.token_type, bytes):
token_type = self._auth.token_type.decode()
elif isinstance(self._auth.token_type, str):
token_type = self._auth.token_type
else:
raise ValueError("Token type must be a string or bytes.")
self._token_put_time = int(utc_now().timestamp())
if self._token and token_type:
self._put_token(
self._token,
token_type,
self._auth.audience, # type: ignore
utc_from_timestamp(self._expires_on),
)
def handle_token(self) -> bool: # pylint: disable=inconsistent-return-statements
if not self._cbs_link_ready():
return False
self._update_status()
if self.auth_state == CbsAuthState.IDLE:
self.update_token()
return False
if self.auth_state == CbsAuthState.IN_PROGRESS:
return False
if self.auth_state == CbsAuthState.OK:
return True
if self.auth_state == CbsAuthState.REFRESH_REQUIRED:
_LOGGER.info(
"Token will expire soon - attempting to refresh.",
extra=self._network_trace_params
)
self.update_token()
return False
if self.auth_state == CbsAuthState.FAILURE:
raise AuthenticationException(
condition=ErrorCondition.InternalError,
description="Failed to open CBS authentication link.",
)
if self.auth_state == CbsAuthState.ERROR:
raise TokenAuthFailure(
self._token_status_code,
self._token_status_description,
encoding=self._encoding, # TODO: drop off all the encodings
)
if self.auth_state == CbsAuthState.TIMEOUT:
raise TimeoutError("Authentication attempt timed-out.")
if self.auth_state == CbsAuthState.EXPIRED:
raise TokenExpired(
condition=ErrorCondition.InternalError,
description="CBS Authentication Expired.",
)