10#ifndef EIGEN_VISITOR_H
11#define EIGEN_VISITOR_H
13#include "./InternalHeaderCheck.h"
19template<
typename Visitor,
typename Derived,
int UnrollCount,
bool Vectorize=((Derived::PacketAccess!=0) && functor_traits<Visitor>::PacketAccess)>
22template<
typename Visitor,
typename Derived,
int UnrollCount>
23struct visitor_impl<Visitor, Derived, UnrollCount, false>
26 col = (UnrollCount-1) / Derived::RowsAtCompileTime,
27 row = (UnrollCount-1) % Derived::RowsAtCompileTime
31 static inline void run(
const Derived &mat, Visitor& visitor)
33 visitor_impl<Visitor, Derived, UnrollCount-1>::run(mat, visitor);
34 visitor(mat.coeff(row, col), row, col);
38template<
typename Visitor,
typename Derived>
39struct visitor_impl<Visitor, Derived, 1, false>
42 static inline void run(
const Derived &mat, Visitor& visitor)
44 return visitor.init(mat.coeff(0, 0), 0, 0);
49template<
typename Visitor,
typename Derived>
50struct visitor_impl<Visitor, Derived, 0, false> {
52 static inline void run(
const Derived &, Visitor& )
56template<
typename Visitor,
typename Derived>
57struct visitor_impl<Visitor, Derived,
Dynamic, false>
60 static inline void run(
const Derived& mat, Visitor& visitor)
62 visitor.init(mat.coeff(0,0), 0, 0);
63 for(
Index i = 1; i < mat.rows(); ++i)
64 visitor(mat.coeff(i, 0), i, 0);
65 for(
Index j = 1; j < mat.cols(); ++j)
66 for(
Index i = 0; i < mat.rows(); ++i)
67 visitor(mat.coeff(i, j), i, j);
71template<
typename Visitor,
typename Derived,
int UnrollSize>
72struct visitor_impl<Visitor, Derived, UnrollSize, true>
74 typedef typename Derived::Scalar Scalar;
75 typedef typename packet_traits<Scalar>::type Packet;
78 static inline void run(
const Derived& mat, Visitor& visitor)
80 const Index PacketSize = packet_traits<Scalar>::size;
81 visitor.init(mat.coeff(0,0), 0, 0);
82 if (Derived::IsRowMajor) {
83 for(
Index i = 0; i < mat.rows(); ++i) {
84 Index j = i == 0 ? 1 : 0;
85 for(; j+PacketSize-1 < mat.cols(); j += PacketSize) {
86 Packet p = mat.packet(i, j);
87 visitor.packet(p, i, j);
89 for(; j < mat.cols(); ++j)
90 visitor(mat.coeff(i, j), i, j);
93 for(
Index j = 0; j < mat.cols(); ++j) {
94 Index i = j == 0 ? 1 : 0;
95 for(; i+PacketSize-1 < mat.rows(); i += PacketSize) {
96 Packet p = mat.packet(i, j);
97 visitor.packet(p, i, j);
99 for(; i < mat.rows(); ++i)
100 visitor(mat.coeff(i, j), i, j);
107template<
typename XprType>
108class visitor_evaluator
111 typedef internal::evaluator<XprType> Evaluator;
115 IsRowMajor = XprType::IsRowMajor,
116 RowsAtCompileTime = XprType::RowsAtCompileTime,
117 CoeffReadCost = Evaluator::CoeffReadCost
122 explicit visitor_evaluator(
const XprType &xpr) : m_evaluator(xpr), m_xpr(xpr) { }
124 typedef typename XprType::Scalar Scalar;
125 typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
126 typedef typename internal::remove_const<typename XprType::PacketReturnType>::type PacketReturnType;
128 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
Index rows() const EIGEN_NOEXCEPT {
return m_xpr.rows(); }
129 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
Index cols() const EIGEN_NOEXCEPT {
return m_xpr.cols(); }
130 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR
Index size() const EIGEN_NOEXCEPT {
return m_xpr.size(); }
132 EIGEN_DEVICE_FUNC CoeffReturnType coeff(
Index row,
Index col)
const
133 {
return m_evaluator.coeff(row, col); }
134 EIGEN_DEVICE_FUNC PacketReturnType packet(
Index row,
Index col)
const
135 {
return m_evaluator.template packet<Unaligned,PacketReturnType>(row, col); }
138 Evaluator m_evaluator;
139 const XprType &m_xpr;
163template<
typename Derived>
164template<
typename Visitor>
171 typedef typename internal::visitor_evaluator<Derived> ThisEvaluator;
172 ThisEvaluator thisEval(derived());
175 unroll = SizeAtCompileTime !=
Dynamic
176 && SizeAtCompileTime * int(ThisEvaluator::CoeffReadCost) + (SizeAtCompileTime-1) *
int(internal::functor_traits<Visitor>::Cost) <= EIGEN_UNROLLING_LIMIT
178 return internal::visitor_impl<Visitor, ThisEvaluator, unroll ? int(SizeAtCompileTime) :
Dynamic>::run(thisEval, visitor);
186template <
typename Derived>
191 coeff_visitor() : row(-1), col(-1), res(0) {}
192 typedef typename Derived::Scalar Scalar;
196 inline void init(
const Scalar& value,
Index i,
Index j)
205template<
typename Scalar,
int NaNPropagation,
bool is_min=true>
206struct minmax_compare {
207 typedef typename packet_traits<Scalar>::type Packet;
208 static EIGEN_DEVICE_FUNC
inline bool compare(Scalar a, Scalar b) {
return a < b; }
209 static EIGEN_DEVICE_FUNC
inline Scalar predux(
const Packet& p) {
return predux_min<NaNPropagation>(p);}
212template<
typename Scalar,
int NaNPropagation>
213struct minmax_compare<Scalar, NaNPropagation, false> {
214 typedef typename packet_traits<Scalar>::type Packet;
215 static EIGEN_DEVICE_FUNC
inline bool compare(Scalar a, Scalar b) {
return a > b; }
216 static EIGEN_DEVICE_FUNC
inline Scalar predux(
const Packet& p) {
return predux_max<NaNPropagation>(p);}
219template <
typename Derived,
bool is_min,
int NaNPropagation>
220struct minmax_coeff_visitor : coeff_visitor<Derived>
222 using Scalar =
typename Derived::Scalar;
223 using Packet =
typename packet_traits<Scalar>::type;
224 using Comparator = minmax_compare<Scalar, NaNPropagation, is_min>;
226 EIGEN_DEVICE_FUNC
inline
227 void operator() (
const Scalar& value,
Index i,
Index j)
229 if(Comparator::compare(value, this->res)) {
236 EIGEN_DEVICE_FUNC
inline
237 void packet(
const Packet& p,
Index i,
Index j) {
238 const Index PacketSize = packet_traits<Scalar>::size;
239 Scalar value = Comparator::predux(p);
240 if (Comparator::compare(value, this->res)) {
241 const Packet range = preverse(plset<Packet>(Scalar(1)));
242 Packet mask = pcmp_eq(pset1<Packet>(value), p);
243 Index max_idx = PacketSize -
static_cast<Index>(predux_max(pand(range, mask)));
245 this->row = Derived::IsRowMajor ? i : i + max_idx;;
246 this->col = Derived::IsRowMajor ? j + max_idx : j;
253template <
typename Derived,
bool is_min>
254struct minmax_coeff_visitor<Derived, is_min,
PropagateNumbers> : coeff_visitor<Derived>
256 typedef typename Derived::Scalar Scalar;
257 using Packet =
typename packet_traits<Scalar>::type;
258 using Comparator = minmax_compare<Scalar, PropagateNumbers, is_min>;
260 EIGEN_DEVICE_FUNC
inline
261 void operator() (
const Scalar& value,
Index i,
Index j)
263 if ((!(numext::isnan)(value) && (numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
270 EIGEN_DEVICE_FUNC
inline
271 void packet(
const Packet& p,
Index i,
Index j) {
272 const Index PacketSize = packet_traits<Scalar>::size;
273 Scalar value = Comparator::predux(p);
274 if ((!(numext::isnan)(value) && (numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
275 const Packet range = preverse(plset<Packet>(Scalar(1)));
277 Packet mask = pcmp_eq(pset1<Packet>(value), p);
278 Index max_idx = PacketSize -
static_cast<Index>(predux_max(pand(range, mask)));
280 this->row = Derived::IsRowMajor ? i : i + max_idx;;
281 this->col = Derived::IsRowMajor ? j + max_idx : j;
289template <
typename Derived,
bool is_min>
290struct minmax_coeff_visitor<Derived, is_min,
PropagateNaN> : coeff_visitor<Derived>
292 typedef typename Derived::Scalar Scalar;
293 using Packet =
typename packet_traits<Scalar>::type;
294 using Comparator = minmax_compare<Scalar, PropagateNaN, is_min>;
296 EIGEN_DEVICE_FUNC
inline
297 void operator() (
const Scalar& value,
Index i,
Index j)
299 const bool value_is_nan = (numext::isnan)(value);
300 if ((value_is_nan && !(numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
307 EIGEN_DEVICE_FUNC
inline
308 void packet(
const Packet& p,
Index i,
Index j) {
309 const Index PacketSize = packet_traits<Scalar>::size;
310 Scalar value = Comparator::predux(p);
311 const bool value_is_nan = (numext::isnan)(value);
312 if ((value_is_nan && !(numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
313 const Packet range = preverse(plset<Packet>(Scalar(1)));
315 Packet mask = value_is_nan ? pnot(pcmp_eq(p, p)) : pcmp_eq(pset1<Packet>(value), p);
316 Index max_idx = PacketSize -
static_cast<Index>(predux_max(pand(range, mask)));
318 this->row = Derived::IsRowMajor ? i : i + max_idx;;
319 this->col = Derived::IsRowMajor ? j + max_idx : j;
324template<
typename Scalar,
bool is_min,
int NaNPropagation>
325struct functor_traits<minmax_coeff_visitor<Scalar, is_min, NaNPropagation> > {
327 Cost = NumTraits<Scalar>::AddCost,
345template<
typename Derived>
346template<
int NaNPropagation,
typename IndexType>
348typename internal::traits<Derived>::Scalar
351 eigen_assert(this->rows()>0 && this->cols()>0 &&
"you are using an empty matrix");
353 internal::minmax_coeff_visitor<Derived, true, NaNPropagation> minVisitor;
354 this->visit(minVisitor);
355 *rowId = minVisitor.row;
356 if (colId) *colId = minVisitor.col;
357 return minVisitor.res;
370template<
typename Derived>
371template<
int NaNPropagation,
typename IndexType>
373typename internal::traits<Derived>::Scalar
376 eigen_assert(this->rows()>0 && this->cols()>0 &&
"you are using an empty matrix");
378 EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
379 internal::minmax_coeff_visitor<Derived, true, NaNPropagation> minVisitor;
380 this->visit(minVisitor);
381 *index = IndexType((RowsAtCompileTime==1) ? minVisitor.col : minVisitor.row);
382 return minVisitor.res;
396template<
typename Derived>
397template<
int NaNPropagation,
typename IndexType>
399typename internal::traits<Derived>::Scalar
402 eigen_assert(this->rows()>0 && this->cols()>0 &&
"you are using an empty matrix");
404 internal::minmax_coeff_visitor<Derived, false, NaNPropagation> maxVisitor;
405 this->visit(maxVisitor);
406 *rowPtr = maxVisitor.row;
407 if (colPtr) *colPtr = maxVisitor.col;
408 return maxVisitor.res;
421template<
typename Derived>
422template<
int NaNPropagation,
typename IndexType>
424typename internal::traits<Derived>::Scalar
427 eigen_assert(this->rows()>0 && this->cols()>0 &&
"you are using an empty matrix");
429 EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
430 internal::minmax_coeff_visitor<Derived, false, NaNPropagation> maxVisitor;
431 this->visit(maxVisitor);
432 *index = (RowsAtCompileTime==1) ? maxVisitor.col : maxVisitor.row;
433 return maxVisitor.res;
internal::traits< Derived >::Scalar minCoeff() const
Definition: Redux.h:433
void visit(Visitor &func) const
Definition: Visitor.h:166
internal::traits< Derived >::Scalar maxCoeff() const
Definition: Redux.h:448
@ PropagateNaN
Definition: Constants.h:345
@ PropagateNumbers
Definition: Constants.h:347
const unsigned int PacketAccessBit
Definition: Constants.h:96
Namespace containing all symbols from the Eigen library.
Definition: B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:59
const int Dynamic
Definition: Constants.h:24