1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_STROAGE_FORMAT_CONVERTOR_H 18 #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_STROAGE_FORMAT_CONVERTOR_H 19 20 #include <memory> 21 #include <string> 22 #include "include/transform/graph_ir/types.h" 23 #include "include/common/utils/utils.h" 24 #include "ir/manager.h" 25 26 namespace mindspore::transform { 27 extern AnfNodePtr GetMomentumVarByAccum(const AnfNodePtr &node, const NodeUsersMap &node_users); 28 29 class StorageFormatConvertor { 30 public: 31 static bool SetupStorageFormat(const AnfGraphPtr &anf_graph, const AnfNodePtr ¶m, 32 const std::shared_ptr<GeTensorDesc> &desc, 33 const std::string &ori_format = kOpFormat_NCHW); 34 35 private: 36 static bool InitParameterKernelInfo(const AnfNodePtr ¶m, std::string *format); 37 static void UpdateParameterKernelInfo(const AnfNodePtr ¶m, const std::string &format); 38 static int32_t GetGeFormat(const AnfNodePtr &src_node, const AnfNodePtr &dst_node, const std::string &storage_format, 39 size_t origin_dim); 40 static int32_t GetGeFormat(const AnfNodePtr &src_node, const std::string &storage_format, size_t origin_dim); 41 StorageFormatConvertor() = default; 42 ~StorageFormatConvertor() = default; 43 static void UpdateTensorDesc(const std::shared_ptr<GeTensorDesc> &desc, int32_t format); 44 static void SetStorageFormatFromConfig(const AnfGraphPtr &anf_graph, const AnfNodePtr ¶m, 45 const std::shared_ptr<GeTensorDesc> &desc); 46 }; 47 } // namespace mindspore::transform 48 49 #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_STROAGE_FORMAT_CONVERTOR_H 50