• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TARGET_HARDWARE_H_
16 #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TARGET_HARDWARE_H_
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "mlir/IR/Operation.h"  // from @llvm-project
22 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
23 #include "mlir/Support/TypeID.h"  // from @llvm-project
24 
25 namespace mlir {
26 namespace TFL {
27 namespace tac {
28 
29 // Default fixed values for ops.
30 constexpr static float kDefaultFixedValuedCost = 1000000.0;
31 
32 // This is just fake data.
33 constexpr static float kCrossHardwareTransferPerByteCost = 5.0f;
34 
35 // This is just fake data.
36 constexpr static float kCrossHardwareTransferFixedCost = 10.f;
37 
38 // Interface for an Operation capabilities which should be tied to
39 // a specific hardware.
40 // Users should implement the interface and use TargetHardwareOpRegistration
41 // for registering the operation.
42 class TargetHardwareOperation {
43  public:
~TargetHardwareOperation()44   virtual ~TargetHardwareOperation() {}
45 
46   virtual double GetOpCost(mlir::Operation* op) const = 0;
47 
48   virtual bool IsOpSupported(mlir::Operation* op) const = 0;
49 };
50 
51 // Abstract base class for a hardware.
52 // To introduce new hardware
53 // users should implement the interface and use TargetHardwareRegistration
54 // for registering the hardware.
55 // Subclasses must implement the pure virtual function interface and
56 // define static member variable that retrieves string identifying the Target
57 // Hardware. Example,
58 // class MyType : public TargetHardware {
59 //  public:
60 //   static constexpr char kId[] = "MyHardware";
61 // };
62 class TargetHardware {
63  public:
~TargetHardware()64   virtual ~TargetHardware() {}
65 
66   // Initializes all TargetHardwareOperation registered for this hardware.
67   // Users overriding this function, should call the base class method to
68   // initialize the ops.
69   virtual bool Init();
70 
71   // Returns the cost of running 'op' on this Hardware.
72   virtual double GetOpCost(mlir::Operation* op) const;
73 
74   // Returns the cost of running the whole function on this hardware.
75   // By default this is the sum of the cost of individual cost for each op.
76   virtual double GetFuncCost(FuncOp* func) const;
77 
78   // Returns true if 'op' can run on this Hardware.
79   virtual bool IsOpSupported(mlir::Operation* op) const;
80 
81   // Switching cost between from hardware and this hardware.
82   // If both the hardwares are the same, the transfer cost is basically 0.
83   virtual double GetHardwareSwitchingCost(const TargetHardware* from,
84                                           size_t buffer_size) const = 0;
85 
86   // Returns a list of all patterns to apply for this hardware.
87   virtual mlir::OwningRewritePatternList GetTransformations(
88       MLIRContext* context) const = 0;
89 
90   // Returns TypeId for the provided hardware.
91   // Usually should be something like mlir::TypeID::get<MyType>()
92   virtual mlir::TypeID GetTypeId() const = 0;
93 
94  protected:
95   // All registered hardware ops.
96   std::vector<std::unique_ptr<TargetHardwareOperation>> hardware_ops_;
97 };
98 
99 // Returns pointer to the Hardware identified by 'hardware_name'.
100 // If not found nullptr is returned.
101 // DEPRECATED: Do not use, prefer GetTargetHardwareFactory instead.
102 const TargetHardware* GetTargetHardware(const std::string& hardware_name);
103 
104 // Returns the factory method for the requested hardware if present.
105 std::function<std::unique_ptr<TargetHardware>()> GetTargetHardwareFactory(
106     const std::string& hardware_name);
107 
108 namespace internal {
109 // DEPRECATED: Do not use, prefer using RegisterTargetHardwareFactory instead.
110 void RegisterTargetHardware(
111     const std::string& unique_name, const std::string& description,
112     mlir::TypeID type_id,
113     std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory);
114 
115 // DEPRECATED: Do not use, prefer using RegisterTargetHardwareFactory instead.
116 template <typename T>
RegisterTargetHardware(const std::string & description,std::function<std::unique_ptr<TargetHardware> ()> target_hardware_factory)117 void RegisterTargetHardware(
118     const std::string& description,
119     std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory) {
120   RegisterTargetHardware(T::kId, description, mlir::TypeID::get<T>(),
121                          target_hardware_factory);
122 }
123 
124 void RegisterTargetHardwareFactory(
125     const std::string& unique_name, const std::string& description,
126     mlir::TypeID type_id,
127     std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory);
128 
129 // Registers the provided target hardware factory.
130 template <typename T>
RegisterTargetHardwareFactory(const std::string & description,std::function<std::unique_ptr<TargetHardware> ()> target_hardware_factory)131 void RegisterTargetHardwareFactory(
132     const std::string& description,
133     std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory) {
134   RegisterTargetHardwareFactory(T::kId, description, mlir::TypeID::get<T>(),
135                                 target_hardware_factory);
136 }
137 
138 // DEPRECATED: Do not use, prefer RegisterTargetHardwareOpFactory intstead.
139 void RegisterTargetHardwareOp(
140     mlir::TypeID hardware_type, mlir::TypeID op_type,
141     std::function<std::unique_ptr<TargetHardwareOperation>()>
142         target_hardware_op_factory);
143 
144 void RegisterTargetHardwareOpFactory(
145     mlir::TypeID hardware_type, mlir::TypeID op_type,
146     std::function<std::unique_ptr<TargetHardwareOperation>()>
147         target_hardware_op_factory);
148 }  // namespace internal
149 
150 // Register target hardware.
151 template <typename Hardware>
152 struct TargetHardwareRegistration {
TargetHardwareRegistrationTargetHardwareRegistration153   TargetHardwareRegistration(const std::string& description,
154                              std::function<std::unique_ptr<TargetHardware>()>
155                                  target_hardware_factory) {
156     // TODO(b/177376459): remove this.
157     internal::RegisterTargetHardware<Hardware>(description,
158                                                target_hardware_factory);
159     internal::RegisterTargetHardwareFactory<Hardware>(description,
160                                                       target_hardware_factory);
161   }
162 };
163 
164 // Register Op capabilities for specific hardware.
165 template <typename Hardware, typename Op>
166 struct TargetHardwareOpRegistration {
TargetHardwareOpRegistrationTargetHardwareOpRegistration167   explicit TargetHardwareOpRegistration(
168       std::function<std::unique_ptr<TargetHardwareOperation>()>
169           target_hardware_op_factory) {
170     // TODO(b/177376459): remove this.
171     internal::RegisterTargetHardwareOp(mlir::TypeID::get<Hardware>(),
172                                        mlir::TypeID::get<Op>(),
173                                        target_hardware_op_factory);
174     internal::RegisterTargetHardwareOpFactory(mlir::TypeID::get<Hardware>(),
175                                               mlir::TypeID::get<Op>(),
176                                               target_hardware_op_factory);
177   }
178 };
179 
180 //======== util functions ==========
181 
182 // Process user specified device specs, will always add CPU if it's not there.
183 // specified_deivce_specs: ',' separated, like "GPU,DSP,CPU".
184 // device_specs: processed device specs enum.
185 bool ProcessTargetDevices(llvm::ArrayRef<std::string> specified_device_specs,
186                           std::vector<std::string>* device_specs);
187 
188 // Check whether two hardwares are the same.
IsSameHardware(const TargetHardware * lhs,const TargetHardware * rhs)189 inline bool IsSameHardware(const TargetHardware* lhs,
190                            const TargetHardware* rhs) {
191   return lhs->GetTypeId() == rhs->GetTypeId();
192 }
193 
194 // Returns the ID identifying 'hardware'. This should match the ID defined
195 // in the hardware field ID.
196 // For example, if MyHardware is passed the value returned should match
197 // MyHardware::kId.
198 std::string GetHardwareName(const TargetHardware* hardware);
199 
200 }  // namespace tac
201 }  // namespace TFL
202 }  // namespace mlir
203 
204 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TARGET_HARDWARE_H_
205