1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHARDING_PROPAGATION_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_SHARDING_PROPAGATION_H_ 18 19 #include <memory> 20 #include <vector> 21 22 #include "tensorflow/compiler/xla/service/hlo_module.h" 23 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 24 #include "tensorflow/compiler/xla/statusor.h" 25 26 namespace xla { 27 28 // Propagates sharding information around the graph. HLOs that have shardings 29 // are kept as-is, those that do not have shardings are given shardings based on 30 // a simple local greedy heuristic. 31 class ShardingPropagation : public HloModulePass { 32 public: 33 explicit ShardingPropagation(bool is_spmd = false, 34 bool propagate_metadata = false) is_spmd_(is_spmd)35 : is_spmd_(is_spmd), propagate_metadata_(propagate_metadata) {} name()36 absl::string_view name() const override { return "sharding-propagation"; } 37 StatusOr<bool> Run(HloModule* module) override; 38 39 // Function which can be used to apply a spatially partitioned sharding onto a 40 // given domain. It will apply the sharding into the exit edges of the domain 41 // and then rely on the rest of sharding propagation to ensure that the 42 // intermediate nodes get the correct sharding. 43 static Status NormalizeDomain(const DomainMetadata::Domain& domain, 44 const DomainMetadata* metadata); 45 46 private: 47 bool is_spmd_; 48 bool propagate_metadata_; 49 }; 50 51 } // namespace xla 52 53 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_SHARDING_PROPAGATION_H_ 54