-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathmain.py
More file actions
84 lines (54 loc) · 2.03 KB
/
main.py
File metadata and controls
84 lines (54 loc) · 2.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import hydra
from omegaconf import OmegaConf
from source.helper.LabelDescriptionHelper import LabelDescriptionHelper
from source.helper.PromptOptimizerHelper import PromptOptimizerHelper
from source.helper.RankingAggregationHelper import RankingAggregationHelper
from source.helper.RankingFusionHelper import RankingFusionHelper
from source.helper.SparseRetrieverHelper import SparseRetrieverHelper
from source.helper.retriever.RetrieverEvalHelper import RetrieverEvalHelper
from source.helper.retriever.RetrieverFitHelper import RetrieverFitHelper
from source.helper.retriever.RetrieverPredictHelper import RetrieverPredictHelper
def sparse_retrieve(params):
SparseRetrieverHelper(params).run()
def fit(params):
if params.model.type == "retriever":
RetrieverFitHelper(params).run()
def predict(params):
if params.model.type == "retriever":
helper = RetrieverPredictHelper(params)
helper.perform_predict()
def eval(params):
if params.model.type == "retriever":
helper = RetrieverEvalHelper(params)
helper.perform_eval()
def aggregate(params):
RankingAggregationHelper(params).run()
def fuse(params):
RankingFusionHelper(params).run()
def prompt_opt(params):
PromptOptimizerHelper(params).run()
def label_desc(params):
LabelDescriptionHelper(params).run()
@hydra.main(config_path="setting", config_name="setting.yaml", version_base=None)
def perform_tasks(params):
os.chdir(hydra.utils.get_original_cwd())
OmegaConf.resolve(params)
if "sparse_retrieve" in params.tasks:
sparse_retrieve(params)
if "fit" in params.tasks:
fit(params)
if "predict" in params.tasks:
predict(params)
if "eval" in params.tasks:
eval(params)
if "fuse" in params.tasks:
fuse(params)
if "aggregate" in params.tasks:
aggregate(params)
if "prompt_opt" in params.tasks:
prompt_opt(params)
if "label_desc" in params.tasks:
label_desc(params)
if __name__ == '__main__':
perform_tasks()