1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SHAPE_H_ 18 19 #include <string> 20 #include <vector> 21 22 #include "absl/container/inlined_vector.h" 23 #include "absl/types/optional.h" 24 #include "tensorflow/compiler/xla/layout.h" 25 #include "tensorflow/compiler/xla/primitive_util.h" 26 #include "tensorflow/compiler/xla/types.h" 27 #include "tensorflow/compiler/xla/xla_data.pb.h" 28 #include "tensorflow/core/platform/types.h" 29 30 namespace xla { 31 32 // A shape describes the number of dimensions in a array, the bounds of each 33 // dimension, and the primitive component type. For tuples, shape describes the 34 // structure (number of elements and nesting). 35 class Shape { 36 public: 37 Shape() = default; 38 39 // Construct a shape from a ShapeProto. 40 explicit Shape(const ShapeProto& shape_proto); 41 42 // Returns a ShapeProto representation of the Shape. 43 ShapeProto ToProto() const; 44 45 // Returns a human-readable string that represents the given shape, with or 46 // without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]". 47 string ToString(bool print_layout = false) const; 48 49 // Returns the rank (number of dimensions) of the given shape. Shape must be 50 // an array. rank()51 int64 rank() const { 52 CHECK(IsArray()) << "Non-arrays do not have a rank, shape: " << ToString(); 53 return dimensions_.size(); 54 } 55 56 // Returns whether the shape is of the specified type (array, tuple, etc). IsArray()57 bool IsArray() const { return primitive_util::IsArrayType(element_type()); } IsTuple()58 bool IsTuple() const { return element_type() == TUPLE; } IsToken()59 bool IsToken() const { return element_type() == TOKEN; } IsOpaque()60 bool IsOpaque() const { return element_type() == OPAQUE_TYPE; } 61 62 // Returns true if no array dimension in the shape is dynamically sized. Tuple 63 // shapes are traversed recursively. 64 bool is_static() const; 65 66 // Returns true if the given dimension is dynamically-sized. is_dynamic_dimension(int dimension)67 bool is_dynamic_dimension(int dimension) const { 68 return dynamic_dimensions_.at(dimension); 69 } 70 71 // Sets whether or not the given dimension is dynamically-sized. set_dynamic_dimension(int dimension,bool is_dynamic)72 void set_dynamic_dimension(int dimension, bool is_dynamic) { 73 dynamic_dimensions_[dimension] = is_dynamic; 74 } 75 dynamic_dimensions()76 absl::Span<const bool> dynamic_dimensions() const { 77 return dynamic_dimensions_; 78 } 79 mutable_dynamic_dimensions()80 absl::Span<bool> mutable_dynamic_dimensions() { 81 return absl::MakeSpan(dynamic_dimensions_); 82 } 83 84 // Add dimension_upper_bound(). 85 86 // Removes the given dimension form the shape. Layout, if it exists, is 87 // adjusted to match the modified shape. 88 void DeleteDimension(int64 dim_to_delete); 89 90 // The following methods mirror the protobuf generated code interface for the 91 // message ShapeProto. This enabled easy migration of this data structure 92 // from a proto to a proper C++ class. 93 // TODO(b/29771030): Replace or augment these methods with a more ergonomic 94 // interface. 95 96 // Methods for accessing the primitive type. element_type()97 PrimitiveType element_type() const { return element_type_; } set_element_type(PrimitiveType value)98 void set_element_type(PrimitiveType value) { element_type_ = value; } 99 100 // Methods for accessing the dimensions array. dimensions_size()101 int dimensions_size() const { return dimensions_.size(); } dimensions(int index)102 int64 dimensions(int index) const { return dimensions_.at(index); } set_dimensions(int index,int64 value)103 void set_dimensions(int index, int64 value) { dimensions_.at(index) = value; } add_dimensions(int64 value)104 void add_dimensions(int64 value) { 105 dimensions_.push_back(value); 106 dynamic_dimensions_.push_back(false); 107 } clear_dimensions()108 void clear_dimensions() { 109 dimensions_.clear(); 110 dynamic_dimensions_.clear(); 111 } dimensions()112 absl::Span<const int64> dimensions() const { return dimensions_; } mutable_dimensions()113 absl::Span<int64> mutable_dimensions() { return absl::MakeSpan(dimensions_); } 114 115 // Methods for accessing the tuple subshapes. This field only non-empty for 116 // tuple shapes. tuple_shapes_size()117 int tuple_shapes_size() const { return tuple_shapes_.size(); } tuple_shapes(int index)118 const Shape& tuple_shapes(int index) const { return tuple_shapes_.at(index); } mutable_tuple_shapes(int index)119 Shape* mutable_tuple_shapes(int index) { return &tuple_shapes_.at(index); } add_tuple_shapes()120 Shape* add_tuple_shapes() { 121 tuple_shapes_.push_back(Shape()); 122 return &tuple_shapes_.back(); 123 } clear_tuple_shapes()124 void clear_tuple_shapes() { tuple_shapes_.clear(); } tuple_shapes()125 const std::vector<Shape>& tuple_shapes() const { return tuple_shapes_; } mutable_tuple_shapes()126 std::vector<Shape>* mutable_tuple_shapes() { return &tuple_shapes_; } 127 128 // Methods for accessing the layout field. has_layout()129 bool has_layout() const { return layout_.format() != INVALID_FORMAT; } layout()130 const Layout& layout() const { return layout_; } mutable_layout()131 Layout* mutable_layout() { return &layout_; } clear_layout()132 void clear_layout() { layout_.Clear(); } 133 134 // Recursively clear dynamic dimension of a shape. clear_dynamic_dimensions()135 void clear_dynamic_dimensions() { 136 if (!IsTuple()) { 137 for (int64 i = 0; i < dynamic_dimensions_.size(); ++i) { 138 dynamic_dimensions_[i] = false; 139 } 140 return; 141 } 142 for (auto& subshape : tuple_shapes_) { 143 subshape.clear_dynamic_dimensions(); 144 } 145 } 146 Swap(Shape * other)147 void Swap(Shape* other) { 148 using std::swap; 149 swap(*this, *other); 150 } 151 Clear()152 void Clear() { 153 element_type_ = PRIMITIVE_TYPE_INVALID; 154 clear_dimensions(); 155 tuple_shapes_.clear(); 156 clear_layout(); 157 } 158 SerializeAsString()159 string SerializeAsString() const { return ToProto().SerializeAsString(); } ShortDebugString()160 string ShortDebugString() const { return ToProto().ShortDebugString(); } DebugString()161 string DebugString() const { return ToProto().DebugString(); } 162 163 // Equal is a configurable functor to check the equality of two shapes. 164 // 165 // Examples: 166 // 167 // - Comparing two shapes ignoring their layout difference: 168 // Equal().IgnoreLayout()(shape1, shape2); 169 // 170 // - Comparing two shapes ignoring their layout and element type difference: 171 // Equal().IgnoreLayout().IgnoreElementType()(shape1, shape2); 172 class Equal { 173 public: 174 Equal() = default; 175 176 bool operator()(const Shape& lhs, const Shape& rhs); 177 IgnoreLayout()178 Equal& IgnoreLayout() { 179 ignore_layout_ = true; 180 return *this; 181 } IgnoreTilesInLayout()182 Equal& IgnoreTilesInLayout() { 183 ignore_tiles_in_layout_ = true; 184 return *this; 185 } IgnoreElementSizeInLayout()186 Equal& IgnoreElementSizeInLayout() { 187 ignore_element_size_in_layout_ = true; 188 return *this; 189 } IgnoreMemorySpaceInLayout()190 Equal& IgnoreMemorySpaceInLayout() { 191 ignore_memory_space_in_layout_ = true; 192 return *this; 193 } MinorToMajorOnlyInLayout()194 Equal& MinorToMajorOnlyInLayout() { 195 ignore_tiles_in_layout_ = true; 196 ignore_element_size_in_layout_ = true; 197 ignore_memory_space_in_layout_ = true; 198 return *this; 199 } IgnoreElementType()200 Equal& IgnoreElementType() { 201 ignore_element_type_ = true; 202 return *this; 203 } IgnoreFpPrecision()204 Equal& IgnoreFpPrecision() { 205 ignore_fp_precision_ = true; 206 return *this; 207 } IgnoreDynamicDimension()208 Equal& IgnoreDynamicDimension() { 209 ignore_dynamic_dimension_ = true; 210 return *this; 211 } 212 213 private: 214 bool ignore_layout_ = false; 215 bool ignore_tiles_in_layout_ = false; 216 bool ignore_element_size_in_layout_ = false; 217 bool ignore_memory_space_in_layout_ = false; 218 bool ignore_element_type_ = false; 219 bool ignore_fp_precision_ = false; 220 bool ignore_dynamic_dimension_ = false; 221 }; 222 223 // Test that all fields of the shape are the same, equivalent to Equal(). 224 bool operator==(const Shape& other) const { return Equal()(*this, other); } 225 bool operator!=(const Shape& other) const { return !(*this == other); } 226 227 template <typename H> AbslHashValue(H h,const Shape & s)228 friend H AbslHashValue(H h, const Shape& s) { 229 return H::combine(std::move(h), s.element_type_, s.dimensions_, 230 s.dynamic_dimensions_, s.tuple_shapes_, s.layout_); 231 } 232 233 private: 234 // The element type of this shape (tuple, array, etc). 235 PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID; 236 237 // The array bounds of the dimensions. This is nonempty only for array 238 // shapes. For a dynamically-sized dimension, the respective value in this 239 // vector is an inclusive upper limit of the array bound. 240 absl::InlinedVector<int64, 6> dimensions_; 241 242 // This vector is the same size as 'dimensions_' and indicates whether the 243 // respective dimension is dynamically sized. 244 absl::InlinedVector<bool, 6> dynamic_dimensions_; 245 246 // The tuple element subshapes. This is nonempty only for tuple shapes. 247 std::vector<Shape> tuple_shapes_; 248 249 // The layout of the shape. Only relevant for arrays. 250 Layout layout_; 251 }; 252 253 // Shape of the parameters and output of an XLA computation. This is analogous 254 // to a traditional function signature. 255 class ProgramShape { 256 public: 257 ProgramShape() = default; 258 259 // Creates a ProgramShape from a ProgramShapeProto protobuf. 260 explicit ProgramShape(const ProgramShapeProto& program_shape_proto); 261 262 // Returns a proto representation of the object. 263 ProgramShapeProto ToProto() const; 264 265 string ToString() const; 266 267 // The following methods mirror the protobuf generated code interface for the 268 // message ProgramShapeProto. This enabled easy migration of this data 269 // structure from a proto to a proper C++ class. 270 // TODO(b/29771030): Replace or augment these methods with a more ergonomic 271 // interface. 272 273 // Methods for accessing and manipulating the Shape of the parameters. parameters_size()274 int parameters_size() const { return parameters_.size(); } parameters(int index)275 const Shape& parameters(int index) const { return parameters_.at(index); } mutable_parameters(int index)276 Shape* mutable_parameters(int index) { return ¶meters_.at(index); } add_parameters()277 Shape* add_parameters() { 278 parameters_.emplace_back(); 279 return ¶meters_.back(); 280 } clear_parameters()281 void clear_parameters() { parameters_.clear(); } parameters()282 const std::vector<Shape>& parameters() const { return parameters_; } mutable_parameters()283 std::vector<Shape>* mutable_parameters() { return ¶meters_; } 284 285 // Methods for accessing and manipulating the Shape of the result. result()286 const Shape& result() const { return result_; } mutable_result()287 Shape* mutable_result() { return &result_; } 288 289 // Methods for accessing and manipulating the names of the parameters. parameter_names_size()290 int parameter_names_size() const { return parameter_names_.size(); } parameter_names(int index)291 const string& parameter_names(int index) const { 292 return parameter_names_.at(index); 293 } set_parameter_names(int index,const string & value)294 void set_parameter_names(int index, const string& value) { 295 parameter_names_.at(index) = value; 296 } mutable_parameter_names(int index)297 string* mutable_parameter_names(int index) { 298 return ¶meter_names_.at(index); 299 } add_parameter_names(const string & value)300 void add_parameter_names(const string& value) { 301 parameter_names_.push_back(value); 302 } add_parameter_names()303 string* add_parameter_names() { 304 parameter_names_.push_back(""); 305 return ¶meter_names_.back(); 306 } clear_parameter_names()307 void clear_parameter_names() { parameter_names_.clear(); } parameter_names()308 const std::vector<string>& parameter_names() const { 309 return parameter_names_; 310 } mutable_parameter_names()311 std::vector<string>* mutable_parameter_names() { return ¶meter_names_; } 312 ShortDebugString()313 string ShortDebugString() const { return ToProto().ShortDebugString(); } DebugString()314 string DebugString() const { return ToProto().DebugString(); } 315 316 private: 317 // The shapes of the parameters of the computation represented by this object. 318 std::vector<Shape> parameters_; 319 320 // The names of the parameters of the computation represented by this object. 321 std::vector<string> parameter_names_; 322 323 // The shape of the result of the computation represented by this object. 324 Shape result_; 325 }; 326 327 std::ostream& operator<<(std::ostream& out, const Shape& shape); 328 std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape); 329 330 } // namespace xla 331 332 #endif // TENSORFLOW_COMPILER_XLA_SHAPE_H_ 333