• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/strings/string_view.h"
21 #include "tensorflow/compiler/xla/service/call_graph.h"
22 #include "tensorflow/compiler/xla/service/hlo_module.h"
23 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
24 #include "tensorflow/compiler/xla/statusor.h"
25 
26 namespace xla {
27 
28 // When the HLO graph contains a cross-module AllReduce, followed by some simple
29 // linear operations, followed by a cross-replica AllReduce (also known as
30 // cross-replica sum, or CRS), we can combine the CMAR and the CRAR, to use an
31 // efficient AllReduce implementation that fully utilizes the interconnect
32 // bandwidth.
33 // Such sequences appear in spatially partitioned models.
34 // This pass must run right after spatial partitioning, when the code is still
35 // in a single HLO module.
36 //
37 // The steps are:
38 // 1) Find CMARs followed by simple ops followed by CRARs.
39 // 2) Group CMARs by all_reduce_id. They must all be rewritten.
40 // 3) Prove that the CMAR patterns in each core produce the same result.
41 // 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the
42 //    other operand by the number of spatial partitions.
43 // 5) Turn the CRAR into an all-core AllReduce.
44 //
45 // The pass also handles the case where multiple CMARs lead to the same CRAR,
46 // and eliminates all CMARs. This graph:
47 //
48 //        Y
49 //        |
50 //  X   CMAR_2   Z
51 //  |      \    /
52 // CMAR_1     +
53 //    \     /
54 //       +
55 //       |
56 //     CRAR
57 //
58 // gets rewritten to:
59 //
60 //           Z   num_partitions
61 //            \  /
62 //       Y    div
63 //        \   /
64 //    X     +
65 //     \   /
66 //       +
67 //       |
68 //  all-core AR
69 //
70 class ArCrsCombiner : public HloModulePass {
71  public:
ArCrsCombiner(int num_spatial_partitions)72   ArCrsCombiner(int num_spatial_partitions)
73       : num_spatial_partitions_(num_spatial_partitions) {}
name()74   absl::string_view name() const override { return "ar-crs-combiner"; }
75   StatusOr<bool> Run(HloModule* module) override;
76 
77   // Helper method to allow testing of InstructionsComputeSameValue.
78   static bool TestInstructionsComputeSameValue(HloInstruction* i1,
79                                                HloInstruction* i2);
80 
81  private:
82   // We used this struct because multiple ARs could be paired with the same CRS.
83   // In this case, we want to select the AR that is furthest from the CRS,
84   // because it makes it easier to eliminate all ARs during RewriteGraph.
85   struct ArCrsPair {
86     HloInstruction* ar;
87     HloInstruction* crs;
88     // The length of the path from AR to CRS in the HLO graph.
89     int64 distance;
90 
ArCrsPairArCrsPair91     ArCrsPair(HloInstruction* all_reduce, HloInstruction* cross_replica_sum,
92               int64 dist)
93         : ar(all_reduce), crs(cross_replica_sum), distance(dist) {}
94 
ToStringArCrsPair95     string ToString() {
96       return absl::StrCat("(AR: ", ar->name(), ", CRS: ", crs->name(),
97                           ", distance: ", distance, ")");
98     }
99   };
100 
101   absl::optional<ArCrsCombiner::ArCrsPair> MatchesArCrsPattern(
102       HloInstruction* instruction);
103 
104   // If the passed instruction is a while parameter, and the while body is only
105   // called by a single while instruction, return the while instruction.
106   absl::optional<HloInstruction*> WhileFromBodyParameter(
107       HloInstruction* instruction);
108 
109   // Returns a vector of tuple instructions.
110   // If all instructions that flow to "instruction" are tuples, return them.
111   // Otherwise, return an empty vector.
112   std::vector<HloInstruction*> GetAllTuples(HloInstruction* instruction);
113 
114   // Checks whether two different elements in the same tuple compute the same
115   // value.
116   bool TupleElementsComputeSameValue(
117       HloInstruction* tuple_shaped_instruction, int64 i1, int64 i2,
118       absl::flat_hash_map<int64, int64>* visited_pairs);
119 
120   // Returns whether the instructions i1 and i2 can be shown to evaluate to the
121   // same value. Handling WHILE requires recursion, which may cause us to visit
122   // the same instruction again. To avoid infinite loops, we pass a cache of
123   // visited instruction pairs.
124   bool InstructionsComputeSameValue(
125       HloInstruction* i1, HloInstruction* i2,
126       absl::flat_hash_map<int64, int64>* visited_pairs);
127 
128   // Populates all_reduce_map_.
129   void GroupAllReducesById(HloModule* module);
130 
131   // Looks at each AllReduce group in all_reduce_map_, and keeps only the
132   // groups for which it's safe to move the AllReduce later in the HLO graph.
133   void KeepProvablyEqualInstructionGroups();
134 
135   // Performs the graph rewrite that eliminates the early AllReduce and turns
136   // the later CRS into an AllReduce.
137   StatusOr<bool> RewriteGraph();
138 
139   int num_spatial_partitions_;
140 
141   // Map from all-reduce ids to the AR/CRS pairs.
142   absl::flat_hash_map<int64, std::vector<ArCrsPair>> all_reduce_map_;
143 
144   // Map from a CRS instruction to the all-reduce ID of the AR paired with the
145   // CRS. Sometimes, several ARs in the code could be paired with the same CRS.
146   // We use this map to pick a single AR/CRS path to rewrite.
147   absl::flat_hash_map<HloInstruction*, int64> crs_reserved_map_;
148 
149   std::unique_ptr<CallGraph> call_graph_;
150 };
151 
152 }  // namespace xla
153 
154 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_
155