• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_
18 
19 #include <atomic>
20 #include <list>
21 #include <memory>
22 #include <random>
23 #include <string>
24 #include <unordered_map>
25 #include <vector>
26 
27 #include "absl/strings/string_view.h"
28 #include "absl/types/optional.h"
29 #include "absl/types/span.h"
30 #include "tensorflow/compiler/xla/iterator_util.h"
31 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
32 #include "tensorflow/compiler/xla/service/hlo.pb.h"
33 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
34 #include "tensorflow/compiler/xla/service/hlo_computation.h"
35 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
37 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
38 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
39 #include "tensorflow/compiler/xla/service/name_uniquer.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/core/lib/gtl/iterator_range.h"
42 #include "tensorflow/core/platform/logging.h"
43 #include "tensorflow/core/platform/mutex.h"
44 
45 namespace xla {
46 
47 // Describes a compilation unit at the HLO level.
48 //
49 // HloModule is the top-level unit in the HLO IR.  It corresponds to a whole
50 // "program".  Running a module, from beginning to end, is the only way to run
51 // an XLA program.
52 //
53 // A module contains one "entry computation"; this HloComputation is like main()
54 // in a C program.  The result of running the module is the result of running
55 // this computation.
56 //
57 // A module also contains some number of "nested computations".  Each nested
58 // computation is attached to an HloInstruction within some other computation.
59 // The meaning of the nested computation depends on the instruction it's
60 // attached to.
61 class HloModule {
62  public:
63   // Constructor without a versioned computation handle. This constructor should
64   // only be used for HloModules used outside of the XLA service (eg
65   // tests). The versioned handle is used by the service in the compilation
66   // cache. A default configuration is created for this module.
67   explicit HloModule(const string& name, HloModuleConfig config);
~HloModule()68   virtual ~HloModule() {}
69 
70   // Adds an entry computation to the module. A module can only have one entry
71   // computation. Returns a pointer to the newly added computation.
72   HloComputation* AddEntryComputation(
73       std::unique_ptr<HloComputation> computation);
74 
75   // Replaces the current entry computation with another computation.
76   // The new entry computation must be a computation that is already in the
77   // module.
78   void ReplaceEntryComputation(HloComputation* entry_computation);
79 
80   // Adds an embedded computation to the module.
81   HloComputation* AddEmbeddedComputation(
82       std::unique_ptr<HloComputation> computation);
83 
84   // Removes an embedded computation.
85   Status RemoveEmbeddedComputation(HloComputation* to_remove);
86 
87   // Removes unused computations.
88   Status RemoveUnusedComputations();
89 
90   // Replaces all uses of computations that are keys of 'replacements' with
91   // the corresponding values in 'replacements'. Replaces the entry computation,
92   // if applicable.
93   //
94   // This function iterates over all instructions in the module to find
95   // computations to replace. We could speed it up by keeping track of users of
96   // computations.
97   void ReplaceComputations(
98       const std::unordered_map<HloComputation*, HloComputation*>& replacements);
99 
name()100   const string& name() const { return name_; }
set_name(string name)101   void set_name(string name) { name_ = std::move(name); }
102 
103   // Returns a deep copy of this module including all computations.
104   std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const;
105   std::unique_ptr<HloModule> Clone(const HloModuleConfig& config,
106                                    const string& suffix = "clone") const;
107 
108   // Performs a deep clone of the computation, by recursively cloning all
109   // the called computations as well. If the clone context is specified, it
110   // will be populated with the cloned object mappings.
111   HloComputation* DeepCloneComputation(HloComputation* computation,
112                                        HloCloneContext* context = nullptr);
113 
114   // Return a pointer to the entry computation of the module.
entry_computation()115   HloComputation* entry_computation() const {
116     CHECK_NE(nullptr, entry_computation_);
117     return entry_computation_;
118   }
119 
has_entry_computation()120   bool has_entry_computation() const { return entry_computation_ != nullptr; }
121 
122   // Returns the root instruction shape of entry computation.
123   //
124   // Precondition: entry_computation_ is not nullptr.
result_shape()125   const Shape& result_shape() const {
126     CHECK_NE(nullptr, entry_computation_);
127     return entry_computation()->root_instruction()->shape();
128   }
129 
130   // Creates the ComputationLayout which describes the current status of the HLO
131   // module entry computation.
compute_computation_layout()132   ComputationLayout compute_computation_layout() const {
133     return ComputationLayout(entry_computation()->ComputeProgramShape(),
134                              /*ignore_layouts=*/false);
135   }
136 
mutable_entry_computation_layout()137   ComputationLayout* mutable_entry_computation_layout() {
138     return config_.mutable_entry_computation_layout();
139   }
140 
entry_computation_layout()141   const ComputationLayout& entry_computation_layout() const {
142     return config_.entry_computation_layout();
143   }
144 
145   // Generates a hash value of an HLO module. Hash considers
146   // information on opcode, shape, operands, and typically a root instruction.
147   // This function returns the same hash value for equivalent HLO modules,
148   // with respect to HloInstruction::Identical() method.
149   uint64 Hash() const;
150 
151   // Gets the computations in this module.
152   //
153   // Returns a view of HloComputation*s, so you can iterate over this in the
154   // natural way:
155   //
156   //   for (HloComputation* c : module->computations()) { ... }
157   //
158   tensorflow::gtl::iterator_range<UnwrappingIterator<
159       std::vector<std::unique_ptr<HloComputation>>::const_iterator>>
computations()160   computations() const {
161     return {MakeUnwrappingIterator(computations_.begin()),
162             MakeUnwrappingIterator(computations_.end())};
163   }
164   tensorflow::gtl::iterator_range<UnwrappingIterator<
165       std::vector<std::unique_ptr<HloComputation>>::iterator>>
computations()166   computations() {
167     return {MakeUnwrappingIterator(computations_.begin()),
168             MakeUnwrappingIterator(computations_.end())};
169   }
170 
171   // Returns the computation in this module that has the name `name`.  Returns
172   // null if there is no such computation.
173   HloComputation* GetComputationWithName(absl::string_view name);
174 
175   // Gets the number of computations in this module.
computation_count()176   int64 computation_count() const { return computations_.size(); }
177 
178   // Returns the mutable computation for the given index.
mutable_computation(int64 idx)179   HloComputation* mutable_computation(int64 idx) {
180     CHECK(idx >= 0 && idx < computations_.size());
181     return computations_[idx].get();
182   }
183 
184   // Gets the number of instructions in this module.
185   int64 instruction_count() const;
186 
187   // Compute and return a post order of all computations in the module. The sort
188   // is defined like so: if computation A has an instruction which calls
189   // computation B, then A will appear after B in the sort.
190   std::vector<HloComputation*> MakeComputationPostOrder() const;
191 
192   // Gets the computations in this module which aren't for fusion nodes.
193   //
194   // Postcondition: All computations in the returned list have
195   // !IsFusionComputation().
196   //
197   // Note: Callers can and do rely on the return value here being a *snapshot*
198   // of the module's non-fusion computations -- that is, it's OK to add or
199   // remove computations from a module while iterating over
200   // MakeNonfusionComputations().
201   std::vector<HloComputation*> MakeNonfusionComputations() const;
202 
203   // Same as MakeNonfusionComputations() but sorting the computations by names.
204   std::vector<HloComputation*> MakeNonfusionComputationsSorted() const;
205 
config()206   const HloModuleConfig& config() const { return config_; }
set_config(const HloModuleConfig & config)207   void set_config(const HloModuleConfig& config) { config_ = config; }
208 
209   // Return a string representation of the module.
210   //
211   // (We express the default options using an overload rather than a default
212   // param because gdb ignores default params, but does resolve overloads.)
ToString()213   string ToString() const { return ToString(HloPrintOptions()); }
214   string ToString(const HloPrintOptions& options) const;
215 
216   // Convert an HloModule to or from a proto.
217   HloModuleProto ToProto() const;
218   static StatusOr<std::unique_ptr<HloModule>> CreateFromProto(
219       const HloModuleProto& proto, const HloModuleConfig& module_config,
220       bool prohibit_empty_literal = true);
221 
222   // Creates and returns an HloModuleConfig with an appropriate program shape
223   // for the HLO module in the given proto.
224   static StatusOr<HloModuleConfig> CreateModuleConfigFromProto(
225       const HloModuleProto& module, const DebugOptions& debug_options,
226       const ExecutionOptions* execution_options = nullptr);
227 
228   // Creates and returns an HloModuleConfig with an appropriate program shape
229   // for the HLO module in the given proto.
230   static StatusOr<HloModuleConfig> CreateModuleConfigFromShape(
231       const ProgramShape& program_shape, const DebugOptions& debug_options,
232       const ExecutionOptions* execution_options = nullptr);
233 
234   // Outlines the given expression from the given computation.
235   // instructions_to_outline contains the instructions that form the expression.
236   //
237   // Precondition: instructions in instructions_to_outline are in topological
238   // order (root of outlined instructions last). TODO(jingyue): takes a set of
239   // instructions and topologically sorts them.
240   HloInstruction* OutlineExpressionFromComputation(
241       absl::Span<HloInstruction* const> instructions_to_outline,
242       const string& outlined_computation_name, HloComputation* computation);
243 
244   // Returns a randomly generated uint64.
245   uint64 RandomNew64() const;
246 
247   // Returns the NameUniquer for uniquing instruction names in this module.
instruction_name_uniquer()248   NameUniquer& instruction_name_uniquer() { return instruction_name_uniquer_; }
249 
250   // Assign a new unique dense id for an instruction
NewUniqueInstructionId()251   int NewUniqueInstructionId() {
252     int result = next_unique_id_;
253     next_unique_id_++;
254     return result;
255   }
256 
257   // input_output_alias_config indicates the list of aliased buffers that are
258   // expected from the module.
input_output_alias_config()259   HloInputOutputAliasConfig& input_output_alias_config() {
260     return input_output_alias_config_;
261   }
input_output_alias_config()262   const HloInputOutputAliasConfig& input_output_alias_config() const {
263     return input_output_alias_config_;
264   }
265 
266   // DynamicParameterBinding holds the list of bindings that indicates which
267   // parameter dimensions are dynamic and which parameters represent their
268   // runtime value.
dynamic_parameter_binding()269   DynamicParameterBinding& dynamic_parameter_binding() {
270     return dynamic_parameter_binding_;
271   }
dynamic_parameter_binding()272   const DynamicParameterBinding& dynamic_parameter_binding() const {
273     return dynamic_parameter_binding_;
274   }
275 
276   // Returns an id that is unique to this module across all modules created over
277   // the lifetime of this process.
unique_id()278   int unique_id() const { return unique_id_; }
279 
280   // Sets the schedule of the module to the given schedule.
281   Status set_schedule(HloSchedule schedule);
282 
283   // Clears the schedule of the module.
clear_schedule()284   void clear_schedule() { schedule_.reset(); }
285 
286   // Returns true if the module has a schedule set.
has_schedule()287   bool has_schedule() const { return schedule_.has_value(); }
288 
289   // Returns the schedule of the module. CHECK fails if no schedule is set.
schedule()290   const HloSchedule& schedule() const { return *schedule_; }
schedule()291   HloSchedule& schedule() { return *schedule_; }
292 
AddComputationAndUnifyNamesAndIds(std::unique_ptr<HloComputation> computation,bool is_entry)293   HloComputation* AddComputationAndUnifyNamesAndIds(
294       std::unique_ptr<HloComputation> computation, bool is_entry) {
295     computation->ClearUniqueIdInternal();
296     for (auto* instruction : computation->instructions()) {
297       instruction->ClearUniqueIdInternal();
298     }
299     return AddComputationInternal(std::move(computation), is_entry,
300                                   /*uniquify_identifiers=*/true);
301   }
302 
303   Status CheckUniqueNamesAndIdsForComputationsAndInstructions() const;
304 
305   // Checks if this config has a list of entry parameters' HLO shardings for
306   // SPMD.
has_spmd_parameters_shardings()307   bool has_spmd_parameters_shardings() const {
308     return spmd_parameters_shardings_.has_value();
309   }
310 
311   // Getter and setter for the list of entry parameters' HLO shardings for SPMD.
spmd_parameters_shardings()312   const std::vector<HloSharding>& spmd_parameters_shardings() const {
313     CHECK(spmd_parameters_shardings_.has_value());
314     return *spmd_parameters_shardings_;
315   }
set_spmd_parameters_shardings(const std::vector<HloSharding> & shardings)316   void set_spmd_parameters_shardings(
317       const std::vector<HloSharding>& shardings) {
318     spmd_parameters_shardings_ = shardings;
319   }
320 
321   // Checks if this config has the entry computation output's HLO sharding for
322   // SPMD.
has_spmd_output_sharding()323   bool has_spmd_output_sharding() const {
324     return spmd_output_sharding_.has_value();
325   }
326 
327   // Getter and setter for the entry computation output's HLO shardings for
328   // SPMD.
spmd_output_sharding()329   const HloSharding& spmd_output_sharding() const {
330     CHECK(spmd_output_sharding_.has_value());
331     return *spmd_output_sharding_;
332   }
set_spmd_output_sharding(const HloSharding & sharding)333   void set_spmd_output_sharding(const HloSharding& sharding) {
334     spmd_output_sharding_ = sharding;
335   }
336 
337  private:
338   HloComputation* AddComputationInternal(
339       std::unique_ptr<HloComputation> computation, bool is_entry,
340       bool uniquify_identifiers);
341 
342   // Same as MakeComputationPostOrder() but sorting the computations by their
343   // contents.
344   std::vector<HloComputation*> MakeComputationSortedByContent() const;
345 
346   string name_;
347   HloModuleConfig config_;
348   HloComputation* entry_computation_ = nullptr;
349   std::vector<std::unique_ptr<HloComputation>> computations_;
350 
351   // Random number generator engine to use when generating random numbers per
352   // HloModule compilation.
353   // TODO(b/25995601): Replace with better seed setting or dev/random for
354   // where we don't need deterministic execution.
355   mutable std::mt19937_64 rng_{42};
356   mutable tensorflow::mutex rng_mutex_;
357 
358   // Unique name generator for computation and instruction names, which are
359   // unique per module.
360   NameUniquer computation_name_uniquer_{/*separator=*/"."};
361   NameUniquer instruction_name_uniquer_{/*separator=*/"."};
362   int next_unique_id_ = 0;
363 
364   // Used to keep track of the next unique module id that should be assigned.
365   static std::atomic<int> next_unique_module_id_;
366   // A unique id to label modules with.
367   int unique_id_;
368 
369   // The HloSchedule of the module. The schedule if it exists contains a
370   // sequential order of instructions for each non-fusion computation in the
371   // module.
372   absl::optional<HloSchedule> schedule_;
373 
374   // alias_config indicates the alias information of input/output buffers that
375   // are expected from the module.
376   HloInputOutputAliasConfig input_output_alias_config_;
377 
378   // Bindings for dynamic parameter mapping.
379   DynamicParameterBinding dynamic_parameter_binding_;
380 
381   // The HLO shardings of the entry computation's parameters for
382   // SPMD-partitioned programs.
383   absl::optional<std::vector<HloSharding>> spmd_parameters_shardings_;
384 
385   // The HLO sharding of the entry computation's output (root) for
386   // SPMD-partitioned programs.
387   absl::optional<HloSharding> spmd_output_sharding_;
388 };
389 
390 }  // namespace xla
391 
392 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_
393