diff --git a/CMakeLists.txt b/CMakeLists.txt index 9eca14785..7d4f372d7 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,6 +63,18 @@ option(YUZU_DOWNLOAD_TIME_ZONE_DATA "Always download time zone binaries" OFF) CMAKE_DEPENDENT_OPTION(YUZU_USE_FASTER_LD "Check if a faster linker is available" ON "NOT WIN32" OFF) +set(DEFAULT_ENABLE_OPENSSL ON) +if (ANDROID OR WIN32 OR APPLE) + # - Windows defaults to the Schannel backend. + # - macOS defaults to the SecureTransport backend. + # - Android currently has no SSL backend as the NDK doesn't include any SSL + # library; a proper 'native' backend would have to go through Java. + # But you can force builds for those platforms to use OpenSSL if you have + # your own copy of it. + set(DEFAULT_ENABLE_OPENSSL OFF) +endif() +option(ENABLE_OPENSSL "Enable OpenSSL backend for ISslConnection" ${DEFAULT_ENABLE_OPENSSL}) + # On Android, fetch and compile libcxx before doing anything else if (ANDROID) set(CMAKE_SKIP_INSTALL_RULES ON) @@ -322,6 +334,10 @@ if (MINGW) find_library(MSWSOCK_LIBRARY mswsock REQUIRED) endif() +if(ENABLE_OPENSSL) + find_package(OpenSSL 1.1.1 REQUIRED) +endif() + # Please consider this as a stub if(ENABLE_QT6 AND Qt6_LOCATION) list(APPEND CMAKE_PREFIX_PATH "${Qt6_LOCATION}") diff --git a/README.md b/README.md index b2af18baf..e3250bc3c 100755 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ yuzu emulator early access ============= -This is the source code for early-access 3734. +This is the source code for early-access 3736. ## Legal Notice diff --git a/src/audio_core/device/device_session.cpp b/src/audio_core/device/device_session.cpp index e1d66ccd6..141eb6928 100755 --- a/src/audio_core/device/device_session.cpp +++ b/src/audio_core/device/device_session.cpp @@ -92,9 +92,9 @@ void DeviceSession::AppendBuffers(std::span buffers) { if (type == Sink::StreamType::In) { stream->AppendBuffer(new_buffer, tmp_samples); } else { - system.ApplicationMemory().ReadBlockUnsafe(buffer.samples, tmp_samples.data(), - buffer.size); - stream->AppendBuffer(new_buffer, tmp_samples); + Core::Memory::CpuGuestMemory samples( + system.ApplicationMemory(), buffer.samples, buffer.size / sizeof(s16)); + stream->AppendBuffer(new_buffer, samples); } } } diff --git a/src/audio_core/renderer/command/data_source/decode.cpp b/src/audio_core/renderer/command/data_source/decode.cpp index 19bbbc313..fd35571ac 100755 --- a/src/audio_core/renderer/command/data_source/decode.cpp +++ b/src/audio_core/renderer/command/data_source/decode.cpp @@ -28,7 +28,6 @@ constexpr std::array PitchBySrcQuality = {4, 8, 4}; template static u32 DecodePcm(Core::Memory::Memory& memory, std::span out_buffer, const DecodeArg& req) { - std::array tmp_samples{}; constexpr s32 min{std::numeric_limits::min()}; constexpr s32 max{std::numeric_limits::max()}; @@ -49,19 +48,18 @@ static u32 DecodePcm(Core::Memory::Memory& memory, std::span out_buffer, const VAddr source{req.buffer + (((req.start_offset + req.offset) * channel_count) * sizeof(T))}; const u64 size{channel_count * samples_to_decode}; - const u64 size_bytes{size * sizeof(T)}; - - memory.ReadBlockUnsafe(source, tmp_samples.data(), size_bytes); + Core::Memory::CpuGuestMemory samples( + memory, source, size); if constexpr (std::is_floating_point_v) { for (u32 i = 0; i < samples_to_decode; i++) { - auto sample{static_cast(tmp_samples[i * channel_count + req.target_channel] * + auto sample{static_cast(samples[i * channel_count + req.target_channel] * std::numeric_limits::max())}; out_buffer[i] = static_cast(std::clamp(sample, min, max)); } } else { for (u32 i = 0; i < samples_to_decode; i++) { - out_buffer[i] = tmp_samples[i * channel_count + req.target_channel]; + out_buffer[i] = samples[i * channel_count + req.target_channel]; } } } break; @@ -74,16 +72,17 @@ static u32 DecodePcm(Core::Memory::Memory& memory, std::span out_buffer, } const VAddr source{req.buffer + ((req.start_offset + req.offset) * sizeof(T))}; - memory.ReadBlockUnsafe(source, tmp_samples.data(), samples_to_decode * sizeof(T)); + Core::Memory::CpuGuestMemory samples( + memory, source, samples_to_decode); if constexpr (std::is_floating_point_v) { for (u32 i = 0; i < samples_to_decode; i++) { - auto sample{static_cast(tmp_samples[i * channel_count + req.target_channel] * + auto sample{static_cast(samples[i * channel_count + req.target_channel] * std::numeric_limits::max())}; out_buffer[i] = static_cast(std::clamp(sample, min, max)); } } else { - std::memcpy(out_buffer.data(), tmp_samples.data(), samples_to_decode * sizeof(s16)); + std::memcpy(out_buffer.data(), samples.data(), samples_to_decode * sizeof(s16)); } break; } @@ -101,7 +100,6 @@ static u32 DecodePcm(Core::Memory::Memory& memory, std::span out_buffer, */ static u32 DecodeAdpcm(Core::Memory::Memory& memory, std::span out_buffer, const DecodeArg& req) { - std::array wavebuffer{}; constexpr u32 SamplesPerFrame{14}; constexpr u32 NibblesPerFrame{16}; @@ -139,7 +137,8 @@ static u32 DecodeAdpcm(Core::Memory::Memory& memory, std::span out_buffer, } const auto size{std::max((samples_to_process / 8U) * SamplesPerFrame, 8U)}; - memory.ReadBlockUnsafe(req.buffer + position_in_frame / 2, wavebuffer.data(), size); + Core::Memory::CpuGuestMemory wavebuffer( + memory, req.buffer + position_in_frame / 2, size); auto context{req.adpcm_context}; auto header{context->header}; diff --git a/src/audio_core/renderer/command/effect/aux_.cpp b/src/audio_core/renderer/command/effect/aux_.cpp index e487feae0..03f1c6b42 100755 --- a/src/audio_core/renderer/command/effect/aux_.cpp +++ b/src/audio_core/renderer/command/effect/aux_.cpp @@ -21,23 +21,13 @@ static void ResetAuxBufferDsp(Core::Memory::Memory& memory, const CpuAddr aux_in } AuxInfo::AuxInfoDsp info{}; - auto info_ptr{&info}; - bool host_safe{(aux_info & Core::Memory::YUZU_PAGEMASK) <= - (Core::Memory::YUZU_PAGESIZE - sizeof(AuxInfo::AuxInfoDsp))}; + memory.ReadBlockUnsafe(aux_info, &info, sizeof(AuxInfo::AuxInfoDsp)); - if (host_safe) [[likely]] { - info_ptr = memory.GetPointer(aux_info); - } else { - memory.ReadBlockUnsafe(aux_info, info_ptr, sizeof(AuxInfo::AuxInfoDsp)); - } + info.read_offset = 0; + info.write_offset = 0; + info.total_sample_count = 0; - info_ptr->read_offset = 0; - info_ptr->write_offset = 0; - info_ptr->total_sample_count = 0; - - if (!host_safe) [[unlikely]] { - memory.WriteBlockUnsafe(aux_info, info_ptr, sizeof(AuxInfo::AuxInfoDsp)); - } + memory.WriteBlockUnsafe(aux_info, &info, sizeof(AuxInfo::AuxInfoDsp)); } /** @@ -86,17 +76,9 @@ static u32 WriteAuxBufferDsp(Core::Memory::Memory& memory, CpuAddr send_info_, } AuxInfo::AuxInfoDsp send_info{}; - auto send_ptr = &send_info; - bool host_safe = (send_info_ & Core::Memory::YUZU_PAGEMASK) <= - (Core::Memory::YUZU_PAGESIZE - sizeof(AuxInfo::AuxInfoDsp)); + memory.ReadBlockUnsafe(send_info_, &send_info, sizeof(AuxInfo::AuxInfoDsp)); - if (host_safe) [[likely]] { - send_ptr = memory.GetPointer(send_info_); - } else { - memory.ReadBlockUnsafe(send_info_, send_ptr, sizeof(AuxInfo::AuxInfoDsp)); - } - - u32 target_write_offset{send_ptr->write_offset + write_offset}; + u32 target_write_offset{send_info.write_offset + write_offset}; if (target_write_offset > count_max) { return 0; } @@ -105,15 +87,9 @@ static u32 WriteAuxBufferDsp(Core::Memory::Memory& memory, CpuAddr send_info_, u32 read_pos{0}; while (write_count > 0) { u32 to_write{std::min(count_max - target_write_offset, write_count)}; - const auto write_addr = send_buffer + target_write_offset * sizeof(s32); - bool write_safe{(write_addr & Core::Memory::YUZU_PAGEMASK) <= - (Core::Memory::YUZU_PAGESIZE - (write_addr + to_write * sizeof(s32)))}; - if (write_safe) [[likely]] { - auto ptr = memory.GetPointer(write_addr); - std::memcpy(ptr, &input[read_pos], to_write * sizeof(s32)); - } else { - memory.WriteBlockUnsafe(send_buffer + target_write_offset * sizeof(s32), - &input[read_pos], to_write * sizeof(s32)); + if (to_write > 0) { + const auto write_addr = send_buffer + target_write_offset * sizeof(s32); + memory.WriteBlockUnsafe(write_addr, &input[read_pos], to_write * sizeof(s32)); } target_write_offset = (target_write_offset + to_write) % count_max; write_count -= to_write; @@ -121,13 +97,10 @@ static u32 WriteAuxBufferDsp(Core::Memory::Memory& memory, CpuAddr send_info_, } if (update_count) { - send_ptr->write_offset = (send_ptr->write_offset + update_count) % count_max; - } - - if (!host_safe) [[unlikely]] { - memory.WriteBlockUnsafe(send_info_, send_ptr, sizeof(AuxInfo::AuxInfoDsp)); + send_info.write_offset = (send_info.write_offset + update_count) % count_max; } + memory.WriteBlockUnsafe(send_info_, &send_info, sizeof(AuxInfo::AuxInfoDsp)); return write_count_; } @@ -174,17 +147,9 @@ static u32 ReadAuxBufferDsp(Core::Memory::Memory& memory, CpuAddr return_info_, } AuxInfo::AuxInfoDsp return_info{}; - auto return_ptr = &return_info; - bool host_safe = (return_info_ & Core::Memory::YUZU_PAGEMASK) <= - (Core::Memory::YUZU_PAGESIZE - sizeof(AuxInfo::AuxInfoDsp)); + memory.ReadBlockUnsafe(return_info_, &return_info, sizeof(AuxInfo::AuxInfoDsp)); - if (host_safe) [[likely]] { - return_ptr = memory.GetPointer(return_info_); - } else { - memory.ReadBlockUnsafe(return_info_, return_ptr, sizeof(AuxInfo::AuxInfoDsp)); - } - - u32 target_read_offset{return_ptr->read_offset + read_offset}; + u32 target_read_offset{return_info.read_offset + read_offset}; if (target_read_offset > count_max) { return 0; } @@ -193,15 +158,9 @@ static u32 ReadAuxBufferDsp(Core::Memory::Memory& memory, CpuAddr return_info_, u32 write_pos{0}; while (read_count > 0) { u32 to_read{std::min(count_max - target_read_offset, read_count)}; - const auto read_addr = return_buffer + target_read_offset * sizeof(s32); - bool read_safe{(read_addr & Core::Memory::YUZU_PAGEMASK) <= - (Core::Memory::YUZU_PAGESIZE - (read_addr + to_read * sizeof(s32)))}; - if (read_safe) [[likely]] { - auto ptr = memory.GetPointer(read_addr); - std::memcpy(&output[write_pos], ptr, to_read * sizeof(s32)); - } else { - memory.ReadBlockUnsafe(return_buffer + target_read_offset * sizeof(s32), - &output[write_pos], to_read * sizeof(s32)); + if (to_read > 0) { + const auto read_addr = return_buffer + target_read_offset * sizeof(s32); + memory.ReadBlockUnsafe(read_addr, &output[write_pos], to_read * sizeof(s32)); } target_read_offset = (target_read_offset + to_read) % count_max; read_count -= to_read; @@ -209,13 +168,10 @@ static u32 ReadAuxBufferDsp(Core::Memory::Memory& memory, CpuAddr return_info_, } if (update_count) { - return_ptr->read_offset = (return_ptr->read_offset + update_count) % count_max; - } - - if (!host_safe) [[unlikely]] { - memory.WriteBlockUnsafe(return_info_, return_ptr, sizeof(AuxInfo::AuxInfoDsp)); + return_info.read_offset = (return_info.read_offset + update_count) % count_max; } + memory.WriteBlockUnsafe(return_info_, &return_info, sizeof(AuxInfo::AuxInfoDsp)); return read_count_; } diff --git a/src/common/page_table.cpp b/src/common/page_table.cpp index 9c1fdcd4b..01fcdc5c0 100755 --- a/src/common/page_table.cpp +++ b/src/common/page_table.cpp @@ -66,6 +66,7 @@ void PageTable::Resize(std::size_t address_space_width_in_bits, std::size_t page << (address_space_width_in_bits - page_size_in_bits)}; pointers.resize(num_page_table_entries); backing_addr.resize(num_page_table_entries); + blocks.resize(num_page_table_entries); current_address_space_width_in_bits = address_space_width_in_bits; page_size = 1ULL << page_size_in_bits; } diff --git a/src/common/page_table.h b/src/common/page_table.h index 6eaa28ba2..edf1e4dcc 100755 --- a/src/common/page_table.h +++ b/src/common/page_table.h @@ -122,6 +122,7 @@ struct PageTable { * corresponding attribute element is of type `Memory`. */ VirtualBuffer pointers; + VirtualBuffer blocks; VirtualBuffer backing_addr; diff --git a/src/common/socket_types.h b/src/common/socket_types.h index 05455490b..8520dc495 100755 --- a/src/common/socket_types.h +++ b/src/common/socket_types.h @@ -5,15 +5,19 @@ #include "common/common_types.h" +#include + namespace Network { /// Address families enum class Domain : u8 { - INET, ///< Address family for IPv4 + Unspecified, ///< Represents 0, used in getaddrinfo hints + INET, ///< Address family for IPv4 }; /// Socket types enum class Type { + Unspecified, ///< Represents 0, used in getaddrinfo hints STREAM, DGRAM, RAW, @@ -22,6 +26,7 @@ enum class Type { /// Protocol values for sockets enum class Protocol : u8 { + Unspecified, ///< Represents 0, usable in various places ICMP, TCP, UDP, @@ -48,4 +53,13 @@ constexpr u32 FLAG_MSG_PEEK = 0x2; constexpr u32 FLAG_MSG_DONTWAIT = 0x80; constexpr u32 FLAG_O_NONBLOCK = 0x800; +/// Cross-platform addrinfo structure +struct AddrInfo { + Domain family; + Type socket_type; + Protocol protocol; + SockAddrIn addr; + std::optional canon_name; +}; + } // namespace Network diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index b154ba87d..254db4f7f 100755 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -722,6 +722,7 @@ add_library(core STATIC hle/service/spl/spl_types.h hle/service/ssl/ssl.cpp hle/service/ssl/ssl.h + hle/service/ssl/ssl_backend.h hle/service/time/clock_types.h hle/service/time/ephemeral_network_system_clock_context_writer.h hle/service/time/ephemeral_network_system_clock_core.h @@ -863,6 +864,23 @@ if (ARCHITECTURE_x86_64 OR ARCHITECTURE_arm64) target_link_libraries(core PRIVATE dynarmic::dynarmic) endif() +if(ENABLE_OPENSSL) + target_sources(core PRIVATE + hle/service/ssl/ssl_backend_openssl.cpp) + target_link_libraries(core PRIVATE OpenSSL::SSL) +elseif (APPLE) + target_sources(core PRIVATE + hle/service/ssl/ssl_backend_securetransport.cpp) + target_link_libraries(core PRIVATE "-framework Security") +elseif (WIN32) + target_sources(core PRIVATE + hle/service/ssl/ssl_backend_schannel.cpp) + target_link_libraries(core PRIVATE secur32) +else() + target_sources(core PRIVATE + hle/service/ssl/ssl_backend_none.cpp) +endif() + if (YUZU_USE_PRECOMPILED_HEADERS) target_precompile_headers(core PRIVATE precompiled_headers.h) endif() diff --git a/src/core/core_timing.cpp b/src/core/core_timing.cpp index c7d2bbda9..521fba414 100755 --- a/src/core/core_timing.cpp +++ b/src/core/core_timing.cpp @@ -70,7 +70,7 @@ void CoreTiming::Initialize(std::function&& on_thread_init_) { -> std::optional { return std::nullopt; }; ev_lost = CreateEvent("_lost_event", empty_timed_callback); if (is_multicore) { - timer_thread = std::make_unique(ThreadEntry, std::ref(*this)); + timer_thread = std::make_unique(ThreadEntry, std::ref(*this)); } } @@ -253,12 +253,8 @@ void CoreTiming::ThreadLoop() { auto wait_time = *next_time - GetGlobalTimeNs().count(); if (wait_time > 0) { #ifdef _WIN32 - const auto timer_resolution_ns = - Common::Windows::GetCurrentTimerResolution().count(); - while (!paused && !event.IsSet() && wait_time > 0) { wait_time = *next_time - GetGlobalTimeNs().count(); - if (wait_time >= timer_resolution_ns) { Common::Windows::SleepForOneTick(); } else { @@ -316,4 +312,10 @@ std::chrono::microseconds CoreTiming::GetGlobalTimeUs() const { return std::chrono::microseconds{Common::WallClock::CPUTickToUS(cpu_ticks)}; } +#ifdef _WIN32 +void CoreTiming::SetTimerResolutionNs(std::chrono::nanoseconds ns) { + timer_resolution_ns = ns.count(); +} +#endif + } // namespace Core::Timing diff --git a/src/core/core_timing.h b/src/core/core_timing.h index 3e58d4675..8fb8257de 100755 --- a/src/core/core_timing.h +++ b/src/core/core_timing.h @@ -131,6 +131,10 @@ public: /// Checks for events manually and returns time in nanoseconds for next event, threadsafe. std::optional Advance(); +#ifdef _WIN32 + void SetTimerResolutionNs(std::chrono::nanoseconds ns); +#endif + private: struct Event; @@ -143,6 +147,10 @@ private: s64 global_timer = 0; +#ifdef _WIN32 + s64 timer_resolution_ns; +#endif + // The queue is a min-heap using std::make_heap/push_heap/pop_heap. // We don't use std::priority_queue because we need to be able to serialize, unserialize and // erase arbitrary events (RemoveEvent()) regardless of the queue order. These aren't @@ -155,7 +163,7 @@ private: Common::Event pause_event{}; std::mutex basic_lock; std::mutex advance_lock; - std::unique_ptr timer_thread; + std::unique_ptr timer_thread; std::atomic paused{}; std::atomic paused_set{}; std::atomic wait_set{}; diff --git a/src/core/hle/service/hle_ipc.cpp b/src/core/hle/service/hle_ipc.cpp index 2290df705..f6a1e54f2 100755 --- a/src/core/hle/service/hle_ipc.cpp +++ b/src/core/hle/service/hle_ipc.cpp @@ -329,8 +329,22 @@ std::vector HLERequestContext::ReadBufferCopy(std::size_t buffer_index) cons } std::span HLERequestContext::ReadBuffer(std::size_t buffer_index) const { - static thread_local std::array, 2> read_buffer_a; - static thread_local std::array, 2> read_buffer_x; + static thread_local std::array read_buffer_a{ + Core::Memory::CpuGuestMemory(memory, 0, 0), + Core::Memory::CpuGuestMemory(memory, 0, 0), + }; + static thread_local std::array read_buffer_data_a{ + Common::ScratchBuffer(), + Common::ScratchBuffer(), + }; + static thread_local std::array read_buffer_x{ + Core::Memory::CpuGuestMemory(memory, 0, 0), + Core::Memory::CpuGuestMemory(memory, 0, 0), + }; + static thread_local std::array read_buffer_data_x{ + Common::ScratchBuffer(), + Common::ScratchBuffer(), + }; const bool is_buffer_a{BufferDescriptorA().size() > buffer_index && BufferDescriptorA()[buffer_index].Size()}; @@ -339,19 +353,17 @@ std::span HLERequestContext::ReadBuffer(std::size_t buffer_index) cons BufferDescriptorA().size() > buffer_index, { return {}; }, "BufferDescriptorA invalid buffer_index {}", buffer_index); auto& read_buffer = read_buffer_a[buffer_index]; - read_buffer.resize_destructive(BufferDescriptorA()[buffer_index].Size()); - memory.ReadBlock(BufferDescriptorA()[buffer_index].Address(), read_buffer.data(), - read_buffer.size()); - return read_buffer; + return read_buffer.Read(BufferDescriptorA()[buffer_index].Address(), + BufferDescriptorA()[buffer_index].Size(), + &read_buffer_data_a[buffer_index]); } else { ASSERT_OR_EXECUTE_MSG( BufferDescriptorX().size() > buffer_index, { return {}; }, "BufferDescriptorX invalid buffer_index {}", buffer_index); auto& read_buffer = read_buffer_x[buffer_index]; - read_buffer.resize_destructive(BufferDescriptorX()[buffer_index].Size()); - memory.ReadBlock(BufferDescriptorX()[buffer_index].Address(), read_buffer.data(), - read_buffer.size()); - return read_buffer; + return read_buffer.Read(BufferDescriptorX()[buffer_index].Address(), + BufferDescriptorX()[buffer_index].Size(), + &read_buffer_data_x[buffer_index]); } } diff --git a/src/core/hle/service/sockets/bsd.cpp b/src/core/hle/service/sockets/bsd.cpp index 19b44749a..858d94e3b 100755 --- a/src/core/hle/service/sockets/bsd.cpp +++ b/src/core/hle/service/sockets/bsd.cpp @@ -20,6 +20,9 @@ #include "core/internal_network/sockets.h" #include "network/network.h" +using Common::Expected; +using Common::Unexpected; + namespace Service::Sockets { namespace { @@ -265,16 +268,19 @@ void BSD::GetSockOpt(HLERequestContext& ctx) { const u32 level = rp.Pop(); const auto optname = static_cast(rp.Pop()); - LOG_WARNING(Service, "(STUBBED) called. fd={} level={} optname=0x{:x}", fd, level, optname); - std::vector optval(ctx.GetWriteBufferSize()); + LOG_DEBUG(Service, "called. fd={} level={} optname=0x{:x} len=0x{:x}", fd, level, optname, + optval.size()); + + const Errno err = GetSockOptImpl(fd, level, optname, optval); + ctx.WriteBuffer(optval); IPC::ResponseBuilder rb{ctx, 5}; rb.Push(ResultSuccess); - rb.Push(-1); - rb.PushEnum(Errno::NOTCONN); + rb.Push(err == Errno::SUCCESS ? 0 : -1); + rb.PushEnum(err); rb.Push(static_cast(optval.size())); } @@ -436,6 +442,31 @@ void BSD::Close(HLERequestContext& ctx) { BuildErrnoResponse(ctx, CloseImpl(fd)); } +void BSD::DuplicateSocket(HLERequestContext& ctx) { + struct InputParameters { + s32 fd; + u64 reserved; + }; + static_assert(sizeof(InputParameters) == 0x10); + + struct OutputParameters { + s32 ret; + Errno bsd_errno; + }; + static_assert(sizeof(OutputParameters) == 0x8); + + IPC::RequestParser rp{ctx}; + auto input = rp.PopRaw(); + + Expected res = DuplicateSocketImpl(input.fd); + IPC::ResponseBuilder rb{ctx, 4}; + rb.Push(ResultSuccess); + rb.PushRaw(OutputParameters{ + .ret = res.value_or(0), + .bsd_errno = res ? Errno::SUCCESS : res.error(), + }); +} + void BSD::EventFd(HLERequestContext& ctx) { IPC::RequestParser rp{ctx}; const u64 initval = rp.Pop(); @@ -477,12 +508,12 @@ std::pair BSD::SocketImpl(Domain domain, Type type, Protocol protoco auto room_member = room_network.GetRoomMember().lock(); if (room_member && room_member->IsConnected()) { - descriptor.socket = std::make_unique(room_network); + descriptor.socket = std::make_shared(room_network); } else { - descriptor.socket = std::make_unique(); + descriptor.socket = std::make_shared(); } - descriptor.socket->Initialize(Translate(domain), Translate(type), Translate(type, protocol)); + descriptor.socket->Initialize(Translate(domain), Translate(type), Translate(protocol)); descriptor.is_connection_based = IsConnectionBased(type); return {fd, Errno::SUCCESS}; @@ -538,7 +569,7 @@ std::pair BSD::PollImpl(std::vector& write_buffer, std::spansocket.get(); - result.events = TranslatePollEventsToHost(pollfd.events); + result.events = Translate(pollfd.events); result.revents = Network::PollEvents{}; return result; }); @@ -547,7 +578,7 @@ std::pair BSD::PollImpl(std::vector& write_buffer, std::span& write_buffer) { } const SockAddrIn guest_addrin = Translate(addr_in); - ASSERT(write_buffer.size() == sizeof(guest_addrin)); + ASSERT(write_buffer.size() >= sizeof(guest_addrin)); + write_buffer.resize(sizeof(guest_addrin)); std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin)); return Translate(bsd_errno); } @@ -633,7 +665,8 @@ Errno BSD::GetSockNameImpl(s32 fd, std::vector& write_buffer) { } const SockAddrIn guest_addrin = Translate(addr_in); - ASSERT(write_buffer.size() == sizeof(guest_addrin)); + ASSERT(write_buffer.size() >= sizeof(guest_addrin)); + write_buffer.resize(sizeof(guest_addrin)); std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin)); return Translate(bsd_errno); } @@ -671,13 +704,47 @@ std::pair BSD::FcntlImpl(s32 fd, FcntlCmd cmd, s32 arg) { } } -Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval) { - UNIMPLEMENTED_IF(level != 0xffff); // SOL_SOCKET - +Errno BSD::GetSockOptImpl(s32 fd, u32 level, OptName optname, std::vector& optval) { if (!IsFileDescriptorValid(fd)) { return Errno::BADF; } + if (level != static_cast(SocketLevel::SOCKET)) { + UNIMPLEMENTED_MSG("Unknown getsockopt level"); + return Errno::SUCCESS; + } + + Network::SocketBase* const socket = file_descriptors[fd]->socket.get(); + + switch (optname) { + case OptName::ERROR_: { + auto [pending_err, getsockopt_err] = socket->GetPendingError(); + if (getsockopt_err == Network::Errno::SUCCESS) { + Errno translated_pending_err = Translate(pending_err); + ASSERT_OR_EXECUTE_MSG( + optval.size() == sizeof(Errno), { return Errno::INVAL; }, + "Incorrect getsockopt option size"); + optval.resize(sizeof(Errno)); + memcpy(optval.data(), &translated_pending_err, sizeof(Errno)); + } + return Translate(getsockopt_err); + } + default: + UNIMPLEMENTED_MSG("Unimplemented optname={}", optname); + return Errno::SUCCESS; + } +} + +Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval) { + if (!IsFileDescriptorValid(fd)) { + return Errno::BADF; + } + + if (level != static_cast(SocketLevel::SOCKET)) { + UNIMPLEMENTED_MSG("Unknown setsockopt level"); + return Errno::SUCCESS; + } + Network::SocketBase* const socket = file_descriptors[fd]->socket.get(); if (optname == OptName::LINGER) { @@ -711,6 +778,9 @@ Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, con return Translate(socket->SetSndTimeo(value)); case OptName::RCVTIMEO: return Translate(socket->SetRcvTimeo(value)); + case OptName::NOSIGPIPE: + LOG_WARNING(Service, "(STUBBED) setting NOSIGPIPE to {}", value); + return Errno::SUCCESS; default: UNIMPLEMENTED_MSG("Unimplemented optname={}", optname); return Errno::SUCCESS; @@ -841,6 +911,28 @@ Errno BSD::CloseImpl(s32 fd) { return bsd_errno; } +Expected BSD::DuplicateSocketImpl(s32 fd) { + if (!IsFileDescriptorValid(fd)) { + return Unexpected(Errno::BADF); + } + + const s32 new_fd = FindFreeFileDescriptorHandle(); + if (new_fd < 0) { + LOG_ERROR(Service, "No more file descriptors available"); + return Unexpected(Errno::MFILE); + } + + file_descriptors[new_fd] = file_descriptors[fd]; + return new_fd; +} + +std::optional> BSD::GetSocket(s32 fd) { + if (!IsFileDescriptorValid(fd)) { + return std::nullopt; + } + return file_descriptors[fd]->socket; +} + s32 BSD::FindFreeFileDescriptorHandle() noexcept { for (s32 fd = 0; fd < static_cast(file_descriptors.size()); ++fd) { if (!file_descriptors[fd]) { @@ -911,7 +1003,7 @@ BSD::BSD(Core::System& system_, const char* name) {24, &BSD::Write, "Write"}, {25, &BSD::Read, "Read"}, {26, &BSD::Close, "Close"}, - {27, nullptr, "DuplicateSocket"}, + {27, &BSD::DuplicateSocket, "DuplicateSocket"}, {28, nullptr, "GetResourceStatistics"}, {29, nullptr, "RecvMMsg"}, {30, nullptr, "SendMMsg"}, diff --git a/src/core/hle/service/sockets/bsd.h b/src/core/hle/service/sockets/bsd.h index 945119199..6a9c7a334 100755 --- a/src/core/hle/service/sockets/bsd.h +++ b/src/core/hle/service/sockets/bsd.h @@ -8,6 +8,7 @@ #include #include "common/common_types.h" +#include "common/expected.h" #include "common/socket_types.h" #include "core/hle/service/service.h" #include "core/hle/service/sockets/sockets.h" @@ -29,12 +30,19 @@ public: explicit BSD(Core::System& system_, const char* name); ~BSD() override; + // These methods are called from SSL; the first two are also called from + // this class for the corresponding IPC methods. + // On the real device, the SSL service makes IPC calls to this service. + Common::Expected DuplicateSocketImpl(s32 fd); + Errno CloseImpl(s32 fd); + std::optional> GetSocket(s32 fd); + private: /// Maximum number of file descriptors static constexpr size_t MAX_FD = 128; struct FileDescriptor { - std::unique_ptr socket; + std::shared_ptr socket; s32 flags = 0; bool is_connection_based = false; }; @@ -138,6 +146,7 @@ private: void Write(HLERequestContext& ctx); void Read(HLERequestContext& ctx); void Close(HLERequestContext& ctx); + void DuplicateSocket(HLERequestContext& ctx); void EventFd(HLERequestContext& ctx); template @@ -153,6 +162,7 @@ private: Errno GetSockNameImpl(s32 fd, std::vector& write_buffer); Errno ListenImpl(s32 fd, s32 backlog); std::pair FcntlImpl(s32 fd, FcntlCmd cmd, s32 arg); + Errno GetSockOptImpl(s32 fd, u32 level, OptName optname, std::vector& optval); Errno SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval); Errno ShutdownImpl(s32 fd, s32 how); std::pair RecvImpl(s32 fd, u32 flags, std::vector& message); @@ -161,7 +171,6 @@ private: std::pair SendImpl(s32 fd, u32 flags, std::span message); std::pair SendToImpl(s32 fd, u32 flags, std::span message, std::span addr); - Errno CloseImpl(s32 fd); s32 FindFreeFileDescriptorHandle() noexcept; bool IsFileDescriptorValid(s32 fd) const noexcept; diff --git a/src/core/hle/service/sockets/nsd.cpp b/src/core/hle/service/sockets/nsd.cpp index b00e76b43..052af6d1b 100755 --- a/src/core/hle/service/sockets/nsd.cpp +++ b/src/core/hle/service/sockets/nsd.cpp @@ -1,10 +1,15 @@ // SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project // SPDX-License-Identifier: GPL-2.0-or-later +#include "core/hle/service/ipc_helpers.h" #include "core/hle/service/sockets/nsd.h" +#include "common/string_util.h" + namespace Service::Sockets { +constexpr Result ResultOverflow{ErrorModule::NSD, 6}; + NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, name} { // clang-format off static const FunctionInfo functions[] = { @@ -15,8 +20,8 @@ NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, na {13, nullptr, "DeleteSettings"}, {14, nullptr, "ImportSettings"}, {15, nullptr, "SetChangeEnvironmentIdentifierDisabled"}, - {20, nullptr, "Resolve"}, - {21, nullptr, "ResolveEx"}, + {20, &NSD::Resolve, "Resolve"}, + {21, &NSD::ResolveEx, "ResolveEx"}, {30, nullptr, "GetNasServiceSetting"}, {31, nullptr, "GetNasServiceSettingEx"}, {40, nullptr, "GetNasRequestFqdn"}, @@ -40,6 +45,55 @@ NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, na RegisterHandlers(functions); } +static ResultVal ResolveImpl(const std::string& fqdn_in) { + // The real implementation makes various substitutions. + // For now we just return the string as-is, which is good enough when not + // connecting to real Nintendo servers. + LOG_WARNING(Service, "(STUBBED) called, fqdn_in={}", fqdn_in); + return fqdn_in; +} + +static Result ResolveCommon(const std::string& fqdn_in, std::array& fqdn_out) { + const auto res = ResolveImpl(fqdn_in); + if (res.Failed()) { + return res.Code(); + } + if (res->size() >= fqdn_out.size()) { + return ResultOverflow; + } + std::memcpy(fqdn_out.data(), res->c_str(), res->size() + 1); + return ResultSuccess; +} + +void NSD::Resolve(HLERequestContext& ctx) { + const std::string fqdn_in = Common::StringFromBuffer(ctx.ReadBuffer(0)); + + std::array fqdn_out{}; + const Result res = ResolveCommon(fqdn_in, fqdn_out); + + ctx.WriteBuffer(fqdn_out); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); +} + +void NSD::ResolveEx(HLERequestContext& ctx) { + const std::string fqdn_in = Common::StringFromBuffer(ctx.ReadBuffer(0)); + + std::array fqdn_out; + const Result res = ResolveCommon(fqdn_in, fqdn_out); + + if (res.IsError()) { + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + return; + } + + ctx.WriteBuffer(fqdn_out); + IPC::ResponseBuilder rb{ctx, 4}; + rb.Push(ResultSuccess); + rb.Push(ResultSuccess); +} + NSD::~NSD() = default; } // namespace Service::Sockets diff --git a/src/core/hle/service/sockets/nsd.h b/src/core/hle/service/sockets/nsd.h index c6be6c414..a98b7cd84 100755 --- a/src/core/hle/service/sockets/nsd.h +++ b/src/core/hle/service/sockets/nsd.h @@ -15,6 +15,10 @@ class NSD final : public ServiceFramework { public: explicit NSD(Core::System& system_, const char* name); ~NSD() override; + +private: + void Resolve(HLERequestContext& ctx); + void ResolveEx(HLERequestContext& ctx); }; } // namespace Service::Sockets diff --git a/src/core/hle/service/sockets/sfdnsres.cpp b/src/core/hle/service/sockets/sfdnsres.cpp index fce5c0413..4d9ca8d14 100755 --- a/src/core/hle/service/sockets/sfdnsres.cpp +++ b/src/core/hle/service/sockets/sfdnsres.cpp @@ -10,27 +10,18 @@ #include "core/core.h" #include "core/hle/service/ipc_helpers.h" #include "core/hle/service/sockets/sfdnsres.h" +#include "core/hle/service/sockets/sockets.h" +#include "core/hle/service/sockets/sockets_translate.h" +#include "core/internal_network/network.h" #include "core/memory.h" -#ifdef _WIN32 -#include -#elif YUZU_UNIX -#include -#include -#include -#include -#ifndef EAI_NODATA -#define EAI_NODATA EAI_NONAME -#endif -#endif - namespace Service::Sockets { SFDNSRES::SFDNSRES(Core::System& system_) : ServiceFramework{system_, "sfdnsres"} { static const FunctionInfo functions[] = { {0, nullptr, "SetDnsAddressesPrivateRequest"}, {1, nullptr, "GetDnsAddressPrivateRequest"}, - {2, nullptr, "GetHostByNameRequest"}, + {2, &SFDNSRES::GetHostByNameRequest, "GetHostByNameRequest"}, {3, nullptr, "GetHostByAddrRequest"}, {4, nullptr, "GetHostStringErrorRequest"}, {5, nullptr, "GetGaiStringErrorRequest"}, @@ -38,11 +29,11 @@ SFDNSRES::SFDNSRES(Core::System& system_) : ServiceFramework{system_, "sfdnsres" {7, nullptr, "GetNameInfoRequest"}, {8, nullptr, "RequestCancelHandleRequest"}, {9, nullptr, "CancelRequest"}, - {10, nullptr, "GetHostByNameRequestWithOptions"}, + {10, &SFDNSRES::GetHostByNameRequestWithOptions, "GetHostByNameRequestWithOptions"}, {11, nullptr, "GetHostByAddrRequestWithOptions"}, {12, &SFDNSRES::GetAddrInfoRequestWithOptions, "GetAddrInfoRequestWithOptions"}, {13, nullptr, "GetNameInfoRequestWithOptions"}, - {14, nullptr, "ResolverSetOptionRequest"}, + {14, &SFDNSRES::ResolverSetOptionRequest, "ResolverSetOptionRequest"}, {15, nullptr, "ResolverGetOptionRequest"}, }; RegisterHandlers(functions); @@ -59,188 +50,285 @@ enum class NetDbError : s32 { NoData = 4, }; -static NetDbError AddrInfoErrorToNetDbError(s32 result) { - // Best effort guess to map errors +static NetDbError GetAddrInfoErrorToNetDbError(GetAddrInfoError result) { + // These combinations have been verified on console (but are not + // exhaustive). switch (result) { - case 0: + case GetAddrInfoError::SUCCESS: return NetDbError::Success; - case EAI_AGAIN: + case GetAddrInfoError::AGAIN: return NetDbError::TryAgain; - case EAI_NODATA: - return NetDbError::NoData; + case GetAddrInfoError::NODATA: + return NetDbError::HostNotFound; + case GetAddrInfoError::SERVICE: + return NetDbError::Success; default: return NetDbError::HostNotFound; } } -static std::vector SerializeAddrInfo(const addrinfo* addrinfo, s32 result_code, +static Errno GetAddrInfoErrorToErrno(GetAddrInfoError result) { + // These combinations have been verified on console (but are not + // exhaustive). + switch (result) { + case GetAddrInfoError::SUCCESS: + // Note: Sometimes a successful lookup sets errno to EADDRNOTAVAIL for + // some reason, but that doesn't seem useful to implement. + return Errno::SUCCESS; + case GetAddrInfoError::AGAIN: + return Errno::SUCCESS; + case GetAddrInfoError::NODATA: + return Errno::SUCCESS; + case GetAddrInfoError::SERVICE: + return Errno::INVAL; + default: + return Errno::SUCCESS; + } +} + +template +static void Append(std::vector& vec, T t) { + const size_t offset = vec.size(); + vec.resize(offset + sizeof(T)); + std::memcpy(vec.data() + offset, &t, sizeof(T)); +} + +static void AppendNulTerminated(std::vector& vec, std::string_view str) { + const size_t offset = vec.size(); + vec.resize(offset + str.size() + 1); + std::memmove(vec.data() + offset, str.data(), str.size()); +} + +// We implement gethostbyname using the host's getaddrinfo rather than the +// host's gethostbyname, because it simplifies portability: e.g., getaddrinfo +// behaves the same on Unix and Windows, unlike gethostbyname where Windows +// doesn't implement h_errno. +static std::vector SerializeAddrInfoAsHostEnt(const std::vector& vec, + std::string_view host) { + + std::vector data; + // h_name: use the input hostname (append nul-terminated) + AppendNulTerminated(data, host); + // h_aliases: leave empty + + Append(data, 0); // count of h_aliases + // (If the count were nonzero, the aliases would be appended as nul-terminated here.) + Append(data, static_cast(Domain::INET)); // h_addrtype + Append(data, sizeof(Network::IPv4Address)); // h_length + // h_addr_list: + size_t count = vec.size(); + ASSERT(count <= UINT32_MAX); + Append(data, static_cast(count)); + for (const Network::AddrInfo& addrinfo : vec) { + // On the Switch, this is passed through htonl despite already being + // big-endian, so it ends up as little-endian. + Append(data, Network::IPv4AddressToInteger(addrinfo.addr.ip)); + + LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host, + Network::IPv4AddressToString(addrinfo.addr.ip)); + } + return data; +} + +static std::pair GetHostByNameRequestImpl(HLERequestContext& ctx) { + struct InputParameters { + u8 use_nsd_resolve; + u32 cancel_handle; + u64 process_id; + }; + static_assert(sizeof(InputParameters) == 0x10); + + IPC::RequestParser rp{ctx}; + const auto parameters = rp.PopRaw(); + + LOG_WARNING( + Service, + "called with ignored parameters: use_nsd_resolve={}, cancel_handle={}, process_id={}", + parameters.use_nsd_resolve, parameters.cancel_handle, parameters.process_id); + + const auto host_buffer = ctx.ReadBuffer(0); + const std::string host = Common::StringFromBuffer(host_buffer); + // For now, ignore options, which are in input buffer 1 for GetHostByNameRequestWithOptions. + + auto res = Network::GetAddressInfo(host, /*service*/ std::nullopt); + if (!res.has_value()) { + return {0, Translate(res.error())}; + } + + const std::vector data = SerializeAddrInfoAsHostEnt(res.value(), host); + const u32 data_size = static_cast(data.size()); + ctx.WriteBuffer(data, 0); + + return {data_size, GetAddrInfoError::SUCCESS}; +} + +void SFDNSRES::GetHostByNameRequest(HLERequestContext& ctx) { + auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx); + + struct OutputParameters { + NetDbError netdb_error; + Errno bsd_errno; + u32 data_size; + }; + static_assert(sizeof(OutputParameters) == 0xc); + + IPC::ResponseBuilder rb{ctx, 5}; + rb.Push(ResultSuccess); + rb.PushRaw(OutputParameters{ + .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err), + .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err), + .data_size = data_size, + }); +} + +void SFDNSRES::GetHostByNameRequestWithOptions(HLERequestContext& ctx) { + auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx); + + struct OutputParameters { + u32 data_size; + NetDbError netdb_error; + Errno bsd_errno; + }; + static_assert(sizeof(OutputParameters) == 0xc); + + IPC::ResponseBuilder rb{ctx, 5}; + rb.Push(ResultSuccess); + rb.PushRaw(OutputParameters{ + .data_size = data_size, + .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err), + .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err), + }); +} + +static std::vector SerializeAddrInfo(const std::vector& vec, std::string_view host) { // Adapted from // https://github.com/switchbrew/libnx/blob/c5a9a909a91657a9818a3b7e18c9b91ff0cbb6e3/nx/source/runtime/resolver.c#L190 std::vector data; - auto* current = addrinfo; - while (current != nullptr) { - struct SerializedResponseHeader { - u32 magic; - s32 flags; - s32 family; - s32 socket_type; - s32 protocol; - u32 address_length; - }; - static_assert(sizeof(SerializedResponseHeader) == 0x18, - "Response header size must be 0x18 bytes"); + for (const Network::AddrInfo& addrinfo : vec) { + // serialized addrinfo: + Append(data, 0xBEEFCAFE); // magic + Append(data, 0); // ai_flags + Append(data, static_cast(Translate(addrinfo.family))); // ai_family + Append(data, static_cast(Translate(addrinfo.socket_type))); // ai_socktype + Append(data, static_cast(Translate(addrinfo.protocol))); // ai_protocol + Append(data, sizeof(SockAddrIn)); // ai_addrlen + // ^ *not* sizeof(SerializedSockAddrIn), not that it matters since they're the same size - constexpr auto header_size = sizeof(SerializedResponseHeader); - const auto addr_size = - current->ai_addr && current->ai_addrlen > 0 ? current->ai_addrlen : 4; - const auto canonname_size = current->ai_canonname ? strlen(current->ai_canonname) + 1 : 1; + // ai_addr: + Append(data, static_cast(Translate(addrinfo.addr.family))); // sin_family + // On the Switch, the following fields are passed through htonl despite + // already being big-endian, so they end up as little-endian. + Append(data, addrinfo.addr.portno); // sin_port + Append(data, Network::IPv4AddressToInteger(addrinfo.addr.ip)); // sin_addr + data.resize(data.size() + 8, 0); // sin_zero - const auto last_size = data.size(); - data.resize(last_size + header_size + addr_size + canonname_size); - - // Header in network byte order - SerializedResponseHeader header{}; - - constexpr auto HEADER_MAGIC = 0xBEEFCAFE; - header.magic = htonl(HEADER_MAGIC); - header.family = htonl(current->ai_family); - header.flags = htonl(current->ai_flags); - header.socket_type = htonl(current->ai_socktype); - header.protocol = htonl(current->ai_protocol); - header.address_length = current->ai_addr ? htonl((u32)current->ai_addrlen) : 0; - - auto* header_ptr = data.data() + last_size; - std::memcpy(header_ptr, &header, header_size); - - if (header.address_length == 0) { - std::memset(header_ptr + header_size, 0, 4); + if (addrinfo.canon_name.has_value()) { + AppendNulTerminated(data, *addrinfo.canon_name); } else { - switch (current->ai_family) { - case AF_INET: { - struct SockAddrIn { - s16 sin_family; - u16 sin_port; - u32 sin_addr; - u8 sin_zero[8]; - }; - - SockAddrIn serialized_addr{}; - const auto addr = *reinterpret_cast(current->ai_addr); - serialized_addr.sin_port = htons(addr.sin_port); - serialized_addr.sin_family = htons(addr.sin_family); - serialized_addr.sin_addr = htonl(addr.sin_addr.s_addr); - std::memcpy(header_ptr + header_size, &serialized_addr, sizeof(SockAddrIn)); - - char addr_string_buf[64]{}; - inet_ntop(AF_INET, &addr.sin_addr, addr_string_buf, std::size(addr_string_buf)); - LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host, addr_string_buf); - break; - } - case AF_INET6: { - struct SockAddrIn6 { - s16 sin6_family; - u16 sin6_port; - u32 sin6_flowinfo; - u8 sin6_addr[16]; - u32 sin6_scope_id; - }; - - SockAddrIn6 serialized_addr{}; - const auto addr = *reinterpret_cast(current->ai_addr); - serialized_addr.sin6_family = htons(addr.sin6_family); - serialized_addr.sin6_port = htons(addr.sin6_port); - serialized_addr.sin6_flowinfo = htonl(addr.sin6_flowinfo); - serialized_addr.sin6_scope_id = htonl(addr.sin6_scope_id); - std::memcpy(serialized_addr.sin6_addr, &addr.sin6_addr, - sizeof(SockAddrIn6::sin6_addr)); - std::memcpy(header_ptr + header_size, &serialized_addr, sizeof(SockAddrIn6)); - - char addr_string_buf[64]{}; - inet_ntop(AF_INET6, &addr.sin6_addr, addr_string_buf, std::size(addr_string_buf)); - LOG_INFO(Service, "Resolved host '{}' to IPv6 address {}", host, addr_string_buf); - break; - } - default: - std::memcpy(header_ptr + header_size, current->ai_addr, addr_size); - break; - } - } - if (current->ai_canonname) { - std::memcpy(header_ptr + addr_size, current->ai_canonname, canonname_size); - } else { - *(header_ptr + header_size + addr_size) = 0; + data.push_back(0); } - current = current->ai_next; + LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host, + Network::IPv4AddressToString(addrinfo.addr.ip)); } - // 4-byte sentinel value - data.push_back(0); - data.push_back(0); - data.push_back(0); - data.push_back(0); + data.resize(data.size() + 4, 0); // 4-byte sentinel value return data; } -static std::pair GetAddrInfoRequestImpl(HLERequestContext& ctx) { - struct Parameters { +static std::pair GetAddrInfoRequestImpl(HLERequestContext& ctx) { + struct InputParameters { u8 use_nsd_resolve; - u32 unknown; + u32 cancel_handle; u64 process_id; }; + static_assert(sizeof(InputParameters) == 0x10); IPC::RequestParser rp{ctx}; - const auto parameters = rp.PopRaw(); + const auto parameters = rp.PopRaw(); - LOG_WARNING(Service, - "called with ignored parameters: use_nsd_resolve={}, unknown={}, process_id={}", - parameters.use_nsd_resolve, parameters.unknown, parameters.process_id); + LOG_WARNING( + Service, + "called with ignored parameters: use_nsd_resolve={}, cancel_handle={}, process_id={}", + parameters.use_nsd_resolve, parameters.cancel_handle, parameters.process_id); + + // TODO: If use_nsd_resolve is true, pass the name through NSD::Resolve + // before looking up. const auto host_buffer = ctx.ReadBuffer(0); const std::string host = Common::StringFromBuffer(host_buffer); - const auto service_buffer = ctx.ReadBuffer(1); - const std::string service = Common::StringFromBuffer(service_buffer); - - addrinfo* addrinfo; - // Pass null for hints. Serialized hints are also passed in a buffer, but are ignored for now - s32 result_code = getaddrinfo(host.c_str(), service.c_str(), nullptr, &addrinfo); - - u32 data_size = 0; - if (result_code == 0 && addrinfo != nullptr) { - const std::vector& data = SerializeAddrInfo(addrinfo, result_code, host); - data_size = static_cast(data.size()); - freeaddrinfo(addrinfo); - - ctx.WriteBuffer(data, 0); + std::optional service = std::nullopt; + if (ctx.CanReadBuffer(1)) { + const std::span service_buffer = ctx.ReadBuffer(1); + service = Common::StringFromBuffer(service_buffer); } - return std::make_pair(data_size, result_code); + // Serialized hints are also passed in a buffer, but are ignored for now. + + auto res = Network::GetAddressInfo(host, service); + if (!res.has_value()) { + return {0, Translate(res.error())}; + } + + const std::vector data = SerializeAddrInfo(res.value(), host); + const u32 data_size = static_cast(data.size()); + ctx.WriteBuffer(data, 0); + + return {data_size, GetAddrInfoError::SUCCESS}; } void SFDNSRES::GetAddrInfoRequest(HLERequestContext& ctx) { - auto [data_size, result_code] = GetAddrInfoRequestImpl(ctx); + auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx); - IPC::ResponseBuilder rb{ctx, 4}; + struct OutputParameters { + Errno bsd_errno; + GetAddrInfoError gai_error; + u32 data_size; + }; + static_assert(sizeof(OutputParameters) == 0xc); + + IPC::ResponseBuilder rb{ctx, 5}; rb.Push(ResultSuccess); - rb.Push(static_cast(AddrInfoErrorToNetDbError(result_code))); // NetDBErrorCode - rb.Push(result_code); // errno - rb.Push(data_size); // serialized size + rb.PushRaw(OutputParameters{ + .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err), + .gai_error = emu_gai_err, + .data_size = data_size, + }); } void SFDNSRES::GetAddrInfoRequestWithOptions(HLERequestContext& ctx) { // Additional options are ignored - auto [data_size, result_code] = GetAddrInfoRequestImpl(ctx); + auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx); - IPC::ResponseBuilder rb{ctx, 5}; + struct OutputParameters { + u32 data_size; + GetAddrInfoError gai_error; + NetDbError netdb_error; + Errno bsd_errno; + }; + static_assert(sizeof(OutputParameters) == 0x10); + + IPC::ResponseBuilder rb{ctx, 6}; rb.Push(ResultSuccess); - rb.Push(data_size); // serialized size - rb.Push(result_code); // errno - rb.Push(static_cast(AddrInfoErrorToNetDbError(result_code))); // NetDBErrorCode - rb.Push(0); + rb.PushRaw(OutputParameters{ + .data_size = data_size, + .gai_error = emu_gai_err, + .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err), + .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err), + }); +} + +void SFDNSRES::ResolverSetOptionRequest(HLERequestContext& ctx) { + LOG_WARNING(Service, "(STUBBED) called"); + + IPC::ResponseBuilder rb{ctx, 3}; + + rb.Push(ResultSuccess); + rb.Push(0); // bsd errno } } // namespace Service::Sockets diff --git a/src/core/hle/service/sockets/sfdnsres.h b/src/core/hle/service/sockets/sfdnsres.h index bfa0ccc22..70a327463 100755 --- a/src/core/hle/service/sockets/sfdnsres.h +++ b/src/core/hle/service/sockets/sfdnsres.h @@ -17,8 +17,11 @@ public: ~SFDNSRES() override; private: + void GetHostByNameRequest(HLERequestContext& ctx); + void GetHostByNameRequestWithOptions(HLERequestContext& ctx); void GetAddrInfoRequest(HLERequestContext& ctx); void GetAddrInfoRequestWithOptions(HLERequestContext& ctx); + void ResolverSetOptionRequest(HLERequestContext& ctx); }; } // namespace Service::Sockets diff --git a/src/core/hle/service/sockets/sockets.h b/src/core/hle/service/sockets/sockets.h index 9dcf83438..725d8bdc9 100755 --- a/src/core/hle/service/sockets/sockets.h +++ b/src/core/hle/service/sockets/sockets.h @@ -22,13 +22,35 @@ enum class Errno : u32 { CONNRESET = 104, NOTCONN = 107, TIMEDOUT = 110, + INPROGRESS = 115, +}; + +enum class GetAddrInfoError : s32 { + SUCCESS = 0, + ADDRFAMILY = 1, + AGAIN = 2, + BADFLAGS = 3, + FAIL = 4, + FAMILY = 5, + MEMORY = 6, + NODATA = 7, + NONAME = 8, + SERVICE = 9, + SOCKTYPE = 10, + SYSTEM = 11, + BADHINTS = 12, + PROTOCOL = 13, + OVERFLOW_ = 14, // avoid name collision with Windows macro + OTHER = 15, }; enum class Domain : u32 { + Unspecified = 0, INET = 2, }; enum class Type : u32 { + Unspecified = 0, STREAM = 1, DGRAM = 2, RAW = 3, @@ -36,12 +58,16 @@ enum class Type : u32 { }; enum class Protocol : u32 { - UNSPECIFIED = 0, + Unspecified = 0, ICMP = 1, TCP = 6, UDP = 17, }; +enum class SocketLevel : u32 { + SOCKET = 0xffff, // i.e. SOL_SOCKET +}; + enum class OptName : u32 { REUSEADDR = 0x4, KEEPALIVE = 0x8, @@ -51,6 +77,8 @@ enum class OptName : u32 { RCVBUF = 0x1002, SNDTIMEO = 0x1005, RCVTIMEO = 0x1006, + ERROR_ = 0x1007, // avoid name collision with Windows macro + NOSIGPIPE = 0x800, // at least according to libnx }; enum class ShutdownHow : s32 { @@ -80,6 +108,9 @@ enum class PollEvents : u16 { Err = 1 << 3, Hup = 1 << 4, Nval = 1 << 5, + RdNorm = 1 << 6, + RdBand = 1 << 7, + WrBand = 1 << 8, }; DECLARE_ENUM_FLAG_OPERATORS(PollEvents); diff --git a/src/core/hle/service/sockets/sockets_translate.cpp b/src/core/hle/service/sockets/sockets_translate.cpp index a5e0365fe..f991cd67c 100755 --- a/src/core/hle/service/sockets/sockets_translate.cpp +++ b/src/core/hle/service/sockets/sockets_translate.cpp @@ -29,6 +29,8 @@ Errno Translate(Network::Errno value) { return Errno::TIMEDOUT; case Network::Errno::CONNRESET: return Errno::CONNRESET; + case Network::Errno::INPROGRESS: + return Errno::INPROGRESS; default: UNIMPLEMENTED_MSG("Unimplemented errno={}", value); return Errno::SUCCESS; @@ -39,8 +41,50 @@ std::pair Translate(std::pair value) { return {value.first, Translate(value.second)}; } +GetAddrInfoError Translate(Network::GetAddrInfoError error) { + switch (error) { + case Network::GetAddrInfoError::SUCCESS: + return GetAddrInfoError::SUCCESS; + case Network::GetAddrInfoError::ADDRFAMILY: + return GetAddrInfoError::ADDRFAMILY; + case Network::GetAddrInfoError::AGAIN: + return GetAddrInfoError::AGAIN; + case Network::GetAddrInfoError::BADFLAGS: + return GetAddrInfoError::BADFLAGS; + case Network::GetAddrInfoError::FAIL: + return GetAddrInfoError::FAIL; + case Network::GetAddrInfoError::FAMILY: + return GetAddrInfoError::FAMILY; + case Network::GetAddrInfoError::MEMORY: + return GetAddrInfoError::MEMORY; + case Network::GetAddrInfoError::NODATA: + return GetAddrInfoError::NODATA; + case Network::GetAddrInfoError::NONAME: + return GetAddrInfoError::NONAME; + case Network::GetAddrInfoError::SERVICE: + return GetAddrInfoError::SERVICE; + case Network::GetAddrInfoError::SOCKTYPE: + return GetAddrInfoError::SOCKTYPE; + case Network::GetAddrInfoError::SYSTEM: + return GetAddrInfoError::SYSTEM; + case Network::GetAddrInfoError::BADHINTS: + return GetAddrInfoError::BADHINTS; + case Network::GetAddrInfoError::PROTOCOL: + return GetAddrInfoError::PROTOCOL; + case Network::GetAddrInfoError::OVERFLOW_: + return GetAddrInfoError::OVERFLOW_; + case Network::GetAddrInfoError::OTHER: + return GetAddrInfoError::OTHER; + default: + UNIMPLEMENTED_MSG("Unimplemented GetAddrInfoError={}", error); + return GetAddrInfoError::OTHER; + } +} + Network::Domain Translate(Domain domain) { switch (domain) { + case Domain::Unspecified: + return Network::Domain::Unspecified; case Domain::INET: return Network::Domain::INET; default: @@ -51,6 +95,8 @@ Network::Domain Translate(Domain domain) { Domain Translate(Network::Domain domain) { switch (domain) { + case Network::Domain::Unspecified: + return Domain::Unspecified; case Network::Domain::INET: return Domain::INET; default: @@ -61,39 +107,69 @@ Domain Translate(Network::Domain domain) { Network::Type Translate(Type type) { switch (type) { + case Type::Unspecified: + return Network::Type::Unspecified; case Type::STREAM: return Network::Type::STREAM; case Type::DGRAM: return Network::Type::DGRAM; + case Type::RAW: + return Network::Type::RAW; + case Type::SEQPACKET: + return Network::Type::SEQPACKET; default: UNIMPLEMENTED_MSG("Unimplemented type={}", type); return Network::Type{}; } } -Network::Protocol Translate(Type type, Protocol protocol) { +Type Translate(Network::Type type) { + switch (type) { + case Network::Type::Unspecified: + return Type::Unspecified; + case Network::Type::STREAM: + return Type::STREAM; + case Network::Type::DGRAM: + return Type::DGRAM; + case Network::Type::RAW: + return Type::RAW; + case Network::Type::SEQPACKET: + return Type::SEQPACKET; + default: + UNIMPLEMENTED_MSG("Unimplemented type={}", type); + return Type{}; + } +} + +Network::Protocol Translate(Protocol protocol) { switch (protocol) { - case Protocol::UNSPECIFIED: - LOG_WARNING(Service, "Unspecified protocol, assuming protocol from type"); - switch (type) { - case Type::DGRAM: - return Network::Protocol::UDP; - case Type::STREAM: - return Network::Protocol::TCP; - default: - return Network::Protocol::TCP; - } + case Protocol::Unspecified: + return Network::Protocol::Unspecified; case Protocol::TCP: return Network::Protocol::TCP; case Protocol::UDP: return Network::Protocol::UDP; default: UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol); - return Network::Protocol::TCP; + return Network::Protocol::Unspecified; } } -Network::PollEvents TranslatePollEventsToHost(PollEvents flags) { +Protocol Translate(Network::Protocol protocol) { + switch (protocol) { + case Network::Protocol::Unspecified: + return Protocol::Unspecified; + case Network::Protocol::TCP: + return Protocol::TCP; + case Network::Protocol::UDP: + return Protocol::UDP; + default: + UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol); + return Protocol::Unspecified; + } +} + +Network::PollEvents Translate(PollEvents flags) { Network::PollEvents result{}; const auto translate = [&result, &flags](PollEvents from, Network::PollEvents to) { if (True(flags & from)) { @@ -107,12 +183,15 @@ Network::PollEvents TranslatePollEventsToHost(PollEvents flags) { translate(PollEvents::Err, Network::PollEvents::Err); translate(PollEvents::Hup, Network::PollEvents::Hup); translate(PollEvents::Nval, Network::PollEvents::Nval); + translate(PollEvents::RdNorm, Network::PollEvents::RdNorm); + translate(PollEvents::RdBand, Network::PollEvents::RdBand); + translate(PollEvents::WrBand, Network::PollEvents::WrBand); UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags); return result; } -PollEvents TranslatePollEventsToGuest(Network::PollEvents flags) { +PollEvents Translate(Network::PollEvents flags) { PollEvents result{}; const auto translate = [&result, &flags](Network::PollEvents from, PollEvents to) { if (True(flags & from)) { @@ -127,13 +206,18 @@ PollEvents TranslatePollEventsToGuest(Network::PollEvents flags) { translate(Network::PollEvents::Err, PollEvents::Err); translate(Network::PollEvents::Hup, PollEvents::Hup); translate(Network::PollEvents::Nval, PollEvents::Nval); + translate(Network::PollEvents::RdNorm, PollEvents::RdNorm); + translate(Network::PollEvents::RdBand, PollEvents::RdBand); + translate(Network::PollEvents::WrBand, PollEvents::WrBand); UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags); return result; } Network::SockAddrIn Translate(SockAddrIn value) { - ASSERT(value.len == 0 || value.len == sizeof(value)); + // Note: 6 is incorrect, but can be passed by homebrew (because libnx sets + // sin_len to 6 when deserializing getaddrinfo results). + ASSERT(value.len == 0 || value.len == sizeof(value) || value.len == 6); return { .family = Translate(static_cast(value.family)), diff --git a/src/core/hle/service/sockets/sockets_translate.h b/src/core/hle/service/sockets/sockets_translate.h index ef6a711ef..afc667786 100755 --- a/src/core/hle/service/sockets/sockets_translate.h +++ b/src/core/hle/service/sockets/sockets_translate.h @@ -17,6 +17,9 @@ Errno Translate(Network::Errno value); /// Translate abstract return value errno pair to guest return value errno pair std::pair Translate(std::pair value); +/// Translate abstract getaddrinfo error to guest getaddrinfo error +GetAddrInfoError Translate(Network::GetAddrInfoError value); + /// Translate guest domain to abstract domain Network::Domain Translate(Domain domain); @@ -26,14 +29,20 @@ Domain Translate(Network::Domain domain); /// Translate guest type to abstract type Network::Type Translate(Type type); -/// Translate guest protocol to abstract protocol -Network::Protocol Translate(Type type, Protocol protocol); +/// Translate abstract type to guest type +Type Translate(Network::Type type); -/// Translate abstract poll event flags to guest poll event flags -Network::PollEvents TranslatePollEventsToHost(PollEvents flags); +/// Translate guest protocol to abstract protocol +Network::Protocol Translate(Protocol protocol); + +/// Translate abstract protocol to guest protocol +Protocol Translate(Network::Protocol protocol); /// Translate guest poll event flags to abstract poll event flags -PollEvents TranslatePollEventsToGuest(Network::PollEvents flags); +Network::PollEvents Translate(PollEvents flags); + +/// Translate abstract poll event flags to guest poll event flags +PollEvents Translate(Network::PollEvents flags); /// Translate guest socket address structure to abstract socket address structure Network::SockAddrIn Translate(SockAddrIn value); diff --git a/src/core/hle/service/ssl/ssl.cpp b/src/core/hle/service/ssl/ssl.cpp index 76b3f5f4e..2bbbf69a9 100755 --- a/src/core/hle/service/ssl/ssl.cpp +++ b/src/core/hle/service/ssl/ssl.cpp @@ -1,10 +1,18 @@ // SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project // SPDX-License-Identifier: GPL-2.0-or-later +#include "common/string_util.h" + +#include "core/core.h" #include "core/hle/service/ipc_helpers.h" #include "core/hle/service/server_manager.h" #include "core/hle/service/service.h" +#include "core/hle/service/sm/sm.h" +#include "core/hle/service/sockets/bsd.h" #include "core/hle/service/ssl/ssl.h" +#include "core/hle/service/ssl/ssl_backend.h" +#include "core/internal_network/network.h" +#include "core/internal_network/sockets.h" namespace Service::SSL { @@ -20,6 +28,18 @@ enum class ContextOption : u32 { CrlImportDateCheckEnable = 1, }; +// This is nn::ssl::Connection::IoMode +enum class IoMode : u32 { + Blocking = 1, + NonBlocking = 2, +}; + +// This is nn::ssl::sf::OptionType +enum class OptionType : u32 { + DoNotCloseSocket = 0, + GetServerCertChain = 1, +}; + // This is nn::ssl::sf::SslVersion struct SslVersion { union { @@ -34,35 +54,42 @@ struct SslVersion { }; }; +struct SslContextSharedData { + u32 connection_count = 0; +}; + class ISslConnection final : public ServiceFramework { public: - explicit ISslConnection(Core::System& system_, SslVersion version) - : ServiceFramework{system_, "ISslConnection"}, ssl_version{version} { + explicit ISslConnection(Core::System& system_in, SslVersion ssl_version_in, + std::shared_ptr& shared_data_in, + std::unique_ptr&& backend_in) + : ServiceFramework{system_in, "ISslConnection"}, ssl_version{ssl_version_in}, + shared_data{shared_data_in}, backend{std::move(backend_in)} { // clang-format off static const FunctionInfo functions[] = { - {0, nullptr, "SetSocketDescriptor"}, - {1, nullptr, "SetHostName"}, - {2, nullptr, "SetVerifyOption"}, - {3, nullptr, "SetIoMode"}, + {0, &ISslConnection::SetSocketDescriptor, "SetSocketDescriptor"}, + {1, &ISslConnection::SetHostName, "SetHostName"}, + {2, &ISslConnection::SetVerifyOption, "SetVerifyOption"}, + {3, &ISslConnection::SetIoMode, "SetIoMode"}, {4, nullptr, "GetSocketDescriptor"}, {5, nullptr, "GetHostName"}, {6, nullptr, "GetVerifyOption"}, {7, nullptr, "GetIoMode"}, - {8, nullptr, "DoHandshake"}, - {9, nullptr, "DoHandshakeGetServerCert"}, - {10, nullptr, "Read"}, - {11, nullptr, "Write"}, - {12, nullptr, "Pending"}, + {8, &ISslConnection::DoHandshake, "DoHandshake"}, + {9, &ISslConnection::DoHandshakeGetServerCert, "DoHandshakeGetServerCert"}, + {10, &ISslConnection::Read, "Read"}, + {11, &ISslConnection::Write, "Write"}, + {12, &ISslConnection::Pending, "Pending"}, {13, nullptr, "Peek"}, {14, nullptr, "Poll"}, {15, nullptr, "GetVerifyCertError"}, {16, nullptr, "GetNeededServerCertBufferSize"}, - {17, nullptr, "SetSessionCacheMode"}, + {17, &ISslConnection::SetSessionCacheMode, "SetSessionCacheMode"}, {18, nullptr, "GetSessionCacheMode"}, {19, nullptr, "FlushSessionCache"}, {20, nullptr, "SetRenegotiationMode"}, {21, nullptr, "GetRenegotiationMode"}, - {22, nullptr, "SetOption"}, + {22, &ISslConnection::SetOption, "SetOption"}, {23, nullptr, "GetOption"}, {24, nullptr, "GetVerifyCertErrors"}, {25, nullptr, "GetCipherInfo"}, @@ -80,21 +107,299 @@ public: // clang-format on RegisterHandlers(functions); + + shared_data->connection_count++; + } + + ~ISslConnection() { + shared_data->connection_count--; + if (fd_to_close.has_value()) { + const s32 fd = *fd_to_close; + if (!do_not_close_socket) { + LOG_ERROR(Service_SSL, + "do_not_close_socket was changed after setting socket; is this right?"); + } else { + auto bsd = system.ServiceManager().GetService("bsd:u"); + if (bsd) { + auto err = bsd->CloseImpl(fd); + if (err != Service::Sockets::Errno::SUCCESS) { + LOG_ERROR(Service_SSL, "Failed to close duplicated socket: {}", err); + } + } + } + } } private: SslVersion ssl_version; + std::shared_ptr shared_data; + std::unique_ptr backend; + std::optional fd_to_close; + bool do_not_close_socket = false; + bool get_server_cert_chain = false; + std::shared_ptr socket; + bool did_set_host_name = false; + bool did_handshake = false; + + ResultVal SetSocketDescriptorImpl(s32 fd) { + LOG_DEBUG(Service_SSL, "called, fd={}", fd); + ASSERT(!did_handshake); + auto bsd = system.ServiceManager().GetService("bsd:u"); + ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; }); + s32 ret_fd; + // Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor + if (do_not_close_socket) { + auto res = bsd->DuplicateSocketImpl(fd); + if (!res.has_value()) { + LOG_ERROR(Service_SSL, "Failed to duplicate socket with fd {}", fd); + return ResultInvalidSocket; + } + fd = *res; + fd_to_close = fd; + ret_fd = fd; + } else { + ret_fd = -1; + } + std::optional> sock = bsd->GetSocket(fd); + if (!sock.has_value()) { + LOG_ERROR(Service_SSL, "invalid socket fd {}", fd); + return ResultInvalidSocket; + } + socket = std::move(*sock); + backend->SetSocket(socket); + return ret_fd; + } + + Result SetHostNameImpl(const std::string& hostname) { + LOG_DEBUG(Service_SSL, "called. hostname={}", hostname); + ASSERT(!did_handshake); + Result res = backend->SetHostName(hostname); + if (res == ResultSuccess) { + did_set_host_name = true; + } + return res; + } + + Result SetVerifyOptionImpl(u32 option) { + ASSERT(!did_handshake); + LOG_WARNING(Service_SSL, "(STUBBED) called. option={}", option); + return ResultSuccess; + } + + Result SetIoModeImpl(u32 input_mode) { + auto mode = static_cast(input_mode); + ASSERT(mode == IoMode::Blocking || mode == IoMode::NonBlocking); + ASSERT_OR_EXECUTE(socket, { return ResultNoSocket; }); + + const bool non_block = mode == IoMode::NonBlocking; + const Network::Errno error = socket->SetNonBlock(non_block); + if (error != Network::Errno::SUCCESS) { + LOG_ERROR(Service_SSL, "Failed to set native socket non-block flag to {}", non_block); + } + return ResultSuccess; + } + + Result SetSessionCacheModeImpl(u32 mode) { + ASSERT(!did_handshake); + LOG_WARNING(Service_SSL, "(STUBBED) called. value={}", mode); + return ResultSuccess; + } + + Result DoHandshakeImpl() { + ASSERT_OR_EXECUTE(!did_handshake && socket, { return ResultNoSocket; }); + ASSERT_OR_EXECUTE_MSG( + did_set_host_name, { return ResultInternalError; }, + "Expected SetHostName before DoHandshake"); + Result res = backend->DoHandshake(); + did_handshake = res.IsSuccess(); + return res; + } + + std::vector SerializeServerCerts(const std::vector>& certs) { + struct Header { + u64 magic; + u32 count; + u32 pad; + }; + struct EntryHeader { + u32 size; + u32 offset; + }; + if (!get_server_cert_chain) { + // Just return the first one, unencoded. + ASSERT_OR_EXECUTE_MSG( + !certs.empty(), { return {}; }, "Should be at least one server cert"); + return certs[0]; + } + std::vector ret; + Header header{0x4E4D684374726543, static_cast(certs.size()), 0}; + ret.insert(ret.end(), reinterpret_cast(&header), reinterpret_cast(&header + 1)); + size_t data_offset = sizeof(Header) + certs.size() * sizeof(EntryHeader); + for (auto& cert : certs) { + EntryHeader entry_header{static_cast(cert.size()), static_cast(data_offset)}; + data_offset += cert.size(); + ret.insert(ret.end(), reinterpret_cast(&entry_header), + reinterpret_cast(&entry_header + 1)); + } + for (auto& cert : certs) { + ret.insert(ret.end(), cert.begin(), cert.end()); + } + return ret; + } + + ResultVal> ReadImpl(size_t size) { + ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; }); + std::vector res(size); + ResultVal actual = backend->Read(res); + if (actual.Failed()) { + return actual.Code(); + } + res.resize(*actual); + return res; + } + + ResultVal WriteImpl(std::span data) { + ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; }); + return backend->Write(data); + } + + ResultVal PendingImpl() { + LOG_WARNING(Service_SSL, "(STUBBED) called."); + return 0; + } + + void SetSocketDescriptor(HLERequestContext& ctx) { + IPC::RequestParser rp{ctx}; + const s32 fd = rp.Pop(); + const ResultVal res = SetSocketDescriptorImpl(fd); + IPC::ResponseBuilder rb{ctx, 3}; + rb.Push(res.Code()); + rb.Push(res.ValueOr(-1)); + } + + void SetHostName(HLERequestContext& ctx) { + const std::string hostname = Common::StringFromBuffer(ctx.ReadBuffer()); + const Result res = SetHostNameImpl(hostname); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void SetVerifyOption(HLERequestContext& ctx) { + IPC::RequestParser rp{ctx}; + const u32 option = rp.Pop(); + const Result res = SetVerifyOptionImpl(option); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void SetIoMode(HLERequestContext& ctx) { + IPC::RequestParser rp{ctx}; + const u32 mode = rp.Pop(); + const Result res = SetIoModeImpl(mode); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void DoHandshake(HLERequestContext& ctx) { + const Result res = DoHandshakeImpl(); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void DoHandshakeGetServerCert(HLERequestContext& ctx) { + struct OutputParameters { + u32 certs_size; + u32 certs_count; + }; + static_assert(sizeof(OutputParameters) == 0x8); + + const Result res = DoHandshakeImpl(); + OutputParameters out{}; + if (res == ResultSuccess) { + auto certs = backend->GetServerCerts(); + if (certs.Succeeded()) { + const std::vector certs_buf = SerializeServerCerts(*certs); + ctx.WriteBuffer(certs_buf); + out.certs_count = static_cast(certs->size()); + out.certs_size = static_cast(certs_buf.size()); + } + } + IPC::ResponseBuilder rb{ctx, 4}; + rb.Push(res); + rb.PushRaw(out); + } + + void Read(HLERequestContext& ctx) { + const ResultVal> res = ReadImpl(ctx.GetWriteBufferSize()); + IPC::ResponseBuilder rb{ctx, 3}; + rb.Push(res.Code()); + if (res.Succeeded()) { + rb.Push(static_cast(res->size())); + ctx.WriteBuffer(*res); + } else { + rb.Push(static_cast(0)); + } + } + + void Write(HLERequestContext& ctx) { + const ResultVal res = WriteImpl(ctx.ReadBuffer()); + IPC::ResponseBuilder rb{ctx, 3}; + rb.Push(res.Code()); + rb.Push(static_cast(res.ValueOr(0))); + } + + void Pending(HLERequestContext& ctx) { + const ResultVal res = PendingImpl(); + IPC::ResponseBuilder rb{ctx, 3}; + rb.Push(res.Code()); + rb.Push(res.ValueOr(0)); + } + + void SetSessionCacheMode(HLERequestContext& ctx) { + IPC::RequestParser rp{ctx}; + const u32 mode = rp.Pop(); + const Result res = SetSessionCacheModeImpl(mode); + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(res); + } + + void SetOption(HLERequestContext& ctx) { + struct Parameters { + OptionType option; + s32 value; + }; + static_assert(sizeof(Parameters) == 0x8, "Parameters is an invalid size"); + + IPC::RequestParser rp{ctx}; + const auto parameters = rp.PopRaw(); + + switch (parameters.option) { + case OptionType::DoNotCloseSocket: + do_not_close_socket = static_cast(parameters.value); + break; + case OptionType::GetServerCertChain: + get_server_cert_chain = static_cast(parameters.value); + break; + default: + LOG_WARNING(Service_SSL, "Unknown option={}, value={}", parameters.option, + parameters.value); + } + + IPC::ResponseBuilder rb{ctx, 2}; + rb.Push(ResultSuccess); + } }; class ISslContext final : public ServiceFramework { public: explicit ISslContext(Core::System& system_, SslVersion version) - : ServiceFramework{system_, "ISslContext"}, ssl_version{version} { + : ServiceFramework{system_, "ISslContext"}, ssl_version{version}, + shared_data{std::make_shared()} { static const FunctionInfo functions[] = { {0, &ISslContext::SetOption, "SetOption"}, {1, nullptr, "GetOption"}, {2, &ISslContext::CreateConnection, "CreateConnection"}, - {3, nullptr, "GetConnectionCount"}, + {3, &ISslContext::GetConnectionCount, "GetConnectionCount"}, {4, &ISslContext::ImportServerPki, "ImportServerPki"}, {5, &ISslContext::ImportClientPki, "ImportClientPki"}, {6, nullptr, "RemoveServerPki"}, @@ -111,6 +416,7 @@ public: private: SslVersion ssl_version; + std::shared_ptr shared_data; void SetOption(HLERequestContext& ctx) { struct Parameters { @@ -130,11 +436,24 @@ private: } void CreateConnection(HLERequestContext& ctx) { - LOG_WARNING(Service_SSL, "(STUBBED) called"); + LOG_WARNING(Service_SSL, "called"); + + auto backend_res = CreateSSLConnectionBackend(); IPC::ResponseBuilder rb{ctx, 2, 0, 1}; + rb.Push(backend_res.Code()); + if (backend_res.Succeeded()) { + rb.PushIpcInterface(system, ssl_version, shared_data, + std::move(*backend_res)); + } + } + + void GetConnectionCount(HLERequestContext& ctx) { + LOG_DEBUG(Service_SSL, "connection_count={}", shared_data->connection_count); + + IPC::ResponseBuilder rb{ctx, 3}; rb.Push(ResultSuccess); - rb.PushIpcInterface(system, ssl_version); + rb.Push(shared_data->connection_count); } void ImportServerPki(HLERequestContext& ctx) { diff --git a/src/core/hle/service/ssl/ssl_backend.h b/src/core/hle/service/ssl/ssl_backend.h new file mode 100755 index 000000000..25c16bcc1 --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend.h @@ -0,0 +1,45 @@ +// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#pragma once + +#include "core/hle/result.h" + +#include "common/common_types.h" + +#include +#include +#include +#include + +namespace Network { +class SocketBase; +} + +namespace Service::SSL { + +constexpr Result ResultNoSocket{ErrorModule::SSLSrv, 103}; +constexpr Result ResultInvalidSocket{ErrorModule::SSLSrv, 106}; +constexpr Result ResultTimeout{ErrorModule::SSLSrv, 205}; +constexpr Result ResultInternalError{ErrorModule::SSLSrv, 999}; // made up + +// ResultWouldBlock is returned from Read and Write, and oddly, DoHandshake, +// with no way in the latter case to distinguish whether the client should poll +// for read or write. The one official client I've seen handles this by always +// polling for read (with a timeout). +constexpr Result ResultWouldBlock{ErrorModule::SSLSrv, 204}; + +class SSLConnectionBackend { +public: + virtual ~SSLConnectionBackend() {} + virtual void SetSocket(std::shared_ptr socket) = 0; + virtual Result SetHostName(const std::string& hostname) = 0; + virtual Result DoHandshake() = 0; + virtual ResultVal Read(std::span data) = 0; + virtual ResultVal Write(std::span data) = 0; + virtual ResultVal>> GetServerCerts() = 0; +}; + +ResultVal> CreateSSLConnectionBackend(); + +} // namespace Service::SSL diff --git a/src/core/hle/service/ssl/ssl_backend_none.cpp b/src/core/hle/service/ssl/ssl_backend_none.cpp new file mode 100755 index 000000000..f2f0ef706 --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_none.cpp @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#include "core/hle/service/ssl/ssl_backend.h" + +#include "common/logging/log.h" + +namespace Service::SSL { + +ResultVal> CreateSSLConnectionBackend() { + LOG_ERROR(Service_SSL, + "Can't create SSL connection because no SSL backend is available on this platform"); + return ResultInternalError; +} + +} // namespace Service::SSL diff --git a/src/core/hle/service/ssl/ssl_backend_openssl.cpp b/src/core/hle/service/ssl/ssl_backend_openssl.cpp new file mode 100755 index 000000000..f69674f77 --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_openssl.cpp @@ -0,0 +1,351 @@ +// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#include "core/hle/service/ssl/ssl_backend.h" +#include "core/internal_network/network.h" +#include "core/internal_network/sockets.h" + +#include "common/fs/file.h" +#include "common/hex_util.h" +#include "common/string_util.h" + +#include + +#include +#include +#include +#include + +using namespace Common::FS; + +namespace Service::SSL { + +// Import OpenSSL's `SSL` type into the namespace. This is needed because the +// namespace is also named `SSL`. +using ::SSL; + +namespace { + +std::once_flag one_time_init_flag; +bool one_time_init_success = false; + +SSL_CTX* ssl_ctx; +IOFile key_log_file; // only open if SSLKEYLOGFILE set in environment +BIO_METHOD* bio_meth; + +Result CheckOpenSSLErrors(); +void OneTimeInit(); +void OneTimeInitLogFile(); +bool OneTimeInitBIO(); + +} // namespace + +class SSLConnectionBackendOpenSSL final : public SSLConnectionBackend { +public: + Result Init() { + std::call_once(one_time_init_flag, OneTimeInit); + + if (!one_time_init_success) { + LOG_ERROR(Service_SSL, + "Can't create SSL connection because OpenSSL one-time initialization failed"); + return ResultInternalError; + } + + ssl = SSL_new(ssl_ctx); + if (!ssl) { + LOG_ERROR(Service_SSL, "SSL_new failed"); + return CheckOpenSSLErrors(); + } + + SSL_set_connect_state(ssl); + + bio = BIO_new(bio_meth); + if (!bio) { + LOG_ERROR(Service_SSL, "BIO_new failed"); + return CheckOpenSSLErrors(); + } + + BIO_set_data(bio, this); + BIO_set_init(bio, 1); + SSL_set_bio(ssl, bio, bio); + + return ResultSuccess; + } + + void SetSocket(std::shared_ptr socket_in) override { + socket = std::move(socket_in); + } + + Result SetHostName(const std::string& hostname) override { + if (!SSL_set1_host(ssl, hostname.c_str())) { // hostname for verification + LOG_ERROR(Service_SSL, "SSL_set1_host({}) failed", hostname); + return CheckOpenSSLErrors(); + } + if (!SSL_set_tlsext_host_name(ssl, hostname.c_str())) { // hostname for SNI + LOG_ERROR(Service_SSL, "SSL_set_tlsext_host_name({}) failed", hostname); + return CheckOpenSSLErrors(); + } + return ResultSuccess; + } + + Result DoHandshake() override { + SSL_set_verify_result(ssl, X509_V_OK); + const int ret = SSL_do_handshake(ssl); + const long verify_result = SSL_get_verify_result(ssl); + if (verify_result != X509_V_OK) { + LOG_ERROR(Service_SSL, "SSL cert verification failed because: {}", + X509_verify_cert_error_string(verify_result)); + return CheckOpenSSLErrors(); + } + if (ret <= 0) { + const int ssl_err = SSL_get_error(ssl, ret); + if (ssl_err == SSL_ERROR_ZERO_RETURN || + (ssl_err == SSL_ERROR_SYSCALL && got_read_eof)) { + LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up"); + return ResultInternalError; + } + } + return HandleReturn("SSL_do_handshake", 0, ret).Code(); + } + + ResultVal Read(std::span data) override { + size_t actual; + const int ret = SSL_read_ex(ssl, data.data(), data.size(), &actual); + return HandleReturn("SSL_read_ex", actual, ret); + } + + ResultVal Write(std::span data) override { + size_t actual; + const int ret = SSL_write_ex(ssl, data.data(), data.size(), &actual); + return HandleReturn("SSL_write_ex", actual, ret); + } + + ResultVal HandleReturn(const char* what, size_t actual, int ret) { + const int ssl_err = SSL_get_error(ssl, ret); + CheckOpenSSLErrors(); + switch (ssl_err) { + case SSL_ERROR_NONE: + return actual; + case SSL_ERROR_ZERO_RETURN: + LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_ZERO_RETURN", what); + // DoHandshake special-cases this, but for Read and Write: + return size_t(0); + case SSL_ERROR_WANT_READ: + LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_READ", what); + return ResultWouldBlock; + case SSL_ERROR_WANT_WRITE: + LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_WRITE", what); + return ResultWouldBlock; + default: + if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof) { + LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what); + return size_t(0); + } + LOG_ERROR(Service_SSL, "{} => other SSL_get_error return value {}", what, ssl_err); + return ResultInternalError; + } + } + + ResultVal>> GetServerCerts() override { + STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl); + if (!chain) { + LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr"); + return ResultInternalError; + } + std::vector> ret; + int count = sk_X509_num(chain); + ASSERT(count >= 0); + for (int i = 0; i < count; i++) { + X509* x509 = sk_X509_value(chain, i); + ASSERT_OR_EXECUTE(x509 != nullptr, { continue; }); + unsigned char* buf = nullptr; + int len = i2d_X509(x509, &buf); + ASSERT_OR_EXECUTE(len >= 0 && buf, { continue; }); + ret.emplace_back(buf, buf + len); + OPENSSL_free(buf); + } + return ret; + } + + ~SSLConnectionBackendOpenSSL() { + // these are null-tolerant: + SSL_free(ssl); + BIO_free(bio); + } + + static void KeyLogCallback(const SSL* ssl, const char* line) { + std::string str(line); + str.push_back('\n'); + // Do this in a single WriteString for atomicity if multiple instances + // are running on different threads (though that can't currently + // happen). + if (key_log_file.WriteString(str) != str.size() || !key_log_file.Flush()) { + LOG_CRITICAL(Service_SSL, "Failed to write to SSLKEYLOGFILE"); + } + LOG_DEBUG(Service_SSL, "Wrote to SSLKEYLOGFILE: {}", line); + } + + static int WriteCallback(BIO* bio, const char* buf, size_t len, size_t* actual_p) { + auto self = static_cast(BIO_get_data(bio)); + ASSERT_OR_EXECUTE_MSG( + self->socket, { return 0; }, "OpenSSL asked to send but we have no socket"); + BIO_clear_retry_flags(bio); + auto [actual, err] = self->socket->Send({reinterpret_cast(buf), len}, 0); + switch (err) { + case Network::Errno::SUCCESS: + *actual_p = actual; + return 1; + case Network::Errno::AGAIN: + BIO_set_flags(bio, BIO_FLAGS_WRITE | BIO_FLAGS_SHOULD_RETRY); + return 0; + default: + LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err); + return -1; + } + } + + static int ReadCallback(BIO* bio, char* buf, size_t len, size_t* actual_p) { + auto self = static_cast(BIO_get_data(bio)); + ASSERT_OR_EXECUTE_MSG( + self->socket, { return 0; }, "OpenSSL asked to recv but we have no socket"); + BIO_clear_retry_flags(bio); + auto [actual, err] = self->socket->Recv(0, {reinterpret_cast(buf), len}); + switch (err) { + case Network::Errno::SUCCESS: + *actual_p = actual; + if (actual == 0) { + self->got_read_eof = true; + } + return actual ? 1 : 0; + case Network::Errno::AGAIN: + BIO_set_flags(bio, BIO_FLAGS_READ | BIO_FLAGS_SHOULD_RETRY); + return 0; + default: + LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err); + return -1; + } + } + + static long CtrlCallback(BIO* bio, int cmd, long l_arg, void* p_arg) { + switch (cmd) { + case BIO_CTRL_FLUSH: + // Nothing to flush. + return 1; + case BIO_CTRL_PUSH: + case BIO_CTRL_POP: +#ifdef BIO_CTRL_GET_KTLS_SEND + case BIO_CTRL_GET_KTLS_SEND: + case BIO_CTRL_GET_KTLS_RECV: +#endif + // We don't support these operations, but don't bother logging them + // as they're nothing unusual. + return 0; + default: + LOG_DEBUG(Service_SSL, "OpenSSL BIO got ctrl({}, {}, {})", cmd, l_arg, p_arg); + return 0; + } + } + + SSL* ssl = nullptr; + BIO* bio = nullptr; + bool got_read_eof = false; + + std::shared_ptr socket; +}; + +ResultVal> CreateSSLConnectionBackend() { + auto conn = std::make_unique(); + const Result res = conn->Init(); + if (res.IsFailure()) { + return res; + } + return conn; +} + +namespace { + +Result CheckOpenSSLErrors() { + unsigned long rc; + const char* file; + int line; + const char* func; + const char* data; + int flags; +#if OPENSSL_VERSION_NUMBER >= 0x30000000L + while ((rc = ERR_get_error_all(&file, &line, &func, &data, &flags))) +#else + // Can't get function names from OpenSSL on this version, so use mine: + func = __func__; + while ((rc = ERR_get_error_line_data(&file, &line, &data, &flags))) +#endif + { + std::string msg; + msg.resize(1024, '\0'); + ERR_error_string_n(rc, msg.data(), msg.size()); + msg.resize(strlen(msg.data()), '\0'); + if (flags & ERR_TXT_STRING) { + msg.append(" | "); + msg.append(data); + } + Common::Log::FmtLogMessage(Common::Log::Class::Service_SSL, Common::Log::Level::Error, + Common::Log::TrimSourcePath(file), line, func, "OpenSSL: {}", + msg); + } + return ResultInternalError; +} + +void OneTimeInit() { + ssl_ctx = SSL_CTX_new(TLS_client_method()); + if (!ssl_ctx) { + LOG_ERROR(Service_SSL, "SSL_CTX_new failed"); + CheckOpenSSLErrors(); + return; + } + + SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, nullptr); + + if (!SSL_CTX_set_default_verify_paths(ssl_ctx)) { + LOG_ERROR(Service_SSL, "SSL_CTX_set_default_verify_paths failed"); + CheckOpenSSLErrors(); + return; + } + + OneTimeInitLogFile(); + + if (!OneTimeInitBIO()) { + return; + } + + one_time_init_success = true; +} + +void OneTimeInitLogFile() { + const char* logfile = getenv("SSLKEYLOGFILE"); + if (logfile) { + key_log_file.Open(logfile, FileAccessMode::Append, FileType::TextFile, + FileShareFlag::ShareWriteOnly); + if (key_log_file.IsOpen()) { + SSL_CTX_set_keylog_callback(ssl_ctx, &SSLConnectionBackendOpenSSL::KeyLogCallback); + } else { + LOG_CRITICAL(Service_SSL, + "SSLKEYLOGFILE was set but file could not be opened; not logging keys!"); + } + } +} + +bool OneTimeInitBIO() { + bio_meth = + BIO_meth_new(BIO_get_new_index() | BIO_TYPE_SOURCE_SINK, "SSLConnectionBackendOpenSSL"); + if (!bio_meth || + !BIO_meth_set_write_ex(bio_meth, &SSLConnectionBackendOpenSSL::WriteCallback) || + !BIO_meth_set_read_ex(bio_meth, &SSLConnectionBackendOpenSSL::ReadCallback) || + !BIO_meth_set_ctrl(bio_meth, &SSLConnectionBackendOpenSSL::CtrlCallback)) { + LOG_ERROR(Service_SSL, "Failed to create BIO_METHOD"); + return false; + } + return true; +} + +} // namespace + +} // namespace Service::SSL diff --git a/src/core/hle/service/ssl/ssl_backend_schannel.cpp b/src/core/hle/service/ssl/ssl_backend_schannel.cpp new file mode 100755 index 000000000..a1d6a186e --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_schannel.cpp @@ -0,0 +1,543 @@ +// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#include "core/hle/service/ssl/ssl_backend.h" +#include "core/internal_network/network.h" +#include "core/internal_network/sockets.h" + +#include "common/error.h" +#include "common/fs/file.h" +#include "common/hex_util.h" +#include "common/string_util.h" + +#include + +namespace { + +// These includes are inside the namespace to avoid a conflict on MinGW where +// the headers define an enum containing Network and Service as enumerators +// (which clash with the correspondingly named namespaces). +#define SECURITY_WIN32 +#include +#include + +std::once_flag one_time_init_flag; +bool one_time_init_success = false; + +SCHANNEL_CRED schannel_cred{}; +CredHandle cred_handle; + +static void OneTimeInit() { + schannel_cred.dwVersion = SCHANNEL_CRED_VERSION; + schannel_cred.dwFlags = + SCH_USE_STRONG_CRYPTO | // don't allow insecure protocols + SCH_CRED_AUTO_CRED_VALIDATION | // validate certs + SCH_CRED_NO_DEFAULT_CREDS; // don't automatically present a client certificate + // ^ I'm assuming that nobody would want to connect Yuzu to a + // service that requires some OS-provided corporate client + // certificate, and presenting one to some arbitrary server + // might be a privacy concern? Who knows, though. + + const SECURITY_STATUS ret = + AcquireCredentialsHandle(nullptr, const_cast(UNISP_NAME), SECPKG_CRED_OUTBOUND, + nullptr, &schannel_cred, nullptr, nullptr, &cred_handle, nullptr); + if (ret != SEC_E_OK) { + // SECURITY_STATUS codes are a type of HRESULT and can be used with NativeErrorToString. + LOG_ERROR(Service_SSL, "AcquireCredentialsHandle failed: {}", + Common::NativeErrorToString(ret)); + return; + } + + if (getenv("SSLKEYLOGFILE")) { + LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but Schannel does not support exporting " + "keys; not logging keys!"); + // Not fatal. + } + + one_time_init_success = true; +} + +} // namespace + +namespace Service::SSL { + +class SSLConnectionBackendSchannel final : public SSLConnectionBackend { +public: + Result Init() { + std::call_once(one_time_init_flag, OneTimeInit); + + if (!one_time_init_success) { + LOG_ERROR( + Service_SSL, + "Can't create SSL connection because Schannel one-time initialization failed"); + return ResultInternalError; + } + + return ResultSuccess; + } + + void SetSocket(std::shared_ptr socket_in) override { + socket = std::move(socket_in); + } + + Result SetHostName(const std::string& hostname_in) override { + hostname = hostname_in; + return ResultSuccess; + } + + Result DoHandshake() override { + while (1) { + Result r; + switch (handshake_state) { + case HandshakeState::Initial: + if ((r = FlushCiphertextWriteBuf()) != ResultSuccess || + (r = CallInitializeSecurityContext()) != ResultSuccess) { + return r; + } + // CallInitializeSecurityContext updated `handshake_state`. + continue; + case HandshakeState::ContinueNeeded: + case HandshakeState::IncompleteMessage: + if ((r = FlushCiphertextWriteBuf()) != ResultSuccess || + (r = FillCiphertextReadBuf()) != ResultSuccess) { + return r; + } + if (ciphertext_read_buf.empty()) { + LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up"); + return ResultInternalError; + } + if ((r = CallInitializeSecurityContext()) != ResultSuccess) { + return r; + } + // CallInitializeSecurityContext updated `handshake_state`. + continue; + case HandshakeState::DoneAfterFlush: + if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) { + return r; + } + handshake_state = HandshakeState::Connected; + return ResultSuccess; + case HandshakeState::Connected: + LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook"); + return ResultInternalError; + case HandshakeState::Error: + return ResultInternalError; + } + } + } + + Result FillCiphertextReadBuf() { + const size_t fill_size = read_buf_fill_size ? read_buf_fill_size : 4096; + read_buf_fill_size = 0; + // This unnecessarily zeroes the buffer; oh well. + const size_t offset = ciphertext_read_buf.size(); + ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; }); + ciphertext_read_buf.resize(offset + fill_size, 0); + const auto read_span = std::span(ciphertext_read_buf).subspan(offset, fill_size); + const auto [actual, err] = socket->Recv(0, read_span); + switch (err) { + case Network::Errno::SUCCESS: + ASSERT(static_cast(actual) <= fill_size); + ciphertext_read_buf.resize(offset + actual); + return ResultSuccess; + case Network::Errno::AGAIN: + ciphertext_read_buf.resize(offset); + return ResultWouldBlock; + default: + ciphertext_read_buf.resize(offset); + LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err); + return ResultInternalError; + } + } + + // Returns success if the write buffer has been completely emptied. + Result FlushCiphertextWriteBuf() { + while (!ciphertext_write_buf.empty()) { + const auto [actual, err] = socket->Send(ciphertext_write_buf, 0); + switch (err) { + case Network::Errno::SUCCESS: + ASSERT(static_cast(actual) <= ciphertext_write_buf.size()); + ciphertext_write_buf.erase(ciphertext_write_buf.begin(), + ciphertext_write_buf.begin() + actual); + break; + case Network::Errno::AGAIN: + return ResultWouldBlock; + default: + LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err); + return ResultInternalError; + } + } + return ResultSuccess; + } + + Result CallInitializeSecurityContext() { + const unsigned long req = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | + ISC_REQ_INTEGRITY | ISC_REQ_REPLAY_DETECT | + ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM | + ISC_REQ_USE_SUPPLIED_CREDS; + unsigned long attr; + // https://learn.microsoft.com/en-us/windows/win32/secauthn/initializesecuritycontext--schannel + std::array input_buffers{{ + // only used if `initial_call_done` + { + // [0] + .cbBuffer = static_cast(ciphertext_read_buf.size()), + .BufferType = SECBUFFER_TOKEN, + .pvBuffer = ciphertext_read_buf.data(), + }, + { + // [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is + // returned, or SECBUFFER_EXTRA when SEC_E_CONTINUE_NEEDED is returned if the + // whole buffer wasn't used) + .cbBuffer = 0, + .BufferType = SECBUFFER_EMPTY, + .pvBuffer = nullptr, + }, + }}; + std::array output_buffers{{ + { + .cbBuffer = 0, + .BufferType = SECBUFFER_TOKEN, + .pvBuffer = nullptr, + }, // [0] + { + .cbBuffer = 0, + .BufferType = SECBUFFER_ALERT, + .pvBuffer = nullptr, + }, // [1] + }}; + SecBufferDesc input_desc{ + .ulVersion = SECBUFFER_VERSION, + .cBuffers = static_cast(input_buffers.size()), + .pBuffers = input_buffers.data(), + }; + SecBufferDesc output_desc{ + .ulVersion = SECBUFFER_VERSION, + .cBuffers = static_cast(output_buffers.size()), + .pBuffers = output_buffers.data(), + }; + ASSERT_OR_EXECUTE_MSG( + input_buffers[0].cbBuffer == ciphertext_read_buf.size(), + { return ResultInternalError; }, "read buffer too large"); + + bool initial_call_done = handshake_state != HandshakeState::Initial; + if (initial_call_done) { + LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext", + ciphertext_read_buf.size()); + } + + const SECURITY_STATUS ret = + InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt : nullptr, + // Caller ensured we have set a hostname: + const_cast(hostname.value().c_str()), req, + 0, // Reserved1 + 0, // TargetDataRep not used with Schannel + initial_call_done ? &input_desc : nullptr, + 0, // Reserved2 + initial_call_done ? nullptr : &ctxt, &output_desc, &attr, + nullptr); // ptsExpiry + + if (output_buffers[0].pvBuffer) { + const std::span span(static_cast(output_buffers[0].pvBuffer), + output_buffers[0].cbBuffer); + ciphertext_write_buf.insert(ciphertext_write_buf.end(), span.begin(), span.end()); + FreeContextBuffer(output_buffers[0].pvBuffer); + } + + if (output_buffers[1].pvBuffer) { + const std::span span(static_cast(output_buffers[1].pvBuffer), + output_buffers[1].cbBuffer); + // The documentation doesn't explain what format this data is in. + LOG_DEBUG(Service_SSL, "Got a {}-byte alert buffer: {}", span.size(), + Common::HexToString(span)); + } + + switch (ret) { + case SEC_I_CONTINUE_NEEDED: + LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED"); + if (input_buffers[1].BufferType == SECBUFFER_EXTRA) { + LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer); + ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf.size()); + ciphertext_read_buf.erase(ciphertext_read_buf.begin(), + ciphertext_read_buf.end() - input_buffers[1].cbBuffer); + } else { + ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY); + ciphertext_read_buf.clear(); + } + handshake_state = HandshakeState::ContinueNeeded; + return ResultSuccess; + case SEC_E_INCOMPLETE_MESSAGE: + LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE"); + ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING); + read_buf_fill_size = input_buffers[1].cbBuffer; + handshake_state = HandshakeState::IncompleteMessage; + return ResultSuccess; + case SEC_E_OK: + LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK"); + ciphertext_read_buf.clear(); + handshake_state = HandshakeState::DoneAfterFlush; + return GrabStreamSizes(); + default: + LOG_ERROR(Service_SSL, + "InitializeSecurityContext failed (probably certificate/protocol issue): {}", + Common::NativeErrorToString(ret)); + handshake_state = HandshakeState::Error; + return ResultInternalError; + } + } + + Result GrabStreamSizes() { + const SECURITY_STATUS ret = + QueryContextAttributes(&ctxt, SECPKG_ATTR_STREAM_SIZES, &stream_sizes); + if (ret != SEC_E_OK) { + LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}", + Common::NativeErrorToString(ret)); + handshake_state = HandshakeState::Error; + return ResultInternalError; + } + return ResultSuccess; + } + + ResultVal Read(std::span data) override { + if (handshake_state != HandshakeState::Connected) { + LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake"); + return ResultInternalError; + } + if (data.size() == 0 || got_read_eof) { + return size_t(0); + } + while (1) { + if (!cleartext_read_buf.empty()) { + const size_t read_size = std::min(cleartext_read_buf.size(), data.size()); + std::memcpy(data.data(), cleartext_read_buf.data(), read_size); + cleartext_read_buf.erase(cleartext_read_buf.begin(), + cleartext_read_buf.begin() + read_size); + return read_size; + } + if (!ciphertext_read_buf.empty()) { + SecBuffer empty{ + .cbBuffer = 0, + .BufferType = SECBUFFER_EMPTY, + .pvBuffer = nullptr, + }; + std::array buffers{{ + { + .cbBuffer = static_cast(ciphertext_read_buf.size()), + .BufferType = SECBUFFER_DATA, + .pvBuffer = ciphertext_read_buf.data(), + }, + empty, + empty, + empty, + }}; + ASSERT_OR_EXECUTE_MSG( + buffers[0].cbBuffer == ciphertext_read_buf.size(), + { return ResultInternalError; }, "read buffer too large"); + SecBufferDesc desc{ + .ulVersion = SECBUFFER_VERSION, + .cBuffers = static_cast(buffers.size()), + .pBuffers = buffers.data(), + }; + SECURITY_STATUS ret = + DecryptMessage(&ctxt, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr); + switch (ret) { + case SEC_E_OK: + ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER, + { return ResultInternalError; }); + ASSERT_OR_EXECUTE(buffers[1].BufferType == SECBUFFER_DATA, + { return ResultInternalError; }); + ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER, + { return ResultInternalError; }); + cleartext_read_buf.assign(static_cast(buffers[1].pvBuffer), + static_cast(buffers[1].pvBuffer) + + buffers[1].cbBuffer); + if (buffers[3].BufferType == SECBUFFER_EXTRA) { + ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf.size()); + ciphertext_read_buf.erase(ciphertext_read_buf.begin(), + ciphertext_read_buf.end() - buffers[3].cbBuffer); + } else { + ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY); + ciphertext_read_buf.clear(); + } + continue; + case SEC_E_INCOMPLETE_MESSAGE: + break; + case SEC_I_CONTEXT_EXPIRED: + // Server hung up by sending close_notify. + got_read_eof = true; + return size_t(0); + default: + LOG_ERROR(Service_SSL, "DecryptMessage failed: {}", + Common::NativeErrorToString(ret)); + return ResultInternalError; + } + } + const Result r = FillCiphertextReadBuf(); + if (r != ResultSuccess) { + return r; + } + if (ciphertext_read_buf.empty()) { + got_read_eof = true; + return size_t(0); + } + } + } + + ResultVal Write(std::span data) override { + if (handshake_state != HandshakeState::Connected) { + LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake"); + return ResultInternalError; + } + if (data.size() == 0) { + return size_t(0); + } + data = data.subspan(0, std::min(data.size(), stream_sizes.cbMaximumMessage)); + if (!cleartext_write_buf.empty()) { + // Already in the middle of a write. It wouldn't make sense to not + // finish sending the entire buffer since TLS has + // header/MAC/padding/etc. + if (data.size() != cleartext_write_buf.size() || + std::memcmp(data.data(), cleartext_write_buf.data(), data.size())) { + LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer"); + return ResultInternalError; + } + return WriteAlreadyEncryptedData(); + } else { + cleartext_write_buf.assign(data.begin(), data.end()); + } + + std::vector header_buf(stream_sizes.cbHeader, 0); + std::vector tmp_data_buf = cleartext_write_buf; + std::vector trailer_buf(stream_sizes.cbTrailer, 0); + + std::array buffers{{ + { + .cbBuffer = stream_sizes.cbHeader, + .BufferType = SECBUFFER_STREAM_HEADER, + .pvBuffer = header_buf.data(), + }, + { + .cbBuffer = static_cast(tmp_data_buf.size()), + .BufferType = SECBUFFER_DATA, + .pvBuffer = tmp_data_buf.data(), + }, + { + .cbBuffer = stream_sizes.cbTrailer, + .BufferType = SECBUFFER_STREAM_TRAILER, + .pvBuffer = trailer_buf.data(), + }, + }}; + ASSERT_OR_EXECUTE_MSG( + buffers[1].cbBuffer == tmp_data_buf.size(), { return ResultInternalError; }, + "temp buffer too large"); + SecBufferDesc desc{ + .ulVersion = SECBUFFER_VERSION, + .cBuffers = static_cast(buffers.size()), + .pBuffers = buffers.data(), + }; + + const SECURITY_STATUS ret = EncryptMessage(&ctxt, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0); + if (ret != SEC_E_OK) { + LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret)); + return ResultInternalError; + } + ciphertext_write_buf.insert(ciphertext_write_buf.end(), header_buf.begin(), + header_buf.end()); + ciphertext_write_buf.insert(ciphertext_write_buf.end(), tmp_data_buf.begin(), + tmp_data_buf.end()); + ciphertext_write_buf.insert(ciphertext_write_buf.end(), trailer_buf.begin(), + trailer_buf.end()); + return WriteAlreadyEncryptedData(); + } + + ResultVal WriteAlreadyEncryptedData() { + const Result r = FlushCiphertextWriteBuf(); + if (r != ResultSuccess) { + return r; + } + // write buf is empty + const size_t cleartext_bytes_written = cleartext_write_buf.size(); + cleartext_write_buf.clear(); + return cleartext_bytes_written; + } + + ResultVal>> GetServerCerts() override { + PCCERT_CONTEXT returned_cert = nullptr; + const SECURITY_STATUS ret = + QueryContextAttributes(&ctxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert); + if (ret != SEC_E_OK) { + LOG_ERROR(Service_SSL, + "QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}", + Common::NativeErrorToString(ret)); + return ResultInternalError; + } + PCCERT_CONTEXT some_cert = nullptr; + std::vector> certs; + while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) { + certs.emplace_back(static_cast(some_cert->pbCertEncoded), + static_cast(some_cert->pbCertEncoded) + + some_cert->cbCertEncoded); + } + std::reverse(certs.begin(), + certs.end()); // Windows returns certs in reverse order from what we want + CertFreeCertificateContext(returned_cert); + return certs; + } + + ~SSLConnectionBackendSchannel() { + if (handshake_state != HandshakeState::Initial) { + DeleteSecurityContext(&ctxt); + } + } + + enum class HandshakeState { + // Haven't called anything yet. + Initial, + // `SEC_I_CONTINUE_NEEDED` was returned by + // `InitializeSecurityContext`; must finish sending data (if any) in + // the write buffer, then read at least one byte before calling + // `InitializeSecurityContext` again. + ContinueNeeded, + // `SEC_E_INCOMPLETE_MESSAGE` was returned by + // `InitializeSecurityContext`; hopefully the write buffer is empty; + // must read at least one byte before calling + // `InitializeSecurityContext` again. + IncompleteMessage, + // `SEC_E_OK` was returned by `InitializeSecurityContext`; must + // finish sending data in the write buffer before having `DoHandshake` + // report success. + DoneAfterFlush, + // We finished the above and are now connected. At this point, writing + // and reading are separate 'state machines' represented by the + // nonemptiness of the ciphertext and cleartext read and write buffers. + Connected, + // Another error was returned and we shouldn't allow initialization + // to continue. + Error, + } handshake_state = HandshakeState::Initial; + + CtxtHandle ctxt; + SecPkgContext_StreamSizes stream_sizes; + + std::shared_ptr socket; + std::optional hostname; + + std::vector ciphertext_read_buf; + std::vector ciphertext_write_buf; + std::vector cleartext_read_buf; + std::vector cleartext_write_buf; + + bool got_read_eof = false; + size_t read_buf_fill_size = 0; +}; + +ResultVal> CreateSSLConnectionBackend() { + auto conn = std::make_unique(); + const Result res = conn->Init(); + if (res.IsFailure()) { + return res; + } + return conn; +} + +} // namespace Service::SSL diff --git a/src/core/hle/service/ssl/ssl_backend_securetransport.cpp b/src/core/hle/service/ssl/ssl_backend_securetransport.cpp new file mode 100755 index 000000000..be40a5aeb --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_securetransport.cpp @@ -0,0 +1,219 @@ +// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#include "core/hle/service/ssl/ssl_backend.h" +#include "core/internal_network/network.h" +#include "core/internal_network/sockets.h" + +#include + +#include + +// SecureTransport has been deprecated in its entirety in favor of +// Network.framework, but that does not allow layering TLS on top of an +// arbitrary socket. +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + +namespace { + +template +struct CFReleaser { + T ptr; + + YUZU_NON_COPYABLE(CFReleaser); + constexpr CFReleaser() : ptr(nullptr) {} + constexpr CFReleaser(T ptr) : ptr(ptr) {} + constexpr operator T() { + return ptr; + } + ~CFReleaser() { + if (ptr) { + CFRelease(ptr); + } + } +}; + +std::string CFStringToString(CFStringRef cfstr) { + CFReleaser cfdata( + CFStringCreateExternalRepresentation(nullptr, cfstr, kCFStringEncodingUTF8, 0)); + ASSERT_OR_EXECUTE(cfdata, { return "???"; }); + return std::string(reinterpret_cast(CFDataGetBytePtr(cfdata)), + CFDataGetLength(cfdata)); +} + +std::string OSStatusToString(OSStatus status) { + CFReleaser cfstr(SecCopyErrorMessageString(status, nullptr)); + if (!cfstr) { + return "[unknown error]"; + } + return CFStringToString(cfstr); +} + +} // namespace + +namespace Service::SSL { + +class SSLConnectionBackendSecureTransport final : public SSLConnectionBackend { +public: + Result Init() { + static std::once_flag once_flag; + std::call_once(once_flag, []() { + if (getenv("SSLKEYLOGFILE")) { + LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but SecureTransport does not " + "support exporting keys; not logging keys!"); + // Not fatal. + } + }); + + context.ptr = SSLCreateContext(nullptr, kSSLClientSide, kSSLStreamType); + if (!context) { + LOG_ERROR(Service_SSL, "SSLCreateContext failed"); + return ResultInternalError; + } + + OSStatus status; + if ((status = SSLSetIOFuncs(context, ReadCallback, WriteCallback)) || + (status = SSLSetConnection(context, this))) { + LOG_ERROR(Service_SSL, "SSLContext initialization failed: {}", + OSStatusToString(status)); + return ResultInternalError; + } + + return ResultSuccess; + } + + void SetSocket(std::shared_ptr in_socket) override { + socket = std::move(in_socket); + } + + Result SetHostName(const std::string& hostname) override { + OSStatus status = SSLSetPeerDomainName(context, hostname.c_str(), hostname.size()); + if (status) { + LOG_ERROR(Service_SSL, "SSLSetPeerDomainName failed: {}", OSStatusToString(status)); + return ResultInternalError; + } + return ResultSuccess; + } + + Result DoHandshake() override { + OSStatus status = SSLHandshake(context); + return HandleReturn("SSLHandshake", 0, status).Code(); + } + + ResultVal Read(std::span data) override { + size_t actual; + OSStatus status = SSLRead(context, data.data(), data.size(), &actual); + ; + return HandleReturn("SSLRead", actual, status); + } + + ResultVal Write(std::span data) override { + size_t actual; + OSStatus status = SSLWrite(context, data.data(), data.size(), &actual); + ; + return HandleReturn("SSLWrite", actual, status); + } + + ResultVal HandleReturn(const char* what, size_t actual, OSStatus status) { + switch (status) { + case 0: + return actual; + case errSSLWouldBlock: + return ResultWouldBlock; + default: { + std::string reason; + if (got_read_eof) { + reason = "server hung up"; + } else { + reason = OSStatusToString(status); + } + LOG_ERROR(Service_SSL, "{} failed: {}", what, reason); + return ResultInternalError; + } + } + } + + ResultVal>> GetServerCerts() override { + CFReleaser trust; + OSStatus status = SSLCopyPeerTrust(context, &trust.ptr); + if (status) { + LOG_ERROR(Service_SSL, "SSLCopyPeerTrust failed: {}", OSStatusToString(status)); + return ResultInternalError; + } + std::vector> ret; + for (CFIndex i = 0, count = SecTrustGetCertificateCount(trust); i < count; i++) { + SecCertificateRef cert = SecTrustGetCertificateAtIndex(trust, i); + CFReleaser data(SecCertificateCopyData(cert)); + ASSERT_OR_EXECUTE(data, { return ResultInternalError; }); + const u8* ptr = CFDataGetBytePtr(data); + ret.emplace_back(ptr, ptr + CFDataGetLength(data)); + } + return ret; + } + + static OSStatus ReadCallback(SSLConnectionRef connection, void* data, size_t* dataLength) { + return ReadOrWriteCallback(connection, data, dataLength, true); + } + + static OSStatus WriteCallback(SSLConnectionRef connection, const void* data, + size_t* dataLength) { + return ReadOrWriteCallback(connection, const_cast(data), dataLength, false); + } + + static OSStatus ReadOrWriteCallback(SSLConnectionRef connection, void* data, size_t* dataLength, + bool is_read) { + auto self = + static_cast(const_cast(connection)); + ASSERT_OR_EXECUTE_MSG( + self->socket, { return 0; }, "SecureTransport asked to {} but we have no socket", + is_read ? "read" : "write"); + + // SecureTransport callbacks (unlike OpenSSL BIO callbacks) are + // expected to read/write the full requested dataLength or return an + // error, so we have to add a loop ourselves. + size_t requested_len = *dataLength; + size_t offset = 0; + while (offset < requested_len) { + std::span cur(reinterpret_cast(data) + offset, requested_len - offset); + auto [actual, err] = is_read ? self->socket->Recv(0, cur) : self->socket->Send(cur, 0); + LOG_CRITICAL(Service_SSL, "op={}, offset={} actual={}/{} err={}", is_read, offset, + actual, cur.size(), static_cast(err)); + switch (err) { + case Network::Errno::SUCCESS: + offset += actual; + if (actual == 0) { + ASSERT(is_read); + self->got_read_eof = true; + return errSecEndOfData; + } + break; + case Network::Errno::AGAIN: + *dataLength = offset; + return errSSLWouldBlock; + default: + LOG_ERROR(Service_SSL, "Socket {} returned Network::Errno {}", + is_read ? "recv" : "send", err); + return errSecIO; + } + } + ASSERT(offset == requested_len); + return 0; + } + +private: + CFReleaser context = nullptr; + bool got_read_eof = false; + + std::shared_ptr socket; +}; + +ResultVal> CreateSSLConnectionBackend() { + auto conn = std::make_unique(); + const Result res = conn->Init(); + if (res.IsFailure()) { + return res; + } + return conn; +} + +} // namespace Service::SSL diff --git a/src/core/internal_network/network.cpp b/src/core/internal_network/network.cpp index 20d2e5aff..e2827cc26 100755 --- a/src/core/internal_network/network.cpp +++ b/src/core/internal_network/network.cpp @@ -27,6 +27,7 @@ #include "common/assert.h" #include "common/common_types.h" +#include "common/expected.h" #include "common/logging/log.h" #include "common/settings.h" #include "core/internal_network/network.h" @@ -97,6 +98,8 @@ bool EnableNonBlock(SOCKET fd, bool enable) { Errno TranslateNativeError(int e) { switch (e) { + case 0: + return Errno::SUCCESS; case WSAEBADF: return Errno::BADF; case WSAEINVAL: @@ -121,6 +124,8 @@ Errno TranslateNativeError(int e) { return Errno::MSGSIZE; case WSAETIMEDOUT: return Errno::TIMEDOUT; + case WSAEINPROGRESS: + return Errno::INPROGRESS; default: UNIMPLEMENTED_MSG("Unimplemented errno={}", e); return Errno::OTHER; @@ -195,6 +200,8 @@ bool EnableNonBlock(int fd, bool enable) { Errno TranslateNativeError(int e) { switch (e) { + case 0: + return Errno::SUCCESS; case EBADF: return Errno::BADF; case EINVAL: @@ -219,8 +226,10 @@ Errno TranslateNativeError(int e) { return Errno::MSGSIZE; case ETIMEDOUT: return Errno::TIMEDOUT; + case EINPROGRESS: + return Errno::INPROGRESS; default: - UNIMPLEMENTED_MSG("Unimplemented errno={}", e); + UNIMPLEMENTED_MSG("Unimplemented errno={} ({})", e, strerror(e)); return Errno::OTHER; } } @@ -234,15 +243,84 @@ Errno GetAndLogLastError() { int e = errno; #endif const Errno err = TranslateNativeError(e); - if (err == Errno::AGAIN || err == Errno::TIMEDOUT) { + if (err == Errno::AGAIN || err == Errno::TIMEDOUT || err == Errno::INPROGRESS) { + // These happen during normal operation, so only log them at debug level. + LOG_DEBUG(Network, "Socket operation error: {}", Common::NativeErrorToString(e)); return err; } LOG_ERROR(Network, "Socket operation error: {}", Common::NativeErrorToString(e)); return err; } -int TranslateDomain(Domain domain) { +GetAddrInfoError TranslateGetAddrInfoErrorFromNative(int gai_err) { + switch (gai_err) { + case 0: + return GetAddrInfoError::SUCCESS; +#ifdef EAI_ADDRFAMILY + case EAI_ADDRFAMILY: + return GetAddrInfoError::ADDRFAMILY; +#endif + case EAI_AGAIN: + return GetAddrInfoError::AGAIN; + case EAI_BADFLAGS: + return GetAddrInfoError::BADFLAGS; + case EAI_FAIL: + return GetAddrInfoError::FAIL; + case EAI_FAMILY: + return GetAddrInfoError::FAMILY; + case EAI_MEMORY: + return GetAddrInfoError::MEMORY; + case EAI_NONAME: + return GetAddrInfoError::NONAME; + case EAI_SERVICE: + return GetAddrInfoError::SERVICE; + case EAI_SOCKTYPE: + return GetAddrInfoError::SOCKTYPE; + // These codes may not be defined on all systems: +#ifdef EAI_SYSTEM + case EAI_SYSTEM: + return GetAddrInfoError::SYSTEM; +#endif +#ifdef EAI_BADHINTS + case EAI_BADHINTS: + return GetAddrInfoError::BADHINTS; +#endif +#ifdef EAI_PROTOCOL + case EAI_PROTOCOL: + return GetAddrInfoError::PROTOCOL; +#endif +#ifdef EAI_OVERFLOW + case EAI_OVERFLOW: + return GetAddrInfoError::OVERFLOW_; +#endif + default: +#ifdef EAI_NODATA + // This can't be a case statement because it would create a duplicate + // case on Windows where EAI_NODATA is an alias for EAI_NONAME. + if (gai_err == EAI_NODATA) { + return GetAddrInfoError::NODATA; + } +#endif + return GetAddrInfoError::OTHER; + } +} + +Domain TranslateDomainFromNative(int domain) { switch (domain) { + case 0: + return Domain::Unspecified; + case AF_INET: + return Domain::INET; + default: + UNIMPLEMENTED_MSG("Unhandled domain={}", domain); + return Domain::INET; + } +} + +int TranslateDomainToNative(Domain domain) { + switch (domain) { + case Domain::Unspecified: + return 0; case Domain::INET: return AF_INET; default: @@ -251,20 +329,58 @@ int TranslateDomain(Domain domain) { } } -int TranslateType(Type type) { +Type TranslateTypeFromNative(int type) { switch (type) { + case 0: + return Type::Unspecified; + case SOCK_STREAM: + return Type::STREAM; + case SOCK_DGRAM: + return Type::DGRAM; + case SOCK_RAW: + return Type::RAW; + case SOCK_SEQPACKET: + return Type::SEQPACKET; + default: + UNIMPLEMENTED_MSG("Unimplemented type={}", type); + return Type::STREAM; + } +} + +int TranslateTypeToNative(Type type) { + switch (type) { + case Type::Unspecified: + return 0; case Type::STREAM: return SOCK_STREAM; case Type::DGRAM: return SOCK_DGRAM; + case Type::RAW: + return SOCK_RAW; default: UNIMPLEMENTED_MSG("Unimplemented type={}", type); return 0; } } -int TranslateProtocol(Protocol protocol) { +Protocol TranslateProtocolFromNative(int protocol) { switch (protocol) { + case 0: + return Protocol::Unspecified; + case IPPROTO_TCP: + return Protocol::TCP; + case IPPROTO_UDP: + return Protocol::UDP; + default: + UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol); + return Protocol::Unspecified; + } +} + +int TranslateProtocolToNative(Protocol protocol) { + switch (protocol) { + case Protocol::Unspecified: + return 0; case Protocol::TCP: return IPPROTO_TCP; case Protocol::UDP: @@ -275,21 +391,10 @@ int TranslateProtocol(Protocol protocol) { } } -SockAddrIn TranslateToSockAddrIn(sockaddr input_) { - sockaddr_in input; - std::memcpy(&input, &input_, sizeof(input)); - +SockAddrIn TranslateToSockAddrIn(sockaddr_in input, size_t input_len) { SockAddrIn result; - switch (input.sin_family) { - case AF_INET: - result.family = Domain::INET; - break; - default: - UNIMPLEMENTED_MSG("Unhandled sockaddr family={}", input.sin_family); - result.family = Domain::INET; - break; - } + result.family = TranslateDomainFromNative(input.sin_family); result.portno = ntohs(input.sin_port); @@ -301,22 +406,33 @@ SockAddrIn TranslateToSockAddrIn(sockaddr input_) { short TranslatePollEvents(PollEvents events) { short result = 0; - if (True(events & PollEvents::In)) { - events &= ~PollEvents::In; - result |= POLLIN; - } - if (True(events & PollEvents::Pri)) { - events &= ~PollEvents::Pri; + const auto translate = [&result, &events](PollEvents guest, short host) { + if (True(events & guest)) { + events &= ~guest; + result |= host; + } + }; + + translate(PollEvents::In, POLLIN); + translate(PollEvents::Pri, POLLPRI); + translate(PollEvents::Out, POLLOUT); + translate(PollEvents::Err, POLLERR); + translate(PollEvents::Hup, POLLHUP); + translate(PollEvents::Nval, POLLNVAL); + translate(PollEvents::RdNorm, POLLRDNORM); + translate(PollEvents::RdBand, POLLRDBAND); + translate(PollEvents::WrBand, POLLWRBAND); + #ifdef _WIN32 - LOG_WARNING(Service, "Winsock doesn't support POLLPRI"); -#else - result |= POLLPRI; + short allowed_events = POLLRDBAND | POLLRDNORM | POLLWRNORM; + // Unlike poll on other OSes, WSAPoll will complain if any other flags are set on input. + if (result & ~allowed_events) { + LOG_DEBUG(Network, + "Removing WSAPoll input events 0x{:x} because Windows doesn't support them", + result & ~allowed_events); + } + result &= allowed_events; #endif - } - if (True(events & PollEvents::Out)) { - events &= ~PollEvents::Out; - result |= POLLOUT; - } UNIMPLEMENTED_IF_MSG((u16)events != 0, "Unhandled guest events=0x{:x}", (u16)events); @@ -337,6 +453,10 @@ PollEvents TranslatePollRevents(short revents) { translate(POLLOUT, PollEvents::Out); translate(POLLERR, PollEvents::Err); translate(POLLHUP, PollEvents::Hup); + translate(POLLNVAL, PollEvents::Nval); + translate(POLLRDNORM, PollEvents::RdNorm); + translate(POLLRDBAND, PollEvents::RdBand); + translate(POLLWRBAND, PollEvents::WrBand); UNIMPLEMENTED_IF_MSG(revents != 0, "Unhandled host revents=0x{:x}", revents); @@ -360,12 +480,51 @@ std::optional GetHostIPv4Address() { return {}; } - std::array ip_addr = {}; - ASSERT(inet_ntop(AF_INET, &network_interface->ip_address, ip_addr.data(), sizeof(ip_addr)) != - nullptr); return TranslateIPv4(network_interface->ip_address); } +std::string IPv4AddressToString(IPv4Address ip_addr) { + std::array buf = {}; + ASSERT(inet_ntop(AF_INET, &ip_addr, buf.data(), sizeof(buf)) == buf.data()); + return std::string(buf.data()); +} + +u32 IPv4AddressToInteger(IPv4Address ip_addr) { + return static_cast(ip_addr[0]) << 24 | static_cast(ip_addr[1]) << 16 | + static_cast(ip_addr[2]) << 8 | static_cast(ip_addr[3]); +} + +Common::Expected, GetAddrInfoError> GetAddressInfo( + const std::string& host, const std::optional& service) { + addrinfo hints{}; + hints.ai_family = AF_INET; // Switch only supports IPv4. + addrinfo* addrinfo; + s32 gai_err = getaddrinfo(host.c_str(), service.has_value() ? service->c_str() : nullptr, + &hints, &addrinfo); + if (gai_err != 0) { + return Common::Unexpected(TranslateGetAddrInfoErrorFromNative(gai_err)); + } + std::vector ret; + for (auto* current = addrinfo; current; current = current->ai_next) { + // We should only get AF_INET results due to the hints value. + ASSERT_OR_EXECUTE(addrinfo->ai_family == AF_INET && + addrinfo->ai_addrlen == sizeof(sockaddr_in), + continue;); + + AddrInfo& out = ret.emplace_back(); + out.family = TranslateDomainFromNative(current->ai_family); + out.socket_type = TranslateTypeFromNative(current->ai_socktype); + out.protocol = TranslateProtocolFromNative(current->ai_protocol); + out.addr = TranslateToSockAddrIn(*reinterpret_cast(current->ai_addr), + current->ai_addrlen); + if (current->ai_canonname != nullptr) { + out.canon_name = current->ai_canonname; + } + } + freeaddrinfo(addrinfo); + return ret; +} + std::pair Poll(std::vector& pollfds, s32 timeout) { const size_t num = pollfds.size(); @@ -411,9 +570,21 @@ Socket::Socket(Socket&& rhs) noexcept { } template -Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) { +std::pair Socket::GetSockOpt(SOCKET fd_so, int option) { + T value{}; + socklen_t len = sizeof(value); + const int result = getsockopt(fd_so, SOL_SOCKET, option, reinterpret_cast(&value), &len); + if (result != SOCKET_ERROR) { + ASSERT(len == sizeof(value)); + return {value, Errno::SUCCESS}; + } + return {value, GetAndLogLastError()}; +} + +template +Errno Socket::SetSockOpt(SOCKET fd_so, int option, T value) { const int result = - setsockopt(fd_, SOL_SOCKET, option, reinterpret_cast(&value), sizeof(value)); + setsockopt(fd_so, SOL_SOCKET, option, reinterpret_cast(&value), sizeof(value)); if (result != SOCKET_ERROR) { return Errno::SUCCESS; } @@ -421,7 +592,8 @@ Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) { } Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) { - fd = socket(TranslateDomain(domain), TranslateType(type), TranslateProtocol(protocol)); + fd = socket(TranslateDomainToNative(domain), TranslateTypeToNative(type), + TranslateProtocolToNative(protocol)); if (fd != INVALID_SOCKET) { return Errno::SUCCESS; } @@ -430,19 +602,17 @@ Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) { } std::pair Socket::Accept() { - sockaddr addr; + sockaddr_in addr; socklen_t addrlen = sizeof(addr); - const SOCKET new_socket = accept(fd, &addr, &addrlen); + const SOCKET new_socket = accept(fd, reinterpret_cast(&addr), &addrlen); if (new_socket == INVALID_SOCKET) { return {AcceptResult{}, GetAndLogLastError()}; } - ASSERT(addrlen == sizeof(sockaddr_in)); - AcceptResult result{ .socket = std::make_unique(new_socket), - .sockaddr_in = TranslateToSockAddrIn(addr), + .sockaddr_in = TranslateToSockAddrIn(addr, addrlen), }; return {std::move(result), Errno::SUCCESS}; @@ -458,25 +628,23 @@ Errno Socket::Connect(SockAddrIn addr_in) { } std::pair Socket::GetPeerName() { - sockaddr addr; + sockaddr_in addr; socklen_t addrlen = sizeof(addr); - if (getpeername(fd, &addr, &addrlen) == SOCKET_ERROR) { + if (getpeername(fd, reinterpret_cast(&addr), &addrlen) == SOCKET_ERROR) { return {SockAddrIn{}, GetAndLogLastError()}; } - ASSERT(addrlen == sizeof(sockaddr_in)); - return {TranslateToSockAddrIn(addr), Errno::SUCCESS}; + return {TranslateToSockAddrIn(addr, addrlen), Errno::SUCCESS}; } std::pair Socket::GetSockName() { - sockaddr addr; + sockaddr_in addr; socklen_t addrlen = sizeof(addr); - if (getsockname(fd, &addr, &addrlen) == SOCKET_ERROR) { + if (getsockname(fd, reinterpret_cast(&addr), &addrlen) == SOCKET_ERROR) { return {SockAddrIn{}, GetAndLogLastError()}; } - ASSERT(addrlen == sizeof(sockaddr_in)); - return {TranslateToSockAddrIn(addr), Errno::SUCCESS}; + return {TranslateToSockAddrIn(addr, addrlen), Errno::SUCCESS}; } Errno Socket::Bind(SockAddrIn addr) { @@ -519,7 +687,7 @@ Errno Socket::Shutdown(ShutdownHow how) { return GetAndLogLastError(); } -std::pair Socket::Recv(int flags, std::vector& message) { +std::pair Socket::Recv(int flags, std::span message) { ASSERT(flags == 0); ASSERT(message.size() < static_cast(std::numeric_limits::max())); @@ -532,21 +700,20 @@ std::pair Socket::Recv(int flags, std::vector& message) { return {-1, GetAndLogLastError()}; } -std::pair Socket::RecvFrom(int flags, std::vector& message, SockAddrIn* addr) { +std::pair Socket::RecvFrom(int flags, std::span message, SockAddrIn* addr) { ASSERT(flags == 0); ASSERT(message.size() < static_cast(std::numeric_limits::max())); - sockaddr addr_in{}; + sockaddr_in addr_in{}; socklen_t addrlen = sizeof(addr_in); socklen_t* const p_addrlen = addr ? &addrlen : nullptr; - sockaddr* const p_addr_in = addr ? &addr_in : nullptr; + sockaddr* const p_addr_in = addr ? reinterpret_cast(&addr_in) : nullptr; const auto result = recvfrom(fd, reinterpret_cast(message.data()), static_cast(message.size()), 0, p_addr_in, p_addrlen); if (result != SOCKET_ERROR) { if (addr) { - ASSERT(addrlen == sizeof(addr_in)); - *addr = TranslateToSockAddrIn(addr_in); + *addr = TranslateToSockAddrIn(addr_in, addrlen); } return {static_cast(result), Errno::SUCCESS}; } @@ -597,6 +764,11 @@ Errno Socket::Close() { return Errno::SUCCESS; } +std::pair Socket::GetPendingError() { + auto [pending_err, getsockopt_err] = GetSockOpt(fd, SO_ERROR); + return {TranslateNativeError(pending_err), getsockopt_err}; +} + Errno Socket::SetLinger(bool enable, u32 linger) { return SetSockOpt(fd, SO_LINGER, MakeLinger(enable, linger)); } diff --git a/src/core/internal_network/network.h b/src/core/internal_network/network.h index 1bad3f6a0..9dcf39b7a 100755 --- a/src/core/internal_network/network.h +++ b/src/core/internal_network/network.h @@ -5,6 +5,7 @@ #include #include +#include #include "common/common_funcs.h" #include "common/common_types.h" @@ -16,6 +17,11 @@ #include #endif +namespace Common { +template +class Expected; +} + namespace Network { class SocketBase; @@ -36,6 +42,26 @@ enum class Errno { NETUNREACH, TIMEDOUT, MSGSIZE, + INPROGRESS, + OTHER, +}; + +enum class GetAddrInfoError { + SUCCESS, + ADDRFAMILY, + AGAIN, + BADFLAGS, + FAIL, + FAMILY, + MEMORY, + NODATA, + NONAME, + SERVICE, + SOCKTYPE, + SYSTEM, + BADHINTS, + PROTOCOL, + OVERFLOW_, OTHER, }; @@ -49,6 +75,9 @@ enum class PollEvents : u16 { Err = 1 << 3, Hup = 1 << 4, Nval = 1 << 5, + RdNorm = 1 << 6, + RdBand = 1 << 7, + WrBand = 1 << 8, }; DECLARE_ENUM_FLAG_OPERATORS(PollEvents); @@ -82,4 +111,11 @@ constexpr IPv4Address TranslateIPv4(in_addr addr) { /// @return human ordered IPv4 address (e.g. 192.168.0.1) as an array std::optional GetHostIPv4Address(); +std::string IPv4AddressToString(IPv4Address ip_addr); +u32 IPv4AddressToInteger(IPv4Address ip_addr); + +// named to avoid name collision with Windows macro +Common::Expected, GetAddrInfoError> GetAddressInfo( + const std::string& host, const std::optional& service); + } // namespace Network diff --git a/src/core/internal_network/socket_proxy.cpp b/src/core/internal_network/socket_proxy.cpp index cc4afe13f..02404155f 100755 --- a/src/core/internal_network/socket_proxy.cpp +++ b/src/core/internal_network/socket_proxy.cpp @@ -98,7 +98,7 @@ Errno ProxySocket::Shutdown(ShutdownHow how) { return Errno::SUCCESS; } -std::pair ProxySocket::Recv(int flags, std::vector& message) { +std::pair ProxySocket::Recv(int flags, std::span message) { LOG_WARNING(Network, "(STUBBED) called"); ASSERT(flags == 0); ASSERT(message.size() < static_cast(std::numeric_limits::max())); @@ -106,7 +106,7 @@ std::pair ProxySocket::Recv(int flags, std::vector& message) { return {static_cast(0), Errno::SUCCESS}; } -std::pair ProxySocket::RecvFrom(int flags, std::vector& message, SockAddrIn* addr) { +std::pair ProxySocket::RecvFrom(int flags, std::span message, SockAddrIn* addr) { ASSERT(flags == 0); ASSERT(message.size() < static_cast(std::numeric_limits::max())); @@ -140,8 +140,8 @@ std::pair ProxySocket::RecvFrom(int flags, std::vector& message, } } -std::pair ProxySocket::ReceivePacket(int flags, std::vector& message, - SockAddrIn* addr, std::size_t max_length) { +std::pair ProxySocket::ReceivePacket(int flags, std::span message, SockAddrIn* addr, + std::size_t max_length) { ProxyPacket& packet = received_packets.front(); if (addr) { addr->family = Domain::INET; @@ -153,10 +153,7 @@ std::pair ProxySocket::ReceivePacket(int flags, std::vector& mes std::size_t read_bytes; if (packet.data.size() > max_length) { read_bytes = max_length; - message.clear(); - std::copy(packet.data.begin(), packet.data.begin() + read_bytes, - std::back_inserter(message)); - message.resize(max_length); + memcpy(message.data(), packet.data.data(), max_length); if (protocol == Protocol::UDP) { if (!peek) { @@ -171,9 +168,7 @@ std::pair ProxySocket::ReceivePacket(int flags, std::vector& mes } } else { read_bytes = packet.data.size(); - message.clear(); - std::copy(packet.data.begin(), packet.data.end(), std::back_inserter(message)); - message.resize(max_length); + memcpy(message.data(), packet.data.data(), read_bytes); if (!peek) { received_packets.pop(); } @@ -293,6 +288,11 @@ Errno ProxySocket::SetNonBlock(bool enable) { return Errno::SUCCESS; } +std::pair ProxySocket::GetPendingError() { + LOG_DEBUG(Network, "(STUBBED) called"); + return {Errno::SUCCESS, Errno::SUCCESS}; +} + bool ProxySocket::IsOpened() const { return fd != INVALID_SOCKET; } diff --git a/src/core/internal_network/socket_proxy.h b/src/core/internal_network/socket_proxy.h index 1c3e3dcb8..336747a57 100755 --- a/src/core/internal_network/socket_proxy.h +++ b/src/core/internal_network/socket_proxy.h @@ -39,11 +39,11 @@ public: Errno Shutdown(ShutdownHow how) override; - std::pair Recv(int flags, std::vector& message) override; + std::pair Recv(int flags, std::span message) override; - std::pair RecvFrom(int flags, std::vector& message, SockAddrIn* addr) override; + std::pair RecvFrom(int flags, std::span message, SockAddrIn* addr) override; - std::pair ReceivePacket(int flags, std::vector& message, SockAddrIn* addr, + std::pair ReceivePacket(int flags, std::span message, SockAddrIn* addr, std::size_t max_length); std::pair Send(std::span message, int flags) override; @@ -74,6 +74,8 @@ public: template Errno SetSockOpt(SOCKET fd, int option, T value); + std::pair GetPendingError() override; + bool IsOpened() const override; private: diff --git a/src/core/internal_network/sockets.h b/src/core/internal_network/sockets.h index 1e0cda83f..389ac8ebd 100755 --- a/src/core/internal_network/sockets.h +++ b/src/core/internal_network/sockets.h @@ -59,10 +59,9 @@ public: virtual Errno Shutdown(ShutdownHow how) = 0; - virtual std::pair Recv(int flags, std::vector& message) = 0; + virtual std::pair Recv(int flags, std::span message) = 0; - virtual std::pair RecvFrom(int flags, std::vector& message, - SockAddrIn* addr) = 0; + virtual std::pair RecvFrom(int flags, std::span message, SockAddrIn* addr) = 0; virtual std::pair Send(std::span message, int flags) = 0; @@ -87,6 +86,8 @@ public: virtual Errno SetNonBlock(bool enable) = 0; + virtual std::pair GetPendingError() = 0; + virtual bool IsOpened() const = 0; virtual void HandleProxyPacket(const ProxyPacket& packet) = 0; @@ -126,9 +127,9 @@ public: Errno Shutdown(ShutdownHow how) override; - std::pair Recv(int flags, std::vector& message) override; + std::pair Recv(int flags, std::span message) override; - std::pair RecvFrom(int flags, std::vector& message, SockAddrIn* addr) override; + std::pair RecvFrom(int flags, std::span message, SockAddrIn* addr) override; std::pair Send(std::span message, int flags) override; @@ -156,6 +157,11 @@ public: template Errno SetSockOpt(SOCKET fd, int option, T value); + std::pair GetPendingError() override; + + template + std::pair GetSockOpt(SOCKET fd, int option); + bool IsOpened() const override; void HandleProxyPacket(const ProxyPacket& packet) override; diff --git a/src/core/memory.cpp b/src/core/memory.cpp index b4390cd00..5bf92e9ce 100755 --- a/src/core/memory.cpp +++ b/src/core/memory.cpp @@ -266,6 +266,22 @@ struct Memory::Impl { ReadBlockImpl(*system.ApplicationProcess(), src_addr, dest_buffer, size); } + const u8* GetSpan(const VAddr src_addr, const std::size_t size) const { + if (current_page_table->blocks[src_addr >> YUZU_PAGEBITS] == + current_page_table->blocks[(src_addr + size) >> YUZU_PAGEBITS]) { + return GetPointerSilent(src_addr); + } + return nullptr; + } + + u8* GetSpan(const VAddr src_addr, const std::size_t size) { + if (current_page_table->blocks[src_addr >> YUZU_PAGEBITS] == + current_page_table->blocks[(src_addr + size) >> YUZU_PAGEBITS]) { + return GetPointerSilent(src_addr); + } + return nullptr; + } + template void WriteBlockImpl(const Kernel::KProcess& process, const Common::ProcessAddress dest_addr, const void* src_buffer, const std::size_t size) { @@ -559,7 +575,7 @@ struct Memory::Impl { } } - const Common::ProcessAddress end = base + size; + const auto end = base + size; ASSERT_MSG(end <= page_table.pointers.size(), "out of range mapping at {:016X}", base + page_table.pointers.size()); @@ -570,14 +586,18 @@ struct Memory::Impl { while (base != end) { page_table.pointers[base].Store(nullptr, type); page_table.backing_addr[base] = 0; - + page_table.blocks[base] = 0; base += 1; } } else { + auto orig_base = base; while (base != end) { - page_table.pointers[base].Store( - system.DeviceMemory().GetPointer(target) - (base << YUZU_PAGEBITS), type); - page_table.backing_addr[base] = GetInteger(target) - (base << YUZU_PAGEBITS); + auto host_ptr = + system.DeviceMemory().GetPointer(target) - (base << YUZU_PAGEBITS); + auto backing = GetInteger(target) - (base << YUZU_PAGEBITS); + page_table.pointers[base].Store(host_ptr, type); + page_table.backing_addr[base] = backing; + page_table.blocks[base] = orig_base << YUZU_PAGEBITS; ASSERT_MSG(page_table.pointers[base].Pointer(), "memory mapping base yield a nullptr within the table"); @@ -747,6 +767,14 @@ struct Memory::Impl { VAddr last_address; }; + void InvalidateRegion(Common::ProcessAddress dest_addr, size_t size) { + system.GPU().InvalidateRegion(GetInteger(dest_addr), size); + } + + void FlushRegion(Common::ProcessAddress dest_addr, size_t size) { + system.GPU().FlushRegion(GetInteger(dest_addr), size); + } + Core::System& system; Common::PageTable* current_page_table = nullptr; std::array @@ -881,6 +909,14 @@ void Memory::ReadBlockUnsafe(const Common::ProcessAddress src_addr, void* dest_b impl->ReadBlockUnsafe(src_addr, dest_buffer, size); } +const u8* Memory::GetSpan(const VAddr src_addr, const std::size_t size) const { + return impl->GetSpan(src_addr, size); +} + +u8* Memory::GetSpan(const VAddr src_addr, const std::size_t size) { + return impl->GetSpan(src_addr, size); +} + void Memory::WriteBlock(const Common::ProcessAddress dest_addr, const void* src_buffer, const std::size_t size) { impl->WriteBlock(dest_addr, src_buffer, size); @@ -924,4 +960,12 @@ void Memory::MarkRegionDebug(Common::ProcessAddress vaddr, u64 size, bool debug) impl->MarkRegionDebug(GetInteger(vaddr), size, debug); } +void Memory::InvalidateRegion(Common::ProcessAddress dest_addr, size_t size) { + impl->InvalidateRegion(dest_addr, size); +} + +void Memory::FlushRegion(Common::ProcessAddress dest_addr, size_t size) { + impl->FlushRegion(dest_addr, size); +} + } // namespace Core::Memory diff --git a/src/core/memory.h b/src/core/memory.h index 9558bda7c..183fed329 100755 --- a/src/core/memory.h +++ b/src/core/memory.h @@ -5,8 +5,12 @@ #include #include +#include #include #include +#include + +#include "common/scratch_buffer.h" #include "common/typed_address.h" #include "core/hle/result.h" @@ -24,6 +28,10 @@ class PhysicalMemory; class KProcess; } // namespace Kernel +namespace Tegra { +class MemoryManager; +} + namespace Core::Memory { /** @@ -343,6 +351,9 @@ public: */ void ReadBlockUnsafe(Common::ProcessAddress src_addr, void* dest_buffer, std::size_t size); + const u8* GetSpan(const VAddr src_addr, const std::size_t size) const; + u8* GetSpan(const VAddr src_addr, const std::size_t size); + /** * Writes a range of bytes into the current process' address space at the specified * virtual address. @@ -461,6 +472,8 @@ public: void MarkRegionDebug(Common::ProcessAddress vaddr, u64 size, bool debug); void SetGPUDirtyManagers(std::span managers); + void InvalidateRegion(Common::ProcessAddress dest_addr, size_t size); + void FlushRegion(Common::ProcessAddress dest_addr, size_t size); private: Core::System& system; @@ -469,4 +482,203 @@ private: std::unique_ptr impl; }; +enum GuestMemoryFlags : u32 { + Read = 1 << 0, + Write = 1 << 1, + Safe = 1 << 2, + Cached = 1 << 3, + + SafeRead = Read | Safe, + SafeWrite = Write | Safe, + SafeReadWrite = SafeRead | SafeWrite, + SafeReadCachedWrite = SafeReadWrite | Cached, + + UnsafeRead = Read, + UnsafeWrite = Write, + UnsafeReadWrite = UnsafeRead | UnsafeWrite, + UnsafeReadCachedWrite = UnsafeReadWrite | Cached, +}; + +namespace { +template +class GuestMemory { + using iterator = T*; + using const_iterator = const T*; + using value_type = T; + using element_type = T; + using iterator_category = std::contiguous_iterator_tag; + +public: + GuestMemory() = delete; + explicit GuestMemory(M& memory_, u64 addr_, std::size_t size_, + Common::ScratchBuffer* backup = nullptr) + : memory{memory_}, addr{addr_}, size{size_} { + static_assert(FLAGS & GuestMemoryFlags::Read || FLAGS & GuestMemoryFlags::Write); + if constexpr (FLAGS & GuestMemoryFlags::Read) { + Read(addr, size, backup); + } + } + + ~GuestMemory() = default; + + T* data() noexcept { + return data_span.data(); + } + + const T* data() const noexcept { + return data_span.data(); + } + + [[nodiscard]] T* begin() noexcept { + return data(); + } + + [[nodiscard]] const T* begin() const noexcept { + return data(); + } + + [[nodiscard]] T* end() noexcept { + return data() + size; + } + + [[nodiscard]] const T* end() const noexcept { + return data() + size; + } + + T& operator[](size_t index) noexcept { + return data_span[index]; + } + + const T& operator[](size_t index) const noexcept { + return data_span[index]; + } + + void SetAddressAndSize(u64 addr_, std::size_t size_) noexcept { + addr = addr_; + size = size_; + addr_changed = true; + } + + std::span Read(u64 addr_, std::size_t size_, + Common::ScratchBuffer* backup = nullptr) noexcept { + addr = addr_; + size = size_; + if (size == 0) { + is_data_copy = true; + return {}; + } + + if (TrySetSpan()) { + if constexpr (FLAGS & GuestMemoryFlags::Safe) { + memory.FlushRegion(addr, size * sizeof(T)); + } + } else { + if (backup) { + backup->resize_destructive(size); + data_span = *backup; + } else { + data_copy.resize(size); + data_span = std::span(data_copy); + } + is_data_copy = true; + span_valid = true; + if constexpr (FLAGS & GuestMemoryFlags::Safe) { + memory.ReadBlock(addr, data_span.data(), size * sizeof(T)); + } else { + memory.ReadBlockUnsafe(addr, data_span.data(), size * sizeof(T)); + } + } + return data_span; + } + + void Write(std::span write_data) noexcept { + if constexpr (FLAGS & GuestMemoryFlags::Cached) { + memory.WriteBlockCached(addr, write_data.data(), size * sizeof(T)); + } else if constexpr (FLAGS & GuestMemoryFlags::Safe) { + memory.WriteBlock(addr, write_data.data(), size * sizeof(T)); + } else { + memory.WriteBlockUnsafe(addr, write_data.data(), size * sizeof(T)); + } + } + + bool TrySetSpan() noexcept { + if (u8* ptr = memory.GetSpan(addr, size * sizeof(T)); ptr) { + data_span = {reinterpret_cast(ptr), size}; + span_valid = true; + return true; + } + return false; + } + +protected: + bool IsDataCopy() const noexcept { + return is_data_copy; + } + + bool AddressChanged() const noexcept { + return addr_changed; + } + + M& memory; + u64 addr; + size_t size; + std::span data_span{}; + std::vector data_copy; + bool span_valid{false}; + bool is_data_copy{false}; + bool addr_changed{false}; +}; + +template +class GuestMemoryScoped : public GuestMemory { +public: + GuestMemoryScoped() = delete; + explicit GuestMemoryScoped(M& memory_, u64 addr_, std::size_t size_, + Common::ScratchBuffer* backup = nullptr) + : GuestMemory(memory_, addr_, size_, backup) { + if constexpr (!(FLAGS & GuestMemoryFlags::Read)) { + if (!this->TrySetSpan()) { + if (backup) { + this->data_span = *backup; + this->span_valid = true; + this->is_data_copy = true; + } + } + } + } + + ~GuestMemoryScoped() { + if constexpr (FLAGS & GuestMemoryFlags::Write) { + if (this->size == 0) [[unlikely]] { + return; + } + + if (this->AddressChanged() || this->IsDataCopy()) { + ASSERT(this->span_valid); + if constexpr (FLAGS & GuestMemoryFlags::Cached) { + this->memory.WriteBlockCached(this->addr, this->data_span.data(), + this->size * sizeof(T)); + } else if constexpr (FLAGS & GuestMemoryFlags::Safe) { + this->memory.WriteBlock(this->addr, this->data_span.data(), + this->size * sizeof(T)); + } else { + this->memory.WriteBlockUnsafe(this->addr, this->data_span.data(), + this->size * sizeof(T)); + } + } else if constexpr (FLAGS & GuestMemoryFlags::Safe) { + this->memory.InvalidateRegion(this->addr, this->size * sizeof(T)); + } + } + } +}; +} // namespace + +template +using CpuGuestMemory = GuestMemory; +template +using CpuGuestMemoryScoped = GuestMemoryScoped; +template +using GpuGuestMemory = GuestMemory; +template +using GpuGuestMemoryScoped = GuestMemoryScoped; } // namespace Core::Memory diff --git a/src/video_core/buffer_cache/buffer_cache.h b/src/video_core/buffer_cache/buffer_cache.h index 7636c74b6..c2a1e7d82 100755 --- a/src/video_core/buffer_cache/buffer_cache.h +++ b/src/video_core/buffer_cache/buffer_cache.h @@ -234,9 +234,10 @@ bool BufferCache

