Expression Templates

Preface

在研究 CRTP 时发现了一种经典应用——表达式模板。今天学一学,顺便水一篇博客。🐶

Motivation and example

表达式模板是一种模板元编程,它在编译期展开向量的一系列复合运算,从而将辅助空间从 $O(n)$ 降到 $O(1)$。相比于朴素算法,表达式模板属于惰性求值,求值时间点是赋值运算。

举个例子,有这样一个 vector 运算:

1
tot = (a + b) * c;

朴素算法的语义是:

1
2
3
4
5
{
auto tmp0 = a + b; // allocating O(n) space
auto tmp1 = tmp0 * c;
tot = tmp1;
}

即使你知道「引用限定成员函数」,使你可以抹掉 tmp1 的空间分配,但是 tmp0 的辅助空间总是逃不掉。

能帮助你的只有表达式模板,它可以表达出这样的语义:

1
2
3
for (size_t i=0, n=a.size(); i<n; ++i) {
tot[i] = (a[i] + b[i]) * c[i];
}

如此一来需要的辅助空间就是寄存器级别的!去掉了函数调用也更方便编译器做进一步的优化。

My first attempt

我的首次尝试

忍住不看教程,我先试着实现出来,看看差距有多大。🤔

我的想法:

  1. 利用类型系统,将懒惰信息记录在模板参数中
  2. 定义惰性二元运算符,它的操作数可以是 vector 或者惰性二元运算符。
  3. 魔改 vector,在赋值运算符中触发求值;重载加减乘除运算。

Draft 1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include <iostream>
#include <vector>
#include <functional>
#include <cassert>

template<typename BinaryOperator, typename LHS, typename RHS>
struct DeferredOperator {
const LHS &lhs;
const RHS &rhs;
constexpr DeferredOperator(
BinaryOperator, const LHS &lhs, const RHS &rhs) noexcept
: lhs(lhs), rhs(rhs) {}
constexpr auto operator[](size_t pos) const {
return BinaryOperator{}(lhs[pos], rhs[pos]);
}
constexpr size_t size() const { return lhs.size(); }
};

template<typename T, class Container = std::vector<T>>
struct Vec : Container {
template<typename Op>
Vec &operator=(Op &&op) {
const size_t n = op.size();
this->resize(n);
for (size_t i = 0; i < n; ++i) this->operator[](i) = op[i];
return *this;
}
};

int main() {
Vec<int> vec;
std::vector<int> a(5, 1), b(5, 3), c(5, 6);
vec = DeferredOperator(
std::multiplies<int>{}, DeferredOperator(std::plus<int>{}, a, b), c);

for (auto x : vec) {
std::cout << x << ' ';
assert(x == (1 + 3) * 6);
}

return 0;
}

What's wrong

这段程序是正确的!😁但也有一些犯蠢的地方:

  1. 没有重载运算符,用起来太丑;
  2. 显式模板参数太多,比如 int 出现太多次。理想中应该只在 Vec 定义时表明T = int,后面的通通自动推导;
  3. 魔改的 Vec 类没有沿用 std::vector 的非默认构造函数;
  4. 当调用 vec = vec 时会错误地调用惰性求值版本;
  5. 没有显式 CRTP,证明不够熟练;
  6. 在计算前调用 vector::resize() 真的好吗?有点违反零开销抽象的感觉。

The second attempt

查文档再改进

针对上一章节列出的缺点,我查阅文档做了一些改进:

Inheriting constructors

cppreference

1
2
3
4
5
template<typename T, class Container = std::vector<T>>
struct Vec : Container {
using Container::Container;
// ...
};

这样就能使用 std::vector 的构造接口了,我们的 C++ 真的太好用啦!

Template Inner Class

我们可以将 DeferredOperator 类定义在 Vec 内部,这样就不用到处传递基础数据类型啦!

Who needs operator+()?

  1. Vec 类需要实现 operator+(),并由这个函数返回一个 DeferredOperator
  2. DeferredOperator 类需要实现 operator+(),这个函数同样返回一个 DeferredOperator

Ambiguity of operator=()

为了区分 Vec = VecVec = DeferredOperator,我们需要更精细地定义参数。

Draft 2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include <iostream>
#include <vector>
#include <functional>
#include <cassert>

