• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_
18 
19 #include <unordered_map>
20 
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/IR/Value.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 
29 namespace xla {
30 
31 class ElementalIrEmitter {
32  public:
33   using HloToElementGeneratorMap =
34       std::unordered_map<const HloInstruction*, llvm_ir::ElementGenerator>;
35 
ElementalIrEmitter(const HloModuleConfig & hlo_module_config,llvm::Module * module,llvm::IRBuilder<> * ir_builder)36   ElementalIrEmitter(const HloModuleConfig& hlo_module_config,
37                      llvm::Module* module, llvm::IRBuilder<>* ir_builder)
38       : ir_builder_(ir_builder),
39         module_(module),
40         hlo_module_config_(hlo_module_config) {}
41 
42   virtual ~ElementalIrEmitter() = default;
43 
44   virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op,
45                                              llvm::Value* operand_value) const;
46 
47   virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op,
48                                               llvm::Value* lhs_value,
49                                               llvm::Value* rhs_value) const;
50 
51   // Returns a function to generate an element of the output of `hlo`, given a
52   // map of functions to generate elements of its operands.
53   virtual llvm_ir::ElementGenerator MakeElementGenerator(
54       const HloInstruction* hlo,
55       const HloToElementGeneratorMap& operand_to_generator) const;
56 
ir_builder()57   llvm::IRBuilder<>* ir_builder() const { return ir_builder_; }
module()58   llvm::Module* module() const { return module_; }
59 
60  protected:
61   virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(
62       const HloInstruction* op, llvm::Value* operand_value) const;
63 
64   virtual StatusOr<llvm::Value*> EmitFloatUnaryOp(
65       const HloInstruction* op, llvm::Value* operand_value) const;
66 
67   virtual StatusOr<llvm::Value*> EmitComplexUnaryOp(
68       const HloInstruction* op, llvm::Value* operand_value) const;
69 
70   virtual StatusOr<llvm::Value*> EmitIntegerBinaryOp(const HloInstruction* op,
71                                                      llvm::Value* lhs_value,
72                                                      llvm::Value* rhs_value,
73                                                      bool is_signed) const;
74 
75   virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(
76       const HloInstruction* op, llvm::Value* lhs_value,
77       llvm::Value* rhs_value) const;
78 
79   virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(
80       const HloInstruction* op, llvm::Value* lhs_value,
81       llvm::Value* rhs_value) const;
82 
83   virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value,
84                                     llvm::Value* rhs_value) const;
85 
86   virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value,
87                                     llvm::Value* rhs_value) const;
88 
89   llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
90                                bool is_signed) const;
91 
92   llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
93                                bool is_signed) const;
94 
95   virtual StatusOr<llvm::Value*> EmitErfInv(PrimitiveType prim_type,
96                                             llvm::Value* value) const;
97 
98   virtual StatusOr<llvm::Value*> EmitErfcInv(PrimitiveType prim_type,
99                                              llvm::Value* value) const;
100 
101   virtual StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type,
102                                            llvm::Value* lhs,
103                                            llvm::Value* rhs) const;
104 
105   virtual StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type,
106                                          llvm::Value* value) const;
107 
108   virtual StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type,
109                                          llvm::Value* value) const;
110 
111   virtual StatusOr<llvm::Value*> EmitCos(PrimitiveType prim_type,
112                                          llvm::Value* value) const;
113 
114   virtual StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
115                                          llvm::Value* value) const;
116 
117   virtual StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type,
118                                          llvm::Value* lhs,
119                                          llvm::Value* rhs) const;
120 
121   virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo,
122                                                      llvm::Value* x) const;
123 
124   virtual llvm::Value* EmitExtractReal(llvm::Value* value) const;
125   virtual llvm::Value* EmitExtractImag(llvm::Value* value) const;
126 
127   // Composes a complex struct. imag may be nullptr for simple cast operations.
128   llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real,
129                                   llvm::Value* imag) const;
130 
131   // A helper method for MakeElementGenerator. Given an elementwise op `hlo` and
132   // the target array index, computes the source array index of its
133   // `operand_no`-th operand.
134   //
135   // Precondition: `hlo` is an elementwise op.
136   llvm_ir::IrArray::Index ElementwiseSourceIndex(
137       const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo,
138       int64 operand_no) const;
139 
140   // Identifier of the thread unique among all threads on the device
EmitThreadId()141   virtual llvm::Value* EmitThreadId() const {
142     return ir_builder_->getIntN(128, 0);
143   }
144 
145   llvm::IRBuilder<>* const ir_builder_;
146 
147   llvm::Module* module_;
148 
149   // The HloModuleConfig which gathers all settings and values which affect the
150   // compiled executable outside of the HLO code itself.
151   const HloModuleConfig& hlo_module_config_;
152 
153  private:
154   // Returns a ElementGenerator for a RNG HloInstruction.
155   llvm_ir::ElementGenerator MakeRngElementGenerator(
156       const HloInstruction* hlo,
157       const HloToElementGeneratorMap& operand_to_generator) const;
158 };
159 
160 }  // namespace xla
161 
162 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_
163