1+ from ml3_drift .monitoring .multivariate .bonferroni import BonferroniCorrectionAlgorithm
2+ from ml3_drift .monitoring .univariate .continuous .ks import KSAlgorithm
13from tests .conftest import is_module_available
24
35import pytest
46
57if is_module_available ("transformers" ):
6- from ml3_drift .huggingface .univariate .ks import (
7- KSDriftDetector ,
8- )
98 from ml3_drift .huggingface .drift_detection_pipeline import (
109 HuggingFaceDriftDetectionPipeline ,
1110 )
@@ -27,13 +26,8 @@ def test_text(self, text_data, return_tensors):
2726 Test pipeline with text data for drift detection.
2827 """
2928
30- # Not optimal as we are loading a big model,
31- # but it didn't work with a simple model taken
32- # from here:
33- # https://github.com/huggingface/transformers/blob/6e3063422c4b1c014aa60c32b9254fd2902f0f28/tests/pipelines/test_pipelines_feature_extraction.py#L46
34- # We should do something.
3529 pipe = HuggingFaceDriftDetectionPipeline (
36- drift_detector = KSDriftDetector ( ),
30+ drift_detector = KSAlgorithm ( p_value = 0.05 ),
3731 task = "feature-extraction" ,
3832 model = "hf-internal-testing/tiny-random-distilbert" ,
3933 framework = "pt" ,
@@ -45,7 +39,7 @@ def test_text(self, text_data, return_tensors):
4539 )
4640
4741 assert pipe ._drift_detector .is_fitted
48- assert pipe ._drift_detector .X_ref_ .shape == (1 , 32 ), (
42+ assert pipe ._drift_detector .X_ref_ .shape == (32 , 1 ), (
4943 "Reference data shape mismatch."
5044 )
5145
@@ -60,7 +54,7 @@ def test_text(self, text_data, return_tensors):
6054 )
6155
6256 assert pipe ._drift_detector .is_fitted
63- assert pipe ._drift_detector .X_ref_ .shape == (1 , 32 ), (
57+ assert pipe ._drift_detector .X_ref_ .shape == (32 , 1 ), (
6458 "Reference data shape mismatch."
6559 )
6660
@@ -69,18 +63,24 @@ def test_text(self, text_data, return_tensors):
6963 return_tensors = return_tensors ,
7064 )
7165
66+ pipe = HuggingFaceDriftDetectionPipeline (
67+ drift_detector = BonferroniCorrectionAlgorithm (
68+ p_value = 0.05 , algorithm = KSAlgorithm ()
69+ ),
70+ task = "feature-extraction" ,
71+ model = "hf-internal-testing/tiny-random-distilbert" ,
72+ framework = "pt" ,
73+ )
74+
7275 pipe .fit_detector (
73- [text_data , text_data ],
76+ [text_data ],
7477 return_tensors = return_tensors ,
7578 )
7679
7780 assert pipe ._drift_detector .is_fitted
78- assert pipe ._drift_detector .X_ref_ .shape == (2 , 32 ), (
79- "Reference data shape mismatch."
80- )
8181
8282 pipe (
83- text_data ,
83+ [ text_data ] ,
8484 return_tensors = return_tensors ,
8585 )
8686
@@ -90,10 +90,8 @@ def test_image(self, image_data, return_tensors):
9090 Test pipeline with image data for drift detection.
9191 """
9292
93- # Not optimal as we are loading a big model,
94- # We should do something.
9593 pipe = HuggingFaceDriftDetectionPipeline (
96- drift_detector = KSDriftDetector ( ),
94+ drift_detector = KSAlgorithm ( p_value = 0.05 ),
9795 task = "image-feature-extraction" ,
9896 model = "hf-internal-testing/tiny-random-vit" ,
9997 framework = "pt" ,
@@ -105,7 +103,7 @@ def test_image(self, image_data, return_tensors):
105103 )
106104
107105 assert pipe ._drift_detector .is_fitted
108- assert pipe ._drift_detector .X_ref_ .shape == (1 , 32 ), (
106+ assert pipe ._drift_detector .X_ref_ .shape == (32 , 1 ), (
109107 "Reference data shape mismatch."
110108 )
111109
@@ -120,7 +118,7 @@ def test_image(self, image_data, return_tensors):
120118 )
121119
122120 assert pipe ._drift_detector .is_fitted
123- assert pipe ._drift_detector .X_ref_ .shape == (1 , 32 ), (
121+ assert pipe ._drift_detector .X_ref_ .shape == (32 , 1 ), (
124122 "Reference data shape mismatch."
125123 )
126124
@@ -129,15 +127,21 @@ def test_image(self, image_data, return_tensors):
129127 return_tensors = return_tensors ,
130128 )
131129
130+ pipe = HuggingFaceDriftDetectionPipeline (
131+ drift_detector = BonferroniCorrectionAlgorithm (
132+ p_value = 0.05 , algorithm = KSAlgorithm ()
133+ ),
134+ task = "image-feature-extraction" ,
135+ model = "hf-internal-testing/tiny-random-vit" ,
136+ framework = "pt" ,
137+ )
138+
132139 pipe .fit_detector (
133- [image_data , image_data ],
140+ [image_data ],
134141 return_tensors = return_tensors ,
135142 )
136143
137144 assert pipe ._drift_detector .is_fitted
138- assert pipe ._drift_detector .X_ref_ .shape == (2 , 32 ), (
139- "Reference data shape mismatch."
140- )
141145
142146 pipe (
143147 image_data ,
0 commit comments