1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // A simple logging device to test custom device registration.
17 #include <memory>
18
19 #include "absl/strings/match.h"
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/eager/c_api.h"
22 #include "tensorflow/c/eager/c_api_experimental.h"
23 #include "tensorflow/c/eager/c_api_test_util.h"
24 #include "tensorflow/c/eager/custom_device_testutil.h"
25 #include "tensorflow/c/tf_status.h"
26 #include "tensorflow/core/lib/gtl/cleanup.h"
27 #include "tensorflow/core/platform/test.h"
28
TEST(CUSTOM_DEVICE,RegisterSimpleDevice)29 TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
30 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
31 TF_NewStatus(), TF_DeleteStatus);
32 TFE_ContextOptions* opts = TFE_NewContextOptions();
33 TFE_Context* context = TFE_NewContext(opts, status.get());
34 TFE_DeleteContextOptions(opts);
35 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
36 bool arrived = false;
37 bool executed = false;
38 const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
39 RegisterLoggingDevice(context, name, /*strict_scope_placement=*/true,
40 &arrived, &executed, status.get());
41 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
42 TFE_TensorHandle* hcpu = TestMatrixTensorHandle(context);
43 ASSERT_FALSE(arrived);
44 TFE_TensorHandle* hdevice =
45 TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get());
46 ASSERT_TRUE(arrived);
47 ASSERT_FALSE(executed);
48 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
49 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
50 MatMulOp(context, hcpu, hdevice), TFE_DeleteOp);
51 TFE_OpSetDevice(matmul.get(), name, status.get());
52 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
53 TFE_TensorHandle* retval;
54 int num_retvals = 1;
55 TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
56 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
57 ASSERT_TRUE(executed);
58
59 TFE_DeleteTensorHandle(retval);
60 TFE_DeleteTensorHandle(hcpu);
61 TFE_DeleteTensorHandle(hdevice);
62 TFE_DeleteContext(context);
63 }
64
TEST(CUSTOM_DEVICE,ResetOperation)65 TEST(CUSTOM_DEVICE, ResetOperation) {
66 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
67 TF_NewStatus(), TF_DeleteStatus);
68 TFE_ContextOptions* opts = TFE_NewContextOptions();
69 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
70 TFE_NewContext(opts, status.get()), TFE_DeleteContext);
71 TFE_DeleteContextOptions(opts);
72 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
73 bool arrived = false;
74 bool executed = false;
75 const char* custom_device_name =
76 "/job:localhost/replica:0/task:0/device:CUSTOM:0";
77 RegisterLoggingDevice(context.get(), custom_device_name,
78 /*strict_scope_placement=*/true, &arrived, &executed,
79 status.get());
80 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
81
82 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> reused_op(
83 TFE_NewOp(context.get(), "Identity", status.get()), TFE_DeleteOp);
84 TFE_OpReset(reused_op.get(), "Identity", custom_device_name, status.get());
85 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
86 ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
87 tensorflow::string(custom_device_name));
88 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
89 TFE_OpReset(reused_op.get(), "Identity",
90 "/job:localhost/replica:0/task:0/device:CPU:0", status.get());
91 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
92 ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
93 tensorflow::string("/job:localhost/replica:0/task:0/device:CPU:0"));
94 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
95 }
96
TEST(CUSTOM_DEVICE,MakeVariable)97 TEST(CUSTOM_DEVICE, MakeVariable) {
98 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
99 TF_NewStatus(), TF_DeleteStatus);
100 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
101 TFE_NewContextOptions(), TFE_DeleteContextOptions);
102 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
103 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
104 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
105 bool arrived = false;
106 bool executed = false;
107 const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
108 RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
109 &arrived, &executed, status.get());
110 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
111
112 // Create a variable handle placed on the custom device.
113 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
114 TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
115 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
116 TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
117 TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
118 TFE_OpSetAttrString(op.get(), "container", "", 0);
119 TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
120 TFE_OpSetDevice(op.get(), name, status.get());
121 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
122 TFE_TensorHandle* var_handle = nullptr;
123 int num_retvals = 1;
124 executed = false;
125 TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
126 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
127 ASSERT_TRUE(executed);
128 auto handle_cleaner = tensorflow::gtl::MakeCleanup(
129 [var_handle]() { TFE_DeleteTensorHandle(var_handle); });
130
131 // Assign to the variable, copying to the custom device.
132 std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
133 TestScalarTensorHandle(context.get(), 111.f), TFE_DeleteTensorHandle);
134 op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
135 TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
136 TFE_OpAddInput(op.get(), var_handle, status.get());
137 TFE_OpAddInput(op.get(), one.get(), status.get());
138 TFE_OpSetDevice(op.get(), name, status.get());
139 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
140 executed = false;
141 num_retvals = 0;
142 TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
143 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
144 ASSERT_TRUE(executed);
145
146 // Read the variable's value.
147 op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
148 TFE_OpAddInput(op.get(), var_handle, status.get());
149 TFE_OpSetDevice(op.get(), name, status.get());
150 TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
151 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
152 executed = false;
153 num_retvals = 1;
154 TFE_TensorHandle* var_value = nullptr;
155 TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
156 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
157 ASSERT_TRUE(executed);
158 auto value_cleaner = tensorflow::gtl::MakeCleanup(
159 [var_value]() { TFE_DeleteTensorHandle(var_value); });
160 ASSERT_EQ(tensorflow::string(name),
161 tensorflow::string(
162 TFE_TensorHandleBackingDeviceName(var_value, status.get())));
163 TFE_TensorHandle* var_value_unpacked =
164 UnpackTensorHandle(var_value, status.get());
165 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
166 std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
167 TFE_TensorHandleResolve(var_value_unpacked, status.get()),
168 TF_DeleteTensor);
169 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
170 ASSERT_EQ(111., *static_cast<float*>(TF_TensorData(resolved_value.get())));
171
172 // Free the backing buffer for the variable.
173 op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
174 TFE_OpAddInput(op.get(), var_handle, status.get());
175 TFE_OpSetDevice(op.get(), name, status.get());
176 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
177 num_retvals = 0;
178 TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
179 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
180 }
181
TEST(CUSTOM_DEVICE,AccessVariableOnCustomDevice)182 TEST(CUSTOM_DEVICE, AccessVariableOnCustomDevice) {
183 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
184 TF_NewStatus(), TF_DeleteStatus);
185 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
186 TFE_NewContextOptions(), TFE_DeleteContextOptions);
187 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
188 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
189 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
190 bool arrived = false;
191 bool executed = false;
192 const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
193 RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/false,
194 &arrived, &executed, status.get());
195 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
196
197 // Create a variable handle placed on the custom device.
198 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
199 TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
200 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
201 TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
202 TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
203 TFE_OpSetAttrString(op.get(), "container", "", 0);
204 TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
205 TFE_OpSetDevice(op.get(), name, status.get());
206 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
207 TFE_TensorHandle* var_handle = nullptr;
208 int num_retvals = 1;
209 executed = false;
210 TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
211 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
212 ASSERT_TRUE(executed);
213 auto handle_cleaner = tensorflow::gtl::MakeCleanup(
214 [var_handle]() { TFE_DeleteTensorHandle(var_handle); });
215
216 // Assign to the variable, copying to the custom device.
217 std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
218 TestScalarTensorHandle(context.get(), 111.f), TFE_DeleteTensorHandle);
219 op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
220 TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
221 TFE_OpAddInput(op.get(), var_handle, status.get());
222 TFE_OpAddInput(op.get(), one.get(), status.get());
223 TFE_OpSetDevice(op.get(), name, status.get());
224 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
225 executed = false;
226 num_retvals = 0;
227 TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
228 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
229 ASSERT_TRUE(executed);
230
231 // Read the variable's value.
232 op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
233 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
234 TFE_OpAddInput(op.get(), var_handle, status.get());
235 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
236 TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
237 executed = false;
238 num_retvals = 1;
239 TFE_TensorHandle* var_value = nullptr;
240 TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
241 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
242 ASSERT_TRUE(executed);
243 ASSERT_EQ(
244 tensorflow::string(name),
245 tensorflow::string(TFE_TensorHandleDeviceName(var_value, status.get())));
246 TFE_DeleteTensorHandle(var_value);
247
248 // Free the backing buffer for the variable.
249 op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
250 TFE_OpAddInput(op.get(), var_handle, status.get());
251 TFE_OpSetDevice(op.get(), name, status.get());
252 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
253 num_retvals = 0;
254 TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
255 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
256 }
257
TEST(CUSTOM_DEVICE,InputBasedPlacement)258 TEST(CUSTOM_DEVICE, InputBasedPlacement) {
259 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
260 TF_NewStatus(), TF_DeleteStatus);
261 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
262 TFE_NewContextOptions(), TFE_DeleteContextOptions);
263 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
264 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
265 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
266
267 const char* custom0 = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
268 const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1";
269 bool arrived = false;
270 bool executed = false;
271 RegisterLoggingDevice(context.get(), custom0,
272 /*strict_scope_placement=*/false, &arrived, &executed,
273 status.get());
274 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
275 RegisterLoggingDevice(context.get(), custom1,
276 /*strict_scope_placement=*/true, &arrived, &executed,
277 status.get());
278 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
279
280 std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcpu(
281 TestMatrixTensorHandle(context.get()), TFE_DeleteTensorHandle);
282 ASSERT_FALSE(arrived);
283 std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom0(
284 TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom0,
285 status.get()),
286 TFE_DeleteTensorHandle);
287 ASSERT_TRUE(arrived);
288 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
289 arrived = false;
290 std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom1(
291 TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom1,
292 status.get()),
293 TFE_DeleteTensorHandle);
294 ASSERT_TRUE(arrived);
295 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
296
297 // Base case: two CPU inputs executes fine.
298 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
299 MatMulOp(context.get(), hcpu.get(), hcpu.get()), TFE_DeleteOp);
300 TFE_TensorHandle* retval;
301 int num_retvals = 1;
302 TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
303 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
304 TFE_DeleteTensorHandle(retval);
305
306 // Custom device: inputs in same custom device works.
307 matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom0.get()));
308 num_retvals = 1;
309 executed = false;
310 TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
311 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
312 ASSERT_TRUE(executed);
313 TFE_DeleteTensorHandle(retval);
314
315 // Custom device: inputs in different custom devices fails.
316 matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom1.get()));
317 num_retvals = 1;
318 TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
319 ASSERT_NE(TF_OK, TF_GetCode(status.get()));
320 ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
321 ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1));
322
323 // Custom device: mix of custom/physical places the op on the custom device.
324 matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
325 num_retvals = 1;
326 executed = false;
327 TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
328 EXPECT_TRUE(executed);
329 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
330 TFE_DeleteTensorHandle(retval);
331
332 // Explicit placement still forces the op onto the requested device
333 matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
334 TFE_OpSetDevice(matmul.get(), "/job:localhost/replica:0/task:0/device:CPU:0",
335 status.get());
336 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
337 num_retvals = 1;
338 executed = false;
339 TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
340 EXPECT_FALSE(executed);
341 ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
342
343 // Custom devices can refuse to do type-based dispatch (as hcustom1 is
344 // configured to do)
345 matmul.reset(MatMulOp(context.get(), hcustom1.get(), hcpu.get()));
346 num_retvals = 1;
347 executed = false;
348 TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
349 EXPECT_FALSE(executed);
350 ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
351 }
352
TEST(CUSTOM_DEVICE,InvalidRegistrationError)353 TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
354 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
355 TF_NewStatus(), TF_DeleteStatus);
356 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
357 TFE_NewContextOptions(), TFE_DeleteContextOptions);
358 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
359 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
360 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
361 bool arrived = false;
362 bool executed = false;
363 RegisterLoggingDevice(context.get(), "/device:CUSTOM:0",
364 /*strict_scope_placement=*/true, &arrived, &executed,
365 status.get());
366 ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
367 << TF_Message(status.get());
368
369 const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
370 RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
371 &arrived, &executed, status.get());
372 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
373 RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
374 &arrived, &executed, status.get());
375 ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
376 << TF_Message(status.get());
377
378 RegisterLoggingDevice(
379 context.get(), "/job:localhost/replica:0/task:0/device:CPU:0",
380 /*strict_scope_placement=*/true, &arrived, &executed, status.get());
381 ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
382 << TF_Message(status.get());
383 }
384