• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/testlib.h"
17 
18 #include <vector>
19 #include "tensorflow/core/framework/common_shape_fns.h"
20 #include "tensorflow/core/framework/graph.pb.h"
21 #include "tensorflow/core/framework/node_def_builder.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/graph/graph.h"
27 #include "tensorflow/core/graph/node_builder.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/platform/logging.h"
30 
31 namespace tensorflow {
32 namespace test {
33 namespace graph {
34 
Send(Graph * g,Node * input,const string & tensor,const string & sender,const uint64 sender_incarnation,const string & receiver)35 Node* Send(Graph* g, Node* input, const string& tensor, const string& sender,
36            const uint64 sender_incarnation, const string& receiver) {
37   Node* ret;
38   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Send")
39                   .Input(input, 0)
40                   .Attr("tensor_name", tensor)
41                   .Attr("send_device", sender)
42                   .Attr("send_device_incarnation",
43                         static_cast<int64>(sender_incarnation))
44                   .Attr("recv_device", receiver)
45                   .Finalize(g, &ret));
46   return ret;
47 }
48 
Recv(Graph * g,const string & tensor,const string & type,const string & sender,const uint64 sender_incarnation,const string & receiver)49 Node* Recv(Graph* g, const string& tensor, const string& type,
50            const string& sender, const uint64 sender_incarnation,
51            const string& receiver) {
52   Node* ret;
53   DataType dtype;
54   CHECK(DataTypeFromString(type, &dtype));
55   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Recv")
56                   .Attr("tensor_type", dtype)
57                   .Attr("tensor_name", tensor)
58                   .Attr("send_device", sender)
59                   .Attr("send_device_incarnation",
60                         static_cast<int64>(sender_incarnation))
61                   .Attr("recv_device", receiver)
62                   .Finalize(g, &ret));
63   return ret;
64 }
65 
Constant(Graph * g,const Tensor & tensor)66 Node* Constant(Graph* g, const Tensor& tensor) {
67   Node* ret;
68   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Const")
69                   .Attr("dtype", tensor.dtype())
70                   .Attr("value", tensor)
71                   .Finalize(g, &ret));
72   return ret;
73 }
74 
Constant(Graph * g,const Tensor & tensor,const string & name)75 Node* Constant(Graph* g, const Tensor& tensor, const string& name) {
76   Node* ret;
77   TF_CHECK_OK(NodeBuilder(name, "Const")
78                   .Attr("dtype", tensor.dtype())
79                   .Attr("value", tensor)
80                   .Finalize(g, &ret));
81   return ret;
82 }
83 
HostConstant(Graph * g,const Tensor & tensor)84 Node* HostConstant(Graph* g, const Tensor& tensor) {
85   return HostConstant(g, tensor, g->NewName("n"));
86 }
87 
HostConstant(Graph * g,const Tensor & tensor,const string & name)88 Node* HostConstant(Graph* g, const Tensor& tensor, const string& name) {
89   Node* ret;
90   TF_CHECK_OK(NodeBuilder(name, "HostConst")
91                   .Attr("dtype", tensor.dtype())
92                   .Attr("value", tensor)
93                   .Finalize(g, &ret));
94   return ret;
95 }
96 
Var(Graph * g,const DataType dtype,const TensorShape & shape)97 Node* Var(Graph* g, const DataType dtype, const TensorShape& shape) {
98   Node* ret;
99   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Variable")
100                   .Attr("dtype", dtype)
101                   .Attr("shape", shape)
102                   .Finalize(g, &ret));
103   return ret;
104 }
105 
Var(Graph * g,const DataType dtype,const TensorShape & shape,const string & name)106 Node* Var(Graph* g, const DataType dtype, const TensorShape& shape,
107           const string& name) {
108   Node* ret;
109   TF_CHECK_OK(NodeBuilder(name, "Variable")
110                   .Attr("dtype", dtype)
111                   .Attr("shape", shape)
112                   .Finalize(g, &ret));
113   return ret;
114 }
115 
Assign(Graph * g,Node * var,Node * val)116 Node* Assign(Graph* g, Node* var, Node* val) {
117   Node* ret;
118   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Assign")
119                   .Input(var)
120                   .Input(val)
121                   .Attr("use_locking", true)
122                   .Finalize(g, &ret));
123   return ret;
124 }
125 
Cumsum(Graph * g,Node * data,Node * axes,bool exclusive,bool reverse)126 Node* Cumsum(Graph* g, Node* data, Node* axes, bool exclusive, bool reverse) {
127   Node* ret;
128   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Cumsum")
129                   .Input(data)
130                   .Input(axes)
131                   .Attr("exclusive", exclusive)
132                   .Attr("reverse", reverse)
133                   .Finalize(g, &ret));
134   return ret;
135 }
136 
Reduce(Graph * g,const string & reduce,Node * data,Node * axes,bool keep_dims)137 Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes,
138              bool keep_dims) {
139   Node* ret;
140   TF_CHECK_OK(NodeBuilder(g->NewName("n"), reduce, g->op_registry())
141                   .Input(data)
142                   .Input(axes)
143                   .Attr("keep_dims", keep_dims)
144                   .Finalize(g, &ret));
145   return ret;
146 }
147 
QuantizeToUINT8(Graph * g,Node * data)148 Node* QuantizeToUINT8(Graph* g, Node* data) {
149   Node* ret;
150   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Quantize")
151                   .Input(data)
152                   .Attr("T", DT_QUINT8)
153                   .Attr("max_range", 1.0f)
154                   .Attr("min_range", -1.0f)
155                   .Finalize(g, &ret));
156   return ret;
157 }
158 
Matmul(Graph * g,Node * in0,Node * in1,bool transpose_a,bool transpose_b)159 Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a,
160              bool transpose_b) {
161   Node* ret;
162   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatMul")
163                   .Input(in0)
164                   .Input(in1)
165                   .Attr("transpose_a", transpose_a)
166                   .Attr("transpose_b", transpose_b)
167                   .Finalize(g, &ret));
168   return ret;
169 }
170 
BatchMatmul(Graph * g,Node * in0,Node * in1,bool adj_x,bool adj_y)171 Node* BatchMatmul(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) {
172   Node* ret;
173   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMul")
174                   .Input(in0)
175                   .Input(in1)
176                   .Attr("adj_x", adj_x)
177                   .Attr("adj_y", adj_y)
178                   .Finalize(g, &ret));
179   return ret;
180 }
181 
RandomNumberGenerator(const string & op,Graph * g,Node * input,DataType dtype)182 Node* RandomNumberGenerator(const string& op, Graph* g, Node* input,
183                             DataType dtype) {
184   Node* ret;
185   TF_CHECK_OK(NodeBuilder(g->NewName("n"), op, g->op_registry())
186                   .Input(input)
187                   .Attr("dtype", dtype)
188                   .Attr("seed", 0)
189                   .Finalize(g, &ret));
190   return ret;
191 }
192 
RandomUniform(Graph * g,Node * input,DataType dtype)193 Node* RandomUniform(Graph* g, Node* input, DataType dtype) {
194   return RandomNumberGenerator("RandomUniform", g, input, dtype);
195 }
196 
RandomGaussian(Graph * g,Node * input,DataType dtype)197 Node* RandomGaussian(Graph* g, Node* input, DataType dtype) {
198   return RandomNumberGenerator("RandomStandardNormal", g, input, dtype);
199 }
200 
TruncatedNormal(Graph * g,Node * input,DataType dtype)201 Node* TruncatedNormal(Graph* g, Node* input, DataType dtype) {
202   return RandomNumberGenerator("TruncatedNormal", g, input, dtype);
203 }
204 
RandomGamma(Graph * g,Node * shape,Node * alpha)205 Node* RandomGamma(Graph* g, Node* shape, Node* alpha) {
206   Node* ret;
207   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "RandomGamma")
208                   .Input(shape)
209                   .Input(alpha)
210                   .Attr("seed", 0)
211                   .Finalize(g, &ret));
212   return ret;
213 }
214 
RandomPoisson(Graph * g,Node * shape,Node * lam)215 Node* RandomPoisson(Graph* g, Node* shape, Node* lam) {
216   Node* ret;
217   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "RandomPoisson")
218                   .Input(shape)
219                   .Input(lam)
220                   .Attr("seed", 0)
221                   .Finalize(g, &ret));
222   return ret;
223 }
224 
Unary(Graph * g,const string & func,Node * input,int index)225 Node* Unary(Graph* g, const string& func, Node* input, int index) {
226   Node* ret;
227   TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry())
228                   .Input(input, index)
229                   .Finalize(g, &ret));
230   return ret;
231 }
232 
Binary(Graph * g,const string & func,Node * in0,Node * in1)233 Node* Binary(Graph* g, const string& func, Node* in0, Node* in1) {
234   Node* ret;
235   TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry())
236                   .Input(in0)
237                   .Input(in1)
238                   .Finalize(g, &ret));
239   return ret;
240 }
241 
Multi(Graph * g,const string & func,gtl::ArraySlice<Node * > ins)242 Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins) {
243   Node* ret;
244   auto b = NodeBuilder(g->NewName("n"), func, g->op_registry());
245   for (Node* n : ins) b = b.Input(n);
246   TF_CHECK_OK(b.Finalize(g, &ret));
247   return ret;
248 }
249 
Identity(Graph * g,Node * input,int index)250 Node* Identity(Graph* g, Node* input, int index) {
251   Node* ret;
252   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Identity")
253                   .Input(input, index)
254                   .Finalize(g, &ret));
255   return ret;
256 }
257 
Add(Graph * g,Node * in0,Node * in1)258 Node* Add(Graph* g, Node* in0, Node* in1) { return Binary(g, "Add", in0, in1); }
259 
Reverse(Graph * g,Node * tensor,Node * axis)260 Node* Reverse(Graph* g, Node* tensor, Node* axis) {
261   return Binary(g, "ReverseV2", tensor, axis);
262 }
263 
Roll(Graph * g,Node * input,Node * shift,Node * axis)264 Node* Roll(Graph* g, Node* input, Node* shift, Node* axis) {
265   Node* ret;
266   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Roll", g->op_registry())
267                   .Input(input)
268                   .Input(shift)
269                   .Input(axis)
270                   .Finalize(g, &ret));
271   return ret;
272 }
273 
Error(Graph * g,Node * input,const string & errmsg,bool log_error)274 Node* Error(Graph* g, Node* input, const string& errmsg, bool log_error) {
275   Node* ret;
276   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Error")
277                   .Input(input)
278                   .Attr("message", errmsg)
279                   .Attr("log_error", log_error)
280                   .Finalize(g, &ret));
281   return ret;
282 }
283 
InvalidRefType(Graph * g,DataType out_type,DataType invalid_type)284 Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type) {
285   DCHECK(out_type != invalid_type);
286   Node* ret;
287   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "InvalidRefType")
288                   .Attr("TIn", out_type)
289                   .Attr("TOut", invalid_type)
290                   .Finalize(g, &ret));
291   return ret;
292 }
293 
Delay(Graph * g,Node * input,Microseconds delay_micros)294 Node* Delay(Graph* g, Node* input, Microseconds delay_micros) {
295   Node* ret;
296   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Delay")
297                   .Input(input)
298                   .Attr("micros", delay_micros.value())
299                   .Finalize(g, &ret));
300   return ret;
301 }
302 
NoOp(Graph * g,const std::vector<Node * > & control_inputs)303 Node* NoOp(Graph* g, const std::vector<Node*>& control_inputs) {
304   Node* ret;
305   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "NoOp")
306                   .ControlInputs(control_inputs)
307                   .Finalize(g, &ret));
308   return ret;
309 }
310 
Switch(Graph * g,Node * in0,Node * in1)311 Node* Switch(Graph* g, Node* in0, Node* in1) {
312   Node* ret;
313   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Switch")
314                   .Input(in0)
315                   .Input(in1)
316                   .Finalize(g, &ret));
317   return ret;
318 }
319 
Enter(Graph * g,Node * input,const string & frame_name)320 Node* Enter(Graph* g, Node* input, const string& frame_name) {
321   Node* ret;
322   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Enter")
323                   .Input(input)
324                   .Attr("frame_name", frame_name)
325                   .Finalize(g, &ret));
326   return ret;
327 }
328 
Exit(Graph * g,Node * input)329 Node* Exit(Graph* g, Node* input) {
330   Node* ret;
331   TF_CHECK_OK(
332       NodeBuilder(g->NewName("n"), "Exit").Input(input).Finalize(g, &ret));
333   return ret;
334 }
335 
Merge(Graph * g,Node * in0,Node * in1)336 Node* Merge(Graph* g, Node* in0, Node* in1) {
337   Node* ret;
338   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Merge")
339                   .Input({in0, in1})
340                   .Finalize(g, &ret));
341   return ret;
342 }
343 
Merge(Graph * g,Node * in0,gtl::ArraySlice<string> remaining_in)344 Node* Merge(Graph* g, Node* in0, gtl::ArraySlice<string> remaining_in) {
345   std::vector<NodeBuilder::NodeOut> inputs;
346   inputs.reserve(remaining_in.size() + 1);
347   inputs.emplace_back(in0);
348   for (const string& in_name : remaining_in) {
349     inputs.emplace_back(in_name, 0, inputs[0].dt);
350   }
351 
352   Node* ret;
353   TF_CHECK_OK(
354       NodeBuilder(g->NewName("n"), "Merge").Input(inputs).Finalize(g, &ret));
355   return ret;
356 }
357 
Concat(Graph * g,Node * concat_dim,gtl::ArraySlice<Node * > tensors)358 Node* Concat(Graph* g, Node* concat_dim, gtl::ArraySlice<Node*> tensors) {
359   std::vector<NodeBuilder::NodeOut> nodeouts;
360   nodeouts.reserve(tensors.size());
361   for (auto const t : tensors) {
362     nodeouts.emplace_back(t);
363   }
364   Node* ret;
365   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Concat")
366                   .Input(concat_dim)
367                   .Input(nodeouts)
368                   .Finalize(g, &ret));
369   return ret;
370 }
371 
ConcatV2(Graph * g,gtl::ArraySlice<Node * > tensors,Node * concat_dim)372 Node* ConcatV2(Graph* g, gtl::ArraySlice<Node*> tensors, Node* concat_dim) {
373   std::vector<NodeBuilder::NodeOut> nodeouts;
374   nodeouts.reserve(tensors.size());
375   for (auto const t : tensors) {
376     nodeouts.emplace_back(t);
377   }
378   Node* ret;
379   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "ConcatV2")
380                   .Input(nodeouts)
381                   .Input(concat_dim)
382                   .Finalize(g, &ret));
383   return ret;
384 }
385 
Next(Graph * g,const string & name,Node * input)386 Node* Next(Graph* g, const string& name, Node* input) {
387   Node* ret;
388   TF_CHECK_OK(
389       NodeBuilder(name, "NextIteration").Input(input).Finalize(g, &ret));
390   return ret;
391 }
392 
LoopCond(Graph * g,Node * input)393 Node* LoopCond(Graph* g, Node* input) {
394   Node* ret;
395   TF_CHECK_OK(
396       NodeBuilder(g->NewName("n"), "LoopCond").Input(input).Finalize(g, &ret));
397   return ret;
398 }
399 
Less(Graph * g,Node * in0,Node * in1)400 Node* Less(Graph* g, Node* in0, Node* in1) {
401   return Binary(g, "Less", in0, in1);
402 }
403 
Select(Graph * g,Node * c,Node * inx,Node * iny)404 Node* Select(Graph* g, Node* c, Node* inx, Node* iny) {
405   Node* ret;
406   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Select")
407                   .Input(c)
408                   .Input(inx)
409                   .Input(iny)
410                   .Finalize(g, &ret));
411   return ret;
412 }
413 
Cast(Graph * g,Node * in,DataType dst)414 Node* Cast(Graph* g, Node* in, DataType dst) {
415   Node* ret;
416   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Cast")
417                   .Input(in)
418                   .Attr("DstT", dst)
419                   .Finalize(g, &ret));
420   return ret;
421 }
422 
Gather(Graph * g,Node * in0,Node * in1,Node * axis)423 Node* Gather(Graph* g, Node* in0, Node* in1, Node* axis) {
424   Node* ret;
425   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GatherV2")
426                   .Input(in0)
427                   .Input(in1)
428                   .Input(axis)
429                   .Finalize(g, &ret));
430   return ret;
431 }
432 
GetSessionTensor(Graph * g,Node * in)433 Node* GetSessionTensor(Graph* g, Node* in) {
434   Node* ret;
435   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GetSessionTensor")
436                   .Input(in, 0)
437                   .Attr("dtype", DT_FLOAT)
438                   .Finalize(g, &ret));
439   return ret;
440 }
441 
Relu(Graph * g,Node * in)442 Node* Relu(Graph* g, Node* in) {
443   Node* ret;
444   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Relu")
445                   .Input(in, 0)
446                   .Attr("T", DT_FLOAT)
447                   .Finalize(g, &ret));
448   return ret;
449 }
450 
Relu6(Graph * g,Node * in)451 Node* Relu6(Graph* g, Node* in) {
452   Node* ret;
453   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Relu6")
454                   .Input(in, 0)
455                   .Attr("T", DT_FLOAT)
456                   .Finalize(g, &ret));
457   return ret;
458 }
459 
BiasAdd(Graph * g,Node * value,Node * bias)460 Node* BiasAdd(Graph* g, Node* value, Node* bias) {
461   Node* ret;
462   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BiasAdd")
463                   .Input(value)
464                   .Input(bias)
465                   .Attr("T", DT_FLOAT)
466                   .Finalize(g, &ret));
467   return ret;
468 }
469 
Conv2D(Graph * g,Node * in0,Node * in1)470 Node* Conv2D(Graph* g, Node* in0, Node* in1) {
471   Node* ret;
472   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Conv2D")
473                   .Input(in0)
474                   .Input(in1)
475                   .Attr("T", DT_FLOAT)
476                   .Attr("strides", {1, 1, 1, 1})
477                   .Attr("padding", "SAME")
478                   .Finalize(g, &ret));
479   return ret;
480 }
481 
Diag(Graph * g,Node * in,DataType type)482 Node* Diag(Graph* g, Node* in, DataType type) {
483   Node* ret;
484   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Diag")
485                   .Input(in)
486                   .Attr("T", type)
487                   .Finalize(g, &ret));
488   return ret;
489 }
490 
DiagPart(Graph * g,Node * in,DataType type)491 Node* DiagPart(Graph* g, Node* in, DataType type) {
492   Node* ret;
493   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "DiagPart")
494                   .Input(in)
495                   .Attr("T", type)
496                   .Finalize(g, &ret));
497   return ret;
498 }
499 
CheckNumerics(Graph * g,Node * in,const string & message)500 Node* CheckNumerics(Graph* g, Node* in, const string& message) {
501   Node* ret;
502   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "CheckNumerics")
503                   .Input(in)
504                   .Attr("message", message)
505                   .Finalize(g, &ret));
506   return ret;
507 }
508 
Arg(Graph * g,int64 index,DataType type)509 Node* Arg(Graph* g, int64 index, DataType type) {
510   Node* ret;
511   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Arg")
512                   .Attr("T", type)
513                   .Attr("index", index)
514                   .Finalize(g, &ret));
515   return ret;
516 }
517 
Retval(Graph * g,int64 index,Node * in)518 Node* Retval(Graph* g, int64 index, Node* in) {
519   Node* ret;
520   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Retval")
521                   .Input(in)
522                   .Attr("index", index)
523                   .Finalize(g, &ret));
524   return ret;
525 }
526 
ToGraphDef(Graph * g,GraphDef * gdef)527 void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); }
528 
529 }  // end namespace graph
530 }  // end namespace test
531 }  // end namespace tensorflow
532