Skip to content

Commit 0e950f0

Browse files
committed
refactor: improve type hints for mypy
1 parent b459ddb commit 0e950f0

13 files changed

Lines changed: 111 additions & 112 deletions

asgi_webdav/auth.py

Lines changed: 12 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import hashlib
44
import re
55
from base64 import b64decode
6-
from enum import Enum
76
from logging import getLogger
87
from typing import Any
98
from urllib.parse import parse_qs
@@ -16,7 +15,7 @@
1615
DAVCacheType,
1716
)
1817
from asgi_webdav.config import Config
19-
from asgi_webdav.constants import DAVMethod, DAVUser
18+
from asgi_webdav.constants import DAVMethod, DAVUpperEnumAbc, DAVUser
2019
from asgi_webdav.exception import DAVExceptionAuthFailed, DAVExceptionConfig
2120
from asgi_webdav.request import DAVRequest
2221
from asgi_webdav.response import DAVResponse
@@ -58,32 +57,22 @@ def _md5(data: str) -> str:
5857
return hashlib.new("md5", data.encode("utf-8")).hexdigest()
5958

6059

61-
class DAVPasswordType(Enum):
62-
INVALID = ":", 0
60+
class DAVPasswordType(DAVUpperEnumAbc):
61+
INVALID = "X", -1
62+
6363
RAW = ":", 0
6464
HASHLIB = ":", 4
6565
DIGEST = ":", 3
6666
LDAP = "#", 5
6767

68-
def __init__(self, *args, **kwds):
69-
self._value_ = self._name_
70-
self.split_char = args[0]
71-
self.split_count = args[1]
68+
def __init__(self, *args: Any, **kwargs: Any) -> None:
69+
super().__init__(*args, **kwargs)
7270

73-
@classmethod
74-
def _missing_(cls, value):
75-
if isinstance(value, str):
76-
value = value.upper()
77-
else:
78-
return cls.INVALID
79-
80-
try:
81-
return cls[value]
71+
self.split_char, self.split_count = args
8272

83-
except KeyError:
84-
pass
85-
86-
return cls.INVALID
73+
@classmethod
74+
def default_value(cls, value: Any) -> str:
75+
return "INVALID"
8776

8877

