diff --git a/include/nn/socket.h b/include/nn/socket.h index 3e8290b..c07732a 100644 --- a/include/nn/socket.h +++ b/include/nn/socket.h @@ -38,13 +38,18 @@ s32 Connect(s32 socket, const sockaddr* address, u32 addressLen); Result Close(s32 socket); s32 Send(s32 socket, const void* data, ulong dataLen, s32 flags); +s32 SendTo(s32 socket, const void* data, ulong dataLen, s32 flags, const struct sockaddr* to, u32 toLen); s32 Recv(s32 socket, void* out, ulong outLen, s32 flags); +s32 RecvFrom(s32 socket, void* out, ulong outLen, s32 flags, struct sockaddr* from, u32* fromLen); +s32 GetSockName(s32 socket, struct sockaddr* name, u32 dataLen); u16 InetHtons(u16 val); s32 InetAton(const char* addressStr, in_addr* addressOut); struct hostent* GetHostByName(const char* name); u32 GetLastErrno(); +s32 Bind(s32 fd, sockaddr* addr, u32 addrlen); +s32 Poll(struct pollfd* fd, u64 addr, s32 timeout); } } diff --git a/include/packets/Packet.h b/include/packets/Packet.h index 400ee46..b312da5 100644 --- a/include/packets/Packet.h +++ b/include/packets/Packet.h @@ -26,6 +26,7 @@ enum PacketType : short { CAPTUREINF, CHANGESTAGE, CMD, + UDPINIT, End // end of enum for bounds checking }; @@ -44,6 +45,7 @@ USED static const char *packetNames[] = { "Capture Info", "Change Stage", "Server Command" + "Udp Initialization" }; enum SenderType { @@ -83,4 +85,5 @@ struct PACKED Packet { #include "packets/CaptureInf.h" #include "packets/HackCapInf.h" #include "packets/ChangeStagePacket.h" -#include "packets/InitPacket.h" \ No newline at end of file +#include "packets/InitPacket.h" +#include "packets/UdpPacket.h" diff --git a/include/packets/UdpPacket.h b/include/packets/UdpPacket.h new file mode 100644 index 0000000..b4c4b37 --- /dev/null +++ b/include/packets/UdpPacket.h @@ -0,0 +1,8 @@ +#pragma once + +#include "Packet.h" + +struct PACKED UdpInit : Packet { + UdpInit() : Packet() {this->mType = PacketType::UDPINIT; mPacketSize = sizeof(UdpInit) - sizeof(Packet);}; + u16 port = 0; +}; diff --git a/include/server/SocketClient.hpp b/include/server/SocketClient.hpp index 3037009..bd68919 100644 --- a/include/server/SocketClient.hpp +++ b/include/server/SocketClient.hpp @@ -28,13 +28,15 @@ class SocketClient : public SocketBase { bool RECV(); void printPacket(Packet* packet); bool isConnected() {return socket_log_state == SOCKET_LOG_CONNECTED; } + u16 getUdpPort(); sead::PtrArray mPacketQueue; private: int maxBufSize = 100; - s32 udp_socket; + s32 udp_socket; + sockaddr udp_addr; /** diff --git a/source/server/Client.cpp b/source/server/Client.cpp index bd5dc6c..d68b2aa 100644 --- a/source/server/Client.cpp +++ b/source/server/Client.cpp @@ -243,6 +243,8 @@ bool Client::startConnection() { waitingForInitPacket = false; } } + + } return mIsConnectionActive; @@ -370,7 +372,11 @@ void Client::readFunc() { } mSocket->SEND(&initPacket); // send initial packet - + + UdpInit udp_init = UdpInit(); + udp_init.port = mSocket->getUdpPort(); + mSocket->SEND(&udp_init) + nn::os::SleepThread(nn::TimeSpan::FromNanoSeconds(500000000)); // sleep for 0.5 seconds to let connection layout fully show (probably should find a better way to do this) mConnectionWait->tryEnd(); diff --git a/source/server/SocketClient.cpp b/source/server/SocketClient.cpp index 6aefcb9..bd0da88 100644 --- a/source/server/SocketClient.cpp +++ b/source/server/SocketClient.cpp @@ -7,6 +7,7 @@ #include "nn/result.h" #include "nn/socket.h" #include "packets/Packet.h" +#include "packets/UdpPacket.h" #include "types.h" nn::Result SocketClient::init(const char* ip, u16 port) { @@ -73,25 +74,10 @@ nn::Result SocketClient::init(const char* ip, u16 port) { return -1; } - udpAddress.address = 0; - udpAddress.port = 0; + udpAddress.address = hostAddress; + udpAddress.port = nn::socket::InetHtons(41553); udpAddress.family = 2; - - if ((nn::socket::Bind(this->udp_socket, &udpAddress, sizeof(serverAddress))).isFailure()){ - Logger::log("Udp Socket failed to bind"); - this->socket_errno = nn::socket::GetLastErrno(); - this->socket_log_state = SOCKET_LOG_UNAVAILABLE; - return -1; - } - - if((result = nn::socket::Connect(this->udp_socket, &serverAddress, sizeof(serverAddress))).isFailure()) { - Logger::log("Udp Socket Connection Failed!\n"); - this->socket_errno = nn::socket::GetLastErrno(); - this->socket_log_state = SOCKET_LOG_UNAVAILABLE; - return result; - } - - + this->udp_addr = udpAddress; this->socket_log_state = SOCKET_LOG_CONNECTED; @@ -101,6 +87,15 @@ nn::Result SocketClient::init(const char* ip, u16 port) { } +u16 SocketClient::getUdpPort() { + sockaddr udpAddress = { 0 }; + if (nn::socket::GetSockName(this->udp_socket, &udpAddress, sizeof(udpAddress)) <= 0) { + return 0; + } + + return udpAddress.port; +} + bool SocketClient::SEND(Packet *packet) { if (this->socket_log_state != SOCKET_LOG_CONNECTED) @@ -114,7 +109,7 @@ bool SocketClient::SEND(Packet *packet) { Logger::log("Sending packet: %s\n", packetNames[packet->mType]); } else { - if ((valread = nn::socket::Send(this->udp_socket, buffer, packet->mPacketSize + sizeof(Packet), 0) > 0)) { + if ((valread = nn::socket::SendTo(this->udp_socket, buffer, packet->mPacketSize + sizeof(Packet), 0, this->udp_addr, sizeof(this->udp_addr)) > 0)) { return true; } else { Logger::log("Failed to Fully Send Packet! Result: %d Type: %s Packet Size: %d\n", valread, packetNames[packet->mType], packet->mPacketSize); @@ -163,20 +158,32 @@ bool SocketClient::RECV() { } s32 fd = -1; + s32 index = -1; for (int i = 0; i < fd_count; i++){ if (pfds[i].revents & 1) { fd = pfds[i].fd; + index = i; } } if (fd == -1) { return true; } + sockaddr udp_addr = {0}; + u32 udp_size = 0; // read only the size of a header while(valread < headerSize) { - int result = nn::socket::Recv(fd, headerBuf + valread, - headerSize - valread, this->sock_flags); + int result = 0; + if (index == 0) { + result = nn::socket::Recv(fd, headerBuf + valread, + headerSize - valread, this->sock_flags); + } else { + result = nn::socket::RecvFrom(fd, headerBuf + valread, + headerSize - valread, this->sock_flags, + &udp_addr, &udp_size); + Logger::log("Got udp packet: %s %s %s\n",udp_addr.family, udp_addr.port, udp_addr.address.data); + } this->socket_errno = nn::socket::GetLastErrno();