• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines the types used in the standard MLIR TensorFlow dialect.
17 
18 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TYPES_H_
19 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TYPES_H_
20 
21 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
22 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
23 #include "mlir/IR/Location.h"  // from @llvm-project
24 #include "mlir/IR/Operation.h"  // from @llvm-project
25 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
26 #include "mlir/IR/Types.h"  // from @llvm-project
27 
28 namespace mlir {
29 namespace TF {
30 //===----------------------------------------------------------------------===//
31 // Utility iterators
32 //===----------------------------------------------------------------------===//
33 
34 // An iterator for the tensor shapes of an op's operands of shaped types.
35 // Returns llvm::None if a operand is unranked; returns ArrayRef<int64_t> as the
36 // shape otherwise.
37 class OperandShapeIterator final
38     : public llvm::mapped_iterator<Operation::operand_iterator,
39                                    llvm::Optional<ArrayRef<int64_t>> (*)(
40                                        Value)> {
41  public:
42   using reference = llvm::Optional<ArrayRef<int64_t>>;
43 
44   /// Initializes the operand shape iterator to the specified operand iterator.
45   explicit OperandShapeIterator(Operation::operand_iterator it);
46 };
47 
48 using OperandShapeRange = iterator_range<OperandShapeIterator>;
49 
50 // An iterator for the tensor shapes of an op's results of shaped types.
51 // Returns llvm::None if a result is unranked; returns ArrayRef<int64_t> as the
52 // shape otherwise.
53 class ResultShapeIterator final
54     : public llvm::mapped_iterator<Operation::result_iterator,
55                                    llvm::Optional<ArrayRef<int64_t>> (*)(
56                                        Value)> {
57  public:
58   using reference = llvm::Optional<ArrayRef<int64_t>>;
59 
60   /// Initializes the result shape iterator to the specified result iterator.
61   explicit ResultShapeIterator(Operation::result_iterator it);
62 };
63 
64 using ResultShapeRange = iterator_range<ResultShapeIterator>;
65 
66 //===----------------------------------------------------------------------===//
67 // TensorFlow types
68 //===----------------------------------------------------------------------===//
69 
70 // The base class in the TensorFlow type hierarchy.
71 class TensorFlowType : public Type {
72  public:
73   using Type::Type;
74 
75   // Support method to enable LLVM-style type casting.
76   static bool classof(Type type);
77 };
78 
79 // Returns true if the specified type is a valid TensorFlow element type.
IsValidTFElementType(Type type)80 static inline bool IsValidTFElementType(Type type) {
81   return type.isa<ComplexType, FloatType, IntegerType, TensorFlowType>();
82 }
83 
84 // Returns true if this is a valid TensorFlow tensor type.
IsValidTFTensorType(Type type)85 static inline bool IsValidTFTensorType(Type type) {
86   // TensorFlow types should be tensors of one of the valid TensorFlow element
87   // types.
88   if (auto tensor_ty = type.dyn_cast<TensorType>())
89     return IsValidTFElementType(tensor_ty.getElementType());
90   return false;
91 }
92 
93 namespace detail {
94 // Common implementation of TensorFlow types. The template argument indicates
95 // the concrete derived class per CRTP.
96 template <typename Derived>
97 class TensorFlowTypeImpl
98     : public Type::TypeBase<Derived, TensorFlowType, TypeStorage> {
99  public:
100   using Base = typename Type::TypeBase<Derived, TensorFlowType, TypeStorage>;
101   using TFBase = TensorFlowTypeImpl<Derived>;
102   using Base::Base;
103 };
104 }  // namespace detail
105 
106 // TensorFlowRefType class supports all the ref types in TensorFlow dialect.
107 class TensorFlowRefType : public TensorFlowType {
108  public:
109   using TensorFlowType::TensorFlowType;
110 
111   // Checks if a type is TensorFlow Ref type.
112   static bool classof(Type type);
113 
114   // Converts a type to the corresponding TensorFlowRef type.
115   static TensorFlowType get(Type type);
getChecked(Type type,MLIRContext * context,Location loc)116   static TensorFlowType getChecked(Type type, MLIRContext* context,
117                                    Location loc) {
118     if (failed(verifyConstructionInvariants(loc, type))) {
119       return TensorFlowRefType();
120     }
121     return get(type);
122   }
123 
verifyConstructionInvariants(Location loc,Type type)124   static LogicalResult verifyConstructionInvariants(Location loc, Type type) {
125     // type should be a valid TensorFlow type.
126     if (!IsValidTFTensorType(type)) {
127       return emitError(loc) << "invalid TensorFlow type: " << type;
128     }
129     return success();
130   }
131 
132   // Converts a TensorFlowRef type to the corresponding TensorFlow or standard
133   // type.
134   Type RemoveRef();
135 };
136 
137 // Returns the corresponding TensorFlow or standard type from TensorFlowRef
138 // type.
GetDefaultTypeOf(TensorFlowRefType type)139 static inline Type GetDefaultTypeOf(TensorFlowRefType type) {
140   return type.RemoveRef();
141 }
142 
143 // Returns the element type if `type` is a `ShapedType` and the type itself
144 // otherwise, converting `TensorFlowRef` type to corresponding `TensorFlow` or
145 // standard type if necessary.
GetElementTypeOrSelfResolveRef(Type type)146 static inline Type GetElementTypeOrSelfResolveRef(Type type) {
147   Type element_type = mlir::getElementTypeOrSelf(type);
148   if (auto ref_type = element_type.dyn_cast<mlir::TF::TensorFlowRefType>()) {
149     element_type = ref_type.RemoveRef();
150   }
151   return element_type;
152 }
153 
154 #define HANDLE_TF_TYPE(tftype, enumerant, name)                          \
155   class tftype##Type : public detail::TensorFlowTypeImpl<tftype##Type> { \
156    public:                                                               \
157     using TFBase::TFBase;                                                \
158   };
159 
160 // Custom TensorFlow types are defined separately.
161 #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name)
162 
163 // NOLINTNEXTLINE
164 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
165 
166 namespace detail {
167 // Storage type contains inferred subtypes for TypeWithSubtype.
168 class TypeWithSubtypeStorage : public TypeStorage {
169  public:
170   using KeyTy = ArrayRef<TensorType>;
171 
172   // NOLINTNEXTLINE
construct(TypeStorageAllocator & allocator,const KeyTy & key)173   static TypeWithSubtypeStorage* construct(TypeStorageAllocator& allocator,
174                                            const KeyTy& key) {
175     ArrayRef<TensorType> subtypes = allocator.copyInto(key);
176     return new (allocator.allocate<TypeWithSubtypeStorage>())
177         TypeWithSubtypeStorage(subtypes);
178   }
179 
TypeWithSubtypeStorage(const KeyTy & key)180   explicit TypeWithSubtypeStorage(const KeyTy& key) : subtypes_(key) {}
181 
182   bool operator==(const KeyTy& key) const { return key == subtypes_; }
183 
hashKey(const KeyTy & key)184   static llvm::hash_code hashKey(const KeyTy& key) {
185     return llvm::hash_combine_range(key.begin(), key.end());
186   }
187 
188   KeyTy subtypes_;
189 };
190 
191 // Common implementation of TensorFlow types with subtypes. These subtypes are
192 // opaque and their interpretation depends on the actual underlying type.
193 // The template argument indicates the concrete derived class per CRTP. Concrete
194 // classes must implement the following:
195 //   - `static std::string getTypeName()` that returns the name of the type for
196 //     verification logging.
197 template <typename Derived>
198 class TypeWithSubtypeImpl
199     : public Type::TypeBase<Derived, TensorFlowType, TypeWithSubtypeStorage> {
200  public:
201   using Base = Type::TypeBase<Derived, TensorFlowType, TypeWithSubtypeStorage>;
202   using TFBase = TypeWithSubtypeImpl<Derived>;
203   using Base::Base;
204 
get(ArrayRef<TensorType> subtypes,MLIRContext * context)205   static Derived get(ArrayRef<TensorType> subtypes, MLIRContext* context) {
206     return Base::get(context, subtypes);
207   }
208 
getChecked(ArrayRef<TensorType> subtypes,MLIRContext * context,Location loc)209   static Derived getChecked(ArrayRef<TensorType> subtypes, MLIRContext* context,
210                             Location loc) {
211     return Base::getChecked(loc, subtypes);
212   }
213 
get(MLIRContext * context)214   static Derived get(MLIRContext* context) { return get({}, context); }
215 
verifyConstructionInvariants(Location loc,ArrayRef<TensorType> subtypes)216   static LogicalResult verifyConstructionInvariants(
217       Location loc, ArrayRef<TensorType> subtypes) {
218     // Each of the subtypes should be a valid TensorFlow type.
219     for (TensorType subtype : subtypes) {
220       if (!IsValidTFTensorType(subtype)) {
221         return emitError(loc) << "invalid " << Derived::getTypeName()
222                               << " subtype: " << subtype;
223       }
224     }
225     return success();
226   }
227 
getSubtypes()228   ArrayRef<TensorType> getSubtypes() { return Base::getImpl()->subtypes_; }
229 };
230 }  // namespace detail
231 
232 // TensorFlowTypeWithSubtype class supports all the types with subtypes in
233 // TensorFlow dialect.
234 class TensorFlowTypeWithSubtype : public TensorFlowType {
235  public:
236   using TensorFlowType::TensorFlowType;
237 
238   // Checks if a type is TensorFlow type with subtypes.
239   static bool classof(Type type);
240 
241   // Converts a TypeWithSubtype type to the same type but without its subtypes.
242   Type RemoveSubtypes();
243 
244   // Returns the subtypes.
245   ArrayRef<TensorType> GetSubtypes();
246 };
247 
248 // Returns the corresponding TensorFlow type with subtypes but without its
249 // subtypes.
GetDefaultTypeOf(TensorFlowTypeWithSubtype type)250 static inline Type GetDefaultTypeOf(TensorFlowTypeWithSubtype type) {
251   return type.RemoveSubtypes();
252 }
253 
254 // TensorFlow resource type is used to support TensorFlow resource variables,
255 // which represent shared, persistent state manipulated by a TensorFlow program.
256 // ResourceType stores shape and datatype for subtypes unlike most other data
257 // types that don't have any associated information.
258 class ResourceType : public detail::TypeWithSubtypeImpl<ResourceType> {
259  public:
260   using TFBase::TFBase;
getTypeName()261   static std::string getTypeName() { return "ResourceType"; }
262 };
263 
264 // TensorFlow variant type is used to support arbitrary custom C++ data types.
265 // VariantType stores inferred shape and datatype for subtypes unlike most other
266 // data types that don't have any associated information. For example, variants
267 // encoding TensorList type stores the common shape and dtype of the list
268 // elements as the only subtype.
269 class VariantType : public detail::TypeWithSubtypeImpl<VariantType> {
270  public:
271   using TFBase::TFBase;
getTypeName()272   static std::string getTypeName() { return "VariantType"; }
273 };
274 
275 // Given two types `a` and `b`, returns a refined type which is cast compatible
276 // with both `a` and `b` and is equal to or more precise than both of them. It
277 // returns empty Type if the input types are not cast compatible.
278 // Provides option to ignore ref types on 'a'. This is useful for TF ops that
279 // might allow operands to either be same as result type or be a ref type
280 // corresponding to it.
281 mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b,
282                                  bool may_ignore_ref_type_a);
283 
284 // Returns whether two arrays of Type are broadcast compatible.
285 bool BroadcastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs);
286 
287 // Returns whether the two elemental types are compatible. Shapes are compatible
288 // if:
289 // - the types are statically equal
290 // - could be dynamically equal
291 //   - considering dynamic shapes equal unless contradictory info known;
292 //   - element types are equivalent, modulo subtypes possible be less exact
293 //     (e.g., a resource type without subtype is considered compatible with
294 //      resource type with known subtype).
295 // Provide option to ignore ref types on 'lhs'.
296 bool HasCompatibleElementTypes(Type lhs, Type rhs,
297                                bool may_ignore_ref_type_lhs = false);
298 
299 // Returns true if all TensorFlow types can be cast to one
300 // another. In other words, a single run-time value is legal for both the types.
301 // For example, tensor<*xf32>, tensor<?xf32> and tensor<3xf32> are cast
302 // compatible.
303 bool AreCastCompatible(ArrayRef<Type> types);
304 
305 // Returns true if corresponding elements of lhs and rhs AreCastCompatible and
306 // lhs and rhs are the same length.
307 bool ArraysAreCastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs);
308 
309 // If `ty` is a tensor type and its element type has subtypes, then returns a
310 // new type of same shape but dropped subtypes for the element type.
311 // Otherwise, if `ty` has subtypes, then returns corresponding type with dropped
312 // subtypes.
313 // Otherwise, returns the original type `ty`.
314 Type DropSubTypes(Type ty);
315 
316 // If `ty` is a tensor type and has elements of a ref type, then returns a new
317 // type of same shape but corresponding non-ref type as element type.
318 // Otherwise, if `ty` is a ref type, then returns corresponding non-ref type.
319 // Otherwise, returns the original type `ty`.
320 Type DropRefType(Type ty);
321 
322 // Convenience call for executing both `DropRefType` and `DropSubTypes`.
323 Type DropRefAndSubTypes(Type ty);
324 
325 }  // end namespace TF
326 }  // end namespace mlir
327 
328 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TYPES_H_
329