diff --git a/Ryujinx.HLE/HOS/Kernel/Common/KernelTransfer.cs b/Ryujinx.HLE/HOS/Kernel/Common/KernelTransfer.cs index 8b739f66e..d57ca481d 100644 --- a/Ryujinx.HLE/HOS/Kernel/Common/KernelTransfer.cs +++ b/Ryujinx.HLE/HOS/Kernel/Common/KernelTransfer.cs @@ -1,5 +1,6 @@ using Ryujinx.Cpu; using Ryujinx.HLE.HOS.Kernel.Process; +using System; namespace Ryujinx.HLE.HOS.Kernel.Common { @@ -22,7 +23,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Common return false; } - public static bool UserToKernelInt32Array(KernelContext context, ulong address, int[] values) + public static bool UserToKernelInt32Array(KernelContext context, ulong address, Span values) { KProcess currentProcess = context.Scheduler.GetCurrentProcess(); diff --git a/Ryujinx.HLE/HOS/Kernel/SupervisorCall/Syscall.cs b/Ryujinx.HLE/HOS/Kernel/SupervisorCall/Syscall.cs index fba22fc11..b6d2caf27 100644 --- a/Ryujinx.HLE/HOS/Kernel/SupervisorCall/Syscall.cs +++ b/Ryujinx.HLE/HOS/Kernel/SupervisorCall/Syscall.cs @@ -8,6 +8,7 @@ using Ryujinx.HLE.HOS.Kernel.Ipc; using Ryujinx.HLE.HOS.Kernel.Memory; using Ryujinx.HLE.HOS.Kernel.Process; using Ryujinx.HLE.HOS.Kernel.Threading; +using System; using System.Collections.Generic; using System.Threading; @@ -2139,30 +2140,84 @@ namespace Ryujinx.HLE.HOS.Kernel.SupervisorCall { handleIndex = 0; - if ((uint)handlesCount > 0x40) + if ((uint)handlesCount > KThread.MaxWaitSyncObjects) { return KernelResult.MaximumExceeded; } - List syncObjs = new List(); + KThread currentThread = _context.Scheduler.GetCurrentThread(); - KProcess process = _context.Scheduler.GetCurrentProcess(); + var syncObjs = new Span(currentThread.WaitSyncObjects).Slice(0, handlesCount); + + if (handlesCount != 0) + { + KProcess currentProcess = _context.Scheduler.GetCurrentProcess(); + + if (currentProcess.MemoryManager.AddrSpaceStart > handlesPtr) + { + return KernelResult.UserCopyFailed; + } + + long handlesSize = handlesCount * 4; + + if (handlesPtr + (ulong)handlesSize <= handlesPtr) + { + return KernelResult.UserCopyFailed; + } + + if (handlesPtr + (ulong)handlesSize - 1 > currentProcess.MemoryManager.AddrSpaceEnd - 1) + { + return KernelResult.UserCopyFailed; + } + + Span handles = new Span(currentThread.WaitSyncHandles).Slice(0, handlesCount); + + if (!KernelTransfer.UserToKernelInt32Array(_context, handlesPtr, handles)) + { + return KernelResult.UserCopyFailed; + } + + int processedHandles = 0; + + for (; processedHandles < handlesCount; processedHandles++) + { + KSynchronizationObject syncObj = currentProcess.HandleTable.GetObject(handles[processedHandles]); + + if (syncObj == null) + { + break; + } + + syncObjs[processedHandles] = syncObj; + + syncObj.IncrementReferenceCount(); + } + + if (processedHandles != handlesCount) + { + // One or more handles are invalid. + for (int index = 0; index < processedHandles; index++) + { + currentThread.WaitSyncObjects[index].DecrementReferenceCount(); + } + + return KernelResult.InvalidHandle; + } + } + + KernelResult result = _context.Synchronization.WaitFor(syncObjs, timeout, out handleIndex); + + if (result == KernelResult.PortRemoteClosed) + { + result = KernelResult.Success; + } for (int index = 0; index < handlesCount; index++) { - int handle = process.CpuMemory.Read(handlesPtr + (ulong)index * 4); - - KSynchronizationObject syncObj = process.HandleTable.GetObject(handle); - - if (syncObj == null) - { - break; - } - - syncObjs.Add(syncObj); + currentThread.WaitSyncObjects[index].DecrementReferenceCount(); } - return _context.Synchronization.WaitFor(syncObjs.ToArray(), timeout, out handleIndex); + return result; } public KernelResult CancelSynchronization(int handle) diff --git a/Ryujinx.HLE/HOS/Kernel/Threading/KSynchronization.cs b/Ryujinx.HLE/HOS/Kernel/Threading/KSynchronization.cs index fa9b669ea..22610b220 100644 --- a/Ryujinx.HLE/HOS/Kernel/Threading/KSynchronization.cs +++ b/Ryujinx.HLE/HOS/Kernel/Threading/KSynchronization.cs @@ -1,4 +1,5 @@ using Ryujinx.HLE.HOS.Kernel.Common; +using System; using System.Collections.Generic; namespace Ryujinx.HLE.HOS.Kernel.Threading @@ -12,7 +13,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading _context = context; } - public KernelResult WaitFor(KSynchronizationObject[] syncObjs, long timeout, out int handleIndex) + public KernelResult WaitFor(Span syncObjs, long timeout, out int handleIndex) { handleIndex = 0; diff --git a/Ryujinx.HLE/HOS/Kernel/Threading/KThread.cs b/Ryujinx.HLE/HOS/Kernel/Threading/KThread.cs index 754a1e530..d4603178f 100644 --- a/Ryujinx.HLE/HOS/Kernel/Threading/KThread.cs +++ b/Ryujinx.HLE/HOS/Kernel/Threading/KThread.cs @@ -12,6 +12,8 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading { class KThread : KSynchronizationObject, IKFutureSchedulerObject { + public const int MaxWaitSyncObjects = 64; + private int _hostThreadRunning; public Thread HostThread { get; private set; } @@ -39,6 +41,9 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading public ulong TlsAddress => _tlsAddress; public ulong TlsDramAddress { get; private set; } + public KSynchronizationObject[] WaitSyncObjects { get; } + public int[] WaitSyncHandles { get; } + public long LastScheduledTime { get; set; } public LinkedListNode[] SiblingsPerCore { get; private set; } @@ -96,6 +101,9 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading _scheduler = KernelContext.Scheduler; _schedulingData = KernelContext.Scheduler.SchedulingData; + WaitSyncObjects = new KSynchronizationObject[MaxWaitSyncObjects]; + WaitSyncHandles = new int[MaxWaitSyncObjects]; + SiblingsPerCore = new LinkedListNode[KScheduler.CpuCoresCount]; _mutexWaiters = new LinkedList();