• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/cluster_scoping_pass.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "tensorflow/compiler/jit/defs.h"
20 #include "tensorflow/compiler/jit/test_util.h"
21 #include "tensorflow/core/common_runtime/graph_constructor.h"
22 #include "tensorflow/core/common_runtime/graph_def_builder_util.h"
23 #include "tensorflow/core/framework/node_def_util.h"
24 #include "tensorflow/core/framework/op.h"
25 #include "tensorflow/core/graph/algorithm.h"
26 #include "tensorflow/core/graph/graph_def_builder.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/platform/test.h"
29 #include "tensorflow/core/public/session_options.h"
30 
31 namespace tensorflow {
32 namespace {
33 
ClusterScoping(std::unique_ptr<Graph> * graph)34 Status ClusterScoping(std::unique_ptr<Graph>* graph) {
35   FixupSourceAndSinkEdges(graph->get());
36 
37   GraphOptimizationPassWrapper wrapper;
38   wrapper.session_options.config.mutable_graph_options()
39       ->mutable_optimizer_options()
40       ->set_global_jit_level(OptimizerOptions::ON_2);
41   GraphOptimizationPassOptions opt_options =
42       wrapper.CreateGraphOptimizationPassOptions(graph);
43 
44   ClusterScopingPass pass;
45   return pass.Run(opt_options);
46 }
47 
GetXlaInternalScopes(const Graph & graph)48 absl::flat_hash_map<string, string> GetXlaInternalScopes(const Graph& graph) {
49   absl::flat_hash_map<string, string> scopes;
50   for (Node* node : graph.nodes()) {
51     string scope;
52     if (GetNodeAttr(node->attrs(), kXlaInternalScopeAttr, &scope).ok()) {
53       scopes[node->name()] = scope;
54     }
55   }
56 
57   if (VLOG_IS_ON(2)) {
58     VLOG(2) << "_XlaInternalScopes:";
59     for (const auto& p : scopes) {
60       VLOG(2) << " " << p.first << " -> " << p.second;
61     }
62   }
63   return scopes;
64 }
65 
BuildStageNode(GraphDefBuilder & builder,string name,std::initializer_list<DataType> dtypes,absl::Span<const ops::NodeOut> values)66 Node* BuildStageNode(GraphDefBuilder& builder, string name,
67                      std::initializer_list<DataType> dtypes,
68                      absl::Span<const ops::NodeOut> values) {
69   auto opts = builder.opts()
70                   .WithName(std::move(name))
71                   .WithAttr("dtypes", std::move(dtypes));
72   if (opts.HaveError()) {
73     return nullptr;
74   }
75 
76   NodeBuilder node_builder(name, "Stage", opts.op_registry());
77   node_builder.Input(values);
78   return opts.FinalizeBuilder(&node_builder);
79 }
80 
TEST(XlaCompilationTest,StagePipelinePreserved)81 TEST(XlaCompilationTest, StagePipelinePreserved) {
82   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
83   {
84     // Graph:
85     //       b
86     //       |
87     //       v
88     // a -> add0 (ClusterX) -> relu0 (ClusterX) -> stage
89     //
90     //             b
91     //             |
92     //             v
93     // unstage -> add1 (ClusterY) -> relu1 (ClusterY)
94     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
95     Node* a = ops::SourceOp("Const", builder.opts()
96                                          .WithName("a")
97                                          .WithAttr("dtype", DT_FLOAT)
98                                          .WithAttr("value", Tensor()));
99     Node* b = ops::SourceOp("Const", builder.opts()
100                                          .WithName("b")
101                                          .WithAttr("dtype", DT_FLOAT)
102                                          .WithAttr("value", Tensor()));
103     Node* unstage = ops::SourceOp(
104         "Unstage",
105         builder.opts().WithName("unstage").WithAttr("dtypes", {DT_FLOAT}));
106 
107     Node* add0 = ops::BinaryOp("Add", a, b, builder.opts().WithName("add0"));
108     Node* add1 =
109         ops::BinaryOp("Add", unstage, b, builder.opts().WithName("add1"));
110     Node* relu0 = ops::UnaryOp("Relu", add0, builder.opts().WithName("relu0"));
111     ops::UnaryOp("Relu", add1, builder.opts().WithName("relu1"));
112     BuildStageNode(builder, "stage", {DT_FLOAT}, {relu0});
113 
114     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
115   }
116 
117   TF_ASSERT_OK(ClusterScoping(&graph));
118 
119   auto scopes = GetXlaInternalScopes(*graph);
120   EXPECT_NE(scopes["add0"], scopes["add1"]);
121   EXPECT_EQ(scopes["add0"], scopes["relu0"]);
122   EXPECT_EQ(scopes["add1"], scopes["relu1"]);
123 }
124 
TEST(XlaCompilationTest,StagePipelinePreservedAndInitialScopesRespected)125 TEST(XlaCompilationTest, StagePipelinePreservedAndInitialScopesRespected) {
126   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
127   {
128     // Graph:
129     //       b
130     //       |
131     //       v
132     // a -> add0 (ClusterA) -> relu0 (ClusterB) -> stage
133     //
134     //             b
135     //             |
136     //             v
137     // unstage -> add1 (ClusterC) -> relu1 (ClusterD)
138     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
139     Node* a = ops::SourceOp("Const", builder.opts()
140                                          .WithName("a")
141                                          .WithAttr("dtype", DT_FLOAT)
142                                          .WithAttr("value", Tensor()));
143     Node* b = ops::SourceOp("Const", builder.opts()
144                                          .WithName("b")
145                                          .WithAttr("dtype", DT_FLOAT)
146                                          .WithAttr("value", Tensor()));
147     Node* unstage = ops::SourceOp(
148         "Unstage",
149         builder.opts().WithName("unstage").WithAttr("dtypes", {DT_FLOAT}));
150 
151     // Intentionally give add0 and add1 the same initial scope but they should
152     // be separated by the ClusterScopingPass.
153     Node* add0 = ops::BinaryOp("Add", a, b,
154                                builder.opts().WithName("add0").WithAttr(
155                                    kXlaInternalScopeAttr, "ClusterA"));
156     Node* add1 = ops::BinaryOp("Add", unstage, b,
157                                builder.opts().WithName("add1").WithAttr(
158                                    kXlaInternalScopeAttr, "ClusterA"));
159     Node* relu0 = ops::UnaryOp("Relu", add0,
160                                builder.opts().WithName("relu0").WithAttr(
161                                    kXlaInternalScopeAttr, "ClusterB"));
162     ops::UnaryOp("Relu", add1,
163                  builder.opts().WithName("relu1").WithAttr(
164                      kXlaInternalScopeAttr, "ClusterD"));
165     BuildStageNode(builder, "stage", {DT_FLOAT}, {relu0});
166 
167     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
168   }
169 
170   TF_ASSERT_OK(ClusterScoping(&graph));
171 
172   auto scopes = GetXlaInternalScopes(*graph);
173   EXPECT_NE(scopes["add0"], scopes["add1"]);
174   EXPECT_NE(scopes["add0"], scopes["relu0"]);
175   EXPECT_NE(scopes["add1"], scopes["relu1"]);
176 }
177 
178 }  // namespace
179 }  // namespace tensorflow
180