1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
17
18 #include "tensorflow/c/c_api.h"
19 #include "tensorflow/c/c_api_experimental.h"
20 #include "tensorflow/c/eager/c_api.h"
21 #include "tensorflow/c/eager/c_api_experimental.h"
22 #include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
23 #include "tensorflow/c/eager/tfe_context_internal.h"
24 #include "tensorflow/core/common_runtime/eager/context.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/test.h"
28
29 namespace tensorflow {
30 namespace parallel_device {
31
32 using ::testing::HasSubstr;
33
TEST(PARALLEL_DEVICE_LIB,TestOpWithError)34 TEST(PARALLEL_DEVICE_LIB, TestOpWithError) {
35 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
36 TF_NewStatus(), TF_DeleteStatus);
37 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
38 TFE_NewContextOptions(), TFE_DeleteContextOptions);
39 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
40 TF_CreateConfig(
41 /*xla*/ false,
42 /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
43 2),
44 TF_DeleteBuffer);
45 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
46 status.get());
47 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
48 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
49 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
50
51 std::vector<std::string> devices{
52 "/job:localhost/replica:0/task:0/device:CPU:0",
53 "/job:localhost/replica:0/task:0/device:CPU:1"};
54 ParallelDevice parallel_device(std::move(devices));
55 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> handle_op(
56 TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
57 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
58 TFE_OpSetAttrType(handle_op.get(), "dtype", TF_FLOAT);
59 TFE_OpSetAttrShape(handle_op.get(), "shape", /*dims=*/nullptr, /*num_dims=*/0,
60 status.get());
61 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
62 auto outputs =
63 parallel_device.Execute(context.get(), std::vector<ParallelTensor*>(),
64 "VarHandleOp", TFE_OpGetAttrs(handle_op.get()),
65 /*expected_max_outputs=*/1, status.get());
66 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
67 const std::vector<std::unique_ptr<ParallelTensor>>& handles = *outputs;
68 std::vector<ParallelTensor*> handle_inputs;
69 handle_inputs.reserve(handles.size());
70 for (auto& handle : handles) {
71 handle_inputs.push_back(handle.get());
72 }
73 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> read_op(
74 TFE_NewOp(context.get(), "ReadVariableOp", status.get()), TFE_DeleteOp);
75 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
76 TFE_OpSetAttrType(read_op.get(), "dtype", TF_FLOAT);
77 parallel_device.Execute(context.get(), handle_inputs, "ReadVariableOp",
78 TFE_OpGetAttrs(read_op.get()),
79 /*expected_max_outputs=*/1, status.get());
80 ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
81 TF_SetStatus(status.get(), TF_OK, "");
82
83 // Check that ops still run successfully on the device.
84 parallel_device.Execute(context.get(), std::vector<ParallelTensor*>(),
85 "VarHandleOp", TFE_OpGetAttrs(handle_op.get()),
86 /*expected_max_outputs=*/1, status.get());
87 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
88 }
89
TEST(PARALLEL_DEVICE_LIB,TestExplicitOutputShape)90 TEST(PARALLEL_DEVICE_LIB, TestExplicitOutputShape) {
91 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
92 TF_NewStatus(), TF_DeleteStatus);
93 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
94 TFE_NewContextOptions(), TFE_DeleteContextOptions);
95 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
96 TF_CreateConfig(
97 /*xla*/ false,
98 /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
99 2),
100 TF_DeleteBuffer);
101 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
102 status.get());
103 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
104 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
105 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
106
107 std::vector<std::string> devices{
108 "/job:localhost/replica:0/task:0/device:CPU:0",
109 "/job:localhost/replica:0/task:0/device:CPU:1"};
110 ParallelDevice parallel_device(std::move(devices));
111 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> handle_op(
112 TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
113 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
114 TFE_OpSetAttrType(handle_op.get(), "dtype", TF_FLOAT);
115 TFE_OpSetAttrShape(handle_op.get(), "shape", /*dims=*/nullptr, /*num_dims=*/0,
116 status.get());
117 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
118 CancellationManager cancellation_manager;
119 parallel_device.StartExecute(context.get(), std::vector<ParallelTensor*>(),
120 "VarHandleOp", TFE_OpGetAttrs(handle_op.get()),
121 /*expected_max_outputs=*/1,
122 cancellation_manager);
123 auto outputs = parallel_device.Join(
124 /*expected_output_shapes=*/{PartialTensorShape({})}, status.get());
125 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
126 const std::vector<std::unique_ptr<ParallelTensor>>& handles = *outputs;
127 const std::vector<int64_t>* shape;
128 Status s = handles[0]->Shape(&shape);
129 ASSERT_TRUE(s.ok());
130 EXPECT_EQ(0, shape->size());
131 }
132
TEST(PARALLEL_DEVICE_LIB,TestCancelOnError)133 TEST(PARALLEL_DEVICE_LIB, TestCancelOnError) {
134 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
135 TF_NewStatus(), TF_DeleteStatus);
136 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
137 TFE_NewContextOptions(), TFE_DeleteContextOptions);
138 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
139 TF_CreateConfig(
140 /*enable_xla_compilation=*/false,
141 /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2),
142 TF_DeleteBuffer);
143 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
144 status.get());
145 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
146 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
147 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
148
149 std::vector<std::string> devices{
150 "/job:localhost/replica:0/task:0/device:CPU:0",
151 "/job:localhost/replica:0/task:0/device:CPU:1"};
152 ParallelDevice parallel_device(devices);
153 const FunctionDef assert_and_collective = FunctionDefHelper::Define(
154 // Name
155 "AssertAndCollective",
156 // Args
157 {"x: float", "condition: bool"},
158 // Return values
159 {"y: float"},
160 // Attr def
161 {},
162 // Nodes
163 {
164 {{"assert"},
165 "Assert",
166 {"condition", "x"},
167 {{"T", std::vector<DataType>{DT_FLOAT}}}},
168 {{"y"},
169 "CollectiveReduce",
170 {"x"},
171 {{"T", DT_FLOAT},
172 {"group_size", static_cast<int>(devices.size())},
173 {"group_key", 0},
174 {"instance_key", 0},
175 {"merge_op", "Add"},
176 {"final_op", "Id"},
177 {"subdiv_offsets", std::vector<int>()}},
178 /*dep=*/{"assert"}},
179 });
180 TF_ASSERT_OK(ContextFromInterface(unwrap(context.get()))
181 ->AddFunctionDef(assert_and_collective));
182
183 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> call_op(
184 TFE_NewOp(context.get(), "AssertAndCollective", status.get()),
185 TFE_DeleteOp);
186 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
187 std::unique_ptr<ParallelTensor> reduced_values =
188 parallel_device.ScalarsFromSequence<float>({1.0, 2.0}, context.get(),
189 status.get());
190 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
191 std::unique_ptr<ParallelTensor> run_collective =
192 parallel_device.ScalarsFromSequence<bool>({true, true}, context.get(),
193 status.get());
194 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
195 auto outputs = parallel_device.Execute(
196 context.get(), {reduced_values.get(), run_collective.get()},
197 "AssertAndCollective", TFE_OpGetAttrs(call_op.get()),
198 /*expected_max_outputs=*/1, status.get());
199 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
200 ASSERT_EQ(outputs->size(), 1);
201 ParallelTensor* parallel_result = (*outputs)[0].get();
202 ExpectScalarEq<float>(parallel_result->tensor(0), 3.);
203 ExpectScalarEq<float>(parallel_result->tensor(1), 3.);
204
205 run_collective = parallel_device.ScalarsFromSequence<bool>(
206 {true, false}, context.get(), status.get());
207 parallel_device.Execute(context.get(),
208 {reduced_values.get(), run_collective.get()},
209 "AssertAndCollective", TFE_OpGetAttrs(call_op.get()),
210 /*expected_max_outputs=*/1, status.get());
211 EXPECT_NE(TF_GetCode(status.get()), TF_CANCELLED);
212 EXPECT_EQ(TF_GetCode(status.get()), TF_INVALID_ARGUMENT);
213 EXPECT_THAT(TF_Message(status.get()), HasSubstr("assertion failed"));
214
215 // Note that future collectives with the same context do not work at the
216 // moment; once canceled, the collective executor requires the program to be
217 // restarted / context to be reset.
218 }
219
TEST(PARALLEL_DEVICE_LIB,TestDifferentShapes)220 TEST(PARALLEL_DEVICE_LIB, TestDifferentShapes) {
221 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
222 TF_NewStatus(), TF_DeleteStatus);
223 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
224 TFE_NewContextOptions(), TFE_DeleteContextOptions);
225 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
226 TF_CreateConfig(
227 /*xla*/ false,
228 /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
229 2),
230 TF_DeleteBuffer);
231 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
232 status.get());
233 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
234 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
235 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
236
237 std::vector<std::string> devices{
238 "/job:localhost/replica:0/task:0/device:CPU:0",
239 "/job:localhost/replica:0/task:0/device:CPU:1"};
240 ParallelDevice parallel_device(std::move(devices));
241 TensorHandlePtr two_vector = VectorFloatTensorHandle({3., 4.}, status.get());
242 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
243 TensorHandlePtr three_vector =
244 VectorFloatTensorHandle({5., 6., 7.}, status.get());
245 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
246
247 std::vector<TensorHandlePtr> vector_handles;
248 vector_handles.reserve(2);
249 vector_handles.push_back(std::move(two_vector));
250 vector_handles.push_back(std::move(three_vector));
251 std::unique_ptr<ParallelTensor> unknown_length_vector =
252 ParallelTensor::FromTensorHandles(
253 parallel_device, std::move(vector_handles), status.get());
254 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
255 const std::vector<int64_t>* shape;
256 Status s = unknown_length_vector->Shape(&shape);
257 EXPECT_FALSE(s.ok());
258
259 TensorHandlePtr scalar = FloatTensorHandle(2., status.get());
260 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
261 two_vector = VectorFloatTensorHandle({3., 4.}, status.get());
262 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
263 std::vector<TensorHandlePtr> mixed_handles;
264 mixed_handles.reserve(2);
265 mixed_handles.push_back(std::move(scalar));
266 mixed_handles.push_back(std::move(two_vector));
267 std::unique_ptr<ParallelTensor> unknown_dims_vector =
268 ParallelTensor::FromTensorHandles(parallel_device,
269 std::move(mixed_handles), status.get());
270 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
271 // Can't take the shape of a parallel tensor with varying numbers of axes, but
272 // running operations on them is OK.
273 s = unknown_length_vector->Shape(&shape);
274 EXPECT_FALSE(s.ok());
275 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> size_op(
276 TFE_NewOp(context.get(), "Size", status.get()), TFE_DeleteOp);
277 auto result = parallel_device.Execute(
278 context.get(), {unknown_dims_vector.get()}, "Size",
279 TFE_OpGetAttrs(size_op.get()), 1, status.get());
280 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
281 s = (*result)[0]->Shape(&shape);
282 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
283 EXPECT_EQ(0, shape->size());
284 }
285
TEST(PARALLEL_DEVICE_LIB,TestScalarsFromSequence)286 TEST(PARALLEL_DEVICE_LIB, TestScalarsFromSequence) {
287 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
288 TF_NewStatus(), TF_DeleteStatus);
289 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
290 TFE_NewContextOptions(), TFE_DeleteContextOptions);
291 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
292 TF_CreateConfig(
293 /*enable_xla_compilation=*/false,
294 /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2),
295 TF_DeleteBuffer);
296 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
297 status.get());
298 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
299 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
300 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
301
302 std::vector<std::string> devices{
303 "/job:localhost/replica:0/task:0/device:CPU:0",
304 "/job:localhost/replica:0/task:0/device:CPU:1"};
305 ParallelDevice parallel_device(std::move(devices));
306 {
307 std::unique_ptr<ParallelTensor> float_tensors =
308 parallel_device.ScalarsFromSequence<float>({10.0, 11.0}, context.get(),
309 status.get());
310 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
311 ExpectScalarEq<float>(float_tensors->tensor(0), 10.0);
312 ExpectScalarEq<float>(float_tensors->tensor(1), 11.0);
313 }
314
315 {
316 std::unique_ptr<ParallelTensor> int_tensors =
317 parallel_device.ScalarsFromSequence<int>({5, 6}, context.get(),
318 status.get());
319 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
320 ExpectScalarEq<int>(int_tensors->tensor(0), 5);
321 ExpectScalarEq<int>(int_tensors->tensor(1), 6);
322 }
323 }
324
325 } // namespace parallel_device
326 } // namespace tensorflow
327