Adding logic to count signatures in multisig

This commit is contained in:
rockstardev
2025-03-23 17:55:46 -05:00
parent 42878f23a6
commit 6784be2ce2
4 changed files with 86 additions and 28 deletions

View File

@@ -54,6 +54,11 @@ public class PendingTransaction: IHasBlob<PendingTransactionBlob>
{
public string PSBT { get; set; }
public List<CollectedSignature> CollectedSignatures { get; set; } = new();
public int? SignaturesCollected { get; set; }
// for example: 3/5
public int? SignaturesNeeded { get; set; }
public int? SignaturesTotal { get; set; }
}
public class CollectedSignature

View File

@@ -123,14 +123,23 @@ public class MultisigTests : UnitTestBase
s.Driver.FindElement(By.Id("Outputs_0__Amount")).SendKeys(amount);
s.Driver.FindElement(By.Id("CreatePendingTransaction")).Click();
// now clicking on View to sign transaction
// validating the state of UI
Assert.Equal("0", s.Driver.FindElement(By.Id("Sigs_0__Collected")).Text);
Assert.Equal("2/3", s.Driver.FindElement(By.Id("Sigs_0__Scheme")).Text);
// now proceeding to click on sign button and sign transactions
SignPendingTransactionWithKey(s, address, derivationScheme, resp1);
Assert.Equal("1", s.Driver.FindElement(By.Id("Sigs_0__Collected")).Text);
SignPendingTransactionWithKey(s, address, derivationScheme, resp2);
Assert.Equal("2", s.Driver.FindElement(By.Id("Sigs_0__Collected")).Text);
// Broadcasting transaction and ensuring there is no longer broadcast button
// we should now have enough signatures to broadcast transaction
s.Driver.WaitForElement(By.XPath("//a[text()='Broadcast']")).Click();
s.Driver.FindElement(By.Id("BroadcastTransaction")).Click();
Assert.Contains("Transaction broadcasted successfully", s.FindAlertMessage().Text);
// now that we broadcast transaction, there shouldn't be broadcast button
s.Driver.AssertElementNotFound(By.XPath("//a[text()='Broadcast']"));
// Abort pending transaction flow

View File

