• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // A simple logging device to test custom device registration.
17 #include <memory>
18 
19 #include "absl/strings/match.h"
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/eager/c_api.h"
22 #include "tensorflow/c/eager/c_api_experimental.h"
23 #include "tensorflow/c/eager/c_api_test_util.h"
24 #include "tensorflow/c/eager/custom_device_testutil.h"
25 #include "tensorflow/c/tf_status.h"
26 #include "tensorflow/core/lib/gtl/cleanup.h"
27 #include "tensorflow/core/platform/test.h"
28 
TEST(CUSTOM_DEVICE,RegisterSimpleDevice)29 TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
30   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
31       TF_NewStatus(), TF_DeleteStatus);
32   TFE_ContextOptions* opts = TFE_NewContextOptions();
33   TFE_Context* context = TFE_NewContext(opts, status.get());
34   TFE_DeleteContextOptions(opts);
35   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
36   bool arrived = false;
37   bool executed = false;
38   const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
39   RegisterLoggingDevice(context, name, /*strict_scope_placement=*/true,
40                         &arrived, &executed, status.get());
41   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
42   TFE_TensorHandle* hcpu = TestMatrixTensorHandle(context);
43   ASSERT_FALSE(arrived);
44   TFE_TensorHandle* hdevice =
45       TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get());
46   ASSERT_TRUE(arrived);
47   ASSERT_FALSE(executed);
48   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
49   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
50       MatMulOp(context, hcpu, hdevice), TFE_DeleteOp);
51   TFE_OpSetDevice(matmul.get(), name, status.get());
52   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
53   TFE_TensorHandle* retval;
54   int num_retvals = 1;
55   TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
56   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
57   ASSERT_TRUE(executed);
58 
59   TFE_DeleteTensorHandle(retval);
60   TFE_DeleteTensorHandle(hcpu);
61   TFE_DeleteTensorHandle(hdevice);
62   TFE_DeleteContext(context);
63 }
64 
TEST(CUSTOM_DEVICE,ResetOperation)65 TEST(CUSTOM_DEVICE, ResetOperation) {
66   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
67       TF_NewStatus(), TF_DeleteStatus);
68   TFE_ContextOptions* opts = TFE_NewContextOptions();
69   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
70       TFE_NewContext(opts, status.get()), TFE_DeleteContext);
71   TFE_DeleteContextOptions(opts);
72   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
73   bool arrived = false;
74   bool executed = false;
75   const char* custom_device_name =
76       "/job:localhost/replica:0/task:0/device:CUSTOM:0";
77   RegisterLoggingDevice(context.get(), custom_device_name,
78                         /*strict_scope_placement=*/true, &arrived, &executed,
79                         status.get());
80   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
81 
82   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> reused_op(
83       TFE_NewOp(context.get(), "Identity", status.get()), TFE_DeleteOp);
84   TFE_OpReset(reused_op.get(), "Identity", custom_device_name, status.get());
85   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
86   ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
87             tensorflow::string(custom_device_name));
88   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
89   TFE_OpReset(reused_op.get(), "Identity",
90               "/job:localhost/replica:0/task:0/device:CPU:0", status.get());
91   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
92   ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
93             tensorflow::string("/job:localhost/replica:0/task:0/device:CPU:0"));
94   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
95 }
96 
TEST(CUSTOM_DEVICE,MakeVariable)97 TEST(CUSTOM_DEVICE, MakeVariable) {
98   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
99       TF_NewStatus(), TF_DeleteStatus);
100   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
101       TFE_NewContextOptions(), TFE_DeleteContextOptions);
102   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
103       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
104   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
105   bool arrived = false;
106   bool executed = false;
107   const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
108   RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
109                         &arrived, &executed, status.get());
110   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
111 
112   // Create a variable handle placed on the custom device.
113   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
114       TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
115   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
116   TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
117   TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
118   TFE_OpSetAttrString(op.get(), "container", "", 0);
119   TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
120   TFE_OpSetDevice(op.get(), name, status.get());
121   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
122   TFE_TensorHandle* var_handle = nullptr;
123   int num_retvals = 1;
124   executed = false;
125   TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
126   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
127   ASSERT_TRUE(executed);
128   auto handle_cleaner = tensorflow::gtl::MakeCleanup(
129       [var_handle]() { TFE_DeleteTensorHandle(var_handle); });
130 
131   // Assign to the variable, copying to the custom device.
132   std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
133       TestScalarTensorHandle(context.get(), 111.f), TFE_DeleteTensorHandle);
134   op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
135   TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
136   TFE_OpAddInput(op.get(), var_handle, status.get());
137   TFE_OpAddInput(op.get(), one.get(), status.get());
138   TFE_OpSetDevice(op.get(), name, status.get());
139   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
140   executed = false;
141   num_retvals = 0;
142   TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
143   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
144   ASSERT_TRUE(executed);
145 
146   // Read the variable's value.
147   op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
148   TFE_OpAddInput(op.get(), var_handle, status.get());
149   TFE_OpSetDevice(op.get(), name, status.get());
150   TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
151   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
152   executed = false;
153   num_retvals = 1;
154   TFE_TensorHandle* var_value = nullptr;
155   TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
156   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
157   ASSERT_TRUE(executed);
158   auto value_cleaner = tensorflow::gtl::MakeCleanup(
159       [var_value]() { TFE_DeleteTensorHandle(var_value); });
160   ASSERT_EQ(tensorflow::string(name),
161             tensorflow::string(
162                 TFE_TensorHandleBackingDeviceName(var_value, status.get())));
163   TFE_TensorHandle* var_value_unpacked =
164       UnpackTensorHandle(var_value, status.get());
165   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
166   std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
167       TFE_TensorHandleResolve(var_value_unpacked, status.get()),
168       TF_DeleteTensor);
169   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
170   ASSERT_EQ(111., *static_cast<float*>(TF_TensorData(resolved_value.get())));
171 
172   // Free the backing buffer for the variable.
173   op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
174   TFE_OpAddInput(op.get(), var_handle, status.get());
175   TFE_OpSetDevice(op.get(), name, status.get());
176   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
177   num_retvals = 0;
178   TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
179   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
180 }
181 
TEST(CUSTOM_DEVICE,AccessVariableOnCustomDevice)182 TEST(CUSTOM_DEVICE, AccessVariableOnCustomDevice) {
183   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
184       TF_NewStatus(), TF_DeleteStatus);
185   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
186       TFE_NewContextOptions(), TFE_DeleteContextOptions);
187   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
188       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
189   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
190   bool arrived = false;
191   bool executed = false;
192   const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
193   RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/false,
194                         &arrived, &executed, status.get());
195   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
196 
197   // Create a variable handle placed on the custom device.
198   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
199       TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
200   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
201   TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
202   TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
203   TFE_OpSetAttrString(op.get(), "container", "", 0);
204   TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
205   TFE_OpSetDevice(op.get(), name, status.get());
206   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
207   TFE_TensorHandle* var_handle = nullptr;
208   int num_retvals = 1;
209   executed = false;
210   TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
211   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
212   ASSERT_TRUE(executed);
213   auto handle_cleaner = tensorflow::gtl::MakeCleanup(
214       [var_handle]() { TFE_DeleteTensorHandle(var_handle); });
215 
216   // Assign to the variable, copying to the custom device.
217   std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
218       TestScalarTensorHandle(context.get(), 111.f), TFE_DeleteTensorHandle);
219   op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
220   TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
221   TFE_OpAddInput(op.get(), var_handle, status.get());
222   TFE_OpAddInput(op.get(), one.get(), status.get());
223   TFE_OpSetDevice(op.get(), name, status.get());
224   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
225   executed = false;
226   num_retvals = 0;
227   TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
228   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
229   ASSERT_TRUE(executed);
230 
231   // Read the variable's value.
232   op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
233   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
234   TFE_OpAddInput(op.get(), var_handle, status.get());
235   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
236   TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
237   executed = false;
238   num_retvals = 1;
239   TFE_TensorHandle* var_value = nullptr;
240   TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
241   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
242   ASSERT_TRUE(executed);
243   ASSERT_EQ(
244       tensorflow::string(name),
245       tensorflow::string(TFE_TensorHandleDeviceName(var_value, status.get())));
246   TFE_DeleteTensorHandle(var_value);
247 
248   // Free the backing buffer for the variable.
249   op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
250   TFE_OpAddInput(op.get(), var_handle, status.get());
251   TFE_OpSetDevice(op.get(), name, status.get());
252   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
253   num_retvals = 0;
254   TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
255   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
256 }
257 
TEST(CUSTOM_DEVICE,InputBasedPlacement)258 TEST(CUSTOM_DEVICE, InputBasedPlacement) {
259   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
260       TF_NewStatus(), TF_DeleteStatus);
261   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
262       TFE_NewContextOptions(), TFE_DeleteContextOptions);
263   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
264       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
265   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
266 
267   const char* custom0 = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
268   const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1";
269   bool arrived = false;
270   bool executed = false;
271   RegisterLoggingDevice(context.get(), custom0,
272                         /*strict_scope_placement=*/false, &arrived, &executed,
273                         status.get());
274   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
275   RegisterLoggingDevice(context.get(), custom1,
276                         /*strict_scope_placement=*/true, &arrived, &executed,
277                         status.get());
278   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
279 
280   std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcpu(
281       TestMatrixTensorHandle(context.get()), TFE_DeleteTensorHandle);
282   ASSERT_FALSE(arrived);
283   std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom0(
284       TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom0,
285                                    status.get()),
286       TFE_DeleteTensorHandle);
287   ASSERT_TRUE(arrived);
288   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
289   arrived = false;
290   std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom1(
291       TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom1,
292                                    status.get()),
293       TFE_DeleteTensorHandle);
294   ASSERT_TRUE(arrived);
295   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
296 
297   // Base case: two CPU inputs executes fine.
298   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
299       MatMulOp(context.get(), hcpu.get(), hcpu.get()), TFE_DeleteOp);
300   TFE_TensorHandle* retval;
301   int num_retvals = 1;
302   TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
303   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
304   TFE_DeleteTensorHandle(retval);
305 
306   // Custom device: inputs in same custom device works.
307   matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom0.get()));
308   num_retvals = 1;
309   executed = false;
310   TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
311   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
312   ASSERT_TRUE(executed);
313   TFE_DeleteTensorHandle(retval);
314 
315   // Custom device: inputs in different custom devices fails.
316   matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom1.get()));
317   num_retvals = 1;
318   TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
319   ASSERT_NE(TF_OK, TF_GetCode(status.get()));
320   ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
321   ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1));
322 
323   // Custom device: mix of custom/physical places the op on the custom device.
324   matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
325   num_retvals = 1;
326   executed = false;
327   TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
328   EXPECT_TRUE(executed);
329   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
330   TFE_DeleteTensorHandle(retval);
331 
332   // Explicit placement still forces the op onto the requested device
333   matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
334   TFE_OpSetDevice(matmul.get(), "/job:localhost/replica:0/task:0/device:CPU:0",
335                   status.get());
336   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
337   num_retvals = 1;
338   executed = false;
339   TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
340   EXPECT_FALSE(executed);
341   ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
342 
343   // Custom devices can refuse to do type-based dispatch (as hcustom1 is
344   // configured to do)
345   matmul.reset(MatMulOp(context.get(), hcustom1.get(), hcpu.get()));
346   num_retvals = 1;
347   executed = false;
348   TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
349   EXPECT_FALSE(executed);
350   ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
351 }
352 
TEST(CUSTOM_DEVICE,InvalidRegistrationError)353 TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
354   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
355       TF_NewStatus(), TF_DeleteStatus);
356   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
357       TFE_NewContextOptions(), TFE_DeleteContextOptions);
358   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
359       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
360   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
361   bool arrived = false;
362   bool executed = false;
363   RegisterLoggingDevice(context.get(), "/device:CUSTOM:0",
364                         /*strict_scope_placement=*/true, &arrived, &executed,
365                         status.get());
366   ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
367       << TF_Message(status.get());
368 
369   const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
370   RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
371                         &arrived, &executed, status.get());
372   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
373   RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
374                         &arrived, &executed, status.get());
375   ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
376       << TF_Message(status.get());
377 
378   RegisterLoggingDevice(
379       context.get(), "/job:localhost/replica:0/task:0/device:CPU:0",
380       /*strict_scope_placement=*/true, &arrived, &executed, status.get());
381   ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
382       << TF_Message(status.get());
383 }
384