• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_
18 
19 #include <unordered_map>
20 
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/IR/Value.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
27 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 
30 namespace xla {
31 
32 class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
33  public:
34   using HloToElementGeneratorMap =
35       std::unordered_map<const HloInstruction*, llvm_ir::ElementGenerator>;
36 
ElementalIrEmitter(const HloModuleConfig & hlo_module_config,llvm::Module * module,llvm::IRBuilder<> * b)37   ElementalIrEmitter(const HloModuleConfig& hlo_module_config,
38                      llvm::Module* module, llvm::IRBuilder<>* b)
39       : b_(b), module_(module), hlo_module_config_(hlo_module_config) {}
40 
41   virtual ~ElementalIrEmitter() = default;
42 
43   virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op,
44                                              llvm::Value* operand_value);
45 
46   virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op,
47                                               llvm::Value* lhs_value,
48                                               llvm::Value* rhs_value);
49 
50   // Returns a function to generate an element of the output of `hlo`, given a
51   // map of functions to generate elements of its operands.
52   virtual llvm_ir::ElementGenerator MakeElementGenerator(
53       const HloInstruction* hlo,
54       const HloToElementGeneratorMap& operand_to_generator);
55 
b()56   llvm::IRBuilder<>* b() { return b_; }
57 
58   // builder() is for IrBuilderMixin.
builder()59   llvm::IRBuilder<>* builder() { return b_; }
60 
module()61   llvm::Module* module() { return module_; }
62 
63  protected:
64   virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(const HloInstruction* op,
65                                                     llvm::Value* operand_value);
66 
67   virtual StatusOr<llvm::Value*> EmitFloatUnaryOp(const HloInstruction* op,
68                                                   llvm::Value* operand_value);
69 
70   virtual StatusOr<llvm::Value*> EmitComplexUnaryOp(const HloInstruction* op,
71                                                     llvm::Value* operand_value);
72 
73   llvm::Value* IsZero(llvm::Value* v);
74   llvm::Value* IsIntMinDivisionOverflow(llvm::Value* lhs, llvm::Value* rhs);
75   llvm::Value* GetZero(llvm::Type* type);
76   llvm::Value* GetOne(llvm::Type* type);
77   llvm::Value* GetIntSMin(llvm::Type* type);
78   llvm::Value* GetMinusOne(llvm::Type* type);
79 
80   llvm::Value* EmitIntegerDivide(llvm::Value* lhs, llvm::Value* rhs,
81                                  bool is_signed);
82   llvm::Value* EmitIntegerRemainder(llvm::Value* lhs, llvm::Value* rhs,
83                                     bool is_signed);
84 
85   virtual StatusOr<llvm::Value*> EmitIntegerBinaryOp(const HloInstruction* op,
86                                                      llvm::Value* lhs_value,
87                                                      llvm::Value* rhs_value,
88                                                      bool is_signed);
89 
90   virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
91                                                    llvm::Value* lhs_value,
92                                                    llvm::Value* rhs_value);
93 
94   virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(const HloInstruction* op,
95                                                      llvm::Value* lhs_value,
96                                                      llvm::Value* rhs_value);
97 
98   virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value,
99                                     llvm::Value* rhs_value);
100 
101   virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value,
102                                     llvm::Value* rhs_value);
103 
104   llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
105                                bool is_signed);
106 
107   llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
108                                bool is_signed);
109 
110   virtual StatusOr<llvm::Value*> EmitErfInv(PrimitiveType prim_type,
111                                             llvm::Value* value);
112 
113   virtual StatusOr<llvm::Value*> EmitErfcInv(PrimitiveType prim_type,
114                                              llvm::Value* value);
115 
116   virtual StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type,
117                                            llvm::Value* lhs, llvm::Value* rhs);
118 
119   virtual StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type,
120                                          llvm::Value* value);
121 
122   virtual StatusOr<llvm::Value*> EmitSqrt(PrimitiveType prim_type,
123                                           llvm::Value* value);
124 
125   virtual StatusOr<llvm::Value*> EmitRsqrt(PrimitiveType prim_type,
126                                            llvm::Value* value);
127 
128   virtual StatusOr<llvm::Value*> EmitLog1p(PrimitiveType prim_type,
129                                            llvm::Value* value);
130 
131   virtual StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type,
132                                          llvm::Value* value);
133 
134   virtual StatusOr<llvm::Value*> EmitCos(PrimitiveType prim_type,
135                                          llvm::Value* value);
136 
137   virtual StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
138                                          llvm::Value* value);
139 
140   virtual StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type,
141                                            llvm::Value* value);
142 
143   virtual StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type,
144                                          llvm::Value* lhs, llvm::Value* rhs);
145 
146   virtual StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
147                                           llvm::Value* value);
148 
149   virtual StatusOr<llvm::Value*> EmitRoundNearestAfz(PrimitiveType prim_type,
150                                                      llvm::Value* value);
151 
152   virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo,
153                                                      llvm::Value* x);
154 
155   virtual llvm::Value* EmitExtractReal(llvm::Value* value);
156   virtual llvm::Value* EmitExtractImag(llvm::Value* value);
157 
158   // Composes a complex struct. imag may be nullptr for simple cast operations.
159   llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real,
160                                   llvm::Value* imag);
161 
162   // Identifier of the thread unique among all threads on the device
EmitThreadId()163   virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); }
164 
165   StatusOr<llvm::Value*> EmitElementalSelect(
166       const HloInstruction* hlo,
167       const HloToElementGeneratorMap& operand_to_generator,
168       const llvm_ir::IrArray::Index& index);
169 
170   StatusOr<llvm::Value*> EmitElementalClamp(
171       const HloInstruction* hlo,
172       const HloToElementGeneratorMap& operand_to_generator,
173       const llvm_ir::IrArray::Index& index);
174 
175   StatusOr<llvm::Value*> EmitElementalConcatenate(
176       const HloInstruction* hlo,
177       const HloToElementGeneratorMap& operand_to_generator,
178       const llvm_ir::IrArray::Index& target_index);
179 
180   StatusOr<llvm::Value*> EmitElementalDynamicSlice(
181       const HloInstruction* hlo,
182       const HloToElementGeneratorMap& operand_to_generator,
183       const llvm_ir::IrArray::Index& index);
184 
185   StatusOr<llvm::Value*> EmitElementalGather(
186       const HloInstruction* hlo,
187       const HloToElementGeneratorMap& operand_to_generator,
188       const llvm_ir::IrArray::Index& index);
189 
190   StatusOr<llvm::Value*> EmitElementalDynamicUpdateSlice(
191       const HloInstruction* hlo,
192       const HloToElementGeneratorMap& operand_to_generator,
193       const llvm_ir::IrArray::Index& index);
194 
195   StatusOr<llvm::Value*> EmitElementalPad(
196       const HloInstruction* hlo,
197       const HloToElementGeneratorMap& operand_to_generator,
198       const llvm_ir::IrArray::Index& padded_index);
199 
200   StatusOr<llvm::Value*> EmitElementalDot(
201       const HloInstruction* hlo,
202       const HloToElementGeneratorMap& operand_to_generator,
203       const llvm_ir::IrArray::Index& dot_result_index);
204 
205   llvm::IRBuilder<>* const b_;
206 
207   llvm::Module* module_;
208 
209   // The HloModuleConfig which gathers all settings and values which affect the
210   // compiled executable outside of the HLO code itself.
211   const HloModuleConfig& hlo_module_config_;
212 
213  private:
214   // Computes the complex power function, returns (a + i*b)^(c + i*d).
215   StatusOr<llvm::Value*> EmitComplexPower(const HloInstruction* op,
216                                           llvm::Value* a, llvm::Value* b,
217                                           llvm::Value* c, llvm::Value* d);
218 
219   // Returns a ElementGenerator for an RNG HloInstruction using the Philox
220   // random number generation algorithm.
221   llvm_ir::ElementGenerator MakePhiloxRngElementGenerator(
222       const HloInstruction* hlo,
223       const HloToElementGeneratorMap& operand_to_generator);
224 
225   // Converts the raw value generated by a random number generation algorithm
226   // to the distribution requested by the RNG HloInstruction.
227   //
228   // Precondition: raw_value has at least as many bits as hlo's element type.
229   StatusOr<llvm::Value*> ConvertValueForDistribution(
230       const HloInstruction* hlo,
231       const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
232       const llvm_ir::IrArray::Index& index, llvm::Value* raw_value);
233 };
234 
235 }  // namespace xla
236 
237 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_
238