1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
17
18 #include <numeric>
19
20 #include "absl/memory/memory.h"
21 #include "tensorflow/compiler/tf2xla/literal_util.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/type_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
25 #include "tensorflow/compiler/tf2xla/xla_context.h"
26 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
27 #include "tensorflow/compiler/xla/client/value_inference.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/core/common_runtime/dma_helper.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/util/overflow.h"
34
35 namespace tensorflow {
36
XlaOpKernelContext(OpKernelContext * context)37 XlaOpKernelContext::XlaOpKernelContext(OpKernelContext* context)
38 : context_(context),
39 dynamic_dimension_is_minus_one_(false),
40 value_inference_(xla_context()->builder()) {}
41
ValidateInputsAreSameShape(OpKernel * op)42 bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
43 return context_->ValidateInputsAreSameShape(op);
44 }
45
xla_context() const46 XlaContext* XlaOpKernelContext::xla_context() const {
47 return &XlaContext::Get(context_);
48 }
49
builder() const50 xla::XlaBuilder* XlaOpKernelContext::builder() const {
51 return xla_context()->builder();
52 }
53
value_inference()54 xla::ValueInference& XlaOpKernelContext::value_inference() {
55 return value_inference_;
56 }
57
compiler() const58 XlaCompiler* XlaOpKernelContext::compiler() const {
59 return xla_context()->compiler();
60 }
61
InputExpression(int index)62 const XlaExpression& XlaOpKernelContext::InputExpression(int index) {
63 return *XlaExpression::CastExpressionFromTensor(context_->input(index));
64 }
65
InputExpression(absl::string_view name)66 const XlaExpression& XlaOpKernelContext::InputExpression(
67 absl::string_view name) {
68 return *XlaExpression::CastExpressionFromTensor(GetInputTensorByName(name));
69 }
70
Input(int index)71 xla::XlaOp XlaOpKernelContext::Input(int index) {
72 return InputExpression(index).AsXlaOp(builder());
73 }
74
Input(absl::string_view name)75 xla::XlaOp XlaOpKernelContext::Input(absl::string_view name) {
76 return InputExpression(name).AsXlaOp(builder());
77 }
78
InputShape(int index)79 TensorShape XlaOpKernelContext::InputShape(int index) {
80 return context_->input(index).shape();
81 }
82
InputShape(absl::string_view name)83 TensorShape XlaOpKernelContext::InputShape(absl::string_view name) {
84 return GetInputTensorByName(name).shape();
85 }
86
InputXlaShape(int index)87 StatusOr<xla::Shape> XlaOpKernelContext::InputXlaShape(int index) {
88 return builder()->GetShape(Input(index));
89 }
90
InputXlaShape(absl::string_view name)91 StatusOr<xla::Shape> XlaOpKernelContext::InputXlaShape(absl::string_view name) {
92 return builder()->GetShape(Input(name));
93 }
94
input_type(int index) const95 DataType XlaOpKernelContext::input_type(int index) const {
96 DataType type = context_->input_dtype(index);
97 if (type == DT_UINT8) {
98 // Masqueraded XlaExpression could have different type. See
99 // XlaOpKernelContext::SetOutputExpression for details.
100 auto expression =
101 XlaExpression::CastExpressionFromTensor(context_->input(index));
102 type = expression->dtype();
103 }
104 return type;
105 }
106
InputType(absl::string_view name)107 DataType XlaOpKernelContext::InputType(absl::string_view name) {
108 const Tensor& tensor = GetInputTensorByName(name);
109 DataType type = tensor.dtype();
110 if (type == DT_UINT8) {
111 // Masqueraded XlaExpression could have different type. See
112 // XlaOpKernelContext::SetOutputExpression for details.
113 auto expression = XlaExpression::CastExpressionFromTensor(tensor);
114 type = expression->dtype();
115 }
116 return type;
117 }
118
input_xla_type(int index)119 xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) {
120 xla::PrimitiveType type;
121 Status status = DataTypeToPrimitiveType(input_type(index), &type);
122 if (!status.ok()) {
123 SetStatus(status);
124 return xla::PRIMITIVE_TYPE_INVALID;
125 }
126 return type;
127 }
128
InputXlaType(absl::string_view name)129 xla::PrimitiveType XlaOpKernelContext::InputXlaType(absl::string_view name) {
130 xla::PrimitiveType type;
131 Status status = DataTypeToPrimitiveType(InputType(name), &type);
132 if (!status.ok()) {
133 SetStatus(status);
134 return xla::PRIMITIVE_TYPE_INVALID;
135 }
136 return type;
137 }
138
ConstantInput(int index,xla::Literal * constant_literal,xla::ValueInferenceMode mode)139 Status XlaOpKernelContext::ConstantInput(int index,
140 xla::Literal* constant_literal,
141 xla::ValueInferenceMode mode) {
142 if (this->InputXlaShape(index)->is_dynamic()) {
143 return errors::InvalidArgument(
144 "Reading input as constant from a dynamic tensor is not yet supported. "
145 "Xla shape: ",
146 this->InputXlaShape(index)->ToString());
147 }
148 return ConstantInputReshaped(index,
149 context_->input(index).shape().dim_sizes(),
150 constant_literal, mode);
151 }
152
InputIndex(XlaOpKernelContext * context,absl::string_view name)153 static StatusOr<int> InputIndex(XlaOpKernelContext* context,
154 absl::string_view name) {
155 int start, stop;
156 TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop));
157 if (stop != start + 1) {
158 return errors::InvalidArgument("OpKernel used list-valued input name '",
159 name,
160 "' when single-valued input was "
161 "expected");
162 }
163 return start;
164 }
165
ResolveInputDynamism(int index,xla::Literal * dynamism_literal)166 Status XlaOpKernelContext::ResolveInputDynamism(
167 int index, xla::Literal* dynamism_literal) {
168 return ResolveInputDynamismReshaped(
169 index, context_->input(index).shape().dim_sizes(), dynamism_literal);
170 }
171
ResolveInputDynamism(absl::string_view name,xla::Literal * dynamism_literal)172 Status XlaOpKernelContext::ResolveInputDynamism(
173 absl::string_view name, xla::Literal* dynamism_literal) {
174 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
175 return ResolveInputDynamism(index, dynamism_literal);
176 }
177
ConstantInput(absl::string_view name,xla::Literal * constant_literal,xla::ValueInferenceMode mode)178 Status XlaOpKernelContext::ConstantInput(absl::string_view name,
179 xla::Literal* constant_literal,
180 xla::ValueInferenceMode mode) {
181 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
182 return ConstantInput(index, constant_literal, mode);
183 }
184
ConstantInputReshaped(int index,absl::Span<const int64_t> new_dims,xla::Literal * constant_literal,xla::ValueInferenceMode mode)185 Status XlaOpKernelContext::ConstantInputReshaped(
186 int index, absl::Span<const int64_t> new_dims,
187 xla::Literal* constant_literal, xla::ValueInferenceMode mode) {
188 TF_ASSIGN_OR_RETURN(Tensor constant, ConstantInputTensor(index, mode));
189 Tensor temp(constant.dtype());
190 if (!temp.CopyFrom(constant, TensorShape(new_dims))) {
191 return errors::InvalidArgument(
192 context_->op_kernel().name(), " input ", index, " has shape ",
193 constant.shape().DebugString(),
194 " but was asked to be reshaped to incompatible shape ",
195 TensorShape(new_dims).DebugString());
196 }
197
198 TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
199 return OkStatus();
200 }
201
202 // Converts an int32 or int64 scalar literal to an int64.
LiteralToInt64Scalar(const xla::LiteralSlice & literal,int64_t * out)203 static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal,
204 int64_t* out) {
205 if (literal.shape().rank() != 0) {
206 return errors::InvalidArgument("value is not a scalar");
207 }
208 if (literal.shape().element_type() == xla::S32) {
209 *out = literal.Get<int32>({});
210 } else if (literal.shape().element_type() == xla::S64) {
211 *out = literal.Get<int64_t>({});
212 } else {
213 return errors::InvalidArgument("value must be either int32 or int64");
214 }
215 return OkStatus();
216 }
217
218 // Converts an float32 or float64 scalar literal to a float64.
LiteralToFloat64Scalar(const xla::LiteralSlice & literal,double * out)219 static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal,
220 double* out) {
221 if (literal.shape().rank() != 0) {
222 return errors::InvalidArgument("value is not a scalar");
223 }
224 if (literal.shape().element_type() == xla::F32) {
225 *out = literal.Get<float>({});
226 } else if (literal.shape().element_type() == xla::F64) {
227 *out = literal.Get<double>({});
228 } else {
229 return errors::InvalidArgument("value must be either float32 or float64");
230 }
231 return OkStatus();
232 }
233
ConstantInputAsIntScalar(int index,int64_t * out,xla::ValueInferenceMode mode)234 Status XlaOpKernelContext::ConstantInputAsIntScalar(
235 int index, int64_t* out, xla::ValueInferenceMode mode) {
236 xla::Literal literal;
237 TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
238 return LiteralToInt64Scalar(literal, out);
239 }
240
ConstantInputAsIntScalar(absl::string_view name,int64_t * out,xla::ValueInferenceMode mode)241 Status XlaOpKernelContext::ConstantInputAsIntScalar(
242 absl::string_view name, int64_t* out, xla::ValueInferenceMode mode) {
243 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
244 return ConstantInputAsIntScalar(index, out, mode);
245 }
246
ConstantInputAsFloatScalar(int index,double * out,xla::ValueInferenceMode mode)247 Status XlaOpKernelContext::ConstantInputAsFloatScalar(
248 int index, double* out, xla::ValueInferenceMode mode) {
249 xla::Literal literal;
250 TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
251 return LiteralToFloat64Scalar(literal, out);
252 }
253
LiteralToPredVector(const xla::LiteralSlice & literal,std::vector<bool> * out)254 static Status LiteralToPredVector(const xla::LiteralSlice& literal,
255 std::vector<bool>* out) {
256 if (literal.shape().rank() != 1) {
257 return errors::InvalidArgument("output_shape must be rank 1, got shape ",
258 literal.shape().DebugString());
259 }
260 int64_t size = xla::ShapeUtil::ElementsIn(literal.shape());
261 if (literal.shape().element_type() != xla::PRED) {
262 return errors::InvalidArgument("value is not PRED");
263 }
264 for (int64_t i = 0; i < size; ++i) {
265 out->push_back(literal.Get<bool>({i}));
266 }
267 return OkStatus();
268 }
269
ResolveInputDynamismIntoPred(int index,bool * out)270 Status XlaOpKernelContext::ResolveInputDynamismIntoPred(int index, bool* out) {
271 xla::Literal literal;
272 XlaExpression e = InputExpression(index);
273 auto* client = compiler() ? compiler()->client() : nullptr;
274 StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client);
275 if (!dynamism_or_status.ok()) {
276 // When failed to resolve dynamism, conservatively consider the value
277 // dynamic. This could happen if the input depends on some ops like
278 // custom-call that is not supported generally for dynamism computation.
279 //
280 // TODO(b/176993339): Support resolving dynamism across computations so
281 // resolving dynamism will not fail in those cases.
282 *out = true;
283 return OkStatus();
284 }
285 Tensor dynamism = dynamism_or_status.ValueOrDie();
286
287 Tensor temp(dynamism.dtype());
288 TensorShape tensor_shape({});
289 if (!temp.CopyFrom(dynamism, tensor_shape)) {
290 return errors::InvalidArgument(
291 context_->op_kernel().name(), " input ", index, " has shape ",
292 dynamism.shape().DebugString(), " which is not a R0 ", tensor_shape);
293 }
294
295 TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp));
296 *out = literal.Get<bool>({});
297 return OkStatus();
298 }
299
ResolveInputDynamismIntoPredVector(absl::string_view name,std::vector<bool> * out)300 Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
301 absl::string_view name, std::vector<bool>* out) {
302 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
303 return ResolveInputDynamismIntoPredVector(index, out);
304 }
305
ResolveInputDynamismIntoPred(absl::string_view name,bool * out)306 Status XlaOpKernelContext::ResolveInputDynamismIntoPred(absl::string_view name,
307 bool* out) {
308 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
309 return ResolveInputDynamismIntoPred(index, out);
310 }
311
ResolveInputDynamismReshaped(int index,absl::Span<const int64_t> new_dims,xla::Literal * dynamism_literal)312 Status XlaOpKernelContext::ResolveInputDynamismReshaped(
313 int index, absl::Span<const int64_t> new_dims,
314 xla::Literal* dynamism_literal) {
315 XlaExpression e = InputExpression(index);
316 auto* client = compiler() ? compiler()->client() : nullptr;
317 StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client);
318 if (!dynamism_or_status.ok()) {
319 xla::Literal true_literal = xla::LiteralUtil::CreateR0<bool>(true);
320 // When failed to resolve dynamism, conservatively consider the value
321 // dynamic. This could happen if the input depends on some ops like
322 // custom-call that is not supported generally for dynamism computation.
323 *dynamism_literal =
324 true_literal
325 .Broadcast(xla::ShapeUtil::MakeShape(xla::PRED, new_dims), {})
326 .ValueOrDie();
327
328 return OkStatus();
329 }
330 Tensor dynamism = dynamism_or_status.ValueOrDie();
331
332 Tensor temp(dynamism.dtype());
333 if (!temp.CopyFrom(dynamism, TensorShape(new_dims))) {
334 return errors::InvalidArgument(
335 context_->op_kernel().name(), " input ", index, " has shape ",
336 dynamism.shape().DebugString(),
337 " but was asked to be reshaped to incompatible shape ",
338 TensorShape(new_dims).DebugString());
339 }
340
341 TF_ASSIGN_OR_RETURN(*dynamism_literal, HostTensorToLiteral(temp));
342 return OkStatus();
343 }
344
ResolveInputDynamismIntoPredVector(int index,std::vector<bool> * out)345 Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
346 int index, std::vector<bool>* out) {
347 xla::Literal literal;
348 TF_RETURN_IF_ERROR(ResolveInputDynamismReshaped(
349 index, {InputShape(index).num_elements()}, &literal));
350
351 return LiteralToPredVector(literal, out);
352 }
353
354 // Converts an int32 or int64 1D literal to an int64 vector.
LiteralToInt64Vector(const xla::LiteralSlice & literal,std::vector<int64_t> * out)355 static Status LiteralToInt64Vector(const xla::LiteralSlice& literal,
356 std::vector<int64_t>* out) {
357 if (literal.shape().rank() != 1) {
358 return errors::InvalidArgument("output_shape must be rank 1, got shape ",
359 literal.shape().DebugString());
360 }
361 int64_t size = xla::ShapeUtil::ElementsIn(literal.shape());
362 if (literal.shape().element_type() == xla::S32) {
363 for (int64_t i = 0; i < size; ++i) {
364 out->push_back(literal.Get<int32>({i}));
365 }
366 } else if (literal.shape().element_type() == xla::S64) {
367 for (int64_t i = 0; i < size; ++i) {
368 out->push_back(literal.Get<int64_t>({i}));
369 }
370 } else {
371 return errors::InvalidArgument("value must be either int32 or int64");
372 }
373 return OkStatus();
374 }
375
ConstantInputAsIntVector(int index,std::vector<int64_t> * out,xla::ValueInferenceMode mode)376 Status XlaOpKernelContext::ConstantInputAsIntVector(
377 int index, std::vector<int64_t>* out, xla::ValueInferenceMode mode) {
378 xla::Literal literal;
379 TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
380 return LiteralToInt64Vector(literal, out);
381 }
382
ConstantInputAsIntVector(absl::string_view name,std::vector<int64_t> * out,xla::ValueInferenceMode mode)383 Status XlaOpKernelContext::ConstantInputAsIntVector(
384 absl::string_view name, std::vector<int64_t>* out,
385 xla::ValueInferenceMode mode) {
386 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
387 return ConstantInputAsIntVector(index, out, mode);
388 }
389
ConstantInputReshapedToIntVector(int index,std::vector<int64_t> * out,xla::ValueInferenceMode mode)390 Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
391 int index, std::vector<int64_t>* out, xla::ValueInferenceMode mode) {
392 xla::Literal literal;
393 TF_RETURN_IF_ERROR(ConstantInputReshaped(
394 index, {InputShape(index).num_elements()}, &literal, mode));
395 return LiteralToInt64Vector(literal, out);
396 }
397
ConstantInputReshapedToIntVector(absl::string_view name,std::vector<int64_t> * out,xla::ValueInferenceMode mode)398 Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
399 absl::string_view name, std::vector<int64_t>* out,
400 xla::ValueInferenceMode mode) {
401 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
402 xla::Literal literal;
403 TF_RETURN_IF_ERROR(ConstantInputReshaped(
404 index, {InputShape(index).num_elements()}, &literal, mode));
405 return LiteralToInt64Vector(literal, out);
406 }
407
ConstantInputAsInt64Literal(int index,xla::Literal * out,xla::ValueInferenceMode mode)408 Status XlaOpKernelContext::ConstantInputAsInt64Literal(
409 int index, xla::Literal* out, xla::ValueInferenceMode mode) {
410 xla::Literal literal;
411 TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
412 switch (literal.shape().element_type()) {
413 case xla::S32: {
414 *out = xla::Literal(
415 xla::ShapeUtil::ChangeElementType(literal.shape(), xla::S64));
416 auto src_data = literal.data<int32>();
417 for (int64_t i = 0; i < src_data.size(); ++i) {
418 out->data<int64_t>()[i] = src_data[i];
419 }
420 return OkStatus();
421 }
422 case xla::S64:
423 *out = std::move(literal);
424 return OkStatus();
425
426 default:
427 return errors::InvalidArgument(
428 "Invalid argument to ConstantInputAsInt64Literal: ",
429 xla::ShapeUtil::HumanString(literal.shape()));
430 }
431 }
432
ConstantInputAsInt64Literal(absl::string_view name,xla::Literal * out,xla::ValueInferenceMode mode)433 Status XlaOpKernelContext::ConstantInputAsInt64Literal(
434 absl::string_view name, xla::Literal* out, xla::ValueInferenceMode mode) {
435 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
436 return ConstantInputAsInt64Literal(index, out, mode);
437 }
438
439 // TODO(phawkins): validate that the dimensions form a valid shape, fail
440 // gracefully if they do not.
ConstantInputAsShape(int index,TensorShape * shape,xla::ValueInferenceMode mode)441 Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape,
442 xla::ValueInferenceMode mode) {
443 xla::Literal literal;
444 TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
445 std::vector<int64_t> dims;
446 TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
447
448 int64_t num_elements = 1;
449 for (auto i = dims.begin(); i != dims.end(); ++i) {
450 num_elements = MultiplyWithoutOverflow(num_elements, *i);
451 if (num_elements < 0)
452 return errors::InvalidArgument(
453 "The total elements specified by orig_input_shape is too large.",
454 "Encountered overflow after multiplying", *i,
455 ", result: ", num_elements);
456 }
457 *shape = TensorShape(dims);
458 return OkStatus();
459 }
460
ConstantInputAsPartialShape(int index,PartialTensorShape * shape)461 Status XlaOpKernelContext::ConstantInputAsPartialShape(
462 int index, PartialTensorShape* shape) {
463 xla::Literal literal;
464 TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
465 // If `literal` is a scalar it's value must be -1.
466 if (literal.shape().rank() == 0) {
467 int64_t shape_val;
468 TF_RETURN_IF_ERROR(LiteralToInt64Scalar(literal, &shape_val));
469 if (shape_val != -1) {
470 return errors::InvalidArgument(
471 "Cannot convert value to PartialTensorShape: ", shape_val);
472 }
473 *shape = PartialTensorShape(); // Shape with unknown rank.
474 return OkStatus();
475 }
476 std::vector<int64_t> dims;
477 TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
478 *shape = PartialTensorShape(dims);
479 return OkStatus();
480 }
481
InputList(absl::string_view name,std::vector<xla::XlaOp> * handles,std::vector<TensorShape> * shapes)482 Status XlaOpKernelContext::InputList(absl::string_view name,
483 std::vector<xla::XlaOp>* handles,
484 std::vector<TensorShape>* shapes) {
485 OpInputList inputs;
486 TF_RETURN_IF_ERROR(context_->input_list(name, &inputs));
487 handles->clear();
488 shapes->clear();
489 for (const Tensor& input : inputs) {
490 handles->push_back(
491 XlaExpression::CastExpressionFromTensor(input)->AsXlaOp(builder()));
492 shapes->push_back(input.shape());
493 }
494 return OkStatus();
495 }
496
ConstantInputList(absl::string_view name,std::vector<xla::Literal> * outputs,xla::ValueInferenceMode mode)497 Status XlaOpKernelContext::ConstantInputList(absl::string_view name,
498 std::vector<xla::Literal>* outputs,
499 xla::ValueInferenceMode mode) {
500 int start, stop;
501 TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop));
502 outputs->resize(stop - start);
503 for (int i = start; i < stop; ++i) {
504 TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i], mode));
505 }
506 return OkStatus();
507 }
508
ConstantInputTensor(int index,xla::ValueInferenceMode mode)509 StatusOr<Tensor> XlaOpKernelContext::ConstantInputTensor(
510 int index, xla::ValueInferenceMode mode) {
511 XlaExpression e = InputExpression(index);
512 auto* client = compiler() ? compiler()->client() : nullptr;
513 StatusOr<std::optional<Tensor>> constant_or_status =
514 e.ResolveConstant(client, dynamic_dimension_is_minus_one_, mode);
515 if (!constant_or_status.ok()) {
516 Status status = constant_or_status.status();
517 errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
518 context_->op_kernel().type_string(),
519 " operator as a compile-time constant.");
520 return status;
521 }
522 std::optional<Tensor> constant = constant_or_status.ValueOrDie();
523 if (!constant.has_value()) {
524 return errors::InvalidArgument(
525 "Input ", index, " to node `", context_->op_kernel().name(),
526 "` with op ", context_->op_kernel().type_string(),
527 " must be a compile-time constant.\n\n"
528 "XLA compilation requires that operator arguments that represent "
529 "shapes or dimensions be evaluated to concrete values at compile time. "
530 "This error means that a shape or dimension argument could not be "
531 "evaluated at compile time, usually because the value of the argument "
532 "depends on a parameter to the computation, on a variable, or on a "
533 "stateful operation such as a random number generator.");
534 }
535 return *constant;
536 }
537
538 namespace {
539
ReadVariableInputTensor(const Tensor & tensor,DataType type,const XlaOpKernelContext * ctx,TensorShape * shape,xla::XlaOp * value)540 Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
541 const XlaOpKernelContext* ctx,
542 TensorShape* shape, xla::XlaOp* value) {
543 const XlaExpression* expression =
544 XlaExpression::CastExpressionFromTensor(tensor);
545 XlaResource* variable = expression->resource();
546 TF_RET_CHECK(variable != nullptr);
547 TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
548 if (!variable->initialized()) {
549 return errors::FailedPrecondition(
550 "Read variable failure ", variable->name(),
551 ". It could mean the variable is uninitialized or the variable is on "
552 "another device ");
553 }
554 if (variable->type() != type) {
555 return errors::InvalidArgument(
556 "Trying to read variable with wrong dtype. Expected ",
557 DataTypeString(type), " got ", DataTypeString(variable->type()));
558 }
559 if (shape) {
560 *shape = variable->shape();
561 }
562
563 if (!variable->IsOverwritten() && expression->constant_value()) {
564 TF_ASSIGN_OR_RETURN(xla::Literal literal,
565 HostTensorToLiteral(*expression->constant_value()));
566 *value = xla::ConstantLiteral(ctx->builder(), literal);
567 return OkStatus();
568 }
569 auto shape_determination_fns =
570 ctx->compiler()->options().shape_determination_fns;
571 XlaLayoutPreference layout_preference =
572 shape_determination_fns.layout_preference_fn(
573 variable->shape(), variable->type(), std::nullopt);
574 TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
575 shape_determination_fns.shape_representation_fn(
576 variable->shape(), variable->type(),
577 /*use_fast_memory=*/false, layout_preference));
578 xla::Shape xla_shape;
579 TF_RETURN_IF_ERROR(
580 TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape));
581 if (xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
582 *value = variable->value();
583 } else {
584 *value = xla::Reshape(variable->value(), variable->shape().dim_sizes());
585 }
586 return OkStatus();
587 }
588
589 } // namespace
590
ReadVariableInput(int index,DataType type,TensorShape * shape,xla::XlaOp * value)591 Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
592 TensorShape* shape,
593 xla::XlaOp* value) {
594 return ReadVariableInputTensor(context_->input(index), type, this, shape,
595 value);
596 }
597
ReadVariableInput(absl::string_view name,DataType type,TensorShape * shape,xla::XlaOp * value)598 Status XlaOpKernelContext::ReadVariableInput(absl::string_view name,
599 DataType type, TensorShape* shape,
600 xla::XlaOp* value) {
601 return ReadVariableInputTensor(GetInputTensorByName(name), type, this, shape,
602 value);
603 }
604
GetVariableTypeAndShape(int index,DataType * type,TensorShape * shape) const605 Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
606 TensorShape* shape) const {
607 const Tensor& tensor = context_->input(index);
608 const XlaExpression* expression =
609 XlaExpression::CastExpressionFromTensor(tensor);
610 XlaResource* variable = expression->resource();
611 TF_RET_CHECK(variable != nullptr);
612 TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
613 if (!variable->initialized()) {
614 return errors::InvalidArgument(
615 "Read variable failure ", variable->name(),
616 ". It could mean the variable is uninitialized or the variable is on "
617 "another device ");
618 }
619 *type = variable->type();
620 *shape = variable->shape();
621 return OkStatus();
622 }
623
SetOutputExpression(int index,const XlaExpression & expression)624 void XlaOpKernelContext::SetOutputExpression(int index,
625 const XlaExpression& expression) {
626 Status status = [&] {
627 // The step's default allocator is the dummy XlaCompilationAllocator which
628 // simply allocates a metadata buffer to hold the expression to which it
629 // corresponds.
630 // Provides a special behavior for DT_VARIANT and other types that are not
631 // trivially copyable. In those cases, allocate a tensor of type DT_UINT8.
632 if (!DataTypeCanUseMemcpy(expression.dtype())) {
633 // tensor_data() is not supported for tensors that cannot be copied via
634 // memcpy, as the copy logic might try to inspect the stored data (e.g.
635 // a std::string). This is likely to fail, as the data is invalid given
636 // that it actually encodes an XlaExpression. Using a uint8 tensor is
637 // always safe, so simply do that.
638 // TODO(jpienaar): This should be refactored to stop masquerading
639 // XlaExpressions as Tensors.
640 Tensor output;
641 TensorShape tensor_shape;
642 TF_RETURN_IF_ERROR(
643 context_->allocate_temp(DT_UINT8, tensor_shape, &output));
644 context_->set_output(index, output);
645 } else {
646 Tensor* output = nullptr;
647 TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape());
648 TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output));
649 }
650 XlaExpression::AssignExpressionToTensor(expression,
651 context_->mutable_output(index));
652 return OkStatus();
653 }();
654 if (!status.ok()) {
655 SetStatus(status);
656 }
657 }
658
output_xla_type(int index)659 xla::PrimitiveType XlaOpKernelContext::output_xla_type(int index) {
660 xla::PrimitiveType type;
661 Status status = DataTypeToPrimitiveType(expected_output_dtype(index), &type);
662 if (!status.ok()) {
663 SetStatus(status);
664 return xla::PRIMITIVE_TYPE_INVALID;
665 }
666 return type;
667 }
668
SetOutput(int index,const xla::XlaOp & handle)669 void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
670 SetOutputExpression(
671 index,
672 XlaExpression::XlaOp(handle, context_->expected_output_dtype(index)));
673 }
674
SetConstantOutput(int index,const Tensor & constant)675 void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
676 SetOutputExpression(index, XlaExpression::Constant(constant));
677 }
678
SetTensorListOutput(int index,const xla::XlaOp & handle)679 void XlaOpKernelContext::SetTensorListOutput(int index,
680 const xla::XlaOp& handle) {
681 SetOutputExpression(index, XlaExpression::TensorList(handle));
682 }
683
SetResourceOutput(int index,XlaResource * resource)684 void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
685 SetOutputExpression(index, XlaExpression::Resource(resource));
686 }
687
GetResourceInput(int index,XlaResource ** resource)688 Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
689 const XlaExpression* expression =
690 XlaExpression::CastExpressionFromTensor(context_->input(index));
691 TF_RET_CHECK(expression->resource() != nullptr);
692 *resource = expression->resource();
693 return OkStatus();
694 }
695
696 namespace {
697
AssignVariableTensor(const Tensor & tensor,DataType type,const XlaOpKernelContext * ctx,xla::XlaOp handle,xla::XlaBuilder * builder)698 Status AssignVariableTensor(const Tensor& tensor, DataType type,
699 const XlaOpKernelContext* ctx, xla::XlaOp handle,
700 xla::XlaBuilder* builder) {
701 const XlaExpression* expression =
702 XlaExpression::CastExpressionFromTensor(tensor);
703 XlaResource* variable = expression->resource();
704 TF_RET_CHECK(variable != nullptr);
705 TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
706
707 auto shape_or_status = builder->GetShape(handle);
708 if (!shape_or_status.ok()) {
709 return shape_or_status.status();
710 }
711 TensorShape shape;
712 TF_RETURN_IF_ERROR(
713 XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape));
714
715 TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
716
717 auto shape_determination_fns =
718 ctx->compiler()->options().shape_determination_fns;
719 XlaLayoutPreference layout_preference =
720 shape_determination_fns.layout_preference_fn(shape, type, std::nullopt);
721 TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
722 shape_determination_fns.shape_representation_fn(
723 shape, type,
724 /*use_fast_memory=*/false, layout_preference));
725 xla::Shape xla_shape;
726 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
727 if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
728 handle = xla::Reshape(handle, representation_shape.dimensions());
729 }
730 variable->SetRepresentationShape(representation_shape);
731 return variable->SetValue(handle);
732 }
733
734 } // namespace
735
AssignVariable(int input_index,DataType type,xla::XlaOp handle)736 Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
737 xla::XlaOp handle) {
738 TF_RET_CHECK(handle.valid());
739 return AssignVariableTensor(context_->input(input_index), type, this, handle,
740 builder());
741 }
742
AssignVariable(absl::string_view name,DataType type,xla::XlaOp handle)743 Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type,
744 xla::XlaOp handle) {
745 TF_RET_CHECK(handle.valid());
746 return AssignVariableTensor(GetInputTensorByName(name), type, this, handle,
747 builder());
748 }
749
GetStatusWithStackTrace(const Status & s,const XlaOpKernelContext * ctx)750 static Status GetStatusWithStackTrace(const Status& s,
751 const XlaOpKernelContext* ctx) {
752 if (s.code() == error::INVALID_ARGUMENT) {
753 return Status{s.code(),
754 absl::StrCat(s.error_message(), "\n", ctx->StackTrace())};
755 }
756 return s;
757 }
758
CtxFailure(const Status & s)759 void XlaOpKernelContext::CtxFailure(const Status& s) {
760 context_->CtxFailure(GetStatusWithStackTrace(s, this));
761 }
CtxFailureWithWarning(const Status & s)762 void XlaOpKernelContext::CtxFailureWithWarning(const Status& s) {
763 context_->CtxFailureWithWarning(GetStatusWithStackTrace(s, this));
764 }
765
CtxFailure(const char * file,int line,const Status & s)766 void XlaOpKernelContext::CtxFailure(const char* file, int line,
767 const Status& s) {
768 context_->CtxFailure(file, line, GetStatusWithStackTrace(s, this));
769 }
CtxFailureWithWarning(const char * file,int line,const Status & s)770 void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line,
771 const Status& s) {
772 context_->CtxFailureWithWarning(file, line, GetStatusWithStackTrace(s, this));
773 }
774
GetOrCreateMax(const DataType type)775 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMax(
776 const DataType type) {
777 return xla_context()->GetOrCreateMax(type);
778 }
779
GetOrCreateMin(const DataType type)780 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMin(
781 const DataType type) {
782 return xla_context()->GetOrCreateMin(type);
783 }
784
GetOrCreateAdd(const DataType type)785 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateAdd(
786 const DataType type) {
787 return xla_context()->GetOrCreateAdd(type);
788 }
789
GetOrCreateMul(const DataType type)790 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul(
791 const DataType type) {
792 return xla_context()->GetOrCreateMul(type);
793 }
794
GetInputTensorByName(absl::string_view name)795 const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) {
796 const Tensor* tensor;
797 CHECK(context_->input(name, &tensor).ok());
798 return *tensor;
799 }
800
XlaOpKernel(OpKernelConstruction * context)801 XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {}
802
Compute(OpKernelContext * context)803 void XlaOpKernel::Compute(OpKernelContext* context) {
804 XlaOpKernelContext xla_context(context);
805 Compile(&xla_context);
806 }
807
StackTrace() const808 std::string XlaOpKernelContext::StackTrace() const {
809 if (const AbstractStackTrace* stack_trace =
810 xla_context()->StackTraceForNodeName(op_kernel().name())) {
811 AbstractStackTrace::TracePrintingOptions opts;
812 opts.show_line_contents = true;
813 opts.filter_common_prefix = true;
814 opts.drop_internal_frames = true;
815 return absl::StrCat("\nStack trace for op definition: \n",
816 stack_trace->ToString(opts), "\n");
817 } else {
818 return "";
819 }
820 }
821
822 } // namespace tensorflow
823