1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file defines the types used in the standard MLIR TensorFlow dialect.
17
18 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TYPES_H_
19 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TYPES_H_
20
21 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
22 #include "mlir/IR/Diagnostics.h" // from @llvm-project
23 #include "mlir/IR/Location.h" // from @llvm-project
24 #include "mlir/IR/Operation.h" // from @llvm-project
25 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
26 #include "mlir/IR/Types.h" // from @llvm-project
27
28 namespace mlir {
29 namespace TF {
30 //===----------------------------------------------------------------------===//
31 // Utility iterators
32 //===----------------------------------------------------------------------===//
33
34 // An iterator for the tensor shapes of an op's operands of shaped types.
35 // Returns llvm::None if a operand is unranked; returns ArrayRef<int64_t> as the
36 // shape otherwise.
37 class OperandShapeIterator final
38 : public llvm::mapped_iterator<Operation::operand_iterator,
39 llvm::Optional<ArrayRef<int64_t>> (*)(
40 Value)> {
41 public:
42 using reference = llvm::Optional<ArrayRef<int64_t>>;
43
44 /// Initializes the operand shape iterator to the specified operand iterator.
45 explicit OperandShapeIterator(Operation::operand_iterator it);
46 };
47
48 using OperandShapeRange = iterator_range<OperandShapeIterator>;
49
50 // An iterator for the tensor shapes of an op's results of shaped types.
51 // Returns llvm::None if a result is unranked; returns ArrayRef<int64_t> as the
52 // shape otherwise.
53 class ResultShapeIterator final
54 : public llvm::mapped_iterator<Operation::result_iterator,
55 llvm::Optional<ArrayRef<int64_t>> (*)(
56 Value)> {
57 public:
58 using reference = llvm::Optional<ArrayRef<int64_t>>;
59
60 /// Initializes the result shape iterator to the specified result iterator.
61 explicit ResultShapeIterator(Operation::result_iterator it);
62 };
63
64 using ResultShapeRange = iterator_range<ResultShapeIterator>;
65
66 //===----------------------------------------------------------------------===//
67 // TensorFlow types
68 //===----------------------------------------------------------------------===//
69
70 // The base class in the TensorFlow type hierarchy.
71 class TensorFlowType : public Type {
72 public:
73 using Type::Type;
74
75 // Support method to enable LLVM-style type casting.
76 static bool classof(Type type);
77 };
78
79 // Returns true if the specified type is a valid TensorFlow element type.
IsValidTFElementType(Type type)80 static inline bool IsValidTFElementType(Type type) {
81 return type.isa<ComplexType, FloatType, IntegerType, TensorFlowType>();
82 }
83
84 // Returns true if this is a valid TensorFlow tensor type.
IsValidTFTensorType(Type type)85 static inline bool IsValidTFTensorType(Type type) {
86 // TensorFlow types should be tensors of one of the valid TensorFlow element
87 // types.
88 if (auto tensor_ty = type.dyn_cast<TensorType>())
89 return IsValidTFElementType(tensor_ty.getElementType());
90 return false;
91 }
92
93 namespace detail {
94 // Common implementation of TensorFlow types. The template argument indicates
95 // the concrete derived class per CRTP.
96 template <typename Derived>
97 class TensorFlowTypeImpl
98 : public Type::TypeBase<Derived, TensorFlowType, TypeStorage> {
99 public:
100 using Base = typename Type::TypeBase<Derived, TensorFlowType, TypeStorage>;
101 using TFBase = TensorFlowTypeImpl<Derived>;
102 using Base::Base;
103 };
104 } // namespace detail
105
106 // TensorFlowRefType class supports all the ref types in TensorFlow dialect.
107 class TensorFlowRefType : public TensorFlowType {
108 public:
109 using TensorFlowType::TensorFlowType;
110
111 // Checks if a type is TensorFlow Ref type.
112 static bool classof(Type type);
113
114 // Converts a type to the corresponding TensorFlowRef type.
115 static TensorFlowType get(Type type);
getChecked(Type type,MLIRContext * context,Location loc)116 static TensorFlowType getChecked(Type type, MLIRContext* context,
117 Location loc) {
118 if (failed(verifyConstructionInvariants(loc, type))) {
119 return TensorFlowRefType();
120 }
121 return get(type);
122 }
123
verifyConstructionInvariants(Location loc,Type type)124 static LogicalResult verifyConstructionInvariants(Location loc, Type type) {
125 // type should be a valid TensorFlow type.
126 if (!IsValidTFTensorType(type)) {
127 return emitError(loc) << "invalid TensorFlow type: " << type;
128 }
129 return success();
130 }
131
132 // Converts a TensorFlowRef type to the corresponding TensorFlow or standard
133 // type.
134 Type RemoveRef();
135 };
136
137 // Returns the corresponding TensorFlow or standard type from TensorFlowRef
138 // type.
GetDefaultTypeOf(TensorFlowRefType type)139 static inline Type GetDefaultTypeOf(TensorFlowRefType type) {
140 return type.RemoveRef();
141 }
142
143 // Returns the element type if `type` is a `ShapedType` and the type itself
144 // otherwise, converting `TensorFlowRef` type to corresponding `TensorFlow` or
145 // standard type if necessary.
GetElementTypeOrSelfResolveRef(Type type)146 static inline Type GetElementTypeOrSelfResolveRef(Type type) {
147 Type element_type = mlir::getElementTypeOrSelf(type);
148 if (auto ref_type = element_type.dyn_cast<mlir::TF::TensorFlowRefType>()) {
149 element_type = ref_type.RemoveRef();
150 }
151 return element_type;
152 }
153
154 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
155 class tftype##Type : public detail::TensorFlowTypeImpl<tftype##Type> { \
156 public: \
157 using TFBase::TFBase; \
158 };
159
160 // Custom TensorFlow types are defined separately.
161 #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
162
163 // NOLINTNEXTLINE
164 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
165
166 namespace detail {
167 // Storage type contains inferred subtypes for TypeWithSubtype.
168 class TypeWithSubtypeStorage : public TypeStorage {
169 public:
170 using KeyTy = ArrayRef<TensorType>;
171
172 // NOLINTNEXTLINE
construct(TypeStorageAllocator & allocator,const KeyTy & key)173 static TypeWithSubtypeStorage* construct(TypeStorageAllocator& allocator,
174 const KeyTy& key) {
175 ArrayRef<TensorType> subtypes = allocator.copyInto(key);
176 return new (allocator.allocate<TypeWithSubtypeStorage>())
177 TypeWithSubtypeStorage(subtypes);
178 }
179
TypeWithSubtypeStorage(const KeyTy & key)180 explicit TypeWithSubtypeStorage(const KeyTy& key) : subtypes_(key) {}
181
182 bool operator==(const KeyTy& key) const { return key == subtypes_; }
183
hashKey(const KeyTy & key)184 static llvm::hash_code hashKey(const KeyTy& key) {
185 return llvm::hash_combine_range(key.begin(), key.end());
186 }
187
188 KeyTy subtypes_;
189 };
190
191 // Common implementation of TensorFlow types with subtypes. These subtypes are
192 // opaque and their interpretation depends on the actual underlying type.
193 // The template argument indicates the concrete derived class per CRTP. Concrete
194 // classes must implement the following:
195 // - `static std::string getTypeName()` that returns the name of the type for
196 // verification logging.
197 template <typename Derived>
198 class TypeWithSubtypeImpl
199 : public Type::TypeBase<Derived, TensorFlowType, TypeWithSubtypeStorage> {
200 public:
201 using Base = Type::TypeBase<Derived, TensorFlowType, TypeWithSubtypeStorage>;
202 using TFBase = TypeWithSubtypeImpl<Derived>;
203 using Base::Base;
204
get(ArrayRef<TensorType> subtypes,MLIRContext * context)205 static Derived get(ArrayRef<TensorType> subtypes, MLIRContext* context) {
206 return Base::get(context, subtypes);
207 }
208
getChecked(ArrayRef<TensorType> subtypes,MLIRContext * context,Location loc)209 static Derived getChecked(ArrayRef<TensorType> subtypes, MLIRContext* context,
210 Location loc) {
211 return Base::getChecked(loc, subtypes);
212 }
213
get(MLIRContext * context)214 static Derived get(MLIRContext* context) { return get({}, context); }
215
verifyConstructionInvariants(Location loc,ArrayRef<TensorType> subtypes)216 static LogicalResult verifyConstructionInvariants(
217 Location loc, ArrayRef<TensorType> subtypes) {
218 // Each of the subtypes should be a valid TensorFlow type.
219 for (TensorType subtype : subtypes) {
220 if (!IsValidTFTensorType(subtype)) {
221 return emitError(loc) << "invalid " << Derived::getTypeName()
222 << " subtype: " << subtype;
223 }
224 }
225 return success();
226 }
227
getSubtypes()228 ArrayRef<TensorType> getSubtypes() { return Base::getImpl()->subtypes_; }
229 };
230 } // namespace detail
231
232 // TensorFlowTypeWithSubtype class supports all the types with subtypes in
233 // TensorFlow dialect.
234 class TensorFlowTypeWithSubtype : public TensorFlowType {
235 public:
236 using TensorFlowType::TensorFlowType;
237
238 // Checks if a type is TensorFlow type with subtypes.
239 static bool classof(Type type);
240
241 // Converts a TypeWithSubtype type to the same type but without its subtypes.
242 Type RemoveSubtypes();
243
244 // Returns the subtypes.
245 ArrayRef<TensorType> GetSubtypes();
246 };
247
248 // Returns the corresponding TensorFlow type with subtypes but without its
249 // subtypes.
GetDefaultTypeOf(TensorFlowTypeWithSubtype type)250 static inline Type GetDefaultTypeOf(TensorFlowTypeWithSubtype type) {
251 return type.RemoveSubtypes();
252 }
253
254 // TensorFlow resource type is used to support TensorFlow resource variables,
255 // which represent shared, persistent state manipulated by a TensorFlow program.
256 // ResourceType stores shape and datatype for subtypes unlike most other data
257 // types that don't have any associated information.
258 class ResourceType : public detail::TypeWithSubtypeImpl<ResourceType> {
259 public:
260 using TFBase::TFBase;
getTypeName()261 static std::string getTypeName() { return "ResourceType"; }
262 };
263
264 // TensorFlow variant type is used to support arbitrary custom C++ data types.
265 // VariantType stores inferred shape and datatype for subtypes unlike most other
266 // data types that don't have any associated information. For example, variants
267 // encoding TensorList type stores the common shape and dtype of the list
268 // elements as the only subtype.
269 class VariantType : public detail::TypeWithSubtypeImpl<VariantType> {
270 public:
271 using TFBase::TFBase;
getTypeName()272 static std::string getTypeName() { return "VariantType"; }
273 };
274
275 // Given two types `a` and `b`, returns a refined type which is cast compatible
276 // with both `a` and `b` and is equal to or more precise than both of them. It
277 // returns empty Type if the input types are not cast compatible.
278 // Provides option to ignore ref types on 'a'. This is useful for TF ops that
279 // might allow operands to either be same as result type or be a ref type
280 // corresponding to it.
281 mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b,
282 bool may_ignore_ref_type_a);
283
284 // Returns whether two arrays of Type are broadcast compatible.
285 bool BroadcastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs);
286
287 // Returns whether the two elemental types are compatible. Shapes are compatible
288 // if:
289 // - the types are statically equal
290 // - could be dynamically equal
291 // - considering dynamic shapes equal unless contradictory info known;
292 // - element types are equivalent, modulo subtypes possible be less exact
293 // (e.g., a resource type without subtype is considered compatible with
294 // resource type with known subtype).
295 // Provide option to ignore ref types on 'lhs'.
296 bool HasCompatibleElementTypes(Type lhs, Type rhs,
297 bool may_ignore_ref_type_lhs = false);
298
299 // Returns true if all TensorFlow types can be cast to one
300 // another. In other words, a single run-time value is legal for both the types.
301 // For example, tensor<*xf32>, tensor<?xf32> and tensor<3xf32> are cast
302 // compatible.
303 bool AreCastCompatible(ArrayRef<Type> types);
304
305 // Returns true if corresponding elements of lhs and rhs AreCastCompatible and
306 // lhs and rhs are the same length.
307 bool ArraysAreCastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs);
308
309 // If `ty` is a tensor type and its element type has subtypes, then returns a
310 // new type of same shape but dropped subtypes for the element type.
311 // Otherwise, if `ty` has subtypes, then returns corresponding type with dropped
312 // subtypes.
313 // Otherwise, returns the original type `ty`.
314 Type DropSubTypes(Type ty);
315
316 // If `ty` is a tensor type and has elements of a ref type, then returns a new
317 // type of same shape but corresponding non-ref type as element type.
318 // Otherwise, if `ty` is a ref type, then returns corresponding non-ref type.
319 // Otherwise, returns the original type `ty`.
320 Type DropRefType(Type ty);
321
322 // Convenience call for executing both `DropRefType` and `DropSubTypes`.
323 Type DropRefAndSubTypes(Type ty);
324
325 } // end namespace TF
326 } // end namespace mlir
327
328 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TYPES_H_
329