add middlweare and authorization
This commit is contained in:
parent
4738c8c590
commit
61ed16163d
51
api/middleware.go
Normal file
51
api/middleware.go
Normal file
@ -0,0 +1,51 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.nochill.in/nochill/naice_pos/token"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
authorizationHeaderKey = "authorization"
|
||||
authorizationTypeBearer = "bearer"
|
||||
authorizationPayloadKey = "authorization_payload"
|
||||
)
|
||||
|
||||
func authMiddleware(tokenMaker token.Maker) gin.HandlerFunc {
|
||||
return func(ctx *gin.Context) {
|
||||
authorizationHeader := ctx.GetHeader(authorizationHeaderKey)
|
||||
if len(authorizationHeader) == 0 {
|
||||
err := errors.New("authorization header is not provided")
|
||||
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
|
||||
return
|
||||
}
|
||||
|
||||
fields := strings.Fields(authorizationHeader)
|
||||
if len(fields) < 2 {
|
||||
err := errors.New("Invalid authorization header format")
|
||||
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
|
||||
return
|
||||
}
|
||||
|
||||
authorizationType := strings.ToLower(fields[0])
|
||||
if authorizationType != authorizationTypeBearer {
|
||||
err := fmt.Errorf("Authorization only accept bearer type")
|
||||
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
|
||||
}
|
||||
|
||||
accessToken := fields[1]
|
||||
payload, err := tokenMaker.VerifyToken(accessToken)
|
||||
if err != nil {
|
||||
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Set(authorizationPayloadKey, payload)
|
||||
ctx.Next()
|
||||
}
|
||||
}
|
80
api/middleware_test.go
Normal file
80
api/middleware_test.go
Normal file
@ -0,0 +1,80 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.nochill.in/nochill/naice_pos/token"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func addAuthorization(
|
||||
t *testing.T,
|
||||
request *http.Request,
|
||||
tokenMaker token.Maker,
|
||||
authorizationType string,
|
||||
email string,
|
||||
merchantID string,
|
||||
duration time.Duration,
|
||||
) {
|
||||
token, err := tokenMaker.CreateToken(email, merchantID, duration)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
|
||||
authorizationHeader := fmt.Sprintf("%s %s", authorizationType, token)
|
||||
request.Header.Set(authorizationHeaderKey, authorizationHeader)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker)
|
||||
checkResponse func(t *testing.T, recorder *httptest.ResponseRecorder)
|
||||
}{
|
||||
{
|
||||
name: "OK",
|
||||
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
|
||||
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "email", uuid.New().String(), time.Minute)
|
||||
},
|
||||
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NoAuthorization",
|
||||
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {},
|
||||
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
require.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i := range testCases {
|
||||
tc := testCases[i]
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
server := newTestServer(t, nil)
|
||||
authPath := "/user/login"
|
||||
server.router.GET(
|
||||
authPath,
|
||||
authMiddleware(server.tokenMaker),
|
||||
func(ctx *gin.Context) {
|
||||
ctx.JSON(http.StatusOK, gin.H{})
|
||||
},
|
||||
)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request, err := http.NewRequest(http.MethodGet, authPath, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
tc.setupAuth(t, request, server.tokenMaker)
|
||||
server.router.ServeHTTP(recorder, request)
|
||||
tc.checkResponse(t, recorder)
|
||||
})
|
||||
}
|
||||
}
|
@ -2,10 +2,12 @@ package api
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
db "git.nochill.in/nochill/naice_pos/db/sqlc"
|
||||
"git.nochill.in/nochill/naice_pos/token"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@ -26,8 +28,9 @@ func (server *Server) createProduct(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload)
|
||||
arg := db.CreateProductParams{
|
||||
MerchantID: req.MerchantID,
|
||||
MerchantID: authPayload.MerchantID,
|
||||
Name: req.Name,
|
||||
SellingPrice: req.SellingPrice,
|
||||
PurchasePrice: req.PurchasePrice,
|
||||
@ -64,9 +67,44 @@ func (server *Server) getProduct(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload)
|
||||
if product.MerchantID != authPayload.MerchantID {
|
||||
err := errors.New("Product doesn't belong to the user")
|
||||
ctx.JSON(http.StatusUnauthorized, errorResponse(err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, product)
|
||||
}
|
||||
|
||||
type listProductsRequest struct {
|
||||
PageID int32 `form:"page_id" binding:"required,min=1"`
|
||||
PageSize int32 `form:"page_size" binding:"required,min=5"`
|
||||
}
|
||||
|
||||
func (server *Server) listProducts(ctx *gin.Context) {
|
||||
var req listProductsRequest
|
||||
if err := ctx.ShouldBindQuery(&req); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, errorResponse(err))
|
||||
return
|
||||
}
|
||||
|
||||
authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload)
|
||||
arg := db.ListProductsParams{
|
||||
MerchantID: authPayload.MerchantID,
|
||||
Limit: req.PageSize,
|
||||
Offset: (req.PageID - 1) * req.PageSize,
|
||||
}
|
||||
|
||||
products, err := server.store.ListProducts(ctx, arg)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, products)
|
||||
}
|
||||
|
||||
type updateProductRequest struct {
|
||||
ProductID uuid.UUID `json:"product_id" binding:"required"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
|
@ -9,9 +9,11 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
mockdb "git.nochill.in/nochill/naice_pos/db/mock"
|
||||
db "git.nochill.in/nochill/naice_pos/db/sqlc"
|
||||
"git.nochill.in/nochill/naice_pos/token"
|
||||
"git.nochill.in/nochill/naice_pos/util"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/google/uuid"
|
||||
@ -19,12 +21,14 @@ import (
|
||||
)
|
||||
|
||||
func TestGetProductApi(t *testing.T) {
|
||||
product := randomProduct()
|
||||
merchantID := "f9ca13cf-8ab3-4ee3-9530-521ae505caa2"
|
||||
product := randomProduct(merchantID)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
productID string
|
||||
buildStubs func(store *mockdb.MockStore)
|
||||
setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker)
|
||||
checkResponse func(t *testing.T, recorder *httptest.ResponseRecorder)
|
||||
}{
|
||||
{
|
||||
@ -36,6 +40,9 @@ func TestGetProductApi(t *testing.T) {
|
||||
Times(1).
|
||||
Return(product, nil)
|
||||
},
|
||||
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
|
||||
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "email", merchantID, time.Minute)
|
||||
},
|
||||
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
requireBodyMatchAccount(t, recorder.Body, product)
|
||||
@ -50,10 +57,43 @@ func TestGetProductApi(t *testing.T) {
|
||||
Times(1).
|
||||
Return(db.Product{}, sql.ErrNoRows)
|
||||
},
|
||||
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
|
||||
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "email", merchantID, time.Minute)
|
||||
},
|
||||
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
require.Equal(t, http.StatusNotFound, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Unauthorized",
|
||||
productID: product.ID.String(),
|
||||
buildStubs: func(store *mockdb.MockStore) {
|
||||
store.EXPECT().
|
||||
GetProduct(gomock.Any(), gomock.Eq(product.ID)).
|
||||
Times(1).
|
||||
Return(db.Product{}, sql.ErrNoRows)
|
||||
},
|
||||
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
|
||||
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "email", uuid.New().String(), time.Minute)
|
||||
},
|
||||
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
require.Equal(t, http.StatusNotFound, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NoAuthorization",
|
||||
productID: product.ID.String(),
|
||||
buildStubs: func(store *mockdb.MockStore) {
|
||||
store.EXPECT().
|
||||
GetProduct(gomock.Any(), gomock.Any()).
|
||||
Times(0)
|
||||
},
|
||||
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
|
||||
},
|
||||
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
require.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Internal Error",
|
||||
productID: product.ID.String(),
|
||||
@ -63,6 +103,9 @@ func TestGetProductApi(t *testing.T) {
|
||||
Times(1).
|
||||
Return(db.Product{}, sql.ErrConnDone)
|
||||
},
|
||||
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
|
||||
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "email", merchantID, time.Minute)
|
||||
},
|
||||
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
require.Equal(t, http.StatusInternalServerError, recorder.Code)
|
||||
},
|
||||
@ -75,6 +118,9 @@ func TestGetProductApi(t *testing.T) {
|
||||
GetProduct(gomock.Any(), gomock.Any()).
|
||||
Times(0)
|
||||
},
|
||||
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
|
||||
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "email", merchantID, time.Minute)
|
||||
},
|
||||
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
},
|
||||
@ -94,11 +140,12 @@ func TestGetProductApi(t *testing.T) {
|
||||
server := newTestServer(t, store)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
url := fmt.Sprintf("/product/%s", tc.productID)
|
||||
url := fmt.Sprintf("/api/product/%s", tc.productID)
|
||||
request, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
tc.setupAuth(t, request, server.tokenMaker)
|
||||
server.router.ServeHTTP(recorder, request)
|
||||
tc.checkResponse(t, recorder)
|
||||
})
|
||||
@ -106,10 +153,10 @@ func TestGetProductApi(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func randomProduct() db.Product {
|
||||
func randomProduct(merchantID string) db.Product {
|
||||
return db.Product{
|
||||
ID: uuid.New(),
|
||||
MerchantID: uuid.MustParse("a848090f-0409-4386-9caa-929ae6874dbb"),
|
||||
MerchantID: uuid.MustParse("f9ca13cf-8ab3-4ee3-9530-521ae505caa2"),
|
||||
Name: util.RandomString(5),
|
||||
SellingPrice: util.RandomFloat(1000, 99999),
|
||||
PurchasePrice: util.RandomFloat(999, 9999),
|
||||
|
@ -38,13 +38,15 @@ func (server *Server) getRoutes() {
|
||||
|
||||
router.POST("/user/merchants", server.createUserMerchant)
|
||||
|
||||
router.POST("/products", server.createProduct)
|
||||
router.PATCH("/products", server.updateProduct)
|
||||
router.GET("/product/:id", server.getProduct)
|
||||
apiRoutes := router.Group("/api").Use(authMiddleware(server.tokenMaker))
|
||||
|
||||
router.POST("/suppliers", server.createSupplier)
|
||||
apiRoutes.POST("/products", server.createProduct)
|
||||
apiRoutes.PATCH("/products", server.updateProduct)
|
||||
apiRoutes.GET("/product/:id", server.getProduct)
|
||||
|
||||
router.POST("/purchase-products", server.createPurchase)
|
||||
apiRoutes.POST("/suppliers", server.createSupplier)
|
||||
|
||||
apiRoutes.POST("/purchase-products", server.createPurchase)
|
||||
|
||||
server.router = router
|
||||
}
|
||||
|
@ -127,6 +127,7 @@ func (server *Server) loginUser(ctx *gin.Context) {
|
||||
|
||||
accessToken, err := server.tokenMaker.CreateToken(
|
||||
user.Email,
|
||||
outlet.ID.String(),
|
||||
server.config.TokenDuration,
|
||||
)
|
||||
|
||||
|
@ -47,7 +47,7 @@ func EqCreateUserMerchant(arg db.UserMerchantTxParams, password string) gomock.M
|
||||
}
|
||||
|
||||
func TestCreateUserMerchantAPI(t *testing.T) {
|
||||
userMerchant, password := RandomUser(t)
|
||||
userMerchant, password := randomUser(t)
|
||||
var userMerchantResult db.UserMerchantTxResult
|
||||
// var userMerchantResponse createUserMerchantResponse
|
||||
testCases := []struct {
|
||||
@ -109,7 +109,7 @@ func TestCreateUserMerchantAPI(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func RandomUser(t *testing.T) (userMerchant db.UserMerchantTxParams, password string) {
|
||||
func randomUser(t *testing.T) (userMerchant db.UserMerchantTxParams, password string) {
|
||||
password = util.RandomString(6)
|
||||
hashedPassword, err := util.HashPassword(password)
|
||||
require.NoError(t, err)
|
||||
|
@ -27,9 +27,10 @@ WHERE id = $1;
|
||||
|
||||
-- name: ListProducts :many
|
||||
SELECT * FROM products
|
||||
WHERE merchant_id = $1
|
||||
ORDER BY index_id
|
||||
LIMIT $1
|
||||
OFFSET $2;
|
||||
LIMIT $2
|
||||
OFFSET $3;
|
||||
|
||||
-- name: UpdateProduct :one
|
||||
UPDATE products
|
||||
|
@ -113,18 +113,20 @@ func (q *Queries) GetStockForUpdateStock(ctx context.Context, id uuid.UUID) (Pro
|
||||
|
||||
const listProducts = `-- name: ListProducts :many
|
||||
SELECT id, merchant_id, index_id, name, selling_price, purchase_price, stock, created_at, updated_at FROM products
|
||||
WHERE merchant_id = $1
|
||||
ORDER BY index_id
|
||||
LIMIT $1
|
||||
OFFSET $2
|
||||
LIMIT $2
|
||||
OFFSET $3
|
||||
`
|
||||
|
||||
type ListProductsParams struct {
|
||||
Limit int32 `json:"limit"`
|
||||
Offset int32 `json:"offset"`
|
||||
MerchantID uuid.UUID `json:"merchant_id"`
|
||||
Limit int32 `json:"limit"`
|
||||
Offset int32 `json:"offset"`
|
||||
}
|
||||
|
||||
func (q *Queries) ListProducts(ctx context.Context, arg ListProductsParams) ([]Product, error) {
|
||||
rows, err := q.db.QueryContext(ctx, listProducts, arg.Limit, arg.Offset)
|
||||
rows, err := q.db.QueryContext(ctx, listProducts, arg.MerchantID, arg.Limit, arg.Offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -106,21 +106,24 @@ func TestDeleteProduct(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetProducts(t *testing.T) {
|
||||
var lastProduct Product
|
||||
for i := 0; i < 6; i++ {
|
||||
createRandomProduct(t)
|
||||
lastProduct, _ = createRandomProduct(t)
|
||||
}
|
||||
|
||||
arg := ListProductsParams{
|
||||
Limit: 5,
|
||||
Offset: 1,
|
||||
MerchantID: lastProduct.MerchantID,
|
||||
Limit: 5,
|
||||
Offset: 0,
|
||||
}
|
||||
|
||||
products, err := testQueries.ListProducts(context.Background(), arg)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, products, 5)
|
||||
require.NotEmpty(t, products)
|
||||
|
||||
for _, product := range products {
|
||||
require.NotEmpty(t, product)
|
||||
require.Equal(t, lastProduct.MerchantID, product.MerchantID)
|
||||
}
|
||||
|
||||
}
|
||||
|
1
go.mod
1
go.mod
@ -4,7 +4,6 @@ go 1.20
|
||||
|
||||
require (
|
||||
github.com/aead/chacha20poly1305 v0.0.0-20201124145622-1a5aba2a8b29
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible
|
||||
github.com/gin-gonic/gin v1.9.0
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/google/uuid v1.3.0
|
||||
|
2
go.sum
2
go.sum
@ -62,8 +62,6 @@ github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnht
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
|
||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
||||
|
@ -1,58 +0,0 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
)
|
||||
|
||||
const minSecretKeySize = 32
|
||||
|
||||
type JWTMaker struct {
|
||||
secretKey string
|
||||
}
|
||||
|
||||
func NewJWTMaker(secretKey string) (Maker, error) {
|
||||
if len(secretKey) < minSecretKeySize {
|
||||
return nil, fmt.Errorf("Invalid key: must be at least %d characters", minSecretKeySize)
|
||||
}
|
||||
|
||||
return &JWTMaker{secretKey}, nil
|
||||
}
|
||||
|
||||
func (maker *JWTMaker) CreateToken(email string, duration time.Duration) (string, error) {
|
||||
payload, err := NewPayload(email, duration)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
|
||||
return jwtToken.SignedString([]byte(maker.secretKey))
|
||||
}
|
||||
|
||||
func (maker *JWTMaker) VerifyToken(token string) (*Payload, error) {
|
||||
keyFunc := func(token *jwt.Token) (interface{}, error) {
|
||||
_, ok := token.Method.(*jwt.SigningMethodHMAC)
|
||||
if !ok {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
return []byte(maker.secretKey), nil
|
||||
}
|
||||
jwtToken, err := jwt.ParseWithClaims(token, &Payload{}, keyFunc)
|
||||
if err != nil {
|
||||
verr, ok := err.(*jwt.ValidationError)
|
||||
if ok && errors.Is(verr.Inner, ErrExpiredToken) {
|
||||
return nil, ErrExpiredToken
|
||||
}
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
payload, ok := jwtToken.Claims.(*Payload)
|
||||
if !ok {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
@ -1,65 +0,0 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.nochill.in/nochill/naice_pos/util"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestJWTMaker(t *testing.T) {
|
||||
maker, err := NewJWTMaker(util.RandomString(32))
|
||||
require.NoError(t, err)
|
||||
|
||||
email := util.RandomEmail()
|
||||
duration := time.Minute
|
||||
|
||||
issuedAt := time.Now()
|
||||
expiredAt := issuedAt.Add(duration)
|
||||
|
||||
token, err := maker.CreateToken(email, duration)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
|
||||
payload, err := maker.VerifyToken(token)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, payload)
|
||||
|
||||
require.NotZero(t, payload.ID)
|
||||
require.Equal(t, email, payload.Email)
|
||||
require.WithinDuration(t, issuedAt, payload.IssuedAt, time.Second)
|
||||
require.WithinDuration(t, expiredAt, payload.ExpiredAt, time.Second)
|
||||
}
|
||||
|
||||
func TestExpiredToken(t *testing.T) {
|
||||
maker, err := NewJWTMaker(util.RandomString(32))
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := maker.CreateToken(util.RandomEmail(), -time.Minute)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
|
||||
payload, err := maker.VerifyToken(token)
|
||||
require.Error(t, err)
|
||||
require.EqualError(t, err, ErrExpiredToken.Error())
|
||||
require.Nil(t, payload)
|
||||
}
|
||||
|
||||
func TestInvalidJWTTokenAlgNone(t *testing.T) {
|
||||
payload, err := NewPayload(util.RandomEmail(), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodNone, payload)
|
||||
token, err := jwtToken.SignedString(jwt.UnsafeAllowNoneSignatureType)
|
||||
require.NoError(t, err)
|
||||
|
||||
maker, err := NewJWTMaker(util.RandomString(32))
|
||||
require.NoError(t, err)
|
||||
|
||||
payload, err = maker.VerifyToken(token)
|
||||
require.Error(t, err)
|
||||
require.EqualError(t, err, ErrInvalidToken.Error())
|
||||
require.Nil(t, payload)
|
||||
}
|
@ -5,6 +5,6 @@ import (
|
||||
)
|
||||
|
||||
type Maker interface {
|
||||
CreateToken(email string, duration time.Duration) (string, error)
|
||||
CreateToken(email string, merchantID string, duration time.Duration) (string, error)
|
||||
VerifyToken(token string) (*Payload, error)
|
||||
}
|
||||
|
@ -26,8 +26,8 @@ func NewPasetoMaker(symmetricKey string) (Maker, error) {
|
||||
return maker, nil
|
||||
}
|
||||
|
||||
func (maker *PasetoMaker) CreateToken(email string, duration time.Duration) (string, error) {
|
||||
payload, err := NewPayload(email, duration)
|
||||
func (maker *PasetoMaker) CreateToken(email string, merchant_id string, duration time.Duration) (string, error) {
|
||||
payload, err := NewPayload(email, merchant_id, duration)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -5,9 +5,12 @@ import (
|
||||
"time"
|
||||
|
||||
"git.nochill.in/nochill/naice_pos/util"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var merchantID = uuid.New().String()
|
||||
|
||||
func TestPasetoMaker(t *testing.T) {
|
||||
maker, err := NewPasetoMaker(util.RandomString(32))
|
||||
require.NoError(t, err)
|
||||
@ -18,7 +21,7 @@ func TestPasetoMaker(t *testing.T) {
|
||||
issuedAt := time.Now()
|
||||
expiredAt := issuedAt.Add(duration)
|
||||
|
||||
token, err := maker.CreateToken(email, duration)
|
||||
token, err := maker.CreateToken(email, merchantID, duration)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
|
||||
@ -36,7 +39,7 @@ func TestExpiredPasetoToken(t *testing.T) {
|
||||
maker, err := NewPasetoMaker(util.RandomString(32))
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := maker.CreateToken(util.RandomEmail(), -time.Minute)
|
||||
token, err := maker.CreateToken(util.RandomEmail(), merchantID, -time.Minute)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
|
||||
|
@ -13,23 +13,25 @@ var (
|
||||
)
|
||||
|
||||
type Payload struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
IssuedAt time.Time `json:"issued_at"`
|
||||
ExpiredAt time.Time `json:"expired_at"`
|
||||
ID uuid.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
MerchantID uuid.UUID `json:"merchant_id"`
|
||||
IssuedAt time.Time `json:"issued_at"`
|
||||
ExpiredAt time.Time `json:"expired_at"`
|
||||
}
|
||||
|
||||
func NewPayload(email string, duration time.Duration) (*Payload, error) {
|
||||
func NewPayload(email string, merchant_id string, duration time.Duration) (*Payload, error) {
|
||||
tokenID, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload := &Payload{
|
||||
ID: tokenID,
|
||||
Email: email,
|
||||
IssuedAt: time.Now(),
|
||||
ExpiredAt: time.Now().Add(duration),
|
||||
ID: tokenID,
|
||||
Email: email,
|
||||
MerchantID: uuid.MustParse(merchant_id),
|
||||
IssuedAt: time.Now(),
|
||||
ExpiredAt: time.Now().Add(duration),
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user