1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "tensorflow/compiler/xla/layout_util.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/hlo_query.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/core/lib/core/errors.h"
34
35 namespace xla {
36 namespace {
37
ToElementType(HloInstruction * hlo,PrimitiveType type)38 HloInstruction* ToElementType(HloInstruction* hlo, PrimitiveType type) {
39 if (hlo->shape().element_type() != type) {
40 Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
41 hlo = hlo->parent()->AddInstruction(
42 HloInstruction::CreateConvert(shape, hlo));
43 }
44 CHECK_EQ(hlo->shape().element_type(), type);
45 return hlo;
46 }
47
HasOperandType(HloInstruction * hlo,PrimitiveType type)48 bool HasOperandType(HloInstruction* hlo, PrimitiveType type) {
49 for (HloInstruction* operand : hlo->operands()) {
50 if (operand->shape().element_type() == type) {
51 return true;
52 }
53 }
54 return false;
55 }
56
57 // Finds out the Tuple Shape of the new instruction after converting the element
58 // type of the operands of the original instruction from `from_type` to
59 // `to_type`.
60 //
61 // This routine assumes the resulting `shape` of the original instruction is a
62 // non-nested tuple. This assumption is currently safe as only kTuple, kInfeed,
63 // kOutfeed, kCall, kCustomCall and kBatchNorm* HLO instructions can produce
64 // results with tuple shapes, and this routine is only called to convert the
65 // result shapes of kBatchNorm* HLO instructions, which are non-nested tuples.
GetConvertedTupleShape(const Shape & shape,PrimitiveType from_type,PrimitiveType to_type)66 Shape GetConvertedTupleShape(const Shape& shape, PrimitiveType from_type,
67 PrimitiveType to_type) {
68 std::vector<Shape> new_tuple_subshapes;
69 const int64_t n = ShapeUtil::TupleElementCount(shape);
70 new_tuple_subshapes.reserve(n);
71 for (int64_t i = 0; i < n; ++i) {
72 Shape subshape = ShapeUtil::GetTupleElementShape(shape, i);
73 CHECK(!subshape.IsTuple());
74 if (subshape.element_type() == from_type) {
75 subshape = ShapeUtil::ChangeElementType(subshape, to_type);
76 }
77 new_tuple_subshapes.push_back(subshape);
78 }
79 return ShapeUtil::MakeTupleShape(new_tuple_subshapes);
80 }
81
82 // Converts the elements of the result of `hlo` to produce a new tuple with
83 // shape `to_shape`.
84 //
85 // This routine assumes `hlo` is an instruction that produces a non-nested Tuple
86 // as a result.
ConvertTupleElements(HloInstruction * hlo,const Shape & to_shape)87 HloInstruction* ConvertTupleElements(HloInstruction* hlo,
88 const Shape& to_shape) {
89 const Shape& shape = hlo->shape();
90 HloComputation* computation = hlo->parent();
91 std::vector<HloInstruction*> tuple_elements;
92 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
93 const Shape& ele_shape = ShapeUtil::GetTupleElementShape(shape, i);
94 HloInstruction* element = computation->AddInstruction(
95 HloInstruction::CreateGetTupleElement(ele_shape, hlo, i));
96 const Shape& to_ele_shape = ShapeUtil::GetTupleElementShape(to_shape, i);
97 CHECK(!ele_shape.IsTuple());
98 if (ele_shape.element_type() != to_ele_shape.element_type()) {
99 element = computation->AddInstruction(
100 HloInstruction::CreateConvert(to_ele_shape, element));
101 }
102 tuple_elements.push_back(element);
103 }
104 return computation->AddInstruction(
105 HloInstruction::CreateTuple(tuple_elements));
106 }
107
108 } // namespace
109
HloElementTypeConverter(PrimitiveType eliminate_type,PrimitiveType replace_with_type)110 HloElementTypeConverter::HloElementTypeConverter(
111 PrimitiveType eliminate_type, PrimitiveType replace_with_type)
112 : eliminate_type_(eliminate_type), replace_with_type_(replace_with_type) {}
113
114 // This routine converts the arithmetic operations in the given module that use
115 // eliminate_type_ to operations that use replace_with_type_.
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)116 StatusOr<bool> HloElementTypeConverter::Run(
117 HloModule* module,
118 const absl::flat_hash_set<absl::string_view>& execution_threads) {
119 XLA_VLOG_LINES(
120 3, "HloElementTypeConverter::Run(), before:\n" + module->ToString());
121
122 if (eliminate_type_ == replace_with_type_) {
123 return false;
124 }
125
126 HloCloneContext context(module);
127 bool changed = false;
128 for (auto* computation : module->computations(execution_threads)) {
129 for (auto* hlo : computation->MakeInstructionPostOrder()) {
130 const auto opcode = hlo->opcode();
131 // These are ops where it does not make sense to convert them.
132 if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant ||
133 opcode == HloOpcode::kTuple || opcode == HloOpcode::kConvert ||
134 opcode == HloOpcode::kBitcastConvert ||
135 opcode == HloOpcode::kGetTupleElement ||
136 opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed) {
137 continue;
138 }
139
140 // We cannot change a CustomCall since we have no way of adjusting the
141 // called binary to expect the updated type.
142 if (opcode == HloOpcode::kCustomCall) {
143 continue;
144 }
145
146 // These are ops with embedded computations where it suffices to convert
147 // the embedded computations instead of converting the ops themselves.
148 if (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall ||
149 opcode == HloOpcode::kAllReduce ||
150 opcode == HloOpcode::kReduceScatter ||
151 opcode == HloOpcode::kAllReduceStart ||
152 opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap ||
153 opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow ||
154 opcode == HloOpcode::kScatter ||
155 opcode == HloOpcode::kSelectAndScatter ||
156 opcode == HloOpcode::kSort || opcode == HloOpcode::kConditional) {
157 continue;
158 }
159 TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString();
160
161 bool nullary = hlo->operands().empty();
162 bool wrong_element_type = hlo->shape().element_type() == eliminate_type_;
163 bool should_eliminate_type = (nullary && wrong_element_type) ||
164 HasOperandType(hlo, eliminate_type_);
165 if (!should_eliminate_type) {
166 // If this CHECK fires, then this was an instruction that does not take
167 // the elimination type as an operand but it does return it. This pass
168 // does not have a feature to change the output type in that case, so
169 // instead of silently failing to eliminate the type, it fails loudly.
170 TF_RET_CHECK(hlo->shape().element_type() != eliminate_type_);
171 continue;
172 }
173
174 // Handle instructions that perform arithmetic operations and contain
175 // operands with eliminate_type_.
176 //
177 // First, convert the operands with eliminate_type_ to operands with
178 // replace_with_type_.
179 std::vector<HloInstruction*> new_operands;
180 const auto& operands = hlo->operands();
181 new_operands.reserve(operands.size());
182 for (HloInstruction* operand : operands) {
183 if (operand->shape().element_type() == eliminate_type_) {
184 operand = ToElementType(operand, replace_with_type_);
185 }
186 new_operands.push_back(operand);
187 }
188
189 // Then find out the result type of the new instruction with the same
190 // opcode but using the converted operands, create the new instruction,
191 // and convert the result of the new instruction back to match the result
192 // type of the original instruction.
193 HloInstruction* new_hlo;
194 if (hlo->shape().element_type() == eliminate_type_) {
195 Shape shape =
196 ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_);
197
198 new_hlo = computation->AddInstruction(
199 hlo->CloneWithNewOperands(shape, new_operands, &context));
200 TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo));
201
202 new_hlo = ToElementType(new_hlo, eliminate_type_);
203 } else if (hlo->shape().IsTuple()) {
204 Shape old_shape = hlo->shape();
205 Shape new_shape = GetConvertedTupleShape(hlo->shape(), eliminate_type_,
206 replace_with_type_);
207
208 new_hlo = computation->AddInstruction(
209 hlo->CloneWithNewOperands(new_shape, new_operands, &context));
210 TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo));
211
212 // Convert the elements of the result of `new_hlo` to produce a new
213 // tuple with shape `old_shape`.
214 new_hlo = ConvertTupleElements(new_hlo, old_shape);
215 } else {
216 new_hlo = computation->AddInstruction(
217 hlo->CloneWithNewOperands(hlo->shape(), new_operands, &context));
218 TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo));
219 }
220
221 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_hlo));
222 TF_RETURN_IF_ERROR(hlo->DropAllControlDeps());
223
224 // NB! We want to replace and remove side effecting instructions like Rng
225 // as well so we can't rely HloComputation::ReplaceInstruction to reliably
226 // remove the replaced instruction.
227 TF_RETURN_IF_ERROR(computation->RemoveInstruction(hlo));
228 changed = true;
229 }
230 }
231 XLA_VLOG_LINES(
232 2, "HloElementTypeConverter::Run(), after:\n" + module->ToString());
233 return changed;
234 }
235
236 } // namespace xla
237