1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/client/client_session.h"
17 #include "tensorflow/cc/framework/gradients.h"
18 #include "tensorflow/cc/framework/testutil.h"
19 #include "tensorflow/cc/ops/standard_ops.h"
20 #include "tensorflow/cc/ops/while_loop.h"
21 #include "tensorflow/core/framework/graph.pb.h"
22 #include "tensorflow/core/framework/tensor_testutil.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/platform/test.h"
25
26 namespace tensorflow {
27
28 namespace {
29
30 class WhileGradientsTest : public ::testing::Test {
31 protected:
WhileGradientsTest()32 WhileGradientsTest() : scope_(Scope::NewRootScope()) {}
33
Init(int num_inputs,DataType dtype=DT_INT32)34 void Init(int num_inputs, DataType dtype = DT_INT32) {
35 for (int i = 0; i < num_inputs; ++i) {
36 inputs_.push_back(ops::Placeholder(scope_, dtype));
37 }
38 }
39
CreateLoop(const ops::CondGraphBuilderFn & cond,const ops::BodyGraphBuilderFn & body,const std::vector<Output> * inputs=nullptr)40 void CreateLoop(const ops::CondGraphBuilderFn& cond,
41 const ops::BodyGraphBuilderFn& body,
42 const std::vector<Output>* inputs = nullptr) {
43 if (inputs == nullptr) inputs = &inputs_;
44 TF_ASSERT_OK(ops::BuildWhileLoop(scope_, *inputs, cond, body, "test_loop",
45 &outputs_));
46 }
47
CreateBackprop()48 void CreateBackprop() {
49 TF_ASSERT_OK(
50 AddSymbolicGradients(scope_, outputs_, inputs_, &grad_outputs_));
51 ASSERT_EQ(grad_outputs_.size(), inputs_.size());
52 }
53
54 template <typename T>
Run(const std::vector<Input::Initializer> & input_values,const std::vector<T> & expected_grad_values)55 void Run(const std::vector<Input::Initializer>& input_values,
56 const std::vector<T>& expected_grad_values) {
57 Run<T>(ClientSession(scope_), input_values, expected_grad_values);
58 }
59
60 template <typename T>
Run(const ClientSession & session,const std::vector<Input::Initializer> & input_values,const std::vector<T> & expected_grad_values,const RunOptions & run_options=RunOptions (),RunMetadata * run_metadata=nullptr)61 void Run(const ClientSession& session,
62 const std::vector<Input::Initializer>& input_values,
63 const std::vector<T>& expected_grad_values,
64 const RunOptions& run_options = RunOptions(),
65 RunMetadata* run_metadata = nullptr) {
66 DCHECK_EQ(input_values.size(), inputs_.size());
67 ClientSession::FeedType feeds;
68 for (int i = 0; i < inputs_.size(); ++i) {
69 feeds.emplace(inputs_[i], input_values[i]);
70 }
71
72 std::vector<Operation> run_outputs;
73 std::vector<Tensor> out_tensors;
74 TF_ASSERT_OK(session.Run(run_options, feeds, grad_outputs_, run_outputs,
75 &out_tensors, run_metadata));
76 ASSERT_EQ(out_tensors.size(), grad_outputs_.size());
77
78 DCHECK_EQ(expected_grad_values.size(), out_tensors.size());
79 for (int i = 0; i < out_tensors.size(); ++i) {
80 test::ExpectTensorEqual<T>(
81 out_tensors[i], test::AsTensor<T>({expected_grad_values[i]}, {}));
82 }
83 }
84
85 Scope scope_;
86 std::vector<Output> inputs_;
87 std::vector<Output> outputs_;
88 std::vector<Output> grad_outputs_;
89 };
90
TEST_F(WhileGradientsTest,Basic)91 TEST_F(WhileGradientsTest, Basic) {
92 // Create loop: while (i < 10) i += 1
93 Init(1);
94 CreateLoop(
95 [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
96 *output = ops::Less(s, inputs[0], 10);
97 return s.status();
98 },
99 [](const Scope& s, const std::vector<Output>& inputs,
100 std::vector<Output>* outputs) {
101 // Use AddN, rather than Add, because the gradient function doesn't
102 // depend on the input shapes, and thus we do not need to store
103 // intermediate values in a stack.
104 outputs->push_back(ops::AddN(s, {inputs[0], 1}));
105 return s.status();
106 });
107 CreateBackprop();
108
109 Run<int>({1}, {1});
110 Run<int>({11}, {1});
111 }
112
TEST_F(WhileGradientsTest,MultipleLoopVars)113 TEST_F(WhileGradientsTest, MultipleLoopVars) {
114 // Create loop: while (i < 10) i += j; j += 1; k = k
115 Init(3);
116 CreateLoop(
117 [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
118 *output = ops::Less(s, inputs[0], 10);
119 return s.status();
120 },
121 [](const Scope& s, const std::vector<Output>& inputs,
122 std::vector<Output>* outputs) {
123 outputs->push_back(ops::AddN(s, {inputs[0], inputs[1]}));
124 outputs->push_back(ops::AddN(s, {inputs[1], 1}));
125 outputs->push_back(inputs[2]);
126 return s.status();
127 });
128 CreateBackprop();
129
130 // The following execution traces illustrate why we expect dF/dj to be 5:
131 //
132 // i j k
133 // ---------
134 // 0 1 2 <-- initial values
135 // 1 2 2
136 // 3 3 2
137 // 6 4 2
138 // 10 5 2 <-- while output values
139 // outputs sum = 17
140 //
141 // i j k
142 // ---------
143 // 0 2 2 <-- initial values (add 1 to j)
144 // 2 3 2
145 // 5 4 2
146 // 9 5 2
147 // 14 6 2 <-- while output values
148 // outputs sum = 22
149 //
150 // Calculate the "slope" between j=1 and j=2:
151 // 22 - 17 = 5 => dF/dj = 5
152 Run<int>({0, 1, 2}, {1, 5, 1});
153
154 Run<int>({1, 1, 0}, {1, 5, 1});
155 Run<int>({0, 0, 0}, {1, 6, 1});
156 }
157
TEST_F(WhileGradientsTest,Chaining)158 TEST_F(WhileGradientsTest, Chaining) {
159 Init(2, DT_DOUBLE);
160
161 // Multiply each input by 2 before passing to while loop to make sure chaining
162 // works properly
163 std::vector<Output> loop_inputs = {ops::Multiply(scope_, inputs_[0], 2.0),
164 ops::Multiply(scope_, inputs_[1], 2.0)};
165
166 // Create loop: while (i > 0 && j > 0) i -= 1
167 CreateLoop(
168 [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
169 *output = ops::LogicalAnd(s, ops::Greater(s, inputs[0], 0.0),
170 ops::Greater(s, inputs[1], 0.0));
171 return s.status();
172 },
173 [](const Scope& s, const std::vector<Output>& inputs,
174 std::vector<Output>* outputs) {
175 outputs->push_back(ops::AddN(s, {inputs[0], -1.0}));
176 outputs->push_back(inputs[1]);
177 return s.status();
178 },
179 &loop_inputs);
180
181 // Take negative of first output to make sure chaining works properly
182 outputs_[0] = ops::Neg(scope_, outputs_[0]);
183
184 CreateBackprop();
185
186 Run<double>({1.0, 1.0}, {-2.0, 2.0});
187 Run<double>({0.0, 0.0}, {-2.0, 2.0});
188 }
189
TEST_F(WhileGradientsTest,MultipleDevices)190 TEST_F(WhileGradientsTest, MultipleDevices) {
191 // Make sure loop is created on cpu0
192 scope_ = scope_.WithDevice("/cpu:0");
193
194 // Create loop: while (i < 10) i += j
195 Init(2);
196 CreateLoop(
197 [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
198 *output = ops::Less(s, inputs[0], 10);
199 return s.status();
200 },
201 [](const Scope& s, const std::vector<Output>& inputs,
202 std::vector<Output>* outputs) {
203 // Place body on cpu1
204 Scope cpu1_scope = s.WithDevice("/cpu:1");
205 outputs->push_back(ops::AddN(cpu1_scope, {inputs[0], inputs[1]}));
206 outputs->push_back(inputs[1]);
207 return cpu1_scope.status();
208 });
209
210 // Build gradient graph on cpu1
211 Scope cpu1_scope = scope_.WithDevice("/cpu:1");
212 TF_ASSERT_OK(
213 AddSymbolicGradients(cpu1_scope, outputs_, inputs_, &grad_outputs_));
214 ASSERT_EQ(grad_outputs_.size(), inputs_.size());
215
216 // Run with two CPU devices and output partition graphs
217 SessionOptions session_options;
218 (*session_options.config.mutable_device_count())["CPU"] = 2;
219 RunOptions run_options;
220 run_options.set_output_partition_graphs(true);
221 RunMetadata run_metadata;
222 Run<int>(ClientSession(scope_, session_options), {0, 1}, {1, 11}, run_options,
223 &run_metadata);
224
225 // Check that at least one node ran on each device
226 ASSERT_EQ(run_metadata.partition_graphs().size(), 2);
227 for (const GraphDef& partition_graph : run_metadata.partition_graphs()) {
228 EXPECT_GE(partition_graph.node().size(), 1);
229 }
230 }
231
232 } // namespace
233 } // namespace tensorflow
234