fix weighted_mean::operator+= and add missing noexcept to operator*= (#311)

This commit is contained in:
Hans Dembinski 2021-03-17 11:53:10 +01:00 committed by GitHub
parent 18b544c6ad
commit f9920135c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 126 additions and 44 deletions

View File

@ -79,10 +79,10 @@ public:
+ sum_of_deltas_squared_2 + n2 (mu2 - mu))^2 + sum_of_deltas_squared_2 + n2 (mu2 - mu))^2
*/ */
const auto mu1 = mean_;
const auto mu2 = rhs.mean_;
const auto n1 = sum_; const auto n1 = sum_;
const auto mu1 = mean_;
const auto n2 = rhs.sum_; const auto n2 = rhs.sum_;
const auto mu2 = rhs.mean_;
sum_ += rhs.sum_; sum_ += rhs.sum_;
mean_ = (n1 * mu1 + n2 * mu2) / sum_; mean_ = (n1 * mu1 + n2 * mu2) / sum_;

View File

@ -8,6 +8,7 @@
#define BOOST_HISTOGRAM_ACCUMULATORS_WEIGHTED_MEAN_HPP #define BOOST_HISTOGRAM_ACCUMULATORS_WEIGHTED_MEAN_HPP
#include <boost/core/nvp.hpp> #include <boost/core/nvp.hpp>
#include <boost/histogram/detail/square.hpp>
#include <boost/histogram/fwd.hpp> // for weighted_mean<> #include <boost/histogram/fwd.hpp> // for weighted_mean<>
#include <boost/histogram/weight.hpp> #include <boost/histogram/weight.hpp>
#include <cassert> #include <cassert>
@ -31,7 +32,7 @@ public:
weighted_mean() = default; weighted_mean() = default;
/// Allow implicit conversion from other weighted_means /// Allow implicit conversion from other weighted_means.
template <class T> template <class T>
weighted_mean(const weighted_mean<T>& o) weighted_mean(const weighted_mean<T>& o)
: sum_of_weights_{o.sum_of_weights_} : sum_of_weights_{o.sum_of_weights_}
@ -39,7 +40,7 @@ public:
, weighted_mean_{o.weighted_mean_} , weighted_mean_{o.weighted_mean_}
, sum_of_weighted_deltas_squared_{o.sum_of_weighted_deltas_squared_} {} , sum_of_weighted_deltas_squared_{o.sum_of_weighted_deltas_squared_} {}
/// Initialize to external sum of weights, sum of weights squared, mean, and variance /// Initialize to external sum of weights, sum of weights squared, mean, and variance.
weighted_mean(const_reference wsum, const_reference wsum2, const_reference mean, weighted_mean(const_reference wsum, const_reference wsum2, const_reference mean,
const_reference variance) const_reference variance)
: sum_of_weights_(wsum) : sum_of_weights_(wsum)
@ -48,10 +49,10 @@ public:
, sum_of_weighted_deltas_squared_( , sum_of_weighted_deltas_squared_(
variance * (sum_of_weights_ - sum_of_weights_squared_ / sum_of_weights_)) {} variance * (sum_of_weights_ - sum_of_weights_squared_ / sum_of_weights_)) {}
/// Insert sample x /// Insert sample x.
void operator()(const_reference x) { operator()(weight(1), x); } void operator()(const_reference x) { operator()(weight(1), x); }
/// Insert sample x with weight w /// Insert sample x with weight w.
void operator()(const weight_type<value_type>& w, const_reference x) { void operator()(const weight_type<value_type>& w, const_reference x) {
sum_of_weights_ += w.value; sum_of_weights_ += w.value;
sum_of_weights_squared_ += w.value * w.value; sum_of_weights_squared_ += w.value * w.value;
@ -60,24 +61,33 @@ public:
sum_of_weighted_deltas_squared_ += w.value * delta * (x - weighted_mean_); sum_of_weighted_deltas_squared_ += w.value * delta * (x - weighted_mean_);
} }
/// Add another weighted_mean /// Add another weighted_mean.
weighted_mean& operator+=(const weighted_mean& rhs) { weighted_mean& operator+=(const weighted_mean& rhs) {
if (sum_of_weights_ != 0 || rhs.sum_of_weights_ != 0) { if (rhs.sum_of_weights_ == 0) return *this;
const auto tmp =
weighted_mean_ * sum_of_weights_ + rhs.weighted_mean_ * rhs.sum_of_weights_; // see mean.hpp for derivation of correct formula
sum_of_weights_ += rhs.sum_of_weights_;
sum_of_weights_squared_ += rhs.sum_of_weights_squared_; const auto n1 = sum_of_weights_;
weighted_mean_ = tmp / sum_of_weights_; const auto mu1 = weighted_mean_;
} const auto n2 = rhs.sum_of_weights_;
const auto mu2 = rhs.weighted_mean_;
sum_of_weights_ += rhs.sum_of_weights_;
sum_of_weights_squared_ += rhs.sum_of_weights_squared_;
weighted_mean_ = (n1 * mu1 + n2 * mu2) / sum_of_weights_;
sum_of_weighted_deltas_squared_ += rhs.sum_of_weighted_deltas_squared_; sum_of_weighted_deltas_squared_ += rhs.sum_of_weighted_deltas_squared_;
sum_of_weighted_deltas_squared_ += n1 * detail::square(weighted_mean_ - mu1);
sum_of_weighted_deltas_squared_ += n2 * detail::square(weighted_mean_ - mu2);
return *this; return *this;
} }
/** Scale by value /** Scale by value.
This acts as if all samples were scaled by the value. This acts as if all samples were scaled by the value.
*/ */
weighted_mean& operator*=(const_reference s) { weighted_mean& operator*=(const_reference s) noexcept {
weighted_mean_ *= s; weighted_mean_ *= s;
sum_of_weighted_deltas_squared_ *= s * s; sum_of_weighted_deltas_squared_ *= s * s;
return *this; return *this;
@ -92,10 +102,10 @@ public:
bool operator!=(const weighted_mean& rhs) const noexcept { return !operator==(rhs); } bool operator!=(const weighted_mean& rhs) const noexcept { return !operator==(rhs); }
/// Return sum of weights /// Return sum of weights.
const_reference sum_of_weights() const noexcept { return sum_of_weights_; } const_reference sum_of_weights() const noexcept { return sum_of_weights_; }
/// Return sum of weights squared (variance of weight distribution) /// Return sum of weights squared (variance of weight distribution).
const_reference sum_of_weights_squared() const noexcept { const_reference sum_of_weights_squared() const noexcept {
return sum_of_weights_squared_; return sum_of_weights_squared_;
} }
@ -106,12 +116,13 @@ public:
*/ */
const_reference value() const noexcept { return weighted_mean_; } const_reference value() const noexcept { return weighted_mean_; }
/** Return variance of accumulated weighted samples /** Return variance of accumulated weighted samples.
The result is undefined, if `sum_of_weights() == 0` or The result is undefined, if `sum_of_weights() == 0` or
`sum_of_weights() == sum_of_weights_squared()`. `sum_of_weights() == sum_of_weights_squared()`.
*/ */
value_type variance() const { value_type variance() const {
// see https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights
return sum_of_weighted_deltas_squared_ / return sum_of_weighted_deltas_squared_ /
(sum_of_weights_ - sum_of_weights_squared_ / sum_of_weights_); (sum_of_weights_ - sum_of_weights_squared_ / sum_of_weights_);
} }

View File

@ -18,33 +18,95 @@ using namespace std::literals;
int main() { int main() {
using m_t = accumulators::weighted_mean<double>; using m_t = accumulators::weighted_mean<double>;
m_t a; using detail::square;
BOOST_TEST_EQ(a.sum_of_weights(), 0);
BOOST_TEST_EQ(a, m_t{});
a(weight(0.5), 1); // basic interface, string conversion
a(weight(1.0), 2); {
a(weight(0.5), 3); // see https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights
BOOST_TEST_EQ(a.sum_of_weights(), 2); m_t a;
BOOST_TEST_EQ(a.sum_of_weights_squared(), 1.5); BOOST_TEST_EQ(a.sum_of_weights(), 0);
BOOST_TEST_EQ(a.value(), 2); BOOST_TEST_EQ(a, m_t{});
BOOST_TEST_IS_CLOSE(a.variance(), 0.8, 1e-3);
BOOST_TEST_EQ(str(a), "weighted_mean(2, 2, 0.8)"s); a(weight(0.5), 1);
BOOST_TEST_EQ(str(a, 25, false), " weighted_mean(2, 2, 0.8)"s); a(weight(1.0), 2);
BOOST_TEST_EQ(str(a, 25, true), "weighted_mean(2, 2, 0.8) "s); a(weight(0.5), 3);
auto b = a; BOOST_TEST_EQ(a.sum_of_weights(), 1 + 2 * 0.5);
b += a; // same as feeding all samples twice BOOST_TEST_EQ(a.sum_of_weights_squared(), 1 + 2 * 0.5 * 0.5);
const auto m = a.value();
BOOST_TEST_IS_CLOSE(
a.variance(),
(0.5 * square(1 - m) + square(2 - m) + 0.5 * square(3 - m)) /
(a.sum_of_weights() - a.sum_of_weights_squared() / a.sum_of_weights()),
1e-3);
BOOST_TEST_EQ(b.sum_of_weights(), 4); BOOST_TEST_EQ(str(a), "weighted_mean(2, 2, 0.8)"s);
BOOST_TEST_EQ(b.value(), 2); BOOST_TEST_EQ(str(a, 25, false), " weighted_mean(2, 2, 0.8)"s);
BOOST_TEST_IS_CLOSE(b.variance(), 0.615, 1e-3); BOOST_TEST_EQ(str(a, 25, true), "weighted_mean(2, 2, 0.8) "s);
}
BOOST_TEST_EQ(m_t() += m_t(), m_t()); // addition of zero element
BOOST_TEST_EQ(m_t(1, 2, 3, 4) += m_t(), m_t(1, 2, 3, 4)); {
BOOST_TEST_EQ(m_t() += m_t(1, 2, 3, 4), m_t(1, 2, 3, 4)); BOOST_TEST_EQ(m_t() += m_t(), m_t());
BOOST_TEST_EQ(m_t(1, 2, 3, 4) += m_t(), m_t(1, 2, 3, 4));
BOOST_TEST_EQ(m_t() += m_t(1, 2, 3, 4), m_t(1, 2, 3, 4));
}
// addition
{
m_t a, b, c;
a(weight(4), 2);
a(weight(3), 3);
BOOST_TEST_EQ(a.sum_of_weights(), 4 + 3);
BOOST_TEST_EQ(a.sum_of_weights_squared(), 4 * 4 + 3 * 3);
BOOST_TEST_EQ(a.value(), (4 * 2 + 3 * 3) / 7.);
BOOST_TEST_IS_CLOSE(a.variance(), 0.5, 1e-3);
b(weight(2), 4);
b(weight(1), 6);
BOOST_TEST_EQ(b.sum_of_weights(), 3);
BOOST_TEST_EQ(b.sum_of_weights_squared(), 1 + 2 * 2);
BOOST_TEST_EQ(b.value(), (2 * 4 + 1 * 6) / (2. + 1.));
BOOST_TEST_IS_CLOSE(b.variance(), 2, 1e-3);
c(weight(4), 2);
c(weight(3), 3);
c(weight(2), 4);
c(weight(1), 6);
auto d = a;
d += b;
BOOST_TEST_EQ(c.sum_of_weights(), d.sum_of_weights());
BOOST_TEST_EQ(c.sum_of_weights_squared(), d.sum_of_weights_squared());
BOOST_TEST_EQ(c.value(), d.value());
BOOST_TEST_IS_CLOSE(c.variance(), d.variance(), 1e-3);
}
// using weights * 2 compared to adding weighted samples twice must
// - give same for sum_of_weights and mean
// - give twice sum_of_weights_squared
// - give half effective count
// - variance is complicated, but larger
{
m_t a, b;
for (int i = 0; i < 2; ++i) {
a(weight(0.5), 1);
a(weight(1.0), 2);
a(weight(0.5), 3);
}
b(weight(1), 1);
b(weight(2), 2);
b(weight(1), 3);
BOOST_TEST_EQ(a.sum_of_weights(), b.sum_of_weights());
BOOST_TEST_EQ(2 * a.sum_of_weights_squared(), b.sum_of_weights_squared());
BOOST_TEST_EQ(a.value(), b.value());
BOOST_TEST_LT(a.variance(), b.variance());
}
return boost::report_errors(); return boost::report_errors();
} }

View File

@ -232,10 +232,19 @@ int main() {
a.reset(1); a.reset(1);
a[0](/* sample */ 1); a[0](/* sample */ 1);
a[0](weight(2), /* sample */ 2); a[0](weight(2), /* sample */ 2);
a[0] += accumulators::weighted_mean<>(1, 0, 0, 0);
BOOST_TEST_EQ(a[0].sum_of_weights(), 4); accumulators::weighted_mean<double> b;
BOOST_TEST_IS_CLOSE(a[0].value(), 1.25, 1e-3); b(weight(3), 3);
BOOST_TEST_IS_CLOSE(a[0].variance(), 0.242, 1e-3); a[0] += b;
accumulators::weighted_mean<double> c;
c(weight(1), 1);
c(weight(2), 2);
c(weight(3), 3);
BOOST_TEST_EQ(a[0].sum_of_weights(), c.sum_of_weights());
BOOST_TEST_IS_CLOSE(a[0].value(), c.value(), 1e-3);
BOOST_TEST_IS_CLOSE(a[0].variance(), c.variance(), 1e-3);
} }
// exceeding array capacity // exceeding array capacity