• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_H_
17 #define TENSORFLOW_COMPILER_XLA_SHAPE_H_
18 
19 #include <string>
20 #include <vector>
21 
22 #include "absl/container/inlined_vector.h"
23 #include "absl/types/optional.h"
24 #include "tensorflow/compiler/xla/layout.h"
25 #include "tensorflow/compiler/xla/primitive_util.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/platform/types.h"
29 
30 namespace xla {
31 
32 // A shape describes the number of dimensions in a array, the bounds of each
33 // dimension, and the primitive component type. For tuples, shape describes the
34 // structure (number of elements and nesting).
35 class Shape {
36  public:
37   Shape() = default;
38 
39   // Construct a shape from a ShapeProto.
40   explicit Shape(const ShapeProto& shape_proto);
41 
42   // Returns a ShapeProto representation of the Shape.
43   ShapeProto ToProto() const;
44 
45   // Returns a human-readable string that represents the given shape, with or
46   // without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]".
47   string ToString(bool print_layout = false) const;
48 
49   // Returns the rank (number of dimensions) of the given shape. Shape must be
50   // an array.
rank()51   int64 rank() const {
52     CHECK(IsArray()) << "Non-arrays do not have a rank, shape: " << ToString();
53     return dimensions_.size();
54   }
55 
56   // Returns whether the shape is of the specified type (array, tuple, etc).
IsArray()57   bool IsArray() const { return primitive_util::IsArrayType(element_type()); }
IsTuple()58   bool IsTuple() const { return element_type() == TUPLE; }
IsToken()59   bool IsToken() const { return element_type() == TOKEN; }
IsOpaque()60   bool IsOpaque() const { return element_type() == OPAQUE_TYPE; }
61 
62   // Returns true if no array dimension in the shape is dynamically sized. Tuple
63   // shapes are traversed recursively.
64   bool is_static() const;
65 
66   // Returns true if the given dimension is dynamically-sized.
is_dynamic_dimension(int dimension)67   bool is_dynamic_dimension(int dimension) const {
68     return dynamic_dimensions_.at(dimension);
69   }
70 
71   // Sets whether or not the given dimension is dynamically-sized.
set_dynamic_dimension(int dimension,bool is_dynamic)72   void set_dynamic_dimension(int dimension, bool is_dynamic) {
73     dynamic_dimensions_[dimension] = is_dynamic;
74   }
75 
dynamic_dimensions()76   absl::Span<const bool> dynamic_dimensions() const {
77     return dynamic_dimensions_;
78   }
79 
mutable_dynamic_dimensions()80   absl::Span<bool> mutable_dynamic_dimensions() {
81     return absl::MakeSpan(dynamic_dimensions_);
82   }
83 
84   // Add dimension_upper_bound().
85 
86   // Removes the given dimension form the shape. Layout, if it exists, is
87   // adjusted to match the modified shape.
88   void DeleteDimension(int64 dim_to_delete);
89 
90   // The following methods mirror the protobuf generated code interface for the
91   // message ShapeProto. This enabled easy migration of this data structure
92   // from a proto to a proper C++ class.
93   // TODO(b/29771030): Replace or augment these methods with a more ergonomic
94   // interface.
95 
96   // Methods for accessing the primitive type.
element_type()97   PrimitiveType element_type() const { return element_type_; }
set_element_type(PrimitiveType value)98   void set_element_type(PrimitiveType value) { element_type_ = value; }
99 
100   // Methods for accessing the dimensions array.
dimensions_size()101   int dimensions_size() const { return dimensions_.size(); }
dimensions(int index)102   int64 dimensions(int index) const { return dimensions_.at(index); }
set_dimensions(int index,int64 value)103   void set_dimensions(int index, int64 value) { dimensions_.at(index) = value; }
add_dimensions(int64 value)104   void add_dimensions(int64 value) {
105     dimensions_.push_back(value);
106     dynamic_dimensions_.push_back(false);
107   }
clear_dimensions()108   void clear_dimensions() {
109     dimensions_.clear();
110     dynamic_dimensions_.clear();
111   }
dimensions()112   absl::Span<const int64> dimensions() const { return dimensions_; }
mutable_dimensions()113   absl::Span<int64> mutable_dimensions() { return absl::MakeSpan(dimensions_); }
114 
115   // Methods for accessing the tuple subshapes. This field only non-empty for
116   // tuple shapes.
tuple_shapes_size()117   int tuple_shapes_size() const { return tuple_shapes_.size(); }
tuple_shapes(int index)118   const Shape& tuple_shapes(int index) const { return tuple_shapes_.at(index); }
mutable_tuple_shapes(int index)119   Shape* mutable_tuple_shapes(int index) { return &tuple_shapes_.at(index); }
add_tuple_shapes()120   Shape* add_tuple_shapes() {
121     tuple_shapes_.push_back(Shape());
122     return &tuple_shapes_.back();
123   }
clear_tuple_shapes()124   void clear_tuple_shapes() { tuple_shapes_.clear(); }
tuple_shapes()125   const std::vector<Shape>& tuple_shapes() const { return tuple_shapes_; }
mutable_tuple_shapes()126   std::vector<Shape>* mutable_tuple_shapes() { return &tuple_shapes_; }
127 
128   // Methods for accessing the layout field.
has_layout()129   bool has_layout() const { return layout_.format() != INVALID_FORMAT; }
layout()130   const Layout& layout() const { return layout_; }
mutable_layout()131   Layout* mutable_layout() { return &layout_; }
clear_layout()132   void clear_layout() { layout_.Clear(); }
133 
134   // Recursively clear dynamic dimension of a shape.
clear_dynamic_dimensions()135   void clear_dynamic_dimensions() {
136     if (!IsTuple()) {
137       for (int64 i = 0; i < dynamic_dimensions_.size(); ++i) {
138         dynamic_dimensions_[i] = false;
139       }
140       return;
141     }
142     for (auto& subshape : tuple_shapes_) {
143       subshape.clear_dynamic_dimensions();
144     }
145   }
146 
Swap(Shape * other)147   void Swap(Shape* other) {
148     using std::swap;
149     swap(*this, *other);
150   }
151 
Clear()152   void Clear() {
153     element_type_ = PRIMITIVE_TYPE_INVALID;
154     clear_dimensions();
155     tuple_shapes_.clear();
156     clear_layout();
157   }
158 
SerializeAsString()159   string SerializeAsString() const { return ToProto().SerializeAsString(); }
ShortDebugString()160   string ShortDebugString() const { return ToProto().ShortDebugString(); }
DebugString()161   string DebugString() const { return ToProto().DebugString(); }
162 
163   // Equal is a configurable functor to check the equality of two shapes.
164   //
165   // Examples:
166   //
167   // - Comparing two shapes ignoring their layout difference:
168   //   Equal().IgnoreLayout()(shape1, shape2);
169   //
170   // - Comparing two shapes ignoring their layout and element type difference:
171   //   Equal().IgnoreLayout().IgnoreElementType()(shape1, shape2);
172   class Equal {
173    public:
174     Equal() = default;
175 
176     bool operator()(const Shape& lhs, const Shape& rhs);
177 
IgnoreLayout()178     Equal& IgnoreLayout() {
179       ignore_layout_ = true;
180       return *this;
181     }
IgnoreTilesInLayout()182     Equal& IgnoreTilesInLayout() {
183       ignore_tiles_in_layout_ = true;
184       return *this;
185     }
IgnoreElementSizeInLayout()186     Equal& IgnoreElementSizeInLayout() {
187       ignore_element_size_in_layout_ = true;
188       return *this;
189     }
IgnoreMemorySpaceInLayout()190     Equal& IgnoreMemorySpaceInLayout() {
191       ignore_memory_space_in_layout_ = true;
192       return *this;
193     }
MinorToMajorOnlyInLayout()194     Equal& MinorToMajorOnlyInLayout() {
195       ignore_tiles_in_layout_ = true;
196       ignore_element_size_in_layout_ = true;
197       ignore_memory_space_in_layout_ = true;
198       return *this;
199     }
IgnoreElementType()200     Equal& IgnoreElementType() {
201       ignore_element_type_ = true;
202       return *this;
203     }
IgnoreFpPrecision()204     Equal& IgnoreFpPrecision() {
205       ignore_fp_precision_ = true;
206       return *this;
207     }
IgnoreDynamicDimension()208     Equal& IgnoreDynamicDimension() {
209       ignore_dynamic_dimension_ = true;
210       return *this;
211     }
212 
213    private:
214     bool ignore_layout_ = false;
215     bool ignore_tiles_in_layout_ = false;
216     bool ignore_element_size_in_layout_ = false;
217     bool ignore_memory_space_in_layout_ = false;
218     bool ignore_element_type_ = false;
219     bool ignore_fp_precision_ = false;
220     bool ignore_dynamic_dimension_ = false;
221   };
222 
223   // Test that all fields of the shape are the same, equivalent to Equal().
224   bool operator==(const Shape& other) const { return Equal()(*this, other); }
225   bool operator!=(const Shape& other) const { return !(*this == other); }
226 
227   template <typename H>
AbslHashValue(H h,const Shape & s)228   friend H AbslHashValue(H h, const Shape& s) {
229     return H::combine(std::move(h), s.element_type_, s.dimensions_,
230                       s.dynamic_dimensions_, s.tuple_shapes_, s.layout_);
231   }
232 
233  private:
234   // The element type of this shape (tuple, array, etc).
235   PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID;
236 
237   // The array bounds of the dimensions. This is nonempty only for array
238   // shapes. For a dynamically-sized dimension, the respective value in this
239   // vector is an inclusive upper limit of the array bound.
240   absl::InlinedVector<int64, 6> dimensions_;
241 
242   // This vector is the same size as 'dimensions_' and indicates whether the
243   // respective dimension is dynamically sized.
244   absl::InlinedVector<bool, 6> dynamic_dimensions_;
245 
246   // The tuple element subshapes. This is nonempty only for tuple shapes.
247   std::vector<Shape> tuple_shapes_;
248 
249   // The layout of the shape. Only relevant for arrays.
250   Layout layout_;
251 };
252 
253 // Shape of the parameters and output of an XLA computation. This is analogous
254 // to a traditional function signature.
255 class ProgramShape {
256  public:
257   ProgramShape() = default;
258 
259   // Creates a ProgramShape from a ProgramShapeProto protobuf.
260   explicit ProgramShape(const ProgramShapeProto& program_shape_proto);
261 
262   // Returns a proto representation of the object.
263   ProgramShapeProto ToProto() const;
264 
265   string ToString() const;
266 
267   // The following methods mirror the protobuf generated code interface for the
268   // message ProgramShapeProto. This enabled easy migration of this data
269   // structure from a proto to a proper C++ class.
270   // TODO(b/29771030): Replace or augment these methods with a more ergonomic
271   // interface.
272 
273   // Methods for accessing and manipulating the Shape of the parameters.
parameters_size()274   int parameters_size() const { return parameters_.size(); }
parameters(int index)275   const Shape& parameters(int index) const { return parameters_.at(index); }
mutable_parameters(int index)276   Shape* mutable_parameters(int index) { return &parameters_.at(index); }
add_parameters()277   Shape* add_parameters() {
278     parameters_.emplace_back();
279     return &parameters_.back();
280   }
clear_parameters()281   void clear_parameters() { parameters_.clear(); }
parameters()282   const std::vector<Shape>& parameters() const { return parameters_; }
mutable_parameters()283   std::vector<Shape>* mutable_parameters() { return &parameters_; }
284 
285   // Methods for accessing and manipulating the Shape of the result.
result()286   const Shape& result() const { return result_; }
mutable_result()287   Shape* mutable_result() { return &result_; }
288 
289   // Methods for accessing and manipulating the names of the parameters.
parameter_names_size()290   int parameter_names_size() const { return parameter_names_.size(); }
parameter_names(int index)291   const string& parameter_names(int index) const {
292     return parameter_names_.at(index);
293   }
set_parameter_names(int index,const string & value)294   void set_parameter_names(int index, const string& value) {
295     parameter_names_.at(index) = value;
296   }
mutable_parameter_names(int index)297   string* mutable_parameter_names(int index) {
298     return &parameter_names_.at(index);
299   }
add_parameter_names(const string & value)300   void add_parameter_names(const string& value) {
301     parameter_names_.push_back(value);
302   }
add_parameter_names()303   string* add_parameter_names() {
304     parameter_names_.push_back("");
305     return &parameter_names_.back();
306   }
clear_parameter_names()307   void clear_parameter_names() { parameter_names_.clear(); }
parameter_names()308   const std::vector<string>& parameter_names() const {
309     return parameter_names_;
310   }
mutable_parameter_names()311   std::vector<string>* mutable_parameter_names() { return &parameter_names_; }
312 
ShortDebugString()313   string ShortDebugString() const { return ToProto().ShortDebugString(); }
DebugString()314   string DebugString() const { return ToProto().DebugString(); }
315 
316  private:
317   // The shapes of the parameters of the computation represented by this object.
318   std::vector<Shape> parameters_;
319 
320   // The names of the parameters of the computation represented by this object.
321   std::vector<string> parameter_names_;
322 
323   // The shape of the result of the computation represented by this object.
324   Shape result_;
325 };
326 
327 std::ostream& operator<<(std::ostream& out, const Shape& shape);
328 std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape);
329 
330 }  // namespace xla
331 
332 #endif  // TENSORFLOW_COMPILER_XLA_SHAPE_H_
333