1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/fill.h"
18 #include <memory>
19 #include "ops/op_utils.h"
20 #include "utils/check_convert_utils.h"
21 #include "utils/tensor_construct_utils.h"
22
23 namespace mindspore {
24 namespace ops {
FillInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)25 AbstractBasePtr FillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
26 const std::vector<AbstractBasePtr> &input_args) {
27 MS_EXCEPTION_IF_NULL(primitive);
28 auto prim_name = primitive->name();
29 const int64_t input_num = 3;
30 (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
31 for (const auto &item : input_args) {
32 MS_EXCEPTION_IF_NULL(item);
33 }
34 auto input_dtype = input_args[kInputIndex0]->cast<abstract::AbstractTypePtr>();
35 MS_EXCEPTION_IF_NULL(input_dtype);
36 auto dtype_value = input_dtype->BuildValue();
37 MS_EXCEPTION_IF_NULL(dtype_value);
38 auto dtype = dtype_value->cast<TypePtr>();
39 MS_EXCEPTION_IF_NULL(dtype);
40 auto valid_types = common_valid_types;
41 valid_types.insert(kBool);
42 (void)CheckAndConvertUtils::CheckTypeValid("output datatype", dtype, valid_types, prim_name);
43 auto out_shape = GetValue<std::vector<int64_t>>(input_args[kInputIndex1]->BuildValue());
44 auto x_type = input_args[kInputIndex2]->BuildType();
45 auto x_type_id = x_type->type_id();
46 auto x_value = input_args[kInputIndex2]->BuildValue();
47 auto abs = std::make_shared<abstract::AbstractTensor>(dtype, std::make_shared<abstract::Shape>(out_shape));
48 tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(x_type_id, out_shape);
49 MS_EXCEPTION_IF_NULL(tensor);
50 auto mem_size = IntToSize(tensor->ElementsNum());
51 if (x_type_id == kNumberTypeInt) {
52 auto int_value = GetValue<int>(x_value);
53 SetTensorData(tensor->data_c(), int_value, mem_size);
54 } else if (x_type_id == kNumberTypeFloat || x_type_id == kNumberTypeFloat32) {
55 auto float_value = GetValue<float>(x_value);
56 SetTensorData(tensor->data_c(), float_value, mem_size);
57 } else {
58 MS_LOG(ERROR) << " Fill not supported to flod the constant type " << input_args[kInputIndex2]->ToString();
59 }
60 abs->set_value(tensor);
61 return abs;
62 }
63 REGISTER_PRIMITIVE_C(kNameFill, Fill);
64 } // namespace ops
65 } // namespace mindspore
66