diff --git a/src/catch2/internal/catch_random_integer_helpers.hpp b/src/catch2/internal/catch_random_integer_helpers.hpp index fadc1e15..10d82559 100644 --- a/src/catch2/internal/catch_random_integer_helpers.hpp +++ b/src/catch2/internal/catch_random_integer_helpers.hpp @@ -14,6 +14,32 @@ #include #include +// Note: We use the usual enable-disable-autodetect dance here even though +// we do not support these in CMake configuration options (yet?). +// It is highly unlikely that we will need to make these actually +// user-configurable, but this will make it simpler if weend up needing +// it, and it provides an escape hatch to the users who need it. +#if defined( __SIZEOF_INT128__ ) +# define CATCH_CONFIG_INTERNAL_UINT128 +#elif defined( _MSC_VER ) && ( defined( _WIN64 ) || defined( _M_ARM64 ) ) +# define CATCH_CONFIG_INTERNAL_MSVC_UMUL128 +#endif + +#if defined( CATCH_CONFIG_INTERNAL_UINT128 ) && \ + !defined( CATCH_CONFIG_NO_UINT128 ) && \ + !defined( CATCH_CONFIG_UINT128 ) +#define CATCH_CONFIG_UINT128 +#endif + +#if defined( CATCH_CONFIG_INTERNAL_MSVC_UMUL128 ) && \ + !defined( CATCH_CONFIG_NO_MSVC_UMUL128 ) && \ + !defined( CATCH_CONFIG_MSVC_UMUL128 ) +# define CATCH_CONFIG_MSVC_UMUL128 +# include +# pragma intrinsic( _umul128 ) +#endif + + namespace Catch { namespace Detail { @@ -46,59 +72,52 @@ namespace Catch { } }; - // Returns 128 bit result of multiplying lhs and rhs + /** + * Returns 128 bit result of lhs * rhs using portable C++ code + * + * This implementation is almost twice as fast as naive long multiplication, + * and unlike intrinsic-based approach, it supports constexpr evaluation. + */ constexpr ExtendedMultResult - extendedMult( std::uint64_t lhs, std::uint64_t rhs ) { - // We use the simple long multiplication approach for - // correctness, we can use platform specific builtins - // for performance later. - - // Split the lhs and rhs into two 32bit "digits", so that we can - // do 64 bit arithmetic to handle carry bits. - // 32b 32b 32b 32b - // lhs L1 L2 - // * rhs R1 R2 - // ------------------------ - // | R2 * L2 | - // | R2 * L1 | - // | R1 * L2 | - // | R1 * L1 | - // ------------------------- - // | a | b | c | d | - + extendedMultPortable(std::uint64_t lhs, std::uint64_t rhs) { #define CarryBits( x ) ( x >> 32 ) #define Digits( x ) ( x & 0xFF'FF'FF'FF ) + std::uint64_t lhs_low = Digits( lhs ); + std::uint64_t rhs_low = Digits( rhs ); + std::uint64_t low_low = ( lhs_low * rhs_low ); + std::uint64_t high_high = CarryBits( lhs ) * CarryBits( rhs ); - auto r2l2 = Digits( rhs ) * Digits( lhs ); - auto r2l1 = Digits( rhs ) * CarryBits( lhs ); - auto r1l2 = CarryBits( rhs ) * Digits( lhs ); - auto r1l1 = CarryBits( rhs ) * CarryBits( lhs ); - - // Sum to columns first - auto d = Digits( r2l2 ); - auto c = CarryBits( r2l2 ) + Digits( r2l1 ) + Digits( r1l2 ); - auto b = CarryBits( r2l1 ) + CarryBits( r1l2 ) + Digits( r1l1 ); - auto a = CarryBits( r1l1 ); - - // Propagate carries between columns - c += CarryBits( d ); - b += CarryBits( c ); - a += CarryBits( b ); - - // Remove the used carries - c = Digits( c ); - b = Digits( b ); - a = Digits( a ); + // We add in carry bits from low-low already + std::uint64_t high_low = + ( CarryBits( lhs ) * rhs_low ) + CarryBits( low_low ); + // Note that we can add only low bits from high_low, to avoid + // overflow with large inputs + std::uint64_t low_high = + ( lhs_low * CarryBits( rhs ) ) + Digits( high_low ); + return { high_high + CarryBits( high_low ) + CarryBits( low_high ), + ( low_high << 32 ) | Digits( low_low ) }; #undef CarryBits #undef Digits - - return { - a << 32 | b, // upper 64 bits - c << 32 | d // lower 64 bits - }; } + //! Returns 128 bit result of lhs * rhs + inline ExtendedMultResult + extendedMult( std::uint64_t lhs, std::uint64_t rhs ) { +#if defined( CATCH_CONFIG_UINT128 ) + auto result = __uint128_t( lhs ) * __uint128_t( rhs ); + return { static_cast( result >> 64 ), + static_cast( result ) }; +#elif defined( CATCH_CONFIG_MSVC_UMUL128 ) + std::uint64_t high; + std::uint64_t low = _umul128( lhs, rhs, &high ); + return { high, low }; +#else + return extendedMultPortable( lhs, rhs ); +#endif + } + + template constexpr ExtendedMultResult extendedMult( UInt lhs, UInt rhs ) { static_assert( std::is_unsigned::value, diff --git a/tests/SelfTest/IntrospectiveTests/Integer.tests.cpp b/tests/SelfTest/IntrospectiveTests/Integer.tests.cpp index fd620ebb..8955f400 100644 --- a/tests/SelfTest/IntrospectiveTests/Integer.tests.cpp +++ b/tests/SelfTest/IntrospectiveTests/Integer.tests.cpp @@ -8,6 +8,7 @@ #include #include +#include namespace { template @@ -20,6 +21,58 @@ namespace { CHECK( extendedMult( b, a ) == ExtendedMultResult{ upper_result, lower_result } ); } + + // Simple (and slow) implmentation of extended multiplication for tests + constexpr Catch::Detail::ExtendedMultResult + extendedMultNaive( std::uint64_t lhs, std::uint64_t rhs ) { + // This is a simple long multiplication, where we split lhs and rhs + // into two 32-bit "digits", so that we can do ops with carry in 64-bits. + // + // 32b 32b 32b 32b + // lhs L1 L2 + // * rhs R1 R2 + // ------------------------ + // | R2 * L2 | + // | R2 * L1 | + // | R1 * L2 | + // | R1 * L1 | + // ------------------------- + // | a | b | c | d | + +#define CarryBits( x ) ( x >> 32 ) +#define Digits( x ) ( x & 0xFF'FF'FF'FF ) + + auto r2l2 = Digits( rhs ) * Digits( lhs ); + auto r2l1 = Digits( rhs ) * CarryBits( lhs ); + auto r1l2 = CarryBits( rhs ) * Digits( lhs ); + auto r1l1 = CarryBits( rhs ) * CarryBits( lhs ); + + // Sum to columns first + auto d = Digits( r2l2 ); + auto c = CarryBits( r2l2 ) + Digits( r2l1 ) + Digits( r1l2 ); + auto b = CarryBits( r2l1 ) + CarryBits( r1l2 ) + Digits( r1l1 ); + auto a = CarryBits( r1l1 ); + + // Propagate carries between columns + c += CarryBits( d ); + b += CarryBits( c ); + a += CarryBits( b ); + + // Remove the used carries + c = Digits( c ); + b = Digits( b ); + a = Digits( a ); + +#undef CarryBits +#undef Digits + + return { + a << 32 | b, // upper 64 bits + c << 32 | d // lower 64 bits + }; + } + + } // namespace TEST_CASE( "extendedMult 64x64", "[Integer][approvals]" ) { @@ -62,6 +115,27 @@ TEST_CASE( "extendedMult 64x64", "[Integer][approvals]" ) { 0xdf44'2d22'ce48'59b9 ); } +TEST_CASE("extendedMult 64x64 - all implementations", "[integer][approvals]") { + using Catch::Detail::extendedMult; + using Catch::Detail::extendedMultPortable; + using Catch::Detail::fillBitsFrom; + + std::random_device rng; + for (size_t i = 0; i < 100; ++i) { + auto a = fillBitsFrom( rng ); + auto b = fillBitsFrom( rng ); + CAPTURE( a, b ); + + auto naive_ab = extendedMultNaive( a, b ); + + REQUIRE( naive_ab == extendedMultNaive( b, a ) ); + REQUIRE( naive_ab == extendedMultPortable( a, b ) ); + REQUIRE( naive_ab == extendedMultPortable( b, a ) ); + REQUIRE( naive_ab == extendedMult( a, b ) ); + REQUIRE( naive_ab == extendedMult( b, a ) ); + } +} + TEST_CASE( "SizedUnsignedType helpers", "[integer][approvals]" ) { using Catch::Detail::SizedUnsignedType_t; using Catch::Detail::DoubleWidthUnsignedType_t;