Bugzilla – Attachment 501 Details for
Bug 895
Port MathFunctions.h to AVX (and possibly other architectures)
Home
|
New
|
Browse
|
Search
|
[?]
|
Reports
|
Requests
|
Help
|
Log In
[x]
|
Forgot Password
Login:
[x]
This bugzilla service is closed. All entries have been migrated to
https://gitlab.com/libeigen/eigen
[patch]
AVX version of the pexp/plog/psqrt math functions
patch.txt (text/plain), 10.56 KB, created by
Benoit Steiner
on 2014-10-20 20:15:36 UTC
(
hide
)
Description:
AVX version of the pexp/plog/psqrt math functions
Filename:
MIME Type:
Creator:
Benoit Steiner
Created:
2014-10-20 20:15:36 UTC
Size:
10.56 KB
patch
obsolete
>diff --git a/Eigen/src/Core/arch/AVX/MathFunctions.h b/Eigen/src/Core/arch/AVX/MathFunctions.h >new file mode 100644 >--- /dev/null >+++ b/Eigen/src/Core/arch/AVX/MathFunctions.h >@@ -0,0 +1,241 @@ >+// This file is part of Eigen, a lightweight C++ template library >+// for linear algebra. >+// >+// Copyright (C) 2014 Pedro Gonnet (pedro.gonnet@gmail.com) >+// >+// This Source Code Form is subject to the terms of the Mozilla >+// Public License v. 2.0. If a copy of the MPL was not distributed >+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. >+ >+#ifndef EIGEN_MATH_FUNCTIONS_AVX_H >+#define EIGEN_MATH_FUNCTIONS_AVX_H >+ >+// For some reason, this function didn't make it into the avxintirn.h >+// used by the compiler, so we'll just wrap it. >+#define _mm256_setr_m128(lo, hi) \ >+ _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1) >+ >+ >+namespace Eigen { >+ >+namespace internal { >+ >+// Natural logarithm >+// Computes log(x) as log(2^e * m) = k*e + log(m), where m is in the range >+// [sqrt(1/2),sqrt(2)). In this range, the logarithm can be easily >+// approximated by a polynomial centered on m=1 for stability. >+// TODO(gonnet): Further reduce the interval allowing for lower-degree >+// polynomial interpolants -> ... -> profit! >+template <> >+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f >+plog<Packet8f>(const Packet8f& _x) { >+ Packet8f x = _x; >+ _EIGEN_DECLARE_CONST_Packet8f(1, 1.0f); >+ _EIGEN_DECLARE_CONST_Packet8f(half, 0.5f); >+ _EIGEN_DECLARE_CONST_Packet8f(126f, 126.0f); >+ >+ _EIGEN_DECLARE_CONST_Packet8f_FROM_INT(inv_mant_mask, ~0x7f800000); >+ >+ /* the smallest non denormalized float number */ >+ _EIGEN_DECLARE_CONST_Packet8f_FROM_INT(min_norm_pos, 0x00800000); >+ _EIGEN_DECLARE_CONST_Packet8f_FROM_INT(minus_inf, 0xff800000); >+ >+ /* natural logarithm computed for 8 simultaneous float >+ return NaN for x <= 0 >+ */ >+ _EIGEN_DECLARE_CONST_Packet8f(cephes_SQRTHF, 0.707106781186547524f); >+ _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p0, 7.0376836292E-2f); >+ _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p1, -1.1514610310E-1f); >+ _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p2, 1.1676998740E-1f); >+ _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p3, -1.2420140846E-1f); >+ _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p4, +1.4249322787E-1f); >+ _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p5, -1.6668057665E-1f); >+ _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p6, +2.0000714765E-1f); >+ _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p7, -2.4999993993E-1f); >+ _EIGEN_DECLARE_CONST_Packet8f(cephes_log_p8, +3.3333331174E-1f); >+ _EIGEN_DECLARE_CONST_Packet8f(cephes_log_q1, -2.12194440e-4f); >+ _EIGEN_DECLARE_CONST_Packet8f(cephes_log_q2, 0.693359375f); >+ >+ Packet8f invalid_mask = _mm256_cmp_ps(x, _mm256_setzero_ps(), _CMP_LT_OQ); >+ Packet8f iszero_mask = _mm256_cmp_ps(x, _mm256_setzero_ps(), _CMP_EQ_OQ); >+ >+ // Truncate input values to the minimum positive normal. >+ x = pmax(x, p8f_min_norm_pos); >+ >+// Extract the shifted exponents (No bitwise shifting in regular AVX, so convert >+// to >+// SSE and do it there). >+#ifdef EIGEN_VECTORIZE_AVX2 >+ Packet8f emm0 = _mm256_cvtepi32_ps(_mm256_srli_epi32((__m256i)x, 23)); >+#else >+ __m128i lo = _mm_srli_epi32(_mm256_extractf128_si256((__m256i)x, 0), 23); >+ __m128i hi = _mm_srli_epi32(_mm256_extractf128_si256((__m256i)x, 1), 23); >+ Packet8f emm0 = _mm256_cvtepi32_ps(_mm256_setr_m128(lo, hi)); >+#endif >+ Packet8f e = _mm256_sub_ps(emm0, p8f_126f); >+ >+ /* Set the exponents to -1, i.e. x are in the range [0.5,1). */ >+ x = _mm256_and_ps(x, p8f_inv_mant_mask); >+ x = _mm256_or_ps(x, p8f_half); >+ >+ /* part2: Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2)) >+ * and shift by -1. The values are then centered around 0, which improves >+ * the stability of the polynomial evaluation. >+ if( x < SQRTHF ) { >+ e -= 1; >+ x = x + x - 1.0; >+ } else { x = x - 1.0; } >+ */ >+ Packet8f mask = _mm256_cmp_ps(x, p8f_cephes_SQRTHF, _CMP_LT_OQ); >+ Packet8f tmp = _mm256_and_ps(x, mask); >+ x = psub(x, p8f_1); >+ e = psub(e, _mm256_and_ps(p8f_1, mask)); >+ x = padd(x, tmp); >+ >+ Packet8f x2 = pmul(x, x); >+ Packet8f x3 = pmul(x2, x); >+ >+ // Evaluate the polynomial approximant of degree 8 in three parts, probably >+ // to improve instruction-level parallelism. >+ // TODO(gonnet): Split into odd/even polynomials, avoids computing x3 and >+ // some of the multiplications at the end. >+ Packet8f y, y1, y2; >+ y = pmadd(p8f_cephes_log_p0, x, p8f_cephes_log_p1); >+ y1 = pmadd(p8f_cephes_log_p3, x, p8f_cephes_log_p4); >+ y2 = pmadd(p8f_cephes_log_p6, x, p8f_cephes_log_p7); >+ y = pmadd(y, x, p8f_cephes_log_p2); >+ y1 = pmadd(y1, x, p8f_cephes_log_p5); >+ y2 = pmadd(y2, x, p8f_cephes_log_p8); >+ y = pmadd(y, x3, y1); >+ y = pmadd(y, x3, y2); >+ y = pmul(y, x3); >+ >+ // Add the logarithm of the exponent back to the result of the interpolation. >+ y1 = pmul(e, p8f_cephes_log_q1); >+ tmp = pmul(x2, p8f_half); >+ y = padd(y, y1); >+ x = psub(x, tmp); >+ y2 = pmul(e, p8f_cephes_log_q2); >+ x = padd(x, y); >+ x = padd(x, y2); >+ >+ // Filter out invalid inputs, i.e. negative arg will be NAN, 0 will be -INF. >+ return _mm256_or_ps( >+ _mm256_andnot_ps(iszero_mask, _mm256_or_ps(x, invalid_mask)), >+ _mm256_and_ps(iszero_mask, p8f_minus_inf)); >+} >+ >+// Exponential function. Works by writing "x = m*log(2) + r" where >+// "m = floor(x/log(2)+1/2)" and "r" is the remainder. The result is then >+// "exp(x) = 2^m*exp(r)" where exp(r) is in the range [-1,1). >+template <> >+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f >+pexp<Packet8f>(const Packet8f& _x) { >+ _EIGEN_DECLARE_CONST_Packet8f(1, 1.0f); >+ _EIGEN_DECLARE_CONST_Packet8f(half, 0.5f); >+ _EIGEN_DECLARE_CONST_Packet8f(127, 127.0f); >+ >+ _EIGEN_DECLARE_CONST_Packet8f(exp_hi, 88.3762626647950f); >+ _EIGEN_DECLARE_CONST_Packet8f(exp_lo, -88.3762626647949f); >+ >+ _EIGEN_DECLARE_CONST_Packet8f(cephes_LOG2EF, 1.44269504088896341f); >+ >+ _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p0, 1.9875691500E-4f); >+ _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p1, 1.3981999507E-3f); >+ _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p2, 8.3334519073E-3f); >+ _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p3, 4.1665795894E-2f); >+ _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p4, 1.6666665459E-1f); >+ _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_p5, 5.0000001201E-1f); >+ >+ // Clamp x. >+ Packet8f x = pmax(pmin(_x, p8f_exp_hi), p8f_exp_lo); >+ >+ // Express exp(x) as exp(m*ln(2) + r), start by extracting >+ // m = floor(x/ln(2) + 0.5). >+ Packet8f m = _mm256_floor_ps(pmadd(x, p8f_cephes_LOG2EF, p8f_half)); >+ >+// Get r = x - m*ln(2). If no FMA instructions are available, m*ln(2) is >+// subtracted out in two parts, m*C1+m*C2 = m*ln(2), to avoid accumulating >+// truncation errors. Note that we don't use the "pmadd" function here to >+// ensure that a precision-preserving FMA instruction is used. >+#ifdef EIGEN_VECTORIZE_FMA >+ _EIGEN_DECLARE_CONST_Packet8f(nln2, -0.6931471805599453f); >+ Packet8f r = _mm256_fmadd_ps(m, p8f_nln2, x); >+#else >+ _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_C1, 0.693359375f); >+ _EIGEN_DECLARE_CONST_Packet8f(cephes_exp_C2, -2.12194440e-4f); >+ Packet8f r = psub(x, pmul(m, p8f_cephes_exp_C1)); >+ r = psub(r, pmul(m, p8f_cephes_exp_C2)); >+#endif >+ >+ Packet8f r2 = pmul(r, r); >+ >+ // TODO(gonnet): Split into odd/even polynomials and try to exploit >+ // instruction-level parallelism. >+ Packet8f y = p8f_cephes_exp_p0; >+ y = pmadd(y, r, p8f_cephes_exp_p1); >+ y = pmadd(y, r, p8f_cephes_exp_p2); >+ y = pmadd(y, r, p8f_cephes_exp_p3); >+ y = pmadd(y, r, p8f_cephes_exp_p4); >+ y = pmadd(y, r, p8f_cephes_exp_p5); >+ y = pmadd(y, r2, r); >+ y = padd(y, p8f_1); >+ >+ // Build emm0 = 2^m. >+ Packet8i emm0 = _mm256_cvttps_epi32(padd(m, p8f_127)); >+#ifdef EIGEN_VECTORIZE_AVX2 >+ emm0 = _mm256_slli_epi32(emm0, 23); >+#else >+ __m128i lo = _mm_slli_epi32(_mm256_extractf128_si256(emm0, 0), 23); >+ __m128i hi = _mm_slli_epi32(_mm256_extractf128_si256(emm0, 1), 23); >+ emm0 = _mm256_setr_m128(lo, hi); >+#endif >+ >+ // Return 2^m * exp(r). >+ return pmul(y, _mm256_castsi256_ps(emm0)); >+} >+ >+// Functions for sqrt. >+// The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step >+// of Newton's method, at a cost of 1-2 bits of precision as opposed to the >+// exact solution. The main advantage of this approach is not just speed, but >+// also the fact that it can be inlined and pipelined with other computations, >+// further reducing its effective latency. >+#if EIGEN_FAST_MATH >+template <> >+EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f >+psqrt<Packet8f>(const Packet8f& _x) { >+ _EIGEN_DECLARE_CONST_Packet8f(one_point_five, 1.5f); >+ _EIGEN_DECLARE_CONST_Packet8f(minus_half, -0.5f); >+ _EIGEN_DECLARE_CONST_Packet8f_FROM_INT(flt_min, 0x00800000); >+ >+ Packet8f neg_half = pmul(_x, p8f_minus_half); >+ >+ // select only the inverse sqrt of positive normal inputs (denormals are >+ // flushed to zero and cause infs as well). >+ Packet8f non_zero_mask = _mm256_cmp_ps(_x, p8f_flt_min, _CMP_GE_OQ); >+ Packet8f x = _mm256_and_ps(non_zero_mask, _mm256_rsqrt_ps(_x)); >+ >+ // Do a single step of Newton's iteration. >+ x = pmul(x, pmadd(neg_half, pmul(x, x), p8f_one_point_five)); >+ >+ // Multiply the original _x by it's reciprocal square root to extract the >+ // square root. >+ return pmul(_x, x); >+} >+#else >+template <> >+EIGEN_STRONG_INLINE Packet8f psqrt<Packet8f>(const Packet8f& x) { >+ return _mm256_sqrt_ps(x); >+} >+#endif >+template <> >+EIGEN_STRONG_INLINE Packet4d psqrt<Packet4d>(const Packet4d& x) { >+ return _mm256_sqrt_pd(x); >+} >+ >+} // end namespace internal >+ >+} // end namespace Eigen >+ >+#endif // EIGEN_MATH_FUNCTIONS_AVX_H >diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h >--- a/Eigen/src/Core/arch/AVX/PacketMath.h >+++ b/Eigen/src/Core/arch/AVX/PacketMath.h >@@ -51,33 +51,35 @@ template<> struct packet_traits<float> > Vectorizable = 1, > AlignedOnScalar = 1, > size=8, > HasHalfPacket = 1, > > HasDiv = 1, > HasSin = 0, > HasCos = 0, >- HasLog = 0, >- HasExp = 0, >- HasSqrt = 0 >+ HasLog = 1, >+ HasExp = 1, >+ HasSqrt = 1 > }; > }; > template<> struct packet_traits<double> : default_packet_traits > { > typedef Packet4d type; > typedef Packet2d half; > enum { > Vectorizable = 1, > AlignedOnScalar = 1, > size=4, > HasHalfPacket = 1, > > HasDiv = 1, >- HasExp = 0 >+ HasLog = 0, >+ HasExp = 0, >+ HasSqrt = 1 > }; > }; > > /* Proper support for integers is only provided by AVX2. In the meantime, we'll > use SSE instructions and packets to deal with integers. > template<> struct packet_traits<int> : default_packet_traits > { > typedef Packet8i type;
You cannot view the attachment while viewing its details because your browser does not support IFRAMEs.
View the attachment on a separate page
.
View Attachment As Diff
View Attachment As Raw
Actions:
View
|
Diff
Attachments on
bug 895
: 501 |
542