Add race protection to message queues

This commit is contained in:
Jack Garrard 2022-10-27 01:02:32 -07:00
parent 89415e6f96
commit 64b56c32da
3 changed files with 36 additions and 3 deletions

View file

@ -50,6 +50,10 @@ class SocketClient : public SocketBase {
u32 getRecvCount() { return mRecvQueue.getCount(); } u32 getRecvCount() { return mRecvQueue.getCount(); }
u32 getRecvMaxCount() { return mRecvQueue.getMaxCount(); } u32 getRecvMaxCount() { return mRecvQueue.getMaxCount(); }
void clearMessageQueues();
void setQueueOpen(bool value) { mPacketQueueOpen = value; }
void setIsFirstConn(bool value) { mIsFirstConnect = value; } void setIsFirstConn(bool value) { mIsFirstConnect = value; }
private: private:
@ -63,6 +67,8 @@ class SocketClient : public SocketBase {
int maxBufSize = 100; int maxBufSize = 100;
bool mIsFirstConnect = true; bool mIsFirstConnect = true;
bool mPacketQueueOpen = true;
/** /**
* @param str a string containing an IPv4 address or a hostname that can be resolved via DNS * @param str a string containing an IPv4 address or a hostname that can be resolved via DNS

View file

@ -122,7 +122,13 @@ void Client::restartConnection() {
playerDC->mUserID = sInstance->mUserID; playerDC->mUserID = sInstance->mUserID;
sInstance->mSocket->queuePacket(playerDC);
sInstance->mSocket->setQueueOpen(false);
sInstance->mSocket->clearMessageQueues();
sInstance->mSocket->send(playerDC);
sInstance->mHeap->free(playerDC);
if (sInstance->mSocket->closeSocket()) { if (sInstance->mSocket->closeSocket()) {
Logger::log("Sucessfully Closed Socket.\n"); Logger::log("Sucessfully Closed Socket.\n");

View file

@ -78,6 +78,8 @@ nn::Result SocketClient::init(const char* ip, u16 port) {
return result; return result;
} }
this->mPacketQueueOpen = true;
this->socket_log_state = SOCKET_LOG_CONNECTED; this->socket_log_state = SOCKET_LOG_CONNECTED;
Logger::log("Socket fd: %d\n", socket_log_socket); Logger::log("Socket fd: %d\n", socket_log_socket);
@ -205,7 +207,7 @@ bool SocketClient::recv() {
Packet *packet = reinterpret_cast<Packet*>(packetBuf); Packet *packet = reinterpret_cast<Packet*>(packetBuf);
if (!mRecvQueue.isFull()) { if (!mRecvQueue.isFull() && mPacketQueueOpen) {
mRecvQueue.push((s64)packet, sead::MessageQueue::BlockType::NonBlocking); mRecvQueue.push((s64)packet, sead::MessageQueue::BlockType::NonBlocking);
} else { } else {
mHeap->free(packetBuf); mHeap->free(packetBuf);
@ -257,6 +259,8 @@ bool SocketClient::closeSocket() {
Logger::log("Closing Socket.\n"); Logger::log("Closing Socket.\n");
mPacketQueueOpen = false;
bool result = false; bool result = false;
if (!(result = SocketBase::closeSocket())) { if (!(result = SocketBase::closeSocket())) {
@ -340,7 +344,7 @@ void SocketClient::recvFunc() {
} }
bool SocketClient::queuePacket(Packet* packet) { bool SocketClient::queuePacket(Packet* packet) {
if (socket_log_state == SOCKET_LOG_CONNECTED) { if (socket_log_state == SOCKET_LOG_CONNECTED && mPacketQueueOpen) {
mSendQueue.push((s64)packet, mSendQueue.push((s64)packet,
sead::MessageQueue::BlockType::NonBlocking); // as this is non-blocking, it sead::MessageQueue::BlockType::NonBlocking); // as this is non-blocking, it
// will always return true. // will always return true.
@ -363,3 +367,20 @@ void SocketClient::trySendQueue() {
Packet* SocketClient::tryGetPacket(sead::MessageQueue::BlockType blockType) { Packet* SocketClient::tryGetPacket(sead::MessageQueue::BlockType blockType) {
return socket_log_state == SOCKET_LOG_CONNECTED ? (Packet*)mRecvQueue.pop(blockType) : nullptr; return socket_log_state == SOCKET_LOG_CONNECTED ? (Packet*)mRecvQueue.pop(blockType) : nullptr;
} }
void SocketClient::clearMessageQueues() {
bool prevQueueOpenness = this->mPacketQueueOpen;
this->mPacketQueueOpen = false;
while (mSendQueue.getCount() > 0) {
Packet* curPacket = (Packet*)mSendQueue.pop(sead::MessageQueue::BlockType::Blocking);
mHeap->free(curPacket);
}
while (mRecvQueue.getCount() > 0) {
Packet* curPacket = (Packet*)mRecvQueue.pop(sead::MessageQueue::BlockType::Blocking);
mHeap->free(curPacket);
}
this->mPacketQueueOpen = prevQueueOpenness;
}