• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h"
16 
17 #include <algorithm>
18 #include <cctype>
19 #include <memory>
20 
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
24 #include "mlir/Support/TypeID.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
26 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
27 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
28 
29 namespace mlir {
30 namespace TFL {
31 namespace tac {
32 namespace {
33 struct RegisteredTargetHardware {
34   // TODO(b/177376459): Remove this constructor.
RegisteredTargetHardwaremlir::TFL::tac::__anona2b81f870111::RegisteredTargetHardware35   RegisteredTargetHardware(const std::string& name,
36                            const std::string& description, mlir::TypeID type_id,
37                            std::unique_ptr<TargetHardware> target_hardware)
38       : unique_name(GetCanonicalHardwareName(name)),
39         description(description),
40         type_id(type_id),
41         target_hardware(std::move(target_hardware)) {}
42 
RegisteredTargetHardwaremlir::TFL::tac::__anona2b81f870111::RegisteredTargetHardware43   RegisteredTargetHardware(
44       const std::string& name, const std::string& description,
45       mlir::TypeID type_id,
46       std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory)
47       : unique_name(GetCanonicalHardwareName(name)),
48         description(description),
49         target_hardware_factory(target_hardware_factory) {}
50 
51   std::string unique_name;
52   std::string description;
53   mlir::TypeID type_id;
54   std::unique_ptr<TargetHardware> target_hardware;
55   std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory;
56 };
57 
58 struct RegisteredTargetHardwareOps {
RegisteredTargetHardwareOpsmlir::TFL::tac::__anona2b81f870111::RegisteredTargetHardwareOps59   explicit RegisteredTargetHardwareOps(mlir::TypeID hardware_type)
60       : hardware_typeid(hardware_type) {}
61   // Key is the Operation TypeID
62   llvm::DenseMap<mlir::TypeID, std::unique_ptr<TargetHardwareOperation>>
63       target_hardware_ops;
64   // Key is the Operation TypeID
65   llvm::DenseMap<mlir::TypeID,
66                  std::function<std::unique_ptr<TargetHardwareOperation>()>>
67       target_hardware_ops_factory;
68   mlir::TypeID hardware_typeid;
69 };
70 
71 std::vector<std::unique_ptr<RegisteredTargetHardwareOps>>*
GetRegisteredTargetHardwareOps()72 GetRegisteredTargetHardwareOps() {
73   static std::vector<std::unique_ptr<RegisteredTargetHardwareOps>>*
74       hardwares_ops =
75           []() -> std::vector<std::unique_ptr<RegisteredTargetHardwareOps>>* {
76     return new std::vector<std::unique_ptr<RegisteredTargetHardwareOps>>();
77   }();
78   return hardwares_ops;
79 }
80 
GetRegisteredHardwares()81 std::vector<RegisteredTargetHardware>* GetRegisteredHardwares() {
82   static std::vector<RegisteredTargetHardware>* hardwares =
83       []() -> std::vector<RegisteredTargetHardware>* {
84     return new std::vector<RegisteredTargetHardware>();
85   }();
86   return hardwares;
87 }
88 
89 llvm::DenseMap<mlir::TypeID, std::unique_ptr<TargetHardwareOperation>>*
getRegisteredOperationsForHardware(mlir::TypeID type_id)90 getRegisteredOperationsForHardware(mlir::TypeID type_id) {
91   auto* hardwares = GetRegisteredTargetHardwareOps();
92   for (auto& hardware : *hardwares) {
93     if (hardware->hardware_typeid == type_id) {
94       return &hardware->target_hardware_ops;
95     }
96   }
97   return nullptr;
98 }
99 
100 // A deny list for op cost computation since those ops are not arithemtic.
IsNonArithmeticOp(mlir::Operation * op)101 inline bool IsNonArithmeticOp(mlir::Operation* op) {
102   if (llvm::isa<ReturnOp, FuncOp>(op)) return true;
103   if (op->hasTrait<OpTrait::ConstantLike>()) return true;
104   if (llvm::isa<QConstOp, SparseQConstOp>(op)) return true;
105   if (!IsTFLNonQuantDequantizeOp(op)) return true;
106   return false;
107 }
108 
109 }  // namespace
110 
Init()111 bool TargetHardware::Init() {
112   auto* hardware_ops_factory = GetRegisteredTargetHardwareOps();
113   for (auto& hardware_ops : *hardware_ops_factory) {
114     if (hardware_ops->hardware_typeid != this->GetTypeId()) continue;
115     auto& op_factories = hardware_ops->target_hardware_ops_factory;
116     for (auto& op_factory : op_factories) {
117       hardware_ops_.emplace_back(op_factory.getSecond()());
118     }
119     break;
120   }
121   return true;
122 }
123 
GetOpCost(mlir::Operation * op) const124 double TargetHardware::GetOpCost(mlir::Operation* op) const {
125   auto* registered_ops = getRegisteredOperationsForHardware(GetTypeId());
126   if (registered_ops == nullptr) {
127     return kDefaultFixedValuedCost;
128   }
129   auto* abstract_op = op->getAbstractOperation();
130   auto hardware_op = registered_ops->find(abstract_op->typeID);
131   if (hardware_op == registered_ops->end()) return kDefaultFixedValuedCost;
132   return hardware_op->second->GetOpCost(op);
133 }
134 
IsOpSupported(mlir::Operation * op) const135 bool TargetHardware::IsOpSupported(mlir::Operation* op) const {
136   auto* registered_ops = getRegisteredOperationsForHardware(GetTypeId());
137   if (registered_ops == nullptr) {
138     return false;
139   }
140   auto* abstract_op = op->getAbstractOperation();
141   auto hardware_op = registered_ops->find(abstract_op->typeID);
142   if (hardware_op == registered_ops->end()) return false;
143   return hardware_op->second->IsOpSupported(op);
144 }
145 
GetFuncCost(FuncOp * func) const146 double TargetHardware::GetFuncCost(FuncOp* func) const {
147   double total_cost = 0.0;
148   func->walk([&](Operation* op) {
149     if (IsNonArithmeticOp(op)) return;
150     // We will always defer to the hardware to decide the cost.
151     total_cost += GetOpCost(op);
152   });
153   return total_cost;
154 }
155 
GetTargetHardware(const std::string & hardware_name)156 const TargetHardware* GetTargetHardware(const std::string& hardware_name) {
157   const std::string canonical_name = GetCanonicalHardwareName(hardware_name);
158   // Just loop for now, we don't expect number of hardwares to be huge.
159   // Revisit to have map if number of elements increased.
160   auto* registered_hardwares = GetRegisteredHardwares();
161   for (const auto& hardware : *registered_hardwares) {
162     if (hardware.unique_name == canonical_name) {
163       return hardware.target_hardware.get();
164     }
165   }
166   return nullptr;
167 }
168 
GetTargetHardwareFactory(const std::string & hardware_name)169 std::function<std::unique_ptr<TargetHardware>()> GetTargetHardwareFactory(
170     const std::string& hardware_name) {
171   const std::string canonical_name = GetCanonicalHardwareName(hardware_name);
172   // Just loop for now, we don't expect number of hardwares to be huge.
173   // Revisit to have map if number of elements increased.
174   auto* registered_hardwares = GetRegisteredHardwares();
175   for (const auto& hardware : *registered_hardwares) {
176     if (hardware.unique_name == canonical_name) {
177       return hardware.target_hardware_factory;
178     }
179   }
180   return nullptr;
181 }
182 
183 namespace internal {
184 
RegisterTargetHardware(const std::string & unique_name,const std::string & description,mlir::TypeID type_id,std::function<std::unique_ptr<TargetHardware> ()> target_hardware_factory)185 void RegisterTargetHardware(
186     const std::string& unique_name, const std::string& description,
187     mlir::TypeID type_id,
188     std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory) {
189   auto* registered_hardwares = GetRegisteredHardwares();
190   for (const auto& hardware : *registered_hardwares) {
191     if (hardware.unique_name == unique_name) {
192       llvm::errs() << "Ignoring duplicate hardware. Hardware " << unique_name
193                    << " already registered\n";
194       return;
195     }
196   }
197   registered_hardwares->push_back(RegisteredTargetHardware(
198       unique_name, description, type_id, target_hardware_factory()));
199 }
200 
RegisterTargetHardwareFactory(const std::string & unique_name,const std::string & description,mlir::TypeID type_id,std::function<std::unique_ptr<TargetHardware> ()> target_hardware_factory)201 void RegisterTargetHardwareFactory(
202     const std::string& unique_name, const std::string& description,
203     mlir::TypeID type_id,
204     std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory) {
205   auto* registered_hardwares = GetRegisteredHardwares();
206   for (auto& hardware : *registered_hardwares) {
207     if (hardware.unique_name == unique_name) {
208       llvm::errs() << "Ignoring duplicate hardware. Hardware " << unique_name
209                    << " already registered\n";
210       hardware.target_hardware_factory = target_hardware_factory;
211       return;
212     }
213   }
214   registered_hardwares->push_back(RegisteredTargetHardware(
215       unique_name, description, type_id, target_hardware_factory));
216 }
217 
RegisterTargetHardwareOp(mlir::TypeID hardware_type,mlir::TypeID op_type,std::function<std::unique_ptr<TargetHardwareOperation> ()> target_hardware_op_factory)218 void RegisterTargetHardwareOp(
219     mlir::TypeID hardware_type, mlir::TypeID op_type,
220     std::function<std::unique_ptr<TargetHardwareOperation>()>
221         target_hardware_op_factory) {
222   auto* registered_hardware_ops = GetRegisteredTargetHardwareOps();
223   for (auto& hardware : *registered_hardware_ops) {
224     if (hardware->hardware_typeid == hardware_type) {
225       if (hardware->target_hardware_ops.count(op_type)) {
226         llvm::errs() << "Trying to register duplicate Op";
227         return;
228       }
229       hardware->target_hardware_ops[op_type] = target_hardware_op_factory();
230       return;
231     }
232   }
233   registered_hardware_ops->push_back(
234       std::make_unique<RegisteredTargetHardwareOps>(
235           RegisteredTargetHardwareOps(hardware_type)));
236   registered_hardware_ops->back()->target_hardware_ops[op_type] =
237       target_hardware_op_factory();
238 }
239 
RegisterTargetHardwareOpFactory(mlir::TypeID hardware_type,mlir::TypeID op_type,std::function<std::unique_ptr<TargetHardwareOperation> ()> target_hardware_op_factory)240 void RegisterTargetHardwareOpFactory(
241     mlir::TypeID hardware_type, mlir::TypeID op_type,
242     std::function<std::unique_ptr<TargetHardwareOperation>()>
243         target_hardware_op_factory) {
244   auto* registered_hardware_ops = GetRegisteredTargetHardwareOps();
245   for (auto& hardware : *registered_hardware_ops) {
246     if (hardware->hardware_typeid == hardware_type) {
247       if (hardware->target_hardware_ops_factory.count(op_type)) {
248         llvm::errs() << "Trying to register duplicate Op";
249         return;
250       }
251       hardware->target_hardware_ops_factory[op_type] =
252           target_hardware_op_factory;
253       return;
254     }
255   }
256   registered_hardware_ops->push_back(
257       std::make_unique<RegisteredTargetHardwareOps>(
258           RegisteredTargetHardwareOps(hardware_type)));
259   registered_hardware_ops->back()->target_hardware_ops_factory[op_type] =
260       target_hardware_op_factory;
261 }
262 
263 }  // namespace internal
264 
ProcessTargetDevices(llvm::ArrayRef<std::string> specified_device_specs,std::vector<std::string> * device_specs)265 bool ProcessTargetDevices(llvm::ArrayRef<std::string> specified_device_specs,
266                           std::vector<std::string>* device_specs) {
267   bool cpu_include = false;
268   for (auto& device_spec : specified_device_specs) {
269     auto device = GetCanonicalHardwareName(device_spec);
270 
271     if (device == "CPU") cpu_include = true;
272     device_specs->push_back(device);
273   }
274   if (!cpu_include) {
275     device_specs->push_back("CPU");
276   }
277 
278   // Make sure all the devices are registered.
279   for (const std::string& device : *device_specs) {
280     if (GetTargetHardware(device) == nullptr) return false;
281   }
282 
283   return true;
284 }
285 
GetHardwareName(const TargetHardware * hardware)286 std::string GetHardwareName(const TargetHardware* hardware) {
287   const auto* registered_hardwares = GetRegisteredHardwares();
288   for (const auto& registered_hardware : *registered_hardwares) {
289     if (registered_hardware.type_id == hardware->GetTypeId())
290       return registered_hardware.unique_name;
291   }
292   return "";
293 }
294 
295 }  // namespace tac
296 }  // namespace TFL
297 }  // namespace mlir
298