Fix and refactor SVC SignalToAddress/WaitForAddress

SVC `SignalToAddress` had a bug with the behavior of `SignalAndModifyBasedOnWaitingThreadCountIfEqual` which was entirely incorrect and led to deadlocks in titles such as ARMS that were dependent on it. This commit corrects the behavior and refactors both SVCs and moves their arbitration/waiting to inside the corresponding `KProcess` function rather than the SVC to avoid redundancies and improve code readability.
This commit is contained in:
PixelyIon 2022-05-05 17:58:50 +05:30
parent 396979e897
commit 37327f1955
3 changed files with 75 additions and 44 deletions

View File

@ -1205,11 +1205,8 @@ namespace skyline::kernel::svc {
return;
}
enum class ArbitrationType : u32 {
WaitIfLessThan = 0,
DecrementAndWaitIfLessThan = 1,
WaitIfEqual = 2,
} arbitrationType{static_cast<ArbitrationType>(static_cast<u32>(state.ctx->gpr.w1))};
using ArbitrationType = type::KProcess::ArbitrationType;
auto arbitrationType{static_cast<ArbitrationType>(static_cast<u32>(state.ctx->gpr.w1))};
u32 value{state.ctx->gpr.w2};
i64 timeout{static_cast<i64>(state.ctx->gpr.x3)};
@ -1217,28 +1214,17 @@ namespace skyline::kernel::svc {
switch (arbitrationType) {
case ArbitrationType::WaitIfLessThan:
Logger::Debug("Waiting on 0x{:X} if less than {} for {}ns", address, value, timeout);
result = state.process->WaitForAddress(address, value, timeout, [](u32 *address, u32 value) {
return *address < value;
});
result = state.process->WaitForAddress(address, value, timeout, ArbitrationType::WaitIfLessThan);
break;
case ArbitrationType::DecrementAndWaitIfLessThan:
Logger::Debug("Waiting on and decrementing 0x{:X} if less than {} for {}ns", address, value, timeout);
result = state.process->WaitForAddress(address, value, timeout, [](u32 *address, u32 value) {
u32 userValue{__atomic_load_n(address, __ATOMIC_SEQ_CST)};
do {
if (value <= userValue) [[unlikely]] // We want to explicitly decrement **after** the check
return false;
} while (!__atomic_compare_exchange_n(address, &userValue, userValue - 1, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST));
return true;
});
result = state.process->WaitForAddress(address, value, timeout, ArbitrationType::DecrementAndWaitIfLessThan);
break;
case ArbitrationType::WaitIfEqual:
Logger::Debug("Waiting on 0x{:X} if equal to {} for {}ns", address, value, timeout);
result = state.process->WaitForAddress(address, value, timeout, [](u32 *address, u32 value) {
return *address == value;
});
result = state.process->WaitForAddress(address, value, timeout, ArbitrationType::WaitIfEqual);
break;
default:
@ -1267,11 +1253,8 @@ namespace skyline::kernel::svc {
return;
}
enum class SignalType : u32 {
Signal = 0,
SignalAndIncrementIfEqual = 1,
SignalAndModifyBasedOnWaitingThreadCountIfEqual = 2,
} signalType{static_cast<SignalType>(static_cast<u32>(state.ctx->gpr.w1))};
using SignalType = type::KProcess::SignalType;
auto signalType{static_cast<SignalType>(static_cast<u32>(state.ctx->gpr.w1))};
u32 value{state.ctx->gpr.w2};
i32 count{static_cast<i32>(state.ctx->gpr.w3)};
@ -1279,21 +1262,17 @@ namespace skyline::kernel::svc {
switch (signalType) {
case SignalType::Signal:
Logger::Debug("Signalling 0x{:X} for {} waiters", address, count);
result = state.process->SignalToAddress(address, value, count);
result = state.process->SignalToAddress(address, value, count, SignalType::Signal);
break;
case SignalType::SignalAndIncrementIfEqual:
Logger::Debug("Signalling 0x{:X} and incrementing if equal to {} for {} waiters", address, value, count);
result = state.process->SignalToAddress(address, value, count, [](u32 *address, u32 value, u32) {
return __atomic_compare_exchange_n(address, &value, value + 1, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST);
});
result = state.process->SignalToAddress(address, value, count, SignalType::SignalAndIncrementIfEqual);
break;
case SignalType::SignalAndModifyBasedOnWaitingThreadCountIfEqual:
Logger::Debug("Signalling 0x{:X} and setting to waiting thread count if equal to {} for {} waiters", address, value, count);
result = state.process->SignalToAddress(address, value, count, [](u32 *address, u32 value, u32 waiterCount) {
return __atomic_compare_exchange_n(address, &value, waiterCount, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST);
});
result = state.process->SignalToAddress(address, value, count, SignalType::SignalAndModifyBasedOnWaitingThreadCountIfEqual);
break;
default:

View File

@ -54,7 +54,7 @@ namespace skyline::kernel::type {
u8 *KProcess::AllocateTlsSlot() {
std::scoped_lock lock{tlsMutex};
u8 *slot;
for (auto &tlsPage: tlsPages)
for (auto &tlsPage : tlsPages)
if ((slot = tlsPage->ReserveSlot()))
return slot;
@ -268,13 +268,32 @@ namespace skyline::kernel::type {
__atomic_store_n(key, false, __ATOMIC_SEQ_CST); // We need to update the boolean flag denoting that there are no more threads waiting on this conditional variable
}
Result KProcess::WaitForAddress(u32 *address, u32 value, i64 timeout, bool (*arbitrationFunction)(u32 *, u32)) {
Result KProcess::WaitForAddress(u32 *address, u32 value, i64 timeout, ArbitrationType type) {
TRACE_EVENT_FMT("kernel", "WaitForAddress 0x{:X}", address);
{
std::scoped_lock lock{syncWaiterMutex};
if (!arbitrationFunction(address, value)) [[unlikely]]
return result::InvalidState;
switch (type) {
case ArbitrationType::WaitIfLessThan:
if (*address >= value) [[unlikely]]
return result::InvalidState;
break;
case ArbitrationType::DecrementAndWaitIfLessThan: {
u32 userValue{__atomic_load_n(address, __ATOMIC_SEQ_CST)};
do {
if (value <= userValue) [[unlikely]] // We want to explicitly decrement **after** the check
return result::InvalidState;
} while (!__atomic_compare_exchange_n(address, &userValue, userValue - 1, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST));
break;
}
case ArbitrationType::WaitIfEqual:
if (*address != value) [[unlikely]]
return result::InvalidState;
break;
}
auto queue{syncWaiters.equal_range(address)};
syncWaiters.insert(std::upper_bound(queue.first, queue.second, state.thread->priority.load(), [](const i8 priority, const SyncWaiters::value_type &it) { return it.second->priority > priority; }), {address, state.thread});
@ -303,15 +322,36 @@ namespace skyline::kernel::type {
return {};
}
Result KProcess::SignalToAddress(u32 *address, u32 value, i32 amount, bool(*mutateFunction)(u32 *address, u32 value, u32 waiterCount)) {
Result KProcess::SignalToAddress(u32 *address, u32 value, i32 amount, SignalType type) {
TRACE_EVENT_FMT("kernel", "SignalToAddress 0x{:X}", address);
std::scoped_lock lock{syncWaiterMutex};
auto queue{syncWaiters.equal_range(address)};
if (mutateFunction)
if (!mutateFunction(address, value, (amount <= 0) ? 0 : std::min(static_cast<u32>(std::distance(queue.first, queue.second) - amount), 0U))) [[unlikely]]
if (type != SignalType::Signal) {
u32 newValue{value};
if (type == SignalType::SignalAndIncrementIfEqual) {
newValue++;
} else if (type == SignalType::SignalAndModifyBasedOnWaitingThreadCountIfEqual) {
if (amount <= 0) {
if (queue.first != queue.second)
newValue -= 2;
else
newValue++;
} else {
if (queue.first != queue.second) {
i32 waiterCount{static_cast<i32>(std::distance(queue.first, queue.second))};
if (waiterCount < amount)
newValue--;
} else {
newValue++;
}
}
}
if (!__atomic_compare_exchange_n(address, &value, newValue, false, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST)) [[unlikely]]
return result::InvalidState;
}
i32 waiterCount{amount};
for (auto it{queue.first}; it != queue.second && (amount <= 0 || waiterCount); it = syncWaiters.erase(it), waiterCount--)

View File

@ -230,15 +230,27 @@ namespace skyline {
*/
void ConditionalVariableSignal(u32 *key, i32 amount);
/**
* @brief Waits on the supplied address with the specified arbitration function
*/
Result WaitForAddress(u32 *address, u32 value, i64 timeout, bool(*arbitrationFunction)(u32 *address, u32 value));
enum class ArbitrationType : u32 {
WaitIfLessThan = 0,
DecrementAndWaitIfLessThan = 1,
WaitIfEqual = 2,
};
/**
* @brief Signals a variable amount of waiters at the supplied address
* @brief Waits on the supplied address with the specified arbitration type
*/
Result SignalToAddress(u32 *address, u32 value, i32 amount, bool(*mutateFunction)(u32 *address, u32 value, u32 waiterCount) = nullptr);
Result WaitForAddress(u32 *address, u32 value, i64 timeout, ArbitrationType type);
enum class SignalType : u32 {
Signal = 0,
SignalAndIncrementIfEqual = 1,
SignalAndModifyBasedOnWaitingThreadCountIfEqual = 2,
};
/**
* @brief Signals a variable for amount of waiters at the supplied address with the specified signal type
*/
Result SignalToAddress(u32 *address, u32 value, i32 amount, SignalType type);
};
}
}