• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/build_xla_ops_pass.h"
17 
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/resource_variable_ops.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/compiler/jit/defs.h"
23 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
24 #include "tensorflow/compiler/jit/node_matchers.h"
25 #include "tensorflow/compiler/jit/test_util.h"
26 #include "tensorflow/core/common_runtime/device_factory.h"
27 #include "tensorflow/core/graph/algorithm.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/test.h"
31 #include "tensorflow/core/public/session_options.h"
32 
33 namespace tensorflow {
34 namespace {
35 
36 class BuildXlaOpsTest : public ::testing::Test {
37  protected:
SetUp()38   void SetUp() override {
39     // This is needed to register the XLA_* devices.
40     CHECK(DeviceFactory::AddDevices(
41               SessionOptions(), "/job:localhost/replica:0/task:0", &devices_)
42               .ok());
43   }
44 
45  private:
46   std::vector<std::unique_ptr<Device>> devices_;
47 };
48 
49 using ::tensorflow::testing::FindNodeByName;
50 using ::tensorflow::testing::matchers::Attr;
51 using ::tensorflow::testing::matchers::CtrlDeps;
52 using ::tensorflow::testing::matchers::Inputs;
53 using ::tensorflow::testing::matchers::NodeWith;
54 using ::tensorflow::testing::matchers::Op;
55 using ::tensorflow::testing::matchers::Out;
56 using ::testing::_;
57 
BuildXlaOps(const Scope & s,const FunctionDefLibrary & fdef_lib,std::unique_ptr<Graph> * result)58 Status BuildXlaOps(const Scope& s, const FunctionDefLibrary& fdef_lib,
59                    std::unique_ptr<Graph>* result) {
60   auto graph = absl::make_unique<Graph>(OpRegistry::Global());
61   TF_RETURN_IF_ERROR(s.ToGraph(graph.get()));
62   FunctionLibraryDefinition flib_def(graph->op_registry(), fdef_lib);
63 
64   // Assign all nodes to the CPU device.
65   static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
66   for (Node* n : graph->nodes()) {
67     if (n->requested_device().empty()) {
68       n->set_assigned_device_name(kCpuDevice);
69     } else {
70       n->set_assigned_device_name(n->requested_device());
71     }
72   }
73 
74   FixupSourceAndSinkEdges(graph.get());
75 
76   GraphOptimizationPassWrapper wrapper;
77   GraphOptimizationPassOptions opt_options =
78       wrapper.CreateGraphOptimizationPassOptions(&graph);
79   opt_options.flib_def = &flib_def;
80 
81   BuildXlaOpsPass pass(/*enable_lazy_compilation=*/true);
82   TF_RETURN_IF_ERROR(pass.Run(opt_options));
83   VLOG(3) << graph->ToGraphDefDebug().DebugString();
84   *result = std::move(graph);
85   return Status::OK();
86 }
87 
MakeXlaCompiledKernel(Graph * graph,const string & callee_name,const string & node_name,int num_constant_args,int num_resource_args,Node ** result)88 Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name,
89                              const string& node_name, int num_constant_args,
90                              int num_resource_args, Node** result) {
91   NodeDef call_node;
92   call_node.set_name(node_name);
93   call_node.set_op(callee_name);
94   AddNodeAttr(kXlaCompiledKernelAttr, true, &call_node);
95   AddNodeAttr(kXlaNumConstantArgsAttr, num_constant_args, &call_node);
96   AddNodeAttr(kXlaNumResourceArgsAttr, num_resource_args, &call_node);
97   Status s;
98   *result = graph->AddNode(call_node, &s);
99   return s;
100 }
101 
MakeXlaCompiledKernel(Graph * graph,const string & callee_name,const string & node_name,Node ** result)102 Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name,
103                              const string& node_name, Node** result) {
104   return MakeXlaCompiledKernel(graph, callee_name, node_name,
105                                /*num_constant_args=*/0, /*num_resource_args=*/0,
106                                result);
107 }
108 
MakeWrite(const Scope & scope,Output value_to_write,const string & id)109 Node* MakeWrite(const Scope& scope, Output value_to_write, const string& id) {
110   Output var_handle = ops::VarHandleOp(scope.WithOpName("Var_" + id), DT_FLOAT,
111                                        TensorShape({}));
112   ops::AssignVariableOp assign_op(scope.WithOpName("Assignee_" + id),
113                                   var_handle, value_to_write);
114   return assign_op.operation.node();
115 }
116 
MakeWrite(const Scope & scope,const string & id)117 Node* MakeWrite(const Scope& scope, const string& id) {
118   return MakeWrite(
119       scope, ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f), id);
120 }
121 
CreateFunctionDefLibWithConstFunction(const string & name)122 FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) {
123   FunctionDefLibrary fdef_lib;
124   FunctionDef func = FunctionDefHelper::Create(
125       /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"},
126       /*attr_def*/
127       {}, /*node_def=*/{FunctionDefHelper::Const("one", 1.0f)},
128       /*ret_def=*/{{"out", "out:output:0"}});
129   *fdef_lib.add_function() = std::move(func);
130   return fdef_lib;
131 }
132 
CreateFunctionDefLibWithInt32Input(const string & name)133 FunctionDefLibrary CreateFunctionDefLibWithInt32Input(const string& name) {
134   FunctionDefLibrary fdef_lib;
135   FunctionDef func = FunctionDefHelper::Create(
136       /*function_name=*/name, /*in_def=*/{"in: int32"},
137       /*out_def=*/{"out: int32"},
138       /*attr_def=*/{}, /*node_def=*/{{{"out"}, "Identity", {"in"}}},
139       /*ret_def=*/{{"out", "out:output:0"}});
140   *fdef_lib.add_function() = std::move(func);
141   return fdef_lib;
142 }
143 
TEST_F(BuildXlaOpsTest,ControlDepsPreserved)144 TEST_F(BuildXlaOpsTest, ControlDepsPreserved) {
145   const char* kXlaDeviceName = "/job:worker/replica:0/task:0/device:XLA_CPU:0";
146   Scope root = Scope::NewRootScope().WithDevice(kXlaDeviceName).ExitOnError();
147 
148   FunctionDefLibrary fdef_lib =
149       CreateFunctionDefLibWithConstFunction("cluster_0");
150   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib));
151   Node* call;
152   TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
153   call->AddAttr(kXlaHasReferenceVarsAttr, false);
154   call->set_requested_device(kXlaDeviceName);
155   Node* write_op = MakeWrite(root, "write");
156   write_op->AddAttr(kXlaHasReferenceVarsAttr, false);
157   root.graph()->AddControlEdge(call, write_op);
158 
159   std::unique_ptr<Graph> graph;
160   TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph));
161 
162   Node* write_op_new = FindNodeByName(graph.get(), write_op->name());
163   ASSERT_NE(write_op_new, nullptr);
164   EXPECT_THAT(write_op_new, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun")))));
165 }
166 
TEST_F(BuildXlaOpsTest,CleanFailureOnBogusAttr)167 TEST_F(BuildXlaOpsTest, CleanFailureOnBogusAttr) {
168   Scope root = Scope::NewRootScope().ExitOnError();
169 
170   FunctionDefLibrary fdef_lib =
171       CreateFunctionDefLibWithConstFunction("cluster_0");
172   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib));
173 
174   Node* call;
175   TF_ASSERT_OK(
176       MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", 100, 100, &call));
177 
178   Node* write_op = MakeWrite(root, "write");
179   root.graph()->AddControlEdge(call, write_op);
180 
181   std::unique_ptr<Graph> graph;
182   Status failure_status = BuildXlaOps(root, fdef_lib, &graph);
183   ASSERT_FALSE(failure_status.ok());
184   EXPECT_EQ(failure_status.code(), error::INVALID_ARGUMENT);
185 }
186 
TEST_F(BuildXlaOpsTest,OnNonXlaDevice)187 TEST_F(BuildXlaOpsTest, OnNonXlaDevice) {
188   Scope root = Scope::NewRootScope().ExitOnError();
189 
190   FunctionDefLibrary fdef_lib =
191       CreateFunctionDefLibWithConstFunction("cluster_0");
192   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib));
193 
194   Node* call;
195   TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
196   TF_ASSERT_OK(root.DoShapeInference(call));
197   call->AddAttr(kXlaHasReferenceVarsAttr, false);
198 
199   Node* write_op = MakeWrite(root, Output(call), "write_result");
200   write_op->AddAttr(kXlaHasReferenceVarsAttr, false);
201 
202   auto xla_compile = NodeWith(Op("_XlaCompile"), Attr("must_compile", false));
203   auto predicated_compilation_key =
204       NodeWith(Op("Switch"), Inputs(Out(0, xla_compile), Out(1, xla_compile)));
205   auto xla_run =
206       NodeWith(Op("_XlaRun"), Inputs(Out(1, predicated_compilation_key)));
207   auto tf_call =
208       NodeWith(Op("PartitionedCall"),
209                CtrlDeps(NodeWith(Op("Identity"),
210                                  Inputs(Out(0, predicated_compilation_key)))));
211   auto merge = NodeWith(Op("_XlaMerge"), Inputs(Out(tf_call), Out(xla_run)));
212   auto assign_var = NodeWith(Op("AssignVariableOp"), Inputs(_, Out(merge)));
213 
214   std::unique_ptr<Graph> graph;
215   TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph));
216 
217   Node* write_op_new = FindNodeByName(graph.get(), write_op->name());
218   ASSERT_NE(write_op_new, nullptr);
219   EXPECT_THAT(write_op_new, assign_var);
220 }
221 
TEST_F(BuildXlaOpsTest,OnXlaDevice)222 TEST_F(BuildXlaOpsTest, OnXlaDevice) {
223   const char* kXlaDeviceName = "/job:worker/replica:0/task:0/device:XLA_CPU:0";
224   Scope root = Scope::NewRootScope().WithDevice(kXlaDeviceName).ExitOnError();
225 
226   FunctionDefLibrary fdef_lib =
227       CreateFunctionDefLibWithConstFunction("cluster_0");
228   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib));
229 
230   Node* call;
231   TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
232   call->set_requested_device(kXlaDeviceName);
233   TF_ASSERT_OK(root.DoShapeInference(call));
234   call->AddAttr(kXlaHasReferenceVarsAttr, false);
235 
236   Node* write_op = MakeWrite(root, Output(call), "write_result");
237   write_op->AddAttr(kXlaHasReferenceVarsAttr, false);
238 
239   std::unique_ptr<Graph> graph;
240   TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph));
241 
242   auto xla_op =
243       NodeWith(Op("_XlaRun"), Inputs(Out(NodeWith(Op("_XlaCompile")))));
244   auto assign_var =
245       NodeWith(Op("AssignVariableOp"), Inputs(Out(NodeWith()), Out(xla_op)));
246 
247   Node* write_op_new = FindNodeByName(graph.get(), write_op->name());
248   ASSERT_NE(write_op_new, nullptr);
249   EXPECT_THAT(write_op_new, assign_var);
250 }
251 
TEST_F(BuildXlaOpsTest,NoExtraMergeForEdgeToSink)252 TEST_F(BuildXlaOpsTest, NoExtraMergeForEdgeToSink) {
253   Scope root = Scope::NewRootScope().ExitOnError();
254 
255   FunctionDefLibrary fdef_lib =
256       CreateFunctionDefLibWithConstFunction("cluster_0");
257   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib));
258   Node* call;
259   TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
260   call->AddAttr(kXlaHasReferenceVarsAttr, false);
261 
262   std::unique_ptr<Graph> graph;
263   TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph));
264 
265   Node* sink_node = graph->sink_node();
266   EXPECT_THAT(sink_node, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun")),
267                                            NodeWith(Op("PartitionedCall")),
268                                            NodeWith(Op("NoOp")))));
269 }
270 
271 #ifdef GOOGLE_CUDA
272 // This tests a rewrite that only makes sense and is active in a CUDA-enabled
273 // build.  Specifically we check that we insert an IdentityN op to avoid extra
274 // device-to-host copies.
TEST_F(BuildXlaOpsTest,NoDeviceToHostCopiesForClustersWithInt32Inputs)275 TEST_F(BuildXlaOpsTest, NoDeviceToHostCopiesForClustersWithInt32Inputs) {
276   const char* kXlaDeviceName = "/job:worker/replica:0/task:0/device:GPU:0";
277   Scope root = Scope::NewRootScope()
278                    .WithDevice(kXlaDeviceName)
279                    .WithAssignedDevice(kXlaDeviceName)
280                    .ExitOnError();
281 
282   FunctionDefLibrary fdef_lib =
283       CreateFunctionDefLibWithInt32Input("cluster_int32");
284   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(fdef_lib));
285   Node* call;
286   TF_ASSERT_OK(
287       MakeXlaCompiledKernel(root.graph(), "cluster_int32", "C", &call));
288   call->set_requested_device(kXlaDeviceName);
289   call->AddAttr(kXlaHasReferenceVarsAttr, false);
290 
291   auto var =
292       ops::VarHandleOp(root.WithOpName("var"), DT_INT32, TensorShape({}));
293   auto int32_on_device =
294       ops::ReadVariableOp(root.WithOpName("int32_on_device"), var, DT_INT32);
295 
296   root.graph()->AddEdge(int32_on_device.node(), 0, call, 0);
297 
298   std::unique_ptr<Graph> graph;
299   TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph));
300 
301   Node* partitioned_call_op = nullptr;
302   for (Node* n : graph->op_nodes()) {
303     if (n->type_string() == "PartitionedCall") {
304       ASSERT_EQ(partitioned_call_op, nullptr);
305       partitioned_call_op = n;
306     }
307   }
308 
309   ASSERT_NE(partitioned_call_op, nullptr);
310   auto xla_compile = NodeWith(Op("_XlaCompile"));
311   auto switch_on_compilation_pred =
312       NodeWith(Op("Switch"), Inputs(Out(0, xla_compile), Out(1, xla_compile)));
313   auto ctrl_dep =
314       NodeWith(Op("Identity"), Inputs(Out(0, switch_on_compilation_pred)));
315   // Check that we pipe int32 inputs through an IdentityN to avoid extra D2H
316   // copies.
317   EXPECT_THAT(
318       partitioned_call_op,
319       NodeWith(Inputs(Out(NodeWith(Op("IdentityN"), CtrlDeps(ctrl_dep))))));
320 }
321 #endif
322 
323 }  // namespace
324 }  // namespace tensorflow
325