1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/c_api_experimental.h"
17
18 #include <string.h>
19
20 #include "tensorflow/c/eager/c_api.h"
21 #include "tensorflow/c/eager/c_api_test_util.h"
22 #include "tensorflow/core/lib/monitoring/collection_registry.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/protobuf.h"
25 #include "tensorflow/core/platform/str_util.h"
26 #include "tensorflow/core/platform/test.h"
27 #include "tensorflow/core/platform/test_benchmark.h"
28
29 using tensorflow::string;
30
31 namespace tensorflow {
32 namespace {
33
HasSubstr(absl::string_view base,absl::string_view substr)34 static bool HasSubstr(absl::string_view base, absl::string_view substr) {
35 bool ok = absl::StrContains(base, substr);
36 EXPECT_TRUE(ok) << base << ", expected substring " << substr;
37 return ok;
38 }
39
TEST(CAPI,MonitoringCounter0)40 TEST(CAPI, MonitoringCounter0) {
41 TF_Status* status = TF_NewStatus();
42 auto* counter =
43 TFE_MonitoringNewCounter0("test/counter", status, "description");
44 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
45 TF_DeleteStatus(status);
46 auto* cell = TFE_MonitoringGetCellCounter0(counter);
47 TFE_MonitoringCounterCellIncrementBy(cell, 1);
48 EXPECT_EQ(TFE_MonitoringCounterCellValue(cell), 1);
49 auto* collection_registry = monitoring::CollectionRegistry::Default();
50 monitoring::CollectionRegistry::CollectMetricsOptions options;
51 std::unique_ptr<monitoring::CollectedMetrics> metrics =
52 collection_registry->CollectMetrics(options);
53
54 EXPECT_EQ("test/counter",
55 metrics->point_set_map.at("test/counter")->metric_name);
56 EXPECT_EQ(
57 1, metrics->point_set_map.at("test/counter")->points.at(0)->int64_value);
58
59 TFE_MonitoringCounterCellIncrementBy(cell, 5);
60 EXPECT_EQ(TFE_MonitoringCounterCellValue(cell), 6);
61 metrics = collection_registry->CollectMetrics(options);
62 EXPECT_EQ(
63 6, metrics->point_set_map.at("test/counter")->points.at(0)->int64_value);
64
65 TFE_MonitoringDeleteCounter0(counter);
66 metrics = collection_registry->CollectMetrics(options);
67 EXPECT_EQ(metrics->point_set_map.end(),
68 metrics->point_set_map.find("test/counter"));
69 }
70
TEST(CAPI,MonitoringCounterMultiple)71 TEST(CAPI, MonitoringCounterMultiple) {
72 TF_Status* status = TF_NewStatus();
73 auto* counter1 = TFE_MonitoringNewCounter1("test/counter1", status,
74 "description", "label1");
75 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
76 auto* cell1 = TFE_MonitoringGetCellCounter1(counter1, "test");
77 TFE_MonitoringCounterCellIncrementBy(cell1, 1);
78 EXPECT_EQ(TFE_MonitoringCounterCellValue(cell1), 1);
79
80 auto* counter2 = TFE_MonitoringNewCounter2("test/counter2", status,
81 "description", "label1", "label2");
82 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
83 TF_DeleteStatus(status);
84 auto* cell2 = TFE_MonitoringGetCellCounter2(counter2, "foo", "bar");
85 TFE_MonitoringCounterCellIncrementBy(cell2, 2);
86 EXPECT_EQ(TFE_MonitoringCounterCellValue(cell2), 2);
87
88 TFE_MonitoringDeleteCounter1(counter1);
89 TFE_MonitoringDeleteCounter2(counter2);
90 }
91
TEST(CAPI,MonitoringGauge0)92 TEST(CAPI, MonitoringGauge0) {
93 TF_Status* status = TF_NewStatus();
94 auto* gauge = TFE_MonitoringNewIntGauge0("test/gauge", status, "test");
95 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
96 auto* cell = TFE_MonitoringGetCellIntGauge0(gauge);
97 TFE_MonitoringIntGaugeCellSet(cell, 1);
98 EXPECT_EQ(TFE_MonitoringIntGaugeCellValue(cell), 1);
99 auto* collection_registry = monitoring::CollectionRegistry::Default();
100 monitoring::CollectionRegistry::CollectMetricsOptions options;
101 std::unique_ptr<monitoring::CollectedMetrics> metrics =
102 collection_registry->CollectMetrics(options);
103
104 EXPECT_EQ("test/gauge", metrics->point_set_map.at("test/gauge")->metric_name);
105 EXPECT_EQ(1,
106 metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
107
108 TFE_MonitoringIntGaugeCellSet(cell, 5);
109 metrics = collection_registry->CollectMetrics(options);
110 EXPECT_EQ(5,
111 metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
112 TFE_MonitoringDeleteIntGauge0(gauge);
113 TF_DeleteStatus(status);
114 }
115
TEST(CAPI,MonitoringMultipleGauge)116 TEST(CAPI, MonitoringMultipleGauge) {
117 TF_Status* status = TF_NewStatus();
118 auto* gauge1 =
119 TFE_MonitoringNewBoolGauge1("test/gauge1", status, "test", "label1");
120 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
121 auto* cell1 = TFE_MonitoringGetCellBoolGauge1(gauge1, "foo");
122 TFE_MonitoringBoolGaugeCellSet(cell1, true);
123 EXPECT_TRUE(TFE_MonitoringBoolGaugeCellValue(cell1));
124 TFE_MonitoringDeleteBoolGauge1(gauge1);
125
126 auto* gauge2 = TFE_MonitoringNewStringGauge2("test/gauge2", status, "test",
127 "label1", "label2");
128 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
129 auto* cell2 = TFE_MonitoringGetCellStringGauge2(gauge2, "foo", "bar");
130 TFE_MonitoringStringGaugeCellSet(cell2, "str");
131 auto* buf = new TF_Buffer;
132 TFE_MonitoringStringGaugeCellValue(cell2, buf);
133 string data(static_cast<const char*>(buf->data), buf->length);
134 TF_DeleteBuffer(buf);
135 EXPECT_EQ(data, "str");
136 TFE_MonitoringDeleteStringGauge2(gauge2);
137 TF_DeleteStatus(status);
138 }
139
TEST(CAPI,MonitoringSampler0)140 TEST(CAPI, MonitoringSampler0) {
141 TF_Status* status = TF_NewStatus();
142 auto* buckets = TFE_MonitoringNewExponentialBuckets(1.0, 2.0, 2);
143 auto* sampler =
144 TFE_MonitoringNewSampler0("test/sampler", buckets, status, "test");
145 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
146 auto* cell = TFE_MonitoringGetCellSampler0(sampler);
147 TFE_MonitoringSamplerCellAdd(cell, 1.0);
148 auto* collection_registry = monitoring::CollectionRegistry::Default();
149 monitoring::CollectionRegistry::CollectMetricsOptions options;
150 std::unique_ptr<monitoring::CollectedMetrics> metrics =
151 collection_registry->CollectMetrics(options);
152
153 EXPECT_EQ("test/sampler",
154 metrics->point_set_map.at("test/sampler")->metric_name);
155 EXPECT_EQ(1.0, metrics->point_set_map.at("test/sampler")
156 ->points.at(0)
157 ->histogram_value.sum());
158
159 TFE_MonitoringSamplerCellAdd(cell, 5.0);
160 metrics = collection_registry->CollectMetrics(options);
161 EXPECT_EQ(6.0, metrics->point_set_map.at("test/sampler")
162 ->points.at(0)
163 ->histogram_value.sum());
164 TFE_MonitoringDeleteBuckets(buckets);
165 TFE_MonitoringDeleteSampler0(sampler);
166 TF_DeleteStatus(status);
167 }
168
TEST(CAPI,MonitoringMultipleSampler)169 TEST(CAPI, MonitoringMultipleSampler) {
170 TF_Status* status = TF_NewStatus();
171 auto* buckets = TFE_MonitoringNewExponentialBuckets(1.0, 2.0, 2);
172 auto* sampler1 = TFE_MonitoringNewSampler1("test/sampler1", buckets, status,
173 "test", "label1");
174 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
175 auto* cell1 = TFE_MonitoringGetCellSampler1(sampler1, "foo");
176 TFE_MonitoringSamplerCellAdd(cell1, 1.0);
177 TFE_MonitoringSamplerCellAdd(cell1, 2.0);
178 TF_Buffer* result1 = TF_NewBuffer();
179 TFE_MonitoringSamplerCellValue(cell1, result1);
180 tensorflow::HistogramProto histogram1;
181 EXPECT_TRUE(histogram1.ParseFromString(
182 {reinterpret_cast<const char*>(result1->data), result1->length}));
183 EXPECT_EQ(histogram1.sum(), 3.0);
184 TF_DeleteBuffer(result1);
185 TFE_MonitoringDeleteSampler1(sampler1);
186
187 auto* sampler2 = TFE_MonitoringNewSampler2("test/sampler2", buckets, status,
188 "test", "label1", "label2");
189 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
190 auto* cell2 = TFE_MonitoringGetCellSampler2(sampler2, "foo", "bar");
191 TFE_MonitoringSamplerCellAdd(cell2, 2.0);
192 TFE_MonitoringSamplerCellAdd(cell2, 3.0);
193 TF_Buffer* result2 = TF_NewBuffer();
194 TFE_MonitoringSamplerCellValue(cell2, result2);
195 tensorflow::HistogramProto histogram2;
196 EXPECT_TRUE(histogram2.ParseFromString(
197 {reinterpret_cast<const char*>(result2->data), result2->length}));
198 EXPECT_EQ(histogram2.sum(), 5.0);
199 TF_DeleteBuffer(result2);
200 TFE_MonitoringDeleteSampler2(sampler2);
201
202 TFE_MonitoringDeleteBuckets(buckets);
203 TF_DeleteStatus(status);
204 }
205
TEST(CAPI,CancellationManager)206 TEST(CAPI, CancellationManager) {
207 TFE_CancellationManager* c_mgr = TFE_NewCancellationManager();
208 EXPECT_FALSE(TFE_CancellationManagerIsCancelled(c_mgr));
209 TFE_CancellationManagerStartCancel(c_mgr);
210 EXPECT_TRUE(TFE_CancellationManagerIsCancelled(c_mgr));
211 TFE_DeleteCancellationManager(c_mgr);
212 }
213
TEST(CAPI,ExecutorContextDestructionOrder)214 TEST(CAPI, ExecutorContextDestructionOrder) {
215 TF_Status* status = TF_NewStatus();
216
217 {
218 TFE_ContextOptions* opts = TFE_NewContextOptions();
219 TFE_Context* ctx = TFE_NewContext(opts, status);
220 ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
221 TFE_DeleteContextOptions(opts);
222 TFE_Executor* executor = TFE_NewExecutor(
223 /*is_async=*/false, /*enable_streaming_enqueue=*/true);
224 TFE_ContextSetExecutorForThread(ctx, executor);
225
226 TFE_DeleteContext(ctx);
227 TFE_DeleteExecutor(executor);
228 }
229
230 {
231 TFE_ContextOptions* opts = TFE_NewContextOptions();
232 TFE_Context* ctx = TFE_NewContext(opts, status);
233 ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
234 TFE_DeleteContextOptions(opts);
235 TFE_Executor* executor = TFE_NewExecutor(
236 /*is_async=*/false, /*enable_streaming_enqueue=*/true);
237 TFE_ContextSetExecutorForThread(ctx, executor);
238
239 TFE_DeleteExecutor(executor);
240 TFE_DeleteContext(ctx);
241 }
242 TF_DeleteStatus(status);
243 }
244
TEST(CAPI,Function_ident_CPU)245 TEST(CAPI, Function_ident_CPU) {
246 // First create a simple identity function.
247 TF_Graph* function_graph = TF_NewGraph();
248 TF_OperationDescription* arg_descr =
249 TF_NewOperation(function_graph, "Placeholder", "arg");
250 TF_SetAttrType(arg_descr, "dtype", TF_INT32);
251 TF_Status* status = TF_NewStatus();
252 TF_Operation* arg = TF_FinishOperation(arg_descr, status);
253 ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
254 TF_OperationDescription* id_descr =
255 TF_NewOperation(function_graph, "Identity", "id");
256 TF_SetAttrType(id_descr, "T", TF_INT32);
257 TF_AddInput(id_descr, {arg, 0});
258 TF_Operation* id = TF_FinishOperation(id_descr, status);
259 ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
260 TF_Output input{arg, 0};
261 TF_Output output{id, 0};
262 TF_Function* fn =
263 TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1,
264 &output, nullptr, nullptr, "test", status);
265 ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
266 TF_DeleteGraph(function_graph);
267 TFE_ContextOptions* opts = TFE_NewContextOptions();
268 TFE_Context* ctx = TFE_NewContext(opts, status);
269 ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
270 TFE_DeleteContextOptions(opts);
271 TFE_ContextAddFunction(ctx, fn, status);
272 ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
273 TF_DeleteFunction(fn);
274
275 for (bool async : {false, true, false}) {
276 TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx);
277 TFE_Executor* executor = TFE_NewExecutor(
278 /*is_async=*/async, /*enable_streaming_enqueue=*/true);
279 TFE_ContextSetExecutorForThread(ctx, executor);
280 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
281
282 TF_Tensor* t =
283 TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
284 *reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
285 TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
286 ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
287 TF_DeleteTensor(t);
288
289 TFE_Op* op = TFE_NewOp(ctx, "ident", status);
290 ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
291 TFE_OpAddInput(op, h, status);
292 ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
293
294 std::vector<TFE_TensorHandle*> result;
295 result.push_back(nullptr);
296 int num_retvals = 1;
297 TFE_Execute(op, result.data(), &num_retvals, status);
298 TFE_DeleteOp(op);
299 ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
300 ASSERT_EQ(num_retvals, 1);
301
302 TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
303 ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
304 EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
305 TFE_ContextSetExecutorForThread(ctx, old_executor);
306 TFE_ExecutorWaitForAllPendingNodes(executor, status);
307 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
308 TFE_DeleteExecutor(executor);
309 TFE_DeleteExecutor(old_executor);
310 TFE_DeleteTensorHandle(h);
311 TF_DeleteTensor(r);
312 TFE_DeleteTensorHandle(result[0]);
313 }
314 TFE_ContextRemoveFunction(ctx, "ident", status);
315 ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
316 TFE_DeleteContext(ctx);
317 ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
318 TF_DeleteStatus(status);
319 }
320
Executor_MatMul_CPU(bool async)321 void Executor_MatMul_CPU(bool async) {
322 TF_Status* status = TF_NewStatus();
323 TFE_ContextOptions* opts = TFE_NewContextOptions();
324 TFE_Context* ctx = TFE_NewContext(opts, status);
325 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
326 TFE_DeleteContextOptions(opts);
327
328 TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx);
329 TFE_Executor* executor = TFE_NewExecutor(
330 /*is_async=*/async, /*enable_streaming_enqueue=*/true);
331 TFE_ContextSetExecutorForThread(ctx, executor);
332
333 TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
334 TFE_Op* matmul = MatMulOp(ctx, m, m);
335 TFE_TensorHandle* retvals[2] = {nullptr, nullptr};
336 int num_retvals = 2;
337 TFE_Execute(matmul, &retvals[0], &num_retvals, status);
338 EXPECT_EQ(1, num_retvals);
339 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
340 TFE_DeleteOp(matmul);
341 TFE_DeleteTensorHandle(m);
342 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
343
344 TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
345 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
346 TFE_DeleteTensorHandle(retvals[0]);
347 TFE_ContextSetExecutorForThread(ctx, old_executor);
348 TFE_ExecutorWaitForAllPendingNodes(executor, status);
349 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
350 TFE_DeleteExecutor(executor);
351 TFE_DeleteExecutor(old_executor);
352 TFE_DeleteContext(ctx);
353 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
354 float product[4] = {0};
355 EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
356 memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
357 TF_DeleteTensor(t);
358 EXPECT_EQ(7, product[0]);
359 EXPECT_EQ(10, product[1]);
360 EXPECT_EQ(15, product[2]);
361 EXPECT_EQ(22, product[3]);
362 TF_DeleteStatus(status);
363 }
TEST(CAPI,Executor_MatMul_CPU)364 TEST(CAPI, Executor_MatMul_CPU) { Executor_MatMul_CPU(false); }
TEST(CAPI,Executor_MatMul_CPUAsync)365 TEST(CAPI, Executor_MatMul_CPUAsync) { Executor_MatMul_CPU(true); }
366
Deleter(void * data,size_t unused,void * tensor_handle)367 void Deleter(void* data, size_t unused, void* tensor_handle) {
368 TFE_DeleteTensorHandle(static_cast<TFE_TensorHandle*>(tensor_handle));
369 }
370
TEST(CAPI,TensorHandleOnDeviceMemory)371 TEST(CAPI, TensorHandleOnDeviceMemory) {
372 TF_Status* status = TF_NewStatus();
373 TFE_ContextOptions* opts = TFE_NewContextOptions();
374 TFE_Context* ctx = TFE_NewContext(opts, status);
375 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
376 TFE_DeleteContextOptions(opts);
377
378 TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
379 TF_Tensor* m_data = TFE_TensorHandleResolve(m, status);
380 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
381 float* m_float = static_cast<float*>(TF_TensorData(m_data));
382 TF_DeviceList* devices = TFE_ContextListDevices(ctx, status);
383 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
384 int num_devices = TF_DeviceListCount(devices);
385 for (int d = 0; d < num_devices; ++d) {
386 const char* name = TF_DeviceListName(devices, d, status);
387 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
388 TFE_TensorHandle* copy = TFE_TensorHandleCopyToDevice(m, ctx, name, status);
389 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
390 void* data = TFE_TensorHandleDevicePointer(copy, status);
391 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
392 size_t size = TFE_TensorHandleDeviceMemorySize(copy, status);
393 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
394 int64_t dims[] = {2, 2};
395 TFE_TensorHandle* copy_aliased = TFE_NewTensorHandleFromDeviceMemory(
396 ctx, name, TF_FLOAT, dims, 2, data, size, &Deleter, copy, status);
397 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
398 TFE_TensorHandle* on_host =
399 TFE_TensorHandleCopyToDevice(copy_aliased, ctx, "CPU:0", status);
400 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
401 TF_Tensor* resolved = TFE_TensorHandleResolve(on_host, status);
402 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
403 const float* resolved_data =
404 static_cast<const float*>(TF_TensorData(resolved));
405 EXPECT_EQ(0, memcmp(m_float, resolved_data, 4 * sizeof(float)));
406 TF_DeleteTensor(resolved);
407 TFE_DeleteTensorHandle(copy_aliased); // Note that this will delete copy.
408 TFE_DeleteTensorHandle(on_host);
409 }
410 TF_DeleteDeviceList(devices);
411 TF_DeleteTensor(m_data);
412 TFE_DeleteTensorHandle(m);
413 TFE_DeleteContext(ctx);
414 TF_DeleteStatus(status);
415 }
416
TEST(CAPI,TensorHandleNullptr)417 TEST(CAPI, TensorHandleNullptr) {
418 TFE_TensorHandle* h = nullptr;
419 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
420 TF_NewStatus(), TF_DeleteStatus);
421
422 const char* device_type = TFE_TensorHandleDeviceType(h, status.get());
423 ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
424 ASSERT_EQ(device_type, nullptr);
425 ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
426
427 TF_SetStatus(status.get(), TF_OK, "");
428
429 int device_id = TFE_TensorHandleDeviceID(h, status.get());
430 ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
431 ASSERT_EQ(device_id, -1);
432 ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
433 }
434
TEST(CAPI,TensorHandleDevices)435 TEST(CAPI, TensorHandleDevices) {
436 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
437 TF_NewStatus(), TF_DeleteStatus);
438 TFE_ContextOptions* opts = TFE_NewContextOptions();
439 TFE_Context* ctx = TFE_NewContext(opts, status.get());
440 TFE_DeleteContextOptions(opts);
441 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
442
443 TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
444 const char* device_type = TFE_TensorHandleDeviceType(hcpu, status.get());
445 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
446 ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
447 int device_id = TFE_TensorHandleDeviceID(hcpu, status.get());
448 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
449 ASSERT_EQ(0, device_id) << device_id;
450
451 // Disable the test if no GPU is present.
452 string gpu_device_name;
453 if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
454 TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
455 hcpu, ctx, gpu_device_name.c_str(), status.get());
456 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
457
458 TFE_Op* shape_op = ShapeOp(ctx, hgpu);
459 TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get());
460 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
461 TFE_TensorHandle* retvals[1];
462 int num_retvals = 1;
463 TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get());
464 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
465
466 device_type = TFE_TensorHandleDeviceType(retvals[0], status.get());
467 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
468 ASSERT_TRUE(absl::StrContains(device_type, "GPU")) << device_type;
469
470 device_id = TFE_TensorHandleDeviceID(retvals[0], status.get());
471 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
472 ASSERT_EQ(0, device_id) << device_id;
473
474 TFE_DeleteOp(shape_op);
475 TFE_DeleteTensorHandle(retvals[0]);
476 TFE_DeleteTensorHandle(hgpu);
477 }
478
479 TFE_DeleteTensorHandle(hcpu);
480 TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
481 TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
482 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
483 TFE_DeleteExecutor(executor);
484 TFE_DeleteContext(ctx);
485 }
486
TEST(CAPI,TensorHandleDefaults)487 TEST(CAPI, TensorHandleDefaults) {
488 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
489 TF_NewStatus(), TF_DeleteStatus);
490 TFE_ContextOptions* opts = TFE_NewContextOptions();
491 TFE_Context* ctx = TFE_NewContext(opts, status.get());
492 TFE_DeleteContextOptions(opts);
493 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
494
495 TFE_TensorHandle* h_default = TestMatrixTensorHandle(ctx);
496 const char* device_type = TFE_TensorHandleDeviceType(h_default, status.get());
497 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
498 ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
499 int device_id = TFE_TensorHandleDeviceID(h_default, status.get());
500 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
501 ASSERT_EQ(0, device_id) << device_id;
502
503 TFE_TensorHandle* h_cpu = TFE_TensorHandleCopyToDevice(
504 h_default, ctx, "/device:CPU:0", status.get());
505 const char* device_type_cpu = TFE_TensorHandleDeviceType(h_cpu, status.get());
506 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
507 ASSERT_TRUE(absl::StrContains(device_type_cpu, "CPU")) << device_type_cpu;
508 int device_id_cpu = TFE_TensorHandleDeviceID(h_cpu, status.get());
509 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
510 ASSERT_EQ(0, device_id_cpu) << device_id_cpu;
511
512 TFE_DeleteTensorHandle(h_default);
513 TFE_DeleteTensorHandle(h_cpu);
514 TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
515 TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
516 ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
517 TFE_DeleteExecutor(executor);
518 TFE_DeleteContext(ctx);
519 }
520
521 } // namespace
522 } // namespace tensorflow
523