11#include " partitioning.h"
2+ #include < queue>
23#include " core/conversion/evaluators/eval_util.h"
34#include " core/lowering/passes/passes.h"
45#include " core/util/prelude.h"
56#include " torch/csrc/jit/api/module.h"
6- #include " torch/csrc/jit/ir/constants .h"
7+ #include " torch/csrc/jit/passes/constant_pooling .h"
78
89namespace trtorch {
910namespace core {
1011namespace partitioning {
1112
13+ inline bool isTensorOrTensorList (torch::jit::Value* val) {
14+ return val->type ()->isSubtypeOf (torch::jit::TensorType::get ()) ||
15+ val->type ()->isSubtypeOf (torch::jit::ListType::ofTensors ());
16+ }
17+
18+ struct usage_info {
19+ int produce_id = -1 ;
20+ std::vector<int > torch_use_id;
21+ std::vector<int > tensorrt_use_id;
22+ };
23+
1224torch::jit::Value* getOrAddInputForValue (
1325 torch::jit::Value* old_value,
1426 std::shared_ptr<torch::jit::Graph>& graph,
@@ -39,6 +51,7 @@ torch::jit::Node* cloneNode(
3951 auto * block = graph->block ();
4052 auto env = [&](torch::jit::Value* v) { return getOrAddInputForValue (v, graph, old_to_new); };
4153
54+ // create node for current graph by using the metadata in node and input Values in env
4255 auto new_node = block->appendNode (graph->createClone (node, env));
4356 for (size_t i = 0 ; i < node->outputs ().size (); ++i) {
4457 auto oo = node->outputs ()[i];
@@ -68,7 +81,6 @@ void registerSegmentInOutIValues(
6881 // create a module to run the graph
6982 auto g = seg_block.g ();
7083 auto copy_g = g->copy ();
71- // LOG_INFO(*copy_g << "(copy graph)\n");
7284
7385 // create tuple for multiple outputs
7486 if (seg_block.raw_outputs ().size () > 1 ) {
@@ -110,7 +122,10 @@ void registerSegmentInOutIValues(
110122
111123 // run segments to get outputs for later segments input shape, and other arguments such as Int
112124 std::vector<torch::jit::IValue> jit_results;
125+ printf (" before forward\n " );
113126 torch::jit::IValue jit_results_ivalues = cur_mod.forward (jit_inputs_ivalues);
127+ printf (" after forward\n " );
128+
114129 if (jit_results_ivalues.isTuple ()) {
115130 auto results = jit_results_ivalues.toTuple ()->elements ();
116131 for (auto r : results) {
@@ -149,13 +164,10 @@ std::vector<torch::jit::IValue> generateRandomInputs(std::vector<conversion::Inp
149164 return random_inputs;
150165}
151166
152- void registerSegmentsInputsOutputs (
153- std::vector<SegmentedBlock>& segmented_blocks,
154- std::shared_ptr<torch::jit::Graph> g) {
167+ void registerSegmentsOutputs (std::vector<SegmentedBlock>& segmented_blocks, std::shared_ptr<torch::jit::Graph> g) {
155168 // find the corresponding raw values in original global graph for this segmented block's inputs/outputs
156169 std::set<torch::jit::Value*> input_values;
157170 for (auto & seg_block : segmented_blocks) {
158- seg_block.registerInputs ();
159171 for (auto & input : seg_block.raw_inputs ()) {
160172 input_values.insert (input);
161173 }
@@ -165,51 +177,124 @@ void registerSegmentsInputsOutputs(
165177 input_values.insert (graph_output);
166178 }
167179
168- // should be careful here because some in-place operations don't return any values
180+ // should be careful here because some in-place operations don't return any values, there is no output for this kind
181+ // of segment identify the output for each mini-graph by checking if any value in this graph is used later we
182+ // shouldn't register nonTensor output for TensorRT segments
169183 for (auto & seg_block : segmented_blocks) {
170184 for (auto & mini_graph_input : input_values) {
171185 if (std::find (seg_block.raw_inputs ().begin (), seg_block.raw_inputs ().end (), mini_graph_input) ==
172186 seg_block.raw_inputs ().end () &&
173- seg_block.contain_raw_input (mini_graph_input)) {
187+ seg_block.contain_raw_value (mini_graph_input)) {
188+ if (!isTensorOrTensorList (mini_graph_input) && seg_block.target () == SegmentedBlock::kTensorRT )
189+ continue ;
174190 seg_block.registerOutput (mini_graph_input);
175191 }
176192 }
193+ // if no output, then register the last node's output as current graph's output
177194 if (seg_block.raw_outputs ().empty ()) {
178- seg_block.registerOutput (seg_block.raw_inputs ()[0 ]);
195+ // for Torch segments, register input as output
196+ if (seg_block.target () == SegmentedBlock::kTorch ) {
197+ seg_block.registerOutput (seg_block.raw_inputs ()[0 ]);
198+ } else {
199+ // for TensorRT segments, register last nonInput Tensor outputs
200+ for (int i = seg_block.raw_nodes ().size () - 1 ; i >= 0 ; --i) {
201+ for (auto node_output : seg_block.raw_nodes ()[i]->outputs ()) {
202+ if (isTensorOrTensorList (node_output))
203+ seg_block.registerOutput (node_output);
204+ }
205+ if (!seg_block.raw_outputs ().empty ())
206+ break ;
207+ }
208+ }
179209 }
180210 }
211+ // erase segments which still have no output
212+ segmented_blocks.erase (
213+ std::remove_if (
214+ segmented_blocks.begin (),
215+ segmented_blocks.end (),
216+ [](SegmentedBlock& seg_block) { return seg_block.raw_outputs ().empty (); }),
217+ segmented_blocks.end ());
181218
182219 return ;
183220}
184221
185- void eraseNonTensorInputsOutputs (
186- SegmentedBlock& seg_block,
187- std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps) {
188- if (seg_block.target () == SegmentedBlock::kTorch )
189- return ;
190- auto mini_graph = seg_block.g ();
191-
192- for (int i = seg_block.raw_inputs ().size () - 1 ; i >= 0 ; --i) {
193- // erase this input and prepend a prim::Constant if it's not Tensor
194- if (!seg_block.raw_inputs ()[i]->type ()->isSubtypeOf (torch::jit::TensorType::get ()) &&
195- !seg_block.raw_inputs ()[i]->type ()->isSubtypeOf (c10::ListType::ofTensors ())) {
196- auto new_val = torch::jit::insertConstant (*mini_graph, ivalues_maps[seg_block.raw_inputs ()[i]]);
197- seg_block.inputs ()[i]->replaceAllUsesWith (new_val);
198- seg_block.eraseInput (i);
222+ std::vector<torch::jit::Node*> getDependencyNodes (std::vector<torch::jit::Value*>& vals) {
223+ // using bfs to get the DAG dependency nodes for input value
224+ std::queue<torch::jit::Value*, std::deque<torch::jit::Value*>> q (
225+ std::deque<torch::jit::Value*>(vals.begin (), vals.end ()));
226+ std::unordered_set<torch::jit::Node*> visited;
227+ std::vector<torch::jit::Node*> stk;
228+ while (!q.empty ()) {
229+ auto cur_val = q.front ();
230+ q.pop ();
231+ auto node = cur_val->node ();
232+ if (node->kind () != torch::jit::prim::Constant && !visited.count (node)) {
233+ stk.push_back (node);
234+ for (auto input : node->inputs ()) {
235+ if (!isTensorOrTensorList (input)) {
236+ q.push (input);
237+ }
238+ }
199239 }
200240 }
241+ std::reverse (stk.begin (), stk.end ());
242+ return stk;
243+ }
201244
202- for (int i = seg_block.raw_outputs ().size () - 1 ; i >= 0 ; --i) {
203- if (!seg_block.raw_outputs ()[i]->type ()->isSubtypeOf (torch::jit::TensorType::get ()) &&
204- !seg_block.raw_outputs ()[i]->type ()->isSubtypeOf (c10::ListType::ofTensors ())) {
205- seg_block.eraseOutput (i);
245+ SegmentedBlock injectNodesForNonTensorInputs (SegmentedBlock& seg_block) {
246+ // reconstruct segmented_block if this block requires nonTensor input
247+ std::vector<torch::jit::Value*> nontensor_inputs;
248+ for (auto input : seg_block.raw_inputs ()) {
249+ if (!isTensorOrTensorList (input)) {
250+ nontensor_inputs.push_back (input);
206251 }
207252 }
253+ std::vector<torch::jit::Node*> new_block_nodes = getDependencyNodes (nontensor_inputs);
254+ new_block_nodes.insert (new_block_nodes.end (), seg_block.raw_nodes ().begin (), seg_block.raw_nodes ().end ());
255+ return SegmentedBlock (seg_block.target (), new_block_nodes);
256+ }
208257
209- // not sure to delete this block or just fallback to pytorch
210- if (seg_block.raw_outputs ().empty ()) {
211- seg_block.update_target (SegmentedBlock::kTorch );
258+ void resolveNonTensorInputs (std::vector<SegmentedBlock>& segmented_blocks, std::shared_ptr<torch::jit::Graph> g) {
259+ // for NonTensor inputs in TensorRT segments, count the usages on Torch segments and TensorRT segments
260+ std::unordered_map<torch::jit::Value*, usage_info> usage_counts;
261+ for (int i = segmented_blocks.size () - 1 ; i >= 0 ; --i) {
262+ for (auto input : segmented_blocks[i].raw_inputs ()) {
263+ if (!isTensorOrTensorList (input)) {
264+ segmented_blocks[i].target () == SegmentedBlock::kTorch ? usage_counts[input].torch_use_id .push_back (i)
265+ : usage_counts[input].tensorrt_use_id .push_back (i);
266+ }
267+ }
268+ for (auto & use : usage_counts) {
269+ if (segmented_blocks[i].contain_raw_value (use.first )) {
270+ use.second .produce_id = i;
271+ }
272+ }
212273 }
274+ std::unordered_set<int > updated_segments;
275+ for (auto & use : usage_counts) {
276+ auto use_info = use.second ;
277+ // if the segment that produce this nonTensor value is kTensorRT but consumed in kTorch, inject nodes in the first
278+ // kTorch segments
279+ if (segmented_blocks[use_info.produce_id ].target () == SegmentedBlock::kTensorRT && !use_info.torch_use_id .empty ()) {
280+ int first_torch_id = use_info.torch_use_id .front ();
281+ if (!updated_segments.count (first_torch_id)) {
282+ auto new_torch_block = injectNodesForNonTensorInputs (segmented_blocks[first_torch_id]);
283+ segmented_blocks[first_torch_id] = new_torch_block;
284+ updated_segments.insert (first_torch_id);
285+ }
286+ } else {
287+ // KTensorRT segments always need to inject nodes for the nonTensor inputs
288+ for (int i : use_info.tensorrt_use_id ) {
289+ if (!updated_segments.count (i)) {
290+ auto new_seg_block = injectNodesForNonTensorInputs (segmented_blocks[i]);
291+ segmented_blocks[i] = new_seg_block;
292+ updated_segments.insert (i);
293+ }
294+ }
295+ }
296+ }
297+ return ;
213298}
214299
215300void construct_segments (
@@ -231,20 +316,18 @@ void construct_segments(
231316 }
232317}
233318
234- std::vector<SegmentedBlock> segment_graph (
319+ void segment_graph (
235320 std::shared_ptr<torch::jit::Graph> g,
236- std::vector< conversion::InputRange>& input_ranges ,
237- const conversion::TorchFallback& fallback_info ) {
321+ const conversion::TorchFallback& fallback_info ,
322+ std::vector<SegmentedBlock>& segmented_blocks ) {
238323 auto min_block_size = fallback_info.min_block_size ;
239324 std::unordered_set<std::string> forced_fallback_operators (
240325 fallback_info.forced_fallback_operators .begin (), fallback_info.forced_fallback_operators .end ());
241- std::vector<SegmentedBlock> segmented_blocks;
242326
243327 auto nodes = g->block ()->nodes ();
244328
245329 // segment the nodes
246330 std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;
247-
248331 for (const auto n : nodes) {
249332 if (n->kind () == torch::jit::prim::Constant)
250333 continue ;
@@ -261,22 +344,33 @@ std::vector<SegmentedBlock> segment_graph(
261344 if (!pytorch_nodes.empty ()) {
262345 segmented_blocks.emplace_back (SegmentedBlock::kTorch , pytorch_nodes);
263346 }
347+ }
348+
349+ std::vector<SegmentedBlock> Partition (
350+ std::shared_ptr<torch::jit::Graph> g,
351+ std::vector<conversion::InputRange>& input_ranges,
352+ const conversion::TorchFallback& fallback_info) {
353+ // segment lowering global graph into blocks
354+ std::vector<SegmentedBlock> segmented_blocks;
355+ segment_graph (g, fallback_info, segmented_blocks);
264356
265- // register input/output torch::jit::Value for segmetned graphs
266- registerSegmentsInputsOutputs (segmented_blocks, g);
357+ // resolve nonTensor inputs/outputs
358+ resolveNonTensorInputs (segmented_blocks, g);
359+
360+ // register input/output torch::jit::Value for segmented graphs
361+ registerSegmentsOutputs (segmented_blocks, g);
267362
268363 // store the mapping from lowering graph torch::jit::Value => torch::jit::IValue that we get by running segments
269364 std::unordered_map<torch::jit::Value*, torch::jit::IValue> ivalues_maps;
270-
271365 std::vector<torch::jit::IValue> random_inputs = generateRandomInputs (input_ranges);
272366 for (size_t i = 0 ; i < g->inputs ().size (); ++i) {
273367 ivalues_maps[g->inputs ()[i]] = random_inputs[i];
274368 }
275369
276- // register every segment's input shape, and it's running output Ivalues
370+ // register every segment's input shape, and it's running output IValues
277371 for (auto & seg_block : segmented_blocks) {
372+ torch::jit::ConstantPooling (seg_block.g ());
278373 registerSegmentInOutIValues (seg_block, ivalues_maps);
279- eraseNonTensorInputsOutputs (seg_block, ivalues_maps);
280374 }
281375
282376 return segmented_blocks;
0 commit comments