1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow; 17 18 import static org.junit.Assert.assertEquals; 19 import static org.junit.Assert.assertFalse; 20 import static org.junit.Assert.assertNotNull; 21 import static org.junit.Assert.assertTrue; 22 23 import java.util.HashSet; 24 import java.util.Iterator; 25 import org.junit.Test; 26 import org.junit.runner.RunWith; 27 import org.junit.runners.JUnit4; 28 29 /** Unit tests for {@link org.tensorflow.Graph}. */ 30 @RunWith(JUnit4.class) 31 public class GraphTest { 32 33 @Test graphDefRoundTrip()34 public void graphDefRoundTrip() { 35 byte[] graphDef; 36 // Create a graph for A * X + B 37 try (Graph g = new Graph()) { 38 TestUtil.transpose_A_times_X(g, new int[2][2]); 39 graphDef = g.toGraphDef(); 40 } 41 // Import the GraphDef and find all the nodes. 42 try (Graph g = new Graph()) { 43 g.importGraphDef(graphDef); 44 validateImportedGraph(g, ""); 45 } 46 try (Graph g = new Graph()) { 47 g.importGraphDef(graphDef, "BugsBunny"); 48 validateImportedGraph(g, "BugsBunny/"); 49 } 50 } 51 52 // Helper function whose implementation is based on knowledge of how 53 // TestUtil.transpose_A_times_X is implemented. validateImportedGraph(Graph g, String prefix)54 private static void validateImportedGraph(Graph g, String prefix) { 55 Operation op = g.operation(prefix + "A"); 56 assertNotNull(op); 57 assertEquals(prefix + "A", op.name()); 58 assertEquals("Const", op.type()); 59 assertEquals(1, op.numOutputs()); 60 assertEquals(op, op.output(0).op()); 61 62 op = g.operation(prefix + "X"); 63 assertNotNull(op); 64 assertEquals(prefix + "X", op.name()); 65 assertEquals("Placeholder", op.type()); 66 assertEquals(1, op.numOutputs()); 67 assertEquals(op, op.output(0).op()); 68 69 op = g.operation(prefix + "Y"); 70 assertNotNull(op); 71 assertEquals(prefix + "Y", op.name()); 72 assertEquals("MatMul", op.type()); 73 assertEquals(1, op.numOutputs()); 74 assertEquals(op, op.output(0).op()); 75 } 76 77 @Test iterateOverOperations()78 public void iterateOverOperations() { 79 try (Graph g = new Graph()) { 80 Iterator<Operation> iterator = g.operations(); 81 HashSet<Operation> operations; 82 83 assertFalse(iterator.hasNext()); 84 85 operations = new HashSet<>(); 86 operations.add(TestUtil.constant(g, "Const-A", Float.valueOf(1.0f)).op()); 87 operations.add(TestUtil.constant(g, "Const-B", Integer.valueOf(23)).op()); 88 operations.add(TestUtil.constant(g, "Const-C", Double.valueOf(1.618)).op()); 89 90 iterator = g.operations(); 91 92 assertTrue(iterator.hasNext()); 93 assertTrue(operations.remove(iterator.next())); 94 95 assertTrue(iterator.hasNext()); 96 assertTrue(operations.remove(iterator.next())); 97 98 assertTrue(iterator.hasNext()); 99 assertTrue(operations.remove(iterator.next())); 100 101 assertFalse(iterator.hasNext()); 102 } 103 } 104 105 @Test failImportOnInvalidGraphDefs()106 public void failImportOnInvalidGraphDefs() { 107 try (Graph g = new Graph()) { 108 try { 109 g.importGraphDef(null); 110 } catch (IllegalArgumentException e) { 111 // expected exception. 112 } 113 114 try { 115 g.importGraphDef(new byte[] {1}); 116 } catch (IllegalArgumentException e) { 117 // expected exception. 118 } 119 } 120 } 121 122 @Test failOnUseAfterClose()123 public void failOnUseAfterClose() { 124 Graph g = new Graph(); 125 g.close(); 126 try { 127 g.toGraphDef(); 128 } catch (IllegalStateException e) { 129 // expected exception. 130 } 131 } 132 133 @Test addGradientsToGraph()134 public void addGradientsToGraph() { 135 try (Graph g = new Graph(); 136 Session s = new Session(g)) { 137 138 Output<Float> x1 = TestUtil.placeholder(g, "x1", Float.class); 139 Output<Float> x2 = TestUtil.placeholder(g, "x2", Float.class); 140 Output<Float> y0 = TestUtil.square(g, "y0", x1); 141 Output<Float> y1 = TestUtil.square(g, "y1", y0); 142 Output<Float> y2 = TestUtil.addN(g, y0, x2); 143 144 Output<?>[] grads0 = g.addGradients(y1, toArray(x1)); 145 assertNotNull(grads0); 146 assertEquals(1, grads0.length); 147 assertEquals(DataType.FLOAT, grads0[0].dataType()); 148 149 Output<?>[] grads1 = g.addGradients(y2, toArray(x1, x2)); 150 assertNotNull(grads1); 151 assertEquals(2, grads1.length); 152 assertEquals(DataType.FLOAT, grads1[0].dataType()); 153 assertEquals(DataType.FLOAT, grads1[1].dataType()); 154 155 try (Tensor<Float> c1 = Tensors.create(3.0f); 156 Tensor<Float> c2 = Tensors.create(2.0f); 157 TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<>( 158 s.runner() 159 .feed(x1, c1) 160 .feed(x2, c2) 161 .fetch(grads0[0]) 162 .fetch(grads1[0]) 163 .fetch(grads1[1]) 164 .run())) { 165 166 assertEquals(3, outputs.size()); 167 assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); 168 assertEquals(6.0f, outputs.get(1).floatValue(), 0.0f); 169 assertEquals(1.0f, outputs.get(2).floatValue(), 0.0f); 170 } 171 } 172 } 173 174 @Test addGradientSumsToGraph()175 public void addGradientSumsToGraph() { 176 try (Graph g = new Graph(); 177 Session s = new Session(g)) { 178 179 Output<Float> x = TestUtil.placeholder(g, "x", Float.class); 180 Output<Float> y0 = TestUtil.square(g, "y0", x); 181 Output<Float> y1 = TestUtil.square(g, "y1", y0); 182 183 Output<?>[] grad = g.addGradients(null, toArray(y0, y1), toArray(x), null); 184 assertNotNull(grad); 185 assertEquals(1, grad.length); 186 assertEquals(DataType.FLOAT, grad[0].dataType()); 187 188 try (Tensor<Float> c = Tensors.create(3.0f); 189 Tensor<?> output = s.runner() 190 .feed(x, c) 191 .fetch(grad[0]) 192 .run() 193 .get(0)) { 194 195 assertEquals(114.0f, output.floatValue(), 0.0f); 196 } 197 } 198 } 199 200 @Test addGradientsWithInitialValuesToGraph()201 public void addGradientsWithInitialValuesToGraph() { 202 try (Graph g = new Graph(); 203 Session s = new Session(g)) { 204 205 Output<Float> x = TestUtil.placeholder(g, "x", Float.class); 206 Output<Float> y0 = TestUtil.square(g, "y0", x); 207 Output<Float> y1 = TestUtil.square(g, "y1", y0); 208 209 Output<?>[] grad0 = g.addGradients(y1, toArray(y0)); 210 assertNotNull(grad0); 211 assertEquals(1, grad0.length); 212 assertEquals(DataType.FLOAT, grad0[0].dataType()); 213 214 Output<?>[] grad1 = g.addGradients(null, toArray(y0), toArray(x), toArray(grad0[0])); 215 assertNotNull(grad1); 216 assertEquals(1, grad1.length); 217 assertEquals(DataType.FLOAT, grad1[0].dataType()); 218 219 try (Tensor<Float> c = Tensors.create(3.0f); 220 Tensor<?> output = s.runner() 221 .feed(x, c) 222 .fetch(grad1[0]) 223 .run() 224 .get(0)) { 225 226 assertEquals(108.0f, output.floatValue(), 0.0f); 227 } 228 } 229 } 230 231 @Test validateGradientsNames()232 public void validateGradientsNames() { 233 try (Graph g = new Graph()) { 234 235 Output<Float> x = TestUtil.placeholder(g, "x", Float.class); 236 Output<Float> y0 = TestUtil.square(g, "y0", x); 237 238 Output<?>[] grad0 = g.addGradients(null, toArray(y0), toArray(x), null); 239 assertTrue(grad0[0].op().name().startsWith("gradients/")); 240 241 Output<?>[] grad1 = g.addGradients(null, toArray(y0), toArray(x), null); 242 assertTrue(grad1[0].op().name().startsWith("gradients_1/")); 243 244 Output<?>[] grad2 = g.addGradients("more_gradients", toArray(y0), toArray(x), null); 245 assertTrue(grad2[0].op().name().startsWith("more_gradients/")); 246 247 Output<?>[] grad3 = g.addGradients("even_more_gradients", toArray(y0), toArray(x), null); 248 assertTrue(grad3[0].op().name().startsWith("even_more_gradients/")); 249 250 try { 251 g.addGradients("even_more_gradients", toArray(y0), toArray(x), null); 252 } catch (IllegalArgumentException e) { 253 // expected exception 254 } 255 } 256 } 257 258 @Test buildWhileLoopSingleInput()259 public void buildWhileLoopSingleInput() { 260 try (Graph g = new Graph(); 261 Session s = new Session(g)) { 262 263 Output<?> input = TestUtil.placeholder(g, "input1", Integer.class); 264 265 // could write this using lambda after Java 8 266 Graph.WhileSubgraphBuilder condGraphBuilder = 267 new Graph.WhileSubgraphBuilder() { 268 @Override 269 public void buildSubgraph( 270 Graph condGraph, Output<?>[] condInputs, Output<?>[] condOutputs) { 271 Output<Integer> sixteen = TestUtil.constant(condGraph, "sixteen", 16); 272 // condInputs[0] < 16 273 Output<?> condOutput = 274 condGraph 275 .opBuilder("Less", "cond") 276 .addInput(condInputs[0]) 277 .addInput(sixteen) 278 .build() 279 .output(0); 280 281 condOutputs[0] = condOutput; 282 } 283 }; 284 285 // could write this using lambda after Java 8 286 Graph.WhileSubgraphBuilder bodyGraphBuilder = 287 new Graph.WhileSubgraphBuilder() { 288 @Override 289 public void buildSubgraph( 290 Graph bodyGraph, Output<?>[] bodyInputs, Output<?>[] bodyOutputs) { 291 bodyOutputs[0] = TestUtil.square(bodyGraph, "square", bodyInputs[0]); 292 } 293 }; 294 295 Output<?>[] loopOutputs = 296 g.whileLoop(toArray(input), condGraphBuilder, bodyGraphBuilder, "test_loop"); 297 298 try (Tensor<Integer> c = Tensors.create(2); 299 Tensor<?> output = s.runner().feed(input, c).fetch(loopOutputs[0]).run().get(0)) { 300 301 assertEquals(16, output.intValue()); // ((2^2)^2) 302 } 303 } 304 } 305 306 @Test buildWhileLoopMultipleInputs()307 public void buildWhileLoopMultipleInputs() { 308 try (Graph g = new Graph(); 309 Session s = new Session(g)) { 310 311 Output<?> input1 = TestUtil.placeholder(g, "input1", Integer.class); 312 Output<?> input2 = TestUtil.placeholder(g, "input2", Integer.class); 313 Output<?>[] inputs = toArray(input1, input2); 314 315 // could write this using lambda after Java 8 316 Graph.WhileSubgraphBuilder condGraphBuilder = 317 new Graph.WhileSubgraphBuilder() { 318 @Override 319 public void buildSubgraph( 320 Graph condGraph, Output<?>[] condInputs, Output<?>[] condOutputs) { 321 Output<Integer> sixteen = TestUtil.constant(condGraph, "sixteen", 16); 322 Output<?> condOutput = 323 condGraph 324 .opBuilder("Less", "cond") 325 .addInput(condInputs[0]) 326 .addInput(sixteen) 327 .build() 328 .output(0); // condInputs[0] < 16 329 330 condOutputs[0] = condOutput; 331 } 332 }; 333 334 // could write this using lambda after Java 8 335 Graph.WhileSubgraphBuilder bodyGraphBuilder = 336 new Graph.WhileSubgraphBuilder() { 337 @Override 338 public void buildSubgraph( 339 Graph bodyGraph, Output<?>[] bodyInputs, Output<?>[] bodyOutputs) { 340 bodyOutputs[0] = TestUtil.square(bodyGraph, "square1", bodyInputs[0]); 341 bodyOutputs[1] = TestUtil.square(bodyGraph, "square2", bodyInputs[1]); 342 } 343 }; 344 345 Output<?>[] loopOutputs = 346 g.whileLoop(inputs, condGraphBuilder, bodyGraphBuilder, "test_loop"); 347 348 try (Tensor<Integer> c1 = Tensors.create(2); 349 Tensor<Integer> c2 = Tensors.create(5); 350 TestUtil.AutoCloseableList<Tensor<?>> outputs = 351 new TestUtil.AutoCloseableList<>( 352 s.runner() 353 .feed(input1, c1) 354 .feed(input2, c2) 355 .fetch(loopOutputs[0]) 356 .fetch(loopOutputs[1]) 357 .run())) { 358 359 assertEquals(2, outputs.size()); 360 assertEquals(16, outputs.get(0).intValue()); // ((2^2)^2) 361 assertEquals(625, outputs.get(1).intValue()); // ((5^2)^2) 362 } 363 } 364 } 365 toArray(Output<?>.... outputs)366 private static Output<?>[] toArray(Output<?>... outputs) { 367 return outputs; 368 } 369 } 370