forked from wenet-e2e/wenet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathparams.h
More file actions
214 lines (192 loc) · 8.17 KB
/
params.h
File metadata and controls
214 lines (192 loc) · 8.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 Binbin Zhang (binbzha@qq.com)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DECODER_PARAMS_H_
#define DECODER_PARAMS_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "decoder/asr_decoder.h"
#ifdef USE_ONNX
#include "decoder/onnx_asr_model.h"
#endif
#ifdef USE_TORCH
#include "decoder/torch_asr_model.h"
#endif
#ifdef USE_XPU
#include "xpu/xpu_asr_model.h"
#endif
#include "frontend/feature_pipeline.h"
#include "post_processor/post_processor.h"
#include "utils/flags.h"
#include "utils/string.h"
DEFINE_int32(num_threads, 1, "num threads for ASR model");
DEFINE_int32(device_id, 0, "set XPU DeviceID for ASR model");
// TorchAsrModel flags
DEFINE_string(model_path, "", "pytorch exported model path");
// OnnxAsrModel flags
DEFINE_string(onnx_dir, "", "directory where the onnx model is saved");
// XPUAsrModel flags
DEFINE_string(xpu_model_dir, "",
"directory where the XPU model and weights is saved");
// FeaturePipelineConfig flags
DEFINE_int32(num_bins, 80, "num mel bins for fbank feature");
DEFINE_int32(sample_rate, 16000, "sample rate for audio");
// TLG fst
DEFINE_string(fst_path, "", "TLG fst path");
// DecodeOptions flags
DEFINE_int32(chunk_size, 16, "decoding chunk size");
DEFINE_int32(num_left_chunks, -1, "left chunks in decoding");
DEFINE_double(ctc_weight, 0.5,
"ctc weight when combining ctc score and rescoring score");
DEFINE_double(rescoring_weight, 1.0,
"rescoring weight when combining ctc score and rescoring score");
DEFINE_double(reverse_weight, 0.0,
"used for bitransformer rescoring. it must be 0.0 if decoder is"
"conventional transformer decoder, and only reverse_weight > 0.0"
"dose the right to left decoder will be calculated and used");
DEFINE_int32(max_active, 7000, "max active states in ctc wfst search");
DEFINE_int32(min_active, 200, "min active states in ctc wfst search");
DEFINE_double(beam, 16.0, "beam in ctc wfst search");
DEFINE_double(lattice_beam, 10.0, "lattice beam in ctc wfst search");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale for ctc wfst search");
DEFINE_double(blank_skip_thresh, 1.0,
"blank skip thresh for ctc wfst search, 1.0 means no skip");
DEFINE_double(length_penalty, 0.0,
"length penalty ctc wfst search, will not"
"apply on self-loop arc, for balancing the del/ins ratio, "
"suggest set to -3.0");
DEFINE_int32(nbest, 10, "nbest for ctc wfst or prefix search");
// SymbolTable flags
DEFINE_string(dict_path, "",
"dict symbol table path, required when LM is enabled");
DEFINE_string(unit_path, "",
"e2e model unit symbol table, it is used in both "
"with/without LM scenarios for context/timestamp");
// Context flags
DEFINE_string(context_path, "", "context path, is used to build context graph");
DEFINE_double(context_score, 3.0, "is used to rescore the decoded result");
// PostProcessOptions flags
DEFINE_int32(language_type, 0,
"remove spaces according to language type"
"0x00 = kMandarinEnglish, "
"0x01 = kIndoEuropean");
DEFINE_bool(lowercase, true, "lowercase final result if needed");
namespace wenet {
std::shared_ptr<FeaturePipelineConfig> InitFeaturePipelineConfigFromFlags() {
auto feature_config = std::make_shared<FeaturePipelineConfig>(
FLAGS_num_bins, FLAGS_sample_rate);
return feature_config;
}
std::shared_ptr<DecodeOptions> InitDecodeOptionsFromFlags() {
auto decode_config = std::make_shared<DecodeOptions>();
decode_config->chunk_size = FLAGS_chunk_size;
decode_config->num_left_chunks = FLAGS_num_left_chunks;
decode_config->ctc_weight = FLAGS_ctc_weight;
decode_config->reverse_weight = FLAGS_reverse_weight;
decode_config->rescoring_weight = FLAGS_rescoring_weight;
decode_config->ctc_wfst_search_opts.max_active = FLAGS_max_active;
decode_config->ctc_wfst_search_opts.min_active = FLAGS_min_active;
decode_config->ctc_wfst_search_opts.beam = FLAGS_beam;
decode_config->ctc_wfst_search_opts.lattice_beam = FLAGS_lattice_beam;
decode_config->ctc_wfst_search_opts.acoustic_scale = FLAGS_acoustic_scale;
decode_config->ctc_wfst_search_opts.blank_skip_thresh =
FLAGS_blank_skip_thresh;
decode_config->ctc_wfst_search_opts.length_penalty = FLAGS_length_penalty;
decode_config->ctc_wfst_search_opts.nbest = FLAGS_nbest;
decode_config->ctc_prefix_search_opts.first_beam_size = FLAGS_nbest;
decode_config->ctc_prefix_search_opts.second_beam_size = FLAGS_nbest;
return decode_config;
}
std::shared_ptr<DecodeResource> InitDecodeResourceFromFlags() {
auto resource = std::make_shared<DecodeResource>();
if (!FLAGS_onnx_dir.empty()) {
#ifdef USE_ONNX
LOG(INFO) << "Reading onnx model ";
OnnxAsrModel::InitEngineThreads(FLAGS_num_threads);
auto model = std::make_shared<OnnxAsrModel>();
model->Read(FLAGS_onnx_dir);
resource->model = model;
#else
LOG(FATAL) << "Please rebuild with cmake options '-DONNX=ON'.";
#endif
} else if (!FLAGS_model_path.empty()) {
#ifdef USE_TORCH
LOG(INFO) << "Reading torch model " << FLAGS_model_path;
TorchAsrModel::InitEngineThreads(FLAGS_num_threads);
auto model = std::make_shared<TorchAsrModel>();
model->Read(FLAGS_model_path);
resource->model = model;
#else
LOG(FATAL) << "Please rebuild with cmake options '-DTORCH=ON'.";
#endif
} else if (!FLAGS_xpu_model_dir.empty()) {
#ifdef USE_XPU
LOG(INFO) << "Reading XPU WeNet model weight from " << FLAGS_xpu_model_dir;
auto model = std::make_shared<XPUAsrModel>();
model->SetEngineThreads(FLAGS_num_threads);
model->SetDeviceId(FLAGS_device_id);
model->Read(FLAGS_xpu_model_dir);
resource->model = model;
#else
LOG(FATAL) << "Please rebuild with cmake options '-DXPU=ON'.";
#endif
} else {
LOG(FATAL) << "Please set ONNX, TORCH or XPU model path!!!";
}
LOG(INFO) << "Reading unit table " << FLAGS_unit_path;
auto unit_table = std::shared_ptr<fst::SymbolTable>(
fst::SymbolTable::ReadText(FLAGS_unit_path));
CHECK(unit_table != nullptr);
resource->unit_table = unit_table;
if (!FLAGS_fst_path.empty()) { // With LM
CHECK(!FLAGS_dict_path.empty());
LOG(INFO) << "Reading fst " << FLAGS_fst_path;
auto fst = std::shared_ptr<fst::Fst<fst::StdArc>>(
fst::Fst<fst::StdArc>::Read(FLAGS_fst_path));
CHECK(fst != nullptr);
resource->fst = fst;
LOG(INFO) << "Reading symbol table " << FLAGS_dict_path;
auto symbol_table = std::shared_ptr<fst::SymbolTable>(
fst::SymbolTable::ReadText(FLAGS_dict_path));
CHECK(symbol_table != nullptr);
resource->symbol_table = symbol_table;
} else { // Without LM, symbol_table is the same as unit_table
resource->symbol_table = unit_table;
}
if (!FLAGS_context_path.empty()) {
LOG(INFO) << "Reading context " << FLAGS_context_path;
std::vector<std::string> contexts;
std::ifstream infile(FLAGS_context_path);
std::string context;
while (getline(infile, context)) {
contexts.emplace_back(Trim(context));
}
ContextConfig config;
config.context_score = FLAGS_context_score;
resource->context_graph = std::make_shared<ContextGraph>(config);
resource->context_graph->BuildContextGraph(contexts,
resource->symbol_table);
}
PostProcessOptions post_process_opts;
post_process_opts.language_type =
FLAGS_language_type == 0 ? kMandarinEnglish : kIndoEuropean;
post_process_opts.lowercase = FLAGS_lowercase;
resource->post_processor =
std::make_shared<PostProcessor>(std::move(post_process_opts));
return resource;
}
} // namespace wenet
#endif // DECODER_PARAMS_H_