1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow; 17 18 import static org.junit.Assert.assertArrayEquals; 19 import static org.junit.Assert.assertEquals; 20 import static org.junit.Assert.assertTrue; 21 import static org.junit.Assert.fail; 22 23 import org.junit.Test; 24 import org.junit.runner.RunWith; 25 import org.junit.runners.JUnit4; 26 27 /** Unit tests for {@link org.tensorflow.Session}. */ 28 @RunWith(JUnit4.class) 29 public class SessionTest { 30 31 @Test runUsingOperationNames()32 public void runUsingOperationNames() { 33 try (Graph g = new Graph(); 34 Session s = new Session(g)) { 35 TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}}); 36 try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}}); 37 TestUtil.AutoCloseableList<Tensor<?>> outputs = 38 new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().feed("X", x).fetch("Y").run())) { 39 assertEquals(1, outputs.size()); 40 final int[][] expected = {{31}}; 41 assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); 42 } 43 } 44 } 45 46 @Test runUsingOperationHandles()47 public void runUsingOperationHandles() { 48 try (Graph g = new Graph(); 49 Session s = new Session(g)) { 50 TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}}); 51 Output<Integer> feed = g.operation("X").output(0); 52 Output<Integer> fetch = g.operation("Y").output(0); 53 try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}}); 54 TestUtil.AutoCloseableList<Tensor<?>> outputs = 55 new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().feed(feed, x).fetch(fetch).run())) { 56 assertEquals(1, outputs.size()); 57 final int[][] expected = {{31}}; 58 assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); 59 } 60 } 61 } 62 63 @Test runUsingColonSeparatedNames()64 public void runUsingColonSeparatedNames() { 65 try (Graph g = new Graph(); 66 Session s = new Session(g)) { 67 Operation split = 68 g.opBuilder("Split", "Split") 69 .addInput(TestUtil.constant(g, "split_dim", 0)) 70 .addInput(TestUtil.constant(g, "value", new int[] {1, 2, 3, 4})) 71 .setAttr("num_split", 2) 72 .build(); 73 g.opBuilder("Add", "Add") 74 .addInput(split.output(0)) 75 .addInput(split.output(1)) 76 .build() 77 .output(0); 78 // Fetch using colon separated names. 79 try (Tensor<Integer> fetched = 80 s.runner().fetch("Split:1").run().get(0).expect(Integer.class)) { 81 final int[] expected = {3, 4}; 82 assertArrayEquals(expected, fetched.copyTo(new int[2])); 83 } 84 // Feed using colon separated names. 85 try (Tensor<Integer> fed = Tensors.create(new int[] {4, 3, 2, 1}); 86 Tensor<Integer> fetched = 87 s.runner() 88 .feed("Split:0", fed) 89 .feed("Split:1", fed) 90 .fetch("Add") 91 .run() 92 .get(0) 93 .expect(Integer.class)) { 94 final int[] expected = {8, 6, 4, 2}; 95 assertArrayEquals(expected, fetched.copyTo(new int[4])); 96 } 97 } 98 } 99 100 @Test runWithMetadata()101 public void runWithMetadata() { 102 try (Graph g = new Graph(); 103 Session s = new Session(g)) { 104 TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}}); 105 try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}})) { 106 Session.Run result = 107 s.runner() 108 .feed("X", x) 109 .fetch("Y") 110 .setOptions(fullTraceRunOptions()) 111 .runAndFetchMetadata(); 112 // Sanity check on outputs. 113 TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<Tensor<?>>(result.outputs); 114 assertEquals(1, outputs.size()); 115 final int[][] expected = {{31}}; 116 assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); 117 // Sanity check on metadata 118 // See comments in fullTraceRunOptions() for an explanation about 119 // why this check is really silly. Ideally, this would be: 120 /* 121 RunMetadata md = RunMetadata.parseFrom(result.metadata); 122 assertTrue(md.toString(), md.hasStepStats()); 123 */ 124 assertTrue(result.metadata.length > 0); 125 outputs.close(); 126 } 127 } 128 } 129 130 @Test runMultipleOutputs()131 public void runMultipleOutputs() { 132 try (Graph g = new Graph(); 133 Session s = new Session(g)) { 134 TestUtil.constant(g, "c1", 2718); 135 TestUtil.constant(g, "c2", 31415); 136 TestUtil.AutoCloseableList<Tensor<?>> outputs = 137 new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().fetch("c2").fetch("c1").run()); 138 assertEquals(2, outputs.size()); 139 assertEquals(31415, outputs.get(0).intValue()); 140 assertEquals(2718, outputs.get(1).intValue()); 141 outputs.close(); 142 } 143 } 144 145 @Test failOnUseAfterClose()146 public void failOnUseAfterClose() { 147 try (Graph g = new Graph()) { 148 Session s = new Session(g); 149 s.close(); 150 try { 151 s.runner().run(); 152 fail("methods on a session should fail after close() is called"); 153 } catch (IllegalStateException e) { 154 // expected exception 155 } 156 } 157 } 158 159 @Test createWithConfigProto()160 public void createWithConfigProto() { 161 try (Graph g = new Graph(); 162 Session s = new Session(g, singleThreadConfigProto())) {} 163 } 164 fullTraceRunOptions()165 private static byte[] fullTraceRunOptions() { 166 // Ideally this would use the generated Java sources for protocol buffers 167 // and end up with something like the snippet below. However, generating 168 // the Java files for the .proto files in tensorflow/core:protos_all is 169 // a bit cumbersome in bazel until the proto_library rule is setup. 170 // 171 // See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866 172 // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362 173 // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558 174 // 175 // For this test, for now, the use of specific bytes suffices. 176 return new byte[] {0x08, 0x03}; 177 /* 178 return org.tensorflow.framework.RunOptions.newBuilder() 179 .setTraceLevel(RunOptions.TraceLevel.FULL_TRACE) 180 .build() 181 .toByteArray(); 182 */ 183 } 184 singleThreadConfigProto()185 public static byte[] singleThreadConfigProto() { 186 // Ideally this would use the generated Java sources for protocol buffers 187 // and end up with something like the snippet below. However, generating 188 // the Java files for the .proto files in tensorflow/core:protos_all is 189 // a bit cumbersome in bazel until the proto_library rule is setup. 190 // 191 // See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866 192 // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362 193 // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558 194 // 195 // For this test, for now, the use of specific bytes suffices. 196 return new byte[] {0x10, 0x01, 0x28, 0x01}; 197 /* 198 return org.tensorflow.framework.ConfigProto.newBuilder() 199 .setInterOpParallelismThreads(1) 200 .setIntraOpParallelismThreads(1) 201 .build() 202 .toByteArray(); 203 */ 204 } 205 } 206