• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/c_api.h"
17 #include "tensorflow/c/c_test_util.h"
18 #include "tensorflow/core/platform/logging.h"
19 #include "tensorflow/core/platform/strcat.h"
20 #include "tensorflow/core/platform/test.h"
21 
22 using tensorflow::GraphDef;
23 
24 namespace {
25 
26 class CApiWhileLoopTest : public ::testing::Test {
27  protected:
CApiWhileLoopTest()28   CApiWhileLoopTest() : s_(TF_NewStatus()), graph_(TF_NewGraph()) {}
29 
~CApiWhileLoopTest()30   ~CApiWhileLoopTest() override {
31     TF_DeleteGraph(graph_);
32     TF_DeleteStatus(s_);
33   }
34 
Init(int ninputs)35   void Init(int ninputs) {
36     DCHECK(inputs_.empty());
37     DCHECK_GT(ninputs, 0);
38 
39     for (int i = 0; i < ninputs; ++i) {
40       TF_Operation* placeholder = Placeholder(
41           graph_, s_, ::tensorflow::strings::StrCat("p", i).c_str());
42       DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
43       inputs_.push_back({placeholder, 0});
44     }
45 
46     original_graph_description_ = GraphDebugString();
47 
48     params_.reset(new TF_WhileParams(
49         TF_NewWhile(graph_, &inputs_[0], inputs_.size(), s_)));
50     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
51     ASSERT_EQ(original_graph_description_, GraphDebugString())
52         << "TF_NewWhile() altered graph";
53 
54     params_->name = "test_loop";
55 
56     // Initialize outputs_ so we can easily detect errors/bugs
57     outputs_.resize(ninputs, {nullptr, -1});
58   }
59 
ExpectOK()60   void ExpectOK() {
61     TF_FinishWhile(params_.get(), s_, &outputs_[0]);
62     EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
63   }
64 
ExpectError(TF_Code expected_code,const string & expected_msg)65   void ExpectError(TF_Code expected_code, const string& expected_msg) {
66     TF_FinishWhile(params_.get(), s_, &outputs_[0]);
67     EXPECT_EQ(expected_code, TF_GetCode(s_));
68     EXPECT_EQ(expected_msg, TF_Message(s_));
69     // TODO(skyewm): this assert is currently broken. Fix or remove guarantee.
70     // ASSERT_EQ(original_graph_description_, GraphDebugString()) <<
71     //     "TF_FinishWhile() altered graph on error";
72   }
73 
Run(std::initializer_list<int> input_values)74   void Run(std::initializer_list<int> input_values) {
75     Run(outputs_, input_values);
76   }
77 
Run(const std::vector<TF_Output> & run_outputs,std::initializer_list<int> input_values)78   void Run(const std::vector<TF_Output>& run_outputs,
79            std::initializer_list<int> input_values) {
80     DCHECK_EQ(inputs_.size(), input_values.size());
81     std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs(inputs_.size());
82     int i = 0;
83     for (int v : input_values) {
84       inputs[i] = {inputs_[i].oper, Int32Tensor(v)};
85       ++i;
86     }
87     // TODO(skyewm): use std::make_unique or absl::make_unique when possible.
88     csession_.reset(new CSession(graph_, s_));
89     csession_->SetInputs(inputs);
90     csession_->SetOutputs(run_outputs);
91     csession_->Run(s_);
92     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
93   }
94 
ExpectOutputValue(int idx,int expected_value)95   void ExpectOutputValue(int idx, int expected_value) {
96     TF_Tensor* out = csession_->output_tensor(idx);
97     ASSERT_TRUE(out != nullptr);
98     EXPECT_EQ(TF_INT32, TF_TensorType(out));
99     EXPECT_EQ(0, TF_NumDims(out));
100     ASSERT_EQ(sizeof(int32_t), TF_TensorByteSize(out));
101     int32_t* data = static_cast<int32_t*>(TF_TensorData(out));
102     EXPECT_EQ(expected_value, *data);
103   }
104 
105   // Create a valid conditional graph. Useful for testing unrelated errors.
CreateCondGraph()106   void CreateCondGraph() {
107     TF_Operation* one = ScalarConst(1, params_->cond_graph, s_);
108     TF_Operation* less_than =
109         LessThan(params_->cond_inputs[0], {one, 0}, params_->cond_graph, s_);
110     DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
111     params_->cond_output = {less_than, 0};
112   }
113 
GraphDebugString() const114   string GraphDebugString() const {
115     TF_Buffer* buf = TF_NewBuffer();
116     TF_GraphToGraphDef(graph_, buf, s_);
117     DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
118     GraphDef def;
119     bool success = def.ParseFromArray(buf->data, buf->length);
120     DCHECK(success);
121     TF_DeleteBuffer(buf);
122     return def.DebugString();
123   }
124 
125   TF_Status* s_;
126   TF_Graph* graph_;
127   std::vector<TF_Output> inputs_;   // The inputs to the while loop
128   std::vector<TF_Output> outputs_;  // The final outputs of the while loop
129   std::unique_ptr<TF_WhileParams> params_;
130   std::unique_ptr<CSession> csession_;
131 
132  private:
133   // Used to verify that errors don't change graph_
134   string original_graph_description_;
135 };
136 
TEST_F(CApiWhileLoopTest,BasicLoop)137 TEST_F(CApiWhileLoopTest, BasicLoop) {
138   Init(2);
139 
140   // Validate TF_WhileParams returned by TF_NewWhile()
141   EXPECT_TRUE(params_->body_graph != nullptr);
142   EXPECT_TRUE(params_->cond_graph != nullptr);
143 
144   EXPECT_EQ(params_->ninputs, 2);
145 
146   ASSERT_TRUE(params_->cond_inputs != nullptr);
147   ASSERT_TRUE(params_->cond_inputs[0].oper != nullptr);
148   EXPECT_TRUE(params_->cond_inputs[1].oper != nullptr);
149 
150   ASSERT_TRUE(params_->body_inputs != nullptr);
151   EXPECT_TRUE(params_->body_inputs[0].oper != nullptr);
152   EXPECT_TRUE(params_->body_inputs[1].oper != nullptr);
153 
154   ASSERT_TRUE(params_->body_outputs != nullptr);
155 
156   // Create loop: while (input1 < input2) input1 += input2 + 1
157   TF_Operation* less_than =
158       LessThan(params_->cond_inputs[0], params_->cond_inputs[1],
159                params_->cond_graph, s_);
160   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
161   params_->cond_output = {less_than, 0};
162 
163   TF_Operation* add1 = Add(params_->body_inputs[0], params_->body_inputs[1],
164                            params_->body_graph, s_, "add1");
165   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
166   TF_Operation* one = ScalarConst(1, params_->body_graph, s_);
167   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
168   TF_Operation* add2 = Add(add1, one, params_->body_graph, s_, "add2");
169   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
170   params_->body_outputs[0] = {add2, 0};
171   params_->body_outputs[1] = params_->body_inputs[1];
172 
173   // Finalize while loop
174   ExpectOK();
175 
176   // Validate while loop outputs returned by TF_FinishWhile()
177   EXPECT_TRUE(outputs_[0].oper != nullptr);
178   EXPECT_GE(outputs_[0].index, 0);
179   EXPECT_TRUE(outputs_[1].oper != nullptr);
180   EXPECT_GE(outputs_[1].index, 0);
181 
182   // Check that cond and body inputs are not present
183   for (int i = 0; i < params_->ninputs; ++i) {
184     string cond_name =
185         ::tensorflow::strings::StrCat(params_->name, "/cond/cond_input", i);
186     string body_name =
187         ::tensorflow::strings::StrCat(params_->name, "/body/body_input", i);
188     EXPECT_TRUE(TF_GraphOperationByName(graph_, cond_name.c_str()) == nullptr);
189     EXPECT_TRUE(TF_GraphOperationByName(graph_, body_name.c_str()) == nullptr);
190   }
191 
192   // Run the graph
193   Run({-9, 2});
194   ExpectOutputValue(0, 3);
195   ExpectOutputValue(1, 2);
196 }
197 
TEST_F(CApiWhileLoopTest,NestedLoop)198 TEST_F(CApiWhileLoopTest, NestedLoop) {
199   Init(2);
200   // Create nested loop:
201   //  while (input1 < 6) {
202   //    inner_input1 = input1
203   //    while (inner_input1 < 3) {
204   //      input2 += 1
205   //      inner_input1 += 2
206   //    }
207   //    input1 += input2
208   //  }
209   //
210   // Expected execution with initial values input1 = input2 = 0:
211   //
212   // outer inner               inner_
213   // step# step# input1 input2 input1
214   // ------------------------------------
215   //   0     0     0      0      0
216   //   0     1     0      1      2
217   //   0     2     0      2      4
218   //   0     -     2      2      -
219   //   1     0     2      2      2
220   //   1     1     2      3      4
221   //   1     -     5      3      -
222   //   2     0     5      3      5
223   //   2     -     8      3      -
224 
225   // Create outer cond graph
226   TF_Operation* six = ScalarConst(6, params_->cond_graph, s_);
227   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
228   TF_Operation* less_than =
229       LessThan(params_->cond_inputs[0], {six, 0}, params_->cond_graph, s_);
230   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
231   params_->cond_output = {less_than, 0};
232 
233   // Create outer body graph
234   // Init inner graph
235   TF_Output inner_inputs[] = {params_->body_inputs[0], params_->body_inputs[1]};
236   TF_WhileParams inner_params =
237       TF_NewWhile(params_->body_graph, inner_inputs, 2, s_);
238   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
239   inner_params.name = "inner_loop";
240 
241   // Create inner cond graph
242   TF_Operation* three = ScalarConst(3, inner_params.cond_graph, s_);
243   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
244   TF_Operation* inner_less_than = LessThan(
245       inner_params.cond_inputs[0], {three, 0}, inner_params.cond_graph, s_);
246   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
247   inner_params.cond_output = {inner_less_than, 0};
248 
249   // Create inner body graph
250   TF_Operation* one = ScalarConst(1, inner_params.body_graph, s_, "one");
251   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
252   TF_Operation* two = ScalarConst(2, inner_params.body_graph, s_, "two");
253   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
254 
255   TF_Operation* input2_add =
256       Add(inner_params.body_inputs[1].oper, one, inner_params.body_graph, s_);
257   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
258   inner_params.body_outputs[1] = {input2_add, 0};
259 
260   TF_Operation* inner_input1_add = Add(inner_params.body_inputs[0].oper, two,
261                                        inner_params.body_graph, s_, "add2");
262   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
263   inner_params.body_outputs[0] = {inner_input1_add, 0};
264 
265   // Finalize inner graph
266   TF_Output inner_outputs[2] = {{nullptr, -1}};
267   TF_FinishWhile(&inner_params, s_, inner_outputs);
268   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
269 
270   TF_Operation* input1_add =
271       Add(params_->body_inputs[0], inner_outputs[1], params_->body_graph, s_);
272   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
273   params_->body_outputs[0] = {input1_add, 0};
274 
275   params_->body_outputs[1] = inner_outputs[1];
276 
277   // Finalize outer graph
278   ExpectOK();
279 
280   // Check for a few expected nodes
281   const char* node_name = "test_loop/cond/scalar";
282   EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
283   node_name = "test_loop/body/add";
284   EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
285   node_name = "test_loop/body/inner_loop/body/one";
286   EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
287   node_name = "test_loop/body/inner_loop/cond/less_than";
288   EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
289 
290   // Run the graph
291   Run({0, 0});
292   ExpectOutputValue(0, 8);
293   ExpectOutputValue(1, 3);
294 }
295 
TEST_F(CApiWhileLoopTest,UnsetCondOutput)296 TEST_F(CApiWhileLoopTest, UnsetCondOutput) {
297   Init(1);
298   params_->body_outputs[0] = params_->body_inputs[0];
299   ExpectError(TF_INVALID_ARGUMENT,
300               "TF_WhileParams `cond_output` field isn't set");
301 }
302 
TEST_F(CApiWhileLoopTest,WrongCondOutputType)303 TEST_F(CApiWhileLoopTest, WrongCondOutputType) {
304   Init(1);
305   params_->cond_output = params_->cond_inputs[0];
306   params_->body_outputs[0] = params_->body_inputs[0];
307   ExpectError(TF_INVALID_ARGUMENT,
308               "BuildWhileLoop: 'cond' argument must return a boolean output, "
309               "got int32");
310 }
311 
TEST_F(CApiWhileLoopTest,InvalidCondOutputNode)312 TEST_F(CApiWhileLoopTest, InvalidCondOutputNode) {
313   Init(1);
314   // Try to reuse node from parent graph
315   params_->cond_output = inputs_[0];
316   params_->body_outputs[0] = params_->body_inputs[0];
317   // TODO(skyewm): this error message could be more informative. Add explicit
318   // checks for this case in the while loop implementation?
319   ExpectError(TF_INVALID_ARGUMENT,
320               "Requested return tensor 'p0:0' not found in graph def");
321 }
322 
TEST_F(CApiWhileLoopTest,InvalidCondOutputIndex)323 TEST_F(CApiWhileLoopTest, InvalidCondOutputIndex) {
324   Init(1);
325   CreateCondGraph();
326   params_->cond_output.index = 100;
327   params_->body_outputs[0] = params_->body_inputs[0];
328   ExpectError(TF_INVALID_ARGUMENT,
329               "Invalid return output 100 of node 'less_than', which has 1 "
330               "output(s)");
331 }
332 
333 // TODO(skyewm): test bad cond output shape
334 
TEST_F(CApiWhileLoopTest,UnsetBodyOutput)335 TEST_F(CApiWhileLoopTest, UnsetBodyOutput) {
336   Init(1);
337   CreateCondGraph();
338   ExpectError(TF_INVALID_ARGUMENT,
339               "TF_WhileParams `body_outputs[0]` field isn't set");
340 }
341 
342 // TODO(skyewm): enable this when it works (currently doesn't error)
343 // TEST_F(CApiWhileLoopTest, WrongBodyOutputType) {
344 //   Init(1);
345 //   CreateCondGraph();
346 //   TF_Operation* double_scalar =
347 //       ScalarConst(1.0, params_->body_graph, s_, "double_scalar");
348 //   params_->body_outputs[0] = {double_scalar, 0};
349 //   ExpectError(TF_INVALID_ARGUMENT, "bad body output type");
350 // }
351 
TEST_F(CApiWhileLoopTest,InvalidBodyOutputNode)352 TEST_F(CApiWhileLoopTest, InvalidBodyOutputNode) {
353   Init(1);
354   CreateCondGraph();
355   // Try to reuse node from parent graph
356   params_->body_outputs[0] = inputs_[0];
357   // TODO(skyewm): this error message could be more informative. Add explicit
358   // checks for this case in the while loop implementation?
359   ExpectError(TF_INVALID_ARGUMENT,
360               "Requested return tensor 'p0:0' not found in graph def");
361 }
362 
363 // TODO(skyewm): enable this when it works (currently segfaults!)
364 // TEST_F(CApiWhileLoopTest, InvalidBodyOutputIndex) {
365 //   Init(1);
366 //   CreateCondGraph();
367 //   params_->body_outputs[0] = params_->body_inputs[0];
368 //   params_->body_outputs[0].index = 100;
369 //   ExpectError(TF_INVALID_ARGUMENT,
370 //               "Invalid return output 100 of node 'less_than', which has 1 "
371 //               "output(s)");
372 // }
373 
374 // TODO(skyewm): test bad body output shape
375 
TEST_F(CApiWhileLoopTest,NullName)376 TEST_F(CApiWhileLoopTest, NullName) {
377   Init(1);
378   CreateCondGraph();
379   params_->body_outputs[0] = params_->body_inputs[0];
380   params_->name = nullptr;
381   ExpectError(TF_INVALID_ARGUMENT, "TF_WhileParams `name` field is null");
382 }
383 
TEST_F(CApiWhileLoopTest,WrongGraph)384 TEST_F(CApiWhileLoopTest, WrongGraph) {
385   Init(1);
386   CreateCondGraph();
387   // Set body output to output from outer graph
388   params_->body_outputs[0] = inputs_[0];
389   // TODO(skyewm): improve error message
390   ExpectError(TF_INVALID_ARGUMENT,
391               "Requested return tensor 'p0:0' not found in graph def");
392 }
393 
TEST_F(CApiWhileLoopTest,BadTypes)394 TEST_F(CApiWhileLoopTest, BadTypes) {
395   Init(1);
396   CreateCondGraph();
397   // Op that has a float input + output
398   TF_OperationDescription* desc = TF_NewOperation(
399       params_->body_graph, "FakeQuantWithMinMaxArgs", "float_op");
400   TF_AddInput(desc, params_->body_inputs[0]);
401   TF_FinishOperation(desc, s_);
402   ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
403   string msg(TF_Message(s_));
404   EXPECT_NE(msg.find("Input 'inputs' passed int32 expected float while "
405                      "building NodeDef 'float_op'"),
406             msg.npos);
407   TF_AbortWhile(params_.get());
408 }
409 
410 // This is a basic test to make sure the C++ gradient code can handle while
411 // loops created by the C API (which calls the C++ API under the hood). There
412 // are more while loop gradient tests in cc/framework/while_gradients_test.cc.
TEST_F(CApiWhileLoopTest,Gradients)413 TEST_F(CApiWhileLoopTest, Gradients) {
414   Init(1);
415 
416   // Create loop: while (i < 10) i += 1
417   TF_Operation* ten = ScalarConst(10, params_->cond_graph, s_);
418   TF_Operation* less_than =
419       LessThan(params_->cond_inputs[0], {ten, 0}, params_->cond_graph, s_);
420   DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
421   params_->cond_output = {less_than, 0};
422 
423   TF_Operation* one = ScalarConst(1, params_->body_graph, s_);
424   TF_Operation* add =
425       Add(params_->body_inputs[0], {one, 0}, params_->body_graph, s_);
426   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
427   params_->body_outputs[0] = {add, 0};
428 
429   ExpectOK();
430 
431   // Create backprop graph
432   TF_Output grad_output;
433   TF_AddGradients(graph_, outputs_.data(), outputs_.size(), inputs_.data(), 1,
434                   nullptr, s_, &grad_output);
435   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
436 
437   // Run gradient
438   Run({grad_output}, {0});
439   ExpectOutputValue(0, 1);
440 }
441 
442 }  // namespace
443