This post was originally written on Codeforces; relevant discussion can be found here.
Someone asked me about my template that used a “cache wrapper” for lambdas, so I decided to write a post explaining how it works. For reference, here is a submission of mine from 2021 that uses that template.
Here’s what you will find implementations for in this post:
- Generalized hashing (for tuple types, sequence types and basic types)
- Convenient aliases for policy based data structures
- Wrappers for recursive (or otherwise) lambdas that automatically do caching (memoization) for you.
Jump to the Usage section if you only care about the template, though I would strongly recommend reading the rest of the post too since it has a lot of cool/useful things in my opinion.
If you know how to implement recursive lambdas
You will notice that the implementation in this post uses the self
pattern (non-standard terminology, just what I like to call it), just
like the y_combinator
pattern uses it.
If you know functional programming
This should remind you of the Continuation-passing style, and the cache implementation is like a monad that does this for you, but for this specific case.
Motivation
There are two major ways in which dynamic programming is implemented — recursive and iterative.
Recursive DP has the following advantages:
- It is (arguably) the cleaner way for beginners, and is often easier to reason about mathematically.
- It is needed in problems where you want to only traverse the states required for computing your answer and no others.
However, people switch to iterative DP as their default DP implementation, because of the following reasons:
- Recursive DP is sometimes slower compared to iterative DP when they compute the same number of states.
- It is much easier to do push DP iteratively.
- Recursive DP takes more boilerplate to write (something along the lines of “if computed already, return the stored value; otherwise, do the following”).
The template talked about in the above comment aims to tackle the third point by reducing boilerplate code — it does this by automatically managing memoization for you.
Implementation
In Python, you have nice things called decorators. For example, the
functools.cache
decorator is precisely what we would use for such a
purpose:
from functools import cache
@cache
def factorial(n):
return n * factorial(n - 1) if n else 1
If you call factorial(10)
, then the internal cache will be populated
with the results of all the recursive calls that were required to
compute factorial(10)
. As a result, you can call factorial(5)
and it
will just look it up in the internal cache and return the value to you
without calling the function body.
So I thought, why not try to do the same thing in C++?
Let’s get a couple of things out of the way first:
- Note that Python dictionaries support heterogeneous types, but you can’t do this without type erasure (in C++ or otherwise). So I decided to keep types homogeneous, i.e., the internal cache has keys of the same type. This was done to avoid performance hits, and it made sense because in competitive programming, people don’t really use DP states that have heterogeneous types.
- Functions in C++ can take arguments by value or by reference. But in Python, arguments are taken by name. For the purposes of DP, since it doesn’t really make sense to make states mutable, we decide that we will pass everything by value (there are a few pointers in the Caveats section if you are worried about performance implications due to copying a lot, but these are often not important since DP states tend to be integers/small types most of the time).
So let’s start out by trying to design it.
We need the following two parts:
- A wrapper around a function that is able to store calls to the function and their results.
- Generalized hashing, to be able to store results.
Let’s solve the first problem first. Hashing arbitrary structs in C++ is quite hard (not impossible, though). But what you can always do (for structs whose intent is to just store elements) is to supply a function that makes them tuples (i.e., provide a type conversion operator). For our problem, we identify the most important types that we will ever need to hash:
- Native integral types (including characters)
- Sequence types (vectors, sets, maps)
- Tuple types (tuples, pairs)
The last two are uniformly recursive on their inputs (except for
vector<bool>
, but our implementation doesn’t need to treat it
separately), so we can do recursive metaprogramming to be able to hash
such types.
All in all, we need a hashing function that hashes the “base case” and a hash combination function that aggregates over the sequence/tuple types.
The implementation becomes something like this:
Spoiler
namespace hashing {
using i64 = std::int64_t;
using u64 = std::uint64_t;
static const u64 FIXED_RANDOM = std::chrono::steady_clock::now().time_since_epoch().count();
#if USE_AES
std::mt19937 rd(FIXED_RANDOM);
const __m128i KEY1{(i64)rd(), (i64)rd()};
const __m128i KEY2{(i64)rd(), (i64)rd()};
#endif
template <class T, class D = void>
struct custom_hash {};
// https://www.boost.org/doc/libs/1_55_0/doc/html/hash/combine.html
template <class T>
inline void hash_combine(u64& seed, const T& v) {
custom_hash<T> hasher;
seed ^= hasher(v) + 0x9e3779b97f4a7c15 + (seed << 12) + (seed >> 4);
};
// http://xorshift.di.unimi.it/splitmix64.c
template <class T>
struct custom_hash<T, typename std::enable_if<std::is_integral<T>::value>::type> {
u64 operator()(T _x) const {
u64 x = _x;
#if USE_AES
// implementation defined till C++17, defined from C++20
__m128i m{i64(u64(x) * 0xbf58476d1ce4e5b9u64), (i64)FIXED_RANDOM};
__m128i y = _mm_aesenc_si128(m, KEY1);
__m128i z = _mm_aesenc_si128(y, KEY2);
return z[0];
#else
x += 0x9e3779b97f4a7c15 + FIXED_RANDOM;
x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
return x ^ (x >> 31);
#endif
}
};
template <class T>
struct custom_hash<T, std::void_t<decltype(std::begin(std::declval<T>()))>> {
u64 operator()(const T& a) const {
u64 value = FIXED_RANDOM;
for (auto& x : a) hash_combine(value, x);
return value;
}
};
template <class... T>
struct custom_hash<std::tuple<T...>> {
u64 operator()(const std::tuple<T...>& a) const {
u64 value = FIXED_RANDOM;
std::apply([&value](T const&... args) { (hash_combine(value, args), ...); }, a);
return value;
}
};
template <class T, class U>
struct custom_hash<std::pair<T, U>> {
u64 operator()(const std::pair<T, U>& a) const {
u64 value = FIXED_RANDOM;
hash_combine(value, a.first);
hash_combine(value, a.second);
return value;
}
};
}; // namespace hashing
Now for the hash table, we could just use std::unordered_map
. However,
it often turns out to be slower than the GNU policy based data structure
gp_hash_table
, so we also define some aliases that are helpful.
Spoiler
#include "ext/pb_ds/assoc_container.hpp"
#include "ext/pb_ds/tree_policy.hpp"
namespace pbds {
using namespace __gnu_pbds;
#ifdef PB_DS_ASSOC_CNTNR_HPP
template <class Key, class Value, class Hash>
using unordered_map = gp_hash_table<Key, Value, Hash, std::equal_to<Key>, direct_mask_range_hashing<>, linear_probe_fn<>,
hash_standard_resize_policy<hash_exponential_size_policy<>, hash_load_check_resize_trigger<>, true>>;
template <class Key, class Hash>
using unordered_set = pbds::unordered_map<Key, null_type, Hash>;
#endif
#ifdef PB_DS_TREE_POLICY_HPP
template <typename T>
using ordered_set = tree<T, null_type, std::less<T>, rb_tree_tag, tree_order_statistics_node_update>;
template <typename T>
using ordered_multiset = tree<T, null_type, std::less_equal<T>, rb_tree_tag, tree_order_statistics_node_update>;
template <class Key, class Value, class Compare = std::less<Key>>
using ordered_map = tree<Key, Value, Compare, rb_tree_tag, tree_order_statistics_node_update>;
#endif
} // namespace pbds
Now the only thing that remains is to actually decide the design of the wrapper.
We make the following choices:
- As mentioned above, we won’t support reference semantics for now.
- We can choose to use either lambdas or plain old functions. Functions can’t “capture” any state, and thus you need to make things global for using them with functions (or pass them as references, which we already disallowed). Hence, we will support lambdas and not functions for the sake of generality.
- We will choose to implement the wrapper as a struct that stores the
lambda as well as the cache (hash table mapping tuples of input
arguments to return values) inside it, and has its own
operator()
(call operator) that looks into the cache and returns the answer if found, else calls the lambda recursively and updates the cache (both recursively and for the current call).
This is done as follows:
Spoiler
template <typename Signature, typename Lambda>
struct Cache;
template <typename ReturnType, typename... Args, typename Lambda>
struct Cache<ReturnType(Args...), Lambda> {
template <typename... DummyArgs>
ReturnType operator()(DummyArgs&&... args) {
auto tied_args = std::tie(args...);
auto it = memo.find(tied_args);
if (it == memo.end()) {
auto&& ans = f(*this, std::forward<DummyArgs>(args)...);
memo[tied_args] = ans;
return ans;
} else {
return it->second;
}
}
template <class _Lambda>
Cache(std::tuple<>, _Lambda&& _f) : f(std::forward<_Lambda>(_f)) {}
Lambda f;
using TiedArgs = std::tuple<std::decay_t<Args>...>;
pbds::unordered_map<TiedArgs, ReturnType, hashing::custom_hash<TiedArgs>> memo;
};
template <class Signature, class Lambda>
auto use_cache(Lambda&& f) {
return Cache<Signature, Lambda>(std::tuple{}, std::forward<Lambda>(f));
}
Usage
The usage is very simple.
Here’s the whole template for reference:
Template
namespace hashing {
using i64 = std::int64_t;
using u64 = std::uint64_t;
static const u64 FIXED_RANDOM = std::chrono::steady_clock::now().time_since_epoch().count();
#if USE_AES
std::mt19937 rd(FIXED_RANDOM);
const __m128i KEY1{(i64)rd(), (i64)rd()};
const __m128i KEY2{(i64)rd(), (i64)rd()};
#endif
template <class T, class D = void>
struct custom_hash {};
// https://www.boost.org/doc/libs/1_55_0/doc/html/hash/combine.html
template <class T>
inline void hash_combine(u64& seed, const T& v) {
custom_hash<T> hasher;
seed ^= hasher(v) + 0x9e3779b97f4a7c15 + (seed << 12) + (seed >> 4);
};
// http://xorshift.di.unimi.it/splitmix64.c
template <class T>
struct custom_hash<T, typename std::enable_if<std::is_integral<T>::value>::type> {
u64 operator()(T _x) const {
u64 x = _x;
#if USE_AES
// implementation defined till C++17, defined from C++20
__m128i m{i64(u64(x) * 0xbf58476d1ce4e5b9u64), (i64)FIXED_RANDOM};
__m128i y = _mm_aesenc_si128(m, KEY1);
__m128i z = _mm_aesenc_si128(y, KEY2);
return z[0];
#else
x += 0x9e3779b97f4a7c15 + FIXED_RANDOM;
x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
return x ^ (x >> 31);
#endif
}
};
template <class T>
struct custom_hash<T, std::void_t<decltype(std::begin(std::declval<T>()))>> {
u64 operator()(const T& a) const {
u64 value = FIXED_RANDOM;
for (auto& x : a) hash_combine(value, x);
return value;
}
};
template <class... T>
struct custom_hash<std::tuple<T...>> {
u64 operator()(const std::tuple<T...>& a) const {
u64 value = FIXED_RANDOM;
std::apply([&value](T const&... args) { (hash_combine(value, args), ...); }, a);
return value;
}
};
template <class T, class U>
struct custom_hash<std::pair<T, U>> {
u64 operator()(const std::pair<T, U>& a) const {
u64 value = FIXED_RANDOM;
hash_combine(value, a.first);
hash_combine(value, a.second);
return value;
}
};
}; // namespace hashing
#include "ext/pb_ds/assoc_container.hpp"
#include "ext/pb_ds/tree_policy.hpp"
namespace pbds {
using namespace __gnu_pbds;
#ifdef PB_DS_ASSOC_CNTNR_HPP
template <class Key, class Value, class Hash>
using unordered_map = gp_hash_table<Key, Value, Hash, std::equal_to<Key>, direct_mask_range_hashing<>, linear_probe_fn<>,
hash_standard_resize_policy<hash_exponential_size_policy<>, hash_load_check_resize_trigger<>, true>>;
template <class Key, class Hash>
using unordered_set = pbds::unordered_map<Key, null_type, Hash>;
#endif
#ifdef PB_DS_TREE_POLICY_HPP
template <typename T>
using ordered_set = tree<T, null_type, std::less<T>, rb_tree_tag, tree_order_statistics_node_update>;
template <typename T>
using ordered_multiset = tree<T, null_type, std::less_equal<T>, rb_tree_tag, tree_order_statistics_node_update>;
template <class Key, class Value, class Compare = std::less<Key>>
using ordered_map = tree<Key, Value, Compare, rb_tree_tag, tree_order_statistics_node_update>;
#endif
} // namespace pbds
template <typename Signature, typename Lambda>
struct Cache;
template <typename ReturnType, typename... Args, typename Lambda>
struct Cache<ReturnType(Args...), Lambda> {
template <typename... DummyArgs>
ReturnType operator()(DummyArgs&&... args) {
auto tied_args = std::tie(args...);
auto it = memo.find(tied_args);
if (it == memo.end()) {
auto&& ans = f(*this, std::forward<DummyArgs>(args)...);
memo[tied_args] = ans;
return ans;
} else {
return it->second;
}
}
template <class _Lambda>
Cache(std::tuple<>, _Lambda&& _f) : f(std::forward<_Lambda>(_f)) {}
Lambda f;
using TiedArgs = std::tuple<std::decay_t<Args>...>;
pbds::unordered_map<TiedArgs, ReturnType, hashing::custom_hash<TiedArgs>> memo;
};
template <class Signature, class Lambda>
auto use_cache(Lambda&& f) {
return Cache<Signature, Lambda>(std::tuple{}, std::forward<Lambda>(f));
}
Let’s first try to replicate the python example from above:
auto factorial = use_cache<int(int)>([](auto&& self, int n) -> int {
if (n) return n * self(n - 1);
else return 1;
});
std::cout << factorial(10) << '\n';
You can do more complicated things with this template too. Let’s say you
want to have a recursive function that takes a char
and
pair<int, string>
as a DP state, and stores a bool. It also requires
values of some array you have taken as input, so you need to capture it
as well.
You write it like this:
vector<int> v;
// read v
auto solve = use_cache<bool(char, pair<int, string>)>([&](auto&& self, char c, pair<int, string> p) -> bool {
// note that v was captured by &
// use c, p, v as needed
// let's say you have c1 and p1 as arguments for a recursive call
auto x = self(c1, p1); // recursive call - do as many of them as needed
// use c, p, v as needed
return y; // return something
});
std::cout << solve('a', pair{1, "hello"s}) << '\n';
Caveats
The most important caveat is that since it uses a hash-table under the hood, you will often miss the performance that you get from just using multidimensional arrays. You can adapt the template to do things like that, but I didn’t, for the sake of clarity and ease of implementation.
Since it doesn’t make a lot of sense to store function calls with arguments that can change over time, the template doesn’t support lambdas that take things by reference either. This is done in order to avoid any possible bugs that might come with references being modified on the fly somewhere in your code.
However, if for some reason (maybe for performance you want const refs of large structs/vectors instead of copies) you really want to use references, you can remedy that by doing the following:
- Wrap your arguments within
std::reference_wrapper_t
(for example, rather thanint&
, usestd::reference_wrapper_t<int>
) — this makes a reference behave like a value (and is copyable for example). - Provide a hashing implementation for
std::reference_wrapper_t<T>
. - Avoid reference invalidation, i.e., the hash table should store values and not references, because references to temporary variables get invalidated during the run of the algorithm.
- Handle any other issues that might crop up, on your own.
Do let me know if you know of a better way of doing these things, and if there are any errors that might have crept into this post!