• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_
17 #define TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_
18 
19 #define EIGEN_USE_THREADS
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/kernels/cwise_ops.h"
22 
23 namespace Eigen {
24 namespace internal {
25 
26 // Gradient for the tanh function
27 template <typename T>
28 struct scalar_tanh_gradient_op {
EIGEN_EMPTY_STRUCT_CTORscalar_tanh_gradient_op29   EIGEN_EMPTY_STRUCT_CTOR(scalar_tanh_gradient_op)
30   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
31   operator()(const T& output, const T& output_gradient) const {
32     return output_gradient * (T(1) - output * output);
33   }
34   template <typename Packet>
35   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
packetOpscalar_tanh_gradient_op36   packetOp(const Packet& output, const Packet& output_gradient) const {
37     return pmul(output_gradient,
38                 psub(pset1<Packet>(T(1)), pmul(output, output)));
39   }
40 };
41 template <typename T>
42 struct functor_traits<scalar_tanh_gradient_op<T>> {
43   enum {
44     Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost,
45     PacketAccess = packet_traits<T>::HasSub && packet_traits<T>::HasMul,
46   };
47 };
48 
49 // Gradient for the sigmoid function
50 template <typename T>
51 struct scalar_sigmoid_gradient_op {
52   EIGEN_EMPTY_STRUCT_CTOR(scalar_sigmoid_gradient_op)
53   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
54   operator()(const T& output, const T& output_gradient) const {
55     return output_gradient * output * (T(1) - output);
56   }
57   template <typename Packet>
58   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
59   packetOp(const Packet& output, const Packet& output_gradient) const {
60     return pmul(output_gradient,
61                 pmul(output, psub(pset1<Packet>(T(1)), output)));
62   }
63 };
64 template <typename T>
65 struct functor_traits<scalar_sigmoid_gradient_op<T>> {
66   enum {
67     Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost,
68     PacketAccess = packet_traits<T>::HasSub && packet_traits<T>::HasMul,
69   };
70 };
71 
72 // Gradient for the inverse function
73 template <typename T>
74 struct scalar_inverse_gradient_op {
75   EIGEN_EMPTY_STRUCT_CTOR(scalar_inverse_gradient_op)
76   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
77   operator()(const T& output, const T& output_gradient) const {
78     if (output_gradient == T(0)) {
79       return T(0);
80     } else {
81       const T out_conj = numext::conj(output);
82       return -out_conj * out_conj * output_gradient;
83     }
84   }
85   template <typename Packet>
86   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
87   packetOp(const Packet& output, const Packet& output_gradient) const {
88     const Packet out_conj = pconj(output);
89     return mul_no_nan_op<T>().packetOp(pnegate(pmul(out_conj, out_conj)),
90                                        output_gradient);
91   }
92 };
93 template <typename T>
94 struct functor_traits<scalar_inverse_gradient_op<T>> {
95   enum {
96     Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost,
97     PacketAccess = packet_traits<T>::HasMul,
98   };
99 };
100 
101 // Gradient for the sqrt function
102 template <typename T>
103 struct scalar_sqrt_gradient_op {
104   EIGEN_EMPTY_STRUCT_CTOR(scalar_sqrt_gradient_op)
105   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
106   operator()(const T& output, const T& output_gradient) const {
107     if (output_gradient == T(0)) {
108       return T(0);
109     } else {
110       const T out_conj = numext::conj(output);
111       return (static_cast<T>(0.5) * output_gradient) / out_conj;
112     }
113   }
114   template <typename Packet>
115   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
116   packetOp(const Packet& output, const Packet& output_gradient) const {
117     const Packet const_half = pset1<Packet>(static_cast<T>(0.5));
118     const Packet out_conj = pconj(output);
119     return mul_no_nan_op<T>().packetOp(pdiv(const_half, out_conj),
120                                        output_gradient);
121   }
122 };
123 template <typename T>
124 struct functor_traits<scalar_sqrt_gradient_op<T>> {
125   enum {
126     PacketAccess = packet_traits<T>::HasMul & packet_traits<T>::HasDiv,
127     Cost = NumTraits<T>::MulCost + scalar_div_cost<T, PacketAccess>::value,
128   };
129 };
130 
131 // Gradient for the rsqrt function
132 template <typename T>
133 struct scalar_rsqrt_gradient_op {
134   EIGEN_EMPTY_STRUCT_CTOR(scalar_rsqrt_gradient_op)
135   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
136   operator()(const T& output, const T& output_gradient) const {
137     if (output_gradient == T(0)) {
138       return T(0);
139     } else {
140       const T out_conj = numext::conj(output);
141       return static_cast<T>(-0.5) * (output_gradient * out_conj) *
142              (out_conj * out_conj);
143     }
144   }
145   template <typename Packet>
146   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
147   packetOp(const Packet& output, const Packet& output_gradient) const {
148     const Packet const_half = pset1<Packet>(static_cast<T>(-0.5));
149     const Packet out_conj = pconj(output);
150     auto safe_pmul = [](const Packet& a, const Packet& b) {
151       return mul_no_nan_op<T>().packetOp(a, b);
152     };
153     return safe_pmul(pmul(const_half, pmul(out_conj, out_conj)),
154                      safe_pmul(out_conj, output_gradient));
155   }
156 };
157 template <typename T>
158 struct functor_traits<scalar_rsqrt_gradient_op<T>> {
159   enum {
160     Cost = 4 * NumTraits<T>::MulCost,
161     PacketAccess = packet_traits<T>::HasMul,
162   };
163 };
164 
165 }  // end namespace internal
166 }  // end namespace Eigen
167 
168 namespace tensorflow {
169 
170 namespace functor {
171 
172 template <typename Device, typename Functor>
173 struct SimpleBinaryFunctor {
174   void operator()(const Device& d, typename Functor::tout_type out,
175                   typename Functor::tin_type in0,
176                   typename Functor::tin_type in1);
177 };
178 
179 // Partial specialization of BinaryFunctor for CPU devices
180 typedef Eigen::ThreadPoolDevice CPUDevice;
181 
182 template <typename Functor>
183 struct SimpleBinaryFunctor<CPUDevice, Functor> {
184   void operator()(const CPUDevice& d, typename Functor::tout_type out,
185                   typename Functor::tin_type in0,
186                   typename Functor::tin_type in1) {
187     out.device(d) = in0.binaryExpr(in1, typename Functor::func());
188   }
189 };
190 
191 #ifdef TENSORFLOW_USE_SYCL
192 // Partial specialization of BinaryFunctor for SYCL devices
193 typedef Eigen::SyclDevice SYCLDevice;
194 template <typename Functor>
195 struct SimpleBinaryFunctor<SYCLDevice, Functor> {
196   void operator()(const SYCLDevice& d, typename Functor::tout_type out,
197                   typename Functor::tin_type in0,
198                   typename Functor::tin_type in1) {
199     out.device(d) = in0.binaryExpr(in1, typename Functor::func());
200   }
201 };
202 
203 #endif  // TENSORFLOW_USE_SYCL
204 
205 template <typename T>
206 struct tanh_grad : base<T, Eigen::internal::scalar_tanh_gradient_op<T>> {};
207 
208 template <typename T>
209 struct sigmoid_grad : base<T, Eigen::internal::scalar_sigmoid_gradient_op<T>> {
210 };
211 
212 template <typename T>
213 struct inverse_grad : base<T, Eigen::internal::scalar_inverse_gradient_op<T>> {
214 };
215 
216 template <typename T>
217 struct sqrt_grad : base<T, Eigen::internal::scalar_sqrt_gradient_op<T>> {};
218 
219 template <typename T>
220 struct rsqrt_grad : base<T, Eigen::internal::scalar_rsqrt_gradient_op<T>> {};
221 
222 template <typename T>
223 struct igamma_grad_a : base<T, Eigen::internal::scalar_igamma_der_a_op<T>> {};
224 
225 }  // end namespace functor
226 
227 }  // end namespace tensorflow
228 #endif  // TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_
229