• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_H_
17 #define TENSORFLOW_COMPILER_XLA_SHAPE_H_
18 
19 #include <string>
20 #include <vector>
21 
22 #include "absl/container/inlined_vector.h"
23 #include "absl/types/optional.h"
24 #include "tensorflow/compiler/xla/layout.h"
25 #include "tensorflow/compiler/xla/primitive_util.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/platform/types.h"
29 
30 namespace xla {
31 
32 // A shape describes the number of dimensions in a array, the bounds of each
33 // dimension, and the primitive component type. For tuples, shape describes the
34 // structure (number of elements and nesting).
35 class Shape {
36  public:
37   Shape() = default;
38 
39   // Construct a shape from a ShapeProto.
40   explicit Shape(const ShapeProto& shape_proto);
41 
Shape(PrimitiveType element_type,absl::Span<const int64> dimensions,absl::Span<const bool> dynamic_dimensions,std::vector<Shape> tuple_shapes)42   Shape(PrimitiveType element_type, absl::Span<const int64> dimensions,
43         absl::Span<const bool> dynamic_dimensions,
44         std::vector<Shape> tuple_shapes)
45       : element_type_(element_type),
46         dimensions_(dimensions.begin(), dimensions.end()),
47         dynamic_dimensions_(dynamic_dimensions.begin(),
48                             dynamic_dimensions.end()),
49         tuple_shapes_(std::move(tuple_shapes)) {}
50 
51   // Returns a ShapeProto representation of the Shape.
52   ShapeProto ToProto() const;
53 
54   // Returns a human-readable string that represents the given shape, with or
55   // without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]".
56   string ToString(bool print_layout = false) const;
57 
58   // Returns the rank (number of dimensions) of the given shape. Shape must be
59   // an array.
rank()60   int64 rank() const {
61     DCHECK(IsArray()) << "Non-arrays do not have a rank, shape: " << ToString();
62     return dimensions_.size();
63   }
64 
65   // Returns whether the shape is of the specified type (array, tuple, etc).
IsArray()66   bool IsArray() const { return primitive_util::IsArrayType(element_type()); }
IsTuple()67   bool IsTuple() const { return element_type() == TUPLE; }
IsToken()68   bool IsToken() const { return element_type() == TOKEN; }
IsOpaque()69   bool IsOpaque() const { return element_type() == OPAQUE_TYPE; }
70 
71   // Returns true if no array dimension in the shape is dynamically sized. Tuple
72   // shapes are traversed recursively.
73   bool is_static() const;
74 
is_dynamic()75   bool is_dynamic() const { return !is_static(); }
76 
77   // Returns true if the given dimension is dynamically-sized.
is_dynamic_dimension(int dimension)78   bool is_dynamic_dimension(int dimension) const {
79     return dynamic_dimensions_.at(dimension);
80   }
81 
82   // Sets whether or not the given dimension is dynamically-sized.
set_dynamic_dimension(int dimension,bool is_dynamic)83   void set_dynamic_dimension(int dimension, bool is_dynamic) {
84     dynamic_dimensions_[dimension] = is_dynamic;
85   }
86 
dynamic_dimensions()87   absl::Span<const bool> dynamic_dimensions() const {
88     return dynamic_dimensions_;
89   }
90 
mutable_dynamic_dimensions()91   absl::Span<bool> mutable_dynamic_dimensions() {
92     return absl::MakeSpan(dynamic_dimensions_);
93   }
94 
95   // Add dimension_upper_bound().
96 
97   // Removes the given dimension form the shape. Layout, if it exists, is
98   // adjusted to match the modified shape.
99   void DeleteDimension(int64 dim_to_delete);
100 
101   // The following methods mirror the protobuf generated code interface for the
102   // message ShapeProto. This enabled easy migration of this data structure
103   // from a proto to a proper C++ class.
104   // TODO(b/29771030): Replace or augment these methods with a more ergonomic
105   // interface.
106 
107   // Methods for accessing the primitive type.
element_type()108   PrimitiveType element_type() const { return element_type_; }
set_element_type(PrimitiveType value)109   void set_element_type(PrimitiveType value) { element_type_ = value; }
110 
111   // Methods for accessing the dimensions array.
dimensions_size()112   int dimensions_size() const { return dimensions_.size(); }
dimensions(int index)113   int64 dimensions(int index) const { return dimensions_.at(index); }
set_dimensions(int index,int64 value)114   void set_dimensions(int index, int64 value) { dimensions_.at(index) = value; }
add_dimensions(int64 value)115   void add_dimensions(int64 value) {
116     dimensions_.push_back(value);
117     dynamic_dimensions_.push_back(false);
118   }
clear_dimensions()119   void clear_dimensions() {
120     dimensions_.clear();
121     dynamic_dimensions_.clear();
122   }
dimensions()123   absl::Span<const int64> dimensions() const { return dimensions_; }
mutable_dimensions()124   absl::Span<int64> mutable_dimensions() { return absl::MakeSpan(dimensions_); }
125 
126   // Methods for accessing the tuple subshapes. This field only non-empty for
127   // tuple shapes.
tuple_shapes_size()128   int tuple_shapes_size() const { return tuple_shapes_.size(); }
tuple_shapes(int index)129   const Shape& tuple_shapes(int index) const { return tuple_shapes_.at(index); }
mutable_tuple_shapes(int index)130   Shape* mutable_tuple_shapes(int index) { return &tuple_shapes_.at(index); }
add_tuple_shapes()131   Shape* add_tuple_shapes() {
132     tuple_shapes_.push_back(Shape());
133     return &tuple_shapes_.back();
134   }
clear_tuple_shapes()135   void clear_tuple_shapes() { tuple_shapes_.clear(); }
tuple_shapes()136   const std::vector<Shape>& tuple_shapes() const { return tuple_shapes_; }
mutable_tuple_shapes()137   std::vector<Shape>* mutable_tuple_shapes() { return &tuple_shapes_; }
138 
139   // Methods for accessing the layout field.
has_layout()140   bool has_layout() const { return layout_.format() != INVALID_FORMAT; }
layout()141   const Layout& layout() const { return layout_; }
mutable_layout()142   Layout* mutable_layout() { return &layout_; }
clear_layout()143   void clear_layout() { layout_.Clear(); }
144 
145   // Recursively clear dynamic dimension of a shape.
clear_dynamic_dimensions()146   void clear_dynamic_dimensions() {
147     if (!IsTuple()) {
148       for (int64 i = 0; i < dynamic_dimensions_.size(); ++i) {
149         dynamic_dimensions_[i] = false;
150       }
151       return;
152     }
153     for (auto& subshape : tuple_shapes_) {
154       subshape.clear_dynamic_dimensions();
155     }
156   }
157 
Swap(Shape * other)158   void Swap(Shape* other) {
159     using std::swap;
160     swap(*this, *other);
161   }
162 
Clear()163   void Clear() {
164     element_type_ = PRIMITIVE_TYPE_INVALID;
165     clear_dimensions();
166     tuple_shapes_.clear();
167     clear_layout();
168   }
169 
SerializeAsString()170   string SerializeAsString() const { return ToProto().SerializeAsString(); }
ShortDebugString()171   string ShortDebugString() const { return ToProto().ShortDebugString(); }
DebugString()172   string DebugString() const { return ToProto().DebugString(); }
173 
174   // Equal is a configurable functor to check the equality of two shapes.
175   //
176   // Examples:
177   //
178   // - Comparing two shapes ignoring their layout difference:
179   //   Equal().IgnoreLayout()(shape1, shape2);
180   //
181   // - Comparing two shapes ignoring their layout and element type difference:
182   //   Equal().IgnoreLayout().IgnoreElementType()(shape1, shape2);
183   class Equal {
184    public:
185     Equal() = default;
186 
187     bool operator()(const Shape& lhs, const Shape& rhs);
188 
IgnoreLayout()189     Equal& IgnoreLayout() {
190       ignore_layout_ = true;
191       return *this;
192     }
IgnoreTilesInLayout()193     Equal& IgnoreTilesInLayout() {
194       ignore_tiles_in_layout_ = true;
195       return *this;
196     }
IgnoreElementSizeInLayout()197     Equal& IgnoreElementSizeInLayout() {
198       ignore_element_size_in_layout_ = true;
199       return *this;
200     }
IgnoreMemorySpaceInLayout()201     Equal& IgnoreMemorySpaceInLayout() {
202       ignore_memory_space_in_layout_ = true;
203       return *this;
204     }
MinorToMajorOnlyInLayout()205     Equal& MinorToMajorOnlyInLayout() {
206       ignore_tiles_in_layout_ = true;
207       ignore_element_size_in_layout_ = true;
208       ignore_memory_space_in_layout_ = true;
209       return *this;
210     }
IgnoreElementType()211     Equal& IgnoreElementType() {
212       ignore_element_type_ = true;
213       return *this;
214     }
IgnoreFpPrecision()215     Equal& IgnoreFpPrecision() {
216       ignore_fp_precision_ = true;
217       return *this;
218     }
IgnoreDynamicDimension()219     Equal& IgnoreDynamicDimension() {
220       ignore_dynamic_dimension_ = true;
221       return *this;
222     }
IgnoreDimensions()223     Equal& IgnoreDimensions() {
224       ignore_dimensions_ = true;
225       return *this;
226     }
227 
228    private:
229     bool ignore_layout_ = false;
230     bool ignore_tiles_in_layout_ = false;
231     bool ignore_element_size_in_layout_ = false;
232     bool ignore_memory_space_in_layout_ = false;
233     bool ignore_element_type_ = false;
234     bool ignore_fp_precision_ = false;
235     bool ignore_dynamic_dimension_ = false;
236     bool ignore_dimensions_ = false;
237   };
238 
239   // Test that all fields of the shape are the same, equivalent to Equal().
240   bool operator==(const Shape& other) const { return Equal()(*this, other); }
241   bool operator!=(const Shape& other) const { return !(*this == other); }
242 
243   template <typename H>
AbslHashValue(H h,const Shape & s)244   friend H AbslHashValue(H h, const Shape& s) {
245     return H::combine(std::move(h), s.element_type_, s.dimensions_,
246                       s.dynamic_dimensions_, s.tuple_shapes_, s.layout_);
247   }
248 
249  private:
250   // The element type of this shape (tuple, array, etc).
251   PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID;
252 
253   // The array bounds of the dimensions. This is nonempty only for array
254   // shapes. For a dynamically-sized dimension, the respective value in this
255   // vector is an inclusive upper limit of the array bound.
256   absl::InlinedVector<int64, 6> dimensions_;
257 
258   // This vector is the same size as 'dimensions_' and indicates whether the
259   // respective dimension is dynamically sized.
260   absl::InlinedVector<bool, 6> dynamic_dimensions_;
261 
262   // The tuple element subshapes. This is nonempty only for tuple shapes.
263   std::vector<Shape> tuple_shapes_;
264 
265   // The layout of the shape. Only relevant for arrays.
266   Layout layout_;
267 };
268 
269 // Shape of the parameters and output of an XLA computation. This is analogous
270 // to a traditional function signature.
271 class ProgramShape {
272  public:
273   ProgramShape() = default;
274 
275   // Creates a ProgramShape from a ProgramShapeProto protobuf.
276   explicit ProgramShape(const ProgramShapeProto& program_shape_proto);
277 
278   // Returns a proto representation of the object.
279   ProgramShapeProto ToProto() const;
280 
281   string ToString() const;
282 
283   // The following methods mirror the protobuf generated code interface for the
284   // message ProgramShapeProto. This enabled easy migration of this data
285   // structure from a proto to a proper C++ class.
286   // TODO(b/29771030): Replace or augment these methods with a more ergonomic
287   // interface.
288 
289   // Methods for accessing and manipulating the Shape of the parameters.
parameters_size()290   int parameters_size() const { return parameters_.size(); }
parameters(int index)291   const Shape& parameters(int index) const { return parameters_.at(index); }
mutable_parameters(int index)292   Shape* mutable_parameters(int index) { return &parameters_.at(index); }
add_parameters()293   Shape* add_parameters() {
294     parameters_.emplace_back();
295     return &parameters_.back();
296   }
clear_parameters()297   void clear_parameters() { parameters_.clear(); }
parameters()298   const std::vector<Shape>& parameters() const { return parameters_; }
mutable_parameters()299   std::vector<Shape>* mutable_parameters() { return &parameters_; }
300 
301   // Methods for accessing and manipulating the Shape of the result.
result()302   const Shape& result() const { return result_; }
mutable_result()303   Shape* mutable_result() { return &result_; }
304 
305   // Methods for accessing and manipulating the names of the parameters.
parameter_names_size()306   int parameter_names_size() const { return parameter_names_.size(); }
parameter_names(int index)307   const string& parameter_names(int index) const {
308     return parameter_names_.at(index);
309   }
set_parameter_names(int index,const string & value)310   void set_parameter_names(int index, const string& value) {
311     parameter_names_.at(index) = value;
312   }
mutable_parameter_names(int index)313   string* mutable_parameter_names(int index) {
314     return &parameter_names_.at(index);
315   }
add_parameter_names(const string & value)316   void add_parameter_names(const string& value) {
317     parameter_names_.push_back(value);
318   }
add_parameter_names()319   string* add_parameter_names() {
320     parameter_names_.push_back("");
321     return &parameter_names_.back();
322   }
clear_parameter_names()323   void clear_parameter_names() { parameter_names_.clear(); }
parameter_names()324   const std::vector<string>& parameter_names() const {
325     return parameter_names_;
326   }
mutable_parameter_names()327   std::vector<string>* mutable_parameter_names() { return &parameter_names_; }
328 
ShortDebugString()329   string ShortDebugString() const { return ToProto().ShortDebugString(); }
DebugString()330   string DebugString() const { return ToProto().DebugString(); }
331 
332  private:
333   // The shapes of the parameters of the computation represented by this object.
334   std::vector<Shape> parameters_;
335 
336   // The names of the parameters of the computation represented by this object.
337   std::vector<string> parameter_names_;
338 
339   // The shape of the result of the computation represented by this object.
340   Shape result_;
341 };
342 
343 std::ostream& operator<<(std::ostream& out, const Shape& shape);
344 std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape);
345 
346 }  // namespace xla
347 
348 #endif  // TENSORFLOW_COMPILER_XLA_SHAPE_H_
349