1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/c_api.h"
17
18 #include <string.h>
19
20 #include <string>
21
22 // clang-format off
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/platform/platform.h"
26 // clang-format on
27
28 #include "absl/strings/match.h"
29 #include "tensorflow/c/eager/c_api_experimental.h"
30 #include "tensorflow/c/eager/c_api_internal.h"
31 #include "tensorflow/c/eager/c_api_test_util.h"
32 #include "tensorflow/c/eager/tfe_op_internal.h"
33 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
34 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
35 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
36 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
37 #include "tensorflow/core/framework/function.pb.h"
38 #include "tensorflow/core/platform/casts.h"
39 #include "tensorflow/core/platform/logging.h"
40 #include "tensorflow/core/platform/macros.h"
41 #include "tensorflow/core/platform/protobuf.h"
42 #include "tensorflow/core/platform/strcat.h"
43 #include "tensorflow/core/platform/test.h"
44 #include "tensorflow/core/platform/test_benchmark.h"
45 #include "tensorflow/core/protobuf/cluster.pb.h"
46 #include "tensorflow/core/protobuf/config.pb.h"
47 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
48 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
49
50 #ifdef PLATFORM_GOOGLE
51 #include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
52 #endif
53
54 using tensorflow::string;
55
56 namespace {
57
BM_InitOp(::testing::benchmark::State & state)58 void BM_InitOp(::testing::benchmark::State& state) {
59 TF_Status* status = TF_NewStatus();
60 TFE_ContextOptions* opts = TFE_NewContextOptions();
61 TFE_Context* ctx = TFE_NewContext(opts, status);
62 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
63 TFE_DeleteContextOptions(opts);
64
65 TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
66 for (auto s : state) {
67 TFE_Op* matmul = MatMulOp(ctx, m, m);
68 TFE_DeleteOp(matmul);
69 }
70 TFE_DeleteTensorHandle(m);
71 TFE_DeleteContext(ctx);
72 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
73 TF_DeleteStatus(status);
74 }
75 BENCHMARK(BM_InitOp);
76
BM_Execute(::testing::benchmark::State & state)77 void BM_Execute(::testing::benchmark::State& state) {
78 const int async = state.range(0);
79 state.SetLabel(async ? "ExecuteAsync" : "Execute");
80 TF_Status* status = TF_NewStatus();
81 TFE_ContextOptions* opts = TFE_NewContextOptions();
82 TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
83 TFE_Context* ctx = TFE_NewContext(opts, status);
84 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
85 TFE_DeleteContextOptions(opts);
86
87 TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
88 TFE_Op* matmul = TFE_NewOp(ctx, "MatMul", status);
89 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
90 TFE_TensorHandle* retvals[1];
91 int num_retvals = 1;
92 for (auto s : state) {
93 TFE_OpReset(matmul, "MatMul", nullptr, status);
94 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
95 TFE_OpAddInput(matmul, m, status);
96 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
97 TFE_OpAddInput(matmul, m, status);
98 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
99 TFE_Execute(matmul, &retvals[0], &num_retvals, status);
100 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
101 if (state.iterations() >= state.max_iterations && async) {
102 TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
103 TFE_ExecutorWaitForAllPendingNodes(executor, status);
104 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
105 TFE_DeleteExecutor(executor);
106 }
107 }
108 TFE_DeleteOp(matmul);
109 TFE_DeleteTensorHandle(m);
110 TFE_DeleteContext(ctx);
111 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
112 TF_DeleteStatus(status);
113 }
114 BENCHMARK(BM_Execute)->Arg(0)->Arg(1);
115
BM_Execute_Identity(::testing::benchmark::State & state)116 void BM_Execute_Identity(::testing::benchmark::State& state) {
117 const int async = state.range(0);
118 state.SetLabel(async ? "ExecuteIdentityAsync" : "ExecuteIdentity");
119 TF_Status* status = TF_NewStatus();
120 TFE_ContextOptions* opts = TFE_NewContextOptions();
121 TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
122 TFE_Context* ctx = TFE_NewContext(opts, status);
123 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
124 TFE_DeleteContextOptions(opts);
125
126 TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
127 TFE_Op* identity = TFE_NewOp(ctx, "Identity", status);
128 TFE_TensorHandle* retvals[1];
129 int num_retvals = 1;
130 for (auto s : state) {
131 TFE_OpReset(identity, "Identity", nullptr, status);
132 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
133 TFE_OpAddInput(identity, m, status);
134 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
135 TFE_Execute(identity, &retvals[0], &num_retvals, status);
136 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
137 if (state.iterations() >= state.max_iterations && async) {
138 TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
139 TFE_ExecutorWaitForAllPendingNodes(executor, status);
140 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
141 TFE_DeleteExecutor(executor);
142 }
143 }
144 TFE_DeleteOp(identity);
145 TFE_DeleteTensorHandle(m);
146 TFE_DeleteContext(ctx);
147 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
148 TF_DeleteStatus(status);
149 }
150 BENCHMARK(BM_Execute_Identity)->Arg(0)->Arg(1);
151
TEST(CAPI,Context)152 TEST(CAPI, Context) {
153 TF_Status* status = TF_NewStatus();
154 TFE_ContextOptions* opts = TFE_NewContextOptions();
155 TFE_Context* ctx = TFE_NewContext(opts, status);
156 TFE_DeleteContextOptions(opts);
157
158 TF_DeviceList* devices = TFE_ContextListDevices(ctx, status);
159 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
160
161 TFE_DeleteContext(ctx);
162 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
163
164 const int num_devices = TF_DeviceListCount(devices);
165 EXPECT_GE(num_devices, 1) << "At least one CPU device should exist";
166 for (int i = 0; i < num_devices; ++i) {
167 EXPECT_NE("", TF_DeviceListName(devices, i, status)) << i;
168 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
169 }
170 TF_DeleteDeviceList(devices);
171 TF_DeleteStatus(status);
172 }
173
TEST(CAPI,TensorHandle)174 TEST(CAPI, TensorHandle) {
175 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
176 TF_NewStatus(), TF_DeleteStatus);
177 TFE_ContextOptions* opts = TFE_NewContextOptions();
178 TFE_Context* ctx = TFE_NewContext(opts, status.get());
179 CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
180 TFE_DeleteContextOptions(opts);
181
182 TFE_TensorHandle* h = TestMatrixTensorHandle(ctx);
183 EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
184
185 TF_Tensor* t = TFE_TensorHandleResolve(h, status.get());
186 ASSERT_EQ(16, TF_TensorByteSize(t));
187 float data[4] = {0};
188 memcpy(&data[0], TF_TensorData(t), TF_TensorByteSize(t));
189 EXPECT_EQ(1.0, data[0]);
190 EXPECT_EQ(2.0, data[1]);
191 EXPECT_EQ(3.0, data[2]);
192 EXPECT_EQ(4.0, data[3]);
193 TF_DeleteTensor(t);
194 TFE_DeleteTensorHandle(h);
195 TFE_DeleteContext(ctx);
196 }
197
TensorHandleCopyBetweenDevices(bool async)198 void TensorHandleCopyBetweenDevices(bool async) {
199 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
200 TF_NewStatus(), TF_DeleteStatus);
201 TFE_ContextOptions* opts = TFE_NewContextOptions();
202 TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
203 TFE_Context* ctx = TFE_NewContext(opts, status.get());
204 TFE_DeleteContextOptions(opts);
205 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
206
207 TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
208 TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
209 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
210
211 TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
212 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
213 const int num_devices = TF_DeviceListCount(devices);
214
215 const char* kCPUDevice = "CPU:0";
216 for (int i = 0; i < num_devices; ++i) {
217 const string name(TF_DeviceListName(devices, i, status.get()));
218 if (TF_GetCode(status.get()) != TF_OK) {
219 ADD_FAILURE() << i << " -- " << TF_Message(status.get());
220 continue;
221 }
222 auto tag = tensorflow::strings::StrCat("Device #", i, " (", name, ")");
223 // Copy to device
224 TFE_TensorHandle* hdevice =
225 TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get());
226 if (TF_GetCode(status.get()) != TF_OK) {
227 ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
228 continue;
229 }
230 // Copy from device to the same device.
231 TFE_TensorHandle* hdevice2 =
232 TFE_TensorHandleCopyToDevice(hdevice, ctx, name.c_str(), status.get());
233 if (TF_GetCode(status.get()) != TF_OK) {
234 ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
235 continue;
236 }
237 TFE_DeleteTensorHandle(hdevice);
238 // Copy back to CPU
239 TFE_TensorHandle* hcopy =
240 TFE_TensorHandleCopyToDevice(hdevice2, ctx, kCPUDevice, status.get());
241 if (TF_GetCode(status.get()) != TF_OK) {
242 ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
243 continue;
244 }
245 TFE_DeleteTensorHandle(hdevice2);
246
247 // Ensure that the contents are the same!
248 TF_Tensor* tcopy = TFE_TensorHandleResolve(hcopy, status.get());
249 TFE_DeleteTensorHandle(hcopy);
250 if (TF_GetCode(status.get()) != TF_OK) {
251 ADD_FAILURE() << tag;
252 continue;
253 }
254 EXPECT_EQ(TF_TensorByteSize(t), TF_TensorByteSize(tcopy)) << tag;
255 EXPECT_EQ(
256 0, memcmp(TF_TensorData(t), TF_TensorData(tcopy), TF_TensorByteSize(t)))
257 << tag;
258 TF_DeleteTensor(tcopy);
259 }
260
261 TF_DeleteDeviceList(devices);
262 TF_DeleteTensor(t);
263 TFE_DeleteTensorHandle(hcpu);
264 TFE_DeleteContext(ctx);
265 }
266
TEST(CAPI,TensorHandleCopyBetweenDevices)267 TEST(CAPI, TensorHandleCopyBetweenDevices) {
268 TensorHandleCopyBetweenDevices(false);
269 }
270
TEST(CAPI,TensorHandleCopyBetweenDevicesAsync)271 TEST(CAPI, TensorHandleCopyBetweenDevicesAsync) {
272 TensorHandleCopyBetweenDevices(true);
273 }
274
TensorHandleCopyBetweenDevicesError(bool async)275 void TensorHandleCopyBetweenDevicesError(bool async) {
276 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
277 TF_NewStatus(), TF_DeleteStatus);
278 TFE_ContextOptions* opts = TFE_NewContextOptions();
279 TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
280 TFE_Context* ctx = TFE_NewContext(opts, status.get());
281 TFE_DeleteContextOptions(opts);
282 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
283 TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
284 const char* kErrorDevice = "NoSuchDevice:0";
285 TFE_TensorHandle* hdevice =
286 TFE_TensorHandleCopyToDevice(hcpu, ctx, kErrorDevice, status.get());
287 EXPECT_NE(TF_OK, TF_GetCode(status.get()));
288 const char* msg = "NoSuchDevice:0 unknown device";
289 EXPECT_TRUE(strstr(TF_Message(status.get()), msg) != nullptr)
290 << TF_Message(status.get());
291 TF_SetStatus(status.get(), TF_OK, "");
292 const char* kCPUDevice = "CPU:0";
293 TFE_TensorHandle* hcopy =
294 TFE_TensorHandleCopyToDevice(hcpu, ctx, kCPUDevice, status.get());
295 EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
296
297 TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
298 TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
299 EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
300 TFE_DeleteExecutor(executor);
301 TFE_DeleteTensorHandle(hcopy);
302 TFE_DeleteTensorHandle(hcpu);
303 if (hdevice != nullptr) TFE_DeleteTensorHandle(hdevice);
304 TFE_DeleteContext(ctx);
305 }
306
TEST(CAPI,TensorHandleCopyBetweenDevicesError)307 TEST(CAPI, TensorHandleCopyBetweenDevicesError) {
308 TensorHandleCopyBetweenDevicesError(false);
309 }
310
TEST(CAPI,TensorHandleCopyBetweenDevicesErrorAsync)311 TEST(CAPI, TensorHandleCopyBetweenDevicesErrorAsync) {
312 TensorHandleCopyBetweenDevicesError(true);
313 }
314
TensorHandleCopyBetweenTwoGPUDevices(bool async)315 void TensorHandleCopyBetweenTwoGPUDevices(bool async) {
316 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
317 TF_NewStatus(), TF_DeleteStatus);
318 TFE_ContextOptions* opts = TFE_NewContextOptions();
319 TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
320 TFE_Context* ctx = TFE_NewContext(opts, status.get());
321 TFE_DeleteContextOptions(opts);
322 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
323
324 TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
325 TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
326 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
327
328 TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
329 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
330 const int num_devices = TF_DeviceListCount(devices);
331 bool has_gpu0 = false;
332 bool has_gpu1 = false;
333 for (int i = 0; i < num_devices; ++i) {
334 const char* dev = TF_DeviceListName(devices, i, status.get());
335 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
336 string device_name(dev);
337 if (device_name.find("GPU:0") != string::npos) {
338 has_gpu0 = true;
339 }
340 if (device_name.find("GPU:1") != string::npos) {
341 has_gpu1 = true;
342 }
343 }
344
345 const char* kCPUDevice = "CPU:0";
346 if (!has_gpu0 || !has_gpu1) {
347 TF_DeleteDeviceList(devices);
348 TF_DeleteTensor(t);
349 TFE_DeleteTensorHandle(hcpu);
350 TFE_DeleteContext(ctx);
351 return;
352 }
353 const string gpu_1_name(TF_DeviceListName(devices, 1, status.get()));
354 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
355 const string gpu_2_name(TF_DeviceListName(devices, 2, status.get()));
356 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
357 TFE_TensorHandle* hdevice =
358 TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_1_name.c_str(), status.get());
359 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
360
361 TFE_TensorHandle* hdevice2 = TFE_TensorHandleCopyToDevice(
362 hdevice, ctx, gpu_2_name.c_str(), status.get());
363 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
364 TFE_DeleteTensorHandle(hdevice);
365 // Copy back to CPU
366 TFE_TensorHandle* hcopy =
367 TFE_TensorHandleCopyToDevice(hdevice2, ctx, kCPUDevice, status.get());
368 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
369 TFE_DeleteTensorHandle(hdevice2);
370
371 // Ensure that the contents are the same!
372 TF_Tensor* tcopy = TFE_TensorHandleResolve(hcopy, status.get());
373 TFE_DeleteTensorHandle(hcopy);
374 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
375 EXPECT_EQ(TF_TensorByteSize(t), TF_TensorByteSize(tcopy));
376 EXPECT_EQ(
377 0, memcmp(TF_TensorData(t), TF_TensorData(tcopy), TF_TensorByteSize(t)));
378 TF_DeleteTensor(tcopy);
379
380 TF_DeleteDeviceList(devices);
381 TF_DeleteTensor(t);
382 TFE_DeleteTensorHandle(hcpu);
383 TFE_DeleteContext(ctx);
384 }
385
TEST(CAPI,TensorHandleCopyBetweenTwoGPUDevices)386 TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) {
387 TensorHandleCopyBetweenTwoGPUDevices(false);
388 }
389
TEST(CAPI,TensorHandleCopyBetweenTwoGPUDevicesAsync)390 TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevicesAsync) {
391 TensorHandleCopyBetweenTwoGPUDevices(true);
392 }
393
TensorHandleSilentCopy(bool async,TFE_ContextDevicePlacementPolicy global_policy,TFE_ContextDevicePlacementPolicy thread_policy,bool cpu_op)394 void TensorHandleSilentCopy(bool async,
395 TFE_ContextDevicePlacementPolicy global_policy,
396 TFE_ContextDevicePlacementPolicy thread_policy,
397 bool cpu_op) {
398 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
399 TF_NewStatus(), TF_DeleteStatus);
400 TFE_ContextOptions* opts = TFE_NewContextOptions();
401 TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
402 TFE_ContextOptionsSetDevicePlacementPolicy(opts, global_policy);
403 TFE_Context* ctx = TFE_NewContext(opts, status.get());
404 if (thread_policy != global_policy) {
405 TFE_ContextSetThreadLocalDevicePlacementPolicy(ctx, thread_policy);
406 }
407 TFE_DeleteContextOptions(opts);
408 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
409
410 TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
411 TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
412 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
413
414 // Disable the test if no GPU is present.
415 string gpu_device_name;
416 if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
417 TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
418 hcpu, ctx, gpu_device_name.c_str(), status.get());
419 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
420
421 auto cpu_arg =
422 tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hcpu));
423 auto gpu_arg =
424 tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hgpu));
425 auto gpu_device = gpu_arg->device();
426 ASSERT_FALSE(cpu_arg->HasLocalMirror(gpu_device));
427
428 TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
429 if (cpu_op) {
430 string cpu_device_name;
431 ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
432 TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status.get());
433 } else {
434 TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
435 }
436 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
437
438 TFE_TensorHandle* retvals[1];
439 int num_retvals = 1;
440 TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
441 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
442
443 // The CPU handle should have been copied and have a mirror on the GPU
444 ASSERT_TRUE(cpu_arg->HasLocalMirror(gpu_device));
445
446 TFE_DeleteOp(matmul);
447 TFE_DeleteTensorHandle(retvals[0]);
448 TFE_DeleteTensorHandle(hgpu);
449 }
450
451 TF_DeleteTensor(t);
452 TFE_DeleteTensorHandle(hcpu);
453 TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
454 TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
455 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
456 TFE_DeleteExecutor(executor);
457 TFE_DeleteContext(ctx);
458 }
TEST(CAPI,TensorHandleSilentCopy)459 TEST(CAPI, TensorHandleSilentCopy) {
460 TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
461 TFE_DEVICE_PLACEMENT_SILENT, false);
462 }
TEST(CAPI,TensorHandleSilentCopyAsync)463 TEST(CAPI, TensorHandleSilentCopyAsync) {
464 TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_SILENT,
465 TFE_DEVICE_PLACEMENT_SILENT, false);
466 }
TEST(CAPI,TensorHandleSilentCopyLocalPolicy)467 TEST(CAPI, TensorHandleSilentCopyLocalPolicy) {
468 TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_EXPLICIT,
469 TFE_DEVICE_PLACEMENT_SILENT, false);
470 }
TEST(CAPI,TensorHandleSilentCopyLocalPolicyAsync)471 TEST(CAPI, TensorHandleSilentCopyLocalPolicyAsync) {
472 TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_EXPLICIT,
473 TFE_DEVICE_PLACEMENT_SILENT, false);
474 }
475
SetAndGetOpDevices(bool async)476 void SetAndGetOpDevices(bool async) {
477 TF_Status* status = TF_NewStatus();
478 TFE_ContextOptions* opts = TFE_NewContextOptions();
479 TFE_Context* ctx = TFE_NewContext(opts, status);
480 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
481 TFE_DeleteContextOptions(opts);
482
483 TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
484 TFE_Op* matmul = MatMulOp(ctx, m, m);
485
486 // Disable the test if no GPU is present.
487 string gpu_device_name;
488 if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
489 TFE_OpSetDevice(matmul, "GPU:0", status);
490 ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
491 const char* device_name = TFE_OpGetDevice(matmul, status);
492 ASSERT_TRUE(strstr(device_name, "GPU:0") != nullptr);
493
494 TFE_OpSetDevice(matmul, "CPU:0", status);
495 ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
496 device_name = TFE_OpGetDevice(matmul, status);
497 ASSERT_TRUE(strstr(device_name, "CPU:0") != nullptr);
498 }
499
500 TFE_DeleteOp(matmul);
501 TFE_DeleteTensorHandle(m);
502 TFE_DeleteContext(ctx);
503 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
504 TF_DeleteStatus(status);
505 }
506
TEST(CAPI,TensorHandleNullptr)507 TEST(CAPI, TensorHandleNullptr) {
508 TFE_TensorHandle* h = nullptr;
509 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
510 TF_NewStatus(), TF_DeleteStatus);
511
512 TF_Tensor* t = TFE_TensorHandleResolve(h, status.get());
513 ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
514 ASSERT_EQ(t, nullptr);
515 ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
516
517 TF_SetStatus(status.get(), TF_OK, "");
518
519 const char* device_name = TFE_TensorHandleDeviceName(h, status.get());
520 ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
521 ASSERT_EQ(device_name, nullptr);
522 ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
523
524 TF_SetStatus(status.get(), TF_OK, "");
525
526 device_name = TFE_TensorHandleBackingDeviceName(h, status.get());
527 ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
528 ASSERT_EQ(device_name, nullptr);
529 ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
530
531 TF_SetStatus(status.get(), TF_OK, "");
532
533 int num_dims = TFE_TensorHandleNumDims(h, status.get());
534 ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
535 ASSERT_EQ(num_dims, -1);
536 ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
537
538 TF_SetStatus(status.get(), TF_OK, "");
539
540 int dim = TFE_TensorHandleDim(h, 0, status.get());
541 ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
542 ASSERT_EQ(dim, -1);
543 ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
544 }
545
TEST(CAPI,TensorHandleDevices)546 TEST(CAPI, TensorHandleDevices) {
547 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
548 TF_NewStatus(), TF_DeleteStatus);
549 TFE_ContextOptions* opts = TFE_NewContextOptions();
550 TFE_Context* ctx = TFE_NewContext(opts, status.get());
551 TFE_DeleteContextOptions(opts);
552 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
553
554 TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
555 const char* device_name = TFE_TensorHandleDeviceName(hcpu, status.get());
556 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
557 ASSERT_TRUE(absl::StrContains(device_name, "CPU:0")) << device_name;
558 const char* backing_device_name =
559 TFE_TensorHandleBackingDeviceName(hcpu, status.get());
560 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
561 ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU:0"))
562 << backing_device_name;
563
564 // Disable the test if no GPU is present.
565 string gpu_device_name;
566 if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
567 TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
568 hcpu, ctx, gpu_device_name.c_str(), status.get());
569 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
570
571 TFE_Op* shape_op = ShapeOp(ctx, hgpu);
572 TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get());
573 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
574 TFE_TensorHandle* retvals[1];
575 int num_retvals = 1;
576 TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get());
577 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
578
579 // .device of shape is GPU since the op is executed on GPU
580 device_name = TFE_TensorHandleDeviceName(retvals[0], status.get());
581 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
582 ASSERT_TRUE(absl::StrContains(device_name, "GPU:0")) << device_name;
583
584 // .backing_device of shape is CPU since the tensor is backed by CPU
585 backing_device_name =
586 TFE_TensorHandleBackingDeviceName(retvals[0], status.get());
587 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
588 ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU:0"))
589 << backing_device_name;
590
591 TFE_DeleteOp(shape_op);
592 TFE_DeleteTensorHandle(retvals[0]);
593 TFE_DeleteTensorHandle(hgpu);
594 }
595
596 TFE_DeleteTensorHandle(hcpu);
597 TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
598 TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
599 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
600 TFE_DeleteExecutor(executor);
601 TFE_DeleteContext(ctx);
602 }
603
ExecuteAdd(bool async,bool forward_input,bool tfrt)604 void ExecuteAdd(bool async, bool forward_input, bool tfrt) {
605 #ifdef PLATFORM_WINDOWS
606 // On windows, we flakily get a failure due to pointer instability.
607 // Disable the 4 tests using this helper until we fix the issue.
608 return;
609 #else
610 TF_Status* status = TF_NewStatus();
611 TFE_ContextOptions* opts = TFE_NewContextOptions();
612 TFE_ContextOptionsSetTfrt(opts, tfrt);
613 TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
614 TFE_Context* ctx = TFE_NewContext(opts, status);
615 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
616 TFE_DeleteContextOptions(opts);
617
618 TFE_TensorHandle* n = TestMatrixTensorHandle100x100(ctx);
619 // If a GPU exists, copy the handle to GPU so that we can exercise
620 // unprotecting a mirror.
621 std::string gpu_device_name;
622 if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
623 TFE_TensorHandle* n_gpu =
624 TFE_TensorHandleCopyToDevice(n, ctx, gpu_device_name.c_str(), status);
625 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
626 TFE_DeleteTensorHandle(n);
627 n = n_gpu;
628 }
629
630 TFE_TensorHandle* m = TestMatrixTensorHandle100x100(ctx);
631
632 // Store pointer to raw buffer for validation of forwarding behaviour.
633 TF_Tensor* orig = TFE_TensorHandleResolve(n, status);
634 void* orig_ptr = TF_TensorData(orig);
635 TF_DeleteTensor(orig);
636
637 TFE_Op* add_op = AddOp(ctx, n, m);
638 std::string cpu_device_name;
639 ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
640 TFE_OpSetDevice(add_op, cpu_device_name.c_str(), status);
641 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
642 if (forward_input) {
643 TFE_DeleteTensorHandle(n);
644 }
645
646 int num_retvals = 1;
647 if (async) {
648 // Enqueue dummy ops so we backlog async execution & actually test async.
649 // This is usually unnecessary, but we've experienced the occasional test
650 // failure when testing async mode with no explicit forwarding.
651 for (int i = 0; i < 100000; ++i) {
652 TFE_Op* add_op_dummy = AddOp(ctx, m, m);
653 TFE_OpSetDevice(add_op_dummy, cpu_device_name.c_str(), status);
654 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
655 TFE_TensorHandle* dummy = nullptr;
656 TFE_Execute(add_op_dummy, &dummy, &num_retvals, status);
657 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
658 TFE_DeleteTensorHandle(dummy);
659 TFE_DeleteOp(add_op_dummy);
660 }
661 }
662 TFE_TensorHandle* retval = nullptr;
663 TFE_Execute(add_op, &retval, &num_retvals, status);
664 EXPECT_EQ(1, num_retvals);
665 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
666 if (!forward_input) {
667 TFE_DeleteTensorHandle(n);
668 }
669 TFE_DeleteOp(add_op);
670
671 TF_Tensor* t = TFE_TensorHandleResolve(retval, status);
672 if (async) {
673 if (forward_input) {
674 // Since the input was forwarded, we released the input handle right away
675 // and hence expect the input to be forwarded to the return tensor.
676 EXPECT_EQ(orig_ptr, TF_TensorData(t));
677 } else {
678 // In async mode we expect forwarding to work without releasing the input
679 // handle since by the time the kernel is executed we have released the
680 // handle in the client code.
681 EXPECT_EQ(orig_ptr, TF_TensorData(t));
682 }
683 } else {
684 if (forward_input) {
685 // Since the input was forwarded, we released the input handle right away
686 // and hence expect the input to be forwarded to the return tensor.
687 EXPECT_EQ(orig_ptr, TF_TensorData(t));
688 } else {
689 // In sync mode, forwarding can't really happen since the client code will
690 // have a reference count on the input tensor while the kernel is being
691 // executed and thus it cannot be re-used for the return tensor.
692 EXPECT_NE(orig_ptr, TF_TensorData(t));
693 }
694 }
695
696 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
697 TFE_DeleteTensorHandle(m);
698 TFE_DeleteTensorHandle(retval);
699 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
700
701 float result[100 * 100] = {0};
702 EXPECT_EQ(sizeof(result), TF_TensorByteSize(t));
703 memcpy(&result[0], TF_TensorData(t), TF_TensorByteSize(t));
704 TF_DeleteTensor(t);
705 for (int i = 0; i < 100 * 100; ++i) {
706 EXPECT_EQ(2.0f, result[i]);
707 }
708 TFE_DeleteContext(ctx);
709 TF_DeleteStatus(status);
710 #endif // PLATFORM_WINDOWS
711 }
TEST(CAPI,ExecuteAdd)712 TEST(CAPI, ExecuteAdd) {
713 ExecuteAdd(
714 /*async=*/false,
715 /*forward_input*/ false,
716 /*tfrt*/ false);
717 }
TEST(CAPI,ExecuteAddAsync)718 TEST(CAPI, ExecuteAddAsync) {
719 ExecuteAdd(
720 /*async=*/true,
721 /*forward_input*/ false,
722 /*tfrt*/ false);
723 }
TEST(CAPI,ExecuteAddForward)724 TEST(CAPI, ExecuteAddForward) {
725 ExecuteAdd(
726 /*async=*/false,
727 /*forward_input*/ true,
728 /*tfrt*/ false);
729 }
TEST(CAPI,ExecuteAddForwardAsync)730 TEST(CAPI, ExecuteAddForwardAsync) {
731 ExecuteAdd(
732 /*async=*/true,
733 /*forward_input*/ true,
734 /*tfrt*/ false);
735 }
736 #ifdef PLATFORM_GOOGLE
737 // TODO(b/153349425): Add forwarding tests for TFRT
738 // TODO(b/178003466): Fix and re-enable.
TEST(CAPI,DISABLED_ExecuteAddTfrt)739 TEST(CAPI, DISABLED_ExecuteAddTfrt) {
740 ExecuteAdd(
741 /*async=*/false,
742 /*forward_input*/ false,
743 /*tfrt*/ true);
744 }
745 #endif
746
Execute_MatMul_CPU(bool async)747 void Execute_MatMul_CPU(bool async) {
748 TF_Status* status = TF_NewStatus();
749 TFE_ContextOptions* opts = TFE_NewContextOptions();
750 TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
751 TFE_Context* ctx = TFE_NewContext(opts, status);
752 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
753 TFE_DeleteContextOptions(opts);
754
755 TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
756 TFE_Op* matmul = MatMulOp(ctx, m, m);
757 TFE_TensorHandle* retvals[2] = {nullptr, nullptr};
758 int num_retvals = 2;
759 TFE_Execute(matmul, &retvals[0], &num_retvals, status);
760 EXPECT_EQ(1, num_retvals);
761 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
762 TFE_DeleteOp(matmul);
763 TFE_DeleteTensorHandle(m);
764
765 TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
766 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
767 TFE_DeleteTensorHandle(retvals[0]);
768 TFE_DeleteContext(ctx);
769 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
770 float product[4] = {0};
771 EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
772 memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
773 TF_DeleteTensor(t);
774 EXPECT_EQ(7, product[0]);
775 EXPECT_EQ(10, product[1]);
776 EXPECT_EQ(15, product[2]);
777 EXPECT_EQ(22, product[3]);
778 TF_DeleteStatus(status);
779 }
TEST(CAPI,Execute_MatMul_CPU)780 TEST(CAPI, Execute_MatMul_CPU) { Execute_MatMul_CPU(false); }
TEST(CAPI,Execute_MatMul_CPUAsync)781 TEST(CAPI, Execute_MatMul_CPUAsync) { Execute_MatMul_CPU(true); }
782
Execute_MatMul_CPU_Runtime_Error(bool async)783 void Execute_MatMul_CPU_Runtime_Error(bool async) {
784 TF_Status* status = TF_NewStatus();
785 TFE_ContextOptions* opts = TFE_NewContextOptions();
786 TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
787 TFE_Context* ctx = TFE_NewContext(opts, status);
788 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
789 TFE_DeleteContextOptions(opts);
790
791 TFE_TensorHandle* m1 = TestMatrixTensorHandle(ctx);
792 TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle3X2(ctx);
793 TFE_Op* matmul = MatMulOp(ctx, m1, m2);
794 TFE_OpSetDevice(matmul, "/job:localhost/replica:0/task:0/device:CPU:0",
795 status);
796 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
797 TFE_Op* matmul2 = MatMulOp(ctx, m1, m1);
798 TFE_OpSetDevice(matmul2, "/job:localhost/replica:0/task:0/device:CPU:0",
799 status);
800 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
801 TFE_TensorHandle* retvals[1] = {nullptr};
802 int num_retvals = 1;
803 TFE_Execute(matmul, &retvals[0], &num_retvals, status);
804 TFE_DeleteOp(matmul);
805 if (!async) {
806 EXPECT_NE(TF_OK, TF_GetCode(status));
807 } else {
808 TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
809 EXPECT_NE(TF_OK, TF_GetCode(status));
810 EXPECT_EQ(nullptr, t);
811 const char* msg = "Matrix size-incompatible: In[0]: [2,2], In[1]: [3,2]";
812 EXPECT_TRUE(strstr(TF_Message(status), msg) != nullptr)
813 << TF_Message(status);
814 // Since error is not cleared, the following copy with correct device will
815 // still fail.
816 TF_SetStatus(status, TF_OK, "");
817 TFE_DeleteTensorHandle(retvals[0]);
818 TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
819 TFE_ExecutorWaitForAllPendingNodes(executor, status);
820 EXPECT_NE(TF_OK, TF_GetCode(status));
821 TF_SetStatus(status, TF_OK, "");
822 retvals[0] = nullptr;
823 TFE_Execute(matmul2, &retvals[0], &num_retvals, status);
824 EXPECT_NE(TF_OK, TF_GetCode(status));
825 TFE_ExecutorClearError(executor);
826 TFE_ExecutorWaitForAllPendingNodes(executor, status);
827 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
828 TFE_DeleteExecutor(executor);
829 }
830 // Following works in async mode since TFE_ContextAsyncClearError was called.
831 TF_SetStatus(status, TF_OK, "");
832 if (retvals[0] != nullptr) {
833 TFE_DeleteTensorHandle(retvals[0]);
834 }
835 retvals[0] = nullptr;
836 TFE_Execute(matmul2, &retvals[0], &num_retvals, status);
837 EXPECT_EQ(TF_OK, TF_GetCode(status));
838 TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
839 EXPECT_EQ(TF_OK, TF_GetCode(status));
840 TF_DeleteTensor(t);
841 TFE_DeleteOp(matmul2);
842 TFE_DeleteTensorHandle(m1);
843 TFE_DeleteTensorHandle(m2);
844 TFE_DeleteTensorHandle(retvals[0]);
845 TFE_DeleteContext(ctx);
846 TF_DeleteStatus(status);
847 }
TEST(CAPI,Execute_MatMul_CPU_Runtime_Error)848 TEST(CAPI, Execute_MatMul_CPU_Runtime_Error) {
849 Execute_MatMul_CPU_Runtime_Error(false);
850 }
TEST(CAPI,Execute_MatMul_CPU_Runtime_ErrorAsync)851 TEST(CAPI, Execute_MatMul_CPU_Runtime_ErrorAsync) {
852 Execute_MatMul_CPU_Runtime_Error(true);
853 }
854
Execute_MatMul_CPU_Type_Error(bool async)855 void Execute_MatMul_CPU_Type_Error(bool async) {
856 TF_Status* status = TF_NewStatus();
857 TFE_ContextOptions* opts = TFE_NewContextOptions();
858 TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
859 TFE_Context* ctx = TFE_NewContext(opts, status);
860 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
861 TFE_DeleteContextOptions(opts);
862
863 TFE_TensorHandle* m1 = TestMatrixTensorHandle(ctx);
864 TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle(ctx);
865 TFE_Op* matmul = MatMulOp(ctx, m1, m2);
866 TFE_TensorHandle* retvals[1] = {nullptr};
867 int num_retvals = 1;
868 TFE_Execute(matmul, &retvals[0], &num_retvals, status);
869 EXPECT_NE(TF_OK, TF_GetCode(status));
870 TFE_DeleteOp(matmul);
871 TFE_DeleteTensorHandle(m1);
872 TFE_DeleteTensorHandle(m2);
873 if (retvals[0] != nullptr) {
874 TFE_DeleteTensorHandle(retvals[0]);
875 }
876 TFE_DeleteContext(ctx);
877 TF_DeleteStatus(status);
878 }
879
TEST(CAPI,Execute_MatMul_CPU_Type_Error)880 TEST(CAPI, Execute_MatMul_CPU_Type_Error) {
881 Execute_MatMul_CPU_Type_Error(false);
882 }
TEST(CAPI,Execute_MatMul_CPU_Type_ErrorAsync)883 TEST(CAPI, Execute_MatMul_CPU_Type_ErrorAsync) {
884 Execute_MatMul_CPU_Type_Error(true);
885 }
TEST(CAPI,Execute_Min_CPU)886 TEST(CAPI, Execute_Min_CPU) {
887 TF_Status* status = TF_NewStatus();
888 TFE_ContextOptions* opts = TFE_NewContextOptions();
889 TFE_Context* ctx = TFE_NewContext(opts, status);
890 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
891 TFE_DeleteContextOptions(opts);
892
893 TFE_TensorHandle* input = TestMatrixTensorHandle(ctx);
894 TFE_TensorHandle* axis = TestAxisTensorHandle(ctx);
895 TFE_Op* minOp = MinOp(ctx, input, axis);
896 TFE_TensorHandle* retvals[1] = {nullptr};
897 int num_retvals = 1;
898 TFE_Execute(minOp, &retvals[0], &num_retvals, status);
899 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
900 TFE_DeleteOp(minOp);
901 TFE_DeleteTensorHandle(input);
902 TFE_DeleteTensorHandle(axis);
903 ASSERT_EQ(1, num_retvals);
904
905 TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
906 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
907 TFE_DeleteTensorHandle(retvals[0]);
908 float output[2] = {0};
909 EXPECT_EQ(sizeof(output), TF_TensorByteSize(t));
910 memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t));
911 TF_DeleteTensor(t);
912 EXPECT_EQ(1, output[0]);
913 EXPECT_EQ(3, output[1]);
914 TFE_DeleteContext(ctx);
915 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
916 TF_DeleteStatus(status);
917 }
918
ExecuteWithTracing(bool async)919 void ExecuteWithTracing(bool async) {
920 TF_Status* status = TF_NewStatus();
921 TFE_ContextOptions* opts = TFE_NewContextOptions();
922 TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
923 TFE_Context* ctx = TFE_NewContext(opts, status);
924 TFE_ContextEnableRunMetadata(ctx);
925 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
926 TFE_DeleteContextOptions(opts);
927
928 TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
929 TFE_Op* matmul = MatMulOp(ctx, m, m);
930 TFE_TensorHandle* retvals[1] = {nullptr};
931 int num_retvals = 1;
932 TFE_Execute(matmul, &retvals[0], &num_retvals, status);
933 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
934 TFE_DeleteOp(matmul);
935 TFE_DeleteTensorHandle(m);
936 TF_Buffer* b = TF_NewBuffer();
937 TFE_ContextExportRunMetadata(ctx, b, status);
938 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
939 tensorflow::RunMetadata rm;
940 EXPECT_TRUE(
941 rm.ParseFromString({reinterpret_cast<const char*>(b->data), b->length}));
942 TF_DeleteBuffer(b);
943 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
944 ASSERT_EQ(1, num_retvals);
945
946 TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
947 TFE_DeleteTensorHandle(retvals[0]);
948 TFE_DeleteContext(ctx);
949 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
950 float product[4] = {0};
951 EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
952 memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
953 TF_DeleteTensor(t);
954 EXPECT_EQ(7, product[0]);
955 EXPECT_EQ(10, product[1]);
956 EXPECT_EQ(15, product[2]);
957 EXPECT_EQ(22, product[3]);
958 TF_DeleteStatus(status);
959 }
TEST(CAPI,ExecuteWithTracing)960 TEST(CAPI, ExecuteWithTracing) { ExecuteWithTracing(false); }
TEST(CAPI,ExecuteWithTracingAsync)961 TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithTracing(true); }
962
963 REGISTER_OP("TestNonCommUnavailable")
964 .Output("out: string")
965 .Doc(R"doc(Test non-communication op throwing Unavailable error.)doc");
966
967 REGISTER_OP("TestCommUnavailable")
968 .Output("out: string")
969 .SetIsDistributedCommunication()
970 .Doc(R"doc(Test communication op throwing Unavailable error.)doc");
971
972 // Kernel that throws an Unavailable error.
973 class TestUnavailableErrorOp : public tensorflow::OpKernel {
974 public:
TestUnavailableErrorOp(tensorflow::OpKernelConstruction * ctx)975 explicit TestUnavailableErrorOp(tensorflow::OpKernelConstruction* ctx)
976 : tensorflow::OpKernel(ctx) {}
Compute(tensorflow::OpKernelContext * ctx)977 void Compute(tensorflow::OpKernelContext* ctx) override {
978 ctx->SetStatus(tensorflow::errors::Unavailable("Test error."));
979 }
980 };
981 REGISTER_KERNEL_BUILDER(
982 Name("TestNonCommUnavailable").Device(tensorflow::DEVICE_DEFAULT),
983 TestUnavailableErrorOp);
984 REGISTER_KERNEL_BUILDER(
985 Name("TestCommUnavailable").Device(tensorflow::DEVICE_DEFAULT),
986 TestUnavailableErrorOp);
987
FunctionWithErrorOp(const tensorflow::StringPiece op_name)988 string FunctionWithErrorOp(const tensorflow::StringPiece op_name) {
989 const std::string& func_str =
990 " signature {"
991 " name: 'FunctionWith__OP_NAME__'"
992 " output_arg {"
993 " name: 'out'"
994 " type: DT_STRING"
995 " }"
996 " }"
997 " node_def {"
998 " name: 'error_op'"
999 " op: '__OP_NAME__'"
1000 " }"
1001 " ret {"
1002 " key: 'out'"
1003 " value: 'error_op:out'"
1004 " }";
1005 tensorflow::FunctionDef def;
1006 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
1007 tensorflow::str_util::StringReplace(func_str, "__OP_NAME__", op_name,
1008 /*replace_all=*/true),
1009 &def));
1010 return def.SerializeAsString();
1011 }
1012
TEST(CAPI,ExecuteOpAndFunctionWithError)1013 TEST(CAPI, ExecuteOpAndFunctionWithError) {
1014 TF_Status* status = TF_NewStatus();
1015 TFE_ContextOptions* opts = TFE_NewContextOptions();
1016 TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(/*async=*/false));
1017 TFE_Context* ctx = TFE_NewContext(opts, status);
1018 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1019 TFE_DeleteContextOptions(opts);
1020
1021 TFE_Op* non_comm_op = TFE_NewOp(ctx, "TestNonCommUnavailable", status);
1022 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1023 TFE_TensorHandle* retval[1] = {};
1024 int num_retvals = 1;
1025 TFE_Execute(non_comm_op, retval, &num_retvals, status);
1026 EXPECT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status);
1027 TFE_DeleteOp(non_comm_op);
1028
1029 TFE_Op* comm_op = TFE_NewOp(ctx, "TestCommUnavailable", status);
1030 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1031 TFE_Execute(comm_op, retval, &num_retvals, status);
1032 EXPECT_EQ(TF_UNAVAILABLE, TF_GetCode(status)) << TF_Message(status);
1033 TFE_DeleteOp(comm_op);
1034
1035 const string& fdef1 = FunctionWithErrorOp("TestNonCommUnavailable");
1036 TFE_ContextAddFunctionDef(ctx, fdef1.data(), fdef1.size(), status);
1037 TFE_Op* fn1 = TFE_NewOp(ctx, "FunctionWithTestNonCommUnavailable", status);
1038 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1039 TFE_Execute(fn1, retval, &num_retvals, status);
1040 EXPECT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status);
1041 TFE_DeleteOp(fn1);
1042
1043 const string& fdef2 = FunctionWithErrorOp("TestCommUnavailable");
1044 TFE_ContextAddFunctionDef(ctx, fdef2.data(), fdef2.size(), status);
1045 TFE_Op* fn2 = TFE_NewOp(ctx, "FunctionWithTestCommUnavailable", status);
1046 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1047 TFE_Execute(fn2, retval, &num_retvals, status);
1048 EXPECT_EQ(TF_UNAVAILABLE, TF_GetCode(status)) << TF_Message(status);
1049 TFE_DeleteOp(fn2);
1050
1051 TFE_DeleteContext(ctx);
1052 TF_DeleteStatus(status);
1053 }
1054
MatMulFunction()1055 string MatMulFunction() {
1056 tensorflow::FunctionDef def;
1057 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
1058 " signature {"
1059 " name: 'MatMulFunction'"
1060 " input_arg {"
1061 " name: 'a'"
1062 " type: DT_FLOAT"
1063 " }"
1064 " output_arg {"
1065 " name: 'm'"
1066 " type: DT_FLOAT"
1067 " }"
1068 " }"
1069 " node_def {"
1070 " name: 'matmul'"
1071 " op: 'MatMul'"
1072 " input: 'a'"
1073 " input: 'a'"
1074 " attr {"
1075 " key: 'T'"
1076 " value {"
1077 " type: DT_FLOAT"
1078 " }"
1079 " }"
1080 " }"
1081 " ret {"
1082 " key: 'm'"
1083 " value: 'matmul:product'"
1084 " }",
1085 &def));
1086 return def.SerializeAsString();
1087 }
1088
1089 // a + a
AddFunction()1090 string AddFunction() {
1091 tensorflow::FunctionDef def;
1092 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
1093 " signature {"
1094 " name: 'AddFunction'"
1095 " input_arg {"
1096 " name: 'a'"
1097 " type: DT_FLOAT"
1098 " }"
1099 " output_arg {"
1100 " name: 'o'"
1101 " type: DT_FLOAT"
1102 " }"
1103 " }"
1104 " node_def {"
1105 " name: 'output'"
1106 " op: 'Add'"
1107 " input: 'a'"
1108 " input: 'a'"
1109 " attr {"
1110 " key: 'T'"
1111 " value {"
1112 " type: DT_FLOAT"
1113 " }"
1114 " }"
1115 " }"
1116 " ret {"
1117 " key: 'o'"
1118 " value: 'output:z'"
1119 " }",
1120 &def));
1121 return def.SerializeAsString();
1122 }
1123
FunctionDefAndExecute(bool async)1124 void FunctionDefAndExecute(bool async) {
1125 TF_Status* status = TF_NewStatus();
1126 TFE_ContextOptions* opts = TFE_NewContextOptions();
1127 TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
1128 TFE_Context* ctx = TFE_NewContext(opts, status);
1129 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1130 TFE_DeleteContextOptions(opts);
1131
1132 string function_def = MatMulFunction();
1133 TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
1134 status);
1135 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1136
1137 for (bool clear_cache : {true, false, true}) {
1138 if (clear_cache) {
1139 TFE_ContextClearCaches(ctx);
1140 }
1141 TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
1142 TFE_TensorHandle* retval[1] = {nullptr};
1143 int num_retvals = 1;
1144 TFE_Op* op = TFE_NewOp(ctx, "MatMulFunction", status);
1145 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1146 TFE_OpAddInput(op, m, status);
1147 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1148 TFE_Execute(op, &retval[0], &num_retvals, status);
1149 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1150 ASSERT_EQ(1, num_retvals);
1151 TFE_DeleteOp(op);
1152 TFE_DeleteTensorHandle(m);
1153 TF_Tensor* t = TFE_TensorHandleResolve(retval[0], status);
1154 TFE_DeleteTensorHandle(retval[0]);
1155 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1156 float product[4] = {0};
1157 EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
1158 memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
1159 TF_DeleteTensor(t);
1160 EXPECT_EQ(7, product[0]);
1161 EXPECT_EQ(10, product[1]);
1162 EXPECT_EQ(15, product[2]);
1163 EXPECT_EQ(22, product[3]);
1164 }
1165 TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
1166 ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
1167 TFE_DeleteContext(ctx);
1168 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1169 TF_DeleteStatus(status);
1170 }
TEST(CAPI,FunctionDefAndExecute)1171 TEST(CAPI, FunctionDefAndExecute) { FunctionDefAndExecute(false); }
TEST(CAPI,FunctionDefAndExecuteAsync)1172 TEST(CAPI, FunctionDefAndExecuteAsync) { FunctionDefAndExecute(true); }
1173
RunAddFunction(bool use_tfrt,bool enable_grappler)1174 void RunAddFunction(bool use_tfrt, bool enable_grappler) {
1175 TF_Status* status = TF_NewStatus();
1176 TFE_ContextOptions* opts = TFE_NewContextOptions();
1177 TFE_ContextOptionsSetTfrt(opts, use_tfrt);
1178 TFE_Context* ctx = TFE_NewContext(opts, status);
1179 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1180 TFE_DeleteContextOptions(opts);
1181
1182 string function_def = AddFunction();
1183 TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
1184 status);
1185 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1186
1187 TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
1188 TFE_TensorHandle* retval[1] = {nullptr};
1189 int num_retvals = 1;
1190 TFE_Op* op = TFE_NewOp(ctx, "AddFunction", status);
1191 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1192
1193 // Add a config_proto attr, to trigger grappler graph rewrites in the current
1194 // eager runtime.
1195 if (enable_grappler) {
1196 tensorflow::ConfigProto config;
1197 // Do not skip grappler optimization even for small graphs.
1198 config.mutable_graph_options()
1199 ->mutable_rewrite_options()
1200 ->set_min_graph_nodes(-1);
1201 string serialized_config;
1202 ASSERT_TRUE(config.SerializeToString(&serialized_config));
1203 TFE_OpSetAttrString(
1204 op, "config_proto",
1205 reinterpret_cast<const void*>(serialized_config.c_str()),
1206 serialized_config.length());
1207 }
1208
1209 if (use_tfrt) {
1210 // Set some test-only graph compiler options.
1211 TFE_OpSetAttrBool(op, "TFRT_TEST_enable_native_ops", false);
1212 TFE_OpSetAttrBool(op, "TFRT_TEST_enable_grappler", enable_grappler);
1213 }
1214
1215 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1216
1217 TFE_OpAddInput(op, m, status);
1218 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1219 TFE_Execute(op, &retval[0], &num_retvals, status);
1220 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1221 ASSERT_EQ(1, num_retvals);
1222 TFE_DeleteOp(op);
1223 TFE_DeleteTensorHandle(m);
1224 TF_Tensor* t = TFE_TensorHandleResolve(retval[0], status);
1225 TFE_DeleteTensorHandle(retval[0]);
1226 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1227 float product[4] = {0};
1228 EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
1229 memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
1230 TF_DeleteTensor(t);
1231 EXPECT_EQ(2, product[0]);
1232 EXPECT_EQ(4, product[1]);
1233 EXPECT_EQ(6, product[2]);
1234 EXPECT_EQ(8, product[3]);
1235
1236 // When we turn on grappler, confirm that the tf.Add has been rewritten into a
1237 // tf.Mul.
1238 // This capability of checking the executed op names is currently only enabled
1239 // for TFRT debug build, for performance and simplicity reasons.
1240 if (use_tfrt) {
1241 TF_Buffer* buf = TF_NewBuffer();
1242 TFE_GetExecutedOpNames(ctx, buf, status);
1243 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1244 #ifndef NDEBUG
1245 if (enable_grappler)
1246 EXPECT_NE(strstr(static_cast<const char*>(buf->data), "tf.Mul"), nullptr);
1247 else
1248 EXPECT_NE(strstr(static_cast<const char*>(buf->data), "tf.Add"), nullptr);
1249 #endif
1250 TF_DeleteBuffer(buf);
1251 }
1252
1253 TFE_ContextRemoveFunction(ctx, "AddFunction", status);
1254 ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
1255 TFE_DeleteContext(ctx);
1256 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1257 TF_DeleteStatus(status);
1258 }
1259
TEST(CAPI,RunAddFunctionWithGrappler)1260 TEST(CAPI, RunAddFunctionWithGrappler) {
1261 RunAddFunction(/*use_tfrt=*/false, /*enable_grappler=*/true);
1262 }
1263
1264 #ifdef PLATFORM_GOOGLE
TEST(CAPI,RunAddFunction_TFRT)1265 TEST(CAPI, RunAddFunction_TFRT) {
1266 RunAddFunction(/*use_tfrt=*/true, /*enable_grappler=*/false);
1267 }
1268
TEST(CAPI,RunAddFunctionWithGrappler_TFRT)1269 TEST(CAPI, RunAddFunctionWithGrappler_TFRT) {
1270 RunAddFunction(/*use_tfrt=*/true, /*enable_grappler=*/true);
1271 }
1272 #endif
1273
BM_ExecuteFunction(::testing::benchmark::State & state)1274 void BM_ExecuteFunction(::testing::benchmark::State& state) {
1275 const int async = state.range(0);
1276 state.SetLabel(async ? "ExecuteFunctionAsync" : "ExecuteFunction");
1277 TF_Status* status = TF_NewStatus();
1278 TFE_ContextOptions* opts = TFE_NewContextOptions();
1279 TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
1280 TFE_Context* ctx = TFE_NewContext(opts, status);
1281 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1282 TFE_DeleteContextOptions(opts);
1283
1284 string function_def = MatMulFunction();
1285 TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
1286 status);
1287 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1288
1289 TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
1290 TFE_TensorHandle* retval[1] = {nullptr};
1291 int num_retvals = 1;
1292 for (auto s : state) {
1293 TFE_Op* matmul = TFE_NewOp(ctx, "MatMulFunction", status);
1294 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1295 TFE_OpAddInput(matmul, m, status);
1296 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1297 TFE_Execute(matmul, &retval[0], &num_retvals, status);
1298 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1299 TFE_DeleteOp(matmul);
1300 if (state.iterations() >= state.max_iterations && async) {
1301 TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
1302 TFE_ExecutorWaitForAllPendingNodes(executor, status);
1303 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1304 TFE_DeleteExecutor(executor);
1305 }
1306 }
1307 TFE_DeleteTensorHandle(m);
1308 TFE_DeleteTensorHandle(retval[0]);
1309 TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
1310 ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
1311 TFE_DeleteContext(ctx);
1312 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1313 TF_DeleteStatus(status);
1314 }
1315 BENCHMARK(BM_ExecuteFunction)->Arg(0)->Arg(1);
1316
TEST(CAPI,Variables)1317 TEST(CAPI, Variables) {
1318 // Variables use resource handles, so this is really a test for resource
1319 // tensor handling.
1320 TF_Status* status = TF_NewStatus();
1321 TFE_ContextOptions* opts = TFE_NewContextOptions();
1322 TFE_Context* ctx = TFE_NewContext(opts, status);
1323 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1324 TFE_DeleteContextOptions(opts);
1325
1326 TFE_TensorHandle* var_handle = TestVariable(ctx, 12.0);
1327 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1328
1329 TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
1330 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1331 TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
1332 TFE_OpAddInput(op, var_handle, status);
1333 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1334 int num_retvals = 1;
1335 TFE_TensorHandle* value_handle = nullptr;
1336 TFE_Execute(op, &value_handle, &num_retvals, status);
1337 TFE_DeleteOp(op);
1338
1339 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1340 ASSERT_EQ(1, num_retvals);
1341 EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(value_handle));
1342 EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle, status));
1343 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1344 float value = 0.0f;
1345 TF_Tensor* t = TFE_TensorHandleResolve(value_handle, status);
1346 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1347 ASSERT_EQ(sizeof(float), TF_TensorByteSize(t));
1348 memcpy(&value, TF_TensorData(t), sizeof(float));
1349 TF_DeleteTensor(t);
1350 EXPECT_EQ(12.0, value);
1351
1352 TFE_DeleteTensorHandle(var_handle);
1353 TFE_DeleteTensorHandle(value_handle);
1354 TFE_DeleteContext(ctx);
1355 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1356 TF_DeleteStatus(status);
1357 }
1358
BM_ReadVariable(::testing::benchmark::State & state)1359 void BM_ReadVariable(::testing::benchmark::State& state) {
1360 TF_Status* status = TF_NewStatus();
1361 TFE_ContextOptions* opts = TFE_NewContextOptions();
1362 TFE_Context* ctx = TFE_NewContext(opts, status);
1363 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1364 TFE_DeleteContextOptions(opts);
1365
1366 TFE_TensorHandle* var_handle = TestVariable(ctx, 5.0);
1367 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1368
1369 int num_retvals = 1;
1370 TFE_TensorHandle* h = nullptr;
1371 for (auto s : state) {
1372 TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
1373 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1374 TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
1375 TFE_OpAddInput(op, var_handle, status);
1376 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1377 TFE_Execute(op, &h, &num_retvals, status);
1378 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1379 CHECK_EQ(1, num_retvals);
1380 CHECK(h);
1381 CHECK_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
1382 CHECK_EQ(0, TFE_TensorHandleNumDims(h, status));
1383 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1384 h = nullptr;
1385 TFE_DeleteOp(op);
1386 }
1387
1388 TFE_DeleteTensorHandle(var_handle);
1389 TFE_DeleteContext(ctx);
1390 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1391 TF_DeleteStatus(status);
1392 }
1393 BENCHMARK(BM_ReadVariable);
1394
TEST(CAPI,StringAttributes)1395 TEST(CAPI, StringAttributes) {
1396 // Test that TFE_OpSetAttrString doesn't hold on to the value after it
1397 // returns.
1398 TF_Status* status = TF_NewStatus();
1399 TFE_ContextOptions* opts = TFE_NewContextOptions();
1400 TFE_Context* ctx = TFE_NewContext(opts, status);
1401 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1402 TFE_DeleteContextOptions(opts);
1403
1404 std::vector<int64_t> dims(4, 1);
1405 TFE_Op* op = TFE_NewOp(ctx, "AvgPool", status);
1406 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1407
1408 TF_Tensor* tensor =
1409 TF_AllocateTensor(TF_FLOAT, dims.data(), dims.size(), sizeof(float));
1410 float tensor_data[] = {1};
1411 memcpy(TF_TensorData(tensor), tensor_data, TF_TensorByteSize(tensor));
1412 TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status);
1413 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1414 TFE_OpAddInput(op, tensor_handle, status);
1415 TF_DeleteTensor(tensor);
1416 TFE_DeleteTensorHandle(tensor_handle);
1417
1418 std::vector<int64_t> values(4, 1);
1419 TFE_OpSetAttrIntList(op, "ksize", values.data(), values.size());
1420 TFE_OpSetAttrIntList(op, "strides", values.data(), values.size());
1421
1422 const int BUFFER_SIZE = 10;
1423 char buffer[BUFFER_SIZE];
1424 std::strncpy(buffer, "VALID", BUFFER_SIZE);
1425 TFE_OpSetAttrString(op, "padding", buffer, std::strlen(buffer));
1426 // Overwriting value in "buffer", should be fine since TFE_Op
1427 // shouldn't be holding on to it.
1428 std::strncpy(buffer, "NHWC", BUFFER_SIZE);
1429 TFE_OpSetAttrString(op, "data_format", buffer, std::strlen(buffer));
1430
1431 TFE_OpSetAttrType(op, "T", TF_FLOAT);
1432
1433 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1434
1435 TFE_TensorHandle* retvals[1];
1436 int num_retvals = 1;
1437 TFE_Execute(op, &retvals[0], &num_retvals, status);
1438 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1439 ASSERT_EQ(1, num_retvals);
1440
1441 tensor = TFE_TensorHandleResolve(retvals[0], status);
1442 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1443 EXPECT_EQ(4, TF_TensorByteSize(tensor));
1444 TF_DeleteTensor(tensor);
1445 TFE_DeleteTensorHandle(retvals[0]);
1446
1447 TFE_DeleteOp(op);
1448
1449 TFE_DeleteContext(ctx);
1450 TF_DeleteStatus(status);
1451 }
1452
1453 // Same test as above, expect use SetOpAttrValueScalar to set attrs.
TEST(CAPI,TestTFE_SetOpAttrs)1454 TEST(CAPI, TestTFE_SetOpAttrs) {
1455 // Test that TFE_OpSetAttrString doesn't hold on to the value after it
1456 // returns.
1457 TF_Status* status = TF_NewStatus();
1458 TFE_ContextOptions* opts = TFE_NewContextOptions();
1459 TFE_Context* ctx = TFE_NewContext(opts, status);
1460 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1461 TFE_DeleteContextOptions(opts);
1462
1463 std::vector<int64_t> dims(4, 1);
1464 TFE_Op* op = TFE_NewOp(ctx, "AvgPool", status);
1465 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1466
1467 TF_Tensor* tensor =
1468 TF_AllocateTensor(TF_FLOAT, dims.data(), dims.size(), sizeof(float));
1469 float tensor_data[] = {1};
1470 memcpy(TF_TensorData(tensor), tensor_data, TF_TensorByteSize(tensor));
1471 TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status);
1472 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1473 TFE_OpAddInput(op, tensor_handle, status);
1474 TF_DeleteTensor(tensor);
1475 TFE_DeleteTensorHandle(tensor_handle);
1476
1477 tensorflow::AttrValue i_list_values;
1478 for (int i = 0; i < 4; ++i) {
1479 i_list_values.mutable_list()->add_i(1);
1480 }
1481 SetOpAttrValueScalar(ctx, op, i_list_values, "ksize", status);
1482 SetOpAttrValueScalar(ctx, op, i_list_values, "strides", status);
1483
1484 tensorflow::AttrValue padding_value;
1485 *padding_value.mutable_s() = "VALID";
1486 tensorflow::SetOpAttrValueScalar(ctx, op, padding_value, "padding", status);
1487
1488 tensorflow::AttrValue data_format_value;
1489 *data_format_value.mutable_s() = "NHWC";
1490 tensorflow::SetOpAttrValueScalar(ctx, op, data_format_value, "data_format",
1491 status);
1492
1493 TFE_OpSetAttrType(op, "T", TF_FLOAT);
1494
1495 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1496
1497 TFE_TensorHandle* retvals[1];
1498 int num_retvals = 1;
1499 TFE_Execute(op, &retvals[0], &num_retvals, status);
1500 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1501 ASSERT_EQ(1, num_retvals);
1502
1503 tensor = TFE_TensorHandleResolve(retvals[0], status);
1504 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1505 EXPECT_EQ(4, TF_TensorByteSize(tensor));
1506 TF_DeleteTensor(tensor);
1507 TFE_DeleteTensorHandle(retvals[0]);
1508
1509 TFE_DeleteOp(op);
1510
1511 TFE_DeleteContext(ctx);
1512 TF_DeleteStatus(status);
1513 }
1514
TEST(CAPI,TestTFE_TensorHandleCopySharingUnderlyingTensorHandle)1515 TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
1516 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
1517 TF_NewStatus(), TF_DeleteStatus);
1518 TFE_ContextOptions* opts = TFE_NewContextOptions();
1519 TFE_Context* ctx = TFE_NewContext(opts, status.get());
1520 CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
1521 TFE_DeleteContextOptions(opts);
1522
1523 TFE_TensorHandle* h = TestMatrixTensorHandle(ctx);
1524 EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
1525
1526 TFE_TensorHandle* h_shares_tensor =
1527 TFE_TensorHandleCopySharingTensor(h, status.get());
1528 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
1529
1530 TF_Tensor* t = TFE_TensorHandleResolve(h_shares_tensor, status.get());
1531 ASSERT_EQ(16, TF_TensorByteSize(t));
1532 float data[4] = {0};
1533 memcpy(&data[0], TF_TensorData(t), TF_TensorByteSize(t));
1534 EXPECT_EQ(1.0, data[0]);
1535 EXPECT_EQ(2.0, data[1]);
1536 EXPECT_EQ(3.0, data[2]);
1537 EXPECT_EQ(4.0, data[3]);
1538 TF_DeleteTensor(t);
1539
1540 TFE_DeleteTensorHandle(h);
1541 TFE_DeleteTensorHandle(h_shares_tensor);
1542 TFE_DeleteContext(ctx);
1543 }
1544
ExtractAttrs(TFE_Op * op)1545 tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) {
1546 tensorflow::AttrValueMap attr_values;
1547 tensorflow::EagerOperation* operation =
1548 tensorflow::OperationFromInterface(tensorflow::unwrap(op));
1549 operation->Attrs().FillAttrValueMap(&attr_values);
1550 return attr_values;
1551 }
1552
TEST(CAPI,TestTFE_OpInferSingleInputAttrs)1553 TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
1554 TF_Status* status = TF_NewStatus();
1555 TFE_ContextOptions* opts = TFE_NewContextOptions();
1556 TFE_Context* ctx = TFE_NewContext(opts, status);
1557 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1558 TFE_DeleteContextOptions(opts);
1559
1560 TFE_TensorHandle* input = TestMatrixTensorHandle(ctx);
1561 TFE_TensorHandle* axis = TestAxisTensorHandle(ctx);
1562 TFE_Op* minOp = TFE_NewOp(ctx, "Min", status);
1563 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1564 TFE_OpAddInput(minOp, input, status);
1565 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1566 TFE_OpAddInput(minOp, axis, status);
1567 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1568
1569 tensorflow::AttrValueMap attr_values = ExtractAttrs(minOp);
1570 tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
1571 EXPECT_NE(attr_found, attr_values.cend());
1572 EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
1573 attr_found = attr_values.find("Tidx");
1574 EXPECT_NE(attr_found, attr_values.cend());
1575 EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_INT32);
1576
1577 TFE_TensorHandle* retvals[1] = {nullptr};
1578 int num_retvals = 1;
1579 TFE_Execute(minOp, &retvals[0], &num_retvals, status);
1580 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1581
1582 TF_DeleteStatus(status);
1583 TFE_DeleteOp(minOp);
1584 TFE_DeleteTensorHandle(input);
1585 TFE_DeleteTensorHandle(axis);
1586 TFE_DeleteTensorHandle(retvals[0]);
1587 TFE_DeleteContext(ctx);
1588 }
1589
TEST(CAPI,TestTFE_OpInferSingleTypeInputListAttrs)1590 TEST(CAPI, TestTFE_OpInferSingleTypeInputListAttrs) {
1591 TF_Status* status = TF_NewStatus();
1592 TFE_ContextOptions* opts = TFE_NewContextOptions();
1593 TFE_Context* ctx = TFE_NewContext(opts, status);
1594 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1595 TFE_DeleteContextOptions(opts);
1596
1597 TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
1598 TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
1599 TFE_TensorHandle* dim = TestScalarTensorHandle(ctx, 0);
1600 TFE_Op* concatOp = TFE_NewOp(ctx, "Concat", status);
1601 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1602 TFE_TensorHandle* inputs[] = {input1, input2};
1603 TFE_OpAddInput(concatOp, dim, status);
1604 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1605 TFE_OpAddInputList(concatOp, inputs, 2, status);
1606 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1607
1608 tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
1609 tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
1610 EXPECT_NE(attr_found, attr_values.cend());
1611 EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
1612 attr_found = attr_values.find("N");
1613 EXPECT_NE(attr_found, attr_values.cend());
1614 EXPECT_EQ(attr_found->second.i(), 2);
1615
1616 TFE_TensorHandle* retvals[1] = {nullptr};
1617 int num_retvals = 1;
1618 TFE_Execute(concatOp, &retvals[0], &num_retvals, status);
1619 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1620
1621 TF_DeleteStatus(status);
1622 TFE_DeleteOp(concatOp);
1623 TFE_DeleteTensorHandle(input1);
1624 TFE_DeleteTensorHandle(input2);
1625 TFE_DeleteTensorHandle(retvals[0]);
1626 TFE_DeleteTensorHandle(dim);
1627 TFE_DeleteContext(ctx);
1628 }
1629
TEST(CAPI,TestTFE_OpInferMixedTypeInputListAttrs)1630 TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) {
1631 TF_Status* status = TF_NewStatus();
1632 TFE_ContextOptions* opts = TFE_NewContextOptions();
1633 TFE_Context* ctx = TFE_NewContext(opts, status);
1634 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1635 TFE_DeleteContextOptions(opts);
1636
1637 TFE_TensorHandle* condition = TestScalarTensorHandle(ctx, true);
1638 TFE_TensorHandle* t1 = TestMatrixTensorHandle(ctx);
1639 TFE_TensorHandle* t2 = TestAxisTensorHandle(ctx);
1640 TFE_Op* assertOp = TFE_NewOp(ctx, "Assert", status);
1641 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1642 TFE_OpAddInput(assertOp, condition, status);
1643 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1644 TFE_TensorHandle* data[] = {condition, t1, t2};
1645 TFE_OpAddInputList(assertOp, data, 3, status);
1646 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1647
1648 tensorflow::AttrValueMap attr_values = ExtractAttrs(assertOp);
1649 tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
1650 EXPECT_NE(attr_found, attr_values.cend());
1651 EXPECT_EQ(attr_found->second.list().type(0), tensorflow::DataType::DT_BOOL);
1652 EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT);
1653 EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32);
1654
1655 TFE_TensorHandle* retvals[1] = {nullptr};
1656 int num_retvals = 1;
1657 TFE_Execute(assertOp, &retvals[0], &num_retvals, status);
1658 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1659
1660 TF_DeleteStatus(status);
1661 TFE_DeleteOp(assertOp);
1662 TFE_DeleteTensorHandle(condition);
1663 TFE_DeleteTensorHandle(t1);
1664 TFE_DeleteTensorHandle(t2);
1665 TFE_DeleteTensorHandle(retvals[0]);
1666 TFE_DeleteContext(ctx);
1667 }
1668
TEST(CAPI,TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList)1669 TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
1670 TF_Status* status = TF_NewStatus();
1671 TFE_ContextOptions* opts = TFE_NewContextOptions();
1672 TFE_Context* ctx = TFE_NewContext(opts, status);
1673 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1674 TFE_DeleteContextOptions(opts);
1675
1676 TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
1677 TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
1678 TFE_TensorHandle* dim = TestScalarTensorHandle(ctx, 0);
1679 TFE_Op* concatOp = TFE_NewOp(ctx, "Concat", status);
1680 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1681 TFE_TensorHandle* inputs[] = {input1, input2};
1682 TFE_OpAddInput(concatOp, dim, status);
1683 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1684 CHECK(tensorflow::unwrap(concatOp)->OpDef());
1685 TFE_OpAddInput(concatOp, inputs[0], status);
1686 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1687 EXPECT_FALSE(tensorflow::unwrap(concatOp)->OpDef())
1688 << "Inference context is still present";
1689 TFE_OpAddInput(concatOp, inputs[1], status);
1690 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1691
1692 tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
1693 EXPECT_EQ(attr_values.find("T"), attr_values.end());
1694 EXPECT_EQ(attr_values.find("N"), attr_values.end());
1695
1696 TF_DeleteStatus(status);
1697 TFE_DeleteOp(concatOp);
1698 TFE_DeleteTensorHandle(input1);
1699 TFE_DeleteTensorHandle(input2);
1700 TFE_DeleteTensorHandle(dim);
1701 TFE_DeleteContext(ctx);
1702 }
1703
TEST(CAPI,TestTFE_OpGetInputAndOutputLengths)1704 TEST(CAPI, TestTFE_OpGetInputAndOutputLengths) {
1705 TF_Status* status = TF_NewStatus();
1706 TFE_ContextOptions* opts = TFE_NewContextOptions();
1707 TFE_Context* ctx = TFE_NewContext(opts, status);
1708 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1709 TFE_DeleteContextOptions(opts);
1710
1711 TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
1712 TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
1713 TFE_Op* identityOp = TFE_NewOp(ctx, "IdentityN", status);
1714 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1715
1716 // Try to retrieve lengths before building the attributes (should fail)
1717 EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "input", status));
1718 CHECK_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
1719 EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status));
1720 CHECK_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
1721
1722 TFE_TensorHandle* inputs[] = {input1, input2};
1723 TFE_OpAddInputList(identityOp, inputs, 2, status);
1724 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1725
1726 // Try to retrieve lengths before executing the op (should work)
1727 EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status));
1728 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1729 EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status));
1730 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1731
1732 TFE_TensorHandle* retvals[2] = {nullptr};
1733 int num_retvals = 2;
1734 TFE_Execute(identityOp, &retvals[0], &num_retvals, status);
1735 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1736
1737 // Try to retrieve lengths after executing the op (should work)
1738 EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status));
1739 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1740 EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status));
1741 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1742
1743 TF_DeleteStatus(status);
1744 TFE_DeleteOp(identityOp);
1745 TFE_DeleteTensorHandle(input1);
1746 TFE_DeleteTensorHandle(input2);
1747 TFE_DeleteTensorHandle(retvals[0]);
1748 TFE_DeleteTensorHandle(retvals[1]);
1749 TFE_DeleteContext(ctx);
1750 }
1751
TEST(CAPI,TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments)1752 TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
1753 TF_Status* status = TF_NewStatus();
1754 TFE_ContextOptions* opts = TFE_NewContextOptions();
1755 TFE_Context* ctx = TFE_NewContext(opts, status);
1756 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1757 TFE_DeleteContextOptions(opts);
1758
1759 TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
1760 TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
1761 TFE_Op* identityOp = TFE_NewOp(ctx, "IdentityN", status);
1762 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1763 TFE_TensorHandle* inputs[] = {input1, input2};
1764 TFE_OpAddInputList(identityOp, inputs, 2, status);
1765 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1766
1767 EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "cheese", status));
1768 CHECK_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status)) << TF_Message(status);
1769 EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "cheese", status));
1770 CHECK_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status)) << TF_Message(status);
1771
1772 TF_DeleteStatus(status);
1773 TFE_DeleteOp(identityOp);
1774 TFE_DeleteTensorHandle(input1);
1775 TFE_DeleteTensorHandle(input2);
1776 TFE_DeleteContext(ctx);
1777 }
1778
TestOpAddAttrs(bool use_tfrt)1779 void TestOpAddAttrs(bool use_tfrt) {
1780 TF_Status* status = TF_NewStatus();
1781 TFE_ContextOptions* opts = TFE_NewContextOptions();
1782 TFE_ContextOptionsSetTfrt(opts, use_tfrt);
1783 TFE_Context* ctx = TFE_NewContext(opts, status);
1784 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1785 TFE_DeleteContextOptions(opts);
1786
1787 TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
1788 TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
1789 TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
1790 const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op);
1791
1792 TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
1793 TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
1794 TFE_OpAddAttrs(copy_op, attributes);
1795 unsigned char is_list = 0;
1796 ASSERT_EQ(TF_ATTR_TYPE,
1797 TFE_OpGetAttrType(copy_op, "dtype", &is_list, status));
1798 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1799 ASSERT_EQ(TF_ATTR_SHAPE,
1800 TFE_OpGetAttrType(copy_op, "shape", &is_list, status));
1801 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1802
1803 tensorflow::AttrValueMap attr_values;
1804 if (use_tfrt) {
1805 #ifdef PLATFORM_GOOGLE
1806 auto* op = tensorflow::down_cast<tfrt::tf::OperationInterface*>(
1807 tensorflow::unwrap(copy_op));
1808 auto* tfrt_op_attrs =
1809 tensorflow::down_cast<const tfrt::tf::OpAttrsInterface*>(
1810 op->GetOpAttrs());
1811 tensorflow::DataType result;
1812 tfrt_op_attrs->GetType("dtype", &result);
1813 EXPECT_EQ(tensorflow::DT_FLOAT, result);
1814 tfrt_op_attrs->GetFallbackAttrs()->FillAttrValueMap(&attr_values);
1815 #endif
1816 } else {
1817 tensorflow::EagerOperation* op =
1818 tensorflow::OperationFromInterface(tensorflow::unwrap(copy_op));
1819 op->Attrs().FillAttrValueMap(&attr_values);
1820 }
1821 EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
1822
1823 TF_DeleteStatus(status);
1824 TFE_DeleteOp(var_op);
1825 TFE_DeleteOp(copy_op);
1826 TFE_DeleteContext(ctx);
1827 }
1828
TEST(CAPI,TestTFE_OpAddAttrs)1829 TEST(CAPI, TestTFE_OpAddAttrs) { TestOpAddAttrs(/*use_tfrt=*/false); }
1830
1831 #ifdef PLATFORM_GOOGLE
TEST(CAPI,TestTFE_OpAddAttrs_TFRT)1832 TEST(CAPI, TestTFE_OpAddAttrs_TFRT) { TestOpAddAttrs(/*use_tfrt=*/true); }
1833
1834 #endif
1835
TEST(CAPI,TestTFE_OpAttrsSerialize)1836 TEST(CAPI, TestTFE_OpAttrsSerialize) {
1837 TF_Status* status = TF_NewStatus();
1838 TFE_ContextOptions* opts = TFE_NewContextOptions();
1839 TFE_Context* ctx = TFE_NewContext(opts, status);
1840 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1841 TFE_DeleteContextOptions(opts);
1842
1843 TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
1844 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1845 TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
1846 TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
1847 const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op);
1848
1849 TF_Buffer* serialized_attr_values = TF_NewBuffer();
1850 TFE_OpAttrsSerialize(attributes, serialized_attr_values, status);
1851 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1852 tensorflow::NameAttrList name_and_attrs;
1853 ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data,
1854 serialized_attr_values->length));
1855 ASSERT_EQ("VarHandleOp", name_and_attrs.name());
1856 ASSERT_EQ(tensorflow::DT_INT64,
1857 name_and_attrs.attr().find("dtype")->second.type());
1858 TF_DeleteBuffer(serialized_attr_values);
1859
1860 TFE_Op* var_op_2 = TFE_NewOp(ctx, "VarHandleOp", status);
1861
1862 string serialized_dtype;
1863 ASSERT_TRUE(name_and_attrs.attr().find("dtype")->second.SerializeToString(
1864 &serialized_dtype));
1865 TFE_OpSetAttrValueProto(
1866 var_op_2, "dtype",
1867 reinterpret_cast<const void*>(serialized_dtype.c_str()),
1868 serialized_dtype.length(), status);
1869 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1870
1871 tensorflow::AttrValueMap attr_values;
1872 tensorflow::EagerOperation* op =
1873 tensorflow::OperationFromInterface(tensorflow::unwrap(var_op_2));
1874 op->Attrs().FillAttrValueMap(&attr_values);
1875 EXPECT_EQ(tensorflow::DT_INT64, attr_values.find("dtype")->second.type());
1876
1877 TF_DeleteStatus(status);
1878 TFE_DeleteOp(var_op);
1879 TFE_DeleteOp(var_op_2);
1880 TFE_DeleteContext(ctx);
1881 }
1882
1883 // Needs to work with a const TFE_Op since custom devices should not modify the
1884 // op they are called with.
CloneOp(const TFE_Op * other)1885 TFE_Op* CloneOp(const TFE_Op* other) {
1886 TF_Status* status = TF_NewStatus();
1887 TFE_Context* context = TFE_OpGetContext(other, status);
1888 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1889 const char* op_name = TFE_OpGetName(other, status);
1890 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1891 TFE_Op* ret = TFE_NewOp(context, op_name, status);
1892 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1893 const char* device = TFE_OpGetDevice(other, status);
1894 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1895 TFE_OpSetDevice(ret, device, status);
1896 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1897 TFE_OpAddAttrs(ret, TFE_OpGetAttrs(other));
1898 int num_inputs = TFE_OpGetFlatInputCount(other, status);
1899 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1900 for (int input_index = 0; input_index < num_inputs; ++input_index) {
1901 TFE_TensorHandle* input = TFE_OpGetFlatInput(other, input_index, status);
1902 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1903 TFE_OpAddInput(ret, input, status);
1904 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1905 }
1906 TF_DeleteStatus(status);
1907 return ret;
1908 }
1909
TEST(CAPI,TestTFE_OpRecreation)1910 TEST(CAPI, TestTFE_OpRecreation) {
1911 TF_Status* status = TF_NewStatus();
1912 TFE_ContextOptions* opts = TFE_NewContextOptions();
1913 TFE_Context* ctx = TFE_NewContext(opts, status);
1914 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1915 TFE_DeleteContextOptions(opts);
1916
1917 // Clone an op with attributes and a device set.
1918 TFE_Op* original_var_op = TFE_NewOp(ctx, "VarHandleOp", status);
1919 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1920 TFE_OpSetAttrType(original_var_op, "dtype", TF_INT64);
1921 TFE_OpSetAttrShape(original_var_op, "shape", {}, 0, status);
1922 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1923 EXPECT_EQ("", std::string(TFE_OpGetDevice(original_var_op, status)));
1924 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1925 TFE_OpSetDevice(original_var_op,
1926 "/job:localhost/replica:0/task:0/device:CPU:0", status);
1927 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1928 TFE_Op* cloned = CloneOp(original_var_op);
1929
1930 EXPECT_EQ("/job:localhost/replica:0/task:0/device:CPU:0",
1931 std::string(TFE_OpGetDevice(cloned, status)));
1932 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1933 EXPECT_EQ("VarHandleOp", std::string(TFE_OpGetName(cloned, status)));
1934 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1935
1936 int num_retvals = 1;
1937 TFE_TensorHandle* ret;
1938 TFE_Execute(cloned, &ret, &num_retvals, status);
1939 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1940 TFE_DeleteTensorHandle(ret);
1941
1942 // Clone an op with inputs and no device set.
1943 TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
1944 TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
1945 TFE_Op* original_identity = TFE_NewOp(ctx, "IdentityN", status);
1946 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1947 TFE_TensorHandle* inputs[] = {input1, input2};
1948 TFE_OpAddInputList(original_identity, inputs, 2, status);
1949 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1950 TFE_Op* cloned_identity = CloneOp(original_identity);
1951 EXPECT_EQ("", std::string(TFE_OpGetDevice(cloned_identity, status)));
1952 TFE_TensorHandle* identity_ret[] = {nullptr, nullptr};
1953 num_retvals = 2;
1954 TFE_Execute(cloned_identity, identity_ret, &num_retvals, status);
1955 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1956
1957 TFE_DeleteTensorHandle(input1);
1958 TFE_DeleteTensorHandle(input2);
1959 TFE_DeleteTensorHandle(identity_ret[0]);
1960 TFE_DeleteTensorHandle(identity_ret[1]);
1961
1962 TFE_DeleteOp(cloned_identity);
1963 TFE_DeleteOp(original_identity);
1964 TFE_DeleteOp(original_var_op);
1965 TFE_DeleteOp(cloned);
1966 TF_DeleteStatus(status);
1967 TFE_DeleteContext(ctx);
1968 }
1969
ReplaceTaskInServerDef(const tensorflow::ServerDef & server_def,int task_index)1970 tensorflow::ServerDef ReplaceTaskInServerDef(
1971 const tensorflow::ServerDef& server_def, int task_index) {
1972 tensorflow::ServerDef server_def_copy = server_def;
1973 tensorflow::ClusterDef* cluster_def = server_def_copy.mutable_cluster();
1974 tensorflow::JobDef* job_def = cluster_def->mutable_job(0);
1975 const int port = tensorflow::testing::PickUnusedPortOrDie();
1976 job_def->mutable_tasks()->at(task_index) =
1977 tensorflow::strings::StrCat("localhost:", port);
1978 return server_def_copy;
1979 }
1980
CreateVarHandle(TFE_Context * ctx,const tensorflow::string & device_name,const tensorflow::string & variable_name)1981 TFE_TensorHandle* CreateVarHandle(TFE_Context* ctx,
1982 const tensorflow::string& device_name,
1983 const tensorflow::string& variable_name) {
1984 TF_Status* status = TF_NewStatus();
1985 // Create the variable handle.
1986 TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
1987 if (TF_GetCode(status) != TF_OK) return nullptr;
1988 TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
1989 TFE_OpSetAttrShape(op, "shape", {}, 0, status);
1990 TFE_OpSetAttrString(op, "container", "localhost", 0);
1991 TFE_OpSetAttrString(op, "shared_name", variable_name.data(),
1992 variable_name.size());
1993 if (!device_name.empty()) {
1994 TFE_OpSetDevice(op, device_name.c_str(), status);
1995 }
1996 if (TF_GetCode(status) != TF_OK) return nullptr;
1997 TFE_TensorHandle* var_handle = nullptr;
1998 int num_retvals = 1;
1999 TFE_Execute(op, &var_handle, &num_retvals, status);
2000 if (TF_GetCode(status) != TF_OK) return nullptr;
2001 TFE_DeleteOp(op);
2002 if (TF_GetCode(status) != TF_OK) return nullptr;
2003 CHECK_EQ(1, num_retvals);
2004 TF_DeleteStatus(status);
2005 return var_handle;
2006 }
2007
CreateVariable(TFE_Context * ctx,float value,const tensorflow::string & device_name,const tensorflow::string & variable_name)2008 TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
2009 const tensorflow::string& device_name,
2010 const tensorflow::string& variable_name) {
2011 TF_Status* status = TF_NewStatus();
2012 TFE_TensorHandle* var_handle =
2013 CreateVarHandle(ctx, device_name, variable_name);
2014
2015 // Assign 'value' to it.
2016 TFE_Op* op = TFE_NewOp(ctx, "AssignVariableOp", status);
2017 if (TF_GetCode(status) != TF_OK) return nullptr;
2018 TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
2019 TFE_OpAddInput(op, var_handle, status);
2020 if (!device_name.empty()) {
2021 TFE_OpSetDevice(op, device_name.c_str(), status);
2022 }
2023
2024 // Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
2025 std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
2026 TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor);
2027 memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
2028
2029 std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
2030 value_handle(TFE_NewTensorHandle(t.get(), status),
2031 TFE_DeleteTensorHandle);
2032 if (TF_GetCode(status) != TF_OK) return nullptr;
2033
2034 TFE_OpAddInput(op, value_handle.get(), status);
2035 if (TF_GetCode(status) != TF_OK) return nullptr;
2036
2037 int num_retvals = 0;
2038 TFE_Execute(op, nullptr, &num_retvals, status);
2039 TFE_DeleteOp(op);
2040 if (TF_GetCode(status) != TF_OK) return nullptr;
2041 CHECK_EQ(0, num_retvals);
2042 TF_DeleteStatus(status);
2043 return var_handle;
2044 }
2045
CreateContext(const string & serialized_server_def,bool isolate_session_state)2046 TFE_Context* CreateContext(const string& serialized_server_def,
2047 bool isolate_session_state) {
2048 TF_Status* status = TF_NewStatus();
2049 TFE_ContextOptions* opts = TFE_NewContextOptions();
2050 opts->session_options.options.config.set_isolate_session_state(
2051 isolate_session_state);
2052 TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(false));
2053 TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
2054 TFE_Context* ctx = TFE_NewContext(opts, status);
2055 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
2056 TFE_ContextSetServerDef(ctx, 0, serialized_server_def.data(),
2057 serialized_server_def.size(), status);
2058 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
2059 TFE_DeleteContextOptions(opts);
2060 TF_DeleteStatus(status);
2061 return ctx;
2062 }
2063
ListDeviceNames(TFE_Context * ctx)2064 std::vector<std::string> ListDeviceNames(TFE_Context* ctx) {
2065 TF_Status* status = TF_NewStatus();
2066 std::vector<std::string> device_names;
2067 TF_DeviceList* devices = TFE_ContextListDevices(ctx, status);
2068 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
2069 const int num_devices = TF_DeviceListCount(devices);
2070 for (int i = 0; i < num_devices; ++i) {
2071 device_names.emplace_back(TF_DeviceListName(devices, i, status));
2072 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
2073 }
2074 TF_DeleteDeviceList(devices);
2075 TF_DeleteStatus(status);
2076 return device_names;
2077 }
2078
TEST(CAPI,ShareVariableAcrossContextsWorks)2079 TEST(CAPI, ShareVariableAcrossContextsWorks) {
2080 // TODO(shreepadma): Add a test case with isolate_session_state set to true.
2081 tensorflow::ServerDef server_def_0 = GetServerDef(3);
2082 server_def_0.mutable_default_session_config()->set_isolate_session_state(
2083 false);
2084 tensorflow::ServerDef server_def_1 =
2085 ReplaceTaskInServerDef(server_def_0, /*task_index=*/0);
2086
2087 // These server defs have task index set to 0.
2088 string serialized_server_def_0 = server_def_0.SerializeAsString();
2089 string serialized_server_def_1 = server_def_1.SerializeAsString();
2090
2091 // Create two worker tasks.
2092 server_def_0.set_task_index(1);
2093 std::unique_ptr<tensorflow::GrpcServer> worker_server1;
2094 ASSERT_TRUE(tensorflow::GrpcServer::Create(
2095 server_def_0, tensorflow::Env::Default(), &worker_server1)
2096 .ok());
2097 ASSERT_TRUE(worker_server1->Start().ok());
2098 server_def_0.set_task_index(2);
2099 std::unique_ptr<tensorflow::GrpcServer> worker_server2;
2100 ASSERT_TRUE(tensorflow::GrpcServer::Create(
2101 server_def_0, tensorflow::Env::Default(), &worker_server2)
2102 .ok());
2103 ASSERT_TRUE(worker_server2->Start().ok());
2104
2105 TFE_Context* ctx_0 = CreateContext(serialized_server_def_0,
2106 /*isolate_session_state=*/false);
2107 TFE_Context* ctx_1 = CreateContext(serialized_server_def_1,
2108 /*isolate_session_state=*/false);
2109
2110 // Remote device on `worker1`.
2111 const char remote_device[] = "/job:localhost/replica:0/task:1/device:CPU:0";
2112 // `ctx_0`, `ctx_1`, `ctx_2` contains `remote_device`.
2113 {
2114 const std::vector<std::string>& device_names = ListDeviceNames(ctx_0);
2115 ASSERT_TRUE(std::find(device_names.begin(), device_names.end(),
2116 remote_device) != device_names.end());
2117 }
2118
2119 {
2120 const std::vector<std::string>& device_names = ListDeviceNames(ctx_1);
2121 ASSERT_TRUE(std::find(device_names.begin(), device_names.end(),
2122 remote_device) != device_names.end());
2123 }
2124
2125 // Create a variable using `ctx_0`.
2126 // Read the variable using `ctx_1`. This read should succeed.
2127 // 1. Create a variable on `remote_device`, using `ctx_0`.
2128 TFE_TensorHandle* handle_0 =
2129 CreateVariable(ctx_0, 1.2, remote_device, /*variable_name=*/"var2");
2130
2131 // 2. Wait for `var2` to be created and initialized on the worker.
2132 TF_Status* status = TF_NewStatus();
2133 TFE_ContextAsyncWait(ctx_0, status);
2134 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2135 TF_DeleteStatus(status);
2136
2137 // 3. Read `var_2` using `ctx_1`. This read should succeed since `ctx_1` was
2138 // created with `isolate_session_state` set to false.
2139 {
2140 // Create a handle to `var2`, using `ctx_1`.
2141 TFE_TensorHandle* var_handle =
2142 CreateVarHandle(ctx_1, remote_device, /*variable_name=*/"var2");
2143
2144 TFE_TensorHandle* handle_1 = nullptr;
2145 int num_retvals = 1;
2146 TF_Status* status = TF_NewStatus();
2147 TFE_Op* op = TFE_NewOp(ctx_1, "ReadVariableOp", status);
2148 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2149 TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
2150 TFE_OpAddInput(op, var_handle, status);
2151 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2152 TFE_Execute(op, &handle_1, &num_retvals, status);
2153 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2154 TFE_DeleteOp(op);
2155
2156 ASSERT_EQ(1, num_retvals);
2157 EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(handle_1));
2158 EXPECT_EQ(0, TFE_TensorHandleNumDims(handle_1, status));
2159 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2160
2161 // Read the value of tensor handle `handle_1`.
2162 float value = 0.0f;
2163 TF_Tensor* t = TFE_TensorHandleResolve(handle_1, status);
2164 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2165 ASSERT_EQ(sizeof(float), TF_TensorByteSize(t));
2166 memcpy(&value, TF_TensorData(t), sizeof(float));
2167 TF_DeleteTensor(t);
2168 EXPECT_EQ(1.2f, value);
2169 TFE_DeleteTensorHandle(handle_1);
2170 TF_DeleteStatus(status);
2171 TFE_DeleteTensorHandle(var_handle);
2172 }
2173
2174 TFE_DeleteTensorHandle(handle_0);
2175
2176 TFE_DeleteContext(ctx_0);
2177 TFE_DeleteContext(ctx_1);
2178
2179 worker_server1.release();
2180 worker_server2.release();
2181 }
2182
ReplaceTaskInServerDef(tensorflow::ServerDef * server_def,int task_index,const string & host,int port)2183 void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index,
2184 const string& host, int port) {
2185 tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0);
2186 job_def->mutable_tasks()->at(task_index) =
2187 tensorflow::strings::StrCat(host, ":", port);
2188 }
2189
TEST(CAPI,ShareVariableAcrossContextsAfterUpdateContextWorks)2190 TEST(CAPI, ShareVariableAcrossContextsAfterUpdateContextWorks) {
2191 tensorflow::ServerDef server_def_0 = GetServerDef(3);
2192 server_def_0.mutable_default_session_config()->set_isolate_session_state(
2193 false);
2194 tensorflow::ServerDef server_def_1 =
2195 ReplaceTaskInServerDef(server_def_0, /*task_index=*/0);
2196
2197 // These server defs have task index set to 0.
2198 string serialized_server_def_0 = server_def_0.SerializeAsString();
2199 string serialized_server_def_1 = server_def_1.SerializeAsString();
2200
2201 // Create two worker tasks.
2202 server_def_0.set_task_index(1);
2203 std::unique_ptr<tensorflow::GrpcServer> worker_server1;
2204 ASSERT_TRUE(tensorflow::GrpcServer::Create(
2205 server_def_0, tensorflow::Env::Default(), &worker_server1)
2206 .ok());
2207 ASSERT_TRUE(worker_server1->Start().ok());
2208 server_def_0.set_task_index(2);
2209 std::unique_ptr<tensorflow::GrpcServer> worker_server2;
2210 ASSERT_TRUE(tensorflow::GrpcServer::Create(
2211 server_def_0, tensorflow::Env::Default(), &worker_server2)
2212 .ok());
2213 ASSERT_TRUE(worker_server2->Start().ok());
2214
2215 // Create two contexts.
2216 TFE_Context* ctx_0 = CreateContext(serialized_server_def_0,
2217 /*isolate_session_state=*/false);
2218 TFE_Context* ctx_1 = CreateContext(serialized_server_def_1,
2219 /*isolate_session_state=*/false);
2220
2221 // Remote device on `worker2`.
2222 const char remote_device[] = "/job:localhost/replica:0/task:2/device:CPU:0";
2223 // `ctx_0`, `ctx_1` contains `remote_device`.
2224 {
2225 const std::vector<std::string>& device_names = ListDeviceNames(ctx_0);
2226 ASSERT_TRUE(std::find(device_names.begin(), device_names.end(),
2227 remote_device) != device_names.end());
2228 }
2229
2230 {
2231 const std::vector<std::string>& device_names = ListDeviceNames(ctx_1);
2232 ASSERT_TRUE(std::find(device_names.begin(), device_names.end(),
2233 remote_device) != device_names.end());
2234 }
2235
2236 // Create a variable using `ctx_0`.
2237 // Replace worker1 using a new worker, and update the contexts.
2238 // Read the variable using `ctx_1`. This read should succeed.
2239 //
2240 // 1. Create a variable on `remote_device`, using `ctx_0`.
2241 TFE_TensorHandle* handle_0 =
2242 CreateVariable(ctx_0, 1.2, remote_device, /*variable_name=*/"var");
2243
2244 // 2. Wait for `var` to be created and initialized on the worker.
2245 TF_Status* status = TF_NewStatus();
2246 TFE_ContextAsyncWait(ctx_0, status);
2247 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2248 TF_DeleteStatus(status);
2249
2250 int port = tensorflow::testing::PickUnusedPortOrDie();
2251 // 3. Replace worker1 with a new worker in server_def_0 and server_def_1.
2252 ReplaceTaskInServerDef(&server_def_0, /*task_index=*/1, "localhost", port);
2253 ReplaceTaskInServerDef(&server_def_1, /*task_index=*/1, "localhost", port);
2254 // 4. Start a new task to replace worker1.
2255 server_def_0.set_task_index(1);
2256 worker_server1.release();
2257 ASSERT_TRUE(tensorflow::GrpcServer::Create(
2258 server_def_0, tensorflow::Env::Default(), &worker_server1)
2259 .ok());
2260 ASSERT_TRUE(worker_server1->Start().ok());
2261
2262 // 5a. Update `ctx_0` with updated `server_def_0`.
2263 {
2264 server_def_0.set_task_index(0);
2265 string serialized_update = server_def_0.SerializeAsString();
2266 TF_Status* status = TF_NewStatus();
2267 TFE_ContextUpdateServerDef(ctx_0, 0, serialized_update.data(),
2268 serialized_update.size(), status);
2269 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2270 TF_DeleteStatus(status);
2271 }
2272
2273 // 5b. Update `ctx_1` with updated `server_def_1`.
2274 {
2275 server_def_1.set_task_index(0);
2276 string serialized_update = server_def_1.SerializeAsString();
2277 TF_Status* status = TF_NewStatus();
2278 TFE_ContextUpdateServerDef(ctx_1, 0, serialized_update.data(),
2279 serialized_update.size(), status);
2280 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2281 TF_DeleteStatus(status);
2282 }
2283
2284 // 6. Read `var` using `ctx_1`. This read should succeed since `ctx_1` was
2285 // created with `isolate_session_state` set to false, and update should
2286 // preserve it.
2287 {
2288 // Create a handle to `var`, using `ctx_1`.
2289 TFE_TensorHandle* var_handle =
2290 CreateVarHandle(ctx_1, remote_device, /*variable_name=*/"var");
2291
2292 TFE_TensorHandle* handle_1 = nullptr;
2293 int num_retvals = 1;
2294 TF_Status* status = TF_NewStatus();
2295 TFE_Op* op = TFE_NewOp(ctx_1, "ReadVariableOp", status);
2296 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2297 TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
2298 TFE_OpAddInput(op, var_handle, status);
2299 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2300 TFE_Execute(op, &handle_1, &num_retvals, status);
2301 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2302 TFE_DeleteOp(op);
2303
2304 ASSERT_EQ(1, num_retvals);
2305 EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(handle_1));
2306 EXPECT_EQ(0, TFE_TensorHandleNumDims(handle_1, status));
2307 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2308
2309 // Read the value of tensor handle `handle_1`.
2310 float value = 0.0f;
2311 TF_Tensor* t = TFE_TensorHandleResolve(handle_1, status);
2312 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2313 ASSERT_EQ(sizeof(float), TF_TensorByteSize(t));
2314 memcpy(&value, TF_TensorData(t), sizeof(float));
2315 TF_DeleteTensor(t);
2316 EXPECT_EQ(1.2f, value);
2317 TFE_DeleteTensorHandle(handle_1);
2318 TF_DeleteStatus(status);
2319 TFE_DeleteTensorHandle(var_handle);
2320 }
2321
2322 TFE_DeleteTensorHandle(handle_0);
2323
2324 TFE_DeleteContext(ctx_0);
2325 TFE_DeleteContext(ctx_1);
2326
2327 worker_server1.release();
2328 worker_server2.release();
2329 }
2330
CreateSingleHostServerDef(const tensorflow::ServerDef & cluster_server_def,int task_index)2331 tensorflow::ServerDef CreateSingleHostServerDef(
2332 const tensorflow::ServerDef& cluster_server_def, int task_index) {
2333 tensorflow::ServerDef single_host_server_def;
2334 single_host_server_def.set_job_name("worker");
2335 single_host_server_def.set_protocol(cluster_server_def.protocol());
2336 single_host_server_def.set_task_index(0);
2337 tensorflow::ClusterDef* cluster_def =
2338 single_host_server_def.mutable_cluster();
2339 tensorflow::JobDef* job_def = cluster_def->add_job();
2340 job_def->set_name("client");
2341
2342 // Add a client.
2343 job_def->mutable_tasks()->insert(
2344 {0, tensorflow::strings::StrCat(
2345 "localhost:", tensorflow::testing::PickUnusedPortOrDie())});
2346
2347 tensorflow::JobDef* job_def2 = cluster_def->add_job();
2348 job_def2->set_name("worker");
2349
2350 // Copy over `host:port` at `task_index`
2351 for (auto task : cluster_server_def.cluster().job(0).tasks()) {
2352 if (task.first == task_index) {
2353 job_def2->mutable_tasks()->insert({task.first, task.second});
2354 }
2355 }
2356
2357 return single_host_server_def;
2358 }
2359
GetClusterServerDef(const string & worker_job_name,int num_workers)2360 tensorflow::ServerDef GetClusterServerDef(const string& worker_job_name,
2361 int num_workers) {
2362 tensorflow::ServerDef server_def = GetServerDef(worker_job_name, num_workers);
2363 tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
2364
2365 // Add a client.
2366 tensorflow::JobDef* job_def2 = cluster_def->add_job();
2367 job_def2->set_name("client");
2368 job_def2->mutable_tasks()->insert(
2369 {0, tensorflow::strings::StrCat(
2370 "localhost:", tensorflow::testing::PickUnusedPortOrDie())});
2371 return server_def;
2372 }
2373
TEST(CAPI,SingleHostServerDefWorks)2374 TEST(CAPI, SingleHostServerDefWorks) {
2375 // Create a server def that represents a 2-process cluster and a client.
2376 // Example:
2377 //
2378 // cluster { job { name: "worker"
2379 // tasks { key: 0 value: "localhost:14522" }
2380 // tasks { key: 1 value: "localhost:14523" }
2381 // }
2382 // job { name: "client"
2383 // tasks { key: 0 value: "localhost:14524" }
2384 // }
2385 // } job_name: "worker" protocol: "grpc"
2386 //
2387 tensorflow::ServerDef cluster_server_def = GetClusterServerDef("worker", 2);
2388
2389 // Create two worker tasks, using single host server defs.
2390 // A single host server def contains a client and the remote host.
2391 // Example:
2392 //
2393 // Worker1:
2394 // cluster { job { name: "client" tasks { key: 0 value: "localhost:14525" } }
2395 // job { name: "worker" tasks { key: 1 value: "localhost:14523" } }
2396 // } job_name: "worker" task_index: 1 protocol: "grpc"
2397 //
2398 // Worker0:
2399 // cluster { job { name: "client" tasks { key: 0 value: "localhost:14526" } }
2400 // job { name: "worker" tasks { key: 0 value: "localhost:14522" } }
2401 // } job_name: "worker" protocol: "grpc"
2402 //
2403
2404 // Create `worker_1` using single host server def `worker_1_server_def`.
2405 tensorflow::ServerDef worker_1_server_def =
2406 CreateSingleHostServerDef(cluster_server_def, 1);
2407 worker_1_server_def.set_task_index(1);
2408 worker_1_server_def.set_job_name("worker");
2409
2410 std::unique_ptr<tensorflow::GrpcServer> worker_server1;
2411 ASSERT_TRUE(tensorflow::GrpcServer::Create(worker_1_server_def,
2412 tensorflow::Env::Default(),
2413 &worker_server1)
2414 .ok());
2415 ASSERT_TRUE(worker_server1->Start().ok());
2416
2417 // Create context `local_ctx` using single host server def -
2418 // `worker_1_server_def`.
2419 worker_1_server_def.set_task_index(0);
2420 worker_1_server_def.set_job_name("client");
2421 TFE_Context* local_ctx =
2422 CreateContext(worker_1_server_def.SerializeAsString(),
2423 /*isolate_session_state=*/false);
2424
2425 const char remote_device[] = "/job:worker/replica:0/task:1/device:CPU:0";
2426
2427 // Create a variable `var` on `worker2` using `local_ctx`.
2428 TFE_TensorHandle* handle_0 =
2429 CreateVariable(local_ctx, 1.2, remote_device, /*variable_name=*/"var");
2430 TF_Status* status = TF_NewStatus();
2431 TFE_ContextAsyncWait(local_ctx, status);
2432 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2433 TF_DeleteStatus(status);
2434 TFE_DeleteTensorHandle(handle_0);
2435
2436 // Create `worker0` using single host server def `worker_0_server_def`.
2437 tensorflow::ServerDef worker_0_server_def =
2438 CreateSingleHostServerDef(cluster_server_def, 0);
2439 worker_0_server_def.set_task_index(0);
2440
2441 std::unique_ptr<tensorflow::GrpcServer> worker_server0;
2442 ASSERT_TRUE(tensorflow::GrpcServer::Create(worker_0_server_def,
2443 tensorflow::Env::Default(),
2444 &worker_server0)
2445 .ok());
2446 ASSERT_TRUE(worker_server0->Start().ok());
2447
2448 // Create a remote context, `remote_ctx`, using `cluster_server_def`.
2449 cluster_server_def.set_task_index(0);
2450 cluster_server_def.set_job_name("client");
2451 TFE_Context* remote_ctx =
2452 CreateContext(cluster_server_def.SerializeAsString(),
2453 /*isolate_session_state=*/false);
2454
2455 // Read variable `var` using `remote_ctx`, created using `cluster_server_def`.
2456 {
2457 // Create a handle to `var`.
2458 TFE_TensorHandle* var_handle =
2459 CreateVarHandle(remote_ctx, remote_device, /*variable_name=*/"var");
2460
2461 TFE_TensorHandle* handle_1 = nullptr;
2462 int num_retvals = 1;
2463 TF_Status* status = TF_NewStatus();
2464 TFE_Op* op = TFE_NewOp(remote_ctx, "ReadVariableOp", status);
2465 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2466 TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
2467 TFE_OpAddInput(op, var_handle, status);
2468 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2469 TFE_Execute(op, &handle_1, &num_retvals, status);
2470 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2471 TFE_DeleteOp(op);
2472
2473 ASSERT_EQ(1, num_retvals);
2474 EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(handle_1));
2475 EXPECT_EQ(0, TFE_TensorHandleNumDims(handle_1, status));
2476 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2477
2478 // Read the value of tensor handle `handle_1`.
2479 float value = 0.0f;
2480 TF_Tensor* t = TFE_TensorHandleResolve(handle_1, status);
2481 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2482 ASSERT_EQ(sizeof(float), TF_TensorByteSize(t));
2483 memcpy(&value, TF_TensorData(t), sizeof(float));
2484 TF_DeleteTensor(t);
2485 EXPECT_EQ(1.2f, value);
2486 TFE_DeleteTensorHandle(handle_1);
2487 TF_DeleteStatus(status);
2488 TFE_DeleteTensorHandle(var_handle);
2489 }
2490
2491 TFE_DeleteContext(local_ctx);
2492 TFE_DeleteContext(remote_ctx);
2493
2494 worker_server1.release();
2495 worker_server0.release();
2496 }
2497
2498 } // namespace
2499