Abstracted all socket code behind a NetworkSystem

In preparation for other forms of communication, I have abstracted all
of the socket code (which needs direct connections) behind a struct
whose calls can be swapped out for other systems if desired.
This commit is contained in:
MysterD 2020-09-12 17:56:42 -07:00
parent 388470d579
commit 6c8050a564
13 changed files with 198 additions and 100 deletions

View file

@ -87,7 +87,7 @@
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<SDLCheck>true</SDLCheck>
<PreprocessorDefinitions>_DEBUG;_CONSOLE;WINSOCK;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PreprocessorDefinitions>_DEBUG;_CONSOLE;WINSOCK;DEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<ConformanceMode>true</ConformanceMode>
</ClCompile>
<Link>

View file

@ -440,7 +440,7 @@ void join_server_as_client(void) {
keyboard_stop_text_input();
joinVersionMismatch = FALSE;
network_init(NT_CLIENT, configJoinIp, configJoinPort);
network_init(NT_CLIENT);
}
void joined_server_as_client(s16 fileIndex) {
@ -1341,7 +1341,7 @@ void load_main_menu_save_file(struct Object *fileButton, s32 fileNum) {
if (fileButton->oMenuButtonState == MENU_BUTTON_STATE_FULLSCREEN) {
sSelectedFileNum = fileNum;
configHostSaveSlot = fileNum;
network_init(NT_SERVER, "", configHostPort);
network_init(NT_SERVER);
}
}

View file

@ -13,6 +13,10 @@
printf(" [%s] ", NETWORKTYPESTR);
}
static void debuglog_print_log_type(char* logType) {
printf("[%s] ", logType);
}
static void debuglog_print_short_filename(char* filename) {
char* last = strrchr(filename, '/');
if (last != NULL) {
@ -22,13 +26,16 @@
}
}
static void debuglog_print_log(char* filename) {
static void debuglog_print_log(char* logType, char* filename) {
debuglog_print_timestamp();
debuglog_print_network_type();
debuglog_print_log_type(logType);
debuglog_print_short_filename(filename);
}
#define LOG_INFO(...) ( debuglog_print_log(__FILE__), printf(__VA_ARGS__), printf("\n") )
#define LOG_INFO(...) ( debuglog_print_log("INFO ", __FILE__), printf(__VA_ARGS__), printf("\n") )
#define LOG_ERROR(...) ( debuglog_print_log("ERROR", __FILE__), printf(__VA_ARGS__), printf("\n") )
#else
#define LOG_INFO(...)
#define LOG_ERROR(...)
#endif

View file

@ -4,13 +4,13 @@
#include "object_constants.h"
#include "socket/socket.h"
#include "pc/configfile.h"
#include "pc/debuglog.h"
// Mario 64 specific externs
extern s16 sCurrPlayMode;
enum NetworkType gNetworkType = NT_NONE;
static SOCKET gSocket = 0;
struct sockaddr_in txAddr = { 0 };
struct NetworkSystem* gNetworkSystem = &gNetworkSystemSocket;
#define LOADING_LEVEL_THRESHOLD 10
u8 networkLoadingLevel = 0;
@ -21,16 +21,22 @@ struct ServerSettings gServerSettings = {
.playerKnockbackStrength = 25,
};
void network_init(enum NetworkType inNetworkType, char* ip, unsigned int port) {
bool network_init(enum NetworkType inNetworkType) {
// sanity check network system
if (gNetworkSystem == NULL) {
LOG_ERROR("no network system attached");
return false;
}
// initialize the network system
int rc = gNetworkSystem->initialize(inNetworkType);
if (!rc) {
LOG_ERROR("failed to initialize network system");
return false;
}
// set network type
gNetworkType = inNetworkType;
if (gNetworkType == NT_NONE) { return; }
// sanity check port
if (port == 0) {
port = (gNetworkType == NT_CLIENT) ? configJoinPort : configHostPort;
if (port == 0) { port = DEFAULT_PORT; }
}
// set server settings
if (gNetworkType == NT_SERVER) {
@ -39,26 +45,17 @@ void network_init(enum NetworkType inNetworkType, char* ip, unsigned int port) {
gServerSettings.stayInLevelAfterStar = configStayInLevelAfterStar;
}
// create a receiver socket to receive datagrams
gSocket = socket_initialize();
if (gSocket == INVALID_SOCKET) { return; }
// connect
if (gNetworkType == NT_SERVER) {
// bind the socket to any address and the specified port.
int rc = socket_bind(gSocket, port);
if (rc != NO_ERROR) { return; }
} else {
// save the port to send to
txAddr.sin_family = AF_INET;
txAddr.sin_port = htons(port);
txAddr.sin_addr.s_addr = inet_addr(ip);
// exit early if we're not really initializing the network
if (gNetworkType == NT_NONE) {
return true;
}
// send connection request
if (gNetworkType == NT_CLIENT) {
network_send_save_file_request();
}
return true;
}
void network_on_init_level(void) {
@ -77,7 +74,8 @@ void network_on_loaded_level(void) {
void network_send(struct Packet* p) {
// sanity checks
if (gNetworkType == NT_NONE) { return; }
if (p->error) { printf("%s packet error!\n", NETWORKTYPESTR); return; }
if (p->error) { LOG_ERROR("packet error!"); return; }
if (gNetworkSystem == NULL) { LOG_ERROR("no network system attached"); return; }
// remember reliable packets
network_remember_reliable(p);
@ -87,11 +85,51 @@ void network_send(struct Packet* p) {
memcpy(&p->buffer[p->dataLength], &hash, sizeof(u32));
// send
int rc = socket_send(gSocket, &txAddr, p->buffer, p->cursor + sizeof(u32));
int rc = gNetworkSystem->send(p->buffer, p->cursor + sizeof(u32));
if (rc != NO_ERROR) { return; }
p->sent = true;
}
void network_receive(u8* data, u16 dataLength) {
// receive packet
struct Packet p = {
.cursor = 3,
.buffer = { 0 },
.dataLength = dataLength,
};
memcpy(p.buffer, data, dataLength);
// subtract and check hash
p.dataLength -= sizeof(u32);
if (!packet_check_hash(&p)) {
LOG_ERROR("invalid packet hash!");
return;
}
// execute packet
switch ((u8)p.buffer[0]) {
case PACKET_ACK: network_receive_ack(&p); break;
case PACKET_PLAYER: network_receive_player(&p); break;
case PACKET_OBJECT: network_receive_object(&p); break;
case PACKET_SPAWN_OBJECTS: network_receive_spawn_objects(&p); break;
case PACKET_SPAWN_STAR: network_receive_spawn_star(&p); break;
case PACKET_LEVEL_WARP: network_receive_level_warp(&p); break;
case PACKET_INSIDE_PAINTING: network_receive_inside_painting(&p); break;
case PACKET_COLLECT_STAR: network_receive_collect_star(&p); break;
case PACKET_COLLECT_COIN: network_receive_collect_coin(&p); break;
case PACKET_COLLECT_ITEM: network_receive_collect_item(&p); break;
case PACKET_RESERVATION_REQUEST: network_receive_reservation_request(&p); break;
case PACKET_RESERVATION: network_receive_reservation(&p); break;
case PACKET_SAVE_FILE_REQUEST: network_receive_save_file_request(&p); break;
case PACKET_SAVE_FILE: network_receive_save_file(&p); break;
case PACKET_CUSTOM: network_receive_custom(&p); break;
default: LOG_ERROR("received unknown packet: %d", p.buffer[0]);
}
// send an ACK if requested
network_send_ack(&p);
}
void network_update(void) {
if (gNetworkType == NT_NONE) { return; }
@ -110,50 +148,18 @@ void network_update(void) {
}
// receive packets
do {
// receive packet
struct Packet p = { .cursor = 3 };
int rc = socket_receive(gSocket, &txAddr, p.buffer, PACKET_LENGTH, &p.dataLength);
if (rc != NO_ERROR) { break; }
// subtract and check hash
p.dataLength -= sizeof(u32);
if (!packet_check_hash(&p)) {
printf("Invalid packet!\n");
continue;
}
// execute packet
switch ((u8)p.buffer[0]) {
case PACKET_ACK: network_receive_ack(&p); break;
case PACKET_PLAYER: network_receive_player(&p); break;
case PACKET_OBJECT: network_receive_object(&p); break;
case PACKET_SPAWN_OBJECTS: network_receive_spawn_objects(&p); break;
case PACKET_SPAWN_STAR: network_receive_spawn_star(&p); break;
case PACKET_LEVEL_WARP: network_receive_level_warp(&p); break;
case PACKET_INSIDE_PAINTING: network_receive_inside_painting(&p); break;
case PACKET_COLLECT_STAR: network_receive_collect_star(&p); break;
case PACKET_COLLECT_COIN: network_receive_collect_coin(&p); break;
case PACKET_COLLECT_ITEM: network_receive_collect_item(&p); break;
case PACKET_RESERVATION_REQUEST: network_receive_reservation_request(&p); break;
case PACKET_RESERVATION: network_receive_reservation(&p); break;
case PACKET_SAVE_FILE_REQUEST: network_receive_save_file_request(&p); break;
case PACKET_SAVE_FILE: network_receive_save_file(&p); break;
case PACKET_CUSTOM: network_receive_custom(&p); break;
default: printf("%s received unknown packet: %d\n", NETWORKTYPESTR, p.buffer[0]);
}
// send an ACK if requested
network_send_ack(&p);
} while (1);
if (gNetworkSystem != NULL) {
gNetworkSystem->update();
}
// update reliable packets
network_update_reliable();
}
void network_shutdown(void) {
if (gNetworkType == NT_NONE) { return; }
// close down socket
socket_close(gSocket);
gNetworkType = NT_NONE;
if (gNetworkSystem == NULL) { LOG_ERROR("no network system attached"); return; }
gNetworkSystem->shutdown();
}

View file

@ -18,6 +18,13 @@ extern struct MarioState gMarioStates[];
#define PACKET_LENGTH 1024
#define NETWORKTYPESTR (gNetworkType == NT_CLIENT ? "Client" : "Server")
struct NetworkSystem {
bool (*initialize)(enum NetworkType);
void (*update)(void);
int (*send)(u8* data, u16 dataLength);
void (*shutdown)(void);
};
enum PacketType {
PACKET_ACK,
PACKET_PLAYER,
@ -85,10 +92,11 @@ extern struct SyncObject gSyncObjects[];
extern struct ServerSettings gServerSettings;
// network.c
void network_init(enum NetworkType inNetworkType, char* ip, unsigned int port);
bool network_init(enum NetworkType inNetworkType);
void network_on_init_level(void);
void network_on_loaded_level(void);
void network_send(struct Packet* p);
void network_receive(u8* data, u16 dataLength);
void network_update(void);
void network_shutdown(void);

View file

@ -57,7 +57,7 @@ void network_receive_inside_painting(struct Packet* p) {
// two-player hack: gControlledWarp is a bool instead of an index
if (gControlledWarp) {
LOG_INFO("this should never happen, received inside_painting when gControlledWarp");
LOG_ERROR("this should never happen, received inside_painting when gControlledWarp");
return;
}

View file

@ -1,40 +1,44 @@
#include <stdio.h>
#include "../network.h"
#include "socket.h"
#include "pc/configfile.h"
#include "pc/debuglog.h"
int socket_bind(SOCKET sock, unsigned int port) {
static SOCKET curSocket = INVALID_SOCKET;
struct sockaddr_in txAddr = { 0 };
static int socket_bind(SOCKET socket, unsigned int port) {
struct sockaddr_in rxAddr;
rxAddr.sin_family = AF_INET;
rxAddr.sin_port = htons(port);
rxAddr.sin_addr.s_addr = htonl(INADDR_ANY);
int rc = bind(sock, (SOCKADDR*)&rxAddr, sizeof(rxAddr));
int rc = bind(socket, (SOCKADDR*)&rxAddr, sizeof(rxAddr));
if (rc != 0) {
printf("%s bind failed with error %d\n", NETWORKTYPESTR, SOCKET_LAST_ERROR);
LOG_ERROR("bind failed with error %d", SOCKET_LAST_ERROR);
}
return rc;
}
int socket_send(SOCKET sock, struct sockaddr_in* txAddr, u8* buffer, u16 bufferLength) {
int txAddrSize = sizeof(struct sockaddr_in);
int rc = sendto(sock, (char*)buffer, bufferLength, 0, (struct sockaddr*)txAddr, txAddrSize);
static int socket_send(SOCKET socket, struct sockaddr_in* addr, u8* buffer, u16 bufferLength) {
int addrSize = sizeof(struct sockaddr_in);
int rc = sendto(socket, (char*)buffer, bufferLength, 0, (struct sockaddr*)addr, addrSize);
if (rc == SOCKET_ERROR) {
printf("%s sendto failed with error: %d\n", NETWORKTYPESTR, SOCKET_LAST_ERROR);
LOG_ERROR("sendto failed with error: %d", SOCKET_LAST_ERROR);
}
return rc;
}
int socket_receive(SOCKET sock, struct sockaddr_in* rxAddr, u8* buffer, u16 bufferLength, u16* receiveLength) {
static int socket_receive(SOCKET socket, struct sockaddr_in* rxAddr, u8* buffer, u16 bufferLength, u16* receiveLength) {
*receiveLength = 0;
int rxAddrSize = sizeof(struct sockaddr_in);
int rc = recvfrom(sock, (char*)buffer, bufferLength, 0, (struct sockaddr*)rxAddr, &rxAddrSize);
int rc = recvfrom(socket, (char*)buffer, bufferLength, 0, (struct sockaddr*)rxAddr, &rxAddrSize);
if (rc == SOCKET_ERROR) {
int error = SOCKET_LAST_ERROR;
if (error != SOCKET_EWOULDBLOCK && error != SOCKET_ECONNRESET) {
printf("%s recvfrom failed with error %d\n", NETWORKTYPESTR, SOCKET_LAST_ERROR);
LOG_ERROR("recvfrom failed with error %d", SOCKET_LAST_ERROR);
}
return rc;
}
@ -42,3 +46,60 @@ int socket_receive(SOCKET sock, struct sockaddr_in* rxAddr, u8* buffer, u16 buff
*receiveLength = rc;
return NO_ERROR;
}
static bool ns_socket_initialize(enum NetworkType networkType) {
// sanity check port
unsigned int port = (networkType == NT_CLIENT) ? configJoinPort : configHostPort;
if (port == 0) { port = DEFAULT_PORT; }
// create a receiver socket to receive datagrams
curSocket = socket_initialize();
if (curSocket == INVALID_SOCKET) { return false; }
// connect
if (networkType == NT_SERVER) {
// bind the socket to any address and the specified port.
int rc = socket_bind(curSocket, port);
if (rc != NO_ERROR) { return false; }
LOG_INFO("bound to port %u", port);
} else {
// save the port to send to
txAddr.sin_family = AF_INET;
txAddr.sin_port = htons(port);
txAddr.sin_addr.s_addr = inet_addr(configJoinIp);
LOG_INFO("connecting to %s %u", configJoinIp, port);
}
LOG_INFO("initialized");
// success
return true;
}
static void ns_socket_update(void) {
do {
// receive packet
u8 data[PACKET_LENGTH];
u16 dataLength = 0;
int rc = socket_receive(curSocket, &txAddr, data, PACKET_LENGTH, &dataLength);
if (rc != NO_ERROR) { break; }
network_receive(data, dataLength);
} while (true);
}
static int ns_socket_send(u8* data, u16 dataLength) {
return socket_send(curSocket, &txAddr, data, dataLength);
}
static void ns_socket_shutdown(void) {
socket_shutdown(curSocket);
curSocket = INVALID_SOCKET;
LOG_INFO("shutdown");
}
struct NetworkSystem gNetworkSystemSocket = {
.initialize = ns_socket_initialize,
.update = ns_socket_update,
.send = ns_socket_send,
.shutdown = ns_socket_shutdown,
};

View file

@ -1,16 +1,17 @@
#ifndef SOCKET_H
#define SOCKET_H
#include "../network.h"
#ifdef WINSOCK
#include "socket_windows.h"
#else
#include "socket_linux.h"
#endif
extern struct NetworkSystem gNetworkSystemSocket;
SOCKET socket_initialize(void);
int socket_bind(SOCKET sock, unsigned int port);
int socket_send(SOCKET sock, struct sockaddr_in* txAddr, u8* buffer, u16 bufferLength);
int socket_receive(SOCKET sock, struct sockaddr_in* rxAddr, u8* buffer, u16 bufferLength, u16* receiveLength);
void socket_close(SOCKET sock);
void socket_shutdown(SOCKET socket);
#endif

View file

@ -1,29 +1,31 @@
#ifndef WINSOCK
#include "socket_linux.h"
#include "../network.h"
#include "pc/debuglog.h"
SOCKET socket_initialize(void) {
// initialize socket
SOCKET sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
if (sock == INVALID_SOCKET) {
printf("%s socket failed with error %d\n", NETWORKTYPESTR, SOCKET_LAST_ERROR);
LOG_ERROR("socket failed with error %d", SOCKET_LAST_ERROR);
return INVALID_SOCKET;
}
// set non-blocking mode
int rc = fcntl(sock, F_SETFL, fcntl(sock, F_GETFL, 0) | O_NONBLOCK);
if (rc == INVALID_SOCKET) {
printf("%s fcntl failed with error: %d\n", NETWORKTYPESTR, rc);
LOG_ERROR("fcntl failed with error: %d", rc);
return INVALID_SOCKET;
}
return sock;
}
void socket_close(SOCKET sock) {
int rc = closesocket(sock);
void socket_shutdown(SOCKET socket) {
if (socket == INVALID_SOCKET) { return; }
int rc = closesocket(socket);
if (rc == SOCKET_ERROR) {
printf("%s closesocket failed with error %d\n", NETWORKTYPESTR, SOCKET_LAST_ERROR);
LOG_ERROR("closesocket failed with error %d\n", SOCKET_LAST_ERROR);
}
}

View file

@ -1,11 +1,13 @@
#ifndef SOCKET_LINUX_H
#define SOCKET_LINUX_H
#ifndef WINSOCK
#include <errno.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <fcntl.h>
#include "socket.h"
#define SOCKET unsigned int
#define INVALID_SOCKET (unsigned int)(-1)
@ -18,3 +20,4 @@
#define SOCKET_ECONNRESET ECONNRESET
#endif
#endif

View file

@ -1,21 +1,21 @@
#ifdef WINSOCK
#include <stdio.h>
#include "socket_windows.h"
#include "../network.h"
#include "pc/debuglog.h"
SOCKET socket_initialize(void) {
// start up winsock
WSADATA wsaData;
int rc = WSAStartup(MAKEWORD(2, 2), &wsaData);
if (rc != NO_ERROR) {
printf("%s WSAStartup failed with error %d\n", NETWORKTYPESTR, rc);
LOG_ERROR("WSAStartup failed with error %d", rc);
return INVALID_SOCKET;
}
// initialize socket
SOCKET sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
if (sock == INVALID_SOCKET) {
printf("%s socket failed with error %d\n", NETWORKTYPESTR, SOCKET_LAST_ERROR);
LOG_ERROR("socket failed with error %d", SOCKET_LAST_ERROR);
return INVALID_SOCKET;
}
@ -23,17 +23,18 @@ SOCKET socket_initialize(void) {
u_long iMode = 1;
rc = ioctlsocket(sock, FIONBIO, &iMode);
if (rc != NO_ERROR) {
printf("%s ioctlsocket failed with error: %d\n", NETWORKTYPESTR, rc);
LOG_ERROR("ioctlsocket failed with error: %d", rc);
return INVALID_SOCKET;
}
return sock;
}
void socket_close(SOCKET sock) {
int rc = closesocket(sock);
void socket_shutdown(SOCKET socket) {
if (socket == INVALID_SOCKET) { return; }
int rc = closesocket(socket);
if (rc == SOCKET_ERROR) {
printf("%s closesocket failed with error %d\n", NETWORKTYPESTR, SOCKET_LAST_ERROR);
LOG_ERROR("closesocket failed with error %d", SOCKET_LAST_ERROR);
}
WSACleanup();
}

View file

@ -3,6 +3,7 @@
#include <winsock2.h>
#include <ws2tcpip.h>
#include "socket.h"
#define SOCKET_LAST_ERROR WSAGetLastError()
#define SOCKET_EWOULDBLOCK WSAEWOULDBLOCK

View file

@ -261,7 +261,15 @@ void main_func(void) {
audio_api = &audio_null;
}
network_init(gCLIOpts.Network, gCLIOpts.JoinIp, gCLIOpts.NetworkPort);
if (gCLIOpts.Network == NT_CLIENT) {
strncpy(configJoinIp, gCLIOpts.JoinIp, IP_MAX_LEN);
configJoinPort = gCLIOpts.NetworkPort;
network_init(NT_CLIENT);
} else if (gCLIOpts.Network == NT_SERVER) {
configHostPort = gCLIOpts.NetworkPort;
network_init(NT_SERVER);
}
audio_init();
sound_init();