1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "tensorflow/compiler/xla/layout_util.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/hlo_query.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/core/lib/core/errors.h"
34
35 namespace xla {
36 namespace {
37
ToElementType(HloInstruction * hlo,PrimitiveType type)38 HloInstruction* ToElementType(HloInstruction* hlo, PrimitiveType type) {
39 if (hlo->shape().element_type() != type) {
40 Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
41 hlo = hlo->parent()->AddInstruction(
42 HloInstruction::CreateConvert(shape, hlo));
43 }
44 CHECK_EQ(hlo->shape().element_type(), type);
45 return hlo;
46 }
47
HasOperandType(HloInstruction * hlo,PrimitiveType type)48 bool HasOperandType(HloInstruction* hlo, PrimitiveType type) {
49 for (HloInstruction* operand : hlo->operands()) {
50 if (operand->shape().element_type() == type) {
51 return true;
52 }
53 }
54 return false;
55 }
56
57 // Finds out the Tuple Shape of the new instruction after converting the element
58 // type of the operands of the original instruction from `from_type` to
59 // `to_type`.
60 //
61 // This routine assumes the resulting `shape` of the original instruction is a
62 // non-nested tuple. This assumption is currently safe as only kTuple, kInfeed,
63 // kOutfeed, kCall, kCustomCall and kBatchNorm* HLO instructions can produce
64 // results with tuple shapes, and this routine is only called to convert the
65 // result shapes of kBatchNorm* HLO instructions, which are non-nested tuples.
GetConvertedTupleShape(const Shape & shape,PrimitiveType from_type,PrimitiveType to_type)66 Shape GetConvertedTupleShape(const Shape& shape, PrimitiveType from_type,
67 PrimitiveType to_type) {
68 std::vector<Shape> new_tuple_subshapes;
69 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
70 Shape subshape = ShapeUtil::GetTupleElementShape(shape, i);
71 CHECK(!subshape.IsTuple());
72 if (subshape.element_type() == from_type) {
73 subshape = ShapeUtil::ChangeElementType(subshape, to_type);
74 }
75 new_tuple_subshapes.push_back(subshape);
76 }
77 return ShapeUtil::MakeTupleShape(new_tuple_subshapes);
78 }
79
80 // Converts the elements of the result of `hlo` to produce a new tuple with
81 // shape `to_shape`.
82 //
83 // This routine assumes `hlo` is an instruction that produces a non-nested Tuple
84 // as a result.
ConvertTupleElements(HloInstruction * hlo,const Shape & to_shape)85 HloInstruction* ConvertTupleElements(HloInstruction* hlo,
86 const Shape& to_shape) {
87 const Shape& shape = hlo->shape();
88 HloComputation* computation = hlo->parent();
89 std::vector<HloInstruction*> tuple_elements;
90 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
91 const Shape& ele_shape = ShapeUtil::GetTupleElementShape(shape, i);
92 HloInstruction* element = computation->AddInstruction(
93 HloInstruction::CreateGetTupleElement(ele_shape, hlo, i));
94 const Shape& to_ele_shape = ShapeUtil::GetTupleElementShape(to_shape, i);
95 CHECK(!ele_shape.IsTuple());
96 if (ele_shape.element_type() != to_ele_shape.element_type()) {
97 element = computation->AddInstruction(
98 HloInstruction::CreateConvert(to_ele_shape, element));
99 }
100 tuple_elements.push_back(element);
101 }
102 return computation->AddInstruction(
103 HloInstruction::CreateTuple(tuple_elements));
104 }
105
106 } // namespace
107
HloElementTypeConverter(PrimitiveType eliminate_type,PrimitiveType replace_with_type)108 HloElementTypeConverter::HloElementTypeConverter(
109 PrimitiveType eliminate_type, PrimitiveType replace_with_type)
110 : eliminate_type_(eliminate_type), replace_with_type_(replace_with_type) {}
111
112 // This routine converts the arithmetic operations in the given module that use
113 // eliminate_type_ to operations that use replace_with_type_.
Run(HloModule * module)114 StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) {
115 XLA_VLOG_LINES(
116 3, "HloElementTypeConverter::Run(), before:\n" + module->ToString());
117
118 if (eliminate_type_ == replace_with_type_) {
119 return false;
120 }
121
122 HloCloneContext context(module);
123 bool changed = false;
124 for (auto* computation : module->computations()) {
125 for (auto* hlo : computation->MakeInstructionPostOrder()) {
126 const auto opcode = hlo->opcode();
127 // These are ops where it does not make sense to convert them.
128 if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant ||
129 opcode == HloOpcode::kTuple || opcode == HloOpcode::kConvert ||
130 opcode == HloOpcode::kBitcastConvert ||
131 opcode == HloOpcode::kGetTupleElement ||
132 opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed) {
133 continue;
134 }
135
136 // We cannot change a CustomCall since we have no way of adjusting the
137 // called binary to expect the updated type.
138 if (opcode == HloOpcode::kCustomCall) {
139 continue;
140 }
141
142 // These are ops with embedded computations where it suffices to convert
143 // the embedded computations instead of converting the ops themselves.
144 if (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall ||
145 opcode == HloOpcode::kAllReduce ||
146 opcode == HloOpcode::kReduceScatter ||
147 opcode == HloOpcode::kAllReduceStart ||
148 opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap ||
149 opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow ||
150 opcode == HloOpcode::kScatter ||
151 opcode == HloOpcode::kSelectAndScatter ||
152 opcode == HloOpcode::kSort || opcode == HloOpcode::kConditional) {
153 continue;
154 }
155 TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString();
156
157 bool nullary = hlo->operands().empty();
158 bool wrong_element_type = hlo->shape().element_type() == eliminate_type_;
159 bool should_eliminate_type = (nullary && wrong_element_type) ||
160 HasOperandType(hlo, eliminate_type_);
161 if (!should_eliminate_type) {
162 // If this CHECK fires, then this was an instruction that does not take
163 // the elimination type as an operand but it does return it. This pass
164 // does not have a feature to change the output type in that case, so
165 // instead of silently failing to eliminate the type, it fails loudly.
166 TF_RET_CHECK(hlo->shape().element_type() != eliminate_type_);
167 continue;
168 }
169
170 // Handle instructions that perform arithmetic operations and contain
171 // operands with eliminate_type_.
172 //
173 // First, convert the operands with eliminate_type_ to operands with
174 // replace_with_type_.
175 std::vector<HloInstruction*> new_operands;
176 for (HloInstruction* operand : hlo->operands()) {
177 if (operand->shape().element_type() == eliminate_type_) {
178 operand = ToElementType(operand, replace_with_type_);
179 }
180 new_operands.push_back(operand);
181 }
182
183 // Then find out the result type of the new instruction with the same
184 // opcode but using the converted operands, create the new instruction,
185 // and convert the result of the new instruction back to match the result
186 // type of the original instruction.
187 HloInstruction* new_hlo;
188 if (hlo->shape().element_type() == eliminate_type_) {
189 Shape shape =
190 ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_);
191
192 new_hlo = computation->AddInstruction(
193 hlo->CloneWithNewOperands(shape, new_operands, &context));
194 TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo));
195
196 new_hlo = ToElementType(new_hlo, eliminate_type_);
197 } else if (hlo->shape().IsTuple()) {
198 Shape old_shape = hlo->shape();
199 Shape new_shape = GetConvertedTupleShape(hlo->shape(), eliminate_type_,
200 replace_with_type_);
201
202 new_hlo = computation->AddInstruction(
203 hlo->CloneWithNewOperands(new_shape, new_operands, &context));
204 TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo));
205
206 // Convert the elements of the result of `new_hlo` to produce a new
207 // tuple with shape `old_shape`.
208 new_hlo = ConvertTupleElements(new_hlo, old_shape);
209 } else {
210 new_hlo = computation->AddInstruction(
211 hlo->CloneWithNewOperands(hlo->shape(), new_operands, &context));
212 TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo));
213 }
214
215 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_hlo));
216 TF_RETURN_IF_ERROR(hlo->DropAllControlDeps());
217
218 // NB! We want to replace and remove side effecting instructions like Rng
219 // as well so we can't rely HloComputation::ReplaceInstruction to reliably
220 // remove the replaced instruction.
221 TF_RETURN_IF_ERROR(computation->RemoveInstruction(hlo));
222 changed = true;
223 }
224 }
225 XLA_VLOG_LINES(
226 2, "HloElementTypeConverter::Run(), after:\n" + module->ToString());
227 return changed;
228 }
229
230 } // namespace xla
231