You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

170 lines
4.2 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package model
import (
"database/sql"
"errors"
"moredoc/conf"
"strings"
"go.uber.org/zap"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
)
type TableColumn struct {
Field string `gorm:"Field"`
Type string `gorm:"Type"`
Collation string `gorm:"Collation"`
Null string `gorm:"Null"`
Key string `gorm:"Key"`
Default string `gorm:"Default"`
Extra string `gorm:"Extra"`
Privileges string `gorm:"Privileges"`
Comment string `gorm:"Comment"`
}
var tablePrefix string
type DBModel struct {
db *gorm.DB
tablePrefix string
logger *zap.Logger
tableFields map[string][]string
tableFieldsMap map[string]map[string]struct{}
}
func NewDBModel(cfg *conf.Database, lg *zap.Logger) (m *DBModel, err error) {
if lg == nil {
err = errors.New("logger cant be nil")
return
}
tablePrefix = cfg.Prefix
m = &DBModel{
logger: lg.Named("model"),
tablePrefix: cfg.Prefix,
tableFields: make(map[string][]string),
tableFieldsMap: make(map[string]map[string]struct{}),
}
var (
db *gorm.DB
sqlDB *sql.DB
)
sqlLogLevel := logger.Info
if !cfg.ShowSQL {
sqlLogLevel = logger.Silent
}
db, err = gorm.Open(mysql.New(mysql.Config{
DSN: cfg.DSN, // DSN data source name
DefaultStringSize: 255, // string 类型字段的默认长度
DisableDatetimePrecision: true, // 禁用 datetime 精度MySQL 5.6 之前的数据库不支持
DontSupportRenameIndex: true, // 重命名索引时采用删除并新建的方式MySQL 5.7 之前的数据库和 MariaDB 不支持重命名索引
DontSupportRenameColumn: true, // 用 `change` 重命名列MySQL 8 之前的数据库和 MariaDB 不支持重命名列
SkipInitializeWithVersion: false, // 根据当前 MySQL 版本自动配置
}), &gorm.Config{
NamingStrategy: schema.NamingStrategy{
TablePrefix: cfg.Prefix, // 表名前缀,`User`表为`t_users`
SingularTable: true, // 使用单数表名,启用该选项后,`User` 表将是`user`
},
Logger: logger.Default.LogMode(sqlLogLevel),
})
if err != nil {
m.logger.Error("NewDBModel", zap.Error(err), zap.Any("config", cfg))
return
}
sqlDB, err = db.DB()
if err != nil {
m.logger.Error("db.DB()", zap.Error(err))
return
}
if cfg.MaxIdle > 0 {
sqlDB.SetMaxIdleConns(cfg.MaxIdle)
}
if cfg.MaxOpen > 0 {
sqlDB.SetMaxIdleConns(cfg.MaxOpen)
}
m.db = db
// 获取所有数据库表并把数据库表字段加入到全局map以便根据指定字段查询数据
tables, err := m.ShowTables()
if err != nil {
m.logger.Error("ShowTables", zap.Error(err))
return nil, err
}
for _, table := range tables {
columns, err := m.showTableColumn(table)
if err != nil {
m.logger.Error("showTableColumn", zap.Error(err))
return nil, err
}
var fields []string
for _, col := range columns {
fields = append(fields, col.Field)
}
m.tableFields[table] = fields
filedsMap := make(map[string]struct{})
for _, field := range fields {
filedsMap[field] = struct{}{}
}
m.tableFieldsMap[table] = filedsMap
}
return
}
func (m *DBModel) SyncDB() (err error) {
tableModels := []interface{}{
// &User{},
}
if err = m.db.AutoMigrate(tableModels...); err != nil {
m.logger.Error("SyncDB", zap.Error(err))
}
return
}
func (m *DBModel) GetDB() *gorm.DB {
return m.db
}
func (m *DBModel) ShowTables() (tables []string, err error) {
err = m.db.Raw("show tables").Scan(&tables).Error
if err != nil {
m.logger.Error("ShowTables", zap.Error(err))
}
return
}
// FilterValidFields 过滤掉不存在的字段
func (m *DBModel) FilterValidFields(tableName string, fields ...string) (validFields []string) {
fieldsMap, ok := m.tableFieldsMap[tableName]
if ok {
for _, field := range fields {
field = strings.ToLower(strings.TrimSpace(field))
if _, ok := fieldsMap[field]; ok {
validFields = append(validFields, field)
}
}
}
return
}
func (m *DBModel) showTableColumn(tableName string) (columns []TableColumn, err error) {
err = m.db.Raw("SHOW FULL COLUMNS FROM " + tableName).Find(&columns).Error
if err != nil {
m.logger.Error("ShowTableColumn", zap.Error(err))
}
return
}