1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/jit/create_xla_launch_op.h"
16
17 #include "absl/memory/memory.h"
18 #include "tensorflow/compiler/jit/defs.h"
19 #include "tensorflow/compiler/jit/kernels/xla_ops.h"
20 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
21 #include "tensorflow/compiler/tf2xla/const_analysis.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
23 #include "tensorflow/core/common_runtime/function.h"
24 #include "tensorflow/core/framework/node_def_builder.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/util/ptr_util.h"
27
28 namespace tensorflow {
29 namespace {
30
31 // Utility which searches for values in a sorted list by scanning over it once.
32 // No matter how many times ScanForValue is called, the list is scanned at most
33 // once. However, if a call to ScanForValue skips over a value, that value is
34 // not revisited in future calls to ScanForValue, so callers must take
35 // care to order their calls.
36 //
37 // Useful for merging multiple sorted lists in O(n) time.
38 class SinglePassSearch {
39 public:
40 // Creates a SinglePassSearch object that can be used to search in `values`.
41 // Does not take ownership of `values`. `values` must outlive this.
42 // `values` must be sorted.
SinglePassSearch(const std::vector<int> * values)43 explicit SinglePassSearch(const std::vector<int>* values)
44 : current_index_(0), values_(values) {}
45
46 // Scans forward in the vector looking for "value", updating the internal
47 // position in to the vector.
48 // Returns true iff the vector contains the given value at or after current
49 // position.
50 // Not thread-safe.
ScanForValue(int value)51 bool ScanForValue(int value) {
52 while (current_index_ < values_->size() &&
53 (*values_)[current_index_] <= value) {
54 if ((*values_)[current_index_] == value) {
55 current_index_++;
56 return true;
57 }
58 current_index_++;
59 }
60 return false;
61 }
62
63 private:
64 int current_index_;
65 const std::vector<int>* values_;
66 };
67
CompilationRequested(const FunctionLibraryRuntime & flr,const NodeDef & node_def)68 Status CompilationRequested(const FunctionLibraryRuntime& flr,
69 const NodeDef& node_def) {
70 const FunctionDef* function_def =
71 flr.GetFunctionLibraryDefinition()->Find(node_def.name());
72 if (function_def == nullptr) {
73 // The node def is not calling a function. Individual ops can be
74 // run directly using on-demand mode, no need to create XlaLaunch
75 // kernel for them.
76 // TODO(b/110359382): Make custom kernel creation return a bool instead of
77 // status.
78 // We don't set error messages here to avoid unnecessary string copy.
79 // Similarly below.
80 return Status(error::INVALID_ARGUMENT, "");
81 }
82
83 // If kXlaCompileAttr is set on the node_def, use its value.
84 const auto& it = node_def.attr().find(kXlaCompileAttr);
85 if (it != node_def.attr().end()) {
86 return it->second.b() ? Status::OK() : Status(error::INVALID_ARGUMENT, "");
87 }
88
89 // kXlaCompileAttr is not set on node_def, check if it is set on
90 // FunctionDef.
91 bool xla_compile = false;
92 Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
93 node_def, kXlaCompileAttr, &xla_compile);
94 if (!status.ok() || !xla_compile) {
95 if (VLOG_IS_ON(3)) {
96 if (!status.ok()) {
97 VLOG(3) << "No " << kXlaCompileAttr << " attr defined for "
98 << node_def.op() << ". status=" << status.ToString();
99 } else {
100 VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
101 }
102 }
103 return Status(error::INVALID_ARGUMENT, "");
104 }
105 return Status::OK();
106 }
107
108 // Given a FunctionLibraryRuntime and a NodeDef calling a function in the
109 // runtime, returns this function's body in `fbody` as well as the indices
110 // of its constant and resource arguments.
111 // `fbody` is owned by `flr`.
112 // `constant_arg_indices` and `resource_arg_indices` should be empty vector.
113 // They are sorted in ascending order on this function's return.
GetBodyAndConstantsAndResources(FunctionLibraryRuntime * flr,const NodeDef & node_def,const FunctionBody ** fbody,std::vector<int> * constant_arg_indices,std::vector<int> * resource_arg_indices)114 Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
115 const NodeDef& node_def,
116 const FunctionBody** fbody,
117 std::vector<int>* constant_arg_indices,
118 std::vector<int>* resource_arg_indices) {
119 FunctionLibraryRuntime::Handle handle;
120 // If node_def is not instantiable, e.g., the function does not exist,
121 // simply bail out.
122 TF_RETURN_IF_ERROR(
123 flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
124 *fbody = flr->GetFunctionBody(handle);
125 CHECK(*fbody); // Can't be nullptr since we just instantiated it.
126 const DataTypeVector& arg_types = (*fbody)->arg_types;
127 std::vector<bool> const_args(arg_types.size());
128 // If we can't analyze the const args. Bail out.
129 TF_RETURN_IF_ERROR(
130 BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
131 /*compile_time_const_nodes=*/nullptr, flr));
132
133 for (int i = 0; i < const_args.size(); ++i) {
134 if (const_args[i]) {
135 constant_arg_indices->push_back(i);
136 }
137 }
138
139 // There can be hundreds of resource variables. Reserve the space for them.
140 // We don't reserve for constants above as they are usually few.
141 resource_arg_indices->reserve(arg_types.size());
142 for (int i = 0; i < arg_types.size(); ++i) {
143 if (arg_types[i] == DT_RESOURCE) {
144 resource_arg_indices->push_back(i);
145 }
146 }
147
148 return Status::OK();
149 }
150
151 } // namespace
152
CreateXlaLaunchOp(FunctionLibraryRuntime * flr,const NodeDef & node_def,std::unique_ptr<OpKernel> * kernel)153 Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
154 std::unique_ptr<OpKernel>* kernel) {
155 TF_RETURN_IF_ERROR(CompilationRequested(*flr, node_def));
156
157 VLOG(3) << "Attemping to create XlaLaunchOp for " << node_def.DebugString();
158
159 // Make sure that kernels have been registered on the JIT device.
160 XlaOpRegistry::RegisterCompilationKernels();
161 if (!IsCompilable(flr, node_def)) {
162 VLOG(1) << "Not creating XlaLaunchOp because function invoked by the "
163 "following node is not compilable: "
164 << node_def.DebugString();
165 // node_def is calling a function that XLA can't compile.
166 return errors::InvalidArgument("Not compilable: ",
167 node_def.ShortDebugString());
168 }
169
170 // Get function body, constant args, and resource args.
171 const FunctionBody* fbody = nullptr;
172 std::vector<int> constant_arg_indices;
173 std::vector<int> resource_arg_indices;
174 TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
175 flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));
176
177 // Set input and output memory types.
178 MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
179 // These indices are used only for optimization purposes. They allow us
180 // to loop over constant_arg_indices and resource_arg_indices only once
181 // while iterating over all the function arguments checking if it is a
182 // resource or a constant.
183 // The reason we optimized this code is because functions can have a lot of
184 // captured arguments. For example, the backward pass of ResNet50 takes in all
185 // 214 variables and a similar number of activations.
186 SinglePassSearch constants_search(&constant_arg_indices);
187 SinglePassSearch resources_search(&resource_arg_indices);
188 for (int i = 0; i < fbody->arg_types.size(); ++i) {
189 if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
190 // Compile-time constants and resource handles are expected to be in
191 // host memory.
192 input_memory_types[i] = HOST_MEMORY;
193 }
194 }
195 // One might wonder, about the case where a compile-time constant argument
196 // (which must be in host memory) is also used as an input into an op,
197 // e.g. Add, that expects its inputs in device memory. Here is how it
198 // works now.
199 // First, what do we mean by "op expects an input in XYZ memory"?
200 // There are two types of "ops" here: the tf2xla kernel and the HLO
201 // computation it builds. The tf2xla kernel needs to retrieve the actual
202 // numeric value of the compile-time constant tensors, so it really expects
203 // them to be on in host memory. However, for other inputs, it refers to them
204 // using xla::ComputationDataHandle, which is just a symbolic handle that
205 // xla::ComputationBuilder assigns. How does this handle gets assigned for
206 // constant arguments? Even constant arguments get an _Arg node in the graph
207 // instatiated for Function compilation. The tf2xla kernel for constant _Arg
208 // nodes takes the constant value, converts it to XlaLiteral, and feeds it
209 // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
210 // constant XlaLiteral is included in the HLO graph, and subsequently, in
211 // the actual executable, which is copied to the device before being
212 // executed. Thus, when this executable runs, the constant is available in
213 // device memory.
214
215 // XlaLaunch kernel keeps all outputs (including constants, which it copies),
216 // in device memory except for resources.
217 MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
218 for (int i = 0; i < fbody->ret_types.size(); ++i) {
219 if (fbody->ret_types[i] == DT_RESOURCE) {
220 output_memory_types[i] = HOST_MEMORY;
221 }
222 }
223
224 // Create the kernel.
225 NameAttrList function;
226 function.set_name(node_def.op());
227 *(function.mutable_attr()) = node_def.attr();
228
229 Device* dev = flr->device();
230 Status s;
231 OpKernelConstruction construction(
232 DeviceType(dev->device_type()), dev,
233 dev->GetAllocator(AllocatorAttributes()), &node_def,
234 &fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
235 fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
236
237 *kernel = absl::make_unique<XlaLocalLaunchBase>(
238 &construction, constant_arg_indices, resource_arg_indices, function);
239 return s;
240 }
241
242 namespace {
243
RegisterLaunchOpCreator()244 bool RegisterLaunchOpCreator() {
245 RegisterDefaultCustomKernelCreator(CreateXlaLaunchOp);
246 return true;
247 }
248
249 static bool register_me = RegisterLaunchOpCreator();
250
251 } // end namespace
252 } // namespace tensorflow
253