1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/ops_func_impl/dropout_ext.h"
18 #include <limits>
19 #include <memory>
20 #include <string>
21 #include "ops/op_utils.h"
22
23 namespace mindspore {
24 namespace ops {
CalMaskShape(const PrimitivePtr & primitive,const ShapeVector & shape_vec)25 int64_t CalMaskShape(const PrimitivePtr &primitive, const ShapeVector &shape_vec) {
26 constexpr int64_t kDropoutGenMaskMaskConvertLen = 128;
27 int64_t count = 1;
28 for (size_t i = 0; i < shape_vec.size(); i++) {
29 auto dim_value = shape_vec[i];
30 if (dim_value <= 0) {
31 MS_LOG(EXCEPTION) << "For '" << primitive->name()
32 << "', each dim of 'shape' must be greater than 0, but got shape[" << i << "]: " << dim_value
33 << ".";
34 }
35
36 if (std::numeric_limits<int64_t>::max() / count / dim_value < 1) {
37 MS_LOG(EXCEPTION) << "For '" << primitive->name() << "', integer multiply integer overflow.";
38 }
39 count *= shape_vec[i];
40 }
41
42 int64_t n128s = count / kDropoutGenMaskMaskConvertLen;
43 if ((count % kDropoutGenMaskMaskConvertLen) != 0) {
44 n128s++;
45 }
46 int64_t bytes_count = n128s * 16;
47
48 return bytes_count;
49 }
50
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const51 BaseShapePtr DropoutExtFuncImpl::InferShape(const PrimitivePtr &primitive,
52 const std::vector<AbstractBasePtr> &input_args) const {
53 auto x_shape_ptr = input_args[kIndex0]->GetShape();
54 auto x_shape = input_args[kIndex0]->GetShape()->GetShapeVector();
55 ShapeVector mask_shape;
56 if (x_shape_ptr->IsDynamic()) {
57 mask_shape.push_back(abstract::TensorShape::kShapeDimAny);
58 } else {
59 mask_shape.push_back(CalMaskShape(primitive, x_shape));
60 }
61 auto mask_shape_ptr = std::make_shared<abstract::TensorShape>(mask_shape);
62 return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{x_shape_ptr, mask_shape_ptr});
63 }
64
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const65 TypePtr DropoutExtFuncImpl::InferType(const PrimitivePtr &primitive,
66 const std::vector<AbstractBasePtr> &input_args) const {
67 auto x_type = input_args[0]->GetType();
68 return std::make_shared<Tuple>(std::vector<TypePtr>{x_type, std::make_shared<TensorType>(kUInt8)});
69 }
70
CheckValidation(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const71 int32_t DropoutExtFuncImpl::CheckValidation(const PrimitivePtr &primitive,
72 const std::vector<AbstractBasePtr> &input_args) const {
73 MS_EXCEPTION_IF_NULL(input_args[kIndex1]);
74 const auto &p_opt = GetScalarValue<float>(input_args[kIndex1]->GetValue());
75 if (MS_UNLIKELY(!p_opt.has_value())) {
76 return OP_CHECK_RETRY;
77 }
78 MS_CHECK_VALUE(p_opt.value() >= static_cast<float>(0.0) && p_opt.value() <= static_cast<float>(1.0),
79 "For 'DropoutExt', the 'p' must be in range [0, 1], but got " + std::to_string(p_opt.value()));
80 return OP_CHECK_SUCCESS;
81 }
82 } // namespace ops
83 } // namespace mindspore
84