1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
17
18 #include <numeric>
19
20 #include "absl/memory/memory.h"
21 #include "tensorflow/compiler/tf2xla/literal_util.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/type_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
25 #include "tensorflow/compiler/tf2xla/xla_context.h"
26 #include "tensorflow/compiler/xla/client/value_inference.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/client/xla_computation.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/core/common_runtime/dma_helper.h"
31 #include "tensorflow/core/platform/errors.h"
32
33 namespace tensorflow {
34
XlaOpKernelContext(OpKernelContext * context)35 XlaOpKernelContext::XlaOpKernelContext(OpKernelContext* context)
36 : context_(context),
37 dynamic_dimension_is_minus_one_(false),
38 value_inference_(xla_context()->builder()) {}
39
ValidateInputsAreSameShape(OpKernel * op)40 bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
41 return context_->ValidateInputsAreSameShape(op);
42 }
43
xla_context() const44 XlaContext* XlaOpKernelContext::xla_context() const {
45 return &XlaContext::Get(context_);
46 }
47
builder() const48 xla::XlaBuilder* XlaOpKernelContext::builder() const {
49 return xla_context()->builder();
50 }
51
value_inference()52 xla::ValueInference& XlaOpKernelContext::value_inference() {
53 return value_inference_;
54 }
55
compiler() const56 XlaCompiler* XlaOpKernelContext::compiler() const {
57 return xla_context()->compiler();
58 }
59
InputExpression(int index)60 const XlaExpression& XlaOpKernelContext::InputExpression(int index) {
61 return *XlaExpression::CastExpressionFromTensor(context_->input(index));
62 }
63
InputExpression(absl::string_view name)64 const XlaExpression& XlaOpKernelContext::InputExpression(
65 absl::string_view name) {
66 return *XlaExpression::CastExpressionFromTensor(GetInputTensorByName(name));
67 }
68
Input(int index)69 xla::XlaOp XlaOpKernelContext::Input(int index) {
70 return InputExpression(index).AsXlaOp(builder());
71 }
72
Input(absl::string_view name)73 xla::XlaOp XlaOpKernelContext::Input(absl::string_view name) {
74 return InputExpression(name).AsXlaOp(builder());
75 }
76
InputShape(int index)77 TensorShape XlaOpKernelContext::InputShape(int index) {
78 return context_->input(index).shape();
79 }
80
InputShape(absl::string_view name)81 TensorShape XlaOpKernelContext::InputShape(absl::string_view name) {
82 return GetInputTensorByName(name).shape();
83 }
84
InputXlaShape(int index)85 StatusOr<xla::Shape> XlaOpKernelContext::InputXlaShape(int index) {
86 return builder()->GetShape(Input(index));
87 }
88
InputXlaShape(absl::string_view name)89 StatusOr<xla::Shape> XlaOpKernelContext::InputXlaShape(absl::string_view name) {
90 return builder()->GetShape(Input(name));
91 }
92
input_type(int index) const93 DataType XlaOpKernelContext::input_type(int index) const {
94 DataType type = context_->input_dtype(index);
95 if (type == DT_UINT8) {
96 // Masqueraded XlaExpression could have different type. See
97 // XlaOpKernelContext::SetOutputExpression for details.
98 auto expression =
99 XlaExpression::CastExpressionFromTensor(context_->input(index));
100 type = expression->dtype();
101 }
102 return type;
103 }
104
InputType(absl::string_view name)105 DataType XlaOpKernelContext::InputType(absl::string_view name) {
106 const Tensor& tensor = GetInputTensorByName(name);
107 DataType type = tensor.dtype();
108 if (type == DT_UINT8) {
109 // Masqueraded XlaExpression could have different type. See
110 // XlaOpKernelContext::SetOutputExpression for details.
111 auto expression = XlaExpression::CastExpressionFromTensor(tensor);
112 type = expression->dtype();
113 }
114 return type;
115 }
116
input_xla_type(int index)117 xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) {
118 xla::PrimitiveType type;
119 Status status = DataTypeToPrimitiveType(input_type(index), &type);
120 if (!status.ok()) {
121 SetStatus(status);
122 return xla::PRIMITIVE_TYPE_INVALID;
123 }
124 return type;
125 }
126
InputXlaType(absl::string_view name)127 xla::PrimitiveType XlaOpKernelContext::InputXlaType(absl::string_view name) {
128 xla::PrimitiveType type;
129 Status status = DataTypeToPrimitiveType(InputType(name), &type);
130 if (!status.ok()) {
131 SetStatus(status);
132 return xla::PRIMITIVE_TYPE_INVALID;
133 }
134 return type;
135 }
136
ConstantInput(int index,xla::Literal * constant_literal,xla::ValueInferenceMode mode)137 Status XlaOpKernelContext::ConstantInput(int index,
138 xla::Literal* constant_literal,
139 xla::ValueInferenceMode mode) {
140 if (this->InputXlaShape(index)->is_dynamic()) {
141 return errors::InvalidArgument(
142 "Reading input as constant from a dynamic tensor is not yet supported. "
143 "Xla shape: ",
144 this->InputXlaShape(index)->ToString());
145 }
146 return ConstantInputReshaped(index,
147 context_->input(index).shape().dim_sizes(),
148 constant_literal, mode);
149 }
150
InputIndex(XlaOpKernelContext * context,absl::string_view name)151 static StatusOr<int> InputIndex(XlaOpKernelContext* context,
152 absl::string_view name) {
153 int start, stop;
154 TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop));
155 if (stop != start + 1) {
156 return errors::InvalidArgument("OpKernel used list-valued input name '",
157 name,
158 "' when single-valued input was "
159 "expected");
160 }
161 return start;
162 }
163
ResolveInputDynamism(int index,xla::Literal * dynamism_literal)164 Status XlaOpKernelContext::ResolveInputDynamism(
165 int index, xla::Literal* dynamism_literal) {
166 return ResolveInputDynamismReshaped(
167 index, context_->input(index).shape().dim_sizes(), dynamism_literal);
168 }
169
ResolveInputDynamism(absl::string_view name,xla::Literal * dynamism_literal)170 Status XlaOpKernelContext::ResolveInputDynamism(
171 absl::string_view name, xla::Literal* dynamism_literal) {
172 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
173 return ResolveInputDynamism(index, dynamism_literal);
174 }
175
ConstantInput(absl::string_view name,xla::Literal * constant_literal,xla::ValueInferenceMode mode)176 Status XlaOpKernelContext::ConstantInput(absl::string_view name,
177 xla::Literal* constant_literal,
178 xla::ValueInferenceMode mode) {
179 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
180 return ConstantInput(index, constant_literal, mode);
181 }
182
ConstantInputReshaped(int index,absl::Span<const int64> new_dims,xla::Literal * constant_literal,xla::ValueInferenceMode mode)183 Status XlaOpKernelContext::ConstantInputReshaped(
184 int index, absl::Span<const int64> new_dims, xla::Literal* constant_literal,
185 xla::ValueInferenceMode mode) {
186 XlaExpression e = InputExpression(index);
187 auto* client = compiler() ? compiler()->client() : nullptr;
188 StatusOr<absl::optional<Tensor>> constant_or_status =
189 e.ResolveConstant(client, dynamic_dimension_is_minus_one_, mode);
190 if (!constant_or_status.ok()) {
191 Status status = constant_or_status.status();
192 errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
193 context_->op_kernel().type_string(),
194 " operator as a compile-time constant.");
195 return status;
196 }
197 absl::optional<Tensor> constant = constant_or_status.ValueOrDie();
198 if (!constant.has_value()) {
199 return errors::InvalidArgument(
200 "Input ", index, " to node `", context_->op_kernel().name(),
201 "` with op ", context_->op_kernel().type_string(),
202 " must be a compile-time constant.\n\n"
203 "XLA compilation requires that operator arguments that represent "
204 "shapes or dimensions be evaluated to concrete values at compile time. "
205 "This error means that a shape or dimension argument could not be "
206 "evaluated at compile time, usually because the value of the argument "
207 "depends on a parameter to the computation, on a variable, or on a "
208 "stateful operation such as a random number generator.");
209 }
210
211 Tensor temp(constant->dtype());
212 if (!temp.CopyFrom(*constant, TensorShape(new_dims))) {
213 return errors::InvalidArgument(
214 context_->op_kernel().name(), " input ", index, " has shape ",
215 constant->shape().DebugString(),
216 " but was asked to be reshaped to incompatible shape ",
217 TensorShape(new_dims).DebugString());
218 }
219
220 TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
221 return Status::OK();
222 }
223
224 // Converts an int32 or int64 scalar literal to an int64.
LiteralToInt64Scalar(const xla::LiteralSlice & literal,int64 * out)225 static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal,
226 int64* out) {
227 if (literal.shape().rank() != 0) {
228 return errors::InvalidArgument("value is not a scalar");
229 }
230 if (literal.shape().element_type() == xla::S32) {
231 *out = literal.Get<int32>({});
232 } else if (literal.shape().element_type() == xla::S64) {
233 *out = literal.Get<int64>({});
234 } else {
235 return errors::InvalidArgument("value must be either int32 or int64");
236 }
237 return Status::OK();
238 }
239
240 // Converts an float32 or float64 scalar literal to a float64.
LiteralToFloat64Scalar(const xla::LiteralSlice & literal,double * out)241 static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal,
242 double* out) {
243 if (literal.shape().rank() != 0) {
244 return errors::InvalidArgument("value is not a scalar");
245 }
246 if (literal.shape().element_type() == xla::F32) {
247 *out = literal.Get<float>({});
248 } else if (literal.shape().element_type() == xla::F64) {
249 *out = literal.Get<double>({});
250 } else {
251 return errors::InvalidArgument("value must be either float32 or float64");
252 }
253 return Status::OK();
254 }
255
ConstantInputAsIntScalar(int index,int64 * out,xla::ValueInferenceMode mode)256 Status XlaOpKernelContext::ConstantInputAsIntScalar(
257 int index, int64* out, xla::ValueInferenceMode mode) {
258 xla::Literal literal;
259 TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
260 return LiteralToInt64Scalar(literal, out);
261 }
262
ConstantInputAsIntScalar(absl::string_view name,int64 * out,xla::ValueInferenceMode mode)263 Status XlaOpKernelContext::ConstantInputAsIntScalar(
264 absl::string_view name, int64* out, xla::ValueInferenceMode mode) {
265 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
266 return ConstantInputAsIntScalar(index, out, mode);
267 }
268
ConstantInputAsFloatScalar(int index,double * out,xla::ValueInferenceMode mode)269 Status XlaOpKernelContext::ConstantInputAsFloatScalar(
270 int index, double* out, xla::ValueInferenceMode mode) {
271 xla::Literal literal;
272 TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
273 return LiteralToFloat64Scalar(literal, out);
274 }
275
LiteralToPredVector(const xla::LiteralSlice & literal,std::vector<bool> * out)276 static Status LiteralToPredVector(const xla::LiteralSlice& literal,
277 std::vector<bool>* out) {
278 if (literal.shape().rank() != 1) {
279 return errors::InvalidArgument("value is not 1D, rank: ",
280 literal.shape().rank());
281 }
282 int64_t size = xla::ShapeUtil::ElementsIn(literal.shape());
283 if (literal.shape().element_type() != xla::PRED) {
284 return errors::InvalidArgument("value is not PRED");
285 }
286 for (int64_t i = 0; i < size; ++i) {
287 out->push_back(literal.Get<bool>({i}));
288 }
289 return Status::OK();
290 }
291
ResolveInputDynamismIntoPred(int index,bool * out)292 Status XlaOpKernelContext::ResolveInputDynamismIntoPred(int index, bool* out) {
293 xla::Literal literal;
294 XlaExpression e = InputExpression(index);
295 auto* client = compiler() ? compiler()->client() : nullptr;
296 StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client);
297 if (!dynamism_or_status.ok()) {
298 // When failed to resolve dynamism, conservatively consider the value
299 // dynamic. This could happen if the input depends on some ops like
300 // custom-call that is not supported generally for dynamism computation.
301 //
302 // TODO(b/176993339): Support resolving dynamism across computations so
303 // resolving dynamism will not fail in those cases.
304 *out = true;
305 return Status::OK();
306 }
307 Tensor dynamism = dynamism_or_status.ValueOrDie();
308
309 Tensor temp(dynamism.dtype());
310 TensorShape tensor_shape({});
311 if (!temp.CopyFrom(dynamism, tensor_shape)) {
312 return errors::InvalidArgument(
313 context_->op_kernel().name(), " input ", index, " has shape ",
314 dynamism.shape().DebugString(), " which is not a R0 ", tensor_shape);
315 }
316
317 TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp));
318 *out = literal.Get<bool>({});
319 return Status::OK();
320 }
321
ResolveInputDynamismIntoPredVector(absl::string_view name,std::vector<bool> * out)322 Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
323 absl::string_view name, std::vector<bool>* out) {
324 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
325 return ResolveInputDynamismIntoPredVector(index, out);
326 }
327
ResolveInputDynamismIntoPred(absl::string_view name,bool * out)328 Status XlaOpKernelContext::ResolveInputDynamismIntoPred(absl::string_view name,
329 bool* out) {
330 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
331 return ResolveInputDynamismIntoPred(index, out);
332 }
333
ResolveInputDynamismReshaped(int index,absl::Span<const int64> new_dims,xla::Literal * dynamism_literal)334 Status XlaOpKernelContext::ResolveInputDynamismReshaped(
335 int index, absl::Span<const int64> new_dims,
336 xla::Literal* dynamism_literal) {
337 XlaExpression e = InputExpression(index);
338 auto* client = compiler() ? compiler()->client() : nullptr;
339 StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client);
340 if (!dynamism_or_status.ok()) {
341 xla::Literal true_literal = xla::LiteralUtil::CreateR0<bool>(true);
342 // When failed to resolve dynamism, conservatively consider the value
343 // dynamic. This could happen if the input depends on some ops like
344 // custom-call that is not supported generally for dynamism computation.
345 *dynamism_literal =
346 true_literal
347 .Broadcast(xla::ShapeUtil::MakeShape(xla::PRED, new_dims), {})
348 .ValueOrDie();
349
350 return Status::OK();
351 }
352 Tensor dynamism = dynamism_or_status.ValueOrDie();
353
354 Tensor temp(dynamism.dtype());
355 if (!temp.CopyFrom(dynamism, TensorShape(new_dims))) {
356 return errors::InvalidArgument(
357 context_->op_kernel().name(), " input ", index, " has shape ",
358 dynamism.shape().DebugString(),
359 " but was asked to be reshaped to incompatible shape ",
360 TensorShape(new_dims).DebugString());
361 }
362
363 TF_ASSIGN_OR_RETURN(*dynamism_literal, HostTensorToLiteral(temp));
364 return Status::OK();
365 }
366
ResolveInputDynamismIntoPredVector(int index,std::vector<bool> * out)367 Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
368 int index, std::vector<bool>* out) {
369 xla::Literal literal;
370 TF_RETURN_IF_ERROR(ResolveInputDynamismReshaped(
371 index, {InputShape(index).num_elements()}, &literal));
372
373 return LiteralToPredVector(literal, out);
374 }
375
376 // Converts an int32 or int64 1D literal to an int64 vector.
LiteralToInt64Vector(const xla::LiteralSlice & literal,std::vector<int64> * out)377 static Status LiteralToInt64Vector(const xla::LiteralSlice& literal,
378 std::vector<int64>* out) {
379 if (literal.shape().rank() != 1) {
380 return errors::InvalidArgument("value is not 1D, rank: ",
381 literal.shape().rank());
382 }
383 int64_t size = xla::ShapeUtil::ElementsIn(literal.shape());
384 if (literal.shape().element_type() == xla::S32) {
385 for (int64_t i = 0; i < size; ++i) {
386 out->push_back(literal.Get<int32>({i}));
387 }
388 } else if (literal.shape().element_type() == xla::S64) {
389 for (int64_t i = 0; i < size; ++i) {
390 out->push_back(literal.Get<int64>({i}));
391 }
392 } else {
393 return errors::InvalidArgument("value must be either int32 or int64");
394 }
395 return Status::OK();
396 }
397
ConstantInputAsIntVector(int index,std::vector<int64> * out,xla::ValueInferenceMode mode)398 Status XlaOpKernelContext::ConstantInputAsIntVector(
399 int index, std::vector<int64>* out, xla::ValueInferenceMode mode) {
400 xla::Literal literal;
401 TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
402 return LiteralToInt64Vector(literal, out);
403 }
404
ConstantInputAsIntVector(absl::string_view name,std::vector<int64> * out,xla::ValueInferenceMode mode)405 Status XlaOpKernelContext::ConstantInputAsIntVector(
406 absl::string_view name, std::vector<int64>* out,
407 xla::ValueInferenceMode mode) {
408 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
409 return ConstantInputAsIntVector(index, out, mode);
410 }
411
ConstantInputReshapedToIntVector(int index,std::vector<int64> * out,xla::ValueInferenceMode mode)412 Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
413 int index, std::vector<int64>* out, xla::ValueInferenceMode mode) {
414 xla::Literal literal;
415 TF_RETURN_IF_ERROR(ConstantInputReshaped(
416 index, {InputShape(index).num_elements()}, &literal, mode));
417 return LiteralToInt64Vector(literal, out);
418 }
419
ConstantInputReshapedToIntVector(absl::string_view name,std::vector<int64> * out,xla::ValueInferenceMode mode)420 Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
421 absl::string_view name, std::vector<int64>* out,
422 xla::ValueInferenceMode mode) {
423 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
424 xla::Literal literal;
425 TF_RETURN_IF_ERROR(ConstantInputReshaped(
426 index, {InputShape(index).num_elements()}, &literal, mode));
427 return LiteralToInt64Vector(literal, out);
428 }
429
ConstantInputAsInt64Literal(int index,xla::Literal * out,xla::ValueInferenceMode mode)430 Status XlaOpKernelContext::ConstantInputAsInt64Literal(
431 int index, xla::Literal* out, xla::ValueInferenceMode mode) {
432 xla::Literal literal;
433 TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
434 switch (literal.shape().element_type()) {
435 case xla::S32: {
436 *out = xla::Literal(
437 xla::ShapeUtil::ChangeElementType(literal.shape(), xla::S64));
438 auto src_data = literal.data<int32>();
439 for (int64_t i = 0; i < src_data.size(); ++i) {
440 out->data<int64>()[i] = src_data[i];
441 }
442 return Status::OK();
443 }
444 case xla::S64:
445 *out = std::move(literal);
446 return Status::OK();
447
448 default:
449 return errors::InvalidArgument(
450 "Invalid argument to ConstantInputAsInt64Literal: ",
451 xla::ShapeUtil::HumanString(literal.shape()));
452 }
453 }
454
ConstantInputAsInt64Literal(absl::string_view name,xla::Literal * out,xla::ValueInferenceMode mode)455 Status XlaOpKernelContext::ConstantInputAsInt64Literal(
456 absl::string_view name, xla::Literal* out, xla::ValueInferenceMode mode) {
457 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
458 return ConstantInputAsInt64Literal(index, out, mode);
459 }
460
461 // TODO(phawkins): validate that the dimensions form a valid shape, fail
462 // gracefully if they do not.
ConstantInputAsShape(int index,TensorShape * shape,xla::ValueInferenceMode mode)463 Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape,
464 xla::ValueInferenceMode mode) {
465 xla::Literal literal;
466 TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
467 std::vector<int64> dims;
468 TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
469 *shape = TensorShape(dims);
470 return Status::OK();
471 }
472
ConstantInputAsPartialShape(int index,PartialTensorShape * shape)473 Status XlaOpKernelContext::ConstantInputAsPartialShape(
474 int index, PartialTensorShape* shape) {
475 xla::Literal literal;
476 TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
477 // If `literal` is a scalar it's value must be -1.
478 if (literal.shape().rank() == 0) {
479 int64_t shape_val;
480 TF_RETURN_IF_ERROR(LiteralToInt64Scalar(literal, &shape_val));
481 if (shape_val != -1) {
482 return errors::InvalidArgument(
483 "Cannot convert value to PartialTensorShape: ", shape_val);
484 }
485 *shape = PartialTensorShape(); // Shape with unknown rank.
486 return Status::OK();
487 }
488 std::vector<int64> dims;
489 TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
490 *shape = PartialTensorShape(dims);
491 return Status::OK();
492 }
493
InputList(absl::string_view name,std::vector<xla::XlaOp> * handles,std::vector<TensorShape> * shapes)494 Status XlaOpKernelContext::InputList(absl::string_view name,
495 std::vector<xla::XlaOp>* handles,
496 std::vector<TensorShape>* shapes) {
497 OpInputList inputs;
498 TF_RETURN_IF_ERROR(context_->input_list(name, &inputs));
499 handles->clear();
500 shapes->clear();
501 for (const Tensor& input : inputs) {
502 handles->push_back(
503 XlaExpression::CastExpressionFromTensor(input)->AsXlaOp(builder()));
504 shapes->push_back(input.shape());
505 }
506 return Status::OK();
507 }
508
ConstantInputList(absl::string_view name,std::vector<xla::Literal> * outputs,xla::ValueInferenceMode mode)509 Status XlaOpKernelContext::ConstantInputList(absl::string_view name,
510 std::vector<xla::Literal>* outputs,
511 xla::ValueInferenceMode mode) {
512 int start, stop;
513 TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop));
514 outputs->resize(stop - start);
515 for (int i = start; i < stop; ++i) {
516 TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i], mode));
517 }
518 return Status::OK();
519 }
520
521 namespace {
522
ReadVariableInputTensor(const Tensor & tensor,DataType type,const XlaOpKernelContext * ctx,TensorShape * shape,xla::XlaOp * value)523 Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
524 const XlaOpKernelContext* ctx,
525 TensorShape* shape, xla::XlaOp* value) {
526 const XlaExpression* expression =
527 XlaExpression::CastExpressionFromTensor(tensor);
528 XlaResource* variable = expression->resource();
529 TF_RET_CHECK(variable != nullptr);
530 TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
531 if (!variable->initialized()) {
532 return errors::FailedPrecondition(
533 "Read variable failure ", variable->name(),
534 ". It could mean the variable is uninitialized or the variable is on "
535 "another device ");
536 }
537 if (variable->type() != type) {
538 return errors::InvalidArgument(
539 "Type mismatch for read of variable ", variable->name(), ". Expected ",
540 DataTypeString(type), "; got ", DataTypeString(variable->type()));
541 }
542 if (shape) {
543 *shape = variable->shape();
544 }
545
546 if (!variable->IsOverwritten() && expression->constant_value()) {
547 TF_ASSIGN_OR_RETURN(xla::Literal literal,
548 HostTensorToLiteral(*expression->constant_value()));
549 *value = xla::ConstantLiteral(ctx->builder(), literal);
550 return Status::OK();
551 }
552
553 TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
554 ctx->compiler()->options().shape_representation_fn(
555 variable->shape(), variable->type(),
556 /*use_fast_memory=*/false));
557 xla::Shape xla_shape;
558 TF_RETURN_IF_ERROR(
559 TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape));
560 if (xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
561 *value = variable->value();
562 } else {
563 *value = xla::Reshape(variable->value(), variable->shape().dim_sizes());
564 }
565 return Status::OK();
566 }
567
568 } // namespace
569
ReadVariableInput(int index,DataType type,TensorShape * shape,xla::XlaOp * value)570 Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
571 TensorShape* shape,
572 xla::XlaOp* value) {
573 return ReadVariableInputTensor(context_->input(index), type, this, shape,
574 value);
575 }
576
ReadVariableInput(absl::string_view name,DataType type,TensorShape * shape,xla::XlaOp * value)577 Status XlaOpKernelContext::ReadVariableInput(absl::string_view name,
578 DataType type, TensorShape* shape,
579 xla::XlaOp* value) {
580 return ReadVariableInputTensor(GetInputTensorByName(name), type, this, shape,
581 value);
582 }
583
GetVariableTypeAndShape(int index,DataType * type,TensorShape * shape) const584 Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
585 TensorShape* shape) const {
586 const Tensor& tensor = context_->input(index);
587 const XlaExpression* expression =
588 XlaExpression::CastExpressionFromTensor(tensor);
589 XlaResource* variable = expression->resource();
590 TF_RET_CHECK(variable != nullptr);
591 TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
592 if (!variable->initialized()) {
593 return errors::InvalidArgument(
594 "Read variable failure ", variable->name(),
595 ". It could mean the variable is uninitialized or the variable is on "
596 "another device ");
597 }
598 *type = variable->type();
599 *shape = variable->shape();
600 return Status::OK();
601 }
602
SetOutputExpression(int index,const XlaExpression & expression)603 void XlaOpKernelContext::SetOutputExpression(int index,
604 const XlaExpression& expression) {
605 Status status = [&] {
606 // The step's default allocator is the dummy XlaCompilationAllocator which
607 // simply allocates a metadata buffer to hold the expression to which it
608 // corresponds.
609 // Provides a special behavior for DT_VARIANT and other types that are not
610 // trivially copyable. In those cases, allocate a tensor of type DT_UINT8.
611 if (!DataTypeCanUseMemcpy(expression.dtype())) {
612 // tensor_data() is not supported for tensors that cannot be copied via
613 // memcpy, as the copy logic might try to inspect the stored data (e.g.
614 // a std::string). This is likely to fail, as the data is invalid given
615 // that it actually encodes an XlaExpression. Using a uint8 tensor is
616 // always safe, so simply do that.
617 // TODO(jpienaar): This should be refactored to stop masquerading
618 // XlaExpressions as Tensors.
619 Tensor output;
620 TensorShape tensor_shape;
621 TF_RETURN_IF_ERROR(
622 context_->allocate_temp(DT_UINT8, tensor_shape, &output));
623 context_->set_output(index, output);
624 } else {
625 Tensor* output = nullptr;
626 TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape());
627 TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output));
628 }
629 XlaExpression::AssignExpressionToTensor(expression,
630 context_->mutable_output(index));
631 return Status::OK();
632 }();
633 if (!status.ok()) {
634 SetStatus(status);
635 }
636 }
637
output_xla_type(int index)638 xla::PrimitiveType XlaOpKernelContext::output_xla_type(int index) {
639 xla::PrimitiveType type;
640 Status status = DataTypeToPrimitiveType(expected_output_dtype(index), &type);
641 if (!status.ok()) {
642 SetStatus(status);
643 return xla::PRIMITIVE_TYPE_INVALID;
644 }
645 return type;
646 }
647
SetOutput(int index,const xla::XlaOp & handle)648 void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
649 SetOutputExpression(
650 index,
651 XlaExpression::XlaOp(handle, context_->expected_output_dtype(index)));
652 }
653
SetConstantOutput(int index,const Tensor & constant)654 void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
655 SetOutputExpression(index, XlaExpression::Constant(constant));
656 }
657
SetTensorListOutput(int index,const xla::XlaOp & handle)658 void XlaOpKernelContext::SetTensorListOutput(int index,
659 const xla::XlaOp& handle) {
660 SetOutputExpression(index, XlaExpression::TensorList(handle));
661 }
662
SetResourceOutput(int index,XlaResource * resource)663 void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
664 SetOutputExpression(index, XlaExpression::Resource(resource));
665 }
666
GetResourceInput(int index,XlaResource ** resource)667 Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
668 const XlaExpression* expression =
669 XlaExpression::CastExpressionFromTensor(context_->input(index));
670 TF_RET_CHECK(expression->resource() != nullptr);
671 *resource = expression->resource();
672 return Status::OK();
673 }
674
675 namespace {
676
AssignVariableTensor(const Tensor & tensor,DataType type,const XlaOpKernelContext * ctx,xla::XlaOp handle,xla::XlaBuilder * builder)677 Status AssignVariableTensor(const Tensor& tensor, DataType type,
678 const XlaOpKernelContext* ctx, xla::XlaOp handle,
679 xla::XlaBuilder* builder) {
680 const XlaExpression* expression =
681 XlaExpression::CastExpressionFromTensor(tensor);
682 XlaResource* variable = expression->resource();
683 TF_RET_CHECK(variable != nullptr);
684 TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
685
686 auto shape_or_status = builder->GetShape(handle);
687 if (!shape_or_status.ok()) {
688 return shape_or_status.status();
689 }
690 TensorShape shape;
691 TF_RETURN_IF_ERROR(
692 XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape));
693
694 TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
695
696 TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
697 ctx->compiler()->options().shape_representation_fn(
698 shape, type,
699 /*use_fast_memory=*/false));
700 xla::Shape xla_shape;
701 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
702 if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
703 handle = xla::Reshape(handle,
704 xla::AsInt64Slice(representation_shape.dimensions()));
705 }
706 variable->SetRepresentationShape(representation_shape);
707 return variable->SetValue(handle);
708 }
709
710 } // namespace
711
AssignVariable(int input_index,DataType type,xla::XlaOp handle)712 Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
713 xla::XlaOp handle) {
714 TF_RET_CHECK(handle.valid());
715 return AssignVariableTensor(context_->input(input_index), type, this, handle,
716 builder());
717 }
718
AssignVariable(absl::string_view name,DataType type,xla::XlaOp handle)719 Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type,
720 xla::XlaOp handle) {
721 TF_RET_CHECK(handle.valid());
722 return AssignVariableTensor(GetInputTensorByName(name), type, this, handle,
723 builder());
724 }
725
GetStatusWithStackTrace(const Status & s,const XlaOpKernelContext * ctx)726 static Status GetStatusWithStackTrace(const Status& s,
727 const XlaOpKernelContext* ctx) {
728 if (s.code() == error::INVALID_ARGUMENT) {
729 return Status{s.code(),
730 absl::StrCat(s.error_message(), "\n", ctx->StackTrace())};
731 }
732 return s;
733 }
734
CtxFailure(const Status & s)735 void XlaOpKernelContext::CtxFailure(const Status& s) {
736 context_->CtxFailure(GetStatusWithStackTrace(s, this));
737 }
CtxFailureWithWarning(const Status & s)738 void XlaOpKernelContext::CtxFailureWithWarning(const Status& s) {
739 context_->CtxFailureWithWarning(GetStatusWithStackTrace(s, this));
740 }
741
CtxFailure(const char * file,int line,const Status & s)742 void XlaOpKernelContext::CtxFailure(const char* file, int line,
743 const Status& s) {
744 context_->CtxFailure(file, line, GetStatusWithStackTrace(s, this));
745 }
CtxFailureWithWarning(const char * file,int line,const Status & s)746 void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line,
747 const Status& s) {
748 context_->CtxFailureWithWarning(file, line, GetStatusWithStackTrace(s, this));
749 }
750
GetOrCreateMax(const DataType type)751 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMax(
752 const DataType type) {
753 return xla_context()->GetOrCreateMax(type);
754 }
755
GetOrCreateMin(const DataType type)756 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMin(
757 const DataType type) {
758 return xla_context()->GetOrCreateMin(type);
759 }
760
GetOrCreateAdd(const DataType type)761 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateAdd(
762 const DataType type) {
763 return xla_context()->GetOrCreateAdd(type);
764 }
765
GetOrCreateMul(const DataType type)766 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul(
767 const DataType type) {
768 return xla_context()->GetOrCreateMul(type);
769 }
770
GetInputTensorByName(absl::string_view name)771 const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) {
772 const Tensor* tensor;
773 CHECK(context_->input(name, &tensor).ok());
774 return *tensor;
775 }
776
XlaOpKernel(OpKernelConstruction * context)777 XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {}
778
Compute(OpKernelContext * context)779 void XlaOpKernel::Compute(OpKernelContext* context) {
780 XlaOpKernelContext xla_context(context);
781 Compile(&xla_context);
782 }
783
StackTrace() const784 std::string XlaOpKernelContext::StackTrace() const {
785 if (const AbstractStackTrace* stack_trace =
786 xla_context()->StackTraceForNodeName(op_kernel().name())) {
787 AbstractStackTrace::TracePrintingOptions opts;
788 opts.show_line_contents = true;
789 opts.filter_common_prefix = true;
790 opts.drop_internal_frames = true;
791 return absl::StrCat("\nStack trace for op definition: \n",
792 stack_trace->ToString(opts), "\n");
793 } else {
794 return "";
795 }
796 }
797
798 } // namespace tensorflow
799