1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <utility>
18
19 #include "absl/memory/memory.h"
20 #include "tensorflow/compiler/xla/client/xla_builder.h"
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
30 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
31 #include "tensorflow/compiler/xla/tests/test_macros.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/platform/dynamic_annotations.h"
34 #include "tensorflow/core/platform/macros.h"
35 #include "tensorflow/core/platform/test.h"
36
37 namespace {
R0F32Add2(float * out,float ** in)38 void R0F32Add2(float* out, float** in) {
39 TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float*));
40 *out = **in + 2.0f;
41 }
42
R2F32ReduceSum(float * out,float ** in)43 void R2F32ReduceSum(float* out, float** in) {
44 TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4);
45 float* array = in[0];
46 *out = array[0] + array[1] + array[2] + array[3];
47 }
48
Add1ToValues(float * out,float ** in)49 void Add1ToValues(float* out, float** in) {
50 TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4);
51 float* array = in[0];
52 out[0] = array[0] + 1;
53 out[1] = array[1] + 1;
54 out[2] = array[2] + 1;
55 out[3] = array[3] + 1;
56 }
57
F32TupleSwap(float ** out,float ** in)58 void F32TupleSwap(float** out, float** in) {
59 TF_ANNOTATE_MEMORY_IS_INITIALIZED(in[0], sizeof(float));
60 TF_ANNOTATE_MEMORY_IS_INITIALIZED(in[1], sizeof(float));
61 *out[0] = *in[1];
62 *out[1] = *in[0];
63 }
64
65 } // namespace
66
67 XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(R0F32Add2);
68 XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(R2F32ReduceSum);
69 XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(Add1ToValues);
70 XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(F32TupleSwap);
71
72 namespace xla {
73 namespace {
74
75 class CustomCallTest : public HloTestBase {
76 protected:
77 Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
78 Shape r2f32_ = ShapeUtil::MakeShape(F32, {2, 2});
79 };
80
XLA_TEST_F(CustomCallTest,CustomCallR0F32Add2)81 XLA_TEST_F(CustomCallTest, CustomCallR0F32Add2) {
82 auto module = CreateNewVerifiedModule();
83 auto builder = HloComputation::Builder(TestName());
84
85 auto constant = builder.AddInstruction(
86 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
87 builder.AddInstruction(
88 HloInstruction::CreateCustomCall(r0f32_, {constant}, "R0F32Add2"));
89
90 module->AddEntryComputation(builder.Build());
91
92 Literal result = ExecuteAndTransfer(std::move(module), {});
93 LiteralTestUtil::ExpectR0Near<float>(44.0f, result, error_spec_);
94 }
95
XLA_TEST_F(CustomCallTest,CustomCallR2F32Reduce)96 XLA_TEST_F(CustomCallTest, CustomCallR2F32Reduce) {
97 auto module = CreateNewVerifiedModule();
98 auto builder = HloComputation::Builder(TestName());
99
100 Array2D<float> array(2, 2);
101 array(0, 0) = 1.0f;
102 array(0, 1) = 2.0f;
103 array(1, 0) = 3.0f;
104 array(1, 1) = 4.0f;
105
106 auto constant = builder.AddInstruction(
107 HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D(array)));
108 builder.AddInstruction(
109 HloInstruction::CreateCustomCall(r0f32_, {constant}, "R2F32ReduceSum"));
110
111 module->AddEntryComputation(builder.Build());
112
113 Literal result = ExecuteAndTransfer(std::move(module), {});
114 LiteralTestUtil::ExpectR0Near<float>(10.0f, result, error_spec_);
115 }
116
XLA_TEST_F(CustomCallTest,UsedInOtherComputations)117 XLA_TEST_F(CustomCallTest, UsedInOtherComputations) {
118 auto module = CreateNewVerifiedModule();
119 auto b = HloComputation::Builder(TestName());
120
121 auto input = b.AddInstruction(
122 HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D(
123 Array2D<float>{{1.0f, 2.0f}, {3.0f, 4.0f}})));
124 auto incremented = b.AddInstruction(HloInstruction::CreateCustomCall(
125 ShapeUtil::MakeShape(F32, {1, 2, 2}), {input}, "Add1ToValues"));
126 auto incremented_again = b.AddInstruction(HloInstruction::CreateCustomCall(
127 ShapeUtil::MakeShape(F32, {1, 2, 2}), {incremented}, "Add1ToValues"));
128
129 // Concatenate the values along first dim.
130 b.AddInstruction(
131 HloInstruction::CreateConcatenate(ShapeUtil::MakeShape(F32, {2, 2, 2}),
132 {incremented, incremented_again}, 0));
133
134 module->AddEntryComputation(b.Build());
135
136 Literal result = ExecuteAndTransfer(std::move(module), {});
137 LiteralTestUtil::ExpectR3EqualArray3D<float>(
138 Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result);
139 }
140
XLA_TEST_F(CustomCallTest,InputAndOutputLayoutDiffer)141 XLA_TEST_F(CustomCallTest, InputAndOutputLayoutDiffer) {
142 auto module = CreateNewVerifiedModule();
143 auto b = HloComputation::Builder(TestName());
144
145 auto input =
146 b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p"));
147 b.AddInstruction(
148 HloInstruction::CreateCustomCall(r2f32_, {input}, "Add1ToValues"));
149
150 module->AddEntryComputation(b.Build());
151 ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0}));
152 ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1}));
153
154 Literal argument = LiteralUtil::CreateR2<float>({{1.f, 2.f}, {3.f, 4.f}});
155
156 // Note, the expected result is transposed! This is because the input and
157 // output layouts of the custom call differ and the called function just
158 // blindly adds one to each element.
159 Literal result = ExecuteAndTransfer(std::move(module), {&argument});
160 LiteralTestUtil::ExpectR2Equal<float>({{2.f, 4.f}, {3.f, 5.f}}, result);
161 }
162
XLA_TEST_F(CustomCallTest,LayoutConstrained)163 XLA_TEST_F(CustomCallTest, LayoutConstrained) {
164 // The argument and result of the computation are set to different layouts,
165 // but the custom call is layout constrained to a fixed operand and result
166 // layout, so the correct result should be produced.
167 auto module = CreateNewVerifiedModule();
168 auto b = HloComputation::Builder(TestName());
169
170 auto input =
171 b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p"));
172
173 const Shape& r2f32_dim0_major =
174 ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
175 auto custom_call = b.AddInstruction(HloInstruction::CreateCustomCall(
176 r2f32_dim0_major, {input}, "Add1ToValues", {r2f32_dim0_major}));
177 b.AddInstruction(
178 custom_call->CloneWithNewOperands(r2f32_dim0_major, {custom_call}));
179
180 module->AddEntryComputation(b.Build());
181 ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0}));
182 ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1}));
183
184 Literal argument = LiteralUtil::CreateR2<float>({{1.f, 2.f}, {3.f, 4.f}});
185
186 Literal result = ExecuteAndTransfer(std::move(module), {&argument});
187 LiteralTestUtil::ExpectR2Equal<float>({{3.f, 4.f}, {5.f, 6.f}}, result);
188 }
189
XLA_TEST_F(CustomCallTest,TupleOutput)190 XLA_TEST_F(CustomCallTest, TupleOutput) {
191 const char* kModuleStr = R"(
192 HloModule m
193 test {
194 p0 = f32[] parameter(0)
195 p1 = f32[] parameter(1)
196 ROOT %custom-call = (f32[], f32[]) custom-call(f32[] %p0, f32[] %p1), custom_call_target="F32TupleSwap", operand_layout_constraints={f32[], f32[]}
197 }
198 )";
199 TF_ASSERT_OK_AND_ASSIGN(auto module,
200 ParseAndReturnVerifiedModule(kModuleStr));
201
202 Literal arg0 = LiteralUtil::CreateR0<float>(7.f);
203 Literal arg1 = LiteralUtil::CreateR0<float>(42.f);
204
205 Literal expected = LiteralUtil::MakeTuple({&arg1, &arg0});
206 Literal result = ExecuteAndTransfer(std::move(module), {&arg0, &arg1});
207 EXPECT_EQ(result, expected);
208 }
209
210 class CustomCallClientAPITest : public ClientLibraryTestBase {};
211
212 // When using the client API, CustomCall targets can't begin with '$' -- these
213 // are reserved for internal use.
XLA_TEST_F(CustomCallClientAPITest,IllegalCustomCallTarget)214 XLA_TEST_F(CustomCallClientAPITest, IllegalCustomCallTarget) {
215 XlaBuilder builder(TestName());
216 CustomCall(&builder, "$illegal", /*operands=*/{},
217 ShapeUtil::MakeShape(F32, {1}));
218
219 StatusOr<std::unique_ptr<GlobalData>> result =
220 Execute(&builder, /*arguments=*/{});
221 EXPECT_FALSE(result.ok());
222 }
223
224 } // namespace
225 } // namespace xla
226