Initial impl

This commit is contained in:
Ruben Perez 2025-05-09 16:16:22 +02:00
parent 90405e79e4
commit 080ada8bfb
3 changed files with 188 additions and 33 deletions

View File

@ -53,6 +53,12 @@ public:
}
}
void resize(std::size_t new_size)
{
BOOST_ASSERT(new_size <= N);
size_ = new_size;
}
void clear() noexcept { size_ = 0; }
};

View File

@ -20,12 +20,23 @@
#include <boost/mysql/impl/internal/protocol/static_buffer.hpp>
#include <boost/mysql/impl/internal/sansio/connection_state_data.hpp>
#include <boost/asio/ssl/error.hpp>
#include <boost/container/small_vector.hpp>
#include <boost/core/span.hpp>
#include <boost/system/result.hpp>
#include <array>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <openssl/bio.h>
#include <openssl/crypto.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/pem.h>
#include <openssl/rsa.h>
#include <openssl/sha.h>
#include <openssl/types.h>
// Reference:
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_caching_sha2_authentication_exchanges.html
@ -92,6 +103,90 @@ inline system::result<static_buffer<32>> csha2p_hash_password(
return res;
}
// TODO: we may want to take this to a separate file?
struct bio_deleter
{
void operator()(BIO* bio) const noexcept { BIO_free(bio); }
};
using unique_bio = std::unique_ptr<BIO, bio_deleter>;
struct evp_pkey_deleter
{
void operator()(EVP_PKEY* pkey) const noexcept { EVP_PKEY_free(pkey); }
};
using unique_evp_pkey = std::unique_ptr<EVP_PKEY, evp_pkey_deleter>;
struct evp_pkey_ctx_deleter
{
void operator()(EVP_PKEY_CTX* ctx) const noexcept { EVP_PKEY_CTX_free(ctx); }
};
using unique_evp_pkey_ctx = std::unique_ptr<EVP_PKEY_CTX, evp_pkey_ctx_deleter>;
inline error_code get_last_openssl_error()
{
return error_code(::ERR_get_error(), asio::error::get_ssl_category()); // TODO: is this OK?
}
using csha2p_password_buffer = container::small_vector<std::uint8_t, 256>;
inline error_code csha2p_encrypt_password(
string_view password,
span<const std::uint8_t> challenge,
span<const std::uint8_t> server_key,
csha2p_password_buffer& output
)
{
// TODO: test that these can really never happen
BOOST_ASSERT(!password.empty());
BOOST_ASSERT(!challenge.empty());
// Try to parse the private key. TODO: size check here
unique_bio bio{BIO_new_mem_buf(server_key.data(), server_key.size())};
if (!bio)
return get_last_openssl_error();
unique_evp_pkey key(PEM_read_bio_PUBKEY(bio.get(), nullptr, nullptr, nullptr));
if (!key)
return get_last_openssl_error();
// Salt the password, as a NULL-terminated string
csha2p_password_buffer salted_password(password.size() + 1u, 0);
for (std::size_t i = 0; i < password.size(); ++i)
salted_password[i] = password[i] ^ challenge[i % challenge.size()];
// Add the NULL terminator. It should be salted, too. Since 0 ^ U = U,
// the byte should be the challenge at the position we're in
salted_password[password.size()] = challenge[password.size() % challenge.size()];
// Set up the encryption context
unique_evp_pkey_ctx ctx(EVP_PKEY_CTX_new(key.get(), nullptr));
if (!ctx)
return get_last_openssl_error();
if (EVP_PKEY_encrypt_init(ctx.get()) <= 0)
return get_last_openssl_error();
if (EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_PKCS1_OAEP_PADDING) <= 0)
return get_last_openssl_error();
// Encrypt
int max_size = EVP_PKEY_get_size(key.get());
BOOST_ASSERT(max_size >= 0);
output.resize(max_size);
std::size_t actual_size = static_cast<std::size_t>(max_size);
if (EVP_PKEY_encrypt(
ctx.get(),
output.data(),
&actual_size,
salted_password.data(),
salted_password.size()
) <= 0)
{
return get_last_openssl_error();
}
output.resize(actual_size);
// Done
return error_code();
}
class csha2p_algo
{
int resume_point_{0};
@ -106,6 +201,24 @@ class csha2p_algo
return server_data.size() == 1u && server_data[0] == 3;
}
static next_action encrypt_password(
connection_state_data& st,
std::uint8_t& seqnum,
string_view password,
span<const std::uint8_t> challenge,
span<const std::uint8_t> server_key
)
{
csha2p_password_buffer buff;
auto ec = csha2p_encrypt_password(password, challenge, server_key, buff);
if (ec)
return ec;
return st.write(
string_eof{string_view(reinterpret_cast<const char*>(buff.data()), buff.size())},
seqnum
);
}
public:
csha2p_algo() = default;
@ -113,6 +226,7 @@ public:
connection_state_data& st,
span<const std::uint8_t> server_data,
string_view password,
span<const std::uint8_t> challenge,
bool secure_channel,
std::uint8_t& seqnum
)
@ -124,14 +238,29 @@ public:
// or told us to read again because an OK packet or error packet is coming.
if (is_perform_full_auth(server_data))
{
// At this point, we don't support full auth over insecure channels
if (!secure_channel)
if (secure_channel)
{
return make_error_code(client_errc::auth_plugin_requires_ssl);
}
// We should send a packet with just the password, as a NULL-terminated string
BOOST_MYSQL_YIELD(resume_point_, 1, st.write(string_null{password}, seqnum))
// We should send a packet with just the password, as a NULL-terminated string
BOOST_MYSQL_YIELD(resume_point_, 1, st.write(string_null{password}, seqnum))
// The server shouldn't send us any more packets
return error_code(client_errc::bad_handshake_packet_type);
}
else
{
// Request the server's public key
BOOST_MYSQL_YIELD(resume_point_, 99, st.write(int1{2}, seqnum))
// Encrypt the password with the key we were given
BOOST_MYSQL_YIELD(
resume_point_,
100,
encrypt_password(st, seqnum, password, challenge, server_data)
)
// The server shouldn't send us any more packets
return error_code(client_errc::bad_handshake_packet_type);
}
}
else if (is_fast_auth_ok(server_data))
{

View File

@ -30,6 +30,7 @@
#include <boost/mysql/impl/internal/sansio/connection_state_data.hpp>
#include <boost/mysql/impl/internal/sansio/mysql_native_password.hpp>
#include <boost/container/small_vector.hpp>
#include <boost/core/span.hpp>
#include <boost/system/result.hpp>
#include <boost/variant2/variant.hpp>
@ -58,23 +59,18 @@ class any_authentication_plugin
public:
any_authentication_plugin() = default;
// Emplaces the plugin and computes the first authentication response by hashing the password
system::result<static_buffer<32>> bootstrap_plugin(
string_view plugin_name,
string_view password,
span<const std::uint8_t> challenge
)
error_code emplace_plugin(string_view plugin_name)
{
if (plugin_name == mnp_plugin_name)
{
type_ = type_t::mnp;
return mnp_hash_password(password, challenge);
return error_code();
}
else if (plugin_name == csha2p_plugin_name)
{
type_ = type_t::csha2p;
csha2p_ = csha2p_algo(); // Reset any leftover state, just in case
return csha2p_hash_password(password, challenge);
return error_code();
}
else
{
@ -82,10 +78,21 @@ public:
}
}
system::result<static_buffer<32>> hash_password(string_view password, span<const std::uint8_t> challenge)
{
switch (type_)
{
case type_t::mnp: return mnp_hash_password(password, challenge);
case type_t::csha2p: return csha2p_hash_password(password, challenge);
default: BOOST_ASSERT(false); return client_errc::unknown_auth_plugin; // LCOV_EXCL_LINE
}
}
next_action resume(
connection_state_data& st,
boost::span<const std::uint8_t> server_data,
string_view password,
boost::span<const std::uint8_t> challenge,
bool secure_channel,
std::uint8_t& seqnum
)
@ -95,7 +102,8 @@ public:
case type_t::mnp:
// This algorithm doesn't allow more data frames
return error_code(client_errc::bad_handshake_packet_type);
case type_t::csha2p: return csha2p_.resume(st, server_data, password, secure_channel, seqnum);
case type_t::csha2p:
return csha2p_.resume(st, server_data, password, challenge, secure_channel, seqnum);
default:
BOOST_ASSERT(false);
return next_action(client_errc::bad_handshake_packet_type); // LCOV_EXCL_LINE
@ -118,7 +126,7 @@ class handshake_algo
int resume_point_{0};
handshake_params hparams_;
any_authentication_plugin plugin_;
static_buffer<32> hashed_password_;
container::small_vector<std::uint8_t, 32> challenge_; // TODO: make this a static vector
std::uint8_t sequence_number_{0};
bool secure_channel_{false};
@ -219,18 +227,12 @@ class handshake_algo
// If we're using SSL, mark the channel as secure
secure_channel_ = secure_channel_ || has_capabilities(*negotiated_caps, capabilities::ssl);
// Emplace the authentication plugin and compute the first response
auto hashed_password = plugin_.bootstrap_plugin(
hello.auth_plugin_name,
hparams_.password(),
hello.auth_plugin_data
);
if (hashed_password.has_error())
return hashed_password.error();
// Save the challenge for later
span<const std::uint8_t> auth_data(hello.auth_plugin_data);
challenge_.assign(auth_data.begin(), auth_data.end());
// Save it for later
hashed_password_ = *hashed_password;
return error_code();
// Emplace the authentication plugin
return plugin_.emplace_plugin(hello.auth_plugin_name);
}
// Response to that initial greeting
@ -243,24 +245,41 @@ class handshake_algo
};
}
login_request compose_login_request(const connection_state_data& st)
next_action compose_login_request(connection_state_data& st)
{
return login_request{
// Hash the password
auto hashed_password = plugin_.hash_password(hparams_.password(), challenge_);
if (hashed_password.has_error())
return hashed_password.error();
// Compose the message
login_request msg{
st.current_capabilities,
static_cast<std::uint32_t>(max_packet_size),
hparams_.connection_collation(),
hparams_.username(),
hashed_password_,
*hashed_password,
hparams_.database(),
plugin_.name(),
};
// Serialize it
return st.write(msg, sequence_number_);
}
// Processes auth_switch and auth_more_data messages, and leaves the result in auth_resp_
next_action process_auth_switch(connection_state_data& st, auth_switch msg)
{
// Emplace the authentication plugin and compute the first response
auto hashed_password = plugin_.bootstrap_plugin(msg.plugin_name, hparams_.password(), msg.auth_data);
// Emplace the new authentication plugin
auto ec = plugin_.emplace_plugin(msg.plugin_name);
if (ec)
return ec;
// Store the challenge for later
challenge_.assign(msg.auth_data.begin(), msg.auth_data.end());
// Hash the password
auto hashed_password = plugin_.hash_password(hparams_.password(), msg.auth_data);
if (hashed_password.has_error())
return hashed_password.error();
@ -312,7 +331,7 @@ class handshake_algo
}
// Compose and send handshake response
BOOST_MYSQL_YIELD(resume_point_, 4, st.write(compose_login_request(st), sequence_number_))
BOOST_MYSQL_YIELD(resume_point_, 4, compose_login_request(st))
// Receive the response
BOOST_MYSQL_YIELD(resume_point_, 5, st.read(sequence_number_))
@ -342,6 +361,7 @@ class handshake_algo
st,
resp.data.more_data,
hparams_.password(),
challenge_,
secure_channel_,
sequence_number_
);