• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
17 #define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
18 
19 #include <atomic>
20 
21 #include "tensorflow/compiler/jit/xla_compilation_cache.h"
22 #include "tensorflow/compiler/jit/xla_device.h"
23 #include "tensorflow/compiler/jit/xla_launch_util.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow/core/util/stream_executor_util.h"
30 
31 namespace tensorflow {
32 
33 // Holds some information about the platform on which an
34 // XlaLaunch/_XlaCompile/_XlaRun op must run on.
35 class XlaPlatformInfo {
36  public:
XlaPlatformInfo()37   XlaPlatformInfo() : device_type_("") {}
38   XlaPlatformInfo(XlaPlatformInfo&&) = default;
XlaPlatformInfo(const DeviceType device_type,se::Platform::Id platform_id,const XlaDevice::Metadata * xla_device_metadata,std::unique_ptr<XlaAllocator> xla_allocator,xla::DeviceMemoryAllocator * device_allocator)39   explicit XlaPlatformInfo(const DeviceType device_type,
40                            se::Platform::Id platform_id,
41                            const XlaDevice::Metadata* xla_device_metadata,
42                            std::unique_ptr<XlaAllocator> xla_allocator,
43                            xla::DeviceMemoryAllocator* device_allocator)
44       : device_type_(device_type),
45         platform_id_(platform_id),
46         xla_device_metadata_(xla_device_metadata),
47         xla_allocator_(std::move(xla_allocator)),
48         device_allocator_(device_allocator) {
49     CHECK((device_allocator_ != nullptr) ^ (xla_allocator_.get() != nullptr));
50   }
51 
52   XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default;
53 
UseMultipleStreams()54   bool UseMultipleStreams() const {
55     return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
56   }
57 
allocator()58   xla::DeviceMemoryAllocator* allocator() const {
59     return device_allocator_ ? device_allocator_ : xla_allocator_.get();
60   }
device_type()61   DeviceType device_type() const { return device_type_; }
62 
63   // This is equal to xla_device_metadata()->platform()->id() if
64   // xla_device_metadata() is not nullptr.
platform_id()65   se::Platform::Id platform_id() const { return platform_id_; }
66 
67   // This may be null if the op this XlaPlatformInfo is for was not placed on an
68   // XLA device.
xla_device_metadata()69   const XlaDevice::Metadata* xla_device_metadata() const {
70     return xla_device_metadata_;
71   }
is_on_xla_device()72   bool is_on_xla_device() const { return xla_device_metadata() != nullptr; }
73 
74  private:
75   DeviceType device_type_;
76   se::Platform::Id platform_id_;
77 
78   // xla_device_metadata_ lives in the tensorflow::DeviceBase in which the
79   // XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the
80   // XlaLaunch/_XlaCompile/_XlaRun OpKernel.
81   const XlaDevice::Metadata* xla_device_metadata_;
82 
83   // If the op associated with this XlaPlatformInfo is placed on an XLA device
84   // then device_allocator_ is the xla::Backend's memory allocator and
85   // xla_allocator_ is null.  If the op is placed on a regular CPU or GPU device
86   // then device_allocator_ is null and xla_allocator_ points to an appropriate
87   // XlaAllocator instance.
88   std::unique_ptr<XlaAllocator> xla_allocator_;
89   xla::DeviceMemoryAllocator* device_allocator_;
90 
91   TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
92 };
93 
94 // XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
95 // The only difference is that it does not require arguments to follow
96 // the "constants, then regular args, then resources" order.
97 // It takes vectors of constant and resource arguments explicitly.
98 // It does not have corresponding OpDef because it is never present
99 // in the GraphDef.
100 // Currently, it is used by eager runtime. FunctionLibraryRuntime creates
101 // this kernel when asked to create a kernel for an XLA-compiled function.
102 class XlaLocalLaunchBase : public OpKernel {
103  public:
104   XlaLocalLaunchBase(OpKernelConstruction* ctx,
105                      const std::vector<int>& constants,
106                      const std::vector<int>& resources,
107                      const NameAttrList& function);
108   XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
109   XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
110   ~XlaLocalLaunchBase() override = default;
111 
112   void Compute(OpKernelContext* ctx) override;
113 
114  protected:
115   // Indexes of compile-time constant inputs
116   const std::vector<int> constants_;
117   // Indexes of resource inputs
118   const std::vector<int> resources_;
119 
120   const NameAttrList function_;
121   const XlaPlatformInfo platform_info_;
122 };
123 
124 // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
125 // which will be compiled and executed using XLA.  The XlaLocalLaunchOp is
126 // responsible for handling interactions with the TensorFlow executor.
127 // Once all inputs are present, and their shapes are known, the op can
128 // use a 'XlaCompilationCache' to compile and execute code which is specific
129 // to the shapes of input Tensors.
130 // XlaLocalLaunchOp uses xla::LocalClient::Compile() and
131 // xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
132 // memory.
133 class XlaLocalLaunchOp : public XlaLocalLaunchBase {
134  public:
135   explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
136   ~XlaLocalLaunchOp() override;
137 
138  private:
139   TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
140 };
141 
142 class XlaCompileOp : public OpKernel {
143  public:
144   explicit XlaCompileOp(OpKernelConstruction* ctx);
145 
146   void Compute(OpKernelContext* ctx) override;
147 
148  private:
149   // Indexes of compile-time constant inputs
150   const std::vector<int> constants_;
151   // Indexes of resource inputs
152   const std::vector<int> resources_;
153 
154   const NameAttrList function_;
155 
156   XlaPlatformInfo platform_info_;
157 
158   const bool must_compile_;
159 
160   // cannot_compile_cluster_ is set to true if XLA returns an Unimplemented
161   // error when compiling the cluster this _XlaCompile is supposed to compile.
162   // If `cannot_compile_cluster_` is true then we avoid compiling this cluster
163   // on any future calls to _XlaCompile.
164   bool cannot_compile_cluster_ GUARDED_BY(cannot_compile_cluster_mu_) = false;
165 
166   mutex cannot_compile_cluster_mu_;
167 };
168 
169 class XlaRunOp : public OpKernel {
170  public:
171   explicit XlaRunOp(OpKernelConstruction* ctx);
172 
173   void Compute(OpKernelContext* ctx) override;
174 
175  private:
176   const XlaPlatformInfo platform_info_;
177 };
178 
179 }  // namespace tensorflow
180 
181 #endif  // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
182