mirror of
https://github.com/BrianPugh/game-and-watch-patch.git
synced 2025-12-16 07:16:26 +01:00
444 lines
14 KiB
Python
444 lines
14 KiB
Python
import hashlib
|
|
|
|
from colorama import Fore, Style
|
|
from Crypto.Cipher import AES
|
|
from elftools.elf.elffile import ELFFile
|
|
|
|
from .compression import lz77_decompress, lzma_compress
|
|
from .exception import InvalidStockRomError, MissingSymbolError, NotEnoughSpaceError
|
|
from .patch import DevicePatchMixin, FirmwarePatchMixin
|
|
|
|
|
|
def _val_to_color(val):
|
|
if 0x9010_0000 > val >= 0x9000_0000:
|
|
return Fore.YELLOW
|
|
elif 0x0804_0000 > val >= 0x0800_0000:
|
|
return Fore.MAGENTA
|
|
else:
|
|
return ""
|
|
|
|
|
|
class Lookup(dict):
|
|
def __repr__(self):
|
|
substrs = []
|
|
substrs.append("{")
|
|
for k, v in sorted(self.items()):
|
|
k_color = _val_to_color(k)
|
|
v_color = _val_to_color(v)
|
|
|
|
substrs.append(
|
|
f" {k_color}0x{k:08X}{Style.RESET_ALL}: "
|
|
f"{v_color}0x{v:08X}{Style.RESET_ALL},"
|
|
)
|
|
substrs.append("}")
|
|
return "\n".join(substrs)
|
|
|
|
|
|
class Firmware(FirmwarePatchMixin, bytearray):
|
|
|
|
RAM_BASE = 0x02000000
|
|
RAM_LEN = 0x00020000
|
|
ENC_LEN = 0
|
|
|
|
FLASH_BASE = 0x0000_0000
|
|
FLASH_LEN = 0
|
|
|
|
def __init__(self, firmware=None):
|
|
if firmware:
|
|
with open(firmware, "rb") as f:
|
|
firmware_data = f.read()
|
|
super().__init__(firmware_data)
|
|
else:
|
|
super().__init__(self.FLASH_LEN)
|
|
|
|
self._lookup = Lookup()
|
|
self._verify()
|
|
|
|
def _verify(self):
|
|
pass
|
|
|
|
def __getitem__(self, key):
|
|
"""Properly raises index error if trying to access oob regions."""
|
|
|
|
if isinstance(key, slice):
|
|
if key.start is not None:
|
|
try:
|
|
self[key.start]
|
|
except IndexError:
|
|
raise IndexError(
|
|
f"Index {key.start} ({hex(key.start)}) out of range"
|
|
) from None
|
|
if key.stop is not None:
|
|
try:
|
|
self[key.stop - 1]
|
|
except IndexError:
|
|
raise IndexError(
|
|
f"Index {key.stop - 1} ({hex(key.stop - 1)}) out of range"
|
|
) from None
|
|
|
|
return super().__getitem__(key)
|
|
|
|
def __setitem__(self, key, new_val):
|
|
"""Properly raises index error if trying to access oob regions."""
|
|
|
|
if isinstance(key, slice):
|
|
if key.start is not None:
|
|
try:
|
|
self[key.start]
|
|
except IndexError:
|
|
raise NotEnoughSpaceError(
|
|
f"Starting index {key.start} ({hex(key.start)}) exceeds "
|
|
f"firmware length {len(self)} ({hex(len(self))})"
|
|
) from None
|
|
if key.stop is not None:
|
|
try:
|
|
self[key.stop - 1]
|
|
except IndexError:
|
|
raise NotEnoughSpaceError(
|
|
f"Ending index {key.stop - 1} ({hex(key.stop - 1)}) exceeds "
|
|
f"firmware length {len(self)} ({hex(len(self))})"
|
|
) from None
|
|
|
|
return super().__setitem__(key, new_val)
|
|
|
|
def int(self, offset: int, size=4):
|
|
return int.from_bytes(self[offset : offset + size], "little")
|
|
|
|
def set_range(self, start: int, end: int, val: bytes):
|
|
self[start:end] = val * (end - start)
|
|
return end - start
|
|
|
|
def clear_range(self, start: int, end: int):
|
|
return self.set_range(start, end, val=b"\x00")
|
|
|
|
def show(self, wrap=1024, show=True):
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.ticker as ticker
|
|
import numpy as np
|
|
|
|
def to_hex(x, pos):
|
|
return f"0x{int(x):06X}"
|
|
|
|
def to_hex_wrap(x, pos):
|
|
return f"0x{int(x)*wrap:06X}"
|
|
|
|
n_bytes = len(self)
|
|
rows = int(np.ceil(n_bytes / wrap))
|
|
occupied = np.array(self) != 0
|
|
plt.imshow(occupied.reshape(rows, wrap))
|
|
plt.title(str(self))
|
|
axes = plt.gca()
|
|
axes.get_xaxis().set_major_locator(ticker.MultipleLocator(128))
|
|
axes.get_xaxis().set_major_formatter(ticker.FuncFormatter(to_hex))
|
|
axes.get_yaxis().set_major_locator(ticker.MultipleLocator(32))
|
|
axes.get_yaxis().set_major_formatter(ticker.FuncFormatter(to_hex_wrap))
|
|
if show:
|
|
plt.show()
|
|
|
|
|
|
class RWData:
|
|
"""
|
|
Assumptions (which are valid for this firmware):
|
|
1. Only compressed rwdata is after this table
|
|
2. We are only modifying the lz_decompress stuff.
|
|
"""
|
|
|
|
# THIS HAS TO AGREE WITH THE LINKER
|
|
MAX_TABLE_ELEMENTS = 5
|
|
|
|
def __init__(self, firmware, table_start, table_len):
|
|
# We want to be able to extend the table.
|
|
|
|
self.firmware = firmware
|
|
self.table_start = table_start
|
|
self.__compressed_len_memo = {}
|
|
|
|
self.datas, self.dsts = [], []
|
|
|
|
for i in range(table_start, table_start + table_len - 4, 16):
|
|
# First thing is pointer to executable, need to always replace this
|
|
# to our lzma
|
|
rel_offset_to_fn = firmware.int(i)
|
|
if rel_offset_to_fn > 0x8000_0000:
|
|
rel_offset_to_fn -= 0x1_0000_0000
|
|
fn_addr = i + rel_offset_to_fn
|
|
assert fn_addr == 0x18005 # lz_decompress function
|
|
i += 4
|
|
|
|
data_addr = i + firmware.int(i)
|
|
i += 4
|
|
data_len = firmware.int(i) >> 1
|
|
i += 4
|
|
data_dst = firmware.int(i)
|
|
i += 4
|
|
|
|
data = lz77_decompress(firmware[data_addr : data_addr + data_len])
|
|
print(f" lz77 decompressed data {data_len} -> {len(data)}")
|
|
firmware.clear_range(data_addr, data_addr + data_len)
|
|
|
|
self.append(data, data_dst)
|
|
|
|
last_element_offset = table_start + table_len - 4
|
|
self.last_fn = firmware.int(last_element_offset)
|
|
if self.last_fn > 0x8000_0000:
|
|
self.last_fn -= 0x1_0000_0000
|
|
self.last_fn += last_element_offset
|
|
|
|
# Mark this area as reserved; there's nothing special about 0x77, its
|
|
# just not 0x00
|
|
firmware.set_range(
|
|
table_start, table_start + 16 * self.MAX_TABLE_ELEMENTS + 4, b"\x77"
|
|
)
|
|
|
|
def __getitem__(self, k):
|
|
return self.datas[k]
|
|
|
|
@property
|
|
def table_end(self):
|
|
return self.table_start + 4 * 4 * len(self.datas) + 4 + 4
|
|
|
|
def append(self, data, dst):
|
|
"""Add a new element to the table"""
|
|
|
|
if len(self.datas) >= self.MAX_TABLE_ELEMENTS:
|
|
raise NotEnoughSpaceError(
|
|
f"MAX_TABLE_ELEMENTS value {self.MAX_TABLE_ELEMENTS} exceeded"
|
|
)
|
|
|
|
self.datas.append(data)
|
|
self.dsts.append(dst)
|
|
|
|
assert len(self.datas) == len(self.dsts)
|
|
|
|
@property
|
|
def compressed_len(self):
|
|
compressed_len = 0
|
|
for data in self.datas:
|
|
data = bytes(data)
|
|
if data not in self.__compressed_len_memo:
|
|
compressed_data = lzma_compress(bytes(data))
|
|
self.__compressed_len_memo[data] = len(compressed_data)
|
|
compressed_len += self.__compressed_len_memo[data]
|
|
return compressed_len
|
|
|
|
def write_table_and_data(self, data_offset=None):
|
|
"""
|
|
Parameters
|
|
----------
|
|
data_offset : int
|
|
Where to write the compressed data
|
|
"""
|
|
|
|
# Write Compressed Data
|
|
data_addrs, data_lens = [], []
|
|
if data_offset is None:
|
|
index = self.table_end
|
|
else:
|
|
index = data_offset
|
|
|
|
total_len = 0
|
|
for data in self.datas:
|
|
compressed_data = lzma_compress(bytes(data))
|
|
print(
|
|
f" compressed {len(data)}->{len(compressed_data)} bytes "
|
|
f"(saves {len(data)-len(compressed_data)}). "
|
|
f"Writing to 0x{index:05X}"
|
|
)
|
|
self.firmware[index : index + len(compressed_data)] = compressed_data
|
|
|
|
data_addrs.append(index)
|
|
data_lens.append(len(compressed_data))
|
|
|
|
index += len(compressed_data)
|
|
total_len += len(compressed_data)
|
|
|
|
# Write Table
|
|
index = self.table_start
|
|
assert len(data_addrs) == len(data_lens) == len(self.dsts)
|
|
for data_addr, data_len, data_dst in zip(data_addrs, data_lens, self.dsts):
|
|
self.firmware.relative(index, "rwdata_inflate")
|
|
index += 4
|
|
|
|
# Assumes that the data will be after the table.
|
|
rel_addr = data_addr - index
|
|
if rel_addr < 0:
|
|
rel_addr += 0x1_0000_0000
|
|
self.firmware.replace(index, rel_addr, size=4)
|
|
index += 4
|
|
|
|
self.firmware.replace(index, data_len, size=4)
|
|
index += 4
|
|
|
|
self.firmware.replace(index, data_dst, size=4)
|
|
index += 4
|
|
|
|
self.firmware.relative(index, "bss_rwdata_init")
|
|
index += 4
|
|
|
|
self.firmware.relative(index, self.last_fn, size=4)
|
|
index += 4
|
|
|
|
assert index == self.table_end
|
|
|
|
# Update the pointer to the end of table in the loader
|
|
self.firmware.relative(0x17DB4, index, size=4)
|
|
|
|
print(self)
|
|
|
|
return total_len
|
|
|
|
def __str__(self):
|
|
"""Returns the **written** table.
|
|
|
|
Doesn't show unstaged changes.
|
|
"""
|
|
substrs = []
|
|
substrs.append("")
|
|
substrs.append("RWData Table")
|
|
substrs.append("------------")
|
|
for addr in range(self.table_start, self.table_end - 4 - 4, 16):
|
|
substrs.append(
|
|
f"0x{addr:08X}: "
|
|
f"0x{self.firmware.int(addr + 0):08X} "
|
|
f"0x{self.firmware.int(addr + 4):08X} "
|
|
f"0x{self.firmware.int(addr + 8):08X} "
|
|
f"0x{self.firmware.int(addr + 12):08X} "
|
|
)
|
|
addr = self.table_end - 8
|
|
substrs.append(f"0x{addr:08X}: 0x{self.firmware.int(addr + 0):08X}")
|
|
addr = self.table_end - 4
|
|
substrs.append(f"0x{addr:08X}: 0x{self.firmware.int(addr + 0):08X}")
|
|
|
|
substrs.append("")
|
|
return "\n".join(substrs)
|
|
|
|
|
|
class IntFirmware(Firmware):
|
|
STOCK_ROM_SHA1_HASH = "efa04c387ad7b40549e15799b471a6e1cd234c76"
|
|
|
|
FLASH_BASE = 0x08000000
|
|
FLASH_LEN = 0x00020000
|
|
|
|
STOCK_ROM_END = 0x00019300 # Actual stock rom end
|
|
|
|
def __init__(self, firmware, elf):
|
|
super().__init__(firmware)
|
|
self._elf_f = open(elf, "rb")
|
|
self.elf = ELFFile(self._elf_f)
|
|
self.symtab = self.elf.get_section_by_name(".symtab")
|
|
|
|
self.rwdata = RWData(self, 0x1_80A4, 36)
|
|
|
|
def __str__(self):
|
|
return "internal"
|
|
|
|
def _verify(self):
|
|
h = hashlib.sha1(self).hexdigest()
|
|
if h != self.STOCK_ROM_SHA1_HASH:
|
|
raise InvalidStockRomError
|
|
|
|
def address(self, symbol_name, sub_base=False):
|
|
symbols = self.symtab.get_symbol_by_name(symbol_name)
|
|
if not symbols:
|
|
raise MissingSymbolError(f'Cannot find symbol "{symbol_name}"')
|
|
address = symbols[0]["st_value"]
|
|
if address == 0:
|
|
raise MissingSymbolError(f"{symbol_name} has address 0x0")
|
|
print(f" found {symbol_name} at 0x{address:08X}")
|
|
if sub_base:
|
|
address -= self.FLASH_BASE
|
|
return address
|
|
|
|
@property
|
|
def key(self):
|
|
offset = 0x106F4
|
|
return self[offset : offset + 16]
|
|
|
|
@property
|
|
def nonce(self):
|
|
offset = 0x106E4
|
|
return self[offset : offset + 8]
|
|
|
|
|
|
def _nonce_to_iv(nonce):
|
|
# need to convert nonce to 2
|
|
assert len(nonce) == 8
|
|
nonce = nonce[::-1]
|
|
# The lower 28bits (counter) will be updated in `crypt` method
|
|
return nonce + b"\x00\x00" + b"\x71\x23" + b"\x20\x00" + b"\x00\x00"
|
|
|
|
|
|
class ExtFirmware(Firmware):
|
|
STOCK_ROM_SHA1_HASH = "eea70bb171afece163fb4b293c5364ddb90637ae"
|
|
|
|
FLASH_BASE = 0x9000_0000
|
|
FLASH_LEN = 0x0010_0000
|
|
ENC_LEN = 0xF_E000 # end address at 0x080106ec
|
|
STOCK_ROM_END = 0x0010_0000
|
|
|
|
def __str__(self):
|
|
return "external"
|
|
|
|
def _verify(self):
|
|
h = hashlib.sha1(self[:-8192]).hexdigest()
|
|
if h != self.STOCK_ROM_SHA1_HASH:
|
|
raise InvalidStockRomError
|
|
|
|
def crypt(self, key, nonce):
|
|
"""Decrypts if encrypted; encrypts if in plain text."""
|
|
key = bytes(key[::-1])
|
|
iv = bytearray(_nonce_to_iv(nonce))
|
|
|
|
aes = AES.new(key, AES.MODE_ECB)
|
|
|
|
for offset in range(0, self.ENC_LEN, 128 // 8):
|
|
counter_block = iv.copy()
|
|
|
|
counter = (self.FLASH_BASE + offset) >> 4
|
|
counter_block[12] = ((counter >> 24) & 0x0F) | (counter_block[12] & 0xF0)
|
|
counter_block[13] = (counter >> 16) & 0xFF
|
|
counter_block[14] = (counter >> 8) & 0xFF
|
|
counter_block[15] = (counter >> 0) & 0xFF
|
|
|
|
cipher_block = aes.encrypt(bytes(counter_block))
|
|
for i, cipher_byte in enumerate(reversed(cipher_block)):
|
|
self[offset + i] ^= cipher_byte
|
|
|
|
|
|
class SRAM3(Firmware):
|
|
# This address of unused ram was found via tools/mem_observer.py
|
|
FLASH_BASE = 0x240F2124
|
|
FLASH_LEN = 0x24100000 - FLASH_BASE
|
|
|
|
def __str__(self):
|
|
return "sram3"
|
|
|
|
|
|
class Device(DevicePatchMixin):
|
|
def __init__(self, internal, external):
|
|
self.internal = internal
|
|
self.external = external
|
|
|
|
self.sram3 = SRAM3()
|
|
|
|
self.lookup = Lookup()
|
|
self.internal._lookup = self.lookup
|
|
self.external._lookup = self.lookup
|
|
self.sram3._lookup = self.lookup
|
|
|
|
def crypt(self):
|
|
self.external.crypt(self.internal.key, self.internal.nonce)
|
|
|
|
def show(self, show=True):
|
|
import matplotlib.pyplot as plt
|
|
|
|
if len(self.external):
|
|
plt.subplot(2, 1, 1)
|
|
self.internal.show(show=False)
|
|
plt.subplot(2, 1, 2)
|
|
self.external.show(show=False)
|
|
else:
|
|
self.internal.show(show=False)
|
|
if show:
|
|
plt.show()
|