using Ryujinx.HLE.HOS.Services.Sockets.Bsd; using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl; using Ryujinx.HLE.HOS.Services.Ssl.Types; using System; using System.IO; using System.Net.Security; using System.Net.Sockets; using System.Security.Authentication; namespace Ryujinx.HLE.HOS.Services.Ssl.SslService { class SslManagedSocketConnection : ISslConnectionBase { public int SocketFd { get; } public ISocket Socket { get; } private readonly BsdContext _bsdContext; private readonly SslVersion _sslVersion; private SslStream _stream; private bool _isBlockingSocket; private int _previousReadTimeout; public SslManagedSocketConnection(BsdContext bsdContext, SslVersion sslVersion, int socketFd, ISocket socket) { _bsdContext = bsdContext; _sslVersion = sslVersion; SocketFd = socketFd; Socket = socket; } private void StartSslOperation() { // Save blocking state _isBlockingSocket = Socket.Blocking; // Force blocking for SslStream Socket.Blocking = true; } private void EndSslOperation() { // Restore blocking state Socket.Blocking = _isBlockingSocket; } private void StartSslReadOperation() { StartSslOperation(); if (!_isBlockingSocket) { _previousReadTimeout = _stream.ReadTimeout; _stream.ReadTimeout = 1; } } private void EndSslReadOperation() { if (!_isBlockingSocket) { _stream.ReadTimeout = _previousReadTimeout; } EndSslOperation(); } // NOTE: We silence warnings about TLS 1.0 and 1.1 as games will likely use it. #pragma warning disable SYSLIB0039 private SslProtocols TranslateSslVersion(SslVersion version) { return (version & SslVersion.VersionMask) switch { SslVersion.Auto => SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12 | SslProtocols.Tls13, SslVersion.TlsV10 => SslProtocols.Tls, SslVersion.TlsV11 => SslProtocols.Tls11, SslVersion.TlsV12 => SslProtocols.Tls12, SslVersion.TlsV13 => SslProtocols.Tls13, _ => throw new NotImplementedException(version.ToString()), }; } #pragma warning restore SYSLIB0039 public ResultCode Handshake(string hostName) { StartSslOperation(); _stream = new SslStream(new NetworkStream(((ManagedSocket)Socket).Socket, false), false, null, null); _stream.AuthenticateAsClient(hostName, null, TranslateSslVersion(_sslVersion), false); EndSslOperation(); return ResultCode.Success; } public ResultCode Peek(out int peekCount, Memory buffer) { // NOTE: We cannot support that on .NET SSL API. // As Nintendo's curl implementation detail check if a connection is alive via Peek, we just return that it would block to let it know that it's alive. peekCount = -1; return ResultCode.WouldBlock; } public int Pending() { // Unsupported return 0; } private bool TryTranslateWinSockError(bool isBlocking, WsaError error, out ResultCode resultCode) { switch (error) { case WsaError.WSAETIMEDOUT: resultCode = isBlocking ? ResultCode.Timeout : ResultCode.WouldBlock; return true; case WsaError.WSAECONNABORTED: resultCode = ResultCode.ConnectionAbort; return true; case WsaError.WSAECONNRESET: resultCode = ResultCode.ConnectionReset; return true; default: resultCode = ResultCode.Success; return false; } } public ResultCode Read(out int readCount, Memory buffer) { if (!Socket.Poll(0, SelectMode.SelectRead)) { readCount = -1; return ResultCode.WouldBlock; } StartSslReadOperation(); try { readCount = _stream.Read(buffer.Span); } catch (IOException exception) { readCount = -1; if (exception.InnerException is SocketException socketException) { WsaError socketErrorCode = (WsaError)socketException.SocketErrorCode; if (TryTranslateWinSockError(_isBlockingSocket, socketErrorCode, out ResultCode result)) { return result; } else { throw socketException; } } else { throw; } } finally { EndSslReadOperation(); } return ResultCode.Success; } public ResultCode Write(out int writtenCount, ReadOnlyMemory buffer) { if (!Socket.Poll(0, SelectMode.SelectWrite)) { writtenCount = 0; return ResultCode.WouldBlock; } StartSslOperation(); try { _stream.Write(buffer.Span); } catch (IOException exception) { writtenCount = -1; if (exception.InnerException is SocketException socketException) { WsaError socketErrorCode = (WsaError)socketException.SocketErrorCode; if (TryTranslateWinSockError(_isBlockingSocket, socketErrorCode, out ResultCode result)) { return result; } else { throw socketException; } } else { throw; } } finally { EndSslOperation(); } // .NET API doesn't provide the size written, assume all written. writtenCount = buffer.Length; return ResultCode.Success; } public ResultCode GetServerCertificate(string hostname, Span certificates, out uint storageSize, out uint certificateCount) { byte[] rawCertData = _stream.RemoteCertificate.GetRawCertData(); storageSize = (uint)rawCertData.Length; certificateCount = 1; if (rawCertData.Length > certificates.Length) { return ResultCode.CertBufferTooSmall; } rawCertData.CopyTo(certificates); return ResultCode.Success; } public void Dispose() { _bsdContext.CloseFileDescriptor(SocketFd); } } }