• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_
18 
19 #include <atomic>
20 #include <list>
21 #include <memory>
22 #include <random>
23 #include <string>
24 #include <unordered_map>
25 #include <vector>
26 
27 #include "absl/strings/string_view.h"
28 #include "absl/types/optional.h"
29 #include "absl/types/span.h"
30 #include "tensorflow/compiler/xla/iterator_util.h"
31 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
32 #include "tensorflow/compiler/xla/service/hlo.pb.h"
33 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
34 #include "tensorflow/compiler/xla/service/hlo_computation.h"
35 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
37 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
38 #include "tensorflow/compiler/xla/service/hlo_module_metadata.h"
39 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
40 #include "tensorflow/compiler/xla/service/name_uniquer.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/core/lib/gtl/iterator_range.h"
43 #include "tensorflow/core/platform/logging.h"
44 #include "tensorflow/core/platform/mutex.h"
45 
46 namespace xla {
47 
48 // Describes a compilation unit at the HLO level.
49 //
50 // HloModule is the top-level unit in the HLO IR.  It corresponds to a whole
51 // "program".  Running a module, from beginning to end, is the only way to run
52 // an XLA program.
53 //
54 // A module contains one "entry computation"; this HloComputation is like main()
55 // in a C program.  The result of running the module is the result of running
56 // this computation.
57 //
58 // A module also contains some number of "nested computations".  Each nested
59 // computation is attached to an HloInstruction within some other computation.
60 // The meaning of the nested computation depends on the instruction it's
61 // attached to.
62 class HloModule {
63  public:
64   // Constructor without a versioned computation handle. This constructor should
65   // only be used for HloModules used outside of the XLA service (eg
66   // tests). The versioned handle is used by the service in the compilation
67   // cache. A default configuration is created for this module.
68   explicit HloModule(const string& name, HloModuleConfig config);
~HloModule()69   virtual ~HloModule() {}
70 
71   // Adds an entry computation to the module. A module can only have one entry
72   // computation. Returns a pointer to the newly added computation.
73   HloComputation* AddEntryComputation(
74       std::unique_ptr<HloComputation> computation);
75 
76   // Same as the AddEntryComputation function above but the module's
77   // entry_computation_layout is updated to match the layout of the new entry
78   // computation.
79   HloComputation* AddEntryComputationWithLayouts(
80       std::unique_ptr<HloComputation> computation);
81 
82   // Replaces the current entry computation with another computation.
83   // The new entry computation must be a computation that is already in the
84   // module.
85   void ReplaceEntryComputation(HloComputation* entry_computation);
86 
87   // Adds an embedded computation to the module.
88   HloComputation* AddEmbeddedComputation(
89       std::unique_ptr<HloComputation> computation);
90 
91   // Removes an embedded computation.
92   Status RemoveEmbeddedComputation(HloComputation* to_remove);
93 
94   // Removes unused computations.
95   Status RemoveUnusedComputations();
96 
97   // Replaces all uses of computations that are keys of 'replacements' with
98   // the corresponding values in 'replacements'. Replaces the entry computation,
99   // if applicable.
100   //
101   // This function iterates over all instructions in the module to find
102   // computations to replace. We could speed it up by keeping track of users of
103   // computations.
104   void ReplaceComputations(
105       const std::unordered_map<HloComputation*, HloComputation*>& replacements);
106 
name()107   const string& name() const { return name_; }
set_name(string name)108   void set_name(string name) { name_ = std::move(name); }
109 
110   // Returns a deep copy of this module including all computations.
111   std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const;
112   std::unique_ptr<HloModule> Clone(const HloModuleConfig& config,
113                                    const string& suffix = "clone") const;
114 
115   // Performs a deep clone of the computation, by recursively cloning all
116   // the called computations as well. If the clone context is specified, it
117   // will be populated with the cloned object mappings.
118   HloComputation* DeepCloneComputation(HloComputation* computation,
119                                        HloCloneContext* context = nullptr);
120 
121   // Return a pointer to the entry computation of the module.
entry_computation()122   HloComputation* entry_computation() const {
123     CHECK_NE(nullptr, entry_computation_);
124     return entry_computation_;
125   }
126 
has_entry_computation()127   bool has_entry_computation() const { return entry_computation_ != nullptr; }
128 
129   // Returns the root instruction shape of entry computation.
130   //
131   // Precondition: entry_computation_ is not nullptr.
result_shape()132   const Shape& result_shape() const {
133     CHECK_NE(nullptr, entry_computation_);
134     return entry_computation()->root_instruction()->shape();
135   }
136 
137   // Creates the ComputationLayout which describes the current status of the HLO
138   // module entry computation.
compute_computation_layout()139   ComputationLayout compute_computation_layout() const {
140     return ComputationLayout(entry_computation()->ComputeProgramShape(),
141                              /*ignore_layouts=*/false);
142   }
143 
mutable_entry_computation_layout()144   ComputationLayout* mutable_entry_computation_layout() {
145     return config_.mutable_entry_computation_layout();
146   }
147 
entry_computation_layout()148   const ComputationLayout& entry_computation_layout() const {
149     return config_.entry_computation_layout();
150   }
151 
152   // Generates a hash value of an HLO module. Hash considers
153   // information on opcode, shape, operands, and typically a root instruction.
154   // This function returns the same hash value for equivalent HLO modules,
155   // with respect to HloInstruction::Identical() method.
156   uint64 Hash() const;
157 
158   // Gets the computations in this module.
159   //
160   // Returns a view of HloComputation*s, so you can iterate over this in the
161   // natural way:
162   //
163   //   for (HloComputation* c : module->computations()) { ... }
164   //
165   tensorflow::gtl::iterator_range<UnwrappingIterator<
166       std::vector<std::unique_ptr<HloComputation>>::const_iterator>>
computations()167   computations() const {
168     return {MakeUnwrappingIterator(computations_.begin()),
169             MakeUnwrappingIterator(computations_.end())};
170   }
171   tensorflow::gtl::iterator_range<UnwrappingIterator<
172       std::vector<std::unique_ptr<HloComputation>>::iterator>>
computations()173   computations() {
174     return {MakeUnwrappingIterator(computations_.begin()),
175             MakeUnwrappingIterator(computations_.end())};
176   }
177 
178   // Returns the computation in this module that has the name `name`.  Returns
179   // null if there is no such computation.
180   HloComputation* GetComputationWithName(absl::string_view name);
181 
182   // Gets the number of computations in this module.
computation_count()183   int64 computation_count() const { return computations_.size(); }
184 
185   // Returns the mutable computation for the given index.
mutable_computation(int64 idx)186   HloComputation* mutable_computation(int64 idx) {
187     CHECK(idx >= 0 && idx < computations_.size());
188     return computations_[idx].get();
189   }
190 
191   // Gets the number of instructions in this module.
192   int64 instruction_count() const;
193 
194   // Deallocate removed instructions in each computation.
Cleanup()195   void Cleanup() {
196     for (auto& comp : computations_) {
197       comp->Cleanup();
198     }
199   }
200 
201   // Compute and return a post order of all computations in the module. The sort
202   // is defined like so: if computation A has an instruction which calls
203   // computation B, then A will appear after B in the sort.
204   std::vector<HloComputation*> MakeComputationPostOrder() const;
205 
206   // Same as MakeComputationPostOrder() but only returns the computations
207   // that are also found in the passed in allowList
208   std::vector<HloComputation*> MakeComputationPostOrder(
209       const absl::flat_hash_set<HloComputation*>& allow_list) const;
210 
211   // Same as MakeComputationPostOrder() but sorting the computations by their
212   // contents. The order is longer post order.
213   std::vector<HloComputation*> MakeComputationSorted() const;
214 
215   // Gets the computations in this module which aren't for fusion nodes.
216   //
217   // Postcondition: All computations in the returned list have
218   // !IsFusionComputation().
219   //
220   // Note: Callers can and do rely on the return value here being a *snapshot*
221   // of the module's non-fusion computations -- that is, it's OK to add or
222   // remove computations from a module while iterating over
223   // MakeNonfusionComputations().
224   std::vector<HloComputation*> MakeNonfusionComputations() const;
225 
226   // Same as MakeNonfusionComputations() but sorting computations by content.
227   std::vector<HloComputation*> MakeNonfusionComputationsSorted() const;
228 
config()229   const HloModuleConfig& config() const { return config_; }
set_config(const HloModuleConfig & config)230   void set_config(const HloModuleConfig& config) { config_ = config; }
231 
232   // Return a string representation of the module.
233   //
234   // (We express the default options using an overload rather than a default
235   // param because gdb ignores default params, but does resolve overloads.)
ToString()236   string ToString() const { return ToString(HloPrintOptions()); }
237   string ToString(const HloPrintOptions& options) const;
238 
239   // Convert an HloModule to or from a proto.
240   HloModuleProto ToProto() const;
241   static StatusOr<std::unique_ptr<HloModule>> CreateFromProto(
242       const HloModuleProto& proto, const HloModuleConfig& module_config,
243       bool prohibit_empty_literal = true);
244 
245   // Creates and returns an HloModuleConfig with an appropriate program shape
246   // for the HLO module in the given proto.
247   static StatusOr<HloModuleConfig> CreateModuleConfigFromProto(
248       const HloModuleProto& module, const DebugOptions& debug_options,
249       const ExecutionOptions* execution_options = nullptr);
250 
251   // Creates and returns an HloModuleConfig with an appropriate program shape
252   // for the HLO module in the given proto.
253   static StatusOr<HloModuleConfig> CreateModuleConfigFromShape(
254       const ProgramShape& program_shape, const DebugOptions& debug_options,
255       const ExecutionOptions* execution_options = nullptr);
256 
257   // Outlines the given expression from the given computation.
258   // instructions_to_outline contains the instructions that form the expression.
259   //
260   // Precondition: instructions in instructions_to_outline are in topological
261   // order (root of outlined instructions last). TODO(jingyue): takes a set of
262   // instructions and topologically sorts them.
263   HloInstruction* OutlineExpressionFromComputation(
264       absl::Span<HloInstruction* const> instructions_to_outline,
265       const string& outlined_computation_name, HloComputation* computation);
266 
267   // Returns a randomly generated uint64.
268   uint64 RandomNew64() const;
269 
270   // Returns the NameUniquer for uniquing instruction names in this module.
instruction_name_uniquer()271   NameUniquer& instruction_name_uniquer() { return instruction_name_uniquer_; }
272 
273   // Assign a new unique dense id for an instruction
NewUniqueInstructionId()274   int NewUniqueInstructionId() {
275     int result = next_unique_id_;
276     next_unique_id_++;
277     return result;
278   }
279 
280   // input_output_alias_config indicates the list of aliased buffers that are
281   // expected from the module.
input_output_alias_config()282   HloInputOutputAliasConfig& input_output_alias_config() {
283     return input_output_alias_config_;
284   }
input_output_alias_config()285   const HloInputOutputAliasConfig& input_output_alias_config() const {
286     return input_output_alias_config_;
287   }
288 
289   // DynamicParameterBinding holds the list of bindings that indicates which
290   // parameter dimensions are dynamic and which parameters represent their
291   // runtime value.
dynamic_parameter_binding()292   DynamicParameterBinding& dynamic_parameter_binding() {
293     return dynamic_parameter_binding_;
294   }
dynamic_parameter_binding()295   const DynamicParameterBinding& dynamic_parameter_binding() const {
296     return dynamic_parameter_binding_;
297   }
298 
299   // Returns an id that is unique to this module across all modules created over
300   // the lifetime of this process.
unique_id()301   int unique_id() const { return unique_id_; }
302 
303   // Sets the schedule of the module to the given schedule.
304   Status set_schedule(HloSchedule schedule);
305 
306   // Clears the schedule of the module.
clear_schedule()307   void clear_schedule() { schedule_.reset(); }
308 
309   // Returns true if the module has a schedule set.
has_schedule()310   bool has_schedule() const { return schedule_.has_value(); }
311 
312   // Returns the schedule of the module. CHECK fails if no schedule is set.
schedule()313   const HloSchedule& schedule() const { return *schedule_; }
schedule()314   HloSchedule& schedule() { return *schedule_; }
315 
AddComputationAndUnifyNamesAndIds(std::unique_ptr<HloComputation> computation,bool is_entry)316   HloComputation* AddComputationAndUnifyNamesAndIds(
317       std::unique_ptr<HloComputation> computation, bool is_entry) {
318     computation->ClearUniqueIdInternal();
319     for (auto* instruction : computation->instructions()) {
320       instruction->ClearUniqueIdInternal();
321     }
322     return AddComputationInternal(std::move(computation), is_entry,
323                                   /*uniquify_identifiers=*/true,
324                                   /*preserve_entry_layouts=*/true);
325   }
326 
327   Status CheckUniqueNamesAndIdsForComputationsAndInstructions() const;
328 
329   // Checks if this config has a list of entry parameters' HLO shardings for
330   // SPMD.
has_spmd_parameters_shardings()331   bool has_spmd_parameters_shardings() const {
332     return spmd_parameters_shardings_.has_value();
333   }
334 
335   // Getter and setter for the list of entry parameters' HLO shardings for SPMD.
spmd_parameters_shardings()336   const std::vector<HloSharding>& spmd_parameters_shardings() const {
337     CHECK(spmd_parameters_shardings_.has_value());
338     return *spmd_parameters_shardings_;
339   }
set_spmd_parameters_shardings(const std::vector<HloSharding> & shardings)340   void set_spmd_parameters_shardings(
341       const std::vector<HloSharding>& shardings) {
342     spmd_parameters_shardings_ = shardings;
343   }
344 
345   // Checks if this config has the entry computation output's HLO sharding for
346   // SPMD.
has_spmd_output_sharding()347   bool has_spmd_output_sharding() const {
348     return spmd_output_sharding_.has_value();
349   }
350 
351   // Getter and setter for the entry computation output's HLO shardings for
352   // SPMD.
spmd_output_sharding()353   const HloSharding& spmd_output_sharding() const {
354     CHECK(spmd_output_sharding_.has_value());
355     return *spmd_output_sharding_;
356   }
set_spmd_output_sharding(const HloSharding & sharding)357   void set_spmd_output_sharding(const HloSharding& sharding) {
358     spmd_output_sharding_ = sharding;
359   }
360 
361   // Add a program argument to be prefetched across programs.
AddCrossProgramPrefetch(int64 parameter,const ShapeIndex & index)362   void AddCrossProgramPrefetch(int64 parameter, const ShapeIndex& index) {
363     cross_program_prefetches_.emplace_back(parameter, index);
364   }
365 
366   // Get the list of program arguments to be prefetch across programs.
CrossProgramPrefetches()367   const absl::Span<const std::pair<int64, ShapeIndex>> CrossProgramPrefetches()
368       const {
369     return cross_program_prefetches_;
370   }
371 
metadata()372   const HloModuleMetadata& metadata() const { return metadata_; }
metadata()373   HloModuleMetadata* metadata() { return &metadata_; }
374 
375   // Moves (not copies) metadata from this HloModule to `module`. To be used
376   // in cases like HloModuleGroup::ReplaceModule when metadata should be
377   // transferred out of a module before it's destroyed.
MoveMetadataToModule(HloModule * module)378   void MoveMetadataToModule(HloModule* module) {
379     module->metadata_ = std::move(metadata_);
380   }
381 
382  private:
383   HloComputation* AddComputationInternal(
384       std::unique_ptr<HloComputation> computation, bool is_entry,
385       bool uniquify_identifiers, bool preserve_entry_layouts);
386 
387   string name_;
388   HloModuleConfig config_;
389   HloComputation* entry_computation_ = nullptr;
390   std::vector<std::unique_ptr<HloComputation>> computations_;
391 
392   // Random number generator engine to use when generating random numbers per
393   // HloModule compilation.
394   // TODO(b/25995601): Replace with better seed setting or dev/random for
395   // where we don't need deterministic execution.
396   mutable std::mt19937_64 rng_{42};
397   mutable tensorflow::mutex rng_mutex_;
398 
399   // Unique name generator for computation and instruction names, which are
400   // unique per module.
401   NameUniquer computation_name_uniquer_{/*separator=*/"."};
402   NameUniquer instruction_name_uniquer_{/*separator=*/"."};
403   int next_unique_id_ = 0;
404 
405   // Used to keep track of the next unique module id that should be assigned.
406   static std::atomic<int> next_unique_module_id_;
407   // A unique id to label modules with.
408   int unique_id_;
409 
410   // The HloSchedule of the module. The schedule if it exists contains a
411   // sequential order of instructions for each non-fusion computation in the
412   // module.
413   absl::optional<HloSchedule> schedule_;
414 
415   // alias_config indicates the alias information of input/output buffers that
416   // are expected from the module.
417   HloInputOutputAliasConfig input_output_alias_config_;
418 
419   // Bindings for dynamic parameter mapping.
420   DynamicParameterBinding dynamic_parameter_binding_;
421 
422   // The HLO shardings of the entry computation's parameters for
423   // SPMD-partitioned programs.
424   absl::optional<std::vector<HloSharding>> spmd_parameters_shardings_;
425 
426   // The HLO sharding of the entry computation's output (root) for
427   // SPMD-partitioned programs.
428   absl::optional<HloSharding> spmd_output_sharding_;
429 
430   // Arguments to be prefetched across programs.
431   std::vector<std::pair<int64, ShapeIndex>> cross_program_prefetches_;
432 
433   // Metadata for this module, such as its canonical id and the HLO passes run.
434   HloModuleMetadata metadata_;
435 };
436 
437 }  // namespace xla
438 
439 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_
440