add user sessions

This commit is contained in:
nochill 2023-09-21 21:38:41 +07:00
parent bee8f6e5b8
commit bcb242c0aa
15 changed files with 332 additions and 78 deletions

View File

@ -1,7 +1,12 @@
package api package api
import ( import (
"database/sql"
"fmt"
"net/http"
"time" "time"
"github.com/gin-gonic/gin"
) )
type renewAccessRequest struct { type renewAccessRequest struct {
@ -13,84 +18,78 @@ type renewAccessResponse struct {
AccessTokenExpiresAt time.Time `json:"access_token_expires_at"` AccessTokenExpiresAt time.Time `json:"access_token_expires_at"`
} }
// func (server *Server) renewAccessToken(ctx *gin.Context) { func (server *Server) renewAccessToken(ctx *gin.Context) {
// var req renewAccessRequest var req renewAccessRequest
// if err := ctx.ShouldBindJSON(&req); err != nil { if err := ctx.ShouldBindJSON(&req); err != nil {
// ctx.JSON(http.StatusBadRequest, errorResponse(err, "")) ctx.JSON(http.StatusBadRequest, ErrorResponse(err, ""))
// return return
// } }
// refreshPayload, err := server.tokenMaker.VerifyToken(req.RefreshToken) refreshPayload, err := server.TokenMaker.VerifyToken(req.RefreshToken)
// if err != nil { if err != nil {
// ctx.JSON(http.StatusUnauthorized, errorResponse(err, "")) ctx.JSON(http.StatusUnauthorized, ErrorResponse(err, ""))
// return return
// } }
// session, err := server.store.GetSession(ctx, refreshPayload.ID) session, err := server.Store.GetSession(ctx, int32(refreshPayload.UserID))
// if err != nil { if err != nil {
// if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// ctx.JSON(http.StatusNotFound, errorResponse(err, "")) ctx.JSON(http.StatusNotFound, ErrorResponse(err, ""))
// return return
// } }
// ctx.JSON(http.StatusInternalServerError, errorResponse(err, "")) ctx.JSON(http.StatusInternalServerError, ErrorResponse(err, ""))
// return return
// } }
// if session.IsBlocked { if session.IsBlocked {
// err := fmt.Errorf("blocked session") err := fmt.Errorf("blocked session")
// ctx.JSON(http.StatusUnauthorized, errorResponse(err, "")) ctx.JSON(http.StatusUnauthorized, ErrorResponse(err, ""))
// return return
// } }
// if session.Email != refreshPayload.Email { if session.Username != refreshPayload.Username {
// err := fmt.Errorf("incorrect session user") err := fmt.Errorf("incorrect session user")
// ctx.JSON(http.StatusUnauthorized, errorResponse(err, "")) ctx.JSON(http.StatusUnauthorized, ErrorResponse(err, ""))
// return return
// } }
// if session.RefreshToken != req.RefreshToken { if session.RefreshToken != req.RefreshToken {
// err := fmt.Errorf("mismatched session token") err := fmt.Errorf("mismatched session token")
// ctx.JSON(http.StatusUnauthorized, errorResponse(err, "")) ctx.JSON(http.StatusUnauthorized, ErrorResponse(err, ""))
// return return
// } }
// if time.Now().After(refreshPayload.ExpiredAt) { if time.Now().After(refreshPayload.ExpiredAt) {
// err := fmt.Errorf("Expired session") err := fmt.Errorf("expired session")
// ctx.JSON(http.StatusUnauthorized, errorResponse(err, "")) ctx.JSON(http.StatusUnauthorized, ErrorResponse(err, ""))
// return return
// } }
// user, err := server.store.GetUserByEmail(ctx, refreshPayload.Email) user, err := server.Store.GetUser(ctx, refreshPayload.Username)
// if err != nil { if err != nil {
// if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// ctx.JSON(http.StatusNotFound, errorResponse(err, "")) ctx.JSON(http.StatusNotFound, ErrorResponse(err, ""))
// return return
// } }
// ctx.JSON(http.StatusInternalServerError, errorResponse(err, "")) ctx.JSON(http.StatusInternalServerError, ErrorResponse(err, ""))
// return return
// } }
// merchant, err := server.store.GetMerchantByUserId(ctx, user.ID) accessToken, accessPayload, err := server.TokenMaker.CreateToken(
// if err != nil { refreshPayload.Username,
// ctx.JSON(http.StatusInternalServerError, errorResponse(err, "")) int(user.ID),
// return server.Config.TokenDuration,
// } )
// accessToken, accessPayload, err := server.tokenMaker.CreateToken( if err != nil {
// refreshPayload.Email, ctx.JSON(http.StatusInternalServerError, ErrorResponse(err, ""))
// merchant.ID.String(), return
// server.config.TokenDuration, }
// )
// if err != nil { res := renewAccessResponse{
// ctx.JSON(http.StatusInternalServerError, errorResponse(err, "")) AccesToken: accessToken,
// return AccessTokenExpiresAt: accessPayload.ExpiredAt,
// } }
// res := renewAccessResponse{ ctx.JSON(http.StatusOK, res)
// AccesToken: accessToken, }
// AccessTokenExpiresAt: accessPayload.ExpiredAt,
// }
// ctx.JSON(http.StatusOK, res)
// }

View File

@ -31,6 +31,14 @@ type createUserResponse struct {
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
} }
type userTokenResponse struct {
SessionID int32 `json:"session_id"`
AccesToken string `json:"access_token"`
AccessTokenExpiresAt time.Time `json:"access_token_expires_at"`
RefreshToken string `json:"refresh_token"`
RefreshTokenExpiresAt time.Time `json:"refresh_token_expires_at"`
}
func (server *Server) createUser(ctx *gin.Context) { func (server *Server) createUser(ctx *gin.Context) {
var req createUserRequest var req createUserRequest
if err := ctx.ShouldBindJSON(&req); err != nil { if err := ctx.ShouldBindJSON(&req); err != nil {
@ -65,6 +73,45 @@ func (server *Server) createUser(ctx *gin.Context) {
return return
} }
accessToken, accessPayload, err := server.TokenMaker.CreateToken(
user.Username,
int(user.ID),
server.Config.TokenDuration,
)
if err != nil {
ctx.JSON(http.StatusInternalServerError, ErrorResponse(err, "Something went wrong while creating token"))
return
}
refreshToken, refreshTokenPayload, err := server.TokenMaker.CreateToken(
user.Username,
int(user.ID),
server.Config.RefreshTokenDuration,
)
session, err := server.Store.CreateSession(ctx, db.CreateSessionParams{
Username: user.Username,
RefreshToken: refreshToken,
UserAgent: ctx.Request.UserAgent(),
ClientIp: ctx.ClientIP(),
IsBlocked: false,
ExpiresAt: refreshTokenPayload.ExpiredAt,
})
if err != nil {
ctx.JSON(http.StatusInternalServerError, ErrorResponse(err, "Something went wrong while saving sessions"))
return
}
tokenResponse := userTokenResponse{
SessionID: session.ID,
AccesToken: accessToken,
AccessTokenExpiresAt: accessPayload.ExpiredAt,
RefreshToken: refreshToken,
RefreshTokenExpiresAt: refreshTokenPayload.ExpiredAt,
}
res := createUserResponse{ res := createUserResponse{
ID: user.ID, ID: user.ID,
Username: user.Username, Username: user.Username,
@ -80,5 +127,8 @@ func (server *Server) createUser(ctx *gin.Context) {
UpdatedAt: user.UpdatedAt.Time, UpdatedAt: user.UpdatedAt.Time,
} }
ctx.JSON(http.StatusOK, res) ctx.JSON(http.StatusOK, gin.H{
"token": tokenResponse,
"user": res,
})
} }

View File

@ -0,0 +1 @@
DROP TABLE IF EXISTS user_sessions;

View File

@ -0,0 +1,13 @@
CREATE TABLE user_sessions(
"id" serial primary key not null,
"index_id" bigserial not null,
"username" varchar not null references "users"("username") not null,
"refresh_token" varchar not null,
"user_agent" varchar not null,
"client_ip" varchar not null,
"is_blocked" boolean not null default false,
"expires_at" timestamp not null,
"created_at" timestamp default(now())
);
CREATE INDEX ON "user_sessions"("index_id");

View File

@ -49,6 +49,21 @@ func (mr *MockStoreMockRecorder) CreateLocation(arg0, arg1 interface{}) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateLocation", reflect.TypeOf((*MockStore)(nil).CreateLocation), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateLocation", reflect.TypeOf((*MockStore)(nil).CreateLocation), arg0, arg1)
} }
// CreateSession mocks base method.
func (m *MockStore) CreateSession(arg0 context.Context, arg1 db.CreateSessionParams) (db.UserSession, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateSession", arg0, arg1)
ret0, _ := ret[0].(db.UserSession)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateSession indicates an expected call of CreateSession.
func (mr *MockStoreMockRecorder) CreateSession(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSession", reflect.TypeOf((*MockStore)(nil).CreateSession), arg0, arg1)
}
// CreateUser mocks base method. // CreateUser mocks base method.
func (m *MockStore) CreateUser(arg0 context.Context, arg1 db.CreateUserParams) (db.User, error) { func (m *MockStore) CreateUser(arg0 context.Context, arg1 db.CreateUserParams) (db.User, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -154,6 +169,21 @@ func (mr *MockStoreMockRecorder) GetLocationTag(arg0, arg1 interface{}) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLocationTag", reflect.TypeOf((*MockStore)(nil).GetLocationTag), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLocationTag", reflect.TypeOf((*MockStore)(nil).GetLocationTag), arg0, arg1)
} }
// GetSession mocks base method.
func (m *MockStore) GetSession(arg0 context.Context, arg1 int32) (db.UserSession, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetSession", arg0, arg1)
ret0, _ := ret[0].(db.UserSession)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetSession indicates an expected call of GetSession.
func (mr *MockStoreMockRecorder) GetSession(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSession", reflect.TypeOf((*MockStore)(nil).GetSession), arg0, arg1)
}
// GetTopListLocations mocks base method. // GetTopListLocations mocks base method.
func (m *MockStore) GetTopListLocations(arg0 context.Context, arg1 db.GetTopListLocationsParams) ([]db.GetTopListLocationsRow, error) { func (m *MockStore) GetTopListLocations(arg0 context.Context, arg1 db.GetTopListLocationsParams) ([]db.GetTopListLocationsRow, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -169,6 +199,21 @@ func (mr *MockStoreMockRecorder) GetTopListLocations(arg0, arg1 interface{}) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTopListLocations", reflect.TypeOf((*MockStore)(nil).GetTopListLocations), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTopListLocations", reflect.TypeOf((*MockStore)(nil).GetTopListLocations), arg0, arg1)
} }
// GetUser mocks base method.
func (m *MockStore) GetUser(arg0 context.Context, arg1 string) (db.User, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUser", arg0, arg1)
ret0, _ := ret[0].(db.User)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUser indicates an expected call of GetUser.
func (mr *MockStoreMockRecorder) GetUser(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUser", reflect.TypeOf((*MockStore)(nil).GetUser), arg0, arg1)
}
// UpdatePassword mocks base method. // UpdatePassword mocks base method.
func (m *MockStore) UpdatePassword(arg0 context.Context, arg1 db.UpdatePasswordParams) error { func (m *MockStore) UpdatePassword(arg0 context.Context, arg1 db.UpdatePasswordParams) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

16
db/queries/sessions.sql Normal file
View File

@ -0,0 +1,16 @@
-- name: CreateSession :one
INSERT INTO user_sessions (
username,
refresh_token,
user_agent,
client_ip,
is_blocked,
expires_at
) VALUES (
$1, $2, $3, $4, $5, $6
) RETURNING *;
-- name: GetSession :one
SELECT * FROM user_sessions
WHERE id = $1
LIMIT 1;

View File

@ -16,6 +16,11 @@ WHERE
RETURNING *; RETURNING *;
-- name: GetUser :one
SELECT * FROM USERS
WHERE username = $1;
-- name: UpdatePassword :exec -- name: UpdatePassword :exec
UPDATE users UPDATE users
SET password = $1 SET password = $1

View File

@ -244,3 +244,15 @@ type UserReport struct {
CreatedAt sql.NullTime `json:"created_at"` CreatedAt sql.NullTime `json:"created_at"`
UpdatedAt sql.NullTime `json:"updated_at"` UpdatedAt sql.NullTime `json:"updated_at"`
} }
type UserSession struct {
ID int32 `json:"id"`
IndexID int64 `json:"index_id"`
Username string `json:"username"`
RefreshToken string `json:"refresh_token"`
UserAgent string `json:"user_agent"`
ClientIp string `json:"client_ip"`
IsBlocked bool `json:"is_blocked"`
ExpiresAt time.Time `json:"expires_at"`
CreatedAt sql.NullTime `json:"created_at"`
}

View File

@ -10,12 +10,15 @@ import (
type Querier interface { type Querier interface {
CreateLocation(ctx context.Context, arg CreateLocationParams) error CreateLocation(ctx context.Context, arg CreateLocationParams) error
CreateSession(ctx context.Context, arg CreateSessionParams) (UserSession, error)
CreateUser(ctx context.Context, arg CreateUserParams) (User, error) CreateUser(ctx context.Context, arg CreateUserParams) (User, error)
GetCountImageByLocation(ctx context.Context, imageOf int32) (int64, error) GetCountImageByLocation(ctx context.Context, imageOf int32) (int64, error)
GetListLocations(ctx context.Context) ([]Location, error) GetListLocations(ctx context.Context) ([]Location, error)
GetListRecentLocationsWithRatings(ctx context.Context, limit int32) ([]GetListRecentLocationsWithRatingsRow, error) GetListRecentLocationsWithRatings(ctx context.Context, limit int32) ([]GetListRecentLocationsWithRatingsRow, error)
GetLocation(ctx context.Context, id int32) (GetLocationRow, error) GetLocation(ctx context.Context, id int32) (GetLocationRow, error)
GetLocationTag(ctx context.Context, targetID int32) ([]string, error) GetLocationTag(ctx context.Context, targetID int32) ([]string, error)
GetSession(ctx context.Context, id int32) (UserSession, error)
GetUser(ctx context.Context, username string) (User, error)
UpdatePassword(ctx context.Context, arg UpdatePasswordParams) error UpdatePassword(ctx context.Context, arg UpdatePasswordParams) error
UpdateUser(ctx context.Context, arg UpdateUserParams) (User, error) UpdateUser(ctx context.Context, arg UpdateUserParams) (User, error)
} }

80
db/sqlc/sessions.sql.go Normal file
View File

@ -0,0 +1,80 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.20.0
// source: sessions.sql
package db
import (
"context"
"time"
)
const createSession = `-- name: CreateSession :one
INSERT INTO user_sessions (
username,
refresh_token,
user_agent,
client_ip,
is_blocked,
expires_at
) VALUES (
$1, $2, $3, $4, $5, $6
) RETURNING id, index_id, username, refresh_token, user_agent, client_ip, is_blocked, expires_at, created_at
`
type CreateSessionParams struct {
Username string `json:"username"`
RefreshToken string `json:"refresh_token"`
UserAgent string `json:"user_agent"`
ClientIp string `json:"client_ip"`
IsBlocked bool `json:"is_blocked"`
ExpiresAt time.Time `json:"expires_at"`
}
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (UserSession, error) {
row := q.db.QueryRowContext(ctx, createSession,
arg.Username,
arg.RefreshToken,
arg.UserAgent,
arg.ClientIp,
arg.IsBlocked,
arg.ExpiresAt,
)
var i UserSession
err := row.Scan(
&i.ID,
&i.IndexID,
&i.Username,
&i.RefreshToken,
&i.UserAgent,
&i.ClientIp,
&i.IsBlocked,
&i.ExpiresAt,
&i.CreatedAt,
)
return i, err
}
const getSession = `-- name: GetSession :one
SELECT id, index_id, username, refresh_token, user_agent, client_ip, is_blocked, expires_at, created_at FROM user_sessions
WHERE id = $1
LIMIT 1
`
func (q *Queries) GetSession(ctx context.Context, id int32) (UserSession, error) {
row := q.db.QueryRowContext(ctx, getSession, id)
var i UserSession
err := row.Scan(
&i.ID,
&i.IndexID,
&i.Username,
&i.RefreshToken,
&i.UserAgent,
&i.ClientIp,
&i.IsBlocked,
&i.ExpiresAt,
&i.CreatedAt,
)
return i, err
}

View File

@ -48,6 +48,36 @@ func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, e
return i, err return i, err
} }
const getUser = `-- name: GetUser :one
SELECT id, email, username, password, avatar_picture, google_sign_in_payload, banned_at, banned_until, ban_reason, is_permaban, is_admin, is_critics, is_verified, is_active, social_media, created_at, updated_at FROM USERS
WHERE username = $1
`
func (q *Queries) GetUser(ctx context.Context, username string) (User, error) {
row := q.db.QueryRowContext(ctx, getUser, username)
var i User
err := row.Scan(
&i.ID,
&i.Email,
&i.Username,
&i.Password,
&i.AvatarPicture,
&i.GoogleSignInPayload,
&i.BannedAt,
&i.BannedUntil,
&i.BanReason,
&i.IsPermaban,
&i.IsAdmin,
&i.IsCritics,
&i.IsVerified,
&i.IsActive,
&i.SocialMedia,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const updatePassword = `-- name: UpdatePassword :exec const updatePassword = `-- name: UpdatePassword :exec
UPDATE users UPDATE users
SET password = $1 SET password = $1

View File

@ -3,6 +3,6 @@ package token
import "time" import "time"
type Maker interface { type Maker interface {
CreateToken(email string, userID int, duration time.Duration) (string, *Payload, error) CreateToken(username string, userID int, duration time.Duration) (string, *Payload, error)
VerifyToken(token string) (*Payload, error) VerifyToken(token string) (*Payload, error)
} }

View File

@ -26,8 +26,8 @@ func NewPasetoMaker(symmetricKey string) (Maker, error) {
return maker, nil return maker, nil
} }
func (maker *PasetoMaker) CreateToken(email string, UserID int, duration time.Duration) (string, *Payload, error) { func (maker *PasetoMaker) CreateToken(Username string, UserID int, duration time.Duration) (string, *Payload, error) {
payload, err := NewPayload(email, UserID, duration) payload, err := NewPayload(Username, UserID, duration)
if err != nil { if err != nil {
return "", payload, err return "", payload, err
} }

View File

@ -31,7 +31,7 @@ func TestPasetoMaker(t *testing.T) {
require.NotEmpty(t, payload) require.NotEmpty(t, payload)
// require.NotZero(t, payload.ID) // require.NotZero(t, payload.ID)
require.Equal(t, email, payload.Email) require.Equal(t, email, payload.Username)
require.WithinDuration(t, issuedAt, payload.IssuedAt, time.Second) require.WithinDuration(t, issuedAt, payload.IssuedAt, time.Second)
require.WithinDuration(t, expiredAt, payload.ExpiredAt, time.Second) require.WithinDuration(t, expiredAt, payload.ExpiredAt, time.Second)
} }

View File

@ -11,15 +11,15 @@ var (
) )
type Payload struct { type Payload struct {
Email string `json:"email"` Username string `json:"email"`
UserID int `json:"user_id"` UserID int `json:"user_id"`
IssuedAt time.Time `json:"issued_at"` IssuedAt time.Time `json:"issued_at"`
ExpiredAt time.Time `json:"expired_at"` ExpiredAt time.Time `json:"expired_at"`
} }
func NewPayload(email string, user_id int, duration time.Duration) (*Payload, error) { func NewPayload(username string, user_id int, duration time.Duration) (*Payload, error) {
payload := &Payload{ payload := &Payload{
Email: email, Username: username,
UserID: user_id, UserID: user_id,
IssuedAt: time.Now(), IssuedAt: time.Now(),
ExpiredAt: time.Now().Add(duration), ExpiredAt: time.Now().Add(duration),