1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_KERNELS_SHIM_TF_TENSOR_VIEW_H_ 16 #define TENSORFLOW_LITE_KERNELS_SHIM_TF_TENSOR_VIEW_H_ 17 18 #include "absl/status/status.h" 19 #include "absl/status/statusor.h" 20 #include "absl/strings/string_view.h" 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/framework/types.h" 23 #include "tensorflow/lite/kernels/shim/tensor_view.h" 24 25 namespace tflite { 26 namespace shim { 27 28 // A view over TF Tensor without taking ownership. It can be either mutable or 29 // immutable. 30 class TfTensorView : public TensorView { 31 public: 32 // Move constructor 33 TfTensorView(TfTensorView &&o) noexcept; 34 // Copy constructor 35 TfTensorView(const TfTensorView &o); 36 // Move assignment operator 37 TfTensorView &operator=(TfTensorView &&o) noexcept; 38 // Copy assignment operator 39 TfTensorView &operator=(const TfTensorView &); 40 41 protected: 42 // Templated constructor. Since it's not possible to specify the template 43 // argument directly we place a dummy argument of that type so compiler can 44 // deduce the right template parameter 45 template <typename DType> 46 TfTensorView(const ::tensorflow::Tensor *wrapped_tensor, const DType &dtype); 47 48 // Let the factory implementation use private constructors 49 template <typename TfTensorType> 50 friend absl::StatusOr< 51 typename MatchConstNess<TfTensorType, TfTensorView>::Type> 52 TfTensorViewTemplatizedNew(TfTensorType *wrapped_tensor); 53 54 // Stores the shape read from the TensorShape object 55 std::vector<int> shape_data_; 56 }; 57 58 // Map ::tensorflow::Tensor -> TfTensorView 59 template <> 60 struct TensorViewSubType<::tensorflow::Tensor> { 61 using Type = TfTensorView; 62 }; 63 64 // Map const ::tensorflow::Tensor -> const TfTensorView 65 template <> 66 struct TensorViewSubType<const ::tensorflow::Tensor> { 67 using Type = const TfTensorView; 68 }; 69 70 // Specialization of New() factory 71 template <> 72 absl::StatusOr<TfTensorView> TensorView::New<::tensorflow::Tensor>( 73 ::tensorflow::Tensor *wrapped_tensor); 74 75 // Specialization of New() factory 76 template <> 77 absl::StatusOr<const TfTensorView> TensorView::New<const ::tensorflow::Tensor>( 78 const ::tensorflow::Tensor *wrapped_tensor); 79 80 /////////////////////// Implementation 81 /////////////////////// 82 83 // Templated ctor 84 template <typename DType> 85 TfTensorView::TfTensorView(const ::tensorflow::Tensor *wrapped_tensor, 86 const DType &dtype) 87 : TensorView({}, wrapped_tensor->data(), 88 wrapped_tensor->tensor_data().size(), dtype) { 89 shape_data_.resize(wrapped_tensor->shape().dims()); 90 for (int dim = 0; dim < wrapped_tensor->shape().dims(); ++dim) { 91 shape_data_[dim] = wrapped_tensor->shape().dim_size(dim); 92 } 93 shape_ = absl::Span<int>(shape_data_); 94 } 95 96 } // namespace shim 97 } // namespace tflite 98 99 #endif // TENSORFLOW_LITE_KERNELS_SHIM_TF_TENSOR_VIEW_H_ 100