• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow;
17 
18 import static org.junit.Assert.assertEquals;
19 import static org.junit.Assert.assertFalse;
20 import static org.junit.Assert.assertNotNull;
21 import static org.junit.Assert.assertTrue;
22 
23 import java.util.HashSet;
24 import java.util.Iterator;
25 import org.junit.Test;
26 import org.junit.runner.RunWith;
27 import org.junit.runners.JUnit4;
28 
29 /** Unit tests for {@link org.tensorflow.Graph}. */
30 @RunWith(JUnit4.class)
31 public class GraphTest {
32 
33   @Test
graphDefRoundTrip()34   public void graphDefRoundTrip() {
35     byte[] graphDef;
36     // Create a graph for A * X + B
37     try (Graph g = new Graph()) {
38       TestUtil.transpose_A_times_X(g, new int[2][2]);
39       graphDef = g.toGraphDef();
40     }
41     // Import the GraphDef and find all the nodes.
42     try (Graph g = new Graph()) {
43       g.importGraphDef(graphDef);
44       validateImportedGraph(g, "");
45     }
46     try (Graph g = new Graph()) {
47       g.importGraphDef(graphDef, "BugsBunny");
48       validateImportedGraph(g, "BugsBunny/");
49     }
50   }
51 
52   // Helper function whose implementation is based on knowledge of how
53   // TestUtil.transpose_A_times_X is implemented.
validateImportedGraph(Graph g, String prefix)54   private static void validateImportedGraph(Graph g, String prefix) {
55     Operation op = g.operation(prefix + "A");
56     assertNotNull(op);
57     assertEquals(prefix + "A", op.name());
58     assertEquals("Const", op.type());
59     assertEquals(1, op.numOutputs());
60     assertEquals(op, op.output(0).op());
61 
62     op = g.operation(prefix + "X");
63     assertNotNull(op);
64     assertEquals(prefix + "X", op.name());
65     assertEquals("Placeholder", op.type());
66     assertEquals(1, op.numOutputs());
67     assertEquals(op, op.output(0).op());
68 
69     op = g.operation(prefix + "Y");
70     assertNotNull(op);
71     assertEquals(prefix + "Y", op.name());
72     assertEquals("MatMul", op.type());
73     assertEquals(1, op.numOutputs());
74     assertEquals(op, op.output(0).op());
75   }
76 
77   @Test
iterateOverOperations()78   public void iterateOverOperations() {
79     try (Graph g = new Graph()) {
80       Iterator<Operation> iterator = g.operations();
81       HashSet<Operation> operations;
82 
83       assertFalse(iterator.hasNext());
84 
85       operations = new HashSet<>();
86       operations.add(TestUtil.constant(g, "Const-A", Float.valueOf(1.0f)).op());
87       operations.add(TestUtil.constant(g, "Const-B", Integer.valueOf(23)).op());
88       operations.add(TestUtil.constant(g, "Const-C", Double.valueOf(1.618)).op());
89 
90       iterator = g.operations();
91 
92       assertTrue(iterator.hasNext());
93       assertTrue(operations.remove(iterator.next()));
94 
95       assertTrue(iterator.hasNext());
96       assertTrue(operations.remove(iterator.next()));
97 
98       assertTrue(iterator.hasNext());
99       assertTrue(operations.remove(iterator.next()));
100 
101       assertFalse(iterator.hasNext());
102     }
103   }
104 
105   @Test
failImportOnInvalidGraphDefs()106   public void failImportOnInvalidGraphDefs() {
107     try (Graph g = new Graph()) {
108       try {
109         g.importGraphDef(null);
110       } catch (IllegalArgumentException e) {
111         // expected exception.
112       }
113 
114       try {
115         g.importGraphDef(new byte[] {1});
116       } catch (IllegalArgumentException e) {
117         // expected exception.
118       }
119     }
120   }
121 
122   @Test
failOnUseAfterClose()123   public void failOnUseAfterClose() {
124     Graph g = new Graph();
125     g.close();
126     try {
127       g.toGraphDef();
128     } catch (IllegalStateException e) {
129       // expected exception.
130     }
131   }
132 
133   @Test
addGradientsToGraph()134   public void addGradientsToGraph() {
135     try (Graph g = new Graph();
136         Session s = new Session(g)) {
137 
138       Output<Float> x1 = TestUtil.placeholder(g, "x1", Float.class);
139       Output<Float> x2 = TestUtil.placeholder(g, "x2", Float.class);
140       Output<Float> y0 = TestUtil.square(g, "y0", x1);
141       Output<Float> y1 = TestUtil.square(g, "y1", y0);
142       Output<Float> y2 = TestUtil.addN(g, y0, x2);
143 
144       Output<?>[] grads0 = g.addGradients(y1, toArray(x1));
145       assertNotNull(grads0);
146       assertEquals(1, grads0.length);
147       assertEquals(DataType.FLOAT, grads0[0].dataType());
148 
149       Output<?>[] grads1 = g.addGradients(y2, toArray(x1, x2));
150       assertNotNull(grads1);
151       assertEquals(2, grads1.length);
152       assertEquals(DataType.FLOAT, grads1[0].dataType());
153       assertEquals(DataType.FLOAT, grads1[1].dataType());
154 
155       try (Tensor<Float> c1 = Tensors.create(3.0f);
156           Tensor<Float> c2 = Tensors.create(2.0f);
157           TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<>(
158               s.runner()
159                   .feed(x1, c1)
160                   .feed(x2, c2)
161                   .fetch(grads0[0])
162                   .fetch(grads1[0])
163                   .fetch(grads1[1])
164                   .run())) {
165 
166         assertEquals(3, outputs.size());
167         assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f);
168         assertEquals(6.0f, outputs.get(1).floatValue(), 0.0f);
169         assertEquals(1.0f, outputs.get(2).floatValue(), 0.0f);
170       }
171     }
172   }
173 
174   @Test
addGradientSumsToGraph()175   public void addGradientSumsToGraph() {
176     try (Graph g = new Graph();
177         Session s = new Session(g)) {
178 
179       Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
180       Output<Float> y0 = TestUtil.square(g, "y0", x);
181       Output<Float> y1 = TestUtil.square(g, "y1", y0);
182 
183       Output<?>[] grad = g.addGradients(null, toArray(y0, y1), toArray(x), null);
184       assertNotNull(grad);
185       assertEquals(1, grad.length);
186       assertEquals(DataType.FLOAT, grad[0].dataType());
187 
188       try (Tensor<Float> c = Tensors.create(3.0f);
189           Tensor<?> output = s.runner()
190               .feed(x, c)
191               .fetch(grad[0])
192               .run()
193               .get(0)) {
194 
195         assertEquals(114.0f, output.floatValue(), 0.0f);
196       }
197     }
198   }
199 
200   @Test
addGradientsWithInitialValuesToGraph()201   public void addGradientsWithInitialValuesToGraph() {
202     try (Graph g = new Graph();
203         Session s = new Session(g)) {
204 
205       Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
206       Output<Float> y0 = TestUtil.square(g, "y0", x);
207       Output<Float> y1 = TestUtil.square(g, "y1", y0);
208 
209       Output<?>[] grad0 = g.addGradients(y1, toArray(y0));
210       assertNotNull(grad0);
211       assertEquals(1, grad0.length);
212       assertEquals(DataType.FLOAT, grad0[0].dataType());
213 
214       Output<?>[] grad1 = g.addGradients(null, toArray(y0), toArray(x), toArray(grad0[0]));
215       assertNotNull(grad1);
216       assertEquals(1, grad1.length);
217       assertEquals(DataType.FLOAT, grad1[0].dataType());
218 
219       try (Tensor<Float> c = Tensors.create(3.0f);
220           Tensor<?> output = s.runner()
221               .feed(x, c)
222               .fetch(grad1[0])
223               .run()
224               .get(0)) {
225 
226         assertEquals(108.0f, output.floatValue(), 0.0f);
227       }
228     }
229   }
230 
231   @Test
validateGradientsNames()232   public void validateGradientsNames() {
233     try (Graph g = new Graph()) {
234 
235       Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
236       Output<Float> y0 = TestUtil.square(g, "y0", x);
237 
238       Output<?>[] grad0 = g.addGradients(null, toArray(y0), toArray(x), null);
239       assertTrue(grad0[0].op().name().startsWith("gradients/"));
240 
241       Output<?>[] grad1 = g.addGradients(null, toArray(y0), toArray(x), null);
242       assertTrue(grad1[0].op().name().startsWith("gradients_1/"));
243 
244       Output<?>[] grad2 = g.addGradients("more_gradients", toArray(y0), toArray(x), null);
245       assertTrue(grad2[0].op().name().startsWith("more_gradients/"));
246 
247       Output<?>[] grad3 = g.addGradients("even_more_gradients", toArray(y0), toArray(x), null);
248       assertTrue(grad3[0].op().name().startsWith("even_more_gradients/"));
249 
250       try {
251         g.addGradients("even_more_gradients", toArray(y0), toArray(x), null);
252       } catch (IllegalArgumentException e) {
253         // expected exception
254       }
255     }
256   }
257 
258   @Test
buildWhileLoopSingleInput()259   public void buildWhileLoopSingleInput() {
260     try (Graph g = new Graph();
261         Session s = new Session(g)) {
262 
263       Output<?> input = TestUtil.placeholder(g, "input1", Integer.class);
264 
265       // could write this using lambda after Java 8
266       Graph.WhileSubgraphBuilder condGraphBuilder =
267           new Graph.WhileSubgraphBuilder() {
268             @Override
269             public void buildSubgraph(
270                 Graph condGraph, Output<?>[] condInputs, Output<?>[] condOutputs) {
271               Output<Integer> sixteen = TestUtil.constant(condGraph, "sixteen", 16);
272               // condInputs[0] < 16
273               Output<?> condOutput =
274                   condGraph
275                       .opBuilder("Less", "cond")
276                       .addInput(condInputs[0])
277                       .addInput(sixteen)
278                       .build()
279                       .output(0);
280 
281               condOutputs[0] = condOutput;
282             }
283           };
284 
285       // could write this using lambda after Java 8
286       Graph.WhileSubgraphBuilder bodyGraphBuilder =
287           new Graph.WhileSubgraphBuilder() {
288             @Override
289             public void buildSubgraph(
290                 Graph bodyGraph, Output<?>[] bodyInputs, Output<?>[] bodyOutputs) {
291               bodyOutputs[0] = TestUtil.square(bodyGraph, "square", bodyInputs[0]);
292             }
293           };
294 
295       Output<?>[] loopOutputs =
296           g.whileLoop(toArray(input), condGraphBuilder, bodyGraphBuilder, "test_loop");
297 
298       try (Tensor<Integer> c = Tensors.create(2);
299           Tensor<?> output = s.runner().feed(input, c).fetch(loopOutputs[0]).run().get(0)) {
300 
301         assertEquals(16, output.intValue()); // ((2^2)^2)
302       }
303     }
304   }
305 
306   @Test
buildWhileLoopMultipleInputs()307   public void buildWhileLoopMultipleInputs() {
308     try (Graph g = new Graph();
309         Session s = new Session(g)) {
310 
311       Output<?> input1 = TestUtil.placeholder(g, "input1", Integer.class);
312       Output<?> input2 = TestUtil.placeholder(g, "input2", Integer.class);
313       Output<?>[] inputs = toArray(input1, input2);
314 
315       // could write this using lambda after Java 8
316       Graph.WhileSubgraphBuilder condGraphBuilder =
317           new Graph.WhileSubgraphBuilder() {
318             @Override
319             public void buildSubgraph(
320                 Graph condGraph, Output<?>[] condInputs, Output<?>[] condOutputs) {
321               Output<Integer> sixteen = TestUtil.constant(condGraph, "sixteen", 16);
322               Output<?> condOutput =
323                   condGraph
324                       .opBuilder("Less", "cond")
325                       .addInput(condInputs[0])
326                       .addInput(sixteen)
327                       .build()
328                       .output(0); // condInputs[0] < 16
329 
330               condOutputs[0] = condOutput;
331             }
332           };
333 
334       // could write this using lambda after Java 8
335       Graph.WhileSubgraphBuilder bodyGraphBuilder =
336           new Graph.WhileSubgraphBuilder() {
337             @Override
338             public void buildSubgraph(
339                 Graph bodyGraph, Output<?>[] bodyInputs, Output<?>[] bodyOutputs) {
340               bodyOutputs[0] = TestUtil.square(bodyGraph, "square1", bodyInputs[0]);
341               bodyOutputs[1] = TestUtil.square(bodyGraph, "square2", bodyInputs[1]);
342             }
343           };
344 
345       Output<?>[] loopOutputs =
346           g.whileLoop(inputs, condGraphBuilder, bodyGraphBuilder, "test_loop");
347 
348       try (Tensor<Integer> c1 = Tensors.create(2);
349           Tensor<Integer> c2 = Tensors.create(5);
350           TestUtil.AutoCloseableList<Tensor<?>> outputs =
351               new TestUtil.AutoCloseableList<>(
352                   s.runner()
353                       .feed(input1, c1)
354                       .feed(input2, c2)
355                       .fetch(loopOutputs[0])
356                       .fetch(loopOutputs[1])
357                       .run())) {
358 
359         assertEquals(2, outputs.size());
360         assertEquals(16, outputs.get(0).intValue()); // ((2^2)^2)
361         assertEquals(625, outputs.get(1).intValue()); // ((5^2)^2)
362       }
363     }
364   }
365 
toArray(Output<?>.... outputs)366   private static Output<?>[] toArray(Output<?>... outputs) {
367     return outputs;
368   }
369 }
370