• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
17 #define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
18 
19 #include <atomic>
20 
21 #include "tensorflow/compiler/jit/xla_compilation_cache.h"
22 #include "tensorflow/compiler/jit/xla_device.h"
23 #include "tensorflow/compiler/jit/xla_launch_util.h"
24 #include "tensorflow/compiler/jit/xla_platform_info.h"
25 #include "tensorflow/core/framework/allocator.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/util/stream_executor_util.h"
31 #include "tensorflow/stream_executor/tf_allocator_adapter.h"
32 
33 namespace tensorflow {
34 
35 
36 // XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
37 // The only difference is that it does not require arguments to follow
38 // the "constants, then regular args, then resources" order.
39 // It takes vectors of constant and resource arguments explicitly.
40 // It does not have corresponding OpDef because it is never present
41 // in the GraphDef.
42 // Currently, it is used by eager runtime. FunctionLibraryRuntime creates
43 // this kernel when asked to create a kernel for an XLA-compiled function.
44 //
45 // `has_ref_vars`: whether the input computation can have reference variables.
46 // TODO(cheshire): instead derive this information from the input graph.
47 class XlaLocalLaunchBase : public OpKernel {
48  public:
49   XlaLocalLaunchBase(OpKernelConstruction* ctx,
50                      const std::vector<int>& constants,
51                      const std::vector<int>& resources,
52                      const NameAttrList& function, bool has_ref_vars);
53   XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
54   XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
55   ~XlaLocalLaunchBase() override = default;
56 
57   void Compute(OpKernelContext* ctx) override;
58 
59  protected:
60   // Indexes of compile-time constant inputs
61   const std::vector<int> constants_;
62   // Indexes of resource inputs
63   const std::vector<int> resources_;
64 
65   const NameAttrList function_;
66   const XlaPlatformInfo platform_info_;
67 
68   bool has_ref_vars_;
69 };
70 
71 // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
72 // which will be compiled and executed using XLA.  The XlaLocalLaunchOp is
73 // responsible for handling interactions with the TensorFlow executor.
74 // Once all inputs are present, and their shapes are known, the op can
75 // use a 'XlaCompilationCache' to compile and execute code which is specific
76 // to the shapes of input Tensors.
77 // XlaLocalLaunchOp uses xla::LocalClient::Compile() and
78 // xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
79 // memory.
80 class XlaLocalLaunchOp : public XlaLocalLaunchBase {
81  public:
82   explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
83   ~XlaLocalLaunchOp() override;
84 
85  private:
86   TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
87 };
88 
89 class XlaCompileOp : public OpKernel {
90  public:
91   explicit XlaCompileOp(OpKernelConstruction* ctx);
92 
93   void Compute(OpKernelContext* ctx) override;
94 
95  private:
96   // Indexes of compile-time constant inputs
97   const std::vector<int> constants_;
98   // Indexes of resource inputs
99   const std::vector<int> resources_;
100 
101   const NameAttrList function_;
102 
103   XlaPlatformInfo platform_info_;
104 
105   const bool must_compile_;
106 
107   // Whether the graph has TF reference variables.
108   const bool has_ref_vars_;
109 
110   // cannot_compile_cluster_ is set to true if XLA returns an Unimplemented
111   // error when compiling the cluster this _XlaCompile is supposed to compile.
112   // If `cannot_compile_cluster_` is true then we avoid compiling this cluster
113   // on any future calls to _XlaCompile.
114   bool cannot_compile_cluster_ TF_GUARDED_BY(cannot_compile_cluster_mu_) =
115       false;
116 
117   mutex cannot_compile_cluster_mu_;
118 };
119 
120 class XlaRunOp : public OpKernel {
121  public:
122   explicit XlaRunOp(OpKernelConstruction* ctx);
123 
124   void Compute(OpKernelContext* ctx) override;
125 
126  private:
127   const XlaPlatformInfo platform_info_;
128 };
129 
130 class XlaMergeOp : public OpKernel {
131  public:
132   explicit XlaMergeOp(OpKernelConstruction* ctx);
133 
134   void Compute(OpKernelContext* ctx) override;
135 };
136 
137 }  // namespace tensorflow
138 
139 #endif  // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
140