1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ 17 #define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ 18 19 #include <atomic> 20 21 #include "tensorflow/compiler/jit/xla_compilation_cache.h" 22 #include "tensorflow/compiler/jit/xla_device.h" 23 #include "tensorflow/compiler/jit/xla_launch_util.h" 24 #include "tensorflow/compiler/jit/xla_platform_info.h" 25 #include "tensorflow/core/framework/allocator.h" 26 #include "tensorflow/core/framework/op.h" 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/platform/macros.h" 30 #include "tensorflow/core/util/stream_executor_util.h" 31 #include "tensorflow/stream_executor/tf_allocator_adapter.h" 32 33 namespace tensorflow { 34 35 36 // XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp. 37 // The only difference is that it does not require arguments to follow 38 // the "constants, then regular args, then resources" order. 39 // It takes vectors of constant and resource arguments explicitly. 40 // It does not have corresponding OpDef because it is never present 41 // in the GraphDef. 42 // Currently, it is used by eager runtime. FunctionLibraryRuntime creates 43 // this kernel when asked to create a kernel for an XLA-compiled function. 44 // 45 // `has_ref_vars`: whether the input computation can have reference variables. 46 // TODO(cheshire): instead derive this information from the input graph. 47 class XlaLocalLaunchBase : public OpKernel { 48 public: 49 XlaLocalLaunchBase(OpKernelConstruction* ctx, 50 const std::vector<int>& constants, 51 const std::vector<int>& resources, 52 const NameAttrList& function, bool has_ref_vars); 53 XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete; 54 XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete; 55 ~XlaLocalLaunchBase() override = default; 56 57 void Compute(OpKernelContext* ctx) override; 58 59 protected: 60 // Indexes of compile-time constant inputs 61 const std::vector<int> constants_; 62 // Indexes of resource inputs 63 const std::vector<int> resources_; 64 65 const NameAttrList function_; 66 const XlaPlatformInfo platform_info_; 67 68 bool has_ref_vars_; 69 }; 70 71 // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph 72 // which will be compiled and executed using XLA. The XlaLocalLaunchOp is 73 // responsible for handling interactions with the TensorFlow executor. 74 // Once all inputs are present, and their shapes are known, the op can 75 // use a 'XlaCompilationCache' to compile and execute code which is specific 76 // to the shapes of input Tensors. 77 // XlaLocalLaunchOp uses xla::LocalClient::Compile() and 78 // xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device 79 // memory. 80 class XlaLocalLaunchOp : public XlaLocalLaunchBase { 81 public: 82 explicit XlaLocalLaunchOp(OpKernelConstruction* ctx); 83 ~XlaLocalLaunchOp() override; 84 85 private: 86 TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp); 87 }; 88 89 class XlaCompileOp : public OpKernel { 90 public: 91 explicit XlaCompileOp(OpKernelConstruction* ctx); 92 93 void Compute(OpKernelContext* ctx) override; 94 95 private: 96 // Indexes of compile-time constant inputs 97 const std::vector<int> constants_; 98 // Indexes of resource inputs 99 const std::vector<int> resources_; 100 101 const NameAttrList function_; 102 103 XlaPlatformInfo platform_info_; 104 105 const bool must_compile_; 106 107 // Whether the graph has TF reference variables. 108 const bool has_ref_vars_; 109 110 // cannot_compile_cluster_ is set to true if XLA returns an Unimplemented 111 // error when compiling the cluster this _XlaCompile is supposed to compile. 112 // If `cannot_compile_cluster_` is true then we avoid compiling this cluster 113 // on any future calls to _XlaCompile. 114 bool cannot_compile_cluster_ TF_GUARDED_BY(cannot_compile_cluster_mu_) = 115 false; 116 117 mutex cannot_compile_cluster_mu_; 118 }; 119 120 class XlaRunOp : public OpKernel { 121 public: 122 explicit XlaRunOp(OpKernelConstruction* ctx); 123 124 void Compute(OpKernelContext* ctx) override; 125 126 private: 127 const XlaPlatformInfo platform_info_; 128 }; 129 130 class XlaMergeOp : public OpKernel { 131 public: 132 explicit XlaMergeOp(OpKernelConstruction* ctx); 133 134 void Compute(OpKernelContext* ctx) override; 135 }; 136 137 } // namespace tensorflow 138 139 #endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_ 140