• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
18 
19 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
20 #include "tensorflow/compiler/tf2xla/xla_context.h"
21 #include "tensorflow/compiler/tf2xla/xla_expression.h"
22 #include "tensorflow/compiler/tf2xla/xla_resource.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/client/xla_computation.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/platform/macros.h"
28 
29 namespace tensorflow {
30 
31 class XlaOpKernelContext;
32 
33 // Implementations of operators that generate XLA code should usually subclass
34 // XlaOpKernel and implement the Compile() method. Unlike a regular OpKernel,
35 // an XlaOpKernel produces and consumes symbolic values during compilation.
36 //
37 // See the comments in xla_context.h for more details.
38 class XlaOpKernel : public OpKernel {
39  public:
40   explicit XlaOpKernel(OpKernelConstruction* construction);
41 
42   // Subclasses should implement Compile(), much as standard OpKernels implement
43   // Compute().
44   virtual void Compile(XlaOpKernelContext* context) = 0;
45 
46  private:
47   void Compute(OpKernelContext* context) final;
48 };
49 
50 // The context passed to the Compile() method of XlaOpKernel. An
51 // XlaOpKernelContext is a variant of the standard OpKernel class, tailored for
52 // implementing operators that perform symbolic execution as part of the XLA
53 // compiler. The key difference is that XlaOpKernelContext produces and consumes
54 // data as XLA computations, rather than as standard Tensors.
55 //
56 // Under the hood, symbolic execution communicates using special Tensors that
57 // wrap XlaExpression objects, however this is an implementation detail that
58 // this class hides. The *only* correct way to allocate a Tensor during
59 // compilation is using the XlaOpKernelContext methods, since they ensure there
60 // is a valid XlaExpression backing the tensor. No Op should ever call
61 // allocate_output or allocate_temp directly on the underlying OpKernelContext.
62 class XlaOpKernelContext {
63  public:
64   explicit XlaOpKernelContext(OpKernelContext* context);
65 
66   XlaContext* xla_context() const;
67 
68   // Returns the XLA XlaBuilder containing the output of compilation.
69   xla::XlaBuilder* builder() const;
70 
71   // Inputs
72 
73   // Returns the number of inputs to the operator.
num_inputs()74   int num_inputs() const { return context_->num_inputs(); }
75 
76   // Returns the type of input `index`.
77   DataType input_type(int index) const;
78 
79   // Returns the type of input `name`.
80   DataType InputType(absl::string_view name);
81 
82   // Returns the type of input `index` as an xla::PrimitiveType. If the type
83   // is not representable as an XLA type, sets an error status and returns
84   // xla::PRIMITIVE_TYPE_INVALID.
85   xla::PrimitiveType input_xla_type(int index);
86 
87   // Returns the type of input `name` as an xla::PrimitiveType. If the type
88   // is not representable as an XLA type, sets an error status and returns
89   // xla::PRIMITIVE_TYPE_INVALID.
90   xla::PrimitiveType InputXlaType(absl::string_view name);
91 
92   // Returns the shape of input `index`.
93   TensorShape InputShape(int index);
94 
95   // Returns the shape of input with name `name`.
96   TensorShape InputShape(absl::string_view name);
97 
98   // Returns input `index` as a XlaOp. Unlike
99   // OpKernelContext::Input returns a symbolic value rather than a concrete
100   // Tensor.
101   xla::XlaOp Input(int index);
102   // Returns input `name` as a XlaOp.
103   xla::XlaOp Input(absl::string_view name);
104 
105   // Returns the xla input shape for a given index.
106   xla::StatusOr<xla::Shape> InputXlaShape(int index);
107   xla::StatusOr<xla::Shape> InputXlaShape(absl::string_view name);
108 
109   // Returns true if all inputs are the same shape, otherwise sets the
110   // status to a non-OK value and returns false.
111   // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
112   bool ValidateInputsAreSameShape(OpKernel* op) TF_MUST_USE_RESULT;
113 
114   // Returns the named list-valued immutable input in "list", as
115   // defined in the OpDef.  If the named output is not list-valued,
116   // returns a one-element list.
117   Status InputList(absl::string_view name, std::vector<xla::XlaOp>* handles,
118                    std::vector<TensorShape>* shapes);
119   // Evaluates input and returns their dynamism vector in a vector of
120   // predicates.
121   Status ResolveInputDynamismIntoPredVector(int index, std::vector<bool>* out);
122   Status ResolveInputDynamismIntoPred(int index, bool* out);
123   // Helper methods for constant inputs.
124 
125   // Evaluates input `index` and stores it in `*constant_literal`. If the
126   // expression cannot be evaluated, e.g., because it depends on unbound
127   // parameters, returns a non-OK status.
128   Status ConstantInput(int index, xla::Literal* constant_literal);
129   Status ConstantInput(absl::string_view name, xla::Literal* constant_literal);
130 
131   // Converts a constant scalar int32 or int64 tensor into an int64.
132   Status ConstantInputAsIntScalar(int index, int64* out);
133   Status ConstantInputAsIntScalar(absl::string_view name, int64* out);
134 
135   // Converts a constant scalar float32 or float64 tensor into a float64.
136   Status ConstantInputAsFloatScalar(int index, double* out);
137 
138   // Converts a constant 1D int32 or int64 tensor into a vector of int64s.
139   Status ConstantInputAsIntVector(int index, std::vector<int64>* out);
140   Status ConstantInputAsIntVector(absl::string_view name,
141                                   std::vector<int64>* out);
142 
143   // Reshapes and converts a constant int32 or int64 tensor into a vector of
144   // int64s.
145   Status ConstantInputReshapedToIntVector(int index, std::vector<int64>* out);
146   Status ConstantInputReshapedToIntVector(absl::string_view name,
147                                           std::vector<int64>* out);
148 
149   // Converts a constant int32 or int64 Tensor into an xla int64 Literal.
150   Status ConstantInputAsInt64Literal(int index, xla::Literal* out);
151   Status ConstantInputAsInt64Literal(absl::string_view name, xla::Literal* out);
152 
153   // Converts a constant 1D int32 or int64 tensor into a TensorShape.
154   Status ConstantInputAsShape(int index, TensorShape* shape);
155 
156   // Converts a constant 1D int32 or int64 tensor, or a scalar with value -1
157   // into a PartialTensorShape.
158   Status ConstantInputAsPartialShape(int index, PartialTensorShape* shape);
159 
160   // Returns the named list-valued immutable input in "list", as
161   // defined in the OpDef.  If the named output is not list-valued,
162   // returns a one-element list.
163   Status ConstantInputList(absl::string_view name,
164                            std::vector<xla::Literal>* literals);
165 
166   // Returns an XlaExpression describing the value of 'index'.
167   const XlaExpression& InputExpression(int index);
168   const XlaExpression& InputExpression(absl::string_view name);
169 
170   // Outputs
171 
num_outputs()172   int num_outputs() const { return context_->num_outputs(); }
expected_output_dtype(int index)173   DataType expected_output_dtype(int index) const {
174     return context_->expected_output_dtype(index);
175   }
176 
177   // Returns the type of output `index` as an xla::PrimitiveType. If the type
178   // is not representable as an XLA type, sets an error status and returns
179   // xla::PRIMITIVE_TYPE_INVALID.
180   xla::PrimitiveType output_xla_type(int index);
181 
182   // Sets output `index` to the XlaOp `handle`.
183   // All outputs should be set using SetOutput and SetConstantOutput, not
184   // via the underlying OpKernelContext.
185   void SetOutput(int index, const xla::XlaOp& handle);
186 
187   // Sets output `index` to compile-time constant `host_tensor`, where
188   // `host_tensor` is a tensor in host memory. It is preferable to use
189   // SetConstantOutput where possible.
190   void SetConstantOutput(int index, const Tensor& host_tensor);
191 
192   // Returns an XlaExpression describing the value of 'index'.
193   void SetOutputExpression(int index, const XlaExpression& expression);
194 
195   // Sets output `index` to the Tensor List `handle`.
196   void SetTensorListOutput(int index, const xla::XlaOp& handle);
197 
198   // Status handling.
SetStatus(const Status & status)199   void SetStatus(const Status& status) { context_->SetStatus(status); }
status()200   Status status() { return context_->status(); }
201 
202   // Variables
203 
204   // Sets `*resource` to the resource associated with input `index`.
205   Status GetResourceInput(int index, XlaResource** resource);
206 
207   // Sets output `index` to be a reference to resource `resource`.
208   void SetResourceOutput(int index, XlaResource* resource);
209 
210   // Sets `*type` and `*shape` to the current type and shape of a variable's
211   // value.
212   Status GetVariableTypeAndShape(int index, DataType* type,
213                                  TensorShape* shape) const;
214 
215   // When dynamic_dimension_is_minus_one is set, querying a dynamic dimension
216   // returns "-1", this is useful when the underlying ops expect explicit
217   // dynamic index like reshape.
set_dynamic_dimension_is_minus_one(bool value)218   void set_dynamic_dimension_is_minus_one(bool value) {
219     dynamic_dimension_is_minus_one_ = value;
220   }
221 
dynamic_dimension_is_minus_one()222   bool dynamic_dimension_is_minus_one() const {
223     return dynamic_dimension_is_minus_one_;
224   }
225 
is_dynamic_dimension(int64 dim_size)226   bool is_dynamic_dimension(int64 dim_size) { return dim_size == -1; }
227 
228   // Reads the current value of the resource variable referred to by input
229   // `index`. If `shape` is not nullptr, sets `*shape` to the shape of the
230   // variable. Returns an error if the variable has not been initialized, or if
231   // its type does not match `type`.
232   Status ReadVariableInput(int index, DataType type, TensorShape* shape,
233                            xla::XlaOp* value);
234   // Reads the current value of the resource variable referred to by input
235   // `name`.
236   Status ReadVariableInput(absl::string_view name, DataType type,
237                            TensorShape* shape, xla::XlaOp* value);
238 
239   // Assigns the value `handle` to the variable referenced by input
240   // `input_index`. The variable must be of `type`. Returns an error if the
241   // variable has been initialized with a different type or with a
242   // different shape.
243   Status AssignVariable(int input_index, DataType type, xla::XlaOp handle);
244   // Assigns the value `handle` to the variable referenced by input `name`.
245   Status AssignVariable(absl::string_view name, DataType type,
246                         xla::XlaOp handle);
247 
248   // Helper routines for the OP_REQUIRES macros
249   void CtxFailure(const Status& s);
250   void CtxFailureWithWarning(const Status& s);
251   void CtxFailure(const char* file, int line, const Status& s);
252   void CtxFailureWithWarning(const char* file, int line, const Status& s);
253 
254   // If this kernel invocation is within a function execution,
255   // call_frame() returns the call frame for the function call.
call_frame()256   CallFrameInterface* call_frame() const { return context_->call_frame(); }
257 
function_library()258   FunctionLibraryRuntime* function_library() const {
259     return context_->function_library();
260   }
261 
op_kernel()262   const OpKernel& op_kernel() const { return context_->op_kernel(); }
263 
264   // Returns the underlying OpKernelContext. Use rarely.
op_kernel_context()265   OpKernelContext* op_kernel_context() const { return context_; }
266 
267   // Returns the XlaCompiler that is performing the compilation. Used for, e.g.,
268   // While to compile nested computations.
269   XlaCompiler* compiler() const;
270 
271   // TODO(phawkins): find a better home for these helpers.
272 
273   // Gets an XLA lambda to compute Max. This is cached in the
274   // XlaContext since it may be used by multiple Ops. There is a
275   // separate specialization of the computation for each DataType.
276   const xla::XlaComputation* GetOrCreateMax(const DataType type);
277 
278   // Gets an XLA lambda to compute Min. This is cached in the
279   // XlaContext since it may be used by multiple Ops. There is a
280   // separate specialization of the computation for each DataType.
281   const xla::XlaComputation* GetOrCreateMin(const DataType type);
282 
283   // Gets an XLA lambda to compute Add. This is cached in the
284   // XlaContext since it may be used by multiple Ops. There is a
285   // separate specialization of the computation for each DataType.
286   const xla::XlaComputation* GetOrCreateAdd(const DataType type);
287 
288   // Gets an XLA lambda to compute Mul. This is cached in the
289   // XlaContext since it may be used by multiple Ops. There is a
290   // separate specialization of the computation for each DataType.
291   const xla::XlaComputation* GetOrCreateMul(const DataType type);
292 
293   // Returns stack trace encoded as a string at a given module, or an empty
294   // string if none found.
295   std::string StackTrace() const;
296 
297  private:
298   // Returns the tensor of input `name`.
299   const Tensor& GetInputTensorByName(absl::string_view name);
300 
301   // Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
302   // InputShape(index), and stores it in `*constant_literal`. If the input
303   // cannot be evaluated, e.g., because it depends on unbound parameters,
304   // returns a non-Ok status. If InputShape(index).num_elements() !=
305   // new_shape.num_elements(), returns an error status.
306   Status ConstantInputReshaped(int index, absl::Span<const int64> new_dims,
307                                xla::Literal* constant_literal);
308 
309   OpKernelContext* const context_;
310   bool dynamic_dimension_is_minus_one_;
311 };
312 
313 }  // namespace tensorflow
314 
315 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
316