• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.op;
17 
18 import static org.junit.Assert.assertEquals;
19 import static org.junit.Assert.assertNotNull;
20 import static org.junit.Assert.fail;
21 
22 import java.util.HashMap;
23 import java.util.Map;
24 import org.junit.Test;
25 import org.junit.runner.RunWith;
26 import org.junit.runners.JUnit4;
27 import org.tensorflow.Graph;
28 import org.tensorflow.Output;
29 import org.tensorflow.Session;
30 import org.tensorflow.Tensor;
31 import org.tensorflow.Tensors;
32 import org.tensorflow.types.UInt8;
33 
34 /** Unit tests for {@link org.tensorflow.Scope}. */
35 @RunWith(JUnit4.class)
36 public class ScopeTest {
37 
38   @Test
basicNames()39   public void basicNames() {
40     try (Graph g = new Graph()) {
41       Scope root = new Scope(g);
42       assertEquals("add", root.makeOpName("add"));
43       assertEquals("add_1", root.makeOpName("add"));
44       assertEquals("add_2", root.makeOpName("add"));
45       assertEquals("mul", root.makeOpName("mul"));
46     }
47   }
48 
49   @Test
hierarchicalNames()50   public void hierarchicalNames() {
51     try (Graph g = new Graph()) {
52       Scope root = new Scope(g);
53       Scope child = root.withSubScope("child");
54       assertEquals("child/add", child.makeOpName("add"));
55       assertEquals("child/add_1", child.makeOpName("add"));
56       assertEquals("child/mul", child.makeOpName("mul"));
57 
58       Scope child_1 = root.withSubScope("child");
59       assertEquals("child_1/add", child_1.makeOpName("add"));
60       assertEquals("child_1/add_1", child_1.makeOpName("add"));
61       assertEquals("child_1/mul", child_1.makeOpName("mul"));
62 
63       Scope c_c = root.withSubScope("c").withSubScope("c");
64       assertEquals("c/c/add", c_c.makeOpName("add"));
65 
66       Scope c_1 = root.withSubScope("c");
67       Scope c_1_c = c_1.withSubScope("c");
68       assertEquals("c_1/c/add", c_1_c.makeOpName("add"));
69 
70       Scope c_1_c_1 = c_1.withSubScope("c");
71       assertEquals("c_1/c_1/add", c_1_c_1.makeOpName("add"));
72     }
73   }
74 
75   @Test
scopeAndOpNames()76   public void scopeAndOpNames() {
77     try (Graph g = new Graph()) {
78       Scope root = new Scope(g);
79 
80       Scope child = root.withSubScope("child");
81 
82       assertEquals("child/add", child.makeOpName("add"));
83       assertEquals("child_1", root.makeOpName("child"));
84       assertEquals("child_2/p", root.withSubScope("child").makeOpName("p"));
85     }
86   }
87 
88   @Test
validateNames()89   public void validateNames() {
90     try (Graph g = new Graph()) {
91       Scope root = new Scope(g);
92 
93       final String[] invalid_names = {
94         "_", "-", "-x", // Names are constrained to start with [A-Za-z0-9.]
95         null, "", "a$", // Invalid characters
96         "a/b", // slashes not allowed
97       };
98 
99       for (String name : invalid_names) {
100         try {
101           root.withName(name);
102           fail("failed to catch invalid op name.");
103         } catch (IllegalArgumentException ex) {
104           // expected
105         }
106         // Subscopes follow the same rules
107         try {
108           root.withSubScope(name);
109           fail("failed to catch invalid scope name: " + name);
110         } catch (IllegalArgumentException ex) {
111           // expected
112         }
113       }
114 
115       // Unusual but valid names.
116       final String[] valid_names = {".", "..", "._-.", "a--."};
117 
118       for (String name : valid_names) {
119         root.withName(name);
120         root.withSubScope(name);
121       }
122     }
123   }
124 
125   @Test
basic()126   public void basic() {
127     try (Graph g = new Graph()) {
128       Scope s = new Scope(g);
129       Const<Integer> c1 = Const.create(s, 42);
130       assertEquals("Const", c1.output().op().name());
131       Const<Integer> c2 = Const.create(s, 7);
132       assertEquals("Const_1", c2.output().op().name());
133       Const<Integer> c3 = Const.create(s.withName("four"), 4);
134       assertEquals("four", c3.output().op().name());
135       Const<Integer> c4 = Const.create(s.withName("four"), 4);
136       assertEquals("four_1", c4.output().op().name());
137     }
138   }
139 
140   @Test
hierarchy()141   public void hierarchy() {
142     try (Graph g = new Graph()) {
143       Scope root = new Scope(g);
144       Scope child = root.withSubScope("child");
145       assertEquals("child/Const", Const.create(child, 42).output().op().name());
146       assertEquals("child/four", Const.create(child.withName("four"), 4).output().op().name());
147     }
148   }
149 
150   @Test
composite()151   public void composite() {
152     try (Graph g = new Graph();
153         Session sess = new Session(g)) {
154       Scope s = new Scope(g);
155       Output<Integer> data =
156           Const.create(s.withName("data"), new int[] {600, 470, 170, 430, 300}).output();
157 
158       // Create a composite op with a customized name
159       Variance<Integer> var1 = Variance.create(s.withName("example"), data, Integer.class);
160       assertEquals("example/variance", var1.output().op().name());
161 
162       // Confirm internally added ops have the right names.
163       assertNotNull(g.operation("example/squared_deviation"));
164       assertNotNull(g.operation("example/Mean"));
165       // assertNotNull(g.operation("example/zero"));
166 
167       // Same composite op with a default name
168       Variance<Integer> var2 = Variance.create(s, data, Integer.class);
169       assertEquals("variance/variance", var2.output().op().name());
170 
171       // Confirm internally added ops have the right names.
172       assertNotNull(g.operation("variance/squared_deviation"));
173       assertNotNull(g.operation("variance/Mean"));
174       // assertNotNull(g.operation("variance/zero"));
175 
176       // Verify correct results as well.
177       Tensor<Integer> result =
178           sess.runner().fetch(var1.output()).run().get(0).expect(Integer.class);
179       assertEquals(21704, result.intValue());
180       result = sess.runner().fetch(var2.output()).run().get(0).expect(Integer.class);
181       assertEquals(21704, result.intValue());
182     }
183   }
184 
185   // "handwritten" sample operator classes
186   private static final class Const<T> {
187     private final Output<T> output;
188 
create(Scope s, int v)189     static Const<Integer> create(Scope s, int v) {
190       return create(s, Tensors.create(v));
191     }
192 
create(Scope s, int[] v)193     static Const<Integer> create(Scope s, int[] v) {
194       return create(s, Tensors.create(v));
195     }
196 
create(Scope s, Tensor<T> value)197     static <T> Const<T> create(Scope s, Tensor<T> value) {
198       return new Const<T>(
199           s.graph()
200               .opBuilder("Const", s.makeOpName("Const"))
201               .setAttr("dtype", value.dataType())
202               .setAttr("value", value)
203               .build()
204               .<T>output(0));
205     }
206 
create(Scope s, Object v, Class<T> type)207     static <T> Const<T> create(Scope s, Object v, Class<T> type) {
208       try (Tensor<T> value = Tensor.create(v, type)) {
209         return new Const<T>(
210             s.graph()
211                 .opBuilder("Const", s.makeOpName("Const"))
212                 .setAttr("dtype", value.dataType())
213                 .setAttr("value", value)
214                 .build()
215                 .<T>output(0));
216       }
217     }
218 
Const(Output<T> o)219     Const(Output<T> o) {
220       output = o;
221     }
222 
output()223     Output<T> output() {
224       return output;
225     }
226   }
227 
228   private static final class Mean<T> {
229     private final Output<T> output;
230 
create(Scope s, Output<T> input, Output<T> reductionIndices)231     static <T> Mean<T> create(Scope s, Output<T> input, Output<T> reductionIndices) {
232       return new Mean<T>(
233           s.graph()
234               .opBuilder("Mean", s.makeOpName("Mean"))
235               .addInput(input)
236               .addInput(reductionIndices)
237               .build()
238               .<T>output(0));
239     }
240 
Mean(Output<T> o)241     Mean(Output<T> o) {
242       output = o;
243     }
244 
output()245     Output<T> output() {
246       return output;
247     }
248   }
249 
250   private static final class SquaredDifference<T> {
251     private final Output<T> output;
252 
create(Scope s, Output<T> x, Output<T> y)253     static <T> SquaredDifference<T> create(Scope s, Output<T> x, Output<T> y) {
254       return new SquaredDifference<T>(
255           s.graph()
256               .opBuilder("SquaredDifference", s.makeOpName("SquaredDifference"))
257               .addInput(x)
258               .addInput(y)
259               .build()
260               .<T>output(0));
261     }
262 
SquaredDifference(Output<T> o)263     SquaredDifference(Output<T> o) {
264       output = o;
265     }
266 
output()267     Output<T> output() {
268       return output;
269     }
270   }
271 
272   /**
273    * Returns the zero value of type described by {@code c}, or null if the type (e.g., string) is
274    * not numeric and therefore has no zero value.
275    *
276    * @param c The class describing the TensorFlow type of interest.
277    */
zeroValue(Class<?> c)278   public static Object zeroValue(Class<?> c) {
279     return zeros.get(c);
280   }
281 
282   private static final Map<Class<?>, Object> zeros = new HashMap<>();
283 
284   static {
zeros.put(Float.class, 0.0f)285     zeros.put(Float.class, 0.0f);
zeros.put(Double.class, 0.0)286     zeros.put(Double.class, 0.0);
zeros.put(Integer.class, 0)287     zeros.put(Integer.class, 0);
zeros.put(UInt8.class, (byte) 0)288     zeros.put(UInt8.class, (byte) 0);
zeros.put(Long.class, 0L)289     zeros.put(Long.class, 0L);
zeros.put(Boolean.class, false)290     zeros.put(Boolean.class, false);
zeros.put(String.class, null)291     zeros.put(String.class, null); // no zero value
292   }
293 
294   private static final class Variance<T> {
295     private final Output<T> output;
296 
create(Scope base, Output<T> x, Class<T> type)297     static <T> Variance<T> create(Scope base, Output<T> x, Class<T> type) {
298       Scope s = base.withSubScope("variance");
299       Output<T> zero = Const.create(base, zeroValue(type), type).output();
300       Output<T> sqdiff =
301           SquaredDifference.create(
302                   s.withName("squared_deviation"), x, Mean.create(s, x, zero).output())
303               .output();
304 
305       return new Variance<T>(Mean.create(s.withName("variance"), sqdiff, zero).output());
306     }
307 
Variance(Output<T> o)308     Variance(Output<T> o) {
309       output = o;
310     }
311 
output()312     Output<T> output() {
313       return output;
314     }
315   }
316 }
317