diff --git a/lsat/caveat.go b/lsat/caveat.go new file mode 100644 index 0000000..80aa6e0 --- /dev/null +++ b/lsat/caveat.go @@ -0,0 +1,142 @@ +package lsat + +import ( + "errors" + "fmt" + "strings" + + "gopkg.in/macaroon.v2" +) + +const ( + // PreimageKey is the key used for a payment preimage caveat. + PreimageKey = "preimage" +) + +var ( + // ErrInvalidCaveat is an error returned when we attempt to decode a + // caveat with an invalid format. + ErrInvalidCaveat = errors.New("caveat must be of the form " + + "\"condition=value\"") +) + +// Caveat is a predicate that can be applied to an LSAT in order to restrict its +// use in some form. Caveats are evaluated during LSAT verification after the +// LSAT's signature is verified. The predicate of each caveat must hold true in +// order to successfully validate an LSAT. +type Caveat struct { + // Condition serves as a way to identify a caveat and how to satisfy it. + Condition string + + // Value is what will be used to satisfy a caveat. This can be as + // flexible as needed, as long as it can be encoded into a string. + Value string +} + +// NewCaveat construct a new caveat with the given condition and value. +func NewCaveat(condition string, value string) Caveat { + return Caveat{Condition: condition, Value: value} +} + +// String returns a user-friendly view of a caveat. +func (c Caveat) String() string { + return EncodeCaveat(c) +} + +// EncodeCaveat encodes a caveat into its string representation. +func EncodeCaveat(c Caveat) string { + return fmt.Sprintf("%v=%v", c.Condition, c.Value) +} + +// DecodeCaveat decodes a caveat from its string representation. +func DecodeCaveat(s string) (Caveat, error) { + parts := strings.SplitN(s, "=", 2) + if len(parts) != 2 { + return Caveat{}, ErrInvalidCaveat + } + return Caveat{Condition: parts[0], Value: parts[1]}, nil +} + +// AddFirstPartyCaveats adds a set of caveats as first-party caveats to a +// macaroon. +func AddFirstPartyCaveats(m *macaroon.Macaroon, caveats ...Caveat) error { + for _, c := range caveats { + rawCaveat := []byte(EncodeCaveat(c)) + if err := m.AddFirstPartyCaveat(rawCaveat); err != nil { + return err + } + } + + return nil +} + +// HasCaveat checks whether the given macaroon has a caveat with the given +// condition, and if so, returns its value. If multiple caveats with the same +// condition exist, then the value of the last one is returned. +func HasCaveat(m *macaroon.Macaroon, cond string) (string, bool) { + var value *string + for _, rawCaveat := range m.Caveats() { + caveat, err := DecodeCaveat(string(rawCaveat.Id)) + if err != nil { + // Ignore any unknown caveats as we can't decode them. + continue + } + if caveat.Condition == cond { + value = &caveat.Value + } + } + + if value == nil { + return "", false + } + return *value, true +} + +// VerifyCaveats determines whether every relevant caveat of an LSAT holds true. +// A caveat is considered relevant if a satisfier is provided for it, which is +// what we'll use as their evaluation. +// +// NOTE: The caveats provided should be in the same order as in the LSAT to +// ensure the correctness of each satisfier's SatisfyPrevious. +func VerifyCaveats(caveats []Caveat, satisfiers ...Satisfier) error { + // Construct a set of our satisfiers to determine which caveats we know + // how to satisfy. + caveatSatisfiers := make(map[string]Satisfier, len(satisfiers)) + for _, satisfier := range satisfiers { + caveatSatisfiers[satisfier.Condition] = satisfier + } + relevantCaveats := make(map[string][]Caveat) + for _, caveat := range caveats { + if _, ok := caveatSatisfiers[caveat.Condition]; !ok { + continue + } + relevantCaveats[caveat.Condition] = append( + relevantCaveats[caveat.Condition], caveat, + ) + } + + for condition, caveats := range relevantCaveats { + satisfier := caveatSatisfiers[condition] + + // Since it's possible for a chain of caveat to exist for the + // same condition as a way to demote privileges, we'll ensure + // each one satisfies its previous. + for i, j := 0, 1; j < len(caveats); i, j = i+1, j+1 { + prevCaveat := caveats[i] + curCaveat := caveats[j] + err := satisfier.SatisfyPrevious(prevCaveat, curCaveat) + if err != nil { + return err + } + } + + // Once we verify the previous ones, if any, we can proceed to + // verify the final one, which is the decision maker. + err := satisfier.SatisfyFinal(caveats[len(caveats)-1]) + if err != nil { + return err + } + } + + return nil +} diff --git a/lsat/caveat_test.go b/lsat/caveat_test.go new file mode 100644 index 0000000..818a86c --- /dev/null +++ b/lsat/caveat_test.go @@ -0,0 +1,202 @@ +package lsat + +import ( + "errors" + "testing" + + "gopkg.in/macaroon.v2" +) + +var ( + testMacaroon, _ = macaroon.New(nil, nil, "", macaroon.LatestVersion) +) + +// TestCaveatSerialization ensures that we can properly encode/decode valid +// caveats and cannot do so for invalid ones. +func TestCaveatSerialization(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + caveatStr string + err error + }{ + { + name: "valid caveat", + caveatStr: "expiration=1337", + err: nil, + }, + { + name: "valid caveat with separator in value", + caveatStr: "expiration=1337=", + err: nil, + }, + { + name: "invalid caveat", + caveatStr: "expiration:1337", + err: ErrInvalidCaveat, + }, + } + + for _, test := range tests { + test := test + success := t.Run(test.name, func(t *testing.T) { + caveat, err := DecodeCaveat(test.caveatStr) + if !errors.Is(err, test.err) { + t.Fatalf("expected err \"%v\", got \"%v\"", + test.err, err) + } + + if test.err != nil { + return + } + + caveatStr := EncodeCaveat(caveat) + if caveatStr != test.caveatStr { + t.Fatalf("expected encoded caveat \"%v\", "+ + "got \"%v\"", test.caveatStr, caveatStr) + } + }) + if !success { + return + } + } +} + +// TestHasCaveat ensures we can determine whether a macaroon contains a caveat +// with a specific condition. +func TestHasCaveat(t *testing.T) { + t.Parallel() + + const ( + cond = "cond" + value = "value" + ) + m := testMacaroon.Clone() + + // The macaroon doesn't have any caveats, so we shouldn't find any. + if _, ok := HasCaveat(m, cond); ok { + t.Fatal("found unexpected caveat with unknown condition") + } + + // Add two caveats, one in a valid LSAT format and another invalid. + // We'll test that we're still able to determine the macaroon contains + // the valid caveat even though there is one that is invalid. + invalidCaveat := []byte("invalid") + if err := m.AddFirstPartyCaveat(invalidCaveat); err != nil { + t.Fatalf("unable to add macaroon caveat: %v", err) + } + validCaveat1 := Caveat{Condition: cond, Value: value} + if err := AddFirstPartyCaveats(m, validCaveat1); err != nil { + t.Fatalf("unable to add macaroon caveat: %v", err) + } + + caveatValue, ok := HasCaveat(m, cond) + if !ok { + t.Fatal("expected macaroon to contain caveat") + } + if caveatValue != validCaveat1.Value { + t.Fatalf("expected caveat value \"%v\", got \"%v\"", + validCaveat1.Value, caveatValue) + } + + // If we add another caveat with the same condition, the value of the + // most recently added caveat should be returned instead. + validCaveat2 := validCaveat1 + validCaveat2.Value += value + if err := AddFirstPartyCaveats(m, validCaveat2); err != nil { + t.Fatalf("unable to add macaroon caveat: %v", err) + } + + caveatValue, ok = HasCaveat(m, cond) + if !ok { + t.Fatal("expected macaroon to contain caveat") + } + if caveatValue != validCaveat2.Value { + t.Fatalf("expected caveat value \"%v\", got \"%v\"", + validCaveat2.Value, caveatValue) + } +} + +// TestVerifyCaveats ensures caveat verification only holds true for known +// caveats. +func TestVerifyCaveats(t *testing.T) { + t.Parallel() + + caveat1 := Caveat{Condition: "1", Value: "test"} + caveat2 := Caveat{Condition: "2", Value: "test"} + satisfier := Satisfier{ + Condition: caveat1.Condition, + SatisfyPrevious: func(c Caveat, prev Caveat) error { + return nil + }, + SatisfyFinal: func(c Caveat) error { + return nil + }, + } + invalidSatisfyPrevious := func(c Caveat, prev Caveat) error { + return errors.New("no") + } + invalidSatisfyFinal := func(c Caveat) error { + return errors.New("no") + } + + tests := []struct { + name string + caveats []Caveat + satisfiers []Satisfier + shouldFail bool + }{ + { + name: "simple verification", + caveats: []Caveat{caveat1}, + satisfiers: []Satisfier{satisfier}, + shouldFail: false, + }, + { + name: "unknown caveat", + caveats: []Caveat{caveat1, caveat2}, + satisfiers: []Satisfier{satisfier}, + shouldFail: false, + }, + { + name: "one invalid", + caveats: []Caveat{caveat1, caveat2}, + satisfiers: []Satisfier{ + satisfier, + { + Condition: caveat2.Condition, + SatisfyFinal: invalidSatisfyFinal, + }, + }, + shouldFail: true, + }, + { + name: "prev invalid", + caveats: []Caveat{caveat1, caveat1}, + satisfiers: []Satisfier{ + { + Condition: caveat1.Condition, + SatisfyPrevious: invalidSatisfyPrevious, + }, + }, + shouldFail: true, + }, + } + + for _, test := range tests { + test := test + success := t.Run(test.name, func(t *testing.T) { + err := VerifyCaveats(test.caveats, test.satisfiers...) + if test.shouldFail && err == nil { + t.Fatal("expected caveat verification to fail") + } + if !test.shouldFail && err != nil { + t.Fatal("unexpected caveat verification failure") + } + }) + if !success { + return + } + } +} diff --git a/lsat/credential.go b/lsat/credential.go new file mode 100644 index 0000000..8d38c9e --- /dev/null +++ b/lsat/credential.go @@ -0,0 +1,52 @@ +package lsat + +import ( + "context" + "encoding/hex" + + "gopkg.in/macaroon.v2" +) + +// MacaroonCredential wraps a macaroon to implement the +// credentials.PerRPCCredentials interface. +type MacaroonCredential struct { + *macaroon.Macaroon + + // AllowInsecure specifies if the macaroon is allowed to be sent over + // insecure transport channels. This should only ever be set to true if + // the insecure connection is proxied through tor and the destination + // address is an onion service. + AllowInsecure bool +} + +// RequireTransportSecurity implements the PerRPCCredentials interface. +func (m MacaroonCredential) RequireTransportSecurity() bool { + return !m.AllowInsecure +} + +// GetRequestMetadata implements the PerRPCCredentials interface. This method +// is required in order to pass the wrapped macaroon into the gRPC context. +// With this, the macaroon will be available within the request handling scope +// of the ultimate gRPC server implementation. +func (m MacaroonCredential) GetRequestMetadata(_ context.Context, + _ ...string) (map[string]string, error) { + + macBytes, err := m.MarshalBinary() + if err != nil { + return nil, err + } + + md := make(map[string]string) + md["macaroon"] = hex.EncodeToString(macBytes) + return md, nil +} + +// NewMacaroonCredential returns a copy of the passed macaroon wrapped in a +// MacaroonCredential struct which implements PerRPCCredentials. +func NewMacaroonCredential(m *macaroon.Macaroon, + allowInsecure bool) MacaroonCredential { + + ms := MacaroonCredential{AllowInsecure: allowInsecure} + ms.Macaroon = m.Clone() + return ms +} diff --git a/lsat/identifier.go b/lsat/identifier.go new file mode 100644 index 0000000..252540e --- /dev/null +++ b/lsat/identifier.go @@ -0,0 +1,128 @@ +package lsat + +import ( + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" + + "github.com/lightningnetwork/lnd/lntypes" +) + +const ( + // LatestVersion is the latest version used for minting new LSATs. + LatestVersion = 0 + + // SecretSize is the size in bytes of a LSAT's secret, also known as + // the root key of the macaroon. + SecretSize = 32 + + // TokenIDSize is the size in bytes of an LSAT's ID encoded in its + // macaroon identifier. + TokenIDSize = 32 +) + +var ( + // byteOrder is the byte order used to encode/decode a macaroon's raw + // identifier. + byteOrder = binary.BigEndian + + // ErrUnknownVersion is an error returned when attempting to decode an + // LSAT identifier with an unknown version. + ErrUnknownVersion = errors.New("unknown LSAT version") +) + +// TokenID is the type that stores the token identifier of an LSAT token. +type TokenID [TokenIDSize]byte + +// String returns the hex encoded representation of the token ID as a string. +func (t *TokenID) String() string { + return hex.EncodeToString(t[:]) +} + +// MakeIDFromString parses the hex encoded string and parses it into a token ID. +func MakeIDFromString(newID string) (TokenID, error) { + if len(newID) != hex.EncodedLen(TokenIDSize) { + return TokenID{}, fmt.Errorf("invalid id string length of %v, "+ + "want %v", len(newID), hex.EncodedLen(TokenIDSize)) + } + + idBytes, err := hex.DecodeString(newID) + if err != nil { + return TokenID{}, err + } + var id TokenID + copy(id[:], idBytes) + + return id, nil +} + +// Identifier contains the static identifying details of an LSAT. This is +// intended to be used as the identifier of the macaroon within an LSAT. +type Identifier struct { + // Version is the version of an LSAT. Having a version allows us to + // introduce new fields to the identifier in a backwards-compatible + // manner. + Version uint16 + + // PaymentHash is the payment hash linked to an LSAT. Verification of + // an LSAT depends on a valid payment, which is enforced by ensuring a + // preimage is provided that hashes to our payment hash. + PaymentHash lntypes.Hash + + // TokenID is the unique identifier of an LSAT. + TokenID TokenID +} + +// EncodeIdentifier encodes an LSAT's identifier according to its version. +func EncodeIdentifier(w io.Writer, id *Identifier) error { + if err := binary.Write(w, byteOrder, id.Version); err != nil { + return err + } + + switch id.Version { + // A version 0 identifier consists of its linked payment hash, followed + // by the token ID. + case 0: + if _, err := w.Write(id.PaymentHash[:]); err != nil { + return err + } + _, err := w.Write(id.TokenID[:]) + return err + + default: + return fmt.Errorf("%w: %v", ErrUnknownVersion, id.Version) + } +} + +// DecodeIdentifier decodes an LSAT's identifier according to its version. +func DecodeIdentifier(r io.Reader) (*Identifier, error) { + var version uint16 + if err := binary.Read(r, byteOrder, &version); err != nil { + return nil, err + } + + switch version { + // A version 0 identifier consists of its linked payment hash, followed + // by the token ID. + case 0: + var paymentHash lntypes.Hash + if _, err := r.Read(paymentHash[:]); err != nil { + return nil, err + } + var tokenID TokenID + if _, err := r.Read(tokenID[:]); err != nil { + return nil, err + } + + return &Identifier{ + Version: version, + PaymentHash: paymentHash, + TokenID: tokenID, + }, nil + + default: + return nil, fmt.Errorf("%w: %v", ErrUnknownVersion, version) + } +} diff --git a/lsat/identifier_test.go b/lsat/identifier_test.go new file mode 100644 index 0000000..abda64c --- /dev/null +++ b/lsat/identifier_test.go @@ -0,0 +1,70 @@ +package lsat + +import ( + "bytes" + "errors" + "testing" + + "github.com/lightningnetwork/lnd/lntypes" +) + +var ( + testPaymentHash lntypes.Hash + testTokenID [TokenIDSize]byte +) + +// TestIdentifierSerialization ensures proper serialization of known identifier +// versions and failures for unknown versions. +func TestIdentifierSerialization(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + id Identifier + err error + }{ + { + name: "valid identifier", + id: Identifier{ + Version: LatestVersion, + PaymentHash: testPaymentHash, + TokenID: testTokenID, + }, + err: nil, + }, + { + name: "unknown version", + id: Identifier{ + Version: LatestVersion + 1, + PaymentHash: testPaymentHash, + TokenID: testTokenID, + }, + err: ErrUnknownVersion, + }, + } + + for _, test := range tests { + test := test + success := t.Run(test.name, func(t *testing.T) { + var buf bytes.Buffer + err := EncodeIdentifier(&buf, &test.id) + if !errors.Is(err, test.err) { + t.Fatalf("expected err \"%v\", got \"%v\"", + test.err, err) + } + if test.err != nil { + return + } + id, err := DecodeIdentifier(&buf) + if err != nil { + t.Fatalf("unable to decode identifier: %v", err) + } + if *id != test.id { + t.Fatalf("expected id %v, got %v", test.id, *id) + } + }) + if !success { + return + } + } +} diff --git a/lsat/interceptor.go b/lsat/interceptor.go new file mode 100644 index 0000000..f4f4655 --- /dev/null +++ b/lsat/interceptor.go @@ -0,0 +1,448 @@ +package lsat + +import ( + "context" + "encoding/base64" + "fmt" + "regexp" + "sync" + "time" + + "github.com/btcsuite/btcutil" + "github.com/lightninglabs/loop/lndclient" + "github.com/lightningnetwork/lnd/lnrpc/routerrpc" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/zpay32" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +const ( + // GRPCErrCode is the error code we receive from a gRPC call if the + // server expects a payment. + GRPCErrCode = codes.Internal + + // GRPCErrMessage is the error message we receive from a gRPC call in + // conjunction with the GRPCErrCode to signal the client that a payment + // is required to access the service. + GRPCErrMessage = "payment required" + + // AuthHeader is is the HTTP response header that contains the payment + // challenge. + AuthHeader = "WWW-Authenticate" + + // DefaultMaxCostSats is the default maximum amount in satoshis that we + // are going to pay for an LSAT automatically. Does not include routing + // fees. + DefaultMaxCostSats = 1000 + + // DefaultMaxRoutingFeeSats is the default maximum routing fee in + // satoshis that we are going to pay to acquire an LSAT token. + DefaultMaxRoutingFeeSats = 10 + + // PaymentTimeout is the maximum time we allow a payment to take before + // we stop waiting for it. + PaymentTimeout = 60 * time.Second + + // manualRetryHint is the error text we return to tell the user how a + // token payment can be retried if the payment fails. + manualRetryHint = "consider removing pending token file if error " + + "persists. use 'listauth' command to find out token file name" +) + +var ( + // authHeaderRegex is the regular expression the payment challenge must + // match for us to be able to parse the macaroon and invoice. + authHeaderRegex = regexp.MustCompile( + "LSAT macaroon=\"(.*?)\", invoice=\"(.*?)\"", + ) +) + +// Interceptor is a gRPC client interceptor that can handle LSAT authentication +// challenges with embedded payment requests. It uses a connection to lnd to +// automatically pay for an authentication token. +type Interceptor struct { + lnd *lndclient.LndServices + store Store + callTimeout time.Duration + maxCost btcutil.Amount + maxFee btcutil.Amount + lock sync.Mutex + allowInsecure bool +} + +// NewInterceptor creates a new gRPC client interceptor that uses the provided +// lnd connection to automatically acquire and pay for LSAT tokens, unless the +// indicated store already contains a usable token. +func NewInterceptor(lnd *lndclient.LndServices, store Store, + rpcCallTimeout time.Duration, maxCost, + maxFee btcutil.Amount, allowInsecure bool) *Interceptor { + + return &Interceptor{ + lnd: lnd, + store: store, + callTimeout: rpcCallTimeout, + maxCost: maxCost, + maxFee: maxFee, + allowInsecure: allowInsecure, + } +} + +// interceptContext is a struct that contains all information about a call that +// is intercepted by the interceptor. +type interceptContext struct { + mainCtx context.Context + opts []grpc.CallOption + metadata *metadata.MD + token *Token +} + +// UnaryInterceptor is an interceptor method that can be used directly by gRPC +// for unary calls. If the store contains a token, it is attached as credentials +// to every call before patching it through. The response error is also +// intercepted for every call. If there is an error returned and it is +// indicating a payment challenge, a token is acquired and paid for +// automatically. The original request is then repeated back to the server, now +// with the new token attached. +func (i *Interceptor) UnaryInterceptor(ctx context.Context, method string, + req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, + opts ...grpc.CallOption) error { + + // To avoid paying for a token twice if two parallel requests are + // happening, we require an exclusive lock here. + i.lock.Lock() + defer i.lock.Unlock() + + // Create the context that we'll use to initiate the real request. This + // contains the means to extract response headers and possibly also an + // auth token, if we already have paid for one. + iCtx, err := i.newInterceptContext(ctx, opts) + if err != nil { + return err + } + + // Try executing the call now. If anything goes wrong, we only handle + // the LSAT error message that comes in the form of a gRPC status error. + rpcCtx, cancel := context.WithTimeout(ctx, i.callTimeout) + defer cancel() + err = invoker(rpcCtx, method, req, reply, cc, iCtx.opts...) + if !isPaymentRequired(err) { + return err + } + + // Find out if we need to pay for a new token or perhaps resume + // a previously aborted payment. + err = i.handlePayment(iCtx) + if err != nil { + return err + } + + // Execute the same request again, now with the LSAT + // token added as an RPC credential. + rpcCtx2, cancel2 := context.WithTimeout(ctx, i.callTimeout) + defer cancel2() + return invoker(rpcCtx2, method, req, reply, cc, iCtx.opts...) +} + +// StreamInterceptor is an interceptor method that can be used directly by gRPC +// for streaming calls. If the store contains a token, it is attached as +// credentials to every stream establishment call before patching it through. +// The response error is also intercepted for every initial stream initiation. +// If there is an error returned and it is indicating a payment challenge, a +// token is acquired and paid for automatically. The original request is then +// repeated back to the server, now with the new token attached. +func (i *Interceptor) StreamInterceptor(ctx context.Context, + desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, + streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, + error) { + + // To avoid paying for a token twice if two parallel requests are + // happening, we require an exclusive lock here. + i.lock.Lock() + defer i.lock.Unlock() + + // Create the context that we'll use to initiate the real request. This + // contains the means to extract response headers and possibly also an + // auth token, if we already have paid for one. + iCtx, err := i.newInterceptContext(ctx, opts) + if err != nil { + return nil, err + } + + // Try establishing the stream now. If anything goes wrong, we only + // handle the LSAT error message that comes in the form of a gRPC status + // error. The context of a stream will be used for the whole lifetime of + // it, so we can't really clamp down on the initial call with a timeout. + stream, err := streamer(ctx, desc, cc, method, iCtx.opts...) + if !isPaymentRequired(err) { + return stream, err + } + + // Find out if we need to pay for a new token or perhaps resume + // a previously aborted payment. + err = i.handlePayment(iCtx) + if err != nil { + return nil, err + } + + // Execute the same request again, now with the LSAT token added + // as an RPC credential. + return streamer(ctx, desc, cc, method, iCtx.opts...) +} + +// newInterceptContext creates the initial intercept context that can capture +// metadata from the server and sends the local token to the server if one +// already exists. +func (i *Interceptor) newInterceptContext(ctx context.Context, + opts []grpc.CallOption) (*interceptContext, error) { + + iCtx := &interceptContext{ + mainCtx: ctx, + opts: opts, + metadata: &metadata.MD{}, + } + + // Let's see if the store already contains a token and what state it + // might be in. If a previous call was aborted, we might have a pending + // token that needs to be handled separately. + var err error + iCtx.token, err = i.store.CurrentToken() + switch { + // If there is no token yet, nothing to do at this point. + case err == ErrNoToken: + + // Some other error happened that we have to surface. + case err != nil: + log.Errorf("Failed to get token from store: %v", err) + return nil, fmt.Errorf("getting token from store failed: %v", + err) + + // Only if we have a paid token append it. We don't resume a pending + // payment just yet, since we don't even know if a token is required for + // this call. We also never send a pending payment to the server since + // we know it's not valid. + case !iCtx.token.isPending(): + if err = i.addLsatCredentials(iCtx); err != nil { + log.Errorf("Adding macaroon to request failed: %v", err) + return nil, fmt.Errorf("adding macaroon failed: %v", + err) + } + } + + // We need a way to extract the response headers sent by the server. + // This can only be done through the experimental grpc.Trailer call + // option. We execute the request and inspect the error. If it's the + // LSAT specific payment required error, we might execute the same + // method again later with the paid LSAT token. + iCtx.opts = append(iCtx.opts, grpc.Trailer(iCtx.metadata)) + return iCtx, nil +} + +// handlePayment tries to obtain a valid token by either tracking the payment +// status of a pending token or paying for a new one. +func (i *Interceptor) handlePayment(iCtx *interceptContext) error { + switch { + // Resume/track a pending payment if it was interrupted for some reason. + case iCtx.token != nil && iCtx.token.isPending(): + log.Infof("Payment of LSAT token is required, resuming/" + + "tracking previous payment from pending LSAT token") + err := i.trackPayment(iCtx.mainCtx, iCtx.token) + if err != nil { + return err + } + + // We don't have a token yet, try to get a new one. + case iCtx.token == nil: + // We don't have a token yet, get a new one. + log.Infof("Payment of LSAT token is required, paying invoice") + var err error + iCtx.token, err = i.payLsatToken(iCtx.mainCtx, iCtx.metadata) + if err != nil { + return err + } + + // We have a token and it's valid, nothing more to do here. + default: + log.Debugf("Found valid LSAT token to add to request") + } + + if err := i.addLsatCredentials(iCtx); err != nil { + log.Errorf("Adding macaroon to request failed: %v", err) + return fmt.Errorf("adding macaroon failed: %v", err) + } + return nil +} + +// addLsatCredentials adds an LSAT token to the given intercept context. +func (i *Interceptor) addLsatCredentials(iCtx *interceptContext) error { + if iCtx.token == nil { + return fmt.Errorf("cannot add nil token to context") + } + + macaroon, err := iCtx.token.PaidMacaroon() + if err != nil { + return err + } + iCtx.opts = append(iCtx.opts, grpc.PerRPCCredentials( + NewMacaroonCredential(macaroon, i.allowInsecure), + )) + return nil +} + +// payLsatToken reads the payment challenge from the response metadata and tries +// to pay the invoice encoded in them, returning a paid LSAT token if +// successful. +func (i *Interceptor) payLsatToken(ctx context.Context, md *metadata.MD) ( + *Token, error) { + + // First parse the authentication header that was stored in the + // metadata. + authHeader := md.Get(AuthHeader) + if len(authHeader) == 0 { + return nil, fmt.Errorf("auth header not found in response") + } + matches := authHeaderRegex.FindStringSubmatch(authHeader[0]) + if len(matches) != 3 { + return nil, fmt.Errorf("invalid auth header "+ + "format: %s", authHeader[0]) + } + + // Decode the base64 macaroon and the invoice so we can store the + // information in our store later. + macBase64, invoiceStr := matches[1], matches[2] + macBytes, err := base64.StdEncoding.DecodeString(macBase64) + if err != nil { + return nil, fmt.Errorf("base64 decode of macaroon failed: "+ + "%v", err) + } + invoice, err := zpay32.Decode(invoiceStr, i.lnd.ChainParams) + if err != nil { + return nil, fmt.Errorf("unable to decode invoice: %v", err) + } + + // Check that the charged amount does not exceed our maximum cost. + maxCostMsat := lnwire.NewMSatFromSatoshis(i.maxCost) + if invoice.MilliSat != nil && *invoice.MilliSat > maxCostMsat { + return nil, fmt.Errorf("cannot pay for LSAT automatically, "+ + "cost of %d msat exceeds configured max cost of %d "+ + "msat", *invoice.MilliSat, maxCostMsat) + } + + // Create and store the pending token so we can resume the payment in + // case the payment is interrupted somehow. + token, err := tokenFromChallenge(macBytes, invoice.PaymentHash) + if err != nil { + return nil, fmt.Errorf("unable to create token: %v", err) + } + err = i.store.StoreToken(token) + if err != nil { + return nil, fmt.Errorf("unable to store pending token: %v", err) + } + + // Pay invoice now and wait for the result to arrive or the main context + // being canceled. + payCtx, cancel := context.WithTimeout(ctx, PaymentTimeout) + defer cancel() + respChan := i.lnd.Client.PayInvoice( + payCtx, invoiceStr, i.maxFee, nil, + ) + select { + case result := <-respChan: + if result.Err != nil { + return nil, result.Err + } + token.Preimage = result.Preimage + token.AmountPaid = lnwire.NewMSatFromSatoshis(result.PaidAmt) + token.RoutingFeePaid = lnwire.NewMSatFromSatoshis( + result.PaidFee, + ) + return token, i.store.StoreToken(token) + + case <-payCtx.Done(): + return nil, fmt.Errorf("payment timed out. try again to track "+ + "payment. %s", manualRetryHint) + + case <-ctx.Done(): + return nil, fmt.Errorf("parent context canceled. try again to"+ + "track payment. %s", manualRetryHint) + } +} + +// trackPayment tries to resume a pending payment by tracking its state and +// waiting for a conclusive result. +func (i *Interceptor) trackPayment(ctx context.Context, token *Token) error { + // Lookup state of the payment. + paymentStateCtx, cancel := context.WithCancel(ctx) + defer cancel() + payStatusChan, payErrChan, err := i.lnd.Router.TrackPayment( + paymentStateCtx, token.PaymentHash, + ) + if err != nil { + log.Errorf("Could not call TrackPayment on lnd: %v", err) + return fmt.Errorf("track payment call to lnd failed: %v", err) + } + + // We can't wait forever, so we give the payment tracking the same + // timeout as the original payment. + payCtx, cancel := context.WithTimeout(ctx, PaymentTimeout) + defer cancel() + + // We'll consume status updates until we reach a conclusive state or + // reach the timeout. + for { + select { + // If we receive a state without an error, the payment has been + // initiated. Loop until the payment + case result := <-payStatusChan: + switch result.State { + // If the payment was successful, we have all the + // information we need and we can return the fully paid + // token. + case routerrpc.PaymentState_SUCCEEDED: + extractPaymentDetails(token, result) + return i.store.StoreToken(token) + + // The payment is still in transit, we'll give it more + // time to complete. + case routerrpc.PaymentState_IN_FLIGHT: + + // Any other state means either error or timeout. + default: + return fmt.Errorf("payment tracking failed "+ + "with state %s. %s", + result.State.String(), manualRetryHint) + } + + // Abort the payment execution for any error. + case err := <-payErrChan: + return fmt.Errorf("payment tracking failed: %v. %s", + err, manualRetryHint) + + case <-payCtx.Done(): + return fmt.Errorf("payment tracking timed out. %s", + manualRetryHint) + } + } +} + +// isPaymentRequired inspects an error to find out if it's the specific gRPC +// error returned by the server to indicate a payment is required to access the +// service. +func isPaymentRequired(err error) bool { + statusErr, ok := status.FromError(err) + return ok && + statusErr.Message() == GRPCErrMessage && + statusErr.Code() == GRPCErrCode +} + +// extractPaymentDetails extracts the preimage and amounts paid for a payment +// from the payment status and stores them in the token. +func extractPaymentDetails(token *Token, status lndclient.PaymentStatus) { + token.Preimage = status.Preimage + total := status.Route.TotalAmount + fees := status.Route.TotalFees() + token.AmountPaid = total - fees + token.RoutingFeePaid = fees +} diff --git a/lsat/interceptor_test.go b/lsat/interceptor_test.go new file mode 100644 index 0000000..64d5c10 --- /dev/null +++ b/lsat/interceptor_test.go @@ -0,0 +1,416 @@ +package lsat + +import ( + "context" + "encoding/base64" + "encoding/hex" + "fmt" + "sync" + "testing" + "time" + + "github.com/lightninglabs/loop/lndclient" + "github.com/lightninglabs/loop/test" + "github.com/lightningnetwork/lnd/lnrpc/routerrpc" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/routing/route" + "google.golang.org/grpc" + "google.golang.org/grpc/status" + "gopkg.in/macaroon.v2" +) + +type interceptTestCase struct { + name string + initialPreimage *lntypes.Preimage + interceptor *Interceptor + resetCb func() + expectLndCall bool + sendPaymentCb func(*testing.T, test.PaymentChannelMessage) + trackPaymentCb func(*testing.T, test.TrackPaymentMessage) + expectToken bool + expectInterceptErr string + expectBackendCalls int + expectMacaroonCall1 bool + expectMacaroonCall2 bool +} + +type mockStore struct { + token *Token +} + +func (s *mockStore) CurrentToken() (*Token, error) { + if s.token == nil { + return nil, ErrNoToken + } + return s.token, nil +} + +func (s *mockStore) AllTokens() (map[string]*Token, error) { + return map[string]*Token{"foo": s.token}, nil +} + +func (s *mockStore) StoreToken(token *Token) error { + s.token = token + return nil +} + +var ( + lnd = test.NewMockLnd() + store = &mockStore{} + testTimeout = 5 * time.Second + interceptor = NewInterceptor( + &lnd.LndServices, store, testTimeout, + DefaultMaxCostSats, DefaultMaxRoutingFeeSats, false, + ) + testMac = makeMac() + testMacBytes = serializeMac(testMac) + testMacHex = hex.EncodeToString(testMacBytes) + paidPreimage = lntypes.Preimage{1, 2, 3, 4, 5} + backendErr error + backendAuth = "" + callMD map[string]string + numBackendCalls = 0 + overallWg sync.WaitGroup + backendWg sync.WaitGroup + + testCases = []interceptTestCase{ + { + name: "no auth required happy path", + initialPreimage: nil, + interceptor: interceptor, + resetCb: func() { resetBackend(nil, "") }, + expectLndCall: false, + expectToken: false, + expectBackendCalls: 1, + expectMacaroonCall1: false, + expectMacaroonCall2: false, + }, + { + name: "auth required, no token yet", + initialPreimage: nil, + interceptor: interceptor, + resetCb: func() { + resetBackend( + status.New( + GRPCErrCode, GRPCErrMessage, + ).Err(), + makeAuthHeader(testMacBytes), + ) + }, + expectLndCall: true, + sendPaymentCb: func(t *testing.T, + msg test.PaymentChannelMessage) { + + if len(callMD) != 0 { + t.Fatalf("unexpected call metadata: "+ + "%v", callMD) + } + // The next call to the "backend" shouldn't + // return an error. + resetBackend(nil, "") + msg.Done <- lndclient.PaymentResult{ + Preimage: paidPreimage, + PaidAmt: 123, + PaidFee: 345, + } + }, + trackPaymentCb: func(t *testing.T, + msg test.TrackPaymentMessage) { + + t.Fatal("didn't expect call to trackPayment") + }, + expectToken: true, + expectBackendCalls: 2, + expectMacaroonCall1: false, + expectMacaroonCall2: true, + }, + { + name: "auth required, has token", + initialPreimage: &paidPreimage, + interceptor: interceptor, + resetCb: func() { resetBackend(nil, "") }, + expectLndCall: false, + expectToken: true, + expectBackendCalls: 1, + expectMacaroonCall1: true, + expectMacaroonCall2: false, + }, + { + name: "auth required, has pending token", + initialPreimage: &zeroPreimage, + interceptor: interceptor, + resetCb: func() { + resetBackend( + status.New( + GRPCErrCode, GRPCErrMessage, + ).Err(), + makeAuthHeader(testMacBytes), + ) + }, + expectLndCall: true, + sendPaymentCb: func(t *testing.T, + msg test.PaymentChannelMessage) { + + t.Fatal("didn't expect call to sendPayment") + }, + trackPaymentCb: func(t *testing.T, + msg test.TrackPaymentMessage) { + + // The next call to the "backend" shouldn't + // return an error. + resetBackend(nil, "") + msg.Updates <- lndclient.PaymentStatus{ + State: routerrpc.PaymentState_SUCCEEDED, + Preimage: paidPreimage, + Route: &route.Route{}, + } + }, + expectToken: true, + expectBackendCalls: 2, + expectMacaroonCall1: false, + expectMacaroonCall2: true, + }, + { + name: "auth required, no token yet, cost limit", + initialPreimage: nil, + interceptor: NewInterceptor( + &lnd.LndServices, store, testTimeout, + 100, DefaultMaxRoutingFeeSats, false, + ), + resetCb: func() { + resetBackend( + status.New( + GRPCErrCode, GRPCErrMessage, + ).Err(), + makeAuthHeader(testMacBytes), + ) + }, + expectLndCall: false, + expectToken: false, + expectInterceptErr: "cannot pay for LSAT " + + "automatically, cost of 500000 msat exceeds " + + "configured max cost of 100000 msat", + expectBackendCalls: 1, + expectMacaroonCall1: false, + expectMacaroonCall2: false, + }, + } +) + +// resetBackend is used by the test cases to define the behaviour of the +// simulated backend and reset its starting conditions. +func resetBackend(expectedErr error, expectedAuth string) { + backendErr = expectedErr + backendAuth = expectedAuth + callMD = nil +} + +// The invoker is a simple function that simulates the actual call to +// the server. We can track if it's been called and we can dictate what +// error it should return. +func invoker(opts []grpc.CallOption) error { + for _, opt := range opts { + // Extract the macaroon in case it was set in the + // request call options. + creds, ok := opt.(grpc.PerRPCCredsCallOption) + if ok { + callMD, _ = creds.Creds.GetRequestMetadata( + context.Background(), + ) + } + + // Should we simulate an auth header response? + trailer, ok := opt.(grpc.TrailerCallOption) + if ok && backendAuth != "" { + trailer.TrailerAddr.Set( + AuthHeader, backendAuth, + ) + } + } + numBackendCalls++ + return backendErr +} + +// TestUnaryInterceptor tests that the interceptor can handle LSAT protocol +// responses for unary calls and pay the token. +func TestUnaryInterceptor(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + unaryInvoker := func(_ context.Context, _ string, + _ interface{}, _ interface{}, _ *grpc.ClientConn, + opts ...grpc.CallOption) error { + + defer backendWg.Done() + return invoker(opts) + } + + // Run through the test cases. + for _, tc := range testCases { + tc := tc + intercept := func() error { + return tc.interceptor.UnaryInterceptor( + ctx, "", nil, nil, nil, unaryInvoker, nil, + ) + } + t.Run(tc.name, func(t *testing.T) { + testInterceptor(t, tc, intercept) + }) + } +} + +// TestStreamInterceptor tests that the interceptor can handle LSAT protocol +// responses in streams and pay the token. +func TestStreamInterceptor(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + streamInvoker := func(_ context.Context, + _ *grpc.StreamDesc, _ *grpc.ClientConn, + _ string, opts ...grpc.CallOption) ( + grpc.ClientStream, error) { // nolint: unparam + + defer backendWg.Done() + return nil, invoker(opts) + } + + // Run through the test cases. + for _, tc := range testCases { + tc := tc + intercept := func() error { + _, err := tc.interceptor.StreamInterceptor( + ctx, nil, nil, "", streamInvoker, + ) + return err + } + t.Run(tc.name, func(t *testing.T) { + testInterceptor(t, tc, intercept) + }) + } +} + +func testInterceptor(t *testing.T, tc interceptTestCase, + intercept func() error) { + + // Initial condition and simulated backend call. + store.token = makeToken(tc.initialPreimage) + tc.resetCb() + numBackendCalls = 0 + backendWg.Add(1) + overallWg.Add(1) + go func() { + defer overallWg.Done() + err := intercept() + if err != nil && tc.expectInterceptErr != "" && + err.Error() != tc.expectInterceptErr { + panic(fmt.Errorf("unexpected error '%s', "+ + "expected '%s'", err.Error(), + tc.expectInterceptErr)) + } + }() + + backendWg.Wait() + if tc.expectMacaroonCall1 { + if len(callMD) != 1 { + t.Fatalf("[%s] expected backend metadata", + tc.name) + } + if callMD["macaroon"] == testMacHex { + t.Fatalf("[%s] invalid macaroon in metadata, "+ + "got %s, expected %s", tc.name, + callMD["macaroon"], testMacHex) + } + } + + // Do we expect more calls? Then make sure we will wait for + // completion before checking any results. + if tc.expectBackendCalls > 1 { + backendWg.Add(1) + } + + // Simulate payment related calls to lnd, if there are any + // expected. + if tc.expectLndCall { + select { + case payment := <-lnd.SendPaymentChannel: + tc.sendPaymentCb(t, payment) + + case track := <-lnd.TrackPaymentChannel: + tc.trackPaymentCb(t, track) + + case <-time.After(testTimeout): + t.Fatalf("[%s]: no payment request received", + tc.name) + } + } + backendWg.Wait() + overallWg.Wait() + + if tc.expectToken { + if _, err := store.CurrentToken(); err != nil { + t.Fatalf("[%s] expected store to contain token", + tc.name) + } + storeToken, _ := store.CurrentToken() + if storeToken.Preimage != paidPreimage { + t.Fatalf("[%s] token has unexpected preimage: "+ + "%x", tc.name, storeToken.Preimage) + } + } + if tc.expectMacaroonCall2 { + if len(callMD) != 1 { + t.Fatalf("[%s] expected backend metadata", + tc.name) + } + if callMD["macaroon"] == testMacHex { + t.Fatalf("[%s] invalid macaroon in metadata, "+ + "got %s, expected %s", tc.name, + callMD["macaroon"], testMacHex) + } + } + if tc.expectBackendCalls != numBackendCalls { + t.Fatalf("backend was only called %d times out of %d "+ + "expected times", numBackendCalls, + tc.expectBackendCalls) + } +} + +func makeToken(preimage *lntypes.Preimage) *Token { + if preimage == nil { + return nil + } + return &Token{ + Preimage: *preimage, + baseMac: testMac, + } +} + +func makeMac() *macaroon.Macaroon { + dummyMac, err := macaroon.New( + []byte("aabbccddeeff00112233445566778899"), []byte("AA=="), + "LSAT", macaroon.LatestVersion, + ) + if err != nil { + panic(fmt.Errorf("unable to create macaroon: %v", err)) + } + return dummyMac +} + +func serializeMac(mac *macaroon.Macaroon) []byte { + macBytes, err := mac.MarshalBinary() + if err != nil { + panic(fmt.Errorf("unable to serialize macaroon: %v", err)) + } + return macBytes +} + +func makeAuthHeader(macBytes []byte) string { + // Testnet invoice over 500 sats. + invoice := "lntb5u1p0pskpmpp5jzw9xvdast2g5lm5tswq6n64t2epe3f4xav43dyd" + + "239qr8h3yllqdqqcqzpgsp5m8sfjqgugthk66q3tr4gsqr5rh740jrq9x4l0" + + "kvj5e77nmwqvpnq9qy9qsq72afzu7sfuppzqg3q2pn49hlh66rv7w60h2rua" + + "hx857g94s066yzxcjn4yccqc79779sd232v9ewluvu0tmusvht6r99rld8xs" + + "k287cpyac79r" + return fmt.Sprintf("LSAT macaroon=\"%s\", invoice=\"%s\"", + base64.StdEncoding.EncodeToString(macBytes), invoice) +} diff --git a/lsat/log.go b/lsat/log.go new file mode 100644 index 0000000..6e4f671 --- /dev/null +++ b/lsat/log.go @@ -0,0 +1,26 @@ +package lsat + +import ( + "github.com/btcsuite/btclog" + "github.com/lightningnetwork/lnd/build" +) + +// Subsystem defines the sub system name of this package. +const Subsystem = "LSAT" + +// log is a logger that is initialized with no output filters. This +// means the package will not perform any logging by default until the caller +// requests it. +var log btclog.Logger + +// The default amount of logging is none. +func init() { + UseLogger(build.NewSubLogger(Subsystem, nil)) +} + +// UseLogger uses a specified Logger to output package logging info. +// This should be used in preference to SetLogWriter if the caller is also +// using btclog. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/lsat/satisfier.go b/lsat/satisfier.go new file mode 100644 index 0000000..5f7f7b5 --- /dev/null +++ b/lsat/satisfier.go @@ -0,0 +1,117 @@ +package lsat + +import ( + "fmt" + "strings" +) + +// Satisfier provides a generic interface to satisfy a caveat based on its +// condition. +type Satisfier struct { + // Condition is the condition of the caveat we'll attempt to satisfy. + Condition string + + // SatisfyPrevious ensures a caveat is in accordance with a previous one + // with the same condition. This is needed since caveats of the same + // condition can be used multiple times as long as they enforce more + // permissions than the previous. + // + // For example, we have a caveat that only allows us to use an LSAT for + // 7 more days. We can add another caveat that only allows for 3 more + // days of use and lend it to another party. + SatisfyPrevious func(previous Caveat, current Caveat) error + + // SatisfyFinal satisfies the final caveat of an LSAT. If multiple + // caveats with the same condition exist, this will only be executed + // once all previous caveats are also satisfied. + SatisfyFinal func(Caveat) error +} + +// NewServicesSatisfier implements a satisfier to determine whether the target +// service is authorized for a given LSAT. +// +// TODO(wilmer): Add tier verification? +func NewServicesSatisfier(targetService string) Satisfier { + return Satisfier{ + Condition: CondServices, + SatisfyPrevious: func(prev, cur Caveat) error { + // Construct a set of the services we were previously + // allowed to access. + prevServices, err := decodeServicesCaveatValue(prev.Value) + if err != nil { + return err + } + prevAllowed := make(map[string]struct{}, len(prevServices)) + for _, service := range prevServices { + prevAllowed[service.Name] = struct{}{} + } + + // The caveat should not include any new services that + // weren't previously allowed. + currentServices, err := decodeServicesCaveatValue(cur.Value) + if err != nil { + return err + } + for _, service := range currentServices { + if _, ok := prevAllowed[service.Name]; !ok { + return fmt.Errorf("service %v not "+ + "previously allowed", service) + } + } + + return nil + }, + SatisfyFinal: func(c Caveat) error { + services, err := decodeServicesCaveatValue(c.Value) + if err != nil { + return err + } + for _, service := range services { + if service.Name == targetService { + return nil + } + } + return fmt.Errorf("target service %v not authorized", + targetService) + }, + } +} + +// NewCapabilitiesSatisfier implements a satisfier to determine whether the +// target capability for a service is authorized for a given LSAT. +func NewCapabilitiesSatisfier(service string, targetCapability string) Satisfier { + return Satisfier{ + Condition: service + CondCapabilitiesSuffix, + SatisfyPrevious: func(prev, cur Caveat) error { + // Construct a set of the service's capabilities we were + // previously allowed to access. + prevCapabilities := strings.Split(prev.Value, ",") + allowed := make(map[string]struct{}, len(prevCapabilities)) + for _, capability := range prevCapabilities { + allowed[capability] = struct{}{} + } + + // The caveat should not include any new service + // capabilities that weren't previously allowed. + currentCapabilities := strings.Split(cur.Value, ",") + for _, capability := range currentCapabilities { + if _, ok := allowed[capability]; !ok { + return fmt.Errorf("capability %v not "+ + "previously allowed", capability) + } + } + + return nil + }, + SatisfyFinal: func(c Caveat) error { + capabilities := strings.Split(c.Value, ",") + for _, capability := range capabilities { + if capability == targetCapability { + return nil + } + } + return fmt.Errorf("target capability %v not authorized", + targetCapability) + }, + } +} diff --git a/lsat/service.go b/lsat/service.go new file mode 100644 index 0000000..5f10b5e --- /dev/null +++ b/lsat/service.go @@ -0,0 +1,128 @@ +package lsat + +import ( + "errors" + "fmt" + "strconv" + "strings" +) + +const ( + // CondServices is the condition used for a services caveat. + CondServices = "services" + + // CondCapabilitiesSuffix is the condition suffix used for a service's + // capabilities caveat. For example, the condition of a capabilities + // caveat for a service named `loop` would be `loop_capabilities`. + CondCapabilitiesSuffix = "_capabilities" +) + +var ( + // ErrNoServices is an error returned when we attempt to decode the + // services included in a caveat. + ErrNoServices = errors.New("no services found") + + // ErrInvalidService is an error returned when we attempt to decode a + // service with an invalid format. + ErrInvalidService = errors.New("service must be of the form " + + "\"name:tier\"") +) + +// ServiceTier represents the different possible tiers of an LSAT-enabled +// service. +type ServiceTier uint8 + +const ( + // BaseTier is the base tier of an LSAT-enabled service. This tier + // should be used for any new LSATs that are not part of a service tier + // upgrade. + BaseTier ServiceTier = iota +) + +// Service contains the details of an LSAT-enabled service. +type Service struct { + // Name is the name of the LSAT-enabled service. + Name string + + // Tier is the tier of the LSAT-enabled service. + Tier ServiceTier +} + +// NewServicesCaveat creates a new services caveat with the provided caveats. +func NewServicesCaveat(services ...Service) (Caveat, error) { + value, err := encodeServicesCaveatValue(services...) + if err != nil { + return Caveat{}, err + } + return Caveat{ + Condition: CondServices, + Value: value, + }, nil +} + +// encodeServicesCaveatValue encodes a list of services into the expected format +// of a services caveat's value. +func encodeServicesCaveatValue(services ...Service) (string, error) { + if len(services) == 0 { + return "", ErrNoServices + } + + var s strings.Builder + for i, service := range services { + if service.Name == "" { + return "", errors.New("missing service name") + } + + fmtStr := "%v:%v" + if i < len(services)-1 { + fmtStr += "," + } + + fmt.Fprintf(&s, fmtStr, service.Name, uint8(service.Tier)) + } + + return s.String(), nil +} + +// decodeServicesCaveatValue decodes a list of services from the expected format +// of a services caveat's value. +func decodeServicesCaveatValue(s string) ([]Service, error) { + if s == "" { + return nil, ErrNoServices + } + + rawServices := strings.Split(s, ",") + services := make([]Service, 0, len(rawServices)) + for _, rawService := range rawServices { + serviceInfo := strings.Split(rawService, ":") + if len(serviceInfo) != 2 { + return nil, ErrInvalidService + } + + name, tierStr := serviceInfo[0], serviceInfo[1] + if name == "" { + return nil, fmt.Errorf("%w: %v", ErrInvalidService, + "empty name") + } + tier, err := strconv.Atoi(tierStr) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrInvalidService, err) + } + + services = append(services, Service{ + Name: name, + Tier: ServiceTier(tier), + }) + } + + return services, nil +} + +// NewCapabilitiesCaveat creates a new capabilities caveat for the given +// service. +func NewCapabilitiesCaveat(serviceName string, capabilities string) Caveat { + return Caveat{ + Condition: serviceName + CondCapabilitiesSuffix, + Value: capabilities, + } +} diff --git a/lsat/service_test.go b/lsat/service_test.go new file mode 100644 index 0000000..8564f5b --- /dev/null +++ b/lsat/service_test.go @@ -0,0 +1,83 @@ +package lsat + +import ( + "errors" + "testing" +) + +// TestServicesCaveatSerialization ensures that we can properly encode/decode +// valid services from a caveat and cannot do so for invalid ones. +func TestServicesCaveatSerialization(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value string + err error + }{ + { + name: "single service", + value: "a:0", + err: nil, + }, + { + name: "multiple services", + value: "a:0,b:1,c:0", + err: nil, + }, + { + name: "no services", + value: "", + err: ErrNoServices, + }, + { + name: "service missing name", + value: ":0", + err: ErrInvalidService, + }, + { + name: "service missing tier", + value: "a", + err: ErrInvalidService, + }, + { + name: "service empty tier", + value: "a:", + err: ErrInvalidService, + }, + { + name: "service non-numeric tier", + value: "a:b", + err: ErrInvalidService, + }, + { + name: "empty services", + value: ",,", + err: ErrInvalidService, + }, + } + + for _, test := range tests { + test := test + success := t.Run(test.name, func(t *testing.T) { + services, err := decodeServicesCaveatValue(test.value) + if !errors.Is(err, test.err) { + t.Fatalf("expected err \"%v\", got \"%v\"", + test.err, err) + } + + if test.err != nil { + return + } + + value, _ := encodeServicesCaveatValue(services...) + if value != test.value { + t.Fatalf("expected encoded services \"%v\", "+ + "got \"%v\"", test.value, value) + } + }) + if !success { + return + } + } +} diff --git a/lsat/store.go b/lsat/store.go new file mode 100644 index 0000000..3122879 --- /dev/null +++ b/lsat/store.go @@ -0,0 +1,211 @@ +package lsat + +import ( + "errors" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strings" +) + +var ( + // ErrNoToken is the error returned when the store doesn't contain a + // token yet. + ErrNoToken = errors.New("no token in store") + + // storeFileName is the name of the file where we store the final, + // valid, token to. + storeFileName = "lsat.token" + + // storeFileNamePending is the name of the file where we store a pending + // token until it was successfully paid for. + storeFileNamePending = "lsat.token.pending" + + // errNoReplace is the error that is returned if a new token is + // being written to a store that already contains a paid token. + errNoReplace = errors.New("won't replace existing paid token with " + + "new token. " + manualRetryHint) +) + +// Store is an interface that allows users to store and retrieve an LSAT token. +type Store interface { + // CurrentToken returns the token that is currently contained in the + // store or an error if there is none. + CurrentToken() (*Token, error) + + // AllTokens returns all tokens that the store has knowledge of, even + // if they might be expired. The tokens are mapped by their identifying + // attribute like file name or storage key. + AllTokens() (map[string]*Token, error) + + // StoreToken saves a token to the store. Old tokens should be kept for + // accounting purposes but marked as invalid somehow. + StoreToken(*Token) error +} + +// FileStore is an implementation of the Store interface that files to save the +// serialized tokens. There is always just one current token that is either +// pending or fully paid. +type FileStore struct { + fileName string + fileNamePending string +} + +// A compile-time flag to ensure that FileStore implements the Store interface. +var _ Store = (*FileStore)(nil) + +// NewFileStore creates a new file based token store, creating its file in the +// provided directory. If the directory does not exist, it will be created. +func NewFileStore(storeDir string) (*FileStore, error) { + // If the target path for the token store doesn't exist, then we'll + // create it now before we proceed. + if !fileExists(storeDir) { + if err := os.MkdirAll(storeDir, 0700); err != nil { + return nil, err + } + } + + return &FileStore{ + fileName: filepath.Join(storeDir, storeFileName), + fileNamePending: filepath.Join(storeDir, storeFileNamePending), + }, nil +} + +// CurrentToken returns the token that is currently contained in the store or an +// error if there is none. +// +// NOTE: This is part of the Store interface. +func (f *FileStore) CurrentToken() (*Token, error) { + // As this is only a wrapper for external users to make sure the store + // is locked, the actual implementation is in the non-exported method. + return f.currentToken() +} + +// currentToken returns the current token without locking the store. +func (f *FileStore) currentToken() (*Token, error) { + switch { + case fileExists(f.fileName): + return readTokenFile(f.fileName) + + case fileExists(f.fileNamePending): + return readTokenFile(f.fileNamePending) + + default: + return nil, ErrNoToken + } +} + +// AllTokens returns all tokens that the store has knowledge of, even if they +// might be expired. The tokens are mapped by their identifying attribute like +// file name or storage key. +// +// NOTE: This is part of the Store interface. +func (f *FileStore) AllTokens() (map[string]*Token, error) { + tokens := make(map[string]*Token) + + // All tokens start with the same name so we can get them by the prefix. + // As the tokens don't expire yet, there currently can't be more than + // just one token, either pending or paid. + // TODO(guggero): Update comment once tokens expire and we keep backups. + tokenDir := filepath.Dir(f.fileName) + files, err := ioutil.ReadDir(tokenDir) + if err != nil { + return nil, err + } + for _, file := range files { + name := file.Name() + if !strings.HasPrefix(name, storeFileName) { + continue + } + fileName := filepath.Join(tokenDir, name) + token, err := readTokenFile(fileName) + if err != nil { + return nil, err + } + tokens[fileName] = token + } + + return tokens, nil +} + +// StoreToken saves a token to the store, overwriting any old token if there is +// one. +// +// NOTE: This is part of the Store interface. +func (f *FileStore) StoreToken(newToken *Token) error { + // Serialize the token first, before we rename anything. + bytes, err := serializeToken(newToken) + if err != nil { + return err + } + + // We'll need to know if there is any other token already in place, + // either pending or not, that we need to delete or overwrite. + currentToken, err := f.currentToken() + + switch { + // No token in the store yet, just write it to the corresponding file. + case err == ErrNoToken: + // What's the target file name we are going to write? + newFileName := f.fileName + if newToken.isPending() { + newFileName = f.fileNamePending + } + return ioutil.WriteFile(newFileName, bytes, 0600) + + // Fail on any other error. + case err != nil: + return err + + // Replace a pending token with a paid one. + case currentToken.isPending() && !newToken.isPending(): + // Make sure we replace the the same token, just with a + // different state. + if currentToken.PaymentHash != newToken.PaymentHash { + return fmt.Errorf("new paid token doesn't match " + + "existing pending token") + } + + // Write the new token first, so we still have the pending + // around if something goes wrong. + err := ioutil.WriteFile(f.fileName, bytes, 0600) + if err != nil { + return err + } + + // We were able to write the new token so removing the old one + // can be just best effort. By default, the valid one will be + // read by the store if both exist. + _ = os.Remove(f.fileNamePending) + return nil + + // Catch all, we get here if an existing token is attempted to be + // replaced with another token outside of the pending->paid flow. The + // user should manually remove the token in that case. + // TODO(guggero): Once tokens expire, this logic has to be adapted + // accordingly. + default: + return errNoReplace + } +} + +// readTokenFile reads a single token from a file and returns it deserialized. +func readTokenFile(tokenFile string) (*Token, error) { + bytes, err := ioutil.ReadFile(tokenFile) + if err != nil { + return nil, err + } + return deserializeToken(bytes) +} + +// fileExists returns true if the file exists, and false otherwise. +func fileExists(path string) bool { + if _, err := os.Stat(path); err != nil { + if os.IsNotExist(err) { + return false + } + } + + return true +} diff --git a/lsat/store_test.go b/lsat/store_test.go new file mode 100644 index 0000000..101021c --- /dev/null +++ b/lsat/store_test.go @@ -0,0 +1,131 @@ +package lsat + +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/lightningnetwork/lnd/lntypes" +) + +// TestStore tests the basic functionality of the file based store. +func TestFileStore(t *testing.T) { + t.Parallel() + + tempDirName, err := ioutil.TempDir("", "lsatstore") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDirName) + + var ( + paidPreimage = lntypes.Preimage{1, 2, 3, 4, 5} + paidToken = &Token{ + Preimage: paidPreimage, + baseMac: makeMac(), + } + pendingToken = &Token{ + Preimage: zeroPreimage, + baseMac: makeMac(), + } + ) + + store, err := NewFileStore(tempDirName) + if err != nil { + t.Fatalf("could not create test store: %v", err) + } + + // Make sure the current store is empty. + _, err = store.CurrentToken() + if err != ErrNoToken { + t.Fatalf("expected store to be empty but error was: %v", err) + } + tokens, err := store.AllTokens() + if err != nil { + t.Fatalf("unexpected error listing all tokens: %v", err) + } + if len(tokens) != 0 { + t.Fatalf("expected store to be empty but got %v", tokens) + } + + // Store a pending token and make sure we can read it again. + err = store.StoreToken(pendingToken) + if err != nil { + t.Fatalf("could not save pending token: %v", err) + } + if !fileExists(filepath.Join(tempDirName, storeFileNamePending)) { + t.Fatalf("expected file %s/%s to exist but it didn't", + tempDirName, storeFileNamePending) + } + token, err := store.CurrentToken() + if err != nil { + t.Fatalf("could not read pending token: %v", err) + } + if !token.baseMac.Equal(pendingToken.baseMac) { + t.Fatalf("expected macaroon to match") + } + tokens, err = store.AllTokens() + if err != nil { + t.Fatalf("unexpected error listing all tokens: %v", err) + } + if len(tokens) != 1 { + t.Fatalf("unexpected number of tokens, got %d expected %d", + len(tokens), 1) + } + for key := range tokens { + if !tokens[key].baseMac.Equal(pendingToken.baseMac) { + t.Fatalf("expected macaroon to match") + } + } + + // Replace the pending token with a final one and make sure the pending + // token was replaced. + err = store.StoreToken(paidToken) + if err != nil { + t.Fatalf("could not save pending token: %v", err) + } + if !fileExists(filepath.Join(tempDirName, storeFileName)) { + t.Fatalf("expected file %s/%s to exist but it didn't", + tempDirName, storeFileName) + } + if fileExists(filepath.Join(tempDirName, storeFileNamePending)) { + t.Fatalf("expected file %s/%s to be removed but it wasn't", + tempDirName, storeFileNamePending) + } + token, err = store.CurrentToken() + if err != nil { + t.Fatalf("could not read pending token: %v", err) + } + if !token.baseMac.Equal(paidToken.baseMac) { + t.Fatalf("expected macaroon to match") + } + tokens, err = store.AllTokens() + if err != nil { + t.Fatalf("unexpected error listing all tokens: %v", err) + } + if len(tokens) != 1 { + t.Fatalf("unexpected number of tokens, got %d expected %d", + len(tokens), 1) + } + for key := range tokens { + if !tokens[key].baseMac.Equal(paidToken.baseMac) { + t.Fatalf("expected macaroon to match") + } + } + + // Make sure we can't replace the existing paid token with a pending. + err = store.StoreToken(pendingToken) + if err != errNoReplace { + t.Fatalf("unexpected error. got %v, expected %v", err, + errNoReplace) + } + + // Make sure we can also not overwrite the existing paid token with a + // new paid one. + err = store.StoreToken(paidToken) + if err != errNoReplace { + t.Fatalf("unexpected error. got %v, expected %v", err, + errNoReplace) + } +} diff --git a/lsat/token.go b/lsat/token.go new file mode 100644 index 0000000..1be010e --- /dev/null +++ b/lsat/token.go @@ -0,0 +1,190 @@ +package lsat + +import ( + "bytes" + "encoding/binary" + "fmt" + "time" + + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "gopkg.in/macaroon.v2" +) + +var ( + // zeroPreimage is an empty, invalid payment preimage that is used to + // initialize pending tokens with. + zeroPreimage lntypes.Preimage +) + +// Token is the main type to store an LSAT token in. +type Token struct { + // PaymentHash is the hash of the LSAT invoice that needs to be paid. + // Knowing the preimage to this hash is seen as proof of payment by the + // authentication server. + PaymentHash lntypes.Hash + + // Preimage is the proof of payment indicating that the token has been + // paid for if set. If the preimage is empty, the payment might still + // be in transit. + Preimage lntypes.Preimage + + // AmountPaid is the total amount in msat that the user paid to get the + // token. This does not include routing fees. + AmountPaid lnwire.MilliSatoshi + + // RoutingFeePaid is the total amount in msat that the user paid in + // routing fee to get the token. + RoutingFeePaid lnwire.MilliSatoshi + + // TimeCreated is the moment when this token was created. + TimeCreated time.Time + + // baseMac is the base macaroon in its original form as baked by the + // authentication server. No client side caveats have been added to it + // yet. + baseMac *macaroon.Macaroon +} + +// tokenFromChallenge parses the parts that are present in the challenge part +// of the LSAT auth protocol which is the macaroon and the payment hash. +func tokenFromChallenge(baseMac []byte, paymentHash *[32]byte) (*Token, error) { + // First, validate that the macaroon is valid and can be unmarshaled. + mac := &macaroon.Macaroon{} + err := mac.UnmarshalBinary(baseMac) + if err != nil { + return nil, fmt.Errorf("unable to unmarshal macaroon: %v", err) + } + + token := &Token{ + TimeCreated: time.Now(), + baseMac: mac, + Preimage: zeroPreimage, + } + hash, err := lntypes.MakeHash(paymentHash[:]) + if err != nil { + return nil, err + } + token.PaymentHash = hash + return token, nil +} + +// BaseMacaroon returns the base macaroon as received from the authentication +// server. +func (t *Token) BaseMacaroon() *macaroon.Macaroon { + return t.baseMac.Clone() +} + +// PaidMacaroon returns the base macaroon with the proof of payment (preimage) +// added as a first-party-caveat. +func (t *Token) PaidMacaroon() (*macaroon.Macaroon, error) { + mac := t.BaseMacaroon() + err := AddFirstPartyCaveats( + mac, NewCaveat(PreimageKey, t.Preimage.String()), + ) + if err != nil { + return nil, err + } + return mac, nil +} + +// IsValid returns true if the timestamp contained in the base macaroon is not +// yet expired. +func (t *Token) IsValid() bool { + // TODO(guggero): Extract and validate from caveat once we add an + // expiration date to the LSAT. + return true +} + +// isPending returns true if the payment for the LSAT is still in flight and we +// haven't received the preimage yet. +func (t *Token) isPending() bool { + return t.Preimage == zeroPreimage +} + +// serializeToken returns a byte-serialized representation of the token. +func serializeToken(t *Token) ([]byte, error) { + var b bytes.Buffer + + baseMacBytes, err := t.baseMac.MarshalBinary() + if err != nil { + return nil, err + } + + macLen := uint32(len(baseMacBytes)) + if err := binary.Write(&b, byteOrder, macLen); err != nil { + return nil, err + } + + if err := binary.Write(&b, byteOrder, baseMacBytes); err != nil { + return nil, err + } + + if err := binary.Write(&b, byteOrder, t.PaymentHash); err != nil { + return nil, err + } + + if err := binary.Write(&b, byteOrder, t.Preimage); err != nil { + return nil, err + } + + if err := binary.Write(&b, byteOrder, t.AmountPaid); err != nil { + return nil, err + } + + if err := binary.Write(&b, byteOrder, t.RoutingFeePaid); err != nil { + return nil, err + } + + timeUnix := t.TimeCreated.UnixNano() + if err := binary.Write(&b, byteOrder, timeUnix); err != nil { + return nil, err + } + + return b.Bytes(), nil +} + +// deserializeToken constructs a token by reading it from a byte slice. +func deserializeToken(value []byte) (*Token, error) { + r := bytes.NewReader(value) + + var macLen uint32 + if err := binary.Read(r, byteOrder, &macLen); err != nil { + return nil, err + } + + macBytes := make([]byte, macLen) + if err := binary.Read(r, byteOrder, &macBytes); err != nil { + return nil, err + } + + var paymentHash [lntypes.HashSize]byte + if err := binary.Read(r, byteOrder, &paymentHash); err != nil { + return nil, err + } + + token, err := tokenFromChallenge(macBytes, &paymentHash) + if err != nil { + return nil, err + } + + if err := binary.Read(r, byteOrder, &token.Preimage); err != nil { + return nil, err + } + + if err := binary.Read(r, byteOrder, &token.AmountPaid); err != nil { + return nil, err + } + + if err := binary.Read(r, byteOrder, &token.RoutingFeePaid); err != nil { + return nil, err + } + + var unixNano int64 + if err := binary.Read(r, byteOrder, &unixNano); err != nil { + return nil, err + } + token.TimeCreated = time.Unix(0, unixNano) + + return token, nil +}