From 61ed16163d772e181e84a72b6fd48595032d8145 Mon Sep 17 00:00:00 2001 From: nochill Date: Wed, 15 Mar 2023 15:00:36 +0700 Subject: [PATCH] add middlweare and authorization --- api/middleware.go | 51 ++++++++++++++++++++++++ api/middleware_test.go | 80 ++++++++++++++++++++++++++++++++++++++ api/product.go | 40 ++++++++++++++++++- api/product_test.go | 55 ++++++++++++++++++++++++-- api/server.go | 12 +++--- api/user.go | 1 + api/user_test.go | 4 +- db/query/products.sql | 5 ++- db/sqlc/products.sql.go | 12 +++--- db/sqlc/products_test.go | 11 ++++-- go.mod | 1 - go.sum | 2 - token/jwt_maker.go | 58 --------------------------- token/jwt_maker_test.go | 65 ------------------------------- token/maker.go | 2 +- token/paseto_maker.go | 4 +- token/paseto_maker_test.go | 7 +++- token/payload.go | 20 +++++----- 18 files changed, 267 insertions(+), 163 deletions(-) create mode 100644 api/middleware.go create mode 100644 api/middleware_test.go delete mode 100644 token/jwt_maker.go delete mode 100644 token/jwt_maker_test.go diff --git a/api/middleware.go b/api/middleware.go new file mode 100644 index 0000000..09f0f55 --- /dev/null +++ b/api/middleware.go @@ -0,0 +1,51 @@ +package api + +import ( + "errors" + "fmt" + "net/http" + "strings" + + "git.nochill.in/nochill/naice_pos/token" + "github.com/gin-gonic/gin" +) + +const ( + authorizationHeaderKey = "authorization" + authorizationTypeBearer = "bearer" + authorizationPayloadKey = "authorization_payload" +) + +func authMiddleware(tokenMaker token.Maker) gin.HandlerFunc { + return func(ctx *gin.Context) { + authorizationHeader := ctx.GetHeader(authorizationHeaderKey) + if len(authorizationHeader) == 0 { + err := errors.New("authorization header is not provided") + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + fields := strings.Fields(authorizationHeader) + if len(fields) < 2 { + err := errors.New("Invalid authorization header format") + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + authorizationType := strings.ToLower(fields[0]) + if authorizationType != authorizationTypeBearer { + err := fmt.Errorf("Authorization only accept bearer type") + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) + } + + accessToken := fields[1] + payload, err := tokenMaker.VerifyToken(accessToken) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + ctx.Set(authorizationPayloadKey, payload) + ctx.Next() + } +} diff --git a/api/middleware_test.go b/api/middleware_test.go new file mode 100644 index 0000000..ca6ffbf --- /dev/null +++ b/api/middleware_test.go @@ -0,0 +1,80 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "git.nochill.in/nochill/naice_pos/token" + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +func addAuthorization( + t *testing.T, + request *http.Request, + tokenMaker token.Maker, + authorizationType string, + email string, + merchantID string, + duration time.Duration, +) { + token, err := tokenMaker.CreateToken(email, merchantID, duration) + require.NoError(t, err) + require.NotEmpty(t, token) + + authorizationHeader := fmt.Sprintf("%s %s", authorizationType, token) + request.Header.Set(authorizationHeaderKey, authorizationHeader) +} + +func TestAuthMiddleware(t *testing.T) { + testCases := []struct { + name string + setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker) + checkResponse func(t *testing.T, recorder *httptest.ResponseRecorder) + }{ + { + name: "OK", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "email", uuid.New().String(), time.Minute) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusOK, recorder.Code) + }, + }, + { + name: "NoAuthorization", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {}, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + } + + for i := range testCases { + tc := testCases[i] + + t.Run(tc.name, func(t *testing.T) { + server := newTestServer(t, nil) + authPath := "/user/login" + server.router.GET( + authPath, + authMiddleware(server.tokenMaker), + func(ctx *gin.Context) { + ctx.JSON(http.StatusOK, gin.H{}) + }, + ) + + recorder := httptest.NewRecorder() + request, err := http.NewRequest(http.MethodGet, authPath, nil) + require.NoError(t, err) + + tc.setupAuth(t, request, server.tokenMaker) + server.router.ServeHTTP(recorder, request) + tc.checkResponse(t, recorder) + }) + } +} diff --git a/api/product.go b/api/product.go index a51e00e..66d55a5 100644 --- a/api/product.go +++ b/api/product.go @@ -2,10 +2,12 @@ package api import ( "database/sql" + "errors" "net/http" "time" db "git.nochill.in/nochill/naice_pos/db/sqlc" + "git.nochill.in/nochill/naice_pos/token" "github.com/gin-gonic/gin" "github.com/google/uuid" ) @@ -26,8 +28,9 @@ func (server *Server) createProduct(ctx *gin.Context) { return } + authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload) arg := db.CreateProductParams{ - MerchantID: req.MerchantID, + MerchantID: authPayload.MerchantID, Name: req.Name, SellingPrice: req.SellingPrice, PurchasePrice: req.PurchasePrice, @@ -64,9 +67,44 @@ func (server *Server) getProduct(ctx *gin.Context) { return } + authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload) + if product.MerchantID != authPayload.MerchantID { + err := errors.New("Product doesn't belong to the user") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + ctx.JSON(http.StatusOK, product) } +type listProductsRequest struct { + PageID int32 `form:"page_id" binding:"required,min=1"` + PageSize int32 `form:"page_size" binding:"required,min=5"` +} + +func (server *Server) listProducts(ctx *gin.Context) { + var req listProductsRequest + if err := ctx.ShouldBindQuery(&req); err != nil { + ctx.JSON(http.StatusBadRequest, errorResponse(err)) + return + } + + authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload) + arg := db.ListProductsParams{ + MerchantID: authPayload.MerchantID, + Limit: req.PageSize, + Offset: (req.PageID - 1) * req.PageSize, + } + + products, err := server.store.ListProducts(ctx, arg) + if err != nil { + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + return + } + + ctx.JSON(http.StatusOK, products) +} + type updateProductRequest struct { ProductID uuid.UUID `json:"product_id" binding:"required"` Name string `json:"name" binding:"required"` diff --git a/api/product_test.go b/api/product_test.go index 0bceb37..274e00c 100644 --- a/api/product_test.go +++ b/api/product_test.go @@ -9,9 +9,11 @@ import ( "net/http" "net/http/httptest" "testing" + "time" mockdb "git.nochill.in/nochill/naice_pos/db/mock" db "git.nochill.in/nochill/naice_pos/db/sqlc" + "git.nochill.in/nochill/naice_pos/token" "git.nochill.in/nochill/naice_pos/util" "github.com/golang/mock/gomock" "github.com/google/uuid" @@ -19,12 +21,14 @@ import ( ) func TestGetProductApi(t *testing.T) { - product := randomProduct() + merchantID := "f9ca13cf-8ab3-4ee3-9530-521ae505caa2" + product := randomProduct(merchantID) testCases := []struct { name string productID string buildStubs func(store *mockdb.MockStore) + setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker) checkResponse func(t *testing.T, recorder *httptest.ResponseRecorder) }{ { @@ -36,6 +40,9 @@ func TestGetProductApi(t *testing.T) { Times(1). Return(product, nil) }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "email", merchantID, time.Minute) + }, checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { require.Equal(t, http.StatusOK, recorder.Code) requireBodyMatchAccount(t, recorder.Body, product) @@ -50,10 +57,43 @@ func TestGetProductApi(t *testing.T) { Times(1). Return(db.Product{}, sql.ErrNoRows) }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "email", merchantID, time.Minute) + }, checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { require.Equal(t, http.StatusNotFound, recorder.Code) }, }, + { + name: "Unauthorized", + productID: product.ID.String(), + buildStubs: func(store *mockdb.MockStore) { + store.EXPECT(). + GetProduct(gomock.Any(), gomock.Eq(product.ID)). + Times(1). + Return(db.Product{}, sql.ErrNoRows) + }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "email", uuid.New().String(), time.Minute) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusNotFound, recorder.Code) + }, + }, + { + name: "NoAuthorization", + productID: product.ID.String(), + buildStubs: func(store *mockdb.MockStore) { + store.EXPECT(). + GetProduct(gomock.Any(), gomock.Any()). + Times(0) + }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, { name: "Internal Error", productID: product.ID.String(), @@ -63,6 +103,9 @@ func TestGetProductApi(t *testing.T) { Times(1). Return(db.Product{}, sql.ErrConnDone) }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "email", merchantID, time.Minute) + }, checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { require.Equal(t, http.StatusInternalServerError, recorder.Code) }, @@ -75,6 +118,9 @@ func TestGetProductApi(t *testing.T) { GetProduct(gomock.Any(), gomock.Any()). Times(0) }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "email", merchantID, time.Minute) + }, checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { require.Equal(t, http.StatusBadRequest, recorder.Code) }, @@ -94,11 +140,12 @@ func TestGetProductApi(t *testing.T) { server := newTestServer(t, store) recorder := httptest.NewRecorder() - url := fmt.Sprintf("/product/%s", tc.productID) + url := fmt.Sprintf("/api/product/%s", tc.productID) request, err := http.NewRequest(http.MethodGet, url, nil) require.NoError(t, err) + tc.setupAuth(t, request, server.tokenMaker) server.router.ServeHTTP(recorder, request) tc.checkResponse(t, recorder) }) @@ -106,10 +153,10 @@ func TestGetProductApi(t *testing.T) { } } -func randomProduct() db.Product { +func randomProduct(merchantID string) db.Product { return db.Product{ ID: uuid.New(), - MerchantID: uuid.MustParse("a848090f-0409-4386-9caa-929ae6874dbb"), + MerchantID: uuid.MustParse("f9ca13cf-8ab3-4ee3-9530-521ae505caa2"), Name: util.RandomString(5), SellingPrice: util.RandomFloat(1000, 99999), PurchasePrice: util.RandomFloat(999, 9999), diff --git a/api/server.go b/api/server.go index 8349dad..b915c73 100644 --- a/api/server.go +++ b/api/server.go @@ -38,13 +38,15 @@ func (server *Server) getRoutes() { router.POST("/user/merchants", server.createUserMerchant) - router.POST("/products", server.createProduct) - router.PATCH("/products", server.updateProduct) - router.GET("/product/:id", server.getProduct) + apiRoutes := router.Group("/api").Use(authMiddleware(server.tokenMaker)) - router.POST("/suppliers", server.createSupplier) + apiRoutes.POST("/products", server.createProduct) + apiRoutes.PATCH("/products", server.updateProduct) + apiRoutes.GET("/product/:id", server.getProduct) - router.POST("/purchase-products", server.createPurchase) + apiRoutes.POST("/suppliers", server.createSupplier) + + apiRoutes.POST("/purchase-products", server.createPurchase) server.router = router } diff --git a/api/user.go b/api/user.go index ffa3a67..dd7dba1 100644 --- a/api/user.go +++ b/api/user.go @@ -127,6 +127,7 @@ func (server *Server) loginUser(ctx *gin.Context) { accessToken, err := server.tokenMaker.CreateToken( user.Email, + outlet.ID.String(), server.config.TokenDuration, ) diff --git a/api/user_test.go b/api/user_test.go index 9bb5204..4fff03b 100644 --- a/api/user_test.go +++ b/api/user_test.go @@ -47,7 +47,7 @@ func EqCreateUserMerchant(arg db.UserMerchantTxParams, password string) gomock.M } func TestCreateUserMerchantAPI(t *testing.T) { - userMerchant, password := RandomUser(t) + userMerchant, password := randomUser(t) var userMerchantResult db.UserMerchantTxResult // var userMerchantResponse createUserMerchantResponse testCases := []struct { @@ -109,7 +109,7 @@ func TestCreateUserMerchantAPI(t *testing.T) { } } -func RandomUser(t *testing.T) (userMerchant db.UserMerchantTxParams, password string) { +func randomUser(t *testing.T) (userMerchant db.UserMerchantTxParams, password string) { password = util.RandomString(6) hashedPassword, err := util.HashPassword(password) require.NoError(t, err) diff --git a/db/query/products.sql b/db/query/products.sql index 959c082..11b0f01 100644 --- a/db/query/products.sql +++ b/db/query/products.sql @@ -27,9 +27,10 @@ WHERE id = $1; -- name: ListProducts :many SELECT * FROM products +WHERE merchant_id = $1 ORDER BY index_id -LIMIT $1 -OFFSET $2; +LIMIT $2 +OFFSET $3; -- name: UpdateProduct :one UPDATE products diff --git a/db/sqlc/products.sql.go b/db/sqlc/products.sql.go index a19dc97..e7ea1ce 100644 --- a/db/sqlc/products.sql.go +++ b/db/sqlc/products.sql.go @@ -113,18 +113,20 @@ func (q *Queries) GetStockForUpdateStock(ctx context.Context, id uuid.UUID) (Pro const listProducts = `-- name: ListProducts :many SELECT id, merchant_id, index_id, name, selling_price, purchase_price, stock, created_at, updated_at FROM products +WHERE merchant_id = $1 ORDER BY index_id -LIMIT $1 -OFFSET $2 +LIMIT $2 +OFFSET $3 ` type ListProductsParams struct { - Limit int32 `json:"limit"` - Offset int32 `json:"offset"` + MerchantID uuid.UUID `json:"merchant_id"` + Limit int32 `json:"limit"` + Offset int32 `json:"offset"` } func (q *Queries) ListProducts(ctx context.Context, arg ListProductsParams) ([]Product, error) { - rows, err := q.db.QueryContext(ctx, listProducts, arg.Limit, arg.Offset) + rows, err := q.db.QueryContext(ctx, listProducts, arg.MerchantID, arg.Limit, arg.Offset) if err != nil { return nil, err } diff --git a/db/sqlc/products_test.go b/db/sqlc/products_test.go index 906d4a4..6098649 100644 --- a/db/sqlc/products_test.go +++ b/db/sqlc/products_test.go @@ -106,21 +106,24 @@ func TestDeleteProduct(t *testing.T) { } func TestGetProducts(t *testing.T) { + var lastProduct Product for i := 0; i < 6; i++ { - createRandomProduct(t) + lastProduct, _ = createRandomProduct(t) } arg := ListProductsParams{ - Limit: 5, - Offset: 1, + MerchantID: lastProduct.MerchantID, + Limit: 5, + Offset: 0, } products, err := testQueries.ListProducts(context.Background(), arg) require.NoError(t, err) - require.Len(t, products, 5) + require.NotEmpty(t, products) for _, product := range products { require.NotEmpty(t, product) + require.Equal(t, lastProduct.MerchantID, product.MerchantID) } } diff --git a/go.mod b/go.mod index 63a418a..a341c0f 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.20 require ( github.com/aead/chacha20poly1305 v0.0.0-20201124145622-1a5aba2a8b29 - github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/gin-gonic/gin v1.9.0 github.com/golang/mock v1.6.0 github.com/google/uuid v1.3.0 diff --git a/go.sum b/go.sum index e2d6e1d..afba6a4 100644 --- a/go.sum +++ b/go.sum @@ -62,8 +62,6 @@ github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnht github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= -github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= diff --git a/token/jwt_maker.go b/token/jwt_maker.go deleted file mode 100644 index c6104d4..0000000 --- a/token/jwt_maker.go +++ /dev/null @@ -1,58 +0,0 @@ -package token - -import ( - "errors" - "fmt" - "time" - - "github.com/dgrijalva/jwt-go" -) - -const minSecretKeySize = 32 - -type JWTMaker struct { - secretKey string -} - -func NewJWTMaker(secretKey string) (Maker, error) { - if len(secretKey) < minSecretKeySize { - return nil, fmt.Errorf("Invalid key: must be at least %d characters", minSecretKeySize) - } - - return &JWTMaker{secretKey}, nil -} - -func (maker *JWTMaker) CreateToken(email string, duration time.Duration) (string, error) { - payload, err := NewPayload(email, duration) - if err != nil { - return "", err - } - - jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, payload) - return jwtToken.SignedString([]byte(maker.secretKey)) -} - -func (maker *JWTMaker) VerifyToken(token string) (*Payload, error) { - keyFunc := func(token *jwt.Token) (interface{}, error) { - _, ok := token.Method.(*jwt.SigningMethodHMAC) - if !ok { - return nil, ErrInvalidToken - } - return []byte(maker.secretKey), nil - } - jwtToken, err := jwt.ParseWithClaims(token, &Payload{}, keyFunc) - if err != nil { - verr, ok := err.(*jwt.ValidationError) - if ok && errors.Is(verr.Inner, ErrExpiredToken) { - return nil, ErrExpiredToken - } - return nil, ErrInvalidToken - } - - payload, ok := jwtToken.Claims.(*Payload) - if !ok { - return nil, ErrInvalidToken - } - - return payload, nil -} diff --git a/token/jwt_maker_test.go b/token/jwt_maker_test.go deleted file mode 100644 index 9b1faa1..0000000 --- a/token/jwt_maker_test.go +++ /dev/null @@ -1,65 +0,0 @@ -package token - -import ( - "testing" - "time" - - "git.nochill.in/nochill/naice_pos/util" - "github.com/dgrijalva/jwt-go" - "github.com/stretchr/testify/require" -) - -func TestJWTMaker(t *testing.T) { - maker, err := NewJWTMaker(util.RandomString(32)) - require.NoError(t, err) - - email := util.RandomEmail() - duration := time.Minute - - issuedAt := time.Now() - expiredAt := issuedAt.Add(duration) - - token, err := maker.CreateToken(email, duration) - require.NoError(t, err) - require.NotEmpty(t, token) - - payload, err := maker.VerifyToken(token) - require.NoError(t, err) - require.NotEmpty(t, payload) - - require.NotZero(t, payload.ID) - require.Equal(t, email, payload.Email) - require.WithinDuration(t, issuedAt, payload.IssuedAt, time.Second) - require.WithinDuration(t, expiredAt, payload.ExpiredAt, time.Second) -} - -func TestExpiredToken(t *testing.T) { - maker, err := NewJWTMaker(util.RandomString(32)) - require.NoError(t, err) - - token, err := maker.CreateToken(util.RandomEmail(), -time.Minute) - require.NoError(t, err) - require.NotEmpty(t, token) - - payload, err := maker.VerifyToken(token) - require.Error(t, err) - require.EqualError(t, err, ErrExpiredToken.Error()) - require.Nil(t, payload) -} - -func TestInvalidJWTTokenAlgNone(t *testing.T) { - payload, err := NewPayload(util.RandomEmail(), time.Minute) - require.NoError(t, err) - - jwtToken := jwt.NewWithClaims(jwt.SigningMethodNone, payload) - token, err := jwtToken.SignedString(jwt.UnsafeAllowNoneSignatureType) - require.NoError(t, err) - - maker, err := NewJWTMaker(util.RandomString(32)) - require.NoError(t, err) - - payload, err = maker.VerifyToken(token) - require.Error(t, err) - require.EqualError(t, err, ErrInvalidToken.Error()) - require.Nil(t, payload) -} diff --git a/token/maker.go b/token/maker.go index 85313f5..17ae30e 100644 --- a/token/maker.go +++ b/token/maker.go @@ -5,6 +5,6 @@ import ( ) type Maker interface { - CreateToken(email string, duration time.Duration) (string, error) + CreateToken(email string, merchantID string, duration time.Duration) (string, error) VerifyToken(token string) (*Payload, error) } diff --git a/token/paseto_maker.go b/token/paseto_maker.go index 0fea221..5f2faac 100644 --- a/token/paseto_maker.go +++ b/token/paseto_maker.go @@ -26,8 +26,8 @@ func NewPasetoMaker(symmetricKey string) (Maker, error) { return maker, nil } -func (maker *PasetoMaker) CreateToken(email string, duration time.Duration) (string, error) { - payload, err := NewPayload(email, duration) +func (maker *PasetoMaker) CreateToken(email string, merchant_id string, duration time.Duration) (string, error) { + payload, err := NewPayload(email, merchant_id, duration) if err != nil { return "", err } diff --git a/token/paseto_maker_test.go b/token/paseto_maker_test.go index 7e1c267..d8c6bcc 100644 --- a/token/paseto_maker_test.go +++ b/token/paseto_maker_test.go @@ -5,9 +5,12 @@ import ( "time" "git.nochill.in/nochill/naice_pos/util" + "github.com/google/uuid" "github.com/stretchr/testify/require" ) +var merchantID = uuid.New().String() + func TestPasetoMaker(t *testing.T) { maker, err := NewPasetoMaker(util.RandomString(32)) require.NoError(t, err) @@ -18,7 +21,7 @@ func TestPasetoMaker(t *testing.T) { issuedAt := time.Now() expiredAt := issuedAt.Add(duration) - token, err := maker.CreateToken(email, duration) + token, err := maker.CreateToken(email, merchantID, duration) require.NoError(t, err) require.NotEmpty(t, token) @@ -36,7 +39,7 @@ func TestExpiredPasetoToken(t *testing.T) { maker, err := NewPasetoMaker(util.RandomString(32)) require.NoError(t, err) - token, err := maker.CreateToken(util.RandomEmail(), -time.Minute) + token, err := maker.CreateToken(util.RandomEmail(), merchantID, -time.Minute) require.NoError(t, err) require.NotEmpty(t, token) diff --git a/token/payload.go b/token/payload.go index 1151fbc..e456692 100644 --- a/token/payload.go +++ b/token/payload.go @@ -13,23 +13,25 @@ var ( ) type Payload struct { - ID uuid.UUID `json:"id"` - Email string `json:"email"` - IssuedAt time.Time `json:"issued_at"` - ExpiredAt time.Time `json:"expired_at"` + ID uuid.UUID `json:"id"` + Email string `json:"email"` + MerchantID uuid.UUID `json:"merchant_id"` + IssuedAt time.Time `json:"issued_at"` + ExpiredAt time.Time `json:"expired_at"` } -func NewPayload(email string, duration time.Duration) (*Payload, error) { +func NewPayload(email string, merchant_id string, duration time.Duration) (*Payload, error) { tokenID, err := uuid.NewRandom() if err != nil { return nil, err } payload := &Payload{ - ID: tokenID, - Email: email, - IssuedAt: time.Now(), - ExpiredAt: time.Now().Add(duration), + ID: tokenID, + Email: email, + MerchantID: uuid.MustParse(merchant_id), + IssuedAt: time.Now(), + ExpiredAt: time.Now().Add(duration), } return payload, nil }