1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
19 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
20 #include "tensorflow/compiler/xla/status_macros.h"
21
22 namespace xla {
23 namespace {
24
25 // Concat(Concat(A, B), C) => Concat(A, B, C)
ConcatForwarding(HloInstruction * concat)26 StatusOr<bool> ConcatForwarding(HloInstruction* concat) {
27 if (concat->opcode() != HloOpcode::kConcatenate) {
28 return false;
29 }
30 bool changed = false;
31
32 auto parent = concat->parent();
33 std::vector<HloInstruction*> new_operands;
34 for (HloInstruction* operand : concat->operands()) {
35 if (operand->opcode() != HloOpcode::kConcatenate ||
36 operand->concatenate_dimension() != concat->concatenate_dimension()) {
37 new_operands.push_back(operand);
38 } else {
39 changed = true;
40 for (HloInstruction* operand_operand : operand->operands()) {
41 new_operands.push_back(operand_operand);
42 }
43 }
44 }
45 if (changed) {
46 auto new_concat = parent->AddInstruction(HloInstruction::CreateConcatenate(
47 concat->shape(), new_operands, concat->concatenate_dimension()));
48 TF_RETURN_IF_ERROR(parent->ReplaceInstruction(concat, new_concat));
49 }
50 return changed;
51 }
52
53 // Slice(Concat(A1, A2, ..., An, ...), [n:n+1]) => An
SliceConcatForwarding(HloInstruction * slice)54 StatusOr<bool> SliceConcatForwarding(HloInstruction* slice) {
55 if (slice->opcode() != HloOpcode::kSlice) {
56 return false;
57 }
58 auto concat = slice->mutable_operand(0);
59 if (concat->opcode() != HloOpcode::kConcatenate) {
60 return false;
61 }
62
63 if (slice->shape().rank() != 1) {
64 // Slice concat forwarding only work for size 1 tensor.
65 return false;
66 }
67
68 int64_t concat_dim = concat->concatenate_dimension();
69
70 std::vector<HloInstruction*> new_operands;
71 int64_t size_so_far = 0;
72 int64_t slice_size = slice->shape().dimensions(concat_dim);
73 if (slice_size != slice->slice_limits(0) - slice->slice_starts(0)) {
74 return false;
75 }
76 if (slice->slice_strides(0) != 1) {
77 return false;
78 }
79 for (HloInstruction* operand : concat->operands()) {
80 if (size_so_far == slice->slice_starts(0) &&
81 operand->shape().dimensions(0) == slice_size) {
82 // Found an operand that can be forwarded.
83 TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(operand));
84 return true;
85 }
86 size_so_far += operand->shape().dimensions(concat_dim);
87 }
88
89 return false;
90 }
91
92 // Reshape(Broadcast(A, []->[1]), [1]->[]) ==> A
ReshapeBroadcastForwarding(HloInstruction * reshape)93 StatusOr<bool> ReshapeBroadcastForwarding(HloInstruction* reshape) {
94 if (reshape->opcode() != HloOpcode::kReshape) {
95 return false;
96 }
97 auto broadcast = reshape->mutable_operand(0);
98 if (broadcast->opcode() != HloOpcode::kBroadcast) {
99 return false;
100 }
101
102 if (reshape->shape().rank() != 0) {
103 return false;
104 }
105
106 if (broadcast->shape().rank() != 1) {
107 return false;
108 }
109
110 if (broadcast->mutable_operand(0)->shape().rank() != 0) {
111 return false;
112 }
113
114 TF_RETURN_IF_ERROR(
115 reshape->ReplaceAllUsesWith(broadcast->mutable_operand(0)));
116
117 return true;
118 }
119
120 // Reshape(Reshape(A, []->[1]), [1]->[]) ==> A
ReshapeReshapeForwarding(HloInstruction * reshape)121 StatusOr<bool> ReshapeReshapeForwarding(HloInstruction* reshape) {
122 if (reshape->opcode() != HloOpcode::kReshape) {
123 return false;
124 }
125 auto reshape_2 = reshape->mutable_operand(0);
126 if (reshape_2->opcode() != HloOpcode::kReshape) {
127 return false;
128 }
129
130 if (!Shape::Equal()(reshape->shape(), reshape_2->operand(0)->shape())) {
131 return false;
132 }
133 TF_RETURN_IF_ERROR(
134 reshape->ReplaceAllUsesWith(reshape_2->mutable_operand(0)));
135
136 return true;
137 }
138
139 // Convert(A, T->T) ==> A
IdentityConvertRemoving(HloInstruction * convert)140 StatusOr<bool> IdentityConvertRemoving(HloInstruction* convert) {
141 if (convert->opcode() != HloOpcode::kConvert) {
142 return false;
143 }
144 auto operand = convert->mutable_operand(0);
145 if (Shape::Equal()(convert->shape(), operand->shape())) {
146 TF_RETURN_IF_ERROR(convert->ReplaceAllUsesWith(operand));
147 return true;
148 }
149 return false;
150 }
151
152 // Reshape(A, S->S) ==> A
IdentityReshapeRemoving(HloInstruction * reshape)153 StatusOr<bool> IdentityReshapeRemoving(HloInstruction* reshape) {
154 if (reshape->opcode() != HloOpcode::kReshape) {
155 return false;
156 }
157 auto operand = reshape->mutable_operand(0);
158 if (Shape::Equal()(reshape->shape(), operand->shape())) {
159 TF_RETURN_IF_ERROR(reshape->ReplaceAllUsesWith(operand));
160 return true;
161 }
162 return false;
163 }
164
165 } // namespace
166
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)167 StatusOr<bool> DynamicDimensionSimplifier::Run(
168 HloModule* module,
169 const absl::flat_hash_set<absl::string_view>& execution_threads) {
170 XLA_VLOG_LINES(
171 2, "DynamicDimensionSimplifier::Run(), before:\n" + module->ToString());
172 bool changed = false;
173
174 for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
175 for (auto* inst : comp->MakeInstructionPostOrder()) {
176 TF_ASSIGN_OR_RETURN(bool local_changed, ConcatForwarding(inst));
177 changed |= local_changed;
178 }
179 }
180
181 for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
182 for (auto* inst : comp->MakeInstructionPostOrder()) {
183 TF_ASSIGN_OR_RETURN(bool local_changed, SliceConcatForwarding(inst));
184 changed |= local_changed;
185 }
186 }
187
188 for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
189 for (auto* inst : comp->MakeInstructionPostOrder()) {
190 TF_ASSIGN_OR_RETURN(bool local_changed, ReshapeBroadcastForwarding(inst));
191 changed |= local_changed;
192 }
193 }
194 for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
195 for (auto* inst : comp->MakeInstructionPostOrder()) {
196 TF_ASSIGN_OR_RETURN(bool local_changed, ReshapeReshapeForwarding(inst));
197 changed |= local_changed;
198 }
199 }
200 for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
201 for (auto* inst : comp->MakeInstructionPostOrder()) {
202 TF_ASSIGN_OR_RETURN(bool local_changed, IdentityConvertRemoving(inst));
203 changed |= local_changed;
204 }
205 }
206 for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
207 for (auto* inst : comp->MakeInstructionPostOrder()) {
208 TF_ASSIGN_OR_RETURN(bool local_changed, IdentityReshapeRemoving(inst));
209 changed |= local_changed;
210 }
211 }
212 XLA_VLOG_LINES(
213 2, "DynamicDimensionSimplifier::Run(), after:\n" + module->ToString());
214 return changed;
215 }
216 } // namespace xla
217