• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow;
17 
18 import static org.junit.Assert.assertEquals;
19 import static org.junit.Assert.assertTrue;
20 import static org.junit.Assert.fail;
21 
22 import org.junit.Ignore;
23 import org.junit.Test;
24 import org.junit.runner.RunWith;
25 import org.junit.runners.JUnit4;
26 
27 /** Unit tests for {@link org.tensorflow.OperationBuilder}. */
28 @RunWith(JUnit4.class)
29 public class OperationBuilderTest {
30   // TODO(ashankar): Restore this test once the C API gracefully handles mixing graphs and
31   // operations instead of segfaulting.
32   @Test
33   @Ignore
failWhenMixingOperationsOnDifferentGraphs()34   public void failWhenMixingOperationsOnDifferentGraphs() {
35     try (Graph g1 = new Graph();
36         Graph g2 = new Graph()) {
37       Output<Integer> c1 = TestUtil.constant(g1, "C1", 3);
38       Output<Integer> c2 = TestUtil.constant(g2, "C2", 3);
39       TestUtil.addN(g1, c1, c1);
40       try {
41         TestUtil.addN(g2, c1, c2);
42       } catch (Exception e) {
43         fail(e.toString());
44       }
45     }
46   }
47 
48   @Test
failOnUseAfterBuild()49   public void failOnUseAfterBuild() {
50     try (Graph g = new Graph();
51         Tensor<Integer> t = Tensors.create(1)) {
52       OperationBuilder b =
53           g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t);
54       b.build();
55       try {
56         b.setAttr("dtype", t.dataType());
57       } catch (IllegalStateException e) {
58         // expected exception.
59       }
60     }
61   }
62 
63   @Test
failOnUseAfterGraphClose()64   public void failOnUseAfterGraphClose() {
65     OperationBuilder b = null;
66     try (Graph g = new Graph();
67         Tensor<Integer> t = Tensors.create(1)) {
68       b = g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t);
69     }
70     try {
71       b.build();
72     } catch (IllegalStateException e) {
73       // expected exception.
74     }
75   }
76 
77   @Test
setAttr()78   public void setAttr() {
79     // The effect of setting an attribute may not easily be visible from the other parts of this
80     // package's API. Thus, for now, the test simply executes the various setAttr variants to see
81     // that there are no exceptions. If an attribute is "visible", test for that in a separate test
82     // (like setAttrShape).
83     //
84     // This is a bit of an awkward test since it has to find operations with attributes of specific
85     // types that aren't inferred from the input arguments.
86     try (Graph g = new Graph()) {
87       // dtype, tensor attributes.
88       try (Tensor<Integer> t = Tensors.create(1)) {
89         g.opBuilder("Const", "DataTypeAndTensor")
90             .setAttr("dtype", DataType.INT32)
91             .setAttr("value", t)
92             .build()
93             .output(0);
94         assertTrue(hasNode(g, "DataTypeAndTensor"));
95       }
96       // string, bool attributes.
97       g.opBuilder("Abort", "StringAndBool")
98           .setAttr("error_msg", "SomeErrorMessage")
99           .setAttr("exit_without_error", false)
100           .build();
101       assertTrue(hasNode(g, "StringAndBool"));
102       // int (TF "int" attributes are 64-bit signed, so a Java long).
103       g.opBuilder("RandomUniform", "Int")
104           .addInput(TestUtil.constant(g, "RandomUniformShape", new int[] {1}))
105           .setAttr("seed", 10)
106           .setAttr("dtype", DataType.FLOAT)
107           .build();
108       assertTrue(hasNode(g, "Int"));
109       // list(int)
110       g.opBuilder("MaxPool", "IntList")
111           .addInput(TestUtil.constant(g, "MaxPoolInput", new float[2][2][2][2]))
112           .setAttr("ksize", new long[] {1, 1, 1, 1})
113           .setAttr("strides", new long[] {1, 1, 1, 1})
114           .setAttr("padding", "SAME")
115           .build();
116       assertTrue(hasNode(g, "IntList"));
117       // list(float)
118       g.opBuilder("FractionalMaxPool", "FloatList")
119           .addInput(TestUtil.constant(g, "FractionalMaxPoolInput", new float[2][2][2][2]))
120           .setAttr("pooling_ratio", new float[] {1.0f, 1.44f, 1.73f, 1.0f})
121           .build();
122       assertTrue(hasNode(g, "FloatList"));
123       // Missing tests: float, list(dtype), list(tensor), list(string), list(bool)
124     }
125   }
126 
127   @Test
setAttrShape()128   public void setAttrShape() {
129     try (Graph g = new Graph()) {
130       Output<?> n =
131           g.opBuilder("Placeholder", "unknown")
132               .setAttr("dtype", DataType.FLOAT)
133               .setAttr("shape", Shape.unknown())
134               .build()
135               .output(0);
136       assertEquals(-1, n.shape().numDimensions());
137       assertEquals(DataType.FLOAT, n.dataType());
138 
139       n = g.opBuilder("Placeholder", "batch_of_vectors")
140               .setAttr("dtype", DataType.FLOAT)
141               .setAttr("shape", Shape.make(-1, 784))
142               .build()
143               .output(0);
144       assertEquals(2, n.shape().numDimensions());
145       assertEquals(-1, n.shape().size(0));
146       assertEquals(784, n.shape().size(1));
147       assertEquals(DataType.FLOAT, n.dataType());
148     }
149   }
150 
151   @Test
setAttrShapeList()152   public void setAttrShapeList() {
153     // Those shapes match tensors ones, so no exception is thrown
154     testSetAttrShapeList(new Shape[] {Shape.make(2, 2), Shape.make(2, 2, 2)});
155     try {
156       // Those shapes do not match tensors ones, exception is thrown
157       testSetAttrShapeList(new Shape[] {Shape.make(2, 2), Shape.make(2, 2, 2, 2)});
158       fail("Shapes are incompatible and an exception was expected");
159     } catch (IllegalArgumentException e) {
160       // expected
161     }
162   }
163 
164   @Test
addControlInput()165   public void addControlInput() {
166     try (Graph g = new Graph();
167         Session s = new Session(g);
168         Tensor<Boolean> yes = Tensors.create(true);
169         Tensor<Boolean> no = Tensors.create(false)) {
170       Output<Boolean> placeholder = TestUtil.placeholder(g, "boolean", Boolean.class);
171       Operation check =
172           g.opBuilder("Assert", "assert")
173               .addInput(placeholder)
174               .addInputList(new Output<?>[] {placeholder})
175               .build();
176       Operation noop = g.opBuilder("NoOp", "noop").addControlInput(check).build();
177 
178       // No problems when the Assert check succeeds
179       s.runner().feed(placeholder, yes).addTarget(noop).run();
180 
181       // Exception thrown by the execution of the Assert node
182       try {
183         s.runner().feed(placeholder, no).addTarget(noop).run();
184         fail("Did not run control operation.");
185       } catch (IllegalArgumentException e) {
186         // expected
187       }
188     }
189   }
190 
testSetAttrShapeList(Shape[] shapes)191   private static void testSetAttrShapeList(Shape[] shapes) {
192     try (Graph g = new Graph();
193         Session s = new Session(g)) {
194       int[][] matrix = new int[][] {{0, 0}, {0, 0}};
195       Output<?> queue =
196           g.opBuilder("FIFOQueue", "queue")
197               .setAttr("component_types", new DataType[] {DataType.INT32, DataType.INT32})
198               .setAttr("shapes", shapes)
199               .build()
200               .output(0);
201       assertTrue(hasNode(g, "queue"));
202       Output<Integer> c1 = TestUtil.constant(g, "const1", matrix);
203       Output<Integer> c2 = TestUtil.constant(g, "const2", new int[][][] {matrix, matrix});
204       Operation enqueue =
205           g.opBuilder("QueueEnqueue", "enqueue")
206               .addInput(queue)
207               .addInputList(new Output<?>[] {c1, c2})
208               .build();
209       assertTrue(hasNode(g, "enqueue"));
210 
211       s.runner().addTarget(enqueue).run();
212     }
213   }
214 
hasNode(Graph g, String name)215   private static boolean hasNode(Graph g, String name) {
216     return g.operation(name) != null;
217   }
218 }
219