1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_TENSOR_SUMMARY_H 17 #define MINDSPORE_TENSOR_SUMMARY_H 18 19 #include <vector> 20 #include <tuple> 21 #include <memory> 22 #include <string> 23 24 #include "utils/hash_map.h" 25 #include "debug/debug_services.h" 26 27 namespace mindspore { 28 class RangeCountCalculator { 29 public: 30 RangeCountCalculator(); 31 ~RangeCountCalculator() = default; 32 void ProcessElement(double element); 33 double GetPercentInRange() const; set_range_start_inclusive(double value)34 void set_range_start_inclusive(double value) { range_start_inclusive = value; } set_range_end_inclusive(double value)35 void set_range_end_inclusive(double value) { range_end_inclusive = value; } 36 37 private: 38 double range_start_inclusive; 39 double range_end_inclusive; 40 int count; 41 int total; 42 }; 43 44 class AllCloseCalculator { 45 public: 46 AllCloseCalculator(); 47 ~AllCloseCalculator() = default; 48 void ProcessElement(double current, double previous); 49 bool IsAllClose() const; set_atol(double value)50 void set_atol(double value) { atol = value; } set_rtol(double value)51 void set_rtol(double value) { rtol = value; } 52 53 private: 54 double atol; 55 double rtol; 56 bool result; 57 }; 58 59 class MeanCalculator { 60 public: 61 MeanCalculator(); 62 ~MeanCalculator() = default; 63 void ProcessElement(double value); 64 double GetMean() const; 65 66 protected: 67 double mean; 68 int count; 69 }; 70 71 class VarianceAndMeanCalculator { 72 public: 73 VarianceAndMeanCalculator(); 74 ~VarianceAndMeanCalculator() = default; 75 void ProcessElement(double value); 76 double GetStandardDeviation() const; 77 double GetVariance() const; 78 double GetMean() const; 79 80 private: 81 double mean; 82 int count; 83 double m2; 84 }; 85 86 class L2Calculator { 87 public: L2Calculator()88 L2Calculator() : squre_sum(0.0) {} 89 ~L2Calculator() = default; 90 void ProcessElement(double value); 91 void ProcessElement(const L2Calculator &other); 92 double GetL2Value() const; 93 94 private: 95 // save (x^2 + y^2)/y^2, when y > x, to avoid itermidiate value overflow 96 // the true l2 value should be sqrt(squre_sum_div_max_ * max_value_^2) 97 double squre_sum; 98 }; 99 100 class ITensorSummary { 101 public: 102 enum WatchpointPos { eHitPos = 0, eErrorCodePos = 1, eParamListPos = 2 }; 103 enum ErrorCode { 104 NAN_TENSOR = 0, 105 INF_TENSOR = 2, 106 NULL_PREV_TENSOR = 4, 107 OUT_OF_MEMORY = 8, 108 HISTORY_NOT_FOUND = 16, 109 NO_VALUE = 32 110 }; 111 virtual ~ITensorSummary() = default; 112 virtual void SummarizeTensor(const std::vector<DebugServices::watchpoint_t> &wps) = 0; 113 virtual std::tuple<bool, int32_t, std::vector<DebugServices::parameter_t>> IsWatchpointHit( 114 DebugServices::watchpoint_t) = 0; 115 virtual void TensorStatistics(DbgDataType dtype_value) = 0; 116 virtual const bool is_bool() const = 0; 117 virtual const double max_value() const = 0; 118 virtual const double min_value() const = 0; 119 virtual const double avg_value() const = 0; 120 virtual const double l2_value() const = 0; 121 122 virtual const uint64_t count() const = 0; 123 virtual const uint64_t neg_zero_count() const = 0; 124 virtual const uint64_t pos_zero_count() const = 0; 125 virtual const uint64_t nan_count() const = 0; 126 virtual const uint64_t neg_inf_count() const = 0; 127 virtual const uint64_t pos_inf_count() const = 0; 128 virtual const uint64_t zero_count() const = 0; 129 }; 130 131 template <typename T> 132 class TensorSummary : public ITensorSummary { 133 public: 134 TensorSummary() = default; 135 ~TensorSummary() override = default; 136 TensorSummary(const void *current_tensor_ptr, const void *const previous_tensor_ptr, uint64_t num_elements, 137 uint64_t prev_num_elements); 138 void SummarizeTensor(const std::vector<DebugServices::watchpoint_t> &wps) override; 139 // returns hit, error_code, parameter_list 140 std::tuple<bool, int, std::vector<DebugServices::parameter_t>> IsWatchpointHit( 141 DebugServices::watchpoint_t wp) override; 142 void TensorStatistics(DbgDataType dtype_value) override; is_bool()143 const bool is_bool() const override { return is_bool_; } max_value()144 const double max_value() const override { return max_; } min_value()145 const double min_value() const override { return min_; } avg_value()146 const double avg_value() const override { return avg_; } count()147 const uint64_t count() const override { return num_elements_; } neg_zero_count()148 const uint64_t neg_zero_count() const override { return neg_zero_count_; } pos_zero_count()149 const uint64_t pos_zero_count() const override { return pos_zero_count_; } nan_count()150 const uint64_t nan_count() const override { return nan_count_; } neg_inf_count()151 const uint64_t neg_inf_count() const override { return neg_inf_count_; } pos_inf_count()152 const uint64_t pos_inf_count() const override { return pos_inf_count_; } zero_count()153 const uint64_t zero_count() const override { return zero_count_; } l2_value()154 const double l2_value() const override { return l2_calc_.GetL2Value(); } 155 156 private: 157 const T *current_tensor_ptr_; 158 const T *prev_tensor_ptr_; 159 uint64_t num_elements_; 160 uint64_t prev_num_elements_; 161 double min_; 162 double max_; 163 double avg_; 164 bool is_bool_; 165 uint64_t neg_zero_count_; 166 uint64_t pos_zero_count_; 167 uint64_t pos_inf_count_; 168 uint64_t neg_inf_count_; 169 uint64_t inf_count_; 170 uint64_t nan_count_; 171 uint64_t zero_count_; 172 double epsilon_; 173 bool mean_sd_cal_enabled_; 174 VarianceAndMeanCalculator current_mean_variance_; 175 L2Calculator l2_calc_; 176 mindspore::HashMap<std::string, std::unique_ptr<MeanCalculator>> means_; 177 mindspore::HashMap<uint32_t, std::unique_ptr<AllCloseCalculator>> all_close_; 178 mindspore::HashMap<uint32_t, std::unique_ptr<RangeCountCalculator>> range_counts_; 179 double_t StatLookup(const DebugServices::watchpoint_t &wp) const; 180 double_t StatLookup(const std::string ¶meter_name, const DebugServices::watchpoint_t &wp); 181 double_t GetZeroValPercent() const; 182 void TensorStatisticsSingleThread(); 183 void InitCalculators(const std::vector<DebugServices::watchpoint_t> &); 184 }; 185 } // namespace mindspore 186 #endif // MINDSPORE_TENSOR_SUMMARY_H 187