1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h"
16
17 #include <algorithm>
18 #include <cctype>
19 #include <memory>
20
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
24 #include "mlir/Support/TypeID.h" // from @llvm-project
25 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
26 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
27 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
28
29 namespace mlir {
30 namespace TFL {
31 namespace tac {
32 namespace {
33 struct RegisteredTargetHardware {
34 // TODO(b/177376459): Remove this constructor.
RegisteredTargetHardwaremlir::TFL::tac::__anona2b81f870111::RegisteredTargetHardware35 RegisteredTargetHardware(const std::string& name,
36 const std::string& description, mlir::TypeID type_id,
37 std::unique_ptr<TargetHardware> target_hardware)
38 : unique_name(GetCanonicalHardwareName(name)),
39 description(description),
40 type_id(type_id),
41 target_hardware(std::move(target_hardware)) {}
42
RegisteredTargetHardwaremlir::TFL::tac::__anona2b81f870111::RegisteredTargetHardware43 RegisteredTargetHardware(
44 const std::string& name, const std::string& description,
45 mlir::TypeID type_id,
46 std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory)
47 : unique_name(GetCanonicalHardwareName(name)),
48 description(description),
49 target_hardware_factory(target_hardware_factory) {}
50
51 std::string unique_name;
52 std::string description;
53 mlir::TypeID type_id;
54 std::unique_ptr<TargetHardware> target_hardware;
55 std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory;
56 };
57
58 struct RegisteredTargetHardwareOps {
RegisteredTargetHardwareOpsmlir::TFL::tac::__anona2b81f870111::RegisteredTargetHardwareOps59 explicit RegisteredTargetHardwareOps(mlir::TypeID hardware_type)
60 : hardware_typeid(hardware_type) {}
61 // Key is the Operation TypeID
62 llvm::DenseMap<mlir::TypeID, std::unique_ptr<TargetHardwareOperation>>
63 target_hardware_ops;
64 // Key is the Operation TypeID
65 llvm::DenseMap<mlir::TypeID,
66 std::function<std::unique_ptr<TargetHardwareOperation>()>>
67 target_hardware_ops_factory;
68 mlir::TypeID hardware_typeid;
69 };
70
71 std::vector<std::unique_ptr<RegisteredTargetHardwareOps>>*
GetRegisteredTargetHardwareOps()72 GetRegisteredTargetHardwareOps() {
73 static std::vector<std::unique_ptr<RegisteredTargetHardwareOps>>*
74 hardwares_ops =
75 []() -> std::vector<std::unique_ptr<RegisteredTargetHardwareOps>>* {
76 return new std::vector<std::unique_ptr<RegisteredTargetHardwareOps>>();
77 }();
78 return hardwares_ops;
79 }
80
GetRegisteredHardwares()81 std::vector<RegisteredTargetHardware>* GetRegisteredHardwares() {
82 static std::vector<RegisteredTargetHardware>* hardwares =
83 []() -> std::vector<RegisteredTargetHardware>* {
84 return new std::vector<RegisteredTargetHardware>();
85 }();
86 return hardwares;
87 }
88
89 llvm::DenseMap<mlir::TypeID, std::unique_ptr<TargetHardwareOperation>>*
getRegisteredOperationsForHardware(mlir::TypeID type_id)90 getRegisteredOperationsForHardware(mlir::TypeID type_id) {
91 auto* hardwares = GetRegisteredTargetHardwareOps();
92 for (auto& hardware : *hardwares) {
93 if (hardware->hardware_typeid == type_id) {
94 return &hardware->target_hardware_ops;
95 }
96 }
97 return nullptr;
98 }
99
100 // A deny list for op cost computation since those ops are not arithemtic.
IsNonArithmeticOp(mlir::Operation * op)101 inline bool IsNonArithmeticOp(mlir::Operation* op) {
102 if (llvm::isa<ReturnOp, FuncOp>(op)) return true;
103 if (op->hasTrait<OpTrait::ConstantLike>()) return true;
104 if (llvm::isa<QConstOp, SparseQConstOp>(op)) return true;
105 if (!IsTFLNonQuantDequantizeOp(op)) return true;
106 return false;
107 }
108
109 } // namespace
110
Init()111 bool TargetHardware::Init() {
112 auto* hardware_ops_factory = GetRegisteredTargetHardwareOps();
113 for (auto& hardware_ops : *hardware_ops_factory) {
114 if (hardware_ops->hardware_typeid != this->GetTypeId()) continue;
115 auto& op_factories = hardware_ops->target_hardware_ops_factory;
116 for (auto& op_factory : op_factories) {
117 hardware_ops_.emplace_back(op_factory.getSecond()());
118 }
119 break;
120 }
121 return true;
122 }
123
GetOpCost(mlir::Operation * op) const124 double TargetHardware::GetOpCost(mlir::Operation* op) const {
125 auto* registered_ops = getRegisteredOperationsForHardware(GetTypeId());
126 if (registered_ops == nullptr) {
127 return kDefaultFixedValuedCost;
128 }
129 auto* abstract_op = op->getAbstractOperation();
130 auto hardware_op = registered_ops->find(abstract_op->typeID);
131 if (hardware_op == registered_ops->end()) return kDefaultFixedValuedCost;
132 return hardware_op->second->GetOpCost(op);
133 }
134
IsOpSupported(mlir::Operation * op) const135 bool TargetHardware::IsOpSupported(mlir::Operation* op) const {
136 auto* registered_ops = getRegisteredOperationsForHardware(GetTypeId());
137 if (registered_ops == nullptr) {
138 return false;
139 }
140 auto* abstract_op = op->getAbstractOperation();
141 auto hardware_op = registered_ops->find(abstract_op->typeID);
142 if (hardware_op == registered_ops->end()) return false;
143 return hardware_op->second->IsOpSupported(op);
144 }
145
GetFuncCost(FuncOp * func) const146 double TargetHardware::GetFuncCost(FuncOp* func) const {
147 double total_cost = 0.0;
148 func->walk([&](Operation* op) {
149 if (IsNonArithmeticOp(op)) return;
150 // We will always defer to the hardware to decide the cost.
151 total_cost += GetOpCost(op);
152 });
153 return total_cost;
154 }
155
GetTargetHardware(const std::string & hardware_name)156 const TargetHardware* GetTargetHardware(const std::string& hardware_name) {
157 const std::string canonical_name = GetCanonicalHardwareName(hardware_name);
158 // Just loop for now, we don't expect number of hardwares to be huge.
159 // Revisit to have map if number of elements increased.
160 auto* registered_hardwares = GetRegisteredHardwares();
161 for (const auto& hardware : *registered_hardwares) {
162 if (hardware.unique_name == canonical_name) {
163 return hardware.target_hardware.get();
164 }
165 }
166 return nullptr;
167 }
168
GetTargetHardwareFactory(const std::string & hardware_name)169 std::function<std::unique_ptr<TargetHardware>()> GetTargetHardwareFactory(
170 const std::string& hardware_name) {
171 const std::string canonical_name = GetCanonicalHardwareName(hardware_name);
172 // Just loop for now, we don't expect number of hardwares to be huge.
173 // Revisit to have map if number of elements increased.
174 auto* registered_hardwares = GetRegisteredHardwares();
175 for (const auto& hardware : *registered_hardwares) {
176 if (hardware.unique_name == canonical_name) {
177 return hardware.target_hardware_factory;
178 }
179 }
180 return nullptr;
181 }
182
183 namespace internal {
184
RegisterTargetHardware(const std::string & unique_name,const std::string & description,mlir::TypeID type_id,std::function<std::unique_ptr<TargetHardware> ()> target_hardware_factory)185 void RegisterTargetHardware(
186 const std::string& unique_name, const std::string& description,
187 mlir::TypeID type_id,
188 std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory) {
189 auto* registered_hardwares = GetRegisteredHardwares();
190 for (const auto& hardware : *registered_hardwares) {
191 if (hardware.unique_name == unique_name) {
192 llvm::errs() << "Ignoring duplicate hardware. Hardware " << unique_name
193 << " already registered\n";
194 return;
195 }
196 }
197 registered_hardwares->push_back(RegisteredTargetHardware(
198 unique_name, description, type_id, target_hardware_factory()));
199 }
200
RegisterTargetHardwareFactory(const std::string & unique_name,const std::string & description,mlir::TypeID type_id,std::function<std::unique_ptr<TargetHardware> ()> target_hardware_factory)201 void RegisterTargetHardwareFactory(
202 const std::string& unique_name, const std::string& description,
203 mlir::TypeID type_id,
204 std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory) {
205 auto* registered_hardwares = GetRegisteredHardwares();
206 for (auto& hardware : *registered_hardwares) {
207 if (hardware.unique_name == unique_name) {
208 llvm::errs() << "Ignoring duplicate hardware. Hardware " << unique_name
209 << " already registered\n";
210 hardware.target_hardware_factory = target_hardware_factory;
211 return;
212 }
213 }
214 registered_hardwares->push_back(RegisteredTargetHardware(
215 unique_name, description, type_id, target_hardware_factory));
216 }
217
RegisterTargetHardwareOp(mlir::TypeID hardware_type,mlir::TypeID op_type,std::function<std::unique_ptr<TargetHardwareOperation> ()> target_hardware_op_factory)218 void RegisterTargetHardwareOp(
219 mlir::TypeID hardware_type, mlir::TypeID op_type,
220 std::function<std::unique_ptr<TargetHardwareOperation>()>
221 target_hardware_op_factory) {
222 auto* registered_hardware_ops = GetRegisteredTargetHardwareOps();
223 for (auto& hardware : *registered_hardware_ops) {
224 if (hardware->hardware_typeid == hardware_type) {
225 if (hardware->target_hardware_ops.count(op_type)) {
226 llvm::errs() << "Trying to register duplicate Op";
227 return;
228 }
229 hardware->target_hardware_ops[op_type] = target_hardware_op_factory();
230 return;
231 }
232 }
233 registered_hardware_ops->push_back(
234 std::make_unique<RegisteredTargetHardwareOps>(
235 RegisteredTargetHardwareOps(hardware_type)));
236 registered_hardware_ops->back()->target_hardware_ops[op_type] =
237 target_hardware_op_factory();
238 }
239
RegisterTargetHardwareOpFactory(mlir::TypeID hardware_type,mlir::TypeID op_type,std::function<std::unique_ptr<TargetHardwareOperation> ()> target_hardware_op_factory)240 void RegisterTargetHardwareOpFactory(
241 mlir::TypeID hardware_type, mlir::TypeID op_type,
242 std::function<std::unique_ptr<TargetHardwareOperation>()>
243 target_hardware_op_factory) {
244 auto* registered_hardware_ops = GetRegisteredTargetHardwareOps();
245 for (auto& hardware : *registered_hardware_ops) {
246 if (hardware->hardware_typeid == hardware_type) {
247 if (hardware->target_hardware_ops_factory.count(op_type)) {
248 llvm::errs() << "Trying to register duplicate Op";
249 return;
250 }
251 hardware->target_hardware_ops_factory[op_type] =
252 target_hardware_op_factory;
253 return;
254 }
255 }
256 registered_hardware_ops->push_back(
257 std::make_unique<RegisteredTargetHardwareOps>(
258 RegisteredTargetHardwareOps(hardware_type)));
259 registered_hardware_ops->back()->target_hardware_ops_factory[op_type] =
260 target_hardware_op_factory;
261 }
262
263 } // namespace internal
264
ProcessTargetDevices(llvm::ArrayRef<std::string> specified_device_specs,std::vector<std::string> * device_specs)265 bool ProcessTargetDevices(llvm::ArrayRef<std::string> specified_device_specs,
266 std::vector<std::string>* device_specs) {
267 bool cpu_include = false;
268 for (auto& device_spec : specified_device_specs) {
269 auto device = GetCanonicalHardwareName(device_spec);
270
271 if (device == "CPU") cpu_include = true;
272 device_specs->push_back(device);
273 }
274 if (!cpu_include) {
275 device_specs->push_back("CPU");
276 }
277
278 // Make sure all the devices are registered.
279 for (const std::string& device : *device_specs) {
280 if (GetTargetHardware(device) == nullptr) return false;
281 }
282
283 return true;
284 }
285
GetHardwareName(const TargetHardware * hardware)286 std::string GetHardwareName(const TargetHardware* hardware) {
287 const auto* registered_hardwares = GetRegisteredHardwares();
288 for (const auto& registered_hardware : *registered_hardwares) {
289 if (registered_hardware.type_id == hardware->GetTypeId())
290 return registered_hardware.unique_name;
291 }
292 return "";
293 }
294
295 } // namespace tac
296 } // namespace TFL
297 } // namespace mlir
298