1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/shape.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/str_join.h"
20 #include "tensorflow/compiler/xla/shape_util.h"
21
22 namespace xla {
23
Shape(const ShapeProto & shape_proto)24 Shape::Shape(const ShapeProto& shape_proto) {
25 set_element_type(shape_proto.element_type());
26 dimensions_.reserve(shape_proto.dimensions_size());
27 for (const int64 dimension : shape_proto.dimensions()) {
28 add_dimensions(dimension);
29 }
30 // A malformed proto may have different is_dynamic_dimension_size and
31 // dimensions_size. Since C++ is evil, and we have no good way of bailing out
32 // in a constructor, conservatively trim the is_dynamic_dimension size.
33 // TODO(b/120111794): Make this a hard error when we have a factory method
34 // instead of a constructor.
35 if (shape_proto.dimensions_size() !=
36 shape_proto.is_dynamic_dimension_size()) {
37 if (shape_proto.is_dynamic_dimension_size() != 0) {
38 LOG(ERROR) << "Malformed shape proto: number of is_dynamic_dimension "
39 "fields does not match number of dimension fields";
40 } else {
41 LOG(WARNING) << "Malformed shape proto: is_dynamic_dimension is empty";
42 }
43 }
44 int64 num_dynamic_dimension_fields = std::min(
45 shape_proto.dimensions_size(), shape_proto.is_dynamic_dimension_size());
46 for (int i = 0; i < num_dynamic_dimension_fields; i++) {
47 dynamic_dimensions_[i] = shape_proto.is_dynamic_dimension(i);
48 }
49 tuple_shapes_.reserve(shape_proto.tuple_shapes_size());
50 for (const ShapeProto& element_shape : shape_proto.tuple_shapes()) {
51 *add_tuple_shapes() = Shape(element_shape);
52 }
53 if (shape_proto.has_layout()) {
54 *mutable_layout() = Layout::CreateFromProto(shape_proto.layout());
55 }
56 }
57
ToProto() const58 ShapeProto Shape::ToProto() const {
59 ShapeProto proto;
60 proto.set_element_type(element_type_);
61 proto.mutable_dimensions()->Reserve(dimensions_size());
62 for (const int64 dimension : dimensions()) {
63 proto.add_dimensions(dimension);
64 }
65 for (const bool dynamic : dynamic_dimensions_) {
66 proto.add_is_dynamic_dimension(dynamic);
67 }
68 proto.mutable_tuple_shapes()->Reserve(tuple_shapes_size());
69 for (const Shape& shape : tuple_shapes()) {
70 *proto.add_tuple_shapes() = shape.ToProto();
71 }
72 if (has_layout()) {
73 *proto.mutable_layout() = layout().ToProto();
74 }
75 return proto;
76 }
77
ToString(bool print_layout) const78 string Shape::ToString(bool print_layout) const {
79 if (print_layout) {
80 return ShapeUtil::HumanStringWithLayout(*this);
81 } else {
82 return ShapeUtil::HumanString(*this);
83 }
84 }
85
is_static() const86 bool Shape::is_static() const {
87 if (IsTuple()) {
88 for (const Shape& subshape : tuple_shapes_) {
89 if (!subshape.is_static()) {
90 return false;
91 }
92 }
93 }
94 return !absl::c_any_of(dynamic_dimensions_, [](bool b) { return b; });
95 }
96
DeleteDimension(int64 dim_to_delete)97 void Shape::DeleteDimension(int64 dim_to_delete) {
98 CHECK(IsArray());
99 CHECK_GE(dim_to_delete, 0);
100 CHECK_LT(dim_to_delete, dimensions_.size());
101 dimensions_.erase(dimensions_.begin() + dim_to_delete);
102 dynamic_dimensions_.erase(dynamic_dimensions_.begin() + dim_to_delete);
103 if (LayoutUtil::HasLayout(*this)) {
104 layout_.set_format(DENSE);
105 for (int64 i = 0; i < layout_.minor_to_major().size();) {
106 if (layout_.minor_to_major(i) == dim_to_delete) {
107 layout_.mutable_minor_to_major()->erase(
108 layout_.mutable_minor_to_major()->begin() + i);
109 continue;
110 }
111 if (layout_.minor_to_major(i) > dim_to_delete) {
112 (*layout_.mutable_minor_to_major())[i] -= 1;
113 }
114 ++i;
115 }
116 }
117 }
118
operator ()(const Shape & lhs,const Shape & rhs)119 bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) {
120 if (lhs.IsTuple()) {
121 return rhs.IsTuple() &&
122 absl::c_equal(
123 lhs.tuple_shapes(), rhs.tuple_shapes(),
124 [=](const Shape& l, const Shape& r) { return (*this)(l, r); });
125 } else if (!lhs.IsArray()) {
126 // Non-tuple, non-array tupes such as opaque and token types are trivially
127 // the same.
128 return lhs.element_type() == rhs.element_type();
129 }
130
131 if (!rhs.IsArray()) {
132 return false;
133 }
134
135 if (!ignore_element_type_) {
136 if ((ignore_fp_precision_ &&
137 !ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) ||
138 (!ignore_fp_precision_ && !ShapeUtil::SameElementType(lhs, rhs))) {
139 VLOG(3) << "CompareShapes: lhs element type != rhs element type";
140 return false;
141 }
142 }
143
144 if (!ShapeUtil::SameDimensions(lhs, rhs)) {
145 VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
146 return false;
147 }
148
149 if (!ignore_layout_) {
150 if (lhs.layout().format() != rhs.layout().format()) {
151 VLOG(3) << "CompareShapes: lhs layout format != rhs layout format";
152 return false;
153 }
154 if (LayoutUtil::IsDenseArray(lhs)) {
155 Layout::Equal equal;
156 if (ignore_tiles_in_layout_) {
157 equal.IgnoreTiles();
158 }
159 if (ignore_element_size_in_layout_) {
160 equal.IgnoreElementSize();
161 }
162 if (ignore_memory_space_in_layout_) {
163 equal.IgnoreMemorySpace();
164 }
165 if (!equal(lhs.layout(), rhs.layout())) {
166 VLOG(3) << "CompareShapes: lhs layout != rhs layout";
167 return false;
168 }
169 }
170 }
171
172 if (!ignore_dynamic_dimension_) {
173 for (int i = 0; i < lhs.rank(); ++i) {
174 if (lhs.is_dynamic_dimension(i) != rhs.is_dynamic_dimension(i)) {
175 VLOG(3)
176 << "CompareShapes: lhs and rhs have different dynamic dimensions.";
177 return false;
178 }
179 }
180 }
181 return true;
182 }
183
operator <<(std::ostream & out,const Shape & shape)184 std::ostream& operator<<(std::ostream& out, const Shape& shape) {
185 out << shape.ToString(/*print_layout=*/true);
186 return out;
187 }
188
ProgramShape(const ProgramShapeProto & program_shape_proto)189 ProgramShape::ProgramShape(const ProgramShapeProto& program_shape_proto) {
190 for (const ShapeProto& shape_proto : program_shape_proto.parameters()) {
191 *add_parameters() = Shape(shape_proto);
192 }
193 *mutable_result() = Shape(program_shape_proto.result());
194 for (const string& name : program_shape_proto.parameter_names()) {
195 add_parameter_names(name);
196 }
197 }
198
ToProto() const199 ProgramShapeProto ProgramShape::ToProto() const {
200 ProgramShapeProto proto;
201 for (const Shape& shape : parameters()) {
202 *proto.add_parameters() = shape.ToProto();
203 }
204 *proto.mutable_result() = result().ToProto();
205 for (const string& name : parameter_names()) {
206 proto.add_parameter_names(name);
207 }
208 return proto;
209 }
210
ToString() const211 string ProgramShape::ToString() const {
212 std::vector<string> parameter_strings(parameters_size());
213 for (int i = 0; i < parameters_size(); ++i) {
214 parameter_strings[i] = absl::StrCat(
215 i < parameter_names_size() ? parameter_names(i) : "(unknown)", ": ",
216 ShapeUtil::HumanString(parameters(i)));
217 }
218 return absl::StrCat("(", absl::StrJoin(parameter_strings, ", "), ") -> ",
219 ShapeUtil::HumanString(result()));
220 }
221
operator <<(std::ostream & out,const ProgramShape & program_shape)222 std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape) {
223 out << program_shape.ToString() << "\n";
224 return out;
225 }
226
227 } // namespace xla
228