• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_SHIM_TF_TENSOR_VIEW_H_
16 #define TENSORFLOW_LITE_KERNELS_SHIM_TF_TENSOR_VIEW_H_
17 
18 #include "absl/status/status.h"
19 #include "absl/status/statusor.h"
20 #include "absl/strings/string_view.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/types.h"
23 #include "tensorflow/lite/kernels/shim/tensor_view.h"
24 
25 namespace tflite {
26 namespace shim {
27 
28 // A view over TF Tensor without taking ownership. It can be either mutable or
29 // immutable.
30 class TfTensorView : public TensorView {
31  public:
32   // Move constructor
33   TfTensorView(TfTensorView &&o) noexcept;
34   // Copy constructor
35   TfTensorView(const TfTensorView &o);
36   // Move assignment operator
37   TfTensorView &operator=(TfTensorView &&o) noexcept;
38   // Copy assignment operator
39   TfTensorView &operator=(const TfTensorView &);
40 
41  protected:
42   // Templated constructor. Since it's not possible to specify the template
43   // argument directly we place a dummy argument of that type so compiler can
44   // deduce the right template parameter
45   template <typename DType>
46   TfTensorView(const ::tensorflow::Tensor *wrapped_tensor, const DType &dtype);
47 
48   // Let the factory implementation use private constructors
49   template <typename TfTensorType>
50   friend absl::StatusOr<
51       typename MatchConstNess<TfTensorType, TfTensorView>::Type>
52   TfTensorViewTemplatizedNew(TfTensorType *wrapped_tensor);
53 
54   // Stores the shape read from the TensorShape object
55   std::vector<int> shape_data_;
56 };
57 
58 // Map ::tensorflow::Tensor -> TfTensorView
59 template <>
60 struct TensorViewSubType<::tensorflow::Tensor> {
61   using Type = TfTensorView;
62 };
63 
64 // Map const ::tensorflow::Tensor -> const TfTensorView
65 template <>
66 struct TensorViewSubType<const ::tensorflow::Tensor> {
67   using Type = const TfTensorView;
68 };
69 
70 // Specialization of New() factory
71 template <>
72 absl::StatusOr<TfTensorView> TensorView::New<::tensorflow::Tensor>(
73     ::tensorflow::Tensor *wrapped_tensor);
74 
75 // Specialization of New() factory
76 template <>
77 absl::StatusOr<const TfTensorView> TensorView::New<const ::tensorflow::Tensor>(
78     const ::tensorflow::Tensor *wrapped_tensor);
79 
80 /////////////////////// Implementation
81 ///////////////////////
82 
83 // Templated ctor
84 template <typename DType>
85 TfTensorView::TfTensorView(const ::tensorflow::Tensor *wrapped_tensor,
86                            const DType &dtype)
87     : TensorView({}, wrapped_tensor->data(),
88                  wrapped_tensor->tensor_data().size(), dtype) {
89   shape_data_.resize(wrapped_tensor->shape().dims());
90   for (int dim = 0; dim < wrapped_tensor->shape().dims(); ++dim) {
91     shape_data_[dim] = wrapped_tensor->shape().dim_size(dim);
92   }
93   shape_ = absl::Span<int>(shape_data_);
94 }
95 
96 }  // namespace shim
97 }  // namespace tflite
98 
99 #endif  // TENSORFLOW_LITE_KERNELS_SHIM_TF_TENSOR_VIEW_H_
100