• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/sharding_remover.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
19 #include "tensorflow/compiler/xla/service/hlo_parser.h"
20 #include "tensorflow/compiler/xla/status_macros.h"
21 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
22 
23 namespace op = xla::testing::opcode_matchers;
24 
25 namespace xla {
26 namespace {
27 
28 using ShardingRemoverTest = HloTestBase;
29 
TEST_F(ShardingRemoverTest,RemoveSharding)30 TEST_F(ShardingRemoverTest, RemoveSharding) {
31   const char* const hlo_string = R"(
32 HloModule module
33 
34 ENTRY entry {
35  %parameter.3379 = f32[1,1]{1,0} parameter(0)
36  %custom-call.3380 = f32[1,1]{1,0} custom-call(f32[1,1]{1,0} %parameter.3379),
37    custom_call_target="Sharding", sharding={replicated}
38  ROOT %reshape.6032 = f32[] reshape(f32[1,1]{1,0} %custom-call.3380)
39 })";
40   TF_ASSERT_OK_AND_ASSIGN(auto module,
41                           ParseAndReturnVerifiedModule(hlo_string));
42   TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardingRemover().Run(module.get()));
43   EXPECT_TRUE(changed);
44   auto root = module->entry_computation()->root_instruction();
45   EXPECT_THAT(root, op::Reshape(op::Parameter()));
46 }
47 
TEST_F(ShardingRemoverTest,RemoveSPMDShardingToFullShape)48 TEST_F(ShardingRemoverTest, RemoveSPMDShardingToFullShape) {
49   const char* const hlo_string = R"(
50 HloModule module
51 
52 ENTRY entry {
53  %parameter.3379 = f32[1,1]{1,0} parameter(0)
54  %custom-call.3380 = f32[1,1]{1,0} custom-call(f32[1,1]{1,0} %parameter.3379),
55    custom_call_target="SPMDShardToFullShape", sharding={replicated}
56  ROOT %reshape.6032 = f32[] reshape(f32[1,1]{1,0} %custom-call.3380)
57 })";
58   TF_ASSERT_OK_AND_ASSIGN(auto module,
59                           ParseAndReturnVerifiedModule(hlo_string));
60   TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardingRemover().Run(module.get()));
61   EXPECT_TRUE(changed);
62   auto root = module->entry_computation()->root_instruction();
63   EXPECT_THAT(root, op::Reshape(op::Parameter()));
64 }
65 
TEST_F(ShardingRemoverTest,RemoveSPMDFullToShardShape)66 TEST_F(ShardingRemoverTest, RemoveSPMDFullToShardShape) {
67   const char* const hlo_string = R"(
68 HloModule module
69 
70 ENTRY entry {
71  %parameter.3379 = f32[1,1]{1,0} parameter(0)
72  %custom-call.3380 = f32[1,1]{1,0} custom-call(f32[1,1]{1,0} %parameter.3379),
73    custom_call_target="SPMDFullToShardShape", sharding={replicated}
74  ROOT %reshape.6032 = f32[] reshape(f32[1,1]{1,0} %custom-call.3380)
75 })";
76   TF_ASSERT_OK_AND_ASSIGN(auto module,
77                           ParseAndReturnVerifiedModule(hlo_string));
78   TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardingRemover().Run(module.get()));
79   EXPECT_TRUE(changed);
80   auto root = module->entry_computation()->root_instruction();
81   EXPECT_THAT(root, op::Reshape(op::Parameter()));
82 }
83 
TEST_F(ShardingRemoverTest,NoChangeForOtherCustomCall)84 TEST_F(ShardingRemoverTest, NoChangeForOtherCustomCall) {
85   const char* const hlo_string = R"(
86 HloModule cluster_2013453984438090939__.47
87 
88 ENTRY %cluster_2013453984438090939__.47
89   (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
90   %arg_tuple.1 = bf16[2,209664] parameter(0)
91   %custom-call = (bf16[2,2000]{1,0}, s32[2,2000]{1,0})
92     custom-call(bf16[2,209664]{1,0} %arg_tuple.1), custom_call_target="TopK"
93   %get-tuple-element = bf16[2,2000]{1,0}
94     get-tuple-element((bf16[2,2000]{1,0}, s32[2,2000]{1,0}) %custom-call),
95     index=0
96   %get-tuple-element.1 = s32[2,2000]{1,0} get-tuple-element((bf16[2,2000]{1,0},
97     s32[2,2000]{1,0}) %custom-call), index=1, sharding={replicated}
98   ROOT %tuple.46 = (bf16[2,2000]{1,0}, s32[2,2000]{1,0})
99     tuple(bf16[2,2000]{1,0} %get-tuple-element, s32[2,2000]{1,0}
100     %get-tuple-element.1),
101     metadata={op_name="XLA_Retvals"}
102 })";
103   TF_ASSERT_OK_AND_ASSIGN(auto module,
104                           ParseAndReturnVerifiedModule(hlo_string));
105   TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardingRemover().Run(module.get()));
106   EXPECT_FALSE(changed);
107 }
108 
109 }  // namespace
110 }  // namespace xla
111