template<typename T, class Container = std::vector<T>>
struct Vec : Container {
using Container::Container;

template<typename BinaryOperator, typename LHS, typename RHS>
struct DeferredOperator {
const LHS &lhs;
const RHS &rhs;
constexpr DeferredOperator(const LHS &lhs, const RHS &rhs) noexcept
: lhs(lhs), rhs(rhs) {}
constexpr auto operator[](size_t pos) const {
return BinaryOperator{}(lhs[pos], rhs[pos]);
}
template<typename OuterRHS>
constexpr auto operator+(const OuterRHS &outerRHS) const {
return DeferredOperator<std::plus<>, DeferredOperator, OuterRHS>(
*this, outerRHS);
}
constexpr size_t size() const { return lhs.size(); }
};

template<typename RHS>
constexpr auto operator+(const RHS &rhs) const {
return DeferredOperator<std::plus<>, Vec, RHS>(*this, rhs);
}

template<typename BinaryOperator, typename LHS, typename RHS>
Vec &operator=(const DeferredOperator<BinaryOperator,LHS,RHS> &op) {
std::cout << "tag\n";
const size_t n = op.size();
this->resize(n);
for (size_t i = 0; i < n; ++i) this->operator[](i) = op[i];
return *this;
}
};

int main() {
Vec<int> vec, a(5, 1), b(5, 3), c(5, 6);
vec = a; // check bug
vec = a + b + c;

for (auto x : vec) {
std::cout << x << ' ';
assert(x == (1 + 3) + 6);
}

return 0;
}

Direction of improvement

WoW! 这段代码也是正确的!但还是写得有点臭:

  1. 这里只写了重载加法。如果每种运算符都要写两个成员函数,那也太难看了。
  2. 只对 Vec 类实现了惰性加法。万一再来一个 Matrix 类呢?惰性运算能不能接口化?
  3. CRTP 在哪?

Peak of Evolution

这次依然是原创代码,改进了一些地方:

Class Diagram

多层次的 CRTP 用起来容易头晕,应该事先设计好类的关系:

CRTP relationship

橙色和红色的类才是 CRTP 的模板实参。白色的类主要用作接口,类似于抽象类。

ET interface

  • 用户类只要继承 VectorExpresstion 接口,同时定义合适的 operator[]operator= 即可实现模板表达式。
  • 用户可以在 operator= 中做性能优化,例如手工循环展开、SIMD 等等,非常灵活。
  • 用户还可以自定义运算符,但我暂时想不到有什么不啰嗦的实现方法(除了宏)。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
// necessary header for Expression Template library
#include <type_traits>
#include <functional>

// header for user-define things
#include <vector>
#include <iostream>
#include <cassert>
#include <cmath>
#include <algorithm>

// helper class for CRTP. For more info, please see my previous blog
template<class T, template<typename> class Interface>
struct CRTP {
T &self() { return static_cast<T &>(*this); }
const T &self() const { return static_cast<const T &>(*this); }

private:
CRTP() = default;
friend Interface<T>;
};

// my Expression Template library implementation
namespace ET {

// base class for lazy evaluation
template<class T>
struct Expression : CRTP<T, Expression> {
// use a looong function name to avoid ambiguity
constexpr auto evaluate(size_t i) const { return this->self()[i]; }
constexpr auto selfSize() const { return this->self().size(); }
};

// class to indicate a single value (instead of an vector)
template<class T>
struct ArithmeticExpression : Expression<ArithmeticExpression<T>> {
ArithmeticExpression() = delete;

template<std::enable_if_t<std::is_arithmetic_v<T>, bool> = true>
constexpr ArithmeticExpression(const T value) noexcept : val(value) {}

constexpr T operator[](size_t) const { return val; }
constexpr void size() const {} // return void for easier metaprogramming

private:
const T val;
};

// class to indicate an vector (instead of a single arithmetic value)
template<class T>
struct VectorExpression : Expression<T> {};

template<typename BinaryOperator, typename LHS, typename RHS>
struct BinaryExpression
: VectorExpression<BinaryExpression<BinaryOperator, LHS, RHS>> {
constexpr auto operator[](size_t i) const {
// evaluation happens when operator[] is called
return BinaryOperator{}(lhs.evaluate(i), rhs.evaluate(i));
}
constexpr size_t size() const {
// return the size of an *vector*
if constexpr (!std::is_same_v<decltype(lhs.selfSize()), void>)
return lhs.selfSize();
else
return rhs.selfSize();
}

BinaryExpression(const LHS &lhs, const RHS &rhs) : lhs(lhs), rhs(rhs) {}

private:
const LHS &lhs;
const RHS &rhs;
};

template<typename LHS, typename RHS>
using AddExpression = BinaryExpression<std::plus<>, LHS, RHS>;

template<typename LHS, typename RHS>
using SubExpression = BinaryExpression<std::minus<>, LHS, RHS>;

template<typename LHS, typename RHS>
using MulExpression = BinaryExpression<std::multiplies<>, LHS, RHS>;

template<typename LHS, typename RHS>
using DivExpression = BinaryExpression<std::divides<>, LHS, RHS>;

} // namespace ET

