1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H 12 13 namespace Eigen { 14 15 /** \class TensorDevice 16 * \ingroup CXX11_Tensor_Module 17 * 18 * \brief Pseudo expression providing an operator = that will evaluate its argument 19 * on the specified computing 'device' (GPU, thread pool, ...) 20 * 21 * Example: 22 * C.device(EIGEN_GPU) = A + B; 23 * 24 * Todo: operator *= and /=. 25 */ 26 27 template <typename ExpressionType, typename DeviceType> class TensorDevice { 28 public: TensorDevice(const DeviceType & device,ExpressionType & expression)29 TensorDevice(const DeviceType& device, ExpressionType& expression) : m_device(device), m_expression(expression) {} 30 31 template<typename OtherDerived> 32 EIGEN_STRONG_INLINE TensorDevice& operator=(const OtherDerived& other) { 33 typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign; 34 Assign assign(m_expression, other); 35 internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device); 36 return *this; 37 } 38 39 template<typename OtherDerived> 40 EIGEN_STRONG_INLINE TensorDevice& operator+=(const OtherDerived& other) { 41 typedef typename OtherDerived::Scalar Scalar; 42 typedef TensorCwiseBinaryOp<internal::scalar_sum_op<Scalar>, const ExpressionType, const OtherDerived> Sum; 43 Sum sum(m_expression, other); 44 typedef TensorAssignOp<ExpressionType, const Sum> Assign; 45 Assign assign(m_expression, sum); 46 internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device); 47 return *this; 48 } 49 50 template<typename OtherDerived> 51 EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) { 52 typedef typename OtherDerived::Scalar Scalar; 53 typedef TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const ExpressionType, const OtherDerived> Difference; 54 Difference difference(m_expression, other); 55 typedef TensorAssignOp<ExpressionType, const Difference> Assign; 56 Assign assign(m_expression, difference); 57 internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device); 58 return *this; 59 } 60 61 protected: 62 const DeviceType& m_device; 63 ExpressionType& m_expression; 64 }; 65 66 } // end namespace Eigen 67 68 #endif // EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H 69