8978
class DAVPassword:
@@ -233,7 +222,7 @@ def check_digest_password(
233222

234223
return False, None
235224

236-
def __repr__(self):
225+
def __repr__(self) -> str:
237226
return f"{self.type}|{self.data}"
238227

239228

@@ -443,7 +432,7 @@ def nonce(self) -> str:
443432
return _md5(f"{uuid4().hex}{self.secret}")
444433

445434
@staticmethod
446-
def authorization_str_parser_to_data(authorization: str) -> dict:
435+
def authorization_str_parser_to_data(authorization: str) -> dict[str, str]:
447436
values = authorization.split(",")
448437
data = dict()
449438
for value in values:
@@ -720,18 +709,3 @@ def _match_user_agent(rule: str, user_agent: str) -> bool:
720709
return False
721710

722711
return True
723-
724-
@staticmethod
725-
def _parser_digest_request(authorization: str) -> dict:
726-
values = authorization[7:].split(",")
727-
728-
data = dict()
729-
for value in values:
730-
value = value.replace('"', "").replace(" ", "")
731-
try:
732-
k, v = value.split("=")
733-
data[k] = v
734-
except ValueError:
735-
pass
736-
737-
return data

asgi_webdav/cache.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ class DAVCacheAbc: # pragma: no cover
2020
async def prepare(self) -> None:
2121
raise NotImplementedError
2222

23-
async def get(self, key):
23+
async def get(self, key: str | bytes) -> Any:
2424
raise NotImplementedError
2525

26-
async def set(self, key, value):
26+
async def set(self, key: str | bytes, value: Any) -> None:
2727
raise NotImplementedError
2828

2929
async def purge(self) -> None:
@@ -38,10 +38,10 @@ class DAVCacheBypass(DAVCacheAbc):
3838
async def prepare(self) -> None: # pragma: no cover
3939
pass
4040

41-
async def get(self, key):
41+
async def get(self, key: str | bytes) -> Any:
4242
return None
4343

44-
async def set(self, key, value):
44+
async def set(self, key: str | bytes, value: Any) -> None:
4545
pass
4646

4747
async def purge(self) -> None:
@@ -53,7 +53,7 @@ async def close(self) -> None: # pragma: no cover
5353

5454
class DAVCacheMemory(DAVCacheAbc):
5555
_lock: Lock
56-
_cache: dict
56+
_cache: dict[str | bytes, Any]
5757

5858
def __init__(self) -> None:
5959
self._cache = {}
@@ -62,11 +62,11 @@ def __init__(self) -> None:
6262
async def prepare(self) -> None: # pragma: no cover
6363
pass
6464

65-
async def get(self, key):
65+
async def get(self, key: str | bytes) -> Any:
6666
async with self._lock:
6767
return self._cache.get(key)
6868

69-
async def set(self, key, value):
69+
async def set(self, key: str | bytes, value: Any) -> None:
7070
async with self._lock:
7171
self._cache[key] = value
7272

@@ -80,7 +80,7 @@ async def close(self) -> None: # pragma: no cover
8080

8181
class DAVCacheExpiring(DAVCacheAbc):
8282
_lock: Lock
83-
_cache: dict[bytes, tuple[Any, datetime]]
83+
_cache: dict[str | bytes, tuple[Any, datetime]]
8484
_cache_expiration_timedelta: timedelta
8585

8686
def __init__(self, cache_expiration: int) -> None:
@@ -94,7 +94,7 @@ def __init__(self, cache_expiration: int) -> None:
9494
async def prepare(self) -> None: # pragma: no cover
9595
pass
9696

97-
async def get(self, key):
97+
async def get(self, key: str | bytes) -> Any:
9898
async with self._lock:
9999
cached = self._cache.get(key)
100100
if cached:
@@ -105,7 +105,7 @@ async def get(self, key):
105105
# Cache entry expired
106106
self._cache.pop(key, None)
107107

108-
async def set(self, key, value):
108+
async def set(self, key: str | bytes, value: Any) -> None:
109109
async with self._lock:
110110
self._cache[key] = (value, datetime.now())
111111

asgi_webdav/config.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import json
22
import sys
3+
from collections.abc import Callable
34
from dataclasses import dataclass, field
45
from enum import Enum
56
from logging import getLogger
67
from pathlib import Path
8+
from typing import Any
79

810
if sys.version_info >= (3, 11):
911
import tomllib
@@ -105,8 +107,8 @@ class GuessTypeExtension:
105107
enable: bool = True
106108
enable_default_mapping: bool = True
107109

108-
filename_mapping: dict = field(default_factory=dict)
109-
suffix_mapping: dict = field(default_factory=dict)
110+
filename_mapping: dict[str, str] = field(default_factory=dict)
111+
suffix_mapping: dict[str, str] = field(default_factory=dict)
110112

111113

112114
@dataclass
@@ -305,15 +307,19 @@ def get_config() -> Config:
305307
return _config
306308

307309

308-
def get_config_copy_from_dict(data: dict, complete_config: bool = False) -> Config:
310+
def get_config_copy_from_dict(
311+
data: dict[str, Any], complete_config: bool = False
312+
) -> Config:
309313
config = Config.from_dict(data)
310314
if complete_config:
311315
config._complete_config()
312316

313317
return config
314318

315319

316-
def reinit_config_from_dict(data: dict, complete_config: bool = False) -> Config:
320+
def reinit_config_from_dict(
321+
data: dict[str, Any], complete_config: bool = False
322+
) -> Config:
317323
global _config
318324

319325
logger.debug("Load config value from python object(dict)")
@@ -323,6 +329,8 @@ def reinit_config_from_dict(data: dict, complete_config: bool = False) -> Config
323329

324330

325331
def reinit_config_from_file(file_name: str, complete_config: bool = False) -> bool:
332+
load_func: Callable[[Any], Any]
333+
326334
file = Path(file_name)
327335
match file.suffix:
328336
case ".json":

asgi_webdav/constants.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from enum import Enum, IntEnum, auto
55
from functools import cache
66
from time import time
7-
from typing import NewType
7+
from typing import Any, NewType
88
from uuid import UUID
99

1010
import arrow
@@ -23,31 +23,26 @@ class DAVUpperEnumAbc(Enum):
2323
默认值为空,需要继承实现;默认不会自动匹配默认值
2424
"""
2525

26-
def __init__(self, *args, **kwargs) -> None:
26+
def __init__(self, *args: Any, **kwargs: Any) -> None:
2727
self._value_ = self._name_.upper()
28-
2928
label = args[0]
3029
if not isinstance(label, str):
3130
self.label = str(label)
3231
else:
3332
self.label = label
3433

3534
@classmethod
36-
def _missing_(cls, value):
35+
def _missing_(cls, value: Any) -> "DAVUpperEnumAbc":
3736
if not isinstance(value, str):
3837
raise ValueError(f"Invalid {cls.__name__} value: {value}")
3938

40-
if "." in value:
41-
# 兼容枚举前后端转换问题
42-
value = value.split(".")[1]
43-
4439
try:
4540
return cls[value.upper()]
4641
except KeyError:
4742
return cls[cls.default_value(value).upper()]
4843

4944
@classmethod
50-
def default_value(cls, value) -> str:
45+
def default_value(cls, value: Any) -> str:
5146
raise ValueError(f"Invalid {cls.__name__} value: {value}")
5247

5348
@classmethod
@@ -68,6 +63,9 @@ def value_label_mapping(cls) -> dict[str, str]:
6863

6964
# WebDAV protocol ---
7065
class DAVMethod(DAVUpperEnumAbc):
66+
# default/fallback
67+
UNKNOWN = auto()
68+
7169
# rfc4918:9.1
7270
PROPFIND = auto()
7371
# rfc4918:9.2
@@ -93,11 +91,8 @@ class DAVMethod(DAVUpperEnumAbc):
9391
# only for inside page
9492
POST = auto()
9593

96-
# only for request parser failed
97-
UNKNOWN = auto()
98-
9994
@classmethod
100-
def default_value(cls, value) -> str:
95+
def default_value(cls, value: Any) -> str:
10196
return "UNKNOWN"
10297

10398
@classmethod
@@ -216,7 +211,10 @@ def add_child(self, child: "DAVPath | str") -> "DAVPath":
216211
def __hash__(self) -> int:
217212
return hash(self.raw)
218213

219-
def __eq__(self, other):
214+
def __eq__(self, other: object) -> bool:
215+
if not isinstance(other, DAVPath):
216+
return False
217+
220218
return self.raw == other.raw
221219

222220
def __lt__(self, other: "DAVPath") -> bool:

asgi_webdav/lock.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import pprint
3+
from collections.abc import Iterable
34
from time import time
45
from uuid import UUID, uuid4
56

@@ -21,7 +22,7 @@ def __init__(self) -> None:
2122
def __contains__(self, item: DAVPath) -> bool:
2223
return item in self.data
2324

24-
def keys(self):
25+
def keys(self) -> Iterable[DAVPath]:
2526
return self.data.keys()
2627

2728
def get_tokens(self, path: DAVPath) -> list[UUID]:
@@ -129,7 +130,7 @@ async def is_locking(self, path: DAVPath, owner_token: UUID | None = None) -> bo
129130
return False
130131

131132
async def get_info_by_path(self, path: DAVPath) -> list[DAVLockInfo]:
132-
result = list()
133+
result: list[DAVLockInfo] = list()
133134
async with self.lock:
134135
if path not in self.path2token_map:
135136
return result
@@ -170,7 +171,7 @@ async def _release_by_path(self, path: DAVPath) -> None:
170171
for token in self.path2token_map.get_tokens(path):
171172
self._remove_token(path, token)
172173

173-
def __repr__(self):
174+
def __repr__(self) -> str:
174175
s = "{}\n{}".format(
175176
pprint.pformat(self.path2token_map), pprint.pformat(self.lock_map)
176177
)

asgi_webdav/property.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,20 @@ class DAVPropertyBasicData:
1414
last_modified: DAVTime
1515

1616
# resource_type: str = field(init=False)
17-
content_type: str | None = field(default=None)
17+
content_type: str = ""
1818
content_charset: str | None = None
19-
content_length: int = field(default=0)
19+
content_length: int = 0
2020
content_encoding: str | None = None
2121

2222
def __post_init__(self) -> None:
2323
# https://developer.mozilla.org/zh-CN/docs/Web/HTTP/Basics_of_HTTP/MIME_types
24-
if self.content_type is None:
24+
if not self.content_type:
2525
if self.is_collection:
2626
# self.content_type = "httpd/unix-directory"
2727
self.content_type = "application/index"
2828
else:
2929
self.content_type = "application/octet-stream"
3030

31-
if self.content_length is None:
32-
self.content_length = 0
33-
3431
@property
3532
def etag(self) -> str:
3633
return generate_etag(self.content_length, self.last_modified.timestamp)
@@ -66,8 +63,8 @@ def get_get_head_response_headers(self) -> dict[bytes, bytes]:
6663

6764
return headers
6865

69-
def as_dict(self) -> dict[str, str]:
70-
data = {
66+
def as_dict(self) -> dict[str, str | int]:
67+
data: dict[str, str | int] = {
7168
"displayname": self.display_name,
7269
"getetag": self.etag,
7370
"creationdate": self.creation_date.dav_creation_date(),

asgi_webdav/provider/webhdfs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ async def _create_dav_property_obj(
137137
last_modified=DAVTime(float(file_status.get("modificationTime", 0.0))),
138138
content_type=content_type,
139139
content_charset=charset,
140-
content_length=file_status.get("length"),
140+
content_length=file_status.get("length", 0),
141141
content_encoding=content_encoding,
142142
)
143143

0 commit comments

Comments
 (0)