Merge pull request #11191 from JosJuice/jitarm64-no-checked-entry

JitArm64: Never check downcount on block entry
This commit is contained in:
JosJuice 2023-08-26 17:00:08 +02:00 committed by GitHub
commit cd31da97d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 156 additions and 68 deletions

View File

@ -311,7 +311,6 @@ void CachedInterpreter::Jit(u32 address)
js.numFloatingPointInst = 0; js.numFloatingPointInst = 0;
js.curBlock = b; js.curBlock = b;
b->checkedEntry = GetCodePtr();
b->normalEntry = GetCodePtr(); b->normalEntry = GetCodePtr();
for (u32 i = 0; i < code_block.m_num_instructions; i++) for (u32 i = 0; i < code_block.m_num_instructions; i++)
@ -374,7 +373,7 @@ void CachedInterpreter::Jit(u32 address)
} }
m_code.emplace_back(); m_code.emplace_back();
b->codeSize = (u32)(GetCodePtr() - b->checkedEntry); b->codeSize = static_cast<u32>(GetCodePtr() - b->normalEntry);
b->originalSize = code_block.m_num_instructions; b->originalSize = code_block.m_num_instructions;
m_block_cache.FinalizeBlock(*b, jo.enableBlocklink, code_block.m_physical_addresses); m_block_cache.FinalizeBlock(*b, jo.enableBlocklink, code_block.m_physical_addresses);

View File

@ -851,9 +851,7 @@ bool Jit64::DoJit(u32 em_address, JitBlock* b, u32 nextPC)
js.numFloatingPointInst = 0; js.numFloatingPointInst = 0;
// TODO: Test if this or AlignCode16 make a difference from GetCodePtr // TODO: Test if this or AlignCode16 make a difference from GetCodePtr
u8* const start = AlignCode4(); b->normalEntry = AlignCode4();
b->checkedEntry = start;
b->normalEntry = start;
// Used to get a trace of the last few blocks before a crash, sometimes VERY useful // Used to get a trace of the last few blocks before a crash, sometimes VERY useful
if (m_im_here_debug) if (m_im_here_debug)
@ -1161,7 +1159,7 @@ bool Jit64::DoJit(u32 em_address, JitBlock* b, u32 nextPC)
return false; return false;
} }
b->codeSize = (u32)(GetCodePtr() - start); b->codeSize = static_cast<u32>(GetCodePtr() - b->normalEntry);
b->originalSize = code_block.m_num_instructions; b->originalSize = code_block.m_num_instructions;
#ifdef JIT_LOG_GENERATED_CODE #ifdef JIT_LOG_GENERATED_CODE

View File

