1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ 18 19 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 20 #include "tensorflow/compiler/xla/client/xla_builder.h" 21 #include "tensorflow/compiler/xla/client/xla_computation.h" 22 #include "tensorflow/compiler/xla/xla_data.pb.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/platform/macros.h" 25 26 namespace tensorflow { 27 28 class XlaOpKernelContext; 29 30 // Implementations of operators that generate XLA code should usually subclass 31 // XlaOpKernel and implement the Compile() method. Unlike a regular OpKernel, 32 // an XlaOpKernel produces and consumes symbolic values during compilation. 33 // 34 // See the comments in xla_context.h for more details. 35 class XlaOpKernel : public OpKernel { 36 public: 37 explicit XlaOpKernel(OpKernelConstruction* construction); 38 39 // Subclasses should implement Compile(), much as standard OpKernels implement 40 // Compute(). 41 virtual void Compile(XlaOpKernelContext* context) = 0; 42 43 private: 44 void Compute(OpKernelContext* context) final; 45 }; 46 47 // The context passed to the Compile() method of XlaOpKernel. An 48 // XlaOpKernelContext is a variant of the standard OpKernel class, tailored for 49 // implementing operators that perform symbolic execution as part of the XLA 50 // compiler. The key difference is that XlaOpKernelContext produces and consumes 51 // data as XLA computations, rather than as standard Tensors. 52 // 53 // Under the hood, symbolic execution communicates using special Tensors that 54 // wrap XlaExpression objects, however this is an implementation detail that 55 // this class hides. The *only* correct way to allocate a Tensor during 56 // compilation is using the XlaOpKernelContext methods, since they ensure there 57 // is a valid XlaExpression backing the tensor. No Op should ever call 58 // allocate_output or allocate_temp directly on the underlying OpKernelContext. 59 class XlaOpKernelContext { 60 public: 61 explicit XlaOpKernelContext(OpKernelContext* context); 62 63 XlaContext* xla_context() const; 64 65 // Returns the XLA XlaBuilder containing the output of compilation. 66 xla::XlaBuilder* builder() const; 67 68 // Inputs 69 70 // Returns the number of inputs to the operator. num_inputs()71 int num_inputs() const { return context_->num_inputs(); } 72 73 // Returns the type of input `index`. 74 DataType input_type(int index) const; 75 76 // Returns the type of input `name`. 77 DataType InputType(absl::string_view name); 78 79 // Returns the type of input `index` as an xla::PrimitiveType. If the type 80 // is not representable as an XLA type, sets an error status and returns 81 // xla::PRIMITIVE_TYPE_INVALID. 82 xla::PrimitiveType input_xla_type(int index); 83 84 // Returns the shape of input `index`. 85 TensorShape InputShape(int index); 86 87 // Returns the shape of input with name `name`. 88 TensorShape InputShape(absl::string_view name); 89 90 // Returns input `index` as a XlaOp. Unlike 91 // OpKernelContext::Input returns a symbolic value rather than a concrete 92 // Tensor. 93 xla::XlaOp Input(int index); 94 // Returns input `name` as a XlaOp. 95 xla::XlaOp Input(absl::string_view name); 96 97 // Returns true if all inputs are the same shape, otherwise sets the 98 // status to a non-OK value and returns false. 99 // Usage: if (!context->ValidateInputsAreSameShape(this)) return; 100 bool ValidateInputsAreSameShape(OpKernel* op) TF_MUST_USE_RESULT; 101 102 // Returns the named list-valued immutable input in "list", as 103 // defined in the OpDef. If the named output is not list-valued, 104 // returns a one-element list. 105 Status InputList(absl::string_view name, std::vector<xla::XlaOp>* handles, 106 std::vector<TensorShape>* shapes); 107 108 // Helper methods for constant inputs. 109 110 // Evaluates input `index` and stores it in `*constant_literal`. If the 111 // expression cannot be evaluated, e.g., because it depends on unbound 112 // parameters, returns a non-OK status. 113 Status ConstantInput(int index, xla::Literal* constant_literal); 114 Status ConstantInput(absl::string_view name, xla::Literal* constant_literal); 115 116 // Converts a constant scalar int32 or int64 tensor into an int64. 117 Status ConstantInputAsIntScalar(int index, int64* out); 118 Status ConstantInputAsIntScalar(absl::string_view name, int64* out); 119 120 // Converts a constant scalar float32 or float64 tensor into a float64. 121 Status ConstantInputAsFloatScalar(int index, double* out); 122 123 // Converts a constant 1D int32 or int64 tensor into a vector of int64s. 124 Status ConstantInputAsIntVector(int index, std::vector<int64>* out); 125 Status ConstantInputAsIntVector(absl::string_view name, 126 std::vector<int64>* out); 127 128 // Reshapes and converts a constant int32 or int64 tensor into a vector of 129 // int64s. 130 Status ConstantInputReshapedToIntVector(int index, std::vector<int64>* out); 131 Status ConstantInputReshapedToIntVector(absl::string_view name, 132 std::vector<int64>* out); 133 134 // Converts a constant int32 or int64 Tensor into an xla int64 Literal. 135 Status ConstantInputAsInt64Literal(int index, xla::Literal* out); 136 Status ConstantInputAsInt64Literal(absl::string_view name, xla::Literal* out); 137 138 // Converts a constant 1D int32 or int64 tensor into a TensorShape. 139 Status ConstantInputAsShape(int index, TensorShape* shape); 140 141 // Converts a constant 1D int32 or int64 tensor, or a scalar with value -1 142 // into a PartialTensorShape. 143 Status ConstantInputAsPartialShape(int index, PartialTensorShape* shape); 144 145 // Returns the named list-valued immutable input in "list", as 146 // defined in the OpDef. If the named output is not list-valued, 147 // returns a one-element list. 148 Status ConstantInputList(absl::string_view name, 149 std::vector<xla::Literal>* literals); 150 151 // Returns an XlaExpression describing the value of 'index'. 152 const XlaExpression& InputExpression(int index); 153 const XlaExpression& InputExpression(absl::string_view name); 154 155 // Outputs 156 num_outputs()157 int num_outputs() const { return context_->num_outputs(); } expected_output_dtype(int index)158 DataType expected_output_dtype(int index) const { 159 return context_->expected_output_dtype(index); 160 } 161 162 // Returns the type of output `index` as an xla::PrimitiveType. If the type 163 // is not representable as an XLA type, sets an error status and returns 164 // xla::PRIMITIVE_TYPE_INVALID. 165 xla::PrimitiveType output_xla_type(int index); 166 167 // Sets output `index` to the XlaOp `handle`. 168 // All outputs should be set using SetOutput and SetConstantOutput, not 169 // via the underlying OpKernelContext. 170 void SetOutput(int index, const xla::XlaOp& handle); 171 172 // Sets output `index` to compile-time constant `host_tensor`, where 173 // `host_tensor` is a tensor in host memory. It is preferable to use 174 // SetConstantOutput where possible. 175 void SetConstantOutput(int index, const Tensor& host_tensor); 176 177 // Returns an XlaExpression describing the value of 'index'. 178 void SetOutputExpression(int index, const XlaExpression& expression); 179 180 // Sets output `index` to the Tensor List `handle`. 181 void SetTensorListOutput(int index, const xla::XlaOp& handle); 182 183 // Status handling. SetStatus(const Status & status)184 void SetStatus(const Status& status) { context_->SetStatus(status); } status()185 Status status() { return context_->status(); } 186 187 // Variables 188 189 // Sets `*resource` to the resource associated with input `index`. 190 Status GetResourceInput(int index, XlaResource** resource); 191 192 // Sets output `index` to be a reference to resource `resource`. 193 void SetResourceOutput(int index, XlaResource* resource); 194 195 // Sets `*type` and `*shape` to the current type and shape of a variable's 196 // value. 197 Status GetVariableTypeAndShape(int index, DataType* type, 198 TensorShape* shape) const; 199 200 // Reads the current value of the resouce variable referred to by input 201 // `index`. If `shape` is not nullptr, sets `*shape` to the shape of the 202 // variable. Returns an error if the variable has not been initialized, or if 203 // its type does not match `type`. 204 Status ReadVariableInput(int index, DataType type, TensorShape* shape, 205 xla::XlaOp* value); 206 // Reads the current value of the resouce variable referred to by input 207 // `name`. 208 Status ReadVariableInput(absl::string_view name, DataType type, 209 TensorShape* shape, xla::XlaOp* value); 210 211 // Assigns the value `handle` to the variable referenced by input 212 // `input_index`. The variable must be of `type`. Returns an error if the 213 // variable has been initialized with a different type or with a 214 // different shape. 215 Status AssignVariable(int input_index, DataType type, xla::XlaOp handle); 216 // Assigns the value `handle` to the variable referenced by input `name`. 217 Status AssignVariable(absl::string_view name, DataType type, 218 xla::XlaOp handle); 219 220 // Helper routines for the OP_REQUIRES macros 221 void CtxFailure(const Status& s); 222 void CtxFailureWithWarning(const Status& s); 223 void CtxFailure(const char* file, int line, const Status& s); 224 void CtxFailureWithWarning(const char* file, int line, const Status& s); 225 226 // If this kernel invocation is within a function execution, 227 // call_frame() returns the call frame for the function call. call_frame()228 CallFrameInterface* call_frame() const { return context_->call_frame(); } 229 function_library()230 FunctionLibraryRuntime* function_library() const { 231 return context_->function_library(); 232 } 233 op_kernel()234 const OpKernel& op_kernel() const { return context_->op_kernel(); } 235 236 // Returns the underlying OpKernelContext. Use rarely. op_kernel_context()237 OpKernelContext* op_kernel_context() const { return context_; } 238 239 // Returns the XlaCompiler that is performing the compilation. Used for, e.g., 240 // While to compile nested computations. 241 XlaCompiler* compiler() const; 242 243 // TODO(phawkins): find a better home for these helpers. 244 245 // Gets an XLA lambda to compute Max. This is cached in the 246 // XlaContext since it may be used by multiple Ops. There is a 247 // separate specialization of the computation for each DataType. 248 const xla::XlaComputation* GetOrCreateMax(const DataType type); 249 250 // Gets an XLA lambda to compute Min. This is cached in the 251 // XlaContext since it may be used by multiple Ops. There is a 252 // separate specialization of the computation for each DataType. 253 const xla::XlaComputation* GetOrCreateMin(const DataType type); 254 255 // Gets an XLA lambda to compute Add. This is cached in the 256 // XlaContext since it may be used by multiple Ops. There is a 257 // separate specialization of the computation for each DataType. 258 const xla::XlaComputation* GetOrCreateAdd(const DataType type); 259 260 // Gets an XLA lambda to compute Mul. This is cached in the 261 // XlaContext since it may be used by multiple Ops. There is a 262 // separate specialization of the computation for each DataType. 263 const xla::XlaComputation* GetOrCreateMul(const DataType type); 264 265 private: 266 // Returns the tensor of input `name`. 267 const Tensor& GetInputTensorByName(absl::string_view name); 268 269 // Evaluates input `index`, reshapes it to `new_shape` if new_shape != 270 // InputShape(index), and stores it in `*constant_literal`. If the input 271 // cannot be evaluated, e.g., because it depends on unbound parameters, 272 // returns a non-Ok status. If InputShape(index).num_elements() != 273 // new_shape.num_elements(), returns an error status. 274 Status ConstantInputReshaped(int index, absl::Span<const int64> new_dims, 275 xla::Literal* constant_literal); 276 277 OpKernelContext* const context_; 278 }; 279 280 } // namespace tensorflow 281 282 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ 283