|
|
package model
|
|
|
|
|
|
import (
|
|
|
"database/sql"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"moredoc/conf"
|
|
|
"strings"
|
|
|
"sync"
|
|
|
|
|
|
"go.uber.org/zap"
|
|
|
"gorm.io/driver/mysql"
|
|
|
"gorm.io/gorm"
|
|
|
"gorm.io/gorm/logger"
|
|
|
"gorm.io/gorm/schema"
|
|
|
)
|
|
|
|
|
|
type TableColumn struct {
|
|
|
Field string `gorm:"Field"`
|
|
|
Type string `gorm:"Type"`
|
|
|
Collation string `gorm:"Collation"`
|
|
|
Null string `gorm:"Null"`
|
|
|
Key string `gorm:"Key"`
|
|
|
Default string `gorm:"Default"`
|
|
|
Extra string `gorm:"Extra"`
|
|
|
Privileges string `gorm:"Privileges"`
|
|
|
Comment string `gorm:"Comment"`
|
|
|
}
|
|
|
|
|
|
type TableIndex struct {
|
|
|
Table string `gorm:"column:Table"` // 表名
|
|
|
NonUnique int `gorm:"column:Non_unique"`
|
|
|
KeyName string `gorm:"column:Key_name"` // 索引名称
|
|
|
SeqInIndex int `gorm:"column:Seq_in_index"`
|
|
|
ColumnName string `gorm:"column:Column_name"` // 索引字段名称
|
|
|
Collation string `gorm:"column:Collation"` // 字符集
|
|
|
Cardinality int `gorm:"column:Cardinality"`
|
|
|
SubPart sql.NullInt64 `gorm:"column:Sub_part"`
|
|
|
Packed sql.NullString `gorm:"column:Packed"`
|
|
|
Null string `gorm:"column:Null"`
|
|
|
IndexType string `gorm:"column:Index_type"`
|
|
|
Comment string `gorm:"column:Comment"` // 索引备注
|
|
|
IndexComment string `gorm:"column:Index_comment"`
|
|
|
}
|
|
|
|
|
|
// 默认表前缀
|
|
|
var (
|
|
|
tablePrefix string = "mnt_"
|
|
|
convertDocumentRunning = false
|
|
|
)
|
|
|
|
|
|
type DBModel struct {
|
|
|
db *gorm.DB
|
|
|
tablePrefix string
|
|
|
logger *zap.Logger
|
|
|
tableFields map[string][]string
|
|
|
tableFieldsMap map[string]map[string]struct{}
|
|
|
validToken sync.Map // map[tokenUUID]struct{} 有效的token uuid
|
|
|
invalidToken sync.Map // map[tokenUUID]struct{} 存在,未过期但无效token,比如读者退出登录后的token
|
|
|
}
|
|
|
|
|
|
func NewDBModel(cfg *conf.Database, lg *zap.Logger) (m *DBModel, err error) {
|
|
|
if lg == nil {
|
|
|
err = errors.New("logger cant be nil")
|
|
|
return
|
|
|
}
|
|
|
|
|
|
tablePrefix = cfg.Prefix
|
|
|
|
|
|
m = &DBModel{
|
|
|
logger: lg.Named("model"),
|
|
|
tablePrefix: cfg.Prefix,
|
|
|
tableFields: make(map[string][]string),
|
|
|
tableFieldsMap: make(map[string]map[string]struct{}),
|
|
|
}
|
|
|
|
|
|
var (
|
|
|
db *gorm.DB
|
|
|
sqlDB *sql.DB
|
|
|
)
|
|
|
|
|
|
sqlLogLevel := logger.Info
|
|
|
if !cfg.ShowSQL {
|
|
|
sqlLogLevel = logger.Silent
|
|
|
}
|
|
|
|
|
|
db, err = gorm.Open(mysql.New(mysql.Config{
|
|
|
DSN: cfg.DSN, // DSN data source name
|
|
|
DefaultStringSize: 255, // string 类型字段的默认长度
|
|
|
DisableDatetimePrecision: true, // 禁用 datetime 精度,MySQL 5.6 之前的数据库不支持
|
|
|
DontSupportRenameIndex: true, // 重命名索引时采用删除并新建的方式,MySQL 5.7 之前的数据库和 MariaDB 不支持重命名索引
|
|
|
DontSupportRenameColumn: true, // 用 `change` 重命名列,MySQL 8 之前的数据库和 MariaDB 不支持重命名列
|
|
|
SkipInitializeWithVersion: false, // 根据当前 MySQL 版本自动配置
|
|
|
}), &gorm.Config{
|
|
|
NamingStrategy: schema.NamingStrategy{
|
|
|
TablePrefix: cfg.Prefix, // 表名前缀,`User`表为`t_users`
|
|
|
SingularTable: true, // 使用单数表名,启用该选项后,`User` 表将是`user`
|
|
|
},
|
|
|
Logger: logger.Default.LogMode(sqlLogLevel),
|
|
|
})
|
|
|
if err != nil {
|
|
|
m.logger.Error("NewDBModel", zap.Error(err), zap.Any("config", cfg))
|
|
|
return
|
|
|
}
|
|
|
|
|
|
sqlDB, err = db.DB()
|
|
|
if err != nil {
|
|
|
m.logger.Error("db.DB()", zap.Error(err))
|
|
|
return
|
|
|
}
|
|
|
|
|
|
if cfg.MaxIdle > 0 {
|
|
|
sqlDB.SetMaxIdleConns(cfg.MaxIdle)
|
|
|
}
|
|
|
|
|
|
if cfg.MaxOpen > 0 {
|
|
|
sqlDB.SetMaxIdleConns(cfg.MaxOpen)
|
|
|
}
|
|
|
|
|
|
m.db = db
|
|
|
|
|
|
// 获取所有数据库表,并把数据库表字段加入到全局map,以便根据指定字段查询数据
|
|
|
tables, err := m.ShowTables()
|
|
|
if err != nil {
|
|
|
m.logger.Error("ShowTables", zap.Error(err))
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
for _, table := range tables {
|
|
|
columns, err := m.showTableColumn(table)
|
|
|
if err != nil {
|
|
|
m.logger.Error("showTableColumn", zap.Error(err))
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
var fields []string
|
|
|
for _, col := range columns {
|
|
|
fields = append(fields, col.Field)
|
|
|
}
|
|
|
|
|
|
m.tableFields[table] = fields
|
|
|
filedsMap := make(map[string]struct{})
|
|
|
for _, field := range fields {
|
|
|
filedsMap[field] = struct{}{}
|
|
|
}
|
|
|
m.tableFieldsMap[table] = filedsMap
|
|
|
}
|
|
|
return
|
|
|
}
|
|
|
|
|
|
func (m *DBModel) SyncDB() (err error) {
|
|
|
tableModels := []interface{}{
|
|
|
&Attachment{},
|
|
|
&Banner{},
|
|
|
&Category{},
|
|
|
&Config{},
|
|
|
&Document{},
|
|
|
&DocumentCategory{},
|
|
|
&DocumentError{},
|
|
|
&DocumentScore{},
|
|
|
&DocumentRelate{},
|
|
|
&Download{},
|
|
|
&Friendlink{},
|
|
|
&User{},
|
|
|
&Group{},
|
|
|
&UserGroup{},
|
|
|
&Permission{},
|
|
|
&GroupPermission{},
|
|
|
&Logout{},
|
|
|
&Article{},
|
|
|
&Favorite{},
|
|
|
&Comment{},
|
|
|
&Dynamic{},
|
|
|
&Sign{},
|
|
|
&Report{},
|
|
|
&Navigation{},
|
|
|
&Punishment{},
|
|
|
&EmailCode{},
|
|
|
}
|
|
|
|
|
|
m.alterTableBeforeSyncDB()
|
|
|
if err = m.db.AutoMigrate(tableModels...); err != nil {
|
|
|
m.logger.Fatal("SyncDB", zap.Error(err))
|
|
|
return
|
|
|
}
|
|
|
m.alterTableAfterSyncDB()
|
|
|
|
|
|
if err = m.initDatabase(); err != nil {
|
|
|
m.logger.Fatal("SyncDB", zap.Error(err))
|
|
|
}
|
|
|
return
|
|
|
}
|
|
|
|
|
|
func (m *DBModel) RunTasks() {
|
|
|
go m.loopCovertDocument()
|
|
|
go m.cronUpdateSitemap()
|
|
|
go m.cronMarkAttachmentDeleted()
|
|
|
go m.cronCleanInvalidAttachment()
|
|
|
}
|
|
|
|
|
|
func (m *DBModel) GetDB() *gorm.DB {
|
|
|
return m.db
|
|
|
}
|
|
|
|
|
|
func (m *DBModel) ShowTables() (tables []string, err error) {
|
|
|
err = m.db.Raw("show tables").Scan(&tables).Error
|
|
|
if err != nil {
|
|
|
m.logger.Error("ShowTables", zap.Error(err))
|
|
|
}
|
|
|
return
|
|
|
}
|
|
|
|
|
|
func (m *DBModel) alterTableBeforeSyncDB() {
|
|
|
// 查询mnt_user表,将email字段由唯一索引删掉,以便变更为普通索引
|
|
|
tableUser := User{}.TableName()
|
|
|
indexes := m.ShowIndexes(tableUser)
|
|
|
m.logger.Debug("alterTableBeforeSyncDB", zap.String("table", tableUser), zap.Any("indexes", indexes))
|
|
|
if len(indexes) > 0 {
|
|
|
for _, index := range indexes {
|
|
|
if index.ColumnName == "email" && index.NonUnique == 0 { // 唯一索引,需要删除原索引
|
|
|
err := m.db.Exec(fmt.Sprintf("alter table %s drop index %s", tableUser, index.KeyName)).Error
|
|
|
if err != nil {
|
|
|
m.logger.Error("alterTableBeforeSyncDB", zap.Error(err))
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func (m *DBModel) alterTableAfterSyncDB() {
|
|
|
|
|
|
}
|
|
|
|
|
|
func (m *DBModel) ShowIndexes(table string) (indexes []TableIndex) {
|
|
|
sql := "show index from " + table
|
|
|
err := m.db.Raw(sql).Find(&indexes).Error
|
|
|
if err != nil {
|
|
|
m.logger.Error("ShowIndexes", zap.Error(err))
|
|
|
}
|
|
|
return
|
|
|
}
|
|
|
|
|
|
// FilterValidFields 过滤掉不存在的字段
|
|
|
func (m *DBModel) FilterValidFields(tableName string, fields ...string) (validFields []string) {
|
|
|
alias := ""
|
|
|
slice := strings.Split(tableName, " ")
|
|
|
if len(slice) == 2 {
|
|
|
alias = slice[1] + "."
|
|
|
tableName = slice[0]
|
|
|
}
|
|
|
fieldsMap, ok := m.tableFieldsMap[tableName]
|
|
|
if ok {
|
|
|
for _, field := range fields {
|
|
|
field = strings.ToLower(strings.TrimSpace(field))
|
|
|
if _, ok := fieldsMap[field]; ok {
|
|
|
validFields = append(validFields, fmt.Sprintf("%s%s", alias, field))
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
return
|
|
|
}
|
|
|
|
|
|
// GetTableFields 查询指定表的所有字段
|
|
|
func (m *DBModel) GetTableFields(tableName string) (fields []string) {
|
|
|
slice := strings.Split(tableName, " ")
|
|
|
alias := ""
|
|
|
if len(slice) == 2 {
|
|
|
tableName = slice[0]
|
|
|
alias = slice[1] + "."
|
|
|
}
|
|
|
fieldsMap, ok := m.tableFieldsMap[tableName]
|
|
|
if ok {
|
|
|
for field := range fieldsMap {
|
|
|
fields = append(fields, fmt.Sprintf("%s%s", alias, field))
|
|
|
}
|
|
|
}
|
|
|
return
|
|
|
}
|
|
|
|
|
|
func (m *DBModel) showTableColumn(tableName string) (columns []TableColumn, err error) {
|
|
|
err = m.db.Raw("SHOW FULL COLUMNS FROM " + tableName).Find(&columns).Error
|
|
|
if err != nil {
|
|
|
m.logger.Error("ShowTableColumn", zap.Error(err))
|
|
|
}
|
|
|
return
|
|
|
}
|
|
|
|
|
|
// initialDatabase 初始化数据库相关数据
|
|
|
func (m *DBModel) initDatabase() (err error) {
|
|
|
// 初始化用户组及其权限
|
|
|
if err = m.initGroupAndPermission(); err != nil {
|
|
|
m.logger.Error("initGroupAndPermission", zap.Error(err))
|
|
|
return
|
|
|
}
|
|
|
|
|
|
// 初始化用户
|
|
|
if err = m.initUser(); err != nil {
|
|
|
m.logger.Error("initUser", zap.Error(err))
|
|
|
return
|
|
|
}
|
|
|
|
|
|
// 初始化配置
|
|
|
if err = m.initConfig(); err != nil {
|
|
|
m.logger.Error("initConfig", zap.Error(err))
|
|
|
return
|
|
|
}
|
|
|
|
|
|
// 初始化文章
|
|
|
m.initArticle()
|
|
|
|
|
|
// 初始化友情链接
|
|
|
if err = m.initFriendlink(); err != nil {
|
|
|
m.logger.Error("initFriendlink", zap.Error(err))
|
|
|
return
|
|
|
}
|
|
|
|
|
|
// 初始化静态页面SEO
|
|
|
m.InitSEO()
|
|
|
|
|
|
return
|
|
|
}
|
|
|
|
|
|
// 初始化用户组
|
|
|
func (m *DBModel) initGroupAndPermission() (err error) {
|
|
|
groups := []Group{
|
|
|
{Id: 1, Title: "超级管理员", IsDisplay: true, Description: "系统超级管理员", UserCount: 0, Sort: 0, EnableUpload: true},
|
|
|
{Id: 2, Title: "普通用户", IsDisplay: true, Description: "普通用户", UserCount: 0, Sort: 0, IsDefault: true},
|
|
|
}
|
|
|
|
|
|
// 如果没有任何用户组,则初始化
|
|
|
var existGroup Group
|
|
|
m.db.First(&existGroup)
|
|
|
|
|
|
sess := m.db.Begin()
|
|
|
defer func() {
|
|
|
if err != nil {
|
|
|
sess.Rollback()
|
|
|
} else {
|
|
|
sess.Commit()
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
if existGroup.Id == 0 {
|
|
|
// 用户组还不存在,则创建初始用户组
|
|
|
err = sess.Create(&groups).Error
|
|
|
if err != nil {
|
|
|
m.logger.Error("initGroup", zap.Error(err))
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 初始化权限
|
|
|
for _, permission := range getPermissions() {
|
|
|
err = sess.Where("method = ? and path = ?", permission.Method, permission.Path).FirstOrCreate(&permission).Error
|
|
|
if err != nil {
|
|
|
m.logger.Error("initPermission", zap.Error(err))
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
|
|
|
return
|
|
|
}
|
|
|
|
|
|
func (m *DBModel) initFriendlink() (err error) {
|
|
|
var friendlink Friendlink
|
|
|
m.db.Find(&friendlink)
|
|
|
if friendlink.Id > 0 {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
// 默认友链
|
|
|
var friendlinks = []Friendlink{
|
|
|
{Title: "摩枫网络科技", Link: "https://mnt.ltd", Enable: true},
|
|
|
{Title: "书栈网", Link: "https://www.bookstack.cn", Enable: true},
|
|
|
}
|
|
|
|
|
|
err = m.db.Create(&friendlinks).Error
|
|
|
if err != nil {
|
|
|
m.logger.Error("initFriendlink", zap.Error(err))
|
|
|
}
|
|
|
return
|
|
|
}
|
|
|
|
|
|
// generateQueryLike 生成like查询。Like 查询比较特殊,统一用or来拼接查询的字段
|
|
|
func (m *DBModel) generateQueryLike(db *gorm.DB, tableName string, queryLike map[string][]interface{}) *gorm.DB {
|
|
|
alias := ""
|
|
|
slice := strings.Split(tableName, " ")
|
|
|
if len(slice) == 2 {
|
|
|
tableName = slice[0]
|
|
|
alias = slice[1] + "."
|
|
|
}
|
|
|
|
|
|
if len(queryLike) > 0 {
|
|
|
var likeQuery []string
|
|
|
var likeValues []interface{}
|
|
|
for field, values := range queryLike {
|
|
|
fields := m.FilterValidFields(tableName, field)
|
|
|
if len(fields) == 0 {
|
|
|
continue
|
|
|
}
|
|
|
for _, value := range values {
|
|
|
valueStr := fmt.Sprintf("%v", value)
|
|
|
likeQuery = append(likeQuery, fmt.Sprintf("%s%s like ?", alias, field))
|
|
|
likeValues = append(likeValues, "%"+valueStr+"%")
|
|
|
}
|
|
|
}
|
|
|
if len(likeQuery) > 0 {
|
|
|
db = db.Where(strings.Join(likeQuery, " or "), likeValues...)
|
|
|
}
|
|
|
}
|
|
|
return db
|
|
|
}
|
|
|
|
|
|
func (m *DBModel) generateQueryRange(db *gorm.DB, tableName string, queryRange map[string][2]interface{}) *gorm.DB {
|
|
|
alias := ""
|
|
|
slice := strings.Split(tableName, " ")
|
|
|
if len(slice) == 2 {
|
|
|
tableName = slice[0]
|
|
|
alias = slice[1] + "."
|
|
|
}
|
|
|
|
|
|
for field, rangeValue := range queryRange {
|
|
|
fields := m.FilterValidFields(tableName, field)
|
|
|
if len(fields) == 0 {
|
|
|
continue
|
|
|
}
|
|
|
if rangeValue[0] != nil {
|
|
|
db = db.Where(fmt.Sprintf("%s%s >= ?", alias, field), rangeValue[0])
|
|
|
}
|
|
|
if rangeValue[1] != nil {
|
|
|
db = db.Where(fmt.Sprintf("%s%s <= ?", alias, field), rangeValue[1])
|
|
|
}
|
|
|
}
|
|
|
return db
|
|
|
}
|
|
|
|
|
|
func (m *DBModel) generateQueryIn(db *gorm.DB, tableName string, queryIn map[string][]interface{}) *gorm.DB {
|
|
|
alias := ""
|
|
|
slice := strings.Split(tableName, " ")
|
|
|
if len(slice) == 2 {
|
|
|
tableName = slice[0]
|
|
|
alias = slice[1] + "."
|
|
|
}
|
|
|
for field, values := range queryIn {
|
|
|
fields := m.FilterValidFields(tableName, field)
|
|
|
if len(fields) == 0 {
|
|
|
continue
|
|
|
}
|
|
|
db = db.Where(fmt.Sprintf("%s%s in (?)", alias, field), values)
|
|
|
}
|
|
|
return db
|
|
|
}
|
|
|
|
|
|
func (m *DBModel) generateQuerySort(db *gorm.DB, tableName string, querySort []string) *gorm.DB {
|
|
|
var sorts []string
|
|
|
alias := ""
|
|
|
slice := strings.Split(tableName, " ")
|
|
|
if len(slice) == 2 {
|
|
|
tableName = slice[0]
|
|
|
alias = slice[1] + "."
|
|
|
}
|
|
|
for _, sort := range querySort {
|
|
|
slice := strings.Split(sort, " ")
|
|
|
if len(m.FilterValidFields(tableName, slice[0])) == 0 {
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
if len(slice) == 2 {
|
|
|
item := strings.ToLower(slice[1])
|
|
|
if item == "asc" || item == "desc" {
|
|
|
sorts = append(sorts, fmt.Sprintf("%s%s %s", alias, slice[0], item))
|
|
|
}
|
|
|
} else {
|
|
|
sorts = append(sorts, fmt.Sprintf("%s%s desc", alias, slice[0]))
|
|
|
}
|
|
|
}
|
|
|
if len(sorts) > 0 {
|
|
|
db = db.Order(strings.Join(sorts, ","))
|
|
|
} else {
|
|
|
db = db.Order(fmt.Sprintf("%sid desc", alias))
|
|
|
}
|
|
|
return db
|
|
|
}
|