mirror of
https://github.com/boostorg/mysql.git
synced 2025-05-12 14:11:41 +00:00
338 lines
11 KiB
C++
338 lines
11 KiB
C++
//
|
|
// Copyright (c) 2019-2025 Ruben Perez Hidalgo (rubenperez038 at gmail dot com)
|
|
//
|
|
// Distributed under the Boost Software License, Version 1.0. (See accompanying
|
|
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
|
|
//
|
|
|
|
#ifndef BOOST_MYSQL_IMPL_INTERNAL_SANSIO_HANDSHAKE_HPP
|
|
#define BOOST_MYSQL_IMPL_INTERNAL_SANSIO_HANDSHAKE_HPP
|
|
|
|
#include <boost/mysql/character_set.hpp>
|
|
#include <boost/mysql/diagnostics.hpp>
|
|
#include <boost/mysql/error_code.hpp>
|
|
#include <boost/mysql/handshake_params.hpp>
|
|
#include <boost/mysql/mysql_collations.hpp>
|
|
|
|
#include <boost/mysql/detail/algo_params.hpp>
|
|
#include <boost/mysql/detail/next_action.hpp>
|
|
#include <boost/mysql/detail/ok_view.hpp>
|
|
|
|
#include <boost/mysql/impl/internal/auth/auth.hpp>
|
|
#include <boost/mysql/impl/internal/coroutine.hpp>
|
|
#include <boost/mysql/impl/internal/protocol/capabilities.hpp>
|
|
#include <boost/mysql/impl/internal/protocol/db_flavor.hpp>
|
|
#include <boost/mysql/impl/internal/protocol/deserialization.hpp>
|
|
#include <boost/mysql/impl/internal/protocol/serialization.hpp>
|
|
#include <boost/mysql/impl/internal/sansio/connection_state_data.hpp>
|
|
|
|
#include <cstdint>
|
|
|
|
namespace boost {
|
|
namespace mysql {
|
|
namespace detail {
|
|
|
|
inline capabilities conditional_capability(bool condition, capabilities cap)
|
|
{
|
|
return condition ? cap : capabilities{};
|
|
}
|
|
|
|
inline error_code process_capabilities(
|
|
const handshake_params& params,
|
|
const server_hello& hello,
|
|
capabilities& negotiated_caps,
|
|
bool transport_supports_ssl
|
|
)
|
|
{
|
|
// The capabilities that we absolutely require. These are always set except in extremely old servers
|
|
constexpr capabilities mandatory_capabilities =
|
|
// We don't speak the older protocol
|
|
capabilities::protocol_41 |
|
|
|
|
// We only know how to deserialize the hello frame if this is set
|
|
capabilities::plugin_auth |
|
|
|
|
// Same as above
|
|
capabilities::plugin_auth_lenenc_data |
|
|
|
|
// This makes processing execute responses easier
|
|
capabilities::deprecate_eof |
|
|
|
|
// Used in MariaDB to signal 4.1 protocol. Always set in MySQL, too
|
|
capabilities::secure_connection;
|
|
|
|
// The capabilities that we support but don't require
|
|
constexpr capabilities optional_capabilities = capabilities::multi_results |
|
|
capabilities::ps_multi_results;
|
|
|
|
auto ssl = transport_supports_ssl ? params.ssl() : ssl_mode::disable;
|
|
capabilities server_caps = hello.server_capabilities;
|
|
capabilities
|
|
required_caps = mandatory_capabilities |
|
|
conditional_capability(!params.database().empty(), capabilities::connect_with_db) |
|
|
conditional_capability(params.multi_queries(), capabilities::multi_statements) |
|
|
conditional_capability(ssl == ssl_mode::require, capabilities::ssl);
|
|
if (has_capabilities(required_caps, capabilities::ssl) &&
|
|
!has_capabilities(server_caps, capabilities::ssl))
|
|
{
|
|
// This happens if the server doesn't have SSL configured. This special
|
|
// error code helps users diagnosing their problem a lot (server_unsupported doesn't).
|
|
return make_error_code(client_errc::server_doesnt_support_ssl);
|
|
}
|
|
else if (!has_capabilities(server_caps, required_caps))
|
|
{
|
|
return make_error_code(client_errc::server_unsupported);
|
|
}
|
|
negotiated_caps = server_caps & (required_caps | optional_capabilities |
|
|
conditional_capability(ssl == ssl_mode::enable, capabilities::ssl));
|
|
return error_code();
|
|
}
|
|
|
|
class handshake_algo
|
|
{
|
|
int resume_point_{0};
|
|
handshake_params hparams_;
|
|
auth_response auth_resp_;
|
|
std::uint8_t sequence_number_{0};
|
|
bool secure_channel_{false};
|
|
|
|
// Attempts to map the collection_id to a character set. We try to be conservative
|
|
// here, since servers will happily accept unknown collation IDs, silently defaulting
|
|
// to the server's default character set (often latin1, which is not Unicode).
|
|
static character_set collation_id_to_charset(std::uint16_t collation_id)
|
|
{
|
|
switch (collation_id)
|
|
{
|
|
case mysql_collations::utf8mb4_bin:
|
|
case mysql_collations::utf8mb4_general_ci: return utf8mb4_charset;
|
|
case mysql_collations::ascii_general_ci:
|
|
case mysql_collations::ascii_bin: return ascii_charset;
|
|
default: return character_set{};
|
|
}
|
|
}
|
|
|
|
// Once the handshake is processed, the capabilities are stored in the connection state
|
|
bool use_ssl(const connection_state_data& st) const
|
|
{
|
|
return has_capabilities(st.current_capabilities, capabilities::ssl);
|
|
}
|
|
|
|
error_code process_handshake(
|
|
connection_state_data& st,
|
|
diagnostics& diag,
|
|
span<const std::uint8_t> buffer
|
|
)
|
|
{
|
|
// Deserialize server hello
|
|
server_hello hello{};
|
|
auto err = deserialize_server_hello(buffer, hello, diag);
|
|
if (err)
|
|
return err;
|
|
|
|
// Check capabilities
|
|
capabilities negotiated_caps{};
|
|
err = process_capabilities(hparams_, hello, negotiated_caps, st.tls_supported);
|
|
if (err)
|
|
return err;
|
|
|
|
// Set capabilities, db flavor and connection ID
|
|
st.current_capabilities = negotiated_caps;
|
|
st.flavor = hello.server;
|
|
st.connection_id = hello.connection_id;
|
|
|
|
// If we're using SSL, mark the channel as secure
|
|
secure_channel_ = secure_channel_ || use_ssl(st);
|
|
|
|
// Compute auth response
|
|
return compute_auth_response(
|
|
hello.auth_plugin_name,
|
|
hparams_.password(),
|
|
hello.auth_plugin_data.to_span(),
|
|
secure_channel_,
|
|
auth_resp_
|
|
);
|
|
}
|
|
|
|
// Response to that initial greeting
|
|
ssl_request compose_ssl_request(const connection_state_data& st)
|
|
{
|
|
return ssl_request{
|
|
st.current_capabilities,
|
|
static_cast<std::uint32_t>(max_packet_size),
|
|
hparams_.connection_collation(),
|
|
};
|
|
}
|
|
|
|
login_request compose_login_request(const connection_state_data& st)
|
|
{
|
|
return login_request{
|
|
st.current_capabilities,
|
|
static_cast<std::uint32_t>(max_packet_size),
|
|
hparams_.connection_collation(),
|
|
hparams_.username(),
|
|
auth_resp_.data,
|
|
hparams_.database(),
|
|
auth_resp_.plugin_name,
|
|
};
|
|
}
|
|
|
|
// Processes auth_switch and auth_more_data messages, and leaves the result in auth_resp_
|
|
error_code process_auth_switch(auth_switch msg)
|
|
{
|
|
return compute_auth_response(
|
|
msg.plugin_name,
|
|
hparams_.password(),
|
|
msg.auth_data,
|
|
secure_channel_,
|
|
auth_resp_
|
|
);
|
|
}
|
|
|
|
error_code process_auth_more_data(span<const std::uint8_t> data)
|
|
{
|
|
return compute_auth_response(
|
|
auth_resp_.plugin_name,
|
|
hparams_.password(),
|
|
data,
|
|
secure_channel_,
|
|
auth_resp_
|
|
);
|
|
}
|
|
|
|
// Composes an auth_switch_response message with the contents of auth_resp_
|
|
auth_switch_response compose_auth_switch_response() const
|
|
{
|
|
return auth_switch_response{auth_resp_.data};
|
|
}
|
|
|
|
void on_success(connection_state_data& st, const ok_view& ok)
|
|
{
|
|
st.status = connection_status::ready;
|
|
st.backslash_escapes = ok.backslash_escapes();
|
|
st.current_charset = collation_id_to_charset(hparams_.connection_collation());
|
|
}
|
|
|
|
error_code process_ok(connection_state_data& st)
|
|
{
|
|
ok_view res{};
|
|
auto ec = deserialize_ok_packet(st.reader.message(), res);
|
|
if (ec)
|
|
return ec;
|
|
on_success(st, res);
|
|
return error_code();
|
|
}
|
|
|
|
public:
|
|
handshake_algo(handshake_algo_params params) noexcept
|
|
: hparams_(params.hparams), secure_channel_(params.secure_channel)
|
|
{
|
|
}
|
|
|
|
next_action resume(connection_state_data& st, diagnostics& diag, error_code ec)
|
|
{
|
|
if (ec)
|
|
return ec;
|
|
|
|
handhake_server_response resp(error_code{});
|
|
|
|
switch (resume_point_)
|
|
{
|
|
case 0:
|
|
// Handshake wipes out state, so no state checks are performed.
|
|
// Setup
|
|
st.reset();
|
|
|
|
// Read server greeting
|
|
BOOST_MYSQL_YIELD(resume_point_, 1, st.read(sequence_number_))
|
|
|
|
// Process server greeting
|
|
ec = process_handshake(st, diag, st.reader.message());
|
|
if (ec)
|
|
return ec;
|
|
|
|
// SSL
|
|
if (use_ssl(st))
|
|
{
|
|
// Send SSL request
|
|
BOOST_MYSQL_YIELD(resume_point_, 2, st.write(compose_ssl_request(st), sequence_number_))
|
|
|
|
// SSL handshake
|
|
BOOST_MYSQL_YIELD(resume_point_, 3, next_action::ssl_handshake())
|
|
|
|
// Mark the connection as using ssl
|
|
st.tls_active = true;
|
|
}
|
|
|
|
// Compose and send handshake response
|
|
BOOST_MYSQL_YIELD(resume_point_, 4, st.write(compose_login_request(st), sequence_number_))
|
|
|
|
// Auth message exchange
|
|
while (true)
|
|
{
|
|
// Receive response
|
|
BOOST_MYSQL_YIELD(resume_point_, 5, st.read(sequence_number_))
|
|
|
|
// Process it
|
|
resp = deserialize_handshake_server_response(st.reader.message(), st.flavor, diag);
|
|
if (resp.type == handhake_server_response::type_t::ok)
|
|
{
|
|
// Auth success, quit
|
|
on_success(st, resp.data.ok);
|
|
return next_action();
|
|
}
|
|
else if (resp.type == handhake_server_response::type_t::error)
|
|
{
|
|
// Error, quit
|
|
return resp.data.err;
|
|
}
|
|
else if (resp.type == handhake_server_response::type_t::auth_switch)
|
|
{
|
|
// Compute response
|
|
ec = process_auth_switch(resp.data.auth_sw);
|
|
if (ec)
|
|
return ec;
|
|
|
|
BOOST_MYSQL_YIELD(
|
|
resume_point_,
|
|
6,
|
|
st.write(compose_auth_switch_response(), sequence_number_)
|
|
)
|
|
}
|
|
else if (resp.type == handhake_server_response::type_t::ok_follows)
|
|
{
|
|
// The next packet must be an OK packet. Read it
|
|
BOOST_MYSQL_YIELD(resume_point_, 7, st.read(sequence_number_))
|
|
|
|
// Process it
|
|
// Regardless of whether we succeeded or not, we're done
|
|
return process_ok(st);
|
|
}
|
|
else
|
|
{
|
|
BOOST_ASSERT(resp.type == handhake_server_response::type_t::auth_more_data);
|
|
|
|
// Compute response
|
|
ec = process_auth_more_data(resp.data.more_data);
|
|
if (ec)
|
|
return ec;
|
|
|
|
// Write response
|
|
BOOST_MYSQL_YIELD(
|
|
resume_point_,
|
|
8,
|
|
st.write(compose_auth_switch_response(), sequence_number_)
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
return next_action();
|
|
}
|
|
};
|
|
|
|
} // namespace detail
|
|
} // namespace mysql
|
|
} // namespace boost
|
|
|
|
#endif
|