1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/function_testlib.h"
17
18 #include "tensorflow/core/framework/function.h"
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/framework/tensor_testutil.h"
21 #include "tensorflow/core/framework/versions.pb.h"
22 #include "tensorflow/core/lib/core/threadpool.h"
23 #include "tensorflow/core/public/version.h"
24
25 namespace tensorflow {
26 namespace test {
27 namespace function {
28
29 typedef FunctionDefHelper FDH;
30
GDef(gtl::ArraySlice<NodeDef> nodes,gtl::ArraySlice<FunctionDef> funcs)31 GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
32 gtl::ArraySlice<FunctionDef> funcs) {
33 GraphDef g;
34 VersionDef* versions = g.mutable_versions();
35 versions->set_producer(TF_GRAPH_DEF_VERSION);
36 versions->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
37 for (const auto& n : nodes) {
38 *(g.add_node()) = n;
39 }
40 auto lib = g.mutable_library();
41 for (const auto& f : funcs) {
42 *(lib->add_function()) = f;
43 }
44 return g;
45 }
46
47 // Helper to construct a NodeDef.
NDef(StringPiece name,StringPiece op,gtl::ArraySlice<string> inputs,gtl::ArraySlice<std::pair<string,FDH::AttrValueWrapper>> attrs,const string & device)48 NodeDef NDef(StringPiece name, StringPiece op, gtl::ArraySlice<string> inputs,
49 gtl::ArraySlice<std::pair<string, FDH::AttrValueWrapper>> attrs,
50 const string& device) {
51 NodeDef n;
52 n.set_name(string(name));
53 n.set_op(string(op));
54 for (const auto& in : inputs) n.add_input(in);
55 n.set_device(device);
56 for (const auto& na : attrs)
57 n.mutable_attr()->insert({na.first, na.second.proto});
58 return n;
59 }
60
NonZero()61 FunctionDef NonZero() {
62 return FDH::Define(
63 // Name
64 "NonZero",
65 // Args
66 {"x:T"},
67 // Return values
68 {"y:T"},
69 // Attr def
70 {"T:{float, double, int32, int64, string}"},
71 // Nodes
72 {
73 {{"y"}, "Identity", {"x"}, {{"T", "$T"}}},
74 });
75 }
76
IsZero()77 FunctionDef IsZero() {
78 const Tensor kZero = test::AsScalar<int64>(0);
79 return FDH::Define(
80 // Name
81 "IsZero",
82 // Args
83 {"x: T"},
84 // Return values
85 {"equal: bool"},
86 // Attr def
87 {"T:{float, double, int32, int64, string}"},
88 {
89 {{"zero"}, "Const", {}, {{"value", kZero}, {"dtype", DT_INT64}}},
90 {{"cast"}, "Cast", {"zero"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
91 {{"equal"}, "Equal", {"x", "cast"}, {{"T", "$T"}}},
92 });
93 }
94
RandomUniform()95 FunctionDef RandomUniform() {
96 const Tensor kZero = test::AsScalar<int64>(0);
97
98 return FDH::Define(
99 // Name
100 "RandomUniform",
101 // Args
102 {"x: T"},
103 // Return values
104 {"random_uniform: int64"},
105 // Attr def
106 {"T:{float, double, int32, int64, string}"},
107 {{{"random_uniform/shape"},
108 "Const",
109 {},
110 {{"value", kZero}, {"dtype", DT_INT64}}},
111 {{"random_uniform"},
112 "RandomUniform",
113 {"random_uniform/shape"},
114 {{"T", DT_INT32},
115 {"Tout", DT_FLOAT},
116 {"seed", 87654321},
117 {"seed2", 42}}}});
118 }
119
XTimesTwo()120 FunctionDef XTimesTwo() {
121 const Tensor kTwo = test::AsScalar<int64>(2);
122 return FDH::Define(
123 // Name
124 "XTimesTwo",
125 // Args
126 {"x: T"},
127 // Return values
128 {"y: T"},
129 // Attr def
130 {"T: {float, double, int32, int64}"},
131 // Nodes
132 {
133 {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
134 {{"scale"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
135 {{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}},
136 });
137 }
138
TwoDeviceMult()139 FunctionDef TwoDeviceMult() {
140 const Tensor kTwo = test::AsScalar<int64>(2);
141 const Tensor kThree = test::AsScalar<int64>(3);
142 return FDH::Create(
143 // Name
144 "TwoDeviceMult",
145 // Args
146 {"x: T"},
147 // Return values
148 {"y_cpu: T", "y_gpu: T"},
149 // Attr def
150 {"T: {float, double, int32, int64}"},
151 // Nodes
152 {
153 {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
154 {{"num_3"}, "Const", {}, {{"value", kThree}, {"dtype", DT_INT64}}},
155 {{"factor_2"},
156 "Cast",
157 {"num_2:output:0"},
158 {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
159 {{"factor_3"},
160 "Cast",
161 {"num_3:output:0"},
162 {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
163 {{"y_cpu"},
164 "Mul",
165 {"x", "factor_2:y:0"},
166 {{"T", "$T"}},
167 {},
168 "/device:CPU:0"},
169 {{"y_gpu"},
170 "Mul",
171 {"x", "factor_3:y:0"},
172 {{"T", "$T"}},
173 {},
174 "/device:GPU:0"},
175 },
176 {{"y_cpu", "y_cpu:z:0"}, {"y_gpu", "y_gpu:z:0"}});
177 }
178
TwoDeviceInputOutput()179 FunctionDef TwoDeviceInputOutput() {
180 const Tensor kTwo = test::AsScalar<float>(2);
181 const Tensor kThree = test::AsScalar<float>(3);
182 return FDH::Create(
183 // Name
184 "TwoDeviceInputOutput",
185 // Args
186 {"x1: T", "x2: T"},
187 // Return values
188 {"y_cpu: T", "y_gpu: T"},
189 // Attr def
190 {"T: {float}"},
191 // Nodes
192 {
193 {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}},
194 {{"num_3"}, "Const", {}, {{"value", kThree}, {"dtype", DT_FLOAT}}},
195 {{"y_cpu"},
196 "Mul",
197 {"x1", "num_2:output:0"},
198 {{"T", "$T"}},
199 {},
200 "/device:CPU:0"},
201 {{"y_gpu"},
202 "Mul",
203 {"x2", "num_3:output:0"},
204 {{"T", "$T"}},
205 {},
206 "/device:GPU:0"},
207 },
208 {{"y_cpu", "y_cpu:z:0"}, {"y_gpu", "y_gpu:z:0"}});
209 }
210
FuncWithListInput()211 FunctionDef FuncWithListInput() {
212 const Tensor kTwo = test::AsScalar<float>(2);
213 return FDH::Create(
214 // Name
215 "FuncWithListInput",
216 // Args
217 {"x1: N * T"},
218 // Return values
219 {},
220 // Attr def
221 {"T: {float}", "N: int >= 1"},
222 // Nodes
223 {
224 {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}},
225 },
226 {});
227 }
228
FuncWithListOutput()229 FunctionDef FuncWithListOutput() {
230 const Tensor kTwo = test::AsScalar<float>(2);
231 return FDH::Create(
232 // Name
233 "FuncWithListOutput",
234 // Args
235 {},
236 // Return values
237 {"y: N * T"},
238 // Attr def
239 {"T: {float}", "N: int >= 1"},
240 // Nodes
241 {
242 {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}},
243 },
244 {{"y", "num_2:output:0"}});
245 }
246
XAddX()247 FunctionDef XAddX() {
248 return FDH::Define(
249 // Name
250 "XAddX",
251 // Args
252 {"x: T"},
253 // Return values
254 {"y: T"},
255 // Attr def
256 {"T: {float, double, int32, int64}"},
257 // Nodes
258 {
259 {{"y"}, "Add", {"x", "x"}, {{"T", "$T"}}},
260 });
261 }
262
XAddY()263 FunctionDef XAddY() {
264 return FDH::Define(
265 // Name
266 "XAddY",
267 // Args
268 {"x: T", "y: T"},
269 // Return values
270 {"z: T"},
271 // Attr def
272 {"T: {float, double, int32, int64}"},
273 // Nodes
274 {
275 {{"z"}, "Add", {"x", "y"}, {{"T", "$T"}}},
276 });
277 }
278
XTimesTwoInt32()279 FunctionDef XTimesTwoInt32() {
280 const Tensor kTwo = test::AsScalar<int64>(2);
281 return FDH::Define(
282 // Name
283 "XTimesTwoInt32",
284 // Args
285 {"x: int32"},
286 // Return values
287 {"y: int32"}, {},
288 // Nodes
289 {
290 {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
291 {{"scale"},
292 "Cast",
293 {"two"},
294 {{"SrcT", DT_INT64}, {"DstT", DT_INT32}}},
295 {{"y"}, "Mul", {"x", "scale"}, {{"T", DT_INT32}}},
296 });
297 }
298
XTimesFour()299 FunctionDef XTimesFour() {
300 return FDH::Create(
301 // Name
302 "XTimesFour",
303 // Args
304 {"x: T"},
305 // Return values
306 {"y: T"},
307 // Attr def
308 {"T: {float, double, int32, int64}"},
309 // Nodes
310 {
311 {{"x2"}, "XTimesTwo", {"x"}, {{"T", "$T"}}},
312 {{"y"}, "XTimesTwo", {"x2:y:0"}, {{"T", "$T"}}},
313 },
314 {{"y", "y:y:0"}});
315 }
316
XTimes16()317 FunctionDef XTimes16() {
318 return FDH::Create(
319 // Name
320 "XTimes16",
321 // Args
322 {"x: T"},
323 // Return values
324 {"y: T"},
325 // Attr def
326 {"T: {float, double, int32, int64}"},
327 // Nodes
328 {
329 {{"x4"}, "XTimesFour", {"x"}, {{"T", "$T"}}},
330 {{"y"}, "XTimesFour", {"x4:y:0"}, {{"T", "$T"}}},
331 },
332 {{"y", "y:y:0"}});
333 }
334
WXPlusB()335 FunctionDef WXPlusB() {
336 return FDH::Define(
337 // Name
338 "WXPlusB",
339 // Args
340 {"w: T", "x: T", "b: T"},
341 // Return values
342 {"y: T"},
343 // Attr def
344 {"T: {float, double}"},
345 // Nodes
346 {{{"mm"},
347 "MatMul",
348 {"w", "x"},
349 {{"T", "$T"}, {"transpose_a", false}, {"transpose_b", false}}},
350 {{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}});
351 }
352
Swap()353 FunctionDef Swap() {
354 return FDH::Define(
355 // Name
356 "Swap",
357 // Args
358 {"i0: T", "i1: T"},
359 // Return values
360 {"o0: T", "o1: T"},
361 // Attr def
362 {"T: {float, double, resource}"},
363 // Nodes
364 {{{"o0"}, "Identity", {"i1"}, {{"T", "$T"}}},
365 {{"o1"}, "Identity", {"i0"}, {{"T", "$T"}}}});
366 }
367
EmptyBodySwap()368 FunctionDef EmptyBodySwap() {
369 return FDH::Create(
370 // Name
371 "EmptyBodySwap",
372 // Args
373 {"i0: T", "i1: T"},
374 // Return values
375 {"o0: T", "o1: T"},
376 // Attr def
377 {"T: {float, double, resource}"},
378 // Nodes
379 {},
380 // Output mapping
381 {{"o0", "i1"}, {"o1", "i0"}});
382 }
383
ResourceOutput()384 FunctionDef ResourceOutput() {
385 const Tensor kTwo = test::AsScalar<float>(2);
386 return FDH::Create(
387 // Name
388 "ResourceOutput",
389 // Args
390 {"x: float", "y: resource"},
391 // Return values
392 {"y_out: resource", "two_x: float"},
393 // Attr def
394 {},
395 // Nodes
396 {
397 {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}},
398 {{"mul"}, "Mul", {"x", "two:output:0"}, {{"T", DT_FLOAT}}, {}},
399 },
400 {{"y_out", "y"}, {"two_x", "mul:z:0"}});
401 }
402
ResourceIdentity()403 FunctionDef ResourceIdentity() {
404 return FDH::Create(
405 // Name
406 "ResourceIdentity",
407 // Args
408 {"x: resource"},
409 // Return values
410 {"y: resource"},
411 // Attr def
412 {},
413 // Nodes
414 {},
415 // Output mapping
416 {{"y", "x"}});
417 }
418
ReadResourceVariable()419 FunctionDef ReadResourceVariable() {
420 return FDH::Create(
421 // Name
422 "ReadResourceVariable",
423 // Args
424 {"x: resource"},
425 // Return values
426 {"y: float"},
427 // Attr def
428 {},
429 // Nodes
430 {
431 {{"read"}, "ReadVariableOp", {"x"}, {{"dtype", DT_FLOAT}}, {}},
432 },
433 {{"y", "read:value:0"}});
434 }
435
InvalidControlFlow()436 FunctionDef InvalidControlFlow() {
437 return FDH::Create(
438 // Name
439 "InvalidControlFlow",
440 // Args
441 {"i: int32"},
442 // Return values
443 {"o: int32"},
444 // Attr def
445 {},
446 // Nodes
447 {{{"enter"}, "Enter", {"i"}, {{"T", DT_INT32}, {"frame_name", "while"}}},
448 {{"add"}, "Add", {"enter:output", "i"}, {{"T", DT_INT32}}}},
449 // Output mapping
450 {{"o", "add:z"}});
451 }
452
LessThanOrEqualToN(int64 N)453 FunctionDef LessThanOrEqualToN(int64 N) {
454 const Tensor kN = test::AsScalar<int64>(N);
455 return FDH::Define(
456 // Name
457 "LessThanOrEqualToN",
458 // Args
459 {"x: T"},
460 // Return values
461 {"z: bool"},
462 // Attr def
463 {"T: {float, double, int32, int64}"},
464 // Nodes
465 {
466 {{"N"}, "Const", {}, {{"value", kN}, {"dtype", DT_INT64}}},
467 {{"y"}, "Cast", {"N"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
468 {{"z"}, "LessEqual", {"x", "y"}, {{"T", "$T"}}},
469 });
470 }
471
XPlusOneXTimesY()472 FunctionDef XPlusOneXTimesY() {
473 const Tensor kOne = test::AsScalar<int64>(1);
474 return FDH::Define(
475 // Name
476 "XPlusOneXTimesY",
477 // Args
478 {"x: T", "y: T"},
479 // Return values
480 {"s: T", "t: T"},
481 // Attr def
482 {"T: {float, double, int32, int64}"},
483 // Nodes
484 {{{"one"}, "Const", {}, {{"value", kOne}, {"dtype", DT_INT64}}},
485 {{"increment"}, "Cast", {"one"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
486 {{"s"}, "Add", {"x", "increment"}, {{"T", "$T"}}},
487 {{"t"}, "Mul", {"x", "y"}, {{"T", "$T"}}}});
488 }
489
XYXLessThanOrEqualToN(int64 N)490 FunctionDef XYXLessThanOrEqualToN(int64 N) {
491 const Tensor kN = test::AsScalar<int64>(N);
492 return FDH::Define(
493 // Name
494 "XYXLessThanOrEqualToN",
495 // Args
496 {"x: T", "y: T"},
497 // Return values
498 {"z: bool"},
499 // Attr def
500 {"T: {float, double, int32, int64}"},
501 // Nodes
502 {
503 {{"N"}, "Const", {}, {{"value", kN}, {"dtype", DT_INT64}}},
504 {{"N1"}, "Cast", {"N"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
505 {{"z"}, "LessEqual", {"x", "N1"}, {{"T", "$T"}}},
506 });
507 }
508
RandomUniformLess()509 FunctionDef RandomUniformLess() {
510 const Tensor kZero = test::AsScalar<int32>(0);
511 const Tensor kOne = test::AsScalar<int32>(1);
512 const Tensor k005 = test::AsScalar<float>(0.05);
513
514 return FDH::Define(
515 // Name
516 "RandomUniformLess",
517 // Args
518 {"arg0: int64"},
519 // Return values
520 {"strided_slice: bool"},
521 // Attr def
522 {"T:{float, double, int32, int64, string}"},
523 {{{"random_uniform/shape"},
524 "Const",
525 {},
526 {{"value", kZero}, {"dtype", DT_INT32}}},
527
528 {{"random_uniform/RandomUniform"},
529 "RandomUniform",
530 {"random_uniform/shape"},
531 {{"T", DT_INT32}, {"Tout", DT_FLOAT}, {"seed", 0}, {"seed2", 0}}},
532
533 {{"Less/y"}, "Const", {}, {{"value", k005}, {"dtype", DT_FLOAT}}},
534
535 {{"Less"},
536 "Less",
537 {"random_uniform/RandomUniform", "Less/y"},
538 {{"T", DT_FLOAT}}},
539
540 {{"strided_slice/stack"},
541 "Const",
542 {},
543 {{"value", kZero}, {"dtype", DT_INT32}}},
544
545 {{"strided_slice/stack_1"},
546 "Const",
547 {},
548 {{"value", kOne}, {"dtype", DT_INT32}}},
549
550 {{"strided_slice/stack_2"},
551 "Const",
552 {},
553 {{"value", kOne}, {"dtype", DT_INT32}}},
554
555 {{"strided_slice"},
556 "StridedSlice",
557 {"Less", "strided_slice/stack", "strided_slice/stack_1",
558 "strided_slice/stack_2"},
559 {{"Index", DT_INT32},
560 {"T", DT_BOOL},
561 {"begin_mask", 0},
562 {"ellipsis_mask", 0},
563 {"end_mask", 0},
564 {"new_axis_mask", 0},
565 {"shrink_axis_mask", 0}}}});
566 }
567
MakeRangeDataset()568 FunctionDef MakeRangeDataset() {
569 return FDH::Define(
570 /*name=*/"MakeRangeDataset",
571 /*arg_def=*/{"start: int64", "stop: int64", "step: int64"},
572 /*ret_def=*/{"y:variant"},
573 /*attr_def=*/
574 {"output_types: list(type) >= 1", "output_shapes: list(shape) >= 1"},
575 /*node_def=*/
576 {{/*ret=*/{"y"},
577 /*op=*/"RangeDataset",
578 /*arg=*/{"start", "stop", "step"},
579 /*attr=*/
580 {{"output_types", "$output_types"},
581 {"output_shapes", "$output_shapes"}}}});
582 }
583
MakeBatchDataset()584 FunctionDef MakeBatchDataset() {
585 return FDH::Define(
586 /*name=*/"MakeBatchDataset",
587 /*arg_def=*/
588 {"input_dataset: variant", "batch_size: int64", "drop_remainder: bool"},
589 /*ret_def=*/{"y: variant"},
590 /*attr_def=*/
591 {"parallel_copy: bool = false", "output_types: list(type) >= 1",
592 "output_shapes: list(shape) >= 1"},
593 /*node_def=*/
594 {{/*ret=*/{"y"},
595 /*op=*/"BatchDatasetV2",
596 /*arg=*/{"input_dataset", "batch_size", "drop_remainder"},
597 /*attr=*/
598 {{"parallel_copy", "$parallel_copy"},
599 {"output_types", "$output_types"},
600 {"output_shapes", "$output_shapes"}}}});
601 }
602
MakeMapDataset(bool has_other_args)603 FunctionDef MakeMapDataset(bool has_other_args) {
604 std::vector<string> args = {"input_dataset: variant"};
605 std::vector<string> inputs = {"input_dataset"};
606 if (has_other_args) {
607 args.emplace_back("other_arguments: Targuments");
608 inputs.emplace_back("other_arguments");
609 }
610
611 return FDH::Define(
612 /*name=*/"MakeMapDataset",
613 /*arg_def=*/args,
614 /*ret_def=*/
615 {"y: variant"},
616 /*attr_def=*/
617 {"f: func", "Targuments: list(type) >= 0",
618 "output_types: list(type) >= 1", "output_shapes: list(shape) >= 1",
619 "use_inter_op_parallelism: bool = true",
620 "preserve_cardinality: bool = false"},
621 /*node_def=*/
622 {{/*ret=*/{"y"},
623 /*op=*/"MapDataset",
624 /*arg=*/inputs,
625 /*attr=*/
626 {{"f", "$f"},
627 {"Targuments", "$Targuments"},
628 {"output_types", "$output_types"},
629 {"output_shapes", "$output_shapes"},
630 {"use_inter_op_parallelism", "$use_inter_op_parallelism"},
631 {"preserve_cardinality", "$preserve_cardinality"}}}});
632 }
633
MakeTakeDataset()634 FunctionDef MakeTakeDataset() {
635 return FDH::Define(
636 // Name
637 "TakeDataset",
638 // Args
639 {"input_dataset: variant", "count: int64"},
640 // Return values
641 {"y:variant"},
642 // Attr def
643 {"output_types: list(type) >= 1", "output_shapes: list(shape) >= 1"},
644 // Nodes
645 {{{"y"},
646 "TakeDataset",
647 {"input_dataset", "count"},
648 {{"output_types", "$output_types"},
649 {"output_shapes", "$output_shapes"}}}});
650 }
651
MakeTensorSliceDataset()652 FunctionDef MakeTensorSliceDataset() {
653 return FDH::Define(
654 // Name
655 "MakeTensorSliceDataset",
656 // Args
657 {"x: Toutput_types"},
658 // Return values
659 {"y: variant"},
660 // Attr def
661 {"Toutput_types: list(type) >= 1", "output_shapes: list(shape) >= 1"},
662 // Nodes
663 {{{"y"},
664 "TensorSliceDataset",
665 {"x"},
666 {{"Toutput_types", "$Toutput_types"},
667 {"output_shapes", "$output_shapes"}}}});
668 }
669
Unique()670 FunctionDef Unique() {
671 return FDH::Create(
672 // Name
673 "GetUnique",
674 // Args
675 {"x:T"},
676 // Return values
677 {"y:T", "idx: out_idx"},
678 // Attr def
679 {"T: type", "out_idx: {int32, int64} = DT_INT32"},
680 // Nodes
681 {
682 {{"result"}, "Unique", {"x"}, {{"T", "$T"}, {"out_idx", "$out_idx"}}},
683 },
684 {{"y", "result:y:0"}, {"idx", "result:idx:0"}});
685 }
686
FunctionTestSchedClosure(std::function<void ()> fn)687 void FunctionTestSchedClosure(std::function<void()> fn) {
688 static thread::ThreadPool* w =
689 new thread::ThreadPool(Env::Default(), "Test", 8);
690 w->Schedule(std::move(fn));
691 }
692
693 } // end namespace function
694 } // end namespace test
695 } // end namespace tensorflow
696