1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.op.core; 17 18 import static org.junit.Assert.assertEquals; 19 import static org.junit.Assert.assertFalse; 20 import static org.junit.Assert.fail; 21 22 import java.util.List; 23 24 import org.junit.Test; 25 import org.junit.runner.RunWith; 26 import org.junit.runners.JUnit4; 27 import org.tensorflow.Graph; 28 import org.tensorflow.Session; 29 import org.tensorflow.Tensor; 30 import org.tensorflow.op.Scope; 31 import org.tensorflow.types.UInt8; 32 33 @RunWith(JUnit4.class) 34 public class ZerosTest { 35 private static final float EPSILON = 1e-7f; 36 37 @Test createIntZeros()38 public void createIntZeros() { 39 try (Graph g = new Graph(); 40 Session sess = new Session(g)) { 41 Scope scope = new Scope(g); 42 long[] shape = {2, 2}; 43 Zeros<Integer> op = Zeros.create(scope, Constant.create(scope, shape), Integer.class); 44 try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { 45 int[][] actual = result.expect(Integer.class).copyTo(new int[(int)shape[0]][(int)shape[1]]); 46 for (int i = 0; i < actual.length; ++i) { 47 for (int j = 0; j < actual[i].length; ++j) { 48 assertEquals(0, actual[i][j]); 49 } 50 } 51 } 52 } 53 } 54 55 @Test createFloatZeros()56 public void createFloatZeros() { 57 try (Graph g = new Graph(); 58 Session sess = new Session(g)) { 59 Scope scope = new Scope(g); 60 long[] shape = {2, 2}; 61 Zeros<Float> op = Zeros.create(scope, Constant.create(scope, shape), Float.class); 62 try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { 63 float[][] actual = result.expect(Float.class).copyTo(new float[(int)shape[0]][(int)shape[1]]); 64 for (int i = 0; i < actual.length; ++i) { 65 for (int j = 0; j < actual[i].length; ++j) { 66 assertEquals(0.0f, actual[i][j], EPSILON); 67 } 68 } 69 } 70 } 71 } 72 73 @Test createDoubleZeros()74 public void createDoubleZeros() { 75 try (Graph g = new Graph(); 76 Session sess = new Session(g)) { 77 Scope scope = new Scope(g); 78 long[] shape = {2, 2}; 79 Zeros<Double> op = Zeros.create(scope, Constant.create(scope, shape), Double.class); 80 try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { 81 double[][] actual = result.expect(Double.class).copyTo(new double[(int)shape[0]][(int)shape[1]]); 82 for (int i = 0; i < actual.length; ++i) { 83 for (int j = 0; j < actual[i].length; ++j) { 84 assertEquals(0.0, actual[i][j], EPSILON); 85 } 86 } 87 } 88 } 89 } 90 91 @Test createLongZeros()92 public void createLongZeros() { 93 try (Graph g = new Graph(); 94 Session sess = new Session(g)) { 95 Scope scope = new Scope(g); 96 long[] shape = {2, 2}; 97 Zeros<Long> op = Zeros.create(scope, Constant.create(scope, shape), Long.class); 98 try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { 99 long[][] actual = result.expect(Long.class).copyTo(new long[(int)shape[0]][(int)shape[1]]); 100 for (int i = 0; i < actual.length; ++i) { 101 for (int j = 0; j < actual[i].length; ++j) { 102 assertEquals(0L, actual[i][j]); 103 } 104 } 105 } 106 } 107 } 108 109 @Test createBooleanZeros()110 public void createBooleanZeros() { 111 try (Graph g = new Graph(); 112 Session sess = new Session(g)) { 113 Scope scope = new Scope(g); 114 long[] shape = {2, 2}; 115 Zeros<Boolean> op = Zeros.create(scope, Constant.create(scope, shape), Boolean.class); 116 try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { 117 boolean[][] actual = result.expect(Boolean.class).copyTo(new boolean[(int)shape[0]][(int)shape[1]]); 118 for (int i = 0; i < actual.length; ++i) { 119 for (int j = 0; j < actual[i].length; ++j) { 120 assertFalse(actual[i][j]); 121 } 122 } 123 } 124 } 125 } 126 127 @Test createUInt8Zeros()128 public void createUInt8Zeros() { 129 try (Graph g = new Graph(); 130 Session sess = new Session(g)) { 131 Scope scope = new Scope(g); 132 long[] shape = {2, 2}; 133 Zeros<UInt8> op = Zeros.create(scope, Constant.create(scope, shape), UInt8.class); 134 try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { 135 byte[][] actual = result.expect(UInt8.class).copyTo(new byte[(int)shape[0]][(int)shape[1]]); 136 result.copyTo(actual); 137 for (int i = 0; i < actual.length; ++i) { 138 for (int j = 0; j < actual[i].length; ++j) { 139 assertEquals(0, actual[i][j]); 140 } 141 } 142 } 143 } 144 } 145 146 @Test cannotCreateStringZeros()147 public void cannotCreateStringZeros() { 148 try (Graph g = new Graph(); 149 Session sess = new Session(g)) { 150 Scope scope = new Scope(g); 151 long[] shape = {2, 2}; 152 Zeros.create(scope, Constant.create(scope, shape), String.class); 153 fail(); 154 } catch (IllegalArgumentException expected) {} 155 } 156 157 @Test operationsComposingZerosAreCorrectlyNamed()158 public void operationsComposingZerosAreCorrectlyNamed() { 159 try (Graph g = new Graph(); 160 Session sess = new Session(g)) { 161 Scope scope = new Scope(g); 162 long[] shape = {2, 2}; 163 Zeros<Float> zeros = Zeros.create(scope.withSubScope("test"), Constant.create(scope, shape), Float.class); 164 List<Tensor<?>> results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run(); 165 } 166 } 167 } 168