• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow;
17 
18 import static org.junit.Assert.assertEquals;
19 import static org.junit.Assert.fail;
20 
21 import org.junit.Test;
22 import org.junit.runner.RunWith;
23 import org.junit.runners.JUnit4;
24 
25 /** Unit tests for {@link EagerOperation} class. */
26 @RunWith(JUnit4.class)
27 public class EagerOperationTest {
28 
29   @Test
failToCreateIfSessionIsClosed()30   public void failToCreateIfSessionIsClosed() {
31     EagerSession session = EagerSession.create();
32     session.close();
33     try {
34       new EagerOperation(session, 1L, new long[] {1L}, "Add", "add");
35       fail();
36     } catch (IllegalStateException e) {
37       // expected
38     }
39   }
40 
41   @Test
outputDataTypeAndShape()42   public void outputDataTypeAndShape() {
43     try (EagerSession session = EagerSession.create();
44         Tensor<Integer> t = Tensors.create(new int[2][3])) {
45       EagerOperation op =
46           opBuilder(session, "Const", "OutputAttrs")
47               .setAttr("dtype", DataType.INT32)
48               .setAttr("value", t)
49               .build();
50       assertEquals(DataType.INT32, op.dtype(0));
51       assertEquals(2, op.shape(0)[0]);
52       assertEquals(3, op.shape(0)[1]);
53     }
54   }
55 
56   @Test
outputTensor()57   public void outputTensor() {
58     try (EagerSession session = EagerSession.create()) {
59       EagerOperation add =
60           opBuilder(session, "Add", "CompareResult")
61               .addInput(TestUtil.constant(session, "Const1", 2))
62               .addInput(TestUtil.constant(session, "Const2", 4))
63               .build();
64       assertEquals(6, add.tensor(0).intValue());
65 
66       // Validate that we retrieve the right shape and datatype from the tensor
67       // that has been resolved
68       assertEquals(0, add.shape(0).length);
69       assertEquals(DataType.INT32, add.dtype(0));
70     }
71   }
72 
73   @Test
inputAndOutputListLengths()74   public void inputAndOutputListLengths() {
75     try (EagerSession session = EagerSession.create()) {
76       Output<Float> c1 = TestUtil.constant(session, "Const1", new float[] {1f, 2f});
77       Output<Float> c2 = TestUtil.constant(session, "Const2", new float[] {3f, 4f});
78 
79       EagerOperation acc =
80           opBuilder(session, "AddN", "InputListLength")
81               .addInputList(new Output<?>[] {c1, c2})
82               .build();
83       assertEquals(2, acc.inputListLength("inputs"));
84       assertEquals(1, acc.outputListLength("sum"));
85 
86       EagerOperation split =
87           opBuilder(session, "Split", "OutputListLength")
88               .addInput(TestUtil.constant(session, "Axis", 0))
89               .addInput(c1)
90               .setAttr("num_split", 2)
91               .build();
92       assertEquals(1, split.inputListLength("split_dim"));
93       assertEquals(2, split.outputListLength("output"));
94 
95       try {
96         split.inputListLength("no_such_input");
97         fail();
98       } catch (IllegalArgumentException e) {
99         // expected
100       }
101 
102       try {
103         split.outputListLength("no_such_output");
104         fail();
105       } catch (IllegalArgumentException e) {
106         // expected
107       }
108     }
109   }
110 
111   @Test
numOutputs()112   public void numOutputs() {
113     try (EagerSession session = EagerSession.create()) {
114       EagerOperation op =
115           opBuilder(session, "UniqueWithCountsV2", "unq")
116               .addInput(TestUtil.constant(session, "Const1", new int[] {1, 2, 1}))
117               .addInput(TestUtil.constant(session, "Axis", new int[] {0}))
118               .setAttr("out_idx", DataType.INT32)
119               .build();
120       assertEquals(3, op.numOutputs());
121     }
122   }
123 
124   @Test
opNotAccessibleIfSessionIsClosed()125   public void opNotAccessibleIfSessionIsClosed() {
126     EagerSession session = EagerSession.create();
127     EagerOperation add =
128         opBuilder(session, "Add", "SessionClosed")
129             .addInput(TestUtil.constant(session, "Const1", 2))
130             .addInput(TestUtil.constant(session, "Const2", 4))
131             .build();
132     assertEquals(1, add.outputListLength("z"));
133     session.close();
134     try {
135       add.outputListLength("z");
136       fail();
137     } catch (IllegalStateException e) {
138       // expected
139     }
140   }
141 
142   @Test
outputIndexOutOfBounds()143   public void outputIndexOutOfBounds() {
144     try (EagerSession session = EagerSession.create()) {
145       EagerOperation add =
146           opBuilder(session, "Add", "OutOfRange")
147               .addInput(TestUtil.constant(session, "Const1", 2))
148               .addInput(TestUtil.constant(session, "Const2", 4))
149               .build();
150       try {
151         add.getUnsafeNativeHandle(1);
152         fail();
153       } catch (IndexOutOfBoundsException e) {
154         // expected
155       }
156       try {
157         add.shape(1);
158         fail();
159       } catch (IndexOutOfBoundsException e) {
160         // expected
161       }
162       try {
163         add.dtype(1);
164         fail();
165       } catch (IndexOutOfBoundsException e) {
166         // expected
167       }
168       try {
169         add.tensor(1);
170         fail();
171       } catch (IndexOutOfBoundsException e) {
172         // expected
173       }
174     }
175   }
176 
opBuilder(EagerSession session, String type, String name)177   private static EagerOperationBuilder opBuilder(EagerSession session, String type, String name) {
178     return new EagerOperationBuilder(session, type, name);
179   }
180 }
181