1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.op.core; 17 18 import static org.junit.Assert.assertArrayEquals; 19 import static org.junit.Assert.assertEquals; 20 import static org.junit.Assert.assertTrue; 21 22 import java.io.ByteArrayOutputStream; 23 import java.io.DataOutputStream; 24 import java.io.IOException; 25 import java.nio.ByteBuffer; 26 import java.nio.DoubleBuffer; 27 import java.nio.FloatBuffer; 28 import java.nio.IntBuffer; 29 import java.nio.LongBuffer; 30 import org.junit.Test; 31 import org.junit.runner.RunWith; 32 import org.junit.runners.JUnit4; 33 import org.tensorflow.Graph; 34 import org.tensorflow.Session; 35 import org.tensorflow.Tensor; 36 import org.tensorflow.op.Scope; 37 38 @RunWith(JUnit4.class) 39 public class ConstantTest { 40 private static final float EPSILON = 1e-7f; 41 42 @Test createInt()43 public void createInt() { 44 int value = 1; 45 46 try (Graph g = new Graph(); 47 Session sess = new Session(g)) { 48 Scope scope = new Scope(g); 49 Constant<Integer> op = Constant.create(scope, value); 50 try (Tensor<Integer> result = sess.runner().fetch(op).run().get(0).expect(Integer.class)) { 51 assertEquals(value, result.intValue()); 52 } 53 } 54 } 55 56 @Test createIntBuffer()57 public void createIntBuffer() { 58 int[] ints = {1, 2, 3, 4}; 59 long[] shape = {4}; 60 61 try (Graph g = new Graph(); 62 Session sess = new Session(g)) { 63 Scope scope = new Scope(g); 64 Constant<Integer> op = Constant.create(scope, shape, IntBuffer.wrap(ints)); 65 try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { 66 int[] actual = new int[ints.length]; 67 assertArrayEquals(ints, result.expect(Integer.class).copyTo(actual)); 68 } 69 } 70 } 71 72 @Test createFloat()73 public void createFloat() { 74 float value = 1; 75 76 try (Graph g = new Graph(); 77 Session sess = new Session(g)) { 78 Scope scope = new Scope(g); 79 Constant<Float> op = Constant.create(scope, value); 80 try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { 81 assertEquals(value, result.expect(Float.class).floatValue(), 0.0f); 82 } 83 } 84 } 85 86 @Test createFloatBuffer()87 public void createFloatBuffer() { 88 float[] floats = {1, 2, 3, 4}; 89 long[] shape = {4}; 90 91 try (Graph g = new Graph(); 92 Session sess = new Session(g)) { 93 Scope scope = new Scope(g); 94 Constant<Float> op = Constant.create(scope, shape, FloatBuffer.wrap(floats)); 95 try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { 96 float[] actual = new float[floats.length]; 97 assertArrayEquals(floats, result.expect(Float.class).copyTo(actual), EPSILON); 98 } 99 } 100 } 101 102 @Test createDouble()103 public void createDouble() { 104 double value = 1; 105 106 try (Graph g = new Graph(); 107 Session sess = new Session(g)) { 108 Scope scope = new Scope(g); 109 Constant<Double> op = Constant.create(scope, value); 110 try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { 111 assertEquals(value, result.expect(Double.class).doubleValue(), 0.0); 112 } 113 } 114 } 115 116 @Test createDoubleBuffer()117 public void createDoubleBuffer() { 118 double[] doubles = {1, 2, 3, 4}; 119 long[] shape = {4}; 120 121 try (Graph g = new Graph(); 122 Session sess = new Session(g)) { 123 Scope scope = new Scope(g); 124 Constant<Double> op = Constant.create(scope, shape, DoubleBuffer.wrap(doubles)); 125 try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { 126 double[] actual = new double[doubles.length]; 127 assertArrayEquals(doubles, result.expect(Double.class).copyTo(actual), EPSILON); 128 } 129 } 130 } 131 132 @Test createLong()133 public void createLong() { 134 long value = 1; 135 136 try (Graph g = new Graph(); 137 Session sess = new Session(g)) { 138 Scope scope = new Scope(g); 139 Constant<Long> op = Constant.create(scope, value); 140 try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { 141 assertEquals(value, result.expect(Long.class).longValue()); 142 } 143 } 144 } 145 146 @Test createLongBuffer()147 public void createLongBuffer() { 148 long[] longs = {1, 2, 3, 4}; 149 long[] shape = {4}; 150 151 try (Graph g = new Graph(); 152 Session sess = new Session(g)) { 153 Scope scope = new Scope(g); 154 Constant<Long> op = Constant.create(scope, shape, LongBuffer.wrap(longs)); 155 try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { 156 long[] actual = new long[longs.length]; 157 assertArrayEquals(longs, result.expect(Long.class).copyTo(actual)); 158 } 159 } 160 } 161 162 @Test createBoolean()163 public void createBoolean() { 164 boolean value = true; 165 166 try (Graph g = new Graph(); 167 Session sess = new Session(g)) { 168 Scope scope = new Scope(g); 169 Constant<Boolean> op = Constant.create(scope, value); 170 try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { 171 assertEquals(value, result.expect(Boolean.class).booleanValue()); 172 } 173 } 174 } 175 176 @Test createStringBuffer()177 public void createStringBuffer() throws IOException { 178 byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4}; 179 long[] shape = {}; 180 181 ByteArrayOutputStream baout = new ByteArrayOutputStream(); 182 DataOutputStream out = new DataOutputStream(baout); 183 // We construct a TF_TString_Small tstring, which has the capacity for a 22 byte string. 184 // The first 6 most significant bits of the first byte represent length; the remaining 185 // 2-bits are type indicators, and are left as 0b00 to denote a TF_TSTR_SMALL type. 186 assertTrue(data.length <= 22); 187 out.writeByte(data.length << 2); 188 out.write(data); 189 out.close(); 190 byte[] content = baout.toByteArray(); 191 192 try (Graph g = new Graph(); 193 Session sess = new Session(g)) { 194 Scope scope = new Scope(g); 195 Constant<String> op = Constant.create(scope, String.class, shape, ByteBuffer.wrap(content)); 196 try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { 197 assertArrayEquals(data, result.expect(String.class).bytesValue()); 198 } 199 } 200 } 201 } 202