• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/aot/codegen.h"
17 
18 #include <algorithm>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/strings/match.h"
23 #include "absl/strings/string_view.h"
24 #include "llvm/Support/TargetSelect.h"
25 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/core/framework/tensor_shape.pb.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/lib/io/path.h"
32 #include "tensorflow/core/platform/env.h"
33 #include "tensorflow/core/platform/resource_loader.h"
34 #include "tensorflow/core/platform/test.h"
35 
36 namespace tensorflow {
37 namespace tfcompile {
38 namespace {
39 
40 using ::xla::cpu_function_runtime::BufferInfo;
41 
ExpectErrorContains(const Status & status,absl::string_view str)42 void ExpectErrorContains(const Status& status, absl::string_view str) {
43   EXPECT_NE(OkStatus(), status);
44   EXPECT_TRUE(absl::StrContains(status.error_message(), str))
45       << "expected error: " << status.error_message() << " to contain: " << str;
46 }
47 
TEST(ValidateCppIdent,Simple)48 TEST(ValidateCppIdent, Simple) {
49   TF_EXPECT_OK(ValidateCppIdent("a", ""));
50   TF_EXPECT_OK(ValidateCppIdent("abc", ""));
51   TF_EXPECT_OK(ValidateCppIdent("_abc", ""));
52   TF_EXPECT_OK(ValidateCppIdent("_abc123", ""));
53   // Make sure we didn't skip a valid letter or digit
54   string ident;
55   for (char c = 'a'; c <= 'z'; c++) {
56     ident.append(1, c);
57   }
58   for (char c = 'A'; c <= 'Z'; c++) {
59     ident.append(1, c);
60   }
61   for (char c = '0'; c <= '9'; c++) {
62     ident.append(1, c);
63   }
64   ident += "_";
65   TF_EXPECT_OK(ValidateCppIdent(ident, ""));
66 
67   ExpectErrorContains(ValidateCppIdent("", ""), "empty identifier");
68   ExpectErrorContains(ValidateCppIdent(" ", ""), "illegal leading char");
69   ExpectErrorContains(ValidateCppIdent("0", ""), "illegal leading char");
70   ExpectErrorContains(ValidateCppIdent(".", ""), "illegal leading char");
71   ExpectErrorContains(ValidateCppIdent(":", ""), "illegal leading char");
72   ExpectErrorContains(ValidateCppIdent("a.", ""), "illegal char");
73   ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
74   ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
75 }
76 
77 class ParseCppClassTest : public ::testing::Test {
78  protected:
ExpectOK(const string & cpp_class,const string & want_class_name,const std::vector<string> & want_namespaces)79   void ExpectOK(const string& cpp_class, const string& want_class_name,
80                 const std::vector<string>& want_namespaces) {
81     string class_name;
82     std::vector<string> namespaces;
83     TF_EXPECT_OK(ParseCppClass(cpp_class, &class_name, &namespaces));
84     EXPECT_EQ(class_name, want_class_name);
85     EXPECT_EQ(namespaces, want_namespaces);
86   }
87 
ExpectFail(const string & cpp_class)88   void ExpectFail(const string& cpp_class) {
89     string class_name;
90     std::vector<string> namespaces;
91     EXPECT_NE(ParseCppClass(cpp_class, &class_name, &namespaces), OkStatus())
92         << cpp_class;
93   }
94 };
95 
TEST_F(ParseCppClassTest,ParseOK)96 TEST_F(ParseCppClassTest, ParseOK) {
97   ExpectOK("MyClass", "MyClass", {});
98   ExpectOK("_MyClass", "_MyClass", {});
99   ExpectOK("a::MyClass", "MyClass", {"a"});
100   ExpectOK("a::foo::MyClass", "MyClass", {"a", "foo"});
101   ExpectOK("a::foo::b::MyClass", "MyClass", {"a", "foo", "b"});
102   ExpectOK("a::foo::b::bar::MyClass", "MyClass", {"a", "foo", "b", "bar"});
103   ExpectOK("foo::MyClass", "MyClass", {"foo"});
104   ExpectOK("_foo::MyClass", "MyClass", {"_foo"});
105   ExpectOK("_foo::_MyClass", "_MyClass", {"_foo"});
106   ExpectOK("::foo::bar::MyClass", "MyClass", {"foo", "bar"});
107   ExpectOK("::_foo::MyClass", "MyClass", {"_foo"});
108   ExpectOK("::_foo::_MyClass", "_MyClass", {"_foo"});
109   // Make sure we didn't skip a valid letter or digit
110   string ident;
111   for (char c = 'a'; c <= 'z'; c++) {
112     ident.append(1, c);
113   }
114   for (char c = 'A'; c <= 'Z'; c++) {
115     ident.append(1, c);
116   }
117   for (char c = '0'; c <= '9'; c++) {
118     ident.append(1, c);
119   }
120   ident += "_";
121   ExpectOK(ident, ident, {});
122   ExpectOK(ident + "::" + ident, ident, {ident});
123   ExpectOK(ident + "::" + ident + "::" + ident, ident, {ident, ident});
124 }
125 
TEST_F(ParseCppClassTest,ParseFail)126 TEST_F(ParseCppClassTest, ParseFail) {
127   ExpectFail("");
128   ExpectFail("::");
129   ExpectFail("0");
130   ExpectFail("a.b");
131   ExpectFail("a:b");
132   ExpectFail(":foo::bar");
133   ExpectFail("good::.bad");
134   ExpectFail("good:::bad");
135   ExpectFail("good::bad::");
136   ExpectFail("good::::bad");
137   ExpectFail("::::bad");
138   ExpectFail("good:: bad");
139   ExpectFail("good::0bad");
140 }
141 
CompareWithGoldenFile(const string & tensorflow_relative_golden_file_name,const string & expected_contents,bool ignore_cr)142 static void CompareWithGoldenFile(
143     const string& tensorflow_relative_golden_file_name,
144     const string& expected_contents, bool ignore_cr) {
145   // Get rid of all CR characters, we may be running under windows.
146   string sanitized_expected_contents(expected_contents);
147   if (ignore_cr) {
148     sanitized_expected_contents.erase(
149         std::remove(sanitized_expected_contents.begin(),
150                     sanitized_expected_contents.end(), '\r'),
151         sanitized_expected_contents.end());
152   }
153 
154   // To update the golden file, flip update_golden to true and run the
155   // following:
156   // bazel test --test_strategy=local \
157   //   "third_party/tensorflow/compiler/aot:codegen_test"
158   const bool update_golden = false;
159   string golden_file_name =
160       GetDataDependencyFilepath(tensorflow_relative_golden_file_name);
161 
162   if (update_golden) {
163     TF_EXPECT_OK(
164         WriteStringToFile(Env::Default(), golden_file_name, expected_contents));
165   }
166 
167   string golden_file_contents;
168   TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name,
169                                 &golden_file_contents));
170   if (ignore_cr) {
171     golden_file_contents.erase(std::remove(golden_file_contents.begin(),
172                                            golden_file_contents.end(), '\r'),
173                                golden_file_contents.end());
174   }
175   EXPECT_EQ(golden_file_contents, expected_contents);
176 }
177 
TEST(CodegenTest,Golden)178 TEST(CodegenTest, Golden) {
179   // Normally CpuCompiler::CpuCompiler does this, but in this test we've
180   // bypassed the Cpu compiler so we have to do this manually.
181   LLVMInitializeX86Target();
182   LLVMInitializeX86TargetInfo();
183   LLVMInitializeX86TargetMC();
184   LLVMInitializeX86AsmPrinter();
185 
186   CodegenOpts opts;
187   opts.class_name = "MyClass";
188   opts.target_triple = "x86_64-pc-linux";
189   opts.namespaces = {"foo", "bar"};
190   opts.gen_name_to_index = true;
191   opts.gen_program_shape = true;
192   tf2xla::Config config;
193   tf2xla::Feed* feed = config.add_feed();
194   feed->mutable_id()->set_node_name("feed0");
195   feed->set_name("myfeed");
196   feed = config.add_feed();
197   feed->mutable_id()->set_node_name("feed1");
198   tf2xla::Fetch* fetch = config.add_fetch();
199   fetch->mutable_id()->set_node_name("fetch0");
200   fetch->set_name("myfetch");
201   tf2xla::Variable* variable = config.add_variable();
202   variable->set_node_name("myvar_readonly");
203   variable->mutable_shape()->add_dim()->set_size(1);
204   variable->set_type(DT_FLOAT);
205   variable->set_readonly(true);
206   tf2xla::Variable* variable2 = config.add_variable();
207   variable2->set_node_name("myvar");
208   variable2->mutable_shape()->add_dim()->set_size(1);
209   variable2->set_type(DT_FLOAT);
210   tf2xla::Variable* variable3 = config.add_variable();
211   variable3->set_node_name("my/var");
212   variable3->set_name("myvar2");
213   variable3->mutable_shape()->add_dim()->set_size(5);
214   variable3->set_type(DT_INT32);
215   CompileResult compile_result;
216   compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
217       {},
218       {BufferInfo::MakeTempBuffer(1),
219        BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0),
220        BufferInfo::MakeTempBuffer(1),
221        BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
222        BufferInfo::MakeTempBuffer(1),
223        BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/2),
224        BufferInfo::MakeTempBuffer(1),
225        BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/3),
226        BufferInfo::MakeTempBuffer(1),
227        BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/4),
228        BufferInfo::MakeTempBuffer(1), BufferInfo::MakeTempBuffer(120)},
229       11, {}));
230   compile_result.program_shape =
231       xla::ShapeUtil::MakeProgramShape(
232           {
233               xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
234               xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
235               xla::ShapeUtil::MakeShape(xla::F32, {1}),
236               xla::ShapeUtil::MakeShape(xla::F32, {1}),
237               xla::ShapeUtil::MakeShape(xla::S32, {5}),
238           },
239           xla::ShapeUtil::MakeTupleShape({
240               xla::ShapeUtil::MakeShape(xla::U32, {5, 6}),
241               xla::ShapeUtil::MakeShape(xla::F32, {1}),
242               xla::ShapeUtil::MakeShape(xla::S32, {5}),
243           }))
244           .ToProto();
245   compile_result.entry_point = "entry_point";
246   compile_result.pointer_size = 8;
247 
248   MetadataResult metadata_result;
249   TF_ASSERT_OK(GenerateMetadata(opts, compile_result, &metadata_result));
250 
251   // The other fields in metadata_result are tested as part of the generated
252   // header test.
253 
254   // This specific golden test checks a binary file. It can potentially run into
255   // issues due to ABIs not being stable, but has not so far.
256   // If we see any ABI issues, we should reconsider this specific test case.
257   CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_o.golden",
258                         metadata_result.object_file_data, false);
259 
260   string header;
261   TF_ASSERT_OK(
262       GenerateHeader(opts, config, compile_result, metadata_result, &header));
263 
264   CompareWithGoldenFile("tensorflow/compiler/aot/codegen_test_h.golden", header,
265                         true);
266 }
267 }  // namespace
268 }  // namespace tfcompile
269 }  // namespace tensorflow
270