1 #include <torch/csrc/jit/passes/shape_analysis.h>
2 
3 #include <c10/util/Exception.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/frontend/error_report.h>
6 #include <torch/csrc/jit/ir/alias_analysis.h>
7 #include <torch/csrc/jit/ir/constants.h>
8 #include <torch/csrc/jit/ir/ir.h>
9 #include <torch/csrc/jit/ir/ir_views.h>
10 #include <torch/csrc/jit/passes/utils/op_registry.h>
11 #include <torch/csrc/jit/runtime/exception_message.h>
12 #include <torch/csrc/jit/runtime/operator.h>
13 
14 #include <torch/csrc/autograd/variable.h>
15 
16 #include <ATen/DeviceGuard.h>
17 #include <ATen/ExpandUtils.h>
18 #include <ATen/core/symbol.h>
19 
20 #ifndef AT_PER_OPERATOR_HEADERS
21 #include <ATen/Functions.h>
22 #else
23 #include <ATen/ops/empty_strided.h>
24 #endif
25 
26 #include <exception>
27 #include <memory>
28 #include <sstream>
29 #include <utility>
30 #include <vector>
31 
32 namespace torch::jit {
33 
mergeTypes(ArrayRef<Value * > lhs,ArrayRef<Value * > rhs,ArrayRef<Value * > outputs)34 bool mergeTypes(
35     ArrayRef<Value*> lhs,
36     ArrayRef<Value*> rhs,
37     ArrayRef<Value*> outputs) {
38   AT_ASSERT(lhs.size() == rhs.size() && rhs.size() == outputs.size());
39   bool changed = false;
40   for (const auto i : c10::irange(lhs.size())) {
41     auto old_output_type = outputs[i]->type();
42     auto new_type =
43         unifyTypes(lhs[i]->type(), rhs[i]->type(), /*default_to_union=*/true);
44     AT_ASSERT(new_type);
45     outputs[i]->setType(*new_type);
46     if (*old_output_type != *outputs[i]->type())
47       changed = true;
48   }
49   return changed;
50 }
51 
applyTypes(ArrayRef<Value * > src,ArrayRef<Value * > dst)52 static void applyTypes(ArrayRef<Value*> src, ArrayRef<Value*> dst) {
53   AT_ASSERT(src.size() == dst.size());
54   for (const auto i : c10::irange(src.size())) {
55     dst[i]->setType(src[i]->type());
56   }
57 }
58 
propagateBlock(Block * block,bool insert_expands)59 void PropertyPropBase::propagateBlock(Block* block, bool insert_expands) {
60   for (Node* node : block->nodes()) {
61     try {
62       propagateNode(node, insert_expands);
63     } catch (propagation_error& e) {
64       setUnshapedType(node);
65     } catch (std::exception& e) {
66       throw(
67           ErrorReport(node->sourceRange())
68           << ExceptionMessage(e)
69           << "\nThe above operation failed shape propagation in this context");
70     }
71   }
72 }
73 
processIf(Node * node)74 void PropertyPropBase::processIf(Node* node) {
75   auto then_block = node->blocks().at(0);
76   auto else_block = node->blocks().at(1);
77   propagateBlock(then_block);
78   propagateBlock(else_block);
79   mergeTypes(then_block->outputs(), else_block->outputs(), node->outputs());
80 }
81 
processLoop(Node * node)82 void PropertyPropBase::processLoop(Node* node) {
83   LoopView loop(node);
84   // propagate counter type
85   loop.currentTripCount()->setType(loop.maxTripCount()->type());
86   applyTypes(loop.carriedInputs(), loop.bodyCarriedInputs());
87 
88   do {
89     propagateBlock(loop.bodyBlock(), /*insert_expands=*/false);
90     // note: inserting expands is unsafe at this point, we don't know
91     // if the types are stable yet, so the arguments to expand may change
92   } while (mergeTypes(
93       loop.bodyCarriedInputs(),
94       loop.bodyCarriedOutputs(),
95       loop.bodyCarriedInputs()));
96 
97   // now that the types are stable, we can insert the expands
98   propagateBlock(loop.bodyBlock(), /*insert_expands=*/true);
99   applyTypes(loop.bodyCarriedInputs(), loop.carriedOutputs());
100 }
101 
setUnshapedType(Value * o)102 void PropertyPropBase::setUnshapedType(Value* o) {
103   o->setType(unshapedType(o->type()));
104 }
105 
setUnshapedType(Node * node)106 void PropertyPropBase::setUnshapedType(Node* node) {
107   for (auto o : node->outputs()) {
108     setUnshapedType(o);
109   }
110 }
111 
112 namespace prim {
113 using namespace ::c10::prim;
114 }
115 
116 #define SHAPE_ASSERT(cond) \
117   if (!(cond))             \
118   throw propagation_error()
119 
120 namespace {
121 
isValidArgumentForRunning(Value * v)122 bool isValidArgumentForRunning(Value* v) {
123   // allow constants
124   if (toIValue(v))
125     return true;
126   if (TensorTypePtr tt = v->type()->cast<TensorType>()) {
127     if (!tt->scalarType()) {
128       return false;
129     }
130     return !at::isIntegralType(*tt->scalarType(), /*includeBool=*/false);
131   }
132   return v->type()->isSubtypeOf(*FloatType::get());
133 }
134 
isValidReturnForRunning(Value * v)135 bool isValidReturnForRunning(Value* v) {
136   return v->type()->isSubtypeOf(*TensorType::get()) ||
137       v->type()->isSubtypeOf(*NumberType::get());
138 }
139 
containsTensorType(const TypePtr & t)140 bool containsTensorType(const TypePtr& t) {
141   auto n_contained = t->containedTypes().size();
142   if (n_contained == 1) {
143     return t->containedTypes().at(0)->isSubtypeOf(*TensorType::get());
144   } else if (n_contained > 1) {
145     return std::any_of(
146         t->containedTypes().begin(),
147         t->containedTypes().end(),
148         containsTensorType);
149   }
150   return false;
151 }
152 
153 // for each node in the schema with type Tensor, extract the T type
154 // returns std::nullopt if any Tensor in the schema does not have a known
155 // shape ignores non-tensor in the list of inputs
gatherTensorTypes(Node * node,bool complete=false)156 std::optional<std::vector<TensorTypePtr>> gatherTensorTypes(
157     Node* node,
158     bool complete = false) {
159   std::vector<TensorTypePtr> tensor_types;
160 
161   auto schema_opt = node->maybeSchema();
162   if (!schema_opt) {
163     return std::nullopt;
164   }
165   auto& schema = *schema_opt;
166   auto& args = schema.arguments();
167   // can't handle varargs primitives because we don't know what should be a
168   // Tensor
169   if (schema.is_vararg()) {
170     return std::nullopt;
171   }
172   for (const auto i : c10::irange(args.size())) {
173     if (args[i].type()->isSubtypeOf(*ListType::ofTensors())) {
174       return std::nullopt;
175     } else if (args[i].type()->isSubtypeOf(*TensorType::get())) {
176       if (auto type = node->input(i)->type()->cast<TensorType>()) {
177         if (complete && !type->isComplete()) {
178           return std::nullopt;
179         }
180         tensor_types.push_back(type);
181       } else {
182         return std::nullopt;
183       }
184     } else /* non-tensor type */ {
185       continue;
186     }
187   }
188   return tensor_types;
189 }
190 
wrapDim(int64_t dim,at::IntArrayRef sizes)191 int64_t wrapDim(int64_t dim, at::IntArrayRef sizes) {
192   if (dim < 0) {
193     dim += (int64_t)sizes.size();
194   }
195   return dim;
196 }
197 
unionScalarTypes(c10::ScalarType original,c10::ScalarType next)198 c10::ScalarType unionScalarTypes(
199     c10::ScalarType original,
200     c10::ScalarType next) {
201   if (original == c10::ScalarType::Undefined) {
202     return next;
203   } else {
204     return c10::promoteTypes(original, next);
205   }
206 }
207 
208 // Promotes result types for arithmetic operations on Tensor operands using
209 // new type promotion logic. See tensor_attributes.rst for details.
210 // This doesn't handle the case of arithmetic ops with Scalar arguments (when
211 // `Tensor.getUnsafeTensorImpl()->is_wrapped_number()` would return true)
getPromotedTypeForArithmeticOp(Node * node)212 std::optional<c10::ScalarType> getPromotedTypeForArithmeticOp(Node* node) {
213   c10::ScalarType dimmed = c10::ScalarType::Undefined;
214   c10::ScalarType zerodim = c10::ScalarType::Undefined;
215   // binary arithmetic ops, more than 2 args is alpha.
216   for (const auto i : c10::irange(2)) {
217     auto dtt = node->inputs()[i]->type()->expect<TensorType>();
218     auto inputDtype = dtt->scalarType();
219     if (!dtt || !inputDtype) {
220       return std::nullopt;
221     }
222     if (dtt->dim() && *dtt->dim() > 0) {
223       dimmed = unionScalarTypes(dimmed, *inputDtype);
224     } else if (!isFloatingType(dimmed)) {
225       // if no dimensions
226       zerodim = unionScalarTypes(zerodim, *inputDtype);
227     }
228   }
229   // if a tensor with dimensions is already of the highest category, don't
230   // need to check zero-dim tensors.
231   if (isFloatingType(dimmed)) {
232     return dimmed;
233   }
234   // int_tensor * zero_dim_floating -> floating_tensor
235   if (isIntegralType(dimmed, false) && isFloatingType(zerodim)) {
236     return zerodim;
237   }
238   // bool_tensor * non_bool_scalar -> non_bool_tensor
239   if (c10::ScalarType::Bool == dimmed &&
240       c10::ScalarType::Undefined != zerodim) {
241     return zerodim;
242   }
243   // types of dimensioned tensors generally take precedence over zero-dim
244   // tensors if not promoting due to category. e.g.:
245   // int_tensor * long -> int_tensor
246   if (c10::ScalarType::Undefined != dimmed) {
247     return dimmed;
248   }
249 
250   // no dimmed tensors. e.g. zero_dim_tensor + zero_dim_tensor.
251   return zerodim;
252 }
253 
254 class ShapePropagator : public PropertyPropBase {
255  public:
ShapePropagator(const std::shared_ptr<Graph> & graph)256   explicit ShapePropagator(const std::shared_ptr<Graph>& graph)
257       : PropertyPropBase(graph), aliasDb_(graph) {
258     collectResizeSet(graph->block());
259   }
260 
261  private:
262   ValueSet resized_alias_set;
263   const AliasDb aliasDb_;
264 
resizesInput(Node * n)265   bool resizesInput(Node* n) {
266     static std::unordered_set<Symbol> resize_ops{
267         aten::resize_,
268         aten::resize_as_,
269         aten::copy_,
270         aten::set_,
271         aten::unsqueeze_,
272         aten::t_,
273         aten::transpose_,
274     };
275 
276     if (resize_ops.count(n->kind()))
277       return true;
278 
279     if (!n->maybeSchema())
280       return false;
281 
282     // ops which take the result and write to input "out"
283     if (auto out_arg_index = n->schema().argumentIndexWithName("out")) {
284       auto arg = n->schema().arguments().at(*out_arg_index);
285       return arg.kwarg_only() && arg.type()->isSubtypeOf(*TensorType::get());
286     }
287     return false;
288   }
289 
collectResizeSet(Block * block)290   void collectResizeSet(Block* block) {
291     for (Node* n : block->nodes()) {
292       for (Block* b : n->blocks()) {
293         collectResizeSet(b);
294       }
295       if (resizesInput(n)) {
296         for (const auto input : n->inputs()) {
297           if (aliasDb_.writesToAlias(n, {input})) {
298             resized_alias_set.insert(input);
299           }
300         }
301       }
302     }
303   }
304 
representativeValue(Value * v)305   IValue representativeValue(Value* v) {
306     TypePtr type_ = v->type();
307     // if the value is actually constant, just use it!
308     if (auto iv = toIValue(v)) {
309       return *iv;
310     }
311     if (TensorTypePtr type = type_->cast<TensorType>()) {
312       if (type->isComplete()) {
313         at::DeviceGuard device_guard(*type->device());
314         return at::empty_strided(
315                    *type->sizes().concrete_sizes(),
316                    *type->strides().concrete_sizes(),
317                    at::TensorOptions(*type->device()).dtype(type->scalarType()))
318             .zero_();
319       }
320       // fallthrough
321     } else if (type_->isSubtypeOf(*FloatType::get())) {
322       return 0.f;
323     }
324     // we should not get here because isValidArgumentForRunning should have
325     // prevented it
326     std::stringstream ss;
327     ss << "unable to create representative value for: " << type_->str()
328        << ". File a bug report";
329     throw std::runtime_error(ss.str());
330   }
331 
broadcastBinary(Node * node,std::vector<TensorTypePtr> & types,size_t idx1,size_t idx2)332   void broadcastBinary(
333       Node* node,
334       std::vector<TensorTypePtr>& types,
335       size_t idx1,
336       size_t idx2) {
337     auto expected_size = at::infer_size(
338         *types[idx1]->sizes().concrete_sizes(),
339         *types[idx2]->sizes().concrete_sizes());
340     auto broadcast = [&](size_t input_idx) {
341       TensorTypePtr input_type = types.at(input_idx);
342       if (input_type->sizes() == expected_size)
343         return;
344       auto graph = node->owningGraph();
345       WithInsertPoint point_guard{node};
346       Node* expand = graph
347                          ->create(
348                              aten::expand,
349                              {node->inputs().at(input_idx),
350                               graph->insertConstant(expected_size),
351                               graph->insertConstant(false)})
352                          ->insertBefore(node);
353       propagateNode(expand);
354       node->replaceInput(input_idx, expand->output());
355     };
356     broadcast(idx1);
357     broadcast(idx2);
358     types[0] = node->inputs().at(idx1)->type()->expect<TensorType>();
359     types[1] = node->inputs().at(idx2)->type()->expect<TensorType>();
360   }
361 
362   OperatorSet cannot_propagate_shape_by_running_it = {
363       "aten::inverse(Tensor self) -> Tensor",
364   };
365 
366   // Check if this node depends on a value that has been mutated previously. If
367   // it has, then it's not safe to run this node in isolation, since we don't
368   // know whether the dependency has been executed.
369   std::unordered_map<Node*, bool> dependsOnMutationMemo_;
dependsOnMutation(Node * node)370   bool dependsOnMutation(Node* node) {
371     if (dependsOnMutationMemo_.count(node) != 0) {
372       return dependsOnMutationMemo_[node];
373     }
374 
375     if (aliasDb_.hasWriters(node)) {
376       // If something could have written to a value used by this node, we can't
377       // guarantee the result is the same when running it in isolation.
378       dependsOnMutationMemo_[node] = true;
379       return true;
380     }
381 
382     // recursively check the producers of its inputs. We need to do this if the
383     // mutable value has been laundered through a pure function:
384     //   a += 1
385     //   c = a + b
386     //   d = c + 1
387     // In this case, `d` cares whether `a` has been mutated even though it's not
388     // a direct input.
389     auto depends = false;
390     for (auto input : node->inputs()) {
391       depends |= dependsOnMutation(input->node());
392     }
393 
394     dependsOnMutationMemo_[node] = depends;
395     return depends;
396   }
397 
canPropagateShapeByRunningIt(Node * node)398   bool canPropagateShapeByRunningIt(Node* node) {
399     if (node->isMemberOf(cannot_propagate_shape_by_running_it)) {
400       return false;
401     }
402 
403     if (dependsOnMutation(node)) {
404       return false;
405     }
406 
407     bool valid_args = std::all_of(
408         node->inputs().begin(),
409         node->inputs().end(),
410         isValidArgumentForRunning);
411     if (!valid_args)
412       return false;
413 
414     bool valid_returns = std::all_of(
415         node->outputs().begin(),
416         node->outputs().end(),
417         isValidReturnForRunning);
418     if (!valid_returns)
419       return false;
420 
421     return true;
422   }
423 
424   // If there's no Tensor in outputs, e.g float / float,
425   // we don't need to propagate shape.
DoesntRefineOutputs(Node * node)426   bool DoesntRefineOutputs(Node* node) {
427     auto outputs = node->outputs();
428     for (auto& out : outputs) {
429       if (containsTensorType(out->type())) {
430         return false;
431       }
432     }
433     return true;
434   }
435 
PropagateShapeOnNodeByRunningIt(Node * node,Operation op=nullptr)436   bool PropagateShapeOnNodeByRunningIt(Node* node, Operation op = nullptr) {
437     if (!canPropagateShapeByRunningIt(node))
438       return false;
439 
440     if (!op)
441       op = node->getOperation();
442 
443     Stack stack;
444 
445     for (auto input : node->inputs()) {
446       stack.push_back(representativeValue(input));
447     }
448 
449     // XXX: we're not catching any exceptions from the op for now. This
450     // is to uncover any mistakes we could make when editing this code,
451     // and eventually it shouldn't matter, because this phase should be
452     // preceded by schema checking.
453     op(stack);
454 
455     AT_ASSERT(stack.size() == node->outputs().size());
456     for (const auto i : c10::irange(stack.size())) {
457       // some ops may have mixed tensor/primitive outputs
458       // for primitives, we don't need to change the type because it is already
459       // its most constrained form.
460       auto tensor_type = node->outputs()[i]->type()->cast<TensorType>();
461       if (stack[i].isTensor() && tensor_type) {
462         // gradient information isn't always available or part of representative
463         // inputs, maintain original grad property
464         auto tensor_grad = tensor_type->requiresGrad();
465         node->outputs()[i]->setType(TensorType::create(stack[i].toTensor())
466                                         ->withRequiresGrad(tensor_grad));
467       }
468     }
469     return true;
470   }
471 
PropagateCatShape(Node * cat_node)472   void PropagateCatShape(Node* cat_node) {
473     static const auto propagate_complete =
474         [](Node* node, at::ArrayRef<Value*> tensors) -> bool {
475       auto input_types =
476           fmap(tensors, [](Value* v) { return v->type()->cast<TensorType>(); });
477       if (!std::all_of(
478               input_types.begin(),
479               input_types.end(),
480               [](const TensorTypePtr& tp) {
481                 return tp != nullptr && tp->isComplete();
482               })) {
483         return false;
484       }
485       if (!node->is_constant(attr::dim))
486         return false;
487       std::vector<int64_t> sizes = *input_types[0]->sizes().concrete_sizes();
488       const int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
489       const int64_t ndim = (int64_t)sizes.size();
490 
491       if (dim < 0 || dim >= ndim)
492         return false;
493 
494       sizes[dim] = 0;
495       for (auto& tp : input_types) {
496         auto tp_sizes = tp->sizes().concrete_sizes().value();
497         if (sizes.size() != tp_sizes.size())
498           return false;
499         for (const auto i : c10::irange(ndim)) {
500           if (sizes[i] != tp_sizes[i] && i != dim) {
501             return false;
502           }
503         }
504         sizes[dim] += tp_sizes[dim];
505       }
506       node->output()->setType(input_types[0]->withSizes(sizes));
507       return true;
508     };
509     static const auto propagate = [](Node* node,
510                                      at::ArrayRef<Value*> tensors) -> bool {
511       for (Value* v : tensors) {
512         if (auto type = v->type()->cast<TensorType>()) {
513           node->output()->setType(type->dimensionedOnly());
514           return true;
515         }
516       }
517       return false;
518     };
519     auto list_node =
520         ((cat_node->kind() == prim::FusedConcat)
521              ? cat_node
522              : cat_node->namedInput(attr::tensors)->node());
523     if (list_node->kind() == prim::ListConstruct ||
524         cat_node->kind() == prim::FusedConcat) {
525       auto tensors = list_node->inputs();
526       if (!tensors.empty()) {
527         // NOLINTNEXTLINE(bugprone-branch-clone)
528         if (propagate_complete(cat_node, tensors)) {
529           return;
530         } else if (propagate(cat_node, tensors)) {
531           return;
532         }
533       }
534     }
535     setUnshapedType(cat_node);
536   }
537 
propagateTorchTensorShape(Node * node)538   void propagateTorchTensorShape(Node* node) {
539     auto input_type = node->inputs().at(0)->type();
540 
541     size_t dims = 0;
542     auto input_base_type = input_type;
543     auto list_type = input_type->cast<ListType>();
544     while (list_type) {
545       dims++;
546       input_base_type = list_type->getElementType();
547       list_type = input_base_type->cast<ListType>();
548     }
549 
550     std::optional<at::ScalarType> default_type =
551         tryScalarTypeFromJitType(*input_base_type);
552     if (auto grad_index = node->schema().argumentIndexWithName("dtype")) {
553       auto inp = toIValue(node->inputs().at(*grad_index));
554       if (inp == std::nullopt) {
555         return;
556       } else if (!inp->isNone()) {
557         default_type = inp->toScalarType();
558       }
559     }
560 
561     at::Device default_device = at::kCPU;
562     if (auto device_index = node->schema().argumentIndexWithName("device")) {
563       auto inp = toIValue(node->inputs().at(*device_index));
564       if (inp == std::nullopt) {
565         return;
566       } else if (!inp->isNone()) {
567         default_device = inp->toDevice();
568       }
569     }
570     node->output()->setType(TensorType::create(
571         default_type, default_device, dims, /*requires_grad=*/std::nullopt));
572   }
573 
574   // returns whether any such values were found
setUnshapedTypeIfAliasResizedSet(at::ArrayRef<Value * > vs)575   bool setUnshapedTypeIfAliasResizedSet(at::ArrayRef<Value*> vs) {
576     bool in_resize = false;
577     for (auto v : vs) {
578       if (aliasDb_.mayAlias(ValueSet{v}, resized_alias_set)) {
579         setUnshapedType(v);
580         in_resize = true;
581       }
582     }
583     return in_resize;
584   }
585 
propagateNode(Node * node,bool insert_expands=true)586   void propagateNode(Node* node, bool insert_expands = true) override {
587     // Certain ops like resize_ change the input tensors size. Because our
588     // analysis is flow invariant, we set any Tensor that can alias a resized
589     // Tensor to the base Tensor Type without size information.
590     if (setUnshapedTypeIfAliasResizedSet(node->inputs())) {
591       return setUnshapedType(node);
592     }
593 
594     // These don't require the types, and have complicated schema. Return early
595     // after we process them.
596     switch (node->kind()) {
597       case prim::If:
598         return processIf(node);
599       case prim::Loop: {
600         return processLoop(node);
601       }
602       case aten::Bool:
603       case aten::Int:
604       case aten::Float:
605       case aten::ScalarImplicit:
606       case aten::FloatImplicit:
607       case aten::IntImplicit:
608         return; // correct num type is already set
609       case prim::NumToTensor: {
610         TypePtr typ = node->input()->type();
611         if (typ->isSubtypeOf(*IntType::get()) ||
612             typ->isSubtypeOf(*BoolType::get())) {
613           node->output()->setType(TensorType::create(
614               at::kLong, at::kCPU, 0, /*requires_grad=*/std::nullopt));
615         } else if (node->input()->type()->isSubtypeOf(*FloatType::get())) {
616           node->output()->setType(TensorType::create(
617               at::kDouble, at::kCPU, 0, /*requires_grad=*/std::nullopt));
618         }
619         return;
620       }
621       case aten::tensor:
622       case aten::as_tensor: {
623         // as_tensor has an overloaded schema and can either have a tensor or
624         // a list as the first input, if the input is a tensor, we delegate
625         // the shape propagation in PropagateTensorShapeOnNode
626         if (node->inputs().at(0)->type()->isSubtypeOf(*TensorType::get())) {
627           break;
628         }
629         return propagateTorchTensorShape(node);
630       }
631       case prim::TupleConstruct: {
632         // We refresh the tuple type, because the input types could have been
633         // refined.
634         auto orig_type = node->output()->type()->expect<TupleType>();
635         auto new_types =
636             fmap(node->inputs(), [](Value* v) { return v->type(); });
637         node->output()->setType(
638             orig_type->createWithContained(std::move(new_types)));
639         return;
640       }
641       case prim::TupleUnpack: {
642         auto tuple_type = node->input()->type()->cast<TupleType>();
643         AT_ASSERT(
644             tuple_type &&
645             tuple_type->elements().size() == node->outputs().size());
646         auto elems = tuple_type->elements();
647         for (size_t i = 0; i < node->outputs().size(); ++i) {
648           node->output(i)->setType(elems[i]);
649         }
650         return;
651       }
652       case prim::Constant: {
653         if (node->output()->type()->isSubtypeOf(*TensorType::get())) {
654           node->output()->inferTypeFrom(node->t(attr::value));
655         }
656         return;
657       }
658       case prim::unchecked_unwrap_optional: {
659         // If we have specialized the optional type to the element type,
660         // we want to pass it down. We write this as input.isSubtypeOf(output)
661         // to be sure that we don't screw up nested optionals.
662         if (node->input()->type()->isSubtypeOf(*node->output()->type())) {
663           node->output()->setType(node->input()->type());
664         }
665         return;
666       }
667       case prim::ConstantChunk: {
668         Value* tensor = node->input();
669         if (auto type = tensor->type()->cast<TensorType>()) {
670           type = type->dimensionedOnly();
671           for (Value* output : node->outputs()) {
672             output->setType(type);
673           }
674         } else {
675           setUnshapedType(node);
676         }
677         return;
678       }
679       case prim::grad: {
680         auto tt = node->input()->type()->expect<TensorType>();
681         // grad may be undefined
682         // requires_grad may be required
683         auto grad_type = TensorType::get()->withPossiblyUndefined();
684         node->output()->setType(std::move(grad_type));
685         return;
686       }
687       case prim::CallFunction:
688       case prim::CallMethod:
689       case prim::AutogradZero: {
690         setUnshapedType(node);
691         return;
692       }
693       case prim::GetAttr: {
694         auto cls = node->input()->type()->expect<ClassType>();
695         // propagate any type specializations encoded in the type of the class
696         node->output()->setType(cls->getAttribute(node->s(attr::name)));
697         return;
698       }
699       case aten::_unwrap_optional: {
700         // If we have specialized the optional type to the element type,
701         // we want to pass it down. We write this as input.isSubtypeOf(output)
702         // to be sure that we don't screw up nested optionals.
703         if (node->input()->type()->isSubtypeOf(*node->output()->type())) {
704           node->output()->setType(node->input()->type());
705         }
706         return;
707       }
708       default:
709         break; // fall-through
710     }
711 
712     if (node->hasSideEffects()) {
713       return;
714     }
715 
716     if (node->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor") ||
717         node->kind() == prim::FusedConcat) {
718       return PropagateCatShape(node);
719     }
720 
721     if (auto maybe_complete_types =
722             gatherTensorTypes(node, /*complete=*/true)) {
723       if (PropagateCompleteShapeOnNode(
724               node, insert_expands, std::move(*maybe_complete_types))) {
725         return;
726       }
727     }
728 
729     if (PropagateTensorShapeOnNode(node, insert_expands)) {
730       return;
731     }
732 
733     if (DoesntRefineOutputs(node)) {
734       return;
735     }
736 
737     if (PropagateShapeOnNodeByRunningIt(node)) {
738       return;
739     }
740     return setUnshapedType(node);
741   }
742 
determineListSize(Value * list)743   static std::optional<size_t> determineListSize(Value* list) {
744     AT_ASSERT(list->type()->cast<ListType>());
745     if (auto shape = constant_as<c10::List<int64_t>>(list)) {
746       return shape->size();
747     }
748     auto input_node = list->node();
749     if (input_node->kind() == prim::ListConstruct) {
750       return input_node->inputs().size();
751     }
752     return std::nullopt;
753   }
754 
755   // is it ok to try to run the op
756   // If an input is a constant, then we assume that the input is valid
757   // and we can try to run it.
758   // Otherwise:
759   // Integral typed _inputs_ are often an indicator that we're indexing into
760   // a tensor, so we should special-case these ops in the shape propagation.
761   // Additionally, passing in a zero representative tensor into an integer
762   // division op causes divide-by-zero errors
763   // _Outputs_ must be tensors or primitives
764   // We will call inferTypeFrom on the tensors, and ignore the primitives.
765   // However, we allow primitive returns because we want to support mixed
766   // primitive/tensor outputs.
767 
PropagateTensorShapeOnNode(Node * node,bool insert_expands)768   bool PropagateTensorShapeOnNode(Node* node, bool insert_expands) {
769     static const auto broadcast =
770         [](std::vector<TensorTypePtr>& tensor_types,
771            std::optional<at::ScalarType> t) -> TensorTypePtr {
772       if (tensor_types.size() == 1) {
773         return tensor_types[0]->dimensionedOnly()->withScalarType(t);
774       }
775       AT_ASSERT(!tensor_types.empty());
776       auto any_type = tensor_types[0];
777       auto max_dims = any_type->dim();
778       for (auto& type : tensor_types) {
779         if (!max_dims || !type->dim()) {
780           max_dims = std::nullopt;
781         } else {
782           max_dims = std::max(*max_dims, *type->dim());
783         }
784       }
785       return TensorType::create(
786           t,
787           any_type->device(),
788           max_dims,
789           /*requires_grad=*/std::nullopt);
790     };
791 
792     using type_vec_t = std::vector<TensorTypePtr>;
793     // Formula is expected to return a vector of length equal to the number of
794     // tensor outputs of the node, or an empty vector which implies that it
795     // failed to propagate.
796     using formula_t = std::function<type_vec_t(Node*)>;
797     static std::mutex shape_formulas_mutex;
798     static std::vector<std::pair<OperatorSet, formula_t>> shape_formulas;
799     struct register_formula_for {
800       register_formula_for(OperatorSet operators, formula_t formula) {
801         std::unique_lock<std::mutex> lock{shape_formulas_mutex};
802         shape_formulas.emplace_back(std::move(operators), std::move(formula));
803       }
804     };
805 
806     // Requirements:
807     //   dims           : preserved
808     //   scalar type    : preserved
809     //   device         : preserved
810     //   tensor inputs  : 1
811     //   tensor outputs : 1
812     // Additionally:
813     //   - First input should be the only tensor input
814     static const register_formula_for simple_unary_ops{
815         {
816             "aten::acos(Tensor self) -> Tensor",
817             "aten::neg(Tensor self) -> Tensor",
818             "aten::t(Tensor self) -> Tensor",
819             "aten::sigmoid(Tensor self) -> Tensor",
820             "aten::logit(Tensor self, float? eps=None) -> Tensor",
821             "aten::tanh(Tensor self) -> Tensor",
822             "aten::relu(Tensor self) -> Tensor",
823             "aten::asin(Tensor self) -> Tensor",
824             "aten::atan(Tensor self) -> Tensor",
825             "aten::ceil(Tensor self) -> Tensor",
826             "aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor",
827             "aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)",
828             "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
829             "aten::celu(Tensor self, Scalar alpha) -> Tensor",
830             "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
831             "aten::clamp_max(Tensor self, Scalar max) -> Tensor",
832             "aten::clamp_min(Tensor self, Scalar min) -> Tensor",
833             "aten::alpha_dropout(Tensor input, float p, bool train) -> Tensor",
834             "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
835             "aten::cos(Tensor self) -> Tensor",
836             "aten::cosh(Tensor self) -> Tensor",
837             "aten::digamma(Tensor self) -> Tensor",
838             "aten::dropout(Tensor input, float p, bool train) -> Tensor",
839             "aten::elu(Tensor self, Scalar alpha, Scalar scale, Scalar input_scale) -> Tensor",
840             "aten::erf(Tensor self) -> Tensor",
841             "aten::erfc(Tensor self) -> Tensor",
842             "aten::erfinv(Tensor self) -> Tensor",
843             "aten::exp(Tensor self) -> Tensor",
844             "aten::expm1(Tensor self) -> Tensor",
845             "aten::log(Tensor self) -> Tensor",
846             "aten::log10(Tensor self) -> Tensor",
847             "aten::log1p(Tensor self) -> Tensor",
848             "aten::log2(Tensor self) -> Tensor",
849             "aten::log_sigmoid(Tensor self) -> Tensor",
850             "aten::floor(Tensor self) -> Tensor",
851             "aten::frac(Tensor self) -> Tensor",
852             "aten::flip(Tensor self, int[] dims) -> Tensor",
853             "aten::feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor",
854             "aten::feature_dropout(Tensor input, float p, bool train) -> Tensor",
855             "aten::hardshrink(Tensor self, Scalar lambd) -> Tensor",
856             "aten::hardtanh(Tensor self, Scalar min_val, Scalar max_val) -> Tensor",
857             "aten::glu(Tensor self, int dim) -> Tensor",
858             "aten::inverse(Tensor self) -> Tensor",
859             "aten::leaky_relu(Tensor self, Scalar negative_slope) -> Tensor",
860             "aten::lgamma(Tensor self) -> Tensor",
861             "aten::mvlgamma(Tensor self, int p) -> Tensor",
862             "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
863             "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
864             "aten::permute(Tensor self, int[] dims) -> Tensor",
865             "aten::pin_memory(Tensor(a) self, Device? device=None) -> Tensor(a)",
866             "aten::pinverse(Tensor self, float rcond) -> Tensor",
867             "aten::reciprocal(Tensor self) -> Tensor",
868             "aten::relu(Tensor self) -> Tensor",
869             "aten::round(Tensor self) -> Tensor",
870             "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
871             "aten::rsqrt(Tensor self) -> Tensor",
872             "aten::selu(Tensor self) -> Tensor",
873             "aten::gelu(Tensor self, *, str approximate='none') -> Tensor",
874             "aten::sigmoid(Tensor self) -> Tensor",
875             "aten::sign(Tensor self) -> Tensor",
876             "aten::sin(Tensor self) -> Tensor",
877             "aten::sinh(Tensor self) -> Tensor",
878             "aten::softplus(Tensor self, Scalar beta, Scalar threshold) -> Tensor",
879             "aten::softshrink(Tensor self, Scalar lambd) -> Tensor",
880             "aten::sqrt(Tensor self) -> Tensor",
881             "aten::tan(Tensor self) -> Tensor",
882             "aten::tanh(Tensor self) -> Tensor",
883             "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
884             "aten::transpose(Tensor self, int dim0, int dim1) -> Tensor",
885             "aten::tril(Tensor self, int diagonal) -> Tensor",
886             "aten::triu(Tensor self, int diagonal) -> Tensor",
887             "aten::trunc(Tensor self) -> Tensor",
888             "aten::rot90(Tensor self, int k, int[] dims) -> Tensor",
889             "aten::narrow(Tensor self, int dim, int start, int length) -> Tensor",
890             "aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor",
891             "aten::alias(Tensor self) -> Tensor",
892         },
893         [](Node* node) -> type_vec_t {
894           auto input_type = node->input(0)->type()->cast<TensorType>();
895           return input_type ? type_vec_t{input_type->dimensionedOnly()}
896                             : type_vec_t{};
897         }};
898 
899     // Requirements:
900     //   dims           : preserved
901     //   scalar type    : preserved, except complex maps to float
902     //   device         : preserved
903     //   tensor inputs  : 1
904     //   tensor outputs : 1
905     // Additionally:
906     //   - First input should be the only tensor input
907     static const register_formula_for simple_unary_ops_complex_to_float{
908         {
909             "aten::abs(Tensor self) -> Tensor",
910         },
911         [](Node* node) -> type_vec_t {
912           auto input_type = node->input(0)->type()->cast<TensorType>();
913 
914           // Maps complex -> float
915           if (input_type->scalarType()) {
916             const auto scalar_type = *(input_type->scalarType());
917             if (isComplexType(scalar_type)) {
918               const auto out_type = c10::toRealValueType(scalar_type);
919               return type_vec_t{
920                   input_type->dimensionedOnly()->withScalarType(out_type)};
921             }
922           }
923 
924           return input_type ? type_vec_t{input_type->dimensionedOnly()}
925                             : type_vec_t{};
926         }};
927 
928     // Requirements:
929     //   dims           : broadcast all tensor args
930     //   scalar type    : promoted from input dtypes
931     //   device         : always matching and preserved
932     //   tensor inputs  : *
933     //   tensor outputs : 1
934     static const register_formula_for broadcasting_ops_arithmetic{
935         {
936             // Tensor-Tensor operators
937             "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
938             "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
939             "aten::mul(Tensor self, Tensor other) -> Tensor",
940             "aten::div(Tensor self, Tensor other) -> Tensor",
941         },
942         [](Node* node) -> type_vec_t {
943           if (auto maybe_tensor_types = gatherTensorTypes(node)) {
944             AT_ASSERT(maybe_tensor_types->size() >= 2);
945             auto dtype = getPromotedTypeForArithmeticOp(node);
946             return {broadcast(*maybe_tensor_types, dtype)};
947           }
948           return {};
949         }};
950 
951     // Requirements:
952     //   dims           : broadcast all tensor args
953     //   scalar type    : always matching and preserved
954     //   device         : always matching and preserved
955     //   tensor inputs  : *
956     //   tensor outputs : 1
957     static const register_formula_for broadcasting_ops{
958         {
959             "aten::pow(Tensor self, Tensor exponent) -> Tensor",
960             "aten::fmod(Tensor self, Tensor other) -> Tensor",
961             "aten::remainder(Tensor self, Tensor other) -> Tensor",
962             "aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor",
963             "aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor",
964             "aten::max(Tensor self, Tensor other) -> Tensor",
965             "aten::min(Tensor self, Tensor other) -> Tensor",
966             "aten::__and__(Tensor self, Tensor other) -> Tensor",
967             "aten::__or__(Tensor self, Tensor other) -> Tensor",
968             "aten::__xor__(Tensor self, Tensor other) -> Tensor",
969             "aten::__lshift__(Tensor self, Tensor other) -> Tensor",
970             "aten::__rshift__(Tensor self, Tensor other) -> Tensor",
971             "aten::__iand__(Tensor self, Tensor other) -> Tensor",
972             "aten::__ior__(Tensor self, Tensor other) -> Tensor",
973             "aten::__ixor__(Tensor self, Tensor other) -> Tensor",
974             "aten::__ilshift__(Tensor self, Tensor other) -> Tensor",
975             "aten::__irshift__(Tensor self, Tensor other) -> Tensor",
976 
977             // Ops with Tensor-Tensor overloads only
978             "aten::atan2(Tensor self, Tensor other) -> Tensor",
979         },
980         [](Node* node) -> type_vec_t {
981           if (auto maybe_tensor_types = gatherTensorTypes(node)) {
982             AT_ASSERT(maybe_tensor_types->size() >= 2);
983             auto first_scalar_type = (*maybe_tensor_types)[0]->scalarType();
984             auto second_scalar_type = (*maybe_tensor_types)[1]->scalarType();
985             if (!first_scalar_type || !second_scalar_type) {
986               return {};
987             }
988             size_t arg_for_type = 0;
989             if (c10::promoteTypes(*first_scalar_type, *second_scalar_type) !=
990                 first_scalar_type) {
991               arg_for_type = 1;
992             }
993             auto t = (*maybe_tensor_types)[arg_for_type]->scalarType();
994             return {broadcast(*maybe_tensor_types, t)};
995           }
996           return {};
997         }};
998 
999     static const register_formula_for fused_accum_binary_ops{
1000         {
1001             // Non-binary ops
1002             "aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
1003             "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
1004         },
1005         [](Node* node) -> type_vec_t {
1006           if (auto maybe_tensor_types = gatherTensorTypes(node)) {
1007             auto dtype = (*maybe_tensor_types)[0]->scalarType();
1008             if (!dtype) {
1009               return {};
1010             }
1011             return {broadcast(*maybe_tensor_types, dtype)};
1012           }
1013           return {};
1014         }};
1015 
1016     static const register_formula_for broadcasting_tensor_scalar_ops_arithmetic{
1017         {
1018             // Tensor-Scalar operators
1019             "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
1020             "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
1021             "aten::mul(Tensor self, Scalar other) -> Tensor",
1022             "aten::div(Tensor self, Scalar other) -> Tensor",
1023         },
1024         [](Node* node) -> type_vec_t {
1025           if (auto maybe_tensor_types = gatherTensorTypes(node)) {
1026             auto first_scalar_type = (*maybe_tensor_types)[0]->scalarType();
1027             auto second_scalar_type =
1028                 tryScalarTypeFromJitType(*node->inputs()[1]->type());
1029             if (!first_scalar_type || !second_scalar_type) {
1030               return {};
1031             }
1032             if (isIntegralType(*first_scalar_type, false) &&
1033                 isFloatingType(*second_scalar_type)) {
1034               auto default_dtype =
1035                   at::typeMetaToScalarType(caffe2::get_default_dtype());
1036               return {broadcast(*maybe_tensor_types, default_dtype)};
1037             }
1038             if (c10::ScalarType::Bool == *first_scalar_type &&
1039                 c10::ScalarType::Bool != *second_scalar_type) {
1040               auto result_type =
1041                   c10::promoteTypes(*first_scalar_type, *second_scalar_type);
1042               return {broadcast(*maybe_tensor_types, result_type)};
1043             }
1044             return {broadcast(*maybe_tensor_types, first_scalar_type)};
1045           }
1046           return {};
1047         }};
1048 
1049     // NB: we always take the scalar type of the Tensor
1050     static const register_formula_for broadcasting_tensor_scalar_ops{
1051         {
1052 
1053             "aten::pow(Tensor self, Scalar exponent) -> Tensor",
1054             "aten::fmod(Tensor self, Scalar other) -> Tensor",
1055             "aten::remainder(Tensor self, Scalar other) -> Tensor",
1056             "aten::pow(Scalar self, Tensor exponent) -> Tensor",
1057             "aten::__and__(Tensor self, Scalar other) -> Tensor",
1058             "aten::__or__(Tensor self, Scalar other) -> Tensor",
1059             "aten::__xor__(Tensor self, Scalar other) -> Tensor",
1060             "aten::__lshift__(Tensor self, Scalar other) -> Tensor",
1061             "aten::__rshift__(Tensor self, Scalar other) -> Tensor",
1062             "aten::__iand__(Tensor self, Scalar other) -> Tensor",
1063             "aten::__ior__(Tensor self, Scalar other) -> Tensor",
1064             "aten::__ixor__(Tensor self, Scalar other) -> Tensor",
1065             "aten::__ilshift__(Tensor self, Scalar other) -> Tensor",
1066             "aten::__irshift__(Tensor self, Scalar other) -> Tensor",
1067         },
1068         [](Node* node) -> type_vec_t {
1069           if (auto maybe_tensor_types = gatherTensorTypes(node)) {
1070             return {broadcast(
1071                 *maybe_tensor_types, (*maybe_tensor_types)[0]->scalarType())};
1072           }
1073           return {};
1074         }};
1075 
1076     // aten::where is special in that its return type is the second argument's
1077     // (self) type rather than the that of condition
1078     static const register_formula_for where_op{
1079         {
1080             "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
1081         },
1082         [](Node* node) -> type_vec_t {
1083           if (auto maybe_tensor_types = gatherTensorTypes(node)) {
1084             return {broadcast(
1085                 *maybe_tensor_types, (*maybe_tensor_types)[1]->scalarType())};
1086           }
1087           return {};
1088         }};
1089 
1090     static const auto any_tensor_type = [](Node* node) -> TensorTypePtr {
1091       for (Value* input : node->inputs()) {
1092         if (auto type = input->type()->cast<TensorType>()) {
1093           if (type->dim().has_value()) {
1094             return type;
1095           }
1096         }
1097       }
1098       return nullptr;
1099     };
1100 
1101     // Requirements:
1102     //   dims           : always matching and preserved
1103     //   scalar type    : always matching and preserved
1104     //   device         : always matching and preserved
1105     //   tensor inputs  : 2
1106     //   tensor outputs : 1
1107     static const register_formula_for binary_ops_strict_match{
1108         {
1109             "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
1110             "aten::mm(Tensor self, Tensor mat2) -> Tensor",
1111             "aten::bmm(Tensor self, Tensor mat2) -> Tensor",
1112         },
1113         [](Node* node) -> type_vec_t {
1114           if (auto type = any_tensor_type(node)) {
1115             return {std::move(type)};
1116           }
1117           return {};
1118         }};
1119 
1120     // Requirements:
1121     //   dims           : all tensor args are broadcast
1122     //   scalar type    : byte/uint8
1123     //   device         : always matching and preserved
1124     //   tensor inputs  : *
1125     //   tensor outputs : 1
1126     static const register_formula_for comparison_ops{
1127         {
1128             "aten::lt(Tensor self, Tensor other) -> Tensor",
1129             "aten::le(Tensor self, Tensor other) -> Tensor",
1130             "aten::gt(Tensor self, Tensor other) -> Tensor",
1131             "aten::ge(Tensor self, Tensor other) -> Tensor",
1132             "aten::eq(Tensor self, Tensor other) -> Tensor",
1133             "aten::ne(Tensor self, Tensor other) -> Tensor",
1134             "aten::lt(Tensor self, Scalar other) -> Tensor",
1135             "aten::le(Tensor self, Scalar other) -> Tensor",
1136             "aten::gt(Tensor self, Scalar other) -> Tensor",
1137             "aten::ge(Tensor self, Scalar other) -> Tensor",
1138             "aten::eq(Tensor self, Scalar other) -> Tensor",
1139             "aten::ne(Tensor self, Scalar other) -> Tensor",
1140         },
1141         [](Node* node) -> type_vec_t {
1142           if (auto maybe_tensor_types = gatherTensorTypes(node)) {
1143             return {broadcast(*maybe_tensor_types, at::kBool)};
1144           }
1145           return {};
1146         }};
1147 
1148     static const register_formula_for nn_ops_first_input_formula{
1149         *nn_ops_first_input_preserving(), [](Node* node) -> type_vec_t {
1150           if (auto type = node->input(0)->type()->cast<TensorType>()) {
1151             return {type->dimensionedOnly()};
1152           }
1153           return {};
1154         }};
1155 
1156     // Requirements:
1157     //   dims           : 0
1158     //   scalar type    : preserved
1159     //   device         : preserved
1160     //   tensor inputs  : 1
1161     //   tensor outputs : 1
1162     // Additionally:
1163     //   - First input should be the only tensor input
1164     static const register_formula_for all_reduce_ops{
1165         {
1166             "aten::det(Tensor self) -> Tensor",
1167             "aten::logdet(Tensor self) -> Tensor",
1168             "aten::max(Tensor self) -> Tensor",
1169             "aten::min(Tensor self) -> Tensor",
1170             "aten::median(Tensor self) -> Tensor",
1171             "aten::nanmedian(Tensor self) -> Tensor",
1172             "aten::norm(Tensor self, Scalar p) -> Tensor",
1173             "aten::std(Tensor self, bool unbiased) -> Tensor",
1174             "aten::trace(Tensor self) -> Tensor",
1175             "aten::var(Tensor self, bool unbiased) -> Tensor",
1176             "aten::all(Tensor self) -> Tensor",
1177             "aten::any(Tensor self) -> Tensor",
1178         },
1179         [](Node* node) -> type_vec_t {
1180           if (auto type = node->input(0)->type()->cast<TensorType>()) {
1181             return {type->withDim(0)};
1182           }
1183           return {};
1184         }};
1185 
1186     // Requirements:
1187     //   dims           : 0
1188     //   scalar type    : dtype if specified, else preserved
1189     //   device         : preserved
1190     //   tensor inputs  : 1
1191     //   tensor outputs : 1
1192     // Additionally:
1193     //   - First input should be the only tensor input
1194     static const register_formula_for reduce_ops_with_opt_dtype{
1195         {"aten::mean(Tensor self, *, int? dtype) -> Tensor"},
1196         [](Node* node) -> type_vec_t {
1197           std::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
1198           if (auto type = node->input(0)->type()->cast<TensorType>()) {
1199             auto ret = type->withDim(0);
1200             if (maybe_dtype_option && !maybe_dtype_option->isNone()) {
1201               return {ret->withScalarType(maybe_dtype_option->toScalarType())};
1202             } else {
1203               return {std::move(ret)};
1204             }
1205           }
1206           return {};
1207         }};
1208 
1209     // Requirements:
1210     //   dims           : 0
1211     //   scalar type    : dtype if specified, else preserved if floating point,
1212     //   otherwise long/int64 device         : preserved tensor inputs  : 1
1213     //   tensor outputs : 1
1214     // Additionally:
1215     //   - First input should be the only tensor input
1216     static const register_formula_for
1217         all_reduce_ops_with_integer_upcast_and_dtype{
1218             {
1219                 "aten::sum(Tensor self, *, int? dtype) -> Tensor",
1220                 "aten::prod(Tensor self, *, int? dtype) -> Tensor",
1221             },
1222             [](Node* node) -> type_vec_t {
1223               if (auto type = node->input(0)->type()->cast<TensorType>()) {
1224                 type = type->withDim(0);
1225                 std::optional<IValue> maybe_dtype_option =
1226                     node->get(attr::dtype);
1227                 if (maybe_dtype_option && !maybe_dtype_option->isNone()) {
1228                   return {
1229                       type->withScalarType(maybe_dtype_option->toScalarType())};
1230                 }
1231                 if (type->scalarType()) {
1232                   return {
1233                       at::isFloatingType(*type->scalarType())
1234                           ? std::move(type)
1235                           : type->withScalarType(at::kLong)};
1236                 } else {
1237                   return {std::move(type)};
1238                 }
1239               }
1240               return {};
1241             }};
1242 
1243     static const auto reduce_op_handler = [](Node* node,
1244                                              int64_t num_reduced_dim = 0,
1245                                              bool upcast_integer = false,
1246                                              std::optional<IValue> opt_dtype =
1247                                                  std::nullopt) -> type_vec_t {
1248       if (auto type = node->input(0)->type()->cast<TensorType>()) {
1249         if (!type->scalarType() || !type->dim()) {
1250           return {};
1251         }
1252         if (opt_dtype && !opt_dtype->isNone()) {
1253           type = type->withScalarType(opt_dtype->toScalarType());
1254         } else if (upcast_integer && !at::isFloatingType(*type->scalarType())) {
1255           type = type->withScalarType(at::kLong);
1256         }
1257         if (static_cast<int64_t>(*type->dim()) >= num_reduced_dim &&
1258             num_reduced_dim > 0) {
1259           return {type->withDim(*type->dim() - num_reduced_dim)};
1260         } else {
1261           return {std::move(type)};
1262         }
1263       }
1264       return {};
1265     };
1266 
1267     static const auto multidim_reduce_with_keepdim =
1268         [](Node* node,
1269            int64_t num_reduced_dim,
1270            bool upcast_integer) -> type_vec_t {
1271       auto maybe_keepdim = node->get<bool>(attr::keepdim);
1272       if (!maybe_keepdim)
1273         return {};
1274       return reduce_op_handler(
1275           node, *maybe_keepdim ? 0 : num_reduced_dim, upcast_integer);
1276     };
1277 
1278     // Requirements:
1279     //   dims           : 0 if dim is None, otherwise preserved if keepdim ==
1280     //   false or 1 smaller otherwise scalar type    : preserved device :
1281     //   preserved tensor inputs  : 1 tensor outputs : 1
1282     // Additionally:
1283     //   - First input should be the only tensor input
1284     //   - Has a bool keepdim argument
1285     static const register_formula_for argminmax{
1286         {
1287             "aten::argmax(Tensor self, int? dim, bool keepdim) -> Tensor",
1288             "aten::argmin(Tensor self, int? dim, bool keepdim) -> Tensor",
1289         },
1290         [](Node* node) -> type_vec_t {
1291           if (auto type = node->input(0)->type()->cast<TensorType>()) {
1292             if (node->input(1)->type()->kind() == c10::TypeKind::NoneType) {
1293               return {type->withDim(0)};
1294             } else {
1295               return multidim_reduce_with_keepdim(
1296                   node, /*num_reduced_dim=*/1, /*upcast_integer=*/false);
1297             }
1298           }
1299           return {};
1300         }};
1301 
1302     // Requirements:
1303     //   dims           : preserved if keepdim == false, 1 smaller otherwise
1304     //   scalar type    : preserved for first output, byte/uint8 for second
1305     //   output if exists device         : preserved tensor inputs  : 1 tensor
1306     //   outputs : 1 or 2
1307     // Additionally:
1308     //   - First input should be the only tensor input
1309     //   - Has a bool keepdim argument
1310     static const register_formula_for dim_reduce_ops{
1311         {
1312             "aten::all(Tensor self, int dim, bool keepdim) -> Tensor",
1313             "aten::any(Tensor self, int dim, bool keepdim) -> Tensor",
1314 
1315             // Ops returning indices as second output
1316             "aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)",
1317             "aten::max(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
1318             "aten::min(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
1319             "aten::median(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
1320             "aten::nanmedian(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
1321             "aten::mode(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
1322         },
1323         [](Node* node) -> type_vec_t {
1324           // NB: Note that while this function is generally meant to be used
1325           // with ops that have a single output, we will fix up its return right
1326           // below.
1327           auto output_types = multidim_reduce_with_keepdim(
1328               node, /*num_reduced_dim=*/1, /*upcast_integer=*/false);
1329           if (!output_types.empty() && node->outputs().size() == 2) {
1330             output_types.push_back(
1331                 output_types.back()->withScalarType(at::kLong));
1332           }
1333           return output_types;
1334         }};
1335 
1336     // Requirements:
1337     //   dims           : preserved if keepdim == false, 1 smaller otherwise
1338     //   scalar type    : dtype if specified. preserved if floating point,
1339     //   otherwise long/int64 device         : preserved tensor inputs  : 1
1340     //   tensor outputs : 1
1341     // Additionally:
1342     //   - First input should be the only tensor input
1343     //   - has a bool keepdim argument
1344     static const register_formula_for dim_reduce_ops_with_integer_upcast{
1345         {
1346             "aten::prod(Tensor self, int dim, bool keepdim, *, int? dtype) -> Tensor",
1347         },
1348         [](Node* node) -> type_vec_t {
1349           auto maybe_keepdim = node->get<bool>(attr::keepdim);
1350           std::optional<IValue> opt_dtype = node->get(attr::dtype);
1351           return reduce_op_handler(
1352               node,
1353               /*num_reduce_dim=*/*maybe_keepdim ? 0 : 1,
1354               /*integer_upcast=*/true,
1355               std::move(opt_dtype));
1356         }};
1357 
1358     // Requirements:
1359     //   dims           : preserved
1360     //   scalar type    : dtype if specified, preserved if floating point,
1361     //    otherwise long/int64
1362     //   device         : preserved
1363     //   tensor inputs  : 1
1364     //   tensor outputs : 1
1365     // Additionally:
1366     //   - First input should be the only tensor input
1367     static const register_formula_for dim_reduce_ops_dtype{
1368         {"aten::cumprod(Tensor self, int dim, *, int? dtype) -> Tensor",
1369          "aten::cumsum(Tensor self, int dim, *, int? dtype) -> Tensor",
1370          "aten::log_softmax(Tensor self, int dim, int? dtype) -> Tensor"},
1371         [](Node* node) -> type_vec_t {
1372           std::optional<IValue> opt_dtype = node->get(attr::dtype);
1373           return reduce_op_handler(
1374               node,
1375               /*num_reduce_dim=*/0,
1376               /*integer_upcast=*/true,
1377               std::move(opt_dtype));
1378         }};
1379 
1380     // Requirements:
1381     //   dims           : preserved
1382     //   scalar type    : dtype if specified, otherwise preserved
1383     //   device         : preserved
1384     //   tensor inputs  : 1
1385     //   tensor outputs : 1
1386     // Additionally:
1387     //   - has bool keepdim and int[] dim arguments
1388     static const register_formula_for register_softmax{
1389         {"aten::softmax(Tensor self, int dim, int? dtype) -> Tensor"},
1390         [](Node* node) -> type_vec_t {
1391           std::optional<IValue> opt_dtype = node->get(attr::dtype);
1392           return reduce_op_handler(
1393               node,
1394               /*num_reduced_dim=*/0,
1395               /*upcast_integer=*/false,
1396               std::move(opt_dtype));
1397         }};
1398 
1399     static const auto factory_with_ndim =
1400         [](Node* node, int dim, at::ScalarType default_dtype) -> type_vec_t {
1401       std::optional<IValue> maybe_layout_option = node->get(attr::layout);
1402       if (!maybe_layout_option)
1403         return {};
1404 
1405       std::optional<IValue> maybe_device_option = node->get(attr::device);
1406       if (!maybe_device_option)
1407         return {};
1408       auto device =
1409           (maybe_device_option->isNone() ? at::kCPU
1410                                          : maybe_device_option->toDevice());
1411 
1412       std::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
1413       if (!maybe_dtype_option)
1414         return {};
1415       auto dtype =
1416           (maybe_dtype_option->isNone() ? default_dtype
1417                                         : maybe_dtype_option->toScalarType());
1418 
1419       return {TensorType::create(
1420           dtype, device, dim, /*requires_grad=*/std::nullopt)};
1421     };
1422 
1423     static const auto factory_like_with_ndim = [](Node* node,
1424                                                   int dim) -> type_vec_t {
1425       auto tt = node->input(0)->type()->expect<TensorType>();
1426       auto in_type = tt->scalarType();
1427       auto in_dev = tt->device();
1428 
1429       std::optional<IValue> maybe_layout_option = node->get(attr::layout);
1430       if (!maybe_layout_option)
1431         return {};
1432 
1433       std::optional<IValue> maybe_device_option = node->get(attr::device);
1434       if (!maybe_device_option)
1435         return {};
1436 
1437       if (!maybe_device_option->isNone()) {
1438         in_dev = maybe_device_option->toDevice();
1439       }
1440 
1441       std::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
1442       if (!maybe_dtype_option)
1443         return {};
1444 
1445       if (!maybe_dtype_option->isNone()) {
1446         in_type = maybe_dtype_option->toScalarType();
1447       }
1448 
1449       return {TensorType::create(
1450           in_type, in_dev, dim, /*requires_grad=*/std::nullopt)};
1451     };
1452 
1453     // Requirements:
1454     //   dims           : preserved
1455     //   scalar type    : equal to value of dtype
1456     //   device         : equal to value of device
1457     //   tensor inputs  : 1
1458     //   tensor outputs : 1
1459     // Additionally:
1460     //   - has ScalarType dtype, Layout layout and Device device arguments
1461     static const register_formula_for like_factories_with_options{
1462         {
1463             "aten::empty_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1464             "aten::full_like(Tensor self, Scalar fill_value, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1465             "aten::ones_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1466             "aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1467             "aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1468             "aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1469             "aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1470             "aten::zeros_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1471         },
1472         [](Node* node) -> type_vec_t {
1473           if (auto type =
1474                   node->namedInput(attr::self)->type()->cast<TensorType>()) {
1475             if (type->dim()) {
1476               return factory_like_with_ndim(node, (int)*type->dim());
1477             }
1478           }
1479           return {};
1480         }};
1481 
1482     // Requirements:
1483     //   dims           : equal to number of elements in size
1484     //   scalar type    : equal to value of dtype
1485     //   device         : equal to value of device
1486     //   tensor inputs  : 1
1487     //   tensor outputs : 1
1488     // Additionally:
1489     //   - has int[] size, ScalarType dtype, Layout layout and Device device
1490     //   arguments
1491     static const register_formula_for size_factories_with_options{
1492         {
1493             "aten::empty(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory, MemoryFormat? memory_format=contiguous_format) -> Tensor",
1494             "aten::full(int[] size, Scalar fill_value, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1495             "aten::ones(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1496             "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1497             "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1498             "aten::zeros(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1499         },
1500         [](Node* node) -> type_vec_t {
1501           if (auto maybe_size = node->get<c10::List<int64_t>>(attr::size)) {
1502             return factory_with_ndim(
1503                 node, (int)maybe_size->size(), at::kDouble);
1504           }
1505           return {};
1506         }};
1507 
1508     static const register_formula_for randint{
1509         {
1510             "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1511             "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1512         },
1513         [](Node* node) -> type_vec_t {
1514           if (auto maybe_size = node->get<c10::List<int64_t>>(attr::size)) {
1515             return factory_with_ndim(node, (int)maybe_size->size(), at::kLong);
1516           }
1517           return {};
1518         }};
1519 
1520     static const auto get_cast_scalar_type = [](Node* node) -> at::ScalarType {
1521       switch (node->kind()) {
1522         case aten::_cast_Byte:
1523           return at::kByte;
1524         case aten::_cast_Char:
1525           return at::kChar;
1526         case aten::_cast_Double:
1527           return at::kDouble;
1528         case aten::_cast_Float:
1529           return at::kFloat;
1530         case aten::_cast_Half:
1531           return at::kHalf;
1532         case aten::_cast_Int:
1533           return at::kInt;
1534         case aten::_cast_Long:
1535           return at::kLong;
1536         case aten::_cast_Short:
1537           return at::kShort;
1538         default:
1539           AT_ASSERTM(
1540               false,
1541               "unknown node kind in get_cast_scalar_type: ",
1542               node->kind().toQualString());
1543       }
1544     };
1545     static const register_formula_for cast_ops{
1546         {
1547             "aten::_cast_Byte(Tensor self, bool non_blocking) -> Tensor",
1548             "aten::_cast_Char(Tensor self, bool non_blocking) -> Tensor",
1549             "aten::_cast_Double(Tensor self, bool non_blocking) -> Tensor",
1550             "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
1551             "aten::_cast_Half(Tensor self, bool non_blocking) -> Tensor",
1552             "aten::_cast_Int(Tensor self, bool non_blocking) -> Tensor",
1553             "aten::_cast_Long(Tensor self, bool non_blocking) -> Tensor",
1554             "aten::_cast_Short(Tensor self, bool non_blocking) -> Tensor",
1555         },
1556         [](Node* node) -> type_vec_t {
1557           if (auto type =
1558                   node->namedInput(attr::self)->type()->cast<TensorType>()) {
1559             return {type->withScalarType(get_cast_scalar_type(node))};
1560           }
1561           return {};
1562         }};
1563 
1564     // First, try to match one of the registered formulas to their operator
1565     // sets.
1566     for (auto& entry : shape_formulas) {
1567       if (node->isMemberOf(entry.first)) {
1568         auto types = entry.second(node);
1569         if (types.empty()) {
1570           return false;
1571         } else {
1572           auto outputs = node->outputs();
1573           AT_ASSERT(types.size() == outputs.size());
1574           for (const auto i : c10::irange(types.size())) {
1575             AT_ASSERT(outputs[i]->type()->isSubtypeOf(*TensorType::get()));
1576             outputs[i]->setType(types[i]);
1577           }
1578           return true;
1579         }
1580       }
1581     }
1582 
1583     // This section implements shape prop for an assorted set of nodes that only
1584     // need partial information about their input types.
1585     const auto input_type = [node](size_t index) {
1586       auto result = node->input(index)->type()->cast<TensorType>();
1587       if (result) {
1588         result = result->dimensionedOnly();
1589       }
1590       return result;
1591     };
1592     if (node->matches(
1593             "aten::masked_select(Tensor self, Tensor mask) -> Tensor")) {
1594       if (auto type = input_type(0)) {
1595         node->output()->setType(type->withDim(1));
1596         return true;
1597       }
1598     } else if (node->matches("aten::detach(Tensor(a) self) -> Tensor(a)")) {
1599       if (auto type = input_type(0)) {
1600         node->output()->setType(type->withRequiresGrad(false));
1601         return true;
1602       }
1603     } else if (
1604         node->matches(
1605             "aten::batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor)")) {
1606       if (auto type = input_type(0)) {
1607         if (type->scalarType() == at::kHalf) {
1608           type = type->withScalarType(at::kFloat);
1609         }
1610         type = type->withDim(1);
1611         node->outputs()[0]->setType(type);
1612         node->outputs()[1]->setType(std::move(type));
1613         return true;
1614       }
1615     } else if (node->matches(
1616                    "aten::dot(Tensor self, Tensor tensor) -> Tensor")) {
1617       if (auto type = any_tensor_type(node)) {
1618         node->output()->setType(type->withDim(0));
1619         return true;
1620       }
1621     } else if (
1622         node->matches("aten::mv(Tensor self, Tensor vec) -> Tensor") ||
1623         node->matches(
1624             "aten::addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta, Scalar alpha) -> Tensor")) {
1625       if (auto type = any_tensor_type(node)) {
1626         node->output()->setType(type->withDim(1));
1627         return true;
1628       }
1629     } else if (
1630         node->matches(
1631             "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor") ||
1632         node->matches(
1633             "aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha) -> Tensor") ||
1634         node->matches(
1635             "aten::addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta, Scalar alpha) -> Tensor")) {
1636       if (auto type = any_tensor_type(node)) {
1637         node->output()->setType(type->withDim(2));
1638         return true;
1639       }
1640     } else if (
1641         node->matches(
1642             "aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha) -> Tensor")) {
1643       if (auto type = any_tensor_type(node)) {
1644         node->output()->setType(type->withDim(3));
1645         return true;
1646       }
1647     } else if (
1648         node->matches(
1649             "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor")) {
1650       auto type = input_type(0);
1651       auto index_type = input_type(1);
1652       // index_select behaves very weirdly when self.dim() == 0. It allows both
1653       // 0D and 1D indices, and returns a value that has as many dimensions as
1654       // index.
1655       if (type && index_type && type->dim()) {
1656         if (*type->dim() == 0) {
1657           node->output()->setType(type->withDim(index_type->dim()));
1658         } else {
1659           node->output()->setType(std::move(type));
1660         }
1661         return true;
1662       }
1663     } else if (
1664         node->matches(
1665             "aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor")) {
1666       auto type = input_type(0);
1667       auto index_type = input_type(1);
1668       // Gather has this annoying edge case where index always needs to match
1669       // the number of dims of self, **except** when self is 1D and index is 0D
1670       // in which case we return a 0D output.
1671       if (type && index_type && index_type->dim()) {
1672         if (*index_type->dim() == 0) {
1673           node->output()->setType(type->withDim(0));
1674         } else {
1675           node->output()->setType(std::move(type));
1676         }
1677         return true;
1678       }
1679     } else if (
1680         node->matches(
1681             "aten::embedding(Tensor weight, Tensor indices, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor")) {
1682       auto weight_type = input_type(0);
1683       auto indices_type = input_type(1);
1684       if (weight_type && indices_type && indices_type->dim()) {
1685         node->output()->setType(weight_type->withDim(*indices_type->dim() + 1));
1686         return true;
1687       }
1688     } else if (
1689         node->matches(
1690             "aten::bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias) -> Tensor")) {
1691       if (auto type = input_type(0)) {
1692         node->output()->setType(std::move(type));
1693         return true;
1694       }
1695       if (auto type = input_type(1)) {
1696         node->output()->setType(std::move(type));
1697         return true;
1698       }
1699     } else if (
1700         node->matches(
1701             "aten::dist(Tensor self, Tensor other, Scalar p) -> Tensor")) {
1702       if (auto type = any_tensor_type(node)) {
1703         node->output()->setType(type->withDim(0));
1704         return true;
1705       }
1706     }
1707 
1708     // The code below implements formulas that need type information for all
1709     // their tensor inputs, and have exactly one output.
1710     std::vector<TensorTypePtr> tensor_types;
1711     static const auto reshape_prop =
1712         [](Node* node,
1713            Symbol shape_input,
1714            const std::vector<TensorTypePtr>& tensor_types) -> TensorTypePtr {
1715       if (auto list_size = determineListSize(node->namedInput(shape_input))) {
1716         return tensor_types.at(0)->withDim(list_size);
1717       }
1718       return nullptr;
1719     };
1720     const auto getSingleOutputType = [&]() -> TypePtr {
1721       if (node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) {
1722         return tensor_types.at(0)->withScalarType(
1723             tensor_types.at(1)->scalarType());
1724       } else if (
1725           node->matches(
1726               "aten::view_as(Tensor(a) self, Tensor other) -> Tensor(a)") ||
1727           node->matches(
1728               "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)") ||
1729           node->matches(
1730               "aten::reshape_as(Tensor(a) self, Tensor other) -> Tensor(a)")) {
1731         return tensor_types.at(0)->withDim(tensor_types.at(1)->dim());
1732       } else if (
1733           node->matches("aten::view(Tensor self, int[] size) -> Tensor") ||
1734           node->matches(
1735               "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor") ||
1736           node->matches(
1737               "aten::as_strided(Tensor self, int[] size, int[] stride, int? storage_offset) -> Tensor")) {
1738         return reshape_prop(node, attr::size, tensor_types);
1739       } else if (
1740           node->matches(
1741               "aten::as_tensor(Tensor data, *, ScalarType? dtype, Device? device) -> Tensor")) {
1742         TypePtr input_type = node->inputs().at(0)->type();
1743         if (auto type = input_type->cast<TensorType>()) {
1744           if (type->scalarType() && type->device()) {
1745             at::ScalarType default_type = *type->scalarType();
1746             c10::Device default_device = *type->device();
1747             if (auto dtype_index =
1748                     node->schema().argumentIndexWithName("dtype")) {
1749               auto inp = toIValue(node->inputs().at(*dtype_index));
1750               if (inp == std::nullopt) {
1751                 return nullptr;
1752               }
1753               if (!inp->isNone()) {
1754                 default_type = inp->toScalarType();
1755               }
1756             }
1757             if (auto device_index =
1758                     node->schema().argumentIndexWithName("device")) {
1759               auto inp = toIValue(node->inputs().at(*device_index));
1760               if (inp == std::nullopt) {
1761                 return nullptr;
1762               }
1763               if (!inp->isNone()) {
1764                 default_device = inp->toDevice();
1765               }
1766             }
1767             node->output()->setType(TensorType::create(
1768                 default_type,
1769                 default_device,
1770                 type->dim(),
1771                 /*requires_grad=*/std::nullopt));
1772           }
1773         }
1774         return nullptr;
1775       } else if (
1776           node->matches(
1777               "aten::reshape(Tensor(a) self, int[] shape) -> Tensor(a)")) {
1778         return reshape_prop(node, attr::shape, tensor_types);
1779       } else if (node->matches(
1780                      "aten::repeat(Tensor self, int[] repeats) -> Tensor")) {
1781         return reshape_prop(node, attr::repeats, tensor_types);
1782       } else if (node->matches(
1783                      "aten::unsqueeze(Tensor self, int dim) -> Tensor")) {
1784         auto& t = tensor_types.at(0);
1785         if (!t->dim()) {
1786           return t;
1787         }
1788         return t->withDim(*t->dim() + 1);
1789       } else if (
1790           node->matches(
1791               "aten::select(Tensor self, int dim, int index) -> Tensor") ||
1792           node->matches(
1793               "aten::diagonal(Tensor self, int offset, int dim1, int dim2) -> Tensor")) {
1794         auto& t = tensor_types.at(0);
1795         return t->dim() && *t->dim() > 0 ? t->withDim(*t->dim() - 1) : nullptr;
1796       } else if (node->matches(
1797                      "aten::matmul(Tensor self, Tensor other) -> Tensor")) {
1798         if (!tensor_types.at(0)->dim() || !tensor_types.at(1)->dim()) {
1799           return nullptr;
1800         }
1801         auto dim1 = *tensor_types.at(0)->dim();
1802         auto dim2 = *tensor_types.at(1)->dim();
1803         if (dim1 == 1 && dim2 == 1) {
1804           // Dot product
1805           return tensor_types.at(0)->withDim(0);
1806           // NOLINTNEXTLINE(bugprone-branch-clone)
1807         } else if (dim1 == 2 && dim2 == 2) {
1808           // Matrix multiply
1809           return tensor_types.at(0);
1810         } else if (dim1 == 1 && dim2 == 2) {
1811           // Unsqueeze + matrix multiply + squeeze
1812           return tensor_types.at(0);
1813         } else if (dim1 == 2 && dim2 == 1) {
1814           // Matrix vector multiply
1815           return tensor_types.at(1);
1816         } else {
1817           // Batched matrix multiply (possibly with squeeze + unsqueeze if one
1818           // argument is 1D)
1819           auto type = broadcast(tensor_types, tensor_types[0]->scalarType());
1820           if (dim1 == 1 || dim2 == 1) {
1821             type = type->withDim(type->dim().value() - 1);
1822           }
1823           return type;
1824         }
1825       } else if (node->matches("aten::nonzero(Tensor self) -> Tensor")) {
1826         return tensor_types.at(0)->dimensionedOnly()->withScalarType(at::kLong);
1827       } else if (node->matches(
1828                      "aten::take(Tensor self, Tensor index) -> Tensor")) {
1829         return tensor_types.at(1)->dimensionedOnly()->withScalarType(
1830             tensor_types.at(0)->scalarType());
1831       } else if (node->matches(
1832                      "aten::diagflat(Tensor self, int offset) -> Tensor")) {
1833         return tensor_types.at(0)->withDim(2);
1834       } else if (node->matches(
1835                      "aten::diag(Tensor self, int diagonal) -> Tensor")) {
1836         auto& t = tensor_types.at(0);
1837         if (t->dim() && *t->dim() == 1) {
1838           return t->withDim(2);
1839         } else if (t->dim() && *t->dim() == 2) {
1840           return t->withDim(1);
1841         } else {
1842           return nullptr;
1843         }
1844       } else if (
1845           node->matches(
1846               "aten::unfold(Tensor self, int dimension, int size, int step) -> Tensor")) {
1847         auto& t = tensor_types.at(0);
1848         if (!t->dim()) {
1849           return nullptr;
1850         }
1851         return t->withDim(*t->dim() + 1);
1852       } else if (node->matches(
1853                      "aten::polygamma(int n, Tensor self) -> Tensor")) {
1854         return tensor_types.at(0);
1855       }
1856       return nullptr;
1857     };
1858     if (auto maybe_tensor_types = gatherTensorTypes(node)) {
1859       tensor_types = std::move(*maybe_tensor_types);
1860     } else {
1861       return false;
1862     }
1863     if (node->outputs().size() == 1) {
1864       if (auto type = getSingleOutputType()) {
1865         node->output()->setType(std::move(type));
1866         return true;
1867       }
1868     }
1869     return false;
1870   }
1871 
PropagateCompleteShapeOnNode(Node * node,bool insert_expands,std::vector<TensorTypePtr> tensor_types)1872   bool PropagateCompleteShapeOnNode(
1873       Node* node,
1874       bool insert_expands,
1875       std::vector<TensorTypePtr> tensor_types) {
1876     // For expensive ops we can directly encode their shape propagation
1877     // here, otherwise we fallback to running a fake version of the op
1878     // to get a quick and dirty propagation.
1879     if (node->matches(
1880             "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
1881         node->matches(
1882             "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
1883         node->matches("aten::mul(Tensor self, Tensor other) -> Tensor")) {
1884       // These nodes handle tensors of different shapes internally, so there's
1885       // no need to insert explicit expand nodes.
1886       return PropagateShapeOnNodeByRunningIt(node);
1887     } else if (node->matches(
1888                    "aten::div(Tensor self, Tensor other) -> Tensor")) {
1889       // "div" handle tensors of different shapes internally, so there's no need
1890       // to insert explicit expand nodes.
1891       // Note that this function could be merged to the one above , but "div" is
1892       // not always safe to run by itself due to integer divide-by-zero.
1893       // We fake the execution by running "mul" operation instead.
1894       auto op = getOperatorForLiteral(
1895                     "aten::mul(Tensor self, Tensor other) -> Tensor")
1896                     ->getOperation();
1897       return PropagateShapeOnNodeByRunningIt(node, std::move(op));
1898     } else if (node->matches(
1899                    "aten::pow(Tensor self, Scalar exponent) -> Tensor")) {
1900       node->output()->setType(tensor_types.at(0));
1901       return true;
1902     } else if (
1903         node->matches(
1904             "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor") ||
1905         node->matches(
1906             "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor") ||
1907         node->matches("aten::div(Tensor self, Scalar other) -> Tensor") ||
1908         node->matches("aten::mul(Tensor self, Scalar other) -> Tensor")) {
1909       auto first_scalar_type = (tensor_types)[0]->scalarType();
1910       auto second_scalar_type =
1911           tryScalarTypeFromJitType(*node->inputs()[1]->type());
1912       if (!first_scalar_type || !second_scalar_type) {
1913         return false;
1914       }
1915       if (isIntegralType(*first_scalar_type, false) &&
1916           isFloatingType(*second_scalar_type)) {
1917         auto default_dtype =
1918             at::typeMetaToScalarType(caffe2::get_default_dtype());
1919         auto type = tensor_types[0]->withScalarType(default_dtype);
1920         node->output()->setType(std::move(type));
1921         return true;
1922       }
1923       if (c10::ScalarType::Bool == *first_scalar_type &&
1924           c10::ScalarType::Bool != *second_scalar_type) {
1925         auto result_type =
1926             c10::promoteTypes(*first_scalar_type, *second_scalar_type);
1927         auto type = tensor_types[0]->withScalarType(result_type);
1928         node->output()->setType(std::move(type));
1929         return true;
1930       }
1931       auto type = tensor_types[0]->withScalarType(first_scalar_type);
1932       node->output()->setType(std::move(type));
1933       return true;
1934     } else if (
1935         insert_expands &&
1936         (node->matches("aten::pow(Tensor self, Tensor exponent) -> Tensor") ||
1937          node->matches("aten::min(Tensor self, Tensor other) -> Tensor") ||
1938          node->matches("aten::max(Tensor self, Tensor other) -> Tensor") ||
1939          node->matches("aten::lt(Tensor self, Tensor other) -> Tensor") ||
1940          node->matches("aten::le(Tensor self, Tensor other) -> Tensor") ||
1941          node->matches("aten::gt(Tensor self, Tensor other) -> Tensor") ||
1942          node->matches("aten::ge(Tensor self, Tensor other) -> Tensor") ||
1943          node->matches("aten::eq(Tensor self, Tensor other) -> Tensor") ||
1944          node->matches("aten::ne(Tensor self, Tensor other) -> Tensor"))) {
1945       // Binary broadcasting ops
1946       // NB: we don't handle the nodes in any other way (note the lack of
1947       // return!), because the type casting logic in scalar cases is
1948       // non-trivial. It's better to just run them.
1949       broadcastBinary(node, tensor_types, 0, 1);
1950       return PropagateShapeOnNodeByRunningIt(node);
1951     } else if (
1952         node->matches(
1953             "aten::logit(Tensor self, float? eps = None) -> Tensor") ||
1954         node->matches("aten::neg(Tensor self) -> Tensor") ||
1955         node->matches("aten::sigmoid(Tensor self) -> Tensor") ||
1956         node->matches("aten::tanh(Tensor self) -> Tensor")) {
1957       node->output()->setType(tensor_types.at(0)->contiguous());
1958       return true;
1959     } else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
1960       auto lhs_type = tensor_types.at(0);
1961       auto rhs_type = tensor_types.at(1);
1962       auto lhs_sizes = lhs_type->sizes().concrete_sizes().value();
1963       auto rhs_sizes = rhs_type->sizes().concrete_sizes().value();
1964       SHAPE_ASSERT(
1965           *lhs_type->sizes().size() == 2 && *rhs_type->sizes().size() == 2);
1966       node->output()->setType(TensorType::createContiguous(
1967           *lhs_type->scalarType(),
1968           *lhs_type->device(),
1969           at::IntArrayRef{lhs_sizes[0], rhs_sizes[1]}));
1970       return true;
1971     } else if (node->matches("aten::t(Tensor self) -> Tensor")) {
1972       auto tp = tensor_types.at(0);
1973       auto sizes = tp->sizes().concrete_sizes().value();
1974       auto strides = tp->strides().concrete_sizes().value();
1975       SHAPE_ASSERT(sizes.size() == 2);
1976       std::swap(sizes.at(0), sizes.at(1));
1977       std::swap(strides.at(0), strides.at(1));
1978       node->output()->setType(tp->withSizesStrides(sizes, strides));
1979       return true;
1980     } else if (
1981         node->matches(
1982             "aten::narrow(Tensor self, int dim, int start, int length) -> Tensor",
1983             /*const_inputs=*/{attr::dim, attr::length})) {
1984       auto tp = tensor_types.at(0);
1985       auto sizes = tp->sizes().concrete_sizes().value();
1986       int64_t dim = node->get<int64_t>(attr::dim).value();
1987       int64_t length = node->get<int64_t>(attr::length).value();
1988       SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
1989       sizes.at(dim) = length;
1990       node->output()->setType(
1991           tp->withSizesStrides(sizes, tp->strides().concrete_sizes().value()));
1992       return true;
1993     } else if (node->matches(
1994                    "aten::sum(Tensor self, *, int? dtype) -> Tensor")) {
1995       node->output()->setType(tensor_types.at(0)->withSizes({}));
1996       return true;
1997     } else if (
1998         node->matches(
1999             "aten::sum(Tensor self, int[]? dim, bool keepdim, *, int? dtype) -> Tensor",
2000             /*const_inputs=*/{attr::dim, attr::keepdim})) {
2001       auto& tp = tensor_types.at(0);
2002       auto sizes = tp->sizes().concrete_sizes().value();
2003       auto dims = node->get<c10::List<int64_t>>(attr::dim).value();
2004       bool keepdim = node->get<bool>(attr::keepdim).value();
2005       std::reverse(dims.begin(), dims.end());
2006       for (int64_t dim : dims) {
2007         SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
2008         if (keepdim) {
2009           sizes.at(dim) = 1;
2010         } else {
2011           sizes.erase(sizes.begin() + dim);
2012         }
2013       }
2014       node->output()->setType(tp->withSizes(sizes));
2015       return true;
2016     } else if (node->matches(
2017                    "aten::squeeze(Tensor self, int dim) -> Tensor",
2018                    /*const_inputs=*/attr::dim)) {
2019       auto& tp = tensor_types.at(0);
2020       auto sizes = tp->sizes().concrete_sizes().value();
2021       auto strides = tp->strides().concrete_sizes().value();
2022       int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
2023       SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
2024       if (sizes.at(dim) == 1) {
2025         sizes.erase(sizes.begin() + dim);
2026         strides.erase(strides.begin() + dim);
2027       }
2028       node->output()->setType(tp->withSizesStrides(sizes, strides));
2029       return true;
2030     } else if (node->matches(
2031                    "aten::unsqueeze(Tensor self, int dim) -> Tensor",
2032                    /*const_inputs=*/attr::dim)) {
2033       auto& tp = tensor_types.at(0);
2034       auto sizes = tp->sizes().concrete_sizes().value();
2035       auto strides = tp->strides().concrete_sizes().value();
2036       int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
2037       SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) <= sizes.size());
2038       int64_t new_stride = dim >= static_cast<int64_t>(sizes.size())
2039           ? 1
2040           : sizes.at(dim) * strides.at(dim);
2041       sizes.insert(sizes.begin() + dim, 1);
2042       strides.insert(strides.begin() + dim, new_stride);
2043       node->output()->setType(tp->withSizesStrides(sizes, strides));
2044       return true;
2045     } else if (node->matches(
2046                    "aten::view(Tensor self, int[] size) -> Tensor",
2047                    /*const_inputs=*/attr::size)) {
2048       auto sizes = node->get<c10::List<int64_t>>(attr::size).value();
2049       bool inferred = false;
2050       size_t inferred_idx = 0;
2051       int64_t size_product = 1;
2052       for (const auto i : c10::irange(sizes.size())) {
2053         if (sizes.get(i) == -1) {
2054           if (inferred)
2055             throw propagation_error();
2056           inferred = true;
2057           inferred_idx = i;
2058         } else {
2059           size_product *= sizes.get(i);
2060         }
2061       }
2062 
2063       if (inferred) {
2064         SHAPE_ASSERT(size_product != 0);
2065         int64_t numel = 1;
2066         auto concrete_sizes =
2067             tensor_types.at(0)->sizes().concrete_sizes().value();
2068         for (int64_t s : concrete_sizes)
2069           numel *= s;
2070         int64_t inferred_size = numel / size_product;
2071         sizes[inferred_idx] = inferred_size;
2072       }
2073       node->output()->setType(tensor_types.at(0)->withSizes(sizes.vec()));
2074       return true;
2075     } else if (node->matches(
2076                    "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
2077       if (tensor_types.at(0)->scalarType() ==
2078           tensor_types.at(1)->scalarType()) {
2079         node->output()->setType(node->namedInput(attr::self)->type());
2080       } else {
2081         // This will be a copy, so the result will be contiguous
2082         node->output()->setType(tensor_types.at(1)->withSizes(
2083             tensor_types.at(0)->sizes().concrete_sizes().value()));
2084       }
2085       return true;
2086     } else if (
2087         node->matches(
2088             "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor",
2089             /*const_inputs=*/attr::size)) {
2090       auto tp = tensor_types.at(0);
2091       auto sizesAndStrides = at::inferExpandGeometry_dimvector(
2092           tp->sizes().concrete_sizes().value(),
2093           tp->strides().concrete_sizes().value(),
2094           node->get<c10::List<int64_t>>(attr::size).value().vec());
2095       node->output()->setType(
2096           tp->withSizesStrides(sizesAndStrides.sizes, sizesAndStrides.strides));
2097       return true;
2098     } else if (
2099         node->matches(
2100             "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor",
2101             /*const_inputs=*/attr::dim)) {
2102       auto ten = tensor_types.at(0);
2103       auto index = tensor_types.at(1);
2104       int64_t dim = node->get<int64_t>(attr::dim).value();
2105       SHAPE_ASSERT(*index->sizes().size() == 1);
2106       SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < ten->sizes().size());
2107       std::vector<int64_t> sizes = ten->sizes().concrete_sizes().value();
2108       sizes[dim] = index->sizes()[0].value();
2109       node->output()->setType(ten->withSizes(sizes));
2110       return true;
2111     } else if (node->matches(
2112                    "aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]",
2113                    /*const_inputs=*/{attr::chunks, attr::dim})) {
2114       auto input_type = tensor_types.at(0);
2115       auto sizes = input_type->sizes().concrete_sizes().value();
2116       auto strides = input_type->strides().concrete_sizes().value();
2117       int64_t dim = node->get<int64_t>(attr::dim).value();
2118       int64_t chunks = node->get<int64_t>(attr::chunks).value();
2119       sizes[dim] /= chunks;
2120       for (Value* output : node->outputs()) {
2121         output->setType(input_type->withSizesStrides(sizes, strides));
2122       }
2123       if (*input_type->sizes()[dim] % chunks != 0) {
2124         sizes[dim] = *input_type->sizes()[dim] % chunks;
2125         node->outputs().back()->setType(
2126             input_type->withSizesStrides(sizes, strides));
2127       }
2128       return true;
2129     } else if (node->kind() == ::c10::onnx::Shape) {
2130       SHAPE_ASSERT(node->inputs().size() == 1 && node->outputs().size() == 1);
2131       std::vector<int64_t> dim_vec = {
2132           (int64_t)*tensor_types.at(0)->sizes().size()};
2133       at::IntArrayRef dims(dim_vec);
2134       node->output()->setType(
2135           TensorType::createContiguous(at::kLong, at::kCPU, dims));
2136       return true;
2137     } else if (node->kind() == ::c10::onnx::Reshape) {
2138       setUnshapedType(node);
2139       return true;
2140     }
2141     setUnshapedType(node);
2142     return false;
2143   }
2144 };
2145 } // anonymous namespace
2146 
PropagateInputShapes(const std::shared_ptr<Graph> & graph)2147 void PropagateInputShapes(const std::shared_ptr<Graph>& graph) {
2148   ShapePropagator(graph).propagateBlock(graph->block());
2149 }
2150 
2151 namespace {
2152 
2153 using TypeCache = std::unordered_map<TypePtr, TypePtr>;
2154 
2155 TypePtr getOrCreateUnshapedType(
2156     const TypePtr& type,
2157     TypeCache& unshaped_type_cache);
2158 
unshapedTypeImpl(TypePtr type,TypeCache & unshaped_type_cache)2159 TypePtr unshapedTypeImpl(TypePtr type, TypeCache& unshaped_type_cache) {
2160   if (type->isSubtypeOf(*TensorType::get())) {
2161     return TensorType::get();
2162   }
2163   at::ArrayRef<TypePtr> contained = type->containedTypes();
2164   if (contained.empty()) {
2165     return type;
2166   }
2167   std::vector<TypePtr> unshaped_contained_types;
2168   for (const auto& contained_type : contained) {
2169     unshaped_contained_types.push_back(
2170         getOrCreateUnshapedType(contained_type, unshaped_type_cache));
2171   }
2172   return type->withContained(std::move(unshaped_contained_types));
2173 }
2174 
getOrCreateUnshapedType(const TypePtr & type,TypeCache & unshaped_type_cache)2175 TypePtr getOrCreateUnshapedType(
2176     const TypePtr& type,
2177     TypeCache& unshaped_type_cache) {
2178   auto maybe_cached_type = unshaped_type_cache.find(type);
2179   if (maybe_cached_type != unshaped_type_cache.end()) {
2180     return maybe_cached_type->second;
2181   }
2182   auto unshaped_type = unshapedTypeImpl(type, unshaped_type_cache);
2183   unshaped_type_cache[type] = unshaped_type;
2184   return unshaped_type;
2185 }
2186 
2187 void EraseShapeInformation(
2188     const std::shared_ptr<Graph>& graph,
2189     TypeCache& unshaped_type_cache);
2190 
EraseShapeInformation(at::ArrayRef<Value * > vals,TypeCache & unshaped_type_cache)2191 void EraseShapeInformation(
2192     at::ArrayRef<Value*> vals,
2193     TypeCache& unshaped_type_cache) {
2194   for (Value* v : vals) {
2195     v->setType(getOrCreateUnshapedType(v->type(), unshaped_type_cache));
2196   }
2197 }
2198 
EraseShapeInformation(Block * b,TypeCache & unshaped_type_cache)2199 void EraseShapeInformation(Block* b, TypeCache& unshaped_type_cache) {
2200   EraseShapeInformation(b->inputs(), unshaped_type_cache);
2201   EraseShapeInformation(b->outputs(), unshaped_type_cache);
2202   for (Node* n : b->nodes()) {
2203     EraseShapeInformation(n->outputs(), unshaped_type_cache);
2204     for (Block* sb : n->blocks()) {
2205       EraseShapeInformation(sb, unshaped_type_cache);
2206     }
2207     if (n->hasAttribute(attr::Subgraph)) {
2208       EraseShapeInformation(n->g(attr::Subgraph), unshaped_type_cache);
2209     }
2210   }
2211 }
2212 
EraseShapeInformation(const std::shared_ptr<Graph> & graph,TypeCache & unshaped_type_cache)2213 void EraseShapeInformation(
2214     const std::shared_ptr<Graph>& graph,
2215     TypeCache& unshaped_type_cache) {
2216   EraseShapeInformation(graph->block(), unshaped_type_cache);
2217 }
2218 
2219 } // anonymous namespace
2220 
EraseShapeInformation(const std::shared_ptr<Graph> & graph)2221 void EraseShapeInformation(const std::shared_ptr<Graph>& graph) {
2222   TypeCache unshaped_type_cache;
2223   EraseShapeInformation(graph->block(), unshaped_type_cache);
2224 }
2225 } // namespace torch::jit
2226