// ==============================

// overloaded operator to export:

template<class T>
using ETNum = ET::ArithmeticExpression<T>;

template<typename E1, typename E2>
constexpr auto
operator+(const ET::Expression<E1> &lhs, const ET::Expression<E2> &rhs) {
// notice Expression<E1> -> E1
// so "Expression" is removed
return ET::AddExpression<E1, E2>(lhs.self(), rhs.self());
}

template<typename E1, typename E2>
constexpr auto
operator-(const ET::Expression<E1> &lhs, const ET::Expression<E2> &rhs) {
return ET::SubExpression<E1, E2>(lhs.self(), rhs.self());
}

template<typename E1, typename E2>
constexpr auto
operator*(const ET::Expression<E1> &lhs, const ET::Expression<E2> &rhs) {
return ET::MulExpression<E1, E2>(lhs.self(), rhs.self());
}

template<typename E1, typename E2>
constexpr auto
operator/(const ET::Expression<E1> &lhs, const ET::Expression<E2> &rhs) {
return ET::DivExpression<E1, E2>(lhs.self(), rhs.self());
}

// ==============================

// user-define operator

// I find it hard to convert a function into an operator() overloading.
// If you have a good idea, please tell me!
struct POW {
template<typename _Tp, typename _Up>
auto operator()(_Tp &&__t, _Up &&__u) const {
return pow(std::forward<_Tp>(__t), std::forward<_Up>(__u));
}
};

template<typename E1, typename E2>
constexpr auto
Pow(const ET::Expression<E1> &lhs, const ET::Expression<E2> &rhs) {
return ET::BinaryExpression<POW, E1, E2>(lhs.self(), rhs.self());
}

// ==============================

// user-define vector

template<typename T, class Container = std::vector<T>>
struct Vec
: Container
, ET::VectorExpression<Vec<T, Container>> {
using Container::Container;

template<typename Exp>
Vec &operator=(const ET::Expression<Exp> &exp) {
const size_t n = exp.selfSize();
this->resize(n);
for (size_t i = 0; i < n; ++i) this->operator[](i) = exp.evaluate(i);
return *this;
}
};

// ==============================

int main() {
Vec<double> ans, a(5, 1), b(5, 2), c(5, 8), d(5, 7), e(5, 0.4);

ans = Pow((a + b) * c / d + ETNum<double>(1), e);

for (auto &&x : ans) {
std::cout << x << ' ';
assert(x == pow(((1 + 2) * 8 / 7.0 + 1), 0.4));
}

return 0;
}

Conclusion and Prospect

一个基本可用的表达式模板库诞生了。但是要记住:

  1. 表达式模板离开了内联优化就是负优化。经过简单测试,gcc 至少要在 -O1 条件下才有内联成员函数的能力。
  2. 使用表达式模板和其他优化手段并不冲突,可以在自定义 operator= 处做 SIMD,循环展开等等。

我不得不承认这份代码对比基础库的水准还是差了很多,至少有以下方面可以改进:

  • more type trait,描述表达式的可加性、可乘性、可被除和可除性等等。
  • 进一步简化用户自定义运算符的代码。
  • 实现一元运算符。
  • 在矢量和标量的混合运算中抹去显式构造标量。

但模板元编程暂时不是我的主要学习方向,所以还是再等等吧~

References