@@ -94,11 +94,28 @@ public class PendingTransactionService(
{
var network = networkProvider.GetNetwork<BTCPayNetwork>(cryptoCode);
if (network is null)
{
throw new NotSupportedException("CryptoCode not supported");
}
var txId = psbt.GetGlobalTransaction().GetHash();
int signaturesNeeded = 0;
int signaturesTotal = 0;
foreach (var input in psbt.Inputs)
{
var script = input.WitnessScript ?? input.RedeemScript;
if (script is null)
continue;
var multisigParams = PayToMultiSigTemplate.Instance.ExtractScriptPubKeyParameters(script);
if (multisigParams != null)
{
signaturesNeeded = multisigParams.SignatureCount;
signaturesTotal = multisigParams.PubKeys.Length;
break; // assume consistent multisig scheme across all inputs
}
}
await using var ctx = dbContextFactory.CreateContext();
var pendingTransaction = new PendingTransaction
{
@@ -109,14 +126,24 @@ public class PendingTransactionService(
Expiry = expiry,
StoreId = storeId,
};
pendingTransaction.SetBlob(new PendingTransactionBlob { PSBT = psbt.ToBase64() });
pendingTransaction.SetBlob(new PendingTransactionBlob
{
PSBT = psbt.ToBase64(),
SignaturesCollected = 0,
SignaturesNeeded = signaturesNeeded,
SignaturesTotal = signaturesTotal
});
ctx.PendingTransactions.Add(pendingTransaction);
await ctx.SaveChangesAsync(cancellationToken);
EventAggregator.Publish(new PendingTransactionEvent
{
Data = pendingTransaction,
Type = PendingTransactionEvent.Created
});
return pendingTransaction;
}
@@ -143,7 +170,7 @@ public class PendingTransactionService(
return null;
}
var originalPsbtWorkingCopy = PSBT.Parse(blob.PSBT, network.NBitcoinNetwork);
var dbPsbt = PSBT.Parse(blob.PSBT, network.NBitcoinNetwork);
// Deduplicate: Check if this exact PSBT (Base64) was already collected
var newPsbtBase64 = psbt.ToBase64();
@@ -155,26 +182,29 @@ public class PendingTransactionService(
foreach (var collectedSignature in blob.CollectedSignatures)
{
var collectedPsbt = PSBT.Parse(collectedSignature.ReceivedPSBT, network.NBitcoinNetwork);
originalPsbtWorkingCopy.Combine(collectedPsbt); // combine changes the object
dbPsbt.Combine(collectedPsbt); // combine changes the object
}
var originalPsbtWorkingCopyWithNewPsbt = originalPsbtWorkingCopy.Clone(); // Clone before modifying
originalPsbtWorkingCopyWithNewPsbt.Combine(psbt);
var newWorkingCopyPsbt = dbPsbt.Clone(); // Clone before modifying
newWorkingCopyPsbt.Combine(psbt);
// Check if new signatures were actually added
bool newSignaturesCollected = false;
for (int i = 0; i < originalPsbtWorkingCopy.Inputs.Count; i++)
{
if (originalPsbtWorkingCopyWithNewPsbt.Inputs[i].PartialSigs.Count >
originalPsbtWorkingCopy.Inputs[i].PartialSigs.Count)
{
newSignaturesCollected = true;
break;
}
}
var oldPubKeys = dbPsbt.Inputs
.SelectMany(input => input.PartialSigs.Keys)
.ToHashSet();
if (newSignaturesCollected)
var newPubKeys = newWorkingCopyPsbt.Inputs
.SelectMany(input => input.PartialSigs.Keys)
.ToHashSet();
newPubKeys.ExceptWith(oldPubKeys);
var newSignatures = newPubKeys.Count;
if (newSignatures > 0)
{
// TODO: For now we're going with estimation of how many signatures were collected until we find better way
// so for example if we have 4 new signatures and only 2 inputs - number of collected signatures will be 2
blob.SignaturesCollected += newSignatures / newWorkingCopyPsbt.Inputs.Count();
blob.CollectedSignatures.Add(new CollectedSignature
{
ReceivedPSBT = newPsbtBase64,
@@ -183,8 +213,12 @@ public class PendingTransactionService(
pendingTransaction.SetBlob(blob);
}
if (originalPsbtWorkingCopyWithNewPsbt.TryFinalize(out _))
if (newWorkingCopyPsbt.TryFinalize(out _))
{
// TODO: Better logic here
if (blob.SignaturesCollected < blob.SignaturesNeeded)
blob.SignaturesCollected = blob.SignaturesNeeded;
pendingTransaction.State = PendingTransactionState.Signed;
}
@@ -224,6 +258,11 @@ public class PendingTransactionService(
if (pt is null) return;
pt.State = PendingTransactionState.Cancelled;
await ctx.SaveChangesAsync();
EventAggregator.Publish(new PendingTransactionEvent
{
Data = pt,
Type = PendingTransactionEvent.Cancelled
});
}
public async Task Broadcasted(string cryptoCode, string storeId, string transactionId)
@@ -247,6 +286,7 @@ public class PendingTransactionService(
public const string Created = nameof(Created);
public const string SignatureCollected = nameof(SignatureCollected);
public const string Broadcast = nameof(Broadcast);
public const string Cancelled = nameof(Cancelled);
public PendingTransaction Data { get; set; } = null!;
public string Type { get; set; } = null!;

View File

@@ -176,21 +176,25 @@
<thead>
<th>Id</th>
<th>State</th>
<th>Signature count</th>
<th>Signatures</th>
<th>Scheme</th>
<th>Actions</th>
</thead>
@foreach (var pendingTransaction in Model.PendingTransactions)
@for (var index = 0; index < Model.PendingTransactions.Length; index++)
{
var pendingTransaction = Model.PendingTransactions[index];
var ptblob = @pendingTransaction.GetBlob();
<tr>
<td>@pendingTransaction.TransactionId</td>
<td>@pendingTransaction.State</td>
<td>@pendingTransaction.GetBlob().CollectedSignatures.Count</td>
<td><span id="Sigs_@(index)__Collected">@ptblob?.SignaturesCollected</span></td>
<td><span id="Sigs_@(index)__Scheme">@ptblob?.SignaturesNeeded/@ptblob?.SignaturesTotal</span></td>
<td>
<a asp-action="ViewPendingTransaction" asp-route-walletId="@walletId" asp-route-transactionId="@pendingTransaction.TransactionId"
>@(pendingTransaction.State == PendingTransactionState.Signed ? "Broadcast" : "View")</a>
<a asp-action="ViewPendingTransaction" asp-route-walletId="@walletId"
asp-route-transactionId="@pendingTransaction.TransactionId">@(pendingTransaction.State == PendingTransactionState.Signed ? "Broadcast" : "View")</a>
-
<a asp-action="CancelPendingTransaction"
asp-route-walletId="@walletId" asp-route-transactionId="@pendingTransaction.TransactionId">Abort</a>
<a asp-action="CancelPendingTransaction" asp-route-walletId="@walletId"
asp-route-transactionId="@pendingTransaction.TransactionId">Abort</a>
</td>
</tr>
}