• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "tensorflow/compiler/xla/layout_util.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/hlo_query.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 
35 namespace xla {
36 namespace {
37 
ToElementType(HloInstruction * hlo,PrimitiveType type)38 HloInstruction* ToElementType(HloInstruction* hlo, PrimitiveType type) {
39   if (hlo->shape().element_type() != type) {
40     Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
41     hlo = hlo->parent()->AddInstruction(
42         HloInstruction::CreateConvert(shape, hlo));
43   }
44   CHECK_EQ(hlo->shape().element_type(), type);
45   return hlo;
46 }
47 
HasOperandType(HloInstruction * hlo,PrimitiveType type)48 bool HasOperandType(HloInstruction* hlo, PrimitiveType type) {
49   for (HloInstruction* operand : hlo->operands()) {
50     if (operand->shape().element_type() == type) {
51       return true;
52     }
53   }
54   return false;
55 }
56 
57 // Finds out the Tuple Shape of the new instruction after converting the element
58 // type of the operands of the original instruction from `from_type` to
59 // `to_type`.
60 //
61 // This routine assumes the resulting `shape` of the original instruction is a
62 // non-nested tuple. This assumption is currently safe as only kTuple, kInfeed,
63 // kOutfeed, kCall, kCustomCall and kBatchNorm* HLO instructions can produce
64 // results with tuple shapes, and this routine is only called to convert the
65 // result shapes of kBatchNorm* HLO instructions, which are non-nested tuples.
GetConvertedTupleShape(const Shape & shape,PrimitiveType from_type,PrimitiveType to_type)66 Shape GetConvertedTupleShape(const Shape& shape, PrimitiveType from_type,
67                              PrimitiveType to_type) {
68   std::vector<Shape> new_tuple_subshapes;
69   const int64_t n = ShapeUtil::TupleElementCount(shape);
70   new_tuple_subshapes.reserve(n);
71   for (int64_t i = 0; i < n; ++i) {
72     Shape subshape = ShapeUtil::GetTupleElementShape(shape, i);
73     CHECK(!subshape.IsTuple());
74     if (subshape.element_type() == from_type) {
75       subshape = ShapeUtil::ChangeElementType(subshape, to_type);
76     }
77     new_tuple_subshapes.push_back(subshape);
78   }
79   return ShapeUtil::MakeTupleShape(new_tuple_subshapes);
80 }
81 
82 // Converts the elements of the result of `hlo` to produce a new tuple with
83 // shape `to_shape`.
84 //
85 // This routine assumes `hlo` is an instruction that produces a non-nested Tuple
86 // as a result.
ConvertTupleElements(HloInstruction * hlo,const Shape & to_shape)87 HloInstruction* ConvertTupleElements(HloInstruction* hlo,
88                                      const Shape& to_shape) {
89   const Shape& shape = hlo->shape();
90   HloComputation* computation = hlo->parent();
91   std::vector<HloInstruction*> tuple_elements;
92   for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
93     const Shape& ele_shape = ShapeUtil::GetTupleElementShape(shape, i);
94     HloInstruction* element = computation->AddInstruction(
95         HloInstruction::CreateGetTupleElement(ele_shape, hlo, i));
96     const Shape& to_ele_shape = ShapeUtil::GetTupleElementShape(to_shape, i);
97     CHECK(!ele_shape.IsTuple());
98     if (ele_shape.element_type() != to_ele_shape.element_type()) {
99       element = computation->AddInstruction(
100           HloInstruction::CreateConvert(to_ele_shape, element));
101     }
102     tuple_elements.push_back(element);
103   }
104   return computation->AddInstruction(
105       HloInstruction::CreateTuple(tuple_elements));
106 }
107 
108 }  // namespace
109 
HloElementTypeConverter(PrimitiveType eliminate_type,PrimitiveType replace_with_type)110 HloElementTypeConverter::HloElementTypeConverter(
111     PrimitiveType eliminate_type, PrimitiveType replace_with_type)
112     : eliminate_type_(eliminate_type), replace_with_type_(replace_with_type) {}
113 
114 // This routine converts the arithmetic operations in the given module that use
115 // eliminate_type_ to operations that use replace_with_type_.
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)116 StatusOr<bool> HloElementTypeConverter::Run(
117     HloModule* module,
118     const absl::flat_hash_set<absl::string_view>& execution_threads) {
119   XLA_VLOG_LINES(
120       3, "HloElementTypeConverter::Run(), before:\n" + module->ToString());
121 
122   if (eliminate_type_ == replace_with_type_) {
123     return false;
124   }
125 
126   HloCloneContext context(module);
127   bool changed = false;
128   for (auto* computation : module->computations(execution_threads)) {
129     for (auto* hlo : computation->MakeInstructionPostOrder()) {
130       const auto opcode = hlo->opcode();
131       // These are ops where it does not make sense to convert them.
132       if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant ||
133           opcode == HloOpcode::kTuple || opcode == HloOpcode::kConvert ||
134           opcode == HloOpcode::kBitcastConvert ||
135           opcode == HloOpcode::kGetTupleElement ||
136           opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed) {
137         continue;
138       }
139 
140       // We cannot change a CustomCall since we have no way of adjusting the
141       // called binary to expect the updated type.
142       if (opcode == HloOpcode::kCustomCall) {
143         continue;
144       }
145 
146       // These are ops with embedded computations where it suffices to convert
147       // the embedded computations instead of converting the ops themselves.
148       if (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall ||
149           opcode == HloOpcode::kAllReduce ||
150           opcode == HloOpcode::kReduceScatter ||
151           opcode == HloOpcode::kAllReduceStart ||
152           opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap ||
153           opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow ||
154           opcode == HloOpcode::kScatter ||
155           opcode == HloOpcode::kSelectAndScatter ||
156           opcode == HloOpcode::kSort || opcode == HloOpcode::kConditional) {
157         continue;
158       }
159       TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString();
160 
161       bool nullary = hlo->operands().empty();
162       bool wrong_element_type = hlo->shape().element_type() == eliminate_type_;
163       bool should_eliminate_type = (nullary && wrong_element_type) ||
164                                    HasOperandType(hlo, eliminate_type_);
165       if (!should_eliminate_type) {
166         // If this CHECK fires, then this was an instruction that does not take
167         // the elimination type as an operand but it does return it. This pass
168         // does not have a feature to change the output type in that case, so
169         // instead of silently failing to eliminate the type, it fails loudly.
170         TF_RET_CHECK(hlo->shape().element_type() != eliminate_type_);
171         continue;
172       }
173 
174       // Handle instructions that perform arithmetic operations and contain
175       // operands with eliminate_type_.
176       //
177       // First, convert the operands with eliminate_type_ to operands with
178       // replace_with_type_.
179       std::vector<HloInstruction*> new_operands;
180       const auto& operands = hlo->operands();
181       new_operands.reserve(operands.size());
182       for (HloInstruction* operand : operands) {
183         if (operand->shape().element_type() == eliminate_type_) {
184           operand = ToElementType(operand, replace_with_type_);
185         }
186         new_operands.push_back(operand);
187       }
188 
189       // Then find out the result type of the new instruction with the same
190       // opcode but using the converted operands, create the new instruction,
191       // and convert the result of the new instruction back to match the result
192       // type of the original instruction.
193       HloInstruction* new_hlo;
194       if (hlo->shape().element_type() == eliminate_type_) {
195         Shape shape =
196             ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_);
197 
198         new_hlo = computation->AddInstruction(
199             hlo->CloneWithNewOperands(shape, new_operands, &context));
200         TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo));
201 
202         new_hlo = ToElementType(new_hlo, eliminate_type_);
203       } else if (hlo->shape().IsTuple()) {
204         Shape old_shape = hlo->shape();
205         Shape new_shape = GetConvertedTupleShape(hlo->shape(), eliminate_type_,
206                                                  replace_with_type_);
207 
208         new_hlo = computation->AddInstruction(
209             hlo->CloneWithNewOperands(new_shape, new_operands, &context));
210         TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo));
211 
212         // Convert the elements of the result of `new_hlo` to produce a new
213         // tuple with shape `old_shape`.
214         new_hlo = ConvertTupleElements(new_hlo, old_shape);
215       } else {
216         new_hlo = computation->AddInstruction(
217             hlo->CloneWithNewOperands(hlo->shape(), new_operands, &context));
218         TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo));
219       }
220 
221       TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_hlo));
222       TF_RETURN_IF_ERROR(hlo->DropAllControlDeps());
223 
224       // NB!  We want to replace and remove side effecting instructions like Rng
225       // as well so we can't rely HloComputation::ReplaceInstruction to reliably
226       // remove the replaced instruction.
227       TF_RETURN_IF_ERROR(computation->RemoveInstruction(hlo));
228       changed = true;
229     }
230   }
231   XLA_VLOG_LINES(
232       2, "HloElementTypeConverter::Run(), after:\n" + module->ToString());
233   return changed;
234 }
235 
236 }  // namespace xla
237