diff --git a/controllers/addinvoice.ctrl.go b/controllers/addinvoice.ctrl.go index 47c2467..395d34c 100644 --- a/controllers/addinvoice.ctrl.go +++ b/controllers/addinvoice.ctrl.go @@ -7,7 +7,6 @@ import ( "github.com/bumi/lndhub.go/db/models" "github.com/bumi/lndhub.go/lib" - "github.com/golang-jwt/jwt" "github.com/labstack/echo/v4" "github.com/labstack/gommon/random" "github.com/sirupsen/logrus" @@ -19,9 +18,7 @@ type AddInvoiceController struct{} // AddInvoice : Add invoice Controller func (AddInvoiceController) AddInvoice(c echo.Context) error { ctx := c.(*lib.LndhubContext) - user := c.Get("user").(*jwt.Token) - claims := user.Claims.(jwt.MapClaims) - userID := claims["id"].(float64) + user := ctx.User type RequestBody struct { Amt uint `json:"amt" validate:"required"` @@ -47,7 +44,7 @@ func (AddInvoiceController) AddInvoice(c echo.Context) error { invoice := models.Invoice{ Type: "", - UserID: uint(userID), + UserID: user.ID, TransactionEntryID: 0, Amount: body.Amt, Memo: body.Memo, diff --git a/controllers/auth.ctrl.go b/controllers/auth.ctrl.go index bd8b957..7bc7d96 100644 --- a/controllers/auth.ctrl.go +++ b/controllers/auth.ctrl.go @@ -12,10 +12,12 @@ import ( ) // AuthController : AuthController struct -type AuthController struct{} +type AuthController struct { + JWTSecret []byte +} // Auth : Auth Controller -func (AuthController) Auth(c echo.Context) error { +func (ctrl AuthController) Auth(c echo.Context) error { ctx := c.(*lib.LndhubContext) type RequestBody struct { Login string `json:"login"` @@ -78,12 +80,12 @@ func (AuthController) Auth(c echo.Context) error { }) } - accessToken, err := tokens.GenerateAccessToken(&user) + accessToken, err := tokens.GenerateAccessToken(ctrl.JWTSecret, &user) if err != nil { return err } - refreshToken, err := tokens.GenerateRefreshToken(&user) + refreshToken, err := tokens.GenerateRefreshToken(ctrl.JWTSecret, &user) if err != nil { return err } diff --git a/db/models/invoice.go b/db/models/invoice.go index 7b093e9..3d28c73 100644 --- a/db/models/invoice.go +++ b/db/models/invoice.go @@ -11,7 +11,7 @@ import ( type Invoice struct { ID uint `json:"id" bun:",pk,autoincrement"` Type string `json:"type"` - UserID uint `json:"user_id"` + UserID int64 `json:"user_id"` TransactionEntryID uint `json:"transaction_entry_id"` Amount uint `json:"amount"` Memo string `json:"memo"` diff --git a/lib/context.go b/lib/context.go index 2becbab..20cc943 100644 --- a/lib/context.go +++ b/lib/context.go @@ -1,6 +1,7 @@ package lib import ( + "github.com/bumi/lndhub.go/db/models" "github.com/labstack/echo/v4" "github.com/uptrace/bun" ) @@ -8,5 +9,6 @@ import ( type LndhubContext struct { echo.Context - DB *bun.DB + DB *bun.DB + User *models.User } diff --git a/lib/tokens/jwt.go b/lib/tokens/jwt.go index 2be7408..93c86b7 100644 --- a/lib/tokens/jwt.go +++ b/lib/tokens/jwt.go @@ -1,57 +1,74 @@ package tokens import ( - "time" + "context" + "database/sql" + "errors" + "net/http" "github.com/bumi/lndhub.go/db/models" "github.com/bumi/lndhub.go/lib" "github.com/golang-jwt/jwt" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" + "github.com/sirupsen/logrus" + "github.com/uptrace/bun" ) -type jwtCustomClaims struct { - ID int64 `json:"id"` - - jwt.StandardClaims -} - -func Middleware() echo.MiddlewareFunc { +func Middleware(secret []byte) echo.MiddlewareFunc { config := middleware.JWTConfig{ ContextKey: "UserJwt", - SigningKey: []byte("secret"), + SigningKey: secret, SuccessHandler: func(c echo.Context) { - ctx := c.(*lib.LndhubContext) - token := ctx.Get("UserJwt").(*jwt.Token) + token := c.Get("UserJwt").(*jwt.Token) claims := token.Claims.(jwt.MapClaims) - userId := claims["id"].(float64) - ctx.Set("UserId", userId) + c.Set("UserID", claims["id"]) + }, } return middleware.JWTWithConfig(config) } -// GenerateAccessToken : Generate Access Token -func GenerateAccessToken(u *models.User) (string, error) { - claims := &jwtCustomClaims{ - u.ID, - jwt.StandardClaims{ - // one week expiration - ExpiresAt: time.Now().Add(time.Hour * 27 * 7).Unix(), - }, - } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) +func UserMiddleware(db *bun.DB) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + ctx := c.(lib.LndhubContext) + userId := c.Get("UserID").(int64) - t, err := token.SignedString([]byte("secret")) - if err != nil { - return "", err - } + var user models.User - return t, nil + err := db.NewSelect().Model(&user).Where("id = ?", userId).Scan(context.TODO()) + switch { + case errors.Is(err, sql.ErrNoRows): + return echo.NewHTTPError(http.StatusNotFound, "user with given ID is not found") + case err != nil: + logrus.Errorf("database error: %v", err) + return echo.NewHTTPError(http.StatusInternalServerError) + } + + ctx.User = &user + + return nil + } + } } -// GenerateRefreshToken : Generate Refresh Token -func GenerateRefreshToken(u *models.User) (string, error) { +// GenerateAccessToken : Generate Access Token +func GenerateAccessToken(secret []byte, u *models.User) (string, error) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "id": u.ID, + }) + + t, err := token.SignedString([]byte("secret")) + if err != nil { + return "", err + } + + return t, nil +} + +// GenerateRefreshToken : Generate Refresh Token +func GenerateRefreshToken(secret []byte, u *models.User) (string, error) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "id": u.ID, }) diff --git a/main.go b/main.go index 546706a..9fbfdad 100644 --- a/main.go +++ b/main.go @@ -28,6 +28,7 @@ type Config struct { DatabaseUri string `envconfig:"DATABASE_URI" required:"true"` SentryDSN string `envconfig:"SENTRY_DSN"` LogFilePath string `envconfig:"LOG_FILE_PATH"` + JWTSecret []byte `envconfig:"JWT_SECRET" default:"secret"` } func main() { @@ -106,10 +107,10 @@ func main() { e.Use(middleware.BodyLimit("250K")) e.Use(middleware.RateLimiter(middleware.NewRateLimiterMemoryStore(20))) - e.POST("/auth", controllers.AuthController{}.Auth) + e.POST("/auth", controllers.AuthController{JWTSecret: c.JWTSecret}.Auth) e.POST("/create", controllers.CreateUserController{}.CreateUser) - secured := e.Group("", tokens.Middleware()) + secured := e.Group("", tokens.Middleware(c.JWTSecret), tokens.UserMiddleware(dbConn)) secured.POST("/addinvoice", controllers.AddInvoiceController{}.AddInvoice) secured.POST("/payinvoice", controllers.PayInvoiceController{}.PayInvoice) secured.GET("/gettxs", controllers.GetTXSController{}.GetTXS)