1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.op.core; 17 18 import static org.junit.Assert.assertEquals; 19 import static org.junit.Assert.assertFalse; 20 21 import java.util.List; 22 23 import org.junit.Test; 24 import org.junit.runner.RunWith; 25 import org.junit.runners.JUnit4; 26 import org.tensorflow.Graph; 27 import org.tensorflow.Session; 28 import org.tensorflow.Tensor; 29 import org.tensorflow.op.Scope; 30 import org.tensorflow.types.UInt8; 31 32 @RunWith(JUnit4.class) 33 public class ZerosTest { 34 private static final float EPSILON = 1e-7f; 35 36 @Test createIntZeros()37 public void createIntZeros() { 38 try (Graph g = new Graph(); 39 Session sess = new Session(g)) { 40 Scope scope = new Scope(g); 41 long[] shape = {2, 2}; 42 Zeros<Integer> op = Zeros.create(scope, Constant.create(scope, shape), Integer.class); 43 try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { 44 int[][] actual = result.expect(Integer.class).copyTo(new int[(int)shape[0]][(int)shape[1]]); 45 for (int i = 0; i < actual.length; ++i) { 46 for (int j = 0; j < actual[i].length; ++j) { 47 assertEquals(0, actual[i][j]); 48 } 49 } 50 } 51 } 52 } 53 54 @Test createFloatZeros()55 public void createFloatZeros() { 56 try (Graph g = new Graph(); 57 Session sess = new Session(g)) { 58 Scope scope = new Scope(g); 59 long[] shape = {2, 2}; 60 Zeros<Float> op = Zeros.create(scope, Constant.create(scope, shape), Float.class); 61 try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { 62 float[][] actual = result.expect(Float.class).copyTo(new float[(int)shape[0]][(int)shape[1]]); 63 for (int i = 0; i < actual.length; ++i) { 64 for (int j = 0; j < actual[i].length; ++j) { 65 assertEquals(0.0f, actual[i][j], EPSILON); 66 } 67 } 68 } 69 } 70 } 71 72 @Test createDoubleZeros()73 public void createDoubleZeros() { 74 try (Graph g = new Graph(); 75 Session sess = new Session(g)) { 76 Scope scope = new Scope(g); 77 long[] shape = {2, 2}; 78 Zeros<Double> op = Zeros.create(scope, Constant.create(scope, shape), Double.class); 79 try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { 80 double[][] actual = result.expect(Double.class).copyTo(new double[(int)shape[0]][(int)shape[1]]); 81 for (int i = 0; i < actual.length; ++i) { 82 for (int j = 0; j < actual[i].length; ++j) { 83 assertEquals(0.0, actual[i][j], EPSILON); 84 } 85 } 86 } 87 } 88 } 89 90 @Test createLongZeros()91 public void createLongZeros() { 92 try (Graph g = new Graph(); 93 Session sess = new Session(g)) { 94 Scope scope = new Scope(g); 95 long[] shape = {2, 2}; 96 Zeros<Long> op = Zeros.create(scope, Constant.create(scope, shape), Long.class); 97 try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { 98 long[][] actual = result.expect(Long.class).copyTo(new long[(int)shape[0]][(int)shape[1]]); 99 for (int i = 0; i < actual.length; ++i) { 100 for (int j = 0; j < actual[i].length; ++j) { 101 assertEquals(0L, actual[i][j]); 102 } 103 } 104 } 105 } 106 } 107 108 @Test createBooleanZeros()109 public void createBooleanZeros() { 110 try (Graph g = new Graph(); 111 Session sess = new Session(g)) { 112 Scope scope = new Scope(g); 113 long[] shape = {2, 2}; 114 Zeros<Boolean> op = Zeros.create(scope, Constant.create(scope, shape), Boolean.class); 115 try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { 116 boolean[][] actual = result.expect(Boolean.class).copyTo(new boolean[(int)shape[0]][(int)shape[1]]); 117 for (int i = 0; i < actual.length; ++i) { 118 for (int j = 0; j < actual[i].length; ++j) { 119 assertFalse(actual[i][j]); 120 } 121 } 122 } 123 } 124 } 125 126 @Test createUInt8Zeros()127 public void createUInt8Zeros() { 128 try (Graph g = new Graph(); 129 Session sess = new Session(g)) { 130 Scope scope = new Scope(g); 131 long[] shape = {2, 2}; 132 Zeros<UInt8> op = Zeros.create(scope, Constant.create(scope, shape), UInt8.class); 133 try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { 134 byte[][] actual = result.expect(UInt8.class).copyTo(new byte[(int)shape[0]][(int)shape[1]]); 135 result.copyTo(actual); 136 for (int i = 0; i < actual.length; ++i) { 137 for (int j = 0; j < actual[i].length; ++j) { 138 assertEquals(0, actual[i][j]); 139 } 140 } 141 } 142 } 143 } 144 145 @Test(expected = IllegalArgumentException.class) cannotCreateStringZeros()146 public void cannotCreateStringZeros() { 147 try (Graph g = new Graph(); 148 Session sess = new Session(g)) { 149 Scope scope = new Scope(g); 150 long[] shape = {2, 2}; 151 Zeros.create(scope, Constant.create(scope, shape), String.class); 152 } 153 } 154 155 @Test operationsComposingZerosAreCorrectlyNamed()156 public void operationsComposingZerosAreCorrectlyNamed() { 157 try (Graph g = new Graph(); 158 Session sess = new Session(g)) { 159 Scope scope = new Scope(g); 160 long[] shape = {2, 2}; 161 Zeros<Float> zeros = Zeros.create(scope.withSubScope("test"), Constant.create(scope, shape), Float.class); 162 List<Tensor<?>> results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run(); 163 } 164 } 165 } 166