Fix race condition in ContentManager (#884)

* Fix race condition in ContentManager

This fix a race condition happening since #791 when trying to load a
game via command line.

* Address gdk's comments

* Ensure to dispose the FileStream and not the IStorage
This commit is contained in:
Thog 2020-01-13 01:17:44 +01:00 committed by Ac_K
parent b8e3909d80
commit f0055482fd

View File

@ -28,6 +28,8 @@ namespace Ryujinx.HLE.FileSystem.Content
private Switch _device; private Switch _device;
private readonly object _lock = new object();
public ContentManager(Switch device) public ContentManager(Switch device)
{ {
_contentDictionary = new SortedDictionary<(ulong, NcaContentType), string>(); _contentDictionary = new SortedDictionary<(ulong, NcaContentType), string>();
@ -58,139 +60,151 @@ namespace Ryujinx.HLE.FileSystem.Content
public void LoadEntries(bool ignoreMissingFonts = false) public void LoadEntries(bool ignoreMissingFonts = false)
{ {
_contentDictionary = new SortedDictionary<(ulong, NcaContentType), string>(); lock (_lock)
_locationEntries = new Dictionary<StorageId, LinkedList<LocationEntry>>();
foreach (StorageId storageId in Enum.GetValues(typeof(StorageId)))
{ {
string contentDirectory = null; _contentDictionary = new SortedDictionary<(ulong, NcaContentType), string>();
string contentPathString = null; _locationEntries = new Dictionary<StorageId, LinkedList<LocationEntry>>();
string registeredDirectory = null;
try foreach (StorageId storageId in Enum.GetValues(typeof(StorageId)))
{ {
contentPathString = LocationHelper.GetContentRoot(storageId); string contentDirectory = null;
contentDirectory = LocationHelper.GetRealPath(_device.FileSystem, contentPathString); string contentPathString = null;
registeredDirectory = Path.Combine(contentDirectory, "registered"); string registeredDirectory = null;
}
catch (NotSupportedException)
{
continue;
}
Directory.CreateDirectory(registeredDirectory); try
LinkedList<LocationEntry> locationList = new LinkedList<LocationEntry>();
void AddEntry(LocationEntry entry)
{
locationList.AddLast(entry);
}
foreach (string directoryPath in Directory.EnumerateDirectories(registeredDirectory))
{
if (Directory.GetFiles(directoryPath).Length > 0)
{ {
string ncaName = new DirectoryInfo(directoryPath).Name.Replace(".nca", string.Empty); contentPathString = LocationHelper.GetContentRoot(storageId);
contentDirectory = LocationHelper.GetRealPath(_device.FileSystem, contentPathString);
registeredDirectory = Path.Combine(contentDirectory, "registered");
}
catch (NotSupportedException)
{
continue;
}
using (FileStream ncaFile = new FileStream(Directory.GetFiles(directoryPath)[0], FileMode.Open, FileAccess.Read)) Directory.CreateDirectory(registeredDirectory);
LinkedList<LocationEntry> locationList = new LinkedList<LocationEntry>();
void AddEntry(LocationEntry entry)
{
locationList.AddLast(entry);
}
foreach (string directoryPath in Directory.EnumerateDirectories(registeredDirectory))
{
if (Directory.GetFiles(directoryPath).Length > 0)
{ {
Nca nca = new Nca(_device.System.KeySet, ncaFile.AsStorage()); string ncaName = new DirectoryInfo(directoryPath).Name.Replace(".nca", string.Empty);
string switchPath = contentPathString + ":/" + ncaFile.Name.Replace(contentDirectory, string.Empty).TrimStart(Path.DirectorySeparatorChar); using (FileStream ncaFile = File.OpenRead(Directory.GetFiles(directoryPath)[0]))
{
Nca nca = new Nca(_device.System.KeySet, ncaFile.AsStorage());
// Change path format to switch's string switchPath = contentPathString + ":/" + ncaFile.Name.Replace(contentDirectory, string.Empty).TrimStart(Path.DirectorySeparatorChar);
switchPath = switchPath.Replace('\\', '/');
LocationEntry entry = new LocationEntry(switchPath, // Change path format to switch's
0, switchPath = switchPath.Replace('\\', '/');
(long)nca.Header.TitleId,
nca.Header.ContentType);
AddEntry(entry); LocationEntry entry = new LocationEntry(switchPath,
0,
(long)nca.Header.TitleId,
nca.Header.ContentType);
_contentDictionary.Add((nca.Header.TitleId, nca.Header.ContentType), ncaName); AddEntry(entry);
_contentDictionary.Add((nca.Header.TitleId, nca.Header.ContentType), ncaName);
}
} }
} }
foreach (string filePath in Directory.EnumerateFiles(contentDirectory))
{
if (Path.GetExtension(filePath) == ".nca")
{
string ncaName = Path.GetFileNameWithoutExtension(filePath);
using (FileStream ncaFile = new FileStream(filePath, FileMode.Open, FileAccess.Read))
{
Nca nca = new Nca(_device.System.KeySet, ncaFile.AsStorage());
string switchPath = contentPathString + ":/" + filePath.Replace(contentDirectory, string.Empty).TrimStart(Path.DirectorySeparatorChar);
// Change path format to switch's
switchPath = switchPath.Replace('\\', '/');
LocationEntry entry = new LocationEntry(switchPath,
0,
(long)nca.Header.TitleId,
nca.Header.ContentType);
AddEntry(entry);
_contentDictionary.Add((nca.Header.TitleId, nca.Header.ContentType), ncaName);
}
}
}
if (_locationEntries.ContainsKey(storageId) && _locationEntries[storageId]?.Count == 0)
{
_locationEntries.Remove(storageId);
}
if (!_locationEntries.ContainsKey(storageId))
{
_locationEntries.Add(storageId, locationList);
}
} }
foreach (string filePath in Directory.EnumerateFiles(contentDirectory)) TimeManager.Instance.InitializeTimeZone(_device);
{
if (Path.GetExtension(filePath) == ".nca")
{
string ncaName = Path.GetFileNameWithoutExtension(filePath);
using (FileStream ncaFile = new FileStream(filePath, FileMode.Open, FileAccess.Read)) _device.System.Font.Initialize(this, ignoreMissingFonts);
{
Nca nca = new Nca(_device.System.KeySet, ncaFile.AsStorage());
string switchPath = contentPathString + ":/" + filePath.Replace(contentDirectory, string.Empty).TrimStart(Path.DirectorySeparatorChar);
// Change path format to switch's
switchPath = switchPath.Replace('\\', '/');
LocationEntry entry = new LocationEntry(switchPath,
0,
(long)nca.Header.TitleId,
nca.Header.ContentType);
AddEntry(entry);
_contentDictionary.Add((nca.Header.TitleId, nca.Header.ContentType), ncaName);
}
}
}
if (_locationEntries.ContainsKey(storageId) && _locationEntries[storageId]?.Count == 0)
{
_locationEntries.Remove(storageId);
}
if (!_locationEntries.ContainsKey(storageId))
{
_locationEntries.Add(storageId, locationList);
}
} }
TimeManager.Instance.InitializeTimeZone(_device);
_device.System.Font.Initialize(this, ignoreMissingFonts);
} }
public void ClearEntry(long titleId, NcaContentType contentType, StorageId storageId) public void ClearEntry(long titleId, NcaContentType contentType, StorageId storageId)
{ {
RemoveLocationEntry(titleId, contentType, storageId); lock (_lock)
{
RemoveLocationEntry(titleId, contentType, storageId);
}
} }
public void RefreshEntries(StorageId storageId, int flag) public void RefreshEntries(StorageId storageId, int flag)
{ {
LinkedList<LocationEntry> locationList = _locationEntries[storageId]; lock (_lock)
LinkedListNode<LocationEntry> locationEntry = locationList.First;
while (locationEntry != null)
{ {
LinkedListNode<LocationEntry> nextLocationEntry = locationEntry.Next; LinkedList<LocationEntry> locationList = _locationEntries[storageId];
LinkedListNode<LocationEntry> locationEntry = locationList.First;
if (locationEntry.Value.Flag == flag) while (locationEntry != null)
{ {
locationList.Remove(locationEntry.Value); LinkedListNode<LocationEntry> nextLocationEntry = locationEntry.Next;
}
locationEntry = nextLocationEntry; if (locationEntry.Value.Flag == flag)
{
locationList.Remove(locationEntry.Value);
}
locationEntry = nextLocationEntry;
}
} }
} }
public bool HasNca(string ncaId, StorageId storageId) public bool HasNca(string ncaId, StorageId storageId)
{ {
if (_contentDictionary.ContainsValue(ncaId)) lock (_lock)
{ {
var content = _contentDictionary.FirstOrDefault(x => x.Value == ncaId); if (_contentDictionary.ContainsValue(ncaId))
long titleId = (long)content.Key.Item1; {
var content = _contentDictionary.FirstOrDefault(x => x.Value == ncaId);
long titleId = (long)content.Key.Item1;
NcaContentType contentType = content.Key.type; NcaContentType contentType = content.Key.type;
StorageId storage = GetInstalledStorage(titleId, contentType, storageId); StorageId storage = GetInstalledStorage(titleId, contentType, storageId);
return storage == storageId; return storage == storageId;
}
} }
return false; return false;
@ -198,9 +212,12 @@ namespace Ryujinx.HLE.FileSystem.Content
public UInt128 GetInstalledNcaId(long titleId, NcaContentType contentType) public UInt128 GetInstalledNcaId(long titleId, NcaContentType contentType)
{ {
if (_contentDictionary.ContainsKey(((ulong)titleId, contentType))) lock (_lock)
{ {
return new UInt128(_contentDictionary[((ulong)titleId, contentType)]); if (_contentDictionary.ContainsKey(((ulong)titleId, contentType)))
{
return new UInt128(_contentDictionary[((ulong)titleId, contentType)]);
}
} }
return new UInt128(); return new UInt128();
@ -208,19 +225,25 @@ namespace Ryujinx.HLE.FileSystem.Content
public StorageId GetInstalledStorage(long titleId, NcaContentType contentType, StorageId storageId) public StorageId GetInstalledStorage(long titleId, NcaContentType contentType, StorageId storageId)
{ {
LocationEntry locationEntry = GetLocation(titleId, contentType, storageId); lock (_lock)
{
LocationEntry locationEntry = GetLocation(titleId, contentType, storageId);
return locationEntry.ContentPath != null ? return locationEntry.ContentPath != null ?
LocationHelper.GetStorageId(locationEntry.ContentPath) : StorageId.None; LocationHelper.GetStorageId(locationEntry.ContentPath) : StorageId.None;
}
} }
public string GetInstalledContentPath(long titleId, StorageId storageId, NcaContentType contentType) public string GetInstalledContentPath(long titleId, StorageId storageId, NcaContentType contentType)
{ {
LocationEntry locationEntry = GetLocation(titleId, contentType, storageId); lock (_lock)
if (VerifyContentType(locationEntry, contentType))
{ {
return locationEntry.ContentPath; LocationEntry locationEntry = GetLocation(titleId, contentType, storageId);
if (VerifyContentType(locationEntry, contentType))
{
return locationEntry.ContentPath;
}
} }
return string.Empty; return string.Empty;
@ -228,14 +251,17 @@ namespace Ryujinx.HLE.FileSystem.Content
public void RedirectLocation(LocationEntry newEntry, StorageId storageId) public void RedirectLocation(LocationEntry newEntry, StorageId storageId)
{ {
LocationEntry locationEntry = GetLocation(newEntry.TitleId, newEntry.ContentType, storageId); lock (_lock)
if (locationEntry.ContentPath != null)
{ {
RemoveLocationEntry(newEntry.TitleId, newEntry.ContentType, storageId); LocationEntry locationEntry = GetLocation(newEntry.TitleId, newEntry.ContentType, storageId);
}
AddLocationEntry(newEntry, storageId); if (locationEntry.ContentPath != null)
{
RemoveLocationEntry(newEntry.TitleId, newEntry.ContentType, storageId);
}
AddLocationEntry(newEntry, storageId);
}
} }
private bool VerifyContentType(LocationEntry locationEntry, NcaContentType contentType) private bool VerifyContentType(LocationEntry locationEntry, NcaContentType contentType)
@ -827,28 +853,31 @@ namespace Ryujinx.HLE.FileSystem.Content
{ {
LoadEntries(true); LoadEntries(true);
var locationEnties = _locationEntries[StorageId.NandSystem]; lock (_lock)
foreach (var entry in locationEnties)
{ {
if (entry.ContentType == NcaContentType.Data) var locationEnties = _locationEntries[StorageId.NandSystem];
foreach (var entry in locationEnties)
{ {
var path = _device.FileSystem.SwitchPathToSystemPath(entry.ContentPath); if (entry.ContentType == NcaContentType.Data)
using (IStorage ncaStorage = File.Open(path, FileMode.Open).AsStorage())
{ {
Nca nca = new Nca(_device.System.KeySet, ncaStorage); var path = _device.FileSystem.SwitchPathToSystemPath(entry.ContentPath);
if (nca.Header.TitleId == SystemVersionTitleId && nca.Header.ContentType == NcaContentType.Data) using (FileStream fileStream = File.OpenRead(path))
{ {
var romfs = nca.OpenFileSystem(NcaSectionType.Data, _device.System.FsIntegrityCheckLevel); Nca nca = new Nca(_device.System.KeySet, fileStream.AsStorage());
if (romfs.OpenFile(out IFile systemVersionFile, "/file", OpenMode.Read).IsSuccess()) if (nca.Header.TitleId == SystemVersionTitleId && nca.Header.ContentType == NcaContentType.Data)
{ {
return new SystemVersion(systemVersionFile.AsStream()); var romfs = nca.OpenFileSystem(NcaSectionType.Data, _device.System.FsIntegrityCheckLevel);
}
}
if (romfs.OpenFile(out IFile systemVersionFile, "/file", OpenMode.Read).IsSuccess())
{
return new SystemVersion(systemVersionFile.AsStream());
}
}
}
} }
} }
} }