• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <initializer_list>
21 #include <iterator>
22 #include <queue>
23 #include <stack>
24 
25 #include "llvm/ADT/Hashing.h"
26 #include "llvm/ADT/None.h"
27 #include "llvm/ADT/PointerUnion.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/iterator_range.h"
31 #include "llvm/Support/Casting.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Support/FormatVariadic.h"
34 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
35 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
36 #include "mlir/IR/Attributes.h"  // from @llvm-project
37 #include "mlir/IR/Block.h"  // from @llvm-project
38 #include "mlir/IR/Builders.h"  // from @llvm-project
39 #include "mlir/IR/BuiltinDialect.h"  // from @llvm-project
40 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
41 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
42 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
43 #include "mlir/IR/Location.h"  // from @llvm-project
44 #include "mlir/IR/Operation.h"  // from @llvm-project
45 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
46 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
47 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
48 #include "mlir/IR/Value.h"  // from @llvm-project
49 #include "mlir/IR/Visitors.h"  // from @llvm-project
50 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
51 #include "mlir/Interfaces/FoldInterfaces.h"  // from @llvm-project
52 #include "mlir/Interfaces/InferTypeOpInterface.h"  // from @llvm-project
53 #include "mlir/Pass/Pass.h"  // from @llvm-project
54 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
55 #include "mlir/Support/LLVM.h"  // from @llvm-project
56 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
57 #include "mlir/Transforms/FoldUtils.h"  // from @llvm-project
58 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
59 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
60 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
61 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
62 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
63 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
64 #include "tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h"
65 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
66 #include "tensorflow/core/framework/shape_inference.h"
67 #include "tensorflow/core/framework/types.pb.h"
68 #include "tensorflow/core/ir/types/dialect.h"
69 
70 #define DEBUG_TYPE "tf-shape-inference"
71 
72 #define DCOMMENT(MSG) LLVM_DEBUG(llvm::dbgs() << MSG << "\n")
73 #define DCOMMENT_OP(OP, MSG) \
74   LLVM_DEBUG(OP->print(llvm::dbgs() << MSG << " "); llvm::dbgs() << "\n")
75 
76 using ::tensorflow::int64;
77 using tensorflow::shape_inference::DimensionHandle;
78 using tensorflow::shape_inference::InferenceContext;
79 using tensorflow::shape_inference::ShapeHandle;
80 
81 namespace mlir {
82 namespace TF {
83 namespace {
84 
85 // Compute a refined type between two types `lhs` and `rhs`, the result type
86 // is always more refined (i.e. has more static information) than `lhs`
87 // This method will actually merge the information contained in the
88 // types, it is capable of refining:
89 //   tensor<!tf_type.variant<tensor<?x8xf32>>>
90 // and:
91 //   tensor<!tf_type.variant<tensor<10x?xf32>>>
92 // into:
93 //   tensor<!tf_type.variant<tensor<10x8xf32>>>
94 //
95 // In case of inconsistencies (rank disagreement for example), it returns `lhs`.
TypeMeet(Type lhs,Type rhs)96 Type TypeMeet(Type lhs, Type rhs) {
97   DCOMMENT("RefineTypeWith : " << lhs << " : " << rhs);
98   if (lhs == rhs) return lhs;
99 
100   auto rhs_shape_type = rhs.dyn_cast<ShapedType>();
101   if (!rhs_shape_type) return lhs;
102   auto lhs_shape_type = lhs.cast<ShapedType>();
103   if (lhs_shape_type.hasRank() && rhs_shape_type.hasRank() &&
104       lhs_shape_type.getRank() != rhs_shape_type.getRank()) {
105     DCOMMENT("Unexpected rank mismatch: " << lhs << " vs " << rhs);
106     return lhs;
107   }
108 
109   SmallVector<int64_t> shape;
110   bool refined_shape = false;
111   // Build the shape of the refined type, if lhs is unranked it
112   // will be directly the shape of the refined type, otherwise we merged by
113   // taking the most specialized. This combines `10x?x?` and `?x?x8` into
114   // `10x?x8`.
115   if (!lhs_shape_type.hasRank()) {
116     if (rhs_shape_type.hasRank()) {
117       shape.append(rhs_shape_type.getShape().begin(),
118                    rhs_shape_type.getShape().end());
119       refined_shape = true;
120     }
121   } else if (rhs_shape_type.hasRank()) {
122     for (auto shape_elts : llvm::enumerate(
123              llvm::zip(lhs_shape_type.getShape(), rhs_shape_type.getShape()))) {
124       if (ShapedType::isDynamic(std::get<0>(shape_elts.value())) &&
125           !ShapedType::isDynamic(std::get<1>(shape_elts.value()))) {
126         shape.push_back(std::get<1>(shape_elts.value()));
127         refined_shape = true;
128         DCOMMENT("-> refining shape element #" << shape_elts.index());
129       } else {
130         DCOMMENT("-> not refining shape element #" << shape_elts.index());
131         shape.push_back(std::get<0>(shape_elts.value()));
132       }
133     }
134   }
135 
136   // Some tensor have an element type wrapping a subtensor, like resource and
137   // variants. In this case we may recurse on the wrapped subtype.
138   // `element_type` will contain the refined inferred element type for the
139   // returned type.
140   auto lhs_element_type = lhs_shape_type.getElementType();
141   auto rhs_element_type_with_subtype =
142       rhs_shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>();
143   // Look for resource or variant element type and ensure we refine the subtype.
144   // We only support a single subtype at the moment, we won't handle something
145   // like:
146   //   tensor<!tf_type.variant<tensor<10xf32>, tensor<8xf32>>
147   if (rhs_element_type_with_subtype &&
148       rhs_element_type_with_subtype.GetSubtypes().size() == 1) {
149     auto lhs_element_type_with_subtype =
150         lhs_element_type.dyn_cast<TF::TensorFlowTypeWithSubtype>();
151     TensorType subtype;
152     if (!lhs_element_type_with_subtype) {
153       DCOMMENT(
154           "Unexpected inferred `TensorFlowTypeWithSubtype` when original "
155           "result isn't");
156     } else if (lhs_element_type_with_subtype.GetSubtypes().size() > 1) {
157       DCOMMENT(
158           "Unexpected `TensorFlowTypeWithSubtype` original type with size>1");
159     } else if (lhs_element_type_with_subtype.GetSubtypes().empty()) {
160       subtype = rhs_element_type_with_subtype.GetSubtypes().front();
161     } else {
162       // Recurse on the subtypes in the variant/resource. Basically if the input
163       // were:
164       //   tensor<!tf_type.variant<tensor<?x8xf32>>>
165       // and:
166       //   tensor<!tf_type.variant<tensor<10x8xf32>>>
167       // we'll try here to refine tensor<?x8xf32> with tensor<10x8xf32>.
168       auto refined_subtype =
169           TypeMeet(lhs_element_type_with_subtype.GetSubtypes().front(),
170                    rhs_element_type_with_subtype.GetSubtypes().front())
171               .cast<TensorType>();
172       if (refined_subtype !=
173           lhs_element_type_with_subtype.GetSubtypes().front())
174         subtype = refined_subtype;
175     }
176     // If we managed to refine the subtype, recreate the element type itself
177     // (i.e. the tf.variant or tf.resource).
178     if (subtype) {
179       lhs_element_type = lhs_element_type_with_subtype.clone({subtype});
180     }
181   }
182   if (refined_shape || lhs_element_type != lhs_shape_type.getElementType()) {
183     Type new_type;
184     if (!lhs_shape_type.hasRank() && !rhs_shape_type.hasRank())
185       new_type = UnrankedTensorType::get(lhs_element_type);
186     else
187       new_type = lhs_shape_type.clone(shape, lhs_element_type);
188     DCOMMENT("Refined to: " << new_type);
189     return new_type;
190   }
191   DCOMMENT("No refinement " << lhs);
192   return lhs;
193 }
194 
195 // Returns whether `original_type` type can be refined with
196 // `potential_refined_type` type.
CanRefineTypeWith(Type original_type,Type potential_refined_type)197 bool CanRefineTypeWith(Type original_type, Type potential_refined_type) {
198   return original_type != TypeMeet(original_type, potential_refined_type);
199 }
200 
201 // Returns if the shape inference pass supports an op outside the TF dialect.
IsSupportedNonTFOp(Operation * op)202 bool IsSupportedNonTFOp(Operation* op) {
203   return isa<tf_device::ReturnOp, tf_device::ClusterOp, tf_device::LaunchOp,
204              tf_executor::EnterOp, tf_executor::ExitOp, tf_executor::FetchOp,
205              tf_executor::GraphOp, tf_executor::IslandOp,
206              tf_executor::LoopCondOp, tf_executor::MergeOp,
207              tf_executor::NextIterationSinkOp, tf_executor::SwitchNOp,
208              tf_executor::SwitchOp, tf_executor::YieldOp>(op) ||
209          isa<InferTypeOpInterface>(op);
210 }
211 
212 // Returns whether a cast back would need to be inserted, e.g., whether the
213 // operation of which use is an operand allows for shape refinement without
214 // a cast.
NeedsCastBack(OpOperand & use,Dialect * tf_dialect)215 bool NeedsCastBack(OpOperand& use, Dialect* tf_dialect) {
216   return use.getOwner()->getDialect() != tf_dialect &&
217          !IsSupportedNonTFOp(use.getOwner());
218 }
219 
CreateTensorType(llvm::Optional<llvm::ArrayRef<int64_t>> shape,Type element_type)220 TensorType CreateTensorType(llvm::Optional<llvm::ArrayRef<int64_t>> shape,
221                             Type element_type) {
222   if (shape.hasValue())
223     return RankedTensorType::get(shape.getValue(), element_type);
224   return UnrankedTensorType::get(element_type);
225 }
226 
227 // Returns true if the op creates a TensorList.
IsTensorListInitOp(Operation * op)228 bool IsTensorListInitOp(Operation* op) {
229   return isa<TensorListReserveOp>(op) || isa<EmptyTensorListOp>(op) ||
230          isa<TensorListFromTensorOp>(op);
231 }
232 
233 // Returns the `element_shape` operand of the ops that create a TensorList.
GetElementShapeOperand(Operation * op)234 Value GetElementShapeOperand(Operation* op) {
235   if (auto empty_tl = dyn_cast<EmptyTensorListOp>(op))
236     return empty_tl.element_shape();
237   if (auto tl_reserve = dyn_cast<TensorListReserveOp>(op))
238     return tl_reserve.element_shape();
239   if (auto tl_from_tensor = dyn_cast<TensorListFromTensorOp>(op))
240     return tl_from_tensor.element_shape();
241   llvm_unreachable("unsupported TensorList op");
242 }
243 
244 // Utility function to create a ranked tensor type after dropping the first
245 // dimension from the input type.
DropFirstDimension(Type type)246 RankedTensorType DropFirstDimension(Type type) {
247   RankedTensorType ranked_type = type.dyn_cast<RankedTensorType>();
248   if (!ranked_type) return {};
249   llvm::ArrayRef<int64_t> dims_except_first =
250       ranked_type.getShape().drop_front();
251   return RankedTensorType::get(dims_except_first, ranked_type.getElementType());
252 }
253 
InsertCast(OpBuilder & b,Location loc,Type dst_type,Value input)254 Operation* InsertCast(OpBuilder& b, Location loc, Type dst_type, Value input) {
255   Type element_type = getElementTypeOrSelf(dst_type);
256   if (element_type.isa<IndexType>())
257     return b.create<tensor::CastOp>(loc, dst_type, input);
258   if (isa<TensorFlowDialect, BuiltinDialect>(element_type.getDialect()))
259     return b.create<TF::CastOp>(loc, dst_type, input,
260                                 /*truncate=*/b.getBoolAttr(false));
261   return nullptr;
262 }
263 
264 // Follow the use chain of TensorList and return true iff all elements written
265 // to TensorList have same static shape. If all elements have same shape, assign
266 // it to `potential_element_type`.
267 //
268 // This can handle multiple mutations of a TensorList object and would return
269 // true if across all mutations the elements written have the same shape.
CanInferTensorListElementType(Value tensorlist,Value initial_element_shape,RankedTensorType * potential_element_type)270 bool CanInferTensorListElementType(Value tensorlist,
271                                    Value initial_element_shape,
272                                    RankedTensorType* potential_element_type) {
273   DCOMMENT("CanInferTensorListElementType " << tensorlist << " with initial "
274                                             << initial_element_shape);
275   // Verifies if the new element type has static shape and matches the potential
276   // type passed from caller. Updates the potential_element_type, if not defined
277   // yet.
278   auto verify_and_update_potential_element_type =
279       [&](RankedTensorType new_element_type) -> bool {
280     DCOMMENT("\t\tConsidering " << new_element_type << " with old "
281                                 << *potential_element_type);
282     if (!new_element_type || !new_element_type.hasStaticShape()) return false;
283     if (!*potential_element_type) {
284       DCOMMENT("\t\tUpdating potential_element_type " << new_element_type);
285       *potential_element_type = new_element_type;
286       return true;
287     }
288     return *potential_element_type == new_element_type;
289   };
290 
291   std::stack<Value> worklist;
292   worklist.emplace(tensorlist);
293 
294   while (!worklist.empty()) {
295     tensorlist = worklist.top();
296     worklist.pop();
297 
298     // TensorLists are semantically immutable. For example, TensorListSetItem
299     // takes a TensorList as input and produces a TensorList as output. So to
300     // traverse modifications to TensorList and verify that all elements written
301     // to it have the same shape, we need to follow use-def chain of ops that
302     // (conceptually) modify it i.e., ops that take an input TensorList and
303     // produce an output TensorList.
304     for (auto& use : tensorlist.getUses()) {
305       if (auto push = llvm::dyn_cast<TensorListPushBackOp>(use.getOwner())) {
306         auto element_type =
307             push.tensor().getType().dyn_cast<RankedTensorType>();
308         if (!verify_and_update_potential_element_type(element_type))
309           return false;
310         worklist.emplace(push.output_handle());
311         continue;
312       }
313       if (auto scatter = llvm::dyn_cast<TensorListScatterIntoExistingListOp>(
314               use.getOwner())) {
315         // For scatter op we can get the element shape by dropping the first
316         // dimension of the input tensor.
317         RankedTensorType element_type =
318             DropFirstDimension(scatter.tensor().getType());
319         if (!verify_and_update_potential_element_type(element_type))
320           return false;
321         worklist.emplace(scatter.output_handle());
322         continue;
323       }
324       if (auto set_item = llvm::dyn_cast<TensorListSetItemOp>(use.getOwner())) {
325         auto element_type =
326             set_item.item().getType().dyn_cast<RankedTensorType>();
327         DCOMMENT("\tTensorListSetItemOp " << element_type);
328         if (!verify_and_update_potential_element_type(element_type))
329           return false;
330         worklist.emplace(set_item.output_handle());
331         continue;
332       }
333       if (auto pop = llvm::dyn_cast<TensorListPopBackOp>(use.getOwner())) {
334         worklist.emplace(pop.output_handle());
335         continue;
336       }
337       if (auto resize = llvm::dyn_cast<TensorListResizeOp>(use.getOwner())) {
338         worklist.emplace(resize.output_handle());
339         continue;
340       }
341       // WhileRegionOp can explicitly capture TensorList value to be used inside
342       // its regions. So we check the uses of corresponding block argument in
343       // each region and the use of TensorList returned using YieldOp.
344       if (auto while_region = llvm::dyn_cast<WhileRegionOp>(use.getOwner())) {
345         DCOMMENT("\tTL WhileRegion");
346         for (auto branch : while_region.getRegions())
347           worklist.emplace(branch->getArgument(use.getOperandNumber()));
348         continue;
349       }
350       if (auto yield = llvm::dyn_cast<YieldOp>(use.getOwner())) {
351         Operation* parent = yield->getParentOp();
352         worklist.emplace(parent->getResult(use.getOperandNumber()));
353         continue;
354       }
355       // TODO(jpienaar): This can be generalized.
356       if (isa<IdentityOp, IdentityNOp, StopGradientOp>(use.getOwner())) {
357         worklist.emplace(use.getOwner()->getResult(use.getOperandNumber()));
358         continue;
359       }
360       // Refining the tensor list element type might change the output of
361       // TensorListElementShape which is expected to be the originally assigned
362       // shape to TensorList init ops. So replace it with the original element
363       // shape value.
364       if (auto tl_element_shape =
365               dyn_cast<TensorListElementShapeOp>(use.getOwner())) {
366         // If element types match, we can do a direct replacement.
367         if (getElementTypeOrSelf(tl_element_shape.getResult()) ==
368             getElementTypeOrSelf(initial_element_shape.getType())) {
369           tl_element_shape.replaceAllUsesWith(initial_element_shape);
370         } else {
371           OpBuilder b(use.getOwner());
372           Operation* cast_op = InsertCast(
373               b, use.getOwner()->getLoc(),
374               tl_element_shape.getResult().getType(), initial_element_shape);
375           if (!cast_op) return false;
376           tl_element_shape.replaceAllUsesWith(cast_op->getResult(0));
377         }
378         continue;
379       }
380       // Ignore ops that just consume a TensorList and do not output another
381       // TensorList.
382       if (isa<TensorListStackOp, TensorListGatherOp, TensorListConcatV2Op,
383               TensorListLengthOp, TensorListGetItemOp>(use.getOwner()))
384         continue;
385 
386       // For any other unknown users of the TensorList, we are conservative and
387       // stop element shape inference.
388       DCOMMENT("TensorListType infer, unknown op " << *use.getOwner());
389       return false;
390     }
391   }
392   return true;
393 }
394 }  // namespace
395 
396 // Returns whether type can be further refined.
CanBeRefined(Type type)397 bool CanBeRefined(Type type) {
398   auto shape_type = type.dyn_cast<ShapedType>();
399   if (!shape_type) return false;
400 
401   // Returns whether type with subtypes can be further refined.
402   auto can_refine_subtypes = [](TF::TensorFlowTypeWithSubtype tws) {
403     return tws.GetSubtypes().empty() ||
404            llvm::any_of(tws.GetSubtypes(), CanBeRefined);
405   };
406   auto type_with_subtype =
407       shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>();
408   if (type_with_subtype && can_refine_subtypes(type_with_subtype)) return true;
409 
410   return !shape_type.hasStaticShape();
411 }
412 
413 // Combination of value producer and port of value produced (e.g.,
414 //   <value result output>:<value in output tensor>,
415 // so for tf.Const -> tensor<10x20xf32>, [0,2,18] would point to a unique output
416 // scalar value).
417 struct ValuePort {
418   PointerUnion<Operation*, BlockArgument> producer;
419   SmallVector<unsigned int, 2> port;
420 
operator ==mlir::TF::ValuePort421   bool operator==(const ValuePort& other) const {
422     return producer == other.producer && port == other.port;
423   }
424 
425   // Convert output value to ValuePort.
ValuePortmlir::TF::ValuePort426   explicit ValuePort(Value v) {
427     OpResult opr = v.dyn_cast<OpResult>();
428     if (opr) {
429       producer = opr.getOwner();
430       port = {opr.getResultNumber()};
431     } else {
432       producer = v.cast<BlockArgument>();
433       port = {0};
434     }
435   }
ValuePortmlir::TF::ValuePort436   ValuePort(PointerUnion<Operation*, BlockArgument> producer,
437             SmallVector<unsigned int, 2> port)
438       : producer(producer), port(port) {}
439 
printmlir::TF::ValuePort440   raw_ostream& print(raw_ostream& os) const {
441     if (auto* op = producer.dyn_cast<Operation*>())
442       os << "op " << op->getName();
443     if (auto ba = producer.dyn_cast<BlockArgument>())
444       os << "block_arg " << ba.getArgNumber();
445     os << formatv(" [{0}]", llvm::make_range(port.begin(), port.end()));
446     return os;
447   }
448 };
449 
450 struct ValuePortHasher {
operator ()mlir::TF::ValuePortHasher451   std::size_t operator()(const ValuePort& other) const {
452     return hash_combine(llvm::hash_value(other.producer.getOpaqueValue()),
453                         hash_value(ArrayRef<unsigned int>(other.port)));
454   }
455 };
456 
457 using ValuePortResultMap =
458     std::unordered_map<ValuePort, Attribute, ValuePortHasher>;
459 using ComputedQueryFn = function_ref<bool(ValuePort)>;
460 using ValueQueryFn = function_ref<Attribute(const ValuePort&)>;
461 using ValuePortInputs = SmallVectorImpl<ValuePort>;
462 
463 // TODO(jpienaar): ComputeInputsRequiredForOutput and ComputeOutputComponent are
464 // intended to be switched to op interfaces once more refined.
ComputeInputsRequiredForOutput(ValuePort value_port,ComputedQueryFn has_been_computed,ValuePortInputs * inputs)465 LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
466                                              ComputedQueryFn has_been_computed,
467                                              ValuePortInputs* inputs) {
468   auto op = value_port.producer.dyn_cast<Operation*>();
469   auto& port = value_port.port;
470   if (!op) return failure();
471 
472   // No inputs required for constants.
473   if (matchPattern(op, m_Constant())) return success();
474 
475   // Note: this focusses only on the trivial pack op case and this could be
476   // generalized.
477   if (auto pack_op = dyn_cast<TF::PackOp>(op)) {
478     auto type = pack_op.getType().cast<TensorType>();
479     if (!type.hasRank() || type.getRank() != 1) return failure();
480     if (port.size() != 2) return failure();
481     assert(port[0] == 0);
482     ValuePort req(pack_op.getOperand(port[1]));
483     if (!has_been_computed(req)) inputs->push_back(req);
484     return success();
485   }
486 
487   return failure();
488 }
489 
490 // Computes the output produced by ValuePort using the query function of
491 // existing computed values.
ComputeOutputComponent(const ValuePort & value_port,ValueQueryFn values)492 Attribute ComputeOutputComponent(const ValuePort& value_port,
493                                  ValueQueryFn values) {
494   LLVM_DEBUG(value_port.print(llvm::dbgs() << "Computing output for ") << "\n");
495   if (auto known = values(value_port)) return known;
496 
497   auto op = value_port.producer.dyn_cast<Operation*>();
498   if (!op) return nullptr;
499   auto& port = value_port.port;
500 
501   if (port.empty()) {
502     LLVM_DEBUG(llvm::dbgs() << "skipping, port outside spec of " << op << "\n");
503     return nullptr;
504   }
505 
506   ElementsAttr attr;
507   if (matchPattern(op, m_Constant(&attr))) {
508     if (port.size() == 1 && port[0] == 0) return attr;
509     return nullptr;
510   }
511 
512   if (auto id = dyn_cast<IdentityOp>(op)) {
513     if (port.size() == 1 && port[0] == 0)
514       return ComputeOutputComponent(ValuePort(id.input()), values);
515     return nullptr;
516   }
517 
518   // Note: this focusses only on the trivial pack op case and this could be
519   // generalized.
520   if (auto pack_op = dyn_cast<TF::PackOp>(op)) {
521     TensorType type = pack_op.getType().cast<TensorType>();
522     if (!type.hasRank() || type.getRank() != 1) return nullptr;
523     if (port.size() != 2 || port[0] != 0) return nullptr;
524     ValuePort op_port(op->getOperand(port[1]));
525     return values(op_port);
526   }
527 
528   if (auto graph = dyn_cast<tf_executor::GraphOp>(op)) {
529     if (port.size() == 1)
530       return ComputeOutputComponent(
531           ValuePort(graph.GetFetch().fetches()[port[0]]), values);
532     return nullptr;
533   }
534 
535   if (auto island = dyn_cast<tf_executor::IslandOp>(op)) {
536     if (port.size() == 1)
537       return ComputeOutputComponent(
538           ValuePort(island.GetYield().fetches()[port[0]]), values);
539     return nullptr;
540   }
541 
542   return nullptr;
543 }
544 
545 // Context used during ShapeInference. This class contains common information
546 // that is required by the individual shape inference helper functions (e.g.,
547 // TF Graph version, constant values computed, etc.)
548 class ShapeInference {
549  public:
550   ShapeInference(int64_t graph_version, ModuleOp module,
551                  bool propagate_caller_callee_constants);
552 
ComputeInputsRequiredForOutput(ValuePort value_port,ValuePortInputs * inputs)553   LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
554                                                ValuePortInputs* inputs) {
555     return ::mlir::TF::ComputeInputsRequiredForOutput(
556         value_port,
557         [this](const ValuePort& port) {
558           return results_.find(port) != results_.end();
559         },
560         inputs);
561   }
562 
ComputeOutputComponent(const ValuePort & value_port)563   Attribute ComputeOutputComponent(const ValuePort& value_port) {
564     if (auto known_attr = results_[value_port]) return known_attr;
565     auto attr = ::mlir::TF::ComputeOutputComponent(
566         value_port, [this](const ValuePort& port) { return results_[port]; });
567     RecordValue(value_port, attr);
568     return attr;
569   }
570 
571   // Returns ShapeHandle if the op result could be computed as shape.
572   ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic);
573 
RecordValue(const ValuePort & value_port,Attribute value)574   void RecordValue(const ValuePort& value_port, Attribute value) {
575     LLVM_DEBUG(value_port.print(llvm::dbgs() << "\trecording ")
576                << value << "\n");
577     results_[value_port] = value;
578   }
579 
580   // Infers shape of tf.While/tf.WhileRegion. If `shape_invariant` attribute is
581   // set, operand types are set as result types if associated body result types
582   // match the operand type (does not change per loop iteration). If operand and
583   // body result types are not the same, only handle types are propagated to
584   // result types. This is necessary to not incorrectly change result shapes
585   // when the While op will have a different result shape. Otherwise operand
586   // shapes are propagated to result shapes.
587   template <typename WhileOpTy>
588   bool InferShapeForWhile(WhileOpTy op, TypeRange body_result_types);
589 
590   // Performs shape inference on the provided op and return true if the type of
591   // at least one result has been changed.
592   // A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect.
593   // `graph_version` indicates the current GraphDef compatibility versions
594   // (the versions field in graph.proto).
595   bool InferShapeForSingleOperation(Operation* op);
596 
597   // Infers shape on the provided region, including nested ones, iterate until
598   // fix point with a limit of max_iteration.
599   // Returns a failure() on error, otherwise returns true to indicate that it
600   // reached convergence, false otherwise.
601   FailureOr<bool> InferShapeUntilFixPoint(Region* region,
602                                           int64_t max_iterations);
603 
604   // Updates input types and refine shapes inside body of functions that are
605   // attached to ControlFlow ops (If/While) or Calls. These functions include
606   // Then/Else branches of IfOp and Cond/Body functions of WhileOp. Functions
607   // attached to control flow share following common properties:
608   //   1) They are never reused, ie. having a single use in module.
609   //   2) Their input types match those of their parent ops (excluding inputs
610   //      like predicate).
611   // For calls, functions can be reused across multiple call sites. In this case
612   // we propagate the types when all call sites have the same operand types.
613   // Returns a failure() on error, otherwise returns true to indicate that it
614   // reached convergence, false otherwise.
615   FailureOr<bool> PropagateShapeToFunctions(ModuleOp module,
616                                             TypeRange input_types,
617                                             ArrayRef<FuncOp> functions,
618                                             int64_t max_iteration);
619 
620   // Propagates shapes to regions given the shapes of the inputs of the regions.
621   // All regions provided in `regions` are assumed to have inputs of type
622   // `input_types`.
623   // Returns a failure() on error, otherwise returns true to indicate that it
624   // reached convergence, false otherwise.
625   FailureOr<bool> PropagateShapeToRegions(TypeRange input_types,
626                                           ArrayRef<Region*> regions,
627                                           int64_t max_iteration);
628 
629   // Shape propagation for call/control flow ops.
630   // Returns a failure() on error, otherwise returns true to indicate that it
631   // reached convergence, false otherwise.
632   FailureOr<bool> PropagateShapeIntoAttachedFunctions(Operation* op,
633                                                       int64_t max_iteration);
634 
635   // Shape propagation for region based control flow.
636   // Returns a failure() on error, otherwise returns true to indicate that it
637   // reached convergence, false otherwise.
638   FailureOr<bool> PropagateShapeIntoAttachedRegions(Operation* op,
639                                                     int64_t max_iterations);
640 
641   // Propagates any constant operand of call_op to the called function body's
642   // corresponding argument if the callee has only one use.
643   //
644   // TODO(b/154065712): Move this to a more general inter-procedural constant
645   // folding pass.
646   void PropagateConstantToCallee(CallOpInterface call_op, FuncOp func,
647                                  ModuleOp module);
648 
649   // Propagates any constant return value of the callee function to the call
650   // op's corresponding result.
651   void PropagateConstantFromCallee(CallOpInterface call_op, FuncOp func,
652                                    ModuleOp module);
653 
654   // Tries to compute the result of folding the op. This doesn't actually
655   // perform constant folding, it is just computes the equivalent constants.
656   // Returns whether it was able to compute constant values.
657   LogicalResult TryToFold(Operation* op);
658 
659   // Makes result types match the operand types (the i-th result type will
660   // match the i-th operand type). Returns true if anything is changed.
661   bool RefineTypeForPassThroughOperands(Operation* op, OperandRange operands,
662                                         ResultRange results);
663 
664   // Makes result type's shape match the corresponding operand's shape.
665   // Returns whether any change was made.
666   bool RefineShapeForPassThroughOps(Operation* op);
667 
668   // Infers shape for necessary ops that are not in the TF dialect. Returns
669   // whether any result type changed.
670   bool InferShapeForNonTFDialectOperation(Operation* op);
671 
672   // Infers shape for function return type and returns whether changed.
673   LogicalResult InferShapeForFunctionReturnType(FuncOp func);
674 
675   // Enqueues function for processing.
enqueue(FuncOp fn)676   void enqueue(FuncOp fn) {
677     LLVM_DEBUG(llvm::dbgs()
678                << "enqueue " << fn.getName() << " ("
679                << (queue_set_.count(fn) ? "already inserted" : "newly inserted")
680                << ")\n");
681     if (queue_set_.insert(fn).second) queue_.push(fn);
682   }
683 
684   // Enqueues callers on functions.
685   void EnqueueCallers(FuncOp fn);
686 
687   // Returns the function at the front of the queue.
front()688   FuncOp front() { return queue_.front(); }
689 
690   // Returns whether work queue is empty.
EmptyQueue() const691   bool EmptyQueue() const { return queue_.empty(); }
692 
693   // Returns function from the front of the work queue.
pop_front()694   FuncOp pop_front() {
695     FuncOp ret = queue_.front();
696     queue_.pop();
697     queue_set_.erase(ret);
698     return ret;
699   }
700 
701   // Returns the current size of the queue.
QueueSize() const702   std::queue<FuncOp>::size_type QueueSize() const { return queue_.size(); }
703 
704   Dialect* const tf_dialect_;
705 
706  private:
707   // Returns whether the result of an operation could be updated to a new
708   // inferred type. Also inserts cast operation for uses that are incompatible
709   // with the new type.
710   bool UpdateTypeAndInsertIncompatibleUseCasts(Type new_type, Value result);
711 
712   // Refines the type of `result` of `op` using the type
713   // `potential_refined_type`. Return true if the type was changed.
714   bool RefineResultType(Operation* op, Value result,
715                         Type potential_refined_type);
716 
717   // Infers the shape from a (Stateful)PartionedCall operation by looking up the
718   // called function and propagating the return type.
719   bool InferShapeForCall(CallOpInterface call_op);
720 
721   bool InferShapeForCast(Operation* op);
722 
723   // Infers the shape IfOp outputs based on the shapes of the then and else
724   // function result types.
725   bool InferShapeForIf(IfOp op);
726 
727   // Infers the shape IfRegion outputs based on the shapes of the then and else
728   // yields.
729   bool InferShapeForIfRegion(IfRegionOp op);
730 
731   // Infers the shape of _XlaHostComputeMlir based on the host computation
732   // module.  Returns true if a return type was changed.
733   bool InferShapeForXlaHostComputeMlir(_XlaHostComputeMlirOp op);
734 
735   // Infers the shape for MapDatasetOp and its associated function. Returns
736   // whether either op or function type was changed.
737   bool InferShapeForMapDataset(MapDatasetOp op);
738 
739   // Infers the shape of ops that create TensorList. Specifically,
740   // TensorListReserveOp, EmptyTensorListOp and TensorListFromTensor ops. It
741   // refines the element shape if all tensors written to the list across all
742   // mutations have identical static shape.
743   bool InferShapeForTensorListInitOps(Operation* op);
744 
745   bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti);
746 
747   // Returns all the callers of a function.
748   // Note: Usage of the return value of this function may not be interleaved
749   // with insertions to the callers map. This could occur if GetCallers is
750   // called with two separate functions, the 2nd one incurs a resize and then
751   // both first and 2nd stored callers are used.
752   ArrayRef<Operation*> GetCallers(FuncOp fn);
753 
754   // Mapping between ValuePort (which corresponds to an OpResult or smaller,
755   // e.g., first element of OpResult produced) to an Attribute if the ValuePort
756   // corresponds to a constant value.
757   ValuePortResultMap results_;
758 
759   // Map from a function to the callers of that function.
760   SymbolTableCollection symbol_table_;
761   SymbolUserMap symbol_users_;
762 
763   // Queue of functions being processed.
764   llvm::DenseSet<FuncOp> queue_set_;
765   std::queue<FuncOp> queue_;
766 
767   int64_t graph_version_;
768 
769   // TODO(b/154065712): Remove propagate_caller_callee_constants once using
770   // SCCP pass instead.
771   bool propagate_caller_callee_constants_;
772 };
773 
ShapeInference(int64_t graph_version,ModuleOp module,bool propagate_caller_callee_constants)774 ShapeInference::ShapeInference(int64_t graph_version, ModuleOp module,
775                                bool propagate_caller_callee_constants)
776     : tf_dialect_(module->getContext()->getLoadedDialect<TensorFlowDialect>()),
777       symbol_users_(symbol_table_, module),
778       graph_version_(graph_version),
779       propagate_caller_callee_constants_(propagate_caller_callee_constants) {}
780 
GetCallers(FuncOp fn)781 ArrayRef<Operation*> ShapeInference::GetCallers(FuncOp fn) {
782   return symbol_users_.getUsers(fn);
783 }
784 
EnqueueCallers(FuncOp fn)785 void ShapeInference::EnqueueCallers(FuncOp fn) {
786   for (auto user : GetCallers(fn)) enqueue(user->getParentOfType<FuncOp>());
787 }
788 
UpdateTypeAndInsertIncompatibleUseCasts(Type new_type,Value result)789 bool ShapeInference::UpdateTypeAndInsertIncompatibleUseCasts(Type new_type,
790                                                              Value result) {
791   Operation* cast_op = nullptr;
792   // First insert cast back for uses that need a cast and then
793   // update the type.
794   bool enqueue_callers = false;
795   for (OpOperand& use : make_early_inc_range(result.getUses())) {
796     if (isa<ReturnOp>(use.getOwner())) {
797       enqueue_callers = true;
798     } else if (NeedsCastBack(use, tf_dialect_)) {
799       if (!cast_op) {
800         Operation* op = result.getDefiningOp();
801         OpBuilder b(op);
802         b.setInsertionPointAfter(op);
803         cast_op = InsertCast(b, op->getLoc(), result.getType(), result);
804         if (!cast_op) return false;
805       }
806       use.set(Value(cast_op->getResult(0)));
807     }
808   }
809 
810   result.setType(new_type);
811   if (enqueue_callers)
812     EnqueueCallers(result.getDefiningOp()->getParentOfType<FuncOp>());
813   return true;
814 }
815 
RefineResultType(Operation * op,Value result,Type potential_refined_type)816 bool ShapeInference::RefineResultType(Operation* op, Value result,
817                                       Type potential_refined_type) {
818   if (!CanRefineTypeWith(result.getType(), potential_refined_type))
819     return false;
820 
821   return UpdateTypeAndInsertIncompatibleUseCasts(potential_refined_type,
822                                                  result);
823 }
824 
825 // Infers the shape from a (Stateful)PartionedCall operation by looking up the
826 // called function and propagating the return type.
InferShapeForCall(CallOpInterface call_op)827 bool ShapeInference::InferShapeForCall(CallOpInterface call_op) {
828   FuncOp func = dyn_cast<FuncOp>(call_op.resolveCallable());
829   if (!func) return false;
830 
831   DCOMMENT("Infer shape for call " << func.getName());
832   Operation* op = call_op.getOperation();
833   bool changed = false;
834   // Map each of the results of the call to the returned type of the
835   // function.
836   for (auto result : zip(op->getResults(), func.getType().getResults())) {
837     changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) ||
838               changed;
839   }
840   DCOMMENT(" - call " << func.getName() << "changed ? " << changed << "\n");
841 
842   return changed;
843 }
844 
InferShapeForCast(Operation * op)845 bool ShapeInference::InferShapeForCast(Operation* op) {
846   DCOMMENT_OP(op, "Inferring shape for ");
847   Value result = op->getResult(0);
848   if (!CanBeRefined(result.getType())) return false;
849 
850   Type operand_type = op->getOperand(0).getType();
851   auto ranked_op_type = operand_type.dyn_cast<RankedTensorType>();
852   if (!ranked_op_type) return false;
853   auto ranked_res_type = result.getType().dyn_cast<RankedTensorType>();
854   if (ranked_res_type &&
855       ranked_op_type.getShape() == ranked_res_type.getShape())
856     return false;
857 
858   // Avoid inserting a cast where no users types could be refined (e.g., where
859   // there would need to be a cast inserted for every user again).
860   if (llvm::all_of(result.getUses(), [this](OpOperand& use) {
861         return NeedsCastBack(use, tf_dialect_);
862       }))
863     return false;
864 
865   auto new_type = RankedTensorType::get(
866       ranked_op_type.getShape(),
867       result.getType().cast<ShapedType>().getElementType());
868 
869   return UpdateTypeAndInsertIncompatibleUseCasts(new_type, op->getResult(0));
870 }
871 
InferShapeForIf(IfOp op)872 bool ShapeInference::InferShapeForIf(IfOp op) {
873   DCOMMENT_OP(op.getOperation(), "Infer shape for if ");
874   bool changed = false;
875   auto then_results = op.then_function().getType().getResults();
876   auto else_results = op.else_function().getType().getResults();
877   for (auto it : llvm::zip(op.getResults(), then_results, else_results)) {
878     // If then and else types do not match, skip refinement for that result.
879     if (std::get<1>(it) != std::get<2>(it)) continue;
880     changed = RefineResultType(op, std::get<0>(it), std::get<1>(it)) || changed;
881   }
882   return changed;
883 }
884 
InferShapeForIfRegion(IfRegionOp op)885 bool ShapeInference::InferShapeForIfRegion(IfRegionOp op) {
886   bool changed = false;
887 
888   Operation* then_yield = op.then_branch().front().getTerminator();
889   Operation* else_yield = op.else_branch().front().getTerminator();
890   for (auto result : zip(op.getResults(), then_yield->getOperandTypes(),
891                          else_yield->getOperandTypes())) {
892     // If then and else types do not match, skip refinement for that result.
893     if (std::get<1>(result) != std::get<2>(result)) continue;
894     changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) ||
895               changed;
896   }
897   return changed;
898 }
899 
InferShapeForXlaHostComputeMlir(_XlaHostComputeMlirOp host_compute_op)900 bool ShapeInference::InferShapeForXlaHostComputeMlir(
901     _XlaHostComputeMlirOp host_compute_op) {
902   // Extract the module and function.
903   // The '_XlaHostComputeMlir` verifier verifies that `host_mlir_module`
904   // attribute is well formed, so we just return in case of an error in
905   // extracting the host function since it should never occur.
906   StringAttr host_module =
907       host_compute_op->getAttrOfType<StringAttr>("host_mlir_module");
908   if (host_module.getValue().empty()) return false;
909 
910   mlir::OwningModuleRef module_for_func;
911   FuncOp func = host_compute_op.GetHostFunc(&module_for_func);
912 
913   // Update/use input shapes for function.
914   FunctionType func_type = func.getType();
915   func.setType(FunctionType::get(func.getContext(),
916                                  host_compute_op.getOperandTypes(),
917                                  func_type.getResults()));
918 
919   // Run shape inference on the function.
920   if (failed(PropagateShapeToRegions(host_compute_op.getOperandTypes(),
921                                      {&func.getBody()}, 10)))
922     return false;
923   if (failed(InferShapeForFunctionReturnType(func))) return false;
924 
925   bool changed = false;
926   // Use refined function return shape for XlaHostComputeMlirOp.
927   for (auto result :
928        zip(host_compute_op.getResults(), func.getType().getResults())) {
929     changed = RefineResultType(host_compute_op, std::get<0>(result),
930                                std::get<1>(result)) ||
931               changed;
932   }
933 
934   return changed;
935 }
936 
InferShapeForMapDataset(MapDatasetOp op)937 bool ShapeInference::InferShapeForMapDataset(MapDatasetOp op) {
938   // MapDatasetOp's relationship with its associated function is as
939   // follows: first M function params are dictated by the the set
940   // output shapes and types, the next N are the last Ninputs from MapDataset
941   // op. The MapDataset op always has N+1 inputs.
942   // TODO(jpienaar): Avoid this lookup.
943   auto module = op->getParentOfType<ModuleOp>();
944   auto f = module.lookupSymbol<FuncOp>(op.f());
945   // Skip if function is not found or more than one caller.
946   if (!f || !llvm::hasSingleElement(GetCallers(f))) return false;
947 
948   int N = op.getNumOperands() - 1;
949   int M = f.getNumArguments() - N;
950   DCOMMENT_OP(op, "Inferring shape for with N = " << N << " and M = " << M);
951 
952   // Initialize with function input types.
953   SmallVector<Type> input_types(f.getArgumentTypes());
954 
955   // Track if changed to skip enqueueing.
956   bool changed = false;
957   auto it = input_types.begin();
958   // First set first M argument shapes & types.
959   for (int i = 0; i < M; ++i) {
960     auto shape = op.output_shapes()[i].cast<tf_type::ShapeAttr>();
961     auto type = op.output_types()[i].cast<TypeAttr>();
962     Type t;
963     if (shape.hasRank())
964       t = RankedTensorType::get(shape.getShape(), type.getValue());
965     else
966       t = UnrankedTensorType::get(type.getValue());
967     t = TypeMeet(*it, t);
968     changed = changed || (t != *it);
969     ++it;
970   }
971   // Now the remaining N from operand types.
972   for (auto t : llvm::drop_begin(op.getOperandTypes())) {
973     auto meet = TypeMeet(*it, t);
974     changed = changed || (meet != *it);
975     *it = meet;
976     ++it;
977   }
978   if (!changed) return false;
979 
980   // TODO(jpienaar): Pipe the max_iteration value through.
981   FailureOr<bool> res =
982       PropagateShapeToFunctions(module, input_types, {f}, /*max_iteration=*/10);
983   if (failed(res)) {
984     LOG(ERROR) << "Propagating shapes for MapDataset failed";
985     return false;
986   }
987   return *res;
988 }
989 
InferShapeForTensorListInitOps(Operation * op)990 bool ShapeInference::InferShapeForTensorListInitOps(Operation* op) {
991   DCOMMENT_OP(op, "Inferring shape for TensorList ");
992   Value handle = op->getResult(0);
993   Value initial_element_shape = GetElementShapeOperand(op);
994   RankedTensorType element_type;
995   if (auto tl_from_tensor = dyn_cast<TensorListFromTensorOp>(op)) {
996     // For TensorListFromTensor op we can infer element shape by dropping the
997     // first dimension of input tensor.
998     element_type = DropFirstDimension(tl_from_tensor.tensor().getType());
999     if (!element_type || !element_type.hasStaticShape()) return false;
1000   }
1001   if (!CanInferTensorListElementType(handle, initial_element_shape,
1002                                      &element_type)) {
1003     DCOMMENT("InferShapeForListInitOps " << op << " could not infer");
1004     return false;
1005   }
1006   DCOMMENT("InferShapeForListInitOps " << *op << " could be inferred "
1007                                        << element_type);
1008   if (!element_type || !element_type.hasStaticShape()) return false;
1009   auto variant_type = VariantType::get(element_type, op->getContext());
1010   auto tensor_type = RankedTensorType::get({}, variant_type);
1011   bool changed = RefineResultType(op, handle, tensor_type);
1012   if (changed) DCOMMENT_OP(op, "Modified after shape inference:");
1013   return changed;
1014 }
1015 
RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti)1016 bool ShapeInference::RefineWithInferTypeOpInterface(
1017     InferTypeOpInterface infer_ti) {
1018   Operation* op = infer_ti.getOperation();
1019   SmallVector<Type, 4> inferred;
1020   LogicalResult res = infer_ti.inferReturnTypes(
1021       op->getContext(), op->getLoc(), op->getOperands(),
1022       op->getAttrDictionary(), op->getRegions(), inferred);
1023   if (failed(res)) {
1024     op->emitOpError("failed to refine type as inference failed");
1025     return false;
1026   }
1027 
1028   if (inferred == op->getResultTypes()) return false;
1029 
1030   // Map each of the results of the call to the returned type of the
1031   // function.
1032   bool changed = false;
1033   for (auto result : zip(op->getResults(), inferred)) {
1034     if (std::get<0>(result).getType() == std::get<1>(result)) continue;
1035 
1036     if (!UpdateTypeAndInsertIncompatibleUseCasts(std::get<1>(result),
1037                                                  std::get<0>(result)))
1038       continue;
1039     changed = true;
1040   }
1041   return changed;
1042 }
1043 
ComputeOutputAsShape(OpResult result,InferenceContext * ic)1044 ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result,
1045                                                  InferenceContext* ic) {
1046   LLVM_DEBUG(result.print(llvm::dbgs() << "\nEvaluate partially "));
1047   auto rt = result.getType().dyn_cast<RankedTensorType>();
1048   if (!rt || !rt.hasStaticShape() || rt.getRank() != 1) return {};
1049   int dim_size = rt.getDimSize(0);
1050 
1051   // Worklist to direct partial evaluation.
1052   SmallVector<ValuePort, 4> worklist;
1053 
1054   // Simple evaluator that attempts to partially evaluate the input value even
1055   // if unable to evaluate the complete output. Below follows a simple stack
1056   // based evaluation where it queries what operands/part of operands need to
1057   // be evaluated and attempting to partially evaluate those operands. It does
1058   // so by pushing the operands that need to be required on to the worklist
1059   // before enqueuing the operation requiering those values.
1060   std::vector<DimensionHandle> dims(dim_size, ic->UnknownDim());
1061   for (unsigned int i = 0, e = dims.size(); i != e; ++i) {
1062     LLVM_DEBUG(llvm::dbgs() << "\nConsidering output dim " << i << "\n");
1063 
1064     worklist.push_back(
1065         ValuePort{result.getOwner(), {result.getResultNumber(), i}});
1066     while (!worklist.empty()) {
1067       auto front = worklist.pop_back_val();
1068       LLVM_DEBUG(front.print(llvm::dbgs() << "\nWorklist front "));
1069 
1070       SmallVector<ValuePort, 4> inputs;
1071       auto res = ComputeInputsRequiredForOutput(front, &inputs);
1072       if (failed(res)) {
1073         // Abort if unable to find which required inputs need to be computed.
1074         worklist.clear();
1075         break;
1076       }
1077 
1078       if (!inputs.empty()) {
1079         // Enqueue required computation followed by its required operands in
1080         // stack.
1081         worklist.push_back(std::move(front));
1082         for (auto& it : inputs) worklist.push_back(std::move(it));
1083         continue;
1084       }
1085 
1086       auto ret = ComputeOutputComponent(front);
1087       if (!ret) continue;
1088 
1089       LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = "));
1090 
1091       // If worklist is empty, then this is the root query op.
1092       if (worklist.empty()) {
1093         LLVM_DEBUG(llvm::dbgs() << "[root node]\n");
1094         if (auto dea = ret.dyn_cast<DenseIntElementsAttr>()) {
1095           if (dea.getNumElements() != 1) {
1096             LLVM_DEBUG(llvm::dbgs() << "Unexpected number of elements\n");
1097             return {};
1098           }
1099           int64_t val = (*dea.getIntValues().begin()).getSExtValue();
1100           dims[i] = ic->MakeDim(val);
1101         }
1102       }
1103     }
1104   }
1105   return ic->MakeShape(dims);
1106 }
1107 
RefineTypeForPassThroughOperands(Operation * op,OperandRange operands,ResultRange results)1108 bool ShapeInference::RefineTypeForPassThroughOperands(Operation* op,
1109                                                       OperandRange operands,
1110                                                       ResultRange results) {
1111   bool changed = false;
1112   for (auto entry : llvm::zip(operands, results)) {
1113     Type operand_type = std::get<0>(entry).getType();
1114     Value result = std::get<1>(entry);
1115     TensorType result_type = result.getType().cast<TensorType>();
1116     Type inferred_type = TypeMeet(result_type, operand_type);
1117     if (result_type == inferred_type) continue;
1118 
1119     if (!UpdateTypeAndInsertIncompatibleUseCasts(inferred_type, result))
1120       continue;
1121     changed = true;
1122   }
1123   return changed;
1124 }
1125 
RefineShapeForPassThroughOps(Operation * op)1126 bool ShapeInference::RefineShapeForPassThroughOps(Operation* op) {
1127   DCOMMENT_OP(op, "Pass through op");
1128   bool changed = false;
1129   for (auto entry : llvm::zip(op->getOperands(), op->getResults())) {
1130     Value operand = std::get<0>(entry);
1131     Value result = std::get<1>(entry);
1132     Type inferred_type = TypeMeet(result.getType(), operand.getType());
1133     if (result.getType() == inferred_type) continue;
1134     if (!UpdateTypeAndInsertIncompatibleUseCasts(inferred_type, result))
1135       continue;
1136     changed = true;
1137   }
1138   return changed;
1139 }
1140 
InferShapeForNonTFDialectOperation(Operation * op)1141 bool ShapeInference::InferShapeForNonTFDialectOperation(Operation* op) {
1142   if (auto graph_op = dyn_cast<tf_executor::GraphOp>(op)) {
1143     return RefineTypeForPassThroughOperands(
1144         graph_op.GetFetch(), graph_op.GetFetch().fetches(), op->getResults());
1145   }
1146   if (auto island_op = dyn_cast<tf_executor::IslandOp>(op)) {
1147     return RefineTypeForPassThroughOperands(
1148         island_op.GetYield(), island_op.GetYield().fetches(), op->getResults());
1149   }
1150   if (auto iter_sink = dyn_cast<tf_executor::NextIterationSinkOp>(op)) {
1151     auto iter_source = cast<tf_executor::NextIterationSourceOp>(
1152         iter_sink.token().getDefiningOp());
1153     return RefineTypeForPassThroughOperands(
1154         op, iter_sink.getOperands().drop_front().take_front(),
1155         iter_source.getResults());
1156   }
1157   if (auto launch_op = dyn_cast<tf_device::LaunchOp>(op)) {
1158     auto terminator = launch_op.GetBody().getTerminator();
1159     return RefineTypeForPassThroughOperands(op, terminator->getOperands(),
1160                                             op->getResults());
1161   }
1162   if (auto cluster_op = dyn_cast<tf_device::ClusterOp>(op)) {
1163     auto terminator = cluster_op.GetBody().getTerminator();
1164     return RefineTypeForPassThroughOperands(op, terminator->getOperands(),
1165                                             op->getResults());
1166   }
1167   if (op->hasTrait<OpTrait::SameOperandsAndResultShape>())
1168     return RefineShapeForPassThroughOps(op);
1169   if (auto call = dyn_cast<CallOpInterface>(op)) return InferShapeForCall(call);
1170   return false;
1171 }
1172 
1173 // Finds element type to be used for result from operand, with special handling
1174 // for handle types.
GetElementTypeFromOperand(TensorType operand_type,TensorType result_type)1175 Type GetElementTypeFromOperand(TensorType operand_type,
1176                                TensorType result_type) {
1177   auto operand_handle_type =
1178       operand_type.getElementType().dyn_cast<TensorFlowTypeWithSubtype>();
1179   if (!operand_handle_type) return result_type.getElementType();
1180   auto result_handle_type =
1181       result_type.getElementType().cast<TensorFlowTypeWithSubtype>();
1182   if (operand_handle_type.GetSubtypes().empty() ||
1183       !result_handle_type.GetSubtypes().empty())
1184     return result_type.getElementType();
1185   return operand_handle_type;
1186 }
1187 
1188 // Checks if one tensor type can refine another type for tf.While/
1189 // tf.WhileRegion. If rank differs or static dimensions can be lost, the other
1190 // type cannot be used for refinement.
CanWhileTypeBeRefinedWith(TensorType current_type,TensorType potential_refined_type)1191 bool CanWhileTypeBeRefinedWith(TensorType current_type,
1192                                TensorType potential_refined_type) {
1193   if (!current_type.hasRank()) return true;
1194   if (!potential_refined_type.hasRank()) return false;
1195   if (current_type.getRank() != potential_refined_type.getRank()) return false;
1196   for (auto dim :
1197        llvm::zip(current_type.getShape(), potential_refined_type.getShape())) {
1198     int64_t current_dim = std::get<0>(dim);
1199     int64_t potential_refined_dim = std::get<1>(dim);
1200     if (current_dim != potential_refined_dim &&
1201         current_dim != ShapedType::kDynamicSize)
1202       return false;
1203   }
1204   return true;
1205 }
1206 
1207 template <typename WhileOpTy>
InferShapeForWhile(WhileOpTy op,TypeRange body_result_types)1208 bool ShapeInference::InferShapeForWhile(WhileOpTy op,
1209                                         TypeRange body_result_types) {
1210   if (!op.shape_invariant())
1211     return RefineTypeForPassThroughOperands(op, op.input(), op.output());
1212 
1213   bool changed = false;
1214   for (auto entry :
1215        zip(op.input().getTypes(), op.output(), body_result_types)) {
1216     Value result = std::get<1>(entry);
1217     TensorType body_result_type =
1218         std::get<2>(entry).template cast<TensorType>();
1219     auto result_type = result.getType().cast<TensorType>();
1220 
1221     Type potential_refined_type;
1222     if (CanWhileTypeBeRefinedWith(result_type, body_result_type)) {
1223       Type element_type =
1224           GetElementTypeFromOperand(body_result_type, result_type);
1225       potential_refined_type = CreateTensorType(
1226           body_result_type.hasRank() ? body_result_type.getShape()
1227                                      : llvm::Optional<ArrayRef<int64_t>>(),
1228           element_type);
1229     } else {
1230       TensorType operand_type = std::get<0>(entry).template cast<TensorType>();
1231       Type element_type = GetElementTypeFromOperand(operand_type, result_type);
1232       potential_refined_type = CreateTensorType(
1233           result_type.hasRank() ? result_type.getShape()
1234                                 : llvm::Optional<ArrayRef<int64_t>>(),
1235           element_type);
1236     }
1237     changed |= RefineResultType(op, result, potential_refined_type);
1238   }
1239   return changed;
1240 }
1241 
InferShapeForSingleOperation(Operation * op)1242 bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
1243   LLVM_DEBUG(op->print(llvm::dbgs() << "InferShapeForSingleOperation for ");
1244              llvm::dbgs() << "\n");
1245   assert(tf_dialect_ == op->getDialect());
1246   // The shape function of these ops sometimes does not propagate subtypes
1247   // (handle shapes) for resource and variant types. We use a simple passthrough
1248   // to make sure they are preserved in the output.
1249   if (isa<TF::IdentityOp, TF::IdentityNOp, TF::StopGradientOp, TF::ZerosLikeOp>(
1250           op)) {
1251     return RefineTypeForPassThroughOperands(op, op->getOperands(),
1252                                             op->getResults());
1253   }
1254 
1255   // If no result for this op needs shape inference, we have a fast-path return.
1256   // But if the type is a resource/variant, we do not skip it because we might
1257   // not have the handle shapes.
1258   if (none_of(op->getResultTypes(), CanBeRefined)) {
1259     LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '"
1260                             << op->getName() << "'.\n");
1261     return false;
1262   }
1263 
1264   // Handle call operations by looking up callee and inferring return shape as
1265   // needed.
1266   if (auto call = dyn_cast<CallOpInterface>(op)) return InferShapeForCall(call);
1267 
1268   // tf.Cast and tensor::Cast are only inferred if they have at least one user
1269   // in the TF dialect or feeding into the function return. This is necessary to
1270   // avoid inserting casts which cannot be refined.
1271   if (isa<CastOp, tensor::CastOp>(op)) return InferShapeForCast(op);
1272 
1273   // Handle IfOp here by inferring the shape from the else/then function
1274   // results. Since `output_shapes` is a derived attribute, avoid going down the
1275   // TF InferenceContext path as IfOp shape inference is implemented as just
1276   // a lookup of the output_shapes attribute.
1277   if (auto if_op = dyn_cast<IfOp>(op)) return InferShapeForIf(if_op);
1278 
1279   // Handle IfRegion operations by inferring return shape from the then and else
1280   // branches.
1281   if (auto if_region = dyn_cast<IfRegionOp>(op))
1282     return InferShapeForIfRegion(if_region);
1283 
1284   if (auto while_op = dyn_cast<WhileOp>(op))
1285     return InferShapeForWhile(while_op,
1286                               while_op.body_function().getType().getResults());
1287 
1288   if (auto while_region = dyn_cast<WhileRegionOp>(op))
1289     return InferShapeForWhile(
1290         while_region,
1291         while_region.body().front().getTerminator()->getOperandTypes());
1292 
1293   if (auto host_compute_op = dyn_cast<_XlaHostComputeMlirOp>(op)) {
1294     return InferShapeForXlaHostComputeMlir(host_compute_op);
1295   }
1296 
1297   if (auto map_dataset_op = dyn_cast<MapDatasetOp>(op))
1298     return InferShapeForMapDataset(map_dataset_op);
1299 
1300   // Handle TensorList init operations by inferring shape from TensorList write
1301   // operations. If we are unable to refine element shape here, proceed to use
1302   // the InferenceContext below to get more precise shapes.
1303   if (IsTensorListInitOp(op) && InferShapeForTensorListInitOps(op)) return true;
1304 
1305   // Return operand as a constant attribute.
1306   auto operand_as_constant_fn = [&](Value operand) {
1307     ValuePort vp(operand);
1308     Attribute attr = ComputeOutputComponent(vp);
1309     if (!attr && matchPattern(operand, m_Constant(&attr)))
1310       RecordValue(vp, attr);
1311     return attr;
1312   };
1313 
1314   // Return op result as a shape.
1315   auto op_result_as_shape_fn = [&](InferenceContext& context,
1316                                    OpResult op_result) {
1317     return ComputeOutputAsShape(op_result, &context);
1318   };
1319 
1320   // Return result element type at `index`.
1321   auto result_element_type_fn = [&](int index) {
1322     return op->getResult(index).getType().cast<TensorType>().getElementType();
1323   };
1324 
1325   llvm::SmallVector<ShapedTypeComponents, 4> inferred_return_shapes;
1326   if (failed(InferReturnTypeComponentsForTFOp(
1327           /*location=*/None, op, graph_version_, operand_as_constant_fn,
1328           op_result_as_shape_fn, result_element_type_fn,
1329           inferred_return_shapes)))
1330     return false;
1331 
1332   // Update the shape for each of the operation result if the InferenceContext
1333   // has more precise shapes recorded.
1334   bool changed = false;
1335   for (auto result : llvm::zip(op->getResults(), inferred_return_shapes)) {
1336     Value op_result = std::get<0>(result);
1337     if (!CanBeRefined(op_result.getType())) continue;
1338 
1339     ShapedTypeComponents inferred = std::get<1>(result);
1340     TensorType inferred_type;
1341     if (inferred.hasRank())
1342       inferred_type =
1343           RankedTensorType::get(inferred.getDims(), inferred.getElementType());
1344     else
1345       inferred_type = UnrankedTensorType::get(inferred.getElementType());
1346 
1347     inferred_type =
1348         TypeMeet(op_result.getType(), inferred_type).cast<TensorType>();
1349     if (op_result.getType() == inferred_type) continue;
1350     if (!UpdateTypeAndInsertIncompatibleUseCasts(inferred_type, op_result))
1351       continue;
1352     changed = true;
1353   }
1354 
1355   if (changed) DCOMMENT_OP(op, "Modified after shape inference:");
1356   return changed;
1357 }
1358 
PropagateShapeToFunctions(ModuleOp module,TypeRange input_types,ArrayRef<FuncOp> functions,int64_t max_iteration)1359 FailureOr<bool> ShapeInference::PropagateShapeToFunctions(
1360     ModuleOp module, TypeRange input_types, ArrayRef<FuncOp> functions,
1361     int64_t max_iteration) {
1362   bool any_failure = false;
1363   bool any_nonconvergence = false;
1364   // If shape propagation fails for one function, return failure, but do not
1365   // early exit and attempt to propagate shapes for all provided functions to
1366   // have a best-effort propagation.
1367   for (FuncOp func : functions) {
1368     DCOMMENT("Propating shape to " << func.getName());
1369     ArrayRef<Operation*> callers = GetCallers(func);
1370     if (!llvm::hasSingleElement(callers) &&
1371         !llvm::all_of(callers.drop_front(), [&](Operation* caller) {
1372           /// TODO(aminim): this is overly conservative as some operations
1373           /// (like TPUPartitionedCallOp) may have extra operands that aren't
1374           /// propagated to the callee.
1375           return isa<CallOpInterface>(caller) &&
1376                  std::equal(caller->getOperandTypes().begin(),
1377                             caller->getOperandTypes().end(),
1378                             callers.front()->getOperandTypes().begin());
1379         })) {
1380       if (llvm::any_of(callers, [](Operation* op) {
1381             return isa<IfOp, WhileOp, CaseOp>(op);
1382           }))
1383         func.emitWarning(formatv(
1384             "expected control flow function @{0} to have exactly 1 use, "
1385             "found {1}.",
1386             func.getName(), callers.size()));
1387 
1388       continue;
1389     }
1390     FunctionType func_type = func.getType();
1391     func.setType(FunctionType::get(func.getContext(), input_types,
1392                                    func_type.getResults()));
1393 
1394     FailureOr<bool> failure_or_converged =
1395         PropagateShapeToRegions(input_types, {&func.getBody()}, max_iteration);
1396     if (failed(failure_or_converged)) {
1397       any_failure = true;
1398       continue;
1399     }
1400     any_nonconvergence = any_nonconvergence || !failure_or_converged.getValue();
1401     if (failed(InferShapeForFunctionReturnType(func))) any_failure = true;
1402   }
1403   if (any_failure) return failure();
1404   return any_nonconvergence;
1405 }
1406 
PropagateShapeToRegions(TypeRange input_types,ArrayRef<Region * > regions,int64_t max_iteration)1407 FailureOr<bool> ShapeInference::PropagateShapeToRegions(
1408     TypeRange input_types, ArrayRef<Region*> regions, int64_t max_iteration) {
1409   DCOMMENT("\tPropagating shapes to regions");
1410   bool any_failure = false;
1411   bool any_nonconvergence = false;
1412   // If shape propagation fails for one region, return failure, but do not
1413   // early exit and attempt to propagate shapes for all provided regions to
1414   // have a best-effort propagation.
1415   for (auto region : regions) {
1416     // Refine region arguments.
1417     Block& entry = region->front();
1418     assert(llvm::size(input_types) == entry.getNumArguments());
1419     for (auto it : llvm::zip(entry.getArguments(), input_types)) {
1420       BlockArgument arg = std::get<0>(it);
1421       Type type = std::get<1>(it);
1422       arg.setType(type);
1423     }
1424 
1425     // Propagate shapes into the region.
1426     FailureOr<bool> failure_or_converged =
1427         InferShapeUntilFixPoint(region, max_iteration);
1428     if (failed(failure_or_converged))
1429       any_failure = true;
1430     else if (!failure_or_converged.getValue())
1431       any_nonconvergence = true;
1432   }
1433   if (any_failure) return failure();
1434   return any_nonconvergence;
1435 }
1436 
PropagateConstantToCallee(CallOpInterface call_op,FuncOp func,ModuleOp module)1437 void ShapeInference::PropagateConstantToCallee(CallOpInterface call_op,
1438                                                FuncOp func, ModuleOp module) {
1439   auto callers = GetCallers(func);
1440   if (!llvm::hasSingleElement(callers)) return;
1441 
1442   OpBuilder builder(&func.front().front());
1443   Operation* op = call_op.getOperation();
1444   // If this is the only caller, and an operand is a constant, propagate
1445   // the constant value inside the function.
1446   for (auto arg : func.getArguments()) {
1447     auto operand = op->getOperand(arg.getArgNumber());
1448     if (propagate_caller_callee_constants_) {
1449       if (isa_and_nonnull<TF::ConstOp>(operand.getDefiningOp())) {
1450         arg.replaceAllUsesWith(
1451             builder.clone(*operand.getDefiningOp())->getResult(0));
1452       }
1453       continue;
1454     }
1455 
1456     auto known_constant = ComputeOutputComponent(ValuePort(operand));
1457     if (!known_constant) continue;
1458     LLVM_DEBUG(call_op.print(llvm::dbgs() << "Propagate to calee: ");
1459                known_constant.print(llvm::dbgs() << " constant ");
1460                llvm::dbgs() << "\n");
1461     RecordValue(ValuePort(arg), known_constant);
1462   }
1463 }
1464 
PropagateConstantFromCallee(CallOpInterface call_op,FuncOp func,ModuleOp module)1465 void ShapeInference::PropagateConstantFromCallee(CallOpInterface call_op,
1466                                                  FuncOp func, ModuleOp module) {
1467   // If the return value is a constant, use the constant as the value of
1468   // the call return.
1469   Operation* op = call_op.getOperation();
1470   OpBuilder builder(op);
1471   builder.setInsertionPointAfter(op);
1472   for (auto retval :
1473        llvm::enumerate(func.front().getTerminator()->getOperands())) {
1474     if (propagate_caller_callee_constants_) {
1475       auto retval_op = retval.value().getDefiningOp();
1476       if (isa_and_nonnull<TF::ConstOp>(retval_op)) {
1477         op->getResult(retval.index())
1478             .replaceAllUsesWith(builder.clone(*retval_op)->getResult(0));
1479       }
1480       continue;
1481     }
1482 
1483     ValuePort vp(retval.value());
1484     if (auto known_constant = ComputeOutputComponent(vp)) {
1485       LLVM_DEBUG(known_constant.print(llvm::dbgs() << "Propagate constant ");
1486                  call_op.print(llvm::dbgs() << "from "); llvm::dbgs() << "\n");
1487       RecordValue(ValuePort(op->getResult(retval.index())), known_constant);
1488     }
1489   }
1490 }
1491 
RankedAndSameRank(TensorType lhs,TensorType rhs)1492 bool RankedAndSameRank(TensorType lhs, TensorType rhs) {
1493   return lhs.hasRank() && rhs.hasRank() && lhs.getRank() == rhs.getRank();
1494 }
1495 
1496 // Creates a compatible RankedTensorType where mismatched dimensions are
1497 // replaced with dynamic sizes.
GetCompatibleRankedTensorType(RankedTensorType lhs,RankedTensorType rhs)1498 RankedTensorType GetCompatibleRankedTensorType(RankedTensorType lhs,
1499                                                RankedTensorType rhs) {
1500   assert(lhs.getRank() == rhs.getRank());
1501   llvm::SmallVector<int64_t, 4> dims;
1502   dims.reserve(lhs.getRank());
1503   for (auto dim : llvm::zip(lhs.getShape(), rhs.getShape())) {
1504     int64_t lhs_dim = std::get<0>(dim);
1505     if (lhs_dim == std::get<1>(dim)) {
1506       dims.push_back(lhs_dim);
1507     } else {
1508       dims.push_back(ShapedType::kDynamicSize);
1509     }
1510   }
1511   return RankedTensorType::get(dims, GetElementTypeFromOperand(lhs, rhs));
1512 }
1513 
1514 // Finds compatible types to propagate into functions/regions of a shape
1515 // invariant tf.While/tf.WhileRegion. If operand and result types are the same,
1516 // that type is returned. If operand and result types are of the same rank, a
1517 // compatible type with matching dimensions is used. Otherwise functions/regions
1518 // arguments are returned but with the handle type from the operand type.
GetWhileCompatibleTypes(TypeRange operand_types,TypeRange result_types,TypeRange region_argument_types)1519 llvm::SmallVector<Type, 4> GetWhileCompatibleTypes(
1520     TypeRange operand_types, TypeRange result_types,
1521     TypeRange region_argument_types) {
1522   llvm::SmallVector<Type, 4> types;
1523   types.reserve(operand_types.size());
1524   for (auto entry :
1525        llvm::zip(operand_types, result_types, region_argument_types)) {
1526     auto operand_type = std::get<0>(entry).cast<TensorType>();
1527     auto result_type = std::get<1>(entry).cast<TensorType>();
1528     if (operand_type == result_type) {
1529       types.push_back(operand_type);
1530     } else if (RankedAndSameRank(operand_type, result_type)) {
1531       auto potential_refined_type =
1532           GetCompatibleRankedTensorType(operand_type.cast<RankedTensorType>(),
1533                                         result_type.cast<RankedTensorType>());
1534       types.push_back(potential_refined_type);
1535     } else {
1536       auto region_argument_type = std::get<2>(entry).cast<TensorType>();
1537       Type element_type = GetElementTypeFromOperand(
1538           operand_type.cast<TensorType>(), region_argument_type);
1539       Type potential_refined_type = CreateTensorType(
1540           region_argument_type.hasRank() ? region_argument_type.getShape()
1541                                          : llvm::Optional<ArrayRef<int64_t>>(),
1542           element_type);
1543       types.push_back(potential_refined_type);
1544     }
1545   }
1546   return types;
1547 }
1548 
PropagateShapeIntoAttachedFunctions(Operation * op,int64_t max_iteration)1549 FailureOr<bool> ShapeInference::PropagateShapeIntoAttachedFunctions(
1550     Operation* op, int64_t max_iteration) {
1551   ModuleOp module = op->getParentOfType<ModuleOp>();
1552   if (auto if_op = dyn_cast<TF::IfOp>(op)) {
1553     DCOMMENT("Propagating shapes into If");
1554     return PropagateShapeToFunctions(
1555         module, if_op.input().getTypes(),
1556         {if_op.then_function(), if_op.else_function()}, max_iteration);
1557   } else if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
1558     SmallVector<FuncOp, 4> branches;
1559     case_op.get_branch_functions(branches);
1560     return PropagateShapeToFunctions(module, case_op.input().getTypes(),
1561                                      branches, max_iteration);
1562   } else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
1563     // If `shape_invariant` is set, operand shapes cannot be simply propagated
1564     // to result shapes as the op may have different intermediate shapes (such
1565     // While ops can have different result shapes from operand shapes).
1566     // Compatible shapes must be determined before propagating them.
1567     if (while_op.shape_invariant()) {
1568       auto compatible_types = GetWhileCompatibleTypes(
1569           while_op.input().getTypes(), while_op.output().getTypes(),
1570           while_op.body_function().getType().getInputs());
1571       return PropagateShapeToFunctions(
1572           module, compatible_types,
1573           {while_op.cond_function(), while_op.body_function()}, max_iteration);
1574     }
1575     return PropagateShapeToFunctions(
1576         module, while_op.input().getTypes(),
1577         {while_op.cond_function(), while_op.body_function()}, max_iteration);
1578   } else if (auto call_op = dyn_cast<CallOpInterface>(op)) {
1579     if (auto func = dyn_cast<FuncOp>(call_op.resolveCallable())) {
1580       PropagateConstantToCallee(call_op, func, module);
1581       FailureOr<bool> failure_or_converged = PropagateShapeToFunctions(
1582           module, call_op.getArgOperands().getTypes(), {func}, max_iteration);
1583       if (failed(failure_or_converged)) return failure();
1584       PropagateConstantFromCallee(call_op, func, module);
1585       return failure_or_converged;
1586     }
1587   }
1588 
1589   // TODO(ycao): Implement support for Call op, including function reuse.
1590 
1591   return true;
1592 }
1593 
PropagateShapeIntoAttachedRegions(Operation * op,int64_t max_iteration)1594 FailureOr<bool> ShapeInference::PropagateShapeIntoAttachedRegions(
1595     Operation* op, int64_t max_iteration) {
1596   if (auto while_op = dyn_cast<TF::WhileRegionOp>(op)) {
1597     // If `shape_invariant` is set, operand shapes cannot be simply propagated
1598     // to result shapes as the op may have different intermediate shapes (such
1599     // While ops can have different result shapes from operand shapes).
1600     // Compatible shapes must be determined before propagating them.
1601     if (while_op.shape_invariant()) {
1602       auto compatible_types = GetWhileCompatibleTypes(
1603           while_op.input().getTypes(), while_op.output().getTypes(),
1604           while_op.body().getArgumentTypes());
1605       return PropagateShapeToRegions(compatible_types,
1606                                      {&while_op.cond(), &while_op.body()},
1607                                      max_iteration);
1608     }
1609     return PropagateShapeToRegions(while_op.input().getTypes(),
1610                                    {&while_op.cond(), &while_op.body()},
1611                                    max_iteration);
1612   }
1613   return true;
1614 }
1615 
TryToFold(Operation * op)1616 LogicalResult ShapeInference::TryToFold(Operation* op) {
1617   LLVM_DEBUG(op->print(llvm::dbgs() << "TryToFold "); llvm::dbgs() << "\n");
1618   // If any output result is known, then the op probably has been computed
1619   // before.
1620   if (op->getNumResults() > 0 && results_[ValuePort(op->getResult(0))])
1621     return success();
1622 
1623   SmallVector<Attribute, 8> constant_operands(op->getNumOperands());
1624   SmallVector<OpFoldResult, 8> fold_results;
1625 
1626   // Check to see if any operands to the operation is constant and whether
1627   // the operation knows how to constant fold itself.
1628   bool some_unknown = false;
1629   for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
1630     if (!(constant_operands[i] =
1631               ComputeOutputComponent(ValuePort(op->getOperand(i)))))
1632       some_unknown = true;
1633   }
1634 
1635   // Attempt to constant fold the operation.
1636   auto* abstract_op = op->getAbstractOperation();
1637   LogicalResult folded = failure();
1638   if (abstract_op) {
1639     folded = abstract_op->foldHook(op, constant_operands, fold_results);
1640   }
1641   // Attempt dialect fallback if op's fold hook failed.
1642   if (failed(folded)) {
1643     Dialect* dialect = op->getDialect();
1644     if (!dialect) return failure();
1645     // Only attempt TF dialect fallback if there are no unknown operands.
1646     if (some_unknown && dialect == tf_dialect_) return failure();
1647     auto* interface = dialect->getRegisteredInterface<DialectFoldInterface>();
1648     if (!interface) return failure();
1649 
1650     if (failed(interface->fold(op, constant_operands, fold_results)))
1651       return failure();
1652   }
1653 
1654   for (auto result : zip(op->getResults(), fold_results)) {
1655     auto fold_result = std::get<1>(result);
1656     Attribute attr = nullptr;
1657     if ((attr = fold_result.dyn_cast<Attribute>())) {
1658       RecordValue(ValuePort(std::get<0>(result)), attr);
1659     } else {
1660       auto value = fold_result.get<Value>();
1661       if ((attr = ComputeOutputComponent(ValuePort(value)))) {
1662         DCOMMENT("\t\tValue Result mapped to " << attr);
1663         RecordValue(ValuePort(std::get<0>(result)), attr);
1664       } else {
1665         DCOMMENT("\t\tValue result unmapped, consider value type:" << value);
1666         RefineResultType(op, std::get<0>(result), value.getType());
1667       }
1668     }
1669 
1670     if (ElementsAttr eattr = attr.dyn_cast_or_null<ElementsAttr>()) {
1671       if (std::get<0>(result).getType() == eattr.getType()) continue;
1672 
1673       (void)UpdateTypeAndInsertIncompatibleUseCasts(eattr.getType(),
1674                                                     std::get<0>(result));
1675     }
1676   }
1677 
1678   return success();
1679 }
1680 
InferShapeForFunctionReturnType(FuncOp func)1681 LogicalResult ShapeInference::InferShapeForFunctionReturnType(FuncOp func) {
1682   LLVM_DEBUG(llvm::dbgs() << "Inferring return type for: " << func.getName()
1683                           << "\n");
1684 
1685   // Find any return ops.
1686   SmallVector<ReturnOp, 4> return_ops;
1687   for (Block& block : func) {
1688     if (auto return_op = dyn_cast<ReturnOp>(block.getTerminator())) {
1689       return_ops.push_back(return_op);
1690     }
1691   }
1692 
1693   // Skip functions without a return, but don't flag as failure here.
1694   if (return_ops.empty()) return success();
1695 
1696   // Right now we only handle the case of a single return op.
1697   // To handle multiple return ops, we would need to look at all their shapes
1698   // and come up with a common shape and insert appropriate casts.
1699   if (return_ops.size() != 1) return failure();
1700 
1701   // Find the return type.
1702   auto return_op = return_ops.front();
1703 
1704   // Manually fold tf.Cast that precedes the return instruction and only differs
1705   // in shape refinement level.
1706   bool changed = false;
1707   for (OpOperand& arg_op : return_op.getOperation()->getOpOperands()) {
1708     Operation* arg_defining_op = arg_op.get().getDefiningOp();
1709     if (isa_and_nonnull<CastOp, tensor::CastOp>(arg_defining_op)) {
1710       Value input = arg_defining_op->getOperand(0);
1711       Value result = arg_defining_op->getResult(0);
1712       Type meet = TypeMeet(result.getType(), input.getType());
1713       if (meet == result.getType()) continue;
1714 
1715       LLVM_DEBUG({
1716         llvm::errs() << "\tfolding & updating return type ";
1717         result.getType().print(llvm::errs());
1718         input.getType().print(llvm::errs() << " to ");
1719         llvm::errs() << "\n";
1720       });
1721 
1722       // Shape inference should not change the element type.
1723       if (HasCompatibleElementTypes(input.getType(), result.getType()) &&
1724           meet == input.getType()) {
1725         arg_op.set(input);
1726       } else {
1727         OpBuilder b(return_op.getOperation());
1728         auto new_cast_op = InsertCast(b, return_op.getLoc(), meet, input);
1729         if (!new_cast_op) return failure();
1730         arg_op.set(new_cast_op->getResult(0));
1731       }
1732       if (result.use_empty()) arg_defining_op->erase();
1733       changed = true;
1734     }
1735   }
1736 
1737   DCOMMENT("Updating function type");
1738   func.setType(FunctionType::get(func.getContext(), func.getArgumentTypes(),
1739                                  return_op.getOperandTypes()));
1740 
1741   if (changed) EnqueueCallers(func);
1742   return success();
1743 }
1744 
InferShapeUntilFixPoint(Region * region,int64_t max_iteration)1745 FailureOr<bool> ShapeInference::InferShapeUntilFixPoint(Region* region,
1746                                                         int64_t max_iteration) {
1747   bool changed = true;
1748 
1749   // TODO(aminim): we could have a more efficient traversal by guiding the
1750   // traversal with a worklist and reconsider only the nodes for which an
1751   // operand type was inferred. This would need to be careful if working on a
1752   // region that would not be isolated.
1753   for (int iteration = 0; iteration < max_iteration && changed; ++iteration) {
1754     changed = false;
1755     LLVM_DEBUG(llvm::dbgs()
1756                << "Shape inference, iteration " << iteration << "\n");
1757     auto res = region->walk([&](Operation* op) {
1758       DCOMMENT_OP(op, "Inferring for");
1759       if (auto infer_ti = dyn_cast<InferTypeOpInterface>(op)) {
1760         DCOMMENT("\tRefinining with type op interface");
1761         changed |= RefineWithInferTypeOpInterface(infer_ti);
1762         return WalkResult::advance();
1763       }
1764 
1765       if (op->getDialect() != tf_dialect_) {
1766         DCOMMENT("\tInfer non-TF dialect");
1767         changed |= InferShapeForNonTFDialectOperation(op);
1768         return WalkResult::advance();
1769       }
1770 
1771       // Before attempting inference, just try to compute the folded
1772       // value/shape.
1773       if (succeeded(TryToFold(op)) &&
1774           // Folding can "succeed" and yet not all types be refined. In such
1775           // cases we still want to give a try at `InferShapeForSingleOperation`
1776           none_of(op->getResultTypes(), CanBeRefined))
1777         return WalkResult::advance();
1778 
1779       // Best-effort shape inference in attached functions. Do not return
1780       // failure even if it doesn't get to fixed point, but propagate "real"
1781       // failure.
1782       if (failed(PropagateShapeIntoAttachedFunctions(op, max_iteration))) {
1783         op->emitWarning() << "unable to refine shape of attached function "
1784                              "arguments and bodies";
1785         return WalkResult::interrupt();
1786       }
1787 
1788       if (failed(PropagateShapeIntoAttachedRegions(op, max_iteration))) {
1789         op->emitWarning() << "unable to refine shape of attached region "
1790                              "arguments and bodies";
1791         return WalkResult::interrupt();
1792       }
1793 
1794       changed |= InferShapeForSingleOperation(op);
1795       return WalkResult::advance();
1796     });
1797     if (res.wasInterrupted()) return failure();
1798   }
1799 
1800   if (changed) {
1801     region->getParentOp()->emitWarning()
1802         << "shape inference did not reach stable state after " << max_iteration
1803         << " iterations";
1804   }
1805   return !changed;
1806 }
1807 
InferShapeForFunction(ShapeInference & context,FuncOp func,int64_t max_iterations)1808 static FailureOr<bool> InferShapeForFunction(ShapeInference& context,
1809                                              FuncOp func,
1810                                              int64_t max_iterations) {
1811   FailureOr<bool> failure_or_converged =
1812       context.InferShapeUntilFixPoint(&func.getBody(), max_iterations);
1813   if (failed(failure_or_converged) || !failure_or_converged.getValue())
1814     return failure_or_converged;
1815   // TODO(b/156276510): Verify that it is always fine to refine a function's
1816   // return type, as long as we do not change the argument shapes.
1817   if (failed(context.InferShapeForFunctionReturnType(func))) return failure();
1818   return true;
1819 }
1820 
InferShapeForFunction(FuncOp func,ArrayRef<ArrayRef<int64_t>> arg_shapes,int64_t graph_version,int64_t max_iterations)1821 FailureOr<bool> InferShapeForFunction(FuncOp func,
1822                                       ArrayRef<ArrayRef<int64_t>> arg_shapes,
1823                                       int64_t graph_version,
1824                                       int64_t max_iterations) {
1825   ShapeInference context(graph_version, func->getParentOfType<ModuleOp>(),
1826                          /*propagate_caller_callee_constants=*/true);
1827   if (arg_shapes.empty()) {
1828     return InferShapeForFunction(context, func, max_iterations);
1829   }
1830 
1831   FunctionType func_type = func.getType();
1832   bool needs_refinement = false;
1833   SmallVector<Type, 4> new_arg_types;
1834   new_arg_types.reserve(func_type.getNumInputs());
1835 
1836   // Update argument types in-place using the provided arg_shapes.
1837   for (size_t i = 0; i < func_type.getNumInputs(); ++i) {
1838     ArrayRef<int64_t> shape = arg_shapes[i];
1839     Type element_type;
1840     if (auto input_ty = func_type.getInput(i).dyn_cast<RankedTensorType>()) {
1841       if (input_ty.getRank() != shape.size()) {
1842         return failure();
1843       }
1844       element_type = input_ty.getElementType();
1845     } else {
1846       auto unranked_input_ty = func_type.getInput(i).dyn_cast<TensorType>();
1847       if (!unranked_input_ty) {
1848         return failure();
1849       }
1850       element_type = unranked_input_ty.getElementType();
1851     }
1852 
1853     auto new_arg_type = RankedTensorType::get(shape, element_type);
1854     if (new_arg_type != func_type.getInput(i)) {
1855       // If the new type is more detailed, trigger shape inference.
1856       func.getArgument(i).setType(new_arg_type);
1857       needs_refinement = true;
1858     }
1859     new_arg_types.push_back(new_arg_type);
1860   }
1861 
1862   if (!needs_refinement) return true;
1863 
1864   FailureOr<bool> failure_or_converged =
1865       context.InferShapeUntilFixPoint(&func.getBody(), max_iterations);
1866   if (failed(failure_or_converged) || !failure_or_converged.getValue())
1867     return failure_or_converged;
1868 
1869   if (failed(context.InferShapeForFunctionReturnType(func))) return failure();
1870   func.setType(FunctionType::get(func.getContext(), new_arg_types,
1871                                  func.getType().getResults()));
1872 
1873   return true;
1874 }
1875 
InferModuleShape(ModuleOp module,int64_t max_iterations)1876 FailureOr<bool> InferModuleShape(ModuleOp module, int64_t max_iterations) {
1877   auto producer_or = tensorflow::GetTfGraphProducerVersion(module);
1878   if (!producer_or.ok()) {
1879     // TODO(jpienaar): Keeping the existing behavior for now but this could
1880     // be relaxed.
1881     LLVM_DEBUG(llvm::dbgs()
1882                << "Skipping inference; " << producer_or.status().ToString());
1883     return true;
1884   }
1885   int64_t producer = producer_or.ValueOrDie();
1886   // TODO(jpienaar): Clean up propagate_NextIterationSinkOp_callee_constants if
1887   // it is no longer needed.
1888   ShapeInference context(producer, module,
1889                          /*propagate_caller_callee_constants=*/false);
1890   if (auto main = module.lookupSymbol<mlir::FuncOp>("main"))
1891     context.enqueue(main);
1892   for (auto func : module.getOps<FuncOp>()) context.enqueue(func);
1893   // Arbitrarily upper bound the maximum number of functions that get processed
1894   // just to avoid pathological cases.
1895   auto max_iteration = context.QueueSize() * 4;
1896   while (!context.EmptyQueue()) {
1897     FuncOp func = context.front();
1898     FailureOr<bool> failure_or_converged =
1899         InferShapeForFunction(context, func, max_iterations);
1900     if (failed(failure_or_converged) || !failure_or_converged.getValue())
1901       return failure_or_converged;
1902     context.pop_front();
1903 
1904     if ((--max_iteration) == 0) {
1905       emitWarning(UnknownLoc::get(module.getContext()))
1906           << "shape inference did not reach stable state after "
1907           << max_iteration << " iterations";
1908       return false;
1909     }
1910   }
1911   return true;
1912 }
1913 
1914 }  // namespace TF
1915 }  // namespace mlir
1916