1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <initializer_list>
21 #include <iterator>
22 #include <queue>
23
24 #include "llvm/ADT/Hashing.h"
25 #include "llvm/ADT/None.h"
26 #include "llvm/ADT/PointerUnion.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/iterator_range.h"
30 #include "llvm/Support/Casting.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/FormatVariadic.h"
33 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
34 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
35 #include "mlir/IR/Attributes.h" // from @llvm-project
36 #include "mlir/IR/Block.h" // from @llvm-project
37 #include "mlir/IR/Builders.h" // from @llvm-project
38 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
39 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
40 #include "mlir/IR/Diagnostics.h" // from @llvm-project
41 #include "mlir/IR/Location.h" // from @llvm-project
42 #include "mlir/IR/Operation.h" // from @llvm-project
43 #include "mlir/IR/OperationSupport.h" // from @llvm-project
44 #include "mlir/IR/SymbolTable.h" // from @llvm-project
45 #include "mlir/IR/Value.h" // from @llvm-project
46 #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
47 #include "mlir/Interfaces/FoldInterfaces.h" // from @llvm-project
48 #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
49 #include "mlir/Pass/Pass.h" // from @llvm-project
50 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
51 #include "mlir/Support/LLVM.h" // from @llvm-project
52 #include "mlir/Support/LogicalResult.h" // from @llvm-project
53 #include "mlir/Transforms/FoldUtils.h" // from @llvm-project
54 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
55 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
56 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
57 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
58 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
59 #include "tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h"
60 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
61 #include "tensorflow/core/framework/shape_inference.h"
62 #include "tensorflow/core/framework/types.pb.h"
63
64 #define DEBUG_TYPE "tf-shape-inference"
65
66 #define DCOMMENT(MSG) LLVM_DEBUG(llvm::dbgs() << MSG << "\n")
67 #define DCOMMENT_OP(OP, MSG) \
68 LLVM_DEBUG(OP->print(llvm::dbgs() << MSG << " "); llvm::dbgs() << "\n")
69
70 using ::tensorflow::int64;
71 using tensorflow::shape_inference::DimensionHandle;
72 using tensorflow::shape_inference::InferenceContext;
73 using tensorflow::shape_inference::ShapeHandle;
74
75 namespace mlir {
76 namespace TF {
77 namespace {
78
79 // Returns whether type can be further refined.
CanBeRefined(Type type)80 bool CanBeRefined(Type type) {
81 auto shape_type = type.dyn_cast<ShapedType>();
82 if (!shape_type) return false;
83
84 // Returns whether type with subtypes can be further refined.
85 auto can_refine_subtypes = [](TF::TensorFlowTypeWithSubtype tws) {
86 return tws.GetSubtypes().empty() ||
87 llvm::any_of(tws.GetSubtypes(), CanBeRefined);
88 };
89 auto type_with_subtype =
90 shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>();
91 if (type_with_subtype && can_refine_subtypes(type_with_subtype)) return true;
92
93 return !shape_type.hasStaticShape();
94 }
95
96 // Returns whether `original_type` type can be refined with
97 // `potential_refined_type` type.
CanRefineTypeWith(Type original_type,Type potential_refined_type)98 bool CanRefineTypeWith(Type original_type, Type potential_refined_type) {
99 if (original_type == potential_refined_type || !CanBeRefined(original_type))
100 return false;
101
102 auto shape_type = potential_refined_type.dyn_cast<ShapedType>();
103 if (!shape_type) return false;
104 if (shape_type.hasRank()) return true;
105
106 auto element_type_with_subtype =
107 shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>();
108 return element_type_with_subtype &&
109 !element_type_with_subtype.GetSubtypes().empty();
110 }
111
112 // Returns if the shape inference pass supports an op outside the TF dialect.
IsSupportedNonTFOp(Operation * op)113 bool IsSupportedNonTFOp(Operation* op) {
114 return isa<tf_device::ReturnOp, tf_device::ClusterOp, tf_device::LaunchOp,
115 tf_executor::EnterOp, tf_executor::ExitOp, tf_executor::FetchOp,
116 tf_executor::GraphOp, tf_executor::IslandOp,
117 tf_executor::LoopCondOp, tf_executor::MergeOp,
118 tf_executor::NextIterationSinkOp, tf_executor::SwitchNOp,
119 tf_executor::SwitchOp, tf_executor::YieldOp>(op);
120 }
121
122 // Returns whether a cast back would need to be inserted, e.g., whether the
123 // operation of which use is an operand allows for shape refinement without
124 // a cast.
NeedsCastBack(OpOperand & use,Dialect * tf_dialect)125 bool NeedsCastBack(OpOperand& use, Dialect* tf_dialect) {
126 return use.getOwner()->getDialect() != tf_dialect &&
127 !IsSupportedNonTFOp(use.getOwner());
128 }
129
CreateTensorType(llvm::Optional<llvm::ArrayRef<int64_t>> shape,Type element_type)130 TensorType CreateTensorType(llvm::Optional<llvm::ArrayRef<int64_t>> shape,
131 Type element_type) {
132 if (shape.hasValue())
133 return RankedTensorType::get(shape.getValue(), element_type);
134 return UnrankedTensorType::get(element_type);
135 }
136
137 // Returns true if the op creates a TensorList.
IsTensorListInitOp(Operation * op)138 bool IsTensorListInitOp(Operation* op) {
139 return isa<TensorListReserveOp>(op) || isa<EmptyTensorListOp>(op) ||
140 isa<TensorListFromTensorOp>(op);
141 }
142
143 // Returns the `element_shape` operand of the ops that create a TensorList.
GetElementShapeOperand(Operation * op)144 Value GetElementShapeOperand(Operation* op) {
145 if (auto empty_tl = dyn_cast<EmptyTensorListOp>(op))
146 return empty_tl.element_shape();
147 if (auto tl_reserve = dyn_cast<TensorListReserveOp>(op))
148 return tl_reserve.element_shape();
149 if (auto tl_from_tensor = dyn_cast<TensorListFromTensorOp>(op))
150 return tl_from_tensor.element_shape();
151 llvm_unreachable("unsupported TensorList op");
152 }
153
154 // Utility function to create a ranked tensor type after dropping the first
155 // dimension from the input type.
DropFirstDimension(Type type)156 RankedTensorType DropFirstDimension(Type type) {
157 RankedTensorType ranked_type = type.dyn_cast<RankedTensorType>();
158 if (!ranked_type) return {};
159 llvm::ArrayRef<int64_t> dims_except_first =
160 ranked_type.getShape().drop_front();
161 return RankedTensorType::get(dims_except_first, ranked_type.getElementType());
162 }
163
164 // Follow the use chain of TensorList and return true iff all elements written
165 // to TensorList have same static shape. If all elements have same shape, assign
166 // it to `potential_element_type`.
167 //
168 // This can handle multiple mutations of a TensorList object and would return
169 // true if across all mutations the elements written have the same shape.
CanInferTensorListElementType(Value tensorlist,Value initial_element_shape,RankedTensorType * potential_element_type)170 bool CanInferTensorListElementType(Value tensorlist,
171 Value initial_element_shape,
172 RankedTensorType* potential_element_type) {
173 // Verifies if the new element type has static shape and matches the potential
174 // type passed from caller. Updates the potential_element_type, if not defined
175 // yet.
176 auto verify_and_update_potential_element_type =
177 [&](RankedTensorType new_element_type) -> bool {
178 if (!new_element_type || !new_element_type.hasStaticShape()) return false;
179 if (!*potential_element_type) {
180 *potential_element_type = new_element_type;
181 return true;
182 }
183 return *potential_element_type == new_element_type;
184 };
185
186 // TensorLists are semantically immutable. For example, TensorListSetItem
187 // takes a TensorList as input and produces a TensorList as output. So to
188 // traverse modifications to TensorList and verify that all elements written
189 // to it have the same shape, we need to follow use-def chain of ops that
190 // (conceptually) modify it i.e., ops that take an input TensorList and
191 // produce an output TensorList.
192 for (auto& use : tensorlist.getUses()) {
193 if (auto push = llvm::dyn_cast<TensorListPushBackOp>(use.getOwner())) {
194 auto element_type = push.tensor().getType().dyn_cast<RankedTensorType>();
195 if (!verify_and_update_potential_element_type(element_type)) return false;
196 if (!CanInferTensorListElementType(push.output_handle(),
197 initial_element_shape,
198 potential_element_type))
199 return false;
200 continue;
201 }
202 if (auto scatter = llvm::dyn_cast<TensorListScatterIntoExistingListOp>(
203 use.getOwner())) {
204 // For scatter op we can get the element shape by dropping the first
205 // dimension of the input tensor.
206 RankedTensorType element_type =
207 DropFirstDimension(scatter.tensor().getType());
208 if (!verify_and_update_potential_element_type(element_type)) return false;
209 if (!CanInferTensorListElementType(scatter.output_handle(),
210 initial_element_shape,
211 potential_element_type))
212 return false;
213 continue;
214 }
215 if (auto set_item = llvm::dyn_cast<TensorListSetItemOp>(use.getOwner())) {
216 auto element_type =
217 set_item.item().getType().dyn_cast<RankedTensorType>();
218 if (!verify_and_update_potential_element_type(element_type)) return false;
219 if (!CanInferTensorListElementType(set_item.output_handle(),
220 initial_element_shape,
221 potential_element_type))
222 return false;
223 continue;
224 }
225 if (auto pop = llvm::dyn_cast<TensorListPopBackOp>(use.getOwner())) {
226 if (!CanInferTensorListElementType(pop.output_handle(),
227 initial_element_shape,
228 potential_element_type))
229 return false;
230 continue;
231 }
232 if (auto resize = llvm::dyn_cast<TensorListResizeOp>(use.getOwner())) {
233 if (!CanInferTensorListElementType(resize.output_handle(),
234 initial_element_shape,
235 potential_element_type))
236 return false;
237 continue;
238 }
239 // WhileRegionOp can explicitly capture TensorList value to be used inside
240 // its regions. So we check the uses of corresponding block argument in each
241 // region and the use of TensorList returned using YieldOp.
242 if (auto while_region = llvm::dyn_cast<WhileRegionOp>(use.getOwner())) {
243 for (auto branch : while_region.getRegions()) {
244 if (!CanInferTensorListElementType(
245 branch->getArgument(use.getOperandNumber()),
246 initial_element_shape, potential_element_type))
247 return false;
248 }
249 continue;
250 }
251 if (auto yield = llvm::dyn_cast<YieldOp>(use.getOwner())) {
252 Operation* parent = yield->getParentOp();
253 if (!CanInferTensorListElementType(
254 parent->getResult(use.getOperandNumber()), initial_element_shape,
255 potential_element_type))
256 return false;
257 continue;
258 }
259 // Refining the tensor list element type might change the output of
260 // TensorListElementShape which is expected tp be the originally assigned
261 // shape to TensorList init ops. So replace it with the original element
262 // shape value.
263 if (auto tl_element_shape =
264 dyn_cast<TensorListElementShapeOp>(use.getOwner())) {
265 tl_element_shape.replaceAllUsesWith(initial_element_shape);
266 continue;
267 }
268 // Ignore ops that just consume a TensorList and do not output another
269 // TensorList.
270 if (isa<TensorListStackOp, TensorListGatherOp, TensorListConcatV2Op,
271 TensorListLengthOp, TensorListGetItemOp>(use.getOwner()))
272 continue;
273
274 // For any other unknown users of the TensorList, we are conservative and
275 // stop element shape inference.
276 return false;
277 }
278 return true;
279 }
280 } // namespace
281
282 // Combination of value producer and port of value produced (e.g.,
283 // <value result output>:<value in output tensor>,
284 // so for tf.Const -> tensor<10x20xf32>, [0,2,18] would point to a unique output
285 // scalar value).
286 struct ValuePort {
287 PointerUnion<Operation*, BlockArgument> producer;
288 SmallVector<unsigned int, 2> port;
289
operator ==mlir::TF::ValuePort290 bool operator==(const ValuePort& other) const {
291 return producer == other.producer && port == other.port;
292 }
293
294 // Convert output value to ValuePort.
ValuePortmlir::TF::ValuePort295 explicit ValuePort(Value v) {
296 OpResult opr = v.dyn_cast<OpResult>();
297 if (opr) {
298 producer = opr.getOwner();
299 port = {opr.getResultNumber()};
300 } else {
301 producer = v.cast<BlockArgument>();
302 port = {0};
303 }
304 }
ValuePortmlir::TF::ValuePort305 ValuePort(PointerUnion<Operation*, BlockArgument> producer,
306 SmallVector<unsigned int, 2> port)
307 : producer(producer), port(port) {}
308
printmlir::TF::ValuePort309 raw_ostream& print(raw_ostream& os) const {
310 if (auto* op = producer.dyn_cast<Operation*>())
311 os << "op " << op->getName();
312 if (auto ba = producer.dyn_cast<BlockArgument>())
313 os << "block_arg " << ba.getArgNumber();
314 os << formatv(" [{0}]", llvm::make_range(port.begin(), port.end()));
315 return os;
316 }
317 };
318
319 struct ValuePortHasher {
operator ()mlir::TF::ValuePortHasher320 std::size_t operator()(const ValuePort& other) const {
321 return hash_combine(llvm::hash_value(other.producer.getOpaqueValue()),
322 hash_value(ArrayRef<unsigned int>(other.port)));
323 }
324 };
325
326 using ValuePortResultMap =
327 std::unordered_map<ValuePort, Attribute, ValuePortHasher>;
328 using ComputedQueryFn = function_ref<bool(ValuePort)>;
329 using ValueQueryFn = function_ref<Attribute(const ValuePort&)>;
330 using ValuePortInputs = SmallVectorImpl<ValuePort>;
331
332 // TODO(jpienaar): ComputeInputsRequiredForOutput and ComputeOutputComponent are
333 // intended to be switched to op interfaces once more refined.
ComputeInputsRequiredForOutput(ValuePort value_port,ComputedQueryFn has_been_computed,ValuePortInputs * inputs)334 LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
335 ComputedQueryFn has_been_computed,
336 ValuePortInputs* inputs) {
337 auto op = value_port.producer.dyn_cast<Operation*>();
338 auto& port = value_port.port;
339 if (!op) return failure();
340
341 // No inputs required for constants.
342 if (matchPattern(op, m_Constant())) return success();
343
344 // Note: this focusses only on the trivial pack op case and this could be
345 // generalized.
346 if (auto pack_op = dyn_cast<TF::PackOp>(op)) {
347 auto type = pack_op.getType().cast<TensorType>();
348 if (!type.hasRank() || type.getRank() != 1) return failure();
349 if (port.size() != 2) return failure();
350 assert(port[0] == 0);
351 ValuePort req(pack_op.getOperand(port[1]));
352 if (!has_been_computed(req)) inputs->push_back(req);
353 return success();
354 }
355
356 return failure();
357 }
358
359 // Computes the output produced by ValuePort using the query function of
360 // existing computed values.
ComputeOutputComponent(const ValuePort & value_port,ValueQueryFn values)361 Attribute ComputeOutputComponent(const ValuePort& value_port,
362 ValueQueryFn values) {
363 LLVM_DEBUG(value_port.print(llvm::dbgs() << "Computing output for ") << "\n");
364 if (auto known = values(value_port)) return known;
365
366 auto op = value_port.producer.dyn_cast<Operation*>();
367 if (!op) return nullptr;
368 auto& port = value_port.port;
369
370 if (port.empty()) {
371 LLVM_DEBUG(llvm::dbgs() << "skipping, port outside spec of " << op << "\n");
372 return nullptr;
373 }
374
375 ElementsAttr attr;
376 if (matchPattern(op, m_Constant(&attr))) {
377 if (port.size() == 1 && port[0] == 0) return attr;
378 return nullptr;
379 }
380
381 // Note: this focusses only on the trivial pack op case and this could be
382 // generalized.
383 if (auto pack_op = dyn_cast<TF::PackOp>(op)) {
384 TensorType type = pack_op.getType().cast<TensorType>();
385 if (!type.hasRank() || type.getRank() != 1) return nullptr;
386 if (port.size() != 2 || port[0] != 0) return nullptr;
387 ValuePort op_port(op->getOperand(port[1]));
388 return values(op_port);
389 }
390
391 if (auto graph = dyn_cast<tf_executor::GraphOp>(op)) {
392 if (port.size() == 1)
393 return ComputeOutputComponent(
394 ValuePort(graph.GetFetch().fetches()[port[0]]), values);
395 return nullptr;
396 }
397
398 if (auto island = dyn_cast<tf_executor::IslandOp>(op)) {
399 if (port.size() == 1)
400 return ComputeOutputComponent(
401 ValuePort(island.GetYield().fetches()[port[0]]), values);
402 return nullptr;
403 }
404
405 return nullptr;
406 }
407
408 // Context used during ShapeInference. This class contains common information
409 // that is required by the individual shape inference helper functions (e.g.,
410 // TF Graph version, constant values computed, etc.)
411 class ShapeInference {
412 public:
413 ShapeInference(int64_t graph_version, MLIRContext* context,
414 bool propagate_caller_callee_constants);
415
ComputeInputsRequiredForOutput(ValuePort value_port,ValuePortInputs * inputs)416 LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
417 ValuePortInputs* inputs) {
418 return ::mlir::TF::ComputeInputsRequiredForOutput(
419 value_port,
420 [this](const ValuePort& port) {
421 return results_.find(port) != results_.end();
422 },
423 inputs);
424 }
425
ComputeOutputComponent(const ValuePort & value_port)426 Attribute ComputeOutputComponent(const ValuePort& value_port) {
427 if (auto known_attr = results_[value_port]) return known_attr;
428 auto attr = ::mlir::TF::ComputeOutputComponent(
429 value_port, [this](const ValuePort& port) { return results_[port]; });
430 RecordValue(value_port, attr);
431 return attr;
432 }
433
434 // Returns ShapeHandle if the op result could be computed as shape.
435 ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic);
436
RecordValue(const ValuePort & value_port,Attribute value)437 void RecordValue(const ValuePort& value_port, Attribute value) {
438 LLVM_DEBUG(value_port.print(llvm::dbgs() << "\trecording ")
439 << value << "\n");
440 results_[value_port] = value;
441 }
442
443 // Infers shape of tf.While/tf.WhileRegion. If `shape_invariant` attribute is
444 // set, operand types are set as result types if associated body result types
445 // match the operand type (does not change per loop iteration). If operand and
446 // body result types are not the same, only handle types are propagated to
447 // result types. This is necessary to not incorrectly change result shapes
448 // when the While op will have a different result shape. Otherwise operand
449 // shapes are propagated to result shapes.
450 template <typename WhileOpTy>
451 bool InferShapeForWhile(WhileOpTy op, TypeRange body_result_types);
452
453 // Performs shape inference on the provided op and return true if the type of
454 // at least one result has been changed.
455 // A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect.
456 // `graph_version` indicates the current GraphDef compatibility versions
457 // (the versions field in graph.proto).
458 bool InferShapeForSingleOperation(Operation* op);
459
460 // Infers shape on the provided region, including nested ones, iterate until
461 // fix point with a limit of max_iteration. Returns success if fix point is
462 // reached before max_iteration.
463 LogicalResult InferShapeUntilFixPoint(Region* region, int64_t max_iterations);
464
465 // Updates input types and refine shapes inside body of functions that are
466 // attached to ControlFlow ops (If/While) or Calls. These functions include
467 // Then/Else branches of IfOp and Cond/Body functions of WhileOp. Functions
468 // attached to control flow share following common properties:
469 // 1) They are never reused, ie. having a single use in module.
470 // 2) Their input types match those of their parent ops (excluding inputs
471 // like predicate).
472 // For calls, functions can be reused across multiple call sites. In this case
473 // we propagate the types when all call sites have the same operand types.
474 LogicalResult PropagateShapeToFunctions(ModuleOp module,
475 TypeRange input_types,
476 ArrayRef<FuncOp> functions,
477 int64_t max_iteration);
478
479 // Propagates shapes to regions given the shapes of the inputs of the regions.
480 // All regions provided in `regions` are assumed to have inputs of type
481 // `input_types`.
482 LogicalResult PropagateShapeToRegions(TypeRange input_types,
483 ArrayRef<Region*> regions,
484 int64_t max_iteration);
485
486 // Shape propagation for call/control flow ops.
487 LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
488 int64_t max_iteration);
489
490 // Shape propagation for region based control flow.
491 LogicalResult PropagateShapeIntoAttachedRegions(Operation* op,
492 int64_t max_iterations);
493
494 // Propagates any constant operand of call_op to the called function body's
495 // corresponding argument if the callee has only one use.
496 //
497 // TODO(b/154065712): Move this to a more general inter-procedural constant
498 // folding pass.
499 void PropagateConstantToCallee(CallOpInterface call_op, FuncOp func,
500 ModuleOp module);
501
502 // Propagates any constant return value of the callee function to the call
503 // op's corresponding result.
504 void PropagateConstantFromCallee(CallOpInterface call_op, FuncOp func,
505 ModuleOp module);
506
507 // Tries to compute the result of folding the op. This doesn't actually
508 // perform constant folding, it is just computes the equivalent constants.
509 // Returns whether it was able to compute constant values.
510 LogicalResult TryToFold(Operation* op);
511
512 // Makes result types match the operand types (the i-th result type will
513 // match the i-th operand type). Returns true if anything is changed.
514 bool RefineTypeForPassThroughOperands(Operation* op, OperandRange operands,
515 ResultRange results);
516
517 // Makes result type's shape match the corresponding operand's shape.
518 // Returns whether any change was made.
519 bool RefineShapeForPassThroughOps(Operation* op);
520
521 // Infers shape for necessary ops that are not in the TF dialect. Returns
522 // whether any result type changed.
523 bool InferShapeForNonTFDialectOperation(Operation* op);
524
525 // Infers shape for function return type and returns whether changed.
526 void InferShapeForFunctionReturnType(FuncOp func);
527
528 // Enqueues function for processing.
enqueue(FuncOp fn)529 void enqueue(FuncOp fn) {
530 LLVM_DEBUG(llvm::dbgs()
531 << "enqueue " << fn.getName() << " ("
532 << (queue_set_.count(fn) ? "already inserted" : "newly inserted")
533 << ")\n");
534 if (queue_set_.insert(fn).second) queue_.push(fn);
535 }
536
537 // Enqueues callers on functions.
538 void EnqueueCallers(FuncOp fn);
539
540 // Returns the function at the front of the queue.
front()541 FuncOp front() { return queue_.front(); }
542
543 // Returns whether work queue is empty.
EmptyQueue() const544 bool EmptyQueue() const { return queue_.empty(); }
545
546 // Returns function from the front of the work queue.
pop_front()547 FuncOp pop_front() {
548 FuncOp ret = queue_.front();
549 queue_.pop();
550 queue_set_.erase(ret);
551 return ret;
552 }
553
554 // Returns the current size of the queue.
QueueSize() const555 std::queue<FuncOp>::size_type QueueSize() const { return queue_.size(); }
556
557 Dialect* const tf_dialect_;
558
559 private:
560 // Updates the result of an operation to a new inferred type. Also inserts
561 // tf.Cast operation for uses that are incompatible with the new type.
562 void UpdateTypeAndInsertIncompatibleUseCasts(Type new_type, Value result);
563
564 // Refines the type of `result` of `op` using the type
565 // `potential_refined_type`. Return true if the type was changed.
566 bool RefineResultType(Operation* op, Value result,
567 Type potential_refined_type);
568
569 // Infers the shape from a (Stateful)PartionedCall operation by looking up the
570 // called function and propagating the return type.
571 bool InferShapeForCall(CallOpInterface call_op);
572
573 bool InferShapeForCast(CastOp op);
574
575 // Infers the shape IfOp outputs based on the shapes of the then and else
576 // function result types.
577 bool InferShapeForIf(IfOp op);
578
579 // Infers the shape IfRegion outputs based on the shapes of the then and else
580 // yields.
581 bool InferShapeForIfRegion(IfRegionOp op);
582
583 // Infers the shape of ops that create TensorList. Specifically,
584 // TensorListReserveOp, EmptyTensorListOp and TensorListFromTensor ops. It
585 // refines the element shape if all tensors written to the list across all
586 // mutations have identical static shape.
587 bool InferShapeForTensorListInitOps(Operation* op);
588
589 bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti);
590
591 // Returns all the callers of a function.
592 // Note: Usage of the return value of this function may not be interleaved
593 // with insertions to the callers map. This could occur if GetCallers is
594 // called with two separate functions, the 2nd one incurs a resize and then
595 // both first and 2nd stored callers are used.
596 ArrayRef<Operation*> GetCallers(FuncOp fn);
597
598 // Mapping between ValuePort (which corresponds to an OpResult or smaller,
599 // e.g., first element of OpResult produced) to an Attribute if the ValuePort
600 // corresponds to a constant value.
601 ValuePortResultMap results_;
602
603 // Map from a function to the callers of that function.
604 llvm::DenseMap<FuncOp, SmallVector<Operation*, 4>> callers_of_func_;
605
606 // Queue of functions being processed.
607 llvm::DenseSet<FuncOp> queue_set_;
608 std::queue<FuncOp> queue_;
609
610 int64_t graph_version_;
611
612 // TODO(b/154065712): Remove propagate_caller_callee_constants once using
613 // SCCP pass instead.
614 bool propagate_caller_callee_constants_;
615 };
616
ShapeInference(int64_t graph_version,MLIRContext * context,bool propagate_caller_callee_constants)617 ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context,
618 bool propagate_caller_callee_constants)
619 : tf_dialect_(context->getLoadedDialect<TensorFlowDialect>()),
620 graph_version_(graph_version),
621 propagate_caller_callee_constants_(propagate_caller_callee_constants) {}
622
GetCallers(FuncOp fn)623 ArrayRef<Operation*> ShapeInference::GetCallers(FuncOp fn) {
624 auto pair = callers_of_func_.try_emplace(fn);
625 if (pair.second) {
626 ModuleOp module = fn->getParentOfType<ModuleOp>();
627 auto uses = mlir::SymbolTable::getSymbolUses(fn.getOperation(), module);
628 if (uses) {
629 pair.first->second.reserve(pair.first->second.size());
630 for (auto use : *uses) {
631 pair.first->second.push_back(use.getUser());
632 }
633 }
634 }
635 return pair.first->second;
636 }
637
EnqueueCallers(FuncOp fn)638 void ShapeInference::EnqueueCallers(FuncOp fn) {
639 for (auto user : GetCallers(fn)) enqueue(user->getParentOfType<FuncOp>());
640 }
641
UpdateTypeAndInsertIncompatibleUseCasts(Type new_type,Value result)642 void ShapeInference::UpdateTypeAndInsertIncompatibleUseCasts(Type new_type,
643 Value result) {
644 // A tf.Cast operation is lazily created on the first use requires a cast.
645 TF::CastOp cast_op;
646 auto get_cast_op = [&]() {
647 if (!cast_op) {
648 Operation* op = result.getDefiningOp();
649 OpBuilder b(op);
650 b.setInsertionPointAfter(op);
651 cast_op = b.create<TF::CastOp>(op->getLoc(), result.getType(), result,
652 /*truncate=*/b.getBoolAttr(false));
653 }
654 return Value(cast_op);
655 };
656 // First insert cast back for uses that need a cast and then
657 // update the type.
658 bool enqueue_callers = false;
659 for (OpOperand& use : make_early_inc_range(result.getUses())) {
660 if (isa<ReturnOp>(use.getOwner()))
661 enqueue_callers = true;
662 else if (NeedsCastBack(use, tf_dialect_))
663 use.set(get_cast_op());
664 }
665
666 result.setType(new_type);
667 if (enqueue_callers)
668 EnqueueCallers(result.getDefiningOp()->getParentOfType<FuncOp>());
669 }
670
RefineResultType(Operation * op,Value result,Type potential_refined_type)671 bool ShapeInference::RefineResultType(Operation* op, Value result,
672 Type potential_refined_type) {
673 if (!CanRefineTypeWith(result.getType(), potential_refined_type))
674 return false;
675
676 UpdateTypeAndInsertIncompatibleUseCasts(potential_refined_type, result);
677 return true;
678 }
679
680 // Infers the shape from a (Stateful)PartionedCall operation by looking up the
681 // called function and propagating the return type.
InferShapeForCall(CallOpInterface call_op)682 bool ShapeInference::InferShapeForCall(CallOpInterface call_op) {
683 FuncOp func = dyn_cast<FuncOp>(call_op.resolveCallable());
684 if (!func) return false;
685
686 LLVM_DEBUG(llvm::dbgs() << "Infer shape for call " << func.getName());
687 Operation* op = call_op.getOperation();
688 bool changed = false;
689 // Map each of the results of the call to the returned type of the
690 // function.
691 for (auto result : zip(op->getResults(), func.getType().getResults())) {
692 changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) ||
693 changed;
694 }
695 LLVM_DEBUG(llvm::dbgs() << " changed ? " << changed << "\n");
696
697 return changed;
698 }
699
InferShapeForCast(CastOp op)700 bool ShapeInference::InferShapeForCast(CastOp op) {
701 DCOMMENT_OP(op.getOperation(), "Inferring shape for ");
702 Value result = op.getResult();
703 if (!CanBeRefined(result.getType())) return false;
704
705 Type operand_type = op.getOperand().getType();
706 auto ranked_op_type = operand_type.dyn_cast<RankedTensorType>();
707 if (!ranked_op_type) return false;
708 auto ranked_res_type = result.getType().dyn_cast<RankedTensorType>();
709 if (ranked_res_type &&
710 ranked_op_type.getShape() == ranked_res_type.getShape())
711 return false;
712
713 // Avoid inserting a cast where no users types could be refined (e.g., where
714 // there would need to be a cast inserted for every user again).
715 if (llvm::all_of(result.getUses(), [this](OpOperand& use) {
716 return NeedsCastBack(use, tf_dialect_);
717 }))
718 return false;
719
720 auto new_type = RankedTensorType::get(
721 ranked_op_type.getShape(),
722 result.getType().cast<ShapedType>().getElementType());
723
724 UpdateTypeAndInsertIncompatibleUseCasts(new_type, op.getResult());
725 return true;
726 }
727
InferShapeForIf(IfOp op)728 bool ShapeInference::InferShapeForIf(IfOp op) {
729 DCOMMENT_OP(op.getOperation(), "Infer shape for if ");
730 bool changed = false;
731 auto then_results = op.then_function().getType().getResults();
732 auto else_results = op.else_function().getType().getResults();
733 for (auto it : llvm::zip(op.getResults(), then_results, else_results)) {
734 // If then and else types do not match, skip refinement for that result.
735 if (std::get<1>(it) != std::get<2>(it)) continue;
736 changed = RefineResultType(op, std::get<0>(it), std::get<1>(it)) || changed;
737 }
738 return changed;
739 }
740
InferShapeForIfRegion(IfRegionOp op)741 bool ShapeInference::InferShapeForIfRegion(IfRegionOp op) {
742 bool changed = false;
743
744 Operation* then_yield = op.then_branch().front().getTerminator();
745 Operation* else_yield = op.else_branch().front().getTerminator();
746 for (auto result : zip(op.getResults(), then_yield->getOperandTypes(),
747 else_yield->getOperandTypes())) {
748 // If then and else types do not match, skip refinement for that result.
749 if (std::get<1>(result) != std::get<2>(result)) continue;
750 changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) ||
751 changed;
752 }
753 return changed;
754 }
755
InferShapeForTensorListInitOps(Operation * op)756 bool ShapeInference::InferShapeForTensorListInitOps(Operation* op) {
757 DCOMMENT_OP(op, "Inferring shape for TensorList ");
758 Value handle = op->getResult(0);
759 Value initial_element_shape = GetElementShapeOperand(op);
760 RankedTensorType element_type;
761 if (auto tl_from_tensor = dyn_cast<TensorListFromTensorOp>(op)) {
762 // For TensorListFromTensor op we can infer element shape by dropping the
763 // first dimension of input tensor.
764 element_type = DropFirstDimension(tl_from_tensor.tensor().getType());
765 if (!element_type || !element_type.hasStaticShape()) return false;
766 }
767 if (!CanInferTensorListElementType(handle, initial_element_shape,
768 &element_type))
769 return false;
770 if (!element_type || !element_type.hasStaticShape()) return false;
771 auto variant_type = VariantType::get(element_type, op->getContext());
772 auto tensor_type = RankedTensorType::get({}, variant_type);
773 bool changed = RefineResultType(op, handle, tensor_type);
774 if (changed) DCOMMENT_OP(op, "Modified after shape inference:");
775 return changed;
776 }
777
RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti)778 bool ShapeInference::RefineWithInferTypeOpInterface(
779 InferTypeOpInterface infer_ti) {
780 Operation* op = infer_ti.getOperation();
781 SmallVector<Type, 4> inferred;
782 LogicalResult res = infer_ti.inferReturnTypes(
783 op->getContext(), op->getLoc(), op->getOperands(),
784 op->getAttrDictionary(), op->getRegions(), inferred);
785 if (failed(res)) {
786 op->emitOpError("failed to refine type as inference failed");
787 return false;
788 }
789
790 if (inferred == op->getResultTypes()) return false;
791
792 // Map each of the results of the call to the returned type of the
793 // function.
794 bool changed = false;
795 for (auto result : zip(op->getResults(), inferred)) {
796 if (std::get<0>(result).getType() == std::get<1>(result)) continue;
797
798 UpdateTypeAndInsertIncompatibleUseCasts(std::get<1>(result),
799 std::get<0>(result));
800 changed = true;
801 }
802 return changed;
803 }
804
ComputeOutputAsShape(OpResult result,InferenceContext * ic)805 ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result,
806 InferenceContext* ic) {
807 LLVM_DEBUG(result.print(llvm::dbgs() << "\nEvaluate partially "));
808 auto rt = result.getType().dyn_cast<RankedTensorType>();
809 if (!rt || !rt.hasStaticShape() || rt.getRank() != 1) return {};
810 int dim_size = rt.getDimSize(0);
811
812 // Worklist to direct partial evaluation.
813 SmallVector<ValuePort, 4> worklist;
814
815 // Simple evaluator that attempts to partially evaluate the input value even
816 // if unable to evaluate the complete output. Below follows a simple stack
817 // based evaluation where it queries what operands/part of operands need to
818 // be evaluated and attempting to partially evaluate those operands. It does
819 // so by pushing the operands that need to be required on to the worklist
820 // before enqueuing the operation requiering those values.
821 std::vector<DimensionHandle> dims(dim_size, ic->UnknownDim());
822 for (unsigned int i = 0, e = dims.size(); i != e; ++i) {
823 LLVM_DEBUG(llvm::dbgs() << "\nConsidering output dim " << i << "\n");
824
825 worklist.push_back(
826 ValuePort{result.getOwner(), {result.getResultNumber(), i}});
827 while (!worklist.empty()) {
828 auto front = worklist.pop_back_val();
829 LLVM_DEBUG(front.print(llvm::dbgs() << "\nWorklist front "));
830
831 SmallVector<ValuePort, 4> inputs;
832 auto res = ComputeInputsRequiredForOutput(front, &inputs);
833 if (failed(res)) {
834 // Abort if unable to find which required inputs need to be computed.
835 worklist.clear();
836 break;
837 }
838
839 if (!inputs.empty()) {
840 // Enqueue required computation followed by its required operands in
841 // stack.
842 worklist.push_back(std::move(front));
843 for (auto& it : inputs) worklist.push_back(std::move(it));
844 continue;
845 }
846
847 auto ret = ComputeOutputComponent(front);
848 if (!ret) continue;
849
850 LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = "));
851
852 // If worklist is empty, then this is the root query op.
853 if (worklist.empty()) {
854 LLVM_DEBUG(llvm::dbgs() << "[root node]\n");
855 if (auto dea = ret.dyn_cast<DenseIntElementsAttr>()) {
856 if (dea.getNumElements() != 1) {
857 LLVM_DEBUG(llvm::dbgs() << "Unexpected number of elements\n");
858 return {};
859 }
860 int64_t val = (*dea.getIntValues().begin()).getSExtValue();
861 dims[i] = ic->MakeDim(val);
862 }
863 }
864 }
865 }
866 return ic->MakeShape(dims);
867 }
868
RefineTypeForPassThroughOperands(Operation * op,OperandRange operands,ResultRange results)869 bool ShapeInference::RefineTypeForPassThroughOperands(Operation* op,
870 OperandRange operands,
871 ResultRange results) {
872 bool changed = false;
873 for (auto entry : zip(operands, results)) {
874 Type operand_type = std::get<0>(entry).getType();
875 Value result = std::get<1>(entry);
876 TensorType result_type = result.getType().cast<TensorType>();
877 if (operand_type == result_type) continue;
878 // Pass through nodes may remove ref types, don't consider that as
879 // refinement.
880 // TODO(jpienaar): There could be refinement in addition to this, so
881 // refine this.
882 if (operand_type.cast<TensorType>()
883 .getElementType()
884 .isa<TF::TensorFlowRefType>() &&
885 !result_type.cast<TensorType>()
886 .getElementType()
887 .isa<TF::TensorFlowRefType>())
888 continue;
889
890 UpdateTypeAndInsertIncompatibleUseCasts(operand_type, result);
891 changed = true;
892 }
893 return changed;
894 }
895
RefineShapeForPassThroughOps(Operation * op)896 bool ShapeInference::RefineShapeForPassThroughOps(Operation* op) {
897 DCOMMENT_OP(op, "Pass through op");
898 auto is_allowed_dtype = [](Type t) {
899 // Skip if element type is not in standard or TF dialect.
900 // TODO(jpienaar): The tf.Cast op, which is uniformly inserted at the
901 // moment, cannot handle arbirary types (e.g., it can't handle quantized
902 // types). This restriction can be relaxed if not only tf.Cast is used.
903 return t.getDialect().getNamespace().empty() ||
904 isa<TensorFlowDialect>(t.getDialect());
905 };
906
907 bool changed = false;
908 for (auto entry : zip(op->getOperands(), op->getResults())) {
909 TensorType operand_type = std::get<0>(entry).getType().cast<TensorType>();
910 Value result = std::get<1>(entry);
911 TensorType result_type = result.getType().cast<TensorType>();
912 if (operand_type == result_type) continue;
913 if (!operand_type.hasRank()) continue;
914 if (result_type.hasRank() &&
915 result_type.getShape() == operand_type.getShape())
916 continue;
917 if (!is_allowed_dtype(operand_type.getElementType()) ||
918 !is_allowed_dtype(result_type.getElementType()))
919 continue;
920
921 auto new_type = RankedTensorType::get(operand_type.getShape(),
922 result_type.getElementType());
923 UpdateTypeAndInsertIncompatibleUseCasts(new_type, result);
924 changed = true;
925 }
926 return changed;
927 }
928
InferShapeForNonTFDialectOperation(Operation * op)929 bool ShapeInference::InferShapeForNonTFDialectOperation(Operation* op) {
930 if (auto graph_op = dyn_cast<tf_executor::GraphOp>(op)) {
931 return RefineTypeForPassThroughOperands(
932 graph_op.GetFetch(), graph_op.GetFetch().fetches(), op->getResults());
933 }
934 if (auto island_op = dyn_cast<tf_executor::IslandOp>(op)) {
935 return RefineTypeForPassThroughOperands(
936 island_op.GetYield(), island_op.GetYield().fetches(), op->getResults());
937 }
938 if (auto iter_sink = dyn_cast<tf_executor::NextIterationSinkOp>(op)) {
939 auto iter_source = cast<tf_executor::NextIterationSourceOp>(
940 iter_sink.token().getDefiningOp());
941 return RefineTypeForPassThroughOperands(
942 op, iter_sink.getOperands().drop_front().take_front(),
943 iter_source.getResults());
944 }
945 if (auto launch_op = dyn_cast<tf_device::LaunchOp>(op)) {
946 auto terminator = launch_op.GetBody().getTerminator();
947 return RefineTypeForPassThroughOperands(op, terminator->getOperands(),
948 op->getResults());
949 }
950 if (auto cluster_op = dyn_cast<tf_device::ClusterOp>(op)) {
951 auto terminator = cluster_op.GetBody().getTerminator();
952 return RefineTypeForPassThroughOperands(op, terminator->getOperands(),
953 op->getResults());
954 }
955 if (op->hasTrait<OpTrait::SameOperandsAndResultShape>() ||
956 isa<tensor::CastOp>(op)) {
957 return RefineShapeForPassThroughOps(op);
958 }
959 if (auto call = dyn_cast<CallOpInterface>(op)) return InferShapeForCall(call);
960 return false;
961 }
962
963 // Finds element type to be used for result from operand, with special handling
964 // for handle types.
GetElementTypeFromOperand(TensorType operand_type,TensorType result_type)965 Type GetElementTypeFromOperand(TensorType operand_type,
966 TensorType result_type) {
967 auto operand_handle_type =
968 operand_type.getElementType().dyn_cast<TensorFlowTypeWithSubtype>();
969 if (!operand_handle_type) return result_type.getElementType();
970 auto result_handle_type =
971 result_type.getElementType().cast<TensorFlowTypeWithSubtype>();
972 if (operand_handle_type.GetSubtypes().empty() ||
973 !result_handle_type.GetSubtypes().empty())
974 return result_type.getElementType();
975 return operand_handle_type;
976 }
977
978 // Checks if one tensor type can refine another type for tf.While/
979 // tf.WhileRegion. If rank differs or static dimensions can be lost, the other
980 // type cannot be used for refinement.
CanWhileTypeBeRefinedWith(TensorType current_type,TensorType potential_refined_type)981 bool CanWhileTypeBeRefinedWith(TensorType current_type,
982 TensorType potential_refined_type) {
983 if (!current_type.hasRank()) return true;
984 if (!potential_refined_type.hasRank()) return false;
985 if (current_type.getRank() != potential_refined_type.getRank()) return false;
986 for (auto dim :
987 llvm::zip(current_type.getShape(), potential_refined_type.getShape())) {
988 int64_t current_dim = std::get<0>(dim);
989 int64_t potential_refined_dim = std::get<1>(dim);
990 if (current_dim != potential_refined_dim &&
991 current_dim != ShapedType::kDynamicSize)
992 return false;
993 }
994 return true;
995 }
996
997 template <typename WhileOpTy>
InferShapeForWhile(WhileOpTy op,TypeRange body_result_types)998 bool ShapeInference::InferShapeForWhile(WhileOpTy op,
999 TypeRange body_result_types) {
1000 if (!op.shape_invariant())
1001 return RefineTypeForPassThroughOperands(op, op.input(), op.output());
1002
1003 bool changed = false;
1004 for (auto entry :
1005 zip(op.input().getTypes(), op.output(), body_result_types)) {
1006 Value result = std::get<1>(entry);
1007 TensorType body_result_type =
1008 std::get<2>(entry).template cast<TensorType>();
1009 auto result_type = result.getType().cast<TensorType>();
1010
1011 Type potential_refined_type;
1012 if (CanWhileTypeBeRefinedWith(result_type, body_result_type)) {
1013 Type element_type =
1014 GetElementTypeFromOperand(body_result_type, result_type);
1015 potential_refined_type = CreateTensorType(
1016 body_result_type.hasRank() ? body_result_type.getShape()
1017 : llvm::Optional<ArrayRef<int64_t>>(),
1018 element_type);
1019 } else {
1020 TensorType operand_type = std::get<0>(entry).template cast<TensorType>();
1021 Type element_type = GetElementTypeFromOperand(operand_type, result_type);
1022 potential_refined_type = CreateTensorType(
1023 result_type.hasRank() ? result_type.getShape()
1024 : llvm::Optional<ArrayRef<int64_t>>(),
1025 element_type);
1026 }
1027 changed |= RefineResultType(op, result, potential_refined_type);
1028 }
1029 return changed;
1030 }
1031
InferShapeForSingleOperation(Operation * op)1032 bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
1033 LLVM_DEBUG(op->print(llvm::dbgs() << "InferShapeForSingleOperation for ");
1034 llvm::dbgs() << "\n");
1035 assert(tf_dialect_ == op->getDialect());
1036 // The shape function of these ops sometimes does not propagate subtypes
1037 // (handle shapes) for resource and variant types. We use a simple passthrough
1038 // to make sure they are preserved in the output.
1039 if (isa<TF::IdentityOp, TF::IdentityNOp, TF::ZerosLikeOp>(op)) {
1040 return RefineTypeForPassThroughOperands(op, op->getOperands(),
1041 op->getResults());
1042 }
1043
1044 // If no result for this op needs shape inference, we have a fast-path return.
1045 // But if the type is a resource/variant, we do not skip it because we might
1046 // not have the handle shapes.
1047 if (none_of(op->getResultTypes(), CanBeRefined)) {
1048 LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '"
1049 << op->getName() << "'.\n");
1050 return false;
1051 }
1052
1053 // Handle call operations by looking up callee and inferring return shape as
1054 // needed.
1055 if (auto call = dyn_cast<CallOpInterface>(op)) return InferShapeForCall(call);
1056
1057 // tf.Cast are only inferred if they have at least one user in the TF dialect
1058 // or feeding into the function return. This is necessary to avoid inserting
1059 // casts which cannot be refined.
1060 if (auto cast_op = dyn_cast<CastOp>(op)) return InferShapeForCast(cast_op);
1061
1062 // Handle IfOp here by inferring the shape from the else/then function
1063 // results. Since `output_shapes` is a derived attribute, avoid going down the
1064 // TF InferenceContext path as IfOp shape inference is implemented as just
1065 // a lookup of the output_shapes attribute.
1066 if (auto if_op = dyn_cast<IfOp>(op)) return InferShapeForIf(if_op);
1067
1068 // Handle IfRegion operations by inferring return shape from the then and else
1069 // branches.
1070 if (auto if_region = dyn_cast<IfRegionOp>(op))
1071 return InferShapeForIfRegion(if_region);
1072
1073 if (auto while_op = dyn_cast<WhileOp>(op))
1074 return InferShapeForWhile(while_op,
1075 while_op.body_function().getType().getResults());
1076
1077 if (auto while_region = dyn_cast<WhileRegionOp>(op))
1078 return InferShapeForWhile(
1079 while_region,
1080 while_region.body().front().getTerminator()->getOperandTypes());
1081
1082 // Handle TensorList init operations by inferring shape from TensorList write
1083 // operations. If we are unable to refine element shape here, proceed to use
1084 // the InferenceContext below to get more precise shapes.
1085 if (IsTensorListInitOp(op) && InferShapeForTensorListInitOps(op)) return true;
1086
1087 // Return operand as a constant attribute.
1088 auto operand_as_constant_fn = [&](Value operand) {
1089 ValuePort vp(operand);
1090 Attribute attr = ComputeOutputComponent(vp);
1091 if (!attr && matchPattern(operand, m_Constant(&attr)))
1092 RecordValue(vp, attr);
1093 return attr;
1094 };
1095
1096 // Return op result as a shape.
1097 auto op_result_as_shape_fn = [&](InferenceContext& context,
1098 OpResult op_result) {
1099 return ComputeOutputAsShape(op_result, &context);
1100 };
1101
1102 // Return result element type at `index`.
1103 auto result_element_type_fn = [&](int index) {
1104 return op->getResult(index).getType().cast<TensorType>().getElementType();
1105 };
1106
1107 llvm::SmallVector<ShapedTypeComponents, 4> inferred_return_shapes;
1108 if (failed(InferReturnTypeComponentsForTFOp(
1109 /*location=*/None, op, graph_version_, operand_as_constant_fn,
1110 op_result_as_shape_fn, result_element_type_fn,
1111 inferred_return_shapes)))
1112 return false;
1113
1114 // Update the shape for each of the operation result if the InferenceContext
1115 // has more precise shapes recorded.
1116 bool changed = false;
1117 for (auto result : llvm::zip(op->getResults(), inferred_return_shapes)) {
1118 Value op_result = std::get<0>(result);
1119 if (!CanBeRefined(op_result.getType())) continue;
1120
1121 ShapedTypeComponents inferred = std::get<1>(result);
1122 TensorType inferred_type;
1123 if (inferred.hasRank())
1124 inferred_type =
1125 RankedTensorType::get(inferred.getDims(), inferred.getElementType());
1126 else
1127 inferred_type = UnrankedTensorType::get(inferred.getElementType());
1128
1129 if (op_result.getType() == inferred_type) continue;
1130 UpdateTypeAndInsertIncompatibleUseCasts(inferred_type, op_result);
1131 changed = true;
1132 }
1133
1134 if (changed) DCOMMENT_OP(op, "Modified after shape inference:");
1135 return changed;
1136 }
1137
PropagateShapeToFunctions(ModuleOp module,TypeRange input_types,ArrayRef<FuncOp> functions,int64_t max_iteration)1138 LogicalResult ShapeInference::PropagateShapeToFunctions(
1139 ModuleOp module, TypeRange input_types, ArrayRef<FuncOp> functions,
1140 int64_t max_iteration) {
1141 bool all_succeeded = true;
1142 // If shape propagation fails for one function, return failure, but do not
1143 // early exit and attempt to propagate shapes for all provided functions to
1144 // have a best-effort propagation.
1145 for (FuncOp func : functions) {
1146 DCOMMENT("Propating shape to" << func.getName());
1147 ArrayRef<Operation*> callers = GetCallers(func);
1148 if (!llvm::hasSingleElement(callers) &&
1149 !llvm::all_of(callers.drop_front(), [&](Operation* caller) {
1150 /// TODO(aminim): this is overly conservative as some operations
1151 /// (like TPUPartitionedCallOp) may have extra operands that aren't
1152 /// propagated to the callee.
1153 return isa<CallOpInterface>(caller) &&
1154 std::equal(caller->getOperandTypes().begin(),
1155 caller->getOperandTypes().end(),
1156 callers.front()->getOperandTypes().begin());
1157 })) {
1158 all_succeeded = false;
1159 if (llvm::any_of(callers, [](Operation* op) {
1160 return isa<IfOp, WhileOp, CaseOp>(op);
1161 }))
1162 func.emitWarning(formatv(
1163 "expected control flow function @{0} to have exactly 1 use, "
1164 "found {1}.",
1165 func.getName(), callers.size()));
1166
1167 continue;
1168 }
1169 FunctionType func_type = func.getType();
1170 func.setType(FunctionType::get(func.getContext(), input_types,
1171 func_type.getResults()));
1172
1173 auto res =
1174 PropagateShapeToRegions(input_types, {&func.getBody()}, max_iteration);
1175 if (failed(res)) {
1176 all_succeeded = false;
1177 continue;
1178 }
1179
1180 InferShapeForFunctionReturnType(func);
1181 }
1182 return success(all_succeeded);
1183 }
1184
PropagateShapeToRegions(TypeRange input_types,ArrayRef<Region * > regions,int64_t max_iteration)1185 LogicalResult ShapeInference::PropagateShapeToRegions(TypeRange input_types,
1186 ArrayRef<Region*> regions,
1187 int64_t max_iteration) {
1188 DCOMMENT("\tPropagating shapes to regions");
1189 bool all_succeeded = true;
1190 // If shape propagation fails for one region, return failure, but do not
1191 // early exit and attempt to propagate shapes for all provided regions to
1192 // have a best-effort propagation.
1193 for (auto region : regions) {
1194 // Refine region arguments.
1195 Block& entry = region->front();
1196 assert(llvm::size(input_types) == entry.getNumArguments());
1197 for (auto it : llvm::zip(entry.getArguments(), input_types)) {
1198 BlockArgument arg = std::get<0>(it);
1199 Type type = std::get<1>(it);
1200 arg.setType(type);
1201 }
1202
1203 // Propagate shapes into the region.
1204 all_succeeded = succeeded(InferShapeUntilFixPoint(region, max_iteration)) &&
1205 all_succeeded;
1206 }
1207 return success(all_succeeded);
1208 }
1209
PropagateConstantToCallee(CallOpInterface call_op,FuncOp func,ModuleOp module)1210 void ShapeInference::PropagateConstantToCallee(CallOpInterface call_op,
1211 FuncOp func, ModuleOp module) {
1212 auto callers = GetCallers(func);
1213 if (!llvm::hasSingleElement(callers)) return;
1214
1215 OpBuilder builder(&func.front().front());
1216 Operation* op = call_op.getOperation();
1217 // If this is the only caller, and an operand is a constant, propagate
1218 // the constant value inside the function.
1219 for (auto arg : func.getArguments()) {
1220 auto operand = op->getOperand(arg.getArgNumber());
1221 if (propagate_caller_callee_constants_) {
1222 if (isa_and_nonnull<TF::ConstOp>(operand.getDefiningOp())) {
1223 arg.replaceAllUsesWith(
1224 builder.clone(*operand.getDefiningOp())->getResult(0));
1225 }
1226 continue;
1227 }
1228
1229 auto known_constant = ComputeOutputComponent(ValuePort(operand));
1230 if (!known_constant) continue;
1231 LLVM_DEBUG(call_op.print(llvm::dbgs() << "Propagate to calee: ");
1232 known_constant.print(llvm::dbgs() << " constant ");
1233 llvm::dbgs() << "\n");
1234 RecordValue(ValuePort(arg), known_constant);
1235 }
1236 }
1237
PropagateConstantFromCallee(CallOpInterface call_op,FuncOp func,ModuleOp module)1238 void ShapeInference::PropagateConstantFromCallee(CallOpInterface call_op,
1239 FuncOp func, ModuleOp module) {
1240 // If the return value is a constant, use the constant as the value of
1241 // the call return.
1242 Operation* op = call_op.getOperation();
1243 OpBuilder builder(op);
1244 builder.setInsertionPointAfter(op);
1245 for (auto retval :
1246 llvm::enumerate(func.front().getTerminator()->getOperands())) {
1247 if (propagate_caller_callee_constants_) {
1248 auto retval_op = retval.value().getDefiningOp();
1249 if (isa_and_nonnull<TF::ConstOp>(retval_op)) {
1250 op->getResult(retval.index())
1251 .replaceAllUsesWith(builder.clone(*retval_op)->getResult(0));
1252 }
1253 continue;
1254 }
1255
1256 ValuePort vp(retval.value());
1257 if (auto known_constant = ComputeOutputComponent(vp)) {
1258 LLVM_DEBUG(known_constant.print(llvm::dbgs() << "Propagate constant ");
1259 call_op.print(llvm::dbgs() << "from "); llvm::dbgs() << "\n");
1260 RecordValue(ValuePort(op->getResult(retval.index())), known_constant);
1261 }
1262 }
1263 }
1264
RankedAndSameRank(TensorType lhs,TensorType rhs)1265 bool RankedAndSameRank(TensorType lhs, TensorType rhs) {
1266 return lhs.hasRank() && rhs.hasRank() && lhs.getRank() == rhs.getRank();
1267 }
1268
1269 // Creates a compatible RankedTensorType where mismatched dimensions are
1270 // replaced with dynamic sizes.
GetCompatibleRankedTensorType(RankedTensorType lhs,RankedTensorType rhs)1271 RankedTensorType GetCompatibleRankedTensorType(RankedTensorType lhs,
1272 RankedTensorType rhs) {
1273 assert(lhs.getRank() == rhs.getRank());
1274 llvm::SmallVector<int64_t, 4> dims;
1275 dims.reserve(lhs.getRank());
1276 for (auto dim : llvm::zip(lhs.getShape(), rhs.getShape())) {
1277 int64_t lhs_dim = std::get<0>(dim);
1278 if (lhs_dim == std::get<1>(dim)) {
1279 dims.push_back(lhs_dim);
1280 } else {
1281 dims.push_back(ShapedType::kDynamicSize);
1282 }
1283 }
1284 return RankedTensorType::get(dims, GetElementTypeFromOperand(lhs, rhs));
1285 }
1286
1287 // Finds compatible types to propagate into functions/regions of a shape
1288 // invariant tf.While/tf.WhileRegion. If operand and result types are the same,
1289 // that type is returned. If operand and result types are of the same rank, a
1290 // compatible type with matching dimensions is used. Otherwise functions/regions
1291 // arguments are returned but with the handle type from the operand type.
GetWhileCompatibleTypes(TypeRange operand_types,TypeRange result_types,TypeRange region_argument_types)1292 llvm::SmallVector<Type, 4> GetWhileCompatibleTypes(
1293 TypeRange operand_types, TypeRange result_types,
1294 TypeRange region_argument_types) {
1295 llvm::SmallVector<Type, 4> types;
1296 types.reserve(operand_types.size());
1297 for (auto entry :
1298 llvm::zip(operand_types, result_types, region_argument_types)) {
1299 auto operand_type = std::get<0>(entry).cast<TensorType>();
1300 auto result_type = std::get<1>(entry).cast<TensorType>();
1301 if (operand_type == result_type) {
1302 types.push_back(operand_type);
1303 } else if (RankedAndSameRank(operand_type, result_type)) {
1304 auto potential_refined_type =
1305 GetCompatibleRankedTensorType(operand_type.cast<RankedTensorType>(),
1306 result_type.cast<RankedTensorType>());
1307 types.push_back(potential_refined_type);
1308 } else {
1309 auto region_argument_type = std::get<2>(entry).cast<TensorType>();
1310 Type element_type = GetElementTypeFromOperand(
1311 operand_type.cast<TensorType>(), region_argument_type);
1312 Type potential_refined_type = CreateTensorType(
1313 region_argument_type.hasRank() ? region_argument_type.getShape()
1314 : llvm::Optional<ArrayRef<int64_t>>(),
1315 element_type);
1316 types.push_back(potential_refined_type);
1317 }
1318 }
1319 return types;
1320 }
1321
PropagateShapeIntoAttachedFunctions(Operation * op,int64_t max_iteration)1322 LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions(
1323 Operation* op, int64_t max_iteration) {
1324 ModuleOp module = op->getParentOfType<ModuleOp>();
1325 if (auto if_op = dyn_cast<TF::IfOp>(op)) {
1326 DCOMMENT("Propagating shapes into If");
1327 return PropagateShapeToFunctions(
1328 module, if_op.input().getTypes(),
1329 {if_op.then_function(), if_op.else_function()}, max_iteration);
1330 } else if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
1331 SmallVector<FuncOp, 4> branches;
1332 case_op.get_branch_functions(branches);
1333 return PropagateShapeToFunctions(module, case_op.input().getTypes(),
1334 branches, max_iteration);
1335 } else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
1336 // If `shape_invariant` is set, operand shapes cannot be simply propagated
1337 // to result shapes as the op may have different intermediate shapes (such
1338 // While ops can have different result shapes from operand shapes).
1339 // Compatible shapes must be determined before propagating them.
1340 if (while_op.shape_invariant()) {
1341 auto compatible_types = GetWhileCompatibleTypes(
1342 while_op.input().getTypes(), while_op.output().getTypes(),
1343 while_op.body_function().getType().getInputs());
1344 return PropagateShapeToFunctions(
1345 module, compatible_types,
1346 {while_op.cond_function(), while_op.body_function()}, max_iteration);
1347 }
1348 return PropagateShapeToFunctions(
1349 module, while_op.input().getTypes(),
1350 {while_op.cond_function(), while_op.body_function()}, max_iteration);
1351 } else if (auto call_op = dyn_cast<CallOpInterface>(op)) {
1352 if (auto func = dyn_cast<FuncOp>(call_op.resolveCallable())) {
1353 PropagateConstantToCallee(call_op, func, module);
1354 if (failed(PropagateShapeToFunctions(module,
1355 call_op.getArgOperands().getTypes(),
1356 {func}, max_iteration))) {
1357 return failure();
1358 }
1359 PropagateConstantFromCallee(call_op, func, module);
1360 return success();
1361 }
1362 }
1363
1364 // TODO(ycao): Implement support for Call op, including function reuse.
1365
1366 return success();
1367 }
1368
PropagateShapeIntoAttachedRegions(Operation * op,int64_t max_iteration)1369 LogicalResult ShapeInference::PropagateShapeIntoAttachedRegions(
1370 Operation* op, int64_t max_iteration) {
1371 if (auto while_op = dyn_cast<TF::WhileRegionOp>(op)) {
1372 // If `shape_invariant` is set, operand shapes cannot be simply propagated
1373 // to result shapes as the op may have different intermediate shapes (such
1374 // While ops can have different result shapes from operand shapes).
1375 // Compatible shapes must be determined before propagating them.
1376 if (while_op.shape_invariant()) {
1377 auto compatible_types = GetWhileCompatibleTypes(
1378 while_op.input().getTypes(), while_op.output().getTypes(),
1379 while_op.body().getArgumentTypes());
1380 return PropagateShapeToRegions(compatible_types,
1381 {&while_op.cond(), &while_op.body()},
1382 max_iteration);
1383 }
1384 return PropagateShapeToRegions(while_op.input().getTypes(),
1385 {&while_op.cond(), &while_op.body()},
1386 max_iteration);
1387 }
1388 return success();
1389 }
1390
TryToFold(Operation * op)1391 LogicalResult ShapeInference::TryToFold(Operation* op) {
1392 LLVM_DEBUG(op->print(llvm::dbgs() << "TryToFold "); llvm::dbgs() << "\n");
1393 // If any output result is known, then the op probably has been computed
1394 // before.
1395 if (op->getNumResults() > 0 && results_[ValuePort(op->getResult(0))])
1396 return success();
1397
1398 SmallVector<Attribute, 8> constant_operands(op->getNumOperands());
1399 SmallVector<OpFoldResult, 8> fold_results;
1400
1401 // Check to see if any operands to the operation is constant and whether
1402 // the operation knows how to constant fold itself.
1403 bool some_unknown = false;
1404 for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
1405 if (!(constant_operands[i] =
1406 ComputeOutputComponent(ValuePort(op->getOperand(i)))))
1407 some_unknown = true;
1408 }
1409
1410 // Attempt to constant fold the operation.
1411 auto* abstract_op = op->getAbstractOperation();
1412 LogicalResult folded = failure();
1413 if (abstract_op) {
1414 folded = abstract_op->foldHook(op, constant_operands, fold_results);
1415 }
1416 // Attempt dialect fallback if op's fold hook failed.
1417 if (failed(folded)) {
1418 Dialect* dialect = op->getDialect();
1419 if (!dialect) return failure();
1420 // Only attempt TF dialect fallback if there are no unknown operands.
1421 if (some_unknown && dialect == tf_dialect_) return failure();
1422 auto* interface = dialect->getRegisteredInterface<DialectFoldInterface>();
1423 if (!interface) return failure();
1424
1425 if (failed(interface->fold(op, constant_operands, fold_results)))
1426 return failure();
1427 }
1428
1429 for (auto result : zip(op->getResults(), fold_results)) {
1430 auto fold_result = std::get<1>(result);
1431 Attribute attr = nullptr;
1432 if ((attr = fold_result.dyn_cast<Attribute>())) {
1433 RecordValue(ValuePort(std::get<0>(result)), attr);
1434 } else {
1435 auto value = fold_result.get<Value>();
1436 if ((attr = ComputeOutputComponent(ValuePort(value)))) {
1437 DCOMMENT("\t\tValue Result mapped to " << attr);
1438 RecordValue(ValuePort(std::get<0>(result)), attr);
1439 } else {
1440 DCOMMENT("\t\tValue result unmapped, consider value type:" << value);
1441 RefineResultType(op, std::get<0>(result), value.getType());
1442 }
1443 }
1444
1445 if (ElementsAttr eattr = attr.dyn_cast_or_null<ElementsAttr>()) {
1446 if (std::get<0>(result).getType() == eattr.getType()) continue;
1447
1448 UpdateTypeAndInsertIncompatibleUseCasts(eattr.getType(),
1449 std::get<0>(result));
1450 }
1451 }
1452
1453 return success();
1454 }
1455
InferShapeForFunctionReturnType(FuncOp func)1456 void ShapeInference::InferShapeForFunctionReturnType(FuncOp func) {
1457 LLVM_DEBUG(llvm::dbgs() << "Inferring return type for: " << func.getName()
1458 << "\n");
1459
1460 // Find any return ops.
1461 SmallVector<ReturnOp, 4> return_ops;
1462 for (Block& block : func) {
1463 if (auto return_op = dyn_cast<ReturnOp>(block.getTerminator())) {
1464 return_ops.push_back(return_op);
1465 }
1466 }
1467
1468 // Right now we only handle the case of a single return op.
1469 // To handle multiple return ops, we would need to look at all their shapes
1470 // and come up with a common shape and insert appropriate casts.
1471 if (return_ops.size() != 1) return;
1472
1473 // Find the return type.
1474 auto return_op = return_ops.front();
1475
1476 // Manually fold tf.Cast that precedes the return instruction and only differs
1477 // in shape refinement level.
1478 bool changed = false;
1479 for (OpOperand& arg_op : return_op.getOperation()->getOpOperands()) {
1480 Operation* arg_defining_op = arg_op.get().getDefiningOp();
1481 if (auto cast_op = dyn_cast_or_null<CastOp>(arg_defining_op)) {
1482 Value input = cast_op.x();
1483 Value result = cast_op.y();
1484 if (!CanRefineTypeWith(result.getType(), input.getType())) continue;
1485
1486 LLVM_DEBUG({
1487 llvm::errs() << "\tfolding & updating return type ";
1488 cast_op.getResult().getType().print(llvm::errs());
1489 cast_op.getOperand().getType().print(llvm::errs() << " to ");
1490 llvm::errs() << "\n";
1491 });
1492
1493 // Shape inference should not change the element type.
1494 if (HasCompatibleElementTypes(input.getType(), result.getType())) {
1495 arg_op.set(input);
1496 } else {
1497 OpBuilder b(return_op.getOperation());
1498 TensorType type;
1499 if (input.getType().cast<TensorType>().hasRank()) {
1500 type = RankedTensorType::get(
1501 input.getType().cast<TensorType>().getShape(),
1502 result.getType().cast<TensorType>().getElementType());
1503 } else {
1504 type = UnrankedTensorType::get(
1505 result.getType().cast<TensorType>().getElementType());
1506 }
1507 auto new_cast_op =
1508 b.create<TF::CastOp>(return_op.getLoc(), type, input,
1509 /*truncate=*/b.getBoolAttr(false));
1510 arg_op.set(new_cast_op);
1511 }
1512 if (cast_op.y().use_empty()) cast_op.erase();
1513 changed = true;
1514 }
1515 }
1516
1517 DCOMMENT("Updating function type");
1518 func.setType(FunctionType::get(func.getContext(), func.getArgumentTypes(),
1519 return_op.getOperandTypes()));
1520
1521 if (changed) EnqueueCallers(func);
1522 }
1523
InferShapeUntilFixPoint(Region * region,int64_t max_iteration)1524 LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
1525 int64_t max_iteration) {
1526 bool changed = true;
1527
1528 // TODO(aminim): we could have a more efficient traversal by guiding the
1529 // traversal with a worklist and reconsider only the nodes for which an
1530 // operand type was inferred. This would need to be careful if working on a
1531 // region that would not be isolated.
1532 for (int iteration = 0; iteration < max_iteration && changed; ++iteration) {
1533 changed = false;
1534 LLVM_DEBUG(llvm::dbgs()
1535 << "Shape inference, iteration " << iteration << "\n");
1536 region->walk([&](Operation* op) {
1537 DCOMMENT_OP(op, "Inferring for");
1538 if (auto infer_ti = dyn_cast<InferTypeOpInterface>(op)) {
1539 DCOMMENT("\tRefinining with type op interface");
1540 changed |= RefineWithInferTypeOpInterface(infer_ti);
1541 return;
1542 }
1543
1544 if (op->getDialect() != tf_dialect_) {
1545 DCOMMENT("\tInfer non-TF dialect");
1546 changed |= InferShapeForNonTFDialectOperation(op);
1547 return;
1548 }
1549
1550 // Before attempting inference, just try to compute the folded
1551 // value/shape.
1552 if (succeeded(TryToFold(op)) &&
1553 // Folding can "succeed" and yet not all types be refined. In such
1554 // cases we still want to give a try at `InferShapeForSingleOperation`
1555 none_of(op->getResultTypes(), CanBeRefined))
1556 return;
1557
1558 // Best-effort shape inference in attached functions. Do not return
1559 // failure even if it doesn't get to fixed point.
1560 if (failed(PropagateShapeIntoAttachedFunctions(op, max_iteration))) {
1561 op->emitWarning() << "unable to refine shape of attached function "
1562 "arguments and bodies";
1563 }
1564
1565 if (failed(PropagateShapeIntoAttachedRegions(op, max_iteration))) {
1566 op->emitWarning() << "unable to refine shape of attached region "
1567 "arguments and bodies";
1568 }
1569
1570 changed |= InferShapeForSingleOperation(op);
1571 });
1572 }
1573
1574 if (changed) {
1575 return region->getParentOp()->emitWarning()
1576 << "shape inference did not reach stable state after "
1577 << max_iteration << " iterations";
1578 }
1579 return success();
1580 }
1581
InferShapeForFunction(ShapeInference & context,FuncOp func,int64_t max_iterations)1582 static LogicalResult InferShapeForFunction(ShapeInference& context, FuncOp func,
1583 int64_t max_iterations) {
1584 if (failed(context.InferShapeUntilFixPoint(&func.getBody(), max_iterations)))
1585 return failure();
1586 // TODO(b/156276510): Verify that it is always fine to refine a function's
1587 // return type, as long as we do not change the argument shapes.
1588 context.InferShapeForFunctionReturnType(func);
1589
1590 return success();
1591 }
1592
InferShapeForFunction(FuncOp func,ArrayRef<ArrayRef<int64_t>> arg_shapes,int64_t graph_version,int64_t max_iterations)1593 LogicalResult InferShapeForFunction(FuncOp func,
1594 ArrayRef<ArrayRef<int64_t>> arg_shapes,
1595 int64_t graph_version,
1596 int64_t max_iterations) {
1597 ShapeInference context(graph_version, func.getContext(),
1598 /*propagate_caller_callee_constants=*/true);
1599 if (arg_shapes.empty()) {
1600 return InferShapeForFunction(context, func, max_iterations);
1601 }
1602
1603 FunctionType func_type = func.getType();
1604 bool needs_refinement = false;
1605 SmallVector<Type, 4> new_arg_types;
1606 new_arg_types.reserve(func_type.getNumInputs());
1607
1608 // Update argument types in-place using the provided arg_shapes.
1609 for (size_t i = 0; i < func_type.getNumInputs(); ++i) {
1610 ArrayRef<int64_t> shape = arg_shapes[i];
1611 Type element_type;
1612 if (auto input_ty = func_type.getInput(i).dyn_cast<RankedTensorType>()) {
1613 if (input_ty.getRank() != shape.size()) {
1614 return failure();
1615 }
1616 element_type = input_ty.getElementType();
1617 } else {
1618 auto unranked_input_ty = func_type.getInput(i).dyn_cast<TensorType>();
1619 if (!unranked_input_ty) {
1620 return failure();
1621 }
1622 element_type = unranked_input_ty.getElementType();
1623 }
1624
1625 auto new_arg_type = RankedTensorType::get(shape, element_type);
1626 if (new_arg_type != func_type.getInput(i)) {
1627 // If the new type is more detailed, trigger shape inference.
1628 func.getArgument(i).setType(new_arg_type);
1629 needs_refinement = true;
1630 }
1631 new_arg_types.push_back(new_arg_type);
1632 }
1633
1634 if (!needs_refinement) return success();
1635
1636 if (failed(context.InferShapeUntilFixPoint(&func.getBody(), max_iterations)))
1637 return failure();
1638
1639 context.InferShapeForFunctionReturnType(func);
1640 func.setType(FunctionType::get(func.getContext(), new_arg_types,
1641 func.getType().getResults()));
1642
1643 return success();
1644 }
1645
InferModuleShape(ModuleOp module,int64_t max_iterations)1646 LogicalResult InferModuleShape(ModuleOp module, int64_t max_iterations) {
1647 auto producer_or = tensorflow::GetTfGraphProducerVersion(module);
1648 if (!producer_or.ok()) {
1649 // TODO(jpienaar): Keeping the existing behavior for now but this could
1650 // be relaxed.
1651 LLVM_DEBUG(llvm::dbgs()
1652 << "Skipping inference; " << producer_or.status().ToString());
1653 return success();
1654 }
1655 int64_t producer = producer_or.ValueOrDie();
1656 // TODO(jpienaar): Clean up propagate_NextIterationSinkOp_callee_constants if
1657 // it is no longer needed.
1658 ShapeInference context(producer, module.getContext(),
1659 /*propagate_caller_callee_constants=*/false);
1660 if (auto main = module.lookupSymbol<mlir::FuncOp>("main"))
1661 context.enqueue(main);
1662 for (auto func : module.getOps<FuncOp>()) context.enqueue(func);
1663 // Arbitrarily upper bound the maximum number of functions that get processed
1664 // just to avoid pathological cases.
1665 auto max_iteration = context.QueueSize() * 4;
1666 while (!context.EmptyQueue()) {
1667 FuncOp func = context.front();
1668 if (failed(InferShapeForFunction(context, func, max_iterations)))
1669 return failure();
1670 context.pop_front();
1671
1672 if ((--max_iteration) == 0) {
1673 return emitWarning(UnknownLoc::get(module.getContext()))
1674 << "shape inference did not reach stable state after "
1675 << max_iteration << " iterations";
1676 }
1677 }
1678 return success();
1679 }
1680
1681 } // namespace TF
1682 } // namespace mlir
1683