1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/c_api.h"
17 #include "tensorflow/c/c_test_util.h"
18 #include "tensorflow/core/platform/logging.h"
19 #include "tensorflow/core/platform/strcat.h"
20 #include "tensorflow/core/platform/test.h"
21
22 using tensorflow::GraphDef;
23
24 namespace {
25
26 class CApiWhileLoopTest : public ::testing::Test {
27 protected:
CApiWhileLoopTest()28 CApiWhileLoopTest() : s_(TF_NewStatus()), graph_(TF_NewGraph()) {}
29
~CApiWhileLoopTest()30 ~CApiWhileLoopTest() override {
31 TF_DeleteGraph(graph_);
32 TF_DeleteStatus(s_);
33 }
34
Init(int ninputs)35 void Init(int ninputs) {
36 DCHECK(inputs_.empty());
37 DCHECK_GT(ninputs, 0);
38
39 for (int i = 0; i < ninputs; ++i) {
40 TF_Operation* placeholder = Placeholder(
41 graph_, s_, ::tensorflow::strings::StrCat("p", i).c_str());
42 DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
43 inputs_.push_back({placeholder, 0});
44 }
45
46 original_graph_description_ = GraphDebugString();
47
48 params_.reset(new TF_WhileParams(
49 TF_NewWhile(graph_, &inputs_[0], inputs_.size(), s_)));
50 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
51 ASSERT_EQ(original_graph_description_, GraphDebugString())
52 << "TF_NewWhile() altered graph";
53
54 params_->name = "test_loop";
55
56 // Initialize outputs_ so we can easily detect errors/bugs
57 outputs_.resize(ninputs, {nullptr, -1});
58 }
59
ExpectOK()60 void ExpectOK() {
61 TF_FinishWhile(params_.get(), s_, &outputs_[0]);
62 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
63 }
64
ExpectError(TF_Code expected_code,const string & expected_msg)65 void ExpectError(TF_Code expected_code, const string& expected_msg) {
66 TF_FinishWhile(params_.get(), s_, &outputs_[0]);
67 EXPECT_EQ(expected_code, TF_GetCode(s_));
68 EXPECT_EQ(expected_msg, TF_Message(s_));
69 // TODO(skyewm): this assert is currently broken. Fix or remove guarantee.
70 // ASSERT_EQ(original_graph_description_, GraphDebugString()) <<
71 // "TF_FinishWhile() altered graph on error";
72 }
73
Run(std::initializer_list<int> input_values)74 void Run(std::initializer_list<int> input_values) {
75 Run(outputs_, input_values);
76 }
77
Run(const std::vector<TF_Output> & run_outputs,std::initializer_list<int> input_values)78 void Run(const std::vector<TF_Output>& run_outputs,
79 std::initializer_list<int> input_values) {
80 DCHECK_EQ(inputs_.size(), input_values.size());
81 std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs(inputs_.size());
82 int i = 0;
83 for (int v : input_values) {
84 inputs[i] = {inputs_[i].oper, Int32Tensor(v)};
85 ++i;
86 }
87 // TODO(skyewm): use std::make_unique or absl::make_unique when possible.
88 csession_.reset(new CSession(graph_, s_));
89 csession_->SetInputs(inputs);
90 csession_->SetOutputs(run_outputs);
91 csession_->Run(s_);
92 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
93 }
94
ExpectOutputValue(int idx,int expected_value)95 void ExpectOutputValue(int idx, int expected_value) {
96 TF_Tensor* out = csession_->output_tensor(idx);
97 ASSERT_TRUE(out != nullptr);
98 EXPECT_EQ(TF_INT32, TF_TensorType(out));
99 EXPECT_EQ(0, TF_NumDims(out));
100 ASSERT_EQ(sizeof(int32_t), TF_TensorByteSize(out));
101 int32_t* data = static_cast<int32_t*>(TF_TensorData(out));
102 EXPECT_EQ(expected_value, *data);
103 }
104
105 // Create a valid conditional graph. Useful for testing unrelated errors.
CreateCondGraph()106 void CreateCondGraph() {
107 TF_Operation* one = ScalarConst(1, params_->cond_graph, s_);
108 TF_Operation* less_than =
109 LessThan(params_->cond_inputs[0], {one, 0}, params_->cond_graph, s_);
110 DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
111 params_->cond_output = {less_than, 0};
112 }
113
GraphDebugString() const114 string GraphDebugString() const {
115 TF_Buffer* buf = TF_NewBuffer();
116 TF_GraphToGraphDef(graph_, buf, s_);
117 DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
118 GraphDef def;
119 bool success = def.ParseFromArray(buf->data, buf->length);
120 DCHECK(success);
121 TF_DeleteBuffer(buf);
122 return def.DebugString();
123 }
124
125 TF_Status* s_;
126 TF_Graph* graph_;
127 std::vector<TF_Output> inputs_; // The inputs to the while loop
128 std::vector<TF_Output> outputs_; // The final outputs of the while loop
129 std::unique_ptr<TF_WhileParams> params_;
130 std::unique_ptr<CSession> csession_;
131
132 private:
133 // Used to verify that errors don't change graph_
134 string original_graph_description_;
135 };
136
TEST_F(CApiWhileLoopTest,BasicLoop)137 TEST_F(CApiWhileLoopTest, BasicLoop) {
138 Init(2);
139
140 // Validate TF_WhileParams returned by TF_NewWhile()
141 EXPECT_TRUE(params_->body_graph != nullptr);
142 EXPECT_TRUE(params_->cond_graph != nullptr);
143
144 EXPECT_EQ(params_->ninputs, 2);
145
146 ASSERT_TRUE(params_->cond_inputs != nullptr);
147 ASSERT_TRUE(params_->cond_inputs[0].oper != nullptr);
148 EXPECT_TRUE(params_->cond_inputs[1].oper != nullptr);
149
150 ASSERT_TRUE(params_->body_inputs != nullptr);
151 EXPECT_TRUE(params_->body_inputs[0].oper != nullptr);
152 EXPECT_TRUE(params_->body_inputs[1].oper != nullptr);
153
154 ASSERT_TRUE(params_->body_outputs != nullptr);
155
156 // Create loop: while (input1 < input2) input1 += input2 + 1
157 TF_Operation* less_than =
158 LessThan(params_->cond_inputs[0], params_->cond_inputs[1],
159 params_->cond_graph, s_);
160 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
161 params_->cond_output = {less_than, 0};
162
163 TF_Operation* add1 = Add(params_->body_inputs[0], params_->body_inputs[1],
164 params_->body_graph, s_, "add1");
165 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
166 TF_Operation* one = ScalarConst(1, params_->body_graph, s_);
167 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
168 TF_Operation* add2 = Add(add1, one, params_->body_graph, s_, "add2");
169 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
170 params_->body_outputs[0] = {add2, 0};
171 params_->body_outputs[1] = params_->body_inputs[1];
172
173 // Finalize while loop
174 ExpectOK();
175
176 // Validate while loop outputs returned by TF_FinishWhile()
177 EXPECT_TRUE(outputs_[0].oper != nullptr);
178 EXPECT_GE(outputs_[0].index, 0);
179 EXPECT_TRUE(outputs_[1].oper != nullptr);
180 EXPECT_GE(outputs_[1].index, 0);
181
182 // Check that cond and body inputs are not present
183 for (int i = 0; i < params_->ninputs; ++i) {
184 string cond_name =
185 ::tensorflow::strings::StrCat(params_->name, "/cond/cond_input", i);
186 string body_name =
187 ::tensorflow::strings::StrCat(params_->name, "/body/body_input", i);
188 EXPECT_TRUE(TF_GraphOperationByName(graph_, cond_name.c_str()) == nullptr);
189 EXPECT_TRUE(TF_GraphOperationByName(graph_, body_name.c_str()) == nullptr);
190 }
191
192 // Run the graph
193 Run({-9, 2});
194 ExpectOutputValue(0, 3);
195 ExpectOutputValue(1, 2);
196 }
197
TEST_F(CApiWhileLoopTest,NestedLoop)198 TEST_F(CApiWhileLoopTest, NestedLoop) {
199 Init(2);
200 // Create nested loop:
201 // while (input1 < 6) {
202 // inner_input1 = input1
203 // while (inner_input1 < 3) {
204 // input2 += 1
205 // inner_input1 += 2
206 // }
207 // input1 += input2
208 // }
209 //
210 // Expected execution with initial values input1 = input2 = 0:
211 //
212 // outer inner inner_
213 // step# step# input1 input2 input1
214 // ------------------------------------
215 // 0 0 0 0 0
216 // 0 1 0 1 2
217 // 0 2 0 2 4
218 // 0 - 2 2 -
219 // 1 0 2 2 2
220 // 1 1 2 3 4
221 // 1 - 5 3 -
222 // 2 0 5 3 5
223 // 2 - 8 3 -
224
225 // Create outer cond graph
226 TF_Operation* six = ScalarConst(6, params_->cond_graph, s_);
227 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
228 TF_Operation* less_than =
229 LessThan(params_->cond_inputs[0], {six, 0}, params_->cond_graph, s_);
230 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
231 params_->cond_output = {less_than, 0};
232
233 // Create outer body graph
234 // Init inner graph
235 TF_Output inner_inputs[] = {params_->body_inputs[0], params_->body_inputs[1]};
236 TF_WhileParams inner_params =
237 TF_NewWhile(params_->body_graph, inner_inputs, 2, s_);
238 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
239 inner_params.name = "inner_loop";
240
241 // Create inner cond graph
242 TF_Operation* three = ScalarConst(3, inner_params.cond_graph, s_);
243 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
244 TF_Operation* inner_less_than = LessThan(
245 inner_params.cond_inputs[0], {three, 0}, inner_params.cond_graph, s_);
246 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
247 inner_params.cond_output = {inner_less_than, 0};
248
249 // Create inner body graph
250 TF_Operation* one = ScalarConst(1, inner_params.body_graph, s_, "one");
251 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
252 TF_Operation* two = ScalarConst(2, inner_params.body_graph, s_, "two");
253 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
254
255 TF_Operation* input2_add =
256 Add(inner_params.body_inputs[1].oper, one, inner_params.body_graph, s_);
257 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
258 inner_params.body_outputs[1] = {input2_add, 0};
259
260 TF_Operation* inner_input1_add = Add(inner_params.body_inputs[0].oper, two,
261 inner_params.body_graph, s_, "add2");
262 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
263 inner_params.body_outputs[0] = {inner_input1_add, 0};
264
265 // Finalize inner graph
266 TF_Output inner_outputs[2] = {{nullptr, -1}};
267 TF_FinishWhile(&inner_params, s_, inner_outputs);
268 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
269
270 TF_Operation* input1_add =
271 Add(params_->body_inputs[0], inner_outputs[1], params_->body_graph, s_);
272 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
273 params_->body_outputs[0] = {input1_add, 0};
274
275 params_->body_outputs[1] = inner_outputs[1];
276
277 // Finalize outer graph
278 ExpectOK();
279
280 // Check for a few expected nodes
281 const char* node_name = "test_loop/cond/scalar";
282 EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
283 node_name = "test_loop/body/add";
284 EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
285 node_name = "test_loop/body/inner_loop/body/one";
286 EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
287 node_name = "test_loop/body/inner_loop/cond/less_than";
288 EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
289
290 // Run the graph
291 Run({0, 0});
292 ExpectOutputValue(0, 8);
293 ExpectOutputValue(1, 3);
294 }
295
TEST_F(CApiWhileLoopTest,UnsetCondOutput)296 TEST_F(CApiWhileLoopTest, UnsetCondOutput) {
297 Init(1);
298 params_->body_outputs[0] = params_->body_inputs[0];
299 ExpectError(TF_INVALID_ARGUMENT,
300 "TF_WhileParams `cond_output` field isn't set");
301 }
302
TEST_F(CApiWhileLoopTest,WrongCondOutputType)303 TEST_F(CApiWhileLoopTest, WrongCondOutputType) {
304 Init(1);
305 params_->cond_output = params_->cond_inputs[0];
306 params_->body_outputs[0] = params_->body_inputs[0];
307 ExpectError(TF_INVALID_ARGUMENT,
308 "BuildWhileLoop: 'cond' argument must return a boolean output, "
309 "got int32");
310 }
311
TEST_F(CApiWhileLoopTest,InvalidCondOutputNode)312 TEST_F(CApiWhileLoopTest, InvalidCondOutputNode) {
313 Init(1);
314 // Try to reuse node from parent graph
315 params_->cond_output = inputs_[0];
316 params_->body_outputs[0] = params_->body_inputs[0];
317 // TODO(skyewm): this error message could be more informative. Add explicit
318 // checks for this case in the while loop implementation?
319 ExpectError(TF_INVALID_ARGUMENT,
320 "Requested return tensor 'p0:0' not found in graph def");
321 }
322
TEST_F(CApiWhileLoopTest,InvalidCondOutputIndex)323 TEST_F(CApiWhileLoopTest, InvalidCondOutputIndex) {
324 Init(1);
325 CreateCondGraph();
326 params_->cond_output.index = 100;
327 params_->body_outputs[0] = params_->body_inputs[0];
328 ExpectError(TF_INVALID_ARGUMENT,
329 "Invalid return output 100 of node 'less_than', which has 1 "
330 "output(s)");
331 }
332
333 // TODO(skyewm): test bad cond output shape
334
TEST_F(CApiWhileLoopTest,UnsetBodyOutput)335 TEST_F(CApiWhileLoopTest, UnsetBodyOutput) {
336 Init(1);
337 CreateCondGraph();
338 ExpectError(TF_INVALID_ARGUMENT,
339 "TF_WhileParams `body_outputs[0]` field isn't set");
340 }
341
342 // TODO(skyewm): enable this when it works (currently doesn't error)
343 // TEST_F(CApiWhileLoopTest, WrongBodyOutputType) {
344 // Init(1);
345 // CreateCondGraph();
346 // TF_Operation* double_scalar =
347 // ScalarConst(1.0, params_->body_graph, s_, "double_scalar");
348 // params_->body_outputs[0] = {double_scalar, 0};
349 // ExpectError(TF_INVALID_ARGUMENT, "bad body output type");
350 // }
351
TEST_F(CApiWhileLoopTest,InvalidBodyOutputNode)352 TEST_F(CApiWhileLoopTest, InvalidBodyOutputNode) {
353 Init(1);
354 CreateCondGraph();
355 // Try to reuse node from parent graph
356 params_->body_outputs[0] = inputs_[0];
357 // TODO(skyewm): this error message could be more informative. Add explicit
358 // checks for this case in the while loop implementation?
359 ExpectError(TF_INVALID_ARGUMENT,
360 "Requested return tensor 'p0:0' not found in graph def");
361 }
362
363 // TODO(skyewm): enable this when it works (currently segfaults!)
364 // TEST_F(CApiWhileLoopTest, InvalidBodyOutputIndex) {
365 // Init(1);
366 // CreateCondGraph();
367 // params_->body_outputs[0] = params_->body_inputs[0];
368 // params_->body_outputs[0].index = 100;
369 // ExpectError(TF_INVALID_ARGUMENT,
370 // "Invalid return output 100 of node 'less_than', which has 1 "
371 // "output(s)");
372 // }
373
374 // TODO(skyewm): test bad body output shape
375
TEST_F(CApiWhileLoopTest,NullName)376 TEST_F(CApiWhileLoopTest, NullName) {
377 Init(1);
378 CreateCondGraph();
379 params_->body_outputs[0] = params_->body_inputs[0];
380 params_->name = nullptr;
381 ExpectError(TF_INVALID_ARGUMENT, "TF_WhileParams `name` field is null");
382 }
383
TEST_F(CApiWhileLoopTest,WrongGraph)384 TEST_F(CApiWhileLoopTest, WrongGraph) {
385 Init(1);
386 CreateCondGraph();
387 // Set body output to output from outer graph
388 params_->body_outputs[0] = inputs_[0];
389 // TODO(skyewm): improve error message
390 ExpectError(TF_INVALID_ARGUMENT,
391 "Requested return tensor 'p0:0' not found in graph def");
392 }
393
TEST_F(CApiWhileLoopTest,BadTypes)394 TEST_F(CApiWhileLoopTest, BadTypes) {
395 Init(1);
396 CreateCondGraph();
397 // Op that has a float input + output
398 TF_OperationDescription* desc = TF_NewOperation(
399 params_->body_graph, "FakeQuantWithMinMaxArgs", "float_op");
400 TF_AddInput(desc, params_->body_inputs[0]);
401 TF_FinishOperation(desc, s_);
402 ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
403 string msg(TF_Message(s_));
404 EXPECT_NE(msg.find("Input 'inputs' passed int32 expected float while "
405 "building NodeDef 'float_op'"),
406 msg.npos);
407 TF_AbortWhile(params_.get());
408 }
409
410 // This is a basic test to make sure the C++ gradient code can handle while
411 // loops created by the C API (which calls the C++ API under the hood). There
412 // are more while loop gradient tests in cc/framework/while_gradients_test.cc.
TEST_F(CApiWhileLoopTest,Gradients)413 TEST_F(CApiWhileLoopTest, Gradients) {
414 Init(1);
415
416 // Create loop: while (i < 10) i += 1
417 TF_Operation* ten = ScalarConst(10, params_->cond_graph, s_);
418 TF_Operation* less_than =
419 LessThan(params_->cond_inputs[0], {ten, 0}, params_->cond_graph, s_);
420 DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
421 params_->cond_output = {less_than, 0};
422
423 TF_Operation* one = ScalarConst(1, params_->body_graph, s_);
424 TF_Operation* add =
425 Add(params_->body_inputs[0], {one, 0}, params_->body_graph, s_);
426 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
427 params_->body_outputs[0] = {add, 0};
428
429 ExpectOK();
430
431 // Create backprop graph
432 TF_Output grad_output;
433 TF_AddGradients(graph_, outputs_.data(), outputs_.size(), inputs_.data(), 1,
434 nullptr, s_, &grad_output);
435 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
436
437 // Run gradient
438 Run({grad_output}, {0});
439 ExpectOutputValue(0, 1);
440 }
441
442 } // namespace
443