• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/c_api_experimental.h"
17 
18 #include <string.h>
19 
20 #include "tensorflow/c/eager/c_api.h"
21 #include "tensorflow/c/eager/c_api_test_util.h"
22 #include "tensorflow/core/lib/monitoring/collection_registry.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/protobuf.h"
25 #include "tensorflow/core/platform/str_util.h"
26 #include "tensorflow/core/platform/test.h"
27 #include "tensorflow/core/platform/test_benchmark.h"
28 
29 using tensorflow::string;
30 
31 namespace tensorflow {
32 namespace {
33 
HasSubstr(absl::string_view base,absl::string_view substr)34 static bool HasSubstr(absl::string_view base, absl::string_view substr) {
35   bool ok = absl::StrContains(base, substr);
36   EXPECT_TRUE(ok) << base << ", expected substring " << substr;
37   return ok;
38 }
39 
TEST(CAPI,MonitoringCounter0)40 TEST(CAPI, MonitoringCounter0) {
41   TF_Status* status = TF_NewStatus();
42   auto* counter =
43       TFE_MonitoringNewCounter0("test/counter", status, "description");
44   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
45   TF_DeleteStatus(status);
46   auto* cell = TFE_MonitoringGetCellCounter0(counter);
47   TFE_MonitoringCounterCellIncrementBy(cell, 1);
48   EXPECT_EQ(TFE_MonitoringCounterCellValue(cell), 1);
49   auto* collection_registry = monitoring::CollectionRegistry::Default();
50   monitoring::CollectionRegistry::CollectMetricsOptions options;
51   std::unique_ptr<monitoring::CollectedMetrics> metrics =
52       collection_registry->CollectMetrics(options);
53 
54   EXPECT_EQ("test/counter",
55             metrics->point_set_map.at("test/counter")->metric_name);
56   EXPECT_EQ(
57       1, metrics->point_set_map.at("test/counter")->points.at(0)->int64_value);
58 
59   TFE_MonitoringCounterCellIncrementBy(cell, 5);
60   EXPECT_EQ(TFE_MonitoringCounterCellValue(cell), 6);
61   metrics = collection_registry->CollectMetrics(options);
62   EXPECT_EQ(
63       6, metrics->point_set_map.at("test/counter")->points.at(0)->int64_value);
64 
65   TFE_MonitoringDeleteCounter0(counter);
66   metrics = collection_registry->CollectMetrics(options);
67   EXPECT_EQ(metrics->point_set_map.end(),
68             metrics->point_set_map.find("test/counter"));
69 }
70 
TEST(CAPI,MonitoringCounterMultiple)71 TEST(CAPI, MonitoringCounterMultiple) {
72   TF_Status* status = TF_NewStatus();
73   auto* counter1 = TFE_MonitoringNewCounter1("test/counter1", status,
74                                              "description", "label1");
75   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
76   auto* cell1 = TFE_MonitoringGetCellCounter1(counter1, "test");
77   TFE_MonitoringCounterCellIncrementBy(cell1, 1);
78   EXPECT_EQ(TFE_MonitoringCounterCellValue(cell1), 1);
79 
80   auto* counter2 = TFE_MonitoringNewCounter2("test/counter2", status,
81                                              "description", "label1", "label2");
82   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
83   TF_DeleteStatus(status);
84   auto* cell2 = TFE_MonitoringGetCellCounter2(counter2, "foo", "bar");
85   TFE_MonitoringCounterCellIncrementBy(cell2, 2);
86   EXPECT_EQ(TFE_MonitoringCounterCellValue(cell2), 2);
87 
88   TFE_MonitoringDeleteCounter1(counter1);
89   TFE_MonitoringDeleteCounter2(counter2);
90 }
91 
TEST(CAPI,MonitoringGauge0)92 TEST(CAPI, MonitoringGauge0) {
93   TF_Status* status = TF_NewStatus();
94   auto* gauge = TFE_MonitoringNewIntGauge0("test/gauge", status, "test");
95   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
96   auto* cell = TFE_MonitoringGetCellIntGauge0(gauge);
97   TFE_MonitoringIntGaugeCellSet(cell, 1);
98   EXPECT_EQ(TFE_MonitoringIntGaugeCellValue(cell), 1);
99   auto* collection_registry = monitoring::CollectionRegistry::Default();
100   monitoring::CollectionRegistry::CollectMetricsOptions options;
101   std::unique_ptr<monitoring::CollectedMetrics> metrics =
102       collection_registry->CollectMetrics(options);
103 
104   EXPECT_EQ("test/gauge", metrics->point_set_map.at("test/gauge")->metric_name);
105   EXPECT_EQ(1,
106             metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
107 
108   TFE_MonitoringIntGaugeCellSet(cell, 5);
109   metrics = collection_registry->CollectMetrics(options);
110   EXPECT_EQ(5,
111             metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
112   TFE_MonitoringDeleteIntGauge0(gauge);
113   TF_DeleteStatus(status);
114 }
115 
TEST(CAPI,MonitoringMultipleGauge)116 TEST(CAPI, MonitoringMultipleGauge) {
117   TF_Status* status = TF_NewStatus();
118   auto* gauge1 =
119       TFE_MonitoringNewBoolGauge1("test/gauge1", status, "test", "label1");
120   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
121   auto* cell1 = TFE_MonitoringGetCellBoolGauge1(gauge1, "foo");
122   TFE_MonitoringBoolGaugeCellSet(cell1, true);
123   EXPECT_TRUE(TFE_MonitoringBoolGaugeCellValue(cell1));
124   TFE_MonitoringDeleteBoolGauge1(gauge1);
125 
126   auto* gauge2 = TFE_MonitoringNewStringGauge2("test/gauge2", status, "test",
127                                                "label1", "label2");
128   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
129   auto* cell2 = TFE_MonitoringGetCellStringGauge2(gauge2, "foo", "bar");
130   TFE_MonitoringStringGaugeCellSet(cell2, "str");
131   auto* buf = new TF_Buffer;
132   TFE_MonitoringStringGaugeCellValue(cell2, buf);
133   string data(static_cast<const char*>(buf->data), buf->length);
134   TF_DeleteBuffer(buf);
135   EXPECT_EQ(data, "str");
136   TFE_MonitoringDeleteStringGauge2(gauge2);
137   TF_DeleteStatus(status);
138 }
139 
TEST(CAPI,MonitoringSampler0)140 TEST(CAPI, MonitoringSampler0) {
141   TF_Status* status = TF_NewStatus();
142   auto* buckets = TFE_MonitoringNewExponentialBuckets(1.0, 2.0, 2);
143   auto* sampler =
144       TFE_MonitoringNewSampler0("test/sampler", buckets, status, "test");
145   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
146   auto* cell = TFE_MonitoringGetCellSampler0(sampler);
147   TFE_MonitoringSamplerCellAdd(cell, 1.0);
148   auto* collection_registry = monitoring::CollectionRegistry::Default();
149   monitoring::CollectionRegistry::CollectMetricsOptions options;
150   std::unique_ptr<monitoring::CollectedMetrics> metrics =
151       collection_registry->CollectMetrics(options);
152 
153   EXPECT_EQ("test/sampler",
154             metrics->point_set_map.at("test/sampler")->metric_name);
155   EXPECT_EQ(1.0, metrics->point_set_map.at("test/sampler")
156                      ->points.at(0)
157                      ->histogram_value.sum());
158 
159   TFE_MonitoringSamplerCellAdd(cell, 5.0);
160   metrics = collection_registry->CollectMetrics(options);
161   EXPECT_EQ(6.0, metrics->point_set_map.at("test/sampler")
162                      ->points.at(0)
163                      ->histogram_value.sum());
164   TFE_MonitoringDeleteBuckets(buckets);
165   TFE_MonitoringDeleteSampler0(sampler);
166   TF_DeleteStatus(status);
167 }
168 
TEST(CAPI,MonitoringMultipleSampler)169 TEST(CAPI, MonitoringMultipleSampler) {
170   TF_Status* status = TF_NewStatus();
171   auto* buckets = TFE_MonitoringNewExponentialBuckets(1.0, 2.0, 2);
172   auto* sampler1 = TFE_MonitoringNewSampler1("test/sampler1", buckets, status,
173                                              "test", "label1");
174   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
175   auto* cell1 = TFE_MonitoringGetCellSampler1(sampler1, "foo");
176   TFE_MonitoringSamplerCellAdd(cell1, 1.0);
177   TFE_MonitoringSamplerCellAdd(cell1, 2.0);
178   TF_Buffer* result1 = TF_NewBuffer();
179   TFE_MonitoringSamplerCellValue(cell1, result1);
180   tensorflow::HistogramProto histogram1;
181   EXPECT_TRUE(histogram1.ParseFromString(
182       {reinterpret_cast<const char*>(result1->data), result1->length}));
183   EXPECT_EQ(histogram1.sum(), 3.0);
184   TF_DeleteBuffer(result1);
185   TFE_MonitoringDeleteSampler1(sampler1);
186 
187   auto* sampler2 = TFE_MonitoringNewSampler2("test/sampler2", buckets, status,
188                                              "test", "label1", "label2");
189   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
190   auto* cell2 = TFE_MonitoringGetCellSampler2(sampler2, "foo", "bar");
191   TFE_MonitoringSamplerCellAdd(cell2, 2.0);
192   TFE_MonitoringSamplerCellAdd(cell2, 3.0);
193   TF_Buffer* result2 = TF_NewBuffer();
194   TFE_MonitoringSamplerCellValue(cell2, result2);
195   tensorflow::HistogramProto histogram2;
196   EXPECT_TRUE(histogram2.ParseFromString(
197       {reinterpret_cast<const char*>(result2->data), result2->length}));
198   EXPECT_EQ(histogram2.sum(), 5.0);
199   TF_DeleteBuffer(result2);
200   TFE_MonitoringDeleteSampler2(sampler2);
201 
202   TFE_MonitoringDeleteBuckets(buckets);
203   TF_DeleteStatus(status);
204 }
205 
TEST(CAPI,CancellationManager)206 TEST(CAPI, CancellationManager) {
207   TFE_CancellationManager* c_mgr = TFE_NewCancellationManager();
208   EXPECT_FALSE(TFE_CancellationManagerIsCancelled(c_mgr));
209   TFE_CancellationManagerStartCancel(c_mgr);
210   EXPECT_TRUE(TFE_CancellationManagerIsCancelled(c_mgr));
211   TFE_DeleteCancellationManager(c_mgr);
212 }
213 
TEST(CAPI,ExecutorContextDestructionOrder)214 TEST(CAPI, ExecutorContextDestructionOrder) {
215   TF_Status* status = TF_NewStatus();
216 
217   {
218     TFE_ContextOptions* opts = TFE_NewContextOptions();
219     TFE_Context* ctx = TFE_NewContext(opts, status);
220     ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
221     TFE_DeleteContextOptions(opts);
222     TFE_Executor* executor = TFE_NewExecutor(
223         /*is_async=*/false, /*enable_streaming_enqueue=*/true);
224     TFE_ContextSetExecutorForThread(ctx, executor);
225 
226     TFE_DeleteContext(ctx);
227     TFE_DeleteExecutor(executor);
228   }
229 
230   {
231     TFE_ContextOptions* opts = TFE_NewContextOptions();
232     TFE_Context* ctx = TFE_NewContext(opts, status);
233     ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
234     TFE_DeleteContextOptions(opts);
235     TFE_Executor* executor = TFE_NewExecutor(
236         /*is_async=*/false, /*enable_streaming_enqueue=*/true);
237     TFE_ContextSetExecutorForThread(ctx, executor);
238 
239     TFE_DeleteExecutor(executor);
240     TFE_DeleteContext(ctx);
241   }
242   TF_DeleteStatus(status);
243 }
244 
TEST(CAPI,Function_ident_CPU)245 TEST(CAPI, Function_ident_CPU) {
246   // First create a simple identity function.
247   TF_Graph* function_graph = TF_NewGraph();
248   TF_OperationDescription* arg_descr =
249       TF_NewOperation(function_graph, "Placeholder", "arg");
250   TF_SetAttrType(arg_descr, "dtype", TF_INT32);
251   TF_Status* status = TF_NewStatus();
252   TF_Operation* arg = TF_FinishOperation(arg_descr, status);
253   ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
254   TF_OperationDescription* id_descr =
255       TF_NewOperation(function_graph, "Identity", "id");
256   TF_SetAttrType(id_descr, "T", TF_INT32);
257   TF_AddInput(id_descr, {arg, 0});
258   TF_Operation* id = TF_FinishOperation(id_descr, status);
259   ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
260   TF_Output input{arg, 0};
261   TF_Output output{id, 0};
262   TF_Function* fn =
263       TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1,
264                          &output, nullptr, nullptr, "test", status);
265   ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
266   TF_DeleteGraph(function_graph);
267   TFE_ContextOptions* opts = TFE_NewContextOptions();
268   TFE_Context* ctx = TFE_NewContext(opts, status);
269   ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
270   TFE_DeleteContextOptions(opts);
271   TFE_ContextAddFunction(ctx, fn, status);
272   ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
273   TF_DeleteFunction(fn);
274 
275   for (bool async : {false, true, false}) {
276     TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx);
277     TFE_Executor* executor = TFE_NewExecutor(
278         /*is_async=*/async, /*enable_streaming_enqueue=*/true);
279     TFE_ContextSetExecutorForThread(ctx, executor);
280     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
281 
282     TF_Tensor* t =
283         TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
284     *reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
285     TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
286     ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
287     TF_DeleteTensor(t);
288 
289     TFE_Op* op = TFE_NewOp(ctx, "ident", status);
290     ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
291     TFE_OpAddInput(op, h, status);
292     ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
293 
294     std::vector<TFE_TensorHandle*> result;
295     result.push_back(nullptr);
296     int num_retvals = 1;
297     TFE_Execute(op, result.data(), &num_retvals, status);
298     TFE_DeleteOp(op);
299     ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
300     ASSERT_EQ(num_retvals, 1);
301 
302     TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
303     ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
304     EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
305     TFE_ContextSetExecutorForThread(ctx, old_executor);
306     TFE_ExecutorWaitForAllPendingNodes(executor, status);
307     ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
308     TFE_DeleteExecutor(executor);
309     TFE_DeleteExecutor(old_executor);
310     TFE_DeleteTensorHandle(h);
311     TF_DeleteTensor(r);
312     TFE_DeleteTensorHandle(result[0]);
313   }
314   TFE_ContextRemoveFunction(ctx, "ident", status);
315   ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
316   TFE_DeleteContext(ctx);
317   ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
318   TF_DeleteStatus(status);
319 }
320 
Executor_MatMul_CPU(bool async)321 void Executor_MatMul_CPU(bool async) {
322   TF_Status* status = TF_NewStatus();
323   TFE_ContextOptions* opts = TFE_NewContextOptions();
324   TFE_Context* ctx = TFE_NewContext(opts, status);
325   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
326   TFE_DeleteContextOptions(opts);
327 
328   TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx);
329   TFE_Executor* executor = TFE_NewExecutor(
330       /*is_async=*/async, /*enable_streaming_enqueue=*/true);
331   TFE_ContextSetExecutorForThread(ctx, executor);
332 
333   TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
334   TFE_Op* matmul = MatMulOp(ctx, m, m);
335   TFE_TensorHandle* retvals[2] = {nullptr, nullptr};
336   int num_retvals = 2;
337   TFE_Execute(matmul, &retvals[0], &num_retvals, status);
338   EXPECT_EQ(1, num_retvals);
339   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
340   TFE_DeleteOp(matmul);
341   TFE_DeleteTensorHandle(m);
342   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
343 
344   TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
345   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
346   TFE_DeleteTensorHandle(retvals[0]);
347   TFE_ContextSetExecutorForThread(ctx, old_executor);
348   TFE_ExecutorWaitForAllPendingNodes(executor, status);
349   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
350   TFE_DeleteExecutor(executor);
351   TFE_DeleteExecutor(old_executor);
352   TFE_DeleteContext(ctx);
353   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
354   float product[4] = {0};
355   EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
356   memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
357   TF_DeleteTensor(t);
358   EXPECT_EQ(7, product[0]);
359   EXPECT_EQ(10, product[1]);
360   EXPECT_EQ(15, product[2]);
361   EXPECT_EQ(22, product[3]);
362   TF_DeleteStatus(status);
363 }
TEST(CAPI,Executor_MatMul_CPU)364 TEST(CAPI, Executor_MatMul_CPU) { Executor_MatMul_CPU(false); }
TEST(CAPI,Executor_MatMul_CPUAsync)365 TEST(CAPI, Executor_MatMul_CPUAsync) { Executor_MatMul_CPU(true); }
366 
Deleter(void * data,size_t unused,void * tensor_handle)367 void Deleter(void* data, size_t unused, void* tensor_handle) {
368   TFE_DeleteTensorHandle(static_cast<TFE_TensorHandle*>(tensor_handle));
369 }
370 
TEST(CAPI,TensorHandleOnDeviceMemory)371 TEST(CAPI, TensorHandleOnDeviceMemory) {
372   TF_Status* status = TF_NewStatus();
373   TFE_ContextOptions* opts = TFE_NewContextOptions();
374   TFE_Context* ctx = TFE_NewContext(opts, status);
375   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
376   TFE_DeleteContextOptions(opts);
377 
378   TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
379   TF_Tensor* m_data = TFE_TensorHandleResolve(m, status);
380   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
381   float* m_float = static_cast<float*>(TF_TensorData(m_data));
382   TF_DeviceList* devices = TFE_ContextListDevices(ctx, status);
383   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
384   int num_devices = TF_DeviceListCount(devices);
385   for (int d = 0; d < num_devices; ++d) {
386     const char* name = TF_DeviceListName(devices, d, status);
387     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
388     TFE_TensorHandle* copy = TFE_TensorHandleCopyToDevice(m, ctx, name, status);
389     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
390     void* data = TFE_TensorHandleDevicePointer(copy, status);
391     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
392     size_t size = TFE_TensorHandleDeviceMemorySize(copy, status);
393     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
394     int64_t dims[] = {2, 2};
395     TFE_TensorHandle* copy_aliased = TFE_NewTensorHandleFromDeviceMemory(
396         ctx, name, TF_FLOAT, dims, 2, data, size, &Deleter, copy, status);
397     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
398     TFE_TensorHandle* on_host =
399         TFE_TensorHandleCopyToDevice(copy_aliased, ctx, "CPU:0", status);
400     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
401     TF_Tensor* resolved = TFE_TensorHandleResolve(on_host, status);
402     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
403     const float* resolved_data =
404         static_cast<const float*>(TF_TensorData(resolved));
405     EXPECT_EQ(0, memcmp(m_float, resolved_data, 4 * sizeof(float)));
406     TF_DeleteTensor(resolved);
407     TFE_DeleteTensorHandle(copy_aliased);  // Note that this will delete copy.
408     TFE_DeleteTensorHandle(on_host);
409   }
410   TF_DeleteDeviceList(devices);
411   TF_DeleteTensor(m_data);
412   TFE_DeleteTensorHandle(m);
413   TFE_DeleteContext(ctx);
414   TF_DeleteStatus(status);
415 }
416 
TEST(CAPI,TensorHandleNullptr)417 TEST(CAPI, TensorHandleNullptr) {
418   TFE_TensorHandle* h = nullptr;
419   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
420       TF_NewStatus(), TF_DeleteStatus);
421 
422   const char* device_type = TFE_TensorHandleDeviceType(h, status.get());
423   ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
424   ASSERT_EQ(device_type, nullptr);
425   ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
426 
427   TF_SetStatus(status.get(), TF_OK, "");
428 
429   int device_id = TFE_TensorHandleDeviceID(h, status.get());
430   ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
431   ASSERT_EQ(device_id, -1);
432   ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
433 }
434 
TEST(CAPI,TensorHandleDevices)435 TEST(CAPI, TensorHandleDevices) {
436   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
437       TF_NewStatus(), TF_DeleteStatus);
438   TFE_ContextOptions* opts = TFE_NewContextOptions();
439   TFE_Context* ctx = TFE_NewContext(opts, status.get());
440   TFE_DeleteContextOptions(opts);
441   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
442 
443   TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
444   const char* device_type = TFE_TensorHandleDeviceType(hcpu, status.get());
445   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
446   ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
447   int device_id = TFE_TensorHandleDeviceID(hcpu, status.get());
448   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
449   ASSERT_EQ(0, device_id) << device_id;
450 
451   // Disable the test if no GPU is present.
452   string gpu_device_name;
453   if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
454     TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
455         hcpu, ctx, gpu_device_name.c_str(), status.get());
456     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
457 
458     TFE_Op* shape_op = ShapeOp(ctx, hgpu);
459     TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get());
460     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
461     TFE_TensorHandle* retvals[1];
462     int num_retvals = 1;
463     TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get());
464     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
465 
466     device_type = TFE_TensorHandleDeviceType(retvals[0], status.get());
467     ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
468     ASSERT_TRUE(absl::StrContains(device_type, "GPU")) << device_type;
469 
470     device_id = TFE_TensorHandleDeviceID(retvals[0], status.get());
471     ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
472     ASSERT_EQ(0, device_id) << device_id;
473 
474     TFE_DeleteOp(shape_op);
475     TFE_DeleteTensorHandle(retvals[0]);
476     TFE_DeleteTensorHandle(hgpu);
477   }
478 
479   TFE_DeleteTensorHandle(hcpu);
480   TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
481   TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
482   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
483   TFE_DeleteExecutor(executor);
484   TFE_DeleteContext(ctx);
485 }
486 
TEST(CAPI,TensorHandleDefaults)487 TEST(CAPI, TensorHandleDefaults) {
488   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
489       TF_NewStatus(), TF_DeleteStatus);
490   TFE_ContextOptions* opts = TFE_NewContextOptions();
491   TFE_Context* ctx = TFE_NewContext(opts, status.get());
492   TFE_DeleteContextOptions(opts);
493   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
494 
495   TFE_TensorHandle* h_default = TestMatrixTensorHandle(ctx);
496   const char* device_type = TFE_TensorHandleDeviceType(h_default, status.get());
497   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
498   ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
499   int device_id = TFE_TensorHandleDeviceID(h_default, status.get());
500   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
501   ASSERT_EQ(0, device_id) << device_id;
502 
503   TFE_TensorHandle* h_cpu = TFE_TensorHandleCopyToDevice(
504       h_default, ctx, "/device:CPU:0", status.get());
505   const char* device_type_cpu = TFE_TensorHandleDeviceType(h_cpu, status.get());
506   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
507   ASSERT_TRUE(absl::StrContains(device_type_cpu, "CPU")) << device_type_cpu;
508   int device_id_cpu = TFE_TensorHandleDeviceID(h_cpu, status.get());
509   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
510   ASSERT_EQ(0, device_id_cpu) << device_id_cpu;
511 
512   TFE_DeleteTensorHandle(h_default);
513   TFE_DeleteTensorHandle(h_cpu);
514   TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
515   TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
516   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
517   TFE_DeleteExecutor(executor);
518   TFE_DeleteContext(ctx);
519 }
520 
521 }  // namespace
522 }  // namespace tensorflow
523