1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/constant_folding.h"
17
18 #include <map>
19 #include <string>
20 #include <unordered_map>
21 #include <vector>
22
23 #include "tensorflow/cc/ops/array_ops_internal.h"
24 #include "tensorflow/cc/ops/nn_ops.h"
25 #include "tensorflow/cc/ops/sendrecv_ops.h"
26 #include "tensorflow/cc/ops/standard_ops.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/common_runtime/device_factory.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/framework/device_attributes.pb.h"
31 #include "tensorflow/core/framework/function_testlib.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/tensor_shape.h"
35 #include "tensorflow/core/framework/tensor_testutil.h"
36 #include "tensorflow/core/framework/types.h"
37 #include "tensorflow/core/graph/node_builder.h"
38 #include "tensorflow/core/lib/core/status_test_util.h"
39 #include "tensorflow/core/lib/core/threadpool.h"
40 #include "tensorflow/core/lib/strings/strcat.h"
41 #include "tensorflow/core/platform/null_file_system.h"
42 #include "tensorflow/core/platform/test.h"
43 #include "tensorflow/core/public/session_options.h"
44
45 namespace tensorflow {
46 namespace {
47
48 class ConstantFoldingTest : public ::testing::Test {
49 protected:
50 template <typename T>
ExpectNodeClose(const Node * n,gtl::ArraySlice<T> values,TensorShape shape)51 void ExpectNodeClose(const Node* n, gtl::ArraySlice<T> values,
52 TensorShape shape) {
53 EXPECT_TRUE(n->IsConstant());
54 const TensorProto* tensor_proto;
55 TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor_proto));
56 DataType dtype;
57 TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype));
58 Tensor t(dtype);
59 EXPECT_TRUE(t.FromProto(*tensor_proto));
60 test::ExpectClose(t, test::AsTensor(values, shape));
61 }
62
63 template <typename T>
ExpectNodeEqual(const Node * n,gtl::ArraySlice<T> values,TensorShape shape)64 void ExpectNodeEqual(const Node* n, gtl::ArraySlice<T> values,
65 TensorShape shape) {
66 EXPECT_TRUE(n->IsConstant());
67 const TensorProto* tensor_proto;
68 TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor_proto));
69 DataType dtype;
70 TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype));
71 Tensor t(dtype);
72 EXPECT_TRUE(t.FromProto(*tensor_proto));
73 test::ExpectTensorEqual<T>(t, test::AsTensor(values, shape));
74 }
75
76 // Constructs the following graph.
77 /*
78 s1 s2
79 | |
80 m1 m2
81 / \ / \
82 a b c
83 */
BuildSimpleGraph(Scope * scope)84 void BuildSimpleGraph(Scope* scope) {
85 Scope& s = *scope;
86 auto a = ops::Const<float>(s, {1.0, 0.0, 0.0, 1.0}, {2, 2});
87 auto b = ops::Const<float>(s, {1.0, 2.0, 3.0, 4.0}, {2, 2});
88 auto c = ops::Const<float>(s, {0.0, 1.0, 1.0, 0.0}, {2, 2});
89 auto m1 = ops::MatMul(s, a, b);
90 auto s1 = ops::_Send(s.WithOpName("s1"), m1, "m1", "sender", 0, "receiver");
91 auto m2 = ops::MatMul(s.WithOpName("m2"), b, c);
92 auto s2 = ops::_Send(s.WithOpName("s2"), m2, "m2", "sender", 0, "receiver");
93 }
94 };
95
96 class FakeDevice : public Device {
97 private:
FakeDevice(const DeviceAttributes & device_attributes)98 explicit FakeDevice(const DeviceAttributes& device_attributes)
99 : Device(nullptr, device_attributes) {}
100
101 public:
Sync()102 Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); }
103
GetAllocator(AllocatorAttributes attr)104 Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; }
105
Make(const string & name,const string & type)106 static std::unique_ptr<Device> Make(const string& name, const string& type) {
107 DeviceAttributes device_attributes;
108 device_attributes.set_name(name);
109 device_attributes.set_device_type(DeviceType(type).type());
110 return std::unique_ptr<Device>(new FakeDevice(device_attributes));
111 }
112 };
113
TEST_F(ConstantFoldingTest,Basic)114 TEST_F(ConstantFoldingTest, Basic) {
115 Scope s = Scope::NewRootScope();
116 BuildSimpleGraph(&s);
117 Graph g(OpRegistry::Global());
118 TF_ASSERT_OK(s.ToGraph(&g));
119
120 bool was_mutated;
121 TF_ASSERT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
122 nullptr, &g, &was_mutated));
123 EXPECT_TRUE(was_mutated);
124
125 std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
126 Node* s1 = index.at("s1");
127 Node* s2 = index.at("s2");
128 // Nodes s1 and s2 now should now have a constant input
129 EXPECT_EQ(1, s1->num_inputs());
130 ExpectNodeClose<float>(*(s1->in_nodes().begin()), {1.0, 2.0, 3.0, 4.0},
131 {2, 2});
132 EXPECT_EQ(1, s2->num_inputs());
133 ExpectNodeClose<float>(*(s2->in_nodes().begin()), {2.0, 1.0, 4.0, 3.0},
134 {2, 2});
135 }
136
137 // Tests that different node creation ordering creates same graph after constant
138 // folding.
TEST_F(ConstantFoldingTest,DeterministicFolding)139 TEST_F(ConstantFoldingTest, DeterministicFolding) {
140 auto build_graph_and_constant_folding = [](Graph& g, bool swap) -> Status {
141 Scope s = Scope::NewRootScope();
142 auto a = ops::Const<float>(s, {1.0}, {});
143 auto b = ops::Const<float>(s, {2.0}, {});
144
145 if (swap) {
146 auto add1 = ops::Add(s.WithOpName("add1"), a, b);
147 auto add2 = ops::Add(s.WithOpName("add2"), a, b);
148 auto s1 =
149 ops::_Send(s.WithOpName("s1"), add1, "add1", "sender", 0, "receiver");
150 auto s2 =
151 ops::_Send(s.WithOpName("s2"), add2, "add2", "sender", 0, "receiver");
152 } else {
153 // Swap the order of node creation.
154 auto add2 = ops::Add(s.WithOpName("add2"), a, b);
155 auto add1 = ops::Add(s.WithOpName("add1"), a, b);
156 auto s1 =
157 ops::_Send(s.WithOpName("s1"), add1, "add1", "sender", 0, "receiver");
158 auto s2 =
159 ops::_Send(s.WithOpName("s2"), add2, "add2", "sender", 0, "receiver");
160 }
161
162 TF_CHECK_OK(s.ToGraph(&g));
163 bool was_mutated;
164 int64 unique_id = 0;
165 auto generate_new_name = [&unique_id](Graph* graph, string old_name) {
166 return strings::StrCat(graph->NewName(old_name), "__cf__", unique_id++);
167 };
168 ConstantFoldingOptions opt{};
169 opt.generate_new_name = generate_new_name;
170 TF_CHECK_OK(
171 ConstantFold(opt, nullptr, Env::Default(), nullptr, &g, &was_mutated));
172 return Status::OK();
173 };
174
175 Graph g1(OpRegistry::Global());
176 TF_ASSERT_OK(build_graph_and_constant_folding(g1, false));
177 Graph g2(OpRegistry::Global());
178 TF_ASSERT_OK(build_graph_and_constant_folding(g2, true));
179 EXPECT_EQ(g1.num_nodes(), g2.num_nodes());
180 auto index = g2.BuildNodeNameIndex();
181
182 // All the nodes in g1 are expected to be present in g2.
183 for (int64 i = 0; i < g1.num_nodes(); ++i) {
184 Node* n1 = g1.FindNodeId(i);
185 EXPECT_GT(index.count(n1->name()), 0);
186 }
187 }
188
TEST_F(ConstantFoldingTest,ConsiderFunction)189 TEST_F(ConstantFoldingTest, ConsiderFunction) {
190 Scope s = Scope::NewRootScope();
191 BuildSimpleGraph(&s);
192 Graph g(OpRegistry::Global());
193 TF_ASSERT_OK(s.ToGraph(&g));
194
195 ConstantFoldingOptions opts;
196 // Do not allow constant folding of m2
197 opts.consider = [](const Node* n) { return "m2" != n->name(); };
198 bool was_mutated;
199 TF_ASSERT_OK(
200 ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
201 EXPECT_TRUE(was_mutated);
202
203 std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
204 Node* s1 = index.at("s1");
205 Node* s2 = index.at("s2");
206 Node* m2 = index.at("m2");
207
208 // Node s1 now should now have a constant input
209 EXPECT_EQ(1, s1->num_inputs());
210 ExpectNodeClose<float>(*(s1->in_nodes().begin()), {1.0, 2.0, 3.0, 4.0},
211 {2, 2});
212 // s2's input should still be m2
213 EXPECT_EQ(1, s2->num_inputs());
214 EXPECT_EQ(*(s2->in_nodes().begin()), m2);
215 }
216
TEST_F(ConstantFoldingTest,TestNoReplaceAnotherConstant)217 TEST_F(ConstantFoldingTest, TestNoReplaceAnotherConstant) {
218 Graph g(OpRegistry::Global());
219 {
220 Scope s = Scope::NewRootScope();
221 BuildSimpleGraph(&s);
222 auto d = ops::Const<float>(s.WithOpName("d"), {1.0, 0.0, 0.0, 1.0}, {2, 2});
223 auto s3 = ops::_Send(s.WithOpName("s3"), d, "d", "sender", 0, "receiver");
224 TF_ASSERT_OK(s.ToGraph(&g));
225 }
226
227 bool was_mutated;
228 TF_ASSERT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
229 nullptr, &g, &was_mutated));
230 EXPECT_TRUE(was_mutated);
231
232 std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
233 Node* d = index.at("d");
234 Node* s3 = index.at("s3");
235
236 // Nodes s3 should still have d as input
237 EXPECT_EQ(1, s3->num_inputs());
238 EXPECT_EQ(*(s3->in_nodes().begin()), d);
239 }
240
TEST_F(ConstantFoldingTest,TwoOutputs)241 TEST_F(ConstantFoldingTest, TwoOutputs) {
242 Graph g(OpRegistry::Global());
243 {
244 Scope s = Scope::NewRootScope();
245 auto s0 = ops::Const<int>(s, {1}, {1});
246 auto s1 = ops::Const<int>(s, {2, 2}, {2});
247 auto b = ops::internal::BroadcastGradientArgs(s, s0, s1);
248 auto b0 = ops::_Send(s.WithOpName("b0"), ops::Identity(s, b.r0), "b0",
249 "sender", 0, "receiver");
250 auto b1 = ops::_Send(s.WithOpName("b1"), ops::Identity(s, b.r1), "b1",
251 "sender", 0, "receiver");
252 TF_ASSERT_OK(s.ToGraph(&g));
253 }
254
255 bool was_mutated;
256 TF_ASSERT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
257 nullptr, &g, &was_mutated));
258 EXPECT_TRUE(was_mutated);
259
260 std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
261 Node* b0 = index.at("b0");
262 Node* b1 = index.at("b1");
263
264 EXPECT_EQ(1, b0->num_inputs());
265 ExpectNodeEqual<int>(*(b0->in_nodes().begin()), {0, 1}, {2});
266 EXPECT_EQ(1, b1->num_inputs());
267 ExpectNodeEqual<int>(*(b1->in_nodes().begin()), {}, {0});
268 }
269
TEST_F(ConstantFoldingTest,TwoOutputsFoldOneOutput)270 TEST_F(ConstantFoldingTest, TwoOutputsFoldOneOutput) {
271 Graph g(OpRegistry::Global());
272 {
273 Scope s = Scope::NewRootScope();
274 auto s0 = ops::Const<int>(s, {1}, {1});
275 auto s1 = ops::Const<int>(s, {2, 2}, {2});
276 auto b = ops::internal::BroadcastGradientArgs(s, s0, s1);
277 auto b0 = ops::_Send(s.WithOpName("b0"), ops::Identity(s, b.r0), "b0",
278 "sender", 0, "receiver");
279 auto b1_ident = ops::Identity(s.WithOpName("b1_ident"), b.r1);
280 auto b1 =
281 ops::_Send(s.WithOpName("b1"), b1_ident, "b1", "sender", 0, "receiver");
282 TF_ASSERT_OK(s.ToGraph(&g));
283 }
284
285 ConstantFoldingOptions opts;
286 opts.consider = [](const Node* n) { return "b1_ident" != n->name(); };
287 bool was_mutated;
288 TF_ASSERT_OK(
289 ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
290 EXPECT_TRUE(was_mutated);
291
292 std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
293 Node* b0 = index.at("b0");
294 Node* b1 = index.at("b1");
295 Node* b1_ident = index.at("b1_ident");
296
297 // 0th output of b should have been folded.
298 ASSERT_EQ(1, b0->num_inputs());
299 ExpectNodeEqual<int>(*(b0->in_nodes().begin()), {0, 1}, {2});
300 // 1st output of b should still be b1_ident. However, b1_ident's input must
301 // have been replaced with a constant.
302 ASSERT_EQ(1, b1->num_inputs());
303 EXPECT_EQ(*(b1->in_nodes().begin()), b1_ident);
304
305 ASSERT_EQ(1, b1_ident->num_inputs());
306 ExpectNodeEqual<int>(*(b1_ident->in_nodes().begin()), {}, {0});
307 }
308
TEST_F(ConstantFoldingTest,TestNoReplaceLargeConstant)309 TEST_F(ConstantFoldingTest, TestNoReplaceLargeConstant) {
310 Graph g(OpRegistry::Global());
311 {
312 Scope s = Scope::NewRootScope();
313 auto s0 = ops::Const<int>(s, 0, {5 * 1024 * 256});
314 auto s1 = ops::Const<int>(s, 0, {5 * 1024 * 256 + 1});
315 auto concat_dim = ops::Const<int>(s, 0);
316 auto concat = ops::Concat(s, {s0, s1}, concat_dim);
317 auto concat_send = ops::_Send(s.WithOpName("concat_send"), concat,
318 "concat_send", "sender", 0, "receiver");
319 TF_ASSERT_OK(s.ToGraph(&g));
320 }
321
322 // The above concat should not have been constant folded.
323 bool was_mutated;
324 TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
325 nullptr, &g, &was_mutated));
326 EXPECT_FALSE(was_mutated);
327
328 // Increase the limit and the concat should now be constant folded.
329 ConstantFoldingOptions opt;
330 opt.max_constant_size_in_bytes = 10 * 1024 * 1024 + 4;
331 TF_EXPECT_OK(
332 ConstantFold(opt, nullptr, Env::Default(), nullptr, &g, &was_mutated));
333 EXPECT_TRUE(was_mutated);
334 }
335
TEST_F(ConstantFoldingTest,TestNoReplaceFunctionCall)336 TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) {
337 FunctionDefLibrary flib;
338 *flib.add_function() = test::function::XTimesTwo();
339
340 FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
341 Graph g(flib_def);
342 {
343 Scope s = Scope::NewRootScope();
344 auto c = ops::Const<int32>(s.WithOpName("c"), {1}, {1});
345 TF_EXPECT_OK(s.graph()->AddFunctionLibrary(flib));
346
347 // TODO(phawkins): there is no way to make a function call using the C++
348 // graph builder API.
349 NodeDef def;
350 TF_ASSERT_OK(
351 NodeDefBuilder("times_two", "XTimesTwo", s.graph()->op_registry())
352 .Input(c.name(), 0, DT_INT32)
353 .Finalize(&def));
354 Status status;
355 Node* times_two = s.graph()->AddNode(def, &status);
356 TF_ASSERT_OK(status);
357 TF_ASSERT_OK(s.DoShapeInference(times_two));
358 s.graph()->AddEdge(c.node(), 0, times_two, 0);
359
360 auto times_two_send =
361 ops::_Send(s.WithOpName("times_two_send"), Output(times_two),
362 "times_two_send", "sender", 0, "receiver");
363 TF_ASSERT_OK(s.ToGraph(&g));
364 }
365
366 // The above function call should not have been constant folded.
367 bool was_mutated;
368 TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
369 nullptr, &g, &was_mutated));
370 EXPECT_FALSE(was_mutated);
371 }
372
373 REGISTER_OP("ConstantFoldingTestOp")
374 .Input("a: int64")
375 .Output("b: int64")
376 .SetShapeFn(shape_inference::UnknownShape);
377
TEST_F(ConstantFoldingTest,TestNoReplaceNonCPUOp)378 TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) {
379 Graph g(OpRegistry::Global());
380 {
381 Scope s = Scope::NewRootScope();
382 auto aconst = ops::Const<int64>(s, 0, {5});
383
384 NodeDef def;
385 TF_ASSERT_OK(NodeDefBuilder("testop", "ConstantFoldingTestOp")
386 .Input(aconst.name(), 0, DT_INT64)
387 .Finalize(&def));
388 Status status;
389 Node* non_cpu = s.graph()->AddNode(def, &status);
390 TF_ASSERT_OK(status);
391 TF_ASSERT_OK(s.DoShapeInference(non_cpu));
392
393 auto non_cpu_send =
394 ops::_Send(s.WithOpName("non_cpu_send"), Output(non_cpu),
395 "non_cpu_send", "sender", 0, "receiver");
396 TF_ASSERT_OK(s.ToGraph(&g));
397 }
398
399 // The non-CPU op should not have been constant folded.
400 bool was_mutated;
401 TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
402 nullptr, &g, &was_mutated));
403 EXPECT_FALSE(was_mutated);
404 }
405
TEST_F(ConstantFoldingTest,ControlDependencies)406 TEST_F(ConstantFoldingTest, ControlDependencies) {
407 Graph g(OpRegistry::Global());
408 {
409 Scope s = Scope::NewRootScope();
410 auto c0 = ops::Const<int>(s, 1);
411 auto recv1 = ops::_Recv(s.WithOpName("recv1"), DT_FLOAT, "recv1", "sender",
412 0, "receiver");
413 auto c1 = ops::Const<int>(s.WithControlDependencies(recv1), 2);
414 auto recv2 = ops::_Recv(s.WithOpName("recv2"), DT_FLOAT, "recv2", "sender",
415 0, "receiver");
416 auto c2 = ops::Const<int>(s.WithControlDependencies(recv2), 3);
417 auto add = ops::Add(s.WithControlDependencies(c2), c0, c1);
418 auto send =
419 ops::_Send(s.WithOpName("send"), add, "send", "sender", 0, "receiver");
420 TF_ASSERT_OK(s.ToGraph(&g));
421 }
422 bool was_mutated;
423 TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
424 nullptr, &g, &was_mutated));
425 EXPECT_TRUE(was_mutated);
426
427 std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
428 Node* recv1 = index.at("recv1");
429 Node* recv2 = index.at("recv2");
430 Node* send = index.at("send");
431
432 ASSERT_EQ(1, send->num_inputs());
433 Node* p = *(send->in_nodes().begin());
434 ExpectNodeEqual<int>(p, {3}, {});
435
436 ASSERT_EQ(2, p->in_edges().size());
437 for (const Edge* e : p->in_edges()) {
438 EXPECT_TRUE(e->IsControlEdge());
439 EXPECT_TRUE(e->src() == recv1 || e->src() == recv2) << e->src()->name();
440 }
441 }
442
TEST_F(ConstantFoldingTest,SimpleShapeKnown)443 TEST_F(ConstantFoldingTest, SimpleShapeKnown) {
444 Graph g(OpRegistry::Global());
445 {
446 Scope s = Scope::NewRootScope();
447 Output recv0 = ops::_Recv(s.WithOpName("recv0"), DT_FLOAT, "recv0",
448 "sender", 0, "receiver");
449 auto shape = ops::Shape(s.WithOpName("shape"), recv0);
450 Output recv1 = ops::_Recv(s.WithOpName("recv1"), DT_FLOAT, "recv1",
451 "sender", 0, "receiver");
452 auto shape_n = ops::ShapeN(s.WithOpName("shape_n"), {recv0, recv1});
453 auto rank = ops::Rank(s.WithOpName("rank"), recv0);
454 auto size = ops::Size(s.WithOpName("size"), recv1);
455 auto recv2 = ops::_Recv(s.WithOpName("recv2"), DT_FLOAT, "recv2", "sender",
456 0, "receiver");
457 auto c = ops::Const<int>(s.WithControlDependencies(recv2), 3);
458 auto add0 = ops::Add(s.WithControlDependencies(c), rank, size);
459 auto add1 = ops::Add(s, shape, shape_n[0]);
460 auto add2 = ops::Add(s, shape_n[1], shape_n[1]);
461 auto send0 = ops::_Send(s.WithOpName("send0"), add0, "send0", "sender", 0,
462 "receiver");
463 auto send1 = ops::_Send(s.WithOpName("send1"), add1, "send1", "sender", 0,
464 "receiver");
465 auto send2 = ops::_Send(s.WithOpName("send2"), add2, "send2", "sender", 0,
466 "receiver");
467 TF_ASSERT_OK(s.ToGraph(&g));
468 }
469 std::unordered_map<string, Node*> orig_index = g.BuildNodeNameIndex();
470 Node* recv0 = orig_index.at("recv0");
471 Node* recv1 = orig_index.at("recv1");
472 PartialTensorShape ps0;
473 int r0_dims[] = {1, 2};
474 TF_EXPECT_OK(PartialTensorShape::MakePartialShape(r0_dims, 2, &ps0));
475 PartialTensorShape ps1;
476 int r1_dims[] = {2, 3, 4};
477 TF_EXPECT_OK(PartialTensorShape::MakePartialShape<int>(r1_dims, 3, &ps1));
478 std::unordered_map<string, std::vector<PartialTensorShape>> map;
479 map[recv0->name()].push_back(ps0);
480 map[recv1->name()].push_back(ps1);
481 ConstantFoldingOptions opts;
482 opts.shape_map = ↦
483 bool was_mutated;
484 TF_EXPECT_OK(
485 ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
486 EXPECT_TRUE(was_mutated);
487
488 std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
489 Node* recv2 = index.at("recv2");
490 Node* send0 = index.at("send0");
491 Node* send1 = index.at("send1");
492 Node* send2 = index.at("send2");
493
494 ASSERT_EQ(1, send0->num_inputs());
495 Node* cf0 = *(send0->in_nodes().begin());
496 ExpectNodeEqual<int>(cf0, {26}, {});
497
498 ASSERT_EQ(1, send1->num_inputs());
499 Node* cf1 = *(send1->in_nodes().begin());
500 ExpectNodeEqual<int>(cf1, {2, 4}, {2});
501
502 ASSERT_EQ(1, send2->num_inputs());
503 Node* cf2 = *(send2->in_nodes().begin());
504 ExpectNodeEqual<int>(cf2, {4, 6, 8}, {3});
505
506 ASSERT_EQ(3, cf0->in_edges().size());
507 for (const Edge* e : cf0->in_edges()) {
508 EXPECT_TRUE(e->IsControlEdge());
509 EXPECT_TRUE(e->src() == recv0 || e->src() == recv1 || e->src() == recv2)
510 << e->src()->name();
511 }
512
513 ASSERT_EQ(2, cf1->in_edges().size());
514 for (const Edge* e : cf1->in_edges()) {
515 EXPECT_TRUE(e->IsControlEdge());
516 EXPECT_TRUE(e->src() == recv0 || e->src() == recv1) << e->src()->name();
517 }
518
519 ASSERT_EQ(2, cf2->in_edges().size());
520 for (const Edge* e : cf2->in_edges()) {
521 EXPECT_TRUE(e->IsControlEdge());
522 EXPECT_TRUE(e->src() == recv0 || e->src() == recv1) << e->src()->name();
523 }
524 }
525
TEST_F(ConstantFoldingTest,PartialShape)526 TEST_F(ConstantFoldingTest, PartialShape) {
527 Graph g(OpRegistry::Global());
528 {
529 Scope s = Scope::NewRootScope();
530 Output recv0 = ops::_Recv(s.WithOpName("recv0"), DT_FLOAT, "recv0",
531 "sender", 0, "receiver");
532 Output recv1 = ops::_Recv(s.WithOpName("recv1"), DT_FLOAT, "recv1",
533 "sender", 0, "receiver");
534 auto shape = ops::Shape(s.WithOpName("shape"), recv0);
535 auto rank0 = ops::Rank(s.WithOpName("rank0"), recv0);
536 auto rank1 = ops::Rank(s.WithOpName("rank1"), recv1);
537 auto size = ops::Size(s.WithOpName("size"), recv0);
538 auto send0 = ops::_Send(s.WithOpName("send0"), rank0, "send0", "sender", 0,
539 "receiver");
540 auto send1 = ops::_Send(s.WithOpName("send1"), shape, "send1", "sender", 0,
541 "receiver");
542 auto send2 = ops::_Send(s.WithOpName("send2"), size, "send2", "sender", 0,
543 "receiver");
544 auto send3 = ops::_Send(s.WithOpName("send3"), rank1, "send3", "sender", 0,
545 "receiver");
546 TF_ASSERT_OK(s.ToGraph(&g));
547 }
548 std::unordered_map<string, Node*> orig_index = g.BuildNodeNameIndex();
549 Node* recv0 = orig_index.at("recv0");
550 Node* recv1 = orig_index.at("recv1");
551 PartialTensorShape ps0;
552 int r0_dims[] = {-1, -1};
553 TF_EXPECT_OK(PartialTensorShape::MakePartialShape(r0_dims, 2, &ps0));
554 PartialTensorShape ps1;
555 std::unordered_map<string, std::vector<PartialTensorShape>> map;
556 map[recv0->name()].push_back(ps0);
557 map[recv1->name()].push_back(ps1);
558 ConstantFoldingOptions opts;
559 opts.shape_map = ↦
560 bool was_mutated;
561 TF_EXPECT_OK(
562 ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
563 EXPECT_TRUE(was_mutated);
564
565 std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
566 Node* shape = index.at("shape");
567 Node* size = index.at("size");
568 Node* rank1 = index.at("rank1");
569 Node* send0 = index.at("send0");
570 Node* send1 = index.at("send1");
571 Node* send2 = index.at("send2");
572 Node* send3 = index.at("send3");
573
574 ASSERT_EQ(1, send0->num_inputs());
575 Node* cf0 = *(send0->in_nodes().begin());
576 ExpectNodeEqual<int>(cf0, {2}, {});
577
578 ASSERT_EQ(1, send1->num_inputs());
579 Node* ncf1 = *(send1->in_nodes().begin());
580 EXPECT_EQ(ncf1, shape);
581
582 ASSERT_EQ(1, send2->num_inputs());
583 Node* ncf2 = *(send2->in_nodes().begin());
584 EXPECT_EQ(ncf2, size);
585
586 ASSERT_EQ(1, send3->num_inputs());
587 Node* ncf3 = *(send3->in_nodes().begin());
588 EXPECT_EQ(ncf3, rank1);
589 }
590
TEST_F(ConstantFoldingTest,ConstShapeKnown)591 TEST_F(ConstantFoldingTest, ConstShapeKnown) {
592 Graph g(OpRegistry::Global());
593 {
594 Scope s = Scope::NewRootScope();
595 auto recv0 = ops::_Recv(s.WithOpName("recv0"), DT_FLOAT, "recv0", "sender",
596 0, "receiver");
597 auto c0 =
598 ops::Const<int>(s.WithOpName("c0").WithControlDependencies(recv0), 1);
599 auto rank = ops::Rank(s.WithOpName("rank"), c0);
600 auto add0 = ops::Add(s, rank, rank);
601 auto send0 = ops::_Send(s.WithOpName("send0"), add0, "send0", "sender", 0,
602 "receiver");
603 TF_ASSERT_OK(s.ToGraph(&g));
604 }
605 std::unordered_map<string, Node*> orig_index = g.BuildNodeNameIndex();
606 Node* c0 = orig_index.at("c0");
607 PartialTensorShape ps0;
608 int c0_dims[] = {};
609 TF_EXPECT_OK(PartialTensorShape::MakePartialShape(c0_dims, 0, &ps0));
610 std::unordered_map<string, std::vector<PartialTensorShape>> map;
611 map[c0->name()].push_back(ps0);
612 ConstantFoldingOptions opts;
613 opts.shape_map = ↦
614 bool was_mutated;
615 TF_EXPECT_OK(
616 ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
617 EXPECT_TRUE(was_mutated);
618
619 std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
620 Node* recv0 = index.at("recv0");
621 Node* send0 = index.at("send0");
622
623 ASSERT_EQ(1, send0->num_inputs());
624 Node* cf0 = *(send0->in_nodes().begin());
625 ExpectNodeEqual<int>(cf0, {0}, {});
626
627 ASSERT_EQ(1, cf0->in_edges().size());
628 for (const Edge* e : cf0->in_edges()) {
629 EXPECT_TRUE(e->IsControlEdge());
630 EXPECT_TRUE(e->src() == recv0) << e->src()->name();
631 }
632 }
633
TEST_F(ConstantFoldingTest,NoReplacePartialOutput)634 TEST_F(ConstantFoldingTest, NoReplacePartialOutput) {
635 Graph g(OpRegistry::Global());
636 {
637 Scope s = Scope::NewRootScope().ExitOnError().WithAssignedDevice("/gpu:0");
638
639 auto c0 = ops::Const<float>(s.WithOpName("c0"), {5.0, 2.0, 8.0, 1.0}, {4});
640 auto k = ops::Const<int>(s.WithOpName("k"), 3);
641 auto topK =
642 ops::TopK(s.WithOpName("topK"), c0, k, ops::TopK::Sorted(false));
643 auto send_values = ops::_Send(s.WithOpName("send_values"), topK.values,
644 "send_values", "sender", 0, "receiver");
645 auto send_indices = ops::_Send(s.WithOpName("send_indices"), topK.indices,
646 "send_indices", "sender", 0, "receiver");
647 TF_ASSERT_OK(s.ToGraph(&g));
648 }
649 bool was_mutated;
650 TF_EXPECT_OK(ConstantFold(
651 ConstantFoldingOptions{}, nullptr, Env::Default(),
652 FakeDevice::Make("/job:tpu_worker/replica:0/task:0/device:GPU:0",
653 DEVICE_GPU)
654 .get(),
655 &g, &was_mutated));
656 EXPECT_FALSE(was_mutated);
657 }
658
659 namespace {
660
661 const char kTestMemRegionName[] = "test://test";
662
663 class TestReadOnlyMemoryRegion : public ::tensorflow::ReadOnlyMemoryRegion {
664 public:
665 ~TestReadOnlyMemoryRegion() override = default;
TestReadOnlyMemoryRegion(const void * data,uint64 length)666 TestReadOnlyMemoryRegion(const void* data, uint64 length)
667 : data_(data), length_(length) {}
data()668 const void* data() override { return data_; }
length()669 uint64 length() override { return length_; }
670
671 protected:
672 const void* data_;
673 uint64 length_;
674 };
675
676 class TestTFFileSystem : public ::tensorflow::NullFileSystem {
677 public:
TestTFFileSystem()678 TestTFFileSystem()
679 : ::tensorflow::NullFileSystem(),
680 data_tensor_(test::AsTensor<double>({1., 2., 3., 4.}, {2, 2})) {}
681
682 using ::tensorflow::NullFileSystem::NewReadOnlyMemoryRegionFromFile;
683
NewReadOnlyMemoryRegionFromFile(const string & fname,::tensorflow::TransactionToken * token,std::unique_ptr<::tensorflow::ReadOnlyMemoryRegion> * result)684 ::tensorflow::Status NewReadOnlyMemoryRegionFromFile(
685 const string& fname, ::tensorflow::TransactionToken* token,
686 std::unique_ptr<::tensorflow::ReadOnlyMemoryRegion>* result) override {
687 if (fname != kTestMemRegionName) {
688 return ::tensorflow::errors::Unimplemented(
689 "NewReadOnlyMemoryRegionFromFile unimplemented");
690 }
691 const ::tensorflow::StringPiece sp = data_tensor_.tensor_data();
692 *result = std::unique_ptr<::tensorflow::ReadOnlyMemoryRegion>(
693 new TestReadOnlyMemoryRegion(sp.data(), sp.size()));
694 return ::tensorflow::Status::OK();
695 }
696
697 protected:
698 ::tensorflow::Tensor data_tensor_;
699 };
700
701 // A test TF environment that checks that the environment was used.
702 class TestTFEnvironment : public ::tensorflow::EnvWrapper {
703 public:
704 using tf_base = ::tensorflow::EnvWrapper;
TestTFEnvironment()705 TestTFEnvironment() : ::tensorflow::EnvWrapper(Default()) {}
GetFileSystemForFile(const string & fname,::tensorflow::FileSystem ** result)706 ::tensorflow::Status GetFileSystemForFile(
707 const string& fname, ::tensorflow::FileSystem** result) override {
708 was_used_ = true;
709 if (fname == "test://test") {
710 *result = &test_filesystem_;
711 return ::tensorflow::Status::OK();
712 }
713 return tf_base::GetFileSystemForFile(fname, result);
714 }
was_used() const715 bool was_used() const { return was_used_; }
716
717 protected:
718 TestTFFileSystem test_filesystem_;
719 bool was_used_ = false;
720 };
721 } // namespace
722
TEST_F(ConstantFoldingTest,TestImmutableConst)723 TEST_F(ConstantFoldingTest, TestImmutableConst) {
724 Graph g(OpRegistry::Global());
725 Scope root = Scope::NewRootScope();
726
727 auto a = ops::ImmutableConst(root, DT_DOUBLE, {2, 2}, kTestMemRegionName);
728 auto b = ops::Const<double>(root, {1.0, 2.0, 3.0, 4.0}, {2, 2});
729 auto c = ops::RandomGamma(root, {2, 2}, 2.0);
730 auto result1 = ops::MatMul(root, a, b);
731 auto result2 = ops::MatMul(root, result1, c);
732 TF_ASSERT_OK(root.ToGraph(&g));
733 TestTFEnvironment test_env;
734 bool was_mutated;
735 Status status = ConstantFold(ConstantFoldingOptions{}, nullptr,
736 Env::Default(), nullptr, &g, &was_mutated);
737 EXPECT_FALSE(was_mutated);
738 EXPECT_FALSE(status.ok());
739 TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, &test_env,
740 nullptr, &g, &was_mutated));
741 EXPECT_TRUE(was_mutated);
742 }
743
744 } // namespace
745 } // namespace tensorflow
746