Eigen  3.4.90 (git rev 67eeba6e720c5745abc77ae6c92ce0a44aa7b7ae)
Visitor.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2008 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_VISITOR_H
11 #define EIGEN_VISITOR_H
12 
13 #include "./InternalHeaderCheck.h"
14 
15 namespace Eigen {
16 
17 namespace internal {
18 
19 template<typename Visitor, typename Derived, int UnrollCount, bool Vectorize=((Derived::PacketAccess!=0) && functor_traits<Visitor>::PacketAccess)>
20 struct visitor_impl;
21 
22 template<typename Visitor, typename Derived, int UnrollCount>
23 struct visitor_impl<Visitor, Derived, UnrollCount, false>
24 {
25  enum {
26  col = Derived::IsRowMajor ? (UnrollCount-1) % Derived::ColsAtCompileTime
27  : (UnrollCount-1) / Derived::RowsAtCompileTime,
28  row = Derived::IsRowMajor ? (UnrollCount-1) / Derived::ColsAtCompileTime
29  : (UnrollCount-1) % Derived::RowsAtCompileTime
30  };
31 
32  EIGEN_DEVICE_FUNC
33  static inline void run(const Derived &mat, Visitor& visitor)
34  {
35  visitor_impl<Visitor, Derived, UnrollCount-1>::run(mat, visitor);
36  visitor(mat.coeff(row, col), row, col);
37  }
38 };
39 
40 template<typename Visitor, typename Derived>
41 struct visitor_impl<Visitor, Derived, 1, false>
42 {
43  EIGEN_DEVICE_FUNC
44  static inline void run(const Derived &mat, Visitor& visitor)
45  {
46  return visitor.init(mat.coeff(0, 0), 0, 0);
47  }
48 };
49 
50 // This specialization enables visitors on empty matrices at compile-time
51 template<typename Visitor, typename Derived>
52 struct visitor_impl<Visitor, Derived, 0, false> {
53  EIGEN_DEVICE_FUNC
54  static inline void run(const Derived &/*mat*/, Visitor& /*visitor*/)
55  {}
56 };
57 
58 template<typename Visitor, typename Derived>
59 struct visitor_impl<Visitor, Derived, Dynamic, /*Vectorize=*/false>
60 {
61  EIGEN_DEVICE_FUNC
62  static inline void run(const Derived& mat, Visitor& visitor)
63  {
64  visitor.init(mat.coeff(0,0), 0, 0);
65  if (Derived::IsRowMajor) {
66  for(Index i = 1; i < mat.cols(); ++i) {
67  visitor(mat.coeff(0, i), 0, i);
68  }
69  for(Index j = 1; j < mat.rows(); ++j) {
70  for(Index i = 0; i < mat.cols(); ++i) {
71  visitor(mat.coeff(j, i), j, i);
72  }
73  }
74  } else {
75  for(Index i = 1; i < mat.rows(); ++i) {
76  visitor(mat.coeff(i, 0), i, 0);
77  }
78  for(Index j = 1; j < mat.cols(); ++j) {
79  for(Index i = 0; i < mat.rows(); ++i) {
80  visitor(mat.coeff(i, j), i, j);
81  }
82  }
83  }
84  }
85 };
86 
87 template<typename Visitor, typename Derived, int UnrollSize>
88 struct visitor_impl<Visitor, Derived, UnrollSize, /*Vectorize=*/true>
89 {
90  typedef typename Derived::Scalar Scalar;
91  typedef typename packet_traits<Scalar>::type Packet;
92 
93  EIGEN_DEVICE_FUNC
94  static inline void run(const Derived& mat, Visitor& visitor)
95  {
96  const Index PacketSize = packet_traits<Scalar>::size;
97  visitor.init(mat.coeff(0,0), 0, 0);
98  if (Derived::IsRowMajor) {
99  for(Index i = 0; i < mat.rows(); ++i) {
100  Index j = i == 0 ? 1 : 0;
101  for(; j+PacketSize-1 < mat.cols(); j += PacketSize) {
102  Packet p = mat.packet(i, j);
103  visitor.packet(p, i, j);
104  }
105  for(; j < mat.cols(); ++j)
106  visitor(mat.coeff(i, j), i, j);
107  }
108  } else {
109  for(Index j = 0; j < mat.cols(); ++j) {
110  Index i = j == 0 ? 1 : 0;
111  for(; i+PacketSize-1 < mat.rows(); i += PacketSize) {
112  Packet p = mat.packet(i, j);
113  visitor.packet(p, i, j);
114  }
115  for(; i < mat.rows(); ++i)
116  visitor(mat.coeff(i, j), i, j);
117  }
118  }
119  }
120 };
121 
122 // evaluator adaptor
123 template<typename XprType>
124 class visitor_evaluator
125 {
126 public:
127  typedef internal::evaluator<XprType> Evaluator;
128 
129  enum {
130  PacketAccess = Evaluator::Flags & PacketAccessBit,
131  IsRowMajor = XprType::IsRowMajor,
132  RowsAtCompileTime = XprType::RowsAtCompileTime,
133  ColsAtCompileTime = XprType::ColsAtCompileTime,
134  CoeffReadCost = Evaluator::CoeffReadCost
135  };
136 
137 
138  EIGEN_DEVICE_FUNC
139  explicit visitor_evaluator(const XprType &xpr) : m_evaluator(xpr), m_xpr(xpr) { }
140 
141  typedef typename XprType::Scalar Scalar;
142  typedef std::remove_const_t<typename XprType::CoeffReturnType> CoeffReturnType;
143  typedef std::remove_const_t<typename XprType::PacketReturnType> PacketReturnType;
144 
145  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return m_xpr.rows(); }
146  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return m_xpr.cols(); }
147  EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index size() const EIGEN_NOEXCEPT { return m_xpr.size(); }
148 
149  EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index row, Index col) const
150  { return m_evaluator.coeff(row, col); }
151  EIGEN_DEVICE_FUNC PacketReturnType packet(Index row, Index col) const
152  { return m_evaluator.template packet<Unaligned,PacketReturnType>(row, col); }
153 
154 protected:
155  Evaluator m_evaluator;
156  const XprType &m_xpr;
157 };
158 
159 } // end namespace internal
160 
180 template<typename Derived>
181 template<typename Visitor>
182 EIGEN_DEVICE_FUNC
183 void DenseBase<Derived>::visit(Visitor& visitor) const
184 {
185  if(size()==0)
186  return;
187 
188  typedef typename internal::visitor_evaluator<Derived> ThisEvaluator;
189  ThisEvaluator thisEval(derived());
190 
191  enum {
192  unroll = SizeAtCompileTime != Dynamic
193  && SizeAtCompileTime * int(ThisEvaluator::CoeffReadCost) + (SizeAtCompileTime-1) * int(internal::functor_traits<Visitor>::Cost) <= EIGEN_UNROLLING_LIMIT
194  };
195  return internal::visitor_impl<Visitor, ThisEvaluator, unroll ? int(SizeAtCompileTime) : Dynamic>::run(thisEval, visitor);
196 }
197 
198 namespace internal {
199 
203 template <typename Derived>
204 struct coeff_visitor
205 {
206  // default initialization to avoid countless invalid maybe-uninitialized warnings by gcc
207  EIGEN_DEVICE_FUNC
208  coeff_visitor() : row(-1), col(-1), res(0) {}
209  typedef typename Derived::Scalar Scalar;
210  Index row, col;
211  Scalar res;
212  EIGEN_DEVICE_FUNC
213  inline void init(const Scalar& value, Index i, Index j)
214  {
215  res = value;
216  row = i;
217  col = j;
218  }
219 };
220 
221 
222 template<typename Scalar, int NaNPropagation, bool is_min=true>
223 struct minmax_compare {
224  typedef typename packet_traits<Scalar>::type Packet;
225  static EIGEN_DEVICE_FUNC inline bool compare(Scalar a, Scalar b) { return a < b; }
226  static EIGEN_DEVICE_FUNC inline Scalar predux(const Packet& p) { return predux_min<NaNPropagation>(p);}
227 };
228 
229 template<typename Scalar, int NaNPropagation>
230 struct minmax_compare<Scalar, NaNPropagation, false> {
231  typedef typename packet_traits<Scalar>::type Packet;
232  static EIGEN_DEVICE_FUNC inline bool compare(Scalar a, Scalar b) { return a > b; }
233  static EIGEN_DEVICE_FUNC inline Scalar predux(const Packet& p) { return predux_max<NaNPropagation>(p);}
234 };
235 
236 template <typename Derived, bool is_min, int NaNPropagation>
237 struct minmax_coeff_visitor : coeff_visitor<Derived>
238 {
239  using Scalar = typename Derived::Scalar;
240  using Packet = typename packet_traits<Scalar>::type;
241  using Comparator = minmax_compare<Scalar, NaNPropagation, is_min>;
242 
243  EIGEN_DEVICE_FUNC inline
244  void operator() (const Scalar& value, Index i, Index j)
245  {
246  if(Comparator::compare(value, this->res)) {
247  this->res = value;
248  this->row = i;
249  this->col = j;
250  }
251  }
252 
253  EIGEN_DEVICE_FUNC inline
254  void packet(const Packet& p, Index i, Index j) {
255  const Index PacketSize = packet_traits<Scalar>::size;
256  Scalar value = Comparator::predux(p);
257  if (Comparator::compare(value, this->res)) {
258  const Packet range = preverse(plset<Packet>(Scalar(1)));
259  Packet mask = pcmp_eq(pset1<Packet>(value), p);
260  Index max_idx = PacketSize - static_cast<Index>(predux_max(pand(range, mask)));
261  this->res = value;
262  this->row = Derived::IsRowMajor ? i : i + max_idx;;
263  this->col = Derived::IsRowMajor ? j + max_idx : j;
264  }
265  }
266 };
267 
268 // Suppress NaN. The only case in which we return NaN is if the matrix is all NaN, in which case,
269 // the row=0, col=0 is returned for the location.
270 template <typename Derived, bool is_min>
271 struct minmax_coeff_visitor<Derived, is_min, PropagateNumbers> : coeff_visitor<Derived>
272 {
273  typedef typename Derived::Scalar Scalar;
274  using Packet = typename packet_traits<Scalar>::type;
275  using Comparator = minmax_compare<Scalar, PropagateNumbers, is_min>;
276 
277  EIGEN_DEVICE_FUNC inline
278  void operator() (const Scalar& value, Index i, Index j)
279  {
280  if ((!(numext::isnan)(value) && (numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
281  this->res = value;
282  this->row = i;
283  this->col = j;
284  }
285  }
286 
287  EIGEN_DEVICE_FUNC inline
288  void packet(const Packet& p, Index i, Index j) {
289  const Index PacketSize = packet_traits<Scalar>::size;
290  Scalar value = Comparator::predux(p);
291  if ((!(numext::isnan)(value) && (numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
292  const Packet range = preverse(plset<Packet>(Scalar(1)));
293  /* mask will be zero for NaNs, so they will be ignored. */
294  Packet mask = pcmp_eq(pset1<Packet>(value), p);
295  Index max_idx = PacketSize - static_cast<Index>(predux_max(pand(range, mask)));
296  this->res = value;
297  this->row = Derived::IsRowMajor ? i : i + max_idx;;
298  this->col = Derived::IsRowMajor ? j + max_idx : j;
299  }
300  }
301 
302 };
303 
304 // Propagate NaN. If the matrix contains NaN, the location of the first NaN will be returned in
305 // row and col.
306 template <typename Derived, bool is_min>
307 struct minmax_coeff_visitor<Derived, is_min, PropagateNaN> : coeff_visitor<Derived>
308 {
309  typedef typename Derived::Scalar Scalar;
310  using Packet = typename packet_traits<Scalar>::type;
311  using Comparator = minmax_compare<Scalar, PropagateNaN, is_min>;
312 
313  EIGEN_DEVICE_FUNC inline
314  void operator() (const Scalar& value, Index i, Index j)
315  {
316  const bool value_is_nan = (numext::isnan)(value);
317  if ((value_is_nan && !(numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
318  this->res = value;
319  this->row = i;
320  this->col = j;
321  }
322  }
323 
324  EIGEN_DEVICE_FUNC inline
325  void packet(const Packet& p, Index i, Index j) {
326  const Index PacketSize = packet_traits<Scalar>::size;
327  Scalar value = Comparator::predux(p);
328  const bool value_is_nan = (numext::isnan)(value);
329  if ((value_is_nan && !(numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
330  const Packet range = preverse(plset<Packet>(Scalar(1)));
331  // If the value is NaN, pick the first position of a NaN, otherwise pick the first extremal value.
332  Packet mask = value_is_nan ? pnot(pcmp_eq(p, p)) : pcmp_eq(pset1<Packet>(value), p);
333  Index max_idx = PacketSize - static_cast<Index>(predux_max(pand(range, mask)));
334  this->res = value;
335  this->row = Derived::IsRowMajor ? i : i + max_idx;;
336  this->col = Derived::IsRowMajor ? j + max_idx : j;
337  }
338  }
339 };
340 
341 template<typename Scalar, bool is_min, int NaNPropagation>
342 struct functor_traits<minmax_coeff_visitor<Scalar, is_min, NaNPropagation> > {
343  enum {
344  Cost = NumTraits<Scalar>::AddCost,
345  PacketAccess = true
346  };
347 };
348 
349 } // end namespace internal
350 
362 template<typename Derived>
363 template<int NaNPropagation, typename IndexType>
364 EIGEN_DEVICE_FUNC
365 typename internal::traits<Derived>::Scalar
366 DenseBase<Derived>::minCoeff(IndexType* rowId, IndexType* colId) const
367 {
368  eigen_assert(this->rows()>0 && this->cols()>0 && "you are using an empty matrix");
369 
370  internal::minmax_coeff_visitor<Derived, true, NaNPropagation> minVisitor;
371  this->visit(minVisitor);
372  *rowId = minVisitor.row;
373  if (colId) *colId = minVisitor.col;
374  return minVisitor.res;
375 }
376 
387 template<typename Derived>
388 template<int NaNPropagation, typename IndexType>
389 EIGEN_DEVICE_FUNC
390 typename internal::traits<Derived>::Scalar
391 DenseBase<Derived>::minCoeff(IndexType* index) const
392 {
393  eigen_assert(this->rows()>0 && this->cols()>0 && "you are using an empty matrix");
394 
395  EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
396  internal::minmax_coeff_visitor<Derived, true, NaNPropagation> minVisitor;
397  this->visit(minVisitor);
398  *index = IndexType((RowsAtCompileTime==1) ? minVisitor.col : minVisitor.row);
399  return minVisitor.res;
400 }
401 
413 template<typename Derived>
414 template<int NaNPropagation, typename IndexType>
415 EIGEN_DEVICE_FUNC
416 typename internal::traits<Derived>::Scalar
417 DenseBase<Derived>::maxCoeff(IndexType* rowPtr, IndexType* colPtr) const
418 {
419  eigen_assert(this->rows()>0 && this->cols()>0 && "you are using an empty matrix");
420 
421  internal::minmax_coeff_visitor<Derived, false, NaNPropagation> maxVisitor;
422  this->visit(maxVisitor);
423  *rowPtr = maxVisitor.row;
424  if (colPtr) *colPtr = maxVisitor.col;
425  return maxVisitor.res;
426 }
427 
438 template<typename Derived>
439 template<int NaNPropagation, typename IndexType>
440 EIGEN_DEVICE_FUNC
441 typename internal::traits<Derived>::Scalar
442 DenseBase<Derived>::maxCoeff(IndexType* index) const
443 {
444  eigen_assert(this->rows()>0 && this->cols()>0 && "you are using an empty matrix");
445 
446  EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
447  internal::minmax_coeff_visitor<Derived, false, NaNPropagation> maxVisitor;
448  this->visit(maxVisitor);
449  *index = (RowsAtCompileTime==1) ? maxVisitor.col : maxVisitor.row;
450  return maxVisitor.res;
451 }
452 
453 } // end namespace Eigen
454 
455 #endif // EIGEN_VISITOR_H
internal::traits< Derived >::Scalar minCoeff() const
Definition: Redux.h:433
void visit(Visitor &func) const
Definition: Visitor.h:183
internal::traits< Derived >::Scalar maxCoeff() const
Definition: Redux.h:448
@ PropagateNaN
Definition: Constants.h:345
@ PropagateNumbers
Definition: Constants.h:347
const unsigned int PacketAccessBit
Definition: Constants.h:96
Namespace containing all symbols from the Eigen library.
Definition: Core:139
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:59
const int Dynamic
Definition: Constants.h:24