This repository was archived by the owner on Oct 1, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathsksym.py
More file actions
146 lines (101 loc) · 3.88 KB
/
sksym.py
File metadata and controls
146 lines (101 loc) · 3.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""Challenge symmetries with the sklearn interface."""
import math
from dataclasses import dataclass
from functools import partial
from numbers import Integral
from typing import Callable, Optional
import numpy
import scipy
@dataclass
class WhichIsReal:
"""Package a symmetry testing setup with methods to hook into sklearn."""
transform: Optional[Callable] = None
nfakes: Integral = 1
def objective(self):
return logistic_difference(self.nfakes)
def pack(self, data):
assert (
self.transform is not None
), "pack requires a transform function; use stack for arrays of fakes"
return pack(data, self.transform, self.nfakes)
def stack(self, data, fakes):
fakes = list(fakes)
assert (
len(fakes) == self.nfakes
), "got len(fakes) == %d, but nfakes == %d" % (
len(fakes),
self.nfakes,
)
return stack(data, fakes)
def logistic_difference(nfakes=1):
"""Return an objective for the antisymmetric logistic loss with nfakes.
It returns first and second derivatives of the negative log likelihood.
Since it is self-supervised, y_true labels are assumed to be
[1]*ndata + [0]*ndata*nfakes, and ignored.
y_pred are model predictions with shape ((1 + nfakes)*ndata,).
The first n entries are real data; the others are fakes.
"""
return partial(_logistic_difference_objective, nfakes)
def _logistic_difference_objective(nfakes, _, y_pred):
zeta = y_pred.reshape(1 + nfakes, -1)
phi = zeta[0] - zeta[1:]
dxdash = scipy.special.expit(-phi)
if nfakes > 1:
dxdash *= 1 / nfakes
d2xdash = dxdash * scipy.special.expit(phi)
jac = numpy.concatenate([-dxdash.sum(axis=0), dxdash.ravel()])
hess = numpy.concatenate([d2xdash.sum(axis=0), d2xdash.ravel()])
return jac, hess
def pack(data, transform, nfakes=1):
"""Return data packed with transformed fake versions.
shape: (1 + nfakes, ndata, ndim)
"""
fakes = [transform(data) for _ in range(nfakes)]
return stack(data, fakes)
def stack(data, fakes):
"""Return data stacked with fakes.
shape: (1 + len(fakes), ndata, ndim)
"""
return numpy.stack([data, *fakes])
# sklearn interface
def fit(model, packed, *args, **kwargs):
"""Fit model to packed data.
packed[0] is x, packed[1:] is sx.
args and kwargs are forwarded to the model.fit(...).
"""
data = packed.reshape(-1, packed.shape[-1])
labels = numpy.zeros(data.shape[0], bool)
return model.fit(data, labels, *args, **kwargs)
def score(model, packed, *, and_std=False, **kwargs):
"""Return the mean log likelihood ratio vs 50:50."""
return score_log_proba(
predict_log_proba(model, packed, **kwargs),
and_std=and_std,
)
def predict_proba(model, packed, **kwargs):
"""Return probabilities in shape (nfakes, ndata, 2).
shape: (ndata, 2) if nfakes == 1, else (nfakes, ndata, 2)
"""
return numpy.exp(predict_log_proba(model, packed, **kwargs))
def predict_log_proba(model, packed, **kwargs):
"""Return log probabilities at packed data.
shape: (ndata, 2) if nfakes == 1, else (nfakes, ndata, 2)
"""
zet = predict_zeta(model, packed, **kwargs)
phi = zet[0] - zet[1:]
if phi.shape[0] == 1:
phi = phi.reshape(phi.shape[1:])
return -numpy.logaddexp(0, numpy.stack([-phi, phi], axis=-1))
# utility
def score_log_proba(log_proba, *, and_std=False):
"""Return the mean log likelihood ratio vs 50:50."""
size = log_proba.shape[-2]
r = log_proba[..., 0] - math.log(0.5)
if and_std:
return r.mean(), (r.std() / size**0.5).astype(r.dtype)
return r.mean()
def predict_zeta(model, packed, **kwargs):
"""Return model outputs in shape (ntot, ndata)."""
ntot, _, ndim = packed.shape
data = packed.reshape(-1, ndim)
return model.predict(data, **kwargs).reshape(ntot, -1)