• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/c_api.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <iterator>
21 #include <memory>
22 #include <vector>
23 
24 #include "tensorflow/c/c_test_util.h"
25 #include "tensorflow/c/tf_status.h"
26 #include "tensorflow/cc/saved_model/signature_constants.h"
27 #include "tensorflow/cc/saved_model/tag_constants.h"
28 #include "tensorflow/core/example/example.pb.h"
29 #include "tensorflow/core/example/feature.pb.h"
30 #include "tensorflow/core/framework/api_def.pb.h"
31 #include "tensorflow/core/framework/common_shape_fns.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/framework/kernel_def.pb.h"
34 #include "tensorflow/core/framework/node_def.pb.h"
35 #include "tensorflow/core/framework/node_def_util.h"
36 #include "tensorflow/core/framework/op.h"
37 #include "tensorflow/core/framework/op_def.pb.h"
38 #include "tensorflow/core/framework/op_kernel.h"
39 #include "tensorflow/core/framework/partial_tensor_shape.h"
40 #include "tensorflow/core/framework/tensor.h"
41 #include "tensorflow/core/framework/tensor_shape.pb.h"
42 #include "tensorflow/core/framework/types.pb.h"
43 #include "tensorflow/core/graph/tensor_id.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45 #include "tensorflow/core/lib/io/path.h"
46 #include "tensorflow/core/platform/path.h"
47 #include "tensorflow/core/platform/resource_loader.h"
48 #include "tensorflow/core/platform/str_util.h"
49 #include "tensorflow/core/platform/strcat.h"
50 #include "tensorflow/core/platform/test.h"
51 #include "tensorflow/core/protobuf/error_codes.pb.h"
52 #include "tensorflow/core/protobuf/meta_graph.pb.h"
53 #include "tensorflow/core/util/equal_graph_def.h"
54 
55 namespace tensorflow {
56 TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
57 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
58 
59 namespace {
60 
ExpectHasSubstr(StringPiece s,StringPiece expected)61 static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
62   EXPECT_TRUE(absl::StrContains(s, expected))
63       << "'" << s << "' does not contain '" << expected << "'";
64 }
65 
66 // Returns the GPU device name if there is one (with arbitrary tie breaking if
67 // there are more than one), or "" otherwise.
GPUDeviceName(TF_Session * session)68 string GPUDeviceName(TF_Session* session) {
69   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
70       TF_NewStatus(), TF_DeleteStatus);
71   TF_Status* s = status.get();
72   std::unique_ptr<TF_DeviceList, decltype(&TF_DeleteDeviceList)> list(
73       TF_SessionListDevices(session, s), TF_DeleteDeviceList);
74   TF_DeviceList* device_list = list.get();
75 
76   CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
77 
78   const int num_devices = TF_DeviceListCount(device_list);
79   LOG(INFO) << "There are " << num_devices << " devices.";
80   for (int i = 0; i < num_devices; ++i) {
81     const char* device_name = TF_DeviceListName(device_list, i, s);
82     CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
83     const char* device_type = TF_DeviceListType(device_list, i, s);
84     CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
85     LOG(INFO) << "Device " << i << " has name " << device_name << ", type "
86               << device_type;
87     if (string(device_type) == DEVICE_GPU) {
88       return device_name;
89     }
90   }
91   // No GPU device found.
92   return "";
93 }
94 
GPUDeviceName()95 string GPUDeviceName() {
96   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
97       TF_NewStatus(), TF_DeleteStatus);
98   TF_Status* s = status.get();
99   std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph(TF_NewGraph(),
100                                                              TF_DeleteGraph);
101 
102   TF_SessionOptions* opts = TF_NewSessionOptions();
103   TF_Session* sess = TF_NewSession(graph.get(), opts, s);
104   TF_DeleteSessionOptions(opts);
105 
106   const string gpu_device_name = GPUDeviceName(sess);
107   TF_DeleteSession(sess, s);
108   CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
109   return gpu_device_name;
110 }
111 
TEST(CAPI,Version)112 TEST(CAPI, Version) { EXPECT_STRNE("", TF_Version()); }
113 
TEST(CAPI,Status)114 TEST(CAPI, Status) {
115   TF_Status* s = TF_NewStatus();
116   EXPECT_EQ(TF_OK, TF_GetCode(s));
117   EXPECT_EQ(string(), TF_Message(s));
118   TF_SetStatus(s, TF_CANCELLED, "cancel");
119   EXPECT_EQ(TF_CANCELLED, TF_GetCode(s));
120   EXPECT_EQ(string("cancel"), TF_Message(s));
121   TF_DeleteStatus(s);
122 }
123 
Deallocator(void * data,size_t,void * arg)124 void Deallocator(void* data, size_t, void* arg) {
125   tensorflow::cpu_allocator()->DeallocateRaw(data);
126   *reinterpret_cast<bool*>(arg) = true;
127 }
128 
TEST(CAPI,Tensor)129 TEST(CAPI, Tensor) {
130   const int num_bytes = 6 * sizeof(float);
131   float* values =
132       reinterpret_cast<float*>(tensorflow::cpu_allocator()->AllocateRaw(
133           EIGEN_MAX_ALIGN_BYTES, num_bytes));
134   int64_t dims[] = {2, 3};
135   bool deallocator_called = false;
136   TF_Tensor* t = TF_NewTensor(TF_FLOAT, dims, 2, values, num_bytes,
137                               &Deallocator, &deallocator_called);
138   EXPECT_FALSE(deallocator_called);
139   EXPECT_EQ(TF_FLOAT, TF_TensorType(t));
140   EXPECT_EQ(2, TF_NumDims(t));
141   EXPECT_EQ(dims[0], TF_Dim(t, 0));
142   EXPECT_EQ(dims[1], TF_Dim(t, 1));
143   EXPECT_EQ(num_bytes, TF_TensorByteSize(t));
144   EXPECT_EQ(static_cast<void*>(values), TF_TensorData(t));
145   TF_DeleteTensor(t);
146   EXPECT_TRUE(deallocator_called);
147 }
148 
NoOpDeallocator(void * data,size_t,void *)149 void NoOpDeallocator(void* data, size_t, void*) {}
150 
TEST(CAPI,MalformedTensor)151 TEST(CAPI, MalformedTensor) {
152   // See https://github.com/tensorflow/tensorflow/issues/7394
153   // num_dims = 0 implies a scalar, so should be backed by at least 4 bytes of
154   // data.
155   TF_Tensor* t =
156       TF_NewTensor(TF_FLOAT, nullptr, 0, nullptr, 0, &NoOpDeallocator, nullptr);
157   ASSERT_TRUE(t == nullptr);
158 }
159 
TEST(CAPI,AllocateTensor)160 TEST(CAPI, AllocateTensor) {
161   const int num_bytes = 6 * sizeof(float);
162   int64_t dims[] = {2, 3};
163   TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims, 2, num_bytes);
164   EXPECT_EQ(TF_FLOAT, TF_TensorType(t));
165   EXPECT_EQ(2, TF_NumDims(t));
166   EXPECT_EQ(dims[0], TF_Dim(t, 0));
167   EXPECT_EQ(dims[1], TF_Dim(t, 1));
168   EXPECT_EQ(num_bytes, TF_TensorByteSize(t));
169   EXPECT_EQ(6, TF_TensorElementCount(t));
170   TF_DeleteTensor(t);
171 }
172 
TEST(CAPI,MaybeMove)173 TEST(CAPI, MaybeMove) {
174   const int num_bytes = 6 * sizeof(float);
175   float* values =
176       reinterpret_cast<float*>(tensorflow::cpu_allocator()->AllocateRaw(
177           EIGEN_MAX_ALIGN_BYTES, num_bytes));
178   int64_t dims[] = {2, 3};
179   bool deallocator_called = false;
180   TF_Tensor* t = TF_NewTensor(TF_FLOAT, dims, 2, values, num_bytes,
181                               &Deallocator, &deallocator_called);
182 
183   TF_Tensor* o = TF_TensorMaybeMove(t);
184   ASSERT_TRUE(o == nullptr);  // It is unsafe to move memory TF might not own.
185   TF_DeleteTensor(t);
186   EXPECT_TRUE(deallocator_called);
187 }
188 
TEST(CAPI,LibraryLoadFunctions)189 TEST(CAPI, LibraryLoadFunctions) {
190   // TODO(b/73318067): Fix linking for the GPU test generated by the
191   // tf_cuda_cc_test() bazel rule and remove the next line.
192   if (!GPUDeviceName().empty()) return;
193 
194 #if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
195   {
196     // Load the library.
197     TF_Status* status = TF_NewStatus();
198     string lib_path = tensorflow::GetDataDependencyFilepath(
199         tensorflow::io::JoinPath("tensorflow", "c", "test_op1.so"));
200     TF_Library* lib = TF_LoadLibrary(lib_path.c_str(), status);
201     TF_Code code = TF_GetCode(status);
202     string status_msg(TF_Message(status));
203     TF_DeleteStatus(status);
204     ASSERT_EQ(TF_OK, code) << status_msg;
205 
206     // Test op list.
207     TF_Buffer op_list_buf = TF_GetOpList(lib);
208     tensorflow::OpList op_list;
209     EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length));
210     ASSERT_EQ(op_list.op_size(), 1);
211     EXPECT_EQ("TestCApi1", op_list.op(0).name());
212     TF_DeleteLibraryHandle(lib);
213   }
214 #endif  // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
215   {
216     TF_Buffer* op_list_buffer = TF_GetAllOpList();
217     tensorflow::OpList op_list;
218     op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
219     ASSERT_GE(op_list.op_size(), 1);
220     typedef tensorflow::protobuf::RepeatedPtrField<tensorflow::OpDef> OpDefs;
221     const OpDefs& ops = op_list.op();
222     bool found = std::find_if(ops.begin(), ops.end(),
223                               [](const tensorflow::OpDef& op_def) {
224                                 return op_def.name() == "TestCApi";
225                               }) != ops.end();
226     EXPECT_TRUE(found);
227     TF_DeleteBuffer(op_list_buffer);
228   }
229 }
230 
TestEncodeDecode(int line,const std::vector<string> & data)231 void TestEncodeDecode(int line, const std::vector<string>& data) {
232   const tensorflow::int64 n = data.size();
233   Status status;
234   for (const std::vector<tensorflow::int64>& dims :
235        std::vector<std::vector<tensorflow::int64>>{
236            {n}, {1, n}, {n, 1}, {n / 2, 2}}) {
237     // Create C++ Tensor
238     Tensor src(tensorflow::DT_STRING, TensorShape(dims));
239     for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
240       src.flat<tstring>()(i) = data[i];
241     }
242     TF_Tensor* dst = TF_TensorFromTensor(src, &status);
243     ASSERT_TRUE(status.ok()) << status.error_message();
244 
245     // Convert back to a C++ Tensor and ensure we get expected output.
246     Tensor output;
247     ASSERT_EQ(Status::OK(), TF_TensorToTensor(dst, &output)) << line;
248     ASSERT_EQ(src.NumElements(), output.NumElements()) << line;
249     for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
250       ASSERT_EQ(data[i], output.flat<tstring>()(i)) << line;
251     }
252 
253     TF_DeleteTensor(dst);
254   }
255 }
256 
TEST(CAPI,TensorEncodeDecodeStrings)257 TEST(CAPI, TensorEncodeDecodeStrings) {
258   TestEncodeDecode(__LINE__, {});
259   TestEncodeDecode(__LINE__, {"hello"});
260   TestEncodeDecode(__LINE__,
261                    {"the", "quick", "brown", "fox", "jumped", "over"});
262 
263   string big(1000, 'a');
264   TestEncodeDecode(__LINE__, {"small", big, "small2"});
265 }
266 
TEST(CAPI,SessionOptions)267 TEST(CAPI, SessionOptions) {
268   TF_SessionOptions* opt = TF_NewSessionOptions();
269   TF_DeleteSessionOptions(opt);
270 }
271 
TEST(CAPI,DeprecatedSession)272 TEST(CAPI, DeprecatedSession) {
273   TF_Status* s = TF_NewStatus();
274   TF_SessionOptions* opt = TF_NewSessionOptions();
275   TF_DeprecatedSession* session = TF_NewDeprecatedSession(opt, s);
276   TF_DeleteSessionOptions(opt);
277   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
278 
279   TF_Buffer* run_options = TF_NewBufferFromString("", 0);
280   TF_Buffer* run_metadata = TF_NewBuffer();
281   TF_Run(session, run_options, nullptr, nullptr, 0, nullptr, nullptr, 0,
282          nullptr, 0, run_metadata, s);
283   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s);
284   EXPECT_EQ("Session was not created with a graph before Run()!",
285             string(TF_Message(s)));
286   TF_DeleteBuffer(run_metadata);
287   TF_DeleteBuffer(run_options);
288 
289   TF_DeleteDeprecatedSession(session, s);
290   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
291 
292   TF_DeleteStatus(s);
293 }
294 
TEST(CAPI,DataTypeEnum)295 TEST(CAPI, DataTypeEnum) {
296   EXPECT_EQ(TF_FLOAT, static_cast<TF_DataType>(tensorflow::DT_FLOAT));
297   EXPECT_EQ(TF_DOUBLE, static_cast<TF_DataType>(tensorflow::DT_DOUBLE));
298   EXPECT_EQ(TF_INT32, static_cast<TF_DataType>(tensorflow::DT_INT32));
299   EXPECT_EQ(TF_UINT8, static_cast<TF_DataType>(tensorflow::DT_UINT8));
300   EXPECT_EQ(TF_INT16, static_cast<TF_DataType>(tensorflow::DT_INT16));
301   EXPECT_EQ(TF_INT8, static_cast<TF_DataType>(tensorflow::DT_INT8));
302   EXPECT_EQ(TF_STRING, static_cast<TF_DataType>(tensorflow::DT_STRING));
303   EXPECT_EQ(TF_COMPLEX64, static_cast<TF_DataType>(tensorflow::DT_COMPLEX64));
304   EXPECT_EQ(TF_COMPLEX, TF_COMPLEX64);
305   EXPECT_EQ(TF_INT64, static_cast<TF_DataType>(tensorflow::DT_INT64));
306   EXPECT_EQ(TF_BOOL, static_cast<TF_DataType>(tensorflow::DT_BOOL));
307   EXPECT_EQ(TF_QINT8, static_cast<TF_DataType>(tensorflow::DT_QINT8));
308   EXPECT_EQ(TF_QUINT8, static_cast<TF_DataType>(tensorflow::DT_QUINT8));
309   EXPECT_EQ(TF_QINT32, static_cast<TF_DataType>(tensorflow::DT_QINT32));
310   EXPECT_EQ(TF_BFLOAT16, static_cast<TF_DataType>(tensorflow::DT_BFLOAT16));
311   EXPECT_EQ(TF_QINT16, static_cast<TF_DataType>(tensorflow::DT_QINT16));
312   EXPECT_EQ(TF_QUINT16, static_cast<TF_DataType>(tensorflow::DT_QUINT16));
313   EXPECT_EQ(TF_UINT16, static_cast<TF_DataType>(tensorflow::DT_UINT16));
314   EXPECT_EQ(TF_COMPLEX128, static_cast<TF_DataType>(tensorflow::DT_COMPLEX128));
315   EXPECT_EQ(TF_HALF, static_cast<TF_DataType>(tensorflow::DT_HALF));
316   EXPECT_EQ(TF_DataTypeSize(TF_DOUBLE),
317             tensorflow::DataTypeSize(tensorflow::DT_DOUBLE));
318   EXPECT_EQ(TF_DataTypeSize(TF_STRING),
319             tensorflow::DataTypeSize(tensorflow::DT_STRING));
320   // Test with invalid type; should always return 0 as documented
321   EXPECT_EQ(TF_DataTypeSize(static_cast<TF_DataType>(0)), 0);
322 }
323 
TEST(CAPI,StatusEnum)324 TEST(CAPI, StatusEnum) {
325   EXPECT_EQ(TF_OK, static_cast<TF_Code>(tensorflow::error::OK));
326   EXPECT_EQ(TF_CANCELLED, static_cast<TF_Code>(tensorflow::error::CANCELLED));
327   EXPECT_EQ(TF_UNKNOWN, static_cast<TF_Code>(tensorflow::error::UNKNOWN));
328   EXPECT_EQ(TF_INVALID_ARGUMENT,
329             static_cast<TF_Code>(tensorflow::error::INVALID_ARGUMENT));
330   EXPECT_EQ(TF_DEADLINE_EXCEEDED,
331             static_cast<TF_Code>(tensorflow::error::DEADLINE_EXCEEDED));
332   EXPECT_EQ(TF_NOT_FOUND, static_cast<TF_Code>(tensorflow::error::NOT_FOUND));
333   EXPECT_EQ(TF_ALREADY_EXISTS,
334             static_cast<TF_Code>(tensorflow::error::ALREADY_EXISTS));
335   EXPECT_EQ(TF_PERMISSION_DENIED,
336             static_cast<TF_Code>(tensorflow::error::PERMISSION_DENIED));
337   EXPECT_EQ(TF_UNAUTHENTICATED,
338             static_cast<TF_Code>(tensorflow::error::UNAUTHENTICATED));
339   EXPECT_EQ(TF_RESOURCE_EXHAUSTED,
340             static_cast<TF_Code>(tensorflow::error::RESOURCE_EXHAUSTED));
341   EXPECT_EQ(TF_FAILED_PRECONDITION,
342             static_cast<TF_Code>(tensorflow::error::FAILED_PRECONDITION));
343   EXPECT_EQ(TF_ABORTED, static_cast<TF_Code>(tensorflow::error::ABORTED));
344   EXPECT_EQ(TF_OUT_OF_RANGE,
345             static_cast<TF_Code>(tensorflow::error::OUT_OF_RANGE));
346   EXPECT_EQ(TF_UNIMPLEMENTED,
347             static_cast<TF_Code>(tensorflow::error::UNIMPLEMENTED));
348   EXPECT_EQ(TF_INTERNAL, static_cast<TF_Code>(tensorflow::error::INTERNAL));
349   EXPECT_EQ(TF_UNAVAILABLE,
350             static_cast<TF_Code>(tensorflow::error::UNAVAILABLE));
351   EXPECT_EQ(TF_DATA_LOSS, static_cast<TF_Code>(tensorflow::error::DATA_LOSS));
352 }
353 
TEST(CAPI,GetAllOpList)354 TEST(CAPI, GetAllOpList) {
355   TF_Buffer* buf = TF_GetAllOpList();
356   tensorflow::OpList op_list;
357   EXPECT_TRUE(op_list.ParseFromArray(buf->data, buf->length));
358   EXPECT_GT(op_list.op_size(), 0);
359   TF_DeleteBuffer(buf);
360 }
361 
TEST(CAPI,SetShape)362 TEST(CAPI, SetShape) {
363   TF_Status* s = TF_NewStatus();
364   TF_Graph* graph = TF_NewGraph();
365 
366   TF_Operation* feed = Placeholder(graph, s);
367   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
368   TF_Output feed_out_0 = TF_Output{feed, 0};
369   int num_dims;
370 
371   // Fetch the shape, it should be completely unknown.
372   num_dims = TF_GraphGetTensorNumDims(graph, feed_out_0, s);
373   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
374   EXPECT_EQ(-1, num_dims);
375 
376   // Set the shape to be unknown, expect no change.
377   TF_GraphSetTensorShape(graph, feed_out_0, /*dims=*/nullptr, -1, s);
378   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
379   num_dims = TF_GraphGetTensorNumDims(graph, feed_out_0, s);
380   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
381   EXPECT_EQ(-1, num_dims);
382 
383   // Set the shape to be 2 x Unknown
384   int64_t dims[] = {2, -1};
385   TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
386   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
387 
388   // Fetch the shape and validate it is 2 by -1.
389   num_dims = TF_GraphGetTensorNumDims(graph, feed_out_0, s);
390   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
391   EXPECT_EQ(2, num_dims);
392 
393   // Resize the dimension vector appropriately.
394   int64_t returned_dims[2];
395   TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
396   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
397   EXPECT_EQ(dims[0], returned_dims[0]);
398   EXPECT_EQ(dims[1], returned_dims[1]);
399 
400   // Set to a new valid shape: [2, 3]
401   dims[1] = 3;
402   TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
403   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
404 
405   // Fetch and see that the new value is returned.
406   TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
407   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
408   EXPECT_EQ(dims[0], returned_dims[0]);
409   EXPECT_EQ(dims[1], returned_dims[1]);
410 
411   // Try to set 'unknown' with unknown rank on the shape and see that
412   // it doesn't change.
413   TF_GraphSetTensorShape(graph, feed_out_0, /*dims=*/nullptr, -1, s);
414   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
415   TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
416   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
417   EXPECT_EQ(2, num_dims);
418   EXPECT_EQ(2, returned_dims[0]);
419   EXPECT_EQ(3, returned_dims[1]);
420 
421   // Try to set 'unknown' with same rank on the shape and see that
422   // it doesn't change.
423   dims[0] = -1;
424   dims[1] = -1;
425   TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
426   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
427   // Fetch and see that the new value is returned.
428   TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
429   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
430   EXPECT_EQ(2, num_dims);
431   EXPECT_EQ(2, returned_dims[0]);
432   EXPECT_EQ(3, returned_dims[1]);
433 
434   // Try to fetch a shape with the wrong num_dims
435   TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s);
436   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s);
437 
438   // Try to set an invalid shape (cannot change 2x3 to a 2x5).
439   dims[1] = 5;
440   TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
441   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s);
442 
443   // Test for a scalar.
444   TF_Operation* three = ScalarConst(3, graph, s);
445   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
446   TF_Output three_out_0 = TF_Output{three, 0};
447 
448   num_dims = TF_GraphGetTensorNumDims(graph, three_out_0, s);
449   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
450   EXPECT_EQ(0, num_dims);
451   TF_GraphGetTensorShape(graph, three_out_0, returned_dims, num_dims, s);
452   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
453 
454   // Clean up
455   TF_DeleteGraph(graph);
456   TF_DeleteStatus(s);
457 }
458 
TEST(CAPI,Graph)459 TEST(CAPI, Graph) {
460   TF_Status* s = TF_NewStatus();
461   TF_Graph* graph = TF_NewGraph();
462 
463   // Make a placeholder operation.
464   TF_Operation* feed = Placeholder(graph, s);
465   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
466 
467   // Test TF_Operation*() query functions.
468   EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
469   EXPECT_EQ(string("Placeholder"), string(TF_OperationOpType(feed)));
470   EXPECT_EQ(string(""), string(TF_OperationDevice(feed)));
471   EXPECT_EQ(1, TF_OperationNumOutputs(feed));
472   EXPECT_EQ(TF_INT32, TF_OperationOutputType(TF_Output{feed, 0}));
473   EXPECT_EQ(1, TF_OperationOutputListLength(feed, "output", s));
474   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
475   EXPECT_EQ(0, TF_OperationNumInputs(feed));
476   EXPECT_EQ(0, TF_OperationOutputNumConsumers(TF_Output{feed, 0}));
477   EXPECT_EQ(0, TF_OperationNumControlInputs(feed));
478   EXPECT_EQ(0, TF_OperationNumControlOutputs(feed));
479 
480   tensorflow::AttrValue attr_value;
481   ASSERT_TRUE(GetAttrValue(feed, "dtype", &attr_value, s)) << TF_Message(s);
482   EXPECT_EQ(attr_value.type(), tensorflow::DT_INT32);
483 
484   // Test not found errors in TF_Operation*() query functions.
485   EXPECT_EQ(-1, TF_OperationOutputListLength(feed, "bogus", s));
486   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s));
487 
488   ASSERT_FALSE(GetAttrValue(feed, "missing", &attr_value, s));
489   EXPECT_EQ(string("Operation 'feed' has no attr named 'missing'."),
490             string(TF_Message(s)));
491 
492   // Make a constant oper with the scalar "3".
493   TF_Operation* three = ScalarConst(3, graph, s);
494   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
495 
496   // Add oper.
497   TF_Operation* add = Add(feed, three, graph, s);
498   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
499 
500   // Test TF_Operation*() query functions.
501   EXPECT_EQ(string("add"), string(TF_OperationName(add)));
502   EXPECT_EQ(string("AddN"), string(TF_OperationOpType(add)));
503   EXPECT_EQ(string(""), string(TF_OperationDevice(add)));
504   EXPECT_EQ(1, TF_OperationNumOutputs(add));
505   EXPECT_EQ(TF_INT32, TF_OperationOutputType(TF_Output{add, 0}));
506   EXPECT_EQ(1, TF_OperationOutputListLength(add, "sum", s));
507   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
508   EXPECT_EQ(2, TF_OperationNumInputs(add));
509   EXPECT_EQ(2, TF_OperationInputListLength(add, "inputs", s));
510   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
511   EXPECT_EQ(TF_INT32, TF_OperationInputType(TF_Input{add, 0}));
512   EXPECT_EQ(TF_INT32, TF_OperationInputType(TF_Input{add, 1}));
513   TF_Output add_in_0 = TF_OperationInput(TF_Input{add, 0});
514   EXPECT_EQ(feed, add_in_0.oper);
515   EXPECT_EQ(0, add_in_0.index);
516   TF_Output add_in_1 = TF_OperationInput(TF_Input{add, 1});
517   EXPECT_EQ(three, add_in_1.oper);
518   EXPECT_EQ(0, add_in_1.index);
519   EXPECT_EQ(0, TF_OperationOutputNumConsumers(TF_Output{add, 0}));
520   EXPECT_EQ(0, TF_OperationNumControlInputs(add));
521   EXPECT_EQ(0, TF_OperationNumControlOutputs(add));
522 
523   ASSERT_TRUE(GetAttrValue(add, "T", &attr_value, s)) << TF_Message(s);
524   EXPECT_EQ(attr_value.type(), tensorflow::DT_INT32);
525   ASSERT_TRUE(GetAttrValue(add, "N", &attr_value, s)) << TF_Message(s);
526   EXPECT_EQ(attr_value.i(), 2);
527 
528   // Placeholder oper now has a consumer.
529   ASSERT_EQ(1, TF_OperationOutputNumConsumers(TF_Output{feed, 0}));
530   TF_Input feed_port;
531   EXPECT_EQ(1, TF_OperationOutputConsumers(TF_Output{feed, 0}, &feed_port, 1));
532   EXPECT_EQ(add, feed_port.oper);
533   EXPECT_EQ(0, feed_port.index);
534 
535   // The scalar const oper also has a consumer.
536   ASSERT_EQ(1, TF_OperationOutputNumConsumers(TF_Output{three, 0}));
537   TF_Input three_port;
538   EXPECT_EQ(1,
539             TF_OperationOutputConsumers(TF_Output{three, 0}, &three_port, 1));
540   EXPECT_EQ(add, three_port.oper);
541   EXPECT_EQ(1, three_port.index);
542 
543   // Serialize to GraphDef.
544   GraphDef graph_def;
545   ASSERT_TRUE(GetGraphDef(graph, &graph_def));
546 
547   // Validate GraphDef is what we expect.
548   bool found_placeholder = false;
549   bool found_scalar_const = false;
550   bool found_add = false;
551   for (const auto& n : graph_def.node()) {
552     if (IsPlaceholder(n)) {
553       EXPECT_FALSE(found_placeholder);
554       found_placeholder = true;
555     } else if (IsScalarConst(n, 3)) {
556       EXPECT_FALSE(found_scalar_const);
557       found_scalar_const = true;
558     } else if (IsAddN(n, 2)) {
559       EXPECT_FALSE(found_add);
560       found_add = true;
561     } else {
562       ADD_FAILURE() << "Unexpected NodeDef: " << n.DebugString();
563     }
564   }
565   EXPECT_TRUE(found_placeholder);
566   EXPECT_TRUE(found_scalar_const);
567   EXPECT_TRUE(found_add);
568 
569   // Add another oper to the graph.
570   TF_Operation* neg = Neg(add, graph, s);
571   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
572 
573   // Serialize to NodeDef.
574   NodeDef node_def;
575   ASSERT_TRUE(GetNodeDef(neg, &node_def));
576 
577   // Validate NodeDef is what we expect.
578   EXPECT_TRUE(IsNeg(node_def, "add"));
579 
580   // Serialize to GraphDef.
581   GraphDef graph_def2;
582   ASSERT_TRUE(GetGraphDef(graph, &graph_def2));
583 
584   // Compare with first GraphDef + added NodeDef.
585   NodeDef* added_node = graph_def.add_node();
586   *added_node = node_def;
587   EXPECT_EQ(graph_def.DebugString(), graph_def2.DebugString());
588 
589   // Look up some nodes by name.
590   TF_Operation* neg2 = TF_GraphOperationByName(graph, "neg");
591   EXPECT_TRUE(neg == neg2);
592   NodeDef node_def2;
593   ASSERT_TRUE(GetNodeDef(neg2, &node_def2));
594   EXPECT_EQ(node_def.DebugString(), node_def2.DebugString());
595 
596   TF_Operation* feed2 = TF_GraphOperationByName(graph, "feed");
597   EXPECT_TRUE(feed == feed2);
598   ASSERT_TRUE(GetNodeDef(feed, &node_def));
599   ASSERT_TRUE(GetNodeDef(feed2, &node_def2));
600   EXPECT_EQ(node_def.DebugString(), node_def2.DebugString());
601 
602   // Test iterating through the nodes of a graph.
603   found_placeholder = false;
604   found_scalar_const = false;
605   found_add = false;
606   bool found_neg = false;
607   size_t pos = 0;
608   TF_Operation* oper;
609   while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) {
610     if (oper == feed) {
611       EXPECT_FALSE(found_placeholder);
612       found_placeholder = true;
613     } else if (oper == three) {
614       EXPECT_FALSE(found_scalar_const);
615       found_scalar_const = true;
616     } else if (oper == add) {
617       EXPECT_FALSE(found_add);
618       found_add = true;
619     } else if (oper == neg) {
620       EXPECT_FALSE(found_neg);
621       found_neg = true;
622     } else {
623       ASSERT_TRUE(GetNodeDef(oper, &node_def));
624       ADD_FAILURE() << "Unexpected Node: " << node_def.DebugString();
625     }
626   }
627   EXPECT_TRUE(found_placeholder);
628   EXPECT_TRUE(found_scalar_const);
629   EXPECT_TRUE(found_add);
630   EXPECT_TRUE(found_neg);
631 
632   // Clean up
633   TF_DeleteGraph(graph);
634   TF_DeleteStatus(s);
635 }
636 
TEST(CAPI,UpdateEdge)637 TEST(CAPI, UpdateEdge) {
638   TF_Status* s = TF_NewStatus();
639   TF_Graph* graph = TF_NewGraph();
640 
641   // Make two scalar constants.
642   TF_Operation* one = ScalarConst(1, graph, s, "one");
643   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
644 
645   TF_Operation* two = ScalarConst(2, graph, s, "two");
646   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
647 
648   // Add oper.
649   TF_Operation* add = Add(one, two, graph, s, "add");
650   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
651 
652   // Add another oper to the graph.
653   TF_Operation* neg = Neg(add, graph, s, "neg");
654   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
655 
656   NodeDef node_def_neg;
657   ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
658   EXPECT_EQ(string("add"), node_def_neg.input(0));
659 
660   // update edge of neg
661   TF_UpdateEdge(graph, TF_Output{one, 0}, TF_Input{neg, 0}, s);
662 
663   ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
664   EXPECT_EQ(string("one:0"), node_def_neg.input(0));
665 
666   // Clean up
667   TF_DeleteGraph(graph);
668   TF_DeleteStatus(s);
669 }
670 
671 /*
672 TODO(skyewm): this test currently DCHECKs, change to bad status
673 
674 TEST(CAPI, InputFromDifferentGraphError) {
675   TF_Status* s = TF_NewStatus();
676   TF_Graph* g1 = TF_NewGraph();
677   TF_Graph* g2 = TF_NewGraph();
678 
679   TF_Operation* feed = Placeholder(g1, s);
680   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
681 
682   // Attempt to create node in g2 with input from g1
683   Neg(feed, g2, s);
684   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s));
685   EXPECT_STREQ("foo", TF_Message(s));
686 
687   TF_DeleteGraph(g1);
688   TF_DeleteGraph(g2);
689   TF_DeleteStatus(s);
690 }
691 */
692 
TEST(CAPI,ImportGraphDef)693 TEST(CAPI, ImportGraphDef) {
694   TF_Status* s = TF_NewStatus();
695   TF_Graph* graph = TF_NewGraph();
696 
697   // Create a simple graph.
698   Placeholder(graph, s);
699   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
700   ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
701   TF_Operation* oper = ScalarConst(3, graph, s);
702   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
703   ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr);
704   Neg(oper, graph, s);
705   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
706   ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr);
707 
708   // Export to a GraphDef.
709   TF_Buffer* graph_def = TF_NewBuffer();
710   TF_GraphToGraphDef(graph, graph_def, s);
711   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
712 
713   // Import it, with a prefix, in a fresh graph.
714   TF_DeleteGraph(graph);
715   graph = TF_NewGraph();
716   TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
717   TF_ImportGraphDefOptionsSetPrefix(opts, "imported");
718   TF_GraphImportGraphDef(graph, graph_def, opts, s);
719   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
720 
721   TF_Operation* scalar = TF_GraphOperationByName(graph, "imported/scalar");
722   TF_Operation* feed = TF_GraphOperationByName(graph, "imported/feed");
723   TF_Operation* neg = TF_GraphOperationByName(graph, "imported/neg");
724   ASSERT_TRUE(scalar != nullptr);
725   ASSERT_TRUE(feed != nullptr);
726   ASSERT_TRUE(neg != nullptr);
727 
728   // Test basic structure of the imported graph.
729   EXPECT_EQ(0, TF_OperationNumInputs(scalar));
730   EXPECT_EQ(0, TF_OperationNumInputs(feed));
731   ASSERT_EQ(1, TF_OperationNumInputs(neg));
732   TF_Output neg_input = TF_OperationInput({neg, 0});
733   EXPECT_EQ(scalar, neg_input.oper);
734   EXPECT_EQ(0, neg_input.index);
735 
736   // Test that we can't see control edges involving the source and sink nodes.
737   TF_Operation* control_ops[100];
738   EXPECT_EQ(0, TF_OperationNumControlInputs(scalar));
739   EXPECT_EQ(0, TF_OperationGetControlInputs(scalar, control_ops, 100));
740   EXPECT_EQ(0, TF_OperationNumControlOutputs(scalar));
741   EXPECT_EQ(0, TF_OperationGetControlOutputs(scalar, control_ops, 100));
742 
743   EXPECT_EQ(0, TF_OperationNumControlInputs(feed));
744   EXPECT_EQ(0, TF_OperationGetControlInputs(feed, control_ops, 100));
745   EXPECT_EQ(0, TF_OperationNumControlOutputs(feed));
746   EXPECT_EQ(0, TF_OperationGetControlOutputs(feed, control_ops, 100));
747 
748   EXPECT_EQ(0, TF_OperationNumControlInputs(neg));
749   EXPECT_EQ(0, TF_OperationGetControlInputs(neg, control_ops, 100));
750   EXPECT_EQ(0, TF_OperationNumControlOutputs(neg));
751   EXPECT_EQ(0, TF_OperationGetControlOutputs(neg, control_ops, 100));
752 
753   // Import it again, with an input mapping, return outputs, and a return
754   // operation, into the same graph.
755   TF_DeleteImportGraphDefOptions(opts);
756   opts = TF_NewImportGraphDefOptions();
757   TF_ImportGraphDefOptionsSetPrefix(opts, "imported2");
758   TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, {scalar, 0});
759   TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0);
760   TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0);
761   EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts));
762   TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar");
763   EXPECT_EQ(1, TF_ImportGraphDefOptionsNumReturnOperations(opts));
764   TF_ImportGraphDefResults* results =
765       TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s);
766   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
767 
768   TF_Operation* scalar2 = TF_GraphOperationByName(graph, "imported2/scalar");
769   TF_Operation* feed2 = TF_GraphOperationByName(graph, "imported2/feed");
770   TF_Operation* neg2 = TF_GraphOperationByName(graph, "imported2/neg");
771   ASSERT_TRUE(scalar2 != nullptr);
772   ASSERT_TRUE(feed2 != nullptr);
773   ASSERT_TRUE(neg2 != nullptr);
774 
775   // Check input mapping
776   neg_input = TF_OperationInput({neg, 0});
777   EXPECT_EQ(scalar, neg_input.oper);
778   EXPECT_EQ(0, neg_input.index);
779 
780   // Check return outputs
781   TF_Output* return_outputs;
782   int num_return_outputs;
783   TF_ImportGraphDefResultsReturnOutputs(results, &num_return_outputs,
784                                         &return_outputs);
785   ASSERT_EQ(2, num_return_outputs);
786   EXPECT_EQ(feed2, return_outputs[0].oper);
787   EXPECT_EQ(0, return_outputs[0].index);
788   EXPECT_EQ(scalar, return_outputs[1].oper);  // remapped
789   EXPECT_EQ(0, return_outputs[1].index);
790 
791   // Check return operation
792   TF_Operation** return_opers;
793   int num_return_opers;
794   TF_ImportGraphDefResultsReturnOperations(results, &num_return_opers,
795                                            &return_opers);
796   ASSERT_EQ(1, num_return_opers);
797   EXPECT_EQ(scalar2, return_opers[0]);  // not remapped
798 
799   TF_DeleteImportGraphDefResults(results);
800 
801   // Import again, with control dependencies, into the same graph.
802   TF_DeleteImportGraphDefOptions(opts);
803   opts = TF_NewImportGraphDefOptions();
804   TF_ImportGraphDefOptionsSetPrefix(opts, "imported3");
805   TF_ImportGraphDefOptionsAddControlDependency(opts, feed);
806   TF_ImportGraphDefOptionsAddControlDependency(opts, feed2);
807   TF_GraphImportGraphDef(graph, graph_def, opts, s);
808   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
809 
810   TF_Operation* scalar3 = TF_GraphOperationByName(graph, "imported3/scalar");
811   TF_Operation* feed3 = TF_GraphOperationByName(graph, "imported3/feed");
812   TF_Operation* neg3 = TF_GraphOperationByName(graph, "imported3/neg");
813   ASSERT_TRUE(scalar3 != nullptr);
814   ASSERT_TRUE(feed3 != nullptr);
815   ASSERT_TRUE(neg3 != nullptr);
816 
817   // Check that newly-imported scalar and feed have control deps (neg3 will
818   // inherit them from input)
819   TF_Operation* control_inputs[100];
820   int num_control_inputs = TF_OperationGetControlInputs(
821       scalar3, control_inputs, TF_OperationNumControlInputs(scalar3));
822   ASSERT_EQ(2, num_control_inputs);
823   EXPECT_EQ(feed, control_inputs[0]);
824   EXPECT_EQ(feed2, control_inputs[1]);
825 
826   num_control_inputs = TF_OperationGetControlInputs(
827       feed3, control_inputs, TF_OperationNumControlInputs(feed3));
828   ASSERT_EQ(2, num_control_inputs);
829   EXPECT_EQ(feed, control_inputs[0]);
830   EXPECT_EQ(feed2, control_inputs[1]);
831 
832   // Export to a graph def so we can import a graph with control dependencies
833   TF_DeleteBuffer(graph_def);
834   graph_def = TF_NewBuffer();
835   TF_GraphToGraphDef(graph, graph_def, s);
836   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
837 
838   // Import again, with remapped control dependency, into the same graph
839   TF_DeleteImportGraphDefOptions(opts);
840   opts = TF_NewImportGraphDefOptions();
841   TF_ImportGraphDefOptionsSetPrefix(opts, "imported4");
842   TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed);
843   TF_GraphImportGraphDef(graph, graph_def, opts, s);
844   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
845 
846   TF_Operation* scalar4 =
847       TF_GraphOperationByName(graph, "imported4/imported3/scalar");
848   TF_Operation* feed4 =
849       TF_GraphOperationByName(graph, "imported4/imported2/feed");
850 
851   // Check that imported `imported3/scalar` has remapped control dep from
852   // original graph and imported control dep
853   num_control_inputs = TF_OperationGetControlInputs(
854       scalar4, control_inputs, TF_OperationNumControlInputs(scalar4));
855   ASSERT_EQ(2, num_control_inputs);
856   EXPECT_EQ(feed, control_inputs[0]);
857   EXPECT_EQ(feed4, control_inputs[1]);
858 
859   TF_DeleteImportGraphDefOptions(opts);
860   TF_DeleteBuffer(graph_def);
861 
862   // Can add nodes to the imported graph without trouble.
863   Add(feed, scalar, graph, s);
864   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
865 
866   TF_DeleteGraph(graph);
867   TF_DeleteStatus(s);
868 }
869 
TEST(CAPI,ImportGraphDef_WithReturnOutputs)870 TEST(CAPI, ImportGraphDef_WithReturnOutputs) {
871   TF_Status* s = TF_NewStatus();
872   TF_Graph* graph = TF_NewGraph();
873 
874   // Create a graph with two nodes: x and 3
875   Placeholder(graph, s);
876   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
877   ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
878   TF_Operation* oper = ScalarConst(3, graph, s);
879   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
880   ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr);
881   Neg(oper, graph, s);
882   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
883   ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr);
884 
885   // Export to a GraphDef.
886   TF_Buffer* graph_def = TF_NewBuffer();
887   TF_GraphToGraphDef(graph, graph_def, s);
888   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
889 
890   // Import it in a fresh graph with return outputs.
891   TF_DeleteGraph(graph);
892   graph = TF_NewGraph();
893   TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
894   TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0);
895   TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0);
896   EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts));
897   TF_Output return_outputs[2];
898   TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, opts,
899                                           return_outputs, 2, s);
900   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
901 
902   TF_Operation* scalar = TF_GraphOperationByName(graph, "scalar");
903   TF_Operation* feed = TF_GraphOperationByName(graph, "feed");
904   TF_Operation* neg = TF_GraphOperationByName(graph, "neg");
905   ASSERT_TRUE(scalar != nullptr);
906   ASSERT_TRUE(feed != nullptr);
907   ASSERT_TRUE(neg != nullptr);
908 
909   // Check return outputs
910   EXPECT_EQ(feed, return_outputs[0].oper);
911   EXPECT_EQ(0, return_outputs[0].index);
912   EXPECT_EQ(scalar, return_outputs[1].oper);
913   EXPECT_EQ(0, return_outputs[1].index);
914 
915   TF_DeleteImportGraphDefOptions(opts);
916   TF_DeleteBuffer(graph_def);
917   TF_DeleteGraph(graph);
918   TF_DeleteStatus(s);
919 }
920 
TEST(CAPI,ImportGraphDef_MissingUnusedInputMappings)921 TEST(CAPI, ImportGraphDef_MissingUnusedInputMappings) {
922   TF_Status* s = TF_NewStatus();
923   TF_Graph* graph = TF_NewGraph();
924 
925   // Create a graph with two nodes: x and 3
926   Placeholder(graph, s);
927   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
928   ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
929   TF_Operation* oper = ScalarConst(3, graph, s);
930   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
931   ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr);
932   Neg(oper, graph, s);
933   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
934   ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr);
935 
936   // Export to a GraphDef.
937   TF_Buffer* graph_def = TF_NewBuffer();
938   TF_GraphToGraphDef(graph, graph_def, s);
939   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
940 
941   // Import it in a fresh graph.
942   TF_DeleteGraph(graph);
943   graph = TF_NewGraph();
944   TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
945   TF_GraphImportGraphDef(graph, graph_def, opts, s);
946   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
947 
948   TF_Operation* scalar = TF_GraphOperationByName(graph, "scalar");
949 
950   // Import it in a fresh graph with an unused input mapping.
951   TF_DeleteImportGraphDefOptions(opts);
952   opts = TF_NewImportGraphDefOptions();
953   TF_ImportGraphDefOptionsSetPrefix(opts, "imported");
954   TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, {scalar, 0});
955   TF_ImportGraphDefOptionsAddInputMapping(opts, "fake", 0, {scalar, 0});
956   TF_ImportGraphDefResults* results =
957       TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s);
958   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
959 
960   // Check unused input mappings
961   int num_unused_input_mappings;
962   const char** src_names;
963   int* src_indexes;
964   TF_ImportGraphDefResultsMissingUnusedInputMappings(
965       results, &num_unused_input_mappings, &src_names, &src_indexes);
966   ASSERT_EQ(1, num_unused_input_mappings);
967   EXPECT_EQ(string("fake"), string(src_names[0]));
968   EXPECT_EQ(0, src_indexes[0]);
969 
970   TF_DeleteImportGraphDefResults(results);
971   TF_DeleteImportGraphDefOptions(opts);
972   TF_DeleteBuffer(graph_def);
973   TF_DeleteGraph(graph);
974   TF_DeleteStatus(s);
975 }
976 
TEST(CAPI,Session)977 TEST(CAPI, Session) {
978   TF_Status* s = TF_NewStatus();
979   TF_Graph* graph = TF_NewGraph();
980 
981   // Make a placeholder operation.
982   TF_Operation* feed = Placeholder(graph, s);
983   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
984 
985   // Make a constant operation with the scalar "2".
986   TF_Operation* two = ScalarConst(2, graph, s);
987   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
988 
989   // Add operation.
990   TF_Operation* add = Add(feed, two, graph, s);
991   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
992 
993   // Create a session for this graph.
994   CSession csession(graph, s);
995   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
996 
997   // Run the graph.
998   csession.SetInputs({{feed, Int32Tensor(3)}});
999   csession.SetOutputs({add});
1000   csession.Run(s);
1001   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1002   TF_Tensor* out = csession.output_tensor(0);
1003   ASSERT_TRUE(out != nullptr);
1004   EXPECT_EQ(TF_INT32, TF_TensorType(out));
1005   EXPECT_EQ(0, TF_NumDims(out));  // scalar
1006   ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
1007   int32* output_contents = static_cast<int32*>(TF_TensorData(out));
1008   EXPECT_EQ(3 + 2, *output_contents);
1009 
1010   // Add another operation to the graph.
1011   TF_Operation* neg = Neg(add, graph, s);
1012   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1013 
1014   // Run up to the new operation.
1015   csession.SetInputs({{feed, Int32Tensor(7)}});
1016   csession.SetOutputs({neg});
1017   csession.Run(s);
1018   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1019   out = csession.output_tensor(0);
1020   ASSERT_TRUE(out != nullptr);
1021   EXPECT_EQ(TF_INT32, TF_TensorType(out));
1022   EXPECT_EQ(0, TF_NumDims(out));  // scalar
1023   ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
1024   output_contents = static_cast<int32*>(TF_TensorData(out));
1025   EXPECT_EQ(-(7 + 2), *output_contents);
1026 
1027   // Clean up
1028   csession.CloseAndDelete(s);
1029   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1030   TF_DeleteGraph(graph);
1031   TF_DeleteStatus(s);
1032 }
1033 
1034 // If `device` is non-empty, run Min op on that device.
1035 // Otherwise run it on the default device (CPU).
RunMinTest(const string & device,bool use_XLA)1036 void RunMinTest(const string& device, bool use_XLA) {
1037   TF_Status* s = TF_NewStatus();
1038   TF_Graph* graph = TF_NewGraph();
1039 
1040   // Make a placeholder operation.
1041   TF_Operation* feed = Placeholder(graph, s);
1042   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1043 
1044   // Make a constant operation with the scalar "0", for axis.
1045   TF_Operation* one = ScalarConst(0, graph, s);
1046   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1047 
1048   // Create a session for this graph.
1049   CSession csession(graph, s, use_XLA);
1050   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1051 
1052   if (!device.empty()) {
1053     LOG(INFO) << "Setting op Min on device " << device;
1054   }
1055   TF_Operation* min = MinWithDevice(feed, one, graph, device, s);
1056   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1057 
1058   // Run the graph.
1059   csession.SetInputs({{feed, Int32Tensor({3, 2, 5})}});
1060   csession.SetOutputs({min});
1061   csession.Run(s);
1062   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1063   TF_Tensor* out = csession.output_tensor(0);
1064   ASSERT_TRUE(out != nullptr);
1065   EXPECT_EQ(TF_INT32, TF_TensorType(out));
1066   EXPECT_EQ(0, TF_NumDims(out));  // scalar
1067   ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
1068   int32* output_contents = static_cast<int32*>(TF_TensorData(out));
1069   EXPECT_EQ(2, *output_contents);
1070 
1071   // Clean up
1072   csession.CloseAndDelete(s);
1073   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1074   TF_DeleteGraph(graph);
1075   TF_DeleteStatus(s);
1076 }
1077 
TEST(CAPI,Session_Min_CPU)1078 TEST(CAPI, Session_Min_CPU) { RunMinTest(/*device=*/"", /*use_XLA=*/false); }
1079 
TEST(CAPI,Session_Min_XLA_CPU)1080 TEST(CAPI, Session_Min_XLA_CPU) { RunMinTest(/*device=*/"", /*use_XLA=*/true); }
1081 
TEST(CAPI,Session_Min_GPU)1082 TEST(CAPI, Session_Min_GPU) {
1083   const string gpu_device = GPUDeviceName();
1084   // Skip this test if no GPU is available.
1085   if (gpu_device.empty()) return;
1086 
1087   RunMinTest(gpu_device, /*use_XLA=*/false);
1088 }
1089 
TEST(CAPI,Session_Min_XLA_GPU)1090 TEST(CAPI, Session_Min_XLA_GPU) {
1091   const string gpu_device = GPUDeviceName();
1092   // Skip this test if no GPU is available.
1093   if (gpu_device.empty()) return;
1094 
1095   RunMinTest(gpu_device, /*use_XLA=*/true);
1096 }
1097 
TEST(CAPI,SessionPRun)1098 TEST(CAPI, SessionPRun) {
1099   TF_Status* s = TF_NewStatus();
1100   TF_Graph* graph = TF_NewGraph();
1101 
1102   // Construct the graph: A + 2 + B
1103   TF_Operation* a = Placeholder(graph, s, "A");
1104   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1105 
1106   TF_Operation* b = Placeholder(graph, s, "B");
1107   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1108 
1109   TF_Operation* two = ScalarConst(2, graph, s);
1110   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1111 
1112   TF_Operation* plus2 = Add(a, two, graph, s, "plus2");
1113   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1114 
1115   TF_Operation* plusB = Add(plus2, b, graph, s, "plusB");
1116   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1117 
1118   // Setup a session and a partial run handle.  The partial run will allow
1119   // computation of A + 2 + B in two phases (calls to TF_SessionPRun):
1120   // 1. Feed A and get (A+2)
1121   // 2. Feed B and get (A+2)+B
1122   TF_SessionOptions* opts = TF_NewSessionOptions();
1123   TF_Session* sess = TF_NewSession(graph, opts, s);
1124   TF_DeleteSessionOptions(opts);
1125 
1126   TF_Output feeds[] = {TF_Output{a, 0}, TF_Output{b, 0}};
1127   TF_Output fetches[] = {TF_Output{plus2, 0}, TF_Output{plusB, 0}};
1128 
1129   const char* handle = nullptr;
1130   TF_SessionPRunSetup(sess, feeds, TF_ARRAYSIZE(feeds), fetches,
1131                       TF_ARRAYSIZE(fetches), nullptr, 0, &handle, s);
1132   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1133 
1134   // Feed A and fetch A + 2.
1135   TF_Output feeds1[] = {TF_Output{a, 0}};
1136   TF_Output fetches1[] = {TF_Output{plus2, 0}};
1137   TF_Tensor* feedValues1[] = {Int32Tensor(1)};
1138   TF_Tensor* fetchValues1[1];
1139   TF_SessionPRun(sess, handle, feeds1, feedValues1, 1, fetches1, fetchValues1,
1140                  1, nullptr, 0, s);
1141   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1142   EXPECT_EQ(3, *(static_cast<int32*>(TF_TensorData(fetchValues1[0]))));
1143   TF_DeleteTensor(feedValues1[0]);
1144   TF_DeleteTensor(fetchValues1[0]);
1145 
1146   // Feed B and fetch (A + 2) + B.
1147   TF_Output feeds2[] = {TF_Output{b, 0}};
1148   TF_Output fetches2[] = {TF_Output{plusB, 0}};
1149   TF_Tensor* feedValues2[] = {Int32Tensor(4)};
1150   TF_Tensor* fetchValues2[1];
1151   TF_SessionPRun(sess, handle, feeds2, feedValues2, 1, fetches2, fetchValues2,
1152                  1, nullptr, 0, s);
1153   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1154   EXPECT_EQ(7, *(static_cast<int32*>(TF_TensorData(fetchValues2[0]))));
1155   TF_DeleteTensor(feedValues2[0]);
1156   TF_DeleteTensor(fetchValues2[0]);
1157 
1158   // Clean up.
1159   TF_DeletePRunHandle(handle);
1160   TF_DeleteSession(sess, s);
1161   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1162   TF_DeleteGraph(graph);
1163   TF_DeleteStatus(s);
1164 }
1165 
TEST(CAPI,ShapeInferenceError)1166 TEST(CAPI, ShapeInferenceError) {
1167   // TF_FinishOperation should fail if the shape of the added operation cannot
1168   // be inferred.
1169   TF_Status* status = TF_NewStatus();
1170   TF_Graph* graph = TF_NewGraph();
1171 
1172   // Create this failure by trying to add two nodes with incompatible shapes
1173   // (A tensor with shape [2] and a tensor with shape [3] cannot be added).
1174   const char data[] = {1, 2, 3};
1175   const int64_t vec2_dims[] = {2};
1176   unique_tensor_ptr vec2_tensor(
1177       Int8Tensor(vec2_dims, TF_ARRAYSIZE(vec2_dims), data), TF_DeleteTensor);
1178   TF_Operation* vec2 = Const(vec2_tensor.get(), graph, status, "vec2");
1179   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1180 
1181   const int64_t vec3_dims[] = {3};
1182   unique_tensor_ptr vec3_tensor(
1183       Int8Tensor(vec3_dims, TF_ARRAYSIZE(vec3_dims), data), TF_DeleteTensor);
1184   TF_Operation* vec3 = Const(vec3_tensor.get(), graph, status, "vec3");
1185   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1186 
1187   TF_Operation* add = AddNoCheck(vec2, vec3, graph, status);
1188   ASSERT_NE(TF_OK, TF_GetCode(status));
1189   ASSERT_TRUE(add == nullptr);
1190 
1191   TF_DeleteGraph(graph);
1192   TF_DeleteStatus(status);
1193 }
1194 
TEST(CAPI,GetOpDef)1195 TEST(CAPI, GetOpDef) {
1196   TF_Status* status = TF_NewStatus();
1197   TF_Graph* graph = TF_NewGraph();
1198   TF_Buffer* buffer = TF_NewBuffer();
1199 
1200   TF_GraphGetOpDef(graph, "Add", buffer, status);
1201   ASSERT_EQ(TF_OK, TF_GetCode(status));
1202   const OpDef* expected_op_def;
1203   TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef("Add", &expected_op_def));
1204   string expected_serialized;
1205   expected_op_def->SerializeToString(&expected_serialized);
1206   string actual_string(reinterpret_cast<const char*>(buffer->data),
1207                        buffer->length);
1208   EXPECT_EQ(expected_serialized, actual_string);
1209 
1210   TF_GraphGetOpDef(graph, "MyFakeOp", buffer, status);
1211   EXPECT_EQ(TF_NOT_FOUND, TF_GetCode(status));
1212   ExpectHasSubstr(TF_Message(status),
1213                   "Op type not registered 'MyFakeOp' in binary");
1214 
1215   TF_DeleteBuffer(buffer);
1216   TF_DeleteGraph(graph);
1217   TF_DeleteStatus(status);
1218 }
1219 
StringVectorToArrays(const std::vector<string> & v,std::unique_ptr<const void * []> * ptrs,std::unique_ptr<size_t[]> * lens)1220 void StringVectorToArrays(const std::vector<string>& v,
1221                           std::unique_ptr<const void*[]>* ptrs,
1222                           std::unique_ptr<size_t[]>* lens) {
1223   ptrs->reset(new const void*[v.size()]);
1224   lens->reset(new size_t[v.size()]);
1225   for (size_t i = 0; i < v.size(); ++i) {
1226     (*ptrs)[i] = v[i].data();
1227     (*lens)[i] = v[i].size();
1228   }
1229 }
1230 
1231 class CApiColocationTest : public ::testing::Test {
1232  protected:
CApiColocationTest()1233   CApiColocationTest() : s_(TF_NewStatus()), graph_(TF_NewGraph()) {}
1234 
SetUp()1235   void SetUp() override {
1236     feed1_ = Placeholder(graph_, s_, "feed1");
1237     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1238 
1239     feed2_ = Placeholder(graph_, s_, "feed2");
1240     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1241 
1242     constant_ = ScalarConst(10, graph_, s_);
1243     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1244 
1245     desc_ = TF_NewOperation(graph_, "AddN", "add");
1246     TF_Output inputs[] = {{feed1_, 0}, {constant_, 0}};
1247     TF_AddInputList(desc_, inputs, TF_ARRAYSIZE(inputs));
1248   }
1249 
~CApiColocationTest()1250   ~CApiColocationTest() override {
1251     TF_DeleteGraph(graph_);
1252     TF_DeleteStatus(s_);
1253   }
1254 
SetViaStringList(TF_OperationDescription * desc,const std::vector<string> & list)1255   void SetViaStringList(TF_OperationDescription* desc,
1256                         const std::vector<string>& list) {
1257     std::unique_ptr<const void*[]> list_ptrs;
1258     std::unique_ptr<size_t[]> list_lens;
1259     StringVectorToArrays(list, &list_ptrs, &list_lens);
1260     TF_SetAttrStringList(desc, tensorflow::kColocationAttrName, list_ptrs.get(),
1261                          list_lens.get(), list.size());
1262   }
1263 
SetViaProto(TF_OperationDescription * desc,const std::vector<string> & list)1264   void SetViaProto(TF_OperationDescription* desc,
1265                    const std::vector<string>& list) {
1266     tensorflow::AttrValue attr;
1267     for (const string& v : list) {
1268       attr.mutable_list()->add_s(v);
1269     }
1270     string bytes;
1271     attr.SerializeToString(&bytes);
1272     TF_SetAttrValueProto(desc, tensorflow::kColocationAttrName, bytes.data(),
1273                          bytes.size(), s_);
1274     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1275   }
1276 
VerifyCollocation(TF_Operation * op,const std::vector<string> & expected)1277   void VerifyCollocation(TF_Operation* op,
1278                          const std::vector<string>& expected) {
1279     TF_AttrMetadata m =
1280         TF_OperationGetAttrMetadata(op, tensorflow::kColocationAttrName, s_);
1281     if (expected.empty()) {
1282       ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
1283       EXPECT_EQ("Operation 'add' has no attr named '_class'.",
1284                 string(TF_Message(s_)));
1285       return;
1286     }
1287     EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1288     EXPECT_EQ(1, m.is_list);
1289     EXPECT_EQ(expected.size(), m.list_size);
1290     EXPECT_EQ(TF_ATTR_STRING, m.type);
1291     std::vector<void*> values(expected.size());
1292     std::vector<size_t> lens(expected.size());
1293     std::unique_ptr<char[]> storage(new char[m.total_size]);
1294     TF_OperationGetAttrStringList(op, tensorflow::kColocationAttrName,
1295                                   values.data(), lens.data(), expected.size(),
1296                                   storage.get(), m.total_size, s_);
1297     EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1298     for (int i = 0; i < expected.size(); ++i) {
1299       EXPECT_EQ(expected[i],
1300                 string(static_cast<const char*>(values[i]), lens[i]));
1301     }
1302   }
1303 
FinishAndVerify(TF_OperationDescription * desc,const std::vector<string> & expected)1304   void FinishAndVerify(TF_OperationDescription* desc,
1305                        const std::vector<string>& expected) {
1306     TF_Operation* op = TF_FinishOperation(desc_, s_);
1307     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1308     VerifyCollocation(op, expected);
1309   }
1310 
1311   TF_Status* s_;
1312   TF_Graph* graph_;
1313   TF_Operation* feed1_;
1314   TF_Operation* feed2_;
1315   TF_Operation* constant_;
1316   TF_OperationDescription* desc_;
1317 };
1318 
TEST_F(CApiColocationTest,ColocateWith)1319 TEST_F(CApiColocationTest, ColocateWith) {
1320   TF_ColocateWith(desc_, feed1_);
1321   FinishAndVerify(desc_, {"loc:@feed1"});
1322 }
1323 
TEST_F(CApiColocationTest,StringList)1324 TEST_F(CApiColocationTest, StringList) {
1325   SetViaStringList(desc_, {"loc:@feed1"});
1326   FinishAndVerify(desc_, {"loc:@feed1"});
1327 }
1328 
TEST_F(CApiColocationTest,Proto)1329 TEST_F(CApiColocationTest, Proto) {
1330   SetViaProto(desc_, {"loc:@feed1"});
1331   FinishAndVerify(desc_, {"loc:@feed1"});
1332 }
1333 
TEST_F(CApiColocationTest,ColocateWith_StringList)1334 TEST_F(CApiColocationTest, ColocateWith_StringList) {
1335   TF_ColocateWith(desc_, feed1_);
1336   SetViaStringList(desc_, {"loc:@feed2"});
1337   FinishAndVerify(desc_, {"loc:@feed2"});
1338 }
1339 
TEST_F(CApiColocationTest,ColocateWith_Proto)1340 TEST_F(CApiColocationTest, ColocateWith_Proto) {
1341   TF_ColocateWith(desc_, feed1_);
1342   SetViaProto(desc_, {"loc:@feed2"});
1343   FinishAndVerify(desc_, {"loc:@feed2"});
1344 }
1345 
TEST_F(CApiColocationTest,StringList_ColocateWith)1346 TEST_F(CApiColocationTest, StringList_ColocateWith) {
1347   SetViaStringList(desc_, {"loc:@feed2"});
1348   TF_ColocateWith(desc_, feed1_);
1349   FinishAndVerify(desc_, {"loc:@feed1", "loc:@feed2"});
1350 }
1351 
TEST_F(CApiColocationTest,Proto_ColocateWith)1352 TEST_F(CApiColocationTest, Proto_ColocateWith) {
1353   SetViaProto(desc_, {"loc:@feed2"});
1354   TF_ColocateWith(desc_, feed1_);
1355   FinishAndVerify(desc_, {"loc:@feed1", "loc:@feed2"});
1356 }
1357 
TEST_F(CApiColocationTest,ColocateWith_ColocateWith)1358 TEST_F(CApiColocationTest, ColocateWith_ColocateWith) {
1359   TF_ColocateWith(desc_, feed1_);
1360   TF_ColocateWith(desc_, feed2_);
1361   FinishAndVerify(desc_, {"loc:@feed1", "loc:@feed2"});
1362 }
1363 
TEST_F(CApiColocationTest,Proto_StringList)1364 TEST_F(CApiColocationTest, Proto_StringList) {
1365   SetViaProto(desc_, {"loc:@feed1"});
1366   SetViaStringList(desc_, {"loc:@feed2"});
1367   FinishAndVerify(desc_, {"loc:@feed2"});
1368 }
1369 
TEST_F(CApiColocationTest,StringList_Proto)1370 TEST_F(CApiColocationTest, StringList_Proto) {
1371   SetViaStringList(desc_, {"loc:@feed1"});
1372   SetViaProto(desc_, {"loc:@feed2"});
1373   FinishAndVerify(desc_, {"loc:@feed2"});
1374 }
1375 
TEST_F(CApiColocationTest,ClearViaStringList)1376 TEST_F(CApiColocationTest, ClearViaStringList) {
1377   TF_ColocateWith(desc_, feed1_);
1378   SetViaStringList(desc_, {});
1379   FinishAndVerify(desc_, {});
1380 }
1381 
TEST_F(CApiColocationTest,ClearViaProto)1382 TEST_F(CApiColocationTest, ClearViaProto) {
1383   TF_ColocateWith(desc_, feed1_);
1384   SetViaProto(desc_, {});
1385   FinishAndVerify(desc_, {});
1386 }
1387 
TEST(CAPI,SavedModel)1388 TEST(CAPI, SavedModel) {
1389   // Load the saved model.
1390   const string saved_model_dir = tensorflow::GetDataDependencyFilepath(
1391       tensorflow::io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
1392                                "half_plus_two", "00000123"));
1393   TF_SessionOptions* opt = TF_NewSessionOptions();
1394   TF_Buffer* run_options = TF_NewBufferFromString("", 0);
1395   TF_Buffer* metagraph = TF_NewBuffer();
1396   TF_Status* s = TF_NewStatus();
1397   const char* tags[] = {tensorflow::kSavedModelTagServe};
1398   TF_Graph* graph = TF_NewGraph();
1399   TF_Session* session = TF_LoadSessionFromSavedModel(
1400       opt, run_options, saved_model_dir.c_str(), tags, 1, graph, metagraph, s);
1401   TF_DeleteBuffer(run_options);
1402   TF_DeleteSessionOptions(opt);
1403   tensorflow::MetaGraphDef metagraph_def;
1404   metagraph_def.ParseFromArray(metagraph->data, metagraph->length);
1405   TF_DeleteBuffer(metagraph);
1406 
1407   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1408   CSession csession(session);
1409 
1410   // Retrieve the regression signature from meta graph def.
1411   const auto signature_def_map = metagraph_def.signature_def();
1412   const auto signature_def = signature_def_map.at("regress_x_to_y");
1413 
1414   const string input_name =
1415       signature_def.inputs().at(tensorflow::kRegressInputs).name();
1416   const string output_name =
1417       signature_def.outputs().at(tensorflow::kRegressOutputs).name();
1418 
1419   // Write {0, 1, 2, 3} as tensorflow::Example inputs.
1420   Tensor input(tensorflow::DT_STRING, TensorShape({4}));
1421   for (tensorflow::int64 i = 0; i < input.NumElements(); ++i) {
1422     tensorflow::Example example;
1423     auto* feature_map = example.mutable_features()->mutable_feature();
1424     (*feature_map)["x"].mutable_float_list()->add_value(i);
1425     input.flat<tstring>()(i) = example.SerializeAsString();
1426   }
1427 
1428   const tensorflow::string input_op_name(
1429       tensorflow::ParseTensorName(input_name).first);
1430   TF_Operation* input_op =
1431       TF_GraphOperationByName(graph, input_op_name.c_str());
1432   ASSERT_TRUE(input_op != nullptr);
1433   Status status;
1434   csession.SetInputs({{input_op, TF_TensorFromTensor(input, &status)}});
1435   ASSERT_TRUE(status.ok()) << status.error_message();
1436 
1437   const tensorflow::string output_op_name(
1438       tensorflow::ParseTensorName(output_name).first);
1439   TF_Operation* output_op =
1440       TF_GraphOperationByName(graph, output_op_name.c_str());
1441   ASSERT_TRUE(output_op != nullptr);
1442   csession.SetOutputs({output_op});
1443   csession.Run(s);
1444   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1445 
1446   TF_Tensor* out = csession.output_tensor(0);
1447   ASSERT_TRUE(out != nullptr);
1448   EXPECT_EQ(TF_FLOAT, TF_TensorType(out));
1449   EXPECT_EQ(2, TF_NumDims(out));
1450   EXPECT_EQ(4, TF_Dim(out, 0));
1451   EXPECT_EQ(1, TF_Dim(out, 1));
1452   float* values = static_cast<float*>(TF_TensorData(out));
1453   // These values are defined to be (input / 2) + 2.
1454   EXPECT_EQ(2, values[0]);
1455   EXPECT_EQ(2.5, values[1]);
1456   EXPECT_EQ(3, values[2]);
1457   EXPECT_EQ(3.5, values[3]);
1458 
1459   csession.CloseAndDelete(s);
1460   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1461   TF_DeleteGraph(graph);
1462   TF_DeleteStatus(s);
1463 }
1464 
TEST(CAPI,SavedModelNullArgsAreValid)1465 TEST(CAPI, SavedModelNullArgsAreValid) {
1466   const string saved_model_dir = tensorflow::GetDataDependencyFilepath(
1467       tensorflow::io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
1468                                "half_plus_two", "00000123"));
1469   TF_SessionOptions* opt = TF_NewSessionOptions();
1470   TF_Status* s = TF_NewStatus();
1471   const char* tags[] = {tensorflow::kSavedModelTagServe};
1472   TF_Graph* graph = TF_NewGraph();
1473   // NULL run_options and meta_graph_def should work.
1474   TF_Session* session = TF_LoadSessionFromSavedModel(
1475       opt, nullptr, saved_model_dir.c_str(), tags, 1, graph, nullptr, s);
1476   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1477   TF_DeleteSessionOptions(opt);
1478   TF_CloseSession(session, s);
1479   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1480   TF_DeleteSession(session, s);
1481   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1482   TF_DeleteGraph(graph);
1483   TF_DeleteStatus(s);
1484 }
1485 
TEST(CAPI,DeletingNullPointerIsSafe)1486 TEST(CAPI, DeletingNullPointerIsSafe) {
1487   TF_Status* status = TF_NewStatus();
1488 
1489   TF_DeleteStatus(nullptr);
1490   TF_DeleteBuffer(nullptr);
1491   TF_DeleteTensor(nullptr);
1492   TF_DeleteSessionOptions(nullptr);
1493   TF_DeleteGraph(nullptr);
1494   TF_DeleteImportGraphDefOptions(nullptr);
1495   TF_DeleteImportGraphDefResults(nullptr);
1496   TF_DeleteFunction(nullptr);
1497   TF_DeleteSession(nullptr, status);
1498   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1499   TF_DeletePRunHandle(nullptr);
1500   TF_DeleteDeprecatedSession(nullptr, status);
1501   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1502   TF_DeleteDeviceList(nullptr);
1503   TF_DeleteLibraryHandle(nullptr);
1504   TF_DeleteApiDefMap(nullptr);
1505 
1506   TF_DeleteStatus(status);
1507 }
1508 
TEST(CAPI,TestBitcastFrom_Reshape)1509 TEST(CAPI, TestBitcastFrom_Reshape) {
1510   int64_t dims[] = {2, 3};
1511   TF_Tensor* a =
1512       TF_AllocateTensor(TF_UINT64, dims, 2, 6 * TF_DataTypeSize(TF_UINT64));
1513   TF_Tensor* b =
1514       TF_AllocateTensor(TF_UINT64, nullptr, 0, TF_DataTypeSize(TF_UINT64));
1515   EXPECT_NE(a, nullptr);
1516   EXPECT_NE(b, nullptr);
1517 
1518   EXPECT_EQ(6, TF_TensorElementCount(a));
1519   EXPECT_EQ(1, TF_TensorElementCount(b));
1520   EXPECT_EQ(6 * TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(a));
1521   EXPECT_EQ(TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(b));
1522 
1523   int64_t new_dims[] = {3, 2};
1524   TF_Status* status = TF_NewStatus();
1525   TF_TensorBitcastFrom(a, TF_UINT64, b, new_dims, 2, status);
1526   ASSERT_EQ(TF_OK, TF_GetCode(status));
1527   TF_DeleteStatus(status);
1528 
1529   EXPECT_EQ(6, TF_TensorElementCount(a));
1530   EXPECT_EQ(6, TF_TensorElementCount(b));
1531   EXPECT_EQ(6 * TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(a));
1532   EXPECT_EQ(6 * TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(b));
1533 
1534   // Check that a write to one tensor shows up in the other.
1535   *(static_cast<int64_t*>(TF_TensorData(a))) = 4;
1536   EXPECT_EQ(4, *(static_cast<int64_t*>(TF_TensorData(b))));
1537   *(static_cast<int64_t*>(TF_TensorData(b))) = 6;
1538   EXPECT_EQ(6, *(static_cast<int64_t*>(TF_TensorData(a))));
1539 
1540   TF_DeleteTensor(a);
1541   TF_DeleteTensor(b);
1542 }
1543 
1544 REGISTER_OP("TestOpWithNoGradient")
1545     .Input("x: T")
1546     .Output("y: T")
1547     .Attr("T: {float, double}")
1548     .Doc(R"doc(
1549 Test op with no grad registered.
1550 
1551 x: input
1552 y: output
1553 )doc")
1554     .SetShapeFn(tensorflow::shape_inference::UnknownShape);
1555 
1556 class CApiGradientsTest : public ::testing::Test {
1557  protected:
CApiGradientsTest()1558   CApiGradientsTest()
1559       : s_(TF_NewStatus()),
1560         graph_(TF_NewGraph()),
1561         expected_graph_(TF_NewGraph()) {}
1562 
~CApiGradientsTest()1563   ~CApiGradientsTest() override {
1564     TF_DeleteGraph(graph_);
1565     TF_DeleteGraph(expected_graph_);
1566     TF_DeleteStatus(s_);
1567   }
1568 
TestGradientsSuccess(bool grad_inputs_provided)1569   void TestGradientsSuccess(bool grad_inputs_provided) {
1570     TF_Output inputs[2];
1571     TF_Output outputs[1];
1572     TF_Output grad_outputs[2];
1573     TF_Output expected_grad_outputs[2];
1574 
1575     BuildSuccessGraph(inputs, outputs);
1576     BuildExpectedGraph(grad_inputs_provided, expected_grad_outputs);
1577 
1578     AddGradients(grad_inputs_provided, nullptr, inputs, 2, outputs, 1,
1579                  grad_outputs);
1580     EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1581 
1582     // Compare that the graphs match.
1583     GraphDef expected_gdef;
1584     GraphDef gdef;
1585     EXPECT_TRUE(GetGraphDef(expected_graph_, &expected_gdef));
1586     EXPECT_TRUE(GetGraphDef(graph_, &gdef));
1587     TF_EXPECT_GRAPH_EQ(expected_gdef, gdef);
1588 
1589     // Compare that the output of the gradients of both graphs match.
1590     RunGraphsAndCompareOutputs(grad_outputs, expected_grad_outputs);
1591   }
1592 
TestGradientsError(bool grad_inputs_provided)1593   void TestGradientsError(bool grad_inputs_provided) {
1594     TF_Output inputs[1];
1595     TF_Output outputs[1];
1596     TF_Output grad_outputs[1];
1597 
1598     BuildErrorGraph(inputs, outputs);
1599 
1600     AddGradients(grad_inputs_provided, nullptr, inputs, 1, outputs, 1,
1601                  grad_outputs);
1602 
1603     string expected_msg =
1604         "No gradient defined for op: TestOpWithNoGradient. Please see "
1605         "https://www.tensorflow.org/code/"
1606         "tensorflow/cc/gradients/README.md"
1607         " for instructions on how to add C++ gradients.";
1608     EXPECT_EQ(expected_msg, TF_Message(s_));
1609   }
1610 
1611   // Run the graph and ensure that the gradient values are as expected.
RunGraphsAndCompareOutputs(TF_Output * grad_outputs,TF_Output * expected_grad_outputs)1612   void RunGraphsAndCompareOutputs(TF_Output* grad_outputs,
1613                                   TF_Output* expected_grad_outputs) {
1614     std::unique_ptr<CSession> csession(new CSession(graph_, s_));
1615     std::unique_ptr<CSession> expected_csession(
1616         new CSession(expected_graph_, s_));
1617 
1618     std::vector<TF_Output> grad_outputs_vec;
1619     grad_outputs_vec.assign(grad_outputs, grad_outputs + 2);
1620     csession->SetOutputs(grad_outputs_vec);
1621     csession->Run(s_);
1622     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1623     TF_Tensor* out0 = csession->output_tensor(0);
1624     TF_Tensor* out1 = csession->output_tensor(1);
1625 
1626     std::vector<TF_Output> expected_grad_outputs_vec;
1627     expected_grad_outputs_vec.assign(expected_grad_outputs,
1628                                      expected_grad_outputs + 2);
1629     expected_csession->SetOutputs(expected_grad_outputs_vec);
1630     expected_csession->Run(s_);
1631     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1632     TF_Tensor* expected_out0 = expected_csession->output_tensor(0);
1633     TF_Tensor* expected_out1 = expected_csession->output_tensor(1);
1634 
1635     CompareTensors(out0, expected_out0);
1636     CompareTensors(out1, expected_out1);
1637   }
1638 
CompareTensors(TF_Tensor * a,TF_Tensor * b)1639   void CompareTensors(TF_Tensor* a, TF_Tensor* b) {
1640     float* a_data = static_cast<float*>(TF_TensorData(a));
1641     float* b_data = static_cast<float*>(TF_TensorData(b));
1642     EXPECT_EQ(*a_data, *b_data);
1643   }
1644 
AddGradients(bool grad_inputs_provided,const char * prefix,TF_Output * inputs,int ninputs,TF_Output * outputs,int noutputs,TF_Output * grad_outputs)1645   void AddGradients(bool grad_inputs_provided, const char* prefix,
1646                     TF_Output* inputs, int ninputs, TF_Output* outputs,
1647                     int noutputs, TF_Output* grad_outputs) {
1648     if (grad_inputs_provided) {
1649       TF_Output grad_inputs[1];
1650       const float grad_inputs_val[] = {1.0, 1.0, 1.0, 1.0};
1651       TF_Operation* grad_inputs_op =
1652           FloatConst2x2(graph_, s_, grad_inputs_val, "GradInputs");
1653       grad_inputs[0] = TF_Output{grad_inputs_op, 0};
1654       TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs,
1655                                 ninputs, grad_inputs, s_, grad_outputs);
1656     } else {
1657       TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs,
1658                                 ninputs, nullptr, s_, grad_outputs);
1659     }
1660   }
1661 
BuildErrorGraph(TF_Output * inputs,TF_Output * outputs)1662   void BuildErrorGraph(TF_Output* inputs, TF_Output* outputs) {
1663     const float const0_val[] = {1.0, 2.0, 3.0, 4.0};
1664     TF_Operation* const0 = FloatConst2x2(graph_, s_, const0_val, "Const_0");
1665     TF_Operation* nograd = NoGradientOp(graph_, s_, const0, "NoGrad");
1666     inputs[0] = TF_Output{const0, 0};
1667     outputs[0] = TF_Output{nograd, 0};
1668     EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1669   }
1670 
BuildSuccessGraph(TF_Output * inputs,TF_Output * outputs)1671   void BuildSuccessGraph(TF_Output* inputs, TF_Output* outputs) {
1672     // Construct the following graph:
1673     //            |
1674     //           z|
1675     //            |
1676     //          MatMul
1677     //         /       \
1678     //        ^         ^
1679     //        |         |
1680     //       x|        y|
1681     //        |         |
1682     //        |         |
1683     //      Const_0    Const_1
1684     //
1685     const float const0_val[] = {1.0, 2.0, 3.0, 4.0};
1686     const float const1_val[] = {1.0, 0.0, 0.0, 1.0};
1687     TF_Operation* const0 = FloatConst2x2(graph_, s_, const0_val, "Const_0");
1688     TF_Operation* const1 = FloatConst2x2(graph_, s_, const1_val, "Const_1");
1689     TF_Operation* matmul = MatMul(graph_, s_, const0, const1, "MatMul");
1690     inputs[0] = TF_Output{const0, 0};
1691     inputs[1] = TF_Output{const1, 0};
1692     outputs[0] = TF_Output{matmul, 0};
1693     EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1694   }
1695 
BuildExpectedGraph(bool grad_inputs_provided,TF_Output * expected_grad_outputs)1696   void BuildExpectedGraph(bool grad_inputs_provided,
1697                           TF_Output* expected_grad_outputs) {
1698     // The expected graph looks like this if grad_inputs_provided.
1699     // If grad_inputs_provided is false, Const_0 will be a OnesLike op.
1700     //      ^             ^
1701     //    dy|           dx|        // MatMul Gradient Graph
1702     //      |             |
1703     //   MatMul_2      MatMul_1
1704     //   ^   ^          ^    ^
1705     //   |   |----------|    |
1706     //   |        ^          |
1707     //   |      dz|          |
1708     //   |        |          |
1709     //   |     Const_3       |
1710     //   |                   |
1711     //   |        ^          |
1712     //   |       z|          |     // MatMul Forward Graph
1713     //   |        |          |
1714     //   |      MatMul       |
1715     //   |     /       \     |
1716     //   |    ^         ^    |
1717     //   |    |         |    |
1718     //   |---x|        y|----|
1719     //        |         |
1720     //        |         |
1721     //      Const_0   Const_1
1722     //
1723     const float const0_val[] = {1.0, 2.0, 3.0, 4.0};
1724     const float const1_val[] = {1.0, 0.0, 0.0, 1.0};
1725     TF_Operation* const0 =
1726         FloatConst2x2(expected_graph_, s_, const0_val, "Const_0");
1727     TF_Operation* const1 =
1728         FloatConst2x2(expected_graph_, s_, const1_val, "Const_1");
1729     TF_Operation* matmul =
1730         MatMul(expected_graph_, s_, const0, const1, "MatMul");
1731 
1732     TF_Operation* const3;
1733     if (grad_inputs_provided) {
1734       const float const3_val[] = {1.0, 1.0, 1.0, 1.0};
1735       const3 = FloatConst2x2(expected_graph_, s_, const3_val, "GradInputs");
1736     } else {
1737       const3 = OnesLike(expected_graph_, s_, matmul, "gradients/OnesLike");
1738     }
1739 
1740     TF_Operation* matmul1 = MatMul(expected_graph_, s_, const3, const1,
1741                                    "gradients/MatMul", false, true);
1742     TF_Operation* matmul2 = MatMul(expected_graph_, s_, const0, const3,
1743                                    "gradients/MatMul_1", true, false);
1744     expected_grad_outputs[0] = {matmul1, 0};
1745     expected_grad_outputs[1] = {matmul2, 0};
1746   }
1747 
FloatTensor2x2(const float * values)1748   TF_Tensor* FloatTensor2x2(const float* values) {
1749     const int64_t dims[2] = {2, 2};
1750     TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims, 2, sizeof(float) * 4);
1751     memcpy(TF_TensorData(t), values, sizeof(float) * 4);
1752     return t;
1753   }
1754 
FloatConst2x2(TF_Graph * graph,TF_Status * s,const float * values,const char * name)1755   TF_Operation* FloatConst2x2(TF_Graph* graph, TF_Status* s,
1756                               const float* values, const char* name) {
1757     unique_tensor_ptr tensor(FloatTensor2x2(values), TF_DeleteTensor);
1758     TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
1759     TF_SetAttrTensor(desc, "value", tensor.get(), s);
1760     if (TF_GetCode(s) != TF_OK) return nullptr;
1761     TF_SetAttrType(desc, "dtype", TF_FLOAT);
1762     TF_Operation* op = TF_FinishOperation(desc, s);
1763     EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1764     return op;
1765   }
1766 
MatMul(TF_Graph * graph,TF_Status * s,TF_Operation * l,TF_Operation * r,const char * name,bool transpose_a=false,bool transpose_b=false)1767   TF_Operation* MatMul(TF_Graph* graph, TF_Status* s, TF_Operation* l,
1768                        TF_Operation* r, const char* name,
1769                        bool transpose_a = false, bool transpose_b = false) {
1770     TF_OperationDescription* desc = TF_NewOperation(graph, "MatMul", name);
1771     if (transpose_a) {
1772       TF_SetAttrBool(desc, "transpose_a", 1);
1773     }
1774     if (transpose_b) {
1775       TF_SetAttrBool(desc, "transpose_b", 1);
1776     }
1777     TF_AddInput(desc, {l, 0});
1778     TF_AddInput(desc, {r, 0});
1779     TF_Operation* op = TF_FinishOperation(desc, s);
1780     EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1781     return op;
1782   }
1783 
OnesLike(TF_Graph * graph,TF_Status * s,TF_Operation * in,const char * name)1784   TF_Operation* OnesLike(TF_Graph* graph, TF_Status* s, TF_Operation* in,
1785                          const char* name) {
1786     TF_OperationDescription* desc = TF_NewOperation(graph, "OnesLike", name);
1787     TF_AddInput(desc, {in, 0});
1788     TF_Operation* op = TF_FinishOperation(desc, s);
1789     EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1790     return op;
1791   }
1792 
NoGradientOp(TF_Graph * graph,TF_Status * s,TF_Operation * in,const char * name)1793   TF_Operation* NoGradientOp(TF_Graph* graph, TF_Status* s, TF_Operation* in,
1794                              const char* name) {
1795     TF_OperationDescription* desc =
1796         TF_NewOperation(graph, "TestOpWithNoGradient", name);
1797     TF_AddInput(desc, {in, 0});
1798     TF_Operation* op = TF_FinishOperation(desc, s);
1799     EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1800     return op;
1801   }
1802 
BuildGraphAndAddGradientsWithPrefixes(const char * prefix1,const char * prefix2=nullptr)1803   void BuildGraphAndAddGradientsWithPrefixes(const char* prefix1,
1804                                              const char* prefix2 = nullptr) {
1805     TF_Output inputs[2];
1806     TF_Output outputs[1];
1807     TF_Output grad_outputs[2];
1808 
1809     BuildSuccessGraph(inputs, outputs);
1810 
1811     AddGradients(false, prefix1, inputs, 2, outputs, 1, grad_outputs);
1812     if (prefix2 != nullptr) {
1813       AddGradients(false, prefix2, inputs, 2, outputs, 1, grad_outputs);
1814     }
1815   }
1816 
1817   TF_Status* s_;
1818   TF_Graph* graph_;
1819   TF_Graph* expected_graph_;
1820 };
1821 
TEST_F(CApiGradientsTest,Gradients_GradInputs)1822 TEST_F(CApiGradientsTest, Gradients_GradInputs) { TestGradientsSuccess(true); }
1823 
TEST_F(CApiGradientsTest,Gradients_NoGradInputs)1824 TEST_F(CApiGradientsTest, Gradients_NoGradInputs) {
1825   TestGradientsSuccess(false);
1826 }
1827 
TEST_F(CApiGradientsTest,OpWithNoGradientRegistered_GradInputs)1828 TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_GradInputs) {
1829   TestGradientsError(true);
1830 }
1831 
TEST_F(CApiGradientsTest,OpWithNoGradientRegistered_NoGradInputs)1832 TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) {
1833   TestGradientsError(false);
1834 }
1835 
TEST_F(CApiGradientsTest,GradientsPrefix_PrefixIsOk)1836 TEST_F(CApiGradientsTest, GradientsPrefix_PrefixIsOk) {
1837   BuildGraphAndAddGradientsWithPrefixes("gradients");
1838   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1839 }
1840 
TEST_F(CApiGradientsTest,GradientsPrefix_TwoGradientsWithDistinctPrefixes)1841 TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithDistinctPrefixes) {
1842   BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients_1");
1843   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1844 }
1845 
TEST_F(CApiGradientsTest,GradientsPrefix_TwoGradientsInSameScope)1846 TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInSameScope) {
1847   BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope/gradients_1");
1848   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1849 }
1850 
TEST_F(CApiGradientsTest,GradientsPrefix_TwoGradientsInDifferentScopes)1851 TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInDifferentScopes) {
1852   BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope_1/gradients");
1853   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1854 }
1855 
TEST_F(CApiGradientsTest,GradientsPrefix_2ndGradientsAsSubScopeOf1st)1856 TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsSubScopeOf1st) {
1857   BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/sub");
1858   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1859 }
1860 
TEST_F(CApiGradientsTest,GradientsPrefix_PrefixMatchesExistingNodeName)1861 TEST_F(CApiGradientsTest, GradientsPrefix_PrefixMatchesExistingNodeName) {
1862   BuildGraphAndAddGradientsWithPrefixes("Const_0");
1863   ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
1864 }
1865 
TEST_F(CApiGradientsTest,GradientsPrefix_TwoGradientsWithIdenticalPrefixes)1866 TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithIdenticalPrefixes) {
1867   BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients");
1868   ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
1869 }
1870 
TEST_F(CApiGradientsTest,GradientsPrefix_2ndGradientsMatchingNodeOf1st)1871 TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsMatchingNodeOf1st) {
1872   BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/MatMul");
1873   ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
1874 }
1875 
TEST_F(CApiGradientsTest,GradientsPrefix_1stGradientsMatchingNodeOf2nd)1876 TEST_F(CApiGradientsTest, GradientsPrefix_1stGradientsMatchingNodeOf2nd) {
1877   BuildGraphAndAddGradientsWithPrefixes("gradients/MatMul", "gradients");
1878   ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
1879 }
1880 
TEST_F(CApiGradientsTest,GradientsPrefix_2ndGradientsAsParentScopeOf1st)1881 TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsParentScopeOf1st) {
1882   BuildGraphAndAddGradientsWithPrefixes("gradients/sub", "gradients");
1883   ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
1884 }
1885 
ScalarFloatFromTensor(const TF_Tensor * t,float * f)1886 void ScalarFloatFromTensor(const TF_Tensor* t, float* f) {
1887   ASSERT_TRUE(t != nullptr);
1888   ASSERT_EQ(TF_FLOAT, TF_TensorType(t));
1889   ASSERT_EQ(0, TF_NumDims(t));
1890   ASSERT_EQ(4, TF_TensorByteSize(t));
1891   float* p = static_cast<float*>(TF_TensorData(t));
1892   *f = *p;
1893 }
1894 
TEST_F(CApiGradientsTest,MultipleCallsToAddGradients)1895 TEST_F(CApiGradientsTest, MultipleCallsToAddGradients) {
1896   const float X = 3.0f, Y = 7.0f;
1897   TF_Operation* x = Placeholder(graph_, s_, "x", TF_FLOAT);
1898   TF_Operation* y = Placeholder(graph_, s_, "y", TF_FLOAT);
1899   TF_Operation* xy = Mul(x, y, graph_, s_, "xy");
1900   TF_Output dxy_dx, dxy_dy;
1901 
1902   TF_Output outputs[1] = {{xy, 0}};
1903   TF_Output inputs[1] = {{x, 0}};
1904   TF_AddGradients(graph_, outputs, 1, inputs, 1, nullptr, s_, &dxy_dx);
1905   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1906 
1907   inputs[0] = {y, 0};
1908   TF_AddGradients(graph_, outputs, 1, inputs, 1, nullptr, s_, &dxy_dy);
1909   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1910 
1911   TF_SessionOptions* opts = TF_NewSessionOptions();
1912   TF_Session* sess = TF_NewSession(graph_, opts, s_);
1913   TF_DeleteSessionOptions(opts);
1914   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1915 
1916   TF_Output feeds[] = {{x, 0}, {y, 0}};
1917   TF_Tensor* feedValues[] = {FloatTensor(X), FloatTensor(Y)};
1918   TF_Output fetches[] = {dxy_dx, dxy_dy};
1919   TF_Tensor* fetchValues[] = {nullptr, nullptr};
1920 
1921   TF_SessionRun(sess, nullptr /* run_options */, feeds, feedValues, 2, fetches,
1922                 fetchValues, 2, nullptr /* target_opers */, 0,
1923                 nullptr /* run_metadata */, s_);
1924   TF_DeleteTensor(feedValues[0]);
1925   TF_DeleteTensor(feedValues[1]);
1926   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1927   TF_DeleteSession(sess, s_);
1928   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1929 
1930   float dxy_dxValue = 0.0f, dxy_dyValue = 0.0f;
1931   ScalarFloatFromTensor(fetchValues[0], &dxy_dxValue);
1932   EXPECT_EQ(Y, dxy_dxValue);
1933 
1934   ScalarFloatFromTensor(fetchValues[1], &dxy_dyValue);
1935   EXPECT_EQ(X, dxy_dyValue);
1936 
1937   TF_DeleteTensor(fetchValues[0]);
1938   TF_DeleteTensor(fetchValues[1]);
1939 }
1940 
1941 // REGISTER_OP for CApiAttributesTest test cases.
1942 // Registers two ops, each with a single attribute called 'v'.
1943 // The attribute in one op will have a type 'type', the other
1944 // will have list(type).
1945 #define ATTR_TEST_REGISTER_OP(type)                           \
1946   REGISTER_OP("CApiAttributesTestOp" #type)                   \
1947       .Attr("v: " #type)                                      \
1948       .SetShapeFn(tensorflow::shape_inference::UnknownShape); \
1949   REGISTER_OP("CApiAttributesTestOpList" #type)               \
1950       .Attr("v: list(" #type ")")                             \
1951       .SetShapeFn(tensorflow::shape_inference::UnknownShape)
1952 ATTR_TEST_REGISTER_OP(string);
1953 ATTR_TEST_REGISTER_OP(int);
1954 ATTR_TEST_REGISTER_OP(float);
1955 ATTR_TEST_REGISTER_OP(bool);
1956 ATTR_TEST_REGISTER_OP(type);
1957 ATTR_TEST_REGISTER_OP(shape);
1958 ATTR_TEST_REGISTER_OP(tensor);
1959 #undef ATTR_TEST_REGISTER_OP
1960 
1961 class CApiAttributesTest : public ::testing::Test {
1962  protected:
CApiAttributesTest()1963   CApiAttributesTest()
1964       : s_(TF_NewStatus()), graph_(TF_NewGraph()), counter_(0) {}
1965 
~CApiAttributesTest()1966   ~CApiAttributesTest() override {
1967     TF_DeleteGraph(graph_);
1968     TF_DeleteStatus(s_);
1969   }
1970 
init(string type)1971   TF_OperationDescription* init(string type) {
1972     // Construct op_name to match the name used by REGISTER_OP in the
1973     // ATTR_TEST_REGISTER calls above.
1974     string op_name = "CApiAttributesTestOp";
1975     if (type.find("list(") == 0) {
1976       op_name += "List";
1977       type = type.replace(0, 5, "");
1978       type = type.replace(type.size() - 1, 1, "");
1979     }
1980     op_name += type;
1981     return TF_NewOperation(
1982         graph_, op_name.c_str(),
1983         ::tensorflow::strings::StrCat("name", counter_++).c_str());
1984   }
1985 
1986   TF_Status* s_;
1987 
1988  private:
1989   TF_Graph* graph_;
1990   int counter_;
1991 };
1992 
1993 // Helper macros for the TF_OperationGetAttr* tests.
1994 // TODO(ashankar): Use gmock matchers instead?
1995 // (https://github.com/google/googletest/blob/master/googlemock/docs/CookBook.md#writing-new-parameterized-matchers-quickly)
1996 // That will require setting up the tensorflow build with gmock.
1997 #define EXPECT_TF_META(attr_name, expected_list_size, expected_type, \
1998                        expected_total_size)                          \
1999   do {                                                               \
2000     auto m = TF_OperationGetAttrMetadata(oper, attr_name, s_);       \
2001     EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);              \
2002     const unsigned char e = expected_list_size >= 0 ? 1 : 0;         \
2003     EXPECT_EQ(e, m.is_list);                                         \
2004     EXPECT_EQ(expected_list_size, m.list_size);                      \
2005     EXPECT_EQ(expected_type, m.type);                                \
2006     EXPECT_EQ(expected_total_size, m.total_size);                    \
2007   } while (0)
2008 
TEST_F(CApiAttributesTest,String)2009 TEST_F(CApiAttributesTest, String) {
2010   auto desc = init("string");
2011   TF_SetAttrString(desc, "v", "bunny", 5);
2012 
2013   auto oper = TF_FinishOperation(desc, s_);
2014   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2015   EXPECT_TF_META("v", -1, TF_ATTR_STRING, 5);
2016   std::unique_ptr<char[]> value(new char[5]);
2017 
2018   TF_OperationGetAttrString(oper, "v", value.get(), 5, s_);
2019   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2020   EXPECT_EQ("bunny", string(static_cast<const char*>(value.get()), 5));
2021 }
2022 
TEST_F(CApiAttributesTest,StringList)2023 TEST_F(CApiAttributesTest, StringList) {
2024   std::vector<string> list = {"bugs", "bunny", "duck"};
2025   std::unique_ptr<const void*[]> list_ptrs;
2026   std::unique_ptr<size_t[]> list_lens;
2027   StringVectorToArrays(list, &list_ptrs, &list_lens);
2028   int list_total_size = 0;
2029   for (const auto& s : list) {
2030     list_total_size += s.size();
2031   }
2032 
2033   auto desc = init("list(string)");
2034   TF_SetAttrStringList(desc, "v", list_ptrs.get(), list_lens.get(),
2035                        list.size());
2036 
2037   auto oper = TF_FinishOperation(desc, s_);
2038   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2039 
2040   EXPECT_TF_META("v", list.size(), TF_ATTR_STRING, list_total_size);
2041   std::unique_ptr<void*[]> values(new void*[list.size()]);
2042   std::unique_ptr<size_t[]> lens(new size_t[list.size()]);
2043   std::unique_ptr<char[]> storage(new char[list_total_size]);
2044   TF_OperationGetAttrStringList(oper, "v", values.get(), lens.get(),
2045                                 list.size(), storage.get(), list_total_size,
2046                                 s_);
2047   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2048   for (size_t i = 0; i < list.size(); ++i) {
2049     EXPECT_EQ(list[i].size(), lens[i]) << i;
2050     EXPECT_EQ(list[i], string(static_cast<const char*>(values[i]), lens[i]))
2051         << i;
2052   }
2053 }
2054 
TEST_F(CApiAttributesTest,Int)2055 TEST_F(CApiAttributesTest, Int) {
2056   auto desc = init("int");
2057   TF_SetAttrInt(desc, "v", 31415);
2058 
2059   auto oper = TF_FinishOperation(desc, s_);
2060   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2061   EXPECT_TF_META("v", -1, TF_ATTR_INT, -1);
2062 
2063   int64_t value;
2064   TF_OperationGetAttrInt(oper, "v", &value, s_);
2065   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2066   EXPECT_EQ(31415, value);
2067 }
2068 
TEST_F(CApiAttributesTest,IntList)2069 TEST_F(CApiAttributesTest, IntList) {
2070   const int64_t list[] = {1, 2, 3, 4};
2071   const size_t list_size = TF_ARRAYSIZE(list);
2072 
2073   auto desc = init("list(int)");
2074   TF_SetAttrIntList(desc, "v", list, list_size);
2075 
2076   auto oper = TF_FinishOperation(desc, s_);
2077   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2078 
2079   int64_t values[list_size];
2080   EXPECT_TF_META("v", list_size, TF_ATTR_INT, -1);
2081   TF_OperationGetAttrIntList(oper, "v", values, list_size, s_);
2082   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2083   EXPECT_TRUE(std::equal(std::begin(list), std::end(list), std::begin(values)));
2084 }
2085 
TEST_F(CApiAttributesTest,Float)2086 TEST_F(CApiAttributesTest, Float) {
2087   auto desc = init("float");
2088   TF_SetAttrFloat(desc, "v", 2.718);
2089 
2090   auto oper = TF_FinishOperation(desc, s_);
2091   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2092   EXPECT_TF_META("v", -1, TF_ATTR_FLOAT, -1);
2093 
2094   float value;
2095   TF_OperationGetAttrFloat(oper, "v", &value, s_);
2096   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2097   EXPECT_FLOAT_EQ(2.718, value);
2098 }
2099 
TEST_F(CApiAttributesTest,FloatList)2100 TEST_F(CApiAttributesTest, FloatList) {
2101   const float list[] = {1.414, 2.718, 3.1415};
2102   const size_t list_size = TF_ARRAYSIZE(list);
2103 
2104   auto desc = init("list(float)");
2105   TF_SetAttrFloatList(desc, "v", list, list_size);
2106 
2107   auto oper = TF_FinishOperation(desc, s_);
2108   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2109 
2110   float values[list_size];
2111   EXPECT_TF_META("v", list_size, TF_ATTR_FLOAT, -1);
2112   TF_OperationGetAttrFloatList(oper, "v", values, list_size, s_);
2113   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2114   EXPECT_TRUE(std::equal(std::begin(list), std::end(list), std::begin(values)));
2115 }
2116 
TEST_F(CApiAttributesTest,Bool)2117 TEST_F(CApiAttributesTest, Bool) {
2118   auto desc = init("bool");
2119   TF_SetAttrBool(desc, "v", 1);
2120 
2121   auto oper = TF_FinishOperation(desc, s_);
2122   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2123   EXPECT_TF_META("v", -1, TF_ATTR_BOOL, -1);
2124 
2125   unsigned char value;
2126   TF_OperationGetAttrBool(oper, "v", &value, s_);
2127   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2128   EXPECT_EQ(1, value);
2129 }
2130 
TEST_F(CApiAttributesTest,BoolList)2131 TEST_F(CApiAttributesTest, BoolList) {
2132   const unsigned char list[] = {0, 1, 1, 0, 0, 1, 1};
2133   const size_t list_size = TF_ARRAYSIZE(list);
2134 
2135   auto desc = init("list(bool)");
2136   TF_SetAttrBoolList(desc, "v", list, list_size);
2137 
2138   auto oper = TF_FinishOperation(desc, s_);
2139   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2140 
2141   unsigned char values[list_size];
2142   EXPECT_TF_META("v", list_size, TF_ATTR_BOOL, -1);
2143   TF_OperationGetAttrBoolList(oper, "v", values, list_size, s_);
2144   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2145   EXPECT_TRUE(std::equal(std::begin(list), std::end(list), std::begin(values)));
2146 }
2147 
TEST_F(CApiAttributesTest,Type)2148 TEST_F(CApiAttributesTest, Type) {
2149   auto desc = init("type");
2150   TF_SetAttrType(desc, "v", TF_COMPLEX128);
2151 
2152   auto oper = TF_FinishOperation(desc, s_);
2153   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2154   EXPECT_TF_META("v", -1, TF_ATTR_TYPE, -1);
2155 
2156   TF_DataType value;
2157   TF_OperationGetAttrType(oper, "v", &value, s_);
2158   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2159   EXPECT_EQ(TF_COMPLEX128, value);
2160 }
2161 
TEST_F(CApiAttributesTest,TypeList)2162 TEST_F(CApiAttributesTest, TypeList) {
2163   const TF_DataType list[] = {TF_FLOAT, TF_DOUBLE, TF_HALF, TF_COMPLEX128};
2164   const size_t list_size = TF_ARRAYSIZE(list);
2165 
2166   auto desc = init("list(type)");
2167   TF_SetAttrTypeList(desc, "v", list, list_size);
2168 
2169   auto oper = TF_FinishOperation(desc, s_);
2170   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2171 
2172   TF_DataType values[list_size];
2173   EXPECT_TF_META("v", list_size, TF_ATTR_TYPE, -1);
2174   TF_OperationGetAttrTypeList(oper, "v", values, list_size, s_);
2175   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2176   EXPECT_TRUE(std::equal(std::begin(list), std::end(list), std::begin(values)));
2177 }
2178 
TEST_F(CApiAttributesTest,Shape)2179 TEST_F(CApiAttributesTest, Shape) {
2180   // Unknown shape
2181   auto desc = init("shape");
2182   TF_SetAttrShape(desc, "v", nullptr, -1);
2183   auto oper = TF_FinishOperation(desc, s_);
2184   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2185   EXPECT_TF_META("v", -1, TF_ATTR_SHAPE, -1);
2186   TF_OperationGetAttrShape(oper, "v", nullptr, 10, s_);
2187   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2188 
2189   // Partially specified shape
2190   const int64_t partial_shape[] = {17, -1};
2191   const size_t sz = TF_ARRAYSIZE(partial_shape);
2192   desc = init("shape");
2193   TF_SetAttrShape(desc, "v", partial_shape, sz);
2194   oper = TF_FinishOperation(desc, s_);
2195   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2196   EXPECT_TF_META("v", -1, TF_ATTR_SHAPE, sz);
2197   int64_t values[sz];
2198   TF_OperationGetAttrShape(oper, "v", values, sz, s_);
2199   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2200   EXPECT_TRUE(
2201       std::equal(std::begin(partial_shape), std::end(partial_shape), values));
2202 }
2203 
TEST_F(CApiAttributesTest,ShapeList)2204 TEST_F(CApiAttributesTest, ShapeList) {
2205   const int64_t shape_1[] = {1, 3};
2206   const int64_t shape_2[] = {2, 4, 6};
2207   const int64_t* list[] = {&shape_1[0], &shape_2[0]};
2208   const size_t list_size = TF_ARRAYSIZE(list);
2209   const int ndims[] = {TF_ARRAYSIZE(shape_1), TF_ARRAYSIZE(shape_2)};
2210   const int total_ndims = 5;  // ndims[0] + ndims[1]
2211 
2212   auto desc = init("list(shape)");
2213   TF_SetAttrShapeList(desc, "v", list, ndims, list_size);
2214   auto oper = TF_FinishOperation(desc, s_);
2215   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2216 
2217   EXPECT_TF_META("v", list_size, TF_ATTR_SHAPE, total_ndims);
2218   int64_t* values[list_size];
2219   int values_ndims[list_size];
2220   int64_t storage[total_ndims];
2221   TF_OperationGetAttrShapeList(oper, "v", values, values_ndims, list_size,
2222                                storage, total_ndims, s_);
2223   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2224   for (size_t i = 0; i < list_size; ++i) {
2225     EXPECT_EQ(ndims[i], values_ndims[i]) << i;
2226     for (int j = 0; j < values_ndims[i]; ++j) {
2227       EXPECT_EQ(list[i][j], values[i][j]) << "(" << i << ", " << j << ")";
2228     }
2229   }
2230 }
2231 
TEST_F(CApiAttributesTest,TensorShapeProto)2232 TEST_F(CApiAttributesTest, TensorShapeProto) {
2233   const tensorflow::int64 pts[] = {2, 4, -1, 8};
2234   tensorflow::TensorShapeProto proto;
2235   tensorflow::PartialTensorShape(pts).AsProto(&proto);
2236   string bytes;
2237   proto.SerializeToString(&bytes);
2238 
2239   auto desc = init("shape");
2240   TF_SetAttrTensorShapeProto(desc, "v", bytes.data(), bytes.length(), s_);
2241   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2242   auto oper = TF_FinishOperation(desc, s_);
2243   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2244 
2245   EXPECT_TF_META("v", -1, TF_ATTR_SHAPE, 4);
2246   TF_Buffer* value = TF_NewBuffer();
2247   TF_OperationGetAttrTensorShapeProto(oper, "v", value, s_);
2248   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2249   EXPECT_EQ(bytes.length(), value->length);
2250   EXPECT_EQ(0, memcmp(bytes.data(), value->data, value->length));
2251   TF_DeleteBuffer(value);
2252 }
2253 
TEST_F(CApiAttributesTest,TensorShapeProtoList)2254 TEST_F(CApiAttributesTest, TensorShapeProtoList) {
2255   string bytes1, bytes2;
2256   tensorflow::TensorShapeProto proto;
2257 
2258   const tensorflow::int64 pts1[] = {2, 4, -1, 8};
2259   tensorflow::PartialTensorShape(pts1).AsProto(&proto);
2260   proto.SerializeToString(&bytes1);
2261 
2262   const tensorflow::int64 pts2[] = {1, 3, 5, 7};
2263   tensorflow::PartialTensorShape(pts2).AsProto(&proto);
2264   proto.SerializeToString(&bytes2);
2265 
2266   std::unique_ptr<const void*[]> list_ptrs;
2267   std::unique_ptr<size_t[]> list_lens;
2268   const std::vector<string> list = {bytes1, bytes2};
2269   StringVectorToArrays(list, &list_ptrs, &list_lens);
2270 
2271   auto desc = init("list(shape)");
2272   TF_SetAttrTensorShapeProtoList(desc, "v", list_ptrs.get(), list_lens.get(),
2273                                  list.size(), s_);
2274   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2275   auto oper = TF_FinishOperation(desc, s_);
2276   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2277 
2278   EXPECT_TF_META("v", 2, TF_ATTR_SHAPE, 8);
2279   TF_Buffer* values[2];
2280   TF_OperationGetAttrTensorShapeProtoList(oper, "v", values, 2, s_);
2281   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2282   for (int i = 0; i < 2; ++i) {
2283     int le = list_lens[i];
2284     int la = values[i]->length;
2285     const void* e = list_ptrs[i];
2286     const void* a = values[i]->data;
2287     EXPECT_EQ(le, la) << i;
2288     EXPECT_EQ(0, memcmp(e, a, std::min(le, la))) << i;
2289     TF_DeleteBuffer(values[i]);
2290   }
2291 }
2292 
TEST_F(CApiAttributesTest,Tensor)2293 TEST_F(CApiAttributesTest, Tensor) {
2294   const char tensor[] = {5, 7};
2295   const int64_t dims[] = {1, 2};
2296   const size_t ndims = TF_ARRAYSIZE(dims);
2297 
2298   auto desc = init("tensor");
2299   unique_tensor_ptr v(Int8Tensor(dims, ndims, tensor), TF_DeleteTensor);
2300   TF_SetAttrTensor(desc, "v", v.get(), s_);
2301   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2302 
2303   auto oper = TF_FinishOperation(desc, s_);
2304   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2305 
2306   EXPECT_TF_META("v", -1, TF_ATTR_TENSOR, -1);
2307   TF_Tensor* value;
2308   TF_OperationGetAttrTensor(oper, "v", &value, s_);
2309   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2310   ASSERT_NE(nullptr, value);
2311   EXPECT_EQ(TF_INT8, TF_TensorType(value));
2312   EXPECT_EQ(ndims, TF_NumDims(value));
2313   for (int i = 0; i < TF_NumDims(value); ++i) {
2314     EXPECT_EQ(dims[i], TF_Dim(value, i)) << i;
2315   }
2316   EXPECT_EQ(sizeof(char) * TF_ARRAYSIZE(tensor), TF_TensorByteSize(value));
2317   EXPECT_EQ(0, memcmp(tensor, TF_TensorData(value), TF_TensorByteSize(value)));
2318   TF_DeleteTensor(value);
2319 }
2320 
TEST_F(CApiAttributesTest,StringTensor)2321 TEST_F(CApiAttributesTest, StringTensor) {
2322   // Create the string-Tensor "attribute" value.
2323   const char test_string[] =
2324       "borkborkborkborkborkborkborkbork";  // >24bytes to force heap alloc
2325   TF_TString tstr[1];
2326   TF_TString_Init(&tstr[0]);
2327   TF_TString_Copy(&tstr[0], test_string, sizeof(test_string) - 1);
2328 
2329   auto deallocator = [](void* data, size_t len, void* arg) {};
2330   unique_tensor_ptr t_in(TF_NewTensor(TF_STRING, nullptr, 0, &tstr[0],
2331                                       sizeof(tstr), deallocator, nullptr),
2332                          TF_DeleteTensor);
2333 
2334   // Create a TF_Operation with the attribute t_in
2335   auto desc = init("tensor");
2336   TF_SetAttrTensor(desc, "v", t_in.get(), s_);
2337   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2338 
2339   auto oper = TF_FinishOperation(desc, s_);
2340   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2341 
2342   // Fetch the attribute back.
2343   EXPECT_TF_META("v", -1, TF_ATTR_TENSOR, -1);
2344   TF_Tensor* t_out = nullptr;
2345   TF_OperationGetAttrTensor(oper, "v", &t_out, s_);
2346   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2347   EXPECT_EQ(TF_STRING, TF_TensorType(t_out));
2348   EXPECT_EQ(0, TF_NumDims(t_out));
2349   ASSERT_EQ(TF_TensorByteSize(t_in.get()), TF_TensorByteSize(t_out));
2350   TF_TString* t_in_tstr = static_cast<TF_TString*>(TF_TensorData(t_in.get()));
2351   TF_TString* t_out_tstr = static_cast<TF_TString*>(TF_TensorData(t_out));
2352   EXPECT_EQ(absl::string_view(test_string),
2353             absl::string_view(TF_TString_GetDataPointer(t_out_tstr),
2354                               TF_TString_GetSize(t_out_tstr)));
2355   EXPECT_EQ(absl::string_view(TF_TString_GetDataPointer(t_in_tstr),
2356                               TF_TString_GetSize(t_in_tstr)),
2357             absl::string_view(TF_TString_GetDataPointer(t_out_tstr),
2358                               TF_TString_GetSize(t_out_tstr)));
2359   TF_DeleteTensor(t_out);
2360   TF_TString_Dealloc(&tstr[0]);
2361 }
2362 
TEST_F(CApiAttributesTest,TensorList)2363 TEST_F(CApiAttributesTest, TensorList) {
2364   const char tensor1[] = {5, 7};
2365   const int64_t dims1[] = {1, 2};
2366   const size_t ndims1 = TF_ARRAYSIZE(dims1);
2367 
2368   const char tensor2[] = {2, 4, 6, 8};
2369   const int64_t dims2[] = {2, 2};
2370   const size_t ndims2 = TF_ARRAYSIZE(dims2);
2371 
2372   auto desc = init("list(tensor)");
2373   TF_Tensor* tmp[] = {
2374       Int8Tensor(dims1, ndims1, tensor1),
2375       Int8Tensor(dims2, ndims2, tensor2),
2376   };
2377   TF_SetAttrTensorList(desc, "v", tmp, TF_ARRAYSIZE(tmp), s_);
2378   for (int i = 0; i < TF_ARRAYSIZE(tmp); ++i) {
2379     TF_DeleteTensor(tmp[i]);
2380   }
2381   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2382   auto oper = TF_FinishOperation(desc, s_);
2383   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2384 
2385   EXPECT_TF_META("v", 2, TF_ATTR_TENSOR, -1);
2386   TF_Tensor* values[2];
2387   TF_OperationGetAttrTensorList(oper, "v", &values[0], TF_ARRAYSIZE(values),
2388                                 s_);
2389   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2390 
2391   const char* tensor_data[] = {&tensor1[0], &tensor2[0]};
2392   const size_t tensor_size[] = {TF_ARRAYSIZE(tensor1), TF_ARRAYSIZE(tensor2)};
2393   const int64_t* tensor_dims[] = {&dims1[0], &dims2[0]};
2394   const size_t tensor_ndims[] = {ndims1, ndims2};
2395   for (int i = 0; i < 2; ++i) {
2396     TF_Tensor* v = values[i];
2397     ASSERT_NE(nullptr, v) << i;
2398     EXPECT_EQ(TF_INT8, TF_TensorType(v)) << i;
2399     EXPECT_EQ(tensor_ndims[i], TF_NumDims(v)) << i;
2400     for (int j = 0; j < TF_NumDims(v); ++j) {
2401       EXPECT_EQ(tensor_dims[i][j], TF_Dim(v, j))
2402           << "Tensor #" << i << ", dimension #" << j;
2403     }
2404     EXPECT_EQ(sizeof(char) * tensor_size[i], TF_TensorByteSize(v)) << i;
2405     EXPECT_EQ(0,
2406               memcmp(tensor_data[i], TF_TensorData(v), TF_TensorByteSize(v)));
2407     TF_DeleteTensor(v);
2408   }
2409 }
2410 
TEST_F(CApiAttributesTest,EmptyList)2411 TEST_F(CApiAttributesTest, EmptyList) {
2412   auto desc = init("list(int)");
2413   TF_SetAttrIntList(desc, "v", nullptr, 0);
2414   auto oper = TF_FinishOperation(desc, s_);
2415   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2416   EXPECT_TF_META("v", 0, TF_ATTR_INT, -1);
2417 }
2418 
TEST_F(CApiAttributesTest,Errors)2419 TEST_F(CApiAttributesTest, Errors) {
2420   auto desc = init("int");
2421   TF_SetAttrInt(desc, "v", 3);
2422   auto oper = TF_FinishOperation(desc, s_);
2423   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2424   TF_OperationGetAttrString(oper, "v", nullptr, 0, s_);
2425   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
2426 }
2427 
TEST(TestApiDef,TestCreateApiDef)2428 TEST(TestApiDef, TestCreateApiDef) {
2429   // TODO(b/73318067): Fix linking for the GPU test generated by the
2430   // tf_cuda_cc_test() bazel rule and remove the next line.
2431   if (!GPUDeviceName().empty()) return;
2432 
2433   TF_Buffer* op_list_buf = TF_GetAllOpList();
2434   TF_Status* status = TF_NewStatus();
2435   auto* api_def_map = TF_NewApiDefMap(op_list_buf, status);
2436   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2437   TF_DeleteStatus(status);
2438 
2439   string op_name = "TestCApi";
2440   status = TF_NewStatus();
2441   auto* api_def_buf =
2442       TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status);
2443   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2444   TF_DeleteStatus(status);
2445 
2446   tensorflow::ApiDef api_def;
2447   EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length));
2448   EXPECT_EQ(op_name, api_def.graph_op_name());
2449   EXPECT_EQ(R"doc(Used to test C API)doc", api_def.summary());
2450 
2451   TF_DeleteBuffer(api_def_buf);
2452   TF_DeleteApiDefMap(api_def_map);
2453   TF_DeleteBuffer(op_list_buf);
2454 }
2455 
TEST(TestApiDef,TestCreateApiDefWithOverwrites)2456 TEST(TestApiDef, TestCreateApiDefWithOverwrites) {
2457   // TODO(b/73318067): Fix linking for the GPU test generated by the
2458   // tf_cuda_cc_test() bazel rule and remove the next line.
2459   if (!GPUDeviceName().empty()) return;
2460 
2461   TF_Buffer* op_list_buf = TF_GetAllOpList();
2462   TF_Status* status = TF_NewStatus();
2463   auto* api_def_map = TF_NewApiDefMap(op_list_buf, status);
2464   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2465   TF_DeleteStatus(status);
2466 
2467   string api_def_overwrites = R"(op: <
2468   graph_op_name: "TestCApi"
2469   summary: "New summary"
2470 >
2471 )";
2472   status = TF_NewStatus();
2473   TF_ApiDefMapPut(api_def_map, api_def_overwrites.c_str(),
2474                   api_def_overwrites.size(), status);
2475   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2476   TF_DeleteStatus(status);
2477 
2478   string op_name = "TestCApi";
2479   status = TF_NewStatus();
2480   auto* api_def_buf =
2481       TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status);
2482   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2483   TF_DeleteStatus(status);
2484 
2485   tensorflow::ApiDef api_def;
2486   EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length));
2487   EXPECT_EQ(op_name, api_def.graph_op_name());
2488   EXPECT_EQ("New summary", api_def.summary());
2489 
2490   TF_DeleteBuffer(api_def_buf);
2491   TF_DeleteApiDefMap(api_def_map);
2492   TF_DeleteBuffer(op_list_buf);
2493 }
2494 
2495 class DummyKernel : public tensorflow::OpKernel {
2496  public:
DummyKernel(tensorflow::OpKernelConstruction * context)2497   explicit DummyKernel(tensorflow::OpKernelConstruction* context)
2498       : OpKernel(context) {}
Compute(tensorflow::OpKernelContext * context)2499   void Compute(tensorflow::OpKernelContext* context) override {}
2500 };
2501 
2502 // Test we can query kernels
2503 REGISTER_OP("TestOpWithSingleKernel")
2504     .Input("a: float")
2505     .Input("b: float")
2506     .Output("o: float");
2507 REGISTER_KERNEL_BUILDER(
2508     Name("TestOpWithSingleKernel").Device(tensorflow::DEVICE_CPU), DummyKernel);
2509 
TEST(TestKernel,TestGetAllRegisteredKernels)2510 TEST(TestKernel, TestGetAllRegisteredKernels) {
2511   TF_Status* status = TF_NewStatus();
2512   TF_Buffer* kernel_list_buf = TF_GetAllRegisteredKernels(status);
2513   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2514   KernelList kernel_list;
2515   kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length);
2516   ASSERT_GT(kernel_list.kernel_size(), 0);
2517   TF_DeleteBuffer(kernel_list_buf);
2518   TF_DeleteStatus(status);
2519 }
2520 
TEST(TestKernel,TestGetRegisteredKernelsForOp)2521 TEST(TestKernel, TestGetRegisteredKernelsForOp) {
2522   TF_Status* status = TF_NewStatus();
2523   TF_Buffer* kernel_list_buf =
2524       TF_GetRegisteredKernelsForOp("TestOpWithSingleKernel", status);
2525   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2526   KernelList kernel_list;
2527   kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length);
2528   ASSERT_EQ(kernel_list.kernel_size(), 1);
2529   EXPECT_EQ(kernel_list.kernel(0).op(), "TestOpWithSingleKernel");
2530   EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU");
2531   TF_DeleteBuffer(kernel_list_buf);
2532   TF_DeleteStatus(status);
2533 }
2534 
TEST(TestKernel,TestGetRegisteredKernelsForOpNoKernels)2535 TEST(TestKernel, TestGetRegisteredKernelsForOpNoKernels) {
2536   TF_Status* status = TF_NewStatus();
2537   TF_Buffer* kernel_list_buf = TF_GetRegisteredKernelsForOp("Unknown", status);
2538   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2539   KernelList kernel_list;
2540   kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length);
2541   ASSERT_EQ(kernel_list.kernel_size(), 0);
2542   TF_DeleteBuffer(kernel_list_buf);
2543   TF_DeleteStatus(status);
2544 }
2545 
2546 #undef EXPECT_TF_META
2547 
TEST(CAPI,TestTensorAligned)2548 TEST(CAPI, TestTensorAligned) {
2549   int64_t dim = 7;
2550   size_t tensor_size_bytes = dim * TF_DataTypeSize(TF_FLOAT);
2551   TF_Tensor* a = TF_AllocateTensor(
2552       /*dtype=*/TF_FLOAT, /*dims=*/&dim, /*num_dims=*/1,
2553       /*len=*/tensor_size_bytes);
2554   float* data = reinterpret_cast<float*>(TF_TensorData(a));
2555   for (int i = 0; i < dim; ++i) {
2556     data[i] = 0;
2557   }
2558   if (EIGEN_MAX_ALIGN_BYTES > 0) {
2559     EXPECT_TRUE(TF_TensorIsAligned(a));
2560   }
2561   TF_DeleteTensor(a);
2562 }
2563 
TEST(CAPI,TestTensorIsNotAligned)2564 TEST(CAPI, TestTensorIsNotAligned) {
2565   // Test unaligned access via a Slice.
2566   Tensor x(DT_FLOAT, TensorShape({30}));
2567   x.flat<float>().setConstant(0.0);
2568 
2569   // Take an unaligned slice.
2570   Tensor y = x.Slice(1, 13);
2571   Status status;
2572   TF_Tensor* a = TF_TensorFromTensor(y, &status);
2573   if (EIGEN_MAX_ALIGN_BYTES > 0) {
2574     EXPECT_FALSE(TF_TensorIsAligned(a));
2575   }
2576   TF_DeleteTensor(a);
2577 }
2578 
2579 }  // namespace
2580 }  // namespace tensorflow
2581 
2582 // TODO(josh11b): Test:
2583 // * TF_SetDevice(desc, "/job:worker");
2584 // * control inputs / outputs
2585 // * targets
2586 // * TF_DeleteGraph() before TF_DeleteSession()
2587