• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/function_testlib.h"
17 
18 #include "tensorflow/core/framework/function.h"
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/framework/tensor_testutil.h"
21 #include "tensorflow/core/framework/versions.pb.h"
22 #include "tensorflow/core/lib/core/threadpool.h"
23 #include "tensorflow/core/public/version.h"
24 
25 namespace tensorflow {
26 namespace test {
27 namespace function {
28 
29 typedef FunctionDefHelper FDH;
30 
GDef(gtl::ArraySlice<NodeDef> nodes,gtl::ArraySlice<FunctionDef> funcs)31 GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
32               gtl::ArraySlice<FunctionDef> funcs) {
33   GraphDef g;
34   VersionDef* versions = g.mutable_versions();
35   versions->set_producer(TF_GRAPH_DEF_VERSION);
36   versions->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
37   for (const auto& n : nodes) {
38     *(g.add_node()) = n;
39   }
40   auto lib = g.mutable_library();
41   for (const auto& f : funcs) {
42     *(lib->add_function()) = f;
43   }
44   return g;
45 }
46 
47 // Helper to construct a NodeDef.
NDef(StringPiece name,StringPiece op,gtl::ArraySlice<string> inputs,gtl::ArraySlice<std::pair<string,FDH::AttrValueWrapper>> attrs,const string & device)48 NodeDef NDef(StringPiece name, StringPiece op, gtl::ArraySlice<string> inputs,
49              gtl::ArraySlice<std::pair<string, FDH::AttrValueWrapper>> attrs,
50              const string& device) {
51   NodeDef n;
52   n.set_name(string(name));
53   n.set_op(string(op));
54   for (const auto& in : inputs) n.add_input(in);
55   n.set_device(device);
56   for (auto na : attrs) n.mutable_attr()->insert({na.first, na.second.proto});
57   return n;
58 }
59 
NonZero()60 FunctionDef NonZero() {
61   return FDH::Define(
62       // Name
63       "NonZero",
64       // Args
65       {"x:T"},
66       // Return values
67       {"y:T"},
68       // Attr def
69       {"T:{float, double, int32, int64, string}"},
70       // Nodes
71       {
72           {{"y"}, "Identity", {"x"}, {{"T", "$T"}}},
73       });
74 }
75 
IsZero()76 FunctionDef IsZero() {
77   const Tensor kZero = test::AsScalar<int64>(0);
78   return FDH::Define(
79       // Name
80       "IsZero",
81       // Args
82       {"x: T"},
83       // Return values
84       {"equal: T"},
85       // Attr def
86       {"T:{float, double, int32, int64, string}"},
87       {
88           {{"zero"}, "Const", {}, {{"value", kZero}, {"dtype", DT_INT64}}},
89           {{"cast"}, "Cast", {"zero"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
90           {{"equal"}, "Equal", {"x", "cast"}, {{"T", "$T"}}},
91       });
92 }
93 
RandomUniform()94 FunctionDef RandomUniform() {
95   const Tensor kZero = test::AsScalar<int64>(0);
96 
97   return FDH::Define(
98       // Name
99       "RandomUniform",
100       // Args
101       {"x: T"},
102       // Return values
103       {"random_uniform: int64"},
104       // Attr def
105       {"T:{float, double, int32, int64, string}"},
106       {{{"random_uniform/shape"},
107         "Const",
108         {},
109         {{"value", kZero}, {"dtype", DT_INT64}}},
110        {{"random_uniform"},
111         "RandomUniform",
112         {"random_uniform/shape"},
113         {{"T", DT_INT32},
114          {"Tout", DT_FLOAT},
115          {"seed", 87654321},
116          {"seed2", 42}}}});
117 }
118 
XTimesTwo()119 FunctionDef XTimesTwo() {
120   const Tensor kTwo = test::AsScalar<int64>(2);
121   return FDH::Define(
122       // Name
123       "XTimesTwo",
124       // Args
125       {"x: T"},
126       // Return values
127       {"y: T"},
128       // Attr def
129       {"T: {float, double, int32, int64}"},
130       // Nodes
131       {
132           {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
133           {{"scale"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
134           {{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}},
135       });
136 }
137 
TwoDeviceMult()138 FunctionDef TwoDeviceMult() {
139   const Tensor kTwo = test::AsScalar<int64>(2);
140   const Tensor kThree = test::AsScalar<int64>(3);
141   return FDH::Create(
142       // Name
143       "TwoDeviceMult",
144       // Args
145       {"x: T"},
146       // Return values
147       {"y_cpu: T", "y_gpu: T"},
148       // Attr def
149       {"T: {float, double, int32, int64}"},
150       // Nodes
151       {
152           {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
153           {{"num_3"}, "Const", {}, {{"value", kThree}, {"dtype", DT_INT64}}},
154           {{"factor_2"},
155            "Cast",
156            {"num_2:output:0"},
157            {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
158           {{"factor_3"},
159            "Cast",
160            {"num_3:output:0"},
161            {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
162           {{"y_cpu"},
163            "Mul",
164            {"x", "factor_2:y:0"},
165            {{"T", "$T"}},
166            {},
167            "/device:CPU:0"},
168           {{"y_gpu"},
169            "Mul",
170            {"x", "factor_3:y:0"},
171            {{"T", "$T"}},
172            {},
173            "/device:GPU:0"},
174       },
175       {{"y_cpu", "y_cpu:z:0"}, {"y_gpu", "y_gpu:z:0"}});
176 }
177 
TwoDeviceInputOutput()178 FunctionDef TwoDeviceInputOutput() {
179   const Tensor kTwo = test::AsScalar<float>(2);
180   const Tensor kThree = test::AsScalar<float>(3);
181   return FDH::Create(
182       // Name
183       "TwoDeviceInputOutput",
184       // Args
185       {"x1: T", "x2: T"},
186       // Return values
187       {"y_cpu: T", "y_gpu: T"},
188       // Attr def
189       {"T: {float}"},
190       // Nodes
191       {
192           {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}},
193           {{"num_3"}, "Const", {}, {{"value", kThree}, {"dtype", DT_FLOAT}}},
194           {{"y_cpu"},
195            "Mul",
196            {"x1", "num_2:output:0"},
197            {{"T", "$T"}},
198            {},
199            "/device:CPU:0"},
200           {{"y_gpu"},
201            "Mul",
202            {"x2", "num_3:output:0"},
203            {{"T", "$T"}},
204            {},
205            "/device:GPU:0"},
206       },
207       {{"y_cpu", "y_cpu:z:0"}, {"y_gpu", "y_gpu:z:0"}});
208 }
209 
FuncWithListInput()210 FunctionDef FuncWithListInput() {
211   const Tensor kTwo = test::AsScalar<float>(2);
212   return FDH::Create(
213       // Name
214       "FuncWithListInput",
215       // Args
216       {"x1: N * T"},
217       // Return values
218       {},
219       // Attr def
220       {"T: {float}", "N: int >= 1"},
221       // Nodes
222       {
223           {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}},
224       },
225       {});
226 }
227 
FuncWithListOutput()228 FunctionDef FuncWithListOutput() {
229   const Tensor kTwo = test::AsScalar<float>(2);
230   return FDH::Create(
231       // Name
232       "FuncWithListOutput",
233       // Args
234       {},
235       // Return values
236       {"y: N * T"},
237       // Attr def
238       {"T: {float}", "N: int >= 1"},
239       // Nodes
240       {
241           {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}},
242       },
243       {{"y", "num_2:output:0"}});
244 }
245 
XAddX()246 FunctionDef XAddX() {
247   return FDH::Define(
248       // Name
249       "XAddX",
250       // Args
251       {"x: T"},
252       // Return values
253       {"y: T"},
254       // Attr def
255       {"T: {float, double, int32, int64}"},
256       // Nodes
257       {
258           {{"y"}, "Add", {"x", "x"}, {{"T", "$T"}}},
259       });
260 }
261 
XTimesTwoInt32()262 FunctionDef XTimesTwoInt32() {
263   const Tensor kTwo = test::AsScalar<int64>(2);
264   return FDH::Define(
265       // Name
266       "XTimesTwoInt32",
267       // Args
268       {"x: int32"},
269       // Return values
270       {"y: int32"}, {},
271       // Nodes
272       {
273           {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
274           {{"scale"},
275            "Cast",
276            {"two"},
277            {{"SrcT", DT_INT64}, {"DstT", DT_INT32}}},
278           {{"y"}, "Mul", {"x", "scale"}, {{"T", DT_INT32}}},
279       });
280 }
281 
XTimesFour()282 FunctionDef XTimesFour() {
283   return FDH::Create(
284       // Name
285       "XTimesFour",
286       // Args
287       {"x: T"},
288       // Return values
289       {"y: T"},
290       // Attr def
291       {"T: {float, double, int32, int64}"},
292       // Nodes
293       {
294           {{"x2"}, "XTimesTwo", {"x"}, {{"T", "$T"}}},
295           {{"y"}, "XTimesTwo", {"x2:y:0"}, {{"T", "$T"}}},
296       },
297       {{"y", "y:y:0"}});
298 }
299 
XTimes16()300 FunctionDef XTimes16() {
301   return FDH::Create(
302       // Name
303       "XTimes16",
304       // Args
305       {"x: T"},
306       // Return values
307       {"y: T"},
308       // Attr def
309       {"T: {float, double, int32, int64}"},
310       // Nodes
311       {
312           {{"x4"}, "XTimesFour", {"x"}, {{"T", "$T"}}},
313           {{"y"}, "XTimesFour", {"x4:y:0"}, {{"T", "$T"}}},
314       },
315       {{"y", "y:y:0"}});
316 }
317 
WXPlusB()318 FunctionDef WXPlusB() {
319   return FDH::Define(
320       // Name
321       "WXPlusB",
322       // Args
323       {"w: T", "x: T", "b: T"},
324       // Return values
325       {"y: T"},
326       // Attr def
327       {"T: {float, double}"},
328       // Nodes
329       {{{"mm"},
330         "MatMul",
331         {"w", "x"},
332         {{"T", "$T"},
333          {"transpose_a", false},
334          {"transpose_b", false},
335          {"_kernel", "eigen"}}},
336        {{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}});
337 }
338 
Swap()339 FunctionDef Swap() {
340   return FDH::Define(
341       // Name
342       "Swap",
343       // Args
344       {"i0: T", "i1: T"},
345       // Return values
346       {"o0: T", "o1: T"},
347       // Attr def
348       {"T: {float, double}"},
349       // Nodes
350       {{{"o0"}, "Identity", {"i1"}, {{"T", "$T"}}},
351        {{"o1"}, "Identity", {"i0"}, {{"T", "$T"}}}});
352 }
353 
EmptyBodySwap()354 FunctionDef EmptyBodySwap() {
355   return FDH::Create(
356       // Name
357       "EmptyBodySwap",
358       // Args
359       {"i0: T", "i1: T"},
360       // Return values
361       {"o0: T", "o1: T"},
362       // Attr def
363       {"T: {float, double}"},
364       // Nodes
365       {},
366       // Output mapping
367       {{"o0", "i1"}, {"o1", "i0"}});
368 }
369 
ResourceOutput()370 FunctionDef ResourceOutput() {
371   const Tensor kTwo = test::AsScalar<float>(2);
372   return FDH::Create(
373       // Name
374       "ResourceOutput",
375       // Args
376       {"x: float", "y: resource"},
377       // Return values
378       {"y_out: resource", "two_x: float"},
379       // Attr def
380       {},
381       // Nodes
382       {
383           {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}},
384           {{"mul"}, "Mul", {"x", "two:output:0"}, {{"T", DT_FLOAT}}, {}},
385       },
386       {{"y_out", "y"}, {"two_x", "mul:z:0"}});
387 }
388 
ReadResourceVariable()389 FunctionDef ReadResourceVariable() {
390   return FDH::Create(
391       // Name
392       "ReadResourceVariable",
393       // Args
394       {"x: resource"},
395       // Return values
396       {"y: float"},
397       // Attr def
398       {},
399       // Nodes
400       {
401           {{"read"}, "ReadVariableOp", {"x"}, {{"dtype", DT_FLOAT}}, {}},
402       },
403       {{"y", "read:value:0"}});
404 }
405 
InvalidControlFlow()406 FunctionDef InvalidControlFlow() {
407   return FDH::Create(
408       // Name
409       "InvalidControlFlow",
410       // Args
411       {"i: int32"},
412       // Return values
413       {"o: int32"},
414       // Attr def
415       {},
416       // Nodes
417       {{{"enter"}, "Enter", {"i"}, {{"T", DT_INT32}, {"frame_name", "while"}}},
418        {{"add"}, "Add", {"enter:output", "i"}, {{"T", DT_INT32}}}},
419       // Output mapping
420       {{"o", "add:z"}});
421 }
422 
LessThanOrEqualToN(int64 N)423 FunctionDef LessThanOrEqualToN(int64 N) {
424   const Tensor kN = test::AsScalar<int64>(N);
425   return FDH::Define(
426       // Name
427       "LessThanOrEqualToN",
428       // Args
429       {"x: T"},
430       // Return values
431       {"z: bool"},
432       // Attr def
433       {"T: {float, double, int32, int64}"},
434       // Nodes
435       {
436           {{"N"}, "Const", {}, {{"value", kN}, {"dtype", DT_INT64}}},
437           {{"y"}, "Cast", {"N"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
438           {{"z"}, "LessEqual", {"x", "y"}, {{"T", "$T"}}},
439       });
440 }
441 
XPlusOneXTimesY()442 FunctionDef XPlusOneXTimesY() {
443   const Tensor kOne = test::AsScalar<int64>(1);
444   return FDH::Define(
445       // Name
446       "XPlusOneXTimesY",
447       // Args
448       {"x: T", "y: T"},
449       // Return values
450       {"s: T", "t: T"},
451       // Attr def
452       {"T: {float, double, int32, int64}"},
453       // Nodes
454       {{{"one"}, "Const", {}, {{"value", kOne}, {"dtype", DT_INT64}}},
455        {{"increment"}, "Cast", {"one"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
456        {{"s"}, "Add", {"x", "increment"}, {{"T", "$T"}}},
457        {{"t"}, "Mul", {"x", "y"}, {{"T", "$T"}}}});
458 }
459 
XYXLessThanOrEqualToN(int64 N)460 FunctionDef XYXLessThanOrEqualToN(int64 N) {
461   const Tensor kN = test::AsScalar<int64>(N);
462   return FDH::Define(
463       // Name
464       "XYXLessThanOrEqualToN",
465       // Args
466       {"x: T", "y: T"},
467       // Return values
468       {"z: bool"},
469       // Attr def
470       {"T: {float, double, int32, int64}"},
471       // Nodes
472       {
473           {{"N"}, "Const", {}, {{"value", kN}, {"dtype", DT_INT64}}},
474           {{"N1"}, "Cast", {"N"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
475           {{"z"}, "LessEqual", {"x", "N1"}, {{"T", "$T"}}},
476       });
477 }
478 
FunctionTestSchedClosure(std::function<void ()> fn)479 void FunctionTestSchedClosure(std::function<void()> fn) {
480   static thread::ThreadPool* w =
481       new thread::ThreadPool(Env::Default(), "Test", 8);
482   w->Schedule(std::move(fn));
483 }
484 
485 }  // end namespace function
486 }  // end namespace test
487 }  // end namespace tensorflow
488