1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <initializer_list>
21 #include <iterator>
22 #include <queue>
23 #include <stack>
24
25 #include "llvm/ADT/Hashing.h"
26 #include "llvm/ADT/None.h"
27 #include "llvm/ADT/PointerUnion.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/iterator_range.h"
31 #include "llvm/Support/Casting.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Support/FormatVariadic.h"
34 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
35 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
36 #include "mlir/IR/Attributes.h" // from @llvm-project
37 #include "mlir/IR/Block.h" // from @llvm-project
38 #include "mlir/IR/Builders.h" // from @llvm-project
39 #include "mlir/IR/BuiltinDialect.h" // from @llvm-project
40 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
41 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
42 #include "mlir/IR/Diagnostics.h" // from @llvm-project
43 #include "mlir/IR/Location.h" // from @llvm-project
44 #include "mlir/IR/Operation.h" // from @llvm-project
45 #include "mlir/IR/OperationSupport.h" // from @llvm-project
46 #include "mlir/IR/SymbolTable.h" // from @llvm-project
47 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
48 #include "mlir/IR/Value.h" // from @llvm-project
49 #include "mlir/IR/Visitors.h" // from @llvm-project
50 #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
51 #include "mlir/Interfaces/FoldInterfaces.h" // from @llvm-project
52 #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
53 #include "mlir/Pass/Pass.h" // from @llvm-project
54 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
55 #include "mlir/Support/LLVM.h" // from @llvm-project
56 #include "mlir/Support/LogicalResult.h" // from @llvm-project
57 #include "mlir/Transforms/FoldUtils.h" // from @llvm-project
58 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
59 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
60 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
61 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
62 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
63 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
64 #include "tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h"
65 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
66 #include "tensorflow/core/framework/shape_inference.h"
67 #include "tensorflow/core/framework/types.pb.h"
68 #include "tensorflow/core/ir/types/dialect.h"
69
70 #define DEBUG_TYPE "tf-shape-inference"
71
72 #define DCOMMENT(MSG) LLVM_DEBUG(llvm::dbgs() << MSG << "\n")
73 #define DCOMMENT_OP(OP, MSG) \
74 LLVM_DEBUG(OP->print(llvm::dbgs() << MSG << " "); llvm::dbgs() << "\n")
75
76 using ::tensorflow::int64;
77 using tensorflow::shape_inference::DimensionHandle;
78 using tensorflow::shape_inference::InferenceContext;
79 using tensorflow::shape_inference::ShapeHandle;
80
81 namespace mlir {
82 namespace TF {
83 namespace {
84
85 // Compute a refined type between two types `lhs` and `rhs`, the result type
86 // is always more refined (i.e. has more static information) than `lhs`
87 // This method will actually merge the information contained in the
88 // types, it is capable of refining:
89 // tensor<!tf_type.variant<tensor<?x8xf32>>>
90 // and:
91 // tensor<!tf_type.variant<tensor<10x?xf32>>>
92 // into:
93 // tensor<!tf_type.variant<tensor<10x8xf32>>>
94 //
95 // In case of inconsistencies (rank disagreement for example), it returns `lhs`.
TypeMeet(Type lhs,Type rhs)96 Type TypeMeet(Type lhs, Type rhs) {
97 DCOMMENT("RefineTypeWith : " << lhs << " : " << rhs);
98 if (lhs == rhs) return lhs;
99
100 auto rhs_shape_type = rhs.dyn_cast<ShapedType>();
101 if (!rhs_shape_type) return lhs;
102 auto lhs_shape_type = lhs.cast<ShapedType>();
103 if (lhs_shape_type.hasRank() && rhs_shape_type.hasRank() &&
104 lhs_shape_type.getRank() != rhs_shape_type.getRank()) {
105 DCOMMENT("Unexpected rank mismatch: " << lhs << " vs " << rhs);
106 return lhs;
107 }
108
109 SmallVector<int64_t> shape;
110 bool refined_shape = false;
111 // Build the shape of the refined type, if lhs is unranked it
112 // will be directly the shape of the refined type, otherwise we merged by
113 // taking the most specialized. This combines `10x?x?` and `?x?x8` into
114 // `10x?x8`.
115 if (!lhs_shape_type.hasRank()) {
116 if (rhs_shape_type.hasRank()) {
117 shape.append(rhs_shape_type.getShape().begin(),
118 rhs_shape_type.getShape().end());
119 refined_shape = true;
120 }
121 } else if (rhs_shape_type.hasRank()) {
122 for (auto shape_elts : llvm::enumerate(
123 llvm::zip(lhs_shape_type.getShape(), rhs_shape_type.getShape()))) {
124 if (ShapedType::isDynamic(std::get<0>(shape_elts.value())) &&
125 !ShapedType::isDynamic(std::get<1>(shape_elts.value()))) {
126 shape.push_back(std::get<1>(shape_elts.value()));
127 refined_shape = true;
128 DCOMMENT("-> refining shape element #" << shape_elts.index());
129 } else {
130 DCOMMENT("-> not refining shape element #" << shape_elts.index());
131 shape.push_back(std::get<0>(shape_elts.value()));
132 }
133 }
134 }
135
136 // Some tensor have an element type wrapping a subtensor, like resource and
137 // variants. In this case we may recurse on the wrapped subtype.
138 // `element_type` will contain the refined inferred element type for the
139 // returned type.
140 auto lhs_element_type = lhs_shape_type.getElementType();
141 auto rhs_element_type_with_subtype =
142 rhs_shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>();
143 // Look for resource or variant element type and ensure we refine the subtype.
144 // We only support a single subtype at the moment, we won't handle something
145 // like:
146 // tensor<!tf_type.variant<tensor<10xf32>, tensor<8xf32>>
147 if (rhs_element_type_with_subtype &&
148 rhs_element_type_with_subtype.GetSubtypes().size() == 1) {
149 auto lhs_element_type_with_subtype =
150 lhs_element_type.dyn_cast<TF::TensorFlowTypeWithSubtype>();
151 TensorType subtype;
152 if (!lhs_element_type_with_subtype) {
153 DCOMMENT(
154 "Unexpected inferred `TensorFlowTypeWithSubtype` when original "
155 "result isn't");
156 } else if (lhs_element_type_with_subtype.GetSubtypes().size() > 1) {
157 DCOMMENT(
158 "Unexpected `TensorFlowTypeWithSubtype` original type with size>1");
159 } else if (lhs_element_type_with_subtype.GetSubtypes().empty()) {
160 subtype = rhs_element_type_with_subtype.GetSubtypes().front();
161 } else {
162 // Recurse on the subtypes in the variant/resource. Basically if the input
163 // were:
164 // tensor<!tf_type.variant<tensor<?x8xf32>>>
165 // and:
166 // tensor<!tf_type.variant<tensor<10x8xf32>>>
167 // we'll try here to refine tensor<?x8xf32> with tensor<10x8xf32>.
168 auto refined_subtype =
169 TypeMeet(lhs_element_type_with_subtype.GetSubtypes().front(),
170 rhs_element_type_with_subtype.GetSubtypes().front())
171 .cast<TensorType>();
172 if (refined_subtype !=
173 lhs_element_type_with_subtype.GetSubtypes().front())
174 subtype = refined_subtype;
175 }
176 // If we managed to refine the subtype, recreate the element type itself
177 // (i.e. the tf.variant or tf.resource).
178 if (subtype) {
179 lhs_element_type = lhs_element_type_with_subtype.clone({subtype});
180 }
181 }
182 if (refined_shape || lhs_element_type != lhs_shape_type.getElementType()) {
183 Type new_type;
184 if (!lhs_shape_type.hasRank() && !rhs_shape_type.hasRank())
185 new_type = UnrankedTensorType::get(lhs_element_type);
186 else
187 new_type = lhs_shape_type.clone(shape, lhs_element_type);
188 DCOMMENT("Refined to: " << new_type);
189 return new_type;
190 }
191 DCOMMENT("No refinement " << lhs);
192 return lhs;
193 }
194
195 // Returns whether `original_type` type can be refined with
196 // `potential_refined_type` type.
CanRefineTypeWith(Type original_type,Type potential_refined_type)197 bool CanRefineTypeWith(Type original_type, Type potential_refined_type) {
198 return original_type != TypeMeet(original_type, potential_refined_type);
199 }
200
201 // Returns if the shape inference pass supports an op outside the TF dialect.
IsSupportedNonTFOp(Operation * op)202 bool IsSupportedNonTFOp(Operation* op) {
203 return isa<tf_device::ReturnOp, tf_device::ClusterOp, tf_device::LaunchOp,
204 tf_executor::EnterOp, tf_executor::ExitOp, tf_executor::FetchOp,
205 tf_executor::GraphOp, tf_executor::IslandOp,
206 tf_executor::LoopCondOp, tf_executor::MergeOp,
207 tf_executor::NextIterationSinkOp, tf_executor::SwitchNOp,
208 tf_executor::SwitchOp, tf_executor::YieldOp>(op) ||
209 isa<InferTypeOpInterface>(op);
210 }
211
212 // Returns whether a cast back would need to be inserted, e.g., whether the
213 // operation of which use is an operand allows for shape refinement without
214 // a cast.
NeedsCastBack(OpOperand & use,Dialect * tf_dialect)215 bool NeedsCastBack(OpOperand& use, Dialect* tf_dialect) {
216 return use.getOwner()->getDialect() != tf_dialect &&
217 !IsSupportedNonTFOp(use.getOwner());
218 }
219
CreateTensorType(llvm::Optional<llvm::ArrayRef<int64_t>> shape,Type element_type)220 TensorType CreateTensorType(llvm::Optional<llvm::ArrayRef<int64_t>> shape,
221 Type element_type) {
222 if (shape.hasValue())
223 return RankedTensorType::get(shape.getValue(), element_type);
224 return UnrankedTensorType::get(element_type);
225 }
226
227 // Returns true if the op creates a TensorList.
IsTensorListInitOp(Operation * op)228 bool IsTensorListInitOp(Operation* op) {
229 return isa<TensorListReserveOp>(op) || isa<EmptyTensorListOp>(op) ||
230 isa<TensorListFromTensorOp>(op);
231 }
232
233 // Returns the `element_shape` operand of the ops that create a TensorList.
GetElementShapeOperand(Operation * op)234 Value GetElementShapeOperand(Operation* op) {
235 if (auto empty_tl = dyn_cast<EmptyTensorListOp>(op))
236 return empty_tl.element_shape();
237 if (auto tl_reserve = dyn_cast<TensorListReserveOp>(op))
238 return tl_reserve.element_shape();
239 if (auto tl_from_tensor = dyn_cast<TensorListFromTensorOp>(op))
240 return tl_from_tensor.element_shape();
241 llvm_unreachable("unsupported TensorList op");
242 }
243
244 // Utility function to create a ranked tensor type after dropping the first
245 // dimension from the input type.
DropFirstDimension(Type type)246 RankedTensorType DropFirstDimension(Type type) {
247 RankedTensorType ranked_type = type.dyn_cast<RankedTensorType>();
248 if (!ranked_type) return {};
249 llvm::ArrayRef<int64_t> dims_except_first =
250 ranked_type.getShape().drop_front();
251 return RankedTensorType::get(dims_except_first, ranked_type.getElementType());
252 }
253
InsertCast(OpBuilder & b,Location loc,Type dst_type,Value input)254 Operation* InsertCast(OpBuilder& b, Location loc, Type dst_type, Value input) {
255 Type element_type = getElementTypeOrSelf(dst_type);
256 if (element_type.isa<IndexType>())
257 return b.create<tensor::CastOp>(loc, dst_type, input);
258 if (isa<TensorFlowDialect, BuiltinDialect>(element_type.getDialect()))
259 return b.create<TF::CastOp>(loc, dst_type, input,
260 /*truncate=*/b.getBoolAttr(false));
261 return nullptr;
262 }
263
264 // Follow the use chain of TensorList and return true iff all elements written
265 // to TensorList have same static shape. If all elements have same shape, assign
266 // it to `potential_element_type`.
267 //
268 // This can handle multiple mutations of a TensorList object and would return
269 // true if across all mutations the elements written have the same shape.
CanInferTensorListElementType(Value tensorlist,Value initial_element_shape,RankedTensorType * potential_element_type)270 bool CanInferTensorListElementType(Value tensorlist,
271 Value initial_element_shape,
272 RankedTensorType* potential_element_type) {
273 DCOMMENT("CanInferTensorListElementType " << tensorlist << " with initial "
274 << initial_element_shape);
275 // Verifies if the new element type has static shape and matches the potential
276 // type passed from caller. Updates the potential_element_type, if not defined
277 // yet.
278 auto verify_and_update_potential_element_type =
279 [&](RankedTensorType new_element_type) -> bool {
280 DCOMMENT("\t\tConsidering " << new_element_type << " with old "
281 << *potential_element_type);
282 if (!new_element_type || !new_element_type.hasStaticShape()) return false;
283 if (!*potential_element_type) {
284 DCOMMENT("\t\tUpdating potential_element_type " << new_element_type);
285 *potential_element_type = new_element_type;
286 return true;
287 }
288 return *potential_element_type == new_element_type;
289 };
290
291 std::stack<Value> worklist;
292 worklist.emplace(tensorlist);
293
294 while (!worklist.empty()) {
295 tensorlist = worklist.top();
296 worklist.pop();
297
298 // TensorLists are semantically immutable. For example, TensorListSetItem
299 // takes a TensorList as input and produces a TensorList as output. So to
300 // traverse modifications to TensorList and verify that all elements written
301 // to it have the same shape, we need to follow use-def chain of ops that
302 // (conceptually) modify it i.e., ops that take an input TensorList and
303 // produce an output TensorList.
304 for (auto& use : tensorlist.getUses()) {
305 if (auto push = llvm::dyn_cast<TensorListPushBackOp>(use.getOwner())) {
306 auto element_type =
307 push.tensor().getType().dyn_cast<RankedTensorType>();
308 if (!verify_and_update_potential_element_type(element_type))
309 return false;
310 worklist.emplace(push.output_handle());
311 continue;
312 }
313 if (auto scatter = llvm::dyn_cast<TensorListScatterIntoExistingListOp>(
314 use.getOwner())) {
315 // For scatter op we can get the element shape by dropping the first
316 // dimension of the input tensor.
317 RankedTensorType element_type =
318 DropFirstDimension(scatter.tensor().getType());
319 if (!verify_and_update_potential_element_type(element_type))
320 return false;
321 worklist.emplace(scatter.output_handle());
322 continue;
323 }
324 if (auto set_item = llvm::dyn_cast<TensorListSetItemOp>(use.getOwner())) {
325 auto element_type =
326 set_item.item().getType().dyn_cast<RankedTensorType>();
327 DCOMMENT("\tTensorListSetItemOp " << element_type);
328 if (!verify_and_update_potential_element_type(element_type))
329 return false;
330 worklist.emplace(set_item.output_handle());
331 continue;
332 }
333 if (auto pop = llvm::dyn_cast<TensorListPopBackOp>(use.getOwner())) {
334 worklist.emplace(pop.output_handle());
335 continue;
336 }
337 if (auto resize = llvm::dyn_cast<TensorListResizeOp>(use.getOwner())) {
338 worklist.emplace(resize.output_handle());
339 continue;
340 }
341 // WhileRegionOp can explicitly capture TensorList value to be used inside
342 // its regions. So we check the uses of corresponding block argument in
343 // each region and the use of TensorList returned using YieldOp.
344 if (auto while_region = llvm::dyn_cast<WhileRegionOp>(use.getOwner())) {
345 DCOMMENT("\tTL WhileRegion");
346 for (auto branch : while_region.getRegions())
347 worklist.emplace(branch->getArgument(use.getOperandNumber()));
348 continue;
349 }
350 if (auto yield = llvm::dyn_cast<YieldOp>(use.getOwner())) {
351 Operation* parent = yield->getParentOp();
352 worklist.emplace(parent->getResult(use.getOperandNumber()));
353 continue;
354 }
355 // TODO(jpienaar): This can be generalized.
356 if (isa<IdentityOp, IdentityNOp, StopGradientOp>(use.getOwner())) {
357 worklist.emplace(use.getOwner()->getResult(use.getOperandNumber()));
358 continue;
359 }
360 // Refining the tensor list element type might change the output of
361 // TensorListElementShape which is expected to be the originally assigned
362 // shape to TensorList init ops. So replace it with the original element
363 // shape value.
364 if (auto tl_element_shape =
365 dyn_cast<TensorListElementShapeOp>(use.getOwner())) {
366 // If element types match, we can do a direct replacement.
367 if (getElementTypeOrSelf(tl_element_shape.getResult()) ==
368 getElementTypeOrSelf(initial_element_shape.getType())) {
369 tl_element_shape.replaceAllUsesWith(initial_element_shape);
370 } else {
371 OpBuilder b(use.getOwner());
372 Operation* cast_op = InsertCast(
373 b, use.getOwner()->getLoc(),
374 tl_element_shape.getResult().getType(), initial_element_shape);
375 if (!cast_op) return false;
376 tl_element_shape.replaceAllUsesWith(cast_op->getResult(0));
377 }
378 continue;
379 }
380 // Ignore ops that just consume a TensorList and do not output another
381 // TensorList.
382 if (isa<TensorListStackOp, TensorListGatherOp, TensorListConcatV2Op,
383 TensorListLengthOp, TensorListGetItemOp>(use.getOwner()))
384 continue;
385
386 // For any other unknown users of the TensorList, we are conservative and
387 // stop element shape inference.
388 DCOMMENT("TensorListType infer, unknown op " << *use.getOwner());
389 return false;
390 }
391 }
392 return true;
393 }
394 } // namespace
395
396 // Returns whether type can be further refined.
CanBeRefined(Type type)397 bool CanBeRefined(Type type) {
398 auto shape_type = type.dyn_cast<ShapedType>();
399 if (!shape_type) return false;
400
401 // Returns whether type with subtypes can be further refined.
402 auto can_refine_subtypes = [](TF::TensorFlowTypeWithSubtype tws) {
403 return tws.GetSubtypes().empty() ||
404 llvm::any_of(tws.GetSubtypes(), CanBeRefined);
405 };
406 auto type_with_subtype =
407 shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>();
408 if (type_with_subtype && can_refine_subtypes(type_with_subtype)) return true;
409
410 return !shape_type.hasStaticShape();
411 }
412
413 // Combination of value producer and port of value produced (e.g.,
414 // <value result output>:<value in output tensor>,
415 // so for tf.Const -> tensor<10x20xf32>, [0,2,18] would point to a unique output
416 // scalar value).
417 struct ValuePort {
418 PointerUnion<Operation*, BlockArgument> producer;
419 SmallVector<unsigned int, 2> port;
420
operator ==mlir::TF::ValuePort421 bool operator==(const ValuePort& other) const {
422 return producer == other.producer && port == other.port;
423 }
424
425 // Convert output value to ValuePort.
ValuePortmlir::TF::ValuePort426 explicit ValuePort(Value v) {
427 OpResult opr = v.dyn_cast<OpResult>();
428 if (opr) {
429 producer = opr.getOwner();
430 port = {opr.getResultNumber()};
431 } else {
432 producer = v.cast<BlockArgument>();
433 port = {0};
434 }
435 }
ValuePortmlir::TF::ValuePort436 ValuePort(PointerUnion<Operation*, BlockArgument> producer,
437 SmallVector<unsigned int, 2> port)
438 : producer(producer), port(port) {}
439
printmlir::TF::ValuePort440 raw_ostream& print(raw_ostream& os) const {
441 if (auto* op = producer.dyn_cast<Operation*>())
442 os << "op " << op->getName();
443 if (auto ba = producer.dyn_cast<BlockArgument>())
444 os << "block_arg " << ba.getArgNumber();
445 os << formatv(" [{0}]", llvm::make_range(port.begin(), port.end()));
446 return os;
447 }
448 };
449
450 struct ValuePortHasher {
operator ()mlir::TF::ValuePortHasher451 std::size_t operator()(const ValuePort& other) const {
452 return hash_combine(llvm::hash_value(other.producer.getOpaqueValue()),
453 hash_value(ArrayRef<unsigned int>(other.port)));
454 }
455 };
456
457 using ValuePortResultMap =
458 std::unordered_map<ValuePort, Attribute, ValuePortHasher>;
459 using ComputedQueryFn = function_ref<bool(ValuePort)>;
460 using ValueQueryFn = function_ref<Attribute(const ValuePort&)>;
461 using ValuePortInputs = SmallVectorImpl<ValuePort>;
462
463 // TODO(jpienaar): ComputeInputsRequiredForOutput and ComputeOutputComponent are
464 // intended to be switched to op interfaces once more refined.
ComputeInputsRequiredForOutput(ValuePort value_port,ComputedQueryFn has_been_computed,ValuePortInputs * inputs)465 LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
466 ComputedQueryFn has_been_computed,
467 ValuePortInputs* inputs) {
468 auto op = value_port.producer.dyn_cast<Operation*>();
469 auto& port = value_port.port;
470 if (!op) return failure();
471
472 // No inputs required for constants.
473 if (matchPattern(op, m_Constant())) return success();
474
475 // Note: this focusses only on the trivial pack op case and this could be
476 // generalized.
477 if (auto pack_op = dyn_cast<TF::PackOp>(op)) {
478 auto type = pack_op.getType().cast<TensorType>();
479 if (!type.hasRank() || type.getRank() != 1) return failure();
480 if (port.size() != 2) return failure();
481 assert(port[0] == 0);
482 ValuePort req(pack_op.getOperand(port[1]));
483 if (!has_been_computed(req)) inputs->push_back(req);
484 return success();
485 }
486
487 return failure();
488 }
489
490 // Computes the output produced by ValuePort using the query function of
491 // existing computed values.
ComputeOutputComponent(const ValuePort & value_port,ValueQueryFn values)492 Attribute ComputeOutputComponent(const ValuePort& value_port,
493 ValueQueryFn values) {
494 LLVM_DEBUG(value_port.print(llvm::dbgs() << "Computing output for ") << "\n");
495 if (auto known = values(value_port)) return known;
496
497 auto op = value_port.producer.dyn_cast<Operation*>();
498 if (!op) return nullptr;
499 auto& port = value_port.port;
500
501 if (port.empty()) {
502 LLVM_DEBUG(llvm::dbgs() << "skipping, port outside spec of " << op << "\n");
503 return nullptr;
504 }
505
506 ElementsAttr attr;
507 if (matchPattern(op, m_Constant(&attr))) {
508 if (port.size() == 1 && port[0] == 0) return attr;
509 return nullptr;
510 }
511
512 if (auto id = dyn_cast<IdentityOp>(op)) {
513 if (port.size() == 1 && port[0] == 0)
514 return ComputeOutputComponent(ValuePort(id.input()), values);
515 return nullptr;
516 }
517
518 // Note: this focusses only on the trivial pack op case and this could be
519 // generalized.
520 if (auto pack_op = dyn_cast<TF::PackOp>(op)) {
521 TensorType type = pack_op.getType().cast<TensorType>();
522 if (!type.hasRank() || type.getRank() != 1) return nullptr;
523 if (port.size() != 2 || port[0] != 0) return nullptr;
524 ValuePort op_port(op->getOperand(port[1]));
525 return values(op_port);
526 }
527
528 if (auto graph = dyn_cast<tf_executor::GraphOp>(op)) {
529 if (port.size() == 1)
530 return ComputeOutputComponent(
531 ValuePort(graph.GetFetch().fetches()[port[0]]), values);
532 return nullptr;
533 }
534
535 if (auto island = dyn_cast<tf_executor::IslandOp>(op)) {
536 if (port.size() == 1)
537 return ComputeOutputComponent(
538 ValuePort(island.GetYield().fetches()[port[0]]), values);
539 return nullptr;
540 }
541
542 return nullptr;
543 }
544
545 // Context used during ShapeInference. This class contains common information
546 // that is required by the individual shape inference helper functions (e.g.,
547 // TF Graph version, constant values computed, etc.)
548 class ShapeInference {
549 public:
550 ShapeInference(int64_t graph_version, ModuleOp module,
551 bool propagate_caller_callee_constants);
552
ComputeInputsRequiredForOutput(ValuePort value_port,ValuePortInputs * inputs)553 LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
554 ValuePortInputs* inputs) {
555 return ::mlir::TF::ComputeInputsRequiredForOutput(
556 value_port,
557 [this](const ValuePort& port) {
558 return results_.find(port) != results_.end();
559 },
560 inputs);
561 }
562
ComputeOutputComponent(const ValuePort & value_port)563 Attribute ComputeOutputComponent(const ValuePort& value_port) {
564 if (auto known_attr = results_[value_port]) return known_attr;
565 auto attr = ::mlir::TF::ComputeOutputComponent(
566 value_port, [this](const ValuePort& port) { return results_[port]; });
567 RecordValue(value_port, attr);
568 return attr;
569 }
570
571 // Returns ShapeHandle if the op result could be computed as shape.
572 ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic);
573
RecordValue(const ValuePort & value_port,Attribute value)574 void RecordValue(const ValuePort& value_port, Attribute value) {
575 LLVM_DEBUG(value_port.print(llvm::dbgs() << "\trecording ")
576 << value << "\n");
577 results_[value_port] = value;
578 }
579
580 // Infers shape of tf.While/tf.WhileRegion. If `shape_invariant` attribute is
581 // set, operand types are set as result types if associated body result types
582 // match the operand type (does not change per loop iteration). If operand and
583 // body result types are not the same, only handle types are propagated to
584 // result types. This is necessary to not incorrectly change result shapes
585 // when the While op will have a different result shape. Otherwise operand
586 // shapes are propagated to result shapes.
587 template <typename WhileOpTy>
588 bool InferShapeForWhile(WhileOpTy op, TypeRange body_result_types);
589
590 // Performs shape inference on the provided op and return true if the type of
591 // at least one result has been changed.
592 // A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect.
593 // `graph_version` indicates the current GraphDef compatibility versions
594 // (the versions field in graph.proto).
595 bool InferShapeForSingleOperation(Operation* op);
596
597 // Infers shape on the provided region, including nested ones, iterate until
598 // fix point with a limit of max_iteration.
599 // Returns a failure() on error, otherwise returns true to indicate that it
600 // reached convergence, false otherwise.
601 FailureOr<bool> InferShapeUntilFixPoint(Region* region,
602 int64_t max_iterations);
603
604 // Updates input types and refine shapes inside body of functions that are
605 // attached to ControlFlow ops (If/While) or Calls. These functions include
606 // Then/Else branches of IfOp and Cond/Body functions of WhileOp. Functions
607 // attached to control flow share following common properties:
608 // 1) They are never reused, ie. having a single use in module.
609 // 2) Their input types match those of their parent ops (excluding inputs
610 // like predicate).
611 // For calls, functions can be reused across multiple call sites. In this case
612 // we propagate the types when all call sites have the same operand types.
613 // Returns a failure() on error, otherwise returns true to indicate that it
614 // reached convergence, false otherwise.
615 FailureOr<bool> PropagateShapeToFunctions(ModuleOp module,
616 TypeRange input_types,
617 ArrayRef<FuncOp> functions,
618 int64_t max_iteration);
619
620 // Propagates shapes to regions given the shapes of the inputs of the regions.
621 // All regions provided in `regions` are assumed to have inputs of type
622 // `input_types`.
623 // Returns a failure() on error, otherwise returns true to indicate that it
624 // reached convergence, false otherwise.
625 FailureOr<bool> PropagateShapeToRegions(TypeRange input_types,
626 ArrayRef<Region*> regions,
627 int64_t max_iteration);
628
629 // Shape propagation for call/control flow ops.
630 // Returns a failure() on error, otherwise returns true to indicate that it
631 // reached convergence, false otherwise.
632 FailureOr<bool> PropagateShapeIntoAttachedFunctions(Operation* op,
633 int64_t max_iteration);
634
635 // Shape propagation for region based control flow.
636 // Returns a failure() on error, otherwise returns true to indicate that it
637 // reached convergence, false otherwise.
638 FailureOr<bool> PropagateShapeIntoAttachedRegions(Operation* op,
639 int64_t max_iterations);
640
641 // Propagates any constant operand of call_op to the called function body's
642 // corresponding argument if the callee has only one use.
643 //
644 // TODO(b/154065712): Move this to a more general inter-procedural constant
645 // folding pass.
646 void PropagateConstantToCallee(CallOpInterface call_op, FuncOp func,
647 ModuleOp module);
648
649 // Propagates any constant return value of the callee function to the call
650 // op's corresponding result.
651 void PropagateConstantFromCallee(CallOpInterface call_op, FuncOp func,
652 ModuleOp module);
653
654 // Tries to compute the result of folding the op. This doesn't actually
655 // perform constant folding, it is just computes the equivalent constants.
656 // Returns whether it was able to compute constant values.
657 LogicalResult TryToFold(Operation* op);
658
659 // Makes result types match the operand types (the i-th result type will
660 // match the i-th operand type). Returns true if anything is changed.
661 bool RefineTypeForPassThroughOperands(Operation* op, OperandRange operands,
662 ResultRange results);
663
664 // Makes result type's shape match the corresponding operand's shape.
665 // Returns whether any change was made.
666 bool RefineShapeForPassThroughOps(Operation* op);
667
668 // Infers shape for necessary ops that are not in the TF dialect. Returns
669 // whether any result type changed.
670 bool InferShapeForNonTFDialectOperation(Operation* op);
671
672 // Infers shape for function return type and returns whether changed.
673 LogicalResult InferShapeForFunctionReturnType(FuncOp func);
674
675 // Enqueues function for processing.
enqueue(FuncOp fn)676 void enqueue(FuncOp fn) {
677 LLVM_DEBUG(llvm::dbgs()
678 << "enqueue " << fn.getName() << " ("
679 << (queue_set_.count(fn) ? "already inserted" : "newly inserted")
680 << ")\n");
681 if (queue_set_.insert(fn).second) queue_.push(fn);
682 }
683
684 // Enqueues callers on functions.
685 void EnqueueCallers(FuncOp fn);
686
687 // Returns the function at the front of the queue.
front()688 FuncOp front() { return queue_.front(); }
689
690 // Returns whether work queue is empty.
EmptyQueue() const691 bool EmptyQueue() const { return queue_.empty(); }
692
693 // Returns function from the front of the work queue.
pop_front()694 FuncOp pop_front() {
695 FuncOp ret = queue_.front();
696 queue_.pop();
697 queue_set_.erase(ret);
698 return ret;
699 }
700
701 // Returns the current size of the queue.
QueueSize() const702 std::queue<FuncOp>::size_type QueueSize() const { return queue_.size(); }
703
704 Dialect* const tf_dialect_;
705
706 private:
707 // Returns whether the result of an operation could be updated to a new
708 // inferred type. Also inserts cast operation for uses that are incompatible
709 // with the new type.
710 bool UpdateTypeAndInsertIncompatibleUseCasts(Type new_type, Value result);
711
712 // Refines the type of `result` of `op` using the type
713 // `potential_refined_type`. Return true if the type was changed.
714 bool RefineResultType(Operation* op, Value result,
715 Type potential_refined_type);
716
717 // Infers the shape from a (Stateful)PartionedCall operation by looking up the
718 // called function and propagating the return type.
719 bool InferShapeForCall(CallOpInterface call_op);
720
721 bool InferShapeForCast(Operation* op);
722
723 // Infers the shape IfOp outputs based on the shapes of the then and else
724 // function result types.
725 bool InferShapeForIf(IfOp op);
726
727 // Infers the shape IfRegion outputs based on the shapes of the then and else
728 // yields.
729 bool InferShapeForIfRegion(IfRegionOp op);
730
731 // Infers the shape of _XlaHostComputeMlir based on the host computation
732 // module. Returns true if a return type was changed.
733 bool InferShapeForXlaHostComputeMlir(_XlaHostComputeMlirOp op);
734
735 // Infers the shape for MapDatasetOp and its associated function. Returns
736 // whether either op or function type was changed.
737 bool InferShapeForMapDataset(MapDatasetOp op);
738
739 // Infers the shape of ops that create TensorList. Specifically,
740 // TensorListReserveOp, EmptyTensorListOp and TensorListFromTensor ops. It
741 // refines the element shape if all tensors written to the list across all
742 // mutations have identical static shape.
743 bool InferShapeForTensorListInitOps(Operation* op);
744
745 bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti);
746
747 // Returns all the callers of a function.
748 // Note: Usage of the return value of this function may not be interleaved
749 // with insertions to the callers map. This could occur if GetCallers is
750 // called with two separate functions, the 2nd one incurs a resize and then
751 // both first and 2nd stored callers are used.
752 ArrayRef<Operation*> GetCallers(FuncOp fn);
753
754 // Mapping between ValuePort (which corresponds to an OpResult or smaller,
755 // e.g., first element of OpResult produced) to an Attribute if the ValuePort
756 // corresponds to a constant value.
757 ValuePortResultMap results_;
758
759 // Map from a function to the callers of that function.
760 SymbolTableCollection symbol_table_;
761 SymbolUserMap symbol_users_;
762
763 // Queue of functions being processed.
764 llvm::DenseSet<FuncOp> queue_set_;
765 std::queue<FuncOp> queue_;
766
767 int64_t graph_version_;
768
769 // TODO(b/154065712): Remove propagate_caller_callee_constants once using
770 // SCCP pass instead.
771 bool propagate_caller_callee_constants_;
772 };
773
ShapeInference(int64_t graph_version,ModuleOp module,bool propagate_caller_callee_constants)774 ShapeInference::ShapeInference(int64_t graph_version, ModuleOp module,
775 bool propagate_caller_callee_constants)
776 : tf_dialect_(module->getContext()->getLoadedDialect<TensorFlowDialect>()),
777 symbol_users_(symbol_table_, module),
778 graph_version_(graph_version),
779 propagate_caller_callee_constants_(propagate_caller_callee_constants) {}
780
GetCallers(FuncOp fn)781 ArrayRef<Operation*> ShapeInference::GetCallers(FuncOp fn) {
782 return symbol_users_.getUsers(fn);
783 }
784
EnqueueCallers(FuncOp fn)785 void ShapeInference::EnqueueCallers(FuncOp fn) {
786 for (auto user : GetCallers(fn)) enqueue(user->getParentOfType<FuncOp>());
787 }
788
UpdateTypeAndInsertIncompatibleUseCasts(Type new_type,Value result)789 bool ShapeInference::UpdateTypeAndInsertIncompatibleUseCasts(Type new_type,
790 Value result) {
791 Operation* cast_op = nullptr;
792 // First insert cast back for uses that need a cast and then
793 // update the type.
794 bool enqueue_callers = false;
795 for (OpOperand& use : make_early_inc_range(result.getUses())) {
796 if (isa<ReturnOp>(use.getOwner())) {
797 enqueue_callers = true;
798 } else if (NeedsCastBack(use, tf_dialect_)) {
799 if (!cast_op) {
800 Operation* op = result.getDefiningOp();
801 OpBuilder b(op);
802 b.setInsertionPointAfter(op);
803 cast_op = InsertCast(b, op->getLoc(), result.getType(), result);
804 if (!cast_op) return false;
805 }
806 use.set(Value(cast_op->getResult(0)));
807 }
808 }
809
810 result.setType(new_type);
811 if (enqueue_callers)
812 EnqueueCallers(result.getDefiningOp()->getParentOfType<FuncOp>());
813 return true;
814 }
815
RefineResultType(Operation * op,Value result,Type potential_refined_type)816 bool ShapeInference::RefineResultType(Operation* op, Value result,
817 Type potential_refined_type) {
818 if (!CanRefineTypeWith(result.getType(), potential_refined_type))
819 return false;
820
821 return UpdateTypeAndInsertIncompatibleUseCasts(potential_refined_type,
822 result);
823 }
824
825 // Infers the shape from a (Stateful)PartionedCall operation by looking up the
826 // called function and propagating the return type.
InferShapeForCall(CallOpInterface call_op)827 bool ShapeInference::InferShapeForCall(CallOpInterface call_op) {
828 FuncOp func = dyn_cast<FuncOp>(call_op.resolveCallable());
829 if (!func) return false;
830
831 DCOMMENT("Infer shape for call " << func.getName());
832 Operation* op = call_op.getOperation();
833 bool changed = false;
834 // Map each of the results of the call to the returned type of the
835 // function.
836 for (auto result : zip(op->getResults(), func.getType().getResults())) {
837 changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) ||
838 changed;
839 }
840 DCOMMENT(" - call " << func.getName() << "changed ? " << changed << "\n");
841
842 return changed;
843 }
844
InferShapeForCast(Operation * op)845 bool ShapeInference::InferShapeForCast(Operation* op) {
846 DCOMMENT_OP(op, "Inferring shape for ");
847 Value result = op->getResult(0);
848 if (!CanBeRefined(result.getType())) return false;
849
850 Type operand_type = op->getOperand(0).getType();
851 auto ranked_op_type = operand_type.dyn_cast<RankedTensorType>();
852 if (!ranked_op_type) return false;
853 auto ranked_res_type = result.getType().dyn_cast<RankedTensorType>();
854 if (ranked_res_type &&
855 ranked_op_type.getShape() == ranked_res_type.getShape())
856 return false;
857
858 // Avoid inserting a cast where no users types could be refined (e.g., where
859 // there would need to be a cast inserted for every user again).
860 if (llvm::all_of(result.getUses(), [this](OpOperand& use) {
861 return NeedsCastBack(use, tf_dialect_);
862 }))
863 return false;
864
865 auto new_type = RankedTensorType::get(
866 ranked_op_type.getShape(),
867 result.getType().cast<ShapedType>().getElementType());
868
869 return UpdateTypeAndInsertIncompatibleUseCasts(new_type, op->getResult(0));
870 }
871
InferShapeForIf(IfOp op)872 bool ShapeInference::InferShapeForIf(IfOp op) {
873 DCOMMENT_OP(op.getOperation(), "Infer shape for if ");
874 bool changed = false;
875 auto then_results = op.then_function().getType().getResults();
876 auto else_results = op.else_function().getType().getResults();
877 for (auto it : llvm::zip(op.getResults(), then_results, else_results)) {
878 // If then and else types do not match, skip refinement for that result.
879 if (std::get<1>(it) != std::get<2>(it)) continue;
880 changed = RefineResultType(op, std::get<0>(it), std::get<1>(it)) || changed;
881 }
882 return changed;
883 }
884
InferShapeForIfRegion(IfRegionOp op)885 bool ShapeInference::InferShapeForIfRegion(IfRegionOp op) {
886 bool changed = false;
887
888 Operation* then_yield = op.then_branch().front().getTerminator();
889 Operation* else_yield = op.else_branch().front().getTerminator();
890 for (auto result : zip(op.getResults(), then_yield->getOperandTypes(),
891 else_yield->getOperandTypes())) {
892 // If then and else types do not match, skip refinement for that result.
893 if (std::get<1>(result) != std::get<2>(result)) continue;
894 changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) ||
895 changed;
896 }
897 return changed;
898 }
899
InferShapeForXlaHostComputeMlir(_XlaHostComputeMlirOp host_compute_op)900 bool ShapeInference::InferShapeForXlaHostComputeMlir(
901 _XlaHostComputeMlirOp host_compute_op) {
902 // Extract the module and function.
903 // The '_XlaHostComputeMlir` verifier verifies that `host_mlir_module`
904 // attribute is well formed, so we just return in case of an error in
905 // extracting the host function since it should never occur.
906 StringAttr host_module =
907 host_compute_op->getAttrOfType<StringAttr>("host_mlir_module");
908 if (host_module.getValue().empty()) return false;
909
910 mlir::OwningModuleRef module_for_func;
911 FuncOp func = host_compute_op.GetHostFunc(&module_for_func);
912
913 // Update/use input shapes for function.
914 FunctionType func_type = func.getType();
915 func.setType(FunctionType::get(func.getContext(),
916 host_compute_op.getOperandTypes(),
917 func_type.getResults()));
918
919 // Run shape inference on the function.
920 if (failed(PropagateShapeToRegions(host_compute_op.getOperandTypes(),
921 {&func.getBody()}, 10)))
922 return false;
923 if (failed(InferShapeForFunctionReturnType(func))) return false;
924
925 bool changed = false;
926 // Use refined function return shape for XlaHostComputeMlirOp.
927 for (auto result :
928 zip(host_compute_op.getResults(), func.getType().getResults())) {
929 changed = RefineResultType(host_compute_op, std::get<0>(result),
930 std::get<1>(result)) ||
931 changed;
932 }
933
934 return changed;
935 }
936
InferShapeForMapDataset(MapDatasetOp op)937 bool ShapeInference::InferShapeForMapDataset(MapDatasetOp op) {
938 // MapDatasetOp's relationship with its associated function is as
939 // follows: first M function params are dictated by the the set
940 // output shapes and types, the next N are the last Ninputs from MapDataset
941 // op. The MapDataset op always has N+1 inputs.
942 // TODO(jpienaar): Avoid this lookup.
943 auto module = op->getParentOfType<ModuleOp>();
944 auto f = module.lookupSymbol<FuncOp>(op.f());
945 // Skip if function is not found or more than one caller.
946 if (!f || !llvm::hasSingleElement(GetCallers(f))) return false;
947
948 int N = op.getNumOperands() - 1;
949 int M = f.getNumArguments() - N;
950 DCOMMENT_OP(op, "Inferring shape for with N = " << N << " and M = " << M);
951
952 // Initialize with function input types.
953 SmallVector<Type> input_types(f.getArgumentTypes());
954
955 // Track if changed to skip enqueueing.
956 bool changed = false;
957 auto it = input_types.begin();
958 // First set first M argument shapes & types.
959 for (int i = 0; i < M; ++i) {
960 auto shape = op.output_shapes()[i].cast<tf_type::ShapeAttr>();
961 auto type = op.output_types()[i].cast<TypeAttr>();
962 Type t;
963 if (shape.hasRank())
964 t = RankedTensorType::get(shape.getShape(), type.getValue());
965 else
966 t = UnrankedTensorType::get(type.getValue());
967 t = TypeMeet(*it, t);
968 changed = changed || (t != *it);
969 ++it;
970 }
971 // Now the remaining N from operand types.
972 for (auto t : llvm::drop_begin(op.getOperandTypes())) {
973 auto meet = TypeMeet(*it, t);
974 changed = changed || (meet != *it);
975 *it = meet;
976 ++it;
977 }
978 if (!changed) return false;
979
980 // TODO(jpienaar): Pipe the max_iteration value through.
981 FailureOr<bool> res =
982 PropagateShapeToFunctions(module, input_types, {f}, /*max_iteration=*/10);
983 if (failed(res)) {
984 LOG(ERROR) << "Propagating shapes for MapDataset failed";
985 return false;
986 }
987 return *res;
988 }
989
InferShapeForTensorListInitOps(Operation * op)990 bool ShapeInference::InferShapeForTensorListInitOps(Operation* op) {
991 DCOMMENT_OP(op, "Inferring shape for TensorList ");
992 Value handle = op->getResult(0);
993 Value initial_element_shape = GetElementShapeOperand(op);
994 RankedTensorType element_type;
995 if (auto tl_from_tensor = dyn_cast<TensorListFromTensorOp>(op)) {
996 // For TensorListFromTensor op we can infer element shape by dropping the
997 // first dimension of input tensor.
998 element_type = DropFirstDimension(tl_from_tensor.tensor().getType());
999 if (!element_type || !element_type.hasStaticShape()) return false;
1000 }
1001 if (!CanInferTensorListElementType(handle, initial_element_shape,
1002 &element_type)) {
1003 DCOMMENT("InferShapeForListInitOps " << op << " could not infer");
1004 return false;
1005 }
1006 DCOMMENT("InferShapeForListInitOps " << *op << " could be inferred "
1007 << element_type);
1008 if (!element_type || !element_type.hasStaticShape()) return false;
1009 auto variant_type = VariantType::get(element_type, op->getContext());
1010 auto tensor_type = RankedTensorType::get({}, variant_type);
1011 bool changed = RefineResultType(op, handle, tensor_type);
1012 if (changed) DCOMMENT_OP(op, "Modified after shape inference:");
1013 return changed;
1014 }
1015
RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti)1016 bool ShapeInference::RefineWithInferTypeOpInterface(
1017 InferTypeOpInterface infer_ti) {
1018 Operation* op = infer_ti.getOperation();
1019 SmallVector<Type, 4> inferred;
1020 LogicalResult res = infer_ti.inferReturnTypes(
1021 op->getContext(), op->getLoc(), op->getOperands(),
1022 op->getAttrDictionary(), op->getRegions(), inferred);
1023 if (failed(res)) {
1024 op->emitOpError("failed to refine type as inference failed");
1025 return false;
1026 }
1027
1028 if (inferred == op->getResultTypes()) return false;
1029
1030 // Map each of the results of the call to the returned type of the
1031 // function.
1032 bool changed = false;
1033 for (auto result : zip(op->getResults(), inferred)) {
1034 if (std::get<0>(result).getType() == std::get<1>(result)) continue;
1035
1036 if (!UpdateTypeAndInsertIncompatibleUseCasts(std::get<1>(result),
1037 std::get<0>(result)))
1038 continue;
1039 changed = true;
1040 }
1041 return changed;
1042 }
1043
ComputeOutputAsShape(OpResult result,InferenceContext * ic)1044 ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result,
1045 InferenceContext* ic) {
1046 LLVM_DEBUG(result.print(llvm::dbgs() << "\nEvaluate partially "));
1047 auto rt = result.getType().dyn_cast<RankedTensorType>();
1048 if (!rt || !rt.hasStaticShape() || rt.getRank() != 1) return {};
1049 int dim_size = rt.getDimSize(0);
1050
1051 // Worklist to direct partial evaluation.
1052 SmallVector<ValuePort, 4> worklist;
1053
1054 // Simple evaluator that attempts to partially evaluate the input value even
1055 // if unable to evaluate the complete output. Below follows a simple stack
1056 // based evaluation where it queries what operands/part of operands need to
1057 // be evaluated and attempting to partially evaluate those operands. It does
1058 // so by pushing the operands that need to be required on to the worklist
1059 // before enqueuing the operation requiering those values.
1060 std::vector<DimensionHandle> dims(dim_size, ic->UnknownDim());
1061 for (unsigned int i = 0, e = dims.size(); i != e; ++i) {
1062 LLVM_DEBUG(llvm::dbgs() << "\nConsidering output dim " << i << "\n");
1063
1064 worklist.push_back(
1065 ValuePort{result.getOwner(), {result.getResultNumber(), i}});
1066 while (!worklist.empty()) {
1067 auto front = worklist.pop_back_val();
1068 LLVM_DEBUG(front.print(llvm::dbgs() << "\nWorklist front "));
1069
1070 SmallVector<ValuePort, 4> inputs;
1071 auto res = ComputeInputsRequiredForOutput(front, &inputs);
1072 if (failed(res)) {
1073 // Abort if unable to find which required inputs need to be computed.
1074 worklist.clear();
1075 break;
1076 }
1077
1078 if (!inputs.empty()) {
1079 // Enqueue required computation followed by its required operands in
1080 // stack.
1081 worklist.push_back(std::move(front));
1082 for (auto& it : inputs) worklist.push_back(std::move(it));
1083 continue;
1084 }
1085
1086 auto ret = ComputeOutputComponent(front);
1087 if (!ret) continue;
1088
1089 LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = "));
1090
1091 // If worklist is empty, then this is the root query op.
1092 if (worklist.empty()) {
1093 LLVM_DEBUG(llvm::dbgs() << "[root node]\n");
1094 if (auto dea = ret.dyn_cast<DenseIntElementsAttr>()) {
1095 if (dea.getNumElements() != 1) {
1096 LLVM_DEBUG(llvm::dbgs() << "Unexpected number of elements\n");
1097 return {};
1098 }
1099 int64_t val = (*dea.getIntValues().begin()).getSExtValue();
1100 dims[i] = ic->MakeDim(val);
1101 }
1102 }
1103 }
1104 }
1105 return ic->MakeShape(dims);
1106 }
1107
RefineTypeForPassThroughOperands(Operation * op,OperandRange operands,ResultRange results)1108 bool ShapeInference::RefineTypeForPassThroughOperands(Operation* op,
1109 OperandRange operands,
1110 ResultRange results) {
1111 bool changed = false;
1112 for (auto entry : llvm::zip(operands, results)) {
1113 Type operand_type = std::get<0>(entry).getType();
1114 Value result = std::get<1>(entry);
1115 TensorType result_type = result.getType().cast<TensorType>();
1116 Type inferred_type = TypeMeet(result_type, operand_type);
1117 if (result_type == inferred_type) continue;
1118
1119 if (!UpdateTypeAndInsertIncompatibleUseCasts(inferred_type, result))
1120 continue;
1121 changed = true;
1122 }
1123 return changed;
1124 }
1125
RefineShapeForPassThroughOps(Operation * op)1126 bool ShapeInference::RefineShapeForPassThroughOps(Operation* op) {
1127 DCOMMENT_OP(op, "Pass through op");
1128 bool changed = false;
1129 for (auto entry : llvm::zip(op->getOperands(), op->getResults())) {
1130 Value operand = std::get<0>(entry);
1131 Value result = std::get<1>(entry);
1132 Type inferred_type = TypeMeet(result.getType(), operand.getType());
1133 if (result.getType() == inferred_type) continue;
1134 if (!UpdateTypeAndInsertIncompatibleUseCasts(inferred_type, result))
1135 continue;
1136 changed = true;
1137 }
1138 return changed;
1139 }
1140
InferShapeForNonTFDialectOperation(Operation * op)1141 bool ShapeInference::InferShapeForNonTFDialectOperation(Operation* op) {
1142 if (auto graph_op = dyn_cast<tf_executor::GraphOp>(op)) {
1143 return RefineTypeForPassThroughOperands(
1144 graph_op.GetFetch(), graph_op.GetFetch().fetches(), op->getResults());
1145 }
1146 if (auto island_op = dyn_cast<tf_executor::IslandOp>(op)) {
1147 return RefineTypeForPassThroughOperands(
1148 island_op.GetYield(), island_op.GetYield().fetches(), op->getResults());
1149 }
1150 if (auto iter_sink = dyn_cast<tf_executor::NextIterationSinkOp>(op)) {
1151 auto iter_source = cast<tf_executor::NextIterationSourceOp>(
1152 iter_sink.token().getDefiningOp());
1153 return RefineTypeForPassThroughOperands(
1154 op, iter_sink.getOperands().drop_front().take_front(),
1155 iter_source.getResults());
1156 }
1157 if (auto launch_op = dyn_cast<tf_device::LaunchOp>(op)) {
1158 auto terminator = launch_op.GetBody().getTerminator();
1159 return RefineTypeForPassThroughOperands(op, terminator->getOperands(),
1160 op->getResults());
1161 }
1162 if (auto cluster_op = dyn_cast<tf_device::ClusterOp>(op)) {
1163 auto terminator = cluster_op.GetBody().getTerminator();
1164 return RefineTypeForPassThroughOperands(op, terminator->getOperands(),
1165 op->getResults());
1166 }
1167 if (op->hasTrait<OpTrait::SameOperandsAndResultShape>())
1168 return RefineShapeForPassThroughOps(op);
1169 if (auto call = dyn_cast<CallOpInterface>(op)) return InferShapeForCall(call);
1170 return false;
1171 }
1172
1173 // Finds element type to be used for result from operand, with special handling
1174 // for handle types.
GetElementTypeFromOperand(TensorType operand_type,TensorType result_type)1175 Type GetElementTypeFromOperand(TensorType operand_type,
1176 TensorType result_type) {
1177 auto operand_handle_type =
1178 operand_type.getElementType().dyn_cast<TensorFlowTypeWithSubtype>();
1179 if (!operand_handle_type) return result_type.getElementType();
1180 auto result_handle_type =
1181 result_type.getElementType().cast<TensorFlowTypeWithSubtype>();
1182 if (operand_handle_type.GetSubtypes().empty() ||
1183 !result_handle_type.GetSubtypes().empty())
1184 return result_type.getElementType();
1185 return operand_handle_type;
1186 }
1187
1188 // Checks if one tensor type can refine another type for tf.While/
1189 // tf.WhileRegion. If rank differs or static dimensions can be lost, the other
1190 // type cannot be used for refinement.
CanWhileTypeBeRefinedWith(TensorType current_type,TensorType potential_refined_type)1191 bool CanWhileTypeBeRefinedWith(TensorType current_type,
1192 TensorType potential_refined_type) {
1193 if (!current_type.hasRank()) return true;
1194 if (!potential_refined_type.hasRank()) return false;
1195 if (current_type.getRank() != potential_refined_type.getRank()) return false;
1196 for (auto dim :
1197 llvm::zip(current_type.getShape(), potential_refined_type.getShape())) {
1198 int64_t current_dim = std::get<0>(dim);
1199 int64_t potential_refined_dim = std::get<1>(dim);
1200 if (current_dim != potential_refined_dim &&
1201 current_dim != ShapedType::kDynamicSize)
1202 return false;
1203 }
1204 return true;
1205 }
1206
1207 template <typename WhileOpTy>
InferShapeForWhile(WhileOpTy op,TypeRange body_result_types)1208 bool ShapeInference::InferShapeForWhile(WhileOpTy op,
1209 TypeRange body_result_types) {
1210 if (!op.shape_invariant())
1211 return RefineTypeForPassThroughOperands(op, op.input(), op.output());
1212
1213 bool changed = false;
1214 for (auto entry :
1215 zip(op.input().getTypes(), op.output(), body_result_types)) {
1216 Value result = std::get<1>(entry);
1217 TensorType body_result_type =
1218 std::get<2>(entry).template cast<TensorType>();
1219 auto result_type = result.getType().cast<TensorType>();
1220
1221 Type potential_refined_type;
1222 if (CanWhileTypeBeRefinedWith(result_type, body_result_type)) {
1223 Type element_type =
1224 GetElementTypeFromOperand(body_result_type, result_type);
1225 potential_refined_type = CreateTensorType(
1226 body_result_type.hasRank() ? body_result_type.getShape()
1227 : llvm::Optional<ArrayRef<int64_t>>(),
1228 element_type);
1229 } else {
1230 TensorType operand_type = std::get<0>(entry).template cast<TensorType>();
1231 Type element_type = GetElementTypeFromOperand(operand_type, result_type);
1232 potential_refined_type = CreateTensorType(
1233 result_type.hasRank() ? result_type.getShape()
1234 : llvm::Optional<ArrayRef<int64_t>>(),
1235 element_type);
1236 }
1237 changed |= RefineResultType(op, result, potential_refined_type);
1238 }
1239 return changed;
1240 }
1241
InferShapeForSingleOperation(Operation * op)1242 bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
1243 LLVM_DEBUG(op->print(llvm::dbgs() << "InferShapeForSingleOperation for ");
1244 llvm::dbgs() << "\n");
1245 assert(tf_dialect_ == op->getDialect());
1246 // The shape function of these ops sometimes does not propagate subtypes
1247 // (handle shapes) for resource and variant types. We use a simple passthrough
1248 // to make sure they are preserved in the output.
1249 if (isa<TF::IdentityOp, TF::IdentityNOp, TF::StopGradientOp, TF::ZerosLikeOp>(
1250 op)) {
1251 return RefineTypeForPassThroughOperands(op, op->getOperands(),
1252 op->getResults());
1253 }
1254
1255 // If no result for this op needs shape inference, we have a fast-path return.
1256 // But if the type is a resource/variant, we do not skip it because we might
1257 // not have the handle shapes.
1258 if (none_of(op->getResultTypes(), CanBeRefined)) {
1259 LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '"
1260 << op->getName() << "'.\n");
1261 return false;
1262 }
1263
1264 // Handle call operations by looking up callee and inferring return shape as
1265 // needed.
1266 if (auto call = dyn_cast<CallOpInterface>(op)) return InferShapeForCall(call);
1267
1268 // tf.Cast and tensor::Cast are only inferred if they have at least one user
1269 // in the TF dialect or feeding into the function return. This is necessary to
1270 // avoid inserting casts which cannot be refined.
1271 if (isa<CastOp, tensor::CastOp>(op)) return InferShapeForCast(op);
1272
1273 // Handle IfOp here by inferring the shape from the else/then function
1274 // results. Since `output_shapes` is a derived attribute, avoid going down the
1275 // TF InferenceContext path as IfOp shape inference is implemented as just
1276 // a lookup of the output_shapes attribute.
1277 if (auto if_op = dyn_cast<IfOp>(op)) return InferShapeForIf(if_op);
1278
1279 // Handle IfRegion operations by inferring return shape from the then and else
1280 // branches.
1281 if (auto if_region = dyn_cast<IfRegionOp>(op))
1282 return InferShapeForIfRegion(if_region);
1283
1284 if (auto while_op = dyn_cast<WhileOp>(op))
1285 return InferShapeForWhile(while_op,
1286 while_op.body_function().getType().getResults());
1287
1288 if (auto while_region = dyn_cast<WhileRegionOp>(op))
1289 return InferShapeForWhile(
1290 while_region,
1291 while_region.body().front().getTerminator()->getOperandTypes());
1292
1293 if (auto host_compute_op = dyn_cast<_XlaHostComputeMlirOp>(op)) {
1294 return InferShapeForXlaHostComputeMlir(host_compute_op);
1295 }
1296
1297 if (auto map_dataset_op = dyn_cast<MapDatasetOp>(op))
1298 return InferShapeForMapDataset(map_dataset_op);
1299
1300 // Handle TensorList init operations by inferring shape from TensorList write
1301 // operations. If we are unable to refine element shape here, proceed to use
1302 // the InferenceContext below to get more precise shapes.
1303 if (IsTensorListInitOp(op) && InferShapeForTensorListInitOps(op)) return true;
1304
1305 // Return operand as a constant attribute.
1306 auto operand_as_constant_fn = [&](Value operand) {
1307 ValuePort vp(operand);
1308 Attribute attr = ComputeOutputComponent(vp);
1309 if (!attr && matchPattern(operand, m_Constant(&attr)))
1310 RecordValue(vp, attr);
1311 return attr;
1312 };
1313
1314 // Return op result as a shape.
1315 auto op_result_as_shape_fn = [&](InferenceContext& context,
1316 OpResult op_result) {
1317 return ComputeOutputAsShape(op_result, &context);
1318 };
1319
1320 // Return result element type at `index`.
1321 auto result_element_type_fn = [&](int index) {
1322 return op->getResult(index).getType().cast<TensorType>().getElementType();
1323 };
1324
1325 llvm::SmallVector<ShapedTypeComponents, 4> inferred_return_shapes;
1326 if (failed(InferReturnTypeComponentsForTFOp(
1327 /*location=*/None, op, graph_version_, operand_as_constant_fn,
1328 op_result_as_shape_fn, result_element_type_fn,
1329 inferred_return_shapes)))
1330 return false;
1331
1332 // Update the shape for each of the operation result if the InferenceContext
1333 // has more precise shapes recorded.
1334 bool changed = false;
1335 for (auto result : llvm::zip(op->getResults(), inferred_return_shapes)) {
1336 Value op_result = std::get<0>(result);
1337 if (!CanBeRefined(op_result.getType())) continue;
1338
1339 ShapedTypeComponents inferred = std::get<1>(result);
1340 TensorType inferred_type;
1341 if (inferred.hasRank())
1342 inferred_type =
1343 RankedTensorType::get(inferred.getDims(), inferred.getElementType());
1344 else
1345 inferred_type = UnrankedTensorType::get(inferred.getElementType());
1346
1347 inferred_type =
1348 TypeMeet(op_result.getType(), inferred_type).cast<TensorType>();
1349 if (op_result.getType() == inferred_type) continue;
1350 if (!UpdateTypeAndInsertIncompatibleUseCasts(inferred_type, op_result))
1351 continue;
1352 changed = true;
1353 }
1354
1355 if (changed) DCOMMENT_OP(op, "Modified after shape inference:");
1356 return changed;
1357 }
1358
PropagateShapeToFunctions(ModuleOp module,TypeRange input_types,ArrayRef<FuncOp> functions,int64_t max_iteration)1359 FailureOr<bool> ShapeInference::PropagateShapeToFunctions(
1360 ModuleOp module, TypeRange input_types, ArrayRef<FuncOp> functions,
1361 int64_t max_iteration) {
1362 bool any_failure = false;
1363 bool any_nonconvergence = false;
1364 // If shape propagation fails for one function, return failure, but do not
1365 // early exit and attempt to propagate shapes for all provided functions to
1366 // have a best-effort propagation.
1367 for (FuncOp func : functions) {
1368 DCOMMENT("Propating shape to " << func.getName());
1369 ArrayRef<Operation*> callers = GetCallers(func);
1370 if (!llvm::hasSingleElement(callers) &&
1371 !llvm::all_of(callers.drop_front(), [&](Operation* caller) {
1372 /// TODO(aminim): this is overly conservative as some operations
1373 /// (like TPUPartitionedCallOp) may have extra operands that aren't
1374 /// propagated to the callee.
1375 return isa<CallOpInterface>(caller) &&
1376 std::equal(caller->getOperandTypes().begin(),
1377 caller->getOperandTypes().end(),
1378 callers.front()->getOperandTypes().begin());
1379 })) {
1380 if (llvm::any_of(callers, [](Operation* op) {
1381 return isa<IfOp, WhileOp, CaseOp>(op);
1382 }))
1383 func.emitWarning(formatv(
1384 "expected control flow function @{0} to have exactly 1 use, "
1385 "found {1}.",
1386 func.getName(), callers.size()));
1387
1388 continue;
1389 }
1390 FunctionType func_type = func.getType();
1391 func.setType(FunctionType::get(func.getContext(), input_types,
1392 func_type.getResults()));
1393
1394 FailureOr<bool> failure_or_converged =
1395 PropagateShapeToRegions(input_types, {&func.getBody()}, max_iteration);
1396 if (failed(failure_or_converged)) {
1397 any_failure = true;
1398 continue;
1399 }
1400 any_nonconvergence = any_nonconvergence || !failure_or_converged.getValue();
1401 if (failed(InferShapeForFunctionReturnType(func))) any_failure = true;
1402 }
1403 if (any_failure) return failure();
1404 return any_nonconvergence;
1405 }
1406
PropagateShapeToRegions(TypeRange input_types,ArrayRef<Region * > regions,int64_t max_iteration)1407 FailureOr<bool> ShapeInference::PropagateShapeToRegions(
1408 TypeRange input_types, ArrayRef<Region*> regions, int64_t max_iteration) {
1409 DCOMMENT("\tPropagating shapes to regions");
1410 bool any_failure = false;
1411 bool any_nonconvergence = false;
1412 // If shape propagation fails for one region, return failure, but do not
1413 // early exit and attempt to propagate shapes for all provided regions to
1414 // have a best-effort propagation.
1415 for (auto region : regions) {
1416 // Refine region arguments.
1417 Block& entry = region->front();
1418 assert(llvm::size(input_types) == entry.getNumArguments());
1419 for (auto it : llvm::zip(entry.getArguments(), input_types)) {
1420 BlockArgument arg = std::get<0>(it);
1421 Type type = std::get<1>(it);
1422 arg.setType(type);
1423 }
1424
1425 // Propagate shapes into the region.
1426 FailureOr<bool> failure_or_converged =
1427 InferShapeUntilFixPoint(region, max_iteration);
1428 if (failed(failure_or_converged))
1429 any_failure = true;
1430 else if (!failure_or_converged.getValue())
1431 any_nonconvergence = true;
1432 }
1433 if (any_failure) return failure();
1434 return any_nonconvergence;
1435 }
1436
PropagateConstantToCallee(CallOpInterface call_op,FuncOp func,ModuleOp module)1437 void ShapeInference::PropagateConstantToCallee(CallOpInterface call_op,
1438 FuncOp func, ModuleOp module) {
1439 auto callers = GetCallers(func);
1440 if (!llvm::hasSingleElement(callers)) return;
1441
1442 OpBuilder builder(&func.front().front());
1443 Operation* op = call_op.getOperation();
1444 // If this is the only caller, and an operand is a constant, propagate
1445 // the constant value inside the function.
1446 for (auto arg : func.getArguments()) {
1447 auto operand = op->getOperand(arg.getArgNumber());
1448 if (propagate_caller_callee_constants_) {
1449 if (isa_and_nonnull<TF::ConstOp>(operand.getDefiningOp())) {
1450 arg.replaceAllUsesWith(
1451 builder.clone(*operand.getDefiningOp())->getResult(0));
1452 }
1453 continue;
1454 }
1455
1456 auto known_constant = ComputeOutputComponent(ValuePort(operand));
1457 if (!known_constant) continue;
1458 LLVM_DEBUG(call_op.print(llvm::dbgs() << "Propagate to calee: ");
1459 known_constant.print(llvm::dbgs() << " constant ");
1460 llvm::dbgs() << "\n");
1461 RecordValue(ValuePort(arg), known_constant);
1462 }
1463 }
1464
PropagateConstantFromCallee(CallOpInterface call_op,FuncOp func,ModuleOp module)1465 void ShapeInference::PropagateConstantFromCallee(CallOpInterface call_op,
1466 FuncOp func, ModuleOp module) {
1467 // If the return value is a constant, use the constant as the value of
1468 // the call return.
1469 Operation* op = call_op.getOperation();
1470 OpBuilder builder(op);
1471 builder.setInsertionPointAfter(op);
1472 for (auto retval :
1473 llvm::enumerate(func.front().getTerminator()->getOperands())) {
1474 if (propagate_caller_callee_constants_) {
1475 auto retval_op = retval.value().getDefiningOp();
1476 if (isa_and_nonnull<TF::ConstOp>(retval_op)) {
1477 op->getResult(retval.index())
1478 .replaceAllUsesWith(builder.clone(*retval_op)->getResult(0));
1479 }
1480 continue;
1481 }
1482
1483 ValuePort vp(retval.value());
1484 if (auto known_constant = ComputeOutputComponent(vp)) {
1485 LLVM_DEBUG(known_constant.print(llvm::dbgs() << "Propagate constant ");
1486 call_op.print(llvm::dbgs() << "from "); llvm::dbgs() << "\n");
1487 RecordValue(ValuePort(op->getResult(retval.index())), known_constant);
1488 }
1489 }
1490 }
1491
RankedAndSameRank(TensorType lhs,TensorType rhs)1492 bool RankedAndSameRank(TensorType lhs, TensorType rhs) {
1493 return lhs.hasRank() && rhs.hasRank() && lhs.getRank() == rhs.getRank();
1494 }
1495
1496 // Creates a compatible RankedTensorType where mismatched dimensions are
1497 // replaced with dynamic sizes.
GetCompatibleRankedTensorType(RankedTensorType lhs,RankedTensorType rhs)1498 RankedTensorType GetCompatibleRankedTensorType(RankedTensorType lhs,
1499 RankedTensorType rhs) {
1500 assert(lhs.getRank() == rhs.getRank());
1501 llvm::SmallVector<int64_t, 4> dims;
1502 dims.reserve(lhs.getRank());
1503 for (auto dim : llvm::zip(lhs.getShape(), rhs.getShape())) {
1504 int64_t lhs_dim = std::get<0>(dim);
1505 if (lhs_dim == std::get<1>(dim)) {
1506 dims.push_back(lhs_dim);
1507 } else {
1508 dims.push_back(ShapedType::kDynamicSize);
1509 }
1510 }
1511 return RankedTensorType::get(dims, GetElementTypeFromOperand(lhs, rhs));
1512 }
1513
1514 // Finds compatible types to propagate into functions/regions of a shape
1515 // invariant tf.While/tf.WhileRegion. If operand and result types are the same,
1516 // that type is returned. If operand and result types are of the same rank, a
1517 // compatible type with matching dimensions is used. Otherwise functions/regions
1518 // arguments are returned but with the handle type from the operand type.
GetWhileCompatibleTypes(TypeRange operand_types,TypeRange result_types,TypeRange region_argument_types)1519 llvm::SmallVector<Type, 4> GetWhileCompatibleTypes(
1520 TypeRange operand_types, TypeRange result_types,
1521 TypeRange region_argument_types) {
1522 llvm::SmallVector<Type, 4> types;
1523 types.reserve(operand_types.size());
1524 for (auto entry :
1525 llvm::zip(operand_types, result_types, region_argument_types)) {
1526 auto operand_type = std::get<0>(entry).cast<TensorType>();
1527 auto result_type = std::get<1>(entry).cast<TensorType>();
1528 if (operand_type == result_type) {
1529 types.push_back(operand_type);
1530 } else if (RankedAndSameRank(operand_type, result_type)) {
1531 auto potential_refined_type =
1532 GetCompatibleRankedTensorType(operand_type.cast<RankedTensorType>(),
1533 result_type.cast<RankedTensorType>());
1534 types.push_back(potential_refined_type);
1535 } else {
1536 auto region_argument_type = std::get<2>(entry).cast<TensorType>();
1537 Type element_type = GetElementTypeFromOperand(
1538 operand_type.cast<TensorType>(), region_argument_type);
1539 Type potential_refined_type = CreateTensorType(
1540 region_argument_type.hasRank() ? region_argument_type.getShape()
1541 : llvm::Optional<ArrayRef<int64_t>>(),
1542 element_type);
1543 types.push_back(potential_refined_type);
1544 }
1545 }
1546 return types;
1547 }
1548
PropagateShapeIntoAttachedFunctions(Operation * op,int64_t max_iteration)1549 FailureOr<bool> ShapeInference::PropagateShapeIntoAttachedFunctions(
1550 Operation* op, int64_t max_iteration) {
1551 ModuleOp module = op->getParentOfType<ModuleOp>();
1552 if (auto if_op = dyn_cast<TF::IfOp>(op)) {
1553 DCOMMENT("Propagating shapes into If");
1554 return PropagateShapeToFunctions(
1555 module, if_op.input().getTypes(),
1556 {if_op.then_function(), if_op.else_function()}, max_iteration);
1557 } else if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
1558 SmallVector<FuncOp, 4> branches;
1559 case_op.get_branch_functions(branches);
1560 return PropagateShapeToFunctions(module, case_op.input().getTypes(),
1561 branches, max_iteration);
1562 } else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
1563 // If `shape_invariant` is set, operand shapes cannot be simply propagated
1564 // to result shapes as the op may have different intermediate shapes (such
1565 // While ops can have different result shapes from operand shapes).
1566 // Compatible shapes must be determined before propagating them.
1567 if (while_op.shape_invariant()) {
1568 auto compatible_types = GetWhileCompatibleTypes(
1569 while_op.input().getTypes(), while_op.output().getTypes(),
1570 while_op.body_function().getType().getInputs());
1571 return PropagateShapeToFunctions(
1572 module, compatible_types,
1573 {while_op.cond_function(), while_op.body_function()}, max_iteration);
1574 }
1575 return PropagateShapeToFunctions(
1576 module, while_op.input().getTypes(),
1577 {while_op.cond_function(), while_op.body_function()}, max_iteration);
1578 } else if (auto call_op = dyn_cast<CallOpInterface>(op)) {
1579 if (auto func = dyn_cast<FuncOp>(call_op.resolveCallable())) {
1580 PropagateConstantToCallee(call_op, func, module);
1581 FailureOr<bool> failure_or_converged = PropagateShapeToFunctions(
1582 module, call_op.getArgOperands().getTypes(), {func}, max_iteration);
1583 if (failed(failure_or_converged)) return failure();
1584 PropagateConstantFromCallee(call_op, func, module);
1585 return failure_or_converged;
1586 }
1587 }
1588
1589 // TODO(ycao): Implement support for Call op, including function reuse.
1590
1591 return true;
1592 }
1593
PropagateShapeIntoAttachedRegions(Operation * op,int64_t max_iteration)1594 FailureOr<bool> ShapeInference::PropagateShapeIntoAttachedRegions(
1595 Operation* op, int64_t max_iteration) {
1596 if (auto while_op = dyn_cast<TF::WhileRegionOp>(op)) {
1597 // If `shape_invariant` is set, operand shapes cannot be simply propagated
1598 // to result shapes as the op may have different intermediate shapes (such
1599 // While ops can have different result shapes from operand shapes).
1600 // Compatible shapes must be determined before propagating them.
1601 if (while_op.shape_invariant()) {
1602 auto compatible_types = GetWhileCompatibleTypes(
1603 while_op.input().getTypes(), while_op.output().getTypes(),
1604 while_op.body().getArgumentTypes());
1605 return PropagateShapeToRegions(compatible_types,
1606 {&while_op.cond(), &while_op.body()},
1607 max_iteration);
1608 }
1609 return PropagateShapeToRegions(while_op.input().getTypes(),
1610 {&while_op.cond(), &while_op.body()},
1611 max_iteration);
1612 }
1613 return true;
1614 }
1615
TryToFold(Operation * op)1616 LogicalResult ShapeInference::TryToFold(Operation* op) {
1617 LLVM_DEBUG(op->print(llvm::dbgs() << "TryToFold "); llvm::dbgs() << "\n");
1618 // If any output result is known, then the op probably has been computed
1619 // before.
1620 if (op->getNumResults() > 0 && results_[ValuePort(op->getResult(0))])
1621 return success();
1622
1623 SmallVector<Attribute, 8> constant_operands(op->getNumOperands());
1624 SmallVector<OpFoldResult, 8> fold_results;
1625
1626 // Check to see if any operands to the operation is constant and whether
1627 // the operation knows how to constant fold itself.
1628 bool some_unknown = false;
1629 for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
1630 if (!(constant_operands[i] =
1631 ComputeOutputComponent(ValuePort(op->getOperand(i)))))
1632 some_unknown = true;
1633 }
1634
1635 // Attempt to constant fold the operation.
1636 auto* abstract_op = op->getAbstractOperation();
1637 LogicalResult folded = failure();
1638 if (abstract_op) {
1639 folded = abstract_op->foldHook(op, constant_operands, fold_results);
1640 }
1641 // Attempt dialect fallback if op's fold hook failed.
1642 if (failed(folded)) {
1643 Dialect* dialect = op->getDialect();
1644 if (!dialect) return failure();
1645 // Only attempt TF dialect fallback if there are no unknown operands.
1646 if (some_unknown && dialect == tf_dialect_) return failure();
1647 auto* interface = dialect->getRegisteredInterface<DialectFoldInterface>();
1648 if (!interface) return failure();
1649
1650 if (failed(interface->fold(op, constant_operands, fold_results)))
1651 return failure();
1652 }
1653
1654 for (auto result : zip(op->getResults(), fold_results)) {
1655 auto fold_result = std::get<1>(result);
1656 Attribute attr = nullptr;
1657 if ((attr = fold_result.dyn_cast<Attribute>())) {
1658 RecordValue(ValuePort(std::get<0>(result)), attr);
1659 } else {
1660 auto value = fold_result.get<Value>();
1661 if ((attr = ComputeOutputComponent(ValuePort(value)))) {
1662 DCOMMENT("\t\tValue Result mapped to " << attr);
1663 RecordValue(ValuePort(std::get<0>(result)), attr);
1664 } else {
1665 DCOMMENT("\t\tValue result unmapped, consider value type:" << value);
1666 RefineResultType(op, std::get<0>(result), value.getType());
1667 }
1668 }
1669
1670 if (ElementsAttr eattr = attr.dyn_cast_or_null<ElementsAttr>()) {
1671 if (std::get<0>(result).getType() == eattr.getType()) continue;
1672
1673 (void)UpdateTypeAndInsertIncompatibleUseCasts(eattr.getType(),
1674 std::get<0>(result));
1675 }
1676 }
1677
1678 return success();
1679 }
1680
InferShapeForFunctionReturnType(FuncOp func)1681 LogicalResult ShapeInference::InferShapeForFunctionReturnType(FuncOp func) {
1682 LLVM_DEBUG(llvm::dbgs() << "Inferring return type for: " << func.getName()
1683 << "\n");
1684
1685 // Find any return ops.
1686 SmallVector<ReturnOp, 4> return_ops;
1687 for (Block& block : func) {
1688 if (auto return_op = dyn_cast<ReturnOp>(block.getTerminator())) {
1689 return_ops.push_back(return_op);
1690 }
1691 }
1692
1693 // Skip functions without a return, but don't flag as failure here.
1694 if (return_ops.empty()) return success();
1695
1696 // Right now we only handle the case of a single return op.
1697 // To handle multiple return ops, we would need to look at all their shapes
1698 // and come up with a common shape and insert appropriate casts.
1699 if (return_ops.size() != 1) return failure();
1700
1701 // Find the return type.
1702 auto return_op = return_ops.front();
1703
1704 // Manually fold tf.Cast that precedes the return instruction and only differs
1705 // in shape refinement level.
1706 bool changed = false;
1707 for (OpOperand& arg_op : return_op.getOperation()->getOpOperands()) {
1708 Operation* arg_defining_op = arg_op.get().getDefiningOp();
1709 if (isa_and_nonnull<CastOp, tensor::CastOp>(arg_defining_op)) {
1710 Value input = arg_defining_op->getOperand(0);
1711 Value result = arg_defining_op->getResult(0);
1712 Type meet = TypeMeet(result.getType(), input.getType());
1713 if (meet == result.getType()) continue;
1714
1715 LLVM_DEBUG({
1716 llvm::errs() << "\tfolding & updating return type ";
1717 result.getType().print(llvm::errs());
1718 input.getType().print(llvm::errs() << " to ");
1719 llvm::errs() << "\n";
1720 });
1721
1722 // Shape inference should not change the element type.
1723 if (HasCompatibleElementTypes(input.getType(), result.getType()) &&
1724 meet == input.getType()) {
1725 arg_op.set(input);
1726 } else {
1727 OpBuilder b(return_op.getOperation());
1728 auto new_cast_op = InsertCast(b, return_op.getLoc(), meet, input);
1729 if (!new_cast_op) return failure();
1730 arg_op.set(new_cast_op->getResult(0));
1731 }
1732 if (result.use_empty()) arg_defining_op->erase();
1733 changed = true;
1734 }
1735 }
1736
1737 DCOMMENT("Updating function type");
1738 func.setType(FunctionType::get(func.getContext(), func.getArgumentTypes(),
1739 return_op.getOperandTypes()));
1740
1741 if (changed) EnqueueCallers(func);
1742 return success();
1743 }
1744
InferShapeUntilFixPoint(Region * region,int64_t max_iteration)1745 FailureOr<bool> ShapeInference::InferShapeUntilFixPoint(Region* region,
1746 int64_t max_iteration) {
1747 bool changed = true;
1748
1749 // TODO(aminim): we could have a more efficient traversal by guiding the
1750 // traversal with a worklist and reconsider only the nodes for which an
1751 // operand type was inferred. This would need to be careful if working on a
1752 // region that would not be isolated.
1753 for (int iteration = 0; iteration < max_iteration && changed; ++iteration) {
1754 changed = false;
1755 LLVM_DEBUG(llvm::dbgs()
1756 << "Shape inference, iteration " << iteration << "\n");
1757 auto res = region->walk([&](Operation* op) {
1758 DCOMMENT_OP(op, "Inferring for");
1759 if (auto infer_ti = dyn_cast<InferTypeOpInterface>(op)) {
1760 DCOMMENT("\tRefinining with type op interface");
1761 changed |= RefineWithInferTypeOpInterface(infer_ti);
1762 return WalkResult::advance();
1763 }
1764
1765 if (op->getDialect() != tf_dialect_) {
1766 DCOMMENT("\tInfer non-TF dialect");
1767 changed |= InferShapeForNonTFDialectOperation(op);
1768 return WalkResult::advance();
1769 }
1770
1771 // Before attempting inference, just try to compute the folded
1772 // value/shape.
1773 if (succeeded(TryToFold(op)) &&
1774 // Folding can "succeed" and yet not all types be refined. In such
1775 // cases we still want to give a try at `InferShapeForSingleOperation`
1776 none_of(op->getResultTypes(), CanBeRefined))
1777 return WalkResult::advance();
1778
1779 // Best-effort shape inference in attached functions. Do not return
1780 // failure even if it doesn't get to fixed point, but propagate "real"
1781 // failure.
1782 if (failed(PropagateShapeIntoAttachedFunctions(op, max_iteration))) {
1783 op->emitWarning() << "unable to refine shape of attached function "
1784 "arguments and bodies";
1785 return WalkResult::interrupt();
1786 }
1787
1788 if (failed(PropagateShapeIntoAttachedRegions(op, max_iteration))) {
1789 op->emitWarning() << "unable to refine shape of attached region "
1790 "arguments and bodies";
1791 return WalkResult::interrupt();
1792 }
1793
1794 changed |= InferShapeForSingleOperation(op);
1795 return WalkResult::advance();
1796 });
1797 if (res.wasInterrupted()) return failure();
1798 }
1799
1800 if (changed) {
1801 region->getParentOp()->emitWarning()
1802 << "shape inference did not reach stable state after " << max_iteration
1803 << " iterations";
1804 }
1805 return !changed;
1806 }
1807
InferShapeForFunction(ShapeInference & context,FuncOp func,int64_t max_iterations)1808 static FailureOr<bool> InferShapeForFunction(ShapeInference& context,
1809 FuncOp func,
1810 int64_t max_iterations) {
1811 FailureOr<bool> failure_or_converged =
1812 context.InferShapeUntilFixPoint(&func.getBody(), max_iterations);
1813 if (failed(failure_or_converged) || !failure_or_converged.getValue())
1814 return failure_or_converged;
1815 // TODO(b/156276510): Verify that it is always fine to refine a function's
1816 // return type, as long as we do not change the argument shapes.
1817 if (failed(context.InferShapeForFunctionReturnType(func))) return failure();
1818 return true;
1819 }
1820
InferShapeForFunction(FuncOp func,ArrayRef<ArrayRef<int64_t>> arg_shapes,int64_t graph_version,int64_t max_iterations)1821 FailureOr<bool> InferShapeForFunction(FuncOp func,
1822 ArrayRef<ArrayRef<int64_t>> arg_shapes,
1823 int64_t graph_version,
1824 int64_t max_iterations) {
1825 ShapeInference context(graph_version, func->getParentOfType<ModuleOp>(),
1826 /*propagate_caller_callee_constants=*/true);
1827 if (arg_shapes.empty()) {
1828 return InferShapeForFunction(context, func, max_iterations);
1829 }
1830
1831 FunctionType func_type = func.getType();
1832 bool needs_refinement = false;
1833 SmallVector<Type, 4> new_arg_types;
1834 new_arg_types.reserve(func_type.getNumInputs());
1835
1836 // Update argument types in-place using the provided arg_shapes.
1837 for (size_t i = 0; i < func_type.getNumInputs(); ++i) {
1838 ArrayRef<int64_t> shape = arg_shapes[i];
1839 Type element_type;
1840 if (auto input_ty = func_type.getInput(i).dyn_cast<RankedTensorType>()) {
1841 if (input_ty.getRank() != shape.size()) {
1842 return failure();
1843 }
1844 element_type = input_ty.getElementType();
1845 } else {
1846 auto unranked_input_ty = func_type.getInput(i).dyn_cast<TensorType>();
1847 if (!unranked_input_ty) {
1848 return failure();
1849 }
1850 element_type = unranked_input_ty.getElementType();
1851 }
1852
1853 auto new_arg_type = RankedTensorType::get(shape, element_type);
1854 if (new_arg_type != func_type.getInput(i)) {
1855 // If the new type is more detailed, trigger shape inference.
1856 func.getArgument(i).setType(new_arg_type);
1857 needs_refinement = true;
1858 }
1859 new_arg_types.push_back(new_arg_type);
1860 }
1861
1862 if (!needs_refinement) return true;
1863
1864 FailureOr<bool> failure_or_converged =
1865 context.InferShapeUntilFixPoint(&func.getBody(), max_iterations);
1866 if (failed(failure_or_converged) || !failure_or_converged.getValue())
1867 return failure_or_converged;
1868
1869 if (failed(context.InferShapeForFunctionReturnType(func))) return failure();
1870 func.setType(FunctionType::get(func.getContext(), new_arg_types,
1871 func.getType().getResults()));
1872
1873 return true;
1874 }
1875
InferModuleShape(ModuleOp module,int64_t max_iterations)1876 FailureOr<bool> InferModuleShape(ModuleOp module, int64_t max_iterations) {
1877 auto producer_or = tensorflow::GetTfGraphProducerVersion(module);
1878 if (!producer_or.ok()) {
1879 // TODO(jpienaar): Keeping the existing behavior for now but this could
1880 // be relaxed.
1881 LLVM_DEBUG(llvm::dbgs()
1882 << "Skipping inference; " << producer_or.status().ToString());
1883 return true;
1884 }
1885 int64_t producer = producer_or.ValueOrDie();
1886 // TODO(jpienaar): Clean up propagate_NextIterationSinkOp_callee_constants if
1887 // it is no longer needed.
1888 ShapeInference context(producer, module,
1889 /*propagate_caller_callee_constants=*/false);
1890 if (auto main = module.lookupSymbol<mlir::FuncOp>("main"))
1891 context.enqueue(main);
1892 for (auto func : module.getOps<FuncOp>()) context.enqueue(func);
1893 // Arbitrarily upper bound the maximum number of functions that get processed
1894 // just to avoid pathological cases.
1895 auto max_iteration = context.QueueSize() * 4;
1896 while (!context.EmptyQueue()) {
1897 FuncOp func = context.front();
1898 FailureOr<bool> failure_or_converged =
1899 InferShapeForFunction(context, func, max_iterations);
1900 if (failed(failure_or_converged) || !failure_or_converged.getValue())
1901 return failure_or_converged;
1902 context.pop_front();
1903
1904 if ((--max_iteration) == 0) {
1905 emitWarning(UnknownLoc::get(module.getContext()))
1906 << "shape inference did not reach stable state after "
1907 << max_iteration << " iterations";
1908 return false;
1909 }
1910 }
1911 return true;
1912 }
1913
1914 } // namespace TF
1915 } // namespace mlir
1916