• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines the op traits used in the MLIR TensorFlow dialect.
17 
18 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_
19 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_
20 
21 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
22 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
23 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
24 #include "mlir/Interfaces/SideEffectInterfaces.h"  // from @llvm-project
25 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
27 
28 namespace mlir {
29 namespace OpTrait {
30 namespace TF {
31 
32 // Verifies if 'ref_type' is a REF type corresponding to 'type'.
VerifyRefTypeMatch(mlir::Type type,mlir::Type maybe_ref_type)33 static inline LogicalResult VerifyRefTypeMatch(mlir::Type type,
34                                                mlir::Type maybe_ref_type) {
35   if (auto ref_type = maybe_ref_type.dyn_cast<mlir::TF::TensorFlowRefType>())
36     return success(ref_type.RemoveRef().getTypeID() == type.getTypeID());
37   return failure();
38 }
39 
40 // This class provides verification for ops that are known to have the same
41 // result types and all operands are either of the same type as result or a REF
42 // type corresponding to the result type.
43 // TODO(jpienaar): Update the name and the description.
44 template <typename ConcreteType>
45 class OperandsSameAsResultsTypeOrRef
46     : public TraitBase<ConcreteType, OperandsSameAsResultsTypeOrRef> {
47  public:
verifyTrait(Operation * op)48   static LogicalResult verifyTrait(Operation* op) {
49     LogicalResult shapeMatch = impl::verifySameOperandsAndResultShape(op);
50     if (failed(shapeMatch)) return shapeMatch;
51     Type type = op->getResult(0).getType();
52     // Verify that the first result type is same as the rest of the results.
53     // We skip the comparison against itself.
54     for (auto result_type : llvm::drop_begin(op->getResultTypes(), 1)) {
55       if (!mlir::TF::HasCompatibleElementTypes(type, result_type))
56         return op->emitOpError()
57                << "requires all return types to have compatible element types";
58     }
59     for (auto operand_type : op->getOperandTypes()) {
60       if (!mlir::TF::HasCompatibleElementTypes(
61               operand_type, type, /*may_ignore_ref_type_lhs=*/true))
62         return op->emitError() << "requires all operands and results to have "
63                                   "compatible element types";
64     }
65     return success();
66   }
67 };
68 
69 namespace detail {
verifySameOperandsAndResultElementTypeResolveRef(Operation * op)70 inline LogicalResult verifySameOperandsAndResultElementTypeResolveRef(
71     Operation* op) {
72   Type element_type;
73   if (op->getNumResults() > 0) {
74     element_type =
75         mlir::TF::GetElementTypeOrSelfResolveRef(op->getResult(0).getType());
76   } else if (op->getNumOperands() > 0) {
77     element_type =
78         mlir::TF::GetElementTypeOrSelfResolveRef(op->getOperand(0).getType());
79   } else {
80     // Nothing to check.
81     return success();
82   }
83   // Verify that all result element types are compatible to `element_type`.
84   for (const auto& result_type : op->getResultTypes()) {
85     if (mlir::TF::GetElementTypeOrSelfResolveRef(result_type) != element_type) {
86       return op->emitOpError(
87           "requires compatible element types for all operands and results");
88     }
89   }
90   // Verify that all operand element types are compatible to `element_type`.
91   for (const auto& operand_type : op->getOperandTypes()) {
92     if (mlir::TF::GetElementTypeOrSelfResolveRef(operand_type) !=
93         element_type) {
94       return op->emitOpError(
95           "requires compatible element types for all operands and results");
96     }
97   }
98   return success();
99 }
100 }  // namespace detail
101 
102 // Verifies that op has the same operand and result element types (or type
103 // itself, if scalar) after resolving reference types (i.e., after converting
104 // reference types to their corresponding TensorFlow or standard types).
105 template <typename ConcreteType>
106 class SameOperandsAndResultElementTypeResolveRef
107     : public TraitBase<ConcreteType,
108                        SameOperandsAndResultElementTypeResolveRef> {
109  public:
verifyTrait(Operation * op)110   static LogicalResult verifyTrait(Operation* op) {
111     return detail::verifySameOperandsAndResultElementTypeResolveRef(op);
112   }
113 };
114 
115 // Verifies that op has the same operand and result types after resolving
116 // reference types (i.e., after converting reference types to their
117 // corresponding TensorFlow or standard types).
118 template <typename ConcreteType>
119 class SameOperandsAndResultTypeResolveRef
120     : public TraitBase<ConcreteType, SameOperandsAndResultTypeResolveRef> {
121  public:
verifyTrait(Operation * op)122   static LogicalResult verifyTrait(Operation* op) {
123     if (failed(impl::verifySameOperandsAndResultShape(op))) return failure();
124     return detail::verifySameOperandsAndResultElementTypeResolveRef(op);
125   }
126 };
127 
128 // Layout agnostic operations do not depend on the operands data layout (data
129 // format), as and example all element wise operations are layout agnostic.
130 template <typename ConcreteType>
131 class LayoutAgnostic : public TraitBase<ConcreteType, LayoutAgnostic> {};
132 
133 // Trait to indicate operations that cannot be duplicated as they might carry
134 // certain state around within their implementations.
135 template <typename ConcreteType>
136 class CannotDuplicate : public TraitBase<ConcreteType, CannotDuplicate> {
137  public:
verifyTrait(Operation * op)138   static LogicalResult verifyTrait(Operation* op) {
139     if (MemoryEffectOpInterface::hasNoEffect(op))
140       return op->emitError(
141           "operations with no side effects cannot have CannotDuplicate trait");
142     return success();
143   }
144 };
145 
146 // Trait to indicate an operation cannot be constant folded.
147 template <typename ConcreteType>
148 class NoConstantFold : public TraitBase<ConcreteType, NoConstantFold> {};
149 
150 // Coefficient-wise binary operation with implicit broadcasting support, for
151 // example tf.Sub operation.
152 template <typename ConcreteType>
153 class CwiseBinary : public TraitBase<ConcreteType, CwiseBinary> {};
154 
155 // Coefficient-wise unary operation, for example tf.Sqrt operation.
156 template <typename ConcreteType>
157 class CwiseUnary : public TraitBase<ConcreteType, CwiseUnary> {};
158 
159 }  // namespace TF
160 }  // namespace OpTrait
161 }  // namespace mlir
162 
163 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_
164