From 694c9e73d6dcefab43538a62a386ee75dced8e51 Mon Sep 17 00:00:00 2001 From: Jack Garrard Date: Sun, 23 Oct 2022 16:37:47 -0700 Subject: [PATCH] Allow for a udp handshake --- include/server/Client.hpp | 2 ++ include/server/SocketClient.hpp | 2 +- source/server/Client.cpp | 51 +++++++++++++++++++++++++++++++++ source/server/SocketClient.cpp | 35 +++++++++++----------- 4 files changed, 71 insertions(+), 19 deletions(-) diff --git a/include/server/Client.hpp b/include/server/Client.hpp index 16a2ca2..51edd92 100644 --- a/include/server/Client.hpp +++ b/include/server/Client.hpp @@ -197,6 +197,8 @@ class Client { void updateTagInfo(TagInf *packet); void updateCaptureInfo(CaptureInf* packet); void sendToStage(ChangeStagePacket* packet); + void sendUdpHolePunch(); + void sendUdpInit(); void disconnectPlayer(PlayerDC *packet); PuppetInfo* findPuppetInfo(const nn::account::Uid& id, bool isFindAvailable); diff --git a/include/server/SocketClient.hpp b/include/server/SocketClient.hpp index b579cd5..47ec8a6 100644 --- a/include/server/SocketClient.hpp +++ b/include/server/SocketClient.hpp @@ -44,7 +44,7 @@ class SocketClient : public SocketBase { void printPacket(Packet* packet); bool isConnected() { return socket_log_state == SOCKET_LOG_CONNECTED; } - u16 getUdpPort(); + u16 getLocalUdpPort(); s32 setPeerUdpPort(u16 port); u32 getSendCount() { return mSendQueue.getCount(); } diff --git a/source/server/Client.cpp b/source/server/Client.cpp index e2a3df4..418b092 100644 --- a/source/server/Client.cpp +++ b/source/server/Client.cpp @@ -383,6 +383,19 @@ void Client::readFunc() { maxPuppets = initPacket->maxPlayers - 1; break; } + case PacketType::UDPINIT: { + UdpInit* initPacket = (UdpInit*)curPacket; + Logger::log("Received udp init packet from server"); + + sInstance->mSocket->setPeerUdpPort(initPacket->port); + sendUdpHolePunch(); + sendUdpInit(); + + break; + } + case PacketType::HOLEPUNCH: + sendUdpHolePunch(); + break; default: Logger::log("Discarding Unknown Packet Type.\n"); break; @@ -939,7 +952,45 @@ void Client::sendToStage(ChangeStagePacket* packet) { GameDataFunction::tryChangeNextStage(accessor, &info); } } +/** + * @brief + * Send a udp holepunch packet to the server + */ +void Client::sendUdpHolePunch() { + if (!sInstance) { + Logger::log("Static Instance is Null!\n"); + return; + } + + sead::ScopedCurrentHeapSetter setter(sInstance->mHeap); + + HolePunch *packet = new HolePunch(); + + packet->mUserID = sInstance->mUserID; + + sInstance->mSocket->queuePacket(packet); +} +/** + * @brief + * Send a udp init packet to server + */ +void Client::sendUdpInit() { + + if (!sInstance) { + Logger::log("Static Instance is Null!\n"); + return; + } + + sead::ScopedCurrentHeapSetter setter(sInstance->mHeap); + + UdpInit *packet = new UdpInit(); + + packet->mUserID = sInstance->mUserID; + packet->port = sInstance->mSocket->getLocalUdpPort(); + + sInstance->mSocket->queuePacket(packet); +} /** * @brief * diff --git a/source/server/SocketClient.cpp b/source/server/SocketClient.cpp index 0091cbd..8ceffec 100644 --- a/source/server/SocketClient.cpp +++ b/source/server/SocketClient.cpp @@ -91,10 +91,7 @@ nn::Result SocketClient::init(const char* ip, u16 port) { udpAddress.port = nn::socket::InetHtons(41553); udpAddress.family = 2; this->udp_addr = udpAddress; - - // udpAddress.address = hostAddress; - // udpAddress.port = nn::socket::InetHtons(57734); - // udpAddress.family = 2; + this->has_recv_udp = false; if((result = nn::socket::Connect(this->udp_socket, &udpAddress, sizeof(udpAddress))).isFailure()) { Logger::log("Udp Socket Connection Failed!\n"); @@ -128,7 +125,7 @@ nn::Result SocketClient::init(const char* ip, u16 port) { } -u16 SocketClient::getUdpPort() { +u16 SocketClient::getLocalUdpPort() { sockaddr udpAddress = { 0 }; u32 size = sizeof(udpAddress); @@ -235,30 +232,32 @@ bool SocketClient::recv() { Packet* header = reinterpret_cast(headerBuf); int fullSize = header->mPacketSize + sizeof(Packet); - + // Verify packet size is appropriate if (result < fullSize || result > MAXPACKSIZE || fullSize > MAXPACKSIZE){ return true; } - - char* packetBuf = (char*)malloc(fullSize); - + + // Verify type of packet if (!(header->mType > PacketType::UNKNOWN && header->mType < PacketType::End)) { Logger::log("Failed to acquire valid packet type! Packet Type: %d Full Packet Size %d valread size: %d", header->mType, fullSize, result); - free(packetBuf); return true; } - memcpy(packetBuf, headerBuf, fullSize); + this->has_recv_udp = true; + + char* packetBuf = (char*)mHeap->alloc(fullSize); + if (packetBuf){ + memcpy(packetBuf, headerBuf, fullSize); - Packet *packet = reinterpret_cast(packetBuf); + Packet *packet = reinterpret_cast(packetBuf); - if(!mRecvQueue.isFull()) { - mRecvQueue.push((s64)packet, sead::MessageQueue::BlockType::NonBlocking); - this->has_recv_udp = true; - } else { - free(packetBuf); - } + if(!mRecvQueue.isFull()) { + mRecvQueue.push((s64)packet, sead::MessageQueue::BlockType::NonBlocking); + } else { + mHeap->free(packetBuf); + } + } return true; }