Make sure all streams are really closed via try/finally block. Added a "closeAll" function in StreamUtils which helps to close a list of closeables

This commit is contained in:
Maschell 2019-04-19 11:41:12 +02:00
parent fccd8f8bf4
commit 9f6f9aaabe
6 changed files with 299 additions and 266 deletions

View File

@ -28,6 +28,7 @@ import java.util.Arrays;
import de.mas.wiiu.jnus.implementations.wud.WUDImage;
import de.mas.wiiu.jnus.utils.PipedInputStreamWithException;
import de.mas.wiiu.jnus.utils.StreamUtils;
import de.mas.wiiu.jnus.utils.cryptography.AESDecryption;
import lombok.Getter;
import lombok.extern.java.Log;
@ -114,6 +115,7 @@ public abstract class WUDDiscReader {
final int BLOCK_SIZE = 0x10000;
long totalread = 0;
try {
do {
long blockNumber = (usedFileOffset / BLOCK_SIZE);
long blockOffset = (usedFileOffset % BLOCK_SIZE);
@ -147,8 +149,10 @@ public abstract class WUDDiscReader {
usedSize -= copySize;
usedFileOffset += copySize;
} while (totalread < size);
} finally {
StreamUtils.closeAll(outputStream);
}
outputStream.close();
return totalread >= size;
}

View File

@ -23,6 +23,7 @@ import java.util.Arrays;
import de.mas.wiiu.jnus.implementations.wud.WUDImage;
import de.mas.wiiu.jnus.implementations.wud.WUDImageCompressedInfo;
import de.mas.wiiu.jnus.utils.StreamUtils;
public class WUDDiscReaderCompressed extends WUDDiscReader {
@ -55,12 +56,14 @@ public class WUDDiscReaderCompressed extends WUDDiscReader {
byte[] buffer = new byte[bufferSize];
RandomAccessFile input = getRandomAccessFileStream();
try {
synchronized (input) {
while (usedSize > 0) {
long sectorOffset = (usedOffset % info.getSectorSize());
long remainingSectorBytes = info.getSectorSize() - sectorOffset;
long sectorIndex = (usedOffset / info.getSectorSize());
int bytesToRead = (int) ((remainingSectorBytes < usedSize) ? remainingSectorBytes : usedSize); // read only up to the end of the current sector
int bytesToRead = (int) ((remainingSectorBytes < usedSize) ? remainingSectorBytes : usedSize); // read only up to the end of the current
// sector
// look up real sector index
long realSectorIndex = info.getSectorIndex((int) sectorIndex);
long offset2 = info.getOffsetSectorArray() + realSectorIndex * info.getSectorSize() + sectorOffset;
@ -77,7 +80,6 @@ public class WUDDiscReaderCompressed extends WUDDiscReader {
if (e.getMessage().equals("Pipe closed")) {
break;
} else {
input.close();
throw e;
}
}
@ -85,10 +87,9 @@ public class WUDDiscReaderCompressed extends WUDDiscReader {
usedSize -= bytesToRead;
usedOffset += bytesToRead;
}
input.close();
}
synchronized (out) {
out.close();
} finally {
StreamUtils.closeAll(input, out);
}
return usedSize == 0;
}

View File

@ -53,8 +53,11 @@ public final class FileUtils {
*/
public static boolean saveByteArrayToFile(@NonNull File output, byte[] data) throws IOException {
FileOutputStream out = new FileOutputStream(output);
try {
out.write(data);
} finally {
out.close();
}
return true;
}
@ -81,12 +84,14 @@ public final class FileUtils {
tempFile.createNewFile();
RandomAccessFile outStream = new RandomAccessFile(tempFilePath, "rw");
try {
outStream.setLength(filesize);
outStream.seek(0L);
action.apply(new RandomFileOutputStream(outStream));
} finally {
outStream.close();
}
// Rename temp file.
if (outputFile.exists()) {

View File

@ -120,6 +120,7 @@ public final class HashUtil {
int inBlockBufferRead = 0;
byte[] blockBuffer = new byte[bufferSize];
ByteArrayBuffer overflow = new ByteArrayBuffer(bufferSize);
try {
do {
inBlockBufferRead = StreamUtils.getChunkFromStream(in, blockBuffer, overflow, bufferSize);
@ -134,8 +135,9 @@ public final class HashUtil {
byte[] missing = new byte[(int) missing_bytes];
digest.update(missing, 0, (int) missing_bytes);
}
} finally {
in.close();
}
return digest.digest();
}

View File

@ -16,6 +16,7 @@
****************************************************************************/
package de.mas.wiiu.jnus.utils;
import java.io.Closeable;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
@ -32,7 +33,16 @@ public final class StreamUtils {
// Utility class
}
/**
* Tries to read a given amount of bytes from a stream and return them as
* a byte array. Closes the inputs stream on success AND failure.
* @param in
* @param size
* @return
* @throws IOException
*/
public static byte[] getBytesFromStream(InputStream in, int size) throws IOException {
try {
synchronized (in) {
byte[] result = new byte[size];
byte[] buffer = null;
@ -53,9 +63,11 @@ public final class StreamUtils {
System.arraycopy(buffer, 0, result, size - toRead, read);
toRead -= read;
} while (toRead > 0);
in.close();
return result;
}
} finally {
StreamUtils.closeAll(in);
}
}
public static int getChunkFromStream(InputStream inputStream, byte[] output, ByteArrayBuffer overflowbuffer, int BLOCKSIZE) throws IOException {
@ -141,6 +153,8 @@ public final class StreamUtils {
int read = 0;
long totalRead = 0;
long written = 0;
try {
do {
read = inputStream.read(buffer);
StreamUtils.checkForException(inputStream);
@ -170,14 +184,13 @@ public final class StreamUtils {
byte[] calculated_hash = sha1.digest();
byte[] expected_hash = hash;
if (!Arrays.equals(calculated_hash, expected_hash)) {
outputStream.close();
inputStream.close();
throw new CheckSumWrongException("Hash doesn't match saves output stream.", calculated_hash, expected_hash);
}
}
outputStream.close();
inputStream.close();
} finally {
StreamUtils.closeAll(inputStream, outputStream);
}
}
}
@ -206,4 +219,17 @@ public final class StreamUtils {
}
}
public static void closeAll(Closeable... stream) throws IOException {
IOException exception = null;
for (Closeable cur : stream) {
try {
cur.close();
} catch (IOException e) {
exception = e;
}
}
if (exception != null) {
throw exception;
}
}
}

View File

@ -31,7 +31,6 @@ import de.mas.wiiu.jnus.entities.content.Content;
import de.mas.wiiu.jnus.utils.ByteArrayBuffer;
import de.mas.wiiu.jnus.utils.CheckSumWrongException;
import de.mas.wiiu.jnus.utils.HashUtil;
import de.mas.wiiu.jnus.utils.PipedInputStreamWithException;
import de.mas.wiiu.jnus.utils.StreamUtils;
import de.mas.wiiu.jnus.utils.Utils;
import lombok.extern.java.Log;
@ -91,6 +90,8 @@ public class NUSDecryption extends AESDecryption {
int skipoffset = (int) (fileOffset % 0x8000);
try {
// If we are at the beginning of a block, but it's not the first one,
// we need to get the IV from the last 16 bytes of the previous block.
// while beeing paranoid to exactly read 16 bytes but not more. Reading more
@ -180,17 +181,14 @@ public class NUSDecryption extends AESDecryption {
byte[] calculated_hash2 = sha1fallback.digest();
byte[] expected_hash = h3hash;
if (!Arrays.equals(calculated_hash1, expected_hash) && !Arrays.equals(calculated_hash2, expected_hash)) {
inputStream.close();
outputStream.close();
throw new CheckSumWrongException("hash checksum failed", calculated_hash1, expected_hash);
} else {
log.finest("Hash DOES match saves output stream.");
}
}
inputStream.close();
outputStream.close();
} finally {
StreamUtils.closeAll(inputStream, outputStream);
}
}
public void decryptFileStreamHashed(InputStream inputStream, OutputStream outputStream, long filesize, long fileoffset, short contentIndex, byte[] h3Hash)
@ -209,17 +207,15 @@ public class NUSDecryption extends AESDecryption {
ByteArrayBuffer overflow = new ByteArrayBuffer(BLOCKSIZE);
long wrote = 0;
int inBlockBuffer = 0;
try {
do {
inBlockBuffer = StreamUtils.getChunkFromStream(inputStream, encryptedBlockBuffer, overflow, BLOCKSIZE);
if (writeSize > filesize) writeSize = filesize;
byte[] output;
try {
output = decryptFileChunkHash(encryptedBlockBuffer, (int) block, contentIndex, h3Hash);
} catch (CheckSumWrongException e) {
outputStream.close();
inputStream.close();
throw e;
}
@ -246,8 +242,9 @@ public class NUSDecryption extends AESDecryption {
}
} while (wrote < filesize && (inBlockBuffer == BLOCKSIZE));
log.finest("Decryption okay");
outputStream.close();
inputStream.close();
} finally {
StreamUtils.closeAll(inputStream, outputStream);
}
}
private byte[] decryptFileChunkHash(byte[] blockBuffer, int block, int contentIndex, byte[] h3_hashes) throws CheckSumWrongException {
@ -277,6 +274,7 @@ public class NUSDecryption extends AESDecryption {
long encryptedFileSize = content.getEncryptedFileSize();
try {
if (content.isEncrypted()) {
if (content.isHashed()) {
byte[] h3 = h3HashHashed.orElseThrow(() -> new FileNotFoundException("h3 hash not found."));
@ -296,13 +294,10 @@ public class NUSDecryption extends AESDecryption {
} else {
StreamUtils.saveInputStreamToOutputStreamWithHash(inputStream, outputStream, size, content.getSHA2Hash(), encryptedFileSize);
}
} finally {
StreamUtils.closeAll(inputStream, outputStream);
}
synchronized (inputStream) {
inputStream.close();
}
synchronized (outputStream) {
outputStream.close();
}
return true;
}
}