Move retry logic into LogClient

This allows retry logic to be used for all requests, not just get-entries

Also add context arguments
This commit is contained in:
Andrew Ayer
2022-06-02 10:02:32 -04:00
parent f7f79f2600
commit 039339154f
3 changed files with 138 additions and 76 deletions

View File

@@ -15,6 +15,7 @@ package certspotter
import (
// "container/list"
"bytes"
"context"
"crypto"
"errors"
"fmt"
@@ -30,11 +31,6 @@ import (
type ProcessCallback func(*Scanner, *ct.LogEntry)
const (
FETCH_RETRIES = 10
FETCH_RETRY_WAIT = 1
)
// ScannerOptions holds configuration options for the Scanner
type ScannerOptions struct {
// Number of entries to request in one batch from the Log
@@ -90,26 +86,12 @@ func (s *Scanner) processerJob(id int, certsProcessed *int64, entries <-chan ct.
}
func (s *Scanner) fetch(r fetchRange, entries chan<- ct.LogEntry, tree *CollapsedMerkleTree) error {
success := false
retries := FETCH_RETRIES
retryWait := FETCH_RETRY_WAIT
for !success {
for r.start <= r.end {
s.Log(fmt.Sprintf("Fetching entries %d to %d", r.start, r.end))
logEntries, err := s.logClient.GetEntries(r.start, r.end)
logEntries, err := s.logClient.GetEntries(context.Background(), r.start, r.end)
if err != nil {
if retries == 0 {
s.Warn(fmt.Sprintf("Problem fetching entries %d to %d from log: %s", r.start, r.end, err.Error()))
return err
} else {
s.Log(fmt.Sprintf("Problem fetching entries %d to %d from log (will retry): %s", r.start, r.end, err.Error()))
time.Sleep(time.Duration(retryWait) * time.Second)
retries--
retryWait *= 2
continue
}
return err
}
retries = FETCH_RETRIES
retryWait = FETCH_RETRY_WAIT
for _, logEntry := range logEntries {
if tree != nil {
tree.Add(hashLeaf(logEntry.LeafBytes))
@@ -118,12 +100,6 @@ func (s *Scanner) fetch(r fetchRange, entries chan<- ct.LogEntry, tree *Collapse
entries <- logEntry
r.start++
}
if r.start > r.end {
// Only complete if we actually got all the leaves we were
// expecting -- Logs MAY return fewer than the number of
// leaves requested.
success = true
}
}
return nil
}
@@ -194,7 +170,7 @@ func (s Scanner) Warn(msg string) {
}
func (s *Scanner) GetSTH() (*ct.SignedTreeHead, error) {
latestSth, err := s.logClient.GetSTH()
latestSth, err := s.logClient.GetSTH(context.Background())
if err != nil {
return nil, err
}
@@ -218,13 +194,13 @@ func (s *Scanner) CheckConsistency(first *ct.SignedTreeHead, second *ct.SignedTr
// return a 400 error if we ask for such a proof.
return true, nil
} else if first.TreeSize < second.TreeSize {
proof, err := s.logClient.GetConsistencyProof(int64(first.TreeSize), int64(second.TreeSize))
proof, err := s.logClient.GetConsistencyProof(context.Background(), int64(first.TreeSize), int64(second.TreeSize))
if err != nil {
return false, err
}
return VerifyConsistencyProof(proof, first, second), nil
} else if first.TreeSize > second.TreeSize {
proof, err := s.logClient.GetConsistencyProof(int64(second.TreeSize), int64(first.TreeSize))
proof, err := s.logClient.GetConsistencyProof(context.Background(), int64(second.TreeSize), int64(first.TreeSize))
if err != nil {
return false, err
}
@@ -241,7 +217,7 @@ func (s *Scanner) MakeCollapsedMerkleTree(sth *ct.SignedTreeHead) (*CollapsedMer
return &CollapsedMerkleTree{}, nil
}
entries, err := s.logClient.GetEntries(int64(sth.TreeSize-1), int64(sth.TreeSize-1))
entries, err := s.logClient.GetEntries(context.Background(), int64(sth.TreeSize-1), int64(sth.TreeSize-1))
if err != nil {
return nil, err
}
@@ -252,7 +228,7 @@ func (s *Scanner) MakeCollapsedMerkleTree(sth *ct.SignedTreeHead) (*CollapsedMer
var tree *CollapsedMerkleTree
if sth.TreeSize > 1 {
auditPath, _, err := s.logClient.GetAuditProof(leafHash, sth.TreeSize)
auditPath, _, err := s.logClient.GetAuditProof(context.Background(), leafHash, sth.TreeSize)
if err != nil {
return nil, err
}