• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_
18 
19 #include <iosfwd>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/strings/string_view.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/service/hlo.pb.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 
28 namespace xla {
29 
30 // An abstraction representing a ordered set of HLO module built to run
31 // concurrently across different devices.
32 class HloModuleGroup {
33  public:
34   // Construct an empty module group.
HloModuleGroup(absl::string_view name)35   explicit HloModuleGroup(absl::string_view name) : name_(name) {}
36 
37   // Construct a module group containing a single module.
38   explicit HloModuleGroup(std::unique_ptr<HloModule> module);
39 
40   // Construct a module group containing any number of modules.
41   HloModuleGroup(absl::string_view name,
42                  absl::Span<std::unique_ptr<HloModule>> modules);
43   HloModuleGroup(absl::string_view name,
44                  std::vector<std::unique_ptr<HloModule>>&& modules);
45 
46   // Returns the modules contained in the group.
modules()47   const std::vector<HloModule*>& modules() const { return module_ptrs_; }
48 
49   // Returns a module at a particular index.
module(int index)50   HloModule& module(int index) const { return *module_ptrs_.at(index); }
51 
52   // Add a module to the back of vector of modules in the group.
53   void push_back(std::unique_ptr<HloModule> module);
54 
55   // Replaces the existing module at the given index with the given module. The
56   // existing module is discarded.
57   void ReplaceModule(int index, std::unique_ptr<HloModule> module);
58 
59   // Moves all modules from the group into the returned vector. After this
60   // method runs, the module group will be empty.
61   std::vector<std::unique_ptr<HloModule>> ConsumeModules();
62 
name()63   std::string name() const { return name_; }
64 
65   std::string ToString() const;
66 
67   // Deallocate removed instructions in each module.
Cleanup()68   void Cleanup() {
69     for (auto& module : modules_) {
70       module->Cleanup();
71     }
72   }
73 
74   template <typename H>
AbslHashValue(H h,const HloModuleGroup & group)75   friend H AbslHashValue(H h, const HloModuleGroup& group) {
76     for (auto& module : group.modules_) {
77       h = H::combine(std::move(h), *module);
78     }
79     return H::combine(std::move(h), group.modules_.size());
80   }
81 
82   // Serialize the module group to/from a proto.
83   HloModuleGroupProto ToProto() const;
84   static StatusOr<HloModuleGroup> CreateFromProto(
85       const HloModuleGroupProto& proto,
86       absl::Span<const HloModuleConfig> module_configs);
87 
88   // Returns the number of modules in the module group.
size()89   int size() const { return modules_.size(); }
90 
91   // Returns true if there are no modules in the module group.
empty()92   bool empty() const { return modules_.empty(); }
93 
cache_key()94   absl::string_view cache_key() const { return cache_key_; }
set_cache_key(absl::string_view cache_key)95   void set_cache_key(absl::string_view cache_key) {
96     cache_key_ = std::string(cache_key);
97   }
98 
99  private:
100   std::string name_;
101 
102   // Vector of modules as std::unique_ptrs.
103   std::vector<std::unique_ptr<HloModule>> modules_;
104 
105   // Vector of modules as normal pointers. This vector is kept in sync with
106   // modules_ as modules are added to the group with push_back.
107   std::vector<HloModule*> module_ptrs_;
108 
109   std::string cache_key_;
110 };
111 
112 std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group);
113 
114 }  // namespace xla
115 
116 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_
117