Skip to content

Commit af21a91

Browse files
committed
Faster dask tests
1 parent 3ac20d9 commit af21a91

File tree

3 files changed

+137
-75
lines changed

3 files changed

+137
-75
lines changed

python-package/xgboost/testing/dask.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -338,13 +338,13 @@ def create_dmatrix(
338338
return DMatrixT(*args, **kwargs)
339339

340340
def run(DMatrixT: Type[dxgb.DaskDMatrix]) -> None:
341-
enc, reenc, y, _, _ = make_recoded(device, n_features=96)
341+
enc, reenc, y, _, _ = make_recoded(device, n_features=16)
342342
to = get_client_workers(client)
343343

344344
denc, dreenc, dy = (
345-
dd.from_pandas(enc, npartitions=8).persist(workers=to),
346-
dd.from_pandas(reenc, npartitions=8).persist(workers=to),
347-
da.from_array(y, chunks=(y.shape[0] // 8,)).persist(workers=to),
345+
dd.from_pandas(enc, npartitions=2).persist(workers=to),
346+
dd.from_pandas(reenc, npartitions=2).persist(workers=to),
347+
da.from_array(y, chunks=(y.shape[0] // 2,)).persist(workers=to),
348348
)
349349

350350
Xy = create_dmatrix(DMatrixT, client, denc, dy, enable_categorical=True)

tests/test_distributed/test_with_dask/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,17 @@ def client(cluster: LocalCluster) -> Generator[Client, None, None]:
3636
yield dask_client
3737

3838

39+
@pytest.fixture(autouse=True)
40+
def client_as_current(request: pytest.FixtureRequest) -> Generator[None, None, None]:
41+
for name in ("client", "client_one_worker"):
42+
if name in request.fixturenames:
43+
dask_client = request.getfixturevalue(name)
44+
with dask_client.as_current():
45+
yield
46+
return
47+
yield
48+
49+
3950
@pytest.fixture(scope="session")
4051
def client_one_worker() -> Generator[Client, None, None]:
4152
n_threads = os.cpu_count()

0 commit comments

Comments
 (0)