Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions runtime/core/decoder/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@ set(decoder_srcs
ctc_endpoint.cc
)

if(NOT TORCH AND NOT ONNX AND NOT XPU)
message(FATAL_ERROR "Please build with TORCH or ONNX or XPU!!!")
if(NOT TORCH AND NOT ONNX AND NOT XPU AND NOT IOS)
message(FATAL_ERROR "Please build with TORCH or ONNX or XPU or IOS!!!")
endif()
if(TORCH)
list(APPEND decoder_srcs torch_asr_model.cc)
endif()
if(ONNX)
list(APPEND decoder_srcs onnx_asr_model.cc)
endif()
if(IOS)
list(APPEND decoder_srcs ios_asr_model.cc)
endif()

add_library(decoder STATIC ${decoder_srcs})
target_link_libraries(decoder PUBLIC kaldi-decoder frontend
Expand Down
233 changes: 233 additions & 0 deletions runtime/core/decoder/ios_asr_model.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 Binbin Zhang (binbzha@qq.com)
Comment thread
pengzhendong marked this conversation as resolved.
Outdated
// 2022 Dan Ma (1067837450@qq.com)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Comment thread
Ma-Dan marked this conversation as resolved.
Outdated
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.


#include "decoder/ios_asr_model.h"

#include <algorithm>
#include <memory>
#include <utility>
#include <stdexcept>

#include "torch/script.h"

namespace wenet {

void IosAsrModel::Read(const std::string& model_path) {
torch::DeviceType device = at::kCPU;
torch::jit::script::Module model = torch::jit::load(model_path, device);
model_ = std::make_shared<TorchModule>(std::move(model));
torch::NoGradGuard no_grad;
model_->eval();
torch::jit::IValue o1 = model_->run_method("subsampling_rate");
CHECK_EQ(o1.isInt(), true);
subsampling_rate_ = o1.toInt();
torch::jit::IValue o2 = model_->run_method("right_context");
CHECK_EQ(o2.isInt(), true);
right_context_ = o2.toInt();
torch::jit::IValue o3 = model_->run_method("sos_symbol");
CHECK_EQ(o3.isInt(), true);
sos_ = o3.toInt();
torch::jit::IValue o4 = model_->run_method("eos_symbol");
CHECK_EQ(o4.isInt(), true);
eos_ = o4.toInt();
torch::jit::IValue o5 = model_->run_method("is_bidirectional_decoder");
CHECK_EQ(o5.isBool(), true);
is_bidirectional_decoder_ = o5.toBool();

VLOG(1) << "Torch Model Info:";
VLOG(1) << "\tsubsampling_rate " << subsampling_rate_;
VLOG(1) << "\tright context " << right_context_;
VLOG(1) << "\tsos " << sos_;
VLOG(1) << "\teos " << eos_;
VLOG(1) << "\tis bidirectional decoder " << is_bidirectional_decoder_;
}

IosAsrModel::IosAsrModel(const IosAsrModel& other) {
// 1. Init the model info
right_context_ = other.right_context_;
subsampling_rate_ = other.subsampling_rate_;
sos_ = other.sos_;
eos_ = other.eos_;
is_bidirectional_decoder_ = other.is_bidirectional_decoder_;
chunk_size_ = other.chunk_size_;
num_left_chunks_ = other.num_left_chunks_;
offset_ = other.offset_;
// 2. Model copy, just copy the model ptr since:
// PyTorch allows using multiple CPU threads during TorchScript model
// inference, please see https://pytorch.org/docs/stable/notes/cpu_
// threading_torchscript_inference.html
model_ = other.model_;

// NOTE(Binbin Zhang):
// inner states for forward are not copied here.
}

std::shared_ptr<AsrModel> IosAsrModel::Copy() const {
auto asr_model = std::make_shared<IosAsrModel>(*this);
// Reset the inner states for new decoding
asr_model->Reset();
return asr_model;
}

void IosAsrModel::Reset() {
offset_ = 0;
att_cache_ = std::move(torch::zeros({0, 0, 0, 0}));
cnn_cache_ = std::move(torch::zeros({0, 0, 0, 0}));
encoder_outs_.clear();
cached_feature_.clear();
}

void IosAsrModel::ForwardEncoderFunc(
const std::vector<std::vector<float>>& chunk_feats,
std::vector<std::vector<float>>* out_prob) {
// 1. Prepare libtorch required data, splice cached_feature_ and chunk_feats
// The first dimension is for batchsize, which is 1.
int num_frames = cached_feature_.size() + chunk_feats.size();
const int feature_dim = chunk_feats[0].size();
torch::Tensor feats =
torch::zeros({1, num_frames, feature_dim}, torch::kFloat);
for (size_t i = 0; i < cached_feature_.size(); ++i) {
torch::Tensor row =
torch::from_blob(const_cast<float*>(cached_feature_[i].data()),
{feature_dim}, torch::kFloat)
.clone();
feats[0][i] = std::move(row);
}
for (size_t i = 0; i < chunk_feats.size(); ++i) {
torch::Tensor row =
torch::from_blob(const_cast<float*>(chunk_feats[i].data()),
{feature_dim}, torch::kFloat)
.clone();
feats[0][cached_feature_.size() + i] = std::move(row);
}

// 2. Encoder chunk forward
int required_cache_size = chunk_size_ * num_left_chunks_;
torch::NoGradGuard no_grad;
std::vector<torch::jit::IValue> inputs = {feats, offset_, required_cache_size,
att_cache_, cnn_cache_};

// Refer interfaces in wenet/transformer/asr_model.py
auto outputs =
model_->get_method("forward_encoder_chunk")(inputs).toTuple()->elements();
CHECK_EQ(outputs.size(), 3);
torch::Tensor chunk_out = outputs[0].toTensor();
att_cache_ = outputs[1].toTensor();
cnn_cache_ = outputs[2].toTensor();
offset_ += chunk_out.size(1);

// The first dimension of returned value is for batchsize, which is 1
torch::Tensor ctc_log_probs =
model_->run_method("ctc_activation", chunk_out).toTensor()[0];
encoder_outs_.push_back(std::move(chunk_out));

// Copy to output
int num_outputs = ctc_log_probs.size(0);
int output_dim = ctc_log_probs.size(1);
out_prob->resize(num_outputs);
for (int i = 0; i < num_outputs; i++) {
(*out_prob)[i].resize(output_dim);
memcpy((*out_prob)[i].data(), ctc_log_probs[i].data_ptr(),
sizeof(float) * output_dim);
}
}

float IosAsrModel::ComputeAttentionScore(const torch::Tensor& prob,
const std::vector<int>& hyp,
int eos) {
float score = 0.0f;
auto accessor = prob.accessor<float, 2>();
for (size_t j = 0; j < hyp.size(); ++j) {
score += accessor[j][hyp[j]];
}
score += accessor[hyp.size()][eos];
return score;
}

void IosAsrModel::AttentionRescoring(
const std::vector<std::vector<int>>& hyps, float reverse_weight,
std::vector<float>* rescoring_score) {
CHECK(rescoring_score != nullptr);
int num_hyps = hyps.size();
rescoring_score->resize(num_hyps, 0.0f);

if (num_hyps == 0) {
return;
}
// No encoder output
if (encoder_outs_.size() == 0) {
return;
}

torch::NoGradGuard no_grad;
// Step 1: Prepare input for libtorch
torch::Tensor hyps_length = torch::zeros({num_hyps}, torch::kLong);
int max_hyps_len = 0;
for (size_t i = 0; i < num_hyps; ++i) {
int length = hyps[i].size() + 1;
max_hyps_len = std::max(length, max_hyps_len);
hyps_length[i] = static_cast<int64_t>(length);
}
torch::Tensor hyps_tensor =
torch::zeros({num_hyps, max_hyps_len}, torch::kLong);
for (size_t i = 0; i < num_hyps; ++i) {
const std::vector<int>& hyp = hyps[i];
hyps_tensor[i][0] = sos_;
for (size_t j = 0; j < hyp.size(); ++j) {
hyps_tensor[i][j + 1] = hyp[j];
}
}

// Step 2: Forward attention decoder by hyps and corresponding encoder_outs_
torch::Tensor encoder_out = torch::cat(encoder_outs_, 1);
auto outputs = model_
->run_method("forward_attention_decoder", hyps_tensor,
hyps_length, encoder_out, reverse_weight)
.toTuple()
->elements();

auto probs = outputs[0].toTensor();
auto r_probs = outputs[1].toTensor();

CHECK_EQ(probs.size(0), num_hyps);
CHECK_EQ(probs.size(1), max_hyps_len);

// Step 3: Compute rescoring score
for (size_t i = 0; i < num_hyps; ++i) {
const std::vector<int>& hyp = hyps[i];
float score = 0.0f;
// left-to-right decoder score
score = ComputeAttentionScore(probs[i], hyp, eos_);
// Optional: Used for right to left score
float r_score = 0.0f;
if (is_bidirectional_decoder_ && reverse_weight > 0) {
// right-to-left score
CHECK_EQ(r_probs.size(0), num_hyps);
CHECK_EQ(r_probs.size(1), max_hyps_len);
std::vector<int> r_hyp(hyp.size());
std::reverse_copy(hyp.begin(), hyp.end(), r_hyp.begin());
// right to left decoder score
r_score = ComputeAttentionScore(r_probs[i], r_hyp, eos_);
}

// combined left-to-right and right-to-left score
(*rescoring_score)[i] =
score * (1 - reverse_weight) + r_score * reverse_weight;
}
}

} // namespace wenet
63 changes: 63 additions & 0 deletions runtime/core/decoder/ios_asr_model.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 Binbin Zhang (binbzha@qq.com)
Comment thread
Ma-Dan marked this conversation as resolved.
Outdated
// 2022 Dan Ma (1067837450@qq.com)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.


#ifndef DECODER_IOS_ASR_MODEL_H_
#define DECODER_IOS_ASR_MODEL_H_

#include <memory>
#include <string>
#include <vector>

#include "torch/script.h"

#include "decoder/asr_model.h"
#include "utils/utils.h"

namespace wenet {

class IosAsrModel : public AsrModel {
public:
using TorchModule = torch::jit::script::Module;
IosAsrModel() = default;
IosAsrModel(const IosAsrModel& other);
void Read(const std::string& model_path);
std::shared_ptr<TorchModule> torch_model() const { return model_; }
void Reset() override;
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) override;
std::shared_ptr<AsrModel> Copy() const override;

protected:
void ForwardEncoderFunc(const std::vector<std::vector<float>>& chunk_feats,
std::vector<std::vector<float>>* ctc_prob) override;

float ComputeAttentionScore(const torch::Tensor& prob,
const std::vector<int>& hyp, int eos);

private:
std::shared_ptr<TorchModule> model_ = nullptr;
std::vector<torch::Tensor> encoder_outs_;
// transformer/conformer attention cache
torch::Tensor att_cache_ = torch::zeros({0, 0, 0, 0});
// conformer-only conv_module cache
torch::Tensor cnn_cache_ = torch::zeros({0, 0, 0, 0});
};

} // namespace wenet

#endif // DECODER_IOS_ASR_MODEL_H_
Empty file.
Loading