1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow; 17 18 import static org.junit.Assert.assertEquals; 19 import static org.junit.Assert.assertNotEquals; 20 import static org.junit.Assert.assertNotNull; 21 import static org.junit.Assert.assertTrue; 22 import static org.junit.Assert.fail; 23 24 import java.util.Arrays; 25 import java.util.HashSet; 26 import java.util.Set; 27 import org.junit.Test; 28 import org.junit.runner.RunWith; 29 import org.junit.runners.JUnit4; 30 31 /** Unit tests for {@link org.tensorflow.Operation}. */ 32 @RunWith(JUnit4.class) 33 public class OperationTest { 34 35 @Test outputListLengthFailsOnInvalidName()36 public void outputListLengthFailsOnInvalidName() { 37 try (Graph g = new Graph()) { 38 Operation op = 39 g.opBuilder("Add", "Add") 40 .addInput(TestUtil.constant(g, "x", 1)) 41 .addInput(TestUtil.constant(g, "y", 2)) 42 .build(); 43 assertEquals(1, op.outputListLength("z")); 44 45 try { 46 op.outputListLength("unknown"); 47 fail("Did not catch bad name"); 48 } catch (IllegalArgumentException iae) { 49 // expected 50 } 51 } 52 } 53 54 @Test operationEquality()55 public void operationEquality() { 56 Operation op1; 57 try (Graph g = new Graph()) { 58 op1 = TestUtil.constant(g, "op1", 1).op(); 59 Operation op2 = TestUtil.constant(g, "op2", 2).op(); 60 Operation op3 = new Operation(g, op1.getUnsafeNativeHandle()); 61 Operation op4 = g.operation("op1"); 62 assertEquals(op1, op1); 63 assertNotEquals(op1, op2); 64 assertEquals(op1, op3); 65 assertEquals(op1.hashCode(), op3.hashCode()); 66 assertEquals(op1, op4); 67 assertEquals(op1.hashCode(), op4.hashCode()); 68 assertEquals(op3, op4); 69 assertNotEquals(op2, op3); 70 assertNotEquals(op2, op4); 71 } 72 try (Graph g = new Graph()) { 73 Operation newOp1 = TestUtil.constant(g, "op1", 1).op(); 74 assertNotEquals(op1, newOp1); 75 } 76 } 77 78 @Test operationCollection()79 public void operationCollection() { 80 try (Graph g = new Graph()) { 81 Operation op1 = TestUtil.constant(g, "op1", 1).op(); 82 Operation op2 = TestUtil.constant(g, "op2", 2).op(); 83 Operation op3 = new Operation(g, op1.getUnsafeNativeHandle()); 84 Operation op4 = g.operation("op1"); 85 Set<Operation> ops = new HashSet<>(); 86 ops.addAll(Arrays.asList(op1, op2, op3, op4)); 87 assertEquals(2, ops.size()); 88 assertTrue(ops.contains(op1)); 89 assertTrue(ops.contains(op2)); 90 assertTrue(ops.contains(op3)); 91 assertTrue(ops.contains(op4)); 92 } 93 } 94 95 @Test operationToString()96 public void operationToString() { 97 try (Graph g = new Graph()) { 98 Operation op = TestUtil.constant(g, "c", new int[] {1}).op(); 99 assertNotNull(op.toString()); 100 } 101 } 102 103 @Test outputEquality()104 public void outputEquality() { 105 try (Graph g = new Graph()) { 106 Output<Integer> output = TestUtil.constant(g, "c", 1); 107 Output<Integer> output1 = output.op().<Integer>output(0); 108 Output<Integer> output2 = g.operation("c").<Integer>output(0); 109 assertEquals(output, output1); 110 assertEquals(output.hashCode(), output1.hashCode()); 111 assertEquals(output, output2); 112 assertEquals(output.hashCode(), output2.hashCode()); 113 } 114 } 115 116 @Test outputCollection()117 public void outputCollection() { 118 try (Graph g = new Graph()) { 119 Output<Integer> output = TestUtil.constant(g, "c", 1); 120 Output<Integer> output1 = output.op().<Integer>output(0); 121 Output<Integer> output2 = g.operation("c").<Integer>output(0); 122 Set<Output<Integer>> ops = new HashSet<>(); 123 ops.addAll(Arrays.asList(output, output1, output2)); 124 assertEquals(1, ops.size()); 125 assertTrue(ops.contains(output)); 126 assertTrue(ops.contains(output1)); 127 assertTrue(ops.contains(output2)); 128 } 129 } 130 131 @Test outputToString()132 public void outputToString() { 133 try (Graph g = new Graph()) { 134 Output<Integer> output = TestUtil.constant(g, "c", new int[] {1}); 135 assertNotNull(output.toString()); 136 } 137 } 138 139 @Test outputListLength()140 public void outputListLength() { 141 assertEquals(1, split(new int[] {0, 1}, 1)); 142 assertEquals(2, split(new int[] {0, 1}, 2)); 143 assertEquals(3, split(new int[] {0, 1, 2}, 3)); 144 } 145 146 @Test inputListLength()147 public void inputListLength() { 148 assertEquals(1, splitWithInputList(new int[] {0, 1}, 1, "split_dim")); 149 try { 150 splitWithInputList(new int[] {0, 1}, 2, "inputs"); 151 } catch (IllegalArgumentException iae) { 152 // expected 153 } 154 } 155 156 @Test outputList()157 public void outputList() { 158 try (Graph g = new Graph()) { 159 Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3); 160 Output<?>[] outputs = split.outputList(1, 2); 161 assertNotNull(outputs); 162 assertEquals(2, outputs.length); 163 for (int i = 0; i < outputs.length; ++i) { 164 assertEquals(i + 1, outputs[i].index()); 165 } 166 } 167 } 168 split(int[] values, int num_split)169 private static int split(int[] values, int num_split) { 170 try (Graph g = new Graph()) { 171 return g.opBuilder("Split", "Split") 172 .addInput(TestUtil.constant(g, "split_dim", 0)) 173 .addInput(TestUtil.constant(g, "values", values)) 174 .setAttr("num_split", num_split) 175 .build() 176 .outputListLength("output"); 177 } 178 } 179 splitWithInputList(int[] values, int num_split, String name)180 private static int splitWithInputList(int[] values, int num_split, String name) { 181 try (Graph g = new Graph()) { 182 return g.opBuilder("Split", "Split") 183 .addInput(TestUtil.constant(g, "split_dim", 0)) 184 .addInput(TestUtil.constant(g, "values", values)) 185 .setAttr("num_split", num_split) 186 .build() 187 .inputListLength(name); 188 } 189 } 190 } 191