1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SHAPE_H_ 18 19 #include <string> 20 #include <vector> 21 22 #include "absl/types/optional.h" 23 #include "tensorflow/compiler/xla/layout.h" 24 #include "tensorflow/compiler/xla/primitive_util.h" 25 #include "tensorflow/compiler/xla/types.h" 26 #include "tensorflow/compiler/xla/xla_data.pb.h" 27 #include "tensorflow/core/platform/types.h" 28 29 namespace xla { 30 31 // A shape describes the number of dimensions in a array, the bounds of each 32 // dimension, and the primitive component type. For tuples, shape describes the 33 // structure (number of elements and nesting). 34 class Shape { 35 public: 36 Shape() = default; 37 38 // Construct a shape from a ShapeProto. 39 explicit Shape(const ShapeProto& shape_proto); 40 41 // Returns a ShapeProto representation of the Shape. 42 ShapeProto ToProto() const; 43 44 // Returns a human-readable string that represents the given shape, with or 45 // without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]". 46 string ToString(bool print_layout = false) const; 47 48 // Returns the rank (number of dimensions) of the given shape. Shape must be 49 // an array. rank()50 int64 rank() const { 51 CHECK(IsArray()) << "Non-arrays do not have a rank, shape: " << ToString(); 52 return dimensions_.size(); 53 } 54 55 // Returns whether the shape is of the specified type (array, tuple, etc). IsArray()56 bool IsArray() const { return primitive_util::IsArrayType(element_type()); } IsTuple()57 bool IsTuple() const { return element_type() == TUPLE; } IsToken()58 bool IsToken() const { return element_type() == TOKEN; } IsOpaque()59 bool IsOpaque() const { return element_type() == OPAQUE; } 60 61 // Returns true if no array dimension in the shape is dynamically sized. Tuple 62 // shapes are traversed recursively. 63 bool is_static() const; 64 65 // Returns true if the given dimension is dynamically-sized. is_dynamic_dimension(int dimension)66 bool is_dynamic_dimension(int dimension) const { 67 return dynamic_dimensions_.at(dimension); 68 } 69 70 // Sets whether or not the given dimension is dynamically-sized. set_dynamic_dimension(int dimension,bool is_dynamic)71 void set_dynamic_dimension(int dimension, bool is_dynamic) { 72 dynamic_dimensions_[dimension] = is_dynamic; 73 } 74 dynamic_dimensions()75 const std::vector<bool>& dynamic_dimensions() const { 76 return dynamic_dimensions_; 77 } 78 79 // Add dimension_upper_bound(). 80 81 // Removes the given dimension form the shape. Layout, if it exists, is 82 // adjusted to match the modified shape. 83 void DeleteDimension(int64 dim_to_delete); 84 85 // The following methods mirror the protobuf generated code interface for the 86 // message ShapeProto. This enabled easy migration of this data structure 87 // from a proto to a proper C++ class. 88 // TODO(b/29771030): Replace or augment these methods with a more ergonomic 89 // interface. 90 91 // Methods for accessing the primitive type. element_type()92 PrimitiveType element_type() const { return element_type_; } set_element_type(PrimitiveType value)93 void set_element_type(PrimitiveType value) { element_type_ = value; } 94 95 // Methods for accessing the dimensions array. dimensions_size()96 int dimensions_size() const { return dimensions_.size(); } dimensions(int index)97 int64 dimensions(int index) const { return dimensions_.at(index); } set_dimensions(int index,int64 value)98 void set_dimensions(int index, int64 value) { dimensions_.at(index) = value; } add_dimensions(int64 value)99 void add_dimensions(int64 value) { 100 dimensions_.push_back(value); 101 dynamic_dimensions_.push_back(false); 102 } clear_dimensions()103 void clear_dimensions() { 104 dimensions_.clear(); 105 dynamic_dimensions_.clear(); 106 } dimensions()107 const std::vector<int64>& dimensions() const { return dimensions_; } mutable_dimensions()108 absl::Span<int64> mutable_dimensions() { return absl::MakeSpan(dimensions_); } 109 110 // Methods for accessing the tuple subshapes. This field only non-empty for 111 // tuple shapes. tuple_shapes_size()112 int tuple_shapes_size() const { return tuple_shapes_.size(); } tuple_shapes(int index)113 const Shape& tuple_shapes(int index) const { return tuple_shapes_.at(index); } mutable_tuple_shapes(int index)114 Shape* mutable_tuple_shapes(int index) { return &tuple_shapes_.at(index); } add_tuple_shapes()115 Shape* add_tuple_shapes() { 116 tuple_shapes_.push_back(Shape()); 117 return &tuple_shapes_.back(); 118 } clear_tuple_shapes()119 void clear_tuple_shapes() { tuple_shapes_.clear(); } tuple_shapes()120 const std::vector<Shape>& tuple_shapes() const { return tuple_shapes_; } mutable_tuple_shapes()121 std::vector<Shape>* mutable_tuple_shapes() { return &tuple_shapes_; } 122 123 // Methods for accessing the layout field. has_layout()124 bool has_layout() const { return layout_.format() != INVALID_FORMAT; } layout()125 const Layout& layout() const { return layout_; } mutable_layout()126 Layout* mutable_layout() { return &layout_; } clear_layout()127 void clear_layout() { layout_.Clear(); } 128 Swap(Shape * other)129 void Swap(Shape* other) { 130 using std::swap; 131 swap(*this, *other); 132 } 133 Clear()134 void Clear() { 135 element_type_ = PRIMITIVE_TYPE_INVALID; 136 dimensions_.clear(); 137 tuple_shapes_.clear(); 138 clear_layout(); 139 } 140 SerializeAsString()141 string SerializeAsString() const { return ToProto().SerializeAsString(); } ShortDebugString()142 string ShortDebugString() const { return ToProto().ShortDebugString(); } DebugString()143 string DebugString() const { return ToProto().DebugString(); } 144 145 // Equal is a configurable functor to check the equality of two shapes. 146 // 147 // Examples: 148 // 149 // - Comparing two shapes ignoring their layout difference: 150 // Equal().IgnoreLayout()(shape1, shape2); 151 // 152 // - Comparing two shapes ignoring their layout and element type difference: 153 // Equal().IgnoreLayout().IgnoreElementType()(shape1, shape2); 154 class Equal { 155 public: 156 Equal() = default; 157 158 bool operator()(const Shape& lhs, const Shape& rhs); 159 IgnoreLayout()160 Equal& IgnoreLayout() { 161 ignore_layout_ = true; 162 return *this; 163 } IgnoreTilesInLayout()164 Equal& IgnoreTilesInLayout() { 165 ignore_tiles_in_layout_ = true; 166 return *this; 167 } IgnoreElementSizeInLayout()168 Equal& IgnoreElementSizeInLayout() { 169 ignore_element_size_in_layout_ = true; 170 return *this; 171 } IgnoreElementType()172 Equal& IgnoreElementType() { 173 ignore_element_type_ = true; 174 return *this; 175 } IgnoreFpPrecision()176 Equal& IgnoreFpPrecision() { 177 ignore_fp_precision_ = true; 178 return *this; 179 } IgnoreDynamicDimension()180 Equal& IgnoreDynamicDimension() { 181 ignore_dynamic_dimension_ = true; 182 return *this; 183 } 184 185 private: 186 bool ignore_layout_ = false; 187 bool ignore_tiles_in_layout_ = false; 188 bool ignore_element_size_in_layout_ = false; 189 bool ignore_element_type_ = false; 190 bool ignore_fp_precision_ = false; 191 bool ignore_dynamic_dimension_ = false; 192 }; 193 194 // Test that all fields of the shape are the same, equivalent to Equal(). 195 bool operator==(const Shape& other) const { return Equal()(*this, other); } 196 bool operator!=(const Shape& other) const { return !(*this == other); } 197 198 private: 199 // The element type of this shape (tuple, array, etc). 200 PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID; 201 202 // The array bounds of the dimensions. This is nonempty only for array 203 // shapes. For a dynamically-sized dimension, the respective value in this 204 // vector is an inclusive upper limit of the array bound. 205 std::vector<int64> dimensions_; 206 207 // This vector is the same size as 'dimensions_' and indicates whether the 208 // respective dimension is dynamically sized. 209 std::vector<bool> dynamic_dimensions_; 210 211 // The tuple element subshapes. This is nonempty only for tuple shapes. 212 std::vector<Shape> tuple_shapes_; 213 214 // The layout of the shape. Only relevant for arrays. 215 Layout layout_; 216 }; 217 218 // Shape of the parameters and output of an XLA computation. This is analogous 219 // to a traditional function signature. 220 class ProgramShape { 221 public: 222 ProgramShape() = default; 223 224 // Creates a ProgramShape from a ProgramShapeProto protobuf. 225 explicit ProgramShape(const ProgramShapeProto& program_shape_proto); 226 227 // Returns a proto representation of the object. 228 ProgramShapeProto ToProto() const; 229 230 string ToString() const; 231 232 // The following methods mirror the protobuf generated code interface for the 233 // message ProgramShapeProto. This enabled easy migration of this data 234 // structure from a proto to a proper C++ class. 235 // TODO(b/29771030): Replace or augment these methods with a more ergonomic 236 // interface. 237 238 // Methods for accessing and manipulating the Shape of the parameters. parameters_size()239 int parameters_size() const { return parameters_.size(); } parameters(int index)240 const Shape& parameters(int index) const { return parameters_.at(index); } mutable_parameters(int index)241 Shape* mutable_parameters(int index) { return ¶meters_.at(index); } add_parameters()242 Shape* add_parameters() { 243 parameters_.emplace_back(); 244 return ¶meters_.back(); 245 } clear_parameters()246 void clear_parameters() { parameters_.clear(); } parameters()247 const std::vector<Shape>& parameters() const { return parameters_; } mutable_parameters()248 std::vector<Shape>* mutable_parameters() { return ¶meters_; } 249 250 // Methods for accessing and manipulating the Shape of the result. result()251 const Shape& result() const { return result_; } mutable_result()252 Shape* mutable_result() { return &result_; } 253 254 // Methods for accessing and manipulating the names of the parameters. parameter_names_size()255 int parameter_names_size() const { return parameter_names_.size(); } parameter_names(int index)256 const string& parameter_names(int index) const { 257 return parameter_names_.at(index); 258 } set_parameter_names(int index,const string & value)259 void set_parameter_names(int index, const string& value) { 260 parameter_names_.at(index) = value; 261 } mutable_parameter_names(int index)262 string* mutable_parameter_names(int index) { 263 return ¶meter_names_.at(index); 264 } add_parameter_names(const string & value)265 void add_parameter_names(const string& value) { 266 parameter_names_.push_back(value); 267 } add_parameter_names()268 string* add_parameter_names() { 269 parameter_names_.push_back(""); 270 return ¶meter_names_.back(); 271 } clear_parameter_names()272 void clear_parameter_names() { parameter_names_.clear(); } parameter_names()273 const std::vector<string>& parameter_names() const { 274 return parameter_names_; 275 } mutable_parameter_names()276 std::vector<string>* mutable_parameter_names() { return ¶meter_names_; } 277 ShortDebugString()278 string ShortDebugString() const { return ToProto().ShortDebugString(); } DebugString()279 string DebugString() const { return ToProto().DebugString(); } 280 281 private: 282 // The shapes of the parameters of the computation represented by this object. 283 std::vector<Shape> parameters_; 284 285 // The names of the parameters of the computation represented by this object. 286 std::vector<string> parameter_names_; 287 288 // The shape of the result of the computation represented by this object. 289 Shape result_; 290 }; 291 292 std::ostream& operator<<(std::ostream& out, const Shape& shape); 293 std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape); 294 295 } // namespace xla 296 297 #endif // TENSORFLOW_COMPILER_XLA_SHAPE_H_ 298