• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_
18 
19 #include <atomic>
20 #include <list>
21 #include <memory>
22 #include <random>
23 #include <string>
24 #include <unordered_map>
25 #include <vector>
26 
27 #include "absl/strings/string_view.h"
28 #include "absl/types/optional.h"
29 #include "absl/types/span.h"
30 #include "tensorflow/compiler/xla/iterator_util.h"
31 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
32 #include "tensorflow/compiler/xla/service/hlo.pb.h"
33 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
34 #include "tensorflow/compiler/xla/service/hlo_computation.h"
35 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
37 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
38 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
39 #include "tensorflow/compiler/xla/service/name_uniquer.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/core/lib/gtl/iterator_range.h"
42 #include "tensorflow/core/platform/logging.h"
43 #include "tensorflow/core/platform/mutex.h"
44 
45 namespace xla {
46 
47 // Describes a compilation unit at the HLO level.
48 //
49 // HloModule is the top-level unit in the HLO IR.  It corresponds to a whole
50 // "program".  Running a module, from beginning to end, is the only way to run
51 // an XLA program.
52 //
53 // A module contains one "entry computation"; this HloComputation is like main()
54 // in a C program.  The result of running the module is the result of running
55 // this computation.
56 //
57 // A module also contains some number of "nested computations".  Each nested
58 // computation is attached to an HloInstruction within some other computation.
59 // The meaning of the nested computation depends on the instruction it's
60 // attached to.
61 class HloModule {
62  public:
63   // Constructor without a versioned computation handle. This constructor should
64   // only be used for HloModules used outside of the XLA service (eg
65   // tests). The versioned handle is used by the service in the compilation
66   // cache. A default configuration is created for this module.
67   explicit HloModule(const string& name, const HloModuleConfig& config);
~HloModule()68   virtual ~HloModule() {}
69 
70   // Adds an entry computation to the module. A module can only have one entry
71   // computation. Returns a pointer to the newly added computation.
72   HloComputation* AddEntryComputation(
73       std::unique_ptr<HloComputation> computation);
74 
75   // Adds an embedded computation to the module.
76   HloComputation* AddEmbeddedComputation(
77       std::unique_ptr<HloComputation> computation);
78 
79   // Removes an embedded computation.
80   Status RemoveEmbeddedComputation(HloComputation* to_remove);
81 
82   // Replaces all uses of computations that are keys of 'replacements' with
83   // the corresponding values in 'replacements'. Replaces the entry computation,
84   // if applicable.
85   //
86   // This function iterates over all instructions in the module to find
87   // computations to replace. We could speed it up by keeping track of users of
88   // computations.
89   void ReplaceComputations(
90       const std::unordered_map<HloComputation*, HloComputation*>& replacements);
91 
name()92   const string& name() const { return name_; }
set_name(string name)93   void set_name(string name) { name_ = std::move(name); }
94 
95   // Returns a deep copy of this module including all computations.
96   std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const;
97   std::unique_ptr<HloModule> Clone(const HloModuleConfig& config,
98                                    const string& suffix = "clone") const;
99 
100   // Performs a deep clone of the computation, by recursively cloning all
101   // the called computations as well. If the clone context is specified, it
102   // will be populated with the cloned object mappings.
103   HloComputation* DeepCloneComputation(HloComputation* computation,
104                                        HloCloneContext* context = nullptr);
105 
106   // Return a pointer to the entry computation of the module.
entry_computation()107   HloComputation* entry_computation() const {
108     CHECK_NE(nullptr, entry_computation_);
109     return entry_computation_;
110   }
111 
112   // Returns the root instruction shape of entry computation.
113   //
114   // Precondition: entry_computation_ is not nullptr.
result_shape()115   const Shape& result_shape() const {
116     CHECK_NE(nullptr, entry_computation_);
117     return entry_computation()->root_instruction()->shape();
118   }
119 
120   // Creates the ComputationLayout which describes the current status of the HLO
121   // module entry computation.
compute_computation_layout()122   ComputationLayout compute_computation_layout() const {
123     return ComputationLayout(entry_computation()->ComputeProgramShape(),
124                              /*ignore_layouts=*/false);
125   }
126 
mutable_entry_computation_layout()127   ComputationLayout* mutable_entry_computation_layout() {
128     return config_.mutable_entry_computation_layout();
129   }
130 
entry_computation_layout()131   const ComputationLayout& entry_computation_layout() const {
132     return config_.entry_computation_layout();
133   }
134 
135   // Generates a hash value of an HLO module. Hash considers
136   // information on opcode, shape, operands, and typically a root instruction.
137   // This function returns the same hash value for equivalent HLO modules,
138   // with respect to HloInstruction::Identical() method.
Hash()139   uint64 Hash() const {
140     return entry_computation()->root_instruction()->Hash();
141   }
142 
143   // Gets the computations in this module.
144   //
145   // Returns a view of HloComputation*s, so you can iterate over this in the
146   // natural way:
147   //
148   //   for (HloComputation* c : module->computations()) { ... }
149   //
150   tensorflow::gtl::iterator_range<UnwrappingIterator<
151       std::vector<std::unique_ptr<HloComputation>>::const_iterator>>
computations()152   computations() const {
153     return {MakeUnwrappingIterator(computations_.begin()),
154             MakeUnwrappingIterator(computations_.end())};
155   }
156   tensorflow::gtl::iterator_range<UnwrappingIterator<
157       std::vector<std::unique_ptr<HloComputation>>::iterator>>
computations()158   computations() {
159     return {MakeUnwrappingIterator(computations_.begin()),
160             MakeUnwrappingIterator(computations_.end())};
161   }
162 
163   // Returns the computation in this module that has the name `name`.  Returns
164   // null if there is no such computation.
165   HloComputation* GetComputationWithName(absl::string_view name);
166 
167   // Gets the number of computations in this module.
computation_count()168   int64 computation_count() const { return computations_.size(); }
169 
170   // Returns the mutable computation for the given index.
mutable_computation(int64 idx)171   HloComputation* mutable_computation(int64 idx) {
172     CHECK(idx >= 0 && idx < computations_.size());
173     return computations_[idx].get();
174   }
175 
176   // Gets the number of instructions in this module.
177   int64 instruction_count() const;
178 
179   // Compute and return a post order of all computations in the module. The sort
180   // is defined like so: if computation A has an instruction which calls
181   // computation B, then A will appear after B in the sort.
182   std::vector<HloComputation*> MakeComputationPostOrder() const;
183 
184   // Gets the computations in this module which aren't for fusion nodes.
185   //
186   // Postcondition: All computations in the returned list have
187   // !IsFusionComputation().
188   //
189   // Note: Callers can and do rely on the return value here being a *snapshot*
190   // of the module's non-fusion computations -- that is, it's OK to add or
191   // remove computations from a module while iterating over
192   // MakeNonfusionComputations().
193   std::vector<HloComputation*> MakeNonfusionComputations() const;
194 
config()195   const HloModuleConfig& config() const { return config_; }
set_config(HloModuleConfig & config)196   void set_config(HloModuleConfig& config) { config_ = config; }
197 
198   // Return a string representation of the module.
199   //
200   // (We express the default options using an overload rather than a default
201   // param because gdb ignores default params, but does resolve overloads.)
ToString()202   string ToString() const { return ToString(HloPrintOptions()); }
203   string ToString(const HloPrintOptions& options) const;
204 
205   // Convert an HloModule to or from a proto.
206   HloModuleProto ToProto() const;
207   static StatusOr<std::unique_ptr<HloModule>> CreateFromProto(
208       const HloModuleProto& proto, const HloModuleConfig& module_config);
209 
210   // Creates and returns an HloModuleConfig with an appropriate program shape
211   // for the HLO module in the given proto.
212   static StatusOr<HloModuleConfig> CreateModuleConfigFromProto(
213       const HloModuleProto& module, const DebugOptions& debug_options);
214 
215   // Outlines the given expression from the given computation.
216   // instructions_to_outline contains the instructions that form the expression.
217   //
218   // Precondition: instructions in instructions_to_outline are in topological
219   // order (root of outlined instructions last). TODO(jingyue): takes a set of
220   // instructions and topologically sorts them.
221   HloInstruction* OutlineExpressionFromComputation(
222       absl::Span<HloInstruction* const> instructions_to_outline,
223       const string& outlined_computation_name, HloComputation* computation);
224 
225   // Returns a randomly generated uint64.
226   uint64 RandomNew64() const;
227 
228   // Returns the NameUniquer for uniquing instruction names in this module.
instruction_name_uniquer()229   NameUniquer& instruction_name_uniquer() { return instruction_name_uniquer_; }
230 
231   // Assign a new unique dense id for an instruction
NewUniqueInstructionId()232   int NewUniqueInstructionId() {
233     int result = next_unique_id_;
234     next_unique_id_++;
235     return result;
236   }
237 
238   // input_output_alias_config indicates the list of aliased buffers that are
239   // expected from the module.
input_output_alias_config()240   HloInputOutputAliasConfig& input_output_alias_config() {
241     return input_output_alias_config_;
242   }
input_output_alias_config()243   const HloInputOutputAliasConfig& input_output_alias_config() const {
244     return input_output_alias_config_;
245   }
246 
247   // DynamicParameterBinding holds the list of bindings that indicates which
248   // parameter dimensions are dynamic and which parameters represent their
249   // runtime value.
dynamic_parameter_binding()250   DynamicParameterBinding& dynamic_parameter_binding() {
251     return dynamic_parameter_binding_;
252   }
dynamic_parameter_binding()253   const DynamicParameterBinding& dynamic_parameter_binding() const {
254     return dynamic_parameter_binding_;
255   }
256 
257   // Returns an id that is unique to this module across all modules created over
258   // the lifetime of this process.
unique_id()259   int unique_id() const { return unique_id_; }
260 
261   // Sets the schedule of the module to the given schedule.
262   Status set_schedule(HloSchedule schedule);
263 
264   // Clears the schedule of the module.
clear_schedule()265   void clear_schedule() { schedule_.reset(); }
266 
267   // Returns true if the module has a schedule set.
has_schedule()268   bool has_schedule() const { return schedule_.has_value(); }
269 
270   // Returns the schedue of the module. CHECK fails if no schedule is set.
schedule()271   const HloSchedule& schedule() const { return *schedule_; }
schedule()272   HloSchedule& schedule() { return *schedule_; }
273 
AddComputationAndUnifyNamesAndIds(std::unique_ptr<HloComputation> computation,bool is_entry)274   HloComputation* AddComputationAndUnifyNamesAndIds(
275       std::unique_ptr<HloComputation> computation, bool is_entry) {
276     computation->ClearUniqueIdInternal();
277     for (auto* instruction : computation->instructions()) {
278       instruction->ClearUniqueIdInternal();
279     }
280     return AddComputationInternal(std::move(computation), is_entry,
281                                   /*uniquify_identifiers=*/true);
282   }
283 
284   Status CheckUniqueNamesAndIdsForComputationsAndInstructions() const;
285 
286  private:
287   HloComputation* AddComputationInternal(
288       std::unique_ptr<HloComputation> computation, bool is_entry,
289       bool uniquify_identifiers);
290 
291   string name_;
292   HloModuleConfig config_;
293   HloComputation* entry_computation_ = nullptr;
294   std::vector<std::unique_ptr<HloComputation>> computations_;
295 
296   // Random number generator engine to use when generating random numbers per
297   // HloModule compilation.
298   // TODO(b/25995601): Replace with better seed setting or dev/random for
299   // where we don't need deterministic execution.
300   mutable std::mt19937_64 rng_{42};
301   mutable tensorflow::mutex rng_mutex_;
302 
303   // Unique name generator for computation and instruction names, which are
304   // unique per module.
305   NameUniquer computation_name_uniquer_{/*separator=*/"."};
306   NameUniquer instruction_name_uniquer_{/*separator=*/"."};
307   int next_unique_id_ = 0;
308 
309   // Used to keep track of the next unique module id that should be assigned.
310   static std::atomic<int> next_unique_module_id_;
311   // A unique id to label modules with.
312   int unique_id_;
313 
314   // The HloSchedule of the module. The schedule if it exists contains a
315   // sequential order of instructions for each non-fusion computation in the
316   // module.
317   absl::optional<HloSchedule> schedule_;
318 
319   // alias_config indicates the alias information of input/output buffers that
320   // are expected from the module.
321   HloInputOutputAliasConfig input_output_alias_config_;
322 
323   // Bindings for dynamic parameter mapping.
324   DynamicParameterBinding dynamic_parameter_binding_;
325 };
326 
327 }  // namespace xla
328 
329 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_
330