• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
18 
19 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
20 #include "tensorflow/compiler/xla/client/xla_builder.h"
21 #include "tensorflow/compiler/xla/client/xla_computation.h"
22 #include "tensorflow/compiler/xla/xla_data.pb.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/platform/macros.h"
25 
26 namespace tensorflow {
27 
28 class XlaOpKernelContext;
29 
30 // Implementations of operators that generate XLA code should usually subclass
31 // XlaOpKernel and implement the Compile() method. Unlike a regular OpKernel,
32 // an XlaOpKernel produces and consumes symbolic values during compilation.
33 //
34 // See the comments in xla_context.h for more details.
35 class XlaOpKernel : public OpKernel {
36  public:
37   explicit XlaOpKernel(OpKernelConstruction* construction);
38 
39   // Subclasses should implement Compile(), much as standard OpKernels implement
40   // Compute().
41   virtual void Compile(XlaOpKernelContext* context) = 0;
42 
43  private:
44   void Compute(OpKernelContext* context) final;
45 };
46 
47 // The context passed to the Compile() method of XlaOpKernel. An
48 // XlaOpKernelContext is a variant of the standard OpKernel class, tailored for
49 // implementing operators that perform symbolic execution as part of the XLA
50 // compiler. The key difference is that XlaOpKernelContext produces and consumes
51 // data as XLA computations, rather than as standard Tensors.
52 //
53 // Under the hood, symbolic execution communicates using special Tensors that
54 // wrap XlaExpression objects, however this is an implementation detail that
55 // this class hides. The *only* correct way to allocate a Tensor during
56 // compilation is using the XlaOpKernelContext methods, since they ensure there
57 // is a valid XlaExpression backing the tensor. No Op should ever call
58 // allocate_output or allocate_temp directly on the underlying OpKernelContext.
59 class XlaOpKernelContext {
60  public:
61   explicit XlaOpKernelContext(OpKernelContext* context);
62 
63   XlaContext* xla_context() const;
64 
65   // Returns the XLA XlaBuilder containing the output of compilation.
66   xla::XlaBuilder* builder() const;
67 
68   // Inputs
69 
70   // Returns the number of inputs to the operator.
num_inputs()71   int num_inputs() const { return context_->num_inputs(); }
72 
73   // Returns the type of input `index`.
74   DataType input_type(int index) const;
75 
76   // Returns the type of input `name`.
77   DataType InputType(absl::string_view name);
78 
79   // Returns the type of input `index` as an xla::PrimitiveType. If the type
80   // is not representable as an XLA type, sets an error status and returns
81   // xla::PRIMITIVE_TYPE_INVALID.
82   xla::PrimitiveType input_xla_type(int index);
83 
84   // Returns the shape of input `index`.
85   TensorShape InputShape(int index);
86 
87   // Returns the shape of input with name `name`.
88   TensorShape InputShape(absl::string_view name);
89 
90   // Returns input `index` as a XlaOp. Unlike
91   // OpKernelContext::Input returns a symbolic value rather than a concrete
92   // Tensor.
93   xla::XlaOp Input(int index);
94   // Returns input `name` as a XlaOp.
95   xla::XlaOp Input(absl::string_view name);
96 
97   // Returns true if all inputs are the same shape, otherwise sets the
98   // status to a non-OK value and returns false.
99   // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
100   bool ValidateInputsAreSameShape(OpKernel* op) TF_MUST_USE_RESULT;
101 
102   // Returns the named list-valued immutable input in "list", as
103   // defined in the OpDef.  If the named output is not list-valued,
104   // returns a one-element list.
105   Status InputList(absl::string_view name, std::vector<xla::XlaOp>* handles,
106                    std::vector<TensorShape>* shapes);
107 
108   // Helper methods for constant inputs.
109 
110   // Evaluates input `index` and stores it in `*constant_literal`. If the
111   // expression cannot be evaluated, e.g., because it depends on unbound
112   // parameters, returns a non-OK status.
113   Status ConstantInput(int index, xla::Literal* constant_literal);
114   Status ConstantInput(absl::string_view name, xla::Literal* constant_literal);
115 
116   // Converts a constant scalar int32 or int64 tensor into an int64.
117   Status ConstantInputAsIntScalar(int index, int64* out);
118   Status ConstantInputAsIntScalar(absl::string_view name, int64* out);
119 
120   // Converts a constant scalar float32 or float64 tensor into a float64.
121   Status ConstantInputAsFloatScalar(int index, double* out);
122 
123   // Converts a constant 1D int32 or int64 tensor into a vector of int64s.
124   Status ConstantInputAsIntVector(int index, std::vector<int64>* out);
125   Status ConstantInputAsIntVector(absl::string_view name,
126                                   std::vector<int64>* out);
127 
128   // Reshapes and converts a constant int32 or int64 tensor into a vector of
129   // int64s.
130   Status ConstantInputReshapedToIntVector(int index, std::vector<int64>* out);
131   Status ConstantInputReshapedToIntVector(absl::string_view name,
132                                           std::vector<int64>* out);
133 
134   // Converts a constant int32 or int64 Tensor into an xla int64 Literal.
135   Status ConstantInputAsInt64Literal(int index, xla::Literal* out);
136   Status ConstantInputAsInt64Literal(absl::string_view name, xla::Literal* out);
137 
138   // Converts a constant 1D int32 or int64 tensor into a TensorShape.
139   Status ConstantInputAsShape(int index, TensorShape* shape);
140 
141   // Converts a constant 1D int32 or int64 tensor, or a scalar with value -1
142   // into a PartialTensorShape.
143   Status ConstantInputAsPartialShape(int index, PartialTensorShape* shape);
144 
145   // Returns the named list-valued immutable input in "list", as
146   // defined in the OpDef.  If the named output is not list-valued,
147   // returns a one-element list.
148   Status ConstantInputList(absl::string_view name,
149                            std::vector<xla::Literal>* literals);
150 
151   // Returns an XlaExpression describing the value of 'index'.
152   const XlaExpression& InputExpression(int index);
153   const XlaExpression& InputExpression(absl::string_view name);
154 
155   // Outputs
156 
num_outputs()157   int num_outputs() const { return context_->num_outputs(); }
expected_output_dtype(int index)158   DataType expected_output_dtype(int index) const {
159     return context_->expected_output_dtype(index);
160   }
161 
162   // Returns the type of output `index` as an xla::PrimitiveType. If the type
163   // is not representable as an XLA type, sets an error status and returns
164   // xla::PRIMITIVE_TYPE_INVALID.
165   xla::PrimitiveType output_xla_type(int index);
166 
167   // Sets output `index` to the XlaOp `handle`.
168   // All outputs should be set using SetOutput and SetConstantOutput, not
169   // via the underlying OpKernelContext.
170   void SetOutput(int index, const xla::XlaOp& handle);
171 
172   // Sets output `index` to compile-time constant `host_tensor`, where
173   // `host_tensor` is a tensor in host memory. It is preferable to use
174   // SetConstantOutput where possible.
175   void SetConstantOutput(int index, const Tensor& host_tensor);
176 
177   // Returns an XlaExpression describing the value of 'index'.
178   void SetOutputExpression(int index, const XlaExpression& expression);
179 
180   // Sets output `index` to the Tensor List `handle`.
181   void SetTensorListOutput(int index, const xla::XlaOp& handle);
182 
183   // Status handling.
SetStatus(const Status & status)184   void SetStatus(const Status& status) { context_->SetStatus(status); }
status()185   Status status() { return context_->status(); }
186 
187   // Variables
188 
189   // Sets `*resource` to the resource associated with input `index`.
190   Status GetResourceInput(int index, XlaResource** resource);
191 
192   // Sets output `index` to be a reference to resource `resource`.
193   void SetResourceOutput(int index, XlaResource* resource);
194 
195   // Sets `*type` and `*shape` to the current type and shape of a variable's
196   // value.
197   Status GetVariableTypeAndShape(int index, DataType* type,
198                                  TensorShape* shape) const;
199 
200   // Reads the current value of the resouce variable referred to by input
201   // `index`. If `shape` is not nullptr, sets `*shape` to the shape of the
202   // variable. Returns an error if the variable has not been initialized, or if
203   // its type does not match `type`.
204   Status ReadVariableInput(int index, DataType type, TensorShape* shape,
205                            xla::XlaOp* value);
206   // Reads the current value of the resouce variable referred to by input
207   // `name`.
208   Status ReadVariableInput(absl::string_view name, DataType type,
209                            TensorShape* shape, xla::XlaOp* value);
210 
211   // Assigns the value `handle` to the variable referenced by input
212   // `input_index`. The variable must be of `type`. Returns an error if the
213   // variable has been initialized with a different type or with a
214   // different shape.
215   Status AssignVariable(int input_index, DataType type, xla::XlaOp handle);
216   // Assigns the value `handle` to the variable referenced by input `name`.
217   Status AssignVariable(absl::string_view name, DataType type,
218                         xla::XlaOp handle);
219 
220   // Helper routines for the OP_REQUIRES macros
221   void CtxFailure(const Status& s);
222   void CtxFailureWithWarning(const Status& s);
223   void CtxFailure(const char* file, int line, const Status& s);
224   void CtxFailureWithWarning(const char* file, int line, const Status& s);
225 
226   // If this kernel invocation is within a function execution,
227   // call_frame() returns the call frame for the function call.
call_frame()228   CallFrameInterface* call_frame() const { return context_->call_frame(); }
229 
function_library()230   FunctionLibraryRuntime* function_library() const {
231     return context_->function_library();
232   }
233 
op_kernel()234   const OpKernel& op_kernel() const { return context_->op_kernel(); }
235 
236   // Returns the underlying OpKernelContext. Use rarely.
op_kernel_context()237   OpKernelContext* op_kernel_context() const { return context_; }
238 
239   // Returns the XlaCompiler that is performing the compilation. Used for, e.g.,
240   // While to compile nested computations.
241   XlaCompiler* compiler() const;
242 
243   // TODO(phawkins): find a better home for these helpers.
244 
245   // Gets an XLA lambda to compute Max. This is cached in the
246   // XlaContext since it may be used by multiple Ops. There is a
247   // separate specialization of the computation for each DataType.
248   const xla::XlaComputation* GetOrCreateMax(const DataType type);
249 
250   // Gets an XLA lambda to compute Min. This is cached in the
251   // XlaContext since it may be used by multiple Ops. There is a
252   // separate specialization of the computation for each DataType.
253   const xla::XlaComputation* GetOrCreateMin(const DataType type);
254 
255   // Gets an XLA lambda to compute Add. This is cached in the
256   // XlaContext since it may be used by multiple Ops. There is a
257   // separate specialization of the computation for each DataType.
258   const xla::XlaComputation* GetOrCreateAdd(const DataType type);
259 
260   // Gets an XLA lambda to compute Mul. This is cached in the
261   // XlaContext since it may be used by multiple Ops. There is a
262   // separate specialization of the computation for each DataType.
263   const xla::XlaComputation* GetOrCreateMul(const DataType type);
264 
265  private:
266   // Returns the tensor of input `name`.
267   const Tensor& GetInputTensorByName(absl::string_view name);
268 
269   // Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
270   // InputShape(index), and stores it in `*constant_literal`. If the input
271   // cannot be evaluated, e.g., because it depends on unbound parameters,
272   // returns a non-Ok status. If InputShape(index).num_elements() !=
273   // new_shape.num_elements(), returns an error status.
274   Status ConstantInputReshaped(int index, absl::Span<const int64> new_dims,
275                                xla::Literal* constant_literal);
276 
277   OpKernelContext* const context_;
278 };
279 
280 }  // namespace tensorflow
281 
282 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
283