• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/ops/while_loop.h"
17 #include "tensorflow/cc/client/client_session.h"
18 #include "tensorflow/cc/ops/standard_ops.h"
19 #include "tensorflow/core/framework/tensor_testutil.h"
20 #include "tensorflow/core/graph/while_context.h"
21 #include "tensorflow/core/lib/core/status_test_util.h"
22 #include "tensorflow/core/platform/test.h"
23 
24 namespace tensorflow {
25 
26 namespace {
27 
28 class WhileLoopTest : public ::testing::Test {
29  protected:
WhileLoopTest()30   WhileLoopTest() : scope_(Scope::NewRootScope()) {}
31 
Init(int num_inputs,DataType dtype=DT_INT32)32   void Init(int num_inputs, DataType dtype = DT_INT32) {
33     for (int i = 0; i < num_inputs; ++i) {
34       inputs_.push_back(ops::Placeholder(scope_, dtype));
35     }
36   }
37 
CreateLoop(const ops::CondGraphBuilderFn & cond,const ops::BodyGraphBuilderFn & body,error::Code error_code=error::OK,const string & error_msg="")38   void CreateLoop(const ops::CondGraphBuilderFn& cond,
39                   const ops::BodyGraphBuilderFn& body,
40                   error::Code error_code = error::OK,
41                   const string& error_msg = "") {
42     Status s =
43         ops::BuildWhileLoop(scope_, inputs_, cond, body, kFrameName, &outputs_);
44     EXPECT_EQ(s.code(), error_code);
45     EXPECT_EQ(s.error_message(), error_msg);
46   }
47 
48   template <typename T>
Run(const std::vector<Input::Initializer> & input_values,const std::vector<T> & expected_output_values)49   void Run(const std::vector<Input::Initializer>& input_values,
50            const std::vector<T>& expected_output_values) {
51     ClientSession session(scope_);
52 
53     DCHECK_EQ(input_values.size(), inputs_.size());
54     ClientSession::FeedType feeds;
55     for (int i = 0; i < inputs_.size(); ++i) {
56       feeds.emplace(inputs_[i], input_values[i]);
57     }
58 
59     std::vector<Tensor> out_tensors;
60     TF_ASSERT_OK(session.Run(feeds, outputs_, &out_tensors));
61     ASSERT_EQ(out_tensors.size(), outputs_.size());
62 
63     DCHECK_EQ(expected_output_values.size(), out_tensors.size());
64     for (int i = 0; i < out_tensors.size(); ++i) {
65       test::ExpectTensorEqual<T>(
66           out_tensors[i], test::AsTensor<T>({expected_output_values[i]}, {}));
67     }
68   }
69 
70   Scope scope_;
71   std::vector<Output> inputs_;
72   std::vector<Output> outputs_;
73 
74   static const char* const kFrameName;
75 };
76 
77 const char* const WhileLoopTest::kFrameName = "test_loop";
78 
LessThanTenCond(const Scope & s,const std::vector<Output> & inputs,Output * output)79 Status LessThanTenCond(const Scope& s, const std::vector<Output>& inputs,
80                        Output* output) {
81   *output = ops::Less(s, inputs[0], 10);
82   return s.status();
83 }
84 
AddOneBody(const Scope & s,const std::vector<Output> & inputs,std::vector<Output> * outputs)85 Status AddOneBody(const Scope& s, const std::vector<Output>& inputs,
86                   std::vector<Output>* outputs) {
87   outputs->push_back(ops::Add(s, inputs[0], 1));
88   return s.status();
89 }
90 
TEST_F(WhileLoopTest,Basic)91 TEST_F(WhileLoopTest, Basic) {
92   // Create loop: while (i < 10) i += 1
93   Init(1);
94   CreateLoop(LessThanTenCond, AddOneBody);
95 
96   // Verify some output invariants
97   WhileContext* while_ctx;
98   for (int i = 0; i < outputs_.size(); ++i) {
99     Node* node = outputs_[i].node();
100     ASSERT_TRUE(node->IsExit()) << "Output node " << i << ":\n"
101                                 << node->DebugString();
102     ASSERT_TRUE(node->while_ctx() != nullptr) << i;
103     if (i == 0) {
104       while_ctx = node->while_ctx();
105       EXPECT_EQ(while_ctx->frame_name(), kFrameName);
106     } else {
107       EXPECT_EQ(node->while_ctx(), while_ctx) << i;
108     }
109   }
110 
111   // Run the loop and test we get the expected results
112   Run<int>({1}, {10});
113   Run<int>({11}, {11});
114 }
115 
TEST_F(WhileLoopTest,WrongCondOutputType)116 TEST_F(WhileLoopTest, WrongCondOutputType) {
117   Init(1);
118   CreateLoop(
119       [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
120         *output = ops::Placeholder(s, DT_FLOAT);
121         return s.status();
122       },
123       AddOneBody, error::INVALID_ARGUMENT,
124       "BuildWhileLoop: 'cond' argument must return a boolean output, got "
125       "float");
126 }
127 
128 // TODO(skyewm): test bad cond output shape
129 
TEST_F(WhileLoopTest,NullCondOutputNode)130 TEST_F(WhileLoopTest, NullCondOutputNode) {
131   Init(1);
132   // TODO(skyewm): improve error message
133   CreateLoop(
134       [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
135         *output = {nullptr, 0};
136         return s.status();
137       },
138       AddOneBody, error::INVALID_ARGUMENT, "Node is null");
139 }
140 
TEST_F(WhileLoopTest,InvalidCondOutputIndex)141 TEST_F(WhileLoopTest, InvalidCondOutputIndex) {
142   Init(1);
143   CreateLoop(
144       [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
145         auto less = ops::Less(s, inputs[0], 10);
146         *output = {less.node(), 100};
147         return s.status();
148       },
149       AddOneBody, error::OUT_OF_RANGE,
150       "Node 'cond/Less' (type: 'Less', num of outputs: 1) does not have output "
151       "100");
152 }
153 
TEST_F(WhileLoopTest,UnsetCondOutput)154 TEST_F(WhileLoopTest, UnsetCondOutput) {
155   Init(1);
156   CreateLoop([](const Scope& s, const std::vector<Output>& inputs,
157                 Output* output) { return s.status(); },
158              AddOneBody, error::INVALID_ARGUMENT, "Node is null");
159 }
160 
161 // TODO(skyewm): test bad body output type
162 // TODO(skyewm): test bad body output shape
163 
TEST_F(WhileLoopTest,NullBodyOutputNode)164 TEST_F(WhileLoopTest, NullBodyOutputNode) {
165   Init(1);
166   // TODO(skyewm): improve error message
167   CreateLoop(LessThanTenCond,
168              [](const Scope& s, const std::vector<Output>& inputs,
169                 std::vector<Output>* outputs) {
170                outputs->push_back({nullptr, 0});
171                return s.status();
172              },
173              error::INVALID_ARGUMENT, "Node is null");
174 }
175 
TEST_F(WhileLoopTest,InvalidBodyOutputIndex)176 TEST_F(WhileLoopTest, InvalidBodyOutputIndex) {
177   Init(1);
178   CreateLoop(LessThanTenCond,
179              [](const Scope& s, const std::vector<Output>& inputs,
180                 std::vector<Output>* outputs) {
181                auto add = ops::Add(s, inputs[0], 1);
182                outputs->emplace_back(add.node(), 100);
183                return s.status();
184              },
185              error::OUT_OF_RANGE,
186              "Node 'body/Add' (type: 'Add', num of outputs: 1) does not have "
187              "output 100");
188 }
189 
TEST_F(WhileLoopTest,UnsetBodyOutputs)190 TEST_F(WhileLoopTest, UnsetBodyOutputs) {
191   Init(1);
192   CreateLoop(
193       LessThanTenCond,
194       [](const Scope& s, const std::vector<Output>& inputs,
195          std::vector<Output>* outputs) { return s.status(); },
196       error::INVALID_ARGUMENT,
197       "BuildWhileLoop: 'body' argument expected to return 1 output(s), got 0");
198 }
199 
200 }  // namespace
201 }  // namespace tensorflow
202