diff --git a/proxy/proxy.go b/proxy/proxy.go index 8dca1f6..7dfb5d4 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -172,6 +172,7 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Determine auth level required to access service and dispatch request // accordingly. authLevel := target.AuthRequired(r) + skipInvoiceCreation := target.SkipInvoiceCreation(r) switch { case authLevel.IsOn(): // Determine if the header contains the authentication @@ -181,6 +182,16 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { // resources. acceptAuth := p.authenticator.Accept(&r.Header, resourceName) if !acceptAuth { + if skipInvoiceCreation { + addCorsHeaders(w.Header()) + sendDirectResponse( + w, r, http.StatusUnauthorized, + "unauthorized", + ) + + return + } + price, err := target.pricer.GetPrice(r.Context(), r) if err != nil { prefixLog.Errorf("error getting "+ diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 64a2372..48ab200 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -48,10 +48,11 @@ var ( ) type testCase struct { - name string - auth auth.Level - authWhitelist []string - wantBackendErr bool + name string + auth auth.Level + authWhitelist []string + authSkipInvoiceCreationPaths []string + wantBackendErr bool } // helloServer is a simple server that implements the GreeterServer interface. @@ -98,6 +99,15 @@ func TestProxyHTTP(t *testing.T) { name: "with whitelist", auth: "on", authWhitelist: []string{"^/http/white.*$"}, + }, { + name: "no whitelist with skip", + auth: "on", + authSkipInvoiceCreationPaths: []string{"^/http/skip.*$"}, + }, { + name: "with whitelist with skip", + auth: "on", + authWhitelist: []string{"^/http/white.*$"}, + authSkipInvoiceCreationPaths: []string{"^/http/skip.*$"}, }} for _, tc := range testCases { @@ -182,12 +192,13 @@ func TestProxyHTTPBlocklist(t *testing.T) { func runHTTPTest(t *testing.T, tc *testCase, method string) { // Create a list of services to proxy between. services := []*proxy.Service{{ - Address: testTargetServiceAddress, - HostRegexp: testHostRegexp, - PathRegexp: testPathRegexpHTTP, - Protocol: "http", - Auth: tc.auth, - AuthWhitelistPaths: tc.authWhitelist, + Address: testTargetServiceAddress, + HostRegexp: testHostRegexp, + PathRegexp: testPathRegexpHTTP, + Protocol: "http", + Auth: tc.auth, + AuthWhitelistPaths: tc.authWhitelist, + AuthSkipInvoiceCreationPaths: tc.authSkipInvoiceCreationPaths, }} mockAuth := auth.NewMockAuthenticator() @@ -261,8 +272,33 @@ func runHTTPTest(t *testing.T, tc *testCase, method string) { require.EqualValues(t, len(bodyBytes), resp.ContentLength) } + // Make sure that if we query a URL that is on the skip invoice + // creation list, we get a 401 if auth fails. + if len(tc.authSkipInvoiceCreationPaths) > 0 { + urlToSkip := fmt.Sprintf("http://%s/http/skip", testProxyAddr) + reqToSkip, err := http.NewRequest(method, urlToSkip, nil) + require.NoError(t, err) + + if method == "POST" { + reqToSkip.Header.Add("Content-Type", "application/json") + reqToSkip.Body = io.NopCloser(strings.NewReader(`{}`)) + } + + respSkipped, err := client.Do(reqToSkip) + require.NoError(t, err) + + require.Equal(t, http.StatusUnauthorized, respSkipped.StatusCode) + require.Equal(t, "401 Unauthorized", respSkipped.Status) + + bodySkippedContent, err := io.ReadAll(respSkipped.Body) + require.NoError(t, err) + require.Equal(t, "unauthorized\n", string(bodySkippedContent)) + require.EqualValues(t, len(bodySkippedContent), respSkipped.ContentLength) + _ = respSkipped.Body.Close() + } + // Make sure that if the Auth header is set, the client's request is - // proxied to the backend service. + // proxied to the backend service for a non-skipped, non-whitelisted path. req, err = http.NewRequest(method, url, nil) require.NoError(t, err) req.Header.Add("Authorization", "foobar") @@ -297,6 +333,12 @@ func TestProxyGRPC(t *testing.T) { authWhitelist: []string{ "^/proxy_test\\.Greeter/SayHelloNoAuth.*$", }, + }, { + name: "gRPC no whitelist with skip for SayHello", + auth: "on", + authSkipInvoiceCreationPaths: []string{ + `^/proxy_test[.]Greeter/SayHello.*$`, + }, }} for _, tc := range testCases { @@ -343,13 +385,14 @@ func runGRPCTest(t *testing.T, tc *testCase) { // Create a list of services to proxy between. services := []*proxy.Service{{ - Address: testTargetServiceAddress, - HostRegexp: testHostRegexp, - PathRegexp: testPathRegexpGRPC, - Protocol: "https", - TLSCertPath: certFile, - Auth: tc.auth, - AuthWhitelistPaths: tc.authWhitelist, + Address: testTargetServiceAddress, + HostRegexp: testHostRegexp, + PathRegexp: testPathRegexpGRPC, + Protocol: "https", + TLSCertPath: certFile, + Auth: tc.auth, + AuthWhitelistPaths: tc.authWhitelist, + AuthSkipInvoiceCreationPaths: tc.authSkipInvoiceCreationPaths, }} // Create the proxy server and start serving on TLS. @@ -393,17 +436,24 @@ func runGRPCTest(t *testing.T, tc *testCase) { grpc.Trailer(&captureMetadata), ) require.Error(t, err) - require.True(t, l402.IsPaymentRequired(err)) + if len(tc.authSkipInvoiceCreationPaths) > 0 { + statusErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, codes.Internal, statusErr.Code()) + require.Equal(t, "unauthorized", statusErr.Message()) + } else { + require.True(t, l402.IsPaymentRequired(err)) - // We expect the WWW-Authenticate header field to be set to an L402 - // auth response. - expectedHeaderContent, _ := mockAuth.FreshChallengeHeader("", 0) - capturedHeader := captureMetadata.Get("WWW-Authenticate") - require.Len(t, capturedHeader, 2) - require.Equal( - t, expectedHeaderContent.Values("WWW-Authenticate"), - capturedHeader, - ) + // We expect the WWW-Authenticate header field to be set to an L402 + // auth response. + expectedHeaderContent, _ := mockAuth.FreshChallengeHeader("", 0) + capturedHeader := captureMetadata.Get("WWW-Authenticate") + require.Len(t, capturedHeader, 2) + require.Equal( + t, expectedHeaderContent.Values("WWW-Authenticate"), + capturedHeader, + ) + } // Make sure that if we query an URL that is on the whitelist, we don't // get the 402 response.