prudp: Code cleanup

This commit is contained in:
Exzap 2024-04-18 19:22:28 +02:00
parent ee36992bd6
commit e2f9725719
3 changed files with 410 additions and 422 deletions

View File

@ -106,7 +106,7 @@ nexService::nexService()
nexService::nexService(prudpClient* con) : nexService() nexService::nexService(prudpClient* con) : nexService()
{ {
if (con->isConnected() == false) if (con->IsConnected() == false)
cemu_assert_suspicious(); cemu_assert_suspicious();
this->conNexService = con; this->conNexService = con;
bufferReceive = std::vector<uint8>(1024 * 4); bufferReceive = std::vector<uint8>(1024 * 4);
@ -191,7 +191,7 @@ void nexService::processQueuedRequest(queuedRequest_t* queuedRequest)
uint32 callId = _currentCallId; uint32 callId = _currentCallId;
_currentCallId++; _currentCallId++;
// check state of connection // check state of connection
if (conNexService->getConnectionState() != prudpClient::STATE_CONNECTED) if (conNexService->GetConnectionState() != prudpClient::ConnectionState::Connected)
{ {
nexServiceResponse_t response = { 0 }; nexServiceResponse_t response = { 0 };
response.isSuccessful = false; response.isSuccessful = false;
@ -214,7 +214,7 @@ void nexService::processQueuedRequest(queuedRequest_t* queuedRequest)
assert_dbg(); assert_dbg();
memcpy((packetBuffer + 0x0D), &queuedRequest->parameterData.front(), queuedRequest->parameterData.size()); memcpy((packetBuffer + 0x0D), &queuedRequest->parameterData.front(), queuedRequest->parameterData.size());
sint32 length = 0xD + (sint32)queuedRequest->parameterData.size(); sint32 length = 0xD + (sint32)queuedRequest->parameterData.size();
conNexService->sendDatagram(packetBuffer, length, true); conNexService->SendDatagram(packetBuffer, length, true);
// remember request // remember request
nexActiveRequestInfo_t requestInfo = { 0 }; nexActiveRequestInfo_t requestInfo = { 0 };
requestInfo.callId = callId; requestInfo.callId = callId;
@ -299,13 +299,13 @@ void nexService::registerForAsyncProcessing()
void nexService::updateTemporaryConnections() void nexService::updateTemporaryConnections()
{ {
// check for connection // check for connection
conNexService->update(); conNexService->Update();
if (conNexService->isConnected()) if (conNexService->IsConnected())
{ {
if (connectionState == STATE_CONNECTING) if (connectionState == STATE_CONNECTING)
connectionState = STATE_CONNECTED; connectionState = STATE_CONNECTED;
} }
if (conNexService->getConnectionState() == prudpClient::STATE_DISCONNECTED) if (conNexService->GetConnectionState() == prudpClient::ConnectionState::Disconnected)
connectionState = STATE_DISCONNECTED; connectionState = STATE_DISCONNECTED;
} }
@ -356,18 +356,18 @@ void nexService::sendRequestResponse(nexServiceRequest_t* request, uint32 errorC
// update length field // update length field
*(uint32*)response.getDataPtr() = response.getWriteIndex()-4; *(uint32*)response.getDataPtr() = response.getWriteIndex()-4;
if(request->nex->conNexService) if(request->nex->conNexService)
request->nex->conNexService->sendDatagram(response.getDataPtr(), response.getWriteIndex(), true); request->nex->conNexService->SendDatagram(response.getDataPtr(), response.getWriteIndex(), true);
} }
void nexService::updateNexServiceConnection() void nexService::updateNexServiceConnection()
{ {
if (conNexService->getConnectionState() == prudpClient::STATE_DISCONNECTED) if (conNexService->GetConnectionState() == prudpClient::ConnectionState::Disconnected)
{ {
this->connectionState = STATE_DISCONNECTED; this->connectionState = STATE_DISCONNECTED;
return; return;
} }
conNexService->update(); conNexService->Update();
sint32 datagramLen = conNexService->receiveDatagram(bufferReceive); sint32 datagramLen = conNexService->ReceiveDatagram(bufferReceive);
if (datagramLen > 0) if (datagramLen > 0)
{ {
if (nexIsRequest(&bufferReceive[0], datagramLen)) if (nexIsRequest(&bufferReceive[0], datagramLen))
@ -454,12 +454,12 @@ bool _extractStationUrlParamValue(const char* urlStr, const char* paramName, cha
return false; return false;
} }
void nexServiceAuthentication_parseStationURL(char* urlStr, stationUrl_t* stationUrl) void nexServiceAuthentication_parseStationURL(char* urlStr, prudpStationUrl* stationUrl)
{ {
// example: // example:
// prudps:/address=34.210.xxx.xxx;port=60181;CID=1;PID=2;sid=1;stream=10;type=2 // prudps:/address=34.210.xxx.xxx;port=60181;CID=1;PID=2;sid=1;stream=10;type=2
memset(stationUrl, 0, sizeof(stationUrl_t)); memset(stationUrl, 0, sizeof(prudpStationUrl));
char optionValue[128]; char optionValue[128];
if (_extractStationUrlParamValue(urlStr, "address", optionValue, sizeof(optionValue))) if (_extractStationUrlParamValue(urlStr, "address", optionValue, sizeof(optionValue)))
@ -499,7 +499,7 @@ typedef struct
sint32 kerberosTicketSize; sint32 kerberosTicketSize;
uint8 kerberosTicket2[4096]; uint8 kerberosTicket2[4096];
sint32 kerberosTicket2Size; sint32 kerberosTicket2Size;
stationUrl_t server; prudpStationUrl server;
// progress info // progress info
bool hasError; bool hasError;
bool done; bool done;
@ -611,18 +611,18 @@ void nexServiceSecure_handleResponse_RegisterEx(nexService* nex, nexServiceRespo
return; return;
} }
nexService* nex_secureLogin(authServerInfo_t* authServerInfo, const char* accessKey, const char* nexToken) nexService* nex_secureLogin(prudpAuthServerInfo* authServerInfo, const char* accessKey, const char* nexToken)
{ {
prudpClient* prudpSecureSock = new prudpClient(authServerInfo->server.ip, authServerInfo->server.port, accessKey, authServerInfo); prudpClient* prudpSecureSock = new prudpClient(authServerInfo->server.ip, authServerInfo->server.port, accessKey, authServerInfo);
// wait until connected // wait until connected
while (true) while (true)
{ {
prudpSecureSock->update(); prudpSecureSock->Update();
if (prudpSecureSock->isConnected()) if (prudpSecureSock->IsConnected())
{ {
break; break;
} }
if (prudpSecureSock->getConnectionState() == prudpClient::STATE_DISCONNECTED) if (prudpSecureSock->GetConnectionState() == prudpClient::ConnectionState::Disconnected)
{ {
// timeout or disconnected // timeout or disconnected
cemuLog_log(LogType::Force, "NEX: Secure login connection time-out"); cemuLog_log(LogType::Force, "NEX: Secure login connection time-out");
@ -638,7 +638,7 @@ nexService* nex_secureLogin(authServerInfo_t* authServerInfo, const char* access
nexPacketBuffer packetBuffer(tempNexBufferArray, sizeof(tempNexBufferArray), true); nexPacketBuffer packetBuffer(tempNexBufferArray, sizeof(tempNexBufferArray), true);
char clientStationUrl[256]; char clientStationUrl[256];
sprintf(clientStationUrl, "prudp:/port=%u;natf=0;natm=0;pmp=0;sid=15;type=2;upnp=0", (uint32)nex->getPRUDPConnection()->getSourcePort()); sprintf(clientStationUrl, "prudp:/port=%u;natf=0;natm=0;pmp=0;sid=15;type=2;upnp=0", (uint32)nex->getPRUDPConnection()->GetSourcePort());
// station url list // station url list
packetBuffer.writeU32(1); packetBuffer.writeU32(1);
packetBuffer.writeString(clientStationUrl); packetBuffer.writeString(clientStationUrl);
@ -737,9 +737,9 @@ nexService* nex_establishSecureConnection(uint32 authServerIp, uint16 authServer
return nullptr; return nullptr;
} }
// auth info // auth info
auto authServerInfo = std::make_unique<authServerInfo_t>(); auto authServerInfo = std::make_unique<prudpAuthServerInfo>();
// decrypt ticket // decrypt ticket
RC4Ctx_t rc4Ticket; RC4Ctx rc4Ticket;
RC4_initCtx(&rc4Ticket, kerberosKey, 16); RC4_initCtx(&rc4Ticket, kerberosKey, 16);
RC4_transform(&rc4Ticket, nexAuthService.kerberosTicket2, nexAuthService.kerberosTicket2Size - 16, nexAuthService.kerberosTicket2); RC4_transform(&rc4Ticket, nexAuthService.kerberosTicket2, nexAuthService.kerberosTicket2Size - 16, nexAuthService.kerberosTicket2);
nexPacketBuffer packetKerberosTicket(nexAuthService.kerberosTicket2, nexAuthService.kerberosTicket2Size - 16, false); nexPacketBuffer packetKerberosTicket(nexAuthService.kerberosTicket2, nexAuthService.kerberosTicket2Size - 16, false);
@ -756,7 +756,7 @@ nexService* nex_establishSecureConnection(uint32 authServerIp, uint16 authServer
memcpy(authServerInfo->kerberosKey, kerberosKey, 16); memcpy(authServerInfo->kerberosKey, kerberosKey, 16);
memcpy(authServerInfo->secureKey, secureKey, 16); memcpy(authServerInfo->secureKey, secureKey, 16);
memcpy(&authServerInfo->server, &nexAuthService.server, sizeof(stationUrl_t)); memcpy(&authServerInfo->server, &nexAuthService.server, sizeof(prudpStationUrl));
authServerInfo->userPid = pid; authServerInfo->userPid = pid;
return nex_secureLogin(authServerInfo.get(), accessKey, nexToken); return nex_secureLogin(authServerInfo.get(), accessKey, nexToken);

View File

@ -6,67 +6,52 @@
#include <boost/random/uniform_int.hpp> #include <boost/random/uniform_int.hpp>
void swap(unsigned char *a, unsigned char *b) static void KSA(unsigned char* key, int keyLen, unsigned char* S)
{ {
int tmp = *a;
*a = *b;
*b = tmp;
}
void KSA(unsigned char *key, int keyLen, unsigned char *S)
{
int j = 0;
for (int i = 0; i < RC4_N; i++) for (int i = 0; i < RC4_N; i++)
S[i] = i; S[i] = i;
int j = 0;
for (int i = 0; i < RC4_N; i++) for (int i = 0; i < RC4_N; i++)
{ {
j = (j + S[i] + key[i % keyLen]) % RC4_N; j = (j + S[i] + key[i % keyLen]) % RC4_N;
std::swap(S[i], S[j]);
swap(&S[i], &S[j]);
} }
} }
void PRGA(unsigned char *S, unsigned char* input, int len, unsigned char* output) static void PRGA(unsigned char* S, unsigned char* input, int len, unsigned char* output)
{ {
int i = 0;
int j = 0;
for (size_t n = 0; n < len; n++) for (size_t n = 0; n < len; n++)
{ {
i = (i + 1) % RC4_N; int i = (i + 1) % RC4_N;
j = (j + S[i]) % RC4_N; int j = (j + S[i]) % RC4_N;
std::swap(S[i], S[j]);
swap(&S[i], &S[j]);
int rnd = S[(S[i] + S[j]) % RC4_N]; int rnd = S[(S[i] + S[j]) % RC4_N];
output[n] = rnd ^ input[n]; output[n] = rnd ^ input[n];
} }
} }
void RC4(char* key, unsigned char* input, int len, unsigned char* output) static void RC4(char* key, unsigned char* input, int len, unsigned char* output)
{ {
unsigned char S[RC4_N]; unsigned char S[RC4_N];
KSA((unsigned char*)key, (int)strlen(key), S); KSA((unsigned char*)key, (int)strlen(key), S);
PRGA(S, input, len, output); PRGA(S, input, len, output);
} }
void RC4_initCtx(RC4Ctx_t* rc4Ctx, const char* key) void RC4_initCtx(RC4Ctx* rc4Ctx, const char* key)
{ {
rc4Ctx->i = 0; rc4Ctx->i = 0;
rc4Ctx->j = 0; rc4Ctx->j = 0;
KSA((unsigned char*)key, (int)strlen(key), rc4Ctx->S); KSA((unsigned char*)key, (int)strlen(key), rc4Ctx->S);
} }
void RC4_initCtx(RC4Ctx_t* rc4Ctx, unsigned char* key, int keyLen) void RC4_initCtx(RC4Ctx* rc4Ctx, unsigned char* key, int keyLen)
{ {
rc4Ctx->i = 0; rc4Ctx->i = 0;
rc4Ctx->j = 0; rc4Ctx->j = 0;
KSA(key, keyLen, rc4Ctx->S); KSA(key, keyLen, rc4Ctx->S);
} }
void RC4_transform(RC4Ctx_t* rc4Ctx, unsigned char* input, int len, unsigned char* output) void RC4_transform(RC4Ctx* rc4Ctx, unsigned char* input, int len, unsigned char* output)
{ {
int i = rc4Ctx->i; int i = rc4Ctx->i;
int j = rc4Ctx->j; int j = rc4Ctx->j;
@ -75,13 +60,10 @@ void RC4_transform(RC4Ctx_t* rc4Ctx, unsigned char* input, int len, unsigned cha
{ {
i = (i + 1) % RC4_N; i = (i + 1) % RC4_N;
j = (j + rc4Ctx->S[i]) % RC4_N; j = (j + rc4Ctx->S[i]) % RC4_N;
std::swap(rc4Ctx->S[i], rc4Ctx->S[j]);
swap(&rc4Ctx->S[i], &rc4Ctx->S[j]);
int rnd = rc4Ctx->S[(rc4Ctx->S[i] + rc4Ctx->S[j]) % RC4_N]; int rnd = rc4Ctx->S[(rc4Ctx->S[i] + rc4Ctx->S[j]) % RC4_N];
output[n] = rnd ^ input[n]; output[n] = rnd ^ input[n];
} }
rc4Ctx->i = i; rc4Ctx->i = i;
rc4Ctx->j = j; rc4Ctx->j = j;
} }
@ -91,34 +73,14 @@ uint32 prudpGetMSTimestamp()
return GetTickCount(); return GetTickCount();
} }
std::bitset<10000> _portUsageMask;
uint16 getRandomSrcPRUDPPort()
{
while (true)
{
sint32 p = rand() % 10000;
if (_portUsageMask.test(p))
continue;
_portUsageMask.set(p);
return 40000 + p;
}
return 0;
}
void releasePRUDPPort(uint16 port)
{
uint32 bitIndex = port - 40000;
_portUsageMask.reset(bitIndex);
}
std::mt19937_64 prudpRG(GetTickCount()); std::mt19937_64 prudpRG(GetTickCount());
// workaround for static asserts when using uniform_int_distribution // workaround for static asserts when using uniform_int_distribution (see https://github.com/cemu-project/Cemu/issues/48)
boost::random::uniform_int_distribution<int> prudpDis8(0, 0xFF); boost::random::uniform_int_distribution<int> prudpRandomDistribution8(0, 0xFF);
boost::random::uniform_int_distribution<int> prudpRandomDistributionPortGen(0, 10000);
uint8 prudp_generateRandomU8() uint8 prudp_generateRandomU8()
{ {
return prudpDis8(prudpRG); return prudpRandomDistribution8(prudpRG);
} }
uint32 prudp_generateRandomU32() uint32 prudp_generateRandomU32()
@ -133,7 +95,29 @@ uint32 prudp_generateRandomU32()
return v; return v;
} }
uint8 prudp_calculateChecksum(uint8 checksumBase, uint8* data, sint32 length) std::bitset<10000> _portUsageMask;
static uint16 AllocateRandomSrcPRUDPPort()
{
while (true)
{
sint32 p = prudpRandomDistributionPortGen(prudpRG);
if (_portUsageMask.test(p))
continue;
_portUsageMask.set(p);
return 40000 + p;
}
}
static void ReleasePRUDPSrcPort(uint16 port)
{
cemu_assert_debug(port >= 40000);
uint32 bitIndex = port - 40000;
cemu_assert_debug(_portUsageMask.test(bitIndex));
_portUsageMask.reset(bitIndex);
}
static uint8 prudp_calculateChecksum(uint8 checksumBase, uint8* data, sint32 length)
{ {
uint32 checksum32 = 0; uint32 checksum32 = 0;
for (sint32 i = 0; i < length / 4; i++) for (sint32 i = 0; i < length / 4; i++)
@ -212,7 +196,7 @@ sint32 prudpPacket::calculateSizeFromPacketData(uint8* data, sint32 length)
return length; return length;
} }
prudpPacket::prudpPacket(prudpStreamSettings_t* streamSettings, uint8 src, uint8 dst, uint8 type, uint16 flags, uint8 sessionId, uint16 sequenceId, uint32 packetSignature) prudpPacket::prudpPacket(prudpStreamSettings* streamSettings, uint8 src, uint8 dst, uint8 type, uint16 flags, uint8 sessionId, uint16 sequenceId, uint32 packetSignature)
{ {
this->src = src; this->src = src;
this->dst = dst; this->dst = dst;
@ -352,7 +336,8 @@ prudpIncomingPacket::prudpIncomingPacket()
streamSettings = nullptr; streamSettings = nullptr;
} }
prudpIncomingPacket::prudpIncomingPacket(prudpStreamSettings_t* streamSettings, uint8* data, sint32 length) : prudpIncomingPacket() prudpIncomingPacket::prudpIncomingPacket(prudpStreamSettings* streamSettings, uint8* data, sint32 length)
: prudpIncomingPacket()
{ {
if (length < 0xB + 1) if (length < 0xB + 1)
{ {
@ -479,53 +464,41 @@ void prudpIncomingPacket::decrypt()
prudpClient::prudpClient() prudpClient::prudpClient()
{ {
currentConnectionState = STATE_CONNECTING; m_currentConnectionState = ConnectionState::Connecting;
serverConnectionSignature = 0; m_serverConnectionSignature = 0;
clientConnectionSignature = 0; m_clientConnectionSignature = 0;
hasSentCon = false; m_incomingSequenceId = 0;
outgoingSequenceId = 0;
incomingSequenceId = 0;
clientSessionId = 0; m_clientSessionId = 0;
serverSessionId = 0; m_serverSessionId = 0;
isSecureConnection = false;
} }
prudpClient::~prudpClient() prudpClient::prudpClient(uint32 dstIp, uint16 dstPort, const char* key)
: prudpClient()
{ {
if (srcPort != 0) m_dstIp = dstIp;
{ m_dstPort = dstPort;
releasePRUDPPort(srcPort);
closesocket(socketUdp);
}
}
prudpClient::prudpClient(uint32 dstIp, uint16 dstPort, const char* key) : prudpClient()
{
this->dstIp = dstIp;
this->dstPort = dstPort;
// get unused random source port // get unused random source port
for (sint32 tries = 0; tries < 5; tries++) for (sint32 tries = 0; tries < 5; tries++)
{ {
srcPort = getRandomSrcPRUDPPort(); m_srcPort = AllocateRandomSrcPRUDPPort();
// create and bind udp socket // create and bind udp socket
socketUdp = socket(AF_INET, SOCK_DGRAM, 0); m_socketUdp = socket(AF_INET, SOCK_DGRAM, 0);
struct sockaddr_in udpServer; struct sockaddr_in udpServer;
udpServer.sin_family = AF_INET; udpServer.sin_family = AF_INET;
udpServer.sin_addr.s_addr = INADDR_ANY; udpServer.sin_addr.s_addr = INADDR_ANY;
udpServer.sin_port = htons(srcPort); udpServer.sin_port = htons(m_srcPort);
if (bind(socketUdp, (struct sockaddr *)&udpServer, sizeof(udpServer)) == SOCKET_ERROR) if (bind(m_socketUdp, (struct sockaddr*)&udpServer, sizeof(udpServer)) == SOCKET_ERROR)
{ {
ReleasePRUDPSrcPort(m_srcPort);
m_srcPort = 0;
if (tries == 4) if (tries == 4)
{ {
cemuLog_log(LogType::Force, "PRUDP: Failed to bind UDP socket"); cemuLog_log(LogType::Force, "PRUDP: Failed to bind UDP socket");
currentConnectionState = STATE_DISCONNECTED; m_currentConnectionState = ConnectionState::Disconnected;
srcPort = 0;
return; return;
} }
releasePRUDPPort(srcPort); closesocket(m_socketUdp);
closesocket(socketUdp);
continue; continue;
} }
else else
@ -534,77 +507,75 @@ prudpClient::prudpClient(uint32 dstIp, uint16 dstPort, const char* key) : prudpC
// set socket to non-blocking mode // set socket to non-blocking mode
#if BOOST_OS_WINDOWS #if BOOST_OS_WINDOWS
u_long nonBlockingMode = 1; // 1 to enable non-blocking socket u_long nonBlockingMode = 1; // 1 to enable non-blocking socket
ioctlsocket(socketUdp, FIONBIO, &nonBlockingMode); ioctlsocket(m_socketUdp, FIONBIO, &nonBlockingMode);
#else #else
int flags = fcntl(socketUdp, F_GETFL); int flags = fcntl(socketUdp, F_GETFL);
fcntl(socketUdp, F_SETFL, flags | O_NONBLOCK); fcntl(socketUdp, F_SETFL, flags | O_NONBLOCK);
#endif #endif
// generate frequently used parameters // generate frequently used parameters
this->vport_src = PRUDP_VPORT(prudpPacket::STREAM_TYPE_SECURE, 0xF); this->m_srcVPort = PRUDP_VPORT(prudpPacket::STREAM_TYPE_SECURE, 0xF);
this->vport_dst = PRUDP_VPORT(prudpPacket::STREAM_TYPE_SECURE, 0x1); this->m_dstVPort = PRUDP_VPORT(prudpPacket::STREAM_TYPE_SECURE, 0x1);
// set stream settings // set stream settings
uint8 checksumBase = 0; uint8 checksumBase = 0;
for (sint32 i = 0; key[i] != '\0'; i++) for (sint32 i = 0; key[i] != '\0'; i++)
{ {
checksumBase += key[i]; checksumBase += key[i];
} }
streamSettings.checksumBase = checksumBase; m_streamSettings.checksumBase = checksumBase;
MD5_CTX md5Ctx; MD5_CTX md5Ctx;
MD5_Init(&md5Ctx); MD5_Init(&md5Ctx);
MD5_Update(&md5Ctx, key, (int)strlen(key)); MD5_Update(&md5Ctx, key, (int)strlen(key));
MD5_Final(streamSettings.accessKeyDigest, &md5Ctx); MD5_Final(m_streamSettings.accessKeyDigest, &md5Ctx);
// init stream ciphers // init stream ciphers
RC4_initCtx(&streamSettings.rc4Server, "CD&ML"); RC4_initCtx(&m_streamSettings.rc4Server, "CD&ML");
RC4_initCtx(&streamSettings.rc4Client, "CD&ML"); RC4_initCtx(&m_streamSettings.rc4Client, "CD&ML");
// send syn packet // send syn packet
prudpPacket* synPacket = new prudpPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_SYN, prudpPacket::FLAG_NEED_ACK, 0, 0, 0); SendCurrentHandshakePacket();
queuePacket(synPacket, dstIp, dstPort);
outgoingSequenceId++;
// set incoming sequence id to 1 // set incoming sequence id to 1
incomingSequenceId = 1; m_incomingSequenceId = 1;
} }
prudpClient::prudpClient(uint32 dstIp, uint16 dstPort, const char* key, authServerInfo_t* authInfo) : prudpClient(dstIp, dstPort, key) prudpClient::prudpClient(uint32 dstIp, uint16 dstPort, const char* key, prudpAuthServerInfo* authInfo)
: prudpClient(dstIp, dstPort, key)
{ {
RC4_initCtx(&streamSettings.rc4Server, authInfo->secureKey, 16); RC4_initCtx(&m_streamSettings.rc4Server, authInfo->secureKey, 16);
RC4_initCtx(&streamSettings.rc4Client, authInfo->secureKey, 16); RC4_initCtx(&m_streamSettings.rc4Client, authInfo->secureKey, 16);
this->isSecureConnection = true; this->m_isSecureConnection = true;
memcpy(&this->authInfo, authInfo, sizeof(authServerInfo_t)); memcpy(&this->m_authInfo, authInfo, sizeof(prudpAuthServerInfo));
} }
bool prudpClient::isConnected() prudpClient::~prudpClient()
{ {
return currentConnectionState == STATE_CONNECTED; if (m_srcPort != 0)
{
ReleasePRUDPSrcPort(m_srcPort);
closesocket(m_socketUdp);
}
} }
uint8 prudpClient::getConnectionState() void prudpClient::AcknowledgePacket(uint16 sequenceId)
{ {
return currentConnectionState; auto it = std::begin(m_dataPacketsWithAckReq);
} while (it != std::end(m_dataPacketsWithAckReq))
void prudpClient::acknowledgePacket(uint16 sequenceId)
{
auto it = std::begin(list_packetsWithAckReq);
while (it != std::end(list_packetsWithAckReq))
{ {
if (it->packet->GetSequenceId() == sequenceId) if (it->packet->GetSequenceId() == sequenceId)
{ {
delete it->packet; delete it->packet;
list_packetsWithAckReq.erase(it); m_dataPacketsWithAckReq.erase(it);
return; return;
} }
it++; it++;
} }
} }
void prudpClient::sortIncomingDataPacket(prudpIncomingPacket* incomingPacket) void prudpClient::SortIncomingDataPacket(std::unique_ptr<prudpIncomingPacket> incomingPacket)
{ {
uint16 sequenceIdIncomingPacket = incomingPacket->sequenceId; uint16 sequenceIdIncomingPacket = incomingPacket->sequenceId;
// find insert index // find insert index
sint32 insertIndex = 0; sint32 insertIndex = 0;
while (insertIndex < queue_incomingPackets.size() ) while (insertIndex < m_incomingPacketQueue.size())
{ {
uint16 seqDif = sequenceIdIncomingPacket - queue_incomingPackets[insertIndex]->sequenceId; uint16 seqDif = sequenceIdIncomingPacket - m_incomingPacketQueue[insertIndex]->sequenceId;
if (seqDif & 0x8000) if (seqDif & 0x8000)
break; // negative seqDif -> insert before current element break; // negative seqDif -> insert before current element
#ifdef CEMU_DEBUG_ASSERT #ifdef CEMU_DEBUG_ASSERT
@ -613,28 +584,72 @@ void prudpClient::sortIncomingDataPacket(prudpIncomingPacket* incomingPacket)
#endif #endif
insertIndex++; insertIndex++;
} }
// insert m_incomingPacketQueue.insert(m_incomingPacketQueue.begin() + insertIndex, std::move(incomingPacket));
sint32 currentSize = (sint32)queue_incomingPackets.size(); // debug check if packets are really ordered by sequence id
queue_incomingPackets.resize(currentSize+1); #ifdef CEMU_DEBUG_ASSERT
for(sint32 i=currentSize; i>insertIndex; i--) for (sint32 i = 1; i < m_incomingPacketQueue.size(); i++)
{ {
queue_incomingPackets[i] = queue_incomingPackets[i - 1]; uint16 seqDif = m_incomingPacketQueue[i]->sequenceId - m_incomingPacketQueue[i - 1]->sequenceId;
if (seqDif & 0x8000)
seqDif = -seqDif;
if (seqDif >= 0x8000)
assert_dbg();
} }
queue_incomingPackets[insertIndex] = incomingPacket; #endif
} }
sint32 prudpClient::kerberosEncryptData(uint8* input, sint32 length, uint8* output) sint32 prudpClient::KerberosEncryptData(uint8* input, sint32 length, uint8* output)
{ {
RC4Ctx_t rc4Kerberos; RC4Ctx rc4Kerberos;
RC4_initCtx(&rc4Kerberos, this->authInfo.secureKey, 16); RC4_initCtx(&rc4Kerberos, this->m_authInfo.secureKey, 16);
memcpy(output, input, length); memcpy(output, input, length);
RC4_transform(&rc4Kerberos, output, length, output); RC4_transform(&rc4Kerberos, output, length, output);
// calculate and append hmac // calculate and append hmac
hmacMD5(this->authInfo.secureKey, 16, output, length, output+length); hmacMD5(this->m_authInfo.secureKey, 16, output, length, output + length);
return length + 16; return length + 16;
} }
void prudpClient::handleIncomingPacket(prudpIncomingPacket* incomingPacket) // (re)sends either CON or SYN based on what stage of the login we are at
// the sequenceId for both is hardcoded for both because we'll never send anything in between
void prudpClient::SendCurrentHandshakePacket()
{
if (!m_hasSynAck)
{
// send syn (with a fixed sequenceId of 0)
prudpPacket synPacket(&m_streamSettings, m_srcVPort, m_dstVPort, prudpPacket::TYPE_SYN, prudpPacket::FLAG_NEED_ACK, 0, 0, 0);
DirectSendPacket(&synPacket);
}
else
{
// send con (with a fixed sequenceId of 1)
prudpPacket conPacket(&m_streamSettings, m_srcVPort, m_dstVPort, prudpPacket::TYPE_CON, prudpPacket::FLAG_NEED_ACK | prudpPacket::FLAG_RELIABLE, this->m_clientSessionId, 1, m_serverConnectionSignature);
if (this->m_isSecureConnection)
{
uint8 tempBuffer[512];
nexPacketBuffer conData(tempBuffer, sizeof(tempBuffer), true);
conData.writeU32(this->m_clientConnectionSignature);
conData.writeBuffer(m_authInfo.secureTicket, m_authInfo.secureTicketLength);
// encrypted request data
uint8 requestData[4 * 3];
uint8 requestDataEncrypted[4 * 3 + 0x10];
*(uint32*)(requestData + 0x0) = m_authInfo.userPid;
*(uint32*)(requestData + 0x4) = m_authInfo.server.cid;
*(uint32*)(requestData + 0x8) = prudp_generateRandomU32(); // todo - check value
sint32 encryptedSize = KerberosEncryptData(requestData, sizeof(requestData), requestDataEncrypted);
conData.writeBuffer(requestDataEncrypted, encryptedSize);
conPacket.setData(conData.getDataPtr(), conData.getWriteIndex());
}
else
{
conPacket.setData((uint8*)&this->m_clientConnectionSignature, sizeof(uint32));
}
DirectSendPacket(&conPacket);
}
m_lastHandshakeTimestamp = prudpGetMSTimestamp();
m_handshakeRetryCount++;
}
void prudpClient::HandleIncomingPacket(std::unique_ptr<prudpIncomingPacket> incomingPacket)
{ {
if (incomingPacket->type == prudpPacket::TYPE_PING) if (incomingPacket->type == prudpPacket::TYPE_PING)
{ {
@ -664,127 +679,114 @@ void prudpClient::handleIncomingPacket(prudpIncomingPacket* incomingPacket)
{ {
// other side is asking for ping ack // other side is asking for ping ack
cemuLog_log(LogType::PRUDP, "[PRUDP] Received ping packet with NEED_ACK set. Sending ACK back"); cemuLog_log(LogType::PRUDP, "[PRUDP] Received ping packet with NEED_ACK set. Sending ACK back");
cemu_assert_debug(incomingPacket->packetData.empty()); // todo - echo data? prudpPacket ackPacket(&m_streamSettings, m_srcVPort, m_dstVPort, prudpPacket::TYPE_PING, prudpPacket::FLAG_ACK, this->m_clientSessionId, incomingPacket->sequenceId, 0);
prudpPacket ackPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_PING, prudpPacket::FLAG_ACK, this->clientSessionId, incomingPacket->sequenceId, 0); if(!incomingPacket->packetData.empty())
directSendPacket(&ackPacket, dstIp, dstPort); ackPacket.setData(incomingPacket->packetData.data(), incomingPacket->packetData.size());
DirectSendPacket(&ackPacket);
} }
delete incomingPacket;
return; return;
} }
// handle general packet ACK else if (incomingPacket->type == prudpPacket::TYPE_SYN)
if (incomingPacket->flags&prudpPacket::FLAG_ACK)
{ {
acknowledgePacket(incomingPacket->sequenceId); // syn packet from server is expected to have ACK set
if (!(incomingPacket->flags & prudpPacket::FLAG_ACK))
{
cemuLog_log(LogType::Force, "[PRUDP] Received SYN packet without ACK flag set"); // always log this
return;
} }
// special cases if (m_hasSynAck || !incomingPacket->hasData || incomingPacket->packetData.size() != 4)
if (incomingPacket->type == prudpPacket::TYPE_SYN)
{ {
if (hasSentCon == false && incomingPacket->hasData && incomingPacket->packetData.size() == 4) // syn already acked or not a valid syn packet
{ cemuLog_log(LogType::PRUDP, "[PRUDP] Received unexpected SYN packet");
this->serverConnectionSignature = *(uint32*)&incomingPacket->packetData.front(); return;
this->clientSessionId = prudp_generateRandomU8(); }
// generate client session id m_hasSynAck = true;
this->clientConnectionSignature = prudp_generateRandomU32(); this->m_serverConnectionSignature = *(uint32*)&incomingPacket->packetData.front();
// generate client session id and connection signature
this->m_clientSessionId = prudp_generateRandomU8();
this->m_clientConnectionSignature = prudp_generateRandomU32();
// send con packet // send con packet
prudpPacket* conPacket = new prudpPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_CON, prudpPacket::FLAG_NEED_ACK|prudpPacket::FLAG_RELIABLE, this->clientSessionId, outgoingSequenceId, serverConnectionSignature); m_handshakeRetryCount = 0;
outgoingSequenceId++; SendCurrentHandshakePacket();
if (this->isSecureConnection)
{
// set packet specific data (client connection signature)
uint8 tempBuffer[512];
nexPacketBuffer conData(tempBuffer, sizeof(tempBuffer), true);
conData.writeU32(this->clientConnectionSignature);
conData.writeBuffer(authInfo.secureTicket, authInfo.secureTicketLength);
// encrypted request data
uint8 requestData[4 * 3];
uint8 requestDataEncrypted[4 * 3 + 0x10];
*(uint32*)(requestData + 0x0) = authInfo.userPid;
*(uint32*)(requestData + 0x4) = authInfo.server.cid;
*(uint32*)(requestData + 0x8) = prudp_generateRandomU32(); // todo - check value
sint32 encryptedSize = kerberosEncryptData(requestData, sizeof(requestData), requestDataEncrypted);
conData.writeBuffer(requestDataEncrypted, encryptedSize);
conPacket->setData(conData.getDataPtr(), conData.getWriteIndex());
}
else
{
// set packet specific data (client connection signature)
conPacket->setData((uint8*)&this->clientConnectionSignature, sizeof(uint32));
}
// send packet
queuePacket(conPacket, dstIp, dstPort);
// remember con packet as sent
hasSentCon = true;
}
delete incomingPacket;
return; return;
} }
else if (incomingPacket->type == prudpPacket::TYPE_CON) else if (incomingPacket->type == prudpPacket::TYPE_CON)
{ {
// connected! if (!m_hasSynAck || m_hasConAck)
if (currentConnectionState == STATE_CONNECTING)
{ {
lastPingTimestamp = prudpGetMSTimestamp(); cemuLog_log(LogType::PRUDP, "[PRUDP] Received unexpected CON packet");
cemu_assert_debug(serverSessionId == 0); return;
serverSessionId = incomingPacket->sessionId;
currentConnectionState = STATE_CONNECTED;
cemuLog_log(LogType::PRUDP, "[PRUDP] Connection established. ClientSession {:02x} ServerSession {:02x}", clientSessionId, serverSessionId);
} }
delete incomingPacket; // make sure the packet has the ACK flag set
if (!(incomingPacket->flags & prudpPacket::FLAG_ACK))
{
cemuLog_log(LogType::Force, "[PRUDP] Received CON packet without ACK flag set");
return;
}
m_hasConAck = true;
m_handshakeRetryCount = 0;
cemu_assert_debug(m_currentConnectionState == ConnectionState::Connecting);
// connected!
m_lastPingTimestamp = prudpGetMSTimestamp();
cemu_assert_debug(m_serverSessionId == 0);
m_serverSessionId = incomingPacket->sessionId;
m_currentConnectionState = ConnectionState::Connected;
cemuLog_log(LogType::PRUDP, "[PRUDP] Connection established. ClientSession {:02x} ServerSession {:02x}", m_clientSessionId, m_serverSessionId);
return; return;
} }
else if (incomingPacket->type == prudpPacket::TYPE_DATA) else if (incomingPacket->type == prudpPacket::TYPE_DATA)
{ {
// handle ACK
if (incomingPacket->flags & prudpPacket::FLAG_ACK)
{
AcknowledgePacket(incomingPacket->sequenceId);
if(!incomingPacket->packetData.empty())
cemuLog_log(LogType::PRUDP, "[PRUDP] Received ACK data packet with payload");
return;
}
// send ack back if requested // send ack back if requested
if (incomingPacket->flags & prudpPacket::FLAG_NEED_ACK) if (incomingPacket->flags & prudpPacket::FLAG_NEED_ACK)
{ {
prudpPacket ackPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_DATA, prudpPacket::FLAG_ACK, this->clientSessionId, incomingPacket->sequenceId, 0); prudpPacket ackPacket(&m_streamSettings, m_srcVPort, m_dstVPort, prudpPacket::TYPE_DATA, prudpPacket::FLAG_ACK, this->m_clientSessionId, incomingPacket->sequenceId, 0);
directSendPacket(&ackPacket, dstIp, dstPort); DirectSendPacket(&ackPacket);
} }
// skip data packets without payload // skip data packets without payload
if (incomingPacket->packetData.empty()) if (incomingPacket->packetData.empty())
{
delete incomingPacket;
return; return;
} // verify sequence id
// verify some values uint16 seqDist = incomingPacket->sequenceId - m_incomingSequenceId;
uint16 seqDist = incomingPacket->sequenceId - incomingSequenceId;
if (seqDist >= 0xC000) if (seqDist >= 0xC000)
{ {
// outdated // outdated
delete incomingPacket;
return; return;
} }
// check if packet is already queued // check if packet is already queued
for (auto& it : queue_incomingPackets) for (auto& it : m_incomingPacketQueue)
{ {
if (it->sequenceId == incomingPacket->sequenceId) if (it->sequenceId == incomingPacket->sequenceId)
{ {
// already queued (should check other values too, like packet type?) // already queued (should check other values too, like packet type?)
cemuLog_log(LogType::PRUDP, "Duplicate PRUDP packet received"); cemuLog_log(LogType::PRUDP, "Duplicate PRUDP packet received");
delete incomingPacket;
return; return;
} }
} }
// put into ordered receive queue // put into ordered receive queue
sortIncomingDataPacket(incomingPacket); SortIncomingDataPacket(std::move(incomingPacket));
} }
else if (incomingPacket->type == prudpPacket::TYPE_DISCONNECT) else if (incomingPacket->type == prudpPacket::TYPE_DISCONNECT)
{ {
currentConnectionState = STATE_DISCONNECTED; m_currentConnectionState = ConnectionState::Disconnected;
return; return;
} }
else else
{ {
// ignore unknown packet cemuLog_log(LogType::PRUDP, "[PRUDP] Received unknown packet type");
delete incomingPacket;
return;
} }
} }
bool prudpClient::update() bool prudpClient::Update()
{ {
if (currentConnectionState == STATE_DISCONNECTED) if (m_currentConnectionState == ConnectionState::Disconnected)
return false; return false;
uint32 currentTimestamp = prudpGetMSTimestamp(); uint32 currentTimestamp = prudpGetMSTimestamp();
// check for incoming packets // check for incoming packets
@ -793,7 +795,7 @@ bool prudpClient::update()
{ {
sockaddr receiveFrom = {0}; sockaddr receiveFrom = {0};
socklen_t receiveFromLen = sizeof(receiveFrom); socklen_t receiveFromLen = sizeof(receiveFrom);
sint32 r = recvfrom(socketUdp, (char*)receiveBuffer, sizeof(receiveBuffer), 0, &receiveFrom, &receiveFromLen); sint32 r = recvfrom(m_socketUdp, (char*)receiveBuffer, sizeof(receiveBuffer), 0, &receiveFrom, &receiveFromLen);
if (r >= 0) if (r >= 0)
{ {
// todo: Verify sender (receiveFrom) // todo: Verify sender (receiveFrom)
@ -807,198 +809,190 @@ bool prudpClient::update()
cemuLog_log(LogType::Force, "[PRUDP] Invalid packet length"); cemuLog_log(LogType::Force, "[PRUDP] Invalid packet length");
break; break;
} }
prudpIncomingPacket* incomingPacket = new prudpIncomingPacket(&streamSettings, receiveBuffer + pIdx, packetLength); auto incomingPacket = std::make_unique<prudpIncomingPacket>(&m_streamSettings, receiveBuffer + pIdx, packetLength);
pIdx += packetLength; pIdx += packetLength;
if (incomingPacket->hasError()) if (incomingPacket->hasError())
{ {
cemuLog_log(LogType::Force, "[PRUDP] Packet error"); cemuLog_log(LogType::Force, "[PRUDP] Packet error");
delete incomingPacket;
break; break;
} }
if (incomingPacket->type != prudpPacket::TYPE_CON && incomingPacket->sessionId != serverSessionId) if (incomingPacket->type != prudpPacket::TYPE_CON && incomingPacket->sessionId != m_serverSessionId)
{ {
cemuLog_log(LogType::PRUDP, "[PRUDP] Invalid session id"); cemuLog_log(LogType::PRUDP, "[PRUDP] Invalid session id");
delete incomingPacket;
continue; // different session continue; // different session
} }
handleIncomingPacket(incomingPacket); HandleIncomingPacket(std::move(incomingPacket));
} }
} }
else else
break; break;
} }
// check for ack timeouts // check for ack timeouts
for (auto &it : list_packetsWithAckReq) for (auto& it : m_dataPacketsWithAckReq)
{ {
if ((currentTimestamp - it.lastRetryTimestamp) >= 2300) if ((currentTimestamp - it.lastRetryTimestamp) >= 2300)
{ {
if (it.retryCount >= 7) if (it.retryCount >= 7)
{ {
// after too many retries consider the connection dead // after too many retries consider the connection dead
currentConnectionState = STATE_DISCONNECTED; m_currentConnectionState = ConnectionState::Disconnected;
} }
// resend // resend
directSendPacket(it.packet, dstIp, dstPort); DirectSendPacket(it.packet);
it.lastRetryTimestamp = currentTimestamp; it.lastRetryTimestamp = currentTimestamp;
it.retryCount++; it.retryCount++;
} }
} }
// check if we need to send another ping if (m_currentConnectionState == ConnectionState::Connecting)
if (currentConnectionState == STATE_CONNECTED)
{ {
// check if we need to resend SYN or CON
uint32 timeSinceLastHandshake = currentTimestamp - m_lastHandshakeTimestamp;
if (timeSinceLastHandshake >= 1200)
{
if (m_handshakeRetryCount >= 5)
{
// too many retries, assume the other side doesn't listen
m_currentConnectionState = ConnectionState::Disconnected;
cemuLog_log(LogType::PRUDP, "PRUDP: Failed to connect");
return false;
}
SendCurrentHandshakePacket();
}
}
else if (m_currentConnectionState == ConnectionState::Connected)
{
// handle pings
if (m_unacknowledgedPingCount != 0) // counts how many times we sent a ping packet (for the current sequenceId) without receiving an ack if (m_unacknowledgedPingCount != 0) // counts how many times we sent a ping packet (for the current sequenceId) without receiving an ack
{ {
// we are waiting for the ack of the previous ping, but it hasn't arrived yet so send another ping packet // we are waiting for the ack of the previous ping, but it hasn't arrived yet so send another ping packet
if((currentTimestamp - lastPingTimestamp) >= 1500) if ((currentTimestamp - m_lastPingTimestamp) >= 1500)
{ {
cemuLog_log(LogType::PRUDP, "[PRUDP] Resending ping packet (no ack received)"); cemuLog_log(LogType::PRUDP, "[PRUDP] Resending ping packet (no ack received)");
if (m_unacknowledgedPingCount >= 10) if (m_unacknowledgedPingCount >= 10)
{ {
// too many unacknowledged pings, assume the connection is dead // too many unacknowledged pings, assume the connection is dead
currentConnectionState = STATE_DISCONNECTED; m_currentConnectionState = ConnectionState::Disconnected;
cemuLog_log(LogType::PRUDP, "PRUDP: Connection did not receive a ping response in a while. Assuming disconnect"); cemuLog_log(LogType::PRUDP, "PRUDP: Connection did not receive a ping response in a while. Assuming disconnect");
return false; return false;
} }
// resend the ping packet // resend the ping packet
prudpPacket* pingPacket = new prudpPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_PING, prudpPacket::FLAG_NEED_ACK, this->clientSessionId, this->m_outgoingSequenceId_ping, serverConnectionSignature); prudpPacket pingPacket(&m_streamSettings, m_srcVPort, m_dstVPort, prudpPacket::TYPE_PING, prudpPacket::FLAG_NEED_ACK, this->m_clientSessionId, this->m_outgoingSequenceId_ping, m_serverConnectionSignature);
directSendPacket(pingPacket, dstIp, dstPort); DirectSendPacket(&pingPacket);
m_unacknowledgedPingCount++; m_unacknowledgedPingCount++;
delete pingPacket; m_lastPingTimestamp = currentTimestamp;
lastPingTimestamp = currentTimestamp;
} }
} }
else else
{ {
if((currentTimestamp - lastPingTimestamp) >= 20000) if ((currentTimestamp - m_lastPingTimestamp) >= 20000)
{ {
cemuLog_log(LogType::PRUDP, "[PRUDP] Sending new ping packet with sequenceId {}", this->m_outgoingSequenceId_ping + 1); cemuLog_log(LogType::PRUDP, "[PRUDP] Sending new ping packet with sequenceId {}", this->m_outgoingSequenceId_ping + 1);
// start a new ping packet with a new sequenceId. Note that ping packets have their own sequenceId and acknowledgement happens by manually comparing the incoming ping ACK against the last sent sequenceId // start a new ping packet with a new sequenceId. Note that ping packets have their own sequenceId and acknowledgement happens by manually comparing the incoming ping ACK against the last sent sequenceId
// only one unacknowledged ping packet can be in flight at a time. We will resend the same ping packet until we receive an ack // only one unacknowledged ping packet can be in flight at a time. We will resend the same ping packet until we receive an ack
this->m_outgoingSequenceId_ping++; // increment before sending. The first ping has a sequenceId of 1 this->m_outgoingSequenceId_ping++; // increment before sending. The first ping has a sequenceId of 1
prudpPacket* pingPacket = new prudpPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_PING, prudpPacket::FLAG_NEED_ACK, this->clientSessionId, this->m_outgoingSequenceId_ping, serverConnectionSignature); prudpPacket pingPacket(&m_streamSettings, m_srcVPort, m_dstVPort, prudpPacket::TYPE_PING, prudpPacket::FLAG_NEED_ACK, this->m_clientSessionId, this->m_outgoingSequenceId_ping, m_serverConnectionSignature);
directSendPacket(pingPacket, dstIp, dstPort); DirectSendPacket(&pingPacket);
m_unacknowledgedPingCount++; m_unacknowledgedPingCount++;
delete pingPacket; m_lastPingTimestamp = currentTimestamp;
lastPingTimestamp = currentTimestamp;
} }
} }
} }
return false; return false;
} }
void prudpClient::directSendPacket(prudpPacket* packet, uint32 dstIp, uint16 dstPort) void prudpClient::DirectSendPacket(prudpPacket* packet)
{ {
uint8 packetBuffer[prudpPacket::PACKET_RAW_SIZE_MAX]; uint8 packetBuffer[prudpPacket::PACKET_RAW_SIZE_MAX];
sint32 len = packet->buildData(packetBuffer, prudpPacket::PACKET_RAW_SIZE_MAX); sint32 len = packet->buildData(packetBuffer, prudpPacket::PACKET_RAW_SIZE_MAX);
sockaddr_in destAddr; sockaddr_in destAddr;
destAddr.sin_family = AF_INET; destAddr.sin_family = AF_INET;
destAddr.sin_port = htons(dstPort); destAddr.sin_port = htons(m_dstPort);
destAddr.sin_addr.s_addr = dstIp; destAddr.sin_addr.s_addr = m_dstIp;
sendto(socketUdp, (const char*)packetBuffer, len, 0, (const sockaddr*)&destAddr, sizeof(destAddr)); sendto(m_socketUdp, (const char*)packetBuffer, len, 0, (const sockaddr*)&destAddr, sizeof(destAddr));
} }
void prudpClient::queuePacket(prudpPacket* packet, uint32 dstIp, uint16 dstPort) void prudpClient::QueuePacket(prudpPacket* packet)
{ {
cemu_assert_debug(packet->GetType() == prudpPacket::TYPE_DATA); // only data packets should be queued
if (packet->requiresAck()) if (packet->requiresAck())
{ {
cemu_assert_debug(packet->GetType() != prudpPacket::TYPE_PING); // ping packets use their own logic for acks, dont queue them
// remember this packet until we receive the ack // remember this packet until we receive the ack
prudpAckRequired_t ackRequired = { 0 }; m_dataPacketsWithAckReq.emplace_back(packet, prudpGetMSTimestamp());
ackRequired.packet = packet; DirectSendPacket(packet);
ackRequired.initialSendTimestamp = prudpGetMSTimestamp();
ackRequired.lastRetryTimestamp = ackRequired.initialSendTimestamp;
list_packetsWithAckReq.push_back(ackRequired);
directSendPacket(packet, dstIp, dstPort);
} }
else else
{ {
directSendPacket(packet, dstIp, dstPort); DirectSendPacket(packet);
delete packet; delete packet;
} }
} }
void prudpClient::sendDatagram(uint8* input, sint32 length, bool reliable) void prudpClient::SendDatagram(uint8* input, sint32 length, bool reliable)
{ {
cemu_assert_debug(reliable); // non-reliable packets require testing cemu_assert_debug(reliable); // non-reliable packets require correct sequenceId handling and testing
cemu_assert_debug(m_hasSynAck && m_hasConAck); // cant send data packets before we are connected
if (length >= 0x300) if (length >= 0x300)
{ {
cemuLog_logOnce(LogType::Force, "PRUDP: Datagram too long"); cemuLog_logOnce(LogType::Force, "PRUDP: Datagram too long. Fragmentation not implemented yet");
} }
// single fragment data packet // single fragment data packet
uint16 flags = prudpPacket::FLAG_NEED_ACK; uint16 flags = prudpPacket::FLAG_NEED_ACK;
if (reliable) if (reliable)
flags |= prudpPacket::FLAG_RELIABLE; flags |= prudpPacket::FLAG_RELIABLE;
prudpPacket* packet = new prudpPacket(&streamSettings, vport_src, vport_dst, prudpPacket::TYPE_DATA, flags, clientSessionId, outgoingSequenceId, 0); prudpPacket* packet = new prudpPacket(&m_streamSettings, m_srcVPort, m_dstVPort, prudpPacket::TYPE_DATA, flags, m_clientSessionId, m_outgoingReliableSequenceId, 0);
if (reliable) if (reliable)
outgoingSequenceId++; m_outgoingReliableSequenceId++;
packet->setFragmentIndex(0); packet->setFragmentIndex(0);
packet->setData(input, length); packet->setData(input, length);
queuePacket(packet, dstIp, dstPort); QueuePacket(packet);
} }
uint16 prudpClient::getSourcePort() sint32 prudpClient::ReceiveDatagram(std::vector<uint8>& outputBuffer)
{ {
return this->srcPort; outputBuffer.clear();
} if (m_incomingPacketQueue.empty())
SOCKET prudpClient::getSocket()
{
if (currentConnectionState == STATE_DISCONNECTED)
{
return INVALID_SOCKET;
}
return this->socketUdp;
}
sint32 prudpClient::receiveDatagram(std::vector<uint8>& outputBuffer)
{
if (queue_incomingPackets.empty())
return -1; return -1;
prudpIncomingPacket* incomingPacket = queue_incomingPackets[0]; prudpIncomingPacket* frontPacket = m_incomingPacketQueue[0].get();
if (incomingPacket->sequenceId != this->incomingSequenceId) if (frontPacket->sequenceId != this->m_incomingSequenceId)
return -1; return -1;
if (frontPacket->fragmentIndex == 0)
if (incomingPacket->fragmentIndex == 0)
{ {
// single-fragment packet // single-fragment packet
// decrypt // decrypt
incomingPacket->decrypt(); frontPacket->decrypt();
// read data // read data
sint32 datagramLen = (sint32)incomingPacket->packetData.size(); if (!frontPacket->packetData.empty())
if (datagramLen > 0) {
// to conserve memory we will also shrink the buffer if it was previously extended beyond 32KB
constexpr size_t BUFFER_TARGET_SIZE = 1024 * 32;
if (frontPacket->packetData.size() < BUFFER_TARGET_SIZE && outputBuffer.capacity() > BUFFER_TARGET_SIZE)
{ {
// resize buffer if necessary
if (datagramLen > outputBuffer.size())
outputBuffer.resize(datagramLen);
// to conserve memory we will also shrink the buffer if it was previously extended beyond 64KB
constexpr size_t BUFFER_TARGET_SIZE = 1024 * 64;
if (datagramLen < BUFFER_TARGET_SIZE && outputBuffer.size() > BUFFER_TARGET_SIZE)
outputBuffer.resize(BUFFER_TARGET_SIZE); outputBuffer.resize(BUFFER_TARGET_SIZE);
// copy datagram to buffer outputBuffer.shrink_to_fit();
memcpy(outputBuffer.data(), &incomingPacket->packetData.front(), datagramLen); outputBuffer.clear();
} }
delete incomingPacket; // write packet data to output buffer
// remove packet from queue cemu_assert_debug(outputBuffer.empty());
queue_incomingPackets.erase(queue_incomingPackets.begin()); outputBuffer.insert(outputBuffer.end(), frontPacket->packetData.begin(), frontPacket->packetData.end());
}
m_incomingPacketQueue.erase(m_incomingPacketQueue.begin());
// advance expected sequence id // advance expected sequence id
this->incomingSequenceId++; this->m_incomingSequenceId++;
return datagramLen; return (sint32)outputBuffer.size();
} }
else else
{ {
// multi-fragment packet // multi-fragment packet
if (incomingPacket->fragmentIndex != 1) if (frontPacket->fragmentIndex != 1)
return -1; // first packet of the chain not received yet return -1; // first packet of the chain not received yet
// verify chain // verify chain
sint32 packetIndex = 1; sint32 packetIndex = 1;
sint32 chainLength = -1; // if full chain found, set to count of packets sint32 chainLength = -1; // if full chain found, set to count of packets
for(sint32 i=1; i<queue_incomingPackets.size(); i++) for (sint32 i = 1; i < m_incomingPacketQueue.size(); i++)
{ {
uint8 itFragmentIndex = queue_incomingPackets[packetIndex]->fragmentIndex; uint8 itFragmentIndex = m_incomingPacketQueue[packetIndex]->fragmentIndex;
// sequence id must increase by 1 for every packet // sequence id must increase by 1 for every packet
if (queue_incomingPackets[packetIndex]->sequenceId != (this->incomingSequenceId+i) ) if (m_incomingPacketQueue[packetIndex]->sequenceId != (m_incomingSequenceId + i))
return -1; // missing packets return -1; // missing packets
// last fragment in chain is marked by fragment index 0 // last fragment in chain is marked by fragment index 0
if (itFragmentIndex == 0) if (itFragmentIndex == 0)
@ -1011,29 +1005,17 @@ sint32 prudpClient::receiveDatagram(std::vector<uint8>& outputBuffer)
if (chainLength < 1) if (chainLength < 1)
return -1; // chain not complete return -1; // chain not complete
// extract data from packet chain // extract data from packet chain
sint32 writeIndex = 0; cemu_assert_debug(outputBuffer.empty());
for (sint32 i = 0; i < chainLength; i++) for (sint32 i = 0; i < chainLength; i++)
{ {
incomingPacket = queue_incomingPackets[i]; prudpIncomingPacket* incomingPacket = m_incomingPacketQueue[i].get();
// decrypt
incomingPacket->decrypt(); incomingPacket->decrypt();
// extract data outputBuffer.insert(outputBuffer.end(), incomingPacket->packetData.begin(), incomingPacket->packetData.end());
sint32 datagramLen = (sint32)incomingPacket->packetData.size();
if (datagramLen > 0)
{
// make sure output buffer can fit the data
if ((writeIndex + datagramLen) > outputBuffer.size())
outputBuffer.resize(writeIndex + datagramLen + 4 * 1024);
memcpy(outputBuffer.data()+writeIndex, &incomingPacket->packetData.front(), datagramLen);
writeIndex += datagramLen;
}
// free packet memory
delete incomingPacket;
} }
// remove packets from queue // remove packets from queue
queue_incomingPackets.erase(queue_incomingPackets.begin(), queue_incomingPackets.begin() + chainLength); m_incomingPacketQueue.erase(m_incomingPacketQueue.begin(), m_incomingPacketQueue.begin() + chainLength);
this->incomingSequenceId += chainLength; m_incomingSequenceId += chainLength;
return writeIndex; return (sint32)outputBuffer.size();
} }
return -1; return -1;
} }

View File

@ -4,26 +4,26 @@
#define RC4_N 256 #define RC4_N 256
typedef struct struct RC4Ctx
{ {
unsigned char S[RC4_N]; unsigned char S[RC4_N];
int i; int i;
int j; int j;
}RC4Ctx_t; };
void RC4_initCtx(RC4Ctx_t* rc4Ctx, char *key); void RC4_initCtx(RC4Ctx* rc4Ctx, const char* key);
void RC4_initCtx(RC4Ctx_t* rc4Ctx, unsigned char* key, int keyLen); void RC4_initCtx(RC4Ctx* rc4Ctx, unsigned char* key, int keyLen);
void RC4_transform(RC4Ctx_t* rc4Ctx, unsigned char* input, int len, unsigned char* output); void RC4_transform(RC4Ctx* rc4Ctx, unsigned char* input, int len, unsigned char* output);
typedef struct struct prudpStreamSettings
{ {
uint8 checksumBase; // calculated from key uint8 checksumBase; // calculated from key
uint8 accessKeyDigest[16]; // MD5 hash of key uint8 accessKeyDigest[16]; // MD5 hash of key
RC4Ctx_t rc4Client; RC4Ctx rc4Client;
RC4Ctx_t rc4Server; RC4Ctx rc4Server;
}prudpStreamSettings_t; };
typedef struct struct prudpStationUrl
{ {
uint32 ip; uint32 ip;
uint16 port; uint16 port;
@ -32,19 +32,17 @@ typedef struct
sint32 sid; sint32 sid;
sint32 stream; sint32 stream;
sint32 type; sint32 type;
}stationUrl_t; };
typedef struct struct prudpAuthServerInfo
{ {
uint32 userPid; uint32 userPid;
uint8 secureKey[16]; uint8 secureKey[16];
uint8 kerberosKey[16]; uint8 kerberosKey[16];
uint8 secureTicket[1024]; uint8 secureTicket[1024];
sint32 secureTicketLength; sint32 secureTicketLength;
stationUrl_t server; prudpStationUrl server;
}authServerInfo_t; };
uint8 prudp_calculateChecksum(uint8 checksumBase, uint8* data, sint32 length);
class prudpPacket class prudpPacket
{ {
@ -66,7 +64,7 @@ public:
static sint32 calculateSizeFromPacketData(uint8* data, sint32 length); static sint32 calculateSizeFromPacketData(uint8* data, sint32 length);
prudpPacket(prudpStreamSettings_t* streamSettings, uint8 src, uint8 dst, uint8 type, uint16 flags, uint8 sessionId, uint16 sequenceId, uint32 packetSignature); prudpPacket(prudpStreamSettings* streamSettings, uint8 src, uint8 dst, uint8 type, uint16 flags, uint8 sessionId, uint16 sequenceId, uint32 packetSignature);
bool requiresAck(); bool requiresAck();
void setData(uint8* data, sint32 length); void setData(uint8* data, sint32 length);
void setFragmentIndex(uint8 fragmentIndex); void setFragmentIndex(uint8 fragmentIndex);
@ -87,7 +85,7 @@ private:
uint16 flags; uint16 flags;
uint8 sessionId; uint8 sessionId;
uint32 specifiedPacketSignature; uint32 specifiedPacketSignature;
prudpStreamSettings_t* streamSettings; prudpStreamSettings* streamSettings;
std::vector<uint8> packetData; std::vector<uint8> packetData;
bool isEncrypted; bool isEncrypted;
uint16 m_sequenceId{0}; uint16 m_sequenceId{0};
@ -97,7 +95,7 @@ private:
class prudpIncomingPacket class prudpIncomingPacket
{ {
public: public:
prudpIncomingPacket(prudpStreamSettings_t* streamSettings, uint8* data, sint32 length); prudpIncomingPacket(prudpStreamSettings* streamSettings, uint8* data, sint32 length);
bool hasError(); bool hasError();
@ -122,83 +120,91 @@ public:
private: private:
bool isInvalid = false; bool isInvalid = false;
prudpStreamSettings_t* streamSettings = nullptr; prudpStreamSettings* streamSettings = nullptr;
}; };
typedef struct
{
prudpPacket* packet;
uint32 initialSendTimestamp;
uint32 lastRetryTimestamp;
sint32 retryCount;
}prudpAckRequired_t;
class prudpClient class prudpClient
{ {
struct PacketWithAckRequired
{
PacketWithAckRequired(prudpPacket* packet, uint32 initialSendTimestamp) :
packet(packet), initialSendTimestamp(initialSendTimestamp), lastRetryTimestamp(initialSendTimestamp) { }
prudpPacket* packet;
uint32 initialSendTimestamp;
uint32 lastRetryTimestamp;
sint32 retryCount{0};
};
public: public:
static const int STATE_CONNECTING = 0; enum class ConnectionState : uint8
static const int STATE_CONNECTED = 1; {
static const int STATE_DISCONNECTED = 2; Connecting,
Connected,
Disconnected
};
public:
prudpClient(uint32 dstIp, uint16 dstPort, const char* key); prudpClient(uint32 dstIp, uint16 dstPort, const char* key);
prudpClient(uint32 dstIp, uint16 dstPort, const char* key, authServerInfo_t* authInfo); prudpClient(uint32 dstIp, uint16 dstPort, const char* key, prudpAuthServerInfo* authInfo);
~prudpClient(); ~prudpClient();
bool isConnected(); bool IsConnected() const { return m_currentConnectionState == ConnectionState::Connected; }
ConnectionState GetConnectionState() const { return m_currentConnectionState; }
uint16 GetSourcePort() const { return m_srcPort; }
uint8 getConnectionState(); bool Update(); // update connection state and check for incoming packets. Returns true if ReceiveDatagram() should be called
void acknowledgePacket(uint16 sequenceId);
void sortIncomingDataPacket(prudpIncomingPacket* incomingPacket);
void handleIncomingPacket(prudpIncomingPacket* incomingPacket);
bool update(); // check for new incoming packets, returns true if receiveDatagram() should be called
sint32 receiveDatagram(std::vector<uint8>& outputBuffer); sint32 ReceiveDatagram(std::vector<uint8>& outputBuffer);
void sendDatagram(uint8* input, sint32 length, bool reliable = true); void SendDatagram(uint8* input, sint32 length, bool reliable = true);
uint16 getSourcePort();
SOCKET getSocket();
private: private:
prudpClient(); prudpClient();
void directSendPacket(prudpPacket* packet, uint32 dstIp, uint16 dstPort);
sint32 kerberosEncryptData(uint8* input, sint32 length, uint8* output); void HandleIncomingPacket(std::unique_ptr<prudpIncomingPacket> incomingPacket);
void queuePacket(prudpPacket* packet, uint32 dstIp, uint16 dstPort); void DirectSendPacket(prudpPacket* packet);
sint32 KerberosEncryptData(uint8* input, sint32 length, uint8* output);
void QueuePacket(prudpPacket* packet);
void AcknowledgePacket(uint16 sequenceId);
void SortIncomingDataPacket(std::unique_ptr<prudpIncomingPacket> incomingPacket);
void SendCurrentHandshakePacket();
private: private:
uint16 srcPort; uint16 m_srcPort;
uint32 dstIp; uint32 m_dstIp;
uint16 dstPort; uint16 m_dstPort;
uint8 vport_src; uint8 m_srcVPort;
uint8 vport_dst; uint8 m_dstVPort;
prudpStreamSettings_t streamSettings; prudpStreamSettings m_streamSettings;
std::vector<prudpAckRequired_t> list_packetsWithAckReq; std::vector<PacketWithAckRequired> m_dataPacketsWithAckReq;
std::vector<prudpIncomingPacket*> queue_incomingPackets; std::vector<std::unique_ptr<prudpIncomingPacket>> m_incomingPacketQueue;
// connection handshake state
bool m_hasSynAck{false};
bool m_hasConAck{false};
uint32 m_lastHandshakeTimestamp{0};
uint8 m_handshakeRetryCount{0};
// connection // connection
uint8 currentConnectionState; ConnectionState m_currentConnectionState;
uint32 serverConnectionSignature; uint32 m_serverConnectionSignature;
uint32 clientConnectionSignature; uint32 m_clientConnectionSignature;
bool hasSentCon; uint32 m_lastPingTimestamp;
uint32 lastPingTimestamp;
uint16 outgoingSequenceId; uint16 m_outgoingReliableSequenceId{2}; // 1 is reserved for CON
uint16 incomingSequenceId; uint16 m_incomingSequenceId;
uint16 m_outgoingSequenceId_ping{0}; uint16 m_outgoingSequenceId_ping{0};
uint8 m_unacknowledgedPingCount{0}; uint8 m_unacknowledgedPingCount{0};
uint8 clientSessionId; uint8 m_clientSessionId;
uint8 serverSessionId; uint8 m_serverSessionId;
// secure // secure
bool isSecureConnection; bool m_isSecureConnection{false};
authServerInfo_t authInfo; prudpAuthServerInfo m_authInfo;
// socket // socket
SOCKET socketUdp; SOCKET m_socketUdp;
}; };
uint32 prudpGetMSTimestamp(); uint32 prudpGetMSTimestamp();