Skip to content

Commit 5d6f0e2

Browse files
author
Giovanni Giacometti
authored
Merge pull request #11 from ml-cube/dev-new-proposition
docs + fix: new proposition, minor fixes
2 parents fea00c5 + 4cf78b3 commit 5d6f0e2

9 files changed

Lines changed: 1295 additions & 1200 deletions

File tree

README.md

Lines changed: 128 additions & 34 deletions
Large diffs are not rendered by default.

examples/sklearn/continuous_data.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,30 @@
11
import logging
22

33
import numpy as np
4-
from ml3_drift.sklearn.univariate.ks import KSDriftDetector
4+
from ml3_drift.monitoring.algorithms.batch.ks import KSAlgorithm
55
from sklearn.tree import DecisionTreeRegressor
66
from sklearn.pipeline import Pipeline
77
from sklearn.preprocessing import StandardScaler
88
from ml3_drift.callbacks.base import logger_callback
99
from functools import partial
1010

11+
from ml3_drift.sklearn.base import SklearnDriftDetector
12+
1113
logger = logging.getLogger(__name__)
1214

1315

1416
if __name__ == "__main__":
1517
# Define your pipeline as usual, but also add a drift detector
16-
drift_detector = KSDriftDetector(
17-
callbacks=[
18-
partial(
19-
logger_callback,
20-
logger=logger,
21-
level=logging.CRITICAL,
22-
)
23-
]
18+
drift_detector = SklearnDriftDetector(
19+
KSAlgorithm(
20+
callbacks=[
21+
partial(
22+
logger_callback,
23+
logger=logger,
24+
level=logging.CRITICAL,
25+
)
26+
]
27+
)
2428
)
2529

2630
pipeline = Pipeline(

examples/sklearn/mixed_data_monitoring.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from functools import partial
22
import logging
3-
from ml3_drift.sklearn.univariate.ks import KSDriftDetector
4-
from ml3_drift.sklearn.univariate.chi_square import ChiSquareDriftDetector
3+
from ml3_drift.monitoring.algorithms.batch.ks import KSAlgorithm
4+
from ml3_drift.monitoring.algorithms.batch.chi_square import ChiSquareAlgorithm
55

66
from sklearn.tree import DecisionTreeRegressor
77
from sklearn.pipeline import Pipeline
88
from sklearn.compose import ColumnTransformer
99
from sklearn.preprocessing import StandardScaler, OrdinalEncoder
1010
import numpy as np
1111
from ml3_drift.callbacks.base import logger_callback
12+
from ml3_drift.sklearn.base import SklearnDriftDetector
1213

1314

1415
logger = logging.getLogger(__name__)
@@ -51,27 +52,31 @@
5152
transformers=[
5253
(
5354
"cont",
54-
KSDriftDetector(
55-
callbacks=[
56-
partial(
57-
logger_callback,
58-
logger=logger,
59-
level=logging.CRITICAL,
60-
),
61-
]
55+
SklearnDriftDetector(
56+
KSAlgorithm(
57+
callbacks=[
58+
partial(
59+
logger_callback,
60+
logger=logger,
61+
level=logging.CRITICAL,
62+
),
63+
]
64+
)
6265
),
6366
[0, 1],
6467
),
6568
(
6669
"cat",
67-
ChiSquareDriftDetector(
68-
callbacks=[
69-
partial(
70-
logger_callback,
71-
logger=logger,
72-
level=logging.CRITICAL,
73-
),
74-
]
70+
SklearnDriftDetector(
71+
ChiSquareAlgorithm(
72+
callbacks=[
73+
partial(
74+
logger_callback,
75+
logger=logger,
76+
level=logging.CRITICAL,
77+
),
78+
]
79+
)
7580
),
7681
[2, 3],
7782
),

src/ml3_drift/analysis/report.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,6 @@ def __init__(
44
):
55
self.concepts = concepts
66
self.same_distributions = same_distributions
7+
8+
def __repr__(self):
9+
return f"Report(concepts={self.concepts}, same_distributions={dict(self.same_distributions)})"

src/ml3_drift/callbacks/base.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import dataclasses
12
from logging import Logger
2-
from ml3_drift.callbacks.models import DriftInfo
3+
4+
from ml3_drift.models.monitoring import DriftInfo
35

46

57
def logger_callback(
6-
drift_info: DriftInfo,
8+
drift_info: DriftInfo | None,
79
logger: Logger,
810
level: int,
911
) -> None:
@@ -20,14 +22,24 @@ def logger_callback(
2022
"""
2123

2224
if drift_info is None:
23-
logger.log(level, "Drift Detected!")
25+
logger.log(level, "Drift Detected, no drift info provided!")
2426
return
25-
msg = f"Drift detected on feature at index {drift_info.feature_index} by drift detector {drift_info.drift_detector}."
2627

27-
if drift_info.p_value is not None:
28-
msg += f"\n p-value = {drift_info.p_value}"
28+
logger.log(level, f"Drift Detected, drift info: {dataclasses.asdict(drift_info)}")
29+
30+
31+
def print_callback(drift_info: DriftInfo | None) -> None:
32+
"""
33+
Print callback prints a message to the console when drift is detected.
34+
It should be used only for testing purposes.
35+
36+
Example
37+
-------
2938
30-
if drift_info.threshold is not None:
31-
msg += f"\n Threshold = {drift_info.threshold}"
39+
callback = print_callback
40+
"""
41+
if drift_info is None:
42+
print("Drift Detected, no drift info provided!")
43+
return
3244

33-
logger.log(level, msg)
45+
print(f"Drift Detected, drift info: {dataclasses.asdict(drift_info)}")

src/ml3_drift/callbacks/models.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

src/ml3_drift/monitoring/algorithms/online/kswin.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,8 @@ def __init__(
4848
self.seed = seed
4949
self._args = args
5050
self._kwargs = kwargs
51-
super().__init__(
52-
comparison_size=1, callbacks=callbacks
53-
) # since we add only one sample per step and river handles building the window internally we set comparison_size to 1
51+
# since we add only one sample per step and river handles building the window internally we set comparison_size to 1
52+
super().__init__(comparison_size=1, callbacks=callbacks)
5453

5554
def _reset_internal_parameters(self):
5655
self.drift_agent = RiverKSWIN(

tests/test_monitoring/test_online/test_univariate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def test_kswin_univariate_one_drift(self, abrupt_univariate_online_drift_info):
4040
)
4141

4242
def test_kswin_univariate_two_drift(self, abrupt_univariate_online_bidrift_info):
43+
np.random.seed(42)
4344
data_stream, drift_point_1, drift_point_2 = (
4445
abrupt_univariate_online_bidrift_info
4546
)

0 commit comments

Comments
 (0)