Files
ark/common/bitcointree/musig2.go
Pietralberto Mazza 7f937e8418 Vars and fields renaming (#387)
* Rename asp > server

* Rename pool > round

* Consolidate naming for pubkey/prvkey vars and types

* Fix

* Fix

* Fix wasm

* Rename congestionTree > vtxoTree

* Fix wasm

* Rename payment > request

* Rename congestionTree > vtxoTree after syncing with master

* Fix Send API in SDK

* Fix wasm

* Fix wasm

* Fixes

* Fixes after review

* Fix

* Fix naming

* Fix

* Fix e2e tests
2024-11-26 15:57:16 +01:00

644 lines
14 KiB
Go

package bitcointree
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"strings"
"github.com/ark-network/ark/common"
"github.com/ark-network/ark/common/tree"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/btcsuite/btcd/btcutil/psbt"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
)
var (
ErrMissingVtxoTree = errors.New("missing vtxo tree")
ErrMissingAggregateKey = errors.New("missing aggregate key")
)
type Musig2Nonce struct {
PubNonce [66]byte
}
func (n *Musig2Nonce) Encode(w io.Writer) error {
_, err := w.Write(n.PubNonce[:])
return err
}
func (n *Musig2Nonce) Decode(r io.Reader) error {
bytes := make([]byte, 66)
bytesRead, err := io.ReadFull(r, bytes)
if err != nil {
return err
}
if bytesRead != 66 {
return fmt.Errorf("expected to read 66 bytes, but read %d", bytesRead)
}
copy(n.PubNonce[:], bytes)
return nil
}
type TreeNonces [][]*Musig2Nonce // public nonces
type TreePartialSigs [][]*musig2.PartialSignature
type SignerSession interface {
GetNonces() (TreeNonces, error) // generate tree nonces for this session
SetKeys([]*btcec.PublicKey) error // set the cosigner public keys for this session
SetAggregatedNonces(TreeNonces) error // set the aggregated nonces
Sign() (TreePartialSigs, error) // sign the tree
}
type CoordinatorSession interface {
AddNonce(*btcec.PublicKey, TreeNonces) error
AggregateNonces() (TreeNonces, error)
AddSig(*btcec.PublicKey, TreePartialSigs) error
// SignTree combines the signatures and add them to the tree's psbts
SignTree() (tree.VtxoTree, error)
}
func (n TreeNonces) Encode(w io.Writer) error {
matrix, err := encodeMatrix(n)
if err != nil {
return err
}
_, err = w.Write(matrix)
return err
}
func DecodeNonces(r io.Reader) (TreeNonces, error) {
return decodeMatrix(func() *Musig2Nonce { return new(Musig2Nonce) }, r)
}
func (s TreePartialSigs) Encode(w io.Writer) error {
matrix, err := encodeMatrix(s)
if err != nil {
return err
}
_, err = w.Write(matrix)
return err
}
func DecodeSignatures(r io.Reader) (TreePartialSigs, error) {
return decodeMatrix(func() *musig2.PartialSignature { return new(musig2.PartialSignature) }, r)
}
func AggregateKeys(
pubkeys []*btcec.PublicKey,
scriptRoot []byte,
) (*musig2.AggregateKey, error) {
key, _, _, err := musig2.AggregateKeys(pubkeys, true,
musig2.WithTaprootKeyTweak(scriptRoot),
)
if err != nil {
return nil, err
}
return key, nil
}
func ValidateTreeSigs(
scriptRoot []byte,
finalAggregatedKey *btcec.PublicKey,
roundSharedOutputAmount int64,
vtxoTree tree.VtxoTree,
) error {
prevoutFetcherFactory, err := prevOutFetcherFactory(finalAggregatedKey, vtxoTree, roundSharedOutputAmount)
if err != nil {
return err
}
for _, level := range vtxoTree {
for _, node := range level {
partialTx, err := psbt.NewFromRawBytes(strings.NewReader(node.Tx), true)
if err != nil {
return err
}
sig := partialTx.Inputs[0].TaprootKeySpendSig
if len(sig) == 0 {
return errors.New("unsigned tree input")
}
schnorrSig, err := schnorr.ParseSignature(sig)
if err != nil {
return err
}
prevoutFetcher, err := prevoutFetcherFactory(partialTx)
if err != nil {
return err
}
message, err := txscript.CalcTaprootSignatureHash(
txscript.NewTxSigHashes(partialTx.UnsignedTx, prevoutFetcher),
txscript.SigHashDefault,
partialTx.UnsignedTx,
0,
prevoutFetcher,
)
if err != nil {
return err
}
if !schnorrSig.Verify(message, finalAggregatedKey) {
return errors.New("invalid signature")
}
}
}
return nil
}
func NewTreeSignerSession(
signer *btcec.PrivateKey,
roundSharedOutputAmount int64,
vtxoTree tree.VtxoTree,
scriptRoot []byte,
) SignerSession {
return &treeSignerSession{
secretKey: signer,
tree: vtxoTree,
scriptRoot: scriptRoot,
roundSharedOutputAmount: roundSharedOutputAmount,
}
}
type treeSignerSession struct {
secretKey *btcec.PrivateKey
tree tree.VtxoTree
myNonces [][]*musig2.Nonces
keys []*btcec.PublicKey
aggregateNonces TreeNonces
scriptRoot []byte
roundSharedOutputAmount int64
prevoutFetcherFactory func(*psbt.Packet) (txscript.PrevOutputFetcher, error)
}
func (t *treeSignerSession) generateNonces() error {
if t.tree == nil {
return ErrMissingVtxoTree
}
myNonces := make([][]*musig2.Nonces, 0)
for _, level := range t.tree {
levelNonces := make([]*musig2.Nonces, 0)
for range level {
nonce, err := musig2.GenNonces(
musig2.WithPublicKey(t.secretKey.PubKey()),
)
if err != nil {
return err
}
levelNonces = append(levelNonces, nonce)
}
myNonces = append(myNonces, levelNonces)
}
t.myNonces = myNonces
return nil
}
func (t *treeSignerSession) GetNonces() (TreeNonces, error) {
if t.tree == nil {
return nil, ErrMissingVtxoTree
}
if t.myNonces == nil {
if err := t.generateNonces(); err != nil {
return nil, err
}
}
nonces := make(TreeNonces, 0)
for _, level := range t.myNonces {
levelNonces := make([]*Musig2Nonce, 0)
for _, nonce := range level {
levelNonces = append(levelNonces, &Musig2Nonce{nonce.PubNonce})
}
nonces = append(nonces, levelNonces)
}
return nonces, nil
}
func (t *treeSignerSession) SetKeys(keys []*btcec.PublicKey) error {
if t.keys != nil {
return errors.New("keys already set")
}
aggregateKey, err := AggregateKeys(keys, t.scriptRoot)
if err != nil {
return err
}
factory, err := prevOutFetcherFactory(aggregateKey.FinalKey, t.tree, t.roundSharedOutputAmount)
if err != nil {
return err
}
t.prevoutFetcherFactory = factory
t.keys = keys
return nil
}
func (t *treeSignerSession) SetAggregatedNonces(nonces TreeNonces) error {
if t.aggregateNonces != nil {
return errors.New("nonces already set")
}
t.aggregateNonces = nonces
return nil
}
func (t *treeSignerSession) Sign() (TreePartialSigs, error) {
if t.tree == nil {
return nil, ErrMissingVtxoTree
}
if t.keys == nil {
return nil, ErrMissingAggregateKey
}
if t.aggregateNonces == nil {
return nil, errors.New("nonces not set")
}
sigs := make(TreePartialSigs, 0)
for i, level := range t.tree {
levelSigs := make([]*musig2.PartialSignature, 0)
for j, node := range level {
partialTx, err := psbt.NewFromRawBytes(strings.NewReader(node.Tx), true)
if err != nil {
return nil, err
}
// sign the node
sig, err := t.signPartial(partialTx, i, j, t.secretKey)
if err != nil {
return nil, err
}
levelSigs = append(levelSigs, sig)
}
sigs = append(sigs, levelSigs)
}
return sigs, nil
}
func (t *treeSignerSession) signPartial(partialTx *psbt.Packet, posx int, posy int, seckey *btcec.PrivateKey) (*musig2.PartialSignature, error) {
prevoutFetcher, err := t.prevoutFetcherFactory(partialTx)
if err != nil {
return nil, err
}
myNonce := t.myNonces[posx][posy]
aggregatedNonce := t.aggregateNonces[posx][posy]
message, err := txscript.CalcTaprootSignatureHash(
txscript.NewTxSigHashes(partialTx.UnsignedTx, prevoutFetcher),
txscript.SigHashDefault,
partialTx.UnsignedTx,
0,
prevoutFetcher,
)
if err != nil {
return nil, err
}
return musig2.Sign(
myNonce.SecNonce, seckey, aggregatedNonce.PubNonce, t.keys, [32]byte(message),
musig2.WithSortedKeys(), musig2.WithTaprootSignTweak(t.scriptRoot),
)
}
type treeCoordinatorSession struct {
scriptRoot []byte
tree tree.VtxoTree
keys []*btcec.PublicKey
nonces []TreeNonces
sigs []TreePartialSigs
prevoutFetcherFactory func(*psbt.Packet) (txscript.PrevOutputFetcher, error)
}
func NewTreeCoordinatorSession(
roundSharedOutputAmount int64,
vtxoTree tree.VtxoTree,
scriptRoot []byte,
keys []*btcec.PublicKey,
) (CoordinatorSession, error) {
aggregateKey, err := AggregateKeys(keys, scriptRoot)
if err != nil {
return nil, err
}
prevoutFetcherFactory, err := prevOutFetcherFactory(aggregateKey.FinalKey, vtxoTree, roundSharedOutputAmount)
if err != nil {
return nil, err
}
nbOfKeys := len(keys)
return &treeCoordinatorSession{
scriptRoot: scriptRoot,
tree: vtxoTree,
keys: keys,
nonces: make([]TreeNonces, nbOfKeys),
sigs: make([]TreePartialSigs, nbOfKeys),
prevoutFetcherFactory: prevoutFetcherFactory,
}, nil
}
func (t *treeCoordinatorSession) getPubkeyIndex(pubkey *btcec.PublicKey) int {
for i, key := range t.keys {
if key.IsEqual(pubkey) {
return i
}
}
return -1
}
func (t *treeCoordinatorSession) AddNonce(pubkey *btcec.PublicKey, nonce TreeNonces) error {
index := t.getPubkeyIndex(pubkey)
if index == -1 {
return errors.New("public key not found")
}
t.nonces[index] = nonce
return nil
}
func (t *treeCoordinatorSession) AddSig(pubkey *btcec.PublicKey, sig TreePartialSigs) error {
index := t.getPubkeyIndex(pubkey)
if index == -1 {
return errors.New("public key not found")
}
t.sigs[index] = sig
return nil
}
func (t *treeCoordinatorSession) AggregateNonces() (TreeNonces, error) {
for _, nonce := range t.nonces {
if nonce == nil {
return nil, errors.New("nonces not set")
}
}
aggregatedNonces := make(TreeNonces, 0)
for i, level := range t.tree {
levelNonces := make([]*Musig2Nonce, 0)
for j := range level {
nonces := make([][66]byte, 0)
for _, n := range t.nonces {
nonces = append(nonces, n[i][j].PubNonce)
}
aggregatedNonce, err := musig2.AggregateNonces(nonces)
if err != nil {
return nil, err
}
levelNonces = append(levelNonces, &Musig2Nonce{aggregatedNonce})
}
aggregatedNonces = append(aggregatedNonces, levelNonces)
}
return aggregatedNonces, nil
}
// SignTree implements CoordinatorSession.
func (t *treeCoordinatorSession) SignTree() (tree.VtxoTree, error) {
var missingSigs int
for _, sig := range t.sigs {
if sig == nil {
missingSigs++
}
}
if missingSigs > 0 {
return nil, fmt.Errorf("missing %d signature(s)", missingSigs)
}
aggregatedKey, err := AggregateKeys(t.keys, t.scriptRoot)
if err != nil {
return nil, err
}
for i, level := range t.tree {
for j, node := range level {
partialTx, err := psbt.NewFromRawBytes(strings.NewReader(node.Tx), true)
if err != nil {
return nil, err
}
var combinedNonce *secp256k1.PublicKey
sigs := make([]*musig2.PartialSignature, 0)
for _, sig := range t.sigs {
s := sig[i][j]
if s.R != nil {
combinedNonce = s.R
}
sigs = append(sigs, s)
}
if combinedNonce == nil {
return nil, errors.New("missing combined nonce")
}
prevoutFetcher, err := t.prevoutFetcherFactory(partialTx)
if err != nil {
return nil, err
}
message, err := txscript.CalcTaprootSignatureHash(
txscript.NewTxSigHashes(partialTx.UnsignedTx, prevoutFetcher),
txscript.SigHashDefault,
partialTx.UnsignedTx,
0,
prevoutFetcher,
)
combinedSig := musig2.CombineSigs(
combinedNonce, sigs,
musig2.WithTaprootTweakedCombine([32]byte(message), t.keys, t.scriptRoot, true),
)
if err != nil {
return nil, err
}
if !combinedSig.Verify(message, aggregatedKey.FinalKey) {
return nil, errors.New("invalid signature")
}
partialTx.Inputs[0].TaprootKeySpendSig = combinedSig.Serialize()
encodedSignedTx, err := partialTx.B64Encode()
if err != nil {
return nil, err
}
node.Tx = encodedSignedTx
t.tree[i][j] = node
}
}
return t.tree, nil
}
func prevOutFetcherFactory(
finalAggregatedKey *btcec.PublicKey,
vtxoTree tree.VtxoTree,
roundSharedOutputAmount int64,
) (
func(partial *psbt.Packet) (txscript.PrevOutputFetcher, error),
error,
) {
pkscript, err := common.P2TRScript(finalAggregatedKey)
if err != nil {
return nil, err
}
rootNode, err := vtxoTree.Root()
if err != nil {
return nil, err
}
return func(partial *psbt.Packet) (txscript.PrevOutputFetcher, error) {
parentOutpoint := partial.UnsignedTx.TxIn[0].PreviousOutPoint
parentTxID := parentOutpoint.Hash.String()
if rootNode.ParentTxid == parentTxID {
return &treePrevOutFetcher{
prevout: &wire.TxOut{
Value: roundSharedOutputAmount,
PkScript: pkscript,
},
}, nil
}
var parent tree.Node
for _, level := range vtxoTree {
for _, n := range level {
if n.Txid == parentTxID {
parent = n
break
}
}
}
if parent.Txid == "" {
return nil, errors.New("parent tx not found")
}
parentTx, err := psbt.NewFromRawBytes(strings.NewReader(parent.Tx), true)
if err != nil {
return nil, err
}
parentValue := parentTx.UnsignedTx.TxOut[parentOutpoint.Index].Value
return &treePrevOutFetcher{
prevout: &wire.TxOut{
Value: parentValue,
PkScript: pkscript,
},
}, nil
}, nil
}
type treePrevOutFetcher struct {
prevout *wire.TxOut
}
func (f *treePrevOutFetcher) FetchPrevOutput(wire.OutPoint) *wire.TxOut {
return f.prevout
}
type writable interface {
Encode(w io.Writer) error
}
type readable interface {
Decode(r io.Reader) error
}
// encodeMatrix encode a matrix of serializable objects into a byte stream
func encodeMatrix[T writable](matrix [][]T) ([]byte, error) {
var buf bytes.Buffer
// Write number of rows
if err := binary.Write(&buf, binary.LittleEndian, uint32(len(matrix))); err != nil {
return nil, err
}
// For each row, write its length and then its elements
for _, row := range matrix {
// Write row length
if err := binary.Write(&buf, binary.LittleEndian, uint32(len(row))); err != nil {
return nil, err
}
// Write row data
for _, cell := range row {
if err := cell.Encode(&buf); err != nil {
return nil, err
}
}
}
return buf.Bytes(), nil
}
// decodeMatrix decode a byte stream into a matrix of serializable objects
func decodeMatrix[T readable](factory func() T, data io.Reader) ([][]T, error) {
var rowCount uint32
// Read number of rows
if err := binary.Read(data, binary.LittleEndian, &rowCount); err != nil {
return nil, err
}
// Initialize matrix
matrix := make([][]T, 0, rowCount)
// For each row, read its length and then its elements
for i := uint32(0); i < rowCount; i++ {
var colCount uint32
// Read row length
if err := binary.Read(data, binary.LittleEndian, &colCount); err != nil {
return nil, err
}
// Initialize row
row := make([]T, 0, colCount)
// Read row data
for j := uint32(0); j < colCount; j++ {
cell := factory()
if err := cell.Decode(data); err != nil {
return nil, err
}
row = append(row, cell)
}
matrix = append(matrix, row)
}
return matrix, nil
}