• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Tests that we call into Eigen for dot operations as needed.
17 
18 #include <algorithm>
19 #include <string>
20 
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
23 #include "tensorflow/compiler/xla/service/cpu/test_target_triple_helper.h"
24 #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/tests/test_utils.h"
27 #include "tensorflow/core/platform/test.h"
28 
29 namespace xla {
30 namespace cpu {
31 namespace {
32 
33 struct DotTestSpec {
34   PrimitiveType primitive_type;
35   std::string filecheck_lines;
36 };
37 
DotTestSpecToString(const::testing::TestParamInfo<DotTestSpec> & info)38 std::string DotTestSpecToString(
39     const ::testing::TestParamInfo<DotTestSpec>& info) {
40   return PrimitiveType_Name(info.param.primitive_type);
41 }
42 
43 class CpuEigenDotOperationTest
44     : public CpuCodegenTest,
45       public ::testing::WithParamInterface<DotTestSpec> {
46  protected:
CompileAndCheck(std::unique_ptr<HloComputation> entry_computation,const std::string & filecheck_lines)47   void CompileAndCheck(std::unique_ptr<HloComputation> entry_computation,
48                        const std::string& filecheck_lines) {
49     CpuAotCompilationOptions options{
50         /*triple=*/kTargetTripleForHost, /*cpu_name=*/kTargetCpuForHost,
51         /*features=*/"",
52         /*entry_point_name=*/"entry",
53         /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static};
54 
55     auto hlo_module = CreateNewVerifiedModule();
56     hlo_module->AddEntryComputation(std::move(entry_computation));
57 
58     CompileAheadOfTimeAndVerifyIr(std::move(hlo_module), options,
59                                   filecheck_lines,
60                                   /*match_optimized_ir=*/true);
61   }
62 };
63 
TEST_P(CpuEigenDotOperationTest,SimpleDotOp)64 TEST_P(CpuEigenDotOperationTest, SimpleDotOp) {
65   HloComputation::Builder builder(TestName());
66   DotTestSpec spec = GetParam();
67 
68   auto param_shape = ShapeUtil::MakeShape(spec.primitive_type, {128, 128});
69 
70   HloInstruction* lhs = builder.AddInstruction(
71       HloInstruction::CreateParameter(0, param_shape, "input"));
72   HloInstruction* rhs = builder.AddInstruction(
73       HloInstruction::CreateParameter(1, param_shape, "input"));
74 
75   builder.AddInstruction(CreateCanonicalDot(param_shape, lhs, rhs));
76   CompileAndCheck(builder.Build(), spec.filecheck_lines);
77 }
78 
TEST_P(CpuEigenDotOperationTest,DotTransposeOp)79 TEST_P(CpuEigenDotOperationTest, DotTransposeOp) {
80   HloComputation::Builder builder(TestName());
81   DotTestSpec spec = GetParam();
82 
83   auto param_shape = ShapeUtil::MakeShape(spec.primitive_type, {128, 128});
84 
85   HloInstruction* lhs = builder.AddInstruction(
86       HloInstruction::CreateParameter(0, param_shape, "input"));
87   HloInstruction* rhs = builder.AddInstruction(
88       HloInstruction::CreateParameter(1, param_shape, "input"));
89   HloInstruction* lhs_transposed = builder.AddInstruction(
90       HloInstruction::CreateTranspose(param_shape, lhs, {1, 0}));
91 
92   builder.AddInstruction(CreateCanonicalDot(param_shape, lhs_transposed, rhs));
93   CompileAndCheck(builder.Build(), spec.filecheck_lines);
94 }
95 
GetDotTestCases()96 std::vector<DotTestSpec> GetDotTestCases() {
97   std::vector<DotTestSpec> result;
98   // The fp16 test runs a 32-bit matmul because we promote fp16 gemms to fp32
99   // (they run much faster).
100   result.push_back(
101       {F16, R"(CHECK: call void @__xla_cpu_runtime_EigenMatMulF32)"});
102   result.push_back(
103       {F32, R"(CHECK: call void @__xla_cpu_runtime_EigenMatMulF32)"});
104   result.push_back(
105       {F64, R"(CHECK: call void @__xla_cpu_runtime_EigenMatMulF64)"});
106   return result;
107 }
108 
109 INSTANTIATE_TEST_SUITE_P(CpuEigenDotOperationTestInstantiation,
110                          CpuEigenDotOperationTest,
111                          ::testing::ValuesIn(GetDotTestCases()),
112                          DotTestSpecToString);
113 
114 }  // namespace
115 }  // namespace cpu
116 }  // namespace xla
117