1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/dot_decomposer.h"
17
18 #include <utility>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/compiler/xla/permutation_util.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/core/platform/logging.h"
30
31 namespace xla {
32
33 namespace {
34
35 // Convert a dot into a canonical form;
36 // * Non-contracting dimensions are reshaped together,
37 // * Contracting dimensions are reshaped together,
38 // * Batch dimensions are the most major dimensions.
39 // This requires transposing and reshaping of the lhs and rhs, and reshaping the
40 // output batch to the original shape.
CanonicalizeDot(HloInstruction * original_dot)41 Status CanonicalizeDot(HloInstruction* original_dot) {
42 auto computation = original_dot->parent();
43 const auto& original_dnums = original_dot->dot_dimension_numbers();
44 const int64_t num_batch_dims = original_dnums.lhs_batch_dimensions_size();
45 const int64_t num_contracting_dims =
46 original_dnums.lhs_contracting_dimensions_size();
47
48 const auto& lhs_shape = original_dot->operand(0)->shape();
49 const int64_t lhs_rank = lhs_shape.rank();
50 const int64_t num_lhs_non_contracting_dims =
51 lhs_rank - num_batch_dims - num_contracting_dims;
52
53 std::vector<int64_t> lhs_non_contracting_dims;
54 lhs_non_contracting_dims.reserve(num_lhs_non_contracting_dims);
55 int64_t lhs_contracting_size = 1;
56 int64_t lhs_non_contracting_size = 1;
57 std::vector<int64_t> batch_dim_sizes;
58 batch_dim_sizes.reserve(num_batch_dims);
59 for (int64_t i = 0; i < lhs_rank; ++i) {
60 if (absl::c_linear_search(original_dnums.lhs_contracting_dimensions(), i)) {
61 lhs_contracting_size *= lhs_shape.dimensions(i);
62 } else if (absl::c_linear_search(original_dnums.lhs_batch_dimensions(),
63 i)) {
64 batch_dim_sizes.push_back(lhs_shape.dimensions(i));
65 } else {
66 lhs_non_contracting_dims.push_back(i);
67 lhs_non_contracting_size *= lhs_shape.dimensions(i);
68 }
69 }
70 // The canonical form of the lhs is
71 // [BatchDims, NonContractingDimsProduct, ContractingsDimsProduct]
72 // If NonContractingDimsProduct is 1, it is omitted.
73 std::vector<int64_t> lhs_transpose;
74 lhs_transpose.reserve(lhs_rank);
75 lhs_transpose.insert(lhs_transpose.end(),
76 original_dnums.lhs_batch_dimensions().begin(),
77 original_dnums.lhs_batch_dimensions().end());
78 lhs_transpose.insert(lhs_transpose.end(), lhs_non_contracting_dims.begin(),
79 lhs_non_contracting_dims.end());
80 lhs_transpose.insert(lhs_transpose.end(),
81 original_dnums.lhs_contracting_dimensions().begin(),
82 original_dnums.lhs_contracting_dimensions().end());
83 HloInstruction* lhs_operand = original_dot->mutable_operand(0);
84 HloInstruction* transposed_lhs = computation->AddInstruction(
85 HloInstruction::CreateTranspose(
86 ShapeUtil::PermuteDimensions(lhs_transpose, lhs_shape), lhs_operand,
87 lhs_transpose),
88 &lhs_operand->metadata());
89
90 std::vector<int64_t> lhs_reshape_dims = batch_dim_sizes;
91 if (lhs_non_contracting_size > 1) {
92 lhs_reshape_dims.push_back(lhs_non_contracting_size);
93 }
94 lhs_reshape_dims.push_back(lhs_contracting_size);
95 // Reshape the contracting and non-contracting dimensions together.
96 HloInstruction* reshaped_lhs = computation->AddInstruction(
97 HloInstruction::CreateReshape(
98 ShapeUtil::MakeShape(lhs_shape.element_type(), lhs_reshape_dims),
99 transposed_lhs),
100 &transposed_lhs->metadata());
101
102 const auto& rhs_shape = original_dot->operand(1)->shape();
103 const int64_t rhs_rank = rhs_shape.rank();
104 const int64_t num_rhs_non_contracting_dims =
105 rhs_rank - num_batch_dims - num_contracting_dims;
106 std::vector<int64_t> rhs_non_contracting_dims;
107 rhs_non_contracting_dims.reserve(num_rhs_non_contracting_dims);
108 int64_t rhs_non_contracting_size = 1;
109 int64_t rhs_contracting_size = 1;
110 for (int64_t i = 0; i < rhs_rank; ++i) {
111 if (absl::c_linear_search(original_dnums.rhs_contracting_dimensions(), i)) {
112 rhs_contracting_size *= rhs_shape.dimensions(i);
113 } else if (!absl::c_linear_search(original_dnums.rhs_batch_dimensions(),
114 i)) {
115 rhs_non_contracting_dims.push_back(i);
116 rhs_non_contracting_size *= rhs_shape.dimensions(i);
117 }
118 }
119
120 // The canonical form of the rhs is
121 // [BatchDims, ContractingsDimsProduct, NonContractingDimsProduct]
122 // If NonContractingDimsProduct is 1, it is omitted.
123 std::vector<int64_t> rhs_transpose;
124 rhs_transpose.reserve(rhs_rank);
125 rhs_transpose.insert(rhs_transpose.end(),
126 original_dnums.rhs_batch_dimensions().begin(),
127 original_dnums.rhs_batch_dimensions().end());
128 rhs_transpose.insert(rhs_transpose.end(),
129 original_dnums.rhs_contracting_dimensions().begin(),
130 original_dnums.rhs_contracting_dimensions().end());
131 rhs_transpose.insert(rhs_transpose.end(), rhs_non_contracting_dims.begin(),
132 rhs_non_contracting_dims.end());
133 HloInstruction* rhs_operand = original_dot->mutable_operand(1);
134 HloInstruction* transposed_rhs = computation->AddInstruction(
135 HloInstruction::CreateTranspose(
136 ShapeUtil::PermuteDimensions(rhs_transpose, rhs_shape), rhs_operand,
137 rhs_transpose),
138 &rhs_operand->metadata());
139
140 std::vector<int64_t> rhs_reshape_dims = batch_dim_sizes;
141 rhs_reshape_dims.push_back(rhs_contracting_size);
142 if (rhs_non_contracting_size > 1) {
143 rhs_reshape_dims.push_back(rhs_non_contracting_size);
144 }
145 // Reshape the contracting and non-contracting dimensions together.
146 HloInstruction* reshaped_rhs = computation->AddInstruction(
147 HloInstruction::CreateReshape(
148 ShapeUtil::MakeShape(rhs_shape.element_type(), rhs_reshape_dims),
149 transposed_rhs),
150 &transposed_rhs->metadata());
151
152 std::vector<int64_t> dot_dims = batch_dim_sizes;
153 if (lhs_non_contracting_size > 1) {
154 dot_dims.push_back(lhs_non_contracting_size);
155 }
156 if (rhs_non_contracting_size > 1) {
157 dot_dims.push_back(rhs_non_contracting_size);
158 }
159
160 DotDimensionNumbers dot_dnums;
161 for (int64_t i = 0; i < num_batch_dims; ++i) {
162 dot_dnums.add_lhs_batch_dimensions(i);
163 dot_dnums.add_rhs_batch_dimensions(i);
164 }
165 dot_dnums.add_lhs_contracting_dimensions(
166 num_batch_dims + (lhs_non_contracting_size > 1 ? 1 : 0));
167 dot_dnums.add_rhs_contracting_dimensions(num_batch_dims);
168
169 HloInstruction* dot = computation->AddInstruction(HloInstruction::CreateDot(
170 ShapeUtil::MakeShape(original_dot->shape().element_type(), dot_dims),
171 reshaped_lhs, reshaped_rhs, dot_dnums, original_dot->precision_config()));
172 original_dot->SetupDerivedInstruction(dot);
173
174 std::unique_ptr<HloInstruction> replacement =
175 HloInstruction::CreateReshape(original_dot->shape(), dot);
176 VLOG(3) << "Canonicalizing dot:\n"
177 << "\t old: " << original_dot->ToString() << "\n"
178 << "\t new: " << dot->ToString() << "\n"
179 << "\t -> " << replacement->ToString();
180 return computation->ReplaceWithNewInstruction(original_dot,
181 std::move(replacement));
182 }
183
184 } // namespace
185
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)186 StatusOr<bool> DotDecomposer::Run(
187 HloModule* module,
188 const absl::flat_hash_set<absl::string_view>& execution_threads) {
189 // Gather all Non-canonical Dot operations.
190 std::vector<HloInstruction*> non_canonical_dots;
191 for (auto* computation :
192 module->MakeNonfusionComputations(execution_threads)) {
193 for (auto* instruction : computation->instructions()) {
194 if (instruction->opcode() != HloOpcode::kDot) {
195 continue;
196 }
197 const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers();
198 // A dot it not canonical if there is more than one contracting dimension.
199 if (dnums.lhs_contracting_dimensions_size() != 1) {
200 non_canonical_dots.push_back(instruction);
201 continue;
202 }
203 // A dot is not canonical if it has more than one non-contracting
204 // dimension.
205 if (dnums.lhs_batch_dimensions_size() + 2 <
206 instruction->operand(0)->shape().rank() ||
207 dnums.rhs_batch_dimensions_size() + 2 <
208 instruction->operand(1)->shape().rank()) {
209 non_canonical_dots.push_back(instruction);
210 continue;
211 }
212 if (dnums.lhs_batch_dimensions().empty() &&
213 dnums.lhs_contracting_dimensions().empty()) {
214 non_canonical_dots.push_back(instruction);
215 continue;
216 }
217 // Check that batch dims, if present, are canonical.
218 std::vector<int64_t> canonical_batch_dims(
219 dnums.lhs_batch_dimensions_size());
220 absl::c_iota(canonical_batch_dims, 0);
221 if (!absl::c_equal(dnums.lhs_batch_dimensions(), canonical_batch_dims) ||
222 !absl::c_equal(dnums.rhs_batch_dimensions(), canonical_batch_dims)) {
223 non_canonical_dots.push_back(instruction);
224 }
225 }
226 }
227 bool changed = false;
228 for (auto* dot : non_canonical_dots) {
229 TF_RETURN_IF_ERROR(CanonicalizeDot(dot));
230 changed = true;
231 }
232 return changed;
233 }
234
235 } // namespace xla
236