• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/ops_func_impl/tuple_to_tensor.h"
18 
19 #include <utility>
20 #include <memory>
21 #include "ops/ops_frontend_func_impl.h"
22 #include "ops/auto_generate/gen_ops_name.h"
23 #include "ops/op_utils.h"
24 #include "ir/dtype.h"
25 #include "ir/dtype/number.h"
26 #include "ir/dtype/type.h"
27 #include "ir/anf.h"
28 #include "ir/primitive.h"
29 #include "ops/op_name.h"
30 #include "kernel/kernel.h"
31 #include "utils/ms_context.h"
32 #include "abstract/abstract_value.h"
33 #include "abstract/dshape.h"
34 #include "abstract/ops/op_infer.h"
35 #include "abstract/ops/primitive_infer_map.h"
36 #include "abstract/param_validator.h"
37 #include "base/base.h"
38 #include "mindapi/base/shape_vector.h"
39 #include "mindapi/src/helper.h"
40 #include "mindspore/core/ops/math_ops.h"
41 #include "mindspore/core/ops/sequence_ops.h"
42 #include "ops/base_operator.h"
43 #include "ops/list_to_tensor.h"
44 #include "ops/primitive_c.h"
45 #include "utils/check_convert_utils.h"
46 #include "utils/convert_utils_base.h"
47 #include "utils/log_adapter.h"
48 
49 namespace mindspore {
50 namespace ops {
51 constexpr auto kTupleToTensor = "TupleToTensor";
CreateEmptyTupleTensorByType(const TypePtr & data_type)52 tensor::TensorPtr CreateEmptyTupleTensorByType(const TypePtr &data_type) {
53   std::vector<int64_t> tensor_shape = {0};
54   tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(data_type->type_id(), tensor_shape);
55   MS_EXCEPTION_IF_NULL(tensor);
56   return tensor;
57 }
58 template <typename T, typename S>
CreateTensorByTupleCast(const std::vector<T> & values,const TypePtr & type_ptr,const size_t data_len)59 tensor::TensorPtr CreateTensorByTupleCast(const std::vector<T> &values, const TypePtr &type_ptr,
60                                           const size_t data_len) {
61   std::vector<S> new_values;
62   (void)std::transform(values.begin(), values.end(), std::back_inserter(new_values),
63                        [&](T value) -> S { return static_cast<S>(value); });
64   std::vector<int64_t> tensor_shape = {SizeToLong(new_values.size())};
65   tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_ptr->type_id(), tensor_shape);
66   MS_EXCEPTION_IF_NULL(tensor);
67   auto data_ptr = tensor->data_c();
68   MS_EXCEPTION_IF_NULL(data_ptr);
69   auto elem_num = new_values.size() * data_len;
70   auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(tensor->data().nbytes()), new_values.data(), elem_num);
71   if (ret_code != EOK) {
72     MS_LOG(EXCEPTION) << "Failed to copy data into tensor, memcpy_s errorno: " << ret_code;
73   }
74   return tensor;
75 }
76 template <typename T>
CreateTensorWithValueTuple(const ValueSequencePtr & value_tuple,const TypePtr & type_ptr,const size_t data_len)77 tensor::TensorPtr CreateTensorWithValueTuple(const ValueSequencePtr &value_tuple, const TypePtr &type_ptr,
78                                              const size_t data_len) {
79   MS_EXCEPTION_IF_NULL(value_tuple);
80   MS_EXCEPTION_IF_NULL(type_ptr);
81   std::vector<T> values;
82   auto first_type = value_tuple->value()[0]->type()->type_id();
83   for (const auto &v : value_tuple->value()) {
84     MS_EXCEPTION_IF_NULL(v);
85     if (v->isa<Scalar>()) {
86       ScalarPtr scalar = v->cast<ScalarPtr>();
87       auto cur_type = scalar->type()->type_id();
88       if (cur_type != first_type) {
89         MS_EXCEPTION(TypeError) << "the tuple elements type must be same, first element type = " << first_type
90                                 << " cur_type = " << cur_type;
91       }
92       values.push_back(GetValue<T>(scalar));
93     } else {
94       MS_EXCEPTION(TypeError) << "The value " << v << "of tuple is not a scalar";
95     }
96   }
97   if (type_ptr->type_id() == kNumberTypeInt32) {
98     return CreateTensorByTupleCast<T, int32_t>(values, type_ptr, data_len);
99   } else if (type_ptr->type_id() == kNumberTypeInt64) {
100     return CreateTensorByTupleCast<T, int64_t>(values, type_ptr, data_len);
101   } else if (type_ptr->type_id() == kNumberTypeFloat32) {
102     return CreateTensorByTupleCast<T, float>(values, type_ptr, data_len);
103   } else if (type_ptr->type_id() == kNumberTypeFloat64) {
104     return CreateTensorByTupleCast<T, double>(values, type_ptr, data_len);
105   } else if (type_ptr->type_id() == kNumberTypeBool) {
106     // std::vector<bool> is not a valid container, so use std::vector<int8_t> to hold the values
107     return CreateTensorByTupleCast<T, int8_t>(values, type_ptr, data_len);
108   } else {
109     MS_EXCEPTION(TypeError) << "Invalid scalar type: " << type_ptr->ToString();
110   }
111 }
112 
SeqToTensorByType(const ValueSequencePtr & value_tuple,const TypePtr & data_type)113 tensor::TensorPtr SeqToTensorByType(const ValueSequencePtr &value_tuple, const TypePtr &data_type) {
114   tensor::TensorPtr tensor = nullptr;
115   if (value_tuple->value().empty()) {
116     tensor = CreateEmptyTupleTensorByType(data_type);
117     return tensor;
118   }
119   ValuePtr v = *(value_tuple->value().begin());
120   MS_EXCEPTION_IF_NULL(v);
121   // Currently we only deal with the scalar tuple
122   if (!v->isa<Scalar>()) {
123     MS_EXCEPTION(TypeError) << "The value " << v << "of tuple is not a scalar";
124   }
125   ScalarPtr scalar = v->cast<ScalarPtr>();
126   MS_EXCEPTION_IF_NULL(scalar);
127   size_t data_len = GetTypeByte(data_type);
128   if (scalar->isa<Int32Imm>()) {
129     tensor = CreateTensorWithValueTuple<int32_t>(value_tuple, data_type, data_len);
130   } else if (scalar->isa<Int64Imm>()) {
131     tensor = CreateTensorWithValueTuple<int64_t>(value_tuple, data_type, data_len);
132   } else if (scalar->isa<FP32Imm>()) {
133     tensor = CreateTensorWithValueTuple<float>(value_tuple, data_type, data_len);
134   } else if (scalar->isa<FP64Imm>()) {
135     tensor = CreateTensorWithValueTuple<double>(value_tuple, data_type, data_len);
136   } else if (scalar->isa<BoolImm>()) {
137     tensor = CreateTensorWithValueTuple<bool>(value_tuple, data_type, data_len);
138   } else {
139     auto type = scalar->type();
140     auto type_str = (type == nullptr) ? "nullptr" : type->ToString();
141     MS_EXCEPTION(TypeError) << "Invalid scalar type: " << type_str;
142   }
143   return tensor;
144 }
145 
146 class TupleToTensorFrontendFuncImpl : public OpFrontendFuncImpl {
147  public:
InferValue(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const148   ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
149     MS_EXCEPTION_IF_NULL(primitive);
150     auto prim_name = primitive->name();
151     constexpr int64_t input_len = 2;
152     (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_len,
153                                              prim_name);
154     auto elem = abstract::CheckArg<abstract::AbstractSequence>(prim_name, input_args, 0);
155     auto elem_value = elem->GetValue();
156     if (elem_value->ContainsValueAny()) {
157       return nullptr;
158     }
159     auto value_tuple = elem_value->cast<ValueSequencePtr>();
160     MS_EXCEPTION_IF_NULL(value_tuple);
161     auto dtype_value = GetScalarValue<int64_t>(input_args[kInputIndex1]->GetValue());
162     MS_CHECK_VALUE(dtype_value.has_value(),
163                    CheckAndConvertUtils::FormatCommMsg("For primitive[", prim_name,
164                                                        "], the `dtype` should has valid value for static type."));
165     auto dst_type = TypeIdToType(static_cast<TypeId>(dtype_value.value()));
166     MS_EXCEPTION_IF_NULL(dst_type);
167     return SeqToTensorByType(value_tuple, dst_type);
168   }
169 };
170 REGISTER_PRIMITIVE_FUNCTION_FRONTEND_FUNC_IMPL(kTupleToTensor, TupleToTensorFrontendFuncImpl);
171 }  // namespace ops
172 }  // namespace mindspore
173