1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TARGET_HARDWARE_H_
16 #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TARGET_HARDWARE_H_
17
18 #include <memory>
19 #include <vector>
20
21 #include "mlir/IR/Operation.h" // from @llvm-project
22 #include "mlir/IR/PatternMatch.h" // from @llvm-project
23 #include "mlir/Support/TypeID.h" // from @llvm-project
24
25 namespace mlir {
26 namespace TFL {
27 namespace tac {
28
29 // Default fixed values for ops.
30 constexpr static float kDefaultFixedValuedCost = 1000000.0;
31
32 // This is just fake data.
33 constexpr static float kCrossHardwareTransferPerByteCost = 5.0f;
34
35 // This is just fake data.
36 constexpr static float kCrossHardwareTransferFixedCost = 10.f;
37
38 // Interface for an Operation capabilities which should be tied to
39 // a specific hardware.
40 // Users should implement the interface and use TargetHardwareOpRegistration
41 // for registering the operation.
42 class TargetHardwareOperation {
43 public:
~TargetHardwareOperation()44 virtual ~TargetHardwareOperation() {}
45
46 virtual double GetOpCost(mlir::Operation* op) const = 0;
47
48 virtual bool IsOpSupported(mlir::Operation* op) const = 0;
49 };
50
51 // Abstract base class for a hardware.
52 // To introduce new hardware
53 // users should implement the interface and use TargetHardwareRegistration
54 // for registering the hardware.
55 // Subclasses must implement the pure virtual function interface and
56 // define static member variable that retrieves string identifying the Target
57 // Hardware. Example,
58 // class MyType : public TargetHardware {
59 // public:
60 // static constexpr char kId[] = "MyHardware";
61 // };
62 class TargetHardware {
63 public:
~TargetHardware()64 virtual ~TargetHardware() {}
65
66 // Initializes all TargetHardwareOperation registered for this hardware.
67 // Users overriding this function, should call the base class method to
68 // initialize the ops.
69 virtual bool Init();
70
71 // Returns the cost of running 'op' on this Hardware.
72 virtual double GetOpCost(mlir::Operation* op) const;
73
74 // Returns the cost of running the whole function on this hardware.
75 // By default this is the sum of the cost of individual cost for each op.
76 virtual double GetFuncCost(FuncOp* func) const;
77
78 // Returns true if 'op' can run on this Hardware.
79 virtual bool IsOpSupported(mlir::Operation* op) const;
80
81 // Switching cost between from hardware and this hardware.
82 // If both the hardwares are the same, the transfer cost is basically 0.
83 virtual double GetHardwareSwitchingCost(const TargetHardware* from,
84 size_t buffer_size) const = 0;
85
86 // Returns a list of all patterns to apply for this hardware.
87 virtual mlir::OwningRewritePatternList GetTransformations(
88 MLIRContext* context) const = 0;
89
90 // Returns TypeId for the provided hardware.
91 // Usually should be something like mlir::TypeID::get<MyType>()
92 virtual mlir::TypeID GetTypeId() const = 0;
93
94 protected:
95 // All registered hardware ops.
96 std::vector<std::unique_ptr<TargetHardwareOperation>> hardware_ops_;
97 };
98
99 // Returns pointer to the Hardware identified by 'hardware_name'.
100 // If not found nullptr is returned.
101 // DEPRECATED: Do not use, prefer GetTargetHardwareFactory instead.
102 const TargetHardware* GetTargetHardware(const std::string& hardware_name);
103
104 // Returns the factory method for the requested hardware if present.
105 std::function<std::unique_ptr<TargetHardware>()> GetTargetHardwareFactory(
106 const std::string& hardware_name);
107
108 namespace internal {
109 // DEPRECATED: Do not use, prefer using RegisterTargetHardwareFactory instead.
110 void RegisterTargetHardware(
111 const std::string& unique_name, const std::string& description,
112 mlir::TypeID type_id,
113 std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory);
114
115 // DEPRECATED: Do not use, prefer using RegisterTargetHardwareFactory instead.
116 template <typename T>
RegisterTargetHardware(const std::string & description,std::function<std::unique_ptr<TargetHardware> ()> target_hardware_factory)117 void RegisterTargetHardware(
118 const std::string& description,
119 std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory) {
120 RegisterTargetHardware(T::kId, description, mlir::TypeID::get<T>(),
121 target_hardware_factory);
122 }
123
124 void RegisterTargetHardwareFactory(
125 const std::string& unique_name, const std::string& description,
126 mlir::TypeID type_id,
127 std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory);
128
129 // Registers the provided target hardware factory.
130 template <typename T>
RegisterTargetHardwareFactory(const std::string & description,std::function<std::unique_ptr<TargetHardware> ()> target_hardware_factory)131 void RegisterTargetHardwareFactory(
132 const std::string& description,
133 std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory) {
134 RegisterTargetHardwareFactory(T::kId, description, mlir::TypeID::get<T>(),
135 target_hardware_factory);
136 }
137
138 // DEPRECATED: Do not use, prefer RegisterTargetHardwareOpFactory intstead.
139 void RegisterTargetHardwareOp(
140 mlir::TypeID hardware_type, mlir::TypeID op_type,
141 std::function<std::unique_ptr<TargetHardwareOperation>()>
142 target_hardware_op_factory);
143
144 void RegisterTargetHardwareOpFactory(
145 mlir::TypeID hardware_type, mlir::TypeID op_type,
146 std::function<std::unique_ptr<TargetHardwareOperation>()>
147 target_hardware_op_factory);
148 } // namespace internal
149
150 // Register target hardware.
151 template <typename Hardware>
152 struct TargetHardwareRegistration {
TargetHardwareRegistrationTargetHardwareRegistration153 TargetHardwareRegistration(const std::string& description,
154 std::function<std::unique_ptr<TargetHardware>()>
155 target_hardware_factory) {
156 // TODO(b/177376459): remove this.
157 internal::RegisterTargetHardware<Hardware>(description,
158 target_hardware_factory);
159 internal::RegisterTargetHardwareFactory<Hardware>(description,
160 target_hardware_factory);
161 }
162 };
163
164 // Register Op capabilities for specific hardware.
165 template <typename Hardware, typename Op>
166 struct TargetHardwareOpRegistration {
TargetHardwareOpRegistrationTargetHardwareOpRegistration167 explicit TargetHardwareOpRegistration(
168 std::function<std::unique_ptr<TargetHardwareOperation>()>
169 target_hardware_op_factory) {
170 // TODO(b/177376459): remove this.
171 internal::RegisterTargetHardwareOp(mlir::TypeID::get<Hardware>(),
172 mlir::TypeID::get<Op>(),
173 target_hardware_op_factory);
174 internal::RegisterTargetHardwareOpFactory(mlir::TypeID::get<Hardware>(),
175 mlir::TypeID::get<Op>(),
176 target_hardware_op_factory);
177 }
178 };
179
180 //======== util functions ==========
181
182 // Process user specified device specs, will always add CPU if it's not there.
183 // specified_deivce_specs: ',' separated, like "GPU,DSP,CPU".
184 // device_specs: processed device specs enum.
185 bool ProcessTargetDevices(llvm::ArrayRef<std::string> specified_device_specs,
186 std::vector<std::string>* device_specs);
187
188 // Check whether two hardwares are the same.
IsSameHardware(const TargetHardware * lhs,const TargetHardware * rhs)189 inline bool IsSameHardware(const TargetHardware* lhs,
190 const TargetHardware* rhs) {
191 return lhs->GetTypeId() == rhs->GetTypeId();
192 }
193
194 // Returns the ID identifying 'hardware'. This should match the ID defined
195 // in the hardware field ID.
196 // For example, if MyHardware is passed the value returned should match
197 // MyHardware::kId.
198 std::string GetHardwareName(const TargetHardware* hardware);
199
200 } // namespace tac
201 } // namespace TFL
202 } // namespace mlir
203
204 #endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TARGET_HARDWARE_H_
205