Merge pull request #10700 from sepalani/ssl-handshake

Socket: Fix some non-blocking connect edge cases
This commit is contained in:
JMC47 2022-06-27 21:39:36 -04:00 committed by GitHub
commit e50e45f400
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 223 additions and 62 deletions

View File

@ -186,4 +186,22 @@ u16 ComputeNetworkChecksum(const void* data, u16 length, u32 initial_value)
checksum = (checksum >> 16) + (checksum & 0xFFFF); checksum = (checksum >> 16) + (checksum & 0xFFFF);
return ~static_cast<u16>(checksum); return ~static_cast<u16>(checksum);
} }
NetworkErrorState SaveNetworkErrorState()
{
return {
errno,
#ifdef _WIN32
WSAGetLastError(),
#endif
};
}
void RestoreNetworkErrorState(const NetworkErrorState& state)
{
errno = state.error;
#ifdef _WIN32
WSASetLastError(state.wsa_error);
#endif
}
} // namespace Common } // namespace Common

View File

@ -99,8 +99,18 @@ struct UDPHeader
}; };
static_assert(sizeof(UDPHeader) == UDPHeader::SIZE); static_assert(sizeof(UDPHeader) == UDPHeader::SIZE);
struct NetworkErrorState
{
int error;
#ifdef _WIN32
int wsa_error;
#endif
};
MACAddress GenerateMacAddress(MACConsumer type); MACAddress GenerateMacAddress(MACConsumer type);
std::string MacAddressToString(const MACAddress& mac); std::string MacAddressToString(const MACAddress& mac);
std::optional<MACAddress> StringToMacAddress(std::string_view mac_string); std::optional<MACAddress> StringToMacAddress(std::string_view mac_string);
u16 ComputeNetworkChecksum(const void* data, u16 length, u32 initial_value = 0); u16 ComputeNetworkChecksum(const void* data, u16 length, u32 initial_value = 0);
NetworkErrorState SaveNetworkErrorState();
void RestoreNetworkErrorState(const NetworkErrorState& state);
} // namespace Common } // namespace Common

View File

