• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_FUNCTION_TESTLIB_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_FUNCTION_TESTLIB_H_
18 
19 #include <string>
20 
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/framework/function.pb.h"
24 #include "tensorflow/core/framework/graph.pb.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/core/lib/gtl/array_slice.h"
27 #include "tensorflow/core/platform/types.h"
28 
29 namespace tensorflow {
30 namespace test {
31 namespace function {
32 
33 // A helper class to make AttrSlice from initializer lists
34 class Attrs {
35  public:
Attrs(const std::initializer_list<std::pair<string,FunctionDefHelper::AttrValueWrapper>> & attrs)36   Attrs(const std::initializer_list<  // NOLINT(runtime/explicit)
37         std::pair<string, FunctionDefHelper::AttrValueWrapper>>& attrs) {
38     for (const auto& aval : attrs) {
39       map_.insert({aval.first, aval.second.proto});
40     }
41   }
42 
Attrs(const std::vector<std::pair<string,FunctionDefHelper::AttrValueWrapper>> & attrs)43   Attrs(
44       const std::vector<std::pair<string, FunctionDefHelper::AttrValueWrapper>>&
45           attrs) {
46     for (const auto& aval : attrs) {
47       map_.insert({aval.first, aval.second.proto});
48     }
49   }
50 
AttrSlice()51   operator AttrSlice() { return AttrSlice(&map_); }  // NOLINT(runtime/explicit)
52 
53  private:
54   AttrValueMap map_;
55 };
56 
57 // Helper to construct a NodeDef.
58 NodeDef NDef(
59     StringPiece name, StringPiece op, gtl::ArraySlice<string> inputs,
60     gtl::ArraySlice<std::pair<string, FunctionDefHelper::AttrValueWrapper>>
61         attrs = {},
62     const string& device = "");
63 
64 // Helper to construct a GraphDef proto.
65 GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
66               gtl::ArraySlice<FunctionDef> funcs = {});
67 
68 // For testing convenience, we provide a few simple functions that can
69 // be easily executed and tested.
70 
71 // x: T -> x * 2.
72 FunctionDef XTimesTwo();
73 
74 // x: T -> cpu(x * 2) + cpu(x * 3).
75 FunctionDef TwoDeviceTimesFive();
76 
77 // x: T -> cpu(x * 2), gpu(x * 3).
78 FunctionDef TwoDeviceMult();
79 
80 // cpu(x): T, gpu(y): T -> cpu(x * 2), gpu(y * 3).
81 FunctionDef TwoDeviceInputOutput();
82 
83 // Function taking a list of Tensors as input.
84 FunctionDef FuncWithListInput();
85 
86 // Function returning a list of Tensors as output.
87 FunctionDef FuncWithListOutput();
88 
89 // x: T -> x + x.
90 FunctionDef XAddX();
91 
92 // x: T, y: T -> x + y.
93 FunctionDef XAddY();
94 
95 // x: T -> x * 2, where x is int32.
96 FunctionDef XTimesTwoInt32();
97 
98 // x: T -> (x * 2) * 2.
99 FunctionDef XTimesFour();
100 
101 // x: T -> ((x * 2) * 2) * 2.
102 FunctionDef XTimes16();
103 
104 // w: T, x: T, b: T -> MatMul(w, x) + b
105 FunctionDef WXPlusB();
106 
107 // x: T -> x: T, T is a type which we automatically converts to a bool.
108 FunctionDef NonZero();
109 
110 // x: T -> bool.
111 FunctionDef IsZero();
112 
113 // x: T -> int64
114 FunctionDef RandomUniform();
115 
116 // x: T, y:T  -> y: T, x: T
117 FunctionDef Swap();
118 
119 // x: T, y: T -> y: T, x: T, the body has no nodes.
120 FunctionDef EmptyBodySwap();
121 
122 // x: float, y: resource -> y: resource, 2*x: float.
123 FunctionDef ResourceOutput();
124 
125 // x: resource -> x: resource
126 FunctionDef ResourceIdentity();
127 
128 // x: resource -> y: float.
129 FunctionDef ReadResourceVariable();
130 
131 // Contains malformed control flow which can't be run by the executor.
132 FunctionDef InvalidControlFlow();
133 
134 // x: T -> x <= N.
135 FunctionDef LessThanOrEqualToN(int64 N);
136 
137 // x: T, y: T -> x + 1, x * y
138 FunctionDef XPlusOneXTimesY();
139 
140 // x: T, y: T -> x <= N
141 FunctionDef XYXLessThanOrEqualToN(int64 N);
142 
143 // x: T -> bool
144 FunctionDef RandomUniformLess();
145 
146 // start: int64, stop: int64, step: int64 -> y: RangeDatasetOp::Dataset
147 FunctionDef MakeRangeDataset();
148 
149 // input_dataset: variant, batch_size: int64, drop_remainder: bool
150 // -> y: BatchDatasetV2::Dataset
151 FunctionDef MakeBatchDataset();
152 
153 // input_dataset: variant, other_arguments: Targuments, f: func,
154 // Targuments: list(type), output_types: list(type), output_shapes: list(shape),
155 // use_inter_op_parallelism: bool, preserve_cardinality: bool
156 // -> y: MapDatasetOp::Dataset
157 FunctionDef MakeMapDataset(bool has_other_args);
158 
159 // input_dataset: variant, count: int64 -> y: TakeDataset::Dataset
160 FunctionDef MakeTakeDataset();
161 
162 // x: T -> y: TensorSliceDatasetOp::Dataset
163 FunctionDef MakeTensorSliceDataset();
164 
165 // x: T -> y: T, idx: out_idx
166 FunctionDef Unique();
167 
168 void FunctionTestSchedClosure(std::function<void()> fn);
169 
170 }  // end namespace function
171 }  // end namespace test
172 }  // end namespace tensorflow
173 
174 #endif  // TENSORFLOW_CORE_FRAMEWORK_FUNCTION_TESTLIB_H_
175