1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SHAPE_H_ 18 19 #include <string> 20 #include <vector> 21 22 #include "absl/container/inlined_vector.h" 23 #include "absl/types/optional.h" 24 #include "tensorflow/compiler/xla/layout.h" 25 #include "tensorflow/compiler/xla/primitive_util.h" 26 #include "tensorflow/compiler/xla/types.h" 27 #include "tensorflow/compiler/xla/xla_data.pb.h" 28 #include "tensorflow/core/platform/types.h" 29 30 namespace xla { 31 32 // A shape describes the number of dimensions in a array, the bounds of each 33 // dimension, and the primitive component type. For tuples, shape describes the 34 // structure (number of elements and nesting). 35 class Shape { 36 public: 37 Shape() = default; 38 39 // Construct a shape from a ShapeProto. 40 explicit Shape(const ShapeProto& shape_proto); 41 Shape(PrimitiveType element_type,absl::Span<const int64> dimensions,absl::Span<const bool> dynamic_dimensions,std::vector<Shape> tuple_shapes)42 Shape(PrimitiveType element_type, absl::Span<const int64> dimensions, 43 absl::Span<const bool> dynamic_dimensions, 44 std::vector<Shape> tuple_shapes) 45 : element_type_(element_type), 46 dimensions_(dimensions.begin(), dimensions.end()), 47 dynamic_dimensions_(dynamic_dimensions.begin(), 48 dynamic_dimensions.end()), 49 tuple_shapes_(std::move(tuple_shapes)) {} 50 51 // Returns a ShapeProto representation of the Shape. 52 ShapeProto ToProto() const; 53 54 // Returns a human-readable string that represents the given shape, with or 55 // without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]". 56 string ToString(bool print_layout = false) const; 57 58 // Returns the rank (number of dimensions) of the given shape. Shape must be 59 // an array. rank()60 int64 rank() const { 61 DCHECK(IsArray()) << "Non-arrays do not have a rank, shape: " << ToString(); 62 return dimensions_.size(); 63 } 64 65 // Returns whether the shape is of the specified type (array, tuple, etc). IsArray()66 bool IsArray() const { return primitive_util::IsArrayType(element_type()); } IsTuple()67 bool IsTuple() const { return element_type() == TUPLE; } IsToken()68 bool IsToken() const { return element_type() == TOKEN; } IsOpaque()69 bool IsOpaque() const { return element_type() == OPAQUE_TYPE; } 70 71 // Returns whether all elements in the shape are integer. 72 // A nested tuple of integers is considered as integer. 73 bool IsInteger() const; 74 75 // Returns true if no array dimension in the shape is dynamically sized. Tuple 76 // shapes are traversed recursively. 77 bool is_static() const; 78 is_dynamic()79 bool is_dynamic() const { return !is_static(); } 80 81 // Returns true if the given dimension is dynamically-sized. is_dynamic_dimension(int dimension)82 bool is_dynamic_dimension(int dimension) const { 83 return dynamic_dimensions_.at(dimension); 84 } 85 86 // Sets whether or not the given dimension is dynamically-sized. set_dynamic_dimension(int dimension,bool is_dynamic)87 void set_dynamic_dimension(int dimension, bool is_dynamic) { 88 dynamic_dimensions_[dimension] = is_dynamic; 89 } 90 dynamic_dimensions()91 absl::Span<const bool> dynamic_dimensions() const { 92 return dynamic_dimensions_; 93 } 94 mutable_dynamic_dimensions()95 absl::Span<bool> mutable_dynamic_dimensions() { 96 return absl::MakeSpan(dynamic_dimensions_); 97 } 98 99 // Add dimension_upper_bound(). 100 101 // Removes the given dimension form the shape. Layout, if it exists, is 102 // adjusted to match the modified shape. 103 void DeleteDimension(int64_t dim_to_delete); 104 105 // The following methods mirror the protobuf generated code interface for the 106 // message ShapeProto. This enabled easy migration of this data structure 107 // from a proto to a proper C++ class. 108 // TODO(b/29771030): Replace or augment these methods with a more ergonomic 109 // interface. 110 111 // Methods for accessing the primitive type. element_type()112 PrimitiveType element_type() const { return element_type_; } set_element_type(PrimitiveType value)113 void set_element_type(PrimitiveType value) { element_type_ = value; } 114 115 // Methods for accessing the dimensions array. dimensions_size()116 int dimensions_size() const { return dimensions_.size(); } dimensions(int index)117 int64 dimensions(int index) const { return dimensions_.at(index); } set_dimensions(int index,int64_t value)118 void set_dimensions(int index, int64_t value) { 119 dimensions_.at(index) = value; 120 } add_dimensions(int64_t value)121 void add_dimensions(int64_t value) { 122 dimensions_.push_back(value); 123 dynamic_dimensions_.push_back(false); 124 } clear_dimensions()125 void clear_dimensions() { 126 dimensions_.clear(); 127 dynamic_dimensions_.clear(); 128 } dimensions()129 absl::Span<const int64> dimensions() const { return dimensions_; } mutable_dimensions()130 absl::Span<int64> mutable_dimensions() { return absl::MakeSpan(dimensions_); } 131 132 // Methods for accessing the tuple subshapes. This field only non-empty for 133 // tuple shapes. tuple_shapes_size()134 int tuple_shapes_size() const { return tuple_shapes_.size(); } tuple_shapes(int index)135 const Shape& tuple_shapes(int index) const { return tuple_shapes_.at(index); } mutable_tuple_shapes(int index)136 Shape* mutable_tuple_shapes(int index) { return &tuple_shapes_.at(index); } add_tuple_shapes()137 Shape* add_tuple_shapes() { 138 tuple_shapes_.push_back(Shape()); 139 return &tuple_shapes_.back(); 140 } clear_tuple_shapes()141 void clear_tuple_shapes() { tuple_shapes_.clear(); } tuple_shapes()142 const std::vector<Shape>& tuple_shapes() const { return tuple_shapes_; } mutable_tuple_shapes()143 std::vector<Shape>* mutable_tuple_shapes() { return &tuple_shapes_; } 144 145 // Methods for accessing the layout field. has_layout()146 bool has_layout() const { return layout_.format() != INVALID_FORMAT; } layout()147 const Layout& layout() const { return layout_; } mutable_layout()148 Layout* mutable_layout() { return &layout_; } clear_layout()149 void clear_layout() { layout_.Clear(); } 150 151 // Recursively clear dynamic dimension of a shape. clear_dynamic_dimensions()152 void clear_dynamic_dimensions() { 153 if (!IsTuple()) { 154 for (int64_t i = 0; i < dynamic_dimensions_.size(); ++i) { 155 dynamic_dimensions_[i] = false; 156 } 157 return; 158 } 159 for (auto& subshape : tuple_shapes_) { 160 subshape.clear_dynamic_dimensions(); 161 } 162 } 163 Swap(Shape * other)164 void Swap(Shape* other) { 165 using std::swap; 166 swap(*this, *other); 167 } 168 Clear()169 void Clear() { 170 element_type_ = PRIMITIVE_TYPE_INVALID; 171 clear_dimensions(); 172 tuple_shapes_.clear(); 173 clear_layout(); 174 } 175 SerializeAsString()176 string SerializeAsString() const { return ToProto().SerializeAsString(); } ShortDebugString()177 string ShortDebugString() const { return ToProto().ShortDebugString(); } DebugString()178 string DebugString() const { return ToProto().DebugString(); } 179 180 // Equal is a configurable functor to check the equality of two shapes. 181 // 182 // Examples: 183 // 184 // - Comparing two shapes ignoring their layout difference: 185 // Equal().IgnoreLayout()(shape1, shape2); 186 // 187 // - Comparing two shapes ignoring their layout and element type difference: 188 // Equal().IgnoreLayout().IgnoreElementType()(shape1, shape2); 189 class Equal { 190 public: 191 Equal() = default; 192 193 bool operator()(const Shape& lhs, const Shape& rhs); 194 IgnoreLayout()195 Equal& IgnoreLayout() { 196 ignore_layout_ = true; 197 return *this; 198 } IgnoreTilesInLayout()199 Equal& IgnoreTilesInLayout() { 200 ignore_tiles_in_layout_ = true; 201 return *this; 202 } IgnoreElementSizeInLayout()203 Equal& IgnoreElementSizeInLayout() { 204 ignore_element_size_in_layout_ = true; 205 return *this; 206 } IgnoreMemorySpaceInLayout()207 Equal& IgnoreMemorySpaceInLayout() { 208 ignore_memory_space_in_layout_ = true; 209 return *this; 210 } MinorToMajorOnlyInLayout()211 Equal& MinorToMajorOnlyInLayout() { 212 ignore_tiles_in_layout_ = true; 213 ignore_element_size_in_layout_ = true; 214 ignore_memory_space_in_layout_ = true; 215 return *this; 216 } IgnoreElementType()217 Equal& IgnoreElementType() { 218 ignore_element_type_ = true; 219 return *this; 220 } IgnoreFpPrecision()221 Equal& IgnoreFpPrecision() { 222 ignore_fp_precision_ = true; 223 return *this; 224 } IgnoreDynamicDimension()225 Equal& IgnoreDynamicDimension() { 226 ignore_dynamic_dimension_ = true; 227 return *this; 228 } IgnoreDimensions()229 Equal& IgnoreDimensions() { 230 ignore_dimensions_ = true; 231 return *this; 232 } 233 234 private: 235 bool ignore_layout_ = false; 236 bool ignore_tiles_in_layout_ = false; 237 bool ignore_element_size_in_layout_ = false; 238 bool ignore_memory_space_in_layout_ = false; 239 bool ignore_element_type_ = false; 240 bool ignore_fp_precision_ = false; 241 bool ignore_dynamic_dimension_ = false; 242 bool ignore_dimensions_ = false; 243 }; 244 245 // Test that all fields of the shape are the same, equivalent to Equal(). 246 bool operator==(const Shape& other) const { return Equal()(*this, other); } 247 bool operator!=(const Shape& other) const { return !(*this == other); } 248 249 template <typename H> AbslHashValue(H h,const Shape & s)250 friend H AbslHashValue(H h, const Shape& s) { 251 return H::combine(std::move(h), s.element_type_, s.dimensions_, 252 s.dynamic_dimensions_, s.tuple_shapes_, s.layout_); 253 } 254 255 private: 256 // The element type of this shape (tuple, array, etc). 257 PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID; 258 259 // The array bounds of the dimensions. This is nonempty only for array 260 // shapes. For a dynamically-sized dimension, the respective value in this 261 // vector is an inclusive upper limit of the array bound. 262 absl::InlinedVector<int64, 6> dimensions_; 263 264 // This vector is the same size as 'dimensions_' and indicates whether the 265 // respective dimension is dynamically sized. 266 absl::InlinedVector<bool, 6> dynamic_dimensions_; 267 268 // The tuple element subshapes. This is nonempty only for tuple shapes. 269 std::vector<Shape> tuple_shapes_; 270 271 // The layout of the shape. Only relevant for arrays. 272 Layout layout_; 273 }; 274 275 // Shape of the parameters and output of an XLA computation. This is analogous 276 // to a traditional function signature. 277 class ProgramShape { 278 public: 279 ProgramShape() = default; 280 281 // Creates a ProgramShape from a ProgramShapeProto protobuf. 282 explicit ProgramShape(const ProgramShapeProto& program_shape_proto); 283 284 // Returns a proto representation of the object. 285 ProgramShapeProto ToProto() const; 286 287 string ToString() const; 288 289 // The following methods mirror the protobuf generated code interface for the 290 // message ProgramShapeProto. This enabled easy migration of this data 291 // structure from a proto to a proper C++ class. 292 // TODO(b/29771030): Replace or augment these methods with a more ergonomic 293 // interface. 294 295 // Methods for accessing and manipulating the Shape of the parameters. parameters_size()296 int parameters_size() const { return parameters_.size(); } parameters(int index)297 const Shape& parameters(int index) const { return parameters_.at(index); } mutable_parameters(int index)298 Shape* mutable_parameters(int index) { return ¶meters_.at(index); } add_parameters()299 Shape* add_parameters() { 300 parameters_.emplace_back(); 301 return ¶meters_.back(); 302 } clear_parameters()303 void clear_parameters() { parameters_.clear(); } parameters()304 const std::vector<Shape>& parameters() const { return parameters_; } mutable_parameters()305 std::vector<Shape>* mutable_parameters() { return ¶meters_; } 306 307 // Methods for accessing and manipulating the Shape of the result. result()308 const Shape& result() const { return result_; } mutable_result()309 Shape* mutable_result() { return &result_; } 310 311 // Methods for accessing and manipulating the names of the parameters. parameter_names_size()312 int parameter_names_size() const { return parameter_names_.size(); } parameter_names(int index)313 const string& parameter_names(int index) const { 314 return parameter_names_.at(index); 315 } set_parameter_names(int index,const string & value)316 void set_parameter_names(int index, const string& value) { 317 parameter_names_.at(index) = value; 318 } mutable_parameter_names(int index)319 string* mutable_parameter_names(int index) { 320 return ¶meter_names_.at(index); 321 } add_parameter_names(const string & value)322 void add_parameter_names(const string& value) { 323 parameter_names_.push_back(value); 324 } add_parameter_names()325 string* add_parameter_names() { 326 parameter_names_.push_back(""); 327 return ¶meter_names_.back(); 328 } clear_parameter_names()329 void clear_parameter_names() { parameter_names_.clear(); } parameter_names()330 const std::vector<string>& parameter_names() const { 331 return parameter_names_; 332 } mutable_parameter_names()333 std::vector<string>* mutable_parameter_names() { return ¶meter_names_; } 334 ShortDebugString()335 string ShortDebugString() const { return ToProto().ShortDebugString(); } DebugString()336 string DebugString() const { return ToProto().DebugString(); } 337 338 private: 339 // The shapes of the parameters of the computation represented by this object. 340 std::vector<Shape> parameters_; 341 342 // The names of the parameters of the computation represented by this object. 343 std::vector<string> parameter_names_; 344 345 // The shape of the result of the computation represented by this object. 346 Shape result_; 347 }; 348 349 std::ostream& operator<<(std::ostream& out, const Shape& shape); 350 std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape); 351 352 } // namespace xla 353 354 #endif // TENSORFLOW_COMPILER_XLA_SHAPE_H_ 355