1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow; 17 18 import static org.junit.Assert.assertEquals; 19 import static org.junit.Assert.fail; 20 21 import org.junit.Test; 22 import org.junit.runner.RunWith; 23 import org.junit.runners.JUnit4; 24 25 /** Unit tests for {@link EagerOperation} class. */ 26 @RunWith(JUnit4.class) 27 public class EagerOperationTest { 28 29 @Test failToCreateIfSessionIsClosed()30 public void failToCreateIfSessionIsClosed() { 31 EagerSession session = EagerSession.create(); 32 session.close(); 33 try { 34 new EagerOperation(session, 1L, new long[] {1L}, "Add", "add"); 35 fail(); 36 } catch (IllegalStateException e) { 37 // expected 38 } 39 } 40 41 @Test outputDataTypeAndShape()42 public void outputDataTypeAndShape() { 43 try (EagerSession session = EagerSession.create(); 44 Tensor<Integer> t = Tensors.create(new int[2][3])) { 45 EagerOperation op = 46 opBuilder(session, "Const", "OutputAttrs") 47 .setAttr("dtype", DataType.INT32) 48 .setAttr("value", t) 49 .build(); 50 assertEquals(DataType.INT32, op.dtype(0)); 51 assertEquals(2, op.shape(0)[0]); 52 assertEquals(3, op.shape(0)[1]); 53 } 54 } 55 56 @Test outputTensor()57 public void outputTensor() { 58 try (EagerSession session = EagerSession.create()) { 59 EagerOperation add = 60 opBuilder(session, "Add", "CompareResult") 61 .addInput(TestUtil.constant(session, "Const1", 2)) 62 .addInput(TestUtil.constant(session, "Const2", 4)) 63 .build(); 64 assertEquals(6, add.tensor(0).intValue()); 65 66 // Validate that we retrieve the right shape and datatype from the tensor 67 // that has been resolved 68 assertEquals(0, add.shape(0).length); 69 assertEquals(DataType.INT32, add.dtype(0)); 70 } 71 } 72 73 @Test inputAndOutputListLengths()74 public void inputAndOutputListLengths() { 75 try (EagerSession session = EagerSession.create()) { 76 Output<Float> c1 = TestUtil.constant(session, "Const1", new float[] {1f, 2f}); 77 Output<Float> c2 = TestUtil.constant(session, "Const2", new float[] {3f, 4f}); 78 79 EagerOperation acc = 80 opBuilder(session, "AddN", "InputListLength") 81 .addInputList(new Output<?>[] {c1, c2}) 82 .build(); 83 assertEquals(2, acc.inputListLength("inputs")); 84 assertEquals(1, acc.outputListLength("sum")); 85 86 EagerOperation split = 87 opBuilder(session, "Split", "OutputListLength") 88 .addInput(TestUtil.constant(session, "Axis", 0)) 89 .addInput(c1) 90 .setAttr("num_split", 2) 91 .build(); 92 assertEquals(1, split.inputListLength("split_dim")); 93 assertEquals(2, split.outputListLength("output")); 94 95 try { 96 split.inputListLength("no_such_input"); 97 fail(); 98 } catch (IllegalArgumentException e) { 99 // expected 100 } 101 102 try { 103 split.outputListLength("no_such_output"); 104 fail(); 105 } catch (IllegalArgumentException e) { 106 // expected 107 } 108 } 109 } 110 111 @Test numOutputs()112 public void numOutputs() { 113 try (EagerSession session = EagerSession.create()) { 114 EagerOperation op = 115 opBuilder(session, "UniqueWithCountsV2", "unq") 116 .addInput(TestUtil.constant(session, "Const1", new int[] {1, 2, 1})) 117 .addInput(TestUtil.constant(session, "Axis", new int[] {0})) 118 .setAttr("out_idx", DataType.INT32) 119 .build(); 120 assertEquals(3, op.numOutputs()); 121 } 122 } 123 124 @Test opNotAccessibleIfSessionIsClosed()125 public void opNotAccessibleIfSessionIsClosed() { 126 EagerSession session = EagerSession.create(); 127 EagerOperation add = 128 opBuilder(session, "Add", "SessionClosed") 129 .addInput(TestUtil.constant(session, "Const1", 2)) 130 .addInput(TestUtil.constant(session, "Const2", 4)) 131 .build(); 132 assertEquals(1, add.outputListLength("z")); 133 session.close(); 134 try { 135 add.outputListLength("z"); 136 fail(); 137 } catch (IllegalStateException e) { 138 // expected 139 } 140 } 141 142 @Test outputIndexOutOfBounds()143 public void outputIndexOutOfBounds() { 144 try (EagerSession session = EagerSession.create()) { 145 EagerOperation add = 146 opBuilder(session, "Add", "OutOfRange") 147 .addInput(TestUtil.constant(session, "Const1", 2)) 148 .addInput(TestUtil.constant(session, "Const2", 4)) 149 .build(); 150 try { 151 add.getUnsafeNativeHandle(1); 152 fail(); 153 } catch (IndexOutOfBoundsException e) { 154 // expected 155 } 156 try { 157 add.shape(1); 158 fail(); 159 } catch (IndexOutOfBoundsException e) { 160 // expected 161 } 162 try { 163 add.dtype(1); 164 fail(); 165 } catch (IndexOutOfBoundsException e) { 166 // expected 167 } 168 try { 169 add.tensor(1); 170 fail(); 171 } catch (IndexOutOfBoundsException e) { 172 // expected 173 } 174 } 175 } 176 opBuilder(EagerSession session, String type, String name)177 private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) { 178 return new EagerOperationBuilder(session, type, name); 179 } 180 } 181