mirror of
https://github.com/lightninglabs/aperture.git
synced 2026-01-31 15:14:26 +01:00
proxy: add skipping invoice creation on request
This commit is contained in:
@@ -172,6 +172,7 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Determine auth level required to access service and dispatch request
|
||||
// accordingly.
|
||||
authLevel := target.AuthRequired(r)
|
||||
skipInvoiceCreation := target.SkipInvoiceCreation(r)
|
||||
switch {
|
||||
case authLevel.IsOn():
|
||||
// Determine if the header contains the authentication
|
||||
@@ -181,6 +182,16 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// resources.
|
||||
acceptAuth := p.authenticator.Accept(&r.Header, resourceName)
|
||||
if !acceptAuth {
|
||||
if skipInvoiceCreation {
|
||||
addCorsHeaders(w.Header())
|
||||
sendDirectResponse(
|
||||
w, r, http.StatusUnauthorized,
|
||||
"unauthorized",
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
price, err := target.pricer.GetPrice(r.Context(), r)
|
||||
if err != nil {
|
||||
prefixLog.Errorf("error getting "+
|
||||
|
||||
@@ -48,10 +48,11 @@ var (
|
||||
)
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
auth auth.Level
|
||||
authWhitelist []string
|
||||
wantBackendErr bool
|
||||
name string
|
||||
auth auth.Level
|
||||
authWhitelist []string
|
||||
authSkipInvoiceCreationPaths []string
|
||||
wantBackendErr bool
|
||||
}
|
||||
|
||||
// helloServer is a simple server that implements the GreeterServer interface.
|
||||
@@ -98,6 +99,15 @@ func TestProxyHTTP(t *testing.T) {
|
||||
name: "with whitelist",
|
||||
auth: "on",
|
||||
authWhitelist: []string{"^/http/white.*$"},
|
||||
}, {
|
||||
name: "no whitelist with skip",
|
||||
auth: "on",
|
||||
authSkipInvoiceCreationPaths: []string{"^/http/skip.*$"},
|
||||
}, {
|
||||
name: "with whitelist with skip",
|
||||
auth: "on",
|
||||
authWhitelist: []string{"^/http/white.*$"},
|
||||
authSkipInvoiceCreationPaths: []string{"^/http/skip.*$"},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -182,12 +192,13 @@ func TestProxyHTTPBlocklist(t *testing.T) {
|
||||
func runHTTPTest(t *testing.T, tc *testCase, method string) {
|
||||
// Create a list of services to proxy between.
|
||||
services := []*proxy.Service{{
|
||||
Address: testTargetServiceAddress,
|
||||
HostRegexp: testHostRegexp,
|
||||
PathRegexp: testPathRegexpHTTP,
|
||||
Protocol: "http",
|
||||
Auth: tc.auth,
|
||||
AuthWhitelistPaths: tc.authWhitelist,
|
||||
Address: testTargetServiceAddress,
|
||||
HostRegexp: testHostRegexp,
|
||||
PathRegexp: testPathRegexpHTTP,
|
||||
Protocol: "http",
|
||||
Auth: tc.auth,
|
||||
AuthWhitelistPaths: tc.authWhitelist,
|
||||
AuthSkipInvoiceCreationPaths: tc.authSkipInvoiceCreationPaths,
|
||||
}}
|
||||
|
||||
mockAuth := auth.NewMockAuthenticator()
|
||||
@@ -261,8 +272,33 @@ func runHTTPTest(t *testing.T, tc *testCase, method string) {
|
||||
require.EqualValues(t, len(bodyBytes), resp.ContentLength)
|
||||
}
|
||||
|
||||
// Make sure that if we query a URL that is on the skip invoice
|
||||
// creation list, we get a 401 if auth fails.
|
||||
if len(tc.authSkipInvoiceCreationPaths) > 0 {
|
||||
urlToSkip := fmt.Sprintf("http://%s/http/skip", testProxyAddr)
|
||||
reqToSkip, err := http.NewRequest(method, urlToSkip, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
if method == "POST" {
|
||||
reqToSkip.Header.Add("Content-Type", "application/json")
|
||||
reqToSkip.Body = io.NopCloser(strings.NewReader(`{}`))
|
||||
}
|
||||
|
||||
respSkipped, err := client.Do(reqToSkip)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, respSkipped.StatusCode)
|
||||
require.Equal(t, "401 Unauthorized", respSkipped.Status)
|
||||
|
||||
bodySkippedContent, err := io.ReadAll(respSkipped.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "unauthorized\n", string(bodySkippedContent))
|
||||
require.EqualValues(t, len(bodySkippedContent), respSkipped.ContentLength)
|
||||
_ = respSkipped.Body.Close()
|
||||
}
|
||||
|
||||
// Make sure that if the Auth header is set, the client's request is
|
||||
// proxied to the backend service.
|
||||
// proxied to the backend service for a non-skipped, non-whitelisted path.
|
||||
req, err = http.NewRequest(method, url, nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Add("Authorization", "foobar")
|
||||
@@ -297,6 +333,12 @@ func TestProxyGRPC(t *testing.T) {
|
||||
authWhitelist: []string{
|
||||
"^/proxy_test\\.Greeter/SayHelloNoAuth.*$",
|
||||
},
|
||||
}, {
|
||||
name: "gRPC no whitelist with skip for SayHello",
|
||||
auth: "on",
|
||||
authSkipInvoiceCreationPaths: []string{
|
||||
`^/proxy_test[.]Greeter/SayHello.*$`,
|
||||
},
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -343,13 +385,14 @@ func runGRPCTest(t *testing.T, tc *testCase) {
|
||||
|
||||
// Create a list of services to proxy between.
|
||||
services := []*proxy.Service{{
|
||||
Address: testTargetServiceAddress,
|
||||
HostRegexp: testHostRegexp,
|
||||
PathRegexp: testPathRegexpGRPC,
|
||||
Protocol: "https",
|
||||
TLSCertPath: certFile,
|
||||
Auth: tc.auth,
|
||||
AuthWhitelistPaths: tc.authWhitelist,
|
||||
Address: testTargetServiceAddress,
|
||||
HostRegexp: testHostRegexp,
|
||||
PathRegexp: testPathRegexpGRPC,
|
||||
Protocol: "https",
|
||||
TLSCertPath: certFile,
|
||||
Auth: tc.auth,
|
||||
AuthWhitelistPaths: tc.authWhitelist,
|
||||
AuthSkipInvoiceCreationPaths: tc.authSkipInvoiceCreationPaths,
|
||||
}}
|
||||
|
||||
// Create the proxy server and start serving on TLS.
|
||||
@@ -393,17 +436,24 @@ func runGRPCTest(t *testing.T, tc *testCase) {
|
||||
grpc.Trailer(&captureMetadata),
|
||||
)
|
||||
require.Error(t, err)
|
||||
require.True(t, l402.IsPaymentRequired(err))
|
||||
if len(tc.authSkipInvoiceCreationPaths) > 0 {
|
||||
statusErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, codes.Internal, statusErr.Code())
|
||||
require.Equal(t, "unauthorized", statusErr.Message())
|
||||
} else {
|
||||
require.True(t, l402.IsPaymentRequired(err))
|
||||
|
||||
// We expect the WWW-Authenticate header field to be set to an L402
|
||||
// auth response.
|
||||
expectedHeaderContent, _ := mockAuth.FreshChallengeHeader("", 0)
|
||||
capturedHeader := captureMetadata.Get("WWW-Authenticate")
|
||||
require.Len(t, capturedHeader, 2)
|
||||
require.Equal(
|
||||
t, expectedHeaderContent.Values("WWW-Authenticate"),
|
||||
capturedHeader,
|
||||
)
|
||||
// We expect the WWW-Authenticate header field to be set to an L402
|
||||
// auth response.
|
||||
expectedHeaderContent, _ := mockAuth.FreshChallengeHeader("", 0)
|
||||
capturedHeader := captureMetadata.Get("WWW-Authenticate")
|
||||
require.Len(t, capturedHeader, 2)
|
||||
require.Equal(
|
||||
t, expectedHeaderContent.Values("WWW-Authenticate"),
|
||||
capturedHeader,
|
||||
)
|
||||
}
|
||||
|
||||
// Make sure that if we query an URL that is on the whitelist, we don't
|
||||
// get the 402 response.
|
||||
|
||||
Reference in New Issue
Block a user