• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/shape_util.h"
17 
18 #include <numeric>
19 
20 #include "tensorflow/compiler/tf2xla/type_util.h"
21 #include "tensorflow/compiler/xla/layout_util.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/core/lib/core/status.h"
24 
25 namespace tensorflow {
26 namespace {
27 
PopulateInfeedLayoutVector(const xla::Shape & shape,std::vector<int> * layouts)28 Status PopulateInfeedLayoutVector(const xla::Shape& shape,
29                                   std::vector<int>* layouts) {
30   if (shape.IsTuple()) {
31     int64 tuple_elements = xla::ShapeUtil::TupleElementCount(shape);
32     for (int64 i = 0; i < tuple_elements; ++i) {
33       const xla::Shape& subshape =
34           xla::ShapeUtil::GetTupleElementShape(shape, i);
35       TF_RETURN_IF_ERROR(PopulateInfeedLayoutVector(subshape, layouts));
36     }
37   } else if (xla::LayoutUtil::HasLayout(shape)) {
38     for (auto dim : xla::LayoutUtil::MinorToMajor(shape)) {
39       layouts->push_back(dim);
40     }
41   } else {
42     layouts->insert(layouts->end(), shape.rank(), -1);
43   }
44   return Status::OK();
45 }
46 
47 // Populate the output layout unless the minor_to_major array contains all -1
48 // value, in which case the layout is considered missing and the API returns
49 // false.
MakeLayout(absl::Span<const int64> minor_to_major,xla::Layout * layout)50 xla::StatusOr<bool> MakeLayout(absl::Span<const int64> minor_to_major,
51                                xla::Layout* layout) {
52   if (std::all_of(minor_to_major.begin(), minor_to_major.end(),
53                   [](int64 dim) { return dim == -1; })) {
54     return false;
55   }
56   std::vector<bool> dim_present(minor_to_major.size(), false);
57   for (auto dim : minor_to_major) {
58     if (dim < 0 || dim >= minor_to_major.size()) {
59       return errors::InvalidArgument("Layout dimension out of range: dim=", dim,
60                                      " rank=", minor_to_major.size());
61     }
62     if (dim_present[dim]) {
63       return errors::InvalidArgument("Repeated layout dimension: dim=", dim);
64     }
65     dim_present[dim] = true;
66   }
67   *layout = xla::LayoutUtil::MakeLayout(minor_to_major);
68   return true;
69 }
70 
AssignLayout(absl::Span<const int64> minor_to_major,const std::function<xla::Layout (const xla::Shape &)> & layout_func,xla::Shape * shape)71 Status AssignLayout(
72     absl::Span<const int64> minor_to_major,
73     const std::function<xla::Layout(const xla::Shape&)>& layout_func,
74     xla::Shape* shape) {
75   xla::Layout layout;
76   TF_ASSIGN_OR_RETURN(bool has_layout, MakeLayout(minor_to_major, &layout));
77   if (!has_layout && layout_func) {
78     layout = layout_func(*shape);
79   }
80   *shape->mutable_layout() = layout;
81   return Status::OK();
82 }
83 
84 }  // namespace
85 
86 // Convert an XLA Shape into the equivalent TensorFlow shape.
XLAShapeToTensorShape(const xla::Shape & shape,TensorShape * tensor_shape)87 Status XLAShapeToTensorShape(const xla::Shape& shape,
88                              TensorShape* tensor_shape) {
89   if (shape.IsTuple()) {
90     return errors::InvalidArgument("XLA shape ",
91                                    xla::ShapeUtil::HumanString(shape),
92                                    " cannot be converted to a TensorShape");
93   }
94   *tensor_shape = TensorShape();
95   for (int i = 0; i < shape.rank(); ++i) {
96     tensor_shape->AddDim(shape.dimensions(i));
97   }
98   return Status::OK();
99 }
100 
101 // Convert a TensorShape into the equivalent XLA Shape proto.
TensorShapeToXLAShape(DataType dtype,const TensorShape & tensor_shape,xla::Shape * shape)102 Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
103                              xla::Shape* shape) {
104   xla::PrimitiveType type;
105   TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
106   *shape = TensorShapeToXLAShape(type, tensor_shape);
107   return Status::OK();
108 }
109 
TensorShapeToXLAShape(xla::PrimitiveType type,const TensorShape & tensor_shape)110 xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
111                                  const TensorShape& tensor_shape) {
112   int rank = tensor_shape.dims();
113   std::vector<int64> dimensions(rank);
114   std::vector<int64> layout(rank);
115   for (int d = 0; d < rank; ++d) {
116     dimensions[d] = tensor_shape.dim_size(d);
117   }
118   // XLA uses minor-to-major; Tensorflow uses major-to-minor.
119   std::iota(layout.rbegin(), layout.rend(), 0);
120 
121   return xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
122 }
123 
GetShapeLayoutVector(const xla::Shape & shape)124 xla::StatusOr<std::vector<int>> GetShapeLayoutVector(const xla::Shape& shape) {
125   std::vector<int> layouts;
126   TF_RETURN_IF_ERROR(PopulateInfeedLayoutVector(shape, &layouts));
127   return layouts;
128 }
129 
GetShapeWithLayout(const xla::Shape & input_shape,absl::Span<const int64> minor_to_major,const std::function<xla::Layout (const xla::Shape &)> & layout_func,xla::Shape * output_shape)130 Status GetShapeWithLayout(
131     const xla::Shape& input_shape, absl::Span<const int64> minor_to_major,
132     const std::function<xla::Layout(const xla::Shape&)>& layout_func,
133     xla::Shape* output_shape) {
134   if (input_shape.IsTuple()) {
135     int64 tuple_elements = xla::ShapeUtil::TupleElementCount(input_shape);
136     std::vector<xla::Shape> shapes;
137     shapes.reserve(tuple_elements);
138     size_t position = 0;
139     for (int64 i = 0; i < tuple_elements; ++i) {
140       const xla::Shape& shape =
141           xla::ShapeUtil::GetTupleElementShape(input_shape, i);
142       if (shape.IsTuple()) {
143         return errors::InvalidArgument(
144             "Nested tuples not supported: ",
145             xla::ShapeUtil::HumanString(input_shape));
146       }
147       int64 rank = shape.rank();
148       if (position + rank > minor_to_major.size()) {
149         return errors::InvalidArgument(
150             "Not enough layout attribute elements: position=", position,
151             " rank=", rank, " elements=", minor_to_major.size());
152       }
153       shapes.push_back(shape);
154       TF_RETURN_IF_ERROR(AssignLayout(
155           absl::Span<const int64>(minor_to_major).subspan(position, rank),
156           layout_func, &shapes.back()));
157       position += rank;
158 
159       VLOG(4) << "Shape[" << i
160               << "] = " << xla::ShapeUtil::HumanStringWithLayout(shapes.back());
161     }
162     if (position != minor_to_major.size()) {
163       return errors::InvalidArgument(
164           "Too many elements passed in the layout attribute: position=",
165           position, " size=", minor_to_major.size());
166     }
167     *output_shape = xla::ShapeUtil::MakeTupleShape(shapes);
168   } else {
169     int64 rank = input_shape.rank();
170     if (rank != minor_to_major.size()) {
171       return errors::InvalidArgument(
172           "Wrong number of layout attribute elements: rank=", rank,
173           " elements=", minor_to_major.size());
174     }
175     *output_shape = input_shape;
176     TF_RETURN_IF_ERROR(AssignLayout(minor_to_major, layout_func, output_shape));
177 
178     VLOG(4) << "Shape[] = "
179             << xla::ShapeUtil::HumanStringWithLayout(*output_shape);
180   }
181   return Status::OK();
182 }
183 
184 }  // namespace tensorflow
185