From 2c21198412f95eacddc0c504b3886ad9768079e9 Mon Sep 17 00:00:00 2001 From: nochill Date: Thu, 16 Mar 2023 12:21:41 +0700 Subject: [PATCH] add refresh token --- api/middleware_test.go | 3 +- api/server.go | 2 +- api/token.go | 101 ++++++++++++++++++ api/user.go | 40 +++++-- api/user_test.go | 62 +++++++++++ db/migrations/000001_init_schema.up.sql | 14 --- .../000002_create_sessions_table.down.sql | 1 + .../000002_create_sessions_table.up.sql | 13 +++ db/mock/store.go | 30 ++++++ db/query/sessions.sql | 17 +++ db/sqlc/models.go | 13 +++ db/sqlc/querier.go | 2 + db/sqlc/sessions.sql.go | 85 +++++++++++++++ dev.env | 1 + token/maker.go | 2 +- token/paseto_maker.go | 6 +- token/paseto_maker_test.go | 10 +- util/config.go | 11 +- 18 files changed, 378 insertions(+), 35 deletions(-) create mode 100644 api/token.go create mode 100644 db/migrations/000002_create_sessions_table.down.sql create mode 100644 db/migrations/000002_create_sessions_table.up.sql create mode 100644 db/query/sessions.sql create mode 100644 db/sqlc/sessions.sql.go diff --git a/api/middleware_test.go b/api/middleware_test.go index ca6ffbf..938e6af 100644 --- a/api/middleware_test.go +++ b/api/middleware_test.go @@ -22,9 +22,10 @@ func addAuthorization( merchantID string, duration time.Duration, ) { - token, err := tokenMaker.CreateToken(email, merchantID, duration) + token, payload, err := tokenMaker.CreateToken(email, merchantID, duration) require.NoError(t, err) require.NotEmpty(t, token) + require.NotEmpty(t, payload) authorizationHeader := fmt.Sprintf("%s %s", authorizationType, token) request.Header.Set(authorizationHeaderKey, authorizationHeader) diff --git a/api/server.go b/api/server.go index b915c73..d4a457a 100644 --- a/api/server.go +++ b/api/server.go @@ -35,8 +35,8 @@ func (server *Server) getRoutes() { router := gin.Default() router.POST("/user/login", server.loginUser) - router.POST("/user/merchants", server.createUserMerchant) + router.POST("/user/renew_token", server.renewAccessToken) apiRoutes := router.Group("/api").Use(authMiddleware(server.tokenMaker)) diff --git a/api/token.go b/api/token.go new file mode 100644 index 0000000..2e5b63e --- /dev/null +++ b/api/token.go @@ -0,0 +1,101 @@ +package api + +import ( + "database/sql" + "fmt" + "net/http" + "time" + + "github.com/gin-gonic/gin" +) + +type renewAccessRequest struct { + RefreshToken string `json:"refresh_token" binding:"required"` +} + +type renewAccessResponse struct { + AccesToken string `json:"access_token"` + AccessTokenExpiresAt time.Time `json:"access_token_expires_at"` +} + +func (server *Server) renewAccessToken(ctx *gin.Context) { + var req renewAccessRequest + if err := ctx.ShouldBindJSON(&req); err != nil { + ctx.JSON(http.StatusBadRequest, errorResponse(err)) + return + } + + refreshPayload, err := server.tokenMaker.VerifyToken(req.RefreshToken) + if err != nil { + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + session, err := server.store.GetSession(ctx, refreshPayload.ID) + if err != nil { + if err == sql.ErrNoRows { + ctx.JSON(http.StatusNotFound, errorResponse(err)) + return + } + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + return + } + + if session.IsBlocked { + err := fmt.Errorf("blocked session") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + if session.Email != refreshPayload.Email { + err := fmt.Errorf("incorrect session user") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + if session.RefreshToken != req.RefreshToken { + err := fmt.Errorf("mismatched session token") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + if time.Now().After(refreshPayload.ExpiredAt) { + err := fmt.Errorf("Exprired session") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + user, err := server.store.GetUserByEmail(ctx, refreshPayload.Email) + if err != nil { + if err == sql.ErrNoRows { + ctx.JSON(http.StatusNotFound, errorResponse(err)) + return + } + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + return + } + + merchant, err := server.store.GetMerchantByUserId(ctx, user.ID) + if err != nil { + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + return + } + + accessToken, accessPayload, err := server.tokenMaker.CreateToken( + refreshPayload.Email, + merchant.ID.String(), + server.config.TokenDuration, + ) + + if err != nil { + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + return + } + + res := renewAccessResponse{ + AccesToken: accessToken, + AccessTokenExpiresAt: accessPayload.ExpiredAt, + } + + ctx.JSON(http.StatusOK, res) +} diff --git a/api/user.go b/api/user.go index dd7dba1..91da2d5 100644 --- a/api/user.go +++ b/api/user.go @@ -3,6 +3,7 @@ package api import ( "database/sql" "net/http" + "time" db "git.nochill.in/nochill/naice_pos/db/sqlc" "git.nochill.in/nochill/naice_pos/util" @@ -92,8 +93,12 @@ type userLoginRequest struct { } type userLoginResponse struct { - AccesToken string `json:"access_token"` - UserMerchantResponse userMerchantResponse + SessionID uuid.UUID `json:"session_id"` + AccesToken string `json:"access_token"` + AccessTokenExpiresAt time.Time `json:"access_token_expires_at"` + RefreshToken string `json:"refresh_token"` + RefreshTokenExpiresAt time.Time `json:"refresh_token_expires_at"` + UserMerchantResponse userMerchantResponse } func (server *Server) loginUser(ctx *gin.Context) { @@ -125,7 +130,7 @@ func (server *Server) loginUser(ctx *gin.Context) { return } - accessToken, err := server.tokenMaker.CreateToken( + accessToken, accessPayload, err := server.tokenMaker.CreateToken( user.Email, outlet.ID.String(), server.config.TokenDuration, @@ -136,11 +141,34 @@ func (server *Server) loginUser(ctx *gin.Context) { return } - userMerchant := newUserMerchantResponse(user, outlet) + refreshToken, refreshTokenPayload, err := server.tokenMaker.CreateToken( + user.Email, + outlet.ID.String(), + server.config.RefreshTokenDuration, + ) + + session, err := server.store.CreateSession(ctx, db.CreateSessionParams{ + ID: refreshTokenPayload.ID, + Email: user.Email, + RefreshToken: refreshToken, + UserAgent: ctx.Request.UserAgent(), + ClientIp: ctx.ClientIP(), + IsBlocked: false, + ExpiresAt: refreshTokenPayload.ExpiredAt, + }) + + if err != nil { + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + return + } res := userLoginResponse{ - AccesToken: accessToken, - UserMerchantResponse: userMerchant, + SessionID: session.ID, + AccesToken: accessToken, + AccessTokenExpiresAt: accessPayload.ExpiredAt, + RefreshToken: refreshToken, + RefreshTokenExpiresAt: refreshTokenPayload.ExpiredAt, + UserMerchantResponse: newUserMerchantResponse(user, outlet), } ctx.JSON(http.StatusOK, res) diff --git a/api/user_test.go b/api/user_test.go index 4fff03b..8000829 100644 --- a/api/user_test.go +++ b/api/user_test.go @@ -109,6 +109,68 @@ func TestCreateUserMerchantAPI(t *testing.T) { } } +func TestUserLoginAPI(t *testing.T) { + user, password := randomUser(t) + var userProfile db.User + var userOutletProfile db.Merchant + + testCases := []struct { + name string + body gin.H + buildStubs func(store *mockdb.MockStore) + checkResponse func(recorder *httptest.ResponseRecorder) + }{ + { + name: "OK", + body: gin.H{ + "email": user.Email, + "password": password, + }, + buildStubs: func(store *mockdb.MockStore) { + store.EXPECT(). + GetUserByEmail(gomock.Any(), gomock.Eq(user.Email)). + Times(1). + Return(userProfile, nil) + store.EXPECT(). + GetMerchantByUserId(gomock.Any(), gomock.Any()). + Times(1). + Return(userOutletProfile, nil) + store.EXPECT(). + CreateSession(gomock.Any(), gomock.Any()). + Times(1) + }, + checkResponse: func(recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusOK, recorder.Code) + }, + }, + } + + for i := range testCases { + tc := testCases[i] + + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + store := mockdb.NewMockStore(ctrl) + tc.buildStubs(store) + + server := newTestServer(t, store) + recorder := httptest.NewRecorder() + + data, err := json.Marshal(tc.body) + require.NoError(t, err) + + url := "/user/login" + request, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(data)) + require.NoError(t, err) + + server.router.ServeHTTP(recorder, request) + tc.checkResponse(recorder) + }) + } +} + func randomUser(t *testing.T) (userMerchant db.UserMerchantTxParams, password string) { password = util.RandomString(6) hashedPassword, err := util.HashPassword(password) diff --git a/db/migrations/000001_init_schema.up.sql b/db/migrations/000001_init_schema.up.sql index 7ae43ca..8a1eb8e 100644 --- a/db/migrations/000001_init_schema.up.sql +++ b/db/migrations/000001_init_schema.up.sql @@ -107,39 +107,25 @@ CREATE TABLE sale_order_detail ( CREATE INDEX ON "users"("index_id"); - - CREATE INDEX ON "merchants"("index_id"); - - CREATE INDEX ON "suppliers"("index_id"); - - CREATE INDEX ON "customers"("index_id"); - - CREATE INDEX ON "products" ("name"); CREATE INDEX ON "products" ("selling_price"); CREATE INDEX ON "products" ("index_id"); CREATE INDEX ON "products" ("purchase_price"); CREATE INDEX ON "products" ("stock"); - CREATE INDEX ON "purchase_order" ("merchant_id"); CREATE INDEX ON "purchase_order" ("supplier_id"); CREATE INDEX ON "purchase_order" ("index_id"); CREATE INDEX ON "purchase_order" ("created_at"); - - CREATE INDEX ON "purchase_order_detail" ("index_id"); - CREATE INDEX ON "sale_order" ("index_id"); - - CREATE INDEX ON "sale_order_detail" ("index_id"); \ No newline at end of file diff --git a/db/migrations/000002_create_sessions_table.down.sql b/db/migrations/000002_create_sessions_table.down.sql new file mode 100644 index 0000000..8f9e258 --- /dev/null +++ b/db/migrations/000002_create_sessions_table.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS user_sessions; \ No newline at end of file diff --git a/db/migrations/000002_create_sessions_table.up.sql b/db/migrations/000002_create_sessions_table.up.sql new file mode 100644 index 0000000..250eb82 --- /dev/null +++ b/db/migrations/000002_create_sessions_table.up.sql @@ -0,0 +1,13 @@ +CREATE TABLE user_sessions( + "id" uuid default gen_random_uuid() primary key not null, + "index_id" bigserial not null, + "email" varchar not null references "users"("email") not null, + "refresh_token" varchar not null, + "user_agent" varchar not null, + "client_ip" varchar not null, + "is_blocked" boolean not null default false, + "expires_at" timestamp not null, + "created_at" timestamp default(now()) +); + +CREATE INDEX ON "user_sessions"("index_id"); diff --git a/db/mock/store.go b/db/mock/store.go index 732456b..5fb5bb9 100644 --- a/db/mock/store.go +++ b/db/mock/store.go @@ -111,6 +111,21 @@ func (mr *MockStoreMockRecorder) CreatePurchaseOrderDetail(arg0, arg1 interface{ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreatePurchaseOrderDetail", reflect.TypeOf((*MockStore)(nil).CreatePurchaseOrderDetail), arg0, arg1) } +// CreateSession mocks base method. +func (m *MockStore) CreateSession(arg0 context.Context, arg1 db.CreateSessionParams) (db.UserSession, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateSession", arg0, arg1) + ret0, _ := ret[0].(db.UserSession) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateSession indicates an expected call of CreateSession. +func (mr *MockStoreMockRecorder) CreateSession(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSession", reflect.TypeOf((*MockStore)(nil).CreateSession), arg0, arg1) +} + // CreateSuppliers mocks base method. func (m *MockStore) CreateSuppliers(arg0 context.Context, arg1 db.CreateSuppliersParams) (db.Supplier, error) { m.ctrl.T.Helper() @@ -273,6 +288,21 @@ func (mr *MockStoreMockRecorder) GetProduct(arg0, arg1 interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProduct", reflect.TypeOf((*MockStore)(nil).GetProduct), arg0, arg1) } +// GetSession mocks base method. +func (m *MockStore) GetSession(arg0 context.Context, arg1 uuid.UUID) (db.UserSession, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSession", arg0, arg1) + ret0, _ := ret[0].(db.UserSession) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSession indicates an expected call of GetSession. +func (mr *MockStoreMockRecorder) GetSession(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSession", reflect.TypeOf((*MockStore)(nil).GetSession), arg0, arg1) +} + // GetStockForUpdateStock mocks base method. func (m *MockStore) GetStockForUpdateStock(arg0 context.Context, arg1 uuid.UUID) (db.Product, error) { m.ctrl.T.Helper() diff --git a/db/query/sessions.sql b/db/query/sessions.sql new file mode 100644 index 0000000..adef2f8 --- /dev/null +++ b/db/query/sessions.sql @@ -0,0 +1,17 @@ +-- name: CreateSession :one +INSERT INTO user_sessions ( + id, + email, + refresh_token, + user_agent, + client_ip, + is_blocked, + expires_at +) VALUES ( + $1, $2, $3, $4, $5, $6, $7 +) RETURNING *; + +-- name: GetSession :one +SELECT * FROM user_sessions +WHERE id = $1 +LIMIT 1; \ No newline at end of file diff --git a/db/sqlc/models.go b/db/sqlc/models.go index 5195fb7..734a266 100644 --- a/db/sqlc/models.go +++ b/db/sqlc/models.go @@ -7,6 +7,7 @@ package db import ( "database/sql" "encoding/json" + "time" "github.com/google/uuid" "github.com/tabbed/pqtype" @@ -118,3 +119,15 @@ type User struct { CreatedAt sql.NullTime `json:"created_at"` UpdatedAt sql.NullTime `json:"updated_at"` } + +type UserSession struct { + ID uuid.UUID `json:"id"` + IndexID int64 `json:"index_id"` + Email string `json:"email"` + RefreshToken string `json:"refresh_token"` + UserAgent string `json:"user_agent"` + ClientIp string `json:"client_ip"` + IsBlocked bool `json:"is_blocked"` + ExpiresAt time.Time `json:"expires_at"` + CreatedAt sql.NullTime `json:"created_at"` +} diff --git a/db/sqlc/querier.go b/db/sqlc/querier.go index 3791739..fdc6368 100644 --- a/db/sqlc/querier.go +++ b/db/sqlc/querier.go @@ -16,6 +16,7 @@ type Querier interface { CreateProduct(ctx context.Context, arg CreateProductParams) (Product, error) CreatePurchaseOrder(ctx context.Context, arg CreatePurchaseOrderParams) (PurchaseOrder, error) CreatePurchaseOrderDetail(ctx context.Context, arg CreatePurchaseOrderDetailParams) (PurchaseOrderDetail, error) + CreateSession(ctx context.Context, arg CreateSessionParams) (UserSession, error) CreateSuppliers(ctx context.Context, arg CreateSuppliersParams) (Supplier, error) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) CustomersList(ctx context.Context, arg CustomersListParams) ([]Customer, error) @@ -26,6 +27,7 @@ type Querier interface { GetMerchantByUserId(ctx context.Context, ownerID uuid.UUID) (Merchant, error) GetPasswordByEmail(ctx context.Context, email string) (string, error) GetProduct(ctx context.Context, id uuid.UUID) (Product, error) + GetSession(ctx context.Context, id uuid.UUID) (UserSession, error) GetStockForUpdateStock(ctx context.Context, id uuid.UUID) (Product, error) GetUserByEmail(ctx context.Context, email string) (User, error) GetUserById(ctx context.Context, id uuid.UUID) (User, error) diff --git a/db/sqlc/sessions.sql.go b/db/sqlc/sessions.sql.go new file mode 100644 index 0000000..fa1014f --- /dev/null +++ b/db/sqlc/sessions.sql.go @@ -0,0 +1,85 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.17.2 +// source: sessions.sql + +package db + +import ( + "context" + "time" + + "github.com/google/uuid" +) + +const createSession = `-- name: CreateSession :one +INSERT INTO user_sessions ( + id, + email, + refresh_token, + user_agent, + client_ip, + is_blocked, + expires_at +) VALUES ( + $1, $2, $3, $4, $5, $6, $7 +) RETURNING id, index_id, email, refresh_token, user_agent, client_ip, is_blocked, expires_at, created_at +` + +type CreateSessionParams struct { + ID uuid.UUID `json:"id"` + Email string `json:"email"` + RefreshToken string `json:"refresh_token"` + UserAgent string `json:"user_agent"` + ClientIp string `json:"client_ip"` + IsBlocked bool `json:"is_blocked"` + ExpiresAt time.Time `json:"expires_at"` +} + +func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (UserSession, error) { + row := q.db.QueryRowContext(ctx, createSession, + arg.ID, + arg.Email, + arg.RefreshToken, + arg.UserAgent, + arg.ClientIp, + arg.IsBlocked, + arg.ExpiresAt, + ) + var i UserSession + err := row.Scan( + &i.ID, + &i.IndexID, + &i.Email, + &i.RefreshToken, + &i.UserAgent, + &i.ClientIp, + &i.IsBlocked, + &i.ExpiresAt, + &i.CreatedAt, + ) + return i, err +} + +const getSession = `-- name: GetSession :one +SELECT id, index_id, email, refresh_token, user_agent, client_ip, is_blocked, expires_at, created_at FROM user_sessions +WHERE id = $1 +LIMIT 1 +` + +func (q *Queries) GetSession(ctx context.Context, id uuid.UUID) (UserSession, error) { + row := q.db.QueryRowContext(ctx, getSession, id) + var i UserSession + err := row.Scan( + &i.ID, + &i.IndexID, + &i.Email, + &i.RefreshToken, + &i.UserAgent, + &i.ClientIp, + &i.IsBlocked, + &i.ExpiresAt, + &i.CreatedAt, + ) + return i, err +} diff --git a/dev.env b/dev.env index cff4011..b3dd151 100644 --- a/dev.env +++ b/dev.env @@ -10,3 +10,4 @@ SERVER_ADDRESS = 0.0.0.0:8888 TOKEN_SYMMETRIC_KEY=75629266996751511372336382467976 TOKEN_DURATION = 6h +REFRESH_TOKEN_DURATION = 24h diff --git a/token/maker.go b/token/maker.go index 17ae30e..d495945 100644 --- a/token/maker.go +++ b/token/maker.go @@ -5,6 +5,6 @@ import ( ) type Maker interface { - CreateToken(email string, merchantID string, duration time.Duration) (string, error) + CreateToken(email string, merchantID string, duration time.Duration) (string, *Payload, error) VerifyToken(token string) (*Payload, error) } diff --git a/token/paseto_maker.go b/token/paseto_maker.go index 5f2faac..5ab5488 100644 --- a/token/paseto_maker.go +++ b/token/paseto_maker.go @@ -26,14 +26,14 @@ func NewPasetoMaker(symmetricKey string) (Maker, error) { return maker, nil } -func (maker *PasetoMaker) CreateToken(email string, merchant_id string, duration time.Duration) (string, error) { +func (maker *PasetoMaker) CreateToken(email string, merchant_id string, duration time.Duration) (string, *Payload, error) { payload, err := NewPayload(email, merchant_id, duration) if err != nil { - return "", err + return "", payload, err } token, err := maker.paseto.Encrypt(maker.symmetricKey, payload, nil) - return token, err + return token, payload, err } func (maker *PasetoMaker) VerifyToken(token string) (*Payload, error) { diff --git a/token/paseto_maker_test.go b/token/paseto_maker_test.go index d8c6bcc..cc7c22d 100644 --- a/token/paseto_maker_test.go +++ b/token/paseto_maker_test.go @@ -21,11 +21,12 @@ func TestPasetoMaker(t *testing.T) { issuedAt := time.Now() expiredAt := issuedAt.Add(duration) - token, err := maker.CreateToken(email, merchantID, duration) + token, payload, err := maker.CreateToken(email, merchantID, duration) require.NoError(t, err) require.NotEmpty(t, token) + require.NotEmpty(t, payload) - payload, err := maker.VerifyToken(token) + payload, err = maker.VerifyToken(token) require.NoError(t, err) require.NotEmpty(t, payload) @@ -39,11 +40,12 @@ func TestExpiredPasetoToken(t *testing.T) { maker, err := NewPasetoMaker(util.RandomString(32)) require.NoError(t, err) - token, err := maker.CreateToken(util.RandomEmail(), merchantID, -time.Minute) + token, payload, err := maker.CreateToken(util.RandomEmail(), merchantID, -time.Minute) require.NoError(t, err) require.NotEmpty(t, token) + require.NotEmpty(t, payload) - payload, err := maker.VerifyToken(token) + payload, err = maker.VerifyToken(token) require.Error(t, err) require.EqualError(t, err, ErrExpiredToken.Error()) require.Nil(t, payload) diff --git a/util/config.go b/util/config.go index e454a36..5edce58 100644 --- a/util/config.go +++ b/util/config.go @@ -7,11 +7,12 @@ import ( ) type Config struct { - DBDriver string `mapstructure:"DB_TYPE"` - DBSource string `mapstructure:"DB_SOURCE"` - ServerAddress string `mapstructure:"SERVER_ADDRESS"` - TokenSymmetricKey string `mapstructure:"TOKEN_SYMMETRIC_KEY"` - TokenDuration time.Duration `mapstructure:"TOKEN_DURATION"` + DBDriver string `mapstructure:"DB_TYPE"` + DBSource string `mapstructure:"DB_SOURCE"` + ServerAddress string `mapstructure:"SERVER_ADDRESS"` + TokenSymmetricKey string `mapstructure:"TOKEN_SYMMETRIC_KEY"` + TokenDuration time.Duration `mapstructure:"TOKEN_DURATION"` + RefreshTokenDuration time.Duration `mapstructure:"REFRESH_TOKEN_DURATION"` } func LoadConfig(path string) (config Config, err error) {