• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
17 
18 #include "absl/strings/match.h"
19 #include "tensorflow/cc/ops/const_op.h"
20 #include "tensorflow/cc/ops/function_ops.h"
21 #include "tensorflow/cc/ops/resource_variable_ops.h"
22 #include "tensorflow/cc/ops/standard_ops.h"
23 #include "tensorflow/core/framework/graph_to_functiondef.h"
24 #include "tensorflow/core/framework/node_def_util.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/test.h"
28 
29 namespace tensorflow {
30 namespace {
31 
32 class GraphHelper {
33  public:
GraphHelper(const Graph & graph)34   explicit GraphHelper(const Graph& graph) : graph_(graph) {
35     for (Node* node : graph.nodes()) {
36       nodes_by_name_[node->name()] = node;
37     }
38   }
39 
GetNodeByName(const string & name)40   Node* GetNodeByName(const string& name) {
41     const auto it = nodes_by_name_.find(name);
42     if (it != nodes_by_name_.end()) {
43       return it->second;
44     }
45     for (const auto& entry : nodes_by_name_) {
46       if (absl::StartsWith(entry.first, name)) {
47         return entry.second;
48       }
49     }
50     return nullptr;
51   }
52 
SetAssignedDevice(const string & node_name,const string & device_name)53   void SetAssignedDevice(const string& node_name, const string& device_name) {
54     CHECK_NOTNULL(GetNodeByName(node_name))
55         ->set_assigned_device_name(device_name);
56   }
57 
CheckArgNum(const int expected_num)58   void CheckArgNum(const int expected_num) {
59     int arg_num = 0;
60     for (Node* node : graph_.op_nodes()) {
61       if (node->IsArg()) {
62         arg_num++;
63       }
64     }
65     EXPECT_EQ(arg_num, expected_num);
66   }
67 
CheckAssignedDevice(const string & node_name,const string & expected_device_name)68   void CheckAssignedDevice(const string& node_name,
69                            const string& expected_device_name) {
70     EXPECT_EQ(expected_device_name,
71               CHECK_NOTNULL(GetNodeByName(node_name))->assigned_device_name());
72   }
73 
74  private:
75   const Graph& graph_;
76   // Maps from a node name to a Node* in the graph.
77   absl::flat_hash_map<string, Node*> nodes_by_name_;
78 };
79 
TEST(ReplicatePerReplicaNodesTest,SingleCompositeDevice)80 TEST(ReplicatePerReplicaNodesTest, SingleCompositeDevice) {
81   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
82   Output arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
83   auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32);
84   auto one = ops::Const<int32>(scope.WithOpName("one"), 1);
85   auto write = ops::AssignVariableOp(scope.WithOpName("write"), arg, one);
86   auto ret = ops::_Retval(
87       scope.WithOpName("ret").WithControlDependencies({write}), read, 0);
88 
89   const std::vector<string> underlying_devices = {"TPU:0", "TPU:1"};
90   const absl::flat_hash_map<string, const std::vector<string>*>
91       composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
92 
93   Graph graph(OpRegistry::Global());
94   TF_ASSERT_OK(scope.ToGraph(&graph));
95   {
96     // _Arg(TPU_COMPOSITE:0) -> ReadVariableOp(TPU:0);
97     // Const(CPU:0) -> AssignVariableOp(TPU_COMPOSITE:0);
98     // ReadVariableOp(TPU:0) -> _Retval(CPU:0)
99     ASSERT_EQ(graph.num_op_nodes(), 5);
100     GraphHelper helper(graph);
101     helper.SetAssignedDevice("arg", "TPU_COMPOSITE:0");
102     helper.SetAssignedDevice("read", "TPU:0");
103     helper.SetAssignedDevice("one", "CPU:0");
104     helper.SetAssignedDevice("write", "TPU_COMPOSITE:0");
105     helper.SetAssignedDevice("ret", "CPU:0");
106   }
107 
108   TF_EXPECT_OK(
109       ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
110 
111   {
112     // _Arg(TPU:0, TPU:1) -> ReadVariableOp(TPU:0);
113     // Const(CPU:0) -> AssignVariableOp(TPU:0, TPU:1);
114     // ReadVariableOp(TPU:0) -> _Retval(CPU:0)
115     EXPECT_EQ(graph.num_op_nodes(), 7);
116     GraphHelper helper(graph);
117     helper.CheckArgNum(2);
118     helper.CheckAssignedDevice("arg/R0", "TPU:0");
119     helper.CheckAssignedDevice("arg/R1", "TPU:1");
120     helper.CheckAssignedDevice("read", "TPU:0");
121     helper.CheckAssignedDevice("one", "CPU:0");
122     helper.CheckAssignedDevice("write/R0", "TPU:0");
123     helper.CheckAssignedDevice("write/R1", "TPU:1");
124     helper.CheckAssignedDevice("ret", "CPU:0");
125   }
126 }
127 
TEST(ReplicatePerReplicaNodesTest,SingleCompositeDeviceToSingleDevice)128 TEST(ReplicatePerReplicaNodesTest, SingleCompositeDeviceToSingleDevice) {
129   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
130   Output arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
131   auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32);
132   auto ret = ops::_Retval(scope.WithOpName("ret"), read, 0);
133 
134   const std::vector<string> underlying_devices = {"TPU:0"};
135   const absl::flat_hash_map<string, const std::vector<string>*>
136       composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
137 
138   Graph graph(OpRegistry::Global());
139   TF_ASSERT_OK(scope.ToGraph(&graph));
140   {
141     // _Arg(TPU_COMPOSITE:0) -> ReadVariableOp(TPU:0) -> _Retval(CPU:0)
142     ASSERT_EQ(graph.num_op_nodes(), 3);
143     GraphHelper helper(graph);
144     helper.SetAssignedDevice("arg", "TPU_COMPOSITE:0");
145     helper.SetAssignedDevice("read", "TPU:0");
146     helper.SetAssignedDevice("ret", "CPU:0");
147   }
148 
149   TF_EXPECT_OK(
150       ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
151 
152   {
153     // _Arg(TPU:0) -> ReadVariableOp(TPU:0) -> _Retval(CPU:0)
154     EXPECT_EQ(graph.num_op_nodes(), 3);
155     GraphHelper helper(graph);
156     helper.CheckArgNum(1);
157     helper.CheckAssignedDevice("arg", "TPU:0");
158     helper.CheckAssignedDevice("read", "TPU:0");
159     helper.CheckAssignedDevice("ret", "CPU:0");
160   }
161 }
162 
TEST(ReplicatePerReplicaNodesTest,MultipleCompositeDevices)163 TEST(ReplicatePerReplicaNodesTest, MultipleCompositeDevices) {
164   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
165   Output arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_RESOURCE, 0);
166   Output arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 0);
167   auto read0 = ops::ReadVariableOp(scope.WithOpName("read0"), arg0, DT_INT32);
168   auto read1 = ops::ReadVariableOp(scope.WithOpName("read1"), arg1, DT_INT32);
169   auto identity0 = ops::Identity(scope.WithOpName("identity0"), read0);
170   auto identity1 = ops::Identity(scope.WithOpName("identity1"), read1);
171   auto add = ops::Add(scope.WithOpName("add"), identity0, identity1);
172   auto ret = ops::_Retval(scope.WithOpName("ret"), add, 0);
173 
174   const std::vector<string> underlying_devices_0 = {"TPU:0", "TPU:1"};
175   const std::vector<string> underlying_devices_1 = {"TPU:2", "TPU:3"};
176   const absl::flat_hash_map<string, const std::vector<string>*>
177       composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices_0},
178                            {"TPU_COMPOSITE:1", &underlying_devices_1}};
179 
180   Graph graph(OpRegistry::Global());
181   TF_ASSERT_OK(scope.ToGraph(&graph));
182   {
183     // _Arg(TPU_COMPOSITE:0) -> ReadVariableOp(TPU_COMPOSITE:0) ->
184     // Identity(TPU:1)
185     // _Arg(TPU_COMPOSITE:1) -> ReadVariableOp(TPU_COMPOSITE:1)
186     // -> Identity(TPU:3)
187     // Identity(TPU:1), Identity(TPU:3) -> Add(TPU:0)-> _Retval(CPU:0)
188     ASSERT_EQ(graph.num_op_nodes(), 8);
189     GraphHelper helper(graph);
190     helper.SetAssignedDevice("arg0", "TPU_COMPOSITE:0");
191     helper.SetAssignedDevice("read0", "TPU_COMPOSITE:0");
192     helper.SetAssignedDevice("identity0", "TPU:1");
193     helper.SetAssignedDevice("arg1", "TPU_COMPOSITE:1");
194     helper.SetAssignedDevice("read1", "TPU_COMPOSITE:1");
195     helper.SetAssignedDevice("identity1", "TPU:3");
196     helper.SetAssignedDevice("add", "TPU:0");
197     helper.SetAssignedDevice("ret", "CPU:0");
198   }
199 
200   TF_EXPECT_OK(
201       ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
202 
203   {
204     // _Arg(TPU:0, TPU:3) -> ReadVariableOp(TPU:1, TPU:3) -> Identity(TPU:1,
205     // TPU:3) -> Add(TPU:0)-> _Retval(CPU:0)
206     EXPECT_EQ(graph.num_op_nodes(), 8);
207     GraphHelper helper(graph);
208     helper.CheckArgNum(2);
209     helper.CheckAssignedDevice("arg0/R1", "TPU:1");
210     helper.CheckAssignedDevice("arg1/R1", "TPU:3");
211     helper.CheckAssignedDevice("read0/R1", "TPU:1");
212     helper.CheckAssignedDevice("read1/R1", "TPU:3");
213     helper.CheckAssignedDevice("identity0", "TPU:1");
214     helper.CheckAssignedDevice("identity1", "TPU:3");
215     helper.CheckAssignedDevice("add", "TPU:0");
216     helper.CheckAssignedDevice("ret", "CPU:0");
217   }
218 }
219 
TEST(ReplicatePerReplicaNodesTest,NestedFunctions)220 TEST(ReplicatePerReplicaNodesTest, NestedFunctions) {
221   const std::vector<string> underlying_devices = {"TPU:0", "TPU:1"};
222   const absl::flat_hash_map<string, const std::vector<string>*>
223       composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
224 
225   FunctionDefLibrary fdef_lib;
226   FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
227   {
228     Scope scope = Scope::NewRootScope().ExitOnError();
229     auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
230     auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32);
231     auto ret = ops::_Retval(scope.WithOpName("ret"), read, 0);
232     Graph graph(OpRegistry::Global());
233     TF_ASSERT_OK(scope.ToGraph(&graph));
234     GraphHelper helper(graph);
235     helper.SetAssignedDevice("arg", "TPU_COMPOSITE:0");
236     helper.SetAssignedDevice("read", "TPU:0");
237     helper.SetAssignedDevice("ret", "CPU:0");
238     FunctionDef fdef;
239     TF_ASSERT_OK(GraphToFunctionDef(graph, "Func", &fdef));
240     *fdef_lib.add_function() = fdef;
241     TF_ASSERT_OK(flib_def.AddFunctionDef(fdef));
242   }
243 
244   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
245   Output arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
246   TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
247   NodeDef def;
248   TF_ASSERT_OK(NodeDefBuilder("func", "Func", &flib_def)
249                    .Input(arg.name(), 0, DT_RESOURCE)
250                    .Finalize(&def));
251   Status status;
252   Node* func = scope.graph()->AddNode(def, &status);
253   TF_ASSERT_OK(status);
254   scope.graph()->AddEdge(arg.node(), 0, func, 0);
255   auto ret = ops::_Retval(scope.WithOpName("ret"), Output(func), 0);
256   Graph graph(OpRegistry::Global());
257   TF_ASSERT_OK(scope.ToGraph(&graph));
258   {
259     // _Arg(TPU_COMPOSITE:0) -> Func(CPU:0) -> _Retval(CPU:0)
260     GraphHelper helper(graph);
261     EXPECT_EQ(graph.num_op_nodes(), 3);
262     helper.SetAssignedDevice("arg", "TPU_COMPOSITE:0");
263     helper.SetAssignedDevice("func", "CPU:0");
264     helper.SetAssignedDevice("ret", "CPU:0");
265   }
266 
267   TF_EXPECT_OK(
268       ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
269 
270   {
271     // _Arg(TPU:0), _Arg(TPU:1) -> Pack(CPU:0) -> Func(CPU:0) -> _Retval(CPU:0)
272     EXPECT_EQ(graph.num_op_nodes(), 5);
273     GraphHelper helper(graph);
274     helper.CheckArgNum(2);
275     helper.CheckAssignedDevice("arg/R0", "TPU:0");
276     helper.CheckAssignedDevice("arg/R1", "TPU:1");
277     helper.CheckAssignedDevice("arg/Packed", "CPU:0");
278     helper.CheckAssignedDevice("func", "CPU:0");
279     helper.CheckAssignedDevice("ret", "CPU:0");
280     const EdgeSet& packed_in_edges =
281         helper.GetNodeByName("arg/Packed")->in_edges();
282     EXPECT_EQ(packed_in_edges.size(), 2);
283     auto it = packed_in_edges.begin();
284     EXPECT_EQ(helper.GetNodeByName("arg/R0"), (*it++)->src());
285     EXPECT_EQ(helper.GetNodeByName("arg/R1"), (*it)->src());
286     const EdgeSet& func_in_edges = helper.GetNodeByName("func")->in_edges();
287     EXPECT_EQ(func_in_edges.size(), 1);
288     EXPECT_EQ(helper.GetNodeByName("arg/Packed"),
289               (*func_in_edges.begin())->src());
290   }
291 }
292 
TEST(ReplicatePerReplicaNodesTest,DeadArgNodes)293 TEST(ReplicatePerReplicaNodesTest, DeadArgNodes) {
294   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
295   Output arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
296   auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32);
297   auto ret = ops::_Retval(scope.WithOpName("ret"), read, 0);
298 
299   const std::vector<string> underlying_devices = {"TPU:0", "TPU:1"};
300   const absl::flat_hash_map<string, const std::vector<string>*>
301       composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
302 
303   Graph graph(OpRegistry::Global());
304   TF_ASSERT_OK(scope.ToGraph(&graph));
305   {
306     // _Arg(TPU_COMPOSITE:0) -> ReadVariableOp(TPU:0) -> _Retval(CPU:0)
307     ASSERT_EQ(graph.num_op_nodes(), 3);
308     GraphHelper helper(graph);
309     helper.SetAssignedDevice("arg", "TPU_COMPOSITE:0");
310     helper.SetAssignedDevice("read", "TPU:0");
311     helper.SetAssignedDevice("ret", "CPU:0");
312   }
313 
314   TF_EXPECT_OK(
315       ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
316 
317   {
318     // _Arg(TPU:0) -> ReadVariableOp(TPU:0) -> _Retval(CPU:0)
319     // "arg/R1" is a dead node, so gets removed.
320     EXPECT_EQ(graph.num_op_nodes(), 3);
321     GraphHelper helper(graph);
322     helper.CheckArgNum(1);
323     helper.CheckAssignedDevice("arg/R0", "TPU:0");
324     helper.CheckAssignedDevice("read", "TPU:0");
325     helper.CheckAssignedDevice("ret", "CPU:0");
326   }
327 }
328 
329 }  // namespace
330 }  // namespace tensorflow
331