Skip to content
61 changes: 17 additions & 44 deletions eyepop/compute/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import logging
from typing import Any

Expand Down Expand Up @@ -46,64 +45,56 @@ async def fetch_new_compute_session(
}

sessions_url = f"{compute_ctx.compute_url}/v1/sessions"
log.debug(f"Fetching sessions from: {sessions_url}")

res = None
need_new_session = False
create_reason = None

try:
async with client_session.get(sessions_url, headers=headers) as get_response:
log.debug(f"GET /v1/sessions - status: {get_response.status}")
if get_response.status == 404:
need_new_session = True
log.debug("GET /v1/sessions returned 404, will create new session")
create_reason = "no sessions endpoint (404)"
else:
get_response.raise_for_status()
res = await get_response.json()
log.debug(f"GET /v1/sessions: {get_response.status}")

if not res:
need_new_session = True
log.debug("Response is empty/None, need to create new session")
create_reason = "empty response"
elif isinstance(res, list) and len(res) == 0:
need_new_session = True
log.debug("Response is empty list, need to create new session")
create_reason = "no existing sessions"
elif isinstance(res, dict) and not res.get("session_uuid"):
need_new_session = True
log.debug("Response is dict without session_uuid, need to create new session")
create_reason = "response missing session_uuid"

except aiohttp.ClientResponseError as e:
if e.status == 404:
need_new_session = True
log.debug("GET /v1/sessions returned 404, will create new session")
create_reason = "no sessions endpoint (404)"
else:
raise ComputeSessionException(
f"Failed to fetch existing sessions: {e.message}",
) from e
except Exception as e:
raise ComputeSessionException(f"Unexpected error fetching sessions: {str(e)}") from e

if need_new_session:
if create_reason:
log.debug(f"Creating new session: {create_reason}")
try:
log.debug(f"Creating new session via POST to: {sessions_url}")
body = {}
if compute_ctx.pipeline_image:
body["pipeline_image"] = compute_ctx.pipeline_image
if compute_ctx.pipeline_version:
body["pipeline_version"] = compute_ctx.pipeline_version
if compute_ctx.session_opts:
body.update(compute_ctx.session_opts)

if body:
log.debug(f"POST /v1/sessions body: {body}")

