1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
17
18 #include <numeric>
19
20 #include "tensorflow/compiler/tf2xla/literal_util.h"
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/type_util.h"
23 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
24 #include "tensorflow/compiler/tf2xla/xla_context.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/core/common_runtime/dma_helper.h"
29
30 namespace tensorflow {
31
XlaOpKernelContext(OpKernelContext * context)32 XlaOpKernelContext::XlaOpKernelContext(OpKernelContext* context)
33 : context_(context) {}
34
ValidateInputsAreSameShape(OpKernel * op)35 bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
36 return context_->ValidateInputsAreSameShape(op);
37 }
38
xla_context() const39 XlaContext* XlaOpKernelContext::xla_context() const {
40 return &XlaContext::Get(context_);
41 }
42
builder() const43 xla::XlaBuilder* XlaOpKernelContext::builder() const {
44 return xla_context()->builder();
45 }
46
compiler() const47 XlaCompiler* XlaOpKernelContext::compiler() const {
48 return xla_context()->compiler();
49 }
50
51 // Retrieves an XlaExpression that was allocated by a previous Op.
CastExpressionFromTensor(const Tensor & tensor)52 static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
53 const XlaExpression* expression =
54 reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
55 CHECK(expression->kind() != XlaExpression::Kind::kInvalid)
56 << expression->HumanString();
57 return expression;
58 }
59
60 // Assigns an XlaExpression to a tensor on an XLA compilation device.
AssignExpressionToTensor(Tensor * tensor,const XlaExpression & value)61 static void AssignExpressionToTensor(Tensor* tensor,
62 const XlaExpression& value) {
63 const XlaExpression* expression =
64 reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
65 CHECK(expression->kind() == XlaExpression::Kind::kInvalid)
66 << expression->HumanString();
67 *const_cast<XlaExpression*>(expression) = value;
68 }
69
InputExpression(int index)70 const XlaExpression& XlaOpKernelContext::InputExpression(int index) {
71 return *CastExpressionFromTensor(context_->input(index));
72 }
73
InputExpression(absl::string_view name)74 const XlaExpression& XlaOpKernelContext::InputExpression(
75 absl::string_view name) {
76 return *CastExpressionFromTensor(GetInputTensorByName(name));
77 }
78
Input(int index)79 xla::XlaOp XlaOpKernelContext::Input(int index) {
80 return InputExpression(index).AsXlaOp(builder());
81 }
82
Input(absl::string_view name)83 xla::XlaOp XlaOpKernelContext::Input(absl::string_view name) {
84 return InputExpression(name).AsXlaOp(builder());
85 }
86
InputShape(int index)87 TensorShape XlaOpKernelContext::InputShape(int index) {
88 return context_->input(index).shape();
89 }
90
InputShape(absl::string_view name)91 TensorShape XlaOpKernelContext::InputShape(absl::string_view name) {
92 return GetInputTensorByName(name).shape();
93 }
94
input_type(int index) const95 DataType XlaOpKernelContext::input_type(int index) const {
96 return context_->input_dtype(index);
97 }
98
InputType(absl::string_view name)99 DataType XlaOpKernelContext::InputType(absl::string_view name) {
100 return GetInputTensorByName(name).dtype();
101 }
102
input_xla_type(int index)103 xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) {
104 xla::PrimitiveType type;
105 Status status = DataTypeToPrimitiveType(input_type(index), &type);
106 if (!status.ok()) {
107 SetStatus(status);
108 return xla::PRIMITIVE_TYPE_INVALID;
109 }
110 return type;
111 }
112
ConstantInput(int index,xla::Literal * constant_literal)113 Status XlaOpKernelContext::ConstantInput(int index,
114 xla::Literal* constant_literal) {
115 return ConstantInputReshaped(
116 index, context_->input(index).shape().dim_sizes(), constant_literal);
117 }
118
InputIndex(XlaOpKernelContext * context,absl::string_view name)119 static xla::StatusOr<int> InputIndex(XlaOpKernelContext* context,
120 absl::string_view name) {
121 int start, stop;
122 TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop));
123 if (stop != start + 1) {
124 return errors::InvalidArgument("OpKernel used list-valued input name '",
125 name,
126 "' when single-valued input was "
127 "expected");
128 }
129 return start;
130 }
131
ConstantInput(absl::string_view name,xla::Literal * constant_literal)132 Status XlaOpKernelContext::ConstantInput(absl::string_view name,
133 xla::Literal* constant_literal) {
134 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
135 return ConstantInput(index, constant_literal);
136 }
137
ConstantInputReshaped(int index,absl::Span<const int64> new_dims,xla::Literal * constant_literal)138 Status XlaOpKernelContext::ConstantInputReshaped(
139 int index, absl::Span<const int64> new_dims,
140 xla::Literal* constant_literal) {
141 XlaExpression e = InputExpression(index);
142 xla::StatusOr<absl::optional<Tensor>> constant_or_status =
143 e.ResolveConstant(compiler()->client());
144 if (!constant_or_status.ok()) {
145 Status status = constant_or_status.status();
146 errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
147 context_->op_kernel().type_string(),
148 " operator as a compile-time constant.");
149 return status;
150 }
151 absl::optional<Tensor> constant = constant_or_status.ValueOrDie();
152 if (!constant.has_value()) {
153 return errors::InvalidArgument(
154 "Input ", index, " to ", context_->op_kernel().type_string(),
155 " operator must be a compile-time constant.\n"
156 "\n"
157 "XLA compilation requires that operator arguments that represent "
158 "shapes or dimensions be evaluated to concrete values at compile time. "
159 "This error means that a shape or dimension argument could not be "
160 "evaluated at compile time, usually because the value of the argument "
161 "depends on a parameter to the computation, on a variable, or on a "
162 "stateful operation such as a random number generator.");
163 }
164
165 Tensor temp(constant->dtype());
166 if (!temp.CopyFrom(*constant, TensorShape(new_dims))) {
167 return errors::InvalidArgument(
168 context_->op_kernel().name(), " input ", index, " has shape ",
169 constant->shape().DebugString(),
170 " but was asked to be reshaped to incompatible shape ",
171 TensorShape(new_dims).DebugString());
172 }
173
174 TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
175 return Status::OK();
176 }
177
178 // Converts an int32 or int64 scalar literal to an int64.
LiteralToInt64Scalar(const xla::LiteralSlice & literal,int64 * out)179 static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal,
180 int64* out) {
181 if (literal.shape().rank() != 0) {
182 return errors::InvalidArgument("value is not a scalar");
183 }
184 if (literal.shape().element_type() == xla::S32) {
185 *out = literal.Get<int32>({});
186 } else if (literal.shape().element_type() == xla::S64) {
187 *out = literal.Get<int64>({});
188 } else {
189 return errors::InvalidArgument("value must be either int32 or int64");
190 }
191 return Status::OK();
192 }
193
194 // Converts an float32 or float64 scalar literal to a float64.
LiteralToFloat64Scalar(const xla::LiteralSlice & literal,double * out)195 static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal,
196 double* out) {
197 if (literal.shape().rank() != 0) {
198 return errors::InvalidArgument("value is not a scalar");
199 }
200 if (literal.shape().element_type() == xla::F32) {
201 *out = literal.Get<float>({});
202 } else if (literal.shape().element_type() == xla::F64) {
203 *out = literal.Get<double>({});
204 } else {
205 return errors::InvalidArgument("value must be either float32 or float64");
206 }
207 return Status::OK();
208 }
209
ConstantInputAsIntScalar(int index,int64 * out)210 Status XlaOpKernelContext::ConstantInputAsIntScalar(int index, int64* out) {
211 xla::Literal literal;
212 TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
213 return LiteralToInt64Scalar(literal, out);
214 }
215
ConstantInputAsIntScalar(absl::string_view name,int64 * out)216 Status XlaOpKernelContext::ConstantInputAsIntScalar(absl::string_view name,
217 int64* out) {
218 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
219 return ConstantInputAsIntScalar(index, out);
220 }
221
ConstantInputAsFloatScalar(int index,double * out)222 Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) {
223 xla::Literal literal;
224 TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
225 return LiteralToFloat64Scalar(literal, out);
226 }
227
228 // Converts an int32 or int64 1D literal to an int64 vector.
LiteralToInt64Vector(const xla::LiteralSlice & literal,std::vector<int64> * out)229 static Status LiteralToInt64Vector(const xla::LiteralSlice& literal,
230 std::vector<int64>* out) {
231 if (literal.shape().rank() != 1) {
232 return errors::InvalidArgument("value is not 1D, rank: ",
233 literal.shape().rank());
234 }
235 int64 size = xla::ShapeUtil::ElementsIn(literal.shape());
236 if (literal.shape().element_type() == xla::S32) {
237 for (int64 i = 0; i < size; ++i) {
238 out->push_back(literal.Get<int32>({i}));
239 }
240 } else if (literal.shape().element_type() == xla::S64) {
241 for (int64 i = 0; i < size; ++i) {
242 out->push_back(literal.Get<int64>({i}));
243 }
244 } else {
245 return errors::InvalidArgument("value must be either int32 or int64");
246 }
247 return Status::OK();
248 }
249
ConstantInputAsIntVector(int index,std::vector<int64> * out)250 Status XlaOpKernelContext::ConstantInputAsIntVector(int index,
251 std::vector<int64>* out) {
252 xla::Literal literal;
253 TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
254 return LiteralToInt64Vector(literal, out);
255 }
256
ConstantInputAsIntVector(absl::string_view name,std::vector<int64> * out)257 Status XlaOpKernelContext::ConstantInputAsIntVector(absl::string_view name,
258 std::vector<int64>* out) {
259 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
260 return ConstantInputAsIntVector(index, out);
261 }
262
ConstantInputReshapedToIntVector(int index,std::vector<int64> * out)263 Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
264 int index, std::vector<int64>* out) {
265 xla::Literal literal;
266 TF_RETURN_IF_ERROR(ConstantInputReshaped(
267 index, {InputShape(index).num_elements()}, &literal));
268 return LiteralToInt64Vector(literal, out);
269 }
270
ConstantInputReshapedToIntVector(absl::string_view name,std::vector<int64> * out)271 Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
272 absl::string_view name, std::vector<int64>* out) {
273 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
274 xla::Literal literal;
275 TF_RETURN_IF_ERROR(ConstantInputReshaped(
276 index, {InputShape(index).num_elements()}, &literal));
277 return LiteralToInt64Vector(literal, out);
278 }
279
ConstantInputAsInt64Literal(int index,xla::Literal * out)280 Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index,
281 xla::Literal* out) {
282 xla::Literal literal;
283 TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
284 switch (literal.shape().element_type()) {
285 case xla::S32: {
286 *out = xla::Literal(
287 xla::ShapeUtil::ChangeElementType(literal.shape(), xla::S64));
288 auto src_data = literal.data<int32>();
289 for (int64 i = 0; i < src_data.size(); ++i) {
290 out->data<int64>()[i] = src_data[i];
291 }
292 return Status::OK();
293 }
294 case xla::S64:
295 *out = std::move(literal);
296 return Status::OK();
297
298 default:
299 return errors::InvalidArgument(
300 "Invalid argument to ConstantInputAsInt64Literal: ",
301 xla::ShapeUtil::HumanString(literal.shape()));
302 }
303 }
304
ConstantInputAsInt64Literal(absl::string_view name,xla::Literal * out)305 Status XlaOpKernelContext::ConstantInputAsInt64Literal(absl::string_view name,
306 xla::Literal* out) {
307 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
308 return ConstantInputAsInt64Literal(index, out);
309 }
310
311 // TODO(phawkins): validate that the dimensions form a valid shape, fail
312 // gracefully if they do not.
ConstantInputAsShape(int index,TensorShape * shape)313 Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) {
314 xla::Literal literal;
315 TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
316 std::vector<int64> dims;
317 TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
318 *shape = TensorShape(dims);
319 return Status::OK();
320 }
321
ConstantInputAsPartialShape(int index,PartialTensorShape * shape)322 Status XlaOpKernelContext::ConstantInputAsPartialShape(
323 int index, PartialTensorShape* shape) {
324 xla::Literal literal;
325 TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
326 // If `literal` is a scalar it's value must be -1.
327 if (literal.shape().rank() == 0) {
328 int64 shape_val;
329 TF_RETURN_IF_ERROR(LiteralToInt64Scalar(literal, &shape_val));
330 if (shape_val != -1) {
331 return errors::InvalidArgument(
332 "Cannot convert value to PartialTensorShape: ", shape_val);
333 }
334 *shape = PartialTensorShape(); // Shape with unknown rank.
335 return Status::OK();
336 }
337 std::vector<int64> dims;
338 TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
339 *shape = PartialTensorShape(dims);
340 return Status::OK();
341 }
342
InputList(absl::string_view name,std::vector<xla::XlaOp> * handles,std::vector<TensorShape> * shapes)343 Status XlaOpKernelContext::InputList(absl::string_view name,
344 std::vector<xla::XlaOp>* handles,
345 std::vector<TensorShape>* shapes) {
346 OpInputList inputs;
347 TF_RETURN_IF_ERROR(context_->input_list(name, &inputs));
348 handles->clear();
349 shapes->clear();
350 for (const Tensor& input : inputs) {
351 handles->push_back(CastExpressionFromTensor(input)->AsXlaOp(builder()));
352 shapes->push_back(input.shape());
353 }
354 return Status::OK();
355 }
356
ConstantInputList(absl::string_view name,std::vector<xla::Literal> * outputs)357 Status XlaOpKernelContext::ConstantInputList(
358 absl::string_view name, std::vector<xla::Literal>* outputs) {
359 int start, stop;
360 TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop));
361 outputs->resize(stop - start);
362 for (int i = start; i < stop; ++i) {
363 TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i]));
364 }
365 return Status::OK();
366 }
367
368 namespace {
369
ReadVariableInputTensor(const Tensor & tensor,DataType type,const XlaOpKernelContext * ctx,TensorShape * shape,xla::XlaOp * value)370 Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
371 const XlaOpKernelContext* ctx,
372 TensorShape* shape, xla::XlaOp* value) {
373 const XlaExpression* expression = CastExpressionFromTensor(tensor);
374 XlaResource* variable = expression->resource();
375 TF_RET_CHECK(variable != nullptr);
376 TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
377 if (!variable->initialized()) {
378 return errors::FailedPrecondition("Read of uninitialized variable ",
379 variable->name());
380 }
381 if (variable->type() != type) {
382 return errors::InvalidArgument(
383 "Type mismatch for read of variable ", variable->name(), ". Expected ",
384 DataTypeString(type), "; got ", DataTypeString(variable->type()));
385 }
386 if (shape) {
387 *shape = variable->shape();
388 }
389
390 TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
391 ctx->compiler()->options().shape_representation_fn(
392 variable->shape(), variable->type()));
393 xla::Shape xla_shape;
394 TF_RETURN_IF_ERROR(
395 TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape));
396 if (xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
397 *value = variable->value();
398 } else {
399 *value = xla::Reshape(variable->value(), variable->shape().dim_sizes());
400 }
401 return Status::OK();
402 }
403
404 } // namespace
405
ReadVariableInput(int index,DataType type,TensorShape * shape,xla::XlaOp * value)406 Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
407 TensorShape* shape,
408 xla::XlaOp* value) {
409 return ReadVariableInputTensor(context_->input(index), type, this, shape,
410 value);
411 }
412
ReadVariableInput(absl::string_view name,DataType type,TensorShape * shape,xla::XlaOp * value)413 Status XlaOpKernelContext::ReadVariableInput(absl::string_view name,
414 DataType type, TensorShape* shape,
415 xla::XlaOp* value) {
416 return ReadVariableInputTensor(GetInputTensorByName(name), type, this, shape,
417 value);
418 }
419
GetVariableTypeAndShape(int index,DataType * type,TensorShape * shape) const420 Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
421 TensorShape* shape) const {
422 const Tensor& tensor = context_->input(index);
423 const XlaExpression* expression = CastExpressionFromTensor(tensor);
424 XlaResource* variable = expression->resource();
425 TF_RET_CHECK(variable != nullptr);
426 TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
427 if (!variable->initialized()) {
428 return errors::InvalidArgument("Read of uninitialized variable ",
429 variable->name());
430 }
431 *type = variable->type();
432 *shape = variable->shape();
433 return Status::OK();
434 }
435
SetOutputExpression(int index,const XlaExpression & expression)436 void XlaOpKernelContext::SetOutputExpression(int index,
437 const XlaExpression& expression) {
438 Status status = [&] {
439 // The step's default allocator is the dummy XlaCompilationAllocator which
440 // simply allocates a metadata buffer to hold the expression to which it
441 // corresponds.
442 Tensor* output = nullptr;
443 // Provides a special behavior for DT_VARIANT: a variant is treated as
444 // DT_UINT8 scalar as the type to allow mapping for variant to more generic
445 // types.
446 if (expression.dtype() == DT_VARIANT) {
447 // tensor_data() is not supported for variant Tensor (i.e.,
448 // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the
449 // XlaExpression inside the Tensor's tensor_data() does not work for
450 // variant. Instead construct a uint8 tensor and store the expression in
451 // its value.
452 // TODO(jpienaar): This should be refactored to stop masquerading
453 // XlaExpressions as Tensors.
454 output = new Tensor();
455 TensorShape tensor_shape;
456 TF_RETURN_IF_ERROR(
457 context_->allocate_temp(DT_UINT8, tensor_shape, output));
458 context_->set_output(index, *output);
459 } else {
460 TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape());
461 TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output));
462 }
463 AssignExpressionToTensor(output, expression);
464 return Status::OK();
465 }();
466 if (!status.ok()) {
467 SetStatus(status);
468 }
469 }
470
output_xla_type(int index)471 xla::PrimitiveType XlaOpKernelContext::output_xla_type(int index) {
472 xla::PrimitiveType type;
473 Status status = DataTypeToPrimitiveType(expected_output_dtype(index), &type);
474 if (!status.ok()) {
475 SetStatus(status);
476 return xla::PRIMITIVE_TYPE_INVALID;
477 }
478 return type;
479 }
480
SetOutput(int index,const xla::XlaOp & handle)481 void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
482 SetOutputExpression(
483 index,
484 XlaExpression::XlaOp(handle, context_->expected_output_dtype(index)));
485 }
486
SetConstantOutput(int index,const Tensor & constant)487 void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
488 SetOutputExpression(index, XlaExpression::Constant(constant));
489 }
490
SetTensorListOutput(int index,const xla::XlaOp & handle)491 void XlaOpKernelContext::SetTensorListOutput(int index,
492 const xla::XlaOp& handle) {
493 SetOutputExpression(index, XlaExpression::TensorList(handle));
494 }
495
SetResourceOutput(int index,XlaResource * resource)496 void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
497 SetOutputExpression(index, XlaExpression::Resource(resource));
498 }
499
GetResourceInput(int index,XlaResource ** resource)500 Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
501 const XlaExpression* expression =
502 CastExpressionFromTensor(context_->input(index));
503 TF_RET_CHECK(expression->resource() != nullptr);
504 *resource = expression->resource();
505 return Status::OK();
506 }
507
508 namespace {
509
AssignVariableTensor(const Tensor & tensor,DataType type,const XlaOpKernelContext * ctx,xla::XlaOp handle,xla::XlaBuilder * builder)510 Status AssignVariableTensor(const Tensor& tensor, DataType type,
511 const XlaOpKernelContext* ctx, xla::XlaOp handle,
512 xla::XlaBuilder* builder) {
513 const XlaExpression* expression = CastExpressionFromTensor(tensor);
514 XlaResource* variable = expression->resource();
515 TF_RET_CHECK(variable != nullptr);
516 TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
517
518 auto shape_or_status = builder->GetShape(handle);
519 if (!shape_or_status.ok()) {
520 return shape_or_status.status();
521 }
522 TensorShape shape;
523 TF_RETURN_IF_ERROR(
524 XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape));
525
526 TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
527
528 TF_ASSIGN_OR_RETURN(
529 xla::Shape representation_shape,
530 ctx->compiler()->options().shape_representation_fn(shape, type));
531 xla::Shape xla_shape;
532 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
533 if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
534 handle = xla::Reshape(handle,
535 xla::AsInt64Slice(representation_shape.dimensions()));
536 }
537 variable->SetRepresentationShape(representation_shape);
538 return variable->SetValue(handle);
539 }
540
541 } // namespace
542
AssignVariable(int input_index,DataType type,xla::XlaOp handle)543 Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
544 xla::XlaOp handle) {
545 TF_RET_CHECK(handle.valid());
546 return AssignVariableTensor(context_->input(input_index), type, this, handle,
547 builder());
548 }
549
AssignVariable(absl::string_view name,DataType type,xla::XlaOp handle)550 Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type,
551 xla::XlaOp handle) {
552 TF_RET_CHECK(handle.valid());
553 return AssignVariableTensor(GetInputTensorByName(name), type, this, handle,
554 builder());
555 }
556
CtxFailure(const Status & s)557 void XlaOpKernelContext::CtxFailure(const Status& s) {
558 context_->CtxFailure(s);
559 }
CtxFailureWithWarning(const Status & s)560 void XlaOpKernelContext::CtxFailureWithWarning(const Status& s) {
561 context_->CtxFailureWithWarning(s);
562 }
CtxFailure(const char * file,int line,const Status & s)563 void XlaOpKernelContext::CtxFailure(const char* file, int line,
564 const Status& s) {
565 context_->CtxFailure(file, line, s);
566 }
CtxFailureWithWarning(const char * file,int line,const Status & s)567 void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line,
568 const Status& s) {
569 context_->CtxFailureWithWarning(file, line, s);
570 }
571
GetOrCreateMax(const DataType type)572 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMax(
573 const DataType type) {
574 return xla_context()->GetOrCreateMax(type);
575 }
576
GetOrCreateMin(const DataType type)577 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMin(
578 const DataType type) {
579 return xla_context()->GetOrCreateMin(type);
580 }
581
GetOrCreateAdd(const DataType type)582 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateAdd(
583 const DataType type) {
584 return xla_context()->GetOrCreateAdd(type);
585 }
586
GetOrCreateMul(const DataType type)587 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul(
588 const DataType type) {
589 return xla_context()->GetOrCreateMul(type);
590 }
591
GetInputTensorByName(absl::string_view name)592 const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) {
593 const Tensor* tensor;
594 CHECK(context_->input(name, &tensor).ok());
595 return *tensor;
596 }
597
XlaOpKernel(OpKernelConstruction * context)598 XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {}
599
Compute(OpKernelContext * context)600 void XlaOpKernel::Compute(OpKernelContext* context) {
601 XlaOpKernelContext xla_context(context);
602 Compile(&xla_context);
603 }
604
605 } // namespace tensorflow
606