1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
17
18 #include "absl/strings/match.h"
19 #include "tensorflow/cc/ops/const_op.h"
20 #include "tensorflow/cc/ops/function_ops.h"
21 #include "tensorflow/cc/ops/resource_variable_ops.h"
22 #include "tensorflow/cc/ops/standard_ops.h"
23 #include "tensorflow/core/framework/graph_to_functiondef.h"
24 #include "tensorflow/core/framework/node_def_util.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/test.h"
28
29 namespace tensorflow {
30 namespace {
31
32 class GraphHelper {
33 public:
GraphHelper(const Graph & graph)34 explicit GraphHelper(const Graph& graph) : graph_(graph) {
35 for (Node* node : graph.nodes()) {
36 nodes_by_name_[node->name()] = node;
37 }
38 }
39
GetNodeByName(const string & name)40 Node* GetNodeByName(const string& name) {
41 const auto it = nodes_by_name_.find(name);
42 if (it != nodes_by_name_.end()) {
43 return it->second;
44 }
45 for (const auto& entry : nodes_by_name_) {
46 if (absl::StartsWith(entry.first, name)) {
47 return entry.second;
48 }
49 }
50 return nullptr;
51 }
52
SetAssignedDevice(const string & node_name,const string & device_name)53 void SetAssignedDevice(const string& node_name, const string& device_name) {
54 CHECK_NOTNULL(GetNodeByName(node_name))
55 ->set_assigned_device_name(device_name);
56 }
57
CheckArgNum(const int expected_num)58 void CheckArgNum(const int expected_num) {
59 int arg_num = 0;
60 for (Node* node : graph_.op_nodes()) {
61 if (node->IsArg()) {
62 arg_num++;
63 }
64 }
65 EXPECT_EQ(arg_num, expected_num);
66 }
67
CheckAssignedDevice(const string & node_name,const string & expected_device_name)68 void CheckAssignedDevice(const string& node_name,
69 const string& expected_device_name) {
70 EXPECT_EQ(expected_device_name,
71 CHECK_NOTNULL(GetNodeByName(node_name))->assigned_device_name());
72 }
73
74 private:
75 const Graph& graph_;
76 // Maps from a node name to a Node* in the graph.
77 absl::flat_hash_map<string, Node*> nodes_by_name_;
78 };
79
TEST(ReplicatePerReplicaNodesTest,SingleCompositeDevice)80 TEST(ReplicatePerReplicaNodesTest, SingleCompositeDevice) {
81 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
82 Output arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
83 auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32);
84 auto one = ops::Const<int32>(scope.WithOpName("one"), 1);
85 auto write = ops::AssignVariableOp(scope.WithOpName("write"), arg, one);
86 auto ret = ops::_Retval(
87 scope.WithOpName("ret").WithControlDependencies({write}), read, 0);
88
89 const std::vector<string> underlying_devices = {"TPU:0", "TPU:1"};
90 const absl::flat_hash_map<string, const std::vector<string>*>
91 composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
92
93 Graph graph(OpRegistry::Global());
94 TF_ASSERT_OK(scope.ToGraph(&graph));
95 {
96 // _Arg(TPU_COMPOSITE:0) -> ReadVariableOp(TPU:0);
97 // Const(CPU:0) -> AssignVariableOp(TPU_COMPOSITE:0);
98 // ReadVariableOp(TPU:0) -> _Retval(CPU:0)
99 ASSERT_EQ(graph.num_op_nodes(), 5);
100 GraphHelper helper(graph);
101 helper.SetAssignedDevice("arg", "TPU_COMPOSITE:0");
102 helper.SetAssignedDevice("read", "TPU:0");
103 helper.SetAssignedDevice("one", "CPU:0");
104 helper.SetAssignedDevice("write", "TPU_COMPOSITE:0");
105 helper.SetAssignedDevice("ret", "CPU:0");
106 }
107
108 TF_EXPECT_OK(
109 ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
110
111 {
112 // _Arg(TPU:0, TPU:1) -> ReadVariableOp(TPU:0);
113 // Const(CPU:0) -> AssignVariableOp(TPU:0, TPU:1);
114 // ReadVariableOp(TPU:0) -> _Retval(CPU:0)
115 EXPECT_EQ(graph.num_op_nodes(), 7);
116 GraphHelper helper(graph);
117 helper.CheckArgNum(2);
118 helper.CheckAssignedDevice("arg/R0", "TPU:0");
119 helper.CheckAssignedDevice("arg/R1", "TPU:1");
120 helper.CheckAssignedDevice("read", "TPU:0");
121 helper.CheckAssignedDevice("one", "CPU:0");
122 helper.CheckAssignedDevice("write/R0", "TPU:0");
123 helper.CheckAssignedDevice("write/R1", "TPU:1");
124 helper.CheckAssignedDevice("ret", "CPU:0");
125 }
126 }
127
TEST(ReplicatePerReplicaNodesTest,SingleCompositeDeviceToSingleDevice)128 TEST(ReplicatePerReplicaNodesTest, SingleCompositeDeviceToSingleDevice) {
129 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
130 Output arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
131 auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32);
132 auto ret = ops::_Retval(scope.WithOpName("ret"), read, 0);
133
134 const std::vector<string> underlying_devices = {"TPU:0"};
135 const absl::flat_hash_map<string, const std::vector<string>*>
136 composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
137
138 Graph graph(OpRegistry::Global());
139 TF_ASSERT_OK(scope.ToGraph(&graph));
140 {
141 // _Arg(TPU_COMPOSITE:0) -> ReadVariableOp(TPU:0) -> _Retval(CPU:0)
142 ASSERT_EQ(graph.num_op_nodes(), 3);
143 GraphHelper helper(graph);
144 helper.SetAssignedDevice("arg", "TPU_COMPOSITE:0");
145 helper.SetAssignedDevice("read", "TPU:0");
146 helper.SetAssignedDevice("ret", "CPU:0");
147 }
148
149 TF_EXPECT_OK(
150 ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
151
152 {
153 // _Arg(TPU:0) -> ReadVariableOp(TPU:0) -> _Retval(CPU:0)
154 EXPECT_EQ(graph.num_op_nodes(), 3);
155 GraphHelper helper(graph);
156 helper.CheckArgNum(1);
157 helper.CheckAssignedDevice("arg", "TPU:0");
158 helper.CheckAssignedDevice("read", "TPU:0");
159 helper.CheckAssignedDevice("ret", "CPU:0");
160 }
161 }
162
TEST(ReplicatePerReplicaNodesTest,MultipleCompositeDevices)163 TEST(ReplicatePerReplicaNodesTest, MultipleCompositeDevices) {
164 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
165 Output arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_RESOURCE, 0);
166 Output arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 0);
167 auto read0 = ops::ReadVariableOp(scope.WithOpName("read0"), arg0, DT_INT32);
168 auto read1 = ops::ReadVariableOp(scope.WithOpName("read1"), arg1, DT_INT32);
169 auto identity0 = ops::Identity(scope.WithOpName("identity0"), read0);
170 auto identity1 = ops::Identity(scope.WithOpName("identity1"), read1);
171 auto add = ops::Add(scope.WithOpName("add"), identity0, identity1);
172 auto ret = ops::_Retval(scope.WithOpName("ret"), add, 0);
173
174 const std::vector<string> underlying_devices_0 = {"TPU:0", "TPU:1"};
175 const std::vector<string> underlying_devices_1 = {"TPU:2", "TPU:3"};
176 const absl::flat_hash_map<string, const std::vector<string>*>
177 composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices_0},
178 {"TPU_COMPOSITE:1", &underlying_devices_1}};
179
180 Graph graph(OpRegistry::Global());
181 TF_ASSERT_OK(scope.ToGraph(&graph));
182 {
183 // _Arg(TPU_COMPOSITE:0) -> ReadVariableOp(TPU_COMPOSITE:0) ->
184 // Identity(TPU:1)
185 // _Arg(TPU_COMPOSITE:1) -> ReadVariableOp(TPU_COMPOSITE:1)
186 // -> Identity(TPU:3)
187 // Identity(TPU:1), Identity(TPU:3) -> Add(TPU:0)-> _Retval(CPU:0)
188 ASSERT_EQ(graph.num_op_nodes(), 8);
189 GraphHelper helper(graph);
190 helper.SetAssignedDevice("arg0", "TPU_COMPOSITE:0");
191 helper.SetAssignedDevice("read0", "TPU_COMPOSITE:0");
192 helper.SetAssignedDevice("identity0", "TPU:1");
193 helper.SetAssignedDevice("arg1", "TPU_COMPOSITE:1");
194 helper.SetAssignedDevice("read1", "TPU_COMPOSITE:1");
195 helper.SetAssignedDevice("identity1", "TPU:3");
196 helper.SetAssignedDevice("add", "TPU:0");
197 helper.SetAssignedDevice("ret", "CPU:0");
198 }
199
200 TF_EXPECT_OK(
201 ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
202
203 {
204 // _Arg(TPU:0, TPU:1, TPU:2, TPU:3) -> ReadVariableOp(TPU:0, TPU:1, TPU:2,
205 // TPU:3) -> Identity(TPU:1, TPU:3) -> Add(TPU:0)-> _Retval(CPU:0)
206 EXPECT_EQ(graph.num_op_nodes(), 12);
207 GraphHelper helper(graph);
208 helper.CheckArgNum(4);
209 helper.CheckAssignedDevice("arg0/R0", "TPU:0");
210 helper.CheckAssignedDevice("arg0/R1", "TPU:1");
211 helper.CheckAssignedDevice("arg1/R0", "TPU:2");
212 helper.CheckAssignedDevice("arg1/R1", "TPU:3");
213 helper.CheckAssignedDevice("read0/R0", "TPU:0");
214 helper.CheckAssignedDevice("read0/R1", "TPU:1");
215 helper.CheckAssignedDevice("read1/R0", "TPU:2");
216 helper.CheckAssignedDevice("read1/R1", "TPU:3");
217 helper.CheckAssignedDevice("identity0", "TPU:1");
218 helper.CheckAssignedDevice("identity1", "TPU:3");
219 helper.CheckAssignedDevice("add", "TPU:0");
220 helper.CheckAssignedDevice("ret", "CPU:0");
221 }
222 }
223
TEST(ReplicatePerReplicaNodesTest,NestedFunctions)224 TEST(ReplicatePerReplicaNodesTest, NestedFunctions) {
225 const std::vector<string> underlying_devices = {"TPU:0", "TPU:1"};
226 const absl::flat_hash_map<string, const std::vector<string>*>
227 composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
228
229 FunctionDefLibrary fdef_lib;
230 FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
231 {
232 Scope scope = Scope::NewRootScope().ExitOnError();
233 auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
234 auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32);
235 auto ret = ops::_Retval(scope.WithOpName("ret"), read, 0);
236 Graph graph(OpRegistry::Global());
237 TF_ASSERT_OK(scope.ToGraph(&graph));
238 GraphHelper helper(graph);
239 helper.SetAssignedDevice("arg", "TPU_COMPOSITE:0");
240 helper.SetAssignedDevice("read", "TPU:0");
241 helper.SetAssignedDevice("ret", "CPU:0");
242 FunctionDef fdef;
243 TF_ASSERT_OK(GraphToFunctionDef(graph, "Func", &fdef));
244 *fdef_lib.add_function() = fdef;
245 TF_ASSERT_OK(flib_def.AddFunctionDef(fdef));
246 }
247
248 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
249 Output arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
250 TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
251 NodeDef def;
252 TF_ASSERT_OK(NodeDefBuilder("func", "Func", &flib_def)
253 .Input(arg.name(), 0, DT_RESOURCE)
254 .Finalize(&def));
255 Status status;
256 Node* func = scope.graph()->AddNode(def, &status);
257 TF_ASSERT_OK(status);
258 scope.graph()->AddEdge(arg.node(), 0, func, 0);
259 auto ret = ops::_Retval(scope.WithOpName("ret"), Output(func), 0);
260 Graph graph(OpRegistry::Global());
261 TF_ASSERT_OK(scope.ToGraph(&graph));
262 {
263 // _Arg(TPU_COMPOSITE:0) -> Func(CPU:0) -> _Retval(CPU:0)
264 GraphHelper helper(graph);
265 EXPECT_EQ(graph.num_op_nodes(), 3);
266 helper.SetAssignedDevice("arg", "TPU_COMPOSITE:0");
267 helper.SetAssignedDevice("func", "CPU:0");
268 helper.SetAssignedDevice("ret", "CPU:0");
269 }
270
271 TF_EXPECT_OK(
272 ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
273
274 {
275 // _Arg(TPU:0), _Arg(TPU:1) -> Pack(CPU:0) -> Func(CPU:0) -> _Retval(CPU:0)
276 EXPECT_EQ(graph.num_op_nodes(), 5);
277 GraphHelper helper(graph);
278 helper.CheckArgNum(2);
279 helper.CheckAssignedDevice("arg/R0", "TPU:0");
280 helper.CheckAssignedDevice("arg/R1", "TPU:1");
281 helper.CheckAssignedDevice("arg/Packed", "CPU:0");
282 helper.CheckAssignedDevice("func", "CPU:0");
283 helper.CheckAssignedDevice("ret", "CPU:0");
284 const EdgeSet& packed_in_edges =
285 helper.GetNodeByName("arg/Packed")->in_edges();
286 EXPECT_EQ(packed_in_edges.size(), 2);
287 auto it = packed_in_edges.begin();
288 EXPECT_EQ(helper.GetNodeByName("arg/R0"), (*it++)->src());
289 EXPECT_EQ(helper.GetNodeByName("arg/R1"), (*it)->src());
290 const EdgeSet& func_in_edges = helper.GetNodeByName("func")->in_edges();
291 EXPECT_EQ(func_in_edges.size(), 1);
292 EXPECT_EQ(helper.GetNodeByName("arg/Packed"),
293 (*func_in_edges.begin())->src());
294 }
295 }
296
TEST(ReplicatePerReplicaNodesTest,DeadArgNodes)297 TEST(ReplicatePerReplicaNodesTest, DeadArgNodes) {
298 tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
299 Output arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
300 auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32);
301 auto ret = ops::_Retval(scope.WithOpName("ret"), read, 0);
302
303 const std::vector<string> underlying_devices = {"TPU:0", "TPU:1"};
304 const absl::flat_hash_map<string, const std::vector<string>*>
305 composite_devices = {{"TPU_COMPOSITE:0", &underlying_devices}};
306
307 Graph graph(OpRegistry::Global());
308 TF_ASSERT_OK(scope.ToGraph(&graph));
309 {
310 // _Arg(TPU_COMPOSITE:0) -> ReadVariableOp(TPU:0) -> _Retval(CPU:0)
311 ASSERT_EQ(graph.num_op_nodes(), 3);
312 GraphHelper helper(graph);
313 helper.SetAssignedDevice("arg", "TPU_COMPOSITE:0");
314 helper.SetAssignedDevice("read", "TPU:0");
315 helper.SetAssignedDevice("ret", "CPU:0");
316 }
317
318 TF_EXPECT_OK(
319 ReplicatePerReplicaNodesInFunctionGraph(composite_devices, &graph));
320
321 {
322 // _Arg(TPU:0) -> ReadVariableOp(TPU:0) -> _Retval(CPU:0)
323 // "arg/R1" is a dead node, so gets removed.
324 EXPECT_EQ(graph.num_op_nodes(), 3);
325 GraphHelper helper(graph);
326 helper.CheckArgNum(1);
327 helper.CheckAssignedDevice("arg/R0", "TPU:0");
328 helper.CheckAssignedDevice("read", "TPU:0");
329 helper.CheckAssignedDevice("ret", "CPU:0");
330 }
331 }
332
333 } // namespace
334 } // namespace tensorflow
335