• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <initializer_list>
21 #include <iterator>
22 #include <queue>
23 
24 #include "llvm/ADT/Hashing.h"
25 #include "llvm/ADT/None.h"
26 #include "llvm/ADT/PointerUnion.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/iterator_range.h"
30 #include "llvm/Support/Casting.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/FormatVariadic.h"
33 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
34 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
35 #include "mlir/IR/Attributes.h"  // from @llvm-project
36 #include "mlir/IR/Block.h"  // from @llvm-project
37 #include "mlir/IR/Builders.h"  // from @llvm-project
38 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
39 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
40 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
41 #include "mlir/IR/Location.h"  // from @llvm-project
42 #include "mlir/IR/Operation.h"  // from @llvm-project
43 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
44 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
45 #include "mlir/IR/Value.h"  // from @llvm-project
46 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
47 #include "mlir/Interfaces/FoldInterfaces.h"  // from @llvm-project
48 #include "mlir/Interfaces/InferTypeOpInterface.h"  // from @llvm-project
49 #include "mlir/Pass/Pass.h"  // from @llvm-project
50 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
51 #include "mlir/Support/LLVM.h"  // from @llvm-project
52 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
53 #include "mlir/Transforms/FoldUtils.h"  // from @llvm-project
54 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
55 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
56 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
57 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
58 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
59 #include "tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h"
60 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
61 #include "tensorflow/core/framework/shape_inference.h"
62 #include "tensorflow/core/framework/types.pb.h"
63 
64 #define DEBUG_TYPE "tf-shape-inference"
65 
66 #define DCOMMENT(MSG) LLVM_DEBUG(llvm::dbgs() << MSG << "\n")
67 #define DCOMMENT_OP(OP, MSG) \
68   LLVM_DEBUG(OP->print(llvm::dbgs() << MSG << " "); llvm::dbgs() << "\n")
69 
70 using ::tensorflow::int64;
71 using tensorflow::shape_inference::DimensionHandle;
72 using tensorflow::shape_inference::InferenceContext;
73 using tensorflow::shape_inference::ShapeHandle;
74 
75 namespace mlir {
76 namespace TF {
77 namespace {
78 
79 // Returns whether type can be further refined.
CanBeRefined(Type type)80 bool CanBeRefined(Type type) {
81   auto shape_type = type.dyn_cast<ShapedType>();
82   if (!shape_type) return false;
83 
84   // Returns whether type with subtypes can be further refined.
85   auto can_refine_subtypes = [](TF::TensorFlowTypeWithSubtype tws) {
86     return tws.GetSubtypes().empty() ||
87            llvm::any_of(tws.GetSubtypes(), CanBeRefined);
88   };
89   auto type_with_subtype =
90       shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>();
91   if (type_with_subtype && can_refine_subtypes(type_with_subtype)) return true;
92 
93   return !shape_type.hasStaticShape();
94 }
95 
96 // Returns whether `original_type` type can be refined with
97 // `potential_refined_type` type.
CanRefineTypeWith(Type original_type,Type potential_refined_type)98 bool CanRefineTypeWith(Type original_type, Type potential_refined_type) {
99   if (original_type == potential_refined_type || !CanBeRefined(original_type))
100     return false;
101 
102   auto shape_type = potential_refined_type.dyn_cast<ShapedType>();
103   if (!shape_type) return false;
104   if (shape_type.hasRank()) return true;
105 
106   auto element_type_with_subtype =
107       shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>();
108   return element_type_with_subtype &&
109          !element_type_with_subtype.GetSubtypes().empty();
110 }
111 
112 // Returns if the shape inference pass supports an op outside the TF dialect.
IsSupportedNonTFOp(Operation * op)113 bool IsSupportedNonTFOp(Operation* op) {
114   return isa<tf_device::ReturnOp, tf_device::ClusterOp, tf_device::LaunchOp,
115              tf_executor::EnterOp, tf_executor::ExitOp, tf_executor::FetchOp,
116              tf_executor::GraphOp, tf_executor::IslandOp,
117              tf_executor::LoopCondOp, tf_executor::MergeOp,
118              tf_executor::NextIterationSinkOp, tf_executor::SwitchNOp,
119              tf_executor::SwitchOp, tf_executor::YieldOp>(op);
120 }
121 
122 // Returns whether a cast back would need to be inserted, e.g., whether the
123 // operation of which use is an operand allows for shape refinement without
124 // a cast.
NeedsCastBack(OpOperand & use,Dialect * tf_dialect)125 bool NeedsCastBack(OpOperand& use, Dialect* tf_dialect) {
126   return use.getOwner()->getDialect() != tf_dialect &&
127          !IsSupportedNonTFOp(use.getOwner());
128 }
129 
CreateTensorType(llvm::Optional<llvm::ArrayRef<int64_t>> shape,Type element_type)130 TensorType CreateTensorType(llvm::Optional<llvm::ArrayRef<int64_t>> shape,
131                             Type element_type) {
132   if (shape.hasValue())
133     return RankedTensorType::get(shape.getValue(), element_type);
134   return UnrankedTensorType::get(element_type);
135 }
136 
137 // Returns true if the op creates a TensorList.
IsTensorListInitOp(Operation * op)138 bool IsTensorListInitOp(Operation* op) {
139   return isa<TensorListReserveOp>(op) || isa<EmptyTensorListOp>(op) ||
140          isa<TensorListFromTensorOp>(op);
141 }
142 
143 // Returns the `element_shape` operand of the ops that create a TensorList.
GetElementShapeOperand(Operation * op)144 Value GetElementShapeOperand(Operation* op) {
145   if (auto empty_tl = dyn_cast<EmptyTensorListOp>(op))
146     return empty_tl.element_shape();
147   if (auto tl_reserve = dyn_cast<TensorListReserveOp>(op))
148     return tl_reserve.element_shape();
149   if (auto tl_from_tensor = dyn_cast<TensorListFromTensorOp>(op))
150     return tl_from_tensor.element_shape();
151   llvm_unreachable("unsupported TensorList op");
152 }
153 
154 // Utility function to create a ranked tensor type after dropping the first
155 // dimension from the input type.
DropFirstDimension(Type type)156 RankedTensorType DropFirstDimension(Type type) {
157   RankedTensorType ranked_type = type.dyn_cast<RankedTensorType>();
158   if (!ranked_type) return {};
159   llvm::ArrayRef<int64_t> dims_except_first =
160       ranked_type.getShape().drop_front();
161   return RankedTensorType::get(dims_except_first, ranked_type.getElementType());
162 }
163 
164 // Follow the use chain of TensorList and return true iff all elements written
165 // to TensorList have same static shape. If all elements have same shape, assign
166 // it to `potential_element_type`.
167 //
168 // This can handle multiple mutations of a TensorList object and would return
169 // true if across all mutations the elements written have the same shape.
CanInferTensorListElementType(Value tensorlist,Value initial_element_shape,RankedTensorType * potential_element_type)170 bool CanInferTensorListElementType(Value tensorlist,
171                                    Value initial_element_shape,
172                                    RankedTensorType* potential_element_type) {
173   // Verifies if the new element type has static shape and matches the potential
174   // type passed from caller. Updates the potential_element_type, if not defined
175   // yet.
176   auto verify_and_update_potential_element_type =
177       [&](RankedTensorType new_element_type) -> bool {
178     if (!new_element_type || !new_element_type.hasStaticShape()) return false;
179     if (!*potential_element_type) {
180       *potential_element_type = new_element_type;
181       return true;
182     }
183     return *potential_element_type == new_element_type;
184   };
185 
186   // TensorLists are semantically immutable. For example, TensorListSetItem
187   // takes a TensorList as input and produces a TensorList as output. So to
188   // traverse modifications to TensorList and verify that all elements written
189   // to it have the same shape, we need to follow use-def chain of ops that
190   // (conceptually) modify it i.e., ops that take an input TensorList and
191   // produce an output TensorList.
192   for (auto& use : tensorlist.getUses()) {
193     if (auto push = llvm::dyn_cast<TensorListPushBackOp>(use.getOwner())) {
194       auto element_type = push.tensor().getType().dyn_cast<RankedTensorType>();
195       if (!verify_and_update_potential_element_type(element_type)) return false;
196       if (!CanInferTensorListElementType(push.output_handle(),
197                                          initial_element_shape,
198                                          potential_element_type))
199         return false;
200       continue;
201     }
202     if (auto scatter = llvm::dyn_cast<TensorListScatterIntoExistingListOp>(
203             use.getOwner())) {
204       // For scatter op we can get the element shape by dropping the first
205       // dimension of the input tensor.
206       RankedTensorType element_type =
207           DropFirstDimension(scatter.tensor().getType());
208       if (!verify_and_update_potential_element_type(element_type)) return false;
209       if (!CanInferTensorListElementType(scatter.output_handle(),
210                                          initial_element_shape,
211                                          potential_element_type))
212         return false;
213       continue;
214     }
215     if (auto set_item = llvm::dyn_cast<TensorListSetItemOp>(use.getOwner())) {
216       auto element_type =
217           set_item.item().getType().dyn_cast<RankedTensorType>();
218       if (!verify_and_update_potential_element_type(element_type)) return false;
219       if (!CanInferTensorListElementType(set_item.output_handle(),
220                                          initial_element_shape,
221                                          potential_element_type))
222         return false;
223       continue;
224     }
225     if (auto pop = llvm::dyn_cast<TensorListPopBackOp>(use.getOwner())) {
226       if (!CanInferTensorListElementType(pop.output_handle(),
227                                          initial_element_shape,
228                                          potential_element_type))
229         return false;
230       continue;
231     }
232     if (auto resize = llvm::dyn_cast<TensorListResizeOp>(use.getOwner())) {
233       if (!CanInferTensorListElementType(resize.output_handle(),
234                                          initial_element_shape,
235                                          potential_element_type))
236         return false;
237       continue;
238     }
239     // WhileRegionOp can explicitly capture TensorList value to be used inside
240     // its regions. So we check the uses of corresponding block argument in each
241     // region and the use of TensorList returned using YieldOp.
242     if (auto while_region = llvm::dyn_cast<WhileRegionOp>(use.getOwner())) {
243       for (auto branch : while_region.getRegions()) {
244         if (!CanInferTensorListElementType(
245                 branch->getArgument(use.getOperandNumber()),
246                 initial_element_shape, potential_element_type))
247           return false;
248       }
249       continue;
250     }
251     if (auto yield = llvm::dyn_cast<YieldOp>(use.getOwner())) {
252       Operation* parent = yield->getParentOp();
253       if (!CanInferTensorListElementType(
254               parent->getResult(use.getOperandNumber()), initial_element_shape,
255               potential_element_type))
256         return false;
257       continue;
258     }
259     // Refining the tensor list element type might change the output of
260     // TensorListElementShape which is expected tp be the originally assigned
261     // shape to TensorList init ops. So replace it with the original element
262     // shape value.
263     if (auto tl_element_shape =
264             dyn_cast<TensorListElementShapeOp>(use.getOwner())) {
265       tl_element_shape.replaceAllUsesWith(initial_element_shape);
266       continue;
267     }
268     // Ignore ops that just consume a TensorList and do not output another
269     // TensorList.
270     if (isa<TensorListStackOp, TensorListGatherOp, TensorListConcatV2Op,
271             TensorListLengthOp, TensorListGetItemOp>(use.getOwner()))
272       continue;
273 
274     // For any other unknown users of the TensorList, we are conservative and
275     // stop element shape inference.
276     return false;
277   }
278   return true;
279 }
280 }  // namespace
281 
282 // Combination of value producer and port of value produced (e.g.,
283 //   <value result output>:<value in output tensor>,
284 // so for tf.Const -> tensor<10x20xf32>, [0,2,18] would point to a unique output
285 // scalar value).
286 struct ValuePort {
287   PointerUnion<Operation*, BlockArgument> producer;
288   SmallVector<unsigned int, 2> port;
289 
operator ==mlir::TF::ValuePort290   bool operator==(const ValuePort& other) const {
291     return producer == other.producer && port == other.port;
292   }
293 
294   // Convert output value to ValuePort.
ValuePortmlir::TF::ValuePort295   explicit ValuePort(Value v) {
296     OpResult opr = v.dyn_cast<OpResult>();
297     if (opr) {
298       producer = opr.getOwner();
299       port = {opr.getResultNumber()};
300     } else {
301       producer = v.cast<BlockArgument>();
302       port = {0};
303     }
304   }
ValuePortmlir::TF::ValuePort305   ValuePort(PointerUnion<Operation*, BlockArgument> producer,
306             SmallVector<unsigned int, 2> port)
307       : producer(producer), port(port) {}
308 
printmlir::TF::ValuePort309   raw_ostream& print(raw_ostream& os) const {
310     if (auto* op = producer.dyn_cast<Operation*>())
311       os << "op " << op->getName();
312     if (auto ba = producer.dyn_cast<BlockArgument>())
313       os << "block_arg " << ba.getArgNumber();
314     os << formatv(" [{0}]", llvm::make_range(port.begin(), port.end()));
315     return os;
316   }
317 };
318 
319 struct ValuePortHasher {
operator ()mlir::TF::ValuePortHasher320   std::size_t operator()(const ValuePort& other) const {
321     return hash_combine(llvm::hash_value(other.producer.getOpaqueValue()),
322                         hash_value(ArrayRef<unsigned int>(other.port)));
323   }
324 };
325 
326 using ValuePortResultMap =
327     std::unordered_map<ValuePort, Attribute, ValuePortHasher>;
328 using ComputedQueryFn = function_ref<bool(ValuePort)>;
329 using ValueQueryFn = function_ref<Attribute(const ValuePort&)>;
330 using ValuePortInputs = SmallVectorImpl<ValuePort>;
331 
332 // TODO(jpienaar): ComputeInputsRequiredForOutput and ComputeOutputComponent are
333 // intended to be switched to op interfaces once more refined.
ComputeInputsRequiredForOutput(ValuePort value_port,ComputedQueryFn has_been_computed,ValuePortInputs * inputs)334 LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
335                                              ComputedQueryFn has_been_computed,
336                                              ValuePortInputs* inputs) {
337   auto op = value_port.producer.dyn_cast<Operation*>();
338   auto& port = value_port.port;
339   if (!op) return failure();
340 
341   // No inputs required for constants.
342   if (matchPattern(op, m_Constant())) return success();
343 
344   // Note: this focusses only on the trivial pack op case and this could be
345   // generalized.
346   if (auto pack_op = dyn_cast<TF::PackOp>(op)) {
347     auto type = pack_op.getType().cast<TensorType>();
348     if (!type.hasRank() || type.getRank() != 1) return failure();
349     if (port.size() != 2) return failure();
350     assert(port[0] == 0);
351     ValuePort req(pack_op.getOperand(port[1]));
352     if (!has_been_computed(req)) inputs->push_back(req);
353     return success();
354   }
355 
356   return failure();
357 }
358 
359 // Computes the output produced by ValuePort using the query function of
360 // existing computed values.
ComputeOutputComponent(const ValuePort & value_port,ValueQueryFn values)361 Attribute ComputeOutputComponent(const ValuePort& value_port,
362                                  ValueQueryFn values) {
363   LLVM_DEBUG(value_port.print(llvm::dbgs() << "Computing output for ") << "\n");
364   if (auto known = values(value_port)) return known;
365 
366   auto op = value_port.producer.dyn_cast<Operation*>();
367   if (!op) return nullptr;
368   auto& port = value_port.port;
369 
370   if (port.empty()) {
371     LLVM_DEBUG(llvm::dbgs() << "skipping, port outside spec of " << op << "\n");
372     return nullptr;
373   }
374 
375   ElementsAttr attr;
376   if (matchPattern(op, m_Constant(&attr))) {
377     if (port.size() == 1 && port[0] == 0) return attr;
378     return nullptr;
379   }
380 
381   // Note: this focusses only on the trivial pack op case and this could be
382   // generalized.
383   if (auto pack_op = dyn_cast<TF::PackOp>(op)) {
384     TensorType type = pack_op.getType().cast<TensorType>();
385     if (!type.hasRank() || type.getRank() != 1) return nullptr;
386     if (port.size() != 2 || port[0] != 0) return nullptr;
387     ValuePort op_port(op->getOperand(port[1]));
388     return values(op_port);
389   }
390 
391   if (auto graph = dyn_cast<tf_executor::GraphOp>(op)) {
392     if (port.size() == 1)
393       return ComputeOutputComponent(
394           ValuePort(graph.GetFetch().fetches()[port[0]]), values);
395     return nullptr;
396   }
397 
398   if (auto island = dyn_cast<tf_executor::IslandOp>(op)) {
399     if (port.size() == 1)
400       return ComputeOutputComponent(
401           ValuePort(island.GetYield().fetches()[port[0]]), values);
402     return nullptr;
403   }
404 
405   return nullptr;
406 }
407 
408 // Context used during ShapeInference. This class contains common information
409 // that is required by the individual shape inference helper functions (e.g.,
410 // TF Graph version, constant values computed, etc.)
411 class ShapeInference {
412  public:
413   ShapeInference(int64_t graph_version, MLIRContext* context,
414                  bool propagate_caller_callee_constants);
415 
ComputeInputsRequiredForOutput(ValuePort value_port,ValuePortInputs * inputs)416   LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
417                                                ValuePortInputs* inputs) {
418     return ::mlir::TF::ComputeInputsRequiredForOutput(
419         value_port,
420         [this](const ValuePort& port) {
421           return results_.find(port) != results_.end();
422         },
423         inputs);
424   }
425 
ComputeOutputComponent(const ValuePort & value_port)426   Attribute ComputeOutputComponent(const ValuePort& value_port) {
427     if (auto known_attr = results_[value_port]) return known_attr;
428     auto attr = ::mlir::TF::ComputeOutputComponent(
429         value_port, [this](const ValuePort& port) { return results_[port]; });
430     RecordValue(value_port, attr);
431     return attr;
432   }
433 
434   // Returns ShapeHandle if the op result could be computed as shape.
435   ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic);
436 
RecordValue(const ValuePort & value_port,Attribute value)437   void RecordValue(const ValuePort& value_port, Attribute value) {
438     LLVM_DEBUG(value_port.print(llvm::dbgs() << "\trecording ")
439                << value << "\n");
440     results_[value_port] = value;
441   }
442 
443   // Infers shape of tf.While/tf.WhileRegion. If `shape_invariant` attribute is
444   // set, operand types are set as result types if associated body result types
445   // match the operand type (does not change per loop iteration). If operand and
446   // body result types are not the same, only handle types are propagated to
447   // result types. This is necessary to not incorrectly change result shapes
448   // when the While op will have a different result shape. Otherwise operand
449   // shapes are propagated to result shapes.
450   template <typename WhileOpTy>
451   bool InferShapeForWhile(WhileOpTy op, TypeRange body_result_types);
452 
453   // Performs shape inference on the provided op and return true if the type of
454   // at least one result has been changed.
455   // A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect.
456   // `graph_version` indicates the current GraphDef compatibility versions
457   // (the versions field in graph.proto).
458   bool InferShapeForSingleOperation(Operation* op);
459 
460   // Infers shape on the provided region, including nested ones, iterate until
461   // fix point with a limit of max_iteration. Returns success if fix point is
462   // reached before max_iteration.
463   LogicalResult InferShapeUntilFixPoint(Region* region, int64_t max_iterations);
464 
465   // Updates input types and refine shapes inside body of functions that are
466   // attached to ControlFlow ops (If/While) or Calls. These functions include
467   // Then/Else branches of IfOp and Cond/Body functions of WhileOp. Functions
468   // attached to control flow share following common properties:
469   //   1) They are never reused, ie. having a single use in module.
470   //   2) Their input types match those of their parent ops (excluding inputs
471   //      like predicate).
472   // For calls, functions can be reused across multiple call sites. In this case
473   // we propagate the types when all call sites have the same operand types.
474   LogicalResult PropagateShapeToFunctions(ModuleOp module,
475                                           TypeRange input_types,
476                                           ArrayRef<FuncOp> functions,
477                                           int64_t max_iteration);
478 
479   // Propagates shapes to regions given the shapes of the inputs of the regions.
480   // All regions provided in `regions` are assumed to have inputs of type
481   // `input_types`.
482   LogicalResult PropagateShapeToRegions(TypeRange input_types,
483                                         ArrayRef<Region*> regions,
484                                         int64_t max_iteration);
485 
486   // Shape propagation for call/control flow ops.
487   LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
488                                                     int64_t max_iteration);
489 
490   // Shape propagation for region based control flow.
491   LogicalResult PropagateShapeIntoAttachedRegions(Operation* op,
492                                                   int64_t max_iterations);
493 
494   // Propagates any constant operand of call_op to the called function body's
495   // corresponding argument if the callee has only one use.
496   //
497   // TODO(b/154065712): Move this to a more general inter-procedural constant
498   // folding pass.
499   void PropagateConstantToCallee(CallOpInterface call_op, FuncOp func,
500                                  ModuleOp module);
501 
502   // Propagates any constant return value of the callee function to the call
503   // op's corresponding result.
504   void PropagateConstantFromCallee(CallOpInterface call_op, FuncOp func,
505                                    ModuleOp module);
506 
507   // Tries to compute the result of folding the op. This doesn't actually
508   // perform constant folding, it is just computes the equivalent constants.
509   // Returns whether it was able to compute constant values.
510   LogicalResult TryToFold(Operation* op);
511 
512   // Makes result types match the operand types (the i-th result type will
513   // match the i-th operand type). Returns true if anything is changed.
514   bool RefineTypeForPassThroughOperands(Operation* op, OperandRange operands,
515                                         ResultRange results);
516 
517   // Makes result type's shape match the corresponding operand's shape.
518   // Returns whether any change was made.
519   bool RefineShapeForPassThroughOps(Operation* op);
520 
521   // Infers shape for necessary ops that are not in the TF dialect. Returns
522   // whether any result type changed.
523   bool InferShapeForNonTFDialectOperation(Operation* op);
524 
525   // Infers shape for function return type and returns whether changed.
526   void InferShapeForFunctionReturnType(FuncOp func);
527 
528   // Enqueues function for processing.
enqueue(FuncOp fn)529   void enqueue(FuncOp fn) {
530     LLVM_DEBUG(llvm::dbgs()
531                << "enqueue " << fn.getName() << " ("
532                << (queue_set_.count(fn) ? "already inserted" : "newly inserted")
533                << ")\n");
534     if (queue_set_.insert(fn).second) queue_.push(fn);
535   }
536 
537   // Enqueues callers on functions.
538   void EnqueueCallers(FuncOp fn);
539 
540   // Returns the function at the front of the queue.
front()541   FuncOp front() { return queue_.front(); }
542 
543   // Returns whether work queue is empty.
EmptyQueue() const544   bool EmptyQueue() const { return queue_.empty(); }
545 
546   // Returns function from the front of the work queue.
pop_front()547   FuncOp pop_front() {
548     FuncOp ret = queue_.front();
549     queue_.pop();
550     queue_set_.erase(ret);
551     return ret;
552   }
553 
554   // Returns the current size of the queue.
QueueSize() const555   std::queue<FuncOp>::size_type QueueSize() const { return queue_.size(); }
556 
557   Dialect* const tf_dialect_;
558 
559  private:
560   // Updates the result of an operation to a new inferred type. Also inserts
561   // tf.Cast operation for uses that are incompatible with the new type.
562   void UpdateTypeAndInsertIncompatibleUseCasts(Type new_type, Value result);
563 
564   // Refines the type of `result` of `op` using the type
565   // `potential_refined_type`. Return true if the type was changed.
566   bool RefineResultType(Operation* op, Value result,
567                         Type potential_refined_type);
568 
569   // Infers the shape from a (Stateful)PartionedCall operation by looking up the
570   // called function and propagating the return type.
571   bool InferShapeForCall(CallOpInterface call_op);
572 
573   bool InferShapeForCast(CastOp op);
574 
575   // Infers the shape IfOp outputs based on the shapes of the then and else
576   // function result types.
577   bool InferShapeForIf(IfOp op);
578 
579   // Infers the shape IfRegion outputs based on the shapes of the then and else
580   // yields.
581   bool InferShapeForIfRegion(IfRegionOp op);
582 
583   // Infers the shape of ops that create TensorList. Specifically,
584   // TensorListReserveOp, EmptyTensorListOp and TensorListFromTensor ops. It
585   // refines the element shape if all tensors written to the list across all
586   // mutations have identical static shape.
587   bool InferShapeForTensorListInitOps(Operation* op);
588 
589   bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti);
590 
591   // Returns all the callers of a function.
592   // Note: Usage of the return value of this function may not be interleaved
593   // with insertions to the callers map. This could occur if GetCallers is
594   // called with two separate functions, the 2nd one incurs a resize and then
595   // both first and 2nd stored callers are used.
596   ArrayRef<Operation*> GetCallers(FuncOp fn);
597 
598   // Mapping between ValuePort (which corresponds to an OpResult or smaller,
599   // e.g., first element of OpResult produced) to an Attribute if the ValuePort
600   // corresponds to a constant value.
601   ValuePortResultMap results_;
602 
603   // Map from a function to the callers of that function.
604   llvm::DenseMap<FuncOp, SmallVector<Operation*, 4>> callers_of_func_;
605 
606   // Queue of functions being processed.
607   llvm::DenseSet<FuncOp> queue_set_;
608   std::queue<FuncOp> queue_;
609 
610   int64_t graph_version_;
611 
612   // TODO(b/154065712): Remove propagate_caller_callee_constants once using
613   // SCCP pass instead.
614   bool propagate_caller_callee_constants_;
615 };
616 
ShapeInference(int64_t graph_version,MLIRContext * context,bool propagate_caller_callee_constants)617 ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context,
618                                bool propagate_caller_callee_constants)
619     : tf_dialect_(context->getLoadedDialect<TensorFlowDialect>()),
620       graph_version_(graph_version),
621       propagate_caller_callee_constants_(propagate_caller_callee_constants) {}
622 
GetCallers(FuncOp fn)623 ArrayRef<Operation*> ShapeInference::GetCallers(FuncOp fn) {
624   auto pair = callers_of_func_.try_emplace(fn);
625   if (pair.second) {
626     ModuleOp module = fn->getParentOfType<ModuleOp>();
627     auto uses = mlir::SymbolTable::getSymbolUses(fn.getOperation(), module);
628     if (uses) {
629       pair.first->second.reserve(pair.first->second.size());
630       for (auto use : *uses) {
631         pair.first->second.push_back(use.getUser());
632       }
633     }
634   }
635   return pair.first->second;
636 }
637 
EnqueueCallers(FuncOp fn)638 void ShapeInference::EnqueueCallers(FuncOp fn) {
639   for (auto user : GetCallers(fn)) enqueue(user->getParentOfType<FuncOp>());
640 }
641 
UpdateTypeAndInsertIncompatibleUseCasts(Type new_type,Value result)642 void ShapeInference::UpdateTypeAndInsertIncompatibleUseCasts(Type new_type,
643                                                              Value result) {
644   // A tf.Cast operation is lazily created on the first use requires a cast.
645   TF::CastOp cast_op;
646   auto get_cast_op = [&]() {
647     if (!cast_op) {
648       Operation* op = result.getDefiningOp();
649       OpBuilder b(op);
650       b.setInsertionPointAfter(op);
651       cast_op = b.create<TF::CastOp>(op->getLoc(), result.getType(), result,
652                                      /*truncate=*/b.getBoolAttr(false));
653     }
654     return Value(cast_op);
655   };
656   // First insert cast back for uses that need a cast and then
657   // update the type.
658   bool enqueue_callers = false;
659   for (OpOperand& use : make_early_inc_range(result.getUses())) {
660     if (isa<ReturnOp>(use.getOwner()))
661       enqueue_callers = true;
662     else if (NeedsCastBack(use, tf_dialect_))
663       use.set(get_cast_op());
664   }
665 
666   result.setType(new_type);
667   if (enqueue_callers)
668     EnqueueCallers(result.getDefiningOp()->getParentOfType<FuncOp>());
669 }
670 
RefineResultType(Operation * op,Value result,Type potential_refined_type)671 bool ShapeInference::RefineResultType(Operation* op, Value result,
672                                       Type potential_refined_type) {
673   if (!CanRefineTypeWith(result.getType(), potential_refined_type))
674     return false;
675 
676   UpdateTypeAndInsertIncompatibleUseCasts(potential_refined_type, result);
677   return true;
678 }
679 
680 // Infers the shape from a (Stateful)PartionedCall operation by looking up the
681 // called function and propagating the return type.
InferShapeForCall(CallOpInterface call_op)682 bool ShapeInference::InferShapeForCall(CallOpInterface call_op) {
683   FuncOp func = dyn_cast<FuncOp>(call_op.resolveCallable());
684   if (!func) return false;
685 
686   LLVM_DEBUG(llvm::dbgs() << "Infer shape for call " << func.getName());
687   Operation* op = call_op.getOperation();
688   bool changed = false;
689   // Map each of the results of the call to the returned type of the
690   // function.
691   for (auto result : zip(op->getResults(), func.getType().getResults())) {
692     changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) ||
693               changed;
694   }
695   LLVM_DEBUG(llvm::dbgs() << " changed ? " << changed << "\n");
696 
697   return changed;
698 }
699 
InferShapeForCast(CastOp op)700 bool ShapeInference::InferShapeForCast(CastOp op) {
701   DCOMMENT_OP(op.getOperation(), "Inferring shape for ");
702   Value result = op.getResult();
703   if (!CanBeRefined(result.getType())) return false;
704 
705   Type operand_type = op.getOperand().getType();
706   auto ranked_op_type = operand_type.dyn_cast<RankedTensorType>();
707   if (!ranked_op_type) return false;
708   auto ranked_res_type = result.getType().dyn_cast<RankedTensorType>();
709   if (ranked_res_type &&
710       ranked_op_type.getShape() == ranked_res_type.getShape())
711     return false;
712 
713   // Avoid inserting a cast where no users types could be refined (e.g., where
714   // there would need to be a cast inserted for every user again).
715   if (llvm::all_of(result.getUses(), [this](OpOperand& use) {
716         return NeedsCastBack(use, tf_dialect_);
717       }))
718     return false;
719 
720   auto new_type = RankedTensorType::get(
721       ranked_op_type.getShape(),
722       result.getType().cast<ShapedType>().getElementType());
723 
724   UpdateTypeAndInsertIncompatibleUseCasts(new_type, op.getResult());
725   return true;
726 }
727 
InferShapeForIf(IfOp op)728 bool ShapeInference::InferShapeForIf(IfOp op) {
729   DCOMMENT_OP(op.getOperation(), "Infer shape for if ");
730   bool changed = false;
731   auto then_results = op.then_function().getType().getResults();
732   auto else_results = op.else_function().getType().getResults();
733   for (auto it : llvm::zip(op.getResults(), then_results, else_results)) {
734     // If then and else types do not match, skip refinement for that result.
735     if (std::get<1>(it) != std::get<2>(it)) continue;
736     changed = RefineResultType(op, std::get<0>(it), std::get<1>(it)) || changed;
737   }
738   return changed;
739 }
740 
InferShapeForIfRegion(IfRegionOp op)741 bool ShapeInference::InferShapeForIfRegion(IfRegionOp op) {
742   bool changed = false;
743 
744   Operation* then_yield = op.then_branch().front().getTerminator();
745   Operation* else_yield = op.else_branch().front().getTerminator();
746   for (auto result : zip(op.getResults(), then_yield->getOperandTypes(),
747                          else_yield->getOperandTypes())) {
748     // If then and else types do not match, skip refinement for that result.
749     if (std::get<1>(result) != std::get<2>(result)) continue;
750     changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) ||
751               changed;
752   }
753   return changed;
754 }
755 
InferShapeForTensorListInitOps(Operation * op)756 bool ShapeInference::InferShapeForTensorListInitOps(Operation* op) {
757   DCOMMENT_OP(op, "Inferring shape for TensorList ");
758   Value handle = op->getResult(0);
759   Value initial_element_shape = GetElementShapeOperand(op);
760   RankedTensorType element_type;
761   if (auto tl_from_tensor = dyn_cast<TensorListFromTensorOp>(op)) {
762     // For TensorListFromTensor op we can infer element shape by dropping the
763     // first dimension of input tensor.
764     element_type = DropFirstDimension(tl_from_tensor.tensor().getType());
765     if (!element_type || !element_type.hasStaticShape()) return false;
766   }
767   if (!CanInferTensorListElementType(handle, initial_element_shape,
768                                      &element_type))
769     return false;
770   if (!element_type || !element_type.hasStaticShape()) return false;
771   auto variant_type = VariantType::get(element_type, op->getContext());
772   auto tensor_type = RankedTensorType::get({}, variant_type);
773   bool changed = RefineResultType(op, handle, tensor_type);
774   if (changed) DCOMMENT_OP(op, "Modified after shape inference:");
775   return changed;
776 }
777 
RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti)778 bool ShapeInference::RefineWithInferTypeOpInterface(
779     InferTypeOpInterface infer_ti) {
780   Operation* op = infer_ti.getOperation();
781   SmallVector<Type, 4> inferred;
782   LogicalResult res = infer_ti.inferReturnTypes(
783       op->getContext(), op->getLoc(), op->getOperands(),
784       op->getAttrDictionary(), op->getRegions(), inferred);
785   if (failed(res)) {
786     op->emitOpError("failed to refine type as inference failed");
787     return false;
788   }
789 
790   if (inferred == op->getResultTypes()) return false;
791 
792   // Map each of the results of the call to the returned type of the
793   // function.
794   bool changed = false;
795   for (auto result : zip(op->getResults(), inferred)) {
796     if (std::get<0>(result).getType() == std::get<1>(result)) continue;
797 
798     UpdateTypeAndInsertIncompatibleUseCasts(std::get<1>(result),
799                                             std::get<0>(result));
800     changed = true;
801   }
802   return changed;
803 }
804 
ComputeOutputAsShape(OpResult result,InferenceContext * ic)805 ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result,
806                                                  InferenceContext* ic) {
807   LLVM_DEBUG(result.print(llvm::dbgs() << "\nEvaluate partially "));
808   auto rt = result.getType().dyn_cast<RankedTensorType>();
809   if (!rt || !rt.hasStaticShape() || rt.getRank() != 1) return {};
810   int dim_size = rt.getDimSize(0);
811 
812   // Worklist to direct partial evaluation.
813   SmallVector<ValuePort, 4> worklist;
814 
815   // Simple evaluator that attempts to partially evaluate the input value even
816   // if unable to evaluate the complete output. Below follows a simple stack
817   // based evaluation where it queries what operands/part of operands need to
818   // be evaluated and attempting to partially evaluate those operands. It does
819   // so by pushing the operands that need to be required on to the worklist
820   // before enqueuing the operation requiering those values.
821   std::vector<DimensionHandle> dims(dim_size, ic->UnknownDim());
822   for (unsigned int i = 0, e = dims.size(); i != e; ++i) {
823     LLVM_DEBUG(llvm::dbgs() << "\nConsidering output dim " << i << "\n");
824 
825     worklist.push_back(
826         ValuePort{result.getOwner(), {result.getResultNumber(), i}});
827     while (!worklist.empty()) {
828       auto front = worklist.pop_back_val();
829       LLVM_DEBUG(front.print(llvm::dbgs() << "\nWorklist front "));
830 
831       SmallVector<ValuePort, 4> inputs;
832       auto res = ComputeInputsRequiredForOutput(front, &inputs);
833       if (failed(res)) {
834         // Abort if unable to find which required inputs need to be computed.
835         worklist.clear();
836         break;
837       }
838 
839       if (!inputs.empty()) {
840         // Enqueue required computation followed by its required operands in
841         // stack.
842         worklist.push_back(std::move(front));
843         for (auto& it : inputs) worklist.push_back(std::move(it));
844         continue;
845       }
846 
847       auto ret = ComputeOutputComponent(front);
848       if (!ret) continue;
849 
850       LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = "));
851 
852       // If worklist is empty, then this is the root query op.
853       if (worklist.empty()) {
854         LLVM_DEBUG(llvm::dbgs() << "[root node]\n");
855         if (auto dea = ret.dyn_cast<DenseIntElementsAttr>()) {
856           if (dea.getNumElements() != 1) {
857             LLVM_DEBUG(llvm::dbgs() << "Unexpected number of elements\n");
858             return {};
859           }
860           int64_t val = (*dea.getIntValues().begin()).getSExtValue();
861           dims[i] = ic->MakeDim(val);
862         }
863       }
864     }
865   }
866   return ic->MakeShape(dims);
867 }
868 
RefineTypeForPassThroughOperands(Operation * op,OperandRange operands,ResultRange results)869 bool ShapeInference::RefineTypeForPassThroughOperands(Operation* op,
870                                                       OperandRange operands,
871                                                       ResultRange results) {
872   bool changed = false;
873   for (auto entry : zip(operands, results)) {
874     Type operand_type = std::get<0>(entry).getType();
875     Value result = std::get<1>(entry);
876     TensorType result_type = result.getType().cast<TensorType>();
877     if (operand_type == result_type) continue;
878     // Pass through nodes may remove ref types, don't consider that as
879     // refinement.
880     // TODO(jpienaar): There could be refinement in addition to this, so
881     // refine this.
882     if (operand_type.cast<TensorType>()
883             .getElementType()
884             .isa<TF::TensorFlowRefType>() &&
885         !result_type.cast<TensorType>()
886              .getElementType()
887              .isa<TF::TensorFlowRefType>())
888       continue;
889 
890     UpdateTypeAndInsertIncompatibleUseCasts(operand_type, result);
891     changed = true;
892   }
893   return changed;
894 }
895 
RefineShapeForPassThroughOps(Operation * op)896 bool ShapeInference::RefineShapeForPassThroughOps(Operation* op) {
897   DCOMMENT_OP(op, "Pass through op");
898   auto is_allowed_dtype = [](Type t) {
899     // Skip if element type is not in standard or TF dialect.
900     // TODO(jpienaar): The tf.Cast op, which is uniformly inserted at the
901     // moment, cannot handle arbirary types (e.g., it can't handle quantized
902     // types). This restriction can be relaxed if not only tf.Cast is used.
903     return t.getDialect().getNamespace().empty() ||
904            isa<TensorFlowDialect>(t.getDialect());
905   };
906 
907   bool changed = false;
908   for (auto entry : zip(op->getOperands(), op->getResults())) {
909     TensorType operand_type = std::get<0>(entry).getType().cast<TensorType>();
910     Value result = std::get<1>(entry);
911     TensorType result_type = result.getType().cast<TensorType>();
912     if (operand_type == result_type) continue;
913     if (!operand_type.hasRank()) continue;
914     if (result_type.hasRank() &&
915         result_type.getShape() == operand_type.getShape())
916       continue;
917     if (!is_allowed_dtype(operand_type.getElementType()) ||
918         !is_allowed_dtype(result_type.getElementType()))
919       continue;
920 
921     auto new_type = RankedTensorType::get(operand_type.getShape(),
922                                           result_type.getElementType());
923     UpdateTypeAndInsertIncompatibleUseCasts(new_type, result);
924     changed = true;
925   }
926   return changed;
927 }
928 
InferShapeForNonTFDialectOperation(Operation * op)929 bool ShapeInference::InferShapeForNonTFDialectOperation(Operation* op) {
930   if (auto graph_op = dyn_cast<tf_executor::GraphOp>(op)) {
931     return RefineTypeForPassThroughOperands(
932         graph_op.GetFetch(), graph_op.GetFetch().fetches(), op->getResults());
933   }
934   if (auto island_op = dyn_cast<tf_executor::IslandOp>(op)) {
935     return RefineTypeForPassThroughOperands(
936         island_op.GetYield(), island_op.GetYield().fetches(), op->getResults());
937   }
938   if (auto iter_sink = dyn_cast<tf_executor::NextIterationSinkOp>(op)) {
939     auto iter_source = cast<tf_executor::NextIterationSourceOp>(
940         iter_sink.token().getDefiningOp());
941     return RefineTypeForPassThroughOperands(
942         op, iter_sink.getOperands().drop_front().take_front(),
943         iter_source.getResults());
944   }
945   if (auto launch_op = dyn_cast<tf_device::LaunchOp>(op)) {
946     auto terminator = launch_op.GetBody().getTerminator();
947     return RefineTypeForPassThroughOperands(op, terminator->getOperands(),
948                                             op->getResults());
949   }
950   if (auto cluster_op = dyn_cast<tf_device::ClusterOp>(op)) {
951     auto terminator = cluster_op.GetBody().getTerminator();
952     return RefineTypeForPassThroughOperands(op, terminator->getOperands(),
953                                             op->getResults());
954   }
955   if (op->hasTrait<OpTrait::SameOperandsAndResultShape>() ||
956       isa<tensor::CastOp>(op)) {
957     return RefineShapeForPassThroughOps(op);
958   }
959   if (auto call = dyn_cast<CallOpInterface>(op)) return InferShapeForCall(call);
960   return false;
961 }
962 
963 // Finds element type to be used for result from operand, with special handling
964 // for handle types.
GetElementTypeFromOperand(TensorType operand_type,TensorType result_type)965 Type GetElementTypeFromOperand(TensorType operand_type,
966                                TensorType result_type) {
967   auto operand_handle_type =
968       operand_type.getElementType().dyn_cast<TensorFlowTypeWithSubtype>();
969   if (!operand_handle_type) return result_type.getElementType();
970   auto result_handle_type =
971       result_type.getElementType().cast<TensorFlowTypeWithSubtype>();
972   if (operand_handle_type.GetSubtypes().empty() ||
973       !result_handle_type.GetSubtypes().empty())
974     return result_type.getElementType();
975   return operand_handle_type;
976 }
977 
978 // Checks if one tensor type can refine another type for tf.While/
979 // tf.WhileRegion. If rank differs or static dimensions can be lost, the other
980 // type cannot be used for refinement.
CanWhileTypeBeRefinedWith(TensorType current_type,TensorType potential_refined_type)981 bool CanWhileTypeBeRefinedWith(TensorType current_type,
982                                TensorType potential_refined_type) {
983   if (!current_type.hasRank()) return true;
984   if (!potential_refined_type.hasRank()) return false;
985   if (current_type.getRank() != potential_refined_type.getRank()) return false;
986   for (auto dim :
987        llvm::zip(current_type.getShape(), potential_refined_type.getShape())) {
988     int64_t current_dim = std::get<0>(dim);
989     int64_t potential_refined_dim = std::get<1>(dim);
990     if (current_dim != potential_refined_dim &&
991         current_dim != ShapedType::kDynamicSize)
992       return false;
993   }
994   return true;
995 }
996 
997 template <typename WhileOpTy>
InferShapeForWhile(WhileOpTy op,TypeRange body_result_types)998 bool ShapeInference::InferShapeForWhile(WhileOpTy op,
999                                         TypeRange body_result_types) {
1000   if (!op.shape_invariant())
1001     return RefineTypeForPassThroughOperands(op, op.input(), op.output());
1002 
1003   bool changed = false;
1004   for (auto entry :
1005        zip(op.input().getTypes(), op.output(), body_result_types)) {
1006     Value result = std::get<1>(entry);
1007     TensorType body_result_type =
1008         std::get<2>(entry).template cast<TensorType>();
1009     auto result_type = result.getType().cast<TensorType>();
1010 
1011     Type potential_refined_type;
1012     if (CanWhileTypeBeRefinedWith(result_type, body_result_type)) {
1013       Type element_type =
1014           GetElementTypeFromOperand(body_result_type, result_type);
1015       potential_refined_type = CreateTensorType(
1016           body_result_type.hasRank() ? body_result_type.getShape()
1017                                      : llvm::Optional<ArrayRef<int64_t>>(),
1018           element_type);
1019     } else {
1020       TensorType operand_type = std::get<0>(entry).template cast<TensorType>();
1021       Type element_type = GetElementTypeFromOperand(operand_type, result_type);
1022       potential_refined_type = CreateTensorType(
1023           result_type.hasRank() ? result_type.getShape()
1024                                 : llvm::Optional<ArrayRef<int64_t>>(),
1025           element_type);
1026     }
1027     changed |= RefineResultType(op, result, potential_refined_type);
1028   }
1029   return changed;
1030 }
1031 
InferShapeForSingleOperation(Operation * op)1032 bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
1033   LLVM_DEBUG(op->print(llvm::dbgs() << "InferShapeForSingleOperation for ");
1034              llvm::dbgs() << "\n");
1035   assert(tf_dialect_ == op->getDialect());
1036   // The shape function of these ops sometimes does not propagate subtypes
1037   // (handle shapes) for resource and variant types. We use a simple passthrough
1038   // to make sure they are preserved in the output.
1039   if (isa<TF::IdentityOp, TF::IdentityNOp, TF::ZerosLikeOp>(op)) {
1040     return RefineTypeForPassThroughOperands(op, op->getOperands(),
1041                                             op->getResults());
1042   }
1043 
1044   // If no result for this op needs shape inference, we have a fast-path return.
1045   // But if the type is a resource/variant, we do not skip it because we might
1046   // not have the handle shapes.
1047   if (none_of(op->getResultTypes(), CanBeRefined)) {
1048     LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '"
1049                             << op->getName() << "'.\n");
1050     return false;
1051   }
1052 
1053   // Handle call operations by looking up callee and inferring return shape as
1054   // needed.
1055   if (auto call = dyn_cast<CallOpInterface>(op)) return InferShapeForCall(call);
1056 
1057   // tf.Cast are only inferred if they have at least one user in the TF dialect
1058   // or feeding into the function return. This is necessary to avoid inserting
1059   // casts which cannot be refined.
1060   if (auto cast_op = dyn_cast<CastOp>(op)) return InferShapeForCast(cast_op);
1061 
1062   // Handle IfOp here by inferring the shape from the else/then function
1063   // results. Since `output_shapes` is a derived attribute, avoid going down the
1064   // TF InferenceContext path as IfOp shape inference is implemented as just
1065   // a lookup of the output_shapes attribute.
1066   if (auto if_op = dyn_cast<IfOp>(op)) return InferShapeForIf(if_op);
1067 
1068   // Handle IfRegion operations by inferring return shape from the then and else
1069   // branches.
1070   if (auto if_region = dyn_cast<IfRegionOp>(op))
1071     return InferShapeForIfRegion(if_region);
1072 
1073   if (auto while_op = dyn_cast<WhileOp>(op))
1074     return InferShapeForWhile(while_op,
1075                               while_op.body_function().getType().getResults());
1076 
1077   if (auto while_region = dyn_cast<WhileRegionOp>(op))
1078     return InferShapeForWhile(
1079         while_region,
1080         while_region.body().front().getTerminator()->getOperandTypes());
1081 
1082   // Handle TensorList init operations by inferring shape from TensorList write
1083   // operations. If we are unable to refine element shape here, proceed to use
1084   // the InferenceContext below to get more precise shapes.
1085   if (IsTensorListInitOp(op) && InferShapeForTensorListInitOps(op)) return true;
1086 
1087   // Return operand as a constant attribute.
1088   auto operand_as_constant_fn = [&](Value operand) {
1089     ValuePort vp(operand);
1090     Attribute attr = ComputeOutputComponent(vp);
1091     if (!attr && matchPattern(operand, m_Constant(&attr)))
1092       RecordValue(vp, attr);
1093     return attr;
1094   };
1095 
1096   // Return op result as a shape.
1097   auto op_result_as_shape_fn = [&](InferenceContext& context,
1098                                    OpResult op_result) {
1099     return ComputeOutputAsShape(op_result, &context);
1100   };
1101 
1102   // Return result element type at `index`.
1103   auto result_element_type_fn = [&](int index) {
1104     return op->getResult(index).getType().cast<TensorType>().getElementType();
1105   };
1106 
1107   llvm::SmallVector<ShapedTypeComponents, 4> inferred_return_shapes;
1108   if (failed(InferReturnTypeComponentsForTFOp(
1109           /*location=*/None, op, graph_version_, operand_as_constant_fn,
1110           op_result_as_shape_fn, result_element_type_fn,
1111           inferred_return_shapes)))
1112     return false;
1113 
1114   // Update the shape for each of the operation result if the InferenceContext
1115   // has more precise shapes recorded.
1116   bool changed = false;
1117   for (auto result : llvm::zip(op->getResults(), inferred_return_shapes)) {
1118     Value op_result = std::get<0>(result);
1119     if (!CanBeRefined(op_result.getType())) continue;
1120 
1121     ShapedTypeComponents inferred = std::get<1>(result);
1122     TensorType inferred_type;
1123     if (inferred.hasRank())
1124       inferred_type =
1125           RankedTensorType::get(inferred.getDims(), inferred.getElementType());
1126     else
1127       inferred_type = UnrankedTensorType::get(inferred.getElementType());
1128 
1129     if (op_result.getType() == inferred_type) continue;
1130     UpdateTypeAndInsertIncompatibleUseCasts(inferred_type, op_result);
1131     changed = true;
1132   }
1133 
1134   if (changed) DCOMMENT_OP(op, "Modified after shape inference:");
1135   return changed;
1136 }
1137 
PropagateShapeToFunctions(ModuleOp module,TypeRange input_types,ArrayRef<FuncOp> functions,int64_t max_iteration)1138 LogicalResult ShapeInference::PropagateShapeToFunctions(
1139     ModuleOp module, TypeRange input_types, ArrayRef<FuncOp> functions,
1140     int64_t max_iteration) {
1141   bool all_succeeded = true;
1142   // If shape propagation fails for one function, return failure, but do not
1143   // early exit and attempt to propagate shapes for all provided functions to
1144   // have a best-effort propagation.
1145   for (FuncOp func : functions) {
1146     DCOMMENT("Propating shape to" << func.getName());
1147     ArrayRef<Operation*> callers = GetCallers(func);
1148     if (!llvm::hasSingleElement(callers) &&
1149         !llvm::all_of(callers.drop_front(), [&](Operation* caller) {
1150           /// TODO(aminim): this is overly conservative as some operations
1151           /// (like TPUPartitionedCallOp) may have extra operands that aren't
1152           /// propagated to the callee.
1153           return isa<CallOpInterface>(caller) &&
1154                  std::equal(caller->getOperandTypes().begin(),
1155                             caller->getOperandTypes().end(),
1156                             callers.front()->getOperandTypes().begin());
1157         })) {
1158       all_succeeded = false;
1159       if (llvm::any_of(callers, [](Operation* op) {
1160             return isa<IfOp, WhileOp, CaseOp>(op);
1161           }))
1162         func.emitWarning(formatv(
1163             "expected control flow function @{0} to have exactly 1 use, "
1164             "found {1}.",
1165             func.getName(), callers.size()));
1166 
1167       continue;
1168     }
1169     FunctionType func_type = func.getType();
1170     func.setType(FunctionType::get(func.getContext(), input_types,
1171                                    func_type.getResults()));
1172 
1173     auto res =
1174         PropagateShapeToRegions(input_types, {&func.getBody()}, max_iteration);
1175     if (failed(res)) {
1176       all_succeeded = false;
1177       continue;
1178     }
1179 
1180     InferShapeForFunctionReturnType(func);
1181   }
1182   return success(all_succeeded);
1183 }
1184 
PropagateShapeToRegions(TypeRange input_types,ArrayRef<Region * > regions,int64_t max_iteration)1185 LogicalResult ShapeInference::PropagateShapeToRegions(TypeRange input_types,
1186                                                       ArrayRef<Region*> regions,
1187                                                       int64_t max_iteration) {
1188   DCOMMENT("\tPropagating shapes to regions");
1189   bool all_succeeded = true;
1190   // If shape propagation fails for one region, return failure, but do not
1191   // early exit and attempt to propagate shapes for all provided regions to
1192   // have a best-effort propagation.
1193   for (auto region : regions) {
1194     // Refine region arguments.
1195     Block& entry = region->front();
1196     assert(llvm::size(input_types) == entry.getNumArguments());
1197     for (auto it : llvm::zip(entry.getArguments(), input_types)) {
1198       BlockArgument arg = std::get<0>(it);
1199       Type type = std::get<1>(it);
1200       arg.setType(type);
1201     }
1202 
1203     // Propagate shapes into the region.
1204     all_succeeded = succeeded(InferShapeUntilFixPoint(region, max_iteration)) &&
1205                     all_succeeded;
1206   }
1207   return success(all_succeeded);
1208 }
1209 
PropagateConstantToCallee(CallOpInterface call_op,FuncOp func,ModuleOp module)1210 void ShapeInference::PropagateConstantToCallee(CallOpInterface call_op,
1211                                                FuncOp func, ModuleOp module) {
1212   auto callers = GetCallers(func);
1213   if (!llvm::hasSingleElement(callers)) return;
1214 
1215   OpBuilder builder(&func.front().front());
1216   Operation* op = call_op.getOperation();
1217   // If this is the only caller, and an operand is a constant, propagate
1218   // the constant value inside the function.
1219   for (auto arg : func.getArguments()) {
1220     auto operand = op->getOperand(arg.getArgNumber());
1221     if (propagate_caller_callee_constants_) {
1222       if (isa_and_nonnull<TF::ConstOp>(operand.getDefiningOp())) {
1223         arg.replaceAllUsesWith(
1224             builder.clone(*operand.getDefiningOp())->getResult(0));
1225       }
1226       continue;
1227     }
1228 
1229     auto known_constant = ComputeOutputComponent(ValuePort(operand));
1230     if (!known_constant) continue;
1231     LLVM_DEBUG(call_op.print(llvm::dbgs() << "Propagate to calee: ");
1232                known_constant.print(llvm::dbgs() << " constant ");
1233                llvm::dbgs() << "\n");
1234     RecordValue(ValuePort(arg), known_constant);
1235   }
1236 }
1237 
PropagateConstantFromCallee(CallOpInterface call_op,FuncOp func,ModuleOp module)1238 void ShapeInference::PropagateConstantFromCallee(CallOpInterface call_op,
1239                                                  FuncOp func, ModuleOp module) {
1240   // If the return value is a constant, use the constant as the value of
1241   // the call return.
1242   Operation* op = call_op.getOperation();
1243   OpBuilder builder(op);
1244   builder.setInsertionPointAfter(op);
1245   for (auto retval :
1246        llvm::enumerate(func.front().getTerminator()->getOperands())) {
1247     if (propagate_caller_callee_constants_) {
1248       auto retval_op = retval.value().getDefiningOp();
1249       if (isa_and_nonnull<TF::ConstOp>(retval_op)) {
1250         op->getResult(retval.index())
1251             .replaceAllUsesWith(builder.clone(*retval_op)->getResult(0));
1252       }
1253       continue;
1254     }
1255 
1256     ValuePort vp(retval.value());
1257     if (auto known_constant = ComputeOutputComponent(vp)) {
1258       LLVM_DEBUG(known_constant.print(llvm::dbgs() << "Propagate constant ");
1259                  call_op.print(llvm::dbgs() << "from "); llvm::dbgs() << "\n");
1260       RecordValue(ValuePort(op->getResult(retval.index())), known_constant);
1261     }
1262   }
1263 }
1264 
RankedAndSameRank(TensorType lhs,TensorType rhs)1265 bool RankedAndSameRank(TensorType lhs, TensorType rhs) {
1266   return lhs.hasRank() && rhs.hasRank() && lhs.getRank() == rhs.getRank();
1267 }
1268 
1269 // Creates a compatible RankedTensorType where mismatched dimensions are
1270 // replaced with dynamic sizes.
GetCompatibleRankedTensorType(RankedTensorType lhs,RankedTensorType rhs)1271 RankedTensorType GetCompatibleRankedTensorType(RankedTensorType lhs,
1272                                                RankedTensorType rhs) {
1273   assert(lhs.getRank() == rhs.getRank());
1274   llvm::SmallVector<int64_t, 4> dims;
1275   dims.reserve(lhs.getRank());
1276   for (auto dim : llvm::zip(lhs.getShape(), rhs.getShape())) {
1277     int64_t lhs_dim = std::get<0>(dim);
1278     if (lhs_dim == std::get<1>(dim)) {
1279       dims.push_back(lhs_dim);
1280     } else {
1281       dims.push_back(ShapedType::kDynamicSize);
1282     }
1283   }
1284   return RankedTensorType::get(dims, GetElementTypeFromOperand(lhs, rhs));
1285 }
1286 
1287 // Finds compatible types to propagate into functions/regions of a shape
1288 // invariant tf.While/tf.WhileRegion. If operand and result types are the same,
1289 // that type is returned. If operand and result types are of the same rank, a
1290 // compatible type with matching dimensions is used. Otherwise functions/regions
1291 // arguments are returned but with the handle type from the operand type.
GetWhileCompatibleTypes(TypeRange operand_types,TypeRange result_types,TypeRange region_argument_types)1292 llvm::SmallVector<Type, 4> GetWhileCompatibleTypes(
1293     TypeRange operand_types, TypeRange result_types,
1294     TypeRange region_argument_types) {
1295   llvm::SmallVector<Type, 4> types;
1296   types.reserve(operand_types.size());
1297   for (auto entry :
1298        llvm::zip(operand_types, result_types, region_argument_types)) {
1299     auto operand_type = std::get<0>(entry).cast<TensorType>();
1300     auto result_type = std::get<1>(entry).cast<TensorType>();
1301     if (operand_type == result_type) {
1302       types.push_back(operand_type);
1303     } else if (RankedAndSameRank(operand_type, result_type)) {
1304       auto potential_refined_type =
1305           GetCompatibleRankedTensorType(operand_type.cast<RankedTensorType>(),
1306                                         result_type.cast<RankedTensorType>());
1307       types.push_back(potential_refined_type);
1308     } else {
1309       auto region_argument_type = std::get<2>(entry).cast<TensorType>();
1310       Type element_type = GetElementTypeFromOperand(
1311           operand_type.cast<TensorType>(), region_argument_type);
1312       Type potential_refined_type = CreateTensorType(
1313           region_argument_type.hasRank() ? region_argument_type.getShape()
1314                                          : llvm::Optional<ArrayRef<int64_t>>(),
1315           element_type);
1316       types.push_back(potential_refined_type);
1317     }
1318   }
1319   return types;
1320 }
1321 
PropagateShapeIntoAttachedFunctions(Operation * op,int64_t max_iteration)1322 LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions(
1323     Operation* op, int64_t max_iteration) {
1324   ModuleOp module = op->getParentOfType<ModuleOp>();
1325   if (auto if_op = dyn_cast<TF::IfOp>(op)) {
1326     DCOMMENT("Propagating shapes into If");
1327     return PropagateShapeToFunctions(
1328         module, if_op.input().getTypes(),
1329         {if_op.then_function(), if_op.else_function()}, max_iteration);
1330   } else if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
1331     SmallVector<FuncOp, 4> branches;
1332     case_op.get_branch_functions(branches);
1333     return PropagateShapeToFunctions(module, case_op.input().getTypes(),
1334                                      branches, max_iteration);
1335   } else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
1336     // If `shape_invariant` is set, operand shapes cannot be simply propagated
1337     // to result shapes as the op may have different intermediate shapes (such
1338     // While ops can have different result shapes from operand shapes).
1339     // Compatible shapes must be determined before propagating them.
1340     if (while_op.shape_invariant()) {
1341       auto compatible_types = GetWhileCompatibleTypes(
1342           while_op.input().getTypes(), while_op.output().getTypes(),
1343           while_op.body_function().getType().getInputs());
1344       return PropagateShapeToFunctions(
1345           module, compatible_types,
1346           {while_op.cond_function(), while_op.body_function()}, max_iteration);
1347     }
1348     return PropagateShapeToFunctions(
1349         module, while_op.input().getTypes(),
1350         {while_op.cond_function(), while_op.body_function()}, max_iteration);
1351   } else if (auto call_op = dyn_cast<CallOpInterface>(op)) {
1352     if (auto func = dyn_cast<FuncOp>(call_op.resolveCallable())) {
1353       PropagateConstantToCallee(call_op, func, module);
1354       if (failed(PropagateShapeToFunctions(module,
1355                                            call_op.getArgOperands().getTypes(),
1356                                            {func}, max_iteration))) {
1357         return failure();
1358       }
1359       PropagateConstantFromCallee(call_op, func, module);
1360       return success();
1361     }
1362   }
1363 
1364   // TODO(ycao): Implement support for Call op, including function reuse.
1365 
1366   return success();
1367 }
1368 
PropagateShapeIntoAttachedRegions(Operation * op,int64_t max_iteration)1369 LogicalResult ShapeInference::PropagateShapeIntoAttachedRegions(
1370     Operation* op, int64_t max_iteration) {
1371   if (auto while_op = dyn_cast<TF::WhileRegionOp>(op)) {
1372     // If `shape_invariant` is set, operand shapes cannot be simply propagated
1373     // to result shapes as the op may have different intermediate shapes (such
1374     // While ops can have different result shapes from operand shapes).
1375     // Compatible shapes must be determined before propagating them.
1376     if (while_op.shape_invariant()) {
1377       auto compatible_types = GetWhileCompatibleTypes(
1378           while_op.input().getTypes(), while_op.output().getTypes(),
1379           while_op.body().getArgumentTypes());
1380       return PropagateShapeToRegions(compatible_types,
1381                                      {&while_op.cond(), &while_op.body()},
1382                                      max_iteration);
1383     }
1384     return PropagateShapeToRegions(while_op.input().getTypes(),
1385                                    {&while_op.cond(), &while_op.body()},
1386                                    max_iteration);
1387   }
1388   return success();
1389 }
1390 
TryToFold(Operation * op)1391 LogicalResult ShapeInference::TryToFold(Operation* op) {
1392   LLVM_DEBUG(op->print(llvm::dbgs() << "TryToFold "); llvm::dbgs() << "\n");
1393   // If any output result is known, then the op probably has been computed
1394   // before.
1395   if (op->getNumResults() > 0 && results_[ValuePort(op->getResult(0))])
1396     return success();
1397 
1398   SmallVector<Attribute, 8> constant_operands(op->getNumOperands());
1399   SmallVector<OpFoldResult, 8> fold_results;
1400 
1401   // Check to see if any operands to the operation is constant and whether
1402   // the operation knows how to constant fold itself.
1403   bool some_unknown = false;
1404   for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
1405     if (!(constant_operands[i] =
1406               ComputeOutputComponent(ValuePort(op->getOperand(i)))))
1407       some_unknown = true;
1408   }
1409 
1410   // Attempt to constant fold the operation.
1411   auto* abstract_op = op->getAbstractOperation();
1412   LogicalResult folded = failure();
1413   if (abstract_op) {
1414     folded = abstract_op->foldHook(op, constant_operands, fold_results);
1415   }
1416   // Attempt dialect fallback if op's fold hook failed.
1417   if (failed(folded)) {
1418     Dialect* dialect = op->getDialect();
1419     if (!dialect) return failure();
1420     // Only attempt TF dialect fallback if there are no unknown operands.
1421     if (some_unknown && dialect == tf_dialect_) return failure();
1422     auto* interface = dialect->getRegisteredInterface<DialectFoldInterface>();
1423     if (!interface) return failure();
1424 
1425     if (failed(interface->fold(op, constant_operands, fold_results)))
1426       return failure();
1427   }
1428 
1429   for (auto result : zip(op->getResults(), fold_results)) {
1430     auto fold_result = std::get<1>(result);
1431     Attribute attr = nullptr;
1432     if ((attr = fold_result.dyn_cast<Attribute>())) {
1433       RecordValue(ValuePort(std::get<0>(result)), attr);
1434     } else {
1435       auto value = fold_result.get<Value>();
1436       if ((attr = ComputeOutputComponent(ValuePort(value)))) {
1437         DCOMMENT("\t\tValue Result mapped to " << attr);
1438         RecordValue(ValuePort(std::get<0>(result)), attr);
1439       } else {
1440         DCOMMENT("\t\tValue result unmapped, consider value type:" << value);
1441         RefineResultType(op, std::get<0>(result), value.getType());
1442       }
1443     }
1444 
1445     if (ElementsAttr eattr = attr.dyn_cast_or_null<ElementsAttr>()) {
1446       if (std::get<0>(result).getType() == eattr.getType()) continue;
1447 
1448       UpdateTypeAndInsertIncompatibleUseCasts(eattr.getType(),
1449                                               std::get<0>(result));
1450     }
1451   }
1452 
1453   return success();
1454 }
1455 
InferShapeForFunctionReturnType(FuncOp func)1456 void ShapeInference::InferShapeForFunctionReturnType(FuncOp func) {
1457   LLVM_DEBUG(llvm::dbgs() << "Inferring return type for: " << func.getName()
1458                           << "\n");
1459 
1460   // Find any return ops.
1461   SmallVector<ReturnOp, 4> return_ops;
1462   for (Block& block : func) {
1463     if (auto return_op = dyn_cast<ReturnOp>(block.getTerminator())) {
1464       return_ops.push_back(return_op);
1465     }
1466   }
1467 
1468   // Right now we only handle the case of a single return op.
1469   // To handle multiple return ops, we would need to look at all their shapes
1470   // and come up with a common shape and insert appropriate casts.
1471   if (return_ops.size() != 1) return;
1472 
1473   // Find the return type.
1474   auto return_op = return_ops.front();
1475 
1476   // Manually fold tf.Cast that precedes the return instruction and only differs
1477   // in shape refinement level.
1478   bool changed = false;
1479   for (OpOperand& arg_op : return_op.getOperation()->getOpOperands()) {
1480     Operation* arg_defining_op = arg_op.get().getDefiningOp();
1481     if (auto cast_op = dyn_cast_or_null<CastOp>(arg_defining_op)) {
1482       Value input = cast_op.x();
1483       Value result = cast_op.y();
1484       if (!CanRefineTypeWith(result.getType(), input.getType())) continue;
1485 
1486       LLVM_DEBUG({
1487         llvm::errs() << "\tfolding & updating return type ";
1488         cast_op.getResult().getType().print(llvm::errs());
1489         cast_op.getOperand().getType().print(llvm::errs() << " to ");
1490         llvm::errs() << "\n";
1491       });
1492 
1493       // Shape inference should not change the element type.
1494       if (HasCompatibleElementTypes(input.getType(), result.getType())) {
1495         arg_op.set(input);
1496       } else {
1497         OpBuilder b(return_op.getOperation());
1498         TensorType type;
1499         if (input.getType().cast<TensorType>().hasRank()) {
1500           type = RankedTensorType::get(
1501               input.getType().cast<TensorType>().getShape(),
1502               result.getType().cast<TensorType>().getElementType());
1503         } else {
1504           type = UnrankedTensorType::get(
1505               result.getType().cast<TensorType>().getElementType());
1506         }
1507         auto new_cast_op =
1508             b.create<TF::CastOp>(return_op.getLoc(), type, input,
1509                                  /*truncate=*/b.getBoolAttr(false));
1510         arg_op.set(new_cast_op);
1511       }
1512       if (cast_op.y().use_empty()) cast_op.erase();
1513       changed = true;
1514     }
1515   }
1516 
1517   DCOMMENT("Updating function type");
1518   func.setType(FunctionType::get(func.getContext(), func.getArgumentTypes(),
1519                                  return_op.getOperandTypes()));
1520 
1521   if (changed) EnqueueCallers(func);
1522 }
1523 
InferShapeUntilFixPoint(Region * region,int64_t max_iteration)1524 LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
1525                                                       int64_t max_iteration) {
1526   bool changed = true;
1527 
1528   // TODO(aminim): we could have a more efficient traversal by guiding the
1529   // traversal with a worklist and reconsider only the nodes for which an
1530   // operand type was inferred. This would need to be careful if working on a
1531   // region that would not be isolated.
1532   for (int iteration = 0; iteration < max_iteration && changed; ++iteration) {
1533     changed = false;
1534     LLVM_DEBUG(llvm::dbgs()
1535                << "Shape inference, iteration " << iteration << "\n");
1536     region->walk([&](Operation* op) {
1537       DCOMMENT_OP(op, "Inferring for");
1538       if (auto infer_ti = dyn_cast<InferTypeOpInterface>(op)) {
1539         DCOMMENT("\tRefinining with type op interface");
1540         changed |= RefineWithInferTypeOpInterface(infer_ti);
1541         return;
1542       }
1543 
1544       if (op->getDialect() != tf_dialect_) {
1545         DCOMMENT("\tInfer non-TF dialect");
1546         changed |= InferShapeForNonTFDialectOperation(op);
1547         return;
1548       }
1549 
1550       // Before attempting inference, just try to compute the folded
1551       // value/shape.
1552       if (succeeded(TryToFold(op)) &&
1553           // Folding can "succeed" and yet not all types be refined. In such
1554           // cases we still want to give a try at `InferShapeForSingleOperation`
1555           none_of(op->getResultTypes(), CanBeRefined))
1556         return;
1557 
1558       // Best-effort shape inference in attached functions. Do not return
1559       // failure even if it doesn't get to fixed point.
1560       if (failed(PropagateShapeIntoAttachedFunctions(op, max_iteration))) {
1561         op->emitWarning() << "unable to refine shape of attached function "
1562                              "arguments and bodies";
1563       }
1564 
1565       if (failed(PropagateShapeIntoAttachedRegions(op, max_iteration))) {
1566         op->emitWarning() << "unable to refine shape of attached region "
1567                              "arguments and bodies";
1568       }
1569 
1570       changed |= InferShapeForSingleOperation(op);
1571     });
1572   }
1573 
1574   if (changed) {
1575     return region->getParentOp()->emitWarning()
1576            << "shape inference did not reach stable state after "
1577            << max_iteration << " iterations";
1578   }
1579   return success();
1580 }
1581 
InferShapeForFunction(ShapeInference & context,FuncOp func,int64_t max_iterations)1582 static LogicalResult InferShapeForFunction(ShapeInference& context, FuncOp func,
1583                                            int64_t max_iterations) {
1584   if (failed(context.InferShapeUntilFixPoint(&func.getBody(), max_iterations)))
1585     return failure();
1586   // TODO(b/156276510): Verify that it is always fine to refine a function's
1587   // return type, as long as we do not change the argument shapes.
1588   context.InferShapeForFunctionReturnType(func);
1589 
1590   return success();
1591 }
1592 
InferShapeForFunction(FuncOp func,ArrayRef<ArrayRef<int64_t>> arg_shapes,int64_t graph_version,int64_t max_iterations)1593 LogicalResult InferShapeForFunction(FuncOp func,
1594                                     ArrayRef<ArrayRef<int64_t>> arg_shapes,
1595                                     int64_t graph_version,
1596                                     int64_t max_iterations) {
1597   ShapeInference context(graph_version, func.getContext(),
1598                          /*propagate_caller_callee_constants=*/true);
1599   if (arg_shapes.empty()) {
1600     return InferShapeForFunction(context, func, max_iterations);
1601   }
1602 
1603   FunctionType func_type = func.getType();
1604   bool needs_refinement = false;
1605   SmallVector<Type, 4> new_arg_types;
1606   new_arg_types.reserve(func_type.getNumInputs());
1607 
1608   // Update argument types in-place using the provided arg_shapes.
1609   for (size_t i = 0; i < func_type.getNumInputs(); ++i) {
1610     ArrayRef<int64_t> shape = arg_shapes[i];
1611     Type element_type;
1612     if (auto input_ty = func_type.getInput(i).dyn_cast<RankedTensorType>()) {
1613       if (input_ty.getRank() != shape.size()) {
1614         return failure();
1615       }
1616       element_type = input_ty.getElementType();
1617     } else {
1618       auto unranked_input_ty = func_type.getInput(i).dyn_cast<TensorType>();
1619       if (!unranked_input_ty) {
1620         return failure();
1621       }
1622       element_type = unranked_input_ty.getElementType();
1623     }
1624 
1625     auto new_arg_type = RankedTensorType::get(shape, element_type);
1626     if (new_arg_type != func_type.getInput(i)) {
1627       // If the new type is more detailed, trigger shape inference.
1628       func.getArgument(i).setType(new_arg_type);
1629       needs_refinement = true;
1630     }
1631     new_arg_types.push_back(new_arg_type);
1632   }
1633 
1634   if (!needs_refinement) return success();
1635 
1636   if (failed(context.InferShapeUntilFixPoint(&func.getBody(), max_iterations)))
1637     return failure();
1638 
1639   context.InferShapeForFunctionReturnType(func);
1640   func.setType(FunctionType::get(func.getContext(), new_arg_types,
1641                                  func.getType().getResults()));
1642 
1643   return success();
1644 }
1645 
InferModuleShape(ModuleOp module,int64_t max_iterations)1646 LogicalResult InferModuleShape(ModuleOp module, int64_t max_iterations) {
1647   auto producer_or = tensorflow::GetTfGraphProducerVersion(module);
1648   if (!producer_or.ok()) {
1649     // TODO(jpienaar): Keeping the existing behavior for now but this could
1650     // be relaxed.
1651     LLVM_DEBUG(llvm::dbgs()
1652                << "Skipping inference; " << producer_or.status().ToString());
1653     return success();
1654   }
1655   int64_t producer = producer_or.ValueOrDie();
1656   // TODO(jpienaar): Clean up propagate_NextIterationSinkOp_callee_constants if
1657   // it is no longer needed.
1658   ShapeInference context(producer, module.getContext(),
1659                          /*propagate_caller_callee_constants=*/false);
1660   if (auto main = module.lookupSymbol<mlir::FuncOp>("main"))
1661     context.enqueue(main);
1662   for (auto func : module.getOps<FuncOp>()) context.enqueue(func);
1663   // Arbitrarily upper bound the maximum number of functions that get processed
1664   // just to avoid pathological cases.
1665   auto max_iteration = context.QueueSize() * 4;
1666   while (!context.EmptyQueue()) {
1667     FuncOp func = context.front();
1668     if (failed(InferShapeForFunction(context, func, max_iterations)))
1669       return failure();
1670     context.pop_front();
1671 
1672     if ((--max_iteration) == 0) {
1673       return emitWarning(UnknownLoc::get(module.getContext()))
1674              << "shape inference did not reach stable state after "
1675              << max_iteration << " iterations";
1676     }
1677   }
1678   return success();
1679 }
1680 
1681 }  // namespace TF
1682 }  // namespace mlir
1683