• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHARDING_PROPAGATION_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_SHARDING_PROPAGATION_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/service/hlo_module.h"
23 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
24 #include "tensorflow/compiler/xla/statusor.h"
25 
26 namespace xla {
27 
28 // Propagates sharding information around the graph. HLOs that have shardings
29 // are kept as-is, those that do not have shardings are given shardings based on
30 // a simple local greedy heuristic.
31 class ShardingPropagation : public HloModulePass {
32  public:
33   explicit ShardingPropagation(bool is_spmd = false,
34                                bool propagate_metadata = false)
is_spmd_(is_spmd)35       : is_spmd_(is_spmd), propagate_metadata_(propagate_metadata) {}
name()36   absl::string_view name() const override { return "sharding-propagation"; }
37   StatusOr<bool> Run(HloModule* module) override;
38 
39   // Function which can be used to apply a spatially partitioned sharding onto a
40   // given domain. It will apply the sharding into the exit edges of the domain
41   // and then rely on the rest of sharding propagation to ensure that the
42   // intermediate nodes get the correct sharding.
43   static Status NormalizeDomain(const DomainMetadata::Domain& domain,
44                                 const DomainMetadata* metadata);
45 
46  private:
47   bool is_spmd_;
48   bool propagate_metadata_;
49 };
50 
51 }  // namespace xla
52 
53 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SHARDING_PROPAGATION_H_
54