• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "debug/data_dump/tensor_statistic.h"
17 #include <memory>
18 #include <string>
19 #include <vector>
20 #include "debug/data_dump/statistic_kernel.h"
21 #include "debug/debugger/debugger_utils.h"
22 #include "include/backend/debug/common/csv_writer.h"
23 #include "include/backend/debug/data_dump/dump_utils.h"
24 #include "include/common/debug/anf_dump_utils.h"
25 #include "include/backend/debug/data_dump/dump_json_parser.h"
26 #include "debug/utils.h"
27 
28 namespace mindspore {
29 
30 namespace {
31 using TensorPtr = tensor::TensorPtr;
32 
33 constexpr auto kInput = "input";
34 constexpr auto kOutput = "output";
35 constexpr auto kCsvFileName = "statistic.csv";
ShapeToString(const ShapeVector & shape)36 string ShapeToString(const ShapeVector &shape) {
37   std::ostringstream sstr;
38   sstr << "\"(";
39   for (size_t i = 0; i < shape.size(); i++) {
40     sstr << (i > 0 ? "," : "") << shape[i];
41   }
42   sstr << ")\"";
43   return string{sstr.str()};
44 }
TensorToString(TensorPtr tensor)45 string TensorToString(TensorPtr tensor) {
46   if (!tensor) {
47     return "null";
48   }
49   return tensor->data().ToString(tensor->data_type(), tensor->shape(), false);
50 }
51 }  // namespace
52 
53 namespace datadump {
54 
GetKernelTensorStats(const DumpTensorInfo & tensor_info,const std::vector<string> & stat_name_list)55 TensorStat GetKernelTensorStats(const DumpTensorInfo &tensor_info, const std::vector<string> &stat_name_list) {
56   auto tensor = tensor_info.tensor;
57   if (tensor == nullptr) {
58     MS_LOG(WARNING) << "Tensor is nullptr, returning empty tensor statistics.";
59     return TensorStat();
60   }
61 
62   const auto &shape_vec = tensor->GetShapeVector();
63   string shape = ShapeToString(shape_vec);
64   size_t data_count = SizeOf(shape_vec);
65   size_t data_size = tensor->size();
66   string data_type = TypeIdToString(tensor->dtype_id(), true);
67   MS_LOG(DEBUG) << "Tensor shape is " << shape << ", size is " << data_size << ", type is " << data_type;
68   auto is_calc_stat = [&stat_name_list](std::string name) {
69     return (std::find(stat_name_list.begin(), stat_name_list.end(), name) != stat_name_list.end());
70   };
71   std::string max_value =
72     is_calc_stat("max") ? TensorToString(CalStatistic("max", tensor_info.device_context, tensor)) : "0";
73   std::string min_value =
74     is_calc_stat("min") ? TensorToString(CalStatistic("min", tensor_info.device_context, tensor)) : "0";
75   std::string mean_value =
76     is_calc_stat("avg") ? TensorToString(CalStatistic("avg", tensor_info.device_context, tensor)) : "0";
77   std::string norm_value =
78     is_calc_stat("l2norm") ? TensorToString(CalStatistic("l2norm", tensor_info.device_context, tensor)) : "0";
79 
80   size_t task_id = 0;  // Under the kbyk, there is no concept of task_id. The default setting is 0.
81   uint64_t timestamp = Common::GetTimeStamp();
82   auto stream_id = tensor->stream_id();
83   string io = (tensor_info.is_input ? kInput : kOutput);
84   TensorStat stat(tensor_info.op_type, tensor_info.op_name, task_id, stream_id, timestamp, io, tensor_info.slot,
85                   data_size, data_type, shape, max_value, min_value, mean_value, norm_value, data_count);
86   return stat;
87 }
88 
DumpKernelTensorStats(const DeviceContext * device_context,vector<device::DeviceAddress * > tensors,bool is_input,const CNodePtr & node,uint32_t graph_id)89 void DumpKernelTensorStats(const DeviceContext *device_context, vector<device::DeviceAddress *> tensors, bool is_input,
90                            const CNodePtr &node, uint32_t graph_id) {
91   string node_name = GetKernelNodeName(node);
92   GetFileKernelName(NOT_NULL(&node_name));
93   string node_type = common::AnfAlgo::GetCNodeName(node);
94   MS_LOG(DEBUG) << "Start calc " << node_name << " node statistics.";
95   const string csv_header = CsvHeaderUtil::GetInstance().GetStatCsvHeader();
96   const std::vector<string> &stat_name_list = DumpJsonParser::GetInstance().statistic_category();
97   uint32_t rank_id = GetRankId();
98   string filename = GenerateDumpPath(graph_id, rank_id) + "/" + kCsvFileName;
99   CsvWriter csv;
100   auto valid_index = GetValidDumpIndex(node, tensors.size(), is_input);
101   if (!valid_index.empty()) {
102     if (!csv.OpenFile(filename, csv_header)) {
103       MS_LOG(WARNING) << "filename is " << filename;
104       MS_LOG(WARNING) << "Open statistic dump file failed, skipping current statistics";
105       return;
106     }
107   }
108   for (auto i : valid_index) {
109     auto tensor = tensors[i]->kernel_tensor().get();
110     DumpTensorInfo tensor_info(device_context, tensor, is_input, i, node_name, node_type);
111     auto stat = GetKernelTensorStats(tensor_info, stat_name_list);
112     stat.UpdateHeaderItemMap();
113 
114     csv.WriteToCsv(stat.type_);
115     csv.WriteToCsv(stat.name_);
116     csv.WriteToCsv(stat.task_id_);
117     csv.WriteToCsv(stat.stream_id_);
118     csv.WriteToCsv(stat.timestamp_);
119     csv.WriteToCsv(stat.io_);
120     csv.WriteToCsv(stat.slot_);
121     csv.WriteToCsv(stat.data_size_);
122     csv.WriteToCsv(stat.data_type_);
123     csv.WriteToCsv(stat.shape_);
124 
125     for (const auto &name : stat_name_list) {
126       // DumpJsonParse guarantee names are valid.
127       auto stat_val = stat.header_item_map[name];
128       csv.WriteToCsv(stat_val);
129     }
130     csv.WriteToCsv("", true);
131   }
132   csv.CloseFile();
133 }
134 
135 }  // namespace datadump
136 }  // namespace mindspore
137