1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/aot/codegen.h"
17
18 #include <algorithm>
19 #include <string>
20 #include <vector>
21
22 #include "absl/strings/match.h"
23 #include "absl/strings/string_view.h"
24 #include "llvm/Support/TargetSelect.h"
25 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/core/framework/tensor_shape.pb.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/lib/io/path.h"
32 #include "tensorflow/core/platform/env.h"
33 #include "tensorflow/core/platform/resource_loader.h"
34 #include "tensorflow/core/platform/test.h"
35
36 namespace tensorflow {
37 namespace tfcompile {
38 namespace {
39
40 using ::xla::cpu_function_runtime::BufferInfo;
41
ExpectErrorContains(const Status & status,absl::string_view str)42 void ExpectErrorContains(const Status& status, absl::string_view str) {
43 EXPECT_NE(OkStatus(), status);
44 EXPECT_TRUE(absl::StrContains(status.error_message(), str))
45 << "expected error: " << status.error_message() << " to contain: " << str;
46 }
47
TEST(ValidateCppIdent,Simple)48 TEST(ValidateCppIdent, Simple) {
49 TF_EXPECT_OK(ValidateCppIdent("a", ""));
50 TF_EXPECT_OK(ValidateCppIdent("abc", ""));
51 TF_EXPECT_OK(ValidateCppIdent("_abc", ""));
52 TF_EXPECT_OK(ValidateCppIdent("_abc123", ""));
53 // Make sure we didn't skip a valid letter or digit
54 string ident;
55 for (char c = 'a'; c <= 'z'; c++) {
56 ident.append(1, c);
57 }
58 for (char c = 'A'; c <= 'Z'; c++) {
59 ident.append(1, c);
60 }
61 for (char c = '0'; c <= '9'; c++) {
62 ident.append(1, c);
63 }
64 ident += "_";
65 TF_EXPECT_OK(ValidateCppIdent(ident, ""));
66
67 ExpectErrorContains(ValidateCppIdent("", ""), "empty identifier");
68 ExpectErrorContains(ValidateCppIdent(" ", ""), "illegal leading char");
69 ExpectErrorContains(ValidateCppIdent("0", ""), "illegal leading char");
70 ExpectErrorContains(ValidateCppIdent(".", ""), "illegal leading char");
71 ExpectErrorContains(ValidateCppIdent(":", ""), "illegal leading char");
72 ExpectErrorContains(ValidateCppIdent("a.", ""), "illegal char");
73 ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
74 ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
75 }
76
77 class ParseCppClassTest : public ::testing::Test {
78 protected:
ExpectOK(const string & cpp_class,const string & want_class_name,const std::vector<string> & want_namespaces)79 void ExpectOK(const string& cpp_class, const string& want_class_name,
80 const std::vector<string>& want_namespaces) {
81 string class_name;
82 std::vector<string> namespaces;
83 TF_EXPECT_OK(ParseCppClass(cpp_class, &class_name, &namespaces));
84 EXPECT_EQ(class_name, want_class_name);
85 EXPECT_EQ(namespaces, want_namespaces);
86 }
87
ExpectFail(const string & cpp_class)88 void ExpectFail(const string& cpp_class) {
89 string class_name;
90 std::vector<string> namespaces;
91 EXPECT_NE(ParseCppClass(cpp_class, &class_name, &namespaces), OkStatus())
92 << cpp_class;
93 }
94 };
95
TEST_F(ParseCppClassTest,ParseOK)96 TEST_F(ParseCppClassTest, ParseOK) {
97 ExpectOK("MyClass", "MyClass", {});
98 ExpectOK("_MyClass", "_MyClass", {});
99 ExpectOK("a::MyClass", "MyClass", {"a"});
100 ExpectOK("a::foo::MyClass", "MyClass", {"a", "foo"});
101 ExpectOK("a::foo::b::MyClass", "MyClass", {"a", "foo", "b"});
102 ExpectOK("a::foo::b::bar::MyClass", "MyClass", {"a", "foo", "b", "bar"});
103 ExpectOK("foo::MyClass", "MyClass", {"foo"});
104 ExpectOK("_foo::MyClass", "MyClass", {"_foo"});
105 ExpectOK("_foo::_MyClass", "_MyClass", {"_foo"});
106 ExpectOK("::foo::bar::MyClass", "MyClass", {"foo", "bar"});
107 ExpectOK("::_foo::MyClass", "MyClass", {"_foo"});
108 ExpectOK("::_foo::_MyClass", "_MyClass", {"_foo"});
109 // Make sure we didn't skip a valid letter or digit
110 string ident;
111 for (char c = 'a'; c <= 'z'; c++) {
112 ident.append(1, c);
113 }
114 for (char c = 'A'; c <= 'Z'; c++) {
115 ident.append(1, c);
116 }
117 for (char c = '0'; c <= '9'; c++) {
118 ident.append(1, c);
119 }
120 ident += "_";
121 ExpectOK(ident, ident, {});
122 ExpectOK(ident + "::" + ident, ident, {ident});
123 ExpectOK(ident + "::" + ident + "::" + ident, ident, {ident, ident});
124 }
125
TEST_F(ParseCppClassTest,ParseFail)126 TEST_F(ParseCppClassTest, ParseFail) {
127 ExpectFail("");
128 ExpectFail("::");
129 ExpectFail("0");
130 ExpectFail("a.b");
131 ExpectFail("a:b");
132 ExpectFail(":foo::bar");
133 ExpectFail("good::.bad");
134 ExpectFail("good:::bad");
135 ExpectFail("good::bad::");
136 ExpectFail("good::::bad");
137 ExpectFail("::::bad");
138 ExpectFail("good:: bad");
139 ExpectFail("good::0bad");
140 }
141
CompareWithGoldenFile(const string & tensorflow_relative_golden_file_name,const string & expected_contents,bool ignore_cr)142 static void CompareWithGoldenFile(
143 const string& tensorflow_relative_golden_file_name,
144 const string& expected_contents, bool ignore_cr) {
145 // Get rid of all CR characters, we may be running under windows.
146 string sanitized_expected_contents(expected_contents);
147 if (ignore_cr) {
148 sanitized_expected_contents.erase(
149 std::remove(sanitized_expected_contents.begin(),
150 sanitized_expected_contents.end(), '\r'),
151 sanitized_expected_contents.end());
152 }
153
154 // To update the golden file, flip update_golden to true and run the
155 // following:
156 // bazel test --test_strategy=local \
157 // "third_party/tensorflow/compiler/aot:codegen_test"
158 const bool update_golden = false;
159 string golden_file_name =
160 GetDataDependencyFilepath(tensorflow_relative_golden_file_name);
161
162 if (update_golden) {
163 TF_EXPECT_OK(
164 WriteStringToFile(Env::Default(), golden_file_name, expected_contents));
165 }
166
167 string golden_file_contents;
168 TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name,
169 &golden_file_contents));
170 if (ignore_cr) {
171 golden_file_contents.erase(std::remove(golden_file_contents.begin(),
172 golden_file_contents.end(), '\r'),
173 golden_file_contents.end());
174 }
175 EXPECT_EQ(golden_file_contents, expected_contents);
176 }
177
TEST(CodegenTest,Golden)178 TEST(CodegenTest, Golden) {
179 // Normally CpuCompiler::CpuCompiler does this, but in this test we've
180 // bypassed the Cpu compiler so we have to do this manually.
181 LLVMInitializeX86Target();
182 LLVMInitializeX86TargetInfo();
183 LLVMInitializeX86TargetMC();
184 LLVMInitializeX86AsmPrinter();
185
186 CodegenOpts opts;
187 opts.class_name = "MyClass";
188 opts.target_triple = "x86_64-pc-linux";
189 opts.namespaces = {"foo", "bar"};
190 opts.gen_name_to_index = true;
191 opts.gen_program_shape = true;
192 tf2xla::Config config;
193 tf2xla::Feed* feed = config.add_feed();
194 feed->mutable_id()->set_node_name("feed0");
195 feed->set_name("myfeed");
196 feed = config.add_feed();
197 feed->mutable_id()->set_node_name("feed1");
198 tf2xla::Fetch* fetch = config.add_fetch();
199 fetch->mutable_id()->set_node_name("fetch0");
200 fetch->set_name("myfetch");
201 tf2xla::Variable* variable = config.add_variable();
202 variable->set_node_name("myvar_readonly");
203 variable->mutable_shape()->add_dim()->set_size(1);
204 variable->set_type(DT_FLOAT);
205 variable->set_readonly(true);
206 tf2xla::Variable* variable2 = config.add_variable();
207 variable2->set_node_name("myvar");
208 variable2->mutable_shape()->add_dim()->set_size(1);
209 variable2->set_type(DT_FLOAT);
210 tf2xla::Variable* variable3 = config.add_variable();
211 variable3->set_node_name("my/var");
212 variable3->set_name("myvar2");
213 variable3->mutable_shape()->add_dim()->set_size(5);
214 variable3->set_type(DT_INT32);
215 CompileResult compile_result;
216 compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
217 {},
218 {BufferInfo::MakeTempBuffer(1),
219 BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0),
220 BufferInfo::MakeTempBuffer(1),
221 BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
222 BufferInfo::MakeTempBuffer(1),
223 BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/2),
224 BufferInfo::MakeTempBuffer(1),
225 BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/3),
226 BufferInfo::MakeTempBuffer(1),
227 BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/4),
228 BufferInfo::MakeTempBuffer(1), BufferInfo::MakeTempBuffer(120)},
229 11, {}));
230 compile_result.program_shape =
231 xla::ShapeUtil::MakeProgramShape(
232 {
233 xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
234 xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
235 xla::ShapeUtil::MakeShape(xla::F32, {1}),
236 xla::ShapeUtil::MakeShape(xla::F32, {1}),
237 xla::ShapeUtil::MakeShape(xla::S32, {5}),
238 },
239 xla::ShapeUtil::MakeTupleShape({
240 xla::ShapeUtil::MakeShape(xla::U32, {5, 6}),
241 xla::ShapeUtil::MakeShape(xla::F32, {1}),
242 xla::ShapeUtil::MakeShape(xla::S32, {5}),
243 }))
244 .ToProto();
245 compile_result.entry_point = "entry_point";
246 compile_result.pointer_size = 8;
247
248 MetadataResult metadata_result;
249 TF_ASSERT_OK(GenerateMetadata(opts, compile_result, &metadata_result));
250
251 // The other fields in metadata_result are tested as part of the generated
252 // header test.
253
254 // This specific golden test checks a binary file. It can potentially run into
255 // issues due to ABIs not being stable, but has not so far.
256 // If we see any ABI issues, we should reconsider this specific test case.
257 CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_o.golden",
258 metadata_result.object_file_data, false);
259
260 string header;
261 TF_ASSERT_OK(
262 GenerateHeader(opts, config, compile_result, metadata_result, &header));
263
264 CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_h.golden", header,
265 true);
266 }
267 } // namespace
268 } // namespace tfcompile
269 } // namespace tensorflow
270