1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow; 17 18 import static org.junit.Assert.assertEquals; 19 import static org.junit.Assert.assertTrue; 20 import static org.junit.Assert.fail; 21 22 import org.junit.Ignore; 23 import org.junit.Test; 24 import org.junit.runner.RunWith; 25 import org.junit.runners.JUnit4; 26 27 /** Unit tests for {@link org.tensorflow.OperationBuilder}. */ 28 @RunWith(JUnit4.class) 29 public class OperationBuilderTest { 30 // TODO(ashankar): Restore this test once the C API gracefully handles mixing graphs and 31 // operations instead of segfaulting. 32 @Test 33 @Ignore failWhenMixingOperationsOnDifferentGraphs()34 public void failWhenMixingOperationsOnDifferentGraphs() { 35 try (Graph g1 = new Graph(); 36 Graph g2 = new Graph()) { 37 Output<Integer> c1 = TestUtil.constant(g1, "C1", 3); 38 Output<Integer> c2 = TestUtil.constant(g2, "C2", 3); 39 TestUtil.addN(g1, c1, c1); 40 try { 41 TestUtil.addN(g2, c1, c2); 42 } catch (Exception e) { 43 fail(e.toString()); 44 } 45 } 46 } 47 48 @Test failOnUseAfterBuild()49 public void failOnUseAfterBuild() { 50 try (Graph g = new Graph(); 51 Tensor<Integer> t = Tensors.create(1)) { 52 OperationBuilder b = 53 g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t); 54 b.build(); 55 try { 56 b.setAttr("dtype", t.dataType()); 57 } catch (IllegalStateException e) { 58 // expected exception. 59 } 60 } 61 } 62 63 @Test failOnUseAfterGraphClose()64 public void failOnUseAfterGraphClose() { 65 OperationBuilder b = null; 66 try (Graph g = new Graph(); 67 Tensor<Integer> t = Tensors.create(1)) { 68 b = g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t); 69 } 70 try { 71 b.build(); 72 } catch (IllegalStateException e) { 73 // expected exception. 74 } 75 } 76 77 @Test setAttr()78 public void setAttr() { 79 // The effect of setting an attribute may not easily be visible from the other parts of this 80 // package's API. Thus, for now, the test simply executes the various setAttr variants to see 81 // that there are no exceptions. If an attribute is "visible", test for that in a separate test 82 // (like setAttrShape). 83 // 84 // This is a bit of an awkward test since it has to find operations with attributes of specific 85 // types that aren't inferred from the input arguments. 86 try (Graph g = new Graph()) { 87 // dtype, tensor attributes. 88 try (Tensor<Integer> t = Tensors.create(1)) { 89 g.opBuilder("Const", "DataTypeAndTensor") 90 .setAttr("dtype", DataType.INT32) 91 .setAttr("value", t) 92 .build() 93 .output(0); 94 assertTrue(hasNode(g, "DataTypeAndTensor")); 95 } 96 // string, bool attributes. 97 g.opBuilder("Abort", "StringAndBool") 98 .setAttr("error_msg", "SomeErrorMessage") 99 .setAttr("exit_without_error", false) 100 .build(); 101 assertTrue(hasNode(g, "StringAndBool")); 102 // int (TF "int" attributes are 64-bit signed, so a Java long). 103 g.opBuilder("RandomUniform", "Int") 104 .addInput(TestUtil.constant(g, "RandomUniformShape", new int[] {1})) 105 .setAttr("seed", 10) 106 .setAttr("dtype", DataType.FLOAT) 107 .build(); 108 assertTrue(hasNode(g, "Int")); 109 // list(int) 110 g.opBuilder("MaxPool", "IntList") 111 .addInput(TestUtil.constant(g, "MaxPoolInput", new float[2][2][2][2])) 112 .setAttr("ksize", new long[] {1, 1, 1, 1}) 113 .setAttr("strides", new long[] {1, 1, 1, 1}) 114 .setAttr("padding", "SAME") 115 .build(); 116 assertTrue(hasNode(g, "IntList")); 117 // list(float) 118 g.opBuilder("FractionalMaxPool", "FloatList") 119 .addInput(TestUtil.constant(g, "FractionalMaxPoolInput", new float[2][2][2][2])) 120 .setAttr("pooling_ratio", new float[] {1.0f, 1.44f, 1.73f, 1.0f}) 121 .build(); 122 assertTrue(hasNode(g, "FloatList")); 123 // Missing tests: float, list(dtype), list(tensor), list(string), list(bool) 124 } 125 } 126 127 @Test setAttrShape()128 public void setAttrShape() { 129 try (Graph g = new Graph()) { 130 Output<?> n = 131 g.opBuilder("Placeholder", "unknown") 132 .setAttr("dtype", DataType.FLOAT) 133 .setAttr("shape", Shape.unknown()) 134 .build() 135 .output(0); 136 assertEquals(-1, n.shape().numDimensions()); 137 assertEquals(DataType.FLOAT, n.dataType()); 138 139 n = g.opBuilder("Placeholder", "batch_of_vectors") 140 .setAttr("dtype", DataType.FLOAT) 141 .setAttr("shape", Shape.make(-1, 784)) 142 .build() 143 .output(0); 144 assertEquals(2, n.shape().numDimensions()); 145 assertEquals(-1, n.shape().size(0)); 146 assertEquals(784, n.shape().size(1)); 147 assertEquals(DataType.FLOAT, n.dataType()); 148 } 149 } 150 151 @Test setAttrShapeList()152 public void setAttrShapeList() { 153 // Those shapes match tensors ones, so no exception is thrown 154 testSetAttrShapeList(new Shape[] {Shape.make(2, 2), Shape.make(2, 2, 2)}); 155 try { 156 // Those shapes do not match tensors ones, exception is thrown 157 testSetAttrShapeList(new Shape[] {Shape.make(2, 2), Shape.make(2, 2, 2, 2)}); 158 fail("Shapes are incompatible and an exception was expected"); 159 } catch (IllegalArgumentException e) { 160 // expected 161 } 162 } 163 164 @Test addControlInput()165 public void addControlInput() { 166 try (Graph g = new Graph(); 167 Session s = new Session(g); 168 Tensor<Boolean> yes = Tensors.create(true); 169 Tensor<Boolean> no = Tensors.create(false)) { 170 Output<Boolean> placeholder = TestUtil.placeholder(g, "boolean", Boolean.class); 171 Operation check = 172 g.opBuilder("Assert", "assert") 173 .addInput(placeholder) 174 .addInputList(new Output<?>[] {placeholder}) 175 .build(); 176 Operation noop = g.opBuilder("NoOp", "noop").addControlInput(check).build(); 177 178 // No problems when the Assert check succeeds 179 s.runner().feed(placeholder, yes).addTarget(noop).run(); 180 181 // Exception thrown by the execution of the Assert node 182 try { 183 s.runner().feed(placeholder, no).addTarget(noop).run(); 184 fail("Did not run control operation."); 185 } catch (IllegalArgumentException e) { 186 // expected 187 } 188 } 189 } 190 testSetAttrShapeList(Shape[] shapes)191 private static void testSetAttrShapeList(Shape[] shapes) { 192 try (Graph g = new Graph(); 193 Session s = new Session(g)) { 194 int[][] matrix = new int[][] {{0, 0}, {0, 0}}; 195 Output<?> queue = 196 g.opBuilder("FIFOQueue", "queue") 197 .setAttr("component_types", new DataType[] {DataType.INT32, DataType.INT32}) 198 .setAttr("shapes", shapes) 199 .build() 200 .output(0); 201 assertTrue(hasNode(g, "queue")); 202 Output<Integer> c1 = TestUtil.constant(g, "const1", matrix); 203 Output<Integer> c2 = TestUtil.constant(g, "const2", new int[][][] {matrix, matrix}); 204 Operation enqueue = 205 g.opBuilder("QueueEnqueue", "enqueue") 206 .addInput(queue) 207 .addInputList(new Output<?>[] {c1, c2}) 208 .build(); 209 assertTrue(hasNode(g, "enqueue")); 210 211 s.runner().addTarget(enqueue).run(); 212 } 213 } 214 hasNode(Graph g, String name)215 private static boolean hasNode(Graph g, String name) { 216 return g.operation(name) != null; 217 } 218 } 219