From 083116a89ce922b540ca471a98a8a1e5501491b1 Mon Sep 17 00:00:00 2001 From: Martin Michelsen Date: Sat, 14 Oct 2023 17:52:26 -0700 Subject: [PATCH] rewrite tapserver interface for better error handling --- Source/Core/Common/SocketContext.cpp | 19 +- Source/Core/Common/SocketContext.h | 5 +- Source/Core/Core/CMakeLists.txt | 16 +- Source/Core/Core/HW/EXI/BBA/TAPServer.cpp | 196 +++++++++++++++--- Source/Core/Core/HW/EXI/EXI_DeviceEthernet.h | 26 ++- .../BroadbandAdapterSettingsDialog.cpp | 18 +- 6 files changed, 223 insertions(+), 57 deletions(-) diff --git a/Source/Core/Common/SocketContext.cpp b/Source/Core/Common/SocketContext.cpp index defc333c11..15f9fd9010 100644 --- a/Source/Core/Common/SocketContext.cpp +++ b/Source/Core/Common/SocketContext.cpp @@ -8,12 +8,27 @@ namespace Common #ifdef _WIN32 SocketContext::SocketContext() { - static_cast(WSAStartup(MAKEWORD(2, 2), &m_data)); + std::lock_guard g(s_lock); + if (s_num_objects == 0) + { + static_cast(WSAStartup(MAKEWORD(2, 2), &s_data)); + } + s_num_objects++; } SocketContext::~SocketContext() { - WSACleanup(); + std::lock_guard g(s_lock); + s_num_objects--; + if (s_num_objects == 0) + { + WSACleanup(); + } } + +std::mutex SocketContext::s_lock; +size_t SocketContext::s_num_objects = 0; +WSADATA SocketContext::s_data; + #else SocketContext::SocketContext() = default; SocketContext::~SocketContext() = default; diff --git a/Source/Core/Common/SocketContext.h b/Source/Core/Common/SocketContext.h index 7e072fd8c0..0aa4929e89 100644 --- a/Source/Core/Common/SocketContext.h +++ b/Source/Core/Common/SocketContext.h @@ -5,6 +5,7 @@ #ifdef _WIN32 #include +#include #endif namespace Common @@ -23,7 +24,9 @@ public: private: #ifdef _WIN32 - WSADATA m_data; + static std::mutex s_lock; + static size_t s_num_objects; + static WSADATA s_data; #endif }; } // namespace Common diff --git a/Source/Core/Core/CMakeLists.txt b/Source/Core/Core/CMakeLists.txt index 945c57760d..10b60a1a1c 100644 --- a/Source/Core/Core/CMakeLists.txt +++ b/Source/Core/Core/CMakeLists.txt @@ -189,6 +189,10 @@ add_library(core HW/DVD/DVDThread.h HW/DVD/FileMonitor.cpp HW/DVD/FileMonitor.h + HW/EXI/BBA/TAPServer.cpp + HW/EXI/BBA/XLINK_KAI_BBA.cpp + HW/EXI/BBA/BuiltIn.cpp + HW/EXI/BBA/BuiltIn.h HW/EXI/EXI_Channel.cpp HW/EXI/EXI_Channel.h HW/EXI/EXI_Device.cpp @@ -696,10 +700,6 @@ if(WIN32) target_sources(core PRIVATE HW/EXI/BBA/TAP_Win32.cpp HW/EXI/BBA/TAP_Win32.h - HW/EXI/BBA/TAPServer.cpp - HW/EXI/BBA/XLINK_KAI_BBA.cpp - HW/EXI/BBA/BuiltIn.cpp - HW/EXI/BBA/BuiltIn.h HW/WiimoteReal/IOWin.cpp HW/WiimoteReal/IOWin.h ) @@ -713,19 +713,11 @@ if(WIN32) elseif(APPLE) target_sources(core PRIVATE HW/EXI/BBA/TAP_Apple.cpp - HW/EXI/BBA/TAPServer.cpp - HW/EXI/BBA/XLINK_KAI_BBA.cpp - HW/EXI/BBA/BuiltIn.cpp - HW/EXI/BBA/BuiltIn.h ) target_link_libraries(core PUBLIC ${IOB_LIBRARY}) elseif(UNIX) target_sources(core PRIVATE HW/EXI/BBA/TAP_Unix.cpp - HW/EXI/BBA/TAPServer.cpp - HW/EXI/BBA/XLINK_KAI_BBA.cpp - HW/EXI/BBA/BuiltIn.cpp - HW/EXI/BBA/BuiltIn.h ) if(ANDROID) target_sources(core PRIVATE diff --git a/Source/Core/Core/HW/EXI/BBA/TAPServer.cpp b/Source/Core/Core/HW/EXI/BBA/TAPServer.cpp index 3e0644a41b..f81c1f479f 100644 --- a/Source/Core/Core/HW/EXI/BBA/TAPServer.cpp +++ b/Source/Core/Core/HW/EXI/BBA/TAPServer.cpp @@ -22,23 +22,38 @@ namespace ExpansionInterface { +#ifdef _WIN32 +static constexpr auto pi_close = &closesocket; +using ws_ssize_t = int; +#else +static constexpr auto pi_close = &close; +using ws_ssize_t = ssize_t; +#endif + +#ifdef __LINUX__ +#define SEND_FLAGS MSG_NOSIGNAL +#else +#define SEND_FLAGS 0 +#endif + static int ConnectToDestination(const std::string& destination) { if (destination.empty()) { - INFO_LOG_FMT(SP1, "Cannot connect: destination is empty\n"); + ERROR_LOG_FMT(SP1, "Cannot connect: destination is empty\n"); return -1; } - size_t ss_size; + int ss_size; struct sockaddr_storage ss; memset(&ss, 0, sizeof(ss)); if (destination[0] != '/') - { // IP address or hostname + { + // IP address or hostname size_t colon_offset = destination.find(':'); if (colon_offset == std::string::npos) { - INFO_LOG_FMT(SP1, "Destination IP address does not include port\n"); + ERROR_LOG_FMT(SP1, "Destination IP address does not include port\n"); return -1; } @@ -50,11 +65,12 @@ static int ConnectToDestination(const std::string& destination) #ifndef _WIN32 } else - { // UNIX socket + { + // UNIX socket struct sockaddr_un* sun = reinterpret_cast(&ss); if (destination.size() + 1 > sizeof(sun->sun_path)) { - INFO_LOG_FMT(SP1, "Socket path is too long, unable to init BBA\n"); + ERROR_LOG_FMT(SP1, "Socket path is too long, unable to init BBA\n"); return -1; } sun->sun_family = AF_UNIX; @@ -64,7 +80,7 @@ static int ConnectToDestination(const std::string& destination) } else { - INFO_LOG_FMT(SP1, "UNIX sockets are not supported on Windows\n"); + ERROR_LOG_FMT(SP1, "UNIX sockets are not supported on Windows\n"); return -1; #endif } @@ -72,7 +88,7 @@ static int ConnectToDestination(const std::string& destination) int fd = socket(ss.ss_family, SOCK_STREAM, (ss.ss_family == AF_INET) ? IPPROTO_TCP : 0); if (fd == -1) { - INFO_LOG_FMT(SP1, "Couldn't create socket; unable to init BBA\n"); + ERROR_LOG_FMT(SP1, "Couldn't create socket; unable to init BBA\n"); return -1; } @@ -86,7 +102,7 @@ static int ConnectToDestination(const std::string& destination) { std::string s = Common::LastStrerrorString(); INFO_LOG_FMT(SP1, "Couldn't connect socket ({}), unable to init BBA\n", s.c_str()); - close(fd); + pi_close(fd); return -1; } @@ -98,12 +114,44 @@ bool CEXIETHERNET::TAPServerNetworkInterface::Activate() if (IsActivated()) return true; - fd = ConnectToDestination(m_destination); + m_fd = ConnectToDestination(m_destination); INFO_LOG_FMT(SP1, "BBA initialized."); return RecvInit(); } +void CEXIETHERNET::TAPServerNetworkInterface::Deactivate() +{ + pi_close(m_fd); + m_fd = -1; + + m_read_enabled.Clear(); + m_read_shutdown.Set(); + if (m_read_thread.joinable()) + m_read_thread.join(); +} + +bool CEXIETHERNET::TAPServerNetworkInterface::IsActivated() +{ + return (m_fd >= 0); +} + +bool CEXIETHERNET::TAPServerNetworkInterface::RecvInit() +{ + m_read_thread = std::thread(&CEXIETHERNET::TAPServerNetworkInterface::ReadThreadHandler, this); + return true; +} + +void CEXIETHERNET::TAPServerNetworkInterface::RecvStart() +{ + m_read_enabled.Set(); +} + +void CEXIETHERNET::TAPServerNetworkInterface::RecvStop() +{ + m_read_enabled.Clear(); +} + bool CEXIETHERNET::TAPServerNetworkInterface::SendFrame(const u8* frame, u32 size) { { @@ -111,13 +159,16 @@ bool CEXIETHERNET::TAPServerNetworkInterface::SendFrame(const u8* frame, u32 siz INFO_LOG_FMT(SP1, "SendFrame {}\n{}", size, s); } - auto size16 = u16(size); - if (write(fd, &size16, 2) != 2) + // On Windows, the data pointer is of type const char*; on other systems it is + // of type const void*. This is the reason for the reinterpret_cast here and + // in the other send/recv calls in this file. + u8 size_bytes[2] = {static_cast(size), static_cast(size >> 8)}; + if (send(m_fd, reinterpret_cast(size_bytes), 2, SEND_FLAGS) != 2) { ERROR_LOG_FMT(SP1, "SendFrame(): could not write size field"); return false; } - int written_bytes = write(fd, frame, size); + int written_bytes = send(m_fd, reinterpret_cast(frame), size, SEND_FLAGS); if (u32(written_bytes) != size) { ERROR_LOG_FMT(SP1, "SendFrame(): expected to write {} bytes, instead wrote {}", size, @@ -133,45 +184,122 @@ bool CEXIETHERNET::TAPServerNetworkInterface::SendFrame(const u8* frame, u32 siz void CEXIETHERNET::TAPServerNetworkInterface::ReadThreadHandler() { - while (!readThreadShutdown.IsSet()) + while (!m_read_shutdown.IsSet()) { fd_set rfds; FD_ZERO(&rfds); - FD_SET(fd, &rfds); + FD_SET(m_fd, &rfds); timeval timeout; timeout.tv_sec = 0; timeout.tv_usec = 50000; - if (select(fd + 1, &rfds, nullptr, nullptr, &timeout) <= 0) + if (select(m_fd + 1, &rfds, nullptr, nullptr, &timeout) <= 0) continue; - u16 size; - if (read(fd, &size, 2) != 2) + // The tapserver protocol is very simple: there is a 16-bit little-endian + // size field, followed by that many bytes of packet data + switch (m_read_state) { - ERROR_LOG_FMT(SP1, "Failed to read size field from BBA: {}", Common::LastStrerrorString()); + case ReadState::Size: + { + u8 size_bytes[2]; + ws_ssize_t bytes_read = recv(m_fd, reinterpret_cast(size_bytes), 2, 0); + if (bytes_read == 1) + { + m_read_state = ReadState::SizeHigh; + m_read_packet_bytes_remaining = size_bytes[0]; + } + else if (bytes_read == 2) + { + m_read_packet_bytes_remaining = size_bytes[0] | (size_bytes[1] << 8); + if (m_read_packet_bytes_remaining > BBA_RECV_SIZE) + { + ERROR_LOG_FMT(SP1, "Packet is too large ({} bytes); dropping it", + m_read_packet_bytes_remaining); + m_read_state = ReadState::Skip; + } + else + { + m_read_state = ReadState::Data; + } + } + else + { + ERROR_LOG_FMT(SP1, "Failed to read size field from BBA: {}", Common::LastStrerrorString()); + } + break; } - else + case ReadState::SizeHigh: { - int read_bytes = read(fd, m_eth_ref->mRecvBuffer.get(), size); - if (read_bytes < 0) + // This handles the annoying case where only one byte of the size field + // was available earlier. + u8 size_high = 0; + ws_ssize_t bytes_read = recv(m_fd, reinterpret_cast(&size_high), 1, 0); + if (bytes_read == 1) { - ERROR_LOG_FMT(SP1, "Failed to read packet data from BBA: {}", Common::LastStrerrorString()); + m_read_packet_bytes_remaining |= (size_high << 8); + if (m_read_packet_bytes_remaining > BBA_RECV_SIZE) + { + ERROR_LOG_FMT(SP1, "Packet is too large ({} bytes); dropping it", + m_read_packet_bytes_remaining); + m_read_state = ReadState::Skip; + } + else + { + m_read_state = ReadState::Data; + } } - else if (readEnabled.IsSet()) + else { - std::string data_string = ArrayToString(m_eth_ref->mRecvBuffer.get(), read_bytes, 0x10); - INFO_LOG_FMT(SP1, "Read data: {}", data_string); - m_eth_ref->mRecvBufferLength = read_bytes; - m_eth_ref->RecvHandlePacket(); + ERROR_LOG_FMT(SP1, "Failed to read split size field from BBA: {}", + Common::LastStrerrorString()); } + break; + } + case ReadState::Data: + { + ws_ssize_t bytes_read = + recv(m_fd, reinterpret_cast(m_eth_ref->mRecvBuffer.get() + m_read_packet_offset), + m_read_packet_bytes_remaining, 0); + if (bytes_read <= 0) + { + ERROR_LOG_FMT(SP1, "Failed to read data from BBA: {}", Common::LastStrerrorString()); + } + else + { + m_read_packet_offset += bytes_read; + m_read_packet_bytes_remaining -= bytes_read; + if (m_read_packet_bytes_remaining == 0) + { + m_eth_ref->mRecvBufferLength = m_read_packet_offset; + m_eth_ref->RecvHandlePacket(); + m_read_state = ReadState::Size; + m_read_packet_offset = 0; + } + } + break; + } + case ReadState::Skip: + { + ws_ssize_t bytes_read = recv(m_fd, reinterpret_cast(m_eth_ref->mRecvBuffer.get()), + std::min(m_read_packet_bytes_remaining, BBA_RECV_SIZE), 0); + if (bytes_read <= 0) + { + ERROR_LOG_FMT(SP1, "Failed to read data from BBA: {}", Common::LastStrerrorString()); + } + else + { + m_read_packet_bytes_remaining -= bytes_read; + if (m_read_packet_bytes_remaining == 0) + { + m_read_state = ReadState::Size; + m_read_packet_offset = 0; + } + } + break; + } } } } -bool CEXIETHERNET::TAPServerNetworkInterface::RecvInit() -{ - readThread = std::thread(&CEXIETHERNET::TAPServerNetworkInterface::ReadThreadHandler, this); - return true; -} - } // namespace ExpansionInterface diff --git a/Source/Core/Core/HW/EXI/EXI_DeviceEthernet.h b/Source/Core/Core/HW/EXI/EXI_DeviceEthernet.h index 301420fbea..067f5d59c0 100644 --- a/Source/Core/Core/HW/EXI/EXI_DeviceEthernet.h +++ b/Source/Core/Core/HW/EXI/EXI_DeviceEthernet.h @@ -17,6 +17,7 @@ #include "Common/Flag.h" #include "Common/Network.h" +#include "Common/SocketContext.h" #include "Core/HW/EXI/BBA/BuiltIn.h" #include "Core/HW/EXI/EXI_Device.h" @@ -362,21 +363,42 @@ private: #endif }; - class TAPServerNetworkInterface : public TAPNetworkInterface + class TAPServerNetworkInterface : public NetworkInterface { public: explicit TAPServerNetworkInterface(CEXIETHERNET* eth_ref, const std::string& destination) - : TAPNetworkInterface(eth_ref), m_destination(destination) + : NetworkInterface(eth_ref), m_destination(destination) { } public: bool Activate() override; + void Deactivate() override; + bool IsActivated() override; bool SendFrame(const u8* frame, u32 size) override; bool RecvInit() override; + void RecvStart() override; + void RecvStop() override; private: + enum class ReadState + { + Size, + SizeHigh, + Data, + Skip, + }; + std::string m_destination; + Common::SocketContext m_socket_context; + + int m_fd = -1; + ReadState m_read_state = ReadState::Size; + u16 m_read_packet_offset; + u16 m_read_packet_bytes_remaining; + std::thread m_read_thread; + Common::Flag m_read_enabled; + Common::Flag m_read_shutdown; void ReadThreadHandler(); }; diff --git a/Source/Core/DolphinQt/Settings/BroadbandAdapterSettingsDialog.cpp b/Source/Core/DolphinQt/Settings/BroadbandAdapterSettingsDialog.cpp index f8abbc77ab..9716d8f211 100644 --- a/Source/Core/DolphinQt/Settings/BroadbandAdapterSettingsDialog.cpp +++ b/Source/Core/DolphinQt/Settings/BroadbandAdapterSettingsDialog.cpp @@ -49,13 +49,19 @@ void BroadbandAdapterSettingsDialog::InitControls() break; case Type::TapServer: - address_label = new QLabel(tr("UNIX socket path or netloc (address:port):")); - address_placeholder = QStringLiteral("/tmp/dolphin-tap"); current_address = QString::fromStdString(Config::Get(Config::MAIN_BBA_TAPSERVER_DESTINATION)); - description = - new QLabel(tr("On macOS and Linux, the default value \"/tmp/dolphin-tap\" will work with " - "tapserver and newserv. On Windows, you must enter an IP address and port.")); - +#ifdef _WIN32 + address_label = new QLabel(tr("Destination (address:port):")); + address_placeholder = QStringLiteral(""); + description = new QLabel( + tr("Enter the IP address and port of the tapserver instance you want to connect to.")); +#else + address_label = new QLabel(tr("Destination (UNIX socket path or address:port):")); + address_placeholder = QStringLiteral("/tmp/dolphin-tap"); + description = new QLabel(tr( + "The default value \"/tmp/dolphin-tap\" will work with a local tapserver and newserv. You " + "can also enter a network location (address:port) to connect to a remote tapserver.")); +#endif window_title = tr("BBA destination address"); break;