Skip to content

Commit 72dd083

Browse files
Default algorithms builders
1 parent fddd086 commit 72dd083

1 file changed

Lines changed: 21 additions & 3 deletions

File tree

src/ml3_drift/analysis/analyzer/batch.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from ml3_drift.models.monitoring import (
1010
MonitoringOutput,
1111
)
12+
from ml3_drift.monitoring.multivariate.bonferroni import BonferroniCorrectionAlgorithm
13+
from ml3_drift.monitoring.univariate.continuous.ks import KSAlgorithm
14+
from ml3_drift.monitoring.univariate.discrete.chi_square import ChiSquareAlgorithm
1215

1316

1417
class BatchDataDriftAnalyzer(DataDriftAnalyzer):
@@ -23,14 +26,29 @@ class BatchDataDriftAnalyzer(DataDriftAnalyzer):
2326
continuous_ma_builder: closure function that accepts int parameter as `comparison_window_size`
2427
and returns an instance of a MonitoringAlgorithm
2528
categorical_ma_builder: closure function that accepts int parameter as `comparison_window_size`
26-
and returns an instance of a MonitoringAlgorithm
29+
and returns an instance of a MonitoringAlgorithm. Notice that this parameter is needed
30+
also when there are no categorical columns (even though it is not used).
2731
batch_size: initial batch dimensions and also used as comparison_window_size
2832
"""
2933

34+
DEFAULT_CONTINUOUS_BUILDER: Callable[[int], MonitoringAlgorithm] = ( # noqa: E731
35+
lambda _: BonferroniCorrectionAlgorithm(
36+
algorithm_builder=lambda p_value: KSAlgorithm(p_value=p_value)
37+
)
38+
)
39+
40+
DEFAULT_CATEGORICAL_BUILDER = lambda _: BonferroniCorrectionAlgorithm( # noqa: E731
41+
algorithm_builder=lambda p_value: ChiSquareAlgorithm(p_value=p_value),
42+
)
43+
3044
def __init__(
3145
self,
32-
continuous_ma_builder: Callable[[int], MonitoringAlgorithm],
33-
categorical_ma_builder: Callable[[int], MonitoringAlgorithm],
46+
continuous_ma_builder: Callable[
47+
[int], MonitoringAlgorithm
48+
] = DEFAULT_CONTINUOUS_BUILDER,
49+
categorical_ma_builder: Callable[
50+
[int], MonitoringAlgorithm
51+
] = DEFAULT_CATEGORICAL_BUILDER,
3452
batch_size: int = 100,
3553
):
3654
super().__init__(

0 commit comments

Comments
 (0)