• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
18 
19 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
20 #include "tensorflow/compiler/tf2xla/xla_context.h"
21 #include "tensorflow/compiler/tf2xla/xla_expression.h"
22 #include "tensorflow/compiler/tf2xla/xla_resource.h"
23 #include "tensorflow/compiler/xla/client/value_inference.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/client/xla_computation.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/platform/macros.h"
29 
30 namespace tensorflow {
31 
32 class XlaOpKernelContext;
33 
34 // Implementations of operators that generate XLA code should usually subclass
35 // XlaOpKernel and implement the Compile() method. Unlike a regular OpKernel,
36 // an XlaOpKernel produces and consumes symbolic values during compilation.
37 //
38 // See the comments in xla_context.h for more details.
39 class XlaOpKernel : public OpKernel {
40  public:
41   explicit XlaOpKernel(OpKernelConstruction* construction);
42 
43   // Subclasses should implement Compile(), much as standard OpKernels implement
44   // Compute().
45   virtual void Compile(XlaOpKernelContext* context) = 0;
46 
47  private:
48   void Compute(OpKernelContext* context) final;
49 };
50 
51 // The context passed to the Compile() method of XlaOpKernel. An
52 // XlaOpKernelContext is a variant of the standard OpKernel class, tailored for
53 // implementing operators that perform symbolic execution as part of the XLA
54 // compiler. The key difference is that XlaOpKernelContext produces and consumes
55 // data as XLA computations, rather than as standard Tensors.
56 //
57 // Under the hood, symbolic execution communicates using special Tensors that
58 // wrap XlaExpression objects, however this is an implementation detail that
59 // this class hides. The *only* correct way to allocate a Tensor during
60 // compilation is using the XlaOpKernelContext methods, since they ensure there
61 // is a valid XlaExpression backing the tensor. No Op should ever call
62 // allocate_output or allocate_temp directly on the underlying OpKernelContext.
63 class XlaOpKernelContext {
64  public:
65   explicit XlaOpKernelContext(OpKernelContext* context);
66 
67   XlaContext* xla_context() const;
68 
69   // Returns the XLA XlaBuilder containing the output of compilation.
70   xla::XlaBuilder* builder() const;
71 
72   xla::ValueInference& value_inference();
73 
74   // Inputs
75 
76   // Returns the number of inputs to the operator.
num_inputs()77   int num_inputs() const { return context_->num_inputs(); }
78 
79   // Returns the type of input `index`.
80   DataType input_type(int index) const;
81 
82   // Returns the type of input `name`.
83   DataType InputType(absl::string_view name);
84 
85   // Returns the type of input `index` as an xla::PrimitiveType. If the type
86   // is not representable as an XLA type, sets an error status and returns
87   // xla::PRIMITIVE_TYPE_INVALID.
88   xla::PrimitiveType input_xla_type(int index);
89 
90   // Returns the type of input `name` as an xla::PrimitiveType. If the type
91   // is not representable as an XLA type, sets an error status and returns
92   // xla::PRIMITIVE_TYPE_INVALID.
93   xla::PrimitiveType InputXlaType(absl::string_view name);
94 
95   // Returns the shape of input at `index` or input the given `name`. Note that
96   // in case the shape of the input is not static, then the returned shape has
97   // bounds as the dimension size instead of having unknown dimensions. Use
98   // InputXlaShape instead that provides shapes with dynamism information.
99   //
100   ABSL_DEPRECATED(
101       "Prefer InputXlaShape which handles dynamic shapes accurately.")
102   TensorShape InputShape(int index);
103   ABSL_DEPRECATED(
104       "Prefer InputXlaShape which handles dynamic shapes accurately.")
105   TensorShape InputShape(absl::string_view name);
106 
107   // Returns input `index` as a XlaOp. Unlike
108   // OpKernelContext::Input returns a symbolic value rather than a concrete
109   // Tensor.
110   xla::XlaOp Input(int index);
111   // Returns input `name` as a XlaOp.
112   xla::XlaOp Input(absl::string_view name);
113 
114   // Returns the xla input shape for a given index.
115   StatusOr<xla::Shape> InputXlaShape(int index);
116   StatusOr<xla::Shape> InputXlaShape(absl::string_view name);
117 
118   // Returns true if all inputs are the same shape, otherwise sets the
119   // status to a non-OK value and returns false.
120   // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
121   bool ValidateInputsAreSameShape(OpKernel* op) TF_MUST_USE_RESULT;
122 
123   // Returns the named list-valued immutable input in "list", as
124   // defined in the OpDef.  If the named output is not list-valued,
125   // returns a one-element list.
126   Status InputList(absl::string_view name, std::vector<xla::XlaOp>* handles,
127                    std::vector<TensorShape>* shapes);
128   // Evaluates input and returns their dynamism vector in a vector of
129   // predicates.
130   Status ResolveInputDynamismIntoPredVector(int index, std::vector<bool>* out);
131   Status ResolveInputDynamismIntoPred(int index, bool* out);
132   Status ResolveInputDynamismIntoPredVector(absl::string_view name,
133                                             std::vector<bool>* out);
134   Status ResolveInputDynamismIntoPred(absl::string_view name, bool* out);
135 
136   Status ResolveInputDynamism(int index, xla::Literal* dynamism_literal);
137   Status ResolveInputDynamism(absl::string_view name,
138                               xla::Literal* dynamism_literal);
139 
140   Status ResolveInputDynamismReshaped(int index,
141                                       absl::Span<const int64> new_dims,
142                                       xla::Literal* dynamism_literal);
143   // Helper methods for constant inputs.
144 
145   // Evaluates input `index` and stores it in `*constant_literal`. If the
146   // expression cannot be evaluated, e.g., because it depends on unbound
147   // parameters, returns a non-OK status. This function can also be used to
148   // infer constant input upper or lower bounds, by changing the `mode`
149   // parameter.
150   Status ConstantInput(
151       int index, xla::Literal* constant_literal,
152       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
153   Status ConstantInput(
154       absl::string_view name, xla::Literal* constant_literal,
155       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
156 
157   // Converts a constant scalar int32 or int64 tensor into an int64.
158   Status ConstantInputAsIntScalar(
159       int index, int64* out,
160       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
161   Status ConstantInputAsIntScalar(
162       absl::string_view name, int64* out,
163       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
164 
165   // Converts a constant scalar float32 or float64 tensor into a float64.
166   Status ConstantInputAsFloatScalar(
167       int index, double* out,
168       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
169 
170   // Converts a constant 1D int32 or int64 tensor into a vector of int64s.
171   Status ConstantInputAsIntVector(
172       int index, std::vector<int64>* out,
173       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
174   Status ConstantInputAsIntVector(
175       absl::string_view name, std::vector<int64>* out,
176       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
177 
178   // Reshapes and converts a constant int32 or int64 tensor into a vector of
179   // int64s.
180   Status ConstantInputReshapedToIntVector(
181       int index, std::vector<int64>* out,
182       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
183   Status ConstantInputReshapedToIntVector(
184       absl::string_view name, std::vector<int64>* out,
185       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
186 
187   // Converts a constant int32 or int64 Tensor into an xla int64 Literal.
188   Status ConstantInputAsInt64Literal(
189       int index, xla::Literal* out,
190       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
191   Status ConstantInputAsInt64Literal(
192       absl::string_view name, xla::Literal* out,
193       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
194 
195   // Converts a constant 1D int32 or int64 tensor into a TensorShape.
196   Status ConstantInputAsShape(
197       int index, TensorShape* shape,
198       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
199 
200   // Converts a constant 1D int32 or int64 tensor, or a scalar with value -1
201   // into a PartialTensorShape.
202   Status ConstantInputAsPartialShape(int index, PartialTensorShape* shape);
203 
204   // Returns the named list-valued immutable input in "list", as
205   // defined in the OpDef.  If the named output is not list-valued,
206   // returns a one-element list.
207   Status ConstantInputList(
208       absl::string_view name, std::vector<xla::Literal>* outputs,
209       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
210 
211   // Returns an XlaExpression describing the value of 'index'.
212   const XlaExpression& InputExpression(int index);
213   const XlaExpression& InputExpression(absl::string_view name);
214 
215   // Outputs
216 
num_outputs()217   int num_outputs() const { return context_->num_outputs(); }
expected_output_dtype(int index)218   DataType expected_output_dtype(int index) const {
219     return context_->expected_output_dtype(index);
220   }
221 
222   // Returns the type of output `index` as an xla::PrimitiveType. If the type
223   // is not representable as an XLA type, sets an error status and returns
224   // xla::PRIMITIVE_TYPE_INVALID.
225   xla::PrimitiveType output_xla_type(int index);
226 
227   // Sets output `index` to the XlaOp `handle`.
228   // All outputs should be set using SetOutput and SetConstantOutput, not
229   // via the underlying OpKernelContext.
230   void SetOutput(int index, const xla::XlaOp& handle);
231 
232   // Sets output `index` to compile-time constant `host_tensor`, where
233   // `host_tensor` is a tensor in host memory. It is preferable to use
234   // SetConstantOutput where possible.
235   void SetConstantOutput(int index, const Tensor& host_tensor);
236 
237   // Returns an XlaExpression describing the value of 'index'.
238   void SetOutputExpression(int index, const XlaExpression& expression);
239 
240   // Sets output `index` to the Tensor List `handle`.
241   void SetTensorListOutput(int index, const xla::XlaOp& handle);
242 
243   // Status handling.
SetStatus(const Status & status)244   void SetStatus(const Status& status) { context_->SetStatus(status); }
status()245   Status status() { return context_->status(); }
246 
247   // Variables
248 
249   // Sets `*resource` to the resource associated with input `index`.
250   Status GetResourceInput(int index, XlaResource** resource);
251 
252   // Sets output `index` to be a reference to resource `resource`.
253   void SetResourceOutput(int index, XlaResource* resource);
254 
255   // Sets `*type` and `*shape` to the current type and shape of a variable's
256   // value.
257   Status GetVariableTypeAndShape(int index, DataType* type,
258                                  TensorShape* shape) const;
259 
260   // When dynamic_dimension_is_minus_one is set, querying a dynamic dimension
261   // returns "-1", this is useful when the underlying ops expect explicit
262   // dynamic index like reshape.
set_dynamic_dimension_is_minus_one(bool value)263   void set_dynamic_dimension_is_minus_one(bool value) {
264     dynamic_dimension_is_minus_one_ = value;
265   }
266 
dynamic_dimension_is_minus_one()267   bool dynamic_dimension_is_minus_one() const {
268     return dynamic_dimension_is_minus_one_;
269   }
270 
is_dynamic_dimension(int64_t dim_size)271   bool is_dynamic_dimension(int64_t dim_size) { return dim_size == -1; }
272 
273   // Reads the current value of the resource variable referred to by input
274   // `index`. If `shape` is not nullptr, sets `*shape` to the shape of the
275   // variable. Returns an error if the variable has not been initialized, or if
276   // its type does not match `type`.
277   Status ReadVariableInput(int index, DataType type, TensorShape* shape,
278                            xla::XlaOp* value);
279   // Reads the current value of the resource variable referred to by input
280   // `name`.
281   Status ReadVariableInput(absl::string_view name, DataType type,
282                            TensorShape* shape, xla::XlaOp* value);
283 
284   // Assigns the value `handle` to the variable referenced by input
285   // `input_index`. The variable must be of `type`. Returns an error if the
286   // variable has been initialized with a different type or with a
287   // different shape.
288   Status AssignVariable(int input_index, DataType type, xla::XlaOp handle);
289   // Assigns the value `handle` to the variable referenced by input `name`.
290   Status AssignVariable(absl::string_view name, DataType type,
291                         xla::XlaOp handle);
292 
293   // Helper routines for the OP_REQUIRES macros
294   void CtxFailure(const Status& s);
295   void CtxFailureWithWarning(const Status& s);
296   void CtxFailure(const char* file, int line, const Status& s);
297   void CtxFailureWithWarning(const char* file, int line, const Status& s);
298 
299   // If this kernel invocation is within a function execution,
300   // call_frame() returns the call frame for the function call.
call_frame()301   CallFrameInterface* call_frame() const { return context_->call_frame(); }
302 
function_library()303   FunctionLibraryRuntime* function_library() const {
304     return context_->function_library();
305   }
306 
op_kernel()307   const OpKernel& op_kernel() const { return context_->op_kernel(); }
308 
309   // Returns the underlying OpKernelContext. Use rarely.
op_kernel_context()310   OpKernelContext* op_kernel_context() const { return context_; }
311 
312   // Returns the XlaCompiler that is performing the compilation. Used for, e.g.,
313   // While to compile nested computations.
314   XlaCompiler* compiler() const;
315 
316   // TODO(phawkins): find a better home for these helpers.
317 
318   // Gets an XLA lambda to compute Max. This is cached in the
319   // XlaContext since it may be used by multiple Ops. There is a
320   // separate specialization of the computation for each DataType.
321   const xla::XlaComputation* GetOrCreateMax(const DataType type);
322 
323   // Gets an XLA lambda to compute Min. This is cached in the
324   // XlaContext since it may be used by multiple Ops. There is a
325   // separate specialization of the computation for each DataType.
326   const xla::XlaComputation* GetOrCreateMin(const DataType type);
327 
328   // Gets an XLA lambda to compute Add. This is cached in the
329   // XlaContext since it may be used by multiple Ops. There is a
330   // separate specialization of the computation for each DataType.
331   const xla::XlaComputation* GetOrCreateAdd(const DataType type);
332 
333   // Gets an XLA lambda to compute Mul. This is cached in the
334   // XlaContext since it may be used by multiple Ops. There is a
335   // separate specialization of the computation for each DataType.
336   const xla::XlaComputation* GetOrCreateMul(const DataType type);
337 
338   // Returns stack trace encoded as a string at a given module, or an empty
339   // string if none found.
340   std::string StackTrace() const;
341 
342  private:
343   // Returns the tensor of input `name`.
344   const Tensor& GetInputTensorByName(absl::string_view name);
345   // Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
346   // InputShape(index), and stores it in `*constant_literal`. If the input
347   // cannot be evaluated, e.g., because it depends on unbound parameters,
348   // returns a non-Ok status. If InputShape(index).num_elements() !=
349   // new_shape.num_elements(), returns an error status.
350   Status ConstantInputReshaped(
351       int index, absl::Span<const int64> new_dims,
352       xla::Literal* constant_literal,
353       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
354 
355   OpKernelContext* const context_;
356   bool dynamic_dimension_is_minus_one_;
357   xla::ValueInference value_inference_;
358 };
359 
360 }  // namespace tensorflow
361 
362 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
363