• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow;
17 
18 import static org.junit.Assert.assertEquals;
19 import static org.junit.Assert.assertNotEquals;
20 import static org.junit.Assert.assertNotNull;
21 import static org.junit.Assert.assertTrue;
22 import static org.junit.Assert.fail;
23 
24 import java.util.Arrays;
25 import java.util.HashSet;
26 import java.util.Set;
27 
28 import org.junit.Test;
29 import org.junit.runner.RunWith;
30 import org.junit.runners.JUnit4;
31 
32 /** Unit tests for {@link org.tensorflow.GraphOperation}. */
33 @RunWith(JUnit4.class)
34 public class GraphOperationTest {
35 
36   @Test
outputListLengthFailsOnInvalidName()37   public void outputListLengthFailsOnInvalidName() {
38     try (Graph g = new Graph()) {
39       Operation op =
40           g.opBuilder("Add", "Add")
41               .addInput(TestUtil.constant(g, "x", 1))
42               .addInput(TestUtil.constant(g, "y", 2))
43               .build();
44       assertEquals(1, op.outputListLength("z"));
45 
46       try {
47         op.outputListLength("unknown");
48         fail("Did not catch bad name");
49       } catch (IllegalArgumentException iae) {
50         // expected
51       }
52     }
53   }
54 
55   @Test
operationEquality()56   public void operationEquality() {
57     GraphOperation op1;
58     try (Graph g = new Graph()) {
59       op1 = TestUtil.constantOp(g, "op1", 1);
60       GraphOperation op2 = TestUtil.constantOp(g, "op2", 2);
61       GraphOperation op3 = new GraphOperation(g, op1.getUnsafeNativeHandle());
62       GraphOperation op4 = g.operation("op1");
63       assertEquals(op1, op1);
64       assertNotEquals(op1, op2);
65       assertEquals(op1, op3);
66       assertEquals(op1.hashCode(), op3.hashCode());
67       assertEquals(op1, op4);
68       assertEquals(op1.hashCode(), op4.hashCode());
69       assertEquals(op3, op4);
70       assertNotEquals(op2, op3);
71       assertNotEquals(op2, op4);
72     }
73     try (Graph g = new Graph()) {
74       Operation newOp1 = TestUtil.constant(g, "op1", 1).op();
75       assertNotEquals(op1, newOp1);
76     }
77   }
78 
79   @Test
operationCollection()80   public void operationCollection() {
81     try (Graph g = new Graph()) {
82       GraphOperation op1 = TestUtil.constantOp(g, "op1", 1);
83       GraphOperation op2 = TestUtil.constantOp(g, "op2", 2);
84       GraphOperation op3 = new GraphOperation(g, op1.getUnsafeNativeHandle());
85       GraphOperation op4 = g.operation("op1");
86       Set<Operation> ops = new HashSet<>();
87       ops.addAll(Arrays.asList(op1, op2, op3, op4));
88       assertEquals(2, ops.size());
89       assertTrue(ops.contains(op1));
90       assertTrue(ops.contains(op2));
91       assertTrue(ops.contains(op3));
92       assertTrue(ops.contains(op4));
93     }
94   }
95 
96   @Test
operationToString()97   public void operationToString() {
98     try (Graph g = new Graph()) {
99       Operation op = TestUtil.constant(g, "c", new int[] {1}).op();
100       assertNotNull(op.toString());
101     }
102   }
103 
104   @Test
outputEquality()105   public void outputEquality() {
106     try (Graph g = new Graph()) {
107       Output<Integer> output = TestUtil.constant(g, "c", 1);
108       Output<Integer> output1 = output.op().<Integer>output(0);
109       Output<Integer> output2 = g.operation("c").<Integer>output(0);
110       assertEquals(output, output1);
111       assertEquals(output.hashCode(), output1.hashCode());
112       assertEquals(output, output2);
113       assertEquals(output.hashCode(), output2.hashCode());
114     }
115   }
116 
117   @Test
outputCollection()118   public void outputCollection() {
119     try (Graph g = new Graph()) {
120       Output<Integer> output = TestUtil.constant(g, "c", 1);
121       Output<Integer> output1 = output.op().<Integer>output(0);
122       Output<Integer> output2 = g.operation("c").<Integer>output(0);
123       Set<Output<Integer>> ops = new HashSet<>();
124       ops.addAll(Arrays.asList(output, output1, output2));
125       assertEquals(1, ops.size());
126       assertTrue(ops.contains(output));
127       assertTrue(ops.contains(output1));
128       assertTrue(ops.contains(output2));
129     }
130   }
131 
132   @Test
outputToString()133   public void outputToString() {
134     try (Graph g = new Graph()) {
135       Output<Integer> output = TestUtil.constant(g, "c", new int[] {1});
136       assertNotNull(output.toString());
137     }
138   }
139 
140   @Test
outputListLength()141   public void outputListLength() {
142     assertEquals(1, split(new int[] {0, 1}, 1));
143     assertEquals(2, split(new int[] {0, 1}, 2));
144     assertEquals(3, split(new int[] {0, 1, 2}, 3));
145   }
146 
147   @Test
inputListLength()148   public void inputListLength() {
149     assertEquals(1, splitWithInputList(new int[] {0, 1}, 1, "split_dim"));
150     try {
151       splitWithInputList(new int[] {0, 1}, 2, "inputs");
152     } catch (IllegalArgumentException iae) {
153       // expected
154     }
155   }
156 
157   @Test
outputList()158   public void outputList() {
159     try (Graph g = new Graph()) {
160       Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3);
161       Output<?>[] outputs = split.outputList(1, 2);
162       assertNotNull(outputs);
163       assertEquals(2, outputs.length);
164       for (int i = 0; i < outputs.length; ++i) {
165         assertEquals(i + 1, outputs[i].index());
166       }
167     }
168   }
169 
170   @Test
outputTensorNotSupported()171   public void outputTensorNotSupported() {
172     try (Graph g = new Graph()) {
173       Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3);
174       try {
175         split.output(0).tensor();
176         fail();
177       } catch (IllegalStateException e) {
178       }
179     }
180   }
181 
split(int[] values, int num_split)182   private static int split(int[] values, int num_split) {
183     try (Graph g = new Graph()) {
184       return g.opBuilder("Split", "Split")
185           .addInput(TestUtil.constant(g, "split_dim", 0))
186           .addInput(TestUtil.constant(g, "values", values))
187           .setAttr("num_split", num_split)
188           .build()
189           .outputListLength("output");
190     }
191   }
192 
splitWithInputList(int[] values, int num_split, String name)193   private static int splitWithInputList(int[] values, int num_split, String name) {
194     try (Graph g = new Graph()) {
195       return g.opBuilder("Split", "Split")
196           .addInput(TestUtil.constant(g, "split_dim", 0))
197           .addInput(TestUtil.constant(g, "values", values))
198           .setAttr("num_split", num_split)
199           .build()
200           .inputListLength(name);
201     }
202   }
203 }
204