1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/experimental/ops/gen/cpp/cpp_generator.h"
16
17 #include <algorithm>
18
19 #include "tensorflow/core/lib/io/path.h"
20 #include "tensorflow/core/platform/env.h"
21 #include "tensorflow/core/platform/test.h"
22
23 namespace tensorflow {
24 namespace generator {
25 namespace {
26
TEST(CppGeneratorTest,typical_usage)27 TEST(CppGeneratorTest, typical_usage) {
28 string category = "testing";
29 string name_space = "tensorflow::ops";
30 string output_dir = "tensorflow/c/experimental/ops/gen/cpp/golden";
31 string source_dir = "tensorflow";
32 string api_dirs = "";
33 std::vector<string> ops = {
34 "Neg", // Simple unary Op
35 "MatMul", // 2 inputs & attrs with default values
36 "IdentityN", // Variadic input+output
37 "SparseSoftmaxCrossEntropyWithLogits", // 2 outputs
38 "AccumulatorApplyGradient", // 0 outputs
39 "VarHandleOp", // type, shape, list(string) attrs
40 "RestoreV2", // Variadic output-only, list(type) attr
41 };
42
43 cpp::CppConfig cpp_config(category, name_space);
44 PathConfig controller_config(output_dir, source_dir, api_dirs, ops);
45 CppGenerator generator(cpp_config, controller_config);
46
47 Env *env = Env::Default();
48 string golden_dir = io::JoinPath(testing::TensorFlowSrcRoot(),
49 controller_config.tf_output_dir);
50
51 string generated_header = generator.HeaderFileContents().Render();
52 string generated_source = generator.SourceFileContents().Render();
53 string expected_header;
54 string header_file_name = io::JoinPath(golden_dir, "testing_ops.h.golden");
55 TF_CHECK_OK(ReadFileToString(env, header_file_name, &expected_header));
56
57 string expected_source;
58 string source_file_name = io::JoinPath(golden_dir, "testing_ops.cc.golden");
59 TF_CHECK_OK(ReadFileToString(env, source_file_name, &expected_source));
60
61 // Remove carriage returns (for Windows)
62 expected_header.erase(
63 std::remove(expected_header.begin(), expected_header.end(), '\r'),
64 expected_header.end());
65 expected_source.erase(
66 std::remove(expected_source.begin(), expected_source.end(), '\r'),
67 expected_source.end());
68
69 EXPECT_EQ(expected_header, generated_header);
70 EXPECT_EQ(expected_source, generated_source);
71 }
72
73 } // namespace
74 } // namespace generator
75 } // namespace tensorflow
76