1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include "base/base_ref_utils.h" 17 #include <vector> 18 #include <memory> 19 #include "ir/tensor.h" 20 21 namespace mindspore { IterateFindTensor(std::vector<tensor::TensorPtr> * msTensors,const VectorRef & ref_list)22void IterateFindTensor(std::vector<tensor::TensorPtr> *msTensors, const VectorRef &ref_list) { 23 for (size_t i = 0; i < ref_list.size(); ++i) { 24 if (utils::isa<tensor::TensorPtr>(ref_list[i])) { 25 auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(ref_list[i]); 26 MS_EXCEPTION_IF_NULL(tensor_ptr); 27 msTensors->emplace_back(tensor_ptr); 28 } else if (utils::isa<VectorRef>(ref_list[i])) { 29 auto ref_iter = utils::cast<VectorRef>(ref_list[i]); 30 IterateFindTensor(msTensors, ref_iter); 31 } else { 32 MS_LOG(EXCEPTION) << "The output is not a tensor"; 33 } 34 } 35 } 36 TransformVectorRefToMultiTensor(const VectorRef & base_ref)37std::vector<tensor::TensorPtr> TransformVectorRefToMultiTensor(const VectorRef &base_ref) { 38 std::vector<tensor::TensorPtr> msTensors; 39 if (utils::isa<VectorRef>(base_ref)) { 40 auto ref_list = utils::cast<VectorRef>(base_ref); 41 IterateFindTensor(&msTensors, ref_list); 42 } else if (utils::isa<tensor::Tensor>(base_ref)) { 43 auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(base_ref); 44 MS_EXCEPTION_IF_NULL(tensor_ptr); 45 msTensors.emplace_back(tensor_ptr); 46 } else { 47 MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; 48 } 49 return msTensors; 50 } 51 } // namespace mindspore 52