1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow; 17 18 import static org.junit.Assert.assertEquals; 19 import static org.junit.Assert.assertNotEquals; 20 import static org.junit.Assert.assertNotNull; 21 import static org.junit.Assert.assertTrue; 22 import static org.junit.Assert.fail; 23 24 import java.util.Arrays; 25 import java.util.HashSet; 26 import java.util.Set; 27 28 import org.junit.Test; 29 import org.junit.runner.RunWith; 30 import org.junit.runners.JUnit4; 31 32 /** Unit tests for {@link org.tensorflow.GraphOperation}. */ 33 @RunWith(JUnit4.class) 34 public class GraphOperationTest { 35 36 @Test outputListLengthFailsOnInvalidName()37 public void outputListLengthFailsOnInvalidName() { 38 try (Graph g = new Graph()) { 39 Operation op = 40 g.opBuilder("Add", "Add") 41 .addInput(TestUtil.constant(g, "x", 1)) 42 .addInput(TestUtil.constant(g, "y", 2)) 43 .build(); 44 assertEquals(1, op.outputListLength("z")); 45 46 try { 47 op.outputListLength("unknown"); 48 fail("Did not catch bad name"); 49 } catch (IllegalArgumentException iae) { 50 // expected 51 } 52 } 53 } 54 55 @Test operationEquality()56 public void operationEquality() { 57 GraphOperation op1; 58 try (Graph g = new Graph()) { 59 op1 = TestUtil.constantOp(g, "op1", 1); 60 GraphOperation op2 = TestUtil.constantOp(g, "op2", 2); 61 GraphOperation op3 = new GraphOperation(g, op1.getUnsafeNativeHandle()); 62 GraphOperation op4 = g.operation("op1"); 63 assertEquals(op1, op1); 64 assertNotEquals(op1, op2); 65 assertEquals(op1, op3); 66 assertEquals(op1.hashCode(), op3.hashCode()); 67 assertEquals(op1, op4); 68 assertEquals(op1.hashCode(), op4.hashCode()); 69 assertEquals(op3, op4); 70 assertNotEquals(op2, op3); 71 assertNotEquals(op2, op4); 72 } 73 try (Graph g = new Graph()) { 74 Operation newOp1 = TestUtil.constant(g, "op1", 1).op(); 75 assertNotEquals(op1, newOp1); 76 } 77 } 78 79 @Test operationCollection()80 public void operationCollection() { 81 try (Graph g = new Graph()) { 82 GraphOperation op1 = TestUtil.constantOp(g, "op1", 1); 83 GraphOperation op2 = TestUtil.constantOp(g, "op2", 2); 84 GraphOperation op3 = new GraphOperation(g, op1.getUnsafeNativeHandle()); 85 GraphOperation op4 = g.operation("op1"); 86 Set<Operation> ops = new HashSet<>(); 87 ops.addAll(Arrays.asList(op1, op2, op3, op4)); 88 assertEquals(2, ops.size()); 89 assertTrue(ops.contains(op1)); 90 assertTrue(ops.contains(op2)); 91 assertTrue(ops.contains(op3)); 92 assertTrue(ops.contains(op4)); 93 } 94 } 95 96 @Test operationToString()97 public void operationToString() { 98 try (Graph g = new Graph()) { 99 Operation op = TestUtil.constant(g, "c", new int[] {1}).op(); 100 assertNotNull(op.toString()); 101 } 102 } 103 104 @Test outputEquality()105 public void outputEquality() { 106 try (Graph g = new Graph()) { 107 Output<Integer> output = TestUtil.constant(g, "c", 1); 108 Output<Integer> output1 = output.op().<Integer>output(0); 109 Output<Integer> output2 = g.operation("c").<Integer>output(0); 110 assertEquals(output, output1); 111 assertEquals(output.hashCode(), output1.hashCode()); 112 assertEquals(output, output2); 113 assertEquals(output.hashCode(), output2.hashCode()); 114 } 115 } 116 117 @Test outputCollection()118 public void outputCollection() { 119 try (Graph g = new Graph()) { 120 Output<Integer> output = TestUtil.constant(g, "c", 1); 121 Output<Integer> output1 = output.op().<Integer>output(0); 122 Output<Integer> output2 = g.operation("c").<Integer>output(0); 123 Set<Output<Integer>> ops = new HashSet<>(); 124 ops.addAll(Arrays.asList(output, output1, output2)); 125 assertEquals(1, ops.size()); 126 assertTrue(ops.contains(output)); 127 assertTrue(ops.contains(output1)); 128 assertTrue(ops.contains(output2)); 129 } 130 } 131 132 @Test outputToString()133 public void outputToString() { 134 try (Graph g = new Graph()) { 135 Output<Integer> output = TestUtil.constant(g, "c", new int[] {1}); 136 assertNotNull(output.toString()); 137 } 138 } 139 140 @Test outputListLength()141 public void outputListLength() { 142 assertEquals(1, split(new int[] {0, 1}, 1)); 143 assertEquals(2, split(new int[] {0, 1}, 2)); 144 assertEquals(3, split(new int[] {0, 1, 2}, 3)); 145 } 146 147 @Test inputListLength()148 public void inputListLength() { 149 assertEquals(1, splitWithInputList(new int[] {0, 1}, 1, "split_dim")); 150 try { 151 splitWithInputList(new int[] {0, 1}, 2, "inputs"); 152 } catch (IllegalArgumentException iae) { 153 // expected 154 } 155 } 156 157 @Test outputList()158 public void outputList() { 159 try (Graph g = new Graph()) { 160 Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3); 161 Output<?>[] outputs = split.outputList(1, 2); 162 assertNotNull(outputs); 163 assertEquals(2, outputs.length); 164 for (int i = 0; i < outputs.length; ++i) { 165 assertEquals(i + 1, outputs[i].index()); 166 } 167 } 168 } 169 170 @Test outputTensorNotSupported()171 public void outputTensorNotSupported() { 172 try (Graph g = new Graph()) { 173 Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3); 174 try { 175 split.output(0).tensor(); 176 fail(); 177 } catch (IllegalStateException e) { 178 } 179 } 180 } 181 split(int[] values, int num_split)182 private static int split(int[] values, int num_split) { 183 try (Graph g = new Graph()) { 184 return g.opBuilder("Split", "Split") 185 .addInput(TestUtil.constant(g, "split_dim", 0)) 186 .addInput(TestUtil.constant(g, "values", values)) 187 .setAttr("num_split", num_split) 188 .build() 189 .outputListLength("output"); 190 } 191 } 192 splitWithInputList(int[] values, int num_split, String name)193 private static int splitWithInputList(int[] values, int num_split, String name) { 194 try (Graph g = new Graph()) { 195 return g.opBuilder("Split", "Split") 196 .addInput(TestUtil.constant(g, "split_dim", 0)) 197 .addInput(TestUtil.constant(g, "values", values)) 198 .setAttr("num_split", num_split) 199 .build() 200 .inputListLength(name); 201 } 202 } 203 } 204