1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/cluster_scoping_pass.h"
17
18 #include "absl/container/flat_hash_map.h"
19 #include "tensorflow/compiler/jit/defs.h"
20 #include "tensorflow/compiler/jit/test_util.h"
21 #include "tensorflow/core/common_runtime/graph_constructor.h"
22 #include "tensorflow/core/common_runtime/graph_def_builder_util.h"
23 #include "tensorflow/core/framework/node_def_util.h"
24 #include "tensorflow/core/framework/op.h"
25 #include "tensorflow/core/graph/algorithm.h"
26 #include "tensorflow/core/graph/graph_def_builder.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/platform/test.h"
29 #include "tensorflow/core/public/session_options.h"
30
31 namespace tensorflow {
32 namespace {
33
ClusterScoping(std::unique_ptr<Graph> * graph)34 Status ClusterScoping(std::unique_ptr<Graph>* graph) {
35 FixupSourceAndSinkEdges(graph->get());
36
37 GraphOptimizationPassWrapper wrapper;
38 wrapper.session_options.config.mutable_graph_options()
39 ->mutable_optimizer_options()
40 ->set_global_jit_level(OptimizerOptions::ON_2);
41 GraphOptimizationPassOptions opt_options =
42 wrapper.CreateGraphOptimizationPassOptions(graph);
43
44 ClusterScopingPass pass;
45 return pass.Run(opt_options);
46 }
47
GetXlaInternalScopes(const Graph & graph)48 absl::flat_hash_map<string, string> GetXlaInternalScopes(const Graph& graph) {
49 absl::flat_hash_map<string, string> scopes;
50 for (Node* node : graph.nodes()) {
51 string scope;
52 if (GetNodeAttr(node->attrs(), kXlaInternalScopeAttr, &scope).ok()) {
53 scopes[node->name()] = scope;
54 }
55 }
56
57 if (VLOG_IS_ON(2)) {
58 VLOG(2) << "_XlaInternalScopes:";
59 for (const auto& p : scopes) {
60 VLOG(2) << " " << p.first << " -> " << p.second;
61 }
62 }
63 return scopes;
64 }
65
BuildStageNode(GraphDefBuilder & builder,string name,std::initializer_list<DataType> dtypes,absl::Span<const ops::NodeOut> values)66 Node* BuildStageNode(GraphDefBuilder& builder, string name,
67 std::initializer_list<DataType> dtypes,
68 absl::Span<const ops::NodeOut> values) {
69 auto opts = builder.opts()
70 .WithName(std::move(name))
71 .WithAttr("dtypes", std::move(dtypes));
72 if (opts.HaveError()) {
73 return nullptr;
74 }
75
76 NodeBuilder node_builder(name, "Stage", opts.op_registry());
77 node_builder.Input(values);
78 return opts.FinalizeBuilder(&node_builder);
79 }
80
TEST(XlaCompilationTest,StagePipelinePreserved)81 TEST(XlaCompilationTest, StagePipelinePreserved) {
82 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
83 {
84 // Graph:
85 // b
86 // |
87 // v
88 // a -> add0 (ClusterX) -> relu0 (ClusterX) -> stage
89 //
90 // b
91 // |
92 // v
93 // unstage -> add1 (ClusterY) -> relu1 (ClusterY)
94 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
95 Node* a = ops::SourceOp("Const", builder.opts()
96 .WithName("a")
97 .WithAttr("dtype", DT_FLOAT)
98 .WithAttr("value", Tensor()));
99 Node* b = ops::SourceOp("Const", builder.opts()
100 .WithName("b")
101 .WithAttr("dtype", DT_FLOAT)
102 .WithAttr("value", Tensor()));
103 Node* unstage = ops::SourceOp(
104 "Unstage",
105 builder.opts().WithName("unstage").WithAttr("dtypes", {DT_FLOAT}));
106
107 Node* add0 = ops::BinaryOp("Add", a, b, builder.opts().WithName("add0"));
108 Node* add1 =
109 ops::BinaryOp("Add", unstage, b, builder.opts().WithName("add1"));
110 Node* relu0 = ops::UnaryOp("Relu", add0, builder.opts().WithName("relu0"));
111 ops::UnaryOp("Relu", add1, builder.opts().WithName("relu1"));
112 BuildStageNode(builder, "stage", {DT_FLOAT}, {relu0});
113
114 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
115 }
116
117 TF_ASSERT_OK(ClusterScoping(&graph));
118
119 auto scopes = GetXlaInternalScopes(*graph);
120 EXPECT_NE(scopes["add0"], scopes["add1"]);
121 EXPECT_EQ(scopes["add0"], scopes["relu0"]);
122 EXPECT_EQ(scopes["add1"], scopes["relu1"]);
123 }
124
TEST(XlaCompilationTest,StagePipelinePreservedAndInitialScopesRespected)125 TEST(XlaCompilationTest, StagePipelinePreservedAndInitialScopesRespected) {
126 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
127 {
128 // Graph:
129 // b
130 // |
131 // v
132 // a -> add0 (ClusterA) -> relu0 (ClusterB) -> stage
133 //
134 // b
135 // |
136 // v
137 // unstage -> add1 (ClusterC) -> relu1 (ClusterD)
138 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
139 Node* a = ops::SourceOp("Const", builder.opts()
140 .WithName("a")
141 .WithAttr("dtype", DT_FLOAT)
142 .WithAttr("value", Tensor()));
143 Node* b = ops::SourceOp("Const", builder.opts()
144 .WithName("b")
145 .WithAttr("dtype", DT_FLOAT)
146 .WithAttr("value", Tensor()));
147 Node* unstage = ops::SourceOp(
148 "Unstage",
149 builder.opts().WithName("unstage").WithAttr("dtypes", {DT_FLOAT}));
150
151 // Intentionally give add0 and add1 the same initial scope but they should
152 // be separated by the ClusterScopingPass.
153 Node* add0 = ops::BinaryOp("Add", a, b,
154 builder.opts().WithName("add0").WithAttr(
155 kXlaInternalScopeAttr, "ClusterA"));
156 Node* add1 = ops::BinaryOp("Add", unstage, b,
157 builder.opts().WithName("add1").WithAttr(
158 kXlaInternalScopeAttr, "ClusterA"));
159 Node* relu0 = ops::UnaryOp("Relu", add0,
160 builder.opts().WithName("relu0").WithAttr(
161 kXlaInternalScopeAttr, "ClusterB"));
162 ops::UnaryOp("Relu", add1,
163 builder.opts().WithName("relu1").WithAttr(
164 kXlaInternalScopeAttr, "ClusterD"));
165 BuildStageNode(builder, "stage", {DT_FLOAT}, {relu0});
166
167 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
168 }
169
170 TF_ASSERT_OK(ClusterScoping(&graph));
171
172 auto scopes = GetXlaInternalScopes(*graph);
173 EXPECT_NE(scopes["add0"], scopes["add1"]);
174 EXPECT_NE(scopes["add0"], scopes["relu0"]);
175 EXPECT_NE(scopes["add1"], scopes["relu1"]);
176 }
177
178 } // namespace
179 } // namespace tensorflow
180