::DMACopy(GPUVAddr src_address, GPUVAddr dest_address, u64 am if (has_new_downloads) { memory_tracker.MarkRegionAsGpuModified(*cpu_dest_address, amount); } - tmp_buffer.resize_destructive(amount); - cpu_memory.ReadBlockUnsafe(*cpu_src_address, tmp_buffer.data(), amount); - cpu_memory.WriteBlockUnsafe(*cpu_dest_address, tmp_buffer.data(), amount); + + Core::Memory::CpuGuestMemoryScoped tmp( + cpu_memory, *cpu_src_address, amount, &tmp_buffer); + tmp.SetAddressAndSize(*cpu_dest_address, amount); return true; } diff --git a/src/video_core/dma_pusher.cpp b/src/video_core/dma_pusher.cpp index a619dca76..51f836fd9 100755 --- a/src/video_core/dma_pusher.cpp +++ b/src/video_core/dma_pusher.cpp @@ -5,6 +5,7 @@ #include "common/microprofile.h" #include "common/settings.h" #include "core/core.h" +#include "core/memory.h" #include "video_core/dma_pusher.h" #include "video_core/engines/maxwell_3d.h" #include "video_core/gpu.h" @@ -12,6 +13,8 @@ namespace Tegra { +constexpr u32 MacroRegistersStart = 0xE00; + DmaPusher::DmaPusher(Core::System& system_, GPU& gpu_, MemoryManager& memory_manager_, Control::ChannelState& channel_state_) : gpu{gpu_}, system{system_}, memory_manager{memory_manager_}, puller{gpu_, memory_manager_, @@ -74,25 +77,16 @@ bool DmaPusher::Step() { } // Push buffer non-empty, read a word - command_headers.resize_destructive(command_list_header.size); - constexpr u32 MacroRegistersStart = 0xE00; - if (dma_state.method < MacroRegistersStart) { - if (Settings::IsGPULevelHigh()) { - memory_manager.ReadBlock(dma_state.dma_get, command_headers.data(), - command_list_header.size * sizeof(u32)); - } else { - memory_manager.ReadBlockUnsafe(dma_state.dma_get, command_headers.data(), - command_list_header.size * sizeof(u32)); - } - } else { - const size_t copy_size = command_list_header.size * sizeof(u32); + if (dma_state.method >= MacroRegistersStart) { if (subchannels[dma_state.subchannel]) { - subchannels[dma_state.subchannel]->current_dirty = - memory_manager.IsMemoryDirty(dma_state.dma_get, copy_size); + subchannels[dma_state.subchannel]->current_dirty = memory_manager.IsMemoryDirty( + dma_state.dma_get, command_list_header.size * sizeof(u32)); } - memory_manager.ReadBlockUnsafe(dma_state.dma_get, command_headers.data(), copy_size); } - ProcessCommands(command_headers); + Core::Memory::GpuGuestMemory + headers(memory_manager, dma_state.dma_get, command_list_header.size, &command_headers); + ProcessCommands(headers); } return true; diff --git a/src/video_core/engines/engine_upload.cpp b/src/video_core/engines/engine_upload.cpp index 545df54c4..9ed7e7327 100755 --- a/src/video_core/engines/engine_upload.cpp +++ b/src/video_core/engines/engine_upload.cpp @@ -5,6 +5,7 @@ #include "common/algorithm.h" #include "common/assert.h" +#include "core/memory.h" #include "video_core/engines/engine_upload.h" #include "video_core/memory_manager.h" #include "video_core/rasterizer_interface.h" @@ -46,15 +47,11 @@ void State::ProcessData(const u32* data, size_t num_data) { void State::ProcessData(std::span read_buffer) { const GPUVAddr address{regs.dest.Address()}; if (is_linear) { - if (regs.line_count == 1) { - rasterizer->AccelerateInlineToMemory(address, copy_size, read_buffer); - } else { - for (size_t line = 0; line < regs.line_count; ++line) { - const GPUVAddr dest_line = address + line * regs.dest.pitch; - std::span buffer(read_buffer.data() + line * regs.line_length_in, - regs.line_length_in); - rasterizer->AccelerateInlineToMemory(dest_line, regs.line_length_in, buffer); - } + for (size_t line = 0; line < regs.line_count; ++line) { + const GPUVAddr dest_line = address + line * regs.dest.pitch; + std::span buffer(read_buffer.data() + line * regs.line_length_in, + regs.line_length_in); + rasterizer->AccelerateInlineToMemory(dest_line, regs.line_length_in, buffer); } } else { u32 width = regs.dest.width; @@ -70,13 +67,14 @@ void State::ProcessData(std::span read_buffer) { const std::size_t dst_size = Tegra::Texture::CalculateSize( true, bytes_per_pixel, width, regs.dest.height, regs.dest.depth, regs.dest.BlockHeight(), regs.dest.BlockDepth()); - tmp_buffer.resize_destructive(dst_size); - memory_manager.ReadBlock(address, tmp_buffer.data(), dst_size); - Tegra::Texture::SwizzleSubrect(tmp_buffer, read_buffer, bytes_per_pixel, width, - regs.dest.height, regs.dest.depth, x_offset, regs.dest.y, - x_elements, regs.line_count, regs.dest.BlockHeight(), + + Core::Memory::GpuGuestMemoryScoped + tmp(memory_manager, address, dst_size, &tmp_buffer); + + Tegra::Texture::SwizzleSubrect(tmp, read_buffer, bytes_per_pixel, width, regs.dest.height, + regs.dest.depth, x_offset, regs.dest.y, x_elements, + regs.line_count, regs.dest.BlockHeight(), regs.dest.BlockDepth(), regs.line_length_in); - memory_manager.WriteBlockCached(address, tmp_buffer.data(), dst_size); } } diff --git a/src/video_core/engines/kepler_compute.cpp b/src/video_core/engines/kepler_compute.cpp index 7735ef1ea..e1de1042c 100755 --- a/src/video_core/engines/kepler_compute.cpp +++ b/src/video_core/engines/kepler_compute.cpp @@ -84,7 +84,6 @@ Texture::TICEntry KeplerCompute::GetTICEntry(u32 tic_index) const { Texture::TICEntry tic_entry; memory_manager.ReadBlockUnsafe(tic_address_gpu, &tic_entry, sizeof(Texture::TICEntry)); - return tic_entry; } diff --git a/src/video_core/engines/maxwell_3d.cpp b/src/video_core/engines/maxwell_3d.cpp index 3152f9aa2..0a0d1a3b0 100755 --- a/src/video_core/engines/maxwell_3d.cpp +++ b/src/video_core/engines/maxwell_3d.cpp @@ -9,6 +9,7 @@ #include "common/settings.h" #include "core/core.h" #include "core/core_timing.h" +#include "core/memory.h" #include "video_core/dirty_flags.h" #include "video_core/engines/draw_manager.h" #include "video_core/engines/maxwell_3d.h" @@ -679,17 +680,14 @@ void Maxwell3D::ProcessCBData(u32 value) { Texture::TICEntry Maxwell3D::GetTICEntry(u32 tic_index) const { const GPUVAddr tic_address_gpu{regs.tex_header.Address() + tic_index * sizeof(Texture::TICEntry)}; - Texture::TICEntry tic_entry; memory_manager.ReadBlockUnsafe(tic_address_gpu, &tic_entry, sizeof(Texture::TICEntry)); - return tic_entry; } Texture::TSCEntry Maxwell3D::GetTSCEntry(u32 tsc_index) const { const GPUVAddr tsc_address_gpu{regs.tex_sampler.Address() + tsc_index * sizeof(Texture::TSCEntry)}; - Texture::TSCEntry tsc_entry; memory_manager.ReadBlockUnsafe(tsc_address_gpu, &tsc_entry, sizeof(Texture::TSCEntry)); return tsc_entry; diff --git a/src/video_core/engines/maxwell_dma.cpp b/src/video_core/engines/maxwell_dma.cpp index 9cdff0cba..0cc78f614 100755 --- a/src/video_core/engines/maxwell_dma.cpp +++ b/src/video_core/engines/maxwell_dma.cpp @@ -7,6 +7,7 @@ #include "common/microprofile.h" #include "common/settings.h" #include "core/core.h" +#include "core/memory.h" #include "video_core/engines/maxwell_3d.h" #include "video_core/engines/maxwell_dma.h" #include "video_core/memory_manager.h" @@ -130,11 +131,12 @@ void MaxwellDMA::Launch() { UNIMPLEMENTED_IF(regs.offset_out % 16 != 0); read_buffer.resize_destructive(16); for (u32 offset = 0; offset < regs.line_length_in; offset += 16) { - memory_manager.ReadBlock( - convert_linear_2_blocklinear_addr(regs.offset_in + offset), - read_buffer.data(), read_buffer.size()); - memory_manager.WriteBlockCached(regs.offset_out + offset, read_buffer.data(), - read_buffer.size()); + Core::Memory::GpuGuestMemoryScoped< + u8, Core::Memory::GuestMemoryFlags::SafeReadCachedWrite> + tmp_write_buffer(memory_manager, + convert_linear_2_blocklinear_addr(regs.offset_in + offset), + 16, &read_buffer); + tmp_write_buffer.SetAddressAndSize(regs.offset_out + offset, 16); } } else if (is_src_pitch && !is_dst_pitch) { UNIMPLEMENTED_IF(regs.line_length_in % 16 != 0); @@ -142,20 +144,19 @@ void MaxwellDMA::Launch() { UNIMPLEMENTED_IF(regs.offset_out % 16 != 0); read_buffer.resize_destructive(16); for (u32 offset = 0; offset < regs.line_length_in; offset += 16) { - memory_manager.ReadBlock(regs.offset_in + offset, read_buffer.data(), - read_buffer.size()); - memory_manager.WriteBlockCached( - convert_linear_2_blocklinear_addr(regs.offset_out + offset), - read_buffer.data(), read_buffer.size()); + Core::Memory::GpuGuestMemoryScoped< + u8, Core::Memory::GuestMemoryFlags::SafeReadCachedWrite> + tmp_write_buffer(memory_manager, regs.offset_in + offset, 16, &read_buffer); + tmp_write_buffer.SetAddressAndSize( + convert_linear_2_blocklinear_addr(regs.offset_out + offset), 16); } } else { if (!accelerate.BufferCopy(regs.offset_in, regs.offset_out, regs.line_length_in)) { - read_buffer.resize_destructive(regs.line_length_in); - memory_manager.ReadBlock(regs.offset_in, read_buffer.data(), - regs.line_length_in, - VideoCommon::CacheType::NoBufferCache); - memory_manager.WriteBlockCached(regs.offset_out, read_buffer.data(), - regs.line_length_in); + Core::Memory::GpuGuestMemoryScoped< + u8, Core::Memory::GuestMemoryFlags::SafeReadCachedWrite> + tmp_write_buffer(memory_manager, regs.offset_in, regs.line_length_in, + &read_buffer); + tmp_write_buffer.SetAddressAndSize(regs.offset_out, regs.line_length_in); } } } @@ -222,17 +223,15 @@ void MaxwellDMA::CopyBlockLinearToPitch() { CalculateSize(true, bytes_per_pixel, width, height, depth, block_height, block_depth); const size_t dst_size = dst_operand.pitch * regs.line_count; - read_buffer.resize_destructive(src_size); - write_buffer.resize_destructive(dst_size); - memory_manager.ReadBlock(src_operand.address, read_buffer.data(), src_size); - memory_manager.ReadBlock(dst_operand.address, write_buffer.data(), dst_size); + Core::Memory::GpuGuestMemory tmp_read_buffer( + memory_manager, src_operand.address, src_size, &read_buffer); + Core::Memory::GpuGuestMemoryScoped + tmp_write_buffer(memory_manager, dst_operand.address, dst_size, &write_buffer); - UnswizzleSubrect(write_buffer, read_buffer, bytes_per_pixel, width, height, depth, x_offset, - src_params.origin.y, x_elements, regs.line_count, block_height, block_depth, - dst_operand.pitch); - - memory_manager.WriteBlockCached(regs.offset_out, write_buffer.data(), dst_size); + UnswizzleSubrect(tmp_write_buffer, tmp_read_buffer, bytes_per_pixel, width, height, depth, + x_offset, src_params.origin.y, x_elements, regs.line_count, block_height, + block_depth, dst_operand.pitch); } void MaxwellDMA::CopyPitchToBlockLinear() { @@ -287,18 +286,17 @@ void MaxwellDMA::CopyPitchToBlockLinear() { CalculateSize(true, bytes_per_pixel, width, height, depth, block_height, block_depth); const size_t src_size = static_cast(regs.pitch_in) * regs.line_count; - read_buffer.resize_destructive(src_size); - write_buffer.resize_destructive(dst_size); + GPUVAddr src_addr = regs.offset_in; + GPUVAddr dst_addr = regs.offset_out; + Core::Memory::GpuGuestMemory tmp_read_buffer( + memory_manager, src_addr, src_size, &read_buffer); + Core::Memory::GpuGuestMemoryScoped + tmp_write_buffer(memory_manager, dst_addr, dst_size, &write_buffer); - memory_manager.ReadBlock(regs.offset_in, read_buffer.data(), src_size); - memory_manager.ReadBlockUnsafe(regs.offset_out, write_buffer.data(), dst_size); - - // If the input is linear and the output is tiled, swizzle the input and copy it over. - SwizzleSubrect(write_buffer, read_buffer, bytes_per_pixel, width, height, depth, x_offset, - dst_params.origin.y, x_elements, regs.line_count, block_height, block_depth, - regs.pitch_in); - - memory_manager.WriteBlockCached(regs.offset_out, write_buffer.data(), dst_size); + // If the input is linear and the output is tiled, swizzle the input and copy it over. + SwizzleSubrect(tmp_write_buffer, tmp_read_buffer, bytes_per_pixel, width, height, depth, + x_offset, dst_params.origin.y, x_elements, regs.line_count, block_height, + block_depth, regs.pitch_in); } void MaxwellDMA::CopyBlockLinearToBlockLinear() { @@ -342,23 +340,20 @@ void MaxwellDMA::CopyBlockLinearToBlockLinear() { const u32 pitch = x_elements * bytes_per_pixel; const size_t mid_buffer_size = pitch * regs.line_count; - read_buffer.resize_destructive(src_size); - write_buffer.resize_destructive(dst_size); - intermediate_buffer.resize_destructive(mid_buffer_size); - memory_manager.ReadBlock(regs.offset_in, read_buffer.data(), src_size); - memory_manager.ReadBlock(regs.offset_out, write_buffer.data(), dst_size); + Core::Memory::GpuGuestMemory tmp_read_buffer( + memory_manager, regs.offset_in, src_size, &read_buffer); + Core::Memory::GpuGuestMemoryScoped + tmp_write_buffer(memory_manager, regs.offset_out, dst_size, &write_buffer); - UnswizzleSubrect(intermediate_buffer, read_buffer, bytes_per_pixel, src_width, src.height, + UnswizzleSubrect(intermediate_buffer, tmp_read_buffer, bytes_per_pixel, src_width, src.height, src.depth, src_x_offset, src.origin.y, x_elements, regs.line_count, src.block_size.height, src.block_size.depth, pitch); - SwizzleSubrect(write_buffer, intermediate_buffer, bytes_per_pixel, dst_width, dst.height, + SwizzleSubrect(tmp_write_buffer, intermediate_buffer, bytes_per_pixel, dst_width, dst.height, dst.depth, dst_x_offset, dst.origin.y, x_elements, regs.line_count, dst.block_size.height, dst.block_size.depth, pitch); - - memory_manager.WriteBlockCached(regs.offset_out, write_buffer.data(), dst_size); } void MaxwellDMA::ReleaseSemaphore() { diff --git a/src/video_core/engines/sw_blitter/blitter.cpp b/src/video_core/engines/sw_blitter/blitter.cpp index ff88cd03d..3a599f466 100755 --- a/src/video_core/engines/sw_blitter/blitter.cpp +++ b/src/video_core/engines/sw_blitter/blitter.cpp @@ -159,11 +159,11 @@ bool SoftwareBlitEngine::Blit(Fermi2D::Surface& src, Fermi2D::Surface& dst, const auto src_bytes_per_pixel = BytesPerBlock(PixelFormatFromRenderTargetFormat(src.format)); const auto dst_bytes_per_pixel = BytesPerBlock(PixelFormatFromRenderTargetFormat(dst.format)); const size_t src_size = get_surface_size(src, src_bytes_per_pixel); - impl->tmp_buffer.resize_destructive(src_size); - memory_manager.ReadBlock(src.Address(), impl->tmp_buffer.data(), src_size); + + Core::Memory::GpuGuestMemory tmp_buffer( + memory_manager, src.Address(), src_size, &impl->tmp_buffer); const size_t src_copy_size = src_extent_x * src_extent_y * src_bytes_per_pixel; - const size_t dst_copy_size = dst_extent_x * dst_extent_y * dst_bytes_per_pixel; impl->src_buffer.resize_destructive(src_copy_size); @@ -200,12 +200,11 @@ bool SoftwareBlitEngine::Blit(Fermi2D::Surface& src, Fermi2D::Surface& dst, impl->dst_buffer.resize_destructive(dst_copy_size); if (src.linear == Fermi2D::MemoryLayout::BlockLinear) { - UnswizzleSubrect(impl->src_buffer, impl->tmp_buffer, src_bytes_per_pixel, src.width, - src.height, src.depth, config.src_x0, config.src_y0, src_extent_x, - src_extent_y, src.block_height, src.block_depth, - src_extent_x * src_bytes_per_pixel); + UnswizzleSubrect(impl->src_buffer, tmp_buffer, src_bytes_per_pixel, src.width, src.height, + src.depth, config.src_x0, config.src_y0, src_extent_x, src_extent_y, + src.block_height, src.block_depth, src_extent_x * src_bytes_per_pixel); } else { - process_pitch_linear(false, impl->tmp_buffer, impl->src_buffer, src_extent_x, src_extent_y, + process_pitch_linear(false, tmp_buffer, impl->src_buffer, src_extent_x, src_extent_y, src.pitch, config.src_x0, config.src_y0, src_bytes_per_pixel); } @@ -221,20 +220,18 @@ bool SoftwareBlitEngine::Blit(Fermi2D::Surface& src, Fermi2D::Surface& dst, } const size_t dst_size = get_surface_size(dst, dst_bytes_per_pixel); - impl->tmp_buffer.resize_destructive(dst_size); - memory_manager.ReadBlock(dst.Address(), impl->tmp_buffer.data(), dst_size); + Core::Memory::GpuGuestMemoryScoped + tmp_buffer2(memory_manager, dst.Address(), dst_size, &impl->tmp_buffer); if (dst.linear == Fermi2D::MemoryLayout::BlockLinear) { - SwizzleSubrect(impl->tmp_buffer, impl->dst_buffer, dst_bytes_per_pixel, dst.width, - dst.height, dst.depth, config.dst_x0, config.dst_y0, dst_extent_x, - dst_extent_y, dst.block_height, dst.block_depth, - dst_extent_x * dst_bytes_per_pixel); + SwizzleSubrect(tmp_buffer2, impl->dst_buffer, dst_bytes_per_pixel, dst.width, dst.height, + dst.depth, config.dst_x0, config.dst_y0, dst_extent_x, dst_extent_y, + dst.block_height, dst.block_depth, dst_extent_x * dst_bytes_per_pixel); } else { - process_pitch_linear(true, impl->dst_buffer, impl->tmp_buffer, dst_extent_x, dst_extent_y, + process_pitch_linear(true, impl->dst_buffer, tmp_buffer2, dst_extent_x, dst_extent_y, dst.pitch, config.dst_x0, config.dst_y0, static_cast(dst_bytes_per_pixel)); } - memory_manager.WriteBlock(dst.Address(), impl->tmp_buffer.data(), dst_size); return true; } diff --git a/src/video_core/memory_manager.cpp b/src/video_core/memory_manager.cpp index 064714b9b..0cce535ca 100755 --- a/src/video_core/memory_manager.cpp +++ b/src/video_core/memory_manager.cpp @@ -10,13 +10,13 @@ #include "core/device_memory.h" #include "core/hle/kernel/k_page_table.h" #include "core/hle/kernel/k_process.h" -#include "core/memory.h" #include "video_core/invalidation_accumulator.h" #include "video_core/memory_manager.h" #include "video_core/rasterizer_interface.h" #include "video_core/renderer_base.h" namespace Tegra { +using Core::Memory::GuestMemoryFlags; std::atomic MemoryManager::unique_identifier_generator{}; @@ -587,13 +587,10 @@ void MemoryManager::InvalidateRegion(GPUVAddr gpu_addr, size_t size, void MemoryManager::CopyBlock(GPUVAddr gpu_dest_addr, GPUVAddr gpu_src_addr, std::size_t size, VideoCommon::CacheType which) { - tmp_buffer.resize_destructive(size); - ReadBlock(gpu_src_addr, tmp_buffer.data(), size, which); - - // The output block must be flushed in case it has data modified from the GPU. - // Fixes NPC geometry in Zombie Panic in Wonderland DX + Core::Memory::GpuGuestMemoryScoped data( + *this, gpu_src_addr, size); + data.SetAddressAndSize(gpu_dest_addr, size); FlushRegion(gpu_dest_addr, size, which); - WriteBlock(gpu_dest_addr, tmp_buffer.data(), size, which); } bool MemoryManager::IsGranularRange(GPUVAddr gpu_addr, std::size_t size) const { @@ -758,4 +755,23 @@ void MemoryManager::FlushCaching() { accumulator->Clear(); } +const u8* MemoryManager::GetSpan(const GPUVAddr src_addr, const std::size_t size) const { + auto cpu_addr = GpuToCpuAddress(src_addr); + if (cpu_addr) { + return memory.GetSpan(*cpu_addr, size); + } + return nullptr; +} + +u8* MemoryManager::GetSpan(const GPUVAddr src_addr, const std::size_t size) { + if (!IsContinuousRange(src_addr, size)) { + return nullptr; + } + auto cpu_addr = GpuToCpuAddress(src_addr); + if (cpu_addr) { + return memory.GetSpan(*cpu_addr, size); + } + return nullptr; +} + } // namespace Tegra diff --git a/src/video_core/memory_manager.h b/src/video_core/memory_manager.h index 51831570f..cfa9f3878 100755 --- a/src/video_core/memory_manager.h +++ b/src/video_core/memory_manager.h @@ -15,6 +15,7 @@ #include "common/range_map.h" #include "common/scratch_buffer.h" #include "common/virtual_buffer.h" +#include "core/memory.h" #include "video_core/cache_types.h" #include "video_core/pte_kind.h" @@ -62,6 +63,20 @@ public: [[nodiscard]] u8* GetPointer(GPUVAddr addr); [[nodiscard]] const u8* GetPointer(GPUVAddr addr) const; + template + [[nodiscard]] T* GetPointer(GPUVAddr addr) { + const auto address{GpuToCpuAddress(addr)}; + if (!address) { + return {}; + } + return memory.GetPointer(*address); + } + + template + [[nodiscard]] const T* GetPointer(GPUVAddr addr) const { + return GetPointer(addr); + } + /** * ReadBlock and WriteBlock are full read and write operations over virtual * GPU Memory. It's important to use these when GPU memory may not be continuous @@ -139,6 +154,9 @@ public: void FlushCaching(); + const u8* GetSpan(const GPUVAddr src_addr, const std::size_t size) const; + u8* GetSpan(const GPUVAddr src_addr, const std::size_t size); + private: template inline void MemoryOperation(GPUVAddr gpu_src_addr, std::size_t size, FuncMapped&& func_mapped, diff --git a/src/video_core/texture_cache/texture_cache.h b/src/video_core/texture_cache/texture_cache.h index 9a3cdeafa..54f2db2de 100755 --- a/src/video_core/texture_cache/texture_cache.h +++ b/src/video_core/texture_cache/texture_cache.h @@ -8,6 +8,7 @@ #include "common/alignment.h" #include "common/settings.h" +#include "core/memory.h" #include "video_core/control/channel_state.h" #include "video_core/dirty_flags.h" #include "video_core/engines/kepler_compute.h" @@ -1022,19 +1023,19 @@ void TextureCache

