1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/c_api.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <iterator>
21 #include <memory>
22 #include <vector>
23
24 #include "tensorflow/c/c_test_util.h"
25 #include "tensorflow/c/tf_status.h"
26 #include "tensorflow/cc/saved_model/signature_constants.h"
27 #include "tensorflow/cc/saved_model/tag_constants.h"
28 #include "tensorflow/core/example/example.pb.h"
29 #include "tensorflow/core/example/feature.pb.h"
30 #include "tensorflow/core/framework/api_def.pb.h"
31 #include "tensorflow/core/framework/common_shape_fns.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/framework/kernel_def.pb.h"
34 #include "tensorflow/core/framework/node_def.pb.h"
35 #include "tensorflow/core/framework/node_def_util.h"
36 #include "tensorflow/core/framework/op.h"
37 #include "tensorflow/core/framework/op_def.pb.h"
38 #include "tensorflow/core/framework/op_kernel.h"
39 #include "tensorflow/core/framework/partial_tensor_shape.h"
40 #include "tensorflow/core/framework/tensor.h"
41 #include "tensorflow/core/framework/tensor_shape.pb.h"
42 #include "tensorflow/core/framework/types.pb.h"
43 #include "tensorflow/core/graph/tensor_id.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45 #include "tensorflow/core/lib/io/path.h"
46 #include "tensorflow/core/platform/path.h"
47 #include "tensorflow/core/platform/resource_loader.h"
48 #include "tensorflow/core/platform/str_util.h"
49 #include "tensorflow/core/platform/strcat.h"
50 #include "tensorflow/core/platform/test.h"
51 #include "tensorflow/core/protobuf/error_codes.pb.h"
52 #include "tensorflow/core/protobuf/meta_graph.pb.h"
53 #include "tensorflow/core/util/equal_graph_def.h"
54
55 namespace tensorflow {
56 TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
57 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
58
59 namespace {
60
ExpectHasSubstr(StringPiece s,StringPiece expected)61 static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
62 EXPECT_TRUE(absl::StrContains(s, expected))
63 << "'" << s << "' does not contain '" << expected << "'";
64 }
65
66 // Returns the GPU device name if there is one (with arbitrary tie breaking if
67 // there are more than one), or "" otherwise.
GPUDeviceName(TF_Session * session)68 string GPUDeviceName(TF_Session* session) {
69 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
70 TF_NewStatus(), TF_DeleteStatus);
71 TF_Status* s = status.get();
72 std::unique_ptr<TF_DeviceList, decltype(&TF_DeleteDeviceList)> list(
73 TF_SessionListDevices(session, s), TF_DeleteDeviceList);
74 TF_DeviceList* device_list = list.get();
75
76 CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
77
78 const int num_devices = TF_DeviceListCount(device_list);
79 LOG(INFO) << "There are " << num_devices << " devices.";
80 for (int i = 0; i < num_devices; ++i) {
81 const char* device_name = TF_DeviceListName(device_list, i, s);
82 CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
83 const char* device_type = TF_DeviceListType(device_list, i, s);
84 CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
85 LOG(INFO) << "Device " << i << " has name " << device_name << ", type "
86 << device_type;
87 if (string(device_type) == DEVICE_GPU) {
88 return device_name;
89 }
90 }
91 // No GPU device found.
92 return "";
93 }
94
GPUDeviceName()95 string GPUDeviceName() {
96 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
97 TF_NewStatus(), TF_DeleteStatus);
98 TF_Status* s = status.get();
99 std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph(TF_NewGraph(),
100 TF_DeleteGraph);
101
102 TF_SessionOptions* opts = TF_NewSessionOptions();
103 TF_Session* sess = TF_NewSession(graph.get(), opts, s);
104 TF_DeleteSessionOptions(opts);
105
106 const string gpu_device_name = GPUDeviceName(sess);
107 TF_DeleteSession(sess, s);
108 CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
109 return gpu_device_name;
110 }
111
TEST(CAPI,Version)112 TEST(CAPI, Version) { EXPECT_STRNE("", TF_Version()); }
113
TEST(CAPI,Status)114 TEST(CAPI, Status) {
115 TF_Status* s = TF_NewStatus();
116 EXPECT_EQ(TF_OK, TF_GetCode(s));
117 EXPECT_EQ(string(), TF_Message(s));
118 TF_SetStatus(s, TF_CANCELLED, "cancel");
119 EXPECT_EQ(TF_CANCELLED, TF_GetCode(s));
120 EXPECT_EQ(string("cancel"), TF_Message(s));
121 TF_DeleteStatus(s);
122 }
123
Deallocator(void * data,size_t,void * arg)124 void Deallocator(void* data, size_t, void* arg) {
125 tensorflow::cpu_allocator()->DeallocateRaw(data);
126 *reinterpret_cast<bool*>(arg) = true;
127 }
128
TEST(CAPI,Tensor)129 TEST(CAPI, Tensor) {
130 const int num_bytes = 6 * sizeof(float);
131 float* values =
132 reinterpret_cast<float*>(tensorflow::cpu_allocator()->AllocateRaw(
133 EIGEN_MAX_ALIGN_BYTES, num_bytes));
134 int64_t dims[] = {2, 3};
135 bool deallocator_called = false;
136 TF_Tensor* t = TF_NewTensor(TF_FLOAT, dims, 2, values, num_bytes,
137 &Deallocator, &deallocator_called);
138 EXPECT_FALSE(deallocator_called);
139 EXPECT_EQ(TF_FLOAT, TF_TensorType(t));
140 EXPECT_EQ(2, TF_NumDims(t));
141 EXPECT_EQ(dims[0], TF_Dim(t, 0));
142 EXPECT_EQ(dims[1], TF_Dim(t, 1));
143 EXPECT_EQ(num_bytes, TF_TensorByteSize(t));
144 EXPECT_EQ(static_cast<void*>(values), TF_TensorData(t));
145 TF_DeleteTensor(t);
146 EXPECT_TRUE(deallocator_called);
147 }
148
NoOpDeallocator(void * data,size_t,void *)149 void NoOpDeallocator(void* data, size_t, void*) {}
150
TEST(CAPI,MalformedTensor)151 TEST(CAPI, MalformedTensor) {
152 // See https://github.com/tensorflow/tensorflow/issues/7394
153 // num_dims = 0 implies a scalar, so should be backed by at least 4 bytes of
154 // data.
155 TF_Tensor* t =
156 TF_NewTensor(TF_FLOAT, nullptr, 0, nullptr, 0, &NoOpDeallocator, nullptr);
157 ASSERT_TRUE(t == nullptr);
158 }
159
TEST(CAPI,AllocateTensor)160 TEST(CAPI, AllocateTensor) {
161 const int num_bytes = 6 * sizeof(float);
162 int64_t dims[] = {2, 3};
163 TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims, 2, num_bytes);
164 EXPECT_EQ(TF_FLOAT, TF_TensorType(t));
165 EXPECT_EQ(2, TF_NumDims(t));
166 EXPECT_EQ(dims[0], TF_Dim(t, 0));
167 EXPECT_EQ(dims[1], TF_Dim(t, 1));
168 EXPECT_EQ(num_bytes, TF_TensorByteSize(t));
169 EXPECT_EQ(6, TF_TensorElementCount(t));
170 TF_DeleteTensor(t);
171 }
172
TEST(CAPI,MaybeMove)173 TEST(CAPI, MaybeMove) {
174 const int num_bytes = 6 * sizeof(float);
175 float* values =
176 reinterpret_cast<float*>(tensorflow::cpu_allocator()->AllocateRaw(
177 EIGEN_MAX_ALIGN_BYTES, num_bytes));
178 int64_t dims[] = {2, 3};
179 bool deallocator_called = false;
180 TF_Tensor* t = TF_NewTensor(TF_FLOAT, dims, 2, values, num_bytes,
181 &Deallocator, &deallocator_called);
182
183 TF_Tensor* o = TF_TensorMaybeMove(t);
184 ASSERT_TRUE(o == nullptr); // It is unsafe to move memory TF might not own.
185 TF_DeleteTensor(t);
186 EXPECT_TRUE(deallocator_called);
187 }
188
TEST(CAPI,LibraryLoadFunctions)189 TEST(CAPI, LibraryLoadFunctions) {
190 // TODO(b/73318067): Fix linking for the GPU test generated by the
191 // tf_cuda_cc_test() bazel rule and remove the next line.
192 if (!GPUDeviceName().empty()) return;
193
194 #if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
195 {
196 // Load the library.
197 TF_Status* status = TF_NewStatus();
198 string lib_path = tensorflow::GetDataDependencyFilepath(
199 tensorflow::io::JoinPath("tensorflow", "c", "test_op1.so"));
200 TF_Library* lib = TF_LoadLibrary(lib_path.c_str(), status);
201 TF_Code code = TF_GetCode(status);
202 string status_msg(TF_Message(status));
203 TF_DeleteStatus(status);
204 ASSERT_EQ(TF_OK, code) << status_msg;
205
206 // Test op list.
207 TF_Buffer op_list_buf = TF_GetOpList(lib);
208 tensorflow::OpList op_list;
209 EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length));
210 ASSERT_EQ(op_list.op_size(), 1);
211 EXPECT_EQ("TestCApi1", op_list.op(0).name());
212 TF_DeleteLibraryHandle(lib);
213 }
214 #endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
215 {
216 TF_Buffer* op_list_buffer = TF_GetAllOpList();
217 tensorflow::OpList op_list;
218 op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
219 ASSERT_GE(op_list.op_size(), 1);
220 typedef tensorflow::protobuf::RepeatedPtrField<tensorflow::OpDef> OpDefs;
221 const OpDefs& ops = op_list.op();
222 bool found = std::find_if(ops.begin(), ops.end(),
223 [](const tensorflow::OpDef& op_def) {
224 return op_def.name() == "TestCApi";
225 }) != ops.end();
226 EXPECT_TRUE(found);
227 TF_DeleteBuffer(op_list_buffer);
228 }
229 }
230
TestEncodeDecode(int line,const std::vector<string> & data)231 void TestEncodeDecode(int line, const std::vector<string>& data) {
232 const tensorflow::int64 n = data.size();
233 Status status;
234 for (const std::vector<tensorflow::int64>& dims :
235 std::vector<std::vector<tensorflow::int64>>{
236 {n}, {1, n}, {n, 1}, {n / 2, 2}}) {
237 // Create C++ Tensor
238 Tensor src(tensorflow::DT_STRING, TensorShape(dims));
239 for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
240 src.flat<tstring>()(i) = data[i];
241 }
242 TF_Tensor* dst = TF_TensorFromTensor(src, &status);
243 ASSERT_TRUE(status.ok()) << status.error_message();
244
245 // Convert back to a C++ Tensor and ensure we get expected output.
246 Tensor output;
247 ASSERT_EQ(Status::OK(), TF_TensorToTensor(dst, &output)) << line;
248 ASSERT_EQ(src.NumElements(), output.NumElements()) << line;
249 for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
250 ASSERT_EQ(data[i], output.flat<tstring>()(i)) << line;
251 }
252
253 TF_DeleteTensor(dst);
254 }
255 }
256
TEST(CAPI,TensorEncodeDecodeStrings)257 TEST(CAPI, TensorEncodeDecodeStrings) {
258 TestEncodeDecode(__LINE__, {});
259 TestEncodeDecode(__LINE__, {"hello"});
260 TestEncodeDecode(__LINE__,
261 {"the", "quick", "brown", "fox", "jumped", "over"});
262
263 string big(1000, 'a');
264 TestEncodeDecode(__LINE__, {"small", big, "small2"});
265 }
266
TEST(CAPI,SessionOptions)267 TEST(CAPI, SessionOptions) {
268 TF_SessionOptions* opt = TF_NewSessionOptions();
269 TF_DeleteSessionOptions(opt);
270 }
271
TEST(CAPI,DeprecatedSession)272 TEST(CAPI, DeprecatedSession) {
273 TF_Status* s = TF_NewStatus();
274 TF_SessionOptions* opt = TF_NewSessionOptions();
275 TF_DeprecatedSession* session = TF_NewDeprecatedSession(opt, s);
276 TF_DeleteSessionOptions(opt);
277 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
278
279 TF_Buffer* run_options = TF_NewBufferFromString("", 0);
280 TF_Buffer* run_metadata = TF_NewBuffer();
281 TF_Run(session, run_options, nullptr, nullptr, 0, nullptr, nullptr, 0,
282 nullptr, 0, run_metadata, s);
283 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s);
284 EXPECT_EQ("Session was not created with a graph before Run()!",
285 string(TF_Message(s)));
286 TF_DeleteBuffer(run_metadata);
287 TF_DeleteBuffer(run_options);
288
289 TF_DeleteDeprecatedSession(session, s);
290 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
291
292 TF_DeleteStatus(s);
293 }
294
TEST(CAPI,DataTypeEnum)295 TEST(CAPI, DataTypeEnum) {
296 EXPECT_EQ(TF_FLOAT, static_cast<TF_DataType>(tensorflow::DT_FLOAT));
297 EXPECT_EQ(TF_DOUBLE, static_cast<TF_DataType>(tensorflow::DT_DOUBLE));
298 EXPECT_EQ(TF_INT32, static_cast<TF_DataType>(tensorflow::DT_INT32));
299 EXPECT_EQ(TF_UINT8, static_cast<TF_DataType>(tensorflow::DT_UINT8));
300 EXPECT_EQ(TF_INT16, static_cast<TF_DataType>(tensorflow::DT_INT16));
301 EXPECT_EQ(TF_INT8, static_cast<TF_DataType>(tensorflow::DT_INT8));
302 EXPECT_EQ(TF_STRING, static_cast<TF_DataType>(tensorflow::DT_STRING));
303 EXPECT_EQ(TF_COMPLEX64, static_cast<TF_DataType>(tensorflow::DT_COMPLEX64));
304 EXPECT_EQ(TF_COMPLEX, TF_COMPLEX64);
305 EXPECT_EQ(TF_INT64, static_cast<TF_DataType>(tensorflow::DT_INT64));
306 EXPECT_EQ(TF_BOOL, static_cast<TF_DataType>(tensorflow::DT_BOOL));
307 EXPECT_EQ(TF_QINT8, static_cast<TF_DataType>(tensorflow::DT_QINT8));
308 EXPECT_EQ(TF_QUINT8, static_cast<TF_DataType>(tensorflow::DT_QUINT8));
309 EXPECT_EQ(TF_QINT32, static_cast<TF_DataType>(tensorflow::DT_QINT32));
310 EXPECT_EQ(TF_BFLOAT16, static_cast<TF_DataType>(tensorflow::DT_BFLOAT16));
311 EXPECT_EQ(TF_QINT16, static_cast<TF_DataType>(tensorflow::DT_QINT16));
312 EXPECT_EQ(TF_QUINT16, static_cast<TF_DataType>(tensorflow::DT_QUINT16));
313 EXPECT_EQ(TF_UINT16, static_cast<TF_DataType>(tensorflow::DT_UINT16));
314 EXPECT_EQ(TF_COMPLEX128, static_cast<TF_DataType>(tensorflow::DT_COMPLEX128));
315 EXPECT_EQ(TF_HALF, static_cast<TF_DataType>(tensorflow::DT_HALF));
316 EXPECT_EQ(TF_DataTypeSize(TF_DOUBLE),
317 tensorflow::DataTypeSize(tensorflow::DT_DOUBLE));
318 EXPECT_EQ(TF_DataTypeSize(TF_STRING),
319 tensorflow::DataTypeSize(tensorflow::DT_STRING));
320 // Test with invalid type; should always return 0 as documented
321 EXPECT_EQ(TF_DataTypeSize(static_cast<TF_DataType>(0)), 0);
322 }
323
TEST(CAPI,StatusEnum)324 TEST(CAPI, StatusEnum) {
325 EXPECT_EQ(TF_OK, static_cast<TF_Code>(tensorflow::error::OK));
326 EXPECT_EQ(TF_CANCELLED, static_cast<TF_Code>(tensorflow::error::CANCELLED));
327 EXPECT_EQ(TF_UNKNOWN, static_cast<TF_Code>(tensorflow::error::UNKNOWN));
328 EXPECT_EQ(TF_INVALID_ARGUMENT,
329 static_cast<TF_Code>(tensorflow::error::INVALID_ARGUMENT));
330 EXPECT_EQ(TF_DEADLINE_EXCEEDED,
331 static_cast<TF_Code>(tensorflow::error::DEADLINE_EXCEEDED));
332 EXPECT_EQ(TF_NOT_FOUND, static_cast<TF_Code>(tensorflow::error::NOT_FOUND));
333 EXPECT_EQ(TF_ALREADY_EXISTS,
334 static_cast<TF_Code>(tensorflow::error::ALREADY_EXISTS));
335 EXPECT_EQ(TF_PERMISSION_DENIED,
336 static_cast<TF_Code>(tensorflow::error::PERMISSION_DENIED));
337 EXPECT_EQ(TF_UNAUTHENTICATED,
338 static_cast<TF_Code>(tensorflow::error::UNAUTHENTICATED));
339 EXPECT_EQ(TF_RESOURCE_EXHAUSTED,
340 static_cast<TF_Code>(tensorflow::error::RESOURCE_EXHAUSTED));
341 EXPECT_EQ(TF_FAILED_PRECONDITION,
342 static_cast<TF_Code>(tensorflow::error::FAILED_PRECONDITION));
343 EXPECT_EQ(TF_ABORTED, static_cast<TF_Code>(tensorflow::error::ABORTED));
344 EXPECT_EQ(TF_OUT_OF_RANGE,
345 static_cast<TF_Code>(tensorflow::error::OUT_OF_RANGE));
346 EXPECT_EQ(TF_UNIMPLEMENTED,
347 static_cast<TF_Code>(tensorflow::error::UNIMPLEMENTED));
348 EXPECT_EQ(TF_INTERNAL, static_cast<TF_Code>(tensorflow::error::INTERNAL));
349 EXPECT_EQ(TF_UNAVAILABLE,
350 static_cast<TF_Code>(tensorflow::error::UNAVAILABLE));
351 EXPECT_EQ(TF_DATA_LOSS, static_cast<TF_Code>(tensorflow::error::DATA_LOSS));
352 }
353
TEST(CAPI,GetAllOpList)354 TEST(CAPI, GetAllOpList) {
355 TF_Buffer* buf = TF_GetAllOpList();
356 tensorflow::OpList op_list;
357 EXPECT_TRUE(op_list.ParseFromArray(buf->data, buf->length));
358 EXPECT_GT(op_list.op_size(), 0);
359 TF_DeleteBuffer(buf);
360 }
361
TEST(CAPI,SetShape)362 TEST(CAPI, SetShape) {
363 TF_Status* s = TF_NewStatus();
364 TF_Graph* graph = TF_NewGraph();
365
366 TF_Operation* feed = Placeholder(graph, s);
367 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
368 TF_Output feed_out_0 = TF_Output{feed, 0};
369 int num_dims;
370
371 // Fetch the shape, it should be completely unknown.
372 num_dims = TF_GraphGetTensorNumDims(graph, feed_out_0, s);
373 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
374 EXPECT_EQ(-1, num_dims);
375
376 // Set the shape to be unknown, expect no change.
377 TF_GraphSetTensorShape(graph, feed_out_0, /*dims=*/nullptr, -1, s);
378 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
379 num_dims = TF_GraphGetTensorNumDims(graph, feed_out_0, s);
380 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
381 EXPECT_EQ(-1, num_dims);
382
383 // Set the shape to be 2 x Unknown
384 int64_t dims[] = {2, -1};
385 TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
386 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
387
388 // Fetch the shape and validate it is 2 by -1.
389 num_dims = TF_GraphGetTensorNumDims(graph, feed_out_0, s);
390 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
391 EXPECT_EQ(2, num_dims);
392
393 // Resize the dimension vector appropriately.
394 int64_t returned_dims[2];
395 TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
396 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
397 EXPECT_EQ(dims[0], returned_dims[0]);
398 EXPECT_EQ(dims[1], returned_dims[1]);
399
400 // Set to a new valid shape: [2, 3]
401 dims[1] = 3;
402 TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
403 EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
404
405 // Fetch and see that the new value is returned.
406 TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
407 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
408 EXPECT_EQ(dims[0], returned_dims[0]);
409 EXPECT_EQ(dims[1], returned_dims[1]);
410
411 // Try to set 'unknown' with unknown rank on the shape and see that
412 // it doesn't change.
413 TF_GraphSetTensorShape(graph, feed_out_0, /*dims=*/nullptr, -1, s);
414 EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
415 TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
416 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
417 EXPECT_EQ(2, num_dims);
418 EXPECT_EQ(2, returned_dims[0]);
419 EXPECT_EQ(3, returned_dims[1]);
420
421 // Try to set 'unknown' with same rank on the shape and see that
422 // it doesn't change.
423 dims[0] = -1;
424 dims[1] = -1;
425 TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
426 EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
427 // Fetch and see that the new value is returned.
428 TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
429 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
430 EXPECT_EQ(2, num_dims);
431 EXPECT_EQ(2, returned_dims[0]);
432 EXPECT_EQ(3, returned_dims[1]);
433
434 // Try to fetch a shape with the wrong num_dims
435 TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s);
436 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s);
437
438 // Try to set an invalid shape (cannot change 2x3 to a 2x5).
439 dims[1] = 5;
440 TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
441 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s);
442
443 // Test for a scalar.
444 TF_Operation* three = ScalarConst(3, graph, s);
445 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
446 TF_Output three_out_0 = TF_Output{three, 0};
447
448 num_dims = TF_GraphGetTensorNumDims(graph, three_out_0, s);
449 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
450 EXPECT_EQ(0, num_dims);
451 TF_GraphGetTensorShape(graph, three_out_0, returned_dims, num_dims, s);
452 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
453
454 // Clean up
455 TF_DeleteGraph(graph);
456 TF_DeleteStatus(s);
457 }
458
TEST(CAPI,Graph)459 TEST(CAPI, Graph) {
460 TF_Status* s = TF_NewStatus();
461 TF_Graph* graph = TF_NewGraph();
462
463 // Make a placeholder operation.
464 TF_Operation* feed = Placeholder(graph, s);
465 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
466
467 // Test TF_Operation*() query functions.
468 EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
469 EXPECT_EQ(string("Placeholder"), string(TF_OperationOpType(feed)));
470 EXPECT_EQ(string(""), string(TF_OperationDevice(feed)));
471 EXPECT_EQ(1, TF_OperationNumOutputs(feed));
472 EXPECT_EQ(TF_INT32, TF_OperationOutputType(TF_Output{feed, 0}));
473 EXPECT_EQ(1, TF_OperationOutputListLength(feed, "output", s));
474 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
475 EXPECT_EQ(0, TF_OperationNumInputs(feed));
476 EXPECT_EQ(0, TF_OperationOutputNumConsumers(TF_Output{feed, 0}));
477 EXPECT_EQ(0, TF_OperationNumControlInputs(feed));
478 EXPECT_EQ(0, TF_OperationNumControlOutputs(feed));
479
480 tensorflow::AttrValue attr_value;
481 ASSERT_TRUE(GetAttrValue(feed, "dtype", &attr_value, s)) << TF_Message(s);
482 EXPECT_EQ(attr_value.type(), tensorflow::DT_INT32);
483
484 // Test not found errors in TF_Operation*() query functions.
485 EXPECT_EQ(-1, TF_OperationOutputListLength(feed, "bogus", s));
486 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s));
487
488 ASSERT_FALSE(GetAttrValue(feed, "missing", &attr_value, s));
489 EXPECT_EQ(string("Operation 'feed' has no attr named 'missing'."),
490 string(TF_Message(s)));
491
492 // Make a constant oper with the scalar "3".
493 TF_Operation* three = ScalarConst(3, graph, s);
494 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
495
496 // Add oper.
497 TF_Operation* add = Add(feed, three, graph, s);
498 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
499
500 // Test TF_Operation*() query functions.
501 EXPECT_EQ(string("add"), string(TF_OperationName(add)));
502 EXPECT_EQ(string("AddN"), string(TF_OperationOpType(add)));
503 EXPECT_EQ(string(""), string(TF_OperationDevice(add)));
504 EXPECT_EQ(1, TF_OperationNumOutputs(add));
505 EXPECT_EQ(TF_INT32, TF_OperationOutputType(TF_Output{add, 0}));
506 EXPECT_EQ(1, TF_OperationOutputListLength(add, "sum", s));
507 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
508 EXPECT_EQ(2, TF_OperationNumInputs(add));
509 EXPECT_EQ(2, TF_OperationInputListLength(add, "inputs", s));
510 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
511 EXPECT_EQ(TF_INT32, TF_OperationInputType(TF_Input{add, 0}));
512 EXPECT_EQ(TF_INT32, TF_OperationInputType(TF_Input{add, 1}));
513 TF_Output add_in_0 = TF_OperationInput(TF_Input{add, 0});
514 EXPECT_EQ(feed, add_in_0.oper);
515 EXPECT_EQ(0, add_in_0.index);
516 TF_Output add_in_1 = TF_OperationInput(TF_Input{add, 1});
517 EXPECT_EQ(three, add_in_1.oper);
518 EXPECT_EQ(0, add_in_1.index);
519 EXPECT_EQ(0, TF_OperationOutputNumConsumers(TF_Output{add, 0}));
520 EXPECT_EQ(0, TF_OperationNumControlInputs(add));
521 EXPECT_EQ(0, TF_OperationNumControlOutputs(add));
522
523 ASSERT_TRUE(GetAttrValue(add, "T", &attr_value, s)) << TF_Message(s);
524 EXPECT_EQ(attr_value.type(), tensorflow::DT_INT32);
525 ASSERT_TRUE(GetAttrValue(add, "N", &attr_value, s)) << TF_Message(s);
526 EXPECT_EQ(attr_value.i(), 2);
527
528 // Placeholder oper now has a consumer.
529 ASSERT_EQ(1, TF_OperationOutputNumConsumers(TF_Output{feed, 0}));
530 TF_Input feed_port;
531 EXPECT_EQ(1, TF_OperationOutputConsumers(TF_Output{feed, 0}, &feed_port, 1));
532 EXPECT_EQ(add, feed_port.oper);
533 EXPECT_EQ(0, feed_port.index);
534
535 // The scalar const oper also has a consumer.
536 ASSERT_EQ(1, TF_OperationOutputNumConsumers(TF_Output{three, 0}));
537 TF_Input three_port;
538 EXPECT_EQ(1,
539 TF_OperationOutputConsumers(TF_Output{three, 0}, &three_port, 1));
540 EXPECT_EQ(add, three_port.oper);
541 EXPECT_EQ(1, three_port.index);
542
543 // Serialize to GraphDef.
544 GraphDef graph_def;
545 ASSERT_TRUE(GetGraphDef(graph, &graph_def));
546
547 // Validate GraphDef is what we expect.
548 bool found_placeholder = false;
549 bool found_scalar_const = false;
550 bool found_add = false;
551 for (const auto& n : graph_def.node()) {
552 if (IsPlaceholder(n)) {
553 EXPECT_FALSE(found_placeholder);
554 found_placeholder = true;
555 } else if (IsScalarConst(n, 3)) {
556 EXPECT_FALSE(found_scalar_const);
557 found_scalar_const = true;
558 } else if (IsAddN(n, 2)) {
559 EXPECT_FALSE(found_add);
560 found_add = true;
561 } else {
562 ADD_FAILURE() << "Unexpected NodeDef: " << n.DebugString();
563 }
564 }
565 EXPECT_TRUE(found_placeholder);
566 EXPECT_TRUE(found_scalar_const);
567 EXPECT_TRUE(found_add);
568
569 // Add another oper to the graph.
570 TF_Operation* neg = Neg(add, graph, s);
571 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
572
573 // Serialize to NodeDef.
574 NodeDef node_def;
575 ASSERT_TRUE(GetNodeDef(neg, &node_def));
576
577 // Validate NodeDef is what we expect.
578 EXPECT_TRUE(IsNeg(node_def, "add"));
579
580 // Serialize to GraphDef.
581 GraphDef graph_def2;
582 ASSERT_TRUE(GetGraphDef(graph, &graph_def2));
583
584 // Compare with first GraphDef + added NodeDef.
585 NodeDef* added_node = graph_def.add_node();
586 *added_node = node_def;
587 EXPECT_EQ(graph_def.DebugString(), graph_def2.DebugString());
588
589 // Look up some nodes by name.
590 TF_Operation* neg2 = TF_GraphOperationByName(graph, "neg");
591 EXPECT_TRUE(neg == neg2);
592 NodeDef node_def2;
593 ASSERT_TRUE(GetNodeDef(neg2, &node_def2));
594 EXPECT_EQ(node_def.DebugString(), node_def2.DebugString());
595
596 TF_Operation* feed2 = TF_GraphOperationByName(graph, "feed");
597 EXPECT_TRUE(feed == feed2);
598 ASSERT_TRUE(GetNodeDef(feed, &node_def));
599 ASSERT_TRUE(GetNodeDef(feed2, &node_def2));
600 EXPECT_EQ(node_def.DebugString(), node_def2.DebugString());
601
602 // Test iterating through the nodes of a graph.
603 found_placeholder = false;
604 found_scalar_const = false;
605 found_add = false;
606 bool found_neg = false;
607 size_t pos = 0;
608 TF_Operation* oper;
609 while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) {
610 if (oper == feed) {
611 EXPECT_FALSE(found_placeholder);
612 found_placeholder = true;
613 } else if (oper == three) {
614 EXPECT_FALSE(found_scalar_const);
615 found_scalar_const = true;
616 } else if (oper == add) {
617 EXPECT_FALSE(found_add);
618 found_add = true;
619 } else if (oper == neg) {
620 EXPECT_FALSE(found_neg);
621 found_neg = true;
622 } else {
623 ASSERT_TRUE(GetNodeDef(oper, &node_def));
624 ADD_FAILURE() << "Unexpected Node: " << node_def.DebugString();
625 }
626 }
627 EXPECT_TRUE(found_placeholder);
628 EXPECT_TRUE(found_scalar_const);
629 EXPECT_TRUE(found_add);
630 EXPECT_TRUE(found_neg);
631
632 // Clean up
633 TF_DeleteGraph(graph);
634 TF_DeleteStatus(s);
635 }
636
TEST(CAPI,UpdateEdge)637 TEST(CAPI, UpdateEdge) {
638 TF_Status* s = TF_NewStatus();
639 TF_Graph* graph = TF_NewGraph();
640
641 // Make two scalar constants.
642 TF_Operation* one = ScalarConst(1, graph, s, "one");
643 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
644
645 TF_Operation* two = ScalarConst(2, graph, s, "two");
646 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
647
648 // Add oper.
649 TF_Operation* add = Add(one, two, graph, s, "add");
650 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
651
652 // Add another oper to the graph.
653 TF_Operation* neg = Neg(add, graph, s, "neg");
654 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
655
656 NodeDef node_def_neg;
657 ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
658 EXPECT_EQ(string("add"), node_def_neg.input(0));
659
660 // update edge of neg
661 TF_UpdateEdge(graph, TF_Output{one, 0}, TF_Input{neg, 0}, s);
662
663 ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
664 EXPECT_EQ(string("one:0"), node_def_neg.input(0));
665
666 // Clean up
667 TF_DeleteGraph(graph);
668 TF_DeleteStatus(s);
669 }
670
671 /*
672 TODO(skyewm): this test currently DCHECKs, change to bad status
673
674 TEST(CAPI, InputFromDifferentGraphError) {
675 TF_Status* s = TF_NewStatus();
676 TF_Graph* g1 = TF_NewGraph();
677 TF_Graph* g2 = TF_NewGraph();
678
679 TF_Operation* feed = Placeholder(g1, s);
680 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
681
682 // Attempt to create node in g2 with input from g1
683 Neg(feed, g2, s);
684 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s));
685 EXPECT_STREQ("foo", TF_Message(s));
686
687 TF_DeleteGraph(g1);
688 TF_DeleteGraph(g2);
689 TF_DeleteStatus(s);
690 }
691 */
692
TEST(CAPI,ImportGraphDef)693 TEST(CAPI, ImportGraphDef) {
694 TF_Status* s = TF_NewStatus();
695 TF_Graph* graph = TF_NewGraph();
696
697 // Create a simple graph.
698 Placeholder(graph, s);
699 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
700 ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
701 TF_Operation* oper = ScalarConst(3, graph, s);
702 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
703 ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr);
704 Neg(oper, graph, s);
705 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
706 ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr);
707
708 // Export to a GraphDef.
709 TF_Buffer* graph_def = TF_NewBuffer();
710 TF_GraphToGraphDef(graph, graph_def, s);
711 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
712
713 // Import it, with a prefix, in a fresh graph.
714 TF_DeleteGraph(graph);
715 graph = TF_NewGraph();
716 TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
717 TF_ImportGraphDefOptionsSetPrefix(opts, "imported");
718 TF_GraphImportGraphDef(graph, graph_def, opts, s);
719 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
720
721 TF_Operation* scalar = TF_GraphOperationByName(graph, "imported/scalar");
722 TF_Operation* feed = TF_GraphOperationByName(graph, "imported/feed");
723 TF_Operation* neg = TF_GraphOperationByName(graph, "imported/neg");
724 ASSERT_TRUE(scalar != nullptr);
725 ASSERT_TRUE(feed != nullptr);
726 ASSERT_TRUE(neg != nullptr);
727
728 // Test basic structure of the imported graph.
729 EXPECT_EQ(0, TF_OperationNumInputs(scalar));
730 EXPECT_EQ(0, TF_OperationNumInputs(feed));
731 ASSERT_EQ(1, TF_OperationNumInputs(neg));
732 TF_Output neg_input = TF_OperationInput({neg, 0});
733 EXPECT_EQ(scalar, neg_input.oper);
734 EXPECT_EQ(0, neg_input.index);
735
736 // Test that we can't see control edges involving the source and sink nodes.
737 TF_Operation* control_ops[100];
738 EXPECT_EQ(0, TF_OperationNumControlInputs(scalar));
739 EXPECT_EQ(0, TF_OperationGetControlInputs(scalar, control_ops, 100));
740 EXPECT_EQ(0, TF_OperationNumControlOutputs(scalar));
741 EXPECT_EQ(0, TF_OperationGetControlOutputs(scalar, control_ops, 100));
742
743 EXPECT_EQ(0, TF_OperationNumControlInputs(feed));
744 EXPECT_EQ(0, TF_OperationGetControlInputs(feed, control_ops, 100));
745 EXPECT_EQ(0, TF_OperationNumControlOutputs(feed));
746 EXPECT_EQ(0, TF_OperationGetControlOutputs(feed, control_ops, 100));
747
748 EXPECT_EQ(0, TF_OperationNumControlInputs(neg));
749 EXPECT_EQ(0, TF_OperationGetControlInputs(neg, control_ops, 100));
750 EXPECT_EQ(0, TF_OperationNumControlOutputs(neg));
751 EXPECT_EQ(0, TF_OperationGetControlOutputs(neg, control_ops, 100));
752
753 // Import it again, with an input mapping, return outputs, and a return
754 // operation, into the same graph.
755 TF_DeleteImportGraphDefOptions(opts);
756 opts = TF_NewImportGraphDefOptions();
757 TF_ImportGraphDefOptionsSetPrefix(opts, "imported2");
758 TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, {scalar, 0});
759 TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0);
760 TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0);
761 EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts));
762 TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar");
763 EXPECT_EQ(1, TF_ImportGraphDefOptionsNumReturnOperations(opts));
764 TF_ImportGraphDefResults* results =
765 TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s);
766 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
767
768 TF_Operation* scalar2 = TF_GraphOperationByName(graph, "imported2/scalar");
769 TF_Operation* feed2 = TF_GraphOperationByName(graph, "imported2/feed");
770 TF_Operation* neg2 = TF_GraphOperationByName(graph, "imported2/neg");
771 ASSERT_TRUE(scalar2 != nullptr);
772 ASSERT_TRUE(feed2 != nullptr);
773 ASSERT_TRUE(neg2 != nullptr);
774
775 // Check input mapping
776 neg_input = TF_OperationInput({neg, 0});
777 EXPECT_EQ(scalar, neg_input.oper);
778 EXPECT_EQ(0, neg_input.index);
779
780 // Check return outputs
781 TF_Output* return_outputs;
782 int num_return_outputs;
783 TF_ImportGraphDefResultsReturnOutputs(results, &num_return_outputs,
784 &return_outputs);
785 ASSERT_EQ(2, num_return_outputs);
786 EXPECT_EQ(feed2, return_outputs[0].oper);
787 EXPECT_EQ(0, return_outputs[0].index);
788 EXPECT_EQ(scalar, return_outputs[1].oper); // remapped
789 EXPECT_EQ(0, return_outputs[1].index);
790
791 // Check return operation
792 TF_Operation** return_opers;
793 int num_return_opers;
794 TF_ImportGraphDefResultsReturnOperations(results, &num_return_opers,
795 &return_opers);
796 ASSERT_EQ(1, num_return_opers);
797 EXPECT_EQ(scalar2, return_opers[0]); // not remapped
798
799 TF_DeleteImportGraphDefResults(results);
800
801 // Import again, with control dependencies, into the same graph.
802 TF_DeleteImportGraphDefOptions(opts);
803 opts = TF_NewImportGraphDefOptions();
804 TF_ImportGraphDefOptionsSetPrefix(opts, "imported3");
805 TF_ImportGraphDefOptionsAddControlDependency(opts, feed);
806 TF_ImportGraphDefOptionsAddControlDependency(opts, feed2);
807 TF_GraphImportGraphDef(graph, graph_def, opts, s);
808 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
809
810 TF_Operation* scalar3 = TF_GraphOperationByName(graph, "imported3/scalar");
811 TF_Operation* feed3 = TF_GraphOperationByName(graph, "imported3/feed");
812 TF_Operation* neg3 = TF_GraphOperationByName(graph, "imported3/neg");
813 ASSERT_TRUE(scalar3 != nullptr);
814 ASSERT_TRUE(feed3 != nullptr);
815 ASSERT_TRUE(neg3 != nullptr);
816
817 // Check that newly-imported scalar and feed have control deps (neg3 will
818 // inherit them from input)
819 TF_Operation* control_inputs[100];
820 int num_control_inputs = TF_OperationGetControlInputs(
821 scalar3, control_inputs, TF_OperationNumControlInputs(scalar3));
822 ASSERT_EQ(2, num_control_inputs);
823 EXPECT_EQ(feed, control_inputs[0]);
824 EXPECT_EQ(feed2, control_inputs[1]);
825
826 num_control_inputs = TF_OperationGetControlInputs(
827 feed3, control_inputs, TF_OperationNumControlInputs(feed3));
828 ASSERT_EQ(2, num_control_inputs);
829 EXPECT_EQ(feed, control_inputs[0]);
830 EXPECT_EQ(feed2, control_inputs[1]);
831
832 // Export to a graph def so we can import a graph with control dependencies
833 TF_DeleteBuffer(graph_def);
834 graph_def = TF_NewBuffer();
835 TF_GraphToGraphDef(graph, graph_def, s);
836 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
837
838 // Import again, with remapped control dependency, into the same graph
839 TF_DeleteImportGraphDefOptions(opts);
840 opts = TF_NewImportGraphDefOptions();
841 TF_ImportGraphDefOptionsSetPrefix(opts, "imported4");
842 TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed);
843 TF_GraphImportGraphDef(graph, graph_def, opts, s);
844 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
845
846 TF_Operation* scalar4 =
847 TF_GraphOperationByName(graph, "imported4/imported3/scalar");
848 TF_Operation* feed4 =
849 TF_GraphOperationByName(graph, "imported4/imported2/feed");
850
851 // Check that imported `imported3/scalar` has remapped control dep from
852 // original graph and imported control dep
853 num_control_inputs = TF_OperationGetControlInputs(
854 scalar4, control_inputs, TF_OperationNumControlInputs(scalar4));
855 ASSERT_EQ(2, num_control_inputs);
856 EXPECT_EQ(feed, control_inputs[0]);
857 EXPECT_EQ(feed4, control_inputs[1]);
858
859 TF_DeleteImportGraphDefOptions(opts);
860 TF_DeleteBuffer(graph_def);
861
862 // Can add nodes to the imported graph without trouble.
863 Add(feed, scalar, graph, s);
864 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
865
866 TF_DeleteGraph(graph);
867 TF_DeleteStatus(s);
868 }
869
TEST(CAPI,ImportGraphDef_WithReturnOutputs)870 TEST(CAPI, ImportGraphDef_WithReturnOutputs) {
871 TF_Status* s = TF_NewStatus();
872 TF_Graph* graph = TF_NewGraph();
873
874 // Create a graph with two nodes: x and 3
875 Placeholder(graph, s);
876 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
877 ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
878 TF_Operation* oper = ScalarConst(3, graph, s);
879 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
880 ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr);
881 Neg(oper, graph, s);
882 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
883 ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr);
884
885 // Export to a GraphDef.
886 TF_Buffer* graph_def = TF_NewBuffer();
887 TF_GraphToGraphDef(graph, graph_def, s);
888 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
889
890 // Import it in a fresh graph with return outputs.
891 TF_DeleteGraph(graph);
892 graph = TF_NewGraph();
893 TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
894 TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0);
895 TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0);
896 EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts));
897 TF_Output return_outputs[2];
898 TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, opts,
899 return_outputs, 2, s);
900 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
901
902 TF_Operation* scalar = TF_GraphOperationByName(graph, "scalar");
903 TF_Operation* feed = TF_GraphOperationByName(graph, "feed");
904 TF_Operation* neg = TF_GraphOperationByName(graph, "neg");
905 ASSERT_TRUE(scalar != nullptr);
906 ASSERT_TRUE(feed != nullptr);
907 ASSERT_TRUE(neg != nullptr);
908
909 // Check return outputs
910 EXPECT_EQ(feed, return_outputs[0].oper);
911 EXPECT_EQ(0, return_outputs[0].index);
912 EXPECT_EQ(scalar, return_outputs[1].oper);
913 EXPECT_EQ(0, return_outputs[1].index);
914
915 TF_DeleteImportGraphDefOptions(opts);
916 TF_DeleteBuffer(graph_def);
917 TF_DeleteGraph(graph);
918 TF_DeleteStatus(s);
919 }
920
TEST(CAPI,ImportGraphDef_MissingUnusedInputMappings)921 TEST(CAPI, ImportGraphDef_MissingUnusedInputMappings) {
922 TF_Status* s = TF_NewStatus();
923 TF_Graph* graph = TF_NewGraph();
924
925 // Create a graph with two nodes: x and 3
926 Placeholder(graph, s);
927 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
928 ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
929 TF_Operation* oper = ScalarConst(3, graph, s);
930 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
931 ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr);
932 Neg(oper, graph, s);
933 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
934 ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr);
935
936 // Export to a GraphDef.
937 TF_Buffer* graph_def = TF_NewBuffer();
938 TF_GraphToGraphDef(graph, graph_def, s);
939 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
940
941 // Import it in a fresh graph.
942 TF_DeleteGraph(graph);
943 graph = TF_NewGraph();
944 TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
945 TF_GraphImportGraphDef(graph, graph_def, opts, s);
946 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
947
948 TF_Operation* scalar = TF_GraphOperationByName(graph, "scalar");
949
950 // Import it in a fresh graph with an unused input mapping.
951 TF_DeleteImportGraphDefOptions(opts);
952 opts = TF_NewImportGraphDefOptions();
953 TF_ImportGraphDefOptionsSetPrefix(opts, "imported");
954 TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, {scalar, 0});
955 TF_ImportGraphDefOptionsAddInputMapping(opts, "fake", 0, {scalar, 0});
956 TF_ImportGraphDefResults* results =
957 TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s);
958 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
959
960 // Check unused input mappings
961 int num_unused_input_mappings;
962 const char** src_names;
963 int* src_indexes;
964 TF_ImportGraphDefResultsMissingUnusedInputMappings(
965 results, &num_unused_input_mappings, &src_names, &src_indexes);
966 ASSERT_EQ(1, num_unused_input_mappings);
967 EXPECT_EQ(string("fake"), string(src_names[0]));
968 EXPECT_EQ(0, src_indexes[0]);
969
970 TF_DeleteImportGraphDefResults(results);
971 TF_DeleteImportGraphDefOptions(opts);
972 TF_DeleteBuffer(graph_def);
973 TF_DeleteGraph(graph);
974 TF_DeleteStatus(s);
975 }
976
TEST(CAPI,Session)977 TEST(CAPI, Session) {
978 TF_Status* s = TF_NewStatus();
979 TF_Graph* graph = TF_NewGraph();
980
981 // Make a placeholder operation.
982 TF_Operation* feed = Placeholder(graph, s);
983 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
984
985 // Make a constant operation with the scalar "2".
986 TF_Operation* two = ScalarConst(2, graph, s);
987 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
988
989 // Add operation.
990 TF_Operation* add = Add(feed, two, graph, s);
991 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
992
993 // Create a session for this graph.
994 CSession csession(graph, s);
995 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
996
997 // Run the graph.
998 csession.SetInputs({{feed, Int32Tensor(3)}});
999 csession.SetOutputs({add});
1000 csession.Run(s);
1001 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1002 TF_Tensor* out = csession.output_tensor(0);
1003 ASSERT_TRUE(out != nullptr);
1004 EXPECT_EQ(TF_INT32, TF_TensorType(out));
1005 EXPECT_EQ(0, TF_NumDims(out)); // scalar
1006 ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
1007 int32* output_contents = static_cast<int32*>(TF_TensorData(out));
1008 EXPECT_EQ(3 + 2, *output_contents);
1009
1010 // Add another operation to the graph.
1011 TF_Operation* neg = Neg(add, graph, s);
1012 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1013
1014 // Run up to the new operation.
1015 csession.SetInputs({{feed, Int32Tensor(7)}});
1016 csession.SetOutputs({neg});
1017 csession.Run(s);
1018 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1019 out = csession.output_tensor(0);
1020 ASSERT_TRUE(out != nullptr);
1021 EXPECT_EQ(TF_INT32, TF_TensorType(out));
1022 EXPECT_EQ(0, TF_NumDims(out)); // scalar
1023 ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
1024 output_contents = static_cast<int32*>(TF_TensorData(out));
1025 EXPECT_EQ(-(7 + 2), *output_contents);
1026
1027 // Clean up
1028 csession.CloseAndDelete(s);
1029 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1030 TF_DeleteGraph(graph);
1031 TF_DeleteStatus(s);
1032 }
1033
1034 // If `device` is non-empty, run Min op on that device.
1035 // Otherwise run it on the default device (CPU).
RunMinTest(const string & device,bool use_XLA)1036 void RunMinTest(const string& device, bool use_XLA) {
1037 TF_Status* s = TF_NewStatus();
1038 TF_Graph* graph = TF_NewGraph();
1039
1040 // Make a placeholder operation.
1041 TF_Operation* feed = Placeholder(graph, s);
1042 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1043
1044 // Make a constant operation with the scalar "0", for axis.
1045 TF_Operation* one = ScalarConst(0, graph, s);
1046 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1047
1048 // Create a session for this graph.
1049 CSession csession(graph, s, use_XLA);
1050 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1051
1052 if (!device.empty()) {
1053 LOG(INFO) << "Setting op Min on device " << device;
1054 }
1055 TF_Operation* min = MinWithDevice(feed, one, graph, device, s);
1056 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1057
1058 // Run the graph.
1059 csession.SetInputs({{feed, Int32Tensor({3, 2, 5})}});
1060 csession.SetOutputs({min});
1061 csession.Run(s);
1062 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1063 TF_Tensor* out = csession.output_tensor(0);
1064 ASSERT_TRUE(out != nullptr);
1065 EXPECT_EQ(TF_INT32, TF_TensorType(out));
1066 EXPECT_EQ(0, TF_NumDims(out)); // scalar
1067 ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
1068 int32* output_contents = static_cast<int32*>(TF_TensorData(out));
1069 EXPECT_EQ(2, *output_contents);
1070
1071 // Clean up
1072 csession.CloseAndDelete(s);
1073 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1074 TF_DeleteGraph(graph);
1075 TF_DeleteStatus(s);
1076 }
1077
TEST(CAPI,Session_Min_CPU)1078 TEST(CAPI, Session_Min_CPU) { RunMinTest(/*device=*/"", /*use_XLA=*/false); }
1079
TEST(CAPI,Session_Min_XLA_CPU)1080 TEST(CAPI, Session_Min_XLA_CPU) { RunMinTest(/*device=*/"", /*use_XLA=*/true); }
1081
TEST(CAPI,Session_Min_GPU)1082 TEST(CAPI, Session_Min_GPU) {
1083 const string gpu_device = GPUDeviceName();
1084 // Skip this test if no GPU is available.
1085 if (gpu_device.empty()) return;
1086
1087 RunMinTest(gpu_device, /*use_XLA=*/false);
1088 }
1089
TEST(CAPI,Session_Min_XLA_GPU)1090 TEST(CAPI, Session_Min_XLA_GPU) {
1091 const string gpu_device = GPUDeviceName();
1092 // Skip this test if no GPU is available.
1093 if (gpu_device.empty()) return;
1094
1095 RunMinTest(gpu_device, /*use_XLA=*/true);
1096 }
1097
TEST(CAPI,SessionPRun)1098 TEST(CAPI, SessionPRun) {
1099 TF_Status* s = TF_NewStatus();
1100 TF_Graph* graph = TF_NewGraph();
1101
1102 // Construct the graph: A + 2 + B
1103 TF_Operation* a = Placeholder(graph, s, "A");
1104 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1105
1106 TF_Operation* b = Placeholder(graph, s, "B");
1107 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1108
1109 TF_Operation* two = ScalarConst(2, graph, s);
1110 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1111
1112 TF_Operation* plus2 = Add(a, two, graph, s, "plus2");
1113 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1114
1115 TF_Operation* plusB = Add(plus2, b, graph, s, "plusB");
1116 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1117
1118 // Setup a session and a partial run handle. The partial run will allow
1119 // computation of A + 2 + B in two phases (calls to TF_SessionPRun):
1120 // 1. Feed A and get (A+2)
1121 // 2. Feed B and get (A+2)+B
1122 TF_SessionOptions* opts = TF_NewSessionOptions();
1123 TF_Session* sess = TF_NewSession(graph, opts, s);
1124 TF_DeleteSessionOptions(opts);
1125
1126 TF_Output feeds[] = {TF_Output{a, 0}, TF_Output{b, 0}};
1127 TF_Output fetches[] = {TF_Output{plus2, 0}, TF_Output{plusB, 0}};
1128
1129 const char* handle = nullptr;
1130 TF_SessionPRunSetup(sess, feeds, TF_ARRAYSIZE(feeds), fetches,
1131 TF_ARRAYSIZE(fetches), nullptr, 0, &handle, s);
1132 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1133
1134 // Feed A and fetch A + 2.
1135 TF_Output feeds1[] = {TF_Output{a, 0}};
1136 TF_Output fetches1[] = {TF_Output{plus2, 0}};
1137 TF_Tensor* feedValues1[] = {Int32Tensor(1)};
1138 TF_Tensor* fetchValues1[1];
1139 TF_SessionPRun(sess, handle, feeds1, feedValues1, 1, fetches1, fetchValues1,
1140 1, nullptr, 0, s);
1141 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1142 EXPECT_EQ(3, *(static_cast<int32*>(TF_TensorData(fetchValues1[0]))));
1143 TF_DeleteTensor(feedValues1[0]);
1144 TF_DeleteTensor(fetchValues1[0]);
1145
1146 // Feed B and fetch (A + 2) + B.
1147 TF_Output feeds2[] = {TF_Output{b, 0}};
1148 TF_Output fetches2[] = {TF_Output{plusB, 0}};
1149 TF_Tensor* feedValues2[] = {Int32Tensor(4)};
1150 TF_Tensor* fetchValues2[1];
1151 TF_SessionPRun(sess, handle, feeds2, feedValues2, 1, fetches2, fetchValues2,
1152 1, nullptr, 0, s);
1153 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1154 EXPECT_EQ(7, *(static_cast<int32*>(TF_TensorData(fetchValues2[0]))));
1155 TF_DeleteTensor(feedValues2[0]);
1156 TF_DeleteTensor(fetchValues2[0]);
1157
1158 // Clean up.
1159 TF_DeletePRunHandle(handle);
1160 TF_DeleteSession(sess, s);
1161 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1162 TF_DeleteGraph(graph);
1163 TF_DeleteStatus(s);
1164 }
1165
TEST(CAPI,ShapeInferenceError)1166 TEST(CAPI, ShapeInferenceError) {
1167 // TF_FinishOperation should fail if the shape of the added operation cannot
1168 // be inferred.
1169 TF_Status* status = TF_NewStatus();
1170 TF_Graph* graph = TF_NewGraph();
1171
1172 // Create this failure by trying to add two nodes with incompatible shapes
1173 // (A tensor with shape [2] and a tensor with shape [3] cannot be added).
1174 const char data[] = {1, 2, 3};
1175 const int64_t vec2_dims[] = {2};
1176 unique_tensor_ptr vec2_tensor(
1177 Int8Tensor(vec2_dims, TF_ARRAYSIZE(vec2_dims), data), TF_DeleteTensor);
1178 TF_Operation* vec2 = Const(vec2_tensor.get(), graph, status, "vec2");
1179 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1180
1181 const int64_t vec3_dims[] = {3};
1182 unique_tensor_ptr vec3_tensor(
1183 Int8Tensor(vec3_dims, TF_ARRAYSIZE(vec3_dims), data), TF_DeleteTensor);
1184 TF_Operation* vec3 = Const(vec3_tensor.get(), graph, status, "vec3");
1185 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1186
1187 TF_Operation* add = AddNoCheck(vec2, vec3, graph, status);
1188 ASSERT_NE(TF_OK, TF_GetCode(status));
1189 ASSERT_TRUE(add == nullptr);
1190
1191 TF_DeleteGraph(graph);
1192 TF_DeleteStatus(status);
1193 }
1194
TEST(CAPI,GetOpDef)1195 TEST(CAPI, GetOpDef) {
1196 TF_Status* status = TF_NewStatus();
1197 TF_Graph* graph = TF_NewGraph();
1198 TF_Buffer* buffer = TF_NewBuffer();
1199
1200 TF_GraphGetOpDef(graph, "Add", buffer, status);
1201 ASSERT_EQ(TF_OK, TF_GetCode(status));
1202 const OpDef* expected_op_def;
1203 TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef("Add", &expected_op_def));
1204 string expected_serialized;
1205 expected_op_def->SerializeToString(&expected_serialized);
1206 string actual_string(reinterpret_cast<const char*>(buffer->data),
1207 buffer->length);
1208 EXPECT_EQ(expected_serialized, actual_string);
1209
1210 TF_GraphGetOpDef(graph, "MyFakeOp", buffer, status);
1211 EXPECT_EQ(TF_NOT_FOUND, TF_GetCode(status));
1212 ExpectHasSubstr(TF_Message(status),
1213 "Op type not registered 'MyFakeOp' in binary");
1214
1215 TF_DeleteBuffer(buffer);
1216 TF_DeleteGraph(graph);
1217 TF_DeleteStatus(status);
1218 }
1219
StringVectorToArrays(const std::vector<string> & v,std::unique_ptr<const void * []> * ptrs,std::unique_ptr<size_t[]> * lens)1220 void StringVectorToArrays(const std::vector<string>& v,
1221 std::unique_ptr<const void*[]>* ptrs,
1222 std::unique_ptr<size_t[]>* lens) {
1223 ptrs->reset(new const void*[v.size()]);
1224 lens->reset(new size_t[v.size()]);
1225 for (size_t i = 0; i < v.size(); ++i) {
1226 (*ptrs)[i] = v[i].data();
1227 (*lens)[i] = v[i].size();
1228 }
1229 }
1230
1231 class CApiColocationTest : public ::testing::Test {
1232 protected:
CApiColocationTest()1233 CApiColocationTest() : s_(TF_NewStatus()), graph_(TF_NewGraph()) {}
1234
SetUp()1235 void SetUp() override {
1236 feed1_ = Placeholder(graph_, s_, "feed1");
1237 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1238
1239 feed2_ = Placeholder(graph_, s_, "feed2");
1240 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1241
1242 constant_ = ScalarConst(10, graph_, s_);
1243 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1244
1245 desc_ = TF_NewOperation(graph_, "AddN", "add");
1246 TF_Output inputs[] = {{feed1_, 0}, {constant_, 0}};
1247 TF_AddInputList(desc_, inputs, TF_ARRAYSIZE(inputs));
1248 }
1249
~CApiColocationTest()1250 ~CApiColocationTest() override {
1251 TF_DeleteGraph(graph_);
1252 TF_DeleteStatus(s_);
1253 }
1254
SetViaStringList(TF_OperationDescription * desc,const std::vector<string> & list)1255 void SetViaStringList(TF_OperationDescription* desc,
1256 const std::vector<string>& list) {
1257 std::unique_ptr<const void*[]> list_ptrs;
1258 std::unique_ptr<size_t[]> list_lens;
1259 StringVectorToArrays(list, &list_ptrs, &list_lens);
1260 TF_SetAttrStringList(desc, tensorflow::kColocationAttrName, list_ptrs.get(),
1261 list_lens.get(), list.size());
1262 }
1263
SetViaProto(TF_OperationDescription * desc,const std::vector<string> & list)1264 void SetViaProto(TF_OperationDescription* desc,
1265 const std::vector<string>& list) {
1266 tensorflow::AttrValue attr;
1267 for (const string& v : list) {
1268 attr.mutable_list()->add_s(v);
1269 }
1270 string bytes;
1271 attr.SerializeToString(&bytes);
1272 TF_SetAttrValueProto(desc, tensorflow::kColocationAttrName, bytes.data(),
1273 bytes.size(), s_);
1274 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1275 }
1276
VerifyCollocation(TF_Operation * op,const std::vector<string> & expected)1277 void VerifyCollocation(TF_Operation* op,
1278 const std::vector<string>& expected) {
1279 TF_AttrMetadata m =
1280 TF_OperationGetAttrMetadata(op, tensorflow::kColocationAttrName, s_);
1281 if (expected.empty()) {
1282 ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
1283 EXPECT_EQ("Operation 'add' has no attr named '_class'.",
1284 string(TF_Message(s_)));
1285 return;
1286 }
1287 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1288 EXPECT_EQ(1, m.is_list);
1289 EXPECT_EQ(expected.size(), m.list_size);
1290 EXPECT_EQ(TF_ATTR_STRING, m.type);
1291 std::vector<void*> values(expected.size());
1292 std::vector<size_t> lens(expected.size());
1293 std::unique_ptr<char[]> storage(new char[m.total_size]);
1294 TF_OperationGetAttrStringList(op, tensorflow::kColocationAttrName,
1295 values.data(), lens.data(), expected.size(),
1296 storage.get(), m.total_size, s_);
1297 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1298 for (int i = 0; i < expected.size(); ++i) {
1299 EXPECT_EQ(expected[i],
1300 string(static_cast<const char*>(values[i]), lens[i]));
1301 }
1302 }
1303
FinishAndVerify(TF_OperationDescription * desc,const std::vector<string> & expected)1304 void FinishAndVerify(TF_OperationDescription* desc,
1305 const std::vector<string>& expected) {
1306 TF_Operation* op = TF_FinishOperation(desc_, s_);
1307 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1308 VerifyCollocation(op, expected);
1309 }
1310
1311 TF_Status* s_;
1312 TF_Graph* graph_;
1313 TF_Operation* feed1_;
1314 TF_Operation* feed2_;
1315 TF_Operation* constant_;
1316 TF_OperationDescription* desc_;
1317 };
1318
TEST_F(CApiColocationTest,ColocateWith)1319 TEST_F(CApiColocationTest, ColocateWith) {
1320 TF_ColocateWith(desc_, feed1_);
1321 FinishAndVerify(desc_, {"loc:@feed1"});
1322 }
1323
TEST_F(CApiColocationTest,StringList)1324 TEST_F(CApiColocationTest, StringList) {
1325 SetViaStringList(desc_, {"loc:@feed1"});
1326 FinishAndVerify(desc_, {"loc:@feed1"});
1327 }
1328
TEST_F(CApiColocationTest,Proto)1329 TEST_F(CApiColocationTest, Proto) {
1330 SetViaProto(desc_, {"loc:@feed1"});
1331 FinishAndVerify(desc_, {"loc:@feed1"});
1332 }
1333
TEST_F(CApiColocationTest,ColocateWith_StringList)1334 TEST_F(CApiColocationTest, ColocateWith_StringList) {
1335 TF_ColocateWith(desc_, feed1_);
1336 SetViaStringList(desc_, {"loc:@feed2"});
1337 FinishAndVerify(desc_, {"loc:@feed2"});
1338 }
1339
TEST_F(CApiColocationTest,ColocateWith_Proto)1340 TEST_F(CApiColocationTest, ColocateWith_Proto) {
1341 TF_ColocateWith(desc_, feed1_);
1342 SetViaProto(desc_, {"loc:@feed2"});
1343 FinishAndVerify(desc_, {"loc:@feed2"});
1344 }
1345
TEST_F(CApiColocationTest,StringList_ColocateWith)1346 TEST_F(CApiColocationTest, StringList_ColocateWith) {
1347 SetViaStringList(desc_, {"loc:@feed2"});
1348 TF_ColocateWith(desc_, feed1_);
1349 FinishAndVerify(desc_, {"loc:@feed1", "loc:@feed2"});
1350 }
1351
TEST_F(CApiColocationTest,Proto_ColocateWith)1352 TEST_F(CApiColocationTest, Proto_ColocateWith) {
1353 SetViaProto(desc_, {"loc:@feed2"});
1354 TF_ColocateWith(desc_, feed1_);
1355 FinishAndVerify(desc_, {"loc:@feed1", "loc:@feed2"});
1356 }
1357
TEST_F(CApiColocationTest,ColocateWith_ColocateWith)1358 TEST_F(CApiColocationTest, ColocateWith_ColocateWith) {
1359 TF_ColocateWith(desc_, feed1_);
1360 TF_ColocateWith(desc_, feed2_);
1361 FinishAndVerify(desc_, {"loc:@feed1", "loc:@feed2"});
1362 }
1363
TEST_F(CApiColocationTest,Proto_StringList)1364 TEST_F(CApiColocationTest, Proto_StringList) {
1365 SetViaProto(desc_, {"loc:@feed1"});
1366 SetViaStringList(desc_, {"loc:@feed2"});
1367 FinishAndVerify(desc_, {"loc:@feed2"});
1368 }
1369
TEST_F(CApiColocationTest,StringList_Proto)1370 TEST_F(CApiColocationTest, StringList_Proto) {
1371 SetViaStringList(desc_, {"loc:@feed1"});
1372 SetViaProto(desc_, {"loc:@feed2"});
1373 FinishAndVerify(desc_, {"loc:@feed2"});
1374 }
1375
TEST_F(CApiColocationTest,ClearViaStringList)1376 TEST_F(CApiColocationTest, ClearViaStringList) {
1377 TF_ColocateWith(desc_, feed1_);
1378 SetViaStringList(desc_, {});
1379 FinishAndVerify(desc_, {});
1380 }
1381
TEST_F(CApiColocationTest,ClearViaProto)1382 TEST_F(CApiColocationTest, ClearViaProto) {
1383 TF_ColocateWith(desc_, feed1_);
1384 SetViaProto(desc_, {});
1385 FinishAndVerify(desc_, {});
1386 }
1387
TEST(CAPI,SavedModel)1388 TEST(CAPI, SavedModel) {
1389 // Load the saved model.
1390 const string saved_model_dir = tensorflow::GetDataDependencyFilepath(
1391 tensorflow::io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
1392 "half_plus_two", "00000123"));
1393 TF_SessionOptions* opt = TF_NewSessionOptions();
1394 TF_Buffer* run_options = TF_NewBufferFromString("", 0);
1395 TF_Buffer* metagraph = TF_NewBuffer();
1396 TF_Status* s = TF_NewStatus();
1397 const char* tags[] = {tensorflow::kSavedModelTagServe};
1398 TF_Graph* graph = TF_NewGraph();
1399 TF_Session* session = TF_LoadSessionFromSavedModel(
1400 opt, run_options, saved_model_dir.c_str(), tags, 1, graph, metagraph, s);
1401 TF_DeleteBuffer(run_options);
1402 TF_DeleteSessionOptions(opt);
1403 tensorflow::MetaGraphDef metagraph_def;
1404 metagraph_def.ParseFromArray(metagraph->data, metagraph->length);
1405 TF_DeleteBuffer(metagraph);
1406
1407 EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1408 CSession csession(session);
1409
1410 // Retrieve the regression signature from meta graph def.
1411 const auto signature_def_map = metagraph_def.signature_def();
1412 const auto signature_def = signature_def_map.at("regress_x_to_y");
1413
1414 const string input_name =
1415 signature_def.inputs().at(tensorflow::kRegressInputs).name();
1416 const string output_name =
1417 signature_def.outputs().at(tensorflow::kRegressOutputs).name();
1418
1419 // Write {0, 1, 2, 3} as tensorflow::Example inputs.
1420 Tensor input(tensorflow::DT_STRING, TensorShape({4}));
1421 for (tensorflow::int64 i = 0; i < input.NumElements(); ++i) {
1422 tensorflow::Example example;
1423 auto* feature_map = example.mutable_features()->mutable_feature();
1424 (*feature_map)["x"].mutable_float_list()->add_value(i);
1425 input.flat<tstring>()(i) = example.SerializeAsString();
1426 }
1427
1428 const tensorflow::string input_op_name(
1429 tensorflow::ParseTensorName(input_name).first);
1430 TF_Operation* input_op =
1431 TF_GraphOperationByName(graph, input_op_name.c_str());
1432 ASSERT_TRUE(input_op != nullptr);
1433 Status status;
1434 csession.SetInputs({{input_op, TF_TensorFromTensor(input, &status)}});
1435 ASSERT_TRUE(status.ok()) << status.error_message();
1436
1437 const tensorflow::string output_op_name(
1438 tensorflow::ParseTensorName(output_name).first);
1439 TF_Operation* output_op =
1440 TF_GraphOperationByName(graph, output_op_name.c_str());
1441 ASSERT_TRUE(output_op != nullptr);
1442 csession.SetOutputs({output_op});
1443 csession.Run(s);
1444 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1445
1446 TF_Tensor* out = csession.output_tensor(0);
1447 ASSERT_TRUE(out != nullptr);
1448 EXPECT_EQ(TF_FLOAT, TF_TensorType(out));
1449 EXPECT_EQ(2, TF_NumDims(out));
1450 EXPECT_EQ(4, TF_Dim(out, 0));
1451 EXPECT_EQ(1, TF_Dim(out, 1));
1452 float* values = static_cast<float*>(TF_TensorData(out));
1453 // These values are defined to be (input / 2) + 2.
1454 EXPECT_EQ(2, values[0]);
1455 EXPECT_EQ(2.5, values[1]);
1456 EXPECT_EQ(3, values[2]);
1457 EXPECT_EQ(3.5, values[3]);
1458
1459 csession.CloseAndDelete(s);
1460 EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1461 TF_DeleteGraph(graph);
1462 TF_DeleteStatus(s);
1463 }
1464
TEST(CAPI,SavedModelNullArgsAreValid)1465 TEST(CAPI, SavedModelNullArgsAreValid) {
1466 const string saved_model_dir = tensorflow::GetDataDependencyFilepath(
1467 tensorflow::io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
1468 "half_plus_two", "00000123"));
1469 TF_SessionOptions* opt = TF_NewSessionOptions();
1470 TF_Status* s = TF_NewStatus();
1471 const char* tags[] = {tensorflow::kSavedModelTagServe};
1472 TF_Graph* graph = TF_NewGraph();
1473 // NULL run_options and meta_graph_def should work.
1474 TF_Session* session = TF_LoadSessionFromSavedModel(
1475 opt, nullptr, saved_model_dir.c_str(), tags, 1, graph, nullptr, s);
1476 EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1477 TF_DeleteSessionOptions(opt);
1478 TF_CloseSession(session, s);
1479 EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1480 TF_DeleteSession(session, s);
1481 EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1482 TF_DeleteGraph(graph);
1483 TF_DeleteStatus(s);
1484 }
1485
TEST(CAPI,DeletingNullPointerIsSafe)1486 TEST(CAPI, DeletingNullPointerIsSafe) {
1487 TF_Status* status = TF_NewStatus();
1488
1489 TF_DeleteStatus(nullptr);
1490 TF_DeleteBuffer(nullptr);
1491 TF_DeleteTensor(nullptr);
1492 TF_DeleteSessionOptions(nullptr);
1493 TF_DeleteGraph(nullptr);
1494 TF_DeleteImportGraphDefOptions(nullptr);
1495 TF_DeleteImportGraphDefResults(nullptr);
1496 TF_DeleteFunction(nullptr);
1497 TF_DeleteSession(nullptr, status);
1498 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1499 TF_DeletePRunHandle(nullptr);
1500 TF_DeleteDeprecatedSession(nullptr, status);
1501 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1502 TF_DeleteDeviceList(nullptr);
1503 TF_DeleteLibraryHandle(nullptr);
1504 TF_DeleteApiDefMap(nullptr);
1505
1506 TF_DeleteStatus(status);
1507 }
1508
TEST(CAPI,TestBitcastFrom_Reshape)1509 TEST(CAPI, TestBitcastFrom_Reshape) {
1510 int64_t dims[] = {2, 3};
1511 TF_Tensor* a =
1512 TF_AllocateTensor(TF_UINT64, dims, 2, 6 * TF_DataTypeSize(TF_UINT64));
1513 TF_Tensor* b =
1514 TF_AllocateTensor(TF_UINT64, nullptr, 0, TF_DataTypeSize(TF_UINT64));
1515 EXPECT_NE(a, nullptr);
1516 EXPECT_NE(b, nullptr);
1517
1518 EXPECT_EQ(6, TF_TensorElementCount(a));
1519 EXPECT_EQ(1, TF_TensorElementCount(b));
1520 EXPECT_EQ(6 * TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(a));
1521 EXPECT_EQ(TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(b));
1522
1523 int64_t new_dims[] = {3, 2};
1524 TF_Status* status = TF_NewStatus();
1525 TF_TensorBitcastFrom(a, TF_UINT64, b, new_dims, 2, status);
1526 ASSERT_EQ(TF_OK, TF_GetCode(status));
1527 TF_DeleteStatus(status);
1528
1529 EXPECT_EQ(6, TF_TensorElementCount(a));
1530 EXPECT_EQ(6, TF_TensorElementCount(b));
1531 EXPECT_EQ(6 * TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(a));
1532 EXPECT_EQ(6 * TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(b));
1533
1534 // Check that a write to one tensor shows up in the other.
1535 *(static_cast<int64_t*>(TF_TensorData(a))) = 4;
1536 EXPECT_EQ(4, *(static_cast<int64_t*>(TF_TensorData(b))));
1537 *(static_cast<int64_t*>(TF_TensorData(b))) = 6;
1538 EXPECT_EQ(6, *(static_cast<int64_t*>(TF_TensorData(a))));
1539
1540 TF_DeleteTensor(a);
1541 TF_DeleteTensor(b);
1542 }
1543
1544 REGISTER_OP("TestOpWithNoGradient")
1545 .Input("x: T")
1546 .Output("y: T")
1547 .Attr("T: {float, double}")
1548 .Doc(R"doc(
1549 Test op with no grad registered.
1550
1551 x: input
1552 y: output
1553 )doc")
1554 .SetShapeFn(tensorflow::shape_inference::UnknownShape);
1555
1556 class CApiGradientsTest : public ::testing::Test {
1557 protected:
CApiGradientsTest()1558 CApiGradientsTest()
1559 : s_(TF_NewStatus()),
1560 graph_(TF_NewGraph()),
1561 expected_graph_(TF_NewGraph()) {}
1562
~CApiGradientsTest()1563 ~CApiGradientsTest() override {
1564 TF_DeleteGraph(graph_);
1565 TF_DeleteGraph(expected_graph_);
1566 TF_DeleteStatus(s_);
1567 }
1568
TestGradientsSuccess(bool grad_inputs_provided)1569 void TestGradientsSuccess(bool grad_inputs_provided) {
1570 TF_Output inputs[2];
1571 TF_Output outputs[1];
1572 TF_Output grad_outputs[2];
1573 TF_Output expected_grad_outputs[2];
1574
1575 BuildSuccessGraph(inputs, outputs);
1576 BuildExpectedGraph(grad_inputs_provided, expected_grad_outputs);
1577
1578 AddGradients(grad_inputs_provided, nullptr, inputs, 2, outputs, 1,
1579 grad_outputs);
1580 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1581
1582 // Compare that the graphs match.
1583 GraphDef expected_gdef;
1584 GraphDef gdef;
1585 EXPECT_TRUE(GetGraphDef(expected_graph_, &expected_gdef));
1586 EXPECT_TRUE(GetGraphDef(graph_, &gdef));
1587 TF_EXPECT_GRAPH_EQ(expected_gdef, gdef);
1588
1589 // Compare that the output of the gradients of both graphs match.
1590 RunGraphsAndCompareOutputs(grad_outputs, expected_grad_outputs);
1591 }
1592
TestGradientsError(bool grad_inputs_provided)1593 void TestGradientsError(bool grad_inputs_provided) {
1594 TF_Output inputs[1];
1595 TF_Output outputs[1];
1596 TF_Output grad_outputs[1];
1597
1598 BuildErrorGraph(inputs, outputs);
1599
1600 AddGradients(grad_inputs_provided, nullptr, inputs, 1, outputs, 1,
1601 grad_outputs);
1602
1603 string expected_msg =
1604 "No gradient defined for op: TestOpWithNoGradient. Please see "
1605 "https://www.tensorflow.org/code/"
1606 "tensorflow/cc/gradients/README.md"
1607 " for instructions on how to add C++ gradients.";
1608 EXPECT_EQ(expected_msg, TF_Message(s_));
1609 }
1610
1611 // Run the graph and ensure that the gradient values are as expected.
RunGraphsAndCompareOutputs(TF_Output * grad_outputs,TF_Output * expected_grad_outputs)1612 void RunGraphsAndCompareOutputs(TF_Output* grad_outputs,
1613 TF_Output* expected_grad_outputs) {
1614 std::unique_ptr<CSession> csession(new CSession(graph_, s_));
1615 std::unique_ptr<CSession> expected_csession(
1616 new CSession(expected_graph_, s_));
1617
1618 std::vector<TF_Output> grad_outputs_vec;
1619 grad_outputs_vec.assign(grad_outputs, grad_outputs + 2);
1620 csession->SetOutputs(grad_outputs_vec);
1621 csession->Run(s_);
1622 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1623 TF_Tensor* out0 = csession->output_tensor(0);
1624 TF_Tensor* out1 = csession->output_tensor(1);
1625
1626 std::vector<TF_Output> expected_grad_outputs_vec;
1627 expected_grad_outputs_vec.assign(expected_grad_outputs,
1628 expected_grad_outputs + 2);
1629 expected_csession->SetOutputs(expected_grad_outputs_vec);
1630 expected_csession->Run(s_);
1631 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1632 TF_Tensor* expected_out0 = expected_csession->output_tensor(0);
1633 TF_Tensor* expected_out1 = expected_csession->output_tensor(1);
1634
1635 CompareTensors(out0, expected_out0);
1636 CompareTensors(out1, expected_out1);
1637 }
1638
CompareTensors(TF_Tensor * a,TF_Tensor * b)1639 void CompareTensors(TF_Tensor* a, TF_Tensor* b) {
1640 float* a_data = static_cast<float*>(TF_TensorData(a));
1641 float* b_data = static_cast<float*>(TF_TensorData(b));
1642 EXPECT_EQ(*a_data, *b_data);
1643 }
1644
AddGradients(bool grad_inputs_provided,const char * prefix,TF_Output * inputs,int ninputs,TF_Output * outputs,int noutputs,TF_Output * grad_outputs)1645 void AddGradients(bool grad_inputs_provided, const char* prefix,
1646 TF_Output* inputs, int ninputs, TF_Output* outputs,
1647 int noutputs, TF_Output* grad_outputs) {
1648 if (grad_inputs_provided) {
1649 TF_Output grad_inputs[1];
1650 const float grad_inputs_val[] = {1.0, 1.0, 1.0, 1.0};
1651 TF_Operation* grad_inputs_op =
1652 FloatConst2x2(graph_, s_, grad_inputs_val, "GradInputs");
1653 grad_inputs[0] = TF_Output{grad_inputs_op, 0};
1654 TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs,
1655 ninputs, grad_inputs, s_, grad_outputs);
1656 } else {
1657 TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs,
1658 ninputs, nullptr, s_, grad_outputs);
1659 }
1660 }
1661
BuildErrorGraph(TF_Output * inputs,TF_Output * outputs)1662 void BuildErrorGraph(TF_Output* inputs, TF_Output* outputs) {
1663 const float const0_val[] = {1.0, 2.0, 3.0, 4.0};
1664 TF_Operation* const0 = FloatConst2x2(graph_, s_, const0_val, "Const_0");
1665 TF_Operation* nograd = NoGradientOp(graph_, s_, const0, "NoGrad");
1666 inputs[0] = TF_Output{const0, 0};
1667 outputs[0] = TF_Output{nograd, 0};
1668 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1669 }
1670
BuildSuccessGraph(TF_Output * inputs,TF_Output * outputs)1671 void BuildSuccessGraph(TF_Output* inputs, TF_Output* outputs) {
1672 // Construct the following graph:
1673 // |
1674 // z|
1675 // |
1676 // MatMul
1677 // / \
1678 // ^ ^
1679 // | |
1680 // x| y|
1681 // | |
1682 // | |
1683 // Const_0 Const_1
1684 //
1685 const float const0_val[] = {1.0, 2.0, 3.0, 4.0};
1686 const float const1_val[] = {1.0, 0.0, 0.0, 1.0};
1687 TF_Operation* const0 = FloatConst2x2(graph_, s_, const0_val, "Const_0");
1688 TF_Operation* const1 = FloatConst2x2(graph_, s_, const1_val, "Const_1");
1689 TF_Operation* matmul = MatMul(graph_, s_, const0, const1, "MatMul");
1690 inputs[0] = TF_Output{const0, 0};
1691 inputs[1] = TF_Output{const1, 0};
1692 outputs[0] = TF_Output{matmul, 0};
1693 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1694 }
1695
BuildExpectedGraph(bool grad_inputs_provided,TF_Output * expected_grad_outputs)1696 void BuildExpectedGraph(bool grad_inputs_provided,
1697 TF_Output* expected_grad_outputs) {
1698 // The expected graph looks like this if grad_inputs_provided.
1699 // If grad_inputs_provided is false, Const_0 will be a OnesLike op.
1700 // ^ ^
1701 // dy| dx| // MatMul Gradient Graph
1702 // | |
1703 // MatMul_2 MatMul_1
1704 // ^ ^ ^ ^
1705 // | |----------| |
1706 // | ^ |
1707 // | dz| |
1708 // | | |
1709 // | Const_3 |
1710 // | |
1711 // | ^ |
1712 // | z| | // MatMul Forward Graph
1713 // | | |
1714 // | MatMul |
1715 // | / \ |
1716 // | ^ ^ |
1717 // | | | |
1718 // |---x| y|----|
1719 // | |
1720 // | |
1721 // Const_0 Const_1
1722 //
1723 const float const0_val[] = {1.0, 2.0, 3.0, 4.0};
1724 const float const1_val[] = {1.0, 0.0, 0.0, 1.0};
1725 TF_Operation* const0 =
1726 FloatConst2x2(expected_graph_, s_, const0_val, "Const_0");
1727 TF_Operation* const1 =
1728 FloatConst2x2(expected_graph_, s_, const1_val, "Const_1");
1729 TF_Operation* matmul =
1730 MatMul(expected_graph_, s_, const0, const1, "MatMul");
1731
1732 TF_Operation* const3;
1733 if (grad_inputs_provided) {
1734 const float const3_val[] = {1.0, 1.0, 1.0, 1.0};
1735 const3 = FloatConst2x2(expected_graph_, s_, const3_val, "GradInputs");
1736 } else {
1737 const3 = OnesLike(expected_graph_, s_, matmul, "gradients/OnesLike");
1738 }
1739
1740 TF_Operation* matmul1 = MatMul(expected_graph_, s_, const3, const1,
1741 "gradients/MatMul", false, true);
1742 TF_Operation* matmul2 = MatMul(expected_graph_, s_, const0, const3,
1743 "gradients/MatMul_1", true, false);
1744 expected_grad_outputs[0] = {matmul1, 0};
1745 expected_grad_outputs[1] = {matmul2, 0};
1746 }
1747
FloatTensor2x2(const float * values)1748 TF_Tensor* FloatTensor2x2(const float* values) {
1749 const int64_t dims[2] = {2, 2};
1750 TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims, 2, sizeof(float) * 4);
1751 memcpy(TF_TensorData(t), values, sizeof(float) * 4);
1752 return t;
1753 }
1754
FloatConst2x2(TF_Graph * graph,TF_Status * s,const float * values,const char * name)1755 TF_Operation* FloatConst2x2(TF_Graph* graph, TF_Status* s,
1756 const float* values, const char* name) {
1757 unique_tensor_ptr tensor(FloatTensor2x2(values), TF_DeleteTensor);
1758 TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
1759 TF_SetAttrTensor(desc, "value", tensor.get(), s);
1760 if (TF_GetCode(s) != TF_OK) return nullptr;
1761 TF_SetAttrType(desc, "dtype", TF_FLOAT);
1762 TF_Operation* op = TF_FinishOperation(desc, s);
1763 EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1764 return op;
1765 }
1766
MatMul(TF_Graph * graph,TF_Status * s,TF_Operation * l,TF_Operation * r,const char * name,bool transpose_a=false,bool transpose_b=false)1767 TF_Operation* MatMul(TF_Graph* graph, TF_Status* s, TF_Operation* l,
1768 TF_Operation* r, const char* name,
1769 bool transpose_a = false, bool transpose_b = false) {
1770 TF_OperationDescription* desc = TF_NewOperation(graph, "MatMul", name);
1771 if (transpose_a) {
1772 TF_SetAttrBool(desc, "transpose_a", 1);
1773 }
1774 if (transpose_b) {
1775 TF_SetAttrBool(desc, "transpose_b", 1);
1776 }
1777 TF_AddInput(desc, {l, 0});
1778 TF_AddInput(desc, {r, 0});
1779 TF_Operation* op = TF_FinishOperation(desc, s);
1780 EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1781 return op;
1782 }
1783
OnesLike(TF_Graph * graph,TF_Status * s,TF_Operation * in,const char * name)1784 TF_Operation* OnesLike(TF_Graph* graph, TF_Status* s, TF_Operation* in,
1785 const char* name) {
1786 TF_OperationDescription* desc = TF_NewOperation(graph, "OnesLike", name);
1787 TF_AddInput(desc, {in, 0});
1788 TF_Operation* op = TF_FinishOperation(desc, s);
1789 EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1790 return op;
1791 }
1792
NoGradientOp(TF_Graph * graph,TF_Status * s,TF_Operation * in,const char * name)1793 TF_Operation* NoGradientOp(TF_Graph* graph, TF_Status* s, TF_Operation* in,
1794 const char* name) {
1795 TF_OperationDescription* desc =
1796 TF_NewOperation(graph, "TestOpWithNoGradient", name);
1797 TF_AddInput(desc, {in, 0});
1798 TF_Operation* op = TF_FinishOperation(desc, s);
1799 EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1800 return op;
1801 }
1802
BuildGraphAndAddGradientsWithPrefixes(const char * prefix1,const char * prefix2=nullptr)1803 void BuildGraphAndAddGradientsWithPrefixes(const char* prefix1,
1804 const char* prefix2 = nullptr) {
1805 TF_Output inputs[2];
1806 TF_Output outputs[1];
1807 TF_Output grad_outputs[2];
1808
1809 BuildSuccessGraph(inputs, outputs);
1810
1811 AddGradients(false, prefix1, inputs, 2, outputs, 1, grad_outputs);
1812 if (prefix2 != nullptr) {
1813 AddGradients(false, prefix2, inputs, 2, outputs, 1, grad_outputs);
1814 }
1815 }
1816
1817 TF_Status* s_;
1818 TF_Graph* graph_;
1819 TF_Graph* expected_graph_;
1820 };
1821
TEST_F(CApiGradientsTest,Gradients_GradInputs)1822 TEST_F(CApiGradientsTest, Gradients_GradInputs) { TestGradientsSuccess(true); }
1823
TEST_F(CApiGradientsTest,Gradients_NoGradInputs)1824 TEST_F(CApiGradientsTest, Gradients_NoGradInputs) {
1825 TestGradientsSuccess(false);
1826 }
1827
TEST_F(CApiGradientsTest,OpWithNoGradientRegistered_GradInputs)1828 TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_GradInputs) {
1829 TestGradientsError(true);
1830 }
1831
TEST_F(CApiGradientsTest,OpWithNoGradientRegistered_NoGradInputs)1832 TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) {
1833 TestGradientsError(false);
1834 }
1835
TEST_F(CApiGradientsTest,GradientsPrefix_PrefixIsOk)1836 TEST_F(CApiGradientsTest, GradientsPrefix_PrefixIsOk) {
1837 BuildGraphAndAddGradientsWithPrefixes("gradients");
1838 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1839 }
1840
TEST_F(CApiGradientsTest,GradientsPrefix_TwoGradientsWithDistinctPrefixes)1841 TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithDistinctPrefixes) {
1842 BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients_1");
1843 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1844 }
1845
TEST_F(CApiGradientsTest,GradientsPrefix_TwoGradientsInSameScope)1846 TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInSameScope) {
1847 BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope/gradients_1");
1848 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1849 }
1850
TEST_F(CApiGradientsTest,GradientsPrefix_TwoGradientsInDifferentScopes)1851 TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInDifferentScopes) {
1852 BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope_1/gradients");
1853 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1854 }
1855
TEST_F(CApiGradientsTest,GradientsPrefix_2ndGradientsAsSubScopeOf1st)1856 TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsSubScopeOf1st) {
1857 BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/sub");
1858 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1859 }
1860
TEST_F(CApiGradientsTest,GradientsPrefix_PrefixMatchesExistingNodeName)1861 TEST_F(CApiGradientsTest, GradientsPrefix_PrefixMatchesExistingNodeName) {
1862 BuildGraphAndAddGradientsWithPrefixes("Const_0");
1863 ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
1864 }
1865
TEST_F(CApiGradientsTest,GradientsPrefix_TwoGradientsWithIdenticalPrefixes)1866 TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithIdenticalPrefixes) {
1867 BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients");
1868 ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
1869 }
1870
TEST_F(CApiGradientsTest,GradientsPrefix_2ndGradientsMatchingNodeOf1st)1871 TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsMatchingNodeOf1st) {
1872 BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/MatMul");
1873 ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
1874 }
1875
TEST_F(CApiGradientsTest,GradientsPrefix_1stGradientsMatchingNodeOf2nd)1876 TEST_F(CApiGradientsTest, GradientsPrefix_1stGradientsMatchingNodeOf2nd) {
1877 BuildGraphAndAddGradientsWithPrefixes("gradients/MatMul", "gradients");
1878 ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
1879 }
1880
TEST_F(CApiGradientsTest,GradientsPrefix_2ndGradientsAsParentScopeOf1st)1881 TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsParentScopeOf1st) {
1882 BuildGraphAndAddGradientsWithPrefixes("gradients/sub", "gradients");
1883 ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
1884 }
1885
ScalarFloatFromTensor(const TF_Tensor * t,float * f)1886 void ScalarFloatFromTensor(const TF_Tensor* t, float* f) {
1887 ASSERT_TRUE(t != nullptr);
1888 ASSERT_EQ(TF_FLOAT, TF_TensorType(t));
1889 ASSERT_EQ(0, TF_NumDims(t));
1890 ASSERT_EQ(4, TF_TensorByteSize(t));
1891 float* p = static_cast<float*>(TF_TensorData(t));
1892 *f = *p;
1893 }
1894
TEST_F(CApiGradientsTest,MultipleCallsToAddGradients)1895 TEST_F(CApiGradientsTest, MultipleCallsToAddGradients) {
1896 const float X = 3.0f, Y = 7.0f;
1897 TF_Operation* x = Placeholder(graph_, s_, "x", TF_FLOAT);
1898 TF_Operation* y = Placeholder(graph_, s_, "y", TF_FLOAT);
1899 TF_Operation* xy = Mul(x, y, graph_, s_, "xy");
1900 TF_Output dxy_dx, dxy_dy;
1901
1902 TF_Output outputs[1] = {{xy, 0}};
1903 TF_Output inputs[1] = {{x, 0}};
1904 TF_AddGradients(graph_, outputs, 1, inputs, 1, nullptr, s_, &dxy_dx);
1905 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1906
1907 inputs[0] = {y, 0};
1908 TF_AddGradients(graph_, outputs, 1, inputs, 1, nullptr, s_, &dxy_dy);
1909 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1910
1911 TF_SessionOptions* opts = TF_NewSessionOptions();
1912 TF_Session* sess = TF_NewSession(graph_, opts, s_);
1913 TF_DeleteSessionOptions(opts);
1914 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1915
1916 TF_Output feeds[] = {{x, 0}, {y, 0}};
1917 TF_Tensor* feedValues[] = {FloatTensor(X), FloatTensor(Y)};
1918 TF_Output fetches[] = {dxy_dx, dxy_dy};
1919 TF_Tensor* fetchValues[] = {nullptr, nullptr};
1920
1921 TF_SessionRun(sess, nullptr /* run_options */, feeds, feedValues, 2, fetches,
1922 fetchValues, 2, nullptr /* target_opers */, 0,
1923 nullptr /* run_metadata */, s_);
1924 TF_DeleteTensor(feedValues[0]);
1925 TF_DeleteTensor(feedValues[1]);
1926 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1927 TF_DeleteSession(sess, s_);
1928 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1929
1930 float dxy_dxValue = 0.0f, dxy_dyValue = 0.0f;
1931 ScalarFloatFromTensor(fetchValues[0], &dxy_dxValue);
1932 EXPECT_EQ(Y, dxy_dxValue);
1933
1934 ScalarFloatFromTensor(fetchValues[1], &dxy_dyValue);
1935 EXPECT_EQ(X, dxy_dyValue);
1936
1937 TF_DeleteTensor(fetchValues[0]);
1938 TF_DeleteTensor(fetchValues[1]);
1939 }
1940
1941 // REGISTER_OP for CApiAttributesTest test cases.
1942 // Registers two ops, each with a single attribute called 'v'.
1943 // The attribute in one op will have a type 'type', the other
1944 // will have list(type).
1945 #define ATTR_TEST_REGISTER_OP(type) \
1946 REGISTER_OP("CApiAttributesTestOp" #type) \
1947 .Attr("v: " #type) \
1948 .SetShapeFn(tensorflow::shape_inference::UnknownShape); \
1949 REGISTER_OP("CApiAttributesTestOpList" #type) \
1950 .Attr("v: list(" #type ")") \
1951 .SetShapeFn(tensorflow::shape_inference::UnknownShape)
1952 ATTR_TEST_REGISTER_OP(string);
1953 ATTR_TEST_REGISTER_OP(int);
1954 ATTR_TEST_REGISTER_OP(float);
1955 ATTR_TEST_REGISTER_OP(bool);
1956 ATTR_TEST_REGISTER_OP(type);
1957 ATTR_TEST_REGISTER_OP(shape);
1958 ATTR_TEST_REGISTER_OP(tensor);
1959 #undef ATTR_TEST_REGISTER_OP
1960
1961 class CApiAttributesTest : public ::testing::Test {
1962 protected:
CApiAttributesTest()1963 CApiAttributesTest()
1964 : s_(TF_NewStatus()), graph_(TF_NewGraph()), counter_(0) {}
1965
~CApiAttributesTest()1966 ~CApiAttributesTest() override {
1967 TF_DeleteGraph(graph_);
1968 TF_DeleteStatus(s_);
1969 }
1970
init(string type)1971 TF_OperationDescription* init(string type) {
1972 // Construct op_name to match the name used by REGISTER_OP in the
1973 // ATTR_TEST_REGISTER calls above.
1974 string op_name = "CApiAttributesTestOp";
1975 if (type.find("list(") == 0) {
1976 op_name += "List";
1977 type = type.replace(0, 5, "");
1978 type = type.replace(type.size() - 1, 1, "");
1979 }
1980 op_name += type;
1981 return TF_NewOperation(
1982 graph_, op_name.c_str(),
1983 ::tensorflow::strings::StrCat("name", counter_++).c_str());
1984 }
1985
1986 TF_Status* s_;
1987
1988 private:
1989 TF_Graph* graph_;
1990 int counter_;
1991 };
1992
1993 // Helper macros for the TF_OperationGetAttr* tests.
1994 // TODO(ashankar): Use gmock matchers instead?
1995 // (https://github.com/google/googletest/blob/master/googlemock/docs/CookBook.md#writing-new-parameterized-matchers-quickly)
1996 // That will require setting up the tensorflow build with gmock.
1997 #define EXPECT_TF_META(attr_name, expected_list_size, expected_type, \
1998 expected_total_size) \
1999 do { \
2000 auto m = TF_OperationGetAttrMetadata(oper, attr_name, s_); \
2001 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); \
2002 const unsigned char e = expected_list_size >= 0 ? 1 : 0; \
2003 EXPECT_EQ(e, m.is_list); \
2004 EXPECT_EQ(expected_list_size, m.list_size); \
2005 EXPECT_EQ(expected_type, m.type); \
2006 EXPECT_EQ(expected_total_size, m.total_size); \
2007 } while (0)
2008
TEST_F(CApiAttributesTest,String)2009 TEST_F(CApiAttributesTest, String) {
2010 auto desc = init("string");
2011 TF_SetAttrString(desc, "v", "bunny", 5);
2012
2013 auto oper = TF_FinishOperation(desc, s_);
2014 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2015 EXPECT_TF_META("v", -1, TF_ATTR_STRING, 5);
2016 std::unique_ptr<char[]> value(new char[5]);
2017
2018 TF_OperationGetAttrString(oper, "v", value.get(), 5, s_);
2019 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2020 EXPECT_EQ("bunny", string(static_cast<const char*>(value.get()), 5));
2021 }
2022
TEST_F(CApiAttributesTest,StringList)2023 TEST_F(CApiAttributesTest, StringList) {
2024 std::vector<string> list = {"bugs", "bunny", "duck"};
2025 std::unique_ptr<const void*[]> list_ptrs;
2026 std::unique_ptr<size_t[]> list_lens;
2027 StringVectorToArrays(list, &list_ptrs, &list_lens);
2028 int list_total_size = 0;
2029 for (const auto& s : list) {
2030 list_total_size += s.size();
2031 }
2032
2033 auto desc = init("list(string)");
2034 TF_SetAttrStringList(desc, "v", list_ptrs.get(), list_lens.get(),
2035 list.size());
2036
2037 auto oper = TF_FinishOperation(desc, s_);
2038 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2039
2040 EXPECT_TF_META("v", list.size(), TF_ATTR_STRING, list_total_size);
2041 std::unique_ptr<void*[]> values(new void*[list.size()]);
2042 std::unique_ptr<size_t[]> lens(new size_t[list.size()]);
2043 std::unique_ptr<char[]> storage(new char[list_total_size]);
2044 TF_OperationGetAttrStringList(oper, "v", values.get(), lens.get(),
2045 list.size(), storage.get(), list_total_size,
2046 s_);
2047 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2048 for (size_t i = 0; i < list.size(); ++i) {
2049 EXPECT_EQ(list[i].size(), lens[i]) << i;
2050 EXPECT_EQ(list[i], string(static_cast<const char*>(values[i]), lens[i]))
2051 << i;
2052 }
2053 }
2054
TEST_F(CApiAttributesTest,Int)2055 TEST_F(CApiAttributesTest, Int) {
2056 auto desc = init("int");
2057 TF_SetAttrInt(desc, "v", 31415);
2058
2059 auto oper = TF_FinishOperation(desc, s_);
2060 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2061 EXPECT_TF_META("v", -1, TF_ATTR_INT, -1);
2062
2063 int64_t value;
2064 TF_OperationGetAttrInt(oper, "v", &value, s_);
2065 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2066 EXPECT_EQ(31415, value);
2067 }
2068
TEST_F(CApiAttributesTest,IntList)2069 TEST_F(CApiAttributesTest, IntList) {
2070 const int64_t list[] = {1, 2, 3, 4};
2071 const size_t list_size = TF_ARRAYSIZE(list);
2072
2073 auto desc = init("list(int)");
2074 TF_SetAttrIntList(desc, "v", list, list_size);
2075
2076 auto oper = TF_FinishOperation(desc, s_);
2077 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2078
2079 int64_t values[list_size];
2080 EXPECT_TF_META("v", list_size, TF_ATTR_INT, -1);
2081 TF_OperationGetAttrIntList(oper, "v", values, list_size, s_);
2082 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2083 EXPECT_TRUE(std::equal(std::begin(list), std::end(list), std::begin(values)));
2084 }
2085
TEST_F(CApiAttributesTest,Float)2086 TEST_F(CApiAttributesTest, Float) {
2087 auto desc = init("float");
2088 TF_SetAttrFloat(desc, "v", 2.718);
2089
2090 auto oper = TF_FinishOperation(desc, s_);
2091 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2092 EXPECT_TF_META("v", -1, TF_ATTR_FLOAT, -1);
2093
2094 float value;
2095 TF_OperationGetAttrFloat(oper, "v", &value, s_);
2096 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2097 EXPECT_FLOAT_EQ(2.718, value);
2098 }
2099
TEST_F(CApiAttributesTest,FloatList)2100 TEST_F(CApiAttributesTest, FloatList) {
2101 const float list[] = {1.414, 2.718, 3.1415};
2102 const size_t list_size = TF_ARRAYSIZE(list);
2103
2104 auto desc = init("list(float)");
2105 TF_SetAttrFloatList(desc, "v", list, list_size);
2106
2107 auto oper = TF_FinishOperation(desc, s_);
2108 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2109
2110 float values[list_size];
2111 EXPECT_TF_META("v", list_size, TF_ATTR_FLOAT, -1);
2112 TF_OperationGetAttrFloatList(oper, "v", values, list_size, s_);
2113 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2114 EXPECT_TRUE(std::equal(std::begin(list), std::end(list), std::begin(values)));
2115 }
2116
TEST_F(CApiAttributesTest,Bool)2117 TEST_F(CApiAttributesTest, Bool) {
2118 auto desc = init("bool");
2119 TF_SetAttrBool(desc, "v", 1);
2120
2121 auto oper = TF_FinishOperation(desc, s_);
2122 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2123 EXPECT_TF_META("v", -1, TF_ATTR_BOOL, -1);
2124
2125 unsigned char value;
2126 TF_OperationGetAttrBool(oper, "v", &value, s_);
2127 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2128 EXPECT_EQ(1, value);
2129 }
2130
TEST_F(CApiAttributesTest,BoolList)2131 TEST_F(CApiAttributesTest, BoolList) {
2132 const unsigned char list[] = {0, 1, 1, 0, 0, 1, 1};
2133 const size_t list_size = TF_ARRAYSIZE(list);
2134
2135 auto desc = init("list(bool)");
2136 TF_SetAttrBoolList(desc, "v", list, list_size);
2137
2138 auto oper = TF_FinishOperation(desc, s_);
2139 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2140
2141 unsigned char values[list_size];
2142 EXPECT_TF_META("v", list_size, TF_ATTR_BOOL, -1);
2143 TF_OperationGetAttrBoolList(oper, "v", values, list_size, s_);
2144 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2145 EXPECT_TRUE(std::equal(std::begin(list), std::end(list), std::begin(values)));
2146 }
2147
TEST_F(CApiAttributesTest,Type)2148 TEST_F(CApiAttributesTest, Type) {
2149 auto desc = init("type");
2150 TF_SetAttrType(desc, "v", TF_COMPLEX128);
2151
2152 auto oper = TF_FinishOperation(desc, s_);
2153 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2154 EXPECT_TF_META("v", -1, TF_ATTR_TYPE, -1);
2155
2156 TF_DataType value;
2157 TF_OperationGetAttrType(oper, "v", &value, s_);
2158 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2159 EXPECT_EQ(TF_COMPLEX128, value);
2160 }
2161
TEST_F(CApiAttributesTest,TypeList)2162 TEST_F(CApiAttributesTest, TypeList) {
2163 const TF_DataType list[] = {TF_FLOAT, TF_DOUBLE, TF_HALF, TF_COMPLEX128};
2164 const size_t list_size = TF_ARRAYSIZE(list);
2165
2166 auto desc = init("list(type)");
2167 TF_SetAttrTypeList(desc, "v", list, list_size);
2168
2169 auto oper = TF_FinishOperation(desc, s_);
2170 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2171
2172 TF_DataType values[list_size];
2173 EXPECT_TF_META("v", list_size, TF_ATTR_TYPE, -1);
2174 TF_OperationGetAttrTypeList(oper, "v", values, list_size, s_);
2175 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2176 EXPECT_TRUE(std::equal(std::begin(list), std::end(list), std::begin(values)));
2177 }
2178
TEST_F(CApiAttributesTest,Shape)2179 TEST_F(CApiAttributesTest, Shape) {
2180 // Unknown shape
2181 auto desc = init("shape");
2182 TF_SetAttrShape(desc, "v", nullptr, -1);
2183 auto oper = TF_FinishOperation(desc, s_);
2184 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2185 EXPECT_TF_META("v", -1, TF_ATTR_SHAPE, -1);
2186 TF_OperationGetAttrShape(oper, "v", nullptr, 10, s_);
2187 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2188
2189 // Partially specified shape
2190 const int64_t partial_shape[] = {17, -1};
2191 const size_t sz = TF_ARRAYSIZE(partial_shape);
2192 desc = init("shape");
2193 TF_SetAttrShape(desc, "v", partial_shape, sz);
2194 oper = TF_FinishOperation(desc, s_);
2195 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2196 EXPECT_TF_META("v", -1, TF_ATTR_SHAPE, sz);
2197 int64_t values[sz];
2198 TF_OperationGetAttrShape(oper, "v", values, sz, s_);
2199 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2200 EXPECT_TRUE(
2201 std::equal(std::begin(partial_shape), std::end(partial_shape), values));
2202 }
2203
TEST_F(CApiAttributesTest,ShapeList)2204 TEST_F(CApiAttributesTest, ShapeList) {
2205 const int64_t shape_1[] = {1, 3};
2206 const int64_t shape_2[] = {2, 4, 6};
2207 const int64_t* list[] = {&shape_1[0], &shape_2[0]};
2208 const size_t list_size = TF_ARRAYSIZE(list);
2209 const int ndims[] = {TF_ARRAYSIZE(shape_1), TF_ARRAYSIZE(shape_2)};
2210 const int total_ndims = 5; // ndims[0] + ndims[1]
2211
2212 auto desc = init("list(shape)");
2213 TF_SetAttrShapeList(desc, "v", list, ndims, list_size);
2214 auto oper = TF_FinishOperation(desc, s_);
2215 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2216
2217 EXPECT_TF_META("v", list_size, TF_ATTR_SHAPE, total_ndims);
2218 int64_t* values[list_size];
2219 int values_ndims[list_size];
2220 int64_t storage[total_ndims];
2221 TF_OperationGetAttrShapeList(oper, "v", values, values_ndims, list_size,
2222 storage, total_ndims, s_);
2223 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2224 for (size_t i = 0; i < list_size; ++i) {
2225 EXPECT_EQ(ndims[i], values_ndims[i]) << i;
2226 for (int j = 0; j < values_ndims[i]; ++j) {
2227 EXPECT_EQ(list[i][j], values[i][j]) << "(" << i << ", " << j << ")";
2228 }
2229 }
2230 }
2231
TEST_F(CApiAttributesTest,TensorShapeProto)2232 TEST_F(CApiAttributesTest, TensorShapeProto) {
2233 const tensorflow::int64 pts[] = {2, 4, -1, 8};
2234 tensorflow::TensorShapeProto proto;
2235 tensorflow::PartialTensorShape(pts).AsProto(&proto);
2236 string bytes;
2237 proto.SerializeToString(&bytes);
2238
2239 auto desc = init("shape");
2240 TF_SetAttrTensorShapeProto(desc, "v", bytes.data(), bytes.length(), s_);
2241 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2242 auto oper = TF_FinishOperation(desc, s_);
2243 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2244
2245 EXPECT_TF_META("v", -1, TF_ATTR_SHAPE, 4);
2246 TF_Buffer* value = TF_NewBuffer();
2247 TF_OperationGetAttrTensorShapeProto(oper, "v", value, s_);
2248 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2249 EXPECT_EQ(bytes.length(), value->length);
2250 EXPECT_EQ(0, memcmp(bytes.data(), value->data, value->length));
2251 TF_DeleteBuffer(value);
2252 }
2253
TEST_F(CApiAttributesTest,TensorShapeProtoList)2254 TEST_F(CApiAttributesTest, TensorShapeProtoList) {
2255 string bytes1, bytes2;
2256 tensorflow::TensorShapeProto proto;
2257
2258 const tensorflow::int64 pts1[] = {2, 4, -1, 8};
2259 tensorflow::PartialTensorShape(pts1).AsProto(&proto);
2260 proto.SerializeToString(&bytes1);
2261
2262 const tensorflow::int64 pts2[] = {1, 3, 5, 7};
2263 tensorflow::PartialTensorShape(pts2).AsProto(&proto);
2264 proto.SerializeToString(&bytes2);
2265
2266 std::unique_ptr<const void*[]> list_ptrs;
2267 std::unique_ptr<size_t[]> list_lens;
2268 const std::vector<string> list = {bytes1, bytes2};
2269 StringVectorToArrays(list, &list_ptrs, &list_lens);
2270
2271 auto desc = init("list(shape)");
2272 TF_SetAttrTensorShapeProtoList(desc, "v", list_ptrs.get(), list_lens.get(),
2273 list.size(), s_);
2274 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2275 auto oper = TF_FinishOperation(desc, s_);
2276 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2277
2278 EXPECT_TF_META("v", 2, TF_ATTR_SHAPE, 8);
2279 TF_Buffer* values[2];
2280 TF_OperationGetAttrTensorShapeProtoList(oper, "v", values, 2, s_);
2281 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2282 for (int i = 0; i < 2; ++i) {
2283 int le = list_lens[i];
2284 int la = values[i]->length;
2285 const void* e = list_ptrs[i];
2286 const void* a = values[i]->data;
2287 EXPECT_EQ(le, la) << i;
2288 EXPECT_EQ(0, memcmp(e, a, std::min(le, la))) << i;
2289 TF_DeleteBuffer(values[i]);
2290 }
2291 }
2292
TEST_F(CApiAttributesTest,Tensor)2293 TEST_F(CApiAttributesTest, Tensor) {
2294 const char tensor[] = {5, 7};
2295 const int64_t dims[] = {1, 2};
2296 const size_t ndims = TF_ARRAYSIZE(dims);
2297
2298 auto desc = init("tensor");
2299 unique_tensor_ptr v(Int8Tensor(dims, ndims, tensor), TF_DeleteTensor);
2300 TF_SetAttrTensor(desc, "v", v.get(), s_);
2301 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2302
2303 auto oper = TF_FinishOperation(desc, s_);
2304 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2305
2306 EXPECT_TF_META("v", -1, TF_ATTR_TENSOR, -1);
2307 TF_Tensor* value;
2308 TF_OperationGetAttrTensor(oper, "v", &value, s_);
2309 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2310 ASSERT_NE(nullptr, value);
2311 EXPECT_EQ(TF_INT8, TF_TensorType(value));
2312 EXPECT_EQ(ndims, TF_NumDims(value));
2313 for (int i = 0; i < TF_NumDims(value); ++i) {
2314 EXPECT_EQ(dims[i], TF_Dim(value, i)) << i;
2315 }
2316 EXPECT_EQ(sizeof(char) * TF_ARRAYSIZE(tensor), TF_TensorByteSize(value));
2317 EXPECT_EQ(0, memcmp(tensor, TF_TensorData(value), TF_TensorByteSize(value)));
2318 TF_DeleteTensor(value);
2319 }
2320
TEST_F(CApiAttributesTest,StringTensor)2321 TEST_F(CApiAttributesTest, StringTensor) {
2322 // Create the string-Tensor "attribute" value.
2323 const char test_string[] =
2324 "borkborkborkborkborkborkborkbork"; // >24bytes to force heap alloc
2325 TF_TString tstr[1];
2326 TF_TString_Init(&tstr[0]);
2327 TF_TString_Copy(&tstr[0], test_string, sizeof(test_string) - 1);
2328
2329 auto deallocator = [](void* data, size_t len, void* arg) {};
2330 unique_tensor_ptr t_in(TF_NewTensor(TF_STRING, nullptr, 0, &tstr[0],
2331 sizeof(tstr), deallocator, nullptr),
2332 TF_DeleteTensor);
2333
2334 // Create a TF_Operation with the attribute t_in
2335 auto desc = init("tensor");
2336 TF_SetAttrTensor(desc, "v", t_in.get(), s_);
2337 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2338
2339 auto oper = TF_FinishOperation(desc, s_);
2340 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2341
2342 // Fetch the attribute back.
2343 EXPECT_TF_META("v", -1, TF_ATTR_TENSOR, -1);
2344 TF_Tensor* t_out = nullptr;
2345 TF_OperationGetAttrTensor(oper, "v", &t_out, s_);
2346 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2347 EXPECT_EQ(TF_STRING, TF_TensorType(t_out));
2348 EXPECT_EQ(0, TF_NumDims(t_out));
2349 ASSERT_EQ(TF_TensorByteSize(t_in.get()), TF_TensorByteSize(t_out));
2350 TF_TString* t_in_tstr = static_cast<TF_TString*>(TF_TensorData(t_in.get()));
2351 TF_TString* t_out_tstr = static_cast<TF_TString*>(TF_TensorData(t_out));
2352 EXPECT_EQ(absl::string_view(test_string),
2353 absl::string_view(TF_TString_GetDataPointer(t_out_tstr),
2354 TF_TString_GetSize(t_out_tstr)));
2355 EXPECT_EQ(absl::string_view(TF_TString_GetDataPointer(t_in_tstr),
2356 TF_TString_GetSize(t_in_tstr)),
2357 absl::string_view(TF_TString_GetDataPointer(t_out_tstr),
2358 TF_TString_GetSize(t_out_tstr)));
2359 TF_DeleteTensor(t_out);
2360 TF_TString_Dealloc(&tstr[0]);
2361 }
2362
TEST_F(CApiAttributesTest,TensorList)2363 TEST_F(CApiAttributesTest, TensorList) {
2364 const char tensor1[] = {5, 7};
2365 const int64_t dims1[] = {1, 2};
2366 const size_t ndims1 = TF_ARRAYSIZE(dims1);
2367
2368 const char tensor2[] = {2, 4, 6, 8};
2369 const int64_t dims2[] = {2, 2};
2370 const size_t ndims2 = TF_ARRAYSIZE(dims2);
2371
2372 auto desc = init("list(tensor)");
2373 TF_Tensor* tmp[] = {
2374 Int8Tensor(dims1, ndims1, tensor1),
2375 Int8Tensor(dims2, ndims2, tensor2),
2376 };
2377 TF_SetAttrTensorList(desc, "v", tmp, TF_ARRAYSIZE(tmp), s_);
2378 for (int i = 0; i < TF_ARRAYSIZE(tmp); ++i) {
2379 TF_DeleteTensor(tmp[i]);
2380 }
2381 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2382 auto oper = TF_FinishOperation(desc, s_);
2383 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2384
2385 EXPECT_TF_META("v", 2, TF_ATTR_TENSOR, -1);
2386 TF_Tensor* values[2];
2387 TF_OperationGetAttrTensorList(oper, "v", &values[0], TF_ARRAYSIZE(values),
2388 s_);
2389 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2390
2391 const char* tensor_data[] = {&tensor1[0], &tensor2[0]};
2392 const size_t tensor_size[] = {TF_ARRAYSIZE(tensor1), TF_ARRAYSIZE(tensor2)};
2393 const int64_t* tensor_dims[] = {&dims1[0], &dims2[0]};
2394 const size_t tensor_ndims[] = {ndims1, ndims2};
2395 for (int i = 0; i < 2; ++i) {
2396 TF_Tensor* v = values[i];
2397 ASSERT_NE(nullptr, v) << i;
2398 EXPECT_EQ(TF_INT8, TF_TensorType(v)) << i;
2399 EXPECT_EQ(tensor_ndims[i], TF_NumDims(v)) << i;
2400 for (int j = 0; j < TF_NumDims(v); ++j) {
2401 EXPECT_EQ(tensor_dims[i][j], TF_Dim(v, j))
2402 << "Tensor #" << i << ", dimension #" << j;
2403 }
2404 EXPECT_EQ(sizeof(char) * tensor_size[i], TF_TensorByteSize(v)) << i;
2405 EXPECT_EQ(0,
2406 memcmp(tensor_data[i], TF_TensorData(v), TF_TensorByteSize(v)));
2407 TF_DeleteTensor(v);
2408 }
2409 }
2410
TEST_F(CApiAttributesTest,EmptyList)2411 TEST_F(CApiAttributesTest, EmptyList) {
2412 auto desc = init("list(int)");
2413 TF_SetAttrIntList(desc, "v", nullptr, 0);
2414 auto oper = TF_FinishOperation(desc, s_);
2415 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2416 EXPECT_TF_META("v", 0, TF_ATTR_INT, -1);
2417 }
2418
TEST_F(CApiAttributesTest,Errors)2419 TEST_F(CApiAttributesTest, Errors) {
2420 auto desc = init("int");
2421 TF_SetAttrInt(desc, "v", 3);
2422 auto oper = TF_FinishOperation(desc, s_);
2423 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2424 TF_OperationGetAttrString(oper, "v", nullptr, 0, s_);
2425 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
2426 }
2427
TEST(TestApiDef,TestCreateApiDef)2428 TEST(TestApiDef, TestCreateApiDef) {
2429 // TODO(b/73318067): Fix linking for the GPU test generated by the
2430 // tf_cuda_cc_test() bazel rule and remove the next line.
2431 if (!GPUDeviceName().empty()) return;
2432
2433 TF_Buffer* op_list_buf = TF_GetAllOpList();
2434 TF_Status* status = TF_NewStatus();
2435 auto* api_def_map = TF_NewApiDefMap(op_list_buf, status);
2436 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2437 TF_DeleteStatus(status);
2438
2439 string op_name = "TestCApi";
2440 status = TF_NewStatus();
2441 auto* api_def_buf =
2442 TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status);
2443 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2444 TF_DeleteStatus(status);
2445
2446 tensorflow::ApiDef api_def;
2447 EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length));
2448 EXPECT_EQ(op_name, api_def.graph_op_name());
2449 EXPECT_EQ(R"doc(Used to test C API)doc", api_def.summary());
2450
2451 TF_DeleteBuffer(api_def_buf);
2452 TF_DeleteApiDefMap(api_def_map);
2453 TF_DeleteBuffer(op_list_buf);
2454 }
2455
TEST(TestApiDef,TestCreateApiDefWithOverwrites)2456 TEST(TestApiDef, TestCreateApiDefWithOverwrites) {
2457 // TODO(b/73318067): Fix linking for the GPU test generated by the
2458 // tf_cuda_cc_test() bazel rule and remove the next line.
2459 if (!GPUDeviceName().empty()) return;
2460
2461 TF_Buffer* op_list_buf = TF_GetAllOpList();
2462 TF_Status* status = TF_NewStatus();
2463 auto* api_def_map = TF_NewApiDefMap(op_list_buf, status);
2464 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2465 TF_DeleteStatus(status);
2466
2467 string api_def_overwrites = R"(op: <
2468 graph_op_name: "TestCApi"
2469 summary: "New summary"
2470 >
2471 )";
2472 status = TF_NewStatus();
2473 TF_ApiDefMapPut(api_def_map, api_def_overwrites.c_str(),
2474 api_def_overwrites.size(), status);
2475 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2476 TF_DeleteStatus(status);
2477
2478 string op_name = "TestCApi";
2479 status = TF_NewStatus();
2480 auto* api_def_buf =
2481 TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status);
2482 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2483 TF_DeleteStatus(status);
2484
2485 tensorflow::ApiDef api_def;
2486 EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length));
2487 EXPECT_EQ(op_name, api_def.graph_op_name());
2488 EXPECT_EQ("New summary", api_def.summary());
2489
2490 TF_DeleteBuffer(api_def_buf);
2491 TF_DeleteApiDefMap(api_def_map);
2492 TF_DeleteBuffer(op_list_buf);
2493 }
2494
2495 class DummyKernel : public tensorflow::OpKernel {
2496 public:
DummyKernel(tensorflow::OpKernelConstruction * context)2497 explicit DummyKernel(tensorflow::OpKernelConstruction* context)
2498 : OpKernel(context) {}
Compute(tensorflow::OpKernelContext * context)2499 void Compute(tensorflow::OpKernelContext* context) override {}
2500 };
2501
2502 // Test we can query kernels
2503 REGISTER_OP("TestOpWithSingleKernel")
2504 .Input("a: float")
2505 .Input("b: float")
2506 .Output("o: float");
2507 REGISTER_KERNEL_BUILDER(
2508 Name("TestOpWithSingleKernel").Device(tensorflow::DEVICE_CPU), DummyKernel);
2509
TEST(TestKernel,TestGetAllRegisteredKernels)2510 TEST(TestKernel, TestGetAllRegisteredKernels) {
2511 TF_Status* status = TF_NewStatus();
2512 TF_Buffer* kernel_list_buf = TF_GetAllRegisteredKernels(status);
2513 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2514 KernelList kernel_list;
2515 kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length);
2516 ASSERT_GT(kernel_list.kernel_size(), 0);
2517 TF_DeleteBuffer(kernel_list_buf);
2518 TF_DeleteStatus(status);
2519 }
2520
TEST(TestKernel,TestGetRegisteredKernelsForOp)2521 TEST(TestKernel, TestGetRegisteredKernelsForOp) {
2522 TF_Status* status = TF_NewStatus();
2523 TF_Buffer* kernel_list_buf =
2524 TF_GetRegisteredKernelsForOp("TestOpWithSingleKernel", status);
2525 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2526 KernelList kernel_list;
2527 kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length);
2528 ASSERT_EQ(kernel_list.kernel_size(), 1);
2529 EXPECT_EQ(kernel_list.kernel(0).op(), "TestOpWithSingleKernel");
2530 EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU");
2531 TF_DeleteBuffer(kernel_list_buf);
2532 TF_DeleteStatus(status);
2533 }
2534
TEST(TestKernel,TestGetRegisteredKernelsForOpNoKernels)2535 TEST(TestKernel, TestGetRegisteredKernelsForOpNoKernels) {
2536 TF_Status* status = TF_NewStatus();
2537 TF_Buffer* kernel_list_buf = TF_GetRegisteredKernelsForOp("Unknown", status);
2538 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2539 KernelList kernel_list;
2540 kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length);
2541 ASSERT_EQ(kernel_list.kernel_size(), 0);
2542 TF_DeleteBuffer(kernel_list_buf);
2543 TF_DeleteStatus(status);
2544 }
2545
2546 #undef EXPECT_TF_META
2547
TEST(CAPI,TestTensorAligned)2548 TEST(CAPI, TestTensorAligned) {
2549 int64_t dim = 7;
2550 size_t tensor_size_bytes = dim * TF_DataTypeSize(TF_FLOAT);
2551 TF_Tensor* a = TF_AllocateTensor(
2552 /*dtype=*/TF_FLOAT, /*dims=*/&dim, /*num_dims=*/1,
2553 /*len=*/tensor_size_bytes);
2554 float* data = reinterpret_cast<float*>(TF_TensorData(a));
2555 for (int i = 0; i < dim; ++i) {
2556 data[i] = 0;
2557 }
2558 if (EIGEN_MAX_ALIGN_BYTES > 0) {
2559 EXPECT_TRUE(TF_TensorIsAligned(a));
2560 }
2561 TF_DeleteTensor(a);
2562 }
2563
TEST(CAPI,TestTensorIsNotAligned)2564 TEST(CAPI, TestTensorIsNotAligned) {
2565 // Test unaligned access via a Slice.
2566 Tensor x(DT_FLOAT, TensorShape({30}));
2567 x.flat<float>().setConstant(0.0);
2568
2569 // Take an unaligned slice.
2570 Tensor y = x.Slice(1, 13);
2571 Status status;
2572 TF_Tensor* a = TF_TensorFromTensor(y, &status);
2573 if (EIGEN_MAX_ALIGN_BYTES > 0) {
2574 EXPECT_FALSE(TF_TensorIsAligned(a));
2575 }
2576 TF_DeleteTensor(a);
2577 }
2578
2579 } // namespace
2580 } // namespace tensorflow
2581
2582 // TODO(josh11b): Test:
2583 // * TF_SetDevice(desc, "/job:worker");
2584 // * control inputs / outputs
2585 // * targets
2586 // * TF_DeleteGraph() before TF_DeleteSession()
2587