1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.op.core; 17 18 import static org.junit.Assert.assertArrayEquals; 19 import static org.junit.Assert.assertEquals; 20 import static org.junit.Assert.assertTrue; 21 22 import java.io.ByteArrayOutputStream; 23 import java.io.DataOutputStream; 24 import java.io.IOException; 25 import java.nio.ByteBuffer; 26 import java.nio.DoubleBuffer; 27 import java.nio.FloatBuffer; 28 import java.nio.IntBuffer; 29 import java.nio.LongBuffer; 30 31 import org.junit.Test; 32 import org.junit.runner.RunWith; 33 import org.junit.runners.JUnit4; 34 import org.tensorflow.Graph; 35 import org.tensorflow.Session; 36 import org.tensorflow.Tensor; 37 import org.tensorflow.op.Scope; 38 39 @RunWith(JUnit4.class) 40 public class ConstantTest { 41 private static final float EPSILON = 1e-7f; 42 43 @Test createInt()44 public void createInt() { 45 int value = 1; 46 47 try (Graph g = new Graph(); 48 Session sess = new Session(g)) { 49 Scope scope = new Scope(g); 50 Constant<Integer> op = Constant.create(scope, value); 51 try (Tensor<Integer> result = sess.runner().fetch(op).run().get(0).expect(Integer.class)) { 52 assertEquals(value, result.intValue()); 53 } 54 } 55 } 56 57 @Test createIntBuffer()58 public void createIntBuffer() { 59 int[] ints = {1, 2, 3, 4}; 60 long[] shape = {4}; 61 62 try (Graph g = new Graph(); 63 Session sess = new Session(g)) { 64 Scope scope = new Scope(g); 65 Constant<Integer> op = Constant.create(scope, shape, IntBuffer.wrap(ints)); 66 try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { 67 int[] actual = new int[ints.length]; 68 assertArrayEquals(ints, result.expect(Integer.class).copyTo(actual)); 69 } 70 } 71 } 72 73 @Test createFloat()74 public void createFloat() { 75 float value = 1; 76 77 try (Graph g = new Graph(); 78 Session sess = new Session(g)) { 79 Scope scope = new Scope(g); 80 Constant<Float> op = Constant.create(scope, value); 81 try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { 82 assertEquals(value, result.expect(Float.class).floatValue(), 0.0f); 83 } 84 } 85 } 86 87 @Test createFloatBuffer()88 public void createFloatBuffer() { 89 float[] floats = {1, 2, 3, 4}; 90 long[] shape = {4}; 91 92 try (Graph g = new Graph(); 93 Session sess = new Session(g)) { 94 Scope scope = new Scope(g); 95 Constant<Float> op = Constant.create(scope, shape, FloatBuffer.wrap(floats)); 96 try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { 97 float[] actual = new float[floats.length]; 98 assertArrayEquals(floats, result.expect(Float.class).copyTo(actual), EPSILON); 99 } 100 } 101 } 102 103 @Test createDouble()104 public void createDouble() { 105 double value = 1; 106 107 try (Graph g = new Graph(); 108 Session sess = new Session(g)) { 109 Scope scope = new Scope(g); 110 Constant<Double> op = Constant.create(scope, value); 111 try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { 112 assertEquals(value, result.expect(Double.class).doubleValue(), 0.0); 113 } 114 } 115 } 116 117 @Test createDoubleBuffer()118 public void createDoubleBuffer() { 119 double[] doubles = {1, 2, 3, 4}; 120 long[] shape = {4}; 121 122 try (Graph g = new Graph(); 123 Session sess = new Session(g)) { 124 Scope scope = new Scope(g); 125 Constant<Double> op = Constant.create(scope, shape, DoubleBuffer.wrap(doubles)); 126 try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { 127 double[] actual = new double[doubles.length]; 128 assertArrayEquals(doubles, result.expect(Double.class).copyTo(actual), EPSILON); 129 } 130 } 131 } 132 133 @Test createLong()134 public void createLong() { 135 long value = 1; 136 137 try (Graph g = new Graph(); 138 Session sess = new Session(g)) { 139 Scope scope = new Scope(g); 140 Constant<Long> op = Constant.create(scope, value); 141 try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { 142 assertEquals(value, result.expect(Long.class).longValue()); 143 } 144 } 145 } 146 147 @Test createLongBuffer()148 public void createLongBuffer() { 149 long[] longs = {1, 2, 3, 4}; 150 long[] shape = {4}; 151 152 try (Graph g = new Graph(); 153 Session sess = new Session(g)) { 154 Scope scope = new Scope(g); 155 Constant<Long> op = Constant.create(scope, shape, LongBuffer.wrap(longs)); 156 try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { 157 long[] actual = new long[longs.length]; 158 assertArrayEquals(longs, result.expect(Long.class).copyTo(actual)); 159 } 160 } 161 } 162 163 @Test createBoolean()164 public void createBoolean() { 165 boolean value = true; 166 167 try (Graph g = new Graph(); 168 Session sess = new Session(g)) { 169 Scope scope = new Scope(g); 170 Constant<Boolean> op = Constant.create(scope, value); 171 try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { 172 assertEquals(value, result.expect(Boolean.class).booleanValue()); 173 } 174 } 175 } 176 177 @Test createStringBuffer()178 public void createStringBuffer() throws IOException { 179 byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4}; 180 long[] shape = {}; 181 182 // byte arrays (DataType.STRING in Tensorflow) are encoded as an offset in the data buffer, 183 // followed by a varint encoded size, followed by the data. 184 ByteArrayOutputStream baout = new ByteArrayOutputStream(); 185 DataOutputStream out = new DataOutputStream(baout); 186 // Offset in array. 187 out.writeLong(0L); 188 // Varint encoded length of buffer. 189 // For any number < 0x80, the varint encoding is simply the number itself. 190 // https://developers.google.com/protocol-buffers/docs/encoding#varints 191 assertTrue(data.length < 0x80); 192 out.write(data.length); 193 out.write(data); 194 out.close(); 195 byte[] content = baout.toByteArray(); 196 197 try (Graph g = new Graph(); 198 Session sess = new Session(g)) { 199 Scope scope = new Scope(g); 200 Constant<String> op = Constant.create(scope, String.class, shape, ByteBuffer.wrap(content)); 201 try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { 202 assertArrayEquals(data, result.expect(String.class).bytesValue()); 203 } 204 } 205 } 206 } 207