1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/c_api.h"
17 
18 #include <string.h>
19 
20 #include <string>
21 
22 // clang-format off
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/platform/platform.h"
26 // clang-format on
27 
28 #include "absl/strings/match.h"
29 #include "tensorflow/c/eager/c_api_experimental.h"
30 #include "tensorflow/c/eager/c_api_internal.h"
31 #include "tensorflow/c/eager/c_api_test_util.h"
32 #include "tensorflow/c/eager/tfe_op_internal.h"
33 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
34 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
35 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
36 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
37 #include "tensorflow/core/framework/function.pb.h"
38 #include "tensorflow/core/platform/casts.h"
39 #include "tensorflow/core/platform/logging.h"
40 #include "tensorflow/core/platform/macros.h"
41 #include "tensorflow/core/platform/protobuf.h"
42 #include "tensorflow/core/platform/strcat.h"
43 #include "tensorflow/core/platform/test.h"
44 #include "tensorflow/core/platform/test_benchmark.h"
45 #include "tensorflow/core/protobuf/cluster.pb.h"
46 #include "tensorflow/core/protobuf/config.pb.h"
47 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
48 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
49 
50 #ifdef PLATFORM_GOOGLE
51 #include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
52 #endif
53 
54 using tensorflow::string;
55 
56 namespace {
57 
BM_InitOp(::testing::benchmark::State & state)58 void BM_InitOp(::testing::benchmark::State& state) {
59   TF_Status* status = TF_NewStatus();
60   TFE_ContextOptions* opts = TFE_NewContextOptions();
61   TFE_Context* ctx = TFE_NewContext(opts, status);
62   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
63   TFE_DeleteContextOptions(opts);
64 
65   TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
66   for (auto s : state) {
67     TFE_Op* matmul = MatMulOp(ctx, m, m);
68     TFE_DeleteOp(matmul);
69   }
70   TFE_DeleteTensorHandle(m);
71   TFE_DeleteContext(ctx);
72   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
73   TF_DeleteStatus(status);
74 }
75 BENCHMARK(BM_InitOp);
76 
BM_Execute(::testing::benchmark::State & state)77 void BM_Execute(::testing::benchmark::State& state) {
78   const int async = state.range(0);
79   state.SetLabel(async ? "ExecuteAsync" : "Execute");
80   TF_Status* status = TF_NewStatus();
81   TFE_ContextOptions* opts = TFE_NewContextOptions();
82   TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
83   TFE_Context* ctx = TFE_NewContext(opts, status);
84   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
85   TFE_DeleteContextOptions(opts);
86 
87   TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
88   TFE_Op* matmul = TFE_NewOp(ctx, "MatMul", status);
89   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
90   TFE_TensorHandle* retvals[1];
91   int num_retvals = 1;
92   for (auto s : state) {
93     TFE_OpReset(matmul, "MatMul", nullptr, status);
94     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
95     TFE_OpAddInput(matmul, m, status);
96     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
97     TFE_OpAddInput(matmul, m, status);
98     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
99     TFE_Execute(matmul, &retvals[0], &num_retvals, status);
100     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
101     if (state.iterations() >= state.max_iterations && async) {
102       TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
103       TFE_ExecutorWaitForAllPendingNodes(executor, status);
104       ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
105       TFE_DeleteExecutor(executor);
106     }
107   }
108   TFE_DeleteOp(matmul);
109   TFE_DeleteTensorHandle(m);
110   TFE_DeleteContext(ctx);
111   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
112   TF_DeleteStatus(status);
113 }
114 BENCHMARK(BM_Execute)->Arg(0)->Arg(1);
115 
BM_Execute_Identity(::testing::benchmark::State & state)116 void BM_Execute_Identity(::testing::benchmark::State& state) {
117   const int async = state.range(0);
118   state.SetLabel(async ? "ExecuteIdentityAsync" : "ExecuteIdentity");
119   TF_Status* status = TF_NewStatus();
120   TFE_ContextOptions* opts = TFE_NewContextOptions();
121   TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
122   TFE_Context* ctx = TFE_NewContext(opts, status);
123   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
124   TFE_DeleteContextOptions(opts);
125 
126   TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
127   TFE_Op* identity = TFE_NewOp(ctx, "Identity", status);
128   TFE_TensorHandle* retvals[1];
129   int num_retvals = 1;
130   for (auto s : state) {
131     TFE_OpReset(identity, "Identity", nullptr, status);
132     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
133     TFE_OpAddInput(identity, m, status);
134     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
135     TFE_Execute(identity, &retvals[0], &num_retvals, status);
136     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
137     if (state.iterations() >= state.max_iterations && async) {
138       TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
139       TFE_ExecutorWaitForAllPendingNodes(executor, status);
140       ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
141       TFE_DeleteExecutor(executor);
142     }
143   }
144   TFE_DeleteOp(identity);
145   TFE_DeleteTensorHandle(m);
146   TFE_DeleteContext(ctx);
147   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
148   TF_DeleteStatus(status);
149 }
150 BENCHMARK(BM_Execute_Identity)->Arg(0)->Arg(1);
151 
TEST(CAPI,Context)152 TEST(CAPI, Context) {
153   TF_Status* status = TF_NewStatus();
154   TFE_ContextOptions* opts = TFE_NewContextOptions();
155   TFE_Context* ctx = TFE_NewContext(opts, status);
156   TFE_DeleteContextOptions(opts);
157 
158   TF_DeviceList* devices = TFE_ContextListDevices(ctx, status);
159   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
160 
161   TFE_DeleteContext(ctx);
162   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
163 
164   const int num_devices = TF_DeviceListCount(devices);
165   EXPECT_GE(num_devices, 1) << "At least one CPU device should exist";
166   for (int i = 0; i < num_devices; ++i) {
167     EXPECT_NE("", TF_DeviceListName(devices, i, status)) << i;
168     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
169   }
170   TF_DeleteDeviceList(devices);
171   TF_DeleteStatus(status);
172 }
173 
TEST(CAPI,TensorHandle)174 TEST(CAPI, TensorHandle) {
175   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
176       TF_NewStatus(), TF_DeleteStatus);
177   TFE_ContextOptions* opts = TFE_NewContextOptions();
178   TFE_Context* ctx = TFE_NewContext(opts, status.get());
179   CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
180   TFE_DeleteContextOptions(opts);
181 
182   TFE_TensorHandle* h = TestMatrixTensorHandle(ctx);
183   EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
184 
185   TF_Tensor* t = TFE_TensorHandleResolve(h, status.get());
186   ASSERT_EQ(16, TF_TensorByteSize(t));
187   float data[4] = {0};
188   memcpy(&data[0], TF_TensorData(t), TF_TensorByteSize(t));
189   EXPECT_EQ(1.0, data[0]);
190   EXPECT_EQ(2.0, data[1]);
191   EXPECT_EQ(3.0, data[2]);
192   EXPECT_EQ(4.0, data[3]);
193   TF_DeleteTensor(t);
194   TFE_DeleteTensorHandle(h);
195   TFE_DeleteContext(ctx);
196 }
197 
TensorHandleCopyBetweenDevices(bool async)198 void TensorHandleCopyBetweenDevices(bool async) {
199   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
200       TF_NewStatus(), TF_DeleteStatus);
201   TFE_ContextOptions* opts = TFE_NewContextOptions();
202   TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
203   TFE_Context* ctx = TFE_NewContext(opts, status.get());
204   TFE_DeleteContextOptions(opts);
205   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
206 
207   TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
208   TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
209   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
210 
211   TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
212   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
213   const int num_devices = TF_DeviceListCount(devices);
214 
215   const char* kCPUDevice = "CPU:0";
216   for (int i = 0; i < num_devices; ++i) {
217     const string name(TF_DeviceListName(devices, i, status.get()));
218     if (TF_GetCode(status.get()) != TF_OK) {
219       ADD_FAILURE() << i << " -- " << TF_Message(status.get());
220       continue;
221     }
222     auto tag = tensorflow::strings::StrCat("Device #", i, " (", name, ")");
223     // Copy to device
224     TFE_TensorHandle* hdevice =
225         TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get());
226     if (TF_GetCode(status.get()) != TF_OK) {
227       ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
228       continue;
229     }
230     // Copy from device to the same device.
231     TFE_TensorHandle* hdevice2 =
232         TFE_TensorHandleCopyToDevice(hdevice, ctx, name.c_str(), status.get());
233     if (TF_GetCode(status.get()) != TF_OK) {
234       ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
235       continue;
236     }
237     TFE_DeleteTensorHandle(hdevice);
238     // Copy back to CPU
239     TFE_TensorHandle* hcopy =
240         TFE_TensorHandleCopyToDevice(hdevice2, ctx, kCPUDevice, status.get());
241     if (TF_GetCode(status.get()) != TF_OK) {
242       ADD_FAILURE() << tag << " -- " << TF_Message(status.get());
243       continue;
244     }
245     TFE_DeleteTensorHandle(hdevice2);
246 
247     // Ensure that the contents are the same!
248     TF_Tensor* tcopy = TFE_TensorHandleResolve(hcopy, status.get());
249     TFE_DeleteTensorHandle(hcopy);
250     if (TF_GetCode(status.get()) != TF_OK) {
251       ADD_FAILURE() << tag;
252       continue;
253     }
254     EXPECT_EQ(TF_TensorByteSize(t), TF_TensorByteSize(tcopy)) << tag;
255     EXPECT_EQ(
256         0, memcmp(TF_TensorData(t), TF_TensorData(tcopy), TF_TensorByteSize(t)))
257         << tag;
258     TF_DeleteTensor(tcopy);
259   }
260 
261   TF_DeleteDeviceList(devices);
262   TF_DeleteTensor(t);
263   TFE_DeleteTensorHandle(hcpu);
264   TFE_DeleteContext(ctx);
265 }
266 
TEST(CAPI,TensorHandleCopyBetweenDevices)267 TEST(CAPI, TensorHandleCopyBetweenDevices) {
268   TensorHandleCopyBetweenDevices(false);
269 }
270 
TEST(CAPI,TensorHandleCopyBetweenDevicesAsync)271 TEST(CAPI, TensorHandleCopyBetweenDevicesAsync) {
272   TensorHandleCopyBetweenDevices(true);
273 }
274 
TensorHandleCopyBetweenDevicesError(bool async)275 void TensorHandleCopyBetweenDevicesError(bool async) {
276   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
277       TF_NewStatus(), TF_DeleteStatus);
278   TFE_ContextOptions* opts = TFE_NewContextOptions();
279   TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
280   TFE_Context* ctx = TFE_NewContext(opts, status.get());
281   TFE_DeleteContextOptions(opts);
282   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
283   TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
284   const char* kErrorDevice = "NoSuchDevice:0";
285   TFE_TensorHandle* hdevice =
286       TFE_TensorHandleCopyToDevice(hcpu, ctx, kErrorDevice, status.get());
287   EXPECT_NE(TF_OK, TF_GetCode(status.get()));
288   const char* msg = "NoSuchDevice:0 unknown device";
289   EXPECT_TRUE(strstr(TF_Message(status.get()), msg) != nullptr)
290       << TF_Message(status.get());
291   TF_SetStatus(status.get(), TF_OK, "");
292   const char* kCPUDevice = "CPU:0";
293   TFE_TensorHandle* hcopy =
294       TFE_TensorHandleCopyToDevice(hcpu, ctx, kCPUDevice, status.get());
295   EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
296 
297   TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
298   TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
299   EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
300   TFE_DeleteExecutor(executor);
301   TFE_DeleteTensorHandle(hcopy);
302   TFE_DeleteTensorHandle(hcpu);
303   if (hdevice != nullptr) TFE_DeleteTensorHandle(hdevice);
304   TFE_DeleteContext(ctx);
305 }
306 
TEST(CAPI,TensorHandleCopyBetweenDevicesError)307 TEST(CAPI, TensorHandleCopyBetweenDevicesError) {
308   TensorHandleCopyBetweenDevicesError(false);
309 }
310 
TEST(CAPI,TensorHandleCopyBetweenDevicesErrorAsync)311 TEST(CAPI, TensorHandleCopyBetweenDevicesErrorAsync) {
312   TensorHandleCopyBetweenDevicesError(true);
313 }
314 
TensorHandleCopyBetweenTwoGPUDevices(bool async)315 void TensorHandleCopyBetweenTwoGPUDevices(bool async) {
316   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
317       TF_NewStatus(), TF_DeleteStatus);
318   TFE_ContextOptions* opts = TFE_NewContextOptions();
319   TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
320   TFE_Context* ctx = TFE_NewContext(opts, status.get());
321   TFE_DeleteContextOptions(opts);
322   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
323 
324   TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
325   TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
326   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
327 
328   TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
329   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
330   const int num_devices = TF_DeviceListCount(devices);
331   bool has_gpu0 = false;
332   bool has_gpu1 = false;
333   for (int i = 0; i < num_devices; ++i) {
334     const char* dev = TF_DeviceListName(devices, i, status.get());
335     ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
336     string device_name(dev);
337     if (device_name.find("GPU:0") != string::npos) {
338       has_gpu0 = true;
339     }
340     if (device_name.find("GPU:1") != string::npos) {
341       has_gpu1 = true;
342     }
343   }
344 
345   const char* kCPUDevice = "CPU:0";
346   if (!has_gpu0 || !has_gpu1) {
347     TF_DeleteDeviceList(devices);
348     TF_DeleteTensor(t);
349     TFE_DeleteTensorHandle(hcpu);
350     TFE_DeleteContext(ctx);
351     return;
352   }
353   const string gpu_1_name(TF_DeviceListName(devices, 1, status.get()));
354   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
355   const string gpu_2_name(TF_DeviceListName(devices, 2, status.get()));
356   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
357   TFE_TensorHandle* hdevice =
358       TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_1_name.c_str(), status.get());
359   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
360 
361   TFE_TensorHandle* hdevice2 = TFE_TensorHandleCopyToDevice(
362       hdevice, ctx, gpu_2_name.c_str(), status.get());
363   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
364   TFE_DeleteTensorHandle(hdevice);
365   // Copy back to CPU
366   TFE_TensorHandle* hcopy =
367       TFE_TensorHandleCopyToDevice(hdevice2, ctx, kCPUDevice, status.get());
368   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
369   TFE_DeleteTensorHandle(hdevice2);
370 
371   // Ensure that the contents are the same!
372   TF_Tensor* tcopy = TFE_TensorHandleResolve(hcopy, status.get());
373   TFE_DeleteTensorHandle(hcopy);
374   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
375   EXPECT_EQ(TF_TensorByteSize(t), TF_TensorByteSize(tcopy));
376   EXPECT_EQ(
377       0, memcmp(TF_TensorData(t), TF_TensorData(tcopy), TF_TensorByteSize(t)));
378   TF_DeleteTensor(tcopy);
379 
380   TF_DeleteDeviceList(devices);
381   TF_DeleteTensor(t);
382   TFE_DeleteTensorHandle(hcpu);
383   TFE_DeleteContext(ctx);
384 }
385 
TEST(CAPI,TensorHandleCopyBetweenTwoGPUDevices)386 TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) {
387   TensorHandleCopyBetweenTwoGPUDevices(false);
388 }
389 
TEST(CAPI,TensorHandleCopyBetweenTwoGPUDevicesAsync)390 TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevicesAsync) {
391   TensorHandleCopyBetweenTwoGPUDevices(true);
392 }
393 
TensorHandleSilentCopy(bool async,TFE_ContextDevicePlacementPolicy global_policy,TFE_ContextDevicePlacementPolicy thread_policy,bool cpu_op)394 void TensorHandleSilentCopy(bool async,
395                             TFE_ContextDevicePlacementPolicy global_policy,
396                             TFE_ContextDevicePlacementPolicy thread_policy,
397                             bool cpu_op) {
398   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
399       TF_NewStatus(), TF_DeleteStatus);
400   TFE_ContextOptions* opts = TFE_NewContextOptions();
401   TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
402   TFE_ContextOptionsSetDevicePlacementPolicy(opts, global_policy);
403   TFE_Context* ctx = TFE_NewContext(opts, status.get());
404   if (thread_policy != global_policy) {
405     TFE_ContextSetThreadLocalDevicePlacementPolicy(ctx, thread_policy);
406   }
407   TFE_DeleteContextOptions(opts);
408   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
409 
410   TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
411   TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
412   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
413 
414   // Disable the test if no GPU is present.
415   string gpu_device_name;
416   if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
417     TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
418         hcpu, ctx, gpu_device_name.c_str(), status.get());
419     ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
420 
421     auto cpu_arg =
422         tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hcpu));
423     auto gpu_arg =
424         tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hgpu));
425     auto gpu_device = gpu_arg->device();
426     ASSERT_FALSE(cpu_arg->HasLocalMirror(gpu_device));
427 
428     TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
429     if (cpu_op) {
430       string cpu_device_name;
431       ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
432       TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status.get());
433     } else {
434       TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
435     }
436     ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
437 
438     TFE_TensorHandle* retvals[1];
439     int num_retvals = 1;
440     TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
441     ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
442 
443     // The CPU handle should have been copied and have a mirror on the GPU
444     ASSERT_TRUE(cpu_arg->HasLocalMirror(gpu_device));
445 
446     TFE_DeleteOp(matmul);
447     TFE_DeleteTensorHandle(retvals[0]);
448     TFE_DeleteTensorHandle(hgpu);
449   }
450 
451   TF_DeleteTensor(t);
452   TFE_DeleteTensorHandle(hcpu);
453   TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
454   TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
455   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
456   TFE_DeleteExecutor(executor);
457   TFE_DeleteContext(ctx);
458 }
TEST(CAPI,TensorHandleSilentCopy)459 TEST(CAPI, TensorHandleSilentCopy) {
460   TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
461                          TFE_DEVICE_PLACEMENT_SILENT, false);
462 }
TEST(CAPI,TensorHandleSilentCopyAsync)463 TEST(CAPI, TensorHandleSilentCopyAsync) {
464   TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_SILENT,
465                          TFE_DEVICE_PLACEMENT_SILENT, false);
466 }
TEST(CAPI,TensorHandleSilentCopyLocalPolicy)467 TEST(CAPI, TensorHandleSilentCopyLocalPolicy) {
468   TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_EXPLICIT,
469                          TFE_DEVICE_PLACEMENT_SILENT, false);
470 }
TEST(CAPI,TensorHandleSilentCopyLocalPolicyAsync)471 TEST(CAPI, TensorHandleSilentCopyLocalPolicyAsync) {
472   TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_EXPLICIT,
473                          TFE_DEVICE_PLACEMENT_SILENT, false);
474 }
475 
SetAndGetOpDevices(bool async)476 void SetAndGetOpDevices(bool async) {
477   TF_Status* status = TF_NewStatus();
478   TFE_ContextOptions* opts = TFE_NewContextOptions();
479   TFE_Context* ctx = TFE_NewContext(opts, status);
480   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
481   TFE_DeleteContextOptions(opts);
482 
483   TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
484   TFE_Op* matmul = MatMulOp(ctx, m, m);
485 
486   // Disable the test if no GPU is present.
487   string gpu_device_name;
488   if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
489     TFE_OpSetDevice(matmul, "GPU:0", status);
490     ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
491     const char* device_name = TFE_OpGetDevice(matmul, status);
492     ASSERT_TRUE(strstr(device_name, "GPU:0") != nullptr);
493 
494     TFE_OpSetDevice(matmul, "CPU:0", status);
495     ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
496     device_name = TFE_OpGetDevice(matmul, status);
497     ASSERT_TRUE(strstr(device_name, "CPU:0") != nullptr);
498   }
499 
500   TFE_DeleteOp(matmul);
501   TFE_DeleteTensorHandle(m);
502   TFE_DeleteContext(ctx);
503   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
504   TF_DeleteStatus(status);
505 }
506 
TEST(CAPI,TensorHandleNullptr)507 TEST(CAPI, TensorHandleNullptr) {
508   TFE_TensorHandle* h = nullptr;
509   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
510       TF_NewStatus(), TF_DeleteStatus);
511 
512   TF_Tensor* t = TFE_TensorHandleResolve(h, status.get());
513   ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
514   ASSERT_EQ(t, nullptr);
515   ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
516 
517   TF_SetStatus(status.get(), TF_OK, "");
518 
519   const char* device_name = TFE_TensorHandleDeviceName(h, status.get());
520   ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
521   ASSERT_EQ(device_name, nullptr);
522   ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
523 
524   TF_SetStatus(status.get(), TF_OK, "");
525 
526   device_name = TFE_TensorHandleBackingDeviceName(h, status.get());
527   ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
528   ASSERT_EQ(device_name, nullptr);
529   ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
530 
531   TF_SetStatus(status.get(), TF_OK, "");
532 
533   int num_dims = TFE_TensorHandleNumDims(h, status.get());
534   ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
535   ASSERT_EQ(num_dims, -1);
536   ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
537 
538   TF_SetStatus(status.get(), TF_OK, "");
539 
540   int dim = TFE_TensorHandleDim(h, 0, status.get());
541   ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
542   ASSERT_EQ(dim, -1);
543   ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
544 }
545 
TEST(CAPI,TensorHandleDevices)546 TEST(CAPI, TensorHandleDevices) {
547   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
548       TF_NewStatus(), TF_DeleteStatus);
549   TFE_ContextOptions* opts = TFE_NewContextOptions();
550   TFE_Context* ctx = TFE_NewContext(opts, status.get());
551   TFE_DeleteContextOptions(opts);
552   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
553 
554   TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
555   const char* device_name = TFE_TensorHandleDeviceName(hcpu, status.get());
556   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
557   ASSERT_TRUE(absl::StrContains(device_name, "CPU:0")) << device_name;
558   const char* backing_device_name =
559       TFE_TensorHandleBackingDeviceName(hcpu, status.get());
560   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
561   ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU:0"))
562       << backing_device_name;
563 
564   // Disable the test if no GPU is present.
565   string gpu_device_name;
566   if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
567     TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
568         hcpu, ctx, gpu_device_name.c_str(), status.get());
569     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
570 
571     TFE_Op* shape_op = ShapeOp(ctx, hgpu);
572     TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get());
573     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
574     TFE_TensorHandle* retvals[1];
575     int num_retvals = 1;
576     TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get());
577     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
578 
579     // .device of shape is GPU since the op is executed on GPU
580     device_name = TFE_TensorHandleDeviceName(retvals[0], status.get());
581     ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
582     ASSERT_TRUE(absl::StrContains(device_name, "GPU:0")) << device_name;
583 
584     // .backing_device of shape is CPU since the tensor is backed by CPU
585     backing_device_name =
586         TFE_TensorHandleBackingDeviceName(retvals[0], status.get());
587     ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
588     ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU:0"))
589         << backing_device_name;
590 
591     TFE_DeleteOp(shape_op);
592     TFE_DeleteTensorHandle(retvals[0]);
593     TFE_DeleteTensorHandle(hgpu);
594   }
595 
596   TFE_DeleteTensorHandle(hcpu);
597   TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
598   TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
599   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
600   TFE_DeleteExecutor(executor);
601   TFE_DeleteContext(ctx);
602 }
603 
ExecuteAdd(bool async,bool forward_input,bool tfrt)604 void ExecuteAdd(bool async, bool forward_input, bool tfrt) {
605 #ifdef PLATFORM_WINDOWS
606   // On windows, we flakily get a failure due to pointer instability.
607   // Disable the 4 tests using this helper until we fix the issue.
608   return;
609 #else
610   TF_Status* status = TF_NewStatus();
611   TFE_ContextOptions* opts = TFE_NewContextOptions();
612   TFE_ContextOptionsSetTfrt(opts, tfrt);
613   TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
614   TFE_Context* ctx = TFE_NewContext(opts, status);
615   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
616   TFE_DeleteContextOptions(opts);
617 
618   TFE_TensorHandle* n = TestMatrixTensorHandle100x100(ctx);
619   // If a GPU exists, copy the handle to GPU so that we can exercise
620   // unprotecting a mirror.
621   std::string gpu_device_name;
622   if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
623     TFE_TensorHandle* n_gpu =
624         TFE_TensorHandleCopyToDevice(n, ctx, gpu_device_name.c_str(), status);
625     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
626     TFE_DeleteTensorHandle(n);
627     n = n_gpu;
628   }
629 
630   TFE_TensorHandle* m = TestMatrixTensorHandle100x100(ctx);
631 
632   // Store pointer to raw buffer for validation of forwarding behaviour.
633   TF_Tensor* orig = TFE_TensorHandleResolve(n, status);
634   void* orig_ptr = TF_TensorData(orig);
635   TF_DeleteTensor(orig);
636 
637   TFE_Op* add_op = AddOp(ctx, n, m);
638   std::string cpu_device_name;
639   ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
640   TFE_OpSetDevice(add_op, cpu_device_name.c_str(), status);
641   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
642   if (forward_input) {
643     TFE_DeleteTensorHandle(n);
644   }
645 
646   int num_retvals = 1;
647   if (async) {
648     // Enqueue dummy ops so we backlog async execution & actually test async.
649     // This is usually unnecessary, but we've experienced the occasional test
650     // failure when testing async mode with no explicit forwarding.
651     for (int i = 0; i < 100000; ++i) {
652       TFE_Op* add_op_dummy = AddOp(ctx, m, m);
653       TFE_OpSetDevice(add_op_dummy, cpu_device_name.c_str(), status);
654       ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
655       TFE_TensorHandle* dummy = nullptr;
656       TFE_Execute(add_op_dummy, &dummy, &num_retvals, status);
657       ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
658       TFE_DeleteTensorHandle(dummy);
659       TFE_DeleteOp(add_op_dummy);
660     }
661   }
662   TFE_TensorHandle* retval = nullptr;
663   TFE_Execute(add_op, &retval, &num_retvals, status);
664   EXPECT_EQ(1, num_retvals);
665   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
666   if (!forward_input) {
667     TFE_DeleteTensorHandle(n);
668   }
669   TFE_DeleteOp(add_op);
670 
671   TF_Tensor* t = TFE_TensorHandleResolve(retval, status);
672   if (async) {
673     if (forward_input) {
674       // Since the input was forwarded, we released the input handle right away
675       // and hence expect the input to be forwarded to the return tensor.
676       EXPECT_EQ(orig_ptr, TF_TensorData(t));
677     } else {
678       // In async mode we expect forwarding to work without releasing the input
679       // handle since by the time the kernel is executed we have released the
680       // handle in the client code.
681       EXPECT_EQ(orig_ptr, TF_TensorData(t));
682     }
683   } else {
684     if (forward_input) {
685       // Since the input was forwarded, we released the input handle right away
686       // and hence expect the input to be forwarded to the return tensor.
687       EXPECT_EQ(orig_ptr, TF_TensorData(t));
688     } else {
689       // In sync mode, forwarding can't really happen since the client code will
690       // have a reference count on the input tensor while the kernel is being
691       // executed and thus it cannot be re-used for the return tensor.
692       EXPECT_NE(orig_ptr, TF_TensorData(t));
693     }
694   }
695 
696   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
697   TFE_DeleteTensorHandle(m);
698   TFE_DeleteTensorHandle(retval);
699   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
700 
701   float result[100 * 100] = {0};
702   EXPECT_EQ(sizeof(result), TF_TensorByteSize(t));
703   memcpy(&result[0], TF_TensorData(t), TF_TensorByteSize(t));
704   TF_DeleteTensor(t);
705   for (int i = 0; i < 100 * 100; ++i) {
706     EXPECT_EQ(2.0f, result[i]);
707   }
708   TFE_DeleteContext(ctx);
709   TF_DeleteStatus(status);
710 #endif  // PLATFORM_WINDOWS
711 }
TEST(CAPI,ExecuteAdd)712 TEST(CAPI, ExecuteAdd) {
713   ExecuteAdd(
714       /*async=*/false,
715       /*forward_input*/ false,
716       /*tfrt*/ false);
717 }
TEST(CAPI,ExecuteAddAsync)718 TEST(CAPI, ExecuteAddAsync) {
719   ExecuteAdd(
720       /*async=*/true,
721       /*forward_input*/ false,
722       /*tfrt*/ false);
723 }
TEST(CAPI,ExecuteAddForward)724 TEST(CAPI, ExecuteAddForward) {
725   ExecuteAdd(
726       /*async=*/false,
727       /*forward_input*/ true,
728       /*tfrt*/ false);
729 }
TEST(CAPI,ExecuteAddForwardAsync)730 TEST(CAPI, ExecuteAddForwardAsync) {
731   ExecuteAdd(
732       /*async=*/true,
733       /*forward_input*/ true,
734       /*tfrt*/ false);
735 }
736 #ifdef PLATFORM_GOOGLE
737 // TODO(b/153349425): Add forwarding tests for TFRT
738 // TODO(b/178003466): Fix and re-enable.
TEST(CAPI,DISABLED_ExecuteAddTfrt)739 TEST(CAPI, DISABLED_ExecuteAddTfrt) {
740   ExecuteAdd(
741       /*async=*/false,
742       /*forward_input*/ false,
743       /*tfrt*/ true);
744 }
745 #endif
746 
Execute_MatMul_CPU(bool async)747 void Execute_MatMul_CPU(bool async) {
748   TF_Status* status = TF_NewStatus();
749   TFE_ContextOptions* opts = TFE_NewContextOptions();
750   TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
751   TFE_Context* ctx = TFE_NewContext(opts, status);
752   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
753   TFE_DeleteContextOptions(opts);
754 
755   TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
756   TFE_Op* matmul = MatMulOp(ctx, m, m);
757   TFE_TensorHandle* retvals[2] = {nullptr, nullptr};
758   int num_retvals = 2;
759   TFE_Execute(matmul, &retvals[0], &num_retvals, status);
760   EXPECT_EQ(1, num_retvals);
761   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
762   TFE_DeleteOp(matmul);
763   TFE_DeleteTensorHandle(m);
764 
765   TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
766   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
767   TFE_DeleteTensorHandle(retvals[0]);
768   TFE_DeleteContext(ctx);
769   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
770   float product[4] = {0};
771   EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
772   memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
773   TF_DeleteTensor(t);
774   EXPECT_EQ(7, product[0]);
775   EXPECT_EQ(10, product[1]);
776   EXPECT_EQ(15, product[2]);
777   EXPECT_EQ(22, product[3]);
778   TF_DeleteStatus(status);
779 }
TEST(CAPI,Execute_MatMul_CPU)780 TEST(CAPI, Execute_MatMul_CPU) { Execute_MatMul_CPU(false); }
TEST(CAPI,Execute_MatMul_CPUAsync)781 TEST(CAPI, Execute_MatMul_CPUAsync) { Execute_MatMul_CPU(true); }
782 
Execute_MatMul_CPU_Runtime_Error(bool async)783 void Execute_MatMul_CPU_Runtime_Error(bool async) {
784   TF_Status* status = TF_NewStatus();
785   TFE_ContextOptions* opts = TFE_NewContextOptions();
786   TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
787   TFE_Context* ctx = TFE_NewContext(opts, status);
788   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
789   TFE_DeleteContextOptions(opts);
790 
791   TFE_TensorHandle* m1 = TestMatrixTensorHandle(ctx);
792   TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle3X2(ctx);
793   TFE_Op* matmul = MatMulOp(ctx, m1, m2);
794   TFE_OpSetDevice(matmul, "/job:localhost/replica:0/task:0/device:CPU:0",
795                   status);
796   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
797   TFE_Op* matmul2 = MatMulOp(ctx, m1, m1);
798   TFE_OpSetDevice(matmul2, "/job:localhost/replica:0/task:0/device:CPU:0",
799                   status);
800   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
801   TFE_TensorHandle* retvals[1] = {nullptr};
802   int num_retvals = 1;
803   TFE_Execute(matmul, &retvals[0], &num_retvals, status);
804   TFE_DeleteOp(matmul);
805   if (!async) {
806     EXPECT_NE(TF_OK, TF_GetCode(status));
807   } else {
808     TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
809     EXPECT_NE(TF_OK, TF_GetCode(status));
810     EXPECT_EQ(nullptr, t);
811     const char* msg = "Matrix size-incompatible: In[0]: [2,2], In[1]: [3,2]";
812     EXPECT_TRUE(strstr(TF_Message(status), msg) != nullptr)
813         << TF_Message(status);
814     // Since error is not cleared, the following copy with correct device will
815     // still fail.
816     TF_SetStatus(status, TF_OK, "");
817     TFE_DeleteTensorHandle(retvals[0]);
818     TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
819     TFE_ExecutorWaitForAllPendingNodes(executor, status);
820     EXPECT_NE(TF_OK, TF_GetCode(status));
821     TF_SetStatus(status, TF_OK, "");
822     retvals[0] = nullptr;
823     TFE_Execute(matmul2, &retvals[0], &num_retvals, status);
824     EXPECT_NE(TF_OK, TF_GetCode(status));
825     TFE_ExecutorClearError(executor);
826     TFE_ExecutorWaitForAllPendingNodes(executor, status);
827     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
828     TFE_DeleteExecutor(executor);
829   }
830   // Following works in async mode since TFE_ContextAsyncClearError was called.
831   TF_SetStatus(status, TF_OK, "");
832   if (retvals[0] != nullptr) {
833     TFE_DeleteTensorHandle(retvals[0]);
834   }
835   retvals[0] = nullptr;
836   TFE_Execute(matmul2, &retvals[0], &num_retvals, status);
837   EXPECT_EQ(TF_OK, TF_GetCode(status));
838   TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
839   EXPECT_EQ(TF_OK, TF_GetCode(status));
840   TF_DeleteTensor(t);
841   TFE_DeleteOp(matmul2);
842   TFE_DeleteTensorHandle(m1);
843   TFE_DeleteTensorHandle(m2);
844   TFE_DeleteTensorHandle(retvals[0]);
845   TFE_DeleteContext(ctx);
846   TF_DeleteStatus(status);
847 }
TEST(CAPI,Execute_MatMul_CPU_Runtime_Error)848 TEST(CAPI, Execute_MatMul_CPU_Runtime_Error) {
849   Execute_MatMul_CPU_Runtime_Error(false);
850 }
TEST(CAPI,Execute_MatMul_CPU_Runtime_ErrorAsync)851 TEST(CAPI, Execute_MatMul_CPU_Runtime_ErrorAsync) {
852   Execute_MatMul_CPU_Runtime_Error(true);
853 }
854 
Execute_MatMul_CPU_Type_Error(bool async)855 void Execute_MatMul_CPU_Type_Error(bool async) {
856   TF_Status* status = TF_NewStatus();
857   TFE_ContextOptions* opts = TFE_NewContextOptions();
858   TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
859   TFE_Context* ctx = TFE_NewContext(opts, status);
860   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
861   TFE_DeleteContextOptions(opts);
862 
863   TFE_TensorHandle* m1 = TestMatrixTensorHandle(ctx);
864   TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle(ctx);
865   TFE_Op* matmul = MatMulOp(ctx, m1, m2);
866   TFE_TensorHandle* retvals[1] = {nullptr};
867   int num_retvals = 1;
868   TFE_Execute(matmul, &retvals[0], &num_retvals, status);
869   EXPECT_NE(TF_OK, TF_GetCode(status));
870   TFE_DeleteOp(matmul);
871   TFE_DeleteTensorHandle(m1);
872   TFE_DeleteTensorHandle(m2);
873   if (retvals[0] != nullptr) {
874     TFE_DeleteTensorHandle(retvals[0]);
875   }
876   TFE_DeleteContext(ctx);
877   TF_DeleteStatus(status);
878 }
879 
TEST(CAPI,Execute_MatMul_CPU_Type_Error)880 TEST(CAPI, Execute_MatMul_CPU_Type_Error) {
881   Execute_MatMul_CPU_Type_Error(false);
882 }
TEST(CAPI,Execute_MatMul_CPU_Type_ErrorAsync)883 TEST(CAPI, Execute_MatMul_CPU_Type_ErrorAsync) {
884   Execute_MatMul_CPU_Type_Error(true);
885 }
TEST(CAPI,Execute_Min_CPU)886 TEST(CAPI, Execute_Min_CPU) {
887   TF_Status* status = TF_NewStatus();
888   TFE_ContextOptions* opts = TFE_NewContextOptions();
889   TFE_Context* ctx = TFE_NewContext(opts, status);
890   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
891   TFE_DeleteContextOptions(opts);
892 
893   TFE_TensorHandle* input = TestMatrixTensorHandle(ctx);
894   TFE_TensorHandle* axis = TestAxisTensorHandle(ctx);
895   TFE_Op* minOp = MinOp(ctx, input, axis);
896   TFE_TensorHandle* retvals[1] = {nullptr};
897   int num_retvals = 1;
898   TFE_Execute(minOp, &retvals[0], &num_retvals, status);
899   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
900   TFE_DeleteOp(minOp);
901   TFE_DeleteTensorHandle(input);
902   TFE_DeleteTensorHandle(axis);
903   ASSERT_EQ(1, num_retvals);
904 
905   TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
906   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
907   TFE_DeleteTensorHandle(retvals[0]);
908   float output[2] = {0};
909   EXPECT_EQ(sizeof(output), TF_TensorByteSize(t));
910   memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t));
911   TF_DeleteTensor(t);
912   EXPECT_EQ(1, output[0]);
913   EXPECT_EQ(3, output[1]);
914   TFE_DeleteContext(ctx);
915   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
916   TF_DeleteStatus(status);
917 }
918 
ExecuteWithTracing(bool async)919 void ExecuteWithTracing(bool async) {
920   TF_Status* status = TF_NewStatus();
921   TFE_ContextOptions* opts = TFE_NewContextOptions();
922   TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
923   TFE_Context* ctx = TFE_NewContext(opts, status);
924   TFE_ContextEnableRunMetadata(ctx);
925   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
926   TFE_DeleteContextOptions(opts);
927 
928   TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
929   TFE_Op* matmul = MatMulOp(ctx, m, m);
930   TFE_TensorHandle* retvals[1] = {nullptr};
931   int num_retvals = 1;
932   TFE_Execute(matmul, &retvals[0], &num_retvals, status);
933   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
934   TFE_DeleteOp(matmul);
935   TFE_DeleteTensorHandle(m);
936   TF_Buffer* b = TF_NewBuffer();
937   TFE_ContextExportRunMetadata(ctx, b, status);
938   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
939   tensorflow::RunMetadata rm;
940   EXPECT_TRUE(
941       rm.ParseFromString({reinterpret_cast<const char*>(b->data), b->length}));
942   TF_DeleteBuffer(b);
943   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
944   ASSERT_EQ(1, num_retvals);
945 
946   TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
947   TFE_DeleteTensorHandle(retvals[0]);
948   TFE_DeleteContext(ctx);
949   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
950   float product[4] = {0};
951   EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
952   memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
953   TF_DeleteTensor(t);
954   EXPECT_EQ(7, product[0]);
955   EXPECT_EQ(10, product[1]);
956   EXPECT_EQ(15, product[2]);
957   EXPECT_EQ(22, product[3]);
958   TF_DeleteStatus(status);
959 }
TEST(CAPI,ExecuteWithTracing)960 TEST(CAPI, ExecuteWithTracing) { ExecuteWithTracing(false); }
TEST(CAPI,ExecuteWithTracingAsync)961 TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithTracing(true); }
962 
963 REGISTER_OP("TestNonCommUnavailable")
964     .Output("out: string")
965     .Doc(R"doc(Test non-communication op throwing Unavailable error.)doc");
966 
967 REGISTER_OP("TestCommUnavailable")
968     .Output("out: string")
969     .SetIsDistributedCommunication()
970     .Doc(R"doc(Test communication op throwing Unavailable error.)doc");
971 
972 // Kernel that throws an Unavailable error.
973 class TestUnavailableErrorOp : public tensorflow::OpKernel {
974  public:
TestUnavailableErrorOp(tensorflow::OpKernelConstruction * ctx)975   explicit TestUnavailableErrorOp(tensorflow::OpKernelConstruction* ctx)
976       : tensorflow::OpKernel(ctx) {}
Compute(tensorflow::OpKernelContext * ctx)977   void Compute(tensorflow::OpKernelContext* ctx) override {
978     ctx->SetStatus(tensorflow::errors::Unavailable("Test error."));
979   }
980 };
981 REGISTER_KERNEL_BUILDER(
982     Name("TestNonCommUnavailable").Device(tensorflow::DEVICE_DEFAULT),
983     TestUnavailableErrorOp);
984 REGISTER_KERNEL_BUILDER(
985     Name("TestCommUnavailable").Device(tensorflow::DEVICE_DEFAULT),
986     TestUnavailableErrorOp);
987 
FunctionWithErrorOp(const tensorflow::StringPiece op_name)988 string FunctionWithErrorOp(const tensorflow::StringPiece op_name) {
989   const std::string& func_str =
990       "    signature {"
991       "      name: 'FunctionWith__OP_NAME__'"
992       "      output_arg {"
993       "        name: 'out'"
994       "        type: DT_STRING"
995       "      }"
996       "    }"
997       "    node_def {"
998       "      name: 'error_op'"
999       "      op: '__OP_NAME__'"
1000       "    }"
1001       "    ret {"
1002       "      key: 'out'"
1003       "      value: 'error_op:out'"
1004       "    }";
1005   tensorflow::FunctionDef def;
1006   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
1007       tensorflow::str_util::StringReplace(func_str, "__OP_NAME__", op_name,
1008                                           /*replace_all=*/true),
1009       &def));
1010   return def.SerializeAsString();
1011 }
1012 
TEST(CAPI,ExecuteOpAndFunctionWithError)1013 TEST(CAPI, ExecuteOpAndFunctionWithError) {
1014   TF_Status* status = TF_NewStatus();
1015   TFE_ContextOptions* opts = TFE_NewContextOptions();
1016   TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(/*async=*/false));
1017   TFE_Context* ctx = TFE_NewContext(opts, status);
1018   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1019   TFE_DeleteContextOptions(opts);
1020 
1021   TFE_Op* non_comm_op = TFE_NewOp(ctx, "TestNonCommUnavailable", status);
1022   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1023   TFE_TensorHandle* retval[1] = {};
1024   int num_retvals = 1;
1025   TFE_Execute(non_comm_op, retval, &num_retvals, status);
1026   EXPECT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status);
1027   TFE_DeleteOp(non_comm_op);
1028 
1029   TFE_Op* comm_op = TFE_NewOp(ctx, "TestCommUnavailable", status);
1030   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1031   TFE_Execute(comm_op, retval, &num_retvals, status);
1032   EXPECT_EQ(TF_UNAVAILABLE, TF_GetCode(status)) << TF_Message(status);
1033   TFE_DeleteOp(comm_op);
1034 
1035   const string& fdef1 = FunctionWithErrorOp("TestNonCommUnavailable");
1036   TFE_ContextAddFunctionDef(ctx, fdef1.data(), fdef1.size(), status);
1037   TFE_Op* fn1 = TFE_NewOp(ctx, "FunctionWithTestNonCommUnavailable", status);
1038   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1039   TFE_Execute(fn1, retval, &num_retvals, status);
1040   EXPECT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status);
1041   TFE_DeleteOp(fn1);
1042 
1043   const string& fdef2 = FunctionWithErrorOp("TestCommUnavailable");
1044   TFE_ContextAddFunctionDef(ctx, fdef2.data(), fdef2.size(), status);
1045   TFE_Op* fn2 = TFE_NewOp(ctx, "FunctionWithTestCommUnavailable", status);
1046   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1047   TFE_Execute(fn2, retval, &num_retvals, status);
1048   EXPECT_EQ(TF_UNAVAILABLE, TF_GetCode(status)) << TF_Message(status);
1049   TFE_DeleteOp(fn2);
1050 
1051   TFE_DeleteContext(ctx);
1052   TF_DeleteStatus(status);
1053 }
1054 
MatMulFunction()1055 string MatMulFunction() {
1056   tensorflow::FunctionDef def;
1057   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
1058       "    signature {"
1059       "      name: 'MatMulFunction'"
1060       "      input_arg {"
1061       "        name: 'a'"
1062       "        type: DT_FLOAT"
1063       "      }"
1064       "      output_arg {"
1065       "        name: 'm'"
1066       "        type: DT_FLOAT"
1067       "      }"
1068       "    }"
1069       "    node_def {"
1070       "      name: 'matmul'"
1071       "      op: 'MatMul'"
1072       "      input: 'a'"
1073       "      input: 'a'"
1074       "      attr {"
1075       "        key: 'T'"
1076       "        value {"
1077       "          type: DT_FLOAT"
1078       "        }"
1079       "      }"
1080       "    }"
1081       "    ret {"
1082       "      key: 'm'"
1083       "      value: 'matmul:product'"
1084       "    }",
1085       &def));
1086   return def.SerializeAsString();
1087 }
1088 
1089 // a + a
AddFunction()1090 string AddFunction() {
1091   tensorflow::FunctionDef def;
1092   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
1093       "    signature {"
1094       "      name: 'AddFunction'"
1095       "      input_arg {"
1096       "        name: 'a'"
1097       "        type: DT_FLOAT"
1098       "      }"
1099       "      output_arg {"
1100       "        name: 'o'"
1101       "        type: DT_FLOAT"
1102       "      }"
1103       "    }"
1104       "    node_def {"
1105       "      name: 'output'"
1106       "      op: 'Add'"
1107       "      input: 'a'"
1108       "      input: 'a'"
1109       "      attr {"
1110       "        key: 'T'"
1111       "        value {"
1112       "          type: DT_FLOAT"
1113       "        }"
1114       "      }"
1115       "    }"
1116       "    ret {"
1117       "      key: 'o'"
1118       "      value: 'output:z'"
1119       "    }",
1120       &def));
1121   return def.SerializeAsString();
1122 }
1123 
FunctionDefAndExecute(bool async)1124 void FunctionDefAndExecute(bool async) {
1125   TF_Status* status = TF_NewStatus();
1126   TFE_ContextOptions* opts = TFE_NewContextOptions();
1127   TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
1128   TFE_Context* ctx = TFE_NewContext(opts, status);
1129   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1130   TFE_DeleteContextOptions(opts);
1131 
1132   string function_def = MatMulFunction();
1133   TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
1134                             status);
1135   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1136 
1137   for (bool clear_cache : {true, false, true}) {
1138     if (clear_cache) {
1139       TFE_ContextClearCaches(ctx);
1140     }
1141     TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
1142     TFE_TensorHandle* retval[1] = {nullptr};
1143     int num_retvals = 1;
1144     TFE_Op* op = TFE_NewOp(ctx, "MatMulFunction", status);
1145     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1146     TFE_OpAddInput(op, m, status);
1147     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1148     TFE_Execute(op, &retval[0], &num_retvals, status);
1149     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1150     ASSERT_EQ(1, num_retvals);
1151     TFE_DeleteOp(op);
1152     TFE_DeleteTensorHandle(m);
1153     TF_Tensor* t = TFE_TensorHandleResolve(retval[0], status);
1154     TFE_DeleteTensorHandle(retval[0]);
1155     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1156     float product[4] = {0};
1157     EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
1158     memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
1159     TF_DeleteTensor(t);
1160     EXPECT_EQ(7, product[0]);
1161     EXPECT_EQ(10, product[1]);
1162     EXPECT_EQ(15, product[2]);
1163     EXPECT_EQ(22, product[3]);
1164   }
1165   TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
1166   ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
1167   TFE_DeleteContext(ctx);
1168   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1169   TF_DeleteStatus(status);
1170 }
TEST(CAPI,FunctionDefAndExecute)1171 TEST(CAPI, FunctionDefAndExecute) { FunctionDefAndExecute(false); }
TEST(CAPI,FunctionDefAndExecuteAsync)1172 TEST(CAPI, FunctionDefAndExecuteAsync) { FunctionDefAndExecute(true); }
1173 
RunAddFunction(bool use_tfrt,bool enable_grappler)1174 void RunAddFunction(bool use_tfrt, bool enable_grappler) {
1175   TF_Status* status = TF_NewStatus();
1176   TFE_ContextOptions* opts = TFE_NewContextOptions();
1177   TFE_ContextOptionsSetTfrt(opts, use_tfrt);
1178   TFE_Context* ctx = TFE_NewContext(opts, status);
1179   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1180   TFE_DeleteContextOptions(opts);
1181 
1182   string function_def = AddFunction();
1183   TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
1184                             status);
1185   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1186 
1187   TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
1188   TFE_TensorHandle* retval[1] = {nullptr};
1189   int num_retvals = 1;
1190   TFE_Op* op = TFE_NewOp(ctx, "AddFunction", status);
1191   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1192 
1193   // Add a config_proto attr, to trigger grappler graph rewrites in the current
1194   // eager runtime.
1195   if (enable_grappler) {
1196     tensorflow::ConfigProto config;
1197     // Do not skip grappler optimization even for small graphs.
1198     config.mutable_graph_options()
1199         ->mutable_rewrite_options()
1200         ->set_min_graph_nodes(-1);
1201     string serialized_config;
1202     ASSERT_TRUE(config.SerializeToString(&serialized_config));
1203     TFE_OpSetAttrString(
1204         op, "config_proto",
1205         reinterpret_cast<const void*>(serialized_config.c_str()),
1206         serialized_config.length());
1207   }
1208 
1209   if (use_tfrt) {
1210     // Set some test-only graph compiler options.
1211     TFE_OpSetAttrBool(op, "TFRT_TEST_enable_native_ops", false);
1212     TFE_OpSetAttrBool(op, "TFRT_TEST_enable_grappler", enable_grappler);
1213   }
1214 
1215   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1216 
1217   TFE_OpAddInput(op, m, status);
1218   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1219   TFE_Execute(op, &retval[0], &num_retvals, status);
1220   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1221   ASSERT_EQ(1, num_retvals);
1222   TFE_DeleteOp(op);
1223   TFE_DeleteTensorHandle(m);
1224   TF_Tensor* t = TFE_TensorHandleResolve(retval[0], status);
1225   TFE_DeleteTensorHandle(retval[0]);
1226   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1227   float product[4] = {0};
1228   EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
1229   memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
1230   TF_DeleteTensor(t);
1231   EXPECT_EQ(2, product[0]);
1232   EXPECT_EQ(4, product[1]);
1233   EXPECT_EQ(6, product[2]);
1234   EXPECT_EQ(8, product[3]);
1235 
1236   // When we turn on grappler, confirm that the tf.Add has been rewritten into a
1237   // tf.Mul.
1238   // This capability of checking the executed op names is currently only enabled
1239   // for TFRT debug build, for performance and simplicity reasons.
1240   if (use_tfrt) {
1241     TF_Buffer* buf = TF_NewBuffer();
1242     TFE_GetExecutedOpNames(ctx, buf, status);
1243     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1244 #ifndef NDEBUG
1245     if (enable_grappler)
1246       EXPECT_NE(strstr(static_cast<const char*>(buf->data), "tf.Mul"), nullptr);
1247     else
1248       EXPECT_NE(strstr(static_cast<const char*>(buf->data), "tf.Add"), nullptr);
1249 #endif
1250     TF_DeleteBuffer(buf);
1251   }
1252 
1253   TFE_ContextRemoveFunction(ctx, "AddFunction", status);
1254   ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
1255   TFE_DeleteContext(ctx);
1256   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1257   TF_DeleteStatus(status);
1258 }
1259 
TEST(CAPI,RunAddFunctionWithGrappler)1260 TEST(CAPI, RunAddFunctionWithGrappler) {
1261   RunAddFunction(/*use_tfrt=*/false, /*enable_grappler=*/true);
1262 }
1263 
1264 #ifdef PLATFORM_GOOGLE
TEST(CAPI,RunAddFunction_TFRT)1265 TEST(CAPI, RunAddFunction_TFRT) {
1266   RunAddFunction(/*use_tfrt=*/true, /*enable_grappler=*/false);
1267 }
1268 
TEST(CAPI,RunAddFunctionWithGrappler_TFRT)1269 TEST(CAPI, RunAddFunctionWithGrappler_TFRT) {
1270   RunAddFunction(/*use_tfrt=*/true, /*enable_grappler=*/true);
1271 }
1272 #endif
1273 
BM_ExecuteFunction(::testing::benchmark::State & state)1274 void BM_ExecuteFunction(::testing::benchmark::State& state) {
1275   const int async = state.range(0);
1276   state.SetLabel(async ? "ExecuteFunctionAsync" : "ExecuteFunction");
1277   TF_Status* status = TF_NewStatus();
1278   TFE_ContextOptions* opts = TFE_NewContextOptions();
1279   TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
1280   TFE_Context* ctx = TFE_NewContext(opts, status);
1281   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1282   TFE_DeleteContextOptions(opts);
1283 
1284   string function_def = MatMulFunction();
1285   TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
1286                             status);
1287   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1288 
1289   TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
1290   TFE_TensorHandle* retval[1] = {nullptr};
1291   int num_retvals = 1;
1292   for (auto s : state) {
1293     TFE_Op* matmul = TFE_NewOp(ctx, "MatMulFunction", status);
1294     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1295     TFE_OpAddInput(matmul, m, status);
1296     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1297     TFE_Execute(matmul, &retval[0], &num_retvals, status);
1298     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1299     TFE_DeleteOp(matmul);
1300     if (state.iterations() >= state.max_iterations && async) {
1301       TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
1302       TFE_ExecutorWaitForAllPendingNodes(executor, status);
1303       ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1304       TFE_DeleteExecutor(executor);
1305     }
1306   }
1307   TFE_DeleteTensorHandle(m);
1308   TFE_DeleteTensorHandle(retval[0]);
1309   TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
1310   ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
1311   TFE_DeleteContext(ctx);
1312   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1313   TF_DeleteStatus(status);
1314 }
1315 BENCHMARK(BM_ExecuteFunction)->Arg(0)->Arg(1);
1316 
TEST(CAPI,Variables)1317 TEST(CAPI, Variables) {
1318   // Variables use resource handles, so this is really a test for resource
1319   // tensor handling.
1320   TF_Status* status = TF_NewStatus();
1321   TFE_ContextOptions* opts = TFE_NewContextOptions();
1322   TFE_Context* ctx = TFE_NewContext(opts, status);
1323   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1324   TFE_DeleteContextOptions(opts);
1325 
1326   TFE_TensorHandle* var_handle = TestVariable(ctx, 12.0);
1327   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1328 
1329   TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
1330   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1331   TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
1332   TFE_OpAddInput(op, var_handle, status);
1333   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1334   int num_retvals = 1;
1335   TFE_TensorHandle* value_handle = nullptr;
1336   TFE_Execute(op, &value_handle, &num_retvals, status);
1337   TFE_DeleteOp(op);
1338 
1339   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1340   ASSERT_EQ(1, num_retvals);
1341   EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(value_handle));
1342   EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle, status));
1343   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1344   float value = 0.0f;
1345   TF_Tensor* t = TFE_TensorHandleResolve(value_handle, status);
1346   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1347   ASSERT_EQ(sizeof(float), TF_TensorByteSize(t));
1348   memcpy(&value, TF_TensorData(t), sizeof(float));
1349   TF_DeleteTensor(t);
1350   EXPECT_EQ(12.0, value);
1351 
1352   TFE_DeleteTensorHandle(var_handle);
1353   TFE_DeleteTensorHandle(value_handle);
1354   TFE_DeleteContext(ctx);
1355   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1356   TF_DeleteStatus(status);
1357 }
1358 
BM_ReadVariable(::testing::benchmark::State & state)1359 void BM_ReadVariable(::testing::benchmark::State& state) {
1360   TF_Status* status = TF_NewStatus();
1361   TFE_ContextOptions* opts = TFE_NewContextOptions();
1362   TFE_Context* ctx = TFE_NewContext(opts, status);
1363   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1364   TFE_DeleteContextOptions(opts);
1365 
1366   TFE_TensorHandle* var_handle = TestVariable(ctx, 5.0);
1367   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1368 
1369   int num_retvals = 1;
1370   TFE_TensorHandle* h = nullptr;
1371   for (auto s : state) {
1372     TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
1373     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1374     TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
1375     TFE_OpAddInput(op, var_handle, status);
1376     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1377     TFE_Execute(op, &h, &num_retvals, status);
1378     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1379     CHECK_EQ(1, num_retvals);
1380     CHECK(h);
1381     CHECK_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
1382     CHECK_EQ(0, TFE_TensorHandleNumDims(h, status));
1383     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1384     h = nullptr;
1385     TFE_DeleteOp(op);
1386   }
1387 
1388   TFE_DeleteTensorHandle(var_handle);
1389   TFE_DeleteContext(ctx);
1390   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1391   TF_DeleteStatus(status);
1392 }
1393 BENCHMARK(BM_ReadVariable);
1394 
TEST(CAPI,StringAttributes)1395 TEST(CAPI, StringAttributes) {
1396   // Test that TFE_OpSetAttrString doesn't hold on to the value after it
1397   // returns.
1398   TF_Status* status = TF_NewStatus();
1399   TFE_ContextOptions* opts = TFE_NewContextOptions();
1400   TFE_Context* ctx = TFE_NewContext(opts, status);
1401   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1402   TFE_DeleteContextOptions(opts);
1403 
1404   std::vector<int64_t> dims(4, 1);
1405   TFE_Op* op = TFE_NewOp(ctx, "AvgPool", status);
1406   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1407 
1408   TF_Tensor* tensor =
1409       TF_AllocateTensor(TF_FLOAT, dims.data(), dims.size(), sizeof(float));
1410   float tensor_data[] = {1};
1411   memcpy(TF_TensorData(tensor), tensor_data, TF_TensorByteSize(tensor));
1412   TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status);
1413   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1414   TFE_OpAddInput(op, tensor_handle, status);
1415   TF_DeleteTensor(tensor);
1416   TFE_DeleteTensorHandle(tensor_handle);
1417 
1418   std::vector<int64_t> values(4, 1);
1419   TFE_OpSetAttrIntList(op, "ksize", values.data(), values.size());
1420   TFE_OpSetAttrIntList(op, "strides", values.data(), values.size());
1421 
1422   const int BUFFER_SIZE = 10;
1423   char buffer[BUFFER_SIZE];
1424   std::strncpy(buffer, "VALID", BUFFER_SIZE);
1425   TFE_OpSetAttrString(op, "padding", buffer, std::strlen(buffer));
1426   // Overwriting value in "buffer", should be fine since TFE_Op
1427   // shouldn't be holding on to it.
1428   std::strncpy(buffer, "NHWC", BUFFER_SIZE);
1429   TFE_OpSetAttrString(op, "data_format", buffer, std::strlen(buffer));
1430 
1431   TFE_OpSetAttrType(op, "T", TF_FLOAT);
1432 
1433   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1434 
1435   TFE_TensorHandle* retvals[1];
1436   int num_retvals = 1;
1437   TFE_Execute(op, &retvals[0], &num_retvals, status);
1438   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1439   ASSERT_EQ(1, num_retvals);
1440 
1441   tensor = TFE_TensorHandleResolve(retvals[0], status);
1442   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1443   EXPECT_EQ(4, TF_TensorByteSize(tensor));
1444   TF_DeleteTensor(tensor);
1445   TFE_DeleteTensorHandle(retvals[0]);
1446 
1447   TFE_DeleteOp(op);
1448 
1449   TFE_DeleteContext(ctx);
1450   TF_DeleteStatus(status);
1451 }
1452 
1453 // Same test as above, expect use SetOpAttrValueScalar to set attrs.
TEST(CAPI,TestTFE_SetOpAttrs)1454 TEST(CAPI, TestTFE_SetOpAttrs) {
1455   // Test that TFE_OpSetAttrString doesn't hold on to the value after it
1456   // returns.
1457   TF_Status* status = TF_NewStatus();
1458   TFE_ContextOptions* opts = TFE_NewContextOptions();
1459   TFE_Context* ctx = TFE_NewContext(opts, status);
1460   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1461   TFE_DeleteContextOptions(opts);
1462 
1463   std::vector<int64_t> dims(4, 1);
1464   TFE_Op* op = TFE_NewOp(ctx, "AvgPool", status);
1465   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1466 
1467   TF_Tensor* tensor =
1468       TF_AllocateTensor(TF_FLOAT, dims.data(), dims.size(), sizeof(float));
1469   float tensor_data[] = {1};
1470   memcpy(TF_TensorData(tensor), tensor_data, TF_TensorByteSize(tensor));
1471   TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status);
1472   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1473   TFE_OpAddInput(op, tensor_handle, status);
1474   TF_DeleteTensor(tensor);
1475   TFE_DeleteTensorHandle(tensor_handle);
1476 
1477   tensorflow::AttrValue i_list_values;
1478   for (int i = 0; i < 4; ++i) {
1479     i_list_values.mutable_list()->add_i(1);
1480   }
1481   SetOpAttrValueScalar(ctx, op, i_list_values, "ksize", status);
1482   SetOpAttrValueScalar(ctx, op, i_list_values, "strides", status);
1483 
1484   tensorflow::AttrValue padding_value;
1485   *padding_value.mutable_s() = "VALID";
1486   tensorflow::SetOpAttrValueScalar(ctx, op, padding_value, "padding", status);
1487 
1488   tensorflow::AttrValue data_format_value;
1489   *data_format_value.mutable_s() = "NHWC";
1490   tensorflow::SetOpAttrValueScalar(ctx, op, data_format_value, "data_format",
1491                                    status);
1492 
1493   TFE_OpSetAttrType(op, "T", TF_FLOAT);
1494 
1495   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1496 
1497   TFE_TensorHandle* retvals[1];
1498   int num_retvals = 1;
1499   TFE_Execute(op, &retvals[0], &num_retvals, status);
1500   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1501   ASSERT_EQ(1, num_retvals);
1502 
1503   tensor = TFE_TensorHandleResolve(retvals[0], status);
1504   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1505   EXPECT_EQ(4, TF_TensorByteSize(tensor));
1506   TF_DeleteTensor(tensor);
1507   TFE_DeleteTensorHandle(retvals[0]);
1508 
1509   TFE_DeleteOp(op);
1510 
1511   TFE_DeleteContext(ctx);
1512   TF_DeleteStatus(status);
1513 }
1514 
TEST(CAPI,TestTFE_TensorHandleCopySharingUnderlyingTensorHandle)1515 TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
1516   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
1517       TF_NewStatus(), TF_DeleteStatus);
1518   TFE_ContextOptions* opts = TFE_NewContextOptions();
1519   TFE_Context* ctx = TFE_NewContext(opts, status.get());
1520   CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
1521   TFE_DeleteContextOptions(opts);
1522 
1523   TFE_TensorHandle* h = TestMatrixTensorHandle(ctx);
1524   EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
1525 
1526   TFE_TensorHandle* h_shares_tensor =
1527       TFE_TensorHandleCopySharingTensor(h, status.get());
1528   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
1529 
1530   TF_Tensor* t = TFE_TensorHandleResolve(h_shares_tensor, status.get());
1531   ASSERT_EQ(16, TF_TensorByteSize(t));
1532   float data[4] = {0};
1533   memcpy(&data[0], TF_TensorData(t), TF_TensorByteSize(t));
1534   EXPECT_EQ(1.0, data[0]);
1535   EXPECT_EQ(2.0, data[1]);
1536   EXPECT_EQ(3.0, data[2]);
1537   EXPECT_EQ(4.0, data[3]);
1538   TF_DeleteTensor(t);
1539 
1540   TFE_DeleteTensorHandle(h);
1541   TFE_DeleteTensorHandle(h_shares_tensor);
1542   TFE_DeleteContext(ctx);
1543 }
1544 
ExtractAttrs(TFE_Op * op)1545 tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) {
1546   tensorflow::AttrValueMap attr_values;
1547   tensorflow::EagerOperation* operation =
1548       tensorflow::OperationFromInterface(tensorflow::unwrap(op));
1549   operation->Attrs().FillAttrValueMap(&attr_values);
1550   return attr_values;
1551 }
1552 
TEST(CAPI,TestTFE_OpInferSingleInputAttrs)1553 TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
1554   TF_Status* status = TF_NewStatus();
1555   TFE_ContextOptions* opts = TFE_NewContextOptions();
1556   TFE_Context* ctx = TFE_NewContext(opts, status);
1557   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1558   TFE_DeleteContextOptions(opts);
1559 
1560   TFE_TensorHandle* input = TestMatrixTensorHandle(ctx);
1561   TFE_TensorHandle* axis = TestAxisTensorHandle(ctx);
1562   TFE_Op* minOp = TFE_NewOp(ctx, "Min", status);
1563   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1564   TFE_OpAddInput(minOp, input, status);
1565   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1566   TFE_OpAddInput(minOp, axis, status);
1567   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1568 
1569   tensorflow::AttrValueMap attr_values = ExtractAttrs(minOp);
1570   tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
1571   EXPECT_NE(attr_found, attr_values.cend());
1572   EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
1573   attr_found = attr_values.find("Tidx");
1574   EXPECT_NE(attr_found, attr_values.cend());
1575   EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_INT32);
1576 
1577   TFE_TensorHandle* retvals[1] = {nullptr};
1578   int num_retvals = 1;
1579   TFE_Execute(minOp, &retvals[0], &num_retvals, status);
1580   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1581 
1582   TF_DeleteStatus(status);
1583   TFE_DeleteOp(minOp);
1584   TFE_DeleteTensorHandle(input);
1585   TFE_DeleteTensorHandle(axis);
1586   TFE_DeleteTensorHandle(retvals[0]);
1587   TFE_DeleteContext(ctx);
1588 }
1589 
TEST(CAPI,TestTFE_OpInferSingleTypeInputListAttrs)1590 TEST(CAPI, TestTFE_OpInferSingleTypeInputListAttrs) {
1591   TF_Status* status = TF_NewStatus();
1592   TFE_ContextOptions* opts = TFE_NewContextOptions();
1593   TFE_Context* ctx = TFE_NewContext(opts, status);
1594   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1595   TFE_DeleteContextOptions(opts);
1596 
1597   TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
1598   TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
1599   TFE_TensorHandle* dim = TestScalarTensorHandle(ctx, 0);
1600   TFE_Op* concatOp = TFE_NewOp(ctx, "Concat", status);
1601   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1602   TFE_TensorHandle* inputs[] = {input1, input2};
1603   TFE_OpAddInput(concatOp, dim, status);
1604   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1605   TFE_OpAddInputList(concatOp, inputs, 2, status);
1606   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1607 
1608   tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
1609   tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
1610   EXPECT_NE(attr_found, attr_values.cend());
1611   EXPECT_EQ(attr_found->second.type(), tensorflow::DataType::DT_FLOAT);
1612   attr_found = attr_values.find("N");
1613   EXPECT_NE(attr_found, attr_values.cend());
1614   EXPECT_EQ(attr_found->second.i(), 2);
1615 
1616   TFE_TensorHandle* retvals[1] = {nullptr};
1617   int num_retvals = 1;
1618   TFE_Execute(concatOp, &retvals[0], &num_retvals, status);
1619   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1620 
1621   TF_DeleteStatus(status);
1622   TFE_DeleteOp(concatOp);
1623   TFE_DeleteTensorHandle(input1);
1624   TFE_DeleteTensorHandle(input2);
1625   TFE_DeleteTensorHandle(retvals[0]);
1626   TFE_DeleteTensorHandle(dim);
1627   TFE_DeleteContext(ctx);
1628 }
1629 
TEST(CAPI,TestTFE_OpInferMixedTypeInputListAttrs)1630 TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) {
1631   TF_Status* status = TF_NewStatus();
1632   TFE_ContextOptions* opts = TFE_NewContextOptions();
1633   TFE_Context* ctx = TFE_NewContext(opts, status);
1634   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1635   TFE_DeleteContextOptions(opts);
1636 
1637   TFE_TensorHandle* condition = TestScalarTensorHandle(ctx, true);
1638   TFE_TensorHandle* t1 = TestMatrixTensorHandle(ctx);
1639   TFE_TensorHandle* t2 = TestAxisTensorHandle(ctx);
1640   TFE_Op* assertOp = TFE_NewOp(ctx, "Assert", status);
1641   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1642   TFE_OpAddInput(assertOp, condition, status);
1643   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1644   TFE_TensorHandle* data[] = {condition, t1, t2};
1645   TFE_OpAddInputList(assertOp, data, 3, status);
1646   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1647 
1648   tensorflow::AttrValueMap attr_values = ExtractAttrs(assertOp);
1649   tensorflow::AttrValueMap::const_iterator attr_found = attr_values.find("T");
1650   EXPECT_NE(attr_found, attr_values.cend());
1651   EXPECT_EQ(attr_found->second.list().type(0), tensorflow::DataType::DT_BOOL);
1652   EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT);
1653   EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32);
1654 
1655   TFE_TensorHandle* retvals[1] = {nullptr};
1656   int num_retvals = 1;
1657   TFE_Execute(assertOp, &retvals[0], &num_retvals, status);
1658   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1659 
1660   TF_DeleteStatus(status);
1661   TFE_DeleteOp(assertOp);
1662   TFE_DeleteTensorHandle(condition);
1663   TFE_DeleteTensorHandle(t1);
1664   TFE_DeleteTensorHandle(t2);
1665   TFE_DeleteTensorHandle(retvals[0]);
1666   TFE_DeleteContext(ctx);
1667 }
1668 
TEST(CAPI,TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList)1669 TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
1670   TF_Status* status = TF_NewStatus();
1671   TFE_ContextOptions* opts = TFE_NewContextOptions();
1672   TFE_Context* ctx = TFE_NewContext(opts, status);
1673   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1674   TFE_DeleteContextOptions(opts);
1675 
1676   TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
1677   TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
1678   TFE_TensorHandle* dim = TestScalarTensorHandle(ctx, 0);
1679   TFE_Op* concatOp = TFE_NewOp(ctx, "Concat", status);
1680   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1681   TFE_TensorHandle* inputs[] = {input1, input2};
1682   TFE_OpAddInput(concatOp, dim, status);
1683   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1684   CHECK(tensorflow::unwrap(concatOp)->OpDef());
1685   TFE_OpAddInput(concatOp, inputs[0], status);
1686   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1687   EXPECT_FALSE(tensorflow::unwrap(concatOp)->OpDef())
1688       << "Inference context is still present";
1689   TFE_OpAddInput(concatOp, inputs[1], status);
1690   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1691 
1692   tensorflow::AttrValueMap attr_values = ExtractAttrs(concatOp);
1693   EXPECT_EQ(attr_values.find("T"), attr_values.end());
1694   EXPECT_EQ(attr_values.find("N"), attr_values.end());
1695 
1696   TF_DeleteStatus(status);
1697   TFE_DeleteOp(concatOp);
1698   TFE_DeleteTensorHandle(input1);
1699   TFE_DeleteTensorHandle(input2);
1700   TFE_DeleteTensorHandle(dim);
1701   TFE_DeleteContext(ctx);
1702 }
1703 
TEST(CAPI,TestTFE_OpGetInputAndOutputLengths)1704 TEST(CAPI, TestTFE_OpGetInputAndOutputLengths) {
1705   TF_Status* status = TF_NewStatus();
1706   TFE_ContextOptions* opts = TFE_NewContextOptions();
1707   TFE_Context* ctx = TFE_NewContext(opts, status);
1708   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1709   TFE_DeleteContextOptions(opts);
1710 
1711   TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
1712   TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
1713   TFE_Op* identityOp = TFE_NewOp(ctx, "IdentityN", status);
1714   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1715 
1716   // Try to retrieve lengths before building the attributes (should fail)
1717   EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "input", status));
1718   CHECK_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
1719   EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status));
1720   CHECK_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
1721 
1722   TFE_TensorHandle* inputs[] = {input1, input2};
1723   TFE_OpAddInputList(identityOp, inputs, 2, status);
1724   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1725 
1726   // Try to retrieve lengths before executing the op (should work)
1727   EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status));
1728   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1729   EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status));
1730   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1731 
1732   TFE_TensorHandle* retvals[2] = {nullptr};
1733   int num_retvals = 2;
1734   TFE_Execute(identityOp, &retvals[0], &num_retvals, status);
1735   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1736 
1737   // Try to retrieve lengths after executing the op (should work)
1738   EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status));
1739   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1740   EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status));
1741   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1742 
1743   TF_DeleteStatus(status);
1744   TFE_DeleteOp(identityOp);
1745   TFE_DeleteTensorHandle(input1);
1746   TFE_DeleteTensorHandle(input2);
1747   TFE_DeleteTensorHandle(retvals[0]);
1748   TFE_DeleteTensorHandle(retvals[1]);
1749   TFE_DeleteContext(ctx);
1750 }
1751 
TEST(CAPI,TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments)1752 TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
1753   TF_Status* status = TF_NewStatus();
1754   TFE_ContextOptions* opts = TFE_NewContextOptions();
1755   TFE_Context* ctx = TFE_NewContext(opts, status);
1756   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1757   TFE_DeleteContextOptions(opts);
1758 
1759   TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
1760   TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
1761   TFE_Op* identityOp = TFE_NewOp(ctx, "IdentityN", status);
1762   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1763   TFE_TensorHandle* inputs[] = {input1, input2};
1764   TFE_OpAddInputList(identityOp, inputs, 2, status);
1765   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1766 
1767   EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "cheese", status));
1768   CHECK_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status)) << TF_Message(status);
1769   EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "cheese", status));
1770   CHECK_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status)) << TF_Message(status);
1771 
1772   TF_DeleteStatus(status);
1773   TFE_DeleteOp(identityOp);
1774   TFE_DeleteTensorHandle(input1);
1775   TFE_DeleteTensorHandle(input2);
1776   TFE_DeleteContext(ctx);
1777 }
1778 
TestOpAddAttrs(bool use_tfrt)1779 void TestOpAddAttrs(bool use_tfrt) {
1780   TF_Status* status = TF_NewStatus();
1781   TFE_ContextOptions* opts = TFE_NewContextOptions();
1782   TFE_ContextOptionsSetTfrt(opts, use_tfrt);
1783   TFE_Context* ctx = TFE_NewContext(opts, status);
1784   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1785   TFE_DeleteContextOptions(opts);
1786 
1787   TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
1788   TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
1789   TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
1790   const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op);
1791 
1792   TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
1793   TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
1794   TFE_OpAddAttrs(copy_op, attributes);
1795   unsigned char is_list = 0;
1796   ASSERT_EQ(TF_ATTR_TYPE,
1797             TFE_OpGetAttrType(copy_op, "dtype", &is_list, status));
1798   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1799   ASSERT_EQ(TF_ATTR_SHAPE,
1800             TFE_OpGetAttrType(copy_op, "shape", &is_list, status));
1801   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1802 
1803   tensorflow::AttrValueMap attr_values;
1804   if (use_tfrt) {
1805 #ifdef PLATFORM_GOOGLE
1806     auto* op = tensorflow::down_cast<tfrt::tf::OperationInterface*>(
1807         tensorflow::unwrap(copy_op));
1808     auto* tfrt_op_attrs =
1809         tensorflow::down_cast<const tfrt::tf::OpAttrsInterface*>(
1810             op->GetOpAttrs());
1811     tensorflow::DataType result;
1812     tfrt_op_attrs->GetType("dtype", &result);
1813     EXPECT_EQ(tensorflow::DT_FLOAT, result);
1814     tfrt_op_attrs->GetFallbackAttrs()->FillAttrValueMap(&attr_values);
1815 #endif
1816   } else {
1817     tensorflow::EagerOperation* op =
1818         tensorflow::OperationFromInterface(tensorflow::unwrap(copy_op));
1819     op->Attrs().FillAttrValueMap(&attr_values);
1820   }
1821   EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
1822 
1823   TF_DeleteStatus(status);
1824   TFE_DeleteOp(var_op);
1825   TFE_DeleteOp(copy_op);
1826   TFE_DeleteContext(ctx);
1827 }
1828 
TEST(CAPI,TestTFE_OpAddAttrs)1829 TEST(CAPI, TestTFE_OpAddAttrs) { TestOpAddAttrs(/*use_tfrt=*/false); }
1830 
1831 #ifdef PLATFORM_GOOGLE
TEST(CAPI,TestTFE_OpAddAttrs_TFRT)1832 TEST(CAPI, TestTFE_OpAddAttrs_TFRT) { TestOpAddAttrs(/*use_tfrt=*/true); }
1833 
1834 #endif
1835 
TEST(CAPI,TestTFE_OpAttrsSerialize)1836 TEST(CAPI, TestTFE_OpAttrsSerialize) {
1837   TF_Status* status = TF_NewStatus();
1838   TFE_ContextOptions* opts = TFE_NewContextOptions();
1839   TFE_Context* ctx = TFE_NewContext(opts, status);
1840   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1841   TFE_DeleteContextOptions(opts);
1842 
1843   TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
1844   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1845   TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
1846   TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
1847   const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op);
1848 
1849   TF_Buffer* serialized_attr_values = TF_NewBuffer();
1850   TFE_OpAttrsSerialize(attributes, serialized_attr_values, status);
1851   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1852   tensorflow::NameAttrList name_and_attrs;
1853   ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data,
1854                                             serialized_attr_values->length));
1855   ASSERT_EQ("VarHandleOp", name_and_attrs.name());
1856   ASSERT_EQ(tensorflow::DT_INT64,
1857             name_and_attrs.attr().find("dtype")->second.type());
1858   TF_DeleteBuffer(serialized_attr_values);
1859 
1860   TFE_Op* var_op_2 = TFE_NewOp(ctx, "VarHandleOp", status);
1861 
1862   string serialized_dtype;
1863   ASSERT_TRUE(name_and_attrs.attr().find("dtype")->second.SerializeToString(
1864       &serialized_dtype));
1865   TFE_OpSetAttrValueProto(
1866       var_op_2, "dtype",
1867       reinterpret_cast<const void*>(serialized_dtype.c_str()),
1868       serialized_dtype.length(), status);
1869   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1870 
1871   tensorflow::AttrValueMap attr_values;
1872   tensorflow::EagerOperation* op =
1873       tensorflow::OperationFromInterface(tensorflow::unwrap(var_op_2));
1874   op->Attrs().FillAttrValueMap(&attr_values);
1875   EXPECT_EQ(tensorflow::DT_INT64, attr_values.find("dtype")->second.type());
1876 
1877   TF_DeleteStatus(status);
1878   TFE_DeleteOp(var_op);
1879   TFE_DeleteOp(var_op_2);
1880   TFE_DeleteContext(ctx);
1881 }
1882 
1883 // Needs to work with a const TFE_Op since custom devices should not modify the
1884 // op they are called with.
CloneOp(const TFE_Op * other)1885 TFE_Op* CloneOp(const TFE_Op* other) {
1886   TF_Status* status = TF_NewStatus();
1887   TFE_Context* context = TFE_OpGetContext(other, status);
1888   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1889   const char* op_name = TFE_OpGetName(other, status);
1890   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1891   TFE_Op* ret = TFE_NewOp(context, op_name, status);
1892   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1893   const char* device = TFE_OpGetDevice(other, status);
1894   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1895   TFE_OpSetDevice(ret, device, status);
1896   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1897   TFE_OpAddAttrs(ret, TFE_OpGetAttrs(other));
1898   int num_inputs = TFE_OpGetFlatInputCount(other, status);
1899   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1900   for (int input_index = 0; input_index < num_inputs; ++input_index) {
1901     TFE_TensorHandle* input = TFE_OpGetFlatInput(other, input_index, status);
1902     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1903     TFE_OpAddInput(ret, input, status);
1904     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1905   }
1906   TF_DeleteStatus(status);
1907   return ret;
1908 }
1909 
TEST(CAPI,TestTFE_OpRecreation)1910 TEST(CAPI, TestTFE_OpRecreation) {
1911   TF_Status* status = TF_NewStatus();
1912   TFE_ContextOptions* opts = TFE_NewContextOptions();
1913   TFE_Context* ctx = TFE_NewContext(opts, status);
1914   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1915   TFE_DeleteContextOptions(opts);
1916 
1917   // Clone an op with attributes and a device set.
1918   TFE_Op* original_var_op = TFE_NewOp(ctx, "VarHandleOp", status);
1919   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1920   TFE_OpSetAttrType(original_var_op, "dtype", TF_INT64);
1921   TFE_OpSetAttrShape(original_var_op, "shape", {}, 0, status);
1922   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1923   EXPECT_EQ("", std::string(TFE_OpGetDevice(original_var_op, status)));
1924   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1925   TFE_OpSetDevice(original_var_op,
1926                   "/job:localhost/replica:0/task:0/device:CPU:0", status);
1927   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1928   TFE_Op* cloned = CloneOp(original_var_op);
1929 
1930   EXPECT_EQ("/job:localhost/replica:0/task:0/device:CPU:0",
1931             std::string(TFE_OpGetDevice(cloned, status)));
1932   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1933   EXPECT_EQ("VarHandleOp", std::string(TFE_OpGetName(cloned, status)));
1934   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1935 
1936   int num_retvals = 1;
1937   TFE_TensorHandle* ret;
1938   TFE_Execute(cloned, &ret, &num_retvals, status);
1939   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1940   TFE_DeleteTensorHandle(ret);
1941 
1942   // Clone an op with inputs and no device set.
1943   TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
1944   TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
1945   TFE_Op* original_identity = TFE_NewOp(ctx, "IdentityN", status);
1946   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1947   TFE_TensorHandle* inputs[] = {input1, input2};
1948   TFE_OpAddInputList(original_identity, inputs, 2, status);
1949   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1950   TFE_Op* cloned_identity = CloneOp(original_identity);
1951   EXPECT_EQ("", std::string(TFE_OpGetDevice(cloned_identity, status)));
1952   TFE_TensorHandle* identity_ret[] = {nullptr, nullptr};
1953   num_retvals = 2;
1954   TFE_Execute(cloned_identity, identity_ret, &num_retvals, status);
1955   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1956 
1957   TFE_DeleteTensorHandle(input1);
1958   TFE_DeleteTensorHandle(input2);
1959   TFE_DeleteTensorHandle(identity_ret[0]);
1960   TFE_DeleteTensorHandle(identity_ret[1]);
1961 
1962   TFE_DeleteOp(cloned_identity);
1963   TFE_DeleteOp(original_identity);
1964   TFE_DeleteOp(original_var_op);
1965   TFE_DeleteOp(cloned);
1966   TF_DeleteStatus(status);
1967   TFE_DeleteContext(ctx);
1968 }
1969 
ReplaceTaskInServerDef(const tensorflow::ServerDef & server_def,int task_index)1970 tensorflow::ServerDef ReplaceTaskInServerDef(
1971     const tensorflow::ServerDef& server_def, int task_index) {
1972   tensorflow::ServerDef server_def_copy = server_def;
1973   tensorflow::ClusterDef* cluster_def = server_def_copy.mutable_cluster();
1974   tensorflow::JobDef* job_def = cluster_def->mutable_job(0);
1975   const int port = tensorflow::testing::PickUnusedPortOrDie();
1976   job_def->mutable_tasks()->at(task_index) =
1977       tensorflow::strings::StrCat("localhost:", port);
1978   return server_def_copy;
1979 }
1980 
CreateVarHandle(TFE_Context * ctx,const tensorflow::string & device_name,const tensorflow::string & variable_name)1981 TFE_TensorHandle* CreateVarHandle(TFE_Context* ctx,
1982                                   const tensorflow::string& device_name,
1983                                   const tensorflow::string& variable_name) {
1984   TF_Status* status = TF_NewStatus();
1985   // Create the variable handle.
1986   TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
1987   if (TF_GetCode(status) != TF_OK) return nullptr;
1988   TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
1989   TFE_OpSetAttrShape(op, "shape", {}, 0, status);
1990   TFE_OpSetAttrString(op, "container", "localhost", 0);
1991   TFE_OpSetAttrString(op, "shared_name", variable_name.data(),
1992                       variable_name.size());
1993   if (!device_name.empty()) {
1994     TFE_OpSetDevice(op, device_name.c_str(), status);
1995   }
1996   if (TF_GetCode(status) != TF_OK) return nullptr;
1997   TFE_TensorHandle* var_handle = nullptr;
1998   int num_retvals = 1;
1999   TFE_Execute(op, &var_handle, &num_retvals, status);
2000   if (TF_GetCode(status) != TF_OK) return nullptr;
2001   TFE_DeleteOp(op);
2002   if (TF_GetCode(status) != TF_OK) return nullptr;
2003   CHECK_EQ(1, num_retvals);
2004   TF_DeleteStatus(status);
2005   return var_handle;
2006 }
2007 
CreateVariable(TFE_Context * ctx,float value,const tensorflow::string & device_name,const tensorflow::string & variable_name)2008 TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
2009                                  const tensorflow::string& device_name,
2010                                  const tensorflow::string& variable_name) {
2011   TF_Status* status = TF_NewStatus();
2012   TFE_TensorHandle* var_handle =
2013       CreateVarHandle(ctx, device_name, variable_name);
2014 
2015   // Assign 'value' to it.
2016   TFE_Op* op = TFE_NewOp(ctx, "AssignVariableOp", status);
2017   if (TF_GetCode(status) != TF_OK) return nullptr;
2018   TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
2019   TFE_OpAddInput(op, var_handle, status);
2020   if (!device_name.empty()) {
2021     TFE_OpSetDevice(op, device_name.c_str(), status);
2022   }
2023 
2024   // Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
2025   std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
2026       TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor);
2027   memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
2028 
2029   std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
2030       value_handle(TFE_NewTensorHandle(t.get(), status),
2031                    TFE_DeleteTensorHandle);
2032   if (TF_GetCode(status) != TF_OK) return nullptr;
2033 
2034   TFE_OpAddInput(op, value_handle.get(), status);
2035   if (TF_GetCode(status) != TF_OK) return nullptr;
2036 
2037   int num_retvals = 0;
2038   TFE_Execute(op, nullptr, &num_retvals, status);
2039   TFE_DeleteOp(op);
2040   if (TF_GetCode(status) != TF_OK) return nullptr;
2041   CHECK_EQ(0, num_retvals);
2042   TF_DeleteStatus(status);
2043   return var_handle;
2044 }
2045 
CreateContext(const string & serialized_server_def,bool isolate_session_state)2046 TFE_Context* CreateContext(const string& serialized_server_def,
2047                            bool isolate_session_state) {
2048   TF_Status* status = TF_NewStatus();
2049   TFE_ContextOptions* opts = TFE_NewContextOptions();
2050   opts->session_options.options.config.set_isolate_session_state(
2051       isolate_session_state);
2052   TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(false));
2053   TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
2054   TFE_Context* ctx = TFE_NewContext(opts, status);
2055   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
2056   TFE_ContextSetServerDef(ctx, 0, serialized_server_def.data(),
2057                           serialized_server_def.size(), status);
2058   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
2059   TFE_DeleteContextOptions(opts);
2060   TF_DeleteStatus(status);
2061   return ctx;
2062 }
2063 
ListDeviceNames(TFE_Context * ctx)2064 std::vector<std::string> ListDeviceNames(TFE_Context* ctx) {
2065   TF_Status* status = TF_NewStatus();
2066   std::vector<std::string> device_names;
2067   TF_DeviceList* devices = TFE_ContextListDevices(ctx, status);
2068   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
2069   const int num_devices = TF_DeviceListCount(devices);
2070   for (int i = 0; i < num_devices; ++i) {
2071     device_names.emplace_back(TF_DeviceListName(devices, i, status));
2072     EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
2073   }
2074   TF_DeleteDeviceList(devices);
2075   TF_DeleteStatus(status);
2076   return device_names;
2077 }
2078 
TEST(CAPI,ShareVariableAcrossContextsWorks)2079 TEST(CAPI, ShareVariableAcrossContextsWorks) {
2080   // TODO(shreepadma): Add a test case with isolate_session_state set to true.
2081   tensorflow::ServerDef server_def_0 = GetServerDef(3);
2082   server_def_0.mutable_default_session_config()->set_isolate_session_state(
2083       false);
2084   tensorflow::ServerDef server_def_1 =
2085       ReplaceTaskInServerDef(server_def_0, /*task_index=*/0);
2086 
2087   // These server defs have task index set to 0.
2088   string serialized_server_def_0 = server_def_0.SerializeAsString();
2089   string serialized_server_def_1 = server_def_1.SerializeAsString();
2090 
2091   // Create two worker tasks.
2092   server_def_0.set_task_index(1);
2093   std::unique_ptr<tensorflow::GrpcServer> worker_server1;
2094   ASSERT_TRUE(tensorflow::GrpcServer::Create(
2095                   server_def_0, tensorflow::Env::Default(), &worker_server1)
2096                   .ok());
2097   ASSERT_TRUE(worker_server1->Start().ok());
2098   server_def_0.set_task_index(2);
2099   std::unique_ptr<tensorflow::GrpcServer> worker_server2;
2100   ASSERT_TRUE(tensorflow::GrpcServer::Create(
2101                   server_def_0, tensorflow::Env::Default(), &worker_server2)
2102                   .ok());
2103   ASSERT_TRUE(worker_server2->Start().ok());
2104 
2105   TFE_Context* ctx_0 = CreateContext(serialized_server_def_0,
2106                                      /*isolate_session_state=*/false);
2107   TFE_Context* ctx_1 = CreateContext(serialized_server_def_1,
2108                                      /*isolate_session_state=*/false);
2109 
2110   // Remote device on `worker1`.
2111   const char remote_device[] = "/job:localhost/replica:0/task:1/device:CPU:0";
2112   // `ctx_0`, `ctx_1`, `ctx_2` contains `remote_device`.
2113   {
2114     const std::vector<std::string>& device_names = ListDeviceNames(ctx_0);
2115     ASSERT_TRUE(std::find(device_names.begin(), device_names.end(),
2116                           remote_device) != device_names.end());
2117   }
2118 
2119   {
2120     const std::vector<std::string>& device_names = ListDeviceNames(ctx_1);
2121     ASSERT_TRUE(std::find(device_names.begin(), device_names.end(),
2122                           remote_device) != device_names.end());
2123   }
2124 
2125   // Create a variable using `ctx_0`.
2126   // Read the variable using `ctx_1`. This read should succeed.
2127   // 1. Create a variable on `remote_device`, using `ctx_0`.
2128   TFE_TensorHandle* handle_0 =
2129       CreateVariable(ctx_0, 1.2, remote_device, /*variable_name=*/"var2");
2130 
2131   // 2. Wait for `var2` to be created and initialized on the worker.
2132   TF_Status* status = TF_NewStatus();
2133   TFE_ContextAsyncWait(ctx_0, status);
2134   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2135   TF_DeleteStatus(status);
2136 
2137   // 3. Read `var_2` using `ctx_1`. This read should succeed since `ctx_1` was
2138   // created with `isolate_session_state` set to false.
2139   {
2140     // Create a handle to `var2`, using `ctx_1`.
2141     TFE_TensorHandle* var_handle =
2142         CreateVarHandle(ctx_1, remote_device, /*variable_name=*/"var2");
2143 
2144     TFE_TensorHandle* handle_1 = nullptr;
2145     int num_retvals = 1;
2146     TF_Status* status = TF_NewStatus();
2147     TFE_Op* op = TFE_NewOp(ctx_1, "ReadVariableOp", status);
2148     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2149     TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
2150     TFE_OpAddInput(op, var_handle, status);
2151     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2152     TFE_Execute(op, &handle_1, &num_retvals, status);
2153     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2154     TFE_DeleteOp(op);
2155 
2156     ASSERT_EQ(1, num_retvals);
2157     EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(handle_1));
2158     EXPECT_EQ(0, TFE_TensorHandleNumDims(handle_1, status));
2159     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2160 
2161     // Read the value of tensor handle `handle_1`.
2162     float value = 0.0f;
2163     TF_Tensor* t = TFE_TensorHandleResolve(handle_1, status);
2164     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2165     ASSERT_EQ(sizeof(float), TF_TensorByteSize(t));
2166     memcpy(&value, TF_TensorData(t), sizeof(float));
2167     TF_DeleteTensor(t);
2168     EXPECT_EQ(1.2f, value);
2169     TFE_DeleteTensorHandle(handle_1);
2170     TF_DeleteStatus(status);
2171     TFE_DeleteTensorHandle(var_handle);
2172   }
2173 
2174   TFE_DeleteTensorHandle(handle_0);
2175 
2176   TFE_DeleteContext(ctx_0);
2177   TFE_DeleteContext(ctx_1);
2178 
2179   worker_server1.release();
2180   worker_server2.release();
2181 }
2182 
ReplaceTaskInServerDef(tensorflow::ServerDef * server_def,int task_index,const string & host,int port)2183 void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index,
2184                             const string& host, int port) {
2185   tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0);
2186   job_def->mutable_tasks()->at(task_index) =
2187       tensorflow::strings::StrCat(host, ":", port);
2188 }
2189 
TEST(CAPI,ShareVariableAcrossContextsAfterUpdateContextWorks)2190 TEST(CAPI, ShareVariableAcrossContextsAfterUpdateContextWorks) {
2191   tensorflow::ServerDef server_def_0 = GetServerDef(3);
2192   server_def_0.mutable_default_session_config()->set_isolate_session_state(
2193       false);
2194   tensorflow::ServerDef server_def_1 =
2195       ReplaceTaskInServerDef(server_def_0, /*task_index=*/0);
2196 
2197   // These server defs have task index set to 0.
2198   string serialized_server_def_0 = server_def_0.SerializeAsString();
2199   string serialized_server_def_1 = server_def_1.SerializeAsString();
2200 
2201   // Create two worker tasks.
2202   server_def_0.set_task_index(1);
2203   std::unique_ptr<tensorflow::GrpcServer> worker_server1;
2204   ASSERT_TRUE(tensorflow::GrpcServer::Create(
2205                   server_def_0, tensorflow::Env::Default(), &worker_server1)
2206                   .ok());
2207   ASSERT_TRUE(worker_server1->Start().ok());
2208   server_def_0.set_task_index(2);
2209   std::unique_ptr<tensorflow::GrpcServer> worker_server2;
2210   ASSERT_TRUE(tensorflow::GrpcServer::Create(
2211                   server_def_0, tensorflow::Env::Default(), &worker_server2)
2212                   .ok());
2213   ASSERT_TRUE(worker_server2->Start().ok());
2214 
2215   // Create two contexts.
2216   TFE_Context* ctx_0 = CreateContext(serialized_server_def_0,
2217                                      /*isolate_session_state=*/false);
2218   TFE_Context* ctx_1 = CreateContext(serialized_server_def_1,
2219                                      /*isolate_session_state=*/false);
2220 
2221   // Remote device on `worker2`.
2222   const char remote_device[] = "/job:localhost/replica:0/task:2/device:CPU:0";
2223   // `ctx_0`, `ctx_1` contains `remote_device`.
2224   {
2225     const std::vector<std::string>& device_names = ListDeviceNames(ctx_0);
2226     ASSERT_TRUE(std::find(device_names.begin(), device_names.end(),
2227                           remote_device) != device_names.end());
2228   }
2229 
2230   {
2231     const std::vector<std::string>& device_names = ListDeviceNames(ctx_1);
2232     ASSERT_TRUE(std::find(device_names.begin(), device_names.end(),
2233                           remote_device) != device_names.end());
2234   }
2235 
2236   // Create a variable using `ctx_0`.
2237   // Replace worker1 using a new worker, and update the contexts.
2238   // Read the variable using `ctx_1`. This read should succeed.
2239   //
2240   // 1. Create a variable on `remote_device`, using `ctx_0`.
2241   TFE_TensorHandle* handle_0 =
2242       CreateVariable(ctx_0, 1.2, remote_device, /*variable_name=*/"var");
2243 
2244   // 2. Wait for `var` to be created and initialized on the worker.
2245   TF_Status* status = TF_NewStatus();
2246   TFE_ContextAsyncWait(ctx_0, status);
2247   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2248   TF_DeleteStatus(status);
2249 
2250   int port = tensorflow::testing::PickUnusedPortOrDie();
2251   // 3. Replace worker1 with a new worker in server_def_0 and server_def_1.
2252   ReplaceTaskInServerDef(&server_def_0, /*task_index=*/1, "localhost", port);
2253   ReplaceTaskInServerDef(&server_def_1, /*task_index=*/1, "localhost", port);
2254   // 4. Start a new task to replace worker1.
2255   server_def_0.set_task_index(1);
2256   worker_server1.release();
2257   ASSERT_TRUE(tensorflow::GrpcServer::Create(
2258                   server_def_0, tensorflow::Env::Default(), &worker_server1)
2259                   .ok());
2260   ASSERT_TRUE(worker_server1->Start().ok());
2261 
2262   // 5a. Update `ctx_0` with updated `server_def_0`.
2263   {
2264     server_def_0.set_task_index(0);
2265     string serialized_update = server_def_0.SerializeAsString();
2266     TF_Status* status = TF_NewStatus();
2267     TFE_ContextUpdateServerDef(ctx_0, 0, serialized_update.data(),
2268                                serialized_update.size(), status);
2269     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2270     TF_DeleteStatus(status);
2271   }
2272 
2273   // 5b. Update `ctx_1` with updated `server_def_1`.
2274   {
2275     server_def_1.set_task_index(0);
2276     string serialized_update = server_def_1.SerializeAsString();
2277     TF_Status* status = TF_NewStatus();
2278     TFE_ContextUpdateServerDef(ctx_1, 0, serialized_update.data(),
2279                                serialized_update.size(), status);
2280     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2281     TF_DeleteStatus(status);
2282   }
2283 
2284   // 6. Read `var` using `ctx_1`. This read should succeed since `ctx_1` was
2285   // created with `isolate_session_state` set to false, and update should
2286   // preserve it.
2287   {
2288     // Create a handle to `var`, using `ctx_1`.
2289     TFE_TensorHandle* var_handle =
2290         CreateVarHandle(ctx_1, remote_device, /*variable_name=*/"var");
2291 
2292     TFE_TensorHandle* handle_1 = nullptr;
2293     int num_retvals = 1;
2294     TF_Status* status = TF_NewStatus();
2295     TFE_Op* op = TFE_NewOp(ctx_1, "ReadVariableOp", status);
2296     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2297     TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
2298     TFE_OpAddInput(op, var_handle, status);
2299     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2300     TFE_Execute(op, &handle_1, &num_retvals, status);
2301     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2302     TFE_DeleteOp(op);
2303 
2304     ASSERT_EQ(1, num_retvals);
2305     EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(handle_1));
2306     EXPECT_EQ(0, TFE_TensorHandleNumDims(handle_1, status));
2307     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2308 
2309     // Read the value of tensor handle `handle_1`.
2310     float value = 0.0f;
2311     TF_Tensor* t = TFE_TensorHandleResolve(handle_1, status);
2312     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2313     ASSERT_EQ(sizeof(float), TF_TensorByteSize(t));
2314     memcpy(&value, TF_TensorData(t), sizeof(float));
2315     TF_DeleteTensor(t);
2316     EXPECT_EQ(1.2f, value);
2317     TFE_DeleteTensorHandle(handle_1);
2318     TF_DeleteStatus(status);
2319     TFE_DeleteTensorHandle(var_handle);
2320   }
2321 
2322   TFE_DeleteTensorHandle(handle_0);
2323 
2324   TFE_DeleteContext(ctx_0);
2325   TFE_DeleteContext(ctx_1);
2326 
2327   worker_server1.release();
2328   worker_server2.release();
2329 }
2330 
CreateSingleHostServerDef(const tensorflow::ServerDef & cluster_server_def,int task_index)2331 tensorflow::ServerDef CreateSingleHostServerDef(
2332     const tensorflow::ServerDef& cluster_server_def, int task_index) {
2333   tensorflow::ServerDef single_host_server_def;
2334   single_host_server_def.set_job_name("worker");
2335   single_host_server_def.set_protocol(cluster_server_def.protocol());
2336   single_host_server_def.set_task_index(0);
2337   tensorflow::ClusterDef* cluster_def =
2338       single_host_server_def.mutable_cluster();
2339   tensorflow::JobDef* job_def = cluster_def->add_job();
2340   job_def->set_name("client");
2341 
2342   // Add a client.
2343   job_def->mutable_tasks()->insert(
2344       {0, tensorflow::strings::StrCat(
2345               "localhost:", tensorflow::testing::PickUnusedPortOrDie())});
2346 
2347   tensorflow::JobDef* job_def2 = cluster_def->add_job();
2348   job_def2->set_name("worker");
2349 
2350   // Copy over `host:port` at `task_index`
2351   for (auto task : cluster_server_def.cluster().job(0).tasks()) {
2352     if (task.first == task_index) {
2353       job_def2->mutable_tasks()->insert({task.first, task.second});
2354     }
2355   }
2356 
2357   return single_host_server_def;
2358 }
2359 
GetClusterServerDef(const string & worker_job_name,int num_workers)2360 tensorflow::ServerDef GetClusterServerDef(const string& worker_job_name,
2361                                           int num_workers) {
2362   tensorflow::ServerDef server_def = GetServerDef(worker_job_name, num_workers);
2363   tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
2364 
2365   // Add a client.
2366   tensorflow::JobDef* job_def2 = cluster_def->add_job();
2367   job_def2->set_name("client");
2368   job_def2->mutable_tasks()->insert(
2369       {0, tensorflow::strings::StrCat(
2370               "localhost:", tensorflow::testing::PickUnusedPortOrDie())});
2371   return server_def;
2372 }
2373 
TEST(CAPI,SingleHostServerDefWorks)2374 TEST(CAPI, SingleHostServerDefWorks) {
2375   // Create a server def that represents a 2-process cluster and a client.
2376   // Example:
2377   //
2378   // cluster { job { name: "worker"
2379   //                 tasks { key: 0 value: "localhost:14522" }
2380   //                 tasks { key: 1 value: "localhost:14523" }
2381   //               }
2382   //           job { name: "client"
2383   //                 tasks { key: 0 value: "localhost:14524" }
2384   //               }
2385   //         } job_name: "worker" protocol: "grpc"
2386   //
2387   tensorflow::ServerDef cluster_server_def = GetClusterServerDef("worker", 2);
2388 
2389   // Create two worker tasks, using single host server defs.
2390   // A single host server def contains a client and the remote host.
2391   // Example:
2392   //
2393   //  Worker1:
2394   //  cluster { job { name: "client" tasks { key: 0 value: "localhost:14525" } }
2395   //            job { name: "worker" tasks { key: 1 value: "localhost:14523" } }
2396   //          } job_name: "worker" task_index: 1 protocol: "grpc"
2397   //
2398   //  Worker0:
2399   //  cluster { job { name: "client" tasks { key: 0 value: "localhost:14526" } }
2400   //            job { name: "worker" tasks { key: 0 value: "localhost:14522" } }
2401   //          } job_name: "worker" protocol: "grpc"
2402   //
2403 
2404   // Create `worker_1` using single host server def `worker_1_server_def`.
2405   tensorflow::ServerDef worker_1_server_def =
2406       CreateSingleHostServerDef(cluster_server_def, 1);
2407   worker_1_server_def.set_task_index(1);
2408   worker_1_server_def.set_job_name("worker");
2409 
2410   std::unique_ptr<tensorflow::GrpcServer> worker_server1;
2411   ASSERT_TRUE(tensorflow::GrpcServer::Create(worker_1_server_def,
2412                                              tensorflow::Env::Default(),
2413                                              &worker_server1)
2414                   .ok());
2415   ASSERT_TRUE(worker_server1->Start().ok());
2416 
2417   // Create context `local_ctx` using single host server def -
2418   // `worker_1_server_def`.
2419   worker_1_server_def.set_task_index(0);
2420   worker_1_server_def.set_job_name("client");
2421   TFE_Context* local_ctx =
2422       CreateContext(worker_1_server_def.SerializeAsString(),
2423                     /*isolate_session_state=*/false);
2424 
2425   const char remote_device[] = "/job:worker/replica:0/task:1/device:CPU:0";
2426 
2427   // Create a variable `var` on `worker2` using `local_ctx`.
2428   TFE_TensorHandle* handle_0 =
2429       CreateVariable(local_ctx, 1.2, remote_device, /*variable_name=*/"var");
2430   TF_Status* status = TF_NewStatus();
2431   TFE_ContextAsyncWait(local_ctx, status);
2432   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2433   TF_DeleteStatus(status);
2434   TFE_DeleteTensorHandle(handle_0);
2435 
2436   // Create `worker0` using single host server def `worker_0_server_def`.
2437   tensorflow::ServerDef worker_0_server_def =
2438       CreateSingleHostServerDef(cluster_server_def, 0);
2439   worker_0_server_def.set_task_index(0);
2440 
2441   std::unique_ptr<tensorflow::GrpcServer> worker_server0;
2442   ASSERT_TRUE(tensorflow::GrpcServer::Create(worker_0_server_def,
2443                                              tensorflow::Env::Default(),
2444                                              &worker_server0)
2445                   .ok());
2446   ASSERT_TRUE(worker_server0->Start().ok());
2447 
2448   // Create a remote context, `remote_ctx`, using `cluster_server_def`.
2449   cluster_server_def.set_task_index(0);
2450   cluster_server_def.set_job_name("client");
2451   TFE_Context* remote_ctx =
2452       CreateContext(cluster_server_def.SerializeAsString(),
2453                     /*isolate_session_state=*/false);
2454 
2455   // Read variable `var` using `remote_ctx`, created using `cluster_server_def`.
2456   {
2457     // Create a handle to `var`.
2458     TFE_TensorHandle* var_handle =
2459         CreateVarHandle(remote_ctx, remote_device, /*variable_name=*/"var");
2460 
2461     TFE_TensorHandle* handle_1 = nullptr;
2462     int num_retvals = 1;
2463     TF_Status* status = TF_NewStatus();
2464     TFE_Op* op = TFE_NewOp(remote_ctx, "ReadVariableOp", status);
2465     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2466     TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
2467     TFE_OpAddInput(op, var_handle, status);
2468     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2469     TFE_Execute(op, &handle_1, &num_retvals, status);
2470     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2471     TFE_DeleteOp(op);
2472 
2473     ASSERT_EQ(1, num_retvals);
2474     EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(handle_1));
2475     EXPECT_EQ(0, TFE_TensorHandleNumDims(handle_1, status));
2476     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2477 
2478     // Read the value of tensor handle `handle_1`.
2479     float value = 0.0f;
2480     TF_Tensor* t = TFE_TensorHandleResolve(handle_1, status);
2481     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2482     ASSERT_EQ(sizeof(float), TF_TensorByteSize(t));
2483     memcpy(&value, TF_TensorData(t), sizeof(float));
2484     TF_DeleteTensor(t);
2485     EXPECT_EQ(1.2f, value);
2486     TFE_DeleteTensorHandle(handle_1);
2487     TF_DeleteStatus(status);
2488     TFE_DeleteTensorHandle(var_handle);
2489   }
2490 
2491   TFE_DeleteContext(local_ctx);
2492   TFE_DeleteContext(remote_ctx);
2493 
2494   worker_server1.release();
2495   worker_server0.release();
2496 }
2497 
2498 }  // namespace
2499