1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "ops/random_standard_normal.h"
17 #include <memory>
18 #include <set>
19 #include <string>
20 #include "mindapi/src/helper.h"
21 #include "mindspore/core/ops/random_ops.h"
22 #include "ops/op_utils.h"
23 #include "ops/standard_normal.h"
24 #include "utils/check_convert_utils.h"
25
26 namespace mindspore {
27 namespace ops {
Init(const int64_t seed,const int64_t seed2)28 void RandomStandardNormal::Init(const int64_t seed, const int64_t seed2) {
29 this->set_seed(seed);
30 this->set_seed2(seed2);
31 }
32
set_seed(int64_t seed)33 void RandomStandardNormal::set_seed(int64_t seed) { (void)this->AddAttr(kSeed, api::MakeValue(seed)); }
34
set_seed2(int64_t seed2)35 void RandomStandardNormal::set_seed2(int64_t seed2) { (void)this->AddAttr(kSeed2, api::MakeValue(seed2)); }
36
get_seed() const37 int64_t RandomStandardNormal::get_seed() const {
38 auto value_ptr = GetAttr(kSeed);
39 return GetValue<int64_t>(value_ptr);
40 }
41
get_seed2() const42 int64_t RandomStandardNormal::get_seed2() const {
43 auto value_ptr = GetAttr(kSeed2);
44 return GetValue<int64_t>(value_ptr);
45 }
46
Init(const int64_t seed,const int64_t seed2)47 void StandardNormal::Init(const int64_t seed, const int64_t seed2) {
48 this->set_seed(seed);
49 this->set_seed2(seed2);
50 }
51
set_seed(int64_t seed)52 void StandardNormal::set_seed(int64_t seed) { (void)this->AddAttr(kSeed, api::MakeValue(seed)); }
53
set_seed2(int64_t seed2)54 void StandardNormal::set_seed2(int64_t seed2) { (void)this->AddAttr(kSeed2, api::MakeValue(seed2)); }
55
get_seed() const56 int64_t StandardNormal::get_seed() const {
57 auto value_ptr = GetAttr(kSeed);
58 return GetValue<int64_t>(value_ptr);
59 }
60
get_seed2() const61 int64_t StandardNormal::get_seed2() const {
62 auto value_ptr = GetAttr(kSeed2);
63 return GetValue<int64_t>(value_ptr);
64 }
65
66 namespace {
RandomStandardNormalInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)67 abstract::ShapePtr RandomStandardNormalInferShape(const PrimitivePtr &primitive,
68 const std::vector<AbstractBasePtr> &input_args) {
69 MS_EXCEPTION_IF_NULL(primitive);
70 auto prim_name = primitive->name();
71 MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
72 auto shape_value = input_args[kInputIndex0]->GetValue();
73 MS_EXCEPTION_IF_NULL(shape_value);
74 if (CheckAndConvertUtils::IsTuple(input_args[kInputIndex0])) {
75 if (IsValueKnown(shape_value)) {
76 // Static Shape.
77 std::vector<int64_t> out_shape =
78 CheckAndConvertUtils::CheckIntOrTupleInt("input[shape]", input_args[kInputIndex0], prim_name);
79 (void)CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, prim_name);
80 return std::make_shared<abstract::Shape>(out_shape);
81 }
82 auto shape_value_opt = ops::GetArrayValue<ShapeValueDType>(input_args[kInputIndex0]);
83 // Dynamic rank.
84 if (!shape_value_opt.has_value()) {
85 return std::make_shared<abstract::TensorShape>(ShapeVector{abstract::TensorShape::kShapeRankAny});
86 }
87 // Dynamic shape.
88 auto array_value = shape_value_opt.value();
89 ShapeVector shape;
90 for (size_t i = 0; i < array_value.size(); ++i) {
91 if (array_value.IsValueUnknown(i)) {
92 shape.push_back(abstract::TensorShape::kShapeDimAny);
93 } else {
94 shape.push_back(array_value[i]);
95 }
96 }
97 return std::make_shared<abstract::Shape>(shape);
98 } else if (CheckAndConvertUtils::IsTensor(input_args[kInputIndex0])) {
99 if (IsValueKnown(shape_value)) {
100 auto shape_ptr = input_args[kInputIndex0]->GetShape();
101 auto shape_vec = shape_ptr->GetShapeVector();
102 auto rank = shape_vec.size();
103 MS_CHECK_VALUE(
104 rank == 1 || rank == 0,
105 CheckAndConvertUtils::FormatCommMsg(
106 "For op[", prim_name, "], if input [shape] is a tensor, its rank must be 1 or 0, but got: ", rank));
107 ShapeVector input_shape = CheckAndConvertUtils::CheckTensorIntValue("input[shape]", shape_value, prim_name,
108 input_args[kInputIndex0]->GetType());
109 (void)CheckAndConvertUtils::CheckPositiveVector("shape", input_shape, prim_name);
110 return std::make_shared<abstract::Shape>(input_shape);
111 } else {
112 constexpr int dynamic_rank_value = -2;
113 ShapeVector shape = {dynamic_rank_value};
114 return std::make_shared<abstract::Shape>(shape);
115 }
116 } else {
117 MS_EXCEPTION(TypeError) << "For '" << prim_name
118 << "', input must be a Int, a tuple, or a Tensor with all Int elements, but got: "
119 << input_args[kInputIndex0]->ToString() << ".";
120 }
121 }
122
RandomStandardNormalInferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)123 TypePtr RandomStandardNormalInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
124 MS_EXCEPTION_IF_NULL(primitive);
125 auto prim_name = primitive->name();
126 MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
127 if (CheckAndConvertUtils::IsTuple(input_args[kInputIndex0])) {
128 auto elements_type = input_args[kInputIndex0]->GetType()->cast<TuplePtr>();
129 MS_EXCEPTION_IF_NULL(elements_type);
130 const std::set<TypePtr> valid_shape_types = {kInt32, kInt64};
131 for (const auto &input_dtype : elements_type->elements()) {
132 (void)CheckAndConvertUtils::CheckSubClass("shape", input_dtype, valid_shape_types, prim_name);
133 }
134 } else if (CheckAndConvertUtils::IsTensor(input_args[kInputIndex0])) {
135 const std::set<TypePtr> valid_shape_types = {kInt32, kInt64};
136 auto input_dtype = input_args[kInputIndex0]->GetType();
137 (void)CheckAndConvertUtils::CheckTensorTypeValid("shape", input_dtype, valid_shape_types, prim_name);
138 } else {
139 MS_EXCEPTION(TypeError) << "For '" << prim_name
140 << "', input must be a Int, a tuple, or a Tensor with all Int elements, but got: "
141 << input_args[kInputIndex0]->ToString() << ".";
142 }
143 return std::make_shared<TensorType>(kFloat32);
144 }
145 } // namespace
146
RandomStandardNormalInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)147 AbstractBasePtr RandomStandardNormalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
148 const std::vector<AbstractBasePtr> &input_args) {
149 MS_EXCEPTION_IF_NULL(primitive);
150 auto prim_name = primitive->name();
151 for (const auto &item : input_args) {
152 MS_EXCEPTION_IF_NULL(item);
153 }
154 const int64_t kMinInputNum = 1;
155 const int64_t kMaxInputNum = 3;
156 (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, kMinInputNum,
157 prim_name);
158 (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kLessEqual, kMaxInputNum,
159 prim_name);
160 auto type = RandomStandardNormalInferType(primitive, input_args);
161 auto shape = RandomStandardNormalInferShape(primitive, input_args);
162 return abstract::MakeAbstract(shape, type);
163 }
164
165 MIND_API_OPERATOR_IMPL(RandomStandardNormal, BaseOperator);
166 MIND_API_OPERATOR_IMPL(StandardNormal, BaseOperator);
167
168 // AG means auto generated
169 class MIND_API AGRandomStandardNormalInfer : public abstract::OpInferBase {
170 public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const171 BaseShapePtr InferShape(const PrimitivePtr &primitive,
172 const std::vector<AbstractBasePtr> &input_args) const override {
173 return RandomStandardNormalInferShape(primitive, input_args);
174 }
175
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const176 TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
177 return RandomStandardNormalInferType(primitive, input_args);
178 }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const179 AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
180 const std::vector<AbstractBasePtr> &input_args) const override {
181 return RandomStandardNormalInfer(engine, primitive, input_args);
182 }
183
GetValueDependArgIndices() const184 std::set<int64_t> GetValueDependArgIndices() const override { return {0}; }
185 };
186
187 REGISTER_PRIMITIVE_OP_INFER_IMPL(RandomStandardNormal, prim::kPrimStandardNormal, AGRandomStandardNormalInfer, false);
188 REGISTER_PRIMITIVE_OP_INFER_IMPL(StandardNormal, prim::kPrimStandardNormal, AGRandomStandardNormalInfer, false);
189 } // namespace ops
190 } // namespace mindspore
191