1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_DESC_H_ 18 #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_DESC_H_ 19 20 #include <memory> 21 #include "transform/graph_ir/op_adapter.h" 22 23 namespace mindspore { 24 namespace transform { 25 class OpAdapterDesc { 26 public: OpAdapterDesc()27 OpAdapterDesc() : train_(nullptr), infer_(nullptr) {} 28 OpAdapterDesc(const OpAdapterPtr & train,const OpAdapterPtr & infer)29 OpAdapterDesc(const OpAdapterPtr &train, const OpAdapterPtr &infer) : train_(train), infer_(infer) {} 30 OpAdapterDesc(const OpAdapterPtr & common)31 explicit OpAdapterDesc(const OpAdapterPtr &common) : train_(common), infer_(common) {} 32 OpAdapterDesc(const OpAdapterDesc & desc)33 OpAdapterDesc(const OpAdapterDesc &desc) { 34 this->train_ = desc.train_; 35 this->infer_ = desc.infer_; 36 } 37 OpAdapterDesc(OpAdapterDesc && desc)38 OpAdapterDesc(OpAdapterDesc &&desc) { 39 this->train_ = desc.train_; 40 this->infer_ = desc.infer_; 41 desc.train_ = nullptr; 42 desc.infer_ = nullptr; 43 } 44 45 ~OpAdapterDesc() = default; 46 Get(bool train)47 OpAdapterPtr Get(bool train) const { return train ? train_ : infer_; } 48 49 OpAdapterDesc &operator=(const OpAdapterDesc &desc) { 50 if (this != &desc) { 51 this->train_ = desc.train_; 52 this->infer_ = desc.infer_; 53 } 54 return *this; 55 } 56 57 OpAdapterDesc &operator=(OpAdapterDesc &&desc) { 58 if (this != &desc) { 59 this->train_ = desc.train_; 60 this->infer_ = desc.infer_; 61 desc.train_ = nullptr; 62 desc.infer_ = nullptr; 63 } 64 return *this; 65 } 66 67 private: 68 OpAdapterPtr train_; 69 OpAdapterPtr infer_; 70 }; 71 72 using OpAdapterDescPtr = std::shared_ptr<OpAdapterDesc>; 73 } // namespace transform 74 } // namespace mindspore 75 #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_DESC_H_ 76