• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
17 
18 #include <numeric>
19 
20 #include "absl/memory/memory.h"
21 #include "tensorflow/compiler/tf2xla/literal_util.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/type_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
25 #include "tensorflow/compiler/tf2xla/xla_context.h"
26 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
27 #include "tensorflow/compiler/xla/client/value_inference.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/core/common_runtime/dma_helper.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/util/overflow.h"
34 
35 namespace tensorflow {
36 
XlaOpKernelContext(OpKernelContext * context)37 XlaOpKernelContext::XlaOpKernelContext(OpKernelContext* context)
38     : context_(context),
39       dynamic_dimension_is_minus_one_(false),
40       value_inference_(xla_context()->builder()) {}
41 
ValidateInputsAreSameShape(OpKernel * op)42 bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
43   return context_->ValidateInputsAreSameShape(op);
44 }
45 
xla_context() const46 XlaContext* XlaOpKernelContext::xla_context() const {
47   return &XlaContext::Get(context_);
48 }
49 
builder() const50 xla::XlaBuilder* XlaOpKernelContext::builder() const {
51   return xla_context()->builder();
52 }
53 
value_inference()54 xla::ValueInference& XlaOpKernelContext::value_inference() {
55   return value_inference_;
56 }
57 
compiler() const58 XlaCompiler* XlaOpKernelContext::compiler() const {
59   return xla_context()->compiler();
60 }
61 
InputExpression(int index)62 const XlaExpression& XlaOpKernelContext::InputExpression(int index) {
63   return *XlaExpression::CastExpressionFromTensor(context_->input(index));
64 }
65 
InputExpression(absl::string_view name)66 const XlaExpression& XlaOpKernelContext::InputExpression(
67     absl::string_view name) {
68   return *XlaExpression::CastExpressionFromTensor(GetInputTensorByName(name));
69 }
70 
Input(int index)71 xla::XlaOp XlaOpKernelContext::Input(int index) {
72   return InputExpression(index).AsXlaOp(builder());
73 }
74 
Input(absl::string_view name)75 xla::XlaOp XlaOpKernelContext::Input(absl::string_view name) {
76   return InputExpression(name).AsXlaOp(builder());
77 }
78 
InputShape(int index)79 TensorShape XlaOpKernelContext::InputShape(int index) {
80   return context_->input(index).shape();
81 }
82 
InputShape(absl::string_view name)83 TensorShape XlaOpKernelContext::InputShape(absl::string_view name) {
84   return GetInputTensorByName(name).shape();
85 }
86 
InputXlaShape(int index)87 StatusOr<xla::Shape> XlaOpKernelContext::InputXlaShape(int index) {
88   return builder()->GetShape(Input(index));
89 }
90 
InputXlaShape(absl::string_view name)91 StatusOr<xla::Shape> XlaOpKernelContext::InputXlaShape(absl::string_view name) {
92   return builder()->GetShape(Input(name));
93 }
94 
input_type(int index) const95 DataType XlaOpKernelContext::input_type(int index) const {
96   DataType type = context_->input_dtype(index);
97   if (type == DT_UINT8) {
98     // Masqueraded XlaExpression could have different type. See
99     // XlaOpKernelContext::SetOutputExpression for details.
100     auto expression =
101         XlaExpression::CastExpressionFromTensor(context_->input(index));
102     type = expression->dtype();
103   }
104   return type;
105 }
106 
InputType(absl::string_view name)107 DataType XlaOpKernelContext::InputType(absl::string_view name) {
108   const Tensor& tensor = GetInputTensorByName(name);
109   DataType type = tensor.dtype();
110   if (type == DT_UINT8) {
111     // Masqueraded XlaExpression could have different type. See
112     // XlaOpKernelContext::SetOutputExpression for details.
113     auto expression = XlaExpression::CastExpressionFromTensor(tensor);
114     type = expression->dtype();
115   }
116   return type;
117 }
118 
input_xla_type(int index)119 xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) {
120   xla::PrimitiveType type;
121   Status status = DataTypeToPrimitiveType(input_type(index), &type);
122   if (!status.ok()) {
123     SetStatus(status);
124     return xla::PRIMITIVE_TYPE_INVALID;
125   }
126   return type;
127 }
128 
InputXlaType(absl::string_view name)129 xla::PrimitiveType XlaOpKernelContext::InputXlaType(absl::string_view name) {
130   xla::PrimitiveType type;
131   Status status = DataTypeToPrimitiveType(InputType(name), &type);
132   if (!status.ok()) {
133     SetStatus(status);
134     return xla::PRIMITIVE_TYPE_INVALID;
135   }
136   return type;
137 }
138 
ConstantInput(int index,xla::Literal * constant_literal,xla::ValueInferenceMode mode)139 Status XlaOpKernelContext::ConstantInput(int index,
140                                          xla::Literal* constant_literal,
141                                          xla::ValueInferenceMode mode) {
142   if (this->InputXlaShape(index)->is_dynamic()) {
143     return errors::InvalidArgument(
144         "Reading input as constant from a dynamic tensor is not yet supported. "
145         "Xla shape: ",
146         this->InputXlaShape(index)->ToString());
147   }
148   return ConstantInputReshaped(index,
149                                context_->input(index).shape().dim_sizes(),
150                                constant_literal, mode);
151 }
152 
InputIndex(XlaOpKernelContext * context,absl::string_view name)153 static StatusOr<int> InputIndex(XlaOpKernelContext* context,
154                                 absl::string_view name) {
155   int start, stop;
156   TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop));
157   if (stop != start + 1) {
158     return errors::InvalidArgument("OpKernel used list-valued input name '",
159                                    name,
160                                    "' when single-valued input was "
161                                    "expected");
162   }
163   return start;
164 }
165 
ResolveInputDynamism(int index,xla::Literal * dynamism_literal)166 Status XlaOpKernelContext::ResolveInputDynamism(
167     int index, xla::Literal* dynamism_literal) {
168   return ResolveInputDynamismReshaped(
169       index, context_->input(index).shape().dim_sizes(), dynamism_literal);
170 }
171 
ResolveInputDynamism(absl::string_view name,xla::Literal * dynamism_literal)172 Status XlaOpKernelContext::ResolveInputDynamism(
173     absl::string_view name, xla::Literal* dynamism_literal) {
174   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
175   return ResolveInputDynamism(index, dynamism_literal);
176 }
177 
ConstantInput(absl::string_view name,xla::Literal * constant_literal,xla::ValueInferenceMode mode)178 Status XlaOpKernelContext::ConstantInput(absl::string_view name,
179                                          xla::Literal* constant_literal,
180                                          xla::ValueInferenceMode mode) {
181   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
182   return ConstantInput(index, constant_literal, mode);
183 }
184 
ConstantInputReshaped(int index,absl::Span<const int64_t> new_dims,xla::Literal * constant_literal,xla::ValueInferenceMode mode)185 Status XlaOpKernelContext::ConstantInputReshaped(
186     int index, absl::Span<const int64_t> new_dims,
187     xla::Literal* constant_literal, xla::ValueInferenceMode mode) {
188   TF_ASSIGN_OR_RETURN(Tensor constant, ConstantInputTensor(index, mode));
189   Tensor temp(constant.dtype());
190   if (!temp.CopyFrom(constant, TensorShape(new_dims))) {
191     return errors::InvalidArgument(
192         context_->op_kernel().name(), " input ", index, " has shape ",
193         constant.shape().DebugString(),
194         " but was asked to be reshaped to incompatible shape ",
195         TensorShape(new_dims).DebugString());
196   }
197 
198   TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
199   return OkStatus();
200 }
201 
202 // Converts an int32 or int64 scalar literal to an int64.
LiteralToInt64Scalar(const xla::LiteralSlice & literal,int64_t * out)203 static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal,
204                                    int64_t* out) {
205   if (literal.shape().rank() != 0) {
206     return errors::InvalidArgument("value is not a scalar");
207   }
208   if (literal.shape().element_type() == xla::S32) {
209     *out = literal.Get<int32>({});
210   } else if (literal.shape().element_type() == xla::S64) {
211     *out = literal.Get<int64_t>({});
212   } else {
213     return errors::InvalidArgument("value must be either int32 or int64");
214   }
215   return OkStatus();
216 }
217 
218 // Converts an float32 or float64 scalar literal to a float64.
LiteralToFloat64Scalar(const xla::LiteralSlice & literal,double * out)219 static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal,
220                                      double* out) {
221   if (literal.shape().rank() != 0) {
222     return errors::InvalidArgument("value is not a scalar");
223   }
224   if (literal.shape().element_type() == xla::F32) {
225     *out = literal.Get<float>({});
226   } else if (literal.shape().element_type() == xla::F64) {
227     *out = literal.Get<double>({});
228   } else {
229     return errors::InvalidArgument("value must be either float32 or float64");
230   }
231   return OkStatus();
232 }
233 
ConstantInputAsIntScalar(int index,int64_t * out,xla::ValueInferenceMode mode)234 Status XlaOpKernelContext::ConstantInputAsIntScalar(
235     int index, int64_t* out, xla::ValueInferenceMode mode) {
236   xla::Literal literal;
237   TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
238   return LiteralToInt64Scalar(literal, out);
239 }
240 
ConstantInputAsIntScalar(absl::string_view name,int64_t * out,xla::ValueInferenceMode mode)241 Status XlaOpKernelContext::ConstantInputAsIntScalar(
242     absl::string_view name, int64_t* out, xla::ValueInferenceMode mode) {
243   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
244   return ConstantInputAsIntScalar(index, out, mode);
245 }
246 
ConstantInputAsFloatScalar(int index,double * out,xla::ValueInferenceMode mode)247 Status XlaOpKernelContext::ConstantInputAsFloatScalar(
248     int index, double* out, xla::ValueInferenceMode mode) {
249   xla::Literal literal;
250   TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
251   return LiteralToFloat64Scalar(literal, out);
252 }
253 
LiteralToPredVector(const xla::LiteralSlice & literal,std::vector<bool> * out)254 static Status LiteralToPredVector(const xla::LiteralSlice& literal,
255                                   std::vector<bool>* out) {
256   if (literal.shape().rank() != 1) {
257     return errors::InvalidArgument("output_shape must be rank 1, got shape ",
258                                    literal.shape().DebugString());
259   }
260   int64_t size = xla::ShapeUtil::ElementsIn(literal.shape());
261   if (literal.shape().element_type() != xla::PRED) {
262     return errors::InvalidArgument("value is not PRED");
263   }
264   for (int64_t i = 0; i < size; ++i) {
265     out->push_back(literal.Get<bool>({i}));
266   }
267   return OkStatus();
268 }
269 
ResolveInputDynamismIntoPred(int index,bool * out)270 Status XlaOpKernelContext::ResolveInputDynamismIntoPred(int index, bool* out) {
271   xla::Literal literal;
272   XlaExpression e = InputExpression(index);
273   auto* client = compiler() ? compiler()->client() : nullptr;
274   StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client);
275   if (!dynamism_or_status.ok()) {
276     // When failed to resolve dynamism, conservatively consider the value
277     // dynamic. This could happen if the input depends on some ops like
278     // custom-call that is not supported generally for dynamism computation.
279     //
280     // TODO(b/176993339): Support resolving dynamism across computations so
281     // resolving dynamism will not fail in those cases.
282     *out = true;
283     return OkStatus();
284   }
285   Tensor dynamism = dynamism_or_status.ValueOrDie();
286 
287   Tensor temp(dynamism.dtype());
288   TensorShape tensor_shape({});
289   if (!temp.CopyFrom(dynamism, tensor_shape)) {
290     return errors::InvalidArgument(
291         context_->op_kernel().name(), " input ", index, " has shape ",
292         dynamism.shape().DebugString(), " which is not a R0 ", tensor_shape);
293   }
294 
295   TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp));
296   *out = literal.Get<bool>({});
297   return OkStatus();
298 }
299 
ResolveInputDynamismIntoPredVector(absl::string_view name,std::vector<bool> * out)300 Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
301     absl::string_view name, std::vector<bool>* out) {
302   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
303   return ResolveInputDynamismIntoPredVector(index, out);
304 }
305 
ResolveInputDynamismIntoPred(absl::string_view name,bool * out)306 Status XlaOpKernelContext::ResolveInputDynamismIntoPred(absl::string_view name,
307                                                         bool* out) {
308   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
309   return ResolveInputDynamismIntoPred(index, out);
310 }
311 
ResolveInputDynamismReshaped(int index,absl::Span<const int64_t> new_dims,xla::Literal * dynamism_literal)312 Status XlaOpKernelContext::ResolveInputDynamismReshaped(
313     int index, absl::Span<const int64_t> new_dims,
314     xla::Literal* dynamism_literal) {
315   XlaExpression e = InputExpression(index);
316   auto* client = compiler() ? compiler()->client() : nullptr;
317   StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client);
318   if (!dynamism_or_status.ok()) {
319     xla::Literal true_literal = xla::LiteralUtil::CreateR0<bool>(true);
320     // When failed to resolve dynamism, conservatively consider the value
321     // dynamic. This could happen if the input depends on some ops like
322     // custom-call that is not supported generally for dynamism computation.
323     *dynamism_literal =
324         true_literal
325             .Broadcast(xla::ShapeUtil::MakeShape(xla::PRED, new_dims), {})
326             .ValueOrDie();
327 
328     return OkStatus();
329   }
330   Tensor dynamism = dynamism_or_status.ValueOrDie();
331 
332   Tensor temp(dynamism.dtype());
333   if (!temp.CopyFrom(dynamism, TensorShape(new_dims))) {
334     return errors::InvalidArgument(
335         context_->op_kernel().name(), " input ", index, " has shape ",
336         dynamism.shape().DebugString(),
337         " but was asked to be reshaped to incompatible shape ",
338         TensorShape(new_dims).DebugString());
339   }
340 
341   TF_ASSIGN_OR_RETURN(*dynamism_literal, HostTensorToLiteral(temp));
342   return OkStatus();
343 }
344 
ResolveInputDynamismIntoPredVector(int index,std::vector<bool> * out)345 Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
346     int index, std::vector<bool>* out) {
347   xla::Literal literal;
348   TF_RETURN_IF_ERROR(ResolveInputDynamismReshaped(
349       index, {InputShape(index).num_elements()}, &literal));
350 
351   return LiteralToPredVector(literal, out);
352 }
353 
354 // Converts an int32 or int64 1D literal to an int64 vector.
LiteralToInt64Vector(const xla::LiteralSlice & literal,std::vector<int64_t> * out)355 static Status LiteralToInt64Vector(const xla::LiteralSlice& literal,
356                                    std::vector<int64_t>* out) {
357   if (literal.shape().rank() != 1) {
358     return errors::InvalidArgument("output_shape must be rank 1, got shape ",
359                                    literal.shape().DebugString());
360   }
361   int64_t size = xla::ShapeUtil::ElementsIn(literal.shape());
362   if (literal.shape().element_type() == xla::S32) {
363     for (int64_t i = 0; i < size; ++i) {
364       out->push_back(literal.Get<int32>({i}));
365     }
366   } else if (literal.shape().element_type() == xla::S64) {
367     for (int64_t i = 0; i < size; ++i) {
368       out->push_back(literal.Get<int64_t>({i}));
369     }
370   } else {
371     return errors::InvalidArgument("value must be either int32 or int64");
372   }
373   return OkStatus();
374 }
375 
ConstantInputAsIntVector(int index,std::vector<int64_t> * out,xla::ValueInferenceMode mode)376 Status XlaOpKernelContext::ConstantInputAsIntVector(
377     int index, std::vector<int64_t>* out, xla::ValueInferenceMode mode) {
378   xla::Literal literal;
379   TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
380   return LiteralToInt64Vector(literal, out);
381 }
382 
ConstantInputAsIntVector(absl::string_view name,std::vector<int64_t> * out,xla::ValueInferenceMode mode)383 Status XlaOpKernelContext::ConstantInputAsIntVector(
384     absl::string_view name, std::vector<int64_t>* out,
385     xla::ValueInferenceMode mode) {
386   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
387   return ConstantInputAsIntVector(index, out, mode);
388 }
389 
ConstantInputReshapedToIntVector(int index,std::vector<int64_t> * out,xla::ValueInferenceMode mode)390 Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
391     int index, std::vector<int64_t>* out, xla::ValueInferenceMode mode) {
392   xla::Literal literal;
393   TF_RETURN_IF_ERROR(ConstantInputReshaped(
394       index, {InputShape(index).num_elements()}, &literal, mode));
395   return LiteralToInt64Vector(literal, out);
396 }
397 
ConstantInputReshapedToIntVector(absl::string_view name,std::vector<int64_t> * out,xla::ValueInferenceMode mode)398 Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
399     absl::string_view name, std::vector<int64_t>* out,
400     xla::ValueInferenceMode mode) {
401   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
402   xla::Literal literal;
403   TF_RETURN_IF_ERROR(ConstantInputReshaped(
404       index, {InputShape(index).num_elements()}, &literal, mode));
405   return LiteralToInt64Vector(literal, out);
406 }
407 
ConstantInputAsInt64Literal(int index,xla::Literal * out,xla::ValueInferenceMode mode)408 Status XlaOpKernelContext::ConstantInputAsInt64Literal(
409     int index, xla::Literal* out, xla::ValueInferenceMode mode) {
410   xla::Literal literal;
411   TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
412   switch (literal.shape().element_type()) {
413     case xla::S32: {
414       *out = xla::Literal(
415           xla::ShapeUtil::ChangeElementType(literal.shape(), xla::S64));
416       auto src_data = literal.data<int32>();
417       for (int64_t i = 0; i < src_data.size(); ++i) {
418         out->data<int64_t>()[i] = src_data[i];
419       }
420       return OkStatus();
421     }
422     case xla::S64:
423       *out = std::move(literal);
424       return OkStatus();
425 
426     default:
427       return errors::InvalidArgument(
428           "Invalid argument to ConstantInputAsInt64Literal: ",
429           xla::ShapeUtil::HumanString(literal.shape()));
430   }
431 }
432 
ConstantInputAsInt64Literal(absl::string_view name,xla::Literal * out,xla::ValueInferenceMode mode)433 Status XlaOpKernelContext::ConstantInputAsInt64Literal(
434     absl::string_view name, xla::Literal* out, xla::ValueInferenceMode mode) {
435   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
436   return ConstantInputAsInt64Literal(index, out, mode);
437 }
438 
439 // TODO(phawkins): validate that the dimensions form a valid shape, fail
440 // gracefully if they do not.
ConstantInputAsShape(int index,TensorShape * shape,xla::ValueInferenceMode mode)441 Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape,
442                                                 xla::ValueInferenceMode mode) {
443   xla::Literal literal;
444   TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
445   std::vector<int64_t> dims;
446   TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
447 
448   int64_t num_elements = 1;
449   for (auto i = dims.begin(); i != dims.end(); ++i) {
450     num_elements = MultiplyWithoutOverflow(num_elements, *i);
451     if (num_elements < 0)
452       return errors::InvalidArgument(
453           "The total elements specified by orig_input_shape is too large.",
454           "Encountered overflow after multiplying", *i,
455           ", result: ", num_elements);
456   }
457   *shape = TensorShape(dims);
458   return OkStatus();
459 }
460 
ConstantInputAsPartialShape(int index,PartialTensorShape * shape)461 Status XlaOpKernelContext::ConstantInputAsPartialShape(
462     int index, PartialTensorShape* shape) {
463   xla::Literal literal;
464   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
465   // If `literal` is a scalar it's value must be -1.
466   if (literal.shape().rank() == 0) {
467     int64_t shape_val;
468     TF_RETURN_IF_ERROR(LiteralToInt64Scalar(literal, &shape_val));
469     if (shape_val != -1) {
470       return errors::InvalidArgument(
471           "Cannot convert value to PartialTensorShape: ", shape_val);
472     }
473     *shape = PartialTensorShape();  // Shape with unknown rank.
474     return OkStatus();
475   }
476   std::vector<int64_t> dims;
477   TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
478   *shape = PartialTensorShape(dims);
479   return OkStatus();
480 }
481 
InputList(absl::string_view name,std::vector<xla::XlaOp> * handles,std::vector<TensorShape> * shapes)482 Status XlaOpKernelContext::InputList(absl::string_view name,
483                                      std::vector<xla::XlaOp>* handles,
484                                      std::vector<TensorShape>* shapes) {
485   OpInputList inputs;
486   TF_RETURN_IF_ERROR(context_->input_list(name, &inputs));
487   handles->clear();
488   shapes->clear();
489   for (const Tensor& input : inputs) {
490     handles->push_back(
491         XlaExpression::CastExpressionFromTensor(input)->AsXlaOp(builder()));
492     shapes->push_back(input.shape());
493   }
494   return OkStatus();
495 }
496 
ConstantInputList(absl::string_view name,std::vector<xla::Literal> * outputs,xla::ValueInferenceMode mode)497 Status XlaOpKernelContext::ConstantInputList(absl::string_view name,
498                                              std::vector<xla::Literal>* outputs,
499                                              xla::ValueInferenceMode mode) {
500   int start, stop;
501   TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop));
502   outputs->resize(stop - start);
503   for (int i = start; i < stop; ++i) {
504     TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i], mode));
505   }
506   return OkStatus();
507 }
508 
ConstantInputTensor(int index,xla::ValueInferenceMode mode)509 StatusOr<Tensor> XlaOpKernelContext::ConstantInputTensor(
510     int index, xla::ValueInferenceMode mode) {
511   XlaExpression e = InputExpression(index);
512   auto* client = compiler() ? compiler()->client() : nullptr;
513   StatusOr<std::optional<Tensor>> constant_or_status =
514       e.ResolveConstant(client, dynamic_dimension_is_minus_one_, mode);
515   if (!constant_or_status.ok()) {
516     Status status = constant_or_status.status();
517     errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
518                             context_->op_kernel().type_string(),
519                             " operator as a compile-time constant.");
520     return status;
521   }
522   std::optional<Tensor> constant = constant_or_status.ValueOrDie();
523   if (!constant.has_value()) {
524     return errors::InvalidArgument(
525         "Input ", index, " to node `", context_->op_kernel().name(),
526         "` with op ", context_->op_kernel().type_string(),
527         " must be a compile-time constant.\n\n"
528         "XLA compilation requires that operator arguments that represent "
529         "shapes or dimensions be evaluated to concrete values at compile time. "
530         "This error means that a shape or dimension argument could not be "
531         "evaluated at compile time, usually because the value of the argument "
532         "depends on a parameter to the computation, on a variable, or on a "
533         "stateful operation such as a random number generator.");
534   }
535   return *constant;
536 }
537 
538 namespace {
539 
ReadVariableInputTensor(const Tensor & tensor,DataType type,const XlaOpKernelContext * ctx,TensorShape * shape,xla::XlaOp * value)540 Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
541                                const XlaOpKernelContext* ctx,
542                                TensorShape* shape, xla::XlaOp* value) {
543   const XlaExpression* expression =
544       XlaExpression::CastExpressionFromTensor(tensor);
545   XlaResource* variable = expression->resource();
546   TF_RET_CHECK(variable != nullptr);
547   TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
548   if (!variable->initialized()) {
549     return errors::FailedPrecondition(
550         "Read variable failure ", variable->name(),
551         ". It could mean the variable is uninitialized or the variable is on "
552         "another device ");
553   }
554   if (variable->type() != type) {
555     return errors::InvalidArgument(
556         "Trying to read variable with wrong dtype. Expected ",
557         DataTypeString(type), " got ", DataTypeString(variable->type()));
558   }
559   if (shape) {
560     *shape = variable->shape();
561   }
562 
563   if (!variable->IsOverwritten() && expression->constant_value()) {
564     TF_ASSIGN_OR_RETURN(xla::Literal literal,
565                         HostTensorToLiteral(*expression->constant_value()));
566     *value = xla::ConstantLiteral(ctx->builder(), literal);
567     return OkStatus();
568   }
569   auto shape_determination_fns =
570       ctx->compiler()->options().shape_determination_fns;
571   XlaLayoutPreference layout_preference =
572       shape_determination_fns.layout_preference_fn(
573           variable->shape(), variable->type(), std::nullopt);
574   TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
575                       shape_determination_fns.shape_representation_fn(
576                           variable->shape(), variable->type(),
577                           /*use_fast_memory=*/false, layout_preference));
578   xla::Shape xla_shape;
579   TF_RETURN_IF_ERROR(
580       TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape));
581   if (xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
582     *value = variable->value();
583   } else {
584     *value = xla::Reshape(variable->value(), variable->shape().dim_sizes());
585   }
586   return OkStatus();
587 }
588 
589 }  // namespace
590 
ReadVariableInput(int index,DataType type,TensorShape * shape,xla::XlaOp * value)591 Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
592                                              TensorShape* shape,
593                                              xla::XlaOp* value) {
594   return ReadVariableInputTensor(context_->input(index), type, this, shape,
595                                  value);
596 }
597 
ReadVariableInput(absl::string_view name,DataType type,TensorShape * shape,xla::XlaOp * value)598 Status XlaOpKernelContext::ReadVariableInput(absl::string_view name,
599                                              DataType type, TensorShape* shape,
600                                              xla::XlaOp* value) {
601   return ReadVariableInputTensor(GetInputTensorByName(name), type, this, shape,
602                                  value);
603 }
604 
GetVariableTypeAndShape(int index,DataType * type,TensorShape * shape) const605 Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
606                                                    TensorShape* shape) const {
607   const Tensor& tensor = context_->input(index);
608   const XlaExpression* expression =
609       XlaExpression::CastExpressionFromTensor(tensor);
610   XlaResource* variable = expression->resource();
611   TF_RET_CHECK(variable != nullptr);
612   TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
613   if (!variable->initialized()) {
614     return errors::InvalidArgument(
615         "Read variable failure ", variable->name(),
616         ". It could mean the variable is uninitialized or the variable is on "
617         "another device ");
618   }
619   *type = variable->type();
620   *shape = variable->shape();
621   return OkStatus();
622 }
623 
SetOutputExpression(int index,const XlaExpression & expression)624 void XlaOpKernelContext::SetOutputExpression(int index,
625                                              const XlaExpression& expression) {
626   Status status = [&] {
627     // The step's default allocator is the dummy XlaCompilationAllocator which
628     // simply allocates a metadata buffer to hold the expression to which it
629     // corresponds.
630     // Provides a special behavior for DT_VARIANT and other types that are not
631     // trivially copyable. In those cases, allocate a tensor of type DT_UINT8.
632     if (!DataTypeCanUseMemcpy(expression.dtype())) {
633       // tensor_data() is not supported for tensors that cannot be copied via
634       // memcpy, as the copy logic might try to inspect the stored data (e.g.
635       // a std::string). This is likely to fail, as the data is invalid given
636       // that it actually encodes an XlaExpression. Using a uint8 tensor is
637       // always safe, so simply do that.
638       // TODO(jpienaar): This should be refactored to stop masquerading
639       // XlaExpressions as Tensors.
640       Tensor output;
641       TensorShape tensor_shape;
642       TF_RETURN_IF_ERROR(
643           context_->allocate_temp(DT_UINT8, tensor_shape, &output));
644       context_->set_output(index, output);
645     } else {
646       Tensor* output = nullptr;
647       TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape());
648       TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output));
649     }
650     XlaExpression::AssignExpressionToTensor(expression,
651                                             context_->mutable_output(index));
652     return OkStatus();
653   }();
654   if (!status.ok()) {
655     SetStatus(status);
656   }
657 }
658 
output_xla_type(int index)659 xla::PrimitiveType XlaOpKernelContext::output_xla_type(int index) {
660   xla::PrimitiveType type;
661   Status status = DataTypeToPrimitiveType(expected_output_dtype(index), &type);
662   if (!status.ok()) {
663     SetStatus(status);
664     return xla::PRIMITIVE_TYPE_INVALID;
665   }
666   return type;
667 }
668 
SetOutput(int index,const xla::XlaOp & handle)669 void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
670   SetOutputExpression(
671       index,
672       XlaExpression::XlaOp(handle, context_->expected_output_dtype(index)));
673 }
674 
SetConstantOutput(int index,const Tensor & constant)675 void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
676   SetOutputExpression(index, XlaExpression::Constant(constant));
677 }
678 
SetTensorListOutput(int index,const xla::XlaOp & handle)679 void XlaOpKernelContext::SetTensorListOutput(int index,
680                                              const xla::XlaOp& handle) {
681   SetOutputExpression(index, XlaExpression::TensorList(handle));
682 }
683 
SetResourceOutput(int index,XlaResource * resource)684 void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
685   SetOutputExpression(index, XlaExpression::Resource(resource));
686 }
687 
GetResourceInput(int index,XlaResource ** resource)688 Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
689   const XlaExpression* expression =
690       XlaExpression::CastExpressionFromTensor(context_->input(index));
691   TF_RET_CHECK(expression->resource() != nullptr);
692   *resource = expression->resource();
693   return OkStatus();
694 }
695 
696 namespace {
697 
AssignVariableTensor(const Tensor & tensor,DataType type,const XlaOpKernelContext * ctx,xla::XlaOp handle,xla::XlaBuilder * builder)698 Status AssignVariableTensor(const Tensor& tensor, DataType type,
699                             const XlaOpKernelContext* ctx, xla::XlaOp handle,
700                             xla::XlaBuilder* builder) {
701   const XlaExpression* expression =
702       XlaExpression::CastExpressionFromTensor(tensor);
703   XlaResource* variable = expression->resource();
704   TF_RET_CHECK(variable != nullptr);
705   TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
706 
707   auto shape_or_status = builder->GetShape(handle);
708   if (!shape_or_status.ok()) {
709     return shape_or_status.status();
710   }
711   TensorShape shape;
712   TF_RETURN_IF_ERROR(
713       XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape));
714 
715   TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
716 
717   auto shape_determination_fns =
718       ctx->compiler()->options().shape_determination_fns;
719   XlaLayoutPreference layout_preference =
720       shape_determination_fns.layout_preference_fn(shape, type, std::nullopt);
721   TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
722                       shape_determination_fns.shape_representation_fn(
723                           shape, type,
724                           /*use_fast_memory=*/false, layout_preference));
725   xla::Shape xla_shape;
726   TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
727   if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
728     handle = xla::Reshape(handle, representation_shape.dimensions());
729   }
730   variable->SetRepresentationShape(representation_shape);
731   return variable->SetValue(handle);
732 }
733 
734 }  // namespace
735 
AssignVariable(int input_index,DataType type,xla::XlaOp handle)736 Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
737                                           xla::XlaOp handle) {
738   TF_RET_CHECK(handle.valid());
739   return AssignVariableTensor(context_->input(input_index), type, this, handle,
740                               builder());
741 }
742 
AssignVariable(absl::string_view name,DataType type,xla::XlaOp handle)743 Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type,
744                                           xla::XlaOp handle) {
745   TF_RET_CHECK(handle.valid());
746   return AssignVariableTensor(GetInputTensorByName(name), type, this, handle,
747                               builder());
748 }
749 
GetStatusWithStackTrace(const Status & s,const XlaOpKernelContext * ctx)750 static Status GetStatusWithStackTrace(const Status& s,
751                                       const XlaOpKernelContext* ctx) {
752   if (s.code() == error::INVALID_ARGUMENT) {
753     return Status{s.code(),
754                   absl::StrCat(s.error_message(), "\n", ctx->StackTrace())};
755   }
756   return s;
757 }
758 
CtxFailure(const Status & s)759 void XlaOpKernelContext::CtxFailure(const Status& s) {
760   context_->CtxFailure(GetStatusWithStackTrace(s, this));
761 }
CtxFailureWithWarning(const Status & s)762 void XlaOpKernelContext::CtxFailureWithWarning(const Status& s) {
763   context_->CtxFailureWithWarning(GetStatusWithStackTrace(s, this));
764 }
765 
CtxFailure(const char * file,int line,const Status & s)766 void XlaOpKernelContext::CtxFailure(const char* file, int line,
767                                     const Status& s) {
768   context_->CtxFailure(file, line, GetStatusWithStackTrace(s, this));
769 }
CtxFailureWithWarning(const char * file,int line,const Status & s)770 void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line,
771                                                const Status& s) {
772   context_->CtxFailureWithWarning(file, line, GetStatusWithStackTrace(s, this));
773 }
774 
GetOrCreateMax(const DataType type)775 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMax(
776     const DataType type) {
777   return xla_context()->GetOrCreateMax(type);
778 }
779 
GetOrCreateMin(const DataType type)780 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMin(
781     const DataType type) {
782   return xla_context()->GetOrCreateMin(type);
783 }
784 
GetOrCreateAdd(const DataType type)785 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateAdd(
786     const DataType type) {
787   return xla_context()->GetOrCreateAdd(type);
788 }
789 
GetOrCreateMul(const DataType type)790 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul(
791     const DataType type) {
792   return xla_context()->GetOrCreateMul(type);
793 }
794 
GetInputTensorByName(absl::string_view name)795 const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) {
796   const Tensor* tensor;
797   CHECK(context_->input(name, &tensor).ok());
798   return *tensor;
799 }
800 
XlaOpKernel(OpKernelConstruction * context)801 XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {}
802 
Compute(OpKernelContext * context)803 void XlaOpKernel::Compute(OpKernelContext* context) {
804   XlaOpKernelContext xla_context(context);
805   Compile(&xla_context);
806 }
807 
StackTrace() const808 std::string XlaOpKernelContext::StackTrace() const {
809   if (const AbstractStackTrace* stack_trace =
810           xla_context()->StackTraceForNodeName(op_kernel().name())) {
811     AbstractStackTrace::TracePrintingOptions opts;
812     opts.show_line_contents = true;
813     opts.filter_common_prefix = true;
814     opts.drop_internal_frames = true;
815     return absl::StrCat("\nStack trace for op definition: \n",
816                         stack_trace->ToString(opts), "\n");
817   } else {
818     return "";
819   }
820 }
821 
822 }  // namespace tensorflow
823