1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/js/ops/ts_op_gen.h"
17
18 #include "tensorflow/core/framework/op_def.pb.h"
19 #include "tensorflow/core/framework/op_gen_lib.h"
20 #include "tensorflow/core/lib/core/status_test_util.h"
21 #include "tensorflow/core/lib/io/path.h"
22 #include "tensorflow/core/lib/strings/str_util.h"
23 #include "tensorflow/core/platform/env.h"
24 #include "tensorflow/core/platform/test.h"
25
26 namespace tensorflow {
27 namespace {
28
ExpectContainsStr(StringPiece s,StringPiece expected)29 void ExpectContainsStr(StringPiece s, StringPiece expected) {
30 EXPECT_TRUE(absl::StrContains(s, expected))
31 << "'" << s << "' does not contain '" << expected << "'";
32 }
33
ExpectDoesNotContainStr(StringPiece s,StringPiece expected)34 void ExpectDoesNotContainStr(StringPiece s, StringPiece expected) {
35 EXPECT_FALSE(absl::StrContains(s, expected))
36 << "'" << s << "' does not contain '" << expected << "'";
37 }
38
39 constexpr char kBaseOpDef[] = R"(
40 op {
41 name: "Foo"
42 input_arg {
43 name: "images"
44 type_attr: "T"
45 number_attr: "N"
46 description: "Images to process."
47 }
48 input_arg {
49 name: "dim"
50 description: "Description for dim."
51 type: DT_FLOAT
52 }
53 output_arg {
54 name: "output"
55 description: "Description for output."
56 type: DT_FLOAT
57 }
58 attr {
59 name: "T"
60 type: "type"
61 description: "Type for images"
62 allowed_values {
63 list {
64 type: DT_UINT8
65 type: DT_INT8
66 }
67 }
68 default_value {
69 i: 1
70 }
71 }
72 attr {
73 name: "N"
74 type: "int"
75 has_minimum: true
76 minimum: 1
77 }
78 summary: "Summary for op Foo."
79 description: "Description for op Foo."
80 }
81 )";
82
83 // Generate TypeScript code
GenerateTsOpFileText(const string & op_def_str,const string & api_def_str,string * ts_file_text)84 void GenerateTsOpFileText(const string& op_def_str, const string& api_def_str,
85 string* ts_file_text) {
86 Env* env = Env::Default();
87 OpList op_defs;
88 protobuf::TextFormat::ParseFromString(
89 op_def_str.empty() ? kBaseOpDef : op_def_str, &op_defs);
90 ApiDefMap api_def_map(op_defs);
91
92 if (!api_def_str.empty()) {
93 TF_ASSERT_OK(api_def_map.LoadApiDef(api_def_str));
94 }
95
96 const string& tmpdir = testing::TmpDir();
97 const auto ts_file_path = io::JoinPath(tmpdir, "test.ts");
98
99 WriteTSOps(op_defs, api_def_map, ts_file_path);
100 TF_ASSERT_OK(ReadFileToString(env, ts_file_path, ts_file_text));
101 }
102
TEST(TsOpGenTest,TestImports)103 TEST(TsOpGenTest, TestImports) {
104 string ts_file_text;
105 GenerateTsOpFileText("", "", &ts_file_text);
106
107 const string expected = R"(
108 import * as tfc from '@tensorflow/tfjs-core';
109 import {createTensorsTypeOpAttr, nodeBackend} from './op_utils';
110 )";
111 ExpectContainsStr(ts_file_text, expected);
112 }
113
TEST(TsOpGenTest,InputSingleAndList)114 TEST(TsOpGenTest, InputSingleAndList) {
115 const string api_def = R"pb(
116 op { graph_op_name: "Foo" arg_order: "dim" arg_order: "images" }
117 )pb";
118
119 string ts_file_text;
120 GenerateTsOpFileText("", api_def, &ts_file_text);
121
122 const string expected = R"(
123 export function Foo(dim: tfc.Tensor, images: tfc.Tensor[]): tfc.Tensor {
124 )";
125 ExpectContainsStr(ts_file_text, expected);
126 }
127
TEST(TsOpGenTest,TestVisibility)128 TEST(TsOpGenTest, TestVisibility) {
129 const string api_def = R"(
130 op {
131 graph_op_name: "Foo"
132 visibility: HIDDEN
133 }
134 )";
135
136 string ts_file_text;
137 GenerateTsOpFileText("", api_def, &ts_file_text);
138
139 const string expected = R"(
140 export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
141 )";
142 ExpectDoesNotContainStr(ts_file_text, expected);
143 }
144
TEST(TsOpGenTest,SkipDeprecated)145 TEST(TsOpGenTest, SkipDeprecated) {
146 const string op_def = R"(
147 op {
148 name: "DeprecatedFoo"
149 input_arg {
150 name: "input"
151 type_attr: "T"
152 description: "Description for input."
153 }
154 output_arg {
155 name: "output"
156 description: "Description for output."
157 type: DT_FLOAT
158 }
159 attr {
160 name: "T"
161 type: "type"
162 description: "Type for input"
163 allowed_values {
164 list {
165 type: DT_FLOAT
166 }
167 }
168 }
169 deprecation {
170 explanation: "Deprecated."
171 }
172 }
173 )";
174
175 string ts_file_text;
176 GenerateTsOpFileText(op_def, "", &ts_file_text);
177
178 ExpectDoesNotContainStr(ts_file_text, "DeprecatedFoo");
179 }
180
TEST(TsOpGenTest,MultiOutput)181 TEST(TsOpGenTest, MultiOutput) {
182 const string op_def = R"(
183 op {
184 name: "MultiOutputFoo"
185 input_arg {
186 name: "input"
187 description: "Description for input."
188 type_attr: "T"
189 }
190 output_arg {
191 name: "output1"
192 description: "Description for output 1."
193 type: DT_FLOAT
194 }
195 output_arg {
196 name: "output2"
197 description: "Description for output 2."
198 type: DT_FLOAT
199 }
200 attr {
201 name: "T"
202 type: "type"
203 description: "Type for input"
204 allowed_values {
205 list {
206 type: DT_FLOAT
207 }
208 }
209 }
210 summary: "Summary for op MultiOutputFoo."
211 description: "Description for op MultiOutputFoo."
212 }
213 )";
214
215 string ts_file_text;
216 GenerateTsOpFileText(op_def, "", &ts_file_text);
217
218 const string expected = R"(
219 export function MultiOutputFoo(input: tfc.Tensor): tfc.Tensor[] {
220 )";
221 ExpectContainsStr(ts_file_text, expected);
222 }
223
TEST(TsOpGenTest,OpAttrs)224 TEST(TsOpGenTest, OpAttrs) {
225 string ts_file_text;
226 GenerateTsOpFileText("", "", &ts_file_text);
227
228 const string expectedFooAttrs = R"(
229 const opAttrs = [
230 createTensorsTypeOpAttr('T', images),
231 {name: 'N', type: nodeBackend().binding.TF_ATTR_INT, value: images.length}
232 ];
233 )";
234
235 ExpectContainsStr(ts_file_text, expectedFooAttrs);
236 }
237
238 } // namespace
239 } // namespace tensorflow
240