add middlweare and authorization

This commit is contained in:
nochill 2023-03-15 15:00:36 +07:00
parent 4738c8c590
commit 61ed16163d
18 changed files with 267 additions and 163 deletions

51
api/middleware.go Normal file
View File

@ -0,0 +1,51 @@
package api
import (
"errors"
"fmt"
"net/http"
"strings"
"git.nochill.in/nochill/naice_pos/token"
"github.com/gin-gonic/gin"
)
const (
authorizationHeaderKey = "authorization"
authorizationTypeBearer = "bearer"
authorizationPayloadKey = "authorization_payload"
)
func authMiddleware(tokenMaker token.Maker) gin.HandlerFunc {
return func(ctx *gin.Context) {
authorizationHeader := ctx.GetHeader(authorizationHeaderKey)
if len(authorizationHeader) == 0 {
err := errors.New("authorization header is not provided")
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
return
}
fields := strings.Fields(authorizationHeader)
if len(fields) < 2 {
err := errors.New("Invalid authorization header format")
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
return
}
authorizationType := strings.ToLower(fields[0])
if authorizationType != authorizationTypeBearer {
err := fmt.Errorf("Authorization only accept bearer type")
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
}
accessToken := fields[1]
payload, err := tokenMaker.VerifyToken(accessToken)
if err != nil {
ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err))
return
}
ctx.Set(authorizationPayloadKey, payload)
ctx.Next()
}
}

80
api/middleware_test.go Normal file
View File

@ -0,0 +1,80 @@
package api
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"git.nochill.in/nochill/naice_pos/token"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)
func addAuthorization(
t *testing.T,
request *http.Request,
tokenMaker token.Maker,
authorizationType string,
email string,
merchantID string,
duration time.Duration,
) {
token, err := tokenMaker.CreateToken(email, merchantID, duration)
require.NoError(t, err)
require.NotEmpty(t, token)
authorizationHeader := fmt.Sprintf("%s %s", authorizationType, token)
request.Header.Set(authorizationHeaderKey, authorizationHeader)
}
func TestAuthMiddleware(t *testing.T) {
testCases := []struct {
name string
setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker)
checkResponse func(t *testing.T, recorder *httptest.ResponseRecorder)
}{
{
name: "OK",
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "email", uuid.New().String(), time.Minute)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusOK, recorder.Code)
},
},
{
name: "NoAuthorization",
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusUnauthorized, recorder.Code)
},
},
}
for i := range testCases {
tc := testCases[i]
t.Run(tc.name, func(t *testing.T) {
server := newTestServer(t, nil)
authPath := "/user/login"
server.router.GET(
authPath,
authMiddleware(server.tokenMaker),
func(ctx *gin.Context) {
ctx.JSON(http.StatusOK, gin.H{})
},
)
recorder := httptest.NewRecorder()
request, err := http.NewRequest(http.MethodGet, authPath, nil)
require.NoError(t, err)
tc.setupAuth(t, request, server.tokenMaker)
server.router.ServeHTTP(recorder, request)
tc.checkResponse(t, recorder)
})
}
}

View File

