Update plugin (un)loading logic to keep plugins loaded if possible

This commit is contained in:
Maschell 2024-11-27 20:44:20 +01:00
parent e41718836d
commit 1524f0a6a9
14 changed files with 193 additions and 115 deletions

View File

@ -13,17 +13,18 @@
#include <memory> #include <memory>
#include <ranges> #include <ranges>
static uint32_t sTrampolineID = 0;
std::vector<PluginContainer> std::vector<PluginContainer>
PluginManagement::loadPlugins(const std::set<std::shared_ptr<PluginData>> &pluginDataList, std::vector<relocation_trampoline_entry_t> &trampolineData) { PluginManagement::loadPlugins(const std::set<std::shared_ptr<PluginData>, PluginDataSharedPtrComparator> &pluginDataList, std::vector<relocation_trampoline_entry_t> &trampolineData) {
std::vector<PluginContainer> plugins; std::vector<PluginContainer> plugins;
uint32_t trampolineID = 0;
for (const auto &pluginData : pluginDataList) { for (const auto &pluginData : pluginDataList) {
PluginParseErrors error = PLUGIN_PARSE_ERROR_UNKNOWN; PluginParseErrors error = PLUGIN_PARSE_ERROR_UNKNOWN;
auto metaInfo = PluginMetaInformationFactory::loadPlugin(*pluginData, error); auto metaInfo = PluginMetaInformationFactory::loadPlugin(*pluginData, error);
if (metaInfo && error == PLUGIN_PARSE_ERROR_NONE) { if (metaInfo && error == PLUGIN_PARSE_ERROR_NONE) {
auto info = PluginInformationFactory::load(*pluginData, trampolineData, trampolineID++); auto info = PluginInformationFactory::load(*pluginData, trampolineData, sTrampolineID++);
if (!info) { if (!info) {
auto errMsg = string_format("Failed to load plugin: %s", pluginData->getSource().c_str()); auto errMsg = string_format("Failed to load plugin: %s", pluginData->getSource().c_str());
DEBUG_FUNCTION_LINE_ERR("%s", errMsg.c_str()); DEBUG_FUNCTION_LINE_ERR("%s", errMsg.c_str());
@ -79,8 +80,7 @@ bool PluginManagement::doRelocation(const std::vector<RelocationData> &relocData
if (!usedRPls.contains(rplName)) { if (!usedRPls.contains(rplName)) {
DEBUG_FUNCTION_LINE_VERBOSE("Acquire %s", rplName.c_str()); DEBUG_FUNCTION_LINE_VERBOSE("Acquire %s", rplName.c_str());
// Always acquire to increase refcount and make sure it won't get unloaded while we're using it. // Always acquire to increase refcount and make sure it won't get unloaded while we're using it.
OSDynLoad_Error err = OSDynLoad_Acquire(rplName.c_str(), &rplHandle); if (const OSDynLoad_Error err = OSDynLoad_Acquire(rplName.c_str(), &rplHandle); err != OS_DYNLOAD_OK) {
if (err != OS_DYNLOAD_OK) {
DEBUG_FUNCTION_LINE_ERR("Failed to acquire %s", rplName.c_str()); DEBUG_FUNCTION_LINE_ERR("Failed to acquire %s", rplName.c_str());
return false; return false;
} }
@ -91,7 +91,7 @@ bool PluginManagement::doRelocation(const std::vector<RelocationData> &relocData
rplHandle = usedRPls[rplName]; rplHandle = usedRPls[rplName];
} }
OSDynLoad_FindExport(rplHandle, (OSDynLoad_ExportType) isData, functionName.c_str(), (void **) &functionAddress); OSDynLoad_FindExport(rplHandle, static_cast<OSDynLoad_ExportType>(isData), functionName.c_str(), reinterpret_cast<void **>(&functionAddress));
} }
if (functionAddress == 0) { if (functionAddress == 0) {
@ -101,7 +101,7 @@ bool PluginManagement::doRelocation(const std::vector<RelocationData> &relocData
//DEBUG_FUNCTION_LINE("Found export for %s %s", rplName.c_str(), functionName.c_str()); //DEBUG_FUNCTION_LINE("Found export for %s %s", rplName.c_str(), functionName.c_str());
} }
if (!ElfUtils::elfLinkOne(cur.getType(), cur.getOffset(), cur.getAddend(), (uint32_t) cur.getDestination(), functionAddress, trampData, RELOC_TYPE_IMPORT, trampolineID)) { if (!ElfUtils::elfLinkOne(cur.getType(), cur.getOffset(), cur.getAddend(), reinterpret_cast<uint32_t>(cur.getDestination()), functionAddress, trampData, RELOC_TYPE_IMPORT, trampolineID)) {
DEBUG_FUNCTION_LINE_ERR("elfLinkOne failed"); DEBUG_FUNCTION_LINE_ERR("elfLinkOne failed");
return false; return false;
} }
@ -156,6 +156,7 @@ bool PluginManagement::RestoreFunctionPatches(std::vector<PluginContainer> &plug
for (auto &cur : std::ranges::reverse_view(plugins)) { for (auto &cur : std::ranges::reverse_view(plugins)) {
for (auto &curFunction : std::ranges::reverse_view(cur.getPluginInformation().getFunctionDataList())) { for (auto &curFunction : std::ranges::reverse_view(cur.getPluginInformation().getFunctionDataList())) {
if (!curFunction.RemovePatch()) { if (!curFunction.RemovePatch()) {
DEBUG_FUNCTION_LINE_ERR("Failed to remove function patch for: plugin %s", cur.getMetaInformation().getName().c_str());
return false; return false;
} }
} }
@ -175,10 +176,10 @@ bool PluginManagement::DoFunctionPatches(std::vector<PluginContainer> &plugins)
return true; return true;
} }
void PluginManagement::callInitHooks(const std::vector<PluginContainer> &plugins) { void PluginManagement::callInitHooks(const std::vector<PluginContainer> &plugins, const std::function<bool(const PluginContainer &)> &pred) {
CallHook(plugins, WUPS_LOADER_HOOK_INIT_CONFIG); CallHook(plugins, WUPS_LOADER_HOOK_INIT_CONFIG, pred);
CallHook(plugins, WUPS_LOADER_HOOK_INIT_STORAGE_DEPRECATED); CallHook(plugins, WUPS_LOADER_HOOK_INIT_STORAGE_DEPRECATED, pred);
CallHook(plugins, WUPS_LOADER_HOOK_INIT_STORAGE); CallHook(plugins, WUPS_LOADER_HOOK_INIT_STORAGE, pred);
CallHook(plugins, WUPS_LOADER_HOOK_INIT_PLUGIN); CallHook(plugins, WUPS_LOADER_HOOK_INIT_PLUGIN, pred);
DEBUG_FUNCTION_LINE_VERBOSE("Done calling init hooks"); DEBUG_FUNCTION_LINE_VERBOSE("Done calling init hooks");
} }

View File

@ -1,7 +1,9 @@
#pragma once #pragma once
#include "plugin/PluginContainer.h" #include "plugin/PluginContainer.h"
#include <coreinit/dynload.h> #include <coreinit/dynload.h>
#include <functional>
#include <map> #include <map>
#include <memory> #include <memory>
#include <set> #include <set>
@ -10,10 +12,10 @@
class PluginManagement { class PluginManagement {
public: public:
static std::vector<PluginContainer> loadPlugins( static std::vector<PluginContainer> loadPlugins(
const std::set<std::shared_ptr<PluginData>> &pluginDataList, const std::set<std::shared_ptr<PluginData>, PluginDataSharedPtrComparator> &pluginDataList,
std::vector<relocation_trampoline_entry_t> &trampolineData); std::vector<relocation_trampoline_entry_t> &trampolineData);
static void callInitHooks(const std::vector<PluginContainer> &plugins); static void callInitHooks(const std::vector<PluginContainer> &plugins, const std::function<bool(const PluginContainer &)> &pred);
static bool doRelocations(const std::vector<PluginContainer> &plugins, static bool doRelocations(const std::vector<PluginContainer> &plugins,
std::vector<relocation_trampoline_entry_t> &trampData, std::vector<relocation_trampoline_entry_t> &trampData,

View File

@ -8,8 +8,8 @@ StoredBuffer gStoredDRCBuffer = {};
std::vector<PluginContainer> gLoadedPlugins; std::vector<PluginContainer> gLoadedPlugins;
std::vector<relocation_trampoline_entry_t> gTrampData; std::vector<relocation_trampoline_entry_t> gTrampData;
std::set<std::shared_ptr<PluginData>> gLoadedData; std::set<std::shared_ptr<PluginData>, PluginDataSharedPtrComparator> gLoadedData;
std::set<std::shared_ptr<PluginData>> gLoadOnNextLaunch; std::vector<std::shared_ptr<PluginData>> gLoadOnNextLaunch;
std::mutex gLoadedDataMutex; std::mutex gLoadedDataMutex;
std::map<std::string, OSDynLoad_Module> gUsedRPLs; std::map<std::string, OSDynLoad_Module> gUsedRPLs;
std::vector<void *> gAllocatedAddresses; std::vector<void *> gAllocatedAddresses;

View File

@ -13,6 +13,7 @@
#define MODULE_VERSION "v0.3.4" #define MODULE_VERSION "v0.3.4"
#define MODULE_VERSION_FULL MODULE_VERSION MODULE_VERSION_EXTRA #define MODULE_VERSION_FULL MODULE_VERSION MODULE_VERSION_EXTRA
class PluginDataSharedPtrComparator;
class PluginData; class PluginData;
class PluginContainer; class PluginContainer;
@ -23,8 +24,8 @@ extern StoredBuffer gStoredDRCBuffer;
extern std::vector<relocation_trampoline_entry_t> gTrampData; extern std::vector<relocation_trampoline_entry_t> gTrampData;
extern std::vector<PluginContainer> gLoadedPlugins; extern std::vector<PluginContainer> gLoadedPlugins;
extern std::set<std::shared_ptr<PluginData>> gLoadedData; extern std::set<std::shared_ptr<PluginData>, PluginDataSharedPtrComparator> gLoadedData;
extern std::set<std::shared_ptr<PluginData>> gLoadOnNextLaunch; extern std::vector<std::shared_ptr<PluginData>> gLoadOnNextLaunch;
extern std::mutex gLoadedDataMutex; extern std::mutex gLoadedDataMutex;
extern std::map<std::string, OSDynLoad_Module> gUsedRPLs; extern std::map<std::string, OSDynLoad_Module> gUsedRPLs;
extern std::vector<void *> gAllocatedAddresses; extern std::vector<void *> gAllocatedAddresses;

View File

@ -3,6 +3,8 @@
#include "utils/StorageUtilsDeprecated.h" #include "utils/StorageUtilsDeprecated.h"
#include "utils/logger.h" #include "utils/logger.h"
#include "utils/storage/StorageUtils.h" #include "utils/storage/StorageUtils.h"
#include <functional>
#include <wups/storage.h> #include <wups/storage.h>
static const char **hook_names = (const char *[]){ static const char **hook_names = (const char *[]){
@ -35,10 +37,16 @@ static const char **hook_names = (const char *[]){
"WUPS_LOADER_HOOK_INIT_STORAGE", "WUPS_LOADER_HOOK_INIT_STORAGE",
"WUPS_LOADER_HOOK_INIT_CONFIG"}; "WUPS_LOADER_HOOK_INIT_CONFIG"};
void CallHook(const std::vector<PluginContainer> &plugins, wups_loader_hook_type_t hook_type) { void CallHook(const std::vector<PluginContainer> &plugins, const wups_loader_hook_type_t hook_type) {
CallHook(plugins, hook_type, [](const auto &) { return true; });
}
void CallHook(const std::vector<PluginContainer> &plugins, const wups_loader_hook_type_t hook_type, const std::function<bool(const PluginContainer &)> &pred) {
DEBUG_FUNCTION_LINE_VERBOSE("Calling hook of type %s [%d]", hook_names[hook_type], hook_type); DEBUG_FUNCTION_LINE_VERBOSE("Calling hook of type %s [%d]", hook_names[hook_type], hook_type);
for (const auto &plugin : plugins) { for (const auto &plugin : plugins) {
CallHook(plugin, hook_type); if (pred(plugin)) {
CallHook(plugin, hook_type);
}
} }
} }

View File

@ -1,9 +1,13 @@
#pragma once #pragma once
#include "plugin/PluginContainer.h" #include "plugin/PluginContainer.h"
#include <functional>
#include <vector> #include <vector>
#include <wups/hooks.h> #include <wups/hooks.h>
void CallHook(const std::vector<PluginContainer> &plugins, wups_loader_hook_type_t hook_type, const std::function<bool(const PluginContainer &)> &pred);
void CallHook(const std::vector<PluginContainer> &plugins, wups_loader_hook_type_t hook_type); void CallHook(const std::vector<PluginContainer> &plugins, wups_loader_hook_type_t hook_type);
void CallHook(const PluginContainer &plugin, wups_loader_hook_type_t hook_type); void CallHook(const PluginContainer &plugin, wups_loader_hook_type_t hook_type);

View File

@ -7,10 +7,11 @@
#include "plugin/PluginDataFactory.h" #include "plugin/PluginDataFactory.h"
#include "utils/logger.h" #include "utils/logger.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "version.h"
#include <coreinit/debug.h> #include <coreinit/debug.h>
#include <notifications/notifications.h> #include <notifications/notifications.h>
#include <version.h> #include <ranges>
#include <wums.h> #include <wums.h>
WUMS_MODULE_EXPORT_NAME("homebrew_wupsbackend"); WUMS_MODULE_EXPORT_NAME("homebrew_wupsbackend");
@ -26,8 +27,7 @@ WUMS_INITIALIZE() {
OSFatal("homebrew_wupsbackend: FunctionPatcher_InitLibrary failed"); OSFatal("homebrew_wupsbackend: FunctionPatcher_InitLibrary failed");
} }
NotificationModuleStatus res; if (const NotificationModuleStatus res = NotificationModule_InitLibrary(); res != NOTIFICATION_MODULE_RESULT_SUCCESS) {
if ((res = NotificationModule_InitLibrary()) != NOTIFICATION_MODULE_RESULT_SUCCESS) {
DEBUG_FUNCTION_LINE_ERR("Failed to init NotificationModule: %s (%d)", NotificationModule_GetStatusStr(res), res); DEBUG_FUNCTION_LINE_ERR("Failed to init NotificationModule: %s (%d)", NotificationModule_GetStatusStr(res), res);
gNotificationModuleLoaded = false; gNotificationModuleLoaded = false;
} else { } else {
@ -45,7 +45,7 @@ WUMS_INITIALIZE() {
} }
WUMS_APPLICATION_REQUESTS_EXIT() { WUMS_APPLICATION_REQUESTS_EXIT() {
uint32_t upid = OSGetUPID(); const uint32_t upid = OSGetUPID();
if (upid != 2 && upid != 15) { if (upid != 2 && upid != 15) {
return; return;
} }
@ -53,7 +53,7 @@ WUMS_APPLICATION_REQUESTS_EXIT() {
} }
WUMS_APPLICATION_ENDS() { WUMS_APPLICATION_ENDS() {
uint32_t upid = OSGetUPID(); const uint32_t upid = OSGetUPID();
if (upid != 2 && upid != 15) { if (upid != 2 && upid != 15) {
return; return;
} }
@ -62,8 +62,8 @@ WUMS_APPLICATION_ENDS() {
CallHook(gLoadedPlugins, WUPS_LOADER_HOOK_FINI_WUT_SOCKETS); CallHook(gLoadedPlugins, WUPS_LOADER_HOOK_FINI_WUT_SOCKETS);
CallHook(gLoadedPlugins, WUPS_LOADER_HOOK_FINI_WUT_DEVOPTAB); CallHook(gLoadedPlugins, WUPS_LOADER_HOOK_FINI_WUT_DEVOPTAB);
for (const auto &pair : gUsedRPLs) { for (const auto &val : gUsedRPLs | std::views::values) {
OSDynLoad_Release(pair.second); OSDynLoad_Release(val);
} }
gUsedRPLs.clear(); gUsedRPLs.clear();
@ -71,9 +71,11 @@ WUMS_APPLICATION_ENDS() {
} }
void CheckCleanupCallbackUsage(const std::vector<PluginContainer> &plugins); void CheckCleanupCallbackUsage(const std::vector<PluginContainer> &plugins);
void CleanupPlugins(std::vector<PluginContainer> &&pluginsToDeinit);
WUMS_APPLICATION_STARTS() { WUMS_APPLICATION_STARTS() {
uint32_t upid = OSGetUPID(); const uint32_t upid = OSGetUPID();
if (upid != 2 && upid != 15) { if (upid != 2 && upid != 15) {
return; return;
} }
@ -87,14 +89,13 @@ WUMS_APPLICATION_STARTS() {
// Let's clean this up! // Let's clean this up!
for (const auto &addr : gAllocatedAddresses) { for (const auto &addr : gAllocatedAddresses) {
DEBUG_FUNCTION_LINE_WARN("Memory allocated by OSDynload was not freed properly, let's clean it up! (%08X)", addr); DEBUG_FUNCTION_LINE_WARN("Memory allocated by OSDynload was not freed properly, let's clean it up! (%08X)", addr);
free((void *) addr); free(addr);
} }
gAllocatedAddresses.clear(); gAllocatedAddresses.clear();
initLogging(); initLogging();
bool initNeeded = false;
std::lock_guard<std::mutex> lock(gLoadedDataMutex); std::lock_guard lock(gLoadedDataMutex);
if (gTrampData.empty()) { if (gTrampData.empty()) {
gTrampData = std::vector<relocation_trampoline_entry_t>(TRAMP_DATA_SIZE); gTrampData = std::vector<relocation_trampoline_entry_t>(TRAMP_DATA_SIZE);
@ -103,109 +104,142 @@ WUMS_APPLICATION_STARTS() {
} }
} }
std::vector<PluginContainer> newLoadedPlugins;
if (gLoadedPlugins.empty()) { if (gLoadedPlugins.empty()) {
auto pluginPath = getPluginPath(); const auto pluginPath = getPluginPath();
DEBUG_FUNCTION_LINE("Load plugins from %s", pluginPath.c_str()); DEBUG_FUNCTION_LINE("Load plugins from %s", pluginPath.c_str());
auto pluginData = PluginDataFactory::loadDir(pluginPath); const auto pluginData = PluginDataFactory::loadDir(pluginPath);
gLoadedPlugins = PluginManagement::loadPlugins(pluginData, gTrampData); newLoadedPlugins = PluginManagement::loadPlugins(pluginData, gTrampData);
initNeeded = true;
} }
if (!gLoadOnNextLaunch.empty()) { if (!gLoadOnNextLaunch.empty()) {
auto *currentThread = OSGetCurrentThread(); std::vector<PluginContainer> pluginsToKeep;
auto saved_reent = currentThread->reserved[4]; std::set<std::shared_ptr<PluginData>, PluginDataSharedPtrComparator> toBeLoaded;
auto saved_cleanupCallback = currentThread->cleanupCallback;
currentThread->reserved[4] = 0; // Check which plugins are already loaded and which needs to be
for (const auto &pluginData : gLoadOnNextLaunch) {
CallHook(gLoadedPlugins, WUPS_LOADER_HOOK_DEINIT_PLUGIN); // Check if the plugin data is already loaded
if (auto it = std::ranges::find_if(gLoadedPlugins,
CheckCleanupCallbackUsage(gLoadedPlugins); [&pluginData](const PluginContainer &container) {
return container.getPluginDataCopy()->getHandle() == pluginData->getHandle();
if (currentThread->cleanupCallback != saved_cleanupCallback) { });
DEBUG_FUNCTION_LINE_WARN("WUPS_LOADER_HOOK_DEINIT_PLUGIN overwrote the ThreadCleanupCallback, we need to restore it!\n"); it != gLoadedPlugins.end()) {
OSSetThreadCleanupCallback(OSGetCurrentThread(), saved_cleanupCallback); pluginsToKeep.push_back(std::move(*it));
} gLoadedPlugins.erase(it);
} else {
currentThread->reserved[4] = saved_reent; // Load it if it's not already loaded
toBeLoaded.insert(pluginData);
DEBUG_FUNCTION_LINE("Restore function patches of currently loaded plugins.");
PluginManagement::RestoreFunctionPatches(gLoadedPlugins);
for (auto &plugin : gLoadedPlugins) {
WUPSStorageError err = plugin.CloseStorage();
if (err != WUPS_STORAGE_ERROR_SUCCESS) {
DEBUG_FUNCTION_LINE_ERR("Failed to close storage for plugin: %s", plugin.getMetaInformation().getName().c_str());
} }
} }
DEBUG_FUNCTION_LINE("Unload existing plugins."); std::vector<PluginContainer> pluginsToDeinit = std::move(gLoadedPlugins);
gLoadedPlugins.clear(); gLoadedPlugins = std::move(pluginsToKeep);
for (auto &cur : gTrampData) {
cur.status = RELOC_TRAMP_FREE; DEBUG_FUNCTION_LINE("Deinit unused plugins");
} CleanupPlugins(std::move(pluginsToDeinit));
DEBUG_FUNCTION_LINE("Load new plugins"); DEBUG_FUNCTION_LINE("Load new plugins");
gLoadedPlugins = PluginManagement::loadPlugins(gLoadOnNextLaunch, gTrampData); newLoadedPlugins = PluginManagement::loadPlugins(toBeLoaded, gTrampData);
initNeeded = true;
} }
DEBUG_FUNCTION_LINE("Clear plugin data lists."); DEBUG_FUNCTION_LINE("Clear plugin data lists.");
gLoadOnNextLaunch.clear(); gLoadOnNextLaunch.clear();
gLoadedData.clear(); gLoadedData.clear();
if (!gLoadedPlugins.empty()) { if (!gLoadedPlugins.empty() || !newLoadedPlugins.empty()) {
for (auto &pluginContainer : newLoadedPlugins) {
pluginContainer.setInitDone(false);
}
// Move all new plugin containers into gLoadedPlugins
append_move_all_values(gLoadedPlugins, newLoadedPlugins);
if (!PluginManagement::doRelocations(gLoadedPlugins, gTrampData, gUsedRPLs)) { if (!PluginManagement::doRelocations(gLoadedPlugins, gTrampData, gUsedRPLs)) {
DEBUG_FUNCTION_LINE_ERR("Relocations failed"); DEBUG_FUNCTION_LINE_ERR("Relocations failed");
OSFatal("WiiUPluginLoaderBackend: Relocations failed.\n See crash logs for more information."); OSFatal("WiiUPluginLoaderBackend: Relocations failed.\n See crash logs for more information.");
} }
// PluginManagement::memsetBSS(plugins); // PluginManagement::memsetBSS(plugins);
if (initNeeded) { const auto &needsInitsCheck = [](const PluginContainer &container) { return !container.isInitDone(); };
CallHook(gLoadedPlugins, WUPS_LOADER_HOOK_INIT_WUT_MALLOC); CallHook(gLoadedPlugins, WUPS_LOADER_HOOK_INIT_WUT_MALLOC, needsInitsCheck);
CallHook(gLoadedPlugins, WUPS_LOADER_HOOK_INIT_WUT_NEWLIB); CallHook(gLoadedPlugins, WUPS_LOADER_HOOK_INIT_WUT_NEWLIB, needsInitsCheck);
CallHook(gLoadedPlugins, WUPS_LOADER_HOOK_INIT_WUT_STDCPP); CallHook(gLoadedPlugins, WUPS_LOADER_HOOK_INIT_WUT_STDCPP, needsInitsCheck);
}
CallHook(gLoadedPlugins, WUPS_LOADER_HOOK_INIT_WUT_DEVOPTAB); CallHook(gLoadedPlugins, WUPS_LOADER_HOOK_INIT_WUT_DEVOPTAB);
CallHook(gLoadedPlugins, WUPS_LOADER_HOOK_INIT_WUT_SOCKETS); CallHook(gLoadedPlugins, WUPS_LOADER_HOOK_INIT_WUT_SOCKETS);
if (initNeeded) { CallHook(gLoadedPlugins, WUPS_LOADER_HOOK_INIT_WRAPPER, needsInitsCheck);
CallHook(gLoadedPlugins, WUPS_LOADER_HOOK_INIT_WRAPPER);
}
if (initNeeded) { for (auto &plugin : gLoadedPlugins) {
for (auto &plugin : gLoadedPlugins) { if (plugin.isInitDone()) { continue; }
WUPSStorageError err = plugin.OpenStorage(); if (const WUPSStorageError err = plugin.OpenStorage(); err != WUPS_STORAGE_ERROR_SUCCESS) {
if (err != WUPS_STORAGE_ERROR_SUCCESS) { DEBUG_FUNCTION_LINE_ERR("Failed to open storage for plugin: %s. (%s)", plugin.getMetaInformation().getName().c_str(), WUPSStorageAPI_GetStatusStr(err));
DEBUG_FUNCTION_LINE_ERR("Failed to open storage for plugin: %s. (%s)", plugin.getMetaInformation().getName().c_str(), WUPSStorageAPI_GetStatusStr(err));
}
} }
PluginManagement::callInitHooks(gLoadedPlugins);
} }
PluginManagement::callInitHooks(gLoadedPlugins, needsInitsCheck);
CallHook(gLoadedPlugins, WUPS_LOADER_HOOK_APPLICATION_STARTS); CallHook(gLoadedPlugins, WUPS_LOADER_HOOK_APPLICATION_STARTS);
for (auto &pluginContainer : gLoadedPlugins) {
pluginContainer.setInitDone(true);
}
} }
} }
void CleanupPlugins(std::vector<PluginContainer> &&pluginsToDeinit) {
auto *currentThread = OSGetCurrentThread();
const auto saved_reent = currentThread->reserved[4];
const auto saved_cleanupCallback = currentThread->cleanupCallback;
currentThread->reserved[4] = 0;
CallHook(pluginsToDeinit, WUPS_LOADER_HOOK_DEINIT_PLUGIN);
CheckCleanupCallbackUsage(pluginsToDeinit);
if (currentThread->cleanupCallback != saved_cleanupCallback) {
DEBUG_FUNCTION_LINE_WARN("WUPS_LOADER_HOOK_DEINIT_PLUGIN overwrote the ThreadCleanupCallback, we need to restore it!\n");
OSSetThreadCleanupCallback(OSGetCurrentThread(), saved_cleanupCallback);
}
currentThread->reserved[4] = saved_reent;
DEBUG_FUNCTION_LINE("Restore function patches of plugins.");
PluginManagement::RestoreFunctionPatches(pluginsToDeinit);
for (auto &plugin : pluginsToDeinit) {
if (const WUPSStorageError err = plugin.CloseStorage(); err != WUPS_STORAGE_ERROR_SUCCESS) {
DEBUG_FUNCTION_LINE_ERR("Failed to close storage for plugin: %s", plugin.getMetaInformation().getName().c_str());
}
}
for (const auto &pluginContainer : pluginsToDeinit) {
for (auto &cur : gTrampData) {
if (cur.id != pluginContainer.getPluginInformation().getTrampolineId()) {
continue;
}
cur.status = RELOC_TRAMP_FREE;
}
}
}
void CheckCleanupCallbackUsage(const std::vector<PluginContainer> &plugins) { void CheckCleanupCallbackUsage(const std::vector<PluginContainer> &plugins) {
auto *curThread = OSGetCurrentThread(); auto *curThread = OSGetCurrentThread();
for (const auto &cur : plugins) { for (const auto &cur : plugins) {
auto textSection = cur.getPluginInformation().getSectionInfo(".text"); const auto textSection = cur.getPluginInformation().getSectionInfo(".text");
if (!textSection) { if (!textSection) {
continue; continue;
} }
uint32_t startAddress = textSection->getAddress(); const uint32_t startAddress = textSection->getAddress();
uint32_t endAddress = textSection->getAddress() + textSection->getSize(); const uint32_t endAddress = textSection->getAddress() + textSection->getSize();
auto *pluginName = cur.getMetaInformation().getName().c_str(); auto *pluginName = cur.getMetaInformation().getName().c_str();
{ {
__OSLockScheduler(curThread); __OSLockScheduler(curThread);
int state = OSDisableInterrupts(); const int state = OSDisableInterrupts();
OSThread *t = *((OSThread **) 0x100567F8); OSThread *t = *reinterpret_cast<OSThread **>(0x100567F8);
while (t) { while (t) {
auto address = reinterpret_cast<uint32_t>(t->cleanupCallback); const auto address = reinterpret_cast<uint32_t>(t->cleanupCallback);
if (address != 0 && address >= startAddress && address <= endAddress) { if (address != 0 && address >= startAddress && address <= endAddress) {
OSReport("[WARN] PluginBackend: Thread 0x%08X is using a function from plugin %s for the threadCleanupCallback\n", t, pluginName); OSReport("[WARN] PluginBackend: Thread 0x%08X is using a function from plugin %s for the threadCleanupCallback\n", t, pluginName);
} }

View File

@ -1,6 +1,5 @@
#include "PluginContainer.h" #include "PluginContainer.h"
#include "utils/storage/StorageUtils.h"
#include <utils/storage/StorageUtils.h>
PluginContainer::PluginContainer(PluginMetaInformation metaInformation, PluginInformation pluginInformation, std::shared_ptr<PluginData> pluginData) PluginContainer::PluginContainer(PluginMetaInformation metaInformation, PluginInformation pluginInformation, std::shared_ptr<PluginData> pluginData)
: mMetaInformation(std::move(metaInformation)), : mMetaInformation(std::move(metaInformation)),
@ -12,21 +11,25 @@ PluginContainer::PluginContainer(PluginContainer &&src) noexcept : mMetaInformat
mPluginInformation(std::move(src.mPluginInformation)), mPluginInformation(std::move(src.mPluginInformation)),
mPluginData(std::move(src.mPluginData)), mPluginData(std::move(src.mPluginData)),
mPluginConfigData(std::move(src.mPluginConfigData)), mPluginConfigData(std::move(src.mPluginConfigData)),
storageRootItem(src.storageRootItem) mStorageRootItem(src.mStorageRootItem),
mInitDone(src.mInitDone)
{ {
src.storageRootItem = {}; src.mStorageRootItem = {};
src.mInitDone = {};
} }
PluginContainer &PluginContainer::operator=(PluginContainer &&src) noexcept { PluginContainer &PluginContainer::operator=(PluginContainer &&src) noexcept {
if (this != &src) { if (this != &src) {
this->mMetaInformation = src.mMetaInformation; this->mMetaInformation = std::move(src.mMetaInformation);
this->mPluginInformation = std::move(src.mPluginInformation); this->mPluginInformation = std::move(src.mPluginInformation);
this->mPluginData = std::move(src.mPluginData); this->mPluginData = std::move(src.mPluginData);
this->mPluginConfigData = std::move(src.mPluginConfigData); this->mPluginConfigData = std::move(src.mPluginConfigData);
this->storageRootItem = src.storageRootItem; this->mStorageRootItem = src.mStorageRootItem;
this->mInitDone = src.mInitDone;
src.storageRootItem = nullptr; src.mStorageRootItem = nullptr;
src.mInitDone = false;
} }
return *this; return *this;
} }
@ -48,7 +51,7 @@ std::shared_ptr<PluginData> PluginContainer::getPluginDataCopy() const {
} }
uint32_t PluginContainer::getHandle() const { uint32_t PluginContainer::getHandle() const {
return (uint32_t) this; return reinterpret_cast<uint32_t>(this);
} }
const std::optional<PluginConfigData> &PluginContainer::getConfigData() const { const std::optional<PluginConfigData> &PluginContainer::getConfigData() const {
@ -67,9 +70,9 @@ WUPSStorageError PluginContainer::OpenStorage() {
if (storageId.empty()) { if (storageId.empty()) {
return WUPS_STORAGE_ERROR_SUCCESS; return WUPS_STORAGE_ERROR_SUCCESS;
} }
auto res = StorageUtils::API::Internal::OpenStorage(storageId, storageRootItem); auto res = StorageUtils::API::Internal::OpenStorage(storageId, mStorageRootItem);
if (res != WUPS_STORAGE_ERROR_SUCCESS) { if (res != WUPS_STORAGE_ERROR_SUCCESS) {
storageRootItem = nullptr; mStorageRootItem = nullptr;
} }
return res; return res;
} }
@ -78,8 +81,20 @@ WUPSStorageError PluginContainer::CloseStorage() {
if (getMetaInformation().getWUPSVersion() < WUPSVersion(0, 8, 0)) { if (getMetaInformation().getWUPSVersion() < WUPSVersion(0, 8, 0)) {
return WUPS_STORAGE_ERROR_SUCCESS; return WUPS_STORAGE_ERROR_SUCCESS;
} }
if (storageRootItem == nullptr) { if (mStorageRootItem == nullptr) {
return WUPS_STORAGE_ERROR_SUCCESS; return WUPS_STORAGE_ERROR_SUCCESS;
} }
return StorageUtils::API::Internal::CloseStorage(storageRootItem); return StorageUtils::API::Internal::CloseStorage(mStorageRootItem);
}
wups_storage_root_item PluginContainer::getStorageRootItem() const {
return mStorageRootItem;
}
void PluginContainer::setInitDone(const bool val) {
mInitDone = val;
}
bool PluginContainer::isInitDone() const {
return mInitDone;
} }

View File

@ -30,15 +30,12 @@ class PluginContainer {
public: public:
PluginContainer(PluginMetaInformation metaInformation, PluginInformation pluginInformation, std::shared_ptr<PluginData> pluginData); PluginContainer(PluginMetaInformation metaInformation, PluginInformation pluginInformation, std::shared_ptr<PluginData> pluginData);
PluginContainer(const PluginContainer &) = delete; PluginContainer(const PluginContainer &) = delete;
PluginContainer(PluginContainer &&src) noexcept; PluginContainer(PluginContainer &&src) noexcept;
PluginContainer &operator=(PluginContainer &&src) noexcept; PluginContainer &operator=(PluginContainer &&src) noexcept;
[[nodiscard]] const PluginMetaInformation &getMetaInformation() const; [[nodiscard]] const PluginMetaInformation &getMetaInformation() const;
[[nodiscard]] const PluginInformation &getPluginInformation() const; [[nodiscard]] const PluginInformation &getPluginInformation() const;
@ -56,9 +53,11 @@ public:
WUPSStorageError CloseStorage(); WUPSStorageError CloseStorage();
[[nodiscard]] wups_storage_root_item getStorageRootItem() const { [[nodiscard]] wups_storage_root_item getStorageRootItem() const;
return storageRootItem;
} void setInitDone(bool val);
[[nodiscard]] bool isInitDone() const;
private: private:
PluginMetaInformation mMetaInformation; PluginMetaInformation mMetaInformation;
@ -66,5 +65,6 @@ private:
std::shared_ptr<PluginData> mPluginData; std::shared_ptr<PluginData> mPluginData;
std::optional<PluginConfigData> mPluginConfigData; std::optional<PluginConfigData> mPluginConfigData;
wups_storage_root_item storageRootItem = nullptr; wups_storage_root_item mStorageRootItem = nullptr;
bool mInitDone = false;
}; };

View File

@ -18,6 +18,7 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <memory>
#include <span> #include <span>
#include <string> #include <string>
#include <vector> #include <vector>
@ -44,3 +45,9 @@ private:
std::vector<uint8_t> mBuffer; std::vector<uint8_t> mBuffer;
std::string mSource; std::string mSource;
}; };
struct PluginDataSharedPtrComparator {
bool operator()(const std::shared_ptr<PluginData> &lhs, const std::shared_ptr<PluginData> &rhs) const {
return lhs->getHandle() < rhs->getHandle();
}
};

View File

@ -25,8 +25,8 @@
#include <set> #include <set>
#include <sys/dirent.h> #include <sys/dirent.h>
std::set<std::shared_ptr<PluginData>> PluginDataFactory::loadDir(const std::string_view path) { std::set<std::shared_ptr<PluginData>, PluginDataSharedPtrComparator> PluginDataFactory::loadDir(const std::string_view path) {
std::set<std::shared_ptr<PluginData>> result; std::set<std::shared_ptr<PluginData>, PluginDataSharedPtrComparator> result;
dirent *dp; dirent *dp;
DIR *dfd; DIR *dfd;

View File

@ -24,7 +24,7 @@
class PluginDataFactory { class PluginDataFactory {
public: public:
static std::set<std::shared_ptr<PluginData>> loadDir(std::string_view path); static std::set<std::shared_ptr<PluginData>, PluginDataSharedPtrComparator> loadDir(std::string_view path);
static std::unique_ptr<PluginData> load(std::string_view path); static std::unique_ptr<PluginData> load(std::string_view path);

View File

@ -35,7 +35,7 @@ extern "C" PluginBackendApiErrorType WUPSLoadAndLinkByDataHandle(const wups_back
for (const auto &pluginData : gLoadedData) { for (const auto &pluginData : gLoadedData) {
if (pluginData->getHandle() == handle) { if (pluginData->getHandle() == handle) {
gLoadOnNextLaunch.insert(pluginData); gLoadOnNextLaunch.push_back(pluginData);
found = true; found = true;
break; break;
} }

View File

@ -63,7 +63,7 @@ std::shared_ptr<T> make_shared_nothrow(Args &&...args) noexcept(noexcept(T(std::
} }
template<typename Container, typename Predicate> template<typename Container, typename Predicate>
typename std::enable_if<std::is_same<Container, std::forward_list<typename Container::value_type>>::value, bool>::type std::enable_if_t<std::is_same_v<Container, std::forward_list<typename Container::value_type>>, bool>
remove_first_if(Container &container, Predicate pred) { remove_first_if(Container &container, Predicate pred) {
auto it = container.before_begin(); auto it = container.before_begin();
@ -78,7 +78,7 @@ remove_first_if(Container &container, Predicate pred) {
} }
template<typename Container, typename Predicate> template<typename Container, typename Predicate>
typename std::enable_if<std::is_same<Container, std::set<typename Container::value_type>>::value, bool>::type std::enable_if_t<std::is_same_v<Container, std::set<typename Container::value_type, typename Container::key_compare>>, bool>
remove_first_if(Container &container, Predicate pred) { remove_first_if(Container &container, Predicate pred) {
auto it = container.begin(); auto it = container.begin();
while (it != container.end()) { while (it != container.end()) {
@ -92,7 +92,7 @@ remove_first_if(Container &container, Predicate pred) {
} }
template<typename Container, typename Predicate> template<typename Container, typename Predicate>
typename std::enable_if<std::is_same<Container, std::vector<typename Container::value_type>>::value, bool>::type std::enable_if_t<std::is_same_v<Container, std::vector<typename Container::value_type>>, bool>
remove_first_if(Container &container, Predicate pred) { remove_first_if(Container &container, Predicate pred) {
auto it = container.begin(); auto it = container.begin();
while (it != container.end()) { while (it != container.end()) {
@ -129,6 +129,12 @@ T pop_locked_first_if(std::mutex &mutex, std::vector<T> &container, Predicate pr
return result; return result;
} }
template<typename Container>
void append_move_all_values(Container &dest, Container &src) {
dest.insert(dest.end(), std::make_move_iterator(src.begin()), std::make_move_iterator(src.end()));
src.clear();
}
std::string getPluginPath(); std::string getPluginPath();
OSDynLoad_Error CustomDynLoadAlloc(int32_t size, int32_t align, void **outAddr); OSDynLoad_Error CustomDynLoadAlloc(int32_t size, int32_t align, void **outAddr);