1 #include <torch/csrc/jit/passes/refine_tuple_types.h> 2 #include <torch/csrc/jit/runtime/graph_iterator.h> 3 4 #include <ATen/core/type_factory.h> 5 6 #include <utility> 7 8 namespace torch::jit { 9 10 namespace { VisitTupleNode(Node * node)11static void VisitTupleNode(Node* node) { 12 TORCH_CHECK( 13 node->outputs().size() == 1, "Tuple must have exactly one output!"); 14 15 Value* output = node->outputs()[0]; 16 auto tuple_type = output->type()->expectRef<TupleType>(); 17 18 TORCH_CHECK( 19 tuple_type.containedTypes().size() == node->inputs().size(), 20 "Number of contained types does not match number of inputs!"); 21 22 // Extract updated types from input values. 23 std::vector<c10::TypePtr> types; 24 for (const Value* input : node->inputs()) { 25 types.push_back(input->type()); 26 } 27 28 // Construct new tuple type based on input types. 29 output->setType(tuple_type.withContained(std::move(types))); 30 } 31 } // anonymous namespace 32 RefineTupleTypes(std::shared_ptr<Graph> & graph)33void RefineTupleTypes(std::shared_ptr<Graph>& graph) { 34 DepthFirstGraphNodeIterator it(graph); 35 for (auto* node = it.next(); node != nullptr; node = it.next()) { 36 if (node->kind() == prim::TupleConstruct) { 37 VisitTupleNode(node); 38 } 39 } 40 } 41 42 } // namespace torch::jit 43