1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ 18 19 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 20 #include "tensorflow/compiler/tf2xla/xla_context.h" 21 #include "tensorflow/compiler/tf2xla/xla_expression.h" 22 #include "tensorflow/compiler/tf2xla/xla_resource.h" 23 #include "tensorflow/compiler/xla/client/xla_builder.h" 24 #include "tensorflow/compiler/xla/client/xla_computation.h" 25 #include "tensorflow/compiler/xla/xla_data.pb.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/platform/macros.h" 28 29 namespace tensorflow { 30 31 class XlaOpKernelContext; 32 33 // Implementations of operators that generate XLA code should usually subclass 34 // XlaOpKernel and implement the Compile() method. Unlike a regular OpKernel, 35 // an XlaOpKernel produces and consumes symbolic values during compilation. 36 // 37 // See the comments in xla_context.h for more details. 38 class XlaOpKernel : public OpKernel { 39 public: 40 explicit XlaOpKernel(OpKernelConstruction* construction); 41 42 // Subclasses should implement Compile(), much as standard OpKernels implement 43 // Compute(). 44 virtual void Compile(XlaOpKernelContext* context) = 0; 45 46 private: 47 void Compute(OpKernelContext* context) final; 48 }; 49 50 // The context passed to the Compile() method of XlaOpKernel. An 51 // XlaOpKernelContext is a variant of the standard OpKernel class, tailored for 52 // implementing operators that perform symbolic execution as part of the XLA 53 // compiler. The key difference is that XlaOpKernelContext produces and consumes 54 // data as XLA computations, rather than as standard Tensors. 55 // 56 // Under the hood, symbolic execution communicates using special Tensors that 57 // wrap XlaExpression objects, however this is an implementation detail that 58 // this class hides. The *only* correct way to allocate a Tensor during 59 // compilation is using the XlaOpKernelContext methods, since they ensure there 60 // is a valid XlaExpression backing the tensor. No Op should ever call 61 // allocate_output or allocate_temp directly on the underlying OpKernelContext. 62 class XlaOpKernelContext { 63 public: 64 explicit XlaOpKernelContext(OpKernelContext* context); 65 66 XlaContext* xla_context() const; 67 68 // Returns the XLA XlaBuilder containing the output of compilation. 69 xla::XlaBuilder* builder() const; 70 71 // Inputs 72 73 // Returns the number of inputs to the operator. num_inputs()74 int num_inputs() const { return context_->num_inputs(); } 75 76 // Returns the type of input `index`. 77 DataType input_type(int index) const; 78 79 // Returns the type of input `name`. 80 DataType InputType(absl::string_view name); 81 82 // Returns the type of input `index` as an xla::PrimitiveType. If the type 83 // is not representable as an XLA type, sets an error status and returns 84 // xla::PRIMITIVE_TYPE_INVALID. 85 xla::PrimitiveType input_xla_type(int index); 86 87 // Returns the type of input `name` as an xla::PrimitiveType. If the type 88 // is not representable as an XLA type, sets an error status and returns 89 // xla::PRIMITIVE_TYPE_INVALID. 90 xla::PrimitiveType InputXlaType(absl::string_view name); 91 92 // Returns the shape of input `index`. 93 TensorShape InputShape(int index); 94 95 // Returns the shape of input with name `name`. 96 TensorShape InputShape(absl::string_view name); 97 98 // Returns input `index` as a XlaOp. Unlike 99 // OpKernelContext::Input returns a symbolic value rather than a concrete 100 // Tensor. 101 xla::XlaOp Input(int index); 102 // Returns input `name` as a XlaOp. 103 xla::XlaOp Input(absl::string_view name); 104 105 // Returns the xla input shape for a given index. 106 xla::StatusOr<xla::Shape> InputXlaShape(int index); 107 xla::StatusOr<xla::Shape> InputXlaShape(absl::string_view name); 108 109 // Returns true if all inputs are the same shape, otherwise sets the 110 // status to a non-OK value and returns false. 111 // Usage: if (!context->ValidateInputsAreSameShape(this)) return; 112 bool ValidateInputsAreSameShape(OpKernel* op) TF_MUST_USE_RESULT; 113 114 // Returns the named list-valued immutable input in "list", as 115 // defined in the OpDef. If the named output is not list-valued, 116 // returns a one-element list. 117 Status InputList(absl::string_view name, std::vector<xla::XlaOp>* handles, 118 std::vector<TensorShape>* shapes); 119 // Evaluates input and returns their dynamism vector in a vector of 120 // predicates. 121 Status ResolveInputDynamismIntoPredVector(int index, std::vector<bool>* out); 122 Status ResolveInputDynamismIntoPred(int index, bool* out); 123 // Helper methods for constant inputs. 124 125 // Evaluates input `index` and stores it in `*constant_literal`. If the 126 // expression cannot be evaluated, e.g., because it depends on unbound 127 // parameters, returns a non-OK status. 128 Status ConstantInput(int index, xla::Literal* constant_literal); 129 Status ConstantInput(absl::string_view name, xla::Literal* constant_literal); 130 131 // Converts a constant scalar int32 or int64 tensor into an int64. 132 Status ConstantInputAsIntScalar(int index, int64* out); 133 Status ConstantInputAsIntScalar(absl::string_view name, int64* out); 134 135 // Converts a constant scalar float32 or float64 tensor into a float64. 136 Status ConstantInputAsFloatScalar(int index, double* out); 137 138 // Converts a constant 1D int32 or int64 tensor into a vector of int64s. 139 Status ConstantInputAsIntVector(int index, std::vector<int64>* out); 140 Status ConstantInputAsIntVector(absl::string_view name, 141 std::vector<int64>* out); 142 143 // Reshapes and converts a constant int32 or int64 tensor into a vector of 144 // int64s. 145 Status ConstantInputReshapedToIntVector(int index, std::vector<int64>* out); 146 Status ConstantInputReshapedToIntVector(absl::string_view name, 147 std::vector<int64>* out); 148 149 // Converts a constant int32 or int64 Tensor into an xla int64 Literal. 150 Status ConstantInputAsInt64Literal(int index, xla::Literal* out); 151 Status ConstantInputAsInt64Literal(absl::string_view name, xla::Literal* out); 152 153 // Converts a constant 1D int32 or int64 tensor into a TensorShape. 154 Status ConstantInputAsShape(int index, TensorShape* shape); 155 156 // Converts a constant 1D int32 or int64 tensor, or a scalar with value -1 157 // into a PartialTensorShape. 158 Status ConstantInputAsPartialShape(int index, PartialTensorShape* shape); 159 160 // Returns the named list-valued immutable input in "list", as 161 // defined in the OpDef. If the named output is not list-valued, 162 // returns a one-element list. 163 Status ConstantInputList(absl::string_view name, 164 std::vector<xla::Literal>* literals); 165 166 // Returns an XlaExpression describing the value of 'index'. 167 const XlaExpression& InputExpression(int index); 168 const XlaExpression& InputExpression(absl::string_view name); 169 170 // Outputs 171 num_outputs()172 int num_outputs() const { return context_->num_outputs(); } expected_output_dtype(int index)173 DataType expected_output_dtype(int index) const { 174 return context_->expected_output_dtype(index); 175 } 176 177 // Returns the type of output `index` as an xla::PrimitiveType. If the type 178 // is not representable as an XLA type, sets an error status and returns 179 // xla::PRIMITIVE_TYPE_INVALID. 180 xla::PrimitiveType output_xla_type(int index); 181 182 // Sets output `index` to the XlaOp `handle`. 183 // All outputs should be set using SetOutput and SetConstantOutput, not 184 // via the underlying OpKernelContext. 185 void SetOutput(int index, const xla::XlaOp& handle); 186 187 // Sets output `index` to compile-time constant `host_tensor`, where 188 // `host_tensor` is a tensor in host memory. It is preferable to use 189 // SetConstantOutput where possible. 190 void SetConstantOutput(int index, const Tensor& host_tensor); 191 192 // Returns an XlaExpression describing the value of 'index'. 193 void SetOutputExpression(int index, const XlaExpression& expression); 194 195 // Sets output `index` to the Tensor List `handle`. 196 void SetTensorListOutput(int index, const xla::XlaOp& handle); 197 198 // Status handling. SetStatus(const Status & status)199 void SetStatus(const Status& status) { context_->SetStatus(status); } status()200 Status status() { return context_->status(); } 201 202 // Variables 203 204 // Sets `*resource` to the resource associated with input `index`. 205 Status GetResourceInput(int index, XlaResource** resource); 206 207 // Sets output `index` to be a reference to resource `resource`. 208 void SetResourceOutput(int index, XlaResource* resource); 209 210 // Sets `*type` and `*shape` to the current type and shape of a variable's 211 // value. 212 Status GetVariableTypeAndShape(int index, DataType* type, 213 TensorShape* shape) const; 214 215 // When dynamic_dimension_is_minus_one is set, querying a dynamic dimension 216 // returns "-1", this is useful when the underlying ops expect explicit 217 // dynamic index like reshape. set_dynamic_dimension_is_minus_one(bool value)218 void set_dynamic_dimension_is_minus_one(bool value) { 219 dynamic_dimension_is_minus_one_ = value; 220 } 221 dynamic_dimension_is_minus_one()222 bool dynamic_dimension_is_minus_one() const { 223 return dynamic_dimension_is_minus_one_; 224 } 225 is_dynamic_dimension(int64 dim_size)226 bool is_dynamic_dimension(int64 dim_size) { return dim_size == -1; } 227 228 // Reads the current value of the resource variable referred to by input 229 // `index`. If `shape` is not nullptr, sets `*shape` to the shape of the 230 // variable. Returns an error if the variable has not been initialized, or if 231 // its type does not match `type`. 232 Status ReadVariableInput(int index, DataType type, TensorShape* shape, 233 xla::XlaOp* value); 234 // Reads the current value of the resource variable referred to by input 235 // `name`. 236 Status ReadVariableInput(absl::string_view name, DataType type, 237 TensorShape* shape, xla::XlaOp* value); 238 239 // Assigns the value `handle` to the variable referenced by input 240 // `input_index`. The variable must be of `type`. Returns an error if the 241 // variable has been initialized with a different type or with a 242 // different shape. 243 Status AssignVariable(int input_index, DataType type, xla::XlaOp handle); 244 // Assigns the value `handle` to the variable referenced by input `name`. 245 Status AssignVariable(absl::string_view name, DataType type, 246 xla::XlaOp handle); 247 248 // Helper routines for the OP_REQUIRES macros 249 void CtxFailure(const Status& s); 250 void CtxFailureWithWarning(const Status& s); 251 void CtxFailure(const char* file, int line, const Status& s); 252 void CtxFailureWithWarning(const char* file, int line, const Status& s); 253 254 // If this kernel invocation is within a function execution, 255 // call_frame() returns the call frame for the function call. call_frame()256 CallFrameInterface* call_frame() const { return context_->call_frame(); } 257 function_library()258 FunctionLibraryRuntime* function_library() const { 259 return context_->function_library(); 260 } 261 op_kernel()262 const OpKernel& op_kernel() const { return context_->op_kernel(); } 263 264 // Returns the underlying OpKernelContext. Use rarely. op_kernel_context()265 OpKernelContext* op_kernel_context() const { return context_; } 266 267 // Returns the XlaCompiler that is performing the compilation. Used for, e.g., 268 // While to compile nested computations. 269 XlaCompiler* compiler() const; 270 271 // TODO(phawkins): find a better home for these helpers. 272 273 // Gets an XLA lambda to compute Max. This is cached in the 274 // XlaContext since it may be used by multiple Ops. There is a 275 // separate specialization of the computation for each DataType. 276 const xla::XlaComputation* GetOrCreateMax(const DataType type); 277 278 // Gets an XLA lambda to compute Min. This is cached in the 279 // XlaContext since it may be used by multiple Ops. There is a 280 // separate specialization of the computation for each DataType. 281 const xla::XlaComputation* GetOrCreateMin(const DataType type); 282 283 // Gets an XLA lambda to compute Add. This is cached in the 284 // XlaContext since it may be used by multiple Ops. There is a 285 // separate specialization of the computation for each DataType. 286 const xla::XlaComputation* GetOrCreateAdd(const DataType type); 287 288 // Gets an XLA lambda to compute Mul. This is cached in the 289 // XlaContext since it may be used by multiple Ops. There is a 290 // separate specialization of the computation for each DataType. 291 const xla::XlaComputation* GetOrCreateMul(const DataType type); 292 293 // Returns stack trace encoded as a string at a given module, or an empty 294 // string if none found. 295 std::string StackTrace() const; 296 297 private: 298 // Returns the tensor of input `name`. 299 const Tensor& GetInputTensorByName(absl::string_view name); 300 301 // Evaluates input `index`, reshapes it to `new_shape` if new_shape != 302 // InputShape(index), and stores it in `*constant_literal`. If the input 303 // cannot be evaluated, e.g., because it depends on unbound parameters, 304 // returns a non-Ok status. If InputShape(index).num_elements() != 305 // new_shape.num_elements(), returns an error status. 306 Status ConstantInputReshaped(int index, absl::Span<const int64> new_dims, 307 xla::Literal* constant_literal); 308 309 OpKernelContext* const context_; 310 bool dynamic_dimension_is_minus_one_; 311 }; 312 313 } // namespace tensorflow 314 315 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ 316