• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_IR_TYPE_DIALECT_H_
17 #define TENSORFLOW_CORE_IR_TYPE_DIALECT_H_
18 
19 #include <string>
20 
21 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
22 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
23 #include "mlir/IR/Dialect.h"  // from @llvm-project
24 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
25 
26 // Include the dialect class generated from dialect.td.
27 // The constructor and the printing/parsing of dialect types are manually
28 // implemented (see ops.cpp).
29 #include "tensorflow/core/ir/types/dialect.h.inc"
30 
31 // Include the Type classes declaration generated from types.td
32 #define GET_TYPEDEF_CLASSES
33 #include "tensorflow/core/ir/types/types.h.inc"
34 
35 namespace mlir {
36 namespace tf_type {
37 
38 //===----------------------------------------------------------------------===//
39 // TensorFlow types
40 //===----------------------------------------------------------------------===//
41 
42 // The base class in the TensorFlow type hierarchy.
43 class TensorFlowType : public Type {
44  public:
45   using Type::Type;
46 
47   // Support method to enable LLVM-style type casting.
48   static bool classof(Type type);
49 };
50 
51 // Returns true if the specified type is a valid TensorFlow element type.
IsValidTFElementType(Type type)52 inline bool IsValidTFElementType(Type type) {
53   return type.isa<ComplexType, FloatType, IntegerType, TensorFlowType>();
54 }
55 
56 // Returns true if this is a valid TensorFlow tensor type.
IsValidTFTensorType(Type type)57 inline bool IsValidTFTensorType(Type type) {
58   // TensorFlow types should be tensors of one of the valid TensorFlow element
59   // types.
60   if (auto tensor_ty = type.dyn_cast<TensorType>())
61     return IsValidTFElementType(tensor_ty.getElementType());
62   return false;
63 }
64 
65 namespace detail {
66 // Common implementation of TensorFlow types. The template argument indicates
67 // the concrete derived class per CRTP.
68 template <typename Derived>
69 class TensorFlowTypeImpl
70     : public Type::TypeBase<Derived, TensorFlowType, TypeStorage> {
71  public:
72   using Base = typename Type::TypeBase<Derived, TensorFlowType, TypeStorage>;
73   using TFBase = TensorFlowTypeImpl<Derived>;
74   using Base::Base;
75 };
76 }  // namespace detail
77 
78 // TensorFlowRefType class supports all the ref types in TensorFlow dialect.
79 class TensorFlowRefType : public TensorFlowType {
80  public:
81   using TensorFlowType::TensorFlowType;
82 
83   // Checks if a type is TensorFlow Ref type.
84   static bool classof(Type type);
85 
86   // Converts a type to the corresponding TensorFlowRef type.
87   static TensorFlowType get(Type type);
getChecked(Type type,MLIRContext * context,Location loc)88   static TensorFlowType getChecked(Type type, MLIRContext* context,
89                                    Location loc) {
90     if (failed(verify(loc, type))) {
91       return TensorFlowRefType();
92     }
93     return get(type);
94   }
95 
verify(Location loc,Type type)96   static LogicalResult verify(Location loc, Type type) {
97     // type should be a valid TensorFlow type.
98     if (!IsValidTFTensorType(type)) {
99       return emitError(loc) << "invalid TensorFlow type: " << type;
100     }
101     return success();
102   }
103 
104   // Converts a TensorFlowRef type to the corresponding TensorFlow or standard
105   // type.
106   Type RemoveRef();
107 };
108 
109 // Define a class for each individual TensorFlow type (dtype), see types.def
110 // for the list.
111 #define HANDLE_TF_TYPE(tftype, enumerant, name)                          \
112   class tftype##Type : public detail::TensorFlowTypeImpl<tftype##Type> { \
113    public:                                                               \
114     using TFBase::TFBase;                                                \
115   };
116 #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
117 #include "tensorflow/core/ir/types/types.def"
118 
119 namespace detail {
120 // Storage type contains inferred subtypes for TypeWithSubtype.
121 class TypeWithSubtypeStorage : public TypeStorage {
122  public:
123   using KeyTy = ArrayRef<TensorType>;
124 
125   // NOLINTNEXTLINE
construct(TypeStorageAllocator & allocator,const KeyTy & key)126   static TypeWithSubtypeStorage* construct(TypeStorageAllocator& allocator,
127                                            const KeyTy& key) {
128     ArrayRef<TensorType> subtypes = allocator.copyInto(key);
129     return new (allocator.allocate<TypeWithSubtypeStorage>())
130         TypeWithSubtypeStorage(subtypes);
131   }
132 
TypeWithSubtypeStorage(const KeyTy & key)133   explicit TypeWithSubtypeStorage(const KeyTy& key) : subtypes_(key) {}
134 
135   bool operator==(const KeyTy& key) const { return key == subtypes_; }
136 
hashKey(const KeyTy & key)137   static llvm::hash_code hashKey(const KeyTy& key) {
138     return llvm::hash_combine_range(key.begin(), key.end());
139   }
140 
141   KeyTy subtypes_;
142 };
143 
144 // Common implementation of TensorFlow types with subtypes. These subtypes are
145 // opaque and their interpretation depends on the actual underlying type.
146 // The template argument indicates the concrete derived class per CRTP. Concrete
147 // classes must implement the following:
148 //   - `static std::string getTypeName()` that returns the name of the type for
149 //     verification logging.
150 template <typename Derived>
151 class TypeWithSubtypeImpl
152     : public Type::TypeBase<Derived, TensorFlowType, TypeWithSubtypeStorage> {
153  public:
154   using Base = Type::TypeBase<Derived, TensorFlowType, TypeWithSubtypeStorage>;
155   using TFBase = TypeWithSubtypeImpl<Derived>;
156   using Base::Base;
157 
get(ArrayRef<TensorType> subtypes,MLIRContext * context)158   static Derived get(ArrayRef<TensorType> subtypes, MLIRContext* context) {
159     return Base::get(context, subtypes);
160   }
161 
getChecked(ArrayRef<TensorType> subtypes,MLIRContext * context,Location loc)162   static Derived getChecked(ArrayRef<TensorType> subtypes, MLIRContext* context,
163                             Location loc) {
164     return Base::getChecked(loc, subtypes);
165   }
getChecked(function_ref<InFlightDiagnostic ()> emitError,MLIRContext * context,ArrayRef<TensorType> subtypes)166   static Derived getChecked(function_ref<InFlightDiagnostic()> emitError,
167                             MLIRContext* context,
168                             ArrayRef<TensorType> subtypes) {
169     return Base::getChecked(emitError, context, subtypes);
170   }
171 
get(MLIRContext * context)172   static Derived get(MLIRContext* context) { return get({}, context); }
173 
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<TensorType> subtypes)174   static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
175                               ArrayRef<TensorType> subtypes) {
176     // Each of the subtypes should be a valid TensorFlow type.
177     for (TensorType subtype : subtypes) {
178       if (!IsValidTFTensorType(subtype)) {
179         return emitError() << "invalid " << Derived::getTypeName()
180                            << " subtype: " << subtype;
181       }
182     }
183     return success();
184   }
185 
getSubtypes()186   ArrayRef<TensorType> getSubtypes() { return Base::getImpl()->subtypes_; }
187 };
188 }  // namespace detail
189 
190 // TensorFlowTypeWithSubtype class supports all the types with subtypes in
191 // TensorFlow dialect.
192 class TensorFlowTypeWithSubtype : public TensorFlowType {
193  public:
194   using TensorFlowType::TensorFlowType;
195 
196   // Checks if a type is TensorFlow type with subtypes.
197   static bool classof(Type type);
198 
199   // Converts a TypeWithSubtype type to the same type but without its subtypes.
200   Type RemoveSubtypes();
201 
202   // Clone the current Type with new subtypes.
203   TensorFlowTypeWithSubtype clone(ArrayRef<TensorType> new_subtypes);
204 
205   // Returns the subtypes.
206   ArrayRef<TensorType> GetSubtypes();
207 };
208 
209 // Returns the corresponding TensorFlow type with subtypes but without its
210 // subtypes.
GetDefaultTypeOf(TensorFlowTypeWithSubtype type)211 inline Type GetDefaultTypeOf(TensorFlowTypeWithSubtype type) {
212   return type.RemoveSubtypes();
213 }
214 
215 // TensorFlow resource type is used to support TensorFlow resource variables,
216 // which represent shared, persistent state manipulated by a TensorFlow program.
217 // ResourceType stores shape and datatype for subtypes unlike most other data
218 // types that don't have any associated information.
219 class ResourceType : public detail::TypeWithSubtypeImpl<ResourceType> {
220  public:
221   using TFBase::TFBase;
getTypeName()222   static std::string getTypeName() { return "ResourceType"; }
223 };
224 
225 // TensorFlow variant type is used to support arbitrary custom C++ data types.
226 // VariantType stores inferred shape and datatype for subtypes unlike most other
227 // data types that don't have any associated information. For example, variants
228 // encoding TensorList type stores the common shape and dtype of the list
229 // elements as the only subtype.
230 class VariantType : public detail::TypeWithSubtypeImpl<VariantType> {
231  public:
232   using TFBase::TFBase;
getTypeName()233   static std::string getTypeName() { return "VariantType"; }
234 };
235 
236 // Given two types `a` and `b`, returns a refined type which is cast compatible
237 // with both `a` and `b` and is equal to or more precise than both of them. It
238 // returns empty Type if the input types are not cast compatible.
239 // Provides option to ignore ref types on 'a'. This is useful for TF ops that
240 // might allow operands to either be same as result type or be a ref type
241 // corresponding to it.
242 Type GetCastCompatibleType(Type a, Type b, bool may_ignore_ref_type_a = false);
243 
244 // Returns whether two arrays of Type are broadcast compatible.
245 bool BroadcastCompatible(TypeRange lhs, TypeRange rhs);
246 
247 // Returns whether the two elemental types are compatible. Shapes are compatible
248 // if:
249 // - the types are statically equal
250 // - could be dynamically equal
251 //   - considering dynamic shapes equal unless contradictory info known;
252 //   - element types are equivalent, modulo subtypes possible be less exact
253 //     (e.g., a resource type without subtype is considered compatible with
254 //      resource type with known subtype).
255 // Provide option to ignore ref types on 'lhs'.
256 bool HasCompatibleElementTypes(Type lhs, Type rhs,
257                                bool may_ignore_ref_type_lhs = false);
258 
259 // Returns true if all TensorFlow types can be cast to one
260 // another. In other words, a single run-time value is legal for all the types.
261 // For example, tensor<*xf32>, tensor<?xf32> and tensor<3xf32> are cast
262 // compatible.
263 bool AreCastCompatible(TypeRange types);
264 
265 // Returns true if corresponding elements of lhs and rhs AreCastCompatible and
266 // lhs and rhs are the same length.
267 bool ArraysAreCastCompatible(TypeRange lhs, TypeRange rhs);
268 
269 // If `ty` is a tensor type and its element type has subtypes, then returns a
270 // new type of same shape but dropped subtypes for the element type.
271 // Otherwise, if `ty` has subtypes, then returns corresponding type with dropped
272 // subtypes.
273 // Otherwise, returns the original type `ty`.
274 Type DropSubTypes(Type ty);
275 
276 // If `ty` is a tensor type and has elements of a ref type, then returns a new
277 // type of same shape but corresponding non-ref type as element type.
278 // Otherwise, if `ty` is a ref type, then returns corresponding non-ref type.
279 // Otherwise, returns the original type `ty`.
280 Type DropRefType(Type ty);
281 
282 // Convenience call for executing both `DropRefType` and `DropSubTypes`.
283 Type DropRefAndSubTypes(Type ty);
284 
285 //===----------------------------------------------------------------------===//
286 // Utility iterators
287 //===----------------------------------------------------------------------===//
288 
289 // An iterator for the tensor shapes of an op's operands of shaped types.
290 // Returns llvm::None if a operand is unranked; returns ArrayRef<int64_t> as the
291 // shape otherwise.
292 class OperandShapeIterator final
293     : public llvm::mapped_iterator<Operation::operand_iterator,
294                                    llvm::Optional<ArrayRef<int64_t>> (*)(
295                                        Value)> {
296  public:
297   using reference = llvm::Optional<ArrayRef<int64_t>>;
298 
299   /// Initializes the operand shape iterator to the specified operand iterator.
300   explicit OperandShapeIterator(Operation::operand_iterator it);
301 };
302 
303 using OperandShapeRange = iterator_range<OperandShapeIterator>;
304 
305 // An iterator for the tensor shapes of an op's results of shaped types.
306 // Returns llvm::None if a result is unranked; returns ArrayRef<int64_t> as the
307 // shape otherwise.
308 class ResultShapeIterator final
309     : public llvm::mapped_iterator<Operation::result_iterator,
310                                    llvm::Optional<ArrayRef<int64_t>> (*)(
311                                        Value)> {
312  public:
313   using reference = llvm::Optional<ArrayRef<int64_t>>;
314 
315   /// Initializes the result shape iterator to the specified result iterator.
316   explicit ResultShapeIterator(Operation::result_iterator it);
317 };
318 
319 using ResultShapeRange = iterator_range<ResultShapeIterator>;
320 
321 // Returns a range with just resource type values from the input range
322 // preserved.
323 template <typename RangeT>
filter_resources(RangeT && range)324 auto filter_resources(RangeT&& range) {
325   return llvm::make_filter_range(std::forward<RangeT>(range), [](Value val) {
326     return getElementTypeOrSelf(val.getType()).isa<ResourceType>();
327   });
328 }
329 
330 // Returns the element type if `type` is a `ShapedType` and the type itself
331 // otherwise, converting `TensorFlowRef` type to corresponding `TensorFlow` or
332 // standard type if necessary.
GetElementTypeOrSelfResolveRef(Type type)333 inline Type GetElementTypeOrSelfResolveRef(Type type) {
334   Type element_type = getElementTypeOrSelf(type);
335   if (auto ref_type = element_type.dyn_cast<TensorFlowRefType>()) {
336     element_type = ref_type.RemoveRef();
337   }
338   return element_type;
339 }
340 
341 }  // namespace tf_type
342 }  // namespace mlir
343 
344 //===----------------------------------------------------------------------===//
345 // Tablegen Attribute Declarations
346 //===----------------------------------------------------------------------===//
347 
348 #define GET_ATTRDEF_CLASSES
349 #include "tensorflow/core/ir/types/attributes.h.inc"
350 #include "tensorflow/core/ir/types/attributes_enum.h.inc"
351 
352 #endif  // TENSORFLOW_CORE_IR_TYPE_DIALECT_H_
353