1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "tensorflow/compiler/xla/layout_util.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/hlo_query.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 
35 namespace xla {
36 namespace {
37 
ToElementType(HloInstruction * hlo,PrimitiveType type)38 HloInstruction* ToElementType(HloInstruction* hlo, PrimitiveType type) {
39   if (hlo->shape().element_type() != type) {
40     Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
41     hlo = hlo->parent()->AddInstruction(
42         HloInstruction::CreateConvert(shape, hlo));
43   }
44   CHECK_EQ(hlo->shape().element_type(), type);
45   return hlo;
46 }
47 
HasOperandType(HloInstruction * hlo,PrimitiveType type)48 bool HasOperandType(HloInstruction* hlo, PrimitiveType type) {
49   for (HloInstruction* operand : hlo->operands()) {
50     if (operand->shape().element_type() == type) {
51       return true;
52     }
53   }
54   return false;
55 }
56 
57 // Finds out the Tuple Shape of the new instruction after converting the element
58 // type of the operands of the original instruction from `from_type` to
59 // `to_type`.
60 //
61 // This routine assumes the resulting `shape` of the original instruction is a
62 // non-nested tuple. This assumption is currently safe as only kTuple, kInfeed,
63 // kOutfeed, kCall, kCustomCall and kBatchNorm* HLO instructions can produce
64 // results with tuple shapes, and this routine is only called to convert the
65 // result shapes of kBatchNorm* HLO instructions, which are non-nested tuples.
GetConvertedTupleShape(const Shape & shape,PrimitiveType from_type,PrimitiveType to_type)66 Shape GetConvertedTupleShape(const Shape& shape, PrimitiveType from_type,
67                              PrimitiveType to_type) {
68   std::vector<Shape> new_tuple_subshapes;
69   for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
70     Shape subshape = ShapeUtil::GetTupleElementShape(shape, i);
71     CHECK(!subshape.IsTuple());
72     if (subshape.element_type() == from_type) {
73       subshape = ShapeUtil::ChangeElementType(subshape, to_type);
74     }
75     new_tuple_subshapes.push_back(subshape);
76   }
77   return ShapeUtil::MakeTupleShape(new_tuple_subshapes);
78 }
79 
80 // Converts the elements of the result of `hlo` to produce a new tuple with
81 // shape `to_shape`.
82 //
83 // This routine assumes `hlo` is an instruction that produces a non-nested Tuple
84 // as a result.
ConvertTupleElements(HloInstruction * hlo,const Shape & to_shape)85 HloInstruction* ConvertTupleElements(HloInstruction* hlo,
86                                      const Shape& to_shape) {
87   const Shape& shape = hlo->shape();
88   HloComputation* computation = hlo->parent();
89   std::vector<HloInstruction*> tuple_elements;
90   for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
91     const Shape& ele_shape = ShapeUtil::GetTupleElementShape(shape, i);
92     HloInstruction* element = computation->AddInstruction(
93         HloInstruction::CreateGetTupleElement(ele_shape, hlo, i));
94     const Shape& to_ele_shape = ShapeUtil::GetTupleElementShape(to_shape, i);
95     CHECK(!ele_shape.IsTuple());
96     if (ele_shape.element_type() != to_ele_shape.element_type()) {
97       element = computation->AddInstruction(
98           HloInstruction::CreateConvert(to_ele_shape, element));
99     }
100     tuple_elements.push_back(element);
101   }
102   return computation->AddInstruction(
103       HloInstruction::CreateTuple(tuple_elements));
104 }
105 
106 }  // namespace
107 
HloElementTypeConverter(PrimitiveType eliminate_type,PrimitiveType replace_with_type)108 HloElementTypeConverter::HloElementTypeConverter(
109     PrimitiveType eliminate_type, PrimitiveType replace_with_type)
110     : eliminate_type_(eliminate_type), replace_with_type_(replace_with_type) {}
111 
112 // This routine converts the arithmetic operations in the given module that use
113 // eliminate_type_ to operations that use replace_with_type_.
Run(HloModule * module)114 StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) {
115   XLA_VLOG_LINES(
116       3, "HloElementTypeConverter::Run(), before:\n" + module->ToString());
117 
118   if (eliminate_type_ == replace_with_type_) {
119     return false;
120   }
121 
122   HloCloneContext context(module);
123   bool changed = false;
124   for (auto* computation : module->computations()) {
125     for (auto* hlo : computation->MakeInstructionPostOrder()) {
126       const auto opcode = hlo->opcode();
127       // These are ops where it does not make sense to convert them.
128       if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant ||
129           opcode == HloOpcode::kTuple || opcode == HloOpcode::kConvert ||
130           opcode == HloOpcode::kBitcastConvert ||
131           opcode == HloOpcode::kGetTupleElement ||
132           opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed) {
133         continue;
134       }
135 
136       // We cannot change a CustomCall since we have no way of adjusting the
137       // called binary to expect the updated type.
138       if (opcode == HloOpcode::kCustomCall) {
139         continue;
140       }
141 
142       // These are ops with embedded computations where it suffices to convert
143       // the embedded computations instead of converting the ops themselves.
144       if (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall ||
145           opcode == HloOpcode::kAllReduce ||
146           opcode == HloOpcode::kReduceScatter ||
147           opcode == HloOpcode::kAllReduceStart ||
148           opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap ||
149           opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow ||
150           opcode == HloOpcode::kScatter ||
151           opcode == HloOpcode::kSelectAndScatter ||
152           opcode == HloOpcode::kSort || opcode == HloOpcode::kConditional) {
153         continue;
154       }
155       TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString();
156 
157       bool nullary = hlo->operands().empty();
158       bool wrong_element_type = hlo->shape().element_type() == eliminate_type_;
159       bool should_eliminate_type = (nullary && wrong_element_type) ||
160                                    HasOperandType(hlo, eliminate_type_);
161       if (!should_eliminate_type) {
162         // If this CHECK fires, then this was an instruction that does not take
163         // the elimination type as an operand but it does return it. This pass
164         // does not have a feature to change the output type in that case, so
165         // instead of silently failing to eliminate the type, it fails loudly.
166         TF_RET_CHECK(hlo->shape().element_type() != eliminate_type_);
167         continue;
168       }
169 
170       // Handle instructions that perform arithmetic operations and contain
171       // operands with eliminate_type_.
172       //
173       // First, convert the operands with eliminate_type_ to operands with
174       // replace_with_type_.
175       std::vector<HloInstruction*> new_operands;
176       for (HloInstruction* operand : hlo->operands()) {
177         if (operand->shape().element_type() == eliminate_type_) {
178           operand = ToElementType(operand, replace_with_type_);
179         }
180         new_operands.push_back(operand);
181       }
182 
183       // Then find out the result type of the new instruction with the same
184       // opcode but using the converted operands, create the new instruction,
185       // and convert the result of the new instruction back to match the result
186       // type of the original instruction.
187       HloInstruction* new_hlo;
188       if (hlo->shape().element_type() == eliminate_type_) {
189         Shape shape =
190             ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_);
191 
192         new_hlo = computation->AddInstruction(
193             hlo->CloneWithNewOperands(shape, new_operands, &context));
194         TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo));
195 
196         new_hlo = ToElementType(new_hlo, eliminate_type_);
197       } else if (hlo->shape().IsTuple()) {
198         Shape old_shape = hlo->shape();
199         Shape new_shape = GetConvertedTupleShape(hlo->shape(), eliminate_type_,
200                                                  replace_with_type_);
201 
202         new_hlo = computation->AddInstruction(
203             hlo->CloneWithNewOperands(new_shape, new_operands, &context));
204         TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo));
205 
206         // Convert the elements of the result of `new_hlo` to produce a new
207         // tuple with shape `old_shape`.
208         new_hlo = ConvertTupleElements(new_hlo, old_shape);
209       } else {
210         new_hlo = computation->AddInstruction(
211             hlo->CloneWithNewOperands(hlo->shape(), new_operands, &context));
212         TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo));
213       }
214 
215       TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_hlo));
216       TF_RETURN_IF_ERROR(hlo->DropAllControlDeps());
217 
218       // NB!  We want to replace and remove side effecting instructions like Rng
219       // as well so we can't rely HloComputation::ReplaceInstruction to reliably
220       // remove the replaced instruction.
221       TF_RETURN_IF_ERROR(computation->RemoveInstruction(hlo));
222       changed = true;
223     }
224   }
225   XLA_VLOG_LINES(
226       2, "HloElementTypeConverter::Run(), after:\n" + module->ToString());
227   return changed;
228 }
229 
230 }  // namespace xla
231