1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_UTIL_H_ 16 #define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_UTIL_H_ 17 18 #include <string> 19 20 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" 21 #include "tensorflow/lite/string_util.h" 22 23 namespace tflite { 24 namespace resource { 25 namespace internal { 26 27 /// Helper class for accessing TFLite tensor data. 28 template <typename T> 29 class TensorReader { 30 public: TensorReader(const TfLiteTensor * input)31 explicit TensorReader(const TfLiteTensor* input) { 32 input_data_ = GetTensorData<T>(input); 33 } 34 35 // Returns the corresponding scalar data at the given index position. 36 // In here, it does not check the validity of the index should be guaranteed 37 // in order not to harm the performance. Caller should take care of it. GetData(int index)38 T GetData(int index) { return input_data_[index]; } 39 40 private: 41 const T* input_data_; 42 }; 43 44 /// Helper class for accessing TFLite tensor data. This specialized class is for 45 /// std::string type. 46 template <> 47 class TensorReader<std::string> { 48 public: TensorReader(const TfLiteTensor * input)49 explicit TensorReader(const TfLiteTensor* input) : input_(input) {} 50 51 // Returns the corresponding string data at the given index position. 52 // In here, it does not check the validity of the index should be guaranteed 53 // in order not to harm the performance. Caller should take care of it. GetData(int index)54 std::string GetData(int index) { 55 auto string_ref = GetString(input_, index); 56 return std::string(string_ref.str, string_ref.len); 57 } 58 59 private: 60 const TfLiteTensor* input_; 61 }; 62 63 /// WARNING: Experimental interface, subject to change. 64 /// Helper class for writing TFLite tensor data. 65 template <typename ValueType> 66 class TensorWriter { 67 public: TensorWriter(TfLiteTensor * values)68 explicit TensorWriter(TfLiteTensor* values) { 69 output_data_ = GetTensorData<ValueType>(values); 70 } 71 72 // Sets the given value to the given index position of the tensor storage. 73 // In here, it does not check the validity of the index should be guaranteed 74 // in order not to harm the performance. Caller should take care of it. SetData(int index,ValueType & value)75 void SetData(int index, ValueType& value) { output_data_[index] = value; } 76 77 // Commit updates. In this case, it does nothing since the SetData method 78 // writes data directly. Commit()79 void Commit() { 80 // Noop. 81 } 82 83 private: 84 ValueType* output_data_; 85 }; 86 87 /// WARNING: Experimental interface, subject to change. 88 /// Helper class for writing TFLite tensor data. This specialized class is for 89 /// std::string type. 90 template <> 91 class TensorWriter<std::string> { 92 public: TensorWriter(TfLiteTensor * values)93 explicit TensorWriter(TfLiteTensor* values) : values_(values) {} 94 95 // Queues the given string value to the buffer regardless of the provided 96 // index. 97 // In here, it does not check the validity of the index should be guaranteed 98 // in order not to harm the performance. Caller should take care of it. SetData(int index,const std::string & value)99 void SetData(int index, const std::string& value) { 100 buf_.AddString(value.data(), value.length()); 101 } 102 103 // Commit updates. The stored data in DynamicBuffer will be written into the 104 // tensor storage. Commit()105 void Commit() { buf_.WriteToTensor(values_, nullptr); } 106 107 private: 108 TfLiteTensor* values_; 109 DynamicBuffer buf_; 110 }; 111 112 } // namespace internal 113 } // namespace resource 114 } // namespace tflite 115 116 #endif // TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_LOOKUP_UTIL_H_ 117