1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_expression.h"
17
18 #include "tensorflow/compiler/tf2xla/literal_util.h"
19 #include "tensorflow/compiler/tf2xla/shape_util.h"
20 #include "tensorflow/core/framework/types.pb.h"
21 #include "tensorflow/core/lib/core/errors.h"
22
23 namespace tensorflow {
24
25 XlaExpression::XlaExpression() = default;
26
Invalid()27 XlaExpression XlaExpression::Invalid() {
28 XlaExpression e;
29 e.kind_ = Kind::kInvalid;
30 return e;
31 }
32
Constant(Tensor value)33 XlaExpression XlaExpression::Constant(Tensor value) {
34 XlaExpression e;
35 e.kind_ = Kind::kConstant;
36 e.dtype_ = value.dtype();
37 e.constant_value_ = value;
38 return e;
39 }
40
ConstantResource(Tensor value,XlaResource * resource)41 XlaExpression XlaExpression::ConstantResource(Tensor value,
42 XlaResource* resource) {
43 XlaExpression e;
44 e.kind_ = Kind::kResource;
45 e.dtype_ = DT_RESOURCE;
46 e.resource_ = resource;
47 e.constant_value_ = value;
48 return e;
49 }
50
XlaOp(xla::XlaOp value,DataType dtype)51 XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) {
52 XlaExpression e;
53 e.kind_ = Kind::kXlaOp;
54 e.dtype_ = dtype;
55 e.handle_ = value;
56 return e;
57 }
58
TensorList(xla::XlaOp tensor_list)59 XlaExpression XlaExpression::TensorList(xla::XlaOp tensor_list) {
60 XlaExpression e;
61 e.kind_ = Kind::kTensorList;
62 e.dtype_ = DT_VARIANT;
63 e.handle_ = tensor_list;
64 return e;
65 }
66
Resource(XlaResource * resource)67 XlaExpression XlaExpression::Resource(XlaResource* resource) {
68 XlaExpression e;
69 e.kind_ = Kind::kResource;
70 e.dtype_ = DT_RESOURCE;
71 e.resource_ = resource;
72 return e;
73 }
74
HumanString() const75 string XlaExpression::HumanString() const {
76 switch (kind_) {
77 case Kind::kInvalid:
78 return "invalid";
79 case Kind::kConstant:
80 return "constant";
81 case Kind::kXlaOp:
82 return "xla_op";
83 case Kind::kResource:
84 return "resource";
85 case Kind::kTensorList:
86 return "tensor_list";
87 }
88 }
89
AsXlaOp(xla::XlaBuilder * builder) const90 xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const {
91 return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
92 switch (kind_) {
93 case Kind::kConstant: {
94 xla::BorrowingLiteral literal;
95 TF_RETURN_IF_ERROR(
96 HostTensorToBorrowingLiteral(*constant_value_, &literal));
97 return xla::ConstantLiteral(builder, literal);
98 }
99 case Kind::kTensorList:
100 TF_FALLTHROUGH_INTENDED;
101 case Kind::kXlaOp:
102 if (builder != handle_.builder()) {
103 return errors::InvalidArgument(
104 "Mismatched builders in XlaExpression::AsXlaOp");
105 }
106 return handle_;
107 default:
108 return errors::InvalidArgument("AsXlaOp called on XlaExpression: ",
109 HumanString());
110 }
111 });
112 }
113
ResolveDynamism(xla::Client * client) const114 xla::StatusOr<Tensor> XlaExpression::ResolveDynamism(
115 xla::Client* client) const {
116 switch (kind()) {
117 case Kind::kConstant: {
118 // Constant values are considered static.
119 Tensor constant_false(DT_BOOL, constant_value()->shape());
120 auto flat = constant_false.flat<bool>();
121 for (int64 i = 0; i < flat.size(); ++i) flat(i) = false;
122 return constant_false;
123 }
124 case Kind::kXlaOp:
125 break;
126 case Kind::kTensorList:
127 TF_FALLTHROUGH_INTENDED;
128 case Kind::kResource:
129 TF_FALLTHROUGH_INTENDED;
130 case Kind::kInvalid:
131 return errors::InvalidArgument(
132 "ResolveDynamism called on unsupported XlaExpression: ",
133 HumanString());
134 }
135
136 if (!client)
137 return errors::InvalidArgument("client is required to resolve constant");
138
139 TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph,
140 handle().builder()->BuildDynamicInferenceGraph(handle()));
141
142 TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
143
144 // The XLA layout is specified minor to major, and TensorFlow uses a major to
145 // minor order.
146 std::vector<int64> layout_indices(shape.dims());
147 std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
148 xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
149 TF_ASSIGN_OR_RETURN(xla::Literal literal,
150 client->ComputeConstant(constant_graph, &layout));
151 Tensor tensor(DT_BOOL);
152 TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, DT_BOOL, &tensor));
153 return tensor;
154 }
155
ResolveConstant(xla::Client * client,bool dynamic_dimension_is_minus_one) const156 xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
157 xla::Client* client, bool dynamic_dimension_is_minus_one) const {
158 switch (kind()) {
159 case Kind::kConstant:
160 case Kind::kResource:
161 return constant_value();
162 case Kind::kXlaOp:
163 break;
164 case Kind::kTensorList:
165 TF_FALLTHROUGH_INTENDED;
166 case Kind::kInvalid:
167 return errors::InvalidArgument(
168 "ResolveConstant called on XlaExpression: ", HumanString());
169 }
170
171 TF_ASSIGN_OR_RETURN(bool is_constant,
172 handle().builder()->IsConstant(handle()));
173 if (!is_constant) {
174 return {absl::nullopt};
175 }
176
177 if (!client)
178 return errors::InvalidArgument("client is required to resolve constant");
179
180 TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph,
181 handle().builder()->BuildConstantSubGraph(
182 handle(), dynamic_dimension_is_minus_one));
183
184 TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
185
186 // The XLA layout is specified minor to major, and TensorFlow uses a major to
187 // minor order.
188 std::vector<int64> layout_indices(shape.dims());
189 std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
190 xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
191 TF_ASSIGN_OR_RETURN(xla::Literal literal,
192 client->ComputeConstant(constant_graph, &layout));
193 Tensor tensor;
194 TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype(), &tensor));
195 return {tensor};
196 }
197
GetShape() const198 xla::StatusOr<TensorShape> XlaExpression::GetShape() const {
199 switch (kind_) {
200 case Kind::kConstant:
201 return constant_value()->shape();
202 case Kind::kResource:
203 if (constant_value()) {
204 return constant_value()->shape();
205 }
206 return TensorShape({});
207 case Kind::kXlaOp: {
208 TF_ASSIGN_OR_RETURN(xla::Shape xla_shape,
209 handle().builder()->GetShape(handle()));
210 TensorShape shape;
211 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape));
212 return shape;
213 }
214 case Kind::kTensorList:
215 return TensorShape({});
216 case Kind::kInvalid:
217 return errors::InvalidArgument(
218 "GetShape() called on invalid XlaExpression");
219 }
220 }
221
CastExpressionFromTensor(const Tensor & tensor)222 const XlaExpression* XlaExpression::CastExpressionFromTensor(
223 const Tensor& tensor) {
224 const XlaExpression* expression =
225 reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
226 CHECK(expression->kind() != XlaExpression::Kind::kInvalid)
227 << expression->HumanString();
228 return expression;
229 }
230
231 // Assigns an XlaExpression to a tensor on an XLA compilation device.
AssignExpressionToTensor(const XlaExpression & value,Tensor * tensor)232 void XlaExpression::AssignExpressionToTensor(const XlaExpression& value,
233 Tensor* tensor) {
234 const XlaExpression* expression =
235 reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
236 CHECK(expression->kind() == XlaExpression::Kind::kInvalid)
237 << expression->HumanString();
238 *const_cast<XlaExpression*>(expression) = value;
239 }
240
241 } // namespace tensorflow
242