mirror of
https://github.com/boostorg/histogram.git
synced 2025-05-10 07:14:05 +00:00
Fix accumulator issues (#72)
* add weighted_mean::sum_of_weights_squared * sum of default constructed mean and weighted_mean is now equal to default constructed object * thread-safe returns a reference to self in operators just like the other accumulators
This commit is contained in:
parent
ebabd550a0
commit
c8b8c4d502
@ -33,7 +33,6 @@ public:
|
|||||||
void operator()(const RealType& x) noexcept {
|
void operator()(const RealType& x) noexcept {
|
||||||
sum_ += static_cast<RealType>(1);
|
sum_ += static_cast<RealType>(1);
|
||||||
const auto delta = x - mean_;
|
const auto delta = x - mean_;
|
||||||
BOOST_ASSERT(sum_ != 0);
|
|
||||||
mean_ += delta / sum_;
|
mean_ += delta / sum_;
|
||||||
sum_of_deltas_squared_ += delta * (x - mean_);
|
sum_of_deltas_squared_ += delta * (x - mean_);
|
||||||
}
|
}
|
||||||
@ -41,17 +40,17 @@ public:
|
|||||||
void operator()(const RealType& w, const RealType& x) noexcept {
|
void operator()(const RealType& w, const RealType& x) noexcept {
|
||||||
sum_ += w;
|
sum_ += w;
|
||||||
const auto delta = x - mean_;
|
const auto delta = x - mean_;
|
||||||
BOOST_ASSERT(sum_ != 0);
|
|
||||||
mean_ += w * delta / sum_;
|
mean_ += w * delta / sum_;
|
||||||
sum_of_deltas_squared_ += w * delta * (x - mean_);
|
sum_of_deltas_squared_ += w * delta * (x - mean_);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
mean& operator+=(const mean<T>& rhs) noexcept {
|
mean& operator+=(const mean<T>& rhs) noexcept {
|
||||||
const auto tmp = mean_ * sum_ + static_cast<RealType>(rhs.mean_ * rhs.sum_);
|
if (sum_ != 0 || rhs.sum_ != 0) {
|
||||||
sum_ += rhs.sum_;
|
const auto tmp = mean_ * sum_ + static_cast<RealType>(rhs.mean_ * rhs.sum_);
|
||||||
BOOST_ASSERT(sum_ != 0);
|
sum_ += rhs.sum_;
|
||||||
mean_ = tmp / sum_;
|
mean_ = tmp / sum_;
|
||||||
|
}
|
||||||
sum_of_deltas_squared_ += static_cast<RealType>(rhs.sum_of_deltas_squared_);
|
sum_of_deltas_squared_ += static_cast<RealType>(rhs.sum_of_deltas_squared_);
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
@ -80,8 +79,7 @@ public:
|
|||||||
template <class Archive>
|
template <class Archive>
|
||||||
void serialize(Archive& ar, unsigned version) {
|
void serialize(Archive& ar, unsigned version) {
|
||||||
if (version == 0) {
|
if (version == 0) {
|
||||||
if (Archive::is_saving::value)
|
// read only
|
||||||
BOOST_THROW_EXCEPTION(std::runtime_error("save must not use version 0"));
|
|
||||||
std::size_t sum;
|
std::size_t sum;
|
||||||
ar& make_nvp("sum", sum);
|
ar& make_nvp("sum", sum);
|
||||||
sum_ = static_cast<RealType>(sum);
|
sum_ = static_cast<RealType>(sum);
|
||||||
|
@ -31,6 +31,7 @@ namespace accumulators {
|
|||||||
template <class T>
|
template <class T>
|
||||||
class thread_safe : public std::atomic<T> {
|
class thread_safe : public std::atomic<T> {
|
||||||
public:
|
public:
|
||||||
|
using value_type = T;
|
||||||
using super_t = std::atomic<T>;
|
using super_t = std::atomic<T>;
|
||||||
|
|
||||||
thread_safe() noexcept : super_t(static_cast<T>(0)) {}
|
thread_safe() noexcept : super_t(static_cast<T>(0)) {}
|
||||||
@ -41,14 +42,24 @@ public:
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
thread_safe(T arg) : super_t(arg) {}
|
thread_safe(value_type arg) : super_t(arg) {}
|
||||||
thread_safe& operator=(T arg) {
|
thread_safe& operator=(value_type arg) {
|
||||||
super_t::store(arg);
|
super_t::store(arg);
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
void operator+=(T arg) { super_t::fetch_add(arg, std::memory_order_relaxed); }
|
thread_safe& operator+=(const thread_safe& arg) {
|
||||||
void operator++() { operator+=(static_cast<T>(1)); }
|
operator+=(arg.load());
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
thread_safe& operator+=(value_type arg) {
|
||||||
|
super_t::fetch_add(arg, std::memory_order_relaxed);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
thread_safe& operator++() {
|
||||||
|
operator+=(static_cast<value_type>(1));
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
template <class Archive>
|
template <class Archive>
|
||||||
void serialize(Archive& ar, unsigned /* version */) {
|
void serialize(Archive& ar, unsigned /* version */) {
|
||||||
|
@ -40,19 +40,19 @@ public:
|
|||||||
sum_of_weights_ += w;
|
sum_of_weights_ += w;
|
||||||
sum_of_weights_squared_ += w * w;
|
sum_of_weights_squared_ += w * w;
|
||||||
const auto delta = x - weighted_mean_;
|
const auto delta = x - weighted_mean_;
|
||||||
BOOST_ASSERT(sum_of_weights_ != 0);
|
|
||||||
weighted_mean_ += w * delta / sum_of_weights_;
|
weighted_mean_ += w * delta / sum_of_weights_;
|
||||||
sum_of_weighted_deltas_squared_ += w * delta * (x - weighted_mean_);
|
sum_of_weighted_deltas_squared_ += w * delta * (x - weighted_mean_);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
weighted_mean& operator+=(const weighted_mean<T>& rhs) {
|
weighted_mean& operator+=(const weighted_mean<T>& rhs) {
|
||||||
const auto tmp = weighted_mean_ * sum_of_weights_ +
|
if (sum_of_weights_ != 0 || rhs.sum_of_weights_ != 0) {
|
||||||
static_cast<RealType>(rhs.weighted_mean_ * rhs.sum_of_weights_);
|
const auto tmp = weighted_mean_ * sum_of_weights_ +
|
||||||
sum_of_weights_ += static_cast<RealType>(rhs.sum_of_weights_);
|
static_cast<RealType>(rhs.weighted_mean_ * rhs.sum_of_weights_);
|
||||||
sum_of_weights_squared_ += static_cast<RealType>(rhs.sum_of_weights_squared_);
|
sum_of_weights_ += static_cast<RealType>(rhs.sum_of_weights_);
|
||||||
BOOST_ASSERT(sum_of_weights_ != 0);
|
sum_of_weights_squared_ += static_cast<RealType>(rhs.sum_of_weights_squared_);
|
||||||
weighted_mean_ = tmp / sum_of_weights_;
|
weighted_mean_ = tmp / sum_of_weights_;
|
||||||
|
}
|
||||||
sum_of_weighted_deltas_squared_ +=
|
sum_of_weighted_deltas_squared_ +=
|
||||||
static_cast<RealType>(rhs.sum_of_weighted_deltas_squared_);
|
static_cast<RealType>(rhs.sum_of_weighted_deltas_squared_);
|
||||||
return *this;
|
return *this;
|
||||||
@ -78,6 +78,9 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
const RealType& sum_of_weights() const noexcept { return sum_of_weights_; }
|
const RealType& sum_of_weights() const noexcept { return sum_of_weights_; }
|
||||||
|
const RealType& sum_of_weights_squared() const noexcept {
|
||||||
|
return sum_of_weights_squared_;
|
||||||
|
}
|
||||||
const RealType& value() const noexcept { return weighted_mean_; }
|
const RealType& value() const noexcept { return weighted_mean_; }
|
||||||
RealType variance() const {
|
RealType variance() const {
|
||||||
return sum_of_weighted_deltas_squared_ /
|
return sum_of_weighted_deltas_squared_ /
|
||||||
|
@ -114,8 +114,8 @@ alias accumulators : [ run boost_accumulators_support_test.cpp ] : <warnings>off
|
|||||||
alias range : [ run boost_range_support_test.cpp ] : <warnings>off ;
|
alias range : [ run boost_range_support_test.cpp ] : <warnings>off ;
|
||||||
alias units : [ run boost_units_support_test.cpp ] : <warnings>off ;
|
alias units : [ run boost_units_support_test.cpp ] : <warnings>off ;
|
||||||
alias serialization :
|
alias serialization :
|
||||||
[ run detail_array_wrapper_serialization_test.cpp libserial ]
|
|
||||||
[ run accumulators_serialization_test.cpp libserial : $(THIS_PATH) ]
|
[ run accumulators_serialization_test.cpp libserial : $(THIS_PATH) ]
|
||||||
|
[ run detail_array_wrapper_serialization_test.cpp libserial ]
|
||||||
[ run axis_variant_serialization_test.cpp libserial : $(THIS_PATH) ]
|
[ run axis_variant_serialization_test.cpp libserial : $(THIS_PATH) ]
|
||||||
[ run histogram_serialization_test.cpp libserial : $(THIS_PATH) ]
|
[ run histogram_serialization_test.cpp libserial : $(THIS_PATH) ]
|
||||||
[ run storage_adaptor_serialization_test.cpp libserial : $(THIS_PATH) ]
|
[ run storage_adaptor_serialization_test.cpp libserial : $(THIS_PATH) ]
|
||||||
|
@ -4,7 +4,6 @@
|
|||||||
// (See accompanying file LICENSE_1_0.txt
|
// (See accompanying file LICENSE_1_0.txt
|
||||||
// or copy at http://www.boost.org/LICENSE_1_0.txt)
|
// or copy at http://www.boost.org/LICENSE_1_0.txt)
|
||||||
|
|
||||||
#include <boost/archive/xml_oarchive.hpp>
|
|
||||||
#include <boost/assert.hpp>
|
#include <boost/assert.hpp>
|
||||||
#include <boost/core/lightweight_test.hpp>
|
#include <boost/core/lightweight_test.hpp>
|
||||||
#include <boost/histogram/accumulators.hpp>
|
#include <boost/histogram/accumulators.hpp>
|
||||||
@ -25,10 +24,6 @@ int main(int argc, char** argv) {
|
|||||||
BOOST_TEST_EQ(a.count(), 3);
|
BOOST_TEST_EQ(a.count(), 3);
|
||||||
BOOST_TEST_EQ(a.value(), 2);
|
BOOST_TEST_EQ(a.value(), 2);
|
||||||
BOOST_TEST_EQ(a.variance(), 0.5);
|
BOOST_TEST_EQ(a.variance(), 0.5);
|
||||||
#ifndef BOOST_NO_EXCEPTIONS
|
|
||||||
boost::archive::xml_oarchive oa(std::cout);
|
|
||||||
BOOST_TEST_THROWS(a.serialize(oa, 0), std::runtime_error);
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// mean
|
// mean
|
||||||
|
@ -72,6 +72,8 @@ int main() {
|
|||||||
w_t y(1, 2);
|
w_t y(1, 2);
|
||||||
BOOST_TEST_NE(y, 1);
|
BOOST_TEST_NE(y, 1);
|
||||||
BOOST_TEST_EQ(static_cast<double>(y), 1);
|
BOOST_TEST_EQ(static_cast<double>(y), 1);
|
||||||
|
|
||||||
|
BOOST_TEST_EQ(w_t() += w_t(), w_t());
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -118,6 +120,10 @@ int main() {
|
|||||||
d(2, 16);
|
d(2, 16);
|
||||||
|
|
||||||
BOOST_TEST_EQ(d, c);
|
BOOST_TEST_EQ(d, c);
|
||||||
|
|
||||||
|
BOOST_TEST_EQ(m_t() += m_t(), m_t());
|
||||||
|
BOOST_TEST_EQ(m_t(1, 2, 3) += m_t(), m_t(1, 2, 3));
|
||||||
|
BOOST_TEST_EQ(m_t() += m_t(1, 2, 3), m_t(1, 2, 3));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -131,6 +137,7 @@ int main() {
|
|||||||
a(0.5, 3);
|
a(0.5, 3);
|
||||||
|
|
||||||
BOOST_TEST_EQ(a.sum_of_weights(), 2);
|
BOOST_TEST_EQ(a.sum_of_weights(), 2);
|
||||||
|
BOOST_TEST_EQ(a.sum_of_weights_squared(), 1.5);
|
||||||
BOOST_TEST_EQ(a.value(), 2);
|
BOOST_TEST_EQ(a.value(), 2);
|
||||||
BOOST_TEST_IS_CLOSE(a.variance(), 0.8, 1e-3);
|
BOOST_TEST_IS_CLOSE(a.variance(), 0.8, 1e-3);
|
||||||
|
|
||||||
@ -144,6 +151,10 @@ int main() {
|
|||||||
BOOST_TEST_EQ(b.sum_of_weights(), 4);
|
BOOST_TEST_EQ(b.sum_of_weights(), 4);
|
||||||
BOOST_TEST_EQ(b.value(), 2);
|
BOOST_TEST_EQ(b.value(), 2);
|
||||||
BOOST_TEST_IS_CLOSE(b.variance(), 0.615, 1e-3);
|
BOOST_TEST_IS_CLOSE(b.variance(), 0.615, 1e-3);
|
||||||
|
|
||||||
|
BOOST_TEST_EQ(m_t() += m_t(), m_t());
|
||||||
|
BOOST_TEST_EQ(m_t(1, 2, 3, 4) += m_t(), m_t(1, 2, 3, 4));
|
||||||
|
BOOST_TEST_EQ(m_t() += m_t(1, 2, 3, 4), m_t(1, 2, 3, 4));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -154,7 +165,8 @@ int main() {
|
|||||||
bad_sum += -1e100;
|
bad_sum += -1e100;
|
||||||
BOOST_TEST_EQ(bad_sum, 0); // instead of 2
|
BOOST_TEST_EQ(bad_sum, 0); // instead of 2
|
||||||
|
|
||||||
accumulators::sum<double> sum;
|
using s_t = accumulators::sum<double>;
|
||||||
|
s_t sum;
|
||||||
++sum;
|
++sum;
|
||||||
BOOST_TEST_EQ(sum.large(), 1);
|
BOOST_TEST_EQ(sum.large(), 1);
|
||||||
BOOST_TEST_EQ(sum.small(), 0);
|
BOOST_TEST_EQ(sum.small(), 0);
|
||||||
@ -179,10 +191,13 @@ int main() {
|
|||||||
BOOST_TEST_GT(a, b);
|
BOOST_TEST_GT(a, b);
|
||||||
BOOST_TEST_GE(a, b);
|
BOOST_TEST_GE(a, b);
|
||||||
BOOST_TEST_GE(a, c);
|
BOOST_TEST_GE(a, c);
|
||||||
|
|
||||||
|
BOOST_TEST_EQ(s_t() += s_t(), s_t());
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
accumulators::weighted_sum<accumulators::sum<double>> w;
|
using s_t = accumulators::weighted_sum<accumulators::sum<double>>;
|
||||||
|
s_t w;
|
||||||
|
|
||||||
++w;
|
++w;
|
||||||
w += 1e100;
|
w += 1e100;
|
||||||
@ -191,15 +206,20 @@ int main() {
|
|||||||
|
|
||||||
BOOST_TEST_EQ(w.value(), 2);
|
BOOST_TEST_EQ(w.value(), 2);
|
||||||
BOOST_TEST_EQ(w.variance(), 2e200);
|
BOOST_TEST_EQ(w.variance(), 2e200);
|
||||||
|
|
||||||
|
BOOST_TEST_EQ(s_t() += s_t(), s_t());
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
accumulators::thread_safe<int> i;
|
using ts_t = accumulators::thread_safe<int>;
|
||||||
|
ts_t i;
|
||||||
++i;
|
++i;
|
||||||
i += 1000;
|
i += 1000;
|
||||||
|
|
||||||
BOOST_TEST_EQ(i, 1001);
|
BOOST_TEST_EQ(i, 1001);
|
||||||
BOOST_TEST_EQ(str(i), "1001"s);
|
BOOST_TEST_EQ(str(i), "1001"s);
|
||||||
|
|
||||||
|
BOOST_TEST_EQ(ts_t() += ts_t(), ts_t());
|
||||||
}
|
}
|
||||||
|
|
||||||
return boost::report_errors();
|
return boost::report_errors();
|
||||||
|
Loading…
x
Reference in New Issue
Block a user