• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/constant_folding.h"
17 
18 #include <map>
19 #include <string>
20 #include <unordered_map>
21 #include <vector>
22 
23 #include "tensorflow/cc/ops/array_ops_internal.h"
24 #include "tensorflow/cc/ops/nn_ops.h"
25 #include "tensorflow/cc/ops/sendrecv_ops.h"
26 #include "tensorflow/cc/ops/standard_ops.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/common_runtime/device_factory.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/framework/device_attributes.pb.h"
31 #include "tensorflow/core/framework/function_testlib.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/tensor_shape.h"
35 #include "tensorflow/core/framework/tensor_testutil.h"
36 #include "tensorflow/core/framework/types.h"
37 #include "tensorflow/core/graph/node_builder.h"
38 #include "tensorflow/core/lib/core/status_test_util.h"
39 #include "tensorflow/core/lib/core/threadpool.h"
40 #include "tensorflow/core/lib/strings/strcat.h"
41 #include "tensorflow/core/platform/null_file_system.h"
42 #include "tensorflow/core/platform/test.h"
43 #include "tensorflow/core/public/session_options.h"
44 
45 namespace tensorflow {
46 namespace {
47 
48 class ConstantFoldingTest : public ::testing::Test {
49  protected:
50   template <typename T>
ExpectNodeClose(const Node * n,gtl::ArraySlice<T> values,TensorShape shape)51   void ExpectNodeClose(const Node* n, gtl::ArraySlice<T> values,
52                        TensorShape shape) {
53     EXPECT_TRUE(n->IsConstant());
54     const TensorProto* tensor_proto;
55     TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor_proto));
56     DataType dtype;
57     TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype));
58     Tensor t(dtype);
59     EXPECT_TRUE(t.FromProto(*tensor_proto));
60     test::ExpectClose(t, test::AsTensor(values, shape));
61   }
62 
63   template <typename T>
ExpectNodeEqual(const Node * n,gtl::ArraySlice<T> values,TensorShape shape)64   void ExpectNodeEqual(const Node* n, gtl::ArraySlice<T> values,
65                        TensorShape shape) {
66     EXPECT_TRUE(n->IsConstant());
67     const TensorProto* tensor_proto;
68     TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor_proto));
69     DataType dtype;
70     TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype));
71     Tensor t(dtype);
72     EXPECT_TRUE(t.FromProto(*tensor_proto));
73     test::ExpectTensorEqual<T>(t, test::AsTensor(values, shape));
74   }
75 
76   // Constructs the following graph.
77   /*
78         s1  s2
79         |    |
80         m1   m2
81         / \ / \
82        a   b   c
83   */
BuildSimpleGraph(Scope * scope)84   void BuildSimpleGraph(Scope* scope) {
85     Scope& s = *scope;
86     auto a = ops::Const<float>(s, {1.0, 0.0, 0.0, 1.0}, {2, 2});
87     auto b = ops::Const<float>(s, {1.0, 2.0, 3.0, 4.0}, {2, 2});
88     auto c = ops::Const<float>(s, {0.0, 1.0, 1.0, 0.0}, {2, 2});
89     auto m1 = ops::MatMul(s, a, b);
90     auto s1 = ops::_Send(s.WithOpName("s1"), m1, "m1", "sender", 0, "receiver");
91     auto m2 = ops::MatMul(s.WithOpName("m2"), b, c);
92     auto s2 = ops::_Send(s.WithOpName("s2"), m2, "m2", "sender", 0, "receiver");
93   }
94 };
95 
96 class FakeDevice : public Device {
97  private:
FakeDevice(const DeviceAttributes & device_attributes)98   explicit FakeDevice(const DeviceAttributes& device_attributes)
99       : Device(nullptr, device_attributes) {}
100 
101  public:
Sync()102   Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); }
103 
GetAllocator(AllocatorAttributes attr)104   Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; }
105 
Make(const string & name,const string & type)106   static std::unique_ptr<Device> Make(const string& name, const string& type) {
107     DeviceAttributes device_attributes;
108     device_attributes.set_name(name);
109     device_attributes.set_device_type(DeviceType(type).type());
110     return std::unique_ptr<Device>(new FakeDevice(device_attributes));
111   }
112 };
113 
TEST_F(ConstantFoldingTest,Basic)114 TEST_F(ConstantFoldingTest, Basic) {
115   Scope s = Scope::NewRootScope();
116   BuildSimpleGraph(&s);
117   Graph g(OpRegistry::Global());
118   TF_ASSERT_OK(s.ToGraph(&g));
119 
120   bool was_mutated;
121   TF_ASSERT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
122                             nullptr, &g, &was_mutated));
123   EXPECT_TRUE(was_mutated);
124 
125   std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
126   Node* s1 = index.at("s1");
127   Node* s2 = index.at("s2");
128   // Nodes s1 and s2 now should now have a constant input
129   EXPECT_EQ(1, s1->num_inputs());
130   ExpectNodeClose<float>(*(s1->in_nodes().begin()), {1.0, 2.0, 3.0, 4.0},
131                          {2, 2});
132   EXPECT_EQ(1, s2->num_inputs());
133   ExpectNodeClose<float>(*(s2->in_nodes().begin()), {2.0, 1.0, 4.0, 3.0},
134                          {2, 2});
135 }
136 
137 // Tests that different node creation ordering creates same graph after constant
138 // folding.
TEST_F(ConstantFoldingTest,DeterministicFolding)139 TEST_F(ConstantFoldingTest, DeterministicFolding) {
140   auto build_graph_and_constant_folding = [](Graph& g, bool swap) -> Status {
141     Scope s = Scope::NewRootScope();
142     auto a = ops::Const<float>(s, {1.0}, {});
143     auto b = ops::Const<float>(s, {2.0}, {});
144 
145     if (swap) {
146       auto add1 = ops::Add(s.WithOpName("add1"), a, b);
147       auto add2 = ops::Add(s.WithOpName("add2"), a, b);
148       auto s1 =
149           ops::_Send(s.WithOpName("s1"), add1, "add1", "sender", 0, "receiver");
150       auto s2 =
151           ops::_Send(s.WithOpName("s2"), add2, "add2", "sender", 0, "receiver");
152     } else {
153       // Swap the order of node creation.
154       auto add2 = ops::Add(s.WithOpName("add2"), a, b);
155       auto add1 = ops::Add(s.WithOpName("add1"), a, b);
156       auto s1 =
157           ops::_Send(s.WithOpName("s1"), add1, "add1", "sender", 0, "receiver");
158       auto s2 =
159           ops::_Send(s.WithOpName("s2"), add2, "add2", "sender", 0, "receiver");
160     }
161 
162     TF_CHECK_OK(s.ToGraph(&g));
163     bool was_mutated;
164     int64 unique_id = 0;
165     auto generate_new_name = [&unique_id](Graph* graph, string old_name) {
166       return strings::StrCat(graph->NewName(old_name), "__cf__", unique_id++);
167     };
168     ConstantFoldingOptions opt{};
169     opt.generate_new_name = generate_new_name;
170     TF_CHECK_OK(
171         ConstantFold(opt, nullptr, Env::Default(), nullptr, &g, &was_mutated));
172     return Status::OK();
173   };
174 
175   Graph g1(OpRegistry::Global());
176   TF_ASSERT_OK(build_graph_and_constant_folding(g1, false));
177   Graph g2(OpRegistry::Global());
178   TF_ASSERT_OK(build_graph_and_constant_folding(g2, true));
179   EXPECT_EQ(g1.num_nodes(), g2.num_nodes());
180   auto index = g2.BuildNodeNameIndex();
181 
182   // All the nodes in g1 are expected to be present in g2.
183   for (int64 i = 0; i < g1.num_nodes(); ++i) {
184     Node* n1 = g1.FindNodeId(i);
185     EXPECT_GT(index.count(n1->name()), 0);
186   }
187 }
188 
TEST_F(ConstantFoldingTest,ConsiderFunction)189 TEST_F(ConstantFoldingTest, ConsiderFunction) {
190   Scope s = Scope::NewRootScope();
191   BuildSimpleGraph(&s);
192   Graph g(OpRegistry::Global());
193   TF_ASSERT_OK(s.ToGraph(&g));
194 
195   ConstantFoldingOptions opts;
196   // Do not allow constant folding of m2
197   opts.consider = [](const Node* n) { return "m2" != n->name(); };
198   bool was_mutated;
199   TF_ASSERT_OK(
200       ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
201   EXPECT_TRUE(was_mutated);
202 
203   std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
204   Node* s1 = index.at("s1");
205   Node* s2 = index.at("s2");
206   Node* m2 = index.at("m2");
207 
208   // Node s1 now should now have a constant input
209   EXPECT_EQ(1, s1->num_inputs());
210   ExpectNodeClose<float>(*(s1->in_nodes().begin()), {1.0, 2.0, 3.0, 4.0},
211                          {2, 2});
212   // s2's input should still be m2
213   EXPECT_EQ(1, s2->num_inputs());
214   EXPECT_EQ(*(s2->in_nodes().begin()), m2);
215 }
216 
TEST_F(ConstantFoldingTest,TestNoReplaceAnotherConstant)217 TEST_F(ConstantFoldingTest, TestNoReplaceAnotherConstant) {
218   Graph g(OpRegistry::Global());
219   {
220     Scope s = Scope::NewRootScope();
221     BuildSimpleGraph(&s);
222     auto d = ops::Const<float>(s.WithOpName("d"), {1.0, 0.0, 0.0, 1.0}, {2, 2});
223     auto s3 = ops::_Send(s.WithOpName("s3"), d, "d", "sender", 0, "receiver");
224     TF_ASSERT_OK(s.ToGraph(&g));
225   }
226 
227   bool was_mutated;
228   TF_ASSERT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
229                             nullptr, &g, &was_mutated));
230   EXPECT_TRUE(was_mutated);
231 
232   std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
233   Node* d = index.at("d");
234   Node* s3 = index.at("s3");
235 
236   // Nodes s3 should still have d as input
237   EXPECT_EQ(1, s3->num_inputs());
238   EXPECT_EQ(*(s3->in_nodes().begin()), d);
239 }
240 
TEST_F(ConstantFoldingTest,TwoOutputs)241 TEST_F(ConstantFoldingTest, TwoOutputs) {
242   Graph g(OpRegistry::Global());
243   {
244     Scope s = Scope::NewRootScope();
245     auto s0 = ops::Const<int>(s, {1}, {1});
246     auto s1 = ops::Const<int>(s, {2, 2}, {2});
247     auto b = ops::internal::BroadcastGradientArgs(s, s0, s1);
248     auto b0 = ops::_Send(s.WithOpName("b0"), ops::Identity(s, b.r0), "b0",
249                          "sender", 0, "receiver");
250     auto b1 = ops::_Send(s.WithOpName("b1"), ops::Identity(s, b.r1), "b1",
251                          "sender", 0, "receiver");
252     TF_ASSERT_OK(s.ToGraph(&g));
253   }
254 
255   bool was_mutated;
256   TF_ASSERT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
257                             nullptr, &g, &was_mutated));
258   EXPECT_TRUE(was_mutated);
259 
260   std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
261   Node* b0 = index.at("b0");
262   Node* b1 = index.at("b1");
263 
264   EXPECT_EQ(1, b0->num_inputs());
265   ExpectNodeEqual<int>(*(b0->in_nodes().begin()), {0, 1}, {2});
266   EXPECT_EQ(1, b1->num_inputs());
267   ExpectNodeEqual<int>(*(b1->in_nodes().begin()), {}, {0});
268 }
269 
TEST_F(ConstantFoldingTest,TwoOutputsFoldOneOutput)270 TEST_F(ConstantFoldingTest, TwoOutputsFoldOneOutput) {
271   Graph g(OpRegistry::Global());
272   {
273     Scope s = Scope::NewRootScope();
274     auto s0 = ops::Const<int>(s, {1}, {1});
275     auto s1 = ops::Const<int>(s, {2, 2}, {2});
276     auto b = ops::internal::BroadcastGradientArgs(s, s0, s1);
277     auto b0 = ops::_Send(s.WithOpName("b0"), ops::Identity(s, b.r0), "b0",
278                          "sender", 0, "receiver");
279     auto b1_ident = ops::Identity(s.WithOpName("b1_ident"), b.r1);
280     auto b1 =
281         ops::_Send(s.WithOpName("b1"), b1_ident, "b1", "sender", 0, "receiver");
282     TF_ASSERT_OK(s.ToGraph(&g));
283   }
284 
285   ConstantFoldingOptions opts;
286   opts.consider = [](const Node* n) { return "b1_ident" != n->name(); };
287   bool was_mutated;
288   TF_ASSERT_OK(
289       ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
290   EXPECT_TRUE(was_mutated);
291 
292   std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
293   Node* b0 = index.at("b0");
294   Node* b1 = index.at("b1");
295   Node* b1_ident = index.at("b1_ident");
296 
297   // 0th output of b should have been folded.
298   ASSERT_EQ(1, b0->num_inputs());
299   ExpectNodeEqual<int>(*(b0->in_nodes().begin()), {0, 1}, {2});
300   // 1st output of b should still be b1_ident. However, b1_ident's input must
301   // have been replaced with a constant.
302   ASSERT_EQ(1, b1->num_inputs());
303   EXPECT_EQ(*(b1->in_nodes().begin()), b1_ident);
304 
305   ASSERT_EQ(1, b1_ident->num_inputs());
306   ExpectNodeEqual<int>(*(b1_ident->in_nodes().begin()), {}, {0});
307 }
308 
TEST_F(ConstantFoldingTest,TestNoReplaceLargeConstant)309 TEST_F(ConstantFoldingTest, TestNoReplaceLargeConstant) {
310   Graph g(OpRegistry::Global());
311   {
312     Scope s = Scope::NewRootScope();
313     auto s0 = ops::Const<int>(s, 0, {5 * 1024 * 256});
314     auto s1 = ops::Const<int>(s, 0, {5 * 1024 * 256 + 1});
315     auto concat_dim = ops::Const<int>(s, 0);
316     auto concat = ops::Concat(s, {s0, s1}, concat_dim);
317     auto concat_send = ops::_Send(s.WithOpName("concat_send"), concat,
318                                   "concat_send", "sender", 0, "receiver");
319     TF_ASSERT_OK(s.ToGraph(&g));
320   }
321 
322   // The above concat should not have been constant folded.
323   bool was_mutated;
324   TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
325                             nullptr, &g, &was_mutated));
326   EXPECT_FALSE(was_mutated);
327 
328   // Increase the limit and the concat should now be constant folded.
329   ConstantFoldingOptions opt;
330   opt.max_constant_size_in_bytes = 10 * 1024 * 1024 + 4;
331   TF_EXPECT_OK(
332       ConstantFold(opt, nullptr, Env::Default(), nullptr, &g, &was_mutated));
333   EXPECT_TRUE(was_mutated);
334 }
335 
TEST_F(ConstantFoldingTest,TestNoReplaceFunctionCall)336 TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) {
337   FunctionDefLibrary flib;
338   *flib.add_function() = test::function::XTimesTwo();
339 
340   FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
341   Graph g(flib_def);
342   {
343     Scope s = Scope::NewRootScope();
344     auto c = ops::Const<int32>(s.WithOpName("c"), {1}, {1});
345     TF_EXPECT_OK(s.graph()->AddFunctionLibrary(flib));
346 
347     // TODO(phawkins): there is no way to make a function call using the C++
348     // graph builder API.
349     NodeDef def;
350     TF_ASSERT_OK(
351         NodeDefBuilder("times_two", "XTimesTwo", s.graph()->op_registry())
352             .Input(c.name(), 0, DT_INT32)
353             .Finalize(&def));
354     Status status;
355     Node* times_two = s.graph()->AddNode(def, &status);
356     TF_ASSERT_OK(status);
357     TF_ASSERT_OK(s.DoShapeInference(times_two));
358     s.graph()->AddEdge(c.node(), 0, times_two, 0);
359 
360     auto times_two_send =
361         ops::_Send(s.WithOpName("times_two_send"), Output(times_two),
362                    "times_two_send", "sender", 0, "receiver");
363     TF_ASSERT_OK(s.ToGraph(&g));
364   }
365 
366   // The above function call should not have been constant folded.
367   bool was_mutated;
368   TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
369                             nullptr, &g, &was_mutated));
370   EXPECT_FALSE(was_mutated);
371 }
372 
373 REGISTER_OP("ConstantFoldingTestOp")
374     .Input("a: int64")
375     .Output("b: int64")
376     .SetShapeFn(shape_inference::UnknownShape);
377 
TEST_F(ConstantFoldingTest,TestNoReplaceNonCPUOp)378 TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) {
379   Graph g(OpRegistry::Global());
380   {
381     Scope s = Scope::NewRootScope();
382     auto aconst = ops::Const<int64>(s, 0, {5});
383 
384     NodeDef def;
385     TF_ASSERT_OK(NodeDefBuilder("testop", "ConstantFoldingTestOp")
386                      .Input(aconst.name(), 0, DT_INT64)
387                      .Finalize(&def));
388     Status status;
389     Node* non_cpu = s.graph()->AddNode(def, &status);
390     TF_ASSERT_OK(status);
391     TF_ASSERT_OK(s.DoShapeInference(non_cpu));
392 
393     auto non_cpu_send =
394         ops::_Send(s.WithOpName("non_cpu_send"), Output(non_cpu),
395                    "non_cpu_send", "sender", 0, "receiver");
396     TF_ASSERT_OK(s.ToGraph(&g));
397   }
398 
399   // The non-CPU op should not have been constant folded.
400   bool was_mutated;
401   TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
402                             nullptr, &g, &was_mutated));
403   EXPECT_FALSE(was_mutated);
404 }
405 
TEST_F(ConstantFoldingTest,ControlDependencies)406 TEST_F(ConstantFoldingTest, ControlDependencies) {
407   Graph g(OpRegistry::Global());
408   {
409     Scope s = Scope::NewRootScope();
410     auto c0 = ops::Const<int>(s, 1);
411     auto recv1 = ops::_Recv(s.WithOpName("recv1"), DT_FLOAT, "recv1", "sender",
412                             0, "receiver");
413     auto c1 = ops::Const<int>(s.WithControlDependencies(recv1), 2);
414     auto recv2 = ops::_Recv(s.WithOpName("recv2"), DT_FLOAT, "recv2", "sender",
415                             0, "receiver");
416     auto c2 = ops::Const<int>(s.WithControlDependencies(recv2), 3);
417     auto add = ops::Add(s.WithControlDependencies(c2), c0, c1);
418     auto send =
419         ops::_Send(s.WithOpName("send"), add, "send", "sender", 0, "receiver");
420     TF_ASSERT_OK(s.ToGraph(&g));
421   }
422   bool was_mutated;
423   TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
424                             nullptr, &g, &was_mutated));
425   EXPECT_TRUE(was_mutated);
426 
427   std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
428   Node* recv1 = index.at("recv1");
429   Node* recv2 = index.at("recv2");
430   Node* send = index.at("send");
431 
432   ASSERT_EQ(1, send->num_inputs());
433   Node* p = *(send->in_nodes().begin());
434   ExpectNodeEqual<int>(p, {3}, {});
435 
436   ASSERT_EQ(2, p->in_edges().size());
437   for (const Edge* e : p->in_edges()) {
438     EXPECT_TRUE(e->IsControlEdge());
439     EXPECT_TRUE(e->src() == recv1 || e->src() == recv2) << e->src()->name();
440   }
441 }
442 
TEST_F(ConstantFoldingTest,SimpleShapeKnown)443 TEST_F(ConstantFoldingTest, SimpleShapeKnown) {
444   Graph g(OpRegistry::Global());
445   {
446     Scope s = Scope::NewRootScope();
447     Output recv0 = ops::_Recv(s.WithOpName("recv0"), DT_FLOAT, "recv0",
448                               "sender", 0, "receiver");
449     auto shape = ops::Shape(s.WithOpName("shape"), recv0);
450     Output recv1 = ops::_Recv(s.WithOpName("recv1"), DT_FLOAT, "recv1",
451                               "sender", 0, "receiver");
452     auto shape_n = ops::ShapeN(s.WithOpName("shape_n"), {recv0, recv1});
453     auto rank = ops::Rank(s.WithOpName("rank"), recv0);
454     auto size = ops::Size(s.WithOpName("size"), recv1);
455     auto recv2 = ops::_Recv(s.WithOpName("recv2"), DT_FLOAT, "recv2", "sender",
456                             0, "receiver");
457     auto c = ops::Const<int>(s.WithControlDependencies(recv2), 3);
458     auto add0 = ops::Add(s.WithControlDependencies(c), rank, size);
459     auto add1 = ops::Add(s, shape, shape_n[0]);
460     auto add2 = ops::Add(s, shape_n[1], shape_n[1]);
461     auto send0 = ops::_Send(s.WithOpName("send0"), add0, "send0", "sender", 0,
462                             "receiver");
463     auto send1 = ops::_Send(s.WithOpName("send1"), add1, "send1", "sender", 0,
464                             "receiver");
465     auto send2 = ops::_Send(s.WithOpName("send2"), add2, "send2", "sender", 0,
466                             "receiver");
467     TF_ASSERT_OK(s.ToGraph(&g));
468   }
469   std::unordered_map<string, Node*> orig_index = g.BuildNodeNameIndex();
470   Node* recv0 = orig_index.at("recv0");
471   Node* recv1 = orig_index.at("recv1");
472   PartialTensorShape ps0;
473   int r0_dims[] = {1, 2};
474   TF_EXPECT_OK(PartialTensorShape::MakePartialShape(r0_dims, 2, &ps0));
475   PartialTensorShape ps1;
476   int r1_dims[] = {2, 3, 4};
477   TF_EXPECT_OK(PartialTensorShape::MakePartialShape<int>(r1_dims, 3, &ps1));
478   std::unordered_map<string, std::vector<PartialTensorShape>> map;
479   map[recv0->name()].push_back(ps0);
480   map[recv1->name()].push_back(ps1);
481   ConstantFoldingOptions opts;
482   opts.shape_map = &map;
483   bool was_mutated;
484   TF_EXPECT_OK(
485       ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
486   EXPECT_TRUE(was_mutated);
487 
488   std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
489   Node* recv2 = index.at("recv2");
490   Node* send0 = index.at("send0");
491   Node* send1 = index.at("send1");
492   Node* send2 = index.at("send2");
493 
494   ASSERT_EQ(1, send0->num_inputs());
495   Node* cf0 = *(send0->in_nodes().begin());
496   ExpectNodeEqual<int>(cf0, {26}, {});
497 
498   ASSERT_EQ(1, send1->num_inputs());
499   Node* cf1 = *(send1->in_nodes().begin());
500   ExpectNodeEqual<int>(cf1, {2, 4}, {2});
501 
502   ASSERT_EQ(1, send2->num_inputs());
503   Node* cf2 = *(send2->in_nodes().begin());
504   ExpectNodeEqual<int>(cf2, {4, 6, 8}, {3});
505 
506   ASSERT_EQ(3, cf0->in_edges().size());
507   for (const Edge* e : cf0->in_edges()) {
508     EXPECT_TRUE(e->IsControlEdge());
509     EXPECT_TRUE(e->src() == recv0 || e->src() == recv1 || e->src() == recv2)
510         << e->src()->name();
511   }
512 
513   ASSERT_EQ(2, cf1->in_edges().size());
514   for (const Edge* e : cf1->in_edges()) {
515     EXPECT_TRUE(e->IsControlEdge());
516     EXPECT_TRUE(e->src() == recv0 || e->src() == recv1) << e->src()->name();
517   }
518 
519   ASSERT_EQ(2, cf2->in_edges().size());
520   for (const Edge* e : cf2->in_edges()) {
521     EXPECT_TRUE(e->IsControlEdge());
522     EXPECT_TRUE(e->src() == recv0 || e->src() == recv1) << e->src()->name();
523   }
524 }
525 
TEST_F(ConstantFoldingTest,PartialShape)526 TEST_F(ConstantFoldingTest, PartialShape) {
527   Graph g(OpRegistry::Global());
528   {
529     Scope s = Scope::NewRootScope();
530     Output recv0 = ops::_Recv(s.WithOpName("recv0"), DT_FLOAT, "recv0",
531                               "sender", 0, "receiver");
532     Output recv1 = ops::_Recv(s.WithOpName("recv1"), DT_FLOAT, "recv1",
533                               "sender", 0, "receiver");
534     auto shape = ops::Shape(s.WithOpName("shape"), recv0);
535     auto rank0 = ops::Rank(s.WithOpName("rank0"), recv0);
536     auto rank1 = ops::Rank(s.WithOpName("rank1"), recv1);
537     auto size = ops::Size(s.WithOpName("size"), recv0);
538     auto send0 = ops::_Send(s.WithOpName("send0"), rank0, "send0", "sender", 0,
539                             "receiver");
540     auto send1 = ops::_Send(s.WithOpName("send1"), shape, "send1", "sender", 0,
541                             "receiver");
542     auto send2 = ops::_Send(s.WithOpName("send2"), size, "send2", "sender", 0,
543                             "receiver");
544     auto send3 = ops::_Send(s.WithOpName("send3"), rank1, "send3", "sender", 0,
545                             "receiver");
546     TF_ASSERT_OK(s.ToGraph(&g));
547   }
548   std::unordered_map<string, Node*> orig_index = g.BuildNodeNameIndex();
549   Node* recv0 = orig_index.at("recv0");
550   Node* recv1 = orig_index.at("recv1");
551   PartialTensorShape ps0;
552   int r0_dims[] = {-1, -1};
553   TF_EXPECT_OK(PartialTensorShape::MakePartialShape(r0_dims, 2, &ps0));
554   PartialTensorShape ps1;
555   std::unordered_map<string, std::vector<PartialTensorShape>> map;
556   map[recv0->name()].push_back(ps0);
557   map[recv1->name()].push_back(ps1);
558   ConstantFoldingOptions opts;
559   opts.shape_map = &map;
560   bool was_mutated;
561   TF_EXPECT_OK(
562       ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
563   EXPECT_TRUE(was_mutated);
564 
565   std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
566   Node* shape = index.at("shape");
567   Node* size = index.at("size");
568   Node* rank1 = index.at("rank1");
569   Node* send0 = index.at("send0");
570   Node* send1 = index.at("send1");
571   Node* send2 = index.at("send2");
572   Node* send3 = index.at("send3");
573 
574   ASSERT_EQ(1, send0->num_inputs());
575   Node* cf0 = *(send0->in_nodes().begin());
576   ExpectNodeEqual<int>(cf0, {2}, {});
577 
578   ASSERT_EQ(1, send1->num_inputs());
579   Node* ncf1 = *(send1->in_nodes().begin());
580   EXPECT_EQ(ncf1, shape);
581 
582   ASSERT_EQ(1, send2->num_inputs());
583   Node* ncf2 = *(send2->in_nodes().begin());
584   EXPECT_EQ(ncf2, size);
585 
586   ASSERT_EQ(1, send3->num_inputs());
587   Node* ncf3 = *(send3->in_nodes().begin());
588   EXPECT_EQ(ncf3, rank1);
589 }
590 
TEST_F(ConstantFoldingTest,ConstShapeKnown)591 TEST_F(ConstantFoldingTest, ConstShapeKnown) {
592   Graph g(OpRegistry::Global());
593   {
594     Scope s = Scope::NewRootScope();
595     auto recv0 = ops::_Recv(s.WithOpName("recv0"), DT_FLOAT, "recv0", "sender",
596                             0, "receiver");
597     auto c0 =
598         ops::Const<int>(s.WithOpName("c0").WithControlDependencies(recv0), 1);
599     auto rank = ops::Rank(s.WithOpName("rank"), c0);
600     auto add0 = ops::Add(s, rank, rank);
601     auto send0 = ops::_Send(s.WithOpName("send0"), add0, "send0", "sender", 0,
602                             "receiver");
603     TF_ASSERT_OK(s.ToGraph(&g));
604   }
605   std::unordered_map<string, Node*> orig_index = g.BuildNodeNameIndex();
606   Node* c0 = orig_index.at("c0");
607   PartialTensorShape ps0;
608   int c0_dims[] = {};
609   TF_EXPECT_OK(PartialTensorShape::MakePartialShape(c0_dims, 0, &ps0));
610   std::unordered_map<string, std::vector<PartialTensorShape>> map;
611   map[c0->name()].push_back(ps0);
612   ConstantFoldingOptions opts;
613   opts.shape_map = &map;
614   bool was_mutated;
615   TF_EXPECT_OK(
616       ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
617   EXPECT_TRUE(was_mutated);
618 
619   std::unordered_map<string, Node*> index = g.BuildNodeNameIndex();
620   Node* recv0 = index.at("recv0");
621   Node* send0 = index.at("send0");
622 
623   ASSERT_EQ(1, send0->num_inputs());
624   Node* cf0 = *(send0->in_nodes().begin());
625   ExpectNodeEqual<int>(cf0, {0}, {});
626 
627   ASSERT_EQ(1, cf0->in_edges().size());
628   for (const Edge* e : cf0->in_edges()) {
629     EXPECT_TRUE(e->IsControlEdge());
630     EXPECT_TRUE(e->src() == recv0) << e->src()->name();
631   }
632 }
633 
TEST_F(ConstantFoldingTest,NoReplacePartialOutput)634 TEST_F(ConstantFoldingTest, NoReplacePartialOutput) {
635   Graph g(OpRegistry::Global());
636   {
637     Scope s = Scope::NewRootScope().ExitOnError().WithAssignedDevice("/gpu:0");
638 
639     auto c0 = ops::Const<float>(s.WithOpName("c0"), {5.0, 2.0, 8.0, 1.0}, {4});
640     auto k = ops::Const<int>(s.WithOpName("k"), 3);
641     auto topK =
642         ops::TopK(s.WithOpName("topK"), c0, k, ops::TopK::Sorted(false));
643     auto send_values = ops::_Send(s.WithOpName("send_values"), topK.values,
644                                   "send_values", "sender", 0, "receiver");
645     auto send_indices = ops::_Send(s.WithOpName("send_indices"), topK.indices,
646                                    "send_indices", "sender", 0, "receiver");
647     TF_ASSERT_OK(s.ToGraph(&g));
648   }
649   bool was_mutated;
650   TF_EXPECT_OK(ConstantFold(
651       ConstantFoldingOptions{}, nullptr, Env::Default(),
652       FakeDevice::Make("/job:tpu_worker/replica:0/task:0/device:GPU:0",
653                        DEVICE_GPU)
654           .get(),
655       &g, &was_mutated));
656   EXPECT_FALSE(was_mutated);
657 }
658 
659 namespace {
660 
661 const char kTestMemRegionName[] = "test://test";
662 
663 class TestReadOnlyMemoryRegion : public ::tensorflow::ReadOnlyMemoryRegion {
664  public:
665   ~TestReadOnlyMemoryRegion() override = default;
TestReadOnlyMemoryRegion(const void * data,uint64 length)666   TestReadOnlyMemoryRegion(const void* data, uint64 length)
667       : data_(data), length_(length) {}
data()668   const void* data() override { return data_; }
length()669   uint64 length() override { return length_; }
670 
671  protected:
672   const void* data_;
673   uint64 length_;
674 };
675 
676 class TestTFFileSystem : public ::tensorflow::NullFileSystem {
677  public:
TestTFFileSystem()678   TestTFFileSystem()
679       : ::tensorflow::NullFileSystem(),
680         data_tensor_(test::AsTensor<double>({1., 2., 3., 4.}, {2, 2})) {}
681 
682   using ::tensorflow::NullFileSystem::NewReadOnlyMemoryRegionFromFile;
683 
NewReadOnlyMemoryRegionFromFile(const string & fname,::tensorflow::TransactionToken * token,std::unique_ptr<::tensorflow::ReadOnlyMemoryRegion> * result)684   ::tensorflow::Status NewReadOnlyMemoryRegionFromFile(
685       const string& fname, ::tensorflow::TransactionToken* token,
686       std::unique_ptr<::tensorflow::ReadOnlyMemoryRegion>* result) override {
687     if (fname != kTestMemRegionName) {
688       return ::tensorflow::errors::Unimplemented(
689           "NewReadOnlyMemoryRegionFromFile unimplemented");
690     }
691     const ::tensorflow::StringPiece sp = data_tensor_.tensor_data();
692     *result = std::unique_ptr<::tensorflow::ReadOnlyMemoryRegion>(
693         new TestReadOnlyMemoryRegion(sp.data(), sp.size()));
694     return ::tensorflow::Status::OK();
695   }
696 
697  protected:
698   ::tensorflow::Tensor data_tensor_;
699 };
700 
701 // A test TF environment that checks that the environment was used.
702 class TestTFEnvironment : public ::tensorflow::EnvWrapper {
703  public:
704   using tf_base = ::tensorflow::EnvWrapper;
TestTFEnvironment()705   TestTFEnvironment() : ::tensorflow::EnvWrapper(Default()) {}
GetFileSystemForFile(const string & fname,::tensorflow::FileSystem ** result)706   ::tensorflow::Status GetFileSystemForFile(
707       const string& fname, ::tensorflow::FileSystem** result) override {
708     was_used_ = true;
709     if (fname == "test://test") {
710       *result = &test_filesystem_;
711       return ::tensorflow::Status::OK();
712     }
713     return tf_base::GetFileSystemForFile(fname, result);
714   }
was_used() const715   bool was_used() const { return was_used_; }
716 
717  protected:
718   TestTFFileSystem test_filesystem_;
719   bool was_used_ = false;
720 };
721 }  // namespace
722 
TEST_F(ConstantFoldingTest,TestImmutableConst)723 TEST_F(ConstantFoldingTest, TestImmutableConst) {
724   Graph g(OpRegistry::Global());
725   Scope root = Scope::NewRootScope();
726 
727   auto a = ops::ImmutableConst(root, DT_DOUBLE, {2, 2}, kTestMemRegionName);
728   auto b = ops::Const<double>(root, {1.0, 2.0, 3.0, 4.0}, {2, 2});
729   auto c = ops::RandomGamma(root, {2, 2}, 2.0);
730   auto result1 = ops::MatMul(root, a, b);
731   auto result2 = ops::MatMul(root, result1, c);
732   TF_ASSERT_OK(root.ToGraph(&g));
733   TestTFEnvironment test_env;
734   bool was_mutated;
735   Status status = ConstantFold(ConstantFoldingOptions{}, nullptr,
736                                Env::Default(), nullptr, &g, &was_mutated);
737   EXPECT_FALSE(was_mutated);
738   EXPECT_FALSE(status.ok());
739   TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, &test_env,
740                             nullptr, &g, &was_mutated));
741   EXPECT_TRUE(was_mutated);
742 }
743 
744 }  // namespace
745 }  // namespace tensorflow
746