1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow; 17 18 import static org.junit.Assert.assertEquals; 19 import static org.junit.Assert.assertTrue; 20 import static org.junit.Assert.fail; 21 22 import org.junit.Ignore; 23 import org.junit.Test; 24 import org.junit.runner.RunWith; 25 import org.junit.runners.JUnit4; 26 27 /** Unit tests for {@link org.tensorflow.GraphOperationBuilder}. */ 28 @RunWith(JUnit4.class) 29 public class GraphOperationBuilderTest { 30 // TODO(ashankar): Restore this test once the C API gracefully handles mixing graphs and 31 // operations instead of segfaulting. 32 @Test 33 @Ignore failWhenMixingOperationsOnDifferentGraphs()34 public void failWhenMixingOperationsOnDifferentGraphs() { 35 try (Graph g1 = new Graph(); 36 Graph g2 = new Graph()) { 37 Output<Integer> c1 = TestUtil.constant(g1, "C1", 3); 38 Output<Integer> c2 = TestUtil.constant(g2, "C2", 3); 39 TestUtil.addN(g1, c1, c1); 40 try { 41 TestUtil.addN(g2, c1, c2); 42 } catch (Exception e) { 43 fail(e.toString()); 44 } 45 } 46 } 47 48 @Test failOnUseAfterBuild()49 public void failOnUseAfterBuild() { 50 try (Graph g = new Graph(); 51 Tensor<Integer> t = Tensors.create(1)) { 52 OperationBuilder b = 53 g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t); 54 b.build(); 55 try { 56 b.setAttr("dtype", t.dataType()); 57 } catch (IllegalStateException e) { 58 // expected exception. 59 } 60 } 61 } 62 63 @Test failOnUseAfterGraphClose()64 public void failOnUseAfterGraphClose() { 65 OperationBuilder b = null; 66 try (Graph g = new Graph(); 67 Tensor<Integer> t = Tensors.create(1)) { 68 b = g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t); 69 } 70 try { 71 b.build(); 72 } catch (IllegalStateException e) { 73 // expected exception. 74 } 75 } 76 77 @Test setAttr()78 public void setAttr() { 79 // The effect of setting an attribute may not easily be visible from the other parts of this 80 // package's API. Thus, for now, the test simply executes the various setAttr variants to see 81 // that there are no exceptions. If an attribute is "visible", test for that in a separate test 82 // (like setAttrShape). 83 // 84 // This is a bit of an awkward test since it has to find operations with attributes of specific 85 // types that aren't inferred from the input arguments. 86 try (Graph g = new Graph()) { 87 // dtype, tensor attributes. 88 try (Tensor<Integer> t = Tensors.create(1)) { 89 g.opBuilder("Const", "DataTypeAndTensor") 90 .setAttr("dtype", DataType.INT32) 91 .setAttr("value", t) 92 .build() 93 .output(0); 94 assertTrue(hasNode(g, "DataTypeAndTensor")); 95 } 96 // string, bool attributes. 97 g.opBuilder("Abort", "StringAndBool") 98 .setAttr("error_msg", "SomeErrorMessage") 99 .setAttr("exit_without_error", false) 100 .build(); 101 assertTrue(hasNode(g, "StringAndBool")); 102 // int (TF "int" attributes are 64-bit signed, so a Java long). 103 g.opBuilder("RandomUniform", "Int") 104 .addInput(TestUtil.constant(g, "RandomUniformShape", new int[] {1})) 105 .setAttr("seed", 10) 106 .setAttr("dtype", DataType.FLOAT) 107 .build(); 108 assertTrue(hasNode(g, "Int")); 109 // list(int) 110 g.opBuilder("MaxPool", "IntList") 111 .addInput(TestUtil.constant(g, "MaxPoolInput", new float[2][2][2][2])) 112 .setAttr("ksize", new long[] {1, 1, 1, 1}) 113 .setAttr("strides", new long[] {1, 1, 1, 1}) 114 .setAttr("padding", "SAME") 115 .build(); 116 assertTrue(hasNode(g, "IntList")); 117 // list(float) 118 g.opBuilder("FractionalMaxPool", "FloatList") 119 .addInput(TestUtil.constant(g, "FractionalMaxPoolInput", new float[2][2][2][2])) 120 .setAttr("pooling_ratio", new float[] {1.0f, 1.44f, 1.73f, 1.0f}) 121 .build(); 122 assertTrue(hasNode(g, "FloatList")); 123 // Missing tests: float, list(dtype), list(tensor), list(string), list(bool) 124 } 125 } 126 127 @Test setAttrShape()128 public void setAttrShape() { 129 try (Graph g = new Graph()) { 130 Output<?> n = 131 g.opBuilder("Placeholder", "unknown") 132 .setAttr("dtype", DataType.FLOAT) 133 .setAttr("shape", Shape.unknown()) 134 .build() 135 .output(0); 136 assertEquals(-1, n.shape().numDimensions()); 137 assertEquals(DataType.FLOAT, n.dataType()); 138 139 n = 140 g.opBuilder("Placeholder", "batch_of_vectors") 141 .setAttr("dtype", DataType.FLOAT) 142 .setAttr("shape", Shape.make(-1, 784)) 143 .build() 144 .output(0); 145 assertEquals(2, n.shape().numDimensions()); 146 assertEquals(-1, n.shape().size(0)); 147 assertEquals(784, n.shape().size(1)); 148 assertEquals(DataType.FLOAT, n.dataType()); 149 } 150 } 151 152 @Test setAttrShapeList()153 public void setAttrShapeList() { 154 // Those shapes match tensors ones, so no exception is thrown 155 testSetAttrShapeList(new Shape[] {Shape.make(2, 2), Shape.make(2, 2, 2)}); 156 try { 157 // Those shapes do not match tensors ones, exception is thrown 158 testSetAttrShapeList(new Shape[] {Shape.make(2, 2), Shape.make(2, 2, 2, 2)}); 159 fail("Shapes are incompatible and an exception was expected"); 160 } catch (IllegalArgumentException e) { 161 // expected 162 } 163 } 164 165 @Test addControlInput()166 public void addControlInput() { 167 try (Graph g = new Graph(); 168 Session s = new Session(g); 169 Tensor<Boolean> yes = Tensors.create(true); 170 Tensor<Boolean> no = Tensors.create(false)) { 171 Output<Boolean> placeholder = TestUtil.placeholder(g, "boolean", Boolean.class); 172 GraphOperation check = 173 g.opBuilder("Assert", "assert") 174 .addInput(placeholder) 175 .addInputList(new Output<?>[] {placeholder}) 176 .build(); 177 Operation noop = g.opBuilder("NoOp", "noop").addControlInput(check).build(); 178 179 // No problems when the Assert check succeeds 180 s.runner().feed(placeholder, yes).addTarget(noop).run(); 181 182 // Exception thrown by the execution of the Assert node 183 try { 184 s.runner().feed(placeholder, no).addTarget(noop).run(); 185 fail("Did not run control operation."); 186 } catch (IllegalArgumentException e) { 187 // expected 188 } 189 } 190 } 191 testSetAttrShapeList(Shape[] shapes)192 private static void testSetAttrShapeList(Shape[] shapes) { 193 try (Graph g = new Graph(); 194 Session s = new Session(g)) { 195 int[][] matrix = new int[][] {{0, 0}, {0, 0}}; 196 Output<?> queue = 197 g.opBuilder("FIFOQueue", "queue") 198 .setAttr("component_types", new DataType[] {DataType.INT32, DataType.INT32}) 199 .setAttr("shapes", shapes) 200 .build() 201 .output(0); 202 assertTrue(hasNode(g, "queue")); 203 Output<Integer> c1 = TestUtil.constant(g, "const1", matrix); 204 Output<Integer> c2 = TestUtil.constant(g, "const2", new int[][][] {matrix, matrix}); 205 Operation enqueue = 206 g.opBuilder("QueueEnqueue", "enqueue") 207 .addInput(queue) 208 .addInputList(new Output<?>[] {c1, c2}) 209 .build(); 210 assertTrue(hasNode(g, "enqueue")); 211 212 s.runner().addTarget(enqueue).run(); 213 } 214 } 215 hasNode(Graph g, String name)216 private static boolean hasNode(Graph g, String name) { 217 return g.operation(name) != null; 218 } 219 } 220