• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/c_api.h"
17 
18 #include "tensorflow/c/c_api_internal.h"
19 #include "tensorflow/c/c_test_util.h"
20 #include "tensorflow/core/framework/function.pb.h"
21 #include "tensorflow/core/framework/op_def.pb.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/lib/hash/hash.h"
24 #include "tensorflow/core/lib/strings/proto_serialization.h"
25 #include "tensorflow/core/lib/strings/str_util.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/test.h"
29 
30 namespace tensorflow {
31 namespace {
32 
33 // Specification for expected input/output and its type.
34 // DataType value of DT_INVALID signifies that we don't want to
35 // check the data type.
36 typedef std::pair<string, DataType> IOSpec;
37 
M(const std::initializer_list<string> & names)38 std::vector<IOSpec> M(const std::initializer_list<string>& names) {
39   std::vector<IOSpec> v;
40   for (const string& name : names) {
41     v.push_back(IOSpec(name, DT_INVALID));
42   }
43   return v;
44 }
45 
46 // Specification for an expected edge.
47 // src is either:
48 // - input name (as it appears in FunctionDef)
49 // - name of output tensor (in nested "add:z:0" format)
50 // dst is either:
51 // - output name (as it appears in FunctionDef)
52 // - <name_of_node>:<index_of_this_input_into_node> (this looks the same as
53 //      output tensor naming, but it the index is actually an input index)
54 struct EdgeSpec : public std::pair<string, string> {
55   typedef std::pair<string, string> Base;
56 
57   // Inherit the set of constructors
58   using Base::pair;
59 
ToStringtensorflow::__anon1a216e640111::EdgeSpec60   string ToString() const { return strings::StrCat(first, "->", second); }
61 };
62 
63 class CApiFunctionTest : public ::testing::Test {
64  protected:
CApiFunctionTest()65   CApiFunctionTest()
66       : s_(TF_NewStatus()),
67         func_graph_(TF_NewGraph()),
68         host_graph_(TF_NewGraph()),
69         func_(nullptr) {}
70 
SetUp()71   void SetUp() override {}
72 
~CApiFunctionTest()73   ~CApiFunctionTest() override {
74     TF_DeleteFunction(func_);
75     TF_DeleteGraph(host_graph_);
76     TF_DeleteGraph(func_graph_);
77     TF_DeleteStatus(s_);
78   }
79 
Run(const std::vector<std::pair<TF_Operation *,TF_Tensor * >> & inputs,TF_Operation * output,int32_t expected_result)80   void Run(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs,
81            TF_Operation* output, int32_t expected_result) {
82     Run(inputs, {{output, 0}}, {expected_result});
83   }
84 
85   // Run the host graph, which now contains a function and check that
86   // outputs are as expected.
87   // 'T' stands for 'tensor' since the outputs are tensors, not scalars.
RunT(const std::vector<std::pair<TF_Operation *,TF_Tensor * >> & inputs,std::initializer_list<TF_Output> outputs,const std::vector<std::vector<int32_t>> & expected_results)88   void RunT(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs,
89             std::initializer_list<TF_Output> outputs,
90             const std::vector<std::vector<int32_t>>& expected_results) {
91     // Create a session for this graph
92     CSession csession(host_graph_, s_);
93     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
94 
95     // Run
96     csession.SetInputs(inputs);
97     csession.SetOutputs(outputs);
98     csession.Run(s_);
99     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
100 
101     // Check results
102     for (int i = 0; i < expected_results.size(); ++i) {
103       TF_Tensor* out = csession.output_tensor(i);
104       ASSERT_TRUE(out != nullptr);
105       EXPECT_EQ(TF_INT32, TF_TensorType(out));
106       EXPECT_EQ(1, TF_NumDims(out));
107       CompareInt32Tensor(expected_results[i], out);
108     }
109   }
110 
111   // Run the host graph, which now contains a function and check that
112   // outputs are as expected.
Run(const std::vector<std::pair<TF_Operation *,TF_Tensor * >> & inputs,std::initializer_list<TF_Output> outputs,const std::vector<int32_t> & expected_results)113   void Run(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs,
114            std::initializer_list<TF_Output> outputs,
115            const std::vector<int32_t>& expected_results) {
116     // Create a session for this graph.
117     CSession csession(host_graph_, s_);
118     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
119 
120     csession.SetInputs(inputs);
121     csession.SetOutputs(outputs);
122     csession.Run(s_);
123     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
124 
125     for (int i = 0; i < expected_results.size(); ++i) {
126       TF_Tensor* out = csession.output_tensor(i);
127       ASSERT_TRUE(out != nullptr);
128       EXPECT_EQ(TF_INT32, TF_TensorType(out));
129       EXPECT_EQ(0, TF_NumDims(out));  // scalar
130       ASSERT_EQ(sizeof(int32_t), TF_TensorByteSize(out));
131       int32_t* output_contents = static_cast<int32_t*>(TF_TensorData(out));
132       EXPECT_EQ(expected_results[i], *output_contents);
133     }
134   }
135 
CompareInt32Tensor(const std::vector<int32_t> & expected,TF_Tensor * t)136   void CompareInt32Tensor(const std::vector<int32_t>& expected, TF_Tensor* t) {
137     int32_t* data = static_cast<int32_t*>(TF_TensorData(t));
138     size_t size = TF_TensorByteSize(t);
139     ASSERT_EQ(expected.size() * sizeof(int32_t), size);
140     for (int i = 0; i < expected.size(); ++i) {
141       ASSERT_EQ(expected[i], data[i]) << "Different data at index " << i;
142     }
143   }
144 
ToOutput(const std::vector<TF_Operation * > ops)145   std::vector<TF_Output> ToOutput(const std::vector<TF_Operation*> ops) {
146     std::vector<TF_Output> out;
147     for (auto op : ops) {
148       out.push_back({op, 0});
149     }
150     return out;
151   }
152 
Define(int num_opers,const std::vector<TF_Operation * > & opers,const std::vector<TF_Operation * > & inputs,const std::vector<TF_Operation * > & outputs,const std::vector<string> & output_names,bool expect_failure=false)153   void Define(int num_opers, const std::vector<TF_Operation*>& opers,
154               const std::vector<TF_Operation*>& inputs,
155               const std::vector<TF_Operation*>& outputs,
156               const std::vector<string>& output_names,
157               bool expect_failure = false) {
158     DefineT(num_opers, opers, ToOutput(inputs), ToOutput(outputs), output_names,
159             expect_failure);
160   }
161 
162   // Caller must delete[] the returned value
ToArray(const std::vector<string> & strs)163   static const char** ToArray(const std::vector<string>& strs) {
164     const char** ptr = nullptr;
165     if (!strs.empty()) {
166       ptr = new const char*[strs.size()];
167       for (size_t i = 0; i < strs.size(); ++i) {
168         ptr[i] = strs[i].c_str();
169       }
170     }
171     return ptr;
172   }
173 
174   // An explicit `num_opers` is needed so that we can distinguish between the
175   // case of no operations specified (-1) and the case of an empty set of
176   // operations specified (0).
DefineT(int num_opers,const std::vector<TF_Operation * > & opers,const std::vector<TF_Output> & inputs,const std::vector<TF_Output> & outputs,const std::vector<string> & output_names,bool expect_failure=false)177   void DefineT(int num_opers, const std::vector<TF_Operation*>& opers,
178                const std::vector<TF_Output>& inputs,
179                const std::vector<TF_Output>& outputs,
180                const std::vector<string>& output_names,
181                bool expect_failure = false) {
182     ASSERT_EQ(func_, nullptr);
183     const char** output_names_ptr = ToArray(output_names);
184     func_ = TF_GraphToFunction(func_graph_, func_name_, false, num_opers,
185                                num_opers == -1 ? nullptr : opers.data(),
186                                inputs.size(), inputs.data(), outputs.size(),
187                                outputs.data(), output_names_ptr,
188                                /*opts=*/nullptr, /*description=*/nullptr, s_);
189     delete[] output_names_ptr;
190     if (expect_failure) {
191       ASSERT_EQ(func_, nullptr);
192       return;
193     }
194 
195     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
196     ASSERT_NE(func_, nullptr);
197     ASSERT_EQ(std::string(func_name_), std::string(TF_FunctionName(func_)));
198     TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
199     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
200   }
201 
Use(const std::vector<TF_Operation * > & inputs)202   TF_Operation* Use(const std::vector<TF_Operation*>& inputs) {
203     return UseT(ToOutput(inputs));
204   }
205 
UseT(const std::vector<TF_Output> & inputs)206   TF_Operation* UseT(const std::vector<TF_Output>& inputs) {
207     TF_Operation* op;
208     UseHelper(inputs, &op);
209     return op;
210   }
211 
212   // All the *Helper methods are used as a workaround for the restrictions that
213   // one cannot call ASSERT_* methods in non-void-returning functions (when
214   // exceptions are disabled during compilation)
UseHelper(const std::vector<TF_Output> & inputs,TF_Operation ** op)215   void UseHelper(const std::vector<TF_Output>& inputs, TF_Operation** op) {
216     TF_OperationDescription* desc =
217         TF_NewOperation(host_graph_, func_name_, func_node_name_);
218     for (auto input : inputs) {
219       TF_AddInput(desc, input);
220     }
221     // Set device to CPU because some ops inside the function might not be
222     // available on GPU.
223     TF_SetDevice(desc, "/cpu:0");
224     *op = TF_FinishOperation(desc, s_);
225     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
226     ASSERT_NE(*op, nullptr);
227   }
228 
fdef()229   FunctionDef fdef() {
230     tensorflow::FunctionDef fdef;
231     EXPECT_TRUE(GetFunctionDef(func_, &fdef));
232     return fdef;
233   }
234 
235   // logging utility
236   template <class Container>
ToString(const Container & v)237   string ToString(const Container& v) {
238     std::stringstream ss;
239     ss << "{";
240     size_t i = 0;
241     for (const auto& e : v) {
242       if (i != 0) {
243         ss << ", ";
244       }
245       ss << e.ToString();
246       ++i;
247     }
248     ss << "}";
249     return ss.str();
250   }
251 
VerifyFDefNodes(const tensorflow::FunctionDef & fdef,const std::unordered_set<string> & nodes)252   void VerifyFDefNodes(const tensorflow::FunctionDef& fdef,
253                        const std::unordered_set<string>& nodes) {
254     ASSERT_EQ(nodes.size(), fdef.node_def_size())
255         << "Got unexpected number of nodes. Expected: ["
256         << str_util::Join(nodes, ", ")
257         << "] Actual nodes in fdef: " << fdef.DebugString();
258     for (const NodeDef& node_def : fdef.node_def()) {
259       ASSERT_TRUE(nodes.find(node_def.name()) != nodes.end())
260           << "Got unexpected node: " << node_def.name()
261           << " in fdef: " << fdef.DebugString();
262     }
263   }
264 
VerifyFDefInputs(const tensorflow::FunctionDef & fdef,const std::vector<IOSpec> & inputs)265   void VerifyFDefInputs(const tensorflow::FunctionDef& fdef,
266                         const std::vector<IOSpec>& inputs) {
267     const OpDef& signature = fdef.signature();
268     ASSERT_EQ(inputs.size(), signature.input_arg_size());
269     for (int i = 0; i < inputs.size(); ++i) {
270       const OpDef::ArgDef& arg = signature.input_arg(i);
271       const IOSpec& in = inputs[i];
272       if (in.second != DT_INVALID) {
273         ASSERT_EQ(arg.type(), in.second)
274             << "Got unexpected type for input " << i
275             << ". fdef: " << fdef.DebugString();
276       }
277       ASSERT_EQ(arg.name(), in.first) << "Got unexpected name for input " << i
278                                       << ". fdef: " << fdef.DebugString();
279     }
280   }
281 
VerifyFDefOutputs(const tensorflow::FunctionDef & fdef,const std::vector<IOSpec> & outputs)282   void VerifyFDefOutputs(const tensorflow::FunctionDef& fdef,
283                          const std::vector<IOSpec>& outputs) {
284     const OpDef& signature = fdef.signature();
285     ASSERT_EQ(outputs.size(), signature.output_arg_size());
286     for (int i = 0; i < outputs.size(); ++i) {
287       const OpDef::ArgDef& arg = signature.output_arg(i);
288       const IOSpec& out = outputs[i];
289       if (out.second != DT_INVALID) {
290         ASSERT_EQ(arg.type(), out.second)
291             << "Got unexpected type for output " << i
292             << ". fdef: " << fdef.DebugString();
293       }
294       ASSERT_EQ(arg.name(), out.first) << "Got unexpected name for output " << i
295                                        << ". fdef: " << fdef.DebugString();
296     }
297   }
298 
VerifyFDefEdges(const tensorflow::FunctionDef & fdef,const std::vector<EdgeSpec> & e_edges,const std::vector<EdgeSpec> & c_edges,bool is_exact_edges=true)299   void VerifyFDefEdges(
300       const tensorflow::FunctionDef& fdef,
301       const std::vector<EdgeSpec>& e_edges,  // expected edges
302       const std::vector<EdgeSpec>& c_edges,  // expected ctrl edges
303       bool is_exact_edges = true) {
304     // Build a set of edges from fdef
305     std::set<EdgeSpec> a_edges;  // actual edges
306     // Get edges from inputs to body nodes and between body nodes
307     for (const NodeDef& node_def : fdef.node_def()) {
308       for (int i = 0; i < node_def.input_size(); ++i) {
309         const string& in = node_def.input(i);
310         const auto& v =
311             a_edges.insert({in, strings::StrCat(node_def.name(), ":", i)});
312         ASSERT_TRUE(v.second) << "Duplicate edge " << in << " -> "
313                               << strings::StrCat(node_def.name(), ":", i)
314                               << ". fdef: " << fdef.DebugString();
315       }
316     }
317     // Get edges from body nodes to outputs and from inputs to outputs
318     for (const OpDef::ArgDef& arg : fdef.signature().output_arg()) {
319       const auto& iter = fdef.ret().find(arg.name());
320       if (iter != fdef.ret().end()) {
321         const auto& v = a_edges.insert({iter->second, arg.name()});
322         ASSERT_TRUE(v.second) << "Duplicate edge " << iter->second << " -> "
323                               << arg.name() << ". fdef: " << fdef.DebugString();
324       } else {
325         const auto& v = a_edges.insert({arg.name(), arg.name()});
326         ASSERT_TRUE(v.second) << "Duplicate edge " << arg.name() << " -> "
327                               << arg.name() << ". fdef: " << fdef.DebugString();
328       }
329     }
330 
331     // Verify edges
332     for (const EdgeSpec& e : e_edges) {
333       ASSERT_TRUE(a_edges.find(e) != a_edges.end())
334           << "Failed to find expected edge " << e.ToString()
335           << " in fdef: " << fdef.DebugString();
336     }
337     for (const EdgeSpec& e : c_edges) {
338       ASSERT_TRUE(a_edges.find(e) != a_edges.end())
339           << "Failed to find expected control edge " << e.ToString()
340           << " in fdef: " << fdef.DebugString();
341     }
342 
343     // If caller specified all edges, check that we have seen all
344     if (is_exact_edges) {
345       ASSERT_EQ(e_edges.size() + c_edges.size(), a_edges.size())
346           << "Expected edges: " << ToString(e_edges)
347           << " Expected Control edges: " << ToString(c_edges)
348           << " Actual edges: " << ToString(a_edges)
349           << " in fdef: " << fdef.DebugString();
350     }
351   }
352 
VerifyFDef(const std::unordered_set<string> & nodes,const std::vector<IOSpec> & inputs,const std::vector<IOSpec> & outputs,const std::vector<EdgeSpec> & e_edges,const std::vector<EdgeSpec> & c_edges,bool is_exact_edges=true)353   void VerifyFDef(const std::unordered_set<string>& nodes,
354                   const std::vector<IOSpec>& inputs,
355                   const std::vector<IOSpec>& outputs,
356                   const std::vector<EdgeSpec>& e_edges,  // expected edges
357                   const std::vector<EdgeSpec>& c_edges,  // expected ctrl edges
358                   bool is_exact_edges = true) {
359     tensorflow::FunctionDef fdef;
360     ASSERT_TRUE(GetFunctionDef(func_, &fdef));
361     VerifyFDefNodes(fdef, nodes);
362     VerifyFDefInputs(fdef, inputs);
363     VerifyFDefOutputs(fdef, outputs);
364     VerifyFDefEdges(fdef, e_edges, c_edges, is_exact_edges);
365   }
366 
367   // Serialize func_ to fdef and import it back
Reincarnate()368   void Reincarnate() {
369     // func_ -> fdef
370     tensorflow::FunctionDef fdef;
371     ASSERT_TRUE(GetFunctionDef(func_, &fdef));
372     TF_DeleteFunction(func_);
373 
374     // fdef -> func_
375     string buf;
376     ASSERT_TRUE(fdef.SerializeToString(&buf));
377     func_ = TF_FunctionImportFunctionDef(buf.data(), buf.size(), s_);
378     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
379   }
380 
GetAttr(const char * attr_name,AttrValue * out_attr)381   void GetAttr(const char* attr_name, AttrValue* out_attr) {
382     TF_Buffer* attr_buf = TF_NewBuffer();
383     TF_FunctionGetAttrValueProto(func_, attr_name, attr_buf, s_);
384     ASSERT_TRUE(out_attr->ParseFromArray(attr_buf->data, attr_buf->length));
385     TF_DeleteBuffer(attr_buf);
386   }
387 
388   const char* func_name_ = "MyFunc";
389   const char* func_node_name_ = "MyFunc_0";
390   TF_Status* s_;
391   TF_Graph* func_graph_;
392   TF_Graph* host_graph_;
393   TF_Function* func_;
394 
395   // Workaround for not being able to initialize empty map using {}
396   std::unordered_set<string> empty_;
397 };
398 
TEST_F(CApiFunctionTest,OneOp_ZeroInputs_OneOutput)399 TEST_F(CApiFunctionTest, OneOp_ZeroInputs_OneOutput) {
400   /*
401    *                constant
402    *                   |
403    *                   v
404    */
405   // Define
406   TF_Operation* c = ScalarConst(10, func_graph_, s_, "scalar10");
407   Define(-1, {}, {}, {c}, {});
408 
409   // Use, run, and verify
410   TF_Operation* func_op = Use({});
411   Run({}, func_op, 10);
412   VerifyFDef({"scalar10_0"}, {}, {{"scalar10", DT_INT32}},
413              {{"scalar10_0:output:0", "scalar10"}}, {});
414 }
415 
TEST_F(CApiFunctionTest,OneOp_OneInput_OneOutput)416 TEST_F(CApiFunctionTest, OneOp_OneInput_OneOutput) {
417   /*
418    *                   |
419    *                   v
420    *                 negate
421    *                   |
422    *                   v
423    */
424   // Define
425   TF_Operation* feed = Placeholder(func_graph_, s_);
426   TF_Operation* neg = Neg(feed, func_graph_, s_);
427   Define(-1, {}, {feed}, {neg}, {});
428 
429   // Use, run, and verify
430   TF_Operation* func_feed = Placeholder(host_graph_, s_);
431   TF_Operation* func_op = Use({func_feed});
432   Run({{func_feed, Int32Tensor(3)}}, func_op, -3);
433   VerifyFDef({"neg_0"}, {{"feed", DT_INT32}}, {{"neg", DT_INT32}},
434              {{"feed", "neg_0:0"}, {"neg_0:y:0", "neg"}}, {});
435 }
436 
TEST_F(CApiFunctionTest,OneOutput_OutputNames)437 TEST_F(CApiFunctionTest, OneOutput_OutputNames) {
438   /*
439    *                   |
440    *                   v
441    *                 negate
442    *                   |
443    *                   v
444    */
445   // Define
446   TF_Operation* feed = Placeholder(func_graph_, s_);
447   TF_Operation* neg = Neg(feed, func_graph_, s_);
448   Define(-1, {}, {feed}, {neg}, {"negated_num"});
449 
450   // Use, run, and verify
451   TF_Operation* func_feed = Placeholder(host_graph_, s_);
452   TF_Operation* func_op = Use({func_feed});
453   Run({{func_feed, Int32Tensor(3)}}, func_op, -3);
454   VerifyFDef({"neg"}, {{"feed", DT_INT32}}, {{"negated_num", DT_INT32}},
455              {{"feed", "neg:0"}, {"neg:y:0", "negated_num"}}, {});
456 }
457 
TEST_F(CApiFunctionTest,OutputNames_SameNameAsInput)458 TEST_F(CApiFunctionTest, OutputNames_SameNameAsInput) {
459   /*
460    *                   |
461    *                   v
462    *                 negate
463    *                   |
464    *                   v
465    */
466   // Define
467   TF_Operation* feed = Placeholder(func_graph_, s_, "negation");
468   TF_Operation* neg = Neg(feed, func_graph_, s_, "neg");
469   Define(-1, {}, {feed}, {neg}, {"negation"});
470 
471   // Use, run, and verify
472   TF_Operation* func_feed = Placeholder(host_graph_, s_);
473   TF_Operation* func_op = Use({func_feed});
474   Run({{func_feed, Int32Tensor(3)}}, func_op, -3);
475   VerifyFDef({"neg"}, {{"negation_0", DT_INT32}}, {{"negation", DT_INT32}},
476              {{"negation_0", "neg:0"}, {"neg:y:0", "negation"}}, {});
477 }
478 
TEST_F(CApiFunctionTest,ZeroOps_Identity)479 TEST_F(CApiFunctionTest, ZeroOps_Identity) {
480   /*
481    *                   |
482    *                   |
483    *                   |
484    *                   v
485    */
486   // Define
487   TF_Operation* feed = Placeholder(func_graph_, s_);
488   Define(-1, {}, {feed}, {feed}, {});
489 
490   // Use, run, and verify
491   TF_Operation* func_feed = Placeholder(host_graph_, s_);
492   TF_Operation* func_op = Use({func_feed});
493   Run({{func_feed, Int32Tensor(3)}}, func_op, 3);
494   VerifyFDef(empty_, {{"feed_0", DT_INT32}}, {{"feed", DT_INT32}},
495              {{"feed_0", "feed"}}, {});
496 }
497 
TEST_F(CApiFunctionTest,ZeroOps_Permutation)498 TEST_F(CApiFunctionTest, ZeroOps_Permutation) {
499   /*
500    *                   |   |
501    *                   \  /
502    *                    \/
503    *                    x
504    *                   /\
505    *                  /  \
506    *                 |   |
507    *                 v   v
508    */
509   // Define
510   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
511   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
512   Define(-1, {}, {feed1, feed2}, {feed2, feed1}, {});
513 
514   // Use, run, and verify
515   TF_Operation* two = ScalarConst(2, host_graph_, s_);
516   TF_Operation* func_feed = Placeholder(host_graph_, s_);
517   TF_Operation* func_op = Use({two, func_feed});
518   Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {3, 2});
519   VerifyFDef(empty_, M({{"feed1_0"}, {"feed2_0"}}), M({{"feed2"}, {"feed1"}}),
520              {{"feed1_0", "feed1"}, {"feed2_0", "feed2"}}, {});
521 }
522 
TEST_F(CApiFunctionTest,ZeroOps_Permutation_OutputNames)523 TEST_F(CApiFunctionTest, ZeroOps_Permutation_OutputNames) {
524   /*
525    *                   |   |
526    *                   \  /
527    *                    \/
528    *                    x
529    *                   /\
530    *                  /  \
531    *                 |   |
532    *                 v   v
533    */
534   // Define
535   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
536   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
537   Define(-1, {}, {feed1, feed2}, {feed2, feed1}, {"first", "second"});
538 
539   // Use, run, and verify
540   TF_Operation* two = ScalarConst(2, host_graph_, s_);
541   TF_Operation* func_feed = Placeholder(host_graph_, s_);
542   TF_Operation* func_op = Use({two, func_feed});
543   Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {3, 2});
544   VerifyFDef(empty_, M({{"feed1"}, {"feed2"}}), M({{"first"}, {"second"}}),
545              {{"feed1", "second"}, {"feed2", "first"}}, {});
546 }
547 
TEST_F(CApiFunctionTest,OneOp_TwoInputs_OneOutput)548 TEST_F(CApiFunctionTest, OneOp_TwoInputs_OneOutput) {
549   /*
550    *                  |  |
551    *                  v  v
552    *                  add
553    *                   |
554    *                   v
555    */
556   // Define
557   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
558   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
559   TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
560   Define(-1, {}, {feed1, feed2}, {add}, {});
561 
562   // Use, run, and verify
563   TF_Operation* two = ScalarConst(2, host_graph_, s_);
564   TF_Operation* func_feed = Placeholder(host_graph_, s_);
565   TF_Operation* func_op = Use({two, func_feed});
566   Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
567   VerifyFDef(
568       {"add_0"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}),
569       {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}}, {});
570 }
571 
TEST_F(CApiFunctionTest,OneOp_TwoInputs_ZeroOutputs)572 TEST_F(CApiFunctionTest, OneOp_TwoInputs_ZeroOutputs) {
573   /*
574    *                  |  |
575    *                  v  v
576    *                  add
577    *
578    *            (output ignored)
579    */
580   // Define
581   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
582   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
583   Add(feed1, feed2, func_graph_, s_);
584   Define(-1, {}, {feed1, feed2}, {}, {});
585 
586   // Use, run, and verify
587   TF_Operation* two = ScalarConst(2, host_graph_, s_);
588   TF_Operation* func_feed = Placeholder(host_graph_, s_);
589   Use({two, func_feed});
590   VerifyFDef({"add"}, M({{"feed1"}, {"feed2"}}), {},
591              {{"feed1", "add:0"}, {"feed2", "add:1"}}, {});
592 }
593 
TEST_F(CApiFunctionTest,TwoOps_ThreeInputs_OneOutput)594 TEST_F(CApiFunctionTest, TwoOps_ThreeInputs_OneOutput) {
595   /*
596    *                  |  |   |
597    *                  v  v   /
598    *                  add1  /
599    *                   |   |
600    *                   v   v
601    *                   add2
602    *                    |
603    *                    v
604    */
605   // Define
606   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
607   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
608   TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
609   TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
610   TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
611   Define(-1, {}, {feed1, feed2, feed3}, {add2}, {});
612 
613   // Use, run, and verify
614   TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
615   TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten");
616   TF_Operation* func_feed = Placeholder(host_graph_, s_);
617   TF_Operation* func_op = Use({two, ten, func_feed});
618   Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 10 + 3);
619   VerifyFDef({"add1", "add2_0"}, M({{"feed1"}, {"feed2"}, {"feed3"}}),
620              M({{"add2"}}),
621              {{"feed1", "add1:0"},
622               {"feed2", "add1:1"},
623               {"add1:sum:0", "add2_0:0"},
624               {"feed3", "add2_0:1"},
625               {"add2_0:sum:0", "add2"}},
626              {});
627 }
628 
TEST_F(CApiFunctionTest,OneOp_TwoInputs_TwoDuplicateOutputs)629 TEST_F(CApiFunctionTest, OneOp_TwoInputs_TwoDuplicateOutputs) {
630   /*
631    *                  |  |
632    *                  v  v
633    *                  add
634    *                   |
635    *                 +-+-+
636    *                 |   |
637    *                 v   v
638    */
639   // Define
640   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
641   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
642   TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
643   Define(-1, {}, {feed1, feed2}, {add, add}, {});
644 
645   // Use, run, and verify
646   TF_Operation* two = ScalarConst(2, host_graph_, s_);
647   TF_Operation* func_feed = Placeholder(host_graph_, s_);
648   TF_Operation* func_op = Use({two, func_feed});
649   Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {5, 5});
650   VerifyFDef({"add_1"}, M({{"feed1"}, {"feed2"}}), M({{"add"}, {"add_0"}}),
651              {{"feed1", "add_1:0"},
652               {"feed2", "add_1:1"},
653               {"add_1:sum:0", "add"},
654               {"add_1:sum:0", "add_0"}},
655              {});
656 }
657 
TEST_F(CApiFunctionTest,TwoDuplicateOutputs_OutputNames)658 TEST_F(CApiFunctionTest, TwoDuplicateOutputs_OutputNames) {
659   /*
660    *                  |  |
661    *                  v  v
662    *                  add
663    *                   |
664    *                 +-+-+
665    *                 |   |
666    *                 v   v
667    */
668   // Define
669   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
670   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
671   TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
672   Define(-1, {}, {feed1, feed2}, {add, add}, {"out1", "out2"});
673 
674   // Use, run, and verify
675   TF_Operation* two = ScalarConst(2, host_graph_, s_);
676   TF_Operation* func_feed = Placeholder(host_graph_, s_);
677   TF_Operation* func_op = Use({two, func_feed});
678   Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {5, 5});
679   VerifyFDef({"add"}, M({{"feed1"}, {"feed2"}}), M({{"out1"}, {"out2"}}),
680              {{"feed1", "add:0"},
681               {"feed2", "add:1"},
682               {"add:sum:0", "out1"},
683               {"add:sum:0", "out2"}},
684              {});
685 }
686 
TEST_F(CApiFunctionTest,TwoOps_ThreeInputs_TwoOutputs)687 TEST_F(CApiFunctionTest, TwoOps_ThreeInputs_TwoOutputs) {
688   /*
689    *                  |  |  |
690    *                  v  v  /
691    *                  add  /
692    *                   |  |
693    *                 +-+  |
694    *                 | |  |
695    *                 | v  v
696    *                 | add
697    *                 |  |
698    *                 v  v
699    */
700   // Define
701   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
702   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
703   TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
704   TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
705   TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
706   Define(-1, {}, {feed1, feed2, feed3}, {add1, add2}, {});
707 
708   // Use, run, and verify
709   TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
710   TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten");
711   TF_Operation* func_feed = Placeholder(host_graph_, s_);
712   TF_Operation* func_op = Use({two, ten, func_feed});
713   Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {12, 15});
714   VerifyFDef({"add1_0", "add2_0"}, M({{"feed1"}, {"feed2"}, {"feed3"}}),
715              M({{"add1"}, {"add2"}}),
716              {{"feed1", "add1_0:0"},
717               {"feed2", "add1_0:1"},
718               {"add1_0:sum:0", "add2_0:0"},
719               {"feed3", "add2_0:1"},
720               {"add1_0:sum:0", "add1"},
721               {"add2_0:sum:0", "add2"}},
722              {});
723 }
724 
TEST_F(CApiFunctionTest,FromSubsetOfOps)725 TEST_F(CApiFunctionTest, FromSubsetOfOps) {
726   /*
727    *                  |  |  |
728    *                  v  v  /
729    *                  add  /
730    *                   |  |
731    *               +---+--+---+
732    *  Ops used     |   |  |   |
733    *  for func     |   v  v   |
734    *     |         |   add    |
735    *     +-------> |    |     |
736    *               |    v     |
737    *               |          |
738    *               +----------+
739    */
740   // Define
741   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
742   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
743   TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
744   TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
745   TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
746   Define(1, {add2}, {add1, feed3}, {add2}, {});
747 
748   // Use, run, and verify
749   TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
750   TF_Operation* func_feed = Placeholder(host_graph_, s_);
751   TF_Operation* func_op = Use({two, func_feed});
752   Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
753   VerifyFDef(
754       {"add2_0"}, M({{"add1"}, {"feed3"}}), M({{"add2"}}),
755       {{"add1", "add2_0:0"}, {"feed3", "add2_0:1"}, {"add2_0:sum:0", "add2"}},
756       {});
757 }
758 
TEST_F(CApiFunctionTest,UsingOneOutputOfSplit)759 TEST_F(CApiFunctionTest, UsingOneOutputOfSplit) {
760   /*
761    *                      feed
762    *                       |
763    *             +---------+---+
764    *             | const0  |   |
765    *             |    |    |   |
766    *             |    v    /   |
767    *             |    split    |
768    *             |   |  |  |   |
769    *             |   v  |  v   |
770    *             |      |      |
771    *             +------+------+
772    *                    |
773    *                    v
774    *
775    *  Only the second output from split is used as function output
776    */
777   // Define
778   TF_Operation* feed = Placeholder(func_graph_, s_);
779   TF_Operation* split = Split3(feed, func_graph_, s_);
780   DefineT(-1, {}, {{feed, 0}}, {{split, 1}}, {});
781 
782   // Use, run, and verify
783   TF_Operation* func_feed = Placeholder(host_graph_, s_);
784   TF_Operation* func_op = Use({func_feed});
785   RunT({{func_feed, Int32Tensor({1, 2, 3, 4, 5, 6})}}, {{func_op, 0}},
786        {{3, 4}});
787   VerifyFDef({"split3_const0", "split3_0"}, M({{"feed"}}), M({{"split3"}}),
788              {{"split3_const0:output:0", "split3_0:0"},
789               {"feed", "split3_0:1"},
790               {"split3_0:output:1", "split3"}},
791              {});
792 }
793 
TEST_F(CApiFunctionTest,UsingTwoOutputsOfSplit)794 TEST_F(CApiFunctionTest, UsingTwoOutputsOfSplit) {
795   /*
796    *                      feed
797    *                       |
798    *             +---------+---+
799    *             | const0  |   |
800    *             |    |    |   |
801    *             |    v    /   |
802    *             |    split    |
803    *             |   |  |  |   |
804    *             |   |  v  |   |
805    *             |   |     |   |
806    *             +---+-----+---+
807    *                 |     |
808    *                 v     v
809    *
810    *  Second output from split is not used as function output
811    */
812   // Define
813   TF_Operation* feed = Placeholder(func_graph_, s_);
814   TF_Operation* split = Split3(feed, func_graph_, s_);
815   DefineT(-1, {}, {{feed, 0}}, {{split, 0}, {split, 2}}, {});
816 
817   // Use, run, and verify
818   TF_Operation* func_feed = Placeholder(host_graph_, s_);
819   TF_Operation* func_op = Use({func_feed});
820   RunT({{func_feed, Int32Tensor({1, 2, 3, 4, 5, 6})}},
821        {{func_op, 0}, {func_op, 1}}, {{1, 2}, {5, 6}});
822   VerifyFDef({"split3_const0", "split3_1"}, M({{"feed"}}),
823              M({{"split3"}, {"split3_0"}}),
824              {{"split3_const0:output:0", "split3_1:0"},
825               {"feed", "split3_1:1"},
826               {"split3_1:output:0", "split3"},
827               {"split3_1:output:2", "split3_0"}},
828              {});
829 }
830 
TEST_F(CApiFunctionTest,UsingTwoOutputsOfSplitAsInputs)831 TEST_F(CApiFunctionTest, UsingTwoOutputsOfSplitAsInputs) {
832   /*
833    *                    |
834    *                    v
835    *                  split
836    *                 |  |  |
837    *                 |  v  |
838    *                 |     |
839    *             +---+-----+---+
840    *             |   |     |   |
841    *             |   v     v   |
842    *             |     add     |
843    *             |      |      |
844    *             |      |      |
845    *             +------+------+
846    *                    |
847    *                    v
848    */
849   // Define
850   TF_Operation* feed = Placeholder(func_graph_, s_);
851   TF_Operation* split = Split3(feed, func_graph_, s_);
852   TF_Operation* add = Add({split, 0}, {split, 2}, func_graph_, s_);
853   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
854   DefineT(1, {add}, {{split, 0}, {split, 2}}, {{add, 0}}, {});
855 
856   // Use, run, and verify
857   TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
858   TF_Operation* func_feed = Placeholder(host_graph_, s_);
859   TF_Operation* func_op = Use({two, func_feed});
860   Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
861   VerifyFDef(
862       {"add_0"}, M({{"split3"}, {"split3_0"}}), M({{"add"}}),
863       {{"split3", "add_0:0"}, {"split3_0", "add_0:1"}, {"add_0:sum:0", "add"}},
864       {});
865 }
866 
TEST_F(CApiFunctionTest,NodesUsedInInputsMustHaveSingleOutput)867 TEST_F(CApiFunctionTest, NodesUsedInInputsMustHaveSingleOutput) {
868   /*
869    *                    |
870    *                    v
871    *                  split
872    *                 |  |  |
873    *                 |  v  |
874    *                 |     |
875    *       input --->|     |<--- input
876    *                 |     |
877    *                 v     v
878    *                   add
879    *                    |
880    *                    |
881    *                    v
882    */
883   // Define
884   TF_Tensor* tensor_123 = Int32Tensor({1, 2, 3});
885   TF_Operation* c = Const(tensor_123, func_graph_, s_, "const_array");
886   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
887   TF_Operation* split = Split3(c, func_graph_, s_);
888   TF_Operation* add = Add({split, 0}, {split, 2}, func_graph_, s_);
889   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
890   DefineT(-1, {}, {{split, 0}, {split, 2}}, {{add, 0}}, {}, true);
891   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
892   EXPECT_EQ(string("When `num_opers` is set to -1, nodes referenced in "
893                    "`inputs` must have a single output. Node split3 has "
894                    "3 outputs. Encountered while creating function 'MyFunc'"),
895             string(TF_Message(s_)));
896 
897   TF_DeleteTensor(tensor_123);
898 }
899 
TEST_F(CApiFunctionTest,FunctionWithWhileLoop)900 TEST_F(CApiFunctionTest, FunctionWithWhileLoop) {
901   // Inputs to the while loop and the function as a whole
902   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
903   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
904 
905   // Outputs of the while loop corresponding to the two inputs above
906   // The first one will the function's output
907   std::vector<TF_Output> outputs;
908 
909   // Add while loop to func_graph_
910   {
911     // The inputs to the while loop
912     std::vector<TF_Output> inputs = {{feed1, 0}, {feed2, 0}};
913     std::unique_ptr<TF_WhileParams> params(new TF_WhileParams(
914         TF_NewWhile(func_graph_, &inputs[0], inputs.size(), s_)));
915     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
916     params->name = "test_loop";
917 
918     // Initialize outputs so we can easily detect errors/bugs
919     outputs.resize(2, {nullptr, -1});
920 
921     // Create loop: while (input1 < input2) input1 += input2 + 1
922     TF_Operation* less_than = LessThan(
923         params->cond_inputs[0], params->cond_inputs[1], params->cond_graph, s_);
924     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
925     params->cond_output = {less_than, 0};
926 
927     TF_Operation* add1 = Add(params->body_inputs[0], params->body_inputs[1],
928                              params->body_graph, s_, "add1");
929     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
930     TF_Operation* one = ScalarConst(1, params->body_graph, s_);
931     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
932     TF_Operation* add2 = Add(add1, one, params->body_graph, s_, "add2");
933     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
934     params->body_outputs[0] = {add2, 0};
935     params->body_outputs[1] = params->body_inputs[1];
936 
937     // Finalize while loop
938     TF_FinishWhile(params.get(), s_, &outputs[0]);
939     EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
940   }
941 
942   // Define function, use it in graph, and run
943   DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {outputs[0]}, {});
944   TF_Operation* five = ScalarConst(5, host_graph_, s_, "five");
945   TF_Operation* func_feed = Placeholder(host_graph_, s_);
946   TF_Operation* func_op = Use({func_feed, five});
947   Run({{func_feed, Int32Tensor(2)}}, func_op, 2 /*+=*/ + 5 + 1);
948 
949   // Verify input, output, and subset of edges in fdef.
950   // The subset of edges we verify is a chain between feed1 and output to
951   // make sure that the correct output is picked.
952   tensorflow::FunctionDef fdef;
953   ASSERT_TRUE(GetFunctionDef(func_, &fdef));
954   VerifyFDefInputs(fdef, M({{"feed1"}, {"feed2"}}));
955   VerifyFDefOutputs(fdef, M({{"test_loop_exit"}}));
956   VerifyFDefEdges(fdef,
957                   {{"feed1", "test_loop/Enter:0"},
958                    {"test_loop/Enter:output:0", "test_loop/Merge:0"},
959                    {"test_loop/Merge:output:0", "test_loop/Switch:0"},
960                    {"test_loop/Switch:output_false:0", "test_loop/Exit:0"},
961                    {"test_loop/Exit:output:0", "test_loop_exit"}},
962                   {}, false);
963 }
964 
TEST_F(CApiFunctionTest,ControlDependency)965 TEST_F(CApiFunctionTest, ControlDependency) {
966   /*
967    *                  |  |    scalar
968    *                  |  |    .
969    *                  v  v   . <---- control dependency
970    *                  add < -
971    *                   |
972    *                   v
973    */
974   // Define
975   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
976   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
977   TF_Operation* five = ScalarConst(5, func_graph_, s_);
978   TF_Operation* add =
979       AddWithCtrlDependency(feed1, feed2, func_graph_, five, s_);
980   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
981   Define(-1, {}, {feed1, feed2}, {add}, {});
982 
983   // Use, run, and verify
984   TF_Operation* two = ScalarConst(2, host_graph_, s_);
985   TF_Operation* func_feed = Placeholder(host_graph_, s_);
986   TF_Operation* func_op = Use({two, func_feed});
987   Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
988   VerifyFDef(
989       {"add_0", "scalar"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}),
990       {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}},
991       {{"^scalar", "add_0:2"}});
992 }
993 
TEST_F(CApiFunctionTest,ControlDependencyOutsideOfBody)994 TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody) {
995   /*
996    *                  |  |    scalar
997    *                  |  |    .
998    *                  v  v   . <---- control dependency
999    *                  add < -
1000    *                   |
1001    *                   v
1002    */
1003   // Define
1004   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1005   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1006   TF_Operation* five = ScalarConst(5, func_graph_, s_);
1007   TF_Operation* add =
1008       AddWithCtrlDependency(feed1, feed2, func_graph_, five, s_);
1009   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1010   Define(1, {add}, {feed1, feed2}, {add}, {}, true);
1011   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1012   EXPECT_EQ(string("The source of control edge [id=3 scalar:-1 -> add:-1] "
1013                    "is not in the body. Encountered while creating "
1014                    "function 'MyFunc'"),
1015             string(TF_Message(s_)));
1016 }
1017 
TEST_F(CApiFunctionTest,ControlDependencyOutsideOfBody_FromInputNode)1018 TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody_FromInputNode) {
1019   /*
1020    *                  |  |.
1021    *                  |  |  .
1022    *                  |  |   .
1023    *                  v  v   . <---- control dependency
1024    *                  add < -
1025    *                   |
1026    *                   v
1027    */
1028   // Define
1029   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1030   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1031   TF_Operation* add =
1032       AddWithCtrlDependency(feed1, feed2, func_graph_, feed1, s_);
1033   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1034   Define(-1, {}, {feed1, feed2}, {add}, {});
1035 
1036   // Use, run, and verify
1037   TF_Operation* two = ScalarConst(2, host_graph_, s_);
1038   TF_Operation* func_feed = Placeholder(host_graph_, s_);
1039   TF_Operation* func_op = Use({two, func_feed});
1040   Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
1041   VerifyFDef(
1042       {"add_0"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}),
1043       {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}},
1044       {{"^feed1", "add_0:2"}});
1045 }
1046 
TEST_F(CApiFunctionTest,DuplicateInputsAreNotAllowed)1047 TEST_F(CApiFunctionTest, DuplicateInputsAreNotAllowed) {
1048   /*
1049    *                  feed
1050    *                   |
1051    *                  +++
1052    *                  | |
1053    *              +---+-+---+
1054    *              |   | |   |
1055    *              |   v v   |
1056    *              |   add   |
1057    *              |    |    |
1058    *              |    |    |
1059    *              +----+----+
1060    *                   |
1061    *                   v
1062    */
1063   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1064   TF_Operation* add = Add(feed1, feed1, func_graph_, s_);
1065   Define(-1, {}, {feed1, feed1}, {add}, {}, true);
1066   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1067   EXPECT_EQ(
1068       string("TF_Output feed1:0 appears more than once in the input list"),
1069       string(TF_Message(s_)));
1070 }
1071 
TEST_F(CApiFunctionTest,DuplicateOutputNamesAreNotAllowed)1072 TEST_F(CApiFunctionTest, DuplicateOutputNamesAreNotAllowed) {
1073   /*
1074    *                  |  |  |
1075    *                  v  v  /
1076    *                  add  /
1077    *                   |  |
1078    *                 +-+  |
1079    *                 | |  |
1080    *                 | v  v
1081    *                 | add
1082    *                 |  |
1083    *                 v  v
1084    */
1085   // Define
1086   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1087   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1088   TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
1089   TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
1090   TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
1091   Define(-1, {}, {feed1, feed2, feed3}, {add1, add2}, {"my_out", "my_out"},
1092          true);
1093   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1094   EXPECT_EQ(string("Cannot have duplicate output names. Name 'my_out' "
1095                    "appears more than once in 'output_names' array."),
1096             string(TF_Message(s_)));
1097 }
1098 
TEST_F(CApiFunctionTest,InvalidInputTensor_HighIndex)1099 TEST_F(CApiFunctionTest, InvalidInputTensor_HighIndex) {
1100   /*
1101    *                  |  |
1102    *                  v  v
1103    *                  add
1104    *                   |
1105    *                   v
1106    */
1107   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1108   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1109   TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
1110   DefineT(-1, {}, {{feed1, 0}, {feed2, 2}}, {{add, 0}}, {}, true);
1111   EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s_));
1112   EXPECT_EQ(string("Node 'feed2' (type: 'Placeholder', num of outputs: 1) does "
1113                    "not have output 2\n\tEncountered while processing "
1114                    "input 1 into function 'MyFunc'"),
1115             string(TF_Message(s_)));
1116 }
1117 
TEST_F(CApiFunctionTest,InvalidInputTensor_BadNodePtr)1118 TEST_F(CApiFunctionTest, InvalidInputTensor_BadNodePtr) {
1119   /*
1120    *                  |  |
1121    *                  v  v
1122    *                  add
1123    *                   |
1124    *                   v
1125    */
1126   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1127   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1128   TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
1129   DefineT(-1, {}, {{feed1, 0}, {nullptr, 0}}, {{add, 0}}, {}, true);
1130   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1131   EXPECT_EQ(string("Node is null\n\tEncountered while processing input 1 "
1132                    "into function 'MyFunc'"),
1133             string(TF_Message(s_)));
1134 }
1135 
TEST_F(CApiFunctionTest,InvalidOutputTensor_HighIndex)1136 TEST_F(CApiFunctionTest, InvalidOutputTensor_HighIndex) {
1137   /*
1138    *                  |  |
1139    *                  v  v
1140    *                  add
1141    *                   |
1142    *                   v
1143    */
1144   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1145   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1146   TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
1147   DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{add, 3}}, {}, true);
1148   EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s_));
1149   EXPECT_EQ(string("Node 'add' (type: 'AddN', num of outputs: 1) does "
1150                    "not have output 3\n\tEncountered while processing "
1151                    "output 0 from function 'MyFunc'"),
1152             string(TF_Message(s_)));
1153 }
1154 
TEST_F(CApiFunctionTest,InvalidOutputTensor_BadNodePtr)1155 TEST_F(CApiFunctionTest, InvalidOutputTensor_BadNodePtr) {
1156   /*
1157    *                  |  |
1158    *                  v  v
1159    *                  add
1160    *                   |
1161    *                   v
1162    */
1163   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1164   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1165   Add(feed1, feed2, func_graph_, s_);
1166   DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{nullptr, 3}}, {}, true);
1167   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1168   EXPECT_EQ(string("Node is null\n\tEncountered while processing output 0 "
1169                    "from function 'MyFunc'"),
1170             string(TF_Message(s_)));
1171 }
1172 
TEST_F(CApiFunctionTest,NodeMissingInput)1173 TEST_F(CApiFunctionTest, NodeMissingInput) {
1174   /*
1175    *        input---> |  | <----missing input
1176    *                  v  v
1177    *        body----> add
1178    *                   |
1179    *                   v
1180    */
1181   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1182   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1183   TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
1184   DefineT(1, {add}, {{feed1, 0}}, {{add, 0}}, {}, true);
1185   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1186   EXPECT_EQ(string("Input 1, 'feed2:0', of node 'add' in function 'MyFunc' "
1187                    "is not available. You might need to include it in inputs "
1188                    "or include its source node in the body"),
1189             string(TF_Message(s_)));
1190 }
1191 
TEST_F(CApiFunctionTest,OutputOpNotInBody)1192 TEST_F(CApiFunctionTest, OutputOpNotInBody) {
1193   /*
1194    *                  |  |
1195    *                  v  v
1196    *                  add    scalar    (scalar not included in body)
1197    *                   |       |
1198    *                   v       v       (function has two outputs)
1199    */
1200   // Define
1201   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1202   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1203   TF_Operation* scalar = ScalarConst(2, func_graph_, s_);
1204   TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
1205   Define(1, {add}, {feed1, feed2}, {add, scalar}, {}, true);
1206   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1207   EXPECT_EQ(string("TF_Output scalar:0 is neither in the function body nor "
1208                    "among function inputs. Encountered while creating "
1209                    "function 'MyFunc'"),
1210             string(TF_Message(s_)));
1211 }
1212 
DefineFunction(const char * name,TF_Function ** func,const char * description=nullptr,bool append_hash=false)1213 void DefineFunction(const char* name, TF_Function** func,
1214                     const char* description = nullptr,
1215                     bool append_hash = false) {
1216   std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
1217       TF_NewGraph(), TF_DeleteGraph);
1218   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
1219                                                            TF_DeleteStatus);
1220 
1221   TF_Operation* feed = Placeholder(func_graph.get(), s.get());
1222   TF_Operation* neg = Neg(feed, func_graph.get(), s.get());
1223 
1224   TF_Output inputs[] = {{feed, 0}};
1225   TF_Output outputs[] = {{neg, 0}};
1226   *func = TF_GraphToFunction(func_graph.get(), name, append_hash, -1,
1227                              /*opers=*/nullptr, 1, inputs, 1, outputs,
1228                              /*output_names=*/nullptr,
1229                              /*opts=*/nullptr, description, s.get());
1230   ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
1231   ASSERT_NE(*func, nullptr);
1232 }
1233 
1234 REGISTER_OP("CustomOp")
1235     .Output("output: float32")
1236     .Attr("index: int")
1237     .SetShapeFn(tensorflow::shape_inference::UnknownShape);
1238 
NodeWithPlaceholderAttrHelper(TF_Graph * graph,TF_Status * s,const char * name,const char * placeholder,TF_Operation ** op)1239 void NodeWithPlaceholderAttrHelper(TF_Graph* graph, TF_Status* s,
1240                                    const char* name, const char* placeholder,
1241                                    TF_Operation** op) {
1242   TF_OperationDescription* desc = TF_NewOperation(graph, "CustomOp", name);
1243   TF_SetAttrPlaceholder(desc, "index", placeholder);
1244   *op = TF_FinishOperation(desc, s);
1245   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1246   ASSERT_NE(*op, nullptr);
1247 }
1248 
TEST_F(CApiFunctionTest,GraphToFunctionDefWithPlaceholderAttr)1249 TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) {
1250   std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
1251       TF_NewGraph(), TF_DeleteGraph);
1252   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
1253                                                            TF_DeleteStatus);
1254 
1255   TF_Operation *node1, *node2, *node3;
1256   NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node1", "v1",
1257                                 &node1);
1258   NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node2", "v1",
1259                                 &node2);
1260   NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node3", "v2",
1261                                 &node3);
1262 
1263   TF_Output inputs[] = {};
1264   TF_Output outputs[] = {{node1, 0}, {node2, 0}, {node3, 0}};
1265   func_ = TF_GraphToFunction(
1266       func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
1267       /*opers=*/nullptr, 0, inputs, 3, outputs,
1268       /*output_names=*/nullptr,
1269       /*opts=*/nullptr, /*description=*/nullptr, s.get());
1270   ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
1271   ASSERT_NE(func_, nullptr);
1272 
1273   // Verify that FunctionDef has 2 attributes, "v1" and "v2".
1274   ASSERT_EQ(func_->fdef.signature().attr().size(), 2);
1275   EXPECT_EQ(func_->fdef.signature().attr(0).name(), "v1");
1276   EXPECT_EQ(func_->fdef.signature().attr(0).type(), "int");
1277   EXPECT_EQ(func_->fdef.signature().attr(1).name(), "v2");
1278   EXPECT_EQ(func_->fdef.signature().attr(1).type(), "int");
1279 }
1280 
TEST_F(CApiFunctionTest,SetGradientAndRun)1281 TEST_F(CApiFunctionTest, SetGradientAndRun) {
1282   // Define the function and its grad
1283   DefineFunction(func_name_, &func_);
1284   TF_Function* grad_func;
1285   DefineFunction("MyGrad", &grad_func);
1286 
1287   // Add func and its gradient to host graph
1288   TF_GraphCopyFunction(host_graph_, func_, grad_func, s_);
1289   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1290 
1291   // Verify that function and its grad are in host graph's GraphDef
1292   GraphDef gdef;
1293   GetGraphDef(host_graph_, &gdef);
1294   std::vector<string> func_names = GetFuncNames(gdef);
1295   ASSERT_EQ(2, func_names.size());
1296   ASSERT_EQ(func_name_, func_names[0]);
1297   ASSERT_EQ("MyGrad", func_names[1]);
1298   std::vector<std::pair<string, string>> grads = GetGradDefs(gdef);
1299   ASSERT_EQ(1, grads.size());
1300   ASSERT_EQ(func_name_, grads[0].first);
1301   ASSERT_EQ("MyGrad", grads[0].second);
1302 
1303   // These calls must be noops
1304   TF_GraphCopyFunction(host_graph_, func_, grad_func, s_);
1305   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1306   TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
1307   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1308 
1309   // Delete the gradient func.
1310   // It is safe to delete after adding a copy to host graph.
1311   TF_DeleteFunction(grad_func);
1312 
1313   // Check that GraphDef did not change
1314   GraphDef gdef2;
1315   GetGraphDef(host_graph_, &gdef2);
1316   ASSERT_EQ(gdef.DebugString(), gdef2.DebugString());
1317 
1318   // Use and run func
1319   TF_Operation* func_feed = Placeholder(host_graph_, s_);
1320   TF_Operation* func_op = Use({func_feed});
1321   Run({{func_feed, Int32Tensor(3)}}, func_op, -3);
1322 }
1323 
TEST_F(CApiFunctionTest,SameGradForTwoFunctions)1324 TEST_F(CApiFunctionTest, SameGradForTwoFunctions) {
1325   // Define the functions
1326   TF_Function* func1;
1327   TF_Function* func2;
1328   TF_Function* grad_func;
1329   DefineFunction("FooFunc1", &func1);
1330   DefineFunction("FooFunc2", &func2);
1331   DefineFunction("MyGrad", &grad_func);
1332 
1333   // Make grad_func be a gradient of func1 and func2
1334   TF_GraphCopyFunction(host_graph_, func1, grad_func, s_);
1335   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1336   TF_GraphCopyFunction(host_graph_, func2, grad_func, s_);
1337   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1338 
1339   // Verify that functions and their gradients are in host graph's GraphDef
1340   GraphDef gdef;
1341   GetGraphDef(host_graph_, &gdef);
1342   std::vector<std::pair<string, string>> grads = GetGradDefs(gdef);
1343   ASSERT_EQ(2, grads.size());
1344   ASSERT_EQ("FooFunc1", grads[0].first);
1345   ASSERT_EQ("MyGrad", grads[0].second);
1346   ASSERT_EQ("FooFunc2", grads[1].first);
1347   ASSERT_EQ("MyGrad", grads[1].second);
1348 
1349   TF_DeleteFunction(func1);
1350   TF_DeleteFunction(func2);
1351   TF_DeleteFunction(grad_func);
1352 }
1353 
TEST_F(CApiFunctionTest,AddFunctionsThenMakeOneGradientOfAnother)1354 TEST_F(CApiFunctionTest, AddFunctionsThenMakeOneGradientOfAnother) {
1355   // Define the functions
1356   TF_Function* func;
1357   TF_Function* grad_func;
1358   DefineFunction("FooFunc", &func);
1359   DefineFunction("MyGrad", &grad_func);
1360 
1361   // Add functions individually
1362   TF_GraphCopyFunction(host_graph_, func, nullptr, s_);
1363   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1364   TF_GraphCopyFunction(host_graph_, grad_func, nullptr, s_);
1365   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1366 
1367   // Check that functions are added but not linked
1368   GraphDef gdef;
1369   GetGraphDef(host_graph_, &gdef);
1370   std::vector<string> func_names = GetFuncNames(gdef);
1371   ASSERT_EQ(2, func_names.size());
1372   ASSERT_EQ("FooFunc", func_names[0]);
1373   ASSERT_EQ("MyGrad", func_names[1]);
1374   ASSERT_EQ(0, GetGradDefs(gdef).size());
1375 
1376   // Make grad_func a gradient of func
1377   TF_GraphCopyFunction(host_graph_, func, grad_func, s_);
1378   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1379 
1380   // Verify that function and its grad are linked
1381   gdef.Clear();
1382   GetGraphDef(host_graph_, &gdef);
1383   std::vector<std::pair<string, string>> grads = GetGradDefs(gdef);
1384   ASSERT_EQ(1, grads.size());
1385   ASSERT_EQ("FooFunc", grads[0].first);
1386   ASSERT_EQ("MyGrad", grads[0].second);
1387 
1388   TF_DeleteFunction(func);
1389   TF_DeleteFunction(grad_func);
1390 }
1391 
TEST_F(CApiFunctionTest,GradientErrorCases)1392 TEST_F(CApiFunctionTest, GradientErrorCases) {
1393   // Define the function
1394   DefineFunction(func_name_, &func_);
1395   TF_Function* grad_func1;
1396   TF_Function* grad_func2;
1397   DefineFunction("MyGrad1", &grad_func1);
1398   DefineFunction("MyGrad2", &grad_func2);
1399 
1400   // func cannot be null
1401   TF_GraphCopyFunction(host_graph_, nullptr, func_, s_);
1402   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1403   EXPECT_EQ(string("'func' argument to TF_GraphCopyFunction cannot be null"),
1404             string(TF_Message(s_)));
1405 
1406   // Cannot change gradient
1407   TF_GraphCopyFunction(host_graph_, func_, grad_func1, s_);
1408   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1409   TF_GraphCopyFunction(host_graph_, func_, grad_func2, s_);
1410   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1411   EXPECT_EQ(string("Cannot assign gradient function 'MyGrad2' to 'MyFunc' "
1412                    "because it already has gradient function 'MyGrad1'"),
1413             string(TF_Message(s_)));
1414 
1415   TF_DeleteFunction(grad_func1);
1416   TF_DeleteFunction(grad_func2);
1417 }
1418 
TEST_F(CApiFunctionTest,ImportFunctionDef)1419 TEST_F(CApiFunctionTest, ImportFunctionDef) {
1420   /*
1421    * Using a fairly complex function with output names
1422    *
1423    *                  |  |  |
1424    *                  v  v  /
1425    *                  add  /
1426    *                   |  |
1427    *            +------+  |
1428    *            |      |  |
1429    *            |      v  v
1430    *            |      add
1431    *            |       |
1432    *            v       v
1433    *    internal_out  final_out
1434    */
1435   // Define
1436   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1437   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1438   TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
1439   TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
1440   TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
1441   Define(-1, {}, {feed1, feed2, feed3}, {add1, add2},
1442          {"internal_out", "final_out"});
1443 
1444   // Save func_ to FunctionDef and import it back
1445   Reincarnate();
1446 
1447   // Use, run, and verify
1448   TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
1449   TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten");
1450   TF_Operation* func_feed = Placeholder(host_graph_, s_);
1451   TF_Operation* func_op = Use({two, ten, func_feed});
1452   Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {12, 15});
1453   VerifyFDef({"add1", "add2"}, M({{"feed1"}, {"feed2"}, {"feed3"}}),
1454              M({{"internal_out"}, {"final_out"}}),
1455              {{"feed1", "add1:0"},
1456               {"feed2", "add1:1"},
1457               {"add1:sum:0", "add2:0"},
1458               {"feed3", "add2:1"},
1459               {"add1:sum:0", "internal_out"},
1460               {"add2:sum:0", "final_out"}},
1461              {});
1462 }
1463 
TEST_F(CApiFunctionTest,ImportFunctionDef_InvalidProto)1464 TEST_F(CApiFunctionTest, ImportFunctionDef_InvalidProto) {
1465   // Invalid protobuf data (protos cannot start with 4 bytes of zeros)
1466   char proto[] = {0x0, 0x0, 0x0, 0x0};
1467   func_ = TF_FunctionImportFunctionDef(proto, 4, s_);
1468   EXPECT_TRUE(func_ == nullptr);
1469   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1470   EXPECT_EQ(string("Invalid FunctionDef given to TF_FunctionImportFunctionDef"),
1471             string(TF_Message(s_)));
1472 }
1473 
TEST_F(CApiFunctionTest,Attribute)1474 TEST_F(CApiFunctionTest, Attribute) {
1475   DefineFunction(func_name_, &func_);
1476 
1477   // Get non existent attribute
1478   TF_Buffer* attr_buf = TF_NewBuffer();
1479   TF_FunctionGetAttrValueProto(func_, "foo_attr", attr_buf, s_);
1480   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1481   EXPECT_EQ(string("Function 'MyFunc' has no attr named 'foo_attr'."),
1482             string(TF_Message(s_)));
1483   TF_DeleteBuffer(attr_buf);
1484 
1485   // Set attr
1486   tensorflow::AttrValue attr;
1487   attr.set_s("test_attr_value");
1488   string bytes;
1489   attr.SerializeToString(&bytes);
1490   TF_FunctionSetAttrValueProto(func_, "test_attr_name", bytes.data(),
1491                                bytes.size(), s_);
1492   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1493 
1494   // Get attr
1495   AttrValue read_attr;
1496   GetAttr("test_attr_name", &read_attr);
1497   ASSERT_EQ(attr.DebugString(), read_attr.DebugString());
1498 
1499   // Retrieve the same attr after save/restore
1500   Reincarnate();
1501   AttrValue read_attr2;
1502   GetAttr("test_attr_name", &read_attr2);
1503   ASSERT_EQ(attr.DebugString(), read_attr2.DebugString());
1504 }
1505 
TEST_F(CApiFunctionTest,Description)1506 TEST_F(CApiFunctionTest, Description) {
1507   DefineFunction(func_name_, &func_, "Return something");
1508   tensorflow::FunctionDef fdef;
1509   ASSERT_TRUE(GetFunctionDef(func_, &fdef));
1510   ASSERT_EQ(string("Return something"), fdef.signature().description());
1511 }
1512 
TEST_F(CApiFunctionTest,Name)1513 TEST_F(CApiFunctionTest, Name) {
1514   DefineFunction("long_func_name", &func_, "Return something",
1515                  /*append_hash=*/false);
1516   tensorflow::FunctionDef fdef;
1517   ASSERT_TRUE(GetFunctionDef(func_, &fdef));
1518   ASSERT_EQ(string("long_func_name"), fdef.signature().name());
1519 }
1520 
TEST_F(CApiFunctionTest,AppendHash)1521 TEST_F(CApiFunctionTest, AppendHash) {
1522   DefineFunction("func_name_base", &func_, "Return something",
1523                  /*append_hash=*/true);
1524   tensorflow::FunctionDef fdef;
1525   ASSERT_TRUE(GetFunctionDef(func_, &fdef));
1526 #if (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
1527   ASSERT_EQ(string("func_name_base_ZpgUD4x8oqk"), fdef.signature().name());
1528 #else
1529   ASSERT_EQ(string("func_name_base_qaJ8jA8UmGY"), fdef.signature().name());
1530 #endif
1531 }
1532 
TEST_F(CApiFunctionTest,GetOpDef)1533 TEST_F(CApiFunctionTest, GetOpDef) {
1534   DefineFunction(func_name_, &func_);
1535   TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
1536   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1537 
1538   // Test we can retrieve function OpDef from graph
1539   TF_Buffer* buffer = TF_NewBuffer();
1540   TF_GraphGetOpDef(host_graph_, func_name_, buffer, s_);
1541   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1542 
1543   // Sanity check returned OpDef
1544   string data(static_cast<const char*>(buffer->data), buffer->length);
1545   OpDef op_def;
1546   op_def.ParseFromString(data);
1547   EXPECT_EQ(op_def.name(), func_name_);
1548   EXPECT_EQ(op_def.input_arg_size(), 1);
1549   EXPECT_EQ(op_def.output_arg_size(), 1);
1550   EXPECT_FALSE(op_def.is_stateful());
1551 
1552   TF_DeleteBuffer(buffer);
1553 }
1554 
DefineStatefulFunction(const char * name,TF_Function ** func)1555 void DefineStatefulFunction(const char* name, TF_Function** func) {
1556   std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
1557       TF_NewGraph(), TF_DeleteGraph);
1558   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
1559                                                            TF_DeleteStatus);
1560 
1561   TF_Tensor* tensor_shape = Int32Tensor({37, 1});
1562   TF_Operation* shape = Const(tensor_shape, func_graph.get(), s.get(), "shape");
1563   TF_Operation* random =
1564       RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get());
1565 
1566   TF_Output inputs[] = {};
1567   TF_Output outputs[] = {{random, 0}};
1568   *func = TF_GraphToFunction(func_graph.get(), name,
1569                              /*append_hash_to_fn_name=*/false, -1,
1570                              /*opers=*/nullptr, 0, inputs, 1, outputs,
1571                              /*output_names=*/nullptr,
1572                              /*opts=*/nullptr, "", s.get());
1573   ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
1574   ASSERT_NE(*func, nullptr);
1575   TF_DeleteTensor(tensor_shape);
1576 }
1577 
TEST_F(CApiFunctionTest,StatefulOpDef)1578 TEST_F(CApiFunctionTest, StatefulOpDef) {
1579   DefineStatefulFunction(func_name_, &func_);
1580   TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
1581   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1582 
1583   // Test we can retrieve function OpDef from graph
1584   TF_Buffer* buffer = TF_NewBuffer();
1585   TF_GraphGetOpDef(host_graph_, func_name_, buffer, s_);
1586   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1587 
1588   // Sanity check returned OpDef
1589   string data(static_cast<const char*>(buffer->data), buffer->length);
1590   OpDef op_def;
1591   op_def.ParseFromString(data);
1592   EXPECT_EQ(op_def.name(), func_name_);
1593   EXPECT_EQ(op_def.input_arg_size(), 0);
1594   EXPECT_EQ(op_def.output_arg_size(), 1);
1595   EXPECT_TRUE(op_def.is_stateful());
1596 
1597   TF_DeleteBuffer(buffer);
1598 }
1599 
AssertEqual(TF_Function * f1,TF_Function * f2)1600 void AssertEqual(TF_Function* f1, TF_Function* f2) {
1601   string s1, s2;
1602   tensorflow::FunctionDef fdef1, fdef2;
1603   ASSERT_TRUE(GetFunctionDef(f1, &fdef1));
1604   ASSERT_TRUE(GetFunctionDef(f2, &fdef2));
1605   SerializeToStringDeterministic(fdef1, &s1);
1606   SerializeToStringDeterministic(fdef2, &s2);
1607   ASSERT_EQ(s1, s2);
1608 }
1609 
GetName(TF_Function * func)1610 string GetName(TF_Function* func) {
1611   tensorflow::FunctionDef fdef;
1612   GetFunctionDef(func, &fdef);
1613   return fdef.signature().name();
1614 }
1615 
TEST_F(CApiFunctionTest,GetFunctionsFromGraph)1616 TEST_F(CApiFunctionTest, GetFunctionsFromGraph) {
1617   TF_Function* funcs[2];
1618 
1619   // Get functions from empty graph
1620   EXPECT_EQ(TF_GraphNumFunctions(host_graph_), 0);
1621   TF_GraphGetFunctions(host_graph_, nullptr, 0, s_);
1622   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1623 
1624   // Define a function and add it to host_graph_
1625   TF_Function* func0;
1626   DefineFunction("FooFunc0", &func0);
1627   TF_GraphCopyFunction(host_graph_, func0, nullptr, s_);
1628   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1629 
1630   // Get this function from host_graph_
1631   EXPECT_EQ(TF_GraphNumFunctions(host_graph_), 1);
1632   EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 0, s_), 0);
1633   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1634   EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 1, s_), 1);
1635   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1636   AssertEqual(func0, funcs[0]);
1637   TF_DeleteFunction(funcs[0]);
1638   EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 2, s_), 1);
1639   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1640   AssertEqual(func0, funcs[0]);
1641   TF_DeleteFunction(funcs[0]);
1642 
1643   // Define a second function
1644   TF_Function* func1;
1645   DefineFunction("FooFunc1", &func1);
1646   TF_GraphCopyFunction(host_graph_, func1, nullptr, s_);
1647   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1648 
1649   // Get both function from host_graph_
1650   EXPECT_EQ(TF_GraphNumFunctions(host_graph_), 2);
1651   EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 0, s_), 0);
1652   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1653   EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 2, s_), 2);
1654   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1655   if (GetName(funcs[0]) == GetName(func0)) {
1656     AssertEqual(func0, funcs[0]);
1657     AssertEqual(func1, funcs[1]);
1658   } else {
1659     AssertEqual(func0, funcs[1]);
1660     AssertEqual(func1, funcs[0]);
1661   }
1662 
1663   TF_DeleteFunction(funcs[0]);
1664   TF_DeleteFunction(funcs[1]);
1665 
1666   TF_DeleteFunction(func0);
1667   TF_DeleteFunction(func1);
1668 }
1669 
1670 // This test only works when the TF build includes XLA compiler. One way to set
1671 // this up is via bazel build option "--define with_xla_support=true".
1672 //
1673 // FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to
1674 // something like TENSORFLOW_CAPI_USE_XLA.
1675 #ifdef TENSORFLOW_EAGER_USE_XLA
TEST_F(CApiFunctionTest,StatelessIf_XLA)1676 TEST_F(CApiFunctionTest, StatelessIf_XLA) {
1677   TF_Function* func;
1678   const std::string funcName = "BranchFunc";
1679   DefineFunction(funcName.c_str(), &func);
1680   TF_GraphCopyFunction(host_graph_, func, nullptr, s_);
1681   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1682 
1683   TF_Operation* feed = Placeholder(host_graph_, s_);
1684   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1685 
1686   TF_Operation* true_cond = ScalarConst(true, host_graph_, s_);
1687   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1688 
1689   TF_OperationDescription* desc =
1690       TF_NewOperation(host_graph_, "StatelessIf", "IfNode");
1691   TF_AddInput(desc, {true_cond, 0});
1692   TF_Output inputs[] = {{feed, 0}};
1693   TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs));
1694   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1695   TF_SetAttrType(desc, "Tcond", TF_BOOL);
1696   TF_DataType inputType = TF_INT32;
1697   TF_SetAttrTypeList(desc, "Tin", &inputType, 1);
1698   TF_SetAttrTypeList(desc, "Tout", &inputType, 1);
1699   TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size());
1700   TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size());
1701   TF_SetDevice(desc, "/device:XLA_CPU:0");
1702   auto op = TF_FinishOperation(desc, s_);
1703   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1704   ASSERT_NE(op, nullptr);
1705 
1706   // Create a session for this graph.
1707   CSession csession(host_graph_, s_, /*use_XLA*/ true);
1708   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1709 
1710   // Run the graph.
1711   csession.SetInputs({{feed, Int32Tensor(17)}});
1712   csession.SetOutputs({op});
1713   csession.Run(s_);
1714   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1715   TF_Tensor* out = csession.output_tensor(0);
1716   ASSERT_TRUE(out != nullptr);
1717   EXPECT_EQ(TF_INT32, TF_TensorType(out));
1718   EXPECT_EQ(0, TF_NumDims(out));  // scalar
1719   ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
1720   int32* output_contents = static_cast<int32*>(TF_TensorData(out));
1721   EXPECT_EQ(-17, *output_contents);
1722 
1723   // Clean up
1724   csession.CloseAndDelete(s_);
1725   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1726 
1727   TF_DeleteFunction(func);
1728 }
1729 #endif  // TENSORFLOW_EAGER_USE_XLA
1730 
1731 }  // namespace
1732 }  // namespace tensorflow
1733