@ -16,8 +16,11 @@
#include <sys/select.h> #include <sys/select.h>
#endif #endif
#include "Common/BitUtils.h"
#include "Common/FileUtil.h" #include "Common/FileUtil.h"
#include "Common/IOFile.h" #include "Common/IOFile.h"
#include "Common/Network.h"
#include "Common/ScopeGuard.h"
#include "Core/Config/MainSettings.h" #include "Core/Config/MainSettings.h"
#include "Core/Core.h" #include "Core/Core.h"
#include "Core/IOS/Device.h" #include "Core/IOS/Device.h"
@ -224,6 +227,7 @@ s32 WiiSocket::CloseFd()
GetIOS()->EnqueueIPCReply(it->request, -SO_ENOTCONN); GetIOS()->EnqueueIPCReply(it->request, -SO_ENOTCONN);
it = pending_sockops.erase(it); it = pending_sockops.erase(it);
} }
connecting_state = ConnectingState::None;
return ReturnValue; return ReturnValue;
} }
@ -278,8 +282,8 @@ void WiiSocket::Update(bool read, bool write, bool except)
case IOCTL_SO_BIND: case IOCTL_SO_BIND:
{ {
sockaddr_in local_name; sockaddr_in local_name;
WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(ioctl.buffer_in + 8); const u8* addr = Memory::GetPointer(ioctl.buffer_in + 8);
WiiSockMan::Convert(*wii_name, local_name); WiiSockMan::ToNativeAddrIn(addr, &local_name);
int ret = bind(fd, (sockaddr*)&local_name, sizeof(local_name)); int ret = bind(fd, (sockaddr*)&local_name, sizeof(local_name));
ReturnValue = WiiSockMan::GetNetErrorCode(ret, "SO_BIND", false); ReturnValue = WiiSockMan::GetNetErrorCode(ret, "SO_BIND", false);
@ -291,11 +295,12 @@ void WiiSocket::Update(bool read, bool write, bool except)
case IOCTL_SO_CONNECT: case IOCTL_SO_CONNECT:
{ {
sockaddr_in local_name; sockaddr_in local_name;
WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(ioctl.buffer_in + 8); const u8* addr = Memory::GetPointer(ioctl.buffer_in + 8);
WiiSockMan::Convert(*wii_name, local_name); WiiSockMan::ToNativeAddrIn(addr, &local_name);
int ret = connect(fd, (sockaddr*)&local_name, sizeof(local_name)); int ret = connect(fd, (sockaddr*)&local_name, sizeof(local_name));
ReturnValue = WiiSockMan::GetNetErrorCode(ret, "SO_CONNECT", false); ReturnValue = WiiSockMan::GetNetErrorCode(ret, "SO_CONNECT", false);
UpdateConnectingState(ReturnValue);
INFO_LOG_FMT(IOS_NET, "IOCTL_SO_CONNECT ({:08x}, {}:{}) = {}", wii_fd, INFO_LOG_FMT(IOS_NET, "IOCTL_SO_CONNECT ({:08x}, {}:{}) = {}", wii_fd,
inet_ntoa(local_name.sin_addr), Common::swap16(local_name.sin_port), ret); inet_ntoa(local_name.sin_addr), Common::swap16(local_name.sin_port), ret);
@ -307,13 +312,13 @@ void WiiSocket::Update(bool read, bool write, bool except)
if (ioctl.buffer_out_size > 0) if (ioctl.buffer_out_size > 0)
{ {
sockaddr_in local_name; sockaddr_in local_name;
WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(ioctl.buffer_out); u8* addr = Memory::GetPointer(ioctl.buffer_out);
WiiSockMan::Convert(*wii_name, local_name); WiiSockMan::ToNativeAddrIn(addr, &local_name);
socklen_t addrlen = sizeof(sockaddr_in); socklen_t addrlen = sizeof(sockaddr_in);
ret = static_cast<s32>(accept(fd, (sockaddr*)&local_name, &addrlen)); ret = static_cast<s32>(accept(fd, (sockaddr*)&local_name, &addrlen));
WiiSockMan::Convert(local_name, *wii_name, addrlen); WiiSockMan::ToWiiAddrIn(local_name, addr, addrlen);
} }
else else
{ {
@ -341,10 +346,12 @@ void WiiSocket::Update(bool read, bool write, bool except)
{ {
ReturnValue = -SO_ENETUNREACH; ReturnValue = -SO_ENETUNREACH;
ResetTimeout(); ResetTimeout();
connecting_state = ConnectingState::Error;
} }
break; break;
case -SO_EISCONN: case -SO_EISCONN:
ReturnValue = SO_SUCCESS; ReturnValue = SO_SUCCESS;
connecting_state = ConnectingState::Connected;
[[fallthrough]]; [[fallthrough]];
default: default:
ResetTimeout(); ResetTimeout();
@ -392,6 +399,24 @@ void WiiSocket::Update(bool read, bool write, bool except)
{ {
case IOCTLV_NET_SSL_DOHANDSHAKE: case IOCTLV_NET_SSL_DOHANDSHAKE:
{ {
// The Wii allows a socket with an in-progress connection to
// perform the SSL handshake. MbedTLS doesn't support it so
// we have to check it manually.
connecting_state = GetConnectingState();
if (connecting_state == ConnectingState::Connecting)
{
WriteReturnValue(SSL_ERR_RAGAIN, BufferIn);
ReturnValue = SSL_ERR_RAGAIN;
break;
}
else if (connecting_state == ConnectingState::None ||
connecting_state == ConnectingState::Error)
{
WriteReturnValue(SSL_ERR_SYSCALL, BufferIn);
ReturnValue = SSL_ERR_SYSCALL;
break;
}
mbedtls_ssl_context* ctx = &NetSSLDevice::_SSL[sslID].ctx; mbedtls_ssl_context* ctx = &NetSSLDevice::_SSL[sslID].ctx;
const int ret = mbedtls_ssl_handshake(ctx); const int ret = mbedtls_ssl_handshake(ctx);
if (ret != 0) if (ret != 0)
@ -550,6 +575,16 @@ void WiiSocket::Update(bool read, bool write, bool except)
{ {
case IOCTLV_SO_SENDTO: case IOCTLV_SO_SENDTO:
{ {
// The Wii allows a socket with a connection in progress to use
// sendto(). This might not be supported by the operating system.
// We have to enforce it manually.
connecting_state = GetConnectingState();
if (nonBlock && IsTCP() && connecting_state == ConnectingState::Connecting)
{
ReturnValue = -SO_EAGAIN;
break;
}
u32 flags = Memory::Read_U32(BufferIn2 + 0x04); u32 flags = Memory::Read_U32(BufferIn2 + 0x04);
u32 has_destaddr = Memory::Read_U32(BufferIn2 + 0x08); u32 has_destaddr = Memory::Read_U32(BufferIn2 + 0x08);
@ -564,8 +599,8 @@ void WiiSocket::Update(bool read, bool write, bool except)
sockaddr_in local_name = {0}; sockaddr_in local_name = {0};
if (has_destaddr) if (has_destaddr)
{ {
WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(BufferIn2 + 0x0C); const u8* addr = Memory::GetPointer(BufferIn2 + 0x0C);
WiiSockMan::Convert(*wii_name, local_name); WiiSockMan::ToNativeAddrIn(addr, &local_name);
} }
auto* to = has_destaddr ? reinterpret_cast<sockaddr*>(&local_name) : nullptr; auto* to = has_destaddr ? reinterpret_cast<sockaddr*>(&local_name) : nullptr;
@ -587,6 +622,16 @@ void WiiSocket::Update(bool read, bool write, bool except)
} }
case IOCTLV_SO_RECVFROM: case IOCTLV_SO_RECVFROM:
{ {
// The Wii allows a socket with a connection in progress to use
// recvfrom(). This might not be supported by the operating system.
// We have to enforce it manually.
connecting_state = GetConnectingState();
if (nonBlock && IsTCP() && connecting_state == ConnectingState::Connecting)
{
ReturnValue = -SO_EAGAIN;
break;
}
u32 flags = Memory::Read_U32(BufferIn + 0x04); u32 flags = Memory::Read_U32(BufferIn + 0x04);
// Not a string, Windows requires a char* for recvfrom // Not a string, Windows requires a char* for recvfrom
char* data = (char*)Memory::GetPointer(BufferOut); char* data = (char*)Memory::GetPointer(BufferOut);
@ -597,8 +642,8 @@ void WiiSocket::Update(bool read, bool write, bool except)
if (BufferOutSize2 != 0) if (BufferOutSize2 != 0)
{ {
WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(BufferOut2); const u8* addr = Memory::GetPointer(BufferOut2);
WiiSockMan::Convert(*wii_name, local_name); WiiSockMan::ToNativeAddrIn(addr, &local_name);
} }
// Act as non blocking when SO_MSG_NONBLOCK is specified // Act as non blocking when SO_MSG_NONBLOCK is specified
@ -634,8 +679,8 @@ void WiiSocket::Update(bool read, bool write, bool except)
if (BufferOutSize2 != 0) if (BufferOutSize2 != 0)
{ {
WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(BufferOut2); u8* addr = Memory::GetPointer(BufferOut2);
WiiSockMan::Convert(local_name, *wii_name, addrlen); WiiSockMan::ToWiiAddrIn(local_name, addr, addrlen);
} }
break; break;
} }
@ -672,6 +717,112 @@ void WiiSocket::Update(bool read, bool write, bool except)
} }
} }
void WiiSocket::UpdateConnectingState(s32 connect_rv)
{
if (connect_rv == -SO_EAGAIN || connect_rv == -SO_EALREADY || connect_rv == -SO_EINPROGRESS)
{
connecting_state = ConnectingState::Connecting;
}
else if (connect_rv >= 0)
{
connecting_state = ConnectingState::Connected;
}
else
{
connecting_state = ConnectingState::Error;
}
}
WiiSocket::ConnectingState WiiSocket::GetConnectingState() const
{
const auto state = Common::SaveNetworkErrorState();
Common::ScopeGuard guard([&state] { Common::RestoreNetworkErrorState(state); });
#ifdef _WIN32
constexpr int (*get_errno)() = &WSAGetLastError;
#else
constexpr int (*get_errno)() = []() { return errno; };
#endif
switch (connecting_state)
{
case ConnectingState::Error:
case ConnectingState::Connected:
case ConnectingState::None:
break;
case ConnectingState::Connecting:
{
const s32 nfds = fd + 1;
fd_set read_fds;
fd_set write_fds;
fd_set except_fds;
struct timeval t = {0, 0};
FD_ZERO(&read_fds);
FD_ZERO(&write_fds);
FD_ZERO(&except_fds);
FD_SET(fd, &write_fds);
FD_SET(fd, &except_fds);
auto& sm = WiiSockMan::GetInstance();
if (select(nfds, &read_fds, &write_fds, &except_fds, &t) < 0)
{
const s32 error = get_errno();
ERROR_LOG_FMT(IOS_SSL, "Failed to get socket (fd={}) connection state (err={}): {}", wii_fd,
error, sm.DecodeError(error));
return ConnectingState::Error;
}
if (FD_ISSET(fd, &write_fds) == 0 && FD_ISSET(fd, &except_fds) == 0)
break;
s32 error = 0;
socklen_t len = sizeof(error);
if (getsockopt(fd, SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&error), &len) != 0)
{
error = get_errno();
ERROR_LOG_FMT(IOS_SSL, "Failed to get socket (fd={}) error state (err={}): {}", wii_fd, error,
sm.DecodeError(error));
return ConnectingState::Error;
}
if (error != 0)
{
ERROR_LOG_FMT(IOS_SSL, "Non-blocking connect (fd={}) failed (err={}): {}", wii_fd, error,
sm.DecodeError(error));
return ConnectingState::Error;
}
// Get peername to ensure the socket is connected
sockaddr_in peer;
socklen_t peer_len = sizeof(peer);
if (getpeername(fd, reinterpret_cast<sockaddr*>(&peer), &peer_len) != 0)
{
error = get_errno();
ERROR_LOG_FMT(IOS_SSL, "Non-blocking connect (fd={}) failed to get peername (err={}): {}",
wii_fd, error, sm.DecodeError(error));
return ConnectingState::Error;
}
INFO_LOG_FMT(IOS_SSL, "Non-blocking connect (fd={}) succeeded", wii_fd);
return ConnectingState::Connected;
}
}
return connecting_state;
}
bool WiiSocket::IsTCP() const
{
const auto state = Common::SaveNetworkErrorState();
Common::ScopeGuard guard([&state] { Common::RestoreNetworkErrorState(state); });
int socket_type;
socklen_t option_length = sizeof(socket_type);
return getsockopt(fd, SOL_SOCKET, SO_TYPE, reinterpret_cast<char*>(&socket_type),
&option_length) == 0 &&
socket_type == SOCK_STREAM;
}
const WiiSocket::Timeout& WiiSocket::GetTimeout() const WiiSocket::Timeout& WiiSocket::GetTimeout()
{ {
if (!timeout.has_value()) if (!timeout.has_value())
@ -937,11 +1088,12 @@ void WiiSockMan::UpdatePollCommands()
pending_polls.end()); pending_polls.end());
} }
void WiiSockMan::Convert(WiiSockAddrIn const& from, sockaddr_in& to) void WiiSockMan::ToNativeAddrIn(const u8* addr, sockaddr_in* to)
{ {
to.sin_addr.s_addr = from.addr.addr; const WiiSockAddrIn from = Common::BitCastPtr<WiiSockAddrIn>(addr);
to.sin_family = from.family; to->sin_addr.s_addr = from.addr.addr;
to.sin_port = from.port; to->sin_family = from.family;
to->sin_port = from.port;
} }
s32 WiiSockMan::ConvertEvents(s32 events, ConvertDirection dir) s32 WiiSockMan::ConvertEvents(s32 events, ConvertDirection dir)
@ -981,15 +1133,15 @@ s32 WiiSockMan::ConvertEvents(s32 events, ConvertDirection dir)
return converted_events; return converted_events;
} }
void WiiSockMan::Convert(sockaddr_in const& from, WiiSockAddrIn& to, s32 addrlen) void WiiSockMan::ToWiiAddrIn(const sockaddr_in& from, u8* to, socklen_t addrlen)
{ {
to.addr.addr = from.sin_addr.s_addr; to[offsetof(WiiSockAddrIn, len)] =
to.family = from.sin_family & 0xFF; u8(addrlen > sizeof(WiiSockAddrIn) ? sizeof(WiiSockAddrIn) : addrlen);
to.port = from.sin_port; to[offsetof(WiiSockAddrIn, family)] = u8(from.sin_family & 0xFF);
if (addrlen < 0 || addrlen > static_cast<s32>(sizeof(WiiSockAddrIn))) const u16& from_port = from.sin_port;
to.len = sizeof(WiiSockAddrIn); memcpy(to + offsetof(WiiSockAddrIn, port), &from_port, sizeof(from_port));
else const u32& from_addr = from.sin_addr.s_addr;
to.len = addrlen; memcpy(to + offsetof(WiiSockAddrIn, addr.addr), &from_addr, sizeof(from_addr));
} }
void WiiSockMan::DoState(PointerWrap& p) void WiiSockMan::DoState(PointerWrap& p)

View File

@ -199,6 +199,14 @@ private:
void Abort(s32 value); void Abort(s32 value);
}; };
enum class ConnectingState
{
None,
Connecting,
Connected,
Error
};
friend class WiiSockMan; friend class WiiSockMan;
void SetFd(s32 s); void SetFd(s32 s);
void SetWiiFd(s32 s); void SetWiiFd(s32 s);
@ -212,11 +220,15 @@ private:
void DoSock(Request request, NET_IOCTL type); void DoSock(Request request, NET_IOCTL type);
void DoSock(Request request, SSL_IOCTL type); void DoSock(Request request, SSL_IOCTL type);
void Update(bool read, bool write, bool except); void Update(bool read, bool write, bool except);
void UpdateConnectingState(s32 connect_rv);
ConnectingState GetConnectingState() const;
bool IsValid() const { return fd >= 0; } bool IsValid() const { return fd >= 0; }
bool IsTCP() const;
s32 fd = -1; s32 fd = -1;
s32 wii_fd = -1; s32 wii_fd = -1;
bool nonBlock = false; bool nonBlock = false;
ConnectingState connecting_state = ConnectingState::None;
std::list<sockop> pending_sockops; std::list<sockop> pending_sockops;
std::optional<Timeout> timeout; std::optional<Timeout> timeout;
@ -248,8 +260,9 @@ public:
return instance; // Instantiated on first use. return instance; // Instantiated on first use.
} }
void Update(); void Update();
static void Convert(WiiSockAddrIn const& from, sockaddr_in& to); static void ToNativeAddrIn(const u8* from, sockaddr_in* to);
static void Convert(sockaddr_in const& from, WiiSockAddrIn& to, s32 addrlen = -1); static void ToWiiAddrIn(const sockaddr_in& from, u8* to,
socklen_t addrlen = sizeof(WiiSockAddrIn));
static s32 ConvertEvents(s32 events, ConvertDirection dir); static s32 ConvertEvents(s32 events, ConvertDirection dir);
void DoState(PointerWrap& p); void DoState(PointerWrap& p);

