• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_
18 
19 #include <atomic>
20 #include <list>
21 #include <memory>
22 #include <random>
23 #include <string>
24 #include <unordered_map>
25 #include <vector>
26 
27 #include "absl/strings/string_view.h"
28 #include "absl/types/optional.h"
29 #include "absl/types/span.h"
30 #include "tensorflow/compiler/xla/iterator_util.h"
31 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
32 #include "tensorflow/compiler/xla/service/hlo.pb.h"
33 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
34 #include "tensorflow/compiler/xla/service/hlo_computation.h"
35 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
37 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
38 #include "tensorflow/compiler/xla/service/hlo_module_metadata.h"
39 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
40 #include "tensorflow/compiler/xla/service/name_uniquer.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/core/lib/gtl/iterator_range.h"
43 #include "tensorflow/core/platform/logging.h"
44 #include "tensorflow/core/platform/mutex.h"
45 
46 namespace xla {
47 
48 // Describes a compilation unit at the HLO level.
49 //
50 // HloModule is the top-level unit in the HLO IR.  It corresponds to a whole
51 // "program".  Running a module, from beginning to end, is the only way to run
52 // an XLA program.
53 //
54 // A module contains one "entry computation"; this HloComputation is like main()
55 // in a C program.  The result of running the module is the result of running
56 // this computation.
57 //
58 // A module also contains some number of "nested computations".  Each nested
59 // computation is attached to an HloInstruction within some other computation.
60 // The meaning of the nested computation depends on the instruction it's
61 // attached to.
62 class HloModule {
63  public:
64   // Constructor without a versioned computation handle. This constructor should
65   // only be used for HloModules used outside of the XLA service (eg
66   // tests). The versioned handle is used by the service in the compilation
67   // cache. A default configuration is created for this module.
68   explicit HloModule(const string& name, HloModuleConfig config);
~HloModule()69   virtual ~HloModule() {}
70 
71   // Adds an entry computation to the module. A module can only have one entry
72   // computation. Returns a pointer to the newly added computation.
73   HloComputation* AddEntryComputation(
74       std::unique_ptr<HloComputation> computation);
75 
76   // Same as the AddEntryComputation function above but the module's
77   // entry_computation_layout is updated to match the layout of the new entry
78   // computation.
79   HloComputation* AddEntryComputationWithLayouts(
80       std::unique_ptr<HloComputation> computation);
81 
82   // Replaces the current entry computation with another computation.
83   // The new entry computation must be a computation that is already in the
84   // module.
85   void ReplaceEntryComputation(HloComputation* entry_computation);
86 
87   // Adds an embedded computation to the module.
88   HloComputation* AddEmbeddedComputation(
89       std::unique_ptr<HloComputation> computation);
90 
91   // Removes an embedded computation.
92   Status RemoveEmbeddedComputation(HloComputation* to_remove);
93 
94   // Removes unused computations.
95   Status RemoveUnusedComputations();
96 
97   // Replaces all uses of computations that are keys of 'replacements' with
98   // the corresponding values in 'replacements'. Replaces the entry computation,
99   // if applicable.
100   //
101   // This function iterates over all instructions in the module to find
102   // computations to replace. We could speed it up by keeping track of users of
103   // computations.
104   void ReplaceComputations(
105       const absl::flat_hash_map<HloComputation*, HloComputation*>&
106           replacements);
107 
name()108   const string& name() const { return name_; }
set_name(string name)109   void set_name(string name) { name_ = std::move(name); }
110 
111   // Returns a deep copy of this module including all computations.
112   std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const;
113   std::unique_ptr<HloModule> Clone(const HloModuleConfig& config,
114                                    const string& suffix = "clone") const;
115 
116   // Performs a deep clone of the computation, by recursively cloning all
117   // the called computations as well. If the clone context is specified, it
118   // will be populated with the cloned object mappings.
119   HloComputation* DeepCloneComputation(HloComputation* computation,
120                                        HloCloneContext* context = nullptr);
121 
122   // Return a pointer to the entry computation of the module.
entry_computation()123   HloComputation* entry_computation() const {
124     CHECK_NE(nullptr, entry_computation_);
125     return entry_computation_;
126   }
127 
has_entry_computation()128   bool has_entry_computation() const { return entry_computation_ != nullptr; }
129 
130   // Returns the root instruction shape of entry computation.
131   //
132   // Precondition: entry_computation_ is not nullptr.
result_shape()133   const Shape& result_shape() const {
134     CHECK_NE(nullptr, entry_computation_);
135     return entry_computation()->root_instruction()->shape();
136   }
137 
138   // Creates the ComputationLayout which describes the current status of the HLO
139   // module entry computation.
compute_computation_layout()140   ComputationLayout compute_computation_layout() const {
141     return ComputationLayout(entry_computation()->ComputeProgramShape(),
142                              /*ignore_layouts=*/false);
143   }
144 
mutable_entry_computation_layout()145   ComputationLayout* mutable_entry_computation_layout() {
146     return config_.mutable_entry_computation_layout();
147   }
148 
entry_computation_layout()149   const ComputationLayout& entry_computation_layout() const {
150     return config_.entry_computation_layout();
151   }
152 
153   // Generates a hash value of an HLO module. Hash considers
154   // information on opcode, shape, operands, and typically a root instruction.
155   // This function returns the same hash value for equivalent HLO modules,
156   // with respect to HloInstruction::Identical() method.
157   uint64 Hash() const;
158 
159   // Gets the computations in this module.
160   //
161   // Returns a view of HloComputation*s, so you can iterate over this in the
162   // natural way:
163   //
164   //   for (HloComputation* c : module->computations()) { ... }
165   //
166   tensorflow::gtl::iterator_range<UnwrappingIterator<
167       std::vector<std::unique_ptr<HloComputation>>::const_iterator>>
computations()168   computations() const {
169     return {MakeUnwrappingIterator(computations_.begin()),
170             MakeUnwrappingIterator(computations_.end())};
171   }
172   tensorflow::gtl::iterator_range<UnwrappingIterator<
173       std::vector<std::unique_ptr<HloComputation>>::iterator>>
computations()174   computations() {
175     return {MakeUnwrappingIterator(computations_.begin()),
176             MakeUnwrappingIterator(computations_.end())};
177   }
178 
179   // Returns the computation in this module that has the name `name`.  Returns
180   // null if there is no such computation.
181   HloComputation* GetComputationWithName(absl::string_view name);
182 
183   // Gets the number of computations in this module.
computation_count()184   int64 computation_count() const { return computations_.size(); }
185 
186   // Returns the mutable computation for the given index.
mutable_computation(int64_t idx)187   HloComputation* mutable_computation(int64_t idx) {
188     CHECK(idx >= 0 && idx < computations_.size());
189     return computations_[idx].get();
190   }
191 
192   // Gets the number of instructions in this module.
193   int64 instruction_count() const;
194 
195   // Deallocate removed instructions in each computation.
Cleanup()196   void Cleanup() {
197     for (auto& comp : computations_) {
198       comp->Cleanup();
199     }
200   }
201 
202   // Compute and return a post order of all computations in the module. The sort
203   // is defined like so: if computation A has an instruction which calls
204   // computation B, then A will appear after B in the sort.
205   std::vector<HloComputation*> MakeComputationPostOrder() const;
206 
207   // Same as MakeComputationPostOrder() but only returns the computations
208   // that are also found in the passed in allowList
209   std::vector<HloComputation*> MakeComputationPostOrder(
210       const absl::flat_hash_set<HloComputation*>& allow_list) const;
211 
212   // Same as MakeComputationPostOrder() but sorting the computations by their
213   // contents. The order is longer post order.
214   std::vector<HloComputation*> MakeComputationSorted() const;
215 
216   // Gets the computations in this module which aren't for fusion nodes.
217   //
218   // Postcondition: All computations in the returned list have
219   // !IsFusionComputation().
220   //
221   // Note: Callers can and do rely on the return value here being a *snapshot*
222   // of the module's non-fusion computations -- that is, it's OK to add or
223   // remove computations from a module while iterating over
224   // MakeNonfusionComputations().
225   std::vector<HloComputation*> MakeNonfusionComputations() const;
226 
227   // Same as MakeNonfusionComputations() but sorting computations by content.
228   std::vector<HloComputation*> MakeNonfusionComputationsSorted() const;
229 
config()230   HloModuleConfig& config() { return config_; }
config()231   const HloModuleConfig& config() const { return config_; }
set_config(const HloModuleConfig & config)232   void set_config(const HloModuleConfig& config) { config_ = config; }
233 
is_dynamic()234   bool is_dynamic() const { return is_dynamic_; }
set_is_dynamic(bool is_dynamic)235   void set_is_dynamic(bool is_dynamic) { is_dynamic_ = is_dynamic; }
236 
237   // Return a string representation of the module.
238   //
239   // (We express the default options using an overload rather than a default
240   // param because gdb ignores default params, but does resolve overloads.)
ToString()241   string ToString() const { return ToString(HloPrintOptions()); }
242   string ToString(const HloPrintOptions& options) const;
243 
244   // Convert an HloModule to or from a proto.
245   HloModuleProto ToProto() const;
246   static StatusOr<std::unique_ptr<HloModule>> CreateFromProto(
247       const HloModuleProto& proto, const HloModuleConfig& module_config,
248       bool prohibit_empty_literal = true);
249 
250   // Creates and returns an HloModuleConfig with an appropriate program shape
251   // for the HLO module in the given proto.
252   static StatusOr<HloModuleConfig> CreateModuleConfigFromProto(
253       const HloModuleProto& module, const DebugOptions& debug_options,
254       const ExecutionOptions* execution_options = nullptr);
255 
256   // Creates and returns an HloModuleConfig with an appropriate program shape
257   // for the HLO module in the given proto.
258   static StatusOr<HloModuleConfig> CreateModuleConfigFromShape(
259       const ProgramShape& program_shape, const DebugOptions& debug_options,
260       const ExecutionOptions* execution_options = nullptr);
261 
262   // Outlines the given expression from the given computation.
263   // instructions_to_outline contains the instructions that form the expression.
264   //
265   // Precondition: instructions in instructions_to_outline are in topological
266   // order (root of outlined instructions last). TODO(jingyue): takes a set of
267   // instructions and topologically sorts them.
268   HloInstruction* OutlineExpressionFromComputation(
269       absl::Span<HloInstruction* const> instructions_to_outline,
270       const string& outlined_computation_name, HloComputation* computation);
271 
272   // Returns a randomly generated uint64.
273   uint64 RandomNew64() const;
274 
275   // Returns the NameUniquer for uniquing instruction names in this module.
instruction_name_uniquer()276   NameUniquer& instruction_name_uniquer() { return instruction_name_uniquer_; }
277 
278   // Assign a new unique dense id for an instruction
NewUniqueInstructionId()279   int NewUniqueInstructionId() {
280     int result = next_unique_id_;
281     next_unique_id_++;
282     return result;
283   }
284 
285   // input_output_alias_config indicates the list of aliased buffers that are
286   // expected from the module.
input_output_alias_config()287   HloInputOutputAliasConfig& input_output_alias_config() {
288     return input_output_alias_config_;
289   }
input_output_alias_config()290   const HloInputOutputAliasConfig& input_output_alias_config() const {
291     return input_output_alias_config_;
292   }
293 
294   // DynamicParameterBinding holds the list of bindings that indicates which
295   // parameter dimensions are dynamic and which parameters represent their
296   // runtime value.
dynamic_parameter_binding()297   DynamicParameterBinding& dynamic_parameter_binding() {
298     return dynamic_parameter_binding_;
299   }
dynamic_parameter_binding()300   const DynamicParameterBinding& dynamic_parameter_binding() const {
301     return dynamic_parameter_binding_;
302   }
303 
304   // Returns an id that is unique to this module across all modules created over
305   // the lifetime of this process.
unique_id()306   int unique_id() const { return unique_id_; }
307 
308   // Sets the schedule of the module to the given schedule.
309   Status set_schedule(HloSchedule schedule);
310 
311   // Clears the schedule of the module.
clear_schedule()312   void clear_schedule() { schedule_.reset(); }
313 
314   // Returns true if the module has a schedule set.
has_schedule()315   bool has_schedule() const { return schedule_.has_value(); }
316 
317   // Returns the schedule of the module. CHECK fails if no schedule is set.
schedule()318   const HloSchedule& schedule() const { return *schedule_; }
schedule()319   HloSchedule& schedule() { return *schedule_; }
320 
AddComputationAndUnifyNamesAndIds(std::unique_ptr<HloComputation> computation,bool is_entry)321   HloComputation* AddComputationAndUnifyNamesAndIds(
322       std::unique_ptr<HloComputation> computation, bool is_entry) {
323     computation->ClearUniqueIdInternal();
324     for (auto* instruction : computation->instructions()) {
325       instruction->ClearUniqueIdInternal();
326     }
327     return AddComputationInternal(std::move(computation), is_entry,
328                                   /*uniquify_identifiers=*/true,
329                                   /*preserve_entry_layouts=*/true);
330   }
331 
332   Status CheckUniqueNamesAndIdsForComputationsAndInstructions() const;
333 
334   // Checks if this config has a list of entry parameters' HLO shardings for
335   // SPMD.
has_spmd_parameters_shardings()336   bool has_spmd_parameters_shardings() const {
337     return spmd_parameters_shardings_.has_value();
338   }
339 
340   // Getter and setter for the list of entry parameters' HLO shardings for SPMD.
spmd_parameters_shardings()341   const std::vector<HloSharding>& spmd_parameters_shardings() const {
342     CHECK(spmd_parameters_shardings_.has_value());
343     return *spmd_parameters_shardings_;
344   }
set_spmd_parameters_shardings(const std::vector<HloSharding> & shardings)345   void set_spmd_parameters_shardings(
346       const std::vector<HloSharding>& shardings) {
347     spmd_parameters_shardings_ = shardings;
348   }
349 
350   // Checks if this config has the entry computation output's HLO sharding for
351   // SPMD.
has_spmd_output_sharding()352   bool has_spmd_output_sharding() const {
353     return spmd_output_sharding_.has_value();
354   }
355 
356   // Getter and setter for the entry computation output's HLO shardings for
357   // SPMD.
spmd_output_sharding()358   const HloSharding& spmd_output_sharding() const {
359     CHECK(spmd_output_sharding_.has_value());
360     return *spmd_output_sharding_;
361   }
set_spmd_output_sharding(const HloSharding & sharding)362   void set_spmd_output_sharding(const HloSharding& sharding) {
363     spmd_output_sharding_ = sharding;
364   }
365 
366   // Add a program argument to be prefetched across programs.
AddCrossProgramPrefetch(int64_t parameter,const ShapeIndex & index)367   void AddCrossProgramPrefetch(int64_t parameter, const ShapeIndex& index) {
368     cross_program_prefetches_.emplace_back(parameter, index);
369   }
370 
371   // Get the list of program arguments to be prefetch across programs.
CrossProgramPrefetches()372   const absl::Span<const std::pair<int64, ShapeIndex>> CrossProgramPrefetches()
373       const {
374     return cross_program_prefetches_;
375   }
376 
metadata()377   const HloModuleMetadata& metadata() const { return metadata_; }
metadata()378   HloModuleMetadata* metadata() { return &metadata_; }
379 
380   // Moves (not copies) metadata from this HloModule to `module`. To be used
381   // in cases like HloModuleGroup::ReplaceModule when metadata should be
382   // transferred out of a module before it's destroyed.
MoveMetadataToModule(HloModule * module)383   void MoveMetadataToModule(HloModule* module) {
384     module->metadata_ = std::move(metadata_);
385   }
386 
387  private:
388   HloComputation* AddComputationInternal(
389       std::unique_ptr<HloComputation> computation, bool is_entry,
390       bool uniquify_identifiers, bool preserve_entry_layouts);
391 
392   string name_;
393   HloModuleConfig config_;
394   HloComputation* entry_computation_ = nullptr;
395   std::vector<std::unique_ptr<HloComputation>> computations_;
396 
397   // Random number generator engine to use when generating random numbers per
398   // HloModule compilation.
399   // TODO(b/25995601): Replace with better seed setting or dev/random for
400   // where we don't need deterministic execution.
401   mutable std::mt19937_64 rng_{42};
402   mutable tensorflow::mutex rng_mutex_;
403 
404   // Unique name generator for computation and instruction names, which are
405   // unique per module.
406   NameUniquer computation_name_uniquer_{/*separator=*/"."};
407   NameUniquer instruction_name_uniquer_{/*separator=*/"."};
408   int next_unique_id_ = 0;
409 
410   // Used to keep track of the next unique module id that should be assigned.
411   static std::atomic<int> next_unique_module_id_;
412   // A unique id to label modules with.
413   int unique_id_;
414 
415   // The HloSchedule of the module. The schedule if it exists contains a
416   // sequential order of instructions for each non-fusion computation in the
417   // module.
418   absl::optional<HloSchedule> schedule_;
419 
420   // alias_config indicates the alias information of input/output buffers that
421   // are expected from the module.
422   HloInputOutputAliasConfig input_output_alias_config_;
423 
424   // Bindings for dynamic parameter mapping.
425   DynamicParameterBinding dynamic_parameter_binding_;
426 
427   // The HLO shardings of the entry computation's parameters for
428   // SPMD-partitioned programs.
429   absl::optional<std::vector<HloSharding>> spmd_parameters_shardings_;
430 
431   // The HLO sharding of the entry computation's output (root) for
432   // SPMD-partitioned programs.
433   absl::optional<HloSharding> spmd_output_sharding_;
434 
435   // Arguments to be prefetched across programs.
436   std::vector<std::pair<int64, ShapeIndex>> cross_program_prefetches_;
437 
438   // Metadata for this module, such as its canonical id and the HLO passes run.
439   HloModuleMetadata metadata_;
440 
441   // True if the module contains dynamic computation.
442   bool is_dynamic_ = false;
443 };
444 
445 }  // namespace xla
446 
447 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_
448