1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18 #include <vector>
19
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/cc/client/client_session.h"
22 #include "tensorflow/cc/framework/ops.h"
23 #include "tensorflow/cc/framework/scope.h"
24 #include "tensorflow/cc/ops/standard_ops.h"
25 #include "tensorflow/compiler/tf2xla/literal_util.h"
26 #include "tensorflow/compiler/tf2xla/shape_util.h"
27 #include "tensorflow/compiler/xla/client/client_library.h"
28 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
29 #include "tensorflow/compiler/xla/client/lib/constants.h"
30 #include "tensorflow/compiler/xla/client/local_client.h"
31 #include "tensorflow/compiler/xla/client/xla_builder.h"
32 #include "tensorflow/compiler/xla/client/xla_computation.h"
33 #include "tensorflow/compiler/xla/literal.h"
34 #include "tensorflow/compiler/xla/literal_util.h"
35 #include "tensorflow/compiler/xla/service/platform_util.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h"
39 #include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h"
40 #include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h"
41 #include "tensorflow/compiler/xrt/xrt.pb.h"
42 #include "tensorflow/core/framework/tensor.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45 #include "tensorflow/core/lib/gtl/array_slice.h"
46 #include "tensorflow/core/platform/types.h"
47 #include "tensorflow/core/util/command_line_flags.h"
48
49 namespace tensorflow {
50 namespace {
51
ReturnDynamicR1()52 xla::XlaComputation ReturnDynamicR1() {
53 xla::XlaBuilder builder("ReturnDynamicR1");
54 auto p0 = xla::Parameter(&builder, 0,
55 xla::ShapeUtil::MakeShape(xla::F32, {4}), "P0");
56 auto p1 = xla::Parameter(&builder, 1,
57 xla::ShapeUtil::MakeShape(xla::F32, {4}), "P1");
58 auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}),
59 "P2");
60 auto sum = xla::Add(p0, p1);
61 auto pad_sum = xla::SetDimensionSize(sum, p2, 0);
62 return builder.Build(pad_sum).ValueOrDie();
63 }
64
ReturnDynamicR2()65 xla::XlaComputation ReturnDynamicR2() {
66 xla::XlaBuilder builder("ReturnDynamicR2");
67 auto p0 = xla::Parameter(&builder, 0,
68 xla::ShapeUtil::MakeShape(xla::F32, {2, 4}), "P0");
69 auto p1 = xla::Parameter(&builder, 1,
70 xla::ShapeUtil::MakeShape(xla::F32, {2, 4}), "P1");
71 auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}),
72 "P2");
73 auto sum = xla::Add(p0, p1);
74 auto pad_sum_dim0 = xla::SetDimensionSize(sum, p2, 0);
75 auto pad_sum_dim1 = xla::SetDimensionSize(pad_sum_dim0, p2, 1);
76 return builder.Build(pad_sum_dim1).ValueOrDie();
77 }
78
AcceptDynamicR1()79 xla::XlaComputation AcceptDynamicR1() {
80 xla::XlaBuilder builder("AcceptDynamicR1");
81 xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
82 dyn_shape.set_dynamic_dimension(0, true);
83 auto p0 = xla::Parameter(&builder, 0, dyn_shape, "P0");
84 auto p1 = xla::Parameter(&builder, 1, dyn_shape, "P1");
85 auto sum = xla::Add(p0, p1);
86 return builder.Build(sum).ValueOrDie();
87 }
88
AcceptDynamicR2()89 xla::XlaComputation AcceptDynamicR2() {
90 xla::XlaBuilder builder("AcceptDynamicR2");
91 xla::Shape dyn_shape;
92 dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4});
93 dyn_shape.set_dynamic_dimension(1, true);
94 auto p0 = xla::Parameter(&builder, 0, dyn_shape, "P0");
95 auto negate = xla::Neg(p0);
96 return builder.Build(negate).ValueOrDie();
97 }
98
ReturnDynamicR1Tuple()99 xla::XlaComputation ReturnDynamicR1Tuple() {
100 xla::XlaBuilder builder("ReturnDynamicR1Tuple");
101 auto p0 = xla::Parameter(&builder, 0,
102 xla::ShapeUtil::MakeShape(xla::F32, {4}), "P0");
103 auto p1 = xla::Parameter(&builder, 1,
104 xla::ShapeUtil::MakeShape(xla::F32, {4}), "P1");
105 auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}),
106 "P2");
107 auto sum = xla::Add(p0, p1);
108 auto sub = xla::Sub(p0, p1);
109 auto one = xla::One(&builder, xla::S32);
110 auto pad_sum = xla::SetDimensionSize(sum, p2, 0);
111 auto pad_sub = xla::SetDimensionSize(sub, p2 + one, 0);
112 auto tuple = xla::Tuple(&builder, {pad_sum, sum, pad_sub});
113 return builder.Build(tuple, /*remove_dynamic_dimensions=*/true).ValueOrDie();
114 }
115
AcceptDynamicR1Tuple()116 xla::XlaComputation AcceptDynamicR1Tuple() {
117 xla::XlaBuilder builder("AcceptDynamicR1");
118 xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
119 dyn_shape.set_dynamic_dimension(0, true);
120 xla::Shape tuple_shape =
121 xla::ShapeUtil::MakeTupleShape({dyn_shape, dyn_shape});
122 xla::Shape nest_tuple_shape =
123 xla::ShapeUtil::MakeTupleShape({dyn_shape, dyn_shape});
124 auto p = xla::Parameter(&builder, 0, tuple_shape, "P0");
125 auto p0 = xla::GetTupleElement(p, 0);
126 auto p1 = xla::GetTupleElement(p, 1);
127 auto sum = xla::Add(p0, p1);
128 return builder.Build(sum).ValueOrDie();
129 }
130
131 template <typename T>
CreateR0(T v)132 xla::LiteralProto CreateR0(T v) {
133 auto array = xla::LiteralUtil::CreateR0<T>(v);
134 return array.ToProto();
135 }
136
137 class XrtClientSession : public ClientSession {
138 public:
XrtClientSession(const Scope & scope)139 explicit XrtClientSession(const Scope& scope) : ClientSession(scope) {
140 auto clear_all = ops::XRTReleaseAllAllocations(scope);
141 std::vector<Tensor> outputs;
142 TF_CHECK_OK(Run(ClientSession::FeedType(), {}, {clear_all}, &outputs));
143 }
144 };
145
146 string* xla_test_device_ptr; // initial value set in main()
147 string* xla_platform_ptr; // initial value set in main()
148
DeviceFromFlag()149 string DeviceFromFlag() {
150 string xla_test_device = *xla_test_device_ptr;
151 return absl::StrCat("/device:", xla_test_device, ":0");
152 }
153
GetAttrLayout(absl::Span<const int64_t> minor_to_mayor)154 std::vector<int> GetAttrLayout(absl::Span<const int64_t> minor_to_mayor) {
155 std::vector<int> layout;
156 for (auto dim : minor_to_mayor) {
157 layout.push_back(static_cast<int>(dim));
158 }
159 return layout;
160 }
161
TwoElementTuple()162 xla::LiteralProto TwoElementTuple() {
163 auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
164 auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
165 auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
166 return tuple.ToProto();
167 }
168
BasedTwoElementTuple(float base)169 xla::LiteralProto BasedTwoElementTuple(float base) {
170 auto array = xla::LiteralUtil::CreateR1<float>({base, base + 1});
171 auto matrix = xla::LiteralUtil::CreateR2<float>(
172 {{base + 2, base + 3}, {base + 4, base + 5}});
173 auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
174 return tuple.ToProto();
175 }
176
ScalarLiteral()177 xla::LiteralProto ScalarLiteral() {
178 auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
179 return scalar.ToProto();
180 }
181
NestedTuple()182 xla::LiteralProto NestedTuple() {
183 auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
184 auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
185 auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
186 auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
187 auto nested = xla::LiteralUtil::MakeTuple({&tuple, &scalar});
188 return nested.ToProto();
189 }
190
MakeTuple0()191 xla::LiteralProto MakeTuple0() {
192 auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
193 auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
194 auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
195 auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
196 auto nested0 = xla::LiteralUtil::MakeTuple({&scalar, &tuple});
197 auto nested1 = xla::LiteralUtil::MakeTuple({&scalar, &nested0});
198 return nested1.ToProto();
199 }
200
FloatVector(absl::Span<const float> v)201 xla::LiteralProto FloatVector(absl::Span<const float> v) {
202 auto array = xla::LiteralUtil::CreateR1<float>(v);
203 return array.ToProto();
204 }
205
FloatMatrix(std::initializer_list<std::initializer_list<float>> v,const xla::Layout & layout)206 xla::LiteralProto FloatMatrix(
207 std::initializer_list<std::initializer_list<float>> v,
208 const xla::Layout& layout) {
209 auto array = xla::LiteralUtil::CreateR2WithLayout<float>(v, layout);
210 return array.ToProto();
211 }
212
ReadOutputLiteral(const std::vector<Tensor> & outputs,size_t idx)213 xla::Literal ReadOutputLiteral(const std::vector<Tensor>& outputs, size_t idx) {
214 xla::LiteralProto response;
215 CHECK(ParseFromTString(outputs[idx].scalar<tstring>()(), &response));
216 return xla::Literal::CreateFromProto(response).ValueOrDie();
217 }
218
CompareLiteralProtos(const xla::LiteralProto & a,const xla::LiteralProto & b)219 bool CompareLiteralProtos(const xla::LiteralProto& a,
220 const xla::LiteralProto& b) {
221 auto l_a = xla::Literal::CreateFromProto(a).ValueOrDie();
222 auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
223 bool equal = l_a == l_b;
224 if (!equal) {
225 LOG(INFO) << "LiteralProtos don't match:\n"
226 << a.DebugString() << "\n!=\n"
227 << b.DebugString();
228 }
229 return equal;
230 }
231
CompareLiteralToLiteralProto(const xla::Literal & a,const xla::LiteralProto & b)232 bool CompareLiteralToLiteralProto(const xla::Literal& a,
233 const xla::LiteralProto& b) {
234 auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
235 bool equal = a == l_b;
236 if (!equal) {
237 LOG(INFO) << "Literal and LiteralProto don't match:\n"
238 << a.ToProto().DebugString() << "\n!=\n"
239 << b.DebugString();
240 }
241 return equal;
242 }
243
CompareLiterals(const xla::Literal & a,const xla::Literal & b)244 bool CompareLiterals(const xla::Literal& a, const xla::Literal& b) {
245 bool equal = a == b;
246 if (!equal) {
247 LOG(INFO) << "Literals don't match:\n"
248 << a.ToProto().DebugString() << "\n!=\n"
249 << b.ToProto().DebugString();
250 }
251 return equal;
252 }
253
OnePlusTwo()254 xla::XlaComputation OnePlusTwo() {
255 xla::XlaBuilder builder("OnePlusTwo");
256 auto c0 = xla::ConstantR0(&builder, 1.0f);
257 auto c1 = xla::ConstantR0(&builder, 2.0f);
258 xla::Add(c0, c1);
259 return builder.Build().ValueOrDie();
260 }
261
AddAndScale()262 xla::XlaComputation AddAndScale() {
263 xla::XlaBuilder builder("AddAndScale");
264 auto p0 = xla::Parameter(&builder, 0,
265 xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0");
266 auto p1 = xla::Parameter(&builder, 1,
267 xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1");
268 auto sum = xla::Add(p0, p1);
269 auto c = xla::ConstantR0<float>(&builder, 3.0f);
270 xla::Mul(sum, c);
271 return builder.Build().ValueOrDie();
272 }
273
SubAndScale()274 xla::XlaComputation SubAndScale() {
275 xla::XlaBuilder builder("SubAndScale");
276 auto p0 = xla::Parameter(&builder, 0,
277 xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0");
278 auto p1 = xla::Parameter(&builder, 1,
279 xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1");
280 auto sum = xla::Sub(p0, p1);
281 auto c = xla::ConstantR0<float>(&builder, 11.0f);
282 xla::Mul(sum, c);
283 return builder.Build().ValueOrDie();
284 }
285
Dot()286 xla::XlaComputation Dot() {
287 xla::XlaBuilder builder("Dot");
288 auto p0 = xla::Parameter(
289 &builder, 0,
290 xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1}), "P0");
291 auto p1 = xla::Parameter(
292 &builder, 1,
293 xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}), "P1");
294 xla::DotDimensionNumbers ddn;
295 ddn.add_lhs_contracting_dimensions(1);
296 ddn.add_rhs_contracting_dimensions(0);
297 xla::DotGeneral(p0, p1, ddn);
298 return builder.Build().ValueOrDie();
299 }
300
AddS64()301 xla::XlaComputation AddS64() {
302 xla::XlaBuilder builder("AddS64");
303 auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::S64, {}),
304 "P0");
305 auto p1 = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(xla::S64, {}),
306 "P1");
307 xla::Add(p0, p1);
308 return builder.Build().ValueOrDie();
309 }
310
AddAndTuple()311 xla::XlaComputation AddAndTuple() {
312 xla::XlaBuilder builder("AddAndTuple");
313 auto p0 = xla::Parameter(&builder, 0,
314 xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0");
315 auto p1 = xla::Parameter(&builder, 1,
316 xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1");
317 auto sum = xla::Add(p0, p1);
318 xla::Tuple(&builder, {sum});
319 return builder.Build().ValueOrDie();
320 }
321
AddAndSubTuple()322 xla::XlaComputation AddAndSubTuple() {
323 xla::XlaBuilder builder("AddAndSubTuple");
324 auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::F32, {}),
325 "P0");
326 auto p1 = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(xla::F32, {}),
327 "P1");
328 auto sum = xla::Add(p0, p1);
329 auto sub = xla::Sub(p0, p1);
330 xla::Tuple(&builder, {sum, sub});
331 return builder.Build().ValueOrDie();
332 }
333
BroadcastComputation(const xla::Shape & shape,absl::Span<const int64_t> dimensions)334 xla::XlaComputation BroadcastComputation(const xla::Shape& shape,
335 absl::Span<const int64_t> dimensions) {
336 xla::XlaBuilder builder("BroadcastComputation");
337 auto p0 = xla::Parameter(&builder, 0, shape, "P0");
338 xla::Broadcast(p0, dimensions);
339 return builder.Build().ValueOrDie();
340 }
341
IsEqualComputation(const xla::Shape & shape)342 xla::XlaComputation IsEqualComputation(const xla::Shape& shape) {
343 xla::XlaBuilder builder("IsEqualComputation");
344 auto p0 = xla::Parameter(&builder, 0, shape, "P0");
345 auto p1 = xla::Parameter(&builder, 1, shape, "P1");
346 auto cmp =
347 xla::Ne(xla::Sub(p0, p1), xla::Zero(&builder, shape.element_type()));
348 auto icmp = xla::ConvertElementType(cmp, xla::S32);
349 xla::ReduceAll(icmp, xla::Zero(&builder, xla::S32),
350 xla::CreateScalarAddComputation(xla::S32, &builder));
351 return builder.Build().ValueOrDie();
352 }
353
StoreComputationSnapshot(const xla::XlaComputation & computation,xla::HloSnapshot * dst)354 void StoreComputationSnapshot(const xla::XlaComputation& computation,
355 xla::HloSnapshot* dst) {
356 auto snapshot = computation.Snapshot().ValueOrDie();
357 *dst = *snapshot;
358 }
359
XlaCompiledProgramShape(const xla::XlaComputation & computation,const xla::ProgramShape & input_program_shape)360 xla::ProgramShape XlaCompiledProgramShape(
361 const xla::XlaComputation& computation,
362 const xla::ProgramShape& input_program_shape) {
363 se::Platform* platform =
364 xla::PlatformUtil::GetPlatform(*xla_platform_ptr).ValueOrDie();
365 xla::LocalClient* client =
366 xla::ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
367 xla::ExecutableBuildOptions exec_options;
368 exec_options.set_result_layout(input_program_shape.result());
369 std::vector<const xla::Shape*> parameters_shapes;
370 for (int64_t i = 0; i < input_program_shape.parameters_size(); ++i) {
371 parameters_shapes.push_back(&input_program_shape.parameters(i));
372 }
373 std::vector<std::unique_ptr<xla::LocalExecutable>> local_executables =
374 client->Compile(computation, parameters_shapes, exec_options).value();
375 EXPECT_EQ(local_executables.size(), 1);
376 std::unique_ptr<xla::LocalExecutable> local_executable =
377 std::move(local_executables[0]);
378 return local_executable->executable()
379 ->module()
380 .entry_computation()
381 ->ComputeProgramShape();
382 }
383
TEST(RawApiTest,AllocFromTensor)384 TEST(RawApiTest, AllocFromTensor) {
385 xla::Literal literal =
386 xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
387 Tensor tensor;
388 TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor));
389
390 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
391 std::vector<int> layout =
392 GetAttrLayout(literal.shape().layout().minor_to_major());
393 ops::XRTAllocateFromTensor::Attrs alloc_attrs =
394 ops::XRTAllocateFromTensor::Layouts(layout);
395 auto handle =
396 ops::XRTAllocateFromTensor(root, {tensor}, {tensor.shape()}, alloc_attrs);
397 auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
398 TF_ASSERT_OK(root.status());
399
400 XrtClientSession session(root);
401 std::vector<Tensor> outputs;
402 TF_EXPECT_OK(session.Run({read_back}, &outputs));
403 EXPECT_EQ(outputs.size(), 1);
404
405 xla::LiteralProto response;
406 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
407 EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
408 }
409
TEST(RawApiTest,AllocUninitialized)410 TEST(RawApiTest, AllocUninitialized) {
411 xla::Literal literal =
412 xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
413 Tensor tensor;
414 TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor));
415
416 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
417 std::vector<int> layout =
418 GetAttrLayout(literal.shape().layout().minor_to_major());
419
420 auto allocate_op =
421 ops::XRTAllocateUninitialized(root, DT_FLOAT, tensor.shape());
422
423 Tensor handle;
424 std::vector<Tensor> outputs;
425 XrtClientSession session(root);
426 // Allocate the tensor
427 {
428 TF_EXPECT_OK(session.Run({allocate_op}, &outputs));
429 handle = outputs[0];
430 }
431
432 // Make sure it has the expected shape
433 {
434 auto read_back_op = ops::XRTReadLiteral(root, handle);
435 TF_ASSERT_OK(root.status());
436
437 TF_EXPECT_OK(session.Run({read_back_op}, &outputs));
438 EXPECT_EQ(outputs.size(), 1);
439 xla::LiteralProto read_back_literal;
440 EXPECT_TRUE(
441 ParseFromTString(outputs[0].scalar<tstring>()(), &read_back_literal));
442 Tensor read_back_tensor;
443 TF_ASSERT_OK(LiteralToHostTensor(
444 xla::Literal::CreateFromProto(read_back_literal).ValueOrDie(), DT_FLOAT,
445 &read_back_tensor));
446
447 // The shape should be the same as 'tensor', but we don't have any
448 // expectation about the value of the tensors yet since it is uninitialized
449 EXPECT_EQ(tensor.shape(), read_back_tensor.shape());
450 }
451
452 // Make sure we can write to it
453 xla::LiteralProto new_literal =
454 xla::LiteralUtil::CreateR2({{9.0f, 2.0f}, {4.0f, 1.0f}}).ToProto();
455 {
456 auto new_value = ops::Const(root.WithDevice("/device:CPU:0"),
457 new_literal.SerializeAsString());
458 auto write_op = ops::XRTWriteLiteral(root, Input(handle), new_value);
459 TF_ASSERT_OK(root.status());
460 TF_EXPECT_OK(session.Run({write_op}, &outputs));
461 }
462
463 // Now read it back
464 {
465 auto read_back_op = ops::XRTReadLiteralAndRelease(root, handle);
466 TF_ASSERT_OK(root.status());
467 TF_EXPECT_OK(session.Run({read_back_op}, &outputs));
468 EXPECT_EQ(outputs.size(), 1);
469
470 xla::LiteralProto response;
471 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
472 EXPECT_TRUE(CompareLiteralProtos(response, new_literal));
473 }
474 }
475
TEST(RawApiTest,AllocFromTensorTuple)476 TEST(RawApiTest, AllocFromTensorTuple) {
477 xla::Literal literal0 =
478 xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
479 xla::Literal literal1 =
480 xla::LiteralUtil::CreateR2<float>({{14.0f, -5.0f}, {16.0f, 17.0f}});
481 xla::Literal literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1});
482 Tensor tensor0;
483 TF_ASSERT_OK(LiteralToHostTensor(literal0, DT_FLOAT, &tensor0));
484 Tensor tensor1;
485 TF_ASSERT_OK(LiteralToHostTensor(literal1, DT_FLOAT, &tensor1));
486
487 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
488 std::vector<int> layout = GetShapeLayoutVector(literal.shape()).ValueOrDie();
489 ops::XRTAllocateFromTensor::Attrs alloc_attrs =
490 ops::XRTAllocateFromTensor::Layouts(layout);
491 auto handle = ops::XRTAllocateFromTensor(root, {tensor0, tensor1},
492 {tensor0.shape(), tensor1.shape()},
493 alloc_attrs);
494 auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
495 TF_ASSERT_OK(root.status());
496
497 XrtClientSession session(root);
498 std::vector<Tensor> outputs;
499 TF_EXPECT_OK(session.Run({read_back}, &outputs));
500 EXPECT_EQ(outputs.size(), 1);
501
502 xla::LiteralProto response;
503 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
504 EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
505 }
506
TEST(RawApiTest,AllocFromTensorTupleSingle)507 TEST(RawApiTest, AllocFromTensorTupleSingle) {
508 xla::Literal literal0 =
509 xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
510 xla::Literal literal = xla::LiteralUtil::MakeTuple({&literal0});
511 Tensor tensor0;
512 TF_ASSERT_OK(LiteralToHostTensor(literal0, DT_FLOAT, &tensor0));
513
514 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
515 std::vector<int> layout = GetShapeLayoutVector(literal.shape()).ValueOrDie();
516 ops::XRTAllocateFromTensor::Attrs alloc_attrs =
517 ops::XRTAllocateFromTensor::Layouts(layout).MakeTuple(true);
518 auto handle = ops::XRTAllocateFromTensor(root, {tensor0}, {tensor0.shape()},
519 alloc_attrs);
520 auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
521 TF_ASSERT_OK(root.status());
522
523 XrtClientSession session(root);
524 std::vector<Tensor> outputs;
525 TF_EXPECT_OK(session.Run({read_back}, &outputs));
526 EXPECT_EQ(outputs.size(), 1);
527
528 xla::LiteralProto response;
529 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
530 EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
531 }
532
TEST(RawApiTest,AllocFromTensorRelayout)533 TEST(RawApiTest, AllocFromTensorRelayout) {
534 xla::Literal literal =
535 xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
536 Tensor tensor;
537 TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor));
538
539 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
540 // Use inverse array layout with the tensor data above.
541 std::vector<int> layout({0, 1});
542 ops::XRTAllocateFromTensor::Attrs alloc_attrs =
543 ops::XRTAllocateFromTensor::Layouts(layout);
544 auto handle =
545 ops::XRTAllocateFromTensor(root, {tensor}, {tensor.shape()}, alloc_attrs);
546 auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
547 TF_ASSERT_OK(root.status());
548
549 XrtClientSession session(root);
550 std::vector<Tensor> outputs;
551 TF_EXPECT_OK(session.Run({read_back}, &outputs));
552 EXPECT_EQ(outputs.size(), 1);
553
554 xla::LiteralProto response;
555 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
556 // We have sent literal's data (in array layout) with a attribute layout
557 // {0,1}, so the expected literal read from device needs to be changed
558 // accordingly.
559 xla::Literal expected_literal =
560 xla::LiteralUtil::CreateR2<float>({{4.0f, 6.0f}, {5.0f, 7.0f}});
561 EXPECT_TRUE(CompareLiteralToLiteralProto(expected_literal, response));
562 }
563
TEST(RawApiTest,AllocAndRewrite)564 TEST(RawApiTest, AllocAndRewrite) {
565 xrt::XLAAllocation alloc;
566 *alloc.mutable_value() =
567 xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto();
568
569 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
570 auto value =
571 ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
572 auto handle = ops::XRTAllocate(root, value);
573 auto read_back = ops::XRTReadLiteral(root, handle);
574 TF_ASSERT_OK(root.status());
575
576 XrtClientSession session(root);
577 std::vector<Tensor> outputs;
578 TF_EXPECT_OK(session.Run({read_back, handle}, &outputs));
579 EXPECT_EQ(outputs.size(), 2);
580
581 int64_t allocation_handle = outputs[1].scalar<int64_t>()();
582 xla::LiteralProto response;
583 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
584 EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
585
586 xla::LiteralProto new_literal =
587 xla::LiteralUtil::CreateR2({{9, 2}, {4, 1}}).ToProto();
588 auto new_value = ops::Const(root.WithDevice("/device:CPU:0"),
589 new_literal.SerializeAsString());
590 auto write_op =
591 ops::XRTWriteLiteral(root, Input(allocation_handle), new_value);
592 TF_ASSERT_OK(root.status());
593 TF_EXPECT_OK(session.Run({write_op}, &outputs));
594 EXPECT_EQ(outputs.size(), 1);
595 EXPECT_EQ(allocation_handle, outputs[0].scalar<int64_t>()());
596
597 auto read_after_write = ops::XRTReadLiteral(root, Input(allocation_handle));
598 TF_EXPECT_OK(session.Run({read_after_write}, &outputs));
599 EXPECT_EQ(outputs.size(), 1);
600
601 xla::LiteralProto new_response;
602 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &new_response));
603 EXPECT_TRUE(CompareLiteralProtos(new_literal, new_response));
604
605 Tensor release_tensor(DT_INT64, TensorShape({1}));
606 release_tensor.flat<int64_t>()(0) = allocation_handle;
607
608 auto release = ops::XRTReleaseAllocationHandle(root, release_tensor);
609 TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {}, {release}, &outputs));
610 }
611
TEST(RawApiTest,AllocReleaseMany)612 TEST(RawApiTest, AllocReleaseMany) {
613 xrt::XLAAllocation alloc1;
614 *alloc1.mutable_value() =
615 xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto();
616 xrt::XLAAllocation alloc2;
617 *alloc2.mutable_value() =
618 xla::LiteralUtil::CreateR2({{6, 7}, {4, 5}}).ToProto();
619
620 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
621 auto value1 =
622 ops::Const(root.WithDevice("/device:CPU:0"), alloc1.SerializeAsString());
623 auto value2 =
624 ops::Const(root.WithDevice("/device:CPU:0"), alloc2.SerializeAsString());
625 auto handle1 = ops::XRTAllocate(root, value1);
626 auto handle2 = ops::XRTAllocate(root, value2);
627 TF_ASSERT_OK(root.status());
628
629 XrtClientSession session(root);
630 std::vector<Tensor> outputs;
631 TF_EXPECT_OK(session.Run({handle1, handle2}, &outputs));
632 EXPECT_EQ(outputs.size(), 2);
633
634 int64_t allocation_handle1 = outputs[0].scalar<int64_t>()();
635 int64_t allocation_handle2 = outputs[1].scalar<int64_t>()();
636
637 Tensor release_tensor(DT_INT64, TensorShape({2}));
638 release_tensor.flat<int64_t>()(0) = allocation_handle1;
639 release_tensor.flat<int64_t>()(1) = allocation_handle2;
640
641 auto release = ops::XRTReleaseAllocationHandle(root, release_tensor);
642 TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {}, {release}, &outputs));
643 }
644
TEST(RawApiTest,CompileAndReleaseMany)645 TEST(RawApiTest, CompileAndReleaseMany) {
646 xrt::XLAComputation c1;
647 auto config1 = c1.mutable_config();
648 auto shapes1 = config1->mutable_program_shape();
649 *shapes1->add_parameters() =
650 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
651 *shapes1->add_parameters() =
652 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
653 *shapes1->mutable_result() =
654 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
655 StoreComputationSnapshot(AddAndScale(), c1.mutable_hlo_snapshot());
656
657 xrt::XLAComputation c2;
658 auto config2 = c2.mutable_config();
659 auto shapes2 = config2->mutable_program_shape();
660 *shapes2->add_parameters() =
661 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
662 *shapes2->add_parameters() =
663 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
664 *shapes2->mutable_result() =
665 xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})})
666 .ToProto();
667 StoreComputationSnapshot(AddAndTuple(), c2.mutable_hlo_snapshot());
668
669 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
670 auto computation1 =
671 ops::Const(root.WithDevice("/device:CPU:0"), c1.SerializeAsString());
672 auto c_handle1 = ops::XRTCompile(root, computation1);
673 auto computation2 =
674 ops::Const(root.WithDevice("/device:CPU:0"), c2.SerializeAsString());
675 auto c_handle2 = ops::XRTCompile(root, computation2);
676 TF_ASSERT_OK(root.status());
677
678 XrtClientSession session(root);
679 std::vector<Tensor> outputs;
680 TF_EXPECT_OK(session.Run({c_handle1.handle, c_handle2.handle}, &outputs));
681 EXPECT_EQ(outputs.size(), 2);
682
683 int64_t compilation_handle1 = outputs[0].scalar<int64_t>()();
684 int64_t compilation_handle2 = outputs[1].scalar<int64_t>()();
685
686 Tensor release_tensor(DT_INT64, TensorShape({2}));
687 release_tensor.flat<int64_t>()(0) = compilation_handle1;
688 release_tensor.flat<int64_t>()(1) = compilation_handle2;
689
690 auto release = ops::XRTReleaseCompilationHandle(root, release_tensor);
691 TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {}, {release}, &outputs));
692 }
693
TEST(RawApiTest,AllocAndClearAll)694 TEST(RawApiTest, AllocAndClearAll) {
695 xrt::XLAAllocation alloc;
696 *alloc.mutable_value() =
697 xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto();
698
699 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
700 auto value =
701 ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
702 auto handle = ops::XRTAllocate(root, value);
703 TF_ASSERT_OK(root.status());
704
705 XrtClientSession session(root);
706 std::vector<Tensor> outputs;
707 TF_EXPECT_OK(session.Run({handle}, &outputs));
708 EXPECT_EQ(outputs.size(), 1);
709
710 int64_t allocation_handle = outputs[0].scalar<int64_t>()();
711
712 auto clear_all = ops::XRTReleaseAllAllocations(root);
713
714 TF_EXPECT_OK(
715 session.Run(ClientSession::FeedType(), {}, {clear_all}, &outputs));
716 EXPECT_EQ(outputs.size(), 0);
717
718 auto read_after_clear = ops::XRTReadLiteral(root, Input(allocation_handle));
719 EXPECT_EQ(session.Run({read_after_clear}, &outputs).code(),
720 error::Code::NOT_FOUND);
721 }
722
TEST(RawApiTest,ReadAndWriteState)723 TEST(RawApiTest, ReadAndWriteState) {
724 xrt::XLAAllocation alloc;
725 *alloc.mutable_value() = TwoElementTuple();
726
727 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
728 auto value =
729 ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
730 auto handle = ops::XRTAllocate(root, value);
731 auto read_back = ops::XRTReadLiteral(root, handle);
732 auto release = ops::XRTReleaseAllocationHandle(
733 root.WithControlDependencies(read_back), handle);
734 TF_ASSERT_OK(root.status());
735
736 XrtClientSession session(root);
737 std::vector<Tensor> outputs;
738 TF_EXPECT_OK(
739 session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs));
740
741 xla::LiteralProto response;
742 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
743
744 EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
745 }
746
TEST(RawApiTest,ReadAndWriteStateAutoFree)747 TEST(RawApiTest, ReadAndWriteStateAutoFree) {
748 xrt::XLAAllocation alloc;
749 *alloc.mutable_value() = TwoElementTuple();
750
751 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
752 auto value =
753 ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
754 auto handle = ops::XRTAllocate(root, value);
755 auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
756 TF_ASSERT_OK(root.status());
757
758 XrtClientSession session(root);
759 std::vector<Tensor> outputs;
760 TF_EXPECT_OK(session.Run({read_back}, &outputs));
761
762 xla::LiteralProto response;
763 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
764 EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
765 }
766
TEST(RawApiTest,SubBuffer)767 TEST(RawApiTest, SubBuffer) {
768 xrt::XLAAllocation alloc;
769 *alloc.mutable_value() = NestedTuple();
770
771 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
772 auto value =
773 ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
774 auto base_handle = ops::XRTAllocate(root, value);
775 auto index_0 = ops::Const(root.WithDevice("/device:CPU:0"), {0});
776 auto index_1 = ops::Const(root.WithDevice("/device:CPU:0"), {1});
777 auto index_00 = ops::Const(root.WithDevice("/device:CPU:0"), {0, 0});
778 auto sub_0 = ops::XRTSubTuple(root, base_handle, index_0);
779 auto sub_1 = ops::XRTSubTuple(root, base_handle, index_1);
780 auto sub_00 = ops::XRTSubTupleAndRelease(
781 root.WithControlDependencies(
782 {sub_0.output_handle.op(), sub_1.output_handle.op()}),
783 base_handle, index_00);
784 auto value_0 = ops::XRTReadLiteralAndRelease(root, sub_0);
785 auto value_1 = ops::XRTReadLiteralAndRelease(root, sub_1);
786 auto value_00 = ops::XRTReadLiteralAndRelease(root, sub_00);
787 TF_ASSERT_OK(root.status());
788
789 XrtClientSession session(root);
790 std::vector<Tensor> outputs;
791 TF_EXPECT_OK(session.Run({value_0, value_1, value_00}, &outputs));
792
793 auto base_literal = xla::Literal::CreateFromProto(alloc.value()).ValueOrDie();
794 auto base_elements = base_literal.DecomposeTuple();
795 auto nested_0_elements = base_elements[0].Clone().DecomposeTuple();
796 xla::LiteralProto response_0;
797 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response_0));
798 EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[0], response_0));
799 xla::LiteralProto response_1;
800 EXPECT_TRUE(ParseFromTString(outputs[1].scalar<tstring>()(), &response_1));
801 EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[1], response_1));
802 xla::LiteralProto response_00;
803 EXPECT_TRUE(ParseFromTString(outputs[2].scalar<tstring>()(), &response_00));
804 EXPECT_TRUE(CompareLiteralToLiteralProto(nested_0_elements[0], response_00));
805 }
806
TEST(RawApiTest,MakeTuple)807 TEST(RawApiTest, MakeTuple) {
808 xrt::XLAAllocation alloc_0;
809 *alloc_0.mutable_value() = TwoElementTuple();
810 xrt::XLAAllocation alloc_1;
811 *alloc_1.mutable_value() = ScalarLiteral();
812
813 // The trivial tuple that just forwards its input and releases it.
814 xrt::XLATupleNode desc_0;
815 desc_0.set_input_index(0);
816 desc_0.set_release_input_handle(true);
817
818 xrt::XLATupleNode desc_1;
819 auto subdesc_10 = desc_1.add_tuples();
820 auto subdesc_11 = desc_1.add_tuples();
821 subdesc_10->set_input_index(0);
822 auto subdesc_110 = subdesc_11->add_tuples();
823 subdesc_110->set_input_index(0);
824 auto subdesc_111 = subdesc_11->add_tuples();
825 subdesc_111->set_input_index(1);
826
827 xrt::XLATupleNode desc_2;
828 auto subdesc_20 = desc_2.add_tuples();
829 auto subdesc_21 = desc_2.add_tuples();
830 subdesc_20->set_input_index(1);
831 subdesc_20->set_release_input_handle(true);
832 subdesc_21->set_input_index(0);
833 subdesc_21->set_release_input_handle(true);
834
835 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
836 auto value_0 =
837 ops::Const(root.WithDevice("/device:CPU:0"), alloc_0.SerializeAsString());
838 auto handle_0 = ops::XRTAllocate(root, value_0);
839 auto value_1 =
840 ops::Const(root.WithDevice("/device:CPU:0"), alloc_1.SerializeAsString());
841 auto handle_1 = ops::XRTAllocate(root, value_1);
842 auto tuple_0 =
843 ops::Const(root.WithDevice("/device:CPU:0"), desc_0.SerializeAsString());
844 auto handle_2 =
845 ops::XRTMakeTuple(root, tuple_0, {static_cast<Output>(handle_0)});
846 // handle_0 has now been released.
847 auto tuple_1 =
848 ops::Const(root.WithDevice("/device:CPU:0"), desc_1.SerializeAsString());
849 auto handle_3 = ops::XRTMakeTuple(
850 root, tuple_1,
851 {static_cast<Output>(handle_1), static_cast<Output>(handle_2)});
852 auto tuple_2 =
853 ops::Const(root.WithDevice("/device:CPU:0"), desc_2.SerializeAsString());
854 // Make sure this runs after handle_3 has completed, since it will free
855 // handle_1 and handle_2.
856 auto handle_4 = ops::XRTMakeTuple(
857 root.WithControlDependencies(handle_3), tuple_2,
858 {static_cast<Output>(handle_1), static_cast<Output>(handle_2)});
859 // handle_1 and handle_2 have now been released.
860
861 auto res_0 = ops::XRTReadLiteralAndRelease(root, handle_3);
862 auto res_1 = ops::XRTReadLiteralAndRelease(root, handle_4);
863 TF_ASSERT_OK(root.status());
864
865 XrtClientSession session(root);
866 std::vector<Tensor> outputs;
867 TF_EXPECT_OK(session.Run({res_0, res_1}, &outputs));
868 xla::LiteralProto response_0;
869 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response_0));
870 xla::LiteralProto response_1;
871 EXPECT_TRUE(ParseFromTString(outputs[1].scalar<tstring>()(), &response_1));
872
873 auto expected_0 = MakeTuple0();
874 EXPECT_TRUE(CompareLiteralProtos(response_0, expected_0));
875 auto expected_1 = NestedTuple();
876 EXPECT_TRUE(CompareLiteralProtos(response_1, expected_1));
877 }
878
TEST(RawApiTest,ExecuteChainedOpByOp)879 TEST(RawApiTest, ExecuteChainedOpByOp) {
880 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
881
882 auto make_computation = [](const std::function<xla::XlaComputation()>& fn) {
883 xrt::XLAComputation c;
884 auto config = c.mutable_config();
885 auto shapes = config->mutable_program_shape();
886 *shapes->add_parameters() =
887 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
888 *shapes->add_parameters() =
889 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
890 *shapes->mutable_result() =
891 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
892 StoreComputationSnapshot(fn(), c.mutable_hlo_snapshot());
893 return c.SerializeAsString();
894 };
895
896 auto c_add_scale = make_computation(AddAndScale);
897 auto c_sub_scale = make_computation(SubAndScale);
898
899 auto c_add_scale_op = ops::XRTCompile(
900 root, ops::Const(root.WithDevice("/device:CPU:0"), c_add_scale));
901 auto c_sub_scale_op = ops::XRTCompile(
902 root, ops::Const(root.WithDevice("/device:CPU:0"), c_sub_scale));
903 TF_ASSERT_OK(root.status());
904
905 XrtClientSession session(root);
906 std::vector<Tensor> outputs;
907 TF_EXPECT_OK(
908 session.Run({c_add_scale_op.handle, c_sub_scale_op.handle}, &outputs));
909 EXPECT_EQ(outputs.size(), 2);
910
911 int64_t c_add_scale_handle = outputs[0].scalar<int64_t>()();
912 int64_t c_sub_scale_handle = outputs[1].scalar<int64_t>()();
913
914 xrt::XLAAllocation p0;
915 *p0.mutable_value() = FloatVector({1.0f, 2.0f});
916 xrt::XLAAllocation p1;
917 *p1.mutable_value() = FloatVector({8.0f, 5.0f});
918
919 auto p0_handle = ops::XRTAllocate(
920 root,
921 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()));
922 auto p1_handle = ops::XRTAllocate(
923 root,
924 ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()));
925
926 xrt::XRTExecutionConfig e;
927 e.set_release_input_handles(false);
928 e.set_release_compilation_handle(false);
929 auto e_config =
930 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
931 auto result0 = ops::XRTExecute(root, Input(c_add_scale_handle), e_config,
932 {Output(p0_handle), Output(p1_handle)});
933 auto result1 = ops::XRTExecute(root, Input(c_sub_scale_handle), e_config,
934 {Output(p0_handle), Output(p1_handle)});
935 auto result = ops::XRTExecute(root, Input(c_add_scale_handle), e_config,
936 {result0.output_handle, result1.output_handle});
937 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
938 TF_ASSERT_OK(root.status());
939
940 TF_EXPECT_OK(session.Run({read_back}, &outputs));
941
942 xla::LiteralProto response;
943 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
944
945 auto expected = xla::LiteralUtil::CreateR1<float>({-150.0f, -36.0f});
946 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
947 }
948
TEST(RawApiTest,ExecuteChained)949 TEST(RawApiTest, ExecuteChained) {
950 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
951
952 auto make_computation = [](const std::function<xla::XlaComputation()>& fn) {
953 xrt::XLAComputation c;
954 auto config = c.mutable_config();
955 auto shapes = config->mutable_program_shape();
956 *shapes->add_parameters() =
957 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
958 *shapes->add_parameters() =
959 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
960 *shapes->mutable_result() =
961 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
962 StoreComputationSnapshot(fn(), c.mutable_hlo_snapshot());
963 return c.SerializeAsString();
964 };
965
966 auto c_add_scale = make_computation(AddAndScale);
967 auto c_sub_scale = make_computation(SubAndScale);
968
969 auto c_add_scale_op = ops::XRTCompile(
970 root, ops::Const(root.WithDevice("/device:CPU:0"), c_add_scale));
971 auto c_sub_scale_op = ops::XRTCompile(
972 root, ops::Const(root.WithDevice("/device:CPU:0"), c_sub_scale));
973 TF_ASSERT_OK(root.status());
974
975 XrtClientSession session(root);
976 std::vector<Tensor> outputs;
977 TF_EXPECT_OK(
978 session.Run({c_add_scale_op.handle, c_sub_scale_op.handle}, &outputs));
979 EXPECT_EQ(outputs.size(), 2);
980
981 int64_t c_add_scale_handle = outputs[0].scalar<int64_t>()();
982 int64_t c_sub_scale_handle = outputs[1].scalar<int64_t>()();
983
984 xrt::XLAAllocation p0;
985 *p0.mutable_value() = FloatVector({1.0f, 2.0f});
986 xrt::XLAAllocation p1;
987 *p1.mutable_value() = FloatVector({8.0f, 5.0f});
988
989 auto p0_handle_op = ops::XRTAllocate(
990 root,
991 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()));
992 auto p1_handle_op = ops::XRTAllocate(
993 root,
994 ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()));
995
996 TF_EXPECT_OK(session.Run({p0_handle_op, p1_handle_op}, &outputs));
997 EXPECT_EQ(outputs.size(), 2);
998
999 int64_t p0_handle = outputs[0].scalar<int64_t>()();
1000 int64_t p1_handle = outputs[1].scalar<int64_t>()();
1001
1002 xrt::XRTChainedExecuteConfig config;
1003 auto config_const =
1004 ops::Const(root.WithDevice("/device:CPU:0"), config.SerializeAsString());
1005
1006 xrt::XRTChainedExecutePlan plan;
1007 xrt::XRTChainedExecuteOp* op;
1008 xrt::XRTChainedExecuteOp::Input* input;
1009 xrt::XRTChainedExecuteOp::Output* output;
1010
1011 // Index 0
1012 op = plan.add_ops();
1013 op->set_data_handle(p0_handle);
1014
1015 // Index 1
1016 op = plan.add_ops();
1017 op->set_data_handle(p1_handle);
1018
1019 // Index 2
1020 op = plan.add_ops();
1021 op->set_computation_handle(c_add_scale_handle);
1022 input = op->add_inputs();
1023 input->set_op_index(0);
1024 input = op->add_inputs();
1025 input->set_op_index(1);
1026
1027 // Index 3
1028 op = plan.add_ops();
1029 op->set_computation_handle(c_sub_scale_handle);
1030 input = op->add_inputs();
1031 input->set_op_index(0);
1032 input = op->add_inputs();
1033 input->set_op_index(1);
1034
1035 // Index 4
1036 op = plan.add_ops();
1037 op->set_computation_handle(c_add_scale_handle);
1038 input = op->add_inputs();
1039 input->set_op_index(2);
1040 input = op->add_inputs();
1041 input->set_op_index(3);
1042 output = op->add_outputs();
1043 output->set_result_index(0);
1044
1045 auto plan_const =
1046 ops::Const(root.WithDevice("/device:CPU:0"), plan.SerializeAsString());
1047 auto result = ops::XRTExecuteChained(root, plan_const, config_const);
1048 TF_ASSERT_OK(root.status());
1049
1050 TF_EXPECT_OK(session.Run({result}, &outputs));
1051 EXPECT_EQ(outputs.size(), 1);
1052
1053 auto handles_vec = outputs[0].vec<int64_t>();
1054 EXPECT_EQ(handles_vec.size(), 1);
1055
1056 auto read_back = ops::XRTReadLiteralAndRelease(root, Input(handles_vec(0)));
1057 TF_ASSERT_OK(root.status());
1058
1059 TF_EXPECT_OK(session.Run({read_back}, &outputs));
1060 EXPECT_EQ(outputs.size(), 1);
1061
1062 xla::LiteralProto response;
1063 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1064
1065 auto expected = xla::LiteralUtil::CreateR1<float>({-150.0f, -36.0f});
1066 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1067 }
1068
TEST(RawApiTest,CompileAndExecute)1069 TEST(RawApiTest, CompileAndExecute) {
1070 xrt::XLAAllocation p0;
1071 *p0.mutable_value() = FloatVector({1.0f, 2.0f});
1072 xrt::XLAAllocation p1;
1073 *p1.mutable_value() = FloatVector({8.0f, 5.0f});
1074
1075 xrt::XLAComputation c;
1076 auto config = c.mutable_config();
1077 auto shapes = config->mutable_program_shape();
1078 *shapes->add_parameters() =
1079 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1080 *shapes->add_parameters() =
1081 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1082 *shapes->mutable_result() =
1083 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1084 StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot());
1085
1086 xrt::XRTExecutionConfig e;
1087 e.set_release_input_handles(true);
1088 e.set_release_compilation_handle(true);
1089
1090 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1091 auto e_config =
1092 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1093 auto computation =
1094 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1095 auto c_handle = ops::XRTCompile(root, computation);
1096 auto p0_value =
1097 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1098 auto p0_handle = ops::XRTAllocate(root, p0_value);
1099 auto p1_value =
1100 ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1101 auto p1_handle = ops::XRTAllocate(root, p1_value);
1102 auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1103 {Output(p0_handle), Output(p1_handle)});
1104 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1105 TF_ASSERT_OK(root.status());
1106
1107 XrtClientSession session(root);
1108 std::vector<Tensor> outputs;
1109 TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1110
1111 xla::LiteralProto response;
1112 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1113
1114 auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
1115 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1116
1117 xla::ProgramShapeProto program_shape;
1118 EXPECT_TRUE(ParseFromTString(outputs[1].vec<tstring>()(0), &program_shape));
1119 EXPECT_EQ(program_shape.parameters_size(), 2);
1120 }
1121
TEST(RawApiTest,DynamicR1Test)1122 TEST(RawApiTest, DynamicR1Test) {
1123 xrt::XLAAllocation p0;
1124 *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f});
1125 xrt::XLAAllocation p1;
1126 *p1.mutable_value() = FloatVector({1.0f, -1.0f, 2.5f, 1.17f});
1127 xrt::XLAAllocation p2;
1128 *p2.mutable_value() = CreateR0<int32_t>(2);
1129
1130 xrt::XLAComputation c;
1131 auto config = c.mutable_config();
1132 auto shapes = config->mutable_program_shape();
1133 *shapes->add_parameters() =
1134 xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
1135 *shapes->add_parameters() =
1136 xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
1137 *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto();
1138 xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1139 dyn_shape.set_dynamic_dimension(0, true);
1140 *shapes->mutable_result() = dyn_shape.ToProto();
1141 StoreComputationSnapshot(ReturnDynamicR1(), c.mutable_hlo_snapshot());
1142
1143 xrt::XRTExecutionConfig e;
1144 e.set_release_input_handles(true);
1145 e.set_release_compilation_handle(true);
1146
1147 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1148 Scope cpu_root = root.WithDevice("/device:CPU:0");
1149 auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1150 auto computation = ops::Const(cpu_root, c.SerializeAsString());
1151 auto c_handle = ops::XRTCompile(root, computation);
1152 auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1153 auto p0_handle = ops::XRTAllocate(root, p0_value);
1154 auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
1155 auto p1_handle = ops::XRTAllocate(root, p1_value);
1156 auto p2_value = ops::Const(cpu_root, p2.SerializeAsString());
1157 auto p2_handle = ops::XRTAllocate(root, p2_value);
1158 auto result = ops::XRTExecute(
1159 root, c_handle.handle, e_config,
1160 {Output(p0_handle), Output(p1_handle), Output(p2_handle)});
1161 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1162 TF_ASSERT_OK(root.status());
1163
1164 XrtClientSession session(root);
1165 std::vector<Tensor> outputs;
1166 TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1167
1168 xla::LiteralProto response;
1169 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1170 auto expected = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f});
1171 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1172 }
1173
TEST(RawApiTest,DynamicR2Test)1174 TEST(RawApiTest, DynamicR2Test) {
1175 xrt::XLAAllocation p0;
1176 *p0.mutable_value() = xla::LiteralUtil::CreateR2({{1.0f, 2.0f, 0.5f, -1.0f},
1177 {1.5f, 2.5f, 3.0f, -2.0f}})
1178 .ToProto();
1179 xrt::XLAAllocation p1;
1180 *p1.mutable_value() = xla::LiteralUtil::CreateR2({{1.0f, -1.0f, 2.5f, 1.17f},
1181 {1.2f, -1.6f, 2.8f, 1.24f}})
1182 .ToProto();
1183 xrt::XLAAllocation p2;
1184 *p2.mutable_value() = CreateR0<int32_t>(2);
1185
1186 xrt::XLAComputation c;
1187 auto config = c.mutable_config();
1188 auto shapes = config->mutable_program_shape();
1189 *shapes->add_parameters() =
1190 xla::ShapeUtil::MakeShape(xla::F32, {2, 4}).ToProto();
1191 *shapes->add_parameters() =
1192 xla::ShapeUtil::MakeShape(xla::F32, {2, 4}).ToProto();
1193 *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto();
1194 xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4});
1195 dyn_shape.set_dynamic_dimension(0, true);
1196 dyn_shape.set_dynamic_dimension(1, true);
1197 *shapes->mutable_result() = dyn_shape.ToProto();
1198 StoreComputationSnapshot(ReturnDynamicR2(), c.mutable_hlo_snapshot());
1199
1200 xrt::XRTExecutionConfig e;
1201 e.set_release_input_handles(true);
1202 e.set_release_compilation_handle(true);
1203
1204 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1205 Scope cpu_root = root.WithDevice("/device:CPU:0");
1206 auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1207 auto computation = ops::Const(cpu_root, c.SerializeAsString());
1208 auto c_handle = ops::XRTCompile(root, computation);
1209 auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1210 auto p0_handle = ops::XRTAllocate(root, p0_value);
1211 auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
1212 auto p1_handle = ops::XRTAllocate(root, p1_value);
1213 auto p2_value = ops::Const(cpu_root, p2.SerializeAsString());
1214 auto p2_handle = ops::XRTAllocate(root, p2_value);
1215 auto result = ops::XRTExecute(
1216 root, c_handle.handle, e_config,
1217 {Output(p0_handle), Output(p1_handle), Output(p2_handle)});
1218 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1219 TF_ASSERT_OK(root.status());
1220
1221 XrtClientSession session(root);
1222 std::vector<Tensor> outputs;
1223 TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1224
1225 xla::LiteralProto response;
1226 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1227 auto expected = xla::LiteralUtil::CreateR2<float>({{2.0f, 1.0f}, {2.7, 0.9}});
1228 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1229 }
1230
TEST(RawApiTest,DynamicR1TupleTest)1231 TEST(RawApiTest, DynamicR1TupleTest) {
1232 xrt::XLAAllocation p0;
1233 *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f});
1234 xrt::XLAAllocation p1;
1235 *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f, 1.0f});
1236 xrt::XLAAllocation p2;
1237 *p2.mutable_value() = CreateR0<int32_t>(2);
1238
1239 xrt::XLAComputation c;
1240 auto config = c.mutable_config();
1241 auto shapes = config->mutable_program_shape();
1242 *shapes->add_parameters() =
1243 xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
1244 *shapes->add_parameters() =
1245 xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
1246 *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto();
1247 xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1248 dyn_shape.set_dynamic_dimension(0, true);
1249 *shapes->mutable_result() =
1250 xla::ShapeUtil::MakeTupleShape(
1251 {dyn_shape, xla::ShapeUtil::MakeShape(xla::F32, {4}), dyn_shape})
1252 .ToProto();
1253 StoreComputationSnapshot(ReturnDynamicR1Tuple(), c.mutable_hlo_snapshot());
1254
1255 xrt::XRTExecutionConfig e;
1256 e.set_release_input_handles(true);
1257 e.set_release_compilation_handle(true);
1258
1259 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1260 Scope cpu_root = root.WithDevice("/device:CPU:0");
1261 auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1262 auto computation = ops::Const(cpu_root, c.SerializeAsString());
1263 auto c_handle = ops::XRTCompile(root, computation);
1264 auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1265 auto p0_handle = ops::XRTAllocate(root, p0_value);
1266 auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
1267 auto p1_handle = ops::XRTAllocate(root, p1_value);
1268 auto p2_value = ops::Const(cpu_root, p2.SerializeAsString());
1269 auto p2_handle = ops::XRTAllocate(root, p2_value);
1270 auto result = ops::XRTExecute(
1271 root, c_handle.handle, e_config,
1272 {Output(p0_handle), Output(p1_handle), Output(p2_handle)});
1273 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1274 TF_ASSERT_OK(root.status());
1275
1276 XrtClientSession session(root);
1277 std::vector<Tensor> outputs;
1278 TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1279
1280 xla::LiteralProto response;
1281 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1282
1283 auto expected0 = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f});
1284 auto expected1 = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f, 0.0f, 0.0f});
1285 auto expected2 = xla::LiteralUtil::CreateR1<float>({0.0f, 3.0f, 1.0f});
1286 auto expected =
1287 xla::LiteralUtil::MakeTuple({&expected0, &expected1, &expected2});
1288 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1289 }
1290
TEST(RawApiTest,AcceptDynamicR1TupleTest)1291 TEST(RawApiTest, AcceptDynamicR1TupleTest) {
1292 if (*xla_test_device_ptr == "XLA_CPU" || *xla_test_device_ptr == "XLA_GPU") {
1293 // XLA_CPU and XLA_GPU has shape check set to kCompileTime.
1294 return;
1295 }
1296 xrt::XLAAllocation p0;
1297 *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f});
1298 xrt::XLAAllocation p1;
1299 *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f});
1300
1301 xrt::XLATupleNode tuple_desc;
1302 auto subdesc_10 = tuple_desc.add_tuples();
1303 auto subdesc_11 = tuple_desc.add_tuples();
1304 subdesc_10->set_input_index(0);
1305 subdesc_10->set_release_input_handle(true);
1306 subdesc_11->set_input_index(1);
1307 subdesc_11->set_release_input_handle(true);
1308
1309 xrt::XLAComputation c;
1310 auto config = c.mutable_config();
1311 auto shapes = config->mutable_program_shape();
1312 xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1313 dyn_input_shape.set_dynamic_dimension(0, true);
1314 xla::Shape dyn_tuple_shape =
1315 xla::ShapeUtil::MakeTupleShape({dyn_input_shape, dyn_input_shape});
1316 *shapes->add_parameters() = dyn_tuple_shape.ToProto();
1317 xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1318 dyn_shape.set_dynamic_dimension(0, true);
1319 *shapes->mutable_result() = dyn_shape.ToProto();
1320 StoreComputationSnapshot(AcceptDynamicR1Tuple(), c.mutable_hlo_snapshot());
1321
1322 xrt::XRTExecutionConfig e;
1323 e.set_release_input_handles(true);
1324 e.set_release_compilation_handle(true);
1325
1326 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1327 Scope cpu_root = root.WithDevice("/device:CPU:0");
1328 auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1329 auto computation = ops::Const(cpu_root, c.SerializeAsString());
1330 auto c_handle = ops::XRTCompile(root, computation);
1331 auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1332 auto p0_handle = ops::XRTAllocate(root, p0_value);
1333 auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
1334 auto p1_handle = ops::XRTAllocate(root, p1_value);
1335
1336 auto tuple_0 = ops::Const(root.WithDevice("/device:CPU:0"),
1337 tuple_desc.SerializeAsString());
1338 auto t0_handle = ops::XRTMakeTuple(
1339 root, tuple_0,
1340 {static_cast<Output>(p0_handle), static_cast<Output>(p1_handle)});
1341 auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1342 {static_cast<Output>(t0_handle)});
1343 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1344 TF_ASSERT_OK(root.status());
1345
1346 XrtClientSession session(root);
1347 std::vector<Tensor> outputs;
1348 TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1349
1350 xla::LiteralProto response;
1351 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1352
1353 auto expected = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f, 0.0f});
1354 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1355 }
1356
TEST(RawApiTest,AcceptDynamicR1Test)1357 TEST(RawApiTest, AcceptDynamicR1Test) {
1358 if (*xla_test_device_ptr == "XLA_CPU" || *xla_test_device_ptr == "XLA_GPU") {
1359 // XLA_CPU and XLA_GPU has shape check set to kCompileTime.
1360 return;
1361 }
1362 xrt::XLAAllocation p0;
1363 *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f});
1364 xrt::XLAAllocation p1;
1365 *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f});
1366
1367 xrt::XLAComputation c;
1368 auto config = c.mutable_config();
1369 auto shapes = config->mutable_program_shape();
1370 xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1371 dyn_input_shape.set_dynamic_dimension(0, true);
1372 *shapes->add_parameters() = dyn_input_shape.ToProto();
1373 *shapes->add_parameters() = dyn_input_shape.ToProto();
1374 xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1375 dyn_shape.set_dynamic_dimension(0, true);
1376 *shapes->mutable_result() = dyn_shape.ToProto();
1377 StoreComputationSnapshot(AcceptDynamicR1(), c.mutable_hlo_snapshot());
1378
1379 xrt::XRTExecutionConfig e;
1380 e.set_release_input_handles(true);
1381 e.set_release_compilation_handle(true);
1382
1383 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1384 Scope cpu_root = root.WithDevice("/device:CPU:0");
1385 auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1386 auto computation = ops::Const(cpu_root, c.SerializeAsString());
1387 auto c_handle = ops::XRTCompile(root, computation);
1388 auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1389 auto allocate_op_0 = ops::XRTAllocate(root, p0_value);
1390 auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
1391 auto allocate_op_1 = ops::XRTAllocate(root, p1_value);
1392 auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1393 {Output(allocate_op_0), Output(allocate_op_1)});
1394 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1395 TF_ASSERT_OK(root.status());
1396
1397 XrtClientSession session(root);
1398 std::vector<Tensor> outputs;
1399 TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1400
1401 xla::LiteralProto response;
1402 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1403
1404 auto expected = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f, 0.0f});
1405 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1406 }
1407
TEST(RawApiTest,AcceptDynamicR2Test)1408 TEST(RawApiTest, AcceptDynamicR2Test) {
1409 xrt::XLAAllocation p0;
1410 *p0.mutable_value() =
1411 xla::LiteralUtil::CreateR2({{-1.0f, 2.0f, 3.0f}, {-4.0f, -5.0f, 6.0f}})
1412 .ToProto();
1413
1414 xrt::XLAComputation c;
1415 auto config = c.mutable_config();
1416 auto shapes = config->mutable_program_shape();
1417 // Compile time expects ascending layout.
1418 xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4});
1419 dyn_shape.set_dynamic_dimension(1, true);
1420 *shapes->add_parameters() = dyn_shape.ToProto();
1421
1422 *shapes->mutable_result() = dyn_shape.ToProto();
1423 StoreComputationSnapshot(AcceptDynamicR2(), c.mutable_hlo_snapshot());
1424
1425 xrt::XRTExecutionConfig e;
1426 e.set_release_input_handles(true);
1427 e.set_release_compilation_handle(true);
1428
1429 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1430 Scope cpu_root = root.WithDevice("/device:CPU:0");
1431 auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1432 auto computation = ops::Const(cpu_root, c.SerializeAsString());
1433 auto c_handle = ops::XRTCompile(root, computation);
1434 auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1435 auto p0_handle = ops::XRTAllocate(root, p0_value);
1436 auto result =
1437 ops::XRTExecute(root, c_handle.handle, e_config, {Output(p0_handle)});
1438 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1439 TF_ASSERT_OK(root.status());
1440
1441 XrtClientSession session(root);
1442 std::vector<Tensor> outputs;
1443 TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1444
1445 xla::LiteralProto response;
1446 EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1447
1448 auto expected = xla::LiteralUtil::CreateR2<float>(
1449 {{1.0f, -2.0f, -3.0f}, {4.0f, 5.0f, -6.0f}});
1450 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1451 }
1452
TEST(RawApiTest,CompileAndExecuteWithArgumentVector)1453 TEST(RawApiTest, CompileAndExecuteWithArgumentVector) {
1454 xrt::XLAAllocation p0;
1455 *p0.mutable_value() = FloatVector({1.0f, 2.0f});
1456 xrt::XLAAllocation p1;
1457 *p1.mutable_value() = FloatVector({8.0f, 5.0f});
1458
1459 xrt::XLAComputation c;
1460 auto config = c.mutable_config();
1461 auto shapes = config->mutable_program_shape();
1462 *shapes->add_parameters() =
1463 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1464 *shapes->add_parameters() =
1465 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1466 *shapes->mutable_result() =
1467 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1468 StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot());
1469
1470 xrt::XRTExecutionConfig e;
1471 e.set_release_input_handles(true);
1472 e.set_release_compilation_handle(true);
1473
1474 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1475 auto e_config =
1476 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1477 auto computation =
1478 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1479 auto c_handle = ops::XRTCompile(root, computation);
1480 auto p0_value =
1481 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1482 auto p0_handle = ops::XRTAllocate(root, p0_value);
1483 auto p1_value =
1484 ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1485 auto p1_handle = ops::XRTAllocate(root, p1_value);
1486 auto packed_args = ops::Stack(root.WithDevice("/device:CPU:0"),
1487 {Output(p0_handle), Output(p1_handle)});
1488 auto result =
1489 ops::XRTExecute(root, c_handle.handle, e_config, {Output(packed_args)});
1490 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1491 TF_ASSERT_OK(root.status());
1492
1493 XrtClientSession session(root);
1494 std::vector<Tensor> outputs;
1495 TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1496
1497 xla::LiteralProto response;
1498 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1499
1500 auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
1501 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1502
1503 xla::ProgramShapeProto program_shape;
1504 EXPECT_TRUE(ParseFromTString(outputs[1].vec<tstring>()(0), &program_shape));
1505 EXPECT_EQ(program_shape.parameters_size(), 2);
1506 }
1507
TEST(RawApiTest,CompileWithXlaReturnShapes)1508 TEST(RawApiTest, CompileWithXlaReturnShapes) {
1509 xla::XlaBuilder builder("XrtXlaShapes");
1510 auto input_shape = xla::ShapeUtil::MakeShape(xla::BF16, {32, 3, 128, 128});
1511 auto kernel_shape = xla::ShapeUtil::MakeShape(xla::BF16, {3, 3, 5, 5});
1512 // Clear layouts to signal XLA we are ready to get whatever are coming out of
1513 // the compilation process.
1514 xla::LayoutUtil::ClearLayout(&input_shape);
1515 xla::LayoutUtil::ClearLayout(&kernel_shape);
1516 auto param_shape =
1517 xla::ShapeUtil::MakeTupleShape({input_shape, kernel_shape});
1518 auto param = xla::Parameter(&builder, 0, param_shape, "param");
1519 auto input = xla::GetTupleElement(param, 0);
1520 auto kernel = xla::GetTupleElement(param, 1);
1521 xla::Conv(input, kernel, {1, 1}, xla::Padding::kSame);
1522 TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation xla_computation, builder.Build());
1523
1524 auto result_shape = xla_computation.GetProgramShape().ValueOrDie().result();
1525 // Clear the result shape layout to tell XLA we are accepting whatever are
1526 // coming out of the compilation process.
1527 xla::LayoutUtil::ClearLayout(&result_shape);
1528
1529 xrt::XLAComputation c;
1530 auto config = c.mutable_config();
1531 auto shapes = config->mutable_program_shape();
1532 *shapes->add_parameters() = param_shape.ToProto();
1533 *shapes->mutable_result() = result_shape.ToProto();
1534 StoreComputationSnapshot(xla_computation, c.mutable_hlo_snapshot());
1535
1536 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1537 auto computation =
1538 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1539 auto c_handle = ops::XRTCompile(root, computation);
1540 auto release = ops::XRTReleaseCompilationHandle(root, c_handle.handle);
1541 TF_ASSERT_OK(root.status());
1542
1543 XrtClientSession session(root);
1544 std::vector<Tensor> outputs;
1545 TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {c_handle.program_shape},
1546 {release}, &outputs));
1547
1548 xla::ProgramShapeProto program_shape_proto;
1549 EXPECT_TRUE(
1550 ParseFromTString(outputs[0].vec<tstring>()(0), &program_shape_proto));
1551 xla::ProgramShape program_shape(program_shape_proto);
1552 EXPECT_EQ(program_shape.parameters_size(), 1);
1553
1554 VLOG(2) << "Param: "
1555 << xla::ShapeUtil::HumanStringWithLayout(program_shape.parameters(0));
1556 VLOG(2) << "Result: "
1557 << xla::ShapeUtil::HumanStringWithLayout(program_shape.result());
1558
1559 xla::ProgramShape xla_program_shape =
1560 XlaCompiledProgramShape(xla_computation, xla::ProgramShape(*shapes));
1561 EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()(
1562 xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(),
1563 xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0})
1564 .layout()));
1565 EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()(
1566 xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {1}).layout(),
1567 xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {1})
1568 .layout()));
1569 EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()(
1570 program_shape.result().layout(), xla_program_shape.result().layout()));
1571 }
1572
TEST(RawApiTest,DotGeneralWithLayoutTest)1573 TEST(RawApiTest, DotGeneralWithLayoutTest) {
1574 auto layout = xla::LayoutUtil::MakeLayout({0, 1});
1575
1576 xrt::XLAAllocation p0;
1577 *p0.mutable_value() = FloatMatrix({{1.0f, 2.0f}, {3.0f, 4.0f}}, layout);
1578 xrt::XLAAllocation p1;
1579 *p1.mutable_value() = FloatMatrix({{8.0f}, {5.0f}}, layout);
1580
1581 xrt::XLAComputation c;
1582 auto config = c.mutable_config();
1583 auto shapes = config->mutable_program_shape();
1584 *shapes->add_parameters() =
1585 xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1}).ToProto();
1586 *shapes->add_parameters() =
1587 xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}).ToProto();
1588 *shapes->mutable_result() =
1589 xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}).ToProto();
1590 StoreComputationSnapshot(Dot(), c.mutable_hlo_snapshot());
1591
1592 xrt::XRTExecutionConfig e;
1593 e.set_release_input_handles(true);
1594 e.set_release_compilation_handle(true);
1595
1596 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1597 auto e_config =
1598 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1599 auto computation =
1600 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1601 auto c_handle = ops::XRTCompile(root, computation);
1602 auto p0_value =
1603 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1604 auto p0_handle = ops::XRTAllocate(root, p0_value);
1605 auto p1_value =
1606 ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1607 auto p1_handle = ops::XRTAllocate(root, p1_value);
1608 auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1609 {Output(p0_handle), Output(p1_handle)});
1610 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1611 TF_ASSERT_OK(root.status());
1612
1613 XrtClientSession session(root);
1614 std::vector<Tensor> outputs;
1615 TF_EXPECT_OK(session.Run({read_back}, &outputs));
1616
1617 xla::LiteralProto response;
1618 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1619
1620 auto expected =
1621 xla::LiteralUtil::CreateR2WithLayout<float>({{18.0f}, {44.0f}}, layout);
1622
1623 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1624 }
1625
TEST(RawApiTest,CompileAndExecuteZeroArg)1626 TEST(RawApiTest, CompileAndExecuteZeroArg) {
1627 xrt::XLAComputation c;
1628 auto config = c.mutable_config();
1629 auto shapes = config->mutable_program_shape();
1630 *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto();
1631
1632 xrt::XRTExecutionConfig e;
1633 e.set_release_input_handles(true);
1634 e.set_release_compilation_handle(true);
1635 StoreComputationSnapshot(OnePlusTwo(), c.mutable_hlo_snapshot());
1636
1637 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1638 auto e_config =
1639 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1640 auto computation =
1641 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1642 auto c_handle = ops::XRTCompile(root, computation);
1643 auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1644 std::initializer_list<Input>({}));
1645 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1646 TF_ASSERT_OK(root.status());
1647
1648 XrtClientSession session(root);
1649 std::vector<Tensor> outputs;
1650 TF_EXPECT_OK(session.Run({read_back}, &outputs));
1651
1652 xla::LiteralProto response;
1653 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1654
1655 auto expected = xla::LiteralUtil::CreateR0<float>(3.0f);
1656 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1657 }
1658
TEST(RawApiTest,CompileAndExecuteReturnTuple)1659 TEST(RawApiTest, CompileAndExecuteReturnTuple) {
1660 xrt::XLAAllocation p0;
1661 *p0.mutable_value() = FloatVector({1.0f, 2.0f});
1662 xrt::XLAAllocation p1;
1663 *p1.mutable_value() = FloatVector({8.0f, 5.0f});
1664
1665 xrt::XLAComputation c;
1666 auto config = c.mutable_config();
1667 auto shapes = config->mutable_program_shape();
1668 *shapes->add_parameters() =
1669 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1670 *shapes->add_parameters() =
1671 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1672 *shapes->mutable_result() =
1673 xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})})
1674 .ToProto();
1675 StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot());
1676
1677 xrt::XRTExecutionConfig e;
1678 e.set_release_input_handles(true);
1679 e.set_release_compilation_handle(true);
1680
1681 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1682 auto e_config =
1683 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1684 auto computation =
1685 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1686 auto c_handle = ops::XRTCompile(root, computation);
1687 auto p0_value =
1688 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1689 auto p0_handle = ops::XRTAllocate(root, p0_value);
1690 auto p1_value =
1691 ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1692 auto p1_handle = ops::XRTAllocate(root, p1_value);
1693 auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1694 {Output(p0_handle), Output(p1_handle)});
1695 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1696 TF_ASSERT_OK(root.status());
1697
1698 XrtClientSession session(root);
1699 std::vector<Tensor> outputs;
1700 TF_EXPECT_OK(session.Run({read_back}, &outputs));
1701
1702 xla::LiteralProto response;
1703 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1704
1705 auto sum = xla::LiteralUtil::CreateR1<float>({9.0f, 7.0f});
1706 auto expected = xla::LiteralUtil::MakeTuple({&sum});
1707 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1708 }
1709
TEST(RawApiTest,CompileAndExecuteReturnExplodedTuple)1710 TEST(RawApiTest, CompileAndExecuteReturnExplodedTuple) {
1711 xrt::XLAAllocation p0;
1712 *p0.mutable_value() = xla::LiteralUtil::CreateR0<float>(12.0f).ToProto();
1713
1714 xrt::XLAAllocation p1;
1715 *p1.mutable_value() = xla::LiteralUtil::CreateR0<float>(3.0f).ToProto();
1716
1717 xrt::XLAComputation c;
1718 auto config = c.mutable_config();
1719 auto shapes = config->mutable_program_shape();
1720 *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto();
1721 *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto();
1722 *shapes->mutable_result() =
1723 xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {}),
1724 xla::ShapeUtil::MakeShape(xla::F32, {})})
1725 .ToProto();
1726 StoreComputationSnapshot(AddAndSubTuple(), c.mutable_hlo_snapshot());
1727
1728 xrt::XRTExecutionConfig e;
1729 e.set_release_input_handles(true);
1730 e.set_release_compilation_handle(true);
1731 e.set_return_exploded_tuple(true);
1732
1733 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1734 auto e_config =
1735 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1736 auto computation =
1737 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1738 auto c_handle = ops::XRTCompile(root, computation);
1739 auto p0_value =
1740 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1741 auto p0_handle = ops::XRTAllocate(root, p0_value);
1742 auto p1_value =
1743 ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1744 auto p1_handle = ops::XRTAllocate(root, p1_value);
1745 auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1746 {Output(p0_handle), Output(p1_handle)});
1747 TF_ASSERT_OK(root.status());
1748
1749 XrtClientSession session(root);
1750 std::vector<Tensor> outputs;
1751 TF_EXPECT_OK(session.Run({result}, &outputs));
1752 EXPECT_EQ(outputs.size(), 1);
1753
1754 auto handles_vec = outputs.front().vec<int64_t>();
1755 EXPECT_EQ(handles_vec.size(), 2);
1756
1757 const float kResults[2] = {15.0f, 9.0f};
1758 for (int64_t i = 0; i < handles_vec.size(); ++i) {
1759 auto read_back = ops::XRTReadLiteralAndRelease(root, Input(handles_vec(i)));
1760 std::vector<Tensor> voutputs;
1761 TF_EXPECT_OK(session.Run({read_back}, &voutputs));
1762 EXPECT_EQ(voutputs.size(), 1);
1763
1764 xla::LiteralProto response;
1765 EXPECT_TRUE(ParseFromTString(voutputs[0].scalar<tstring>()(), &response));
1766
1767 auto expected = xla::LiteralUtil::CreateR0<float>(kResults[i]);
1768 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1769 }
1770 }
1771
TEST(RawApiTest,LeakCompilationReference)1772 TEST(RawApiTest, LeakCompilationReference) {
1773 xrt::XLAComputation c;
1774 auto config = c.mutable_config();
1775 auto shapes = config->mutable_program_shape();
1776 *shapes->add_parameters() =
1777 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1778 *shapes->add_parameters() =
1779 xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1780 *shapes->mutable_result() =
1781 xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})})
1782 .ToProto();
1783 StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot());
1784
1785 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1786 auto computation =
1787 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1788 auto c_handle = ops::XRTCompile(root, computation);
1789 TF_ASSERT_OK(root.status());
1790
1791 XrtClientSession session(root);
1792 std::vector<Tensor> outputs;
1793 TF_EXPECT_OK(session.Run({c_handle.handle}, &outputs));
1794 }
1795
TEST(RawApiTest,CompileAndExecuteWithReusedBuffers)1796 TEST(RawApiTest, CompileAndExecuteWithReusedBuffers) {
1797 xla::Shape element_shape = xla::ShapeUtil::MakeShape(xla::F32, {2});
1798 xla::Shape shape =
1799 xla::ShapeUtil::MakeTupleShape({element_shape, element_shape});
1800 xla::Shape return_shape = xla::ShapeUtil::MakeTupleShape(
1801 {element_shape, element_shape, element_shape, element_shape});
1802 xla::XlaBuilder builder("ReuseBuffer");
1803 auto param = xla::Parameter(&builder, 0, shape, "param");
1804 auto p0 = xla::GetTupleElement(param, 0);
1805 auto p1 = xla::GetTupleElement(param, 1);
1806 auto add = xla::Add(p0, p1);
1807 auto sub = xla::Sub(p0, p1);
1808 xla::Tuple(&builder, {add, sub, p0, p1});
1809
1810 // Flip the tuple literals in the input handle.
1811 builder.SetUpAlias({1}, 0, {0});
1812 builder.SetUpAlias({0}, 0, {1});
1813
1814 auto computation = builder.Build().ValueOrDie();
1815
1816 auto literal0 = xla::LiteralUtil::CreateR1<float>({1.0f, 2.0f});
1817 auto literal1 = xla::LiteralUtil::CreateR1<float>({5.0f, 9.0f});
1818 auto literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1});
1819
1820 xrt::XLAAllocation param_alloc;
1821 *param_alloc.mutable_value() = literal.ToProto();
1822
1823 xrt::XLAComputation c;
1824 auto config = c.mutable_config();
1825 auto shapes = config->mutable_program_shape();
1826 *shapes->add_parameters() = shape.ToProto();
1827 *shapes->mutable_result() = return_shape.ToProto();
1828 StoreComputationSnapshot(computation, c.mutable_hlo_snapshot());
1829
1830 xrt::XRTExecutionConfig e;
1831 e.set_release_input_handles(false);
1832 e.set_release_compilation_handle(true);
1833
1834 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1835 XrtClientSession session(root);
1836 auto e_config =
1837 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1838 auto c_data =
1839 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1840 auto c_handle = ops::XRTCompile(root, c_data);
1841 auto param_value = ops::Const(root.WithDevice("/device:CPU:0"),
1842 param_alloc.SerializeAsString());
1843 auto param_handle = ops::XRTAllocate(root, param_value);
1844 TF_ASSERT_OK(root.status());
1845
1846 std::vector<Tensor> outputs;
1847 TF_EXPECT_OK(session.Run({param_handle}, &outputs));
1848
1849 int64_t alloc_handle = outputs[0].scalar<int64_t>()();
1850
1851 // Note that we release the result handle immediately, but since we aliased
1852 // the output buffers onto the input allocation ones (held in alloc_handle),
1853 // we can fetch the result from there.
1854 auto result =
1855 ops::XRTExecute(root, c_handle.handle, e_config, {Input(alloc_handle)});
1856 auto read_back = ops::XRTReadLiteral(root, result);
1857 auto release = ops::XRTReleaseAllocationHandle(
1858 root.WithControlDependencies(read_back), result);
1859 TF_ASSERT_OK(root.status());
1860
1861 TF_EXPECT_OK(
1862 session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs));
1863
1864 xla::Literal exec_literal = ReadOutputLiteral(outputs, 0);
1865 auto exec_literal_parts = exec_literal.DecomposeTuple();
1866 ASSERT_EQ(exec_literal_parts.size(), 4);
1867
1868 EXPECT_TRUE(CompareLiterals(exec_literal_parts[2], literal0));
1869 EXPECT_TRUE(CompareLiterals(exec_literal_parts[3], literal1));
1870
1871 // Now we read back the original input handle values, which at this point
1872 // should contain the result of the XLA computation.
1873 auto read_handle = ops::XRTReadLiteral(root, Input(alloc_handle));
1874 TF_ASSERT_OK(root.status());
1875 auto release_handle = ops::XRTReleaseAllocationHandle(
1876 root.WithControlDependencies(read_handle), Input(alloc_handle));
1877 TF_ASSERT_OK(root.status());
1878
1879 TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {read_handle},
1880 {release_handle}, &outputs));
1881
1882 xla::Literal return_literal = ReadOutputLiteral(outputs, 0);
1883
1884 auto expected_literal0 = xla::LiteralUtil::CreateR1<float>({6.0f, 11.0f});
1885 auto expected_literal1 = xla::LiteralUtil::CreateR1<float>({-4.0f, -7.0f});
1886 // The first element of the computation returned tuple would be the add
1887 // (expected_literal0), but since we flipped the buffers, the sub
1888 // (expected_literal1) should come first.
1889 auto expected_literal =
1890 xla::LiteralUtil::MakeTuple({&expected_literal1, &expected_literal0});
1891
1892 EXPECT_TRUE(CompareLiterals(return_literal, expected_literal));
1893 }
1894
TEST(RawApiTest,CompileAndExecuteWithReusedBuffersS64)1895 TEST(RawApiTest, CompileAndExecuteWithReusedBuffersS64) {
1896 xla::Shape element_shape = xla::ShapeUtil::MakeShape(xla::S64, {2});
1897 xla::Shape shape =
1898 xla::ShapeUtil::MakeTupleShape({element_shape, element_shape});
1899 xla::Shape return_shape = xla::ShapeUtil::MakeTupleShape(
1900 {element_shape, element_shape, element_shape, element_shape});
1901 xla::XlaBuilder builder("ReuseBuffer");
1902 auto param = xla::Parameter(&builder, 0, shape, "param");
1903 auto p0 = xla::GetTupleElement(param, 0);
1904 auto p1 = xla::GetTupleElement(param, 1);
1905 auto add = xla::Add(p0, p1);
1906 auto sub = xla::Sub(p0, p1);
1907 xla::Tuple(&builder, {add, sub, p0, p1});
1908
1909 // Flip the tuple literals in the input handle.
1910 builder.SetUpAlias({1}, 0, {0});
1911 builder.SetUpAlias({0}, 0, {1});
1912
1913 auto computation = builder.Build().ValueOrDie();
1914
1915 auto literal0 = xla::LiteralUtil::CreateR1<int64_t>({1, 2});
1916 auto literal1 = xla::LiteralUtil::CreateR1<int64_t>({5, 9});
1917 auto literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1});
1918
1919 xrt::XLAAllocation param_alloc;
1920 *param_alloc.mutable_value() = literal.ToProto();
1921
1922 xrt::XLAComputation c;
1923 auto config = c.mutable_config();
1924 auto shapes = config->mutable_program_shape();
1925 *shapes->add_parameters() = shape.ToProto();
1926 *shapes->mutable_result() = return_shape.ToProto();
1927 StoreComputationSnapshot(computation, c.mutable_hlo_snapshot());
1928
1929 xrt::XRTExecutionConfig e;
1930 e.set_release_input_handles(false);
1931 e.set_release_compilation_handle(true);
1932
1933 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1934 XrtClientSession session(root);
1935 auto e_config =
1936 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1937 auto c_data =
1938 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1939 auto c_handle = ops::XRTCompile(root, c_data);
1940 auto param_value = ops::Const(root.WithDevice("/device:CPU:0"),
1941 param_alloc.SerializeAsString());
1942 auto param_handle = ops::XRTAllocate(root, param_value);
1943 TF_ASSERT_OK(root.status());
1944
1945 std::vector<Tensor> outputs;
1946 TF_EXPECT_OK(session.Run({param_handle}, &outputs));
1947
1948 int64_t alloc_handle = outputs[0].scalar<int64_t>()();
1949
1950 // Note that we release the result handle immediately, but since we aliased
1951 // the output buffers onto the input allocation ones (held in alloc_handle),
1952 // we can fetch the result from there.
1953 auto result =
1954 ops::XRTExecute(root, c_handle.handle, e_config, {Input(alloc_handle)});
1955 auto read_back = ops::XRTReadLiteral(root, result);
1956 auto release = ops::XRTReleaseAllocationHandle(
1957 root.WithControlDependencies(read_back), result);
1958 TF_ASSERT_OK(root.status());
1959
1960 TF_EXPECT_OK(
1961 session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs));
1962
1963 xla::Literal exec_literal = ReadOutputLiteral(outputs, 0);
1964 auto exec_literal_parts = exec_literal.DecomposeTuple();
1965 ASSERT_EQ(exec_literal_parts.size(), 4);
1966
1967 EXPECT_TRUE(CompareLiterals(exec_literal_parts[2], literal0));
1968 EXPECT_TRUE(CompareLiterals(exec_literal_parts[3], literal1));
1969
1970 // Now we read back the original input handle values, which at this point
1971 // should contain the result of the XLA computation.
1972 auto read_handle = ops::XRTReadLiteral(root, Input(alloc_handle));
1973 TF_ASSERT_OK(root.status());
1974 auto release_handle = ops::XRTReleaseAllocationHandle(
1975 root.WithControlDependencies(read_handle), Input(alloc_handle));
1976 TF_ASSERT_OK(root.status());
1977
1978 TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {read_handle},
1979 {release_handle}, &outputs));
1980
1981 xla::Literal return_literal = ReadOutputLiteral(outputs, 0);
1982
1983 auto expected_literal0 = xla::LiteralUtil::CreateR1<int64_t>({6, 11});
1984 auto expected_literal1 = xla::LiteralUtil::CreateR1<int64_t>({-4, -7});
1985 // The first element of the computation returned tuple would be the add
1986 // (expected_literal0), but since we flipped the buffers, the sub
1987 // (expected_literal1) should come first.
1988 auto expected_literal =
1989 xla::LiteralUtil::MakeTuple({&expected_literal1, &expected_literal0});
1990
1991 EXPECT_TRUE(CompareLiterals(return_literal, expected_literal));
1992 }
1993
TEST(RawApiTest,CompileAndExecuteWithS64Argument)1994 TEST(RawApiTest, CompileAndExecuteWithS64Argument) {
1995 xrt::XLAAllocation p0;
1996 *p0.mutable_value() = xla::LiteralUtil::CreateR0<int64_t>(11031965).ToProto();
1997 xrt::XLAAllocation p1;
1998 *p1.mutable_value() = xla::LiteralUtil::CreateR0<int64_t>(4091934).ToProto();
1999
2000 xrt::XLAComputation c;
2001 auto config = c.mutable_config();
2002 auto shapes = config->mutable_program_shape();
2003 *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
2004 *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
2005 *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
2006 StoreComputationSnapshot(AddS64(), c.mutable_hlo_snapshot());
2007
2008 xrt::XRTExecutionConfig e;
2009 e.set_release_input_handles(true);
2010 e.set_release_compilation_handle(true);
2011 e.set_return_exploded_tuple(true);
2012
2013 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
2014 auto e_config =
2015 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
2016 auto computation =
2017 ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
2018 auto c_handle = ops::XRTCompile(root, computation);
2019 auto p0_value =
2020 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
2021 auto p0_handle = ops::XRTAllocate(root, p0_value);
2022 auto p1_value =
2023 ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
2024 auto p1_handle = ops::XRTAllocate(root, p1_value);
2025 auto result = ops::XRTExecute(root, c_handle.handle, e_config,
2026 {Output(p0_handle), Output(p1_handle)});
2027 auto read_back = ops::XRTReadLiteralAndRelease(root, result);
2028 TF_ASSERT_OK(root.status());
2029
2030 XrtClientSession session(root);
2031 std::vector<Tensor> outputs;
2032 TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
2033
2034 xla::LiteralProto response;
2035 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
2036
2037 auto expected = xla::LiteralUtil::CreateR0<int64_t>(15123899);
2038 EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
2039
2040 xla::ProgramShapeProto program_shape;
2041 EXPECT_TRUE(ParseFromTString(outputs[1].vec<tstring>()(0), &program_shape));
2042 EXPECT_EQ(program_shape.parameters_size(), 2);
2043 EXPECT_TRUE(xla::ShapeUtil::HasPrimitiveType(
2044 xla::Shape(program_shape.result()), xla::S64));
2045 }
2046
2047 // Tests the XRT device memory compaction API (XRTCompactAllocations).
TEST(RawApiTest,TestDeviceMemoryCompaction)2048 TEST(RawApiTest, TestDeviceMemoryCompaction) {
2049 static const int kNumAllocs = 32;
2050 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
2051
2052 std::vector<xrt::XLAAllocation> allocs(kNumAllocs);
2053 std::vector<Output> handle_outputs;
2054 for (int i = 0; i < kNumAllocs; ++i) {
2055 *allocs[i].mutable_value() = BasedTwoElementTuple(i * 4.0f);
2056 auto value = ops::Const(root.WithDevice("/device:CPU:0"),
2057 allocs[i].SerializeAsString());
2058 handle_outputs.push_back(ops::XRTAllocate(root, value));
2059 }
2060 TF_ASSERT_OK(root.status());
2061
2062 XrtClientSession session(root);
2063 std::vector<Tensor> outputs;
2064 TF_EXPECT_OK(session.Run(handle_outputs, &outputs));
2065 EXPECT_EQ(outputs.size(), handle_outputs.size());
2066
2067 std::vector<int64_t> handles;
2068 for (auto& output : outputs) {
2069 handles.push_back(output.scalar<int64_t>()());
2070 }
2071 // Create holes by releasing even allocations.
2072 std::vector<Operation> handle_releases;
2073 for (size_t i = 0; i < handles.size(); i += 2) {
2074 handle_releases.push_back(
2075 ops::XRTReleaseAllocationHandle(root, Input(handles[i])));
2076 }
2077 TF_ASSERT_OK(root.status());
2078
2079 TF_EXPECT_OK(
2080 session.Run(ClientSession::FeedType(), {}, handle_releases, &outputs));
2081
2082 // Run the compaction API.
2083 auto compact_op = ops::XRTCompactAllocations(root);
2084 TF_EXPECT_OK(
2085 session.Run(ClientSession::FeedType(), {}, {compact_op}, &outputs));
2086
2087 // Read back the allocation left at odd indices.
2088 std::vector<Output> read_outputs;
2089 for (size_t i = 1; i < handles.size(); i += 2) {
2090 read_outputs.push_back(ops::XRTReadLiteral(root, Input(handles[i])));
2091 }
2092 TF_ASSERT_OK(root.status());
2093
2094 TF_EXPECT_OK(session.Run(read_outputs, &outputs));
2095 EXPECT_EQ(outputs.size(), read_outputs.size());
2096
2097 // Verify that everything got moved correctly and the device data matches what
2098 // we have on record.
2099 for (size_t i = 1, j = 0; i < handles.size(); i += 2, ++j) {
2100 xla::LiteralProto response;
2101 EXPECT_TRUE(ParseFromTString(outputs[j].scalar<tstring>()(), &response));
2102 EXPECT_TRUE(CompareLiteralProtos(allocs[i].value(), response));
2103 }
2104 }
2105
TEST(RawApiTest,TestDeviceMemorySwap)2106 TEST(RawApiTest, TestDeviceMemorySwap) {
2107 const xla::Shape scalar_shape = xla::ShapeUtil::MakeShape(xla::F32, {});
2108 // 100MB F32 tensor.
2109 const xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {5000, 5000});
2110 const int64_t tensor_size = xla::ShapeUtil::ByteSizeOf(shape);
2111 // On CPU we cannot trigger OOM/swap. For TPU and GPU we select 16GB as
2112 // maximum memory.
2113 int64_t device_memory_size = 8LL * 1024 * 1024 * 1024;
2114 if (*xla_test_device_ptr == "TPU" || *xla_test_device_ptr == "XLA_GPU") {
2115 device_memory_size = 16LL * 1024 * 1024 * 1024;
2116 }
2117
2118 xrt::XLAAllocation p0;
2119 *p0.mutable_value() = xla::LiteralUtil::CreateR0<float>(0.90434).ToProto();
2120
2121 // Create a computation which broadcasts a scalar to a big tensor.
2122 xrt::XLAComputation c_bcast;
2123 {
2124 auto shapes = c_bcast.mutable_config()->mutable_program_shape();
2125 *shapes->add_parameters() = scalar_shape.ToProto();
2126 *shapes->mutable_result() = shape.ToProto();
2127 StoreComputationSnapshot(
2128 BroadcastComputation(scalar_shape, shape.dimensions()),
2129 c_bcast.mutable_hlo_snapshot());
2130 }
2131
2132 // Create a computation which compares two tensors.
2133 xrt::XLAComputation c_equal;
2134 {
2135 auto shapes = c_equal.mutable_config()->mutable_program_shape();
2136 *shapes->add_parameters() = shape.ToProto();
2137 *shapes->add_parameters() = shape.ToProto();
2138 *shapes->mutable_result() =
2139 xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto();
2140 StoreComputationSnapshot(IsEqualComputation(shape),
2141 c_equal.mutable_hlo_snapshot());
2142 }
2143
2144 xrt::XRTExecutionConfig e;
2145 e.set_release_input_handles(false);
2146 e.set_release_compilation_handle(false);
2147
2148 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
2149 XrtClientSession session(root);
2150 auto e_config =
2151 ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
2152 auto bcast_computation =
2153 ops::Const(root.WithDevice("/device:CPU:0"), c_bcast.SerializeAsString());
2154 auto c_bcast_handle = ops::XRTCompile(root, bcast_computation);
2155 auto equal_computation =
2156 ops::Const(root.WithDevice("/device:CPU:0"), c_equal.SerializeAsString());
2157 auto c_equal_handle = ops::XRTCompile(root, equal_computation);
2158 auto p0_value =
2159 ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
2160 auto p0_handle = ops::XRTAllocate(root, p0_value);
2161 std::vector<Tensor> outputs;
2162 std::vector<int64_t> device_handles;
2163
2164 // Create more data the device can take using the broadcast computation.
2165 int64_t num_tensors = 8 + device_memory_size / tensor_size;
2166 for (int64_t i = 0; i < num_tensors; ++i) {
2167 auto result = ops::XRTExecute(root, c_bcast_handle.handle, e_config,
2168 {Output(p0_handle)});
2169 TF_ASSERT_OK(root.status());
2170 TF_ASSERT_OK(session.Run({result}, &outputs));
2171 EXPECT_EQ(outputs.size(), 1);
2172 device_handles.push_back(outputs[0].scalar<int64_t>()());
2173 }
2174
2175 // Trigger computations on XRT handles to verify the swap-out/swap-in logic,
2176 // by comparing sequential couple of tensors.
2177 auto zero_literal = xla::LiteralUtil::CreateR0<int32_t>(0);
2178 for (size_t i = 0; i + 1 < device_handles.size(); ++i) {
2179 auto exec_op = ops::XRTExecute(
2180 root, c_equal_handle.handle, e_config,
2181 {Input(device_handles[i]), Input(device_handles[i + 1])});
2182 auto read_back = ops::XRTReadLiteral(root, exec_op);
2183
2184 TF_ASSERT_OK(root.status());
2185 TF_ASSERT_OK(session.Run({read_back}, &outputs));
2186 EXPECT_EQ(outputs.size(), 1);
2187
2188 xla::LiteralProto response;
2189 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
2190 auto literal = xla::Literal::CreateFromProto(response).ValueOrDie();
2191 EXPECT_EQ(literal, zero_literal);
2192 }
2193 }
2194
TEST(RawApiTest,TestMetricsFetch)2195 TEST(RawApiTest, TestMetricsFetch) {
2196 xrt::XRTMetricsCollect metrics;
2197 metrics.add_metrics_regex("/tensorflow/xrt/.*");
2198
2199 Scope root = Scope::NewRootScope().WithDevice("/device:CPU:0");
2200 auto metrics_value = ops::Const(root, metrics.SerializeAsString());
2201 Output result = ops::XRTMetricsCollect(root, metrics_value);
2202 TF_ASSERT_OK(root.status());
2203
2204 ClientSession session(root);
2205 std::vector<Tensor> outputs;
2206 TF_EXPECT_OK(session.Run({result}, &outputs));
2207 ASSERT_EQ(outputs.size(), 1);
2208
2209 xrt::MetricsReport report;
2210 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &report));
2211 for (auto& metric : report.metrics()) {
2212 EXPECT_EQ(metric.name().compare(0, 16, "/tensorflow/xrt/"), 0);
2213 }
2214 }
2215
TEST(RawApiTest,TestMemoryInfo)2216 TEST(RawApiTest, TestMemoryInfo) {
2217 Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
2218 Output result = ops::XRTMemoryInfo(root);
2219 TF_ASSERT_OK(root.status());
2220
2221 ClientSession session(root);
2222 std::vector<Tensor> outputs;
2223 TF_EXPECT_OK(session.Run({result}, &outputs));
2224 ASSERT_EQ(outputs.size(), 1);
2225
2226 xrt::MemoryInfo mem_info;
2227 EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &mem_info));
2228 EXPECT_GT(mem_info.kb_total(), 0);
2229 EXPECT_GT(mem_info.kb_free(), 0);
2230 }
2231
2232 } // namespace
2233
2234 } // namespace tensorflow
2235
main(int argc,char ** argv)2236 int main(int argc, char** argv) {
2237 tensorflow::xla_test_device_ptr = new tensorflow::string("XLA_CPU");
2238 tensorflow::xla_platform_ptr = new tensorflow::string("CPU");
2239 std::vector<tensorflow::Flag> flag_list = {
2240 tensorflow::Flag("xla_test_device", tensorflow::xla_test_device_ptr,
2241 "Tensorflow device type to use for test, e.g., XLA_CPU"),
2242 tensorflow::Flag("xla_platform", tensorflow::xla_platform_ptr,
2243 "The XLA platform to select for the device"),
2244 };
2245 tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
2246 const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
2247 if (!parse_result) {
2248 LOG(ERROR) << "\n" << usage;
2249 return 2;
2250 }
2251 testing::InitGoogleTest(&argc, argv);
2252 if (argc > 1) {
2253 LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
2254 return 2;
2255 }
2256 return RUN_ALL_TESTS();
2257 }
2258