• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
17 
18 #include "absl/strings/match.h"
19 #include "tensorflow/cc/ops/const_op.h"
20 #include "tensorflow/cc/ops/function_ops.h"
21 #include "tensorflow/cc/ops/resource_variable_ops.h"
22 #include "tensorflow/cc/ops/standard_ops.h"
23 #include "tensorflow/core/framework/graph_to_functiondef.h"
24 #include "tensorflow/core/framework/node_def_util.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/test.h"
28 
29 namespace tensorflow {
30 namespace {
31 
32 class GraphHelper {
33  public:
GraphHelper(const Graph & graph)34   explicit GraphHelper(const Graph& graph) : graph_(graph) {
35     for (Node* node : graph.nodes()) {
36       nodes_by_name_[node->name()] = node;
37     }
38   }
39 
GetNodeByName(const string & name)40   Node* GetNodeByName(const string& name) {
41     const auto it = nodes_by_name_.find(name);
42     if (it != nodes_by_name_.end()) {
43       return it->second;
44     }
45     for (const auto& entry : nodes_by_name_) {
46       if (absl::StartsWith(entry.first, name)) {
47         return entry.second;
48       }
49     }
50     return nullptr;
51   }
52 
SetAssignedDevice(const string & node_name,const string & device_name)53   void SetAssignedDevice(const string& node_name, const string& device_name) {
54     CHECK_NOTNULL(GetNodeByName(node_name))
55         ->set_assigned_device_name(device_name);
56   }
57 
CheckArgNum(const int expected_num)58   void CheckArgNum(const int expected_num) {
59     int arg_num = 0;
60     for (Node* node : graph_.op_nodes()) {
61       if (node->IsArg()) {
62         arg_num++;
63       }
64     }
65     EXPECT_EQ(arg_num, expected_num);
66   }
67 
CheckAssignedDevice(const string & node_name,const string & expected_device_name)68   void CheckAssignedDevice(const string& node_name,
69                            const string& expected_device_name) {
70     EXPECT_EQ(expected_device_name,
71               CHECK_NOTNULL(GetNodeByName(node_name))->assigned_device_name());
72   }
73 
74  private:
75   const Graph& graph_;
76   // Maps from a node name to a Node* in the graph.
77   absl::flat_hash_map<string, Node*> nodes_by_name_;
78 };
79 
TEST(ReplicatePerReplicaNodesTest,SingleCompositeDevice)80 TEST(ReplicatePerReplicaNodesTest, SingleCompositeDevice) {
81   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
82   Output arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
83   auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32);
84   auto one = ops::Const<int32>(scope.WithOpName("one"), 1);
85   auto write = ops::AssignVariableOp(scope.WithOpName("write"), arg, one);
86   auto ret = ops::_Retval(
87       scope.WithOpName("ret").WithControlDependencies({write}), read, 0);
88 
89   const std::vector<string> underlying_devices = {"TPU:0", "TPU:1"};
90   const absl::flat_hash_map<string, const std::vector<string>*>
91       composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
92 
93   Graph graph(OpRegistry::Global());
94   TF_ASSERT_OK(scope.ToGraph(&graph));
95   {
96     // _Arg(TPU_COMPOSITE:0) -> ReadVariableOp(TPU:0);
97     // Const(CPU:0) -> AssignVariableOp(TPU_COMPOSITE:0);
98     // ReadVariableOp(TPU:0) -> _Retval(CPU:0)
99     ASSERT_EQ(graph.num_op_nodes(), 5);
100     GraphHelper helper(graph);
101     helper.SetAssignedDevice("arg", "TPU_COMPOSITE:0");
102     helper.SetAssignedDevice("read", "TPU:0");
103     helper.SetAssignedDevice("one", "CPU:0");
104     helper.SetAssignedDevice("write", "TPU_COMPOSITE:0");
105     helper.SetAssignedDevice("ret", "CPU:0");
106   }
107 
108   TF_EXPECT_OK(
109       ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
110 
111   {
112     // _Arg(TPU:0, TPU:1) -> ReadVariableOp(TPU:0);
113     // Const(CPU:0) -> AssignVariableOp(TPU:0, TPU:1);
114     // ReadVariableOp(TPU:0) -> _Retval(CPU:0)
115     EXPECT_EQ(graph.num_op_nodes(), 7);
116     GraphHelper helper(graph);
117     helper.CheckArgNum(2);
118     helper.CheckAssignedDevice("arg/R0", "TPU:0");
119     helper.CheckAssignedDevice("arg/R1", "TPU:1");
120     helper.CheckAssignedDevice("read", "TPU:0");
121     helper.CheckAssignedDevice("one", "CPU:0");
122     helper.CheckAssignedDevice("write/R0", "TPU:0");
123     helper.CheckAssignedDevice("write/R1", "TPU:1");
124     helper.CheckAssignedDevice("ret", "CPU:0");
125   }
126 }
127 
TEST(ReplicatePerReplicaNodesTest,SingleCompositeDeviceToSingleDevice)128 TEST(ReplicatePerReplicaNodesTest, SingleCompositeDeviceToSingleDevice) {
129   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
130   Output arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
131   auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32);
132   auto ret = ops::_Retval(scope.WithOpName("ret"), read, 0);
133 
134   const std::vector<string> underlying_devices = {"TPU:0"};
135   const absl::flat_hash_map<string, const std::vector<string>*>
136       composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
137 
138   Graph graph(OpRegistry::Global());
139   TF_ASSERT_OK(scope.ToGraph(&graph));
140   {
141     // _Arg(TPU_COMPOSITE:0) -> ReadVariableOp(TPU:0) -> _Retval(CPU:0)
142     ASSERT_EQ(graph.num_op_nodes(), 3);
143     GraphHelper helper(graph);
144     helper.SetAssignedDevice("arg", "TPU_COMPOSITE:0");
145     helper.SetAssignedDevice("read", "TPU:0");
146     helper.SetAssignedDevice("ret", "CPU:0");
147   }
148 
149   TF_EXPECT_OK(
150       ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
151 
152   {
153     // _Arg(TPU:0) -> ReadVariableOp(TPU:0) -> _Retval(CPU:0)
154     EXPECT_EQ(graph.num_op_nodes(), 3);
155     GraphHelper helper(graph);
156     helper.CheckArgNum(1);
157     helper.CheckAssignedDevice("arg", "TPU:0");
158     helper.CheckAssignedDevice("read", "TPU:0");
159     helper.CheckAssignedDevice("ret", "CPU:0");
160   }
161 }
162 
TEST(ReplicatePerReplicaNodesTest,MultipleCompositeDevices)163 TEST(ReplicatePerReplicaNodesTest, MultipleCompositeDevices) {
164   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
165   Output arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_RESOURCE, 0);
166   Output arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 0);
167   auto read0 = ops::ReadVariableOp(scope.WithOpName("read0"), arg0, DT_INT32);
168   auto read1 = ops::ReadVariableOp(scope.WithOpName("read1"), arg1, DT_INT32);
169   auto identity0 = ops::Identity(scope.WithOpName("identity0"), read0);
170   auto identity1 = ops::Identity(scope.WithOpName("identity1"), read1);
171   auto add = ops::Add(scope.WithOpName("add"), identity0, identity1);
172   auto ret = ops::_Retval(scope.WithOpName("ret"), add, 0);
173 
174   const std::vector<string> underlying_devices_0 = {"TPU:0", "TPU:1"};
175   const std::vector<string> underlying_devices_1 = {"TPU:2", "TPU:3"};
176   const absl::flat_hash_map<string, const std::vector<string>*>
177       composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices_0},
178                            {"TPU_COMPOSITE:1", &underlying_devices_1}};
179 
180   Graph graph(OpRegistry::Global());
181   TF_ASSERT_OK(scope.ToGraph(&graph));
182   {
183     // _Arg(TPU_COMPOSITE:0) -> ReadVariableOp(TPU_COMPOSITE:0) ->
184     // Identity(TPU:1)
185     // _Arg(TPU_COMPOSITE:1) -> ReadVariableOp(TPU_COMPOSITE:1)
186     // -> Identity(TPU:3)
187     // Identity(TPU:1), Identity(TPU:3) -> Add(TPU:0)-> _Retval(CPU:0)
188     ASSERT_EQ(graph.num_op_nodes(), 8);
189     GraphHelper helper(graph);
190     helper.SetAssignedDevice("arg0", "TPU_COMPOSITE:0");
191     helper.SetAssignedDevice("read0", "TPU_COMPOSITE:0");
192     helper.SetAssignedDevice("identity0", "TPU:1");
193     helper.SetAssignedDevice("arg1", "TPU_COMPOSITE:1");
194     helper.SetAssignedDevice("read1", "TPU_COMPOSITE:1");
195     helper.SetAssignedDevice("identity1", "TPU:3");
196     helper.SetAssignedDevice("add", "TPU:0");
197     helper.SetAssignedDevice("ret", "CPU:0");
198   }
199 
200   TF_EXPECT_OK(
201       ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
202 
203   {
204     // _Arg(TPU:0, TPU:1, TPU:2, TPU:3) -> ReadVariableOp(TPU:0, TPU:1, TPU:2,
205     // TPU:3) -> Identity(TPU:1, TPU:3) -> Add(TPU:0)-> _Retval(CPU:0)
206     EXPECT_EQ(graph.num_op_nodes(), 12);
207     GraphHelper helper(graph);
208     helper.CheckArgNum(4);
209     helper.CheckAssignedDevice("arg0/R0", "TPU:0");
210     helper.CheckAssignedDevice("arg0/R1", "TPU:1");
211     helper.CheckAssignedDevice("arg1/R0", "TPU:2");
212     helper.CheckAssignedDevice("arg1/R1", "TPU:3");
213     helper.CheckAssignedDevice("read0/R0", "TPU:0");
214     helper.CheckAssignedDevice("read0/R1", "TPU:1");
215     helper.CheckAssignedDevice("read1/R0", "TPU:2");
216     helper.CheckAssignedDevice("read1/R1", "TPU:3");
217     helper.CheckAssignedDevice("identity0", "TPU:1");
218     helper.CheckAssignedDevice("identity1", "TPU:3");
219     helper.CheckAssignedDevice("add", "TPU:0");
220     helper.CheckAssignedDevice("ret", "CPU:0");
221   }
222 }
223 
TEST(ReplicatePerReplicaNodesTest,NestedFunctions)224 TEST(ReplicatePerReplicaNodesTest, NestedFunctions) {
225   const std::vector<string> underlying_devices = {"TPU:0", "TPU:1"};
226   const absl::flat_hash_map<string, const std::vector<string>*>
227       composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
228 
229   FunctionDefLibrary fdef_lib;
230   FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
231   {
232     Scope scope = Scope::NewRootScope().ExitOnError();
233     auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
234     auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32);
235     auto ret = ops::_Retval(scope.WithOpName("ret"), read, 0);
236     Graph graph(OpRegistry::Global());
237     TF_ASSERT_OK(scope.ToGraph(&graph));
238     GraphHelper helper(graph);
239     helper.SetAssignedDevice("arg", "TPU_COMPOSITE:0");
240     helper.SetAssignedDevice("read", "TPU:0");
241     helper.SetAssignedDevice("ret", "CPU:0");
242     FunctionDef fdef;
243     TF_ASSERT_OK(GraphToFunctionDef(graph, "Func", &fdef));
244     *fdef_lib.add_function() = fdef;
245     TF_ASSERT_OK(flib_def.AddFunctionDef(fdef));
246   }
247 
248   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
249   Output arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
250   TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
251   NodeDef def;
252   TF_ASSERT_OK(NodeDefBuilder("func", "Func", &flib_def)
253                    .Input(arg.name(), 0, DT_RESOURCE)
254                    .Finalize(&def));
255   Status status;
256   Node* func = scope.graph()->AddNode(def, &status);
257   TF_ASSERT_OK(status);
258   scope.graph()->AddEdge(arg.node(), 0, func, 0);
259   auto ret = ops::_Retval(scope.WithOpName("ret"), Output(func), 0);
260   Graph graph(OpRegistry::Global());
261   TF_ASSERT_OK(scope.ToGraph(&graph));
262   {
263     // _Arg(TPU_COMPOSITE:0) -> Func(CPU:0) -> _Retval(CPU:0)
264     GraphHelper helper(graph);
265     EXPECT_EQ(graph.num_op_nodes(), 3);
266     helper.SetAssignedDevice("arg", "TPU_COMPOSITE:0");
267     helper.SetAssignedDevice("func", "CPU:0");
268     helper.SetAssignedDevice("ret", "CPU:0");
269   }
270 
271   TF_EXPECT_OK(
272       ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
273 
274   {
275     // _Arg(TPU:0), _Arg(TPU:1) -> Pack(CPU:0) -> Func(CPU:0) -> _Retval(CPU:0)
276     EXPECT_EQ(graph.num_op_nodes(), 5);
277     GraphHelper helper(graph);
278     helper.CheckArgNum(2);
279     helper.CheckAssignedDevice("arg/R0", "TPU:0");
280     helper.CheckAssignedDevice("arg/R1", "TPU:1");
281     helper.CheckAssignedDevice("arg/Packed", "CPU:0");
282     helper.CheckAssignedDevice("func", "CPU:0");
283     helper.CheckAssignedDevice("ret", "CPU:0");
284     const EdgeSet& packed_in_edges =
285         helper.GetNodeByName("arg/Packed")->in_edges();
286     EXPECT_EQ(packed_in_edges.size(), 2);
287     auto it = packed_in_edges.begin();
288     EXPECT_EQ(helper.GetNodeByName("arg/R0"), (*it++)->src());
289     EXPECT_EQ(helper.GetNodeByName("arg/R1"), (*it)->src());
290     const EdgeSet& func_in_edges = helper.GetNodeByName("func")->in_edges();
291     EXPECT_EQ(func_in_edges.size(), 1);
292     EXPECT_EQ(helper.GetNodeByName("arg/Packed"),
293               (*func_in_edges.begin())->src());
294   }
295 }
296 
TEST(ReplicatePerReplicaNodesTest,DeadArgNodes)297 TEST(ReplicatePerReplicaNodesTest, DeadArgNodes) {
298   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
299   Output arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
300   auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32);
301   auto ret = ops::_Retval(scope.WithOpName("ret"), read, 0);
302 
303   const std::vector<string> underlying_devices = {"TPU:0", "TPU:1"};
304   const absl::flat_hash_map<string, const std::vector<string>*>
305       composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
306 
307   Graph graph(OpRegistry::Global());
308   TF_ASSERT_OK(scope.ToGraph(&graph));
309   {
310     // _Arg(TPU_COMPOSITE:0) -> ReadVariableOp(TPU:0) -> _Retval(CPU:0)
311     ASSERT_EQ(graph.num_op_nodes(), 3);
312     GraphHelper helper(graph);
313     helper.SetAssignedDevice("arg", "TPU_COMPOSITE:0");
314     helper.SetAssignedDevice("read", "TPU:0");
315     helper.SetAssignedDevice("ret", "CPU:0");
316   }
317 
318   TF_EXPECT_OK(
319       ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
320 
321   {
322     // _Arg(TPU:0) -> ReadVariableOp(TPU:0) -> _Retval(CPU:0)
323     // "arg/R1" is a dead node, so gets removed.
324     EXPECT_EQ(graph.num_op_nodes(), 3);
325     GraphHelper helper(graph);
326     helper.CheckArgNum(1);
327     helper.CheckAssignedDevice("arg/R0", "TPU:0");
328     helper.CheckAssignedDevice("read", "TPU:0");
329     helper.CheckAssignedDevice("ret", "CPU:0");
330   }
331 }
332 
333 }  // namespace
334 }  // namespace tensorflow
335