• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/shape.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/str_join.h"
20 #include "tensorflow/compiler/xla/shape_util.h"
21 
22 namespace xla {
23 
Shape(const ShapeProto & shape_proto)24 Shape::Shape(const ShapeProto& shape_proto) {
25   set_element_type(shape_proto.element_type());
26   dimensions_.reserve(shape_proto.dimensions_size());
27   for (const int64 dimension : shape_proto.dimensions()) {
28     add_dimensions(dimension);
29   }
30   // A malformed proto may have different is_dynamic_dimension_size and
31   // dimensions_size. Since C++ is evil, and we have no good way of bailing out
32   // in a constructor, conservatively trim the is_dynamic_dimension size.
33   // TODO(b/120111794): Make this a hard error when we have a factory method
34   // instead of a constructor.
35   if (shape_proto.dimensions_size() !=
36       shape_proto.is_dynamic_dimension_size()) {
37     if (shape_proto.is_dynamic_dimension_size() != 0) {
38       LOG(ERROR) << "Malformed shape proto: number of is_dynamic_dimension "
39                     "fields does not match number of dimension fields";
40     } else {
41       LOG(WARNING) << "Malformed shape proto: is_dynamic_dimension is empty";
42     }
43   }
44   int64 num_dynamic_dimension_fields = std::min(
45       shape_proto.dimensions_size(), shape_proto.is_dynamic_dimension_size());
46   for (int i = 0; i < num_dynamic_dimension_fields; i++) {
47     dynamic_dimensions_[i] = shape_proto.is_dynamic_dimension(i);
48   }
49   tuple_shapes_.reserve(shape_proto.tuple_shapes_size());
50   for (const ShapeProto& element_shape : shape_proto.tuple_shapes()) {
51     *add_tuple_shapes() = Shape(element_shape);
52   }
53   if (shape_proto.has_layout()) {
54     *mutable_layout() = Layout::CreateFromProto(shape_proto.layout());
55   }
56 }
57 
ToProto() const58 ShapeProto Shape::ToProto() const {
59   ShapeProto proto;
60   proto.set_element_type(element_type_);
61   proto.mutable_dimensions()->Reserve(dimensions_size());
62   for (const int64 dimension : dimensions()) {
63     proto.add_dimensions(dimension);
64   }
65   for (const bool dynamic : dynamic_dimensions_) {
66     proto.add_is_dynamic_dimension(dynamic);
67   }
68   proto.mutable_tuple_shapes()->Reserve(tuple_shapes_size());
69   for (const Shape& shape : tuple_shapes()) {
70     *proto.add_tuple_shapes() = shape.ToProto();
71   }
72   if (has_layout()) {
73     *proto.mutable_layout() = layout().ToProto();
74   }
75   return proto;
76 }
77 
ToString(bool print_layout) const78 string Shape::ToString(bool print_layout) const {
79   if (print_layout) {
80     return ShapeUtil::HumanStringWithLayout(*this);
81   } else {
82     return ShapeUtil::HumanString(*this);
83   }
84 }
85 
is_static() const86 bool Shape::is_static() const {
87   if (IsTuple()) {
88     for (const Shape& subshape : tuple_shapes_) {
89       if (!subshape.is_static()) {
90         return false;
91       }
92     }
93   }
94   return !absl::c_any_of(dynamic_dimensions_, [](bool b) { return b; });
95 }
96 
DeleteDimension(int64 dim_to_delete)97 void Shape::DeleteDimension(int64 dim_to_delete) {
98   CHECK(IsArray());
99   CHECK_GE(dim_to_delete, 0);
100   CHECK_LT(dim_to_delete, dimensions_.size());
101   dimensions_.erase(dimensions_.begin() + dim_to_delete);
102   dynamic_dimensions_.erase(dynamic_dimensions_.begin() + dim_to_delete);
103   if (LayoutUtil::HasLayout(*this)) {
104     layout_.set_format(DENSE);
105     for (int64 i = 0; i < layout_.minor_to_major().size();) {
106       if (layout_.minor_to_major(i) == dim_to_delete) {
107         layout_.mutable_minor_to_major()->erase(
108             layout_.mutable_minor_to_major()->begin() + i);
109         continue;
110       }
111       if (layout_.minor_to_major(i) > dim_to_delete) {
112         (*layout_.mutable_minor_to_major())[i] -= 1;
113       }
114       ++i;
115     }
116   }
117 }
118 
operator ()(const Shape & lhs,const Shape & rhs)119 bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) {
120   if (lhs.IsTuple()) {
121     return rhs.IsTuple() &&
122            absl::c_equal(
123                lhs.tuple_shapes(), rhs.tuple_shapes(),
124                [=](const Shape& l, const Shape& r) { return (*this)(l, r); });
125   } else if (!lhs.IsArray()) {
126     // Non-tuple, non-array tupes such as opaque and token types are trivially
127     // the same.
128     return lhs.element_type() == rhs.element_type();
129   }
130 
131   if (!rhs.IsArray()) {
132     return false;
133   }
134 
135   if (!ignore_element_type_) {
136     if ((ignore_fp_precision_ &&
137          !ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) ||
138         (!ignore_fp_precision_ && !ShapeUtil::SameElementType(lhs, rhs))) {
139       VLOG(3) << "CompareShapes: lhs element type != rhs element type";
140       return false;
141     }
142   }
143 
144   if (!ShapeUtil::SameDimensions(lhs, rhs)) {
145     VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
146     return false;
147   }
148 
149   if (!ignore_layout_) {
150     if (lhs.layout().format() != rhs.layout().format()) {
151       VLOG(3) << "CompareShapes: lhs layout format != rhs layout format";
152       return false;
153     }
154     if (LayoutUtil::IsDenseArray(lhs)) {
155       Layout::Equal equal;
156       if (ignore_tiles_in_layout_) {
157         equal.IgnoreTiles();
158       }
159       if (ignore_element_size_in_layout_) {
160         equal.IgnoreElementSize();
161       }
162       if (ignore_memory_space_in_layout_) {
163         equal.IgnoreMemorySpace();
164       }
165       if (!equal(lhs.layout(), rhs.layout())) {
166         VLOG(3) << "CompareShapes: lhs layout != rhs layout";
167         return false;
168       }
169     }
170   }
171 
172   if (!ignore_dynamic_dimension_) {
173     for (int i = 0; i < lhs.rank(); ++i) {
174       if (lhs.is_dynamic_dimension(i) != rhs.is_dynamic_dimension(i)) {
175         VLOG(3)
176             << "CompareShapes: lhs and rhs have different dynamic dimensions.";
177         return false;
178       }
179     }
180   }
181   return true;
182 }
183 
operator <<(std::ostream & out,const Shape & shape)184 std::ostream& operator<<(std::ostream& out, const Shape& shape) {
185   out << shape.ToString(/*print_layout=*/true);
186   return out;
187 }
188 
ProgramShape(const ProgramShapeProto & program_shape_proto)189 ProgramShape::ProgramShape(const ProgramShapeProto& program_shape_proto) {
190   for (const ShapeProto& shape_proto : program_shape_proto.parameters()) {
191     *add_parameters() = Shape(shape_proto);
192   }
193   *mutable_result() = Shape(program_shape_proto.result());
194   for (const string& name : program_shape_proto.parameter_names()) {
195     add_parameter_names(name);
196   }
197 }
198 
ToProto() const199 ProgramShapeProto ProgramShape::ToProto() const {
200   ProgramShapeProto proto;
201   for (const Shape& shape : parameters()) {
202     *proto.add_parameters() = shape.ToProto();
203   }
204   *proto.mutable_result() = result().ToProto();
205   for (const string& name : parameter_names()) {
206     proto.add_parameter_names(name);
207   }
208   return proto;
209 }
210 
ToString() const211 string ProgramShape::ToString() const {
212   std::vector<string> parameter_strings(parameters_size());
213   for (int i = 0; i < parameters_size(); ++i) {
214     parameter_strings[i] = absl::StrCat(
215         i < parameter_names_size() ? parameter_names(i) : "(unknown)", ": ",
216         ShapeUtil::HumanString(parameters(i)));
217   }
218   return absl::StrCat("(", absl::StrJoin(parameter_strings, ", "), ") -> ",
219                       ShapeUtil::HumanString(result()));
220 }
221 
operator <<(std::ostream & out,const ProgramShape & program_shape)222 std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape) {
223   out << program_shape.ToString() << "\n";
224   return out;
225 }
226 
227 }  // namespace xla
228