1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/print.h"
18
19 #include <memory>
20
21 #include "abstract/dshape.h"
22 #include "abstract/ops/op_infer.h"
23 #include "abstract/ops/primitive_infer_map.h"
24 #include "base/base.h"
25 #include "ir/anf.h"
26 #include "ir/dtype/number.h"
27 #include "ir/dtype/tensor_type.h"
28 #include "mindapi/base/shape_vector.h"
29 #include "mindapi/base/shared_ptr.h"
30 #include "mindapi/ir/value.h"
31 #include "mindapi/src/helper.h"
32 #include "mindspore/core/ops/framework_ops.h"
33 #include "ops/primitive_c.h"
34 #include "utils/log_adapter.h"
35 #include "utils/compile_config.h"
36
37 namespace mindspore {
38 namespace ops {
39 constexpr auto kStringValue = "string_value";
40 constexpr auto kStringPos = "string_pos";
41 constexpr auto kValueType = "value_type";
42 constexpr auto kValueTypePos = "value_type_pos";
43
set_string_value(const std::vector<std::string> & string_value)44 void Print::set_string_value(const std::vector<std::string> &string_value) {
45 (void)this->AddAttr(kStringValue, api::MakeValue(string_value));
46 }
47
set_string_pos(const std::vector<int64_t> & string_pos)48 void Print::set_string_pos(const std::vector<int64_t> &string_pos) {
49 (void)this->AddAttr(kStringPos, api::MakeValue(string_pos));
50 }
51
set_value_type(const std::vector<int64_t> & value_type)52 void Print::set_value_type(const std::vector<int64_t> &value_type) {
53 (void)this->AddAttr(kValueType, api::MakeValue(value_type));
54 }
55
set_value_type_pos(const std::vector<int64_t> & value_type_pos)56 void Print::set_value_type_pos(const std::vector<int64_t> &value_type_pos) {
57 (void)this->AddAttr(kValueTypePos, api::MakeValue(value_type_pos));
58 }
59
get_string_value() const60 std::vector<std::string> Print::get_string_value() const {
61 auto value_ptr = this->GetAttr(kStringValue);
62 return GetValue<std::vector<std::string>>(value_ptr);
63 }
64
get_string_pos() const65 std::vector<int64_t> Print::get_string_pos() const {
66 auto value_ptr = this->GetAttr(kStringPos);
67 return GetValue<std::vector<int64_t>>(value_ptr);
68 }
69
get_value_type() const70 std::vector<int64_t> Print::get_value_type() const {
71 auto value_ptr = this->GetAttr(kValueType);
72 return GetValue<std::vector<int64_t>>(value_ptr);
73 }
74
get_value_type_pos() const75 std::vector<int64_t> Print::get_value_type_pos() const {
76 auto value_ptr = this->GetAttr(kValueTypePos);
77 return GetValue<std::vector<int64_t>>(value_ptr);
78 }
79
PrintValueToString(const ValuePtr & value)80 std::string PrintValueToString(const ValuePtr &value) {
81 if (value == nullptr) {
82 return "UnknownValue";
83 }
84 if (value->ContainsValueAny()) {
85 return "UnknownValue";
86 }
87 if (value->isa<StringImm>()) {
88 return value->ToString();
89 }
90 if (value->isa<Scalar>()) {
91 std::ostringstream buffer;
92 buffer << "Scalar(" << value->ToString() << ")";
93 return buffer.str();
94 }
95 return value->ToString();
96 }
97
PrintAbstractToString(const AbstractBasePtr & abstract)98 std::string PrintAbstractToString(const AbstractBasePtr &abstract) {
99 if (abstract->isa<abstract::AbstractScalar>()) {
100 auto value = abstract->GetValue();
101 return PrintValueToString(value);
102 }
103 if (abstract->isa<abstract::AbstractTensor>()) {
104 std::ostringstream buffer;
105 auto abs_tensor = abstract->cast<abstract::AbstractTensorPtr>();
106 buffer << "Tensor(shape:" << abs_tensor->GetShape()->ToString()
107 << ", dtype:" << abs_tensor->GetTypeTrack()->ToString()
108 << ", value:" << PrintValueToString(abs_tensor->GetValue()) << ")";
109 return buffer.str();
110 }
111 if (abstract->isa<abstract::AbstractSequence>()) {
112 auto abs_list = abstract->cast<abstract::AbstractSequencePtr>();
113 std::ostringstream buffer;
114 buffer << (abstract->isa<abstract::AbstractList>() ? "List[" : "Tuple(");
115 if (abs_list->dynamic_len()) {
116 buffer << PrintAbstractToString(abs_list->dynamic_len_element_abs());
117 buffer << "......";
118 } else {
119 for (const auto &element : abs_list->elements()) {
120 buffer << PrintAbstractToString(element) << ", ";
121 }
122 }
123 buffer << (abstract->isa<abstract::AbstractList>() ? "]" : ")");
124 return buffer.str();
125 }
126 return abstract->ToString();
127 }
128
129 MIND_API_OPERATOR_IMPL(Print, BaseOperator);
130
131 class PrintInfer : public abstract::OpInferBase {
132 public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const133 BaseShapePtr InferShape(const PrimitivePtr &primitive,
134 const std::vector<AbstractBasePtr> &input_args) const override {
135 ShapeVector shape = {1};
136 return std::make_shared<abstract::Shape>(shape);
137 }
138
InferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args) const139 TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
140 return std::make_shared<TensorType>(kInt32);
141 }
142
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const143 AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
144 const std::vector<AbstractBasePtr> &input_args) const override {
145 auto shape = InferShape(primitive, input_args);
146 auto type = InferType(primitive, input_args);
147 std::ostringstream buffer;
148 if (common::GetCompileConfig("COMPILE_PRINT") == "1") {
149 for (const auto &input_arg : input_args) {
150 buffer << PrintAbstractToString(input_arg);
151 }
152 std::cout << buffer.str() << std::endl;
153 }
154 return abstract::MakeAbstract(shape, type);
155 }
156 };
157
158 REGISTER_PRIMITIVE_OP_INFER_IMPL(Print, prim::kPrimPrint, PrintInfer, false);
159 } // namespace ops
160 } // namespace mindspore
161