Implement NCE Memory Trapping API

An API for trapping accesses to guest memory and performing callbacks based on those accesses alongside managing protection of the memory. This is a fundamental building block for avoiding redundant synchronization of resources from the guest and host.

Note: All accesses are treated as write accesses at the moment, support for picking up read accesses will be implemented later
This commit is contained in:
PixelyIon 2022-03-06 20:16:01 +05:30
parent 08510d75b0
commit 3e33d49faf
2 changed files with 197 additions and 4 deletions

View File

@ -75,6 +75,12 @@ namespace skyline::nce {
if (*tls) { // If TLS was restored then this occurred in guest code
auto &mctx{ctx->uc_mcontext};
const auto &state{*reinterpret_cast<ThreadContext *>(*tls)->state};
if (signal == SIGSEGV && info->si_code == SEGV_ACCERR)
// If we get a guest access violation then we want to handle any accesses that may be from a trapped region
if (state.nce->TrapHandler(reinterpret_cast<u8 *>(info->si_addr), true))
return;
if (signal != SIGINT) {
signal::StackFrame topFrame{.lr = reinterpret_cast<void *>(ctx->uc_mcontext.pc), .next = reinterpret_cast<signal::StackFrame *>(ctx->uc_mcontext.regs[29])};
std::string trace{state.loader->GetStackTrace(&topFrame)};
@ -84,7 +90,7 @@ namespace skyline::nce {
cpuContext += fmt::format("\n Fault Address: 0x{:X}", mctx.fault_address);
if (mctx.sp)
cpuContext += fmt::format("\n Stack Pointer: 0x{:X}", mctx.sp);
for (u8 index{}; index < (sizeof(mcontext_t::regs) / sizeof(u64)); index += 2)
for (size_t index{}; index < (sizeof(mcontext_t::regs) / sizeof(u64)); index += 2)
cpuContext += fmt::format("\n X{:<2}: 0x{:<16X} X{:<2}: 0x{:X}", index, mctx.regs[index], index + 1, mctx.regs[index + 1]);
Logger::Error("Thread #{} has crashed due to signal: {}\nStack Trace:{}\nCPU Context:{}", state.thread->id, strsignal(signal), trace, cpuContext);
@ -106,7 +112,7 @@ namespace skyline::nce {
static std::ifstream status("/proc/self/status");
status.seekg(0);
constexpr std::string_view TracerPidTag = "TracerPid:";
constexpr std::string_view TracerPidTag{"TracerPid:"};
for (std::string line; std::getline(status, line);) {
if (line.starts_with(TracerPidTag)) {
line = line.substr(TracerPidTag.size());
@ -355,4 +361,128 @@ namespace skyline::nce {
}
}
}
NCE::CallbackEntry::CallbackEntry(TrapProtection protection, NCE::TrapCallback readCallback, NCE::TrapCallback writeCallback) : protection(protection), readCallback(std::move(readCallback)), writeCallback(std::move(writeCallback)) {}
void NCE::ReprotectIntervals(const std::vector<TrapMap::Interval> &intervals, TrapProtection protection) {
auto reprotectIntervalsWithFunction = [&intervals](auto getProtection) {
for (auto region : intervals) {
region = region.Align(PAGE_SIZE);
mprotect(region.start, region.Size(), getProtection(region));
}
};
// We need to determine the lowest protection possible for the given interval
if (protection == TrapProtection::None) {
reprotectIntervalsWithFunction([&](auto region) {
auto entries{trapMap.GetRange(region)};
TrapProtection lowestProtection{TrapProtection::None};
for (const auto &entry : entries) {
auto entryProtection{entry.get().protection};
if (entryProtection > lowestProtection) {
lowestProtection = entryProtection;
if (entryProtection == TrapProtection::ReadWrite)
return PROT_EXEC;
}
}
switch (lowestProtection) {
case TrapProtection::None:
return PROT_READ | PROT_WRITE | PROT_EXEC;
case TrapProtection::WriteOnly:
return PROT_READ | PROT_EXEC;
case TrapProtection::ReadWrite:
return PROT_EXEC;
}
});
} else if (protection == TrapProtection::WriteOnly) {
reprotectIntervalsWithFunction([&](auto region) {
auto entries{trapMap.GetRange(region)};
for (const auto &entry : entries)
if (entry.get().protection == TrapProtection::ReadWrite)
return PROT_EXEC;
return PROT_READ | PROT_EXEC;
});
} else {
reprotectIntervalsWithFunction([&](auto region) {
return PROT_EXEC; // No checks are needed as this is already the highest level of protection
});
}
}
bool NCE::TrapHandler(u8 *address, bool write) {
std::scoped_lock lock(trapMutex);
// Check if we have a callback for this address
auto[entries, intervals]{trapMap.GetAlignedRecursiveRange<PAGE_SIZE>(address)};
if (entries.empty())
return false;
// Do callbacks for every entry in the intervals
if (write) {
for (auto entryRef : entries) {
auto &entry{entryRef.get()};
if (entry.protection == TrapProtection::None)
// We don't need to do the callback if the entry doesn't require any protection already
continue;
entry.writeCallback();
entry.protection = TrapProtection::None; // We don't need to protect this entry anymore
}
} else {
bool allNone{true}; // If all entries require no protection, we can protect to allow all accesses
for (auto entryRef : entries) {
auto &entry{entryRef.get()};
if (entry.protection < TrapProtection::ReadWrite) {
// We don't need to do the callback if the entry can already handle read accesses
allNone = allNone && entry.protection == TrapProtection::None;
continue;
}
entry.readCallback();
entry.protection = TrapProtection::WriteOnly; // We only need to trap writes to this entry
}
write = allNone;
}
int permission{PROT_READ | (write ? PROT_WRITE : 0) | PROT_EXEC};
for (const auto &interval : intervals)
// Reprotect the interval to the lowest protection level that the callbacks performed allow
mprotect(interval.start, interval.Size(), permission);
return true;
}
constexpr NCE::TrapHandle::TrapHandle(const TrapMap::GroupHandle &handle) : TrapMap::GroupHandle(handle) {}
NCE::TrapHandle NCE::TrapRegions(span<span<u8>> regions, bool writeOnly, const TrapCallback &readCallback, const TrapCallback &writeCallback) {
std::scoped_lock lock(trapMutex);
auto protection{writeOnly ? TrapProtection::WriteOnly : TrapProtection::ReadWrite};
TrapHandle handle{trapMap.Insert(regions, CallbackEntry{protection, readCallback, writeCallback})};
ReprotectIntervals(handle->intervals, protection);
return handle;
}
void NCE::RetrapRegions(TrapHandle handle, bool writeOnly) {
std::scoped_lock lock(trapMutex);
auto protection{writeOnly ? TrapProtection::WriteOnly : TrapProtection::ReadWrite};
handle->value.protection = protection;
ReprotectIntervals(handle->intervals, protection);
}
void NCE::RemoveTrap(TrapHandle handle) {
std::scoped_lock lock(trapMutex);
handle->value.protection = TrapProtection::None;
ReprotectIntervals(handle->intervals, TrapProtection::None);
}
void NCE::DeleteTrap(TrapHandle handle) {
std::scoped_lock lock(trapMutex);
handle->value.protection = TrapProtection::None;
ReprotectIntervals(handle->intervals, TrapProtection::None);
trapMap.Remove(handle);
}
}

View File

@ -3,8 +3,9 @@
#pragma once
#include "common.h"
#include <sys/wait.h>
#include "common.h"
#include "common/interval_map.h"
namespace skyline::nce {
/**
@ -14,6 +15,35 @@ namespace skyline::nce {
private:
const DeviceState &state;
/**
* @brief The level of protection that is required for a callback entry
*/
enum class TrapProtection {
None = 0, //!< No protection is required
WriteOnly = 1, //!< Only write protection is required
ReadWrite = 2, //!< Both read and write protection are required
};
using TrapCallback = std::function<void()>;
struct CallbackEntry {
TrapProtection protection; //!< The least restrictive protection that this callback needs to have
TrapCallback readCallback, writeCallback;
CallbackEntry(TrapProtection protection, NCE::TrapCallback readCallback, NCE::TrapCallback writeCallback);
};
std::mutex trapMutex; //!< Synchronizes the accesses to the trap map
using TrapMap = IntervalMap<u8*, CallbackEntry>;
TrapMap trapMap; //!< A map of all intervals and corresponding callbacks that have been registered
/**
* @brief Reprotects the intervals to the least restrictive protection given the supplied protection
*/
void ReprotectIntervals(const std::vector<TrapMap::Interval>& intervals, TrapProtection protection);
bool TrapHandler(u8* address, bool write);
static void SvcHandler(u16 svcId, ThreadContext *ctx);
public:
@ -26,7 +56,7 @@ namespace skyline::nce {
ExitException(bool killAllThreads = true);
virtual const char* what() const noexcept;
virtual const char *what() const noexcept;
};
/**
@ -48,5 +78,38 @@ namespace skyline::nce {
* @param patch A pointer to the .patch section which should be exactly patchSize in size and located before the .text section
*/
static void PatchCode(std::vector<u8> &text, u32 *patch, size_t patchSize, const std::vector<size_t> &offsets);
/**
* @brief An opaque handle to a group of trapped region
*/
class TrapHandle : private TrapMap::GroupHandle {
constexpr TrapHandle(const TrapMap::GroupHandle &handle);
friend NCE;
};
/**
* @brief Traps a region of guest memory with a callback for when an access to it has been made
* @param writeOnly If the trap is optimally for write-only accesses initially, this is not guarenteed
* @note The handle **must** be deleted using DeleteTrap before the NCE instance is destroyed
* @note It is UB to supply a region of host memory rather than guest memory
*/
TrapHandle TrapRegions(span<span<u8>> regions, bool writeOnly, const TrapCallback& readCallback, const TrapCallback& writeCallback);
/**
* @brief Re-traps a region of memory after protections were removed
* @param writeOnly If the trap is optimally for write-only accesses, this is not guarenteed
*/
void RetrapRegions(TrapHandle handle, bool writeOnly);
/**
* @brief Removes protections from a region of memory
*/
void RemoveTrap(TrapHandle handle);
/**
* @brief Deletes a trap handle and removes the protection from the region
*/
void DeleteTrap(TrapHandle handle);
};
}