diff --git a/src/pyicloud_ipd/base.py b/src/pyicloud_ipd/base.py index cf4479130..b7e704d71 100644 --- a/src/pyicloud_ipd/base.py +++ b/src/pyicloud_ipd/base.py @@ -53,6 +53,12 @@ "scnt": "scnt", } +def origin_referer_headers(input: str) -> Dict[str, str]: + return { + "Origin": input, + "Referer": f"{input}/" + } + class TrustedPhoneContextProvider(NamedTuple): domain: str oauth_session: AuthenticatedSession @@ -105,10 +111,12 @@ def __init__( self.password_filter: PyiCloudPasswordFilter|None = None if (domain == 'com'): + self.AUTH_ROOT_ENDPOINT = "https://idmsa.apple.com" self.AUTH_ENDPOINT = "https://idmsa.apple.com/appleauth/auth" self.HOME_ENDPOINT = "https://www.icloud.com" self.SETUP_ENDPOINT = "https://setup.icloud.com/setup/ws/1" elif (domain == 'cn'): + self.AUTH_ROOT_ENDPOINT = "https://idmsa.apple.com.cn" self.AUTH_ENDPOINT = "https://idmsa.apple.com.cn/appleauth/auth" self.HOME_ENDPOINT = "https://www.icloud.com.cn" self.SETUP_ENDPOINT = "https://setup.icloud.com.cn/setup/ws/1" @@ -376,7 +384,7 @@ def encode(self) -> bytes: 'protocols': ['s2k', 's2k_fo'] } - headers = self._get_auth_headers() + headers = self._get_auth_headers(origin_referer_headers(self.AUTH_ROOT_ENDPOINT)) try: if self.response_observer: @@ -487,7 +495,7 @@ def _authenticate_raw_password(self, password: str) -> None: if self.session_data.get("trust_token"): data["trustTokens"] = [self.session_data.get("trust_token")] - headers = self._get_auth_headers() + headers = self._get_auth_headers(origin_referer_headers(self.AUTH_ROOT_ENDPOINT)) try: # set observer with obfuscator if self.response_observer: @@ -517,6 +525,7 @@ def _authenticate_raw_password(self, password: str) -> None: def _validate_token(self) -> Dict[str, Any]: """Checks if the current access token is still valid.""" LOGGER.debug("Checking session token validity") + headers = origin_referer_headers(self.HOME_ENDPOINT) try: # set observer with obfuscator if self.response_observer: @@ -533,7 +542,7 @@ def _validate_token(self) -> Dict[str, Any]: rules = [] with self.use_rules(rules): - response = self.session.post("%s/validate" % self.SETUP_ENDPOINT, data="null") + response = self.session.post("%s/validate" % self.SETUP_ENDPOINT, data="null", headers=headers) LOGGER.debug("Session token is still valid") result: Dict[str, Any] = response.json() return result