1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/ops/while_loop.h"
17 #include "tensorflow/cc/client/client_session.h"
18 #include "tensorflow/cc/ops/standard_ops.h"
19 #include "tensorflow/core/framework/tensor_testutil.h"
20 #include "tensorflow/core/graph/while_context.h"
21 #include "tensorflow/core/lib/core/status_test_util.h"
22 #include "tensorflow/core/platform/test.h"
23
24 namespace tensorflow {
25
26 namespace {
27
28 class WhileLoopTest : public ::testing::Test {
29 protected:
WhileLoopTest()30 WhileLoopTest() : scope_(Scope::NewRootScope()) {}
31
Init(int num_inputs,DataType dtype=DT_INT32)32 void Init(int num_inputs, DataType dtype = DT_INT32) {
33 for (int i = 0; i < num_inputs; ++i) {
34 inputs_.push_back(ops::Placeholder(scope_, dtype));
35 }
36 }
37
CreateLoop(const ops::CondGraphBuilderFn & cond,const ops::BodyGraphBuilderFn & body,error::Code error_code=error::OK,const string & error_msg="")38 void CreateLoop(const ops::CondGraphBuilderFn& cond,
39 const ops::BodyGraphBuilderFn& body,
40 error::Code error_code = error::OK,
41 const string& error_msg = "") {
42 Status s =
43 ops::BuildWhileLoop(scope_, inputs_, cond, body, kFrameName, &outputs_);
44 EXPECT_EQ(s.code(), error_code);
45 EXPECT_EQ(s.error_message(), error_msg);
46 }
47
48 template <typename T>
Run(const std::vector<Input::Initializer> & input_values,const std::vector<T> & expected_output_values)49 void Run(const std::vector<Input::Initializer>& input_values,
50 const std::vector<T>& expected_output_values) {
51 ClientSession session(scope_);
52
53 DCHECK_EQ(input_values.size(), inputs_.size());
54 ClientSession::FeedType feeds;
55 for (int i = 0; i < inputs_.size(); ++i) {
56 feeds.emplace(inputs_[i], input_values[i]);
57 }
58
59 std::vector<Tensor> out_tensors;
60 TF_ASSERT_OK(session.Run(feeds, outputs_, &out_tensors));
61 ASSERT_EQ(out_tensors.size(), outputs_.size());
62
63 DCHECK_EQ(expected_output_values.size(), out_tensors.size());
64 for (int i = 0; i < out_tensors.size(); ++i) {
65 test::ExpectTensorEqual<T>(
66 out_tensors[i], test::AsTensor<T>({expected_output_values[i]}, {}));
67 }
68 }
69
70 Scope scope_;
71 std::vector<Output> inputs_;
72 std::vector<Output> outputs_;
73
74 static const char* const kFrameName;
75 };
76
77 const char* const WhileLoopTest::kFrameName = "test_loop";
78
LessThanTenCond(const Scope & s,const std::vector<Output> & inputs,Output * output)79 Status LessThanTenCond(const Scope& s, const std::vector<Output>& inputs,
80 Output* output) {
81 *output = ops::Less(s, inputs[0], 10);
82 return s.status();
83 }
84
AddOneBody(const Scope & s,const std::vector<Output> & inputs,std::vector<Output> * outputs)85 Status AddOneBody(const Scope& s, const std::vector<Output>& inputs,
86 std::vector<Output>* outputs) {
87 outputs->push_back(ops::Add(s, inputs[0], 1));
88 return s.status();
89 }
90
TEST_F(WhileLoopTest,Basic)91 TEST_F(WhileLoopTest, Basic) {
92 // Create loop: while (i < 10) i += 1
93 Init(1);
94 CreateLoop(LessThanTenCond, AddOneBody);
95
96 // Verify some output invariants
97 WhileContext* while_ctx;
98 for (int i = 0; i < outputs_.size(); ++i) {
99 Node* node = outputs_[i].node();
100 ASSERT_TRUE(node->IsExit()) << "Output node " << i << ":\n"
101 << node->DebugString();
102 ASSERT_TRUE(node->while_ctx() != nullptr) << i;
103 if (i == 0) {
104 while_ctx = node->while_ctx();
105 EXPECT_EQ(while_ctx->frame_name(), kFrameName);
106 } else {
107 EXPECT_EQ(node->while_ctx(), while_ctx) << i;
108 }
109 }
110
111 // Run the loop and test we get the expected results
112 Run<int>({1}, {10});
113 Run<int>({11}, {11});
114 }
115
TEST_F(WhileLoopTest,WrongCondOutputType)116 TEST_F(WhileLoopTest, WrongCondOutputType) {
117 Init(1);
118 CreateLoop(
119 [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
120 *output = ops::Placeholder(s, DT_FLOAT);
121 return s.status();
122 },
123 AddOneBody, error::INVALID_ARGUMENT,
124 "BuildWhileLoop: 'cond' argument must return a boolean output, got "
125 "float");
126 }
127
128 // TODO(skyewm): test bad cond output shape
129
TEST_F(WhileLoopTest,NullCondOutputNode)130 TEST_F(WhileLoopTest, NullCondOutputNode) {
131 Init(1);
132 // TODO(skyewm): improve error message
133 CreateLoop(
134 [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
135 *output = {nullptr, 0};
136 return s.status();
137 },
138 AddOneBody, error::INVALID_ARGUMENT, "Node is null");
139 }
140
TEST_F(WhileLoopTest,InvalidCondOutputIndex)141 TEST_F(WhileLoopTest, InvalidCondOutputIndex) {
142 Init(1);
143 CreateLoop(
144 [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
145 auto less = ops::Less(s, inputs[0], 10);
146 *output = {less.node(), 100};
147 return s.status();
148 },
149 AddOneBody, error::OUT_OF_RANGE,
150 "Node 'cond/Less' (type: 'Less', num of outputs: 1) does not have output "
151 "100");
152 }
153
TEST_F(WhileLoopTest,UnsetCondOutput)154 TEST_F(WhileLoopTest, UnsetCondOutput) {
155 Init(1);
156 CreateLoop([](const Scope& s, const std::vector<Output>& inputs,
157 Output* output) { return s.status(); },
158 AddOneBody, error::INVALID_ARGUMENT, "Node is null");
159 }
160
161 // TODO(skyewm): test bad body output type
162 // TODO(skyewm): test bad body output shape
163
TEST_F(WhileLoopTest,NullBodyOutputNode)164 TEST_F(WhileLoopTest, NullBodyOutputNode) {
165 Init(1);
166 // TODO(skyewm): improve error message
167 CreateLoop(LessThanTenCond,
168 [](const Scope& s, const std::vector<Output>& inputs,
169 std::vector<Output>* outputs) {
170 outputs->push_back({nullptr, 0});
171 return s.status();
172 },
173 error::INVALID_ARGUMENT, "Node is null");
174 }
175
TEST_F(WhileLoopTest,InvalidBodyOutputIndex)176 TEST_F(WhileLoopTest, InvalidBodyOutputIndex) {
177 Init(1);
178 CreateLoop(LessThanTenCond,
179 [](const Scope& s, const std::vector<Output>& inputs,
180 std::vector<Output>* outputs) {
181 auto add = ops::Add(s, inputs[0], 1);
182 outputs->emplace_back(add.node(), 100);
183 return s.status();
184 },
185 error::OUT_OF_RANGE,
186 "Node 'body/Add' (type: 'Add', num of outputs: 1) does not have "
187 "output 100");
188 }
189
TEST_F(WhileLoopTest,UnsetBodyOutputs)190 TEST_F(WhileLoopTest, UnsetBodyOutputs) {
191 Init(1);
192 CreateLoop(
193 LessThanTenCond,
194 [](const Scope& s, const std::vector<Output>& inputs,
195 std::vector<Output>* outputs) { return s.status(); },
196 error::INVALID_ARGUMENT,
197 "BuildWhileLoop: 'body' argument expected to return 1 output(s), got 0");
198 }
199
200 } // namespace
201 } // namespace tensorflow
202