• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
19 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
20 #include "tensorflow/compiler/xla/status_macros.h"
21 
22 namespace xla {
23 namespace {
24 
25 // Concat(Concat(A, B), C) => Concat(A, B, C)
ConcatForwarding(HloInstruction * concat)26 StatusOr<bool> ConcatForwarding(HloInstruction* concat) {
27   if (concat->opcode() != HloOpcode::kConcatenate) {
28     return false;
29   }
30   bool changed = false;
31 
32   auto parent = concat->parent();
33   std::vector<HloInstruction*> new_operands;
34   for (HloInstruction* operand : concat->operands()) {
35     if (operand->opcode() != HloOpcode::kConcatenate ||
36         operand->concatenate_dimension() != concat->concatenate_dimension()) {
37       new_operands.push_back(operand);
38     } else {
39       changed = true;
40       for (HloInstruction* operand_operand : operand->operands()) {
41         new_operands.push_back(operand_operand);
42       }
43     }
44   }
45   if (changed) {
46     auto new_concat = parent->AddInstruction(HloInstruction::CreateConcatenate(
47         concat->shape(), new_operands, concat->concatenate_dimension()));
48     TF_RETURN_IF_ERROR(parent->ReplaceInstruction(concat, new_concat));
49   }
50   return changed;
51 }
52 
53 // Slice(Concat(A1, A2, ..., An, ...), [n:n+1]) => An
SliceConcatForwarding(HloInstruction * slice)54 StatusOr<bool> SliceConcatForwarding(HloInstruction* slice) {
55   if (slice->opcode() != HloOpcode::kSlice) {
56     return false;
57   }
58   auto concat = slice->mutable_operand(0);
59   if (concat->opcode() != HloOpcode::kConcatenate) {
60     return false;
61   }
62 
63   if (slice->shape().rank() != 1) {
64     // Slice concat forwarding only work for size 1 tensor.
65     return false;
66   }
67 
68   int64_t concat_dim = concat->concatenate_dimension();
69 
70   std::vector<HloInstruction*> new_operands;
71   int64_t size_so_far = 0;
72   int64_t slice_size = slice->shape().dimensions(concat_dim);
73   if (slice_size != slice->slice_limits(0) - slice->slice_starts(0)) {
74     return false;
75   }
76   if (slice->slice_strides(0) != 1) {
77     return false;
78   }
79   for (HloInstruction* operand : concat->operands()) {
80     if (size_so_far == slice->slice_starts(0) &&
81         operand->shape().dimensions(0) == slice_size) {
82       // Found an operand that can be forwarded.
83       TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(operand));
84       return true;
85     }
86     size_so_far += operand->shape().dimensions(concat_dim);
87   }
88 
89   return false;
90 }
91 
92 // Reshape(Broadcast(A, []->[1]), [1]->[]) ==> A
ReshapeBroadcastForwarding(HloInstruction * reshape)93 StatusOr<bool> ReshapeBroadcastForwarding(HloInstruction* reshape) {
94   if (reshape->opcode() != HloOpcode::kReshape) {
95     return false;
96   }
97   auto broadcast = reshape->mutable_operand(0);
98   if (broadcast->opcode() != HloOpcode::kBroadcast) {
99     return false;
100   }
101 
102   if (reshape->shape().rank() != 0) {
103     return false;
104   }
105 
106   if (broadcast->shape().rank() != 1) {
107     return false;
108   }
109 
110   if (broadcast->mutable_operand(0)->shape().rank() != 0) {
111     return false;
112   }
113 
114   TF_RETURN_IF_ERROR(
115       reshape->ReplaceAllUsesWith(broadcast->mutable_operand(0)));
116 
117   return true;
118 }
119 
120 // Reshape(Reshape(A, []->[1]), [1]->[]) ==> A
ReshapeReshapeForwarding(HloInstruction * reshape)121 StatusOr<bool> ReshapeReshapeForwarding(HloInstruction* reshape) {
122   if (reshape->opcode() != HloOpcode::kReshape) {
123     return false;
124   }
125   auto reshape_2 = reshape->mutable_operand(0);
126   if (reshape_2->opcode() != HloOpcode::kReshape) {
127     return false;
128   }
129 
130   if (!Shape::Equal()(reshape->shape(), reshape_2->operand(0)->shape())) {
131     return false;
132   }
133   TF_RETURN_IF_ERROR(
134       reshape->ReplaceAllUsesWith(reshape_2->mutable_operand(0)));
135 
136   return true;
137 }
138 
139 // Convert(A, T->T) ==> A
IdentityConvertRemoving(HloInstruction * convert)140 StatusOr<bool> IdentityConvertRemoving(HloInstruction* convert) {
141   if (convert->opcode() != HloOpcode::kConvert) {
142     return false;
143   }
144   auto operand = convert->mutable_operand(0);
145   if (Shape::Equal()(convert->shape(), operand->shape())) {
146     TF_RETURN_IF_ERROR(convert->ReplaceAllUsesWith(operand));
147     return true;
148   }
149   return false;
150 }
151 
152 // Reshape(A, S->S) ==> A
IdentityReshapeRemoving(HloInstruction * reshape)153 StatusOr<bool> IdentityReshapeRemoving(HloInstruction* reshape) {
154   if (reshape->opcode() != HloOpcode::kReshape) {
155     return false;
156   }
157   auto operand = reshape->mutable_operand(0);
158   if (Shape::Equal()(reshape->shape(), operand->shape())) {
159     TF_RETURN_IF_ERROR(reshape->ReplaceAllUsesWith(operand));
160     return true;
161   }
162   return false;
163 }
164 
165 }  // namespace
166 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)167 StatusOr<bool> DynamicDimensionSimplifier::Run(
168     HloModule* module,
169     const absl::flat_hash_set<absl::string_view>& execution_threads) {
170   XLA_VLOG_LINES(
171       2, "DynamicDimensionSimplifier::Run(), before:\n" + module->ToString());
172   bool changed = false;
173 
174   for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
175     for (auto* inst : comp->MakeInstructionPostOrder()) {
176       TF_ASSIGN_OR_RETURN(bool local_changed, ConcatForwarding(inst));
177       changed |= local_changed;
178     }
179   }
180 
181   for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
182     for (auto* inst : comp->MakeInstructionPostOrder()) {
183       TF_ASSIGN_OR_RETURN(bool local_changed, SliceConcatForwarding(inst));
184       changed |= local_changed;
185     }
186   }
187 
188   for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
189     for (auto* inst : comp->MakeInstructionPostOrder()) {
190       TF_ASSIGN_OR_RETURN(bool local_changed, ReshapeBroadcastForwarding(inst));
191       changed |= local_changed;
192     }
193   }
194   for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
195     for (auto* inst : comp->MakeInstructionPostOrder()) {
196       TF_ASSIGN_OR_RETURN(bool local_changed, ReshapeReshapeForwarding(inst));
197       changed |= local_changed;
198     }
199   }
200   for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
201     for (auto* inst : comp->MakeInstructionPostOrder()) {
202       TF_ASSIGN_OR_RETURN(bool local_changed, IdentityConvertRemoving(inst));
203       changed |= local_changed;
204     }
205   }
206   for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
207     for (auto* inst : comp->MakeInstructionPostOrder()) {
208       TF_ASSIGN_OR_RETURN(bool local_changed, IdentityReshapeRemoving(inst));
209       changed |= local_changed;
210     }
211   }
212   XLA_VLOG_LINES(
213       2, "DynamicDimensionSimplifier::Run(), after:\n" + module->ToString());
214   return changed;
215 }
216 }  // namespace xla
217