Make sure all streams are really closed via try/finally block. Added a "closeAll" function in StreamUtils which helps to close a list of closeables

This commit is contained in:
Maschell 2019-04-19 11:41:12 +02:00
parent fccd8f8bf4
commit 9f6f9aaabe
6 changed files with 299 additions and 266 deletions

View File

@ -28,6 +28,7 @@ import java.util.Arrays;
import de.mas.wiiu.jnus.implementations.wud.WUDImage;
import de.mas.wiiu.jnus.utils.PipedInputStreamWithException;
import de.mas.wiiu.jnus.utils.StreamUtils;
import de.mas.wiiu.jnus.utils.cryptography.AESDecryption;
import lombok.Getter;
import lombok.extern.java.Log;
@ -114,41 +115,44 @@ public abstract class WUDDiscReader {
final int BLOCK_SIZE = 0x10000;
long totalread = 0;
do {
long blockNumber = (usedFileOffset / BLOCK_SIZE);
long blockOffset = (usedFileOffset % BLOCK_SIZE);
try {
do {
long blockNumber = (usedFileOffset / BLOCK_SIZE);
long blockOffset = (usedFileOffset % BLOCK_SIZE);
readOffset = clusterOffset + (blockNumber * BLOCK_SIZE);
// (long)WiiUDisc.WIIU_DECRYPTED_AREA_OFFSET + volumeOffset + clusterOffset + (blockStructure.getBlockNumber() * 0x8000);
readOffset = clusterOffset + (blockNumber * BLOCK_SIZE);
// (long)WiiUDisc.WIIU_DECRYPTED_AREA_OFFSET + volumeOffset + clusterOffset + (blockStructure.getBlockNumber() * 0x8000);
if (!useFixedIV) {
ByteBuffer byteBuffer = ByteBuffer.allocate(0x10);
byteBuffer.position(0x08);
usedIV = byteBuffer.putLong(usedFileOffset >> 16).array();
}
buffer = readDecryptedChunk(readOffset, key, usedIV);
maxCopySize = BLOCK_SIZE - blockOffset;
copySize = (usedSize > maxCopySize) ? maxCopySize : usedSize;
try {
outputStream.write(Arrays.copyOfRange(buffer, (int) blockOffset, (int) (blockOffset + copySize)));
} catch (IOException e) {
if (e.getMessage().equals("Pipe closed")) {
break;
} else {
throw e;
if (!useFixedIV) {
ByteBuffer byteBuffer = ByteBuffer.allocate(0x10);
byteBuffer.position(0x08);
usedIV = byteBuffer.putLong(usedFileOffset >> 16).array();
}
}
totalread += copySize;
buffer = readDecryptedChunk(readOffset, key, usedIV);
maxCopySize = BLOCK_SIZE - blockOffset;
copySize = (usedSize > maxCopySize) ? maxCopySize : usedSize;
// update counters
usedSize -= copySize;
usedFileOffset += copySize;
} while (totalread < size);
try {
outputStream.write(Arrays.copyOfRange(buffer, (int) blockOffset, (int) (blockOffset + copySize)));
} catch (IOException e) {
if (e.getMessage().equals("Pipe closed")) {
break;
} else {
throw e;
}
}
totalread += copySize;
// update counters
usedSize -= copySize;
usedFileOffset += copySize;
} while (totalread < size);
} finally {
StreamUtils.closeAll(outputStream);
}
outputStream.close();
return totalread >= size;
}

View File

@ -23,6 +23,7 @@ import java.util.Arrays;
import de.mas.wiiu.jnus.implementations.wud.WUDImage;
import de.mas.wiiu.jnus.implementations.wud.WUDImageCompressedInfo;
import de.mas.wiiu.jnus.utils.StreamUtils;
public class WUDDiscReaderCompressed extends WUDDiscReader {
@ -55,40 +56,40 @@ public class WUDDiscReaderCompressed extends WUDDiscReader {
byte[] buffer = new byte[bufferSize];
RandomAccessFile input = getRandomAccessFileStream();
synchronized (input) {
while (usedSize > 0) {
long sectorOffset = (usedOffset % info.getSectorSize());
long remainingSectorBytes = info.getSectorSize() - sectorOffset;
long sectorIndex = (usedOffset / info.getSectorSize());
int bytesToRead = (int) ((remainingSectorBytes < usedSize) ? remainingSectorBytes : usedSize); // read only up to the end of the current sector
// look up real sector index
long realSectorIndex = info.getSectorIndex((int) sectorIndex);
long offset2 = info.getOffsetSectorArray() + realSectorIndex * info.getSectorSize() + sectorOffset;
try {
synchronized (input) {
while (usedSize > 0) {
long sectorOffset = (usedOffset % info.getSectorSize());
long remainingSectorBytes = info.getSectorSize() - sectorOffset;
long sectorIndex = (usedOffset / info.getSectorSize());
int bytesToRead = (int) ((remainingSectorBytes < usedSize) ? remainingSectorBytes : usedSize); // read only up to the end of the current
// sector
// look up real sector index
long realSectorIndex = info.getSectorIndex((int) sectorIndex);
long offset2 = info.getOffsetSectorArray() + realSectorIndex * info.getSectorSize() + sectorOffset;
input.seek(offset2);
int read = input.read(buffer);
input.seek(offset2);
int read = input.read(buffer);
if (read < 0) {
break;
}
try {
out.write(Arrays.copyOfRange(buffer, 0, bytesToRead));
} catch (IOException e) {
if (e.getMessage().equals("Pipe closed")) {
if (read < 0) {
break;
} else {
input.close();
throw e;
}
}
try {
out.write(Arrays.copyOfRange(buffer, 0, bytesToRead));
} catch (IOException e) {
if (e.getMessage().equals("Pipe closed")) {
break;
} else {
throw e;
}
}
usedSize -= bytesToRead;
usedOffset += bytesToRead;
usedSize -= bytesToRead;
usedOffset += bytesToRead;
}
}
input.close();
}
synchronized (out) {
out.close();
} finally {
StreamUtils.closeAll(input, out);
}
return usedSize == 0;
}

View File

@ -53,8 +53,11 @@ public final class FileUtils {
*/
public static boolean saveByteArrayToFile(@NonNull File output, byte[] data) throws IOException {
FileOutputStream out = new FileOutputStream(output);
out.write(data);
out.close();
try {
out.write(data);
} finally {
out.close();
}
return true;
}
@ -81,12 +84,14 @@ public final class FileUtils {
tempFile.createNewFile();
RandomAccessFile outStream = new RandomAccessFile(tempFilePath, "rw");
outStream.setLength(filesize);
outStream.seek(0L);
try {
outStream.setLength(filesize);
outStream.seek(0L);
action.apply(new RandomFileOutputStream(outStream));
outStream.close();
action.apply(new RandomFileOutputStream(outStream));
} finally {
outStream.close();
}
// Rename temp file.
if (outputFile.exists()) {

View File

@ -120,23 +120,25 @@ public final class HashUtil {
int inBlockBufferRead = 0;
byte[] blockBuffer = new byte[bufferSize];
ByteArrayBuffer overflow = new ByteArrayBuffer(bufferSize);
do {
inBlockBufferRead = StreamUtils.getChunkFromStream(in, blockBuffer, overflow, bufferSize);
try {
do {
inBlockBufferRead = StreamUtils.getChunkFromStream(in, blockBuffer, overflow, bufferSize);
if (inBlockBufferRead <= 0) break;
if (inBlockBufferRead <= 0) break;
digest.update(blockBuffer, 0, inBlockBufferRead);
cur_position += inBlockBufferRead;
digest.update(blockBuffer, 0, inBlockBufferRead);
cur_position += inBlockBufferRead;
} while (cur_position < target_size);
long missing_bytes = target_size - cur_position;
if (missing_bytes > 0) {
byte[] missing = new byte[(int) missing_bytes];
digest.update(missing, 0, (int) missing_bytes);
} while (cur_position < target_size);
long missing_bytes = target_size - cur_position;
if (missing_bytes > 0) {
byte[] missing = new byte[(int) missing_bytes];
digest.update(missing, 0, (int) missing_bytes);
}
} finally {
in.close();
}
in.close();
return digest.digest();
}

View File

@ -16,6 +16,7 @@
****************************************************************************/
package de.mas.wiiu.jnus.utils;
import java.io.Closeable;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
@ -32,29 +33,40 @@ public final class StreamUtils {
// Utility class
}
/**
* Tries to read a given amount of bytes from a stream and return them as
* a byte array. Closes the inputs stream on success AND failure.
* @param in
* @param size
* @return
* @throws IOException
*/
public static byte[] getBytesFromStream(InputStream in, int size) throws IOException {
synchronized (in) {
byte[] result = new byte[size];
byte[] buffer = null;
if (size < 0x8000) {
buffer = new byte[size];
} else {
buffer = new byte[0x8000];
}
int toRead = size;
int curReadChunk = buffer.length;
do {
if (toRead < curReadChunk) {
curReadChunk = toRead;
try {
synchronized (in) {
byte[] result = new byte[size];
byte[] buffer = null;
if (size < 0x8000) {
buffer = new byte[size];
} else {
buffer = new byte[0x8000];
}
int read = in.read(buffer, 0, curReadChunk);
StreamUtils.checkForException(in);
if (read < 0) break;
System.arraycopy(buffer, 0, result, size - toRead, read);
toRead -= read;
} while (toRead > 0);
in.close();
return result;
int toRead = size;
int curReadChunk = buffer.length;
do {
if (toRead < curReadChunk) {
curReadChunk = toRead;
}
int read = in.read(buffer, 0, curReadChunk);
StreamUtils.checkForException(in);
if (read < 0) break;
System.arraycopy(buffer, 0, result, size - toRead, read);
toRead -= read;
} while (toRead > 0);
return result;
}
} finally {
StreamUtils.closeAll(in);
}
}
@ -141,43 +153,44 @@ public final class StreamUtils {
int read = 0;
long totalRead = 0;
long written = 0;
do {
read = inputStream.read(buffer);
StreamUtils.checkForException(inputStream);
if (read < 0) {
break;
}
totalRead += read;
if (totalRead > filesize) {
read = (int) (read - (totalRead - filesize));
try {
do {
read = inputStream.read(buffer);
StreamUtils.checkForException(inputStream);
if (read < 0) {
break;
}
totalRead += read;
if (totalRead > filesize) {
read = (int) (read - (totalRead - filesize));
}
outputStream.write(buffer, 0, read);
written += read;
if (sha1 != null) {
sha1.update(buffer, 0, read);
}
} while (written < filesize);
if (sha1 != null && hash != null) {
long missingInHash = expectedSizeForHash - written;
if (missingInHash > 0) {
sha1.update(new byte[(int) missingInHash]);
}
byte[] calculated_hash = sha1.digest();
byte[] expected_hash = hash;
if (!Arrays.equals(calculated_hash, expected_hash)) {
throw new CheckSumWrongException("Hash doesn't match saves output stream.", calculated_hash, expected_hash);
}
}
outputStream.write(buffer, 0, read);
written += read;
if (sha1 != null) {
sha1.update(buffer, 0, read);
}
} while (written < filesize);
if (sha1 != null && hash != null) {
long missingInHash = expectedSizeForHash - written;
if (missingInHash > 0) {
sha1.update(new byte[(int) missingInHash]);
}
byte[] calculated_hash = sha1.digest();
byte[] expected_hash = hash;
if (!Arrays.equals(calculated_hash, expected_hash)) {
outputStream.close();
inputStream.close();
throw new CheckSumWrongException("Hash doesn't match saves output stream.", calculated_hash, expected_hash);
}
} finally {
StreamUtils.closeAll(inputStream, outputStream);
}
outputStream.close();
inputStream.close();
}
}
@ -206,4 +219,17 @@ public final class StreamUtils {
}
}
public static void closeAll(Closeable... stream) throws IOException {
IOException exception = null;
for (Closeable cur : stream) {
try {
cur.close();
} catch (IOException e) {
exception = e;
}
}
if (exception != null) {
throw exception;
}
}
}

View File

@ -31,7 +31,6 @@ import de.mas.wiiu.jnus.entities.content.Content;
import de.mas.wiiu.jnus.utils.ByteArrayBuffer;
import de.mas.wiiu.jnus.utils.CheckSumWrongException;
import de.mas.wiiu.jnus.utils.HashUtil;
import de.mas.wiiu.jnus.utils.PipedInputStreamWithException;
import de.mas.wiiu.jnus.utils.StreamUtils;
import de.mas.wiiu.jnus.utils.Utils;
import lombok.extern.java.Log;
@ -91,106 +90,105 @@ public class NUSDecryption extends AESDecryption {
int skipoffset = (int) (fileOffset % 0x8000);
// If we are at the beginning of a block, but it's not the first one,
// we need to get the IV from the last 16 bytes of the previous block.
// while beeing paranoid to exactly read 16 bytes but not more. Reading more
// would destroy our input stream.
// The input stream has been prepared to start 16 bytes earlier on this case.
if (fileOffset >= 0x8000 && fileOffset % 0x8000 == 0) {
int toRead = 16;
byte[] data = new byte[toRead];
int readTotal = 0;
while (readTotal < toRead) {
int res = inputStream.read(data, readTotal, toRead - readTotal);
StreamUtils.checkForException(inputStream);
if (res < 0) {
// This should NEVER happen.
throw new IOException();
try {
// If we are at the beginning of a block, but it's not the first one,
// we need to get the IV from the last 16 bytes of the previous block.
// while beeing paranoid to exactly read 16 bytes but not more. Reading more
// would destroy our input stream.
// The input stream has been prepared to start 16 bytes earlier on this case.
if (fileOffset >= 0x8000 && fileOffset % 0x8000 == 0) {
int toRead = 16;
byte[] data = new byte[toRead];
int readTotal = 0;
while (readTotal < toRead) {
int res = inputStream.read(data, readTotal, toRead - readTotal);
StreamUtils.checkForException(inputStream);
if (res < 0) {
// This should NEVER happen.
throw new IOException();
}
readTotal += res;
}
readTotal += res;
IV = Arrays.copyOfRange(data, 0, toRead);
}
IV = Arrays.copyOfRange(data, 0, toRead);
}
ByteArrayBuffer overflow = new ByteArrayBuffer(BLOCKSIZE);
ByteArrayBuffer overflow = new ByteArrayBuffer(BLOCKSIZE);
// We can only decrypt multiples of 16. So we need to align it.
long toRead = Utils.align(filesize + 15, 16);
// We can only decrypt multiples of 16. So we need to align it.
long toRead = Utils.align(filesize + 15, 16);
do {
// In case we start on the middle of a block we need to consume the "garbage" and save the
// current IV.
if (skipoffset > 0) {
int skippedBytes = StreamUtils.getChunkFromStream(inputStream, blockBuffer, overflow, skipoffset);
if (skippedBytes >= 16) {
IV = Arrays.copyOfRange(blockBuffer, skippedBytes - 16, skippedBytes);
do {
// In case we start on the middle of a block we need to consume the "garbage" and save the
// current IV.
if (skipoffset > 0) {
int skippedBytes = StreamUtils.getChunkFromStream(inputStream, blockBuffer, overflow, skipoffset);
if (skippedBytes >= 16) {
IV = Arrays.copyOfRange(blockBuffer, skippedBytes - 16, skippedBytes);
}
skipoffset = 0;
}
skipoffset = 0;
}
int curReadSize = BLOCKSIZE;
if (toRead < BLOCKSIZE) {
curReadSize = (int) toRead;
}
int curReadSize = BLOCKSIZE;
if (toRead < BLOCKSIZE) {
curReadSize = (int) toRead;
}
inBlockBuffer = StreamUtils.getChunkFromStream(inputStream, blockBuffer, overflow, curReadSize);
inBlockBuffer = StreamUtils.getChunkFromStream(inputStream, blockBuffer, overflow, curReadSize);
byte[] output = decryptFileChunk(blockBuffer, (int) Utils.align(inBlockBuffer, 16), IV);
byte[] output = decryptFileChunk(blockBuffer, (int) Utils.align(inBlockBuffer, 16), IV);
if (inBlockBuffer == BLOCKSIZE) {
IV = Arrays.copyOfRange(blockBuffer, BLOCKSIZE - 16, BLOCKSIZE);
}
if (inBlockBuffer == BLOCKSIZE) {
IV = Arrays.copyOfRange(blockBuffer, BLOCKSIZE - 16, BLOCKSIZE);
}
int toWrite = inBlockBuffer;
int toWrite = inBlockBuffer;
if ((written + inBlockBuffer) > filesize) {
toWrite = (int) (filesize - written);
}
if ((written + inBlockBuffer) > filesize) {
toWrite = (int) (filesize - written);
}
written += toWrite;
toRead -= toWrite;
written += toWrite;
toRead -= toWrite;
outputStream.write(output, 0, toWrite);
outputStream.write(output, 0, toWrite);
if (sha1 != null && sha1fallback != null) {
sha1.update(output, 0, toWrite);
// In some cases it's using the hash of the whole .app file instead of the part
// that's been actually used.
long toFallback = inBlockBuffer;
if (writtenFallback + toFallback > expectedSizeForHash) {
toFallback = expectedSizeForHash - writtenFallback;
}
sha1fallback.update(output, 0, (int) toFallback);
writtenFallback += toFallback;
}
if (written >= filesize && h3hash == null) {
break;
}
} while (inBlockBuffer == BLOCKSIZE);
if (sha1 != null && sha1fallback != null) {
sha1.update(output, 0, toWrite);
// In some cases it's using the hash of the whole .app file instead of the part
// that's been actually used.
long toFallback = inBlockBuffer;
if (writtenFallback + toFallback > expectedSizeForHash) {
toFallback = expectedSizeForHash - writtenFallback;
long missingInHash = expectedSizeForHash - writtenFallback;
if (missingInHash > 0) {
sha1fallback.update(new byte[(int) missingInHash]);
}
sha1fallback.update(output, 0, (int) toFallback);
writtenFallback += toFallback;
}
if (written >= filesize && h3hash == null) {
break;
}
} while (inBlockBuffer == BLOCKSIZE);
if (sha1 != null && sha1fallback != null) {
long missingInHash = expectedSizeForHash - writtenFallback;
if (missingInHash > 0) {
sha1fallback.update(new byte[(int) missingInHash]);
}
byte[] calculated_hash1 = sha1.digest();
byte[] calculated_hash2 = sha1fallback.digest();
byte[] expected_hash = h3hash;
if (!Arrays.equals(calculated_hash1, expected_hash) && !Arrays.equals(calculated_hash2, expected_hash)) {
inputStream.close();
outputStream.close();
throw new CheckSumWrongException("hash checksum failed", calculated_hash1, expected_hash);
} else {
log.finest("Hash DOES match saves output stream.");
byte[] calculated_hash1 = sha1.digest();
byte[] calculated_hash2 = sha1fallback.digest();
byte[] expected_hash = h3hash;
if (!Arrays.equals(calculated_hash1, expected_hash) && !Arrays.equals(calculated_hash2, expected_hash)) {
throw new CheckSumWrongException("hash checksum failed", calculated_hash1, expected_hash);
} else {
log.finest("Hash DOES match saves output stream.");
}
}
} finally {
StreamUtils.closeAll(inputStream, outputStream);
}
inputStream.close();
outputStream.close();
}
public void decryptFileStreamHashed(InputStream inputStream, OutputStream outputStream, long filesize, long fileoffset, short contentIndex, byte[] h3Hash)
@ -209,45 +207,44 @@ public class NUSDecryption extends AESDecryption {
ByteArrayBuffer overflow = new ByteArrayBuffer(BLOCKSIZE);
long wrote = 0;
int inBlockBuffer = 0;
do {
inBlockBuffer = StreamUtils.getChunkFromStream(inputStream, encryptedBlockBuffer, overflow, BLOCKSIZE);
try {
do {
inBlockBuffer = StreamUtils.getChunkFromStream(inputStream, encryptedBlockBuffer, overflow, BLOCKSIZE);
if (writeSize > filesize) writeSize = filesize;
if (writeSize > filesize) writeSize = filesize;
byte[] output;
try {
output = decryptFileChunkHash(encryptedBlockBuffer, (int) block, contentIndex, h3Hash);
} catch (CheckSumWrongException e) {
outputStream.close();
inputStream.close();
throw e;
}
if ((wrote + writeSize) > filesize) {
writeSize = (int) (filesize - wrote);
}
try {
outputStream.write(output, (int) (0 + soffset), (int) writeSize);
} catch (IOException e) {
if (e.getMessage().equals("Pipe closed")) {
break;
byte[] output;
try {
output = decryptFileChunkHash(encryptedBlockBuffer, (int) block, contentIndex, h3Hash);
} catch (CheckSumWrongException e) {
throw e;
}
e.printStackTrace();
throw e;
}
wrote += writeSize;
block++;
if ((wrote + writeSize) > filesize) {
writeSize = (int) (filesize - wrote);
}
if (soffset > 0) {
writeSize = HASHBLOCKSIZE;
soffset = 0;
}
} while (wrote < filesize && (inBlockBuffer == BLOCKSIZE));
log.finest("Decryption okay");
outputStream.close();
inputStream.close();
try {
outputStream.write(output, (int) (0 + soffset), (int) writeSize);
} catch (IOException e) {
if (e.getMessage().equals("Pipe closed")) {
break;
}
e.printStackTrace();
throw e;
}
wrote += writeSize;
block++;
if (soffset > 0) {
writeSize = HASHBLOCKSIZE;
soffset = 0;
}
} while (wrote < filesize && (inBlockBuffer == BLOCKSIZE));
log.finest("Decryption okay");
} finally {
StreamUtils.closeAll(inputStream, outputStream);
}
}
private byte[] decryptFileChunkHash(byte[] blockBuffer, int block, int contentIndex, byte[] h3_hashes) throws CheckSumWrongException {
@ -277,32 +274,30 @@ public class NUSDecryption extends AESDecryption {
long encryptedFileSize = content.getEncryptedFileSize();
if (content.isEncrypted()) {
if (content.isHashed()) {
byte[] h3 = h3HashHashed.orElseThrow(() -> new FileNotFoundException("h3 hash not found."));
try {
if (content.isEncrypted()) {
if (content.isHashed()) {
byte[] h3 = h3HashHashed.orElseThrow(() -> new FileNotFoundException("h3 hash not found."));
decryptFileStreamHashed(inputStream, outputStream, size, offset, (short) contentIndex, h3);
} else {
byte[] h3Hash = content.getSHA2Hash();
// We want to check if we read the whole file or just a part of it.
// There should be only one actual file inside a non-hashed content.
// But it could also contain a directory, so we need to filter.
long fstFileSize = content.getEntries().stream().filter(f -> !f.isDir()).findFirst().map(f -> f.getFileSize()).orElse(0L);
if (size > 0 && size < fstFileSize) {
h3Hash = null;
decryptFileStreamHashed(inputStream, outputStream, size, offset, (short) contentIndex, h3);
} else {
byte[] h3Hash = content.getSHA2Hash();
// We want to check if we read the whole file or just a part of it.
// There should be only one actual file inside a non-hashed content.
// But it could also contain a directory, so we need to filter.
long fstFileSize = content.getEntries().stream().filter(f -> !f.isDir()).findFirst().map(f -> f.getFileSize()).orElse(0L);
if (size > 0 && size < fstFileSize) {
h3Hash = null;
}
decryptFileStream(inputStream, outputStream, size, offset, (short) contentIndex, h3Hash, encryptedFileSize);
}
decryptFileStream(inputStream, outputStream, size, offset, (short) contentIndex, h3Hash, encryptedFileSize);
} else {
StreamUtils.saveInputStreamToOutputStreamWithHash(inputStream, outputStream, size, content.getSHA2Hash(), encryptedFileSize);
}
} else {
StreamUtils.saveInputStreamToOutputStreamWithHash(inputStream, outputStream, size, content.getSHA2Hash(), encryptedFileSize);
} finally {
StreamUtils.closeAll(inputStream, outputStream);
}
synchronized (inputStream) {
inputStream.close();
}
synchronized (outputStream) {
outputStream.close();
}
return true;
}
}