1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_cluster_util.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/strings/str_join.h"
20 #include "tensorflow/cc/framework/ops.h"
21 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
22 #include "tensorflow/cc/ops/function_ops.h"
23 #include "tensorflow/cc/ops/functional_ops.h"
24 #include "tensorflow/cc/ops/standard_ops.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/core/common_runtime/graph_constructor.h"
27 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
28 #include "tensorflow/core/framework/function_testlib.h"
29 #include "tensorflow/core/framework/graph_to_functiondef.h"
30 #include "tensorflow/core/graph/algorithm.h"
31 #include "tensorflow/core/graph/graph_def_builder.h"
32 #include "tensorflow/core/graph/testlib.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/test.h"
35 #include "tensorflow/core/public/version.h"
36
37 namespace tensorflow {
38 namespace {
39
TEST(CreateCycleDetectionGraph,ConnectivityThroughEnterExitRegion)40 TEST(CreateCycleDetectionGraph, ConnectivityThroughEnterExitRegion) {
41 Scope root = Scope::NewRootScope().ExitOnError();
42
43 Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0));
44 Output enter =
45 ops::internal::Enter(root.WithOpName("enter"), a, "only_frame");
46 Output exit = ops::internal::Exit(root.WithOpName("exit"), enter);
47 Output b = ops::Add(root.WithOpName("b"), a, exit);
48
49 FixupSourceAndSinkEdges(root.graph());
50
51 GraphCycles cycles;
52 TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status());
53 EXPECT_FALSE(cycles.CanContractEdge(a.node()->id(), b.node()->id()));
54 }
55
TEST(CreateCycleDetectionGraph,ConnectivityThroughMultipleEnterExitRegions)56 TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) {
57 Scope root = Scope::NewRootScope().ExitOnError();
58
59 Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0));
60 Output enter_0 =
61 ops::internal::Enter(root.WithOpName("enter_0"), a, "frame_0");
62 Output exit_0 = ops::internal::Exit(root.WithOpName("exit_0"), enter_0);
63 Output enter_1 =
64 ops::internal::Enter(root.WithOpName("enter_1"), a, "frame_1");
65 Output exit_1 = ops::internal::Exit(root.WithOpName("exit_1"), enter_1);
66 Output b = ops::Add(root.WithOpName("b"), a, exit_1);
67
68 FixupSourceAndSinkEdges(root.graph());
69
70 GraphCycles cycles;
71 TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status());
72 EXPECT_FALSE(cycles.CanContractEdge(a.node()->id(), b.node()->id()));
73 }
74
TEST(CreateCycleDetectionGraph,ReachingEnterExit)75 TEST(CreateCycleDetectionGraph, ReachingEnterExit) {
76 // TODO(b/127521408): We can lift this limitation with some work.
77 Scope root = Scope::NewRootScope().ExitOnError();
78
79 Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0));
80 Output enter_0 =
81 ops::internal::Enter(root.WithOpName("enter_0"), a, "frame_0");
82 Output exit_0 = ops::internal::Exit(root.WithOpName("exit_0"), enter_0);
83
84 Output add = ops::Add(root.WithOpName("add"), exit_0, exit_0);
85
86 Output enter_1 =
87 ops::internal::Enter(root.WithOpName("enter_1"), add, "frame_0");
88 Output exit_1 = ops::internal::Exit(root.WithOpName("exit_1"), enter_1);
89
90 FixupSourceAndSinkEdges(root.graph());
91
92 GraphCycles cycles;
93 TF_ASSERT_OK_AND_ASSIGN(bool ok,
94 CreateCycleDetectionGraph(root.graph(), &cycles));
95 EXPECT_FALSE(ok);
96 }
97
98 const char* kCPU0 = "/job:localhost/replica:0/task:0/device:CPU:0";
99 const char* kGPU0 = "/job:localhost/replica:0/task:0/device:GPU:0";
100 const char* kGPU1 = "/job:localhost/replica:0/task:0/device:GPU:1";
101
TEST(IsSingleGpuGraph,ReturnsTrue)102 TEST(IsSingleGpuGraph, ReturnsTrue) {
103 Scope root = Scope::NewRootScope().WithAssignedDevice(kGPU0).ExitOnError();
104
105 Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0));
106 Output b = ops::Add(root.WithOpName("b"), a, a);
107 Output c = ops::Add(root.WithOpName("c"), b, b);
108
109 FixupSourceAndSinkEdges(root.graph());
110
111 EXPECT_TRUE(IsSingleGpuGraph(*root.graph()));
112 }
113
TEST(IsSingleGpuGraph,ReturnsFalseForCpuGraph)114 TEST(IsSingleGpuGraph, ReturnsFalseForCpuGraph) {
115 Scope root = Scope::NewRootScope().WithAssignedDevice(kCPU0).ExitOnError();
116
117 Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0));
118 Output b = ops::Add(root.WithOpName("b"), a, a);
119 Output c = ops::Add(root.WithOpName("c"), b, b);
120
121 FixupSourceAndSinkEdges(root.graph());
122
123 EXPECT_FALSE(IsSingleGpuGraph(*root.graph()));
124 }
125
TEST(IsSingleGpuGraph,ReturnsFalseForMultiGpuGraph)126 TEST(IsSingleGpuGraph, ReturnsFalseForMultiGpuGraph) {
127 Scope root = Scope::NewRootScope().WithAssignedDevice(kGPU0).ExitOnError();
128
129 Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0));
130 Output b = ops::Add(root.WithOpName("b").WithAssignedDevice(kGPU1), a, a);
131 Output c = ops::Add(root.WithOpName("c"), b, b);
132
133 FixupSourceAndSinkEdges(root.graph());
134
135 EXPECT_FALSE(IsSingleGpuGraph(*root.graph()));
136 }
137
GetNodesRelatedToRefVarsSorted(const Scope & scope,FunctionLibraryDefinition * flib_def=nullptr)138 StatusOr<std::vector<string>> GetNodesRelatedToRefVarsSorted(
139 const Scope& scope, FunctionLibraryDefinition* flib_def = nullptr) {
140 FunctionDefLibrary flib;
141 FunctionLibraryDefinition flib_def_local(OpRegistry::Global(), flib);
142 if (flib_def == nullptr) {
143 flib_def = &flib_def_local;
144 }
145
146 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
147
148 TF_RETURN_IF_ERROR(scope.ToGraph(graph.get()));
149
150 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
151 new ProcessFunctionLibraryRuntime(
152 nullptr, Env::Default(), /*config=*/nullptr, TF_GRAPH_DEF_VERSION,
153 flib_def, OptimizerOptions{}));
154 FunctionLibraryRuntime* lib_runtime =
155 pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
156
157 TF_ASSIGN_OR_RETURN(absl::flat_hash_set<Node*> nodes_related_to_ref_vars,
158 GetNodesRelatedToRefVariables(*graph, lib_runtime));
159
160 std::vector<string> names;
161 absl::c_transform(nodes_related_to_ref_vars, std::back_inserter(names),
162 [](Node* n) { return n->name(); });
163 absl::c_sort(names);
164 return names;
165 }
166
CreateSubgraphTouchingRefVar(const Scope & s)167 void CreateSubgraphTouchingRefVar(const Scope& s) {
168 Output variable =
169 ops::Variable(s.WithOpName("variable"), PartialTensorShape{}, DT_FLOAT);
170 Output read = ops::Identity(s.WithOpName("read_ref_var"), variable);
171 Output neg = ops::Negate(s.WithOpName("negate_ref"), read);
172 Output add = ops::Add(s.WithOpName("add_ref"), neg, neg);
173
174 Output constant =
175 ops::Const(s.WithOpName("constant_ref"), Input::Initializer(0.0));
176 s.graph()->AddControlEdge(constant.node(), variable.node());
177 }
178
CreateSubgraphNotTouchingRefVar(const Scope & s)179 void CreateSubgraphNotTouchingRefVar(const Scope& s) {
180 Output constant =
181 ops::Const(s.WithOpName("constant_normal"), Input::Initializer(0.0));
182 Output neg = ops::Negate(s.WithOpName("negate_normal"), constant);
183 Output add = ops::Add(s.WithOpName("add_normal"), neg, neg);
184 }
185
CreateSubgraphCallingFunctionWithRefVar(const Scope & s)186 void CreateSubgraphCallingFunctionWithRefVar(const Scope& s) {
187 NameAttrList ref_float_function;
188 ref_float_function.set_name("RefFloatFn");
189 ops::PartitionedCall call(s.WithOpName("RefFloat"), {absl::Span<Input>{}},
190 {DT_FLOAT}, ref_float_function);
191 Output constant =
192 ops::Const(s.WithOpName("constant_ref_pco"), Input::Initializer(0.0));
193 s.graph()->AddControlEdge(call.operation.node(), constant.node());
194 }
195
CreateSubgraphCallingFunctionWithoutRefVar(const Scope & s)196 void CreateSubgraphCallingFunctionWithoutRefVar(const Scope& s) {
197 NameAttrList regular_float_function;
198 regular_float_function.set_name("RegularFloatFn");
199 ops::PartitionedCall call(s.WithOpName("RegularFloat"), {absl::Span<Input>{}},
200 {DT_FLOAT}, regular_float_function);
201 Output constant =
202 ops::Const(s.WithOpName("constant_normal_pco"), Input::Initializer(0.0));
203 s.graph()->AddControlEdge(call.operation.node(), constant.node());
204 }
205
AddRefFunctionFunctionDef(FunctionDefLibrary * fdef_lib)206 void AddRefFunctionFunctionDef(FunctionDefLibrary* fdef_lib) {
207 FunctionDef make_ref_float = FunctionDefHelper::Define(
208 "RefFloatFn", {}, {"r:float"}, {},
209 {{{"var"},
210 "VariableV2",
211 {},
212 {{"dtype", DT_FLOAT}, {"shape", TensorShape({})}}},
213 {{"r"}, "Identity", {"var"}, {{"T", DT_FLOAT}}}});
214 *fdef_lib->add_function() = make_ref_float;
215 }
216
AddRegularFunctionFunctionDef(FunctionDefLibrary * fdef_lib)217 void AddRegularFunctionFunctionDef(FunctionDefLibrary* fdef_lib) {
218 Tensor seven(DT_FLOAT, {});
219 seven.scalar<float>()() = 7;
220 FunctionDef make_regular_float = FunctionDefHelper::Define(
221 "RegularFloatFn", {}, {"r:float"}, {},
222 {{{"r"}, "Const", {}, {{"dtype", DT_FLOAT}, {"value", seven}}}});
223 *fdef_lib->add_function() = make_regular_float;
224 }
225
TEST(NodesRelatedToRefVariables,Basic)226 TEST(NodesRelatedToRefVariables, Basic) {
227 Scope root = Scope::NewRootScope().ExitOnError();
228
229 FunctionDefLibrary fdef_lib;
230
231 CreateSubgraphTouchingRefVar(root);
232 CreateSubgraphNotTouchingRefVar(root);
233
234 AddRefFunctionFunctionDef(&fdef_lib);
235 CreateSubgraphCallingFunctionWithRefVar(root);
236
237 AddRegularFunctionFunctionDef(&fdef_lib);
238 CreateSubgraphCallingFunctionWithoutRefVar(root);
239
240 FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
241
242 TF_ASSERT_OK_AND_ASSIGN(std::vector<string> names,
243 GetNodesRelatedToRefVarsSorted(root, &flib_def));
244
245 std::vector<string> expected({
246 "RefFloat",
247 "add_ref",
248 "constant_ref",
249 "constant_ref_pco",
250 "negate_ref",
251 "read_ref_var",
252 "variable",
253 });
254
255 EXPECT_EQ(names, expected);
256 }
257
MakeLoop(Scope s,Output init_value,absl::string_view loop_name)258 Status MakeLoop(Scope s, Output init_value, absl::string_view loop_name) {
259 s = s.NewSubScope(std::string(loop_name));
260 ops::internal::Enter enter(s.WithOpName("init_value"), init_value, loop_name);
261 ops::Merge merge(s.WithOpName("merge"), {init_value, init_value});
262 Output next_iteration =
263 ops::NextIteration(s.WithOpName("next_itr"), merge.output);
264 return s.graph()->UpdateEdge(next_iteration.node(), 0, merge.output.node(),
265 1);
266 }
267
TEST(NodesRelatedToRefVariables,Cycles)268 TEST(NodesRelatedToRefVariables, Cycles) {
269 Scope root = Scope::NewRootScope().ExitOnError();
270 Output variable = ops::Variable(root.WithOpName("variable"),
271 PartialTensorShape{}, DT_FLOAT);
272 TF_ASSERT_OK(
273 MakeLoop(root, ops::Identity(root.WithOpName("read_ref_var"), variable),
274 "ref_loop"));
275 TF_ASSERT_OK(MakeLoop(
276 root, ops::Const(root.WithOpName("constant"), Input::Initializer(0.0)),
277 "normal_loop"));
278
279 TF_ASSERT_OK_AND_ASSIGN(std::vector<string> names,
280 GetNodesRelatedToRefVarsSorted(root));
281 std::vector<string> expected({"read_ref_var", "ref_loop/init_value",
282 "ref_loop/merge", "ref_loop/next_itr",
283 "variable"});
284
285 EXPECT_EQ(names, expected);
286 }
287 } // namespace
288 } // namespace tensorflow
289