diff --git a/Makefile b/Makefile index d1b8bfa..1018f35 100644 --- a/Makefile +++ b/Makefile @@ -45,15 +45,15 @@ emu: mv $(shell basename $(CURDIR))$(SMOVER).nso starlight_patch_$(SMOVER)/yuzu/subsdk1 # builds and sends project to FTP server hosted on provided IP send: all - python3.8 scripts/sendPatch.py $(IP) $(PROJNAME) + python3 scripts/sendPatch.py $(IP) $(PROJNAME) log: all - python3.8 scripts/tcpServer.py $(SERVERIP) + python3 scripts/tcpServer.py $(SERVERIP) sendlog: all - python3.8 scripts/sendPatch.py $(IP) $(PROJNAME) $(USER) $(PASS) - python3.8 scripts/tcpServer.py $(SERVERIP) + python3 scripts/sendPatch.py $(IP) $(PROJNAME) $(USER) $(PASS) + python3 scripts/tcpServer.py $(SERVERIP) clean: $(MAKE) clean -f MakefileNSO - @rm -fr starlight_patch_* \ No newline at end of file + @rm -fr starlight_patch_* diff --git a/include/game/System/GameSystemInfo.h b/include/game/System/GameSystemInfo.h index b187ac6..49aaa24 100644 --- a/include/game/System/GameSystemInfo.h +++ b/include/game/System/GameSystemInfo.h @@ -37,8 +37,8 @@ namespace al { al::GameDrawInfo *mDrawInfo; // 0x38 from Application::sInstance + 0x30 ProjectNfpDirector *mProjNfpDirector; // 0x48 al::HtmlViewer *mHtmlViewer; // 0x50 - ApplicationMessageReceiver *mMessageReciever; // 0x58 + ApplicationMessageReceiver *mMessageReceiver; // 0x58 al::WaveVibrationHolder *mWaveVibrationHolder; // 0x60 void *gap2; }; -} \ No newline at end of file +} diff --git a/include/nn/socket.h b/include/nn/socket.h index 3e8290b..592c1bf 100644 --- a/include/nn/socket.h +++ b/include/nn/socket.h @@ -2,6 +2,12 @@ #include "../types.h" +struct pollfd +{ + s32 fd; + s16 events; + s16 revents; +}; struct in_addr { @@ -38,13 +44,19 @@ s32 Connect(s32 socket, const sockaddr* address, u32 addressLen); Result Close(s32 socket); s32 Send(s32 socket, const void* data, ulong dataLen, s32 flags); +s32 SendTo(s32 socket, const void* data, ulong dataLen, s32 flags, const struct sockaddr* to, u32 toLen); s32 Recv(s32 socket, void* out, ulong outLen, s32 flags); +s32 RecvFrom(s32 socket, void* out, ulong outLen, s32 flags, struct sockaddr* from, u32* fromLen); +s32 GetSockName(s32 socket, struct sockaddr* name, u32* dataLen); u16 InetHtons(u16 val); +u16 InetNtohs(u16 val); s32 InetAton(const char* addressStr, in_addr* addressOut); struct hostent* GetHostByName(const char* name); u32 GetLastErrno(); +s32 Bind(s32 fd, sockaddr* addr, u32 addrlen); +s32 Poll(struct pollfd* fd, u64 addr, s32 timeout); } } diff --git a/include/packets/HolePunchPacket.h b/include/packets/HolePunchPacket.h new file mode 100644 index 0000000..28380f0 --- /dev/null +++ b/include/packets/HolePunchPacket.h @@ -0,0 +1,7 @@ +#pragma once + +#include "Packet.h" + +struct PACKED HolePunch : Packet { + HolePunch() : Packet() {this->mType = PacketType::HOLEPUNCH; mPacketSize = sizeof(HolePunch) - sizeof(Packet);}; +}; diff --git a/include/packets/Packet.h b/include/packets/Packet.h index 400ee46..92fb053 100644 --- a/include/packets/Packet.h +++ b/include/packets/Packet.h @@ -26,6 +26,8 @@ enum PacketType : short { CAPTUREINF, CHANGESTAGE, CMD, + UDPINIT, + HOLEPUNCH, End // end of enum for bounds checking }; @@ -43,7 +45,9 @@ USED static const char *packetNames[] = { "Moon Collection", "Capture Info", "Change Stage", - "Server Command" + "Server Command", + "Udp Initialization", + "Hole punch", }; enum SenderType { @@ -83,4 +87,6 @@ struct PACKED Packet { #include "packets/CaptureInf.h" #include "packets/HackCapInf.h" #include "packets/ChangeStagePacket.h" -#include "packets/InitPacket.h" \ No newline at end of file +#include "packets/InitPacket.h" +#include "packets/UdpPacket.h" +#include "packets/HolePunchPacket.h" diff --git a/include/packets/UdpPacket.h b/include/packets/UdpPacket.h new file mode 100644 index 0000000..b4c4b37 --- /dev/null +++ b/include/packets/UdpPacket.h @@ -0,0 +1,8 @@ +#pragma once + +#include "Packet.h" + +struct PACKED UdpInit : Packet { + UdpInit() : Packet() {this->mType = PacketType::UDPINIT; mPacketSize = sizeof(UdpInit) - sizeof(Packet);}; + u16 port = 0; +}; diff --git a/include/server/Client.hpp b/include/server/Client.hpp index fc09a68..ab1b873 100644 --- a/include/server/Client.hpp +++ b/include/server/Client.hpp @@ -194,6 +194,8 @@ class Client { void updateTagInfo(TagInf *packet); void updateCaptureInfo(CaptureInf* packet); void sendToStage(ChangeStagePacket* packet); + void sendUdpHolePunch(); + void sendUdpInit(); void disconnectPlayer(PlayerDC *packet); PuppetInfo* findPuppetInfo(const nn::account::Uid& id, bool isFindAvailable); @@ -214,7 +216,7 @@ class Client { // --- Server Syncing Members --- - // array of shine IDs for checking if multiple shines have been collected in quick sucession, all moons within the players stage that match the ID will be deleted + // array of shine IDs for checking if multiple shines have been collected in quick succession, all moons within the players stage that match the ID will be deleted sead::SafeArray curCollectedShines; int collectedShineCount = 0; diff --git a/include/server/SocketClient.hpp b/include/server/SocketClient.hpp index a35c86e..bae3af9 100644 --- a/include/server/SocketClient.hpp +++ b/include/server/SocketClient.hpp @@ -31,12 +31,13 @@ class SocketClient : public SocketBase { bool startThreads(); void endThreads(); + void waitForThreads(); bool send(Packet* packet); bool recv(); bool queuePacket(Packet *packet); - void trySendQueue(); + bool trySendQueue(); void sendFunc(); void recvFunc(); @@ -46,6 +47,10 @@ class SocketClient : public SocketBase { void printPacket(Packet* packet); bool isConnected() { return socket_log_state == SOCKET_LOG_CONNECTED; } + u16 getLocalUdpPort(); + s32 setPeerUdpPort(u16 port); + const char* getUdpStateChar(); + u32 getSendCount() { return mSendQueue.getCount(); } u32 getSendMaxCount() { return mSendQueue.getMaxCount(); } @@ -67,12 +72,21 @@ class SocketClient : public SocketBase { sead::MessageQueue mRecvQueue; sead::MessageQueue mSendQueue; + char* recvBuf = nullptr; int maxBufSize = 100; bool mIsFirstConnect = true; bool mPacketQueueOpen = true; + int pollTime = 0; + bool mHasRecvUdp; + s32 mUdpSocket; + sockaddr mUdpAddress; + + bool recvTcp(); + bool recvUdp(); + /** * @param str a string containing an IPv4 address or a hostname that can be resolved via DNS * @param out IPv4 address diff --git a/scripts/sendPatch.py b/scripts/sendPatch.py index 15d14b3..4d74074 100644 --- a/scripts/sendPatch.py +++ b/scripts/sendPatch.py @@ -35,36 +35,37 @@ if '.' not in consoleIP: print(sys.argv[0], "ERROR: Please specify with `IP=[Your console's IP]`") sys.exit(-1) -isNeedOtherSwitch = True - -altSwitchIP = sys.argv[2] -if '.' not in altSwitchIP: - isNeedOtherSwitch = False +isNeedOtherSwitch = False consolePort = 5000 -if len(sys.argv) < 4: +if len(sys.argv) < 3: projName = 'StarlightBase' else: - projName = sys.argv[3] + projName = sys.argv[2] + +if len(sys.argv) < 5: + user = 'crafty' + passwd = 'boss' +else: + user = sys.argv[3] + passwd = sys.argv[4] curDir = os.curdir ftp = FTP() -otherftp = FTP() - print(f'Connecting to {consoleIP}... ', end='') ftp.connect(consoleIP, consolePort) print('logging into server...', end='') -ftp.login('crafty','boss') +ftp.login(user,passwd) print('Connected!') if isNeedOtherSwitch: print(f'Connecting to {altSwitchIP}... ', end='') otherftp.connect(altSwitchIP, consolePort) print('logging into server...', end='') - otherftp.login('crafty','boss') + otherftp.login(user,passwd) print('Connected!') patchDirectories = [] diff --git a/source/main.cpp b/source/main.cpp index 39f2810..2ab620f 100644 --- a/source/main.cpp +++ b/source/main.cpp @@ -118,6 +118,7 @@ void drawMainHook(HakoniwaSequence *curSequence, sead::Viewport *viewport, sead: } gTextWriter->printf("Client Socket Connection Status: %s\n", Client::instance()->mSocket->getStateChar()); + gTextWriter->printf("Udp socket status: %s\n", Client::instance()->mSocket->getUdpStateChar()); //gTextWriter->printf("nn::socket::GetLastErrno: 0x%x\n", Client::instance()->mSocket->socket_errno); gTextWriter->printf("Connected Players: %d/%d\n", Client::getConnectCount() + 1, Client::getMaxPlayerCount()); diff --git a/source/puppets/PuppetHolder.cpp b/source/puppets/PuppetHolder.cpp index f730c7b..ebfb672 100644 --- a/source/puppets/PuppetHolder.cpp +++ b/source/puppets/PuppetHolder.cpp @@ -17,7 +17,7 @@ PuppetHolder::PuppetHolder(int size) { * @brief resizes puppet ptr array by creating a new ptr array and storing previous ptrs in it, before freeing the previous array * * @param size the size of the new ptr array - * @return returns true if resizing was sucessful + * @return returns true if resizing was successful */ bool PuppetHolder::resizeHolder(int size) { @@ -106,4 +106,4 @@ bool PuppetHolder::checkInfoIsInStage(PuppetInfo *info) { void PuppetHolder::setStageInfo(const char *stageName, u8 scenarioNo) { mStageName = stageName; mScenarioNo = scenarioNo; -} \ No newline at end of file +} diff --git a/source/server/Client.cpp b/source/server/Client.cpp index f38cc77..f30da5d 100644 --- a/source/server/Client.cpp +++ b/source/server/Client.cpp @@ -89,13 +89,13 @@ void Client::init(al::LayoutInitInfo const &initInfo, GameDataHolderAccessor hol /** * @brief starts client read thread * - * @return true if read thread was sucessfully started + * @return true if read thread was succesfully started * @return false if read thread was unable to start, or thread was already started. */ bool Client::startThread() { if(mReadThread->isDone() ) { mReadThread->start(); - Logger::log("Read Thread Sucessfully Started.\n"); + Logger::log("Read Thread Successfully Started.\n"); return true; }else { Logger::log("Read Thread has already started! Or other unknown reason.\n"); @@ -139,7 +139,7 @@ bool Client::startConnection() { if (mIsConnectionActive) { - Logger::log("Sucessful Connection. Waiting to recieve init packet.\n"); + Logger::log("Succesful Connection. Waiting to receive init packet.\n"); bool waitingForInitPacket = true; // wait for client init packet @@ -163,11 +163,13 @@ bool Client::startConnection() { mHeap->free(curPacket); } else { - Logger::log("Recieve failed! Stopping Connection.\n"); + Logger::log("Receive failed! Stopping Connection.\n"); mIsConnectionActive = false; waitingForInitPacket = false; } } + + } return mIsConnectionActive; @@ -313,7 +315,7 @@ void Client::readFunc() { while(mIsConnectionActive) { - Packet *curPacket = mSocket->tryGetPacket(); // will block until a packet has been recieved, or socket disconnected + Packet *curPacket = mSocket->tryGetPacket(); // will block until a packet has been received, or socket disconnected if (curPacket) { @@ -377,6 +379,19 @@ void Client::readFunc() { maxPuppets = initPacket->maxPlayers - 1; break; } + case PacketType::UDPINIT: { + UdpInit* initPacket = (UdpInit*)curPacket; + Logger::log("Received udp init packet from server\n"); + + sInstance->mSocket->setPeerUdpPort(initPacket->port); + sendUdpHolePunch(); + sendUdpInit(); + + break; + } + case PacketType::HOLEPUNCH: + sendUdpHolePunch(); + break; default: Logger::log("Discarding Unknown Packet Type.\n"); break; @@ -384,8 +399,8 @@ void Client::readFunc() { mHeap->free(curPacket); - }else { // if false, socket has errored or disconnected, so close the socket and end this thread. - Logger::log("Client Socket Encountered an Error! Errno: 0x%x\n", mSocket->socket_errno); + }else { // if false, socket has errored or disconnected, so restart the connection + Logger::log("Client Socket Encountered an Error, restarting connection! Errno: 0x%x\n", mSocket->socket_errno); } } @@ -967,7 +982,45 @@ void Client::sendToStage(ChangeStagePacket* packet) { GameDataFunction::tryChangeNextStage(accessor, &info); } } +/** + * @brief + * Send a udp holepunch packet to the server + */ +void Client::sendUdpHolePunch() { + if (!sInstance) { + Logger::log("Static Instance is Null!\n"); + return; + } + + sead::ScopedCurrentHeapSetter setter(sInstance->mHeap); + + HolePunch *packet = new HolePunch(); + + packet->mUserID = sInstance->mUserID; + + sInstance->mSocket->queuePacket(packet); +} +/** + * @brief + * Send a udp init packet to server + */ +void Client::sendUdpInit() { + + if (!sInstance) { + Logger::log("Static Instance is Null!\n"); + return; + } + + sead::ScopedCurrentHeapSetter setter(sInstance->mHeap); + + UdpInit *packet = new UdpInit(); + + packet->mUserID = sInstance->mUserID; + packet->port = sInstance->mSocket->getLocalUdpPort(); + + sInstance->mSocket->queuePacket(packet); +} /** * @brief * diff --git a/source/server/SocketBase.cpp b/source/server/SocketBase.cpp index b9c408a..9c79758 100644 --- a/source/server/SocketBase.cpp +++ b/source/server/SocketBase.cpp @@ -72,13 +72,13 @@ s32 SocketBase::getFd() { bool SocketBase::closeSocket() { - if (this->socket_log_state != SOCKET_LOG_DISCONNECTED) { - nn::Result result = nn::socket::Close(this->socket_log_socket); - if (result.isSuccess()) { - this->socket_log_state = SOCKET_LOG_DISCONNECTED; - } - return result.isSuccess(); - } + if (this->socket_log_state != SOCKET_LOG_DISCONNECTED) { + nn::Result result = nn::socket::Close(this->socket_log_socket); + if (result.isSuccess()) { + this->socket_log_state = SOCKET_LOG_DISCONNECTED; + } + return result.isSuccess(); + } - return true; + return true; } diff --git a/source/server/SocketClient.cpp b/source/server/SocketClient.cpp index 952c545..24021a7 100644 --- a/source/server/SocketClient.cpp +++ b/source/server/SocketClient.cpp @@ -9,6 +9,7 @@ #include "nn/result.h" #include "nn/socket.h" #include "packets/Packet.h" +#include "packets/UdpPacket.h" #include "server/Client.hpp" #include "thread/seadMessageQueue.h" #include "types.h" @@ -16,12 +17,18 @@ SocketClient::SocketClient(const char* name, sead::Heap* heap, Client* client) : mHeap(heap), SocketBase(name) { this->client = client; +#if EMU + this->pollTime = 0; +#else + this->pollTime = -1; +#endif mRecvThread = new al::AsyncFunctorThread("SocketRecvThread", al::FunctorV0M(this, &SocketClient::recvFunc), 0, 0x1000, {0}); mSendThread = new al::AsyncFunctorThread("SocketSendThread", al::FunctorV0M(this, &SocketClient::sendFunc), 0, 0x1000, {0}); mRecvQueue.allocate(maxBufSize, mHeap); mSendQueue.allocate(maxBufSize, mHeap); + recvBuf = (char*)mHeap->alloc(MAXPACKSIZE+1); }; nn::Result SocketClient::init(const char* ip, u16 port) { @@ -31,6 +38,7 @@ nn::Result SocketClient::init(const char* ip, u16 port) { in_addr hostAddress = { 0 }; sockaddr serverAddress = { 0 }; + sockaddr udpAddress = { 0 }; Logger::log("SocketClient::init: %s:%d sock %s\n", ip, port, getStateChar()); @@ -80,13 +88,25 @@ nn::Result SocketClient::init(const char* ip, u16 port) { return result; } + if ((this->mUdpSocket = nn::socket::Socket(2, 2, 17)) < 0) { + Logger::log("Udp Socket failed to create"); + this->socket_errno = nn::socket::GetLastErrno(); + this->socket_log_state = SOCKET_LOG_UNAVAILABLE; + return -1; + } + + udpAddress.address = hostAddress; + udpAddress.port = 0; + udpAddress.family = 2; + this->mUdpAddress = udpAddress; + this->mHasRecvUdp = false; this->mPacketQueueOpen = true; this->socket_log_state = SOCKET_LOG_CONNECTED; Logger::log("Socket fd: %d\n", socket_log_socket); - startThreads(); // start recv and send threads after sucessful connection + startThreads(); // start recv and send threads after succesful connection // send init packet to server once we connect (an issue with the server prevents this from working properly, waiting for a fix to implement) @@ -127,24 +147,77 @@ nn::Result SocketClient::init(const char* ip, u16 port) { } +u16 SocketClient::getLocalUdpPort() { + sockaddr udpAddress = { 0 }; + u32 size = sizeof(udpAddress); + + nn::Result result; + if((result = nn::socket::GetSockName(this->mUdpSocket, &udpAddress, &size)).isFailure()) { + this->socket_errno = nn::socket::GetLastErrno(); + return 0; + } + + return nn::socket::InetNtohs(udpAddress.port); +} + + +s32 SocketClient::setPeerUdpPort(u16 port) { + u16 net_port = nn::socket::InetHtons(port); + this->mUdpAddress.port = net_port; + + nn::Result result; + if((result = nn::socket::Connect(this->mUdpSocket, &this->mUdpAddress, sizeof(this->mUdpAddress))).isFailure()) { + Logger::log("Udp socket connection failed to connect to port %d!\n", port); + this->socket_errno = nn::socket::GetLastErrno(); + return -1; + } + + return 0; + +} + +const char* SocketClient::getUdpStateChar() { + if (this->mUdpAddress.port == 0) { + return "Waiting for handshake"; + } + + if (!this->mHasRecvUdp) { + return "Waiting for holepunch"; + } + + return "Utilizing UDP"; +} bool SocketClient::send(Packet *packet) { - if (this->socket_log_state != SOCKET_LOG_CONNECTED) + if (this->socket_log_state != SOCKET_LOG_CONNECTED || packet == nullptr) return false; char* buffer = reinterpret_cast(packet); int valread = 0; - if (packet->mType != PLAYERINF && packet->mType != HACKCAPINF) - Logger::log("Sending packet: %s\n", packetNames[packet->mType]); + int fd = -1; + if ((packet->mType != PLAYERINF && packet->mType != HACKCAPINF && packet->mType != HOLEPUNCH) + || (!this->mHasRecvUdp && packet->mType != HOLEPUNCH) + || this->mUdpAddress.port == 0) { - if ((valread = nn::socket::Send(this->socket_log_socket, buffer, packet->mPacketSize + sizeof(Packet), 0) > 0)) { + if (packet->mType != PLAYERINF && packet->mType != HACKCAPINF) { + Logger::log("Sending packet: %s\n", packetNames[packet->mType]); + } + + fd = this->socket_log_socket; + } else { + + fd = this->mUdpSocket; + } + + + if ((valread = nn::socket::Send(fd, buffer, packet->mPacketSize + sizeof(Packet), 0) > 0)) { return true; } else { Logger::log("Failed to Fully Send Packet! Result: %d Type: %s Packet Size: %d\n", valread, packetNames[packet->mType], packet->mPacketSize); this->socket_errno = nn::socket::GetLastErrno(); - this->tryReconnect(); + this->closeSocket(); return false; } return true; @@ -155,20 +228,68 @@ bool SocketClient::recv() { if (this->socket_log_state != SOCKET_LOG_CONNECTED) { Logger::log("Unable To Receive! Socket Not Connected.\n"); this->socket_errno = nn::socket::GetLastErrno(); - return this->tryReconnect(); + this->closeSocket(); + return false; } + const int fd_count = 2; + struct pollfd pfds[fd_count] = {{0}, {0}}; + + // TCP Connection + pfds[0].fd = this->socket_log_socket; + pfds[0].events = 1; + pfds[0].revents = 0; + + // UDP Connection + pfds[1].fd = this->mUdpSocket; + pfds[1].events = 1; + pfds[1].revents = 0; + + + int result = nn::socket::Poll(pfds, fd_count, this->pollTime); + + if (result == 0) { + return true; + } else if (result < 0) { + Logger::log("Error occurred when polling for packets\n"); + this->socket_errno = nn::socket::GetLastErrno(); + this->closeSocket(); + return false; + } + + s32 index = -1; + for (int i = 0; i < fd_count; i++){ + if (pfds[i].revents & 1) { + index = i; + break; + } + } + + switch (index) { + case 0: + return recvTcp(); + case 1: + return recvUdp(); + default: + return true; + } + + +} + +bool SocketClient::recvTcp() { int headerSize = sizeof(Packet); - char headerBuf[sizeof(Packet)] = {}; int valread = 0; + s32 fd = this->socket_log_socket; // read only the size of a header while(valread < headerSize) { - int result = nn::socket::Recv(this->socket_log_socket, headerBuf + valread, - headerSize - valread, this->sock_flags); + int result = 0; + result = nn::socket::Recv(fd, recvBuf + valread, + headerSize - valread, this->sock_flags); this->socket_errno = nn::socket::GetLastErrno(); - + if(result > 0) { valread += result; } else { @@ -176,75 +297,129 @@ bool SocketClient::recv() { return true; } else { Logger::log("Header Read Failed! Value: %d Total Read: %d\n", result, valread); - return this->tryReconnect(); // if we sucessfully reconnect, we dont want + this->closeSocket(); + return false; } } } - if(valread > 0) { - Packet* header = reinterpret_cast(headerBuf); - - int fullSize = header->mPacketSize + sizeof(Packet); - - if (fullSize <= MAXPACKSIZE && fullSize > 0 && valread == sizeof(Packet)) { - - if (header->mType != PLAYERINF && header->mType != HACKCAPINF) { - Logger::log("Received packet (from %02X%02X):", header->mUserID.data[0], - header->mUserID.data[1]); - Logger::disableName(); - Logger::log(" Size: %d", header->mPacketSize); - Logger::log(" Type: %d", header->mType); - if(packetNames[header->mType]) - Logger::log(" Type String: %s\n", packetNames[header->mType]); - Logger::enableName(); - } - - char* packetBuf = (char*)mHeap->alloc(fullSize); - - if (packetBuf) { - - memcpy(packetBuf, headerBuf, sizeof(Packet)); - - while (valread < fullSize) { - - int result = nn::socket::Recv(this->socket_log_socket, packetBuf + valread, - fullSize - valread, this->sock_flags); - - this->socket_errno = nn::socket::GetLastErrno(); - - if (result > 0) { - valread += result; - } else { - mHeap->free(packetBuf); - Logger::log("Packet Read Failed! Value: %d\nPacket Size: %d\nPacket Type: %s\n", result, header->mPacketSize, packetNames[header->mType]); - return this->tryReconnect(); - } - } - - if (!(header->mType > PacketType::UNKNOWN && header->mType < PacketType::End)) { - Logger::log("Failed to acquire valid packet type! Packet Type: %d Full Packet Size %d valread size: %d", header->mType, fullSize, valread); - mHeap->free(packetBuf); - return true; - } - - Packet *packet = reinterpret_cast(packetBuf); - - if (!mRecvQueue.isFull() && mPacketQueueOpen) { - mRecvQueue.push((s64)packet, sead::MessageQueue::BlockType::NonBlocking); - } else { - mHeap->free(packetBuf); - } - } - } else { - Logger::log("Failed to acquire valid data! Packet Type: %d Full Packet Size %d valread size: %d", header->mType, fullSize, valread); - } - - return true; - } else { // if we error'd, close the socket + if(valread <= 0) { // if we error'd, close the socket Logger::log("valread was zero! Disconnecting.\n"); this->socket_errno = nn::socket::GetLastErrno(); - return this->tryReconnect(); + this->closeSocket(); + return false; } + + Packet* header = reinterpret_cast(recvBuf); + int fullSize = header->mPacketSize + sizeof(Packet); + + if (!(fullSize <= MAXPACKSIZE && fullSize > 0 && valread == sizeof(Packet))) { + Logger::log("Failed to acquire valid data! Packet Type: %d Full Packet Size %d valread size: %d\n", header->mType, fullSize, valread); + return true; + } + + if (header->mType != PLAYERINF && header->mType != HACKCAPINF) { + Logger::log("Received packet (from %02X%02X):", header->mUserID.data[0], + header->mUserID.data[1]); + Logger::disableName(); + Logger::log(" Size: %d", header->mPacketSize); + Logger::log(" Type: %d", header->mType); + if(packetNames[header->mType]) + Logger::log(" Type String: %s\n", packetNames[header->mType]); + Logger::enableName(); + } + + char* packetBuf = (char*)mHeap->alloc(fullSize); + + if (!packetBuf) { + return true; + } + + + memcpy(packetBuf, recvBuf, sizeof(Packet)); + + while (valread < fullSize) { + + int result = nn::socket::Recv(fd, packetBuf + valread, + fullSize - valread, this->sock_flags); + + this->socket_errno = nn::socket::GetLastErrno(); + + if (result > 0) { + valread += result; + } else { + mHeap->free(packetBuf); + Logger::log("Packet Read Failed! Value: %d\nPacket Size: %d\nPacket Type: %s\n", result, header->mPacketSize, packetNames[header->mType]); + this->closeSocket(); + return false; + } + } + + if (!(header->mType > PacketType::UNKNOWN && header->mType < PacketType::End)) { + Logger::log("Failed to acquire valid packet type! Packet Type: %d Full Packet Size %d valread size: %d\n", header->mType, fullSize, valread); + mHeap->free(packetBuf); + return true; + } + + Packet *packet = reinterpret_cast(packetBuf); + + if (!mRecvQueue.isFull() && mPacketQueueOpen) { + mRecvQueue.push((s64)packet, sead::MessageQueue::BlockType::NonBlocking); + } else { + mHeap->free(packetBuf); + } + + return true; +} + +bool SocketClient::recvUdp() { + int headerSize = sizeof(Packet); + s32 fd = this->mUdpSocket; + + int valread = nn::socket::Recv(fd, recvBuf, MAXPACKSIZE, this->sock_flags); + + if (valread == 0) { + Logger::log("Udp connection valread was zero. Disconnecting.\n"); + this->closeSocket(); + return false; + } + + if (valread < headerSize){ + return true; + } + + Packet* header = reinterpret_cast(recvBuf); + int fullSize = header->mPacketSize + sizeof(Packet); + // Verify packet size is appropriate + if (valread < fullSize || valread > MAXPACKSIZE || fullSize > MAXPACKSIZE){ + return true; + } + + // Verify type of packet + if (!(header->mType > PacketType::UNKNOWN && header->mType < PacketType::End)) { + Logger::log("Failed to acquire valid packet type! Packet Type: %d Full Packet Size %d valread size: %d\n", header->mType, fullSize, valread); + return true; + } + + this->mHasRecvUdp = true; + + char* packetBuf = (char*)mHeap->alloc(fullSize); + if (!packetBuf) { + return true; + } + + memcpy(packetBuf, recvBuf, fullSize); + + + Packet *packet = reinterpret_cast(packetBuf); + + if(!mRecvQueue.isFull()) { + mRecvQueue.push((s64)packet, sead::MessageQueue::BlockType::NonBlocking); + } else { + mHeap->free(packetBuf); + } + + return true; } // prints packet to debug logger @@ -281,6 +456,8 @@ bool SocketClient::closeSocket() { Logger::log("Closing Socket.\n"); + mHasRecvUdp = false; + mUdpAddress.port = 0; mPacketQueueOpen = false; bool result = false; @@ -315,7 +492,7 @@ bool SocketClient::stringToIPAddress(const char* str, in_addr* out) { /** * @brief starts client read thread * - * @return true if read thread was sucessfully started + * @return true if read thread was successfully started * @return false if read thread was unable to start, or thread was already started. */ bool SocketClient::startThreads() { @@ -326,7 +503,7 @@ bool SocketClient::startThreads() { if(this->mRecvThread->isDone() && this->mSendThread->isDone()) { this->mRecvThread->start(); this->mSendThread->start(); - Logger::log("Socket threads sucessfully started.\n"); + Logger::log("Socket threads succesfully started.\n"); return true; }else { Logger::log("Socket threads failed to start.\n"); @@ -339,14 +516,18 @@ void SocketClient::endThreads() { mSendThread->mDelegateThread->destroy(); } +void SocketClient::waitForThreads() { + while (!mRecvThread->isDone()){} + while (!mSendThread->isDone()){} +} + void SocketClient::sendFunc() { Logger::log("Starting Send Thread.\n"); - while (true) { - trySendQueue(); - } + while (trySendQueue() || socket_log_state != SOCKET_LOG_DISCONNECTED) {} + Logger::log("Sending packet failed!\n"); Logger::log("Ending Send Thread.\n"); } @@ -356,12 +537,13 @@ void SocketClient::recvFunc() { Logger::log("Starting Recv Thread.\n"); - while (true) { - if (!recv()) { - Logger::log("Receiving Packet Failed!\n"); - } - } + while (recv() || socket_log_state != SOCKET_LOG_DISCONNECTED) {} + // Free up all blocked threads + mSendQueue.push(0, sead::MessageQueue::BlockType::NonBlocking); + mRecvQueue.push(0, sead::MessageQueue::BlockType::NonBlocking); + + Logger::log("Receiving Packet Failed!\n"); Logger::log("Ending Recv Thread.\n"); } @@ -377,13 +559,15 @@ bool SocketClient::queuePacket(Packet* packet) { } } -void SocketClient::trySendQueue() { +bool SocketClient::trySendQueue() { Packet* curPacket = (Packet*)mSendQueue.pop(sead::MessageQueue::BlockType::Blocking); - send(curPacket); + bool successful = send(curPacket); mHeap->free(curPacket); + + return successful; } Packet* SocketClient::tryGetPacket(sead::MessageQueue::BlockType blockType) { diff --git a/source/server/logger.cpp b/source/server/logger.cpp index 99fd356..038c18a 100644 --- a/source/server/logger.cpp +++ b/source/server/logger.cpp @@ -111,7 +111,7 @@ void Logger::log(const char* fmt, ...) { } bool Logger::pingSocket() { - return socket_log("ping") > 0; // if value is greater than zero, than the socket recieved our message, otherwise the connection was lost. + return socket_log("ping") > 0; // if value is greater than zero, than the socket received our message, otherwise the connection was lost. } void tryInitSocket() { @@ -119,4 +119,4 @@ void tryInitSocket() { #if DEBUGLOG Logger::createInstance(); // creates a static instance for debug logger #endif -} \ No newline at end of file +}