From 3c05d892e6a2c35b921009f1b7b75bbac365d8a1 Mon Sep 17 00:00:00 2001 From: Michael Fabian 'Xaymar' Dirks Date: Mon, 26 Apr 2021 01:21:36 +0200 Subject: [PATCH] filter/nvidia-face-tracking: Update to `nvidia::cuda::obs` --- source/filters/filter-nv-face-tracking.cpp | 42 +++++----------------- source/filters/filter-nv-face-tracking.hpp | 15 +++----- 2 files changed, 14 insertions(+), 43 deletions(-) diff --git a/source/filters/filter-nv-face-tracking.cpp b/source/filters/filter-nv-face-tracking.cpp index e806f6ef..96275186 100644 --- a/source/filters/filter-nv-face-tracking.cpp +++ b/source/filters/filter-nv-face-tracking.cpp @@ -53,8 +53,7 @@ face_tracking_instance::face_tracking_instance(obs_data_t* settings, obs_source_ _geometry(), _filters(), _values(), - _cuda(face_tracking_factory::get()->get_cuda()), _cuda_ctx(face_tracking_factory::get()->get_cuda_context()), - _cuda_stream(), + _cuda(::nvidia::cuda::obs::get()), _cuda_stream(), _ar_library(face_tracking_factory::get()->get_ar()), _ar_loaded(false), _ar_feature(), _ar_is_tracking(false), _ar_bboxes_confidence(), _ar_bboxes_data(), _ar_bboxes(), _ar_texture(), _ar_texture_cuda_fresh(false), @@ -76,7 +75,7 @@ face_tracking_instance::face_tracking_instance(obs_data_t* settings, obs_source_ auto gctx = gs::context{}; _rt = std::make_shared(GS_RGBA, GS_ZS_NONE); _geometry = std::make_shared(uint32_t(4), uint8_t(1)); - auto cctx = std::make_shared<::nvidia::cuda::context_stack>(_cuda_ctx); + auto cctx = _cuda->get_context()->enter(); _cuda_stream = std::make_shared<::nvidia::cuda::stream>(::nvidia::cuda::stream_flags::NON_BLOCKING, 0); } @@ -138,7 +137,7 @@ void face_tracking_instance::async_initialize(std::shared_ptr ptr) // Update the current CUDA context for working. gs::context gctx; - auto cctx = std::make_shared<::nvidia::cuda::context_stack>(_cuda_ctx); + auto cctx = _cuda->get_context()->enter(); // Create Face Detection feature. { @@ -266,7 +265,7 @@ void face_tracking_instance::async_track(std::shared_ptr ptr) gs::context gctx{}; // Update the current CUDA context for working. - auto cctx = std::make_shared<::nvidia::cuda::context_stack>(_cuda_ctx); + auto cctx = _cuda->get_context()->enter(); // Refresh any now broken buffers. if (!_ar_texture_cuda_fresh) { @@ -291,7 +290,7 @@ void face_tracking_instance::async_track(std::shared_ptr ptr) NVCV_INTERLEAVED, NVCV_CUDA, 0); // Synchronize Streams. - _cuda->cuStreamSynchronize(_cuda_stream->get()); + _cuda_stream->synchronize(); // Finally set the input object. if (NvCV_Status res = _ar_library->set_object(_ar_feature.get(), NvAR_Parameter_Input(Image), @@ -327,7 +326,7 @@ void face_tracking_instance::async_track(std::shared_ptr ptr) mc.width_in_bytes = static_cast(_ar_image.pitch); mc.height = _ar_image.height; - if (::nvidia::cuda::result res = _cuda->cuMemcpy2DAsync(&mc, _cuda_stream->get()); + if (::nvidia::cuda::result res = _cuda->get_cuda()->cuMemcpy2DAsync(&mc, _cuda_stream->get()); res != ::nvidia::cuda::result::SUCCESS) { DLOG_ERROR("<%s> Failed to prepare buffers for tracking.", obs_source_get_name(_self)); return; @@ -347,8 +346,8 @@ void face_tracking_instance::async_track(std::shared_ptr ptr) } // Synchronize Streams. - _cuda->cuStreamSynchronize(_cuda_stream->get()); - _cuda->cuCtxSynchronize(); + _cuda_stream->synchronize(); + _cuda->get_context()->synchronize(); } { // Track any faces. @@ -604,24 +603,11 @@ bool face_tracking_instance::button_profile(obs_properties_t* props, obs_propert face_tracking_factory::face_tracking_factory() { // Try and load CUDA. - _cuda = ::nvidia::cuda::cuda::get(); + _cuda = ::nvidia::cuda::obs::get(); // Try and load AR. _ar = std::make_shared<::nvidia::ar::ar>(); - // Initialize CUDA - { - auto gctx = gs::context{}; -#ifdef WIN32 - if (gs_get_device_type() == GS_DEVICE_DIRECT3D_11) { - _cuda_ctx = std::make_shared<::nvidia::cuda::context>(reinterpret_cast(gs_get_device_obj())); - } -#endif - if (gs_get_device_type() == GS_DEVICE_OPENGL) { - throw std::runtime_error("OpenGL not supported."); - } - } - // Info _info.id = PREFIX "filter-nvidia-face-tracking"; _info.type = OBS_SOURCE_TYPE_FILTER; @@ -693,16 +679,6 @@ obs_properties_t* face_tracking_factory::get_properties2(face_tracking_instance* return pr; } -std::shared_ptr<::nvidia::cuda::cuda> face_tracking_factory::get_cuda() -{ - return _cuda; -} - -std::shared_ptr<::nvidia::cuda::context> face_tracking_factory::get_cuda_context() -{ - return _cuda_ctx; -} - std::shared_ptr<::nvidia::ar::ar> face_tracking_factory::get_ar() { return _ar; diff --git a/source/filters/filter-nv-face-tracking.hpp b/source/filters/filter-nv-face-tracking.hpp index 77d5eafd..3b334e47 100644 --- a/source/filters/filter-nv-face-tracking.hpp +++ b/source/filters/filter-nv-face-tracking.hpp @@ -31,6 +31,7 @@ #include "nvidia/cuda/nvidia-cuda-context.hpp" #include "nvidia/cuda/nvidia-cuda-gs-texture.hpp" #include "nvidia/cuda/nvidia-cuda-memory.hpp" +#include "nvidia/cuda/nvidia-cuda-obs.hpp" #include "nvidia/cuda/nvidia-cuda-stream.hpp" #include "nvidia/cuda/nvidia-cuda.hpp" @@ -62,9 +63,8 @@ namespace streamfx::filter::nvidia { } _values; // Nvidia CUDA interop - std::shared_ptr<::nvidia::cuda::cuda> _cuda; - std::shared_ptr<::nvidia::cuda::context> _cuda_ctx; - std::shared_ptr<::nvidia::cuda::stream> _cuda_stream; + std::shared_ptr<::nvidia::cuda::obs> _cuda; + std::shared_ptr<::nvidia::cuda::stream> _cuda_stream; // Nvidia AR interop std::shared_ptr<::nvidia::ar::ar> _ar_library; @@ -129,9 +129,8 @@ namespace streamfx::filter::nvidia { class face_tracking_factory : public obs::source_factory { - std::shared_ptr<::nvidia::cuda::cuda> _cuda; - std::shared_ptr<::nvidia::cuda::context> _cuda_ctx; - std::shared_ptr<::nvidia::ar::ar> _ar; + std::shared_ptr<::nvidia::cuda::obs> _cuda; + std::shared_ptr<::nvidia::ar::ar> _ar; public: face_tracking_factory(); @@ -143,10 +142,6 @@ namespace streamfx::filter::nvidia { virtual obs_properties_t* get_properties2(filter::nvidia::face_tracking_instance* data) override; - std::shared_ptr<::nvidia::cuda::cuda> get_cuda(); - - std::shared_ptr<::nvidia::cuda::context> get_cuda_context(); - std::shared_ptr<::nvidia::ar::ar> get_ar(); public: // Singleton