View File

@ -15,6 +15,7 @@
#include "Common/IOFile.h" #include "Common/IOFile.h"
#include "Common/Network.h" #include "Common/Network.h"
#include "Common/PcapFile.h" #include "Common/PcapFile.h"
#include "Common/ScopeGuard.h"
#include "Core/Config/MainSettings.h" #include "Core/Config/MainSettings.h"
#include "Core/ConfigManager.h" #include "Core/ConfigManager.h"
@ -90,24 +91,6 @@ void PCAPSSLCaptureLogger::OnNewSocket(s32 socket)
m_write_sequence_number[socket] = 0; m_write_sequence_number[socket] = 0;
} }
PCAPSSLCaptureLogger::ErrorState PCAPSSLCaptureLogger::SaveState() const
{
return {
errno,
#ifdef _WIN32
WSAGetLastError(),
#endif
};
}
void PCAPSSLCaptureLogger::RestoreState(const PCAPSSLCaptureLogger::ErrorState& state) const
{
errno = state.error;
#ifdef _WIN32
WSASetLastError(state.wsa_error);
#endif
}
void PCAPSSLCaptureLogger::LogSSLRead(const void* data, std::size_t length, s32 socket) void PCAPSSLCaptureLogger::LogSSLRead(const void* data, std::size_t length, s32 socket)
{ {
if (!Config::Get(Config::MAIN_NETWORK_SSL_DUMP_READ)) if (!Config::Get(Config::MAIN_NETWORK_SSL_DUMP_READ))
@ -135,7 +118,8 @@ void PCAPSSLCaptureLogger::LogWrite(const void* data, std::size_t length, s32 so
void PCAPSSLCaptureLogger::Log(LogType log_type, const void* data, std::size_t length, s32 socket, void PCAPSSLCaptureLogger::Log(LogType log_type, const void* data, std::size_t length, s32 socket,
sockaddr* other) sockaddr* other)
{ {
const auto state = SaveState(); const auto state = Common::SaveNetworkErrorState();
Common::ScopeGuard guard([&state] { Common::RestoreNetworkErrorState(state); });
sockaddr_in sock; sockaddr_in sock;
sockaddr_in peer; sockaddr_in peer;
sockaddr_in* from; sockaddr_in* from;
@ -144,16 +128,10 @@ void PCAPSSLCaptureLogger::Log(LogType log_type, const void* data, std::size_t l
socklen_t peer_len = sizeof(sock); socklen_t peer_len = sizeof(sock);
if (getsockname(socket, reinterpret_cast<sockaddr*>(&sock), &sock_len) != 0) if (getsockname(socket, reinterpret_cast<sockaddr*>(&sock), &sock_len) != 0)
{
RestoreState(state);
return; return;
}
if (other == nullptr && getpeername(socket, reinterpret_cast<sockaddr*>(&peer), &peer_len) != 0) if (other == nullptr && getpeername(socket, reinterpret_cast<sockaddr*>(&peer), &peer_len) != 0)
{
RestoreState(state);
return; return;
}
if (log_type == LogType::Read) if (log_type == LogType::Read)
{ {
@ -168,7 +146,6 @@ void PCAPSSLCaptureLogger::Log(LogType log_type, const void* data, std::size_t l
LogIPv4(log_type, reinterpret_cast<const u8*>(data), static_cast<u16>(length), socket, *from, LogIPv4(log_type, reinterpret_cast<const u8*>(data), static_cast<u16>(length), socket, *from,
*to); *to);
RestoreState(state);
} }
void PCAPSSLCaptureLogger::LogIPv4(LogType log_type, const u8* data, u16 length, s32 socket, void PCAPSSLCaptureLogger::LogIPv4(LogType log_type, const u8* data, u16 length, s32 socket,

View File

@ -99,15 +99,6 @@ private:
Read, Read,
Write, Write,
}; };
struct ErrorState
{
int error;
#ifdef _WIN32
int wsa_error;
#endif
};
ErrorState SaveState() const;
void RestoreState(const ErrorState& state) const;
void Log(LogType log_type, const void* data, std::size_t length, s32 socket, sockaddr* other); void Log(LogType log_type, const void* data, std::size_t length, s32 socket, sockaddr* other);
void LogIPv4(LogType log_type, const u8* data, u16 length, s32 socket, const sockaddr_in& from, void LogIPv4(LogType log_type, const u8* data, u16 length, s32 socket, const sockaddr_in& from,