• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.op.core;
17 
18 import static org.junit.Assert.assertEquals;
19 import static org.junit.Assert.assertFalse;
20 import static org.junit.Assert.fail;
21 
22 import java.util.List;
23 
24 import org.junit.Test;
25 import org.junit.runner.RunWith;
26 import org.junit.runners.JUnit4;
27 import org.tensorflow.Graph;
28 import org.tensorflow.Session;
29 import org.tensorflow.Tensor;
30 import org.tensorflow.op.Scope;
31 import org.tensorflow.types.UInt8;
32 
33 @RunWith(JUnit4.class)
34 public class ZerosTest {
35   private static final float EPSILON = 1e-7f;
36 
37   @Test
createIntZeros()38   public void createIntZeros() {
39     try (Graph g = new Graph();
40         Session sess = new Session(g)) {
41       Scope scope = new Scope(g);
42       long[] shape = {2, 2};
43       Zeros<Integer> op = Zeros.create(scope, Constant.create(scope, shape), Integer.class);
44       try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
45         int[][] actual = result.expect(Integer.class).copyTo(new int[(int)shape[0]][(int)shape[1]]);
46         for (int i = 0; i < actual.length; ++i) {
47           for (int j = 0; j < actual[i].length; ++j) {
48             assertEquals(0, actual[i][j]);
49           }
50         }
51       }
52     }
53   }
54 
55   @Test
createFloatZeros()56   public void createFloatZeros() {
57     try (Graph g = new Graph();
58         Session sess = new Session(g)) {
59       Scope scope = new Scope(g);
60       long[] shape = {2, 2};
61       Zeros<Float> op = Zeros.create(scope, Constant.create(scope, shape), Float.class);
62       try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
63         float[][] actual = result.expect(Float.class).copyTo(new float[(int)shape[0]][(int)shape[1]]);
64         for (int i = 0; i < actual.length; ++i) {
65           for (int j = 0; j < actual[i].length; ++j) {
66             assertEquals(0.0f, actual[i][j], EPSILON);
67           }
68         }
69       }
70     }
71   }
72 
73   @Test
createDoubleZeros()74   public void createDoubleZeros() {
75     try (Graph g = new Graph();
76         Session sess = new Session(g)) {
77       Scope scope = new Scope(g);
78       long[] shape = {2, 2};
79       Zeros<Double> op = Zeros.create(scope, Constant.create(scope, shape), Double.class);
80       try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
81         double[][] actual = result.expect(Double.class).copyTo(new double[(int)shape[0]][(int)shape[1]]);
82         for (int i = 0; i < actual.length; ++i) {
83           for (int j = 0; j < actual[i].length; ++j) {
84             assertEquals(0.0, actual[i][j], EPSILON);
85           }
86         }
87       }
88     }
89   }
90 
91   @Test
createLongZeros()92   public void createLongZeros() {
93     try (Graph g = new Graph();
94         Session sess = new Session(g)) {
95       Scope scope = new Scope(g);
96       long[] shape = {2, 2};
97       Zeros<Long> op = Zeros.create(scope, Constant.create(scope, shape), Long.class);
98       try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
99         long[][] actual = result.expect(Long.class).copyTo(new long[(int)shape[0]][(int)shape[1]]);
100         for (int i = 0; i < actual.length; ++i) {
101           for (int j = 0; j < actual[i].length; ++j) {
102             assertEquals(0L, actual[i][j]);
103           }
104         }
105       }
106     }
107   }
108 
109   @Test
createBooleanZeros()110   public void createBooleanZeros() {
111     try (Graph g = new Graph();
112         Session sess = new Session(g)) {
113       Scope scope = new Scope(g);
114       long[] shape = {2, 2};
115       Zeros<Boolean> op = Zeros.create(scope, Constant.create(scope, shape), Boolean.class);
116       try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
117         boolean[][] actual = result.expect(Boolean.class).copyTo(new boolean[(int)shape[0]][(int)shape[1]]);
118         for (int i = 0; i < actual.length; ++i) {
119           for (int j = 0; j < actual[i].length; ++j) {
120             assertFalse(actual[i][j]);
121           }
122         }
123       }
124     }
125   }
126 
127   @Test
createUInt8Zeros()128   public void createUInt8Zeros() {
129     try (Graph g = new Graph();
130         Session sess = new Session(g)) {
131       Scope scope = new Scope(g);
132       long[] shape = {2, 2};
133       Zeros<UInt8> op = Zeros.create(scope, Constant.create(scope, shape), UInt8.class);
134       try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
135         byte[][] actual = result.expect(UInt8.class).copyTo(new byte[(int)shape[0]][(int)shape[1]]);
136         result.copyTo(actual);
137         for (int i = 0; i < actual.length; ++i) {
138           for (int j = 0; j < actual[i].length; ++j) {
139             assertEquals(0, actual[i][j]);
140           }
141         }
142       }
143     }
144   }
145 
146   @Test
cannotCreateStringZeros()147   public void cannotCreateStringZeros() {
148     try (Graph g = new Graph();
149         Session sess = new Session(g)) {
150       Scope scope = new Scope(g);
151       long[] shape = {2, 2};
152       Zeros.create(scope, Constant.create(scope, shape), String.class);
153       fail();
154     } catch (IllegalArgumentException expected) {}
155   }
156 
157   @Test
operationsComposingZerosAreCorrectlyNamed()158   public void operationsComposingZerosAreCorrectlyNamed() {
159     try (Graph g = new Graph();
160         Session sess = new Session(g)) {
161       Scope scope = new Scope(g);
162       long[] shape = {2, 2};
163       Zeros<Float> zeros = Zeros.create(scope.withSubScope("test"), Constant.create(scope, shape), Float.class);
164       List<Tensor<?>> results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run();
165     }
166   }
167 }
168