1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/build_xla_ops_pass.h"
17
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/resource_variable_ops.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/compiler/jit/defs.h"
23 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
24 #include "tensorflow/compiler/jit/node_matchers.h"
25 #include "tensorflow/compiler/jit/test_util.h"
26 #include "tensorflow/core/common_runtime/device_factory.h"
27 #include "tensorflow/core/graph/algorithm.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/test.h"
31 #include "tensorflow/core/public/session_options.h"
32
33 namespace tensorflow {
34 namespace {
35
36 class BuildXlaOpsTest : public ::testing::Test {
37 protected:
SetUp()38 void SetUp() override {
39 // This is needed to register the XLA_* devices.
40 CHECK(DeviceFactory::AddDevices(
41 SessionOptions(), "/job:localhost/replica:0/task:0", &devices_)
42 .ok());
43 }
44
45 private:
46 std::vector<std::unique_ptr<Device>> devices_;
47 };
48
49 using ::tensorflow::testing::FindNodeByName;
50 using ::tensorflow::testing::matchers::Attr;
51 using ::tensorflow::testing::matchers::CtrlDeps;
52 using ::tensorflow::testing::matchers::Inputs;
53 using ::tensorflow::testing::matchers::NodeWith;
54 using ::tensorflow::testing::matchers::Op;
55 using ::tensorflow::testing::matchers::Out;
56 using ::testing::_;
57
BuildXlaOps(const Scope & s,const FunctionDefLibrary & fdef_lib,std::unique_ptr<Graph> * result)58 Status BuildXlaOps(const Scope& s, const FunctionDefLibrary& fdef_lib,
59 std::unique_ptr<Graph>* result) {
60 auto graph = absl::make_unique<Graph>(OpRegistry::Global());
61 TF_RETURN_IF_ERROR(s.ToGraph(graph.get()));
62 FunctionLibraryDefinition flib_def(graph->op_registry(), fdef_lib);
63
64 // Assign all nodes to the CPU device.
65 static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
66 for (Node* n : graph->nodes()) {
67 if (n->requested_device().empty()) {
68 n->set_assigned_device_name(kCpuDevice);
69 } else {
70 n->set_assigned_device_name(n->requested_device());
71 }
72 }
73
74 FixupSourceAndSinkEdges(graph.get());
75
76 GraphOptimizationPassWrapper wrapper;
77 GraphOptimizationPassOptions opt_options =
78 wrapper.CreateGraphOptimizationPassOptions(&graph);
79 opt_options.flib_def = &flib_def;
80
81 BuildXlaOpsPass pass(/*enable_lazy_compilation=*/true);
82 TF_RETURN_IF_ERROR(pass.Run(opt_options));
83 VLOG(3) << graph->ToGraphDefDebug().DebugString();
84 *result = std::move(graph);
85 return Status::OK();
86 }
87
MakeXlaCompiledKernel(Graph * graph,const string & callee_name,const string & node_name,int num_constant_args,int num_resource_args,Node ** result)88 Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name,
89 const string& node_name, int num_constant_args,
90 int num_resource_args, Node** result) {
91 NodeDef call_node;
92 call_node.set_name(node_name);
93 call_node.set_op(callee_name);
94 AddNodeAttr(kXlaCompiledKernelAttr, true, &call_node);
95 AddNodeAttr(kXlaNumConstantArgsAttr, num_constant_args, &call_node);
96 AddNodeAttr(kXlaNumResourceArgsAttr, num_resource_args, &call_node);
97 Status s;
98 *result = graph->AddNode(call_node, &s);
99 return s;
100 }
101
MakeXlaCompiledKernel(Graph * graph,const string & callee_name,const string & node_name,Node ** result)102 Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name,
103 const string& node_name, Node** result) {
104 return MakeXlaCompiledKernel(graph, callee_name, node_name,
105 /*num_constant_args=*/0, /*num_resource_args=*/0,
106 result);
107 }
108
MakeWrite(const Scope & scope,Output value_to_write,const string & id)109 Node* MakeWrite(const Scope& scope, Output value_to_write, const string& id) {
110 Output var_handle = ops::VarHandleOp(scope.WithOpName("Var_" + id), DT_FLOAT,
111 TensorShape({}));
112 ops::AssignVariableOp assign_op(scope.WithOpName("Assignee_" + id),
113 var_handle, value_to_write);
114 return assign_op.operation.node();
115 }
116
MakeWrite(const Scope & scope,const string & id)117 Node* MakeWrite(const Scope& scope, const string& id) {
118 return MakeWrite(
119 scope, ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f), id);
120 }
121
CreateFunctionDefLibWithConstFunction(const string & name)122 FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) {
123 FunctionDefLibrary fdef_lib;
124 FunctionDef func = FunctionDefHelper::Create(
125 /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"},
126 /*attr_def*/
127 {}, /*node_def=*/{FunctionDefHelper::Const("one", 1.0f)},
128 /*ret_def=*/{{"out", "out:output:0"}});
129 *fdef_lib.add_function() = std::move(func);
130 return fdef_lib;
131 }
132
CreateFunctionDefLibWithInt32Input(const string & name)133 FunctionDefLibrary CreateFunctionDefLibWithInt32Input(const string& name) {
134 FunctionDefLibrary fdef_lib;
135 FunctionDef func = FunctionDefHelper::Create(
136 /*function_name=*/name, /*in_def=*/{"in: int32"},
137 /*out_def=*/{"out: int32"},
138 /*attr_def=*/{}, /*node_def=*/{{{"out"}, "Identity", {"in"}}},
139 /*ret_def=*/{{"out", "out:output:0"}});
140 *fdef_lib.add_function() = std::move(func);
141 return fdef_lib;
142 }
143
TEST_F(BuildXlaOpsTest,ControlDepsPreserved)144 TEST_F(BuildXlaOpsTest, ControlDepsPreserved) {
145 const char* kXlaDeviceName = "/job:worker/replica:0/task:0/device:XLA_CPU:0";
146 Scope root = Scope::NewRootScope().WithDevice(kXlaDeviceName).ExitOnError();
147
148 FunctionDefLibrary fdef_lib =
149 CreateFunctionDefLibWithConstFunction("cluster_0");
150 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib));
151 Node* call;
152 TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
153 call->AddAttr(kXlaHasReferenceVarsAttr, false);
154 call->set_requested_device(kXlaDeviceName);
155 Node* write_op = MakeWrite(root, "write");
156 write_op->AddAttr(kXlaHasReferenceVarsAttr, false);
157 root.graph()->AddControlEdge(call, write_op);
158
159 std::unique_ptr<Graph> graph;
160 TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph));
161
162 Node* write_op_new = FindNodeByName(graph.get(), write_op->name());
163 ASSERT_NE(write_op_new, nullptr);
164 EXPECT_THAT(write_op_new, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun")))));
165 }
166
TEST_F(BuildXlaOpsTest,CleanFailureOnBogusAttr)167 TEST_F(BuildXlaOpsTest, CleanFailureOnBogusAttr) {
168 Scope root = Scope::NewRootScope().ExitOnError();
169
170 FunctionDefLibrary fdef_lib =
171 CreateFunctionDefLibWithConstFunction("cluster_0");
172 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib));
173
174 Node* call;
175 TF_ASSERT_OK(
176 MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", 100, 100, &call));
177
178 Node* write_op = MakeWrite(root, "write");
179 root.graph()->AddControlEdge(call, write_op);
180
181 std::unique_ptr<Graph> graph;
182 Status failure_status = BuildXlaOps(root, fdef_lib, &graph);
183 ASSERT_FALSE(failure_status.ok());
184 EXPECT_EQ(failure_status.code(), error::INVALID_ARGUMENT);
185 }
186
TEST_F(BuildXlaOpsTest,OnNonXlaDevice)187 TEST_F(BuildXlaOpsTest, OnNonXlaDevice) {
188 Scope root = Scope::NewRootScope().ExitOnError();
189
190 FunctionDefLibrary fdef_lib =
191 CreateFunctionDefLibWithConstFunction("cluster_0");
192 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib));
193
194 Node* call;
195 TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
196 TF_ASSERT_OK(root.DoShapeInference(call));
197 call->AddAttr(kXlaHasReferenceVarsAttr, false);
198
199 Node* write_op = MakeWrite(root, Output(call), "write_result");
200 write_op->AddAttr(kXlaHasReferenceVarsAttr, false);
201
202 auto xla_compile = NodeWith(Op("_XlaCompile"), Attr("must_compile", false));
203 auto predicated_compilation_key =
204 NodeWith(Op("Switch"), Inputs(Out(0, xla_compile), Out(1, xla_compile)));
205 auto xla_run =
206 NodeWith(Op("_XlaRun"), Inputs(Out(1, predicated_compilation_key)));
207 auto tf_call =
208 NodeWith(Op("PartitionedCall"),
209 CtrlDeps(NodeWith(Op("Identity"),
210 Inputs(Out(0, predicated_compilation_key)))));
211 auto merge = NodeWith(Op("_XlaMerge"), Inputs(Out(tf_call), Out(xla_run)));
212 auto assign_var = NodeWith(Op("AssignVariableOp"), Inputs(_, Out(merge)));
213
214 std::unique_ptr<Graph> graph;
215 TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph));
216
217 Node* write_op_new = FindNodeByName(graph.get(), write_op->name());
218 ASSERT_NE(write_op_new, nullptr);
219 EXPECT_THAT(write_op_new, assign_var);
220 }
221
TEST_F(BuildXlaOpsTest,OnXlaDevice)222 TEST_F(BuildXlaOpsTest, OnXlaDevice) {
223 const char* kXlaDeviceName = "/job:worker/replica:0/task:0/device:XLA_CPU:0";
224 Scope root = Scope::NewRootScope().WithDevice(kXlaDeviceName).ExitOnError();
225
226 FunctionDefLibrary fdef_lib =
227 CreateFunctionDefLibWithConstFunction("cluster_0");
228 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib));
229
230 Node* call;
231 TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
232 call->set_requested_device(kXlaDeviceName);
233 TF_ASSERT_OK(root.DoShapeInference(call));
234 call->AddAttr(kXlaHasReferenceVarsAttr, false);
235
236 Node* write_op = MakeWrite(root, Output(call), "write_result");
237 write_op->AddAttr(kXlaHasReferenceVarsAttr, false);
238
239 std::unique_ptr<Graph> graph;
240 TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph));
241
242 auto xla_op =
243 NodeWith(Op("_XlaRun"), Inputs(Out(NodeWith(Op("_XlaCompile")))));
244 auto assign_var =
245 NodeWith(Op("AssignVariableOp"), Inputs(Out(NodeWith()), Out(xla_op)));
246
247 Node* write_op_new = FindNodeByName(graph.get(), write_op->name());
248 ASSERT_NE(write_op_new, nullptr);
249 EXPECT_THAT(write_op_new, assign_var);
250 }
251
TEST_F(BuildXlaOpsTest,NoExtraMergeForEdgeToSink)252 TEST_F(BuildXlaOpsTest, NoExtraMergeForEdgeToSink) {
253 Scope root = Scope::NewRootScope().ExitOnError();
254
255 FunctionDefLibrary fdef_lib =
256 CreateFunctionDefLibWithConstFunction("cluster_0");
257 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib));
258 Node* call;
259 TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
260 call->AddAttr(kXlaHasReferenceVarsAttr, false);
261
262 std::unique_ptr<Graph> graph;
263 TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph));
264
265 Node* sink_node = graph->sink_node();
266 EXPECT_THAT(sink_node, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun")),
267 NodeWith(Op("PartitionedCall")),
268 NodeWith(Op("NoOp")))));
269 }
270
271 #ifdef GOOGLE_CUDA
272 // This tests a rewrite that only makes sense and is active in a CUDA-enabled
273 // build. Specifically we check that we insert an IdentityN op to avoid extra
274 // device-to-host copies.
TEST_F(BuildXlaOpsTest,NoDeviceToHostCopiesForClustersWithInt32Inputs)275 TEST_F(BuildXlaOpsTest, NoDeviceToHostCopiesForClustersWithInt32Inputs) {
276 const char* kXlaDeviceName = "/job:worker/replica:0/task:0/device:GPU:0";
277 Scope root = Scope::NewRootScope()
278 .WithDevice(kXlaDeviceName)
279 .WithAssignedDevice(kXlaDeviceName)
280 .ExitOnError();
281
282 FunctionDefLibrary fdef_lib =
283 CreateFunctionDefLibWithInt32Input("cluster_int32");
284 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib));
285 Node* call;
286 TF_ASSERT_OK(
287 MakeXlaCompiledKernel(root.graph(), "cluster_int32", "C", &call));
288 call->set_requested_device(kXlaDeviceName);
289 call->AddAttr(kXlaHasReferenceVarsAttr, false);
290
291 auto var =
292 ops::VarHandleOp(root.WithOpName("var"), DT_INT32, TensorShape({}));
293 auto int32_on_device =
294 ops::ReadVariableOp(root.WithOpName("int32_on_device"), var, DT_INT32);
295
296 root.graph()->AddEdge(int32_on_device.node(), 0, call, 0);
297
298 std::unique_ptr<Graph> graph;
299 TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph));
300
301 Node* partitioned_call_op = nullptr;
302 for (Node* n : graph->op_nodes()) {
303 if (n->type_string() == "PartitionedCall") {
304 ASSERT_EQ(partitioned_call_op, nullptr);
305 partitioned_call_op = n;
306 }
307 }
308
309 ASSERT_NE(partitioned_call_op, nullptr);
310 auto xla_compile = NodeWith(Op("_XlaCompile"));
311 auto switch_on_compilation_pred =
312 NodeWith(Op("Switch"), Inputs(Out(0, xla_compile), Out(1, xla_compile)));
313 auto ctrl_dep =
314 NodeWith(Op("Identity"), Inputs(Out(0, switch_on_compilation_pred)));
315 // Check that we pipe int32 inputs through an IdentityN to avoid extra D2H
316 // copies.
317 EXPECT_THAT(
318 partitioned_call_op,
319 NodeWith(Inputs(Out(NodeWith(Op("IdentityN"), CtrlDeps(ctrl_dep))))));
320 }
321 #endif
322
323 } // namespace
324 } // namespace tensorflow
325