Skip to content

Commit 2dabc0d

Browse files
HuggingFace integration uses monitoring modules
1 parent 0b84b97 commit 2dabc0d

9 files changed

Lines changed: 69 additions & 78 deletions

File tree

examples/huggingface/text_embedding_monitoring.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from ml3_drift.huggingface.drift_detection_pipeline import (
44
HuggingFaceDriftDetectionPipeline,
55
)
6-
from ml3_drift.huggingface.univariate.ks import KSDriftDetector
6+
from ml3_drift.monitoring.multivariate.bonferroni import BonferroniCorrectionAlgorithm
7+
from ml3_drift.monitoring.univariate.continuous.ks import KSAlgorithm
78
from ml3_drift.callbacks.base import logger_callback
89

910

@@ -37,7 +38,8 @@
3738
# to monitor the drift in the embeddings.
3839

3940
hf_pipe = HuggingFaceDriftDetectionPipeline(
40-
drift_detector=KSDriftDetector(
41+
drift_detector=BonferroniCorrectionAlgorithm(
42+
algorithm=KSAlgorithm(p_value=0.05),
4143
callbacks=[
4244
partial(
4345
logger_callback,

src/ml3_drift/callbacks/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ def logger_callback(
1818
1919
callback = partial(logger_callback, logger=logging.getLogger("drift_callback"), level=logging.INFO)
2020
"""
21+
22+
if drift_info is None:
23+
logger.log(level, "Drift Detected!")
24+
return
2125
msg = f"Drift detected on feature at index {drift_info.feature_index} by drift detector {drift_info.drift_detector}."
2226

2327
if drift_info.p_value is not None:

src/ml3_drift/huggingface/drift_detection_pipeline.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
import torch
33
from transformers import Pipeline, pipeline
44

5-
from ml3_drift.huggingface.base import BaseDriftDetector
5+
from ml3_drift.monitoring.base import MonitoringAlgorithm
66

77

88
class HuggingFaceDriftDetectionPipeline:
9-
def __init__(self, drift_detector: BaseDriftDetector, **kwargs):
9+
def __init__(self, drift_detector: MonitoringAlgorithm, **kwargs): # noqa: F821
1010
"""
1111
Init
1212
"""
@@ -111,12 +111,12 @@ def _to_numpy(self, data) -> np.ndarray:
111111
return data
112112
case 3:
113113
# Take mean over the second-to-last dimension
114-
return data.mean(axis=1).reshape(1, -1)
114+
return data.mean(axis=1).reshape(-1, 1)
115115

116116
case 4:
117117
# Take mean over the second-to-last dimension and reshape
118-
# so that each sample is a row and each feature is a column
119-
return np.mean(data, axis=2).reshape(data.shape[0], -1)
118+
# so that each sample is a column vector
119+
return np.mean(data, axis=2).reshape(-1, data.shape[0])
120120
case _:
121121
raise ValueError(
122122
"Shape mismatch detected: expected data to have 3 or 4 dimensions, "

src/ml3_drift/monitoring/base.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class MonitoringAlgorithm(ABC):
3434
comparison_size: int | None, optional
3535
Only relevant in online monitoring algorithms.
3636
It defines the size of the sliding window used for comparison.
37-
callbacks: list[Callable[[DriftInfo], None]], optional
37+
callbacks: list[Callable[[DriftInfo | None], None]], optional
3838
A list of callback functions that are called when a drift is detected.
3939
Each callback receives a DriftInfo object containing information about the detected drift.
4040
If not provided, no callbacks are used.
@@ -50,7 +50,7 @@ def specs(cls) -> MonitoringAlgorithmSpecs:
5050
def __init__(
5151
self,
5252
comparison_size: int | None = None,
53-
callbacks: list[Callable[[DriftInfo], None]] | None = None,
53+
callbacks: list[Callable[[DriftInfo | None], None]] | None = None,
5454
) -> None:
5555
self.comparison_size = comparison_size
5656

@@ -220,10 +220,7 @@ def detect(self, X: np.ndarray) -> list[MonitoringOutput]:
220220

221221
if self.has_callbacks:
222222
for sample_output in detection_output:
223-
if (
224-
sample_output.drift_detected
225-
and sample_output.drift_info is not None
226-
):
223+
if sample_output.drift_detected:
227224
for callback in self.callbacks:
228225
callback(sample_output.drift_info)
229226

src/ml3_drift/monitoring/multivariate/bonferroni.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,13 @@ class BonferroniCorrectionAlgorithm(MonitoringAlgorithm):
2525
The univariate monitoring algorithm to be used for each dimension.
2626
p_value: float, default=0.005
2727
The p-value threshold for detecting drift, will be adjusted using Bonferroni correction.
28-
callbacks: list[Callable[[DriftInfo], None]] | None, default=None
29-
Callbacks to be executed when drift is detected.
30-
Each callback will receive a DriftInfo object with details about the detected drift.
28+
callbacks: list[Callable[[DriftInfo | None], None]] | None, optional
29+
A list of callback functions that are called when a drift is detected.
30+
Each callback receives a DriftInfo object containing information about the detected drift.
31+
If not provided, no callbacks are used. The current type hint also includes
32+
the case where drift_info is None (which happens for only some algorithms). This
33+
will change in the future as it's not very useful to have a callback that
34+
receives None as input.
3135
"""
3236

3337
@classmethod
@@ -42,7 +46,7 @@ def __init__(
4246
self,
4347
algorithm: T,
4448
p_value: float = 0.005,
45-
callbacks: list[Callable[[DriftInfo], None]] | None = None,
49+
callbacks: list[Callable[[DriftInfo | None], None]] | None = None,
4650
) -> None:
4751
super().__init__(comparison_size=None, callbacks=callbacks)
4852
self.p_value = p_value

src/ml3_drift/monitoring/univariate/continuous/ks.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ class KSAlgorithm(UnivariateMonitoringAlgorithm):
1818
----------
1919
p_value: float
2020
p-value threshold for detecting drift. Default is 0.005.
21+
callbacks: list[Callable[[DriftInfo | None], None]] | None, optional
22+
A list of callback functions that are called when a drift is detected.
23+
Each callback receives a DriftInfo object containing information about the detected drift.
24+
If not provided, no callbacks are used. The current type hint also includes
25+
the case where drift_info is None (which happens for only some algorithms). This
26+
will change in the future as it's not very useful to have a callback that
27+
receives None as input.
2128
"""
2229

2330
@classmethod
@@ -31,7 +38,7 @@ def specs(cls) -> MonitoringAlgorithmSpecs:
3138
def __init__(
3239
self,
3340
p_value: float = 0.005,
34-
callbacks: list[Callable[[DriftInfo], None]] | None = None,
41+
callbacks: list[Callable[[DriftInfo | None], None]] | None = None,
3542
) -> None:
3643
super().__init__(
3744
comparison_size=None,

src/ml3_drift/monitoring/univariate/discrete/chi_square.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ class ChiSquareAlgorithm(UnivariateMonitoringAlgorithm):
1818
----------
1919
p_value: float
2020
p-value threshold for detecting drift. Default is 0.005.
21+
callbacks: list[Callable[[DriftInfo | None], None]] | None, optional
22+
A list of callback functions that are called when a drift is detected.
23+
Each callback receives a DriftInfo object containing information about the detected drift.
24+
If not provided, no callbacks are used. The current type hint also includes
25+
the case where drift_info is None (which happens for only some algorithms). This
26+
will change in the future as it's not very useful to have a callback that
27+
receives None as input.
2128
"""
2229

2330
@classmethod
@@ -31,7 +38,7 @@ def specs(cls) -> MonitoringAlgorithmSpecs:
3138
def __init__(
3239
self,
3340
p_value: float = 0.005,
34-
callbacks: list[Callable[[DriftInfo], None]] | None = None,
41+
callbacks: list[Callable[[DriftInfo | None], None]] | None = None,
3542
) -> None:
3643
super().__init__(comparison_size=None, callbacks=callbacks)
3744
self._p_value = p_value

tests/test_huggingface/test_ks.py

Lines changed: 0 additions & 34 deletions
This file was deleted.

tests/test_huggingface/test_pipe.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1+
from ml3_drift.monitoring.multivariate.bonferroni import BonferroniCorrectionAlgorithm
2+
from ml3_drift.monitoring.univariate.continuous.ks import KSAlgorithm
13
from tests.conftest import is_module_available
24

35
import pytest
46

57
if is_module_available("transformers"):
6-
from ml3_drift.huggingface.univariate.ks import (
7-
KSDriftDetector,
8-
)
98
from ml3_drift.huggingface.drift_detection_pipeline import (
109
HuggingFaceDriftDetectionPipeline,
1110
)
@@ -27,13 +26,8 @@ def test_text(self, text_data, return_tensors):
2726
Test pipeline with text data for drift detection.
2827
"""
2928

30-
# Not optimal as we are loading a big model,
31-
# but it didn't work with a simple model taken
32-
# from here:
33-
# https://github.com/huggingface/transformers/blob/6e3063422c4b1c014aa60c32b9254fd2902f0f28/tests/pipelines/test_pipelines_feature_extraction.py#L46
34-
# We should do something.
3529
pipe = HuggingFaceDriftDetectionPipeline(
36-
drift_detector=KSDriftDetector(),
30+
drift_detector=KSAlgorithm(p_value=0.05),
3731
task="feature-extraction",
3832
model="hf-internal-testing/tiny-random-distilbert",
3933
framework="pt",
@@ -45,7 +39,7 @@ def test_text(self, text_data, return_tensors):
4539
)
4640

4741
assert pipe._drift_detector.is_fitted
48-
assert pipe._drift_detector.X_ref_.shape == (1, 32), (
42+
assert pipe._drift_detector.X_ref_.shape == (32, 1), (
4943
"Reference data shape mismatch."
5044
)
5145

@@ -60,7 +54,7 @@ def test_text(self, text_data, return_tensors):
6054
)
6155

6256
assert pipe._drift_detector.is_fitted
63-
assert pipe._drift_detector.X_ref_.shape == (1, 32), (
57+
assert pipe._drift_detector.X_ref_.shape == (32, 1), (
6458
"Reference data shape mismatch."
6559
)
6660

@@ -69,18 +63,24 @@ def test_text(self, text_data, return_tensors):
6963
return_tensors=return_tensors,
7064
)
7165

66+
pipe = HuggingFaceDriftDetectionPipeline(
67+
drift_detector=BonferroniCorrectionAlgorithm(
68+
p_value=0.05, algorithm=KSAlgorithm()
69+
),
70+
task="feature-extraction",
71+
model="hf-internal-testing/tiny-random-distilbert",
72+
framework="pt",
73+
)
74+
7275
pipe.fit_detector(
73-
[text_data, text_data],
76+
[text_data],
7477
return_tensors=return_tensors,
7578
)
7679

7780
assert pipe._drift_detector.is_fitted
78-
assert pipe._drift_detector.X_ref_.shape == (2, 32), (
79-
"Reference data shape mismatch."
80-
)
8181

8282
pipe(
83-
text_data,
83+
[text_data],
8484
return_tensors=return_tensors,
8585
)
8686

@@ -90,10 +90,8 @@ def test_image(self, image_data, return_tensors):
9090
Test pipeline with image data for drift detection.
9191
"""
9292

93-
# Not optimal as we are loading a big model,
94-
# We should do something.
9593
pipe = HuggingFaceDriftDetectionPipeline(
96-
drift_detector=KSDriftDetector(),
94+
drift_detector=KSAlgorithm(p_value=0.05),
9795
task="image-feature-extraction",
9896
model="hf-internal-testing/tiny-random-vit",
9997
framework="pt",
@@ -105,7 +103,7 @@ def test_image(self, image_data, return_tensors):
105103
)
106104

107105
assert pipe._drift_detector.is_fitted
108-
assert pipe._drift_detector.X_ref_.shape == (1, 32), (
106+
assert pipe._drift_detector.X_ref_.shape == (32, 1), (
109107
"Reference data shape mismatch."
110108
)
111109

@@ -120,7 +118,7 @@ def test_image(self, image_data, return_tensors):
120118
)
121119

122120
assert pipe._drift_detector.is_fitted
123-
assert pipe._drift_detector.X_ref_.shape == (1, 32), (
121+
assert pipe._drift_detector.X_ref_.shape == (32, 1), (
124122
"Reference data shape mismatch."
125123
)
126124

@@ -129,15 +127,21 @@ def test_image(self, image_data, return_tensors):
129127
return_tensors=return_tensors,
130128
)
131129

130+
pipe = HuggingFaceDriftDetectionPipeline(
131+
drift_detector=BonferroniCorrectionAlgorithm(
132+
p_value=0.05, algorithm=KSAlgorithm()
133+
),
134+
task="image-feature-extraction",
135+
model="hf-internal-testing/tiny-random-vit",
136+
framework="pt",
137+
)
138+
132139
pipe.fit_detector(
133-
[image_data, image_data],
140+
[image_data],
134141
return_tensors=return_tensors,
135142
)
136143

137144
assert pipe._drift_detector.is_fitted
138-
assert pipe._drift_detector.X_ref_.shape == (2, 32), (
139-
"Reference data shape mismatch."
140-
)
141145

142146
pipe(
143147
image_data,

0 commit comments

Comments
 (0)