• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "ops/random_standard_normal.h"
17 #include <memory>
18 #include <set>
19 #include <string>
20 #include "mindapi/src/helper.h"
21 #include "mindspore/core/ops/random_ops.h"
22 #include "ops/op_utils.h"
23 #include "ops/standard_normal.h"
24 #include "utils/check_convert_utils.h"
25 
26 namespace mindspore {
27 namespace ops {
Init(const int64_t seed,const int64_t seed2)28 void RandomStandardNormal::Init(const int64_t seed, const int64_t seed2) {
29   this->set_seed(seed);
30   this->set_seed2(seed2);
31 }
32 
set_seed(int64_t seed)33 void RandomStandardNormal::set_seed(int64_t seed) { (void)this->AddAttr(kSeed, api::MakeValue(seed)); }
34 
set_seed2(int64_t seed2)35 void RandomStandardNormal::set_seed2(int64_t seed2) { (void)this->AddAttr(kSeed2, api::MakeValue(seed2)); }
36 
get_seed() const37 int64_t RandomStandardNormal::get_seed() const {
38   auto value_ptr = GetAttr(kSeed);
39   return GetValue<int64_t>(value_ptr);
40 }
41 
get_seed2() const42 int64_t RandomStandardNormal::get_seed2() const {
43   auto value_ptr = GetAttr(kSeed2);
44   return GetValue<int64_t>(value_ptr);
45 }
46 
Init(const int64_t seed,const int64_t seed2)47 void StandardNormal::Init(const int64_t seed, const int64_t seed2) {
48   this->set_seed(seed);
49   this->set_seed2(seed2);
50 }
51 
set_seed(int64_t seed)52 void StandardNormal::set_seed(int64_t seed) { (void)this->AddAttr(kSeed, api::MakeValue(seed)); }
53 
set_seed2(int64_t seed2)54 void StandardNormal::set_seed2(int64_t seed2) { (void)this->AddAttr(kSeed2, api::MakeValue(seed2)); }
55 
get_seed() const56 int64_t StandardNormal::get_seed() const {
57   auto value_ptr = GetAttr(kSeed);
58   return GetValue<int64_t>(value_ptr);
59 }
60 
get_seed2() const61 int64_t StandardNormal::get_seed2() const {
62   auto value_ptr = GetAttr(kSeed2);
63   return GetValue<int64_t>(value_ptr);
64 }
65 
66 namespace {
RandomStandardNormalInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)67 abstract::ShapePtr RandomStandardNormalInferShape(const PrimitivePtr &primitive,
68                                                   const std::vector<AbstractBasePtr> &input_args) {
69   MS_EXCEPTION_IF_NULL(primitive);
70   auto prim_name = primitive->name();
71   MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
72   auto shape_value = input_args[kInputIndex0]->GetValue();
73   MS_EXCEPTION_IF_NULL(shape_value);
74   if (CheckAndConvertUtils::IsTuple(input_args[kInputIndex0])) {
75     if (IsValueKnown(shape_value)) {
76       // Static Shape.
77       std::vector<int64_t> out_shape =
78         CheckAndConvertUtils::CheckIntOrTupleInt("input[shape]", input_args[kInputIndex0], prim_name);
79       (void)CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, prim_name);
80       return std::make_shared<abstract::Shape>(out_shape);
81     }
82     auto shape_value_opt = ops::GetArrayValue<ShapeValueDType>(input_args[kInputIndex0]);
83     // Dynamic rank.
84     if (!shape_value_opt.has_value()) {
85       return std::make_shared<abstract::TensorShape>(ShapeVector{abstract::TensorShape::kShapeRankAny});
86     }
87     // Dynamic shape.
88     auto array_value = shape_value_opt.value();
89     ShapeVector shape;
90     for (size_t i = 0; i < array_value.size(); ++i) {
91       if (array_value.IsValueUnknown(i)) {
92         shape.push_back(abstract::TensorShape::kShapeDimAny);
93       } else {
94         shape.push_back(array_value[i]);
95       }
96     }
97     return std::make_shared<abstract::Shape>(shape);
98   } else if (CheckAndConvertUtils::IsTensor(input_args[kInputIndex0])) {
99     if (IsValueKnown(shape_value)) {
100       auto shape_ptr = input_args[kInputIndex0]->GetShape();
101       auto shape_vec = shape_ptr->GetShapeVector();
102       auto rank = shape_vec.size();
103       MS_CHECK_VALUE(
104         rank == 1 || rank == 0,
105         CheckAndConvertUtils::FormatCommMsg(
106           "For op[", prim_name, "], if input [shape] is a tensor, its rank must be 1 or 0, but got: ", rank));
107       ShapeVector input_shape = CheckAndConvertUtils::CheckTensorIntValue("input[shape]", shape_value, prim_name,
108                                                                           input_args[kInputIndex0]->GetType());
109       (void)CheckAndConvertUtils::CheckPositiveVector("shape", input_shape, prim_name);
110       return std::make_shared<abstract::Shape>(input_shape);
111     } else {
112       constexpr int dynamic_rank_value = -2;
113       ShapeVector shape = {dynamic_rank_value};
114       return std::make_shared<abstract::Shape>(shape);
115     }
116   } else {
117     MS_EXCEPTION(TypeError) << "For '" << prim_name
118                             << "', input must be a Int, a tuple, or a Tensor with all Int elements, but got: "
119                             << input_args[kInputIndex0]->ToString() << ".";
120   }
121 }
122 
RandomStandardNormalInferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)123 TypePtr RandomStandardNormalInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
124   MS_EXCEPTION_IF_NULL(primitive);
125   auto prim_name = primitive->name();
126   MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
127   if (CheckAndConvertUtils::IsTuple(input_args[kInputIndex0])) {
128     auto elements_type = input_args[kInputIndex0]->GetType()->cast<TuplePtr>();
129     MS_EXCEPTION_IF_NULL(elements_type);
130     const std::set<TypePtr> valid_shape_types = {kInt32, kInt64};
131     for (const auto &input_dtype : elements_type->elements()) {
132       (void)CheckAndConvertUtils::CheckSubClass("shape", input_dtype, valid_shape_types, prim_name);
133     }
134   } else if (CheckAndConvertUtils::IsTensor(input_args[kInputIndex0])) {
135     const std::set<TypePtr> valid_shape_types = {kInt32, kInt64};
136     auto input_dtype = input_args[kInputIndex0]->GetType();
137     (void)CheckAndConvertUtils::CheckTensorTypeValid("shape", input_dtype, valid_shape_types, prim_name);
138   } else {
139     MS_EXCEPTION(TypeError) << "For '" << prim_name
140                             << "', input must be a Int, a tuple, or a Tensor with all Int elements, but got: "
141                             << input_args[kInputIndex0]->ToString() << ".";
142   }
143   return std::make_shared<TensorType>(kFloat32);
144 }
145 }  // namespace
146 
RandomStandardNormalInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)147 AbstractBasePtr RandomStandardNormalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
148                                           const std::vector<AbstractBasePtr> &input_args) {
149   MS_EXCEPTION_IF_NULL(primitive);
150   auto prim_name = primitive->name();
151   for (const auto &item : input_args) {
152     MS_EXCEPTION_IF_NULL(item);
153   }
154   const int64_t kMinInputNum = 1;
155   const int64_t kMaxInputNum = 3;
156   (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, kMinInputNum,
157                                            prim_name);
158   (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kLessEqual, kMaxInputNum,
159                                            prim_name);
160   auto type = RandomStandardNormalInferType(primitive, input_args);
161   auto shape = RandomStandardNormalInferShape(primitive, input_args);
162   return abstract::MakeAbstract(shape, type);
163 }
164 
165 MIND_API_OPERATOR_IMPL(RandomStandardNormal, BaseOperator);
166 MIND_API_OPERATOR_IMPL(StandardNormal, BaseOperator);
167 
168 // AG means auto generated
169 class MIND_API AGRandomStandardNormalInfer : public abstract::OpInferBase {
170  public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const171   BaseShapePtr InferShape(const PrimitivePtr &primitive,
172                           const std::vector<AbstractBasePtr> &input_args) const override {
173     return RandomStandardNormalInferShape(primitive, input_args);
174   }
175 
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const176   TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
177     return RandomStandardNormalInferType(primitive, input_args);
178   }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const179   AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
180                                     const std::vector<AbstractBasePtr> &input_args) const override {
181     return RandomStandardNormalInfer(engine, primitive, input_args);
182   }
183 
GetValueDependArgIndices() const184   std::set<int64_t> GetValueDependArgIndices() const override { return {0}; }
185 };
186 
187 REGISTER_PRIMITIVE_OP_INFER_IMPL(RandomStandardNormal, prim::kPrimStandardNormal, AGRandomStandardNormalInfer, false);
188 REGISTER_PRIMITIVE_OP_INFER_IMPL(StandardNormal, prim::kPrimStandardNormal, AGRandomStandardNormalInfer, false);
189 }  // namespace ops
190 }  // namespace mindspore
191