• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_IR_TYPE_DIALECT_H_
17 #define TENSORFLOW_CORE_IR_TYPE_DIALECT_H_
18 
19 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
20 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
21 #include "mlir/IR/Dialect.h"  // from @llvm-project
22 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
23 
24 // Include the dialect class generated from dialect.td.
25 // The constructor and the printing/parsing of dialect types are manually
26 // implemented (see ops.cpp).
27 #include "tensorflow/core/ir/types/dialect.h.inc"
28 
29 // Include the Type classes declaration generated from types.td
30 #define GET_TYPEDEF_CLASSES
31 #include "tensorflow/core/ir/types/types.h.inc"
32 
33 namespace mlir {
34 namespace tf_type {
35 
36 //===----------------------------------------------------------------------===//
37 // TensorFlow types
38 //===----------------------------------------------------------------------===//
39 
40 // The base class in the TensorFlow type hierarchy.
41 class TensorFlowType : public Type {
42  public:
43   using Type::Type;
44 
45   // Support method to enable LLVM-style type casting.
46   static bool classof(Type type);
47 };
48 
49 // Returns true if the specified type is a valid TensorFlow element type.
IsValidTFElementType(Type type)50 inline bool IsValidTFElementType(Type type) {
51   return type.isa<ComplexType, FloatType, IntegerType, TensorFlowType>();
52 }
53 
54 // Returns true if this is a valid TensorFlow tensor type.
IsValidTFTensorType(Type type)55 inline bool IsValidTFTensorType(Type type) {
56   // TensorFlow types should be tensors of one of the valid TensorFlow element
57   // types.
58   if (auto tensor_ty = type.dyn_cast<TensorType>())
59     return IsValidTFElementType(tensor_ty.getElementType());
60   return false;
61 }
62 
63 namespace detail {
64 // Common implementation of TensorFlow types. The template argument indicates
65 // the concrete derived class per CRTP.
66 template <typename Derived>
67 class TensorFlowTypeImpl
68     : public Type::TypeBase<Derived, TensorFlowType, TypeStorage> {
69  public:
70   using Base = typename Type::TypeBase<Derived, TensorFlowType, TypeStorage>;
71   using TFBase = TensorFlowTypeImpl<Derived>;
72   using Base::Base;
73 };
74 }  // namespace detail
75 
76 // TensorFlowRefType class supports all the ref types in TensorFlow dialect.
77 class TensorFlowRefType : public TensorFlowType {
78  public:
79   using TensorFlowType::TensorFlowType;
80 
81   // Checks if a type is TensorFlow Ref type.
82   static bool classof(Type type);
83 
84   // Converts a type to the corresponding TensorFlowRef type.
85   static TensorFlowType get(Type type);
getChecked(Type type,MLIRContext * context,Location loc)86   static TensorFlowType getChecked(Type type, MLIRContext* context,
87                                    Location loc) {
88     if (failed(verify(loc, type))) {
89       return TensorFlowRefType();
90     }
91     return get(type);
92   }
93 
verify(Location loc,Type type)94   static LogicalResult verify(Location loc, Type type) {
95     // type should be a valid TensorFlow type.
96     if (!IsValidTFTensorType(type)) {
97       return emitError(loc) << "invalid TensorFlow type: " << type;
98     }
99     return success();
100   }
101 
102   // Converts a TensorFlowRef type to the corresponding TensorFlow or standard
103   // type.
104   Type RemoveRef();
105 };
106 
107 // Define a class for each individual TensorFlow type (dtype), see types.def
108 // for the list.
109 #define HANDLE_TF_TYPE(tftype, enumerant, name)                          \
110   class tftype##Type : public detail::TensorFlowTypeImpl<tftype##Type> { \
111    public:                                                               \
112     using TFBase::TFBase;                                                \
113   };
114 #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
115 #include "tensorflow/core/ir/types/types.def"
116 
117 namespace detail {
118 // Storage type contains inferred subtypes for TypeWithSubtype.
119 class TypeWithSubtypeStorage : public TypeStorage {
120  public:
121   using KeyTy = ArrayRef<TensorType>;
122 
123   // NOLINTNEXTLINE
construct(TypeStorageAllocator & allocator,const KeyTy & key)124   static TypeWithSubtypeStorage* construct(TypeStorageAllocator& allocator,
125                                            const KeyTy& key) {
126     ArrayRef<TensorType> subtypes = allocator.copyInto(key);
127     return new (allocator.allocate<TypeWithSubtypeStorage>())
128         TypeWithSubtypeStorage(subtypes);
129   }
130 
TypeWithSubtypeStorage(const KeyTy & key)131   explicit TypeWithSubtypeStorage(const KeyTy& key) : subtypes_(key) {}
132 
133   bool operator==(const KeyTy& key) const { return key == subtypes_; }
134 
hashKey(const KeyTy & key)135   static llvm::hash_code hashKey(const KeyTy& key) {
136     return llvm::hash_combine_range(key.begin(), key.end());
137   }
138 
139   KeyTy subtypes_;
140 };
141 
142 // Common implementation of TensorFlow types with subtypes. These subtypes are
143 // opaque and their interpretation depends on the actual underlying type.
144 // The template argument indicates the concrete derived class per CRTP. Concrete
145 // classes must implement the following:
146 //   - `static std::string getTypeName()` that returns the name of the type for
147 //     verification logging.
148 template <typename Derived>
149 class TypeWithSubtypeImpl
150     : public Type::TypeBase<Derived, TensorFlowType, TypeWithSubtypeStorage> {
151  public:
152   using Base = Type::TypeBase<Derived, TensorFlowType, TypeWithSubtypeStorage>;
153   using TFBase = TypeWithSubtypeImpl<Derived>;
154   using Base::Base;
155 
get(ArrayRef<TensorType> subtypes,MLIRContext * context)156   static Derived get(ArrayRef<TensorType> subtypes, MLIRContext* context) {
157     return Base::get(context, subtypes);
158   }
159 
getChecked(ArrayRef<TensorType> subtypes,MLIRContext * context,Location loc)160   static Derived getChecked(ArrayRef<TensorType> subtypes, MLIRContext* context,
161                             Location loc) {
162     return Base::getChecked(loc, subtypes);
163   }
getChecked(function_ref<InFlightDiagnostic ()> emitError,MLIRContext * context,ArrayRef<TensorType> subtypes)164   static Derived getChecked(function_ref<InFlightDiagnostic()> emitError,
165                             MLIRContext* context,
166                             ArrayRef<TensorType> subtypes) {
167     return Base::getChecked(emitError, context, subtypes);
168   }
169 
get(MLIRContext * context)170   static Derived get(MLIRContext* context) { return get({}, context); }
171 
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<TensorType> subtypes)172   static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
173                               ArrayRef<TensorType> subtypes) {
174     // Each of the subtypes should be a valid TensorFlow type.
175     for (TensorType subtype : subtypes) {
176       if (!IsValidTFTensorType(subtype)) {
177         return emitError() << "invalid " << Derived::getTypeName()
178                            << " subtype: " << subtype;
179       }
180     }
181     return success();
182   }
183 
getSubtypes()184   ArrayRef<TensorType> getSubtypes() { return Base::getImpl()->subtypes_; }
185 };
186 }  // namespace detail
187 
188 // TensorFlowTypeWithSubtype class supports all the types with subtypes in
189 // TensorFlow dialect.
190 class TensorFlowTypeWithSubtype : public TensorFlowType {
191  public:
192   using TensorFlowType::TensorFlowType;
193 
194   // Checks if a type is TensorFlow type with subtypes.
195   static bool classof(Type type);
196 
197   // Converts a TypeWithSubtype type to the same type but without its subtypes.
198   Type RemoveSubtypes();
199 
200   // Clone the current Type with new subtypes.
201   TensorFlowTypeWithSubtype clone(ArrayRef<TensorType> new_subtypes);
202 
203   // Returns the subtypes.
204   ArrayRef<TensorType> GetSubtypes();
205 };
206 
207 // Returns the corresponding TensorFlow type with subtypes but without its
208 // subtypes.
GetDefaultTypeOf(TensorFlowTypeWithSubtype type)209 inline Type GetDefaultTypeOf(TensorFlowTypeWithSubtype type) {
210   return type.RemoveSubtypes();
211 }
212 
213 // TensorFlow resource type is used to support TensorFlow resource variables,
214 // which represent shared, persistent state manipulated by a TensorFlow program.
215 // ResourceType stores shape and datatype for subtypes unlike most other data
216 // types that don't have any associated information.
217 class ResourceType : public detail::TypeWithSubtypeImpl<ResourceType> {
218  public:
219   using TFBase::TFBase;
getTypeName()220   static std::string getTypeName() { return "ResourceType"; }
221 };
222 
223 // TensorFlow variant type is used to support arbitrary custom C++ data types.
224 // VariantType stores inferred shape and datatype for subtypes unlike most other
225 // data types that don't have any associated information. For example, variants
226 // encoding TensorList type stores the common shape and dtype of the list
227 // elements as the only subtype.
228 class VariantType : public detail::TypeWithSubtypeImpl<VariantType> {
229  public:
230   using TFBase::TFBase;
getTypeName()231   static std::string getTypeName() { return "VariantType"; }
232 };
233 
234 // Given two types `a` and `b`, returns a refined type which is cast compatible
235 // with both `a` and `b` and is equal to or more precise than both of them. It
236 // returns empty Type if the input types are not cast compatible.
237 // Provides option to ignore ref types on 'a'. This is useful for TF ops that
238 // might allow operands to either be same as result type or be a ref type
239 // corresponding to it.
240 Type GetCastCompatibleType(Type a, Type b, bool may_ignore_ref_type_a);
241 
242 // Returns whether two arrays of Type are broadcast compatible.
243 bool BroadcastCompatible(TypeRange lhs, TypeRange rhs);
244 
245 // Returns whether the two elemental types are compatible. Shapes are compatible
246 // if:
247 // - the types are statically equal
248 // - could be dynamically equal
249 //   - considering dynamic shapes equal unless contradictory info known;
250 //   - element types are equivalent, modulo subtypes possible be less exact
251 //     (e.g., a resource type without subtype is considered compatible with
252 //      resource type with known subtype).
253 // Provide option to ignore ref types on 'lhs'.
254 bool HasCompatibleElementTypes(Type lhs, Type rhs,
255                                bool may_ignore_ref_type_lhs = false);
256 
257 // Returns true if all TensorFlow types can be cast to one
258 // another. In other words, a single run-time value is legal for all the types.
259 // For example, tensor<*xf32>, tensor<?xf32> and tensor<3xf32> are cast
260 // compatible.
261 bool AreCastCompatible(TypeRange types);
262 
263 // Returns true if corresponding elements of lhs and rhs AreCastCompatible and
264 // lhs and rhs are the same length.
265 bool ArraysAreCastCompatible(TypeRange lhs, TypeRange rhs);
266 
267 // If `ty` is a tensor type and its element type has subtypes, then returns a
268 // new type of same shape but dropped subtypes for the element type.
269 // Otherwise, if `ty` has subtypes, then returns corresponding type with dropped
270 // subtypes.
271 // Otherwise, returns the original type `ty`.
272 Type DropSubTypes(Type ty);
273 
274 // If `ty` is a tensor type and has elements of a ref type, then returns a new
275 // type of same shape but corresponding non-ref type as element type.
276 // Otherwise, if `ty` is a ref type, then returns corresponding non-ref type.
277 // Otherwise, returns the original type `ty`.
278 Type DropRefType(Type ty);
279 
280 // Convenience call for executing both `DropRefType` and `DropSubTypes`.
281 Type DropRefAndSubTypes(Type ty);
282 
283 //===----------------------------------------------------------------------===//
284 // Utility iterators
285 //===----------------------------------------------------------------------===//
286 
287 // An iterator for the tensor shapes of an op's operands of shaped types.
288 // Returns llvm::None if a operand is unranked; returns ArrayRef<int64_t> as the
289 // shape otherwise.
290 class OperandShapeIterator final
291     : public llvm::mapped_iterator<Operation::operand_iterator,
292                                    llvm::Optional<ArrayRef<int64_t>> (*)(
293                                        Value)> {
294  public:
295   using reference = llvm::Optional<ArrayRef<int64_t>>;
296 
297   /// Initializes the operand shape iterator to the specified operand iterator.
298   explicit OperandShapeIterator(Operation::operand_iterator it);
299 };
300 
301 using OperandShapeRange = iterator_range<OperandShapeIterator>;
302 
303 // An iterator for the tensor shapes of an op's results of shaped types.
304 // Returns llvm::None if a result is unranked; returns ArrayRef<int64_t> as the
305 // shape otherwise.
306 class ResultShapeIterator final
307     : public llvm::mapped_iterator<Operation::result_iterator,
308                                    llvm::Optional<ArrayRef<int64_t>> (*)(
309                                        Value)> {
310  public:
311   using reference = llvm::Optional<ArrayRef<int64_t>>;
312 
313   /// Initializes the result shape iterator to the specified result iterator.
314   explicit ResultShapeIterator(Operation::result_iterator it);
315 };
316 
317 using ResultShapeRange = iterator_range<ResultShapeIterator>;
318 
319 // Returns a range with just resource type values from the input range
320 // preserved.
321 template <typename RangeT>
filter_resources(RangeT && range)322 auto filter_resources(RangeT&& range) {
323   return llvm::make_filter_range(std::forward<RangeT>(range), [](Value val) {
324     return getElementTypeOrSelf(val.getType()).isa<ResourceType>();
325   });
326 }
327 
328 // Returns the element type if `type` is a `ShapedType` and the type itself
329 // otherwise, converting `TensorFlowRef` type to corresponding `TensorFlow` or
330 // standard type if necessary.
GetElementTypeOrSelfResolveRef(Type type)331 inline Type GetElementTypeOrSelfResolveRef(Type type) {
332   Type element_type = getElementTypeOrSelf(type);
333   if (auto ref_type = element_type.dyn_cast<TensorFlowRefType>()) {
334     element_type = ref_type.RemoveRef();
335   }
336   return element_type;
337 }
338 
339 }  // namespace tf_type
340 }  // namespace mlir
341 
342 //===----------------------------------------------------------------------===//
343 // Tablegen Attribute Declarations
344 //===----------------------------------------------------------------------===//
345 
346 #define GET_ATTRDEF_CLASSES
347 #include "tensorflow/core/ir/types/attributes.h.inc"
348 
349 #endif  // TENSORFLOW_CORE_IR_TYPE_DIALECT_H_
350