• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/jit/xla_kernel_creator.h"
16 
17 #include "absl/memory/memory.h"
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/str_format.h"
20 #include "tensorflow/compiler/jit/compilability_check_util.h"
21 #include "tensorflow/compiler/jit/defs.h"
22 #include "tensorflow/compiler/jit/flags.h"
23 #include "tensorflow/compiler/jit/kernels/xla_ops.h"
24 #include "tensorflow/compiler/tf2xla/const_analysis.h"
25 #include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
27 #include "tensorflow/core/common_runtime/function.h"
28 #include "tensorflow/core/framework/node_def_builder.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/util/ptr_util.h"
32 
33 namespace tensorflow {
34 
35 // Returns true iff 'ndef' is a call to a function that is compilable.  A
36 // function is compilable iff every operator in the function body is
37 // compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not
38 // null, we will populate 'uncompilable_node_info' with uncompilable node info.
IsCompilable(FunctionLibraryRuntime * flr,const NodeDef & ndef,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_node_info)39 static bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
40                          RecursiveCompilabilityChecker::UncompilableNodesMap*
41                              uncompilable_node_info) {
42   Device* device = flr->device();
43   const XlaOpRegistry::DeviceRegistration* registration;
44   CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
45                                             &registration));
46 
47   // We can always *compile* resource operations, stateful RNGs and dummy ops,
48   // even if we are sometimes unable to auto-cluster them.
49   RecursiveCompilabilityChecker::OperationFilter op_filter;
50   op_filter.allow_resource_ops_in_called_functions = true;
51   op_filter.allow_stack_ops = true;
52   op_filter.allow_tensor_array_ops = true;
53   op_filter.allow_stateful_rng_ops = true;
54   op_filter.allow_control_trigger = true;
55   op_filter.allow_eliding_assert_and_checknumerics_ops = true;
56   op_filter.allow_ops_producing_or_consuming_variant = true;
57   op_filter.allow_slow_ops = true;
58   op_filter.allow_inaccurate_ops = true;
59 
60   RecursiveCompilabilityChecker checker{
61       op_filter, DeviceType{registration->compilation_device_name}};
62   if (!uncompilable_node_info) {
63     // We do not need uncompilable node info. Just return the result.
64     return checker.IsCompilableCall(ndef, flr);
65   }
66 
67   RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result =
68       checker.FindUncompilableNodes(ndef, flr);
69   uncompilable_node_info->swap(uncompilable_node_result);
70   return uncompilable_node_info->empty();
71 }
72 
CanCreateKernel(const FunctionLibraryRuntime & flr,const std::shared_ptr<const NodeProperties> & props) const73 bool XlaKernelCreator::CanCreateKernel(
74     const FunctionLibraryRuntime& flr,
75     const std::shared_ptr<const NodeProperties>& props) const {
76   return CanCreateXlaKernel(props->node_def) &&
77          !XlaOpRegistry::IsCompilationDevice(flr.device()->device_type());
78 }
79 
CreateXlaKernel(FunctionLibraryRuntime * flr,const NodeDef & node_def,std::unique_ptr<OpKernel> * kernel)80 static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
81                               const NodeDef& node_def,
82                               std::unique_ptr<OpKernel>* kernel) {
83   if (!CanCreateXlaKernel(node_def)) {
84     return errors::Internal("Invalid node: ", node_def.ShortDebugString());
85   }
86 
87   VLOG(3) << "Attempting to create XlaLaunchOp for " << node_def.DebugString();
88 
89   // Make sure that kernels have been registered on the JIT device.
90   XlaOpRegistry::RegisterCompilationKernels();
91 
92   // Get function body, constant args, and resource args.
93   NameAttrList function;
94   TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
95   const FunctionBody* fbody = nullptr;
96   std::vector<int> constant_arg_indices;
97   std::vector<int> resource_arg_indices;
98   TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
99       flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
100 
101   // Only check for compilability if the MLIR bridge is not enabled.
102   absl::optional<ConfigProto> config_proto;
103   if (flr->config_proto()) {
104     config_proto = *flr->config_proto();
105   }
106   // There is no easy way to check if we have uninitialized resource args here
107   // so we assume there are uninitialized resource args. This means that we
108   // might run the compilability checker in cases where we don't need to (when
109   // MLIR bridge is run later). Note that this is just temporary until
110   // b/171732021 gets fixed.
111   // We should also revisit if this check provides any value, otherwise we
112   // should remove it.
113   MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
114       *fbody->graph, config_proto, /*uses_uninitialized_resource_args=*/true);
115   if (policy != MlirBridgeRolloutPolicy::kEnabledByUser) {
116     RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
117     if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
118       std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
119           uncompilable_node_info;
120       for (const auto& it : uncompilable_nodes_map) {
121         for (const auto& info : it.second.second) {
122           uncompilable_node_info.emplace_back(info);
123         }
124       }
125       std::string message = absl::StrCat(
126           "Function invoked by the following node is not compilable: ",
127           SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n");
128       absl::StrAppend(&message, "Uncompilable operations:");
129       for (const auto& node_info : uncompilable_node_info) {
130         std::string node_message = absl::StrCat(
131             "\n", node_info.name, ": ", node_info.uncompilable_reason, "\n",
132             "The op is created at:\n");
133         if (node_info.stack_trace.back().stack_trace) {
134           AbstractStackTrace::TracePrintingOptions opts;
135           opts.show_line_contents = true;
136           opts.filter_common_prefix = true;
137           opts.drop_internal_frames = true;
138           absl::StrAppend(
139               &node_message,
140               node_info.stack_trace.back().stack_trace->ToString(opts));
141         } else {
142           absl::StrAppend(&node_message, "<Unavailable>\n");
143         }
144         absl::StrAppend(&message, node_message);
145       }
146       VLOG(1) << message;
147       return errors::InvalidArgument(message);
148     }
149   }
150 
151   MemoryTypeVector input_memory_types =
152       GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
153   MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
154 
155   // Create the kernel.
156   Device* dev = flr->device();
157   Status s;
158   auto props = std::make_shared<NodeProperties>(
159       &fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types);
160   OpKernelConstruction construction(DeviceType(dev->device_type()), dev,
161                                     dev->GetAllocator(AllocatorAttributes()),
162                                     flr, dev->resource_manager(), props,
163                                     input_memory_types, output_memory_types,
164                                     flr->graph_def_version(), &s);
165 
166   *kernel = absl::make_unique<XlaLocalLaunchBase>(
167       &construction, constant_arg_indices, resource_arg_indices, function,
168       /*has_ref_vars=*/false);
169   return s;
170 }
171 
CreateKernel(FunctionLibraryRuntime * flr,const std::shared_ptr<const NodeProperties> & props,std::unique_ptr<OpKernel> * kernel) const172 Status XlaKernelCreator::CreateKernel(
173     FunctionLibraryRuntime* flr,
174     const std::shared_ptr<const NodeProperties>& props,
175     std::unique_ptr<OpKernel>* kernel) const {
176   return CreateXlaKernel(flr, props->node_def, kernel);
177 }
178 
RegisterLaunchOpCreator()179 static bool RegisterLaunchOpCreator() {
180   XlaKernelCreator* xla_kernel_creator = new XlaKernelCreator();
181   RegisterDefaultCustomKernelCreator(xla_kernel_creator);
182   return true;
183 }
184 
185 static bool register_me = RegisterLaunchOpCreator();
186 
187 }  // namespace tensorflow
188