• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_cluster_util.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/strings/str_join.h"
20 #include "tensorflow/cc/framework/ops.h"
21 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
22 #include "tensorflow/cc/ops/function_ops.h"
23 #include "tensorflow/cc/ops/functional_ops.h"
24 #include "tensorflow/cc/ops/standard_ops.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/core/common_runtime/graph_constructor.h"
27 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
28 #include "tensorflow/core/framework/function_testlib.h"
29 #include "tensorflow/core/framework/graph_to_functiondef.h"
30 #include "tensorflow/core/graph/algorithm.h"
31 #include "tensorflow/core/graph/graph_def_builder.h"
32 #include "tensorflow/core/graph/testlib.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/test.h"
35 #include "tensorflow/core/public/version.h"
36 
37 namespace tensorflow {
38 namespace {
39 
TEST(CreateCycleDetectionGraph,ConnectivityThroughEnterExitRegion)40 TEST(CreateCycleDetectionGraph, ConnectivityThroughEnterExitRegion) {
41   Scope root = Scope::NewRootScope().ExitOnError();
42 
43   Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0));
44   Output enter =
45       ops::internal::Enter(root.WithOpName("enter"), a, "only_frame");
46   Output exit = ops::internal::Exit(root.WithOpName("exit"), enter);
47   Output b = ops::Add(root.WithOpName("b"), a, exit);
48 
49   FixupSourceAndSinkEdges(root.graph());
50 
51   GraphCycles cycles;
52   TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status());
53   EXPECT_FALSE(cycles.CanContractEdge(a.node()->id(), b.node()->id()));
54 }
55 
TEST(CreateCycleDetectionGraph,ConnectivityThroughMultipleEnterExitRegions)56 TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) {
57   Scope root = Scope::NewRootScope().ExitOnError();
58 
59   Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0));
60   Output enter_0 =
61       ops::internal::Enter(root.WithOpName("enter_0"), a, "frame_0");
62   Output exit_0 = ops::internal::Exit(root.WithOpName("exit_0"), enter_0);
63   Output enter_1 =
64       ops::internal::Enter(root.WithOpName("enter_1"), a, "frame_1");
65   Output exit_1 = ops::internal::Exit(root.WithOpName("exit_1"), enter_1);
66   Output b = ops::Add(root.WithOpName("b"), a, exit_1);
67 
68   FixupSourceAndSinkEdges(root.graph());
69 
70   GraphCycles cycles;
71   TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status());
72   EXPECT_FALSE(cycles.CanContractEdge(a.node()->id(), b.node()->id()));
73 }
74 
TEST(CreateCycleDetectionGraph,ReachingEnterExit)75 TEST(CreateCycleDetectionGraph, ReachingEnterExit) {
76   // TODO(b/127521408): We can lift this limitation with some work.
77   Scope root = Scope::NewRootScope().ExitOnError();
78 
79   Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0));
80   Output enter_0 =
81       ops::internal::Enter(root.WithOpName("enter_0"), a, "frame_0");
82   Output exit_0 = ops::internal::Exit(root.WithOpName("exit_0"), enter_0);
83 
84   Output add = ops::Add(root.WithOpName("add"), exit_0, exit_0);
85 
86   Output enter_1 =
87       ops::internal::Enter(root.WithOpName("enter_1"), add, "frame_0");
88   Output exit_1 = ops::internal::Exit(root.WithOpName("exit_1"), enter_1);
89 
90   FixupSourceAndSinkEdges(root.graph());
91 
92   GraphCycles cycles;
93   TF_ASSERT_OK_AND_ASSIGN(bool ok,
94                           CreateCycleDetectionGraph(root.graph(), &cycles));
95   EXPECT_FALSE(ok);
96 }
97 
98 const char* kCPU0 = "/job:localhost/replica:0/task:0/device:CPU:0";
99 const char* kGPU0 = "/job:localhost/replica:0/task:0/device:GPU:0";
100 const char* kGPU1 = "/job:localhost/replica:0/task:0/device:GPU:1";
101 
TEST(IsSingleGpuGraph,ReturnsTrue)102 TEST(IsSingleGpuGraph, ReturnsTrue) {
103   Scope root = Scope::NewRootScope().WithAssignedDevice(kGPU0).ExitOnError();
104 
105   Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0));
106   Output b = ops::Add(root.WithOpName("b"), a, a);
107   Output c = ops::Add(root.WithOpName("c"), b, b);
108 
109   FixupSourceAndSinkEdges(root.graph());
110 
111   EXPECT_TRUE(IsSingleGpuGraph(*root.graph()));
112 }
113 
TEST(IsSingleGpuGraph,ReturnsFalseForCpuGraph)114 TEST(IsSingleGpuGraph, ReturnsFalseForCpuGraph) {
115   Scope root = Scope::NewRootScope().WithAssignedDevice(kCPU0).ExitOnError();
116 
117   Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0));
118   Output b = ops::Add(root.WithOpName("b"), a, a);
119   Output c = ops::Add(root.WithOpName("c"), b, b);
120 
121   FixupSourceAndSinkEdges(root.graph());
122 
123   EXPECT_FALSE(IsSingleGpuGraph(*root.graph()));
124 }
125 
TEST(IsSingleGpuGraph,ReturnsFalseForMultiGpuGraph)126 TEST(IsSingleGpuGraph, ReturnsFalseForMultiGpuGraph) {
127   Scope root = Scope::NewRootScope().WithAssignedDevice(kGPU0).ExitOnError();
128 
129   Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0));
130   Output b = ops::Add(root.WithOpName("b").WithAssignedDevice(kGPU1), a, a);
131   Output c = ops::Add(root.WithOpName("c"), b, b);
132 
133   FixupSourceAndSinkEdges(root.graph());
134 
135   EXPECT_FALSE(IsSingleGpuGraph(*root.graph()));
136 }
137 
GetNodesRelatedToRefVarsSorted(const Scope & scope,FunctionLibraryDefinition * flib_def=nullptr)138 StatusOr<std::vector<string>> GetNodesRelatedToRefVarsSorted(
139     const Scope& scope, FunctionLibraryDefinition* flib_def = nullptr) {
140   FunctionDefLibrary flib;
141   FunctionLibraryDefinition flib_def_local(OpRegistry::Global(), flib);
142   if (flib_def == nullptr) {
143     flib_def = &flib_def_local;
144   }
145 
146   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
147 
148   TF_RETURN_IF_ERROR(scope.ToGraph(graph.get()));
149 
150   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
151       new ProcessFunctionLibraryRuntime(
152           nullptr, Env::Default(), /*config=*/nullptr, TF_GRAPH_DEF_VERSION,
153           flib_def, OptimizerOptions{}));
154   FunctionLibraryRuntime* lib_runtime =
155       pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
156 
157   TF_ASSIGN_OR_RETURN(absl::flat_hash_set<Node*> nodes_related_to_ref_vars,
158                       GetNodesRelatedToRefVariables(*graph, lib_runtime));
159 
160   std::vector<string> names;
161   absl::c_transform(nodes_related_to_ref_vars, std::back_inserter(names),
162                     [](Node* n) { return n->name(); });
163   absl::c_sort(names);
164   return names;
165 }
166 
CreateSubgraphTouchingRefVar(const Scope & s)167 void CreateSubgraphTouchingRefVar(const Scope& s) {
168   Output variable =
169       ops::Variable(s.WithOpName("variable"), PartialTensorShape{}, DT_FLOAT);
170   Output read = ops::Identity(s.WithOpName("read_ref_var"), variable);
171   Output neg = ops::Negate(s.WithOpName("negate_ref"), read);
172   Output add = ops::Add(s.WithOpName("add_ref"), neg, neg);
173 
174   Output constant =
175       ops::Const(s.WithOpName("constant_ref"), Input::Initializer(0.0));
176   s.graph()->AddControlEdge(constant.node(), variable.node());
177 }
178 
CreateSubgraphNotTouchingRefVar(const Scope & s)179 void CreateSubgraphNotTouchingRefVar(const Scope& s) {
180   Output constant =
181       ops::Const(s.WithOpName("constant_normal"), Input::Initializer(0.0));
182   Output neg = ops::Negate(s.WithOpName("negate_normal"), constant);
183   Output add = ops::Add(s.WithOpName("add_normal"), neg, neg);
184 }
185 
CreateSubgraphCallingFunctionWithRefVar(const Scope & s)186 void CreateSubgraphCallingFunctionWithRefVar(const Scope& s) {
187   NameAttrList ref_float_function;
188   ref_float_function.set_name("RefFloatFn");
189   ops::PartitionedCall call(s.WithOpName("RefFloat"), {absl::Span<Input>{}},
190                             {DT_FLOAT}, ref_float_function);
191   Output constant =
192       ops::Const(s.WithOpName("constant_ref_pco"), Input::Initializer(0.0));
193   s.graph()->AddControlEdge(call.operation.node(), constant.node());
194 }
195 
CreateSubgraphCallingFunctionWithoutRefVar(const Scope & s)196 void CreateSubgraphCallingFunctionWithoutRefVar(const Scope& s) {
197   NameAttrList regular_float_function;
198   regular_float_function.set_name("RegularFloatFn");
199   ops::PartitionedCall call(s.WithOpName("RegularFloat"), {absl::Span<Input>{}},
200                             {DT_FLOAT}, regular_float_function);
201   Output constant =
202       ops::Const(s.WithOpName("constant_normal_pco"), Input::Initializer(0.0));
203   s.graph()->AddControlEdge(call.operation.node(), constant.node());
204 }
205 
AddRefFunctionFunctionDef(FunctionDefLibrary * fdef_lib)206 void AddRefFunctionFunctionDef(FunctionDefLibrary* fdef_lib) {
207   FunctionDef make_ref_float = FunctionDefHelper::Define(
208       "RefFloatFn", {}, {"r:float"}, {},
209       {{{"var"},
210         "VariableV2",
211         {},
212         {{"dtype", DT_FLOAT}, {"shape", TensorShape({})}}},
213        {{"r"}, "Identity", {"var"}, {{"T", DT_FLOAT}}}});
214   *fdef_lib->add_function() = make_ref_float;
215 }
216 
AddRegularFunctionFunctionDef(FunctionDefLibrary * fdef_lib)217 void AddRegularFunctionFunctionDef(FunctionDefLibrary* fdef_lib) {
218   Tensor seven(DT_FLOAT, {});
219   seven.scalar<float>()() = 7;
220   FunctionDef make_regular_float = FunctionDefHelper::Define(
221       "RegularFloatFn", {}, {"r:float"}, {},
222       {{{"r"}, "Const", {}, {{"dtype", DT_FLOAT}, {"value", seven}}}});
223   *fdef_lib->add_function() = make_regular_float;
224 }
225 
TEST(NodesRelatedToRefVariables,Basic)226 TEST(NodesRelatedToRefVariables, Basic) {
227   Scope root = Scope::NewRootScope().ExitOnError();
228 
229   FunctionDefLibrary fdef_lib;
230 
231   CreateSubgraphTouchingRefVar(root);
232   CreateSubgraphNotTouchingRefVar(root);
233 
234   AddRefFunctionFunctionDef(&fdef_lib);
235   CreateSubgraphCallingFunctionWithRefVar(root);
236 
237   AddRegularFunctionFunctionDef(&fdef_lib);
238   CreateSubgraphCallingFunctionWithoutRefVar(root);
239 
240   FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
241 
242   TF_ASSERT_OK_AND_ASSIGN(std::vector<string> names,
243                           GetNodesRelatedToRefVarsSorted(root, &flib_def));
244 
245   std::vector<string> expected({
246       "RefFloat",
247       "add_ref",
248       "constant_ref",
249       "constant_ref_pco",
250       "negate_ref",
251       "read_ref_var",
252       "variable",
253   });
254 
255   EXPECT_EQ(names, expected);
256 }
257 
MakeLoop(Scope s,Output init_value,absl::string_view loop_name)258 Status MakeLoop(Scope s, Output init_value, absl::string_view loop_name) {
259   s = s.NewSubScope(std::string(loop_name));
260   ops::internal::Enter enter(s.WithOpName("init_value"), init_value, loop_name);
261   ops::Merge merge(s.WithOpName("merge"), {init_value, init_value});
262   Output next_iteration =
263       ops::NextIteration(s.WithOpName("next_itr"), merge.output);
264   return s.graph()->UpdateEdge(next_iteration.node(), 0, merge.output.node(),
265                                1);
266 }
267 
TEST(NodesRelatedToRefVariables,Cycles)268 TEST(NodesRelatedToRefVariables, Cycles) {
269   Scope root = Scope::NewRootScope().ExitOnError();
270   Output variable = ops::Variable(root.WithOpName("variable"),
271                                   PartialTensorShape{}, DT_FLOAT);
272   TF_ASSERT_OK(
273       MakeLoop(root, ops::Identity(root.WithOpName("read_ref_var"), variable),
274                "ref_loop"));
275   TF_ASSERT_OK(MakeLoop(
276       root, ops::Const(root.WithOpName("constant"), Input::Initializer(0.0)),
277       "normal_loop"));
278 
279   TF_ASSERT_OK_AND_ASSIGN(std::vector<string> names,
280                           GetNodesRelatedToRefVarsSorted(root));
281   std::vector<string> expected({"read_ref_var", "ref_loop/init_value",
282                                 "ref_loop/merge", "ref_loop/next_itr",
283                                 "variable"});
284 
285   EXPECT_EQ(names, expected);
286 }
287 }  // namespace
288 }  // namespace tensorflow
289