diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 411f4d0..d049f1c 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -1,58 +1,98 @@ package proxy_test import ( + "bytes" + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "fmt" "io/ioutil" + "math/big" + "net" "net/http" + "os" + "path" "strings" "testing" "time" "github.com/lightninglabs/kirin/auth" "github.com/lightninglabs/kirin/proxy" + proxytest "github.com/lightninglabs/kirin/proxy/testdata" + "github.com/lightningnetwork/lnd/macaroons" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/status" + "gopkg.in/macaroon.v2" ) const ( - testAddr = "localhost:10019" + testProxyAddr = "localhost:10019" testHostRegexp = "^localhost:.*$" - testPathRegexp = "^/grpc/.*$" + testPathRegexpHTTP = "^/http/.*$" + testPathRegexpGRPC = "^/proxy_test.*$" testTargetServiceAddress = "localhost:8082" testHTTPResponseBody = "HTTP Hello" ) -func TestProxy(t *testing.T) { +var ( + serialNumberLimit = new(big.Int).Lsh(big.NewInt(1), 128) + tlsCipherSuites = []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + } +) + +// helloServer is a simple server that implements the GreeterServer interface. +type helloServer struct{} + +// SayHello returns a simple string that also contains a string from the +// request. +func (s *helloServer) SayHello(ctx context.Context, + req *proxytest.HelloRequest) (*proxytest.HelloReply, error) { + + return &proxytest.HelloReply{ + Message: fmt.Sprintf("Hello %s", req.Name), + }, nil +} + +// TestProxyHTTP tests that the proxy can forward HTTP requests to a backend +// service and handle LSAT authentication correctly. +func TestProxyHTTP(t *testing.T) { // Create a list of services to proxy between. services := []*proxy.Service{{ Address: testTargetServiceAddress, HostRegexp: testHostRegexp, - PathRegexp: testPathRegexp, + PathRegexp: testPathRegexpHTTP, Protocol: "http", }} - auth := auth.NewMockAuthenticator() - proxy, err := proxy.New(auth, services, "static") + mockAuth := auth.NewMockAuthenticator() + p, err := proxy.New(mockAuth, services, "static") if err != nil { t.Fatalf("failed to create new proxy: %v", err) } // Start server that gives requests to the proxy. server := &http.Server{ - Addr: testAddr, - Handler: http.HandlerFunc(proxy.ServeHTTP), + Addr: testProxyAddr, + Handler: http.HandlerFunc(p.ServeHTTP), } - - go func() { - if err := server.ListenAndServe(); err != nil { - t.Fatalf("failed to serve to proxy: %v", err) - } - }() + go server.ListenAndServe() + defer server.Close() // Start the target backend service. - go func() { - if err := startHTTPHello(); err != nil { - t.Fatalf("failed to start backend service: %v", err) - } - }() + backendService := &http.Server{Addr: testTargetServiceAddress} + go startBackendHTTP(backendService) + defer backendService.Close() // Wait for servers to start. time.Sleep(100 * time.Millisecond) @@ -60,7 +100,7 @@ func TestProxy(t *testing.T) { // Test making a request to the backend service without the // Authorization header set. client := &http.Client{} - url := fmt.Sprintf("http://%s/grpc/test", testAddr) + url := fmt.Sprintf("http://%s/http/test", testProxyAddr) resp, err := client.Get(url) if err != nil { t.Fatalf("errored making http request: %v", err) @@ -79,6 +119,9 @@ func TestProxy(t *testing.T) { // Make sure that if the Auth header is set, the client's request is // proxied to the backend service. req, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Fatalf("error creating request: %v", err) + } req.Header.Add("Authorization", "foobar") resp, err = client.Do(req) @@ -103,10 +146,287 @@ func TestProxy(t *testing.T) { } } -func startHTTPHello() error { - sayHello := func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(testHTTPResponseBody)) +// TestProxyHTTP tests that the proxy can forward gRPC requests to a backend +// service and handle LSAT authentication correctly. +func TestProxyGRPC(t *testing.T) { + // Since gRPC only really works over TLS, we need to generate a + // certificate and key pair first. + tempDirName, err := ioutil.TempDir("", "proxytest") + if err != nil { + t.Fatalf("unable to create temp dir: %v", err) + } + certFile := path.Join(tempDirName, "proxy.cert") + keyFile := path.Join(tempDirName, "proxy.key") + certPool, creds, cert, err := genCertPair(certFile, keyFile) + if err != nil { + t.Fatalf("unable to create cert pair: %v", err) + } + + // Create a list of services to proxy between. + services := []*proxy.Service{{ + Address: testTargetServiceAddress, + HostRegexp: testHostRegexp, + PathRegexp: testPathRegexpGRPC, + Protocol: "https", + TLSCertPath: certFile, + }} + + // Create the proxy server and start serving on TLS. + mockAuth := auth.NewMockAuthenticator() + p, err := proxy.New(mockAuth, services, "static") + if err != nil { + t.Fatalf("failed to create new proxy: %v", err) + } + server := &http.Server{ + Addr: testProxyAddr, + Handler: http.HandlerFunc(p.ServeHTTP), + TLSConfig: &tls.Config{ + RootCAs: certPool, + InsecureSkipVerify: true, + }, + } + go server.ListenAndServeTLS(certFile, keyFile) + defer server.Close() + + // Start the target backend service also on TLS. + tlsConf := &tls.Config{ + Certificates: []tls.Certificate{cert}, + CipherSuites: tlsCipherSuites, + MinVersion: tls.VersionTLS12, + } + serverOpts := []grpc.ServerOption{ + grpc.Creds(credentials.NewTLS(tlsConf)), + } + backendService := grpc.NewServer(serverOpts...) + go startBackendGRPC(backendService) + defer backendService.Stop() + + // Dial to the proxy now, without any authentication. + opts := []grpc.DialOption{grpc.WithTransportCredentials(creds)} + conn, err := grpc.Dial(testProxyAddr, opts...) + if err != nil { + t.Fatalf("unable to connect to RPC server: %v", err) + } + client := proxytest.NewGreeterClient(conn) + + // Make request without authentication. We expect an error that can + // be parsed by gRPC. + req := &proxytest.HelloRequest{Name: "foo"} + res, err := client.SayHello( + context.Background(), req, grpc.WaitForReady(true), + ) + if err == nil { + t.Fatalf("expected error to be returned without auth") + } + statusErr, ok := status.FromError(err) + if !ok { + t.Fatalf("expected error to be status.Status") + } + if statusErr.Code() != codes.Internal { + t.Fatalf("unexpected code. wanted %d, got %d", + codes.Internal, statusErr.Code()) + } + if statusErr.Message() != "payment required" { + t.Fatalf("invalid error. expected [%s] got [%s]", + "payment required", err.Error()) + } + + // Dial to the proxy again, this time with a dummy macaroon. + dummyMac, err := macaroon.New( + []byte("key"), []byte("id"), "loc", macaroon.LatestVersion, + ) + opts = []grpc.DialOption{ + grpc.WithTransportCredentials(creds), + grpc.WithPerRPCCredentials(macaroons.NewMacaroonCredential( + dummyMac, + )), + } + conn, err = grpc.Dial(testProxyAddr, opts...) + if err != nil { + t.Fatalf("unable to connect to RPC server: %v", err) + } + client = proxytest.NewGreeterClient(conn) + + // Make the request. This time no error should be returned. + req = &proxytest.HelloRequest{Name: "foo"} + res, err = client.SayHello(context.Background(), req) + if err != nil { + t.Fatalf("unable to call service: %v", err) + } + if res.Message != "Hello foo" { + t.Fatalf("unexpected reply, wanted %s, got %s", + "Hello foo", res.Message) } - http.HandleFunc("/", sayHello) - return http.ListenAndServe(testTargetServiceAddress, nil) +} + +// startBackendHTTP starts the given HTTP server and blocks until the server +// is shut down. +func startBackendHTTP(server *http.Server) error { + sayHello := func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte(testHTTPResponseBody)) + if err != nil { + panic(err) + } + } + server.Handler = http.HandlerFunc(sayHello) + return server.ListenAndServe() +} + +// startBackendGRPC starts the given RPC server and blocks until the server is +// shut down. +func startBackendGRPC(grpcServer *grpc.Server) error { + server := helloServer{} + proxytest.RegisterGreeterServer(grpcServer, &server) + grpcListener, err := net.Listen("tcp", testTargetServiceAddress) + if err != nil { + return fmt.Errorf("RPC server unable to listen on %s", + testTargetServiceAddress) + + } + return grpcServer.Serve(grpcListener) +} + +// genCertPair generates a pair of private key and certificate and returns them +// in different formats needed to spin up test servers and clients. +func genCertPair(certFile, keyFile string) (*x509.CertPool, + credentials.TransportCredentials, tls.Certificate, error) { + + org := "kirin autogenerated cert" + cert := tls.Certificate{} + now := time.Now() + validUntil := now.Add(1 * time.Hour) + + // Generate a serial number that's below the serialNumberLimit. + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, nil, cert, fmt.Errorf("failed to generate serial "+ + "number: %s", err) + } + + // Collect the host's IP addresses, including loopback, in a slice. + ipAddresses := []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")} + + // addIP appends an IP address only if it isn't already in the slice. + addIP := func(ipAddr net.IP) { + for _, ip := range ipAddresses { + if ip.Equal(ipAddr) { + return + } + } + ipAddresses = append(ipAddresses, ipAddr) + } + + // Add all the interface IPs that aren't already in the slice. + addrs, err := net.InterfaceAddrs() + if err != nil { + return nil, nil, cert, err + } + for _, a := range addrs { + ipAddr, _, err := net.ParseCIDR(a.String()) + if err == nil { + addIP(ipAddr) + } + } + + // Collect the host's names into a slice. + host, err := os.Hostname() + if err != nil { + return nil, nil, cert, err + } + dnsNames := []string{host} + if host != "localhost" { + dnsNames = append(dnsNames, "localhost") + } + + // Generate a private key for the certificate. + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, cert, err + } + + // Construct the certificate template. + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{org}, + CommonName: host, + }, + NotBefore: now.Add(-time.Hour * 24), + NotAfter: validUntil, + + KeyUsage: x509.KeyUsageKeyEncipherment | + x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + IsCA: true, // so can sign self. + BasicConstraintsValid: true, + + DNSNames: dnsNames, + IPAddresses: ipAddresses, + } + + derBytes, err := x509.CreateCertificate( + rand.Reader, &template, &template, &priv.PublicKey, priv, + ) + if err != nil { + return nil, nil, cert, fmt.Errorf("failed to create "+ + "certificate: %v", err) + } + + certBuf := &bytes.Buffer{} + err = pem.Encode( + certBuf, + &pem.Block{Type: "CERTIFICATE", + Bytes: derBytes, + }, + ) + if err != nil { + return nil, nil, cert, fmt.Errorf("failed to encode "+ + "certificate: %v", err) + } + + keybytes, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return nil, nil, cert, fmt.Errorf("unable to encode privkey: "+ + "%v", err) + } + keyBuf := &bytes.Buffer{} + err = pem.Encode( + keyBuf, + &pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: keybytes, + }, + ) + if err != nil { + return nil, nil, cert, fmt.Errorf("failed to encode private "+ + "key: %v", err) + } + cert, err = tls.X509KeyPair(certBuf.Bytes(), keyBuf.Bytes()) + if err != nil { + return nil, nil, cert, fmt.Errorf("failed to create key pair: "+ + "%v", err) + } + + // Write cert and key files. + if err = ioutil.WriteFile(certFile, certBuf.Bytes(), 0644); err != nil { + return nil, nil, cert, fmt.Errorf("unable to write cert file "+ + "at %v: %v", certFile, err) + } + if err = ioutil.WriteFile(keyFile, keyBuf.Bytes(), 0600); err != nil { + os.Remove(certFile) + return nil, nil, cert, fmt.Errorf("unable to write key file "+ + "at %v: %v", keyFile, err) + } + + cp := x509.NewCertPool() + if !cp.AppendCertsFromPEM(certBuf.Bytes()) { + return nil, nil, cert, fmt.Errorf("credentials: failed to " + + "append certificate") + } + + creds, err := credentials.NewClientTLSFromFile(certFile, "") + if err != nil { + return nil, nil, cert, fmt.Errorf("unable to load cert file: "+ + "%v", err) + } + return cp, creds, cert, nil } diff --git a/proxy/testdata/gen_protos.sh b/proxy/testdata/gen_protos.sh new file mode 100755 index 0000000..1e9032f --- /dev/null +++ b/proxy/testdata/gen_protos.sh @@ -0,0 +1,8 @@ +#!/bin/sh + +set -e + +protoc -I/usr/local/include -I. \ + -I$GOPATH/src/github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis \ + --go_out=plugins=grpc,paths=source_relative:. \ + hello.proto diff --git a/proxy/testdata/hello.pb.go b/proxy/testdata/hello.pb.go new file mode 100644 index 0000000..2ea4116 --- /dev/null +++ b/proxy/testdata/hello.pb.go @@ -0,0 +1,194 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: hello.proto + +package proxy_test + +import ( + context "context" + fmt "fmt" + proto "github.com/golang/protobuf/proto" + grpc "google.golang.org/grpc" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type HelloRequest struct { + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *HelloRequest) Reset() { *m = HelloRequest{} } +func (m *HelloRequest) String() string { return proto.CompactTextString(m) } +func (*HelloRequest) ProtoMessage() {} +func (*HelloRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_61ef911816e0a8ce, []int{0} +} + +func (m *HelloRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_HelloRequest.Unmarshal(m, b) +} +func (m *HelloRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_HelloRequest.Marshal(b, m, deterministic) +} +func (m *HelloRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_HelloRequest.Merge(m, src) +} +func (m *HelloRequest) XXX_Size() int { + return xxx_messageInfo_HelloRequest.Size(m) +} +func (m *HelloRequest) XXX_DiscardUnknown() { + xxx_messageInfo_HelloRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_HelloRequest proto.InternalMessageInfo + +func (m *HelloRequest) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +type HelloReply struct { + Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *HelloReply) Reset() { *m = HelloReply{} } +func (m *HelloReply) String() string { return proto.CompactTextString(m) } +func (*HelloReply) ProtoMessage() {} +func (*HelloReply) Descriptor() ([]byte, []int) { + return fileDescriptor_61ef911816e0a8ce, []int{1} +} + +func (m *HelloReply) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_HelloReply.Unmarshal(m, b) +} +func (m *HelloReply) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_HelloReply.Marshal(b, m, deterministic) +} +func (m *HelloReply) XXX_Merge(src proto.Message) { + xxx_messageInfo_HelloReply.Merge(m, src) +} +func (m *HelloReply) XXX_Size() int { + return xxx_messageInfo_HelloReply.Size(m) +} +func (m *HelloReply) XXX_DiscardUnknown() { + xxx_messageInfo_HelloReply.DiscardUnknown(m) +} + +var xxx_messageInfo_HelloReply proto.InternalMessageInfo + +func (m *HelloReply) GetMessage() string { + if m != nil { + return m.Message + } + return "" +} + +func init() { + proto.RegisterType((*HelloRequest)(nil), "proxy_test.HelloRequest") + proto.RegisterType((*HelloReply)(nil), "proxy_test.HelloReply") +} + +func init() { proto.RegisterFile("hello.proto", fileDescriptor_61ef911816e0a8ce) } + +var fileDescriptor_61ef911816e0a8ce = []byte{ + // 145 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0xce, 0x48, 0xcd, 0xc9, + 0xc9, 0xd7, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x2a, 0x28, 0xca, 0xaf, 0xa8, 0x8c, 0x2f, + 0x49, 0x2d, 0x2e, 0x51, 0x52, 0xe2, 0xe2, 0xf1, 0x00, 0x49, 0x05, 0xa5, 0x16, 0x96, 0xa6, 0x16, + 0x97, 0x08, 0x09, 0x71, 0xb1, 0xe4, 0x25, 0xe6, 0xa6, 0x4a, 0x30, 0x2a, 0x30, 0x6a, 0x70, 0x06, + 0x81, 0xd9, 0x4a, 0x6a, 0x5c, 0x5c, 0x50, 0x35, 0x05, 0x39, 0x95, 0x42, 0x12, 0x5c, 0xec, 0xb9, + 0xa9, 0xc5, 0xc5, 0x89, 0xe9, 0x30, 0x45, 0x30, 0xae, 0x91, 0x27, 0x17, 0xbb, 0x7b, 0x51, 0x6a, + 0x6a, 0x49, 0x6a, 0x91, 0x90, 0x1d, 0x17, 0x47, 0x70, 0x62, 0x25, 0x58, 0x97, 0x90, 0x84, 0x1e, + 0xc2, 0x3e, 0x3d, 0x64, 0xcb, 0xa4, 0xc4, 0xb0, 0xc8, 0x14, 0xe4, 0x54, 0x2a, 0x31, 0x24, 0xb1, + 0x81, 0x5d, 0x6a, 0x0c, 0x08, 0x00, 0x00, 0xff, 0xff, 0xe3, 0x7f, 0xe6, 0xbe, 0xb8, 0x00, 0x00, + 0x00, +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// GreeterClient is the client API for Greeter service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. +type GreeterClient interface { + SayHello(ctx context.Context, in *HelloRequest, opts ...grpc.CallOption) (*HelloReply, error) +} + +type greeterClient struct { + cc *grpc.ClientConn +} + +func NewGreeterClient(cc *grpc.ClientConn) GreeterClient { + return &greeterClient{cc} +} + +func (c *greeterClient) SayHello(ctx context.Context, in *HelloRequest, opts ...grpc.CallOption) (*HelloReply, error) { + out := new(HelloReply) + err := c.cc.Invoke(ctx, "/proxy_test.Greeter/SayHello", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// GreeterServer is the server API for Greeter service. +type GreeterServer interface { + SayHello(context.Context, *HelloRequest) (*HelloReply, error) +} + +func RegisterGreeterServer(s *grpc.Server, srv GreeterServer) { + s.RegisterService(&_Greeter_serviceDesc, srv) +} + +func _Greeter_SayHello_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(HelloRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(GreeterServer).SayHello(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/proxy_test.Greeter/SayHello", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(GreeterServer).SayHello(ctx, req.(*HelloRequest)) + } + return interceptor(ctx, in, info, handler) +} + +var _Greeter_serviceDesc = grpc.ServiceDesc{ + ServiceName: "proxy_test.Greeter", + HandlerType: (*GreeterServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "SayHello", + Handler: _Greeter_SayHello_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "hello.proto", +} diff --git a/proxy/testdata/hello.proto b/proxy/testdata/hello.proto new file mode 100644 index 0000000..4ee0771 --- /dev/null +++ b/proxy/testdata/hello.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +package proxy_test; + +service Greeter { + rpc SayHello (HelloRequest) returns (HelloReply) {} +} + +message HelloRequest { + string name = 1; +} + +message HelloReply { + string message = 1; +}