1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "tensorflow/compiler/xla/layout_util.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/hlo_query.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/core/lib/core/errors.h"
34
35 namespace xla {
36 namespace {
37
ToElementType(HloInstruction * hlo,PrimitiveType type)38 HloInstruction* ToElementType(HloInstruction* hlo, PrimitiveType type) {
39 if (hlo->shape().element_type() != type) {
40 Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
41 hlo = hlo->parent()->AddInstruction(
42 HloInstruction::CreateConvert(shape, hlo));
43 }
44 CHECK_EQ(hlo->shape().element_type(), type);
45 return hlo;
46 }
47
HasOperandType(HloInstruction * hlo,PrimitiveType type)48 bool HasOperandType(HloInstruction* hlo, PrimitiveType type) {
49 for (HloInstruction* operand : hlo->operands()) {
50 if (operand->shape().element_type() == type) {
51 return true;
52 }
53 }
54 return false;
55 }
56
57 // Finds out the Tuple Shape of the new instruction after converting the element
58 // type of the operands of the original instruction from `from_type` to
59 // `to_type`.
60 //
61 // This routine assumes the resulting `shape` of the original instruction is a
62 // non-nested tuple. This assumption is currently safe as only kTuple, kInfeed,
63 // kOutfeed, kCall, kCustomCall and kBatchNorm* HLO instructions can produce
64 // results with tuple shapes, and this routine is only called to convert the
65 // result shapes of kBatchNorm* HLO instructions, which are non-nested tuples.
GetConvertedTupleShape(const Shape & shape,PrimitiveType from_type,PrimitiveType to_type)66 Shape GetConvertedTupleShape(const Shape& shape, PrimitiveType from_type,
67 PrimitiveType to_type) {
68 std::vector<Shape> new_tuple_subshapes;
69 for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
70 Shape subshape = ShapeUtil::GetTupleElementShape(shape, i);
71 CHECK(!subshape.IsTuple());
72 if (subshape.element_type() == from_type) {
73 subshape = ShapeUtil::ChangeElementType(subshape, to_type);
74 }
75 new_tuple_subshapes.push_back(subshape);
76 }
77 return ShapeUtil::MakeTupleShape(new_tuple_subshapes);
78 }
79
80 // Converts the elements of the result of `hlo` to produce a new tuple with
81 // shape `to_shape`.
82 //
83 // This routine assumes `hlo` is an instruction that produces a non-nested Tuple
84 // as a result.
ConvertTupleElements(HloInstruction * hlo,const Shape & to_shape)85 HloInstruction* ConvertTupleElements(HloInstruction* hlo,
86 const Shape& to_shape) {
87 const Shape& shape = hlo->shape();
88 HloComputation* computation = hlo->parent();
89 std::vector<HloInstruction*> tuple_elements;
90 for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
91 const Shape& ele_shape = ShapeUtil::GetTupleElementShape(shape, i);
92 HloInstruction* element = computation->AddInstruction(
93 HloInstruction::CreateGetTupleElement(ele_shape, hlo, i));
94 const Shape& to_ele_shape = ShapeUtil::GetTupleElementShape(to_shape, i);
95 CHECK(!ele_shape.IsTuple());
96 if (ele_shape.element_type() != to_ele_shape.element_type()) {
97 element = computation->AddInstruction(
98 HloInstruction::CreateConvert(to_ele_shape, element));
99 }
100 tuple_elements.push_back(element);
101 }
102 return computation->AddInstruction(
103 HloInstruction::CreateTuple(tuple_elements));
104 }
105
106 } // namespace
107
HloElementTypeConverter(PrimitiveType eliminate_type,PrimitiveType replace_with_type)108 HloElementTypeConverter::HloElementTypeConverter(
109 PrimitiveType eliminate_type, PrimitiveType replace_with_type)
110 : eliminate_type_(eliminate_type), replace_with_type_(replace_with_type) {}
111
112 // This routine converts the arithmetic operations in the given module that use
113 // eliminate_type_ to operations that use replace_with_type_.
Run(HloModule * module)114 StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) {
115 XLA_VLOG_LINES(
116 3, "HloElementTypeConverter::Run(), before:\n" + module->ToString());
117
118 if (eliminate_type_ == replace_with_type_) {
119 return false;
120 }
121
122 HloCloneContext context(module);
123 bool changed = false;
124 for (auto* computation : module->computations()) {
125 for (auto* hlo : computation->MakeInstructionPostOrder()) {
126 const auto opcode = hlo->opcode();
127 // These are ops where it does not make sense to convert them.
128 if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant ||
129 opcode == HloOpcode::kTuple || opcode == HloOpcode::kConvert ||
130 opcode == HloOpcode::kBitcastConvert ||
131 opcode == HloOpcode::kGetTupleElement ||
132 opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed) {
133 continue;
134 }
135
136 // We cannot change a CustomCall since we have no way of adjusting the
137 // called binary to expect the updated type.
138 if (opcode == HloOpcode::kCustomCall) {
139 continue;
140 }
141
142 // These are ops with embedded computations where it suffices to convert
143 // the embedded computations instead of converting the ops themselves.
144 if (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall ||
145 opcode == HloOpcode::kAllReduce || opcode == HloOpcode::kFusion ||
146 opcode == HloOpcode::kMap || opcode == HloOpcode::kReduce ||
147 opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kScatter ||
148 opcode == HloOpcode::kSelectAndScatter ||
149 opcode == HloOpcode::kSort || opcode == HloOpcode::kConditional) {
150 continue;
151 }
152 TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString();
153
154 bool nullary = hlo->operands().empty();
155 bool wrong_element_type = hlo->shape().element_type() == eliminate_type_;
156 bool should_eliminate_type = (nullary && wrong_element_type) ||
157 HasOperandType(hlo, eliminate_type_);
158 if (!should_eliminate_type) {
159 // If this CHECK fires, then this was an instruction that does not take
160 // the elimination type as an operand but it does return it. This pass
161 // does not have a feature to change the output type in that case, so
162 // instead of silently failing to eliminate the type, it fails loudly.
163 TF_RET_CHECK(hlo->shape().element_type() != eliminate_type_);
164 continue;
165 }
166
167 // Handle instructions that perform arithmetic operations and contain
168 // operands with eliminate_type_.
169 //
170 // First, convert the operands with eliminate_type_ to operands with
171 // replace_with_type_.
172 std::vector<HloInstruction*> new_operands;
173 for (HloInstruction* operand : hlo->operands()) {
174 if (operand->shape().element_type() == eliminate_type_) {
175 operand = ToElementType(operand, replace_with_type_);
176 }
177 new_operands.push_back(operand);
178 }
179
180 // Then find out the result type of the new instruction with the same
181 // opcode but using the converted operands, create the new instruction,
182 // and convert the result of the new instruction back to match the result
183 // type of the original instruction.
184 HloInstruction* new_hlo;
185 if (hlo->shape().element_type() == eliminate_type_) {
186 Shape shape =
187 ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_);
188
189 new_hlo = computation->AddInstruction(
190 hlo->CloneWithNewOperands(shape, new_operands, &context));
191 TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo));
192
193 new_hlo = ToElementType(new_hlo, eliminate_type_);
194 } else if (hlo->shape().IsTuple()) {
195 Shape old_shape = hlo->shape();
196 Shape new_shape = GetConvertedTupleShape(hlo->shape(), eliminate_type_,
197 replace_with_type_);
198
199 new_hlo = computation->AddInstruction(
200 hlo->CloneWithNewOperands(new_shape, new_operands, &context));
201 TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo));
202
203 // Convert the elements of the result of `new_hlo` to produce a new
204 // tuple with shape `old_shape`.
205 new_hlo = ConvertTupleElements(new_hlo, old_shape);
206 } else {
207 new_hlo = computation->AddInstruction(
208 hlo->CloneWithNewOperands(hlo->shape(), new_operands, &context));
209 TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo));
210 }
211
212 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_hlo));
213 TF_RETURN_IF_ERROR(hlo->DropAllControlDeps());
214
215 // NB! We want to replace and remove side effecting instructions like Rng
216 // as well so we can't rely HloComputation::ReplaceInstruction to reliably
217 // remove the replaced instruction.
218 TF_RETURN_IF_ERROR(computation->RemoveInstruction(hlo));
219 changed = true;
220 }
221 }
222 XLA_VLOG_LINES(
223 2, "HloElementTypeConverter::Run(), after:\n" + module->ToString());
224 return changed;
225 }
226
227 } // namespace xla
228