You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

15 KiB

38 | 折叠表达式:高效的编译期展开

你好,我是吴咏炜。

当我在写第 18 讲“应用可变模板和 tuple 的编译期技巧”时,有朋友就建议可以讨论一下折叠表达式 [1]。不过在那时候我对折叠表达式并没有什么实际应用经验觉得它用处不大于是就略过了。幸好我只是没有写而没有乱加评论否则就图欧图森破too old, too simple了。很多功能只有在你真正需要到它、又掌握如何使用的时候你才会觉得真正有用。很多 C++ 里看似用处不大的特性,都是这种情况(当然也不是所有,否则就不会有对特性的废弃和删除了)。

跟之前一样,我们还是通过具体的例子来看一下折叠表达式的作用和表达能力。

基本用法

一元折叠

第 18 讲我举过这样一个编译期递归的例子:

template <typename T>
constexpr auto sum(T x)
{
  return x;
}

template <typename T1, typename T2,
          typename... Targ>
constexpr auto sum(T1 x, T2 y,
                   Targ... args)
{
  return sum(x + y, args...);
}

而使用折叠表达式的话,我们可以把代码简化成下面这个样子:

template <typename... Args>
constexpr auto sum(Args... args)
{
  return (... + args);
}

你应该可以看到,这个改进有多大了吧?

虽然猜这个代码是什么意思不难,但要精确理解这样的代码的语义,以及自己写出正确的折叠表达式,我们还是需要讲一点点语法。

首先,折叠表达式应用在可变模板的情况下,所以,我们需要有 typename... Args 这样的模板参数,及 Args... argsconst Args&... argsArgs&&... args 之类的函数参数包 [2]。

其次,我们在折叠表达式里使用参数包有一种特殊的形式。我们一定会用到圆括号(漏括号是初学时的常见错误)、参数包名称、运算符和 ...。上面的代码用的是“一元左折叠”,形式是“(... 运算符 参数包名称)”。

最后,我们看一下折叠表达式是如何展开的。如果参数包只有一项(args0),那结果就是这个参数自身(args0);如果参数包有两项(args0, args1),那结果就是这两项用运算符拼接起来(args0 + args1);如果参数包有三项(args0, args1, args2),那结果就是这三项用运算符拼接起来(args0 + args1 + args2);以此类推。

如果你初步理解了,那我得声明一下,我上面做了简化。对于超过两项的折叠表达式展开,我们有时候需要关注一下运算符的结合问题,即我们需要的是 (args0 + args1) + args2,还是 args0 + (args1 + args2)。显而易见,对于加法,以及其他满足结合律的运算符,这两者的区别并不重要。但是,即使对于加法,实际上这两种顺序都不是完全没有区别的(如浮点数),而对于减法、除法之类的运算,那就更不用说了。所以,我们需要区分一元左折叠和一元右折叠两种不同的方式,它们分别对应于 (args0 + args1) + args2args0 + (args1 + args2)。写折叠表达式时的区别是,一元左折叠的 ... 在左边,而一元右折叠的 ... 在右边:(... + args)(args + ...)

下面用符号描述一下。令 E 代表参数包,N 代表参数包里的参数数量,则:

  • 一元左折叠是 (\\ldots\\ \\mathrm{op}\\ E),展开后成为 (((E\_1\\ \\mathrm{op}\\ E\_2) \\ \\mathrm{op}\\ \\ldots)\\ \\mathrm{op}\\ E\_N)
  • 一元右折叠是 (E \\ \\mathrm{op}\\ \\ldots),展开后成为(E\_1\\ \\mathrm{op}\\ (\\ldots\\ \\mathrm{op}\\ (E\_{N-1}\\ \\mathrm{op}\\ E\_N)))

与和或的折叠

在大部分折叠表达式的展开过程中参数包为空是一个错误。不过为了方便实际使用的场景C++ 对于使用 &&|| 的折叠表达式有特殊处理,允许参数包为空。这种情况下,&& 得到 true|| 得到 false——也就是说,相当于折叠表达式默认填充了一项不影响正常运算结果的数据:true && args0 && ...false || args0 || ...

C++ 里的“与”和“或”有短路求值行为 [3]。当你写下 a && b 时,如果 a 算出的结果是 false,编译器就不会对 b 进行求值,因为求值没有意义,不会影响结果。这是一个明确定义了的行为。类似地,如果 a || ba 的结果是 true,编译器也不会对 b 进行求值。

