1313// See the License for the specific language governing permissions and
1414// limitations under the License.
1515
16-
1716#ifndef DECODER_PARAMS_H_
1817#define DECODER_PARAMS_H_
1918
2928#ifdef USE_TORCH
3029#include " decoder/torch_asr_model.h"
3130#endif
31+ #ifdef USE_XPU
32+ #include " decoder/xpu_asr_model.h"
33+ #endif
3234#include " frontend/feature_pipeline.h"
3335#include " post_processor/post_processor.h"
3436#include " utils/flags.h"
3537#include " utils/string.h"
3638
3739DEFINE_int32 (num_threads, 1 , " num threads for ASR model" );
40+ DEFINE_int32 (device_id, 0 , " set XPU DeviceID for ASR model" );
3841
3942// TorchAsrModel flags
4043DEFINE_string (model_path, " " , " pytorch exported model path" );
4144// OnnxAsrModel flags
4245DEFINE_string (onnx_dir, " " , " directory where the onnx model is saved" );
46+ // XPUAsrModel flags
47+ DEFINE_string (xpu_model_dir, " " ,
48+ " directory where the XPU model and weights is saved" );
4349
4450// FeaturePipelineConfig flags
4551DEFINE_int32 (num_bins, 80 , " num mel bins for fbank feature" );
@@ -66,7 +72,8 @@ DEFINE_double(lattice_beam, 10.0, "lattice beam in ctc wfst search");
6672DEFINE_double (acoustic_scale, 1.0 , " acoustic scale for ctc wfst search" );
6773DEFINE_double (blank_skip_thresh, 1.0 ,
6874 " blank skip thresh for ctc wfst search, 1.0 means no skip" );
69- DEFINE_double (length_penalty, 0.0 , " length penalty ctc wfst search, will not"
75+ DEFINE_double (length_penalty, 0.0 ,
76+ " length penalty ctc wfst search, will not"
7077 " apply on self-loop arc, for balancing the del/ins ratio, "
7178 " suggest set to -3.0" );
7279DEFINE_int32 (nbest, 10 , " nbest for ctc wfst or prefix search" );
@@ -130,7 +137,7 @@ std::shared_ptr<DecodeResource> InitDecodeResourceFromFlags() {
130137#else
131138 LOG (FATAL) << " Please rebuild with cmake options '-DONNX=ON'." ;
132139#endif
133- } else {
140+ } else if (!FLAGS_model_path. empty ()) {
134141#ifdef USE_TORCH
135142 LOG (INFO) << " Reading torch model " << FLAGS_model_path;
136143 TorchAsrModel::InitEngineThreads (FLAGS_num_threads);
@@ -140,6 +147,19 @@ std::shared_ptr<DecodeResource> InitDecodeResourceFromFlags() {
140147#else
141148 LOG (FATAL) << " Please rebuild with cmake options '-DTORCH=ON'." ;
142149#endif
150+ } else if (!FLAGS_xpu_model_dir.empty ()) {
151+ #ifdef USE_XPU
152+ LOG (INFO) << " Reading XPU WeNet model weight from " << FLAGS_xpu_model_dir;
153+ auto model = std::make_shared<XPUAsrModel>();
154+ model->SetEngineThreads (FLAGS_num_threads);
155+ model->SetDeviceId (FLAGS_device_id);
156+ model->Read (FLAGS_xpu_model_dir);
157+ resource->model = model;
158+ #else
159+ LOG (FATAL) << " Please rebuild with cmake options '-DXPU=ON'." ;
160+ #endif
161+ } else {
162+ LOG (FATAL) << " Please set ONNX, TORCH or XPU model path!!!" ;
143163 }
144164
145165 LOG (INFO) << " Reading unit table " << FLAGS_unit_path;
@@ -186,6 +206,7 @@ std::shared_ptr<DecodeResource> InitDecodeResourceFromFlags() {
186206 post_process_opts.lowercase = FLAGS_lowercase;
187207 resource->post_processor =
188208 std::make_shared<PostProcessor>(std::move (post_process_opts));
209+ LOG (INFO) << " Finish set PostProcessOptions. \n " ;
189210 return resource;
190211}
191212
0 commit comments