mirror of
https://github.com/aljazceru/ark.git
synced 2025-12-17 04:04:21 +01:00
* Rename asp > server * Rename pool > round * Consolidate naming for pubkey/prvkey vars and types * Fix * Fix * Fix wasm * Rename congestionTree > vtxoTree * Fix wasm * Rename payment > request * Rename congestionTree > vtxoTree after syncing with master * Fix Send API in SDK * Fix wasm * Fix wasm * Fixes * Fixes after review * Fix * Fix naming * Fix * Fix e2e tests
644 lines
14 KiB
Go
644 lines
14 KiB
Go
package bitcointree
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
|
|
"github.com/ark-network/ark/common"
|
|
"github.com/ark-network/ark/common/tree"
|
|
"github.com/btcsuite/btcd/btcec/v2"
|
|
"github.com/btcsuite/btcd/btcec/v2/schnorr"
|
|
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
|
|
"github.com/btcsuite/btcd/btcutil/psbt"
|
|
"github.com/btcsuite/btcd/txscript"
|
|
"github.com/btcsuite/btcd/wire"
|
|
"github.com/decred/dcrd/dcrec/secp256k1/v4"
|
|
)
|
|
|
|
var (
|
|
ErrMissingVtxoTree = errors.New("missing vtxo tree")
|
|
ErrMissingAggregateKey = errors.New("missing aggregate key")
|
|
)
|
|
|
|
type Musig2Nonce struct {
|
|
PubNonce [66]byte
|
|
}
|
|
|
|
func (n *Musig2Nonce) Encode(w io.Writer) error {
|
|
_, err := w.Write(n.PubNonce[:])
|
|
return err
|
|
}
|
|
|
|
func (n *Musig2Nonce) Decode(r io.Reader) error {
|
|
bytes := make([]byte, 66)
|
|
bytesRead, err := io.ReadFull(r, bytes)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if bytesRead != 66 {
|
|
return fmt.Errorf("expected to read 66 bytes, but read %d", bytesRead)
|
|
}
|
|
|
|
copy(n.PubNonce[:], bytes)
|
|
return nil
|
|
}
|
|
|
|
type TreeNonces [][]*Musig2Nonce // public nonces
|
|
type TreePartialSigs [][]*musig2.PartialSignature
|
|
|
|
type SignerSession interface {
|
|
GetNonces() (TreeNonces, error) // generate tree nonces for this session
|
|
SetKeys([]*btcec.PublicKey) error // set the cosigner public keys for this session
|
|
SetAggregatedNonces(TreeNonces) error // set the aggregated nonces
|
|
Sign() (TreePartialSigs, error) // sign the tree
|
|
}
|
|
|
|
type CoordinatorSession interface {
|
|
AddNonce(*btcec.PublicKey, TreeNonces) error
|
|
AggregateNonces() (TreeNonces, error)
|
|
AddSig(*btcec.PublicKey, TreePartialSigs) error
|
|
// SignTree combines the signatures and add them to the tree's psbts
|
|
SignTree() (tree.VtxoTree, error)
|
|
}
|
|
|
|
func (n TreeNonces) Encode(w io.Writer) error {
|
|
matrix, err := encodeMatrix(n)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = w.Write(matrix)
|
|
return err
|
|
}
|
|
|
|
func DecodeNonces(r io.Reader) (TreeNonces, error) {
|
|
return decodeMatrix(func() *Musig2Nonce { return new(Musig2Nonce) }, r)
|
|
}
|
|
|
|
func (s TreePartialSigs) Encode(w io.Writer) error {
|
|
matrix, err := encodeMatrix(s)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = w.Write(matrix)
|
|
return err
|
|
}
|
|
|
|
func DecodeSignatures(r io.Reader) (TreePartialSigs, error) {
|
|
return decodeMatrix(func() *musig2.PartialSignature { return new(musig2.PartialSignature) }, r)
|
|
}
|
|
|
|
func AggregateKeys(
|
|
pubkeys []*btcec.PublicKey,
|
|
scriptRoot []byte,
|
|
) (*musig2.AggregateKey, error) {
|
|
key, _, _, err := musig2.AggregateKeys(pubkeys, true,
|
|
musig2.WithTaprootKeyTweak(scriptRoot),
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return key, nil
|
|
}
|
|
|
|
func ValidateTreeSigs(
|
|
scriptRoot []byte,
|
|
finalAggregatedKey *btcec.PublicKey,
|
|
roundSharedOutputAmount int64,
|
|
vtxoTree tree.VtxoTree,
|
|
) error {
|
|
prevoutFetcherFactory, err := prevOutFetcherFactory(finalAggregatedKey, vtxoTree, roundSharedOutputAmount)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, level := range vtxoTree {
|
|
for _, node := range level {
|
|
partialTx, err := psbt.NewFromRawBytes(strings.NewReader(node.Tx), true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
sig := partialTx.Inputs[0].TaprootKeySpendSig
|
|
if len(sig) == 0 {
|
|
return errors.New("unsigned tree input")
|
|
}
|
|
|
|
schnorrSig, err := schnorr.ParseSignature(sig)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
prevoutFetcher, err := prevoutFetcherFactory(partialTx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
message, err := txscript.CalcTaprootSignatureHash(
|
|
txscript.NewTxSigHashes(partialTx.UnsignedTx, prevoutFetcher),
|
|
txscript.SigHashDefault,
|
|
partialTx.UnsignedTx,
|
|
0,
|
|
prevoutFetcher,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !schnorrSig.Verify(message, finalAggregatedKey) {
|
|
return errors.New("invalid signature")
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func NewTreeSignerSession(
|
|
signer *btcec.PrivateKey,
|
|
roundSharedOutputAmount int64,
|
|
vtxoTree tree.VtxoTree,
|
|
scriptRoot []byte,
|
|
) SignerSession {
|
|
return &treeSignerSession{
|
|
secretKey: signer,
|
|
tree: vtxoTree,
|
|
scriptRoot: scriptRoot,
|
|
roundSharedOutputAmount: roundSharedOutputAmount,
|
|
}
|
|
}
|
|
|
|
type treeSignerSession struct {
|
|
secretKey *btcec.PrivateKey
|
|
tree tree.VtxoTree
|
|
myNonces [][]*musig2.Nonces
|
|
keys []*btcec.PublicKey
|
|
aggregateNonces TreeNonces
|
|
scriptRoot []byte
|
|
roundSharedOutputAmount int64
|
|
prevoutFetcherFactory func(*psbt.Packet) (txscript.PrevOutputFetcher, error)
|
|
}
|
|
|
|
func (t *treeSignerSession) generateNonces() error {
|
|
if t.tree == nil {
|
|
return ErrMissingVtxoTree
|
|
}
|
|
|
|
myNonces := make([][]*musig2.Nonces, 0)
|
|
|
|
for _, level := range t.tree {
|
|
levelNonces := make([]*musig2.Nonces, 0)
|
|
for range level {
|
|
nonce, err := musig2.GenNonces(
|
|
musig2.WithPublicKey(t.secretKey.PubKey()),
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
levelNonces = append(levelNonces, nonce)
|
|
}
|
|
myNonces = append(myNonces, levelNonces)
|
|
}
|
|
|
|
t.myNonces = myNonces
|
|
return nil
|
|
}
|
|
|
|
func (t *treeSignerSession) GetNonces() (TreeNonces, error) {
|
|
if t.tree == nil {
|
|
return nil, ErrMissingVtxoTree
|
|
}
|
|
|
|
if t.myNonces == nil {
|
|
if err := t.generateNonces(); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
nonces := make(TreeNonces, 0)
|
|
|
|
for _, level := range t.myNonces {
|
|
levelNonces := make([]*Musig2Nonce, 0)
|
|
for _, nonce := range level {
|
|
levelNonces = append(levelNonces, &Musig2Nonce{nonce.PubNonce})
|
|
}
|
|
nonces = append(nonces, levelNonces)
|
|
}
|
|
|
|
return nonces, nil
|
|
}
|
|
|
|
func (t *treeSignerSession) SetKeys(keys []*btcec.PublicKey) error {
|
|
if t.keys != nil {
|
|
return errors.New("keys already set")
|
|
}
|
|
|
|
aggregateKey, err := AggregateKeys(keys, t.scriptRoot)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
factory, err := prevOutFetcherFactory(aggregateKey.FinalKey, t.tree, t.roundSharedOutputAmount)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
t.prevoutFetcherFactory = factory
|
|
t.keys = keys
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *treeSignerSession) SetAggregatedNonces(nonces TreeNonces) error {
|
|
if t.aggregateNonces != nil {
|
|
return errors.New("nonces already set")
|
|
}
|
|
|
|
t.aggregateNonces = nonces
|
|
return nil
|
|
}
|
|
|
|
func (t *treeSignerSession) Sign() (TreePartialSigs, error) {
|
|
if t.tree == nil {
|
|
return nil, ErrMissingVtxoTree
|
|
}
|
|
|
|
if t.keys == nil {
|
|
return nil, ErrMissingAggregateKey
|
|
}
|
|
|
|
if t.aggregateNonces == nil {
|
|
return nil, errors.New("nonces not set")
|
|
}
|
|
|
|
sigs := make(TreePartialSigs, 0)
|
|
|
|
for i, level := range t.tree {
|
|
levelSigs := make([]*musig2.PartialSignature, 0)
|
|
|
|
for j, node := range level {
|
|
partialTx, err := psbt.NewFromRawBytes(strings.NewReader(node.Tx), true)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// sign the node
|
|
sig, err := t.signPartial(partialTx, i, j, t.secretKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
levelSigs = append(levelSigs, sig)
|
|
}
|
|
|
|
sigs = append(sigs, levelSigs)
|
|
}
|
|
|
|
return sigs, nil
|
|
}
|
|
|
|
func (t *treeSignerSession) signPartial(partialTx *psbt.Packet, posx int, posy int, seckey *btcec.PrivateKey) (*musig2.PartialSignature, error) {
|
|
prevoutFetcher, err := t.prevoutFetcherFactory(partialTx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
myNonce := t.myNonces[posx][posy]
|
|
aggregatedNonce := t.aggregateNonces[posx][posy]
|
|
|
|
message, err := txscript.CalcTaprootSignatureHash(
|
|
txscript.NewTxSigHashes(partialTx.UnsignedTx, prevoutFetcher),
|
|
txscript.SigHashDefault,
|
|
partialTx.UnsignedTx,
|
|
0,
|
|
prevoutFetcher,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return musig2.Sign(
|
|
myNonce.SecNonce, seckey, aggregatedNonce.PubNonce, t.keys, [32]byte(message),
|
|
musig2.WithSortedKeys(), musig2.WithTaprootSignTweak(t.scriptRoot),
|
|
)
|
|
}
|
|
|
|
type treeCoordinatorSession struct {
|
|
scriptRoot []byte
|
|
tree tree.VtxoTree
|
|
keys []*btcec.PublicKey
|
|
nonces []TreeNonces
|
|
sigs []TreePartialSigs
|
|
prevoutFetcherFactory func(*psbt.Packet) (txscript.PrevOutputFetcher, error)
|
|
}
|
|
|
|
func NewTreeCoordinatorSession(
|
|
roundSharedOutputAmount int64,
|
|
vtxoTree tree.VtxoTree,
|
|
scriptRoot []byte,
|
|
keys []*btcec.PublicKey,
|
|
) (CoordinatorSession, error) {
|
|
aggregateKey, err := AggregateKeys(keys, scriptRoot)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
prevoutFetcherFactory, err := prevOutFetcherFactory(aggregateKey.FinalKey, vtxoTree, roundSharedOutputAmount)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
nbOfKeys := len(keys)
|
|
|
|
return &treeCoordinatorSession{
|
|
scriptRoot: scriptRoot,
|
|
tree: vtxoTree,
|
|
keys: keys,
|
|
nonces: make([]TreeNonces, nbOfKeys),
|
|
sigs: make([]TreePartialSigs, nbOfKeys),
|
|
prevoutFetcherFactory: prevoutFetcherFactory,
|
|
}, nil
|
|
}
|
|
|
|
func (t *treeCoordinatorSession) getPubkeyIndex(pubkey *btcec.PublicKey) int {
|
|
for i, key := range t.keys {
|
|
if key.IsEqual(pubkey) {
|
|
return i
|
|
}
|
|
}
|
|
|
|
return -1
|
|
}
|
|
|
|
func (t *treeCoordinatorSession) AddNonce(pubkey *btcec.PublicKey, nonce TreeNonces) error {
|
|
index := t.getPubkeyIndex(pubkey)
|
|
if index == -1 {
|
|
return errors.New("public key not found")
|
|
}
|
|
|
|
t.nonces[index] = nonce
|
|
return nil
|
|
}
|
|
|
|
func (t *treeCoordinatorSession) AddSig(pubkey *btcec.PublicKey, sig TreePartialSigs) error {
|
|
index := t.getPubkeyIndex(pubkey)
|
|
if index == -1 {
|
|
return errors.New("public key not found")
|
|
}
|
|
|
|
t.sigs[index] = sig
|
|
return nil
|
|
}
|
|
|
|
func (t *treeCoordinatorSession) AggregateNonces() (TreeNonces, error) {
|
|
for _, nonce := range t.nonces {
|
|
if nonce == nil {
|
|
return nil, errors.New("nonces not set")
|
|
}
|
|
}
|
|
|
|
aggregatedNonces := make(TreeNonces, 0)
|
|
|
|
for i, level := range t.tree {
|
|
levelNonces := make([]*Musig2Nonce, 0)
|
|
for j := range level {
|
|
nonces := make([][66]byte, 0)
|
|
for _, n := range t.nonces {
|
|
nonces = append(nonces, n[i][j].PubNonce)
|
|
}
|
|
|
|
aggregatedNonce, err := musig2.AggregateNonces(nonces)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
levelNonces = append(levelNonces, &Musig2Nonce{aggregatedNonce})
|
|
}
|
|
|
|
aggregatedNonces = append(aggregatedNonces, levelNonces)
|
|
}
|
|
|
|
return aggregatedNonces, nil
|
|
}
|
|
|
|
// SignTree implements CoordinatorSession.
|
|
func (t *treeCoordinatorSession) SignTree() (tree.VtxoTree, error) {
|
|
var missingSigs int
|
|
for _, sig := range t.sigs {
|
|
if sig == nil {
|
|
missingSigs++
|
|
}
|
|
}
|
|
|
|
if missingSigs > 0 {
|
|
return nil, fmt.Errorf("missing %d signature(s)", missingSigs)
|
|
}
|
|
|
|
aggregatedKey, err := AggregateKeys(t.keys, t.scriptRoot)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for i, level := range t.tree {
|
|
for j, node := range level {
|
|
partialTx, err := psbt.NewFromRawBytes(strings.NewReader(node.Tx), true)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var combinedNonce *secp256k1.PublicKey
|
|
sigs := make([]*musig2.PartialSignature, 0)
|
|
for _, sig := range t.sigs {
|
|
s := sig[i][j]
|
|
if s.R != nil {
|
|
combinedNonce = s.R
|
|
}
|
|
sigs = append(sigs, s)
|
|
}
|
|
|
|
if combinedNonce == nil {
|
|
return nil, errors.New("missing combined nonce")
|
|
}
|
|
|
|
prevoutFetcher, err := t.prevoutFetcherFactory(partialTx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
message, err := txscript.CalcTaprootSignatureHash(
|
|
txscript.NewTxSigHashes(partialTx.UnsignedTx, prevoutFetcher),
|
|
txscript.SigHashDefault,
|
|
partialTx.UnsignedTx,
|
|
0,
|
|
prevoutFetcher,
|
|
)
|
|
|
|
combinedSig := musig2.CombineSigs(
|
|
combinedNonce, sigs,
|
|
musig2.WithTaprootTweakedCombine([32]byte(message), t.keys, t.scriptRoot, true),
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if !combinedSig.Verify(message, aggregatedKey.FinalKey) {
|
|
return nil, errors.New("invalid signature")
|
|
}
|
|
|
|
partialTx.Inputs[0].TaprootKeySpendSig = combinedSig.Serialize()
|
|
|
|
encodedSignedTx, err := partialTx.B64Encode()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
node.Tx = encodedSignedTx
|
|
t.tree[i][j] = node
|
|
}
|
|
}
|
|
|
|
return t.tree, nil
|
|
}
|
|
|
|
func prevOutFetcherFactory(
|
|
finalAggregatedKey *btcec.PublicKey,
|
|
vtxoTree tree.VtxoTree,
|
|
roundSharedOutputAmount int64,
|
|
) (
|
|
func(partial *psbt.Packet) (txscript.PrevOutputFetcher, error),
|
|
error,
|
|
) {
|
|
pkscript, err := common.P2TRScript(finalAggregatedKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rootNode, err := vtxoTree.Root()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return func(partial *psbt.Packet) (txscript.PrevOutputFetcher, error) {
|
|
parentOutpoint := partial.UnsignedTx.TxIn[0].PreviousOutPoint
|
|
parentTxID := parentOutpoint.Hash.String()
|
|
if rootNode.ParentTxid == parentTxID {
|
|
return &treePrevOutFetcher{
|
|
prevout: &wire.TxOut{
|
|
Value: roundSharedOutputAmount,
|
|
PkScript: pkscript,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
var parent tree.Node
|
|
for _, level := range vtxoTree {
|
|
for _, n := range level {
|
|
if n.Txid == parentTxID {
|
|
parent = n
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if parent.Txid == "" {
|
|
return nil, errors.New("parent tx not found")
|
|
}
|
|
|
|
parentTx, err := psbt.NewFromRawBytes(strings.NewReader(parent.Tx), true)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
parentValue := parentTx.UnsignedTx.TxOut[parentOutpoint.Index].Value
|
|
|
|
return &treePrevOutFetcher{
|
|
prevout: &wire.TxOut{
|
|
Value: parentValue,
|
|
PkScript: pkscript,
|
|
},
|
|
}, nil
|
|
}, nil
|
|
}
|
|
|
|
type treePrevOutFetcher struct {
|
|
prevout *wire.TxOut
|
|
}
|
|
|
|
func (f *treePrevOutFetcher) FetchPrevOutput(wire.OutPoint) *wire.TxOut {
|
|
return f.prevout
|
|
}
|
|
|
|
type writable interface {
|
|
Encode(w io.Writer) error
|
|
}
|
|
|
|
type readable interface {
|
|
Decode(r io.Reader) error
|
|
}
|
|
|
|
// encodeMatrix encode a matrix of serializable objects into a byte stream
|
|
func encodeMatrix[T writable](matrix [][]T) ([]byte, error) {
|
|
var buf bytes.Buffer
|
|
|
|
// Write number of rows
|
|
if err := binary.Write(&buf, binary.LittleEndian, uint32(len(matrix))); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// For each row, write its length and then its elements
|
|
for _, row := range matrix {
|
|
// Write row length
|
|
if err := binary.Write(&buf, binary.LittleEndian, uint32(len(row))); err != nil {
|
|
return nil, err
|
|
}
|
|
// Write row data
|
|
for _, cell := range row {
|
|
if err := cell.Encode(&buf); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
|
|
return buf.Bytes(), nil
|
|
}
|
|
|
|
// decodeMatrix decode a byte stream into a matrix of serializable objects
|
|
func decodeMatrix[T readable](factory func() T, data io.Reader) ([][]T, error) {
|
|
var rowCount uint32
|
|
|
|
// Read number of rows
|
|
if err := binary.Read(data, binary.LittleEndian, &rowCount); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Initialize matrix
|
|
matrix := make([][]T, 0, rowCount)
|
|
// For each row, read its length and then its elements
|
|
for i := uint32(0); i < rowCount; i++ {
|
|
var colCount uint32
|
|
// Read row length
|
|
if err := binary.Read(data, binary.LittleEndian, &colCount); err != nil {
|
|
return nil, err
|
|
}
|
|
// Initialize row
|
|
row := make([]T, 0, colCount)
|
|
// Read row data
|
|
for j := uint32(0); j < colCount; j++ {
|
|
cell := factory()
|
|
if err := cell.Decode(data); err != nil {
|
|
return nil, err
|
|
}
|
|
row = append(row, cell)
|
|
}
|
|
matrix = append(matrix, row)
|
|
}
|
|
|
|
return matrix, nil
|
|
}
|