• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
17 
18 #include <numeric>
19 
20 #include "absl/memory/memory.h"
21 #include "tensorflow/compiler/tf2xla/literal_util.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/type_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
25 #include "tensorflow/compiler/tf2xla/xla_context.h"
26 #include "tensorflow/compiler/xla/client/value_inference.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/client/xla_computation.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/core/common_runtime/dma_helper.h"
31 #include "tensorflow/core/platform/errors.h"
32 
33 namespace tensorflow {
34 
XlaOpKernelContext(OpKernelContext * context)35 XlaOpKernelContext::XlaOpKernelContext(OpKernelContext* context)
36     : context_(context),
37       dynamic_dimension_is_minus_one_(false),
38       value_inference_(xla_context()->builder()) {}
39 
ValidateInputsAreSameShape(OpKernel * op)40 bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
41   return context_->ValidateInputsAreSameShape(op);
42 }
43 
xla_context() const44 XlaContext* XlaOpKernelContext::xla_context() const {
45   return &XlaContext::Get(context_);
46 }
47 
builder() const48 xla::XlaBuilder* XlaOpKernelContext::builder() const {
49   return xla_context()->builder();
50 }
51 
value_inference()52 xla::ValueInference& XlaOpKernelContext::value_inference() {
53   return value_inference_;
54 }
55 
compiler() const56 XlaCompiler* XlaOpKernelContext::compiler() const {
57   return xla_context()->compiler();
58 }
59 
InputExpression(int index)60 const XlaExpression& XlaOpKernelContext::InputExpression(int index) {
61   return *XlaExpression::CastExpressionFromTensor(context_->input(index));
62 }
63 
InputExpression(absl::string_view name)64 const XlaExpression& XlaOpKernelContext::InputExpression(
65     absl::string_view name) {
66   return *XlaExpression::CastExpressionFromTensor(GetInputTensorByName(name));
67 }
68 
Input(int index)69 xla::XlaOp XlaOpKernelContext::Input(int index) {
70   return InputExpression(index).AsXlaOp(builder());
71 }
72 
Input(absl::string_view name)73 xla::XlaOp XlaOpKernelContext::Input(absl::string_view name) {
74   return InputExpression(name).AsXlaOp(builder());
75 }
76 
InputShape(int index)77 TensorShape XlaOpKernelContext::InputShape(int index) {
78   return context_->input(index).shape();
79 }
80 
InputShape(absl::string_view name)81 TensorShape XlaOpKernelContext::InputShape(absl::string_view name) {
82   return GetInputTensorByName(name).shape();
83 }
84 
InputXlaShape(int index)85 StatusOr<xla::Shape> XlaOpKernelContext::InputXlaShape(int index) {
86   return builder()->GetShape(Input(index));
87 }
88 
InputXlaShape(absl::string_view name)89 StatusOr<xla::Shape> XlaOpKernelContext::InputXlaShape(absl::string_view name) {
90   return builder()->GetShape(Input(name));
91 }
92 
input_type(int index) const93 DataType XlaOpKernelContext::input_type(int index) const {
94   DataType type = context_->input_dtype(index);
95   if (type == DT_UINT8) {
96     // Masqueraded XlaExpression could have different type. See
97     // XlaOpKernelContext::SetOutputExpression for details.
98     auto expression =
99         XlaExpression::CastExpressionFromTensor(context_->input(index));
100     type = expression->dtype();
101   }
102   return type;
103 }
104 
InputType(absl::string_view name)105 DataType XlaOpKernelContext::InputType(absl::string_view name) {
106   const Tensor& tensor = GetInputTensorByName(name);
107   DataType type = tensor.dtype();
108   if (type == DT_UINT8) {
109     // Masqueraded XlaExpression could have different type. See
110     // XlaOpKernelContext::SetOutputExpression for details.
111     auto expression = XlaExpression::CastExpressionFromTensor(tensor);
112     type = expression->dtype();
113   }
114   return type;
115 }
116 
input_xla_type(int index)117 xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) {
118   xla::PrimitiveType type;
119   Status status = DataTypeToPrimitiveType(input_type(index), &type);
120   if (!status.ok()) {
121     SetStatus(status);
122     return xla::PRIMITIVE_TYPE_INVALID;
123   }
124   return type;
125 }
126 
InputXlaType(absl::string_view name)127 xla::PrimitiveType XlaOpKernelContext::InputXlaType(absl::string_view name) {
128   xla::PrimitiveType type;
129   Status status = DataTypeToPrimitiveType(InputType(name), &type);
130   if (!status.ok()) {
131     SetStatus(status);
132     return xla::PRIMITIVE_TYPE_INVALID;
133   }
134   return type;
135 }
136 
ConstantInput(int index,xla::Literal * constant_literal,xla::ValueInferenceMode mode)137 Status XlaOpKernelContext::ConstantInput(int index,
138                                          xla::Literal* constant_literal,
139                                          xla::ValueInferenceMode mode) {
140   if (this->InputXlaShape(index)->is_dynamic()) {
141     return errors::InvalidArgument(
142         "Reading input as constant from a dynamic tensor is not yet supported. "
143         "Xla shape: ",
144         this->InputXlaShape(index)->ToString());
145   }
146   return ConstantInputReshaped(index,
147                                context_->input(index).shape().dim_sizes(),
148                                constant_literal, mode);
149 }
150 
InputIndex(XlaOpKernelContext * context,absl::string_view name)151 static StatusOr<int> InputIndex(XlaOpKernelContext* context,
152                                 absl::string_view name) {
153   int start, stop;
154   TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop));
155   if (stop != start + 1) {
156     return errors::InvalidArgument("OpKernel used list-valued input name '",
157                                    name,
158                                    "' when single-valued input was "
159                                    "expected");
160   }
161   return start;
162 }
163 
ResolveInputDynamism(int index,xla::Literal * dynamism_literal)164 Status XlaOpKernelContext::ResolveInputDynamism(
165     int index, xla::Literal* dynamism_literal) {
166   return ResolveInputDynamismReshaped(
167       index, context_->input(index).shape().dim_sizes(), dynamism_literal);
168 }
169 
ResolveInputDynamism(absl::string_view name,xla::Literal * dynamism_literal)170 Status XlaOpKernelContext::ResolveInputDynamism(
171     absl::string_view name, xla::Literal* dynamism_literal) {
172   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
173   return ResolveInputDynamism(index, dynamism_literal);
174 }
175 
ConstantInput(absl::string_view name,xla::Literal * constant_literal,xla::ValueInferenceMode mode)176 Status XlaOpKernelContext::ConstantInput(absl::string_view name,
177                                          xla::Literal* constant_literal,
178                                          xla::ValueInferenceMode mode) {
179   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
180   return ConstantInput(index, constant_literal, mode);
181 }
182 
ConstantInputReshaped(int index,absl::Span<const int64> new_dims,xla::Literal * constant_literal,xla::ValueInferenceMode mode)183 Status XlaOpKernelContext::ConstantInputReshaped(
184     int index, absl::Span<const int64> new_dims, xla::Literal* constant_literal,
185     xla::ValueInferenceMode mode) {
186   XlaExpression e = InputExpression(index);
187   auto* client = compiler() ? compiler()->client() : nullptr;
188   StatusOr<absl::optional<Tensor>> constant_or_status =
189       e.ResolveConstant(client, dynamic_dimension_is_minus_one_, mode);
190   if (!constant_or_status.ok()) {
191     Status status = constant_or_status.status();
192     errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
193                             context_->op_kernel().type_string(),
194                             " operator as a compile-time constant.");
195     return status;
196   }
197   absl::optional<Tensor> constant = constant_or_status.ValueOrDie();
198   if (!constant.has_value()) {
199     return errors::InvalidArgument(
200         "Input ", index, " to node `", context_->op_kernel().name(),
201         "` with op ", context_->op_kernel().type_string(),
202         " must be a compile-time constant.\n\n"
203         "XLA compilation requires that operator arguments that represent "
204         "shapes or dimensions be evaluated to concrete values at compile time. "
205         "This error means that a shape or dimension argument could not be "
206         "evaluated at compile time, usually because the value of the argument "
207         "depends on a parameter to the computation, on a variable, or on a "
208         "stateful operation such as a random number generator.");
209   }
210 
211   Tensor temp(constant->dtype());
212   if (!temp.CopyFrom(*constant, TensorShape(new_dims))) {
213     return errors::InvalidArgument(
214         context_->op_kernel().name(), " input ", index, " has shape ",
215         constant->shape().DebugString(),
216         " but was asked to be reshaped to incompatible shape ",
217         TensorShape(new_dims).DebugString());
218   }
219 
220   TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
221   return Status::OK();
222 }
223 
224 // Converts an int32 or int64 scalar literal to an int64.
LiteralToInt64Scalar(const xla::LiteralSlice & literal,int64 * out)225 static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal,
226                                    int64* out) {
227   if (literal.shape().rank() != 0) {
228     return errors::InvalidArgument("value is not a scalar");
229   }
230   if (literal.shape().element_type() == xla::S32) {
231     *out = literal.Get<int32>({});
232   } else if (literal.shape().element_type() == xla::S64) {
233     *out = literal.Get<int64>({});
234   } else {
235     return errors::InvalidArgument("value must be either int32 or int64");
236   }
237   return Status::OK();
238 }
239 
240 // Converts an float32 or float64 scalar literal to a float64.
LiteralToFloat64Scalar(const xla::LiteralSlice & literal,double * out)241 static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal,
242                                      double* out) {
243   if (literal.shape().rank() != 0) {
244     return errors::InvalidArgument("value is not a scalar");
245   }
246   if (literal.shape().element_type() == xla::F32) {
247     *out = literal.Get<float>({});
248   } else if (literal.shape().element_type() == xla::F64) {
249     *out = literal.Get<double>({});
250   } else {
251     return errors::InvalidArgument("value must be either float32 or float64");
252   }
253   return Status::OK();
254 }
255 
ConstantInputAsIntScalar(int index,int64 * out,xla::ValueInferenceMode mode)256 Status XlaOpKernelContext::ConstantInputAsIntScalar(
257     int index, int64* out, xla::ValueInferenceMode mode) {
258   xla::Literal literal;
259   TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
260   return LiteralToInt64Scalar(literal, out);
261 }
262 
ConstantInputAsIntScalar(absl::string_view name,int64 * out,xla::ValueInferenceMode mode)263 Status XlaOpKernelContext::ConstantInputAsIntScalar(
264     absl::string_view name, int64* out, xla::ValueInferenceMode mode) {
265   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
266   return ConstantInputAsIntScalar(index, out, mode);
267 }
268 
ConstantInputAsFloatScalar(int index,double * out,xla::ValueInferenceMode mode)269 Status XlaOpKernelContext::ConstantInputAsFloatScalar(
270     int index, double* out, xla::ValueInferenceMode mode) {
271   xla::Literal literal;
272   TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
273   return LiteralToFloat64Scalar(literal, out);
274 }
275 
LiteralToPredVector(const xla::LiteralSlice & literal,std::vector<bool> * out)276 static Status LiteralToPredVector(const xla::LiteralSlice& literal,
277                                   std::vector<bool>* out) {
278   if (literal.shape().rank() != 1) {
279     return errors::InvalidArgument("value is not 1D, rank: ",
280                                    literal.shape().rank());
281   }
282   int64_t size = xla::ShapeUtil::ElementsIn(literal.shape());
283   if (literal.shape().element_type() != xla::PRED) {
284     return errors::InvalidArgument("value is not PRED");
285   }
286   for (int64_t i = 0; i < size; ++i) {
287     out->push_back(literal.Get<bool>({i}));
288   }
289   return Status::OK();
290 }
291 
ResolveInputDynamismIntoPred(int index,bool * out)292 Status XlaOpKernelContext::ResolveInputDynamismIntoPred(int index, bool* out) {
293   xla::Literal literal;
294   XlaExpression e = InputExpression(index);
295   auto* client = compiler() ? compiler()->client() : nullptr;
296   StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client);
297   if (!dynamism_or_status.ok()) {
298     // When failed to resolve dynamism, conservatively consider the value
299     // dynamic. This could happen if the input depends on some ops like
300     // custom-call that is not supported generally for dynamism computation.
301     //
302     // TODO(b/176993339): Support resolving dynamism across computations so
303     // resolving dynamism will not fail in those cases.
304     *out = true;
305     return Status::OK();
306   }
307   Tensor dynamism = dynamism_or_status.ValueOrDie();
308 
309   Tensor temp(dynamism.dtype());
310   TensorShape tensor_shape({});
311   if (!temp.CopyFrom(dynamism, tensor_shape)) {
312     return errors::InvalidArgument(
313         context_->op_kernel().name(), " input ", index, " has shape ",
314         dynamism.shape().DebugString(), " which is not a R0 ", tensor_shape);
315   }
316 
317   TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp));
318   *out = literal.Get<bool>({});
319   return Status::OK();
320 }
321 
ResolveInputDynamismIntoPredVector(absl::string_view name,std::vector<bool> * out)322 Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
323     absl::string_view name, std::vector<bool>* out) {
324   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
325   return ResolveInputDynamismIntoPredVector(index, out);
326 }
327 
ResolveInputDynamismIntoPred(absl::string_view name,bool * out)328 Status XlaOpKernelContext::ResolveInputDynamismIntoPred(absl::string_view name,
329                                                         bool* out) {
330   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
331   return ResolveInputDynamismIntoPred(index, out);
332 }
333 
ResolveInputDynamismReshaped(int index,absl::Span<const int64> new_dims,xla::Literal * dynamism_literal)334 Status XlaOpKernelContext::ResolveInputDynamismReshaped(
335     int index, absl::Span<const int64> new_dims,
336     xla::Literal* dynamism_literal) {
337   XlaExpression e = InputExpression(index);
338   auto* client = compiler() ? compiler()->client() : nullptr;
339   StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client);
340   if (!dynamism_or_status.ok()) {
341     xla::Literal true_literal = xla::LiteralUtil::CreateR0<bool>(true);
342     // When failed to resolve dynamism, conservatively consider the value
343     // dynamic. This could happen if the input depends on some ops like
344     // custom-call that is not supported generally for dynamism computation.
345     *dynamism_literal =
346         true_literal
347             .Broadcast(xla::ShapeUtil::MakeShape(xla::PRED, new_dims), {})
348             .ValueOrDie();
349 
350     return Status::OK();
351   }
352   Tensor dynamism = dynamism_or_status.ValueOrDie();
353 
354   Tensor temp(dynamism.dtype());
355   if (!temp.CopyFrom(dynamism, TensorShape(new_dims))) {
356     return errors::InvalidArgument(
357         context_->op_kernel().name(), " input ", index, " has shape ",
358         dynamism.shape().DebugString(),
359         " but was asked to be reshaped to incompatible shape ",
360         TensorShape(new_dims).DebugString());
361   }
362 
363   TF_ASSIGN_OR_RETURN(*dynamism_literal, HostTensorToLiteral(temp));
364   return Status::OK();
365 }
366 
ResolveInputDynamismIntoPredVector(int index,std::vector<bool> * out)367 Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
368     int index, std::vector<bool>* out) {
369   xla::Literal literal;
370   TF_RETURN_IF_ERROR(ResolveInputDynamismReshaped(
371       index, {InputShape(index).num_elements()}, &literal));
372 
373   return LiteralToPredVector(literal, out);
374 }
375 
376 // Converts an int32 or int64 1D literal to an int64 vector.
LiteralToInt64Vector(const xla::LiteralSlice & literal,std::vector<int64> * out)377 static Status LiteralToInt64Vector(const xla::LiteralSlice& literal,
378                                    std::vector<int64>* out) {
379   if (literal.shape().rank() != 1) {
380     return errors::InvalidArgument("value is not 1D, rank: ",
381                                    literal.shape().rank());
382   }
383   int64_t size = xla::ShapeUtil::ElementsIn(literal.shape());
384   if (literal.shape().element_type() == xla::S32) {
385     for (int64_t i = 0; i < size; ++i) {
386       out->push_back(literal.Get<int32>({i}));
387     }
388   } else if (literal.shape().element_type() == xla::S64) {
389     for (int64_t i = 0; i < size; ++i) {
390       out->push_back(literal.Get<int64>({i}));
391     }
392   } else {
393     return errors::InvalidArgument("value must be either int32 or int64");
394   }
395   return Status::OK();
396 }
397 
ConstantInputAsIntVector(int index,std::vector<int64> * out,xla::ValueInferenceMode mode)398 Status XlaOpKernelContext::ConstantInputAsIntVector(
399     int index, std::vector<int64>* out, xla::ValueInferenceMode mode) {
400   xla::Literal literal;
401   TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
402   return LiteralToInt64Vector(literal, out);
403 }
404 
ConstantInputAsIntVector(absl::string_view name,std::vector<int64> * out,xla::ValueInferenceMode mode)405 Status XlaOpKernelContext::ConstantInputAsIntVector(
406     absl::string_view name, std::vector<int64>* out,
407     xla::ValueInferenceMode mode) {
408   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
409   return ConstantInputAsIntVector(index, out, mode);
410 }
411 
ConstantInputReshapedToIntVector(int index,std::vector<int64> * out,xla::ValueInferenceMode mode)412 Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
413     int index, std::vector<int64>* out, xla::ValueInferenceMode mode) {
414   xla::Literal literal;
415   TF_RETURN_IF_ERROR(ConstantInputReshaped(
416       index, {InputShape(index).num_elements()}, &literal, mode));
417   return LiteralToInt64Vector(literal, out);
418 }
419 
ConstantInputReshapedToIntVector(absl::string_view name,std::vector<int64> * out,xla::ValueInferenceMode mode)420 Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
421     absl::string_view name, std::vector<int64>* out,
422     xla::ValueInferenceMode mode) {
423   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
424   xla::Literal literal;
425   TF_RETURN_IF_ERROR(ConstantInputReshaped(
426       index, {InputShape(index).num_elements()}, &literal, mode));
427   return LiteralToInt64Vector(literal, out);
428 }
429 
ConstantInputAsInt64Literal(int index,xla::Literal * out,xla::ValueInferenceMode mode)430 Status XlaOpKernelContext::ConstantInputAsInt64Literal(
431     int index, xla::Literal* out, xla::ValueInferenceMode mode) {
432   xla::Literal literal;
433   TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
434   switch (literal.shape().element_type()) {
435     case xla::S32: {
436       *out = xla::Literal(
437           xla::ShapeUtil::ChangeElementType(literal.shape(), xla::S64));
438       auto src_data = literal.data<int32>();
439       for (int64_t i = 0; i < src_data.size(); ++i) {
440         out->data<int64>()[i] = src_data[i];
441       }
442       return Status::OK();
443     }
444     case xla::S64:
445       *out = std::move(literal);
446       return Status::OK();
447 
448     default:
449       return errors::InvalidArgument(
450           "Invalid argument to ConstantInputAsInt64Literal: ",
451           xla::ShapeUtil::HumanString(literal.shape()));
452   }
453 }
454 
ConstantInputAsInt64Literal(absl::string_view name,xla::Literal * out,xla::ValueInferenceMode mode)455 Status XlaOpKernelContext::ConstantInputAsInt64Literal(
456     absl::string_view name, xla::Literal* out, xla::ValueInferenceMode mode) {
457   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
458   return ConstantInputAsInt64Literal(index, out, mode);
459 }
460 
461 // TODO(phawkins): validate that the dimensions form a valid shape, fail
462 // gracefully if they do not.
ConstantInputAsShape(int index,TensorShape * shape,xla::ValueInferenceMode mode)463 Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape,
464                                                 xla::ValueInferenceMode mode) {
465   xla::Literal literal;
466   TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
467   std::vector<int64> dims;
468   TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
469   *shape = TensorShape(dims);
470   return Status::OK();
471 }
472 
ConstantInputAsPartialShape(int index,PartialTensorShape * shape)473 Status XlaOpKernelContext::ConstantInputAsPartialShape(
474     int index, PartialTensorShape* shape) {
475   xla::Literal literal;
476   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
477   // If `literal` is a scalar it's value must be -1.
478   if (literal.shape().rank() == 0) {
479     int64_t shape_val;
480     TF_RETURN_IF_ERROR(LiteralToInt64Scalar(literal, &shape_val));
481     if (shape_val != -1) {
482       return errors::InvalidArgument(
483           "Cannot convert value to PartialTensorShape: ", shape_val);
484     }
485     *shape = PartialTensorShape();  // Shape with unknown rank.
486     return Status::OK();
487   }
488   std::vector<int64> dims;
489   TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
490   *shape = PartialTensorShape(dims);
491   return Status::OK();
492 }
493 
InputList(absl::string_view name,std::vector<xla::XlaOp> * handles,std::vector<TensorShape> * shapes)494 Status XlaOpKernelContext::InputList(absl::string_view name,
495                                      std::vector<xla::XlaOp>* handles,
496                                      std::vector<TensorShape>* shapes) {
497   OpInputList inputs;
498   TF_RETURN_IF_ERROR(context_->input_list(name, &inputs));
499   handles->clear();
500   shapes->clear();
501   for (const Tensor& input : inputs) {
502     handles->push_back(
503         XlaExpression::CastExpressionFromTensor(input)->AsXlaOp(builder()));
504     shapes->push_back(input.shape());
505   }
506   return Status::OK();
507 }
508 
ConstantInputList(absl::string_view name,std::vector<xla::Literal> * outputs,xla::ValueInferenceMode mode)509 Status XlaOpKernelContext::ConstantInputList(absl::string_view name,
510                                              std::vector<xla::Literal>* outputs,
511                                              xla::ValueInferenceMode mode) {
512   int start, stop;
513   TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop));
514   outputs->resize(stop - start);
515   for (int i = start; i < stop; ++i) {
516     TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i], mode));
517   }
518   return Status::OK();
519 }
520 
521 namespace {
522 
ReadVariableInputTensor(const Tensor & tensor,DataType type,const XlaOpKernelContext * ctx,TensorShape * shape,xla::XlaOp * value)523 Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
524                                const XlaOpKernelContext* ctx,
525                                TensorShape* shape, xla::XlaOp* value) {
526   const XlaExpression* expression =
527       XlaExpression::CastExpressionFromTensor(tensor);
528   XlaResource* variable = expression->resource();
529   TF_RET_CHECK(variable != nullptr);
530   TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
531   if (!variable->initialized()) {
532     return errors::FailedPrecondition(
533         "Read variable failure ", variable->name(),
534         ". It could mean the variable is uninitialized or the variable is on "
535         "another device ");
536   }
537   if (variable->type() != type) {
538     return errors::InvalidArgument(
539         "Type mismatch for read of variable ", variable->name(), ". Expected ",
540         DataTypeString(type), "; got ", DataTypeString(variable->type()));
541   }
542   if (shape) {
543     *shape = variable->shape();
544   }
545 
546   if (!variable->IsOverwritten() && expression->constant_value()) {
547     TF_ASSIGN_OR_RETURN(xla::Literal literal,
548                         HostTensorToLiteral(*expression->constant_value()));
549     *value = xla::ConstantLiteral(ctx->builder(), literal);
550     return Status::OK();
551   }
552 
553   TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
554                       ctx->compiler()->options().shape_representation_fn(
555                           variable->shape(), variable->type(),
556                           /*use_fast_memory=*/false));
557   xla::Shape xla_shape;
558   TF_RETURN_IF_ERROR(
559       TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape));
560   if (xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
561     *value = variable->value();
562   } else {
563     *value = xla::Reshape(variable->value(), variable->shape().dim_sizes());
564   }
565   return Status::OK();
566 }
567 
568 }  // namespace
569 
ReadVariableInput(int index,DataType type,TensorShape * shape,xla::XlaOp * value)570 Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
571                                              TensorShape* shape,
572                                              xla::XlaOp* value) {
573   return ReadVariableInputTensor(context_->input(index), type, this, shape,
574                                  value);
575 }
576 
ReadVariableInput(absl::string_view name,DataType type,TensorShape * shape,xla::XlaOp * value)577 Status XlaOpKernelContext::ReadVariableInput(absl::string_view name,
578                                              DataType type, TensorShape* shape,
579                                              xla::XlaOp* value) {
580   return ReadVariableInputTensor(GetInputTensorByName(name), type, this, shape,
581                                  value);
582 }
583 
GetVariableTypeAndShape(int index,DataType * type,TensorShape * shape) const584 Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
585                                                    TensorShape* shape) const {
586   const Tensor& tensor = context_->input(index);
587   const XlaExpression* expression =
588       XlaExpression::CastExpressionFromTensor(tensor);
589   XlaResource* variable = expression->resource();
590   TF_RET_CHECK(variable != nullptr);
591   TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
592   if (!variable->initialized()) {
593     return errors::InvalidArgument(
594         "Read variable failure ", variable->name(),
595         ". It could mean the variable is uninitialized or the variable is on "
596         "another device ");
597   }
598   *type = variable->type();
599   *shape = variable->shape();
600   return Status::OK();
601 }
602 
SetOutputExpression(int index,const XlaExpression & expression)603 void XlaOpKernelContext::SetOutputExpression(int index,
604                                              const XlaExpression& expression) {
605   Status status = [&] {
606     // The step's default allocator is the dummy XlaCompilationAllocator which
607     // simply allocates a metadata buffer to hold the expression to which it
608     // corresponds.
609     // Provides a special behavior for DT_VARIANT and other types that are not
610     // trivially copyable. In those cases, allocate a tensor of type DT_UINT8.
611     if (!DataTypeCanUseMemcpy(expression.dtype())) {
612       // tensor_data() is not supported for tensors that cannot be copied via
613       // memcpy, as the copy logic might try to inspect the stored data (e.g.
614       // a std::string). This is likely to fail, as the data is invalid given
615       // that it actually encodes an XlaExpression. Using a uint8 tensor is
616       // always safe, so simply do that.
617       // TODO(jpienaar): This should be refactored to stop masquerading
618       // XlaExpressions as Tensors.
619       Tensor output;
620       TensorShape tensor_shape;
621       TF_RETURN_IF_ERROR(
622           context_->allocate_temp(DT_UINT8, tensor_shape, &output));
623       context_->set_output(index, output);
624     } else {
625       Tensor* output = nullptr;
626       TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape());
627       TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output));
628     }
629     XlaExpression::AssignExpressionToTensor(expression,
630                                             context_->mutable_output(index));
631     return Status::OK();
632   }();
633   if (!status.ok()) {
634     SetStatus(status);
635   }
636 }
637 
output_xla_type(int index)638 xla::PrimitiveType XlaOpKernelContext::output_xla_type(int index) {
639   xla::PrimitiveType type;
640   Status status = DataTypeToPrimitiveType(expected_output_dtype(index), &type);
641   if (!status.ok()) {
642     SetStatus(status);
643     return xla::PRIMITIVE_TYPE_INVALID;
644   }
645   return type;
646 }
647 
SetOutput(int index,const xla::XlaOp & handle)648 void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
649   SetOutputExpression(
650       index,
651       XlaExpression::XlaOp(handle, context_->expected_output_dtype(index)));
652 }
653 
SetConstantOutput(int index,const Tensor & constant)654 void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
655   SetOutputExpression(index, XlaExpression::Constant(constant));
656 }
657 
SetTensorListOutput(int index,const xla::XlaOp & handle)658 void XlaOpKernelContext::SetTensorListOutput(int index,
659                                              const xla::XlaOp& handle) {
660   SetOutputExpression(index, XlaExpression::TensorList(handle));
661 }
662 
SetResourceOutput(int index,XlaResource * resource)663 void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
664   SetOutputExpression(index, XlaExpression::Resource(resource));
665 }
666 
GetResourceInput(int index,XlaResource ** resource)667 Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
668   const XlaExpression* expression =
669       XlaExpression::CastExpressionFromTensor(context_->input(index));
670   TF_RET_CHECK(expression->resource() != nullptr);
671   *resource = expression->resource();
672   return Status::OK();
673 }
674 
675 namespace {
676 
AssignVariableTensor(const Tensor & tensor,DataType type,const XlaOpKernelContext * ctx,xla::XlaOp handle,xla::XlaBuilder * builder)677 Status AssignVariableTensor(const Tensor& tensor, DataType type,
678                             const XlaOpKernelContext* ctx, xla::XlaOp handle,
679                             xla::XlaBuilder* builder) {
680   const XlaExpression* expression =
681       XlaExpression::CastExpressionFromTensor(tensor);
682   XlaResource* variable = expression->resource();
683   TF_RET_CHECK(variable != nullptr);
684   TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
685 
686   auto shape_or_status = builder->GetShape(handle);
687   if (!shape_or_status.ok()) {
688     return shape_or_status.status();
689   }
690   TensorShape shape;
691   TF_RETURN_IF_ERROR(
692       XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape));
693 
694   TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
695 
696   TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
697                       ctx->compiler()->options().shape_representation_fn(
698                           shape, type,
699                           /*use_fast_memory=*/false));
700   xla::Shape xla_shape;
701   TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
702   if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
703     handle = xla::Reshape(handle,
704                           xla::AsInt64Slice(representation_shape.dimensions()));
705   }
706   variable->SetRepresentationShape(representation_shape);
707   return variable->SetValue(handle);
708 }
709 
710 }  // namespace
711 
AssignVariable(int input_index,DataType type,xla::XlaOp handle)712 Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
713                                           xla::XlaOp handle) {
714   TF_RET_CHECK(handle.valid());
715   return AssignVariableTensor(context_->input(input_index), type, this, handle,
716                               builder());
717 }
718 
AssignVariable(absl::string_view name,DataType type,xla::XlaOp handle)719 Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type,
720                                           xla::XlaOp handle) {
721   TF_RET_CHECK(handle.valid());
722   return AssignVariableTensor(GetInputTensorByName(name), type, this, handle,
723                               builder());
724 }
725 
GetStatusWithStackTrace(const Status & s,const XlaOpKernelContext * ctx)726 static Status GetStatusWithStackTrace(const Status& s,
727                                       const XlaOpKernelContext* ctx) {
728   if (s.code() == error::INVALID_ARGUMENT) {
729     return Status{s.code(),
730                   absl::StrCat(s.error_message(), "\n", ctx->StackTrace())};
731   }
732   return s;
733 }
734 
CtxFailure(const Status & s)735 void XlaOpKernelContext::CtxFailure(const Status& s) {
736   context_->CtxFailure(GetStatusWithStackTrace(s, this));
737 }
CtxFailureWithWarning(const Status & s)738 void XlaOpKernelContext::CtxFailureWithWarning(const Status& s) {
739   context_->CtxFailureWithWarning(GetStatusWithStackTrace(s, this));
740 }
741 
CtxFailure(const char * file,int line,const Status & s)742 void XlaOpKernelContext::CtxFailure(const char* file, int line,
743                                     const Status& s) {
744   context_->CtxFailure(file, line, GetStatusWithStackTrace(s, this));
745 }
CtxFailureWithWarning(const char * file,int line,const Status & s)746 void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line,
747                                                const Status& s) {
748   context_->CtxFailureWithWarning(file, line, GetStatusWithStackTrace(s, this));
749 }
750 
GetOrCreateMax(const DataType type)751 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMax(
752     const DataType type) {
753   return xla_context()->GetOrCreateMax(type);
754 }
755 
GetOrCreateMin(const DataType type)756 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMin(
757     const DataType type) {
758   return xla_context()->GetOrCreateMin(type);
759 }
760 
GetOrCreateAdd(const DataType type)761 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateAdd(
762     const DataType type) {
763   return xla_context()->GetOrCreateAdd(type);
764 }
765 
GetOrCreateMul(const DataType type)766 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul(
767     const DataType type) {
768   return xla_context()->GetOrCreateMul(type);
769 }
770 
GetInputTensorByName(absl::string_view name)771 const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) {
772   const Tensor* tensor;
773   CHECK(context_->input(name, &tensor).ok());
774   return *tensor;
775 }
776 
XlaOpKernel(OpKernelConstruction * context)777 XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {}
778 
Compute(OpKernelContext * context)779 void XlaOpKernel::Compute(OpKernelContext* context) {
780   XlaOpKernelContext xla_context(context);
781   Compile(&xla_context);
782 }
783 
StackTrace() const784 std::string XlaOpKernelContext::StackTrace() const {
785   if (const AbstractStackTrace* stack_trace =
786           xla_context()->StackTraceForNodeName(op_kernel().name())) {
787     AbstractStackTrace::TracePrintingOptions opts;
788     opts.show_line_contents = true;
789     opts.filter_common_prefix = true;
790     opts.drop_internal_frames = true;
791     return absl::StrCat("\nStack trace for op definition: \n",
792                         stack_trace->ToString(opts), "\n");
793   } else {
794     return "";
795   }
796 }
797 
798 }  // namespace tensorflow
799