1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.op; 17 18 import static org.junit.Assert.assertEquals; 19 import static org.junit.Assert.assertNotNull; 20 import static org.junit.Assert.fail; 21 22 import java.util.HashMap; 23 import java.util.Map; 24 import org.junit.Test; 25 import org.junit.runner.RunWith; 26 import org.junit.runners.JUnit4; 27 import org.tensorflow.Graph; 28 import org.tensorflow.Output; 29 import org.tensorflow.Session; 30 import org.tensorflow.Tensor; 31 import org.tensorflow.Tensors; 32 import org.tensorflow.types.UInt8; 33 34 /** Unit tests for {@link org.tensorflow.Scope}. */ 35 @RunWith(JUnit4.class) 36 public class ScopeTest { 37 38 @Test basicNames()39 public void basicNames() { 40 try (Graph g = new Graph()) { 41 Scope root = new Scope(g); 42 assertEquals("add", root.makeOpName("add")); 43 assertEquals("add_1", root.makeOpName("add")); 44 assertEquals("add_2", root.makeOpName("add")); 45 assertEquals("mul", root.makeOpName("mul")); 46 } 47 } 48 49 @Test hierarchicalNames()50 public void hierarchicalNames() { 51 try (Graph g = new Graph()) { 52 Scope root = new Scope(g); 53 Scope child = root.withSubScope("child"); 54 assertEquals("child/add", child.makeOpName("add")); 55 assertEquals("child/add_1", child.makeOpName("add")); 56 assertEquals("child/mul", child.makeOpName("mul")); 57 58 Scope child_1 = root.withSubScope("child"); 59 assertEquals("child_1/add", child_1.makeOpName("add")); 60 assertEquals("child_1/add_1", child_1.makeOpName("add")); 61 assertEquals("child_1/mul", child_1.makeOpName("mul")); 62 63 Scope c_c = root.withSubScope("c").withSubScope("c"); 64 assertEquals("c/c/add", c_c.makeOpName("add")); 65 66 Scope c_1 = root.withSubScope("c"); 67 Scope c_1_c = c_1.withSubScope("c"); 68 assertEquals("c_1/c/add", c_1_c.makeOpName("add")); 69 70 Scope c_1_c_1 = c_1.withSubScope("c"); 71 assertEquals("c_1/c_1/add", c_1_c_1.makeOpName("add")); 72 } 73 } 74 75 @Test scopeAndOpNames()76 public void scopeAndOpNames() { 77 try (Graph g = new Graph()) { 78 Scope root = new Scope(g); 79 80 Scope child = root.withSubScope("child"); 81 82 assertEquals("child/add", child.makeOpName("add")); 83 assertEquals("child_1", root.makeOpName("child")); 84 assertEquals("child_2/p", root.withSubScope("child").makeOpName("p")); 85 } 86 } 87 88 @Test validateNames()89 public void validateNames() { 90 try (Graph g = new Graph()) { 91 Scope root = new Scope(g); 92 93 final String[] invalid_names = { 94 "_", "-", "-x", // Names are constrained to start with [A-Za-z0-9.] 95 null, "", "a$", // Invalid characters 96 "a/b", // slashes not allowed 97 }; 98 99 for (String name : invalid_names) { 100 try { 101 root.withName(name); 102 fail("failed to catch invalid op name."); 103 } catch (IllegalArgumentException ex) { 104 // expected 105 } 106 // Subscopes follow the same rules 107 try { 108 root.withSubScope(name); 109 fail("failed to catch invalid scope name: " + name); 110 } catch (IllegalArgumentException ex) { 111 // expected 112 } 113 } 114 115 // Unusual but valid names. 116 final String[] valid_names = {".", "..", "._-.", "a--."}; 117 118 for (String name : valid_names) { 119 root.withName(name); 120 root.withSubScope(name); 121 } 122 } 123 } 124 125 @Test basic()126 public void basic() { 127 try (Graph g = new Graph()) { 128 Scope s = new Scope(g); 129 Const<Integer> c1 = Const.create(s, 42); 130 assertEquals("Const", c1.output().op().name()); 131 Const<Integer> c2 = Const.create(s, 7); 132 assertEquals("Const_1", c2.output().op().name()); 133 Const<Integer> c3 = Const.create(s.withName("four"), 4); 134 assertEquals("four", c3.output().op().name()); 135 Const<Integer> c4 = Const.create(s.withName("four"), 4); 136 assertEquals("four_1", c4.output().op().name()); 137 } 138 } 139 140 @Test hierarchy()141 public void hierarchy() { 142 try (Graph g = new Graph()) { 143 Scope root = new Scope(g); 144 Scope child = root.withSubScope("child"); 145 assertEquals("child/Const", Const.create(child, 42).output().op().name()); 146 assertEquals("child/four", Const.create(child.withName("four"), 4).output().op().name()); 147 } 148 } 149 150 @Test composite()151 public void composite() { 152 try (Graph g = new Graph(); 153 Session sess = new Session(g)) { 154 Scope s = new Scope(g); 155 Output<Integer> data = 156 Const.create(s.withName("data"), new int[] {600, 470, 170, 430, 300}).output(); 157 158 // Create a composite op with a customized name 159 Variance<Integer> var1 = Variance.create(s.withName("example"), data, Integer.class); 160 assertEquals("example/variance", var1.output().op().name()); 161 162 // Confirm internally added ops have the right names. 163 assertNotNull(g.operation("example/squared_deviation")); 164 assertNotNull(g.operation("example/Mean")); 165 // assertNotNull(g.operation("example/zero")); 166 167 // Same composite op with a default name 168 Variance<Integer> var2 = Variance.create(s, data, Integer.class); 169 assertEquals("variance/variance", var2.output().op().name()); 170 171 // Confirm internally added ops have the right names. 172 assertNotNull(g.operation("variance/squared_deviation")); 173 assertNotNull(g.operation("variance/Mean")); 174 // assertNotNull(g.operation("variance/zero")); 175 176 // Verify correct results as well. 177 Tensor<Integer> result = 178 sess.runner().fetch(var1.output()).run().get(0).expect(Integer.class); 179 assertEquals(21704, result.intValue()); 180 result = sess.runner().fetch(var2.output()).run().get(0).expect(Integer.class); 181 assertEquals(21704, result.intValue()); 182 } 183 } 184 185 // "handwritten" sample operator classes 186 private static final class Const<T> { 187 private final Output<T> output; 188 create(Scope s, int v)189 static Const<Integer> create(Scope s, int v) { 190 return create(s, Tensors.create(v)); 191 } 192 create(Scope s, int[] v)193 static Const<Integer> create(Scope s, int[] v) { 194 return create(s, Tensors.create(v)); 195 } 196 create(Scope s, Tensor<T> value)197 static <T> Const<T> create(Scope s, Tensor<T> value) { 198 return new Const<T>( 199 s.graph() 200 .opBuilder("Const", s.makeOpName("Const")) 201 .setAttr("dtype", value.dataType()) 202 .setAttr("value", value) 203 .build() 204 .<T>output(0)); 205 } 206 create(Scope s, Object v, Class<T> type)207 static <T> Const<T> create(Scope s, Object v, Class<T> type) { 208 try (Tensor<T> value = Tensor.create(v, type)) { 209 return new Const<T>( 210 s.graph() 211 .opBuilder("Const", s.makeOpName("Const")) 212 .setAttr("dtype", value.dataType()) 213 .setAttr("value", value) 214 .build() 215 .<T>output(0)); 216 } 217 } 218 Const(Output<T> o)219 Const(Output<T> o) { 220 output = o; 221 } 222 output()223 Output<T> output() { 224 return output; 225 } 226 } 227 228 private static final class Mean<T> { 229 private final Output<T> output; 230 create(Scope s, Output<T> input, Output<T> reductionIndices)231 static <T> Mean<T> create(Scope s, Output<T> input, Output<T> reductionIndices) { 232 return new Mean<T>( 233 s.graph() 234 .opBuilder("Mean", s.makeOpName("Mean")) 235 .addInput(input) 236 .addInput(reductionIndices) 237 .build() 238 .<T>output(0)); 239 } 240 Mean(Output<T> o)241 Mean(Output<T> o) { 242 output = o; 243 } 244 output()245 Output<T> output() { 246 return output; 247 } 248 } 249 250 private static final class SquaredDifference<T> { 251 private final Output<T> output; 252 create(Scope s, Output<T> x, Output<T> y)253 static <T> SquaredDifference<T> create(Scope s, Output<T> x, Output<T> y) { 254 return new SquaredDifference<T>( 255 s.graph() 256 .opBuilder("SquaredDifference", s.makeOpName("SquaredDifference")) 257 .addInput(x) 258 .addInput(y) 259 .build() 260 .<T>output(0)); 261 } 262 SquaredDifference(Output<T> o)263 SquaredDifference(Output<T> o) { 264 output = o; 265 } 266 output()267 Output<T> output() { 268 return output; 269 } 270 } 271 272 /** 273 * Returns the zero value of type described by {@code c}, or null if the type (e.g., string) is 274 * not numeric and therefore has no zero value. 275 * 276 * @param c The class describing the TensorFlow type of interest. 277 */ zeroValue(Class<?> c)278 public static Object zeroValue(Class<?> c) { 279 return zeros.get(c); 280 } 281 282 private static final Map<Class<?>, Object> zeros = new HashMap<>(); 283 284 static { zeros.put(Float.class, 0.0f)285 zeros.put(Float.class, 0.0f); zeros.put(Double.class, 0.0)286 zeros.put(Double.class, 0.0); zeros.put(Integer.class, 0)287 zeros.put(Integer.class, 0); zeros.put(UInt8.class, (byte) 0)288 zeros.put(UInt8.class, (byte) 0); zeros.put(Long.class, 0L)289 zeros.put(Long.class, 0L); zeros.put(Boolean.class, false)290 zeros.put(Boolean.class, false); zeros.put(String.class, null)291 zeros.put(String.class, null); // no zero value 292 } 293 294 private static final class Variance<T> { 295 private final Output<T> output; 296 create(Scope base, Output<T> x, Class<T> type)297 static <T> Variance<T> create(Scope base, Output<T> x, Class<T> type) { 298 Scope s = base.withSubScope("variance"); 299 Output<T> zero = Const.create(base, zeroValue(type), type).output(); 300 Output<T> sqdiff = 301 SquaredDifference.create( 302 s.withName("squared_deviation"), x, Mean.create(s, x, zero).output()) 303 .output(); 304 305 return new Variance<T>(Mean.create(s.withName("variance"), sqdiff, zero).output()); 306 } 307 Variance(Output<T> o)308 Variance(Output<T> o) { 309 output = o; 310 } 311 output()312 Output<T> output() { 313 return output; 314 } 315 } 316 } 317