Files
game-and-watch-patch/patches/firmware.py
2021-09-24 09:08:22 -07:00

424 lines
13 KiB
Python

import hashlib
from Crypto.Cipher import AES
from elftools.elf.elffile import ELFFile
from colorama import Fore, Back, Style
from .patch import DevicePatchMixin, FirmwarePatchMixin
from .exception import InvalidStockRomError, MissingSymbolError, NotEnoughSpaceError
from .compression import lz77_decompress, lzma_compress
def _val_to_color(val):
if 0x9010_0000 > val >= 0x9000_0000:
return Fore.YELLOW
elif 0x0804_0000 > val >= 0x0800_0000:
return Fore.MAGENTA
else:
return ""
class Lookup(dict):
def __repr__(self):
substrs = []
substrs.append("{")
for k, v in sorted(self.items()):
k_color = _val_to_color(k)
v_color = _val_to_color(v)
substrs.append(f" {k_color}0x{k:08X}{Style.RESET_ALL}: "
f"{v_color}0x{v:08X}{Style.RESET_ALL},"
)
substrs.append("}")
return "\n".join(substrs)
class Firmware(FirmwarePatchMixin, bytearray):
RAM_BASE = 0x02000000
RAM_LEN = 0x00020000
ENC_LEN = 0
FLASH_BASE = 0x0000_0000
FLASH_LEN = 0
def __init__(self, firmware=None):
if firmware:
with open(firmware, 'rb') as f:
firmware_data = f.read()
super().__init__(firmware_data)
else:
super().__init__(self.FLASH_LEN)
self._lookup = Lookup()
self._verify()
def _verify(self):
pass
def __getitem__(self, key):
""" Properly raises index error if trying to access oob regions.
"""
if isinstance(key, slice):
if key.start is not None:
try:
self[key.start]
except IndexError:
raise IndexError(f"Index {key.start} ({hex(key.start)}) out of range") from None
if key.stop is not None:
try:
self[key.stop - 1]
except IndexError:
raise IndexError(f"Index {key.stop - 1} ({hex(key.stop - 1)}) out of range") from None
return super().__getitem__(key)
def __setitem__(self, key, new_val):
""" Properly raises index error if trying to access oob regions.
"""
if isinstance(key, slice):
if key.start is not None:
try:
self[key.start]
except IndexError:
raise NotEnoughSpaceError(f"Starting index {key.start} ({hex(key.start)}) exceeds firmware length {len(self)} ({hex(len(self))})") from None
if key.stop is not None:
try:
self[key.stop - 1]
except IndexError:
raise NotEnoughSpaceError(f"Ending index {key.stop - 1} ({hex(key.stop - 1)}) exceeds firmware length {len(self)} ({hex(len(self))})") from None
return super().__setitem__(key, new_val)
def int(self, offset : int, size=4):
return int.from_bytes(self[offset:offset+size], 'little')
def set_range(self, start : int, end : int, val : bytes):
self[start:end] = val * (end - start)
return end - start
def clear_range(self, start : int, end : int):
return self.set_range(start, end, val=b"\x00")
def show(self, wrap=1024, show=True):
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
def to_hex(x, pos):
return f"0x{int(x):06X}"
def to_hex_wrap(x, pos):
return f"0x{int(x)*wrap:06X}"
n_bytes = len(self)
rows = int(np.ceil(n_bytes / wrap))
occupied = np.array(self) != 0
plt.imshow(occupied.reshape(rows, wrap))
plt.title(str(self))
axes = plt.gca()
axes.get_xaxis().set_major_locator(ticker.MultipleLocator(128))
axes.get_xaxis().set_major_formatter(ticker.FuncFormatter(to_hex))
axes.get_yaxis().set_major_locator(ticker.MultipleLocator(32))
axes.get_yaxis().set_major_formatter(ticker.FuncFormatter(to_hex_wrap))
if show:
plt.show()
class RWData:
"""
Assumptions (which are valid for this firmware):
1. Only compressed rwdata is after this table
2. We are only modifying the lz_decompress stuff.
"""
# THIS HAS TO AGREE WITH THE LINKER
MAX_TABLE_ELEMENTS = 5
def __init__(self, firmware, table_start, table_len):
# We want to be able to extend the table.
self.firmware = firmware
self.table_start = table_start
self.__compressed_len_memo = {}
self.datas, self.dsts = [], []
for i in range(table_start, table_start + table_len - 4, 16):
# First thing is pointer to executable, need to always replace this
# to our lzma
rel_offset_to_fn = firmware.int(i)
if rel_offset_to_fn > 0x8000_0000:
rel_offset_to_fn -= 0x1_0000_0000
fn_addr = i + rel_offset_to_fn
assert fn_addr == 0x18005 # lz_decompress function
i += 4
data_addr = i + firmware.int(i)
i += 4
data_len = firmware.int(i) >> 1
i += 4
data_dst = firmware.int(i)
i += 4
data = lz77_decompress(firmware[data_addr : data_addr + data_len])
print(f" lz77 decompressed data {data_len} -> {len(data)}")
firmware.clear_range(data_addr, data_addr + data_len)
self.append(data, data_dst)
last_element_offset = table_start + table_len - 4
self.last_fn = firmware.int(last_element_offset)
if self.last_fn > 0x8000_0000:
self.last_fn -= 0x1_0000_0000
self.last_fn += last_element_offset
# Mark this area as reserved; there's nothing special about 0x77, its
# just not 0x00
firmware.set_range(table_start, table_start + 16 * self.MAX_TABLE_ELEMENTS + 4, b"\x77")
def __getitem__(self, k):
return self.datas[k]
@property
def table_end(self):
return self.table_start + 4 * 4 * len(self.datas) + 4
def append(self, data, dst):
""" Add a new element to the table
"""
if len(self.datas) >= self.MAX_TABLE_ELEMENTS:
raise NotEnoughSpaceError(f"MAX_TABLE_ELEMENTS exceeded; increase this value")
self.datas.append(data)
self.dsts.append(dst)
assert len(self.datas) == len(self.dsts)
@property
def compressed_len(self):
compressed_len = 0
for data in self.datas:
data = bytes(data)
if data not in self.__compressed_len_memo:
compressed_data = lzma_compress(bytes(data))
self.__compressed_len_memo[data] = len(compressed_data)
compressed_len += self.__compressed_len_memo[data]
return compressed_len
def write_table_and_data(self, data_offset=None):
"""
Parameters
----------
data_offset : int
Where to write the compressed data
"""
# Write Compressed Data
data_addrs, data_lens = [], []
if data_offset is None:
index = self.table_end
else:
index = data_offset
total_len = 0
for data in self.datas:
compressed_data = lzma_compress(bytes(data))
print(f" compressed {len(data)}->{len(compressed_data)} bytes (saves {len(data)-len(compressed_data)}). Writing to 0x{index:05X}")
self.firmware[index:index+len(compressed_data)] = compressed_data
data_addrs.append(index)
data_lens.append(len(compressed_data))
index += len(compressed_data)
total_len += len(compressed_data)
# Write Table
index = self.table_start
assert len(data_addrs) == len(data_lens) == len(self.dsts)
for data_addr, data_len, data_dst in zip(data_addrs, data_lens, self.dsts):
self.firmware.relative(index, "rwdata_inflate")
index += 4
# Assumes that the data will be after the table.
rel_addr = data_addr - index
if rel_addr < 0:
rel_addr += 0x1_0000_0000
self.firmware.replace(index, rel_addr, size=4)
index += 4
self.firmware.replace(index, data_len, size=4)
index += 4
self.firmware.replace(index, data_dst, size=4)
index += 4
self.firmware.relative(index, self.last_fn, size=4)
index += 4
assert index == self.table_end
# Update the pointer to the end of table in the loader
self.firmware.relative(0x17db4, index, size=4)
print(self)
return total_len
def __str__(self):
""" Returns the **written** table.
Doesn't show unstaged changes.
"""
substrs = []
substrs.append("")
substrs.append("RWData Table")
substrs.append("------------")
for addr in range(self.table_start, self.table_end - 4, 16):
substrs.append(f"0x{addr:08X}: "
f"0x{self.firmware.int(addr + 0):08X} "
f"0x{self.firmware.int(addr + 4):08X} "
f"0x{self.firmware.int(addr + 8):08X} "
f"0x{self.firmware.int(addr + 12):08X} "
)
addr = self.table_end - 4
substrs.append(f"0x{addr:08X}: 0x{self.firmware.int(addr + 0):08X}")
substrs.append("")
return "\n".join(substrs)
class IntFirmware(Firmware):
STOCK_ROM_SHA1_HASH = "efa04c387ad7b40549e15799b471a6e1cd234c76"
FLASH_BASE = 0x08000000
FLASH_LEN = 0x00020000
STOCK_ROM_END = 0x00019300 # Actual stock rom end
def __init__(self, firmware, elf):
super().__init__(firmware)
self._elf_f = open(elf, 'rb')
self.elf = ELFFile(self._elf_f)
self.symtab = self.elf.get_section_by_name('.symtab')
self.rwdata = RWData(self, 0x1_80a4, 36)
def __str__(self):
return "internal"
def _verify(self):
h = hashlib.sha1(self).hexdigest()
if h != self.STOCK_ROM_SHA1_HASH:
raise InvalidStockRomError
def address(self, symbol_name):
symbols = self.symtab.get_symbol_by_name(symbol_name)
if not symbols:
raise MissingSymbolError(f"Cannot find symbol \"{symbol_name}\"")
address = symbols[0]['st_value']
if not address or not (
(self.RAM_BASE <= address <= self.RAM_BASE + self.RAM_LEN) or
(self.FLASH_BASE <= address <= self.FLASH_BASE + self.FLASH_LEN)
):
raise MissingSymbolError(f"Symbol \"{symbol_name}\" has invalid address 0x{address:08X}")
print(f" found {symbol_name} at 0x{address:08X}")
return address
@property
def key(self):
offset = 0x106f4
return self[offset: offset+16]
@property
def nonce(self):
offset = 0x106e4
return self[offset:offset+8]
def _nonce_to_iv(nonce):
# need to convert nonce to 2
assert len(nonce) == 8
nonce = nonce[::-1]
# The lower 28bits (counter) will be updated in `crypt` method
return nonce + b"\x00\x00" + b"\x71\x23" + b"\x20\x00" + b"\x00\x00"
class ExtFirmware(Firmware):
STOCK_ROM_SHA1_HASH = "eea70bb171afece163fb4b293c5364ddb90637ae"
FLASH_BASE = 0x9000_0000
FLASH_LEN = 0x0010_0000
ENC_LEN = 0xF_E000 # end address at 0x080106ec
STOCK_ROM_END = 0x0010_0000
def __str__(self):
return "external"
def _verify(self):
h = hashlib.sha1(self[:-8192]).hexdigest()
if h != self.STOCK_ROM_SHA1_HASH:
raise InvalidStockRomError
def crypt(self, key, nonce):
""" Decrypts if encrypted; encrypts if in plain text.
"""
key = bytes(key[::-1])
iv = bytearray(_nonce_to_iv(nonce))
aes = AES.new(key, AES.MODE_ECB)
for offset in range(0, self.ENC_LEN, 128 // 8):
counter_block = iv.copy()
counter = (self.FLASH_BASE + offset) >> 4
counter_block[12] = ((counter >> 24) & 0x0F) | (counter_block[12] & 0xF0)
counter_block[13] = (counter >> 16) & 0xFF
counter_block[14] = (counter >> 8) & 0xFF
counter_block[15] = (counter >> 0) & 0xFF
cipher_block = aes.encrypt(bytes(counter_block))
for i, cipher_byte in enumerate(reversed(cipher_block)):
self[offset + i] ^= cipher_byte
class SRAM3(Firmware):
# This address of unused ram was found via tools/mem_observer.py
FLASH_BASE = 0x240f2124
FLASH_LEN = 0x24100000 - FLASH_BASE
def __str__(self):
return "sram3"
class Device(DevicePatchMixin):
def __init__(self, internal, external):
self.internal = internal
self.external = external
self.sram3 = SRAM3()
self.lookup = Lookup()
self.internal._lookup = self.lookup
self.external._lookup = self.lookup
self.sram3._lookup = self.lookup
def crypt(self):
self.external.crypt(self.internal.key, self.internal.nonce)
def show(self, show=True):
import matplotlib.pyplot as plt
if len(self.external):
plt.subplot(2, 1, 1)
self.internal.show(show=False)
plt.subplot(2, 1, 2)
self.external.show(show=False)
else:
self.internal.show(show=False)
if show:
plt.show()