1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SHAPE_H_ 18 19 #include <string> 20 #include <vector> 21 22 #include "absl/container/inlined_vector.h" 23 #include "absl/types/optional.h" 24 #include "tensorflow/compiler/xla/layout.h" 25 #include "tensorflow/compiler/xla/primitive_util.h" 26 #include "tensorflow/compiler/xla/types.h" 27 #include "tensorflow/compiler/xla/xla_data.pb.h" 28 #include "tensorflow/core/platform/types.h" 29 30 namespace xla { 31 32 // A shape describes the number of dimensions in a array, the bounds of each 33 // dimension, and the primitive component type. For tuples, shape describes the 34 // structure (number of elements and nesting). 35 class Shape { 36 public: 37 Shape() = default; 38 39 // Construct a shape from a ShapeProto. 40 explicit Shape(const ShapeProto& shape_proto); 41 Shape(PrimitiveType element_type,absl::Span<const int64> dimensions,absl::Span<const bool> dynamic_dimensions,std::vector<Shape> tuple_shapes)42 Shape(PrimitiveType element_type, absl::Span<const int64> dimensions, 43 absl::Span<const bool> dynamic_dimensions, 44 std::vector<Shape> tuple_shapes) 45 : element_type_(element_type), 46 dimensions_(dimensions.begin(), dimensions.end()), 47 dynamic_dimensions_(dynamic_dimensions.begin(), 48 dynamic_dimensions.end()), 49 tuple_shapes_(std::move(tuple_shapes)) {} 50 51 // Returns a ShapeProto representation of the Shape. 52 ShapeProto ToProto() const; 53 54 // Returns a human-readable string that represents the given shape, with or 55 // without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]". 56 string ToString(bool print_layout = false) const; 57 58 // Returns the rank (number of dimensions) of the given shape. Shape must be 59 // an array. rank()60 int64 rank() const { 61 DCHECK(IsArray()) << "Non-arrays do not have a rank, shape: " << ToString(); 62 return dimensions_.size(); 63 } 64 65 // Returns whether the shape is of the specified type (array, tuple, etc). IsArray()66 bool IsArray() const { return primitive_util::IsArrayType(element_type()); } IsTuple()67 bool IsTuple() const { return element_type() == TUPLE; } IsToken()68 bool IsToken() const { return element_type() == TOKEN; } IsOpaque()69 bool IsOpaque() const { return element_type() == OPAQUE_TYPE; } 70 71 // Returns true if no array dimension in the shape is dynamically sized. Tuple 72 // shapes are traversed recursively. 73 bool is_static() const; 74 is_dynamic()75 bool is_dynamic() const { return !is_static(); } 76 77 // Returns true if the given dimension is dynamically-sized. is_dynamic_dimension(int dimension)78 bool is_dynamic_dimension(int dimension) const { 79 return dynamic_dimensions_.at(dimension); 80 } 81 82 // Sets whether or not the given dimension is dynamically-sized. set_dynamic_dimension(int dimension,bool is_dynamic)83 void set_dynamic_dimension(int dimension, bool is_dynamic) { 84 dynamic_dimensions_[dimension] = is_dynamic; 85 } 86 dynamic_dimensions()87 absl::Span<const bool> dynamic_dimensions() const { 88 return dynamic_dimensions_; 89 } 90 mutable_dynamic_dimensions()91 absl::Span<bool> mutable_dynamic_dimensions() { 92 return absl::MakeSpan(dynamic_dimensions_); 93 } 94 95 // Add dimension_upper_bound(). 96 97 // Removes the given dimension form the shape. Layout, if it exists, is 98 // adjusted to match the modified shape. 99 void DeleteDimension(int64 dim_to_delete); 100 101 // The following methods mirror the protobuf generated code interface for the 102 // message ShapeProto. This enabled easy migration of this data structure 103 // from a proto to a proper C++ class. 104 // TODO(b/29771030): Replace or augment these methods with a more ergonomic 105 // interface. 106 107 // Methods for accessing the primitive type. element_type()108 PrimitiveType element_type() const { return element_type_; } set_element_type(PrimitiveType value)109 void set_element_type(PrimitiveType value) { element_type_ = value; } 110 111 // Methods for accessing the dimensions array. dimensions_size()112 int dimensions_size() const { return dimensions_.size(); } dimensions(int index)113 int64 dimensions(int index) const { return dimensions_.at(index); } set_dimensions(int index,int64 value)114 void set_dimensions(int index, int64 value) { dimensions_.at(index) = value; } add_dimensions(int64 value)115 void add_dimensions(int64 value) { 116 dimensions_.push_back(value); 117 dynamic_dimensions_.push_back(false); 118 } clear_dimensions()119 void clear_dimensions() { 120 dimensions_.clear(); 121 dynamic_dimensions_.clear(); 122 } dimensions()123 absl::Span<const int64> dimensions() const { return dimensions_; } mutable_dimensions()124 absl::Span<int64> mutable_dimensions() { return absl::MakeSpan(dimensions_); } 125 126 // Methods for accessing the tuple subshapes. This field only non-empty for 127 // tuple shapes. tuple_shapes_size()128 int tuple_shapes_size() const { return tuple_shapes_.size(); } tuple_shapes(int index)129 const Shape& tuple_shapes(int index) const { return tuple_shapes_.at(index); } mutable_tuple_shapes(int index)130 Shape* mutable_tuple_shapes(int index) { return &tuple_shapes_.at(index); } add_tuple_shapes()131 Shape* add_tuple_shapes() { 132 tuple_shapes_.push_back(Shape()); 133 return &tuple_shapes_.back(); 134 } clear_tuple_shapes()135 void clear_tuple_shapes() { tuple_shapes_.clear(); } tuple_shapes()136 const std::vector<Shape>& tuple_shapes() const { return tuple_shapes_; } mutable_tuple_shapes()137 std::vector<Shape>* mutable_tuple_shapes() { return &tuple_shapes_; } 138 139 // Methods for accessing the layout field. has_layout()140 bool has_layout() const { return layout_.format() != INVALID_FORMAT; } layout()141 const Layout& layout() const { return layout_; } mutable_layout()142 Layout* mutable_layout() { return &layout_; } clear_layout()143 void clear_layout() { layout_.Clear(); } 144 145 // Recursively clear dynamic dimension of a shape. clear_dynamic_dimensions()146 void clear_dynamic_dimensions() { 147 if (!IsTuple()) { 148 for (int64 i = 0; i < dynamic_dimensions_.size(); ++i) { 149 dynamic_dimensions_[i] = false; 150 } 151 return; 152 } 153 for (auto& subshape : tuple_shapes_) { 154 subshape.clear_dynamic_dimensions(); 155 } 156 } 157 Swap(Shape * other)158 void Swap(Shape* other) { 159 using std::swap; 160 swap(*this, *other); 161 } 162 Clear()163 void Clear() { 164 element_type_ = PRIMITIVE_TYPE_INVALID; 165 clear_dimensions(); 166 tuple_shapes_.clear(); 167 clear_layout(); 168 } 169 SerializeAsString()170 string SerializeAsString() const { return ToProto().SerializeAsString(); } ShortDebugString()171 string ShortDebugString() const { return ToProto().ShortDebugString(); } DebugString()172 string DebugString() const { return ToProto().DebugString(); } 173 174 // Equal is a configurable functor to check the equality of two shapes. 175 // 176 // Examples: 177 // 178 // - Comparing two shapes ignoring their layout difference: 179 // Equal().IgnoreLayout()(shape1, shape2); 180 // 181 // - Comparing two shapes ignoring their layout and element type difference: 182 // Equal().IgnoreLayout().IgnoreElementType()(shape1, shape2); 183 class Equal { 184 public: 185 Equal() = default; 186 187 bool operator()(const Shape& lhs, const Shape& rhs); 188 IgnoreLayout()189 Equal& IgnoreLayout() { 190 ignore_layout_ = true; 191 return *this; 192 } IgnoreTilesInLayout()193 Equal& IgnoreTilesInLayout() { 194 ignore_tiles_in_layout_ = true; 195 return *this; 196 } IgnoreElementSizeInLayout()197 Equal& IgnoreElementSizeInLayout() { 198 ignore_element_size_in_layout_ = true; 199 return *this; 200 } IgnoreMemorySpaceInLayout()201 Equal& IgnoreMemorySpaceInLayout() { 202 ignore_memory_space_in_layout_ = true; 203 return *this; 204 } MinorToMajorOnlyInLayout()205 Equal& MinorToMajorOnlyInLayout() { 206 ignore_tiles_in_layout_ = true; 207 ignore_element_size_in_layout_ = true; 208 ignore_memory_space_in_layout_ = true; 209 return *this; 210 } IgnoreElementType()211 Equal& IgnoreElementType() { 212 ignore_element_type_ = true; 213 return *this; 214 } IgnoreFpPrecision()215 Equal& IgnoreFpPrecision() { 216 ignore_fp_precision_ = true; 217 return *this; 218 } IgnoreDynamicDimension()219 Equal& IgnoreDynamicDimension() { 220 ignore_dynamic_dimension_ = true; 221 return *this; 222 } IgnoreDimensions()223 Equal& IgnoreDimensions() { 224 ignore_dimensions_ = true; 225 return *this; 226 } 227 228 private: 229 bool ignore_layout_ = false; 230 bool ignore_tiles_in_layout_ = false; 231 bool ignore_element_size_in_layout_ = false; 232 bool ignore_memory_space_in_layout_ = false; 233 bool ignore_element_type_ = false; 234 bool ignore_fp_precision_ = false; 235 bool ignore_dynamic_dimension_ = false; 236 bool ignore_dimensions_ = false; 237 }; 238 239 // Test that all fields of the shape are the same, equivalent to Equal(). 240 bool operator==(const Shape& other) const { return Equal()(*this, other); } 241 bool operator!=(const Shape& other) const { return !(*this == other); } 242 243 template <typename H> AbslHashValue(H h,const Shape & s)244 friend H AbslHashValue(H h, const Shape& s) { 245 return H::combine(std::move(h), s.element_type_, s.dimensions_, 246 s.dynamic_dimensions_, s.tuple_shapes_, s.layout_); 247 } 248 249 private: 250 // The element type of this shape (tuple, array, etc). 251 PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID; 252 253 // The array bounds of the dimensions. This is nonempty only for array 254 // shapes. For a dynamically-sized dimension, the respective value in this 255 // vector is an inclusive upper limit of the array bound. 256 absl::InlinedVector<int64, 6> dimensions_; 257 258 // This vector is the same size as 'dimensions_' and indicates whether the 259 // respective dimension is dynamically sized. 260 absl::InlinedVector<bool, 6> dynamic_dimensions_; 261 262 // The tuple element subshapes. This is nonempty only for tuple shapes. 263 std::vector<Shape> tuple_shapes_; 264 265 // The layout of the shape. Only relevant for arrays. 266 Layout layout_; 267 }; 268 269 // Shape of the parameters and output of an XLA computation. This is analogous 270 // to a traditional function signature. 271 class ProgramShape { 272 public: 273 ProgramShape() = default; 274 275 // Creates a ProgramShape from a ProgramShapeProto protobuf. 276 explicit ProgramShape(const ProgramShapeProto& program_shape_proto); 277 278 // Returns a proto representation of the object. 279 ProgramShapeProto ToProto() const; 280 281 string ToString() const; 282 283 // The following methods mirror the protobuf generated code interface for the 284 // message ProgramShapeProto. This enabled easy migration of this data 285 // structure from a proto to a proper C++ class. 286 // TODO(b/29771030): Replace or augment these methods with a more ergonomic 287 // interface. 288 289 // Methods for accessing and manipulating the Shape of the parameters. parameters_size()290 int parameters_size() const { return parameters_.size(); } parameters(int index)291 const Shape& parameters(int index) const { return parameters_.at(index); } mutable_parameters(int index)292 Shape* mutable_parameters(int index) { return ¶meters_.at(index); } add_parameters()293 Shape* add_parameters() { 294 parameters_.emplace_back(); 295 return ¶meters_.back(); 296 } clear_parameters()297 void clear_parameters() { parameters_.clear(); } parameters()298 const std::vector<Shape>& parameters() const { return parameters_; } mutable_parameters()299 std::vector<Shape>* mutable_parameters() { return ¶meters_; } 300 301 // Methods for accessing and manipulating the Shape of the result. result()302 const Shape& result() const { return result_; } mutable_result()303 Shape* mutable_result() { return &result_; } 304 305 // Methods for accessing and manipulating the names of the parameters. parameter_names_size()306 int parameter_names_size() const { return parameter_names_.size(); } parameter_names(int index)307 const string& parameter_names(int index) const { 308 return parameter_names_.at(index); 309 } set_parameter_names(int index,const string & value)310 void set_parameter_names(int index, const string& value) { 311 parameter_names_.at(index) = value; 312 } mutable_parameter_names(int index)313 string* mutable_parameter_names(int index) { 314 return ¶meter_names_.at(index); 315 } add_parameter_names(const string & value)316 void add_parameter_names(const string& value) { 317 parameter_names_.push_back(value); 318 } add_parameter_names()319 string* add_parameter_names() { 320 parameter_names_.push_back(""); 321 return ¶meter_names_.back(); 322 } clear_parameter_names()323 void clear_parameter_names() { parameter_names_.clear(); } parameter_names()324 const std::vector<string>& parameter_names() const { 325 return parameter_names_; 326 } mutable_parameter_names()327 std::vector<string>* mutable_parameter_names() { return ¶meter_names_; } 328 ShortDebugString()329 string ShortDebugString() const { return ToProto().ShortDebugString(); } DebugString()330 string DebugString() const { return ToProto().DebugString(); } 331 332 private: 333 // The shapes of the parameters of the computation represented by this object. 334 std::vector<Shape> parameters_; 335 336 // The names of the parameters of the computation represented by this object. 337 std::vector<string> parameter_names_; 338 339 // The shape of the result of the computation represented by this object. 340 Shape result_; 341 }; 342 343 std::ostream& operator<<(std::ostream& out, const Shape& shape); 344 std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape); 345 346 } // namespace xla 347 348 #endif // TENSORFLOW_COMPILER_XLA_SHAPE_H_ 349