• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_H_
17 #define TENSORFLOW_COMPILER_XLA_SHAPE_H_
18 
19 #include <string>
20 #include <vector>
21 
22 #include "absl/container/inlined_vector.h"
23 #include "absl/types/optional.h"
24 #include "tensorflow/compiler/xla/layout.h"
25 #include "tensorflow/compiler/xla/primitive_util.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/platform/types.h"
29 
30 namespace xla {
31 
32 // A shape describes the number of dimensions in a array, the bounds of each
33 // dimension, and the primitive component type. For tuples, shape describes the
34 // structure (number of elements and nesting).
35 class Shape {
36  public:
37   Shape() = default;
38 
39   // Construct a shape from a ShapeProto.
40   explicit Shape(const ShapeProto& shape_proto);
41 
Shape(PrimitiveType element_type,absl::Span<const int64> dimensions,absl::Span<const bool> dynamic_dimensions,std::vector<Shape> tuple_shapes)42   Shape(PrimitiveType element_type, absl::Span<const int64> dimensions,
43         absl::Span<const bool> dynamic_dimensions,
44         std::vector<Shape> tuple_shapes)
45       : element_type_(element_type),
46         dimensions_(dimensions.begin(), dimensions.end()),
47         dynamic_dimensions_(dynamic_dimensions.begin(),
48                             dynamic_dimensions.end()),
49         tuple_shapes_(std::move(tuple_shapes)) {}
50 
51   // Returns a ShapeProto representation of the Shape.
52   ShapeProto ToProto() const;
53 
54   // Returns a human-readable string that represents the given shape, with or
55   // without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]".
56   string ToString(bool print_layout = false) const;
57 
58   // Returns the rank (number of dimensions) of the given shape. Shape must be
59   // an array.
rank()60   int64 rank() const {
61     DCHECK(IsArray()) << "Non-arrays do not have a rank, shape: " << ToString();
62     return dimensions_.size();
63   }
64 
65   // Returns whether the shape is of the specified type (array, tuple, etc).
IsArray()66   bool IsArray() const { return primitive_util::IsArrayType(element_type()); }
IsTuple()67   bool IsTuple() const { return element_type() == TUPLE; }
IsToken()68   bool IsToken() const { return element_type() == TOKEN; }
IsOpaque()69   bool IsOpaque() const { return element_type() == OPAQUE_TYPE; }
70 
71   // Returns whether all elements in the shape are integer.
72   // A nested tuple of integers is considered as integer.
73   bool IsInteger() const;
74 
75   // Returns true if no array dimension in the shape is dynamically sized. Tuple
76   // shapes are traversed recursively.
77   bool is_static() const;
78 
is_dynamic()79   bool is_dynamic() const { return !is_static(); }
80 
81   // Returns true if the given dimension is dynamically-sized.
is_dynamic_dimension(int dimension)82   bool is_dynamic_dimension(int dimension) const {
83     return dynamic_dimensions_.at(dimension);
84   }
85 
86   // Sets whether or not the given dimension is dynamically-sized.
set_dynamic_dimension(int dimension,bool is_dynamic)87   void set_dynamic_dimension(int dimension, bool is_dynamic) {
88     dynamic_dimensions_[dimension] = is_dynamic;
89   }
90 
dynamic_dimensions()91   absl::Span<const bool> dynamic_dimensions() const {
92     return dynamic_dimensions_;
93   }
94 
mutable_dynamic_dimensions()95   absl::Span<bool> mutable_dynamic_dimensions() {
96     return absl::MakeSpan(dynamic_dimensions_);
97   }
98 
99   // Add dimension_upper_bound().
100 
101   // Removes the given dimension form the shape. Layout, if it exists, is
102   // adjusted to match the modified shape.
103   void DeleteDimension(int64_t dim_to_delete);
104 
105   // The following methods mirror the protobuf generated code interface for the
106   // message ShapeProto. This enabled easy migration of this data structure
107   // from a proto to a proper C++ class.
108   // TODO(b/29771030): Replace or augment these methods with a more ergonomic
109   // interface.
110 
111   // Methods for accessing the primitive type.
element_type()112   PrimitiveType element_type() const { return element_type_; }
set_element_type(PrimitiveType value)113   void set_element_type(PrimitiveType value) { element_type_ = value; }
114 
115   // Methods for accessing the dimensions array.
dimensions_size()116   int dimensions_size() const { return dimensions_.size(); }
dimensions(int index)117   int64 dimensions(int index) const { return dimensions_.at(index); }
set_dimensions(int index,int64_t value)118   void set_dimensions(int index, int64_t value) {
119     dimensions_.at(index) = value;
120   }
add_dimensions(int64_t value)121   void add_dimensions(int64_t value) {
122     dimensions_.push_back(value);
123     dynamic_dimensions_.push_back(false);
124   }
clear_dimensions()125   void clear_dimensions() {
126     dimensions_.clear();
127     dynamic_dimensions_.clear();
128   }
dimensions()129   absl::Span<const int64> dimensions() const { return dimensions_; }
mutable_dimensions()130   absl::Span<int64> mutable_dimensions() { return absl::MakeSpan(dimensions_); }
131 
132   // Methods for accessing the tuple subshapes. This field only non-empty for
133   // tuple shapes.
tuple_shapes_size()134   int tuple_shapes_size() const { return tuple_shapes_.size(); }
tuple_shapes(int index)135   const Shape& tuple_shapes(int index) const { return tuple_shapes_.at(index); }
mutable_tuple_shapes(int index)136   Shape* mutable_tuple_shapes(int index) { return &tuple_shapes_.at(index); }
add_tuple_shapes()137   Shape* add_tuple_shapes() {
138     tuple_shapes_.push_back(Shape());
139     return &tuple_shapes_.back();
140   }
clear_tuple_shapes()141   void clear_tuple_shapes() { tuple_shapes_.clear(); }
tuple_shapes()142   const std::vector<Shape>& tuple_shapes() const { return tuple_shapes_; }
mutable_tuple_shapes()143   std::vector<Shape>* mutable_tuple_shapes() { return &tuple_shapes_; }
144 
145   // Methods for accessing the layout field.
has_layout()146   bool has_layout() const { return layout_.format() != INVALID_FORMAT; }
layout()147   const Layout& layout() const { return layout_; }
mutable_layout()148   Layout* mutable_layout() { return &layout_; }
clear_layout()149   void clear_layout() { layout_.Clear(); }
150 
151   // Recursively clear dynamic dimension of a shape.
clear_dynamic_dimensions()152   void clear_dynamic_dimensions() {
153     if (!IsTuple()) {
154       for (int64_t i = 0; i < dynamic_dimensions_.size(); ++i) {
155         dynamic_dimensions_[i] = false;
156       }
157       return;
158     }
159     for (auto& subshape : tuple_shapes_) {
160       subshape.clear_dynamic_dimensions();
161     }
162   }
163 
Swap(Shape * other)164   void Swap(Shape* other) {
165     using std::swap;
166     swap(*this, *other);
167   }
168 
Clear()169   void Clear() {
170     element_type_ = PRIMITIVE_TYPE_INVALID;
171     clear_dimensions();
172     tuple_shapes_.clear();
173     clear_layout();
174   }
175 
SerializeAsString()176   string SerializeAsString() const { return ToProto().SerializeAsString(); }
ShortDebugString()177   string ShortDebugString() const { return ToProto().ShortDebugString(); }
DebugString()178   string DebugString() const { return ToProto().DebugString(); }
179 
180   // Equal is a configurable functor to check the equality of two shapes.
181   //
182   // Examples:
183   //
184   // - Comparing two shapes ignoring their layout difference:
185   //   Equal().IgnoreLayout()(shape1, shape2);
186   //
187   // - Comparing two shapes ignoring their layout and element type difference:
188   //   Equal().IgnoreLayout().IgnoreElementType()(shape1, shape2);
189   class Equal {
190    public:
191     Equal() = default;
192 
193     bool operator()(const Shape& lhs, const Shape& rhs);
194 
IgnoreLayout()195     Equal& IgnoreLayout() {
196       ignore_layout_ = true;
197       return *this;
198     }
IgnoreTilesInLayout()199     Equal& IgnoreTilesInLayout() {
200       ignore_tiles_in_layout_ = true;
201       return *this;
202     }
IgnoreElementSizeInLayout()203     Equal& IgnoreElementSizeInLayout() {
204       ignore_element_size_in_layout_ = true;
205       return *this;
206     }
IgnoreMemorySpaceInLayout()207     Equal& IgnoreMemorySpaceInLayout() {
208       ignore_memory_space_in_layout_ = true;
209       return *this;
210     }
MinorToMajorOnlyInLayout()211     Equal& MinorToMajorOnlyInLayout() {
212       ignore_tiles_in_layout_ = true;
213       ignore_element_size_in_layout_ = true;
214       ignore_memory_space_in_layout_ = true;
215       return *this;
216     }
IgnoreElementType()217     Equal& IgnoreElementType() {
218       ignore_element_type_ = true;
219       return *this;
220     }
IgnoreFpPrecision()221     Equal& IgnoreFpPrecision() {
222       ignore_fp_precision_ = true;
223       return *this;
224     }
IgnoreDynamicDimension()225     Equal& IgnoreDynamicDimension() {
226       ignore_dynamic_dimension_ = true;
227       return *this;
228     }
IgnoreDimensions()229     Equal& IgnoreDimensions() {
230       ignore_dimensions_ = true;
231       return *this;
232     }
233 
234    private:
235     bool ignore_layout_ = false;
236     bool ignore_tiles_in_layout_ = false;
237     bool ignore_element_size_in_layout_ = false;
238     bool ignore_memory_space_in_layout_ = false;
239     bool ignore_element_type_ = false;
240     bool ignore_fp_precision_ = false;
241     bool ignore_dynamic_dimension_ = false;
242     bool ignore_dimensions_ = false;
243   };
244 
245   // Test that all fields of the shape are the same, equivalent to Equal().
246   bool operator==(const Shape& other) const { return Equal()(*this, other); }
247   bool operator!=(const Shape& other) const { return !(*this == other); }
248 
249   template <typename H>
AbslHashValue(H h,const Shape & s)250   friend H AbslHashValue(H h, const Shape& s) {
251     return H::combine(std::move(h), s.element_type_, s.dimensions_,
252                       s.dynamic_dimensions_, s.tuple_shapes_, s.layout_);
253   }
254 
255  private:
256   // The element type of this shape (tuple, array, etc).
257   PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID;
258 
259   // The array bounds of the dimensions. This is nonempty only for array
260   // shapes. For a dynamically-sized dimension, the respective value in this
261   // vector is an inclusive upper limit of the array bound.
262   absl::InlinedVector<int64, 6> dimensions_;
263 
264   // This vector is the same size as 'dimensions_' and indicates whether the
265   // respective dimension is dynamically sized.
266   absl::InlinedVector<bool, 6> dynamic_dimensions_;
267 
268   // The tuple element subshapes. This is nonempty only for tuple shapes.
269   std::vector<Shape> tuple_shapes_;
270 
271   // The layout of the shape. Only relevant for arrays.
272   Layout layout_;
273 };
274 
275 // Shape of the parameters and output of an XLA computation. This is analogous
276 // to a traditional function signature.
277 class ProgramShape {
278  public:
279   ProgramShape() = default;
280 
281   // Creates a ProgramShape from a ProgramShapeProto protobuf.
282   explicit ProgramShape(const ProgramShapeProto& program_shape_proto);
283 
284   // Returns a proto representation of the object.
285   ProgramShapeProto ToProto() const;
286 
287   string ToString() const;
288 
289   // The following methods mirror the protobuf generated code interface for the
290   // message ProgramShapeProto. This enabled easy migration of this data
291   // structure from a proto to a proper C++ class.
292   // TODO(b/29771030): Replace or augment these methods with a more ergonomic
293   // interface.
294 
295   // Methods for accessing and manipulating the Shape of the parameters.
parameters_size()296   int parameters_size() const { return parameters_.size(); }
parameters(int index)297   const Shape& parameters(int index) const { return parameters_.at(index); }
mutable_parameters(int index)298   Shape* mutable_parameters(int index) { return &parameters_.at(index); }
add_parameters()299   Shape* add_parameters() {
300     parameters_.emplace_back();
301     return &parameters_.back();
302   }
clear_parameters()303   void clear_parameters() { parameters_.clear(); }
parameters()304   const std::vector<Shape>& parameters() const { return parameters_; }
mutable_parameters()305   std::vector<Shape>* mutable_parameters() { return &parameters_; }
306 
307   // Methods for accessing and manipulating the Shape of the result.
result()308   const Shape& result() const { return result_; }
mutable_result()309   Shape* mutable_result() { return &result_; }
310 
311   // Methods for accessing and manipulating the names of the parameters.
parameter_names_size()312   int parameter_names_size() const { return parameter_names_.size(); }
parameter_names(int index)313   const string& parameter_names(int index) const {
314     return parameter_names_.at(index);
315   }
set_parameter_names(int index,const string & value)316   void set_parameter_names(int index, const string& value) {
317     parameter_names_.at(index) = value;
318   }
mutable_parameter_names(int index)319   string* mutable_parameter_names(int index) {
320     return &parameter_names_.at(index);
321   }
add_parameter_names(const string & value)322   void add_parameter_names(const string& value) {
323     parameter_names_.push_back(value);
324   }
add_parameter_names()325   string* add_parameter_names() {
326     parameter_names_.push_back("");
327     return &parameter_names_.back();
328   }
clear_parameter_names()329   void clear_parameter_names() { parameter_names_.clear(); }
parameter_names()330   const std::vector<string>& parameter_names() const {
331     return parameter_names_;
332   }
mutable_parameter_names()333   std::vector<string>* mutable_parameter_names() { return &parameter_names_; }
334 
ShortDebugString()335   string ShortDebugString() const { return ToProto().ShortDebugString(); }
DebugString()336   string DebugString() const { return ToProto().DebugString(); }
337 
338  private:
339   // The shapes of the parameters of the computation represented by this object.
340   std::vector<Shape> parameters_;
341 
342   // The names of the parameters of the computation represented by this object.
343   std::vector<string> parameter_names_;
344 
345   // The shape of the result of the computation represented by this object.
346   Shape result_;
347 };
348 
349 std::ostream& operator<<(std::ostream& out, const Shape& shape);
350 std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape);
351 
352 }  // namespace xla
353 
354 #endif  // TENSORFLOW_COMPILER_XLA_SHAPE_H_
355