1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_IR_TYPE_DIALECT_H_
17 #define TENSORFLOW_CORE_IR_TYPE_DIALECT_H_
18
19 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
20 #include "mlir/IR/Diagnostics.h" // from @llvm-project
21 #include "mlir/IR/Dialect.h" // from @llvm-project
22 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
23
24 // Include the dialect class generated from dialect.td.
25 // The constructor and the printing/parsing of dialect types are manually
26 // implemented (see ops.cpp).
27 #include "tensorflow/core/ir/types/dialect.h.inc"
28
29 // Include the Type classes declaration generated from types.td
30 #define GET_TYPEDEF_CLASSES
31 #include "tensorflow/core/ir/types/types.h.inc"
32
33 namespace mlir {
34 namespace tf_type {
35
36 //===----------------------------------------------------------------------===//
37 // TensorFlow types
38 //===----------------------------------------------------------------------===//
39
40 // The base class in the TensorFlow type hierarchy.
41 class TensorFlowType : public Type {
42 public:
43 using Type::Type;
44
45 // Support method to enable LLVM-style type casting.
46 static bool classof(Type type);
47 };
48
49 // Returns true if the specified type is a valid TensorFlow element type.
IsValidTFElementType(Type type)50 inline bool IsValidTFElementType(Type type) {
51 return type.isa<ComplexType, FloatType, IntegerType, TensorFlowType>();
52 }
53
54 // Returns true if this is a valid TensorFlow tensor type.
IsValidTFTensorType(Type type)55 inline bool IsValidTFTensorType(Type type) {
56 // TensorFlow types should be tensors of one of the valid TensorFlow element
57 // types.
58 if (auto tensor_ty = type.dyn_cast<TensorType>())
59 return IsValidTFElementType(tensor_ty.getElementType());
60 return false;
61 }
62
63 namespace detail {
64 // Common implementation of TensorFlow types. The template argument indicates
65 // the concrete derived class per CRTP.
66 template <typename Derived>
67 class TensorFlowTypeImpl
68 : public Type::TypeBase<Derived, TensorFlowType, TypeStorage> {
69 public:
70 using Base = typename Type::TypeBase<Derived, TensorFlowType, TypeStorage>;
71 using TFBase = TensorFlowTypeImpl<Derived>;
72 using Base::Base;
73 };
74 } // namespace detail
75
76 // TensorFlowRefType class supports all the ref types in TensorFlow dialect.
77 class TensorFlowRefType : public TensorFlowType {
78 public:
79 using TensorFlowType::TensorFlowType;
80
81 // Checks if a type is TensorFlow Ref type.
82 static bool classof(Type type);
83
84 // Converts a type to the corresponding TensorFlowRef type.
85 static TensorFlowType get(Type type);
getChecked(Type type,MLIRContext * context,Location loc)86 static TensorFlowType getChecked(Type type, MLIRContext* context,
87 Location loc) {
88 if (failed(verify(loc, type))) {
89 return TensorFlowRefType();
90 }
91 return get(type);
92 }
93
verify(Location loc,Type type)94 static LogicalResult verify(Location loc, Type type) {
95 // type should be a valid TensorFlow type.
96 if (!IsValidTFTensorType(type)) {
97 return emitError(loc) << "invalid TensorFlow type: " << type;
98 }
99 return success();
100 }
101
102 // Converts a TensorFlowRef type to the corresponding TensorFlow or standard
103 // type.
104 Type RemoveRef();
105 };
106
107 // Define a class for each individual TensorFlow type (dtype), see types.def
108 // for the list.
109 #define HANDLE_TF_TYPE(tftype, enumerant, name) \
110 class tftype##Type : public detail::TensorFlowTypeImpl<tftype##Type> { \
111 public: \
112 using TFBase::TFBase; \
113 };
114 #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
115 #include "tensorflow/core/ir/types/types.def"
116
117 namespace detail {
118 // Storage type contains inferred subtypes for TypeWithSubtype.
119 class TypeWithSubtypeStorage : public TypeStorage {
120 public:
121 using KeyTy = ArrayRef<TensorType>;
122
123 // NOLINTNEXTLINE
construct(TypeStorageAllocator & allocator,const KeyTy & key)124 static TypeWithSubtypeStorage* construct(TypeStorageAllocator& allocator,
125 const KeyTy& key) {
126 ArrayRef<TensorType> subtypes = allocator.copyInto(key);
127 return new (allocator.allocate<TypeWithSubtypeStorage>())
128 TypeWithSubtypeStorage(subtypes);
129 }
130
TypeWithSubtypeStorage(const KeyTy & key)131 explicit TypeWithSubtypeStorage(const KeyTy& key) : subtypes_(key) {}
132
133 bool operator==(const KeyTy& key) const { return key == subtypes_; }
134
hashKey(const KeyTy & key)135 static llvm::hash_code hashKey(const KeyTy& key) {
136 return llvm::hash_combine_range(key.begin(), key.end());
137 }
138
139 KeyTy subtypes_;
140 };
141
142 // Common implementation of TensorFlow types with subtypes. These subtypes are
143 // opaque and their interpretation depends on the actual underlying type.
144 // The template argument indicates the concrete derived class per CRTP. Concrete
145 // classes must implement the following:
146 // - `static std::string getTypeName()` that returns the name of the type for
147 // verification logging.
148 template <typename Derived>
149 class TypeWithSubtypeImpl
150 : public Type::TypeBase<Derived, TensorFlowType, TypeWithSubtypeStorage> {
151 public:
152 using Base = Type::TypeBase<Derived, TensorFlowType, TypeWithSubtypeStorage>;
153 using TFBase = TypeWithSubtypeImpl<Derived>;
154 using Base::Base;
155
get(ArrayRef<TensorType> subtypes,MLIRContext * context)156 static Derived get(ArrayRef<TensorType> subtypes, MLIRContext* context) {
157 return Base::get(context, subtypes);
158 }
159
getChecked(ArrayRef<TensorType> subtypes,MLIRContext * context,Location loc)160 static Derived getChecked(ArrayRef<TensorType> subtypes, MLIRContext* context,
161 Location loc) {
162 return Base::getChecked(loc, subtypes);
163 }
getChecked(function_ref<InFlightDiagnostic ()> emitError,MLIRContext * context,ArrayRef<TensorType> subtypes)164 static Derived getChecked(function_ref<InFlightDiagnostic()> emitError,
165 MLIRContext* context,
166 ArrayRef<TensorType> subtypes) {
167 return Base::getChecked(emitError, context, subtypes);
168 }
169
get(MLIRContext * context)170 static Derived get(MLIRContext* context) { return get({}, context); }
171
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<TensorType> subtypes)172 static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
173 ArrayRef<TensorType> subtypes) {
174 // Each of the subtypes should be a valid TensorFlow type.
175 for (TensorType subtype : subtypes) {
176 if (!IsValidTFTensorType(subtype)) {
177 return emitError() << "invalid " << Derived::getTypeName()
178 << " subtype: " << subtype;
179 }
180 }
181 return success();
182 }
183
getSubtypes()184 ArrayRef<TensorType> getSubtypes() { return Base::getImpl()->subtypes_; }
185 };
186 } // namespace detail
187
188 // TensorFlowTypeWithSubtype class supports all the types with subtypes in
189 // TensorFlow dialect.
190 class TensorFlowTypeWithSubtype : public TensorFlowType {
191 public:
192 using TensorFlowType::TensorFlowType;
193
194 // Checks if a type is TensorFlow type with subtypes.
195 static bool classof(Type type);
196
197 // Converts a TypeWithSubtype type to the same type but without its subtypes.
198 Type RemoveSubtypes();
199
200 // Clone the current Type with new subtypes.
201 TensorFlowTypeWithSubtype clone(ArrayRef<TensorType> new_subtypes);
202
203 // Returns the subtypes.
204 ArrayRef<TensorType> GetSubtypes();
205 };
206
207 // Returns the corresponding TensorFlow type with subtypes but without its
208 // subtypes.
GetDefaultTypeOf(TensorFlowTypeWithSubtype type)209 inline Type GetDefaultTypeOf(TensorFlowTypeWithSubtype type) {
210 return type.RemoveSubtypes();
211 }
212
213 // TensorFlow resource type is used to support TensorFlow resource variables,
214 // which represent shared, persistent state manipulated by a TensorFlow program.
215 // ResourceType stores shape and datatype for subtypes unlike most other data
216 // types that don't have any associated information.
217 class ResourceType : public detail::TypeWithSubtypeImpl<ResourceType> {
218 public:
219 using TFBase::TFBase;
getTypeName()220 static std::string getTypeName() { return "ResourceType"; }
221 };
222
223 // TensorFlow variant type is used to support arbitrary custom C++ data types.
224 // VariantType stores inferred shape and datatype for subtypes unlike most other
225 // data types that don't have any associated information. For example, variants
226 // encoding TensorList type stores the common shape and dtype of the list
227 // elements as the only subtype.
228 class VariantType : public detail::TypeWithSubtypeImpl<VariantType> {
229 public:
230 using TFBase::TFBase;
getTypeName()231 static std::string getTypeName() { return "VariantType"; }
232 };
233
234 // Given two types `a` and `b`, returns a refined type which is cast compatible
235 // with both `a` and `b` and is equal to or more precise than both of them. It
236 // returns empty Type if the input types are not cast compatible.
237 // Provides option to ignore ref types on 'a'. This is useful for TF ops that
238 // might allow operands to either be same as result type or be a ref type
239 // corresponding to it.
240 Type GetCastCompatibleType(Type a, Type b, bool may_ignore_ref_type_a);
241
242 // Returns whether two arrays of Type are broadcast compatible.
243 bool BroadcastCompatible(TypeRange lhs, TypeRange rhs);
244
245 // Returns whether the two elemental types are compatible. Shapes are compatible
246 // if:
247 // - the types are statically equal
248 // - could be dynamically equal
249 // - considering dynamic shapes equal unless contradictory info known;
250 // - element types are equivalent, modulo subtypes possible be less exact
251 // (e.g., a resource type without subtype is considered compatible with
252 // resource type with known subtype).
253 // Provide option to ignore ref types on 'lhs'.
254 bool HasCompatibleElementTypes(Type lhs, Type rhs,
255 bool may_ignore_ref_type_lhs = false);
256
257 // Returns true if all TensorFlow types can be cast to one
258 // another. In other words, a single run-time value is legal for all the types.
259 // For example, tensor<*xf32>, tensor<?xf32> and tensor<3xf32> are cast
260 // compatible.
261 bool AreCastCompatible(TypeRange types);
262
263 // Returns true if corresponding elements of lhs and rhs AreCastCompatible and
264 // lhs and rhs are the same length.
265 bool ArraysAreCastCompatible(TypeRange lhs, TypeRange rhs);
266
267 // If `ty` is a tensor type and its element type has subtypes, then returns a
268 // new type of same shape but dropped subtypes for the element type.
269 // Otherwise, if `ty` has subtypes, then returns corresponding type with dropped
270 // subtypes.
271 // Otherwise, returns the original type `ty`.
272 Type DropSubTypes(Type ty);
273
274 // If `ty` is a tensor type and has elements of a ref type, then returns a new
275 // type of same shape but corresponding non-ref type as element type.
276 // Otherwise, if `ty` is a ref type, then returns corresponding non-ref type.
277 // Otherwise, returns the original type `ty`.
278 Type DropRefType(Type ty);
279
280 // Convenience call for executing both `DropRefType` and `DropSubTypes`.
281 Type DropRefAndSubTypes(Type ty);
282
283 //===----------------------------------------------------------------------===//
284 // Utility iterators
285 //===----------------------------------------------------------------------===//
286
287 // An iterator for the tensor shapes of an op's operands of shaped types.
288 // Returns llvm::None if a operand is unranked; returns ArrayRef<int64_t> as the
289 // shape otherwise.
290 class OperandShapeIterator final
291 : public llvm::mapped_iterator<Operation::operand_iterator,
292 llvm::Optional<ArrayRef<int64_t>> (*)(
293 Value)> {
294 public:
295 using reference = llvm::Optional<ArrayRef<int64_t>>;
296
297 /// Initializes the operand shape iterator to the specified operand iterator.
298 explicit OperandShapeIterator(Operation::operand_iterator it);
299 };
300
301 using OperandShapeRange = iterator_range<OperandShapeIterator>;
302
303 // An iterator for the tensor shapes of an op's results of shaped types.
304 // Returns llvm::None if a result is unranked; returns ArrayRef<int64_t> as the
305 // shape otherwise.
306 class ResultShapeIterator final
307 : public llvm::mapped_iterator<Operation::result_iterator,
308 llvm::Optional<ArrayRef<int64_t>> (*)(
309 Value)> {
310 public:
311 using reference = llvm::Optional<ArrayRef<int64_t>>;
312
313 /// Initializes the result shape iterator to the specified result iterator.
314 explicit ResultShapeIterator(Operation::result_iterator it);
315 };
316
317 using ResultShapeRange = iterator_range<ResultShapeIterator>;
318
319 // Returns a range with just resource type values from the input range
320 // preserved.
321 template <typename RangeT>
filter_resources(RangeT && range)322 auto filter_resources(RangeT&& range) {
323 return llvm::make_filter_range(std::forward<RangeT>(range), [](Value val) {
324 return getElementTypeOrSelf(val.getType()).isa<ResourceType>();
325 });
326 }
327
328 // Returns the element type if `type` is a `ShapedType` and the type itself
329 // otherwise, converting `TensorFlowRef` type to corresponding `TensorFlow` or
330 // standard type if necessary.
GetElementTypeOrSelfResolveRef(Type type)331 inline Type GetElementTypeOrSelfResolveRef(Type type) {
332 Type element_type = getElementTypeOrSelf(type);
333 if (auto ref_type = element_type.dyn_cast<TensorFlowRefType>()) {
334 element_type = ref_type.RemoveRef();
335 }
336 return element_type;
337 }
338
339 } // namespace tf_type
340 } // namespace mlir
341
342 //===----------------------------------------------------------------------===//
343 // Tablegen Attribute Declarations
344 //===----------------------------------------------------------------------===//
345
346 #define GET_ATTRDEF_CLASSES
347 #include "tensorflow/core/ir/types/attributes.h.inc"
348
349 #endif // TENSORFLOW_CORE_IR_TYPE_DIALECT_H_
350