1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tfrt/python_tests/python_test_attrs.h"
17
18 #include <algorithm>
19
20 #include "mlir/IR/Attributes.h" // from @llvm-project
21 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
22 #include "mlir/IR/MLIRContext.h" // from @llvm-project
23 #include "mlir/IR/Operation.h" // from @llvm-project
24 #include "mlir/Support/LLVM.h" // from @llvm-project
25 #include "mlir/Support/LogicalResult.h" // from @llvm-project
26 // Include the auto-generated dialect defs.
27 #include "tensorflow/compiler/mlir/tfrt/python_tests/python_test_attrs.cc.inc"
28
29 namespace mlir {
30 namespace tfrt {
31
initialize()32 void PythonTestAttrsDialect::initialize() {}
33
verifyRegionArgAttribute(::mlir::Operation * op,unsigned regionIndex,unsigned argIndex,::mlir::NamedAttribute attribute)34 ::mlir::LogicalResult PythonTestAttrsDialect::verifyRegionArgAttribute(
35 ::mlir::Operation* op, unsigned regionIndex, unsigned argIndex,
36 ::mlir::NamedAttribute attribute) {
37 const auto& arg = op->getRegion(regionIndex).getArguments()[argIndex];
38
39 // Only verify at the tensor level. We are interested in the correct attribute
40 // values when processing the Tensorflow dialect IR.
41 auto arg_type = arg.getType().dyn_cast<RankedTensorType>();
42 if (!arg_type) return success();
43
44 if (attribute.getName() == GetStaticTypeAttrName()) {
45 auto type_attr = attribute.getValue().dyn_cast<TypeAttr>();
46 if (!type_attr) {
47 return op->emitError()
48 << GetStaticTypeAttrName()
49 << " argument attribute of other type than TypeAttr";
50 }
51
52 auto attr_type = type_attr.getValue().dyn_cast<RankedTensorType>();
53 if (!attr_type) {
54 return op->emitError()
55 << GetStaticTypeAttrName()
56 << " argument type attribute is not a ranked tensor type";
57 }
58 if (attr_type.getNumDynamicDims() > 0) {
59 return op->emitError() << GetStaticTypeAttrName()
60 << " argument type attribute is a ranked tensor "
61 "type with dynamic dimensions";
62 }
63 if (attr_type.getRank() != arg_type.getRank()) {
64 return op->emitError()
65 << GetStaticTypeAttrName()
66 << " argument type attribute is a ranked tensor type with a "
67 "different rank than the rank of the argument tensor";
68 }
69 auto compatible = [&](Type a, Type b) {
70 if (a == b) {
71 return true;
72 }
73 if (!a.isa<IntegerType>() || !b.isa<IntegerType>()) {
74 return false;
75 }
76 auto width_a = a.dyn_cast<IntegerType>().getWidth();
77 auto width_b = b.dyn_cast<IntegerType>().getWidth();
78 return width_a == width_b || std::max(width_a, width_b) == 8;
79 };
80 if (!compatible(attr_type.getElementType(), arg_type.getElementType())) {
81 return op->emitError()
82 << GetStaticTypeAttrName()
83 << " argument type attribute is a ranked tensor type with a "
84 "different element type than the element type of the argument "
85 "tensor";
86 }
87 const auto& attr_shape = attr_type.getShape();
88 const auto& arg_shape = arg_type.getShape();
89 for (int64_t i = 0; i < attr_shape.size(); ++i) {
90 if (!arg_type.isDynamicDim(i) && arg_shape[i] != attr_shape[i]) {
91 return op->emitError()
92 << GetStaticTypeAttrName()
93 << " argument type attribute is a ranked tensor type with a "
94 "shape that doesn't match the static dimensions of the "
95 "argument tensor";
96 }
97 }
98 } else if (attribute.getName() == GetShapeValueAttrName()) {
99 auto dense_attr = attribute.getValue().dyn_cast<DenseIntElementsAttr>();
100 if (!dense_attr) {
101 return op->emitError()
102 << GetShapeValueAttrName()
103 << " argument attribute is not a dense int elements attribute";
104 }
105
106 if (dense_attr.getType() != arg_type) {
107 return op->emitError() << GetShapeValueAttrName()
108 << " argument elements attribute has a different "
109 "type than the argument type";
110 }
111
112 // We expect a valid shape value, therefore check that the dimension values
113 // are not negative.
114 for (auto&& dim : dense_attr) {
115 if (dim.isNegative()) {
116 return op->emitError()
117 << GetShapeValueAttrName()
118 << " argument elements attribute has a negative dimension value";
119 }
120 }
121 }
122 return success();
123 }
124
125 } // namespace tfrt
126 } // namespace mlir
127