• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow;
17 
18 import static org.junit.Assert.assertEquals;
19 import static org.junit.Assert.assertNotEquals;
20 import static org.junit.Assert.assertNotNull;
21 import static org.junit.Assert.assertTrue;
22 import static org.junit.Assert.fail;
23 
24 import java.util.Arrays;
25 import java.util.HashSet;
26 import java.util.Set;
27 import org.junit.Test;
28 import org.junit.runner.RunWith;
29 import org.junit.runners.JUnit4;
30 
31 /** Unit tests for {@link org.tensorflow.Operation}. */
32 @RunWith(JUnit4.class)
33 public class OperationTest {
34 
35   @Test
outputListLengthFailsOnInvalidName()36   public void outputListLengthFailsOnInvalidName() {
37     try (Graph g = new Graph()) {
38       Operation op =
39           g.opBuilder("Add", "Add")
40               .addInput(TestUtil.constant(g, "x", 1))
41               .addInput(TestUtil.constant(g, "y", 2))
42               .build();
43       assertEquals(1, op.outputListLength("z"));
44 
45       try {
46         op.outputListLength("unknown");
47         fail("Did not catch bad name");
48       } catch (IllegalArgumentException iae) {
49         // expected
50       }
51     }
52   }
53 
54   @Test
operationEquality()55   public void operationEquality() {
56     Operation op1;
57     try (Graph g = new Graph()) {
58       op1 = TestUtil.constant(g, "op1", 1).op();
59       Operation op2 = TestUtil.constant(g, "op2", 2).op();
60       Operation op3 = new Operation(g, op1.getUnsafeNativeHandle());
61       Operation op4 = g.operation("op1");
62       assertEquals(op1, op1);
63       assertNotEquals(op1, op2);
64       assertEquals(op1, op3);
65       assertEquals(op1.hashCode(), op3.hashCode());
66       assertEquals(op1, op4);
67       assertEquals(op1.hashCode(), op4.hashCode());
68       assertEquals(op3, op4);
69       assertNotEquals(op2, op3);
70       assertNotEquals(op2, op4);
71     }
72     try (Graph g = new Graph()) {
73       Operation newOp1 = TestUtil.constant(g, "op1", 1).op();
74       assertNotEquals(op1, newOp1);
75     }
76   }
77 
78   @Test
operationCollection()79   public void operationCollection() {
80     try (Graph g = new Graph()) {
81       Operation op1 = TestUtil.constant(g, "op1", 1).op();
82       Operation op2 = TestUtil.constant(g, "op2", 2).op();
83       Operation op3 = new Operation(g, op1.getUnsafeNativeHandle());
84       Operation op4 = g.operation("op1");
85       Set<Operation> ops = new HashSet<>();
86       ops.addAll(Arrays.asList(op1, op2, op3, op4));
87       assertEquals(2, ops.size());
88       assertTrue(ops.contains(op1));
89       assertTrue(ops.contains(op2));
90       assertTrue(ops.contains(op3));
91       assertTrue(ops.contains(op4));
92     }
93   }
94 
95   @Test
operationToString()96   public void operationToString() {
97     try (Graph g = new Graph()) {
98       Operation op = TestUtil.constant(g, "c", new int[] {1}).op();
99       assertNotNull(op.toString());
100     }
101   }
102 
103   @Test
outputEquality()104   public void outputEquality() {
105     try (Graph g = new Graph()) {
106       Output<Integer> output = TestUtil.constant(g, "c", 1);
107       Output<Integer> output1 = output.op().<Integer>output(0);
108       Output<Integer> output2 = g.operation("c").<Integer>output(0);
109       assertEquals(output, output1);
110       assertEquals(output.hashCode(), output1.hashCode());
111       assertEquals(output, output2);
112       assertEquals(output.hashCode(), output2.hashCode());
113     }
114   }
115 
116   @Test
outputCollection()117   public void outputCollection() {
118     try (Graph g = new Graph()) {
119       Output<Integer> output = TestUtil.constant(g, "c", 1);
120       Output<Integer> output1 = output.op().<Integer>output(0);
121       Output<Integer> output2 = g.operation("c").<Integer>output(0);
122       Set<Output<Integer>> ops = new HashSet<>();
123       ops.addAll(Arrays.asList(output, output1, output2));
124       assertEquals(1, ops.size());
125       assertTrue(ops.contains(output));
126       assertTrue(ops.contains(output1));
127       assertTrue(ops.contains(output2));
128     }
129   }
130 
131   @Test
outputToString()132   public void outputToString() {
133     try (Graph g = new Graph()) {
134       Output<Integer> output = TestUtil.constant(g, "c", new int[] {1});
135       assertNotNull(output.toString());
136     }
137   }
138 
139   @Test
outputListLength()140   public void outputListLength() {
141     assertEquals(1, split(new int[] {0, 1}, 1));
142     assertEquals(2, split(new int[] {0, 1}, 2));
143     assertEquals(3, split(new int[] {0, 1, 2}, 3));
144   }
145 
146   @Test
inputListLength()147   public void inputListLength() {
148     assertEquals(1, splitWithInputList(new int[] {0, 1}, 1, "split_dim"));
149     try {
150       splitWithInputList(new int[] {0, 1}, 2, "inputs");
151     } catch (IllegalArgumentException iae) {
152       // expected
153     }
154   }
155 
156   @Test
outputList()157   public void outputList() {
158     try (Graph g = new Graph()) {
159       Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3);
160       Output<?>[] outputs = split.outputList(1, 2);
161       assertNotNull(outputs);
162       assertEquals(2, outputs.length);
163       for (int i = 0; i < outputs.length; ++i) {
164         assertEquals(i + 1, outputs[i].index());
165       }
166     }
167   }
168 
split(int[] values, int num_split)169   private static int split(int[] values, int num_split) {
170     try (Graph g = new Graph()) {
171       return g.opBuilder("Split", "Split")
172           .addInput(TestUtil.constant(g, "split_dim", 0))
173           .addInput(TestUtil.constant(g, "values", values))
174           .setAttr("num_split", num_split)
175           .build()
176           .outputListLength("output");
177     }
178   }
179 
splitWithInputList(int[] values, int num_split, String name)180   private static int splitWithInputList(int[] values, int num_split, String name) {
181     try (Graph g = new Graph()) {
182       return g.opBuilder("Split", "Split")
183           .addInput(TestUtil.constant(g, "split_dim", 0))
184           .addInput(TestUtil.constant(g, "values", values))
185           .setAttr("num_split", num_split)
186           .build()
187           .inputListLength(name);
188     }
189   }
190 }
191