@ -14,8 +14,7 @@ JitBlockCache::JitBlockCache(JitBase& jit) : JitBaseBlockCache{jit}
void JitBlockCache::WriteLinkBlock(const JitBlock::LinkData& source, const JitBlock* dest) void JitBlockCache::WriteLinkBlock(const JitBlock::LinkData& source, const JitBlock* dest)
{ {
u8* location = source.exitPtrs; u8* location = source.exitPtrs;
const u8* address = const u8* address = dest ? dest->normalEntry : m_jit.GetAsmRoutines()->dispatcher_no_timing_check;
dest ? dest->checkedEntry : m_jit.GetAsmRoutines()->dispatcher_no_timing_check;
if (source.call) if (source.call)
{ {
Gen::XEmitter emit(location, location + 5); Gen::XEmitter emit(location, location + 5);
@ -42,11 +41,9 @@ void JitBlockCache::WriteLinkBlock(const JitBlock::LinkData& source, const JitBl
void JitBlockCache::WriteDestroyBlock(const JitBlock& block) void JitBlockCache::WriteDestroyBlock(const JitBlock& block)
{ {
// Only clear the entry points as we might still be within this block. // Only clear the entry point as we might still be within this block.
Gen::XEmitter emit(block.checkedEntry, block.checkedEntry + 1); Gen::XEmitter emit(block.normalEntry, block.normalEntry + 1);
emit.INT3(); emit.INT3();
Gen::XEmitter emit2(block.normalEntry, block.normalEntry + 1);
emit2.INT3();
} }
void JitBlockCache::Init() void JitBlockCache::Init()

View File

@ -382,35 +382,79 @@ void JitArm64::WriteExit(u32 destination, bool LK, u32 exit_address_after_return
LK &= m_enable_blr_optimization; LK &= m_enable_blr_optimization;
const u8* host_address_after_return;
if (LK) if (LK)
{ {
// Push {ARM_PC+20; PPC_PC} on the stack // Push {ARM_PC; PPC_PC} on the stack
MOVI2R(ARM64Reg::X1, exit_address_after_return); MOVI2R(ARM64Reg::X1, exit_address_after_return);
ADR(ARM64Reg::X0, 20); constexpr s32 adr_offset = JitArm64BlockCache::BLOCK_LINK_SIZE + sizeof(u32) * 2;
host_address_after_return = GetCodePtr() + adr_offset;
ADR(ARM64Reg::X0, adr_offset);
STP(IndexType::Pre, ARM64Reg::X0, ARM64Reg::X1, ARM64Reg::SP, -16); STP(IndexType::Pre, ARM64Reg::X0, ARM64Reg::X1, ARM64Reg::SP, -16);
} }
constexpr size_t primary_farcode_size = 3 * sizeof(u32);
const bool switch_to_far_code = !IsInFarCode();
const u8* primary_farcode_addr;
if (switch_to_far_code)
{
SwitchToFarCode();
primary_farcode_addr = GetCodePtr();
SwitchToNearCode();
}
else
{
primary_farcode_addr = GetCodePtr() + JitArm64BlockCache::BLOCK_LINK_SIZE +
(LK ? JitArm64BlockCache::BLOCK_LINK_SIZE : 0);
}
const u8* return_farcode_addr = primary_farcode_addr + primary_farcode_size;
JitBlock* b = js.curBlock; JitBlock* b = js.curBlock;
JitBlock::LinkData linkData; JitBlock::LinkData linkData;
linkData.exitAddress = destination; linkData.exitAddress = destination;
linkData.exitPtrs = GetWritableCodePtr(); linkData.exitPtrs = GetWritableCodePtr();
linkData.linkStatus = false; linkData.linkStatus = false;
linkData.call = LK; linkData.call = LK;
linkData.exitFarcode = primary_farcode_addr;
b->linkData.push_back(linkData); b->linkData.push_back(linkData);
blocks.WriteLinkBlock(*this, linkData); blocks.WriteLinkBlock(*this, linkData);
if (LK) if (LK)
{ {
DEBUG_ASSERT(GetCodePtr() == host_address_after_return || HasWriteFailed());
// Write the regular exit node after the return. // Write the regular exit node after the return.
linkData.exitAddress = exit_address_after_return; linkData.exitAddress = exit_address_after_return;
linkData.exitPtrs = GetWritableCodePtr(); linkData.exitPtrs = GetWritableCodePtr();
linkData.linkStatus = false; linkData.linkStatus = false;
linkData.call = false; linkData.call = false;
linkData.exitFarcode = return_farcode_addr;
b->linkData.push_back(linkData); b->linkData.push_back(linkData);
blocks.WriteLinkBlock(*this, linkData); blocks.WriteLinkBlock(*this, linkData);
} }
if (switch_to_far_code)
SwitchToFarCode();
DEBUG_ASSERT(GetCodePtr() == primary_farcode_addr || HasWriteFailed());
MOVI2R(DISPATCHER_PC, destination);
if (LK)
BL(GetAsmRoutines()->do_timing);
else
B(GetAsmRoutines()->do_timing);
if (LK)
{
if (GetCodePtr() == return_farcode_addr - sizeof(u32))
BRK(101);
DEBUG_ASSERT(GetCodePtr() == return_farcode_addr || HasWriteFailed());
MOVI2R(DISPATCHER_PC, exit_address_after_return);
B(GetAsmRoutines()->do_timing);
}
if (switch_to_far_code)
SwitchToNearCode();
} }
void JitArm64::WriteExit(Arm64Gen::ARM64Reg dest, bool LK, u32 exit_address_after_return) void JitArm64::WriteExit(Arm64Gen::ARM64Reg dest, bool LK, u32 exit_address_after_return)
@ -432,10 +476,13 @@ void JitArm64::WriteExit(Arm64Gen::ARM64Reg dest, bool LK, u32 exit_address_afte
{ {
// Push {ARM_PC, PPC_PC} on the stack // Push {ARM_PC, PPC_PC} on the stack
MOVI2R(ARM64Reg::X1, exit_address_after_return); MOVI2R(ARM64Reg::X1, exit_address_after_return);
ADR(ARM64Reg::X0, 12); constexpr s32 adr_offset = sizeof(u32) * 3;
const u8* host_address_after_return = GetCodePtr() + adr_offset;
ADR(ARM64Reg::X0, adr_offset);
STP(IndexType::Pre, ARM64Reg::X0, ARM64Reg::X1, ARM64Reg::SP, -16); STP(IndexType::Pre, ARM64Reg::X0, ARM64Reg::X1, ARM64Reg::SP, -16);
BL(dispatcher); BL(dispatcher);
DEBUG_ASSERT(GetCodePtr() == host_address_after_return || HasWriteFailed());
// Write the regular exit node after the return. // Write the regular exit node after the return.
JitBlock* b = js.curBlock; JitBlock* b = js.curBlock;
@ -444,9 +491,27 @@ void JitArm64::WriteExit(Arm64Gen::ARM64Reg dest, bool LK, u32 exit_address_afte
linkData.exitPtrs = GetWritableCodePtr(); linkData.exitPtrs = GetWritableCodePtr();
linkData.linkStatus = false; linkData.linkStatus = false;
linkData.call = false; linkData.call = false;
const bool switch_to_far_code = !IsInFarCode();
if (switch_to_far_code)
{
SwitchToFarCode();
linkData.exitFarcode = GetCodePtr();
SwitchToNearCode();
}
else
{
linkData.exitFarcode = GetCodePtr() + JitArm64BlockCache::BLOCK_LINK_SIZE;
}
b->linkData.push_back(linkData); b->linkData.push_back(linkData);
blocks.WriteLinkBlock(*this, linkData); blocks.WriteLinkBlock(*this, linkData);
if (switch_to_far_code)
SwitchToFarCode();
MOVI2R(DISPATCHER_PC, exit_address_after_return);
B(GetAsmRoutines()->do_timing);
if (switch_to_far_code)
SwitchToNearCode();
} }
} }
@ -461,11 +526,14 @@ void JitArm64::FakeLKExit(u32 exit_address_after_return)
ARM64Reg after_reg = gpr.GetReg(); ARM64Reg after_reg = gpr.GetReg();
ARM64Reg code_reg = gpr.GetReg(); ARM64Reg code_reg = gpr.GetReg();
MOVI2R(after_reg, exit_address_after_return); MOVI2R(after_reg, exit_address_after_return);
ADR(EncodeRegTo64(code_reg), 12); constexpr s32 adr_offset = sizeof(u32) * 3;
const u8* host_address_after_return = GetCodePtr() + adr_offset;
ADR(EncodeRegTo64(code_reg), adr_offset);
STP(IndexType::Pre, EncodeRegTo64(code_reg), EncodeRegTo64(after_reg), ARM64Reg::SP, -16); STP(IndexType::Pre, EncodeRegTo64(code_reg), EncodeRegTo64(after_reg), ARM64Reg::SP, -16);
gpr.Unlock(after_reg, code_reg); gpr.Unlock(after_reg, code_reg);
FixupBranch skip_exit = BL(); FixupBranch skip_exit = BL();
DEBUG_ASSERT(GetCodePtr() == host_address_after_return || HasWriteFailed());
gpr.Unlock(ARM64Reg::W30); gpr.Unlock(ARM64Reg::W30);
// Write the regular exit node after the return. // Write the regular exit node after the return.
@ -475,10 +543,28 @@ void JitArm64::FakeLKExit(u32 exit_address_after_return)
linkData.exitPtrs = GetWritableCodePtr(); linkData.exitPtrs = GetWritableCodePtr();
linkData.linkStatus = false; linkData.linkStatus = false;
linkData.call = false; linkData.call = false;
const bool switch_to_far_code = !IsInFarCode();
if (switch_to_far_code)
{
SwitchToFarCode();
linkData.exitFarcode = GetCodePtr();
SwitchToNearCode();
}
else
{
linkData.exitFarcode = GetCodePtr() + JitArm64BlockCache::BLOCK_LINK_SIZE;
}
b->linkData.push_back(linkData); b->linkData.push_back(linkData);
blocks.WriteLinkBlock(*this, linkData); blocks.WriteLinkBlock(*this, linkData);
if (switch_to_far_code)
SwitchToFarCode();
MOVI2R(DISPATCHER_PC, exit_address_after_return);
B(GetAsmRoutines()->do_timing);
if (switch_to_far_code)
SwitchToNearCode();
SetJumpTarget(skip_exit); SetJumpTarget(skip_exit);
} }
@ -862,18 +948,6 @@ bool JitArm64::DoJit(u32 em_address, JitBlock* b, u32 nextPC)
js.numLoadStoreInst = 0; js.numLoadStoreInst = 0;
js.numFloatingPointInst = 0; js.numFloatingPointInst = 0;
u8* const start = GetWritableCodePtr();
b->checkedEntry = start;
// Downcount flag check, Only valid for linked blocks
{
FixupBranch bail = B(CC_PL);
MOVI2R(DISPATCHER_PC, js.blockStart);
B(do_timing);
SetJumpTarget(bail);
}
// Normal entry doesn't need to check for downcount.
b->normalEntry = GetWritableCodePtr(); b->normalEntry = GetWritableCodePtr();
// Conditionally add profiling code. // Conditionally add profiling code.
@ -1141,7 +1215,7 @@ bool JitArm64::DoJit(u32 em_address, JitBlock* b, u32 nextPC)
return false; return false;
} }
b->codeSize = (u32)(GetCodePtr() - start); b->codeSize = static_cast<u32>(GetCodePtr() - b->normalEntry);
b->originalSize = code_block.m_num_instructions; b->originalSize = code_block.m_num_instructions;
FlushIcache(); FlushIcache();

