From e33693398ece3c87173f652771270603d42cffcc Mon Sep 17 00:00:00 2001 From: Michael Bumann Date: Tue, 9 Jan 2024 19:38:01 +0200 Subject: [PATCH] Soft delete users (#476) * Update Makefile * Optionally load test DB from env variable * Add option to soft-delete a user This allows users to be marked as deleted. An additional middleware checks if a user is deleted or deactivated and rejects requests for those as StatusUnauthorized. note: the middelware adds an additional DB query to load the user. --- Makefile | 5 +- cmd/server/main.go | 4 +- controllers_v2/update.ctrl.go | 5 +- ...20240103130000_add_deleted_to_users.up.sql | 1 + db/models/user.go | 1 + integration_tests/deactivated_deleted_test.go | 101 ++++++++++++++++++ integration_tests/util.go | 5 +- lib/service/service.go | 31 +++++- lib/service/user.go | 10 +- lib/tokens/jwt.go | 14 +-- 10 files changed, 163 insertions(+), 14 deletions(-) create mode 100644 db/migrations/20240103130000_add_deleted_to_users.up.sql create mode 100644 integration_tests/deactivated_deleted_test.go diff --git a/Makefile b/Makefile index 0ca7fd2..76d7d2e 100644 --- a/Makefile +++ b/Makefile @@ -2,4 +2,7 @@ cp .env_example .env build: - CGO_ENABLED=0 go build -o lndhub + CGO_ENABLED=0 go build -o lndhub ./cmd/server + +test: + go test -p 1 -v -covermode=atomic -coverprofile=coverage.out -cover -coverpkg=./... ./... diff --git a/cmd/server/main.go b/cmd/server/main.go index fa5a61a..6de7eb7 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -153,8 +153,8 @@ func main() { logMw := transport.CreateLoggingMiddleware(logger) // strict rate limit for requests for sending payments strictRateLimitMiddleware := transport.CreateRateLimitMiddleware(c.StrictRateLimit, c.BurstRateLimit) - secured := e.Group("", tokens.Middleware(c.JWTSecret), logMw) - securedWithStrictRateLimit := e.Group("", tokens.Middleware(c.JWTSecret), strictRateLimitMiddleware, logMw) + secured := e.Group("", tokens.Middleware(c.JWTSecret), svc.ValidateUserMiddleware(), logMw) + securedWithStrictRateLimit := e.Group("", tokens.Middleware(c.JWTSecret), svc.ValidateUserMiddleware(), strictRateLimitMiddleware, logMw) transport.RegisterLegacyEndpoints(svc, e, secured, securedWithStrictRateLimit, strictRateLimitMiddleware, tokens.AdminTokenMiddleware(c.AdminToken), logMw) transport.RegisterV2Endpoints(svc, e, secured, securedWithStrictRateLimit, strictRateLimitMiddleware, tokens.AdminTokenMiddleware(c.AdminToken), logMw) diff --git a/controllers_v2/update.ctrl.go b/controllers_v2/update.ctrl.go index 614eb35..4bc29c7 100644 --- a/controllers_v2/update.ctrl.go +++ b/controllers_v2/update.ctrl.go @@ -20,12 +20,14 @@ func NewUpdateUserController(svc *service.LndhubService) *UpdateUserController { type UpdateUserResponseBody struct { Login string `json:"login"` Deactivated bool `json:"deactivated"` + Deleted bool `json:"deleted"` ID int64 `json:"id"` } type UpdateUserRequestBody struct { Login *string `json:"login,omitempty"` Password *string `json:"password,omitempty"` Deactivated *bool `json:"deactivated,omitempty"` + Deleted *bool `json:"deleted,omitempty"` ID int64 `json:"id" validate:"required"` } @@ -52,7 +54,7 @@ func (controller *UpdateUserController) UpdateUser(c echo.Context) error { c.Logger().Errorf("Invalid update user request body error: %v", err) return c.JSON(http.StatusBadRequest, responses.BadArgumentsError) } - user, err := controller.svc.UpdateUser(c.Request().Context(), body.ID, body.Login, body.Password, body.Deactivated) + user, err := controller.svc.UpdateUser(c.Request().Context(), body.ID, body.Login, body.Password, body.Deactivated, body.Deleted) if err != nil { c.Logger().Errorf("Failed to update user: %v", err) return c.JSON(http.StatusBadRequest, responses.BadArgumentsError) @@ -61,6 +63,7 @@ func (controller *UpdateUserController) UpdateUser(c echo.Context) error { var ResponseBody UpdateUserResponseBody ResponseBody.Login = user.Login ResponseBody.Deactivated = user.Deactivated + ResponseBody.Deleted = user.Deleted ResponseBody.ID = user.ID return c.JSON(http.StatusOK, &ResponseBody) diff --git a/db/migrations/20240103130000_add_deleted_to_users.up.sql b/db/migrations/20240103130000_add_deleted_to_users.up.sql new file mode 100644 index 0000000..d8fee07 --- /dev/null +++ b/db/migrations/20240103130000_add_deleted_to_users.up.sql @@ -0,0 +1 @@ +alter table users add column deleted boolean default false; diff --git a/db/models/user.go b/db/models/user.go index 98939f0..f52d4d8 100644 --- a/db/models/user.go +++ b/db/models/user.go @@ -19,6 +19,7 @@ type User struct { Invoices []*Invoice `bun:"rel:has-many,join:id=user_id"` Accounts []*Account `bun:"rel:has-many,join:id=user_id"` Deactivated bool + Deleted bool } func (u *User) BeforeAppendModel(ctx context.Context, query bun.Query) error { diff --git a/integration_tests/deactivated_deleted_test.go b/integration_tests/deactivated_deleted_test.go new file mode 100644 index 0000000..8966a15 --- /dev/null +++ b/integration_tests/deactivated_deleted_test.go @@ -0,0 +1,101 @@ +package integration_tests + +import ( + "context" + "fmt" + "log" + "net/http" + "net/http/httptest" + "testing" + + "github.com/getAlby/lndhub.go/controllers" + "github.com/getAlby/lndhub.go/lib" + "github.com/getAlby/lndhub.go/lib/responses" + "github.com/getAlby/lndhub.go/lib/service" + "github.com/getAlby/lndhub.go/lib/tokens" + "github.com/go-playground/validator/v10" + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type ValidateUserSuite struct { + TestSuite + Service *service.LndhubService + userLogin ExpectedCreateUserResponseBody + userToken string + mockLND *MockLND + invoiceUpdateSubCancelFn context.CancelFunc +} + +func (suite *ValidateUserSuite) SetupSuite() { + mockLND := newDefaultMockLND() + svc, err := LndHubTestServiceInit(mockLND) + if err != nil { + log.Fatalf("Error initializing test service: %v", err) + } + users, userTokens, err := createUsers(svc, 1) + if err != nil { + log.Fatalf("Error creating test users %v", err) + } + suite.Service = svc + suite.mockLND = mockLND + e := echo.New() + + e.HTTPErrorHandler = responses.HTTPErrorHandler + e.Validator = &lib.CustomValidator{Validator: validator.New()} + suite.echo = e + suite.echo.Use(tokens.Middleware([]byte(suite.Service.Config.JWTSecret))) + suite.echo.Use(svc.ValidateUserMiddleware()) + suite.echo.GET("/gettxs", controllers.NewGetTXSController(suite.Service).GetTXS) + suite.echo.GET("/getuserinvoices", controllers.NewGetTXSController(svc).GetUserInvoices) + suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.Service).AddInvoice) + suite.echo.POST("/payinvoice", controllers.NewPayInvoiceController(suite.Service).PayInvoice) + + assert.Equal(suite.T(), 1, len(users)) + suite.userLogin = users[0] + suite.userToken = userTokens[0] +} + +func (suite *ValidateUserSuite) TearDownSuite() { +} + +func (suite *ValidateUserSuite) TestDeletedUserValidation() { + _, err := suite.Service.DB.NewUpdate().Table("users").Set("deleted = ?", true).Where("login = ?", suite.userLogin.Login).Exec(context.TODO()) + assert.NoError(suite.T(), err) + req := httptest.NewRequest(http.MethodGet, "/gettxs", nil) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", suite.userToken)) + rec := httptest.NewRecorder() + suite.echo.ServeHTTP(rec, req) + assert.Equal(suite.T(), http.StatusUnauthorized, rec.Code) + + _, err = suite.Service.DB.NewUpdate().Table("users").Set("deleted = ?", false).Where("login = ?", suite.userLogin.Login).Exec(context.TODO()) + assert.NoError(suite.T(), err) + req = httptest.NewRequest(http.MethodGet, "/gettxs", nil) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", suite.userToken)) + rec = httptest.NewRecorder() + suite.echo.ServeHTTP(rec, req) + assert.Equal(suite.T(), http.StatusOK, rec.Code) +} + +func (suite *ValidateUserSuite) TestDeactivatedUserValidation() { + _, err := suite.Service.DB.NewUpdate().Table("users").Set("deactivated = ?, deleted = false", true).Where("login = ?", suite.userLogin.Login).Exec(context.TODO()) + assert.NoError(suite.T(), err) + req := httptest.NewRequest(http.MethodGet, "/gettxs", nil) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", suite.userToken)) + rec := httptest.NewRecorder() + suite.echo.ServeHTTP(rec, req) + assert.Equal(suite.T(), http.StatusUnauthorized, rec.Code) + + _, err = suite.Service.DB.NewUpdate().Table("users").Set("deactivated = ?, deleted = false", false).Where("login = ?", suite.userLogin.Login).Exec(context.TODO()) + assert.NoError(suite.T(), err) + req = httptest.NewRequest(http.MethodGet, "/gettxs", nil) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", suite.userToken)) + rec = httptest.NewRecorder() + suite.echo.ServeHTTP(rec, req) + assert.Equal(suite.T(), http.StatusOK, rec.Code) +} + +func TestValidateUserSuite(t *testing.T) { + suite.Run(t, new(ValidateUserSuite)) +} diff --git a/integration_tests/util.go b/integration_tests/util.go index c353abe..4d4fb48 100644 --- a/integration_tests/util.go +++ b/integration_tests/util.go @@ -47,7 +47,10 @@ const ( ) func LndHubTestServiceInit(lndClientMock lnd.LightningClientWrapper) (svc *service.LndhubService, err error) { - dbUri := "postgresql://user:password@localhost/lndhub?sslmode=disable" + dbUri, ok := os.LookupEnv("DATABASE_URI") + if !ok { + dbUri = "postgresql://user:password@localhost/lndhub?sslmode=disable" + } c := &service.Config{ DatabaseUri: dbUri, DatabaseMaxConns: 1, diff --git a/lib/service/service.go b/lib/service/service.go index b4278ff..2e0413a 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -3,6 +3,7 @@ package service import ( "context" "fmt" + "net/http" "strconv" "github.com/getAlby/lndhub.go/rabbitmq" @@ -11,6 +12,7 @@ import ( "github.com/getAlby/lndhub.go/lib/responses" "github.com/getAlby/lndhub.go/lib/tokens" "github.com/getAlby/lndhub.go/lnd" + "github.com/labstack/echo/v4" "github.com/labstack/gommon/random" "github.com/uptrace/bun" "github.com/ziflex/lecho/v3" @@ -58,7 +60,7 @@ func (svc *LndhubService) GenerateToken(ctx context.Context, login, password, in } } - if user.Deactivated { + if user.Deactivated || user.Deleted { return "", "", fmt.Errorf(responses.AccountDeactivatedError.Message) } @@ -88,3 +90,30 @@ func (svc *LndhubService) ParseInt(value interface{}) (int64, error) { return 0, fmt.Errorf("conversion to int from %T not supported", v) } } + +func (svc *LndhubService) ValidateUserMiddleware() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + userId := c.Get("UserID").(int64) + if userId == 0 { + return echo.ErrUnauthorized + } + user, err := svc.FindUser(c.Request().Context(), userId) + if err != nil { + return echo.NewHTTPError(http.StatusUnauthorized, echo.Map{ + "error": true, + "code": 1, + "message": "bad auth", + }) + } + if user.Deactivated || user.Deleted { + return echo.NewHTTPError(http.StatusUnauthorized, echo.Map{ + "error": true, + "code": 1, + "message": "bad auth", + }) + } + return next(c) + } + } +} diff --git a/lib/service/user.go b/lib/service/user.go index da66d62..a41cdee 100644 --- a/lib/service/user.go +++ b/lib/service/user.go @@ -78,7 +78,7 @@ func (svc *LndhubService) CreateUser(ctx context.Context, login string, password return user, err } -func (svc *LndhubService) UpdateUser(ctx context.Context, userId int64, login *string, password *string, deactivated *bool) (user *models.User, err error) { +func (svc *LndhubService) UpdateUser(ctx context.Context, userId int64, login *string, password *string, deactivated *bool, deleted *bool) (user *models.User, err error) { user, err = svc.FindUser(ctx, userId) if err != nil { return nil, err @@ -99,6 +99,14 @@ func (svc *LndhubService) UpdateUser(ctx context.Context, userId int64, login *s if deactivated != nil { user.Deactivated = *deactivated } + // if a user gets deleted we mark it as deactivated and deleted + // un-deleting it is not supported currently + if deleted != nil { + if *deleted == true { + user.Deactivated = true + user.Deleted = true + } + } _, err = svc.DB.NewUpdate().Model(user).WherePK().Exec(ctx) if err != nil { return nil, err diff --git a/lib/tokens/jwt.go b/lib/tokens/jwt.go index 20709f3..b95c88c 100644 --- a/lib/tokens/jwt.go +++ b/lib/tokens/jwt.go @@ -15,13 +15,13 @@ import ( ) type jwtCustomClaims struct { - ID int64 `json:"id"` - IsRefresh bool `json:"isRefresh"` - MaxSendVolume int64 `json:"maxSendVolume"` - MaxSendAmount int64 `json:"maxSendAmount"` - MaxReceiveVolume int64 `json:"maxReceiveVolume"` - MaxReceiveAmount int64 `json:"maxReceiveAmount"` - MaxAccountBalance int64 `json:"maxAccountBalance"` + ID int64 `json:"id"` + IsRefresh bool `json:"isRefresh"` + MaxSendVolume int64 `json:"maxSendVolume"` + MaxSendAmount int64 `json:"maxSendAmount"` + MaxReceiveVolume int64 `json:"maxReceiveVolume"` + MaxReceiveAmount int64 `json:"maxReceiveAmount"` + MaxAccountBalance int64 `json:"maxAccountBalance"` jwt.StandardClaims }