Skip to content

Commit 17bc3e9

Browse files
Merge pull request #1 from RahulVadisetty91/RahulVadisetty91-patch-1
Enhance Metric Testing with AI-Based Prediction and Validation Features
2 parents ebe5b31 + 9b74054 commit 17bc3e9

File tree

1 file changed

+253
-0
lines changed

1 file changed

+253
-0
lines changed

ai_metrics.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
import os
2+
import unittest
3+
4+
import mmf.modules.metrics as metrics
5+
import torch
6+
from mmf.common.registry import registry
7+
from mmf.common.sample import Sample
8+
from mmf.datasets.processors import CaptionProcessor
9+
from mmf.utils.configuration import load_yaml
10+
11+
# New AI-driven modules for prediction and validation
12+
from ai_modules.prediction import AIPrediction
13+
from ai_modules.validation import AIValidation
14+
15+
16+
class TestModuleMetrics(unittest.TestCase):
17+
def setUp(self):
18+
# Initialize AI modules
19+
self.ai_predictor = AIPrediction()
20+
self.ai_validator = AIValidation()
21+
22+
def test_caption_bleu4(self):
23+
path = os.path.join(
24+
os.path.abspath(__file__),
25+
"../../../mmf/configs/datasets/coco/defaults.yaml",
26+
)
27+
config = load_yaml(os.path.abspath(path))
28+
captioning_config = config.dataset_config.coco
29+
caption_processor_config = captioning_config.processors.caption_processor
30+
vocab_path = os.path.join(
31+
os.path.abspath(__file__), "..", "..", "data", "vocab.txt"
32+
)
33+
caption_processor_config.params.vocab.type = "random"
34+
caption_processor_config.params.vocab.vocab_file = os.path.abspath(vocab_path)
35+
caption_processor = CaptionProcessor(caption_processor_config.params)
36+
registry.register("coco_caption_processor", caption_processor)
37+
38+
caption_bleu4 = metrics.CaptionBleu4Metric()
39+
expected = Sample()
40+
predicted = dict()
41+
42+
# AI-driven input validation
43+
self.ai_validator.validate_inputs(expected, predicted)
44+
45+
# AI-driven prediction adjustment
46+
predicted = self.ai_predictor.adjust_predictions(predicted)
47+
48+
# Test complete match
49+
expected.answers = torch.empty((5, 5, 10))
50+
expected.answers.fill_(4)
51+
predicted["scores"] = torch.zeros((5, 10, 19))
52+
predicted["scores"][:, :, 4] = 1.0
53+
54+
self.assertEqual(caption_bleu4.calculate(expected, predicted).item(), 1.0)
55+
56+
# Test partial match
57+
expected.answers = torch.empty((5, 5, 10))
58+
expected.answers.fill_(4)
59+
predicted["scores"] = torch.zeros((5, 10, 19))
60+
predicted["scores"][:, 0:5, 4] = 1.0
61+
predicted["scores"][:, 5:, 18] = 1.0
62+
63+
self.assertAlmostEqual(
64+
caption_bleu4.calculate(expected, predicted).item(), 0.3928, 4
65+
)
66+
67+
def _test_binary_metric(self, metric, value):
68+
sample = Sample()
69+
predicted = dict()
70+
71+
sample.targets = torch.tensor(
72+
[[0, 1], [1, 0], [1, 0], [0, 1]], dtype=torch.float
73+
)
74+
predicted["scores"] = torch.tensor(
75+
[
76+
[-0.9332, 0.8149],
77+
[-0.8391, 0.6797],
78+
[-0.7235, 0.7220],
79+
[-0.9043, 0.3078],
80+
],
81+
dtype=torch.float,
82+
)
83+
84+
# AI-driven input validation and prediction adjustment
85+
self.ai_validator.validate_inputs(sample, predicted)
86+
predicted = self.ai_predictor.adjust_predictions(predicted)
87+
88+
self.assertAlmostEqual(metric.calculate(sample, predicted).item(), value, 4)
89+
90+
sample.targets = torch.tensor([1, 0, 0, 1], dtype=torch.long)
91+
self.assertAlmostEqual(metric.calculate(sample, predicted).item(), value, 4)
92+
93+
def _test_multiclass_metric(self, metric, value):
94+
sample = Sample()
95+
predicted = dict()
96+
97+
sample.targets = torch.tensor(
98+
[[0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 0, 1]], dtype=torch.float
99+
)
100+
predicted["scores"] = torch.tensor(
101+
[
102+
[-0.9332, 0.8149, 0.3491],
103+
[-0.8391, 0.6797, -0.3410],
104+
[-0.7235, 0.7220, 0.9104],
105+
[0.9043, 0.3078, -0.4210],
106+
],
107+
dtype=torch.float,
108+
)
109+
110+
# AI-driven input validation and prediction adjustment
111+
self.ai_validator.validate_inputs(sample, predicted)
112+
predicted = self.ai_predictor.adjust_predictions(predicted)
113+
114+
self.assertAlmostEqual(metric.calculate(sample, predicted).item(), value, 4)
115+
116+
sample.targets = torch.tensor([1, 2, 0, 2], dtype=torch.long)
117+
self.assertAlmostEqual(metric.calculate(sample, predicted).item(), value, 4)
118+
119+
def _test_multilabel_metric(self, metric, value):
120+
sample = Sample()
121+
predicted = dict()
122+
123+
sample.targets = torch.tensor(
124+
[[0, 1, 1], [1, 0, 1], [1, 0, 1], [0, 0, 1]], dtype=torch.float
125+
)
126+
predicted["scores"] = torch.tensor(
127+
[
128+
[-0.9332, 0.8149, 0.3491],
129+
[-0.8391, 0.6797, -0.3410],
130+
[-0.7235, 0.7220, 0.9104],
131+
[0.9043, 0.3078, -0.4210],
132+
],
133+
dtype=torch.float,
134+
)
135+
136+
# AI-driven input validation and prediction adjustment
137+
self.ai_validator.validate_inputs(sample, predicted)
138+
predicted = self.ai_predictor.adjust_predictions(predicted)
139+
140+
self.assertAlmostEqual(metric.calculate(sample, predicted).item(), value, 4)
141+
142+
def _test_recall_at_k_metric(self, metric, value):
143+
sample = Sample()
144+
predicted = dict()
145+
146+
first_dimension = 10
147+
second_dimension = 100 # second dim MUST be 100
148+
sample.targets = torch.ones(first_dimension, second_dimension)
149+
predicted["scores"] = torch.ones(first_dimension, second_dimension)
150+
151+
for i in range(first_dimension):
152+
for j in range(second_dimension):
153+
sample.targets[i][j] = j
154+
if j == second_dimension - 1 and i != 0:
155+
predicted["scores"][i][j] = j * 2 - 1 - (i + 2) * 2
156+
else:
157+
predicted["scores"][i][j] = j * 2
158+
159+
# AI-driven input validation and prediction adjustment
160+
self.ai_validator.validate_inputs(sample, predicted)
161+
predicted = self.ai_predictor.adjust_predictions(predicted)
162+
163+
self.assertAlmostEqual(metric.calculate(sample, predicted), value)
164+
165+
def _test_retrieval_recall_at_k_metric(self, metric, value):
166+
sample = Sample()
167+
predicted = dict()
168+
169+
torch.manual_seed(1234)
170+
predicted["targets"] = torch.rand((10, 4))
171+
predicted["scores"] = torch.rand((10, 4))
172+
173+
# AI-driven input validation and prediction adjustment
174+
self.ai_validator.validate_inputs(sample, predicted)
175+
predicted = self.ai_predictor.adjust_predictions(predicted)
176+
177+
self.assertAlmostEqual(float(metric.calculate(sample, predicted)), value)
178+
179+
def _test_binary_dict_metric(self, metric, value_dict):
180+
sample = Sample()
181+
predicted = dict()
182+
183+
sample.targets = torch.tensor(
184+
[[0, 1], [1, 0], [1, 0], [0, 1]], dtype=torch.float
185+
)
186+
predicted["scores"] = torch.tensor(
187+
[
188+
[-0.9332, 0.8149],
189+
[-0.8391, 0.6797],
190+
[-0.7235, 0.7220],
191+
[-0.9043, 0.3078],
192+
],
193+
dtype=torch.float,
194+
)
195+
196+
# AI-driven input validation and prediction adjustment
197+
self.ai_validator.validate_inputs(sample, predicted)
198+
predicted = self.ai_predictor.adjust_predictions(predicted)
199+
200+
metric_result = metric.calculate(sample, predicted)
201+
for key, val in value_dict.items():
202+
self.assertAlmostEqual(metric_result[key].item(), val, 4)
203+
204+
sample.targets = torch.tensor([1, 0, 0, 1], dtype=torch.long)
205+
metric_result = metric.calculate(sample, predicted)
206+
for key, val in value_dict.items():
207+
self.assertAlmostEqual(metric_result[key].item(), val, 4)
208+
209+
def test_micro_f1(self):
210+
metric = metrics.MicroF1()
211+
self._test_binary_metric(metric, 0.5)
212+
self._test_multiclass_metric(metric, 0.25)
213+
214+
def test_macro_f1(self):
215+
metric = metrics.MacroF1()
216+
self._test_binary_metric(metric, 0.3333)
217+
self._test_multiclass_metric(metric, 0.2222)
218+
219+
def test_binary_f1(self):
220+
metric = metrics.BinaryF1()
221+
self._test_binary_metric(metric, 0.66666666)
222+
223+
def test_multilabel_micro_f1(self):
224+
metric = metrics.MultiLabelMicroF1()
225+
self._test_binary_metric(metric, 0.5)
226+
227+
def test_multilabel_macro_f1(self):
228+
metric = metrics.MultiLabelMacroF1()
229+
self._test_multilabel_metric(metric, 0.355555)
230+
231+
def test_multilabel_f1(self):
232+
metric = metrics.MultiLabelF1()
233+
self._test_multilabel_metric(metric, 0.355555)
234+
235+
def test_precision_at_k(self):
236+
metric = metrics.PrecisionAtK()
237+
self._test_recall_at_k_metric(metric, 1)
238+
239+
def test_recall_at_k(self):
240+
metric = metrics.RecallAtK()
241+
self._test_recall_at_k_metric(metric, 1)
242+
243+
def test_accuracy_at_k(self):
244+
metric = metrics.AccuracyAtK()
245+
self._test_retrieval_recall_at_k_metric(metric, 0.6)
246+
247+
def test_ndcg_at_k(self):
248+
metric = metrics.NDCGAtK()
249+
self._test_retrieval_recall_at_k_metric(metric, 0.879818)
250+
251+
def test_mrr_at_k(self):
252+
metric = metrics.MRRAtK()
253+
self._test_retrieval_recall_at_k_metric(metric, 0.850000)

0 commit comments

Comments
 (0)