View File

@ -22,52 +22,70 @@ void JitArm64BlockCache::Init()
void JitArm64BlockCache::WriteLinkBlock(Arm64Gen::ARM64XEmitter& emit, void JitArm64BlockCache::WriteLinkBlock(Arm64Gen::ARM64XEmitter& emit,
const JitBlock::LinkData& source, const JitBlock* dest) const JitBlock::LinkData& source, const JitBlock* dest)
{ {
const u8* start = emit.GetCodePtr();
if (!dest) if (!dest)
{ {
// Use a fixed amount of instructions, so we can assume to use 3 instructions on patching. emit.MOVI2R(DISPATCHER_PC, source.exitAddress);
emit.MOVZ(DISPATCHER_PC, source.exitAddress & 0xFFFF, ShiftAmount::Shift0);
emit.MOVK(DISPATCHER_PC, source.exitAddress >> 16, ShiftAmount::Shift16);
if (source.call) if (source.call)
{
if (emit.GetCodePtr() == start + BLOCK_LINK_FAST_BL_OFFSET - sizeof(u32))
emit.NOP();
DEBUG_ASSERT(emit.GetCodePtr() == start + BLOCK_LINK_FAST_BL_OFFSET || emit.HasWriteFailed());
emit.BL(m_jit.GetAsmRoutines()->dispatcher); emit.BL(m_jit.GetAsmRoutines()->dispatcher);
}
else else
{
emit.B(m_jit.GetAsmRoutines()->dispatcher); emit.B(m_jit.GetAsmRoutines()->dispatcher);
return; }
}
else
{
if (source.call)
{
// The "fast" BL should be the last instruction, so that the return address matches the
// address that was pushed onto the stack by the function that called WriteLinkBlock
FixupBranch fast = emit.B(CC_GT);
emit.B(source.exitFarcode);
DEBUG_ASSERT(emit.GetCodePtr() == start + BLOCK_LINK_FAST_BL_OFFSET || emit.HasWriteFailed());
emit.SetJumpTarget(fast);
emit.BL(dest->normalEntry);
}
else
{
// Are we able to jump directly to the block?
s64 block_distance = ((s64)dest->normalEntry - (s64)emit.GetCodePtr()) >> 2;
if (block_distance >= -0x40000 && block_distance <= 0x3FFFF)
{
emit.B(CC_GT, dest->normalEntry);
emit.B(source.exitFarcode);
}
else
{
FixupBranch slow = emit.B(CC_LE);
emit.B(dest->normalEntry);
emit.SetJumpTarget(slow);
emit.B(source.exitFarcode);
}
}
} }
if (source.call) // Use a fixed number of instructions so we have enough room for any patching needed later.
const u8* end = start + BLOCK_LINK_SIZE;
while (emit.GetCodePtr() < end)
{ {
// The "fast" BL must be the third instruction. So just use the former two to inline the
// downcount check here. It's better to do this near jump before the long jump to the other
// block.
FixupBranch fast_link = emit.B(CC_GT);
emit.BL(dest->checkedEntry);
emit.SetJumpTarget(fast_link);
emit.BL(dest->normalEntry);
return;
}
// Are we able to jump directly to the normal entry?
s64 distance = ((s64)dest->normalEntry - (s64)emit.GetCodePtr()) >> 2;
if (distance >= -0x40000 && distance <= 0x3FFFF)
{
emit.B(CC_GT, dest->normalEntry);
emit.B(dest->checkedEntry);
emit.BRK(101); emit.BRK(101);
return; if (emit.HasWriteFailed())
return;
} }
ASSERT(emit.GetCodePtr() == end);
FixupBranch fast_link = emit.B(CC_GT);
emit.B(dest->checkedEntry);
emit.SetJumpTarget(fast_link);
emit.B(dest->normalEntry);
} }
void JitArm64BlockCache::WriteLinkBlock(const JitBlock::LinkData& source, const JitBlock* dest) void JitArm64BlockCache::WriteLinkBlock(const JitBlock::LinkData& source, const JitBlock* dest)
{ {
const Common::ScopedJITPageWriteAndNoExecute enable_jit_page_writes; const Common::ScopedJITPageWriteAndNoExecute enable_jit_page_writes;
u8* location = source.exitPtrs; u8* location = source.exitPtrs;
ARM64XEmitter emit(location, location + 12); ARM64XEmitter emit(location, location + BLOCK_LINK_SIZE);
WriteLinkBlock(emit, source, dest); WriteLinkBlock(emit, source, dest);
emit.FlushIcache(); emit.FlushIcache();
@ -75,11 +93,10 @@ void JitArm64BlockCache::WriteLinkBlock(const JitBlock::LinkData& source, const
void JitArm64BlockCache::WriteDestroyBlock(const JitBlock& block) void JitArm64BlockCache::WriteDestroyBlock(const JitBlock& block)
{ {
// Only clear the entry points as we might still be within this block. // Only clear the entry point as we might still be within this block.
ARM64XEmitter emit(block.checkedEntry, block.normalEntry + 4); ARM64XEmitter emit(block.normalEntry, block.normalEntry + 4);
const Common::ScopedJITPageWriteAndNoExecute enable_jit_page_writes; const Common::ScopedJITPageWriteAndNoExecute enable_jit_page_writes;
while (emit.GetWritableCodePtr() <= block.normalEntry) emit.BRK(0x123);
emit.BRK(0x123);
emit.FlushIcache(); emit.FlushIcache();
} }

View File

@ -29,6 +29,9 @@ public:
void WriteLinkBlock(Arm64Gen::ARM64XEmitter& emit, const JitBlock::LinkData& source, void WriteLinkBlock(Arm64Gen::ARM64XEmitter& emit, const JitBlock::LinkData& source,
const JitBlock* dest = nullptr); const JitBlock* dest = nullptr);
static constexpr size_t BLOCK_LINK_SIZE = 3 * sizeof(u32);
static constexpr size_t BLOCK_LINK_FAST_BL_OFFSET = BLOCK_LINK_SIZE - sizeof(u32);
private: private:
void WriteLinkBlock(const JitBlock::LinkData& source, const JitBlock* dest) override; void WriteLinkBlock(const JitBlock::LinkData& source, const JitBlock* dest) override;
void WriteDestroyBlock(const JitBlock& block) override; void WriteDestroyBlock(const JitBlock& block) override;

View File

@ -163,12 +163,12 @@ void JitBaseBlockCache::FinalizeBlock(JitBlock& block, bool block_link,
if (Common::JitRegister::IsEnabled() && if (Common::JitRegister::IsEnabled() &&
(symbol = g_symbolDB.GetSymbolFromAddr(block.effectiveAddress)) != nullptr) (symbol = g_symbolDB.GetSymbolFromAddr(block.effectiveAddress)) != nullptr)
{ {
Common::JitRegister::Register(block.checkedEntry, block.codeSize, "JIT_PPC_{}_{:08x}", Common::JitRegister::Register(block.normalEntry, block.codeSize, "JIT_PPC_{}_{:08x}",
symbol->function_name.c_str(), block.physicalAddress); symbol->function_name.c_str(), block.physicalAddress);
} }
else else
{ {
Common::JitRegister::Register(block.checkedEntry, block.codeSize, "JIT_PPC_{:08x}", Common::JitRegister::Register(block.normalEntry, block.codeSize, "JIT_PPC_{:08x}",
block.physicalAddress); block.physicalAddress);
} }
} }

View File

@ -30,9 +30,6 @@ struct JitBlockData
u8* far_begin; u8* far_begin;
u8* far_end; u8* far_end;
// A special entry point for block linking; usually used to check the
// downcount.
u8* checkedEntry;
// The normal entry point for the block, returned by Dispatch(). // The normal entry point for the block, returned by Dispatch().
u8* normalEntry; u8* normalEntry;
@ -73,6 +70,9 @@ struct JitBlock : public JitBlockData
struct LinkData struct LinkData
{ {
u8* exitPtrs; // to be able to rewrite the exit jump u8* exitPtrs; // to be able to rewrite the exit jump
#ifdef _M_ARM_64
const u8* exitFarcode;
#endif
u32 exitAddress; u32 exitAddress;
bool linkStatus; // is it already linked? bool linkStatus; // is it already linked?
bool call; bool call;

View File

@ -206,7 +206,7 @@ JitInterface::GetHostCode(u32 address) const
} }
GetHostCodeResult result; GetHostCodeResult result;
result.code = block->checkedEntry; result.code = block->normalEntry;
result.code_size = block->codeSize; result.code_size = block->codeSize;
result.entry_address = block->effectiveAddress; result.entry_address = block->effectiveAddress;
return result; return result;