• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// This is a simple include file used to simplify the splitting of the
17// tf_ops.cc file. The helpers in here should be refactored and moved to
18// tf_verifiers or tf_ops.
19// TODO(jpienaar): Remove this file post refactoring.
20
21//===----------------------------------------------------------------------===//
22// TF op helper functions
23//===----------------------------------------------------------------------===//
24
25// Returns the RankedTensorType for the given operand. TensorFlow constant ops
26// may have non-static shape because the shape is not propagated during constant
27// folding. If the defining op for the given operand is a constant op, this
28// routine uses the constant op's attribute to get the actual shape.
29static RankedTensorType GetRankedTensorTypeForOperand(Value operand) {
30  DenseElementsAttr attr;
31  if (matchPattern(operand, m_Constant(&attr))) {
32    return attr.getType().dyn_cast<RankedTensorType>();
33  }
34  return operand.getType().dyn_cast<RankedTensorType>();
35}
36
37// Returns true if the given `value` is of ranked float tensor type with the
38// given `rank`.
39static inline bool IsOfRankedFloatTensorType(RankedTensorType type, int rank) {
40  return type && type.getRank() == rank &&
41         type.getElementType().isa<FloatType>();
42}
43
44// Returns true if the given `value` has the specified rank or has unranked
45// type.
46static inline bool IsOfRankOrUnranked(Value value, int64_t rank) {
47  RankedTensorType type = GetRankedTensorTypeForOperand(value);
48  return !type || type.getRank() == rank;
49}
50
51// Returns true if the given `value` has at least the specified rank or has
52// unranked type.
53static inline bool HasRankAtLeast(Value value, int64_t rank) {
54  RankedTensorType type = GetRankedTensorTypeForOperand(value);
55  return !type || type.getRank() >= rank;
56}
57
58// Returns true if the given `value` has at most the specified rank or has
59// unranked type.
60static inline bool HasRankAtMost(Value value, int64_t rank) {
61  RankedTensorType type = GetRankedTensorTypeForOperand(value);
62  return !type || type.getRank() <= rank;
63}
64
65static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
66  return dim_or_rank == -1;
67}
68
69// Returns the tf.Equal/tf.NotEqual result type given `x` and `y` and inputs. If
70// `incompatible_shape_error` is true, reports error if `x` and `y` has
71// incompatible shapes. Otherwise, returns a tensor type with unknown rank.
72static Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value x,
73                                 Value y, BoolAttr incompatible_shape_error) {
74  auto result_type =
75      OpTrait::util::getBroadcastedType(x.getType(), y.getType());
76  if (!result_type) {
77    if (incompatible_shape_error.getValue()) {
78      mlir::emitError(loc, "non-broadcastable operands");
79    } else {
80      return UnrankedTensorType::get(builder->getI1Type());
81    }
82  }
83
84  auto ranked_type = result_type.dyn_cast<RankedTensorType>();
85  if (!ranked_type) return UnrankedTensorType::get(builder->getI1Type());
86
87  return RankedTensorType::get(ranked_type.getShape(), builder->getI1Type());
88}
89
90// Returns dimension index for the given TensorFlow axis that supports negative
91// indexing.
92static int64_t GetDimForAxis(int64_t axis, int64_t rank) {
93  return axis >= 0 ? axis : axis + rank;
94}
95
96// Infers output type for reduction ops such as SumOp, MaxOp etc.
97// TODO(b/e667204a): Move this logic to shape inference once it supports custom
98// inference functions.
99static Type InferReductionOpType(Value input, Value reduction_indices,
100                                 BoolAttr keep_dims, Builder *builder) {
101  Type input_ty = input.getType();
102  Type element_ty = getElementTypeOrSelf(input_ty);
103
104  // Output type is unranked if input type is not ranked.
105  auto ranked_ty = input_ty.dyn_cast<RankedTensorType>();
106  if (!ranked_ty) return UnrankedTensorType::get(element_ty);
107  int64_t rank = ranked_ty.getRank();
108
109  DenseIntElementsAttr indices;
110  if (!matchPattern(reduction_indices, m_Constant(&indices))) {
111    // Output type is unranked if reduction indices are not constant and reduced
112    // dimensions are not kept.
113    if (!keep_dims.getValue()) return UnrankedTensorType::get(element_ty);
114
115    // Otherwise, output type has same rank as the input.
116    return RankedTensorType::get(SmallVector<int64_t, 4>(rank, -1), element_ty);
117  }
118
119  int64_t num_reduce_dim = 0;
120  llvm::SmallVector<bool, 4> is_reduce_dim(rank, false);
121  for (const APInt &index : indices.getValues<APInt>()) {
122    int64_t dim = GetDimForAxis(index.getSExtValue(), rank);
123    // Invalid input.
124    if (dim < 0 || dim >= rank) return UnrankedTensorType::get(element_ty);
125
126    if (!is_reduce_dim[dim]) {
127      is_reduce_dim[dim] = true;
128      num_reduce_dim++;
129    }
130  }
131
132  ArrayRef<int64_t> shape = ranked_ty.getShape();
133  SmallVector<int64_t, 4> out_shape;
134  out_shape.reserve(rank - (keep_dims.getValue() ? 0 : num_reduce_dim));
135  for (int64_t i = 0; i < rank; ++i) {
136    if (!is_reduce_dim[i])
137      out_shape.push_back(shape[i]);
138    else if (keep_dims.getValue())
139      out_shape.push_back(1);
140  }
141  return RankedTensorType::get(out_shape, element_ty);
142}
143
144// Verifies that the given types are cast compatible. If not, emits appropriate
145// error for the given op. If mask_one_dim is set to true, then the types are
146// allowed to have one mismatching dimension. Masking one of the dimensions is
147// useful for ops like Concat that requires all ranked inputs to have the same
148// rank and match dimension sizes for all but one of the dimensions.
149static LogicalResult VerifyTypesCompatibility(
150    Operation::operand_type_range types, bool mask_one_dim, Operation *op) {
151  constexpr int64_t kUninitialized = -1;
152  int64_t common_rank = kUninitialized;
153  llvm::SmallVector<int64_t, 4> common_dims;
154  int64_t dim_to_mask = kUninitialized;
155
156  // Initialize common_rank with rank of the first ranked type and verify that
157  // following ranked types have the same rank.
158  // Similarly, initialize each of the dimensions with the first type that has
159  // the dimension size available and verify that all following types have the
160  // same size for the dimension. However, if mask_one_dim is true, note down
161  // the dimension index on the first mismatch and ignore dimension at that
162  // index in following types.
163  for (Type ty : types) {
164    RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
165    if (!ranked_ty) continue;
166
167    int64_t rank = ranked_ty.getRank();
168    if (common_rank == kUninitialized) {
169      common_rank = rank;
170      common_dims.resize(common_rank, kUninitialized);
171    } else if (common_rank != rank) {
172      return op->emitError()
173             << "operand type " << ranked_ty
174             << " is not compatible with preceding operands; expected rank: "
175             << common_rank;
176    }
177
178    for (int64_t i = 0, e = common_rank; i != e; i++) {
179      if (i == dim_to_mask) continue;
180
181      int64_t dim = ranked_ty.getDimSize(i);
182      if (dim == kUninitialized) continue;
183
184      int64_t &common_dim = common_dims[i];
185      if (common_dim == kUninitialized) {
186        common_dim = dim;
187      } else if (common_dim != dim) {
188        // If mask_one_dim is true, do not emit an error if this is the only
189        // dimension with mismatches. Note down the dimension to mask it from
190        // the following types.
191        if (mask_one_dim && dim_to_mask == kUninitialized) {
192          dim_to_mask = i;
193          continue;
194        }
195
196        return op->emitError() << "operand type " << ranked_ty
197                               << " is not compatible with preceding operands; "
198                                  "expected dimension at index "
199                               << i << ": " << common_dim;
200      }
201    }
202  }
203  return success();
204}
205
206// This is a helper for the Select to SelectV2 canonicalization. The `data` rank
207// refers to the rank of `t`/`e` (these two inputs have equal rank; this is
208// checked in the verifier).
209//
210// In most cases, the predicate for Select can be used directly as the predicate
211// for SelectV2. However, there is one case that varies, which is when the
212// predicate is a tensor and the data is multidimensional. In this case, Select
213// op semantics dictate that the predicate tensor length must match the size of
214// the first data dimension. This varies from normal broadcasting semantics
215// (which are used in SelectV2), so we must reshape the tensor in this case to
216// be compatible.
217static Value ReshapeSelectPredIfNecessary(OpBuilder *builder, Location loc,
218                                          Value cond, int data_rank) {
219  auto cond_tensor = cond.getType().cast<RankedTensorType>();
220  // Reshape is only needed in the case that the cond rank is 1 (i.e. it is
221  // a vector) AND t/e rank is > 1.
222  if (cond_tensor.getRank() != 1 || data_rank <= 1) {
223    // No reshape necessary. Leave cond as it is.
224    return cond;
225  }
226
227  // This is the case where a reshape is needed. We want to construct the
228  // shape [x,1,...1], where x is the value in the pred tensor and the
229  // length of the shape is equal to data_rank.
230  SmallVector<int64_t, 8> shape(data_rank, 1);
231  shape[0] = cond_tensor.getShape().front();
232  auto new_shape_type =
233      RankedTensorType::get({data_rank}, builder->getIntegerType(64));
234  auto shape_attr = DenseIntElementsAttr::get(new_shape_type, shape);
235  auto new_shape = builder->create<ConstOp>(loc, shape_attr);
236  return builder->create<ReshapeOp>(loc, cond, new_shape);
237}
238
239//===----------------------------------------------------------------------===//
240// Helper functions detect device capabilities from RuntimeDevices.
241//===----------------------------------------------------------------------===//
242
243namespace {
244using DeviceNameUtils = ::tensorflow::DeviceNameUtils;
245using ParsedName = ::tensorflow::DeviceNameUtils::ParsedName;
246
247bool IsGpuDevice(const DeviceNameUtils::ParsedName &device) {
248  return device.type == ::tensorflow::DEVICE_GPU;
249}
250
251}  // namespace
252
253// Returns true if at least one GPU device is available at runtime.
254bool CanUseGpuDevice(const RuntimeDevices &devices) {
255  return llvm::any_of(devices.device_names(), IsGpuDevice);
256}
257
258// Returns true if all of the GPUs available at runtime support TensorCores
259// (NVIDIA compute capability >= 7.0).
260bool CanUseTensorCores(const RuntimeDevices &devices) {
261  auto has_tensor_cores = [&](const DeviceNameUtils::ParsedName &device) {
262    auto md = devices.GetGpuDeviceMetadata(device);
263    return md ? md->cc_major().getInt() >= 7 : false;
264  };
265  return llvm::all_of(
266      llvm::make_filter_range(devices.device_names(), IsGpuDevice),
267      has_tensor_cores);
268}
269
270// Returns true if operation does not have explicit device placement that would
271// prevent it from running on GPU device.
272bool CanUseGpuDevice(Operation *op) {
273  auto device_attr = op->getAttrOfType<StringAttr>("device");
274  if (!device_attr || device_attr.getValue().empty()) return true;
275
276  DeviceNameUtils::ParsedName device;
277  if (!DeviceNameUtils::ParseFullName(device_attr.getValue().str(), &device))
278    return false;
279
280  // We can't use GPU if operation explicitly placed on non-GPU device.
281  return !device.has_type || device.type == ::tensorflow::DEVICE_GPU;
282}
283
284//===----------------------------------------------------------------------===//
285// TF op helper functions to work with layout transformation.
286//===----------------------------------------------------------------------===//
287
288SmallVector<int64_t, 4> ReversePermutation(ArrayRef<int64_t> permutation) {
289  SmallVector<int64_t, 4> reverse(permutation.size());
290  for (size_t i = 0; i < permutation.size(); ++i) {
291    reverse[permutation[i]] = i;
292  }
293  return reverse;
294}
295
296SmallVector<int64_t, 4> GetDataFormatPermutation(StringRef from, StringRef to) {
297  if (from == "NHWC" && to == "NCHW") {
298    return {0, 3, 1, 2};
299  } else if (from == "NCHW" && to == "NHWC") {
300    return {0, 2, 3, 1};
301  } else {
302    return {};
303  }
304}
305
306// Shuffle elements in the `attr` according to the permutation. Optional
307// `inner_size` allows to shuffle array attributes created from rank 2 tensors
308// on outer dimension only.
309ArrayAttr ShuffleArrayAttr(ArrayAttr attr, ArrayRef<int64_t> permutation,
310                           int inner_size = 1) {
311  if (attr.size() == 0) return attr;
312
313  assert(attr.size() % inner_size == 0);
314  assert(attr.size() / inner_size == permutation.size());
315
316  SmallVector<Attribute, 8> values{attr.begin(), attr.end()};
317  SmallVector<Attribute, 8> shuffled(values.size());
318
319  for (size_t i = 0; i < permutation.size(); ++i) {
320    for (size_t j = 0; j < inner_size; ++j) {
321      shuffled[i * inner_size + j] = values[permutation[i] * inner_size + j];
322    }
323  }
324
325  return ArrayAttr::get(attr.getContext(), shuffled);
326}
327
328// Shuffle ranked tensor dimensions according to the permutation.
329Type ShuffleRankedTensorType(Type type, ArrayRef<int64_t> permutation) {
330  if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
331    ArrayRef<int64_t> shape = ranked_type.getShape();
332    assert(permutation.size() == shape.size());
333
334    SmallVector<int64_t, 4> new_shape(permutation.size());
335    for (size_t i = 0; i < permutation.size(); ++i)
336      new_shape[i] = shape[permutation[i]];
337
338    return RankedTensorType::get(new_shape, ranked_type.getElementType());
339  }
340
341  return type;
342}
343
344static bool AreCancellablePermutations(DenseIntElementsAttr perm0,
345                                       DenseIntElementsAttr perm1) {
346  if (perm0.getNumElements() == 0 || perm1.getNumElements() == 0) return false;
347  if (perm0.getNumElements() != perm1.getNumElements()) return false;
348
349  SmallVector<int64_t, 8> perm0_values;
350  for (const auto &value : perm0.getIntValues())
351    perm0_values.push_back(value.getSExtValue());
352
353  SmallVector<int64_t, 8> perm1_values;
354  for (const auto &value : perm1.getIntValues())
355    perm1_values.push_back(value.getSExtValue());
356
357  for (int i = 0; i < perm0_values.size(); ++i) {
358    if (perm0_values[perm1_values[i]] != i) return false;
359  }
360
361  return true;
362}
363
364// Default implementation of `LayoutSensitiveInterface::UpdateDataFormat` for
365// layout sensitive operations that do not have any additional layout dependent
366// attributes besides `data_format` string.
367template <typename Op>
368LogicalResult UpdateDataFormat(StringRef data_format, Op *op) {
369  auto perm = GetDataFormatPermutation(op->data_format(), data_format);
370  if (perm.empty()) return failure();
371
372  // Update data format attribute.
373  (*op)->setAttr("data_format", StringAttr::get(op->getContext(), data_format));
374
375  // Update types for all layout sensitive results.
376  auto layout_sensitive = cast<LayoutSensitiveInterface>(op->getOperation());
377  for (unsigned idx : layout_sensitive.GetLayoutDependentResults()) {
378    OpResult result = op->getOperation()->getResult(idx);
379    result.setType(ShuffleRankedTensorType(result.getType(), perm));
380  }
381
382  return success();
383}
384
385// Default implementation for folding operand transpose into the operation.
386// See `FoldOperandsTransposeInterface::FoldOperandsPermutation`.
387template <typename Op>
388LogicalResult FoldOperandsPermutation(
389    ArrayRef<int64_t> permutation, Op *op,
390    ArrayRef<std::pair<StringRef, ArrayAttr>> shuffle_attrs = {}) {
391  MLIRContext *context = (*op)->template getParentOfType<ModuleOp>().getContext();
392
393  // We only support NHWC <-> NCHW permutations.
394  static constexpr std::array<int64_t, 4> kNchwToNhwc = {0, 2, 3, 1};
395  static constexpr std::array<int64_t, 4> kNhwcToNchw = {0, 3, 1, 2};
396
397  // Operation data format after folding `permutation`.
398  StringRef target_data_format = [&]() -> StringRef {
399    if (op->data_format() == "NHWC" && permutation.equals(kNchwToNhwc)) {
400      return "NCHW";  // cancel NCHW->NHWC operand permutation
401    } else if (op->data_format() == "NCHW" && permutation.equals(kNhwcToNchw)) {
402      return "NHWC";  // cancel NHWC->NCHW operand permutation
403    } else {
404      return "";
405    }
406  }();
407  if (target_data_format.empty()) return failure();
408
409  // To fold operand `permutation` into the `op` we need shuffle all layout
410  // dependent attributes and types with a reverse permutation, and change
411  // operation data format to `target_data_format`.
412  //
413  // Example:
414  //   %1 = SomeOp(...)   {data_format = NHWC}
415  //   %2 = Transpose(%1) {permutation = NHWC->NCHW}
416  //   %3 = Op(%2)        {data_format = NCHW}
417  //
418  // To bypass %2 we have to change data format to shuffle data format from NCHW
419  // to NHWC, which is the reverse of operand permutation (function argument).
420  auto reverse_permutation =
421      GetDataFormatPermutation(op->data_format(), target_data_format);
422  if (reverse_permutation.empty()) return failure();
423
424  (*op)->setAttr("data_format", StringAttr::get(context, target_data_format));
425
426  for (auto pair : shuffle_attrs) {
427    StringRef attr_name = pair.first;
428    ArrayAttr attr_value = pair.second;
429    (*op)->setAttr(attr_name, ShuffleArrayAttr(attr_value, reverse_permutation));
430  }
431
432  auto fold = cast<FoldOperandsTransposeInterface>(op->getOperation());
433  for (unsigned idx : fold.GetLayoutDependentResults()) {
434    OpResult result = op->getOperation()->getResult(idx);
435    result.setType(
436        ShuffleRankedTensorType(result.getType(), reverse_permutation));
437  }
438
439  return success();
440}
441
442//===----------------------------------------------------------------------===//
443// Rewrite Pattern for removing trivial Arithmetic op.
444//===----------------------------------------------------------------------===//
445
446namespace {
447// Fold Arithmetic Op if one of the operands is a constant known to be an
448// Identity (e.g. X+0, X*1, etc...). For commutative operations fold if
449// known identity value is either lhs or rhs.
450template <
451    typename OpT,
452    typename std::enable_if<llvm::is_one_of<
453        OpT, AddV2Op, SubOp, MulOp, DivOp, RealDivOp>::value>::type * = nullptr>
454OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
455                                        ArrayRef<Attribute> operands) {
456  auto lhs_type = arithmetic_op.x().getType().template cast<ShapedType>();
457  auto rhs_type = arithmetic_op.y().getType().template cast<ShapedType>();
458  auto result_type =
459      arithmetic_op.getResult().getType().template cast<ShapedType>();
460
461  // We can fold arithmetic operation only of we can prove that we will not
462  // accidentally hide a broadcasting error.
463  auto is_valid_broadcasting = [](ShapedType operand_ty, ShapedType identity_ty,
464                                  ShapedType result_ty) -> bool {
465    // Scalar identity is broadcastable to any operand shape, we only need to
466    // check that operand has the same shape as a result.
467    bool scalar_identity = identity_ty.hasRank() && identity_ty.getRank() == 0;
468    if (scalar_identity) return operand_ty == result_ty;
469
470    // If identity is not a scalar, we must verify that all shapes are equal
471    // and statically known.
472    //
473    // TODO(ezhulenev): Fold if identity shape is statically know to be
474    // broadcastable to the operand shape.
475    return operand_ty == result_ty && identity_ty == result_ty &&
476           result_ty.hasStaticShape();
477  };
478
479  // Check that we have a constant operand on one side (candidate for identity).
480  const bool is_commutative =
481      (std::is_same<OpT, AddV2Op>::value || std::is_same<OpT, MulOp>::value);
482  auto lhs_attr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
483  auto rhs_attr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
484  if (!rhs_attr && !(is_commutative && lhs_attr)) return {};
485
486  // Mul and Div ops have identity value one while AddV2 and SubOp have identity
487  // value zero.
488  const int identity =
489      (std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value ||
490       std::is_same<OpT, RealDivOp>::value)
491          ? 1
492          : 0;
493
494  Type element_ty = lhs_type.getElementType();
495  Attribute identity_attr;
496  if (auto ty = element_ty.template dyn_cast<FloatType>()) {
497    identity_attr = FloatAttr::get(ty, static_cast<double>(identity));
498  } else if (auto ty = element_ty.template dyn_cast<IntegerType>()) {
499    identity_attr = IntegerAttr::get(ty, static_cast<int64_t>(identity));
500  } else {
501    return {};
502  }
503
504  // Fold: Op(Operand, Identity) -> Operand.
505  if (rhs_attr && is_valid_broadcasting(lhs_type, rhs_type, result_type)) {
506    if (rhs_attr.isSplat() && rhs_attr.getSplatValue() == identity_attr)
507      return arithmetic_op.x();
508  }
509
510  // Fold: Op(Identity, Operand) -> Operand for commutative operations.
511  if (lhs_attr && is_commutative &&
512      is_valid_broadcasting(rhs_type, lhs_type, result_type)) {
513    if (lhs_attr.isSplat() && lhs_attr.getSplatValue() == identity_attr)
514      return arithmetic_op.y();
515  }
516
517  return {};
518}
519}  // namespace
520
521// Verifies an reduction op's `input` and reduction `dims`.
522static LogicalResult VerifyReductionInputAndDims(Value input, Value dims,
523                                                 Location loc) {
524  auto dims_type = dims.getType().dyn_cast<RankedTensorType>();
525  if (!dims_type) return success();
526  if (dims_type.getRank() > 1)
527    return emitError(loc, "dimensions can only be 0D or 1D tensor");
528
529  auto input_type = input.getType().dyn_cast<RankedTensorType>();
530  if (!input_type) return success();
531  int64_t rank = input_type.getRank();
532
533  DenseIntElementsAttr dims_attr;
534  if (!matchPattern(dims, m_Constant(&dims_attr))) return success();
535  for (const auto &dim_pair : llvm::enumerate(dims_attr)) {
536    int64_t cur_dim = dim_pair.value().getSExtValue();
537    if (cur_dim < -rank || cur_dim >= rank)
538      return emitError(loc)
539             << dim_pair.index() << "-th dimension should be in the range of [-"
540             << rank << ", " << rank << ")";
541  }
542
543  return success();
544}
545
546// A type range with description (in singular form) attached to it.
547using TypeRangeWithDesc = std::pair<TypeRange, StringRef>;
548
549LogicalResult VerifyTypeRangesAreCompatible(Operation *op,
550                                            TypeRangeWithDesc range0,
551                                            TypeRangeWithDesc range1) {
552  if (range0.first.size() != range1.first.size()) {
553    return op->emitOpError()
554           << range0.second << "s (size = " << range0.first.size() << ")"
555           << " should have the same number of values as " << range1.second
556           << "s (size = " << range1.first.size() << ")";
557  }
558
559  for (auto it : llvm::enumerate(llvm::zip(range0.first, range1.first))) {
560    int index = it.index();
561    Type type0 = std::get<0>(it.value());
562    Type type1 = std::get<1>(it.value());
563    if (!AreCastCompatible({type0, type1}))
564      return op->emitOpError(llvm::formatv(
565          "{0} type {1} is incompatible with {2} type {3} at index {4}",
566          range0.second, type0, range1.second, type1, index));
567  }
568  return success();
569}
570
571//===----------------------------------------------------------------------===//
572// Function control flow canonicalization.
573//===----------------------------------------------------------------------===//
574
575// Eliminate attributes that are not needed, but can get attached to Ops
576// during import.
577template <typename Op>
578struct DropAttributes : public OpRewritePattern<Op> {
579  using OpRewritePattern<Op>::OpRewritePattern;
580
581  // Drop the "output_shapes" attribute.
582  LogicalResult matchAndRewrite(Op op,
583                                PatternRewriter &rewriter) const override {
584    bool found = !!op.removeAttr("output_shapes");
585    return success(found);
586  }
587};
588
589//===----------------------------------------------------------------------===//
590// TF op helper functions for handling resource handles and ids.
591//===----------------------------------------------------------------------===//
592
593// Returns device of op if present. If op has no device set, an empty string ref
594// is returned instead.
595llvm::StringRef GetDeviceOrEmpty(Operation *op) {
596  if (auto device_attr = op->getAttrOfType<StringAttr>("device"))
597    return device_attr.getValue();
598  return llvm::StringRef();
599}
600
601// Returns resource handle value and id for resource op based on attributes. If
602// a resource handle is anonymous, a new id is always returned.
603ResourceHandleValueAndId GetResourceHandleValueAndIdBase(
604    llvm::StringRef container, llvm::StringRef shared_name,
605    llvm::StringRef device, Value resource,
606    llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map,
607    int64_t &next_id) {
608  // Always create a new ID for anonymous handle.
609  if (IsResourceHandleAnonymous(shared_name)) return {resource, next_id++};
610
611  ResourceHandle handle(container, shared_name, device, /*op=*/nullptr);
612  auto emplace_res = resource_handle_id_map.try_emplace(handle, next_id);
613  // New ID created, increment next_id.
614  if (emplace_res.second) ++next_id;
615  return {resource, emplace_res.first->second};
616}
617