You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
2.8 KiB

package auth
import (
"context"
"moredoc/conf"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/gofrs/uuid"
"github.com/golang-jwt/jwt"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
"google.golang.org/grpc"
)
type Auth struct {
jwt *conf.JWT
}
type UserClaims struct {
UserId int64
UUID string
jwt.StandardClaims
}
func NewAuth(jwt *conf.JWT) *Auth {
return &Auth{
jwt: jwt,
}
}
type ContextKey string
func (ck ContextKey) String() string {
return string(ck)
}
const (
CtxKeyUserClaims ContextKey = "user"
CtxKeyFullMethod ContextKey = "fullMethod"
)
func (p *Auth) AuthUnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
newCtx, err := p.AuthGRPC(ctx, info)
if err != nil {
return nil, err
}
return handler(newCtx, req)
}
}
func (p *Auth) AuthGRPC(ctx context.Context, info *grpc.UnaryServerInfo) (context.Context, error) {
ctx = context.WithValue(ctx, CtxKeyFullMethod, info.FullMethod)
token, err := grpc_auth.AuthFromMD(ctx, "bearer")
if err != nil {
return ctx, nil
}
claims, err := p.CheckJWTToken(token)
if err != nil || claims == nil || claims.ExpiresAt < time.Now().Unix() {
return ctx, nil
}
newCtx := context.WithValue(ctx, CtxKeyUserClaims, claims)
return newCtx, nil
}
func (p *Auth) AuthGin() gin.HandlerFunc {
return func(ctx *gin.Context) {
auth := ctx.Request.Header.Get("authorization")
bearer := strings.Split(auth, " ")
if auth == "" || len(bearer) != 2 {
ctx.Next()
return
}
token := bearer[1]
claims, err := p.CheckJWTToken(token)
if err != nil || claims == nil || claims.ExpiresAt < time.Now().Unix() {
ctx.Next()
return
}
ctx.Set(CtxKeyUserClaims.String(), claims)
ctx.Next()
}
}
// CreateUserJWTToken 生成用户JWT Token
func (p *Auth) CreateJWTToken(userId int64) (string, error) {
expireTime := time.Now().Add(time.Duration(p.jwt.ExpireDays) * 24 * time.Hour).Unix()
claims := UserClaims{
UserId: userId,
UUID: uuid.Must(uuid.NewV4()).String(),
StandardClaims: jwt.StandardClaims{
ExpiresAt: expireTime,
Issuer: "moredoc",
IssuedAt: time.Now().Unix(),
},
}
token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(p.jwt.Secret))
return token, err
}
// CheckUserJWTToken 验证用户JWT token
func (p *Auth) CheckJWTToken(token string) (*UserClaims, error) {
tokenClaims, err := jwt.ParseWithClaims(token, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(p.jwt.Secret), nil
})
if err != nil {
return nil, err
}
if tokenClaims != nil {
if claims, ok := tokenClaims.Claims.(*UserClaims); ok && tokenClaims.Valid {
return claims, nil
}
}
return nil, err
}