1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18 #include <vector>
19
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/cc/client/client_session.h"
22 #include "tensorflow/cc/framework/ops.h"
23 #include "tensorflow/cc/framework/scope.h"
24 #include "tensorflow/cc/ops/standard_ops.h"
25 #include "tensorflow/compiler/tf2xla/literal_util.h"
26 #include "tensorflow/compiler/tf2xla/shape_util.h"
27 #include "tensorflow/compiler/xla/client/client_library.h"
28 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
29 #include "tensorflow/compiler/xla/client/lib/constants.h"
30 #include "tensorflow/compiler/xla/client/local_client.h"
31 #include "tensorflow/compiler/xla/client/xla_builder.h"
32 #include "tensorflow/compiler/xla/client/xla_computation.h"
33 #include "tensorflow/compiler/xla/literal.h"
34 #include "tensorflow/compiler/xla/literal_util.h"
35 #include "tensorflow/compiler/xla/service/platform_util.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h"
39 #include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h"
40 #include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h"
41 #include "tensorflow/compiler/xrt/xrt.pb.h"
42 #include "tensorflow/core/framework/tensor.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45 #include "tensorflow/core/lib/gtl/array_slice.h"
46 #include "tensorflow/core/platform/types.h"
47 #include "tensorflow/core/util/command_line_flags.h"
48
49 namespace tensorflow {
50 namespace {
51
ReturnDynamicR1()52 xla::XlaComputation ReturnDynamicR1() {
53 xla::XlaBuilder builder("ReturnDynamicR1");
54 auto p0 = xla::Parameter(&builder, 0,
55 xla::ShapeUtil::MakeShape(xla::F32, {4}), "P0");
56 auto p1 = xla::Parameter(&builder, 1,
57 xla::ShapeUtil::MakeShape(xla::F32, {4}), "P1");
58 auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}),
59 "P2");
60 auto sum = xla::Add(p0, p1);
61 auto pad_sum = xla::SetDimensionSize(sum, p2, 0);
62 return builder.Build(pad_sum).ValueOrDie();
63 }
64
ReturnDynamicR2()65 xla::XlaComputation ReturnDynamicR2() {
66 xla::XlaBuilder builder("ReturnDynamicR2");
67 auto p0 = xla::Parameter(&builder, 0,
68 xla::ShapeUtil::MakeShape(xla::F32, {2, 4}), "P0");
69 auto p1 = xla::Parameter(&builder, 1,
70 xla::ShapeUtil::MakeShape(xla::F32, {2, 4}), "P1");
71 auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}),
72 "P2");
73 auto sum = xla::Add(p0, p1);
74 auto pad_sum_dim0 = xla::SetDimensionSize(sum, p2, 0);
75 auto pad_sum_dim1 = xla::SetDimensionSize(pad_sum_dim0, p2, 1);
76 return builder.Build(pad_sum_dim1).ValueOrDie();
77 }
78
AcceptDynamicR1()79 xla::XlaComputation AcceptDynamicR1() {
80 xla::XlaBuilder builder("AcceptDynamicR1");
81 xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
82 dyn_shape.set_dynamic_dimension(0, true);
83 auto p0 = xla::Parameter(&builder, 0, dyn_shape, "P0");
84 auto p1 = xla::Parameter(&builder, 1, dyn_shape, "P1");
85 auto sum = xla::Add(p0, p1);
86 return builder.Build(sum).ValueOrDie();
87 }
88
AcceptDynamicR2()89 xla::XlaComputation AcceptDynamicR2() {
90 xla::XlaBuilder builder("AcceptDynamicR2");
91 xla::Shape dyn_shape;
92 dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4});
93 dyn_shape.set_dynamic_dimension(1, true);
94 auto p0 = xla::Parameter(&builder, 0, dyn_shape, "P0");
95 auto negate = xla::Neg(p0);
96 return builder.Build(negate).ValueOrDie();
97 }
98
ReturnDynamicR1Tuple()99 xla::XlaComputation ReturnDynamicR1Tuple() {
100 xla::XlaBuilder builder("ReturnDynamicR1Tuple");
101 auto p0 = xla::Parameter(&builder, 0,
102 xla::ShapeUtil::MakeShape(xla::F32, {4}), "P0");
103 auto p1 = xla::Parameter(&builder, 1,
104 xla::ShapeUtil::MakeShape(xla::F32, {4}), "P1");
105 auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}),
106 "P2");
107 auto sum = xla::Add(p0, p1);
108 auto sub = xla::Sub(p0, p1);
109 auto one = xla::One(&builder, xla::S32);
110 auto pad_sum = xla::SetDimensionSize(sum, p2, 0);
111 auto pad_sub = xla::SetDimensionSize(sub, p2 + one, 0);
112 auto tuple = xla::Tuple(&builder, {pad_sum, sum, pad_sub});
113 return builder.Build(tuple, /*remove_dynamic_dimensions=*/true).ValueOrDie();
114 }
115
AcceptDynamicR1Tuple()116 xla::XlaComputation AcceptDynamicR1Tuple() {
117 xla::XlaBuilder builder("AcceptDynamicR1");
118 xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
119 dyn_shape.set_dynamic_dimension(0, true);
120 xla::Shape tuple_shape =
121 xla::ShapeUtil::MakeTupleShape({dyn_shape, dyn_shape});
122 xla::Shape nest_tuple_shape =
123 xla::ShapeUtil::MakeTupleShape({dyn_shape, dyn_shape});
124 auto p = xla::Parameter(&builder, 0, tuple_shape, "P0");
125 auto p0 = xla::GetTupleElement(p, 0);
126 auto p1 = xla::GetTupleElement(p, 1);
127 auto sum = xla::Add(p0, p1);
128 return builder.Build(sum).ValueOrDie();
129 }
130
131 template <typename T>
CreateR0(T v)132 xla::LiteralProto CreateR0(T v) {
133 auto array = xla::LiteralUtil::CreateR0<T>(v);
134 return array.ToProto();
135 }
136
137 class XrtClientSession : public ClientSession {
138 public:
XrtClientSession(const Scope & scope)139 explicit XrtClientSession(const Scope& scope) : ClientSession(scope) {
140 auto clear_all = ops::XRTReleaseAllAllocations(scope);
141 std::vector<Tensor> outputs;
142 TF_CHECK_OK(Run(ClientSession::FeedType(), {}, {clear_all}, &outputs));
143 }
144 };
145
146 string* xla_test_device_ptr; // initial value set in main()
147 string* xla_platform_ptr; // initial value set in main()
148
DeviceFromFlag()149 string DeviceFromFlag() {
150 string xla_test_device = *xla_test_device_ptr;
151 return absl::StrCat("/device:", xla_test_device, ":0");
152 }
153
GetAttrLayout(absl::Span<const int64> minor_to_mayor)154 std::vector<int> GetAttrLayout(absl::Span<const int64> minor_to_mayor) {
155 std::vector<int> layout;
156 for (auto dim : minor_to_mayor) {
157 layout.push_back(static_cast<int>(dim));
158 }
159 return layout;
160 }
161
TwoElementTuple()162 xla::LiteralProto TwoElementTuple() {
163 auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
164 auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
165 auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
166 return tuple.ToProto();
167 }
168
BasedTwoElementTuple(float base)169 xla::LiteralProto BasedTwoElementTuple(float base) {
170 auto array = xla::LiteralUtil::CreateR1<float>({base, base + 1});
171 auto matrix = xla::LiteralUtil::CreateR2<float>(
172 {{base + 2, base + 3}, {base + 4, base + 5}});
173 auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
174 return tuple.ToProto();
175 }
176
ScalarLiteral()177 xla::LiteralProto ScalarLiteral() {
178 auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
179 return scalar.ToProto();
180 }
181
NestedTuple()182 xla::LiteralProto NestedTuple() {
183 auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
184 auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
185 auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
186 auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
187 auto nested = xla::LiteralUtil::MakeTuple({&tuple, &scalar});
188 return nested.ToProto();
189 }
190
MakeTuple0()191 xla::LiteralProto MakeTuple0() {
192 auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
193 auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
194 auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
195 auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
196 auto nested0 = xla::LiteralUtil::MakeTuple({&scalar, &tuple});
197 auto nested1 = xla::LiteralUtil::MakeTuple({&scalar, &nested0});
198 return nested1.ToProto();
199 }
200
FloatVector(absl::Span<const float> v)201 xla::LiteralProto FloatVector(absl::Span<const float> v) {
202 auto array = xla::LiteralUtil::CreateR1<float>(v);
203 return array.ToProto();
204 }
205
FloatMatrix(std::initializer_list<std::initializer_list<float>> v,const xla::Layout & layout)206 xla::LiteralProto FloatMatrix(
207 std::initializer_list<std::initializer_list<float>> v,
208 const xla::Layout& layout) {
209 auto array = xla::LiteralUtil::CreateR2WithLayout<float>(v, layout);
210 return array.ToProto();
211 }
212
ReadOutputLiteral(const std::vector<Tensor> & outputs,size_t idx)213 xla::Literal ReadOutputLiteral(const std::vector<Tensor>& outputs, size_t idx) {
214 xla::LiteralProto response;
215 CHECK(ParseFromTString(outputs[idx].scalar<tstring>()(), &response));
216 return xla::Literal::CreateFromProto(response).ValueOrDie();
217 }
218
CompareLiteralProtos(const xla::LiteralProto & a,const xla::LiteralProto & b)219 bool CompareLiteralProtos(const xla::LiteralProto& a,
220 const xla::LiteralProto& b) {
221 auto l_a = xla::Literal::CreateFromProto(a).ValueOrDie();
222 auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
223 bool equal = l_a == l_b;
224 if (!equal) {
225 LOG(INFO) << "LiteralProtos don't match:\n"
226 << a.DebugString() << "\n!=\n"
227 << b.DebugString();
228 }
229 return equal;
230 }
231
CompareLiteralToLiteralProto(const xla::Literal & a,const xla::LiteralProto & b)232 bool CompareLiteralToLiteralProto(const xla::Literal& a,
233 const xla::LiteralProto& b) {
234 auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
235 bool equal = a == l_b;
236 if (!equal) {
237 LOG(INFO) << "Literal and LiteralProto don't match:\n"
238 << a.ToProto().DebugString() << "\n!=\n"
239 << b.DebugString();
240 }
241 return equal;
242 }
243
CompareLiterals(const xla::Literal & a,const xla::Literal & b)244 bool CompareLiterals(const xla::Literal& a, const xla::Literal& b) {
245 bool equal = a == b;
246 if (!equal) {
247 LOG(INFO) << "Literals don't match:\n"
248 << a.ToProto().DebugString() << "\n!=\n"
249 << b.ToProto().DebugString();
250 }
251 return equal;
252 }
253
OnePlusTwo()254 xla::XlaComputation OnePlusTwo() {
255 xla::XlaBuilder builder("OnePlusTwo");
256 auto c0 = xla::ConstantR0(&builder, 1.0f);
257 auto c1 = xla::ConstantR0(&builder, 2.0f);
258 xla::Add(c0, c1);
259 return builder.Build().ValueOrDie();
260 }
261
AddAndScale()262 xla::XlaComputation AddAndScale() {
263 xla::XlaBuilder builder("AddAndScale");
264 auto p0 = xla::Parameter(&builder, 0,
265 xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0");
266 auto p1 = xla::Parameter(&builder, 1,
267 xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1");
268 auto sum = xla::Add(p0, p1);
269 auto c = xla::ConstantR0<float>(&builder, 3.0f);
270 xla::Mul(sum, c);
271 return builder.Build().ValueOrDie();
272 }
273
SubAndScale()274 xla::XlaComputation SubAndScale() {
275 xla::XlaBuilder builder("SubAndScale");
276 auto p0 = xla::Parameter(&builder, 0,
277 xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0");
278 auto p1 = xla::Parameter(&builder, 1,
279 xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1");
280 auto sum = xla::Sub(p0, p1);
281 auto c = xla::ConstantR0<float>(&builder, 11.0f);
282 xla::Mul(sum, c);
283 return builder.Build().ValueOrDie();
284 }
285
Dot()286 xla::XlaComputation Dot() {
287 xla::XlaBuilder builder("Dot");
288 auto p0 = xla::Parameter(
289 &builder, 0,
290 xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1}), "P0");
291 auto p1 = xla::Parameter(
292 &builder, 1,
293 xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}), "P1");
294 xla::DotDimensionNumbers ddn;
295 ddn.add_lhs_contracting_dimensions(1);
296 ddn.add_rhs_contracting_dimensions(0);
297 xla::DotGeneral(p0, p1, ddn);
298 return builder.Build().ValueOrDie();
299 }
300
AddS64()301 xla::XlaComputation AddS64() {
302 xla::XlaBuilder builder("AddS64");
303 auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::S64, {}),
304 "P0");
305 auto p1 = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(xla::S64, {}),
306 "P1");
307 xla::Add(p0, p1);
308 return builder.Build().ValueOrDie();
309 }
310
AddAndTuple()311 xla::XlaComputation AddAndTuple() {
312 xla::XlaBuilder builder("AddAndTuple");
313 auto p0 = xla::Parameter(&builder, 0,
314 xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0");
315 auto p1 = xla::Parameter(&builder, 1,
316 xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1");
317 auto sum = xla::Add(p0, p1);
318 xla::Tuple(&builder, {sum});
319 return builder.Build().ValueOrDie();
320 }
321
AddAndSubTuple()322 xla::XlaComputation AddAndSubTuple() {
323 xla::XlaBuilder builder("AddAndSubTuple");
324 auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::F32, {}),
325 "P0");
326 auto p1 = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(xla::F32, {}),
327 "P1");
328 auto sum = xla::Add(p0, p1);
329 auto sub = xla::Sub(p0, p1);
330 xla::Tuple(&builder, {sum, sub});
331 return builder.Build().ValueOrDie();
332 }
333
BroadcastComputation(const xla::Shape & shape,absl::Span<const xla::int64> dimensions)334 xla::XlaComputation BroadcastComputation(
335 const xla::Shape& shape, absl::Span<const xla::int64> dimensions) {
336 xla::XlaBuilder builder("BroadcastComputation");
337 auto p0 = xla::Parameter(&builder, 0, shape, "P0");
338 xla::Broadcast(p0, dimensions);
339 return builder.Build().ValueOrDie();
340 }
341
IsEqualComputation(const xla::Shape & shape)342 xla::XlaComputation IsEqualComputation(const xla::Shape& shape) {
343 xla::XlaBuilder builder("IsEqualComputation");
344 auto p0 = xla::Parameter(&builder, 0, shape, "P0");
345 auto p1 = xla::Parameter(&builder, 1, shape, "P1");
346 auto cmp =
347 xla::Ne(xla::Sub(p0, p1), xla::Zero(&builder, shape.element_type()));
348 auto icmp = xla::ConvertElementType(cmp, xla::S32);
349 xla::ReduceAll(icmp, xla::Zero(&builder, xla::S32),
350 xla::CreateScalarAddComputation(xla::S32, &builder));
351 return builder.Build().ValueOrDie();
352 }
353
StoreComputationSnapshot(const xla::XlaComputation & computation,xla::HloSnapshot * dst)354 void StoreComputationSnapshot(const xla::XlaComputation& computation,
355 xla::HloSnapshot* dst) {
356 auto snapshot = computation.Snapshot().ValueOrDie();
357 *dst = *snapshot;
358 }
359
XlaCompiledProgramShape(const xla::XlaComputation & computation,const xla::ProgramShape & input_program_shape)360 xla::ProgramShape XlaCompiledProgramShape(
361 const xla::XlaComputation& computation,
362 const xla::ProgramShape& input_program_shape) {
363 se::Platform* platform =
364 xla::PlatformUtil::GetPlatform(*xla_platform_ptr).ValueOrDie();
365 xla::LocalClient* client =
366 xla::ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
367 xla::ExecutableBuildOptions exec_options;
368 exec_options.set_result_layout(input_program_shape.result());
369 std::vector<const xla::Shape*> parameters_shapes;
370 for (int64 i = 0; i < input_program_shape.parameters_size(); ++i) {
371 parameters_shapes.push_back(&input_program_shape.parameters(i));
372 }
373 std::vector<std::unique_ptr<xla::LocalExecutable>> local_executables =
374 client->Compile(computation, parameters_shapes, exec_options)
375 .ConsumeValueOrDie();
376 EXPECT_EQ(local_executables.size(), 1);
377 std::unique_ptr<xla::LocalExecutable> local_executable =
378 std::move(local_executables[0]);
379 return local_executable->executable()
380 ->module()
381 .entry_computation()
382 ->ComputeProgramShape();
383 }
384
TEST(RawApiTest,AllocFromTensor)385 TEST(RawApiTest, AllocFromTensor) {
386 xla::Literal literal =
387 xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
388 Tensor tensor;
389 TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor));
390
391 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
392 std::vector<int> layout =
393 GetAttrLayout(literal.shape().layout().minor_to_major());
394 ops::XRTAllocateFromTensor::Attrs alloc_attrs =
395 ops::XRTAllocateFromTensor::Layouts(layout);
396 auto handle =
397 ops::XRTAllocateFromTensor(root, {tensor}, {tensor.shape()}, alloc_attrs);
398 auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
399 TF_ASSERT_OK(root.status());
400
401 XrtClientSession session(root);
402 std::vector<Tensor> outputs;
403 TF_EXPECT_OK(session.Run({read_back}, &outputs));
404 EXPECT_EQ(outputs.size(), 1);
405
406 xla::LiteralProto response;
407 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
408 EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
409 }
410
TEST(RawApiTest,AllocUninitialized)411 TEST(RawApiTest, AllocUninitialized) {
412 xla::Literal literal =
413 xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
414 Tensor tensor;
415 TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor));
416
417 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
418 std::vector<int> layout =
419 GetAttrLayout(literal.shape().layout().minor_to_major());
420
421 auto allocate_op =
422 ops::XRTAllocateUninitialized(root, DT_FLOAT, tensor.shape());
423
424 Tensor handle;
425 std::vector<Tensor> outputs;
426 XrtClientSession session(root);
427 // Allocate the tensor
428 {
429 TF_EXPECT_OK(session.Run({allocate_op}, &outputs));
430 handle = outputs[0];
431 }
432
433 // Make sure it has the expected shape
434 {
435 auto read_back_op = ops::XRTReadLiteral(root, handle);
436 TF_ASSERT_OK(root.status());
437
438 TF_EXPECT_OK(session.Run({read_back_op}, &outputs));
439 EXPECT_EQ(outputs.size(), 1);
440 xla::LiteralProto read_back_literal;
441 EXPECT_TRUE(
442 ParseFromTString(outputs[0].scalar<tstring>()(), &read_back_literal));
443 Tensor read_back_tensor;
444 TF_ASSERT_OK(LiteralToHostTensor(
445 xla::Literal::CreateFromProto(read_back_literal).ValueOrDie(), DT_FLOAT,
446 &read_back_tensor));
447
448 // The shape should be the same as 'tensor', but we don't have any
449 // expectation about the value of the tensors yet since it is uninitialized
450 EXPECT_EQ(tensor.shape(), read_back_tensor.shape());
451 }
452
453 // Make sure we can write to it
454 xla::LiteralProto new_literal =
455 xla::LiteralUtil::CreateR2({{9.0f, 2.0f}, {4.0f, 1.0f}}).ToProto();
456 {
457 auto new_value = ops::Const(root.WithDevice("/device:CPU:0"),
458 new_literal.SerializeAsString());
459 auto write_op = ops::XRTWriteLiteral(root, Input(handle), new_value);
460 TF_ASSERT_OK(root.status());
461 TF_EXPECT_OK(session.Run({write_op}, &outputs));
462 }
463
464 // Now read it back
465 {
466 auto read_back_op = ops::XRTReadLiteralAndRelease(root, handle);
467 TF_ASSERT_OK(root.status());
468 TF_EXPECT_OK(session.Run({read_back_op}, &outputs));
469 EXPECT_EQ(outputs.size(), 1);
470
471 xla::LiteralProto response;
472 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
473 EXPECT_TRUE(CompareLiteralProtos(response, new_literal));
474 }
475 }
476
TEST(RawApiTest,AllocFromTensorTuple)477 TEST(RawApiTest, AllocFromTensorTuple) {
478 xla::Literal literal0 =
479 xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
480 xla::Literal literal1 =
481 xla::LiteralUtil::CreateR2<float>({{14.0f, -5.0f}, {16.0f, 17.0f}});
482 xla::Literal literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1});
483 Tensor tensor0;
484 TF_ASSERT_OK(LiteralToHostTensor(literal0, DT_FLOAT, &tensor0));
485 Tensor tensor1;
486 TF_ASSERT_OK(LiteralToHostTensor(literal1, DT_FLOAT, &tensor1));
487
488 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
489 std::vector<int> layout = GetShapeLayoutVector(literal.shape()).ValueOrDie();
490 ops::XRTAllocateFromTensor::Attrs alloc_attrs =
491 ops::XRTAllocateFromTensor::Layouts(layout);
492 auto handle = ops::XRTAllocateFromTensor(root, {tensor0, tensor1},
493 {tensor0.shape(), tensor1.shape()},
494 alloc_attrs);
495 auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
496 TF_ASSERT_OK(root.status());
497
498 XrtClientSession session(root);
499 std::vector<Tensor> outputs;
500 TF_EXPECT_OK(session.Run({read_back}, &outputs));
501 EXPECT_EQ(outputs.size(), 1);
502
503 xla::LiteralProto response;
504 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
505 EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
506 }
507
TEST(RawApiTest,AllocFromTensorTupleSingle)508 TEST(RawApiTest, AllocFromTensorTupleSingle) {
509 xla::Literal literal0 =
510 xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
511 xla::Literal literal = xla::LiteralUtil::MakeTuple({&literal0});
512 Tensor tensor0;
513 TF_ASSERT_OK(LiteralToHostTensor(literal0, DT_FLOAT, &tensor0));
514
515 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
516 std::vector<int> layout = GetShapeLayoutVector(literal.shape()).ValueOrDie();
517 ops::XRTAllocateFromTensor::Attrs alloc_attrs =
518 ops::XRTAllocateFromTensor::Layouts(layout).MakeTuple(true);
519 auto handle = ops::XRTAllocateFromTensor(root, {tensor0}, {tensor0.shape()},
520 alloc_attrs);
521 auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
522 TF_ASSERT_OK(root.status());
523
524 XrtClientSession session(root);
525 std::vector<Tensor> outputs;
526 TF_EXPECT_OK(session.Run({read_back}, &outputs));
527 EXPECT_EQ(outputs.size(), 1);
528
529 xla::LiteralProto response;
530 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
531 EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
532 }
533
TEST(RawApiTest,AllocFromTensorRelayout)534 TEST(RawApiTest, AllocFromTensorRelayout) {
535 xla::Literal literal =
536 xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
537 Tensor tensor;
538 TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor));
539
540 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
541 // Use inverse array layout with the tensor data above.
542 std::vector<int> layout({0, 1});
543 ops::XRTAllocateFromTensor::Attrs alloc_attrs =
544 ops::XRTAllocateFromTensor::Layouts(layout);
545 auto handle =
546 ops::XRTAllocateFromTensor(root, {tensor}, {tensor.shape()}, alloc_attrs);
547 auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
548 TF_ASSERT_OK(root.status());
549
550 XrtClientSession session(root);
551 std::vector<Tensor> outputs;
552 TF_EXPECT_OK(session.Run({read_back}, &outputs));
553 EXPECT_EQ(outputs.size(), 1);
554
555 xla::LiteralProto response;
556 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
557 // We have sent literal's data (in array layout) with a attribute layout
558 // {0,1}, so the expected literal read from device needs to be changed
559 // accordingly.
560 xla::Literal expected_literal =
561 xla::LiteralUtil::CreateR2<float>({{4.0f, 6.0f}, {5.0f, 7.0f}});
562 EXPECT_TRUE(CompareLiteralToLiteralProto(expected_literal, response));
563 }
564
TEST(RawApiTest,AllocAndRewrite)565 TEST(RawApiTest, AllocAndRewrite) {
566 xrt::XLAAllocation alloc;
567 *alloc.mutable_value() =
568 xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto();
569
570 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
571 auto value =
572 ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
573 auto handle = ops::XRTAllocate(root, value);
574 auto read_back = ops::XRTReadLiteral(root, handle);
575 TF_ASSERT_OK(root.status());
576
577 XrtClientSession session(root);
578 std::vector<Tensor> outputs;
579 TF_EXPECT_OK(session.Run({read_back, handle}, &outputs));
580 EXPECT_EQ(outputs.size(), 2);
581
582 int64 allocation_handle = outputs[1].scalar<int64>()();
583 xla::LiteralProto response;
584 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
585 EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
586
587 xla::LiteralProto new_literal =
588 xla::LiteralUtil::CreateR2({{9, 2}, {4, 1}}).ToProto();
589 auto new_value = ops::Const(root.WithDevice("/device:CPU:0"),
590 new_literal.SerializeAsString());
591 auto write_op =
592 ops::XRTWriteLiteral(root, Input(allocation_handle), new_value);
593 TF_ASSERT_OK(root.status());
594 TF_EXPECT_OK(session.Run({write_op}, &outputs));
595 EXPECT_EQ(outputs.size(), 1);
596 EXPECT_EQ(allocation_handle, outputs[0].scalar<int64>()());
597
598 auto read_after_write = ops::XRTReadLiteral(root, Input(allocation_handle));
599 TF_EXPECT_OK(session.Run({read_after_write}, &outputs));
600 EXPECT_EQ(outputs.size(), 1);
601
602 xla::LiteralProto new_response;
603 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &new_response));
604 EXPECT_TRUE(CompareLiteralProtos(new_literal, new_response));
605
606 Tensor release_tensor(DT_INT64, TensorShape({1}));
607 release_tensor.flat<int64>()(0) = allocation_handle;
608
609 auto release = ops::XRTReleaseAllocationHandle(root, release_tensor);
610 TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {}, {release}, &outputs));
611 }
612
TEST(RawApiTest,AllocReleaseMany)613 TEST(RawApiTest, AllocReleaseMany) {
614 xrt::XLAAllocation alloc1;
615 *alloc1.mutable_value() =
616 xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto();
617 xrt::XLAAllocation alloc2;
618 *alloc2.mutable_value() =
619 xla::LiteralUtil::CreateR2({{6, 7}, {4, 5}}).ToProto();
620
621 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
622 auto value1 =
623 ops::Const(root.WithDevice("/device:CPU:0"), alloc1.SerializeAsString());
624 auto value2 =
625 ops::Const(root.WithDevice("/device:CPU:0"), alloc2.SerializeAsString());
626 auto handle1 = ops::XRTAllocate(root, value1);
627 auto handle2 = ops::XRTAllocate(root, value2);
628 TF_ASSERT_OK(root.status());
629
630 XrtClientSession session(root);
631 std::vector<Tensor> outputs;
632 TF_EXPECT_OK(session.Run({handle1, handle2}, &outputs));
633 EXPECT_EQ(outputs.size(), 2);
634
635 int64 allocation_handle1 = outputs[0].scalar<int64>()();
636 int64 allocation_handle2 = outputs[1].scalar<int64>()();
637
638 Tensor release_tensor(DT_INT64, TensorShape({2}));
639 release_tensor.flat<int64>()(0) = allocation_handle1;
640 release_tensor.flat<int64>()(1) = allocation_handle2;
641
642 auto release = ops::XRTReleaseAllocationHandle(root, release_tensor);
643 TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {}, {release}, &outputs));
644 }
645
TEST(RawApiTest,CompileAndReleaseMany)646 TEST(RawApiTest, CompileAndReleaseMany) {
647 xrt::XLAComputation c1;
648 auto config1 = c1.mutable_config();
649 auto shapes1 = config1->mutable_program_shape();
650 *shapes1->add_parameters() =
651 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
652 *shapes1->add_parameters() =
653 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
654 *shapes1->mutable_result() =
655 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
656 StoreComputationSnapshot(AddAndScale(), c1.mutable_hlo_snapshot());
657
658 xrt::XLAComputation c2;
659 auto config2 = c2.mutable_config();
660 auto shapes2 = config2->mutable_program_shape();
661 *shapes2->add_parameters() =
662 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
663 *shapes2->add_parameters() =
664 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
665 *shapes2->mutable_result() =
666 xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})})
667 .ToProto();
668 StoreComputationSnapshot(AddAndTuple(), c2.mutable_hlo_snapshot());
669
670 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
671 auto computation1 =
672 ops::Const(root.WithDevice("/device:CPU:0"), c1.SerializeAsString());
673 auto c_handle1 = ops::XRTCompile(root, computation1);
674 auto computation2 =
675 ops::Const(root.WithDevice("/device:CPU:0"), c2.SerializeAsString());
676 auto c_handle2 = ops::XRTCompile(root, computation2);
677 TF_ASSERT_OK(root.status());
678
679 XrtClientSession session(root);
680 std::vector<Tensor> outputs;
681 TF_EXPECT_OK(session.Run({c_handle1.handle, c_handle2.handle}, &outputs));
682 EXPECT_EQ(outputs.size(), 2);
683
684 int64 compilation_handle1 = outputs[0].scalar<int64>()();
685 int64 compilation_handle2 = outputs[1].scalar<int64>()();
686
687 Tensor release_tensor(DT_INT64, TensorShape({2}));
688 release_tensor.flat<int64>()(0) = compilation_handle1;
689 release_tensor.flat<int64>()(1) = compilation_handle2;
690
691 auto release = ops::XRTReleaseCompilationHandle(root, release_tensor);
692 TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {}, {release}, &outputs));
693 }
694
TEST(RawApiTest,AllocAndClearAll)695 TEST(RawApiTest, AllocAndClearAll) {
696 xrt::XLAAllocation alloc;
697 *alloc.mutable_value() =
698 xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto();
699
700 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
701 auto value =
702 ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
703 auto handle = ops::XRTAllocate(root, value);
704 TF_ASSERT_OK(root.status());
705
706 XrtClientSession session(root);
707 std::vector<Tensor> outputs;
708 TF_EXPECT_OK(session.Run({handle}, &outputs));
709 EXPECT_EQ(outputs.size(), 1);
710
711 int64 allocation_handle = outputs[0].scalar<int64>()();
712
713 auto clear_all = ops::XRTReleaseAllAllocations(root);
714
715 TF_EXPECT_OK(
716 session.Run(ClientSession::FeedType(), {}, {clear_all}, &outputs));
717 EXPECT_EQ(outputs.size(), 0);
718
719 auto read_after_clear = ops::XRTReadLiteral(root, Input(allocation_handle));
720 EXPECT_EQ(session.Run({read_after_clear}, &outputs).code(),
721 error::Code::NOT_FOUND);
722 }
723
TEST(RawApiTest,ReadAndWriteState)724 TEST(RawApiTest, ReadAndWriteState) {
725 xrt::XLAAllocation alloc;
726 *alloc.mutable_value() = TwoElementTuple();
727
728 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
729 auto value =
730 ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
731 auto handle = ops::XRTAllocate(root, value);
732 auto read_back = ops::XRTReadLiteral(root, handle);
733 auto release = ops::XRTReleaseAllocationHandle(
734 root.WithControlDependencies(read_back), handle);
735 TF_ASSERT_OK(root.status());
736
737 XrtClientSession session(root);
738 std::vector<Tensor> outputs;
739 TF_EXPECT_OK(
740 session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs));
741
742 xla::LiteralProto response;
743 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
744
745 EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
746 }
747
TEST(RawApiTest,ReadAndWriteStateAutoFree)748 TEST(RawApiTest, ReadAndWriteStateAutoFree) {
749 xrt::XLAAllocation alloc;
750 *alloc.mutable_value() = TwoElementTuple();
751
752 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
753 auto value =
754 ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
755 auto handle = ops::XRTAllocate(root, value);
756 auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
757 TF_ASSERT_OK(root.status());
758
759 XrtClientSession session(root);
760 std::vector<Tensor> outputs;
761 TF_EXPECT_OK(session.Run({read_back}, &outputs));
762
763 xla::LiteralProto response;
764 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
765 EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
766 }
767
TEST(RawApiTest,SubBuffer)768 TEST(RawApiTest, SubBuffer) {
769 xrt::XLAAllocation alloc;
770 *alloc.mutable_value() = NestedTuple();
771
772 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
773 auto value =
774 ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
775 auto base_handle = ops::XRTAllocate(root, value);
776 auto index_0 = ops::Const(root.WithDevice("/device:CPU:0"), {0});
777 auto index_1 = ops::Const(root.WithDevice("/device:CPU:0"), {1});
778 auto index_00 = ops::Const(root.WithDevice("/device:CPU:0"), {0, 0});
779 auto sub_0 = ops::XRTSubTuple(root, base_handle, index_0);
780 auto sub_1 = ops::XRTSubTuple(root, base_handle, index_1);
781 auto sub_00 = ops::XRTSubTupleAndRelease(
782 root.WithControlDependencies(
783 {sub_0.output_handle.op(), sub_1.output_handle.op()}),
784 base_handle, index_00);
785 auto value_0 = ops::XRTReadLiteralAndRelease(root, sub_0);
786 auto value_1 = ops::XRTReadLiteralAndRelease(root, sub_1);
787 auto value_00 = ops::XRTReadLiteralAndRelease(root, sub_00);
788 TF_ASSERT_OK(root.status());
789
790 XrtClientSession session(root);
791 std::vector<Tensor> outputs;
792 TF_EXPECT_OK(session.Run({value_0, value_1, value_00}, &outputs));
793
794 auto base_literal = xla::Literal::CreateFromProto(alloc.value()).ValueOrDie();
795 auto base_elements = base_literal.DecomposeTuple();
796 auto nested_0_elements = base_elements[0].Clone().DecomposeTuple();
797 xla::LiteralProto response_0;
798 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response_0));
799 EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[0], response_0));
800 xla::LiteralProto response_1;
801 EXPECT_TRUE(ParseFromTString(outputs[1].scalar<tstring>()(), &response_1));
802 EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[1], response_1));
803 xla::LiteralProto response_00;
804 EXPECT_TRUE(ParseFromTString(outputs[2].scalar<tstring>()(), &response_00));
805 EXPECT_TRUE(CompareLiteralToLiteralProto(nested_0_elements[0], response_00));
806 }
807
TEST(RawApiTest,MakeTuple)808 TEST(RawApiTest, MakeTuple) {
809 xrt::XLAAllocation alloc_0;
810 *alloc_0.mutable_value() = TwoElementTuple();
811 xrt::XLAAllocation alloc_1;
812 *alloc_1.mutable_value() = ScalarLiteral();
813
814 // The trivial tuple that just forwards its input and releases it.
815 xrt::XLATupleNode desc_0;
816 desc_0.set_input_index(0);
817 desc_0.set_release_input_handle(true);
818
819 xrt::XLATupleNode desc_1;
820 auto subdesc_10 = desc_1.add_tuples();
821 auto subdesc_11 = desc_1.add_tuples();
822 subdesc_10->set_input_index(0);
823 auto subdesc_110 = subdesc_11->add_tuples();
824 subdesc_110->set_input_index(0);
825 auto subdesc_111 = subdesc_11->add_tuples();
826 subdesc_111->set_input_index(1);
827
828 xrt::XLATupleNode desc_2;
829 auto subdesc_20 = desc_2.add_tuples();
830 auto subdesc_21 = desc_2.add_tuples();
831 subdesc_20->set_input_index(1);
832 subdesc_20->set_release_input_handle(true);
833 subdesc_21->set_input_index(0);
834 subdesc_21->set_release_input_handle(true);
835
836 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
837 auto value_0 =
838 ops::Const(root.WithDevice("/device:CPU:0"), alloc_0.SerializeAsString());
839 auto handle_0 = ops::XRTAllocate(root, value_0);
840 auto value_1 =
841 ops::Const(root.WithDevice("/device:CPU:0"), alloc_1.SerializeAsString());
842 auto handle_1 = ops::XRTAllocate(root, value_1);
843 auto tuple_0 =
844 ops::Const(root.WithDevice("/device:CPU:0"), desc_0.SerializeAsString());
845 auto handle_2 =
846 ops::XRTMakeTuple(root, tuple_0, {static_cast<Output>(handle_0)});
847 // handle_0 has now been released.
848 auto tuple_1 =
849 ops::Const(root.WithDevice("/device:CPU:0"), desc_1.SerializeAsString());
850 auto handle_3 = ops::XRTMakeTuple(
851 root, tuple_1,
852 {static_cast<Output>(handle_1), static_cast<Output>(handle_2)});
853 auto tuple_2 =
854 ops::Const(root.WithDevice("/device:CPU:0"), desc_2.SerializeAsString());
855 // Make sure this runs after handle_3 has completed, since it will free
856 // handle_1 and handle_2.
857 auto handle_4 = ops::XRTMakeTuple(
858 root.WithControlDependencies(handle_3), tuple_2,
859 {static_cast<Output>(handle_1), static_cast<Output>(handle_2)});
860 // handle_1 and handle_2 have now been released.
861
862 auto res_0 = ops::XRTReadLiteralAndRelease(root, handle_3);
863 auto res_1 = ops::XRTReadLiteralAndRelease(root, handle_4);
864 TF_ASSERT_OK(root.status());
865
866 XrtClientSession session(root);
867 std::vector<Tensor> outputs;
868 TF_EXPECT_OK(session.Run({res_0, res_1}, &outputs));
869 xla::LiteralProto response_0;
870 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response_0));
871 xla::LiteralProto response_1;
872 EXPECT_TRUE(ParseFromTString(outputs[1].scalar<tstring>()(), &response_1));
873
874 auto expected_0 = MakeTuple0();
875 EXPECT_TRUE(CompareLiteralProtos(response_0, expected_0));
876 auto expected_1 = NestedTuple();
877 EXPECT_TRUE(CompareLiteralProtos(response_1, expected_1));
878 }
879
TEST(RawApiTest,ExecuteChainedOpByOp)880 TEST(RawApiTest, ExecuteChainedOpByOp) {
881 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
882
883 auto make_computation = [](const std::function<xla::XlaComputation()>& fn) {
884 xrt::XLAComputation c;
885 auto config = c.mutable_config();
886 auto shapes = config->mutable_program_shape();
887 *shapes->add_parameters() =
888 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
889 *shapes->add_parameters() =
890 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
891 *shapes->mutable_result() =
892 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
893 StoreComputationSnapshot(fn(), c.mutable_hlo_snapshot());
894 return c.SerializeAsString();
895 };
896
897 auto c_add_scale = make_computation(AddAndScale);
898 auto c_sub_scale = make_computation(SubAndScale);
899
900 auto c_add_scale_op = ops::XRTCompile(
901 root, ops::Const(root.WithDevice("/device:CPU:0"), c_add_scale));
902 auto c_sub_scale_op = ops::XRTCompile(
903 root, ops::Const(root.WithDevice("/device:CPU:0"), c_sub_scale));
904 TF_ASSERT_OK(root.status());
905
906 XrtClientSession session(root);
907 std::vector<Tensor> outputs;
908 TF_EXPECT_OK(
909 session.Run({c_add_scale_op.handle, c_sub_scale_op.handle}, &outputs));
910 EXPECT_EQ(outputs.size(), 2);
911
912 int64 c_add_scale_handle = outputs[0].scalar<int64>()();
913 int64 c_sub_scale_handle = outputs[1].scalar<int64>()();
914
915 xrt::XLAAllocation p0;
916 *p0.mutable_value() = FloatVector({1.0f, 2.0f});
917 xrt::XLAAllocation p1;
918 *p1.mutable_value() = FloatVector({8.0f, 5.0f});
919
920 auto p0_handle = ops::XRTAllocate(
921 root,
922 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()));
923 auto p1_handle = ops::XRTAllocate(
924 root,
925 ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()));
926
927 xrt::XRTExecutionConfig e;
928 e.set_release_input_handles(false);
929 e.set_release_compilation_handle(false);
930 auto e_config =
931 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
932 auto result0 = ops::XRTExecute(root, Input(c_add_scale_handle), e_config,
933 {Output(p0_handle), Output(p1_handle)});
934 auto result1 = ops::XRTExecute(root, Input(c_sub_scale_handle), e_config,
935 {Output(p0_handle), Output(p1_handle)});
936 auto result = ops::XRTExecute(root, Input(c_add_scale_handle), e_config,
937 {result0.output_handle, result1.output_handle});
938 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
939 TF_ASSERT_OK(root.status());
940
941 TF_EXPECT_OK(session.Run({read_back}, &outputs));
942
943 xla::LiteralProto response;
944 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
945
946 auto expected = xla::LiteralUtil::CreateR1<float>({-150.0f, -36.0f});
947 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
948 }
949
TEST(RawApiTest,ExecuteChained)950 TEST(RawApiTest, ExecuteChained) {
951 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
952
953 auto make_computation = [](const std::function<xla::XlaComputation()>& fn) {
954 xrt::XLAComputation c;
955 auto config = c.mutable_config();
956 auto shapes = config->mutable_program_shape();
957 *shapes->add_parameters() =
958 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
959 *shapes->add_parameters() =
960 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
961 *shapes->mutable_result() =
962 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
963 StoreComputationSnapshot(fn(), c.mutable_hlo_snapshot());
964 return c.SerializeAsString();
965 };
966
967 auto c_add_scale = make_computation(AddAndScale);
968 auto c_sub_scale = make_computation(SubAndScale);
969
970 auto c_add_scale_op = ops::XRTCompile(
971 root, ops::Const(root.WithDevice("/device:CPU:0"), c_add_scale));
972 auto c_sub_scale_op = ops::XRTCompile(
973 root, ops::Const(root.WithDevice("/device:CPU:0"), c_sub_scale));
974 TF_ASSERT_OK(root.status());
975
976 XrtClientSession session(root);
977 std::vector<Tensor> outputs;
978 TF_EXPECT_OK(
979 session.Run({c_add_scale_op.handle, c_sub_scale_op.handle}, &outputs));
980 EXPECT_EQ(outputs.size(), 2);
981
982 int64 c_add_scale_handle = outputs[0].scalar<int64>()();
983 int64 c_sub_scale_handle = outputs[1].scalar<int64>()();
984
985 xrt::XLAAllocation p0;
986 *p0.mutable_value() = FloatVector({1.0f, 2.0f});
987 xrt::XLAAllocation p1;
988 *p1.mutable_value() = FloatVector({8.0f, 5.0f});
989
990 auto p0_handle_op = ops::XRTAllocate(
991 root,
992 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()));
993 auto p1_handle_op = ops::XRTAllocate(
994 root,
995 ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()));
996
997 TF_EXPECT_OK(session.Run({p0_handle_op, p1_handle_op}, &outputs));
998 EXPECT_EQ(outputs.size(), 2);
999
1000 int64 p0_handle = outputs[0].scalar<int64>()();
1001 int64 p1_handle = outputs[1].scalar<int64>()();
1002
1003 xrt::XRTChainedExecuteConfig config;
1004 auto config_const =
1005 ops::Const(root.WithDevice("/device:CPU:0"), config.SerializeAsString());
1006
1007 xrt::XRTChainedExecutePlan plan;
1008 xrt::XRTChainedExecuteOp* op;
1009 xrt::XRTChainedExecuteOp::Input* input;
1010 xrt::XRTChainedExecuteOp::Output* output;
1011
1012 // Index 0
1013 op = plan.add_ops();
1014 op->set_data_handle(p0_handle);
1015
1016 // Index 1
1017 op = plan.add_ops();
1018 op->set_data_handle(p1_handle);
1019
1020 // Index 2
1021 op = plan.add_ops();
1022 op->set_computation_handle(c_add_scale_handle);
1023 input = op->add_inputs();
1024 input->set_op_index(0);
1025 input = op->add_inputs();
1026 input->set_op_index(1);
1027
1028 // Index 3
1029 op = plan.add_ops();
1030 op->set_computation_handle(c_sub_scale_handle);
1031 input = op->add_inputs();
1032 input->set_op_index(0);
1033 input = op->add_inputs();
1034 input->set_op_index(1);
1035
1036 // Index 4
1037 op = plan.add_ops();
1038 op->set_computation_handle(c_add_scale_handle);
1039 input = op->add_inputs();
1040 input->set_op_index(2);
1041 input = op->add_inputs();
1042 input->set_op_index(3);
1043 output = op->add_outputs();
1044 output->set_result_index(0);
1045
1046 auto plan_const =
1047 ops::Const(root.WithDevice("/device:CPU:0"), plan.SerializeAsString());
1048 auto result = ops::XRTExecuteChained(root, plan_const, config_const);
1049 TF_ASSERT_OK(root.status());
1050
1051 TF_EXPECT_OK(session.Run({result}, &outputs));
1052 EXPECT_EQ(outputs.size(), 1);
1053
1054 auto handles_vec = outputs[0].vec<int64>();
1055 EXPECT_EQ(handles_vec.size(), 1);
1056
1057 auto read_back = ops::XRTReadLiteralAndRelease(root, Input(handles_vec(0)));
1058 TF_ASSERT_OK(root.status());
1059
1060 TF_EXPECT_OK(session.Run({read_back}, &outputs));
1061 EXPECT_EQ(outputs.size(), 1);
1062
1063 xla::LiteralProto response;
1064 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1065
1066 auto expected = xla::LiteralUtil::CreateR1<float>({-150.0f, -36.0f});
1067 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1068 }
1069
TEST(RawApiTest,CompileAndExecute)1070 TEST(RawApiTest, CompileAndExecute) {
1071 xrt::XLAAllocation p0;
1072 *p0.mutable_value() = FloatVector({1.0f, 2.0f});
1073 xrt::XLAAllocation p1;
1074 *p1.mutable_value() = FloatVector({8.0f, 5.0f});
1075
1076 xrt::XLAComputation c;
1077 auto config = c.mutable_config();
1078 auto shapes = config->mutable_program_shape();
1079 *shapes->add_parameters() =
1080 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1081 *shapes->add_parameters() =
1082 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1083 *shapes->mutable_result() =
1084 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1085 StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot());
1086
1087 xrt::XRTExecutionConfig e;
1088 e.set_release_input_handles(true);
1089 e.set_release_compilation_handle(true);
1090
1091 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1092 auto e_config =
1093 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1094 auto computation =
1095 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1096 auto c_handle = ops::XRTCompile(root, computation);
1097 auto p0_value =
1098 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1099 auto p0_handle = ops::XRTAllocate(root, p0_value);
1100 auto p1_value =
1101 ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1102 auto p1_handle = ops::XRTAllocate(root, p1_value);
1103 auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1104 {Output(p0_handle), Output(p1_handle)});
1105 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1106 TF_ASSERT_OK(root.status());
1107
1108 XrtClientSession session(root);
1109 std::vector<Tensor> outputs;
1110 TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1111
1112 xla::LiteralProto response;
1113 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1114
1115 auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
1116 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1117
1118 xla::ProgramShapeProto program_shape;
1119 EXPECT_TRUE(ParseFromTString(outputs[1].vec<tstring>()(0), &program_shape));
1120 EXPECT_EQ(program_shape.parameters_size(), 2);
1121 }
1122
TEST(RawApiTest,DynamicR1Test)1123 TEST(RawApiTest, DynamicR1Test) {
1124 xrt::XLAAllocation p0;
1125 *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f});
1126 xrt::XLAAllocation p1;
1127 *p1.mutable_value() = FloatVector({1.0f, -1.0f, 2.5f, 1.17f});
1128 xrt::XLAAllocation p2;
1129 *p2.mutable_value() = CreateR0<xla::int32>(2);
1130
1131 xrt::XLAComputation c;
1132 auto config = c.mutable_config();
1133 auto shapes = config->mutable_program_shape();
1134 *shapes->add_parameters() =
1135 xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
1136 *shapes->add_parameters() =
1137 xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
1138 *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto();
1139 xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1140 dyn_shape.set_dynamic_dimension(0, true);
1141 *shapes->mutable_result() = dyn_shape.ToProto();
1142 StoreComputationSnapshot(ReturnDynamicR1(), c.mutable_hlo_snapshot());
1143
1144 xrt::XRTExecutionConfig e;
1145 e.set_release_input_handles(true);
1146 e.set_release_compilation_handle(true);
1147
1148 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1149 Scope cpu_root = root.WithDevice("/device:CPU:0");
1150 auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1151 auto computation = ops::Const(cpu_root, c.SerializeAsString());
1152 auto c_handle = ops::XRTCompile(root, computation);
1153 auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1154 auto p0_handle = ops::XRTAllocate(root, p0_value);
1155 auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
1156 auto p1_handle = ops::XRTAllocate(root, p1_value);
1157 auto p2_value = ops::Const(cpu_root, p2.SerializeAsString());
1158 auto p2_handle = ops::XRTAllocate(root, p2_value);
1159 auto result = ops::XRTExecute(
1160 root, c_handle.handle, e_config,
1161 {Output(p0_handle), Output(p1_handle), Output(p2_handle)});
1162 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1163 TF_ASSERT_OK(root.status());
1164
1165 XrtClientSession session(root);
1166 std::vector<Tensor> outputs;
1167 TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1168
1169 xla::LiteralProto response;
1170 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1171 auto expected = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f});
1172 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1173 }
1174
TEST(RawApiTest,DynamicR2Test)1175 TEST(RawApiTest, DynamicR2Test) {
1176 xrt::XLAAllocation p0;
1177 *p0.mutable_value() = xla::LiteralUtil::CreateR2({{1.0f, 2.0f, 0.5f, -1.0f},
1178 {1.5f, 2.5f, 3.0f, -2.0f}})
1179 .ToProto();
1180 xrt::XLAAllocation p1;
1181 *p1.mutable_value() = xla::LiteralUtil::CreateR2({{1.0f, -1.0f, 2.5f, 1.17f},
1182 {1.2f, -1.6f, 2.8f, 1.24f}})
1183 .ToProto();
1184 xrt::XLAAllocation p2;
1185 *p2.mutable_value() = CreateR0<xla::int32>(2);
1186
1187 xrt::XLAComputation c;
1188 auto config = c.mutable_config();
1189 auto shapes = config->mutable_program_shape();
1190 *shapes->add_parameters() =
1191 xla::ShapeUtil::MakeShape(xla::F32, {2, 4}).ToProto();
1192 *shapes->add_parameters() =
1193 xla::ShapeUtil::MakeShape(xla::F32, {2, 4}).ToProto();
1194 *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto();
1195 xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4});
1196 dyn_shape.set_dynamic_dimension(0, true);
1197 dyn_shape.set_dynamic_dimension(1, true);
1198 *shapes->mutable_result() = dyn_shape.ToProto();
1199 StoreComputationSnapshot(ReturnDynamicR2(), c.mutable_hlo_snapshot());
1200
1201 xrt::XRTExecutionConfig e;
1202 e.set_release_input_handles(true);
1203 e.set_release_compilation_handle(true);
1204
1205 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1206 Scope cpu_root = root.WithDevice("/device:CPU:0");
1207 auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1208 auto computation = ops::Const(cpu_root, c.SerializeAsString());
1209 auto c_handle = ops::XRTCompile(root, computation);
1210 auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1211 auto p0_handle = ops::XRTAllocate(root, p0_value);
1212 auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
1213 auto p1_handle = ops::XRTAllocate(root, p1_value);
1214 auto p2_value = ops::Const(cpu_root, p2.SerializeAsString());
1215 auto p2_handle = ops::XRTAllocate(root, p2_value);
1216 auto result = ops::XRTExecute(
1217 root, c_handle.handle, e_config,
1218 {Output(p0_handle), Output(p1_handle), Output(p2_handle)});
1219 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1220 TF_ASSERT_OK(root.status());
1221
1222 XrtClientSession session(root);
1223 std::vector<Tensor> outputs;
1224 TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1225
1226 xla::LiteralProto response;
1227 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1228 auto expected = xla::LiteralUtil::CreateR2<float>({{2.0f, 1.0f}, {2.7, 0.9}});
1229 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1230 }
1231
TEST(RawApiTest,DynamicR1TupleTest)1232 TEST(RawApiTest, DynamicR1TupleTest) {
1233 xrt::XLAAllocation p0;
1234 *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f});
1235 xrt::XLAAllocation p1;
1236 *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f, 1.0f});
1237 xrt::XLAAllocation p2;
1238 *p2.mutable_value() = CreateR0<xla::int32>(2);
1239
1240 xrt::XLAComputation c;
1241 auto config = c.mutable_config();
1242 auto shapes = config->mutable_program_shape();
1243 *shapes->add_parameters() =
1244 xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
1245 *shapes->add_parameters() =
1246 xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
1247 *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto();
1248 xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1249 dyn_shape.set_dynamic_dimension(0, true);
1250 *shapes->mutable_result() =
1251 xla::ShapeUtil::MakeTupleShape(
1252 {dyn_shape, xla::ShapeUtil::MakeShape(xla::F32, {4}), dyn_shape})
1253 .ToProto();
1254 StoreComputationSnapshot(ReturnDynamicR1Tuple(), c.mutable_hlo_snapshot());
1255
1256 xrt::XRTExecutionConfig e;
1257 e.set_release_input_handles(true);
1258 e.set_release_compilation_handle(true);
1259
1260 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1261 Scope cpu_root = root.WithDevice("/device:CPU:0");
1262 auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1263 auto computation = ops::Const(cpu_root, c.SerializeAsString());
1264 auto c_handle = ops::XRTCompile(root, computation);
1265 auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1266 auto p0_handle = ops::XRTAllocate(root, p0_value);
1267 auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
1268 auto p1_handle = ops::XRTAllocate(root, p1_value);
1269 auto p2_value = ops::Const(cpu_root, p2.SerializeAsString());
1270 auto p2_handle = ops::XRTAllocate(root, p2_value);
1271 auto result = ops::XRTExecute(
1272 root, c_handle.handle, e_config,
1273 {Output(p0_handle), Output(p1_handle), Output(p2_handle)});
1274 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1275 TF_ASSERT_OK(root.status());
1276
1277 XrtClientSession session(root);
1278 std::vector<Tensor> outputs;
1279 TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1280
1281 xla::LiteralProto response;
1282 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1283
1284 auto expected0 = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f});
1285 auto expected1 = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f, 0.0f, 0.0f});
1286 auto expected2 = xla::LiteralUtil::CreateR1<float>({0.0f, 3.0f, 1.0f});
1287 auto expected =
1288 xla::LiteralUtil::MakeTuple({&expected0, &expected1, &expected2});
1289 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1290 }
1291
TEST(RawApiTest,AcceptDynamicR1TupleTest)1292 TEST(RawApiTest, AcceptDynamicR1TupleTest) {
1293 xrt::XLAAllocation p0;
1294 *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f});
1295 xrt::XLAAllocation p1;
1296 *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f});
1297
1298 xrt::XLATupleNode tuple_desc;
1299 auto subdesc_10 = tuple_desc.add_tuples();
1300 auto subdesc_11 = tuple_desc.add_tuples();
1301 subdesc_10->set_input_index(0);
1302 subdesc_10->set_release_input_handle(true);
1303 subdesc_11->set_input_index(1);
1304 subdesc_11->set_release_input_handle(true);
1305
1306 xrt::XLAComputation c;
1307 auto config = c.mutable_config();
1308 auto shapes = config->mutable_program_shape();
1309 xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1310 dyn_input_shape.set_dynamic_dimension(0, true);
1311 xla::Shape dyn_tuple_shape =
1312 xla::ShapeUtil::MakeTupleShape({dyn_input_shape, dyn_input_shape});
1313 *shapes->add_parameters() = dyn_tuple_shape.ToProto();
1314 xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1315 dyn_shape.set_dynamic_dimension(0, true);
1316 *shapes->mutable_result() = dyn_shape.ToProto();
1317 StoreComputationSnapshot(AcceptDynamicR1Tuple(), c.mutable_hlo_snapshot());
1318
1319 xrt::XRTExecutionConfig e;
1320 e.set_release_input_handles(true);
1321 e.set_release_compilation_handle(true);
1322
1323 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1324 Scope cpu_root = root.WithDevice("/device:CPU:0");
1325 auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1326 auto computation = ops::Const(cpu_root, c.SerializeAsString());
1327 auto c_handle = ops::XRTCompile(root, computation);
1328 auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1329 auto p0_handle = ops::XRTAllocate(root, p0_value);
1330 auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
1331 auto p1_handle = ops::XRTAllocate(root, p1_value);
1332
1333 auto tuple_0 = ops::Const(root.WithDevice("/device:CPU:0"),
1334 tuple_desc.SerializeAsString());
1335 auto t0_handle = ops::XRTMakeTuple(
1336 root, tuple_0,
1337 {static_cast<Output>(p0_handle), static_cast<Output>(p1_handle)});
1338 auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1339 {static_cast<Output>(t0_handle)});
1340 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1341 TF_ASSERT_OK(root.status());
1342
1343 XrtClientSession session(root);
1344 std::vector<Tensor> outputs;
1345 TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1346
1347 xla::LiteralProto response;
1348 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1349
1350 auto expected = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f, 0.0f});
1351 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1352 }
1353
TEST(RawApiTest,AcceptDynamicR1Test)1354 TEST(RawApiTest, AcceptDynamicR1Test) {
1355 xrt::XLAAllocation p0;
1356 *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f});
1357 xrt::XLAAllocation p1;
1358 *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f});
1359
1360 xrt::XLAComputation c;
1361 auto config = c.mutable_config();
1362 auto shapes = config->mutable_program_shape();
1363 xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1364 dyn_input_shape.set_dynamic_dimension(0, true);
1365 *shapes->add_parameters() = dyn_input_shape.ToProto();
1366 *shapes->add_parameters() = dyn_input_shape.ToProto();
1367 xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1368 dyn_shape.set_dynamic_dimension(0, true);
1369 *shapes->mutable_result() = dyn_shape.ToProto();
1370 StoreComputationSnapshot(AcceptDynamicR1(), c.mutable_hlo_snapshot());
1371
1372 xrt::XRTExecutionConfig e;
1373 e.set_release_input_handles(true);
1374 e.set_release_compilation_handle(true);
1375
1376 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1377 Scope cpu_root = root.WithDevice("/device:CPU:0");
1378 auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1379 auto computation = ops::Const(cpu_root, c.SerializeAsString());
1380 auto c_handle = ops::XRTCompile(root, computation);
1381 auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1382 auto allocate_op_0 = ops::XRTAllocate(root, p0_value);
1383 auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
1384 auto allocate_op_1 = ops::XRTAllocate(root, p1_value);
1385 auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1386 {Output(allocate_op_0), Output(allocate_op_1)});
1387 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1388 TF_ASSERT_OK(root.status());
1389
1390 XrtClientSession session(root);
1391 std::vector<Tensor> outputs;
1392 TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1393
1394 xla::LiteralProto response;
1395 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1396
1397 auto expected = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f, 0.0f});
1398 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1399 }
1400
TEST(RawApiTest,AcceptDynamicR2Test)1401 TEST(RawApiTest, AcceptDynamicR2Test) {
1402 xrt::XLAAllocation p0;
1403 *p0.mutable_value() =
1404 xla::LiteralUtil::CreateR2({{-1.0f, 2.0f, 3.0f}, {-4.0f, -5.0f, 6.0f}})
1405 .ToProto();
1406
1407 xrt::XLAComputation c;
1408 auto config = c.mutable_config();
1409 auto shapes = config->mutable_program_shape();
1410 // Compile time expects ascending layout.
1411 xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4});
1412 dyn_shape.set_dynamic_dimension(1, true);
1413 *shapes->add_parameters() = dyn_shape.ToProto();
1414
1415 *shapes->mutable_result() = dyn_shape.ToProto();
1416 StoreComputationSnapshot(AcceptDynamicR2(), c.mutable_hlo_snapshot());
1417
1418 xrt::XRTExecutionConfig e;
1419 e.set_release_input_handles(true);
1420 e.set_release_compilation_handle(true);
1421
1422 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1423 Scope cpu_root = root.WithDevice("/device:CPU:0");
1424 auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1425 auto computation = ops::Const(cpu_root, c.SerializeAsString());
1426 auto c_handle = ops::XRTCompile(root, computation);
1427 auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1428 auto p0_handle = ops::XRTAllocate(root, p0_value);
1429 auto result =
1430 ops::XRTExecute(root, c_handle.handle, e_config, {Output(p0_handle)});
1431 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1432 TF_ASSERT_OK(root.status());
1433
1434 XrtClientSession session(root);
1435 std::vector<Tensor> outputs;
1436 TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1437
1438 xla::LiteralProto response;
1439 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1440
1441 auto expected = xla::LiteralUtil::CreateR2<float>(
1442 {{1.0f, -2.0f, -3.0f}, {4.0f, 5.0f, -6.0f}});
1443 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1444 }
1445
TEST(RawApiTest,CompileAndExecuteWithArgumentVector)1446 TEST(RawApiTest, CompileAndExecuteWithArgumentVector) {
1447 xrt::XLAAllocation p0;
1448 *p0.mutable_value() = FloatVector({1.0f, 2.0f});
1449 xrt::XLAAllocation p1;
1450 *p1.mutable_value() = FloatVector({8.0f, 5.0f});
1451
1452 xrt::XLAComputation c;
1453 auto config = c.mutable_config();
1454 auto shapes = config->mutable_program_shape();
1455 *shapes->add_parameters() =
1456 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1457 *shapes->add_parameters() =
1458 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1459 *shapes->mutable_result() =
1460 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1461 StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot());
1462
1463 xrt::XRTExecutionConfig e;
1464 e.set_release_input_handles(true);
1465 e.set_release_compilation_handle(true);
1466
1467 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1468 auto e_config =
1469 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1470 auto computation =
1471 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1472 auto c_handle = ops::XRTCompile(root, computation);
1473 auto p0_value =
1474 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1475 auto p0_handle = ops::XRTAllocate(root, p0_value);
1476 auto p1_value =
1477 ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1478 auto p1_handle = ops::XRTAllocate(root, p1_value);
1479 auto packed_args = ops::Stack(root.WithDevice("/device:CPU:0"),
1480 {Output(p0_handle), Output(p1_handle)});
1481 auto result =
1482 ops::XRTExecute(root, c_handle.handle, e_config, {Output(packed_args)});
1483 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1484 TF_ASSERT_OK(root.status());
1485
1486 XrtClientSession session(root);
1487 std::vector<Tensor> outputs;
1488 TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1489
1490 xla::LiteralProto response;
1491 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1492
1493 auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
1494 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1495
1496 xla::ProgramShapeProto program_shape;
1497 EXPECT_TRUE(ParseFromTString(outputs[1].vec<tstring>()(0), &program_shape));
1498 EXPECT_EQ(program_shape.parameters_size(), 2);
1499 }
1500
TEST(RawApiTest,CompileWithXlaReturnShapes)1501 TEST(RawApiTest, CompileWithXlaReturnShapes) {
1502 xla::XlaBuilder builder("XrtXlaShapes");
1503 auto input_shape = xla::ShapeUtil::MakeShape(xla::BF16, {32, 3, 128, 128});
1504 auto kernel_shape = xla::ShapeUtil::MakeShape(xla::BF16, {3, 3, 5, 5});
1505 // Clear layouts to signal XLA we are ready to get whatever are coming out of
1506 // the compilation process.
1507 xla::LayoutUtil::ClearLayout(&input_shape);
1508 xla::LayoutUtil::ClearLayout(&kernel_shape);
1509 auto param_shape =
1510 xla::ShapeUtil::MakeTupleShape({input_shape, kernel_shape});
1511 auto param = xla::Parameter(&builder, 0, param_shape, "param");
1512 auto input = xla::GetTupleElement(param, 0);
1513 auto kernel = xla::GetTupleElement(param, 1);
1514 xla::Conv(input, kernel, {1, 1}, xla::Padding::kSame);
1515 TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation xla_computation, builder.Build());
1516
1517 auto result_shape = xla_computation.GetProgramShape().ValueOrDie().result();
1518 // Clear the result shape layout to tell XLA we are accepting whatever are
1519 // coming out of the compilation process.
1520 xla::LayoutUtil::ClearLayout(&result_shape);
1521
1522 xrt::XLAComputation c;
1523 auto config = c.mutable_config();
1524 auto shapes = config->mutable_program_shape();
1525 *shapes->add_parameters() = param_shape.ToProto();
1526 *shapes->mutable_result() = result_shape.ToProto();
1527 StoreComputationSnapshot(xla_computation, c.mutable_hlo_snapshot());
1528
1529 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1530 auto computation =
1531 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1532 auto c_handle = ops::XRTCompile(root, computation);
1533 auto release = ops::XRTReleaseCompilationHandle(root, c_handle.handle);
1534 TF_ASSERT_OK(root.status());
1535
1536 XrtClientSession session(root);
1537 std::vector<Tensor> outputs;
1538 TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {c_handle.program_shape},
1539 {release}, &outputs));
1540
1541 xla::ProgramShapeProto program_shape_proto;
1542 EXPECT_TRUE(
1543 ParseFromTString(outputs[0].vec<tstring>()(0), &program_shape_proto));
1544 xla::ProgramShape program_shape(program_shape_proto);
1545 EXPECT_EQ(program_shape.parameters_size(), 1);
1546
1547 VLOG(2) << "Param: "
1548 << xla::ShapeUtil::HumanStringWithLayout(program_shape.parameters(0));
1549 VLOG(2) << "Result: "
1550 << xla::ShapeUtil::HumanStringWithLayout(program_shape.result());
1551
1552 xla::ProgramShape xla_program_shape =
1553 XlaCompiledProgramShape(xla_computation, xla::ProgramShape(*shapes));
1554 EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()(
1555 xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(),
1556 xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0})
1557 .layout()));
1558 EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()(
1559 xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {1}).layout(),
1560 xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {1})
1561 .layout()));
1562 EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()(
1563 program_shape.result().layout(), xla_program_shape.result().layout()));
1564 }
1565
TEST(RawApiTest,DotGeneralWithLayoutTest)1566 TEST(RawApiTest, DotGeneralWithLayoutTest) {
1567 auto layout = xla::LayoutUtil::MakeLayout({0, 1});
1568
1569 xrt::XLAAllocation p0;
1570 *p0.mutable_value() = FloatMatrix({{1.0f, 2.0f}, {3.0f, 4.0f}}, layout);
1571 xrt::XLAAllocation p1;
1572 *p1.mutable_value() = FloatMatrix({{8.0f}, {5.0f}}, layout);
1573
1574 xrt::XLAComputation c;
1575 auto config = c.mutable_config();
1576 auto shapes = config->mutable_program_shape();
1577 *shapes->add_parameters() =
1578 xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1}).ToProto();
1579 *shapes->add_parameters() =
1580 xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}).ToProto();
1581 *shapes->mutable_result() =
1582 xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}).ToProto();
1583 StoreComputationSnapshot(Dot(), c.mutable_hlo_snapshot());
1584
1585 xrt::XRTExecutionConfig e;
1586 e.set_release_input_handles(true);
1587 e.set_release_compilation_handle(true);
1588
1589 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1590 auto e_config =
1591 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1592 auto computation =
1593 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1594 auto c_handle = ops::XRTCompile(root, computation);
1595 auto p0_value =
1596 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1597 auto p0_handle = ops::XRTAllocate(root, p0_value);
1598 auto p1_value =
1599 ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1600 auto p1_handle = ops::XRTAllocate(root, p1_value);
1601 auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1602 {Output(p0_handle), Output(p1_handle)});
1603 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1604 TF_ASSERT_OK(root.status());
1605
1606 XrtClientSession session(root);
1607 std::vector<Tensor> outputs;
1608 TF_EXPECT_OK(session.Run({read_back}, &outputs));
1609
1610 xla::LiteralProto response;
1611 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1612
1613 auto expected =
1614 xla::LiteralUtil::CreateR2WithLayout<float>({{18.0f}, {44.0f}}, layout);
1615
1616 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1617 }
1618
TEST(RawApiTest,CompileAndExecuteZeroArg)1619 TEST(RawApiTest, CompileAndExecuteZeroArg) {
1620 xrt::XLAComputation c;
1621 auto config = c.mutable_config();
1622 auto shapes = config->mutable_program_shape();
1623 *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto();
1624
1625 xrt::XRTExecutionConfig e;
1626 e.set_release_input_handles(true);
1627 e.set_release_compilation_handle(true);
1628 StoreComputationSnapshot(OnePlusTwo(), c.mutable_hlo_snapshot());
1629
1630 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1631 auto e_config =
1632 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1633 auto computation =
1634 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1635 auto c_handle = ops::XRTCompile(root, computation);
1636 auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1637 std::initializer_list<Input>({}));
1638 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1639 TF_ASSERT_OK(root.status());
1640
1641 XrtClientSession session(root);
1642 std::vector<Tensor> outputs;
1643 TF_EXPECT_OK(session.Run({read_back}, &outputs));
1644
1645 xla::LiteralProto response;
1646 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1647
1648 auto expected = xla::LiteralUtil::CreateR0<float>(3.0f);
1649 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1650 }
1651
TEST(RawApiTest,CompileAndExecuteReturnTuple)1652 TEST(RawApiTest, CompileAndExecuteReturnTuple) {
1653 xrt::XLAAllocation p0;
1654 *p0.mutable_value() = FloatVector({1.0f, 2.0f});
1655 xrt::XLAAllocation p1;
1656 *p1.mutable_value() = FloatVector({8.0f, 5.0f});
1657
1658 xrt::XLAComputation c;
1659 auto config = c.mutable_config();
1660 auto shapes = config->mutable_program_shape();
1661 *shapes->add_parameters() =
1662 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1663 *shapes->add_parameters() =
1664 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1665 *shapes->mutable_result() =
1666 xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})})
1667 .ToProto();
1668 StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot());
1669
1670 xrt::XRTExecutionConfig e;
1671 e.set_release_input_handles(true);
1672 e.set_release_compilation_handle(true);
1673
1674 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1675 auto e_config =
1676 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1677 auto computation =
1678 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1679 auto c_handle = ops::XRTCompile(root, computation);
1680 auto p0_value =
1681 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1682 auto p0_handle = ops::XRTAllocate(root, p0_value);
1683 auto p1_value =
1684 ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1685 auto p1_handle = ops::XRTAllocate(root, p1_value);
1686 auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1687 {Output(p0_handle), Output(p1_handle)});
1688 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1689 TF_ASSERT_OK(root.status());
1690
1691 XrtClientSession session(root);
1692 std::vector<Tensor> outputs;
1693 TF_EXPECT_OK(session.Run({read_back}, &outputs));
1694
1695 xla::LiteralProto response;
1696 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1697
1698 auto sum = xla::LiteralUtil::CreateR1<float>({9.0f, 7.0f});
1699 auto expected = xla::LiteralUtil::MakeTuple({&sum});
1700 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1701 }
1702
TEST(RawApiTest,CompileAndExecuteReturnExplodedTuple)1703 TEST(RawApiTest, CompileAndExecuteReturnExplodedTuple) {
1704 xrt::XLAAllocation p0;
1705 *p0.mutable_value() = xla::LiteralUtil::CreateR0<float>(12.0f).ToProto();
1706
1707 xrt::XLAAllocation p1;
1708 *p1.mutable_value() = xla::LiteralUtil::CreateR0<float>(3.0f).ToProto();
1709
1710 xrt::XLAComputation c;
1711 auto config = c.mutable_config();
1712 auto shapes = config->mutable_program_shape();
1713 *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto();
1714 *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto();
1715 *shapes->mutable_result() =
1716 xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {}),
1717 xla::ShapeUtil::MakeShape(xla::F32, {})})
1718 .ToProto();
1719 StoreComputationSnapshot(AddAndSubTuple(), c.mutable_hlo_snapshot());
1720
1721 xrt::XRTExecutionConfig e;
1722 e.set_release_input_handles(true);
1723 e.set_release_compilation_handle(true);
1724 e.set_return_exploded_tuple(true);
1725
1726 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1727 auto e_config =
1728 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1729 auto computation =
1730 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1731 auto c_handle = ops::XRTCompile(root, computation);
1732 auto p0_value =
1733 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1734 auto p0_handle = ops::XRTAllocate(root, p0_value);
1735 auto p1_value =
1736 ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1737 auto p1_handle = ops::XRTAllocate(root, p1_value);
1738 auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1739 {Output(p0_handle), Output(p1_handle)});
1740 TF_ASSERT_OK(root.status());
1741
1742 XrtClientSession session(root);
1743 std::vector<Tensor> outputs;
1744 TF_EXPECT_OK(session.Run({result}, &outputs));
1745 EXPECT_EQ(outputs.size(), 1);
1746
1747 auto handles_vec = outputs.front().vec<int64>();
1748 EXPECT_EQ(handles_vec.size(), 2);
1749
1750 const float kResults[2] = {15.0f, 9.0f};
1751 for (int64 i = 0; i < handles_vec.size(); ++i) {
1752 auto read_back = ops::XRTReadLiteralAndRelease(root, Input(handles_vec(i)));
1753 std::vector<Tensor> voutputs;
1754 TF_EXPECT_OK(session.Run({read_back}, &voutputs));
1755 EXPECT_EQ(voutputs.size(), 1);
1756
1757 xla::LiteralProto response;
1758 EXPECT_TRUE(ParseFromTString(voutputs[0].scalar<tstring>()(), &response));
1759
1760 auto expected = xla::LiteralUtil::CreateR0<float>(kResults[i]);
1761 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1762 }
1763 }
1764
TEST(RawApiTest,LeakCompilationReference)1765 TEST(RawApiTest, LeakCompilationReference) {
1766 xrt::XLAComputation c;
1767 auto config = c.mutable_config();
1768 auto shapes = config->mutable_program_shape();
1769 *shapes->add_parameters() =
1770 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1771 *shapes->add_parameters() =
1772 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1773 *shapes->mutable_result() =
1774 xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})})
1775 .ToProto();
1776 StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot());
1777
1778 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1779 auto computation =
1780 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1781 auto c_handle = ops::XRTCompile(root, computation);
1782 TF_ASSERT_OK(root.status());
1783
1784 XrtClientSession session(root);
1785 std::vector<Tensor> outputs;
1786 TF_EXPECT_OK(session.Run({c_handle.handle}, &outputs));
1787 }
1788
TEST(RawApiTest,CompileAndExecuteWithReusedBuffers)1789 TEST(RawApiTest, CompileAndExecuteWithReusedBuffers) {
1790 xla::Shape element_shape = xla::ShapeUtil::MakeShape(xla::F32, {2});
1791 xla::Shape shape =
1792 xla::ShapeUtil::MakeTupleShape({element_shape, element_shape});
1793 xla::Shape return_shape = xla::ShapeUtil::MakeTupleShape(
1794 {element_shape, element_shape, element_shape, element_shape});
1795 xla::XlaBuilder builder("ReuseBuffer");
1796 auto param = xla::Parameter(&builder, 0, shape, "param");
1797 auto p0 = xla::GetTupleElement(param, 0);
1798 auto p1 = xla::GetTupleElement(param, 1);
1799 auto add = xla::Add(p0, p1);
1800 auto sub = xla::Sub(p0, p1);
1801 xla::Tuple(&builder, {add, sub, p0, p1});
1802
1803 // Flip the tuple literals in the input handle.
1804 builder.SetUpAlias({1}, 0, {0});
1805 builder.SetUpAlias({0}, 0, {1});
1806
1807 auto computation = builder.Build().ValueOrDie();
1808
1809 auto literal0 = xla::LiteralUtil::CreateR1<float>({1.0f, 2.0f});
1810 auto literal1 = xla::LiteralUtil::CreateR1<float>({5.0f, 9.0f});
1811 auto literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1});
1812
1813 xrt::XLAAllocation param_alloc;
1814 *param_alloc.mutable_value() = literal.ToProto();
1815
1816 xrt::XLAComputation c;
1817 auto config = c.mutable_config();
1818 auto shapes = config->mutable_program_shape();
1819 *shapes->add_parameters() = shape.ToProto();
1820 *shapes->mutable_result() = return_shape.ToProto();
1821 StoreComputationSnapshot(computation, c.mutable_hlo_snapshot());
1822
1823 xrt::XRTExecutionConfig e;
1824 e.set_release_input_handles(false);
1825 e.set_release_compilation_handle(true);
1826
1827 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1828 XrtClientSession session(root);
1829 auto e_config =
1830 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1831 auto c_data =
1832 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1833 auto c_handle = ops::XRTCompile(root, c_data);
1834 auto param_value = ops::Const(root.WithDevice("/device:CPU:0"),
1835 param_alloc.SerializeAsString());
1836 auto param_handle = ops::XRTAllocate(root, param_value);
1837 TF_ASSERT_OK(root.status());
1838
1839 std::vector<Tensor> outputs;
1840 TF_EXPECT_OK(session.Run({param_handle}, &outputs));
1841
1842 int64 alloc_handle = outputs[0].scalar<int64>()();
1843
1844 // Note that we release the result handle immediately, but since we aliased
1845 // the output buffers onto the input allocation ones (held in alloc_handle),
1846 // we can fetch the result from there.
1847 auto result =
1848 ops::XRTExecute(root, c_handle.handle, e_config, {Input(alloc_handle)});
1849 auto read_back = ops::XRTReadLiteral(root, result);
1850 auto release = ops::XRTReleaseAllocationHandle(
1851 root.WithControlDependencies(read_back), result);
1852 TF_ASSERT_OK(root.status());
1853
1854 TF_EXPECT_OK(
1855 session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs));
1856
1857 xla::Literal exec_literal = ReadOutputLiteral(outputs, 0);
1858 auto exec_literal_parts = exec_literal.DecomposeTuple();
1859 ASSERT_EQ(exec_literal_parts.size(), 4);
1860
1861 EXPECT_TRUE(CompareLiterals(exec_literal_parts[2], literal0));
1862 EXPECT_TRUE(CompareLiterals(exec_literal_parts[3], literal1));
1863
1864 // Now we read back the original input handle values, which at this point
1865 // should contain the result of the XLA computation.
1866 auto read_handle = ops::XRTReadLiteral(root, Input(alloc_handle));
1867 TF_ASSERT_OK(root.status());
1868 auto release_handle = ops::XRTReleaseAllocationHandle(
1869 root.WithControlDependencies(read_handle), Input(alloc_handle));
1870 TF_ASSERT_OK(root.status());
1871
1872 TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {read_handle},
1873 {release_handle}, &outputs));
1874
1875 xla::Literal return_literal = ReadOutputLiteral(outputs, 0);
1876
1877 auto expected_literal0 = xla::LiteralUtil::CreateR1<float>({6.0f, 11.0f});
1878 auto expected_literal1 = xla::LiteralUtil::CreateR1<float>({-4.0f, -7.0f});
1879 // The first element of the computation returned tuple would be the add
1880 // (expected_literal0), but since we flipped the buffers, the sub
1881 // (expected_literal1) should come first.
1882 auto expected_literal =
1883 xla::LiteralUtil::MakeTuple({&expected_literal1, &expected_literal0});
1884
1885 EXPECT_TRUE(CompareLiterals(return_literal, expected_literal));
1886 }
1887
TEST(RawApiTest,CompileAndExecuteWithReusedBuffersS64)1888 TEST(RawApiTest, CompileAndExecuteWithReusedBuffersS64) {
1889 xla::Shape element_shape = xla::ShapeUtil::MakeShape(xla::S64, {2});
1890 xla::Shape shape =
1891 xla::ShapeUtil::MakeTupleShape({element_shape, element_shape});
1892 xla::Shape return_shape = xla::ShapeUtil::MakeTupleShape(
1893 {element_shape, element_shape, element_shape, element_shape});
1894 xla::XlaBuilder builder("ReuseBuffer");
1895 auto param = xla::Parameter(&builder, 0, shape, "param");
1896 auto p0 = xla::GetTupleElement(param, 0);
1897 auto p1 = xla::GetTupleElement(param, 1);
1898 auto add = xla::Add(p0, p1);
1899 auto sub = xla::Sub(p0, p1);
1900 xla::Tuple(&builder, {add, sub, p0, p1});
1901
1902 // Flip the tuple literals in the input handle.
1903 builder.SetUpAlias({1}, 0, {0});
1904 builder.SetUpAlias({0}, 0, {1});
1905
1906 auto computation = builder.Build().ValueOrDie();
1907
1908 auto literal0 = xla::LiteralUtil::CreateR1<int64>({1, 2});
1909 auto literal1 = xla::LiteralUtil::CreateR1<int64>({5, 9});
1910 auto literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1});
1911
1912 xrt::XLAAllocation param_alloc;
1913 *param_alloc.mutable_value() = literal.ToProto();
1914
1915 xrt::XLAComputation c;
1916 auto config = c.mutable_config();
1917 auto shapes = config->mutable_program_shape();
1918 *shapes->add_parameters() = shape.ToProto();
1919 *shapes->mutable_result() = return_shape.ToProto();
1920 StoreComputationSnapshot(computation, c.mutable_hlo_snapshot());
1921
1922 xrt::XRTExecutionConfig e;
1923 e.set_release_input_handles(false);
1924 e.set_release_compilation_handle(true);
1925
1926 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1927 XrtClientSession session(root);
1928 auto e_config =
1929 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1930 auto c_data =
1931 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1932 auto c_handle = ops::XRTCompile(root, c_data);
1933 auto param_value = ops::Const(root.WithDevice("/device:CPU:0"),
1934 param_alloc.SerializeAsString());
1935 auto param_handle = ops::XRTAllocate(root, param_value);
1936 TF_ASSERT_OK(root.status());
1937
1938 std::vector<Tensor> outputs;
1939 TF_EXPECT_OK(session.Run({param_handle}, &outputs));
1940
1941 int64 alloc_handle = outputs[0].scalar<int64>()();
1942
1943 // Note that we release the result handle immediately, but since we aliased
1944 // the output buffers onto the input allocation ones (held in alloc_handle),
1945 // we can fetch the result from there.
1946 auto result =
1947 ops::XRTExecute(root, c_handle.handle, e_config, {Input(alloc_handle)});
1948 auto read_back = ops::XRTReadLiteral(root, result);
1949 auto release = ops::XRTReleaseAllocationHandle(
1950 root.WithControlDependencies(read_back), result);
1951 TF_ASSERT_OK(root.status());
1952
1953 TF_EXPECT_OK(
1954 session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs));
1955
1956 xla::Literal exec_literal = ReadOutputLiteral(outputs, 0);
1957 auto exec_literal_parts = exec_literal.DecomposeTuple();
1958 ASSERT_EQ(exec_literal_parts.size(), 4);
1959
1960 EXPECT_TRUE(CompareLiterals(exec_literal_parts[2], literal0));
1961 EXPECT_TRUE(CompareLiterals(exec_literal_parts[3], literal1));
1962
1963 // Now we read back the original input handle values, which at this point
1964 // should contain the result of the XLA computation.
1965 auto read_handle = ops::XRTReadLiteral(root, Input(alloc_handle));
1966 TF_ASSERT_OK(root.status());
1967 auto release_handle = ops::XRTReleaseAllocationHandle(
1968 root.WithControlDependencies(read_handle), Input(alloc_handle));
1969 TF_ASSERT_OK(root.status());
1970
1971 TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {read_handle},
1972 {release_handle}, &outputs));
1973
1974 xla::Literal return_literal = ReadOutputLiteral(outputs, 0);
1975
1976 auto expected_literal0 = xla::LiteralUtil::CreateR1<int64>({6, 11});
1977 auto expected_literal1 = xla::LiteralUtil::CreateR1<int64>({-4, -7});
1978 // The first element of the computation returned tuple would be the add
1979 // (expected_literal0), but since we flipped the buffers, the sub
1980 // (expected_literal1) should come first.
1981 auto expected_literal =
1982 xla::LiteralUtil::MakeTuple({&expected_literal1, &expected_literal0});
1983
1984 EXPECT_TRUE(CompareLiterals(return_literal, expected_literal));
1985 }
1986
TEST(RawApiTest,CompileAndExecuteWithS64Argument)1987 TEST(RawApiTest, CompileAndExecuteWithS64Argument) {
1988 xrt::XLAAllocation p0;
1989 *p0.mutable_value() = xla::LiteralUtil::CreateR0<int64>(11031965).ToProto();
1990 xrt::XLAAllocation p1;
1991 *p1.mutable_value() = xla::LiteralUtil::CreateR0<int64>(4091934).ToProto();
1992
1993 xrt::XLAComputation c;
1994 auto config = c.mutable_config();
1995 auto shapes = config->mutable_program_shape();
1996 *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
1997 *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
1998 *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
1999 StoreComputationSnapshot(AddS64(), c.mutable_hlo_snapshot());
2000
2001 xrt::XRTExecutionConfig e;
2002 e.set_release_input_handles(true);
2003 e.set_release_compilation_handle(true);
2004 e.set_return_exploded_tuple(true);
2005
2006 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
2007 auto e_config =
2008 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
2009 auto computation =
2010 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
2011 auto c_handle = ops::XRTCompile(root, computation);
2012 auto p0_value =
2013 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
2014 auto p0_handle = ops::XRTAllocate(root, p0_value);
2015 auto p1_value =
2016 ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
2017 auto p1_handle = ops::XRTAllocate(root, p1_value);
2018 auto result = ops::XRTExecute(root, c_handle.handle, e_config,
2019 {Output(p0_handle), Output(p1_handle)});
2020 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
2021 TF_ASSERT_OK(root.status());
2022
2023 XrtClientSession session(root);
2024 std::vector<Tensor> outputs;
2025 TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
2026
2027 xla::LiteralProto response;
2028 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
2029
2030 auto expected = xla::LiteralUtil::CreateR0<int64>(15123899);
2031 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
2032
2033 xla::ProgramShapeProto program_shape;
2034 EXPECT_TRUE(ParseFromTString(outputs[1].vec<tstring>()(0), &program_shape));
2035 EXPECT_EQ(program_shape.parameters_size(), 2);
2036 EXPECT_TRUE(xla::ShapeUtil::HasPrimitiveType(
2037 xla::Shape(program_shape.result()), xla::S64));
2038 }
2039
2040 // Tests the XRT device memory compaction API (XRTCompactAllocations).
TEST(RawApiTest,TestDeviceMemoryCompaction)2041 TEST(RawApiTest, TestDeviceMemoryCompaction) {
2042 static const int kNumAllocs = 32;
2043 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
2044
2045 std::vector<xrt::XLAAllocation> allocs(kNumAllocs);
2046 std::vector<Output> handle_outputs;
2047 for (int i = 0; i < kNumAllocs; ++i) {
2048 *allocs[i].mutable_value() = BasedTwoElementTuple(i * 4.0f);
2049 auto value = ops::Const(root.WithDevice("/device:CPU:0"),
2050 allocs[i].SerializeAsString());
2051 handle_outputs.push_back(ops::XRTAllocate(root, value));
2052 }
2053 TF_ASSERT_OK(root.status());
2054
2055 XrtClientSession session(root);
2056 std::vector<Tensor> outputs;
2057 TF_EXPECT_OK(session.Run(handle_outputs, &outputs));
2058 EXPECT_EQ(outputs.size(), handle_outputs.size());
2059
2060 std::vector<int64> handles;
2061 for (auto& output : outputs) {
2062 handles.push_back(output.scalar<int64>()());
2063 }
2064 // Create holes by releasing even allocations.
2065 std::vector<Operation> handle_releases;
2066 for (size_t i = 0; i < handles.size(); i += 2) {
2067 handle_releases.push_back(
2068 ops::XRTReleaseAllocationHandle(root, Input(handles[i])));
2069 }
2070 TF_ASSERT_OK(root.status());
2071
2072 TF_EXPECT_OK(
2073 session.Run(ClientSession::FeedType(), {}, handle_releases, &outputs));
2074
2075 // Run the compaction API.
2076 auto compact_op = ops::XRTCompactAllocations(root);
2077 TF_EXPECT_OK(
2078 session.Run(ClientSession::FeedType(), {}, {compact_op}, &outputs));
2079
2080 // Read back the allocation left at odd indices.
2081 std::vector<Output> read_outputs;
2082 for (size_t i = 1; i < handles.size(); i += 2) {
2083 read_outputs.push_back(ops::XRTReadLiteral(root, Input(handles[i])));
2084 }
2085 TF_ASSERT_OK(root.status());
2086
2087 TF_EXPECT_OK(session.Run(read_outputs, &outputs));
2088 EXPECT_EQ(outputs.size(), read_outputs.size());
2089
2090 // Verify that everything got moved correctly and the device data matches what
2091 // we have on record.
2092 for (size_t i = 1, j = 0; i < handles.size(); i += 2, ++j) {
2093 xla::LiteralProto response;
2094 EXPECT_TRUE(ParseFromTString(outputs[j].scalar<tstring>()(), &response));
2095 EXPECT_TRUE(CompareLiteralProtos(allocs[i].value(), response));
2096 }
2097 }
2098
TEST(RawApiTest,TestDeviceMemorySwap)2099 TEST(RawApiTest, TestDeviceMemorySwap) {
2100 const xla::Shape scalar_shape = xla::ShapeUtil::MakeShape(xla::F32, {});
2101 // 100MB F32 tensor.
2102 const xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {5000, 5000});
2103 const xla::int64 tensor_size = xla::ShapeUtil::ByteSizeOf(shape);
2104 // On CPU we cannot trigger OOM/swap. For TPU and GPU we select 16GB as
2105 // maximum memory.
2106 xla::int64 device_memory_size = 8LL * 1024 * 1024 * 1024;
2107 if (*xla_test_device_ptr == "TPU" || *xla_test_device_ptr == "XLA_GPU") {
2108 device_memory_size = 16LL * 1024 * 1024 * 1024;
2109 }
2110
2111 xrt::XLAAllocation p0;
2112 *p0.mutable_value() = xla::LiteralUtil::CreateR0<float>(0.90434).ToProto();
2113
2114 // Create a computation which broadcasts a scalar to a big tensor.
2115 xrt::XLAComputation c_bcast;
2116 {
2117 auto shapes = c_bcast.mutable_config()->mutable_program_shape();
2118 *shapes->add_parameters() = scalar_shape.ToProto();
2119 *shapes->mutable_result() = shape.ToProto();
2120 StoreComputationSnapshot(
2121 BroadcastComputation(scalar_shape, shape.dimensions()),
2122 c_bcast.mutable_hlo_snapshot());
2123 }
2124
2125 // Create a computation which compares two tensors.
2126 xrt::XLAComputation c_equal;
2127 {
2128 auto shapes = c_equal.mutable_config()->mutable_program_shape();
2129 *shapes->add_parameters() = shape.ToProto();
2130 *shapes->add_parameters() = shape.ToProto();
2131 *shapes->mutable_result() =
2132 xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto();
2133 StoreComputationSnapshot(IsEqualComputation(shape),
2134 c_equal.mutable_hlo_snapshot());
2135 }
2136
2137 xrt::XRTExecutionConfig e;
2138 e.set_release_input_handles(false);
2139 e.set_release_compilation_handle(false);
2140
2141 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
2142 XrtClientSession session(root);
2143 auto e_config =
2144 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
2145 auto bcast_computation =
2146 ops::Const(root.WithDevice("/device:CPU:0"), c_bcast.SerializeAsString());
2147 auto c_bcast_handle = ops::XRTCompile(root, bcast_computation);
2148 auto equal_computation =
2149 ops::Const(root.WithDevice("/device:CPU:0"), c_equal.SerializeAsString());
2150 auto c_equal_handle = ops::XRTCompile(root, equal_computation);
2151 auto p0_value =
2152 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
2153 auto p0_handle = ops::XRTAllocate(root, p0_value);
2154 std::vector<Tensor> outputs;
2155 std::vector<xla::int64> device_handles;
2156
2157 // Create more data the device can take using the broadcast computation.
2158 xla::int64 num_tensors = 8 + device_memory_size / tensor_size;
2159 for (xla::int64 i = 0; i < num_tensors; ++i) {
2160 auto result = ops::XRTExecute(root, c_bcast_handle.handle, e_config,
2161 {Output(p0_handle)});
2162 TF_ASSERT_OK(root.status());
2163 TF_ASSERT_OK(session.Run({result}, &outputs));
2164 EXPECT_EQ(outputs.size(), 1);
2165 device_handles.push_back(outputs[0].scalar<int64>()());
2166 }
2167
2168 // Trigger computations on XRT handles to verify the swap-out/swap-in logic,
2169 // by comparing sequential couple of tensors.
2170 auto zero_literal = xla::LiteralUtil::CreateR0<xla::int32>(0);
2171 for (size_t i = 0; i + 1 < device_handles.size(); ++i) {
2172 auto exec_op = ops::XRTExecute(
2173 root, c_equal_handle.handle, e_config,
2174 {Input(device_handles[i]), Input(device_handles[i + 1])});
2175 auto read_back = ops::XRTReadLiteral(root, exec_op);
2176
2177 TF_ASSERT_OK(root.status());
2178 TF_ASSERT_OK(session.Run({read_back}, &outputs));
2179 EXPECT_EQ(outputs.size(), 1);
2180
2181 xla::LiteralProto response;
2182 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
2183 auto literal = xla::Literal::CreateFromProto(response).ValueOrDie();
2184 EXPECT_EQ(literal, zero_literal);
2185 }
2186 }
2187
TEST(RawApiTest,TestMetricsFetch)2188 TEST(RawApiTest, TestMetricsFetch) {
2189 xrt::XRTMetricsCollect metrics;
2190 metrics.add_metrics_regex("/tensorflow/xrt/.*");
2191
2192 Scope root = Scope::NewRootScope().WithDevice("/device:CPU:0");
2193 auto metrics_value = ops::Const(root, metrics.SerializeAsString());
2194 Output result = ops::XRTMetricsCollect(root, metrics_value);
2195 TF_ASSERT_OK(root.status());
2196
2197 ClientSession session(root);
2198 std::vector<Tensor> outputs;
2199 TF_EXPECT_OK(session.Run({result}, &outputs));
2200 ASSERT_EQ(outputs.size(), 1);
2201
2202 xrt::MetricsReport report;
2203 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &report));
2204 for (auto& metric : report.metrics()) {
2205 EXPECT_EQ(metric.name().compare(0, 16, "/tensorflow/xrt/"), 0);
2206 }
2207 }
2208
TEST(RawApiTest,TestMemoryInfo)2209 TEST(RawApiTest, TestMemoryInfo) {
2210 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
2211 Output result = ops::XRTMemoryInfo(root);
2212 TF_ASSERT_OK(root.status());
2213
2214 ClientSession session(root);
2215 std::vector<Tensor> outputs;
2216 TF_EXPECT_OK(session.Run({result}, &outputs));
2217 ASSERT_EQ(outputs.size(), 1);
2218
2219 xrt::MemoryInfo mem_info;
2220 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &mem_info));
2221 EXPECT_GT(mem_info.kb_total(), 0);
2222 EXPECT_GT(mem_info.kb_free(), 0);
2223 }
2224
2225 } // namespace
2226
2227 } // namespace tensorflow
2228
main(int argc,char ** argv)2229 int main(int argc, char** argv) {
2230 tensorflow::xla_test_device_ptr = new tensorflow::string("XLA_CPU");
2231 tensorflow::xla_platform_ptr = new tensorflow::string("CPU");
2232 std::vector<tensorflow::Flag> flag_list = {
2233 tensorflow::Flag("xla_test_device", tensorflow::xla_test_device_ptr,
2234 "Tensorflow device type to use for test, e.g., XLA_CPU"),
2235 tensorflow::Flag("xla_platform", tensorflow::xla_platform_ptr,
2236 "The XLA platform to select for the device"),
2237 };
2238 tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
2239 const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
2240 if (!parse_result) {
2241 LOG(ERROR) << "\n" << usage;
2242 return 2;
2243 }
2244 testing::InitGoogleTest(&argc, argv);
2245 if (argc > 1) {
2246 LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
2247 return 2;
2248 }
2249 return RUN_ALL_TESTS();
2250 }
2251