• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <torch/csrc/jit/passes/refine_tuple_types.h>
2 #include <torch/csrc/jit/runtime/graph_iterator.h>
3 
4 #include <ATen/core/type_factory.h>
5 
6 #include <utility>
7 
8 namespace torch::jit {
9 
10 namespace {
VisitTupleNode(Node * node)11 static void VisitTupleNode(Node* node) {
12   TORCH_CHECK(
13       node->outputs().size() == 1, "Tuple must have exactly one output!");
14 
15   Value* output = node->outputs()[0];
16   auto tuple_type = output->type()->expectRef<TupleType>();
17 
18   TORCH_CHECK(
19       tuple_type.containedTypes().size() == node->inputs().size(),
20       "Number of contained types does not match number of inputs!");
21 
22   // Extract updated types from input values.
23   std::vector<c10::TypePtr> types;
24   for (const Value* input : node->inputs()) {
25     types.push_back(input->type());
26   }
27 
28   // Construct new tuple type based on input types.
29   output->setType(tuple_type.withContained(std::move(types)));
30 }
31 } // anonymous namespace
32 
RefineTupleTypes(std::shared_ptr<Graph> & graph)33 void RefineTupleTypes(std::shared_ptr<Graph>& graph) {
34   DepthFirstGraphNodeIterator it(graph);
35   for (auto* node = it.next(); node != nullptr; node = it.next()) {
36     if (node->kind() == prim::TupleConstruct) {
37       VisitTupleNode(node);
38     }
39   }
40 }
41 
42 } // namespace torch::jit
43