1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/ops_func_impl/tuple_to_tensor.h"
18
19 #include <utility>
20 #include <memory>
21 #include "ops/ops_frontend_func_impl.h"
22 #include "ops/auto_generate/gen_ops_name.h"
23 #include "ops/op_utils.h"
24 #include "ir/dtype.h"
25 #include "ir/dtype/number.h"
26 #include "ir/dtype/type.h"
27 #include "ir/anf.h"
28 #include "ir/primitive.h"
29 #include "ops/op_name.h"
30 #include "kernel/kernel.h"
31 #include "utils/ms_context.h"
32 #include "abstract/abstract_value.h"
33 #include "abstract/dshape.h"
34 #include "abstract/ops/op_infer.h"
35 #include "abstract/ops/primitive_infer_map.h"
36 #include "abstract/param_validator.h"
37 #include "base/base.h"
38 #include "mindapi/base/shape_vector.h"
39 #include "mindapi/src/helper.h"
40 #include "mindspore/core/ops/math_ops.h"
41 #include "mindspore/core/ops/sequence_ops.h"
42 #include "ops/base_operator.h"
43 #include "ops/list_to_tensor.h"
44 #include "ops/primitive_c.h"
45 #include "utils/check_convert_utils.h"
46 #include "utils/convert_utils_base.h"
47 #include "utils/log_adapter.h"
48
49 namespace mindspore {
50 namespace ops {
51 constexpr auto kTupleToTensor = "TupleToTensor";
CreateEmptyTupleTensorByType(const TypePtr & data_type)52 tensor::TensorPtr CreateEmptyTupleTensorByType(const TypePtr &data_type) {
53 std::vector<int64_t> tensor_shape = {0};
54 tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(data_type->type_id(), tensor_shape);
55 MS_EXCEPTION_IF_NULL(tensor);
56 return tensor;
57 }
58 template <typename T, typename S>
CreateTensorByTupleCast(const std::vector<T> & values,const TypePtr & type_ptr,const size_t data_len)59 tensor::TensorPtr CreateTensorByTupleCast(const std::vector<T> &values, const TypePtr &type_ptr,
60 const size_t data_len) {
61 std::vector<S> new_values;
62 (void)std::transform(values.begin(), values.end(), std::back_inserter(new_values),
63 [&](T value) -> S { return static_cast<S>(value); });
64 std::vector<int64_t> tensor_shape = {SizeToLong(new_values.size())};
65 tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_ptr->type_id(), tensor_shape);
66 MS_EXCEPTION_IF_NULL(tensor);
67 auto data_ptr = tensor->data_c();
68 MS_EXCEPTION_IF_NULL(data_ptr);
69 auto elem_num = new_values.size() * data_len;
70 auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(tensor->data().nbytes()), new_values.data(), elem_num);
71 if (ret_code != EOK) {
72 MS_LOG(EXCEPTION) << "Failed to copy data into tensor, memcpy_s errorno: " << ret_code;
73 }
74 return tensor;
75 }
76 template <typename T>
CreateTensorWithValueTuple(const ValueSequencePtr & value_tuple,const TypePtr & type_ptr,const size_t data_len)77 tensor::TensorPtr CreateTensorWithValueTuple(const ValueSequencePtr &value_tuple, const TypePtr &type_ptr,
78 const size_t data_len) {
79 MS_EXCEPTION_IF_NULL(value_tuple);
80 MS_EXCEPTION_IF_NULL(type_ptr);
81 std::vector<T> values;
82 auto first_type = value_tuple->value()[0]->type()->type_id();
83 for (const auto &v : value_tuple->value()) {
84 MS_EXCEPTION_IF_NULL(v);
85 if (v->isa<Scalar>()) {
86 ScalarPtr scalar = v->cast<ScalarPtr>();
87 auto cur_type = scalar->type()->type_id();
88 if (cur_type != first_type) {
89 MS_EXCEPTION(TypeError) << "the tuple elements type must be same, first element type = " << first_type
90 << " cur_type = " << cur_type;
91 }
92 values.push_back(GetValue<T>(scalar));
93 } else {
94 MS_EXCEPTION(TypeError) << "The value " << v << "of tuple is not a scalar";
95 }
96 }
97 if (type_ptr->type_id() == kNumberTypeInt32) {
98 return CreateTensorByTupleCast<T, int32_t>(values, type_ptr, data_len);
99 } else if (type_ptr->type_id() == kNumberTypeInt64) {
100 return CreateTensorByTupleCast<T, int64_t>(values, type_ptr, data_len);
101 } else if (type_ptr->type_id() == kNumberTypeFloat32) {
102 return CreateTensorByTupleCast<T, float>(values, type_ptr, data_len);
103 } else if (type_ptr->type_id() == kNumberTypeFloat64) {
104 return CreateTensorByTupleCast<T, double>(values, type_ptr, data_len);
105 } else if (type_ptr->type_id() == kNumberTypeBool) {
106 // std::vector<bool> is not a valid container, so use std::vector<int8_t> to hold the values
107 return CreateTensorByTupleCast<T, int8_t>(values, type_ptr, data_len);
108 } else {
109 MS_EXCEPTION(TypeError) << "Invalid scalar type: " << type_ptr->ToString();
110 }
111 }
112
SeqToTensorByType(const ValueSequencePtr & value_tuple,const TypePtr & data_type)113 tensor::TensorPtr SeqToTensorByType(const ValueSequencePtr &value_tuple, const TypePtr &data_type) {
114 tensor::TensorPtr tensor = nullptr;
115 if (value_tuple->value().empty()) {
116 tensor = CreateEmptyTupleTensorByType(data_type);
117 return tensor;
118 }
119 ValuePtr v = *(value_tuple->value().begin());
120 MS_EXCEPTION_IF_NULL(v);
121 // Currently we only deal with the scalar tuple
122 if (!v->isa<Scalar>()) {
123 MS_EXCEPTION(TypeError) << "The value " << v << "of tuple is not a scalar";
124 }
125 ScalarPtr scalar = v->cast<ScalarPtr>();
126 MS_EXCEPTION_IF_NULL(scalar);
127 size_t data_len = GetTypeByte(data_type);
128 if (scalar->isa<Int32Imm>()) {
129 tensor = CreateTensorWithValueTuple<int32_t>(value_tuple, data_type, data_len);
130 } else if (scalar->isa<Int64Imm>()) {
131 tensor = CreateTensorWithValueTuple<int64_t>(value_tuple, data_type, data_len);
132 } else if (scalar->isa<FP32Imm>()) {
133 tensor = CreateTensorWithValueTuple<float>(value_tuple, data_type, data_len);
134 } else if (scalar->isa<FP64Imm>()) {
135 tensor = CreateTensorWithValueTuple<double>(value_tuple, data_type, data_len);
136 } else if (scalar->isa<BoolImm>()) {
137 tensor = CreateTensorWithValueTuple<bool>(value_tuple, data_type, data_len);
138 } else {
139 auto type = scalar->type();
140 auto type_str = (type == nullptr) ? "nullptr" : type->ToString();
141 MS_EXCEPTION(TypeError) << "Invalid scalar type: " << type_str;
142 }
143 return tensor;
144 }
145
146 class TupleToTensorFrontendFuncImpl : public OpFrontendFuncImpl {
147 public:
InferValue(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const148 ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
149 MS_EXCEPTION_IF_NULL(primitive);
150 auto prim_name = primitive->name();
151 constexpr int64_t input_len = 2;
152 (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_len,
153 prim_name);
154 auto elem = abstract::CheckArg<abstract::AbstractSequence>(prim_name, input_args, 0);
155 auto elem_value = elem->GetValue();
156 if (elem_value->ContainsValueAny()) {
157 return nullptr;
158 }
159 auto value_tuple = elem_value->cast<ValueSequencePtr>();
160 MS_EXCEPTION_IF_NULL(value_tuple);
161 auto dtype_value = GetScalarValue<int64_t>(input_args[kInputIndex1]->GetValue());
162 MS_CHECK_VALUE(dtype_value.has_value(),
163 CheckAndConvertUtils::FormatCommMsg("For primitive[", prim_name,
164 "], the `dtype` should has valid value for static type."));
165 auto dst_type = TypeIdToType(static_cast<TypeId>(dtype_value.value()));
166 MS_EXCEPTION_IF_NULL(dst_type);
167 return SeqToTensorByType(value_tuple, dst_type);
168 }
169 };
170 REGISTER_PRIMITIVE_FUNCTION_FRONTEND_FUNC_IMPL(kTupleToTensor, TupleToTensorFrontendFuncImpl);
171 } // namespace ops
172 } // namespace mindspore
173