Skip to content

Commit e3e6628

Browse files
added comment to clarify comparison size = 1 in online algorithm, fixed multivariate edge case validation
1 parent 6e92d34 commit e3e6628

4 files changed

Lines changed: 13 additions & 8 deletions

File tree

src/ml3_drift/monitoring/algorithms/online/adwin.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ def __init__(
5050
self.grace_period = grace_period
5151
self._args = args
5252
self._kwargs = kwargs
53-
super().__init__(comparison_size=1, callbacks=callbacks)
54-
53+
super().__init__(
54+
comparison_size=1, callbacks=callbacks
55+
) # since we add only one sample per step and river handles building the window internally we set comparison_size to 1
5556

5657
def _reset_internal_parameters(self):
5758
self.drift_agent = RiverADWIN(
@@ -61,7 +62,7 @@ def _reset_internal_parameters(self):
6162
min_window_length=self.min_window_length,
6263
grace_period=self.grace_period,
6364
*self._args,
64-
**self._kwargs
65+
**self._kwargs,
6566
)
6667

6768
def _fit(self, X: np.ndarray):

src/ml3_drift/monitoring/algorithms/online/kswin.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def __init__(
4848
self.seed = seed
4949
self._args = args
5050
self._kwargs = kwargs
51-
super().__init__(comparison_size=1, callbacks=callbacks)
51+
super().__init__(
52+
comparison_size=1, callbacks=callbacks
53+
) # since we add only one sample per step and river handles building the window internally we set comparison_size to 1
5254

5355
def _reset_internal_parameters(self):
5456
self.drift_agent = RiverKSWIN(
@@ -57,7 +59,7 @@ def _reset_internal_parameters(self):
5759
stat_size=self.stat_size,
5860
seed=self.seed,
5961
*self._args,
60-
**self._kwargs
62+
**self._kwargs,
6163
)
6264

6365
def _fit(self, X: np.ndarray):

src/ml3_drift/monitoring/algorithms/online/page_hinkley.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ def __init__(
5353
self.seed = seed
5454
self._args = args
5555
self._kwargs = kwargs
56-
super().__init__(callbacks=callbacks, comparison_size=1)
56+
super().__init__(
57+
callbacks=callbacks, comparison_size=1
58+
) # since we add only one sample per step and river handles building the window internally we set comparison_size to 1
5759

5860
def _reset_internal_parameters(self):
5961
self.drift_agent = RiverPageHinkley(
@@ -63,7 +65,7 @@ def _reset_internal_parameters(self):
6365
threshold=self.threshold,
6466
mode=self.mode,
6567
*self._args,
66-
**self._kwargs
68+
**self._kwargs,
6769
)
6870

6971
def _fit(self, X: np.ndarray):

src/ml3_drift/monitoring/base/base_multivariate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class MultivariateMonitoringAlgorithm(MonitoringAlgorithm, ABC):
1212
"""
1313

1414
def _is_valid(self, X: np.ndarray) -> tuple[bool, str]:
15-
if X.shape[1] > 1:
15+
if X.shape[1] >= 1:
1616
return True, ""
1717
else:
1818
return False, f"X must be multi-dimensional vector. Got {X.shape}"

0 commit comments

Comments
 (0)