@ -2,10 +2,12 @@ package api
import (
"database/sql"
"errors"
"net/http"
"time"
db "git.nochill.in/nochill/naice_pos/db/sqlc"
"git.nochill.in/nochill/naice_pos/token"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
@ -26,8 +28,9 @@ func (server *Server) createProduct(ctx *gin.Context) {
return
}
authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload)
arg := db.CreateProductParams{
MerchantID: req.MerchantID,
MerchantID: authPayload.MerchantID,
Name: req.Name,
SellingPrice: req.SellingPrice,
PurchasePrice: req.PurchasePrice,
@ -64,9 +67,44 @@ func (server *Server) getProduct(ctx *gin.Context) {
return
}
authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload)
if product.MerchantID != authPayload.MerchantID {
err := errors.New("Product doesn't belong to the user")
ctx.JSON(http.StatusUnauthorized, errorResponse(err))
return
}
ctx.JSON(http.StatusOK, product)
}
type listProductsRequest struct {
PageID int32 `form:"page_id" binding:"required,min=1"`
PageSize int32 `form:"page_size" binding:"required,min=5"`
}
func (server *Server) listProducts(ctx *gin.Context) {
var req listProductsRequest
if err := ctx.ShouldBindQuery(&req); err != nil {
ctx.JSON(http.StatusBadRequest, errorResponse(err))
return
}
authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload)
arg := db.ListProductsParams{
MerchantID: authPayload.MerchantID,
Limit: req.PageSize,
Offset: (req.PageID - 1) * req.PageSize,
}
products, err := server.store.ListProducts(ctx, arg)
if err != nil {
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
return
}
ctx.JSON(http.StatusOK, products)
}
type updateProductRequest struct {
ProductID uuid.UUID `json:"product_id" binding:"required"`
Name string `json:"name" binding:"required"`

View File

@ -9,9 +9,11 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
mockdb "git.nochill.in/nochill/naice_pos/db/mock"
db "git.nochill.in/nochill/naice_pos/db/sqlc"
"git.nochill.in/nochill/naice_pos/token"
"git.nochill.in/nochill/naice_pos/util"
"github.com/golang/mock/gomock"
"github.com/google/uuid"
@ -19,12 +21,14 @@ import (
)
func TestGetProductApi(t *testing.T) {
product := randomProduct()
merchantID := "f9ca13cf-8ab3-4ee3-9530-521ae505caa2"
product := randomProduct(merchantID)
testCases := []struct {
name string
productID string
buildStubs func(store *mockdb.MockStore)
setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker)
checkResponse func(t *testing.T, recorder *httptest.ResponseRecorder)
}{
{
@ -36,6 +40,9 @@ func TestGetProductApi(t *testing.T) {
Times(1).
Return(product, nil)
},
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "email", merchantID, time.Minute)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusOK, recorder.Code)
requireBodyMatchAccount(t, recorder.Body, product)
@ -50,10 +57,43 @@ func TestGetProductApi(t *testing.T) {
Times(1).
Return(db.Product{}, sql.ErrNoRows)
},
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "email", merchantID, time.Minute)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusNotFound, recorder.Code)
},
},
{
name: "Unauthorized",
productID: product.ID.String(),
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
GetProduct(gomock.Any(), gomock.Eq(product.ID)).
Times(1).
Return(db.Product{}, sql.ErrNoRows)
},
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "email", uuid.New().String(), time.Minute)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusNotFound, recorder.Code)
},
},
{
name: "NoAuthorization",
productID: product.ID.String(),
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
GetProduct(gomock.Any(), gomock.Any()).
Times(0)
},
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusUnauthorized, recorder.Code)
},
},
{
name: "Internal Error",
productID: product.ID.String(),
@ -63,6 +103,9 @@ func TestGetProductApi(t *testing.T) {
Times(1).
Return(db.Product{}, sql.ErrConnDone)
},
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "email", merchantID, time.Minute)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusInternalServerError, recorder.Code)
},
@ -75,6 +118,9 @@ func TestGetProductApi(t *testing.T) {
GetProduct(gomock.Any(), gomock.Any()).
Times(0)
},
setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "email", merchantID, time.Minute)
},
checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusBadRequest, recorder.Code)
},
@ -94,11 +140,12 @@ func TestGetProductApi(t *testing.T) {
server := newTestServer(t, store)
recorder := httptest.NewRecorder()
url := fmt.Sprintf("/product/%s", tc.productID)
url := fmt.Sprintf("/api/product/%s", tc.productID)
request, err := http.NewRequest(http.MethodGet, url, nil)
require.NoError(t, err)
tc.setupAuth(t, request, server.tokenMaker)
server.router.ServeHTTP(recorder, request)
tc.checkResponse(t, recorder)
})
@ -106,10 +153,10 @@ func TestGetProductApi(t *testing.T) {
}
}
func randomProduct() db.Product {
func randomProduct(merchantID string) db.Product {
return db.Product{
ID: uuid.New(),
MerchantID: uuid.MustParse("a848090f-0409-4386-9caa-929ae6874dbb"),
MerchantID: uuid.MustParse("f9ca13cf-8ab3-4ee3-9530-521ae505caa2"),
Name: util.RandomString(5),
SellingPrice: util.RandomFloat(1000, 99999),
PurchasePrice: util.RandomFloat(999, 9999),

View File

@ -38,13 +38,15 @@ func (server *Server) getRoutes() {
router.POST("/user/merchants", server.createUserMerchant)
router.POST("/products", server.createProduct)
router.PATCH("/products", server.updateProduct)
router.GET("/product/:id", server.getProduct)
apiRoutes := router.Group("/api").Use(authMiddleware(server.tokenMaker))
router.POST("/suppliers", server.createSupplier)
apiRoutes.POST("/products", server.createProduct)
apiRoutes.PATCH("/products", server.updateProduct)
apiRoutes.GET("/product/:id", server.getProduct)
router.POST("/purchase-products", server.createPurchase)
apiRoutes.POST("/suppliers", server.createSupplier)
apiRoutes.POST("/purchase-products", server.createPurchase)
server.router = router
}

View File

@ -127,6 +127,7 @@ func (server *Server) loginUser(ctx *gin.Context) {
accessToken, err := server.tokenMaker.CreateToken(
user.Email,
outlet.ID.String(),
server.config.TokenDuration,
)

View File

@ -47,7 +47,7 @@ func EqCreateUserMerchant(arg db.UserMerchantTxParams, password string) gomock.M
}
func TestCreateUserMerchantAPI(t *testing.T) {
userMerchant, password := RandomUser(t)
userMerchant, password := randomUser(t)
var userMerchantResult db.UserMerchantTxResult
// var userMerchantResponse createUserMerchantResponse
testCases := []struct {
@ -109,7 +109,7 @@ func TestCreateUserMerchantAPI(t *testing.T) {
}
}
func RandomUser(t *testing.T) (userMerchant db.UserMerchantTxParams, password string) {
func randomUser(t *testing.T) (userMerchant db.UserMerchantTxParams, password string) {
password = util.RandomString(6)
hashedPassword, err := util.HashPassword(password)
require.NoError(t, err)

View File

@ -27,9 +27,10 @@ WHERE id = $1;
-- name: ListProducts :many
SELECT * FROM products
WHERE merchant_id = $1
ORDER BY index_id
LIMIT $1
OFFSET $2;
LIMIT $2
OFFSET $3;
-- name: UpdateProduct :one
UPDATE products

View File

@ -113,18 +113,20 @@ func (q *Queries) GetStockForUpdateStock(ctx context.Context, id uuid.UUID) (Pro
const listProducts = `-- name: ListProducts :many
SELECT id, merchant_id, index_id, name, selling_price, purchase_price, stock, created_at, updated_at FROM products
WHERE merchant_id = $1
ORDER BY index_id
LIMIT $1
OFFSET $2
LIMIT $2
OFFSET $3
`
type ListProductsParams struct {
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
MerchantID uuid.UUID `json:"merchant_id"`
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
}
func (q *Queries) ListProducts(ctx context.Context, arg ListProductsParams) ([]Product, error) {
rows, err := q.db.QueryContext(ctx, listProducts, arg.Limit, arg.Offset)
rows, err := q.db.QueryContext(ctx, listProducts, arg.MerchantID, arg.Limit, arg.Offset)
if err != nil {
return nil, err
}

View File

@ -106,21 +106,24 @@ func TestDeleteProduct(t *testing.T) {
}
func TestGetProducts(t *testing.T) {
var lastProduct Product
for i := 0; i < 6; i++ {
createRandomProduct(t)
lastProduct, _ = createRandomProduct(t)
}
arg := ListProductsParams{
Limit: 5,
Offset: 1,
MerchantID: lastProduct.MerchantID,
Limit: 5,
Offset: 0,
}
products, err := testQueries.ListProducts(context.Background(), arg)
require.NoError(t, err)
require.Len(t, products, 5)
require.NotEmpty(t, products)
for _, product := range products {
require.NotEmpty(t, product)
require.Equal(t, lastProduct.MerchantID, product.MerchantID)
}
}

1
go.mod
View File

@ -4,7 +4,6 @@ go 1.20
require (
github.com/aead/chacha20poly1305 v0.0.0-20201124145622-1a5aba2a8b29
github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/gin-gonic/gin v1.9.0
github.com/golang/mock v1.6.0
github.com/google/uuid v1.3.0

2
go.sum
View File

@ -62,8 +62,6 @@ github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnht
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=

View File

@ -1,58 +0,0 @@
package token
import (
"errors"
"fmt"
"time"
"github.com/dgrijalva/jwt-go"
)
const minSecretKeySize = 32
type JWTMaker struct {
secretKey string
}
func NewJWTMaker(secretKey string) (Maker, error) {
if len(secretKey) < minSecretKeySize {
return nil, fmt.Errorf("Invalid key: must be at least %d characters", minSecretKeySize)
}
return &JWTMaker{secretKey}, nil
}
func (maker *JWTMaker) CreateToken(email string, duration time.Duration) (string, error) {
payload, err := NewPayload(email, duration)
if err != nil {
return "", err
}
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
return jwtToken.SignedString([]byte(maker.secretKey))
}
func (maker *JWTMaker) VerifyToken(token string) (*Payload, error) {
keyFunc := func(token *jwt.Token) (interface{}, error) {
_, ok := token.Method.(*jwt.SigningMethodHMAC)
if !ok {
return nil, ErrInvalidToken
}
return []byte(maker.secretKey), nil
}
jwtToken, err := jwt.ParseWithClaims(token, &Payload{}, keyFunc)
if err != nil {
verr, ok := err.(*jwt.ValidationError)
if ok && errors.Is(verr.Inner, ErrExpiredToken) {
return nil, ErrExpiredToken
}
return nil, ErrInvalidToken
}
payload, ok := jwtToken.Claims.(*Payload)
if !ok {
return nil, ErrInvalidToken
}
return payload, nil
}

View File

@ -1,65 +0,0 @@
package token
import (
"testing"
"time"
"git.nochill.in/nochill/naice_pos/util"
"github.com/dgrijalva/jwt-go"
"github.com/stretchr/testify/require"
)
func TestJWTMaker(t *testing.T) {
maker, err := NewJWTMaker(util.RandomString(32))
require.NoError(t, err)
email := util.RandomEmail()
duration := time.Minute
issuedAt := time.Now()
expiredAt := issuedAt.Add(duration)
token, err := maker.CreateToken(email, duration)
require.NoError(t, err)
require.NotEmpty(t, token)
payload, err := maker.VerifyToken(token)
require.NoError(t, err)
require.NotEmpty(t, payload)
require.NotZero(t, payload.ID)
require.Equal(t, email, payload.Email)
require.WithinDuration(t, issuedAt, payload.IssuedAt, time.Second)
require.WithinDuration(t, expiredAt, payload.ExpiredAt, time.Second)
}
func TestExpiredToken(t *testing.T) {
maker, err := NewJWTMaker(util.RandomString(32))
require.NoError(t, err)
token, err := maker.CreateToken(util.RandomEmail(), -time.Minute)
require.NoError(t, err)
require.NotEmpty(t, token)
payload, err := maker.VerifyToken(token)
require.Error(t, err)
require.EqualError(t, err, ErrExpiredToken.Error())
require.Nil(t, payload)
}
func TestInvalidJWTTokenAlgNone(t *testing.T) {
payload, err := NewPayload(util.RandomEmail(), time.Minute)
require.NoError(t, err)
jwtToken := jwt.NewWithClaims(jwt.SigningMethodNone, payload)
token, err := jwtToken.SignedString(jwt.UnsafeAllowNoneSignatureType)
require.NoError(t, err)
maker, err := NewJWTMaker(util.RandomString(32))
require.NoError(t, err)
payload, err = maker.VerifyToken(token)
require.Error(t, err)
require.EqualError(t, err, ErrInvalidToken.Error())
require.Nil(t, payload)
}

View File

@ -5,6 +5,6 @@ import (
)
type Maker interface {
CreateToken(email string, duration time.Duration) (string, error)
CreateToken(email string, merchantID string, duration time.Duration) (string, error)
VerifyToken(token string) (*Payload, error)
}

View File

@ -26,8 +26,8 @@ func NewPasetoMaker(symmetricKey string) (Maker, error) {
return maker, nil
}
func (maker *PasetoMaker) CreateToken(email string, duration time.Duration) (string, error) {
payload, err := NewPayload(email, duration)
func (maker *PasetoMaker) CreateToken(email string, merchant_id string, duration time.Duration) (string, error) {
payload, err := NewPayload(email, merchant_id, duration)
if err != nil {
return "", err
}

View File

@ -5,9 +5,12 @@ import (
"time"
"git.nochill.in/nochill/naice_pos/util"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)
var merchantID = uuid.New().String()
func TestPasetoMaker(t *testing.T) {
maker, err := NewPasetoMaker(util.RandomString(32))
require.NoError(t, err)
@ -18,7 +21,7 @@ func TestPasetoMaker(t *testing.T) {
issuedAt := time.Now()
expiredAt := issuedAt.Add(duration)
token, err := maker.CreateToken(email, duration)
token, err := maker.CreateToken(email, merchantID, duration)
require.NoError(t, err)
require.NotEmpty(t, token)
@ -36,7 +39,7 @@ func TestExpiredPasetoToken(t *testing.T) {
maker, err := NewPasetoMaker(util.RandomString(32))
require.NoError(t, err)
token, err := maker.CreateToken(util.RandomEmail(), -time.Minute)
token, err := maker.CreateToken(util.RandomEmail(), merchantID, -time.Minute)
require.NoError(t, err)
require.NotEmpty(t, token)

View File

@ -13,23 +13,25 @@ var (
)
type Payload struct {
ID uuid.UUID `json:"id"`
Email string `json:"email"`
IssuedAt time.Time `json:"issued_at"`
ExpiredAt time.Time `json:"expired_at"`
ID uuid.UUID `json:"id"`
Email string `json:"email"`
MerchantID uuid.UUID `json:"merchant_id"`
IssuedAt time.Time `json:"issued_at"`
ExpiredAt time.Time `json:"expired_at"`
}
func NewPayload(email string, duration time.Duration) (*Payload, error) {
func NewPayload(email string, merchant_id string, duration time.Duration) (*Payload, error) {
tokenID, err := uuid.NewRandom()
if err != nil {
return nil, err
}
payload := &Payload{
ID: tokenID,
Email: email,
IssuedAt: time.Now(),
ExpiredAt: time.Now().Add(duration),
ID: tokenID,
Email: email,
MerchantID: uuid.MustParse(merchant_id),
IssuedAt: time.Now(),
ExpiredAt: time.Now().Add(duration),
}
return payload, nil
}