• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tfrt/python_tests/python_test_attrs.h"
17 
18 #include <algorithm>
19 
20 #include "mlir/IR/Attributes.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
22 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
23 #include "mlir/IR/Operation.h"  // from @llvm-project
24 #include "mlir/Support/LLVM.h"  // from @llvm-project
25 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
26 // Include the auto-generated dialect defs.
27 #include "tensorflow/compiler/mlir/tfrt/python_tests/python_test_attrs.cc.inc"
28 
29 namespace mlir {
30 namespace tfrt {
31 
initialize()32 void PythonTestAttrsDialect::initialize() {}
33 
verifyRegionArgAttribute(::mlir::Operation * op,unsigned regionIndex,unsigned argIndex,::mlir::NamedAttribute attribute)34 ::mlir::LogicalResult PythonTestAttrsDialect::verifyRegionArgAttribute(
35     ::mlir::Operation* op, unsigned regionIndex, unsigned argIndex,
36     ::mlir::NamedAttribute attribute) {
37   const auto& arg = op->getRegion(regionIndex).getArguments()[argIndex];
38 
39   // Only verify at the tensor level. We are interested in the correct attribute
40   // values when processing the Tensorflow dialect IR.
41   auto arg_type = arg.getType().dyn_cast<RankedTensorType>();
42   if (!arg_type) return success();
43 
44   if (attribute.getName() == GetStaticTypeAttrName()) {
45     auto type_attr = attribute.getValue().dyn_cast<TypeAttr>();
46     if (!type_attr) {
47       return op->emitError()
48              << GetStaticTypeAttrName()
49              << " argument attribute of other type than TypeAttr";
50     }
51 
52     auto attr_type = type_attr.getValue().dyn_cast<RankedTensorType>();
53     if (!attr_type) {
54       return op->emitError()
55              << GetStaticTypeAttrName()
56              << " argument type attribute is not a ranked tensor type";
57     }
58     if (attr_type.getNumDynamicDims() > 0) {
59       return op->emitError() << GetStaticTypeAttrName()
60                              << " argument type attribute is a ranked tensor "
61                                 "type with dynamic dimensions";
62     }
63     if (attr_type.getRank() != arg_type.getRank()) {
64       return op->emitError()
65              << GetStaticTypeAttrName()
66              << " argument type attribute is a ranked tensor type with a "
67                 "different rank than the rank of the argument tensor";
68     }
69     auto compatible = [&](Type a, Type b) {
70       if (a == b) {
71         return true;
72       }
73       if (!a.isa<IntegerType>() || !b.isa<IntegerType>()) {
74         return false;
75       }
76       auto width_a = a.dyn_cast<IntegerType>().getWidth();
77       auto width_b = b.dyn_cast<IntegerType>().getWidth();
78       return width_a == width_b || std::max(width_a, width_b) == 8;
79     };
80     if (!compatible(attr_type.getElementType(), arg_type.getElementType())) {
81       return op->emitError()
82              << GetStaticTypeAttrName()
83              << " argument type attribute is a ranked tensor type with a "
84                 "different element type than the element type of the argument "
85                 "tensor";
86     }
87     const auto& attr_shape = attr_type.getShape();
88     const auto& arg_shape = arg_type.getShape();
89     for (int64_t i = 0; i < attr_shape.size(); ++i) {
90       if (!arg_type.isDynamicDim(i) && arg_shape[i] != attr_shape[i]) {
91         return op->emitError()
92                << GetStaticTypeAttrName()
93                << " argument type attribute is a ranked tensor type with a "
94                   "shape that doesn't match the static dimensions of the "
95                   "argument tensor";
96       }
97     }
98   } else if (attribute.getName() == GetShapeValueAttrName()) {
99     auto dense_attr = attribute.getValue().dyn_cast<DenseIntElementsAttr>();
100     if (!dense_attr) {
101       return op->emitError()
102              << GetShapeValueAttrName()
103              << " argument attribute is not a dense int elements attribute";
104     }
105 
106     if (dense_attr.getType() != arg_type) {
107       return op->emitError() << GetShapeValueAttrName()
108                              << " argument elements attribute has a different "
109                                 "type than the argument type";
110     }
111 
112     // We expect a valid shape value, therefore check that the dimension values
113     // are not negative.
114     for (auto&& dim : dense_attr) {
115       if (dim.isNegative()) {
116         return op->emitError()
117                << GetShapeValueAttrName()
118                << " argument elements attribute has a negative dimension value";
119       }
120     }
121   }
122   return success();
123 }
124 
125 }  // namespace tfrt
126 }  // namespace mlir
127