From 64b56c32dac2d435b125932f8379d88fa7ab6748 Mon Sep 17 00:00:00 2001 From: Jack Garrard Date: Thu, 27 Oct 2022 01:02:32 -0700 Subject: [PATCH] Add race protection to message queues --- include/server/SocketClient.hpp | 6 ++++++ source/server/Client.cpp | 8 +++++++- source/server/SocketClient.cpp | 25 +++++++++++++++++++++++-- 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/include/server/SocketClient.hpp b/include/server/SocketClient.hpp index 2f6dd74..fcecd97 100644 --- a/include/server/SocketClient.hpp +++ b/include/server/SocketClient.hpp @@ -50,6 +50,10 @@ class SocketClient : public SocketBase { u32 getRecvCount() { return mRecvQueue.getCount(); } u32 getRecvMaxCount() { return mRecvQueue.getMaxCount(); } + void clearMessageQueues(); + void setQueueOpen(bool value) { mPacketQueueOpen = value; } + + void setIsFirstConn(bool value) { mIsFirstConnect = value; } private: @@ -63,6 +67,8 @@ class SocketClient : public SocketBase { int maxBufSize = 100; bool mIsFirstConnect = true; + bool mPacketQueueOpen = true; + /** * @param str a string containing an IPv4 address or a hostname that can be resolved via DNS diff --git a/source/server/Client.cpp b/source/server/Client.cpp index 9f8f1c4..b77c5ca 100644 --- a/source/server/Client.cpp +++ b/source/server/Client.cpp @@ -122,7 +122,13 @@ void Client::restartConnection() { playerDC->mUserID = sInstance->mUserID; - sInstance->mSocket->queuePacket(playerDC); + + sInstance->mSocket->setQueueOpen(false); + sInstance->mSocket->clearMessageQueues(); + + sInstance->mSocket->send(playerDC); + + sInstance->mHeap->free(playerDC); if (sInstance->mSocket->closeSocket()) { Logger::log("Sucessfully Closed Socket.\n"); diff --git a/source/server/SocketClient.cpp b/source/server/SocketClient.cpp index ec89113..4bec07a 100644 --- a/source/server/SocketClient.cpp +++ b/source/server/SocketClient.cpp @@ -78,6 +78,8 @@ nn::Result SocketClient::init(const char* ip, u16 port) { return result; } + this->mPacketQueueOpen = true; + this->socket_log_state = SOCKET_LOG_CONNECTED; Logger::log("Socket fd: %d\n", socket_log_socket); @@ -205,7 +207,7 @@ bool SocketClient::recv() { Packet *packet = reinterpret_cast(packetBuf); - if (!mRecvQueue.isFull()) { + if (!mRecvQueue.isFull() && mPacketQueueOpen) { mRecvQueue.push((s64)packet, sead::MessageQueue::BlockType::NonBlocking); } else { mHeap->free(packetBuf); @@ -257,6 +259,8 @@ bool SocketClient::closeSocket() { Logger::log("Closing Socket.\n"); + mPacketQueueOpen = false; + bool result = false; if (!(result = SocketBase::closeSocket())) { @@ -340,7 +344,7 @@ void SocketClient::recvFunc() { } bool SocketClient::queuePacket(Packet* packet) { - if (socket_log_state == SOCKET_LOG_CONNECTED) { + if (socket_log_state == SOCKET_LOG_CONNECTED && mPacketQueueOpen) { mSendQueue.push((s64)packet, sead::MessageQueue::BlockType::NonBlocking); // as this is non-blocking, it // will always return true. @@ -363,3 +367,20 @@ void SocketClient::trySendQueue() { Packet* SocketClient::tryGetPacket(sead::MessageQueue::BlockType blockType) { return socket_log_state == SOCKET_LOG_CONNECTED ? (Packet*)mRecvQueue.pop(blockType) : nullptr; } + +void SocketClient::clearMessageQueues() { + bool prevQueueOpenness = this->mPacketQueueOpen; + this->mPacketQueueOpen = false; + + while (mSendQueue.getCount() > 0) { + Packet* curPacket = (Packet*)mSendQueue.pop(sead::MessageQueue::BlockType::Blocking); + mHeap->free(curPacket); + } + + while (mRecvQueue.getCount() > 0) { + Packet* curPacket = (Packet*)mRecvQueue.pop(sead::MessageQueue::BlockType::Blocking); + mHeap->free(curPacket); + } + + this->mPacketQueueOpen = prevQueueOpenness; +}