• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.op.core;
17 
18 import static org.junit.Assert.assertArrayEquals;
19 import static org.junit.Assert.assertEquals;
20 import static org.junit.Assert.assertTrue;
21 
22 import java.io.ByteArrayOutputStream;
23 import java.io.DataOutputStream;
24 import java.io.IOException;
25 import java.nio.ByteBuffer;
26 import java.nio.DoubleBuffer;
27 import java.nio.FloatBuffer;
28 import java.nio.IntBuffer;
29 import java.nio.LongBuffer;
30 import org.junit.Test;
31 import org.junit.runner.RunWith;
32 import org.junit.runners.JUnit4;
33 import org.tensorflow.Graph;
34 import org.tensorflow.Session;
35 import org.tensorflow.Tensor;
36 import org.tensorflow.op.Scope;
37 
38 @RunWith(JUnit4.class)
39 public class ConstantTest {
40   private static final float EPSILON = 1e-7f;
41 
42   @Test
createInt()43   public void createInt() {
44     int value = 1;
45 
46     try (Graph g = new Graph();
47         Session sess = new Session(g)) {
48       Scope scope = new Scope(g);
49       Constant<Integer> op = Constant.create(scope, value);
50       try (Tensor<Integer> result = sess.runner().fetch(op).run().get(0).expect(Integer.class)) {
51         assertEquals(value, result.intValue());
52       }
53     }
54   }
55 
56   @Test
createIntBuffer()57   public void createIntBuffer() {
58     int[] ints = {1, 2, 3, 4};
59     long[] shape = {4};
60 
61     try (Graph g = new Graph();
62         Session sess = new Session(g)) {
63       Scope scope = new Scope(g);
64       Constant<Integer> op = Constant.create(scope, shape, IntBuffer.wrap(ints));
65       try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
66         int[] actual = new int[ints.length];
67         assertArrayEquals(ints, result.expect(Integer.class).copyTo(actual));
68       }
69     }
70   }
71 
72   @Test
createFloat()73   public void createFloat() {
74     float value = 1;
75 
76     try (Graph g = new Graph();
77         Session sess = new Session(g)) {
78       Scope scope = new Scope(g);
79       Constant<Float> op = Constant.create(scope, value);
80       try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
81         assertEquals(value, result.expect(Float.class).floatValue(), 0.0f);
82       }
83     }
84   }
85 
86   @Test
createFloatBuffer()87   public void createFloatBuffer() {
88     float[] floats = {1, 2, 3, 4};
89     long[] shape = {4};
90 
91     try (Graph g = new Graph();
92         Session sess = new Session(g)) {
93       Scope scope = new Scope(g);
94       Constant<Float> op = Constant.create(scope, shape, FloatBuffer.wrap(floats));
95       try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
96         float[] actual = new float[floats.length];
97         assertArrayEquals(floats, result.expect(Float.class).copyTo(actual), EPSILON);
98       }
99     }
100   }
101 
102   @Test
createDouble()103   public void createDouble() {
104     double value = 1;
105 
106     try (Graph g = new Graph();
107         Session sess = new Session(g)) {
108       Scope scope = new Scope(g);
109       Constant<Double> op = Constant.create(scope, value);
110       try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
111         assertEquals(value, result.expect(Double.class).doubleValue(), 0.0);
112       }
113     }
114   }
115 
116   @Test
createDoubleBuffer()117   public void createDoubleBuffer() {
118     double[] doubles = {1, 2, 3, 4};
119     long[] shape = {4};
120 
121     try (Graph g = new Graph();
122         Session sess = new Session(g)) {
123       Scope scope = new Scope(g);
124       Constant<Double> op = Constant.create(scope, shape, DoubleBuffer.wrap(doubles));
125       try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
126         double[] actual = new double[doubles.length];
127         assertArrayEquals(doubles, result.expect(Double.class).copyTo(actual), EPSILON);
128       }
129     }
130   }
131 
132   @Test
createLong()133   public void createLong() {
134     long value = 1;
135 
136     try (Graph g = new Graph();
137         Session sess = new Session(g)) {
138       Scope scope = new Scope(g);
139       Constant<Long> op = Constant.create(scope, value);
140       try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
141         assertEquals(value, result.expect(Long.class).longValue());
142       }
143     }
144   }
145 
146   @Test
createLongBuffer()147   public void createLongBuffer() {
148     long[] longs = {1, 2, 3, 4};
149     long[] shape = {4};
150 
151     try (Graph g = new Graph();
152         Session sess = new Session(g)) {
153       Scope scope = new Scope(g);
154       Constant<Long> op = Constant.create(scope, shape, LongBuffer.wrap(longs));
155       try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
156         long[] actual = new long[longs.length];
157         assertArrayEquals(longs, result.expect(Long.class).copyTo(actual));
158       }
159     }
160   }
161 
162   @Test
createBoolean()163   public void createBoolean() {
164     boolean value = true;
165 
166     try (Graph g = new Graph();
167         Session sess = new Session(g)) {
168       Scope scope = new Scope(g);
169       Constant<Boolean> op = Constant.create(scope, value);
170       try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
171         assertEquals(value, result.expect(Boolean.class).booleanValue());
172       }
173     }
174   }
175 
176   @Test
createStringBuffer()177   public void createStringBuffer() throws IOException {
178     byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4};
179     long[] shape = {};
180 
181     ByteArrayOutputStream baout = new ByteArrayOutputStream();
182     DataOutputStream out = new DataOutputStream(baout);
183     // We construct a TF_TString_Small tstring, which has the capacity for a 22 byte string.
184     // The first 6 most significant bits of the first byte represent length; the remaining
185     // 2-bits are type indicators, and are left as 0b00 to denote a TF_TSTR_SMALL type.
186     assertTrue(data.length <= 22);
187     out.writeByte(data.length << 2);
188     out.write(data);
189     out.close();
190     byte[] content = baout.toByteArray();
191 
192     try (Graph g = new Graph();
193         Session sess = new Session(g)) {
194       Scope scope = new Scope(g);
195       Constant<String> op = Constant.create(scope, String.class, shape, ByteBuffer.wrap(content));
196       try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
197         assertArrayEquals(data, result.expect(String.class).bytesValue());
198       }
199     }
200   }
201 }
202