1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_ 18 19 #include <map> 20 #include <vector> 21 22 #include "absl/container/inlined_vector.h" 23 #include "absl/types/optional.h" 24 #include "tensorflow/compiler/xla/service/hlo_computation.h" 25 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 26 #include "tensorflow/compiler/xla/service/hlo_instructions.h" 27 #include "tensorflow/compiler/xla/service/hlo_module.h" 28 #include "tensorflow/compiler/xla/service/hlo_sharding.h" 29 30 namespace xla { 31 namespace hlo_sharding_util { 32 33 struct GatherParallelDims { 34 absl::InlinedVector<int64, 1> indices_parallel_dims; 35 absl::InlinedVector<int64, 1> operand_parallel_dims; 36 std::vector<int64> index_parallel_in_dim; 37 }; 38 39 // Returns true if the lhs sharding is preferable over the rhs sharding. 40 // The most specific sharding is tile maximal followed by single device tile 41 // maximal and finally replicated. This order aims to primarily reduce memory 42 // usage and secondly reduce total compute. 43 // Note: This does NOT provide a total ordering as we can have 2 different 44 // sharding with same preference level. 45 bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs); 46 47 // Tries to refine `to_merge` by combining with `old`. Returns if the final 48 // `to_merge` is more specific than `old`. 49 bool MergeSharding(const HloSharding& old, HloSharding* to_merge, 50 bool may_combine_partial_sharding); 51 52 // Merges `to_merge` into `dst` only if they are compatible, and the merged 53 // sharding has >= minimum_tiles tiles. Returns if merging happened. 54 bool MergeShardingIfCompatible(const HloSharding& to_merge, 55 int64_t minimum_tiles, HloSharding* dst); 56 57 // Given a map<device, occurrence_count>, selects the device with higher 58 // occurrence count (if any). If top_count in not nullptr, it will receive the 59 // count of the dominant device returned. 60 absl::optional<int64> SelectDominantDevice( 61 const std::map<int64, int64>& device_map, int64* top_count); 62 63 // Assigns all the instructions of a computation, to a given device. 64 // This API does not recurse into called computations, and does not assign 65 // instructions which already have sharding. 66 Status AssignComputationDevice(HloComputation* computation, int64_t device); 67 68 // Given an instruction container, returns the device which is most commonly 69 // occurring among the instructions. 70 absl::optional<int64> GetMostOccurringDevice( 71 absl::Span<HloInstruction* const> instructions); 72 73 // Given a set of computations, tries to extract the dominant device. A device 74 // is dominant if the combined occurrence among all the instructions of the 75 // input computations, is greater/equal than/to dominant_factor (real number 76 // from 0 to 1). 77 // This API does not recurse into called computations. 78 // If no device exists that satisfies the condition, the returned optional will 79 // hold no value. 80 StatusOr<absl::optional<int64>> GetDominantDevice( 81 absl::Span<HloComputation* const> computations, double dominant_factor); 82 83 // Returns the HloSharding with the tile dimensions and tile assignment 84 // transposed based on the specified dimension numbers. In case of a tile 85 // maximal sharding returns the original sharding. 86 HloSharding TransposeSharding(const HloSharding& sharding, 87 const std::vector<int64>& dimensions); 88 89 // Returns the HloSharding with the tile shape reshaped based on the source and 90 // target shapes and the tile assignment adjusted to correspond to the new tile 91 // shape or absl::nullopt if the resulting reshape would create an invalid 92 // sharding (non continuous or non uniformly sized tiles). In case of a tile 93 // maximal sharding returns the original sharding. 94 absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape, 95 const Shape& target_shape, 96 const HloSharding& sharding); 97 98 // Returns the HloSharding with the tile dimensions and tile assignment 99 // reversed based on the specified dimension numbers. In case of a tile 100 // maximal sharding returns the original sharding. 101 HloSharding ReverseSharding(const HloSharding& sharding, 102 absl::Span<const int64> dimensions); 103 104 // Returns a sharding tiled on unique dimension dim by reshaping the tile 105 // assignment of the sharding argument. Only dimensions in the dims span 106 // argument are considered for reshaping, the others are ignored. 107 // Assumptions: sharding is tile sharded, and dim must be included in dims. 108 HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64_t dim, 109 absl::Span<const int64> dims); 110 111 // Returns true if the provided module includes one or more instructions with 112 // a tile sharding. 113 bool ContainsTileSharding(const HloModule& module); 114 115 // Returns the preferred output sharding for a gather op based on the sharding 116 // of the indces. 117 HloSharding GatherOutputSharding(const HloSharding& index_sharding, 118 const HloInstruction* hlo); 119 120 // Returns the preferred index sharding for a gather op based on the sharding 121 // of the output. 122 HloSharding GatherIndexSharding(const HloSharding& output_sharding, 123 const HloInstruction* hlo); 124 125 // Returns a new HloSharding for a gather op so that only non offset dimensions 126 // are sharded. Assume "result" is returned by this function. It is ensured that 127 // "GetIndexSharding(result, hlo)" will have the same number of elements as 128 // "result". 129 HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo); 130 131 // Returns the preferred index sharding for a scatter op based on the sharding 132 // of the data. 133 HloSharding ScatterIndexSharding(const HloSharding& data_sharding, 134 const HloInstruction* hlo); 135 136 // Returns the preferred data sharding for a scatter op based on the sharding 137 // of the index. 138 HloSharding ScatterDataSharding(const HloSharding& index_sharding, 139 const HloInstruction* hlo); 140 141 // Returns a new index sharding for a scatter op so that we only shard on first 142 // "number of scatter_window_dims" dimensions. Assume "result" is returned by 143 // this function. It is ensured that "ScatterDataSharding(result, hlo)" will 144 // have the same number of elements as "result". 145 HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding, 146 const HloInstruction& hlo); 147 148 // Returns a new data sharding for a scatter op so that we only shard on 149 // scatter_window_dims. Assume "result" is returned by this function. It is 150 // ensured that "ScatterIndexSharding(result, hlo)" will have the same number of 151 // elements as "result". 152 HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding, 153 const HloInstruction& hlo); 154 155 // Returns an output sharding of gather by passing through the data operand's 156 // sharding. 157 absl::optional<HloSharding> GatherOutputShardingFromDataOperand( 158 const HloSharding& data_operand_sharding, const HloInstruction& hlo, 159 const Shape& output_shape, const Shape& operand_shape); 160 161 // Returns a data operand sharding of gather by passing through the output's 162 // sharding. 163 absl::optional<HloSharding> GatherDataOperandShardingFromOutput( 164 const HloSharding& output_sharding, const HloInstruction& hlo); 165 166 // Returns an output sharding of scatter by passing through the update operand's 167 // sharding. 168 absl::optional<HloSharding> ScatterOutputShardingFromUpdate( 169 const HloSharding& update_sharding, const HloInstruction& hlo); 170 171 // Returns an update operand sharding of scatter by passing through the output's 172 // sharding. 173 absl::optional<HloSharding> ScatterUpdateShardingFromOutput( 174 const HloSharding& output_sharding, const HloInstruction& hlo); 175 176 // Returns an identity value and an HloOpcode for reduce computation of scatter 177 // instruction. 178 // - If computation is add/or, return 0/false with corresponding op code; 179 // - If computation is multiply/and, return 1/true with corresponding op code. 180 // - If computation is min/max, return max value/min value with corresponding op 181 // code. 182 // - Otherwise, return error status. 183 StatusOr<std::pair<std::unique_ptr<HloInstruction>, HloOpcode>> 184 IdentityValueAndHloOpcodeForScatterReduceComputation( 185 const HloScatterInstruction& scatter); 186 187 // Given a sharding and a list of devices in the topology, return a 188 // list of the devices that `sharding` applies to. 189 std::vector<int64> DevicesForSharding( 190 const HloSharding& sharding, const std::vector<int64>& available_devices); 191 192 // Returns a sharding that replicates data across devices along the given 193 // dimensions in the original sharding. 194 HloSharding PartiallyReplicateTiledShardingOnDims( 195 const HloSharding& sharding, absl::Span<const int64> dims_to_replicate); 196 197 // Returns a sharding that replicates data across devices along all dimensions 198 // but the given ones to keep in the original sharding. 199 HloSharding PartiallyReplicateTiledShardingOnAllDimsExcept( 200 const HloSharding& sharding, absl::Span<const int64> dims_to_keep); 201 202 // Returns a sharding the removes given tile dimensions. 203 // 204 // Precondition: if not tile maximal, the size of each tile dimension must be 1. 205 HloSharding RemoveShapeDimensions(const HloSharding& sharding, 206 const std::vector<int64>& dims_to_remove); 207 208 // Similar to TransposeSharding(), but allows removing/adding non-partitioned 209 // dimensions. In src_to_tgt and tgt_to_src, -1 represents a non-existing 210 // dimension. 211 absl::optional<HloSharding> TransposeShardingWithCollapsedDims( 212 const HloSharding& source, absl::Span<int64 const> src_to_tgt, 213 absl::Span<int64 const> tgt_to_src); 214 215 // Returns the iota dimension if maybe_iota is an kIota instruction or 216 // equivalent to kIota. 217 absl::optional<int64> GetDimensionForIota(const HloInstruction* maybe_iota); 218 219 // Returns identified parallel dimensions for Gather. 220 absl::optional<GatherParallelDims> GetGatherBatchParallelDims( 221 const HloInstruction& hlo); 222 223 // Returns the parallel dimensions of the output of a gather based on the 224 // parallel dimensions of the input. 225 absl::InlinedVector<int64, 1> GatherParallelOutputDims( 226 const HloInstruction& gather, const GatherParallelDims& parallel_dim); 227 228 // Returns the parallel dimensions of the data operand of a gather with the 229 // order of the parallel dimensions matching that of the parallel dimensions 230 // of the output. 231 absl::InlinedVector<int64, 1> GatherOutputAlignedOperandParallelDims( 232 const HloInstruction& gather, const GatherParallelDims& parallel_dims); 233 234 } // namespace hlo_sharding_util 235 } // namespace xla 236 237 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_ 238