1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_UTIL_TRANSFORM_OUTPUT_ITERATOR_H_ 17 #define TENSORFLOW_UTIL_TRANSFORM_OUTPUT_ITERATOR_H_ 18 19 #include <iostream> 20 #include <iterator> 21 22 namespace tensorflow { 23 24 template <typename StoreType, typename InputType, typename ConversionOp, 25 typename OffsetT = ptrdiff_t> 26 class TransformOutputIterator { 27 protected: 28 // Proxy object 29 struct Reference { 30 StoreType* ptr; 31 ConversionOp conversion_op; 32 33 /// Constructor ReferenceReference34 __host__ __device__ __forceinline__ Reference(StoreType* ptr, 35 ConversionOp conversion_op) 36 : ptr(ptr), conversion_op(conversion_op) {} 37 38 /// Assignment 39 __host__ __device__ __forceinline__ InputType operator=(InputType val) { 40 *ptr = conversion_op(val); 41 return val; 42 } 43 }; 44 45 public: 46 // Required iterator traits 47 typedef TransformOutputIterator self_type; ///< My own type 48 typedef OffsetT difference_type; ///< Type to express the result of 49 ///< subtracting one iterator from another 50 typedef void 51 value_type; ///< The type of the element the iterator can point to 52 typedef void pointer; ///< The type of a pointer to an element the iterator 53 ///< can point to 54 typedef Reference reference; ///< The type of a reference to an element the 55 ///< iterator can point to 56 57 typedef std::random_access_iterator_tag 58 iterator_category; ///< The iterator category 59 60 /*private:*/ 61 62 StoreType* ptr; 63 ConversionOp conversion_op; 64 65 public: 66 /// Constructor 67 template <typename QualifiedStoreType> TransformOutputIterator(QualifiedStoreType * ptr,ConversionOp conversionOp)68 __host__ __device__ __forceinline__ TransformOutputIterator( 69 QualifiedStoreType* ptr, 70 ConversionOp conversionOp) ///< Native pointer to wrap 71 : ptr(ptr), conversion_op(conversionOp) {} 72 73 /// Postfix increment 74 __host__ __device__ __forceinline__ self_type operator++(int) { 75 self_type retval = *this; 76 ptr++; 77 return retval; 78 } 79 80 /// Prefix increment 81 __host__ __device__ __forceinline__ self_type operator++() { 82 ptr++; 83 return *this; 84 } 85 86 /// Indirection 87 __host__ __device__ __forceinline__ reference operator*() const { 88 return Reference(ptr, conversion_op); 89 } 90 91 /// Addition 92 template <typename Distance> 93 __host__ __device__ __forceinline__ self_type operator+(Distance n) const { 94 self_type retval(ptr + n, conversion_op); 95 return retval; 96 } 97 98 /// Addition assignment 99 template <typename Distance> 100 __host__ __device__ __forceinline__ self_type& operator+=(Distance n) { 101 ptr += n; 102 return *this; 103 } 104 105 /// Subtraction 106 template <typename Distance> 107 __host__ __device__ __forceinline__ self_type operator-(Distance n) const { 108 self_type retval(ptr - n, conversion_op); 109 return retval; 110 } 111 112 /// Subtraction assignment 113 template <typename Distance> 114 __host__ __device__ __forceinline__ self_type& operator-=(Distance n) { 115 ptr -= n; 116 return *this; 117 } 118 119 /// Distance 120 __host__ __device__ __forceinline__ difference_type 121 operator-(self_type other) const { 122 return ptr - other.ptr; 123 } 124 125 /// Array subscript 126 template <typename Distance> 127 __host__ __device__ __forceinline__ reference operator[](Distance n) const { 128 return Reference(ptr + n, conversion_op); 129 } 130 131 /// Equal to 132 __host__ __device__ __forceinline__ bool operator==(const self_type& rhs) { 133 return (ptr == rhs.ptr); 134 } 135 136 /// Not equal to 137 __host__ __device__ __forceinline__ bool operator!=(const self_type& rhs) { 138 return (ptr != rhs.ptr); 139 } 140 141 /// ostream operator 142 friend std::ostream& operator<<(std::ostream& os, const self_type& itr) { 143 return os; 144 } 145 }; 146 147 } // end namespace tensorflow 148 149 #endif // TENSORFLOW_UTIL_TRANSFORM_OUTPUT_ITERATOR_H_ 150