blog_backend_api/db/db.go

600 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

// 数据库工具包
package db
import (
"database/sql"
"errors"
"fmt"
"reflect"
"strconv"
"sync"
"time"
"background/logs"
)
// 数据容器抽象对象定义
type Database struct {
Type string // 用来给SqlBuilder进行一些特殊的判断 (空值或mysql 皆表示这是一个MySQL实例)
DB *sql.DB
}
// SQL异步执行队列定义
type queueList struct {
list []*QueueItem //队列列表
sleeping chan bool
loop chan bool
lock sync.RWMutex
quit chan bool
quited bool
}
// SQL异步执行队列子元素定义
type QueueItem struct {
DB *Database //数据库对象
Query string //SQL语句字符串
Params []interface{} //参数列表
}
// 缓存数据对象定义
type cache struct {
data map[string]map[string]interface{}
}
func (this *cache) Init() {
this.data["default"] = make(map[string]interface{})
}
// 设置缓存
func (this *cache) Set(key string, value interface{}, args ...string) {
var group string
if len(args) > 0 {
group = args[0]
if _, exist := this.data[group]; !exist {
this.data[group] = make(map[string]interface{})
}
} else {
group = "default"
}
this.data[group][key] = value
}
// 获取缓存数据
func (this *cache) Get(key string, args ...string) interface{} {
var group string
if len(args) > 0 {
group = args[0]
} else {
group = "default"
}
if g, exist := this.data[group]; exist {
if v, ok := g[key]; ok {
return v
}
}
return nil
}
// 删除缓存数据
func (this *cache) Del(key string, args ...string) {
var group string
if len(args) > 0 {
group = args[0]
} else {
group = "default"
}
if g, exist := this.data[group]; exist {
if _, ok := g[key]; ok {
delete(this.data[group], key)
}
}
}
var (
lastError error
Cache *cache
queue *queueList
Obj *Database
)
func init() {
Cache = &cache{data: make(map[string]map[string]interface{})}
Cache.Init()
queue = &queueList{}
go queue.Start()
}
// 关闭数据库连接
func (this *Database) Close() {
this.DB.Close()
}
// 获取最后发生的错误字符串
func LastErr() string {
if lastError != nil {
return lastError.Error()
}
return ""
}
// 执行语句
func (this *Database) Exec(query string, args ...interface{}) (sql.Result, error) {
return this.DB.Exec(query, args...)
}
// 查询单条记录
func (this *Database) Query(query string, args ...interface{}) (*sql.Rows, error) {
return this.DB.Query(query, args...)
}
// 查询单条记录
func (this *Database) QueryRow(query string, args ...interface{}) *sql.Row {
return this.DB.QueryRow(query, args...)
}
// Query2 查询实体集合
// obj 为接收数据的实体指针
func (this *Database) Query2(sql string, obj interface{}, args ...interface{}) error {
var tagMap map[string]int
var tp, tps reflect.Type
var n, i int
var err error
var ret reflect.Value
// 检测val参数是否为我们所想要的参数
tp = reflect.TypeOf(obj)
if reflect.Ptr != tp.Kind() {
return errors.New("is not pointer")
}
if reflect.Slice != tp.Elem().Kind() {
return errors.New("is not slice pointer")
}
tp = tp.Elem()
tps = tp.Elem()
if reflect.Struct != tps.Kind() {
return errors.New("is not struct slice pointer")
}
tagMap = make(map[string]int)
n = tps.NumField()
for i = 0; i < n; i++ {
tag := tps.Field(i).Tag.Get("sql")
if len(tag) > 0 {
tagMap[tag] = i + 1
}
}
// 执行查询
ret, err = this.queryAndReflect(sql, tagMap, tp, args...)
if nil != err {
return err
}
// 返回结果
reflect.ValueOf(obj).Elem().Set(ret)
return nil
}
// queryAndReflect 查询并将结果反射成实体集合
func (this *Database) queryAndReflect(sql string,
tagMap map[string]int,
tpSlice reflect.Type, args ...interface{}) (reflect.Value, error) {
var ret reflect.Value
// 执行sql语句
rows, err := this.DB.Query(sql, args...)
if nil != err {
return reflect.Value{}, err
}
defer rows.Close()
// 开始枚举结果
cols, err := rows.Columns()
if nil != err {
return reflect.Value{}, err
}
ret = reflect.MakeSlice(tpSlice, 0, 50)
// 构建接收队列
scan := make([]interface{}, len(cols))
row := make([]interface{}, len(cols))
for r := range row {
scan[r] = &row[r]
}
for rows.Next() {
feild := reflect.New(tpSlice.Elem()).Elem()
// 取得结果
err = rows.Scan(scan...)
// 开始遍历结果
for i := 0; i < len(cols); i++ {
n := tagMap[cols[i]] - 1
if n < 0 {
continue
}
switch feild.Type().Field(n).Type.Kind() {
case reflect.Bool:
if nil != row[i] {
feild.Field(n).SetBool("false" != string(row[i].([]byte)))
} else {
feild.Field(n).SetBool(false)
}
case reflect.String:
if nil != row[i] {
feild.Field(n).SetString(string(row[i].([]byte)))
} else {
feild.Field(n).SetString("")
}
case reflect.Float32:
if nil != row[i] {
//log.Println(row[i].(float32))
switch reflect.TypeOf(row[i]).Kind() {
case reflect.Slice:
v, e := strconv.ParseFloat(string(row[i].([]byte)), 0)
if nil == e {
feild.Field(n).SetFloat(float64(v))
//feild.Field(n).SetFloat(float64(row[i].(float32)))
}
break
case reflect.Float64:
feild.Field(n).SetFloat(float64(row[i].(float32)))
}
} else {
feild.Field(n).SetFloat(0)
}
case reflect.Float64:
if nil != row[i] {
//log.Println(row[i].(float32))
//v, e := strconv.ParseFloat(string(row[i].([]byte)), 0)
//if nil == e {
feild.Field(n).SetFloat(row[i].(float64))
//}
} else {
feild.Field(n).SetFloat(0)
}
case reflect.Int8:
fallthrough
case reflect.Int16:
fallthrough
case reflect.Int32:
fallthrough
case reflect.Int64:
fallthrough
case reflect.Int:
if nil != row[i] {
byRow, ok := row[i].([]byte)
if ok {
v, e := strconv.ParseInt(string(byRow), 10, 64)
if nil == e {
feild.Field(n).SetInt(v)
}
} else {
v, e := strconv.ParseInt(fmt.Sprint(row[i]), 10, 64)
if nil == e {
feild.Field(n).SetInt(v)
}
}
} else {
feild.Field(n).SetInt(0)
}
}
}
ret = reflect.Append(ret, feild)
}
return ret, nil
}
// 执行UPDATE语句并返回受影响的行数
// 返回0表示没有出错, 但没有被更新的行
// 返回-1表示出错
func (this *Database) Update(query string, args ...interface{}) (int64, error) {
ret, err := this.Exec(query, args...)
if err != nil {
return -1, err
}
aff, err := ret.RowsAffected()
if err != nil {
return -1, err
}
return aff, nil
}
// 执行DELETE语句并返回受影响的行数
// 返回0表示没有出错, 但没有被删除的行
// 返回-1表示出错
func (this *Database) Delete(query string, args ...interface{}) (int64, error) {
return this.Update(query, args...)
}
func GenSql(obj interface{}) (string, error) {
ret := ""
typ := reflect.TypeOf(obj).Kind()
if typ != reflect.Struct {
return (""), errors.New("not a struct")
}
value := obj.(reflect.Value)
num := value.NumField()
for i := 0; i < num; i++ {
if i == 0 {
ret += "("
}
switch value.Field(i).Type().Kind() {
case reflect.String:
str := value.Field(i).Interface().(string)
if str[0] != '"' {
ret += "\""
str += "\""
ret += str
} else {
ret += value.Field(i).Interface().(string)
}
case reflect.Int:
ret += fmt.Sprintf("%d", value.Field(i).Interface().(int))
case reflect.Int8:
ret += fmt.Sprintf("%d", value.Field(i).Interface().(int8))
case reflect.Int32:
ret += fmt.Sprintf("%d", value.Field(i).Interface().(int32))
case reflect.Int64:
ret += fmt.Sprintf("%d", value.Field(i).Interface().(int64))
case reflect.Int16:
ret += fmt.Sprintf("%d", value.Field(i).Interface().(int16))
case reflect.Bool:
if value.Field(i).Interface().(bool) {
ret += fmt.Sprintf("true")
} else {
ret += fmt.Sprintf("false")
}
case reflect.Float32:
ret += fmt.Sprintf("%x", value.Field(i).Interface().(float32))
case reflect.Float64:
ret += fmt.Sprintf("true", value.Field(i).Interface().(float64))
}
if i == num-1 {
ret += ")"
} else {
ret += ","
}
}
return ret, nil
}
func (this *Database) InsertObejct(tb_name string, obj interface{}) (int64, error) {
var tagMap map[int]string
var tp, tps reflect.Type
var n, i int
// 检测val参数是否为我们所想要的参数
tp = reflect.TypeOf(obj)
if reflect.Ptr != tp.Kind() {
return 0, errors.New("is not pointer")
}
if reflect.Slice != tp.Elem().Kind() {
return 0, errors.New("is not slice pointer")
}
tp = tp.Elem()
tps = tp.Elem()
value := reflect.ValueOf(obj).Elem()
if reflect.Struct != tps.Kind() {
return 0, errors.New("is not struct slice pointer")
}
for z := 0; z < value.Len(); z++ {
tagMap = make(map[int]string)
n = tps.NumField()
var query_struct string
for i = 0; i < n; i++ {
tag := tps.Field(i).Tag.Get("sql")
if len(tag) > 0 {
tagMap[i] = tag
}
if i == 0 {
query_struct += "("
}
query_struct += tagMap[i]
if i == n-1 {
query_struct += ")"
} else {
query_struct += ","
}
}
vs, e := GenSql(value.Index(z))
if nil != e {
logs.Error(e.Error())
}
query := "insert into " + tb_name + query_struct + "values " + vs
_, e = this.Insert(query)
if e != nil {
logs.Error(e.Error())
}
}
return 0, nil
}
// 执行INSERT语句并返回最后生成的自增ID
// 返回0表示没有出错, 但没生成自增ID
// 返回-1表示出错
func (this *Database) Insert(query string, args ...interface{}) (int64, error) {
ret, err := this.Exec(query, args...)
if err != nil {
return -1, err
}
last, err := ret.LastInsertId()
if err != nil {
return -1, err
}
return last, nil
}
type OneRow map[string]string
type Results []OneRow
// 判断字段是否存在
func (row OneRow) Exist(field string) bool {
if _, ok := row[field]; ok {
return true
}
return false
}
// 获取指定字段的值
func (row OneRow) Get(field string) string {
if v, ok := row[field]; ok {
return v
}
return ""
}
// 获取指定字段的整数值, 注意, 如果该字段不存在则会返回0
func (row OneRow) GetInt(field string) int {
if v, ok := row[field]; ok {
return Atoi(v)
}
return 0
}
// 获取指定字段的整数值, 注意, 如果该字段不存在则会返回0
func (row OneRow) GetInt64(field string) int64 {
if v, ok := row[field]; ok {
return Atoi64(v)
}
return 0
}
// 设置值
func (row OneRow) Set(key, val string) {
row[key] = val
}
// 查询不定字段的结果集
func (this *Database) Select(query string, args ...interface{}) (Results, error) {
rows, err := this.DB.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
cols, err := rows.Columns()
if err != nil {
return nil, err
}
colNum := len(cols)
rawValues := make([][]byte, colNum)
scans := make([]interface{}, len(cols)) //query.Scan的参数因为每次查询出来的列是不定长的所以传入长度固定当次查询的长度
// 将每行数据填充到[][]byte里
for i := range rawValues {
scans[i] = &rawValues[i]
}
results := make(Results, 0)
for rows.Next() {
err = rows.Scan(scans...)
if err != nil {
return nil, err
}
row := make(map[string]string)
for k, raw := range rawValues {
key := cols[k]
/*if raw == nil {
row[key] = "\\N"
} else {*/
row[key] = string(raw)
//}
}
results = append(results, row)
}
return results, nil
}
// 查询一行不定字段的结果
func (this *Database) SelectOne(query string, args ...interface{}) (OneRow, error) {
ret, err := this.Select(query, args...)
if err != nil {
return nil, err
}
if len(ret) > 0 {
return ret[0], nil
}
return make(OneRow), nil
}
// 队列入栈
func (this *queueList) Push(item *QueueItem) {
this.lock.Lock()
this.list = append(this.list, item)
this.lock.Unlock()
}
// 队列出栈
func (this *queueList) Pop() chan *QueueItem {
item := make(chan *QueueItem)
go func() {
defer close(item)
for {
switch {
case len(this.list) == 0:
timeout := time.After(time.Second * 2)
select {
case <-this.quit:
this.quited = true
return
case <-timeout:
//log.Println("SQL Queue polling")
}
default:
this.lock.Lock()
i := this.list[0]
this.list = this.list[1:]
this.lock.Unlock()
select {
case item <- i:
return
case <-this.quit:
this.quited = true
return
}
}
}
}()
return item
}
// 执行开始执行
func (this *queueList) Start() {
for {
if this.quited {
return
}
c := this.Pop()
item := <-c
item.DB.Exec(item.Query, item.Params...)
}
}
// 停止队列
func (this *queueList) Stop() {
this.quit <- true
}
// 向Sql队列中插入一条执行语句
func (this *Database) Queue(query string, args ...interface{}) {
item := &QueueItem{
DB: this,
Query: query,
Params: args,
}
queue.Push(item)
}