• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/jit/create_xla_launch_op.h"
16 
17 #include "absl/memory/memory.h"
18 #include "tensorflow/compiler/jit/defs.h"
19 #include "tensorflow/compiler/jit/kernels/xla_ops.h"
20 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
21 #include "tensorflow/compiler/tf2xla/const_analysis.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
23 #include "tensorflow/core/common_runtime/function.h"
24 #include "tensorflow/core/framework/node_def_builder.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/util/ptr_util.h"
27 
28 namespace tensorflow {
29 namespace {
30 
31 // Utility which searches for values in a sorted list by scanning over it once.
32 // No matter how many times ScanForValue is called, the list is scanned at most
33 // once. However, if a call to ScanForValue skips over a value, that value is
34 // not revisited in future calls to ScanForValue, so callers must take
35 // care to order their calls.
36 //
37 // Useful for merging multiple sorted lists in O(n) time.
38 class SinglePassSearch {
39  public:
40   // Creates a SinglePassSearch object that can be used to search in `values`.
41   // Does not take ownership of `values`. `values` must outlive this.
42   // `values` must be sorted.
SinglePassSearch(const std::vector<int> * values)43   explicit SinglePassSearch(const std::vector<int>* values)
44       : current_index_(0), values_(values) {}
45 
46   // Scans forward in the vector looking for "value", updating the internal
47   // position in to the vector.
48   // Returns true iff the vector contains the given value at or after current
49   // position.
50   // Not thread-safe.
ScanForValue(int value)51   bool ScanForValue(int value) {
52     while (current_index_ < values_->size() &&
53            (*values_)[current_index_] <= value) {
54       if ((*values_)[current_index_] == value) {
55         current_index_++;
56         return true;
57       }
58       current_index_++;
59     }
60     return false;
61   }
62 
63  private:
64   int current_index_;
65   const std::vector<int>* values_;
66 };
67 
CompilationRequested(const FunctionLibraryRuntime & flr,const NodeDef & node_def)68 Status CompilationRequested(const FunctionLibraryRuntime& flr,
69                             const NodeDef& node_def) {
70   const FunctionDef* function_def =
71       flr.GetFunctionLibraryDefinition()->Find(node_def.name());
72   if (function_def == nullptr) {
73     // The node def is not calling a function. Individual ops can be
74     // run directly using on-demand mode, no need to create XlaLaunch
75     // kernel for them.
76     // TODO(b/110359382): Make custom kernel creation return a bool instead of
77     // status.
78     // We don't set error messages here to avoid unnecessary string copy.
79     // Similarly below.
80     return Status(error::INVALID_ARGUMENT, "");
81   }
82 
83   // If kXlaCompileAttr is set on the node_def, use its value.
84   const auto& it = node_def.attr().find(kXlaCompileAttr);
85   if (it != node_def.attr().end()) {
86     return it->second.b() ? Status::OK() : Status(error::INVALID_ARGUMENT, "");
87   }
88 
89   // kXlaCompileAttr is not set on node_def, check if it is set on
90   // FunctionDef.
91   bool xla_compile = false;
92   Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
93       node_def, kXlaCompileAttr, &xla_compile);
94   if (!status.ok() || !xla_compile) {
95     if (VLOG_IS_ON(3)) {
96       if (!status.ok()) {
97         VLOG(3) << "No " << kXlaCompileAttr << " attr defined for "
98                 << node_def.op() << ". status=" << status.ToString();
99       } else {
100         VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
101       }
102     }
103     return Status(error::INVALID_ARGUMENT, "");
104   }
105   return Status::OK();
106 }
107 
108 // Given a FunctionLibraryRuntime and a NodeDef calling a function in the
109 // runtime, returns this function's body in `fbody` as well as the indices
110 // of its constant and resource arguments.
111 // `fbody` is owned by `flr`.
112 // `constant_arg_indices` and `resource_arg_indices` should be empty vector.
113 // They are sorted in ascending order on this function's return.
GetBodyAndConstantsAndResources(FunctionLibraryRuntime * flr,const NodeDef & node_def,const FunctionBody ** fbody,std::vector<int> * constant_arg_indices,std::vector<int> * resource_arg_indices)114 Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
115                                        const NodeDef& node_def,
116                                        const FunctionBody** fbody,
117                                        std::vector<int>* constant_arg_indices,
118                                        std::vector<int>* resource_arg_indices) {
119   FunctionLibraryRuntime::Handle handle;
120   // If node_def is not instantiable, e.g., the function does not exist,
121   // simply bail out.
122   TF_RETURN_IF_ERROR(
123       flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
124   *fbody = flr->GetFunctionBody(handle);
125   CHECK(*fbody);  // Can't be nullptr since we just instantiated it.
126   const DataTypeVector& arg_types = (*fbody)->arg_types;
127   std::vector<bool> const_args(arg_types.size());
128   // If we can't analyze the const args. Bail out.
129   TF_RETURN_IF_ERROR(
130       BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
131                              /*compile_time_const_nodes=*/nullptr, flr));
132 
133   for (int i = 0; i < const_args.size(); ++i) {
134     if (const_args[i]) {
135       constant_arg_indices->push_back(i);
136     }
137   }
138 
139   // There can be hundreds of resource variables. Reserve the space for them.
140   // We don't reserve for constants above as they are usually few.
141   resource_arg_indices->reserve(arg_types.size());
142   for (int i = 0; i < arg_types.size(); ++i) {
143     if (arg_types[i] == DT_RESOURCE) {
144       resource_arg_indices->push_back(i);
145     }
146   }
147 
148   return Status::OK();
149 }
150 
151 }  // namespace
152 
CreateXlaLaunchOp(FunctionLibraryRuntime * flr,const NodeDef & node_def,std::unique_ptr<OpKernel> * kernel)153 Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
154                          std::unique_ptr<OpKernel>* kernel) {
155   TF_RETURN_IF_ERROR(CompilationRequested(*flr, node_def));
156 
157   VLOG(3) << "Attemping to create XlaLaunchOp for " << node_def.DebugString();
158 
159   // Make sure that kernels have been registered on the JIT device.
160   XlaOpRegistry::RegisterCompilationKernels();
161   if (!IsCompilable(flr, node_def)) {
162     VLOG(1) << "Not creating XlaLaunchOp because function invoked by the "
163                "following node is not compilable: "
164             << node_def.DebugString();
165     // node_def is calling a function that XLA can't compile.
166     return errors::InvalidArgument("Not compilable: ",
167                                    node_def.ShortDebugString());
168   }
169 
170   // Get function body, constant args, and resource args.
171   const FunctionBody* fbody = nullptr;
172   std::vector<int> constant_arg_indices;
173   std::vector<int> resource_arg_indices;
174   TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
175       flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));
176 
177   // Set input and output memory types.
178   MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
179   // These indices are used only for optimization purposes. They allow us
180   // to loop over constant_arg_indices and resource_arg_indices only once
181   // while iterating over all the function arguments checking if it is a
182   // resource or a constant.
183   // The reason we optimized this code is because functions can have a lot of
184   // captured arguments. For example, the backward pass of ResNet50 takes in all
185   // 214 variables and a similar number of activations.
186   SinglePassSearch constants_search(&constant_arg_indices);
187   SinglePassSearch resources_search(&resource_arg_indices);
188   for (int i = 0; i < fbody->arg_types.size(); ++i) {
189     if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
190       // Compile-time constants and resource handles are expected to be in
191       // host memory.
192       input_memory_types[i] = HOST_MEMORY;
193     }
194   }
195   // One might wonder, about the case where a compile-time constant argument
196   // (which must be in host memory) is also used as an input into an op,
197   // e.g. Add, that expects its inputs in device memory. Here is how it
198   // works now.
199   // First, what do we mean by "op expects an input in XYZ memory"?
200   // There are two types of "ops" here: the tf2xla kernel and the HLO
201   // computation it builds. The tf2xla kernel needs to retrieve the actual
202   // numeric value of the compile-time constant tensors, so it really expects
203   // them to be on in host memory. However, for other inputs, it refers to them
204   // using xla::ComputationDataHandle, which is just a symbolic handle that
205   // xla::ComputationBuilder assigns. How does this handle gets assigned for
206   // constant arguments? Even constant arguments get an _Arg node in the graph
207   // instatiated for Function compilation. The tf2xla kernel for constant _Arg
208   // nodes takes the constant value, converts it to XlaLiteral, and feeds it
209   // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
210   // constant XlaLiteral is included in the HLO graph, and subsequently, in
211   // the actual executable, which is copied to the device before being
212   // executed. Thus, when this executable runs, the constant is available in
213   // device memory.
214 
215   // XlaLaunch kernel keeps all outputs (including constants, which it copies),
216   // in device memory except for resources.
217   MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
218   for (int i = 0; i < fbody->ret_types.size(); ++i) {
219     if (fbody->ret_types[i] == DT_RESOURCE) {
220       output_memory_types[i] = HOST_MEMORY;
221     }
222   }
223 
224   // Create the kernel.
225   NameAttrList function;
226   function.set_name(node_def.op());
227   *(function.mutable_attr()) = node_def.attr();
228 
229   Device* dev = flr->device();
230   Status s;
231   OpKernelConstruction construction(
232       DeviceType(dev->device_type()), dev,
233       dev->GetAllocator(AllocatorAttributes()), &node_def,
234       &fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
235       fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
236 
237   *kernel = absl::make_unique<XlaLocalLaunchBase>(
238       &construction, constant_arg_indices, resource_arg_indices, function);
239   return s;
240 }
241 
242 namespace {
243 
RegisterLaunchOpCreator()244 bool RegisterLaunchOpCreator() {
245   RegisterDefaultCustomKernelCreator(CreateXlaLaunchOp);
246   return true;
247 }
248 
249 static bool register_me = RegisterLaunchOpCreator();
250 
251 }  // end namespace
252 }  // namespace tensorflow
253