• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow;
17 
18 import static org.junit.Assert.assertEquals;
19 import static org.junit.Assert.assertTrue;
20 import static org.junit.Assert.fail;
21 
22 import org.junit.Ignore;
23 import org.junit.Test;
24 import org.junit.runner.RunWith;
25 import org.junit.runners.JUnit4;
26 
27 /** Unit tests for {@link org.tensorflow.GraphOperationBuilder}. */
28 @RunWith(JUnit4.class)
29 public class GraphOperationBuilderTest {
30   // TODO(ashankar): Restore this test once the C API gracefully handles mixing graphs and
31   // operations instead of segfaulting.
32   @Test
33   @Ignore
failWhenMixingOperationsOnDifferentGraphs()34   public void failWhenMixingOperationsOnDifferentGraphs() {
35     try (Graph g1 = new Graph();
36         Graph g2 = new Graph()) {
37       Output<Integer> c1 = TestUtil.constant(g1, "C1", 3);
38       Output<Integer> c2 = TestUtil.constant(g2, "C2", 3);
39       TestUtil.addN(g1, c1, c1);
40       try {
41         TestUtil.addN(g2, c1, c2);
42       } catch (Exception e) {
43         fail(e.toString());
44       }
45     }
46   }
47 
48   @Test
failOnUseAfterBuild()49   public void failOnUseAfterBuild() {
50     try (Graph g = new Graph();
51         Tensor<Integer> t = Tensors.create(1)) {
52       OperationBuilder b =
53           g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t);
54       b.build();
55       try {
56         b.setAttr("dtype", t.dataType());
57       } catch (IllegalStateException e) {
58         // expected exception.
59       }
60     }
61   }
62 
63   @Test
failOnUseAfterGraphClose()64   public void failOnUseAfterGraphClose() {
65     OperationBuilder b = null;
66     try (Graph g = new Graph();
67         Tensor<Integer> t = Tensors.create(1)) {
68       b = g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t);
69     }
70     try {
71       b.build();
72     } catch (IllegalStateException e) {
73       // expected exception.
74     }
75   }
76 
77   @Test
setAttr()78   public void setAttr() {
79     // The effect of setting an attribute may not easily be visible from the other parts of this
80     // package's API. Thus, for now, the test simply executes the various setAttr variants to see
81     // that there are no exceptions. If an attribute is "visible", test for that in a separate test
82     // (like setAttrShape).
83     //
84     // This is a bit of an awkward test since it has to find operations with attributes of specific
85     // types that aren't inferred from the input arguments.
86     try (Graph g = new Graph()) {
87       // dtype, tensor attributes.
88       try (Tensor<Integer> t = Tensors.create(1)) {
89         g.opBuilder("Const", "DataTypeAndTensor")
90             .setAttr("dtype", DataType.INT32)
91             .setAttr("value", t)
92             .build()
93             .output(0);
94         assertTrue(hasNode(g, "DataTypeAndTensor"));
95       }
96       // string, bool attributes.
97       g.opBuilder("Abort", "StringAndBool")
98           .setAttr("error_msg", "SomeErrorMessage")
99           .setAttr("exit_without_error", false)
100           .build();
101       assertTrue(hasNode(g, "StringAndBool"));
102       // int (TF "int" attributes are 64-bit signed, so a Java long).
103       g.opBuilder("RandomUniform", "Int")
104           .addInput(TestUtil.constant(g, "RandomUniformShape", new int[] {1}))
105           .setAttr("seed", 10)
106           .setAttr("dtype", DataType.FLOAT)
107           .build();
108       assertTrue(hasNode(g, "Int"));
109       // list(int)
110       g.opBuilder("MaxPool", "IntList")
111           .addInput(TestUtil.constant(g, "MaxPoolInput", new float[2][2][2][2]))
112           .setAttr("ksize", new long[] {1, 1, 1, 1})
113           .setAttr("strides", new long[] {1, 1, 1, 1})
114           .setAttr("padding", "SAME")
115           .build();
116       assertTrue(hasNode(g, "IntList"));
117       // list(float)
118       g.opBuilder("FractionalMaxPool", "FloatList")
119           .addInput(TestUtil.constant(g, "FractionalMaxPoolInput", new float[2][2][2][2]))
120           .setAttr("pooling_ratio", new float[] {1.0f, 1.44f, 1.73f, 1.0f})
121           .build();
122       assertTrue(hasNode(g, "FloatList"));
123       // Missing tests: float, list(dtype), list(tensor), list(string), list(bool)
124     }
125   }
126 
127   @Test
setAttrShape()128   public void setAttrShape() {
129     try (Graph g = new Graph()) {
130       Output<?> n =
131           g.opBuilder("Placeholder", "unknown")
132               .setAttr("dtype", DataType.FLOAT)
133               .setAttr("shape", Shape.unknown())
134               .build()
135               .output(0);
136       assertEquals(-1, n.shape().numDimensions());
137       assertEquals(DataType.FLOAT, n.dataType());
138 
139       n =
140           g.opBuilder("Placeholder", "batch_of_vectors")
141               .setAttr("dtype", DataType.FLOAT)
142               .setAttr("shape", Shape.make(-1, 784))
143               .build()
144               .output(0);
145       assertEquals(2, n.shape().numDimensions());
146       assertEquals(-1, n.shape().size(0));
147       assertEquals(784, n.shape().size(1));
148       assertEquals(DataType.FLOAT, n.dataType());
149     }
150   }
151 
152   @Test
setAttrShapeList()153   public void setAttrShapeList() {
154     // Those shapes match tensors ones, so no exception is thrown
155     testSetAttrShapeList(new Shape[] {Shape.make(2, 2), Shape.make(2, 2, 2)});
156     try {
157       // Those shapes do not match tensors ones, exception is thrown
158       testSetAttrShapeList(new Shape[] {Shape.make(2, 2), Shape.make(2, 2, 2, 2)});
159       fail("Shapes are incompatible and an exception was expected");
160     } catch (IllegalArgumentException e) {
161       // expected
162     }
163   }
164 
165   @Test
addControlInput()166   public void addControlInput() {
167     try (Graph g = new Graph();
168         Session s = new Session(g);
169         Tensor<Boolean> yes = Tensors.create(true);
170         Tensor<Boolean> no = Tensors.create(false)) {
171       Output<Boolean> placeholder = TestUtil.placeholder(g, "boolean", Boolean.class);
172       GraphOperation check =
173           g.opBuilder("Assert", "assert")
174               .addInput(placeholder)
175               .addInputList(new Output<?>[] {placeholder})
176               .build();
177       Operation noop = g.opBuilder("NoOp", "noop").addControlInput(check).build();
178 
179       // No problems when the Assert check succeeds
180       s.runner().feed(placeholder, yes).addTarget(noop).run();
181 
182       // Exception thrown by the execution of the Assert node
183       try {
184         s.runner().feed(placeholder, no).addTarget(noop).run();
185         fail("Did not run control operation.");
186       } catch (IllegalArgumentException e) {
187         // expected
188       }
189     }
190   }
191 
testSetAttrShapeList(Shape[] shapes)192   private static void testSetAttrShapeList(Shape[] shapes) {
193     try (Graph g = new Graph();
194         Session s = new Session(g)) {
195       int[][] matrix = new int[][] {{0, 0}, {0, 0}};
196       Output<?> queue =
197           g.opBuilder("FIFOQueue", "queue")
198               .setAttr("component_types", new DataType[] {DataType.INT32, DataType.INT32})
199               .setAttr("shapes", shapes)
200               .build()
201               .output(0);
202       assertTrue(hasNode(g, "queue"));
203       Output<Integer> c1 = TestUtil.constant(g, "const1", matrix);
204       Output<Integer> c2 = TestUtil.constant(g, "const2", new int[][][] {matrix, matrix});
205       Operation enqueue =
206           g.opBuilder("QueueEnqueue", "enqueue")
207               .addInput(queue)
208               .addInputList(new Output<?>[] {c1, c2})
209               .build();
210       assertTrue(hasNode(g, "enqueue"));
211 
212       s.runner().addTarget(enqueue).run();
213     }
214   }
215 
hasNode(Graph g, String name)216   private static boolean hasNode(Graph g, String name) {
217     return g.operation(name) != null;
218   }
219 }
220