1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h"
17
18 #include "llvm/Support/FormatVariadic.h"
19 #include "mlir/IR/Attributes.h" // from @llvm-project
20 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
21 #include "mlir/IR/Matchers.h" // from @llvm-project
22 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
23
24 namespace mlir {
25 namespace TF {
26
27 // Verifies an reduction op's `input` and reduction `dims`.
VerifyReductionInputAndDims(Value input,Value dims,Location loc)28 LogicalResult VerifyReductionInputAndDims(Value input, Value dims,
29 Location loc) {
30 auto dims_type = dims.getType().dyn_cast<RankedTensorType>();
31 if (!dims_type) return success();
32 if (dims_type.getRank() > 1)
33 return emitError(loc, "dimensions can only be 0D or 1D tensor");
34
35 auto input_type = input.getType().dyn_cast<RankedTensorType>();
36 if (!input_type) return success();
37 int64_t rank = input_type.getRank();
38
39 DenseIntElementsAttr dims_attr;
40 if (!matchPattern(dims, m_Constant(&dims_attr))) return success();
41 for (const auto &dim_pair : llvm::enumerate(dims_attr)) {
42 int64_t cur_dim = dim_pair.value().getSExtValue();
43 if (cur_dim < -rank || cur_dim >= rank)
44 return emitError(loc)
45 << dim_pair.index() << "-th dimension should be in the range of [-"
46 << rank << ", " << rank << ")";
47 }
48
49 return success();
50 }
51
VerifyTypeRangesAreCompatible(Operation * op,TypeRangeWithDesc range0,TypeRangeWithDesc range1)52 LogicalResult VerifyTypeRangesAreCompatible(Operation *op,
53 TypeRangeWithDesc range0,
54 TypeRangeWithDesc range1) {
55 if (range0.first.size() != range1.first.size()) {
56 return op->emitOpError()
57 << range0.second << "s (size = " << range0.first.size() << ")"
58 << " should have the same number of values as " << range1.second
59 << "s (size = " << range1.first.size() << ")";
60 }
61
62 for (auto it : llvm::enumerate(llvm::zip(range0.first, range1.first))) {
63 int index = it.index();
64 Type type0 = std::get<0>(it.value());
65 Type type1 = std::get<1>(it.value());
66 if (!AreCastCompatible({type0, type1}))
67 return op->emitOpError(llvm::formatv(
68 "{0} type {1} is incompatible with {2} type {3} at index {4}",
69 range0.second, type0, range1.second, type1, index));
70 }
71 return success();
72 }
73
74 } // namespace TF
75 } // namespace mlir
76