1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/testlib.h"
17
18 #include <vector>
19 #include "tensorflow/core/framework/common_shape_fns.h"
20 #include "tensorflow/core/framework/graph.pb.h"
21 #include "tensorflow/core/framework/node_def_builder.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/graph/graph.h"
27 #include "tensorflow/core/graph/node_builder.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/platform/logging.h"
30
31 namespace tensorflow {
32 namespace test {
33 namespace graph {
34
Send(Graph * g,Node * input,const string & tensor,const string & sender,const uint64 sender_incarnation,const string & receiver)35 Node* Send(Graph* g, Node* input, const string& tensor, const string& sender,
36 const uint64 sender_incarnation, const string& receiver) {
37 Node* ret;
38 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Send")
39 .Input(input, 0)
40 .Attr("tensor_name", tensor)
41 .Attr("send_device", sender)
42 .Attr("send_device_incarnation",
43 static_cast<int64>(sender_incarnation))
44 .Attr("recv_device", receiver)
45 .Finalize(g, &ret));
46 return ret;
47 }
48
Recv(Graph * g,const string & tensor,const string & type,const string & sender,const uint64 sender_incarnation,const string & receiver)49 Node* Recv(Graph* g, const string& tensor, const string& type,
50 const string& sender, const uint64 sender_incarnation,
51 const string& receiver) {
52 Node* ret;
53 DataType dtype;
54 CHECK(DataTypeFromString(type, &dtype));
55 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Recv")
56 .Attr("tensor_type", dtype)
57 .Attr("tensor_name", tensor)
58 .Attr("send_device", sender)
59 .Attr("send_device_incarnation",
60 static_cast<int64>(sender_incarnation))
61 .Attr("recv_device", receiver)
62 .Finalize(g, &ret));
63 return ret;
64 }
65
Constant(Graph * g,const Tensor & tensor)66 Node* Constant(Graph* g, const Tensor& tensor) {
67 Node* ret;
68 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Const")
69 .Attr("dtype", tensor.dtype())
70 .Attr("value", tensor)
71 .Finalize(g, &ret));
72 return ret;
73 }
74
Constant(Graph * g,const Tensor & tensor,const string & name)75 Node* Constant(Graph* g, const Tensor& tensor, const string& name) {
76 Node* ret;
77 TF_CHECK_OK(NodeBuilder(name, "Const")
78 .Attr("dtype", tensor.dtype())
79 .Attr("value", tensor)
80 .Finalize(g, &ret));
81 return ret;
82 }
83
HostConstant(Graph * g,const Tensor & tensor)84 Node* HostConstant(Graph* g, const Tensor& tensor) {
85 return HostConstant(g, tensor, g->NewName("n"));
86 }
87
HostConstant(Graph * g,const Tensor & tensor,const string & name)88 Node* HostConstant(Graph* g, const Tensor& tensor, const string& name) {
89 Node* ret;
90 TF_CHECK_OK(NodeBuilder(name, "HostConst")
91 .Attr("dtype", tensor.dtype())
92 .Attr("value", tensor)
93 .Finalize(g, &ret));
94 return ret;
95 }
96
Var(Graph * g,const DataType dtype,const TensorShape & shape)97 Node* Var(Graph* g, const DataType dtype, const TensorShape& shape) {
98 Node* ret;
99 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Variable")
100 .Attr("dtype", dtype)
101 .Attr("shape", shape)
102 .Finalize(g, &ret));
103 return ret;
104 }
105
Var(Graph * g,const DataType dtype,const TensorShape & shape,const string & name)106 Node* Var(Graph* g, const DataType dtype, const TensorShape& shape,
107 const string& name) {
108 Node* ret;
109 TF_CHECK_OK(NodeBuilder(name, "Variable")
110 .Attr("dtype", dtype)
111 .Attr("shape", shape)
112 .Finalize(g, &ret));
113 return ret;
114 }
115
Assign(Graph * g,Node * var,Node * val)116 Node* Assign(Graph* g, Node* var, Node* val) {
117 Node* ret;
118 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Assign")
119 .Input(var)
120 .Input(val)
121 .Attr("use_locking", true)
122 .Finalize(g, &ret));
123 return ret;
124 }
125
Cumsum(Graph * g,Node * data,Node * axes,bool exclusive,bool reverse)126 Node* Cumsum(Graph* g, Node* data, Node* axes, bool exclusive, bool reverse) {
127 Node* ret;
128 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Cumsum")
129 .Input(data)
130 .Input(axes)
131 .Attr("exclusive", exclusive)
132 .Attr("reverse", reverse)
133 .Finalize(g, &ret));
134 return ret;
135 }
136
Reduce(Graph * g,const string & reduce,Node * data,Node * axes,bool keep_dims)137 Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes,
138 bool keep_dims) {
139 Node* ret;
140 TF_CHECK_OK(NodeBuilder(g->NewName("n"), reduce, g->op_registry())
141 .Input(data)
142 .Input(axes)
143 .Attr("keep_dims", keep_dims)
144 .Finalize(g, &ret));
145 return ret;
146 }
147
QuantizeToUINT8(Graph * g,Node * data)148 Node* QuantizeToUINT8(Graph* g, Node* data) {
149 Node* ret;
150 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Quantize")
151 .Input(data)
152 .Attr("T", DT_QUINT8)
153 .Attr("max_range", 1.0f)
154 .Attr("min_range", -1.0f)
155 .Finalize(g, &ret));
156 return ret;
157 }
158
Matmul(Graph * g,Node * in0,Node * in1,bool transpose_a,bool transpose_b)159 Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a,
160 bool transpose_b) {
161 Node* ret;
162 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatMul")
163 .Input(in0)
164 .Input(in1)
165 .Attr("transpose_a", transpose_a)
166 .Attr("transpose_b", transpose_b)
167 .Finalize(g, &ret));
168 return ret;
169 }
170
BatchMatmul(Graph * g,Node * in0,Node * in1,bool adj_x,bool adj_y)171 Node* BatchMatmul(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) {
172 Node* ret;
173 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMul")
174 .Input(in0)
175 .Input(in1)
176 .Attr("adj_x", adj_x)
177 .Attr("adj_y", adj_y)
178 .Finalize(g, &ret));
179 return ret;
180 }
181
RandomNumberGenerator(const string & op,Graph * g,Node * input,DataType dtype)182 Node* RandomNumberGenerator(const string& op, Graph* g, Node* input,
183 DataType dtype) {
184 Node* ret;
185 TF_CHECK_OK(NodeBuilder(g->NewName("n"), op, g->op_registry())
186 .Input(input)
187 .Attr("dtype", dtype)
188 .Attr("seed", 0)
189 .Finalize(g, &ret));
190 return ret;
191 }
192
RandomUniform(Graph * g,Node * input,DataType dtype)193 Node* RandomUniform(Graph* g, Node* input, DataType dtype) {
194 return RandomNumberGenerator("RandomUniform", g, input, dtype);
195 }
196
RandomGaussian(Graph * g,Node * input,DataType dtype)197 Node* RandomGaussian(Graph* g, Node* input, DataType dtype) {
198 return RandomNumberGenerator("RandomStandardNormal", g, input, dtype);
199 }
200
TruncatedNormal(Graph * g,Node * input,DataType dtype)201 Node* TruncatedNormal(Graph* g, Node* input, DataType dtype) {
202 return RandomNumberGenerator("TruncatedNormal", g, input, dtype);
203 }
204
RandomGamma(Graph * g,Node * shape,Node * alpha)205 Node* RandomGamma(Graph* g, Node* shape, Node* alpha) {
206 Node* ret;
207 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "RandomGamma")
208 .Input(shape)
209 .Input(alpha)
210 .Attr("seed", 0)
211 .Finalize(g, &ret));
212 return ret;
213 }
214
RandomPoisson(Graph * g,Node * shape,Node * lam)215 Node* RandomPoisson(Graph* g, Node* shape, Node* lam) {
216 Node* ret;
217 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "RandomPoisson")
218 .Input(shape)
219 .Input(lam)
220 .Attr("seed", 0)
221 .Finalize(g, &ret));
222 return ret;
223 }
224
Unary(Graph * g,const string & func,Node * input,int index)225 Node* Unary(Graph* g, const string& func, Node* input, int index) {
226 Node* ret;
227 TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry())
228 .Input(input, index)
229 .Finalize(g, &ret));
230 return ret;
231 }
232
Binary(Graph * g,const string & func,Node * in0,Node * in1)233 Node* Binary(Graph* g, const string& func, Node* in0, Node* in1) {
234 Node* ret;
235 TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry())
236 .Input(in0)
237 .Input(in1)
238 .Finalize(g, &ret));
239 return ret;
240 }
241
Multi(Graph * g,const string & func,gtl::ArraySlice<Node * > ins)242 Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins) {
243 Node* ret;
244 auto b = NodeBuilder(g->NewName("n"), func, g->op_registry());
245 for (Node* n : ins) b = b.Input(n);
246 TF_CHECK_OK(b.Finalize(g, &ret));
247 return ret;
248 }
249
Identity(Graph * g,Node * input,int index)250 Node* Identity(Graph* g, Node* input, int index) {
251 Node* ret;
252 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Identity")
253 .Input(input, index)
254 .Finalize(g, &ret));
255 return ret;
256 }
257
Add(Graph * g,Node * in0,Node * in1)258 Node* Add(Graph* g, Node* in0, Node* in1) { return Binary(g, "Add", in0, in1); }
259
Reverse(Graph * g,Node * tensor,Node * axis)260 Node* Reverse(Graph* g, Node* tensor, Node* axis) {
261 return Binary(g, "ReverseV2", tensor, axis);
262 }
263
Roll(Graph * g,Node * input,Node * shift,Node * axis)264 Node* Roll(Graph* g, Node* input, Node* shift, Node* axis) {
265 Node* ret;
266 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Roll", g->op_registry())
267 .Input(input)
268 .Input(shift)
269 .Input(axis)
270 .Finalize(g, &ret));
271 return ret;
272 }
273
Error(Graph * g,Node * input,const string & errmsg,bool log_error)274 Node* Error(Graph* g, Node* input, const string& errmsg, bool log_error) {
275 Node* ret;
276 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Error")
277 .Input(input)
278 .Attr("message", errmsg)
279 .Attr("log_error", log_error)
280 .Finalize(g, &ret));
281 return ret;
282 }
283
InvalidRefType(Graph * g,DataType out_type,DataType invalid_type)284 Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type) {
285 DCHECK(out_type != invalid_type);
286 Node* ret;
287 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "InvalidRefType")
288 .Attr("TIn", out_type)
289 .Attr("TOut", invalid_type)
290 .Finalize(g, &ret));
291 return ret;
292 }
293
Delay(Graph * g,Node * input,Microseconds delay_micros)294 Node* Delay(Graph* g, Node* input, Microseconds delay_micros) {
295 Node* ret;
296 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Delay")
297 .Input(input)
298 .Attr("micros", delay_micros.value())
299 .Finalize(g, &ret));
300 return ret;
301 }
302
NoOp(Graph * g,const std::vector<Node * > & control_inputs)303 Node* NoOp(Graph* g, const std::vector<Node*>& control_inputs) {
304 Node* ret;
305 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "NoOp")
306 .ControlInputs(control_inputs)
307 .Finalize(g, &ret));
308 return ret;
309 }
310
Switch(Graph * g,Node * in0,Node * in1)311 Node* Switch(Graph* g, Node* in0, Node* in1) {
312 Node* ret;
313 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Switch")
314 .Input(in0)
315 .Input(in1)
316 .Finalize(g, &ret));
317 return ret;
318 }
319
Enter(Graph * g,Node * input,const string & frame_name)320 Node* Enter(Graph* g, Node* input, const string& frame_name) {
321 Node* ret;
322 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Enter")
323 .Input(input)
324 .Attr("frame_name", frame_name)
325 .Finalize(g, &ret));
326 return ret;
327 }
328
Exit(Graph * g,Node * input)329 Node* Exit(Graph* g, Node* input) {
330 Node* ret;
331 TF_CHECK_OK(
332 NodeBuilder(g->NewName("n"), "Exit").Input(input).Finalize(g, &ret));
333 return ret;
334 }
335
Merge(Graph * g,Node * in0,Node * in1)336 Node* Merge(Graph* g, Node* in0, Node* in1) {
337 Node* ret;
338 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Merge")
339 .Input({in0, in1})
340 .Finalize(g, &ret));
341 return ret;
342 }
343
Merge(Graph * g,Node * in0,gtl::ArraySlice<string> remaining_in)344 Node* Merge(Graph* g, Node* in0, gtl::ArraySlice<string> remaining_in) {
345 std::vector<NodeBuilder::NodeOut> inputs;
346 inputs.reserve(remaining_in.size() + 1);
347 inputs.emplace_back(in0);
348 for (const string& in_name : remaining_in) {
349 inputs.emplace_back(in_name, 0, inputs[0].dt);
350 }
351
352 Node* ret;
353 TF_CHECK_OK(
354 NodeBuilder(g->NewName("n"), "Merge").Input(inputs).Finalize(g, &ret));
355 return ret;
356 }
357
Concat(Graph * g,Node * concat_dim,gtl::ArraySlice<Node * > tensors)358 Node* Concat(Graph* g, Node* concat_dim, gtl::ArraySlice<Node*> tensors) {
359 std::vector<NodeBuilder::NodeOut> nodeouts;
360 nodeouts.reserve(tensors.size());
361 for (auto const t : tensors) {
362 nodeouts.emplace_back(t);
363 }
364 Node* ret;
365 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Concat")
366 .Input(concat_dim)
367 .Input(nodeouts)
368 .Finalize(g, &ret));
369 return ret;
370 }
371
ConcatV2(Graph * g,gtl::ArraySlice<Node * > tensors,Node * concat_dim)372 Node* ConcatV2(Graph* g, gtl::ArraySlice<Node*> tensors, Node* concat_dim) {
373 std::vector<NodeBuilder::NodeOut> nodeouts;
374 nodeouts.reserve(tensors.size());
375 for (auto const t : tensors) {
376 nodeouts.emplace_back(t);
377 }
378 Node* ret;
379 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "ConcatV2")
380 .Input(nodeouts)
381 .Input(concat_dim)
382 .Finalize(g, &ret));
383 return ret;
384 }
385
Next(Graph * g,const string & name,Node * input)386 Node* Next(Graph* g, const string& name, Node* input) {
387 Node* ret;
388 TF_CHECK_OK(
389 NodeBuilder(name, "NextIteration").Input(input).Finalize(g, &ret));
390 return ret;
391 }
392
LoopCond(Graph * g,Node * input)393 Node* LoopCond(Graph* g, Node* input) {
394 Node* ret;
395 TF_CHECK_OK(
396 NodeBuilder(g->NewName("n"), "LoopCond").Input(input).Finalize(g, &ret));
397 return ret;
398 }
399
Less(Graph * g,Node * in0,Node * in1)400 Node* Less(Graph* g, Node* in0, Node* in1) {
401 return Binary(g, "Less", in0, in1);
402 }
403
Select(Graph * g,Node * c,Node * inx,Node * iny)404 Node* Select(Graph* g, Node* c, Node* inx, Node* iny) {
405 Node* ret;
406 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Select")
407 .Input(c)
408 .Input(inx)
409 .Input(iny)
410 .Finalize(g, &ret));
411 return ret;
412 }
413
Cast(Graph * g,Node * in,DataType dst)414 Node* Cast(Graph* g, Node* in, DataType dst) {
415 Node* ret;
416 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Cast")
417 .Input(in)
418 .Attr("DstT", dst)
419 .Finalize(g, &ret));
420 return ret;
421 }
422
Gather(Graph * g,Node * in0,Node * in1,Node * axis)423 Node* Gather(Graph* g, Node* in0, Node* in1, Node* axis) {
424 Node* ret;
425 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GatherV2")
426 .Input(in0)
427 .Input(in1)
428 .Input(axis)
429 .Finalize(g, &ret));
430 return ret;
431 }
432
GetSessionTensor(Graph * g,Node * in)433 Node* GetSessionTensor(Graph* g, Node* in) {
434 Node* ret;
435 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GetSessionTensor")
436 .Input(in, 0)
437 .Attr("dtype", DT_FLOAT)
438 .Finalize(g, &ret));
439 return ret;
440 }
441
Relu(Graph * g,Node * in)442 Node* Relu(Graph* g, Node* in) {
443 Node* ret;
444 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Relu")
445 .Input(in, 0)
446 .Attr("T", DT_FLOAT)
447 .Finalize(g, &ret));
448 return ret;
449 }
450
Relu6(Graph * g,Node * in)451 Node* Relu6(Graph* g, Node* in) {
452 Node* ret;
453 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Relu6")
454 .Input(in, 0)
455 .Attr("T", DT_FLOAT)
456 .Finalize(g, &ret));
457 return ret;
458 }
459
BiasAdd(Graph * g,Node * value,Node * bias)460 Node* BiasAdd(Graph* g, Node* value, Node* bias) {
461 Node* ret;
462 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BiasAdd")
463 .Input(value)
464 .Input(bias)
465 .Attr("T", DT_FLOAT)
466 .Finalize(g, &ret));
467 return ret;
468 }
469
Conv2D(Graph * g,Node * in0,Node * in1)470 Node* Conv2D(Graph* g, Node* in0, Node* in1) {
471 Node* ret;
472 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Conv2D")
473 .Input(in0)
474 .Input(in1)
475 .Attr("T", DT_FLOAT)
476 .Attr("strides", {1, 1, 1, 1})
477 .Attr("padding", "SAME")
478 .Finalize(g, &ret));
479 return ret;
480 }
481
Diag(Graph * g,Node * in,DataType type)482 Node* Diag(Graph* g, Node* in, DataType type) {
483 Node* ret;
484 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Diag")
485 .Input(in)
486 .Attr("T", type)
487 .Finalize(g, &ret));
488 return ret;
489 }
490
DiagPart(Graph * g,Node * in,DataType type)491 Node* DiagPart(Graph* g, Node* in, DataType type) {
492 Node* ret;
493 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "DiagPart")
494 .Input(in)
495 .Attr("T", type)
496 .Finalize(g, &ret));
497 return ret;
498 }
499
CheckNumerics(Graph * g,Node * in,const string & message)500 Node* CheckNumerics(Graph* g, Node* in, const string& message) {
501 Node* ret;
502 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "CheckNumerics")
503 .Input(in)
504 .Attr("message", message)
505 .Finalize(g, &ret));
506 return ret;
507 }
508
Arg(Graph * g,int64 index,DataType type)509 Node* Arg(Graph* g, int64 index, DataType type) {
510 Node* ret;
511 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Arg")
512 .Attr("T", type)
513 .Attr("index", index)
514 .Finalize(g, &ret));
515 return ret;
516 }
517
Retval(Graph * g,int64 index,Node * in)518 Node* Retval(Graph* g, int64 index, Node* in) {
519 Node* ret;
520 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Retval")
521 .Input(in)
522 .Attr("index", index)
523 .Finalize(g, &ret));
524 return ret;
525 }
526
ToGraphDef(Graph * g,GraphDef * gdef)527 void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); }
528
529 } // end namespace graph
530 } // end namespace test
531 } // end namespace tensorflow
532