• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_expression.h"
17 
18 #include "tensorflow/compiler/tf2xla/literal_util.h"
19 #include "tensorflow/compiler/tf2xla/shape_util.h"
20 #include "tensorflow/compiler/xla/client/value_inference.h"
21 #include "tensorflow/core/framework/types.pb.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 
24 namespace tensorflow {
25 
26 XlaExpression::XlaExpression() = default;
27 
Invalid()28 XlaExpression XlaExpression::Invalid() {
29   XlaExpression e;
30   e.kind_ = Kind::kInvalid;
31   return e;
32 }
33 
Constant(Tensor value)34 XlaExpression XlaExpression::Constant(Tensor value) {
35   XlaExpression e;
36   e.kind_ = Kind::kConstant;
37   e.dtype_ = value.dtype();
38   e.constant_value_ = value;
39   return e;
40 }
41 
ConstantResource(Tensor value,XlaResource * resource)42 XlaExpression XlaExpression::ConstantResource(Tensor value,
43                                               XlaResource* resource) {
44   XlaExpression e;
45   e.kind_ = Kind::kResource;
46   e.dtype_ = DT_RESOURCE;
47   e.resource_ = resource;
48   e.constant_value_ = value;
49   return e;
50 }
51 
XlaOp(xla::XlaOp value,DataType dtype)52 XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) {
53   XlaExpression e;
54   e.kind_ = Kind::kXlaOp;
55   e.dtype_ = dtype;
56   e.handle_ = value;
57   return e;
58 }
59 
TensorList(xla::XlaOp tensor_list)60 XlaExpression XlaExpression::TensorList(xla::XlaOp tensor_list) {
61   XlaExpression e;
62   e.kind_ = Kind::kTensorList;
63   e.dtype_ = DT_VARIANT;
64   e.handle_ = tensor_list;
65   return e;
66 }
67 
Resource(XlaResource * resource)68 XlaExpression XlaExpression::Resource(XlaResource* resource) {
69   XlaExpression e;
70   e.kind_ = Kind::kResource;
71   e.dtype_ = DT_RESOURCE;
72   e.resource_ = resource;
73   return e;
74 }
75 
HumanString() const76 string XlaExpression::HumanString() const {
77   switch (kind_) {
78     case Kind::kInvalid:
79       return "invalid";
80     case Kind::kConstant:
81       return "constant";
82     case Kind::kXlaOp:
83       return "xla_op";
84     case Kind::kResource:
85       return "resource";
86     case Kind::kTensorList:
87       return "tensor_list";
88   }
89 }
90 
AsXlaOp(xla::XlaBuilder * builder) const91 xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const {
92   return builder->ReportErrorOrReturn([&]() -> StatusOr<xla::XlaOp> {
93     switch (kind_) {
94       case Kind::kConstant: {
95         xla::BorrowingLiteral literal;
96         TF_RETURN_IF_ERROR(
97             HostTensorToBorrowingLiteral(*constant_value_, &literal));
98         return xla::ConstantLiteral(builder, literal);
99       }
100       case Kind::kTensorList:
101         TF_FALLTHROUGH_INTENDED;
102       case Kind::kXlaOp:
103         if (builder != handle_.builder()) {
104           return errors::InvalidArgument(
105               "Mismatched builders in XlaExpression::AsXlaOp");
106         }
107         return handle_;
108       default:
109         return errors::InvalidArgument("AsXlaOp called on XlaExpression: ",
110                                        HumanString());
111     }
112   });
113 }
114 
ResolveDynamism(xla::Client * client) const115 StatusOr<Tensor> XlaExpression::ResolveDynamism(xla::Client* client) const {
116   switch (kind()) {
117     case Kind::kConstant: {
118       // Constant values are considered static.
119       Tensor constant_false(DT_BOOL, constant_value()->shape());
120       auto flat = constant_false.flat<bool>();
121       for (int64_t i = 0; i < flat.size(); ++i) flat(i) = false;
122       return constant_false;
123     }
124     case Kind::kXlaOp:
125       break;
126     case Kind::kTensorList:
127       TF_FALLTHROUGH_INTENDED;
128     case Kind::kResource:
129       TF_FALLTHROUGH_INTENDED;
130     case Kind::kInvalid:
131       return errors::InvalidArgument(
132           "ResolveDynamism called on unsupported XlaExpression: ",
133           HumanString());
134   }
135 
136   if (!client)
137     return errors::InvalidArgument("client is required to resolve constant");
138 
139   TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
140 
141   // The XLA layout is specified minor to major, and TensorFlow uses a major to
142   // minor order.
143   std::vector<int64_t> layout_indices(shape.dims());
144   std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
145   xla::ValueInference value_inference(handle().builder());
146   TF_ASSIGN_OR_RETURN(xla::LiteralSlice literal,
147                       value_inference.AnalyzeIsDynamic(handle()));
148   Tensor tensor(DT_BOOL);
149   TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, DT_BOOL, &tensor));
150   return tensor;
151 }
152 
ResolveConstant(xla::Client * client,bool dynamic_dimension_is_minus_one,xla::ValueInferenceMode mode) const153 StatusOr<std::optional<Tensor>> XlaExpression::ResolveConstant(
154     xla::Client* client, bool dynamic_dimension_is_minus_one,
155     xla::ValueInferenceMode mode) const {
156   switch (kind()) {
157     case Kind::kConstant:
158     case Kind::kResource:
159       return constant_value();
160     case Kind::kXlaOp:
161       break;
162     case Kind::kTensorList:
163       TF_FALLTHROUGH_INTENDED;
164     case Kind::kInvalid:
165       return errors::InvalidArgument(
166           "ResolveConstant called on XlaExpression: ", HumanString());
167   }
168   TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
169   // The XLA layout is specified minor to major, and TensorFlow uses a major to
170   // minor order.
171   std::vector<int64_t> layout_indices(shape.dims());
172   std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
173   xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
174   if (mode == xla::ValueInferenceMode::kLowerBound ||
175       mode == xla::ValueInferenceMode::kUpperBound ||
176       mode == xla::ValueInferenceMode::kValue) {
177     std::vector<int64_t> layout_indices(shape.dims());
178     std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
179     xla::ValueInference value_inference(handle().builder());
180     TF_ASSIGN_OR_RETURN(xla::OptionalLiteral literal,
181                         value_inference.AnalyzeConstant(handle(), mode));
182     if (!literal.GetValue().has_value()) {
183       return {std::nullopt};
184     }
185     Tensor tensor;
186     TF_RETURN_IF_ERROR(LiteralToHostTensor(
187         literal.GetValue().value().Relayout(layout), dtype(), &tensor));
188     return {tensor};
189   }
190 
191   TF_ASSIGN_OR_RETURN(bool is_constant,
192                       handle().builder()->IsConstant(handle()));
193   if (!is_constant) {
194     return {std::nullopt};
195   }
196 
197   if (!client)
198     return errors::InvalidArgument("client is required to resolve constant");
199 
200   TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph,
201                       handle().builder()->BuildConstantSubGraph(
202                           handle(), dynamic_dimension_is_minus_one));
203 
204   TF_ASSIGN_OR_RETURN(xla::Literal literal,
205                       client->ComputeConstant(constant_graph, &layout));
206   Tensor tensor;
207   TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype(), &tensor));
208   return {tensor};
209 }
210 
GetShape() const211 StatusOr<TensorShape> XlaExpression::GetShape() const {
212   switch (kind_) {
213     case Kind::kConstant:
214       return constant_value()->shape();
215     case Kind::kResource:
216       if (constant_value()) {
217         return constant_value()->shape();
218       }
219       return TensorShape({});
220     case Kind::kXlaOp: {
221       TF_ASSIGN_OR_RETURN(xla::Shape xla_shape,
222                           handle().builder()->GetShape(handle()));
223       TensorShape shape;
224       TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape));
225       return shape;
226     }
227     case Kind::kTensorList:
228       return TensorShape({});
229     case Kind::kInvalid:
230       return errors::InvalidArgument(
231           "GetShape() called on invalid XlaExpression");
232   }
233 }
234 
CastExpressionFromTensor(const Tensor & tensor)235 const XlaExpression* XlaExpression::CastExpressionFromTensor(
236     const Tensor& tensor) {
237   const XlaExpression* expression =
238       reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
239   CHECK(expression->kind() != XlaExpression::Kind::kInvalid)
240       << expression->HumanString();
241   return expression;
242 }
243 
244 // Assigns an XlaExpression to a tensor on an XLA compilation device.
AssignExpressionToTensor(const XlaExpression & value,Tensor * tensor)245 void XlaExpression::AssignExpressionToTensor(const XlaExpression& value,
246                                              Tensor* tensor) {
247   const XlaExpression* expression =
248       reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
249   CHECK(expression->kind() == XlaExpression::Kind::kInvalid)
250       << expression->HumanString();
251   *const_cast<XlaExpression*>(expression) = value;
252 }
253 
254 }  // namespace tensorflow
255