1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_IR_TYPE_DIALECT_H_
17 #define TENSORFLOW_CORE_IR_TYPE_DIALECT_H_
18
19 #include <string>
20
21 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
22 #include "mlir/IR/Diagnostics.h" // from @llvm-project
23 #include "mlir/IR/Dialect.h" // from @llvm-project
24 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
25
26 // Include the dialect class generated from dialect.td.
27 // The constructor and the printing/parsing of dialect types are manually
28 // implemented (see ops.cpp).
29 #include "tensorflow/core/ir/types/dialect.h.inc"
30
31 // Include the Type classes declaration generated from types.td
32 #define GET_TYPEDEF_CLASSES
33 #include "tensorflow/core/ir/types/types.h.inc"
34
35 namespace mlir {
36 namespace tf_type {
37
38 //===----------------------------------------------------------------------===//
39 // TensorFlow types
40 //===----------------------------------------------------------------------===//
41
42 // The base class in the TensorFlow type hierarchy.
43 class TensorFlowType : public Type {
44 public:
45 using Type::Type;
46
47 // Support method to enable LLVM-style type casting.
48 static bool classof(Type type);
49 };
50
51 // Returns true if the specified type is a valid TensorFlow element type.
IsValidTFElementType(Type type)52 inline bool IsValidTFElementType(Type type) {
53 return type.isa<ComplexType, FloatType, IntegerType, TensorFlowType>();
54 }
55
56 // Returns true if this is a valid TensorFlow tensor type.
IsValidTFTensorType(Type type)57 inline bool IsValidTFTensorType(Type type) {
58 // TensorFlow types should be tensors of one of the valid TensorFlow element
59 // types.
60 if (auto tensor_ty = type.dyn_cast<TensorType>())
61 return IsValidTFElementType(tensor_ty.getElementType());
62 return false;
63 }
64
65 namespace detail {
66 // Common implementation of TensorFlow types. The template argument indicates
67 // the concrete derived class per CRTP.
68 template <typename Derived>
69 class TensorFlowTypeImpl
70 : public Type::TypeBase<Derived, TensorFlowType, TypeStorage> {
71 public:
72 using Base = typename Type::TypeBase<Derived, TensorFlowType, TypeStorage>;
73 using TFBase = TensorFlowTypeImpl<Derived>;
74 using Base::Base;
75 };
76 } // namespace detail
77
78 // TensorFlowRefType class supports all the ref types in TensorFlow dialect.
79 class TensorFlowRefType : public TensorFlowType {
80 public:
81 using TensorFlowType::TensorFlowType;
82
83 // Checks if a type is TensorFlow Ref type.
84 static bool classof(Type type);
85
86 // Converts a type to the corresponding TensorFlowRef type.
87 static TensorFlowType get(Type type);
getChecked(Type type,MLIRContext * context,Location loc)88 static TensorFlowType getChecked(Type type, MLIRContext* context,
89 Location loc) {
90 if (failed(verify(loc, type))) {
91 return TensorFlowRefType();
92 }
93 return get(type);
94 }
95
verify(Location loc,Type type)96 static LogicalResult verify(Location loc, Type type) {
97 // type should be a valid TensorFlow type.
98 if (!IsValidTFTensorType(type)) {
99 return emitError(loc) << "invalid TensorFlow type: " << type;
100 }
101 return success();
102 }
103
104 // Converts a TensorFlowRef type to the corresponding TensorFlow or standard
105 // type.
106 Type RemoveRef();
107 };
108
109 // Define a class for each individual TensorFlow type (dtype), see types.def
110 // for the list.
111 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
112 class tftype##Type : public detail::TensorFlowTypeImpl<tftype##Type> { \
113 public: \
114 using TFBase::TFBase; \
115 };
116 #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
117 #include "tensorflow/core/ir/types/types.def"
118
119 namespace detail {
120 // Storage type contains inferred subtypes for TypeWithSubtype.
121 class TypeWithSubtypeStorage : public TypeStorage {
122 public:
123 using KeyTy = ArrayRef<TensorType>;
124
125 // NOLINTNEXTLINE
construct(TypeStorageAllocator & allocator,const KeyTy & key)126 static TypeWithSubtypeStorage* construct(TypeStorageAllocator& allocator,
127 const KeyTy& key) {
128 ArrayRef<TensorType> subtypes = allocator.copyInto(key);
129 return new (allocator.allocate<TypeWithSubtypeStorage>())
130 TypeWithSubtypeStorage(subtypes);
131 }
132
TypeWithSubtypeStorage(const KeyTy & key)133 explicit TypeWithSubtypeStorage(const KeyTy& key) : subtypes_(key) {}
134
135 bool operator==(const KeyTy& key) const { return key == subtypes_; }
136
hashKey(const KeyTy & key)137 static llvm::hash_code hashKey(const KeyTy& key) {
138 return llvm::hash_combine_range(key.begin(), key.end());
139 }
140
141 KeyTy subtypes_;
142 };
143
144 // Common implementation of TensorFlow types with subtypes. These subtypes are
145 // opaque and their interpretation depends on the actual underlying type.
146 // The template argument indicates the concrete derived class per CRTP. Concrete
147 // classes must implement the following:
148 // - `static std::string getTypeName()` that returns the name of the type for
149 // verification logging.
150 template <typename Derived>
151 class TypeWithSubtypeImpl
152 : public Type::TypeBase<Derived, TensorFlowType, TypeWithSubtypeStorage> {
153 public:
154 using Base = Type::TypeBase<Derived, TensorFlowType, TypeWithSubtypeStorage>;
155 using TFBase = TypeWithSubtypeImpl<Derived>;
156 using Base::Base;
157
get(ArrayRef<TensorType> subtypes,MLIRContext * context)158 static Derived get(ArrayRef<TensorType> subtypes, MLIRContext* context) {
159 return Base::get(context, subtypes);
160 }
161
getChecked(ArrayRef<TensorType> subtypes,MLIRContext * context,Location loc)162 static Derived getChecked(ArrayRef<TensorType> subtypes, MLIRContext* context,
163 Location loc) {
164 return Base::getChecked(loc, subtypes);
165 }
getChecked(function_ref<InFlightDiagnostic ()> emitError,MLIRContext * context,ArrayRef<TensorType> subtypes)166 static Derived getChecked(function_ref<InFlightDiagnostic()> emitError,
167 MLIRContext* context,
168 ArrayRef<TensorType> subtypes) {
169 return Base::getChecked(emitError, context, subtypes);
170 }
171
get(MLIRContext * context)172 static Derived get(MLIRContext* context) { return get({}, context); }
173
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<TensorType> subtypes)174 static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
175 ArrayRef<TensorType> subtypes) {
176 // Each of the subtypes should be a valid TensorFlow type.
177 for (TensorType subtype : subtypes) {
178 if (!IsValidTFTensorType(subtype)) {
179 return emitError() << "invalid " << Derived::getTypeName()
180 << " subtype: " << subtype;
181 }
182 }
183 return success();
184 }
185
getSubtypes()186 ArrayRef<TensorType> getSubtypes() { return Base::getImpl()->subtypes_; }
187 };
188 } // namespace detail
189
190 // TensorFlowTypeWithSubtype class supports all the types with subtypes in
191 // TensorFlow dialect.
192 class TensorFlowTypeWithSubtype : public TensorFlowType {
193 public:
194 using TensorFlowType::TensorFlowType;
195
196 // Checks if a type is TensorFlow type with subtypes.
197 static bool classof(Type type);
198
199 // Converts a TypeWithSubtype type to the same type but without its subtypes.
200 Type RemoveSubtypes();
201
202 // Clone the current Type with new subtypes.
203 TensorFlowTypeWithSubtype clone(ArrayRef<TensorType> new_subtypes);
204
205 // Returns the subtypes.
206 ArrayRef<TensorType> GetSubtypes();
207 };
208
209 // Returns the corresponding TensorFlow type with subtypes but without its
210 // subtypes.
GetDefaultTypeOf(TensorFlowTypeWithSubtype type)211 inline Type GetDefaultTypeOf(TensorFlowTypeWithSubtype type) {
212 return type.RemoveSubtypes();
213 }
214
215 // TensorFlow resource type is used to support TensorFlow resource variables,
216 // which represent shared, persistent state manipulated by a TensorFlow program.
217 // ResourceType stores shape and datatype for subtypes unlike most other data
218 // types that don't have any associated information.
219 class ResourceType : public detail::TypeWithSubtypeImpl<ResourceType> {
220 public:
221 using TFBase::TFBase;
getTypeName()222 static std::string getTypeName() { return "ResourceType"; }
223 };
224
225 // TensorFlow variant type is used to support arbitrary custom C++ data types.
226 // VariantType stores inferred shape and datatype for subtypes unlike most other
227 // data types that don't have any associated information. For example, variants
228 // encoding TensorList type stores the common shape and dtype of the list
229 // elements as the only subtype.
230 class VariantType : public detail::TypeWithSubtypeImpl<VariantType> {
231 public:
232 using TFBase::TFBase;
getTypeName()233 static std::string getTypeName() { return "VariantType"; }
234 };
235
236 // Given two types `a` and `b`, returns a refined type which is cast compatible
237 // with both `a` and `b` and is equal to or more precise than both of them. It
238 // returns empty Type if the input types are not cast compatible.
239 // Provides option to ignore ref types on 'a'. This is useful for TF ops that
240 // might allow operands to either be same as result type or be a ref type
241 // corresponding to it.
242 Type GetCastCompatibleType(Type a, Type b, bool may_ignore_ref_type_a = false);
243
244 // Returns whether two arrays of Type are broadcast compatible.
245 bool BroadcastCompatible(TypeRange lhs, TypeRange rhs);
246
247 // Returns whether the two elemental types are compatible. Shapes are compatible
248 // if:
249 // - the types are statically equal
250 // - could be dynamically equal
251 // - considering dynamic shapes equal unless contradictory info known;
252 // - element types are equivalent, modulo subtypes possible be less exact
253 // (e.g., a resource type without subtype is considered compatible with
254 // resource type with known subtype).
255 // Provide option to ignore ref types on 'lhs'.
256 bool HasCompatibleElementTypes(Type lhs, Type rhs,
257 bool may_ignore_ref_type_lhs = false);
258
259 // Returns true if all TensorFlow types can be cast to one
260 // another. In other words, a single run-time value is legal for all the types.
261 // For example, tensor<*xf32>, tensor<?xf32> and tensor<3xf32> are cast
262 // compatible.
263 bool AreCastCompatible(TypeRange types);
264
265 // Returns true if corresponding elements of lhs and rhs AreCastCompatible and
266 // lhs and rhs are the same length.
267 bool ArraysAreCastCompatible(TypeRange lhs, TypeRange rhs);
268
269 // If `ty` is a tensor type and its element type has subtypes, then returns a
270 // new type of same shape but dropped subtypes for the element type.
271 // Otherwise, if `ty` has subtypes, then returns corresponding type with dropped
272 // subtypes.
273 // Otherwise, returns the original type `ty`.
274 Type DropSubTypes(Type ty);
275
276 // If `ty` is a tensor type and has elements of a ref type, then returns a new
277 // type of same shape but corresponding non-ref type as element type.
278 // Otherwise, if `ty` is a ref type, then returns corresponding non-ref type.
279 // Otherwise, returns the original type `ty`.
280 Type DropRefType(Type ty);
281
282 // Convenience call for executing both `DropRefType` and `DropSubTypes`.
283 Type DropRefAndSubTypes(Type ty);
284
285 //===----------------------------------------------------------------------===//
286 // Utility iterators
287 //===----------------------------------------------------------------------===//
288
289 // An iterator for the tensor shapes of an op's operands of shaped types.
290 // Returns llvm::None if a operand is unranked; returns ArrayRef<int64_t> as the
291 // shape otherwise.
292 class OperandShapeIterator final
293 : public llvm::mapped_iterator<Operation::operand_iterator,
294 llvm::Optional<ArrayRef<int64_t>> (*)(
295 Value)> {
296 public:
297 using reference = llvm::Optional<ArrayRef<int64_t>>;
298
299 /// Initializes the operand shape iterator to the specified operand iterator.
300 explicit OperandShapeIterator(Operation::operand_iterator it);
301 };
302
303 using OperandShapeRange = iterator_range<OperandShapeIterator>;
304
305 // An iterator for the tensor shapes of an op's results of shaped types.
306 // Returns llvm::None if a result is unranked; returns ArrayRef<int64_t> as the
307 // shape otherwise.
308 class ResultShapeIterator final
309 : public llvm::mapped_iterator<Operation::result_iterator,
310 llvm::Optional<ArrayRef<int64_t>> (*)(
311 Value)> {
312 public:
313 using reference = llvm::Optional<ArrayRef<int64_t>>;
314
315 /// Initializes the result shape iterator to the specified result iterator.
316 explicit ResultShapeIterator(Operation::result_iterator it);
317 };
318
319 using ResultShapeRange = iterator_range<ResultShapeIterator>;
320
321 // Returns a range with just resource type values from the input range
322 // preserved.
323 template <typename RangeT>
filter_resources(RangeT && range)324 auto filter_resources(RangeT&& range) {
325 return llvm::make_filter_range(std::forward<RangeT>(range), [](Value val) {
326 return getElementTypeOrSelf(val.getType()).isa<ResourceType>();
327 });
328 }
329
330 // Returns the element type if `type` is a `ShapedType` and the type itself
331 // otherwise, converting `TensorFlowRef` type to corresponding `TensorFlow` or
332 // standard type if necessary.
GetElementTypeOrSelfResolveRef(Type type)333 inline Type GetElementTypeOrSelfResolveRef(Type type) {
334 Type element_type = getElementTypeOrSelf(type);
335 if (auto ref_type = element_type.dyn_cast<TensorFlowRefType>()) {
336 element_type = ref_type.RemoveRef();
337 }
338 return element_type;
339 }
340
341 } // namespace tf_type
342 } // namespace mlir
343
344 //===----------------------------------------------------------------------===//
345 // Tablegen Attribute Declarations
346 //===----------------------------------------------------------------------===//
347
348 #define GET_ATTRDEF_CLASSES
349 #include "tensorflow/core/ir/types/attributes.h.inc"
350 #include "tensorflow/core/ir/types/attributes_enum.h.inc"
351
352 #endif // TENSORFLOW_CORE_IR_TYPE_DIALECT_H_
353