diff --git a/BTCPayServer.Tests/UnitTest1.cs b/BTCPayServer.Tests/UnitTest1.cs index 4599f9405..96c51a92f 100644 --- a/BTCPayServer.Tests/UnitTest1.cs +++ b/BTCPayServer.Tests/UnitTest1.cs @@ -897,6 +897,14 @@ namespace BTCPayServer.Tests { await tester.StartAsync(); var proxy = tester.PayTester.GetService(); + void AssertConnectionDropped() + { + TestUtils.Eventually(() => + { + Thread.MemoryBarrier(); + Assert.Equal(0, proxy.ConnectionCount); + }); + } var httpFactory = tester.PayTester.GetService(); var client = httpFactory.CreateClient(PayjoinClient.PayjoinOnionNamedClient); Assert.NotNull(client); @@ -905,46 +913,32 @@ namespace BTCPayServer.Tests var result = await response.Content.ReadAsStringAsync(); Assert.DoesNotContain("You are not using Tor.", result); Assert.Contains("Congratulations. This browser is configured to use Tor.", result); - Logs.Tester.LogInformation("Now we should have one connection"); - TestUtils.Eventually(() => - { - Thread.MemoryBarrier(); - Assert.Equal(1, proxy.ConnectionCount); - }); + AssertConnectionDropped(); response = await client.GetAsync("http://explorerzydxu5ecjrkwceayqybizmpjjznk5izmitf2modhcusuqlid.onion/"); response.EnsureSuccessStatusCode(); result = await response.Content.ReadAsStringAsync(); Assert.Contains("Bitcoin", result); - Logs.Tester.LogInformation("Now we should have two connections"); - TestUtils.Eventually(() => - { - Thread.MemoryBarrier(); - Assert.Equal(2, proxy.ConnectionCount); - }); + + AssertConnectionDropped(); response = await client.GetAsync("http://explorerzydxu5ecjrkwceayqybizmpjjznk5izmitf2modhcusuqlid.onion/"); response.EnsureSuccessStatusCode(); - Logs.Tester.LogInformation("Querying the same address should not create additional connection"); - TestUtils.Eventually(() => - { - Thread.MemoryBarrier(); - Assert.Equal(2, proxy.ConnectionCount); - }); + AssertConnectionDropped(); client.Dispose(); - Logs.Tester.LogInformation("Disposing a HttpClient should not proxy connection"); - TestUtils.Eventually(() => - { - Thread.MemoryBarrier(); - Assert.Equal(2, proxy.ConnectionCount); - }); + AssertConnectionDropped(); client = httpFactory.CreateClient(PayjoinClient.PayjoinOnionNamedClient); response = await client.GetAsync("http://explorerzydxu5ecjrkwceayqybizmpjjznk5izmitf2modhcusuqlid.onion/"); response.EnsureSuccessStatusCode(); - Logs.Tester.LogInformation("Querying the same address with same client should not create additional connection"); - TestUtils.Eventually(() => - { - Thread.MemoryBarrier(); - Assert.Equal(2, proxy.ConnectionCount); - }); + AssertConnectionDropped(); + + Logs.Tester.LogInformation("Querying an onion address which can't be found should send http 500"); + response = await client.GetAsync("http://dwoduwoi.onion/"); + Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode); + AssertConnectionDropped(); + + Logs.Tester.LogInformation("Querying valid onion but unreachable should send error 502"); + response = await client.GetAsync("http://fastrcl5totos3vekjbqcmgpnias5qytxnaj7gpxtxhubdcnfrkapqad.onion/"); + Assert.Equal(HttpStatusCode.BadGateway, response.StatusCode); + AssertConnectionDropped(); } } diff --git a/BTCPayServer/HostedServices/Socks5HttpProxyServer.cs b/BTCPayServer/HostedServices/Socks5HttpProxyServer.cs index 05c27bb36..2e8713391 100644 --- a/BTCPayServer/HostedServices/Socks5HttpProxyServer.cs +++ b/BTCPayServer/HostedServices/Socks5HttpProxyServer.cs @@ -26,11 +26,13 @@ namespace BTCPayServer.HostedServices public Socket ClientSocket; public Socket SocksSocket; public CancellationToken CancellationToken; + public CancellationTokenSource CancellationTokenSource; public void Dispose() { Socks5HttpProxyServer.Dispose(ClientSocket); Socks5HttpProxyServer.Dispose(SocksSocket); + CancellationTokenSource.Dispose(); } } private readonly BTCPayServerOptions _opts; @@ -76,11 +78,13 @@ namespace BTCPayServer.HostedServices return; } var toSocksProxy = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + var connectionCts = CancellationTokenSource.CreateLinkedTokenSource(_Cts.Token); toSocksProxy.BeginConnect(_opts.SocksEndpoint, ConnectToSocks, new ProxyConnection() { ClientSocket = clientSocket, SocksSocket = toSocksProxy, - CancellationToken = _Cts.Token + CancellationToken = connectionCts.Token, + CancellationTokenSource = connectionCts }); _ServerSocket.BeginAccept(Accept, null); } @@ -99,14 +103,17 @@ namespace BTCPayServer.HostedServices } Interlocked.Increment(ref connectionCount); var pipe = new Pipe(PipeOptions.Default); - var reading = FillPipeAsync(connection.ClientSocket, pipe.Writer, connection.CancellationToken); - var writing = ReadPipeAsync(connection.SocksSocket, connection.ClientSocket, pipe.Reader, connection.CancellationToken); + CancellationTokenSource.CreateLinkedTokenSource(connection.CancellationToken); + var reading = FillPipeAsync(connection.ClientSocket, pipe.Writer, connection.CancellationToken) + .ContinueWith(_ => connection.CancellationTokenSource.Cancel(), TaskScheduler.Default); + var writing = ReadPipeAsync(connection.SocksSocket, connection.ClientSocket, pipe.Reader, connection.CancellationToken) + .ContinueWith(_ => connection.CancellationTokenSource.Cancel(), TaskScheduler.Default); _ = Task.WhenAll(reading, writing) .ContinueWith(_ => { connection.Dispose(); Interlocked.Decrement(ref connectionCount); - }); + }, TaskScheduler.Default); } private int connectionCount = 0; @@ -197,19 +204,23 @@ namespace BTCPayServer.HostedServices } catch (SocksException e) when (e.SocksErrorCode == SocksErrorCode.HostUnreachable || e.SocksErrorCode == SocksErrorCode.HostUnreachable) { - await SendAsync(clientSocket , $"{httpVersion} 502 Bad Gateway\r\n\r\n", cancellationToken); + await SendAsync(clientSocket , $"{httpVersion} 502 Bad Gateway\r\nContent-Length: 0\r\n\r\n", cancellationToken); + goto done; } catch (SocksException e) { - await SendAsync(clientSocket , $"{httpVersion} 500 Internal Server Error\r\nX-Proxy-Error-Type: Socks {e.SocksErrorCode}\r\n\r\n", cancellationToken); + await SendAsync(clientSocket , $"{httpVersion} 500 Internal Server Error\r\nContent-Length: 0\r\nX-Proxy-Error-Type: Socks {e.SocksErrorCode}\r\n\r\n", cancellationToken); + goto done; } catch (SocketException e) { - await SendAsync(clientSocket , $"{httpVersion} 500 Internal Server Error\r\nX-Proxy-Error-Type: Socket {e.SocketErrorCode}\r\n\r\n", cancellationToken); + await SendAsync(clientSocket , $"{httpVersion} 500 Internal Server Error\r\nContent-Length: 0\r\nX-Proxy-Error-Type: Socket {e.SocketErrorCode}\r\n\r\n", cancellationToken); + goto done; } catch { await SendAsync(clientSocket , $"{httpVersion} 500 Internal Server Error\r\n\r\n", cancellationToken); + goto done; } } else @@ -231,6 +242,7 @@ namespace BTCPayServer.HostedServices } } + done: // Mark the PipeReader as complete reader.Complete(); } diff --git a/BTCPayServer/Payments/PayJoin/PayJoinExtensions.cs b/BTCPayServer/Payments/PayJoin/PayJoinExtensions.cs index c1dc8cbe2..f67499e7a 100644 --- a/BTCPayServer/Payments/PayJoin/PayJoinExtensions.cs +++ b/BTCPayServer/Payments/PayJoin/PayJoinExtensions.cs @@ -17,6 +17,7 @@ namespace BTCPayServer.Payments.PayJoin services.AddSingleton(); services.AddTransient(); services.AddHttpClient(PayjoinClient.PayjoinOnionNamedClient) + .ConfigureHttpClient(h => h.DefaultRequestHeaders.ConnectionClose = true ) .ConfigurePrimaryHttpMessageHandler(); } }