1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ 18 19 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 20 #include "tensorflow/compiler/tf2xla/xla_context.h" 21 #include "tensorflow/compiler/tf2xla/xla_expression.h" 22 #include "tensorflow/compiler/tf2xla/xla_resource.h" 23 #include "tensorflow/compiler/xla/client/value_inference.h" 24 #include "tensorflow/compiler/xla/client/xla_builder.h" 25 #include "tensorflow/compiler/xla/client/xla_computation.h" 26 #include "tensorflow/compiler/xla/xla_data.pb.h" 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/platform/macros.h" 29 30 namespace tensorflow { 31 32 class XlaOpKernelContext; 33 34 // Implementations of operators that generate XLA code should usually subclass 35 // XlaOpKernel and implement the Compile() method. Unlike a regular OpKernel, 36 // an XlaOpKernel produces and consumes symbolic values during compilation. 37 // 38 // See the comments in xla_context.h for more details. 39 class XlaOpKernel : public OpKernel { 40 public: 41 explicit XlaOpKernel(OpKernelConstruction* construction); 42 43 // Subclasses should implement Compile(), much as standard OpKernels implement 44 // Compute(). 45 virtual void Compile(XlaOpKernelContext* context) = 0; 46 47 private: 48 void Compute(OpKernelContext* context) final; 49 }; 50 51 // The context passed to the Compile() method of XlaOpKernel. An 52 // XlaOpKernelContext is a variant of the standard OpKernel class, tailored for 53 // implementing operators that perform symbolic execution as part of the XLA 54 // compiler. The key difference is that XlaOpKernelContext produces and consumes 55 // data as XLA computations, rather than as standard Tensors. 56 // 57 // Under the hood, symbolic execution communicates using special Tensors that 58 // wrap XlaExpression objects, however this is an implementation detail that 59 // this class hides. The *only* correct way to allocate a Tensor during 60 // compilation is using the XlaOpKernelContext methods, since they ensure there 61 // is a valid XlaExpression backing the tensor. No Op should ever call 62 // allocate_output or allocate_temp directly on the underlying OpKernelContext. 63 class XlaOpKernelContext { 64 public: 65 explicit XlaOpKernelContext(OpKernelContext* context); 66 67 XlaContext* xla_context() const; 68 69 // Returns the XLA XlaBuilder containing the output of compilation. 70 xla::XlaBuilder* builder() const; 71 72 xla::ValueInference& value_inference(); 73 74 // Inputs 75 76 // Returns the number of inputs to the operator. num_inputs()77 int num_inputs() const { return context_->num_inputs(); } 78 79 // Returns the type of input `index`. 80 DataType input_type(int index) const; 81 82 // Returns the type of input `name`. 83 DataType InputType(absl::string_view name); 84 85 // Returns the type of input `index` as an xla::PrimitiveType. If the type 86 // is not representable as an XLA type, sets an error status and returns 87 // xla::PRIMITIVE_TYPE_INVALID. 88 xla::PrimitiveType input_xla_type(int index); 89 90 // Returns the type of input `name` as an xla::PrimitiveType. If the type 91 // is not representable as an XLA type, sets an error status and returns 92 // xla::PRIMITIVE_TYPE_INVALID. 93 xla::PrimitiveType InputXlaType(absl::string_view name); 94 95 // Returns the shape of input at `index` or input the given `name`. Note that 96 // in case the shape of the input is not static, then the returned shape has 97 // bounds as the dimension size instead of having unknown dimensions. Use 98 // InputXlaShape instead that provides shapes with dynamism information. 99 // 100 ABSL_DEPRECATED( 101 "Prefer InputXlaShape which handles dynamic shapes accurately.") 102 TensorShape InputShape(int index); 103 ABSL_DEPRECATED( 104 "Prefer InputXlaShape which handles dynamic shapes accurately.") 105 TensorShape InputShape(absl::string_view name); 106 107 // Returns input `index` as a XlaOp. Unlike 108 // OpKernelContext::Input returns a symbolic value rather than a concrete 109 // Tensor. 110 xla::XlaOp Input(int index); 111 // Returns input `name` as a XlaOp. 112 xla::XlaOp Input(absl::string_view name); 113 114 // Returns the xla input shape for a given index. 115 StatusOr<xla::Shape> InputXlaShape(int index); 116 StatusOr<xla::Shape> InputXlaShape(absl::string_view name); 117 118 // Returns true if all inputs are the same shape, otherwise sets the 119 // status to a non-OK value and returns false. 120 // Usage: if (!context->ValidateInputsAreSameShape(this)) return; 121 bool ValidateInputsAreSameShape(OpKernel* op) TF_MUST_USE_RESULT; 122 123 // Returns the named list-valued immutable input in "list", as 124 // defined in the OpDef. If the named output is not list-valued, 125 // returns a one-element list. 126 Status InputList(absl::string_view name, std::vector<xla::XlaOp>* handles, 127 std::vector<TensorShape>* shapes); 128 // Evaluates input and returns their dynamism vector in a vector of 129 // predicates. 130 Status ResolveInputDynamismIntoPredVector(int index, std::vector<bool>* out); 131 Status ResolveInputDynamismIntoPred(int index, bool* out); 132 Status ResolveInputDynamismIntoPredVector(absl::string_view name, 133 std::vector<bool>* out); 134 Status ResolveInputDynamismIntoPred(absl::string_view name, bool* out); 135 136 Status ResolveInputDynamism(int index, xla::Literal* dynamism_literal); 137 Status ResolveInputDynamism(absl::string_view name, 138 xla::Literal* dynamism_literal); 139 140 Status ResolveInputDynamismReshaped(int index, 141 absl::Span<const int64> new_dims, 142 xla::Literal* dynamism_literal); 143 // Helper methods for constant inputs. 144 145 // Evaluates input `index` and stores it in `*constant_literal`. If the 146 // expression cannot be evaluated, e.g., because it depends on unbound 147 // parameters, returns a non-OK status. This function can also be used to 148 // infer constant input upper or lower bounds, by changing the `mode` 149 // parameter. 150 Status ConstantInput( 151 int index, xla::Literal* constant_literal, 152 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 153 Status ConstantInput( 154 absl::string_view name, xla::Literal* constant_literal, 155 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 156 157 // Converts a constant scalar int32 or int64 tensor into an int64. 158 Status ConstantInputAsIntScalar( 159 int index, int64* out, 160 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 161 Status ConstantInputAsIntScalar( 162 absl::string_view name, int64* out, 163 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 164 165 // Converts a constant scalar float32 or float64 tensor into a float64. 166 Status ConstantInputAsFloatScalar( 167 int index, double* out, 168 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 169 170 // Converts a constant 1D int32 or int64 tensor into a vector of int64s. 171 Status ConstantInputAsIntVector( 172 int index, std::vector<int64>* out, 173 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 174 Status ConstantInputAsIntVector( 175 absl::string_view name, std::vector<int64>* out, 176 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 177 178 // Reshapes and converts a constant int32 or int64 tensor into a vector of 179 // int64s. 180 Status ConstantInputReshapedToIntVector( 181 int index, std::vector<int64>* out, 182 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 183 Status ConstantInputReshapedToIntVector( 184 absl::string_view name, std::vector<int64>* out, 185 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 186 187 // Converts a constant int32 or int64 Tensor into an xla int64 Literal. 188 Status ConstantInputAsInt64Literal( 189 int index, xla::Literal* out, 190 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 191 Status ConstantInputAsInt64Literal( 192 absl::string_view name, xla::Literal* out, 193 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 194 195 // Converts a constant 1D int32 or int64 tensor into a TensorShape. 196 Status ConstantInputAsShape( 197 int index, TensorShape* shape, 198 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 199 200 // Converts a constant 1D int32 or int64 tensor, or a scalar with value -1 201 // into a PartialTensorShape. 202 Status ConstantInputAsPartialShape(int index, PartialTensorShape* shape); 203 204 // Returns the named list-valued immutable input in "list", as 205 // defined in the OpDef. If the named output is not list-valued, 206 // returns a one-element list. 207 Status ConstantInputList( 208 absl::string_view name, std::vector<xla::Literal>* outputs, 209 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 210 211 // Returns an XlaExpression describing the value of 'index'. 212 const XlaExpression& InputExpression(int index); 213 const XlaExpression& InputExpression(absl::string_view name); 214 215 // Outputs 216 num_outputs()217 int num_outputs() const { return context_->num_outputs(); } expected_output_dtype(int index)218 DataType expected_output_dtype(int index) const { 219 return context_->expected_output_dtype(index); 220 } 221 222 // Returns the type of output `index` as an xla::PrimitiveType. If the type 223 // is not representable as an XLA type, sets an error status and returns 224 // xla::PRIMITIVE_TYPE_INVALID. 225 xla::PrimitiveType output_xla_type(int index); 226 227 // Sets output `index` to the XlaOp `handle`. 228 // All outputs should be set using SetOutput and SetConstantOutput, not 229 // via the underlying OpKernelContext. 230 void SetOutput(int index, const xla::XlaOp& handle); 231 232 // Sets output `index` to compile-time constant `host_tensor`, where 233 // `host_tensor` is a tensor in host memory. It is preferable to use 234 // SetConstantOutput where possible. 235 void SetConstantOutput(int index, const Tensor& host_tensor); 236 237 // Returns an XlaExpression describing the value of 'index'. 238 void SetOutputExpression(int index, const XlaExpression& expression); 239 240 // Sets output `index` to the Tensor List `handle`. 241 void SetTensorListOutput(int index, const xla::XlaOp& handle); 242 243 // Status handling. SetStatus(const Status & status)244 void SetStatus(const Status& status) { context_->SetStatus(status); } status()245 Status status() { return context_->status(); } 246 247 // Variables 248 249 // Sets `*resource` to the resource associated with input `index`. 250 Status GetResourceInput(int index, XlaResource** resource); 251 252 // Sets output `index` to be a reference to resource `resource`. 253 void SetResourceOutput(int index, XlaResource* resource); 254 255 // Sets `*type` and `*shape` to the current type and shape of a variable's 256 // value. 257 Status GetVariableTypeAndShape(int index, DataType* type, 258 TensorShape* shape) const; 259 260 // When dynamic_dimension_is_minus_one is set, querying a dynamic dimension 261 // returns "-1", this is useful when the underlying ops expect explicit 262 // dynamic index like reshape. set_dynamic_dimension_is_minus_one(bool value)263 void set_dynamic_dimension_is_minus_one(bool value) { 264 dynamic_dimension_is_minus_one_ = value; 265 } 266 dynamic_dimension_is_minus_one()267 bool dynamic_dimension_is_minus_one() const { 268 return dynamic_dimension_is_minus_one_; 269 } 270 is_dynamic_dimension(int64_t dim_size)271 bool is_dynamic_dimension(int64_t dim_size) { return dim_size == -1; } 272 273 // Reads the current value of the resource variable referred to by input 274 // `index`. If `shape` is not nullptr, sets `*shape` to the shape of the 275 // variable. Returns an error if the variable has not been initialized, or if 276 // its type does not match `type`. 277 Status ReadVariableInput(int index, DataType type, TensorShape* shape, 278 xla::XlaOp* value); 279 // Reads the current value of the resource variable referred to by input 280 // `name`. 281 Status ReadVariableInput(absl::string_view name, DataType type, 282 TensorShape* shape, xla::XlaOp* value); 283 284 // Assigns the value `handle` to the variable referenced by input 285 // `input_index`. The variable must be of `type`. Returns an error if the 286 // variable has been initialized with a different type or with a 287 // different shape. 288 Status AssignVariable(int input_index, DataType type, xla::XlaOp handle); 289 // Assigns the value `handle` to the variable referenced by input `name`. 290 Status AssignVariable(absl::string_view name, DataType type, 291 xla::XlaOp handle); 292 293 // Helper routines for the OP_REQUIRES macros 294 void CtxFailure(const Status& s); 295 void CtxFailureWithWarning(const Status& s); 296 void CtxFailure(const char* file, int line, const Status& s); 297 void CtxFailureWithWarning(const char* file, int line, const Status& s); 298 299 // If this kernel invocation is within a function execution, 300 // call_frame() returns the call frame for the function call. call_frame()301 CallFrameInterface* call_frame() const { return context_->call_frame(); } 302 function_library()303 FunctionLibraryRuntime* function_library() const { 304 return context_->function_library(); 305 } 306 op_kernel()307 const OpKernel& op_kernel() const { return context_->op_kernel(); } 308 309 // Returns the underlying OpKernelContext. Use rarely. op_kernel_context()310 OpKernelContext* op_kernel_context() const { return context_; } 311 312 // Returns the XlaCompiler that is performing the compilation. Used for, e.g., 313 // While to compile nested computations. 314 XlaCompiler* compiler() const; 315 316 // TODO(phawkins): find a better home for these helpers. 317 318 // Gets an XLA lambda to compute Max. This is cached in the 319 // XlaContext since it may be used by multiple Ops. There is a 320 // separate specialization of the computation for each DataType. 321 const xla::XlaComputation* GetOrCreateMax(const DataType type); 322 323 // Gets an XLA lambda to compute Min. This is cached in the 324 // XlaContext since it may be used by multiple Ops. There is a 325 // separate specialization of the computation for each DataType. 326 const xla::XlaComputation* GetOrCreateMin(const DataType type); 327 328 // Gets an XLA lambda to compute Add. This is cached in the 329 // XlaContext since it may be used by multiple Ops. There is a 330 // separate specialization of the computation for each DataType. 331 const xla::XlaComputation* GetOrCreateAdd(const DataType type); 332 333 // Gets an XLA lambda to compute Mul. This is cached in the 334 // XlaContext since it may be used by multiple Ops. There is a 335 // separate specialization of the computation for each DataType. 336 const xla::XlaComputation* GetOrCreateMul(const DataType type); 337 338 // Returns stack trace encoded as a string at a given module, or an empty 339 // string if none found. 340 std::string StackTrace() const; 341 342 private: 343 // Returns the tensor of input `name`. 344 const Tensor& GetInputTensorByName(absl::string_view name); 345 // Evaluates input `index`, reshapes it to `new_shape` if new_shape != 346 // InputShape(index), and stores it in `*constant_literal`. If the input 347 // cannot be evaluated, e.g., because it depends on unbound parameters, 348 // returns a non-Ok status. If InputShape(index).num_elements() != 349 // new_shape.num_elements(), returns an error status. 350 Status ConstantInputReshaped( 351 int index, absl::Span<const int64> new_dims, 352 xla::Literal* constant_literal, 353 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 354 355 OpKernelContext* const context_; 356 bool dynamic_dimension_is_minus_one_; 357 xla::ValueInference value_inference_; 358 }; 359 360 } // namespace tensorflow 361 362 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ 363