forked from wenet-e2e/wenet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathios_asr_model.cc
More file actions
233 lines (207 loc) · 8.04 KB
/
ios_asr_model.cc
File metadata and controls
233 lines (207 loc) · 8.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 Binbin Zhang (binbzha@qq.com)
// 2022 Dan Ma (1067837450@qq.com)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ios_asr_model.h"
#include <algorithm>
#include <memory>
#include <utility>
#include <stdexcept>
#include "torch/script.h"
namespace wenet {
void IosAsrModel::Read(const std::string& model_path) {
torch::DeviceType device = at::kCPU;
torch::jit::script::Module model = torch::jit::load(model_path, device);
model_ = std::make_shared<TorchModule>(std::move(model));
torch::NoGradGuard no_grad;
model_->eval();
torch::jit::IValue o1 = model_->run_method("subsampling_rate");
CHECK_EQ(o1.isInt(), true);
subsampling_rate_ = o1.toInt();
torch::jit::IValue o2 = model_->run_method("right_context");
CHECK_EQ(o2.isInt(), true);
right_context_ = o2.toInt();
torch::jit::IValue o3 = model_->run_method("sos_symbol");
CHECK_EQ(o3.isInt(), true);
sos_ = o3.toInt();
torch::jit::IValue o4 = model_->run_method("eos_symbol");
CHECK_EQ(o4.isInt(), true);
eos_ = o4.toInt();
torch::jit::IValue o5 = model_->run_method("is_bidirectional_decoder");
CHECK_EQ(o5.isBool(), true);
is_bidirectional_decoder_ = o5.toBool();
VLOG(1) << "Torch Model Info:";
VLOG(1) << "\tsubsampling_rate " << subsampling_rate_;
VLOG(1) << "\tright context " << right_context_;
VLOG(1) << "\tsos " << sos_;
VLOG(1) << "\teos " << eos_;
VLOG(1) << "\tis bidirectional decoder " << is_bidirectional_decoder_;
}
IosAsrModel::IosAsrModel(const IosAsrModel& other) {
// 1. Init the model info
right_context_ = other.right_context_;
subsampling_rate_ = other.subsampling_rate_;
sos_ = other.sos_;
eos_ = other.eos_;
is_bidirectional_decoder_ = other.is_bidirectional_decoder_;
chunk_size_ = other.chunk_size_;
num_left_chunks_ = other.num_left_chunks_;
offset_ = other.offset_;
// 2. Model copy, just copy the model ptr since:
// PyTorch allows using multiple CPU threads during TorchScript model
// inference, please see https://pytorch.org/docs/stable/notes/cpu_
// threading_torchscript_inference.html
model_ = other.model_;
// NOTE(Binbin Zhang):
// inner states for forward are not copied here.
}
std::shared_ptr<AsrModel> IosAsrModel::Copy() const {
auto asr_model = std::make_shared<IosAsrModel>(*this);
// Reset the inner states for new decoding
asr_model->Reset();
return asr_model;
}
void IosAsrModel::Reset() {
offset_ = 0;
att_cache_ = std::move(torch::zeros({0, 0, 0, 0}));
cnn_cache_ = std::move(torch::zeros({0, 0, 0, 0}));
encoder_outs_.clear();
cached_feature_.clear();
}
void IosAsrModel::ForwardEncoderFunc(
const std::vector<std::vector<float>>& chunk_feats,
std::vector<std::vector<float>>* out_prob) {
// 1. Prepare libtorch required data, splice cached_feature_ and chunk_feats
// The first dimension is for batchsize, which is 1.
int num_frames = cached_feature_.size() + chunk_feats.size();
const int feature_dim = chunk_feats[0].size();
torch::Tensor feats =
torch::zeros({1, num_frames, feature_dim}, torch::kFloat);
for (size_t i = 0; i < cached_feature_.size(); ++i) {
torch::Tensor row =
torch::from_blob(const_cast<float*>(cached_feature_[i].data()),
{feature_dim}, torch::kFloat)
.clone();
feats[0][i] = std::move(row);
}
for (size_t i = 0; i < chunk_feats.size(); ++i) {
torch::Tensor row =
torch::from_blob(const_cast<float*>(chunk_feats[i].data()),
{feature_dim}, torch::kFloat)
.clone();
feats[0][cached_feature_.size() + i] = std::move(row);
}
// 2. Encoder chunk forward
int required_cache_size = chunk_size_ * num_left_chunks_;
torch::NoGradGuard no_grad;
std::vector<torch::jit::IValue> inputs = {feats, offset_, required_cache_size,
att_cache_, cnn_cache_};
// Refer interfaces in wenet/transformer/asr_model.py
auto outputs =
model_->get_method("forward_encoder_chunk")(inputs).toTuple()->elements();
CHECK_EQ(outputs.size(), 3);
torch::Tensor chunk_out = outputs[0].toTensor();
att_cache_ = outputs[1].toTensor();
cnn_cache_ = outputs[2].toTensor();
offset_ += chunk_out.size(1);
// The first dimension of returned value is for batchsize, which is 1
torch::Tensor ctc_log_probs =
model_->run_method("ctc_activation", chunk_out).toTensor()[0];
encoder_outs_.push_back(std::move(chunk_out));
// Copy to output
int num_outputs = ctc_log_probs.size(0);
int output_dim = ctc_log_probs.size(1);
out_prob->resize(num_outputs);
for (int i = 0; i < num_outputs; i++) {
(*out_prob)[i].resize(output_dim);
memcpy((*out_prob)[i].data(), ctc_log_probs[i].data_ptr(),
sizeof(float) * output_dim);
}
}
float IosAsrModel::ComputeAttentionScore(const torch::Tensor& prob,
const std::vector<int>& hyp,
int eos) {
float score = 0.0f;
auto accessor = prob.accessor<float, 2>();
for (size_t j = 0; j < hyp.size(); ++j) {
score += accessor[j][hyp[j]];
}
score += accessor[hyp.size()][eos];
return score;
}
void IosAsrModel::AttentionRescoring(
const std::vector<std::vector<int>>& hyps, float reverse_weight,
std::vector<float>* rescoring_score) {
CHECK(rescoring_score != nullptr);
int num_hyps = hyps.size();
rescoring_score->resize(num_hyps, 0.0f);
if (num_hyps == 0) {
return;
}
// No encoder output
if (encoder_outs_.size() == 0) {
return;
}
torch::NoGradGuard no_grad;
// Step 1: Prepare input for libtorch
torch::Tensor hyps_length = torch::zeros({num_hyps}, torch::kLong);
int max_hyps_len = 0;
for (size_t i = 0; i < num_hyps; ++i) {
int length = hyps[i].size() + 1;
max_hyps_len = std::max(length, max_hyps_len);
hyps_length[i] = static_cast<int64_t>(length);
}
torch::Tensor hyps_tensor =
torch::zeros({num_hyps, max_hyps_len}, torch::kLong);
for (size_t i = 0; i < num_hyps; ++i) {
const std::vector<int>& hyp = hyps[i];
hyps_tensor[i][0] = sos_;
for (size_t j = 0; j < hyp.size(); ++j) {
hyps_tensor[i][j + 1] = hyp[j];
}
}
// Step 2: Forward attention decoder by hyps and corresponding encoder_outs_
torch::Tensor encoder_out = torch::cat(encoder_outs_, 1);
auto outputs = model_
->run_method("forward_attention_decoder", hyps_tensor,
hyps_length, encoder_out, reverse_weight)
.toTuple()
->elements();
auto probs = outputs[0].toTensor();
auto r_probs = outputs[1].toTensor();
CHECK_EQ(probs.size(0), num_hyps);
CHECK_EQ(probs.size(1), max_hyps_len);
// Step 3: Compute rescoring score
for (size_t i = 0; i < num_hyps; ++i) {
const std::vector<int>& hyp = hyps[i];
float score = 0.0f;
// left-to-right decoder score
score = ComputeAttentionScore(probs[i], hyp, eos_);
// Optional: Used for right to left score
float r_score = 0.0f;
if (is_bidirectional_decoder_ && reverse_weight > 0) {
// right-to-left score
CHECK_EQ(r_probs.size(0), num_hyps);
CHECK_EQ(r_probs.size(1), max_hyps_len);
std::vector<int> r_hyp(hyp.size());
std::reverse_copy(hyp.begin(), hyp.end(), r_hyp.begin());
// right to left decoder score
r_score = ComputeAttentionScore(r_probs[i], r_hyp, eos_);
}
// combined left-to-right and right-to-left score
(*rescoring_score)[i] =
score * (1 - reverse_weight) + r_score * reverse_weight;
}
}
} // namespace wenet