::UploadImageContents(Image& image, StagingBuffer& staging) runtime.AccelerateImageUpload(image, staging, uploads); return; } - const size_t guest_size_bytes = image.guest_size_bytes; - swizzle_data_buffer.resize_destructive(guest_size_bytes); - gpu_memory->ReadBlockUnsafe(gpu_addr, swizzle_data_buffer.data(), guest_size_bytes); + + Core::Memory::GpuGuestMemory swizzle_data( + *gpu_memory, gpu_addr, image.guest_size_bytes, &swizzle_data_buffer); if (True(image.flags & ImageFlagBits::Converted)) { unswizzle_data_buffer.resize_destructive(image.unswizzled_size_bytes); - auto copies = UnswizzleImage(*gpu_memory, gpu_addr, image.info, swizzle_data_buffer, - unswizzle_data_buffer); + auto copies = + UnswizzleImage(*gpu_memory, gpu_addr, image.info, swizzle_data, unswizzle_data_buffer); ConvertImage(unswizzle_data_buffer, image.info, mapped_span, copies); image.UploadMemory(staging, copies); } else { const auto copies = - UnswizzleImage(*gpu_memory, gpu_addr, image.info, swizzle_data_buffer, mapped_span); + UnswizzleImage(*gpu_memory, gpu_addr, image.info, swizzle_data, mapped_span); image.UploadMemory(staging, copies); } } @@ -1227,11 +1228,12 @@ void TextureCache

