• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_expression.h"
17 
18 #include "tensorflow/compiler/tf2xla/literal_util.h"
19 #include "tensorflow/compiler/tf2xla/shape_util.h"
20 #include "tensorflow/core/framework/types.pb.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 
23 namespace tensorflow {
24 
25 XlaExpression::XlaExpression() = default;
26 
Invalid()27 XlaExpression XlaExpression::Invalid() {
28   XlaExpression e;
29   e.kind_ = Kind::kInvalid;
30   return e;
31 }
32 
Constant(Tensor value)33 XlaExpression XlaExpression::Constant(Tensor value) {
34   XlaExpression e;
35   e.kind_ = Kind::kConstant;
36   e.dtype_ = value.dtype();
37   e.constant_value_ = value;
38   return e;
39 }
40 
ConstantResource(Tensor value,XlaResource * resource)41 XlaExpression XlaExpression::ConstantResource(Tensor value,
42                                               XlaResource* resource) {
43   XlaExpression e;
44   e.kind_ = Kind::kResource;
45   e.dtype_ = DT_RESOURCE;
46   e.resource_ = resource;
47   e.constant_value_ = value;
48   return e;
49 }
50 
XlaOp(xla::XlaOp value,DataType dtype)51 XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) {
52   XlaExpression e;
53   e.kind_ = Kind::kXlaOp;
54   e.dtype_ = dtype;
55   e.handle_ = value;
56   return e;
57 }
58 
TensorList(xla::XlaOp tensor_list)59 XlaExpression XlaExpression::TensorList(xla::XlaOp tensor_list) {
60   XlaExpression e;
61   e.kind_ = Kind::kTensorList;
62   e.dtype_ = DT_VARIANT;
63   e.handle_ = tensor_list;
64   return e;
65 }
66 
Resource(XlaResource * resource)67 XlaExpression XlaExpression::Resource(XlaResource* resource) {
68   XlaExpression e;
69   e.kind_ = Kind::kResource;
70   e.dtype_ = DT_RESOURCE;
71   e.resource_ = resource;
72   return e;
73 }
74 
HumanString() const75 string XlaExpression::HumanString() const {
76   switch (kind_) {
77     case Kind::kInvalid:
78       return "invalid";
79     case Kind::kConstant:
80       return "constant";
81     case Kind::kXlaOp:
82       return "xla_op";
83     case Kind::kResource:
84       return "resource";
85     case Kind::kTensorList:
86       return "tensor_list";
87   }
88 }
89 
AsXlaOp(xla::XlaBuilder * builder) const90 xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const {
91   return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
92     switch (kind_) {
93       case Kind::kConstant: {
94         xla::BorrowingLiteral literal;
95         TF_RETURN_IF_ERROR(
96             HostTensorToBorrowingLiteral(*constant_value_, &literal));
97         return xla::ConstantLiteral(builder, literal);
98       }
99       case Kind::kTensorList:
100         TF_FALLTHROUGH_INTENDED;
101       case Kind::kXlaOp:
102         if (builder != handle_.builder()) {
103           return errors::InvalidArgument(
104               "Mismatched builders in XlaExpression::AsXlaOp");
105         }
106         return handle_;
107       default:
108         return errors::InvalidArgument("AsXlaOp called on XlaExpression: ",
109                                        HumanString());
110     }
111   });
112 }
113 
ResolveDynamism(xla::Client * client) const114 xla::StatusOr<Tensor> XlaExpression::ResolveDynamism(
115     xla::Client* client) const {
116   switch (kind()) {
117     case Kind::kConstant: {
118       // Constant values are considered static.
119       Tensor constant_false(DT_BOOL, constant_value()->shape());
120       auto flat = constant_false.flat<bool>();
121       for (int64 i = 0; i < flat.size(); ++i) flat(i) = false;
122       return constant_false;
123     }
124     case Kind::kXlaOp:
125       break;
126     case Kind::kTensorList:
127       TF_FALLTHROUGH_INTENDED;
128     case Kind::kResource:
129       TF_FALLTHROUGH_INTENDED;
130     case Kind::kInvalid:
131       return errors::InvalidArgument(
132           "ResolveDynamism called on unsupported XlaExpression: ",
133           HumanString());
134   }
135 
136   if (!client)
137     return errors::InvalidArgument("client is required to resolve constant");
138 
139   TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph,
140                       handle().builder()->BuildDynamicInferenceGraph(handle()));
141 
142   TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
143 
144   // The XLA layout is specified minor to major, and TensorFlow uses a major to
145   // minor order.
146   std::vector<int64> layout_indices(shape.dims());
147   std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
148   xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
149   TF_ASSIGN_OR_RETURN(xla::Literal literal,
150                       client->ComputeConstant(constant_graph, &layout));
151   Tensor tensor(DT_BOOL);
152   TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, DT_BOOL, &tensor));
153   return tensor;
154 }
155 
ResolveConstant(xla::Client * client,bool dynamic_dimension_is_minus_one) const156 xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
157     xla::Client* client, bool dynamic_dimension_is_minus_one) const {
158   switch (kind()) {
159     case Kind::kConstant:
160     case Kind::kResource:
161       return constant_value();
162     case Kind::kXlaOp:
163       break;
164     case Kind::kTensorList:
165       TF_FALLTHROUGH_INTENDED;
166     case Kind::kInvalid:
167       return errors::InvalidArgument(
168           "ResolveConstant called on XlaExpression: ", HumanString());
169   }
170 
171   TF_ASSIGN_OR_RETURN(bool is_constant,
172                       handle().builder()->IsConstant(handle()));
173   if (!is_constant) {
174     return {absl::nullopt};
175   }
176 
177   if (!client)
178     return errors::InvalidArgument("client is required to resolve constant");
179 
180   TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph,
181                       handle().builder()->BuildConstantSubGraph(
182                           handle(), dynamic_dimension_is_minus_one));
183 
184   TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
185 
186   // The XLA layout is specified minor to major, and TensorFlow uses a major to
187   // minor order.
188   std::vector<int64> layout_indices(shape.dims());
189   std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
190   xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
191   TF_ASSIGN_OR_RETURN(xla::Literal literal,
192                       client->ComputeConstant(constant_graph, &layout));
193   Tensor tensor;
194   TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype(), &tensor));
195   return {tensor};
196 }
197 
GetShape() const198 xla::StatusOr<TensorShape> XlaExpression::GetShape() const {
199   switch (kind_) {
200     case Kind::kConstant:
201       return constant_value()->shape();
202     case Kind::kResource:
203       if (constant_value()) {
204         return constant_value()->shape();
205       }
206       return TensorShape({});
207     case Kind::kXlaOp: {
208       TF_ASSIGN_OR_RETURN(xla::Shape xla_shape,
209                           handle().builder()->GetShape(handle()));
210       TensorShape shape;
211       TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape));
212       return shape;
213     }
214     case Kind::kTensorList:
215       return TensorShape({});
216     case Kind::kInvalid:
217       return errors::InvalidArgument(
218           "GetShape() called on invalid XlaExpression");
219   }
220 }
221 
CastExpressionFromTensor(const Tensor & tensor)222 const XlaExpression* XlaExpression::CastExpressionFromTensor(
223     const Tensor& tensor) {
224   const XlaExpression* expression =
225       reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
226   CHECK(expression->kind() != XlaExpression::Kind::kInvalid)
227       << expression->HumanString();
228   return expression;
229 }
230 
231 // Assigns an XlaExpression to a tensor on an XLA compilation device.
AssignExpressionToTensor(const XlaExpression & value,Tensor * tensor)232 void XlaExpression::AssignExpressionToTensor(const XlaExpression& value,
233                                              Tensor* tensor) {
234   const XlaExpression* expression =
235       reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
236   CHECK(expression->kind() == XlaExpression::Kind::kInvalid)
237       << expression->HumanString();
238   *const_cast<XlaExpression*>(expression) = value;
239 }
240 
241 }  // namespace tensorflow
242