diff --git a/itest/lspd_node.go b/itest/lspd_node.go index fa6beab..61482b8 100644 --- a/itest/lspd_node.go +++ b/itest/lspd_node.go @@ -21,7 +21,6 @@ import ( "github.com/decred/dcrd/dcrec/secp256k1/v4" ecies "github.com/ecies/go/v2" "github.com/golang/protobuf/proto" - "github.com/jackc/pgx/v5/pgxpool" "google.golang.org/grpc/metadata" ) @@ -193,14 +192,7 @@ func (l *lspBase) Initialize() error { return err } - pgxPool, err := pgxpool.New(l.harness.Ctx, l.postgresBackend.ConnectionString()) - if err != nil { - lntest.PerformCleanup(cleanups) - return fmt.Errorf("failed to connect to postgres: %w", err) - } - defer pgxPool.Close() - - _, err = pgxPool.Exec( + _, err = l.postgresBackend.Pool().Exec( l.harness.Ctx, `DELETE FROM new_channel_params`, ) @@ -209,7 +201,7 @@ func (l *lspBase) Initialize() error { return fmt.Errorf("failed to delete new_channel_params: %w", err) } - _, err = pgxPool.Exec( + _, err = l.postgresBackend.Pool().Exec( l.harness.Ctx, `INSERT INTO new_channel_params (validity, params, token) VALUES @@ -301,13 +293,7 @@ type FeeParamSetting struct { } func SetFeeParams(l LspNode, settings []*FeeParamSetting) error { - pgxPool, err := pgxpool.New(l.Harness().Ctx, l.PostgresBackend().ConnectionString()) - if err != nil { - return fmt.Errorf("failed to connect to postgres: %w", err) - } - defer pgxPool.Close() - - _, err = pgxPool.Exec(l.Harness().Ctx, "DELETE FROM new_channel_params") + _, err := l.PostgresBackend().Pool().Exec(l.Harness().Ctx, "DELETE FROM new_channel_params") if err != nil { return fmt.Errorf("failed to delete new_channel_params: %w", err) } @@ -333,7 +319,7 @@ func SetFeeParams(l LspNode, settings []*FeeParamSetting) error { first = false } query += `;` - _, err = pgxPool.Exec(l.Harness().Ctx, query) + _, err = l.PostgresBackend().Pool().Exec(l.Harness().Ctx, query) if err != nil { return fmt.Errorf("failed to insert new_channel_params: %w", err) } diff --git a/itest/lsps2_buy_test.go b/itest/lsps2_buy_test.go index eaf9841..10be44f 100644 --- a/itest/lsps2_buy_test.go +++ b/itest/lsps2_buy_test.go @@ -9,7 +9,6 @@ import ( "github.com/breez/lntest" "github.com/breez/lspd/lightning" "github.com/breez/lspd/lsps0" - "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/assert" ) @@ -89,16 +88,10 @@ func testLsps2Buy(p *testParams) { err = json.Unmarshal(buyResp.Data, b) lntest.CheckError(p.t, err) - pgxPool, err := pgxpool.New(p.h.Ctx, p.lsp.PostgresBackend().ConnectionString()) - if err != nil { - p.h.T.Fatalf("Failed to connect to postgres backend: %v", err) - } - defer pgxPool.Close() - scid, err := lightning.NewShortChannelIDFromString(b.Result.Jit_channel_scid) lntest.CheckError(p.t, err) - rows, err := pgxPool.Query( + rows, err := p.lsp.PostgresBackend().Pool().Query( p.h.Ctx, `SELECT token FROM lsps2.buy_registrations diff --git a/itest/postgres.go b/itest/postgres.go index d0092cb..ac999ce 100644 --- a/itest/postgres.go +++ b/itest/postgres.go @@ -29,9 +29,14 @@ type PostgresContainer struct { logfile string isInitialized bool isStarted bool + pool *pgxpool.Pool mtx sync.Mutex } +func (p *PostgresContainer) Pool() *pgxpool.Pool { + return p.pool +} + func NewPostgresContainer(logfile string) (*PostgresContainer, error) { port, err := lntest.GetPort() if err != nil { @@ -91,9 +96,16 @@ HealthCheck: return fmt.Errorf("container '%s' unhealthy", c.id) case "healthy": for { - pgxPool, err := pgxpool.New(ctx, c.ConnectionString()) + if c.pool == nil { + c.pool, err = pgxpool.New(ctx, c.ConnectionString()) + if err != nil { + <-time.After(50 * time.Millisecond) + continue + } + } + + _, err = c.pool.Exec(ctx, "SELECT 1;") if err == nil { - pgxPool.Close() break HealthCheck } @@ -175,6 +187,11 @@ func (c *PostgresContainer) Stop(ctx context.Context) error { return nil } + if c.pool != nil { + c.pool.Close() + c.pool = nil + } + defer c.cli.Close() err := c.cli.ContainerStop(ctx, c.id, nil) c.isStarted = false @@ -246,19 +263,13 @@ func (c *PostgresContainer) RunMigrations(ctx context.Context, migrationDir stri sort.Strings(filenames) - pgxPool, err := pgxpool.New(ctx, c.ConnectionString()) - if err != nil { - return fmt.Errorf("failed to connect to postgres: %w", err) - } - defer pgxPool.Close() - for _, filename := range filenames { data, err := os.ReadFile(filename) if err != nil { return fmt.Errorf("failed to read migration file '%s': %w", filename, err) } - _, err = pgxPool.Exec(ctx, string(data)) + _, err = c.pool.Exec(ctx, string(data)) if err != nil { return fmt.Errorf("failed to execute migration file '%s': %w", filename, err) } diff --git a/itest/tag_test.go b/itest/tag_test.go index dc9f4ff..c4c0387 100644 --- a/itest/tag_test.go +++ b/itest/tag_test.go @@ -5,7 +5,6 @@ import ( "github.com/breez/lntest" lspd "github.com/breez/lspd/rpc" - "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/assert" ) @@ -25,13 +24,7 @@ func registerPaymentWithTag(p *testParams) { Tag: expected, }, false) - pgxPool, err := pgxpool.New(p.h.Ctx, p.lsp.PostgresBackend().ConnectionString()) - if err != nil { - p.h.T.Fatalf("Failed to connect to postgres backend: %v", err) - } - defer pgxPool.Close() - - rows, err := pgxPool.Query( + rows, err := p.lsp.PostgresBackend().Pool().Query( p.h.Ctx, "SELECT tag FROM public.payments WHERE payment_hash=$1", i.PaymentHash,