1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
17
18 #include "absl/strings/match.h"
19 #include "tensorflow/cc/ops/const_op.h"
20 #include "tensorflow/cc/ops/function_ops.h"
21 #include "tensorflow/cc/ops/resource_variable_ops.h"
22 #include "tensorflow/cc/ops/standard_ops.h"
23 #include "tensorflow/core/framework/graph_to_functiondef.h"
24 #include "tensorflow/core/framework/node_def_util.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/test.h"
28
29 namespace tensorflow {
30 namespace {
31
32 class GraphHelper {
33 public:
GraphHelper(const Graph & graph)34 explicit GraphHelper(const Graph& graph) : graph_(graph) {
35 for (Node* node : graph.nodes()) {
36 nodes_by_name_[node->name()] = node;
37 }
38 }
39
GetNodeByName(const string & name)40 Node* GetNodeByName(const string& name) {
41 const auto it = nodes_by_name_.find(name);
42 if (it != nodes_by_name_.end()) {
43 return it->second;
44 }
45 for (const auto& entry : nodes_by_name_) {
46 if (absl::StartsWith(entry.first, name)) {
47 return entry.second;
48 }
49 }
50 return nullptr;
51 }
52
SetAssignedDevice(const string & node_name,const string & device_name)53 void SetAssignedDevice(const string& node_name, const string& device_name) {
54 CHECK_NOTNULL(GetNodeByName(node_name))
55 ->set_assigned_device_name(device_name);
56 }
57
CheckArgNum(const int expected_num)58 void CheckArgNum(const int expected_num) {
59 int arg_num = 0;
60 for (Node* node : graph_.op_nodes()) {
61 if (node->IsArg()) {
62 arg_num++;
63 }
64 }
65 EXPECT_EQ(arg_num, expected_num);
66 }
67
CheckAssignedDevice(const string & node_name,const string & expected_device_name)68 void CheckAssignedDevice(const string& node_name,
69 const string& expected_device_name) {
70 EXPECT_EQ(expected_device_name,
71 CHECK_NOTNULL(GetNodeByName(node_name))->assigned_device_name());
72 }
73
74 private:
75 const Graph& graph_;
76 // Maps from a node name to a Node* in the graph.
77 absl::flat_hash_map<string, Node*> nodes_by_name_;
78 };
79
TEST(ReplicatePerReplicaNodesTest,SingleCompositeDevice)80 TEST(ReplicatePerReplicaNodesTest, SingleCompositeDevice) {
81 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
82 Output arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
83 auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32);
84 auto one = ops::Const<int32>(scope.WithOpName("one"), 1);
85 auto write = ops::AssignVariableOp(scope.WithOpName("write"), arg, one);
86 auto ret = ops::_Retval(
87 scope.WithOpName("ret").WithControlDependencies({write}), read, 0);
88
89 const std::vector<string> underlying_devices = {"TPU:0", "TPU:1"};
90 const absl::flat_hash_map<string, const std::vector<string>*>
91 composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
92
93 Graph graph(OpRegistry::Global());
94 TF_ASSERT_OK(scope.ToGraph(&graph));
95 {
96 // _Arg(TPU_COMPOSITE:0) -> ReadVariableOp(TPU:0);
97 // Const(CPU:0) -> AssignVariableOp(TPU_COMPOSITE:0);
98 // ReadVariableOp(TPU:0) -> _Retval(CPU:0)
99 ASSERT_EQ(graph.num_op_nodes(), 5);
100 GraphHelper helper(graph);
101 helper.SetAssignedDevice("arg", "TPU_COMPOSITE:0");
102 helper.SetAssignedDevice("read", "TPU:0");
103 helper.SetAssignedDevice("one", "CPU:0");
104 helper.SetAssignedDevice("write", "TPU_COMPOSITE:0");
105 helper.SetAssignedDevice("ret", "CPU:0");
106 }
107
108 TF_EXPECT_OK(
109 ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
110
111 {
112 // _Arg(TPU:0, TPU:1) -> ReadVariableOp(TPU:0);
113 // Const(CPU:0) -> AssignVariableOp(TPU:0, TPU:1);
114 // ReadVariableOp(TPU:0) -> _Retval(CPU:0)
115 EXPECT_EQ(graph.num_op_nodes(), 7);
116 GraphHelper helper(graph);
117 helper.CheckArgNum(2);
118 helper.CheckAssignedDevice("arg/R0", "TPU:0");
119 helper.CheckAssignedDevice("arg/R1", "TPU:1");
120 helper.CheckAssignedDevice("read", "TPU:0");
121 helper.CheckAssignedDevice("one", "CPU:0");
122 helper.CheckAssignedDevice("write/R0", "TPU:0");
123 helper.CheckAssignedDevice("write/R1", "TPU:1");
124 helper.CheckAssignedDevice("ret", "CPU:0");
125 }
126 }
127
TEST(ReplicatePerReplicaNodesTest,SingleCompositeDeviceToSingleDevice)128 TEST(ReplicatePerReplicaNodesTest, SingleCompositeDeviceToSingleDevice) {
129 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
130 Output arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
131 auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32);
132 auto ret = ops::_Retval(scope.WithOpName("ret"), read, 0);
133
134 const std::vector<string> underlying_devices = {"TPU:0"};
135 const absl::flat_hash_map<string, const std::vector<string>*>
136 composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
137
138 Graph graph(OpRegistry::Global());
139 TF_ASSERT_OK(scope.ToGraph(&graph));
140 {
141 // _Arg(TPU_COMPOSITE:0) -> ReadVariableOp(TPU:0) -> _Retval(CPU:0)
142 ASSERT_EQ(graph.num_op_nodes(), 3);
143 GraphHelper helper(graph);
144 helper.SetAssignedDevice("arg", "TPU_COMPOSITE:0");
145 helper.SetAssignedDevice("read", "TPU:0");
146 helper.SetAssignedDevice("ret", "CPU:0");
147 }
148
149 TF_EXPECT_OK(
150 ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
151
152 {
153 // _Arg(TPU:0) -> ReadVariableOp(TPU:0) -> _Retval(CPU:0)
154 EXPECT_EQ(graph.num_op_nodes(), 3);
155 GraphHelper helper(graph);
156 helper.CheckArgNum(1);
157 helper.CheckAssignedDevice("arg", "TPU:0");
158 helper.CheckAssignedDevice("read", "TPU:0");
159 helper.CheckAssignedDevice("ret", "CPU:0");
160 }
161 }
162
TEST(ReplicatePerReplicaNodesTest,MultipleCompositeDevices)163 TEST(ReplicatePerReplicaNodesTest, MultipleCompositeDevices) {
164 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
165 Output arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_RESOURCE, 0);
166 Output arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 0);
167 auto read0 = ops::ReadVariableOp(scope.WithOpName("read0"), arg0, DT_INT32);
168 auto read1 = ops::ReadVariableOp(scope.WithOpName("read1"), arg1, DT_INT32);
169 auto identity0 = ops::Identity(scope.WithOpName("identity0"), read0);
170 auto identity1 = ops::Identity(scope.WithOpName("identity1"), read1);
171 auto add = ops::Add(scope.WithOpName("add"), identity0, identity1);
172 auto ret = ops::_Retval(scope.WithOpName("ret"), add, 0);
173
174 const std::vector<string> underlying_devices_0 = {"TPU:0", "TPU:1"};
175 const std::vector<string> underlying_devices_1 = {"TPU:2", "TPU:3"};
176 const absl::flat_hash_map<string, const std::vector<string>*>
177 composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices_0},
178 {"TPU_COMPOSITE:1", &underlying_devices_1}};
179
180 Graph graph(OpRegistry::Global());
181 TF_ASSERT_OK(scope.ToGraph(&graph));
182 {
183 // _Arg(TPU_COMPOSITE:0) -> ReadVariableOp(TPU_COMPOSITE:0) ->
184 // Identity(TPU:1)
185 // _Arg(TPU_COMPOSITE:1) -> ReadVariableOp(TPU_COMPOSITE:1)
186 // -> Identity(TPU:3)
187 // Identity(TPU:1), Identity(TPU:3) -> Add(TPU:0)-> _Retval(CPU:0)
188 ASSERT_EQ(graph.num_op_nodes(), 8);
189 GraphHelper helper(graph);
190 helper.SetAssignedDevice("arg0", "TPU_COMPOSITE:0");
191 helper.SetAssignedDevice("read0", "TPU_COMPOSITE:0");
192 helper.SetAssignedDevice("identity0", "TPU:1");
193 helper.SetAssignedDevice("arg1", "TPU_COMPOSITE:1");
194 helper.SetAssignedDevice("read1", "TPU_COMPOSITE:1");
195 helper.SetAssignedDevice("identity1", "TPU:3");
196 helper.SetAssignedDevice("add", "TPU:0");
197 helper.SetAssignedDevice("ret", "CPU:0");
198 }
199
200 TF_EXPECT_OK(
201 ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
202
203 {
204 // _Arg(TPU:0, TPU:3) -> ReadVariableOp(TPU:1, TPU:3) -> Identity(TPU:1,
205 // TPU:3) -> Add(TPU:0)-> _Retval(CPU:0)
206 EXPECT_EQ(graph.num_op_nodes(), 8);
207 GraphHelper helper(graph);
208 helper.CheckArgNum(2);
209 helper.CheckAssignedDevice("arg0/R1", "TPU:1");
210 helper.CheckAssignedDevice("arg1/R1", "TPU:3");
211 helper.CheckAssignedDevice("read0/R1", "TPU:1");
212 helper.CheckAssignedDevice("read1/R1", "TPU:3");
213 helper.CheckAssignedDevice("identity0", "TPU:1");
214 helper.CheckAssignedDevice("identity1", "TPU:3");
215 helper.CheckAssignedDevice("add", "TPU:0");
216 helper.CheckAssignedDevice("ret", "CPU:0");
217 }
218 }
219
TEST(ReplicatePerReplicaNodesTest,NestedFunctions)220 TEST(ReplicatePerReplicaNodesTest, NestedFunctions) {
221 const std::vector<string> underlying_devices = {"TPU:0", "TPU:1"};
222 const absl::flat_hash_map<string, const std::vector<string>*>
223 composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
224
225 FunctionDefLibrary fdef_lib;
226 FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
227 {
228 Scope scope = Scope::NewRootScope().ExitOnError();
229 auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
230 auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32);
231 auto ret = ops::_Retval(scope.WithOpName("ret"), read, 0);
232 Graph graph(OpRegistry::Global());
233 TF_ASSERT_OK(scope.ToGraph(&graph));
234 GraphHelper helper(graph);
235 helper.SetAssignedDevice("arg", "TPU_COMPOSITE:0");
236 helper.SetAssignedDevice("read", "TPU:0");
237 helper.SetAssignedDevice("ret", "CPU:0");
238 FunctionDef fdef;
239 TF_ASSERT_OK(GraphToFunctionDef(graph, "Func", &fdef));
240 *fdef_lib.add_function() = fdef;
241 TF_ASSERT_OK(flib_def.AddFunctionDef(fdef));
242 }
243
244 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
245 Output arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
246 TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
247 NodeDef def;
248 TF_ASSERT_OK(NodeDefBuilder("func", "Func", &flib_def)
249 .Input(arg.name(), 0, DT_RESOURCE)
250 .Finalize(&def));
251 Status status;
252 Node* func = scope.graph()->AddNode(def, &status);
253 TF_ASSERT_OK(status);
254 scope.graph()->AddEdge(arg.node(), 0, func, 0);
255 auto ret = ops::_Retval(scope.WithOpName("ret"), Output(func), 0);
256 Graph graph(OpRegistry::Global());
257 TF_ASSERT_OK(scope.ToGraph(&graph));
258 {
259 // _Arg(TPU_COMPOSITE:0) -> Func(CPU:0) -> _Retval(CPU:0)
260 GraphHelper helper(graph);
261 EXPECT_EQ(graph.num_op_nodes(), 3);
262 helper.SetAssignedDevice("arg", "TPU_COMPOSITE:0");
263 helper.SetAssignedDevice("func", "CPU:0");
264 helper.SetAssignedDevice("ret", "CPU:0");
265 }
266
267 TF_EXPECT_OK(
268 ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
269
270 {
271 // _Arg(TPU:0), _Arg(TPU:1) -> Pack(CPU:0) -> Func(CPU:0) -> _Retval(CPU:0)
272 EXPECT_EQ(graph.num_op_nodes(), 5);
273 GraphHelper helper(graph);
274 helper.CheckArgNum(2);
275 helper.CheckAssignedDevice("arg/R0", "TPU:0");
276 helper.CheckAssignedDevice("arg/R1", "TPU:1");
277 helper.CheckAssignedDevice("arg/Packed", "CPU:0");
278 helper.CheckAssignedDevice("func", "CPU:0");
279 helper.CheckAssignedDevice("ret", "CPU:0");
280 const EdgeSet& packed_in_edges =
281 helper.GetNodeByName("arg/Packed")->in_edges();
282 EXPECT_EQ(packed_in_edges.size(), 2);
283 auto it = packed_in_edges.begin();
284 EXPECT_EQ(helper.GetNodeByName("arg/R0"), (*it++)->src());
285 EXPECT_EQ(helper.GetNodeByName("arg/R1"), (*it)->src());
286 const EdgeSet& func_in_edges = helper.GetNodeByName("func")->in_edges();
287 EXPECT_EQ(func_in_edges.size(), 1);
288 EXPECT_EQ(helper.GetNodeByName("arg/Packed"),
289 (*func_in_edges.begin())->src());
290 }
291 }
292
TEST(ReplicatePerReplicaNodesTest,DeadArgNodes)293 TEST(ReplicatePerReplicaNodesTest, DeadArgNodes) {
294 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
295 Output arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
296 auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32);
297 auto ret = ops::_Retval(scope.WithOpName("ret"), read, 0);
298
299 const std::vector<string> underlying_devices = {"TPU:0", "TPU:1"};
300 const absl::flat_hash_map<string, const std::vector<string>*>
301 composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
302
303 Graph graph(OpRegistry::Global());
304 TF_ASSERT_OK(scope.ToGraph(&graph));
305 {
306 // _Arg(TPU_COMPOSITE:0) -> ReadVariableOp(TPU:0) -> _Retval(CPU:0)
307 ASSERT_EQ(graph.num_op_nodes(), 3);
308 GraphHelper helper(graph);
309 helper.SetAssignedDevice("arg", "TPU_COMPOSITE:0");
310 helper.SetAssignedDevice("read", "TPU:0");
311 helper.SetAssignedDevice("ret", "CPU:0");
312 }
313
314 TF_EXPECT_OK(
315 ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
316
317 {
318 // _Arg(TPU:0) -> ReadVariableOp(TPU:0) -> _Retval(CPU:0)
319 // "arg/R1" is a dead node, so gets removed.
320 EXPECT_EQ(graph.num_op_nodes(), 3);
321 GraphHelper helper(graph);
322 helper.CheckArgNum(1);
323 helper.CheckAssignedDevice("arg/R0", "TPU:0");
324 helper.CheckAssignedDevice("read", "TPU:0");
325 helper.CheckAssignedDevice("ret", "CPU:0");
326 }
327 }
328
329 } // namespace
330 } // namespace tensorflow
331