1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/c_api.h"
17
18 #include "tensorflow/c/c_api_internal.h"
19 #include "tensorflow/c/c_test_util.h"
20 #include "tensorflow/core/framework/function.pb.h"
21 #include "tensorflow/core/framework/op_def.pb.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/lib/hash/hash.h"
24 #include "tensorflow/core/lib/strings/proto_serialization.h"
25 #include "tensorflow/core/lib/strings/str_util.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/test.h"
29
30 namespace tensorflow {
31 namespace {
32
33 // Specification for expected input/output and its type.
34 // DataType value of DT_INVALID signifies that we don't want to
35 // check the data type.
36 typedef std::pair<string, DataType> IOSpec;
37
M(const std::initializer_list<string> & names)38 std::vector<IOSpec> M(const std::initializer_list<string>& names) {
39 std::vector<IOSpec> v;
40 for (const string& name : names) {
41 v.push_back(IOSpec(name, DT_INVALID));
42 }
43 return v;
44 }
45
46 // Specification for an expected edge.
47 // src is either:
48 // - input name (as it appears in FunctionDef)
49 // - name of output tensor (in nested "add:z:0" format)
50 // dst is either:
51 // - output name (as it appears in FunctionDef)
52 // - <name_of_node>:<index_of_this_input_into_node> (this looks the same as
53 // output tensor naming, but it the index is actually an input index)
54 struct EdgeSpec : public std::pair<string, string> {
55 typedef std::pair<string, string> Base;
56
57 // Inherit the set of constructors
58 using Base::pair;
59
ToStringtensorflow::__anon1a216e640111::EdgeSpec60 string ToString() const { return strings::StrCat(first, "->", second); }
61 };
62
63 class CApiFunctionTest : public ::testing::Test {
64 protected:
CApiFunctionTest()65 CApiFunctionTest()
66 : s_(TF_NewStatus()),
67 func_graph_(TF_NewGraph()),
68 host_graph_(TF_NewGraph()),
69 func_(nullptr) {}
70
SetUp()71 void SetUp() override {}
72
~CApiFunctionTest()73 ~CApiFunctionTest() override {
74 TF_DeleteFunction(func_);
75 TF_DeleteGraph(host_graph_);
76 TF_DeleteGraph(func_graph_);
77 TF_DeleteStatus(s_);
78 }
79
Run(const std::vector<std::pair<TF_Operation *,TF_Tensor * >> & inputs,TF_Operation * output,int32_t expected_result)80 void Run(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs,
81 TF_Operation* output, int32_t expected_result) {
82 Run(inputs, {{output, 0}}, {expected_result});
83 }
84
85 // Run the host graph, which now contains a function and check that
86 // outputs are as expected.
87 // 'T' stands for 'tensor' since the outputs are tensors, not scalars.
RunT(const std::vector<std::pair<TF_Operation *,TF_Tensor * >> & inputs,std::initializer_list<TF_Output> outputs,const std::vector<std::vector<int32_t>> & expected_results)88 void RunT(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs,
89 std::initializer_list<TF_Output> outputs,
90 const std::vector<std::vector<int32_t>>& expected_results) {
91 // Create a session for this graph
92 CSession csession(host_graph_, s_);
93 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
94
95 // Run
96 csession.SetInputs(inputs);
97 csession.SetOutputs(outputs);
98 csession.Run(s_);
99 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
100
101 // Check results
102 for (int i = 0; i < expected_results.size(); ++i) {
103 TF_Tensor* out = csession.output_tensor(i);
104 ASSERT_TRUE(out != nullptr);
105 EXPECT_EQ(TF_INT32, TF_TensorType(out));
106 EXPECT_EQ(1, TF_NumDims(out));
107 CompareInt32Tensor(expected_results[i], out);
108 }
109 }
110
111 // Run the host graph, which now contains a function and check that
112 // outputs are as expected.
Run(const std::vector<std::pair<TF_Operation *,TF_Tensor * >> & inputs,std::initializer_list<TF_Output> outputs,const std::vector<int32_t> & expected_results)113 void Run(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs,
114 std::initializer_list<TF_Output> outputs,
115 const std::vector<int32_t>& expected_results) {
116 // Create a session for this graph.
117 CSession csession(host_graph_, s_);
118 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
119
120 csession.SetInputs(inputs);
121 csession.SetOutputs(outputs);
122 csession.Run(s_);
123 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
124
125 for (int i = 0; i < expected_results.size(); ++i) {
126 TF_Tensor* out = csession.output_tensor(i);
127 ASSERT_TRUE(out != nullptr);
128 EXPECT_EQ(TF_INT32, TF_TensorType(out));
129 EXPECT_EQ(0, TF_NumDims(out)); // scalar
130 ASSERT_EQ(sizeof(int32_t), TF_TensorByteSize(out));
131 int32_t* output_contents = static_cast<int32_t*>(TF_TensorData(out));
132 EXPECT_EQ(expected_results[i], *output_contents);
133 }
134 }
135
CompareInt32Tensor(const std::vector<int32_t> & expected,TF_Tensor * t)136 void CompareInt32Tensor(const std::vector<int32_t>& expected, TF_Tensor* t) {
137 int32_t* data = static_cast<int32_t*>(TF_TensorData(t));
138 size_t size = TF_TensorByteSize(t);
139 ASSERT_EQ(expected.size() * sizeof(int32_t), size);
140 for (int i = 0; i < expected.size(); ++i) {
141 ASSERT_EQ(expected[i], data[i]) << "Different data at index " << i;
142 }
143 }
144
ToOutput(const std::vector<TF_Operation * > ops)145 std::vector<TF_Output> ToOutput(const std::vector<TF_Operation*> ops) {
146 std::vector<TF_Output> out;
147 for (auto op : ops) {
148 out.push_back({op, 0});
149 }
150 return out;
151 }
152
Define(int num_opers,const std::vector<TF_Operation * > & opers,const std::vector<TF_Operation * > & inputs,const std::vector<TF_Operation * > & outputs,const std::vector<string> & output_names,bool expect_failure=false)153 void Define(int num_opers, const std::vector<TF_Operation*>& opers,
154 const std::vector<TF_Operation*>& inputs,
155 const std::vector<TF_Operation*>& outputs,
156 const std::vector<string>& output_names,
157 bool expect_failure = false) {
158 DefineT(num_opers, opers, ToOutput(inputs), ToOutput(outputs), output_names,
159 expect_failure);
160 }
161
162 // Caller must delete[] the returned value
ToArray(const std::vector<string> & strs)163 static const char** ToArray(const std::vector<string>& strs) {
164 const char** ptr = nullptr;
165 if (!strs.empty()) {
166 ptr = new const char*[strs.size()];
167 for (size_t i = 0; i < strs.size(); ++i) {
168 ptr[i] = strs[i].c_str();
169 }
170 }
171 return ptr;
172 }
173
174 // An explicit `num_opers` is needed so that we can distinguish between the
175 // case of no operations specified (-1) and the case of an empty set of
176 // operations specified (0).
DefineT(int num_opers,const std::vector<TF_Operation * > & opers,const std::vector<TF_Output> & inputs,const std::vector<TF_Output> & outputs,const std::vector<string> & output_names,bool expect_failure=false)177 void DefineT(int num_opers, const std::vector<TF_Operation*>& opers,
178 const std::vector<TF_Output>& inputs,
179 const std::vector<TF_Output>& outputs,
180 const std::vector<string>& output_names,
181 bool expect_failure = false) {
182 ASSERT_EQ(func_, nullptr);
183 const char** output_names_ptr = ToArray(output_names);
184 func_ = TF_GraphToFunction(func_graph_, func_name_, false, num_opers,
185 num_opers == -1 ? nullptr : opers.data(),
186 inputs.size(), inputs.data(), outputs.size(),
187 outputs.data(), output_names_ptr,
188 /*opts=*/nullptr, /*description=*/nullptr, s_);
189 delete[] output_names_ptr;
190 if (expect_failure) {
191 ASSERT_EQ(func_, nullptr);
192 return;
193 }
194
195 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
196 ASSERT_NE(func_, nullptr);
197 ASSERT_EQ(std::string(func_name_), std::string(TF_FunctionName(func_)));
198 TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
199 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
200 }
201
Use(const std::vector<TF_Operation * > & inputs)202 TF_Operation* Use(const std::vector<TF_Operation*>& inputs) {
203 return UseT(ToOutput(inputs));
204 }
205
UseT(const std::vector<TF_Output> & inputs)206 TF_Operation* UseT(const std::vector<TF_Output>& inputs) {
207 TF_Operation* op;
208 UseHelper(inputs, &op);
209 return op;
210 }
211
212 // All the *Helper methods are used as a workaround for the restrictions that
213 // one cannot call ASSERT_* methods in non-void-returning functions (when
214 // exceptions are disabled during compilation)
UseHelper(const std::vector<TF_Output> & inputs,TF_Operation ** op)215 void UseHelper(const std::vector<TF_Output>& inputs, TF_Operation** op) {
216 TF_OperationDescription* desc =
217 TF_NewOperation(host_graph_, func_name_, func_node_name_);
218 for (auto input : inputs) {
219 TF_AddInput(desc, input);
220 }
221 // Set device to CPU because some ops inside the function might not be
222 // available on GPU.
223 TF_SetDevice(desc, "/cpu:0");
224 *op = TF_FinishOperation(desc, s_);
225 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
226 ASSERT_NE(*op, nullptr);
227 }
228
fdef()229 FunctionDef fdef() {
230 tensorflow::FunctionDef fdef;
231 EXPECT_TRUE(GetFunctionDef(func_, &fdef));
232 return fdef;
233 }
234
235 // logging utility
236 template <class Container>
ToString(const Container & v)237 string ToString(const Container& v) {
238 std::stringstream ss;
239 ss << "{";
240 size_t i = 0;
241 for (const auto& e : v) {
242 if (i != 0) {
243 ss << ", ";
244 }
245 ss << e.ToString();
246 ++i;
247 }
248 ss << "}";
249 return ss.str();
250 }
251
VerifyFDefNodes(const tensorflow::FunctionDef & fdef,const std::unordered_set<string> & nodes)252 void VerifyFDefNodes(const tensorflow::FunctionDef& fdef,
253 const std::unordered_set<string>& nodes) {
254 ASSERT_EQ(nodes.size(), fdef.node_def_size())
255 << "Got unexpected number of nodes. Expected: ["
256 << str_util::Join(nodes, ", ")
257 << "] Actual nodes in fdef: " << fdef.DebugString();
258 for (const NodeDef& node_def : fdef.node_def()) {
259 ASSERT_TRUE(nodes.find(node_def.name()) != nodes.end())
260 << "Got unexpected node: " << node_def.name()
261 << " in fdef: " << fdef.DebugString();
262 }
263 }
264
VerifyFDefInputs(const tensorflow::FunctionDef & fdef,const std::vector<IOSpec> & inputs)265 void VerifyFDefInputs(const tensorflow::FunctionDef& fdef,
266 const std::vector<IOSpec>& inputs) {
267 const OpDef& signature = fdef.signature();
268 ASSERT_EQ(inputs.size(), signature.input_arg_size());
269 for (int i = 0; i < inputs.size(); ++i) {
270 const OpDef::ArgDef& arg = signature.input_arg(i);
271 const IOSpec& in = inputs[i];
272 if (in.second != DT_INVALID) {
273 ASSERT_EQ(arg.type(), in.second)
274 << "Got unexpected type for input " << i
275 << ". fdef: " << fdef.DebugString();
276 }
277 ASSERT_EQ(arg.name(), in.first) << "Got unexpected name for input " << i
278 << ". fdef: " << fdef.DebugString();
279 }
280 }
281
VerifyFDefOutputs(const tensorflow::FunctionDef & fdef,const std::vector<IOSpec> & outputs)282 void VerifyFDefOutputs(const tensorflow::FunctionDef& fdef,
283 const std::vector<IOSpec>& outputs) {
284 const OpDef& signature = fdef.signature();
285 ASSERT_EQ(outputs.size(), signature.output_arg_size());
286 for (int i = 0; i < outputs.size(); ++i) {
287 const OpDef::ArgDef& arg = signature.output_arg(i);
288 const IOSpec& out = outputs[i];
289 if (out.second != DT_INVALID) {
290 ASSERT_EQ(arg.type(), out.second)
291 << "Got unexpected type for output " << i
292 << ". fdef: " << fdef.DebugString();
293 }
294 ASSERT_EQ(arg.name(), out.first) << "Got unexpected name for output " << i
295 << ". fdef: " << fdef.DebugString();
296 }
297 }
298
VerifyFDefEdges(const tensorflow::FunctionDef & fdef,const std::vector<EdgeSpec> & e_edges,const std::vector<EdgeSpec> & c_edges,bool is_exact_edges=true)299 void VerifyFDefEdges(
300 const tensorflow::FunctionDef& fdef,
301 const std::vector<EdgeSpec>& e_edges, // expected edges
302 const std::vector<EdgeSpec>& c_edges, // expected ctrl edges
303 bool is_exact_edges = true) {
304 // Build a set of edges from fdef
305 std::set<EdgeSpec> a_edges; // actual edges
306 // Get edges from inputs to body nodes and between body nodes
307 for (const NodeDef& node_def : fdef.node_def()) {
308 for (int i = 0; i < node_def.input_size(); ++i) {
309 const string& in = node_def.input(i);
310 const auto& v =
311 a_edges.insert({in, strings::StrCat(node_def.name(), ":", i)});
312 ASSERT_TRUE(v.second) << "Duplicate edge " << in << " -> "
313 << strings::StrCat(node_def.name(), ":", i)
314 << ". fdef: " << fdef.DebugString();
315 }
316 }
317 // Get edges from body nodes to outputs and from inputs to outputs
318 for (const OpDef::ArgDef& arg : fdef.signature().output_arg()) {
319 const auto& iter = fdef.ret().find(arg.name());
320 if (iter != fdef.ret().end()) {
321 const auto& v = a_edges.insert({iter->second, arg.name()});
322 ASSERT_TRUE(v.second) << "Duplicate edge " << iter->second << " -> "
323 << arg.name() << ". fdef: " << fdef.DebugString();
324 } else {
325 const auto& v = a_edges.insert({arg.name(), arg.name()});
326 ASSERT_TRUE(v.second) << "Duplicate edge " << arg.name() << " -> "
327 << arg.name() << ". fdef: " << fdef.DebugString();
328 }
329 }
330
331 // Verify edges
332 for (const EdgeSpec& e : e_edges) {
333 ASSERT_TRUE(a_edges.find(e) != a_edges.end())
334 << "Failed to find expected edge " << e.ToString()
335 << " in fdef: " << fdef.DebugString();
336 }
337 for (const EdgeSpec& e : c_edges) {
338 ASSERT_TRUE(a_edges.find(e) != a_edges.end())
339 << "Failed to find expected control edge " << e.ToString()
340 << " in fdef: " << fdef.DebugString();
341 }
342
343 // If caller specified all edges, check that we have seen all
344 if (is_exact_edges) {
345 ASSERT_EQ(e_edges.size() + c_edges.size(), a_edges.size())
346 << "Expected edges: " << ToString(e_edges)
347 << " Expected Control edges: " << ToString(c_edges)
348 << " Actual edges: " << ToString(a_edges)
349 << " in fdef: " << fdef.DebugString();
350 }
351 }
352
VerifyFDef(const std::unordered_set<string> & nodes,const std::vector<IOSpec> & inputs,const std::vector<IOSpec> & outputs,const std::vector<EdgeSpec> & e_edges,const std::vector<EdgeSpec> & c_edges,bool is_exact_edges=true)353 void VerifyFDef(const std::unordered_set<string>& nodes,
354 const std::vector<IOSpec>& inputs,
355 const std::vector<IOSpec>& outputs,
356 const std::vector<EdgeSpec>& e_edges, // expected edges
357 const std::vector<EdgeSpec>& c_edges, // expected ctrl edges
358 bool is_exact_edges = true) {
359 tensorflow::FunctionDef fdef;
360 ASSERT_TRUE(GetFunctionDef(func_, &fdef));
361 VerifyFDefNodes(fdef, nodes);
362 VerifyFDefInputs(fdef, inputs);
363 VerifyFDefOutputs(fdef, outputs);
364 VerifyFDefEdges(fdef, e_edges, c_edges, is_exact_edges);
365 }
366
367 // Serialize func_ to fdef and import it back
Reincarnate()368 void Reincarnate() {
369 // func_ -> fdef
370 tensorflow::FunctionDef fdef;
371 ASSERT_TRUE(GetFunctionDef(func_, &fdef));
372 TF_DeleteFunction(func_);
373
374 // fdef -> func_
375 string buf;
376 ASSERT_TRUE(fdef.SerializeToString(&buf));
377 func_ = TF_FunctionImportFunctionDef(buf.data(), buf.size(), s_);
378 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
379 }
380
GetAttr(const char * attr_name,AttrValue * out_attr)381 void GetAttr(const char* attr_name, AttrValue* out_attr) {
382 TF_Buffer* attr_buf = TF_NewBuffer();
383 TF_FunctionGetAttrValueProto(func_, attr_name, attr_buf, s_);
384 ASSERT_TRUE(out_attr->ParseFromArray(attr_buf->data, attr_buf->length));
385 TF_DeleteBuffer(attr_buf);
386 }
387
388 const char* func_name_ = "MyFunc";
389 const char* func_node_name_ = "MyFunc_0";
390 TF_Status* s_;
391 TF_Graph* func_graph_;
392 TF_Graph* host_graph_;
393 TF_Function* func_;
394
395 // Workaround for not being able to initialize empty map using {}
396 std::unordered_set<string> empty_;
397 };
398
TEST_F(CApiFunctionTest,OneOp_ZeroInputs_OneOutput)399 TEST_F(CApiFunctionTest, OneOp_ZeroInputs_OneOutput) {
400 /*
401 * constant
402 * |
403 * v
404 */
405 // Define
406 TF_Operation* c = ScalarConst(10, func_graph_, s_, "scalar10");
407 Define(-1, {}, {}, {c}, {});
408
409 // Use, run, and verify
410 TF_Operation* func_op = Use({});
411 Run({}, func_op, 10);
412 VerifyFDef({"scalar10_0"}, {}, {{"scalar10", DT_INT32}},
413 {{"scalar10_0:output:0", "scalar10"}}, {});
414 }
415
TEST_F(CApiFunctionTest,OneOp_OneInput_OneOutput)416 TEST_F(CApiFunctionTest, OneOp_OneInput_OneOutput) {
417 /*
418 * |
419 * v
420 * negate
421 * |
422 * v
423 */
424 // Define
425 TF_Operation* feed = Placeholder(func_graph_, s_);
426 TF_Operation* neg = Neg(feed, func_graph_, s_);
427 Define(-1, {}, {feed}, {neg}, {});
428
429 // Use, run, and verify
430 TF_Operation* func_feed = Placeholder(host_graph_, s_);
431 TF_Operation* func_op = Use({func_feed});
432 Run({{func_feed, Int32Tensor(3)}}, func_op, -3);
433 VerifyFDef({"neg_0"}, {{"feed", DT_INT32}}, {{"neg", DT_INT32}},
434 {{"feed", "neg_0:0"}, {"neg_0:y:0", "neg"}}, {});
435 }
436
TEST_F(CApiFunctionTest,OneOutput_OutputNames)437 TEST_F(CApiFunctionTest, OneOutput_OutputNames) {
438 /*
439 * |
440 * v
441 * negate
442 * |
443 * v
444 */
445 // Define
446 TF_Operation* feed = Placeholder(func_graph_, s_);
447 TF_Operation* neg = Neg(feed, func_graph_, s_);
448 Define(-1, {}, {feed}, {neg}, {"negated_num"});
449
450 // Use, run, and verify
451 TF_Operation* func_feed = Placeholder(host_graph_, s_);
452 TF_Operation* func_op = Use({func_feed});
453 Run({{func_feed, Int32Tensor(3)}}, func_op, -3);
454 VerifyFDef({"neg"}, {{"feed", DT_INT32}}, {{"negated_num", DT_INT32}},
455 {{"feed", "neg:0"}, {"neg:y:0", "negated_num"}}, {});
456 }
457
TEST_F(CApiFunctionTest,OutputNames_SameNameAsInput)458 TEST_F(CApiFunctionTest, OutputNames_SameNameAsInput) {
459 /*
460 * |
461 * v
462 * negate
463 * |
464 * v
465 */
466 // Define
467 TF_Operation* feed = Placeholder(func_graph_, s_, "negation");
468 TF_Operation* neg = Neg(feed, func_graph_, s_, "neg");
469 Define(-1, {}, {feed}, {neg}, {"negation"});
470
471 // Use, run, and verify
472 TF_Operation* func_feed = Placeholder(host_graph_, s_);
473 TF_Operation* func_op = Use({func_feed});
474 Run({{func_feed, Int32Tensor(3)}}, func_op, -3);
475 VerifyFDef({"neg"}, {{"negation_0", DT_INT32}}, {{"negation", DT_INT32}},
476 {{"negation_0", "neg:0"}, {"neg:y:0", "negation"}}, {});
477 }
478
TEST_F(CApiFunctionTest,ZeroOps_Identity)479 TEST_F(CApiFunctionTest, ZeroOps_Identity) {
480 /*
481 * |
482 * |
483 * |
484 * v
485 */
486 // Define
487 TF_Operation* feed = Placeholder(func_graph_, s_);
488 Define(-1, {}, {feed}, {feed}, {});
489
490 // Use, run, and verify
491 TF_Operation* func_feed = Placeholder(host_graph_, s_);
492 TF_Operation* func_op = Use({func_feed});
493 Run({{func_feed, Int32Tensor(3)}}, func_op, 3);
494 VerifyFDef(empty_, {{"feed_0", DT_INT32}}, {{"feed", DT_INT32}},
495 {{"feed_0", "feed"}}, {});
496 }
497
TEST_F(CApiFunctionTest,ZeroOps_Permutation)498 TEST_F(CApiFunctionTest, ZeroOps_Permutation) {
499 /*
500 * | |
501 * \ /
502 * \/
503 * x
504 * /\
505 * / \
506 * | |
507 * v v
508 */
509 // Define
510 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
511 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
512 Define(-1, {}, {feed1, feed2}, {feed2, feed1}, {});
513
514 // Use, run, and verify
515 TF_Operation* two = ScalarConst(2, host_graph_, s_);
516 TF_Operation* func_feed = Placeholder(host_graph_, s_);
517 TF_Operation* func_op = Use({two, func_feed});
518 Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {3, 2});
519 VerifyFDef(empty_, M({{"feed1_0"}, {"feed2_0"}}), M({{"feed2"}, {"feed1"}}),
520 {{"feed1_0", "feed1"}, {"feed2_0", "feed2"}}, {});
521 }
522
TEST_F(CApiFunctionTest,ZeroOps_Permutation_OutputNames)523 TEST_F(CApiFunctionTest, ZeroOps_Permutation_OutputNames) {
524 /*
525 * | |
526 * \ /
527 * \/
528 * x
529 * /\
530 * / \
531 * | |
532 * v v
533 */
534 // Define
535 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
536 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
537 Define(-1, {}, {feed1, feed2}, {feed2, feed1}, {"first", "second"});
538
539 // Use, run, and verify
540 TF_Operation* two = ScalarConst(2, host_graph_, s_);
541 TF_Operation* func_feed = Placeholder(host_graph_, s_);
542 TF_Operation* func_op = Use({two, func_feed});
543 Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {3, 2});
544 VerifyFDef(empty_, M({{"feed1"}, {"feed2"}}), M({{"first"}, {"second"}}),
545 {{"feed1", "second"}, {"feed2", "first"}}, {});
546 }
547
TEST_F(CApiFunctionTest,OneOp_TwoInputs_OneOutput)548 TEST_F(CApiFunctionTest, OneOp_TwoInputs_OneOutput) {
549 /*
550 * | |
551 * v v
552 * add
553 * |
554 * v
555 */
556 // Define
557 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
558 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
559 TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
560 Define(-1, {}, {feed1, feed2}, {add}, {});
561
562 // Use, run, and verify
563 TF_Operation* two = ScalarConst(2, host_graph_, s_);
564 TF_Operation* func_feed = Placeholder(host_graph_, s_);
565 TF_Operation* func_op = Use({two, func_feed});
566 Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
567 VerifyFDef(
568 {"add_0"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}),
569 {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}}, {});
570 }
571
TEST_F(CApiFunctionTest,OneOp_TwoInputs_ZeroOutputs)572 TEST_F(CApiFunctionTest, OneOp_TwoInputs_ZeroOutputs) {
573 /*
574 * | |
575 * v v
576 * add
577 *
578 * (output ignored)
579 */
580 // Define
581 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
582 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
583 Add(feed1, feed2, func_graph_, s_);
584 Define(-1, {}, {feed1, feed2}, {}, {});
585
586 // Use, run, and verify
587 TF_Operation* two = ScalarConst(2, host_graph_, s_);
588 TF_Operation* func_feed = Placeholder(host_graph_, s_);
589 Use({two, func_feed});
590 VerifyFDef({"add"}, M({{"feed1"}, {"feed2"}}), {},
591 {{"feed1", "add:0"}, {"feed2", "add:1"}}, {});
592 }
593
TEST_F(CApiFunctionTest,TwoOps_ThreeInputs_OneOutput)594 TEST_F(CApiFunctionTest, TwoOps_ThreeInputs_OneOutput) {
595 /*
596 * | | |
597 * v v /
598 * add1 /
599 * | |
600 * v v
601 * add2
602 * |
603 * v
604 */
605 // Define
606 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
607 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
608 TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
609 TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
610 TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
611 Define(-1, {}, {feed1, feed2, feed3}, {add2}, {});
612
613 // Use, run, and verify
614 TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
615 TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten");
616 TF_Operation* func_feed = Placeholder(host_graph_, s_);
617 TF_Operation* func_op = Use({two, ten, func_feed});
618 Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 10 + 3);
619 VerifyFDef({"add1", "add2_0"}, M({{"feed1"}, {"feed2"}, {"feed3"}}),
620 M({{"add2"}}),
621 {{"feed1", "add1:0"},
622 {"feed2", "add1:1"},
623 {"add1:sum:0", "add2_0:0"},
624 {"feed3", "add2_0:1"},
625 {"add2_0:sum:0", "add2"}},
626 {});
627 }
628
TEST_F(CApiFunctionTest,OneOp_TwoInputs_TwoDuplicateOutputs)629 TEST_F(CApiFunctionTest, OneOp_TwoInputs_TwoDuplicateOutputs) {
630 /*
631 * | |
632 * v v
633 * add
634 * |
635 * +-+-+
636 * | |
637 * v v
638 */
639 // Define
640 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
641 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
642 TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
643 Define(-1, {}, {feed1, feed2}, {add, add}, {});
644
645 // Use, run, and verify
646 TF_Operation* two = ScalarConst(2, host_graph_, s_);
647 TF_Operation* func_feed = Placeholder(host_graph_, s_);
648 TF_Operation* func_op = Use({two, func_feed});
649 Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {5, 5});
650 VerifyFDef({"add_1"}, M({{"feed1"}, {"feed2"}}), M({{"add"}, {"add_0"}}),
651 {{"feed1", "add_1:0"},
652 {"feed2", "add_1:1"},
653 {"add_1:sum:0", "add"},
654 {"add_1:sum:0", "add_0"}},
655 {});
656 }
657
TEST_F(CApiFunctionTest,TwoDuplicateOutputs_OutputNames)658 TEST_F(CApiFunctionTest, TwoDuplicateOutputs_OutputNames) {
659 /*
660 * | |
661 * v v
662 * add
663 * |
664 * +-+-+
665 * | |
666 * v v
667 */
668 // Define
669 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
670 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
671 TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
672 Define(-1, {}, {feed1, feed2}, {add, add}, {"out1", "out2"});
673
674 // Use, run, and verify
675 TF_Operation* two = ScalarConst(2, host_graph_, s_);
676 TF_Operation* func_feed = Placeholder(host_graph_, s_);
677 TF_Operation* func_op = Use({two, func_feed});
678 Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {5, 5});
679 VerifyFDef({"add"}, M({{"feed1"}, {"feed2"}}), M({{"out1"}, {"out2"}}),
680 {{"feed1", "add:0"},
681 {"feed2", "add:1"},
682 {"add:sum:0", "out1"},
683 {"add:sum:0", "out2"}},
684 {});
685 }
686
TEST_F(CApiFunctionTest,TwoOps_ThreeInputs_TwoOutputs)687 TEST_F(CApiFunctionTest, TwoOps_ThreeInputs_TwoOutputs) {
688 /*
689 * | | |
690 * v v /
691 * add /
692 * | |
693 * +-+ |
694 * | | |
695 * | v v
696 * | add
697 * | |
698 * v v
699 */
700 // Define
701 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
702 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
703 TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
704 TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
705 TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
706 Define(-1, {}, {feed1, feed2, feed3}, {add1, add2}, {});
707
708 // Use, run, and verify
709 TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
710 TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten");
711 TF_Operation* func_feed = Placeholder(host_graph_, s_);
712 TF_Operation* func_op = Use({two, ten, func_feed});
713 Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {12, 15});
714 VerifyFDef({"add1_0", "add2_0"}, M({{"feed1"}, {"feed2"}, {"feed3"}}),
715 M({{"add1"}, {"add2"}}),
716 {{"feed1", "add1_0:0"},
717 {"feed2", "add1_0:1"},
718 {"add1_0:sum:0", "add2_0:0"},
719 {"feed3", "add2_0:1"},
720 {"add1_0:sum:0", "add1"},
721 {"add2_0:sum:0", "add2"}},
722 {});
723 }
724
TEST_F(CApiFunctionTest,FromSubsetOfOps)725 TEST_F(CApiFunctionTest, FromSubsetOfOps) {
726 /*
727 * | | |
728 * v v /
729 * add /
730 * | |
731 * +---+--+---+
732 * Ops used | | | |
733 * for func | v v |
734 * | | add |
735 * +-------> | | |
736 * | v |
737 * | |
738 * +----------+
739 */
740 // Define
741 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
742 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
743 TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
744 TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
745 TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
746 Define(1, {add2}, {add1, feed3}, {add2}, {});
747
748 // Use, run, and verify
749 TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
750 TF_Operation* func_feed = Placeholder(host_graph_, s_);
751 TF_Operation* func_op = Use({two, func_feed});
752 Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
753 VerifyFDef(
754 {"add2_0"}, M({{"add1"}, {"feed3"}}), M({{"add2"}}),
755 {{"add1", "add2_0:0"}, {"feed3", "add2_0:1"}, {"add2_0:sum:0", "add2"}},
756 {});
757 }
758
TEST_F(CApiFunctionTest,UsingOneOutputOfSplit)759 TEST_F(CApiFunctionTest, UsingOneOutputOfSplit) {
760 /*
761 * feed
762 * |
763 * +---------+---+
764 * | const0 | |
765 * | | | |
766 * | v / |
767 * | split |
768 * | | | | |
769 * | v | v |
770 * | | |
771 * +------+------+
772 * |
773 * v
774 *
775 * Only the second output from split is used as function output
776 */
777 // Define
778 TF_Operation* feed = Placeholder(func_graph_, s_);
779 TF_Operation* split = Split3(feed, func_graph_, s_);
780 DefineT(-1, {}, {{feed, 0}}, {{split, 1}}, {});
781
782 // Use, run, and verify
783 TF_Operation* func_feed = Placeholder(host_graph_, s_);
784 TF_Operation* func_op = Use({func_feed});
785 RunT({{func_feed, Int32Tensor({1, 2, 3, 4, 5, 6})}}, {{func_op, 0}},
786 {{3, 4}});
787 VerifyFDef({"split3_const0", "split3_0"}, M({{"feed"}}), M({{"split3"}}),
788 {{"split3_const0:output:0", "split3_0:0"},
789 {"feed", "split3_0:1"},
790 {"split3_0:output:1", "split3"}},
791 {});
792 }
793
TEST_F(CApiFunctionTest,UsingTwoOutputsOfSplit)794 TEST_F(CApiFunctionTest, UsingTwoOutputsOfSplit) {
795 /*
796 * feed
797 * |
798 * +---------+---+
799 * | const0 | |
800 * | | | |
801 * | v / |
802 * | split |
803 * | | | | |
804 * | | v | |
805 * | | | |
806 * +---+-----+---+
807 * | |
808 * v v
809 *
810 * Second output from split is not used as function output
811 */
812 // Define
813 TF_Operation* feed = Placeholder(func_graph_, s_);
814 TF_Operation* split = Split3(feed, func_graph_, s_);
815 DefineT(-1, {}, {{feed, 0}}, {{split, 0}, {split, 2}}, {});
816
817 // Use, run, and verify
818 TF_Operation* func_feed = Placeholder(host_graph_, s_);
819 TF_Operation* func_op = Use({func_feed});
820 RunT({{func_feed, Int32Tensor({1, 2, 3, 4, 5, 6})}},
821 {{func_op, 0}, {func_op, 1}}, {{1, 2}, {5, 6}});
822 VerifyFDef({"split3_const0", "split3_1"}, M({{"feed"}}),
823 M({{"split3"}, {"split3_0"}}),
824 {{"split3_const0:output:0", "split3_1:0"},
825 {"feed", "split3_1:1"},
826 {"split3_1:output:0", "split3"},
827 {"split3_1:output:2", "split3_0"}},
828 {});
829 }
830
TEST_F(CApiFunctionTest,UsingTwoOutputsOfSplitAsInputs)831 TEST_F(CApiFunctionTest, UsingTwoOutputsOfSplitAsInputs) {
832 /*
833 * |
834 * v
835 * split
836 * | | |
837 * | v |
838 * | |
839 * +---+-----+---+
840 * | | | |
841 * | v v |
842 * | add |
843 * | | |
844 * | | |
845 * +------+------+
846 * |
847 * v
848 */
849 // Define
850 TF_Operation* feed = Placeholder(func_graph_, s_);
851 TF_Operation* split = Split3(feed, func_graph_, s_);
852 TF_Operation* add = Add({split, 0}, {split, 2}, func_graph_, s_);
853 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
854 DefineT(1, {add}, {{split, 0}, {split, 2}}, {{add, 0}}, {});
855
856 // Use, run, and verify
857 TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
858 TF_Operation* func_feed = Placeholder(host_graph_, s_);
859 TF_Operation* func_op = Use({two, func_feed});
860 Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
861 VerifyFDef(
862 {"add_0"}, M({{"split3"}, {"split3_0"}}), M({{"add"}}),
863 {{"split3", "add_0:0"}, {"split3_0", "add_0:1"}, {"add_0:sum:0", "add"}},
864 {});
865 }
866
TEST_F(CApiFunctionTest,NodesUsedInInputsMustHaveSingleOutput)867 TEST_F(CApiFunctionTest, NodesUsedInInputsMustHaveSingleOutput) {
868 /*
869 * |
870 * v
871 * split
872 * | | |
873 * | v |
874 * | |
875 * input --->| |<--- input
876 * | |
877 * v v
878 * add
879 * |
880 * |
881 * v
882 */
883 // Define
884 TF_Tensor* tensor_123 = Int32Tensor({1, 2, 3});
885 TF_Operation* c = Const(tensor_123, func_graph_, s_, "const_array");
886 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
887 TF_Operation* split = Split3(c, func_graph_, s_);
888 TF_Operation* add = Add({split, 0}, {split, 2}, func_graph_, s_);
889 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
890 DefineT(-1, {}, {{split, 0}, {split, 2}}, {{add, 0}}, {}, true);
891 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
892 EXPECT_EQ(string("When `num_opers` is set to -1, nodes referenced in "
893 "`inputs` must have a single output. Node split3 has "
894 "3 outputs. Encountered while creating function 'MyFunc'"),
895 string(TF_Message(s_)));
896
897 TF_DeleteTensor(tensor_123);
898 }
899
TEST_F(CApiFunctionTest,FunctionWithWhileLoop)900 TEST_F(CApiFunctionTest, FunctionWithWhileLoop) {
901 // Inputs to the while loop and the function as a whole
902 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
903 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
904
905 // Outputs of the while loop corresponding to the two inputs above
906 // The first one will the function's output
907 std::vector<TF_Output> outputs;
908
909 // Add while loop to func_graph_
910 {
911 // The inputs to the while loop
912 std::vector<TF_Output> inputs = {{feed1, 0}, {feed2, 0}};
913 std::unique_ptr<TF_WhileParams> params(new TF_WhileParams(
914 TF_NewWhile(func_graph_, &inputs[0], inputs.size(), s_)));
915 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
916 params->name = "test_loop";
917
918 // Initialize outputs so we can easily detect errors/bugs
919 outputs.resize(2, {nullptr, -1});
920
921 // Create loop: while (input1 < input2) input1 += input2 + 1
922 TF_Operation* less_than = LessThan(
923 params->cond_inputs[0], params->cond_inputs[1], params->cond_graph, s_);
924 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
925 params->cond_output = {less_than, 0};
926
927 TF_Operation* add1 = Add(params->body_inputs[0], params->body_inputs[1],
928 params->body_graph, s_, "add1");
929 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
930 TF_Operation* one = ScalarConst(1, params->body_graph, s_);
931 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
932 TF_Operation* add2 = Add(add1, one, params->body_graph, s_, "add2");
933 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
934 params->body_outputs[0] = {add2, 0};
935 params->body_outputs[1] = params->body_inputs[1];
936
937 // Finalize while loop
938 TF_FinishWhile(params.get(), s_, &outputs[0]);
939 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
940 }
941
942 // Define function, use it in graph, and run
943 DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {outputs[0]}, {});
944 TF_Operation* five = ScalarConst(5, host_graph_, s_, "five");
945 TF_Operation* func_feed = Placeholder(host_graph_, s_);
946 TF_Operation* func_op = Use({func_feed, five});
947 Run({{func_feed, Int32Tensor(2)}}, func_op, 2 /*+=*/ + 5 + 1);
948
949 // Verify input, output, and subset of edges in fdef.
950 // The subset of edges we verify is a chain between feed1 and output to
951 // make sure that the correct output is picked.
952 tensorflow::FunctionDef fdef;
953 ASSERT_TRUE(GetFunctionDef(func_, &fdef));
954 VerifyFDefInputs(fdef, M({{"feed1"}, {"feed2"}}));
955 VerifyFDefOutputs(fdef, M({{"test_loop_exit"}}));
956 VerifyFDefEdges(fdef,
957 {{"feed1", "test_loop/Enter:0"},
958 {"test_loop/Enter:output:0", "test_loop/Merge:0"},
959 {"test_loop/Merge:output:0", "test_loop/Switch:0"},
960 {"test_loop/Switch:output_false:0", "test_loop/Exit:0"},
961 {"test_loop/Exit:output:0", "test_loop_exit"}},
962 {}, false);
963 }
964
TEST_F(CApiFunctionTest,ControlDependency)965 TEST_F(CApiFunctionTest, ControlDependency) {
966 /*
967 * | | scalar
968 * | | .
969 * v v . <---- control dependency
970 * add < -
971 * |
972 * v
973 */
974 // Define
975 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
976 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
977 TF_Operation* five = ScalarConst(5, func_graph_, s_);
978 TF_Operation* add =
979 AddWithCtrlDependency(feed1, feed2, func_graph_, five, s_);
980 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
981 Define(-1, {}, {feed1, feed2}, {add}, {});
982
983 // Use, run, and verify
984 TF_Operation* two = ScalarConst(2, host_graph_, s_);
985 TF_Operation* func_feed = Placeholder(host_graph_, s_);
986 TF_Operation* func_op = Use({two, func_feed});
987 Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
988 VerifyFDef(
989 {"add_0", "scalar"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}),
990 {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}},
991 {{"^scalar", "add_0:2"}});
992 }
993
TEST_F(CApiFunctionTest,ControlDependencyOutsideOfBody)994 TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody) {
995 /*
996 * | | scalar
997 * | | .
998 * v v . <---- control dependency
999 * add < -
1000 * |
1001 * v
1002 */
1003 // Define
1004 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1005 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1006 TF_Operation* five = ScalarConst(5, func_graph_, s_);
1007 TF_Operation* add =
1008 AddWithCtrlDependency(feed1, feed2, func_graph_, five, s_);
1009 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1010 Define(1, {add}, {feed1, feed2}, {add}, {}, true);
1011 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1012 EXPECT_EQ(string("The source of control edge [id=3 scalar:-1 -> add:-1] "
1013 "is not in the body. Encountered while creating "
1014 "function 'MyFunc'"),
1015 string(TF_Message(s_)));
1016 }
1017
TEST_F(CApiFunctionTest,ControlDependencyOutsideOfBody_FromInputNode)1018 TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody_FromInputNode) {
1019 /*
1020 * | |.
1021 * | | .
1022 * | | .
1023 * v v . <---- control dependency
1024 * add < -
1025 * |
1026 * v
1027 */
1028 // Define
1029 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1030 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1031 TF_Operation* add =
1032 AddWithCtrlDependency(feed1, feed2, func_graph_, feed1, s_);
1033 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1034 Define(-1, {}, {feed1, feed2}, {add}, {});
1035
1036 // Use, run, and verify
1037 TF_Operation* two = ScalarConst(2, host_graph_, s_);
1038 TF_Operation* func_feed = Placeholder(host_graph_, s_);
1039 TF_Operation* func_op = Use({two, func_feed});
1040 Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
1041 VerifyFDef(
1042 {"add_0"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}),
1043 {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}},
1044 {{"^feed1", "add_0:2"}});
1045 }
1046
TEST_F(CApiFunctionTest,DuplicateInputsAreNotAllowed)1047 TEST_F(CApiFunctionTest, DuplicateInputsAreNotAllowed) {
1048 /*
1049 * feed
1050 * |
1051 * +++
1052 * | |
1053 * +---+-+---+
1054 * | | | |
1055 * | v v |
1056 * | add |
1057 * | | |
1058 * | | |
1059 * +----+----+
1060 * |
1061 * v
1062 */
1063 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1064 TF_Operation* add = Add(feed1, feed1, func_graph_, s_);
1065 Define(-1, {}, {feed1, feed1}, {add}, {}, true);
1066 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1067 EXPECT_EQ(
1068 string("TF_Output feed1:0 appears more than once in the input list"),
1069 string(TF_Message(s_)));
1070 }
1071
TEST_F(CApiFunctionTest,DuplicateOutputNamesAreNotAllowed)1072 TEST_F(CApiFunctionTest, DuplicateOutputNamesAreNotAllowed) {
1073 /*
1074 * | | |
1075 * v v /
1076 * add /
1077 * | |
1078 * +-+ |
1079 * | | |
1080 * | v v
1081 * | add
1082 * | |
1083 * v v
1084 */
1085 // Define
1086 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1087 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1088 TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
1089 TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
1090 TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
1091 Define(-1, {}, {feed1, feed2, feed3}, {add1, add2}, {"my_out", "my_out"},
1092 true);
1093 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1094 EXPECT_EQ(string("Cannot have duplicate output names. Name 'my_out' "
1095 "appears more than once in 'output_names' array."),
1096 string(TF_Message(s_)));
1097 }
1098
TEST_F(CApiFunctionTest,InvalidInputTensor_HighIndex)1099 TEST_F(CApiFunctionTest, InvalidInputTensor_HighIndex) {
1100 /*
1101 * | |
1102 * v v
1103 * add
1104 * |
1105 * v
1106 */
1107 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1108 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1109 TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
1110 DefineT(-1, {}, {{feed1, 0}, {feed2, 2}}, {{add, 0}}, {}, true);
1111 EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s_));
1112 EXPECT_EQ(string("Node 'feed2' (type: 'Placeholder', num of outputs: 1) does "
1113 "not have output 2\n\tEncountered while processing "
1114 "input 1 into function 'MyFunc'"),
1115 string(TF_Message(s_)));
1116 }
1117
TEST_F(CApiFunctionTest,InvalidInputTensor_BadNodePtr)1118 TEST_F(CApiFunctionTest, InvalidInputTensor_BadNodePtr) {
1119 /*
1120 * | |
1121 * v v
1122 * add
1123 * |
1124 * v
1125 */
1126 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1127 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1128 TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
1129 DefineT(-1, {}, {{feed1, 0}, {nullptr, 0}}, {{add, 0}}, {}, true);
1130 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1131 EXPECT_EQ(string("Node is null\n\tEncountered while processing input 1 "
1132 "into function 'MyFunc'"),
1133 string(TF_Message(s_)));
1134 }
1135
TEST_F(CApiFunctionTest,InvalidOutputTensor_HighIndex)1136 TEST_F(CApiFunctionTest, InvalidOutputTensor_HighIndex) {
1137 /*
1138 * | |
1139 * v v
1140 * add
1141 * |
1142 * v
1143 */
1144 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1145 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1146 TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
1147 DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{add, 3}}, {}, true);
1148 EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s_));
1149 EXPECT_EQ(string("Node 'add' (type: 'AddN', num of outputs: 1) does "
1150 "not have output 3\n\tEncountered while processing "
1151 "output 0 from function 'MyFunc'"),
1152 string(TF_Message(s_)));
1153 }
1154
TEST_F(CApiFunctionTest,InvalidOutputTensor_BadNodePtr)1155 TEST_F(CApiFunctionTest, InvalidOutputTensor_BadNodePtr) {
1156 /*
1157 * | |
1158 * v v
1159 * add
1160 * |
1161 * v
1162 */
1163 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1164 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1165 Add(feed1, feed2, func_graph_, s_);
1166 DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{nullptr, 3}}, {}, true);
1167 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1168 EXPECT_EQ(string("Node is null\n\tEncountered while processing output 0 "
1169 "from function 'MyFunc'"),
1170 string(TF_Message(s_)));
1171 }
1172
TEST_F(CApiFunctionTest,NodeMissingInput)1173 TEST_F(CApiFunctionTest, NodeMissingInput) {
1174 /*
1175 * input---> | | <----missing input
1176 * v v
1177 * body----> add
1178 * |
1179 * v
1180 */
1181 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1182 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1183 TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
1184 DefineT(1, {add}, {{feed1, 0}}, {{add, 0}}, {}, true);
1185 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1186 EXPECT_EQ(string("Input 1, 'feed2:0', of node 'add' in function 'MyFunc' "
1187 "is not available. You might need to include it in inputs "
1188 "or include its source node in the body"),
1189 string(TF_Message(s_)));
1190 }
1191
TEST_F(CApiFunctionTest,OutputOpNotInBody)1192 TEST_F(CApiFunctionTest, OutputOpNotInBody) {
1193 /*
1194 * | |
1195 * v v
1196 * add scalar (scalar not included in body)
1197 * | |
1198 * v v (function has two outputs)
1199 */
1200 // Define
1201 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1202 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1203 TF_Operation* scalar = ScalarConst(2, func_graph_, s_);
1204 TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
1205 Define(1, {add}, {feed1, feed2}, {add, scalar}, {}, true);
1206 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1207 EXPECT_EQ(string("TF_Output scalar:0 is neither in the function body nor "
1208 "among function inputs. Encountered while creating "
1209 "function 'MyFunc'"),
1210 string(TF_Message(s_)));
1211 }
1212
DefineFunction(const char * name,TF_Function ** func,const char * description=nullptr,bool append_hash=false)1213 void DefineFunction(const char* name, TF_Function** func,
1214 const char* description = nullptr,
1215 bool append_hash = false) {
1216 std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
1217 TF_NewGraph(), TF_DeleteGraph);
1218 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
1219 TF_DeleteStatus);
1220
1221 TF_Operation* feed = Placeholder(func_graph.get(), s.get());
1222 TF_Operation* neg = Neg(feed, func_graph.get(), s.get());
1223
1224 TF_Output inputs[] = {{feed, 0}};
1225 TF_Output outputs[] = {{neg, 0}};
1226 *func = TF_GraphToFunction(func_graph.get(), name, append_hash, -1,
1227 /*opers=*/nullptr, 1, inputs, 1, outputs,
1228 /*output_names=*/nullptr,
1229 /*opts=*/nullptr, description, s.get());
1230 ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
1231 ASSERT_NE(*func, nullptr);
1232 }
1233
1234 REGISTER_OP("CustomOp")
1235 .Output("output: float32")
1236 .Attr("index: int")
1237 .SetShapeFn(tensorflow::shape_inference::UnknownShape);
1238
NodeWithPlaceholderAttrHelper(TF_Graph * graph,TF_Status * s,const char * name,const char * placeholder,TF_Operation ** op)1239 void NodeWithPlaceholderAttrHelper(TF_Graph* graph, TF_Status* s,
1240 const char* name, const char* placeholder,
1241 TF_Operation** op) {
1242 TF_OperationDescription* desc = TF_NewOperation(graph, "CustomOp", name);
1243 TF_SetAttrPlaceholder(desc, "index", placeholder);
1244 *op = TF_FinishOperation(desc, s);
1245 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1246 ASSERT_NE(*op, nullptr);
1247 }
1248
TEST_F(CApiFunctionTest,GraphToFunctionDefWithPlaceholderAttr)1249 TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) {
1250 std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
1251 TF_NewGraph(), TF_DeleteGraph);
1252 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
1253 TF_DeleteStatus);
1254
1255 TF_Operation *node1, *node2, *node3;
1256 NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node1", "v1",
1257 &node1);
1258 NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node2", "v1",
1259 &node2);
1260 NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node3", "v2",
1261 &node3);
1262
1263 TF_Output inputs[] = {};
1264 TF_Output outputs[] = {{node1, 0}, {node2, 0}, {node3, 0}};
1265 func_ = TF_GraphToFunction(
1266 func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
1267 /*opers=*/nullptr, 0, inputs, 3, outputs,
1268 /*output_names=*/nullptr,
1269 /*opts=*/nullptr, /*description=*/nullptr, s.get());
1270 ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
1271 ASSERT_NE(func_, nullptr);
1272
1273 // Verify that FunctionDef has 2 attributes, "v1" and "v2".
1274 ASSERT_EQ(func_->fdef.signature().attr().size(), 2);
1275 EXPECT_EQ(func_->fdef.signature().attr(0).name(), "v1");
1276 EXPECT_EQ(func_->fdef.signature().attr(0).type(), "int");
1277 EXPECT_EQ(func_->fdef.signature().attr(1).name(), "v2");
1278 EXPECT_EQ(func_->fdef.signature().attr(1).type(), "int");
1279 }
1280
TEST_F(CApiFunctionTest,SetGradientAndRun)1281 TEST_F(CApiFunctionTest, SetGradientAndRun) {
1282 // Define the function and its grad
1283 DefineFunction(func_name_, &func_);
1284 TF_Function* grad_func;
1285 DefineFunction("MyGrad", &grad_func);
1286
1287 // Add func and its gradient to host graph
1288 TF_GraphCopyFunction(host_graph_, func_, grad_func, s_);
1289 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1290
1291 // Verify that function and its grad are in host graph's GraphDef
1292 GraphDef gdef;
1293 GetGraphDef(host_graph_, &gdef);
1294 std::vector<string> func_names = GetFuncNames(gdef);
1295 ASSERT_EQ(2, func_names.size());
1296 ASSERT_EQ(func_name_, func_names[0]);
1297 ASSERT_EQ("MyGrad", func_names[1]);
1298 std::vector<std::pair<string, string>> grads = GetGradDefs(gdef);
1299 ASSERT_EQ(1, grads.size());
1300 ASSERT_EQ(func_name_, grads[0].first);
1301 ASSERT_EQ("MyGrad", grads[0].second);
1302
1303 // These calls must be noops
1304 TF_GraphCopyFunction(host_graph_, func_, grad_func, s_);
1305 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1306 TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
1307 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1308
1309 // Delete the gradient func.
1310 // It is safe to delete after adding a copy to host graph.
1311 TF_DeleteFunction(grad_func);
1312
1313 // Check that GraphDef did not change
1314 GraphDef gdef2;
1315 GetGraphDef(host_graph_, &gdef2);
1316 ASSERT_EQ(gdef.DebugString(), gdef2.DebugString());
1317
1318 // Use and run func
1319 TF_Operation* func_feed = Placeholder(host_graph_, s_);
1320 TF_Operation* func_op = Use({func_feed});
1321 Run({{func_feed, Int32Tensor(3)}}, func_op, -3);
1322 }
1323
TEST_F(CApiFunctionTest,SameGradForTwoFunctions)1324 TEST_F(CApiFunctionTest, SameGradForTwoFunctions) {
1325 // Define the functions
1326 TF_Function* func1;
1327 TF_Function* func2;
1328 TF_Function* grad_func;
1329 DefineFunction("FooFunc1", &func1);
1330 DefineFunction("FooFunc2", &func2);
1331 DefineFunction("MyGrad", &grad_func);
1332
1333 // Make grad_func be a gradient of func1 and func2
1334 TF_GraphCopyFunction(host_graph_, func1, grad_func, s_);
1335 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1336 TF_GraphCopyFunction(host_graph_, func2, grad_func, s_);
1337 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1338
1339 // Verify that functions and their gradients are in host graph's GraphDef
1340 GraphDef gdef;
1341 GetGraphDef(host_graph_, &gdef);
1342 std::vector<std::pair<string, string>> grads = GetGradDefs(gdef);
1343 ASSERT_EQ(2, grads.size());
1344 ASSERT_EQ("FooFunc1", grads[0].first);
1345 ASSERT_EQ("MyGrad", grads[0].second);
1346 ASSERT_EQ("FooFunc2", grads[1].first);
1347 ASSERT_EQ("MyGrad", grads[1].second);
1348
1349 TF_DeleteFunction(func1);
1350 TF_DeleteFunction(func2);
1351 TF_DeleteFunction(grad_func);
1352 }
1353
TEST_F(CApiFunctionTest,AddFunctionsThenMakeOneGradientOfAnother)1354 TEST_F(CApiFunctionTest, AddFunctionsThenMakeOneGradientOfAnother) {
1355 // Define the functions
1356 TF_Function* func;
1357 TF_Function* grad_func;
1358 DefineFunction("FooFunc", &func);
1359 DefineFunction("MyGrad", &grad_func);
1360
1361 // Add functions individually
1362 TF_GraphCopyFunction(host_graph_, func, nullptr, s_);
1363 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1364 TF_GraphCopyFunction(host_graph_, grad_func, nullptr, s_);
1365 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1366
1367 // Check that functions are added but not linked
1368 GraphDef gdef;
1369 GetGraphDef(host_graph_, &gdef);
1370 std::vector<string> func_names = GetFuncNames(gdef);
1371 ASSERT_EQ(2, func_names.size());
1372 ASSERT_EQ("FooFunc", func_names[0]);
1373 ASSERT_EQ("MyGrad", func_names[1]);
1374 ASSERT_EQ(0, GetGradDefs(gdef).size());
1375
1376 // Make grad_func a gradient of func
1377 TF_GraphCopyFunction(host_graph_, func, grad_func, s_);
1378 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1379
1380 // Verify that function and its grad are linked
1381 gdef.Clear();
1382 GetGraphDef(host_graph_, &gdef);
1383 std::vector<std::pair<string, string>> grads = GetGradDefs(gdef);
1384 ASSERT_EQ(1, grads.size());
1385 ASSERT_EQ("FooFunc", grads[0].first);
1386 ASSERT_EQ("MyGrad", grads[0].second);
1387
1388 TF_DeleteFunction(func);
1389 TF_DeleteFunction(grad_func);
1390 }
1391
TEST_F(CApiFunctionTest,GradientErrorCases)1392 TEST_F(CApiFunctionTest, GradientErrorCases) {
1393 // Define the function
1394 DefineFunction(func_name_, &func_);
1395 TF_Function* grad_func1;
1396 TF_Function* grad_func2;
1397 DefineFunction("MyGrad1", &grad_func1);
1398 DefineFunction("MyGrad2", &grad_func2);
1399
1400 // func cannot be null
1401 TF_GraphCopyFunction(host_graph_, nullptr, func_, s_);
1402 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1403 EXPECT_EQ(string("'func' argument to TF_GraphCopyFunction cannot be null"),
1404 string(TF_Message(s_)));
1405
1406 // Cannot change gradient
1407 TF_GraphCopyFunction(host_graph_, func_, grad_func1, s_);
1408 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1409 TF_GraphCopyFunction(host_graph_, func_, grad_func2, s_);
1410 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1411 EXPECT_EQ(string("Cannot assign gradient function 'MyGrad2' to 'MyFunc' "
1412 "because it already has gradient function 'MyGrad1'"),
1413 string(TF_Message(s_)));
1414
1415 TF_DeleteFunction(grad_func1);
1416 TF_DeleteFunction(grad_func2);
1417 }
1418
TEST_F(CApiFunctionTest,ImportFunctionDef)1419 TEST_F(CApiFunctionTest, ImportFunctionDef) {
1420 /*
1421 * Using a fairly complex function with output names
1422 *
1423 * | | |
1424 * v v /
1425 * add /
1426 * | |
1427 * +------+ |
1428 * | | |
1429 * | v v
1430 * | add
1431 * | |
1432 * v v
1433 * internal_out final_out
1434 */
1435 // Define
1436 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1437 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1438 TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
1439 TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
1440 TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
1441 Define(-1, {}, {feed1, feed2, feed3}, {add1, add2},
1442 {"internal_out", "final_out"});
1443
1444 // Save func_ to FunctionDef and import it back
1445 Reincarnate();
1446
1447 // Use, run, and verify
1448 TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
1449 TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten");
1450 TF_Operation* func_feed = Placeholder(host_graph_, s_);
1451 TF_Operation* func_op = Use({two, ten, func_feed});
1452 Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {12, 15});
1453 VerifyFDef({"add1", "add2"}, M({{"feed1"}, {"feed2"}, {"feed3"}}),
1454 M({{"internal_out"}, {"final_out"}}),
1455 {{"feed1", "add1:0"},
1456 {"feed2", "add1:1"},
1457 {"add1:sum:0", "add2:0"},
1458 {"feed3", "add2:1"},
1459 {"add1:sum:0", "internal_out"},
1460 {"add2:sum:0", "final_out"}},
1461 {});
1462 }
1463
TEST_F(CApiFunctionTest,ImportFunctionDef_InvalidProto)1464 TEST_F(CApiFunctionTest, ImportFunctionDef_InvalidProto) {
1465 // Invalid protobuf data (protos cannot start with 4 bytes of zeros)
1466 char proto[] = {0x0, 0x0, 0x0, 0x0};
1467 func_ = TF_FunctionImportFunctionDef(proto, 4, s_);
1468 EXPECT_TRUE(func_ == nullptr);
1469 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1470 EXPECT_EQ(string("Invalid FunctionDef given to TF_FunctionImportFunctionDef"),
1471 string(TF_Message(s_)));
1472 }
1473
TEST_F(CApiFunctionTest,Attribute)1474 TEST_F(CApiFunctionTest, Attribute) {
1475 DefineFunction(func_name_, &func_);
1476
1477 // Get non existent attribute
1478 TF_Buffer* attr_buf = TF_NewBuffer();
1479 TF_FunctionGetAttrValueProto(func_, "foo_attr", attr_buf, s_);
1480 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1481 EXPECT_EQ(string("Function 'MyFunc' has no attr named 'foo_attr'."),
1482 string(TF_Message(s_)));
1483 TF_DeleteBuffer(attr_buf);
1484
1485 // Set attr
1486 tensorflow::AttrValue attr;
1487 attr.set_s("test_attr_value");
1488 string bytes;
1489 attr.SerializeToString(&bytes);
1490 TF_FunctionSetAttrValueProto(func_, "test_attr_name", bytes.data(),
1491 bytes.size(), s_);
1492 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1493
1494 // Get attr
1495 AttrValue read_attr;
1496 GetAttr("test_attr_name", &read_attr);
1497 ASSERT_EQ(attr.DebugString(), read_attr.DebugString());
1498
1499 // Retrieve the same attr after save/restore
1500 Reincarnate();
1501 AttrValue read_attr2;
1502 GetAttr("test_attr_name", &read_attr2);
1503 ASSERT_EQ(attr.DebugString(), read_attr2.DebugString());
1504 }
1505
TEST_F(CApiFunctionTest,Description)1506 TEST_F(CApiFunctionTest, Description) {
1507 DefineFunction(func_name_, &func_, "Return something");
1508 tensorflow::FunctionDef fdef;
1509 ASSERT_TRUE(GetFunctionDef(func_, &fdef));
1510 ASSERT_EQ(string("Return something"), fdef.signature().description());
1511 }
1512
TEST_F(CApiFunctionTest,Name)1513 TEST_F(CApiFunctionTest, Name) {
1514 DefineFunction("long_func_name", &func_, "Return something",
1515 /*append_hash=*/false);
1516 tensorflow::FunctionDef fdef;
1517 ASSERT_TRUE(GetFunctionDef(func_, &fdef));
1518 ASSERT_EQ(string("long_func_name"), fdef.signature().name());
1519 }
1520
TEST_F(CApiFunctionTest,AppendHash)1521 TEST_F(CApiFunctionTest, AppendHash) {
1522 DefineFunction("func_name_base", &func_, "Return something",
1523 /*append_hash=*/true);
1524 tensorflow::FunctionDef fdef;
1525 ASSERT_TRUE(GetFunctionDef(func_, &fdef));
1526 #if (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
1527 ASSERT_EQ(string("func_name_base_ZpgUD4x8oqk"), fdef.signature().name());
1528 #else
1529 ASSERT_EQ(string("func_name_base_qaJ8jA8UmGY"), fdef.signature().name());
1530 #endif
1531 }
1532
TEST_F(CApiFunctionTest,GetOpDef)1533 TEST_F(CApiFunctionTest, GetOpDef) {
1534 DefineFunction(func_name_, &func_);
1535 TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
1536 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1537
1538 // Test we can retrieve function OpDef from graph
1539 TF_Buffer* buffer = TF_NewBuffer();
1540 TF_GraphGetOpDef(host_graph_, func_name_, buffer, s_);
1541 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1542
1543 // Sanity check returned OpDef
1544 string data(static_cast<const char*>(buffer->data), buffer->length);
1545 OpDef op_def;
1546 op_def.ParseFromString(data);
1547 EXPECT_EQ(op_def.name(), func_name_);
1548 EXPECT_EQ(op_def.input_arg_size(), 1);
1549 EXPECT_EQ(op_def.output_arg_size(), 1);
1550 EXPECT_FALSE(op_def.is_stateful());
1551
1552 TF_DeleteBuffer(buffer);
1553 }
1554
DefineStatefulFunction(const char * name,TF_Function ** func)1555 void DefineStatefulFunction(const char* name, TF_Function** func) {
1556 std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
1557 TF_NewGraph(), TF_DeleteGraph);
1558 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
1559 TF_DeleteStatus);
1560
1561 TF_Tensor* tensor_shape = Int32Tensor({37, 1});
1562 TF_Operation* shape = Const(tensor_shape, func_graph.get(), s.get(), "shape");
1563 TF_Operation* random =
1564 RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get());
1565
1566 TF_Output inputs[] = {};
1567 TF_Output outputs[] = {{random, 0}};
1568 *func = TF_GraphToFunction(func_graph.get(), name,
1569 /*append_hash_to_fn_name=*/false, -1,
1570 /*opers=*/nullptr, 0, inputs, 1, outputs,
1571 /*output_names=*/nullptr,
1572 /*opts=*/nullptr, "", s.get());
1573 ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
1574 ASSERT_NE(*func, nullptr);
1575 TF_DeleteTensor(tensor_shape);
1576 }
1577
TEST_F(CApiFunctionTest,StatefulOpDef)1578 TEST_F(CApiFunctionTest, StatefulOpDef) {
1579 DefineStatefulFunction(func_name_, &func_);
1580 TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
1581 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1582
1583 // Test we can retrieve function OpDef from graph
1584 TF_Buffer* buffer = TF_NewBuffer();
1585 TF_GraphGetOpDef(host_graph_, func_name_, buffer, s_);
1586 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1587
1588 // Sanity check returned OpDef
1589 string data(static_cast<const char*>(buffer->data), buffer->length);
1590 OpDef op_def;
1591 op_def.ParseFromString(data);
1592 EXPECT_EQ(op_def.name(), func_name_);
1593 EXPECT_EQ(op_def.input_arg_size(), 0);
1594 EXPECT_EQ(op_def.output_arg_size(), 1);
1595 EXPECT_TRUE(op_def.is_stateful());
1596
1597 TF_DeleteBuffer(buffer);
1598 }
1599
AssertEqual(TF_Function * f1,TF_Function * f2)1600 void AssertEqual(TF_Function* f1, TF_Function* f2) {
1601 string s1, s2;
1602 tensorflow::FunctionDef fdef1, fdef2;
1603 ASSERT_TRUE(GetFunctionDef(f1, &fdef1));
1604 ASSERT_TRUE(GetFunctionDef(f2, &fdef2));
1605 SerializeToStringDeterministic(fdef1, &s1);
1606 SerializeToStringDeterministic(fdef2, &s2);
1607 ASSERT_EQ(s1, s2);
1608 }
1609
GetName(TF_Function * func)1610 string GetName(TF_Function* func) {
1611 tensorflow::FunctionDef fdef;
1612 GetFunctionDef(func, &fdef);
1613 return fdef.signature().name();
1614 }
1615
TEST_F(CApiFunctionTest,GetFunctionsFromGraph)1616 TEST_F(CApiFunctionTest, GetFunctionsFromGraph) {
1617 TF_Function* funcs[2];
1618
1619 // Get functions from empty graph
1620 EXPECT_EQ(TF_GraphNumFunctions(host_graph_), 0);
1621 TF_GraphGetFunctions(host_graph_, nullptr, 0, s_);
1622 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1623
1624 // Define a function and add it to host_graph_
1625 TF_Function* func0;
1626 DefineFunction("FooFunc0", &func0);
1627 TF_GraphCopyFunction(host_graph_, func0, nullptr, s_);
1628 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1629
1630 // Get this function from host_graph_
1631 EXPECT_EQ(TF_GraphNumFunctions(host_graph_), 1);
1632 EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 0, s_), 0);
1633 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1634 EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 1, s_), 1);
1635 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1636 AssertEqual(func0, funcs[0]);
1637 TF_DeleteFunction(funcs[0]);
1638 EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 2, s_), 1);
1639 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1640 AssertEqual(func0, funcs[0]);
1641 TF_DeleteFunction(funcs[0]);
1642
1643 // Define a second function
1644 TF_Function* func1;
1645 DefineFunction("FooFunc1", &func1);
1646 TF_GraphCopyFunction(host_graph_, func1, nullptr, s_);
1647 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1648
1649 // Get both function from host_graph_
1650 EXPECT_EQ(TF_GraphNumFunctions(host_graph_), 2);
1651 EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 0, s_), 0);
1652 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1653 EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 2, s_), 2);
1654 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1655 if (GetName(funcs[0]) == GetName(func0)) {
1656 AssertEqual(func0, funcs[0]);
1657 AssertEqual(func1, funcs[1]);
1658 } else {
1659 AssertEqual(func0, funcs[1]);
1660 AssertEqual(func1, funcs[0]);
1661 }
1662
1663 TF_DeleteFunction(funcs[0]);
1664 TF_DeleteFunction(funcs[1]);
1665
1666 TF_DeleteFunction(func0);
1667 TF_DeleteFunction(func1);
1668 }
1669
1670 // This test only works when the TF build includes XLA compiler. One way to set
1671 // this up is via bazel build option "--define with_xla_support=true".
1672 //
1673 // FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to
1674 // something like TENSORFLOW_CAPI_USE_XLA.
1675 #ifdef TENSORFLOW_EAGER_USE_XLA
TEST_F(CApiFunctionTest,StatelessIf_XLA)1676 TEST_F(CApiFunctionTest, StatelessIf_XLA) {
1677 TF_Function* func;
1678 const std::string funcName = "BranchFunc";
1679 DefineFunction(funcName.c_str(), &func);
1680 TF_GraphCopyFunction(host_graph_, func, nullptr, s_);
1681 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1682
1683 TF_Operation* feed = Placeholder(host_graph_, s_);
1684 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1685
1686 TF_Operation* true_cond = ScalarConst(true, host_graph_, s_);
1687 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1688
1689 TF_OperationDescription* desc =
1690 TF_NewOperation(host_graph_, "StatelessIf", "IfNode");
1691 TF_AddInput(desc, {true_cond, 0});
1692 TF_Output inputs[] = {{feed, 0}};
1693 TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs));
1694 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1695 TF_SetAttrType(desc, "Tcond", TF_BOOL);
1696 TF_DataType inputType = TF_INT32;
1697 TF_SetAttrTypeList(desc, "Tin", &inputType, 1);
1698 TF_SetAttrTypeList(desc, "Tout", &inputType, 1);
1699 TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size());
1700 TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size());
1701 TF_SetDevice(desc, "/device:XLA_CPU:0");
1702 auto op = TF_FinishOperation(desc, s_);
1703 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1704 ASSERT_NE(op, nullptr);
1705
1706 // Create a session for this graph.
1707 CSession csession(host_graph_, s_, /*use_XLA*/ true);
1708 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1709
1710 // Run the graph.
1711 csession.SetInputs({{feed, Int32Tensor(17)}});
1712 csession.SetOutputs({op});
1713 csession.Run(s_);
1714 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1715 TF_Tensor* out = csession.output_tensor(0);
1716 ASSERT_TRUE(out != nullptr);
1717 EXPECT_EQ(TF_INT32, TF_TensorType(out));
1718 EXPECT_EQ(0, TF_NumDims(out)); // scalar
1719 ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
1720 int32* output_contents = static_cast<int32*>(TF_TensorData(out));
1721 EXPECT_EQ(-17, *output_contents);
1722
1723 // Clean up
1724 csession.CloseAndDelete(s_);
1725 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1726
1727 TF_DeleteFunction(func);
1728 }
1729 #endif // TENSORFLOW_EAGER_USE_XLA
1730
1731 } // namespace
1732 } // namespace tensorflow
1733