This post was originally written on Codeforces; relevant discussion can be found here.
TL;DR
Use the following template (C++20) for efficient and near-optimal binary search (in terms of number of queries) on floating point numbers.
Template
template <std::size_t N_BITS>
using int_least_t = std::conditional_t<
N_BITS <= 8, std::uint8_t,
std::conditional_t<
N_BITS <= 16, std::uint16_t,
std::conditional_t<
N_BITS <= 32, std::uint32_t,
std::conditional_t<
N_BITS <= 64, std::uint64_t,
std::conditional_t<N_BITS <= 128, __uint128_t, void>>>>>;
// this should work for float and doubles, but for long doubles, std::bit_cast will fail on most systems due to being 80 bits wide.
// to handle this, consider using doubles instead or std::bit_cast the long double to an 80-bit bitset and convert it to a 128 bit integer using to_ullong.
/*
* returns first x in [a, b] such that predicate(x) is false, conditioned on
* logical_predicate(a) && !logical_predicate(b) && logical_predicate(-inf) &&
* !logical_predicate(inf)
* here logical_predicate is the mathematical value of the predicate, not the
* machine value of the predicate
* it is guaranteed that non-nan, non-inf inputs are passed into the predicate
* if NaNs or infinities are passed to this function as argument, then the
* inputs to the predicate will start from smallest/largest representable
* floating point numbers of the input type - this can be a source of errors
* if you multiply the input by something > 1 for example
* strictly speaking, the predicate should also be perfectly monotonic, but if
* it gives out-of-order booleans in some small range [a, a + eps] (and the
* correct order elsewhere), then the answer will be somewhere in between
* the same holds for how denormals are handled by this code
*/
//
template <bool check_infinities = false,
bool distinguish_plus_minus_zero = false,
bool deal_with_nans_and_infs = false, std::floating_point T>
T partition_point_fp(T a, T b, auto&& predicate) {
static constexpr std::size_t T_WIDTH = sizeof(T) * CHAR_BIT;
using Int = int_least_t<T_WIDTH>;
static constexpr auto is_negative = [](T x) {
return static_cast<bool>((std::bit_cast<Int>(x) >> (T_WIDTH - 1)) & 1);
};
if constexpr (distinguish_plus_minus_zero) {
if (a == T(0.0) && b == T(0.0) && is_negative(a) && !is_negative(b)) {
if (!predicate(-T(0.0))) {
return -T(0.0);
} else {
// predicate(0.0) is guaranteed to be true because b = 0.0
return T(0.0);
}
}
}
if (a >= b) return NAN;
if constexpr (deal_with_nans_and_infs) {
// get rid of NaNs as soon as possible
if (std::isnan(a)) a = -std::numeric_limits<T>::infinity();
if (std::isnan(b)) b = std::numeric_limits<T>::infinity();
// deal with infinities
if (a == -std::numeric_limits<T>::infinity()) {
if constexpr (check_infinities) {
if (predicate(-std::numeric_limits<T>::max())) {
a = -std::numeric_limits<T>::max();
} else {
return -std::numeric_limits<T>::max();
}
} else {
a = -std::numeric_limits<T>::max();
}
}
if (b == std::numeric_limits<T>::infinity()) {
if constexpr (check_infinities) {
if (!predicate(std::numeric_limits<T>::max())) {
b = std::numeric_limits<T>::max();
} else {
return std::numeric_limits<T>::infinity();
}
} else {
b = std::numeric_limits<T>::max();
}
}
}
// now a and b are both finite - deal with differently signed a and b
if (is_negative(a) && !is_negative(b)) {
// check 0 once
if constexpr (distinguish_plus_minus_zero) {
if (!predicate(-T(0.0))) {
b = -T(0.0);
} else if (predicate(T(0.0))) {
a = T(0.0);
} else {
return T(0.0);
}
} else {
if (!predicate(T(0.0))) {
b = -T(0.0);
} else {
a = T(0.0);
}
}
}
// in the case a and b are both 0 after the above check, return 0
if (a == b) return T(0.0);
// start actual binary search
auto get_int = [](T x) { return std::bit_cast<Int, T>(x); };
auto get_float = [](Int x) { return std::bit_cast<T, Int>(x); };
if (b > 0) {
while (get_int(a) + 1 < get_int(b)) {
auto m = std::midpoint(get_int(a), get_int(b));
if (predicate(get_float(m))) {
a = get_float(m);
} else {
b = get_float(m);
}
}
} else {
while (get_int(-b) + 1 < get_int(-a)) {
auto m = std::midpoint(get_int(-b), get_int(-a));
if (predicate(-get_float(m))) {
a = -get_float(m);
} else {
b = -get_float(m);
}
}
}
return b;
}
It is also possible to extend this to breaking early when a custom closeness predicate is true (for example, min(absolute error, relative error) < 1e-9 and so on), but for the sake of simplicity, this template does not do so.
Introduction
It is well known that writing the condition for continuation of the
binary search as anything like while (r - l > 1)
or while (l < r)
is
bug-prone on floating point integers, so people tend to write a binary
search by fixing the number of iterations. However, this is usually not
the best method — you need \(O(\log L)\) iterations where \(L\) is the
length of the range, and the range of floating point numbers is very
large.
So, I came up with a way to binary search on (IEEE754, i.e., practically
all implementations of) floating point numbers that takes
bit_width(floating_type) + O(1)
calls to the predicate for binary
search. I am pretty sure that this method has been explored before (feel
free to drop references in the comments if you find them). Regardless of
whether this is a novel algorithm or not, I wanted to share this since
it teaches you a lot about floating points, and it also has the
following advantages:
- No need to hard-code the number of iterations.
- It avoids issues that you get in other implementations (for example, using sqrt, you end up with a ton of cases with 0, negatives and so on).
- It is efficient (sqrt is expensive, and so is dealing with a lot of cases).
Arriving at the algorithm
I started out by thinking about this: floating point numbers have a very large range. Can we do better than \(O(\log L)\) where \(L\) is the length of the range? Recall that floating point numbers also have a fixed width that is much smaller than the log of the range they represent, and between two consecutive representable floating point numbers, it is their ratio that is nicely bounded (ignoring the boundary of 0 and inf and nans) instead of their difference. Now since it is well-known that for getting to a certain relative-error, it is optimal to use \(\sqrt{lr}\) instead of \(\frac{l + r}{2}\) in binary search.
So the first idea that comes to mind is to use sqrt instead of midpoint in the usual binary search. However, it has a lot of issues — for example, if \(l = 0\), then you never progress in the binary search (this is fixable by doing a midpoint search on 0 to 1 and sqrt search on 1 to \(r\)). One more issue is how to deal with negatives (this is again doable by splitting the input range into multiple ranges where sqrt works) and with overflows/underflows. And the main issue with this approach is that sqrt is expensive — if you are doing a problem where the predicate is pretty fast and you need to do a lot of binary searches, most of your computation time would be due to sqrt.
Inspired by these, we decide to try to approximate sqrt in a way that
preserves monotonicity. Here comes the main point — note that the
IEEE754 implementation of floating point numbers separates the mantissa
from the exponent, and the exponent part of \(\sqrt{lr}\) is roughly the
mean of that of \(l\) and \(r\). Since these are the top few bits
(excluding the sign bit), we can just do the following (for positive
floating point numbers at least): read the floating point representation
as an integer (this can be done in a type-safe manner by using
std::bit_cast
), take the midpoint of these integers, and convert it
back to a floating point number. This clearly preserves monotonicity —
this can be verified easily by hand. Note that this same thing works for
when both range endpoints are negative numbers too — for this we
simply invert the sign bit. The case where both are of opposite signs is
also simple (if you disregard the fact that +0 and -0 are equal but have
distinct representations and reciprocals) — check the predicate on
\(0\) (both the zeroes if it matters for your predicate).
Now if we want to do some more handling (infinities, NaNs, denormals), we can do it as we wish. In the implementation in the TL;DR above, we decide to replace NaNs with the infinity in the correct direction, and since there can be many infinities, we try to bring down the range endpoint to the largest representable floating point instead.
Analysis
In all, there are at most \(w + O(1)\) calls to the predicate, where \(w\) is the length of the complement of the longest common prefix of the floating point representations of endpoints.
This implementation is also robust to predicates which are noisy near the boundary (i.e., near the boundary, there is a small range where the true values of the predicate can come after the false values) — in this case the algorithm returns something in this range near the boundary.
Note that we did not need to hardcode the number of iterations, nor did we require some carefully-written predicate for loop termination.
Also, since std::bit_cast
is practically free compared to std::sqrt
,
it is much faster in cases where you need to do a lot of binary
searches.
As a usage example, this submission uses the usual binary search with fixed number of iterations, and this submission uses the template above.
Some problems
The following are some problems that use binary search on floating points, and should be solvable using this template. If you encounter any bugs in the template while solving these problems, do let me know in the comments below. Thanks to jeroenodb and PurpleCrayon for problem suggestions.