From b18671f11337b6c0d9252d29223dbb2bcb8544f3 Mon Sep 17 00:00:00 2001 From: Jack Garrard Date: Sat, 27 Aug 2022 01:47:52 -0700 Subject: [PATCH 1/4] Move packet type check to after packet data recv --- source/server/SocketClient.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/source/server/SocketClient.cpp b/source/server/SocketClient.cpp index 8abc7e6..407ac98 100644 --- a/source/server/SocketClient.cpp +++ b/source/server/SocketClient.cpp @@ -133,8 +133,7 @@ bool SocketClient::RECV() { int fullSize = header->mPacketSize + sizeof(Packet); - if (header->mType > PacketType::UNKNOWN && header->mType < PacketType::End && - fullSize <= MAXPACKSIZE && fullSize > 0 && valread == sizeof(Packet)) { + if (fullSize <= MAXPACKSIZE && fullSize > 0 && valread == sizeof(Packet)) { if (header->mType != PLAYERINF && header->mType != HACKCAPINF) { Logger::log("Received packet (from %02X%02X):", header->mUserID.data[0], @@ -171,6 +170,12 @@ bool SocketClient::RECV() { } } + if (!(header->mType > PacketType::UNKNOWN && header->mType < PacketType::End)) { + Logger::log("Failed to acquire valid packet type! Packet Type: %d Full Packet Size %d valread size: %d", header->mType, fullSize, valread); + free(packetBuf); + return true; + } + Packet *packet = reinterpret_cast(packetBuf); if(mPacketQueue.size() < maxBufSize - 1) { @@ -180,7 +185,7 @@ bool SocketClient::RECV() { } } } else { - Logger::log("Failed to aquire valid data! Packet Type: %d Full Packet Size %d valread size: %d", header->mType, fullSize, valread); + Logger::log("Failed to acquire valid data! Packet Type: %d Full Packet Size %d valread size: %d", header->mType, fullSize, valread); } return true; From 7c17db2d93d1d2f077420c25aeba59264ba2632c Mon Sep 17 00:00:00 2001 From: Jack Garrard Date: Sun, 23 Oct 2022 14:41:16 -0700 Subject: [PATCH 2/4] Fix free issue from wrong heap if unknown packet --- source/server/SocketClient.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/server/SocketClient.cpp b/source/server/SocketClient.cpp index 310ce9c..ec89113 100644 --- a/source/server/SocketClient.cpp +++ b/source/server/SocketClient.cpp @@ -199,7 +199,7 @@ bool SocketClient::recv() { if (!(header->mType > PacketType::UNKNOWN && header->mType < PacketType::End)) { Logger::log("Failed to acquire valid packet type! Packet Type: %d Full Packet Size %d valread size: %d", header->mType, fullSize, valread); - free(packetBuf); + mHeap->free(packetBuf); return true; } From 89415e6f969ed7171cc83221b91032d3053fef73 Mon Sep 17 00:00:00 2001 From: Jack Garrard Date: Thu, 27 Oct 2022 00:26:26 -0700 Subject: [PATCH 3/4] Hopefully prevent close socket race condition --- source/server/SocketBase.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/source/server/SocketBase.cpp b/source/server/SocketBase.cpp index 4f82d95..b9c408a 100644 --- a/source/server/SocketBase.cpp +++ b/source/server/SocketBase.cpp @@ -72,12 +72,13 @@ s32 SocketBase::getFd() { bool SocketBase::closeSocket() { - this->socket_log_state = SOCKET_LOG_DISCONNECTED; // probably not safe to assume socket will be closed + if (this->socket_log_state != SOCKET_LOG_DISCONNECTED) { + nn::Result result = nn::socket::Close(this->socket_log_socket); + if (result.isSuccess()) { + this->socket_log_state = SOCKET_LOG_DISCONNECTED; + } + return result.isSuccess(); + } - nn::Result result = nn::socket::Close(this->socket_log_socket); - - return result.isSuccess(); + return true; } - - - From 64b56c32dac2d435b125932f8379d88fa7ab6748 Mon Sep 17 00:00:00 2001 From: Jack Garrard Date: Thu, 27 Oct 2022 01:02:32 -0700 Subject: [PATCH 4/4] Add race protection to message queues --- include/server/SocketClient.hpp | 6 ++++++ source/server/Client.cpp | 8 +++++++- source/server/SocketClient.cpp | 25 +++++++++++++++++++++++-- 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/include/server/SocketClient.hpp b/include/server/SocketClient.hpp index 2f6dd74..fcecd97 100644 --- a/include/server/SocketClient.hpp +++ b/include/server/SocketClient.hpp @@ -50,6 +50,10 @@ class SocketClient : public SocketBase { u32 getRecvCount() { return mRecvQueue.getCount(); } u32 getRecvMaxCount() { return mRecvQueue.getMaxCount(); } + void clearMessageQueues(); + void setQueueOpen(bool value) { mPacketQueueOpen = value; } + + void setIsFirstConn(bool value) { mIsFirstConnect = value; } private: @@ -63,6 +67,8 @@ class SocketClient : public SocketBase { int maxBufSize = 100; bool mIsFirstConnect = true; + bool mPacketQueueOpen = true; + /** * @param str a string containing an IPv4 address or a hostname that can be resolved via DNS diff --git a/source/server/Client.cpp b/source/server/Client.cpp index 9f8f1c4..b77c5ca 100644 --- a/source/server/Client.cpp +++ b/source/server/Client.cpp @@ -122,7 +122,13 @@ void Client::restartConnection() { playerDC->mUserID = sInstance->mUserID; - sInstance->mSocket->queuePacket(playerDC); + + sInstance->mSocket->setQueueOpen(false); + sInstance->mSocket->clearMessageQueues(); + + sInstance->mSocket->send(playerDC); + + sInstance->mHeap->free(playerDC); if (sInstance->mSocket->closeSocket()) { Logger::log("Sucessfully Closed Socket.\n"); diff --git a/source/server/SocketClient.cpp b/source/server/SocketClient.cpp index ec89113..4bec07a 100644 --- a/source/server/SocketClient.cpp +++ b/source/server/SocketClient.cpp @@ -78,6 +78,8 @@ nn::Result SocketClient::init(const char* ip, u16 port) { return result; } + this->mPacketQueueOpen = true; + this->socket_log_state = SOCKET_LOG_CONNECTED; Logger::log("Socket fd: %d\n", socket_log_socket); @@ -205,7 +207,7 @@ bool SocketClient::recv() { Packet *packet = reinterpret_cast(packetBuf); - if (!mRecvQueue.isFull()) { + if (!mRecvQueue.isFull() && mPacketQueueOpen) { mRecvQueue.push((s64)packet, sead::MessageQueue::BlockType::NonBlocking); } else { mHeap->free(packetBuf); @@ -257,6 +259,8 @@ bool SocketClient::closeSocket() { Logger::log("Closing Socket.\n"); + mPacketQueueOpen = false; + bool result = false; if (!(result = SocketBase::closeSocket())) { @@ -340,7 +344,7 @@ void SocketClient::recvFunc() { } bool SocketClient::queuePacket(Packet* packet) { - if (socket_log_state == SOCKET_LOG_CONNECTED) { + if (socket_log_state == SOCKET_LOG_CONNECTED && mPacketQueueOpen) { mSendQueue.push((s64)packet, sead::MessageQueue::BlockType::NonBlocking); // as this is non-blocking, it // will always return true. @@ -363,3 +367,20 @@ void SocketClient::trySendQueue() { Packet* SocketClient::tryGetPacket(sead::MessageQueue::BlockType blockType) { return socket_log_state == SOCKET_LOG_CONNECTED ? (Packet*)mRecvQueue.pop(blockType) : nullptr; } + +void SocketClient::clearMessageQueues() { + bool prevQueueOpenness = this->mPacketQueueOpen; + this->mPacketQueueOpen = false; + + while (mSendQueue.getCount() > 0) { + Packet* curPacket = (Packet*)mSendQueue.pop(sead::MessageQueue::BlockType::Blocking); + mHeap->free(curPacket); + } + + while (mRecvQueue.getCount() > 0) { + Packet* curPacket = (Packet*)mRecvQueue.pop(sead::MessageQueue::BlockType::Blocking); + mHeap->free(curPacket); + } + + this->mPacketQueueOpen = prevQueueOpenness; +}