• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 #include <vector>
19 
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/cc/client/client_session.h"
22 #include "tensorflow/cc/framework/ops.h"
23 #include "tensorflow/cc/framework/scope.h"
24 #include "tensorflow/cc/ops/standard_ops.h"
25 #include "tensorflow/compiler/tf2xla/literal_util.h"
26 #include "tensorflow/compiler/tf2xla/shape_util.h"
27 #include "tensorflow/compiler/xla/client/client_library.h"
28 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
29 #include "tensorflow/compiler/xla/client/lib/constants.h"
30 #include "tensorflow/compiler/xla/client/local_client.h"
31 #include "tensorflow/compiler/xla/client/xla_builder.h"
32 #include "tensorflow/compiler/xla/client/xla_computation.h"
33 #include "tensorflow/compiler/xla/literal.h"
34 #include "tensorflow/compiler/xla/literal_util.h"
35 #include "tensorflow/compiler/xla/service/platform_util.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h"
39 #include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h"
40 #include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h"
41 #include "tensorflow/compiler/xrt/xrt.pb.h"
42 #include "tensorflow/core/framework/tensor.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45 #include "tensorflow/core/lib/gtl/array_slice.h"
46 #include "tensorflow/core/platform/types.h"
47 #include "tensorflow/core/util/command_line_flags.h"
48 
49 namespace tensorflow {
50 namespace {
51 
ReturnDynamicR1()52 xla::XlaComputation ReturnDynamicR1() {
53   xla::XlaBuilder builder("ReturnDynamicR1");
54   auto p0 = xla::Parameter(&builder, 0,
55                            xla::ShapeUtil::MakeShape(xla::F32, {4}), "P0");
56   auto p1 = xla::Parameter(&builder, 1,
57                            xla::ShapeUtil::MakeShape(xla::F32, {4}), "P1");
58   auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}),
59                            "P2");
60   auto sum = xla::Add(p0, p1);
61   auto pad_sum = xla::SetDimensionSize(sum, p2, 0);
62   return builder.Build(pad_sum).ValueOrDie();
63 }
64 
ReturnDynamicR2()65 xla::XlaComputation ReturnDynamicR2() {
66   xla::XlaBuilder builder("ReturnDynamicR2");
67   auto p0 = xla::Parameter(&builder, 0,
68                            xla::ShapeUtil::MakeShape(xla::F32, {2, 4}), "P0");
69   auto p1 = xla::Parameter(&builder, 1,
70                            xla::ShapeUtil::MakeShape(xla::F32, {2, 4}), "P1");
71   auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}),
72                            "P2");
73   auto sum = xla::Add(p0, p1);
74   auto pad_sum_dim0 = xla::SetDimensionSize(sum, p2, 0);
75   auto pad_sum_dim1 = xla::SetDimensionSize(pad_sum_dim0, p2, 1);
76   return builder.Build(pad_sum_dim1).ValueOrDie();
77 }
78 
AcceptDynamicR1()79 xla::XlaComputation AcceptDynamicR1() {
80   xla::XlaBuilder builder("AcceptDynamicR1");
81   xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
82   dyn_shape.set_dynamic_dimension(0, true);
83   auto p0 = xla::Parameter(&builder, 0, dyn_shape, "P0");
84   auto p1 = xla::Parameter(&builder, 1, dyn_shape, "P1");
85   auto sum = xla::Add(p0, p1);
86   return builder.Build(sum).ValueOrDie();
87 }
88 
AcceptDynamicR2()89 xla::XlaComputation AcceptDynamicR2() {
90   xla::XlaBuilder builder("AcceptDynamicR2");
91   xla::Shape dyn_shape;
92   dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4});
93   dyn_shape.set_dynamic_dimension(1, true);
94   auto p0 = xla::Parameter(&builder, 0, dyn_shape, "P0");
95   auto negate = xla::Neg(p0);
96   return builder.Build(negate).ValueOrDie();
97 }
98 
ReturnDynamicR1Tuple()99 xla::XlaComputation ReturnDynamicR1Tuple() {
100   xla::XlaBuilder builder("ReturnDynamicR1Tuple");
101   auto p0 = xla::Parameter(&builder, 0,
102                            xla::ShapeUtil::MakeShape(xla::F32, {4}), "P0");
103   auto p1 = xla::Parameter(&builder, 1,
104                            xla::ShapeUtil::MakeShape(xla::F32, {4}), "P1");
105   auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}),
106                            "P2");
107   auto sum = xla::Add(p0, p1);
108   auto sub = xla::Sub(p0, p1);
109   auto one = xla::One(&builder, xla::S32);
110   auto pad_sum = xla::SetDimensionSize(sum, p2, 0);
111   auto pad_sub = xla::SetDimensionSize(sub, p2 + one, 0);
112   auto tuple = xla::Tuple(&builder, {pad_sum, sum, pad_sub});
113   return builder.Build(tuple, /*remove_dynamic_dimensions=*/true).ValueOrDie();
114 }
115 
AcceptDynamicR1Tuple()116 xla::XlaComputation AcceptDynamicR1Tuple() {
117   xla::XlaBuilder builder("AcceptDynamicR1");
118   xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
119   dyn_shape.set_dynamic_dimension(0, true);
120   xla::Shape tuple_shape =
121       xla::ShapeUtil::MakeTupleShape({dyn_shape, dyn_shape});
122   xla::Shape nest_tuple_shape =
123       xla::ShapeUtil::MakeTupleShape({dyn_shape, dyn_shape});
124   auto p = xla::Parameter(&builder, 0, tuple_shape, "P0");
125   auto p0 = xla::GetTupleElement(p, 0);
126   auto p1 = xla::GetTupleElement(p, 1);
127   auto sum = xla::Add(p0, p1);
128   return builder.Build(sum).ValueOrDie();
129 }
130 
131 template <typename T>
CreateR0(T v)132 xla::LiteralProto CreateR0(T v) {
133   auto array = xla::LiteralUtil::CreateR0<T>(v);
134   return array.ToProto();
135 }
136 
137 class XrtClientSession : public ClientSession {
138  public:
XrtClientSession(const Scope & scope)139   explicit XrtClientSession(const Scope& scope) : ClientSession(scope) {
140     auto clear_all = ops::XRTReleaseAllAllocations(scope);
141     std::vector<Tensor> outputs;
142     TF_CHECK_OK(Run(ClientSession::FeedType(), {}, {clear_all}, &outputs));
143   }
144 };
145 
146 string* xla_test_device_ptr;  // initial value set in main()
147 string* xla_platform_ptr;     // initial value set in main()
148 
DeviceFromFlag()149 string DeviceFromFlag() {
150   string xla_test_device = *xla_test_device_ptr;
151   return absl::StrCat("/device:", xla_test_device, ":0");
152 }
153 
GetAttrLayout(absl::Span<const int64_t> minor_to_mayor)154 std::vector<int> GetAttrLayout(absl::Span<const int64_t> minor_to_mayor) {
155   std::vector<int> layout;
156   for (auto dim : minor_to_mayor) {
157     layout.push_back(static_cast<int>(dim));
158   }
159   return layout;
160 }
161 
TwoElementTuple()162 xla::LiteralProto TwoElementTuple() {
163   auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
164   auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
165   auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
166   return tuple.ToProto();
167 }
168 
BasedTwoElementTuple(float base)169 xla::LiteralProto BasedTwoElementTuple(float base) {
170   auto array = xla::LiteralUtil::CreateR1<float>({base, base + 1});
171   auto matrix = xla::LiteralUtil::CreateR2<float>(
172       {{base + 2, base + 3}, {base + 4, base + 5}});
173   auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
174   return tuple.ToProto();
175 }
176 
ScalarLiteral()177 xla::LiteralProto ScalarLiteral() {
178   auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
179   return scalar.ToProto();
180 }
181 
NestedTuple()182 xla::LiteralProto NestedTuple() {
183   auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
184   auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
185   auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
186   auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
187   auto nested = xla::LiteralUtil::MakeTuple({&tuple, &scalar});
188   return nested.ToProto();
189 }
190 
MakeTuple0()191 xla::LiteralProto MakeTuple0() {
192   auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
193   auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
194   auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
195   auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
196   auto nested0 = xla::LiteralUtil::MakeTuple({&scalar, &tuple});
197   auto nested1 = xla::LiteralUtil::MakeTuple({&scalar, &nested0});
198   return nested1.ToProto();
199 }
200 
FloatVector(absl::Span<const float> v)201 xla::LiteralProto FloatVector(absl::Span<const float> v) {
202   auto array = xla::LiteralUtil::CreateR1<float>(v);
203   return array.ToProto();
204 }
205 
FloatMatrix(std::initializer_list<std::initializer_list<float>> v,const xla::Layout & layout)206 xla::LiteralProto FloatMatrix(
207     std::initializer_list<std::initializer_list<float>> v,
208     const xla::Layout& layout) {
209   auto array = xla::LiteralUtil::CreateR2WithLayout<float>(v, layout);
210   return array.ToProto();
211 }
212 
ReadOutputLiteral(const std::vector<Tensor> & outputs,size_t idx)213 xla::Literal ReadOutputLiteral(const std::vector<Tensor>& outputs, size_t idx) {
214   xla::LiteralProto response;
215   CHECK(ParseFromTString(outputs[idx].scalar<tstring>()(), &response));
216   return xla::Literal::CreateFromProto(response).ValueOrDie();
217 }
218 
CompareLiteralProtos(const xla::LiteralProto & a,const xla::LiteralProto & b)219 bool CompareLiteralProtos(const xla::LiteralProto& a,
220                           const xla::LiteralProto& b) {
221   auto l_a = xla::Literal::CreateFromProto(a).ValueOrDie();
222   auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
223   bool equal = l_a == l_b;
224   if (!equal) {
225     LOG(INFO) << "LiteralProtos don't match:\n"
226               << a.DebugString() << "\n!=\n"
227               << b.DebugString();
228   }
229   return equal;
230 }
231 
CompareLiteralToLiteralProto(const xla::Literal & a,const xla::LiteralProto & b)232 bool CompareLiteralToLiteralProto(const xla::Literal& a,
233                                   const xla::LiteralProto& b) {
234   auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
235   bool equal = a == l_b;
236   if (!equal) {
237     LOG(INFO) << "Literal and LiteralProto don't match:\n"
238               << a.ToProto().DebugString() << "\n!=\n"
239               << b.DebugString();
240   }
241   return equal;
242 }
243 
CompareLiterals(const xla::Literal & a,const xla::Literal & b)244 bool CompareLiterals(const xla::Literal& a, const xla::Literal& b) {
245   bool equal = a == b;
246   if (!equal) {
247     LOG(INFO) << "Literals don't match:\n"
248               << a.ToProto().DebugString() << "\n!=\n"
249               << b.ToProto().DebugString();
250   }
251   return equal;
252 }
253 
OnePlusTwo()254 xla::XlaComputation OnePlusTwo() {
255   xla::XlaBuilder builder("OnePlusTwo");
256   auto c0 = xla::ConstantR0(&builder, 1.0f);
257   auto c1 = xla::ConstantR0(&builder, 2.0f);
258   xla::Add(c0, c1);
259   return builder.Build().ValueOrDie();
260 }
261 
AddAndScale()262 xla::XlaComputation AddAndScale() {
263   xla::XlaBuilder builder("AddAndScale");
264   auto p0 = xla::Parameter(&builder, 0,
265                            xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0");
266   auto p1 = xla::Parameter(&builder, 1,
267                            xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1");
268   auto sum = xla::Add(p0, p1);
269   auto c = xla::ConstantR0<float>(&builder, 3.0f);
270   xla::Mul(sum, c);
271   return builder.Build().ValueOrDie();
272 }
273 
SubAndScale()274 xla::XlaComputation SubAndScale() {
275   xla::XlaBuilder builder("SubAndScale");
276   auto p0 = xla::Parameter(&builder, 0,
277                            xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0");
278   auto p1 = xla::Parameter(&builder, 1,
279                            xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1");
280   auto sum = xla::Sub(p0, p1);
281   auto c = xla::ConstantR0<float>(&builder, 11.0f);
282   xla::Mul(sum, c);
283   return builder.Build().ValueOrDie();
284 }
285 
Dot()286 xla::XlaComputation Dot() {
287   xla::XlaBuilder builder("Dot");
288   auto p0 = xla::Parameter(
289       &builder, 0,
290       xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1}), "P0");
291   auto p1 = xla::Parameter(
292       &builder, 1,
293       xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}), "P1");
294   xla::DotDimensionNumbers ddn;
295   ddn.add_lhs_contracting_dimensions(1);
296   ddn.add_rhs_contracting_dimensions(0);
297   xla::DotGeneral(p0, p1, ddn);
298   return builder.Build().ValueOrDie();
299 }
300 
AddS64()301 xla::XlaComputation AddS64() {
302   xla::XlaBuilder builder("AddS64");
303   auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::S64, {}),
304                            "P0");
305   auto p1 = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(xla::S64, {}),
306                            "P1");
307   xla::Add(p0, p1);
308   return builder.Build().ValueOrDie();
309 }
310 
AddAndTuple()311 xla::XlaComputation AddAndTuple() {
312   xla::XlaBuilder builder("AddAndTuple");
313   auto p0 = xla::Parameter(&builder, 0,
314                            xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0");
315   auto p1 = xla::Parameter(&builder, 1,
316                            xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1");
317   auto sum = xla::Add(p0, p1);
318   xla::Tuple(&builder, {sum});
319   return builder.Build().ValueOrDie();
320 }
321 
AddAndSubTuple()322 xla::XlaComputation AddAndSubTuple() {
323   xla::XlaBuilder builder("AddAndSubTuple");
324   auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::F32, {}),
325                            "P0");
326   auto p1 = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(xla::F32, {}),
327                            "P1");
328   auto sum = xla::Add(p0, p1);
329   auto sub = xla::Sub(p0, p1);
330   xla::Tuple(&builder, {sum, sub});
331   return builder.Build().ValueOrDie();
332 }
333 
BroadcastComputation(const xla::Shape & shape,absl::Span<const int64_t> dimensions)334 xla::XlaComputation BroadcastComputation(const xla::Shape& shape,
335                                          absl::Span<const int64_t> dimensions) {
336   xla::XlaBuilder builder("BroadcastComputation");
337   auto p0 = xla::Parameter(&builder, 0, shape, "P0");
338   xla::Broadcast(p0, dimensions);
339   return builder.Build().ValueOrDie();
340 }
341 
IsEqualComputation(const xla::Shape & shape)342 xla::XlaComputation IsEqualComputation(const xla::Shape& shape) {
343   xla::XlaBuilder builder("IsEqualComputation");
344   auto p0 = xla::Parameter(&builder, 0, shape, "P0");
345   auto p1 = xla::Parameter(&builder, 1, shape, "P1");
346   auto cmp =
347       xla::Ne(xla::Sub(p0, p1), xla::Zero(&builder, shape.element_type()));
348   auto icmp = xla::ConvertElementType(cmp, xla::S32);
349   xla::ReduceAll(icmp, xla::Zero(&builder, xla::S32),
350                  xla::CreateScalarAddComputation(xla::S32, &builder));
351   return builder.Build().ValueOrDie();
352 }
353 
StoreComputationSnapshot(const xla::XlaComputation & computation,xla::HloSnapshot * dst)354 void StoreComputationSnapshot(const xla::XlaComputation& computation,
355                               xla::HloSnapshot* dst) {
356   auto snapshot = computation.Snapshot().ValueOrDie();
357   *dst = *snapshot;
358 }
359 
XlaCompiledProgramShape(const xla::XlaComputation & computation,const xla::ProgramShape & input_program_shape)360 xla::ProgramShape XlaCompiledProgramShape(
361     const xla::XlaComputation& computation,
362     const xla::ProgramShape& input_program_shape) {
363   se::Platform* platform =
364       xla::PlatformUtil::GetPlatform(*xla_platform_ptr).ValueOrDie();
365   xla::LocalClient* client =
366       xla::ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
367   xla::ExecutableBuildOptions exec_options;
368   exec_options.set_result_layout(input_program_shape.result());
369   std::vector<const xla::Shape*> parameters_shapes;
370   for (int64_t i = 0; i < input_program_shape.parameters_size(); ++i) {
371     parameters_shapes.push_back(&input_program_shape.parameters(i));
372   }
373   std::vector<std::unique_ptr<xla::LocalExecutable>> local_executables =
374       client->Compile(computation, parameters_shapes, exec_options).value();
375   EXPECT_EQ(local_executables.size(), 1);
376   std::unique_ptr<xla::LocalExecutable> local_executable =
377       std::move(local_executables[0]);
378   return local_executable->executable()
379       ->module()
380       .entry_computation()
381       ->ComputeProgramShape();
382 }
383 
TEST(RawApiTest,AllocFromTensor)384 TEST(RawApiTest, AllocFromTensor) {
385   xla::Literal literal =
386       xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
387   Tensor tensor;
388   TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor));
389 
390   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
391   std::vector<int> layout =
392       GetAttrLayout(literal.shape().layout().minor_to_major());
393   ops::XRTAllocateFromTensor::Attrs alloc_attrs =
394       ops::XRTAllocateFromTensor::Layouts(layout);
395   auto handle =
396       ops::XRTAllocateFromTensor(root, {tensor}, {tensor.shape()}, alloc_attrs);
397   auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
398   TF_ASSERT_OK(root.status());
399 
400   XrtClientSession session(root);
401   std::vector<Tensor> outputs;
402   TF_EXPECT_OK(session.Run({read_back}, &outputs));
403   EXPECT_EQ(outputs.size(), 1);
404 
405   xla::LiteralProto response;
406   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
407   EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
408 }
409 
TEST(RawApiTest,AllocUninitialized)410 TEST(RawApiTest, AllocUninitialized) {
411   xla::Literal literal =
412       xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
413   Tensor tensor;
414   TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor));
415 
416   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
417   std::vector<int> layout =
418       GetAttrLayout(literal.shape().layout().minor_to_major());
419 
420   auto allocate_op =
421       ops::XRTAllocateUninitialized(root, DT_FLOAT, tensor.shape());
422 
423   Tensor handle;
424   std::vector<Tensor> outputs;
425   XrtClientSession session(root);
426   // Allocate the tensor
427   {
428     TF_EXPECT_OK(session.Run({allocate_op}, &outputs));
429     handle = outputs[0];
430   }
431 
432   // Make sure it has the expected shape
433   {
434     auto read_back_op = ops::XRTReadLiteral(root, handle);
435     TF_ASSERT_OK(root.status());
436 
437     TF_EXPECT_OK(session.Run({read_back_op}, &outputs));
438     EXPECT_EQ(outputs.size(), 1);
439     xla::LiteralProto read_back_literal;
440     EXPECT_TRUE(
441         ParseFromTString(outputs[0].scalar<tstring>()(), &read_back_literal));
442     Tensor read_back_tensor;
443     TF_ASSERT_OK(LiteralToHostTensor(
444         xla::Literal::CreateFromProto(read_back_literal).ValueOrDie(), DT_FLOAT,
445         &read_back_tensor));
446 
447     // The shape should be the same as 'tensor', but we don't have any
448     // expectation about the value of the tensors yet since it is uninitialized
449     EXPECT_EQ(tensor.shape(), read_back_tensor.shape());
450   }
451 
452   // Make sure we can write to it
453   xla::LiteralProto new_literal =
454       xla::LiteralUtil::CreateR2({{9.0f, 2.0f}, {4.0f, 1.0f}}).ToProto();
455   {
456     auto new_value = ops::Const(root.WithDevice("/device:CPU:0"),
457                                 new_literal.SerializeAsString());
458     auto write_op = ops::XRTWriteLiteral(root, Input(handle), new_value);
459     TF_ASSERT_OK(root.status());
460     TF_EXPECT_OK(session.Run({write_op}, &outputs));
461   }
462 
463   // Now read it back
464   {
465     auto read_back_op = ops::XRTReadLiteralAndRelease(root, handle);
466     TF_ASSERT_OK(root.status());
467     TF_EXPECT_OK(session.Run({read_back_op}, &outputs));
468     EXPECT_EQ(outputs.size(), 1);
469 
470     xla::LiteralProto response;
471     EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
472     EXPECT_TRUE(CompareLiteralProtos(response, new_literal));
473   }
474 }
475 
TEST(RawApiTest,AllocFromTensorTuple)476 TEST(RawApiTest, AllocFromTensorTuple) {
477   xla::Literal literal0 =
478       xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
479   xla::Literal literal1 =
480       xla::LiteralUtil::CreateR2<float>({{14.0f, -5.0f}, {16.0f, 17.0f}});
481   xla::Literal literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1});
482   Tensor tensor0;
483   TF_ASSERT_OK(LiteralToHostTensor(literal0, DT_FLOAT, &tensor0));
484   Tensor tensor1;
485   TF_ASSERT_OK(LiteralToHostTensor(literal1, DT_FLOAT, &tensor1));
486 
487   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
488   std::vector<int> layout = GetShapeLayoutVector(literal.shape()).ValueOrDie();
489   ops::XRTAllocateFromTensor::Attrs alloc_attrs =
490       ops::XRTAllocateFromTensor::Layouts(layout);
491   auto handle = ops::XRTAllocateFromTensor(root, {tensor0, tensor1},
492                                            {tensor0.shape(), tensor1.shape()},
493                                            alloc_attrs);
494   auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
495   TF_ASSERT_OK(root.status());
496 
497   XrtClientSession session(root);
498   std::vector<Tensor> outputs;
499   TF_EXPECT_OK(session.Run({read_back}, &outputs));
500   EXPECT_EQ(outputs.size(), 1);
501 
502   xla::LiteralProto response;
503   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
504   EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
505 }
506 
TEST(RawApiTest,AllocFromTensorTupleSingle)507 TEST(RawApiTest, AllocFromTensorTupleSingle) {
508   xla::Literal literal0 =
509       xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
510   xla::Literal literal = xla::LiteralUtil::MakeTuple({&literal0});
511   Tensor tensor0;
512   TF_ASSERT_OK(LiteralToHostTensor(literal0, DT_FLOAT, &tensor0));
513 
514   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
515   std::vector<int> layout = GetShapeLayoutVector(literal.shape()).ValueOrDie();
516   ops::XRTAllocateFromTensor::Attrs alloc_attrs =
517       ops::XRTAllocateFromTensor::Layouts(layout).MakeTuple(true);
518   auto handle = ops::XRTAllocateFromTensor(root, {tensor0}, {tensor0.shape()},
519                                            alloc_attrs);
520   auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
521   TF_ASSERT_OK(root.status());
522 
523   XrtClientSession session(root);
524   std::vector<Tensor> outputs;
525   TF_EXPECT_OK(session.Run({read_back}, &outputs));
526   EXPECT_EQ(outputs.size(), 1);
527 
528   xla::LiteralProto response;
529   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
530   EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
531 }
532 
TEST(RawApiTest,AllocFromTensorRelayout)533 TEST(RawApiTest, AllocFromTensorRelayout) {
534   xla::Literal literal =
535       xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
536   Tensor tensor;
537   TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor));
538 
539   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
540   // Use inverse array layout with the tensor data above.
541   std::vector<int> layout({0, 1});
542   ops::XRTAllocateFromTensor::Attrs alloc_attrs =
543       ops::XRTAllocateFromTensor::Layouts(layout);
544   auto handle =
545       ops::XRTAllocateFromTensor(root, {tensor}, {tensor.shape()}, alloc_attrs);
546   auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
547   TF_ASSERT_OK(root.status());
548 
549   XrtClientSession session(root);
550   std::vector<Tensor> outputs;
551   TF_EXPECT_OK(session.Run({read_back}, &outputs));
552   EXPECT_EQ(outputs.size(), 1);
553 
554   xla::LiteralProto response;
555   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
556   // We have sent literal's data (in array layout) with a attribute layout
557   // {0,1}, so the expected literal read from device needs to be changed
558   // accordingly.
559   xla::Literal expected_literal =
560       xla::LiteralUtil::CreateR2<float>({{4.0f, 6.0f}, {5.0f, 7.0f}});
561   EXPECT_TRUE(CompareLiteralToLiteralProto(expected_literal, response));
562 }
563 
TEST(RawApiTest,AllocAndRewrite)564 TEST(RawApiTest, AllocAndRewrite) {
565   xrt::XLAAllocation alloc;
566   *alloc.mutable_value() =
567       xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto();
568 
569   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
570   auto value =
571       ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
572   auto handle = ops::XRTAllocate(root, value);
573   auto read_back = ops::XRTReadLiteral(root, handle);
574   TF_ASSERT_OK(root.status());
575 
576   XrtClientSession session(root);
577   std::vector<Tensor> outputs;
578   TF_EXPECT_OK(session.Run({read_back, handle}, &outputs));
579   EXPECT_EQ(outputs.size(), 2);
580 
581   int64_t allocation_handle = outputs[1].scalar<int64_t>()();
582   xla::LiteralProto response;
583   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
584   EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
585 
586   xla::LiteralProto new_literal =
587       xla::LiteralUtil::CreateR2({{9, 2}, {4, 1}}).ToProto();
588   auto new_value = ops::Const(root.WithDevice("/device:CPU:0"),
589                               new_literal.SerializeAsString());
590   auto write_op =
591       ops::XRTWriteLiteral(root, Input(allocation_handle), new_value);
592   TF_ASSERT_OK(root.status());
593   TF_EXPECT_OK(session.Run({write_op}, &outputs));
594   EXPECT_EQ(outputs.size(), 1);
595   EXPECT_EQ(allocation_handle, outputs[0].scalar<int64_t>()());
596 
597   auto read_after_write = ops::XRTReadLiteral(root, Input(allocation_handle));
598   TF_EXPECT_OK(session.Run({read_after_write}, &outputs));
599   EXPECT_EQ(outputs.size(), 1);
600 
601   xla::LiteralProto new_response;
602   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &new_response));
603   EXPECT_TRUE(CompareLiteralProtos(new_literal, new_response));
604 
605   Tensor release_tensor(DT_INT64, TensorShape({1}));
606   release_tensor.flat<int64_t>()(0) = allocation_handle;
607 
608   auto release = ops::XRTReleaseAllocationHandle(root, release_tensor);
609   TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {}, {release}, &outputs));
610 }
611 
TEST(RawApiTest,AllocReleaseMany)612 TEST(RawApiTest, AllocReleaseMany) {
613   xrt::XLAAllocation alloc1;
614   *alloc1.mutable_value() =
615       xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto();
616   xrt::XLAAllocation alloc2;
617   *alloc2.mutable_value() =
618       xla::LiteralUtil::CreateR2({{6, 7}, {4, 5}}).ToProto();
619 
620   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
621   auto value1 =
622       ops::Const(root.WithDevice("/device:CPU:0"), alloc1.SerializeAsString());
623   auto value2 =
624       ops::Const(root.WithDevice("/device:CPU:0"), alloc2.SerializeAsString());
625   auto handle1 = ops::XRTAllocate(root, value1);
626   auto handle2 = ops::XRTAllocate(root, value2);
627   TF_ASSERT_OK(root.status());
628 
629   XrtClientSession session(root);
630   std::vector<Tensor> outputs;
631   TF_EXPECT_OK(session.Run({handle1, handle2}, &outputs));
632   EXPECT_EQ(outputs.size(), 2);
633 
634   int64_t allocation_handle1 = outputs[0].scalar<int64_t>()();
635   int64_t allocation_handle2 = outputs[1].scalar<int64_t>()();
636 
637   Tensor release_tensor(DT_INT64, TensorShape({2}));
638   release_tensor.flat<int64_t>()(0) = allocation_handle1;
639   release_tensor.flat<int64_t>()(1) = allocation_handle2;
640 
641   auto release = ops::XRTReleaseAllocationHandle(root, release_tensor);
642   TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {}, {release}, &outputs));
643 }
644 
TEST(RawApiTest,CompileAndReleaseMany)645 TEST(RawApiTest, CompileAndReleaseMany) {
646   xrt::XLAComputation c1;
647   auto config1 = c1.mutable_config();
648   auto shapes1 = config1->mutable_program_shape();
649   *shapes1->add_parameters() =
650       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
651   *shapes1->add_parameters() =
652       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
653   *shapes1->mutable_result() =
654       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
655   StoreComputationSnapshot(AddAndScale(), c1.mutable_hlo_snapshot());
656 
657   xrt::XLAComputation c2;
658   auto config2 = c2.mutable_config();
659   auto shapes2 = config2->mutable_program_shape();
660   *shapes2->add_parameters() =
661       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
662   *shapes2->add_parameters() =
663       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
664   *shapes2->mutable_result() =
665       xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})})
666           .ToProto();
667   StoreComputationSnapshot(AddAndTuple(), c2.mutable_hlo_snapshot());
668 
669   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
670   auto computation1 =
671       ops::Const(root.WithDevice("/device:CPU:0"), c1.SerializeAsString());
672   auto c_handle1 = ops::XRTCompile(root, computation1);
673   auto computation2 =
674       ops::Const(root.WithDevice("/device:CPU:0"), c2.SerializeAsString());
675   auto c_handle2 = ops::XRTCompile(root, computation2);
676   TF_ASSERT_OK(root.status());
677 
678   XrtClientSession session(root);
679   std::vector<Tensor> outputs;
680   TF_EXPECT_OK(session.Run({c_handle1.handle, c_handle2.handle}, &outputs));
681   EXPECT_EQ(outputs.size(), 2);
682 
683   int64_t compilation_handle1 = outputs[0].scalar<int64_t>()();
684   int64_t compilation_handle2 = outputs[1].scalar<int64_t>()();
685 
686   Tensor release_tensor(DT_INT64, TensorShape({2}));
687   release_tensor.flat<int64_t>()(0) = compilation_handle1;
688   release_tensor.flat<int64_t>()(1) = compilation_handle2;
689 
690   auto release = ops::XRTReleaseCompilationHandle(root, release_tensor);
691   TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {}, {release}, &outputs));
692 }
693 
TEST(RawApiTest,AllocAndClearAll)694 TEST(RawApiTest, AllocAndClearAll) {
695   xrt::XLAAllocation alloc;
696   *alloc.mutable_value() =
697       xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto();
698 
699   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
700   auto value =
701       ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
702   auto handle = ops::XRTAllocate(root, value);
703   TF_ASSERT_OK(root.status());
704 
705   XrtClientSession session(root);
706   std::vector<Tensor> outputs;
707   TF_EXPECT_OK(session.Run({handle}, &outputs));
708   EXPECT_EQ(outputs.size(), 1);
709 
710   int64_t allocation_handle = outputs[0].scalar<int64_t>()();
711 
712   auto clear_all = ops::XRTReleaseAllAllocations(root);
713 
714   TF_EXPECT_OK(
715       session.Run(ClientSession::FeedType(), {}, {clear_all}, &outputs));
716   EXPECT_EQ(outputs.size(), 0);
717 
718   auto read_after_clear = ops::XRTReadLiteral(root, Input(allocation_handle));
719   EXPECT_EQ(session.Run({read_after_clear}, &outputs).code(),
720             error::Code::NOT_FOUND);
721 }
722 
TEST(RawApiTest,ReadAndWriteState)723 TEST(RawApiTest, ReadAndWriteState) {
724   xrt::XLAAllocation alloc;
725   *alloc.mutable_value() = TwoElementTuple();
726 
727   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
728   auto value =
729       ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
730   auto handle = ops::XRTAllocate(root, value);
731   auto read_back = ops::XRTReadLiteral(root, handle);
732   auto release = ops::XRTReleaseAllocationHandle(
733       root.WithControlDependencies(read_back), handle);
734   TF_ASSERT_OK(root.status());
735 
736   XrtClientSession session(root);
737   std::vector<Tensor> outputs;
738   TF_EXPECT_OK(
739       session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs));
740 
741   xla::LiteralProto response;
742   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
743 
744   EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
745 }
746 
TEST(RawApiTest,ReadAndWriteStateAutoFree)747 TEST(RawApiTest, ReadAndWriteStateAutoFree) {
748   xrt::XLAAllocation alloc;
749   *alloc.mutable_value() = TwoElementTuple();
750 
751   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
752   auto value =
753       ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
754   auto handle = ops::XRTAllocate(root, value);
755   auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
756   TF_ASSERT_OK(root.status());
757 
758   XrtClientSession session(root);
759   std::vector<Tensor> outputs;
760   TF_EXPECT_OK(session.Run({read_back}, &outputs));
761 
762   xla::LiteralProto response;
763   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
764   EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
765 }
766 
TEST(RawApiTest,SubBuffer)767 TEST(RawApiTest, SubBuffer) {
768   xrt::XLAAllocation alloc;
769   *alloc.mutable_value() = NestedTuple();
770 
771   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
772   auto value =
773       ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
774   auto base_handle = ops::XRTAllocate(root, value);
775   auto index_0 = ops::Const(root.WithDevice("/device:CPU:0"), {0});
776   auto index_1 = ops::Const(root.WithDevice("/device:CPU:0"), {1});
777   auto index_00 = ops::Const(root.WithDevice("/device:CPU:0"), {0, 0});
778   auto sub_0 = ops::XRTSubTuple(root, base_handle, index_0);
779   auto sub_1 = ops::XRTSubTuple(root, base_handle, index_1);
780   auto sub_00 = ops::XRTSubTupleAndRelease(
781       root.WithControlDependencies(
782           {sub_0.output_handle.op(), sub_1.output_handle.op()}),
783       base_handle, index_00);
784   auto value_0 = ops::XRTReadLiteralAndRelease(root, sub_0);
785   auto value_1 = ops::XRTReadLiteralAndRelease(root, sub_1);
786   auto value_00 = ops::XRTReadLiteralAndRelease(root, sub_00);
787   TF_ASSERT_OK(root.status());
788 
789   XrtClientSession session(root);
790   std::vector<Tensor> outputs;
791   TF_EXPECT_OK(session.Run({value_0, value_1, value_00}, &outputs));
792 
793   auto base_literal = xla::Literal::CreateFromProto(alloc.value()).ValueOrDie();
794   auto base_elements = base_literal.DecomposeTuple();
795   auto nested_0_elements = base_elements[0].Clone().DecomposeTuple();
796   xla::LiteralProto response_0;
797   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response_0));
798   EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[0], response_0));
799   xla::LiteralProto response_1;
800   EXPECT_TRUE(ParseFromTString(outputs[1].scalar<tstring>()(), &response_1));
801   EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[1], response_1));
802   xla::LiteralProto response_00;
803   EXPECT_TRUE(ParseFromTString(outputs[2].scalar<tstring>()(), &response_00));
804   EXPECT_TRUE(CompareLiteralToLiteralProto(nested_0_elements[0], response_00));
805 }
806 
TEST(RawApiTest,MakeTuple)807 TEST(RawApiTest, MakeTuple) {
808   xrt::XLAAllocation alloc_0;
809   *alloc_0.mutable_value() = TwoElementTuple();
810   xrt::XLAAllocation alloc_1;
811   *alloc_1.mutable_value() = ScalarLiteral();
812 
813   // The trivial tuple that just forwards its input and releases it.
814   xrt::XLATupleNode desc_0;
815   desc_0.set_input_index(0);
816   desc_0.set_release_input_handle(true);
817 
818   xrt::XLATupleNode desc_1;
819   auto subdesc_10 = desc_1.add_tuples();
820   auto subdesc_11 = desc_1.add_tuples();
821   subdesc_10->set_input_index(0);
822   auto subdesc_110 = subdesc_11->add_tuples();
823   subdesc_110->set_input_index(0);
824   auto subdesc_111 = subdesc_11->add_tuples();
825   subdesc_111->set_input_index(1);
826 
827   xrt::XLATupleNode desc_2;
828   auto subdesc_20 = desc_2.add_tuples();
829   auto subdesc_21 = desc_2.add_tuples();
830   subdesc_20->set_input_index(1);
831   subdesc_20->set_release_input_handle(true);
832   subdesc_21->set_input_index(0);
833   subdesc_21->set_release_input_handle(true);
834 
835   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
836   auto value_0 =
837       ops::Const(root.WithDevice("/device:CPU:0"), alloc_0.SerializeAsString());
838   auto handle_0 = ops::XRTAllocate(root, value_0);
839   auto value_1 =
840       ops::Const(root.WithDevice("/device:CPU:0"), alloc_1.SerializeAsString());
841   auto handle_1 = ops::XRTAllocate(root, value_1);
842   auto tuple_0 =
843       ops::Const(root.WithDevice("/device:CPU:0"), desc_0.SerializeAsString());
844   auto handle_2 =
845       ops::XRTMakeTuple(root, tuple_0, {static_cast<Output>(handle_0)});
846   // handle_0 has now been released.
847   auto tuple_1 =
848       ops::Const(root.WithDevice("/device:CPU:0"), desc_1.SerializeAsString());
849   auto handle_3 = ops::XRTMakeTuple(
850       root, tuple_1,
851       {static_cast<Output>(handle_1), static_cast<Output>(handle_2)});
852   auto tuple_2 =
853       ops::Const(root.WithDevice("/device:CPU:0"), desc_2.SerializeAsString());
854   // Make sure this runs after handle_3 has completed, since it will free
855   // handle_1 and handle_2.
856   auto handle_4 = ops::XRTMakeTuple(
857       root.WithControlDependencies(handle_3), tuple_2,
858       {static_cast<Output>(handle_1), static_cast<Output>(handle_2)});
859   // handle_1 and handle_2 have now been released.
860 
861   auto res_0 = ops::XRTReadLiteralAndRelease(root, handle_3);
862   auto res_1 = ops::XRTReadLiteralAndRelease(root, handle_4);
863   TF_ASSERT_OK(root.status());
864 
865   XrtClientSession session(root);
866   std::vector<Tensor> outputs;
867   TF_EXPECT_OK(session.Run({res_0, res_1}, &outputs));
868   xla::LiteralProto response_0;
869   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response_0));
870   xla::LiteralProto response_1;
871   EXPECT_TRUE(ParseFromTString(outputs[1].scalar<tstring>()(), &response_1));
872 
873   auto expected_0 = MakeTuple0();
874   EXPECT_TRUE(CompareLiteralProtos(response_0, expected_0));
875   auto expected_1 = NestedTuple();
876   EXPECT_TRUE(CompareLiteralProtos(response_1, expected_1));
877 }
878 
TEST(RawApiTest,ExecuteChainedOpByOp)879 TEST(RawApiTest, ExecuteChainedOpByOp) {
880   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
881 
882   auto make_computation = [](const std::function<xla::XlaComputation()>& fn) {
883     xrt::XLAComputation c;
884     auto config = c.mutable_config();
885     auto shapes = config->mutable_program_shape();
886     *shapes->add_parameters() =
887         xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
888     *shapes->add_parameters() =
889         xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
890     *shapes->mutable_result() =
891         xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
892     StoreComputationSnapshot(fn(), c.mutable_hlo_snapshot());
893     return c.SerializeAsString();
894   };
895 
896   auto c_add_scale = make_computation(AddAndScale);
897   auto c_sub_scale = make_computation(SubAndScale);
898 
899   auto c_add_scale_op = ops::XRTCompile(
900       root, ops::Const(root.WithDevice("/device:CPU:0"), c_add_scale));
901   auto c_sub_scale_op = ops::XRTCompile(
902       root, ops::Const(root.WithDevice("/device:CPU:0"), c_sub_scale));
903   TF_ASSERT_OK(root.status());
904 
905   XrtClientSession session(root);
906   std::vector<Tensor> outputs;
907   TF_EXPECT_OK(
908       session.Run({c_add_scale_op.handle, c_sub_scale_op.handle}, &outputs));
909   EXPECT_EQ(outputs.size(), 2);
910 
911   int64_t c_add_scale_handle = outputs[0].scalar<int64_t>()();
912   int64_t c_sub_scale_handle = outputs[1].scalar<int64_t>()();
913 
914   xrt::XLAAllocation p0;
915   *p0.mutable_value() = FloatVector({1.0f, 2.0f});
916   xrt::XLAAllocation p1;
917   *p1.mutable_value() = FloatVector({8.0f, 5.0f});
918 
919   auto p0_handle = ops::XRTAllocate(
920       root,
921       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()));
922   auto p1_handle = ops::XRTAllocate(
923       root,
924       ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()));
925 
926   xrt::XRTExecutionConfig e;
927   e.set_release_input_handles(false);
928   e.set_release_compilation_handle(false);
929   auto e_config =
930       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
931   auto result0 = ops::XRTExecute(root, Input(c_add_scale_handle), e_config,
932                                  {Output(p0_handle), Output(p1_handle)});
933   auto result1 = ops::XRTExecute(root, Input(c_sub_scale_handle), e_config,
934                                  {Output(p0_handle), Output(p1_handle)});
935   auto result = ops::XRTExecute(root, Input(c_add_scale_handle), e_config,
936                                 {result0.output_handle, result1.output_handle});
937   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
938   TF_ASSERT_OK(root.status());
939 
940   TF_EXPECT_OK(session.Run({read_back}, &outputs));
941 
942   xla::LiteralProto response;
943   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
944 
945   auto expected = xla::LiteralUtil::CreateR1<float>({-150.0f, -36.0f});
946   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
947 }
948 
TEST(RawApiTest,ExecuteChained)949 TEST(RawApiTest, ExecuteChained) {
950   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
951 
952   auto make_computation = [](const std::function<xla::XlaComputation()>& fn) {
953     xrt::XLAComputation c;
954     auto config = c.mutable_config();
955     auto shapes = config->mutable_program_shape();
956     *shapes->add_parameters() =
957         xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
958     *shapes->add_parameters() =
959         xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
960     *shapes->mutable_result() =
961         xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
962     StoreComputationSnapshot(fn(), c.mutable_hlo_snapshot());
963     return c.SerializeAsString();
964   };
965 
966   auto c_add_scale = make_computation(AddAndScale);
967   auto c_sub_scale = make_computation(SubAndScale);
968 
969   auto c_add_scale_op = ops::XRTCompile(
970       root, ops::Const(root.WithDevice("/device:CPU:0"), c_add_scale));
971   auto c_sub_scale_op = ops::XRTCompile(
972       root, ops::Const(root.WithDevice("/device:CPU:0"), c_sub_scale));
973   TF_ASSERT_OK(root.status());
974 
975   XrtClientSession session(root);
976   std::vector<Tensor> outputs;
977   TF_EXPECT_OK(
978       session.Run({c_add_scale_op.handle, c_sub_scale_op.handle}, &outputs));
979   EXPECT_EQ(outputs.size(), 2);
980 
981   int64_t c_add_scale_handle = outputs[0].scalar<int64_t>()();
982   int64_t c_sub_scale_handle = outputs[1].scalar<int64_t>()();
983 
984   xrt::XLAAllocation p0;
985   *p0.mutable_value() = FloatVector({1.0f, 2.0f});
986   xrt::XLAAllocation p1;
987   *p1.mutable_value() = FloatVector({8.0f, 5.0f});
988 
989   auto p0_handle_op = ops::XRTAllocate(
990       root,
991       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()));
992   auto p1_handle_op = ops::XRTAllocate(
993       root,
994       ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()));
995 
996   TF_EXPECT_OK(session.Run({p0_handle_op, p1_handle_op}, &outputs));
997   EXPECT_EQ(outputs.size(), 2);
998 
999   int64_t p0_handle = outputs[0].scalar<int64_t>()();
1000   int64_t p1_handle = outputs[1].scalar<int64_t>()();
1001 
1002   xrt::XRTChainedExecuteConfig config;
1003   auto config_const =
1004       ops::Const(root.WithDevice("/device:CPU:0"), config.SerializeAsString());
1005 
1006   xrt::XRTChainedExecutePlan plan;
1007   xrt::XRTChainedExecuteOp* op;
1008   xrt::XRTChainedExecuteOp::Input* input;
1009   xrt::XRTChainedExecuteOp::Output* output;
1010 
1011   // Index 0
1012   op = plan.add_ops();
1013   op->set_data_handle(p0_handle);
1014 
1015   // Index 1
1016   op = plan.add_ops();
1017   op->set_data_handle(p1_handle);
1018 
1019   // Index 2
1020   op = plan.add_ops();
1021   op->set_computation_handle(c_add_scale_handle);
1022   input = op->add_inputs();
1023   input->set_op_index(0);
1024   input = op->add_inputs();
1025   input->set_op_index(1);
1026 
1027   // Index 3
1028   op = plan.add_ops();
1029   op->set_computation_handle(c_sub_scale_handle);
1030   input = op->add_inputs();
1031   input->set_op_index(0);
1032   input = op->add_inputs();
1033   input->set_op_index(1);
1034 
1035   // Index 4
1036   op = plan.add_ops();
1037   op->set_computation_handle(c_add_scale_handle);
1038   input = op->add_inputs();
1039   input->set_op_index(2);
1040   input = op->add_inputs();
1041   input->set_op_index(3);
1042   output = op->add_outputs();
1043   output->set_result_index(0);
1044 
1045   auto plan_const =
1046       ops::Const(root.WithDevice("/device:CPU:0"), plan.SerializeAsString());
1047   auto result = ops::XRTExecuteChained(root, plan_const, config_const);
1048   TF_ASSERT_OK(root.status());
1049 
1050   TF_EXPECT_OK(session.Run({result}, &outputs));
1051   EXPECT_EQ(outputs.size(), 1);
1052 
1053   auto handles_vec = outputs[0].vec<int64_t>();
1054   EXPECT_EQ(handles_vec.size(), 1);
1055 
1056   auto read_back = ops::XRTReadLiteralAndRelease(root, Input(handles_vec(0)));
1057   TF_ASSERT_OK(root.status());
1058 
1059   TF_EXPECT_OK(session.Run({read_back}, &outputs));
1060   EXPECT_EQ(outputs.size(), 1);
1061 
1062   xla::LiteralProto response;
1063   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1064 
1065   auto expected = xla::LiteralUtil::CreateR1<float>({-150.0f, -36.0f});
1066   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1067 }
1068 
TEST(RawApiTest,CompileAndExecute)1069 TEST(RawApiTest, CompileAndExecute) {
1070   xrt::XLAAllocation p0;
1071   *p0.mutable_value() = FloatVector({1.0f, 2.0f});
1072   xrt::XLAAllocation p1;
1073   *p1.mutable_value() = FloatVector({8.0f, 5.0f});
1074 
1075   xrt::XLAComputation c;
1076   auto config = c.mutable_config();
1077   auto shapes = config->mutable_program_shape();
1078   *shapes->add_parameters() =
1079       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1080   *shapes->add_parameters() =
1081       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1082   *shapes->mutable_result() =
1083       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1084   StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot());
1085 
1086   xrt::XRTExecutionConfig e;
1087   e.set_release_input_handles(true);
1088   e.set_release_compilation_handle(true);
1089 
1090   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1091   auto e_config =
1092       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1093   auto computation =
1094       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1095   auto c_handle = ops::XRTCompile(root, computation);
1096   auto p0_value =
1097       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1098   auto p0_handle = ops::XRTAllocate(root, p0_value);
1099   auto p1_value =
1100       ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1101   auto p1_handle = ops::XRTAllocate(root, p1_value);
1102   auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1103                                 {Output(p0_handle), Output(p1_handle)});
1104   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1105   TF_ASSERT_OK(root.status());
1106 
1107   XrtClientSession session(root);
1108   std::vector<Tensor> outputs;
1109   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1110 
1111   xla::LiteralProto response;
1112   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1113 
1114   auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
1115   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1116 
1117   xla::ProgramShapeProto program_shape;
1118   EXPECT_TRUE(ParseFromTString(outputs[1].vec<tstring>()(0), &program_shape));
1119   EXPECT_EQ(program_shape.parameters_size(), 2);
1120 }
1121 
TEST(RawApiTest,DynamicR1Test)1122 TEST(RawApiTest, DynamicR1Test) {
1123   xrt::XLAAllocation p0;
1124   *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f});
1125   xrt::XLAAllocation p1;
1126   *p1.mutable_value() = FloatVector({1.0f, -1.0f, 2.5f, 1.17f});
1127   xrt::XLAAllocation p2;
1128   *p2.mutable_value() = CreateR0<int32_t>(2);
1129 
1130   xrt::XLAComputation c;
1131   auto config = c.mutable_config();
1132   auto shapes = config->mutable_program_shape();
1133   *shapes->add_parameters() =
1134       xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
1135   *shapes->add_parameters() =
1136       xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
1137   *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto();
1138   xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1139   dyn_shape.set_dynamic_dimension(0, true);
1140   *shapes->mutable_result() = dyn_shape.ToProto();
1141   StoreComputationSnapshot(ReturnDynamicR1(), c.mutable_hlo_snapshot());
1142 
1143   xrt::XRTExecutionConfig e;
1144   e.set_release_input_handles(true);
1145   e.set_release_compilation_handle(true);
1146 
1147   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1148   Scope cpu_root = root.WithDevice("/device:CPU:0");
1149   auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1150   auto computation = ops::Const(cpu_root, c.SerializeAsString());
1151   auto c_handle = ops::XRTCompile(root, computation);
1152   auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1153   auto p0_handle = ops::XRTAllocate(root, p0_value);
1154   auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
1155   auto p1_handle = ops::XRTAllocate(root, p1_value);
1156   auto p2_value = ops::Const(cpu_root, p2.SerializeAsString());
1157   auto p2_handle = ops::XRTAllocate(root, p2_value);
1158   auto result = ops::XRTExecute(
1159       root, c_handle.handle, e_config,
1160       {Output(p0_handle), Output(p1_handle), Output(p2_handle)});
1161   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1162   TF_ASSERT_OK(root.status());
1163 
1164   XrtClientSession session(root);
1165   std::vector<Tensor> outputs;
1166   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1167 
1168   xla::LiteralProto response;
1169   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1170   auto expected = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f});
1171   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1172 }
1173 
TEST(RawApiTest,DynamicR2Test)1174 TEST(RawApiTest, DynamicR2Test) {
1175   xrt::XLAAllocation p0;
1176   *p0.mutable_value() = xla::LiteralUtil::CreateR2({{1.0f, 2.0f, 0.5f, -1.0f},
1177                                                     {1.5f, 2.5f, 3.0f, -2.0f}})
1178                             .ToProto();
1179   xrt::XLAAllocation p1;
1180   *p1.mutable_value() = xla::LiteralUtil::CreateR2({{1.0f, -1.0f, 2.5f, 1.17f},
1181                                                     {1.2f, -1.6f, 2.8f, 1.24f}})
1182                             .ToProto();
1183   xrt::XLAAllocation p2;
1184   *p2.mutable_value() = CreateR0<int32_t>(2);
1185 
1186   xrt::XLAComputation c;
1187   auto config = c.mutable_config();
1188   auto shapes = config->mutable_program_shape();
1189   *shapes->add_parameters() =
1190       xla::ShapeUtil::MakeShape(xla::F32, {2, 4}).ToProto();
1191   *shapes->add_parameters() =
1192       xla::ShapeUtil::MakeShape(xla::F32, {2, 4}).ToProto();
1193   *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto();
1194   xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4});
1195   dyn_shape.set_dynamic_dimension(0, true);
1196   dyn_shape.set_dynamic_dimension(1, true);
1197   *shapes->mutable_result() = dyn_shape.ToProto();
1198   StoreComputationSnapshot(ReturnDynamicR2(), c.mutable_hlo_snapshot());
1199 
1200   xrt::XRTExecutionConfig e;
1201   e.set_release_input_handles(true);
1202   e.set_release_compilation_handle(true);
1203 
1204   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1205   Scope cpu_root = root.WithDevice("/device:CPU:0");
1206   auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1207   auto computation = ops::Const(cpu_root, c.SerializeAsString());
1208   auto c_handle = ops::XRTCompile(root, computation);
1209   auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1210   auto p0_handle = ops::XRTAllocate(root, p0_value);
1211   auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
1212   auto p1_handle = ops::XRTAllocate(root, p1_value);
1213   auto p2_value = ops::Const(cpu_root, p2.SerializeAsString());
1214   auto p2_handle = ops::XRTAllocate(root, p2_value);
1215   auto result = ops::XRTExecute(
1216       root, c_handle.handle, e_config,
1217       {Output(p0_handle), Output(p1_handle), Output(p2_handle)});
1218   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1219   TF_ASSERT_OK(root.status());
1220 
1221   XrtClientSession session(root);
1222   std::vector<Tensor> outputs;
1223   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1224 
1225   xla::LiteralProto response;
1226   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1227   auto expected = xla::LiteralUtil::CreateR2<float>({{2.0f, 1.0f}, {2.7, 0.9}});
1228   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1229 }
1230 
TEST(RawApiTest,DynamicR1TupleTest)1231 TEST(RawApiTest, DynamicR1TupleTest) {
1232   xrt::XLAAllocation p0;
1233   *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f});
1234   xrt::XLAAllocation p1;
1235   *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f, 1.0f});
1236   xrt::XLAAllocation p2;
1237   *p2.mutable_value() = CreateR0<int32_t>(2);
1238 
1239   xrt::XLAComputation c;
1240   auto config = c.mutable_config();
1241   auto shapes = config->mutable_program_shape();
1242   *shapes->add_parameters() =
1243       xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
1244   *shapes->add_parameters() =
1245       xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
1246   *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto();
1247   xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1248   dyn_shape.set_dynamic_dimension(0, true);
1249   *shapes->mutable_result() =
1250       xla::ShapeUtil::MakeTupleShape(
1251           {dyn_shape, xla::ShapeUtil::MakeShape(xla::F32, {4}), dyn_shape})
1252           .ToProto();
1253   StoreComputationSnapshot(ReturnDynamicR1Tuple(), c.mutable_hlo_snapshot());
1254 
1255   xrt::XRTExecutionConfig e;
1256   e.set_release_input_handles(true);
1257   e.set_release_compilation_handle(true);
1258 
1259   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1260   Scope cpu_root = root.WithDevice("/device:CPU:0");
1261   auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1262   auto computation = ops::Const(cpu_root, c.SerializeAsString());
1263   auto c_handle = ops::XRTCompile(root, computation);
1264   auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1265   auto p0_handle = ops::XRTAllocate(root, p0_value);
1266   auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
1267   auto p1_handle = ops::XRTAllocate(root, p1_value);
1268   auto p2_value = ops::Const(cpu_root, p2.SerializeAsString());
1269   auto p2_handle = ops::XRTAllocate(root, p2_value);
1270   auto result = ops::XRTExecute(
1271       root, c_handle.handle, e_config,
1272       {Output(p0_handle), Output(p1_handle), Output(p2_handle)});
1273   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1274   TF_ASSERT_OK(root.status());
1275 
1276   XrtClientSession session(root);
1277   std::vector<Tensor> outputs;
1278   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1279 
1280   xla::LiteralProto response;
1281   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1282 
1283   auto expected0 = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f});
1284   auto expected1 = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f, 0.0f, 0.0f});
1285   auto expected2 = xla::LiteralUtil::CreateR1<float>({0.0f, 3.0f, 1.0f});
1286   auto expected =
1287       xla::LiteralUtil::MakeTuple({&expected0, &expected1, &expected2});
1288   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1289 }
1290 
TEST(RawApiTest,AcceptDynamicR1TupleTest)1291 TEST(RawApiTest, AcceptDynamicR1TupleTest) {
1292   if (*xla_test_device_ptr == "XLA_CPU" || *xla_test_device_ptr == "XLA_GPU") {
1293     // XLA_CPU and XLA_GPU has shape check set to kCompileTime.
1294     return;
1295   }
1296   xrt::XLAAllocation p0;
1297   *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f});
1298   xrt::XLAAllocation p1;
1299   *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f});
1300 
1301   xrt::XLATupleNode tuple_desc;
1302   auto subdesc_10 = tuple_desc.add_tuples();
1303   auto subdesc_11 = tuple_desc.add_tuples();
1304   subdesc_10->set_input_index(0);
1305   subdesc_10->set_release_input_handle(true);
1306   subdesc_11->set_input_index(1);
1307   subdesc_11->set_release_input_handle(true);
1308 
1309   xrt::XLAComputation c;
1310   auto config = c.mutable_config();
1311   auto shapes = config->mutable_program_shape();
1312   xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1313   dyn_input_shape.set_dynamic_dimension(0, true);
1314   xla::Shape dyn_tuple_shape =
1315       xla::ShapeUtil::MakeTupleShape({dyn_input_shape, dyn_input_shape});
1316   *shapes->add_parameters() = dyn_tuple_shape.ToProto();
1317   xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1318   dyn_shape.set_dynamic_dimension(0, true);
1319   *shapes->mutable_result() = dyn_shape.ToProto();
1320   StoreComputationSnapshot(AcceptDynamicR1Tuple(), c.mutable_hlo_snapshot());
1321 
1322   xrt::XRTExecutionConfig e;
1323   e.set_release_input_handles(true);
1324   e.set_release_compilation_handle(true);
1325 
1326   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1327   Scope cpu_root = root.WithDevice("/device:CPU:0");
1328   auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1329   auto computation = ops::Const(cpu_root, c.SerializeAsString());
1330   auto c_handle = ops::XRTCompile(root, computation);
1331   auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1332   auto p0_handle = ops::XRTAllocate(root, p0_value);
1333   auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
1334   auto p1_handle = ops::XRTAllocate(root, p1_value);
1335 
1336   auto tuple_0 = ops::Const(root.WithDevice("/device:CPU:0"),
1337                             tuple_desc.SerializeAsString());
1338   auto t0_handle = ops::XRTMakeTuple(
1339       root, tuple_0,
1340       {static_cast<Output>(p0_handle), static_cast<Output>(p1_handle)});
1341   auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1342                                 {static_cast<Output>(t0_handle)});
1343   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1344   TF_ASSERT_OK(root.status());
1345 
1346   XrtClientSession session(root);
1347   std::vector<Tensor> outputs;
1348   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1349 
1350   xla::LiteralProto response;
1351   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1352 
1353   auto expected = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f, 0.0f});
1354   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1355 }
1356 
TEST(RawApiTest,AcceptDynamicR1Test)1357 TEST(RawApiTest, AcceptDynamicR1Test) {
1358   if (*xla_test_device_ptr == "XLA_CPU" || *xla_test_device_ptr == "XLA_GPU") {
1359     // XLA_CPU and XLA_GPU has shape check set to kCompileTime.
1360     return;
1361   }
1362   xrt::XLAAllocation p0;
1363   *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f});
1364   xrt::XLAAllocation p1;
1365   *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f});
1366 
1367   xrt::XLAComputation c;
1368   auto config = c.mutable_config();
1369   auto shapes = config->mutable_program_shape();
1370   xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1371   dyn_input_shape.set_dynamic_dimension(0, true);
1372   *shapes->add_parameters() = dyn_input_shape.ToProto();
1373   *shapes->add_parameters() = dyn_input_shape.ToProto();
1374   xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1375   dyn_shape.set_dynamic_dimension(0, true);
1376   *shapes->mutable_result() = dyn_shape.ToProto();
1377   StoreComputationSnapshot(AcceptDynamicR1(), c.mutable_hlo_snapshot());
1378 
1379   xrt::XRTExecutionConfig e;
1380   e.set_release_input_handles(true);
1381   e.set_release_compilation_handle(true);
1382 
1383   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1384   Scope cpu_root = root.WithDevice("/device:CPU:0");
1385   auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1386   auto computation = ops::Const(cpu_root, c.SerializeAsString());
1387   auto c_handle = ops::XRTCompile(root, computation);
1388   auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1389   auto allocate_op_0 = ops::XRTAllocate(root, p0_value);
1390   auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
1391   auto allocate_op_1 = ops::XRTAllocate(root, p1_value);
1392   auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1393                                 {Output(allocate_op_0), Output(allocate_op_1)});
1394   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1395   TF_ASSERT_OK(root.status());
1396 
1397   XrtClientSession session(root);
1398   std::vector<Tensor> outputs;
1399   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1400 
1401   xla::LiteralProto response;
1402   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1403 
1404   auto expected = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f, 0.0f});
1405   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1406 }
1407 
TEST(RawApiTest,AcceptDynamicR2Test)1408 TEST(RawApiTest, AcceptDynamicR2Test) {
1409   xrt::XLAAllocation p0;
1410   *p0.mutable_value() =
1411       xla::LiteralUtil::CreateR2({{-1.0f, 2.0f, 3.0f}, {-4.0f, -5.0f, 6.0f}})
1412           .ToProto();
1413 
1414   xrt::XLAComputation c;
1415   auto config = c.mutable_config();
1416   auto shapes = config->mutable_program_shape();
1417   // Compile time expects ascending layout.
1418   xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4});
1419   dyn_shape.set_dynamic_dimension(1, true);
1420   *shapes->add_parameters() = dyn_shape.ToProto();
1421 
1422   *shapes->mutable_result() = dyn_shape.ToProto();
1423   StoreComputationSnapshot(AcceptDynamicR2(), c.mutable_hlo_snapshot());
1424 
1425   xrt::XRTExecutionConfig e;
1426   e.set_release_input_handles(true);
1427   e.set_release_compilation_handle(true);
1428 
1429   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1430   Scope cpu_root = root.WithDevice("/device:CPU:0");
1431   auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1432   auto computation = ops::Const(cpu_root, c.SerializeAsString());
1433   auto c_handle = ops::XRTCompile(root, computation);
1434   auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1435   auto p0_handle = ops::XRTAllocate(root, p0_value);
1436   auto result =
1437       ops::XRTExecute(root, c_handle.handle, e_config, {Output(p0_handle)});
1438   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1439   TF_ASSERT_OK(root.status());
1440 
1441   XrtClientSession session(root);
1442   std::vector<Tensor> outputs;
1443   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1444 
1445   xla::LiteralProto response;
1446   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1447 
1448   auto expected = xla::LiteralUtil::CreateR2<float>(
1449       {{1.0f, -2.0f, -3.0f}, {4.0f, 5.0f, -6.0f}});
1450   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1451 }
1452 
TEST(RawApiTest,CompileAndExecuteWithArgumentVector)1453 TEST(RawApiTest, CompileAndExecuteWithArgumentVector) {
1454   xrt::XLAAllocation p0;
1455   *p0.mutable_value() = FloatVector({1.0f, 2.0f});
1456   xrt::XLAAllocation p1;
1457   *p1.mutable_value() = FloatVector({8.0f, 5.0f});
1458 
1459   xrt::XLAComputation c;
1460   auto config = c.mutable_config();
1461   auto shapes = config->mutable_program_shape();
1462   *shapes->add_parameters() =
1463       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1464   *shapes->add_parameters() =
1465       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1466   *shapes->mutable_result() =
1467       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1468   StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot());
1469 
1470   xrt::XRTExecutionConfig e;
1471   e.set_release_input_handles(true);
1472   e.set_release_compilation_handle(true);
1473 
1474   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1475   auto e_config =
1476       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1477   auto computation =
1478       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1479   auto c_handle = ops::XRTCompile(root, computation);
1480   auto p0_value =
1481       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1482   auto p0_handle = ops::XRTAllocate(root, p0_value);
1483   auto p1_value =
1484       ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1485   auto p1_handle = ops::XRTAllocate(root, p1_value);
1486   auto packed_args = ops::Stack(root.WithDevice("/device:CPU:0"),
1487                                 {Output(p0_handle), Output(p1_handle)});
1488   auto result =
1489       ops::XRTExecute(root, c_handle.handle, e_config, {Output(packed_args)});
1490   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1491   TF_ASSERT_OK(root.status());
1492 
1493   XrtClientSession session(root);
1494   std::vector<Tensor> outputs;
1495   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1496 
1497   xla::LiteralProto response;
1498   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1499 
1500   auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
1501   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1502 
1503   xla::ProgramShapeProto program_shape;
1504   EXPECT_TRUE(ParseFromTString(outputs[1].vec<tstring>()(0), &program_shape));
1505   EXPECT_EQ(program_shape.parameters_size(), 2);
1506 }
1507 
TEST(RawApiTest,CompileWithXlaReturnShapes)1508 TEST(RawApiTest, CompileWithXlaReturnShapes) {
1509   xla::XlaBuilder builder("XrtXlaShapes");
1510   auto input_shape = xla::ShapeUtil::MakeShape(xla::BF16, {32, 3, 128, 128});
1511   auto kernel_shape = xla::ShapeUtil::MakeShape(xla::BF16, {3, 3, 5, 5});
1512   // Clear layouts to signal XLA we are ready to get whatever are coming out of
1513   // the compilation process.
1514   xla::LayoutUtil::ClearLayout(&input_shape);
1515   xla::LayoutUtil::ClearLayout(&kernel_shape);
1516   auto param_shape =
1517       xla::ShapeUtil::MakeTupleShape({input_shape, kernel_shape});
1518   auto param = xla::Parameter(&builder, 0, param_shape, "param");
1519   auto input = xla::GetTupleElement(param, 0);
1520   auto kernel = xla::GetTupleElement(param, 1);
1521   xla::Conv(input, kernel, {1, 1}, xla::Padding::kSame);
1522   TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation xla_computation, builder.Build());
1523 
1524   auto result_shape = xla_computation.GetProgramShape().ValueOrDie().result();
1525   // Clear the result shape layout to tell XLA we are accepting whatever are
1526   // coming out of the compilation process.
1527   xla::LayoutUtil::ClearLayout(&result_shape);
1528 
1529   xrt::XLAComputation c;
1530   auto config = c.mutable_config();
1531   auto shapes = config->mutable_program_shape();
1532   *shapes->add_parameters() = param_shape.ToProto();
1533   *shapes->mutable_result() = result_shape.ToProto();
1534   StoreComputationSnapshot(xla_computation, c.mutable_hlo_snapshot());
1535 
1536   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1537   auto computation =
1538       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1539   auto c_handle = ops::XRTCompile(root, computation);
1540   auto release = ops::XRTReleaseCompilationHandle(root, c_handle.handle);
1541   TF_ASSERT_OK(root.status());
1542 
1543   XrtClientSession session(root);
1544   std::vector<Tensor> outputs;
1545   TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {c_handle.program_shape},
1546                            {release}, &outputs));
1547 
1548   xla::ProgramShapeProto program_shape_proto;
1549   EXPECT_TRUE(
1550       ParseFromTString(outputs[0].vec<tstring>()(0), &program_shape_proto));
1551   xla::ProgramShape program_shape(program_shape_proto);
1552   EXPECT_EQ(program_shape.parameters_size(), 1);
1553 
1554   VLOG(2) << "Param: "
1555           << xla::ShapeUtil::HumanStringWithLayout(program_shape.parameters(0));
1556   VLOG(2) << "Result: "
1557           << xla::ShapeUtil::HumanStringWithLayout(program_shape.result());
1558 
1559   xla::ProgramShape xla_program_shape =
1560       XlaCompiledProgramShape(xla_computation, xla::ProgramShape(*shapes));
1561   EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()(
1562       xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(),
1563       xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0})
1564           .layout()));
1565   EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()(
1566       xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {1}).layout(),
1567       xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {1})
1568           .layout()));
1569   EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()(
1570       program_shape.result().layout(), xla_program_shape.result().layout()));
1571 }
1572 
TEST(RawApiTest,DotGeneralWithLayoutTest)1573 TEST(RawApiTest, DotGeneralWithLayoutTest) {
1574   auto layout = xla::LayoutUtil::MakeLayout({0, 1});
1575 
1576   xrt::XLAAllocation p0;
1577   *p0.mutable_value() = FloatMatrix({{1.0f, 2.0f}, {3.0f, 4.0f}}, layout);
1578   xrt::XLAAllocation p1;
1579   *p1.mutable_value() = FloatMatrix({{8.0f}, {5.0f}}, layout);
1580 
1581   xrt::XLAComputation c;
1582   auto config = c.mutable_config();
1583   auto shapes = config->mutable_program_shape();
1584   *shapes->add_parameters() =
1585       xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1}).ToProto();
1586   *shapes->add_parameters() =
1587       xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}).ToProto();
1588   *shapes->mutable_result() =
1589       xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}).ToProto();
1590   StoreComputationSnapshot(Dot(), c.mutable_hlo_snapshot());
1591 
1592   xrt::XRTExecutionConfig e;
1593   e.set_release_input_handles(true);
1594   e.set_release_compilation_handle(true);
1595 
1596   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1597   auto e_config =
1598       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1599   auto computation =
1600       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1601   auto c_handle = ops::XRTCompile(root, computation);
1602   auto p0_value =
1603       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1604   auto p0_handle = ops::XRTAllocate(root, p0_value);
1605   auto p1_value =
1606       ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1607   auto p1_handle = ops::XRTAllocate(root, p1_value);
1608   auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1609                                 {Output(p0_handle), Output(p1_handle)});
1610   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1611   TF_ASSERT_OK(root.status());
1612 
1613   XrtClientSession session(root);
1614   std::vector<Tensor> outputs;
1615   TF_EXPECT_OK(session.Run({read_back}, &outputs));
1616 
1617   xla::LiteralProto response;
1618   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1619 
1620   auto expected =
1621       xla::LiteralUtil::CreateR2WithLayout<float>({{18.0f}, {44.0f}}, layout);
1622 
1623   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1624 }
1625 
TEST(RawApiTest,CompileAndExecuteZeroArg)1626 TEST(RawApiTest, CompileAndExecuteZeroArg) {
1627   xrt::XLAComputation c;
1628   auto config = c.mutable_config();
1629   auto shapes = config->mutable_program_shape();
1630   *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto();
1631 
1632   xrt::XRTExecutionConfig e;
1633   e.set_release_input_handles(true);
1634   e.set_release_compilation_handle(true);
1635   StoreComputationSnapshot(OnePlusTwo(), c.mutable_hlo_snapshot());
1636 
1637   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1638   auto e_config =
1639       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1640   auto computation =
1641       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1642   auto c_handle = ops::XRTCompile(root, computation);
1643   auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1644                                 std::initializer_list<Input>({}));
1645   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1646   TF_ASSERT_OK(root.status());
1647 
1648   XrtClientSession session(root);
1649   std::vector<Tensor> outputs;
1650   TF_EXPECT_OK(session.Run({read_back}, &outputs));
1651 
1652   xla::LiteralProto response;
1653   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1654 
1655   auto expected = xla::LiteralUtil::CreateR0<float>(3.0f);
1656   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1657 }
1658 
TEST(RawApiTest,CompileAndExecuteReturnTuple)1659 TEST(RawApiTest, CompileAndExecuteReturnTuple) {
1660   xrt::XLAAllocation p0;
1661   *p0.mutable_value() = FloatVector({1.0f, 2.0f});
1662   xrt::XLAAllocation p1;
1663   *p1.mutable_value() = FloatVector({8.0f, 5.0f});
1664 
1665   xrt::XLAComputation c;
1666   auto config = c.mutable_config();
1667   auto shapes = config->mutable_program_shape();
1668   *shapes->add_parameters() =
1669       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1670   *shapes->add_parameters() =
1671       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1672   *shapes->mutable_result() =
1673       xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})})
1674           .ToProto();
1675   StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot());
1676 
1677   xrt::XRTExecutionConfig e;
1678   e.set_release_input_handles(true);
1679   e.set_release_compilation_handle(true);
1680 
1681   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1682   auto e_config =
1683       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1684   auto computation =
1685       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1686   auto c_handle = ops::XRTCompile(root, computation);
1687   auto p0_value =
1688       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1689   auto p0_handle = ops::XRTAllocate(root, p0_value);
1690   auto p1_value =
1691       ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1692   auto p1_handle = ops::XRTAllocate(root, p1_value);
1693   auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1694                                 {Output(p0_handle), Output(p1_handle)});
1695   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1696   TF_ASSERT_OK(root.status());
1697 
1698   XrtClientSession session(root);
1699   std::vector<Tensor> outputs;
1700   TF_EXPECT_OK(session.Run({read_back}, &outputs));
1701 
1702   xla::LiteralProto response;
1703   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1704 
1705   auto sum = xla::LiteralUtil::CreateR1<float>({9.0f, 7.0f});
1706   auto expected = xla::LiteralUtil::MakeTuple({&sum});
1707   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1708 }
1709 
TEST(RawApiTest,CompileAndExecuteReturnExplodedTuple)1710 TEST(RawApiTest, CompileAndExecuteReturnExplodedTuple) {
1711   xrt::XLAAllocation p0;
1712   *p0.mutable_value() = xla::LiteralUtil::CreateR0<float>(12.0f).ToProto();
1713 
1714   xrt::XLAAllocation p1;
1715   *p1.mutable_value() = xla::LiteralUtil::CreateR0<float>(3.0f).ToProto();
1716 
1717   xrt::XLAComputation c;
1718   auto config = c.mutable_config();
1719   auto shapes = config->mutable_program_shape();
1720   *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto();
1721   *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto();
1722   *shapes->mutable_result() =
1723       xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {}),
1724                                       xla::ShapeUtil::MakeShape(xla::F32, {})})
1725           .ToProto();
1726   StoreComputationSnapshot(AddAndSubTuple(), c.mutable_hlo_snapshot());
1727 
1728   xrt::XRTExecutionConfig e;
1729   e.set_release_input_handles(true);
1730   e.set_release_compilation_handle(true);
1731   e.set_return_exploded_tuple(true);
1732 
1733   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1734   auto e_config =
1735       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1736   auto computation =
1737       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1738   auto c_handle = ops::XRTCompile(root, computation);
1739   auto p0_value =
1740       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1741   auto p0_handle = ops::XRTAllocate(root, p0_value);
1742   auto p1_value =
1743       ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1744   auto p1_handle = ops::XRTAllocate(root, p1_value);
1745   auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1746                                 {Output(p0_handle), Output(p1_handle)});
1747   TF_ASSERT_OK(root.status());
1748 
1749   XrtClientSession session(root);
1750   std::vector<Tensor> outputs;
1751   TF_EXPECT_OK(session.Run({result}, &outputs));
1752   EXPECT_EQ(outputs.size(), 1);
1753 
1754   auto handles_vec = outputs.front().vec<int64_t>();
1755   EXPECT_EQ(handles_vec.size(), 2);
1756 
1757   const float kResults[2] = {15.0f, 9.0f};
1758   for (int64_t i = 0; i < handles_vec.size(); ++i) {
1759     auto read_back = ops::XRTReadLiteralAndRelease(root, Input(handles_vec(i)));
1760     std::vector<Tensor> voutputs;
1761     TF_EXPECT_OK(session.Run({read_back}, &voutputs));
1762     EXPECT_EQ(voutputs.size(), 1);
1763 
1764     xla::LiteralProto response;
1765     EXPECT_TRUE(ParseFromTString(voutputs[0].scalar<tstring>()(), &response));
1766 
1767     auto expected = xla::LiteralUtil::CreateR0<float>(kResults[i]);
1768     EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1769   }
1770 }
1771 
TEST(RawApiTest,LeakCompilationReference)1772 TEST(RawApiTest, LeakCompilationReference) {
1773   xrt::XLAComputation c;
1774   auto config = c.mutable_config();
1775   auto shapes = config->mutable_program_shape();
1776   *shapes->add_parameters() =
1777       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1778   *shapes->add_parameters() =
1779       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1780   *shapes->mutable_result() =
1781       xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})})
1782           .ToProto();
1783   StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot());
1784 
1785   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1786   auto computation =
1787       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1788   auto c_handle = ops::XRTCompile(root, computation);
1789   TF_ASSERT_OK(root.status());
1790 
1791   XrtClientSession session(root);
1792   std::vector<Tensor> outputs;
1793   TF_EXPECT_OK(session.Run({c_handle.handle}, &outputs));
1794 }
1795 
TEST(RawApiTest,CompileAndExecuteWithReusedBuffers)1796 TEST(RawApiTest, CompileAndExecuteWithReusedBuffers) {
1797   xla::Shape element_shape = xla::ShapeUtil::MakeShape(xla::F32, {2});
1798   xla::Shape shape =
1799       xla::ShapeUtil::MakeTupleShape({element_shape, element_shape});
1800   xla::Shape return_shape = xla::ShapeUtil::MakeTupleShape(
1801       {element_shape, element_shape, element_shape, element_shape});
1802   xla::XlaBuilder builder("ReuseBuffer");
1803   auto param = xla::Parameter(&builder, 0, shape, "param");
1804   auto p0 = xla::GetTupleElement(param, 0);
1805   auto p1 = xla::GetTupleElement(param, 1);
1806   auto add = xla::Add(p0, p1);
1807   auto sub = xla::Sub(p0, p1);
1808   xla::Tuple(&builder, {add, sub, p0, p1});
1809 
1810   // Flip the tuple literals in the input handle.
1811   builder.SetUpAlias({1}, 0, {0});
1812   builder.SetUpAlias({0}, 0, {1});
1813 
1814   auto computation = builder.Build().ValueOrDie();
1815 
1816   auto literal0 = xla::LiteralUtil::CreateR1<float>({1.0f, 2.0f});
1817   auto literal1 = xla::LiteralUtil::CreateR1<float>({5.0f, 9.0f});
1818   auto literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1});
1819 
1820   xrt::XLAAllocation param_alloc;
1821   *param_alloc.mutable_value() = literal.ToProto();
1822 
1823   xrt::XLAComputation c;
1824   auto config = c.mutable_config();
1825   auto shapes = config->mutable_program_shape();
1826   *shapes->add_parameters() = shape.ToProto();
1827   *shapes->mutable_result() = return_shape.ToProto();
1828   StoreComputationSnapshot(computation, c.mutable_hlo_snapshot());
1829 
1830   xrt::XRTExecutionConfig e;
1831   e.set_release_input_handles(false);
1832   e.set_release_compilation_handle(true);
1833 
1834   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1835   XrtClientSession session(root);
1836   auto e_config =
1837       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1838   auto c_data =
1839       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1840   auto c_handle = ops::XRTCompile(root, c_data);
1841   auto param_value = ops::Const(root.WithDevice("/device:CPU:0"),
1842                                 param_alloc.SerializeAsString());
1843   auto param_handle = ops::XRTAllocate(root, param_value);
1844   TF_ASSERT_OK(root.status());
1845 
1846   std::vector<Tensor> outputs;
1847   TF_EXPECT_OK(session.Run({param_handle}, &outputs));
1848 
1849   int64_t alloc_handle = outputs[0].scalar<int64_t>()();
1850 
1851   // Note that we release the result handle immediately, but since we aliased
1852   // the output buffers onto the input allocation ones (held in alloc_handle),
1853   // we can fetch the result from there.
1854   auto result =
1855       ops::XRTExecute(root, c_handle.handle, e_config, {Input(alloc_handle)});
1856   auto read_back = ops::XRTReadLiteral(root, result);
1857   auto release = ops::XRTReleaseAllocationHandle(
1858       root.WithControlDependencies(read_back), result);
1859   TF_ASSERT_OK(root.status());
1860 
1861   TF_EXPECT_OK(
1862       session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs));
1863 
1864   xla::Literal exec_literal = ReadOutputLiteral(outputs, 0);
1865   auto exec_literal_parts = exec_literal.DecomposeTuple();
1866   ASSERT_EQ(exec_literal_parts.size(), 4);
1867 
1868   EXPECT_TRUE(CompareLiterals(exec_literal_parts[2], literal0));
1869   EXPECT_TRUE(CompareLiterals(exec_literal_parts[3], literal1));
1870 
1871   // Now we read back the original input handle values, which at this point
1872   // should contain the result of the XLA computation.
1873   auto read_handle = ops::XRTReadLiteral(root, Input(alloc_handle));
1874   TF_ASSERT_OK(root.status());
1875   auto release_handle = ops::XRTReleaseAllocationHandle(
1876       root.WithControlDependencies(read_handle), Input(alloc_handle));
1877   TF_ASSERT_OK(root.status());
1878 
1879   TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {read_handle},
1880                            {release_handle}, &outputs));
1881 
1882   xla::Literal return_literal = ReadOutputLiteral(outputs, 0);
1883 
1884   auto expected_literal0 = xla::LiteralUtil::CreateR1<float>({6.0f, 11.0f});
1885   auto expected_literal1 = xla::LiteralUtil::CreateR1<float>({-4.0f, -7.0f});
1886   // The first element of the computation returned tuple would be the add
1887   // (expected_literal0), but since we flipped the buffers, the sub
1888   // (expected_literal1) should come first.
1889   auto expected_literal =
1890       xla::LiteralUtil::MakeTuple({&expected_literal1, &expected_literal0});
1891 
1892   EXPECT_TRUE(CompareLiterals(return_literal, expected_literal));
1893 }
1894 
TEST(RawApiTest,CompileAndExecuteWithReusedBuffersS64)1895 TEST(RawApiTest, CompileAndExecuteWithReusedBuffersS64) {
1896   xla::Shape element_shape = xla::ShapeUtil::MakeShape(xla::S64, {2});
1897   xla::Shape shape =
1898       xla::ShapeUtil::MakeTupleShape({element_shape, element_shape});
1899   xla::Shape return_shape = xla::ShapeUtil::MakeTupleShape(
1900       {element_shape, element_shape, element_shape, element_shape});
1901   xla::XlaBuilder builder("ReuseBuffer");
1902   auto param = xla::Parameter(&builder, 0, shape, "param");
1903   auto p0 = xla::GetTupleElement(param, 0);
1904   auto p1 = xla::GetTupleElement(param, 1);
1905   auto add = xla::Add(p0, p1);
1906   auto sub = xla::Sub(p0, p1);
1907   xla::Tuple(&builder, {add, sub, p0, p1});
1908 
1909   // Flip the tuple literals in the input handle.
1910   builder.SetUpAlias({1}, 0, {0});
1911   builder.SetUpAlias({0}, 0, {1});
1912 
1913   auto computation = builder.Build().ValueOrDie();
1914 
1915   auto literal0 = xla::LiteralUtil::CreateR1<int64_t>({1, 2});
1916   auto literal1 = xla::LiteralUtil::CreateR1<int64_t>({5, 9});
1917   auto literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1});
1918 
1919   xrt::XLAAllocation param_alloc;
1920   *param_alloc.mutable_value() = literal.ToProto();
1921 
1922   xrt::XLAComputation c;
1923   auto config = c.mutable_config();
1924   auto shapes = config->mutable_program_shape();
1925   *shapes->add_parameters() = shape.ToProto();
1926   *shapes->mutable_result() = return_shape.ToProto();
1927   StoreComputationSnapshot(computation, c.mutable_hlo_snapshot());
1928 
1929   xrt::XRTExecutionConfig e;
1930   e.set_release_input_handles(false);
1931   e.set_release_compilation_handle(true);
1932 
1933   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1934   XrtClientSession session(root);
1935   auto e_config =
1936       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1937   auto c_data =
1938       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1939   auto c_handle = ops::XRTCompile(root, c_data);
1940   auto param_value = ops::Const(root.WithDevice("/device:CPU:0"),
1941                                 param_alloc.SerializeAsString());
1942   auto param_handle = ops::XRTAllocate(root, param_value);
1943   TF_ASSERT_OK(root.status());
1944 
1945   std::vector<Tensor> outputs;
1946   TF_EXPECT_OK(session.Run({param_handle}, &outputs));
1947 
1948   int64_t alloc_handle = outputs[0].scalar<int64_t>()();
1949 
1950   // Note that we release the result handle immediately, but since we aliased
1951   // the output buffers onto the input allocation ones (held in alloc_handle),
1952   // we can fetch the result from there.
1953   auto result =
1954       ops::XRTExecute(root, c_handle.handle, e_config, {Input(alloc_handle)});
1955   auto read_back = ops::XRTReadLiteral(root, result);
1956   auto release = ops::XRTReleaseAllocationHandle(
1957       root.WithControlDependencies(read_back), result);
1958   TF_ASSERT_OK(root.status());
1959 
1960   TF_EXPECT_OK(
1961       session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs));
1962 
1963   xla::Literal exec_literal = ReadOutputLiteral(outputs, 0);
1964   auto exec_literal_parts = exec_literal.DecomposeTuple();
1965   ASSERT_EQ(exec_literal_parts.size(), 4);
1966 
1967   EXPECT_TRUE(CompareLiterals(exec_literal_parts[2], literal0));
1968   EXPECT_TRUE(CompareLiterals(exec_literal_parts[3], literal1));
1969 
1970   // Now we read back the original input handle values, which at this point
1971   // should contain the result of the XLA computation.
1972   auto read_handle = ops::XRTReadLiteral(root, Input(alloc_handle));
1973   TF_ASSERT_OK(root.status());
1974   auto release_handle = ops::XRTReleaseAllocationHandle(
1975       root.WithControlDependencies(read_handle), Input(alloc_handle));
1976   TF_ASSERT_OK(root.status());
1977 
1978   TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {read_handle},
1979                            {release_handle}, &outputs));
1980 
1981   xla::Literal return_literal = ReadOutputLiteral(outputs, 0);
1982 
1983   auto expected_literal0 = xla::LiteralUtil::CreateR1<int64_t>({6, 11});
1984   auto expected_literal1 = xla::LiteralUtil::CreateR1<int64_t>({-4, -7});
1985   // The first element of the computation returned tuple would be the add
1986   // (expected_literal0), but since we flipped the buffers, the sub
1987   // (expected_literal1) should come first.
1988   auto expected_literal =
1989       xla::LiteralUtil::MakeTuple({&expected_literal1, &expected_literal0});
1990 
1991   EXPECT_TRUE(CompareLiterals(return_literal, expected_literal));
1992 }
1993 
TEST(RawApiTest,CompileAndExecuteWithS64Argument)1994 TEST(RawApiTest, CompileAndExecuteWithS64Argument) {
1995   xrt::XLAAllocation p0;
1996   *p0.mutable_value() = xla::LiteralUtil::CreateR0<int64_t>(11031965).ToProto();
1997   xrt::XLAAllocation p1;
1998   *p1.mutable_value() = xla::LiteralUtil::CreateR0<int64_t>(4091934).ToProto();
1999 
2000   xrt::XLAComputation c;
2001   auto config = c.mutable_config();
2002   auto shapes = config->mutable_program_shape();
2003   *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
2004   *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
2005   *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
2006   StoreComputationSnapshot(AddS64(), c.mutable_hlo_snapshot());
2007 
2008   xrt::XRTExecutionConfig e;
2009   e.set_release_input_handles(true);
2010   e.set_release_compilation_handle(true);
2011   e.set_return_exploded_tuple(true);
2012 
2013   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
2014   auto e_config =
2015       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
2016   auto computation =
2017       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
2018   auto c_handle = ops::XRTCompile(root, computation);
2019   auto p0_value =
2020       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
2021   auto p0_handle = ops::XRTAllocate(root, p0_value);
2022   auto p1_value =
2023       ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
2024   auto p1_handle = ops::XRTAllocate(root, p1_value);
2025   auto result = ops::XRTExecute(root, c_handle.handle, e_config,
2026                                 {Output(p0_handle), Output(p1_handle)});
2027   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
2028   TF_ASSERT_OK(root.status());
2029 
2030   XrtClientSession session(root);
2031   std::vector<Tensor> outputs;
2032   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
2033 
2034   xla::LiteralProto response;
2035   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
2036 
2037   auto expected = xla::LiteralUtil::CreateR0<int64_t>(15123899);
2038   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
2039 
2040   xla::ProgramShapeProto program_shape;
2041   EXPECT_TRUE(ParseFromTString(outputs[1].vec<tstring>()(0), &program_shape));
2042   EXPECT_EQ(program_shape.parameters_size(), 2);
2043   EXPECT_TRUE(xla::ShapeUtil::HasPrimitiveType(
2044       xla::Shape(program_shape.result()), xla::S64));
2045 }
2046 
2047 // Tests the XRT device memory compaction API (XRTCompactAllocations).
TEST(RawApiTest,TestDeviceMemoryCompaction)2048 TEST(RawApiTest, TestDeviceMemoryCompaction) {
2049   static const int kNumAllocs = 32;
2050   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
2051 
2052   std::vector<xrt::XLAAllocation> allocs(kNumAllocs);
2053   std::vector<Output> handle_outputs;
2054   for (int i = 0; i < kNumAllocs; ++i) {
2055     *allocs[i].mutable_value() = BasedTwoElementTuple(i * 4.0f);
2056     auto value = ops::Const(root.WithDevice("/device:CPU:0"),
2057                             allocs[i].SerializeAsString());
2058     handle_outputs.push_back(ops::XRTAllocate(root, value));
2059   }
2060   TF_ASSERT_OK(root.status());
2061 
2062   XrtClientSession session(root);
2063   std::vector<Tensor> outputs;
2064   TF_EXPECT_OK(session.Run(handle_outputs, &outputs));
2065   EXPECT_EQ(outputs.size(), handle_outputs.size());
2066 
2067   std::vector<int64_t> handles;
2068   for (auto& output : outputs) {
2069     handles.push_back(output.scalar<int64_t>()());
2070   }
2071   // Create holes by releasing even allocations.
2072   std::vector<Operation> handle_releases;
2073   for (size_t i = 0; i < handles.size(); i += 2) {
2074     handle_releases.push_back(
2075         ops::XRTReleaseAllocationHandle(root, Input(handles[i])));
2076   }
2077   TF_ASSERT_OK(root.status());
2078 
2079   TF_EXPECT_OK(
2080       session.Run(ClientSession::FeedType(), {}, handle_releases, &outputs));
2081 
2082   // Run the compaction API.
2083   auto compact_op = ops::XRTCompactAllocations(root);
2084   TF_EXPECT_OK(
2085       session.Run(ClientSession::FeedType(), {}, {compact_op}, &outputs));
2086 
2087   // Read back the allocation left at odd indices.
2088   std::vector<Output> read_outputs;
2089   for (size_t i = 1; i < handles.size(); i += 2) {
2090     read_outputs.push_back(ops::XRTReadLiteral(root, Input(handles[i])));
2091   }
2092   TF_ASSERT_OK(root.status());
2093 
2094   TF_EXPECT_OK(session.Run(read_outputs, &outputs));
2095   EXPECT_EQ(outputs.size(), read_outputs.size());
2096 
2097   // Verify that everything got moved correctly and the device data matches what
2098   // we have on record.
2099   for (size_t i = 1, j = 0; i < handles.size(); i += 2, ++j) {
2100     xla::LiteralProto response;
2101     EXPECT_TRUE(ParseFromTString(outputs[j].scalar<tstring>()(), &response));
2102     EXPECT_TRUE(CompareLiteralProtos(allocs[i].value(), response));
2103   }
2104 }
2105 
TEST(RawApiTest,TestDeviceMemorySwap)2106 TEST(RawApiTest, TestDeviceMemorySwap) {
2107   const xla::Shape scalar_shape = xla::ShapeUtil::MakeShape(xla::F32, {});
2108   // 100MB F32 tensor.
2109   const xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {5000, 5000});
2110   const int64_t tensor_size = xla::ShapeUtil::ByteSizeOf(shape);
2111   // On CPU we cannot trigger OOM/swap. For TPU and GPU we select 16GB as
2112   // maximum memory.
2113   int64_t device_memory_size = 8LL * 1024 * 1024 * 1024;
2114   if (*xla_test_device_ptr == "TPU" || *xla_test_device_ptr == "XLA_GPU") {
2115     device_memory_size = 16LL * 1024 * 1024 * 1024;
2116   }
2117 
2118   xrt::XLAAllocation p0;
2119   *p0.mutable_value() = xla::LiteralUtil::CreateR0<float>(0.90434).ToProto();
2120 
2121   // Create a computation which broadcasts a scalar to a big tensor.
2122   xrt::XLAComputation c_bcast;
2123   {
2124     auto shapes = c_bcast.mutable_config()->mutable_program_shape();
2125     *shapes->add_parameters() = scalar_shape.ToProto();
2126     *shapes->mutable_result() = shape.ToProto();
2127     StoreComputationSnapshot(
2128         BroadcastComputation(scalar_shape, shape.dimensions()),
2129         c_bcast.mutable_hlo_snapshot());
2130   }
2131 
2132   // Create a computation which compares two tensors.
2133   xrt::XLAComputation c_equal;
2134   {
2135     auto shapes = c_equal.mutable_config()->mutable_program_shape();
2136     *shapes->add_parameters() = shape.ToProto();
2137     *shapes->add_parameters() = shape.ToProto();
2138     *shapes->mutable_result() =
2139         xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto();
2140     StoreComputationSnapshot(IsEqualComputation(shape),
2141                              c_equal.mutable_hlo_snapshot());
2142   }
2143 
2144   xrt::XRTExecutionConfig e;
2145   e.set_release_input_handles(false);
2146   e.set_release_compilation_handle(false);
2147 
2148   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
2149   XrtClientSession session(root);
2150   auto e_config =
2151       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
2152   auto bcast_computation =
2153       ops::Const(root.WithDevice("/device:CPU:0"), c_bcast.SerializeAsString());
2154   auto c_bcast_handle = ops::XRTCompile(root, bcast_computation);
2155   auto equal_computation =
2156       ops::Const(root.WithDevice("/device:CPU:0"), c_equal.SerializeAsString());
2157   auto c_equal_handle = ops::XRTCompile(root, equal_computation);
2158   auto p0_value =
2159       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
2160   auto p0_handle = ops::XRTAllocate(root, p0_value);
2161   std::vector<Tensor> outputs;
2162   std::vector<int64_t> device_handles;
2163 
2164   // Create more data the device can take using the broadcast computation.
2165   int64_t num_tensors = 8 + device_memory_size / tensor_size;
2166   for (int64_t i = 0; i < num_tensors; ++i) {
2167     auto result = ops::XRTExecute(root, c_bcast_handle.handle, e_config,
2168                                   {Output(p0_handle)});
2169     TF_ASSERT_OK(root.status());
2170     TF_ASSERT_OK(session.Run({result}, &outputs));
2171     EXPECT_EQ(outputs.size(), 1);
2172     device_handles.push_back(outputs[0].scalar<int64_t>()());
2173   }
2174 
2175   // Trigger computations on XRT handles to verify the swap-out/swap-in logic,
2176   // by comparing sequential couple of tensors.
2177   auto zero_literal = xla::LiteralUtil::CreateR0<int32_t>(0);
2178   for (size_t i = 0; i + 1 < device_handles.size(); ++i) {
2179     auto exec_op = ops::XRTExecute(
2180         root, c_equal_handle.handle, e_config,
2181         {Input(device_handles[i]), Input(device_handles[i + 1])});
2182     auto read_back = ops::XRTReadLiteral(root, exec_op);
2183 
2184     TF_ASSERT_OK(root.status());
2185     TF_ASSERT_OK(session.Run({read_back}, &outputs));
2186     EXPECT_EQ(outputs.size(), 1);
2187 
2188     xla::LiteralProto response;
2189     EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
2190     auto literal = xla::Literal::CreateFromProto(response).ValueOrDie();
2191     EXPECT_EQ(literal, zero_literal);
2192   }
2193 }
2194 
TEST(RawApiTest,TestMetricsFetch)2195 TEST(RawApiTest, TestMetricsFetch) {
2196   xrt::XRTMetricsCollect metrics;
2197   metrics.add_metrics_regex("/tensorflow/xrt/.*");
2198 
2199   Scope root = Scope::NewRootScope().WithDevice("/device:CPU:0");
2200   auto metrics_value = ops::Const(root, metrics.SerializeAsString());
2201   Output result = ops::XRTMetricsCollect(root, metrics_value);
2202   TF_ASSERT_OK(root.status());
2203 
2204   ClientSession session(root);
2205   std::vector<Tensor> outputs;
2206   TF_EXPECT_OK(session.Run({result}, &outputs));
2207   ASSERT_EQ(outputs.size(), 1);
2208 
2209   xrt::MetricsReport report;
2210   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &report));
2211   for (auto& metric : report.metrics()) {
2212     EXPECT_EQ(metric.name().compare(0, 16, "/tensorflow/xrt/"), 0);
2213   }
2214 }
2215 
TEST(RawApiTest,TestMemoryInfo)2216 TEST(RawApiTest, TestMemoryInfo) {
2217   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
2218   Output result = ops::XRTMemoryInfo(root);
2219   TF_ASSERT_OK(root.status());
2220 
2221   ClientSession session(root);
2222   std::vector<Tensor> outputs;
2223   TF_EXPECT_OK(session.Run({result}, &outputs));
2224   ASSERT_EQ(outputs.size(), 1);
2225 
2226   xrt::MemoryInfo mem_info;
2227   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &mem_info));
2228   EXPECT_GT(mem_info.kb_total(), 0);
2229   EXPECT_GT(mem_info.kb_free(), 0);
2230 }
2231 
2232 }  // namespace
2233 
2234 }  // namespace tensorflow
2235 
main(int argc,char ** argv)2236 int main(int argc, char** argv) {
2237   tensorflow::xla_test_device_ptr = new tensorflow::string("XLA_CPU");
2238   tensorflow::xla_platform_ptr = new tensorflow::string("CPU");
2239   std::vector<tensorflow::Flag> flag_list = {
2240       tensorflow::Flag("xla_test_device", tensorflow::xla_test_device_ptr,
2241                        "Tensorflow device type to use for test, e.g., XLA_CPU"),
2242       tensorflow::Flag("xla_platform", tensorflow::xla_platform_ptr,
2243                        "The XLA platform to select for the device"),
2244   };
2245   tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
2246   const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
2247   if (!parse_result) {
2248     LOG(ERROR) << "\n" << usage;
2249     return 2;
2250   }
2251   testing::InitGoogleTest(&argc, argv);
2252   if (argc > 1) {
2253     LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
2254     return 2;
2255   }
2256   return RUN_ALL_TESTS();
2257 }
2258