• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_UTIL_H_
16 #define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_UTIL_H_
17 
18 #include <string>
19 
20 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
21 #include "tensorflow/lite/string_util.h"
22 
23 namespace tflite {
24 namespace resource {
25 namespace internal {
26 
27 /// Helper class for accessing TFLite tensor data.
28 template <typename T>
29 class TensorReader {
30  public:
TensorReader(const TfLiteTensor * input)31   explicit TensorReader(const TfLiteTensor* input) {
32     input_data_ = GetTensorData<T>(input);
33   }
34 
35   // Returns the corresponding scalar data at the given index position.
36   // In here, it does not check the validity of the index should be guaranteed
37   // in order not to harm the performance. Caller should take care of it.
GetData(int index)38   T GetData(int index) { return input_data_[index]; }
39 
40  private:
41   const T* input_data_;
42 };
43 
44 /// Helper class for accessing TFLite tensor data. This specialized class is for
45 /// std::string type.
46 template <>
47 class TensorReader<std::string> {
48  public:
TensorReader(const TfLiteTensor * input)49   explicit TensorReader(const TfLiteTensor* input) : input_(input) {}
50 
51   // Returns the corresponding string data at the given index position.
52   // In here, it does not check the validity of the index should be guaranteed
53   // in order not to harm the performance. Caller should take care of it.
GetData(int index)54   std::string GetData(int index) {
55     auto string_ref = GetString(input_, index);
56     return std::string(string_ref.str, string_ref.len);
57   }
58 
59  private:
60   const TfLiteTensor* input_;
61 };
62 
63 /// WARNING: Experimental interface, subject to change.
64 /// Helper class for writing TFLite tensor data.
65 template <typename ValueType>
66 class TensorWriter {
67  public:
TensorWriter(TfLiteTensor * values)68   explicit TensorWriter(TfLiteTensor* values) {
69     output_data_ = GetTensorData<ValueType>(values);
70   }
71 
72   // Sets the given value to the given index position of the tensor storage.
73   // In here, it does not check the validity of the index should be guaranteed
74   // in order not to harm the performance. Caller should take care of it.
SetData(int index,ValueType & value)75   void SetData(int index, ValueType& value) { output_data_[index] = value; }
76 
77   // Commit updates. In this case, it does nothing since the SetData method
78   // writes data directly.
Commit()79   void Commit() {
80     // Noop.
81   }
82 
83  private:
84   ValueType* output_data_;
85 };
86 
87 /// WARNING: Experimental interface, subject to change.
88 /// Helper class for writing TFLite tensor data. This specialized class is for
89 /// std::string type.
90 template <>
91 class TensorWriter<std::string> {
92  public:
TensorWriter(TfLiteTensor * values)93   explicit TensorWriter(TfLiteTensor* values) : values_(values) {}
94 
95   // Queues the given string value to the buffer regardless of the provided
96   // index.
97   // In here, it does not check the validity of the index should be guaranteed
98   // in order not to harm the performance. Caller should take care of it.
SetData(int index,const std::string & value)99   void SetData(int index, const std::string& value) {
100     buf_.AddString(value.data(), value.length());
101   }
102 
103   // Commit updates. The stored data in DynamicBuffer will be written into the
104   // tensor storage.
Commit()105   void Commit() { buf_.WriteToTensor(values_, nullptr); }
106 
107  private:
108   TfLiteTensor* values_;
109   DynamicBuffer buf_;
110 };
111 
112 }  // namespace internal
113 }  // namespace resource
114 }  // namespace tflite
115 
116 #endif  // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_UTIL_H_
117