1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18 #include <vector>
19
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/cc/client/client_session.h"
22 #include "tensorflow/cc/framework/ops.h"
23 #include "tensorflow/cc/framework/scope.h"
24 #include "tensorflow/cc/ops/standard_ops.h"
25 #include "tensorflow/compiler/tf2xla/literal_util.h"
26 #include "tensorflow/compiler/tf2xla/shape_util.h"
27 #include "tensorflow/compiler/xla/client/client_library.h"
28 #include "tensorflow/compiler/xla/client/local_client.h"
29 #include "tensorflow/compiler/xla/client/xla_builder.h"
30 #include "tensorflow/compiler/xla/client/xla_computation.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/literal_util.h"
33 #include "tensorflow/compiler/xla/service/platform_util.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h"
37 #include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h"
38 #include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h"
39 #include "tensorflow/compiler/xrt/xrt.pb.h"
40 #include "tensorflow/core/framework/tensor.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/lib/core/status_test_util.h"
43 #include "tensorflow/core/lib/gtl/array_slice.h"
44 #include "tensorflow/core/platform/types.h"
45 #include "tensorflow/core/util/command_line_flags.h"
46
47 namespace tensorflow {
48 namespace {
49
50 string* xla_test_device_ptr; // initial value set in main()
51 string* xla_platform_ptr; // initial value set in main()
52
DeviceFromFlag()53 string DeviceFromFlag() {
54 string xla_test_device = *xla_test_device_ptr;
55 return absl::StrCat("/device:", xla_test_device, ":0");
56 }
57
GetAttrLayout(absl::Span<const int64> minor_to_mayor)58 std::vector<int> GetAttrLayout(absl::Span<const int64> minor_to_mayor) {
59 std::vector<int> layout;
60 for (auto dim : minor_to_mayor) {
61 layout.push_back(static_cast<int>(dim));
62 }
63 return layout;
64 }
65
TwoElementTuple()66 xla::LiteralProto TwoElementTuple() {
67 auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
68 auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
69 auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
70 return tuple.ToProto();
71 }
72
ScalarLiteral()73 xla::LiteralProto ScalarLiteral() {
74 auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
75 return scalar.ToProto();
76 }
77
NestedTuple()78 xla::LiteralProto NestedTuple() {
79 auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
80 auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
81 auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
82 auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
83 auto nested = xla::LiteralUtil::MakeTuple({&tuple, &scalar});
84 return nested.ToProto();
85 }
86
MakeTuple0()87 xla::LiteralProto MakeTuple0() {
88 auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
89 auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
90 auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
91 auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
92 auto nested0 = xla::LiteralUtil::MakeTuple({&scalar, &tuple});
93 auto nested1 = xla::LiteralUtil::MakeTuple({&scalar, &nested0});
94 return nested1.ToProto();
95 }
96
FloatVector(absl::Span<const float> v)97 xla::LiteralProto FloatVector(absl::Span<const float> v) {
98 auto array = xla::LiteralUtil::CreateR1<float>(v);
99 return array.ToProto();
100 }
101
FloatMatrix(std::initializer_list<std::initializer_list<float>> v,const xla::Layout & layout)102 xla::LiteralProto FloatMatrix(
103 std::initializer_list<std::initializer_list<float>> v,
104 const xla::Layout& layout) {
105 auto array = xla::LiteralUtil::CreateR2WithLayout<float>(v, layout);
106 return array.ToProto();
107 }
108
ReadOutputLiteral(const std::vector<Tensor> & outputs,size_t idx)109 xla::Literal ReadOutputLiteral(const std::vector<Tensor>& outputs, size_t idx) {
110 xla::LiteralProto response;
111 CHECK(response.ParseFromString(outputs[idx].scalar<string>()()));
112 return xla::Literal::CreateFromProto(response).ValueOrDie();
113 }
114
CompareLiteralProtos(const xla::LiteralProto & a,const xla::LiteralProto & b)115 bool CompareLiteralProtos(const xla::LiteralProto& a,
116 const xla::LiteralProto& b) {
117 auto l_a = xla::Literal::CreateFromProto(a).ValueOrDie();
118 auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
119 bool equal = l_a == l_b;
120 if (!equal) {
121 LOG(INFO) << "LiteralProtos don't match:\n"
122 << a.DebugString() << "\n!=\n"
123 << b.DebugString();
124 }
125 return equal;
126 }
127
CompareLiteralToLiteralProto(const xla::Literal & a,const xla::LiteralProto & b)128 bool CompareLiteralToLiteralProto(const xla::Literal& a,
129 const xla::LiteralProto& b) {
130 auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
131 bool equal = a == l_b;
132 if (!equal) {
133 LOG(INFO) << "Literal and LiteralProto don't match:\n"
134 << a.ToProto().DebugString() << "\n!=\n"
135 << b.DebugString();
136 }
137 return equal;
138 }
139
CompareLiterals(const xla::Literal & a,const xla::Literal & b)140 bool CompareLiterals(const xla::Literal& a, const xla::Literal& b) {
141 bool equal = a == b;
142 if (!equal) {
143 LOG(INFO) << "Literals don't match:\n"
144 << a.ToProto().DebugString() << "\n!=\n"
145 << b.ToProto().DebugString();
146 }
147 return equal;
148 }
149
OnePlusTwo()150 xla::XlaComputation OnePlusTwo() {
151 xla::XlaBuilder builder("OnePlusTwo");
152 auto c0 = xla::ConstantR0(&builder, 1.0f);
153 auto c1 = xla::ConstantR0(&builder, 2.0f);
154 xla::Add(c0, c1);
155 return builder.Build().ValueOrDie();
156 }
157
AddAndScale()158 xla::XlaComputation AddAndScale() {
159 xla::XlaBuilder builder("AddAndScale");
160 auto p0 = xla::Parameter(&builder, 0,
161 xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0");
162 auto p1 = xla::Parameter(&builder, 1,
163 xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1");
164 auto sum = xla::Add(p0, p1);
165 auto c = xla::ConstantR0<float>(&builder, 3.0f);
166 xla::Mul(sum, c);
167 return builder.Build().ValueOrDie();
168 }
169
Dot()170 xla::XlaComputation Dot() {
171 xla::XlaBuilder builder("Dot");
172 auto p0 = xla::Parameter(
173 &builder, 0,
174 xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1}), "P0");
175 auto p1 = xla::Parameter(
176 &builder, 1,
177 xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}), "P1");
178 xla::DotDimensionNumbers ddn;
179 ddn.add_lhs_contracting_dimensions(1);
180 ddn.add_rhs_contracting_dimensions(0);
181 xla::DotGeneral(p0, p1, ddn);
182 return builder.Build().ValueOrDie();
183 }
184
AddS64()185 xla::XlaComputation AddS64() {
186 xla::XlaBuilder builder("AddS64");
187 auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::S64, {}),
188 "P0");
189 auto p1 = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(xla::S64, {}),
190 "P1");
191 xla::Add(p0, p1);
192 return builder.Build().ValueOrDie();
193 }
194
AddAndTuple()195 xla::XlaComputation AddAndTuple() {
196 xla::XlaBuilder builder("AddAndTuple");
197 auto p0 = xla::Parameter(&builder, 0,
198 xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0");
199 auto p1 = xla::Parameter(&builder, 1,
200 xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1");
201 auto sum = xla::Add(p0, p1);
202 xla::Tuple(&builder, {sum});
203 return builder.Build().ValueOrDie();
204 }
205
AddAndSubTuple()206 xla::XlaComputation AddAndSubTuple() {
207 xla::XlaBuilder builder("AddAndSubTuple");
208 auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::F32, {}),
209 "P0");
210 auto p1 = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(xla::F32, {}),
211 "P1");
212 auto sum = xla::Add(p0, p1);
213 auto sub = xla::Sub(p0, p1);
214 xla::Tuple(&builder, {sum, sub});
215 return builder.Build().ValueOrDie();
216 }
217
StoreComputationSnapshot(const xla::XlaComputation & computation,xla::HloSnapshot * dst)218 void StoreComputationSnapshot(const xla::XlaComputation& computation,
219 xla::HloSnapshot* dst) {
220 auto snapshot = computation.Snapshot().ValueOrDie();
221 *dst = *snapshot;
222 }
223
XlaCompiledProgramShape(const xla::XlaComputation & computation,const xla::ProgramShape & input_program_shape)224 xla::ProgramShape XlaCompiledProgramShape(
225 const xla::XlaComputation& computation,
226 const xla::ProgramShape& input_program_shape) {
227 se::Platform* platform =
228 xla::PlatformUtil::GetPlatform(*xla_platform_ptr).ValueOrDie();
229 xla::LocalClient* client =
230 xla::ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
231 xla::ExecutableBuildOptions exec_options;
232 exec_options.set_result_layout(input_program_shape.result());
233 std::vector<const xla::Shape*> parameters_shapes;
234 for (int64 i = 0; i < input_program_shape.parameters_size(); ++i) {
235 parameters_shapes.push_back(&input_program_shape.parameters(i));
236 }
237 auto local_executable =
238 client->Compile(computation, parameters_shapes, exec_options)
239 .ValueOrDie();
240 return local_executable->executable()
241 ->module()
242 .entry_computation()
243 ->ComputeProgramShape();
244 }
245
TEST(RawApiTest,AllocFromTensor)246 TEST(RawApiTest, AllocFromTensor) {
247 xla::Literal literal =
248 xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
249 Tensor tensor;
250 TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor));
251
252 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
253 std::vector<int> layout =
254 GetAttrLayout(literal.shape().layout().minor_to_major());
255 ops::XRTAllocateFromTensor::Attrs alloc_attrs =
256 ops::XRTAllocateFromTensor::Layouts(layout);
257 auto handle =
258 ops::XRTAllocateFromTensor(root, {tensor}, {tensor.shape()}, alloc_attrs);
259 auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
260 TF_ASSERT_OK(root.status());
261
262 ClientSession session(root);
263 std::vector<Tensor> outputs;
264 TF_EXPECT_OK(session.Run({read_back}, &outputs));
265 EXPECT_EQ(outputs.size(), 1);
266
267 xla::LiteralProto response;
268 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
269 EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
270 }
271
TEST(RawApiTest,AllocFromTensorTuple)272 TEST(RawApiTest, AllocFromTensorTuple) {
273 xla::Literal literal0 =
274 xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
275 xla::Literal literal1 =
276 xla::LiteralUtil::CreateR2<float>({{14.0f, -5.0f}, {16.0f, 17.0f}});
277 xla::Literal literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1});
278 Tensor tensor0;
279 TF_ASSERT_OK(LiteralToHostTensor(literal0, DT_FLOAT, &tensor0));
280 Tensor tensor1;
281 TF_ASSERT_OK(LiteralToHostTensor(literal1, DT_FLOAT, &tensor1));
282
283 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
284 std::vector<int> layout = GetShapeLayoutVector(literal.shape()).ValueOrDie();
285 ops::XRTAllocateFromTensor::Attrs alloc_attrs =
286 ops::XRTAllocateFromTensor::Layouts(layout);
287 auto handle = ops::XRTAllocateFromTensor(root, {tensor0, tensor1},
288 {tensor0.shape(), tensor1.shape()},
289 alloc_attrs);
290 auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
291 TF_ASSERT_OK(root.status());
292
293 ClientSession session(root);
294 std::vector<Tensor> outputs;
295 TF_EXPECT_OK(session.Run({read_back}, &outputs));
296 EXPECT_EQ(outputs.size(), 1);
297
298 xla::LiteralProto response;
299 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
300 EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
301 }
302
TEST(RawApiTest,AllocFromTensorTupleSingle)303 TEST(RawApiTest, AllocFromTensorTupleSingle) {
304 xla::Literal literal0 =
305 xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
306 xla::Literal literal = xla::LiteralUtil::MakeTuple({&literal0});
307 Tensor tensor0;
308 TF_ASSERT_OK(LiteralToHostTensor(literal0, DT_FLOAT, &tensor0));
309
310 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
311 std::vector<int> layout = GetShapeLayoutVector(literal.shape()).ValueOrDie();
312 ops::XRTAllocateFromTensor::Attrs alloc_attrs =
313 ops::XRTAllocateFromTensor::Layouts(layout).MakeTuple(true);
314 auto handle = ops::XRTAllocateFromTensor(root, {tensor0}, {tensor0.shape()},
315 alloc_attrs);
316 auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
317 TF_ASSERT_OK(root.status());
318
319 ClientSession session(root);
320 std::vector<Tensor> outputs;
321 TF_EXPECT_OK(session.Run({read_back}, &outputs));
322 EXPECT_EQ(outputs.size(), 1);
323
324 xla::LiteralProto response;
325 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
326 EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
327 }
328
TEST(RawApiTest,AllocFromTensorRelayout)329 TEST(RawApiTest, AllocFromTensorRelayout) {
330 xla::Literal literal =
331 xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
332 Tensor tensor;
333 TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor));
334
335 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
336 // Use inverse array layout with the tensor data above.
337 std::vector<int> layout({0, 1});
338 ops::XRTAllocateFromTensor::Attrs alloc_attrs =
339 ops::XRTAllocateFromTensor::Layouts(layout);
340 auto handle =
341 ops::XRTAllocateFromTensor(root, {tensor}, {tensor.shape()}, alloc_attrs);
342 auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
343 TF_ASSERT_OK(root.status());
344
345 ClientSession session(root);
346 std::vector<Tensor> outputs;
347 TF_EXPECT_OK(session.Run({read_back}, &outputs));
348 EXPECT_EQ(outputs.size(), 1);
349
350 xla::LiteralProto response;
351 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
352 // We have sent literal's data (in array layout) with a attribute layout
353 // {0,1}, so the expected literal read from device needs to be changed
354 // accordingly.
355 xla::Literal expected_literal =
356 xla::LiteralUtil::CreateR2<float>({{4.0f, 6.0f}, {5.0f, 7.0f}});
357 EXPECT_TRUE(CompareLiteralToLiteralProto(expected_literal, response));
358 }
359
TEST(RawApiTest,AllocAndRewrite)360 TEST(RawApiTest, AllocAndRewrite) {
361 xrt::XLAAllocation alloc;
362 *alloc.mutable_value() =
363 xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto();
364
365 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
366 auto value =
367 ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
368 auto handle = ops::XRTAllocate(root, value);
369 auto read_back = ops::XRTReadLiteral(root, handle);
370 TF_ASSERT_OK(root.status());
371
372 tensorflow::ClientSession session(root);
373 std::vector<tensorflow::Tensor> outputs;
374 TF_EXPECT_OK(session.Run({read_back, handle}, &outputs));
375 EXPECT_EQ(outputs.size(), 2);
376
377 int64 allocation_handle = outputs[1].scalar<int64>()();
378 xla::LiteralProto response;
379 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
380 EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
381 outputs.clear();
382
383 xla::LiteralProto new_literal =
384 xla::LiteralUtil::CreateR2({{9, 2}, {4, 1}}).ToProto();
385 auto new_value = ops::Const(root.WithDevice("/device:CPU:0"),
386 new_literal.SerializeAsString());
387 auto write_op =
388 ops::XRTWriteLiteral(root, Input(allocation_handle), new_value);
389 TF_ASSERT_OK(root.status());
390 TF_EXPECT_OK(session.Run({write_op}, &outputs));
391 EXPECT_EQ(outputs.size(), 1);
392 EXPECT_EQ(allocation_handle, outputs[0].scalar<int64>()());
393 outputs.clear();
394
395 auto read_after_write = ops::XRTReadLiteral(root, Input(allocation_handle));
396 TF_EXPECT_OK(session.Run({read_after_write}, &outputs));
397 EXPECT_EQ(outputs.size(), 1);
398
399 xla::LiteralProto new_response;
400 EXPECT_TRUE(new_response.ParseFromString(outputs[0].scalar<string>()()));
401 EXPECT_TRUE(CompareLiteralProtos(new_literal, new_response));
402
403 Tensor release_tensor(DT_INT64, TensorShape({1}));
404 release_tensor.flat<int64>()(0) = allocation_handle;
405
406 auto release = ops::XRTReleaseAllocationHandle(root, release_tensor);
407 TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release},
408 &outputs));
409 }
410
TEST(RawApiTest,AllocReleaseMany)411 TEST(RawApiTest, AllocReleaseMany) {
412 xrt::XLAAllocation alloc1;
413 *alloc1.mutable_value() =
414 xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto();
415 xrt::XLAAllocation alloc2;
416 *alloc2.mutable_value() =
417 xla::LiteralUtil::CreateR2({{6, 7}, {4, 5}}).ToProto();
418
419 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
420 auto value1 =
421 ops::Const(root.WithDevice("/device:CPU:0"), alloc1.SerializeAsString());
422 auto value2 =
423 ops::Const(root.WithDevice("/device:CPU:0"), alloc2.SerializeAsString());
424 auto handle1 = ops::XRTAllocate(root, value1);
425 auto handle2 = ops::XRTAllocate(root, value2);
426 TF_ASSERT_OK(root.status());
427
428 tensorflow::ClientSession session(root);
429 std::vector<tensorflow::Tensor> outputs;
430 TF_EXPECT_OK(session.Run({handle1, handle2}, &outputs));
431 EXPECT_EQ(outputs.size(), 2);
432
433 int64 allocation_handle1 = outputs[0].scalar<int64>()();
434 int64 allocation_handle2 = outputs[1].scalar<int64>()();
435
436 Tensor release_tensor(DT_INT64, TensorShape({2}));
437 release_tensor.flat<int64>()(0) = allocation_handle1;
438 release_tensor.flat<int64>()(1) = allocation_handle2;
439
440 auto release = ops::XRTReleaseAllocationHandle(root, release_tensor);
441 outputs.clear();
442 TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release},
443 &outputs));
444 }
445
TEST(RawApiTest,CompileAndReleaseMany)446 TEST(RawApiTest, CompileAndReleaseMany) {
447 xrt::XLAComputation c1;
448 auto config1 = c1.mutable_config();
449 auto shapes1 = config1->mutable_program_shape();
450 *shapes1->add_parameters() =
451 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
452 *shapes1->add_parameters() =
453 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
454 *shapes1->mutable_result() =
455 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
456 StoreComputationSnapshot(AddAndScale(), c1.mutable_hlo_snapshot());
457
458 xrt::XLAComputation c2;
459 auto config2 = c2.mutable_config();
460 auto shapes2 = config2->mutable_program_shape();
461 *shapes2->add_parameters() =
462 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
463 *shapes2->add_parameters() =
464 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
465 *shapes2->mutable_result() =
466 xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})})
467 .ToProto();
468 StoreComputationSnapshot(AddAndTuple(), c2.mutable_hlo_snapshot());
469
470 xrt::XRTExecutionConfig e;
471 e.set_release_input_handles(true);
472 e.set_release_compilation_handle(false);
473
474 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
475 auto e_config =
476 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
477 auto computation1 =
478 ops::Const(root.WithDevice("/device:CPU:0"), c1.SerializeAsString());
479 auto c_handle1 = ops::XRTCompile(root, computation1);
480 auto computation2 =
481 ops::Const(root.WithDevice("/device:CPU:0"), c2.SerializeAsString());
482 auto c_handle2 = ops::XRTCompile(root, computation2);
483 TF_ASSERT_OK(root.status());
484
485 ClientSession session(root);
486 std::vector<Tensor> outputs;
487 TF_EXPECT_OK(session.Run({c_handle1.handle, c_handle2.handle}, &outputs));
488 EXPECT_EQ(outputs.size(), 2);
489
490 int64 compilation_handle1 = outputs[0].scalar<int64>()();
491 int64 compilation_handle2 = outputs[1].scalar<int64>()();
492
493 Tensor release_tensor(DT_INT64, TensorShape({2}));
494 release_tensor.flat<int64>()(0) = compilation_handle1;
495 release_tensor.flat<int64>()(1) = compilation_handle2;
496
497 auto release = ops::XRTReleaseCompilationHandle(root, release_tensor);
498 outputs.clear();
499 TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release},
500 &outputs));
501 }
502
TEST(RawApiTest,AllocAndClearAll)503 TEST(RawApiTest, AllocAndClearAll) {
504 xrt::XLAAllocation alloc;
505 *alloc.mutable_value() =
506 xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto();
507
508 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
509 auto value =
510 ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
511 auto handle = ops::XRTAllocate(root, value);
512 TF_ASSERT_OK(root.status());
513
514 tensorflow::ClientSession session(root);
515 std::vector<tensorflow::Tensor> outputs;
516 TF_EXPECT_OK(session.Run({handle}, &outputs));
517 EXPECT_EQ(outputs.size(), 1);
518
519 int64 allocation_handle = outputs[0].scalar<int64>()();
520
521 auto clear_all = ops::XRTReleaseAllAllocations(root);
522
523 outputs.clear();
524 TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {},
525 {clear_all}, &outputs));
526 EXPECT_EQ(outputs.size(), 0);
527
528 auto read_after_clear = ops::XRTReadLiteral(root, Input(allocation_handle));
529 EXPECT_EQ(session.Run({read_after_clear}, &outputs).code(),
530 tensorflow::error::Code::NOT_FOUND);
531 }
532
TEST(RawApiTest,ReadAndWriteState)533 TEST(RawApiTest, ReadAndWriteState) {
534 xrt::XLAAllocation alloc;
535 *alloc.mutable_value() = TwoElementTuple();
536
537 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
538 auto value =
539 ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
540 auto handle = ops::XRTAllocate(root, value);
541 auto read_back = ops::XRTReadLiteral(root, handle);
542 auto release = ops::XRTReleaseAllocationHandle(
543 root.WithControlDependencies(read_back), handle);
544 TF_ASSERT_OK(root.status());
545
546 tensorflow::ClientSession session(root);
547 std::vector<tensorflow::Tensor> outputs;
548 TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {read_back},
549 {release}, &outputs));
550
551 xla::LiteralProto response;
552 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
553
554 EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
555 }
556
TEST(RawApiTest,ReadAndWriteStateAutoFree)557 TEST(RawApiTest, ReadAndWriteStateAutoFree) {
558 xrt::XLAAllocation alloc;
559 *alloc.mutable_value() = TwoElementTuple();
560
561 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
562 auto value =
563 ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
564 auto handle = ops::XRTAllocate(root, value);
565 auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
566 TF_ASSERT_OK(root.status());
567
568 ClientSession session(root);
569 std::vector<Tensor> outputs;
570 TF_EXPECT_OK(session.Run({read_back}, &outputs));
571
572 xla::LiteralProto response;
573 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
574 EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
575 }
576
TEST(RawApiTest,SubBuffer)577 TEST(RawApiTest, SubBuffer) {
578 xrt::XLAAllocation alloc;
579 *alloc.mutable_value() = NestedTuple();
580
581 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
582 auto value =
583 ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
584 auto base_handle = ops::XRTAllocate(root, value);
585 auto index_0 = ops::Const(root.WithDevice("/device:CPU:0"), {0});
586 auto index_1 = ops::Const(root.WithDevice("/device:CPU:0"), {1});
587 auto index_00 = ops::Const(root.WithDevice("/device:CPU:0"), {0, 0});
588 auto sub_0 = ops::XRTSubTuple(root, base_handle, index_0);
589 auto sub_1 = ops::XRTSubTuple(root, base_handle, index_1);
590 auto sub_00 = ops::XRTSubTupleAndRelease(
591 root.WithControlDependencies(
592 {sub_0.output_handle.op(), sub_1.output_handle.op()}),
593 base_handle, index_00);
594 auto value_0 = ops::XRTReadLiteralAndRelease(root, sub_0);
595 auto value_1 = ops::XRTReadLiteralAndRelease(root, sub_1);
596 auto value_00 = ops::XRTReadLiteralAndRelease(root, sub_00);
597 TF_ASSERT_OK(root.status());
598
599 ClientSession session(root);
600 std::vector<Tensor> outputs;
601 TF_EXPECT_OK(session.Run({value_0, value_1, value_00}, &outputs));
602
603 auto base_literal = xla::Literal::CreateFromProto(alloc.value()).ValueOrDie();
604 auto base_elements = base_literal.DecomposeTuple();
605 auto nested_0_elements = base_elements[0].Clone().DecomposeTuple();
606 xla::LiteralProto response_0;
607 EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar<string>()()));
608 EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[0], response_0));
609 xla::LiteralProto response_1;
610 EXPECT_TRUE(response_1.ParseFromString(outputs[1].scalar<string>()()));
611 EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[1], response_1));
612 xla::LiteralProto response_00;
613 EXPECT_TRUE(response_00.ParseFromString(outputs[2].scalar<string>()()));
614 EXPECT_TRUE(CompareLiteralToLiteralProto(nested_0_elements[0], response_00));
615 }
616
TEST(RawApiTest,MakeTuple)617 TEST(RawApiTest, MakeTuple) {
618 xrt::XLAAllocation alloc_0;
619 *alloc_0.mutable_value() = TwoElementTuple();
620 xrt::XLAAllocation alloc_1;
621 *alloc_1.mutable_value() = ScalarLiteral();
622
623 // The trivial tuple that just forwards its input and releases it.
624 xrt::XLATupleNode desc_0;
625 desc_0.set_input_index(0);
626 desc_0.set_release_input_handle(true);
627
628 xrt::XLATupleNode desc_1;
629 auto subdesc_10 = desc_1.add_tuples();
630 auto subdesc_11 = desc_1.add_tuples();
631 subdesc_10->set_input_index(0);
632 auto subdesc_110 = subdesc_11->add_tuples();
633 subdesc_110->set_input_index(0);
634 auto subdesc_111 = subdesc_11->add_tuples();
635 subdesc_111->set_input_index(1);
636
637 xrt::XLATupleNode desc_2;
638 auto subdesc_20 = desc_2.add_tuples();
639 auto subdesc_21 = desc_2.add_tuples();
640 subdesc_20->set_input_index(1);
641 subdesc_20->set_release_input_handle(true);
642 subdesc_21->set_input_index(0);
643 subdesc_21->set_release_input_handle(true);
644
645 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
646 auto value_0 =
647 ops::Const(root.WithDevice("/device:CPU:0"), alloc_0.SerializeAsString());
648 auto handle_0 = ops::XRTAllocate(root, value_0);
649 auto value_1 =
650 ops::Const(root.WithDevice("/device:CPU:0"), alloc_1.SerializeAsString());
651 auto handle_1 = ops::XRTAllocate(root, value_1);
652 auto tuple_0 =
653 ops::Const(root.WithDevice("/device:CPU:0"), desc_0.SerializeAsString());
654 auto handle_2 =
655 ops::XRTMakeTuple(root, tuple_0, {static_cast<Output>(handle_0)});
656 // handle_0 has now been released.
657 auto tuple_1 =
658 ops::Const(root.WithDevice("/device:CPU:0"), desc_1.SerializeAsString());
659 auto handle_3 = ops::XRTMakeTuple(
660 root, tuple_1,
661 {static_cast<Output>(handle_1), static_cast<Output>(handle_2)});
662 auto tuple_2 =
663 ops::Const(root.WithDevice("/device:CPU:0"), desc_2.SerializeAsString());
664 // Make sure this runs after handle_3 has completed, since it will free
665 // handle_1 and handle_2.
666 auto handle_4 = ops::XRTMakeTuple(
667 root.WithControlDependencies(handle_3), tuple_2,
668 {static_cast<Output>(handle_1), static_cast<Output>(handle_2)});
669 // handle_1 and handle_2 have now been released.
670
671 auto res_0 = ops::XRTReadLiteralAndRelease(root, handle_3);
672 auto res_1 = ops::XRTReadLiteralAndRelease(root, handle_4);
673 TF_ASSERT_OK(root.status());
674
675 ClientSession session(root);
676 std::vector<Tensor> outputs;
677 TF_EXPECT_OK(session.Run({res_0, res_1}, &outputs));
678 xla::LiteralProto response_0;
679 EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar<string>()()));
680 xla::LiteralProto response_1;
681 EXPECT_TRUE(response_1.ParseFromString(outputs[1].scalar<string>()()));
682
683 auto expected_0 = MakeTuple0();
684 EXPECT_TRUE(CompareLiteralProtos(response_0, expected_0));
685 auto expected_1 = NestedTuple();
686 EXPECT_TRUE(CompareLiteralProtos(response_1, expected_1));
687 }
688
TEST(RawApiTest,CompileAndExecute)689 TEST(RawApiTest, CompileAndExecute) {
690 xrt::XLAAllocation p0;
691 *p0.mutable_value() = FloatVector({1.0f, 2.0f});
692 xrt::XLAAllocation p1;
693 *p1.mutable_value() = FloatVector({8.0f, 5.0f});
694
695 xrt::XLAComputation c;
696 auto config = c.mutable_config();
697 auto shapes = config->mutable_program_shape();
698 *shapes->add_parameters() =
699 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
700 *shapes->add_parameters() =
701 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
702 *shapes->mutable_result() =
703 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
704 StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot());
705
706 xrt::XRTExecutionConfig e;
707 e.set_release_input_handles(true);
708 e.set_release_compilation_handle(true);
709
710 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
711 auto e_config =
712 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
713 auto computation =
714 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
715 auto c_handle = ops::XRTCompile(root, computation);
716 auto p0_value =
717 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
718 auto p0_handle = ops::XRTAllocate(root, p0_value);
719 auto p1_value =
720 ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
721 auto p1_handle = ops::XRTAllocate(root, p1_value);
722 auto result = ops::XRTExecute(root, c_handle.handle, e_config,
723 {Output(p0_handle), Output(p1_handle)});
724 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
725 TF_ASSERT_OK(root.status());
726
727 ClientSession session(root);
728 std::vector<Tensor> outputs;
729 TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
730
731 xla::LiteralProto response;
732 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
733
734 auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
735 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
736
737 xla::ProgramShapeProto program_shape;
738 EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
739 EXPECT_EQ(program_shape.parameters_size(), 2);
740 }
741
TEST(RawApiTest,CompileAndExecuteWithArgumentVector)742 TEST(RawApiTest, CompileAndExecuteWithArgumentVector) {
743 xrt::XLAAllocation p0;
744 *p0.mutable_value() = FloatVector({1.0f, 2.0f});
745 xrt::XLAAllocation p1;
746 *p1.mutable_value() = FloatVector({8.0f, 5.0f});
747
748 xrt::XLAComputation c;
749 auto config = c.mutable_config();
750 auto shapes = config->mutable_program_shape();
751 *shapes->add_parameters() =
752 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
753 *shapes->add_parameters() =
754 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
755 *shapes->mutable_result() =
756 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
757 StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot());
758
759 xrt::XRTExecutionConfig e;
760 e.set_release_input_handles(true);
761 e.set_release_compilation_handle(true);
762
763 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
764 auto e_config =
765 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
766 auto computation =
767 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
768 auto c_handle = ops::XRTCompile(root, computation);
769 auto p0_value =
770 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
771 auto p0_handle = ops::XRTAllocate(root, p0_value);
772 auto p1_value =
773 ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
774 auto p1_handle = ops::XRTAllocate(root, p1_value);
775 auto packed_args = ops::Stack(root.WithDevice("/device:CPU:0"),
776 {Output(p0_handle), Output(p1_handle)});
777 auto result =
778 ops::XRTExecute(root, c_handle.handle, e_config, {Output(packed_args)});
779 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
780 TF_ASSERT_OK(root.status());
781
782 ClientSession session(root);
783 std::vector<Tensor> outputs;
784 TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
785
786 xla::LiteralProto response;
787 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
788
789 auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
790 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
791
792 xla::ProgramShapeProto program_shape;
793 EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
794 EXPECT_EQ(program_shape.parameters_size(), 2);
795 }
796
TEST(RawApiTest,CompileWithXlaReturnShapes)797 TEST(RawApiTest, CompileWithXlaReturnShapes) {
798 xla::XlaBuilder builder("XrtXlaShapes");
799 auto input_shape = xla::ShapeUtil::MakeShape(xla::BF16, {32, 3, 128, 128});
800 auto kernel_shape = xla::ShapeUtil::MakeShape(xla::BF16, {3, 3, 5, 5});
801 // Clear layouts to signal XLA we are ready to get whatever are coming out of
802 // the compilation process.
803 xla::LayoutUtil::ClearLayout(&input_shape);
804 xla::LayoutUtil::ClearLayout(&kernel_shape);
805 auto param_shape =
806 xla::ShapeUtil::MakeTupleShape({input_shape, kernel_shape});
807 auto param = xla::Parameter(&builder, 0, param_shape, "param");
808 auto input = xla::GetTupleElement(param, 0);
809 auto kernel = xla::GetTupleElement(param, 1);
810 xla::Conv(input, kernel, {1, 1}, xla::Padding::kSame);
811 TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation xla_computation, builder.Build());
812
813 auto result_shape = xla_computation.GetProgramShape().ValueOrDie().result();
814 // Clear the result shape layout to tell XLA we are accepting whatever are
815 // coming out of the compilation process.
816 xla::LayoutUtil::ClearLayout(&result_shape);
817
818 xrt::XLAComputation c;
819 auto config = c.mutable_config();
820 auto shapes = config->mutable_program_shape();
821 *shapes->add_parameters() = param_shape.ToProto();
822 *shapes->mutable_result() = result_shape.ToProto();
823 StoreComputationSnapshot(xla_computation, c.mutable_hlo_snapshot());
824
825 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
826 auto computation =
827 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
828 auto c_handle = ops::XRTCompile(root, computation);
829 auto release = ops::XRTReleaseCompilationHandle(root, c_handle.handle);
830 TF_ASSERT_OK(root.status());
831
832 ClientSession session(root);
833 std::vector<Tensor> outputs;
834 TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(),
835 {c_handle.program_shape}, {release}, &outputs));
836
837 xla::ProgramShapeProto program_shape_proto;
838 EXPECT_TRUE(program_shape_proto.ParseFromString(outputs[0].vec<string>()(0)));
839 xla::ProgramShape program_shape(program_shape_proto);
840 EXPECT_EQ(program_shape.parameters_size(), 1);
841
842 VLOG(2) << "Param: "
843 << xla::ShapeUtil::HumanStringWithLayout(program_shape.parameters(0));
844 VLOG(2) << "Result: "
845 << xla::ShapeUtil::HumanStringWithLayout(program_shape.result());
846
847 xla::ProgramShape xla_program_shape =
848 XlaCompiledProgramShape(xla_computation, xla::ProgramShape(*shapes));
849 EXPECT_TRUE(xla::LayoutUtil::Equal(
850 xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(),
851 xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0})
852 .layout()));
853 EXPECT_TRUE(xla::LayoutUtil::Equal(
854 xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {1}).layout(),
855 xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {1})
856 .layout()));
857 EXPECT_TRUE(xla::LayoutUtil::Equal(program_shape.result().layout(),
858 xla_program_shape.result().layout()));
859 }
860
TEST(RawApiTest,DotGeneralWithLayoutTest)861 TEST(RawApiTest, DotGeneralWithLayoutTest) {
862 auto layout = xla::LayoutUtil::MakeLayout({0, 1});
863
864 xrt::XLAAllocation p0;
865 *p0.mutable_value() = FloatMatrix({{1.0f, 2.0f}, {3.0f, 4.0f}}, layout);
866 xrt::XLAAllocation p1;
867 *p1.mutable_value() = FloatMatrix({{8.0f}, {5.0f}}, layout);
868
869 xrt::XLAComputation c;
870 auto config = c.mutable_config();
871 auto shapes = config->mutable_program_shape();
872 *shapes->add_parameters() =
873 xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1}).ToProto();
874 *shapes->add_parameters() =
875 xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}).ToProto();
876 *shapes->mutable_result() =
877 xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}).ToProto();
878 StoreComputationSnapshot(Dot(), c.mutable_hlo_snapshot());
879
880 xrt::XRTExecutionConfig e;
881 e.set_release_input_handles(true);
882 e.set_release_compilation_handle(true);
883
884 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
885 auto e_config =
886 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
887 auto computation =
888 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
889 auto c_handle = ops::XRTCompile(root, computation);
890 auto p0_value =
891 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
892 auto p0_handle = ops::XRTAllocate(root, p0_value);
893 auto p1_value =
894 ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
895 auto p1_handle = ops::XRTAllocate(root, p1_value);
896 auto result = ops::XRTExecute(root, c_handle.handle, e_config,
897 {Output(p0_handle), Output(p1_handle)});
898 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
899 TF_ASSERT_OK(root.status());
900
901 ClientSession session(root);
902 std::vector<Tensor> outputs;
903 TF_EXPECT_OK(session.Run({read_back}, &outputs));
904
905 xla::LiteralProto response;
906 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
907
908 auto expected =
909 xla::LiteralUtil::CreateR2WithLayout<float>({{18.0f}, {44.0f}}, layout);
910
911 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
912 }
913
TEST(RawApiTest,CompileAndExecuteZeroArg)914 TEST(RawApiTest, CompileAndExecuteZeroArg) {
915 xrt::XLAComputation c;
916 auto config = c.mutable_config();
917 auto shapes = config->mutable_program_shape();
918 *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto();
919
920 xrt::XRTExecutionConfig e;
921 e.set_release_input_handles(true);
922 e.set_release_compilation_handle(true);
923 StoreComputationSnapshot(OnePlusTwo(), c.mutable_hlo_snapshot());
924
925 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
926 auto e_config =
927 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
928 auto computation =
929 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
930 auto c_handle = ops::XRTCompile(root, computation);
931 auto result = ops::XRTExecute(root, c_handle.handle, e_config,
932 std::initializer_list<Input>({}));
933 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
934 TF_ASSERT_OK(root.status());
935
936 ClientSession session(root);
937 std::vector<Tensor> outputs;
938 TF_EXPECT_OK(session.Run({read_back}, &outputs));
939
940 xla::LiteralProto response;
941 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
942
943 auto expected = xla::LiteralUtil::CreateR0<float>(3.0f);
944 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
945 }
946
TEST(RawApiTest,CompileAndExecuteReturnTuple)947 TEST(RawApiTest, CompileAndExecuteReturnTuple) {
948 xrt::XLAAllocation p0;
949 *p0.mutable_value() = FloatVector({1.0f, 2.0f});
950 xrt::XLAAllocation p1;
951 *p1.mutable_value() = FloatVector({8.0f, 5.0f});
952
953 xrt::XLAComputation c;
954 auto config = c.mutable_config();
955 auto shapes = config->mutable_program_shape();
956 *shapes->add_parameters() =
957 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
958 *shapes->add_parameters() =
959 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
960 *shapes->mutable_result() =
961 xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})})
962 .ToProto();
963 StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot());
964
965 xrt::XRTExecutionConfig e;
966 e.set_release_input_handles(true);
967 e.set_release_compilation_handle(true);
968
969 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
970 auto e_config =
971 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
972 auto computation =
973 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
974 auto c_handle = ops::XRTCompile(root, computation);
975 auto p0_value =
976 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
977 auto p0_handle = ops::XRTAllocate(root, p0_value);
978 auto p1_value =
979 ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
980 auto p1_handle = ops::XRTAllocate(root, p1_value);
981 auto result = ops::XRTExecute(root, c_handle.handle, e_config,
982 {Output(p0_handle), Output(p1_handle)});
983 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
984 TF_ASSERT_OK(root.status());
985
986 ClientSession session(root);
987 std::vector<Tensor> outputs;
988 TF_EXPECT_OK(session.Run({read_back}, &outputs));
989
990 xla::LiteralProto response;
991 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
992
993 auto sum = xla::LiteralUtil::CreateR1<float>({9.0f, 7.0f});
994 auto expected = xla::LiteralUtil::MakeTuple({&sum});
995 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
996 }
997
TEST(RawApiTest,CompileAndExecuteReturnExplodedTuple)998 TEST(RawApiTest, CompileAndExecuteReturnExplodedTuple) {
999 xrt::XLAAllocation p0;
1000 *p0.mutable_value() = xla::LiteralUtil::CreateR0<float>(12.0f).ToProto();
1001
1002 xrt::XLAAllocation p1;
1003 *p1.mutable_value() = xla::LiteralUtil::CreateR0<float>(3.0f).ToProto();
1004
1005 xrt::XLAComputation c;
1006 auto config = c.mutable_config();
1007 auto shapes = config->mutable_program_shape();
1008 *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto();
1009 *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto();
1010 *shapes->mutable_result() =
1011 xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {}),
1012 xla::ShapeUtil::MakeShape(xla::F32, {})})
1013 .ToProto();
1014 StoreComputationSnapshot(AddAndSubTuple(), c.mutable_hlo_snapshot());
1015
1016 xrt::XRTExecutionConfig e;
1017 e.set_release_input_handles(true);
1018 e.set_release_compilation_handle(true);
1019 e.set_return_exploded_tuple(true);
1020
1021 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1022 auto e_config =
1023 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1024 auto computation =
1025 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1026 auto c_handle = ops::XRTCompile(root, computation);
1027 auto p0_value =
1028 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1029 auto p0_handle = ops::XRTAllocate(root, p0_value);
1030 auto p1_value =
1031 ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1032 auto p1_handle = ops::XRTAllocate(root, p1_value);
1033 auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1034 {Output(p0_handle), Output(p1_handle)});
1035 TF_ASSERT_OK(root.status());
1036
1037 ClientSession session(root);
1038 std::vector<Tensor> outputs;
1039 TF_EXPECT_OK(session.Run({result}, &outputs));
1040 EXPECT_EQ(outputs.size(), 1);
1041
1042 auto handles_vec = outputs.front().vec<int64>();
1043 EXPECT_EQ(handles_vec.size(), 2);
1044
1045 const float kResults[2] = {15.0f, 9.0f};
1046 for (int64 i = 0; i < handles_vec.size(); ++i) {
1047 auto read_back = ops::XRTReadLiteralAndRelease(root, Input(handles_vec(i)));
1048 std::vector<Tensor> voutputs;
1049 TF_EXPECT_OK(session.Run({read_back}, &voutputs));
1050 EXPECT_EQ(voutputs.size(), 1);
1051
1052 xla::LiteralProto response;
1053 EXPECT_TRUE(response.ParseFromString(voutputs[0].scalar<string>()()));
1054
1055 auto expected = xla::LiteralUtil::CreateR0<float>(kResults[i]);
1056 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1057 }
1058 }
1059
TEST(RawApiTest,LeakCompilationReference)1060 TEST(RawApiTest, LeakCompilationReference) {
1061 xrt::XLAComputation c;
1062 auto config = c.mutable_config();
1063 auto shapes = config->mutable_program_shape();
1064 *shapes->add_parameters() =
1065 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1066 *shapes->add_parameters() =
1067 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1068 *shapes->mutable_result() =
1069 xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})})
1070 .ToProto();
1071 StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot());
1072
1073 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1074 auto computation =
1075 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1076 auto c_handle = ops::XRTCompile(root, computation);
1077 TF_ASSERT_OK(root.status());
1078
1079 ClientSession session(root);
1080 std::vector<Tensor> outputs;
1081 TF_EXPECT_OK(session.Run({c_handle.handle}, &outputs));
1082 }
1083
TEST(RawApiTest,CompileAndExecuteWithReusedBuffers)1084 TEST(RawApiTest, CompileAndExecuteWithReusedBuffers) {
1085 xla::Shape element_shape = xla::ShapeUtil::MakeShape(xla::F32, {2});
1086 xla::Shape shape =
1087 xla::ShapeUtil::MakeTupleShape({element_shape, element_shape});
1088 xla::Shape return_shape = xla::ShapeUtil::MakeTupleShape(
1089 {element_shape, element_shape, element_shape, element_shape});
1090 xla::XlaBuilder builder("ReuseBuffer");
1091 auto param = xla::Parameter(&builder, 0, shape, "param");
1092 auto p0 = xla::GetTupleElement(param, 0);
1093 auto p1 = xla::GetTupleElement(param, 1);
1094 auto add = xla::Add(p0, p1);
1095 auto sub = xla::Sub(p0, p1);
1096 xla::Tuple(&builder, {add, sub, p0, p1});
1097
1098 // Flip the tuple literals in the input handle.
1099 builder.SetUpAlias({1}, 0, {0});
1100 builder.SetUpAlias({0}, 0, {1});
1101
1102 auto computation = builder.Build().ValueOrDie();
1103
1104 auto literal0 = xla::LiteralUtil::CreateR1<float>({1.0f, 2.0f});
1105 auto literal1 = xla::LiteralUtil::CreateR1<float>({5.0f, 9.0f});
1106 auto literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1});
1107
1108 xrt::XLAAllocation param_alloc;
1109 *param_alloc.mutable_value() = literal.ToProto();
1110
1111 xrt::XLAComputation c;
1112 auto config = c.mutable_config();
1113 auto shapes = config->mutable_program_shape();
1114 *shapes->add_parameters() = shape.ToProto();
1115 *shapes->mutable_result() = return_shape.ToProto();
1116 StoreComputationSnapshot(computation, c.mutable_hlo_snapshot());
1117
1118 xrt::XRTExecutionConfig e;
1119 e.set_release_input_handles(false);
1120 e.set_release_compilation_handle(true);
1121
1122 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1123 ClientSession session(root);
1124 auto e_config =
1125 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1126 auto c_data =
1127 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1128 auto c_handle = ops::XRTCompile(root, c_data);
1129 auto param_value = ops::Const(root.WithDevice("/device:CPU:0"),
1130 param_alloc.SerializeAsString());
1131 auto param_handle = ops::XRTAllocate(root, param_value);
1132 TF_ASSERT_OK(root.status());
1133
1134 std::vector<Tensor> outputs;
1135 TF_EXPECT_OK(session.Run({param_handle}, &outputs));
1136
1137 int64 alloc_handle = outputs[0].scalar<int64>()();
1138
1139 // Note that we release the result handle immediately, but since we aliased
1140 // the output buffers onto the input allocation ones (held in alloc_handle),
1141 // we can fetch the result from there.
1142 auto result =
1143 ops::XRTExecute(root, c_handle.handle, e_config, {Input(alloc_handle)});
1144 auto read_back = ops::XRTReadLiteral(root, result);
1145 auto release = ops::XRTReleaseAllocationHandle(
1146 root.WithControlDependencies(read_back), result);
1147 TF_ASSERT_OK(root.status());
1148
1149 outputs.clear();
1150 TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {read_back},
1151 {release}, &outputs));
1152
1153 xla::Literal exec_literal = ReadOutputLiteral(outputs, 0);
1154 auto exec_literal_parts = exec_literal.DecomposeTuple();
1155 ASSERT_EQ(exec_literal_parts.size(), 4);
1156
1157 EXPECT_TRUE(CompareLiterals(exec_literal_parts[2], literal0));
1158 EXPECT_TRUE(CompareLiterals(exec_literal_parts[3], literal1));
1159
1160 // Now we read back the original input handle values, which at this point
1161 // should contain the result of the XLA computation.
1162 auto read_handle = ops::XRTReadLiteral(root, Input(alloc_handle));
1163 TF_ASSERT_OK(root.status());
1164 auto release_handle = ops::XRTReleaseAllocationHandle(
1165 root.WithControlDependencies(read_handle), Input(alloc_handle));
1166 TF_ASSERT_OK(root.status());
1167
1168 outputs.clear();
1169 TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {read_handle},
1170 {release_handle}, &outputs));
1171
1172 xla::Literal return_literal = ReadOutputLiteral(outputs, 0);
1173
1174 auto expected_literal0 = xla::LiteralUtil::CreateR1<float>({6.0f, 11.0f});
1175 auto expected_literal1 = xla::LiteralUtil::CreateR1<float>({-4.0f, -7.0f});
1176 // The first element of the computation returned tuple would be the add
1177 // (expected_literal0), but since we flipped the buffers, the sub
1178 // (expected_literal1) should come first.
1179 auto expected_literal =
1180 xla::LiteralUtil::MakeTuple({&expected_literal1, &expected_literal0});
1181
1182 EXPECT_TRUE(CompareLiterals(return_literal, expected_literal));
1183 }
1184
TEST(RawApiTest,CompileAndExecuteWithS64Argument)1185 TEST(RawApiTest, CompileAndExecuteWithS64Argument) {
1186 xrt::XLAAllocation p0;
1187 *p0.mutable_value() = xla::LiteralUtil::CreateR0<int64>(11031965).ToProto();
1188 xrt::XLAAllocation p1;
1189 *p1.mutable_value() = xla::LiteralUtil::CreateR0<int64>(4091934).ToProto();
1190
1191 xrt::XLAComputation c;
1192 auto config = c.mutable_config();
1193 auto shapes = config->mutable_program_shape();
1194 *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
1195 *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
1196 *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
1197 StoreComputationSnapshot(AddS64(), c.mutable_hlo_snapshot());
1198
1199 xrt::XRTExecutionConfig e;
1200 e.set_release_input_handles(true);
1201 e.set_release_compilation_handle(true);
1202 e.set_return_exploded_tuple(true);
1203
1204 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1205 auto e_config =
1206 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1207 auto computation =
1208 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1209 auto c_handle = ops::XRTCompile(root, computation);
1210 auto p0_value =
1211 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1212 auto p0_handle = ops::XRTAllocate(root, p0_value);
1213 auto p1_value =
1214 ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1215 auto p1_handle = ops::XRTAllocate(root, p1_value);
1216 auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1217 {Output(p0_handle), Output(p1_handle)});
1218 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1219 TF_ASSERT_OK(root.status());
1220
1221 ClientSession session(root);
1222 std::vector<Tensor> outputs;
1223 TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1224
1225 xla::LiteralProto response;
1226 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
1227
1228 auto expected = xla::LiteralUtil::CreateR0<int64>(15123899);
1229 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1230
1231 xla::ProgramShapeProto program_shape;
1232 EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
1233 EXPECT_EQ(program_shape.parameters_size(), 2);
1234 EXPECT_TRUE(xla::ShapeUtil::HasPrimitiveType(
1235 xla::Shape(program_shape.result()), xla::S64));
1236 }
1237
1238 } // namespace
1239
1240 } // namespace tensorflow
1241
main(int argc,char ** argv)1242 int main(int argc, char** argv) {
1243 tensorflow::xla_test_device_ptr = new tensorflow::string("XLA_CPU");
1244 tensorflow::xla_platform_ptr = new tensorflow::string("CPU");
1245 std::vector<tensorflow::Flag> flag_list = {
1246 tensorflow::Flag("xla_test_device", tensorflow::xla_test_device_ptr,
1247 "Tensorflow device type to use for test, e.g., XLA_CPU"),
1248 tensorflow::Flag("xla_platform", tensorflow::xla_platform_ptr,
1249 "The XLA platform to select for the device"),
1250 };
1251 tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
1252 const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
1253 if (!parse_result) {
1254 LOG(ERROR) << "\n" << usage;
1255 return 2;
1256 }
1257 testing::InitGoogleTest(&argc, argv);
1258 if (argc > 1) {
1259 LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
1260 return 2;
1261 }
1262 return RUN_ALL_TESTS();
1263 }
1264