Skip to content

Commit d512e73

Browse files
authored
[dev] update mypy to 1.14.1 (#7101)
* [dev] update mypy to 1.16.1 mypy 1.17 drops support for python 3.8, so this is the latest version that works. I did this to get python/mypy#15896, which will allow us to properly typecheck context_utils.to_thread. * switch to 1.14.1
1 parent f09291d commit d512e73

File tree

12 files changed

+40
-16
lines changed

12 files changed

+40
-16
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ repos:
3333
files: "^sky/skylet/providers/ibm/.*" # Only match IBM-specific directory
3434

3535
- repo: https://github.com/pre-commit/mirrors-mypy
36-
rev: v1.4.0 # Match the version from requirements
36+
rev: v1.14.1 # Match the version from requirements
3737
hooks:
3838
- id: mypy
3939
args:

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ isort==5.12.0
1616

1717
# type checking
1818
# match the version with .pre-commit-config.yaml
19-
mypy==1.4.0
19+
mypy==1.14.1
2020
types-PyYAML
2121
types-paramiko
2222
# 2.31 requires urlib3>2, which is incompatible with IBM and

sky/backends/backend_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ def get_expirable_clouds(
539539
# get all custom contexts
540540
contexts = kubernetes_utils.get_custom_config_k8s_contexts()
541541
# add remote_identity of each context if it exists
542-
remote_identities = None
542+
remote_identities: Optional[Union[str, List[Dict[str, str]]]] = None
543543
for context in contexts:
544544
context_remote_identity = skypilot_config.get_effective_region_config(
545545
cloud='kubernetes',
@@ -550,9 +550,11 @@ def get_expirable_clouds(
550550
if remote_identities is None:
551551
remote_identities = []
552552
if isinstance(context_remote_identity, str):
553+
assert isinstance(remote_identities, list)
553554
remote_identities.append(
554555
{context: context_remote_identity})
555556
elif isinstance(context_remote_identity, list):
557+
assert isinstance(remote_identities, list)
556558
remote_identities.extend(context_remote_identity)
557559
# add global kubernetes remote identity if it exists, if not, add default
558560
global_remote_identity = skypilot_config.get_effective_region_config(
@@ -564,8 +566,10 @@ def get_expirable_clouds(
564566
if remote_identities is None:
565567
remote_identities = []
566568
if isinstance(global_remote_identity, str):
569+
assert isinstance(remote_identities, list)
567570
remote_identities.append({'*': global_remote_identity})
568571
elif isinstance(global_remote_identity, list):
572+
assert isinstance(remote_identities, list)
569573
remote_identities.extend(global_remote_identity)
570574
if remote_identities is None:
571575
remote_identities = schemas.get_default_remote_identity(

sky/clouds/aws.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,9 @@ def _get_max_efa_interfaces(instance_type: str, region_name: str) -> int:
156156
try:
157157
client = aws.client('ec2', region_name=region_name)
158158
response = client.describe_instance_types(
159-
InstanceTypes=[instance_type],
159+
# TODO(cooperc): fix the types for mypy 1.16
160+
# Boto3 type stubs expect Literal instance types; using str list here.
161+
InstanceTypes=[instance_type], # type: ignore
160162
Filters=[{
161163
'Name': 'network-info.efa-supported',
162164
'Values': ['true']

sky/data/storage.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2700,7 +2700,11 @@ def from_metadata(cls, metadata: AbstractStore.StoreMetadata,
27002700
name=override_args.get('name', metadata.name),
27012701
storage_account_name=override_args.get(
27022702
'storage_account', metadata.storage_account_name),
2703-
source=override_args.get('source', metadata.source),
2703+
# TODO(cooperc): fix the types for mypy 1.16
2704+
# Azure store expects a string path; metadata.source may be a Path
2705+
# or List[Path].
2706+
source=override_args.get('source',
2707+
metadata.source), # type: ignore[arg-type]
27042708
region=override_args.get('region', metadata.region),
27052709
is_sky_managed=override_args.get('is_sky_managed',
27062710
metadata.is_sky_managed),

sky/jobs/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1425,12 +1425,12 @@ def load_managed_job_queue(
14251425
"""Load job queue from json string."""
14261426
result = message_utils.decode_payload(payload)
14271427
result_type = ManagedJobQueueResultType.DICT
1428-
status_counts = {}
1428+
status_counts: Dict[str, int] = {}
14291429
if isinstance(result, dict):
1430-
jobs = result['jobs']
1431-
total = result['total']
1430+
jobs: List[Dict[str, Any]] = result['jobs']
1431+
total: int = result['total']
14321432
status_counts = result.get('status_counts', {})
1433-
total_no_filter = result.get('total_no_filter', total)
1433+
total_no_filter: int = result.get('total_no_filter', total)
14341434
else:
14351435
jobs = result
14361436
total = len(jobs)

sky/provision/aws/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,10 @@ def _get_route_tables(ec2: 'mypy_boto3_ec2.ServiceResource',
305305
Returns:
306306
A list of route tables associated with the options VPC and region
307307
"""
308-
filters = [{'Name': 'association.main', 'Values': [str(main).lower()]}]
308+
filters: List['ec2_type_defs.FilterTypeDef'] = [{
309+
'Name': 'association.main',
310+
'Values': [str(main).lower()],
311+
}]
309312
if vpc_id is not None:
310313
filters.append({'Name': 'vpc-id', 'Values': [vpc_id]})
311314
logger.debug(

sky/provision/gcp/config.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import typing
66
from typing import Any, Dict, List, Set, Tuple
77

8+
from typing_extensions import TypedDict
9+
810
from sky.adaptors import gcp
911
from sky.clouds.utils import gcp_utils
1012
from sky.provision import common
@@ -415,6 +417,9 @@ def _configure_iam_role(config: common.ProvisionConfig, crm, iam) -> dict:
415417
return iam_role
416418

417419

420+
AllowedList = TypedDict('AllowedList', {'IPProtocol': str, 'ports': List[str]})
421+
422+
418423
def _check_firewall_rules(cluster_name: str, vpc_name: str, project_id: str,
419424
compute):
420425
"""Check if the firewall rules in the VPC are sufficient."""
@@ -466,7 +471,7 @@ def _merge_and_refine_rule(
466471
}
467472
"""
468473
source2rules: Dict[Tuple[str, str], Dict[str, Set[int]]] = {}
469-
source2allowed_list: Dict[Tuple[str, str], List[Dict[str, str]]] = {}
474+
source2allowed_list: Dict[Tuple[str, str], List[AllowedList]] = {}
470475
for rule in rules:
471476
# Rules applied to specific VM (targetTags) may not work for the
472477
# current VM, so should be skipped.

sky/server/requests/serializers/encoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def encode_jobs_queue(jobs: List[dict],) -> List[Dict[str, Any]]:
131131
def encode_jobs_queue_v2(
132132
jobs_or_tuple) -> Union[List[Dict[str, Any]], Dict[str, Any]]:
133133
# Support returning either a plain jobs list or a (jobs, total) tuple
134-
status_counts = {}
134+
status_counts: Dict[str, int] = {}
135135
if isinstance(jobs_or_tuple, tuple):
136136
if len(jobs_or_tuple) == 2:
137137
jobs, total = jobs_or_tuple

sky/server/server.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1829,7 +1829,7 @@ async def all_contexts(request: fastapi.Request) -> None:
18291829
async def gpu_metrics() -> fastapi.Response:
18301830
"""Gets the GPU metrics from multiple external k8s clusters"""
18311831
contexts = core.get_all_contexts()
1832-
all_metrics = []
1832+
all_metrics: List[str] = []
18331833
successful_contexts = 0
18341834

18351835
tasks = [
@@ -1844,6 +1844,10 @@ async def gpu_metrics() -> fastapi.Response:
18441844
if isinstance(result, Exception):
18451845
logger.error(
18461846
f'Failed to get metrics for context {contexts[i]}: {result}')
1847+
elif isinstance(result, BaseException):
1848+
# Avoid changing behavior for non-Exception BaseExceptions
1849+
# like KeyboardInterrupt/SystemExit: re-raise them.
1850+
raise result
18471851
else:
18481852
metrics_text = result
18491853
all_metrics.append(metrics_text)

0 commit comments

Comments
 (0)