diff --git a/include/nn/socket.h b/include/nn/socket.h index c07732a..0ed1c7c 100644 --- a/include/nn/socket.h +++ b/include/nn/socket.h @@ -42,8 +42,9 @@ s32 SendTo(s32 socket, const void* data, ulong dataLen, s32 flags, const struct s32 Recv(s32 socket, void* out, ulong outLen, s32 flags); s32 RecvFrom(s32 socket, void* out, ulong outLen, s32 flags, struct sockaddr* from, u32* fromLen); -s32 GetSockName(s32 socket, struct sockaddr* name, u32 dataLen); +s32 GetSockName(s32 socket, struct sockaddr* name, u32* dataLen); u16 InetHtons(u16 val); +u16 InetNtohs(u16 val); s32 InetAton(const char* addressStr, in_addr* addressOut); struct hostent* GetHostByName(const char* name); diff --git a/include/packets/HolePunchPacket.h b/include/packets/HolePunchPacket.h new file mode 100644 index 0000000..28380f0 --- /dev/null +++ b/include/packets/HolePunchPacket.h @@ -0,0 +1,7 @@ +#pragma once + +#include "Packet.h" + +struct PACKED HolePunch : Packet { + HolePunch() : Packet() {this->mType = PacketType::HOLEPUNCH; mPacketSize = sizeof(HolePunch) - sizeof(Packet);}; +}; diff --git a/include/packets/Packet.h b/include/packets/Packet.h index b312da5..92fb053 100644 --- a/include/packets/Packet.h +++ b/include/packets/Packet.h @@ -27,6 +27,7 @@ enum PacketType : short { CHANGESTAGE, CMD, UDPINIT, + HOLEPUNCH, End // end of enum for bounds checking }; @@ -44,8 +45,9 @@ USED static const char *packetNames[] = { "Moon Collection", "Capture Info", "Change Stage", - "Server Command" - "Udp Initialization" + "Server Command", + "Udp Initialization", + "Hole punch", }; enum SenderType { @@ -87,3 +89,4 @@ struct PACKED Packet { #include "packets/ChangeStagePacket.h" #include "packets/InitPacket.h" #include "packets/UdpPacket.h" +#include "packets/HolePunchPacket.h" diff --git a/include/server/SocketClient.hpp b/include/server/SocketClient.hpp index bd68919..34a2449 100644 --- a/include/server/SocketClient.hpp +++ b/include/server/SocketClient.hpp @@ -29,6 +29,8 @@ class SocketClient : public SocketBase { void printPacket(Packet* packet); bool isConnected() {return socket_log_state == SOCKET_LOG_CONNECTED; } u16 getUdpPort(); + s32 setPeerUdpPort(u16 port); + sead::PtrArray mPacketQueue; diff --git a/source/server/Client.cpp b/source/server/Client.cpp index d68b2aa..e88370c 100644 --- a/source/server/Client.cpp +++ b/source/server/Client.cpp @@ -373,10 +373,6 @@ void Client::readFunc() { mSocket->SEND(&initPacket); // send initial packet - UdpInit udp_init = UdpInit(); - udp_init.port = mSocket->getUdpPort(); - mSocket->SEND(&udp_init) - nn::os::SleepThread(nn::TimeSpan::FromNanoSeconds(500000000)); // sleep for 0.5 seconds to let connection layout fully show (probably should find a better way to do this) mConnectionWait->tryEnd(); @@ -480,6 +476,19 @@ void Client::readFunc() { maxPuppets = initPacket->maxPlayers - 1; break; } + case PacketType::UDPINIT: { + UdpInit* udpInit = (UdpInit*)curPacket; + Logger::log("Setting udp port %u\n", udpInit->port); + this->mSocket->setPeerUdpPort(udpInit->port); + + HolePunch punch = HolePunch(); + this->mSocket->SEND(&punch); + + UdpInit udp_init = UdpInit(); + udp_init.port = mSocket->getUdpPort(); + mSocket->SEND(&udp_init); + break; + } default: break; } diff --git a/source/server/SocketClient.cpp b/source/server/SocketClient.cpp index bd0da88..1de20fb 100644 --- a/source/server/SocketClient.cpp +++ b/source/server/SocketClient.cpp @@ -67,17 +67,27 @@ nn::Result SocketClient::init(const char* ip, u16 port) { return result; } - if ((this->udp_socket = nn::socket::Socket(2, 2, 17)) <= 0) { + if ((this->udp_socket = nn::socket::Socket(2, 2, 17)) < 0) { Logger::log("Udp Socket failed to create"); this->socket_errno = nn::socket::GetLastErrno(); this->socket_log_state = SOCKET_LOG_UNAVAILABLE; return -1; } - udpAddress.address = hostAddress; udpAddress.port = nn::socket::InetHtons(41553); udpAddress.family = 2; - this->udp_addr = udpAddress; + this->udp_addr = udpAddress; + + udpAddress.address = hostAddress; + udpAddress.port = nn::socket::InetHtons(57734); + udpAddress.family = 2; + + if((result = nn::socket::Connect(this->udp_socket, &udpAddress, sizeof(udpAddress))).isFailure()) { + Logger::log("Udp Socket Connection Failed!\n"); + this->socket_errno = nn::socket::GetLastErrno(); + this->socket_log_state = SOCKET_LOG_UNAVAILABLE; + return result; + } this->socket_log_state = SOCKET_LOG_CONNECTED; @@ -88,12 +98,33 @@ nn::Result SocketClient::init(const char* ip, u16 port) { } u16 SocketClient::getUdpPort() { - sockaddr udpAddress = { 0 }; - if (nn::socket::GetSockName(this->udp_socket, &udpAddress, sizeof(udpAddress)) <= 0) { - return 0; - } + sockaddr udpAddress = { 0 }; + u32 size = sizeof(udpAddress); + + nn::Result result; + if((result = nn::socket::GetSockName(this->udp_socket, &udpAddress, &size)).isFailure()) { + this->socket_errno = nn::socket::GetLastErrno(); + return 0; + } + + return nn::socket::InetNtohs(udpAddress.port); +} + + +s32 SocketClient::setPeerUdpPort(u16 port) { + u16 net_port = nn::socket::InetHtons(port); + this->udp_addr.port = net_port; + + nn::Result result; + if((result = nn::socket::Connect(this->udp_socket, &this->udp_addr, sizeof(this->udp_addr))).isFailure()) { + Logger::log("Udp Socket Connection Failed!\n"); + this->socket_errno = nn::socket::GetLastErrno(); + this->socket_log_state = SOCKET_LOG_UNAVAILABLE; + return -1; + } + + return 0; - return udpAddress.port; } bool SocketClient::SEND(Packet *packet) { @@ -105,22 +136,17 @@ bool SocketClient::SEND(Packet *packet) { int valread = 0; - if (packet->mType != PLAYERINF && packet->mType != HACKCAPINF) { + int fd = -1; + if (packet->mType != PLAYERINF && packet->mType != HACKCAPINF && packet->mType != HOLEPUNCH) { Logger::log("Sending packet: %s\n", packetNames[packet->mType]); + fd = this->socket_log_socket; } else { - if ((valread = nn::socket::SendTo(this->udp_socket, buffer, packet->mPacketSize + sizeof(Packet), 0, this->udp_addr, sizeof(this->udp_addr)) > 0)) { - return true; - } else { - Logger::log("Failed to Fully Send Packet! Result: %d Type: %s Packet Size: %d\n", valread, packetNames[packet->mType], packet->mPacketSize); - this->socket_errno = nn::socket::GetLastErrno(); - this->closeSocket(); - return false; - } + fd = this->udp_socket; } - if ((valread = nn::socket::Send(this->socket_log_socket, buffer, packet->mPacketSize + sizeof(Packet), 0) > 0)) { + if ((valread = nn::socket::Send(fd, buffer, packet->mPacketSize + sizeof(Packet), 0) > 0)) { return true; } else { Logger::log("Failed to Fully Send Packet! Result: %d Type: %s Packet Size: %d\n", valread, packetNames[packet->mType], packet->mPacketSize); @@ -140,11 +166,11 @@ bool SocketClient::RECV() { } int headerSize = sizeof(Packet); - char headerBuf[sizeof(Packet)] = {}; + char headerBuf[MAXPACKSIZE * 2] = {}; int valread = 0; const int fd_count = 2; - struct pollfd pfds[fd_count] = {0}; + struct pollfd pfds[fd_count] = {{0}, {0}}; pfds[0].fd = this->socket_log_socket; pfds[0].events = 1; pfds[0].revents = 0; @@ -153,7 +179,7 @@ bool SocketClient::RECV() { pfds[1].revents = 0; - if (poll(pfds, fd_count, -1) <= 0) { + if (nn::socket::Poll(pfds, fd_count, -1) <= 0) { return true; } @@ -169,21 +195,45 @@ bool SocketClient::RECV() { if (fd == -1) { return true; } - sockaddr udp_addr = {0}; - u32 udp_size = 0; +if (index == 1) { + int result = nn::socket::Recv(fd, headerBuf, sizeof(headerBuf), this->sock_flags); + if (result < headerSize){ + return true; + } + + Packet* header = reinterpret_cast(headerBuf); + int fullSize = header->mPacketSize + sizeof(Packet); + + if (result < fullSize || result > MAXPACKSIZE || fullSize > MAXPACKSIZE){ + return true; + } + + char* packetBuf = (char*)malloc(fullSize); + + if (!(header->mType > PacketType::UNKNOWN && header->mType < PacketType::End)) { + Logger::log("Failed to acquire valid packet type! Packet Type: %d Full Packet Size %d valread size: %d", header->mType, fullSize, result); + free(packetBuf); + return true; + } + + memcpy(packetBuf, headerBuf, fullSize); + + + Packet *packet = reinterpret_cast(packetBuf); + + if(mPacketQueue.size() < maxBufSize - 1) { + mPacketQueue.pushBack(packet); + } else { + free(packetBuf); + } + return true; + } // read only the size of a header while(valread < headerSize) { int result = 0; - if (index == 0) { - result = nn::socket::Recv(fd, headerBuf + valread, - headerSize - valread, this->sock_flags); - } else { - result = nn::socket::RecvFrom(fd, headerBuf + valread, - headerSize - valread, this->sock_flags, - &udp_addr, &udp_size); - Logger::log("Got udp packet: %s %s %s\n",udp_addr.family, udp_addr.port, udp_addr.address.data); - } + result = nn::socket::Recv(fd, headerBuf + valread, + headerSize - valread, this->sock_flags); this->socket_errno = nn::socket::GetLastErrno();