You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

526 lines
16 KiB

package model
import (
"fmt"
2 years ago
"moredoc/util"
"strings"
"time"
"github.com/alexandrevicenzi/unchained"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
11 months ago
UserStatusNormal = 0
UserStatusDisabled = 1 // 禁用账户:禁止登录、禁止评论、禁止上传、禁止下载、禁止收藏
UserStatusCommentLimited = 2 // 禁止评论
UserStatusUploadLimited = 3 // 禁止上传
UserStatusDownloadLimited = 4 // 禁止下载
UserStatusFavoriteLimited = 5 // 禁止收藏
)
// 用户的公开信息字段
var UserPublicFields = []string{
"id", "username", "signature", "status", "avatar", "realname",
"doc_count", "follow_count", "fans_count", "favorite_count", "comment_count",
"created_at", "updated_at", "login_at", "credit_count",
}
type User struct {
Id int64 `form:"id" json:"id,omitempty" gorm:"primaryKey;autoIncrement;column:id;comment:用户 id;"`
Username string `form:"username" json:"username,omitempty" gorm:"column:username;type:varchar(64);size:64;index:username,unique;comment:用户名;"`
Password string `form:"password" json:"password,omitempty" gorm:"column:password;type:varchar(128);size:128;comment:密码;"`
Mobile string `form:"mobile" json:"mobile,omitempty" gorm:"column:mobile;type:varchar(20);size:20;index:mobile;comment:手机号;"`
Email string `form:"email" json:"email,omitempty" gorm:"column:email;type:varchar(64);size:64;index:idx_email;comment:联系邮箱;"`
Address string `form:"address" json:"address,omitempty" gorm:"column:address;type:varchar(255);size:255;comment:联系地址;"`
Signature string `form:"signature" json:"signature,omitempty" gorm:"column:signature;type:varchar(255);size:255;comment:签名;"`
LastLoginIp string `form:"last_login_ip" json:"last_login_ip,omitempty" gorm:"column:last_login_ip;type:varchar(64);size:64;comment:最后登录 ip 地址;"`
RegisterIp string `form:"register_ip" json:"register_ip,omitempty" gorm:"column:register_ip;type:varchar(64);size:64;comment:注册ip;"`
DocCount int `form:"doc_count" json:"doc_count,omitempty" gorm:"column:doc_count;type:int(10);default:0;comment:上传的文档数;"`
FollowCount int `form:"follow_count" json:"follow_count,omitempty" gorm:"column:follow_count;type:int(10);default:0;comment:关注数;"`
FansCount int `form:"fans_count" json:"fans_count,omitempty" gorm:"column:fans_count;type:int(10);default:0;comment:粉丝数;"`
FavoriteCount int `form:"favorite_count" json:"favorite_count,omitempty" gorm:"column:favorite_count;type:int(10);default:0;comment:收藏数;"`
CommentCount int `form:"comment_count" json:"comment_count,omitempty" gorm:"column:comment_count;type:int(11);size:11;default:0;comment:评论数;"`
CreditCount int `form:"credit_count" json:"credit_count,omitempty" gorm:"column:credit_count;type:int(11);size:11;default:0;comment:积分数量;"`
Avatar string `form:"avatar" json:"avatar,omitempty" gorm:"column:avatar;type:varchar(255);size:255;comment:头像;"`
Identity string `form:"identity" json:"identity,omitempty" gorm:"column:identity;type:char(18);size:18;comment:身份证号码;"`
Realname string `form:"realname" json:"realname,omitempty" gorm:"column:realname;type:varchar(20);size:20;comment:身份证姓名;"`
LoginAt *time.Time `form:"login_at" json:"login_at,omitempty" gorm:"column:login_at;type:datetime;comment:最后登录时间;"`
CreatedAt *time.Time `form:"created_at" json:"created_at,omitempty" gorm:"column:created_at;type:datetime;comment:创建时间;"`
UpdatedAt *time.Time `form:"updated_at" json:"updated_at,omitempty" gorm:"column:updated_at;type:datetime;comment:更新时间;"`
11 months ago
// Status int8 `form:"status" json:"status,omitempty" gorm:"column:status;type:tinyint(4);size:4;default:0;index:status;comment:用户状态0表示正常其他状态表示被惩罚;"`
}
func (User) TableName() string {
return tablePrefix + "user"
}
2 years ago
// GetUserPublicFields 获取用户公开字段
func (m *DBModel) GetUserPublicFields() []string {
return []string{"id", "username", "avatar", "signature", "doc_count", "follow_count", "fans_count", "favorite_count", "comment_count", "credit_count"}
2 years ago
}
// CreateUser 创建User
2 years ago
func (m *DBModel) CreateUser(user *User, groupIds ...int64) (err error) {
user.Password, _ = unchained.MakePassword(user.Password, unchained.GetRandomString(4), "md5")
sess := m.db.Begin()
defer func() {
if err != nil {
sess.Rollback()
} else {
sess.Commit()
}
}()
// 1. 添加用户
err = sess.Create(user).Error
if err != nil {
m.logger.Error("CreateUser", zap.Error(err))
return
}
// 2. 添加用户组
2 years ago
for _, groupId := range groupIds {
group := &UserGroup{
UserId: user.Id,
GroupId: groupId,
}
err = sess.Create(group).Error
if err != nil {
m.logger.Error("CreateUser", zap.Error(err))
return
}
2 years ago
// 3. 添加用户统计
err = sess.Model(&Group{}).Where("id = ?", groupId).Update("user_count", gorm.Expr("user_count + ?", 1)).Error
if err != nil {
m.logger.Error("CreateUser", zap.Error(err))
return
}
}
return
}
// UpdateUserPassword 更新User密码
2 years ago
func (m *DBModel) UpdateUserPassword(id interface{}, newPassword string, tx ...*gorm.DB) (err error) {
newPassword, _ = unchained.MakePassword(newPassword, unchained.GetRandomString(4), "md5")
2 years ago
user := &User{}
sess := m.db.Model(user)
if len(tx) > 0 {
sess = tx[0].Model(user)
}
err = sess.Where("id = ?", id).Update("password", newPassword).Error
if err != nil {
m.logger.Error("UpdateUserPassword", zap.Error(err))
}
return
}
// UpdateUser 更新User如果需要更新指定字段则请指定updateFields参数
func (m *DBModel) UpdateUser(user *User, updateFields ...string) (err error) {
db := m.db.Model(user)
updateFields = m.FilterValidFields(User{}.TableName(), updateFields...)
if len(updateFields) > 0 { // 更新指定字段
db = db.Select(updateFields)
}
err = db.Where("id = ?", user.Id).Updates(user).Error
if err != nil {
m.logger.Error("UpdateUser", zap.Error(err))
}
return
}
// GetUser 根据id获取User
func (m *DBModel) GetUser(id int64, fields ...string) (user User, err error) {
if id <= 0 {
return user, gorm.ErrRecordNotFound
}
db := m.db
fields = m.FilterValidFields(User{}.TableName(), fields...)
if len(fields) > 0 {
db = db.Select(fields)
}
err = db.Where("id = ?", id).First(&user).Error
return
}
// GetUserByUsername(username string, fields ...string) 根据唯一索引获取User
func (m *DBModel) GetUserByUsername(username string, fields ...string) (user User, err error) {
db := m.db
fields = m.FilterValidFields(User{}.TableName(), fields...)
if len(fields) > 0 {
db = db.Select(fields)
}
db = db.Where("username = ?", username)
err = db.First(&user).Error
if err != nil && err != gorm.ErrRecordNotFound {
m.logger.Error("GetUserByUsername", zap.Error(err))
return
}
return
}
func (m *DBModel) GetUserByEmail(email string, fields ...string) (user User, err error) {
db := m.db
fields = m.FilterValidFields(User{}.TableName(), fields...)
if len(fields) > 0 {
db = db.Select(fields)
}
db = db.Where("email = ?", email)
err = db.First(&user).Error
if err != nil && err != gorm.ErrRecordNotFound {
m.logger.Error("GetUserByEmail", zap.Error(err))
return
}
return
}
type OptionGetUserList struct {
Page int
Size int
WithCount bool // 是否返回总数
11 months ago
Ids []int64 // id列表
SelectFields []string // 查询字段
QueryRange map[string][2]interface{} // map[field][]{min,max}
QueryIn map[string][]interface{} // map[field][]{value1,value2,...}
QueryLike map[string][]interface{} // map[field][]{value1,value2,...}
Sort []string
}
// GetUserList 获取User列表
func (m *DBModel) GetUserList(opt *OptionGetUserList) (userList []User, total int64, err error) {
db := m.db.Model(&User{})
for field, rangeValue := range opt.QueryRange {
fields := m.FilterValidFields(User{}.TableName(), field)
if len(fields) == 0 {
continue
}
if rangeValue[0] != nil {
db = db.Where(fmt.Sprintf("%s >= ?", field), rangeValue[0])
}
if rangeValue[1] != nil {
db = db.Where(fmt.Sprintf("%s <= ?", field), rangeValue[1])
}
}
for field, values := range opt.QueryIn {
if field == "group_id" {
db = db.Joins(fmt.Sprintf("left JOIN %s ug ON ug.user_id = %s.id", UserGroup{}.TableName(), User{}.TableName())).Where("ug.group_id in (?)", values)
continue
}
fields := m.FilterValidFields(User{}.TableName(), field)
if len(fields) == 0 {
continue
}
db = db.Where(fmt.Sprintf("%s in (?)", field), values)
}
2 years ago
db = m.generateQueryLike(db, User{}.TableName(), opt.QueryLike)
if len(opt.Ids) > 0 {
db = db.Where("id in (?)", opt.Ids)
}
if opt.WithCount {
err = db.Count(&total).Error
if err != nil {
m.logger.Error("GetUserList", zap.Error(err))
return
}
}
opt.SelectFields = m.FilterValidFields(User{}.TableName(), opt.SelectFields...)
if len(opt.SelectFields) > 0 {
db = db.Select(opt.SelectFields)
}
var sorts []string
if len(opt.Sort) > 0 {
2 years ago
db = m.generateQuerySort(db, User{}.TableName(), opt.Sort)
} else {
sorts = append(sorts, "id desc")
}
if len(sorts) > 0 {
db = db.Order(strings.Join(sorts, ","))
}
2 years ago
opt.Page = util.LimitMin(opt.Page, 1)
opt.Size = util.LimitRange(opt.Size, 10, 1000)
2 years ago
db = db.Offset((opt.Page - 1) * opt.Size).Limit(opt.Size)
err = db.Find(&userList).Error
if err != nil && err != gorm.ErrRecordNotFound {
m.logger.Error("GetUserList", zap.Error(err))
}
return
}
// DeleteUser 删除数据
// TODO: 删除数据之后,存在 user_id 的关联表,需要删除对应数据,同时相关表的统计数值,也要随着减少
2 years ago
// TODO: 删除关联表数据,以及关联表的关联表数据,同时相关文件也一并删除掉
func (m *DBModel) DeleteUser(ids []int64) (err error) {
sess := m.db.Begin()
defer func() {
if err != nil {
sess.Rollback()
} else {
sess.Commit()
}
}()
2 years ago
// id==1的用户不允许删除
err = sess.Where("id in (?) and id != ?", ids, 1).Delete(&User{}).Error
if err != nil {
m.logger.Error("DeleteUser", zap.Error(err))
}
return
}
func (m *DBModel) initUser() (err error) {
// 如果不存在任意用户,则初始化一个用户作为管理员
var existUser User
m.db.Select("id").First(&existUser)
if existUser.Id > 0 {
return
}
// 初始化一个用户
2 years ago
user := &User{Username: "admin", Password: "mnt.ltd"}
2 years ago
var groupId int64 = 1 // ID==1的用户组为管理员组
err = m.CreateUser(user, groupId)
if err != nil {
m.logger.Error("initUser", zap.Error(err))
}
return
}
2 years ago
// GetUserPermissinsByUserId 根据用户ID获取用户权限
func (m *DBModel) GetUserPermissinsByUserId(userId int64) (permissions []*Permission, err error) {
2 years ago
if userId == 1 {
// id==1的用户拥有所有权限
err = m.db.Find(&permissions).Error
if err != nil && err != gorm.ErrRecordNotFound {
m.logger.Error("GetUserPermissinsByUserId", zap.Error(err))
}
return
}
2 years ago
sql := `SELECT
p.*
FROM
%s p
LEFT JOIN
%s gp
ON
p.id = gp.permission_id
LEFT JOIN
%s ug
ON
ug.group_id=gp.group_id
WHERE
ug.user_id=?
group by p.id
`
sql = fmt.Sprintf(sql, Permission{}.TableName(), GroupPermission{}.TableName(), UserGroup{}.TableName())
err = m.db.Raw(sql, userId).Find(&permissions).Error
if err != nil && err != gorm.ErrRecordNotFound {
m.logger.Error("GetUserPermissinsByUserId", zap.Error(err))
return
}
err = nil
return
}
2 years ago
// SetUserGroupAndPassword 设置用户组和密码
func (m *DBModel) SetUserGroupAndPassword(userId int64, groupId []int64, password ...string) (err error) {
tx := m.db.Begin()
defer func() {
if err != nil {
tx.Rollback()
} else {
tx.Commit()
}
}()
var (
existUsersGroups []UserGroup
userGroups []UserGroup
)
tx.Where("user_id = ?", userId).Find(&existUsersGroups)
// 删除旧的关联用户组
err = tx.Where("user_id = ?", userId).Delete(&UserGroup{}).Error
if err != nil {
m.logger.Error("SetUserGroupAndPassword", zap.Error(err))
return
}
// 设置用户组统计数据
for _, existUsersGroup := range existUsersGroups {
err = tx.Model(&Group{}).Where("id = ?", existUsersGroup.GroupId).Update("user_count", gorm.Expr("user_count - ?", 1)).Error
if err != nil {
m.logger.Error("SetUserGroupAndPassword", zap.Error(err))
return
}
}
// 设置用户组统计数据
for _, groupId := range groupId {
userGroups = append(userGroups, UserGroup{UserId: userId, GroupId: groupId})
err = tx.Model(&Group{}).Where("id = ?", groupId).Update("user_count", gorm.Expr("user_count + ?", 1)).Error
if err != nil {
m.logger.Error("SetUserGroupAndPassword", zap.Error(err))
return
}
}
if len(userGroups) > 0 {
err = tx.Create(&userGroups).Error
if err != nil {
m.logger.Error("SetUserGroupAndPassword", zap.Error(err))
return
}
}
2 years ago
if len(password) > 0 && password[0] != "" {
2 years ago
err = m.UpdateUserPassword(userId, password[0], tx)
if err != nil {
m.logger.Error("UpdateUserPassword", zap.Error(err))
return
}
}
return
}
// CanIUploadDocument 判断用户是否有上传文档的权限
11 months ago
// 1. 用户是否被禁用或被处罚禁止上传文档
// 2. 用户所在的用户组是否允许上传文档
func (m *DBModel) CanIAccessUploadDocument(userId int64) (yes bool) {
if inPunishing, _ := m.isInPunishing(userId, []int{PunishmentTypeDisabled, PunishmentTypeUploadLimited}); inPunishing {
return false
}
var (
tableGroup = Group{}.TableName()
tableUserGroup = UserGroup{}.TableName()
group Group
)
err := m.db.Select("g.id").Table(tableGroup+" g").Joins(
"left join "+tableUserGroup+" ug on g.id=ug.group_id",
).Where("ug.user_id = ? and g.enable_upload = ?", userId, true).Find(&group).Error
if err != nil {
m.logger.Error("CanIUploadDocument", zap.Error(err))
return
}
return group.Id > 0
}
11 months ago
// 用户是否可以下载文档:被禁用的账号或被禁止下载的账户不能下载
func (m *DBModel) CanIAccessDownload(userId int64) (yes bool, err error) {
yes, err = m.isInPunishing(userId, []int{PunishmentTypeDownloadLimited, PunishmentTypeDisabled})
yes = !yes
if err != nil {
m.logger.Error("CanIAccessDownload", zap.Error(err))
return
}
return
}
// 用户是否可以评论
func (m *DBModel) CanIAccessComment(userId int64) (yes bool, err error) {
yes, err = m.isInPunishing(userId, []int{PunishmentTypeCommentLimited, PunishmentTypeDisabled})
yes = !yes
if err != nil {
m.logger.Error("CanIAccessComment", zap.Error(err))
return
}
return
}
// 用户是否可以收藏文档
func (m *DBModel) CanIAccessFavorite(userId int64) (yes bool, err error) {
yes, err = m.isInPunishing(userId, []int{PunishmentTypeFavoriteLimited, PunishmentTypeDisabled})
yes = !yes
if err != nil {
m.logger.Error("CanIAccessFavorite", zap.Error(err))
return
}
return
}
// 用户是否发表评论
func (m *DBModel) CanIPublishComment(userId int64) (defaultCommentStatus int8, err error) {
if userId <= 0 {
return
}
var (
group Group
userGroup UserGroup
comment Comment
)
m.db.Select("g.id").Table(group.TableName()+" g").Joins(
"left join "+userGroup.TableName()+" ug on g.id=ug.group_id",
).Where("ug.user_id = ? and g.enable_comment_approval = ?", userId, false).Find(&group)
// 评论不需要审核
if group.Id > 0 {
defaultCommentStatus = CommentStatusApproved
} else {
defaultCommentStatus = CommentStatusPending
}
// 评论时间间隔
commentInterval := m.GetConfigOfSecurity(ConfigSecurityCommentInterval).CommentInterval
if commentInterval <= 0 {
return
}
// 获取用户最新的一条评论信息
m.db.Select("id", "created_at").Where("user_id = ?", userId).Order("id desc").Find(&comment)
if comment.Id <= 0 {
return
}
seconds := int32(time.Since(*comment.CreatedAt).Seconds())
left := commentInterval - seconds
if left > 0 {
err = fmt.Errorf("您的评论太快了,请等待 %d 秒后再试", left)
return
}
return
}
1 year ago
func (s *DBModel) CountUser(status ...int) (count int64, err error) {
db := s.db.Model(&User{})
if len(status) > 0 {
db = db.Where("status in (?)", status)
}
err = db.Count(&count).Error
if err != nil {
s.logger.Error("CountUser", zap.Error(err))
return
}
return
}