-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_naive_bayes.py
More file actions
94 lines (78 loc) · 1.79 KB
/
test_naive_bayes.py
File metadata and controls
94 lines (78 loc) · 1.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import pytest
from scipy.sparse import csr_matrix
from naive_bayes_em import *
@pytest.fixture(scope="session")
def data_matrix():
return csr_matrix([
[1,1],
[1,0],
[1,0],
[0,1],
[1,0], ####
[1,1],
[0,1],
[0,1],
[1,0],
[0,1]
]
)
@pytest.fixture(scope="session")
def labels():
return np.array([
0,
0,
0,
0,
-1,
1,
1,
1,
1,
-1
])
@pytest.fixture(scope="session")
def model(data_matrix, labels):
nb = NaiveBayes()
nb._bootstrap_model(data_matrix, labels)
return nb
class TestNaiveBayes():
def test_bootstrap_model(self, model):
expected_params = np.array([
[3/4,0.5],
[0.5,3/4]
])
expected_bias_params = np.array([
0.5, 0.5
])
assert np.allclose(model.parameters_, expected_params)
assert np.allclose(model.intercept_parameters_, expected_bias_params)
expected_intercept = np.array([
2*np.log(0.5)+np.log(1/4),
2*np.log(0.5)+np.log(1/4)
])
assert np.allclose(expected_intercept, model.intercept_)
expected_coef = np.log(expected_params) - np.log(1-expected_params)
assert np.allclose(expected_coef, model.coef_)
def test_predict_proba(self, model, data_matrix):
expected_proba = np.array([
[0.5,0.5],
[3/4,1/4],
[3/4,1/4],
[1/4,3/4],
[3/4,1/4],
[0.5,0.5],
[1/4,3/4],
[1/4,3/4],
[3/4,1/4],
[1/4,3/4],
])
proba = model.predict_proba(data_matrix)
assert np.allclose(proba,expected_proba)
def test_fit(self, model, data_matrix, labels):
nb = NaiveBayes(n_iter=1)
nb.fit(data_matrix, labels)
expected_params = np.array([
[(3+3/4)/(4+1),(2+1/4)/(4+1)],
[(2+1/4)/(4+1), (3+3/4)/(4+1)]
])
assert np.allclose(expected_params, nb.parameters_)