• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/js/ops/ts_op_gen.h"
17 
18 #include "tensorflow/core/framework/op_def.pb.h"
19 #include "tensorflow/core/framework/op_gen_lib.h"
20 #include "tensorflow/core/lib/core/status_test_util.h"
21 #include "tensorflow/core/lib/io/path.h"
22 #include "tensorflow/core/lib/strings/str_util.h"
23 #include "tensorflow/core/platform/env.h"
24 #include "tensorflow/core/platform/test.h"
25 
26 namespace tensorflow {
27 namespace {
28 
ExpectContainsStr(StringPiece s,StringPiece expected)29 void ExpectContainsStr(StringPiece s, StringPiece expected) {
30   EXPECT_TRUE(absl::StrContains(s, expected))
31       << "'" << s << "' does not contain '" << expected << "'";
32 }
33 
ExpectDoesNotContainStr(StringPiece s,StringPiece expected)34 void ExpectDoesNotContainStr(StringPiece s, StringPiece expected) {
35   EXPECT_FALSE(absl::StrContains(s, expected))
36       << "'" << s << "' does not contain '" << expected << "'";
37 }
38 
39 constexpr char kBaseOpDef[] = R"(
40 op {
41   name: "Foo"
42   input_arg {
43     name: "images"
44     type_attr: "T"
45     number_attr: "N"
46     description: "Images to process."
47   }
48   input_arg {
49     name: "dim"
50     description: "Description for dim."
51     type: DT_FLOAT
52   }
53   output_arg {
54     name: "output"
55     description: "Description for output."
56     type: DT_FLOAT
57   }
58   attr {
59     name: "T"
60     type: "type"
61     description: "Type for images"
62     allowed_values {
63       list {
64         type: DT_UINT8
65         type: DT_INT8
66       }
67     }
68     default_value {
69       i: 1
70     }
71   }
72   attr {
73     name: "N"
74     type: "int"
75     has_minimum: true
76     minimum: 1
77   }
78   summary: "Summary for op Foo."
79   description: "Description for op Foo."
80 }
81 )";
82 
83 // Generate TypeScript code
GenerateTsOpFileText(const string & op_def_str,const string & api_def_str,string * ts_file_text)84 void GenerateTsOpFileText(const string& op_def_str, const string& api_def_str,
85                           string* ts_file_text) {
86   Env* env = Env::Default();
87   OpList op_defs;
88   protobuf::TextFormat::ParseFromString(
89       op_def_str.empty() ? kBaseOpDef : op_def_str, &op_defs);
90   ApiDefMap api_def_map(op_defs);
91 
92   if (!api_def_str.empty()) {
93     TF_ASSERT_OK(api_def_map.LoadApiDef(api_def_str));
94   }
95 
96   const string& tmpdir = testing::TmpDir();
97   const auto ts_file_path = io::JoinPath(tmpdir, "test.ts");
98 
99   WriteTSOps(op_defs, api_def_map, ts_file_path);
100   TF_ASSERT_OK(ReadFileToString(env, ts_file_path, ts_file_text));
101 }
102 
TEST(TsOpGenTest,TestImports)103 TEST(TsOpGenTest, TestImports) {
104   string ts_file_text;
105   GenerateTsOpFileText("", "", &ts_file_text);
106 
107   const string expected = R"(
108 import * as tfc from '@tensorflow/tfjs-core';
109 import {createTensorsTypeOpAttr, nodeBackend} from './op_utils';
110 )";
111   ExpectContainsStr(ts_file_text, expected);
112 }
113 
TEST(TsOpGenTest,InputSingleAndList)114 TEST(TsOpGenTest, InputSingleAndList) {
115   const string api_def = R"pb(
116     op { graph_op_name: "Foo" arg_order: "dim" arg_order: "images" }
117   )pb";
118 
119   string ts_file_text;
120   GenerateTsOpFileText("", api_def, &ts_file_text);
121 
122   const string expected = R"(
123 export function Foo(dim: tfc.Tensor, images: tfc.Tensor[]): tfc.Tensor {
124 )";
125   ExpectContainsStr(ts_file_text, expected);
126 }
127 
TEST(TsOpGenTest,TestVisibility)128 TEST(TsOpGenTest, TestVisibility) {
129   const string api_def = R"(
130 op {
131   graph_op_name: "Foo"
132   visibility: HIDDEN
133 }
134 )";
135 
136   string ts_file_text;
137   GenerateTsOpFileText("", api_def, &ts_file_text);
138 
139   const string expected = R"(
140 export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
141 )";
142   ExpectDoesNotContainStr(ts_file_text, expected);
143 }
144 
TEST(TsOpGenTest,SkipDeprecated)145 TEST(TsOpGenTest, SkipDeprecated) {
146   const string op_def = R"(
147 op {
148   name: "DeprecatedFoo"
149   input_arg {
150     name: "input"
151     type_attr: "T"
152     description: "Description for input."
153   }
154   output_arg {
155     name: "output"
156     description: "Description for output."
157     type: DT_FLOAT
158   }
159   attr {
160     name: "T"
161     type: "type"
162     description: "Type for input"
163     allowed_values {
164       list {
165         type: DT_FLOAT
166       }
167     }
168   }
169   deprecation {
170     explanation: "Deprecated."
171   }
172 }
173 )";
174 
175   string ts_file_text;
176   GenerateTsOpFileText(op_def, "", &ts_file_text);
177 
178   ExpectDoesNotContainStr(ts_file_text, "DeprecatedFoo");
179 }
180 
TEST(TsOpGenTest,MultiOutput)181 TEST(TsOpGenTest, MultiOutput) {
182   const string op_def = R"(
183 op {
184   name: "MultiOutputFoo"
185   input_arg {
186     name: "input"
187     description: "Description for input."
188     type_attr: "T"
189   }
190   output_arg {
191     name: "output1"
192     description: "Description for output 1."
193     type: DT_FLOAT
194   }
195   output_arg {
196     name: "output2"
197     description: "Description for output 2."
198     type: DT_FLOAT
199   }
200   attr {
201     name: "T"
202     type: "type"
203     description: "Type for input"
204     allowed_values {
205       list {
206         type: DT_FLOAT
207       }
208     }
209   }
210   summary: "Summary for op MultiOutputFoo."
211   description: "Description for op MultiOutputFoo."
212 }
213 )";
214 
215   string ts_file_text;
216   GenerateTsOpFileText(op_def, "", &ts_file_text);
217 
218   const string expected = R"(
219 export function MultiOutputFoo(input: tfc.Tensor): tfc.Tensor[] {
220 )";
221   ExpectContainsStr(ts_file_text, expected);
222 }
223 
TEST(TsOpGenTest,OpAttrs)224 TEST(TsOpGenTest, OpAttrs) {
225   string ts_file_text;
226   GenerateTsOpFileText("", "", &ts_file_text);
227 
228   const string expectedFooAttrs = R"(
229   const opAttrs = [
230     createTensorsTypeOpAttr('T', images),
231     {name: 'N', type: nodeBackend().binding.TF_ATTR_INT, value: images.length}
232   ];
233 )";
234 
235   ExpectContainsStr(ts_file_text, expectedFooAttrs);
236 }
237 
238 }  // namespace
239 }  // namespace tensorflow
240