1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_expression.h"
17
18 #include "tensorflow/compiler/tf2xla/literal_util.h"
19 #include "tensorflow/compiler/tf2xla/shape_util.h"
20 #include "tensorflow/compiler/xla/client/value_inference.h"
21 #include "tensorflow/core/framework/types.pb.h"
22 #include "tensorflow/core/lib/core/errors.h"
23
24 namespace tensorflow {
25
26 XlaExpression::XlaExpression() = default;
27
Invalid()28 XlaExpression XlaExpression::Invalid() {
29 XlaExpression e;
30 e.kind_ = Kind::kInvalid;
31 return e;
32 }
33
Constant(Tensor value)34 XlaExpression XlaExpression::Constant(Tensor value) {
35 XlaExpression e;
36 e.kind_ = Kind::kConstant;
37 e.dtype_ = value.dtype();
38 e.constant_value_ = value;
39 return e;
40 }
41
ConstantResource(Tensor value,XlaResource * resource)42 XlaExpression XlaExpression::ConstantResource(Tensor value,
43 XlaResource* resource) {
44 XlaExpression e;
45 e.kind_ = Kind::kResource;
46 e.dtype_ = DT_RESOURCE;
47 e.resource_ = resource;
48 e.constant_value_ = value;
49 return e;
50 }
51
XlaOp(xla::XlaOp value,DataType dtype)52 XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) {
53 XlaExpression e;
54 e.kind_ = Kind::kXlaOp;
55 e.dtype_ = dtype;
56 e.handle_ = value;
57 return e;
58 }
59
TensorList(xla::XlaOp tensor_list)60 XlaExpression XlaExpression::TensorList(xla::XlaOp tensor_list) {
61 XlaExpression e;
62 e.kind_ = Kind::kTensorList;
63 e.dtype_ = DT_VARIANT;
64 e.handle_ = tensor_list;
65 return e;
66 }
67
Resource(XlaResource * resource)68 XlaExpression XlaExpression::Resource(XlaResource* resource) {
69 XlaExpression e;
70 e.kind_ = Kind::kResource;
71 e.dtype_ = DT_RESOURCE;
72 e.resource_ = resource;
73 return e;
74 }
75
HumanString() const76 string XlaExpression::HumanString() const {
77 switch (kind_) {
78 case Kind::kInvalid:
79 return "invalid";
80 case Kind::kConstant:
81 return "constant";
82 case Kind::kXlaOp:
83 return "xla_op";
84 case Kind::kResource:
85 return "resource";
86 case Kind::kTensorList:
87 return "tensor_list";
88 }
89 }
90
AsXlaOp(xla::XlaBuilder * builder) const91 xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const {
92 return builder->ReportErrorOrReturn([&]() -> StatusOr<xla::XlaOp> {
93 switch (kind_) {
94 case Kind::kConstant: {
95 xla::BorrowingLiteral literal;
96 TF_RETURN_IF_ERROR(
97 HostTensorToBorrowingLiteral(*constant_value_, &literal));
98 return xla::ConstantLiteral(builder, literal);
99 }
100 case Kind::kTensorList:
101 TF_FALLTHROUGH_INTENDED;
102 case Kind::kXlaOp:
103 if (builder != handle_.builder()) {
104 return errors::InvalidArgument(
105 "Mismatched builders in XlaExpression::AsXlaOp");
106 }
107 return handle_;
108 default:
109 return errors::InvalidArgument("AsXlaOp called on XlaExpression: ",
110 HumanString());
111 }
112 });
113 }
114
ResolveDynamism(xla::Client * client) const115 StatusOr<Tensor> XlaExpression::ResolveDynamism(xla::Client* client) const {
116 switch (kind()) {
117 case Kind::kConstant: {
118 // Constant values are considered static.
119 Tensor constant_false(DT_BOOL, constant_value()->shape());
120 auto flat = constant_false.flat<bool>();
121 for (int64_t i = 0; i < flat.size(); ++i) flat(i) = false;
122 return constant_false;
123 }
124 case Kind::kXlaOp:
125 break;
126 case Kind::kTensorList:
127 TF_FALLTHROUGH_INTENDED;
128 case Kind::kResource:
129 TF_FALLTHROUGH_INTENDED;
130 case Kind::kInvalid:
131 return errors::InvalidArgument(
132 "ResolveDynamism called on unsupported XlaExpression: ",
133 HumanString());
134 }
135
136 if (!client)
137 return errors::InvalidArgument("client is required to resolve constant");
138
139 TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
140
141 // The XLA layout is specified minor to major, and TensorFlow uses a major to
142 // minor order.
143 std::vector<int64_t> layout_indices(shape.dims());
144 std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
145 xla::ValueInference value_inference(handle().builder());
146 TF_ASSIGN_OR_RETURN(xla::LiteralSlice literal,
147 value_inference.AnalyzeIsDynamic(handle()));
148 Tensor tensor(DT_BOOL);
149 TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, DT_BOOL, &tensor));
150 return tensor;
151 }
152
ResolveConstant(xla::Client * client,bool dynamic_dimension_is_minus_one,xla::ValueInferenceMode mode) const153 StatusOr<std::optional<Tensor>> XlaExpression::ResolveConstant(
154 xla::Client* client, bool dynamic_dimension_is_minus_one,
155 xla::ValueInferenceMode mode) const {
156 switch (kind()) {
157 case Kind::kConstant:
158 case Kind::kResource:
159 return constant_value();
160 case Kind::kXlaOp:
161 break;
162 case Kind::kTensorList:
163 TF_FALLTHROUGH_INTENDED;
164 case Kind::kInvalid:
165 return errors::InvalidArgument(
166 "ResolveConstant called on XlaExpression: ", HumanString());
167 }
168 TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
169 // The XLA layout is specified minor to major, and TensorFlow uses a major to
170 // minor order.
171 std::vector<int64_t> layout_indices(shape.dims());
172 std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
173 xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
174 if (mode == xla::ValueInferenceMode::kLowerBound ||
175 mode == xla::ValueInferenceMode::kUpperBound ||
176 mode == xla::ValueInferenceMode::kValue) {
177 std::vector<int64_t> layout_indices(shape.dims());
178 std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
179 xla::ValueInference value_inference(handle().builder());
180 TF_ASSIGN_OR_RETURN(xla::OptionalLiteral literal,
181 value_inference.AnalyzeConstant(handle(), mode));
182 if (!literal.GetValue().has_value()) {
183 return {std::nullopt};
184 }
185 Tensor tensor;
186 TF_RETURN_IF_ERROR(LiteralToHostTensor(
187 literal.GetValue().value().Relayout(layout), dtype(), &tensor));
188 return {tensor};
189 }
190
191 TF_ASSIGN_OR_RETURN(bool is_constant,
192 handle().builder()->IsConstant(handle()));
193 if (!is_constant) {
194 return {std::nullopt};
195 }
196
197 if (!client)
198 return errors::InvalidArgument("client is required to resolve constant");
199
200 TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph,
201 handle().builder()->BuildConstantSubGraph(
202 handle(), dynamic_dimension_is_minus_one));
203
204 TF_ASSIGN_OR_RETURN(xla::Literal literal,
205 client->ComputeConstant(constant_graph, &layout));
206 Tensor tensor;
207 TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype(), &tensor));
208 return {tensor};
209 }
210
GetShape() const211 StatusOr<TensorShape> XlaExpression::GetShape() const {
212 switch (kind_) {
213 case Kind::kConstant:
214 return constant_value()->shape();
215 case Kind::kResource:
216 if (constant_value()) {
217 return constant_value()->shape();
218 }
219 return TensorShape({});
220 case Kind::kXlaOp: {
221 TF_ASSIGN_OR_RETURN(xla::Shape xla_shape,
222 handle().builder()->GetShape(handle()));
223 TensorShape shape;
224 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape));
225 return shape;
226 }
227 case Kind::kTensorList:
228 return TensorShape({});
229 case Kind::kInvalid:
230 return errors::InvalidArgument(
231 "GetShape() called on invalid XlaExpression");
232 }
233 }
234
CastExpressionFromTensor(const Tensor & tensor)235 const XlaExpression* XlaExpression::CastExpressionFromTensor(
236 const Tensor& tensor) {
237 const XlaExpression* expression =
238 reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
239 CHECK(expression->kind() != XlaExpression::Kind::kInvalid)
240 << expression->HumanString();
241 return expression;
242 }
243
244 // Assigns an XlaExpression to a tensor on an XLA compilation device.
AssignExpressionToTensor(const XlaExpression & value,Tensor * tensor)245 void XlaExpression::AssignExpressionToTensor(const XlaExpression& value,
246 Tensor* tensor) {
247 const XlaExpression* expression =
248 reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
249 CHECK(expression->kind() == XlaExpression::Kind::kInvalid)
250 << expression->HumanString();
251 *const_cast<XlaExpression*>(expression) = value;
252 }
253
254 } // namespace tensorflow
255