add refresh token

This commit is contained in:
nochill 2023-03-16 12:21:41 +07:00
parent 61ed16163d
commit 2c21198412
18 changed files with 378 additions and 35 deletions

View File

@ -22,9 +22,10 @@ func addAuthorization(
merchantID string,
duration time.Duration,
) {
token, err := tokenMaker.CreateToken(email, merchantID, duration)
token, payload, err := tokenMaker.CreateToken(email, merchantID, duration)
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotEmpty(t, payload)
authorizationHeader := fmt.Sprintf("%s %s", authorizationType, token)
request.Header.Set(authorizationHeaderKey, authorizationHeader)

View File

@ -35,8 +35,8 @@ func (server *Server) getRoutes() {
router := gin.Default()
router.POST("/user/login", server.loginUser)
router.POST("/user/merchants", server.createUserMerchant)
router.POST("/user/renew_token", server.renewAccessToken)
apiRoutes := router.Group("/api").Use(authMiddleware(server.tokenMaker))

101
api/token.go Normal file
View File

@ -0,0 +1,101 @@
package api
import (
"database/sql"
"fmt"
"net/http"
"time"
"github.com/gin-gonic/gin"
)
type renewAccessRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
}
type renewAccessResponse struct {
AccesToken string `json:"access_token"`
AccessTokenExpiresAt time.Time `json:"access_token_expires_at"`
}
func (server *Server) renewAccessToken(ctx *gin.Context) {
var req renewAccessRequest
if err := ctx.ShouldBindJSON(&req); err != nil {
ctx.JSON(http.StatusBadRequest, errorResponse(err))
return
}
refreshPayload, err := server.tokenMaker.VerifyToken(req.RefreshToken)
if err != nil {
ctx.JSON(http.StatusUnauthorized, errorResponse(err))
return
}
session, err := server.store.GetSession(ctx, refreshPayload.ID)
if err != nil {
if err == sql.ErrNoRows {
ctx.JSON(http.StatusNotFound, errorResponse(err))
return
}
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
return
}
if session.IsBlocked {
err := fmt.Errorf("blocked session")
ctx.JSON(http.StatusUnauthorized, errorResponse(err))
return
}
if session.Email != refreshPayload.Email {
err := fmt.Errorf("incorrect session user")
ctx.JSON(http.StatusUnauthorized, errorResponse(err))
return
}
if session.RefreshToken != req.RefreshToken {
err := fmt.Errorf("mismatched session token")
ctx.JSON(http.StatusUnauthorized, errorResponse(err))
return
}
if time.Now().After(refreshPayload.ExpiredAt) {
err := fmt.Errorf("Exprired session")
ctx.JSON(http.StatusUnauthorized, errorResponse(err))
return
}
user, err := server.store.GetUserByEmail(ctx, refreshPayload.Email)
if err != nil {
if err == sql.ErrNoRows {
ctx.JSON(http.StatusNotFound, errorResponse(err))
return
}
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
return
}
merchant, err := server.store.GetMerchantByUserId(ctx, user.ID)
if err != nil {
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
return
}
accessToken, accessPayload, err := server.tokenMaker.CreateToken(
refreshPayload.Email,
merchant.ID.String(),
server.config.TokenDuration,
)
if err != nil {
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
return
}
res := renewAccessResponse{
AccesToken: accessToken,
AccessTokenExpiresAt: accessPayload.ExpiredAt,
}
ctx.JSON(http.StatusOK, res)
}

View File

@ -3,6 +3,7 @@ package api
import (
"database/sql"
"net/http"
"time"
db "git.nochill.in/nochill/naice_pos/db/sqlc"
"git.nochill.in/nochill/naice_pos/util"
@ -92,8 +93,12 @@ type userLoginRequest struct {
}
type userLoginResponse struct {
AccesToken string `json:"access_token"`
UserMerchantResponse userMerchantResponse
SessionID uuid.UUID `json:"session_id"`
AccesToken string `json:"access_token"`
AccessTokenExpiresAt time.Time `json:"access_token_expires_at"`
RefreshToken string `json:"refresh_token"`
RefreshTokenExpiresAt time.Time `json:"refresh_token_expires_at"`
UserMerchantResponse userMerchantResponse
}
func (server *Server) loginUser(ctx *gin.Context) {
@ -125,7 +130,7 @@ func (server *Server) loginUser(ctx *gin.Context) {
return
}
accessToken, err := server.tokenMaker.CreateToken(
accessToken, accessPayload, err := server.tokenMaker.CreateToken(
user.Email,
outlet.ID.String(),
server.config.TokenDuration,
@ -136,11 +141,34 @@ func (server *Server) loginUser(ctx *gin.Context) {
return
}
userMerchant := newUserMerchantResponse(user, outlet)
refreshToken, refreshTokenPayload, err := server.tokenMaker.CreateToken(
user.Email,
outlet.ID.String(),
server.config.RefreshTokenDuration,
)
session, err := server.store.CreateSession(ctx, db.CreateSessionParams{
ID: refreshTokenPayload.ID,
Email: user.Email,
RefreshToken: refreshToken,
UserAgent: ctx.Request.UserAgent(),
ClientIp: ctx.ClientIP(),
IsBlocked: false,
ExpiresAt: refreshTokenPayload.ExpiredAt,
})
if err != nil {
ctx.JSON(http.StatusInternalServerError, errorResponse(err))
return
}
res := userLoginResponse{
AccesToken: accessToken,
UserMerchantResponse: userMerchant,
SessionID: session.ID,
AccesToken: accessToken,
AccessTokenExpiresAt: accessPayload.ExpiredAt,
RefreshToken: refreshToken,
RefreshTokenExpiresAt: refreshTokenPayload.ExpiredAt,
UserMerchantResponse: newUserMerchantResponse(user, outlet),
}
ctx.JSON(http.StatusOK, res)

View File

@ -109,6 +109,68 @@ func TestCreateUserMerchantAPI(t *testing.T) {
}
}
func TestUserLoginAPI(t *testing.T) {
user, password := randomUser(t)
var userProfile db.User
var userOutletProfile db.Merchant
testCases := []struct {
name string
body gin.H
buildStubs func(store *mockdb.MockStore)
checkResponse func(recorder *httptest.ResponseRecorder)
}{
{
name: "OK",
body: gin.H{
"email": user.Email,
"password": password,
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
GetUserByEmail(gomock.Any(), gomock.Eq(user.Email)).
Times(1).
Return(userProfile, nil)
store.EXPECT().
GetMerchantByUserId(gomock.Any(), gomock.Any()).
Times(1).
Return(userOutletProfile, nil)
store.EXPECT().
CreateSession(gomock.Any(), gomock.Any()).
Times(1)
},
checkResponse: func(recorder *httptest.ResponseRecorder) {
require.Equal(t, http.StatusOK, recorder.Code)
},
},
}
for i := range testCases {
tc := testCases[i]
t.Run(tc.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
store := mockdb.NewMockStore(ctrl)
tc.buildStubs(store)
server := newTestServer(t, store)
recorder := httptest.NewRecorder()
data, err := json.Marshal(tc.body)
require.NoError(t, err)
url := "/user/login"
request, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(data))
require.NoError(t, err)
server.router.ServeHTTP(recorder, request)
tc.checkResponse(recorder)
})
}
}
func randomUser(t *testing.T) (userMerchant db.UserMerchantTxParams, password string) {
password = util.RandomString(6)
hashedPassword, err := util.HashPassword(password)

View File

@ -107,39 +107,25 @@ CREATE TABLE sale_order_detail (
CREATE INDEX ON "users"("index_id");
CREATE INDEX ON "merchants"("index_id");
CREATE INDEX ON "suppliers"("index_id");
CREATE INDEX ON "customers"("index_id");
CREATE INDEX ON "products" ("name");
CREATE INDEX ON "products" ("selling_price");
CREATE INDEX ON "products" ("index_id");
CREATE INDEX ON "products" ("purchase_price");
CREATE INDEX ON "products" ("stock");
CREATE INDEX ON "purchase_order" ("merchant_id");
CREATE INDEX ON "purchase_order" ("supplier_id");
CREATE INDEX ON "purchase_order" ("index_id");
CREATE INDEX ON "purchase_order" ("created_at");
CREATE INDEX ON "purchase_order_detail" ("index_id");
CREATE INDEX ON "sale_order" ("index_id");
CREATE INDEX ON "sale_order_detail" ("index_id");

View File

@ -0,0 +1 @@
DROP TABLE IF EXISTS user_sessions;

View File

@ -0,0 +1,13 @@
CREATE TABLE user_sessions(
"id" uuid default gen_random_uuid() primary key not null,
"index_id" bigserial not null,
"email" varchar not null references "users"("email") not null,
"refresh_token" varchar not null,
"user_agent" varchar not null,
"client_ip" varchar not null,
"is_blocked" boolean not null default false,
"expires_at" timestamp not null,
"created_at" timestamp default(now())
);
CREATE INDEX ON "user_sessions"("index_id");

View File

@ -111,6 +111,21 @@ func (mr *MockStoreMockRecorder) CreatePurchaseOrderDetail(arg0, arg1 interface{
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreatePurchaseOrderDetail", reflect.TypeOf((*MockStore)(nil).CreatePurchaseOrderDetail), arg0, arg1)
}
// CreateSession mocks base method.
func (m *MockStore) CreateSession(arg0 context.Context, arg1 db.CreateSessionParams) (db.UserSession, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateSession", arg0, arg1)
ret0, _ := ret[0].(db.UserSession)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateSession indicates an expected call of CreateSession.
func (mr *MockStoreMockRecorder) CreateSession(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSession", reflect.TypeOf((*MockStore)(nil).CreateSession), arg0, arg1)
}
// CreateSuppliers mocks base method.
func (m *MockStore) CreateSuppliers(arg0 context.Context, arg1 db.CreateSuppliersParams) (db.Supplier, error) {
m.ctrl.T.Helper()
@ -273,6 +288,21 @@ func (mr *MockStoreMockRecorder) GetProduct(arg0, arg1 interface{}) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProduct", reflect.TypeOf((*MockStore)(nil).GetProduct), arg0, arg1)
}
// GetSession mocks base method.
func (m *MockStore) GetSession(arg0 context.Context, arg1 uuid.UUID) (db.UserSession, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetSession", arg0, arg1)
ret0, _ := ret[0].(db.UserSession)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetSession indicates an expected call of GetSession.
func (mr *MockStoreMockRecorder) GetSession(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSession", reflect.TypeOf((*MockStore)(nil).GetSession), arg0, arg1)
}
// GetStockForUpdateStock mocks base method.
func (m *MockStore) GetStockForUpdateStock(arg0 context.Context, arg1 uuid.UUID) (db.Product, error) {
m.ctrl.T.Helper()

17
db/query/sessions.sql Normal file
View File

@ -0,0 +1,17 @@
-- name: CreateSession :one
INSERT INTO user_sessions (
id,
email,
refresh_token,
user_agent,
client_ip,
is_blocked,
expires_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7
) RETURNING *;
-- name: GetSession :one
SELECT * FROM user_sessions
WHERE id = $1
LIMIT 1;

View File

@ -7,6 +7,7 @@ package db
import (
"database/sql"
"encoding/json"
"time"
"github.com/google/uuid"
"github.com/tabbed/pqtype"
@ -118,3 +119,15 @@ type User struct {
CreatedAt sql.NullTime `json:"created_at"`
UpdatedAt sql.NullTime `json:"updated_at"`
}
type UserSession struct {
ID uuid.UUID `json:"id"`
IndexID int64 `json:"index_id"`
Email string `json:"email"`
RefreshToken string `json:"refresh_token"`
UserAgent string `json:"user_agent"`
ClientIp string `json:"client_ip"`
IsBlocked bool `json:"is_blocked"`
ExpiresAt time.Time `json:"expires_at"`
CreatedAt sql.NullTime `json:"created_at"`
}

View File

@ -16,6 +16,7 @@ type Querier interface {
CreateProduct(ctx context.Context, arg CreateProductParams) (Product, error)
CreatePurchaseOrder(ctx context.Context, arg CreatePurchaseOrderParams) (PurchaseOrder, error)
CreatePurchaseOrderDetail(ctx context.Context, arg CreatePurchaseOrderDetailParams) (PurchaseOrderDetail, error)
CreateSession(ctx context.Context, arg CreateSessionParams) (UserSession, error)
CreateSuppliers(ctx context.Context, arg CreateSuppliersParams) (Supplier, error)
CreateUser(ctx context.Context, arg CreateUserParams) (User, error)
CustomersList(ctx context.Context, arg CustomersListParams) ([]Customer, error)
@ -26,6 +27,7 @@ type Querier interface {
GetMerchantByUserId(ctx context.Context, ownerID uuid.UUID) (Merchant, error)
GetPasswordByEmail(ctx context.Context, email string) (string, error)
GetProduct(ctx context.Context, id uuid.UUID) (Product, error)
GetSession(ctx context.Context, id uuid.UUID) (UserSession, error)
GetStockForUpdateStock(ctx context.Context, id uuid.UUID) (Product, error)
GetUserByEmail(ctx context.Context, email string) (User, error)
GetUserById(ctx context.Context, id uuid.UUID) (User, error)

85
db/sqlc/sessions.sql.go Normal file
View File

@ -0,0 +1,85 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.17.2
// source: sessions.sql
package db
import (
"context"
"time"
"github.com/google/uuid"
)
const createSession = `-- name: CreateSession :one
INSERT INTO user_sessions (
id,
email,
refresh_token,
user_agent,
client_ip,
is_blocked,
expires_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7
) RETURNING id, index_id, email, refresh_token, user_agent, client_ip, is_blocked, expires_at, created_at
`
type CreateSessionParams struct {
ID uuid.UUID `json:"id"`
Email string `json:"email"`
RefreshToken string `json:"refresh_token"`
UserAgent string `json:"user_agent"`
ClientIp string `json:"client_ip"`
IsBlocked bool `json:"is_blocked"`
ExpiresAt time.Time `json:"expires_at"`
}
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (UserSession, error) {
row := q.db.QueryRowContext(ctx, createSession,
arg.ID,
arg.Email,
arg.RefreshToken,
arg.UserAgent,
arg.ClientIp,
arg.IsBlocked,
arg.ExpiresAt,
)
var i UserSession
err := row.Scan(
&i.ID,
&i.IndexID,
&i.Email,
&i.RefreshToken,
&i.UserAgent,
&i.ClientIp,
&i.IsBlocked,
&i.ExpiresAt,
&i.CreatedAt,
)
return i, err
}
const getSession = `-- name: GetSession :one
SELECT id, index_id, email, refresh_token, user_agent, client_ip, is_blocked, expires_at, created_at FROM user_sessions
WHERE id = $1
LIMIT 1
`
func (q *Queries) GetSession(ctx context.Context, id uuid.UUID) (UserSession, error) {
row := q.db.QueryRowContext(ctx, getSession, id)
var i UserSession
err := row.Scan(
&i.ID,
&i.IndexID,
&i.Email,
&i.RefreshToken,
&i.UserAgent,
&i.ClientIp,
&i.IsBlocked,
&i.ExpiresAt,
&i.CreatedAt,
)
return i, err
}

View File

@ -10,3 +10,4 @@ SERVER_ADDRESS = 0.0.0.0:8888
TOKEN_SYMMETRIC_KEY=75629266996751511372336382467976
TOKEN_DURATION = 6h
REFRESH_TOKEN_DURATION = 24h

View File

@ -5,6 +5,6 @@ import (
)
type Maker interface {
CreateToken(email string, merchantID string, duration time.Duration) (string, error)
CreateToken(email string, merchantID string, duration time.Duration) (string, *Payload, error)
VerifyToken(token string) (*Payload, error)
}

View File

@ -26,14 +26,14 @@ func NewPasetoMaker(symmetricKey string) (Maker, error) {
return maker, nil
}
func (maker *PasetoMaker) CreateToken(email string, merchant_id string, duration time.Duration) (string, error) {
func (maker *PasetoMaker) CreateToken(email string, merchant_id string, duration time.Duration) (string, *Payload, error) {
payload, err := NewPayload(email, merchant_id, duration)
if err != nil {
return "", err
return "", payload, err
}
token, err := maker.paseto.Encrypt(maker.symmetricKey, payload, nil)
return token, err
return token, payload, err
}
func (maker *PasetoMaker) VerifyToken(token string) (*Payload, error) {

View File

@ -21,11 +21,12 @@ func TestPasetoMaker(t *testing.T) {
issuedAt := time.Now()
expiredAt := issuedAt.Add(duration)
token, err := maker.CreateToken(email, merchantID, duration)
token, payload, err := maker.CreateToken(email, merchantID, duration)
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotEmpty(t, payload)
payload, err := maker.VerifyToken(token)
payload, err = maker.VerifyToken(token)
require.NoError(t, err)
require.NotEmpty(t, payload)
@ -39,11 +40,12 @@ func TestExpiredPasetoToken(t *testing.T) {
maker, err := NewPasetoMaker(util.RandomString(32))
require.NoError(t, err)
token, err := maker.CreateToken(util.RandomEmail(), merchantID, -time.Minute)
token, payload, err := maker.CreateToken(util.RandomEmail(), merchantID, -time.Minute)
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotEmpty(t, payload)
payload, err := maker.VerifyToken(token)
payload, err = maker.VerifyToken(token)
require.Error(t, err)
require.EqualError(t, err, ErrExpiredToken.Error())
require.Nil(t, payload)

View File

@ -7,11 +7,12 @@ import (
)
type Config struct {
DBDriver string `mapstructure:"DB_TYPE"`
DBSource string `mapstructure:"DB_SOURCE"`
ServerAddress string `mapstructure:"SERVER_ADDRESS"`
TokenSymmetricKey string `mapstructure:"TOKEN_SYMMETRIC_KEY"`
TokenDuration time.Duration `mapstructure:"TOKEN_DURATION"`
DBDriver string `mapstructure:"DB_TYPE"`
DBSource string `mapstructure:"DB_SOURCE"`
ServerAddress string `mapstructure:"SERVER_ADDRESS"`
TokenSymmetricKey string `mapstructure:"TOKEN_SYMMETRIC_KEY"`
TokenDuration time.Duration `mapstructure:"TOKEN_DURATION"`
RefreshTokenDuration time.Duration `mapstructure:"REFRESH_TOKEN_DURATION"`
}
func LoadConfig(path string) (config Config, err error) {