@@ -120,6 +120,7 @@ def clean(self) -> None:
120120 self .prisma_api_url = normalize_prisma_url (os .getenv ('PRISMA_API_URL' , 'https://api0.prismacloud.io' ))
121121 self .prisma_policies_url : str | None = None
122122 self .prisma_policy_filters_url : str | None = None
123+ self .custom_auth_headers : dict [str , str ] = {}
123124 self .setup_api_urls ()
124125 self .customer_run_config_response = None
125126 self .runtime_run_config_response = None
@@ -163,6 +164,7 @@ def init_instance(self, platform_integration_data: dict[str, Any]) -> None:
163164 self .credentials = platform_integration_data ["credentials" ]
164165 self .platform_integration_configured = platform_integration_data ["platform_integration_configured" ]
165166 self .prisma_api_url = platform_integration_data ["prisma_api_url" ]
167+ self .custom_auth_headers = platform_integration_data ["custom_auth_headers" ]
166168 self .repo_branch = platform_integration_data ["repo_branch" ]
167169 self .repo_id = platform_integration_data ["repo_id" ]
168170 self .repo_path = platform_integration_data ["repo_path" ]
@@ -187,6 +189,7 @@ def generate_instance_data(self) -> dict[str, Any]:
187189 "credentials" : self .credentials ,
188190 "platform_integration_configured" : self .platform_integration_configured ,
189191 "prisma_api_url" : self .prisma_api_url ,
192+ "custom_auth_headers" : self .custom_auth_headers ,
190193 "repo_branch" : self .repo_branch ,
191194 "repo_id" : self .repo_id ,
192195 "repo_path" : self .repo_path ,
@@ -479,7 +482,8 @@ def _get_s3_creds(self, repo_id: str, token: str) -> dict[str, Any]:
479482 request = self .http .request ("POST" , self .integrations_api_url , # type:ignore[union-attr]
480483 body = json .dumps ({"repoId" : repo_id , "support" : self .support_flag_enabled }),
481484 headers = merge_dicts ({"Authorization" : token , "Content-Type" : "application/json" },
482- get_user_agent_header ()))
485+ get_user_agent_header (),
486+ self .custom_auth_headers ))
483487 logging .debug (f'Request ID: { request .headers .get ("x-amzn-requestid" )} ' )
484488 logging .debug (f'Trace ID: { request .headers .get ("x-amzn-trace-id" )} ' )
485489 if request .status == 403 :
@@ -834,7 +838,8 @@ def commit_repository(self, branch: str) -> str | None:
834838 "Content-Type" : "application/json" ,
835839 'x-api-client' : self .bc_source .name ,
836840 'x-api-checkov-version' : checkov_version },
837- get_user_agent_header ()
841+ get_user_agent_header (),
842+ self .custom_auth_headers
838843 ))
839844 response = json .loads (request .data .decode ("utf8" ))
840845 logging .debug (f'Request ID: { request .headers .get ("x-amzn-requestid" )} ' )
@@ -939,7 +944,8 @@ def get_customer_run_config(self) -> None:
939944 try :
940945 token = self .get_auth_token ()
941946 headers = merge_dicts (get_auth_header (token ),
942- get_default_get_headers (self .bc_source , self .bc_source_version ))
947+ get_default_get_headers (self .bc_source , self .bc_source_version ),
948+ self .custom_auth_headers )
943949
944950 self .setup_http_manager ()
945951 if not self .http :
@@ -989,7 +995,8 @@ def get_reachability_run_config(self) -> Union[Dict[str, Any], None]:
989995 try :
990996 token = self .get_auth_token ()
991997 headers = merge_dicts (get_auth_header (token ),
992- get_default_get_headers (self .bc_source , self .bc_source_version ))
998+ get_default_get_headers (self .bc_source , self .bc_source_version ),
999+ self .custom_auth_headers )
9931000
9941001 self .setup_http_manager ()
9951002 if not self .http :
@@ -1030,7 +1037,8 @@ def get_runtime_run_config(self) -> None:
10301037
10311038 token = self .get_auth_token ()
10321039 headers = merge_dicts (get_auth_header (token ),
1033- get_default_get_headers (self .bc_source , self .bc_source_version ))
1040+ get_default_get_headers (self .bc_source , self .bc_source_version ),
1041+ self .custom_auth_headers )
10341042
10351043 self .setup_http_manager ()
10361044 if not self .http :
@@ -1075,7 +1083,7 @@ def get_prisma_build_policies(self, policy_filter: str) -> None:
10751083
10761084 try :
10771085 token = self .get_auth_token ()
1078- headers = merge_dicts (get_prisma_auth_header (token ), get_prisma_get_headers ())
1086+ headers = merge_dicts (get_prisma_auth_header (token ), get_prisma_get_headers (), self . custom_auth_headers )
10791087
10801088 self .setup_http_manager ()
10811089 if not self .http :
@@ -1107,7 +1115,7 @@ def get_prisma_policy_filters(self) -> Dict[str, Dict[str, Any]]:
11071115 request = None
11081116 try :
11091117 token = self .get_auth_token ()
1110- headers = merge_dicts (get_prisma_auth_header (token ), get_prisma_get_headers ())
1118+ headers = merge_dicts (get_prisma_auth_header (token ), get_prisma_get_headers (), self . custom_auth_headers )
11111119
11121120 self .setup_http_manager ()
11131121 if not self .http :
@@ -1301,10 +1309,12 @@ def get_default_headers(self, request_type: str) -> dict[str, Any]:
13011309
13021310 if request_type .upper () == "GET" :
13031311 return merge_dicts (get_default_get_headers (self .bc_source , self .bc_source_version ),
1304- {"Authorization" : self .get_auth_token ()})
1312+ {"Authorization" : self .get_auth_token ()},
1313+ self .custom_auth_headers )
13051314 elif request_type .upper () == "POST" :
13061315 return merge_dicts (get_default_post_headers (self .bc_source , self .bc_source_version ),
1307- {"Authorization" : self .get_auth_token ()})
1316+ {"Authorization" : self .get_auth_token ()},
1317+ self .custom_auth_headers )
13081318
13091319 logging .info (f"Unsupported request { request_type } " )
13101320 return {}
@@ -1316,7 +1326,8 @@ def get_sso_prismacloud_url(self, report_url: str) -> str:
13161326 url_saml_config = f"{ bc_integration .prisma_api_url } /saml/config"
13171327 token = self .get_auth_token ()
13181328 headers = merge_dicts (get_auth_header (token ),
1319- get_default_get_headers (self .bc_source , self .bc_source_version ))
1329+ get_default_get_headers (self .bc_source , self .bc_source_version ),
1330+ bc_integration .custom_auth_headers )
13201331
13211332 request = self .http .request ("GET" , url_saml_config , headers = headers , timeout = 10 ) # type:ignore[no-untyped-call]
13221333 if request .status >= 300 :
0 commit comments