• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.op.core;
17 
18 import static org.junit.Assert.assertArrayEquals;
19 import static org.junit.Assert.assertEquals;
20 import static org.junit.Assert.assertTrue;
21 
22 import java.io.ByteArrayOutputStream;
23 import java.io.DataOutputStream;
24 import java.io.IOException;
25 import java.nio.ByteBuffer;
26 import java.nio.DoubleBuffer;
27 import java.nio.FloatBuffer;
28 import java.nio.IntBuffer;
29 import java.nio.LongBuffer;
30 
31 import org.junit.Test;
32 import org.junit.runner.RunWith;
33 import org.junit.runners.JUnit4;
34 import org.tensorflow.Graph;
35 import org.tensorflow.Session;
36 import org.tensorflow.Tensor;
37 import org.tensorflow.op.Scope;
38 
39 @RunWith(JUnit4.class)
40 public class ConstantTest {
41   private static final float EPSILON = 1e-7f;
42 
43   @Test
createInt()44   public void createInt() {
45     int value = 1;
46 
47     try (Graph g = new Graph();
48         Session sess = new Session(g)) {
49       Scope scope = new Scope(g);
50       Constant<Integer> op = Constant.create(scope, value);
51       try (Tensor<Integer> result = sess.runner().fetch(op).run().get(0).expect(Integer.class)) {
52         assertEquals(value, result.intValue());
53       }
54     }
55   }
56 
57   @Test
createIntBuffer()58   public void createIntBuffer() {
59     int[] ints = {1, 2, 3, 4};
60     long[] shape = {4};
61 
62     try (Graph g = new Graph();
63         Session sess = new Session(g)) {
64       Scope scope = new Scope(g);
65       Constant<Integer> op = Constant.create(scope, shape, IntBuffer.wrap(ints));
66       try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
67         int[] actual = new int[ints.length];
68         assertArrayEquals(ints, result.expect(Integer.class).copyTo(actual));
69       }
70     }
71   }
72 
73   @Test
createFloat()74   public void createFloat() {
75     float value = 1;
76 
77     try (Graph g = new Graph();
78         Session sess = new Session(g)) {
79       Scope scope = new Scope(g);
80       Constant<Float> op = Constant.create(scope, value);
81       try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
82         assertEquals(value, result.expect(Float.class).floatValue(), 0.0f);
83       }
84     }
85   }
86 
87   @Test
createFloatBuffer()88   public void createFloatBuffer() {
89     float[] floats = {1, 2, 3, 4};
90     long[] shape = {4};
91 
92     try (Graph g = new Graph();
93         Session sess = new Session(g)) {
94       Scope scope = new Scope(g);
95       Constant<Float> op = Constant.create(scope, shape, FloatBuffer.wrap(floats));
96       try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
97         float[] actual = new float[floats.length];
98         assertArrayEquals(floats, result.expect(Float.class).copyTo(actual), EPSILON);
99       }
100     }
101   }
102 
103   @Test
createDouble()104   public void createDouble() {
105     double value = 1;
106 
107     try (Graph g = new Graph();
108         Session sess = new Session(g)) {
109       Scope scope = new Scope(g);
110       Constant<Double> op = Constant.create(scope, value);
111       try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
112         assertEquals(value, result.expect(Double.class).doubleValue(), 0.0);
113       }
114     }
115   }
116 
117   @Test
createDoubleBuffer()118   public void createDoubleBuffer() {
119     double[] doubles = {1, 2, 3, 4};
120     long[] shape = {4};
121 
122     try (Graph g = new Graph();
123         Session sess = new Session(g)) {
124       Scope scope = new Scope(g);
125       Constant<Double> op = Constant.create(scope, shape, DoubleBuffer.wrap(doubles));
126       try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
127         double[] actual = new double[doubles.length];
128         assertArrayEquals(doubles, result.expect(Double.class).copyTo(actual), EPSILON);
129       }
130     }
131   }
132 
133   @Test
createLong()134   public void createLong() {
135     long value = 1;
136 
137     try (Graph g = new Graph();
138         Session sess = new Session(g)) {
139       Scope scope = new Scope(g);
140       Constant<Long> op = Constant.create(scope, value);
141       try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
142         assertEquals(value, result.expect(Long.class).longValue());
143       }
144     }
145   }
146 
147   @Test
createLongBuffer()148   public void createLongBuffer() {
149     long[] longs = {1, 2, 3, 4};
150     long[] shape = {4};
151 
152     try (Graph g = new Graph();
153         Session sess = new Session(g)) {
154       Scope scope = new Scope(g);
155       Constant<Long> op = Constant.create(scope, shape, LongBuffer.wrap(longs));
156       try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
157         long[] actual = new long[longs.length];
158         assertArrayEquals(longs, result.expect(Long.class).copyTo(actual));
159       }
160     }
161   }
162 
163   @Test
createBoolean()164   public void createBoolean() {
165     boolean value = true;
166 
167     try (Graph g = new Graph();
168         Session sess = new Session(g)) {
169       Scope scope = new Scope(g);
170       Constant<Boolean> op = Constant.create(scope, value);
171       try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
172         assertEquals(value, result.expect(Boolean.class).booleanValue());
173       }
174     }
175   }
176 
177   @Test
createStringBuffer()178   public void createStringBuffer() throws IOException {
179     byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4};
180     long[] shape = {};
181 
182     // byte arrays (DataType.STRING in Tensorflow) are encoded as an offset in the data buffer,
183     // followed by a varint encoded size, followed by the data.
184     ByteArrayOutputStream baout = new ByteArrayOutputStream();
185     DataOutputStream out = new DataOutputStream(baout);
186     // Offset in array.
187     out.writeLong(0L);
188     // Varint encoded length of buffer.
189     // For any number < 0x80, the varint encoding is simply the number itself.
190     // https://developers.google.com/protocol-buffers/docs/encoding#varints
191     assertTrue(data.length < 0x80);
192     out.write(data.length);
193     out.write(data);
194     out.close();
195     byte[] content = baout.toByteArray();
196 
197     try (Graph g = new Graph();
198         Session sess = new Session(g)) {
199       Scope scope = new Scope(g);
200       Constant<String> op = Constant.create(scope, String.class, shape, ByteBuffer.wrap(content));
201       try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
202         assertArrayEquals(data, result.expect(String.class).bytesValue());
203       }
204     }
205   }
206 }
207