forked from onnx/onnx
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathaxes_input_to_attribute.h
More file actions
64 lines (57 loc) · 2.31 KB
/
axes_input_to_attribute.h
File metadata and controls
64 lines (57 loc) · 2.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for all ops that remove consumed_inputs
#pragma once
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class AxesInputToAttribute : public Adapter {
public:
explicit AxesInputToAttribute(const std::string& op_name, const OpSetID& initial, const OpSetID& target)
: Adapter(op_name, initial, target) {}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
// Identify if axes is statically determined; if so, feed as attribute
const ArrayRef<Value*>& inputs = node->inputs();
// Get axes from initializer or constant operator
// Identify whether we have a Constant Op or an Initializer
Value* const_val = inputs[1];
Node* node_ptr = const_val->node();
if (node_ptr->kind() == kConstant) {
// Get value attribute of kConstant
const std::vector<int64_t>& int64s = node_ptr->t(kvalue).int64s();
if (int64s.empty()) {
// Also handle raw data
std::string raw_data = node_ptr->t(kvalue).raw();
ONNX_ASSERTM(
raw_data.size() != 0 && raw_data.size() % 8 == 0,
"Raw Data must be non-empty and size must be a multiple of 8");
int64_t* raw = (int64_t*)const_cast<char*>(raw_data.c_str());
node->is_(kaxes, std::vector<int64_t>(raw, raw + node_ptr->t(kvalue).size_from_dim(0)));
} else {
node->is_(kaxes, std::forward<const std::vector<int64_t>>(int64s));
}
// If Constant node isn't used anywhere else, remove it
node->removeInput(1);
if (const_val->uses().size() < 1) {
node_ptr->destroy();
}
} else {
// Get Value name, find Initializer with same name
for (const auto& initializer : graph->initializers()) {
if (initializer.name() == inputs[1]->uniqueName()) {
node->is_(kaxes, std::forward<const std::vector<int64_t>>(initializer.int64s()));
node->removeInput(1);
// Remove initializer
if (const_val->uses().size() < 1)
graph->eraseInitializerAndInput(const_val);
break;
}
}
}
ONNX_ASSERTM(node->hasAttribute(kaxes), "No initializer or constant input to node found");
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE