1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/common_shape_fns.h" 17 #include "tensorflow/core/framework/op.h" 18 #include "tensorflow/core/framework/shape_inference.h" 19 20 namespace tensorflow { 21 22 using shape_inference::DimensionHandle; 23 using shape_inference::InferenceContext; 24 using shape_inference::ShapeHandle; 25 26 REGISTER_OP("RandomUniform") 27 .Input("shape: T") 28 .SetIsStateful() 29 .Output("output: dtype") 30 .Attr("seed: int = 0") 31 .Attr("seed2: int = 0") 32 .Attr("dtype: {half,bfloat16,float,double}") 33 .Attr("T: {int32, int64}") 34 .SetShapeFn(shape_inference::RandomShape); 35 36 REGISTER_OP("RandomUniformInt") 37 .Input("shape: T") 38 .Input("minval: Tout") 39 .Input("maxval: Tout") 40 .SetIsStateful() 41 .Output("output: Tout") 42 .Attr("seed: int = 0") 43 .Attr("seed2: int = 0") 44 .Attr("Tout: {int32, int64}") 45 .Attr("T: {int32, int64}") __anon979c5d7d0102(InferenceContext* c) 46 .SetShapeFn([](InferenceContext* c) { 47 ShapeHandle unused; 48 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 49 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 50 return shape_inference::RandomShape(c); 51 }); 52 53 REGISTER_OP("RandomStandardNormal") 54 .Input("shape: T") 55 .SetIsStateful() 56 .Output("output: dtype") 57 .Attr("seed: int = 0") 58 .Attr("seed2: int = 0") 59 .Attr("dtype: {half,bfloat16,float,double}") 60 .Attr("T: {int32, int64}") 61 .SetShapeFn(shape_inference::RandomShape); 62 63 REGISTER_OP("ParameterizedTruncatedNormal") 64 .Input("shape: T") 65 .Input("means: dtype") 66 .Input("stdevs: dtype") 67 .Input("minvals: dtype") 68 .Input("maxvals: dtype") 69 .SetIsStateful() 70 .Output("output: dtype") 71 .Attr("seed: int = 0") 72 .Attr("seed2: int = 0") 73 .Attr("dtype: {half,bfloat16,float,double}") 74 .Attr("T: {int32, int64}") __anon979c5d7d0202(InferenceContext* c) 75 .SetShapeFn([](InferenceContext* c) { 76 ShapeHandle unused; 77 // Parameters must be 0-d or 1-d. 78 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &unused)); 79 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused)); 80 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(3), 1, &unused)); 81 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &unused)); 82 return shape_inference::RandomShape(c); 83 }); 84 85 REGISTER_OP("TruncatedNormal") 86 .Input("shape: T") 87 .SetIsStateful() 88 .Output("output: dtype") 89 .Attr("seed: int = 0") 90 .Attr("seed2: int = 0") 91 .Attr("dtype: {half,bfloat16,float,double}") 92 .Attr("T: {int32, int64}") 93 .SetShapeFn(shape_inference::RandomShape); 94 95 REGISTER_OP("RandomShuffle") 96 .Input("value: T") 97 .SetIsStateful() 98 .Output("output: T") 99 .Attr("seed: int = 0") 100 .Attr("seed2: int = 0") 101 .Attr("T: type") 102 .SetShapeFn(shape_inference::UnchangedShape); 103 104 REGISTER_OP("Multinomial") 105 .SetIsStateful() 106 .Input("logits: T") 107 .Input("num_samples: int32") 108 .Output("output: output_dtype") 109 .Attr("seed: int = 0") 110 .Attr("seed2: int = 0") 111 .Attr("T: realnumbertype") 112 .Attr("output_dtype: {int32, int64} = DT_INT64") __anon979c5d7d0302(InferenceContext* c) 113 .SetShapeFn([](InferenceContext* c) { 114 ShapeHandle logits_shape; 115 ShapeHandle unused; 116 DimensionHandle num_samples; 117 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &logits_shape)); 118 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 119 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &num_samples)); 120 c->set_output(0, c->Matrix(c->Dim(logits_shape, 0), num_samples)); 121 return Status::OK(); 122 }); 123 124 REGISTER_OP("RandomGamma") 125 .SetIsStateful() 126 .Input("shape: S") 127 .Input("alpha: T") 128 .Output("output: T") 129 .Attr("seed: int = 0") 130 .Attr("seed2: int = 0") 131 .Attr("S: {int32, int64}") 132 .Attr("T: {half, float, double}") __anon979c5d7d0402(InferenceContext* c) 133 .SetShapeFn([](InferenceContext* c) { 134 ShapeHandle out; 135 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); 136 TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out)); 137 c->set_output(0, out); 138 return Status::OK(); 139 }); 140 141 REGISTER_OP("RandomGammaGrad") 142 .Input("alpha: T") 143 .Input("sample: T") 144 .Output("output: T") 145 .Attr("T: {float, double}") 146 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 147 148 REGISTER_OP("RandomPoisson") 149 .SetIsStateful() 150 .Input("shape: S") 151 .Input("rate: dtype") 152 .Output("output: dtype") 153 .Attr("seed: int = 0") 154 .Attr("seed2: int = 0") 155 .Attr("S: {int32, int64}") 156 .Attr("dtype: {half, float, double}") __anon979c5d7d0502(InferenceContext* c) 157 .SetShapeFn([](InferenceContext* c) { 158 ShapeHandle out; 159 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); 160 TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out)); 161 c->set_output(0, out); 162 return Status::OK(); 163 }) 164 .Deprecated(25, "Replaced by RandomPoissonV2"); 165 166 REGISTER_OP("RandomPoissonV2") 167 .SetIsStateful() 168 .Input("shape: S") 169 .Input("rate: R") 170 .Output("output: dtype") 171 .Attr("seed: int = 0") 172 .Attr("seed2: int = 0") 173 .Attr("S: {int32, int64}") 174 .Attr("R: {half, float, double, int32, int64} = DT_DOUBLE") 175 .Attr("dtype: {half, float, double, int32, int64} = DT_INT64") __anon979c5d7d0602(InferenceContext* c) 176 .SetShapeFn([](InferenceContext* c) { 177 ShapeHandle out; 178 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); 179 TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out)); 180 c->set_output(0, out); 181 return Status::OK(); 182 }); 183 184 } // namespace tensorflow 185