Expression templates

From Eigen
Revision as of 07:34, 25 June 2008 by Bjacob (Talk | contribs)

Jump to: navigation, search

Expression templates refers to a c++ coding technique that was discovered in the 90's, which can greatly improve the performance and the API cleanness and expressiveness of certain kinds of c++ template libraries, and is especially useful for a vector/matrix library such as Eigen. First, some links: wikibooks and the links given there.

General discussion

You can safely skip this section, it is less precise and more demanding than the rest of this article. A concrete example is discussed below.

C++ method chaining is nice, but for certain use case it suffers from limitations. Suppose we have a class as follows:

class C
{
  // ...
 
  public:
    C();
    C a();
    C b();
    C& operator=(const C&);
};

We can then chain the methods as follows,

C c1, c2;
c1 = c2.a().b().a();

However there are two problems with that. The first problem is that the methods a() and b() create everytime a new object of C and return it by value, copying it onto the stack. This can introduce a large overhead. The second problem is a corollary of the first one: since the return value of a() and of b() is a new temporary object, it is impossible to use them like this,

C c1, c2;
c1.a() = c2.b();

because this would assign to the temporary object of C returned by a() which has nothing to do with c1 -- hence this would have no effect on c1.

The usual idea to overcome the second problem, is to let a() and b() return "proxy objects" instead of objects of C. However this is not optimal as these objects could not be used as objects of C when that is what one wants, and in particular that wouldn't allow method chaining like this,

C c1, c2;
c1.a().b() = c2;

unless of course the API of class C is replicated in the proxy classes. This would also address the first problem.

Now you are probably starting to think about a trick with templates and inheritance to do this automatically... and that is expression templates.

Let us first see what we could do with templates only:

template<typename T> class A;
template<typename T> class B;
template<typename T>
class C
{
  // ...
 
  public:
    C();
    C<A<T> > a();
    C<B<T> > b();
    template<typename U> C& operator=(const C<U>&);
};

Now one could imagine that C<int> is the actual class C that we referred to above; class C<A<T> > could be the "proxy" for a() and could be implemented by partial template specialization.

There is, however, a big problem: every partial specialization of C<T> needs to reimplement all the methods, there is no code reuse between specializations even though typically many methods could be shared.

Our solution is given by the CRT inheritance pattern. The general idea is sketched in the following code snippet, read the rest of this article for more implementation details.

template<typename Derived>
class CBase
{
  public:
    // implement most (if not all) of the API here.
    // when you need some code/data that is specific to the Derived class, implement it as follows:
    Derived& derived() { return *static_cast<Derived*>(this); }
    void someMethodSpecificToDerived() { return derived()._someMethodSpecificToDerived(); }
 
    C();
    C<A<Derived> > a();
    C<B<Derived> > b();
    template<typename U> C& operator=(const C<U>&);
};
 
class C : public CBase<C>
{
  // only need to implement here what is specific to this derived class
 
  void _someMethodSpecificToDerived() { /* ... */ }
};
 
template<typename Derived>
class A : public CBase<A<Derived> >
{
  // ...
};
 
template<typename Derived>
class A : public CBase<A<Derived> >
{
  // ...
};

A simple Vector class

Suppose that we want to write a Vector class. The standard approach is,

class Vector {
    float m_x, m_y, m_z;
  public:
    Vector(float x, float y, float z) : m_x(x), m_y(y), m_z(z) {}
    float x() const { return m_x; }
    float y() const { return m_y; }
    float z() const { return m_z; }
    Vector operator+(const Vector& other)
    {
      return Vector(x()+other.x(), y()+other.y(), z()+other.z());
    }
};