Skip to content

Commit 79d922a

Browse files
author
Giovanni Giacometti
authored
Merge pull request #9 from gloriadesideri/main
Added univariate streaming algorithms
2 parents 94c3e90 + e3e6628 commit 79d922a

29 files changed

Lines changed: 1962 additions & 1202 deletions

File tree

examples/huggingface/text_embedding_monitoring.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
from ml3_drift.huggingface.drift_detection_pipeline import (
44
HuggingFaceDriftDetectionPipeline,
55
)
6-
from ml3_drift.monitoring.multivariate.bonferroni import BonferroniCorrectionAlgorithm
7-
from ml3_drift.monitoring.univariate.continuous.ks import KSAlgorithm
6+
from ml3_drift.monitoring.algorithms.batch.bonferroni import (
7+
BonferroniCorrectionAlgorithm,
8+
)
9+
from ml3_drift.monitoring.algorithms.batch.ks import KSAlgorithm
810
from ml3_drift.callbacks.base import logger_callback
911

1012

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@ dependencies = [
2222

2323
sklearn = ["scikit-learn>=1.6.1"]
2424

25-
huggingface = ["scipy>=1.15.2", "transformers[torch]>=4.52.3"]
25+
huggingface = ["transformers[torch]>=4.52.3"]
2626

2727
polars = ["polars>=1.31.0"]
2828

2929
pandas = ["pandas>=2.2.3"]
3030

31+
river = ["river>=0.22.0"]
32+
3133

3234
# -------------------------------------------------
3335

src/ml3_drift/analysis/analyzer/base.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
from typing_extensions import TypeIs
55

66
from ml3_drift.analysis.report import Report
7-
from ml3_drift.monitoring.base import MonitoringAlgorithm
8-
from ml3_drift.monitoring.multivariate.bonferroni import BonferroniCorrectionAlgorithm
9-
from ml3_drift.monitoring.univariate.continuous.ks import KSAlgorithm
10-
from ml3_drift.monitoring.univariate.discrete.chi_square import (
7+
from ml3_drift.monitoring.algorithms.batch.bonferroni import (
8+
BonferroniCorrectionAlgorithm,
9+
)
10+
from ml3_drift.monitoring.algorithms.batch.ks import KSAlgorithm
11+
from ml3_drift.monitoring.algorithms.batch.chi_square import (
1112
ChiSquareAlgorithm,
1213
)
14+
from ml3_drift.monitoring.base.base import MonitoringAlgorithm
1315

1416
if TYPE_CHECKING:
1517
import pandas as pd
@@ -158,6 +160,10 @@ def analyze(
158160
else:
159161
categorical_columns_ids = []
160162

163+
if not continuous_columns_ids and not categorical_columns_ids:
164+
raise ValueError(
165+
"At least one of continuous_columns or categorical_columns must be provided."
166+
)
161167
# Input and target in canonical form
162168
array_X = self._to_numpy(X)
163169

src/ml3_drift/analysis/analyzer/batch.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from ml3_drift.analysis.analyzer.base import DataDriftAnalyzer
66
from ml3_drift.analysis.report import Report
7-
from ml3_drift.monitoring.base import MonitoringAlgorithm
7+
from ml3_drift.monitoring.base.base import MonitoringAlgorithm
88

99
from ml3_drift.models.monitoring import (
1010
MonitoringOutput,
@@ -74,7 +74,7 @@ def _single_scan_data(
7474
y_categorical: bool,
7575
first_batch_indexes: tuple[int, int],
7676
second_batch_indexes: tuple[int, int],
77-
) -> tuple[MonitoringOutput, MonitoringOutput]:
77+
) -> tuple[MonitoringOutput | None, MonitoringOutput | None]:
7878
"""
7979
Inner helper method that performs a single scan of two batches
8080
"""
@@ -95,16 +95,20 @@ def _single_scan_data(
9595
y_categorical,
9696
second_batch_indexes,
9797
)
98-
99-
cont_algorithm = deepcopy(self.continuous_monitoring_algorithm).fit(
100-
first_batch_cont
101-
)
102-
cat_algorithm = deepcopy(self.categorical_monitoring_algorithm).fit(
103-
first_batch_cat
104-
)
105-
106-
cont_output = cont_algorithm.detect(second_batch_cont)[0]
107-
cat_output = cat_algorithm.detect(second_batch_cat)[0]
98+
if len(continuous_columns_ids) > 0:
99+
cont_algorithm = deepcopy(self.continuous_monitoring_algorithm).fit(
100+
first_batch_cont
101+
)
102+
cont_output = cont_algorithm.detect(second_batch_cont)[0]
103+
else:
104+
cont_output = None
105+
if len(categorical_columns_ids) > 0:
106+
cat_algorithm = deepcopy(self.categorical_monitoring_algorithm).fit(
107+
first_batch_cat
108+
)
109+
cat_output = cat_algorithm.detect(second_batch_cat)[0]
110+
else:
111+
cat_output = None
108112

109113
return cont_output, cat_output
110114

@@ -149,7 +153,9 @@ def _scan_data(
149153
next_batch_indexes,
150154
)
151155

152-
if cont_output.drift_detected | cat_output.drift_detected:
156+
if (cont_output is not None and cont_output.drift_detected) | (
157+
cat_output is not None and cat_output.drift_detected
158+
):
153159
# if a drift is detected then, we close the current batch and open a new one
154160
merged_batches.append(
155161
(current_batch_start, current_batch_indexes[1] - 1)
@@ -181,7 +187,10 @@ def _scan_data(
181187

182188
# if no drift is detected the two batches are considered to belong to the same distribution
183189
# and are added to the same distribution list
184-
if not (cont_output.drift_detected | cat_output.drift_detected):
190+
if not (
191+
(cont_output is not None and cont_output.drift_detected)
192+
| (cat_output is not None and cat_output.drift_detected)
193+
):
185194
same_distributions[pair[0]].append(pair[1])
186195

187196
return Report(concepts=merged_batches, same_distributions=same_distributions)

src/ml3_drift/analysis/analyzer/stream.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from ml3_drift.analysis.analyzer.base import DataDriftAnalyzer
44
from ml3_drift.analysis.report import Report
5-
from ml3_drift.monitoring.base import MonitoringAlgorithm
5+
from ml3_drift.monitoring.base.base import MonitoringAlgorithm
66

77

88
class StreamDataDriftAnalyzer(DataDriftAnalyzer):

src/ml3_drift/huggingface/drift_detection_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
from transformers import Pipeline, pipeline
44

5-
from ml3_drift.monitoring.base import MonitoringAlgorithm
5+
from ml3_drift.monitoring.base.base import MonitoringAlgorithm
66

77

88
class HuggingFaceDriftDetectionPipeline:

src/ml3_drift/monitoring/multivariate/bonferroni.py renamed to src/ml3_drift/monitoring/algorithms/batch/bonferroni.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,18 @@
88
MonitoringAlgorithmSpecs,
99
MonitoringOutput,
1010
)
11-
from ml3_drift.monitoring.base import MonitoringAlgorithm
12-
from ml3_drift.monitoring.univariate.base import UnivariateMonitoringAlgorithm
11+
from ml3_drift.monitoring.base.base_multivariate import MultivariateMonitoringAlgorithm
12+
from ml3_drift.monitoring.base.base_univariate import UnivariateMonitoringAlgorithm
13+
from ml3_drift.monitoring.base.batch_monitoring_algorithm import (
14+
BatchMonitoringAlgorithm,
15+
)
1316

1417
T = TypeVar("T", bound=UnivariateMonitoringAlgorithm)
1518

1619

17-
class BonferroniCorrectionAlgorithm(MonitoringAlgorithm):
20+
class BonferroniCorrectionAlgorithm(
21+
BatchMonitoringAlgorithm, MultivariateMonitoringAlgorithm
22+
):
1823
"""
1924
Extension of p-value based univariate algorithms with Bonferroni correction
2025
to handle multivariate data

src/ml3_drift/monitoring/univariate/discrete/chi_square.py renamed to src/ml3_drift/monitoring/algorithms/batch/chi_square.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@
88
MonitoringAlgorithmSpecs,
99
MonitoringOutput,
1010
)
11-
from ml3_drift.monitoring.univariate.base import UnivariateMonitoringAlgorithm
11+
from ml3_drift.monitoring.base.base_univariate import UnivariateMonitoringAlgorithm
12+
from ml3_drift.monitoring.base.batch_monitoring_algorithm import (
13+
BatchMonitoringAlgorithm,
14+
)
1215

1316

14-
class ChiSquareAlgorithm(UnivariateMonitoringAlgorithm):
17+
class ChiSquareAlgorithm(BatchMonitoringAlgorithm, UnivariateMonitoringAlgorithm):
1518
"""Monitoring algorithm based on the Chi Square statistic test.
1619
1720
Parameters

src/ml3_drift/monitoring/univariate/continuous/ks.py renamed to src/ml3_drift/monitoring/algorithms/batch/ks.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@
88
MonitoringAlgorithmSpecs,
99
MonitoringOutput,
1010
)
11-
from ml3_drift.monitoring.univariate.base import UnivariateMonitoringAlgorithm
11+
from ml3_drift.monitoring.base.base_univariate import UnivariateMonitoringAlgorithm
12+
from ml3_drift.monitoring.base.batch_monitoring_algorithm import (
13+
BatchMonitoringAlgorithm,
14+
)
1215

1316

14-
class KSAlgorithm(UnivariateMonitoringAlgorithm):
17+
class KSAlgorithm(BatchMonitoringAlgorithm, UnivariateMonitoringAlgorithm):
1518
"""Monitoring algorithm based on the Kolmogorov-Smirnov statistic test.
1619
1720
Parameters
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from typing import Callable
2+
3+
import numpy as np
4+
from ml3_drift.enums.monitoring import DataDimension, DataType, MonitoringType
5+
from ml3_drift.models.monitoring import (
6+
DriftInfo,
7+
MonitoringAlgorithmSpecs,
8+
MonitoringOutput,
9+
)
10+
from ml3_drift.monitoring.base.base_univariate import UnivariateMonitoringAlgorithm
11+
from ml3_drift.monitoring.base.online_monitorning_algorithm import (
12+
OnlineMonitorningAlgorithm,
13+
)
14+
15+
RIVER = True
16+
try:
17+
from river.drift.adwin import ADWIN as RiverADWIN
18+
except ModuleNotFoundError:
19+
RIVER = False
20+
21+
22+
class ADWIN(OnlineMonitorningAlgorithm, UnivariateMonitoringAlgorithm):
23+
@classmethod
24+
def specs(cls) -> MonitoringAlgorithmSpecs:
25+
return MonitoringAlgorithmSpecs(
26+
data_dimension=DataDimension.MULTIVARIATE,
27+
data_type=DataType.MIX,
28+
monitoring_type=MonitoringType.ONLINE,
29+
)
30+
31+
def __init__(
32+
self,
33+
callbacks: list[Callable[[DriftInfo | None], None]] | None = None,
34+
p_value: float = 0.002,
35+
clock: float = 32,
36+
max_buckets: int = 5,
37+
min_window_length: int = 5,
38+
grace_period: int = 10,
39+
*args,
40+
**kwargs,
41+
) -> None:
42+
if not RIVER:
43+
raise ModuleNotFoundError(
44+
"River library is required for ADWIN algorithm. Please install it using pip install/ uv add ml3-drift[river]"
45+
)
46+
self.p_value = p_value
47+
self.clock = clock
48+
self.max_buckets = max_buckets
49+
self.min_window_length = min_window_length
50+
self.grace_period = grace_period
51+
self._args = args
52+
self._kwargs = kwargs
53+
super().__init__(
54+
comparison_size=1, callbacks=callbacks
55+
) # since we add only one sample per step and river handles building the window internally we set comparison_size to 1
56+
57+
def _reset_internal_parameters(self):
58+
self.drift_agent = RiverADWIN(
59+
delta=self.p_value,
60+
clock=self.clock,
61+
max_buckets=self.max_buckets,
62+
min_window_length=self.min_window_length,
63+
grace_period=self.grace_period,
64+
*self._args,
65+
**self._kwargs,
66+
)
67+
68+
def _fit(self, X: np.ndarray):
69+
"""Fit the KSWIN algorithm to the data."""
70+
self._validate(X)
71+
self.reset_internal_parameters()
72+
self.is_fitted = True
73+
74+
def _detect(self):
75+
self.drift_agent.update(self.comparison_data)
76+
drift_detected = self.drift_agent.drift_detected
77+
return MonitoringOutput(drift_detected=drift_detected, drift_info=None)

0 commit comments

Comments
 (0)