• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // DEPRECATED: Use the C++ API defined in tensorflow/cc instead.
17 
18 #ifndef TENSORFLOW_CORE_GRAPH_TESTLIB_H_
19 #define TENSORFLOW_CORE_GRAPH_TESTLIB_H_
20 
21 #include <string>
22 #include <vector>
23 
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/graph/graph.h"
27 #include "tensorflow/core/graph/types.h"
28 #include "tensorflow/core/platform/types.h"
29 
30 namespace tensorflow {
31 namespace test {
32 namespace graph {
33 
34 // Converts "g" into its corresponding GraphDef "def".
35 ABSL_DEPRECATED("Call g->ToGraphDef(def) instead.")
36 void ToGraphDef(Graph* g, GraphDef* def);
37 
38 // A few helpers to construct a graph.
39 
40 // Adds a node in "g" producing a constant "tensor".
41 Node* Constant(Graph* g, const Tensor& tensor);
42 Node* Constant(Graph* g, const Tensor& tensor, const string& name);
43 
44 // Adds a node in "g" producing a constant "tensor" on the host.
45 // The given node which, unlike the regular Constant above, always
46 // stores its output on the host.  This is necessary for use
47 // in GPU tests where the test Op in question runs on the device
48 // but requires some arguments to be pinned to the host.
49 Node* HostConstant(Graph* g, const Tensor& tensor);
50 Node* HostConstant(Graph* g, const Tensor& tensor, const string& name);
51 
52 // Adds a variable in "g" of the given "shape" and "dtype".
53 Node* Var(Graph* g, const DataType dtype, const TensorShape& shape);
54 Node* Var(Graph* g, const DataType dtype, const TensorShape& shape,
55           const string& name);
56 
57 // Adds an assign node in "g" which assigns "val" into "var".
58 Node* Assign(Graph* g, Node* var, Node* val);
59 
60 // Adds a send node "g" sending "input" as a named "tensor" from
61 // "sender" to "receiver".
62 Node* Send(Graph* g, Node* input, const string& tensor, const string& sender,
63            const uint64 sender_incarnation, const string& receiver);
64 
65 // Adds a recv node in "g" receiving a named "tensor" from "sender"
66 // to "receiver".
67 Node* Recv(Graph* g, const string& tensor, const string& type,
68            const string& sender, const uint64 sender_incarnation,
69            const string& receiver);
70 
71 // Adds a cumsum "node" in "g" doing cumsum(data, axes).
72 Node* Cumsum(Graph* g, Node* data, Node* axes, bool exclusive = false,
73              bool reverse = false);
74 
75 // Adds a reduction "node" in "g" doing sum(data, axes).  "reduce" is
76 // a reduction, e.g., Sum, Max, Min, Mean, etc.
77 Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes,
78              bool keep_dims = false);
79 
80 // Adds a Matmul node in g doing in0.contract(in1).
81 Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a,
82              bool transpose_b);
83 
84 // Adds a Matmul node in g doing in0.contract(in1).
85 Node* BatchMatmul(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y);
86 
87 // Adds a Quantize node into g that quantize floats into QUINT8. The range of
88 // the input float tensor is assumed to be [-1, 1].
89 Node* QuantizeToUINT8(Graph* g, Node* data);
90 
91 // Adds a unary function "func" "node" in "g" taking "input".
92 Node* Unary(Graph* g, const string& func, Node* input, int index = 0);
93 
94 // Adds an identity node in "g" taking "input" and producing an
95 // identity copy.
96 Node* Identity(Graph* g, Node* input, int index = 0);
97 
98 // Adds a binary function "func" node in "g" taking "in0" and "in1".
99 Node* Binary(Graph* g, const string& func, Node* in0, Node* in1);
100 
101 // Adds a function "func" node in "g" taking inputs "ins".
102 Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins);
103 
104 // Adds a binary add node in "g" doing in0 + in1.
105 Node* Add(Graph* g, Node* in0, Node* in1);
106 
107 // Reverses <axis> dimensions of <tensor>>
108 Node* Reverse(Graph* g, Node* tensor, Node* axis);
109 
110 // Generates random unit uniform distribution of the input shape.
111 Node* RandomUniform(Graph* g, Node* input, DataType dtype);
112 
113 // Generates random unit normal distribution of the input shape.
114 Node* RandomGaussian(Graph* g, Node* input, DataType dtype);
115 
116 // Generates random gamma distribution with the given shape and alpha[s].
117 // Output dtype determined by alpha.
118 Node* RandomGamma(Graph* g, Node* shape, Node* alpha);
119 
120 // Generates random poisson distribution with the given shape and lam[s].
121 // Output dtype determined by lam.
122 Node* RandomPoisson(Graph* g, Node* shape, Node* lam);
123 
124 // Rolls tensor by an offset of <shift> along the corresponding
125 // <axis> dimensions.
126 Node* Roll(Graph* g, Node* input, Node* shift, Node* axis);
127 
128 // Generates random parameters from the truncated standard normal distribution
129 // of the input shape
130 Node* TruncatedNormal(Graph* g, Node* input, DataType dtype);
131 
132 // Adds an error node in "g". The node's computation always
133 // generates an error with the given error message "errmsg".
134 Node* Error(Graph* g, Node* input, const string& errmsg,
135             bool log_error = false);
136 
137 // Adds a node that generates a invalid ref output.
138 Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type);
139 
140 // Adds a node in "g". Its Compute() sleeps a while and outputs the
141 // input (i.e., same as identity).
142 Node* Delay(Graph* g, Node* input, Microseconds delay_micros);
143 
144 // Adds a no-op "node" in "g", with control inputs from all nodes in
145 // control_inputs vector.
146 Node* NoOp(Graph* g, const std::vector<Node*>& control_inputs);
147 
148 // Adds a Switch node in "g". If "in1" is true, it forwards "in0" to
149 // output 1. Otherwise, it forwards "in0" to output 0.
150 Node* Switch(Graph* g, Node* in0, Node* in1);
151 
152 // Adds an Enter node in "g", which enters a new frame.
153 Node* Enter(Graph* g, Node* input, const string& frame_name);
154 
155 // Adds an Exit node in "g", which exits a frame.
156 Node* Exit(Graph* g, Node* input);
157 
158 // Adds a Merge node in "g" with two inputs "in0" and "in1".
159 Node* Merge(Graph* g, Node* in0, Node* in1);
160 
161 // Adds a Merge node in "g". The first input is "in0", the remaining
162 // inputs are only given by their names in remaining_in.
163 Node* Merge(Graph* g, Node* in0, gtl::ArraySlice<string> remaining_in);
164 
165 // Adds a NextIteration node in "g", which makes its input available
166 // to the next iteration.
167 Node* Next(Graph* g, const string& name, Node* input);
168 
169 // Adds a LoopCond node in "g", representing the "pivot" termination
170 // condition of a loop.
171 Node* LoopCond(Graph* g, Node* input);
172 
173 // Adds a less node in "g", which returns true iff "in0" < "in1".
174 Node* Less(Graph* g, Node* in0, Node* in1);
175 
176 // Adds a select node in "g", which outputs either "inx" or "iny"
177 // depending on the boolean value of "c".
178 Node* Select(Graph* g, Node* c, Node* inx, Node* iny);
179 
180 // Casts "in" into data type "dst".
181 Node* Cast(Graph* g, Node* in, DataType dst);
182 
183 // Perform gather op on params "in0" with indices "in1" and axis "axis".
184 Node* Gather(Graph* g, Node* in0, Node* in1, Node* axis);
185 
186 // Gets a tensor stored in the session state.
187 Node* GetSessionTensor(Graph* g, Node* in);
188 
189 // Adds a Concat node in "g". The first input is "concat_dim", the
190 // dimension to concatenate on, and the tensors to concatenate are
191 // given in "tensors".
192 Node* Concat(Graph* g, Node* concat_dim, gtl::ArraySlice<Node*> tensors);
193 
194 // Adds a ConcatV2 node in "g". The last input is "concat_dim", the
195 // dimension to concatenate on, and the tensors to concatenate are
196 // given in "tensors".
197 Node* ConcatV2(Graph* g, gtl::ArraySlice<Node*> tensors, Node* concat_dim);
198 
199 // Add a Relu node in "g".
200 Node* Relu(Graph* g, Node* in);
201 
202 // Add a Relu6 node in "g".
203 Node* Relu6(Graph* g, Node* in);
204 
205 // Add a BiasAdd node in "g".
206 Node* BiasAdd(Graph* g, Node* value, Node* bias);
207 
208 // Add a Conv2D node in "g".
209 Node* Conv2D(Graph* g, Node* in0, Node* in1);
210 
211 // Add a Diag node in "g".
212 Node* Diag(Graph* g, Node* in, DataType type);
213 
214 // Add a DiagPart node in "g".
215 Node* DiagPart(Graph* g, Node* in, DataType type);
216 
217 // Add a CheckNumerics node in "g".
218 Node* CheckNumerics(Graph* g, Node* in, const string& message);
219 
220 // Add an _Arg node in "g".
221 Node* Arg(Graph* g, int64 index, DataType type);
222 
223 // Add a _Retval node in "g".
224 Node* Retval(Graph* g, int64 index, Node* in);
225 
226 }  // end namespace graph
227 }  // end namespace test
228 }  // end namespace tensorflow
229 
230 #endif  // TENSORFLOW_CORE_GRAPH_TESTLIB_H_
231