diff --git a/app/src/main/cpp/skyline/kernel/svc.cpp b/app/src/main/cpp/skyline/kernel/svc.cpp index 15182800..d061d1da 100644 --- a/app/src/main/cpp/skyline/kernel/svc.cpp +++ b/app/src/main/cpp/skyline/kernel/svc.cpp @@ -1205,11 +1205,8 @@ namespace skyline::kernel::svc { return; } - enum class ArbitrationType : u32 { - WaitIfLessThan = 0, - DecrementAndWaitIfLessThan = 1, - WaitIfEqual = 2, - } arbitrationType{static_cast(static_cast(state.ctx->gpr.w1))}; + using ArbitrationType = type::KProcess::ArbitrationType; + auto arbitrationType{static_cast(static_cast(state.ctx->gpr.w1))}; u32 value{state.ctx->gpr.w2}; i64 timeout{static_cast(state.ctx->gpr.x3)}; @@ -1217,28 +1214,17 @@ namespace skyline::kernel::svc { switch (arbitrationType) { case ArbitrationType::WaitIfLessThan: Logger::Debug("Waiting on 0x{:X} if less than {} for {}ns", address, value, timeout); - result = state.process->WaitForAddress(address, value, timeout, [](u32 *address, u32 value) { - return *address < value; - }); + result = state.process->WaitForAddress(address, value, timeout, ArbitrationType::WaitIfLessThan); break; case ArbitrationType::DecrementAndWaitIfLessThan: Logger::Debug("Waiting on and decrementing 0x{:X} if less than {} for {}ns", address, value, timeout); - result = state.process->WaitForAddress(address, value, timeout, [](u32 *address, u32 value) { - u32 userValue{__atomic_load_n(address, __ATOMIC_SEQ_CST)}; - do { - if (value <= userValue) [[unlikely]] // We want to explicitly decrement **after** the check - return false; - } while (!__atomic_compare_exchange_n(address, &userValue, userValue - 1, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST)); - return true; - }); + result = state.process->WaitForAddress(address, value, timeout, ArbitrationType::DecrementAndWaitIfLessThan); break; case ArbitrationType::WaitIfEqual: Logger::Debug("Waiting on 0x{:X} if equal to {} for {}ns", address, value, timeout); - result = state.process->WaitForAddress(address, value, timeout, [](u32 *address, u32 value) { - return *address == value; - }); + result = state.process->WaitForAddress(address, value, timeout, ArbitrationType::WaitIfEqual); break; default: @@ -1267,11 +1253,8 @@ namespace skyline::kernel::svc { return; } - enum class SignalType : u32 { - Signal = 0, - SignalAndIncrementIfEqual = 1, - SignalAndModifyBasedOnWaitingThreadCountIfEqual = 2, - } signalType{static_cast(static_cast(state.ctx->gpr.w1))}; + using SignalType = type::KProcess::SignalType; + auto signalType{static_cast(static_cast(state.ctx->gpr.w1))}; u32 value{state.ctx->gpr.w2}; i32 count{static_cast(state.ctx->gpr.w3)}; @@ -1279,21 +1262,17 @@ namespace skyline::kernel::svc { switch (signalType) { case SignalType::Signal: Logger::Debug("Signalling 0x{:X} for {} waiters", address, count); - result = state.process->SignalToAddress(address, value, count); + result = state.process->SignalToAddress(address, value, count, SignalType::Signal); break; case SignalType::SignalAndIncrementIfEqual: Logger::Debug("Signalling 0x{:X} and incrementing if equal to {} for {} waiters", address, value, count); - result = state.process->SignalToAddress(address, value, count, [](u32 *address, u32 value, u32) { - return __atomic_compare_exchange_n(address, &value, value + 1, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); - }); + result = state.process->SignalToAddress(address, value, count, SignalType::SignalAndIncrementIfEqual); break; case SignalType::SignalAndModifyBasedOnWaitingThreadCountIfEqual: Logger::Debug("Signalling 0x{:X} and setting to waiting thread count if equal to {} for {} waiters", address, value, count); - result = state.process->SignalToAddress(address, value, count, [](u32 *address, u32 value, u32 waiterCount) { - return __atomic_compare_exchange_n(address, &value, waiterCount, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); - }); + result = state.process->SignalToAddress(address, value, count, SignalType::SignalAndModifyBasedOnWaitingThreadCountIfEqual); break; default: diff --git a/app/src/main/cpp/skyline/kernel/types/KProcess.cpp b/app/src/main/cpp/skyline/kernel/types/KProcess.cpp index 88f2a33c..d73d2565 100644 --- a/app/src/main/cpp/skyline/kernel/types/KProcess.cpp +++ b/app/src/main/cpp/skyline/kernel/types/KProcess.cpp @@ -54,7 +54,7 @@ namespace skyline::kernel::type { u8 *KProcess::AllocateTlsSlot() { std::scoped_lock lock{tlsMutex}; u8 *slot; - for (auto &tlsPage: tlsPages) + for (auto &tlsPage : tlsPages) if ((slot = tlsPage->ReserveSlot())) return slot; @@ -268,13 +268,32 @@ namespace skyline::kernel::type { __atomic_store_n(key, false, __ATOMIC_SEQ_CST); // We need to update the boolean flag denoting that there are no more threads waiting on this conditional variable } - Result KProcess::WaitForAddress(u32 *address, u32 value, i64 timeout, bool (*arbitrationFunction)(u32 *, u32)) { + Result KProcess::WaitForAddress(u32 *address, u32 value, i64 timeout, ArbitrationType type) { TRACE_EVENT_FMT("kernel", "WaitForAddress 0x{:X}", address); { std::scoped_lock lock{syncWaiterMutex}; - if (!arbitrationFunction(address, value)) [[unlikely]] - return result::InvalidState; + + switch (type) { + case ArbitrationType::WaitIfLessThan: + if (*address >= value) [[unlikely]] + return result::InvalidState; + break; + + case ArbitrationType::DecrementAndWaitIfLessThan: { + u32 userValue{__atomic_load_n(address, __ATOMIC_SEQ_CST)}; + do { + if (value <= userValue) [[unlikely]] // We want to explicitly decrement **after** the check + return result::InvalidState; + } while (!__atomic_compare_exchange_n(address, &userValue, userValue - 1, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST)); + break; + } + + case ArbitrationType::WaitIfEqual: + if (*address != value) [[unlikely]] + return result::InvalidState; + break; + } auto queue{syncWaiters.equal_range(address)}; syncWaiters.insert(std::upper_bound(queue.first, queue.second, state.thread->priority.load(), [](const i8 priority, const SyncWaiters::value_type &it) { return it.second->priority > priority; }), {address, state.thread}); @@ -303,15 +322,36 @@ namespace skyline::kernel::type { return {}; } - Result KProcess::SignalToAddress(u32 *address, u32 value, i32 amount, bool(*mutateFunction)(u32 *address, u32 value, u32 waiterCount)) { + Result KProcess::SignalToAddress(u32 *address, u32 value, i32 amount, SignalType type) { TRACE_EVENT_FMT("kernel", "SignalToAddress 0x{:X}", address); std::scoped_lock lock{syncWaiterMutex}; auto queue{syncWaiters.equal_range(address)}; - if (mutateFunction) - if (!mutateFunction(address, value, (amount <= 0) ? 0 : std::min(static_cast(std::distance(queue.first, queue.second) - amount), 0U))) [[unlikely]] + if (type != SignalType::Signal) { + u32 newValue{value}; + if (type == SignalType::SignalAndIncrementIfEqual) { + newValue++; + } else if (type == SignalType::SignalAndModifyBasedOnWaitingThreadCountIfEqual) { + if (amount <= 0) { + if (queue.first != queue.second) + newValue -= 2; + else + newValue++; + } else { + if (queue.first != queue.second) { + i32 waiterCount{static_cast(std::distance(queue.first, queue.second))}; + if (waiterCount < amount) + newValue--; + } else { + newValue++; + } + } + } + + if (!__atomic_compare_exchange_n(address, &value, newValue, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST)) [[unlikely]] return result::InvalidState; + } i32 waiterCount{amount}; for (auto it{queue.first}; it != queue.second && (amount <= 0 || waiterCount); it = syncWaiters.erase(it), waiterCount--) diff --git a/app/src/main/cpp/skyline/kernel/types/KProcess.h b/app/src/main/cpp/skyline/kernel/types/KProcess.h index 893da215..049812af 100644 --- a/app/src/main/cpp/skyline/kernel/types/KProcess.h +++ b/app/src/main/cpp/skyline/kernel/types/KProcess.h @@ -230,15 +230,27 @@ namespace skyline { */ void ConditionalVariableSignal(u32 *key, i32 amount); - /** - * @brief Waits on the supplied address with the specified arbitration function - */ - Result WaitForAddress(u32 *address, u32 value, i64 timeout, bool(*arbitrationFunction)(u32 *address, u32 value)); + enum class ArbitrationType : u32 { + WaitIfLessThan = 0, + DecrementAndWaitIfLessThan = 1, + WaitIfEqual = 2, + }; /** - * @brief Signals a variable amount of waiters at the supplied address + * @brief Waits on the supplied address with the specified arbitration type */ - Result SignalToAddress(u32 *address, u32 value, i32 amount, bool(*mutateFunction)(u32 *address, u32 value, u32 waiterCount) = nullptr); + Result WaitForAddress(u32 *address, u32 value, i64 timeout, ArbitrationType type); + + enum class SignalType : u32 { + Signal = 0, + SignalAndIncrementIfEqual = 1, + SignalAndModifyBasedOnWaitingThreadCountIfEqual = 2, + }; + + /** + * @brief Signals a variable for amount of waiters at the supplied address with the specified signal type + */ + Result SignalToAddress(u32 *address, u32 value, i32 amount, SignalType type); }; } }