You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

526 lines
16 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package model
import (
"fmt"
"moredoc/util"
"strings"
"time"
"github.com/alexandrevicenzi/unchained"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
UserStatusNormal = 0
UserStatusDisabled = 1 // 禁用账户:禁止登录、禁止评论、禁止上传、禁止下载、禁止收藏
UserStatusCommentLimited = 2 // 禁止评论
UserStatusUploadLimited = 3 // 禁止上传
UserStatusDownloadLimited = 4 // 禁止下载
UserStatusFavoriteLimited = 5 // 禁止收藏
)
// 用户的公开信息字段
var UserPublicFields = []string{
"id", "username", "signature", "status", "avatar", "realname",
"doc_count", "follow_count", "fans_count", "favorite_count", "comment_count",
"created_at", "updated_at", "login_at", "credit_count",
}
type User struct {
Id int64 `form:"id" json:"id,omitempty" gorm:"primaryKey;autoIncrement;column:id;comment:用户 id;"`
Username string `form:"username" json:"username,omitempty" gorm:"column:username;type:varchar(64);size:64;index:username,unique;comment:用户名;"`
Password string `form:"password" json:"password,omitempty" gorm:"column:password;type:varchar(128);size:128;comment:密码;"`
Mobile string `form:"mobile" json:"mobile,omitempty" gorm:"column:mobile;type:varchar(20);size:20;index:mobile;comment:手机号;"`
Email string `form:"email" json:"email,omitempty" gorm:"column:email;type:varchar(64);size:64;index:idx_email;comment:联系邮箱;"`
Address string `form:"address" json:"address,omitempty" gorm:"column:address;type:varchar(255);size:255;comment:联系地址;"`
Signature string `form:"signature" json:"signature,omitempty" gorm:"column:signature;type:varchar(255);size:255;comment:签名;"`
LastLoginIp string `form:"last_login_ip" json:"last_login_ip,omitempty" gorm:"column:last_login_ip;type:varchar(64);size:64;comment:最后登录 ip 地址;"`
RegisterIp string `form:"register_ip" json:"register_ip,omitempty" gorm:"column:register_ip;type:varchar(64);size:64;comment:注册ip;"`
DocCount int `form:"doc_count" json:"doc_count,omitempty" gorm:"column:doc_count;type:int(10);default:0;comment:上传的文档数;"`
FollowCount int `form:"follow_count" json:"follow_count,omitempty" gorm:"column:follow_count;type:int(10);default:0;comment:关注数;"`
FansCount int `form:"fans_count" json:"fans_count,omitempty" gorm:"column:fans_count;type:int(10);default:0;comment:粉丝数;"`
FavoriteCount int `form:"favorite_count" json:"favorite_count,omitempty" gorm:"column:favorite_count;type:int(10);default:0;comment:收藏数;"`
CommentCount int `form:"comment_count" json:"comment_count,omitempty" gorm:"column:comment_count;type:int(11);size:11;default:0;comment:评论数;"`
CreditCount int `form:"credit_count" json:"credit_count,omitempty" gorm:"column:credit_count;type:int(11);size:11;default:0;comment:积分数量;"`
Avatar string `form:"avatar" json:"avatar,omitempty" gorm:"column:avatar;type:varchar(255);size:255;comment:头像;"`
Identity string `form:"identity" json:"identity,omitempty" gorm:"column:identity;type:char(18);size:18;comment:身份证号码;"`
Realname string `form:"realname" json:"realname,omitempty" gorm:"column:realname;type:varchar(20);size:20;comment:身份证姓名;"`
LoginAt *time.Time `form:"login_at" json:"login_at,omitempty" gorm:"column:login_at;type:datetime;comment:最后登录时间;"`
CreatedAt *time.Time `form:"created_at" json:"created_at,omitempty" gorm:"column:created_at;type:datetime;comment:创建时间;"`
UpdatedAt *time.Time `form:"updated_at" json:"updated_at,omitempty" gorm:"column:updated_at;type:datetime;comment:更新时间;"`
// Status int8 `form:"status" json:"status,omitempty" gorm:"column:status;type:tinyint(4);size:4;default:0;index:status;comment:用户状态0表示正常其他状态表示被惩罚;"`
}
func (User) TableName() string {
return tablePrefix + "user"
}
// GetUserPublicFields 获取用户公开字段
func (m *DBModel) GetUserPublicFields() []string {
return []string{"id", "username", "avatar", "signature", "doc_count", "follow_count", "fans_count", "favorite_count", "comment_count", "credit_count"}
}
// CreateUser 创建User
func (m *DBModel) CreateUser(user *User, groupIds ...int64) (err error) {
user.Password, _ = unchained.MakePassword(user.Password, unchained.GetRandomString(4), "md5")
sess := m.db.Begin()
defer func() {
if err != nil {
sess.Rollback()
} else {
sess.Commit()
}
}()
// 1. 添加用户
err = sess.Create(user).Error
if err != nil {
m.logger.Error("CreateUser", zap.Error(err))
return
}
// 2. 添加用户组
for _, groupId := range groupIds {
group := &UserGroup{
UserId: user.Id,
GroupId: groupId,
}
err = sess.Create(group).Error
if err != nil {
m.logger.Error("CreateUser", zap.Error(err))
return
}
// 3. 添加用户统计
err = sess.Model(&Group{}).Where("id = ?", groupId).Update("user_count", gorm.Expr("user_count + ?", 1)).Error
if err != nil {
m.logger.Error("CreateUser", zap.Error(err))
return
}
}
return
}
// UpdateUserPassword 更新User密码
func (m *DBModel) UpdateUserPassword(id interface{}, newPassword string, tx ...*gorm.DB) (err error) {
newPassword, _ = unchained.MakePassword(newPassword, unchained.GetRandomString(4), "md5")
user := &User{}
sess := m.db.Model(user)
if len(tx) > 0 {
sess = tx[0].Model(user)
}
err = sess.Where("id = ?", id).Update("password", newPassword).Error
if err != nil {
m.logger.Error("UpdateUserPassword", zap.Error(err))
}
return
}
// UpdateUser 更新User如果需要更新指定字段则请指定updateFields参数
func (m *DBModel) UpdateUser(user *User, updateFields ...string) (err error) {
db := m.db.Model(user)
updateFields = m.FilterValidFields(User{}.TableName(), updateFields...)
if len(updateFields) > 0 { // 更新指定字段
db = db.Select(updateFields)
}
err = db.Where("id = ?", user.Id).Updates(user).Error
if err != nil {
m.logger.Error("UpdateUser", zap.Error(err))
}
return
}
// GetUser 根据id获取User
func (m *DBModel) GetUser(id int64, fields ...string) (user User, err error) {
if id <= 0 {
return user, gorm.ErrRecordNotFound
}
db := m.db
fields = m.FilterValidFields(User{}.TableName(), fields...)
if len(fields) > 0 {
db = db.Select(fields)
}
err = db.Where("id = ?", id).First(&user).Error
return
}
// GetUserByUsername(username string, fields ...string) 根据唯一索引获取User
func (m *DBModel) GetUserByUsername(username string, fields ...string) (user User, err error) {
db := m.db
fields = m.FilterValidFields(User{}.TableName(), fields...)
if len(fields) > 0 {
db = db.Select(fields)
}
db = db.Where("username = ?", username)
err = db.First(&user).Error
if err != nil && err != gorm.ErrRecordNotFound {
m.logger.Error("GetUserByUsername", zap.Error(err))
return
}
return
}
func (m *DBModel) GetUserByEmail(email string, fields ...string) (user User, err error) {
db := m.db
fields = m.FilterValidFields(User{}.TableName(), fields...)
if len(fields) > 0 {
db = db.Select(fields)
}
db = db.Where("email = ?", email)
err = db.First(&user).Error
if err != nil && err != gorm.ErrRecordNotFound {
m.logger.Error("GetUserByEmail", zap.Error(err))
return
}
return
}
type OptionGetUserList struct {
Page int
Size int
WithCount bool // 是否返回总数
Ids []int64 // id列表
SelectFields []string // 查询字段
QueryRange map[string][2]interface{} // map[field][]{min,max}
QueryIn map[string][]interface{} // map[field][]{value1,value2,...}
QueryLike map[string][]interface{} // map[field][]{value1,value2,...}
Sort []string
}
// GetUserList 获取User列表
func (m *DBModel) GetUserList(opt *OptionGetUserList) (userList []User, total int64, err error) {
db := m.db.Model(&User{})
for field, rangeValue := range opt.QueryRange {
fields := m.FilterValidFields(User{}.TableName(), field)
if len(fields) == 0 {
continue
}
if rangeValue[0] != nil {
db = db.Where(fmt.Sprintf("%s >= ?", field), rangeValue[0])
}
if rangeValue[1] != nil {
db = db.Where(fmt.Sprintf("%s <= ?", field), rangeValue[1])
}
}
for field, values := range opt.QueryIn {
if field == "group_id" {
db = db.Joins(fmt.Sprintf("left JOIN %s ug ON ug.user_id = %s.id", UserGroup{}.TableName(), User{}.TableName())).Where("ug.group_id in (?)", values)
continue
}
fields := m.FilterValidFields(User{}.TableName(), field)
if len(fields) == 0 {
continue
}
db = db.Where(fmt.Sprintf("%s in (?)", field), values)
}
db = m.generateQueryLike(db, User{}.TableName(), opt.QueryLike)
if len(opt.Ids) > 0 {
db = db.Where("id in (?)", opt.Ids)
}
if opt.WithCount {
err = db.Count(&total).Error
if err != nil {
m.logger.Error("GetUserList", zap.Error(err))
return
}
}
opt.SelectFields = m.FilterValidFields(User{}.TableName(), opt.SelectFields...)
if len(opt.SelectFields) > 0 {
db = db.Select(opt.SelectFields)
}
var sorts []string
if len(opt.Sort) > 0 {
db = m.generateQuerySort(db, User{}.TableName(), opt.Sort)
} else {
sorts = append(sorts, "id desc")
}
if len(sorts) > 0 {
db = db.Order(strings.Join(sorts, ","))
}
opt.Page = util.LimitMin(opt.Page, 1)
opt.Size = util.LimitRange(opt.Size, 10, 1000)
db = db.Offset((opt.Page - 1) * opt.Size).Limit(opt.Size)
err = db.Find(&userList).Error
if err != nil && err != gorm.ErrRecordNotFound {
m.logger.Error("GetUserList", zap.Error(err))
}
return
}
// DeleteUser 删除数据
// TODO: 删除数据之后,存在 user_id 的关联表,需要删除对应数据,同时相关表的统计数值,也要随着减少
// TODO: 删除关联表数据,以及关联表的关联表数据,同时相关文件也一并删除掉
func (m *DBModel) DeleteUser(ids []int64) (err error) {
sess := m.db.Begin()
defer func() {
if err != nil {
sess.Rollback()
} else {
sess.Commit()
}
}()
// id==1的用户不允许删除
err = sess.Where("id in (?) and id != ?", ids, 1).Delete(&User{}).Error
if err != nil {
m.logger.Error("DeleteUser", zap.Error(err))
}
return
}
func (m *DBModel) initUser() (err error) {
// 如果不存在任意用户,则初始化一个用户作为管理员
var existUser User
m.db.Select("id").First(&existUser)
if existUser.Id > 0 {
return
}
// 初始化一个用户
user := &User{Username: "admin", Password: "mnt.ltd"}
var groupId int64 = 1 // ID==1的用户组为管理员组
err = m.CreateUser(user, groupId)
if err != nil {
m.logger.Error("initUser", zap.Error(err))
}
return
}
// GetUserPermissinsByUserId 根据用户ID获取用户权限
func (m *DBModel) GetUserPermissinsByUserId(userId int64) (permissions []*Permission, err error) {
if userId == 1 {
// id==1的用户拥有所有权限
err = m.db.Find(&permissions).Error
if err != nil && err != gorm.ErrRecordNotFound {
m.logger.Error("GetUserPermissinsByUserId", zap.Error(err))
}
return
}
sql := `SELECT
p.*
FROM
%s p
LEFT JOIN
%s gp
ON
p.id = gp.permission_id
LEFT JOIN
%s ug
ON
ug.group_id=gp.group_id
WHERE
ug.user_id=?
group by p.id
`
sql = fmt.Sprintf(sql, Permission{}.TableName(), GroupPermission{}.TableName(), UserGroup{}.TableName())
err = m.db.Raw(sql, userId).Find(&permissions).Error
if err != nil && err != gorm.ErrRecordNotFound {
m.logger.Error("GetUserPermissinsByUserId", zap.Error(err))
return
}
err = nil
return
}
// SetUserGroupAndPassword 设置用户组和密码
func (m *DBModel) SetUserGroupAndPassword(userId int64, groupId []int64, password ...string) (err error) {
tx := m.db.Begin()
defer func() {
if err != nil {
tx.Rollback()
} else {
tx.Commit()
}
}()
var (
existUsersGroups []UserGroup
userGroups []UserGroup
)
tx.Where("user_id = ?", userId).Find(&existUsersGroups)
// 删除旧的关联用户组
err = tx.Where("user_id = ?", userId).Delete(&UserGroup{}).Error
if err != nil {
m.logger.Error("SetUserGroupAndPassword", zap.Error(err))
return
}
// 设置用户组统计数据
for _, existUsersGroup := range existUsersGroups {
err = tx.Model(&Group{}).Where("id = ?", existUsersGroup.GroupId).Update("user_count", gorm.Expr("user_count - ?", 1)).Error
if err != nil {
m.logger.Error("SetUserGroupAndPassword", zap.Error(err))
return
}
}
// 设置用户组统计数据
for _, groupId := range groupId {
userGroups = append(userGroups, UserGroup{UserId: userId, GroupId: groupId})
err = tx.Model(&Group{}).Where("id = ?", groupId).Update("user_count", gorm.Expr("user_count + ?", 1)).Error
if err != nil {
m.logger.Error("SetUserGroupAndPassword", zap.Error(err))
return
}
}
if len(userGroups) > 0 {
err = tx.Create(&userGroups).Error
if err != nil {
m.logger.Error("SetUserGroupAndPassword", zap.Error(err))
return
}
}
if len(password) > 0 && password[0] != "" {
err = m.UpdateUserPassword(userId, password[0], tx)
if err != nil {
m.logger.Error("UpdateUserPassword", zap.Error(err))
return
}
}
return
}
// CanIUploadDocument 判断用户是否有上传文档的权限
// 1. 用户是否被禁用或被处罚禁止上传文档
// 2. 用户所在的用户组是否允许上传文档
func (m *DBModel) CanIAccessUploadDocument(userId int64) (yes bool) {
if inPunishing, _ := m.isInPunishing(userId, []int{PunishmentTypeDisabled, PunishmentTypeUploadLimited}); inPunishing {
return false
}
var (
tableGroup = Group{}.TableName()
tableUserGroup = UserGroup{}.TableName()
group Group
)
err := m.db.Select("g.id").Table(tableGroup+" g").Joins(
"left join "+tableUserGroup+" ug on g.id=ug.group_id",
).Where("ug.user_id = ? and g.enable_upload = ?", userId, true).Find(&group).Error
if err != nil {
m.logger.Error("CanIUploadDocument", zap.Error(err))
return
}
return group.Id > 0
}
// 用户是否可以下载文档:被禁用的账号或被禁止下载的账户不能下载
func (m *DBModel) CanIAccessDownload(userId int64) (yes bool, err error) {
yes, err = m.isInPunishing(userId, []int{PunishmentTypeDownloadLimited, PunishmentTypeDisabled})
yes = !yes
if err != nil {
m.logger.Error("CanIAccessDownload", zap.Error(err))
return
}
return
}
// 用户是否可以评论
func (m *DBModel) CanIAccessComment(userId int64) (yes bool, err error) {
yes, err = m.isInPunishing(userId, []int{PunishmentTypeCommentLimited, PunishmentTypeDisabled})
yes = !yes
if err != nil {
m.logger.Error("CanIAccessComment", zap.Error(err))
return
}
return
}
// 用户是否可以收藏文档
func (m *DBModel) CanIAccessFavorite(userId int64) (yes bool, err error) {
yes, err = m.isInPunishing(userId, []int{PunishmentTypeFavoriteLimited, PunishmentTypeDisabled})
yes = !yes
if err != nil {
m.logger.Error("CanIAccessFavorite", zap.Error(err))
return
}
return
}
// 用户是否发表评论
func (m *DBModel) CanIPublishComment(userId int64) (defaultCommentStatus int8, err error) {
if userId <= 0 {
return
}
var (
group Group
userGroup UserGroup
comment Comment
)
m.db.Select("g.id").Table(group.TableName()+" g").Joins(
"left join "+userGroup.TableName()+" ug on g.id=ug.group_id",
).Where("ug.user_id = ? and g.enable_comment_approval = ?", userId, false).Find(&group)
// 评论不需要审核
if group.Id > 0 {
defaultCommentStatus = CommentStatusApproved
} else {
defaultCommentStatus = CommentStatusPending
}
// 评论时间间隔
commentInterval := m.GetConfigOfSecurity(ConfigSecurityCommentInterval).CommentInterval
if commentInterval <= 0 {
return
}
// 获取用户最新的一条评论信息
m.db.Select("id", "created_at").Where("user_id = ?", userId).Order("id desc").Find(&comment)
if comment.Id <= 0 {
return
}
seconds := int32(time.Since(*comment.CreatedAt).Seconds())
left := commentInterval - seconds
if left > 0 {
err = fmt.Errorf("您的评论太快了,请等待 %d 秒后再试", left)
return
}
return
}
func (s *DBModel) CountUser(status ...int) (count int64, err error) {
db := s.db.Model(&User{})
if len(status) > 0 {
db = db.Where("status in (?)", status)
}
err = db.Count(&count).Error
if err != nil {
s.logger.Error("CountUser", zap.Error(err))
return
}
return
}