1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_ 18 19 #include "absl/container/flat_hash_map.h" 20 #include "absl/strings/string_view.h" 21 #include "tensorflow/compiler/xla/service/call_graph.h" 22 #include "tensorflow/compiler/xla/service/hlo_module.h" 23 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 24 #include "tensorflow/compiler/xla/statusor.h" 25 26 namespace xla { 27 28 // When the HLO graph contains a cross-module AllReduce, followed by some simple 29 // linear operations, followed by a cross-replica AllReduce (also known as 30 // cross-replica sum, or CRS), we can combine the CMAR and the CRAR, to use an 31 // efficient AllReduce implementation that fully utilizes the interconnect 32 // bandwidth. 33 // Such sequences appear in spatially partitioned models. 34 // This pass must run right after spatial partitioning, when the code is still 35 // in a single HLO module. 36 // 37 // The steps are: 38 // 1) Find CMARs followed by simple ops followed by CRARs. 39 // 2) Group CMARs by all_reduce_id. They must all be rewritten. 40 // 3) Prove that the CMAR patterns in each core produce the same result. 41 // 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the 42 // other operand by the number of spatial partitions. 43 // 5) Turn the CRAR into an all-core AllReduce. 44 // 45 // The pass also handles the case where multiple CMARs lead to the same CRAR, 46 // and eliminates all CMARs. This graph: 47 // 48 // Y 49 // | 50 // X CMAR_2 Z 51 // | \ / 52 // CMAR_1 + 53 // \ / 54 // + 55 // | 56 // CRAR 57 // 58 // gets rewritten to: 59 // 60 // Z num_partitions 61 // \ / 62 // Y div 63 // \ / 64 // X + 65 // \ / 66 // + 67 // | 68 // all-core AR 69 // 70 class ArCrsCombiner : public HloModulePass { 71 public: ArCrsCombiner(int num_spatial_partitions)72 ArCrsCombiner(int num_spatial_partitions) 73 : num_spatial_partitions_(num_spatial_partitions) {} name()74 absl::string_view name() const override { return "ar-crs-combiner"; } 75 StatusOr<bool> Run(HloModule* module) override; 76 77 // Helper method to allow testing of InstructionsComputeSameValue. 78 static bool TestInstructionsComputeSameValue(HloInstruction* i1, 79 HloInstruction* i2); 80 81 private: 82 // We used this struct because multiple ARs could be paired with the same CRS. 83 // In this case, we want to select the AR that is furthest from the CRS, 84 // because it makes it easier to eliminate all ARs during RewriteGraph. 85 struct ArCrsPair { 86 HloInstruction* ar; 87 HloInstruction* crs; 88 // The length of the path from AR to CRS in the HLO graph. 89 int64 distance; 90 ArCrsPairArCrsPair91 ArCrsPair(HloInstruction* all_reduce, HloInstruction* cross_replica_sum, 92 int64 dist) 93 : ar(all_reduce), crs(cross_replica_sum), distance(dist) {} 94 ToStringArCrsPair95 string ToString() { 96 return absl::StrCat("(AR: ", ar->name(), ", CRS: ", crs->name(), 97 ", distance: ", distance, ")"); 98 } 99 }; 100 101 absl::optional<ArCrsCombiner::ArCrsPair> MatchesArCrsPattern( 102 HloInstruction* instruction); 103 104 // If the passed instruction is a while parameter, and the while body is only 105 // called by a single while instruction, return the while instruction. 106 absl::optional<HloInstruction*> WhileFromBodyParameter( 107 HloInstruction* instruction); 108 109 // Returns a vector of tuple instructions. 110 // If all instructions that flow to "instruction" are tuples, return them. 111 // Otherwise, return an empty vector. 112 std::vector<HloInstruction*> GetAllTuples(HloInstruction* instruction); 113 114 // Checks whether two different elements in the same tuple compute the same 115 // value. 116 bool TupleElementsComputeSameValue( 117 HloInstruction* tuple_shaped_instruction, int64 i1, int64 i2, 118 absl::flat_hash_map<int64, int64>* visited_pairs); 119 120 // Returns whether the instructions i1 and i2 can be shown to evaluate to the 121 // same value. Handling WHILE requires recursion, which may cause us to visit 122 // the same instruction again. To avoid infinite loops, we pass a cache of 123 // visited instruction pairs. 124 bool InstructionsComputeSameValue( 125 HloInstruction* i1, HloInstruction* i2, 126 absl::flat_hash_map<int64, int64>* visited_pairs); 127 128 // Populates all_reduce_map_. 129 void GroupAllReducesById(HloModule* module); 130 131 // Looks at each AllReduce group in all_reduce_map_, and keeps only the 132 // groups for which it's safe to move the AllReduce later in the HLO graph. 133 void KeepProvablyEqualInstructionGroups(); 134 135 // Performs the graph rewrite that eliminates the early AllReduce and turns 136 // the later CRS into an AllReduce. 137 StatusOr<bool> RewriteGraph(); 138 139 int num_spatial_partitions_; 140 141 // Map from all-reduce ids to the AR/CRS pairs. 142 absl::flat_hash_map<int64, std::vector<ArCrsPair>> all_reduce_map_; 143 144 // Map from a CRS instruction to the all-reduce ID of the AR paired with the 145 // CRS. Sometimes, several ARs in the code could be paired with the same CRS. 146 // We use this map to pick a single AR/CRS path to rewrite. 147 absl::flat_hash_map<HloInstruction*, int64> crs_reserved_map_; 148 149 std::unique_ptr<CallGraph> call_graph_; 150 }; 151 152 } // namespace xla 153 154 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_ 155