• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/model.h"
17 
18 #include <gmock/gmock.h>
19 #include <gtest/gtest.h>
20 #include "absl/status/status.h"
21 
22 namespace tflite {
23 namespace gpu {
24 namespace {
25 
26 using ::testing::ElementsAre;
27 using ::testing::UnorderedElementsAre;
28 
TEST(Model,SingleNode)29 TEST(Model, SingleNode) {
30   // graph_input -> node -> graph_output
31   GraphFloat32 graph;
32   Node* node = graph.NewNode();
33   Value* graph_input = graph.NewValue();
34   Value* graph_output = graph.NewValue();
35   ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
36   ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
37 
38   EXPECT_THAT(graph.nodes(), ElementsAre(node));
39   EXPECT_THAT(graph.values(), UnorderedElementsAre(graph_input, graph_output));
40   EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input));
41   EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output));
42   EXPECT_THAT(graph.FindInputs(node->id), UnorderedElementsAre(graph_input));
43   EXPECT_THAT(graph.FindOutputs(node->id), UnorderedElementsAre(graph_output));
44   EXPECT_THAT(graph.FindConsumers(graph_input->id), UnorderedElementsAre(node));
45   EXPECT_THAT(graph.FindProducer(graph_output->id), ::testing::Eq(node));
46   EXPECT_THAT(graph.FindConsumers(graph_output->id), UnorderedElementsAre());
47   EXPECT_THAT(graph.FindProducer(graph_input->id), ::testing::Eq(nullptr));
48 }
49 
TEST(Model,SingleNodeMultipleOutputs)50 TEST(Model, SingleNodeMultipleOutputs) {
51   // graph_input -> node -> (graph_output1, graph_output2)
52   GraphFloat32 graph;
53   Node* node = graph.NewNode();
54   Value* graph_input = graph.NewValue();
55   Value* graph_output1 = graph.NewValue();
56   Value* graph_output2 = graph.NewValue();
57   ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
58   ASSERT_TRUE(graph.SetProducer(node->id, graph_output1->id).ok());
59   ASSERT_TRUE(graph.SetProducer(node->id, graph_output2->id).ok());
60   EXPECT_THAT(graph.FindOutputs(node->id),
61               UnorderedElementsAre(graph_output1, graph_output2));
62   EXPECT_THAT(graph.FindProducer(graph_output1->id), ::testing::Eq(node));
63   EXPECT_THAT(graph.FindProducer(graph_output2->id), ::testing::Eq(node));
64 }
65 
TEST(Model,SetSameConsumer)66 TEST(Model, SetSameConsumer) {
67   GraphFloat32 graph;
68   Node* node = graph.NewNode();
69   Value* graph_input = graph.NewValue();
70   ASSERT_TRUE(graph.AddConsumer(node->id, graph_input->id).ok());
71   EXPECT_FALSE(graph.AddConsumer(node->id, graph_input->id).ok());
72 }
73 
TEST(Model,RemoveConsumer)74 TEST(Model, RemoveConsumer) {
75   // (graph_input1, graph_input2) -> node
76   GraphFloat32 graph;
77   Node* node = graph.NewNode();
78   Value* graph_input1 = graph.NewValue();
79   Value* graph_input2 = graph.NewValue();
80   ASSERT_TRUE(graph.AddConsumer(node->id, graph_input1->id).ok());
81   ASSERT_TRUE(graph.AddConsumer(node->id, graph_input2->id).ok());
82   EXPECT_THAT(graph.FindConsumers(graph_input1->id),
83               UnorderedElementsAre(node));
84   EXPECT_THAT(graph.FindConsumers(graph_input2->id),
85               UnorderedElementsAre(node));
86   EXPECT_THAT(graph.FindInputs(node->id),
87               UnorderedElementsAre(graph_input1, graph_input2));
88   EXPECT_THAT(graph.outputs(), UnorderedElementsAre());
89 
90   // Now remove graph_input1
91   ASSERT_TRUE(graph.RemoveConsumer(node->id, graph_input1->id).ok());
92   EXPECT_THAT(graph.FindConsumers(graph_input1->id), UnorderedElementsAre());
93   EXPECT_THAT(graph.FindInputs(node->id), UnorderedElementsAre(graph_input2));
94   EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_input1));
95 
96   // Can not remove it twice
97   ASSERT_FALSE(graph.RemoveConsumer(node->id, graph_input1->id).ok());
98 }
99 
TEST(Model,SetSameProducer)100 TEST(Model, SetSameProducer) {
101   GraphFloat32 graph;
102   Node* node = graph.NewNode();
103   Value* graph_output = graph.NewValue();
104   ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
105   EXPECT_FALSE(graph.SetProducer(node->id, graph_output->id).ok());
106 }
107 
TEST(Model,ReplaceInput)108 TEST(Model, ReplaceInput) {
109   GraphFloat32 graph;
110   Node* node = graph.NewNode();
111   Value* v0 = graph.NewValue();
112   Value* v1 = graph.NewValue();
113   Value* v2 = graph.NewValue();
114   Value* v3 = graph.NewValue();
115   ASSERT_TRUE(graph.AddConsumer(node->id, v0->id).ok());
116   ASSERT_TRUE(graph.AddConsumer(node->id, v1->id).ok());
117   ASSERT_TRUE(graph.AddConsumer(node->id, v2->id).ok());
118   EXPECT_THAT(graph.FindInputs(node->id), ElementsAre(v0, v1, v2));
119   ASSERT_TRUE(graph.ReplaceInput(node->id, v1->id, v3->id).ok());
120   EXPECT_THAT(graph.FindInputs(node->id), ElementsAre(v0, v3, v2));
121 }
122 
TEST(Model,RemoveProducer)123 TEST(Model, RemoveProducer) {
124   GraphFloat32 graph;
125   Node* node = graph.NewNode();
126   Value* graph_output = graph.NewValue();
127 
128   ASSERT_TRUE(graph.SetProducer(node->id, graph_output->id).ok());
129   EXPECT_THAT(graph.inputs(), UnorderedElementsAre());
130   EXPECT_THAT(graph.FindProducer(graph_output->id), ::testing::Eq(node));
131 
132   ASSERT_TRUE(graph.RemoveProducer(graph_output->id).ok());
133   EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_output));
134   EXPECT_THAT(graph.FindProducer(graph_output->id), ::testing::Eq(nullptr));
135 
136   // Can not remove producer twice
137   ASSERT_FALSE(graph.RemoveProducer(graph_output->id).ok());
138 }
139 
140 class OneNodeModel : public testing::Test {
141  protected:
SetUp()142   void SetUp() override {
143     node_ = graph_.NewNode();
144     Value* graph_input = graph_.NewValue();
145     Value* graph_output = graph_.NewValue();
146     ASSERT_TRUE(graph_.AddConsumer(node_->id, graph_input->id).ok());
147     ASSERT_TRUE(graph_.SetProducer(node_->id, graph_output->id).ok());
148     EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input));
149     EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output));
150     EXPECT_THAT(graph_.nodes(), ElementsAre(node_));
151   }
152   GraphFloat32 graph_;
153   Node* node_;
154 };
155 
TEST_F(OneNodeModel,DeleteNodeKeepInput)156 TEST_F(OneNodeModel, DeleteNodeKeepInput) {
157   ASSERT_TRUE(RemoveSimpleNodeKeepInput(&graph_, node_).ok());
158   EXPECT_TRUE(graph_.inputs().empty());
159   EXPECT_TRUE(graph_.outputs().empty());
160   EXPECT_TRUE(graph_.nodes().empty());
161 }
162 
TEST_F(OneNodeModel,DeleteNodeKeepOutput)163 TEST_F(OneNodeModel, DeleteNodeKeepOutput) {
164   ASSERT_TRUE(RemoveSimpleNodeKeepOutput(&graph_, node_).ok());
165   EXPECT_TRUE(graph_.inputs().empty());
166   EXPECT_TRUE(graph_.outputs().empty());
167   EXPECT_TRUE(graph_.nodes().empty());
168 }
169 
170 class TwoNodesModel : public testing::Test {
171  protected:
SetUp()172   void SetUp() override {
173     graph_input_ = graph_.NewValue();
174     first_node_ = graph_.NewNode();
175     value_ = graph_.NewValue();
176     second_node_ = graph_.NewNode();
177     graph_output_ = graph_.NewValue();
178 
179     ASSERT_TRUE(graph_.AddConsumer(first_node_->id, graph_input_->id).ok());
180     ASSERT_TRUE(graph_.SetProducer(first_node_->id, value_->id).ok());
181     ASSERT_TRUE(graph_.AddConsumer(second_node_->id, value_->id).ok());
182     ASSERT_TRUE(graph_.SetProducer(second_node_->id, graph_output_->id).ok());
183     EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input_));
184     EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output_));
185     EXPECT_THAT(graph_.nodes(), ElementsAre(first_node_, second_node_));
186   }
187   GraphFloat32 graph_;
188   Node* first_node_;
189   Node* second_node_;
190   Value* graph_input_;
191   Value* value_;
192   Value* graph_output_;
193 };
194 
TEST_F(TwoNodesModel,DeleteFirstNodeKeepInput)195 TEST_F(TwoNodesModel, DeleteFirstNodeKeepInput) {
196   ASSERT_TRUE(RemoveSimpleNodeKeepInput(&graph_, first_node_).ok());
197   EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input_));
198   EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output_));
199   EXPECT_THAT(graph_.nodes(), ElementsAre(second_node_));
200 }
201 
TEST_F(TwoNodesModel,DeleteFirstNodeKeepOutput)202 TEST_F(TwoNodesModel, DeleteFirstNodeKeepOutput) {
203   ASSERT_TRUE(RemoveSimpleNodeKeepOutput(&graph_, first_node_).ok());
204   EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(value_));
205   EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output_));
206   EXPECT_THAT(graph_.nodes(), ElementsAre(second_node_));
207 }
208 
TEST_F(TwoNodesModel,DeleteSecondNodeKeepInput)209 TEST_F(TwoNodesModel, DeleteSecondNodeKeepInput) {
210   ASSERT_TRUE(RemoveSimpleNodeKeepInput(&graph_, second_node_).ok());
211   EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input_));
212   EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(value_));
213   EXPECT_THAT(graph_.nodes(), ElementsAre(first_node_));
214 }
215 
TEST_F(TwoNodesModel,DeleteSecondNodeKeepOutput)216 TEST_F(TwoNodesModel, DeleteSecondNodeKeepOutput) {
217   ASSERT_TRUE(RemoveSimpleNodeKeepOutput(&graph_, second_node_).ok());
218   EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input_));
219   EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output_));
220   EXPECT_THAT(graph_.nodes(), ElementsAre(first_node_));
221 }
222 
223 class ThreeNodesModel : public testing::Test {
224  protected:
SetUp()225   void SetUp() override {
226     first_node_ = graph_.NewNode();
227     second_node_ = graph_.NewNode();
228     third_node_ = graph_.NewNode();
229     graph_input_ = graph_.NewValue();
230     value0_ = graph_.NewValue();
231     value1_ = graph_.NewValue();
232     graph_output_ = graph_.NewValue();
233 
234     ASSERT_TRUE(graph_.AddConsumer(first_node_->id, graph_input_->id).ok());
235     ASSERT_TRUE(graph_.SetProducer(first_node_->id, value0_->id).ok());
236     ASSERT_TRUE(graph_.AddConsumer(second_node_->id, value0_->id).ok());
237     ASSERT_TRUE(graph_.SetProducer(second_node_->id, value1_->id).ok());
238     ASSERT_TRUE(graph_.AddConsumer(third_node_->id, value1_->id).ok());
239     ASSERT_TRUE(graph_.SetProducer(third_node_->id, graph_output_->id).ok());
240     EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input_));
241     EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output_));
242     EXPECT_THAT(graph_.nodes(),
243                 ElementsAre(first_node_, second_node_, third_node_));
244   }
245   GraphFloat32 graph_;
246   Node* first_node_;
247   Node* second_node_;
248   Node* third_node_;
249   Value* graph_input_;
250   Value* value0_;
251   Value* value1_;
252   Value* graph_output_;
253 };
254 
TEST_F(ThreeNodesModel,DeleteMiddleNodeKeepInput)255 TEST_F(ThreeNodesModel, DeleteMiddleNodeKeepInput) {
256   ASSERT_TRUE(RemoveSimpleNodeKeepInput(&graph_, second_node_).ok());
257   EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input_));
258   EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output_));
259   EXPECT_THAT(graph_.nodes(), ElementsAre(first_node_, third_node_));
260   EXPECT_THAT(graph_.values(),
261               UnorderedElementsAre(graph_input_, value0_, graph_output_));
262 }
263 
TEST_F(ThreeNodesModel,DeleteMiddleNodeKeepOutput)264 TEST_F(ThreeNodesModel, DeleteMiddleNodeKeepOutput) {
265   ASSERT_TRUE(RemoveSimpleNodeKeepOutput(&graph_, second_node_).ok());
266   EXPECT_THAT(graph_.inputs(), UnorderedElementsAre(graph_input_));
267   EXPECT_THAT(graph_.outputs(), UnorderedElementsAre(graph_output_));
268   EXPECT_THAT(graph_.nodes(), ElementsAre(first_node_, third_node_));
269   EXPECT_THAT(graph_.values(),
270               UnorderedElementsAre(graph_input_, value1_, graph_output_));
271 }
272 
TEST(Model,RemoveSimpleNodeKeepInputComplexCase)273 TEST(Model, RemoveSimpleNodeKeepInputComplexCase) {
274   // We have this graph and we are going to delete n1 and preserve order of
275   // v0, v1 for n0 node and v2, v3 for n2 node
276   //  v0   v1
277   //   \  /  \
278   //    n0    n1
279   //    |      \
280   //    o1      v2   v3
281   //             \  /
282   //              n2
283   //              |
284   //              o2
285   //
286   // And we are going to receive this:
287   //  v0   v1
288   //   \  /  \
289   //    n0    \
290   //    |      \
291   //    o1      \   v3
292   //             \  /
293   //              n2
294   //              |
295   //              o2
296   GraphFloat32 graph;
297   Node* n0 = graph.NewNode();
298   Node* n1 = graph.NewNode();  // node to remove
299   Node* n2 = graph.NewNode();
300   Value* v0 = graph.NewValue();
301   Value* v1 = graph.NewValue();
302   Value* v2 = graph.NewValue();  // value to be removed
303   Value* v3 = graph.NewValue();
304   Value* o1 = graph.NewValue();
305   Value* o2 = graph.NewValue();
306 
307   ASSERT_TRUE(graph.AddConsumer(n0->id, v0->id).ok());
308   ASSERT_TRUE(graph.AddConsumer(n0->id, v1->id).ok());
309   ASSERT_TRUE(graph.SetProducer(n0->id, o1->id).ok());
310   ASSERT_TRUE(graph.AddConsumer(n1->id, v1->id).ok());
311   ASSERT_TRUE(graph.SetProducer(n1->id, v2->id).ok());
312   ASSERT_TRUE(graph.AddConsumer(n2->id, v2->id).ok());
313   ASSERT_TRUE(graph.AddConsumer(n2->id, v3->id).ok());
314   ASSERT_TRUE(graph.SetProducer(n2->id, o2->id).ok());
315   EXPECT_THAT(graph.inputs(), UnorderedElementsAre(v0, v1, v3));
316   EXPECT_THAT(graph.outputs(), UnorderedElementsAre(o1, o2));
317   EXPECT_THAT(graph.nodes(), ElementsAre(n0, n1, n2));
318 
319   // Node should be the only consumer of the input value to be able to be
320   // deleted with this function.
321   ASSERT_FALSE(RemoveSimpleNodeKeepOutput(&graph, n1).ok());
322 
323   ASSERT_TRUE(RemoveSimpleNodeKeepInput(&graph, n1).ok());
324   EXPECT_THAT(graph.inputs(), UnorderedElementsAre(v0, v1, v3));
325   EXPECT_THAT(graph.outputs(), UnorderedElementsAre(o1, o2));
326   EXPECT_THAT(graph.nodes(), ElementsAre(n0, n2));
327   EXPECT_THAT(graph.values(), UnorderedElementsAre(v0, v1, v3, o1, o2));
328   EXPECT_THAT(graph.FindInputs(n0->id), ElementsAre(v0, v1));
329   EXPECT_THAT(graph.FindInputs(n2->id), ElementsAre(v1, v3));
330 }
331 
TEST(Model,CircularDependency)332 TEST(Model, CircularDependency) {
333   {
334     GraphFloat32 graph;
335     Node* node = graph.NewNode();
336     Value* value = graph.NewValue();
337     ASSERT_TRUE(graph.AddConsumer(node->id, value->id).ok());
338     EXPECT_FALSE(graph.SetProducer(node->id, value->id).ok());
339   }
340   {
341     GraphFloat32 graph;
342     Node* node = graph.NewNode();
343     Value* value = graph.NewValue();
344     ASSERT_TRUE(graph.SetProducer(node->id, value->id).ok());
345     EXPECT_FALSE(graph.AddConsumer(node->id, value->id).ok());
346   }
347 }
348 
TEST(Model,ReassignValue)349 TEST(Model, ReassignValue) {
350   // Before:
351   //   graph_input  -> node1 -> graph_output
352   //              \ -> node2
353   GraphFloat32 graph;
354   Node* node1 = graph.NewNode();
355   Node* node2 = graph.NewNode();
356   Value* graph_input = graph.NewValue();
357   Value* graph_output = graph.NewValue();
358   ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
359   ASSERT_TRUE(graph.SetProducer(node1->id, graph_output->id).ok());
360   ASSERT_TRUE(graph.AddConsumer(node2->id, graph_input->id).ok());
361 
362   // After:
363   //   graph_input  -> node1
364   //              \ -> node2 -> graph_output
365   ASSERT_TRUE(graph.SetProducer(node2->id, graph_output->id).ok());
366 
367   EXPECT_THAT(graph.nodes(), ElementsAre(node1, node2));
368   EXPECT_THAT(graph.FindInputs(node1->id), UnorderedElementsAre(graph_input));
369   EXPECT_THAT(graph.FindInputs(node2->id), UnorderedElementsAre(graph_input));
370   EXPECT_THAT(graph.FindOutputs(node1->id), UnorderedElementsAre());
371   EXPECT_THAT(graph.FindOutputs(node2->id), UnorderedElementsAre(graph_output));
372   EXPECT_THAT(graph.FindConsumers(graph_input->id),
373               UnorderedElementsAre(node1, node2));
374   EXPECT_THAT(graph.FindProducer(graph_output->id), ::testing::Eq(node2));
375   EXPECT_THAT(graph.FindConsumers(graph_output->id), UnorderedElementsAre());
376 }
377 
TEST(Model,DeleteValue)378 TEST(Model, DeleteValue) {
379   // graph_input  -> node1 -> value -> node2 -> graph_output
380   GraphFloat32 graph;
381   Node* node1 = graph.NewNode();
382   Node* node2 = graph.NewNode();
383   Value* graph_input = graph.NewValue();
384   Value* graph_output = graph.NewValue();
385   Value* value = graph.NewValue();
386   ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
387   ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok());
388   ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok());
389   ASSERT_TRUE(graph.SetProducer(node2->id, graph_output->id).ok());
390 
391   EXPECT_THAT(graph.values(),
392               UnorderedElementsAre(graph_input, graph_output, value));
393   EXPECT_THAT(graph.FindConsumers(value->id), UnorderedElementsAre(node2));
394   EXPECT_THAT(graph.FindProducer(value->id), ::testing::Eq(node1));
395   EXPECT_THAT(graph.FindInputs(node2->id), UnorderedElementsAre(value));
396   EXPECT_THAT(graph.FindOutputs(node1->id), UnorderedElementsAre(value));
397 
398   ASSERT_TRUE(graph.DeleteValue(value->id).ok());
399   value = nullptr;
400   EXPECT_THAT(graph.values(), UnorderedElementsAre(graph_input, graph_output));
401   EXPECT_THAT(graph.FindInputs(node2->id), UnorderedElementsAre());
402   EXPECT_THAT(graph.FindOutputs(node1->id), UnorderedElementsAre());
403 
404   ASSERT_TRUE(graph.DeleteValue(graph_input->id).ok());
405   graph_input = nullptr;
406   EXPECT_THAT(graph.values(), UnorderedElementsAre(graph_output));
407   EXPECT_THAT(graph.inputs(), UnorderedElementsAre());
408   EXPECT_THAT(graph.FindInputs(node1->id), UnorderedElementsAre());
409 
410   ASSERT_TRUE(graph.DeleteValue(graph_output->id).ok());
411   graph_output = nullptr;
412   EXPECT_THAT(graph.values(), UnorderedElementsAre());
413   EXPECT_THAT(graph.outputs(), UnorderedElementsAre());
414   EXPECT_THAT(graph.FindOutputs(node2->id), UnorderedElementsAre());
415 }
416 
TEST(Model,DeleteNode)417 TEST(Model, DeleteNode) {
418   // graph_input -> node1 -> value  -> node2 -> graph_output
419   //                               \-> node3 -> graph_output2
420   GraphFloat32 graph;
421   Node* node1 = graph.NewNode();
422   Node* node2 = graph.NewNode();
423   Node* node3 = graph.NewNode();
424   Value* graph_input = graph.NewValue();
425   Value* graph_output = graph.NewValue();
426   Value* graph_output2 = graph.NewValue();
427   Value* value = graph.NewValue();
428   ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
429   ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok());
430   ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok());
431   ASSERT_TRUE(graph.AddConsumer(node3->id, value->id).ok());
432   ASSERT_TRUE(graph.SetProducer(node2->id, graph_output->id).ok());
433   ASSERT_TRUE(graph.SetProducer(node3->id, graph_output2->id).ok());
434 
435   EXPECT_THAT(graph.nodes(), ElementsAre(node1, node2, node3));
436   EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input));
437   EXPECT_THAT(graph.outputs(),
438               UnorderedElementsAre(graph_output, graph_output2));
439   EXPECT_THAT(graph.FindConsumers(value->id),
440               UnorderedElementsAre(node2, node3));
441   EXPECT_THAT(graph.FindProducer(value->id), ::testing::Eq(node1));
442   EXPECT_THAT(graph.FindInputs(node2->id), UnorderedElementsAre(value));
443   EXPECT_THAT(graph.FindInputs(node3->id), UnorderedElementsAre(value));
444 
445   // graph_input  -> node1 -> value -> node2 -> graph_output
446   // graph_output2
447   ASSERT_TRUE(graph.DeleteNode(node3->id).ok());
448   node3 = nullptr;
449   EXPECT_THAT(graph.nodes(), ElementsAre(node1, node2));
450   EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input, graph_output2));
451   EXPECT_THAT(graph.outputs(),
452               UnorderedElementsAre(graph_output, graph_output2));
453   EXPECT_THAT(graph.FindConsumers(value->id), UnorderedElementsAre(node2));
454 
455   // value -> node2 -> graph_output
456   // graph_input
457   // graph_output2
458   ASSERT_TRUE(graph.DeleteNode(node1->id).ok());
459   node1 = nullptr;
460   EXPECT_THAT(graph.nodes(), ElementsAre(node2));
461   EXPECT_THAT(graph.inputs(),
462               UnorderedElementsAre(value, graph_output2, graph_input));
463   EXPECT_THAT(graph.outputs(),
464               UnorderedElementsAre(graph_input, graph_output, graph_output2));
465   EXPECT_THAT(graph.FindConsumers(value->id), UnorderedElementsAre(node2));
466   EXPECT_THAT(graph.FindProducer(value->id), ::testing::Eq(nullptr));
467 
468   ASSERT_TRUE(graph.DeleteNode(node2->id).ok());
469   node2 = nullptr;
470   EXPECT_THAT(graph.nodes(), ElementsAre());
471   EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_output, graph_output2,
472                                                    graph_input, value));
473   EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output, graph_output2,
474                                                     graph_input, value));
475   EXPECT_THAT(graph.FindConsumers(value->id), UnorderedElementsAre());
476   EXPECT_THAT(graph.FindProducer(value->id), ::testing::Eq(nullptr));
477 }
478 
TEST(Model,InsertNodeAfter)479 TEST(Model, InsertNodeAfter) {
480   // graph_input -> node1 -> value -> node2 -> graph_output
481   GraphFloat32 graph;
482   Node* node1 = graph.NewNode();
483   Node* node2 = graph.NewNode();
484   Value* graph_input = graph.NewValue();
485   Value* graph_output = graph.NewValue();
486   Value* value = graph.NewValue();
487   ASSERT_TRUE(graph.AddConsumer(node1->id, graph_input->id).ok());
488   ASSERT_TRUE(graph.SetProducer(node1->id, value->id).ok());
489   ASSERT_TRUE(graph.AddConsumer(node2->id, value->id).ok());
490   ASSERT_TRUE(graph.SetProducer(node2->id, graph_output->id).ok());
491 
492   EXPECT_THAT(graph.nodes(), ElementsAre(node1, node2));
493   EXPECT_THAT(graph.inputs(), UnorderedElementsAre(graph_input));
494   EXPECT_THAT(graph.outputs(), UnorderedElementsAre(graph_output));
495   EXPECT_THAT(graph.FindConsumers(value->id), UnorderedElementsAre(node2));
496   EXPECT_THAT(graph.FindProducer(value->id), ::testing::Eq(node1));
497   EXPECT_THAT(graph.FindInputs(node2->id), UnorderedElementsAre(value));
498 
499   Node* new_node1;
500   absl::Status status = graph.InsertNodeAfter(node1->id, &new_node1);
501   ASSERT_TRUE(status.ok());
502   EXPECT_THAT(graph.nodes(), ElementsAre(node1, new_node1, node2));
503 
504   Node* new_node2;
505   status = graph.InsertNodeAfter(/*id=*/100, &new_node2);
506   EXPECT_EQ(status.code(), absl::StatusCode::kOutOfRange);
507 
508   status = graph.InsertNodeAfter(node2->id, &new_node2);
509   ASSERT_TRUE(status.ok());
510   EXPECT_THAT(graph.nodes(), ElementsAre(node1, new_node1, node2, new_node2));
511 }
512 
TEST(BatchMatchingTest,EmptyGraph)513 TEST(BatchMatchingTest, EmptyGraph) {
514   GraphFloat32 graph;
515   ASSERT_TRUE(IsBatchMatchesForAllValues(graph));
516 }
517 
TEST(BatchMatchingTest,AllMatch)518 TEST(BatchMatchingTest, AllMatch) {
519   GraphFloat32 graph;
520   Value* a = graph.NewValue();
521   Value* b = graph.NewValue();
522   a->tensor.shape = BHWC(1, 1, 1, 1);
523   b->tensor.shape = BHWC(1, 1, 1, 1);
524   ASSERT_TRUE(IsBatchMatchesForAllValues(graph));
525 }
526 
TEST(BatchMatchingTest,NotAllMatch)527 TEST(BatchMatchingTest, NotAllMatch) {
528   GraphFloat32 graph;
529   Value* a = graph.NewValue();
530   Value* b = graph.NewValue();
531   a->tensor.shape = BHWC(1, 1, 1, 1);
532   b->tensor.shape = BHWC(2, 1, 1, 1);
533   ASSERT_FALSE(IsBatchMatchesForAllValues(graph));
534 }
535 
536 }  // namespace
537 }  // namespace gpu
538 }  // namespace tflite
539