• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/print.h"
18 
19 #include <memory>
20 
21 #include "abstract/dshape.h"
22 #include "abstract/ops/op_infer.h"
23 #include "abstract/ops/primitive_infer_map.h"
24 #include "base/base.h"
25 #include "ir/anf.h"
26 #include "ir/dtype/number.h"
27 #include "ir/dtype/tensor_type.h"
28 #include "mindapi/base/shape_vector.h"
29 #include "mindapi/base/shared_ptr.h"
30 #include "mindapi/ir/value.h"
31 #include "mindapi/src/helper.h"
32 #include "mindspore/core/ops/framework_ops.h"
33 #include "ops/primitive_c.h"
34 #include "utils/log_adapter.h"
35 #include "utils/compile_config.h"
36 
37 namespace mindspore {
38 namespace ops {
39 constexpr auto kStringValue = "string_value";
40 constexpr auto kStringPos = "string_pos";
41 constexpr auto kValueType = "value_type";
42 constexpr auto kValueTypePos = "value_type_pos";
43 
set_string_value(const std::vector<std::string> & string_value)44 void Print::set_string_value(const std::vector<std::string> &string_value) {
45   (void)this->AddAttr(kStringValue, api::MakeValue(string_value));
46 }
47 
set_string_pos(const std::vector<int64_t> & string_pos)48 void Print::set_string_pos(const std::vector<int64_t> &string_pos) {
49   (void)this->AddAttr(kStringPos, api::MakeValue(string_pos));
50 }
51 
set_value_type(const std::vector<int64_t> & value_type)52 void Print::set_value_type(const std::vector<int64_t> &value_type) {
53   (void)this->AddAttr(kValueType, api::MakeValue(value_type));
54 }
55 
set_value_type_pos(const std::vector<int64_t> & value_type_pos)56 void Print::set_value_type_pos(const std::vector<int64_t> &value_type_pos) {
57   (void)this->AddAttr(kValueTypePos, api::MakeValue(value_type_pos));
58 }
59 
get_string_value() const60 std::vector<std::string> Print::get_string_value() const {
61   auto value_ptr = this->GetAttr(kStringValue);
62   return GetValue<std::vector<std::string>>(value_ptr);
63 }
64 
get_string_pos() const65 std::vector<int64_t> Print::get_string_pos() const {
66   auto value_ptr = this->GetAttr(kStringPos);
67   return GetValue<std::vector<int64_t>>(value_ptr);
68 }
69 
get_value_type() const70 std::vector<int64_t> Print::get_value_type() const {
71   auto value_ptr = this->GetAttr(kValueType);
72   return GetValue<std::vector<int64_t>>(value_ptr);
73 }
74 
get_value_type_pos() const75 std::vector<int64_t> Print::get_value_type_pos() const {
76   auto value_ptr = this->GetAttr(kValueTypePos);
77   return GetValue<std::vector<int64_t>>(value_ptr);
78 }
79 
PrintValueToString(const ValuePtr & value)80 std::string PrintValueToString(const ValuePtr &value) {
81   if (value == nullptr) {
82     return "UnknownValue";
83   }
84   if (value->ContainsValueAny()) {
85     return "UnknownValue";
86   }
87   if (value->isa<StringImm>()) {
88     return value->ToString();
89   }
90   if (value->isa<Scalar>()) {
91     std::ostringstream buffer;
92     buffer << "Scalar(" << value->ToString() << ")";
93     return buffer.str();
94   }
95   return value->ToString();
96 }
97 
PrintAbstractToString(const AbstractBasePtr & abstract)98 std::string PrintAbstractToString(const AbstractBasePtr &abstract) {
99   if (abstract->isa<abstract::AbstractScalar>()) {
100     auto value = abstract->GetValue();
101     return PrintValueToString(value);
102   }
103   if (abstract->isa<abstract::AbstractTensor>()) {
104     std::ostringstream buffer;
105     auto abs_tensor = abstract->cast<abstract::AbstractTensorPtr>();
106     buffer << "Tensor(shape:" << abs_tensor->GetShape()->ToString()
107            << ", dtype:" << abs_tensor->GetTypeTrack()->ToString()
108            << ", value:" << PrintValueToString(abs_tensor->GetValue()) << ")";
109     return buffer.str();
110   }
111   if (abstract->isa<abstract::AbstractSequence>()) {
112     auto abs_list = abstract->cast<abstract::AbstractSequencePtr>();
113     std::ostringstream buffer;
114     buffer << (abstract->isa<abstract::AbstractList>() ? "List[" : "Tuple(");
115     if (abs_list->dynamic_len()) {
116       buffer << PrintAbstractToString(abs_list->dynamic_len_element_abs());
117       buffer << "......";
118     } else {
119       for (const auto &element : abs_list->elements()) {
120         buffer << PrintAbstractToString(element) << ", ";
121       }
122     }
123     buffer << (abstract->isa<abstract::AbstractList>() ? "]" : ")");
124     return buffer.str();
125   }
126   return abstract->ToString();
127 }
128 
129 MIND_API_OPERATOR_IMPL(Print, BaseOperator);
130 
131 class PrintInfer : public abstract::OpInferBase {
132  public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const133   BaseShapePtr InferShape(const PrimitivePtr &primitive,
134                           const std::vector<AbstractBasePtr> &input_args) const override {
135     ShapeVector shape = {1};
136     return std::make_shared<abstract::Shape>(shape);
137   }
138 
InferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args) const139   TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
140     return std::make_shared<TensorType>(kInt32);
141   }
142 
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const143   AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
144                                     const std::vector<AbstractBasePtr> &input_args) const override {
145     auto shape = InferShape(primitive, input_args);
146     auto type = InferType(primitive, input_args);
147     std::ostringstream buffer;
148     if (common::GetCompileConfig("COMPILE_PRINT") == "1") {
149       for (const auto &input_arg : input_args) {
150         buffer << PrintAbstractToString(input_arg);
151       }
152       std::cout << buffer.str() << std::endl;
153     }
154     return abstract::MakeAbstract(shape, type);
155   }
156 };
157 
158 REGISTER_PRIMITIVE_OP_INFER_IMPL(Print, prim::kPrimPrint, PrintInfer, false);
159 }  // namespace ops
160 }  // namespace mindspore
161