1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/model.h"
17
18 #include <stdint.h>
19
20 #include <algorithm>
21 #include <iterator>
22 #include <map>
23 #include <memory>
24 #include <string>
25 #include <utility>
26 #include <vector>
27
28 #include "absl/status/status.h"
29 #include "absl/strings/str_cat.h"
30 #include "tensorflow/lite/delegates/gpu/common/shape.h"
31 #include "tensorflow/lite/delegates/gpu/common/status.h"
32 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
33
34 namespace tflite {
35 namespace gpu {
36
nodes() const37 std::vector<Node*> GraphFloat32::nodes() const {
38 return FilterNodes([](const NodeDef&) { return true; });
39 }
40
values() const41 std::vector<Value*> GraphFloat32::values() const {
42 return FilterValues([](const ValueDef&) { return true; });
43 }
44
inputs() const45 std::vector<Value*> GraphFloat32::inputs() const {
46 return FilterValues([](const ValueDef& v) { return v.producer == nullptr; });
47 }
48
variable_inputs() const49 std::vector<Value*> GraphFloat32::variable_inputs() const {
50 return FilterValues(
51 [](const ValueDef& v) { return v.value->tensor.is_variable_input; });
52 }
53
outputs() const54 std::vector<Value*> GraphFloat32::outputs() const {
55 return FilterValues([](const ValueDef& v) { return v.consumers.empty(); });
56 }
57
FindInputs(NodeId id) const58 std::vector<Value*> GraphFloat32::FindInputs(NodeId id) const {
59 if (id >= nodes_.size()) {
60 return {};
61 }
62 return nodes_.at(id).inputs;
63 }
64
FindOutputs(NodeId id) const65 std::vector<Value*> GraphFloat32::FindOutputs(NodeId id) const {
66 if (id >= nodes_.size()) {
67 return {};
68 }
69 return nodes_.at(id).outputs;
70 }
71
IsGraphInput(ValueId id) const72 bool GraphFloat32::IsGraphInput(ValueId id) const {
73 if (id >= values_.size()) {
74 return false;
75 }
76 return values_[id].producer == nullptr;
77 }
78
IsGraphOutput(ValueId id) const79 bool GraphFloat32::IsGraphOutput(ValueId id) const {
80 if (id >= values_.size()) {
81 return false;
82 }
83 return values_[id].consumers.empty();
84 }
85
FindProducer(ValueId id) const86 Node* GraphFloat32::FindProducer(ValueId id) const {
87 if (id >= values_.size()) {
88 return nullptr;
89 }
90 return values_[id].producer;
91 }
92
FindConsumers(ValueId id) const93 std::vector<Node*> GraphFloat32::FindConsumers(ValueId id) const {
94 if (id >= values_.size()) {
95 return {};
96 }
97 return values_[id].consumers;
98 }
99
GetNode(NodeId id) const100 Node* GraphFloat32::GetNode(NodeId id) const {
101 if (id >= nodes_.size()) {
102 return {};
103 }
104 return nodes_.at(id).node.get();
105 }
106
GetValue(ValueId id) const107 Value* GraphFloat32::GetValue(ValueId id) const {
108 if (id >= values_.size()) {
109 return nullptr;
110 }
111 return values_[id].value.get();
112 }
113
NewNode()114 Node* GraphFloat32::NewNode() {
115 const NodeId new_id = nodes_.size();
116 NodeDef def;
117 def.node = absl::make_unique<Node>(Node{static_cast<NodeId>(new_id), {}});
118 Node* node = def.node.get();
119 nodes_[new_id] = std::move(def);
120 execution_plan_.push_back(new_id);
121 return node;
122 }
123
InsertNodeAfter(NodeId id,Node ** new_node)124 absl::Status GraphFloat32::InsertNodeAfter(NodeId id, Node** new_node) {
125 if (id >= nodes_.size()) {
126 return absl::OutOfRangeError("NodeId is out of range");
127 }
128 int idx = 0;
129 while (idx < execution_plan_.size()) {
130 if (execution_plan_[idx] == id) break;
131 ++idx;
132 }
133 if (idx == execution_plan_.size()) {
134 return absl::OutOfRangeError("NodeId not in execution plan");
135 }
136
137 const NodeId new_id = nodes_.size();
138 NodeDef def;
139 def.node = absl::make_unique<Node>(Node{static_cast<NodeId>(new_id), {}});
140 *new_node = def.node.get();
141 nodes_[new_id] = std::move(def);
142 execution_plan_.insert(execution_plan_.begin() + idx + 1, new_id);
143 return absl::OkStatus();
144 }
145
NewValue()146 Value* GraphFloat32::NewValue() {
147 ValueDef def;
148 def.value =
149 absl::make_unique<Value>(Value{static_cast<ValueId>(values_.size()), {}});
150 Value* value = def.value.get();
151 values_.push_back(std::move(def));
152 return value;
153 }
154
SetProducer(NodeId producer,ValueId value)155 absl::Status GraphFloat32::SetProducer(NodeId producer, ValueId value) {
156 ValueDef* v;
157 RETURN_IF_ERROR(LookupValue(value, &v));
158 Value* value_ptr = v->value.get();
159 NodeDef* n;
160 RETURN_IF_ERROR(LookupNode(producer, &n));
161 Node* node_ptr = n->node.get();
162
163 // check if this value has the same producer already
164 if (node_ptr == v->producer) {
165 return absl::AlreadyExistsError(absl::StrCat(
166 "Node ", producer, " is already a producer of the value ", value));
167 }
168
169 // Check if the node is a consumer of this value.
170 if (IsInput(producer, value)) {
171 return absl::InvalidArgumentError("Node is a consumer of the value");
172 }
173
174 if (v->producer != nullptr) {
175 // value is no longer produced by it's previous producer.
176 Erase(&nodes_[v->producer->id].outputs, value_ptr);
177 }
178 v->producer = node_ptr;
179 n->outputs.push_back(value_ptr);
180 return absl::OkStatus();
181 }
182
RemoveProducer(ValueId value)183 absl::Status GraphFloat32::RemoveProducer(ValueId value) {
184 ValueDef* v;
185 RETURN_IF_ERROR(LookupValue(value, &v));
186 Value* value_ptr = v->value.get();
187 if (v->producer == nullptr) {
188 return absl::InvalidArgumentError("Value does not have a producer");
189 }
190 Erase(&nodes_[v->producer->id].outputs, value_ptr);
191 v->producer = nullptr;
192 return absl::OkStatus();
193 }
194
AddConsumer(NodeId consumer,ValueId value)195 absl::Status GraphFloat32::AddConsumer(NodeId consumer, ValueId value) {
196 ValueDef* v;
197 RETURN_IF_ERROR(LookupValue(value, &v));
198 Value* value_ptr = v->value.get();
199 NodeDef* n;
200 RETURN_IF_ERROR(LookupNode(consumer, &n));
201 Node* node_ptr = n->node.get();
202
203 // check if this value has the same producer already
204 if (node_ptr == v->producer) {
205 return absl::InvalidArgumentError("Node is a producer of the value");
206 }
207
208 // check if this value has the same consumer already
209 if (IsInput(consumer, value)) {
210 return absl::AlreadyExistsError(absl::StrCat(
211 "Node ", consumer, " is already a consumer of the value ", value));
212 }
213
214 n->inputs.push_back(value_ptr);
215 v->consumers.push_back(node_ptr);
216 return absl::OkStatus();
217 }
218
219 // Replace input value for given node.
ReplaceInput(NodeId node,ValueId old_value,ValueId new_value)220 absl::Status GraphFloat32::ReplaceInput(NodeId node, ValueId old_value,
221 ValueId new_value) {
222 ValueDef* v_old;
223 RETURN_IF_ERROR(LookupValue(old_value, &v_old));
224 Value* value_old_ptr = v_old->value.get();
225 ValueDef* v_new;
226 RETURN_IF_ERROR(LookupValue(new_value, &v_new));
227 Value* value_new_ptr = v_new->value.get();
228 NodeDef* n;
229 RETURN_IF_ERROR(LookupNode(node, &n));
230 Node* node_ptr = n->node.get();
231
232 // Check if the node is a consumer of old_value.
233 if (!IsInput(node, old_value)) {
234 return absl::InvalidArgumentError("old_value must be input of node.");
235 }
236
237 // Check if the node is not a consumer of new_value.
238 if (IsInput(node, new_value)) {
239 return absl::InvalidArgumentError("new_value can not be input of node.");
240 }
241
242 // Check if this value has the same producer already
243 if (node_ptr == v_new->producer) {
244 return absl::InvalidArgumentError("new_value can not be output of node.");
245 }
246
247 for (int i = 0; i < n->inputs.size(); ++i) {
248 if (n->inputs[i] == value_old_ptr) {
249 n->inputs[i] = value_new_ptr;
250 break;
251 }
252 }
253 v_new->consumers.push_back(node_ptr);
254 Erase(&v_old->consumers, node_ptr);
255 return absl::OkStatus();
256 }
257
RemoveConsumer(NodeId consumer,ValueId value)258 absl::Status GraphFloat32::RemoveConsumer(NodeId consumer, ValueId value) {
259 ValueDef* v;
260 RETURN_IF_ERROR(LookupValue(value, &v));
261 Value* value_ptr = v->value.get();
262 NodeDef* n;
263 RETURN_IF_ERROR(LookupNode(consumer, &n));
264 Node* node_ptr = n->node.get();
265 if (!IsInput(consumer, value)) {
266 return absl::InvalidArgumentError("Node is not a consumer of the value");
267 }
268 Erase(&n->inputs, value_ptr);
269 Erase(&v->consumers, node_ptr);
270 return absl::OkStatus();
271 }
272
DeleteNode(NodeId id)273 absl::Status GraphFloat32::DeleteNode(NodeId id) {
274 NodeDef* n;
275 RETURN_IF_ERROR(LookupNode(id, &n));
276 Node* node_ptr = n->node.get();
277 for (auto value : n->inputs) {
278 Erase(&values_[value->id].consumers, node_ptr);
279 }
280 for (auto value : n->outputs) {
281 values_[value->id].producer = nullptr;
282 }
283 n->inputs.clear();
284 n->outputs.clear();
285 n->node.reset();
286 return absl::OkStatus();
287 }
288
DeleteValue(ValueId id)289 absl::Status GraphFloat32::DeleteValue(ValueId id) {
290 ValueDef* v;
291 RETURN_IF_ERROR(LookupValue(id, &v));
292 Value* value_ptr = v->value.get();
293 if (v->producer != nullptr) {
294 Erase(&nodes_[v->producer->id].outputs, value_ptr);
295 }
296 if (!v->consumers.empty()) {
297 for (auto node : v->consumers) {
298 Erase(&nodes_[node->id].inputs, value_ptr);
299 }
300 }
301 v->producer = nullptr;
302 v->consumers.clear();
303 v->value.reset();
304 return absl::OkStatus();
305 }
306
MakeExactCopy(GraphFloat32 * model) const307 absl::Status GraphFloat32::MakeExactCopy(GraphFloat32* model) const {
308 model->nodes_.clear();
309 model->execution_plan_.clear();
310 model->values_.clear();
311 for (auto& value_def : values_) {
312 model->values_.push_back({});
313 if (value_def.value) {
314 model->values_.back().value = absl::make_unique<Value>(*value_def.value);
315 }
316 }
317 // Add all nodes first.
318 for (auto node_id : execution_plan_) {
319 model->execution_plan_.push_back(node_id);
320 model->nodes_[node_id] = {};
321 auto& node_def = nodes_.at(node_id);
322 if (node_def.node) {
323 model->nodes_[node_id].node = absl::make_unique<Node>(*node_def.node);
324 }
325 }
326 // Wire up dependencies between nodes.
327 for (auto node_id : execution_plan_) {
328 auto& node_def = nodes_.at(node_id);
329 if (node_def.node) {
330 for (auto output : node_def.outputs) {
331 RETURN_IF_ERROR(model->SetProducer(node_def.node->id, output->id));
332 }
333 for (auto input : node_def.inputs) {
334 RETURN_IF_ERROR(model->AddConsumer(node_def.node->id, input->id));
335 }
336 }
337 }
338 return absl::OkStatus();
339 }
340
IsInput(NodeId node,ValueId value)341 bool GraphFloat32::IsInput(NodeId node, ValueId value) {
342 if (node >= nodes_.size() || value >= values_.size()) {
343 return false;
344 }
345 const NodeDef& n = nodes_[node];
346 const ValueDef& v = values_[value];
347 if (!n.node || !v.value) {
348 return false;
349 }
350 return std::find(n.inputs.begin(), n.inputs.end(), v.value.get()) !=
351 n.inputs.end();
352 }
353
LookupNode(NodeId id,NodeDef ** node_def)354 absl::Status GraphFloat32::LookupNode(NodeId id, NodeDef** node_def) {
355 if (id >= nodes_.size()) {
356 return absl::OutOfRangeError("NodeId is out of range");
357 }
358 auto& n = nodes_[id];
359 if (!n.node) {
360 return absl::OutOfRangeError("Node is already deleted");
361 }
362 *node_def = &n;
363 return absl::OkStatus();
364 }
365
LookupValue(ValueId id,ValueDef ** value_def)366 absl::Status GraphFloat32::LookupValue(ValueId id, ValueDef** value_def) {
367 if (id >= values_.size()) {
368 return absl::OutOfRangeError("ValueId is out of range");
369 }
370 auto& v = values_[id];
371 if (!v.value) {
372 return absl::OutOfRangeError("Value is already deleted");
373 }
374 *value_def = &v;
375 return absl::OkStatus();
376 }
377
RemovePrecedingNode(GraphFloat32 * graph,const Node * to_remove,const Node * to_keep)378 absl::Status RemovePrecedingNode(GraphFloat32* graph, const Node* to_remove,
379 const Node* to_keep) {
380 // Make sure all outputs from to_remove are consumed by to_keep.
381 for (auto output : graph->FindOutputs(to_remove->id)) {
382 auto consumers = graph->FindConsumers(output->id);
383 if (consumers.size() > 1 ||
384 (consumers.size() == 1 && consumers[0] != to_keep)) {
385 return absl::InvalidArgumentError(
386 "Output from to_remove node has other consumers");
387 }
388 }
389
390 // Update all references
391 for (auto input : graph->FindInputs(to_remove->id)) {
392 RETURN_IF_ERROR(graph->AddConsumer(to_keep->id, input->id));
393 }
394 for (auto output : graph->FindOutputs(to_remove->id)) {
395 RETURN_IF_ERROR(graph->DeleteValue(output->id));
396 }
397 return graph->DeleteNode(to_remove->id);
398 }
399
RemoveFollowingNode(GraphFloat32 * graph,const Node * to_remove,const Node * to_keep)400 absl::Status RemoveFollowingNode(GraphFloat32* graph, const Node* to_remove,
401 const Node* to_keep) {
402 // Make sure all inputs to to_remove are produced by to_keep.
403 for (auto input : graph->FindInputs(to_remove->id)) {
404 Node* producer = graph->FindProducer(input->id);
405 if (producer->id != to_keep->id) {
406 return absl::InvalidArgumentError("To_remove node has other inputs");
407 }
408 }
409
410 for (auto input : graph->FindInputs(to_remove->id)) {
411 RETURN_IF_ERROR(graph->DeleteValue(input->id));
412 }
413 for (auto output : graph->FindOutputs(to_remove->id)) {
414 RETURN_IF_ERROR(graph->SetProducer(to_keep->id, output->id));
415 }
416 return graph->DeleteNode(to_remove->id);
417 }
418
RemoveSimpleNodeKeepInput(GraphFloat32 * graph,const Node * simple_node)419 absl::Status RemoveSimpleNodeKeepInput(GraphFloat32* graph,
420 const Node* simple_node) {
421 const auto inputs = graph->FindInputs(simple_node->id);
422 const auto outputs = graph->FindOutputs(simple_node->id);
423 if (inputs.size() != 1 || outputs.size() != 1) {
424 return absl::FailedPreconditionError(
425 "simple_node node must have 1 input and 1 output");
426 }
427 const auto input_id = inputs[0]->id;
428 const auto output_id = outputs[0]->id;
429 const Node* producer = graph->FindProducer(input_id);
430 const auto consumers = graph->FindConsumers(output_id);
431 RETURN_IF_ERROR(graph->DeleteNode(simple_node->id));
432 for (auto& consumer : consumers) {
433 RETURN_IF_ERROR(graph->ReplaceInput(consumer->id, output_id, input_id));
434 }
435 RETURN_IF_ERROR(graph->DeleteValue(output_id));
436 if (!producer && consumers.empty()) {
437 RETURN_IF_ERROR(graph->DeleteValue(input_id));
438 }
439 return absl::OkStatus();
440 }
441
RemoveSimpleNodeKeepOutput(GraphFloat32 * graph,const Node * simple_node)442 absl::Status RemoveSimpleNodeKeepOutput(GraphFloat32* graph,
443 const Node* simple_node) {
444 const auto inputs = graph->FindInputs(simple_node->id);
445 const auto outputs = graph->FindOutputs(simple_node->id);
446 if (inputs.size() != 1 || outputs.size() != 1) {
447 return absl::FailedPreconditionError(
448 "simple_node must have 1 input and 1 output");
449 }
450 const auto input_id = inputs[0]->id;
451 const auto output_id = outputs[0]->id;
452 const Node* producer = graph->FindProducer(input_id);
453 const auto input_consumers = graph->FindConsumers(input_id);
454 if (input_consumers.size() != 1) {
455 return absl::FailedPreconditionError(
456 "simple_node should be the only consumer on the node.");
457 }
458
459 RETURN_IF_ERROR(graph->DeleteNode(simple_node->id));
460 if (producer) {
461 RETURN_IF_ERROR(graph->RemoveProducer(input_id));
462 RETURN_IF_ERROR(graph->SetProducer(producer->id, output_id));
463 }
464
465 RETURN_IF_ERROR(graph->DeleteValue(input_id));
466
467 const auto output_consumers = graph->FindConsumers(output_id);
468 if (!producer && output_consumers.empty()) {
469 RETURN_IF_ERROR(graph->DeleteValue(output_id));
470 }
471 return absl::OkStatus();
472 }
473
AddOutput(GraphFloat32 * graph,const Node * from_node,Value ** output)474 absl::Status AddOutput(GraphFloat32* graph, const Node* from_node,
475 Value** output) {
476 auto link = graph->NewValue();
477 RETURN_IF_ERROR(graph->SetProducer(from_node->id, link->id));
478 *output = link;
479 return absl::OkStatus();
480 }
481
ConnectTwoNodes(GraphFloat32 * graph,const Node * from_node,const Node * to_node,Value ** output)482 absl::Status ConnectTwoNodes(GraphFloat32* graph, const Node* from_node,
483 const Node* to_node, Value** output) {
484 const Node* output_producer =
485 *output ? graph->FindProducer((*output)->id) : nullptr;
486 // Output is already initialized, but producer is not from_node.
487 if (*output && output_producer && output_producer->id != from_node->id) {
488 return absl::InvalidArgumentError("Wrong output is passed.");
489 }
490 // Output is already initialized, and producer is from_node.
491 if (*output) {
492 RETURN_IF_ERROR(graph->AddConsumer(to_node->id, (*output)->id));
493 } else {
494 // Output is not initialized.
495 Value* link;
496 RETURN_IF_ERROR(AddOutput(graph, from_node, &link));
497 RETURN_IF_ERROR(graph->AddConsumer(to_node->id, link->id));
498 *output = link;
499 }
500 return absl::OkStatus();
501 }
502
IsBatchMatchesForAllValues(const GraphFloat32 & model)503 bool IsBatchMatchesForAllValues(const GraphFloat32& model) {
504 if (model.values().empty()) return true;
505 const int32_t b = model.values()[0]->tensor.shape.b;
506 for (auto value : model.values()) {
507 if (value->tensor.shape.b != b) {
508 return false;
509 }
510 }
511 return true;
512 }
513
514 } // namespace gpu
515 } // namespace tflite
516