# Explicit kwargs instead of **post_kwargs — the dict unpacking
# confuses aiohttp's overloaded signature and breaks type checking
post_headers = {**headers, **compute_ctx.session_headers} if compute_ctx.session_headers else headers
async with client_session.post(
f'{sessions_url}?wait=true',
headers=headers,
headers=post_headers,
json=body if body else None,
) as post_response:
post_response.raise_for_status()
res = await post_response.json()
log.debug(f"POST /v1/sessions response: {post_response.status}")
log.debug(f"POST /v1/sessions - status: {post_response.status}")
except aiohttp.ClientResponseError as e:
raise ComputeSessionException(
f"Failed to create new session: {e.message}",
Expand Down Expand Up @@ -138,24 +129,14 @@ def _compute_context_from_response(compute_ctx: ComputeContext, res: dict | None
compute_ctx.access_token_expires_at = session_response.access_token_expires_at
compute_ctx.access_token_expires_in = session_response.access_token_expires_in
pipeline_id = ""

if len(session_response.pipelines) > 0:
pipeline_id = session_response.pipelines[0].get("id", None)
if not pipeline_id:
pipeline_id = session_response.pipelines[0].get("pipeline_id", "")

compute_ctx.pipeline_id = pipeline_id

debug_obj = {
"session_endpoint": session_response.session_endpoint,
"session_uuid": session_response.session_uuid,
"m2m_access_token": session_response.access_token,
"m2m_access_token_expires_at": session_response.access_token_expires_at,
"m2m_access_token_expires_in": session_response.access_token_expires_in,
"pipeline_id": pipeline_id,
"pipelines": session_response.pipelines,
}
log.debug(json.dumps(debug_obj, indent=4))

if not session_response.access_token or len(session_response.access_token.strip()) == 0:
raise ComputeSessionException(
"No M2M access_token received from compute API session response. "
Expand All @@ -176,32 +157,25 @@ async def refresh_compute_token(
headers = {"Authorization": f"Bearer {compute_ctx.api_key}", "Accept": "application/json"}

refresh_url = f"{compute_ctx.compute_url}/v1/auth/authenticate"
log.debug(f"Refreshing token at: {refresh_url}")

try:
async with client_session.post(refresh_url, headers=headers) as response:
response.raise_for_status()
token_response = await response.json()
log.debug(f"Token refresh response: {token_response}")
log.debug(f"POST /v1/auth/authenticate - status: {response.status}")

compute_ctx.m2m_access_token = token_response.get("access_token", "")
compute_ctx.access_token_expires_at = token_response.get("expires_at", "")
compute_ctx.access_token_expires_in = token_response.get("expires_in", 0)

log.debug(
f"Token refreshed successfully, expires in: {compute_ctx.access_token_expires_in}s"
)
return compute_ctx

except aiohttp.ClientResponseError as e:
log.error(f"Failed to refresh token: HTTP {e.status} - {e.message}")
raise ComputeTokenException(
f"Token refresh failed: HTTP {e.status} - {e.message}",
session_uuid=compute_ctx.session_uuid,
) from e
except Exception as e:
log.error("Failed to refresh token")
log.debug(str(e))
raise ComputeTokenException(
f"Token refresh failed: {str(e)}", session_uuid=compute_ctx.session_uuid
) from e
Expand All @@ -218,13 +192,12 @@ async def fetch_permanent_compute_session(
}

session_url = f"{compute_ctx.compute_url}/v1/sessions/{permanent_session_uuid}"
log.debug(f"Fetching session from: {session_url}")

try:
async with client_session.get(session_url, headers=headers) as get_response:
get_response.raise_for_status()
res = await get_response.json()
log.debug(f"GET /v1/sessions: {get_response.status}")
log.debug(f"GET /v1/sessions/{permanent_session_uuid} - status: {get_response.status}")
_compute_context_from_response(compute_ctx, res)
return compute_ctx
except aiohttp.ClientResponseError as e:
Expand Down
8 changes: 8 additions & 0 deletions eyepop/compute/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ class ComputeContext(BaseModel):
description="Custom Docker image tag for the worker pipeline",
default_factory=lambda: os.getenv("EYEPOP_PIPELINE_VERSION", "")
)
session_opts: dict = Field(
description="Arbitrary extra fields merged into the session POST body",
default_factory=dict
)
session_headers: dict = Field(
description="Arbitrary extra headers sent with the session POST request",
default_factory=dict
)


class PipelineStatus(str, Enum):
Expand Down
58 changes: 18 additions & 40 deletions eyepop/compute/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@

log = logging.getLogger("eyepop.compute")

_TERMINAL_STATES = {PipelineStatus.FAILED, PipelineStatus.ERROR, PipelineStatus.STOPPED}


async def wait_for_session(
compute_config: ComputeContext, client_session: aiohttp.ClientSession
) -> bool:
timeout = compute_config.wait_for_session_timeout
interval = compute_config.wait_for_session_interval

# Session endpoint health check ALWAYS uses the JWT access_token
if not compute_config.m2m_access_token or len(compute_config.m2m_access_token.strip()) == 0:
raise ComputeHealthCheckException(
"No access_token in compute_config. "
Expand All @@ -25,19 +26,13 @@ async def wait_for_session(
session_endpoint=compute_config.session_endpoint,
)

auth_header = f"Bearer {compute_config.m2m_access_token}"
log.debug(
f"Using JWT access_token for session health check at {compute_config.session_endpoint}/health"
)

headers = {
"Authorization": auth_header,
"Authorization": f"Bearer {compute_config.m2m_access_token}",
"Accept": "application/json",
}

health_url = f"{compute_config.session_endpoint}/health"
log.debug(f"Waiting for session to be ready at: {health_url}")
log.debug(f"Timeout: {timeout}s, Interval: {interval}s")
log.debug(f"Waiting for session ready: {health_url} (timeout={timeout}s, interval={interval}s)")

end_time = asyncio.get_event_loop().time() + timeout
last_message = "No message received"
Expand All @@ -46,58 +41,41 @@ async def wait_for_session(
while asyncio.get_event_loop().time() < end_time:
attempt += 1
try:
log.debug(f"Health check attempt {attempt}")

async with client_session.get(health_url, headers=headers) as response:
log.debug(f"Health check response status: {response.status}")

if response.status == 200:
log.debug("Session is ready (status 200)")
return True

if response.status != 200:
last_message = f"Health check returned status {response.status}"
log.debug(last_message)
last_message = f"HTTP {response.status}"
log.debug(f"GET /health - status: {response.status} (attempt {attempt})")
await asyncio.sleep(interval)
continue

session_response = ComputeApiSessionResponse(**(await response.json()))
status = session_response.session_status
log.debug(f"GET /health - status: 200, pipeline: {status.value} (attempt {attempt})")

if status == PipelineStatus.RUNNING:
log.debug("Session is running")
return True
elif status == PipelineStatus.PENDING:
last_message = f"Session status: {status.value}"
log.debug(f"Session still pending/creating: {last_message}")
await asyncio.sleep(interval)
continue
elif status in [
PipelineStatus.FAILED,
PipelineStatus.ERROR,
PipelineStatus.STOPPED,
]:

if status in _TERMINAL_STATES:
raise ComputeHealthCheckException(
f"Session in terminal state: {status.value}. Message: {session_response.session_message}",
f"Session in terminal state: {status.value}. "
f"Message: {session_response.session_message}",
session_endpoint=compute_config.session_endpoint,
last_status=status.value,
)
else:
last_message = f"Session status: {status.value}"
log.debug(f"Unknown session status, continuing to wait: {last_message}")
await asyncio.sleep(interval)
continue

last_message = f"Pipeline status: {status.value}"
await asyncio.sleep(interval)

except ComputeHealthCheckException:
raise
except aiohttp.ClientResponseError as e:
last_message = f"HTTP {e.status}: {e.message}"
log.debug(f"HTTP error during health check: {last_message}")
log.debug(f"GET /health - error: {last_message} (attempt {attempt})")
await asyncio.sleep(interval)
except Exception as e:
last_message = str(e)
log.debug(f"Exception during health check: {last_message}")

await asyncio.sleep(interval)
log.debug(f"GET /health - error: {last_message} (attempt {attempt})")
await asyncio.sleep(interval)

log.error(f"Session timed out after {timeout}s. Last message: {last_message}")
raise TimeoutError(f"Session timed out after {timeout}s. Last message: {last_message}")
3 changes: 1 addition & 2 deletions eyepop/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ async def __get_access_token(self) -> str | None:
if self.compute_ctx.m2m_access_token:
return self.compute_ctx.m2m_access_token
else:
log.debug("compute ctx m2m access token is None, fetching new token")
assert self.client_session is not None
authenticate_url = f'{self.compute_ctx.compute_url}/v1/auth/authenticate'
api_auth_header = {
Expand All @@ -224,9 +223,9 @@ async def __get_access_token(self) -> str | None:
'Accept': 'application/json'
}
async with self.client_session.post(authenticate_url, headers=api_auth_header) as response:
log.debug(f"POST /v1/auth/authenticate - status: {response.status}")
response_json = await response.json()
self.compute_ctx.m2m_access_token = response_json['access_token']
log.debug(f"compute ctx m2m access token: {self.compute_ctx.m2m_access_token}")
return self.compute_ctx.m2m_access_token
if self.secret_key is None:
return None
Expand Down
8 changes: 8 additions & 0 deletions eyepop/eyepopsdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def workerEndpoint(
dataset_uuid: str | None = None,
pipeline_image: str | None = None,
pipeline_version: str | None = None,
**kwargs,
) -> WorkerEndpoint | SyncWorkerEndpoint:
if is_async:
return EyePopSdk.async_worker(
Expand All @@ -50,6 +51,7 @@ def workerEndpoint(
dataset_uuid=dataset_uuid,
pipeline_image=pipeline_image,
pipeline_version=pipeline_version,
**kwargs,
)
else:
return EyePopSdk.sync_worker(
Expand All @@ -67,6 +69,7 @@ def workerEndpoint(
dataset_uuid=dataset_uuid,
pipeline_image=pipeline_image,
pipeline_version=pipeline_version,
**kwargs,
)

@staticmethod
Expand All @@ -85,6 +88,7 @@ def sync_worker(
dataset_uuid: str | None = None,
pipeline_image: str | None = None,
pipeline_version: str | None = None,
**kwargs,
) -> SyncWorkerEndpoint:
endpoint = EyePopSdk.async_worker(
pop_id=pop_id,
Expand All @@ -101,6 +105,7 @@ def sync_worker(
dataset_uuid=dataset_uuid,
pipeline_image=pipeline_image,
pipeline_version=pipeline_version,
**kwargs,
)
return SyncWorkerEndpoint(endpoint)

Expand All @@ -120,6 +125,7 @@ def async_worker(
dataset_uuid: str | None = None,
pipeline_image: str | None = None,
pipeline_version: str | None = None,
**kwargs,
) -> WorkerEndpoint:
if is_local_mode is None:
local_mode_env = os.getenv("EYEPOP_LOCAL_MODE", "")
Expand Down Expand Up @@ -187,6 +193,8 @@ def async_worker(
dataset_uuid=dataset_uuid,
pipeline_image=pipeline_image,
pipeline_version=pipeline_version,
session_opts=kwargs.get("session_opts"),
session_headers=kwargs.get("session_headers"),
)
return endpoint

Expand Down
8 changes: 7 additions & 1 deletion eyepop/worker/worker_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def __init__(
dataset_uuid: str | None = None,
pipeline_image: str | None = None,
pipeline_version: str | None = None,
session_opts: dict | None = None,
session_headers: dict | None = None,
):
super().__init__(
secret_key=secret_key,
Expand All @@ -83,6 +85,10 @@ def __init__(
self.compute_ctx.pipeline_image = pipeline_image
if pipeline_version:
self.compute_ctx.pipeline_version = pipeline_version
if session_opts:
self.compute_ctx.session_opts = dict(session_opts)
if session_headers:
self.compute_ctx.session_headers = dict(session_headers)
self.is_dev_mode = not bool(session_uuid)
else:
self.is_dev_mode = True
Expand Down Expand Up @@ -194,7 +200,7 @@ async def _reconnect(self):
raise e

if self.compute_ctx:
log_requests.debug(f'Using compute context config: {self.worker_config}')
log_requests.debug(f'Compute session: {self.compute_ctx.session_uuid}')

base_url = await self.dev_mode_base_url()

Expand Down
Loading