• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <vector>
18 
19 #include "tensorflow/compiler/xla/client/global_data.h"
20 #include "tensorflow/compiler/xla/client/local_client.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/client/xla_computation.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/test_helpers.h"
27 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
28 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
29 #include "tensorflow/compiler/xla/tests/test_macros.h"
30 #include "tensorflow/compiler/xla/tests/test_utils.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 namespace xla {
36 namespace {
37 
38 class ClientTest : public ClientLibraryTestBase {};
39 
XLA_TEST_F(ClientTest,ExecuteWithLayout)40 XLA_TEST_F(ClientTest, ExecuteWithLayout) {
41   XlaBuilder b(TestName());
42 
43   std::vector<std::vector<int64>> layouts = {{0, 1}, {1, 0}};
44   for (const std::vector<int64>& execute_layout : layouts) {
45     for (const std::vector<int64>& transfer_layout : layouts) {
46       Add(ConstantR2<int32>(&b, {{1, 2}, {3, 4}}),
47           ConstantR2<int32>(&b, {{10, 20}, {30, 40}}));
48       TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
49 
50       ExecutionOptions execution_options = execution_options_;
51       *execution_options.mutable_shape_with_output_layout() =
52           ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
53                                          execute_layout)
54               .ToProto();
55       TF_ASSERT_OK_AND_ASSIGN(
56           std::unique_ptr<GlobalData> data,
57           client_->Execute(computation, {}, &execution_options));
58 
59       Literal expected_literal = LiteralUtil::CreateR2WithLayout<int32>(
60           {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
61 
62       TF_ASSERT_OK_AND_ASSIGN(
63           auto computed, client_->Transfer(*data, &expected_literal.shape()));
64 
65       ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
66           expected_literal.shape(), computed.shape()));
67       EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
68     }
69   }
70 }
71 
XLA_TEST_F(ClientTest,ExecuteWithTupleLayout)72 XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) {
73   XlaBuilder b(TestName());
74 
75   Tuple(&b, {ConstantR2<int32>(&b, {{1, 2}, {3, 4}}),
76              ConstantR2<int32>(&b, {{10, 20}, {30, 40}})});
77 
78   TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
79 
80   ExecutionOptions execution_options = execution_options_;
81   // Create a result shape with one element column major and the other row
82   // major.
83   *execution_options.mutable_shape_with_output_layout() =
84       ShapeUtil::MakeTupleShape(
85           {ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
86                                           /*minor_to_major=*/{0, 1}),
87            ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
88                                           /*minor_to_major=*/{1, 0})})
89           .ToProto();
90 
91   TF_ASSERT_OK_AND_ASSIGN(
92       auto result,
93       client_->ExecuteAndTransfer(computation, {}, &execution_options));
94   LiteralTestUtil::ExpectR2Equal<int32>({{1, 2}, {3, 4}},
95                                         LiteralSlice(result, {0}));
96   LiteralTestUtil::ExpectR2Equal<int32>({{10, 20}, {30, 40}},
97                                         LiteralSlice(result, {1}));
98 
99   EXPECT_TRUE(result.shape().IsTuple());
100   EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.shape()));
101 
102   EXPECT_TRUE(ShapeUtil::Equal(
103       ShapeUtil::GetTupleElementShape(result.shape(), 0),
104       ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
105                                      /*minor_to_major=*/{0, 1})));
106   EXPECT_TRUE(ShapeUtil::Equal(
107       ShapeUtil::GetTupleElementShape(result.shape(), 1),
108       ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
109                                      /*minor_to_major=*/{1, 0})));
110 }
111 
112 // Disabled for interpreter since ExecuteAsyncOnStream is not implemented on
113 // interpreter backend.
XLA_TEST_F(ClientTest,DISABLED_ON_INTERPRETER (DISABLED_ON_GPU (ExecuteParallel)))114 XLA_TEST_F(ClientTest,
115            DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(ExecuteParallel))) {
116   XlaComputation add_with_one_arg, mul_with_two_args, dot_with_one_arg;
117   Shape shape = ShapeUtil::MakeShape(S32, {2, 2});
118 
119   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> const_arg,
120                           client_->TransferToServer(
121                               LiteralUtil::CreateR2<int32>({{5, 6}, {7, 8}})));
122 
123   XlaBuilder b(TestName() + ".add");
124   Add(Parameter(&b, 0, shape, "param_0"),
125       ConstantR2<int32>(&b, {{1, 2}, {3, 4}}));
126   TF_ASSERT_OK_AND_ASSIGN(add_with_one_arg, b.Build());
127 
128   // We can't really test parallel execution on CPU since all of the cores in a
129   // CPU are presented as a single device.  So for now we test "parallel"
130   // execution on a single device.
131   std::vector<Client::XlaComputationInstance> computation_instances;
132   TF_ASSERT_OK_AND_ASSIGN(std::vector<xla::DeviceHandle> devices,
133                           client_->GetDeviceHandles(1));
134   ASSERT_EQ(devices.size(), 1);
135 
136   ExecutionOptions options = execution_options_;
137   *options.add_device_handles() = devices[0];
138   computation_instances.push_back(Client::XlaComputationInstance(
139       add_with_one_arg, {const_arg.get()}, options, nullptr));
140 
141   TF_ASSERT_OK_AND_ASSIGN(auto results,
142                           client_->ExecuteParallel(computation_instances));
143   auto expected_result = LiteralUtil::CreateR2<int32>({{6, 8}, {10, 12}});
144 
145   TF_ASSERT_OK_AND_ASSIGN(
146       auto result_literal,
147       client_->Transfer(*results[0], &expected_result.shape()));
148 
149   EXPECT_TRUE(LiteralTestUtil::Equal(expected_result, result_literal));
150 }
151 
152 }  // namespace
153 }  // namespace xla
154