• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
17 
18 #include "tensorflow/c/c_api.h"
19 #include "tensorflow/c/c_api_experimental.h"
20 #include "tensorflow/c/eager/c_api.h"
21 #include "tensorflow/c/eager/c_api_experimental.h"
22 #include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
23 #include "tensorflow/c/eager/tfe_context_internal.h"
24 #include "tensorflow/core/common_runtime/eager/context.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/test.h"
28 
29 namespace tensorflow {
30 namespace parallel_device {
31 
32 using ::testing::HasSubstr;
33 
TEST(PARALLEL_DEVICE_LIB,TestOpWithError)34 TEST(PARALLEL_DEVICE_LIB, TestOpWithError) {
35   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
36       TF_NewStatus(), TF_DeleteStatus);
37   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
38       TFE_NewContextOptions(), TFE_DeleteContextOptions);
39   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
40       TF_CreateConfig(
41           /*xla*/ false,
42           /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
43           2),
44       TF_DeleteBuffer);
45   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
46                               status.get());
47   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
48       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
49   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
50 
51   std::vector<std::string> devices{
52       "/job:localhost/replica:0/task:0/device:CPU:0",
53       "/job:localhost/replica:0/task:0/device:CPU:1"};
54   ParallelDevice parallel_device(std::move(devices));
55   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> handle_op(
56       TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
57   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
58   TFE_OpSetAttrType(handle_op.get(), "dtype", TF_FLOAT);
59   TFE_OpSetAttrShape(handle_op.get(), "shape", /*dims=*/nullptr, /*num_dims=*/0,
60                      status.get());
61   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
62   auto outputs =
63       parallel_device.Execute(context.get(), std::vector<ParallelTensor*>(),
64                               "VarHandleOp", TFE_OpGetAttrs(handle_op.get()),
65                               /*expected_max_outputs=*/1, status.get());
66   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
67   const std::vector<std::unique_ptr<ParallelTensor>>& handles = *outputs;
68   std::vector<ParallelTensor*> handle_inputs;
69   handle_inputs.reserve(handles.size());
70   for (auto& handle : handles) {
71     handle_inputs.push_back(handle.get());
72   }
73   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> read_op(
74       TFE_NewOp(context.get(), "ReadVariableOp", status.get()), TFE_DeleteOp);
75   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
76   TFE_OpSetAttrType(read_op.get(), "dtype", TF_FLOAT);
77   parallel_device.Execute(context.get(), handle_inputs, "ReadVariableOp",
78                           TFE_OpGetAttrs(read_op.get()),
79                           /*expected_max_outputs=*/1, status.get());
80   ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
81   TF_SetStatus(status.get(), TF_OK, "");
82 
83   // Check that ops still run successfully on the device.
84   parallel_device.Execute(context.get(), std::vector<ParallelTensor*>(),
85                           "VarHandleOp", TFE_OpGetAttrs(handle_op.get()),
86                           /*expected_max_outputs=*/1, status.get());
87   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
88 }
89 
TEST(PARALLEL_DEVICE_LIB,TestExplicitOutputShape)90 TEST(PARALLEL_DEVICE_LIB, TestExplicitOutputShape) {
91   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
92       TF_NewStatus(), TF_DeleteStatus);
93   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
94       TFE_NewContextOptions(), TFE_DeleteContextOptions);
95   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
96       TF_CreateConfig(
97           /*xla*/ false,
98           /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
99           2),
100       TF_DeleteBuffer);
101   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
102                               status.get());
103   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
104       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
105   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
106 
107   std::vector<std::string> devices{
108       "/job:localhost/replica:0/task:0/device:CPU:0",
109       "/job:localhost/replica:0/task:0/device:CPU:1"};
110   ParallelDevice parallel_device(std::move(devices));
111   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> handle_op(
112       TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
113   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
114   TFE_OpSetAttrType(handle_op.get(), "dtype", TF_FLOAT);
115   TFE_OpSetAttrShape(handle_op.get(), "shape", /*dims=*/nullptr, /*num_dims=*/0,
116                      status.get());
117   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
118   CancellationManager cancellation_manager;
119   parallel_device.StartExecute(context.get(), std::vector<ParallelTensor*>(),
120                                "VarHandleOp", TFE_OpGetAttrs(handle_op.get()),
121                                /*expected_max_outputs=*/1,
122                                cancellation_manager);
123   auto outputs = parallel_device.Join(
124       /*expected_output_shapes=*/{PartialTensorShape({})}, status.get());
125   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
126   const std::vector<std::unique_ptr<ParallelTensor>>& handles = *outputs;
127   const std::vector<int64_t>* shape;
128   Status s = handles[0]->Shape(&shape);
129   ASSERT_TRUE(s.ok());
130   EXPECT_EQ(0, shape->size());
131 }
132 
TEST(PARALLEL_DEVICE_LIB,TestCancelOnError)133 TEST(PARALLEL_DEVICE_LIB, TestCancelOnError) {
134   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
135       TF_NewStatus(), TF_DeleteStatus);
136   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
137       TFE_NewContextOptions(), TFE_DeleteContextOptions);
138   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
139       TF_CreateConfig(
140           /*enable_xla_compilation=*/false,
141           /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2),
142       TF_DeleteBuffer);
143   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
144                               status.get());
145   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
146       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
147   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
148 
149   std::vector<std::string> devices{
150       "/job:localhost/replica:0/task:0/device:CPU:0",
151       "/job:localhost/replica:0/task:0/device:CPU:1"};
152   ParallelDevice parallel_device(devices);
153   const FunctionDef assert_and_collective = FunctionDefHelper::Define(
154       // Name
155       "AssertAndCollective",
156       // Args
157       {"x: float", "condition: bool"},
158       // Return values
159       {"y: float"},
160       // Attr def
161       {},
162       // Nodes
163       {
164           {{"assert"},
165            "Assert",
166            {"condition", "x"},
167            {{"T", std::vector<DataType>{DT_FLOAT}}}},
168           {{"y"},
169            "CollectiveReduce",
170            {"x"},
171            {{"T", DT_FLOAT},
172             {"group_size", static_cast<int>(devices.size())},
173             {"group_key", 0},
174             {"instance_key", 0},
175             {"merge_op", "Add"},
176             {"final_op", "Id"},
177             {"subdiv_offsets", std::vector<int>()}},
178            /*dep=*/{"assert"}},
179       });
180   TF_ASSERT_OK(ContextFromInterface(unwrap(context.get()))
181                    ->AddFunctionDef(assert_and_collective));
182 
183   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> call_op(
184       TFE_NewOp(context.get(), "AssertAndCollective", status.get()),
185       TFE_DeleteOp);
186   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
187   std::unique_ptr<ParallelTensor> reduced_values =
188       parallel_device.ScalarsFromSequence<float>({1.0, 2.0}, context.get(),
189                                                  status.get());
190   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
191   std::unique_ptr<ParallelTensor> run_collective =
192       parallel_device.ScalarsFromSequence<bool>({true, true}, context.get(),
193                                                 status.get());
194   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
195   auto outputs = parallel_device.Execute(
196       context.get(), {reduced_values.get(), run_collective.get()},
197       "AssertAndCollective", TFE_OpGetAttrs(call_op.get()),
198       /*expected_max_outputs=*/1, status.get());
199   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
200   ASSERT_EQ(outputs->size(), 1);
201   ParallelTensor* parallel_result = (*outputs)[0].get();
202   ExpectScalarEq<float>(parallel_result->tensor(0), 3.);
203   ExpectScalarEq<float>(parallel_result->tensor(1), 3.);
204 
205   run_collective = parallel_device.ScalarsFromSequence<bool>(
206       {true, false}, context.get(), status.get());
207   parallel_device.Execute(context.get(),
208                           {reduced_values.get(), run_collective.get()},
209                           "AssertAndCollective", TFE_OpGetAttrs(call_op.get()),
210                           /*expected_max_outputs=*/1, status.get());
211   EXPECT_NE(TF_GetCode(status.get()), TF_CANCELLED);
212   EXPECT_EQ(TF_GetCode(status.get()), TF_INVALID_ARGUMENT);
213   EXPECT_THAT(TF_Message(status.get()), HasSubstr("assertion failed"));
214 
215   // Note that future collectives with the same context do not work at the
216   // moment; once canceled, the collective executor requires the program to be
217   // restarted / context to be reset.
218 }
219 
TEST(PARALLEL_DEVICE_LIB,TestDifferentShapes)220 TEST(PARALLEL_DEVICE_LIB, TestDifferentShapes) {
221   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
222       TF_NewStatus(), TF_DeleteStatus);
223   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
224       TFE_NewContextOptions(), TFE_DeleteContextOptions);
225   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
226       TF_CreateConfig(
227           /*xla*/ false,
228           /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
229           2),
230       TF_DeleteBuffer);
231   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
232                               status.get());
233   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
234       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
235   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
236 
237   std::vector<std::string> devices{
238       "/job:localhost/replica:0/task:0/device:CPU:0",
239       "/job:localhost/replica:0/task:0/device:CPU:1"};
240   ParallelDevice parallel_device(std::move(devices));
241   TensorHandlePtr two_vector = VectorFloatTensorHandle({3., 4.}, status.get());
242   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
243   TensorHandlePtr three_vector =
244       VectorFloatTensorHandle({5., 6., 7.}, status.get());
245   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
246 
247   std::vector<TensorHandlePtr> vector_handles;
248   vector_handles.reserve(2);
249   vector_handles.push_back(std::move(two_vector));
250   vector_handles.push_back(std::move(three_vector));
251   std::unique_ptr<ParallelTensor> unknown_length_vector =
252       ParallelTensor::FromTensorHandles(
253           parallel_device, std::move(vector_handles), status.get());
254   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
255   const std::vector<int64_t>* shape;
256   Status s = unknown_length_vector->Shape(&shape);
257   EXPECT_FALSE(s.ok());
258 
259   TensorHandlePtr scalar = FloatTensorHandle(2., status.get());
260   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
261   two_vector = VectorFloatTensorHandle({3., 4.}, status.get());
262   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
263   std::vector<TensorHandlePtr> mixed_handles;
264   mixed_handles.reserve(2);
265   mixed_handles.push_back(std::move(scalar));
266   mixed_handles.push_back(std::move(two_vector));
267   std::unique_ptr<ParallelTensor> unknown_dims_vector =
268       ParallelTensor::FromTensorHandles(parallel_device,
269                                         std::move(mixed_handles), status.get());
270   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
271   // Can't take the shape of a parallel tensor with varying numbers of axes, but
272   // running operations on them is OK.
273   s = unknown_length_vector->Shape(&shape);
274   EXPECT_FALSE(s.ok());
275   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> size_op(
276       TFE_NewOp(context.get(), "Size", status.get()), TFE_DeleteOp);
277   auto result = parallel_device.Execute(
278       context.get(), {unknown_dims_vector.get()}, "Size",
279       TFE_OpGetAttrs(size_op.get()), 1, status.get());
280   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
281   s = (*result)[0]->Shape(&shape);
282   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
283   EXPECT_EQ(0, shape->size());
284 }
285 
TEST(PARALLEL_DEVICE_LIB,TestScalarsFromSequence)286 TEST(PARALLEL_DEVICE_LIB, TestScalarsFromSequence) {
287   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
288       TF_NewStatus(), TF_DeleteStatus);
289   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
290       TFE_NewContextOptions(), TFE_DeleteContextOptions);
291   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
292       TF_CreateConfig(
293           /*enable_xla_compilation=*/false,
294           /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2),
295       TF_DeleteBuffer);
296   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
297                               status.get());
298   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
299       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
300   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
301 
302   std::vector<std::string> devices{
303       "/job:localhost/replica:0/task:0/device:CPU:0",
304       "/job:localhost/replica:0/task:0/device:CPU:1"};
305   ParallelDevice parallel_device(std::move(devices));
306   {
307     std::unique_ptr<ParallelTensor> float_tensors =
308         parallel_device.ScalarsFromSequence<float>({10.0, 11.0}, context.get(),
309                                                    status.get());
310     ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
311     ExpectScalarEq<float>(float_tensors->tensor(0), 10.0);
312     ExpectScalarEq<float>(float_tensors->tensor(1), 11.0);
313   }
314 
315   {
316     std::unique_ptr<ParallelTensor> int_tensors =
317         parallel_device.ScalarsFromSequence<int>({5, 6}, context.get(),
318                                                  status.get());
319     ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
320     ExpectScalarEq<int>(int_tensors->tensor(0), 5);
321     ExpectScalarEq<int>(int_tensors->tensor(1), 6);
322   }
323 }
324 
325 }  // namespace parallel_device
326 }  // namespace tensorflow
327