• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.op.core;
17 
18 import static org.junit.Assert.assertEquals;
19 import static org.junit.Assert.assertFalse;
20 
21 import java.util.List;
22 
23 import org.junit.Test;
24 import org.junit.runner.RunWith;
25 import org.junit.runners.JUnit4;
26 import org.tensorflow.Graph;
27 import org.tensorflow.Session;
28 import org.tensorflow.Tensor;
29 import org.tensorflow.op.Scope;
30 import org.tensorflow.types.UInt8;
31 
32 @RunWith(JUnit4.class)
33 public class ZerosTest {
34   private static final float EPSILON = 1e-7f;
35 
36   @Test
createIntZeros()37   public void createIntZeros() {
38     try (Graph g = new Graph();
39         Session sess = new Session(g)) {
40       Scope scope = new Scope(g);
41       long[] shape = {2, 2};
42       Zeros<Integer> op = Zeros.create(scope, Constant.create(scope, shape), Integer.class);
43       try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
44         int[][] actual = result.expect(Integer.class).copyTo(new int[(int)shape[0]][(int)shape[1]]);
45         for (int i = 0; i < actual.length; ++i) {
46           for (int j = 0; j < actual[i].length; ++j) {
47             assertEquals(0, actual[i][j]);
48           }
49         }
50       }
51     }
52   }
53 
54   @Test
createFloatZeros()55   public void createFloatZeros() {
56     try (Graph g = new Graph();
57         Session sess = new Session(g)) {
58       Scope scope = new Scope(g);
59       long[] shape = {2, 2};
60       Zeros<Float> op = Zeros.create(scope, Constant.create(scope, shape), Float.class);
61       try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
62         float[][] actual = result.expect(Float.class).copyTo(new float[(int)shape[0]][(int)shape[1]]);
63         for (int i = 0; i < actual.length; ++i) {
64           for (int j = 0; j < actual[i].length; ++j) {
65             assertEquals(0.0f, actual[i][j], EPSILON);
66           }
67         }
68       }
69     }
70   }
71 
72   @Test
createDoubleZeros()73   public void createDoubleZeros() {
74     try (Graph g = new Graph();
75         Session sess = new Session(g)) {
76       Scope scope = new Scope(g);
77       long[] shape = {2, 2};
78       Zeros<Double> op = Zeros.create(scope, Constant.create(scope, shape), Double.class);
79       try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
80         double[][] actual = result.expect(Double.class).copyTo(new double[(int)shape[0]][(int)shape[1]]);
81         for (int i = 0; i < actual.length; ++i) {
82           for (int j = 0; j < actual[i].length; ++j) {
83             assertEquals(0.0, actual[i][j], EPSILON);
84           }
85         }
86       }
87     }
88   }
89 
90   @Test
createLongZeros()91   public void createLongZeros() {
92     try (Graph g = new Graph();
93         Session sess = new Session(g)) {
94       Scope scope = new Scope(g);
95       long[] shape = {2, 2};
96       Zeros<Long> op = Zeros.create(scope, Constant.create(scope, shape), Long.class);
97       try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
98         long[][] actual = result.expect(Long.class).copyTo(new long[(int)shape[0]][(int)shape[1]]);
99         for (int i = 0; i < actual.length; ++i) {
100           for (int j = 0; j < actual[i].length; ++j) {
101             assertEquals(0L, actual[i][j]);
102           }
103         }
104       }
105     }
106   }
107 
108   @Test
createBooleanZeros()109   public void createBooleanZeros() {
110     try (Graph g = new Graph();
111         Session sess = new Session(g)) {
112       Scope scope = new Scope(g);
113       long[] shape = {2, 2};
114       Zeros<Boolean> op = Zeros.create(scope, Constant.create(scope, shape), Boolean.class);
115       try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
116         boolean[][] actual = result.expect(Boolean.class).copyTo(new boolean[(int)shape[0]][(int)shape[1]]);
117         for (int i = 0; i < actual.length; ++i) {
118           for (int j = 0; j < actual[i].length; ++j) {
119             assertFalse(actual[i][j]);
120           }
121         }
122       }
123     }
124   }
125 
126   @Test
createUInt8Zeros()127   public void createUInt8Zeros() {
128     try (Graph g = new Graph();
129         Session sess = new Session(g)) {
130       Scope scope = new Scope(g);
131       long[] shape = {2, 2};
132       Zeros<UInt8> op = Zeros.create(scope, Constant.create(scope, shape), UInt8.class);
133       try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
134         byte[][] actual = result.expect(UInt8.class).copyTo(new byte[(int)shape[0]][(int)shape[1]]);
135         result.copyTo(actual);
136         for (int i = 0; i < actual.length; ++i) {
137           for (int j = 0; j < actual[i].length; ++j) {
138             assertEquals(0, actual[i][j]);
139           }
140         }
141       }
142     }
143   }
144 
145   @Test(expected = IllegalArgumentException.class)
cannotCreateStringZeros()146   public void cannotCreateStringZeros() {
147     try (Graph g = new Graph();
148         Session sess = new Session(g)) {
149       Scope scope = new Scope(g);
150       long[] shape = {2, 2};
151       Zeros.create(scope, Constant.create(scope, shape), String.class);
152     }
153   }
154 
155   @Test
operationsComposingZerosAreCorrectlyNamed()156   public void operationsComposingZerosAreCorrectlyNamed() {
157     try (Graph g = new Graph();
158         Session sess = new Session(g)) {
159       Scope scope = new Scope(g);
160       long[] shape = {2, 2};
161       Zeros<Float> zeros = Zeros.create(scope.withSubScope("test"), Constant.create(scope, shape), Float.class);
162       List<Tensor<?>> results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run();
163     }
164   }
165 }
166