• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow;
17 
18 import static org.junit.Assert.assertArrayEquals;
19 import static org.junit.Assert.assertEquals;
20 import static org.junit.Assert.assertTrue;
21 import static org.junit.Assert.fail;
22 
23 import org.junit.Test;
24 import org.junit.runner.RunWith;
25 import org.junit.runners.JUnit4;
26 
27 /** Unit tests for {@link org.tensorflow.Session}. */
28 @RunWith(JUnit4.class)
29 public class SessionTest {
30 
31   @Test
runUsingOperationNames()32   public void runUsingOperationNames() {
33     try (Graph g = new Graph();
34         Session s = new Session(g)) {
35       TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}});
36       try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}});
37           TestUtil.AutoCloseableList<Tensor<?>> outputs =
38               new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().feed("X", x).fetch("Y").run())) {
39         assertEquals(1, outputs.size());
40         final int[][] expected = {{31}};
41         assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
42       }
43     }
44   }
45 
46   @Test
runUsingOperationHandles()47   public void runUsingOperationHandles() {
48     try (Graph g = new Graph();
49         Session s = new Session(g)) {
50       TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}});
51       Output<Integer> feed = g.operation("X").output(0);
52       Output<Integer> fetch = g.operation("Y").output(0);
53       try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}});
54           TestUtil.AutoCloseableList<Tensor<?>> outputs =
55               new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().feed(feed, x).fetch(fetch).run())) {
56         assertEquals(1, outputs.size());
57         final int[][] expected = {{31}};
58         assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
59       }
60     }
61   }
62 
63   @Test
runUsingColonSeparatedNames()64   public void runUsingColonSeparatedNames() {
65     try (Graph g = new Graph();
66         Session s = new Session(g)) {
67       Operation split =
68           g.opBuilder("Split", "Split")
69               .addInput(TestUtil.constant(g, "split_dim", 0))
70               .addInput(TestUtil.constant(g, "value", new int[] {1, 2, 3, 4}))
71               .setAttr("num_split", 2)
72               .build();
73       g.opBuilder("Add", "Add")
74           .addInput(split.output(0))
75           .addInput(split.output(1))
76           .build()
77           .output(0);
78       // Fetch using colon separated names.
79       try (Tensor<Integer> fetched =
80           s.runner().fetch("Split:1").run().get(0).expect(Integer.class)) {
81         final int[] expected = {3, 4};
82         assertArrayEquals(expected, fetched.copyTo(new int[2]));
83       }
84       // Feed using colon separated names.
85       try (Tensor<Integer> fed = Tensors.create(new int[] {4, 3, 2, 1});
86           Tensor<Integer> fetched =
87               s.runner()
88                   .feed("Split:0", fed)
89                   .feed("Split:1", fed)
90                   .fetch("Add")
91                   .run()
92                   .get(0)
93                   .expect(Integer.class)) {
94         final int[] expected = {8, 6, 4, 2};
95         assertArrayEquals(expected, fetched.copyTo(new int[4]));
96       }
97     }
98   }
99 
100   @Test
runWithMetadata()101   public void runWithMetadata() {
102     try (Graph g = new Graph();
103         Session s = new Session(g)) {
104       TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}});
105       try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}})) {
106         Session.Run result =
107             s.runner()
108                 .feed("X", x)
109                 .fetch("Y")
110                 .setOptions(fullTraceRunOptions())
111                 .runAndFetchMetadata();
112         // Sanity check on outputs.
113         TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<Tensor<?>>(result.outputs);
114         assertEquals(1, outputs.size());
115         final int[][] expected = {{31}};
116         assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
117         // Sanity check on metadata
118         // See comments in fullTraceRunOptions() for an explanation about
119         // why this check is really silly. Ideally, this would be:
120         /*
121             RunMetadata md = RunMetadata.parseFrom(result.metadata);
122             assertTrue(md.toString(), md.hasStepStats());
123         */
124         assertTrue(result.metadata.length > 0);
125         outputs.close();
126       }
127     }
128   }
129 
130   @Test
runMultipleOutputs()131   public void runMultipleOutputs() {
132     try (Graph g = new Graph();
133         Session s = new Session(g)) {
134       TestUtil.constant(g, "c1", 2718);
135       TestUtil.constant(g, "c2", 31415);
136       TestUtil.AutoCloseableList<Tensor<?>> outputs =
137           new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().fetch("c2").fetch("c1").run());
138       assertEquals(2, outputs.size());
139       assertEquals(31415, outputs.get(0).intValue());
140       assertEquals(2718, outputs.get(1).intValue());
141       outputs.close();
142     }
143   }
144 
145   @Test
failOnUseAfterClose()146   public void failOnUseAfterClose() {
147     try (Graph g = new Graph()) {
148       Session s = new Session(g);
149       s.close();
150       try {
151         s.runner().run();
152         fail("methods on a session should fail after close() is called");
153       } catch (IllegalStateException e) {
154         // expected exception
155       }
156     }
157   }
158 
159   @Test
createWithConfigProto()160   public void createWithConfigProto() {
161     try (Graph g = new Graph();
162         Session s = new Session(g, singleThreadConfigProto())) {}
163   }
164 
fullTraceRunOptions()165   private static byte[] fullTraceRunOptions() {
166     // Ideally this would use the generated Java sources for protocol buffers
167     // and end up with something like the snippet below. However, generating
168     // the Java files for the .proto files in tensorflow/core:protos_all is
169     // a bit cumbersome in bazel until the proto_library rule is setup.
170     //
171     // See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866
172     // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362
173     // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558
174     //
175     // For this test, for now, the use of specific bytes suffices.
176     return new byte[] {0x08, 0x03};
177     /*
178     return org.tensorflow.framework.RunOptions.newBuilder()
179         .setTraceLevel(RunOptions.TraceLevel.FULL_TRACE)
180         .build()
181         .toByteArray();
182     */
183   }
184 
singleThreadConfigProto()185   public static byte[] singleThreadConfigProto() {
186     // Ideally this would use the generated Java sources for protocol buffers
187     // and end up with something like the snippet below. However, generating
188     // the Java files for the .proto files in tensorflow/core:protos_all is
189     // a bit cumbersome in bazel until the proto_library rule is setup.
190     //
191     // See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866
192     // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362
193     // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558
194     //
195     // For this test, for now, the use of specific bytes suffices.
196     return new byte[] {0x10, 0x01, 0x28, 0x01};
197     /*
198     return org.tensorflow.framework.ConfigProto.newBuilder()
199         .setInterOpParallelismThreads(1)
200         .setIntraOpParallelismThreads(1)
201         .build()
202         .toByteArray();
203      */
204   }
205 }
206