1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/common_shape_fns.h" 17 #include "tensorflow/core/framework/op.h" 18 #include "tensorflow/core/framework/shape_inference.h" 19 20 namespace tensorflow { 21 22 using shape_inference::DimensionHandle; 23 using shape_inference::InferenceContext; 24 using shape_inference::ShapeHandle; 25 26 REGISTER_OP("RandomUniform") 27 .Input("shape: T") 28 .SetIsStateful() 29 .Output("output: dtype") 30 .Attr("seed: int = 0") 31 .Attr("seed2: int = 0") 32 .Attr("dtype: {half,bfloat16,float,double}") 33 .Attr("T: {int32, int64}") 34 .SetShapeFn(shape_inference::RandomShape); 35 36 REGISTER_OP("RandomUniformInt") 37 .Input("shape: T") 38 .Input("minval: Tout") 39 .Input("maxval: Tout") 40 .SetIsStateful() 41 .Output("output: Tout") 42 .Attr("seed: int = 0") 43 .Attr("seed2: int = 0") 44 .Attr("Tout: {int32, int64}") 45 .Attr("T: {int32, int64}") __anon9ce6b60a0102(InferenceContext* c) 46 .SetShapeFn([](InferenceContext* c) { 47 ShapeHandle unused; 48 Status s = c->WithRank(c->input(1), 0, &unused); 49 if (!s.ok()) { 50 return errors::InvalidArgument( 51 "minval must be a scalar; got a tensor of shape ", 52 c->DebugString(c->input(1))); 53 } 54 s = c->WithRank(c->input(2), 0, &unused); 55 if (!s.ok()) { 56 return errors::InvalidArgument( 57 "maxval must be a scalar; got a tensor of shape ", 58 c->DebugString(c->input(2))); 59 } 60 return shape_inference::RandomShape(c); 61 }); 62 63 REGISTER_OP("RandomStandardNormal") 64 .Input("shape: T") 65 .SetIsStateful() 66 .Output("output: dtype") 67 .Attr("seed: int = 0") 68 .Attr("seed2: int = 0") 69 .Attr("dtype: {half,bfloat16,float,double}") 70 .Attr("T: {int32, int64}") 71 .SetShapeFn(shape_inference::RandomShape); 72 73 REGISTER_OP("ParameterizedTruncatedNormal") 74 .Input("shape: T") 75 .Input("means: dtype") 76 .Input("stdevs: dtype") 77 .Input("minvals: dtype") 78 .Input("maxvals: dtype") 79 .SetIsStateful() 80 .Output("output: dtype") 81 .Attr("seed: int = 0") 82 .Attr("seed2: int = 0") 83 .Attr("dtype: {half,bfloat16,float,double}") 84 .Attr("T: {int32, int64}") __anon9ce6b60a0202(InferenceContext* c) 85 .SetShapeFn([](InferenceContext* c) { 86 ShapeHandle unused; 87 // Parameters must be 0-d or 1-d. 88 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &unused)); 89 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused)); 90 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(3), 1, &unused)); 91 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &unused)); 92 return shape_inference::RandomShape(c); 93 }); 94 95 REGISTER_OP("TruncatedNormal") 96 .Input("shape: T") 97 .SetIsStateful() 98 .Output("output: dtype") 99 .Attr("seed: int = 0") 100 .Attr("seed2: int = 0") 101 .Attr("dtype: {half,bfloat16,float,double}") 102 .Attr("T: {int32, int64}") 103 .SetShapeFn(shape_inference::RandomShape); 104 105 REGISTER_OP("RandomShuffle") 106 .Input("value: T") 107 .SetIsStateful() 108 .Output("output: T") 109 .Attr("seed: int = 0") 110 .Attr("seed2: int = 0") 111 .Attr("T: type") 112 .SetShapeFn(shape_inference::UnchangedShape); 113 114 REGISTER_OP("Multinomial") 115 .SetIsStateful() 116 .Input("logits: T") 117 .Input("num_samples: int32") 118 .Output("output: output_dtype") 119 .Attr("seed: int = 0") 120 .Attr("seed2: int = 0") 121 .Attr("T: realnumbertype") 122 .Attr("output_dtype: {int32, int64} = DT_INT64") __anon9ce6b60a0302(InferenceContext* c) 123 .SetShapeFn([](InferenceContext* c) { 124 ShapeHandle logits_shape; 125 ShapeHandle unused; 126 DimensionHandle num_samples; 127 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &logits_shape)); 128 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 129 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &num_samples)); 130 c->set_output(0, c->Matrix(c->Dim(logits_shape, 0), num_samples)); 131 return OkStatus(); 132 }); 133 134 REGISTER_OP("RandomGamma") 135 .SetIsStateful() 136 .Input("shape: S") 137 .Input("alpha: T") 138 .Output("output: T") 139 .Attr("seed: int = 0") 140 .Attr("seed2: int = 0") 141 .Attr("S: {int32, int64}") 142 .Attr("T: {half, float, double}") __anon9ce6b60a0402(InferenceContext* c) 143 .SetShapeFn([](InferenceContext* c) { 144 ShapeHandle out; 145 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); 146 TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out)); 147 c->set_output(0, out); 148 return OkStatus(); 149 }); 150 151 REGISTER_OP("RandomGammaGrad") 152 .Input("alpha: T") 153 .Input("sample: T") 154 .Output("output: T") 155 .Attr("T: {float, double}") 156 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); 157 158 REGISTER_OP("RandomPoisson") 159 .SetIsStateful() 160 .Input("shape: S") 161 .Input("rate: dtype") 162 .Output("output: dtype") 163 .Attr("seed: int = 0") 164 .Attr("seed2: int = 0") 165 .Attr("S: {int32, int64}") 166 .Attr("dtype: {half, float, double}") __anon9ce6b60a0502(InferenceContext* c) 167 .SetShapeFn([](InferenceContext* c) { 168 ShapeHandle out; 169 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); 170 TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out)); 171 c->set_output(0, out); 172 return OkStatus(); 173 }) 174 .Deprecated(25, "Replaced by RandomPoissonV2"); 175 176 REGISTER_OP("RandomPoissonV2") 177 .SetIsStateful() 178 .Input("shape: S") 179 .Input("rate: R") 180 .Output("output: dtype") 181 .Attr("seed: int = 0") 182 .Attr("seed2: int = 0") 183 .Attr("S: {int32, int64}") 184 .Attr("R: {half, float, double, int32, int64} = DT_DOUBLE") 185 .Attr("dtype: {half, float, double, int32, int64} = DT_INT64") __anon9ce6b60a0602(InferenceContext* c) 186 .SetShapeFn([](InferenceContext* c) { 187 ShapeHandle out; 188 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); 189 TF_RETURN_IF_ERROR(c->Concatenate(out, c->input(1), &out)); 190 c->set_output(0, out); 191 return OkStatus(); 192 }); 193 194 } // namespace tensorflow 195