::QueueAsyncDecode(Image& image, ImageId image_id) { decode->image_id = image_id; async_decodes.push_back(std::move(decode)); - Common::ScratchBuffer local_unswizzle_data_buffer(image.unswizzled_size_bytes); - const size_t guest_size_bytes = image.guest_size_bytes; - swizzle_data_buffer.resize_destructive(guest_size_bytes); - gpu_memory->ReadBlockUnsafe(image.gpu_addr, swizzle_data_buffer.data(), guest_size_bytes); - auto copies = UnswizzleImage(*gpu_memory, image.gpu_addr, image.info, swizzle_data_buffer, + static Common::ScratchBuffer local_unswizzle_data_buffer; + local_unswizzle_data_buffer.resize_destructive(image.unswizzled_size_bytes); + Core::Memory::GpuGuestMemory swizzle_data( + *gpu_memory, image.gpu_addr, image.guest_size_bytes, &swizzle_data_buffer); + + auto copies = UnswizzleImage(*gpu_memory, image.gpu_addr, image.info, swizzle_data, local_unswizzle_data_buffer); const size_t out_size = MapSizeBytes(image); diff --git a/src/video_core/texture_cache/util.cpp b/src/video_core/texture_cache/util.cpp index d230a38a2..45daeee97 100755 --- a/src/video_core/texture_cache/util.cpp +++ b/src/video_core/texture_cache/util.cpp @@ -20,6 +20,7 @@ #include "common/div_ceil.h" #include "common/scratch_buffer.h" #include "common/settings.h" +#include "core/memory.h" #include "video_core/compatible_formats.h" #include "video_core/engines/maxwell_3d.h" #include "video_core/memory_manager.h" @@ -544,17 +545,15 @@ void SwizzleBlockLinearImage(Tegra::MemoryManager& gpu_memory, GPUVAddr gpu_addr tile_size.height, info.tile_width_spacing); const size_t subresource_size = sizes[level]; - tmp_buffer.resize_destructive(subresource_size); - const std::span dst(tmp_buffer); - for (s32 layer = 0; layer < info.resources.layers; ++layer) { const std::span src = input.subspan(host_offset); - gpu_memory.ReadBlockUnsafe(gpu_addr + guest_offset, dst.data(), dst.size_bytes()); + { + Core::Memory::GpuGuestMemoryScoped + dst(gpu_memory, gpu_addr + guest_offset, subresource_size, &tmp_buffer); - SwizzleTexture(dst, src, bytes_per_block, num_tiles.width, num_tiles.height, - num_tiles.depth, block.height, block.depth); - - gpu_memory.WriteBlockUnsafe(gpu_addr + guest_offset, dst.data(), dst.size_bytes()); + SwizzleTexture(dst, src, bytes_per_block, num_tiles.width, num_tiles.height, + num_tiles.depth, block.height, block.depth); + } host_offset += host_bytes_per_layer; guest_offset += layer_stride; @@ -837,6 +836,7 @@ boost::container::small_vector UnswizzleImage(Tegra::Memory const Extent3D size = info.size; if (info.type == ImageType::Linear) { + ASSERT(output.size_bytes() >= guest_size_bytes); gpu_memory.ReadBlockUnsafe(gpu_addr, output.data(), guest_size_bytes); ASSERT((info.pitch >> bpp_log2) << bpp_log2 == info.pitch); @@ -904,16 +904,6 @@ boost::container::small_vector UnswizzleImage(Tegra::Memory return copies; } -BufferCopy UploadBufferCopy(Tegra::MemoryManager& gpu_memory, GPUVAddr gpu_addr, - const ImageBase& image, std::span output) { - gpu_memory.ReadBlockUnsafe(gpu_addr, output.data(), image.guest_size_bytes); - return BufferCopy{ - .src_offset = 0, - .dst_offset = 0, - .size = image.guest_size_bytes, - }; -} - void ConvertImage(std::span input, const ImageInfo& info, std::span output, std::span copies) { u32 output_offset = 0; diff --git a/src/video_core/texture_cache/util.h b/src/video_core/texture_cache/util.h index a7315196c..a0332387f 100755 --- a/src/video_core/texture_cache/util.h +++ b/src/video_core/texture_cache/util.h @@ -66,9 +66,6 @@ struct OverlapResult { Tegra::MemoryManager& gpu_memory, GPUVAddr gpu_addr, const ImageInfo& info, std::span input, std::span output); -[[nodiscard]] BufferCopy UploadBufferCopy(Tegra::MemoryManager& gpu_memory, GPUVAddr gpu_addr, - const ImageBase& image, std::span output); - void ConvertImage(std::span input, const ImageInfo& info, std::span output, std::span copies); diff --git a/src/yuzu/main.cpp b/src/yuzu/main.cpp index d52876c60..ebc3a614f 100755 --- a/src/yuzu/main.cpp +++ b/src/yuzu/main.cpp @@ -101,6 +101,7 @@ static FileSys::VirtualFile VfsDirectoryCreateFileWrapper(const FileSys::Virtual #include "common/settings.h" #include "common/telemetry.h" #include "core/core.h" +#include "core/core_timing.h" #include "core/crypto/key_manager.h" #include "core/file_sys/card_image.h" #include "core/file_sys/common_funcs.h" @@ -389,6 +390,7 @@ GMainWindow::GMainWindow(std::unique_ptr config_, bool has_broken_vulkan std::chrono::duration_cast>( Common::Windows::SetCurrentTimerResolutionToMaximum()) .count()); + system->CoreTiming().SetTimerResolutionNs(Common::Windows::GetCurrentTimerResolution()); #endif UpdateWindowTitle(); diff --git a/src/yuzu_cmd/yuzu.cpp b/src/yuzu_cmd/yuzu.cpp index 04c2fd0ef..fb2e8e01a 100755 --- a/src/yuzu_cmd/yuzu.cpp +++ b/src/yuzu_cmd/yuzu.cpp @@ -21,6 +21,7 @@ #include "common/string_util.h" #include "common/telemetry.h" #include "core/core.h" +#include "core/core_timing.h" #include "core/cpu_manager.h" #include "core/crypto/key_manager.h" #include "core/file_sys/registered_cache.h" @@ -316,8 +317,6 @@ int main(int argc, char** argv) { #ifdef _WIN32 LocalFree(argv_w); - - Common::Windows::SetCurrentTimerResolutionToMaximum(); #endif MicroProfileOnThreadCreate("EmuThread"); @@ -351,6 +350,11 @@ int main(int argc, char** argv) { break; } +#ifdef _WIN32 + Common::Windows::SetCurrentTimerResolutionToMaximum(); + system.CoreTiming().SetTimerResolutionNs(Common::Windows::GetCurrentTimerResolution()); +#endif + system.SetContentProvider(std::make_unique()); system.SetFilesystem(std::make_shared()); system.GetFileSystemController().CreateFactories(*system.GetFilesystem());