1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/jit/xla_kernel_creator.h"
16
17 #include "absl/memory/memory.h"
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/str_format.h"
20 #include "tensorflow/compiler/jit/compilability_check_util.h"
21 #include "tensorflow/compiler/jit/defs.h"
22 #include "tensorflow/compiler/jit/flags.h"
23 #include "tensorflow/compiler/jit/kernels/xla_ops.h"
24 #include "tensorflow/compiler/tf2xla/const_analysis.h"
25 #include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
27 #include "tensorflow/core/common_runtime/function.h"
28 #include "tensorflow/core/framework/node_def_builder.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/util/ptr_util.h"
32
33 namespace tensorflow {
34
35 // Returns true iff 'ndef' is a call to a function that is compilable. A
36 // function is compilable iff every operator in the function body is
37 // compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not
38 // null, we will populate 'uncompilable_node_info' with uncompilable node info.
IsCompilable(FunctionLibraryRuntime * flr,const NodeDef & ndef,RecursiveCompilabilityChecker::UncompilableNodesMap * uncompilable_node_info)39 static bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
40 RecursiveCompilabilityChecker::UncompilableNodesMap*
41 uncompilable_node_info) {
42 Device* device = flr->device();
43 const XlaOpRegistry::DeviceRegistration* registration;
44 CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
45 ®istration));
46
47 // We can always *compile* resource operations, stateful RNGs and dummy ops,
48 // even if we are sometimes unable to auto-cluster them.
49 RecursiveCompilabilityChecker::OperationFilter op_filter;
50 op_filter.allow_resource_ops_in_called_functions = true;
51 op_filter.allow_stack_ops = true;
52 op_filter.allow_tensor_array_ops = true;
53 op_filter.allow_stateful_rng_ops = true;
54 op_filter.allow_control_trigger = true;
55 op_filter.allow_eliding_assert_and_checknumerics_ops = true;
56 op_filter.allow_ops_producing_or_consuming_variant = true;
57 op_filter.allow_slow_ops = true;
58 op_filter.allow_inaccurate_ops = true;
59
60 RecursiveCompilabilityChecker checker{
61 op_filter, DeviceType{registration->compilation_device_name}};
62 if (!uncompilable_node_info) {
63 // We do not need uncompilable node info. Just return the result.
64 return checker.IsCompilableCall(ndef, flr);
65 }
66
67 RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result =
68 checker.FindUncompilableNodes(ndef, flr);
69 uncompilable_node_info->swap(uncompilable_node_result);
70 return uncompilable_node_info->empty();
71 }
72
CanCreateKernel(const FunctionLibraryRuntime & flr,const std::shared_ptr<const NodeProperties> & props) const73 bool XlaKernelCreator::CanCreateKernel(
74 const FunctionLibraryRuntime& flr,
75 const std::shared_ptr<const NodeProperties>& props) const {
76 return CanCreateXlaKernel(props->node_def) &&
77 !XlaOpRegistry::IsCompilationDevice(flr.device()->device_type());
78 }
79
CreateXlaKernel(FunctionLibraryRuntime * flr,const NodeDef & node_def,std::unique_ptr<OpKernel> * kernel)80 static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
81 const NodeDef& node_def,
82 std::unique_ptr<OpKernel>* kernel) {
83 if (!CanCreateXlaKernel(node_def)) {
84 return errors::Internal("Invalid node: ", node_def.ShortDebugString());
85 }
86
87 VLOG(3) << "Attempting to create XlaLaunchOp for " << node_def.DebugString();
88
89 // Make sure that kernels have been registered on the JIT device.
90 XlaOpRegistry::RegisterCompilationKernels();
91
92 // Get function body, constant args, and resource args.
93 NameAttrList function;
94 TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
95 const FunctionBody* fbody = nullptr;
96 std::vector<int> constant_arg_indices;
97 std::vector<int> resource_arg_indices;
98 TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
99 flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
100
101 // Only check for compilability if the MLIR bridge is not enabled.
102 absl::optional<ConfigProto> config_proto;
103 if (flr->config_proto()) {
104 config_proto = *flr->config_proto();
105 }
106 // There is no easy way to check if we have uninitialized resource args here
107 // so we assume there are uninitialized resource args. This means that we
108 // might run the compilability checker in cases where we don't need to (when
109 // MLIR bridge is run later). Note that this is just temporary until
110 // b/171732021 gets fixed.
111 // We should also revisit if this check provides any value, otherwise we
112 // should remove it.
113 MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
114 *fbody->graph, config_proto, /*uses_uninitialized_resource_args=*/true);
115 if (policy != MlirBridgeRolloutPolicy::kEnabledByUser) {
116 RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
117 if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
118 std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
119 uncompilable_node_info;
120 for (const auto& it : uncompilable_nodes_map) {
121 for (const auto& info : it.second.second) {
122 uncompilable_node_info.emplace_back(info);
123 }
124 }
125 std::string message = absl::StrCat(
126 "Function invoked by the following node is not compilable: ",
127 SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n");
128 absl::StrAppend(&message, "Uncompilable operations:");
129 for (const auto& node_info : uncompilable_node_info) {
130 std::string node_message = absl::StrCat(
131 "\n", node_info.name, ": ", node_info.uncompilable_reason, "\n",
132 "The op is created at:\n");
133 if (node_info.stack_trace.back().stack_trace) {
134 AbstractStackTrace::TracePrintingOptions opts;
135 opts.show_line_contents = true;
136 opts.filter_common_prefix = true;
137 opts.drop_internal_frames = true;
138 absl::StrAppend(
139 &node_message,
140 node_info.stack_trace.back().stack_trace->ToString(opts));
141 } else {
142 absl::StrAppend(&node_message, "<Unavailable>\n");
143 }
144 absl::StrAppend(&message, node_message);
145 }
146 VLOG(1) << message;
147 return errors::InvalidArgument(message);
148 }
149 }
150
151 MemoryTypeVector input_memory_types =
152 GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
153 MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
154
155 // Create the kernel.
156 Device* dev = flr->device();
157 Status s;
158 auto props = std::make_shared<NodeProperties>(
159 &fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types);
160 OpKernelConstruction construction(DeviceType(dev->device_type()), dev,
161 dev->GetAllocator(AllocatorAttributes()),
162 flr, dev->resource_manager(), props,
163 input_memory_types, output_memory_types,
164 flr->graph_def_version(), &s);
165
166 *kernel = absl::make_unique<XlaLocalLaunchBase>(
167 &construction, constant_arg_indices, resource_arg_indices, function,
168 /*has_ref_vars=*/false);
169 return s;
170 }
171
CreateKernel(FunctionLibraryRuntime * flr,const std::shared_ptr<const NodeProperties> & props,std::unique_ptr<OpKernel> * kernel) const172 Status XlaKernelCreator::CreateKernel(
173 FunctionLibraryRuntime* flr,
174 const std::shared_ptr<const NodeProperties>& props,
175 std::unique_ptr<OpKernel>* kernel) const {
176 return CreateXlaKernel(flr, props->node_def, kernel);
177 }
178
RegisterLaunchOpCreator()179 static bool RegisterLaunchOpCreator() {
180 XlaKernelCreator* xla_kernel_creator = new XlaKernelCreator();
181 RegisterDefaultCustomKernelCreator(xla_kernel_creator);
182 return true;
183 }
184
185 static bool register_me = RegisterLaunchOpCreator();
186
187 } // namespace tensorflow
188