如果我们看 a && b && c 这样的表达式的话,我们会发现情况也完全一样。按照运算符的结合规则,上面的表达式等同于 (a && b) && c;如果 a 结果为 false,无需对 b 求值就得到 a && bfalse,所以 c 也无需求值即得到最终结果 false。如果 atrue,编译器才会对 b 求值,并在结果为 true 时才对 c 求值……

再进一步,对于 a && (b && c) 进行分析,我们会发现,求值的顺序和结果仍将完全相同:

  • a 求值为 false,则 b && c 不求值,结果为 false
  • a 求值为 true,然后 b 求值为 false,则 c 不求值,结果为 false
  • a 求值为 true,然后 b 求值为 true,则结果为 c 求值的结果

这样的分析对于 || 也同样适用。因此,一元左折叠和右折叠的等价性不仅对于普通满足结合律的运算符是成立的,而且对于有短路规则的运算符也同样是成立的。

逗号的折叠

很多人可能没注意到,逗号“,”也是一个运算符 [4],表达式 a, b 的意思是(当然,不是在能被当作函数参数的地方;如果可能被编译器误解,多加一重括号就行),对 ab 依次进行求值,返回后一个表达式 b 的结果。在使用逗号的折叠表达式里,参数包也允许为空,此时表达式的结果相当于 void(),即没有数值。

我们后面会展示逗号折叠表达式的用法。

二元折叠

在参数包里提供了运算所需的所有参数时,一元折叠表达式就很好了。但还有很大的一类展开场景,我们没法用一元折叠表达式,因为我们需要在函数里提供某个参数。一种典型的情况就是,我们需要把一堆参数输出到某个流里:

cout << args0 << args1 << ... << argsN;

这里我们就需要用到所谓的“二元左折叠”了。如果待输出的参数组成了我们的参数包 args,我们用下面的代码就能输出:

(cout << ... << args);

类似地,我们有“二元右折叠”,道理相同,我就不展开了。

二元折叠相当于提供了一个“初值”,所以参数包允许为空。对于空参数包,(... + args) 是不合法代码,而 (0 + ... + args) 就是合法的了。

下面用符号描述一下。令 E 代表参数包,N 代表参数包里的参数数量,I 代表“初值参数”,则:

  • 二元左折叠是 (I\\ \\mathrm{op}\\ \\ldots\\ \\mathrm{op}\\ E),展开后成为 ((((I\\ \\mathrm{op}\\ E\_1)\\ \\mathrm{op}\\ E\_2) \\ \\mathrm{op}\\ \\ldots)\\ \\mathrm{op}\\ E\_N)
  • 二元右折叠是 (E \\ \\mathrm{op}\\ \\ldots\\ \\mathrm{op}\\ I),展开后成为(E\_1\\ \\mathrm{op}\\ (\\ldots\\ \\mathrm{op}\\ (E\_{N-1}\\ \\mathrm{op}\\ (E\_N\\ \\mathrm{op}\\ I))))

折叠表达式的应用场景

在对折叠表达式有了一些初步的了解之后,我们来看一下实际应用折叠表达式的一些场景。

空指针检查

作为一种编译期展开的功能,折叠表达式能够达到跟手写展开完全相同的效果,但表达上要精炼得多。比如,我们有代码需要检查给定的指针(有可能有智能指针)是否有为空的情况,我们就可以写:

if (ptr1 == nullptr ||
    ptr2 == nullptr ||
    
    ptrN == nullptr) {
  // 记录日志,出错返回,等等
}

这当然不算糟糕,但下面这样的写法是不是好上一点点?

if (is_any_null(ptr1, ptr2, , ptrN)) {
  // 记录日志,出错返回,等等
}

is_any_null 的实现非常简单:

template <typename... Args>
constexpr bool
is_any_null(const Args&... args)
{
  return (... || (args == nullptr));
}

返回值检查

比上面这种更复杂一点的,是调用多个函数,检查返回值,并在返回值表示不成功时终止代码执行。示意代码如下:

error_t result{};
result = check1();
if (result != error_t::ok) {
  return result;
}
result = check2();
if (result != error_t::ok) {
  return result;
}
result = check3();
if (result != error_t::ok) {
  return result;
}
result = check4();
if (result != error_t::ok) {
  return result;
}
return error_t::ok;

利用折叠表达式,我们也可以这样简化代码:

