• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_TYPES_H_
17 #define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_TYPES_H_
18 
19 #include "mlir/IR/Attributes.h"  // from @llvm-project
20 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
21 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
22 #include "mlir/IR/Location.h"  // from @llvm-project
23 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
24 #include "mlir/IR/Operation.h"  // from @llvm-project
25 #include "mlir/IR/TypeSupport.h"  // from @llvm-project
26 #include "mlir/IR/Types.h"  // from @llvm-project
27 
28 namespace mlir {
29 namespace TFR {
30 
31 class TFRType : public Type {
32  public:
33   using Type::Type;
34 
35   static bool classof(Type type);
36 };
37 
38 namespace detail {
39 
40 struct TFRTypeStorage final
41     : public TypeStorage,
42       public llvm::TrailingObjects<TFRTypeStorage, StringAttr> {
43   using KeyTy = ArrayRef<StringAttr>;
44 
TFRTypeStoragefinal45   explicit TFRTypeStorage(unsigned num_attrs) : num_attrs(num_attrs) {}
46 
constructfinal47   static TFRTypeStorage* construct(TypeStorageAllocator& allocator, KeyTy key) {
48     // Allocate a new storage instance.
49     auto byteSize = TFRTypeStorage::totalSizeToAlloc<StringAttr>(key.size());
50     auto rawMem = allocator.allocate(byteSize, alignof(TFRTypeStorage));
51     auto result = ::new (rawMem) TFRTypeStorage(key.size());
52 
53     // Copy in the string attributes into the trailing storage.
54     std::uninitialized_copy(key.begin(), key.end(),
55                             result->getTrailingObjects<StringAttr>());
56     return result;
57   }
58 
59   bool operator==(const KeyTy& attrs) const { return attrs == GetAttrs(); }
60 
GetAttrsfinal61   KeyTy GetAttrs() const {
62     return {getTrailingObjects<StringAttr>(), num_attrs};
63   }
64 
65   unsigned num_attrs;
66 };
67 
68 template <typename Derived>
69 class TFRTypeImpl : public Type::TypeBase<Derived, TFRType, TFRTypeStorage> {
70  public:
71   using Base = Type::TypeBase<Derived, TFRType, TFRTypeStorage>;
72   using TFRBase = TFRTypeImpl<Derived>;
73   using Base::Base;
74 
get(ArrayRef<StringAttr> attrs,MLIRContext * context)75   static Derived get(ArrayRef<StringAttr> attrs, MLIRContext* context) {
76     return Base::get(context, attrs);
77   }
78 
getChecked(ArrayRef<StringAttr> attrs,Location loc)79   static Derived getChecked(ArrayRef<StringAttr> attrs, Location loc) {
80     return Base::getChecked(loc, loc.getContext(), attrs);
81   }
getChecked(function_ref<InFlightDiagnostic ()> emitError,MLIRContext * context,ArrayRef<StringAttr> attrs)82   static Derived getChecked(function_ref<InFlightDiagnostic()> emitError,
83                             MLIRContext* context, ArrayRef<StringAttr> attrs) {
84     return Base::getChecked(emitError, context, attrs);
85   }
86 
get(MLIRContext * context)87   static Derived get(MLIRContext* context) { return get({}, context); }
88 
89   // TODO(fengliuai): fix the implementation
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<StringAttr> attrs)90   static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
91                               ArrayRef<StringAttr> attrs) {
92     return success();
93   }
94 
getAttrKeys()95   ArrayRef<StringAttr> getAttrKeys() { return Base::getImpl()->GetAttrs(); }
96 };
97 }  // namespace detail
98 
99 class TFRTensorType : public detail::TFRTypeImpl<TFRTensorType> {
100  public:
101   using TFRBase::TFRBase;
getTypeName()102   static std::string getTypeName() { return "TFRTensorType"; }
103 };
104 
105 class TFRTensorListType : public detail::TFRTypeImpl<TFRTensorListType> {
106  public:
107   using TFRBase::TFRBase;
getTypeName()108   static std::string getTypeName() { return "TFRTensorListType"; }
109 };
110 
111 class TFRAttrType : public Type::TypeBase<TFRAttrType, TFRType, TypeStorage> {
112  public:
113   using Base::Base;
getTypeName()114   static std::string getTypeName() { return "TFRAttrType"; }
115 };
116 
117 }  // namespace TFR
118 }  // namespace mlir
119 
120 #endif  // TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_TYPES_H_
121