From bcb242c0aafd9fdf9ae53b3ba4bc083b36894ecc Mon Sep 17 00:00:00 2001 From: nochill Date: Thu, 21 Sep 2023 21:38:41 +0700 Subject: [PATCH] add user sessions --- api/token.go | 139 +++++++++--------- api/user.go | 52 ++++++- ...000006_create_user_sessions_table.down.sql | 1 + .../000006_create_user_sessions_table.up.sql | 13 ++ db/mock/store.go | 45 ++++++ db/queries/sessions.sql | 16 ++ db/queries/users.sql | 5 + db/sqlc/models.go | 12 ++ db/sqlc/querier.go | 3 + db/sqlc/sessions.sql.go | 80 ++++++++++ db/sqlc/users.sql.go | 30 ++++ util/token/maker.go | 2 +- util/token/paseto.go | 4 +- util/token/paseto_test.go | 2 +- util/token/payload.go | 6 +- 15 files changed, 332 insertions(+), 78 deletions(-) create mode 100644 db/migrations/000006_create_user_sessions_table.down.sql create mode 100644 db/migrations/000006_create_user_sessions_table.up.sql create mode 100644 db/queries/sessions.sql create mode 100644 db/sqlc/sessions.sql.go diff --git a/api/token.go b/api/token.go index a0a95d5..4f720ac 100644 --- a/api/token.go +++ b/api/token.go @@ -1,7 +1,12 @@ package api import ( + "database/sql" + "fmt" + "net/http" "time" + + "github.com/gin-gonic/gin" ) type renewAccessRequest struct { @@ -13,84 +18,78 @@ type renewAccessResponse struct { AccessTokenExpiresAt time.Time `json:"access_token_expires_at"` } -// func (server *Server) renewAccessToken(ctx *gin.Context) { -// var req renewAccessRequest -// if err := ctx.ShouldBindJSON(&req); err != nil { -// ctx.JSON(http.StatusBadRequest, errorResponse(err, "")) -// return -// } +func (server *Server) renewAccessToken(ctx *gin.Context) { + var req renewAccessRequest + if err := ctx.ShouldBindJSON(&req); err != nil { + ctx.JSON(http.StatusBadRequest, ErrorResponse(err, "")) + return + } -// refreshPayload, err := server.tokenMaker.VerifyToken(req.RefreshToken) -// if err != nil { -// ctx.JSON(http.StatusUnauthorized, errorResponse(err, "")) -// return -// } + refreshPayload, err := server.TokenMaker.VerifyToken(req.RefreshToken) + if err != nil { + ctx.JSON(http.StatusUnauthorized, ErrorResponse(err, "")) + return + } -// session, err := server.store.GetSession(ctx, refreshPayload.ID) -// if err != nil { -// if err == sql.ErrNoRows { -// ctx.JSON(http.StatusNotFound, errorResponse(err, "")) -// return -// } -// ctx.JSON(http.StatusInternalServerError, errorResponse(err, "")) -// return -// } + session, err := server.Store.GetSession(ctx, int32(refreshPayload.UserID)) + if err != nil { + if err == sql.ErrNoRows { + ctx.JSON(http.StatusNotFound, ErrorResponse(err, "")) + return + } + ctx.JSON(http.StatusInternalServerError, ErrorResponse(err, "")) + return + } -// if session.IsBlocked { -// err := fmt.Errorf("blocked session") -// ctx.JSON(http.StatusUnauthorized, errorResponse(err, "")) -// return -// } + if session.IsBlocked { + err := fmt.Errorf("blocked session") + ctx.JSON(http.StatusUnauthorized, ErrorResponse(err, "")) + return + } -// if session.Email != refreshPayload.Email { -// err := fmt.Errorf("incorrect session user") -// ctx.JSON(http.StatusUnauthorized, errorResponse(err, "")) -// return -// } + if session.Username != refreshPayload.Username { + err := fmt.Errorf("incorrect session user") + ctx.JSON(http.StatusUnauthorized, ErrorResponse(err, "")) + return + } -// if session.RefreshToken != req.RefreshToken { -// err := fmt.Errorf("mismatched session token") -// ctx.JSON(http.StatusUnauthorized, errorResponse(err, "")) -// return -// } + if session.RefreshToken != req.RefreshToken { + err := fmt.Errorf("mismatched session token") + ctx.JSON(http.StatusUnauthorized, ErrorResponse(err, "")) + return + } -// if time.Now().After(refreshPayload.ExpiredAt) { -// err := fmt.Errorf("Expired session") -// ctx.JSON(http.StatusUnauthorized, errorResponse(err, "")) -// return -// } + if time.Now().After(refreshPayload.ExpiredAt) { + err := fmt.Errorf("expired session") + ctx.JSON(http.StatusUnauthorized, ErrorResponse(err, "")) + return + } -// user, err := server.store.GetUserByEmail(ctx, refreshPayload.Email) -// if err != nil { -// if err == sql.ErrNoRows { -// ctx.JSON(http.StatusNotFound, errorResponse(err, "")) -// return -// } -// ctx.JSON(http.StatusInternalServerError, errorResponse(err, "")) -// return -// } + user, err := server.Store.GetUser(ctx, refreshPayload.Username) + if err != nil { + if err == sql.ErrNoRows { + ctx.JSON(http.StatusNotFound, ErrorResponse(err, "")) + return + } + ctx.JSON(http.StatusInternalServerError, ErrorResponse(err, "")) + return + } -// merchant, err := server.store.GetMerchantByUserId(ctx, user.ID) -// if err != nil { -// ctx.JSON(http.StatusInternalServerError, errorResponse(err, "")) -// return -// } + accessToken, accessPayload, err := server.TokenMaker.CreateToken( + refreshPayload.Username, + int(user.ID), + server.Config.TokenDuration, + ) -// accessToken, accessPayload, err := server.tokenMaker.CreateToken( -// refreshPayload.Email, -// merchant.ID.String(), -// server.config.TokenDuration, -// ) + if err != nil { + ctx.JSON(http.StatusInternalServerError, ErrorResponse(err, "")) + return + } -// if err != nil { -// ctx.JSON(http.StatusInternalServerError, errorResponse(err, "")) -// return -// } + res := renewAccessResponse{ + AccesToken: accessToken, + AccessTokenExpiresAt: accessPayload.ExpiredAt, + } -// res := renewAccessResponse{ -// AccesToken: accessToken, -// AccessTokenExpiresAt: accessPayload.ExpiredAt, -// } - -// ctx.JSON(http.StatusOK, res) -// } + ctx.JSON(http.StatusOK, res) +} diff --git a/api/user.go b/api/user.go index ffa72c8..dc140cd 100644 --- a/api/user.go +++ b/api/user.go @@ -31,6 +31,14 @@ type createUserResponse struct { UpdatedAt time.Time `json:"updated_at"` } +type userTokenResponse struct { + SessionID int32 `json:"session_id"` + AccesToken string `json:"access_token"` + AccessTokenExpiresAt time.Time `json:"access_token_expires_at"` + RefreshToken string `json:"refresh_token"` + RefreshTokenExpiresAt time.Time `json:"refresh_token_expires_at"` +} + func (server *Server) createUser(ctx *gin.Context) { var req createUserRequest if err := ctx.ShouldBindJSON(&req); err != nil { @@ -65,6 +73,45 @@ func (server *Server) createUser(ctx *gin.Context) { return } + accessToken, accessPayload, err := server.TokenMaker.CreateToken( + user.Username, + int(user.ID), + server.Config.TokenDuration, + ) + + if err != nil { + ctx.JSON(http.StatusInternalServerError, ErrorResponse(err, "Something went wrong while creating token")) + return + } + + refreshToken, refreshTokenPayload, err := server.TokenMaker.CreateToken( + user.Username, + int(user.ID), + server.Config.RefreshTokenDuration, + ) + + session, err := server.Store.CreateSession(ctx, db.CreateSessionParams{ + Username: user.Username, + RefreshToken: refreshToken, + UserAgent: ctx.Request.UserAgent(), + ClientIp: ctx.ClientIP(), + IsBlocked: false, + ExpiresAt: refreshTokenPayload.ExpiredAt, + }) + + if err != nil { + ctx.JSON(http.StatusInternalServerError, ErrorResponse(err, "Something went wrong while saving sessions")) + return + } + + tokenResponse := userTokenResponse{ + SessionID: session.ID, + AccesToken: accessToken, + AccessTokenExpiresAt: accessPayload.ExpiredAt, + RefreshToken: refreshToken, + RefreshTokenExpiresAt: refreshTokenPayload.ExpiredAt, + } + res := createUserResponse{ ID: user.ID, Username: user.Username, @@ -80,5 +127,8 @@ func (server *Server) createUser(ctx *gin.Context) { UpdatedAt: user.UpdatedAt.Time, } - ctx.JSON(http.StatusOK, res) + ctx.JSON(http.StatusOK, gin.H{ + "token": tokenResponse, + "user": res, + }) } diff --git a/db/migrations/000006_create_user_sessions_table.down.sql b/db/migrations/000006_create_user_sessions_table.down.sql new file mode 100644 index 0000000..8f9e258 --- /dev/null +++ b/db/migrations/000006_create_user_sessions_table.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS user_sessions; \ No newline at end of file diff --git a/db/migrations/000006_create_user_sessions_table.up.sql b/db/migrations/000006_create_user_sessions_table.up.sql new file mode 100644 index 0000000..1d99724 --- /dev/null +++ b/db/migrations/000006_create_user_sessions_table.up.sql @@ -0,0 +1,13 @@ +CREATE TABLE user_sessions( + "id" serial primary key not null, + "index_id" bigserial not null, + "username" varchar not null references "users"("username") not null, + "refresh_token" varchar not null, + "user_agent" varchar not null, + "client_ip" varchar not null, + "is_blocked" boolean not null default false, + "expires_at" timestamp not null, + "created_at" timestamp default(now()) +); + +CREATE INDEX ON "user_sessions"("index_id"); diff --git a/db/mock/store.go b/db/mock/store.go index 3cf6432..6850c02 100644 --- a/db/mock/store.go +++ b/db/mock/store.go @@ -49,6 +49,21 @@ func (mr *MockStoreMockRecorder) CreateLocation(arg0, arg1 interface{}) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateLocation", reflect.TypeOf((*MockStore)(nil).CreateLocation), arg0, arg1) } +// CreateSession mocks base method. +func (m *MockStore) CreateSession(arg0 context.Context, arg1 db.CreateSessionParams) (db.UserSession, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateSession", arg0, arg1) + ret0, _ := ret[0].(db.UserSession) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateSession indicates an expected call of CreateSession. +func (mr *MockStoreMockRecorder) CreateSession(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSession", reflect.TypeOf((*MockStore)(nil).CreateSession), arg0, arg1) +} + // CreateUser mocks base method. func (m *MockStore) CreateUser(arg0 context.Context, arg1 db.CreateUserParams) (db.User, error) { m.ctrl.T.Helper() @@ -154,6 +169,21 @@ func (mr *MockStoreMockRecorder) GetLocationTag(arg0, arg1 interface{}) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLocationTag", reflect.TypeOf((*MockStore)(nil).GetLocationTag), arg0, arg1) } +// GetSession mocks base method. +func (m *MockStore) GetSession(arg0 context.Context, arg1 int32) (db.UserSession, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSession", arg0, arg1) + ret0, _ := ret[0].(db.UserSession) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSession indicates an expected call of GetSession. +func (mr *MockStoreMockRecorder) GetSession(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSession", reflect.TypeOf((*MockStore)(nil).GetSession), arg0, arg1) +} + // GetTopListLocations mocks base method. func (m *MockStore) GetTopListLocations(arg0 context.Context, arg1 db.GetTopListLocationsParams) ([]db.GetTopListLocationsRow, error) { m.ctrl.T.Helper() @@ -169,6 +199,21 @@ func (mr *MockStoreMockRecorder) GetTopListLocations(arg0, arg1 interface{}) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTopListLocations", reflect.TypeOf((*MockStore)(nil).GetTopListLocations), arg0, arg1) } +// GetUser mocks base method. +func (m *MockStore) GetUser(arg0 context.Context, arg1 string) (db.User, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUser", arg0, arg1) + ret0, _ := ret[0].(db.User) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUser indicates an expected call of GetUser. +func (mr *MockStoreMockRecorder) GetUser(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUser", reflect.TypeOf((*MockStore)(nil).GetUser), arg0, arg1) +} + // UpdatePassword mocks base method. func (m *MockStore) UpdatePassword(arg0 context.Context, arg1 db.UpdatePasswordParams) error { m.ctrl.T.Helper() diff --git a/db/queries/sessions.sql b/db/queries/sessions.sql new file mode 100644 index 0000000..d03781c --- /dev/null +++ b/db/queries/sessions.sql @@ -0,0 +1,16 @@ +-- name: CreateSession :one +INSERT INTO user_sessions ( + username, + refresh_token, + user_agent, + client_ip, + is_blocked, + expires_at +) VALUES ( + $1, $2, $3, $4, $5, $6 +) RETURNING *; + +-- name: GetSession :one +SELECT * FROM user_sessions +WHERE id = $1 +LIMIT 1; \ No newline at end of file diff --git a/db/queries/users.sql b/db/queries/users.sql index 6a05c1d..2b00d00 100644 --- a/db/queries/users.sql +++ b/db/queries/users.sql @@ -16,6 +16,11 @@ WHERE RETURNING *; +-- name: GetUser :one +SELECT * FROM USERS +WHERE username = $1; + + -- name: UpdatePassword :exec UPDATE users SET password = $1 diff --git a/db/sqlc/models.go b/db/sqlc/models.go index ee4f957..34b57b6 100644 --- a/db/sqlc/models.go +++ b/db/sqlc/models.go @@ -244,3 +244,15 @@ type UserReport struct { CreatedAt sql.NullTime `json:"created_at"` UpdatedAt sql.NullTime `json:"updated_at"` } + +type UserSession struct { + ID int32 `json:"id"` + IndexID int64 `json:"index_id"` + Username string `json:"username"` + RefreshToken string `json:"refresh_token"` + UserAgent string `json:"user_agent"` + ClientIp string `json:"client_ip"` + IsBlocked bool `json:"is_blocked"` + ExpiresAt time.Time `json:"expires_at"` + CreatedAt sql.NullTime `json:"created_at"` +} diff --git a/db/sqlc/querier.go b/db/sqlc/querier.go index 47bea7a..570a40e 100644 --- a/db/sqlc/querier.go +++ b/db/sqlc/querier.go @@ -10,12 +10,15 @@ import ( type Querier interface { CreateLocation(ctx context.Context, arg CreateLocationParams) error + CreateSession(ctx context.Context, arg CreateSessionParams) (UserSession, error) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) GetCountImageByLocation(ctx context.Context, imageOf int32) (int64, error) GetListLocations(ctx context.Context) ([]Location, error) GetListRecentLocationsWithRatings(ctx context.Context, limit int32) ([]GetListRecentLocationsWithRatingsRow, error) GetLocation(ctx context.Context, id int32) (GetLocationRow, error) GetLocationTag(ctx context.Context, targetID int32) ([]string, error) + GetSession(ctx context.Context, id int32) (UserSession, error) + GetUser(ctx context.Context, username string) (User, error) UpdatePassword(ctx context.Context, arg UpdatePasswordParams) error UpdateUser(ctx context.Context, arg UpdateUserParams) (User, error) } diff --git a/db/sqlc/sessions.sql.go b/db/sqlc/sessions.sql.go new file mode 100644 index 0000000..9ee7fb1 --- /dev/null +++ b/db/sqlc/sessions.sql.go @@ -0,0 +1,80 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 +// source: sessions.sql + +package db + +import ( + "context" + "time" +) + +const createSession = `-- name: CreateSession :one +INSERT INTO user_sessions ( + username, + refresh_token, + user_agent, + client_ip, + is_blocked, + expires_at +) VALUES ( + $1, $2, $3, $4, $5, $6 +) RETURNING id, index_id, username, refresh_token, user_agent, client_ip, is_blocked, expires_at, created_at +` + +type CreateSessionParams struct { + Username string `json:"username"` + RefreshToken string `json:"refresh_token"` + UserAgent string `json:"user_agent"` + ClientIp string `json:"client_ip"` + IsBlocked bool `json:"is_blocked"` + ExpiresAt time.Time `json:"expires_at"` +} + +func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (UserSession, error) { + row := q.db.QueryRowContext(ctx, createSession, + arg.Username, + arg.RefreshToken, + arg.UserAgent, + arg.ClientIp, + arg.IsBlocked, + arg.ExpiresAt, + ) + var i UserSession + err := row.Scan( + &i.ID, + &i.IndexID, + &i.Username, + &i.RefreshToken, + &i.UserAgent, + &i.ClientIp, + &i.IsBlocked, + &i.ExpiresAt, + &i.CreatedAt, + ) + return i, err +} + +const getSession = `-- name: GetSession :one +SELECT id, index_id, username, refresh_token, user_agent, client_ip, is_blocked, expires_at, created_at FROM user_sessions +WHERE id = $1 +LIMIT 1 +` + +func (q *Queries) GetSession(ctx context.Context, id int32) (UserSession, error) { + row := q.db.QueryRowContext(ctx, getSession, id) + var i UserSession + err := row.Scan( + &i.ID, + &i.IndexID, + &i.Username, + &i.RefreshToken, + &i.UserAgent, + &i.ClientIp, + &i.IsBlocked, + &i.ExpiresAt, + &i.CreatedAt, + ) + return i, err +} diff --git a/db/sqlc/users.sql.go b/db/sqlc/users.sql.go index 6dbb453..c1d05ef 100644 --- a/db/sqlc/users.sql.go +++ b/db/sqlc/users.sql.go @@ -48,6 +48,36 @@ func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, e return i, err } +const getUser = `-- name: GetUser :one +SELECT id, email, username, password, avatar_picture, google_sign_in_payload, banned_at, banned_until, ban_reason, is_permaban, is_admin, is_critics, is_verified, is_active, social_media, created_at, updated_at FROM USERS +WHERE username = $1 +` + +func (q *Queries) GetUser(ctx context.Context, username string) (User, error) { + row := q.db.QueryRowContext(ctx, getUser, username) + var i User + err := row.Scan( + &i.ID, + &i.Email, + &i.Username, + &i.Password, + &i.AvatarPicture, + &i.GoogleSignInPayload, + &i.BannedAt, + &i.BannedUntil, + &i.BanReason, + &i.IsPermaban, + &i.IsAdmin, + &i.IsCritics, + &i.IsVerified, + &i.IsActive, + &i.SocialMedia, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const updatePassword = `-- name: UpdatePassword :exec UPDATE users SET password = $1 diff --git a/util/token/maker.go b/util/token/maker.go index 7c9f103..aa65259 100644 --- a/util/token/maker.go +++ b/util/token/maker.go @@ -3,6 +3,6 @@ package token import "time" type Maker interface { - CreateToken(email string, userID int, duration time.Duration) (string, *Payload, error) + CreateToken(username string, userID int, duration time.Duration) (string, *Payload, error) VerifyToken(token string) (*Payload, error) } diff --git a/util/token/paseto.go b/util/token/paseto.go index 2001502..c82033c 100644 --- a/util/token/paseto.go +++ b/util/token/paseto.go @@ -26,8 +26,8 @@ func NewPasetoMaker(symmetricKey string) (Maker, error) { return maker, nil } -func (maker *PasetoMaker) CreateToken(email string, UserID int, duration time.Duration) (string, *Payload, error) { - payload, err := NewPayload(email, UserID, duration) +func (maker *PasetoMaker) CreateToken(Username string, UserID int, duration time.Duration) (string, *Payload, error) { + payload, err := NewPayload(Username, UserID, duration) if err != nil { return "", payload, err } diff --git a/util/token/paseto_test.go b/util/token/paseto_test.go index 529d0d2..7676fdc 100644 --- a/util/token/paseto_test.go +++ b/util/token/paseto_test.go @@ -31,7 +31,7 @@ func TestPasetoMaker(t *testing.T) { require.NotEmpty(t, payload) // require.NotZero(t, payload.ID) - require.Equal(t, email, payload.Email) + require.Equal(t, email, payload.Username) require.WithinDuration(t, issuedAt, payload.IssuedAt, time.Second) require.WithinDuration(t, expiredAt, payload.ExpiredAt, time.Second) } diff --git a/util/token/payload.go b/util/token/payload.go index 1081c00..9ba53e6 100644 --- a/util/token/payload.go +++ b/util/token/payload.go @@ -11,15 +11,15 @@ var ( ) type Payload struct { - Email string `json:"email"` + Username string `json:"email"` UserID int `json:"user_id"` IssuedAt time.Time `json:"issued_at"` ExpiredAt time.Time `json:"expired_at"` } -func NewPayload(email string, user_id int, duration time.Duration) (*Payload, error) { +func NewPayload(username string, user_id int, duration time.Duration) (*Payload, error) { payload := &Payload{ - Email: email, + Username: username, UserID: user_id, IssuedAt: time.Now(), ExpiredAt: time.Now().Add(duration),