From e2f972571906b909ad19cb922b1fa5549e3522da Mon Sep 17 00:00:00 2001 From: Exzap <13877693+Exzap@users.noreply.github.com> Date: Thu, 18 Apr 2024 19:22:28 +0200 Subject: [PATCH] prudp: Code cleanup --- src/Cemu/nex/nex.cpp | 42 +-- src/Cemu/nex/prudp.cpp | 642 ++++++++++++++++++++--------------------- src/Cemu/nex/prudp.h | 148 +++++----- 3 files changed, 410 insertions(+), 422 deletions(-) diff --git a/src/Cemu/nex/nex.cpp b/src/Cemu/nex/nex.cpp index d0857507..973a4395 100644 --- a/src/Cemu/nex/nex.cpp +++ b/src/Cemu/nex/nex.cpp @@ -106,7 +106,7 @@ nexService::nexService() nexService::nexService(prudpClient* con) : nexService() { - if (con->isConnected() == false) + if (con->IsConnected() == false) cemu_assert_suspicious(); this->conNexService = con; bufferReceive = std::vector(1024 * 4); @@ -191,7 +191,7 @@ void nexService::processQueuedRequest(queuedRequest_t* queuedRequest) uint32 callId = _currentCallId; _currentCallId++; // check state of connection - if (conNexService->getConnectionState() != prudpClient::STATE_CONNECTED) + if (conNexService->GetConnectionState() != prudpClient::ConnectionState::Connected) { nexServiceResponse_t response = { 0 }; response.isSuccessful = false; @@ -214,7 +214,7 @@ void nexService::processQueuedRequest(queuedRequest_t* queuedRequest) assert_dbg(); memcpy((packetBuffer + 0x0D), &queuedRequest->parameterData.front(), queuedRequest->parameterData.size()); sint32 length = 0xD + (sint32)queuedRequest->parameterData.size(); - conNexService->sendDatagram(packetBuffer, length, true); + conNexService->SendDatagram(packetBuffer, length, true); // remember request nexActiveRequestInfo_t requestInfo = { 0 }; requestInfo.callId = callId; @@ -299,13 +299,13 @@ void nexService::registerForAsyncProcessing() void nexService::updateTemporaryConnections() { // check for connection - conNexService->update(); - if (conNexService->isConnected()) + conNexService->Update(); + if (conNexService->IsConnected()) { if (connectionState == STATE_CONNECTING) connectionState = STATE_CONNECTED; } - if (conNexService->getConnectionState() == prudpClient::STATE_DISCONNECTED) + if (conNexService->GetConnectionState() == prudpClient::ConnectionState::Disconnected) connectionState = STATE_DISCONNECTED; } @@ -356,18 +356,18 @@ void nexService::sendRequestResponse(nexServiceRequest_t* request, uint32 errorC // update length field *(uint32*)response.getDataPtr() = response.getWriteIndex()-4; if(request->nex->conNexService) - request->nex->conNexService->sendDatagram(response.getDataPtr(), response.getWriteIndex(), true); + request->nex->conNexService->SendDatagram(response.getDataPtr(), response.getWriteIndex(), true); } void nexService::updateNexServiceConnection() { - if (conNexService->getConnectionState() == prudpClient::STATE_DISCONNECTED) + if (conNexService->GetConnectionState() == prudpClient::ConnectionState::Disconnected) { this->connectionState = STATE_DISCONNECTED; return; } - conNexService->update(); - sint32 datagramLen = conNexService->receiveDatagram(bufferReceive); + conNexService->Update(); + sint32 datagramLen = conNexService->ReceiveDatagram(bufferReceive); if (datagramLen > 0) { if (nexIsRequest(&bufferReceive[0], datagramLen)) @@ -454,12 +454,12 @@ bool _extractStationUrlParamValue(const char* urlStr, const char* paramName, cha return false; } -void nexServiceAuthentication_parseStationURL(char* urlStr, stationUrl_t* stationUrl) +void nexServiceAuthentication_parseStationURL(char* urlStr, prudpStationUrl* stationUrl) { // example: // prudps:/address=34.210.xxx.xxx;port=60181;CID=1;PID=2;sid=1;stream=10;type=2 - memset(stationUrl, 0, sizeof(stationUrl_t)); + memset(stationUrl, 0, sizeof(prudpStationUrl)); char optionValue[128]; if (_extractStationUrlParamValue(urlStr, "address", optionValue, sizeof(optionValue))) @@ -499,7 +499,7 @@ typedef struct sint32 kerberosTicketSize; uint8 kerberosTicket2[4096]; sint32 kerberosTicket2Size; - stationUrl_t server; + prudpStationUrl server; // progress info bool hasError; bool done; @@ -611,18 +611,18 @@ void nexServiceSecure_handleResponse_RegisterEx(nexService* nex, nexServiceRespo return; } -nexService* nex_secureLogin(authServerInfo_t* authServerInfo, const char* accessKey, const char* nexToken) +nexService* nex_secureLogin(prudpAuthServerInfo* authServerInfo, const char* accessKey, const char* nexToken) { prudpClient* prudpSecureSock = new prudpClient(authServerInfo->server.ip, authServerInfo->server.port, accessKey, authServerInfo); // wait until connected while (true) { - prudpSecureSock->update(); - if (prudpSecureSock->isConnected()) + prudpSecureSock->Update(); + if (prudpSecureSock->IsConnected()) { break; } - if (prudpSecureSock->getConnectionState() == prudpClient::STATE_DISCONNECTED) + if (prudpSecureSock->GetConnectionState() == prudpClient::ConnectionState::Disconnected) { // timeout or disconnected cemuLog_log(LogType::Force, "NEX: Secure login connection time-out"); @@ -638,7 +638,7 @@ nexService* nex_secureLogin(authServerInfo_t* authServerInfo, const char* access nexPacketBuffer packetBuffer(tempNexBufferArray, sizeof(tempNexBufferArray), true); char clientStationUrl[256]; - sprintf(clientStationUrl, "prudp:/port=%u;natf=0;natm=0;pmp=0;sid=15;type=2;upnp=0", (uint32)nex->getPRUDPConnection()->getSourcePort()); + sprintf(clientStationUrl, "prudp:/port=%u;natf=0;natm=0;pmp=0;sid=15;type=2;upnp=0", (uint32)nex->getPRUDPConnection()->GetSourcePort()); // station url list packetBuffer.writeU32(1); packetBuffer.writeString(clientStationUrl); @@ -737,9 +737,9 @@ nexService* nex_establishSecureConnection(uint32 authServerIp, uint16 authServer return nullptr; } // auth info - auto authServerInfo = std::make_unique(); + auto authServerInfo = std::make_unique(); // decrypt ticket - RC4Ctx_t rc4Ticket; + RC4Ctx rc4Ticket; RC4_initCtx(&rc4Ticket, kerberosKey, 16); RC4_transform(&rc4Ticket, nexAuthService.kerberosTicket2, nexAuthService.kerberosTicket2Size - 16, nexAuthService.kerberosTicket2); nexPacketBuffer packetKerberosTicket(nexAuthService.kerberosTicket2, nexAuthService.kerberosTicket2Size - 16, false); @@ -756,7 +756,7 @@ nexService* nex_establishSecureConnection(uint32 authServerIp, uint16 authServer memcpy(authServerInfo->kerberosKey, kerberosKey, 16); memcpy(authServerInfo->secureKey, secureKey, 16); - memcpy(&authServerInfo->server, &nexAuthService.server, sizeof(stationUrl_t)); + memcpy(&authServerInfo->server, &nexAuthService.server, sizeof(prudpStationUrl)); authServerInfo->userPid = pid; return nex_secureLogin(authServerInfo.get(), accessKey, nexToken); diff --git a/src/Cemu/nex/prudp.cpp b/src/Cemu/nex/prudp.cpp index 7c01bec7..5c773fe7 100644 --- a/src/Cemu/nex/prudp.cpp +++ b/src/Cemu/nex/prudp.cpp @@ -1,72 +1,57 @@ #include "prudp.h" #include "util/crypto/md5.h" -#include -#include +#include +#include #include -void swap(unsigned char *a, unsigned char *b) +static void KSA(unsigned char* key, int keyLen, unsigned char* S) { - int tmp = *a; - *a = *b; - *b = tmp; -} - -void KSA(unsigned char *key, int keyLen, unsigned char *S) -{ - int j = 0; - for (int i = 0; i < RC4_N; i++) S[i] = i; - - for (int i = 0; i < RC4_N; i++) + int j = 0; + for (int i = 0; i < RC4_N; i++) { j = (j + S[i] + key[i % keyLen]) % RC4_N; - - swap(&S[i], &S[j]); + std::swap(S[i], S[j]); } } -void PRGA(unsigned char *S, unsigned char* input, int len, unsigned char* output) +static void PRGA(unsigned char* S, unsigned char* input, int len, unsigned char* output) { - int i = 0; - int j = 0; - - for (size_t n = 0; n < len; n++) + for (size_t n = 0; n < len; n++) { - i = (i + 1) % RC4_N; - j = (j + S[i]) % RC4_N; - - swap(&S[i], &S[j]); + int i = (i + 1) % RC4_N; + int j = (j + S[i]) % RC4_N; + std::swap(S[i], S[j]); int rnd = S[(S[i] + S[j]) % RC4_N]; - output[n] = rnd ^ input[n]; } } -void RC4(char* key, unsigned char* input, int len, unsigned char* output) +static void RC4(char* key, unsigned char* input, int len, unsigned char* output) { unsigned char S[RC4_N]; KSA((unsigned char*)key, (int)strlen(key), S); PRGA(S, input, len, output); } -void RC4_initCtx(RC4Ctx_t* rc4Ctx, const char* key) +void RC4_initCtx(RC4Ctx* rc4Ctx, const char* key) { rc4Ctx->i = 0; rc4Ctx->j = 0; KSA((unsigned char*)key, (int)strlen(key), rc4Ctx->S); } -void RC4_initCtx(RC4Ctx_t* rc4Ctx, unsigned char* key, int keyLen) +void RC4_initCtx(RC4Ctx* rc4Ctx, unsigned char* key, int keyLen) { rc4Ctx->i = 0; rc4Ctx->j = 0; KSA(key, keyLen, rc4Ctx->S); } -void RC4_transform(RC4Ctx_t* rc4Ctx, unsigned char* input, int len, unsigned char* output) +void RC4_transform(RC4Ctx* rc4Ctx, unsigned char* input, int len, unsigned char* output) { int i = rc4Ctx->i; int j = rc4Ctx->j; @@ -75,13 +60,10 @@ void RC4_transform(RC4Ctx_t* rc4Ctx, unsigned char* input, int len, unsigned cha { i = (i + 1) % RC4_N; j = (j + rc4Ctx->S[i]) % RC4_N; - - swap(&rc4Ctx->S[i], &rc4Ctx->S[j]); + std::swap(rc4Ctx->S[i], rc4Ctx->S[j]); int rnd = rc4Ctx->S[(rc4Ctx->S[i] + rc4Ctx->S[j]) % RC4_N]; - output[n] = rnd ^ input[n]; } - rc4Ctx->i = i; rc4Ctx->j = j; } @@ -91,34 +73,14 @@ uint32 prudpGetMSTimestamp() return GetTickCount(); } -std::bitset<10000> _portUsageMask; - -uint16 getRandomSrcPRUDPPort() -{ - while (true) - { - sint32 p = rand() % 10000; - if (_portUsageMask.test(p)) - continue; - _portUsageMask.set(p); - return 40000 + p; - } - return 0; -} - -void releasePRUDPPort(uint16 port) -{ - uint32 bitIndex = port - 40000; - _portUsageMask.reset(bitIndex); -} - std::mt19937_64 prudpRG(GetTickCount()); -// workaround for static asserts when using uniform_int_distribution -boost::random::uniform_int_distribution prudpDis8(0, 0xFF); +// workaround for static asserts when using uniform_int_distribution (see https://github.com/cemu-project/Cemu/issues/48) +boost::random::uniform_int_distribution prudpRandomDistribution8(0, 0xFF); +boost::random::uniform_int_distribution prudpRandomDistributionPortGen(0, 10000); uint8 prudp_generateRandomU8() { - return prudpDis8(prudpRG); + return prudpRandomDistribution8(prudpRG); } uint32 prudp_generateRandomU32() @@ -133,7 +95,29 @@ uint32 prudp_generateRandomU32() return v; } -uint8 prudp_calculateChecksum(uint8 checksumBase, uint8* data, sint32 length) +std::bitset<10000> _portUsageMask; + +static uint16 AllocateRandomSrcPRUDPPort() +{ + while (true) + { + sint32 p = prudpRandomDistributionPortGen(prudpRG); + if (_portUsageMask.test(p)) + continue; + _portUsageMask.set(p); + return 40000 + p; + } +} + +static void ReleasePRUDPSrcPort(uint16 port) +{ + cemu_assert_debug(port >= 40000); + uint32 bitIndex = port - 40000; + cemu_assert_debug(_portUsageMask.test(bitIndex)); + _portUsageMask.reset(bitIndex); +} + +static uint8 prudp_calculateChecksum(uint8 checksumBase, uint8* data, sint32 length) { uint32 checksum32 = 0; for (sint32 i = 0; i < length / 4; i++) @@ -141,7 +125,7 @@ uint8 prudp_calculateChecksum(uint8 checksumBase, uint8* data, sint32 length) checksum32 += *(uint32*)(data + i * 4); } uint8 checksum = checksumBase; - for (sint32 i = length&(~3); i < length; i++) + for (sint32 i = length & (~3); i < length; i++) { checksum += data[i]; } @@ -161,16 +145,16 @@ sint32 prudpPacket::calculateSizeFromPacketData(uint8* data, sint32 length) return 0; // get flags fields uint16 typeAndFlags = *(uint16*)(data + 0x02); - uint16 type = (typeAndFlags&0xF); + uint16 type = (typeAndFlags & 0xF); uint16 flags = (typeAndFlags >> 4); - if ((flags&FLAG_HAS_SIZE) == 0) + if ((flags & FLAG_HAS_SIZE) == 0) return length; // without a size field, we cant calculate the length sint32 calculatedSize; if (type == TYPE_SYN) { if (length < (0xB + 0x4 + 2)) return 0; - uint16 payloadSize = *(uint16*)(data+0xB+0x4); + uint16 payloadSize = *(uint16*)(data + 0xB + 0x4); calculatedSize = 0xB + 0x4 + 2 + (sint32)payloadSize + 1; // base header + connection signature (SYN param) + payloadSize field + checksum after payload if (calculatedSize > length) return 0; @@ -212,7 +196,7 @@ sint32 prudpPacket::calculateSizeFromPacketData(uint8* data, sint32 length) return length; } -prudpPacket::prudpPacket(prudpStreamSettings_t* streamSettings, uint8 src, uint8 dst, uint8 type, uint16 flags, uint8 sessionId, uint16 sequenceId, uint32 packetSignature) +prudpPacket::prudpPacket(prudpStreamSettings* streamSettings, uint8 src, uint8 dst, uint8 type, uint16 flags, uint8 sessionId, uint16 sequenceId, uint32 packetSignature) { this->src = src; this->dst = dst; @@ -228,7 +212,7 @@ prudpPacket::prudpPacket(prudpStreamSettings_t* streamSettings, uint8 src, uint8 bool prudpPacket::requiresAck() { - return (flags&FLAG_NEED_ACK) != 0; + return (flags & FLAG_NEED_ACK) != 0; } sint32 prudpPacket::buildData(uint8* output, sint32 maxLength) @@ -352,7 +336,8 @@ prudpIncomingPacket::prudpIncomingPacket() streamSettings = nullptr; } -prudpIncomingPacket::prudpIncomingPacket(prudpStreamSettings_t* streamSettings, uint8* data, sint32 length) : prudpIncomingPacket() +prudpIncomingPacket::prudpIncomingPacket(prudpStreamSettings* streamSettings, uint8* data, sint32 length) + : prudpIncomingPacket() { if (length < 0xB + 1) { @@ -418,7 +403,7 @@ prudpIncomingPacket::prudpIncomingPacket(prudpStreamSettings_t* streamSettings, bool hasPayloadSize = (this->flags & prudpPacket::FLAG_HAS_SIZE) != 0; // verify length - if ((length-readIndex) < 1+(hasPayloadSize?2:0)) + if ((length - readIndex) < 1 + (hasPayloadSize ? 2 : 0)) { // too short isInvalid = true; @@ -475,57 +460,45 @@ void prudpIncomingPacket::decrypt() RC4_transform(&streamSettings->rc4Server, &packetData.front(), (int)packetData.size(), &packetData.front()); } -#define PRUDP_VPORT(__streamType, __port) (((__streamType)<<4) | (__port)) +#define PRUDP_VPORT(__streamType, __port) (((__streamType) << 4) | (__port)) prudpClient::prudpClient() { - currentConnectionState = STATE_CONNECTING; - serverConnectionSignature = 0; - clientConnectionSignature = 0; - hasSentCon = false; - outgoingSequenceId = 0; - incomingSequenceId = 0; + m_currentConnectionState = ConnectionState::Connecting; + m_serverConnectionSignature = 0; + m_clientConnectionSignature = 0; + m_incomingSequenceId = 0; - clientSessionId = 0; - serverSessionId = 0; - - isSecureConnection = false; + m_clientSessionId = 0; + m_serverSessionId = 0; } -prudpClient::~prudpClient() +prudpClient::prudpClient(uint32 dstIp, uint16 dstPort, const char* key) + : prudpClient() { - if (srcPort != 0) - { - releasePRUDPPort(srcPort); - closesocket(socketUdp); - } -} - -prudpClient::prudpClient(uint32 dstIp, uint16 dstPort, const char* key) : prudpClient() -{ - this->dstIp = dstIp; - this->dstPort = dstPort; + m_dstIp = dstIp; + m_dstPort = dstPort; // get unused random source port for (sint32 tries = 0; tries < 5; tries++) { - srcPort = getRandomSrcPRUDPPort(); + m_srcPort = AllocateRandomSrcPRUDPPort(); // create and bind udp socket - socketUdp = socket(AF_INET, SOCK_DGRAM, 0); + m_socketUdp = socket(AF_INET, SOCK_DGRAM, 0); struct sockaddr_in udpServer; udpServer.sin_family = AF_INET; udpServer.sin_addr.s_addr = INADDR_ANY; - udpServer.sin_port = htons(srcPort); - if (bind(socketUdp, (struct sockaddr *)&udpServer, sizeof(udpServer)) == SOCKET_ERROR) + udpServer.sin_port = htons(m_srcPort); + if (bind(m_socketUdp, (struct sockaddr*)&udpServer, sizeof(udpServer)) == SOCKET_ERROR) { + ReleasePRUDPSrcPort(m_srcPort); + m_srcPort = 0; if (tries == 4) { cemuLog_log(LogType::Force, "PRUDP: Failed to bind UDP socket"); - currentConnectionState = STATE_DISCONNECTED; - srcPort = 0; + m_currentConnectionState = ConnectionState::Disconnected; return; } - releasePRUDPPort(srcPort); - closesocket(socketUdp); + closesocket(m_socketUdp); continue; } else @@ -533,79 +506,77 @@ prudpClient::prudpClient(uint32 dstIp, uint16 dstPort, const char* key) : prudpC } // set socket to non-blocking mode #if BOOST_OS_WINDOWS - u_long nonBlockingMode = 1; // 1 to enable non-blocking socket - ioctlsocket(socketUdp, FIONBIO, &nonBlockingMode); + u_long nonBlockingMode = 1; // 1 to enable non-blocking socket + ioctlsocket(m_socketUdp, FIONBIO, &nonBlockingMode); #else int flags = fcntl(socketUdp, F_GETFL); fcntl(socketUdp, F_SETFL, flags | O_NONBLOCK); #endif // generate frequently used parameters - this->vport_src = PRUDP_VPORT(prudpPacket::STREAM_TYPE_SECURE, 0xF); - this->vport_dst = PRUDP_VPORT(prudpPacket::STREAM_TYPE_SECURE, 0x1); + this->m_srcVPort = PRUDP_VPORT(prudpPacket::STREAM_TYPE_SECURE, 0xF); + this->m_dstVPort = PRUDP_VPORT(prudpPacket::STREAM_TYPE_SECURE, 0x1); // set stream settings uint8 checksumBase = 0; for (sint32 i = 0; key[i] != '\0'; i++) { checksumBase += key[i]; } - streamSettings.checksumBase = checksumBase; + m_streamSettings.checksumBase = checksumBase; MD5_CTX md5Ctx; MD5_Init(&md5Ctx); MD5_Update(&md5Ctx, key, (int)strlen(key)); - MD5_Final(streamSettings.accessKeyDigest, &md5Ctx); + MD5_Final(m_streamSettings.accessKeyDigest, &md5Ctx); // init stream ciphers - RC4_initCtx(&streamSettings.rc4Server, "CD&ML"); - RC4_initCtx(&streamSettings.rc4Client, "CD&ML"); + RC4_initCtx(&m_streamSettings.rc4Server, "CD&ML"); + RC4_initCtx(&m_streamSettings.rc4Client, "CD&ML"); // send syn packet - prudpPacket* synPacket = new prudpPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_SYN, prudpPacket::FLAG_NEED_ACK, 0, 0, 0); - queuePacket(synPacket, dstIp, dstPort); - outgoingSequenceId++; + SendCurrentHandshakePacket(); // set incoming sequence id to 1 - incomingSequenceId = 1; + m_incomingSequenceId = 1; } -prudpClient::prudpClient(uint32 dstIp, uint16 dstPort, const char* key, authServerInfo_t* authInfo) : prudpClient(dstIp, dstPort, key) +prudpClient::prudpClient(uint32 dstIp, uint16 dstPort, const char* key, prudpAuthServerInfo* authInfo) + : prudpClient(dstIp, dstPort, key) { - RC4_initCtx(&streamSettings.rc4Server, authInfo->secureKey, 16); - RC4_initCtx(&streamSettings.rc4Client, authInfo->secureKey, 16); - this->isSecureConnection = true; - memcpy(&this->authInfo, authInfo, sizeof(authServerInfo_t)); + RC4_initCtx(&m_streamSettings.rc4Server, authInfo->secureKey, 16); + RC4_initCtx(&m_streamSettings.rc4Client, authInfo->secureKey, 16); + this->m_isSecureConnection = true; + memcpy(&this->m_authInfo, authInfo, sizeof(prudpAuthServerInfo)); } -bool prudpClient::isConnected() +prudpClient::~prudpClient() { - return currentConnectionState == STATE_CONNECTED; + if (m_srcPort != 0) + { + ReleasePRUDPSrcPort(m_srcPort); + closesocket(m_socketUdp); + } } -uint8 prudpClient::getConnectionState() +void prudpClient::AcknowledgePacket(uint16 sequenceId) { - return currentConnectionState; -} - -void prudpClient::acknowledgePacket(uint16 sequenceId) -{ - auto it = std::begin(list_packetsWithAckReq); - while (it != std::end(list_packetsWithAckReq)) + auto it = std::begin(m_dataPacketsWithAckReq); + while (it != std::end(m_dataPacketsWithAckReq)) { if (it->packet->GetSequenceId() == sequenceId) { delete it->packet; - list_packetsWithAckReq.erase(it); + m_dataPacketsWithAckReq.erase(it); return; } it++; } } -void prudpClient::sortIncomingDataPacket(prudpIncomingPacket* incomingPacket) +void prudpClient::SortIncomingDataPacket(std::unique_ptr incomingPacket) { uint16 sequenceIdIncomingPacket = incomingPacket->sequenceId; // find insert index sint32 insertIndex = 0; - while (insertIndex < queue_incomingPackets.size() ) + while (insertIndex < m_incomingPacketQueue.size()) { - uint16 seqDif = sequenceIdIncomingPacket - queue_incomingPackets[insertIndex]->sequenceId; - if (seqDif&0x8000) + uint16 seqDif = sequenceIdIncomingPacket - m_incomingPacketQueue[insertIndex]->sequenceId; + if (seqDif & 0x8000) break; // negative seqDif -> insert before current element #ifdef CEMU_DEBUG_ASSERT if (seqDif == 0) @@ -613,39 +584,83 @@ void prudpClient::sortIncomingDataPacket(prudpIncomingPacket* incomingPacket) #endif insertIndex++; } - // insert - sint32 currentSize = (sint32)queue_incomingPackets.size(); - queue_incomingPackets.resize(currentSize+1); - for(sint32 i=currentSize; i>insertIndex; i--) + m_incomingPacketQueue.insert(m_incomingPacketQueue.begin() + insertIndex, std::move(incomingPacket)); + // debug check if packets are really ordered by sequence id +#ifdef CEMU_DEBUG_ASSERT + for (sint32 i = 1; i < m_incomingPacketQueue.size(); i++) { - queue_incomingPackets[i] = queue_incomingPackets[i - 1]; + uint16 seqDif = m_incomingPacketQueue[i]->sequenceId - m_incomingPacketQueue[i - 1]->sequenceId; + if (seqDif & 0x8000) + seqDif = -seqDif; + if (seqDif >= 0x8000) + assert_dbg(); } - queue_incomingPackets[insertIndex] = incomingPacket; +#endif } -sint32 prudpClient::kerberosEncryptData(uint8* input, sint32 length, uint8* output) +sint32 prudpClient::KerberosEncryptData(uint8* input, sint32 length, uint8* output) { - RC4Ctx_t rc4Kerberos; - RC4_initCtx(&rc4Kerberos, this->authInfo.secureKey, 16); + RC4Ctx rc4Kerberos; + RC4_initCtx(&rc4Kerberos, this->m_authInfo.secureKey, 16); memcpy(output, input, length); RC4_transform(&rc4Kerberos, output, length, output); // calculate and append hmac - hmacMD5(this->authInfo.secureKey, 16, output, length, output+length); + hmacMD5(this->m_authInfo.secureKey, 16, output, length, output + length); return length + 16; } -void prudpClient::handleIncomingPacket(prudpIncomingPacket* incomingPacket) +// (re)sends either CON or SYN based on what stage of the login we are at +// the sequenceId for both is hardcoded for both because we'll never send anything in between +void prudpClient::SendCurrentHandshakePacket() { - if(incomingPacket->type == prudpPacket::TYPE_PING) + if (!m_hasSynAck) { - if (incomingPacket->flags&prudpPacket::FLAG_ACK) + // send syn (with a fixed sequenceId of 0) + prudpPacket synPacket(&m_streamSettings, m_srcVPort, m_dstVPort, prudpPacket::TYPE_SYN, prudpPacket::FLAG_NEED_ACK, 0, 0, 0); + DirectSendPacket(&synPacket); + } + else + { + // send con (with a fixed sequenceId of 1) + prudpPacket conPacket(&m_streamSettings, m_srcVPort, m_dstVPort, prudpPacket::TYPE_CON, prudpPacket::FLAG_NEED_ACK | prudpPacket::FLAG_RELIABLE, this->m_clientSessionId, 1, m_serverConnectionSignature); + if (this->m_isSecureConnection) + { + uint8 tempBuffer[512]; + nexPacketBuffer conData(tempBuffer, sizeof(tempBuffer), true); + conData.writeU32(this->m_clientConnectionSignature); + conData.writeBuffer(m_authInfo.secureTicket, m_authInfo.secureTicketLength); + // encrypted request data + uint8 requestData[4 * 3]; + uint8 requestDataEncrypted[4 * 3 + 0x10]; + *(uint32*)(requestData + 0x0) = m_authInfo.userPid; + *(uint32*)(requestData + 0x4) = m_authInfo.server.cid; + *(uint32*)(requestData + 0x8) = prudp_generateRandomU32(); // todo - check value + sint32 encryptedSize = KerberosEncryptData(requestData, sizeof(requestData), requestDataEncrypted); + conData.writeBuffer(requestDataEncrypted, encryptedSize); + conPacket.setData(conData.getDataPtr(), conData.getWriteIndex()); + } + else + { + conPacket.setData((uint8*)&this->m_clientConnectionSignature, sizeof(uint32)); + } + DirectSendPacket(&conPacket); + } + m_lastHandshakeTimestamp = prudpGetMSTimestamp(); + m_handshakeRetryCount++; +} + +void prudpClient::HandleIncomingPacket(std::unique_ptr incomingPacket) +{ + if (incomingPacket->type == prudpPacket::TYPE_PING) + { + if (incomingPacket->flags & prudpPacket::FLAG_ACK) { // ack for our ping packet - if(incomingPacket->flags&prudpPacket::FLAG_NEED_ACK) + if (incomingPacket->flags & prudpPacket::FLAG_NEED_ACK) cemuLog_log(LogType::PRUDP, "[PRUDP] Received unexpected ping packet with both ACK and NEED_ACK set"); - if(m_unacknowledgedPingCount > 0) + if (m_unacknowledgedPingCount > 0) { - if(incomingPacket->sequenceId == m_outgoingSequenceId_ping) + if (incomingPacket->sequenceId == m_outgoingSequenceId_ping) { cemuLog_log(LogType::PRUDP, "[PRUDP] Received ping packet ACK (unacknowledged count: {})", m_unacknowledgedPingCount); m_unacknowledgedPingCount = 0; @@ -660,140 +675,127 @@ void prudpClient::handleIncomingPacket(prudpIncomingPacket* incomingPacket) cemuLog_log(LogType::PRUDP, "[PRUDP] Received ping packet ACK which we dont need"); } } - else if (incomingPacket->flags&prudpPacket::FLAG_NEED_ACK) + else if (incomingPacket->flags & prudpPacket::FLAG_NEED_ACK) { // other side is asking for ping ack cemuLog_log(LogType::PRUDP, "[PRUDP] Received ping packet with NEED_ACK set. Sending ACK back"); - cemu_assert_debug(incomingPacket->packetData.empty()); // todo - echo data? - prudpPacket ackPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_PING, prudpPacket::FLAG_ACK, this->clientSessionId, incomingPacket->sequenceId, 0); - directSendPacket(&ackPacket, dstIp, dstPort); + prudpPacket ackPacket(&m_streamSettings, m_srcVPort, m_dstVPort, prudpPacket::TYPE_PING, prudpPacket::FLAG_ACK, this->m_clientSessionId, incomingPacket->sequenceId, 0); + if(!incomingPacket->packetData.empty()) + ackPacket.setData(incomingPacket->packetData.data(), incomingPacket->packetData.size()); + DirectSendPacket(&ackPacket); } - delete incomingPacket; return; } - // handle general packet ACK - if (incomingPacket->flags&prudpPacket::FLAG_ACK) + else if (incomingPacket->type == prudpPacket::TYPE_SYN) { - acknowledgePacket(incomingPacket->sequenceId); - } - // special cases - if (incomingPacket->type == prudpPacket::TYPE_SYN) - { - if (hasSentCon == false && incomingPacket->hasData && incomingPacket->packetData.size() == 4) + // syn packet from server is expected to have ACK set + if (!(incomingPacket->flags & prudpPacket::FLAG_ACK)) { - this->serverConnectionSignature = *(uint32*)&incomingPacket->packetData.front(); - this->clientSessionId = prudp_generateRandomU8(); - // generate client session id - this->clientConnectionSignature = prudp_generateRandomU32(); - // send con packet - prudpPacket* conPacket = new prudpPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_CON, prudpPacket::FLAG_NEED_ACK|prudpPacket::FLAG_RELIABLE, this->clientSessionId, outgoingSequenceId, serverConnectionSignature); - outgoingSequenceId++; - - if (this->isSecureConnection) - { - // set packet specific data (client connection signature) - uint8 tempBuffer[512]; - nexPacketBuffer conData(tempBuffer, sizeof(tempBuffer), true); - conData.writeU32(this->clientConnectionSignature); - conData.writeBuffer(authInfo.secureTicket, authInfo.secureTicketLength); - // encrypted request data - uint8 requestData[4 * 3]; - uint8 requestDataEncrypted[4 * 3 + 0x10]; - *(uint32*)(requestData + 0x0) = authInfo.userPid; - *(uint32*)(requestData + 0x4) = authInfo.server.cid; - *(uint32*)(requestData + 0x8) = prudp_generateRandomU32(); // todo - check value - sint32 encryptedSize = kerberosEncryptData(requestData, sizeof(requestData), requestDataEncrypted); - conData.writeBuffer(requestDataEncrypted, encryptedSize); - conPacket->setData(conData.getDataPtr(), conData.getWriteIndex()); - } - else - { - // set packet specific data (client connection signature) - conPacket->setData((uint8*)&this->clientConnectionSignature, sizeof(uint32)); - } - // send packet - queuePacket(conPacket, dstIp, dstPort); - // remember con packet as sent - hasSentCon = true; + cemuLog_log(LogType::Force, "[PRUDP] Received SYN packet without ACK flag set"); // always log this + return; } - delete incomingPacket; + if (m_hasSynAck || !incomingPacket->hasData || incomingPacket->packetData.size() != 4) + { + // syn already acked or not a valid syn packet + cemuLog_log(LogType::PRUDP, "[PRUDP] Received unexpected SYN packet"); + return; + } + m_hasSynAck = true; + this->m_serverConnectionSignature = *(uint32*)&incomingPacket->packetData.front(); + // generate client session id and connection signature + this->m_clientSessionId = prudp_generateRandomU8(); + this->m_clientConnectionSignature = prudp_generateRandomU32(); + // send con packet + m_handshakeRetryCount = 0; + SendCurrentHandshakePacket(); return; } else if (incomingPacket->type == prudpPacket::TYPE_CON) { - // connected! - if (currentConnectionState == STATE_CONNECTING) + if (!m_hasSynAck || m_hasConAck) { - lastPingTimestamp = prudpGetMSTimestamp(); - cemu_assert_debug(serverSessionId == 0); - serverSessionId = incomingPacket->sessionId; - currentConnectionState = STATE_CONNECTED; - cemuLog_log(LogType::PRUDP, "[PRUDP] Connection established. ClientSession {:02x} ServerSession {:02x}", clientSessionId, serverSessionId); + cemuLog_log(LogType::PRUDP, "[PRUDP] Received unexpected CON packet"); + return; } - delete incomingPacket; + // make sure the packet has the ACK flag set + if (!(incomingPacket->flags & prudpPacket::FLAG_ACK)) + { + cemuLog_log(LogType::Force, "[PRUDP] Received CON packet without ACK flag set"); + return; + } + m_hasConAck = true; + m_handshakeRetryCount = 0; + cemu_assert_debug(m_currentConnectionState == ConnectionState::Connecting); + // connected! + m_lastPingTimestamp = prudpGetMSTimestamp(); + cemu_assert_debug(m_serverSessionId == 0); + m_serverSessionId = incomingPacket->sessionId; + m_currentConnectionState = ConnectionState::Connected; + cemuLog_log(LogType::PRUDP, "[PRUDP] Connection established. ClientSession {:02x} ServerSession {:02x}", m_clientSessionId, m_serverSessionId); return; } else if (incomingPacket->type == prudpPacket::TYPE_DATA) { - // send ack back if requested - if (incomingPacket->flags&prudpPacket::FLAG_NEED_ACK) + // handle ACK + if (incomingPacket->flags & prudpPacket::FLAG_ACK) { - prudpPacket ackPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_DATA, prudpPacket::FLAG_ACK, this->clientSessionId, incomingPacket->sequenceId, 0); - directSendPacket(&ackPacket, dstIp, dstPort); + AcknowledgePacket(incomingPacket->sequenceId); + if(!incomingPacket->packetData.empty()) + cemuLog_log(LogType::PRUDP, "[PRUDP] Received ACK data packet with payload"); + return; + } + // send ack back if requested + if (incomingPacket->flags & prudpPacket::FLAG_NEED_ACK) + { + prudpPacket ackPacket(&m_streamSettings, m_srcVPort, m_dstVPort, prudpPacket::TYPE_DATA, prudpPacket::FLAG_ACK, this->m_clientSessionId, incomingPacket->sequenceId, 0); + DirectSendPacket(&ackPacket); } // skip data packets without payload if (incomingPacket->packetData.empty()) - { - delete incomingPacket; return; - } - // verify some values - uint16 seqDist = incomingPacket->sequenceId - incomingSequenceId; + // verify sequence id + uint16 seqDist = incomingPacket->sequenceId - m_incomingSequenceId; if (seqDist >= 0xC000) { // outdated - delete incomingPacket; return; } // check if packet is already queued - for (auto& it : queue_incomingPackets) + for (auto& it : m_incomingPacketQueue) { if (it->sequenceId == incomingPacket->sequenceId) { // already queued (should check other values too, like packet type?) cemuLog_log(LogType::PRUDP, "Duplicate PRUDP packet received"); - delete incomingPacket; return; } } // put into ordered receive queue - sortIncomingDataPacket(incomingPacket); + SortIncomingDataPacket(std::move(incomingPacket)); } else if (incomingPacket->type == prudpPacket::TYPE_DISCONNECT) { - currentConnectionState = STATE_DISCONNECTED; + m_currentConnectionState = ConnectionState::Disconnected; return; } else { - // ignore unknown packet - delete incomingPacket; - return; + cemuLog_log(LogType::PRUDP, "[PRUDP] Received unknown packet type"); } } -bool prudpClient::update() +bool prudpClient::Update() { - if (currentConnectionState == STATE_DISCONNECTED) + if (m_currentConnectionState == ConnectionState::Disconnected) return false; uint32 currentTimestamp = prudpGetMSTimestamp(); // check for incoming packets uint8 receiveBuffer[4096]; while (true) { - sockaddr receiveFrom = { 0 }; + sockaddr receiveFrom = {0}; socklen_t receiveFromLen = sizeof(receiveFrom); - sint32 r = recvfrom(socketUdp, (char*)receiveBuffer, sizeof(receiveBuffer), 0, &receiveFrom, &receiveFromLen); + sint32 r = recvfrom(m_socketUdp, (char*)receiveBuffer, sizeof(receiveBuffer), 0, &receiveFrom, &receiveFromLen); if (r >= 0) { // todo: Verify sender (receiveFrom) @@ -807,203 +809,195 @@ bool prudpClient::update() cemuLog_log(LogType::Force, "[PRUDP] Invalid packet length"); break; } - prudpIncomingPacket* incomingPacket = new prudpIncomingPacket(&streamSettings, receiveBuffer + pIdx, packetLength); + auto incomingPacket = std::make_unique(&m_streamSettings, receiveBuffer + pIdx, packetLength); pIdx += packetLength; if (incomingPacket->hasError()) { cemuLog_log(LogType::Force, "[PRUDP] Packet error"); - delete incomingPacket; break; } - if (incomingPacket->type != prudpPacket::TYPE_CON && incomingPacket->sessionId != serverSessionId) + if (incomingPacket->type != prudpPacket::TYPE_CON && incomingPacket->sessionId != m_serverSessionId) { cemuLog_log(LogType::PRUDP, "[PRUDP] Invalid session id"); - delete incomingPacket; continue; // different session } - handleIncomingPacket(incomingPacket); + HandleIncomingPacket(std::move(incomingPacket)); } } else break; } // check for ack timeouts - for (auto &it : list_packetsWithAckReq) + for (auto& it : m_dataPacketsWithAckReq) { if ((currentTimestamp - it.lastRetryTimestamp) >= 2300) { if (it.retryCount >= 7) { // after too many retries consider the connection dead - currentConnectionState = STATE_DISCONNECTED; + m_currentConnectionState = ConnectionState::Disconnected; } // resend - directSendPacket(it.packet, dstIp, dstPort); + DirectSendPacket(it.packet); it.lastRetryTimestamp = currentTimestamp; it.retryCount++; } } - // check if we need to send another ping - if (currentConnectionState == STATE_CONNECTED) + if (m_currentConnectionState == ConnectionState::Connecting) { - if(m_unacknowledgedPingCount != 0) // counts how many times we sent a ping packet (for the current sequenceId) without receiving an ack + // check if we need to resend SYN or CON + uint32 timeSinceLastHandshake = currentTimestamp - m_lastHandshakeTimestamp; + if (timeSinceLastHandshake >= 1200) + { + if (m_handshakeRetryCount >= 5) + { + // too many retries, assume the other side doesn't listen + m_currentConnectionState = ConnectionState::Disconnected; + cemuLog_log(LogType::PRUDP, "PRUDP: Failed to connect"); + return false; + } + SendCurrentHandshakePacket(); + } + } + else if (m_currentConnectionState == ConnectionState::Connected) + { + // handle pings + if (m_unacknowledgedPingCount != 0) // counts how many times we sent a ping packet (for the current sequenceId) without receiving an ack { // we are waiting for the ack of the previous ping, but it hasn't arrived yet so send another ping packet - if((currentTimestamp - lastPingTimestamp) >= 1500) + if ((currentTimestamp - m_lastPingTimestamp) >= 1500) { cemuLog_log(LogType::PRUDP, "[PRUDP] Resending ping packet (no ack received)"); - if(m_unacknowledgedPingCount >= 10) + if (m_unacknowledgedPingCount >= 10) { // too many unacknowledged pings, assume the connection is dead - currentConnectionState = STATE_DISCONNECTED; + m_currentConnectionState = ConnectionState::Disconnected; cemuLog_log(LogType::PRUDP, "PRUDP: Connection did not receive a ping response in a while. Assuming disconnect"); return false; } // resend the ping packet - prudpPacket* pingPacket = new prudpPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_PING, prudpPacket::FLAG_NEED_ACK, this->clientSessionId, this->m_outgoingSequenceId_ping, serverConnectionSignature); - directSendPacket(pingPacket, dstIp, dstPort); + prudpPacket pingPacket(&m_streamSettings, m_srcVPort, m_dstVPort, prudpPacket::TYPE_PING, prudpPacket::FLAG_NEED_ACK, this->m_clientSessionId, this->m_outgoingSequenceId_ping, m_serverConnectionSignature); + DirectSendPacket(&pingPacket); m_unacknowledgedPingCount++; - delete pingPacket; - lastPingTimestamp = currentTimestamp; + m_lastPingTimestamp = currentTimestamp; } } else { - if((currentTimestamp - lastPingTimestamp) >= 20000) + if ((currentTimestamp - m_lastPingTimestamp) >= 20000) { - cemuLog_log(LogType::PRUDP, "[PRUDP] Sending new ping packet with sequenceId {}", this->m_outgoingSequenceId_ping+1); + cemuLog_log(LogType::PRUDP, "[PRUDP] Sending new ping packet with sequenceId {}", this->m_outgoingSequenceId_ping + 1); // start a new ping packet with a new sequenceId. Note that ping packets have their own sequenceId and acknowledgement happens by manually comparing the incoming ping ACK against the last sent sequenceId // only one unacknowledged ping packet can be in flight at a time. We will resend the same ping packet until we receive an ack this->m_outgoingSequenceId_ping++; // increment before sending. The first ping has a sequenceId of 1 - prudpPacket* pingPacket = new prudpPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_PING, prudpPacket::FLAG_NEED_ACK, this->clientSessionId, this->m_outgoingSequenceId_ping, serverConnectionSignature); - directSendPacket(pingPacket, dstIp, dstPort); + prudpPacket pingPacket(&m_streamSettings, m_srcVPort, m_dstVPort, prudpPacket::TYPE_PING, prudpPacket::FLAG_NEED_ACK, this->m_clientSessionId, this->m_outgoingSequenceId_ping, m_serverConnectionSignature); + DirectSendPacket(&pingPacket); m_unacknowledgedPingCount++; - delete pingPacket; - lastPingTimestamp = currentTimestamp; + m_lastPingTimestamp = currentTimestamp; } } } return false; } -void prudpClient::directSendPacket(prudpPacket* packet, uint32 dstIp, uint16 dstPort) +void prudpClient::DirectSendPacket(prudpPacket* packet) { uint8 packetBuffer[prudpPacket::PACKET_RAW_SIZE_MAX]; - sint32 len = packet->buildData(packetBuffer, prudpPacket::PACKET_RAW_SIZE_MAX); - sockaddr_in destAddr; destAddr.sin_family = AF_INET; - destAddr.sin_port = htons(dstPort); - destAddr.sin_addr.s_addr = dstIp; - sendto(socketUdp, (const char*)packetBuffer, len, 0, (const sockaddr*)&destAddr, sizeof(destAddr)); + destAddr.sin_port = htons(m_dstPort); + destAddr.sin_addr.s_addr = m_dstIp; + sendto(m_socketUdp, (const char*)packetBuffer, len, 0, (const sockaddr*)&destAddr, sizeof(destAddr)); } -void prudpClient::queuePacket(prudpPacket* packet, uint32 dstIp, uint16 dstPort) +void prudpClient::QueuePacket(prudpPacket* packet) { + cemu_assert_debug(packet->GetType() == prudpPacket::TYPE_DATA); // only data packets should be queued if (packet->requiresAck()) { - cemu_assert_debug(packet->GetType() != prudpPacket::TYPE_PING); // ping packets use their own logic for acks, dont queue them // remember this packet until we receive the ack - prudpAckRequired_t ackRequired = { 0 }; - ackRequired.packet = packet; - ackRequired.initialSendTimestamp = prudpGetMSTimestamp(); - ackRequired.lastRetryTimestamp = ackRequired.initialSendTimestamp; - list_packetsWithAckReq.push_back(ackRequired); - directSendPacket(packet, dstIp, dstPort); + m_dataPacketsWithAckReq.emplace_back(packet, prudpGetMSTimestamp()); + DirectSendPacket(packet); } else { - directSendPacket(packet, dstIp, dstPort); + DirectSendPacket(packet); delete packet; } } -void prudpClient::sendDatagram(uint8* input, sint32 length, bool reliable) +void prudpClient::SendDatagram(uint8* input, sint32 length, bool reliable) { - cemu_assert_debug(reliable); // non-reliable packets require testing - if(length >= 0x300) + cemu_assert_debug(reliable); // non-reliable packets require correct sequenceId handling and testing + cemu_assert_debug(m_hasSynAck && m_hasConAck); // cant send data packets before we are connected + if (length >= 0x300) { - cemuLog_logOnce(LogType::Force, "PRUDP: Datagram too long"); + cemuLog_logOnce(LogType::Force, "PRUDP: Datagram too long. Fragmentation not implemented yet"); } // single fragment data packet uint16 flags = prudpPacket::FLAG_NEED_ACK; - if(reliable) + if (reliable) flags |= prudpPacket::FLAG_RELIABLE; - prudpPacket* packet = new prudpPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_DATA, flags, clientSessionId, outgoingSequenceId, 0); - if(reliable) - outgoingSequenceId++; + prudpPacket* packet = new prudpPacket(&m_streamSettings, m_srcVPort, m_dstVPort, prudpPacket::TYPE_DATA, flags, m_clientSessionId, m_outgoingReliableSequenceId, 0); + if (reliable) + m_outgoingReliableSequenceId++; packet->setFragmentIndex(0); packet->setData(input, length); - queuePacket(packet, dstIp, dstPort); + QueuePacket(packet); } -uint16 prudpClient::getSourcePort() +sint32 prudpClient::ReceiveDatagram(std::vector& outputBuffer) { - return this->srcPort; -} - -SOCKET prudpClient::getSocket() -{ - if (currentConnectionState == STATE_DISCONNECTED) - { - return INVALID_SOCKET; - } - return this->socketUdp; -} - -sint32 prudpClient::receiveDatagram(std::vector& outputBuffer) -{ - if (queue_incomingPackets.empty()) + outputBuffer.clear(); + if (m_incomingPacketQueue.empty()) return -1; - prudpIncomingPacket* incomingPacket = queue_incomingPackets[0]; - if (incomingPacket->sequenceId != this->incomingSequenceId) + prudpIncomingPacket* frontPacket = m_incomingPacketQueue[0].get(); + if (frontPacket->sequenceId != this->m_incomingSequenceId) return -1; - - if (incomingPacket->fragmentIndex == 0) + if (frontPacket->fragmentIndex == 0) { // single-fragment packet // decrypt - incomingPacket->decrypt(); + frontPacket->decrypt(); // read data - sint32 datagramLen = (sint32)incomingPacket->packetData.size(); - if (datagramLen > 0) + if (!frontPacket->packetData.empty()) { - // resize buffer if necessary - if (datagramLen > outputBuffer.size()) - outputBuffer.resize(datagramLen); - // to conserve memory we will also shrink the buffer if it was previously extended beyond 64KB - constexpr size_t BUFFER_TARGET_SIZE = 1024 * 64; - if (datagramLen < BUFFER_TARGET_SIZE && outputBuffer.size() > BUFFER_TARGET_SIZE) + // to conserve memory we will also shrink the buffer if it was previously extended beyond 32KB + constexpr size_t BUFFER_TARGET_SIZE = 1024 * 32; + if (frontPacket->packetData.size() < BUFFER_TARGET_SIZE && outputBuffer.capacity() > BUFFER_TARGET_SIZE) + { outputBuffer.resize(BUFFER_TARGET_SIZE); - // copy datagram to buffer - memcpy(outputBuffer.data(), &incomingPacket->packetData.front(), datagramLen); + outputBuffer.shrink_to_fit(); + outputBuffer.clear(); + } + // write packet data to output buffer + cemu_assert_debug(outputBuffer.empty()); + outputBuffer.insert(outputBuffer.end(), frontPacket->packetData.begin(), frontPacket->packetData.end()); } - delete incomingPacket; - // remove packet from queue - queue_incomingPackets.erase(queue_incomingPackets.begin()); + m_incomingPacketQueue.erase(m_incomingPacketQueue.begin()); // advance expected sequence id - this->incomingSequenceId++; - return datagramLen; + this->m_incomingSequenceId++; + return (sint32)outputBuffer.size(); } else { // multi-fragment packet - if (incomingPacket->fragmentIndex != 1) + if (frontPacket->fragmentIndex != 1) return -1; // first packet of the chain not received yet // verify chain sint32 packetIndex = 1; sint32 chainLength = -1; // if full chain found, set to count of packets - for(sint32 i=1; ifragmentIndex; + uint8 itFragmentIndex = m_incomingPacketQueue[packetIndex]->fragmentIndex; // sequence id must increase by 1 for every packet - if (queue_incomingPackets[packetIndex]->sequenceId != (this->incomingSequenceId+i) ) + if (m_incomingPacketQueue[packetIndex]->sequenceId != (m_incomingSequenceId + i)) return -1; // missing packets // last fragment in chain is marked by fragment index 0 if (itFragmentIndex == 0) { - chainLength = i+1; + chainLength = i + 1; break; } packetIndex++; @@ -1011,29 +1005,17 @@ sint32 prudpClient::receiveDatagram(std::vector& outputBuffer) if (chainLength < 1) return -1; // chain not complete // extract data from packet chain - sint32 writeIndex = 0; + cemu_assert_debug(outputBuffer.empty()); for (sint32 i = 0; i < chainLength; i++) { - incomingPacket = queue_incomingPackets[i]; - // decrypt + prudpIncomingPacket* incomingPacket = m_incomingPacketQueue[i].get(); incomingPacket->decrypt(); - // extract data - sint32 datagramLen = (sint32)incomingPacket->packetData.size(); - if (datagramLen > 0) - { - // make sure output buffer can fit the data - if ((writeIndex + datagramLen) > outputBuffer.size()) - outputBuffer.resize(writeIndex + datagramLen + 4 * 1024); - memcpy(outputBuffer.data()+writeIndex, &incomingPacket->packetData.front(), datagramLen); - writeIndex += datagramLen; - } - // free packet memory - delete incomingPacket; + outputBuffer.insert(outputBuffer.end(), incomingPacket->packetData.begin(), incomingPacket->packetData.end()); } // remove packets from queue - queue_incomingPackets.erase(queue_incomingPackets.begin(), queue_incomingPackets.begin() + chainLength); - this->incomingSequenceId += chainLength; - return writeIndex; + m_incomingPacketQueue.erase(m_incomingPacketQueue.begin(), m_incomingPacketQueue.begin() + chainLength); + m_incomingSequenceId += chainLength; + return (sint32)outputBuffer.size(); } return -1; } diff --git a/src/Cemu/nex/prudp.h b/src/Cemu/nex/prudp.h index 5ed5bcb1..3192c833 100644 --- a/src/Cemu/nex/prudp.h +++ b/src/Cemu/nex/prudp.h @@ -4,26 +4,26 @@ #define RC4_N 256 -typedef struct +struct RC4Ctx { unsigned char S[RC4_N]; int i; int j; -}RC4Ctx_t; +}; -void RC4_initCtx(RC4Ctx_t* rc4Ctx, char *key); -void RC4_initCtx(RC4Ctx_t* rc4Ctx, unsigned char* key, int keyLen); -void RC4_transform(RC4Ctx_t* rc4Ctx, unsigned char* input, int len, unsigned char* output); +void RC4_initCtx(RC4Ctx* rc4Ctx, const char* key); +void RC4_initCtx(RC4Ctx* rc4Ctx, unsigned char* key, int keyLen); +void RC4_transform(RC4Ctx* rc4Ctx, unsigned char* input, int len, unsigned char* output); -typedef struct +struct prudpStreamSettings { uint8 checksumBase; // calculated from key uint8 accessKeyDigest[16]; // MD5 hash of key - RC4Ctx_t rc4Client; - RC4Ctx_t rc4Server; -}prudpStreamSettings_t; + RC4Ctx rc4Client; + RC4Ctx rc4Server; +}; -typedef struct +struct prudpStationUrl { uint32 ip; uint16 port; @@ -32,19 +32,17 @@ typedef struct sint32 sid; sint32 stream; sint32 type; -}stationUrl_t; +}; -typedef struct +struct prudpAuthServerInfo { uint32 userPid; uint8 secureKey[16]; uint8 kerberosKey[16]; uint8 secureTicket[1024]; sint32 secureTicketLength; - stationUrl_t server; -}authServerInfo_t; - -uint8 prudp_calculateChecksum(uint8 checksumBase, uint8* data, sint32 length); + prudpStationUrl server; +}; class prudpPacket { @@ -66,7 +64,7 @@ public: static sint32 calculateSizeFromPacketData(uint8* data, sint32 length); - prudpPacket(prudpStreamSettings_t* streamSettings, uint8 src, uint8 dst, uint8 type, uint16 flags, uint8 sessionId, uint16 sequenceId, uint32 packetSignature); + prudpPacket(prudpStreamSettings* streamSettings, uint8 src, uint8 dst, uint8 type, uint16 flags, uint8 sessionId, uint16 sequenceId, uint32 packetSignature); bool requiresAck(); void setData(uint8* data, sint32 length); void setFragmentIndex(uint8 fragmentIndex); @@ -87,7 +85,7 @@ private: uint16 flags; uint8 sessionId; uint32 specifiedPacketSignature; - prudpStreamSettings_t* streamSettings; + prudpStreamSettings* streamSettings; std::vector packetData; bool isEncrypted; uint16 m_sequenceId{0}; @@ -97,7 +95,7 @@ private: class prudpIncomingPacket { public: - prudpIncomingPacket(prudpStreamSettings_t* streamSettings, uint8* data, sint32 length); + prudpIncomingPacket(prudpStreamSettings* streamSettings, uint8* data, sint32 length); bool hasError(); @@ -122,83 +120,91 @@ public: private: bool isInvalid = false; - prudpStreamSettings_t* streamSettings = nullptr; - + prudpStreamSettings* streamSettings = nullptr; }; -typedef struct -{ - prudpPacket* packet; - uint32 initialSendTimestamp; - uint32 lastRetryTimestamp; - sint32 retryCount; -}prudpAckRequired_t; - class prudpClient { + struct PacketWithAckRequired + { + PacketWithAckRequired(prudpPacket* packet, uint32 initialSendTimestamp) : + packet(packet), initialSendTimestamp(initialSendTimestamp), lastRetryTimestamp(initialSendTimestamp) { } + prudpPacket* packet; + uint32 initialSendTimestamp; + uint32 lastRetryTimestamp; + sint32 retryCount{0}; + }; public: - static const int STATE_CONNECTING = 0; - static const int STATE_CONNECTED = 1; - static const int STATE_DISCONNECTED = 2; + enum class ConnectionState : uint8 + { + Connecting, + Connected, + Disconnected + }; -public: prudpClient(uint32 dstIp, uint16 dstPort, const char* key); - prudpClient(uint32 dstIp, uint16 dstPort, const char* key, authServerInfo_t* authInfo); + prudpClient(uint32 dstIp, uint16 dstPort, const char* key, prudpAuthServerInfo* authInfo); ~prudpClient(); - bool isConnected(); + bool IsConnected() const { return m_currentConnectionState == ConnectionState::Connected; } + ConnectionState GetConnectionState() const { return m_currentConnectionState; } + uint16 GetSourcePort() const { return m_srcPort; } - uint8 getConnectionState(); - void acknowledgePacket(uint16 sequenceId); - void sortIncomingDataPacket(prudpIncomingPacket* incomingPacket); - void handleIncomingPacket(prudpIncomingPacket* incomingPacket); - bool update(); // check for new incoming packets, returns true if receiveDatagram() should be called + bool Update(); // update connection state and check for incoming packets. Returns true if ReceiveDatagram() should be called - sint32 receiveDatagram(std::vector& outputBuffer); - void sendDatagram(uint8* input, sint32 length, bool reliable = true); - - uint16 getSourcePort(); - - SOCKET getSocket(); + sint32 ReceiveDatagram(std::vector& outputBuffer); + void SendDatagram(uint8* input, sint32 length, bool reliable = true); private: prudpClient(); - void directSendPacket(prudpPacket* packet, uint32 dstIp, uint16 dstPort); - sint32 kerberosEncryptData(uint8* input, sint32 length, uint8* output); - void queuePacket(prudpPacket* packet, uint32 dstIp, uint16 dstPort); + + void HandleIncomingPacket(std::unique_ptr incomingPacket); + void DirectSendPacket(prudpPacket* packet); + sint32 KerberosEncryptData(uint8* input, sint32 length, uint8* output); + void QueuePacket(prudpPacket* packet); + + void AcknowledgePacket(uint16 sequenceId); + void SortIncomingDataPacket(std::unique_ptr incomingPacket); + + void SendCurrentHandshakePacket(); private: - uint16 srcPort; - uint32 dstIp; - uint16 dstPort; - uint8 vport_src; - uint8 vport_dst; - prudpStreamSettings_t streamSettings; - std::vector list_packetsWithAckReq; - std::vector queue_incomingPackets; - - // connection - uint8 currentConnectionState; - uint32 serverConnectionSignature; - uint32 clientConnectionSignature; - bool hasSentCon; - uint32 lastPingTimestamp; + uint16 m_srcPort; + uint32 m_dstIp; + uint16 m_dstPort; + uint8 m_srcVPort; + uint8 m_dstVPort; + prudpStreamSettings m_streamSettings; + std::vector m_dataPacketsWithAckReq; + std::vector> m_incomingPacketQueue; - uint16 outgoingSequenceId; - uint16 incomingSequenceId; + // connection handshake state + bool m_hasSynAck{false}; + bool m_hasConAck{false}; + uint32 m_lastHandshakeTimestamp{0}; + uint8 m_handshakeRetryCount{0}; + + // connection + ConnectionState m_currentConnectionState; + uint32 m_serverConnectionSignature; + uint32 m_clientConnectionSignature; + uint32 m_lastPingTimestamp; + + uint16 m_outgoingReliableSequenceId{2}; // 1 is reserved for CON + uint16 m_incomingSequenceId; uint16 m_outgoingSequenceId_ping{0}; uint8 m_unacknowledgedPingCount{0}; - uint8 clientSessionId; - uint8 serverSessionId; + uint8 m_clientSessionId; + uint8 m_serverSessionId; // secure - bool isSecureConnection; - authServerInfo_t authInfo; + bool m_isSecureConnection{false}; + prudpAuthServerInfo m_authInfo; // socket - SOCKET socketUdp; + SOCKET m_socketUdp; }; uint32 prudpGetMSTimestamp(); \ No newline at end of file