return checked_exec(
  error_t::ok,
  [&] { return check1(); },
  [&] { return check2(); },
  [&] { return check3(); },
  [&] { return check4(); });

当然,我们需要提供 checked_exec 的定义:

template <typename R,
          typename... Fn>
R checked_exec(const R& expected,
               Fn&&... fn)
{
  R result = expected;
  (void)(((result = forward<Fn>(
             fn)()) == expected) &&
         ...);
  return result;
}

在参数展开和内联后,我们上面对 checked_exec 的调用就大致相当于下面的代码:

error_t result = error_t::ok;
(void)(((result = check1()) == error_t::ok) &&
       (((result = check2()) == error_t::ok) &&
        (((result = check3()) == error_t::ok) &&
         ((result = check4()) == error_t::ok))));
return result;

这里我严格按一元右折叠的形式进行了展开,但就如上面讨论过的,这里左折叠和右折叠是等价的。此时,去掉一些括号,代码会更加清晰:

(void)((result = check1()) == error_t::ok &&
       (result = check2()) == error_t::ok &&
       (result = check3()) == error_t::ok &&
       (result = check4()) == error_t::ok);

所以,我们看到了,利用折叠表达式和短路规则,我们可以实现 checked_exec 或类似的函数,来简化一些重复的检查,让代码更加清晰,并避免低级错误。

编译期遍历

利用逗号折叠表达式,我们可以实现一些编译期的遍历操作。最基本的,当然就是直接遍历所有的参数了。利用这种方式,我们可以来实现带分隔符的打印操作:

template <typename T,
          typename First,
          typename... Rest>
void print_with_separator(
  const T& sep,
  const First& first,
  const Rest&... rest)
{
  cout << first;
  ((cout << sep << rest), ...);
  cout << endl;
}

这个代码很简单,可以内联,因此我也没有必要像第 37 讲里描述的那样进一步进行传参优化了。这里的编译期展开就利用了逗号折叠表达式。比如,当我们以 print_with_separator(", ", "one", "two", "three") 来调用时,函数体展开成大致这个样子(去掉了不必要的括号):

cout << "one";
((cout << ", " << "two"),
 (cout << ", " << "three"));
cout << endl;

逗号前的那个表达式就成了我们希望在参数包 args 上反复执行的内容。

使用类似的方式,我们可以打印一个 tuple。这时,代码就稍微复杂一些了:我们需要根据需要遍历的项数预先生成编译期的整数序列,也就是第 18 讲讨论过的 make_index_sequence,然后利用折叠表达式来逐项遍历。

不过呢,我们这次会使用标准库里的一个对 make_index_sequence 的小小封装 [5]

template <class... T>
using index_sequence_for =
  make_index_sequence<sizeof...(T)>;

这个类模板会根据模板参数的项数来生成一个合适的序列。比如,如果传给 index_sequence_for 的模板参数有三项的话,那结果类型就会是 index_sequence<0, 1, 2>

然后,print_tuple 就可以这样实现:

template <typename Tup,
          size_t... Is>
void output_tuple_members(
  ostream& os,
  const Tup& tup,
  index_sequence<Is...>)
{
  ((os << (Is != 0 ? ", " : "")
       << get<Is>(tup)),
   ...);
}

template <typename... Args>
void print_tuple(const tuple<Args...>& args)
{
  cout << '(';
  output_tuple_members(
    cout, args,
    index_sequence_for<Args...>{});
  cout << ')';
}

对于一个三项的 tuple,最后展开出来的代码就差不多是这个样子:

cout << '(';
((cout << (0 != 0 ? ", " : "")
       << get<0>(args)),
 (cout << (1 != 0 ? ", " : "")
       << get<1>(args)),
 (cout << (2 != 0 ? ", " : "")
       << get<2>(args)));
cout << ')';

显然,它确实能够完成我们需要的打印任务。如果我们传它一个 make_tuple(1, "two", 3.14159),打印结果就会是:

(1, two, 3.14159)

内容小结

本讲我讨论了 C++17 提供的折叠表达式,并通过提供具体的例子,向你展示了如何使用这一特性来进行编译期展开,从而简化重复的代码。

课后思考

尝试一下不用折叠表达式去实现 checked_exec(或其他使用了折叠表达式的函数模板),体会一下折叠表达式带来的简化。

期待你的动手实践,有任何疑问我们留言区见!

参考资料

1
1a
2
2a
3
3a
4
4a
5
5a