1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// This is a simple include file used to simplify the splitting of the 17// tf_ops.cc file. The helpers in here should be refactored and moved to 18// tf_verifiers or tf_ops. 19// TODO(jpienaar): Remove this file post refactoring. 20 21//===----------------------------------------------------------------------===// 22// TF op helper functions 23//===----------------------------------------------------------------------===// 24 25// Returns the RankedTensorType for the given operand. TensorFlow constant ops 26// may have non-static shape because the shape is not propagated during constant 27// folding. If the defining op for the given operand is a constant op, this 28// routine uses the constant op's attribute to get the actual shape. 29static RankedTensorType GetRankedTensorTypeForOperand(Value operand) { 30 DenseElementsAttr attr; 31 if (matchPattern(operand, m_Constant(&attr))) { 32 return attr.getType().dyn_cast<RankedTensorType>(); 33 } 34 return operand.getType().dyn_cast<RankedTensorType>(); 35} 36 37// Returns true if the given `value` is of ranked float tensor type with the 38// given `rank`. 39static inline bool IsOfRankedFloatTensorType(RankedTensorType type, int rank) { 40 return type && type.getRank() == rank && 41 type.getElementType().isa<FloatType>(); 42} 43 44// Returns true if the given `value` has the specified rank or has unranked 45// type. 46static inline bool IsOfRankOrUnranked(Value value, int64_t rank) { 47 RankedTensorType type = GetRankedTensorTypeForOperand(value); 48 return !type || type.getRank() == rank; 49} 50 51// Returns true if the given `value` has at least the specified rank or has 52// unranked type. 53static inline bool HasRankAtLeast(Value value, int64_t rank) { 54 RankedTensorType type = GetRankedTensorTypeForOperand(value); 55 return !type || type.getRank() >= rank; 56} 57 58// Returns true if the given `value` has at most the specified rank or has 59// unranked type. 60static inline bool HasRankAtMost(Value value, int64_t rank) { 61 RankedTensorType type = GetRankedTensorTypeForOperand(value); 62 return !type || type.getRank() <= rank; 63} 64 65static bool IsUnknownDimOrRank(int64_t dim_or_rank) { 66 return dim_or_rank == -1; 67} 68 69// Returns the tf.Equal/tf.NotEqual result type given `x` and `y` and inputs. If 70// `incompatible_shape_error` is true, reports error if `x` and `y` has 71// incompatible shapes. Otherwise, returns a tensor type with unknown rank. 72static Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value x, 73 Value y, BoolAttr incompatible_shape_error) { 74 auto result_type = 75 OpTrait::util::getBroadcastedType(x.getType(), y.getType()); 76 if (!result_type) { 77 if (incompatible_shape_error.getValue()) { 78 mlir::emitError(loc, "non-broadcastable operands"); 79 } else { 80 return UnrankedTensorType::get(builder->getI1Type()); 81 } 82 } 83 84 auto ranked_type = result_type.dyn_cast<RankedTensorType>(); 85 if (!ranked_type) return UnrankedTensorType::get(builder->getI1Type()); 86 87 return RankedTensorType::get(ranked_type.getShape(), builder->getI1Type()); 88} 89 90// Returns dimension index for the given TensorFlow axis that supports negative 91// indexing. 92static int64_t GetDimForAxis(int64_t axis, int64_t rank) { 93 return axis >= 0 ? axis : axis + rank; 94} 95 96// Infers output type for reduction ops such as SumOp, MaxOp etc. 97// TODO(b/e667204a): Move this logic to shape inference once it supports custom 98// inference functions. 99static Type InferReductionOpType(Value input, Value reduction_indices, 100 BoolAttr keep_dims, Builder *builder) { 101 Type input_ty = input.getType(); 102 Type element_ty = getElementTypeOrSelf(input_ty); 103 104 // Output type is unranked if input type is not ranked. 105 auto ranked_ty = input_ty.dyn_cast<RankedTensorType>(); 106 if (!ranked_ty) return UnrankedTensorType::get(element_ty); 107 int64_t rank = ranked_ty.getRank(); 108 109 DenseIntElementsAttr indices; 110 if (!matchPattern(reduction_indices, m_Constant(&indices))) { 111 // Output type is unranked if reduction indices are not constant and reduced 112 // dimensions are not kept. 113 if (!keep_dims.getValue()) return UnrankedTensorType::get(element_ty); 114 115 // Otherwise, output type has same rank as the input. 116 return RankedTensorType::get(SmallVector<int64_t, 4>(rank, -1), element_ty); 117 } 118 119 int64_t num_reduce_dim = 0; 120 llvm::SmallVector<bool, 4> is_reduce_dim(rank, false); 121 for (const APInt &index : indices.getValues<APInt>()) { 122 int64_t dim = GetDimForAxis(index.getSExtValue(), rank); 123 // Invalid input. 124 if (dim < 0 || dim >= rank) return UnrankedTensorType::get(element_ty); 125 126 if (!is_reduce_dim[dim]) { 127 is_reduce_dim[dim] = true; 128 num_reduce_dim++; 129 } 130 } 131 132 ArrayRef<int64_t> shape = ranked_ty.getShape(); 133 SmallVector<int64_t, 4> out_shape; 134 out_shape.reserve(rank - (keep_dims.getValue() ? 0 : num_reduce_dim)); 135 for (int64_t i = 0; i < rank; ++i) { 136 if (!is_reduce_dim[i]) 137 out_shape.push_back(shape[i]); 138 else if (keep_dims.getValue()) 139 out_shape.push_back(1); 140 } 141 return RankedTensorType::get(out_shape, element_ty); 142} 143 144// Verifies that the given types are cast compatible. If not, emits appropriate 145// error for the given op. If mask_one_dim is set to true, then the types are 146// allowed to have one mismatching dimension. Masking one of the dimensions is 147// useful for ops like Concat that requires all ranked inputs to have the same 148// rank and match dimension sizes for all but one of the dimensions. 149static LogicalResult VerifyTypesCompatibility( 150 Operation::operand_type_range types, bool mask_one_dim, Operation *op) { 151 constexpr int64_t kUninitialized = -1; 152 int64_t common_rank = kUninitialized; 153 llvm::SmallVector<int64_t, 4> common_dims; 154 int64_t dim_to_mask = kUninitialized; 155 156 // Initialize common_rank with rank of the first ranked type and verify that 157 // following ranked types have the same rank. 158 // Similarly, initialize each of the dimensions with the first type that has 159 // the dimension size available and verify that all following types have the 160 // same size for the dimension. However, if mask_one_dim is true, note down 161 // the dimension index on the first mismatch and ignore dimension at that 162 // index in following types. 163 for (Type ty : types) { 164 RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>(); 165 if (!ranked_ty) continue; 166 167 int64_t rank = ranked_ty.getRank(); 168 if (common_rank == kUninitialized) { 169 common_rank = rank; 170 common_dims.resize(common_rank, kUninitialized); 171 } else if (common_rank != rank) { 172 return op->emitError() 173 << "operand type " << ranked_ty 174 << " is not compatible with preceding operands; expected rank: " 175 << common_rank; 176 } 177 178 for (int64_t i = 0, e = common_rank; i != e; i++) { 179 if (i == dim_to_mask) continue; 180 181 int64_t dim = ranked_ty.getDimSize(i); 182 if (dim == kUninitialized) continue; 183 184 int64_t &common_dim = common_dims[i]; 185 if (common_dim == kUninitialized) { 186 common_dim = dim; 187 } else if (common_dim != dim) { 188 // If mask_one_dim is true, do not emit an error if this is the only 189 // dimension with mismatches. Note down the dimension to mask it from 190 // the following types. 191 if (mask_one_dim && dim_to_mask == kUninitialized) { 192 dim_to_mask = i; 193 continue; 194 } 195 196 return op->emitError() << "operand type " << ranked_ty 197 << " is not compatible with preceding operands; " 198 "expected dimension at index " 199 << i << ": " << common_dim; 200 } 201 } 202 } 203 return success(); 204} 205 206// This is a helper for the Select to SelectV2 canonicalization. The `data` rank 207// refers to the rank of `t`/`e` (these two inputs have equal rank; this is 208// checked in the verifier). 209// 210// In most cases, the predicate for Select can be used directly as the predicate 211// for SelectV2. However, there is one case that varies, which is when the 212// predicate is a tensor and the data is multidimensional. In this case, Select 213// op semantics dictate that the predicate tensor length must match the size of 214// the first data dimension. This varies from normal broadcasting semantics 215// (which are used in SelectV2), so we must reshape the tensor in this case to 216// be compatible. 217static Value ReshapeSelectPredIfNecessary(OpBuilder *builder, Location loc, 218 Value cond, int data_rank) { 219 auto cond_tensor = cond.getType().cast<RankedTensorType>(); 220 // Reshape is only needed in the case that the cond rank is 1 (i.e. it is 221 // a vector) AND t/e rank is > 1. 222 if (cond_tensor.getRank() != 1 || data_rank <= 1) { 223 // No reshape necessary. Leave cond as it is. 224 return cond; 225 } 226 227 // This is the case where a reshape is needed. We want to construct the 228 // shape [x,1,...1], where x is the value in the pred tensor and the 229 // length of the shape is equal to data_rank. 230 SmallVector<int64_t, 8> shape(data_rank, 1); 231 shape[0] = cond_tensor.getShape().front(); 232 auto new_shape_type = 233 RankedTensorType::get({data_rank}, builder->getIntegerType(64)); 234 auto shape_attr = DenseIntElementsAttr::get(new_shape_type, shape); 235 auto new_shape = builder->create<ConstOp>(loc, shape_attr); 236 return builder->create<ReshapeOp>(loc, cond, new_shape); 237} 238 239//===----------------------------------------------------------------------===// 240// Helper functions detect device capabilities from RuntimeDevices. 241//===----------------------------------------------------------------------===// 242 243namespace { 244using DeviceNameUtils = ::tensorflow::DeviceNameUtils; 245using ParsedName = ::tensorflow::DeviceNameUtils::ParsedName; 246 247bool IsGpuDevice(const DeviceNameUtils::ParsedName &device) { 248 return device.type == ::tensorflow::DEVICE_GPU; 249} 250 251} // namespace 252 253// Returns true if at least one GPU device is available at runtime. 254bool CanUseGpuDevice(const RuntimeDevices &devices) { 255 return llvm::any_of(devices.device_names(), IsGpuDevice); 256} 257 258// Returns true if all of the GPUs available at runtime support TensorCores 259// (NVIDIA compute capability >= 7.0). 260bool CanUseTensorCores(const RuntimeDevices &devices) { 261 auto has_tensor_cores = [&](const DeviceNameUtils::ParsedName &device) { 262 auto md = devices.GetGpuDeviceMetadata(device); 263 return md ? md->cc_major().getInt() >= 7 : false; 264 }; 265 return llvm::all_of( 266 llvm::make_filter_range(devices.device_names(), IsGpuDevice), 267 has_tensor_cores); 268} 269 270// Returns true if operation does not have explicit device placement that would 271// prevent it from running on GPU device. 272bool CanUseGpuDevice(Operation *op) { 273 auto device_attr = op->getAttrOfType<StringAttr>("device"); 274 if (!device_attr || device_attr.getValue().empty()) return true; 275 276 DeviceNameUtils::ParsedName device; 277 if (!DeviceNameUtils::ParseFullName(device_attr.getValue().str(), &device)) 278 return false; 279 280 // We can't use GPU if operation explicitly placed on non-GPU device. 281 return !device.has_type || device.type == ::tensorflow::DEVICE_GPU; 282} 283 284//===----------------------------------------------------------------------===// 285// TF op helper functions to work with layout transformation. 286//===----------------------------------------------------------------------===// 287 288SmallVector<int64_t, 4> ReversePermutation(ArrayRef<int64_t> permutation) { 289 SmallVector<int64_t, 4> reverse(permutation.size()); 290 for (size_t i = 0; i < permutation.size(); ++i) { 291 reverse[permutation[i]] = i; 292 } 293 return reverse; 294} 295 296SmallVector<int64_t, 4> GetDataFormatPermutation(StringRef from, StringRef to) { 297 if (from == "NHWC" && to == "NCHW") { 298 return {0, 3, 1, 2}; 299 } else if (from == "NCHW" && to == "NHWC") { 300 return {0, 2, 3, 1}; 301 } else { 302 return {}; 303 } 304} 305 306// Shuffle elements in the `attr` according to the permutation. Optional 307// `inner_size` allows to shuffle array attributes created from rank 2 tensors 308// on outer dimension only. 309ArrayAttr ShuffleArrayAttr(ArrayAttr attr, ArrayRef<int64_t> permutation, 310 int inner_size = 1) { 311 if (attr.size() == 0) return attr; 312 313 assert(attr.size() % inner_size == 0); 314 assert(attr.size() / inner_size == permutation.size()); 315 316 SmallVector<Attribute, 8> values{attr.begin(), attr.end()}; 317 SmallVector<Attribute, 8> shuffled(values.size()); 318 319 for (size_t i = 0; i < permutation.size(); ++i) { 320 for (size_t j = 0; j < inner_size; ++j) { 321 shuffled[i * inner_size + j] = values[permutation[i] * inner_size + j]; 322 } 323 } 324 325 return ArrayAttr::get(attr.getContext(), shuffled); 326} 327 328// Shuffle ranked tensor dimensions according to the permutation. 329Type ShuffleRankedTensorType(Type type, ArrayRef<int64_t> permutation) { 330 if (auto ranked_type = type.dyn_cast<RankedTensorType>()) { 331 ArrayRef<int64_t> shape = ranked_type.getShape(); 332 assert(permutation.size() == shape.size()); 333 334 SmallVector<int64_t, 4> new_shape(permutation.size()); 335 for (size_t i = 0; i < permutation.size(); ++i) 336 new_shape[i] = shape[permutation[i]]; 337 338 return RankedTensorType::get(new_shape, ranked_type.getElementType()); 339 } 340 341 return type; 342} 343 344static bool AreCancellablePermutations(DenseIntElementsAttr perm0, 345 DenseIntElementsAttr perm1) { 346 if (perm0.getNumElements() == 0 || perm1.getNumElements() == 0) return false; 347 if (perm0.getNumElements() != perm1.getNumElements()) return false; 348 349 SmallVector<int64_t, 8> perm0_values; 350 for (const auto &value : perm0.getIntValues()) 351 perm0_values.push_back(value.getSExtValue()); 352 353 SmallVector<int64_t, 8> perm1_values; 354 for (const auto &value : perm1.getIntValues()) 355 perm1_values.push_back(value.getSExtValue()); 356 357 for (int i = 0; i < perm0_values.size(); ++i) { 358 if (perm0_values[perm1_values[i]] != i) return false; 359 } 360 361 return true; 362} 363 364// Default implementation of `LayoutSensitiveInterface::UpdateDataFormat` for 365// layout sensitive operations that do not have any additional layout dependent 366// attributes besides `data_format` string. 367template <typename Op> 368LogicalResult UpdateDataFormat(StringRef data_format, Op *op) { 369 auto perm = GetDataFormatPermutation(op->data_format(), data_format); 370 if (perm.empty()) return failure(); 371 372 // Update data format attribute. 373 (*op)->setAttr("data_format", StringAttr::get(op->getContext(), data_format)); 374 375 // Update types for all layout sensitive results. 376 auto layout_sensitive = cast<LayoutSensitiveInterface>(op->getOperation()); 377 for (unsigned idx : layout_sensitive.GetLayoutDependentResults()) { 378 OpResult result = op->getOperation()->getResult(idx); 379 result.setType(ShuffleRankedTensorType(result.getType(), perm)); 380 } 381 382 return success(); 383} 384 385// Default implementation for folding operand transpose into the operation. 386// See `FoldOperandsTransposeInterface::FoldOperandsPermutation`. 387template <typename Op> 388LogicalResult FoldOperandsPermutation( 389 ArrayRef<int64_t> permutation, Op *op, 390 ArrayRef<std::pair<StringRef, ArrayAttr>> shuffle_attrs = {}) { 391 MLIRContext *context = (*op)->template getParentOfType<ModuleOp>().getContext(); 392 393 // We only support NHWC <-> NCHW permutations. 394 static constexpr std::array<int64_t, 4> kNchwToNhwc = {0, 2, 3, 1}; 395 static constexpr std::array<int64_t, 4> kNhwcToNchw = {0, 3, 1, 2}; 396 397 // Operation data format after folding `permutation`. 398 StringRef target_data_format = [&]() -> StringRef { 399 if (op->data_format() == "NHWC" && permutation.equals(kNchwToNhwc)) { 400 return "NCHW"; // cancel NCHW->NHWC operand permutation 401 } else if (op->data_format() == "NCHW" && permutation.equals(kNhwcToNchw)) { 402 return "NHWC"; // cancel NHWC->NCHW operand permutation 403 } else { 404 return ""; 405 } 406 }(); 407 if (target_data_format.empty()) return failure(); 408 409 // To fold operand `permutation` into the `op` we need shuffle all layout 410 // dependent attributes and types with a reverse permutation, and change 411 // operation data format to `target_data_format`. 412 // 413 // Example: 414 // %1 = SomeOp(...) {data_format = NHWC} 415 // %2 = Transpose(%1) {permutation = NHWC->NCHW} 416 // %3 = Op(%2) {data_format = NCHW} 417 // 418 // To bypass %2 we have to change data format to shuffle data format from NCHW 419 // to NHWC, which is the reverse of operand permutation (function argument). 420 auto reverse_permutation = 421 GetDataFormatPermutation(op->data_format(), target_data_format); 422 if (reverse_permutation.empty()) return failure(); 423 424 (*op)->setAttr("data_format", StringAttr::get(context, target_data_format)); 425 426 for (auto pair : shuffle_attrs) { 427 StringRef attr_name = pair.first; 428 ArrayAttr attr_value = pair.second; 429 (*op)->setAttr(attr_name, ShuffleArrayAttr(attr_value, reverse_permutation)); 430 } 431 432 auto fold = cast<FoldOperandsTransposeInterface>(op->getOperation()); 433 for (unsigned idx : fold.GetLayoutDependentResults()) { 434 OpResult result = op->getOperation()->getResult(idx); 435 result.setType( 436 ShuffleRankedTensorType(result.getType(), reverse_permutation)); 437 } 438 439 return success(); 440} 441 442//===----------------------------------------------------------------------===// 443// Rewrite Pattern for removing trivial Arithmetic op. 444//===----------------------------------------------------------------------===// 445 446namespace { 447// Fold Arithmetic Op if one of the operands is a constant known to be an 448// Identity (e.g. X+0, X*1, etc...). For commutative operations fold if 449// known identity value is either lhs or rhs. 450template < 451 typename OpT, 452 typename std::enable_if<llvm::is_one_of< 453 OpT, AddV2Op, SubOp, MulOp, DivOp, RealDivOp>::value>::type * = nullptr> 454OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, 455 ArrayRef<Attribute> operands) { 456 auto lhs_type = arithmetic_op.x().getType().template cast<ShapedType>(); 457 auto rhs_type = arithmetic_op.y().getType().template cast<ShapedType>(); 458 auto result_type = 459 arithmetic_op.getResult().getType().template cast<ShapedType>(); 460 461 // We can fold arithmetic operation only of we can prove that we will not 462 // accidentally hide a broadcasting error. 463 auto is_valid_broadcasting = [](ShapedType operand_ty, ShapedType identity_ty, 464 ShapedType result_ty) -> bool { 465 // Scalar identity is broadcastable to any operand shape, we only need to 466 // check that operand has the same shape as a result. 467 bool scalar_identity = identity_ty.hasRank() && identity_ty.getRank() == 0; 468 if (scalar_identity) return operand_ty == result_ty; 469 470 // If identity is not a scalar, we must verify that all shapes are equal 471 // and statically known. 472 // 473 // TODO(ezhulenev): Fold if identity shape is statically know to be 474 // broadcastable to the operand shape. 475 return operand_ty == result_ty && identity_ty == result_ty && 476 result_ty.hasStaticShape(); 477 }; 478 479 // Check that we have a constant operand on one side (candidate for identity). 480 const bool is_commutative = 481 (std::is_same<OpT, AddV2Op>::value || std::is_same<OpT, MulOp>::value); 482 auto lhs_attr = operands[0].dyn_cast_or_null<DenseElementsAttr>(); 483 auto rhs_attr = operands[1].dyn_cast_or_null<DenseElementsAttr>(); 484 if (!rhs_attr && !(is_commutative && lhs_attr)) return {}; 485 486 // Mul and Div ops have identity value one while AddV2 and SubOp have identity 487 // value zero. 488 const int identity = 489 (std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value || 490 std::is_same<OpT, RealDivOp>::value) 491 ? 1 492 : 0; 493 494 Type element_ty = lhs_type.getElementType(); 495 Attribute identity_attr; 496 if (auto ty = element_ty.template dyn_cast<FloatType>()) { 497 identity_attr = FloatAttr::get(ty, static_cast<double>(identity)); 498 } else if (auto ty = element_ty.template dyn_cast<IntegerType>()) { 499 identity_attr = IntegerAttr::get(ty, static_cast<int64_t>(identity)); 500 } else { 501 return {}; 502 } 503 504 // Fold: Op(Operand, Identity) -> Operand. 505 if (rhs_attr && is_valid_broadcasting(lhs_type, rhs_type, result_type)) { 506 if (rhs_attr.isSplat() && rhs_attr.getSplatValue() == identity_attr) 507 return arithmetic_op.x(); 508 } 509 510 // Fold: Op(Identity, Operand) -> Operand for commutative operations. 511 if (lhs_attr && is_commutative && 512 is_valid_broadcasting(rhs_type, lhs_type, result_type)) { 513 if (lhs_attr.isSplat() && lhs_attr.getSplatValue() == identity_attr) 514 return arithmetic_op.y(); 515 } 516 517 return {}; 518} 519} // namespace 520 521// Verifies an reduction op's `input` and reduction `dims`. 522static LogicalResult VerifyReductionInputAndDims(Value input, Value dims, 523 Location loc) { 524 auto dims_type = dims.getType().dyn_cast<RankedTensorType>(); 525 if (!dims_type) return success(); 526 if (dims_type.getRank() > 1) 527 return emitError(loc, "dimensions can only be 0D or 1D tensor"); 528 529 auto input_type = input.getType().dyn_cast<RankedTensorType>(); 530 if (!input_type) return success(); 531 int64_t rank = input_type.getRank(); 532 533 DenseIntElementsAttr dims_attr; 534 if (!matchPattern(dims, m_Constant(&dims_attr))) return success(); 535 for (const auto &dim_pair : llvm::enumerate(dims_attr)) { 536 int64_t cur_dim = dim_pair.value().getSExtValue(); 537 if (cur_dim < -rank || cur_dim >= rank) 538 return emitError(loc) 539 << dim_pair.index() << "-th dimension should be in the range of [-" 540 << rank << ", " << rank << ")"; 541 } 542 543 return success(); 544} 545 546// A type range with description (in singular form) attached to it. 547using TypeRangeWithDesc = std::pair<TypeRange, StringRef>; 548 549LogicalResult VerifyTypeRangesAreCompatible(Operation *op, 550 TypeRangeWithDesc range0, 551 TypeRangeWithDesc range1) { 552 if (range0.first.size() != range1.first.size()) { 553 return op->emitOpError() 554 << range0.second << "s (size = " << range0.first.size() << ")" 555 << " should have the same number of values as " << range1.second 556 << "s (size = " << range1.first.size() << ")"; 557 } 558 559 for (auto it : llvm::enumerate(llvm::zip(range0.first, range1.first))) { 560 int index = it.index(); 561 Type type0 = std::get<0>(it.value()); 562 Type type1 = std::get<1>(it.value()); 563 if (!AreCastCompatible({type0, type1})) 564 return op->emitOpError(llvm::formatv( 565 "{0} type {1} is incompatible with {2} type {3} at index {4}", 566 range0.second, type0, range1.second, type1, index)); 567 } 568 return success(); 569} 570 571//===----------------------------------------------------------------------===// 572// Function control flow canonicalization. 573//===----------------------------------------------------------------------===// 574 575// Eliminate attributes that are not needed, but can get attached to Ops 576// during import. 577template <typename Op> 578struct DropAttributes : public OpRewritePattern<Op> { 579 using OpRewritePattern<Op>::OpRewritePattern; 580 581 // Drop the "output_shapes" attribute. 582 LogicalResult matchAndRewrite(Op op, 583 PatternRewriter &rewriter) const override { 584 bool found = !!op.removeAttr("output_shapes"); 585 return success(found); 586 } 587}; 588 589//===----------------------------------------------------------------------===// 590// TF op helper functions for handling resource handles and ids. 591//===----------------------------------------------------------------------===// 592 593// Returns device of op if present. If op has no device set, an empty string ref 594// is returned instead. 595llvm::StringRef GetDeviceOrEmpty(Operation *op) { 596 if (auto device_attr = op->getAttrOfType<StringAttr>("device")) 597 return device_attr.getValue(); 598 return llvm::StringRef(); 599} 600 601// Returns resource handle value and id for resource op based on attributes. If 602// a resource handle is anonymous, a new id is always returned. 603ResourceHandleValueAndId GetResourceHandleValueAndIdBase( 604 llvm::StringRef container, llvm::StringRef shared_name, 605 llvm::StringRef device, Value resource, 606 llvm::SmallDenseMap<ResourceHandle, int64_t> &resource_handle_id_map, 607 int64_t &next_id) { 608 // Always create a new ID for anonymous handle. 609 if (IsResourceHandleAnonymous(shared_name)) return {resource, next_id++}; 610 611 ResourceHandle handle(container, shared_name, device, /*op=*/nullptr); 612 auto emplace_res = resource_handle_id_map.try_emplace(handle, next_id); 613 // New ID created, increment next_id. 614 if (emplace_res.second) ++next_id; 615 return {resource, emplace_res.first->second}; 616} 617