Files
ark/server/internal/interface/grpc/handlers/arkservice.go
Louis Singer 9e3d667b51 Fix: "tree signing session not found" error (#323)
* failing test

* fix duplicate input register

* fix btc-embedded coin selection

* rename test

* add checks in failing test case

* fixes GetEventStream

* add TODO comment in createPoolTx

* update with master changes

* fix server unit test

* increase liquidity of testing ASP

* simplify AliceSeveralPaymentsBob test
2024-09-20 12:03:12 +02:00

565 lines
15 KiB
Go

package handlers
import (
"context"
"encoding/hex"
"sync"
arkv1 "github.com/ark-network/ark/api-spec/protobuf/gen/ark/v1"
"github.com/ark-network/ark/server/internal/core/application"
"github.com/ark-network/ark/server/internal/core/domain"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
type listener struct {
id string
done chan struct{}
ch chan *arkv1.GetEventStreamResponse
}
type handler struct {
svc application.Service
listenersLock *sync.Mutex
listeners []*listener
}
func NewHandler(service application.Service) arkv1.ArkServiceServer {
h := &handler{
svc: service,
listenersLock: &sync.Mutex{},
listeners: make([]*listener, 0),
}
go h.listenToEvents()
return h
}
func (h *handler) CompletePayment(ctx context.Context, req *arkv1.CompletePaymentRequest) (*arkv1.CompletePaymentResponse, error) {
if req.GetSignedRedeemTx() == "" {
return nil, status.Error(codes.InvalidArgument, "missing signed redeem tx")
}
if len(req.GetSignedUnconditionalForfeitTxs()) <= 0 {
return nil, status.Error(codes.InvalidArgument, "missing signed unconditional forfeit txs")
}
if err := h.svc.CompleteAsyncPayment(
ctx, req.GetSignedRedeemTx(), req.GetSignedUnconditionalForfeitTxs(),
); err != nil {
return nil, err
}
return &arkv1.CompletePaymentResponse{}, nil
}
func (h *handler) CreatePayment(ctx context.Context, req *arkv1.CreatePaymentRequest) (*arkv1.CreatePaymentResponse, error) {
inputs, err := parseInputs(req.GetInputs())
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
receivers, err := parseReceivers(req.GetOutputs())
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
for _, receiver := range receivers {
if receiver.Amount <= 0 {
return nil, status.Error(codes.InvalidArgument, "output amount must be greater than 0")
}
if len(receiver.OnchainAddress) > 0 {
return nil, status.Error(codes.InvalidArgument, "onchain address is not supported as async payment destination")
}
if len(receiver.Descriptor) <= 0 {
return nil, status.Error(codes.InvalidArgument, "missing output descriptor")
}
}
redeemTx, unconditionalForfeitTxs, err := h.svc.CreateAsyncPayment(
ctx, inputs, receivers,
)
if err != nil {
return nil, err
}
return &arkv1.CreatePaymentResponse{
SignedRedeemTx: redeemTx,
UsignedUnconditionalForfeitTxs: unconditionalForfeitTxs,
}, nil
}
func (h *handler) Ping(ctx context.Context, req *arkv1.PingRequest) (*arkv1.PingResponse, error) {
if req.GetPaymentId() == "" {
return nil, status.Error(codes.InvalidArgument, "missing payment id")
}
lastEvent, err := h.svc.UpdatePaymentStatus(ctx, req.GetPaymentId())
if err != nil {
return nil, err
}
var resp *arkv1.PingResponse
switch e := lastEvent.(type) {
case domain.RoundFinalizationStarted:
resp = &arkv1.PingResponse{
Event: &arkv1.PingResponse_RoundFinalization{
RoundFinalization: &arkv1.RoundFinalizationEvent{
Id: e.Id,
PoolTx: e.PoolTx,
CongestionTree: congestionTree(e.CongestionTree).toProto(),
Connectors: e.Connectors,
MinRelayFeeRate: e.MinRelayFeeRate,
},
},
}
case domain.RoundFinalized:
resp = &arkv1.PingResponse{
Event: &arkv1.PingResponse_RoundFinalized{
RoundFinalized: &arkv1.RoundFinalizedEvent{
Id: e.Id,
PoolTxid: e.Txid,
},
},
}
case domain.RoundFailed:
resp = &arkv1.PingResponse{
Event: &arkv1.PingResponse_RoundFailed{
RoundFailed: &arkv1.RoundFailed{
Id: e.Id,
Reason: e.Err,
},
},
}
case application.RoundSigningStarted:
cosignersKeys := make([]string, 0, len(e.Cosigners))
for _, key := range e.Cosigners {
cosignersKeys = append(cosignersKeys, hex.EncodeToString(key.SerializeCompressed()))
}
resp = &arkv1.PingResponse{
Event: &arkv1.PingResponse_RoundSigning{
RoundSigning: &arkv1.RoundSigningEvent{
Id: e.Id,
CosignersPubkeys: cosignersKeys,
UnsignedTree: congestionTree(e.UnsignedVtxoTree).toProto(),
UnsignedRoundTx: e.UnsignedRoundTx,
},
},
}
case application.RoundSigningNoncesGenerated:
serialized, err := e.SerializeNonces()
if err != nil {
logrus.WithError(err).Error("failed to serialize nonces")
return nil, status.Error(codes.Internal, "failed to serialize nonces")
}
resp = &arkv1.PingResponse{
Event: &arkv1.PingResponse_RoundSigningNoncesGenerated{
RoundSigningNoncesGenerated: &arkv1.RoundSigningNoncesGeneratedEvent{
Id: e.Id,
TreeNonces: serialized,
},
},
}
}
return resp, nil
}
func (h *handler) RegisterPayment(ctx context.Context, req *arkv1.RegisterPaymentRequest) (*arkv1.RegisterPaymentResponse, error) {
inputs, err := parseInputs(req.GetInputs())
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
id, err := h.svc.SpendVtxos(ctx, inputs)
if err != nil {
return nil, err
}
pubkey := req.GetEphemeralPubkey()
if len(pubkey) > 0 {
if err := h.svc.RegisterCosignerPubkey(ctx, id, pubkey); err != nil {
return nil, err
}
}
return &arkv1.RegisterPaymentResponse{
Id: id,
}, nil
}
func (h *handler) ClaimPayment(ctx context.Context, req *arkv1.ClaimPaymentRequest) (*arkv1.ClaimPaymentResponse, error) {
receivers, err := parseReceivers(req.GetOutputs())
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
if err := h.svc.ClaimVtxos(ctx, req.GetId(), receivers); err != nil {
return nil, err
}
return &arkv1.ClaimPaymentResponse{}, nil
}
func (h *handler) FinalizePayment(ctx context.Context, req *arkv1.FinalizePaymentRequest) (*arkv1.FinalizePaymentResponse, error) {
forfeitTxs := req.GetSignedForfeitTxs()
roundTx := req.GetSignedRoundTx()
if len(forfeitTxs) <= 0 && roundTx == "" {
return nil, status.Error(codes.InvalidArgument, "missing forfeit txs or round tx")
}
if len(forfeitTxs) > 0 {
if err := h.svc.SignVtxos(ctx, forfeitTxs); err != nil {
return nil, err
}
}
if roundTx != "" {
if err := h.svc.SignRoundTx(ctx, roundTx); err != nil {
return nil, err
}
}
return &arkv1.FinalizePaymentResponse{}, nil
}
func (h *handler) GetRound(ctx context.Context, req *arkv1.GetRoundRequest) (*arkv1.GetRoundResponse, error) {
if len(req.GetTxid()) <= 0 {
round, err := h.svc.GetCurrentRound(ctx)
if err != nil {
return nil, err
}
return &arkv1.GetRoundResponse{
Round: &arkv1.Round{
Id: round.Id,
Start: round.StartingTimestamp,
End: round.EndingTimestamp,
PoolTx: round.UnsignedTx,
CongestionTree: congestionTree(round.CongestionTree).toProto(),
ForfeitTxs: round.ForfeitTxs,
Connectors: round.Connectors,
Stage: stage(round.Stage).toProto(),
},
}, nil
}
round, err := h.svc.GetRoundByTxid(ctx, req.GetTxid())
if err != nil {
return nil, err
}
return &arkv1.GetRoundResponse{
Round: &arkv1.Round{
Id: round.Id,
Start: round.StartingTimestamp,
End: round.EndingTimestamp,
PoolTx: round.UnsignedTx,
CongestionTree: congestionTree(round.CongestionTree).toProto(),
ForfeitTxs: round.ForfeitTxs,
Connectors: round.Connectors,
Stage: stage(round.Stage).toProto(),
},
}, nil
}
func (h *handler) GetRoundById(
ctx context.Context, req *arkv1.GetRoundByIdRequest,
) (*arkv1.GetRoundByIdResponse, error) {
id := req.GetId()
if len(id) <= 0 {
return nil, status.Error(codes.InvalidArgument, "missing round id")
}
round, err := h.svc.GetRoundById(ctx, id)
if err != nil {
return nil, err
}
return &arkv1.GetRoundByIdResponse{
Round: &arkv1.Round{
Id: round.Id,
Start: round.StartingTimestamp,
End: round.EndingTimestamp,
PoolTx: round.UnsignedTx,
CongestionTree: congestionTree(round.CongestionTree).toProto(),
ForfeitTxs: round.ForfeitTxs,
Connectors: round.Connectors,
Stage: stage(round.Stage).toProto(),
},
}, nil
}
func (h *handler) GetEventStream(_ *arkv1.GetEventStreamRequest, stream arkv1.ArkService_GetEventStreamServer) error {
doneCh := make(chan struct{})
listener := &listener{
id: uuid.NewString(),
done: doneCh,
ch: make(chan *arkv1.GetEventStreamResponse),
}
h.pushListener(listener)
defer h.removeListener(listener.id)
defer close(listener.ch)
defer close(doneCh)
for {
select {
case <-stream.Context().Done():
return nil
case <-doneCh:
return nil
case ev := <-listener.ch:
if err := stream.Send(ev); err != nil {
return err
}
}
}
}
func (h *handler) ListVtxos(ctx context.Context, req *arkv1.ListVtxosRequest) (*arkv1.ListVtxosResponse, error) {
_, userPubkey, _, err := parseAddress(req.GetAddress())
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
spendableVtxos, spentVtxos, err := h.svc.ListVtxos(ctx, userPubkey)
if err != nil {
return nil, err
}
return &arkv1.ListVtxosResponse{
SpendableVtxos: vtxoList(spendableVtxos).toProto(),
SpentVtxos: vtxoList(spentVtxos).toProto(),
}, nil
}
func (h *handler) GetInfo(ctx context.Context, req *arkv1.GetInfoRequest) (*arkv1.GetInfoResponse, error) {
info, err := h.svc.GetInfo(ctx)
if err != nil {
return nil, err
}
return &arkv1.GetInfoResponse{
Pubkey: info.PubKey,
RoundLifetime: info.RoundLifetime,
UnilateralExitDelay: info.UnilateralExitDelay,
RoundInterval: info.RoundInterval,
Network: info.Network,
Dust: int64(info.Dust),
BoardingDescriptorTemplate: info.BoardingDescriptorTemplate,
}, nil
}
func (h *handler) GetBoardingAddress(ctx context.Context, req *arkv1.GetBoardingAddressRequest) (*arkv1.GetBoardingAddressResponse, error) {
pubkey := req.GetPubkey()
if pubkey == "" {
return nil, status.Error(codes.InvalidArgument, "missing pubkey")
}
pubkeyBytes, err := hex.DecodeString(pubkey)
if err != nil {
return nil, status.Error(codes.InvalidArgument, "invalid pubkey (invalid hex)")
}
userPubkey, err := secp256k1.ParsePubKey(pubkeyBytes)
if err != nil {
return nil, status.Error(codes.InvalidArgument, "invalid pubkey (parse error)")
}
addr, descriptor, err := h.svc.GetBoardingAddress(ctx, userPubkey)
if err != nil {
return nil, err
}
return &arkv1.GetBoardingAddressResponse{
Address: addr,
Descriptor_: descriptor,
}, nil
}
func (h *handler) SendTreeNonces(ctx context.Context, req *arkv1.SendTreeNoncesRequest) (*arkv1.SendTreeNoncesResponse, error) {
pubkey := req.GetPublicKey()
encodedNonces := req.GetTreeNonces()
roundID := req.GetRoundId()
if len(pubkey) <= 0 {
return nil, status.Error(codes.InvalidArgument, "missing cosigner public key")
}
if len(encodedNonces) <= 0 {
return nil, status.Error(codes.InvalidArgument, "missing tree nonces")
}
if len(roundID) <= 0 {
return nil, status.Error(codes.InvalidArgument, "missing round id")
}
pubkeyBytes, err := hex.DecodeString(pubkey)
if err != nil {
return nil, status.Error(codes.InvalidArgument, "invalid cosigner public key")
}
cosignerPublicKey, err := secp256k1.ParsePubKey(pubkeyBytes)
if err != nil {
return nil, status.Error(codes.InvalidArgument, "invalid cosigner public key")
}
if err := h.svc.RegisterCosignerNonces(ctx, roundID, cosignerPublicKey, encodedNonces); err != nil {
return nil, err
}
return &arkv1.SendTreeNoncesResponse{}, nil
}
func (h *handler) SendTreeSignatures(ctx context.Context, req *arkv1.SendTreeSignaturesRequest) (*arkv1.SendTreeSignaturesResponse, error) {
roundID := req.GetRoundId()
pubkey := req.GetPublicKey()
encodedSignatures := req.GetTreeSignatures()
if len(pubkey) <= 0 {
return nil, status.Error(codes.InvalidArgument, "missing cosigner public key")
}
if len(encodedSignatures) <= 0 {
return nil, status.Error(codes.InvalidArgument, "missing tree signatures")
}
if len(roundID) <= 0 {
return nil, status.Error(codes.InvalidArgument, "missing round id")
}
pubkeyBytes, err := hex.DecodeString(pubkey)
if err != nil {
return nil, status.Error(codes.InvalidArgument, "invalid cosigner public key")
}
cosignerPublicKey, err := secp256k1.ParsePubKey(pubkeyBytes)
if err != nil {
return nil, status.Error(codes.InvalidArgument, "invalid cosigner public key")
}
if err := h.svc.RegisterCosignerSignatures(ctx, roundID, cosignerPublicKey, encodedSignatures); err != nil {
return nil, err
}
return &arkv1.SendTreeSignaturesResponse{}, nil
}
func (h *handler) pushListener(l *listener) {
h.listenersLock.Lock()
defer h.listenersLock.Unlock()
h.listeners = append(h.listeners, l)
}
func (h *handler) removeListener(id string) {
h.listenersLock.Lock()
defer h.listenersLock.Unlock()
for i, listener := range h.listeners {
if listener.id == id {
h.listeners = append(h.listeners[:i], h.listeners[i+1:]...)
return
}
}
}
// listenToEvents forwards events from the application layer to the set of listeners
func (h *handler) listenToEvents() {
channel := h.svc.GetEventsChannel(context.Background())
for event := range channel {
var ev *arkv1.GetEventStreamResponse
shouldClose := false
switch e := event.(type) {
case domain.RoundFinalizationStarted:
ev = &arkv1.GetEventStreamResponse{
Event: &arkv1.GetEventStreamResponse_RoundFinalization{
RoundFinalization: &arkv1.RoundFinalizationEvent{
Id: e.Id,
PoolTx: e.PoolTx,
CongestionTree: congestionTree(e.CongestionTree).toProto(),
Connectors: e.Connectors,
MinRelayFeeRate: e.MinRelayFeeRate,
},
},
}
case domain.RoundFinalized:
shouldClose = true
ev = &arkv1.GetEventStreamResponse{
Event: &arkv1.GetEventStreamResponse_RoundFinalized{
RoundFinalized: &arkv1.RoundFinalizedEvent{
Id: e.Id,
PoolTxid: e.Txid,
},
},
}
case domain.RoundFailed:
shouldClose = true
ev = &arkv1.GetEventStreamResponse{
Event: &arkv1.GetEventStreamResponse_RoundFailed{
RoundFailed: &arkv1.RoundFailed{
Id: e.Id,
Reason: e.Err,
},
},
}
case application.RoundSigningStarted:
cosignersKeys := make([]string, 0, len(e.Cosigners))
for _, key := range e.Cosigners {
cosignersKeys = append(cosignersKeys, hex.EncodeToString(key.SerializeCompressed()))
}
ev = &arkv1.GetEventStreamResponse{
Event: &arkv1.GetEventStreamResponse_RoundSigning{
RoundSigning: &arkv1.RoundSigningEvent{
Id: e.Id,
CosignersPubkeys: cosignersKeys,
UnsignedTree: congestionTree(e.UnsignedVtxoTree).toProto(),
UnsignedRoundTx: e.UnsignedRoundTx,
},
},
}
case application.RoundSigningNoncesGenerated:
serialized, err := e.SerializeNonces()
if err != nil {
logrus.WithError(err).Error("failed to serialize nonces")
continue
}
ev = &arkv1.GetEventStreamResponse{
Event: &arkv1.GetEventStreamResponse_RoundSigningNoncesGenerated{
RoundSigningNoncesGenerated: &arkv1.RoundSigningNoncesGeneratedEvent{
Id: e.Id,
TreeNonces: serialized,
},
},
}
}
if ev != nil {
logrus.Debugf("forwarding event to %d listeners", len(h.listeners))
for _, l := range h.listeners {
go func(l *listener) {
l.ch <- ev
if shouldClose {
l.done <- struct{}{}
}
}(l)
}
}
}
}