1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/model.h"
17
18 #include <gmock/gmock.h>
19 #include <gtest/gtest.h>
20 #include "absl/status/status.h"
21
22 namespace tflite {
23 namespace gpu {
24 namespace {
25
26 using ::testing::ElementsAre;
27 using ::testing::UnorderedElementsAre;
28
TEST(Model,SingleNode)29 TEST(Model, SingleNode) {
30 // graph_input -> node -> graph_output
31 GraphFloat32 graph;
32 Node* node = graph.NewNode();
33 Value* graph_input = graph.NewValue();
34 Value* graph_output = graph.NewValue();
35 ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
36 ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
37
38 EXPECT_THAT(graph.nodes(), ElementsAre(node));
39 EXPECT_THAT(graph.values(), UnorderedElementsAre(graph_input, graph_output));
40 EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input));
41 EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output));
42 EXPECT_THAT(graph.FindInputs(node->id), UnorderedElementsAre(graph_input));
43 EXPECT_THAT(graph.FindOutputs(node->id), UnorderedElementsAre(graph_output));
44 EXPECT_THAT(graph.FindConsumers(graph_input->id), UnorderedElementsAre(node));
45 EXPECT_THAT(graph.FindProducer(graph_output->id), ::testing::Eq(node));
46 EXPECT_THAT(graph.FindConsumers(graph_output->id), UnorderedElementsAre());
47 EXPECT_THAT(graph.FindProducer(graph_input->id), ::testing::Eq(nullptr));
48 }
49
TEST(Model,SingleNodeMultipleOutputs)50 TEST(Model, SingleNodeMultipleOutputs) {
51 // graph_input -> node -> (graph_output1, graph_output2)
52 GraphFloat32 graph;
53 Node* node = graph.NewNode();
54 Value* graph_input = graph.NewValue();
55 Value* graph_output1 = graph.NewValue();
56 Value* graph_output2 = graph.NewValue();
57 ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
58 ASSERT_TRUE(graph.SetProducer(node->id, graph_output1->id).ok());
59 ASSERT_TRUE(graph.SetProducer(node->id, graph_output2->id).ok());
60 EXPECT_THAT(graph.FindOutputs(node->id),
61 UnorderedElementsAre(graph_output1, graph_output2));
62 EXPECT_THAT(graph.FindProducer(graph_output1->id), ::testing::Eq(node));
63 EXPECT_THAT(graph.FindProducer(graph_output2->id), ::testing::Eq(node));
64 }
65
TEST(Model,SetSameConsumer)66 TEST(Model, SetSameConsumer) {
67 GraphFloat32 graph;
68 Node* node = graph.NewNode();
69 Value* graph_input = graph.NewValue();
70 ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
71 EXPECT_FALSE(graph.AddConsumer(node->id, graph_input->id).ok());
72 }
73
TEST(Model,RemoveConsumer)74 TEST(Model, RemoveConsumer) {
75 // (graph_input1, graph_input2) -> node
76 GraphFloat32 graph;
77 Node* node = graph.NewNode();
78 Value* graph_input1 = graph.NewValue();
79 Value* graph_input2 = graph.NewValue();
80 ASSERT_TRUE(graph.AddConsumer(node->id, graph_input1->id).ok());
81 ASSERT_TRUE(graph.AddConsumer(node->id, graph_input2->id).ok());
82 EXPECT_THAT(graph.FindConsumers(graph_input1->id),
83 UnorderedElementsAre(node));
84 EXPECT_THAT(graph.FindConsumers(graph_input2->id),
85 UnorderedElementsAre(node));
86 EXPECT_THAT(graph.FindInputs(node->id),
87 UnorderedElementsAre(graph_input1, graph_input2));
88 EXPECT_THAT(graph.outputs(), UnorderedElementsAre());
89
90 // Now remove graph_input1
91 ASSERT_TRUE(graph.RemoveConsumer(node->id, graph_input1->id).ok());
92 EXPECT_THAT(graph.FindConsumers(graph_input1->id), UnorderedElementsAre());
93 EXPECT_THAT(graph.FindInputs(node->id), UnorderedElementsAre(graph_input2));
94 EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_input1));
95
96 // Can not remove it twice
97 ASSERT_FALSE(graph.RemoveConsumer(node->id, graph_input1->id).ok());
98 }
99
TEST(Model,SetSameProducer)100 TEST(Model, SetSameProducer) {
101 GraphFloat32 graph;
102 Node* node = graph.NewNode();
103 Value* graph_output = graph.NewValue();
104 ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
105 EXPECT_FALSE(graph.SetProducer(node->id, graph_output->id).ok());
106 }
107
TEST(Model,ReplaceInput)108 TEST(Model, ReplaceInput) {
109 GraphFloat32 graph;
110 Node* node = graph.NewNode();
111 Value* v0 = graph.NewValue();
112 Value* v1 = graph.NewValue();
113 Value* v2 = graph.NewValue();
114 Value* v3 = graph.NewValue();
115 ASSERT_TRUE(graph.AddConsumer(node->id, v0->id).ok());
116 ASSERT_TRUE(graph.AddConsumer(node->id, v1->id).ok());
117 ASSERT_TRUE(graph.AddConsumer(node->id, v2->id).ok());
118 EXPECT_THAT(graph.FindInputs(node->id), ElementsAre(v0, v1, v2));
119 ASSERT_TRUE(graph.ReplaceInput(node->id, v1->id, v3->id).ok());
120 EXPECT_THAT(graph.FindInputs(node->id), ElementsAre(v0, v3, v2));
121 }
122
TEST(Model,RemoveProducer)123 TEST(Model, RemoveProducer) {
124 GraphFloat32 graph;
125 Node* node = graph.NewNode();
126 Value* graph_output = graph.NewValue();
127
128 ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
129 EXPECT_THAT(graph.inputs(), UnorderedElementsAre());
130 EXPECT_THAT(graph.FindProducer(graph_output->id), ::testing::Eq(node));
131
132 ASSERT_TRUE(graph.RemoveProducer(graph_output->id).ok());
133 EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_output));
134 EXPECT_THAT(graph.FindProducer(graph_output->id), ::testing::Eq(nullptr));
135
136 // Can not remove producer twice
137 ASSERT_FALSE(graph.RemoveProducer(graph_output->id).ok());
138 }
139
140 class OneNodeModel : public testing::Test {
141 protected:
SetUp()142 void SetUp() override {
143 node_ = graph_.NewNode();
144 Value* graph_input = graph_.NewValue();
145 Value* graph_output = graph_.NewValue();
146 ASSERT_TRUE(graph_.AddConsumer(node_->id, graph_input->id).ok());
147 ASSERT_TRUE(graph_.SetProducer(node_->id, graph_output->id).ok());
148 EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input));
149 EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output));
150 EXPECT_THAT(graph_.nodes(), ElementsAre(node_));
151 }
152 GraphFloat32 graph_;
153 Node* node_;
154 };
155
TEST_F(OneNodeModel,DeleteNodeKeepInput)156 TEST_F(OneNodeModel, DeleteNodeKeepInput) {
157 ASSERT_TRUE(RemoveSimpleNodeKeepInput(&graph_, node_).ok());
158 EXPECT_TRUE(graph_.inputs().empty());
159 EXPECT_TRUE(graph_.outputs().empty());
160 EXPECT_TRUE(graph_.nodes().empty());
161 }
162
TEST_F(OneNodeModel,DeleteNodeKeepOutput)163 TEST_F(OneNodeModel, DeleteNodeKeepOutput) {
164 ASSERT_TRUE(RemoveSimpleNodeKeepOutput(&graph_, node_).ok());
165 EXPECT_TRUE(graph_.inputs().empty());
166 EXPECT_TRUE(graph_.outputs().empty());
167 EXPECT_TRUE(graph_.nodes().empty());
168 }
169
170 class TwoNodesModel : public testing::Test {
171 protected:
SetUp()172 void SetUp() override {
173 graph_input_ = graph_.NewValue();
174 first_node_ = graph_.NewNode();
175 value_ = graph_.NewValue();
176 second_node_ = graph_.NewNode();
177 graph_output_ = graph_.NewValue();
178
179 ASSERT_TRUE(graph_.AddConsumer(first_node_->id, graph_input_->id).ok());
180 ASSERT_TRUE(graph_.SetProducer(first_node_->id, value_->id).ok());
181 ASSERT_TRUE(graph_.AddConsumer(second_node_->id, value_->id).ok());
182 ASSERT_TRUE(graph_.SetProducer(second_node_->id, graph_output_->id).ok());
183 EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input_));
184 EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output_));
185 EXPECT_THAT(graph_.nodes(), ElementsAre(first_node_, second_node_));
186 }
187 GraphFloat32 graph_;
188 Node* first_node_;
189 Node* second_node_;
190 Value* graph_input_;
191 Value* value_;
192 Value* graph_output_;
193 };
194
TEST_F(TwoNodesModel,DeleteFirstNodeKeepInput)195 TEST_F(TwoNodesModel, DeleteFirstNodeKeepInput) {
196 ASSERT_TRUE(RemoveSimpleNodeKeepInput(&graph_, first_node_).ok());
197 EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input_));
198 EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output_));
199 EXPECT_THAT(graph_.nodes(), ElementsAre(second_node_));
200 }
201
TEST_F(TwoNodesModel,DeleteFirstNodeKeepOutput)202 TEST_F(TwoNodesModel, DeleteFirstNodeKeepOutput) {
203 ASSERT_TRUE(RemoveSimpleNodeKeepOutput(&graph_, first_node_).ok());
204 EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(value_));
205 EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output_));
206 EXPECT_THAT(graph_.nodes(), ElementsAre(second_node_));
207 }
208
TEST_F(TwoNodesModel,DeleteSecondNodeKeepInput)209 TEST_F(TwoNodesModel, DeleteSecondNodeKeepInput) {
210 ASSERT_TRUE(RemoveSimpleNodeKeepInput(&graph_, second_node_).ok());
211 EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input_));
212 EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(value_));
213 EXPECT_THAT(graph_.nodes(), ElementsAre(first_node_));
214 }
215
TEST_F(TwoNodesModel,DeleteSecondNodeKeepOutput)216 TEST_F(TwoNodesModel, DeleteSecondNodeKeepOutput) {
217 ASSERT_TRUE(RemoveSimpleNodeKeepOutput(&graph_, second_node_).ok());
218 EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input_));
219 EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output_));
220 EXPECT_THAT(graph_.nodes(), ElementsAre(first_node_));
221 }
222
223 class ThreeNodesModel : public testing::Test {
224 protected:
SetUp()225 void SetUp() override {
226 first_node_ = graph_.NewNode();
227 second_node_ = graph_.NewNode();
228 third_node_ = graph_.NewNode();
229 graph_input_ = graph_.NewValue();
230 value0_ = graph_.NewValue();
231 value1_ = graph_.NewValue();
232 graph_output_ = graph_.NewValue();
233
234 ASSERT_TRUE(graph_.AddConsumer(first_node_->id, graph_input_->id).ok());
235 ASSERT_TRUE(graph_.SetProducer(first_node_->id, value0_->id).ok());
236 ASSERT_TRUE(graph_.AddConsumer(second_node_->id, value0_->id).ok());
237 ASSERT_TRUE(graph_.SetProducer(second_node_->id, value1_->id).ok());
238 ASSERT_TRUE(graph_.AddConsumer(third_node_->id, value1_->id).ok());
239 ASSERT_TRUE(graph_.SetProducer(third_node_->id, graph_output_->id).ok());
240 EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input_));
241 EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output_));
242 EXPECT_THAT(graph_.nodes(),
243 ElementsAre(first_node_, second_node_, third_node_));
244 }
245 GraphFloat32 graph_;
246 Node* first_node_;
247 Node* second_node_;
248 Node* third_node_;
249 Value* graph_input_;
250 Value* value0_;
251 Value* value1_;
252 Value* graph_output_;
253 };
254
TEST_F(ThreeNodesModel,DeleteMiddleNodeKeepInput)255 TEST_F(ThreeNodesModel, DeleteMiddleNodeKeepInput) {
256 ASSERT_TRUE(RemoveSimpleNodeKeepInput(&graph_, second_node_).ok());
257 EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input_));
258 EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output_));
259 EXPECT_THAT(graph_.nodes(), ElementsAre(first_node_, third_node_));
260 EXPECT_THAT(graph_.values(),
261 UnorderedElementsAre(graph_input_, value0_, graph_output_));
262 }
263
TEST_F(ThreeNodesModel,DeleteMiddleNodeKeepOutput)264 TEST_F(ThreeNodesModel, DeleteMiddleNodeKeepOutput) {
265 ASSERT_TRUE(RemoveSimpleNodeKeepOutput(&graph_, second_node_).ok());
266 EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input_));
267 EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output_));
268 EXPECT_THAT(graph_.nodes(), ElementsAre(first_node_, third_node_));
269 EXPECT_THAT(graph_.values(),
270 UnorderedElementsAre(graph_input_, value1_, graph_output_));
271 }
272
TEST(Model,RemoveSimpleNodeKeepInputComplexCase)273 TEST(Model, RemoveSimpleNodeKeepInputComplexCase) {
274 // We have this graph and we are going to delete n1 and preserve order of
275 // v0, v1 for n0 node and v2, v3 for n2 node
276 // v0 v1
277 // \ / \
278 // n0 n1
279 // | \
280 // o1 v2 v3
281 // \ /
282 // n2
283 // |
284 // o2
285 //
286 // And we are going to receive this:
287 // v0 v1
288 // \ / \
289 // n0 \
290 // | \
291 // o1 \ v3
292 // \ /
293 // n2
294 // |
295 // o2
296 GraphFloat32 graph;
297 Node* n0 = graph.NewNode();
298 Node* n1 = graph.NewNode(); // node to remove
299 Node* n2 = graph.NewNode();
300 Value* v0 = graph.NewValue();
301 Value* v1 = graph.NewValue();
302 Value* v2 = graph.NewValue(); // value to be removed
303 Value* v3 = graph.NewValue();
304 Value* o1 = graph.NewValue();
305 Value* o2 = graph.NewValue();
306
307 ASSERT_TRUE(graph.AddConsumer(n0->id, v0->id).ok());
308 ASSERT_TRUE(graph.AddConsumer(n0->id, v1->id).ok());
309 ASSERT_TRUE(graph.SetProducer(n0->id, o1->id).ok());
310 ASSERT_TRUE(graph.AddConsumer(n1->id, v1->id).ok());
311 ASSERT_TRUE(graph.SetProducer(n1->id, v2->id).ok());
312 ASSERT_TRUE(graph.AddConsumer(n2->id, v2->id).ok());
313 ASSERT_TRUE(graph.AddConsumer(n2->id, v3->id).ok());
314 ASSERT_TRUE(graph.SetProducer(n2->id, o2->id).ok());
315 EXPECT_THAT(graph.inputs(), UnorderedElementsAre(v0, v1, v3));
316 EXPECT_THAT(graph.outputs(), UnorderedElementsAre(o1, o2));
317 EXPECT_THAT(graph.nodes(), ElementsAre(n0, n1, n2));
318
319 // Node should be the only consumer of the input value to be able to be
320 // deleted with this function.
321 ASSERT_FALSE(RemoveSimpleNodeKeepOutput(&graph, n1).ok());
322
323 ASSERT_TRUE(RemoveSimpleNodeKeepInput(&graph, n1).ok());
324 EXPECT_THAT(graph.inputs(), UnorderedElementsAre(v0, v1, v3));
325 EXPECT_THAT(graph.outputs(), UnorderedElementsAre(o1, o2));
326 EXPECT_THAT(graph.nodes(), ElementsAre(n0, n2));
327 EXPECT_THAT(graph.values(), UnorderedElementsAre(v0, v1, v3, o1, o2));
328 EXPECT_THAT(graph.FindInputs(n0->id), ElementsAre(v0, v1));
329 EXPECT_THAT(graph.FindInputs(n2->id), ElementsAre(v1, v3));
330 }
331
TEST(Model,CircularDependency)332 TEST(Model, CircularDependency) {
333 {
334 GraphFloat32 graph;
335 Node* node = graph.NewNode();
336 Value* value = graph.NewValue();
337 ASSERT_TRUE(graph.AddConsumer(node->id, value->id).ok());
338 EXPECT_FALSE(graph.SetProducer(node->id, value->id).ok());
339 }
340 {
341 GraphFloat32 graph;
342 Node* node = graph.NewNode();
343 Value* value = graph.NewValue();
344 ASSERT_TRUE(graph.SetProducer(node->id, value->id).ok());
345 EXPECT_FALSE(graph.AddConsumer(node->id, value->id).ok());
346 }
347 }
348
TEST(Model,ReassignValue)349 TEST(Model, ReassignValue) {
350 // Before:
351 // graph_input -> node1 -> graph_output
352 // \ -> node2
353 GraphFloat32 graph;
354 Node* node1 = graph.NewNode();
355 Node* node2 = graph.NewNode();
356 Value* graph_input = graph.NewValue();
357 Value* graph_output = graph.NewValue();
358 ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
359 ASSERT_TRUE(graph.SetProducer(node1->id, graph_output->id).ok());
360 ASSERT_TRUE(graph.AddConsumer(node2->id, graph_input->id).ok());
361
362 // After:
363 // graph_input -> node1
364 // \ -> node2 -> graph_output
365 ASSERT_TRUE(graph.SetProducer(node2->id, graph_output->id).ok());
366
367 EXPECT_THAT(graph.nodes(), ElementsAre(node1, node2));
368 EXPECT_THAT(graph.FindInputs(node1->id), UnorderedElementsAre(graph_input));
369 EXPECT_THAT(graph.FindInputs(node2->id), UnorderedElementsAre(graph_input));
370 EXPECT_THAT(graph.FindOutputs(node1->id), UnorderedElementsAre());
371 EXPECT_THAT(graph.FindOutputs(node2->id), UnorderedElementsAre(graph_output));
372 EXPECT_THAT(graph.FindConsumers(graph_input->id),
373 UnorderedElementsAre(node1, node2));
374 EXPECT_THAT(graph.FindProducer(graph_output->id), ::testing::Eq(node2));
375 EXPECT_THAT(graph.FindConsumers(graph_output->id), UnorderedElementsAre());
376 }
377
TEST(Model,DeleteValue)378 TEST(Model, DeleteValue) {
379 // graph_input -> node1 -> value -> node2 -> graph_output
380 GraphFloat32 graph;
381 Node* node1 = graph.NewNode();
382 Node* node2 = graph.NewNode();
383 Value* graph_input = graph.NewValue();
384 Value* graph_output = graph.NewValue();
385 Value* value = graph.NewValue();
386 ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
387 ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok());
388 ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok());
389 ASSERT_TRUE(graph.SetProducer(node2->id, graph_output->id).ok());
390
391 EXPECT_THAT(graph.values(),
392 UnorderedElementsAre(graph_input, graph_output, value));
393 EXPECT_THAT(graph.FindConsumers(value->id), UnorderedElementsAre(node2));
394 EXPECT_THAT(graph.FindProducer(value->id), ::testing::Eq(node1));
395 EXPECT_THAT(graph.FindInputs(node2->id), UnorderedElementsAre(value));
396 EXPECT_THAT(graph.FindOutputs(node1->id), UnorderedElementsAre(value));
397
398 ASSERT_TRUE(graph.DeleteValue(value->id).ok());
399 value = nullptr;
400 EXPECT_THAT(graph.values(), UnorderedElementsAre(graph_input, graph_output));
401 EXPECT_THAT(graph.FindInputs(node2->id), UnorderedElementsAre());
402 EXPECT_THAT(graph.FindOutputs(node1->id), UnorderedElementsAre());
403
404 ASSERT_TRUE(graph.DeleteValue(graph_input->id).ok());
405 graph_input = nullptr;
406 EXPECT_THAT(graph.values(), UnorderedElementsAre(graph_output));
407 EXPECT_THAT(graph.inputs(), UnorderedElementsAre());
408 EXPECT_THAT(graph.FindInputs(node1->id), UnorderedElementsAre());
409
410 ASSERT_TRUE(graph.DeleteValue(graph_output->id).ok());
411 graph_output = nullptr;
412 EXPECT_THAT(graph.values(), UnorderedElementsAre());
413 EXPECT_THAT(graph.outputs(), UnorderedElementsAre());
414 EXPECT_THAT(graph.FindOutputs(node2->id), UnorderedElementsAre());
415 }
416
TEST(Model,DeleteNode)417 TEST(Model, DeleteNode) {
418 // graph_input -> node1 -> value -> node2 -> graph_output
419 // \-> node3 -> graph_output2
420 GraphFloat32 graph;
421 Node* node1 = graph.NewNode();
422 Node* node2 = graph.NewNode();
423 Node* node3 = graph.NewNode();
424 Value* graph_input = graph.NewValue();
425 Value* graph_output = graph.NewValue();
426 Value* graph_output2 = graph.NewValue();
427 Value* value = graph.NewValue();
428 ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
429 ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok());
430 ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok());
431 ASSERT_TRUE(graph.AddConsumer(node3->id, value->id).ok());
432 ASSERT_TRUE(graph.SetProducer(node2->id, graph_output->id).ok());
433 ASSERT_TRUE(graph.SetProducer(node3->id, graph_output2->id).ok());
434
435 EXPECT_THAT(graph.nodes(), ElementsAre(node1, node2, node3));
436 EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input));
437 EXPECT_THAT(graph.outputs(),
438 UnorderedElementsAre(graph_output, graph_output2));
439 EXPECT_THAT(graph.FindConsumers(value->id),
440 UnorderedElementsAre(node2, node3));
441 EXPECT_THAT(graph.FindProducer(value->id), ::testing::Eq(node1));
442 EXPECT_THAT(graph.FindInputs(node2->id), UnorderedElementsAre(value));
443 EXPECT_THAT(graph.FindInputs(node3->id), UnorderedElementsAre(value));
444
445 // graph_input -> node1 -> value -> node2 -> graph_output
446 // graph_output2
447 ASSERT_TRUE(graph.DeleteNode(node3->id).ok());
448 node3 = nullptr;
449 EXPECT_THAT(graph.nodes(), ElementsAre(node1, node2));
450 EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input, graph_output2));
451 EXPECT_THAT(graph.outputs(),
452 UnorderedElementsAre(graph_output, graph_output2));
453 EXPECT_THAT(graph.FindConsumers(value->id), UnorderedElementsAre(node2));
454
455 // value -> node2 -> graph_output
456 // graph_input
457 // graph_output2
458 ASSERT_TRUE(graph.DeleteNode(node1->id).ok());
459 node1 = nullptr;
460 EXPECT_THAT(graph.nodes(), ElementsAre(node2));
461 EXPECT_THAT(graph.inputs(),
462 UnorderedElementsAre(value, graph_output2, graph_input));
463 EXPECT_THAT(graph.outputs(),
464 UnorderedElementsAre(graph_input, graph_output, graph_output2));
465 EXPECT_THAT(graph.FindConsumers(value->id), UnorderedElementsAre(node2));
466 EXPECT_THAT(graph.FindProducer(value->id), ::testing::Eq(nullptr));
467
468 ASSERT_TRUE(graph.DeleteNode(node2->id).ok());
469 node2 = nullptr;
470 EXPECT_THAT(graph.nodes(), ElementsAre());
471 EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_output, graph_output2,
472 graph_input, value));
473 EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output, graph_output2,
474 graph_input, value));
475 EXPECT_THAT(graph.FindConsumers(value->id), UnorderedElementsAre());
476 EXPECT_THAT(graph.FindProducer(value->id), ::testing::Eq(nullptr));
477 }
478
TEST(Model,InsertNodeAfter)479 TEST(Model, InsertNodeAfter) {
480 // graph_input -> node1 -> value -> node2 -> graph_output
481 GraphFloat32 graph;
482 Node* node1 = graph.NewNode();
483 Node* node2 = graph.NewNode();
484 Value* graph_input = graph.NewValue();
485 Value* graph_output = graph.NewValue();
486 Value* value = graph.NewValue();
487 ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
488 ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok());
489 ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok());
490 ASSERT_TRUE(graph.SetProducer(node2->id, graph_output->id).ok());
491
492 EXPECT_THAT(graph.nodes(), ElementsAre(node1, node2));
493 EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input));
494 EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output));
495 EXPECT_THAT(graph.FindConsumers(value->id), UnorderedElementsAre(node2));
496 EXPECT_THAT(graph.FindProducer(value->id), ::testing::Eq(node1));
497 EXPECT_THAT(graph.FindInputs(node2->id), UnorderedElementsAre(value));
498
499 Node* new_node1;
500 absl::Status status = graph.InsertNodeAfter(node1->id, &new_node1);
501 ASSERT_TRUE(status.ok());
502 EXPECT_THAT(graph.nodes(), ElementsAre(node1, new_node1, node2));
503
504 Node* new_node2;
505 status = graph.InsertNodeAfter(/*id=*/100, &new_node2);
506 EXPECT_EQ(status.code(), absl::StatusCode::kOutOfRange);
507
508 status = graph.InsertNodeAfter(node2->id, &new_node2);
509 ASSERT_TRUE(status.ok());
510 EXPECT_THAT(graph.nodes(), ElementsAre(node1, new_node1, node2, new_node2));
511 }
512
TEST(BatchMatchingTest,EmptyGraph)513 TEST(BatchMatchingTest, EmptyGraph) {
514 GraphFloat32 graph;
515 ASSERT_TRUE(IsBatchMatchesForAllValues(graph));
516 }
517
TEST(BatchMatchingTest,AllMatch)518 TEST(BatchMatchingTest, AllMatch) {
519 GraphFloat32 graph;
520 Value* a = graph.NewValue();
521 Value* b = graph.NewValue();
522 a->tensor.shape = BHWC(1, 1, 1, 1);
523 b->tensor.shape = BHWC(1, 1, 1, 1);
524 ASSERT_TRUE(IsBatchMatchesForAllValues(graph));
525 }
526
TEST(BatchMatchingTest,NotAllMatch)527 TEST(BatchMatchingTest, NotAllMatch) {
528 GraphFloat32 graph;
529 Value* a = graph.NewValue();
530 Value* b = graph.NewValue();
531 a->tensor.shape = BHWC(1, 1, 1, 1);
532 b->tensor.shape = BHWC(2, 1, 1, 1);
533 ASSERT_FALSE(IsBatchMatchesForAllValues(graph));
534 }
535
536 } // namespace
537 } // namespace gpu
538 } // namespace tflite
539