11 #ifndef EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
12 #define EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
14 #include "./InternalHeaderCheck.h"
26 template<
typename XprType>
27 struct traits<TensorIndexPairOp<XprType> > :
public traits<XprType>
29 typedef traits<XprType> XprTraits;
30 typedef typename XprTraits::StorageKind StorageKind;
31 typedef typename XprTraits::Index
Index;
32 typedef Pair<Index, typename XprTraits::Scalar> Scalar;
33 typedef typename XprType::Nested Nested;
34 typedef std::remove_reference_t<Nested> Nested_;
35 static constexpr
int NumDimensions = XprTraits::NumDimensions;
36 static constexpr
int Layout = XprTraits::Layout;
39 template<
typename XprType>
40 struct eval<TensorIndexPairOp<XprType>,
Eigen::Dense>
42 typedef const TensorIndexPairOp<XprType>EIGEN_DEVICE_REF type;
45 template<
typename XprType>
46 struct nested<TensorIndexPairOp<XprType>, 1,
47 typename eval<TensorIndexPairOp<XprType> >::type>
49 typedef TensorIndexPairOp<XprType> type;
54 template<
typename XprType>
55 class TensorIndexPairOp :
public TensorBase<TensorIndexPairOp<XprType>, ReadOnlyAccessors>
58 typedef typename Eigen::internal::traits<TensorIndexPairOp>::Scalar Scalar;
60 typedef typename Eigen::internal::nested<TensorIndexPairOp>::type Nested;
61 typedef typename Eigen::internal::traits<TensorIndexPairOp>::StorageKind StorageKind;
62 typedef typename Eigen::internal::traits<TensorIndexPairOp>::Index
Index;
63 typedef Pair<Index, typename XprType::CoeffReturnType> CoeffReturnType;
65 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorIndexPairOp(
const XprType& expr)
69 const internal::remove_all_t<typename XprType::Nested>&
70 expression()
const {
return m_xpr; }
73 typename XprType::Nested m_xpr;
77 template<
typename ArgType,
typename Device>
78 struct TensorEvaluator<const TensorIndexPairOp<ArgType>, Device>
80 typedef TensorIndexPairOp<ArgType> XprType;
81 typedef typename XprType::Index Index;
82 typedef typename XprType::Scalar Scalar;
83 typedef typename XprType::CoeffReturnType CoeffReturnType;
85 typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
86 static constexpr
int NumDims = internal::array_size<Dimensions>::value;
87 typedef StorageMemory<CoeffReturnType, Device> Storage;
88 typedef typename Storage::Type EvaluatorPointerType;
94 PreferBlockAccess = TensorEvaluator<ArgType, Device>::PreferBlockAccess,
98 static constexpr
int Layout = TensorEvaluator<ArgType, Device>::Layout;
101 typedef internal::TensorBlockNotImplemented TensorBlock;
104 EIGEN_STRONG_INLINE TensorEvaluator(
const XprType& op,
const Device& device)
105 : m_impl(op.expression(), device) { }
107 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
108 return m_impl.dimensions();
111 EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(EvaluatorPointerType ) {
112 m_impl.evalSubExprsIfNeeded(NULL);
115 EIGEN_STRONG_INLINE
void cleanup() {
119 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index)
const
121 return CoeffReturnType(index, m_impl.coeff(index));
124 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
125 costPerCoeff(
bool vectorized)
const {
126 return m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, 1);
129 EIGEN_DEVICE_FUNC EvaluatorPointerType data()
const {
return NULL; }
131 #ifdef EIGEN_USE_SYCL
132 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void bind(cl::sycl::handler &cgh)
const {
138 TensorEvaluator<ArgType, Device> m_impl;
149 template<
typename ReduceOp,
typename Dims,
typename XprType>
150 struct traits<TensorPairReducerOp<ReduceOp, Dims, XprType> > :
public traits<XprType>
152 typedef traits<XprType> XprTraits;
153 typedef typename XprTraits::StorageKind StorageKind;
154 typedef typename XprTraits::Index
Index;
155 typedef Index Scalar;
156 typedef typename XprType::Nested Nested;
157 typedef std::remove_reference_t<Nested> Nested_;
158 static constexpr
int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value;
159 static constexpr
int Layout = XprTraits::Layout;
162 template<
typename ReduceOp,
typename Dims,
typename XprType>
163 struct eval<TensorPairReducerOp<ReduceOp, Dims, XprType>,
Eigen::Dense>
165 typedef const TensorPairReducerOp<ReduceOp, Dims, XprType>EIGEN_DEVICE_REF type;
168 template<
typename ReduceOp,
typename Dims,
typename XprType>
169 struct nested<TensorPairReducerOp<ReduceOp, Dims, XprType>, 1,
170 typename eval<TensorPairReducerOp<ReduceOp, Dims, XprType> >::type>
172 typedef TensorPairReducerOp<ReduceOp, Dims, XprType> type;
177 template<
typename ReduceOp,
typename Dims,
typename XprType>
178 class TensorPairReducerOp :
public TensorBase<TensorPairReducerOp<ReduceOp, Dims, XprType>, ReadOnlyAccessors>
181 typedef typename Eigen::internal::traits<TensorPairReducerOp>::Scalar Scalar;
183 typedef typename Eigen::internal::nested<TensorPairReducerOp>::type Nested;
184 typedef typename Eigen::internal::traits<TensorPairReducerOp>::StorageKind StorageKind;
185 typedef typename Eigen::internal::traits<TensorPairReducerOp>::Index
Index;
186 typedef Index CoeffReturnType;
188 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorPairReducerOp(
const XprType& expr,
189 const ReduceOp& reduce_op,
190 const Index return_dim,
191 const Dims& reduce_dims)
192 : m_xpr(expr), m_reduce_op(reduce_op), m_return_dim(return_dim), m_reduce_dims(reduce_dims) {}
195 const internal::remove_all_t<typename XprType::Nested>&
196 expression()
const {
return m_xpr; }
199 const ReduceOp& reduce_op()
const {
return m_reduce_op; }
202 const Dims& reduce_dims()
const {
return m_reduce_dims; }
205 Index return_dim()
const {
return m_return_dim; }
208 typename XprType::Nested m_xpr;
209 const ReduceOp m_reduce_op;
210 const Index m_return_dim;
211 const Dims m_reduce_dims;
215 template<
typename ReduceOp,
typename Dims,
typename ArgType,
typename Device>
216 struct TensorEvaluator<const TensorPairReducerOp<ReduceOp, Dims, ArgType>, Device>
218 typedef TensorPairReducerOp<ReduceOp, Dims, ArgType> XprType;
219 typedef typename XprType::Index
Index;
220 typedef typename XprType::Scalar Scalar;
221 typedef typename XprType::CoeffReturnType CoeffReturnType;
222 typedef typename TensorIndexPairOp<ArgType>::CoeffReturnType PairType;
223 typedef typename TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexPairOp<ArgType> >, Device>::Dimensions Dimensions;
224 typedef typename TensorEvaluator<const TensorIndexPairOp<ArgType> , Device>::Dimensions InputDimensions;
225 static constexpr
int NumDims = internal::array_size<InputDimensions>::value;
226 typedef array<Index, NumDims> StrideDims;
227 typedef StorageMemory<CoeffReturnType, Device> Storage;
228 typedef typename Storage::Type EvaluatorPointerType;
229 typedef StorageMemory<PairType, Device> PairStorageMem;
233 PacketAccess =
false,
235 PreferBlockAccess = TensorEvaluator<ArgType, Device>::PreferBlockAccess,
239 static constexpr
int Layout = TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexPairOp<ArgType>>, Device>::Layout;
241 typedef internal::TensorBlockNotImplemented TensorBlock;
244 EIGEN_STRONG_INLINE TensorEvaluator(
const XprType& op,
const Device& device)
245 : m_orig_impl(op.expression(), device),
246 m_impl(op.expression().index_pairs().reduce(op.reduce_dims(), op.reduce_op()), device),
247 m_return_dim(op.return_dim())
249 gen_strides(m_orig_impl.dimensions(), m_strides);
250 if (Layout ==
static_cast<int>(
ColMajor)) {
251 const Index total_size = internal::array_prod(m_orig_impl.dimensions());
252 m_stride_mod = (m_return_dim < NumDims - 1) ? m_strides[m_return_dim + 1] : total_size;
254 const Index total_size = internal::array_prod(m_orig_impl.dimensions());
255 m_stride_mod = (m_return_dim > 0) ? m_strides[m_return_dim - 1] : total_size;
258 m_stride_div = ((m_return_dim >= 0) &&
259 (m_return_dim <
static_cast<Index>(m_strides.size())))
260 ? m_strides[m_return_dim] : 1;
263 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
264 return m_impl.dimensions();
267 EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(EvaluatorPointerType ) {
268 m_impl.evalSubExprsIfNeeded(NULL);
271 EIGEN_STRONG_INLINE
void cleanup() {
275 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index)
const {
276 const PairType v = m_impl.coeff(index);
277 return (m_return_dim < 0) ? v.first : (v.first % m_stride_mod) / m_stride_div;
280 EIGEN_DEVICE_FUNC EvaluatorPointerType data()
const {
return NULL; }
281 #ifdef EIGEN_USE_SYCL
282 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void bind(cl::sycl::handler &cgh)
const {
284 m_orig_impl.bind(cgh);
288 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
289 costPerCoeff(
bool vectorized)
const {
290 const double compute_cost = 1.0 +
291 (m_return_dim < 0 ? 0.0 : (TensorOpCost::ModCost<Index>() + TensorOpCost::DivCost<Index>()));
292 return m_orig_impl.costPerCoeff(vectorized) +
293 m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, compute_cost);
297 EIGEN_DEVICE_FUNC
void gen_strides(
const InputDimensions& dims, StrideDims& strides) {
298 if (m_return_dim < 0) {
301 eigen_assert(m_return_dim < NumDims &&
302 "Asking to convert index to a dimension outside of the rank");
306 if (Layout ==
static_cast<int>(
ColMajor)) {
308 for (
int i = 1; i < NumDims; ++i) {
309 strides[i] = strides[i-1] * dims[i-1];
312 strides[NumDims-1] = 1;
313 for (
int i = NumDims - 2; i >= 0; --i) {
314 strides[i] = strides[i+1] * dims[i+1];
320 TensorEvaluator<const TensorIndexPairOp<ArgType>, Device> m_orig_impl;
321 TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexPairOp<ArgType> >, Device> m_impl;
322 const Index m_return_dim;
323 StrideDims m_strides;
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index