• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <vector>
18 #include <string>
19 #include <set>
20 #include <utility>
21 #include <map>
22 #include <unordered_map>
23 #include "schema/ops_generated.h"
24 #include "base/core_ops.h"
25 #include "include/lite_types.h"
26 #ifndef MINDSPORE_LITE_SRC_PASS_PARALLEL_SPLIT_STRATEGY_H_
27 #define MINDSPORE_LITE_SRC_PASS_PARALLEL_SPLIT_STRATEGY_H_
28 
29 namespace mindspore {
30 namespace opt {
31 constexpr auto PARALLEL_NAME_SUFFIX = "_parallel";
32 
33 constexpr auto kParallelPrimitiveIndex = 0;
34 
35 const std::vector<int64_t> kSplitDefaultRatio = {0, 0};
36 
37 // user's device to split, only split to cpu && gpu, no support npu
38 const std::vector<std::string> kSplitDevTypes = {"cpu", "gpu"};
39 
40 using Strategys = std::vector<std::vector<std::vector<int64_t>>>;
41 
42 constexpr auto kDeviceTypeNone = -1;
43 // strategy format is NHWC-KHWC
44 constexpr int32_t kAxisN = 0;
45 constexpr int32_t kAxisCIn = 3;
46 constexpr int32_t kAxisCOut = 0;
47 constexpr int32_t kAxisH = 1;
48 constexpr int32_t kAxisW = 2;
49 
50 constexpr auto kDefaultBatch = 1;
51 
52 constexpr auto kShapeN = 0;
53 constexpr auto kShapeH = 1;
54 constexpr auto kShapeW = 2;
55 constexpr auto kShapeC = 3;
56 
57 constexpr auto kIndexH = 0;
58 constexpr auto kIndexW = 1;
59 
60 constexpr auto kPadUp = 0;
61 constexpr auto kPadDown = 1;
62 constexpr auto kPadLeft = 2;
63 constexpr auto kPadRight = 3;
64 
65 enum SplitMode {
66   NoSplit = 0,
67   SplitN = 1,
68   SplitH = 2,
69   SplitCIN = 3,
70   SplitCOUT = 4,
71 };
72 
73 struct SplitStrategy {
74   Strategys strategys{};
75   std::vector<std::string> dev_types{};
76   size_t dev_num{0};
77   SplitMode split_mode_{NoSplit};
78 };
79 
80 // this is a map for key: <primitive,is_depth_wise>  value: parallel_op_name
81 const std::map<std::pair<PrimitivePtr, bool>, std::string> kParallelOpNames = {
82   {{prim::kPrimConv2D, false}, "Conv2D"},
83   {{prim::kPrimConv2DFusion, false}, "Conv2D"},
84   {{prim::kPrimConv2D, true}, "DepthwiseConv2D"},
85   {{prim::kPrimConv2DFusion, true}, "DepthwiseConv2D"}};
86 
87 const std::map<std::string, lite::DeviceType> kSupportSplitedDevices = {
88   {"cpu", lite::DeviceType::DT_CPU}, {"gpu", lite::DeviceType::DT_GPU}, {"npu", lite::DeviceType::DT_NPU}};
89 
90 // this is a map for key: primitive  value: schema_primitive_id
91 const std::unordered_map<PrimitivePtr, std::pair<schema::PrimitiveType, TypeId>> kParallelSchemaId = {
92   {prim::kPrimConv2D, {schema::PrimitiveType_Conv2DFusion, kNumberTypeFloat32}},
93   {prim::kPrimConv2DFusion, {schema::PrimitiveType_Conv2DFusion, kNumberTypeFloat32}}};
94 
95 // this is an artificial restriction that if user split conv, we limit total FLOPs bigger than
96 // 2 * output_H * output_W * (in_C * kW * kH +1) * out_C >= 100
97 // FLOPs ~= output_H * output_W * (in_C * kW * kH) * out_C
98 // FLOPs ~= (input_h/stride_h)*(input_w/stride_w)*in_C * kW * kH) * out_C
99 // etc. (12/1)*(12/1)*(1*3*3)*128/1024 = 162kFLPOs
100 constexpr auto kUserFLOPs = 100;
101 constexpr auto kPerFlops = 1024;
102 
103 int64_t ApproximateFLOPs(const std::vector<int64_t> &strides, const std::vector<int64_t> &input_shae,
104                          const std::vector<int64_t> &weight_shape);
105 
106 std::unordered_map<std::string, opt::SplitStrategy> ParserSplitStrategy(
107   const std::vector<int64_t> &parallel_compute_rates, const std::vector<std::string> &parallel_devices,
108   SplitMode split_mode);
109 
110 }  // namespace opt
111 }  // namespace mindspore
112 #endif  // MINDSPORE_LITE_SRC_PASS_PARALLEL_SPLIT_STRATEGY_H_
113