1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
17
18 #include <numeric>
19
20 #include "absl/memory/memory.h"
21 #include "tensorflow/compiler/tf2xla/literal_util.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/type_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
25 #include "tensorflow/compiler/tf2xla/xla_context.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/client/xla_computation.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/core/common_runtime/dma_helper.h"
30
31 namespace tensorflow {
32
XlaOpKernelContext(OpKernelContext * context)33 XlaOpKernelContext::XlaOpKernelContext(OpKernelContext* context)
34 : context_(context), dynamic_dimension_is_minus_one_(false) {}
35
ValidateInputsAreSameShape(OpKernel * op)36 bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
37 return context_->ValidateInputsAreSameShape(op);
38 }
39
xla_context() const40 XlaContext* XlaOpKernelContext::xla_context() const {
41 return &XlaContext::Get(context_);
42 }
43
builder() const44 xla::XlaBuilder* XlaOpKernelContext::builder() const {
45 return xla_context()->builder();
46 }
47
compiler() const48 XlaCompiler* XlaOpKernelContext::compiler() const {
49 return xla_context()->compiler();
50 }
51
InputExpression(int index)52 const XlaExpression& XlaOpKernelContext::InputExpression(int index) {
53 return *XlaExpression::CastExpressionFromTensor(context_->input(index));
54 }
55
InputExpression(absl::string_view name)56 const XlaExpression& XlaOpKernelContext::InputExpression(
57 absl::string_view name) {
58 return *XlaExpression::CastExpressionFromTensor(GetInputTensorByName(name));
59 }
60
Input(int index)61 xla::XlaOp XlaOpKernelContext::Input(int index) {
62 return InputExpression(index).AsXlaOp(builder());
63 }
64
Input(absl::string_view name)65 xla::XlaOp XlaOpKernelContext::Input(absl::string_view name) {
66 return InputExpression(name).AsXlaOp(builder());
67 }
68
InputShape(int index)69 TensorShape XlaOpKernelContext::InputShape(int index) {
70 return context_->input(index).shape();
71 }
72
InputShape(absl::string_view name)73 TensorShape XlaOpKernelContext::InputShape(absl::string_view name) {
74 return GetInputTensorByName(name).shape();
75 }
76
InputXlaShape(int index)77 xla::StatusOr<xla::Shape> XlaOpKernelContext::InputXlaShape(int index) {
78 return builder()->GetShape(Input(index));
79 }
80
InputXlaShape(absl::string_view name)81 xla::StatusOr<xla::Shape> XlaOpKernelContext::InputXlaShape(
82 absl::string_view name) {
83 return builder()->GetShape(Input(name));
84 }
85
input_type(int index) const86 DataType XlaOpKernelContext::input_type(int index) const {
87 DataType type = context_->input_dtype(index);
88 if (type == DT_UINT8) {
89 // Masqueraded XlaExpression could have different type. See
90 // XlaOpKernelContext::SetOutputExpression for details.
91 auto expression =
92 XlaExpression::CastExpressionFromTensor(context_->input(index));
93 type = expression->dtype();
94 }
95 return type;
96 }
97
InputType(absl::string_view name)98 DataType XlaOpKernelContext::InputType(absl::string_view name) {
99 const Tensor& tensor = GetInputTensorByName(name);
100 DataType type = tensor.dtype();
101 if (type == DT_UINT8) {
102 // Masqueraded XlaExpression could have different type. See
103 // XlaOpKernelContext::SetOutputExpression for details.
104 auto expression = XlaExpression::CastExpressionFromTensor(tensor);
105 type = expression->dtype();
106 }
107 return type;
108 }
109
input_xla_type(int index)110 xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) {
111 xla::PrimitiveType type;
112 Status status = DataTypeToPrimitiveType(input_type(index), &type);
113 if (!status.ok()) {
114 SetStatus(status);
115 return xla::PRIMITIVE_TYPE_INVALID;
116 }
117 return type;
118 }
119
InputXlaType(absl::string_view name)120 xla::PrimitiveType XlaOpKernelContext::InputXlaType(absl::string_view name) {
121 xla::PrimitiveType type;
122 Status status = DataTypeToPrimitiveType(InputType(name), &type);
123 if (!status.ok()) {
124 SetStatus(status);
125 return xla::PRIMITIVE_TYPE_INVALID;
126 }
127 return type;
128 }
129
ConstantInput(int index,xla::Literal * constant_literal)130 Status XlaOpKernelContext::ConstantInput(int index,
131 xla::Literal* constant_literal) {
132 return ConstantInputReshaped(
133 index, context_->input(index).shape().dim_sizes(), constant_literal);
134 }
135
InputIndex(XlaOpKernelContext * context,absl::string_view name)136 static xla::StatusOr<int> InputIndex(XlaOpKernelContext* context,
137 absl::string_view name) {
138 int start, stop;
139 TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop));
140 if (stop != start + 1) {
141 return errors::InvalidArgument("OpKernel used list-valued input name '",
142 name,
143 "' when single-valued input was "
144 "expected");
145 }
146 return start;
147 }
148
ConstantInput(absl::string_view name,xla::Literal * constant_literal)149 Status XlaOpKernelContext::ConstantInput(absl::string_view name,
150 xla::Literal* constant_literal) {
151 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
152 return ConstantInput(index, constant_literal);
153 }
154
ConstantInputReshaped(int index,absl::Span<const int64> new_dims,xla::Literal * constant_literal)155 Status XlaOpKernelContext::ConstantInputReshaped(
156 int index, absl::Span<const int64> new_dims,
157 xla::Literal* constant_literal) {
158 XlaExpression e = InputExpression(index);
159 auto* client = compiler() ? compiler()->client() : nullptr;
160 xla::StatusOr<absl::optional<Tensor>> constant_or_status =
161 e.ResolveConstant(client, dynamic_dimension_is_minus_one_);
162 if (!constant_or_status.ok()) {
163 Status status = constant_or_status.status();
164 errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
165 context_->op_kernel().type_string(),
166 " operator as a compile-time constant.");
167 return status;
168 }
169 absl::optional<Tensor> constant = constant_or_status.ValueOrDie();
170 if (!constant.has_value()) {
171 return errors::InvalidArgument(
172 "Input ", index, " to node `", context_->op_kernel().name(),
173 "` with op ", context_->op_kernel().type_string(),
174 " must be a compile-time constant.\n\n"
175 "XLA compilation requires that operator arguments that represent "
176 "shapes or dimensions be evaluated to concrete values at compile time. "
177 "This error means that a shape or dimension argument could not be "
178 "evaluated at compile time, usually because the value of the argument "
179 "depends on a parameter to the computation, on a variable, or on a "
180 "stateful operation such as a random number generator.");
181 }
182
183 Tensor temp(constant->dtype());
184 if (!temp.CopyFrom(*constant, TensorShape(new_dims))) {
185 return errors::InvalidArgument(
186 context_->op_kernel().name(), " input ", index, " has shape ",
187 constant->shape().DebugString(),
188 " but was asked to be reshaped to incompatible shape ",
189 TensorShape(new_dims).DebugString());
190 }
191
192 TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
193 return Status::OK();
194 }
195
196 // Converts an int32 or int64 scalar literal to an int64.
LiteralToInt64Scalar(const xla::LiteralSlice & literal,int64 * out)197 static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal,
198 int64* out) {
199 if (literal.shape().rank() != 0) {
200 return errors::InvalidArgument("value is not a scalar");
201 }
202 if (literal.shape().element_type() == xla::S32) {
203 *out = literal.Get<int32>({});
204 } else if (literal.shape().element_type() == xla::S64) {
205 *out = literal.Get<int64>({});
206 } else {
207 return errors::InvalidArgument("value must be either int32 or int64");
208 }
209 return Status::OK();
210 }
211
212 // Converts an float32 or float64 scalar literal to a float64.
LiteralToFloat64Scalar(const xla::LiteralSlice & literal,double * out)213 static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal,
214 double* out) {
215 if (literal.shape().rank() != 0) {
216 return errors::InvalidArgument("value is not a scalar");
217 }
218 if (literal.shape().element_type() == xla::F32) {
219 *out = literal.Get<float>({});
220 } else if (literal.shape().element_type() == xla::F64) {
221 *out = literal.Get<double>({});
222 } else {
223 return errors::InvalidArgument("value must be either float32 or float64");
224 }
225 return Status::OK();
226 }
227
ConstantInputAsIntScalar(int index,int64 * out)228 Status XlaOpKernelContext::ConstantInputAsIntScalar(int index, int64* out) {
229 xla::Literal literal;
230 TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
231 return LiteralToInt64Scalar(literal, out);
232 }
233
ConstantInputAsIntScalar(absl::string_view name,int64 * out)234 Status XlaOpKernelContext::ConstantInputAsIntScalar(absl::string_view name,
235 int64* out) {
236 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
237 return ConstantInputAsIntScalar(index, out);
238 }
239
ConstantInputAsFloatScalar(int index,double * out)240 Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) {
241 xla::Literal literal;
242 TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
243 return LiteralToFloat64Scalar(literal, out);
244 }
245
LiteralToPredVector(const xla::LiteralSlice & literal,std::vector<bool> * out)246 static Status LiteralToPredVector(const xla::LiteralSlice& literal,
247 std::vector<bool>* out) {
248 if (literal.shape().rank() != 1) {
249 return errors::InvalidArgument("value is not 1D, rank: ",
250 literal.shape().rank());
251 }
252 int64 size = xla::ShapeUtil::ElementsIn(literal.shape());
253 if (literal.shape().element_type() != xla::PRED) {
254 return errors::InvalidArgument("value is not PRED");
255 }
256 for (int64 i = 0; i < size; ++i) {
257 out->push_back(literal.Get<bool>({i}));
258 }
259 return Status::OK();
260 }
261
ResolveInputDynamismIntoPred(int index,bool * out)262 Status XlaOpKernelContext::ResolveInputDynamismIntoPred(int index, bool* out) {
263 xla::Literal literal;
264 XlaExpression e = InputExpression(index);
265 auto* client = compiler() ? compiler()->client() : nullptr;
266 xla::StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client);
267 if (!dynamism_or_status.ok()) {
268 // When failed to resolve dynamism, conservatively consider the value
269 // dynamic.
270 //
271 // TODO(b/176993339): Support resolving dynamism across computations so
272 // resolving dynamism will not fail.
273 *out = true;
274 return Status::OK();
275 }
276 Tensor dynamism = dynamism_or_status.ValueOrDie();
277
278 Tensor temp(dynamism.dtype());
279 TensorShape tensor_shape({});
280 if (!temp.CopyFrom(dynamism, tensor_shape)) {
281 return errors::InvalidArgument(
282 context_->op_kernel().name(), " input ", index, " has shape ",
283 dynamism.shape().DebugString(), " which is not a R0 ", tensor_shape);
284 }
285
286 TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp));
287 *out = literal.Get<bool>({});
288 return Status::OK();
289 }
290
ResolveInputDynamismIntoPredVector(int index,std::vector<bool> * out)291 Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
292 int index, std::vector<bool>* out) {
293 xla::Literal literal;
294 XlaExpression e = InputExpression(index);
295 auto* client = compiler() ? compiler()->client() : nullptr;
296 xla::StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client);
297 if (!dynamism_or_status.ok()) {
298 // When failed to resolve dynamism, conservatively consider the value
299 // dynamic.
300 //
301 // TODO(b/176993339): Support resolving dynamism across computations so
302 // resolving dynamism will not fail.
303 out->resize(InputShape(index).num_elements(), false);
304 return Status::OK();
305 }
306 Tensor dynamism = dynamism_or_status.ValueOrDie();
307
308 Tensor temp(dynamism.dtype());
309 TensorShape tensor_shape({InputShape(index).num_elements()});
310 if (!temp.CopyFrom(dynamism, tensor_shape)) {
311 return errors::InvalidArgument(
312 context_->op_kernel().name(), " input ", index, " has shape ",
313 dynamism.shape().DebugString(), " which is not a R1 ", tensor_shape);
314 }
315
316 TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp));
317 return LiteralToPredVector(literal, out);
318 }
319
320 // Converts an int32 or int64 1D literal to an int64 vector.
LiteralToInt64Vector(const xla::LiteralSlice & literal,std::vector<int64> * out)321 static Status LiteralToInt64Vector(const xla::LiteralSlice& literal,
322 std::vector<int64>* out) {
323 if (literal.shape().rank() != 1) {
324 return errors::InvalidArgument("value is not 1D, rank: ",
325 literal.shape().rank());
326 }
327 int64 size = xla::ShapeUtil::ElementsIn(literal.shape());
328 if (literal.shape().element_type() == xla::S32) {
329 for (int64 i = 0; i < size; ++i) {
330 out->push_back(literal.Get<int32>({i}));
331 }
332 } else if (literal.shape().element_type() == xla::S64) {
333 for (int64 i = 0; i < size; ++i) {
334 out->push_back(literal.Get<int64>({i}));
335 }
336 } else {
337 return errors::InvalidArgument("value must be either int32 or int64");
338 }
339 return Status::OK();
340 }
341
ConstantInputAsIntVector(int index,std::vector<int64> * out)342 Status XlaOpKernelContext::ConstantInputAsIntVector(int index,
343 std::vector<int64>* out) {
344 xla::Literal literal;
345 TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
346 return LiteralToInt64Vector(literal, out);
347 }
348
ConstantInputAsIntVector(absl::string_view name,std::vector<int64> * out)349 Status XlaOpKernelContext::ConstantInputAsIntVector(absl::string_view name,
350 std::vector<int64>* out) {
351 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
352 return ConstantInputAsIntVector(index, out);
353 }
354
ConstantInputReshapedToIntVector(int index,std::vector<int64> * out)355 Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
356 int index, std::vector<int64>* out) {
357 xla::Literal literal;
358 TF_RETURN_IF_ERROR(ConstantInputReshaped(
359 index, {InputShape(index).num_elements()}, &literal));
360 return LiteralToInt64Vector(literal, out);
361 }
362
ConstantInputReshapedToIntVector(absl::string_view name,std::vector<int64> * out)363 Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
364 absl::string_view name, std::vector<int64>* out) {
365 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
366 xla::Literal literal;
367 TF_RETURN_IF_ERROR(ConstantInputReshaped(
368 index, {InputShape(index).num_elements()}, &literal));
369 return LiteralToInt64Vector(literal, out);
370 }
371
ConstantInputAsInt64Literal(int index,xla::Literal * out)372 Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index,
373 xla::Literal* out) {
374 xla::Literal literal;
375 TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
376 switch (literal.shape().element_type()) {
377 case xla::S32: {
378 *out = xla::Literal(
379 xla::ShapeUtil::ChangeElementType(literal.shape(), xla::S64));
380 auto src_data = literal.data<int32>();
381 for (int64 i = 0; i < src_data.size(); ++i) {
382 out->data<int64>()[i] = src_data[i];
383 }
384 return Status::OK();
385 }
386 case xla::S64:
387 *out = std::move(literal);
388 return Status::OK();
389
390 default:
391 return errors::InvalidArgument(
392 "Invalid argument to ConstantInputAsInt64Literal: ",
393 xla::ShapeUtil::HumanString(literal.shape()));
394 }
395 }
396
ConstantInputAsInt64Literal(absl::string_view name,xla::Literal * out)397 Status XlaOpKernelContext::ConstantInputAsInt64Literal(absl::string_view name,
398 xla::Literal* out) {
399 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
400 return ConstantInputAsInt64Literal(index, out);
401 }
402
403 // TODO(phawkins): validate that the dimensions form a valid shape, fail
404 // gracefully if they do not.
ConstantInputAsShape(int index,TensorShape * shape)405 Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) {
406 xla::Literal literal;
407 TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
408 std::vector<int64> dims;
409 TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
410 *shape = TensorShape(dims);
411 return Status::OK();
412 }
413
ConstantInputAsPartialShape(int index,PartialTensorShape * shape)414 Status XlaOpKernelContext::ConstantInputAsPartialShape(
415 int index, PartialTensorShape* shape) {
416 xla::Literal literal;
417 TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
418 // If `literal` is a scalar it's value must be -1.
419 if (literal.shape().rank() == 0) {
420 int64 shape_val;
421 TF_RETURN_IF_ERROR(LiteralToInt64Scalar(literal, &shape_val));
422 if (shape_val != -1) {
423 return errors::InvalidArgument(
424 "Cannot convert value to PartialTensorShape: ", shape_val);
425 }
426 *shape = PartialTensorShape(); // Shape with unknown rank.
427 return Status::OK();
428 }
429 std::vector<int64> dims;
430 TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
431 *shape = PartialTensorShape(dims);
432 return Status::OK();
433 }
434
InputList(absl::string_view name,std::vector<xla::XlaOp> * handles,std::vector<TensorShape> * shapes)435 Status XlaOpKernelContext::InputList(absl::string_view name,
436 std::vector<xla::XlaOp>* handles,
437 std::vector<TensorShape>* shapes) {
438 OpInputList inputs;
439 TF_RETURN_IF_ERROR(context_->input_list(name, &inputs));
440 handles->clear();
441 shapes->clear();
442 for (const Tensor& input : inputs) {
443 handles->push_back(
444 XlaExpression::CastExpressionFromTensor(input)->AsXlaOp(builder()));
445 shapes->push_back(input.shape());
446 }
447 return Status::OK();
448 }
449
ConstantInputList(absl::string_view name,std::vector<xla::Literal> * outputs)450 Status XlaOpKernelContext::ConstantInputList(
451 absl::string_view name, std::vector<xla::Literal>* outputs) {
452 int start, stop;
453 TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop));
454 outputs->resize(stop - start);
455 for (int i = start; i < stop; ++i) {
456 TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i]));
457 }
458 return Status::OK();
459 }
460
461 namespace {
462
ReadVariableInputTensor(const Tensor & tensor,DataType type,const XlaOpKernelContext * ctx,TensorShape * shape,xla::XlaOp * value)463 Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
464 const XlaOpKernelContext* ctx,
465 TensorShape* shape, xla::XlaOp* value) {
466 const XlaExpression* expression =
467 XlaExpression::CastExpressionFromTensor(tensor);
468 XlaResource* variable = expression->resource();
469 TF_RET_CHECK(variable != nullptr);
470 TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
471 if (!variable->initialized()) {
472 return errors::FailedPrecondition(
473 "Read variable failure ", variable->name(),
474 ". It could mean the variable is uninitialized or the variable is on "
475 "another device ");
476 }
477 if (variable->type() != type) {
478 return errors::InvalidArgument(
479 "Type mismatch for read of variable ", variable->name(), ". Expected ",
480 DataTypeString(type), "; got ", DataTypeString(variable->type()));
481 }
482 if (shape) {
483 *shape = variable->shape();
484 }
485
486 if (!variable->IsOverwritten() && expression->constant_value()) {
487 TF_ASSIGN_OR_RETURN(xla::Literal literal,
488 HostTensorToLiteral(*expression->constant_value()));
489 *value = xla::ConstantLiteral(ctx->builder(), literal);
490 return Status::OK();
491 }
492
493 TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
494 ctx->compiler()->options().shape_representation_fn(
495 variable->shape(), variable->type(),
496 /*use_fast_memory=*/false));
497 xla::Shape xla_shape;
498 TF_RETURN_IF_ERROR(
499 TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape));
500 if (xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
501 *value = variable->value();
502 } else {
503 *value = xla::Reshape(variable->value(), variable->shape().dim_sizes());
504 }
505 return Status::OK();
506 }
507
508 } // namespace
509
ReadVariableInput(int index,DataType type,TensorShape * shape,xla::XlaOp * value)510 Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
511 TensorShape* shape,
512 xla::XlaOp* value) {
513 return ReadVariableInputTensor(context_->input(index), type, this, shape,
514 value);
515 }
516
ReadVariableInput(absl::string_view name,DataType type,TensorShape * shape,xla::XlaOp * value)517 Status XlaOpKernelContext::ReadVariableInput(absl::string_view name,
518 DataType type, TensorShape* shape,
519 xla::XlaOp* value) {
520 return ReadVariableInputTensor(GetInputTensorByName(name), type, this, shape,
521 value);
522 }
523
GetVariableTypeAndShape(int index,DataType * type,TensorShape * shape) const524 Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
525 TensorShape* shape) const {
526 const Tensor& tensor = context_->input(index);
527 const XlaExpression* expression =
528 XlaExpression::CastExpressionFromTensor(tensor);
529 XlaResource* variable = expression->resource();
530 TF_RET_CHECK(variable != nullptr);
531 TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
532 if (!variable->initialized()) {
533 return errors::InvalidArgument(
534 "Read variable failure ", variable->name(),
535 ". It could mean the variable is uninitialized or the variable is on "
536 "another device ");
537 }
538 *type = variable->type();
539 *shape = variable->shape();
540 return Status::OK();
541 }
542
SetOutputExpression(int index,const XlaExpression & expression)543 void XlaOpKernelContext::SetOutputExpression(int index,
544 const XlaExpression& expression) {
545 Status status = [&] {
546 // The step's default allocator is the dummy XlaCompilationAllocator which
547 // simply allocates a metadata buffer to hold the expression to which it
548 // corresponds.
549 // Provides a special behavior for DT_VARIANT and other types that are not
550 // trivially copyable. In those cases, allocate a tensor of type DT_UINT8.
551 if (!DataTypeCanUseMemcpy(expression.dtype())) {
552 // tensor_data() is not supported for tensors that cannot be copied via
553 // memcpy, as the copy logic might try to inspect the stored data (e.g.
554 // a std::string). This is likely to fail, as the data is invalid given
555 // that it actually encodes an XlaExpression. Using a uint8 tensor is
556 // always safe, so simply do that.
557 // TODO(jpienaar): This should be refactored to stop masquerading
558 // XlaExpressions as Tensors.
559 Tensor output;
560 TensorShape tensor_shape;
561 TF_RETURN_IF_ERROR(
562 context_->allocate_temp(DT_UINT8, tensor_shape, &output));
563 context_->set_output(index, output);
564 } else {
565 Tensor* output = nullptr;
566 TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape());
567 TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output));
568 }
569 XlaExpression::AssignExpressionToTensor(expression,
570 context_->mutable_output(index));
571 return Status::OK();
572 }();
573 if (!status.ok()) {
574 SetStatus(status);
575 }
576 }
577
output_xla_type(int index)578 xla::PrimitiveType XlaOpKernelContext::output_xla_type(int index) {
579 xla::PrimitiveType type;
580 Status status = DataTypeToPrimitiveType(expected_output_dtype(index), &type);
581 if (!status.ok()) {
582 SetStatus(status);
583 return xla::PRIMITIVE_TYPE_INVALID;
584 }
585 return type;
586 }
587
SetOutput(int index,const xla::XlaOp & handle)588 void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
589 SetOutputExpression(
590 index,
591 XlaExpression::XlaOp(handle, context_->expected_output_dtype(index)));
592 }
593
SetConstantOutput(int index,const Tensor & constant)594 void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
595 SetOutputExpression(index, XlaExpression::Constant(constant));
596 }
597
SetTensorListOutput(int index,const xla::XlaOp & handle)598 void XlaOpKernelContext::SetTensorListOutput(int index,
599 const xla::XlaOp& handle) {
600 SetOutputExpression(index, XlaExpression::TensorList(handle));
601 }
602
SetResourceOutput(int index,XlaResource * resource)603 void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
604 SetOutputExpression(index, XlaExpression::Resource(resource));
605 }
606
GetResourceInput(int index,XlaResource ** resource)607 Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
608 const XlaExpression* expression =
609 XlaExpression::CastExpressionFromTensor(context_->input(index));
610 TF_RET_CHECK(expression->resource() != nullptr);
611 *resource = expression->resource();
612 return Status::OK();
613 }
614
615 namespace {
616
AssignVariableTensor(const Tensor & tensor,DataType type,const XlaOpKernelContext * ctx,xla::XlaOp handle,xla::XlaBuilder * builder)617 Status AssignVariableTensor(const Tensor& tensor, DataType type,
618 const XlaOpKernelContext* ctx, xla::XlaOp handle,
619 xla::XlaBuilder* builder) {
620 const XlaExpression* expression =
621 XlaExpression::CastExpressionFromTensor(tensor);
622 XlaResource* variable = expression->resource();
623 TF_RET_CHECK(variable != nullptr);
624 TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
625
626 auto shape_or_status = builder->GetShape(handle);
627 if (!shape_or_status.ok()) {
628 return shape_or_status.status();
629 }
630 TensorShape shape;
631 TF_RETURN_IF_ERROR(
632 XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape));
633
634 TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
635
636 TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
637 ctx->compiler()->options().shape_representation_fn(
638 shape, type,
639 /*use_fast_memory=*/false));
640 xla::Shape xla_shape;
641 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
642 if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
643 handle = xla::Reshape(handle,
644 xla::AsInt64Slice(representation_shape.dimensions()));
645 }
646 variable->SetRepresentationShape(representation_shape);
647 return variable->SetValue(handle);
648 }
649
650 } // namespace
651
AssignVariable(int input_index,DataType type,xla::XlaOp handle)652 Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
653 xla::XlaOp handle) {
654 TF_RET_CHECK(handle.valid());
655 return AssignVariableTensor(context_->input(input_index), type, this, handle,
656 builder());
657 }
658
AssignVariable(absl::string_view name,DataType type,xla::XlaOp handle)659 Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type,
660 xla::XlaOp handle) {
661 TF_RET_CHECK(handle.valid());
662 return AssignVariableTensor(GetInputTensorByName(name), type, this, handle,
663 builder());
664 }
665
GetStatusWithStackTrace(const Status & s,const XlaOpKernelContext * ctx)666 static Status GetStatusWithStackTrace(const Status& s,
667 const XlaOpKernelContext* ctx) {
668 if (s.code() == error::INVALID_ARGUMENT) {
669 return Status{s.code(),
670 absl::StrCat(s.error_message(), "\n", ctx->StackTrace())};
671 }
672 return s;
673 }
674
CtxFailure(const Status & s)675 void XlaOpKernelContext::CtxFailure(const Status& s) {
676 context_->CtxFailure(GetStatusWithStackTrace(s, this));
677 }
CtxFailureWithWarning(const Status & s)678 void XlaOpKernelContext::CtxFailureWithWarning(const Status& s) {
679 context_->CtxFailureWithWarning(GetStatusWithStackTrace(s, this));
680 }
681
CtxFailure(const char * file,int line,const Status & s)682 void XlaOpKernelContext::CtxFailure(const char* file, int line,
683 const Status& s) {
684 context_->CtxFailure(file, line, GetStatusWithStackTrace(s, this));
685 }
CtxFailureWithWarning(const char * file,int line,const Status & s)686 void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line,
687 const Status& s) {
688 context_->CtxFailureWithWarning(file, line, GetStatusWithStackTrace(s, this));
689 }
690
GetOrCreateMax(const DataType type)691 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMax(
692 const DataType type) {
693 return xla_context()->GetOrCreateMax(type);
694 }
695
GetOrCreateMin(const DataType type)696 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMin(
697 const DataType type) {
698 return xla_context()->GetOrCreateMin(type);
699 }
700
GetOrCreateAdd(const DataType type)701 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateAdd(
702 const DataType type) {
703 return xla_context()->GetOrCreateAdd(type);
704 }
705
GetOrCreateMul(const DataType type)706 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul(
707 const DataType type) {
708 return xla_context()->GetOrCreateMul(type);
709 }
710
GetInputTensorByName(absl::string_view name)711 const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) {
712 const Tensor* tensor;
713 CHECK(context_->input(name, &tensor).ok());
714 return *tensor;
715 }
716
XlaOpKernel(OpKernelConstruction * context)717 XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {}
718
Compute(OpKernelContext * context)719 void XlaOpKernel::Compute(OpKernelContext* context) {
720 XlaOpKernelContext xla_context(context);
721 Compile(&xla_context);
722 }
723
StackTrace() const724 std::string XlaOpKernelContext::StackTrace() const {
725 if (const AbstractStackTrace* stack_trace =
726 xla_context()->StackTraceForNodeName(op_kernel().name())) {
727 AbstractStackTrace::TracePrintingOptions opts;
728 opts.show_line_contents = true;
729 opts.filter_common_prefix = true;
730 opts.drop_internal_frames = true;
731 return absl::StrCat("\nStack trace for op definition: \n",
732 stack_trace->ToString(opts), "\n");
733 } else {
734 return "";
735 }
736 }
737
738 } // namespace tensorflow
739