• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 
20 namespace tensorflow {
21 
22 using shape_inference::DimensionHandle;
23 using shape_inference::InferenceContext;
24 using shape_inference::ShapeHandle;
25 
26 REGISTER_OP("RandomUniform")
27     .Input("shape: T")
28     .SetIsStateful()
29     .Output("output: dtype")
30     .Attr("seed: int = 0")
31     .Attr("seed2: int = 0")
32     .Attr("dtype: {half,bfloat16,float,double}")
33     .Attr("T: {int32, int64}")
34     .SetShapeFn(shape_inference::RandomShape);
35 
36 REGISTER_OP("RandomUniformInt")
37     .Input("shape: T")
38     .Input("minval: Tout")
39     .Input("maxval: Tout")
40     .SetIsStateful()
41     .Output("output: Tout")
42     .Attr("seed: int = 0")
43     .Attr("seed2: int = 0")
44     .Attr("Tout: {int32, int64}")
45     .Attr("T: {int32, int64}")
__anon979c5d7d0102(InferenceContext* c) 46     .SetShapeFn([](InferenceContext* c) {
47       ShapeHandle unused;
48       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
49       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
50       return shape_inference::RandomShape(c);
51     });
52 
53 REGISTER_OP("RandomStandardNormal")
54     .Input("shape: T")
55     .SetIsStateful()
56     .Output("output: dtype")
57     .Attr("seed: int = 0")
58     .Attr("seed2: int = 0")
59     .Attr("dtype: {half,bfloat16,float,double}")
60     .Attr("T: {int32, int64}")
61     .SetShapeFn(shape_inference::RandomShape);
62 
63 REGISTER_OP("ParameterizedTruncatedNormal")
64     .Input("shape: T")
65     .Input("means: dtype")
66     .Input("stdevs: dtype")
67     .Input("minvals: dtype")
68     .Input("maxvals: dtype")
69     .SetIsStateful()
70     .Output("output: dtype")
71     .Attr("seed: int = 0")
72     .Attr("seed2: int = 0")
73     .Attr("dtype: {half,bfloat16,float,double}")
74     .Attr("T: {int32, int64}")
__anon979c5d7d0202(InferenceContext* c) 75     .SetShapeFn([](InferenceContext* c) {
76       ShapeHandle unused;
77       // Parameters must be 0-d or 1-d.
78       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &unused));
79       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused));
80       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(3), 1, &unused));
81       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &unused));
82       return shape_inference::RandomShape(c);
83     });
84 
85 REGISTER_OP("TruncatedNormal")
86     .Input("shape: T")
87     .SetIsStateful()
88     .Output("output: dtype")
89     .Attr("seed: int = 0")
90     .Attr("seed2: int = 0")
91     .Attr("dtype: {half,bfloat16,float,double}")
92     .Attr("T: {int32, int64}")
93     .SetShapeFn(shape_inference::RandomShape);
94 
95 REGISTER_OP("RandomShuffle")
96     .Input("value: T")
97     .SetIsStateful()
98     .Output("output: T")
99     .Attr("seed: int = 0")
100     .Attr("seed2: int = 0")
101     .Attr("T: type")
102     .SetShapeFn(shape_inference::UnchangedShape);
103 
104 REGISTER_OP("Multinomial")
105     .SetIsStateful()
106     .Input("logits: T")
107     .Input("num_samples: int32")
108     .Output("output: output_dtype")
109     .Attr("seed: int = 0")
110     .Attr("seed2: int = 0")
111     .Attr("T: realnumbertype")
112     .Attr("output_dtype: {int32, int64} = DT_INT64")
__anon979c5d7d0302(InferenceContext* c) 113     .SetShapeFn([](InferenceContext* c) {
114       ShapeHandle logits_shape;
115       ShapeHandle unused;
116       DimensionHandle num_samples;
117       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &logits_shape));
118       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
119       TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &num_samples));
120       c->set_output(0, c->Matrix(c->Dim(logits_shape, 0), num_samples));
121       return Status::OK();
122     });
123 
124 REGISTER_OP("RandomGamma")
125     .SetIsStateful()
126     .Input("shape: S")
127     .Input("alpha: T")
128     .Output("output: T")
129     .Attr("seed: int = 0")
130     .Attr("seed2: int = 0")
131     .Attr("S: {int32, int64}")
132     .Attr("T: {half, float, double}")
__anon979c5d7d0402(InferenceContext* c) 133     .SetShapeFn([](InferenceContext* c) {
134       ShapeHandle out;
135       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
136       TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out));
137       c->set_output(0, out);
138       return Status::OK();
139     });
140 
141 REGISTER_OP("RandomGammaGrad")
142     .Input("alpha: T")
143     .Input("sample: T")
144     .Output("output: T")
145     .Attr("T: {float, double}")
146     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
147 
148 REGISTER_OP("RandomPoisson")
149     .SetIsStateful()
150     .Input("shape: S")
151     .Input("rate: dtype")
152     .Output("output: dtype")
153     .Attr("seed: int = 0")
154     .Attr("seed2: int = 0")
155     .Attr("S: {int32, int64}")
156     .Attr("dtype: {half, float, double}")
__anon979c5d7d0502(InferenceContext* c) 157     .SetShapeFn([](InferenceContext* c) {
158       ShapeHandle out;
159       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
160       TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out));
161       c->set_output(0, out);
162       return Status::OK();
163     })
164     .Deprecated(25, "Replaced by RandomPoissonV2");
165 
166 REGISTER_OP("RandomPoissonV2")
167     .SetIsStateful()
168     .Input("shape: S")
169     .Input("rate: R")
170     .Output("output: dtype")
171     .Attr("seed: int = 0")
172     .Attr("seed2: int = 0")
173     .Attr("S: {int32, int64}")
174     .Attr("R: {half, float, double, int32, int64} = DT_DOUBLE")
175     .Attr("dtype: {half, float, double, int32, int64} = DT_INT64")
__anon979c5d7d0602(InferenceContext* c) 176     .SetShapeFn([](InferenceContext* c) {
177       ShapeHandle out;
178       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
179       TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out));
180       c->set_output(0, out);
181       return Status::OK();
182     });
183 
184 }  // namespace tensorflow
185