Skip to content

Commit 2b45919

Browse files
committed
Simplify fixtures
1 parent af21a91 commit 2b45919

File tree

2 files changed

+13
-48
lines changed

2 files changed

+13
-48
lines changed

tests/test_distributed/test_with_dask/conftest.py

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,10 @@ def client_kwargs(request: pytest.FixtureRequest) -> Dict[str, Any]:
2525

2626

2727
@pytest.fixture(scope="session")
28-
def cluster(client_kwargs: Dict[str, Any]) -> Generator[LocalCluster, None, None]:
28+
def client(client_kwargs: Dict[str, Any]) -> Generator[Client, None, None]:
2929
with LocalCluster(**client_kwargs) as dask_cluster:
30-
yield dask_cluster
31-
32-
33-
@pytest.fixture(scope="session")
34-
def client(cluster: LocalCluster) -> Generator[Client, None, None]:
35-
with Client(cluster) as dask_client:
36-
yield dask_client
30+
with Client(dask_cluster) as dask_client:
31+
yield dask_client
3732

3833

3934
@pytest.fixture(autouse=True)
@@ -56,24 +51,3 @@ def client_one_worker() -> Generator[Client, None, None]:
5651
) as dask_cluster:
5752
with Client(dask_cluster) as dask_client:
5853
yield dask_client
59-
60-
61-
@pytest.fixture
62-
def client_factory() -> Any:
63-
@contextmanager
64-
def _factory(**kwargs: Any) -> Iterator[Client]:
65-
with LocalCluster(**kwargs) as dask_cluster:
66-
with Client(dask_cluster) as dask_client:
67-
yield dask_client
68-
69-
return _factory
70-
71-
72-
@pytest.fixture
73-
def client_from_cluster() -> Any:
74-
@contextmanager
75-
def _factory(cluster: LocalCluster) -> Iterator[Client]:
76-
with Client(cluster) as dask_client:
77-
yield dask_client
78-
79-
return _factory

tests/test_distributed/test_with_dask/test_with_dask.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2038,14 +2038,7 @@ def test_tree_stats(
20382038
assert local == distributed
20392039

20402040

2041-
@pytest.mark.parametrize(
2042-
"client_kwargs",
2043-
[pytest.param({"n_workers": 4, "dashboard_address": ":0"}, id="4-workers")],
2044-
indirect=True,
2045-
)
2046-
def test_parallel_submit_multi_clients(
2047-
client: "Client", cluster: "LocalCluster", client_from_cluster: Any
2048-
) -> None:
2041+
def test_parallel_submit_multi_clients() -> None:
20492042
"""Test for running multiple train simultaneously from multiple clients."""
20502043
try:
20512044
from distributed import MultiLock # NOQA
@@ -2054,19 +2047,17 @@ def test_parallel_submit_multi_clients(
20542047

20552048
from sklearn.datasets import load_digits
20562049

2057-
workers = tm.dask.get_client_workers(client)
2050+
with LocalCluster(n_workers=4, dashboard_address=":0") as cluster:
2051+
with Client(cluster) as client:
2052+
workers = tm.dask.get_client_workers(client)
20582053

2059-
n_submits = len(workers)
2060-
assert n_submits == 4
2061-
futures = []
2054+
n_submits = len(workers)
2055+
assert n_submits == 4
2056+
futures = []
20622057

2063-
with ExitStack() as stack:
20642058
for i in range(n_submits):
2065-
extra_client = stack.enter_context(client_from_cluster(cluster))
2059+
client = Client(cluster)
20662060
X_, y_ = load_digits(return_X_y=True)
2067-
X_, _, y_, _ = train_test_split(
2068-
X_, y_, train_size=300, stratify=y_, random_state=1994
2069-
)
20702061
X_ += 1.0
20712062
X = dd.from_array(X_, chunksize=32)
20722063
y = dd.from_array(y_, chunksize=32)
@@ -2075,8 +2066,8 @@ def test_parallel_submit_multi_clients(
20752066
n_estimators=i + 1,
20762067
eval_metric="merror",
20772068
)
2078-
f = extra_client.submit(cls.fit, X, y, pure=False)
2079-
futures.append((extra_client, f))
2069+
f = client.submit(cls.fit, X, y, pure=False)
2070+
futures.append((client, f))
20802071

20812072
t_futures = []
20822073
with ThreadPoolExecutor(max_workers=16) as e:

0 commit comments

Comments
 (0)