1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <utility>
18
19 #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
20 #include "tensorflow/compiler/xla/tests/filecheck.h"
21 #include "tensorflow/core/platform/test.h"
22
23 namespace xla {
24 namespace gpu {
25 namespace {
26
27 class GpuAtomicTest : public GpuCodegenTest {};
28
TEST_F(GpuAtomicTest,TestStore)29 TEST_F(GpuAtomicTest, TestStore) {
30 const char* hlo_string = R"(
31 HloModule TensorFlowScatterV1
32
33 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
34 lhs = s32[] parameter(0)
35 ROOT rhs = s32[] parameter(1)
36 }
37
38 ENTRY main {
39 operand = s32[3,3] parameter(0)
40 indices = s32[2] parameter(1)
41 updates = s32[2,3] parameter(2)
42 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
43 to_apply=update_s32,
44 update_window_dims={1},
45 inserted_window_dims={0},
46 scatter_dims_to_operand_dims={0},
47 index_vector_dim=1
48 }
49 )";
50
51 CompileAndVerifyIr(hlo_string, R"(
52 CHECK: store atomic{{.*}}unordered, align 4
53 )");
54 }
55
TEST_F(GpuAtomicTest,TestStoreNoAtomic)56 TEST_F(GpuAtomicTest, TestStoreNoAtomic) {
57 const char* hlo_string = R"(
58 HloModule TensorFlowScatterV1
59
60 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
61 lhs = s32[] parameter(0)
62 ROOT rhs = s32[] parameter(1)
63 }
64
65 ENTRY main {
66 operand = s32[3,3] parameter(0)
67 indices = s32[2] parameter(1)
68 updates = s32[2,3] parameter(2)
69 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
70 to_apply=update_s32,
71 update_window_dims={1},
72 inserted_window_dims={0},
73 scatter_dims_to_operand_dims={0},
74 index_vector_dim=1, unique_indices=true
75 }
76 )";
77
78 CompileAndVerifyIr(hlo_string, R"(
79 CHECK-NOT: store atomic{{.*}}unordered, align 4
80 )");
81 }
82
TEST_F(GpuAtomicTest,TestAddAtomicF32)83 TEST_F(GpuAtomicTest, TestAddAtomicF32) {
84 const char* hlo_string = R"(
85 HloModule TensorFlowScatterV1
86
87 update_f32 (lhs: f32[], rhs: f32[]) -> f32[] {
88 lhs = f32[] parameter(0)
89 rhs = f32[] parameter(1)
90 ROOT add = f32[] add(lhs, rhs)
91 }
92
93 ENTRY main {
94 operand = f32[3,3] parameter(0)
95 indices = s32[2] parameter(1)
96 updates = f32[2,3] parameter(2)
97 ROOT scatter = f32[3,3] scatter(operand, indices, updates),
98 to_apply=update_f32,
99 update_window_dims={1},
100 inserted_window_dims={0},
101 scatter_dims_to_operand_dims={0},
102 index_vector_dim=1, unique_indices=false
103 }
104 )";
105
106 CompileAndVerifyIr(hlo_string, is_built_with_rocm_ ? R"(
107 CHECK: atomicrmw fadd float addrspace{{.*}}, float {{.*}} seq_cst, align 4
108 )"
109 : R"(
110 CHECK: atomicrmw fadd ptr %[[ADDR:.*]], float %[[VALUE:.*]] seq_cst
111 )");
112 }
113
TEST_F(GpuAtomicTest,TestAddAtomicF64)114 TEST_F(GpuAtomicTest, TestAddAtomicF64) {
115 // Atomic add required sm_60 or above.
116 if (!backend()
117 .default_stream_executor()
118 ->GetDeviceDescription()
119 .cuda_compute_capability()
120 .IsAtLeast(6)) {
121 return;
122 }
123
124 const char* hlo_string = R"(
125 HloModule TensorFlowScatterV1
126
127 update_f64 (lhs: f64[], rhs: f64[]) -> f64[] {
128 lhs = f64[] parameter(0)
129 rhs = f64[] parameter(1)
130 ROOT add = f64[] add(lhs, rhs)
131 }
132
133 ENTRY main {
134 operand = f64[3,3] parameter(0)
135 indices = s32[2] parameter(1)
136 updates = f64[2,3] parameter(2)
137 ROOT scatter = f64[3,3] scatter(operand, indices, updates),
138 to_apply=update_f64,
139 update_window_dims={1},
140 inserted_window_dims={0},
141 scatter_dims_to_operand_dims={0},
142 index_vector_dim=1, unique_indices=false
143 }
144 )";
145
146 CompileAndVerifyIr(hlo_string, R"(
147 CHECK: atomicrmw fadd ptr %[[ADDR:.*]], double %[[VALUE:.*]] seq_cst
148 )");
149 }
150
151 } // namespace
152 } // namespace gpu
153 } // namespace xla
154