• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/dot_decomposer.h"
17 
18 #include <utility>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/compiler/xla/permutation_util.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/core/platform/logging.h"
30 
31 namespace xla {
32 
33 namespace {
34 
35 // Convert a dot into a canonical form;
36 // * Non-contracting dimensions are reshaped together,
37 // * Contracting dimensions are reshaped together,
38 // * Batch dimensions are the most major dimensions.
39 // This requires transposing and reshaping of the lhs and rhs, and reshaping the
40 // output batch to the original shape.
CanonicalizeDot(HloInstruction * original_dot)41 Status CanonicalizeDot(HloInstruction* original_dot) {
42   auto computation = original_dot->parent();
43   const auto& original_dnums = original_dot->dot_dimension_numbers();
44   const int64_t num_batch_dims = original_dnums.lhs_batch_dimensions_size();
45   const int64_t num_contracting_dims =
46       original_dnums.lhs_contracting_dimensions_size();
47 
48   const auto& lhs_shape = original_dot->operand(0)->shape();
49   const int64_t lhs_rank = lhs_shape.rank();
50   const int64_t num_lhs_non_contracting_dims =
51       lhs_rank - num_batch_dims - num_contracting_dims;
52 
53   std::vector<int64_t> lhs_non_contracting_dims;
54   lhs_non_contracting_dims.reserve(num_lhs_non_contracting_dims);
55   int64_t lhs_contracting_size = 1;
56   int64_t lhs_non_contracting_size = 1;
57   std::vector<int64_t> batch_dim_sizes;
58   batch_dim_sizes.reserve(num_batch_dims);
59   for (int64_t i = 0; i < lhs_rank; ++i) {
60     if (absl::c_linear_search(original_dnums.lhs_contracting_dimensions(), i)) {
61       lhs_contracting_size *= lhs_shape.dimensions(i);
62     } else if (absl::c_linear_search(original_dnums.lhs_batch_dimensions(),
63                                      i)) {
64       batch_dim_sizes.push_back(lhs_shape.dimensions(i));
65     } else {
66       lhs_non_contracting_dims.push_back(i);
67       lhs_non_contracting_size *= lhs_shape.dimensions(i);
68     }
69   }
70   // The canonical form of the lhs is
71   // [BatchDims, NonContractingDimsProduct, ContractingsDimsProduct]
72   // If NonContractingDimsProduct is 1, it is omitted.
73   std::vector<int64_t> lhs_transpose;
74   lhs_transpose.reserve(lhs_rank);
75   lhs_transpose.insert(lhs_transpose.end(),
76                        original_dnums.lhs_batch_dimensions().begin(),
77                        original_dnums.lhs_batch_dimensions().end());
78   lhs_transpose.insert(lhs_transpose.end(), lhs_non_contracting_dims.begin(),
79                        lhs_non_contracting_dims.end());
80   lhs_transpose.insert(lhs_transpose.end(),
81                        original_dnums.lhs_contracting_dimensions().begin(),
82                        original_dnums.lhs_contracting_dimensions().end());
83   HloInstruction* lhs_operand = original_dot->mutable_operand(0);
84   HloInstruction* transposed_lhs = computation->AddInstruction(
85       HloInstruction::CreateTranspose(
86           ShapeUtil::PermuteDimensions(lhs_transpose, lhs_shape), lhs_operand,
87           lhs_transpose),
88       &lhs_operand->metadata());
89 
90   std::vector<int64_t> lhs_reshape_dims = batch_dim_sizes;
91   if (lhs_non_contracting_size > 1) {
92     lhs_reshape_dims.push_back(lhs_non_contracting_size);
93   }
94   lhs_reshape_dims.push_back(lhs_contracting_size);
95   // Reshape the contracting and non-contracting dimensions together.
96   HloInstruction* reshaped_lhs = computation->AddInstruction(
97       HloInstruction::CreateReshape(
98           ShapeUtil::MakeShape(lhs_shape.element_type(), lhs_reshape_dims),
99           transposed_lhs),
100       &transposed_lhs->metadata());
101 
102   const auto& rhs_shape = original_dot->operand(1)->shape();
103   const int64_t rhs_rank = rhs_shape.rank();
104   const int64_t num_rhs_non_contracting_dims =
105       rhs_rank - num_batch_dims - num_contracting_dims;
106   std::vector<int64_t> rhs_non_contracting_dims;
107   rhs_non_contracting_dims.reserve(num_rhs_non_contracting_dims);
108   int64_t rhs_non_contracting_size = 1;
109   int64_t rhs_contracting_size = 1;
110   for (int64_t i = 0; i < rhs_rank; ++i) {
111     if (absl::c_linear_search(original_dnums.rhs_contracting_dimensions(), i)) {
112       rhs_contracting_size *= rhs_shape.dimensions(i);
113     } else if (!absl::c_linear_search(original_dnums.rhs_batch_dimensions(),
114                                       i)) {
115       rhs_non_contracting_dims.push_back(i);
116       rhs_non_contracting_size *= rhs_shape.dimensions(i);
117     }
118   }
119 
120   // The canonical form of the rhs is
121   // [BatchDims, ContractingsDimsProduct, NonContractingDimsProduct]
122   // If NonContractingDimsProduct is 1, it is omitted.
123   std::vector<int64_t> rhs_transpose;
124   rhs_transpose.reserve(rhs_rank);
125   rhs_transpose.insert(rhs_transpose.end(),
126                        original_dnums.rhs_batch_dimensions().begin(),
127                        original_dnums.rhs_batch_dimensions().end());
128   rhs_transpose.insert(rhs_transpose.end(),
129                        original_dnums.rhs_contracting_dimensions().begin(),
130                        original_dnums.rhs_contracting_dimensions().end());
131   rhs_transpose.insert(rhs_transpose.end(), rhs_non_contracting_dims.begin(),
132                        rhs_non_contracting_dims.end());
133   HloInstruction* rhs_operand = original_dot->mutable_operand(1);
134   HloInstruction* transposed_rhs = computation->AddInstruction(
135       HloInstruction::CreateTranspose(
136           ShapeUtil::PermuteDimensions(rhs_transpose, rhs_shape), rhs_operand,
137           rhs_transpose),
138       &rhs_operand->metadata());
139 
140   std::vector<int64_t> rhs_reshape_dims = batch_dim_sizes;
141   rhs_reshape_dims.push_back(rhs_contracting_size);
142   if (rhs_non_contracting_size > 1) {
143     rhs_reshape_dims.push_back(rhs_non_contracting_size);
144   }
145   // Reshape the contracting and non-contracting dimensions together.
146   HloInstruction* reshaped_rhs = computation->AddInstruction(
147       HloInstruction::CreateReshape(
148           ShapeUtil::MakeShape(rhs_shape.element_type(), rhs_reshape_dims),
149           transposed_rhs),
150       &transposed_rhs->metadata());
151 
152   std::vector<int64_t> dot_dims = batch_dim_sizes;
153   if (lhs_non_contracting_size > 1) {
154     dot_dims.push_back(lhs_non_contracting_size);
155   }
156   if (rhs_non_contracting_size > 1) {
157     dot_dims.push_back(rhs_non_contracting_size);
158   }
159 
160   DotDimensionNumbers dot_dnums;
161   for (int64_t i = 0; i < num_batch_dims; ++i) {
162     dot_dnums.add_lhs_batch_dimensions(i);
163     dot_dnums.add_rhs_batch_dimensions(i);
164   }
165   dot_dnums.add_lhs_contracting_dimensions(
166       num_batch_dims + (lhs_non_contracting_size > 1 ? 1 : 0));
167   dot_dnums.add_rhs_contracting_dimensions(num_batch_dims);
168 
169   HloInstruction* dot = computation->AddInstruction(HloInstruction::CreateDot(
170       ShapeUtil::MakeShape(original_dot->shape().element_type(), dot_dims),
171       reshaped_lhs, reshaped_rhs, dot_dnums, original_dot->precision_config()));
172   original_dot->SetupDerivedInstruction(dot);
173 
174   std::unique_ptr<HloInstruction> replacement =
175       HloInstruction::CreateReshape(original_dot->shape(), dot);
176   VLOG(3) << "Canonicalizing dot:\n"
177           << "\t old: " << original_dot->ToString() << "\n"
178           << "\t new: " << dot->ToString() << "\n"
179           << "\t   -> " << replacement->ToString();
180   return computation->ReplaceWithNewInstruction(original_dot,
181                                                 std::move(replacement));
182 }
183 
184 }  // namespace
185 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)186 StatusOr<bool> DotDecomposer::Run(
187     HloModule* module,
188     const absl::flat_hash_set<absl::string_view>& execution_threads) {
189   // Gather all Non-canonical Dot operations.
190   std::vector<HloInstruction*> non_canonical_dots;
191   for (auto* computation :
192        module->MakeNonfusionComputations(execution_threads)) {
193     for (auto* instruction : computation->instructions()) {
194       if (instruction->opcode() != HloOpcode::kDot) {
195         continue;
196       }
197       const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers();
198       // A dot it not canonical if there is more than one contracting dimension.
199       if (dnums.lhs_contracting_dimensions_size() != 1) {
200         non_canonical_dots.push_back(instruction);
201         continue;
202       }
203       // A dot is not canonical if it has more than one non-contracting
204       // dimension.
205       if (dnums.lhs_batch_dimensions_size() + 2 <
206               instruction->operand(0)->shape().rank() ||
207           dnums.rhs_batch_dimensions_size() + 2 <
208               instruction->operand(1)->shape().rank()) {
209         non_canonical_dots.push_back(instruction);
210         continue;
211       }
212       if (dnums.lhs_batch_dimensions().empty() &&
213           dnums.lhs_contracting_dimensions().empty()) {
214         non_canonical_dots.push_back(instruction);
215         continue;
216       }
217       // Check that batch dims, if present, are canonical.
218       std::vector<int64_t> canonical_batch_dims(
219           dnums.lhs_batch_dimensions_size());
220       absl::c_iota(canonical_batch_dims, 0);
221       if (!absl::c_equal(dnums.lhs_batch_dimensions(), canonical_batch_dims) ||
222           !absl::c_equal(dnums.rhs_batch_dimensions(), canonical_batch_dims)) {
223         non_canonical_dots.push_back(instruction);
224       }
225     }
226   }
227   bool changed = false;
228   for (auto* dot : non_canonical_dots) {
229     TF_RETURN_IF_ERROR(CanonicalizeDot(dot));
230     changed = true;
231   }
232   return changed;
233 }
234 
235 }  // namespace xla
236