Fix accumulator issues (#72)

* add weighted_mean::sum_of_weights_squared
* sum of default constructed mean and weighted_mean is now equal to default constructed object
* thread-safe returns a reference to self in operators just like the other accumulators
This commit is contained in:
Hans Dembinski 2019-10-26 11:12:29 +02:00 committed by GitHub
parent ebabd550a0
commit c8b8c4d502
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 55 additions and 28 deletions

View File

@ -33,7 +33,6 @@ public:
void operator()(const RealType& x) noexcept {
sum_ += static_cast<RealType>(1);
const auto delta = x - mean_;
BOOST_ASSERT(sum_ != 0);
mean_ += delta / sum_;
sum_of_deltas_squared_ += delta * (x - mean_);
}
@ -41,17 +40,17 @@ public:
void operator()(const RealType& w, const RealType& x) noexcept {
sum_ += w;
const auto delta = x - mean_;
BOOST_ASSERT(sum_ != 0);
mean_ += w * delta / sum_;
sum_of_deltas_squared_ += w * delta * (x - mean_);
}
template <class T>
mean& operator+=(const mean<T>& rhs) noexcept {
const auto tmp = mean_ * sum_ + static_cast<RealType>(rhs.mean_ * rhs.sum_);
sum_ += rhs.sum_;
BOOST_ASSERT(sum_ != 0);
mean_ = tmp / sum_;
if (sum_ != 0 || rhs.sum_ != 0) {
const auto tmp = mean_ * sum_ + static_cast<RealType>(rhs.mean_ * rhs.sum_);
sum_ += rhs.sum_;
mean_ = tmp / sum_;
}
sum_of_deltas_squared_ += static_cast<RealType>(rhs.sum_of_deltas_squared_);
return *this;
}
@ -80,8 +79,7 @@ public:
template <class Archive>
void serialize(Archive& ar, unsigned version) {
if (version == 0) {
if (Archive::is_saving::value)
BOOST_THROW_EXCEPTION(std::runtime_error("save must not use version 0"));
// read only
std::size_t sum;
ar& make_nvp("sum", sum);
sum_ = static_cast<RealType>(sum);

View File

@ -31,6 +31,7 @@ namespace accumulators {
template <class T>
class thread_safe : public std::atomic<T> {
public:
using value_type = T;
using super_t = std::atomic<T>;
thread_safe() noexcept : super_t(static_cast<T>(0)) {}
@ -41,14 +42,24 @@ public:
return *this;
}
thread_safe(T arg) : super_t(arg) {}
thread_safe& operator=(T arg) {
thread_safe(value_type arg) : super_t(arg) {}
thread_safe& operator=(value_type arg) {
super_t::store(arg);
return *this;
}
void operator+=(T arg) { super_t::fetch_add(arg, std::memory_order_relaxed); }
void operator++() { operator+=(static_cast<T>(1)); }
thread_safe& operator+=(const thread_safe& arg) {
operator+=(arg.load());
return *this;
}
thread_safe& operator+=(value_type arg) {
super_t::fetch_add(arg, std::memory_order_relaxed);
return *this;
}
thread_safe& operator++() {
operator+=(static_cast<value_type>(1));
return *this;
}
template <class Archive>
void serialize(Archive& ar, unsigned /* version */) {

View File

@ -40,19 +40,19 @@ public:
sum_of_weights_ += w;
sum_of_weights_squared_ += w * w;
const auto delta = x - weighted_mean_;
BOOST_ASSERT(sum_of_weights_ != 0);
weighted_mean_ += w * delta / sum_of_weights_;
sum_of_weighted_deltas_squared_ += w * delta * (x - weighted_mean_);
}
template <typename T>
weighted_mean& operator+=(const weighted_mean<T>& rhs) {
const auto tmp = weighted_mean_ * sum_of_weights_ +
static_cast<RealType>(rhs.weighted_mean_ * rhs.sum_of_weights_);
sum_of_weights_ += static_cast<RealType>(rhs.sum_of_weights_);
sum_of_weights_squared_ += static_cast<RealType>(rhs.sum_of_weights_squared_);
BOOST_ASSERT(sum_of_weights_ != 0);
weighted_mean_ = tmp / sum_of_weights_;
if (sum_of_weights_ != 0 || rhs.sum_of_weights_ != 0) {
const auto tmp = weighted_mean_ * sum_of_weights_ +
static_cast<RealType>(rhs.weighted_mean_ * rhs.sum_of_weights_);
sum_of_weights_ += static_cast<RealType>(rhs.sum_of_weights_);
sum_of_weights_squared_ += static_cast<RealType>(rhs.sum_of_weights_squared_);
weighted_mean_ = tmp / sum_of_weights_;
}
sum_of_weighted_deltas_squared_ +=
static_cast<RealType>(rhs.sum_of_weighted_deltas_squared_);
return *this;
@ -78,6 +78,9 @@ public:
}
const RealType& sum_of_weights() const noexcept { return sum_of_weights_; }
const RealType& sum_of_weights_squared() const noexcept {
return sum_of_weights_squared_;
}
const RealType& value() const noexcept { return weighted_mean_; }
RealType variance() const {
return sum_of_weighted_deltas_squared_ /

View File

@ -114,8 +114,8 @@ alias accumulators : [ run boost_accumulators_support_test.cpp ] : <warnings>off
alias range : [ run boost_range_support_test.cpp ] : <warnings>off ;
alias units : [ run boost_units_support_test.cpp ] : <warnings>off ;
alias serialization :
[ run detail_array_wrapper_serialization_test.cpp libserial ]
[ run accumulators_serialization_test.cpp libserial : $(THIS_PATH) ]
[ run detail_array_wrapper_serialization_test.cpp libserial ]
[ run axis_variant_serialization_test.cpp libserial : $(THIS_PATH) ]
[ run histogram_serialization_test.cpp libserial : $(THIS_PATH) ]
[ run storage_adaptor_serialization_test.cpp libserial : $(THIS_PATH) ]

View File

@ -4,7 +4,6 @@
// (See accompanying file LICENSE_1_0.txt
// or copy at http://www.boost.org/LICENSE_1_0.txt)
#include <boost/archive/xml_oarchive.hpp>
#include <boost/assert.hpp>
#include <boost/core/lightweight_test.hpp>
#include <boost/histogram/accumulators.hpp>
@ -25,10 +24,6 @@ int main(int argc, char** argv) {
BOOST_TEST_EQ(a.count(), 3);
BOOST_TEST_EQ(a.value(), 2);
BOOST_TEST_EQ(a.variance(), 0.5);
#ifndef BOOST_NO_EXCEPTIONS
boost::archive::xml_oarchive oa(std::cout);
BOOST_TEST_THROWS(a.serialize(oa, 0), std::runtime_error);
#endif
}
// mean

View File

@ -72,6 +72,8 @@ int main() {
w_t y(1, 2);
BOOST_TEST_NE(y, 1);
BOOST_TEST_EQ(static_cast<double>(y), 1);
BOOST_TEST_EQ(w_t() += w_t(), w_t());
}
{
@ -118,6 +120,10 @@ int main() {
d(2, 16);
BOOST_TEST_EQ(d, c);
BOOST_TEST_EQ(m_t() += m_t(), m_t());
BOOST_TEST_EQ(m_t(1, 2, 3) += m_t(), m_t(1, 2, 3));
BOOST_TEST_EQ(m_t() += m_t(1, 2, 3), m_t(1, 2, 3));
}
{
@ -131,6 +137,7 @@ int main() {
a(0.5, 3);
BOOST_TEST_EQ(a.sum_of_weights(), 2);
BOOST_TEST_EQ(a.sum_of_weights_squared(), 1.5);
BOOST_TEST_EQ(a.value(), 2);
BOOST_TEST_IS_CLOSE(a.variance(), 0.8, 1e-3);
@ -144,6 +151,10 @@ int main() {
BOOST_TEST_EQ(b.sum_of_weights(), 4);
BOOST_TEST_EQ(b.value(), 2);
BOOST_TEST_IS_CLOSE(b.variance(), 0.615, 1e-3);
BOOST_TEST_EQ(m_t() += m_t(), m_t());
BOOST_TEST_EQ(m_t(1, 2, 3, 4) += m_t(), m_t(1, 2, 3, 4));
BOOST_TEST_EQ(m_t() += m_t(1, 2, 3, 4), m_t(1, 2, 3, 4));
}
{
@ -154,7 +165,8 @@ int main() {
bad_sum += -1e100;
BOOST_TEST_EQ(bad_sum, 0); // instead of 2
accumulators::sum<double> sum;
using s_t = accumulators::sum<double>;
s_t sum;
++sum;
BOOST_TEST_EQ(sum.large(), 1);
BOOST_TEST_EQ(sum.small(), 0);
@ -179,10 +191,13 @@ int main() {
BOOST_TEST_GT(a, b);
BOOST_TEST_GE(a, b);
BOOST_TEST_GE(a, c);
BOOST_TEST_EQ(s_t() += s_t(), s_t());
}
{
accumulators::weighted_sum<accumulators::sum<double>> w;
using s_t = accumulators::weighted_sum<accumulators::sum<double>>;
s_t w;
++w;
w += 1e100;
@ -191,15 +206,20 @@ int main() {
BOOST_TEST_EQ(w.value(), 2);
BOOST_TEST_EQ(w.variance(), 2e200);
BOOST_TEST_EQ(s_t() += s_t(), s_t());
}
{
accumulators::thread_safe<int> i;
using ts_t = accumulators::thread_safe<int>;
ts_t i;
++i;
i += 1000;
BOOST_TEST_EQ(i, 1001);
BOOST_TEST_EQ(str(i), "1001"s);
BOOST_TEST_EQ(ts_t() += ts_t(), ts_t());
}
return boost::report_errors();