• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <utility>
18 
19 #include "absl/memory/memory.h"
20 #include "tensorflow/compiler/xla/client/xla_builder.h"
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
30 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
31 #include "tensorflow/compiler/xla/tests/test_macros.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/platform/dynamic_annotations.h"
34 #include "tensorflow/core/platform/macros.h"
35 #include "tensorflow/core/platform/test.h"
36 
37 namespace {
R0F32Add2(float * out,float ** in)38 void R0F32Add2(float* out, float** in) {
39   TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float*));
40   *out = **in + 2.0f;
41 }
42 
R2F32ReduceSum(float * out,float ** in)43 void R2F32ReduceSum(float* out, float** in) {
44   TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4);
45   float* array = in[0];
46   *out = array[0] + array[1] + array[2] + array[3];
47 }
48 
Add1ToValues(float * out,float ** in)49 void Add1ToValues(float* out, float** in) {
50   TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4);
51   float* array = in[0];
52   out[0] = array[0] + 1;
53   out[1] = array[1] + 1;
54   out[2] = array[2] + 1;
55   out[3] = array[3] + 1;
56 }
57 
F32TupleSwap(float ** out,float ** in)58 void F32TupleSwap(float** out, float** in) {
59   TF_ANNOTATE_MEMORY_IS_INITIALIZED(in[0], sizeof(float));
60   TF_ANNOTATE_MEMORY_IS_INITIALIZED(in[1], sizeof(float));
61   *out[0] = *in[1];
62   *out[1] = *in[0];
63 }
64 
65 }  // namespace
66 
67 XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(R0F32Add2);
68 XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(R2F32ReduceSum);
69 XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(Add1ToValues);
70 XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(F32TupleSwap);
71 
72 namespace xla {
73 namespace {
74 
75 class CustomCallTest : public HloTestBase {
76  protected:
77   Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
78   Shape r2f32_ = ShapeUtil::MakeShape(F32, {2, 2});
79 };
80 
XLA_TEST_F(CustomCallTest,CustomCallR0F32Add2)81 XLA_TEST_F(CustomCallTest, CustomCallR0F32Add2) {
82   auto module = CreateNewVerifiedModule();
83   auto builder = HloComputation::Builder(TestName());
84 
85   auto constant = builder.AddInstruction(
86       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
87   builder.AddInstruction(
88       HloInstruction::CreateCustomCall(r0f32_, {constant}, "R0F32Add2"));
89 
90   module->AddEntryComputation(builder.Build());
91 
92   Literal result = ExecuteAndTransfer(std::move(module), {});
93   LiteralTestUtil::ExpectR0Near<float>(44.0f, result, error_spec_);
94 }
95 
XLA_TEST_F(CustomCallTest,CustomCallR2F32Reduce)96 XLA_TEST_F(CustomCallTest, CustomCallR2F32Reduce) {
97   auto module = CreateNewVerifiedModule();
98   auto builder = HloComputation::Builder(TestName());
99 
100   Array2D<float> array(2, 2);
101   array(0, 0) = 1.0f;
102   array(0, 1) = 2.0f;
103   array(1, 0) = 3.0f;
104   array(1, 1) = 4.0f;
105 
106   auto constant = builder.AddInstruction(
107       HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D(array)));
108   builder.AddInstruction(
109       HloInstruction::CreateCustomCall(r0f32_, {constant}, "R2F32ReduceSum"));
110 
111   module->AddEntryComputation(builder.Build());
112 
113   Literal result = ExecuteAndTransfer(std::move(module), {});
114   LiteralTestUtil::ExpectR0Near<float>(10.0f, result, error_spec_);
115 }
116 
XLA_TEST_F(CustomCallTest,UsedInOtherComputations)117 XLA_TEST_F(CustomCallTest, UsedInOtherComputations) {
118   auto module = CreateNewVerifiedModule();
119   auto b = HloComputation::Builder(TestName());
120 
121   auto input = b.AddInstruction(
122       HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D(
123           Array2D<float>{{1.0f, 2.0f}, {3.0f, 4.0f}})));
124   auto incremented = b.AddInstruction(HloInstruction::CreateCustomCall(
125       ShapeUtil::MakeShape(F32, {1, 2, 2}), {input}, "Add1ToValues"));
126   auto incremented_again = b.AddInstruction(HloInstruction::CreateCustomCall(
127       ShapeUtil::MakeShape(F32, {1, 2, 2}), {incremented}, "Add1ToValues"));
128 
129   // Concatenate the values along first dim.
130   b.AddInstruction(
131       HloInstruction::CreateConcatenate(ShapeUtil::MakeShape(F32, {2, 2, 2}),
132                                         {incremented, incremented_again}, 0));
133 
134   module->AddEntryComputation(b.Build());
135 
136   Literal result = ExecuteAndTransfer(std::move(module), {});
137   LiteralTestUtil::ExpectR3EqualArray3D<float>(
138       Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result);
139 }
140 
XLA_TEST_F(CustomCallTest,InputAndOutputLayoutDiffer)141 XLA_TEST_F(CustomCallTest, InputAndOutputLayoutDiffer) {
142   auto module = CreateNewVerifiedModule();
143   auto b = HloComputation::Builder(TestName());
144 
145   auto input =
146       b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p"));
147   b.AddInstruction(
148       HloInstruction::CreateCustomCall(r2f32_, {input}, "Add1ToValues"));
149 
150   module->AddEntryComputation(b.Build());
151   ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0}));
152   ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1}));
153 
154   Literal argument = LiteralUtil::CreateR2<float>({{1.f, 2.f}, {3.f, 4.f}});
155 
156   // Note, the expected result is transposed! This is because the input and
157   // output layouts of the custom call differ and the called function just
158   // blindly adds one to each element.
159   Literal result = ExecuteAndTransfer(std::move(module), {&argument});
160   LiteralTestUtil::ExpectR2Equal<float>({{2.f, 4.f}, {3.f, 5.f}}, result);
161 }
162 
XLA_TEST_F(CustomCallTest,LayoutConstrained)163 XLA_TEST_F(CustomCallTest, LayoutConstrained) {
164   // The argument and result of the computation are set to different layouts,
165   // but the custom call is layout constrained to a fixed operand and result
166   // layout, so the correct result should be produced.
167   auto module = CreateNewVerifiedModule();
168   auto b = HloComputation::Builder(TestName());
169 
170   auto input =
171       b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p"));
172 
173   const Shape& r2f32_dim0_major =
174       ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
175   auto custom_call = b.AddInstruction(HloInstruction::CreateCustomCall(
176       r2f32_dim0_major, {input}, "Add1ToValues", {r2f32_dim0_major}));
177   b.AddInstruction(
178       custom_call->CloneWithNewOperands(r2f32_dim0_major, {custom_call}));
179 
180   module->AddEntryComputation(b.Build());
181   ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0}));
182   ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1}));
183 
184   Literal argument = LiteralUtil::CreateR2<float>({{1.f, 2.f}, {3.f, 4.f}});
185 
186   Literal result = ExecuteAndTransfer(std::move(module), {&argument});
187   LiteralTestUtil::ExpectR2Equal<float>({{3.f, 4.f}, {5.f, 6.f}}, result);
188 }
189 
XLA_TEST_F(CustomCallTest,TupleOutput)190 XLA_TEST_F(CustomCallTest, TupleOutput) {
191   const char* kModuleStr = R"(
192     HloModule m
193     test {
194       p0 = f32[] parameter(0)
195       p1 = f32[] parameter(1)
196       ROOT %custom-call = (f32[], f32[]) custom-call(f32[] %p0, f32[] %p1), custom_call_target="F32TupleSwap", operand_layout_constraints={f32[], f32[]}
197     }
198   )";
199   TF_ASSERT_OK_AND_ASSIGN(auto module,
200                           ParseAndReturnVerifiedModule(kModuleStr));
201 
202   Literal arg0 = LiteralUtil::CreateR0<float>(7.f);
203   Literal arg1 = LiteralUtil::CreateR0<float>(42.f);
204 
205   Literal expected = LiteralUtil::MakeTuple({&arg1, &arg0});
206   Literal result = ExecuteAndTransfer(std::move(module), {&arg0, &arg1});
207   EXPECT_EQ(result, expected);
208 }
209 
210 class CustomCallClientAPITest : public ClientLibraryTestBase {};
211 
212 // When using the client API, CustomCall targets can't begin with '$' -- these
213 // are reserved for internal use.
XLA_TEST_F(CustomCallClientAPITest,IllegalCustomCallTarget)214 XLA_TEST_F(CustomCallClientAPITest, IllegalCustomCallTarget) {
215   XlaBuilder builder(TestName());
216   CustomCall(&builder, "$illegal", /*operands=*/{},
217              ShapeUtil::MakeShape(F32, {1}));
218 
219   StatusOr<std::unique_ptr<GlobalData>> result =
220       Execute(&builder, /*arguments=*/{});
221   EXPECT_FALSE(result.ok());
222 }
223 
224 }  // namespace
225 }  // namespace xla
226