Files
ark/pkg/client-sdk/client/rest/client.go
2024-08-09 17:46:42 +02:00

507 lines
12 KiB
Go

package restclient
import (
"context"
"fmt"
"net/url"
"strconv"
"strings"
"time"
arkv1 "github.com/ark-network/ark/api-spec/protobuf/gen/ark/v1"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/pkg/client-sdk/client"
"github.com/ark-network/ark/pkg/client-sdk/client/rest/service/arkservice"
"github.com/ark-network/ark/pkg/client-sdk/client/rest/service/arkservice/ark_service"
"github.com/ark-network/ark/pkg/client-sdk/client/rest/service/models"
"github.com/ark-network/ark/pkg/client-sdk/internal/utils"
"github.com/btcsuite/btcd/btcutil/psbt"
httptransport "github.com/go-openapi/runtime/client"
"github.com/go-openapi/strfmt"
"github.com/vulpemventures/go-elements/psetv2"
)
type restClient struct {
svc ark_service.ClientService
eventsCh chan client.RoundEventChannel
requestTimeout time.Duration
treeCache *utils.Cache[tree.CongestionTree]
}
func NewClient(aspUrl string) (client.ASPClient, error) {
if len(aspUrl) <= 0 {
return nil, fmt.Errorf("missing asp url")
}
svc, err := newRestClient(aspUrl)
if err != nil {
return nil, err
}
eventsCh := make(chan client.RoundEventChannel)
reqTimeout := 15 * time.Second
treeCache := utils.NewCache[tree.CongestionTree]()
return &restClient{svc, eventsCh, reqTimeout, treeCache}, nil
}
func (c *restClient) Close() {}
func (a *restClient) GetEventStream(
ctx context.Context, paymentID string,
) (<-chan client.RoundEventChannel, error) {
go func(payID string) {
defer close(a.eventsCh)
timeout := time.After(a.requestTimeout)
for {
select {
case <-timeout:
a.eventsCh <- client.RoundEventChannel{
Err: fmt.Errorf("timeout reached"),
}
return
default:
event, err := a.Ping(ctx, payID)
if err != nil {
a.eventsCh <- client.RoundEventChannel{
Err: err,
}
return
}
if event != nil {
a.eventsCh <- client.RoundEventChannel{
Event: *event,
}
for {
roundID := event.ID
round, err := a.GetRoundByID(ctx, roundID)
if err != nil {
a.eventsCh <- client.RoundEventChannel{
Err: err,
}
return
}
if round.Stage == client.RoundStageFinalized {
a.eventsCh <- client.RoundEventChannel{
Event: client.RoundFinalizedEvent{
ID: roundID,
Txid: getTxid(round.Tx),
},
}
return
}
if round.Stage == client.RoundStageFailed {
a.eventsCh <- client.RoundEventChannel{
Event: client.RoundFailedEvent{
ID: roundID,
},
}
return
}
time.Sleep(1 * time.Second)
}
}
time.Sleep(1 * time.Second)
}
}
}(paymentID)
return a.eventsCh, nil
}
func (a *restClient) GetInfo(
ctx context.Context,
) (*client.Info, error) {
resp, err := a.svc.ArkServiceGetInfo(ark_service.NewArkServiceGetInfoParams())
if err != nil {
return nil, err
}
roundLifetime, err := strconv.Atoi(resp.Payload.RoundLifetime)
if err != nil {
return nil, err
}
unilateralExitDelay, err := strconv.Atoi(resp.Payload.UnilateralExitDelay)
if err != nil {
return nil, err
}
roundInterval, err := strconv.Atoi(resp.Payload.RoundInterval)
if err != nil {
return nil, err
}
minRelayFee, err := strconv.Atoi(resp.Payload.MinRelayFee)
if err != nil {
return nil, err
}
return &client.Info{
Pubkey: resp.Payload.Pubkey,
RoundLifetime: int64(roundLifetime),
UnilateralExitDelay: int64(unilateralExitDelay),
RoundInterval: int64(roundInterval),
Network: resp.Payload.Network,
MinRelayFee: int64(minRelayFee),
}, nil
}
func (a *restClient) ListVtxos(
ctx context.Context, addr string,
) ([]client.Vtxo, []client.Vtxo, error) {
resp, err := a.svc.ArkServiceListVtxos(
ark_service.NewArkServiceListVtxosParams().WithAddress(addr),
)
if err != nil {
return nil, nil, err
}
spendableVtxos := make([]client.Vtxo, 0, len(resp.Payload.SpendableVtxos))
for _, v := range resp.Payload.SpendableVtxos {
var expiresAt *time.Time
if v.ExpireAt != "" && v.ExpireAt != "0" {
expAt, err := strconv.Atoi(v.ExpireAt)
if err != nil {
return nil, nil, err
}
t := time.Unix(int64(expAt), 0)
expiresAt = &t
}
amount, err := strconv.Atoi(v.Receiver.Amount)
if err != nil {
return nil, nil, err
}
spendableVtxos = append(spendableVtxos, client.Vtxo{
VtxoKey: client.VtxoKey{
Txid: v.Outpoint.Txid,
VOut: uint32(v.Outpoint.Vout),
},
Amount: uint64(amount),
RoundTxid: v.PoolTxid,
ExpiresAt: expiresAt,
})
}
spentVtxos := make([]client.Vtxo, 0, len(resp.Payload.SpentVtxos))
for _, v := range resp.Payload.SpentVtxos {
var expiresAt *time.Time
if v.ExpireAt != "" && v.ExpireAt != "0" {
expAt, err := strconv.Atoi(v.ExpireAt)
if err != nil {
return nil, nil, err
}
t := time.Unix(int64(expAt), 0)
expiresAt = &t
}
amount, err := strconv.Atoi(v.Receiver.Amount)
if err != nil {
return nil, nil, err
}
spentVtxos = append(spentVtxos, client.Vtxo{
VtxoKey: client.VtxoKey{
Txid: v.Outpoint.Txid,
VOut: uint32(v.Outpoint.Vout),
},
Amount: uint64(amount),
RoundTxid: v.PoolTxid,
ExpiresAt: expiresAt,
})
}
return spendableVtxos, spentVtxos, nil
}
func (a *restClient) GetRound(
ctx context.Context, txID string,
) (*client.Round, error) {
resp, err := a.svc.ArkServiceGetRound(
ark_service.NewArkServiceGetRoundParams().WithTxid(txID),
)
if err != nil {
return nil, err
}
start, err := strconv.Atoi(resp.Payload.Round.Start)
if err != nil {
return nil, err
}
end, err := strconv.Atoi(resp.Payload.Round.End)
if err != nil {
return nil, err
}
startedAt := time.Unix(int64(start), 0)
var endedAt *time.Time
if end > 0 {
t := time.Unix(int64(end), 0)
endedAt = &t
}
return &client.Round{
ID: resp.Payload.Round.ID,
StartedAt: &startedAt,
EndedAt: endedAt,
Tx: resp.Payload.Round.PoolTx,
Tree: treeFromProto{resp.Payload.Round.CongestionTree}.parse(),
ForfeitTxs: resp.Payload.Round.ForfeitTxs,
Connectors: resp.Payload.Round.Connectors,
Stage: toRoundStage(*resp.Payload.Round.Stage),
}, nil
}
func (a *restClient) Onboard(
ctx context.Context, tx, userPubkey string, congestionTree tree.CongestionTree,
) error {
body := models.V1OnboardRequest{
BoardingTx: tx,
CongestionTree: treeToProto(congestionTree).parse(),
UserPubkey: userPubkey,
}
_, err := a.svc.ArkServiceOnboard(
ark_service.NewArkServiceOnboardParams().WithBody(&body),
)
return err
}
func (a *restClient) RegisterPayment(
ctx context.Context, inputs []client.VtxoKey,
) (string, error) {
ins := make([]*models.V1Input, 0, len(inputs))
for _, i := range inputs {
ins = append(ins, &models.V1Input{
Txid: i.Txid,
Vout: int64(i.VOut),
})
}
body := models.V1RegisterPaymentRequest{
Inputs: ins,
}
resp, err := a.svc.ArkServiceRegisterPayment(
ark_service.NewArkServiceRegisterPaymentParams().WithBody(&body),
)
if err != nil {
return "", err
}
return resp.Payload.ID, nil
}
func (a *restClient) ClaimPayment(
ctx context.Context, paymentID string, outputs []client.Output,
) error {
outs := make([]*models.V1Output, 0, len(outputs))
for _, o := range outputs {
outs = append(outs, &models.V1Output{
Address: o.Address,
Amount: strconv.Itoa(int(o.Amount)),
})
}
body := models.V1ClaimPaymentRequest{
ID: paymentID,
Outputs: outs,
}
_, err := a.svc.ArkServiceClaimPayment(
ark_service.NewArkServiceClaimPaymentParams().WithBody(&body),
)
return err
}
func (a *restClient) Ping(
ctx context.Context, paymentID string,
) (*client.RoundFinalizationEvent, error) {
r := ark_service.NewArkServicePingParams()
r.SetPaymentID(paymentID)
resp, err := a.svc.ArkServicePing(r)
if err != nil {
return nil, err
}
var event *client.RoundFinalizationEvent
if resp.Payload.Event != nil {
event = &client.RoundFinalizationEvent{
ID: resp.Payload.Event.ID,
Tx: resp.Payload.Event.PoolTx,
ForfeitTxs: resp.Payload.Event.ForfeitTxs,
Tree: treeFromProto{resp.Payload.Event.CongestionTree}.parse(),
Connectors: resp.Payload.Event.Connectors,
}
}
return event, nil
}
func (a *restClient) FinalizePayment(
ctx context.Context, signedForfeitTxs []string,
) error {
req := &arkv1.FinalizePaymentRequest{
SignedForfeitTxs: signedForfeitTxs,
}
body := models.V1FinalizePaymentRequest{
SignedForfeitTxs: req.GetSignedForfeitTxs(),
}
_, err := a.svc.ArkServiceFinalizePayment(
ark_service.NewArkServiceFinalizePaymentParams().WithBody(&body),
)
return err
}
func (a *restClient) GetRoundByID(
ctx context.Context, roundID string,
) (*client.Round, error) {
resp, err := a.svc.ArkServiceGetRoundByID(
ark_service.NewArkServiceGetRoundByIDParams().WithID(roundID),
)
if err != nil {
return nil, err
}
start, err := strconv.Atoi(resp.Payload.Round.Start)
if err != nil {
return nil, err
}
end, err := strconv.Atoi(resp.Payload.Round.End)
if err != nil {
return nil, err
}
startedAt := time.Unix(int64(start), 0)
var endedAt *time.Time
if end > 0 {
t := time.Unix(int64(end), 0)
endedAt = &t
}
return &client.Round{
ID: resp.Payload.Round.ID,
StartedAt: &startedAt,
EndedAt: endedAt,
Tx: resp.Payload.Round.PoolTx,
Tree: treeFromProto{resp.Payload.Round.CongestionTree}.parse(),
ForfeitTxs: resp.Payload.Round.ForfeitTxs,
Connectors: resp.Payload.Round.Connectors,
Stage: toRoundStage(*resp.Payload.Round.Stage),
}, nil
}
func newRestClient(
serviceURL string,
) (ark_service.ClientService, error) {
parsedURL, err := url.Parse(serviceURL)
if err != nil {
return nil, err
}
schemes := []string{parsedURL.Scheme}
host := parsedURL.Host
basePath := parsedURL.Path
if basePath == "" {
basePath = arkservice.DefaultBasePath
}
cfg := &arkservice.TransportConfig{
Host: host,
BasePath: basePath,
Schemes: schemes,
}
transport := httptransport.New(cfg.Host, cfg.BasePath, cfg.Schemes)
svc := arkservice.New(transport, strfmt.Default)
return svc.ArkService, nil
}
func toRoundStage(stage models.V1RoundStage) client.RoundStage {
switch stage {
case models.V1RoundStageROUNDSTAGEREGISTRATION:
return client.RoundStageRegistration
case models.V1RoundStageROUNDSTAGEFINALIZATION:
return client.RoundStageFinalization
case models.V1RoundStageROUNDSTAGEFINALIZED:
return client.RoundStageFinalized
case models.V1RoundStageROUNDSTAGEFAILED:
return client.RoundStageFailed
default:
return client.RoundStageUndefined
}
}
type treeFromProto struct {
*models.V1Tree
}
func (t treeFromProto) parse() tree.CongestionTree {
congestionTree := make(tree.CongestionTree, 0, len(t.Levels))
for _, l := range t.Levels {
level := make([]tree.Node, 0, len(l.Nodes))
for _, n := range l.Nodes {
level = append(level, tree.Node{
Txid: n.Txid,
Tx: n.Tx,
ParentTxid: n.ParentTxid,
})
}
congestionTree = append(congestionTree, level)
}
for j, treeLvl := range congestionTree {
for i, node := range treeLvl {
if len(congestionTree.Children(node.Txid)) == 0 {
congestionTree[j][i] = tree.Node{
Txid: node.Txid,
Tx: node.Tx,
ParentTxid: node.ParentTxid,
Leaf: true,
}
}
}
}
return congestionTree
}
type treeToProto tree.CongestionTree
func (t treeToProto) parse() *models.V1Tree {
levels := make([]*models.V1TreeLevel, 0, len(t))
for _, level := range t {
nodes := make([]*models.V1Node, 0, len(level))
for _, n := range level {
nodes = append(nodes, &models.V1Node{
Txid: n.Txid,
Tx: n.Tx,
ParentTxid: n.ParentTxid,
})
}
levels = append(levels, &models.V1TreeLevel{
Nodes: nodes,
})
}
return &models.V1Tree{
Levels: levels,
}
}
func getTxid(tx string) string {
if ptx, _ := psetv2.NewPsetFromBase64(tx); ptx != nil {
utx, _ := ptx.UnsignedTx()
return utx.TxHash().String()
}
ptx, _ := psbt.NewFromRawBytes(strings.NewReader(tx), true)
return ptx.UnsignedTx.TxID()
}