Fix accumulator issues (#72)

* add weighted_mean::sum_of_weights_squared
* sum of default constructed mean and weighted_mean is now equal to default constructed object
* thread-safe returns a reference to self in operators just like the other accumulators
This commit is contained in:
Hans Dembinski 2019-10-26 11:12:29 +02:00 committed by GitHub
parent ebabd550a0
commit c8b8c4d502
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 55 additions and 28 deletions

View File

@ -33,7 +33,6 @@ public:
void operator()(const RealType& x) noexcept { void operator()(const RealType& x) noexcept {
sum_ += static_cast<RealType>(1); sum_ += static_cast<RealType>(1);
const auto delta = x - mean_; const auto delta = x - mean_;
BOOST_ASSERT(sum_ != 0);
mean_ += delta / sum_; mean_ += delta / sum_;
sum_of_deltas_squared_ += delta * (x - mean_); sum_of_deltas_squared_ += delta * (x - mean_);
} }
@ -41,17 +40,17 @@ public:
void operator()(const RealType& w, const RealType& x) noexcept { void operator()(const RealType& w, const RealType& x) noexcept {
sum_ += w; sum_ += w;
const auto delta = x - mean_; const auto delta = x - mean_;
BOOST_ASSERT(sum_ != 0);
mean_ += w * delta / sum_; mean_ += w * delta / sum_;
sum_of_deltas_squared_ += w * delta * (x - mean_); sum_of_deltas_squared_ += w * delta * (x - mean_);
} }
template <class T> template <class T>
mean& operator+=(const mean<T>& rhs) noexcept { mean& operator+=(const mean<T>& rhs) noexcept {
if (sum_ != 0 || rhs.sum_ != 0) {
const auto tmp = mean_ * sum_ + static_cast<RealType>(rhs.mean_ * rhs.sum_); const auto tmp = mean_ * sum_ + static_cast<RealType>(rhs.mean_ * rhs.sum_);
sum_ += rhs.sum_; sum_ += rhs.sum_;
BOOST_ASSERT(sum_ != 0);
mean_ = tmp / sum_; mean_ = tmp / sum_;
}
sum_of_deltas_squared_ += static_cast<RealType>(rhs.sum_of_deltas_squared_); sum_of_deltas_squared_ += static_cast<RealType>(rhs.sum_of_deltas_squared_);
return *this; return *this;
} }
@ -80,8 +79,7 @@ public:
template <class Archive> template <class Archive>
void serialize(Archive& ar, unsigned version) { void serialize(Archive& ar, unsigned version) {
if (version == 0) { if (version == 0) {
if (Archive::is_saving::value) // read only
BOOST_THROW_EXCEPTION(std::runtime_error("save must not use version 0"));
std::size_t sum; std::size_t sum;
ar& make_nvp("sum", sum); ar& make_nvp("sum", sum);
sum_ = static_cast<RealType>(sum); sum_ = static_cast<RealType>(sum);

View File

@ -31,6 +31,7 @@ namespace accumulators {
template <class T> template <class T>
class thread_safe : public std::atomic<T> { class thread_safe : public std::atomic<T> {
public: public:
using value_type = T;
using super_t = std::atomic<T>; using super_t = std::atomic<T>;
thread_safe() noexcept : super_t(static_cast<T>(0)) {} thread_safe() noexcept : super_t(static_cast<T>(0)) {}
@ -41,14 +42,24 @@ public:
return *this; return *this;
} }
thread_safe(T arg) : super_t(arg) {} thread_safe(value_type arg) : super_t(arg) {}
thread_safe& operator=(T arg) { thread_safe& operator=(value_type arg) {
super_t::store(arg); super_t::store(arg);
return *this; return *this;
} }
void operator+=(T arg) { super_t::fetch_add(arg, std::memory_order_relaxed); } thread_safe& operator+=(const thread_safe& arg) {
void operator++() { operator+=(static_cast<T>(1)); } operator+=(arg.load());
return *this;
}
thread_safe& operator+=(value_type arg) {
super_t::fetch_add(arg, std::memory_order_relaxed);
return *this;
}
thread_safe& operator++() {
operator+=(static_cast<value_type>(1));
return *this;
}
template <class Archive> template <class Archive>
void serialize(Archive& ar, unsigned /* version */) { void serialize(Archive& ar, unsigned /* version */) {

View File

@ -40,19 +40,19 @@ public:
sum_of_weights_ += w; sum_of_weights_ += w;
sum_of_weights_squared_ += w * w; sum_of_weights_squared_ += w * w;
const auto delta = x - weighted_mean_; const auto delta = x - weighted_mean_;
BOOST_ASSERT(sum_of_weights_ != 0);
weighted_mean_ += w * delta / sum_of_weights_; weighted_mean_ += w * delta / sum_of_weights_;
sum_of_weighted_deltas_squared_ += w * delta * (x - weighted_mean_); sum_of_weighted_deltas_squared_ += w * delta * (x - weighted_mean_);
} }
template <typename T> template <typename T>
weighted_mean& operator+=(const weighted_mean<T>& rhs) { weighted_mean& operator+=(const weighted_mean<T>& rhs) {
if (sum_of_weights_ != 0 || rhs.sum_of_weights_ != 0) {
const auto tmp = weighted_mean_ * sum_of_weights_ + const auto tmp = weighted_mean_ * sum_of_weights_ +
static_cast<RealType>(rhs.weighted_mean_ * rhs.sum_of_weights_); static_cast<RealType>(rhs.weighted_mean_ * rhs.sum_of_weights_);
sum_of_weights_ += static_cast<RealType>(rhs.sum_of_weights_); sum_of_weights_ += static_cast<RealType>(rhs.sum_of_weights_);
sum_of_weights_squared_ += static_cast<RealType>(rhs.sum_of_weights_squared_); sum_of_weights_squared_ += static_cast<RealType>(rhs.sum_of_weights_squared_);
BOOST_ASSERT(sum_of_weights_ != 0);
weighted_mean_ = tmp / sum_of_weights_; weighted_mean_ = tmp / sum_of_weights_;
}
sum_of_weighted_deltas_squared_ += sum_of_weighted_deltas_squared_ +=
static_cast<RealType>(rhs.sum_of_weighted_deltas_squared_); static_cast<RealType>(rhs.sum_of_weighted_deltas_squared_);
return *this; return *this;
@ -78,6 +78,9 @@ public:
} }
const RealType& sum_of_weights() const noexcept { return sum_of_weights_; } const RealType& sum_of_weights() const noexcept { return sum_of_weights_; }
const RealType& sum_of_weights_squared() const noexcept {
return sum_of_weights_squared_;
}
const RealType& value() const noexcept { return weighted_mean_; } const RealType& value() const noexcept { return weighted_mean_; }
RealType variance() const { RealType variance() const {
return sum_of_weighted_deltas_squared_ / return sum_of_weighted_deltas_squared_ /

View File

@ -114,8 +114,8 @@ alias accumulators : [ run boost_accumulators_support_test.cpp ] : <warnings>off
alias range : [ run boost_range_support_test.cpp ] : <warnings>off ; alias range : [ run boost_range_support_test.cpp ] : <warnings>off ;
alias units : [ run boost_units_support_test.cpp ] : <warnings>off ; alias units : [ run boost_units_support_test.cpp ] : <warnings>off ;
alias serialization : alias serialization :
[ run detail_array_wrapper_serialization_test.cpp libserial ]
[ run accumulators_serialization_test.cpp libserial : $(THIS_PATH) ] [ run accumulators_serialization_test.cpp libserial : $(THIS_PATH) ]
[ run detail_array_wrapper_serialization_test.cpp libserial ]
[ run axis_variant_serialization_test.cpp libserial : $(THIS_PATH) ] [ run axis_variant_serialization_test.cpp libserial : $(THIS_PATH) ]
[ run histogram_serialization_test.cpp libserial : $(THIS_PATH) ] [ run histogram_serialization_test.cpp libserial : $(THIS_PATH) ]
[ run storage_adaptor_serialization_test.cpp libserial : $(THIS_PATH) ] [ run storage_adaptor_serialization_test.cpp libserial : $(THIS_PATH) ]

View File

@ -4,7 +4,6 @@
// (See accompanying file LICENSE_1_0.txt // (See accompanying file LICENSE_1_0.txt
// or copy at http://www.boost.org/LICENSE_1_0.txt) // or copy at http://www.boost.org/LICENSE_1_0.txt)
#include <boost/archive/xml_oarchive.hpp>
#include <boost/assert.hpp> #include <boost/assert.hpp>
#include <boost/core/lightweight_test.hpp> #include <boost/core/lightweight_test.hpp>
#include <boost/histogram/accumulators.hpp> #include <boost/histogram/accumulators.hpp>
@ -25,10 +24,6 @@ int main(int argc, char** argv) {
BOOST_TEST_EQ(a.count(), 3); BOOST_TEST_EQ(a.count(), 3);
BOOST_TEST_EQ(a.value(), 2); BOOST_TEST_EQ(a.value(), 2);
BOOST_TEST_EQ(a.variance(), 0.5); BOOST_TEST_EQ(a.variance(), 0.5);
#ifndef BOOST_NO_EXCEPTIONS
boost::archive::xml_oarchive oa(std::cout);
BOOST_TEST_THROWS(a.serialize(oa, 0), std::runtime_error);
#endif
} }
// mean // mean

View File

@ -72,6 +72,8 @@ int main() {
w_t y(1, 2); w_t y(1, 2);
BOOST_TEST_NE(y, 1); BOOST_TEST_NE(y, 1);
BOOST_TEST_EQ(static_cast<double>(y), 1); BOOST_TEST_EQ(static_cast<double>(y), 1);
BOOST_TEST_EQ(w_t() += w_t(), w_t());
} }
{ {
@ -118,6 +120,10 @@ int main() {
d(2, 16); d(2, 16);
BOOST_TEST_EQ(d, c); BOOST_TEST_EQ(d, c);
BOOST_TEST_EQ(m_t() += m_t(), m_t());
BOOST_TEST_EQ(m_t(1, 2, 3) += m_t(), m_t(1, 2, 3));
BOOST_TEST_EQ(m_t() += m_t(1, 2, 3), m_t(1, 2, 3));
} }
{ {
@ -131,6 +137,7 @@ int main() {
a(0.5, 3); a(0.5, 3);
BOOST_TEST_EQ(a.sum_of_weights(), 2); BOOST_TEST_EQ(a.sum_of_weights(), 2);
BOOST_TEST_EQ(a.sum_of_weights_squared(), 1.5);
BOOST_TEST_EQ(a.value(), 2); BOOST_TEST_EQ(a.value(), 2);
BOOST_TEST_IS_CLOSE(a.variance(), 0.8, 1e-3); BOOST_TEST_IS_CLOSE(a.variance(), 0.8, 1e-3);
@ -144,6 +151,10 @@ int main() {
BOOST_TEST_EQ(b.sum_of_weights(), 4); BOOST_TEST_EQ(b.sum_of_weights(), 4);
BOOST_TEST_EQ(b.value(), 2); BOOST_TEST_EQ(b.value(), 2);
BOOST_TEST_IS_CLOSE(b.variance(), 0.615, 1e-3); BOOST_TEST_IS_CLOSE(b.variance(), 0.615, 1e-3);
BOOST_TEST_EQ(m_t() += m_t(), m_t());
BOOST_TEST_EQ(m_t(1, 2, 3, 4) += m_t(), m_t(1, 2, 3, 4));
BOOST_TEST_EQ(m_t() += m_t(1, 2, 3, 4), m_t(1, 2, 3, 4));
} }
{ {
@ -154,7 +165,8 @@ int main() {
bad_sum += -1e100; bad_sum += -1e100;
BOOST_TEST_EQ(bad_sum, 0); // instead of 2 BOOST_TEST_EQ(bad_sum, 0); // instead of 2
accumulators::sum<double> sum; using s_t = accumulators::sum<double>;
s_t sum;
++sum; ++sum;
BOOST_TEST_EQ(sum.large(), 1); BOOST_TEST_EQ(sum.large(), 1);
BOOST_TEST_EQ(sum.small(), 0); BOOST_TEST_EQ(sum.small(), 0);
@ -179,10 +191,13 @@ int main() {
BOOST_TEST_GT(a, b); BOOST_TEST_GT(a, b);
BOOST_TEST_GE(a, b); BOOST_TEST_GE(a, b);
BOOST_TEST_GE(a, c); BOOST_TEST_GE(a, c);
BOOST_TEST_EQ(s_t() += s_t(), s_t());
} }
{ {
accumulators::weighted_sum<accumulators::sum<double>> w; using s_t = accumulators::weighted_sum<accumulators::sum<double>>;
s_t w;
++w; ++w;
w += 1e100; w += 1e100;
@ -191,15 +206,20 @@ int main() {
BOOST_TEST_EQ(w.value(), 2); BOOST_TEST_EQ(w.value(), 2);
BOOST_TEST_EQ(w.variance(), 2e200); BOOST_TEST_EQ(w.variance(), 2e200);
BOOST_TEST_EQ(s_t() += s_t(), s_t());
} }
{ {
accumulators::thread_safe<int> i; using ts_t = accumulators::thread_safe<int>;
ts_t i;
++i; ++i;
i += 1000; i += 1000;
BOOST_TEST_EQ(i, 1001); BOOST_TEST_EQ(i, 1001);
BOOST_TEST_EQ(str(i), "1001"s); BOOST_TEST_EQ(str(i), "1001"s);
BOOST_TEST_EQ(ts_t() += ts_t(), ts_t());
} }
return boost::report_errors(); return boost::report_errors();