From edd5a83d3d2647a00b9ad9ab594dfc1d7186d65f Mon Sep 17 00:00:00 2001 From: 18650180552 Date: Fri, 25 Jan 2019 17:11:15 +0800 Subject: [PATCH] first commit --- config/config.go | 75 ++++++ controller/hanlder.go | 550 ++++++++++++++++++++++++++++++++++++++ db/README.md | 0 db/db.go | 598 ++++++++++++++++++++++++++++++++++++++++++ db/mssql.go | 76 ++++++ db/sqlManager.go | 33 +++ db/sql_builder.go | 542 ++++++++++++++++++++++++++++++++++++++ db/utils.go | 66 +++++ logs/logs.go | 148 +++++++++++ main.go | 46 ++++ model/model.go | 13 + redis/redis.go | 426 ++++++++++++++++++++++++++++++ user.yaml | 19 ++ 13 files changed, 2592 insertions(+) create mode 100644 config/config.go create mode 100644 controller/hanlder.go create mode 100644 db/README.md create mode 100644 db/db.go create mode 100644 db/mssql.go create mode 100644 db/sqlManager.go create mode 100644 db/sql_builder.go create mode 100644 db/utils.go create mode 100644 logs/logs.go create mode 100644 main.go create mode 100644 model/model.go create mode 100644 redis/redis.go create mode 100644 user.yaml diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..7d75092 --- /dev/null +++ b/config/config.go @@ -0,0 +1,75 @@ +package config + +import ( + "log" + "gopkg.in/yaml.v2" + "os" + "runtime" +) + +var ostype = runtime.GOOS + +var conf ConfAPI + +type ConfAPI struct { + ListenSvr int `yaml:"listen_svr"` // 服务监听端口 + ListenApi int `yaml:"listen_api"` // 服务监听端口 + RunMode string `yaml:"runmode"` // 服务运行模式 + MaxConn int `yaml:"max_conn"` + Logs LogConfig `yaml:"logs"` // 日志 + Redis RedisConfig `yaml:"redis"` + Mysql MysqlConfig `yaml:"mysql"` // 认证配置 + init bool +} + +type RedisConfig struct { + Addr string `yaml:"addr"` + Pwd string `yaml:"password"` + DB int64 `yaml:"db"` +} +type LogConfig struct { + Dir string `yaml:"dir"` + File string `yaml:"file"` + Level int `yaml:"level"` + SaveFile bool `yaml:"savefile"` +} +type MysqlConfig struct { + Addr string `yaml:"addr"` + UserName string `yaml:"user"` + Password string `yaml:"password"` + Db string `yaml:"db"` + MaxOpen int `yaml:"max_open"` + MaxIdle int `yaml:"max_idle"` +} +var gConf ConfAPI +func Init(path string) error { + file,e := os.Open(path) + if nil != e{ + log.Println(e.Error()) + return e + } + stat,_ := file.Stat() + filec := make([]byte, stat.Size()) + file.Read(filec) + e = yaml.Unmarshal(filec,&gConf) + if nil != e{ + log.Println(e.Error()) + } + gConf.init = true + return nil +} + +func GetPort() int { + if gConf.init{ + return gConf.ListenApi + }else { + return 8001 + } +} +func GetMysqlConfig() *MysqlConfig{ + if gConf.init{ + return &gConf.Mysql + }else { + return nil + } +} \ No newline at end of file diff --git a/controller/hanlder.go b/controller/hanlder.go new file mode 100644 index 0000000..012c96b --- /dev/null +++ b/controller/hanlder.go @@ -0,0 +1,550 @@ +package controller +import ( + "bytes" + "crypto/md5" + "encoding/json" + "errors" + "fmt" + "github.com/fatih/structs" + "github.com/gin-gonic/gin" + _ "github.com/go-sql-driver/mysql" + "github.com/tommy351/gin-sessions" + "io" + "strconv" + "time" + "user/logs" + "log" + "math/rand" + "net/http" + "net/smtp" + "regexp" + "strings" + "user/redis" + "user/model" + "user/db" +) + +type ReqSendEmailCode struct { + EmailAdress string `json:"email_address"` +} + +type SetUserGroupReq struct { + Id int64 `json:"id,omitempty"` + Description string `json:"description"` + GroupName string `json:"group_name"` + UserIds []int `json:"user_ids"` +} +type RespBase struct { + Msg string + Status int + Data interface{} +} + +func Auth(c *gin.Context) { + var resp RespBase + var statuscode int + + statuscode = 200 + + var userinfo map[string] interface{} + //var userSockToken map[string] interface{} + + defer func() { + c.JSON(statuscode, resp) + }() + + socketToken := c.Query("socketToken") + struserinfo ,e := redis.Get(socketToken) + + if e != nil{ + logs.Error(e.Error()) + return + } + + e = json.Unmarshal([]byte(struserinfo),userinfo) + if nil != e{ + logs.Error(e.Error()) + return + } +} +// SetUser godoc +// @Summary SetUser +// @Description set userinfo +// @Accept json +// @Produce json +// @Param q query string false "name search by q" +// @Success 200 {array} util.RespBase +// @Failure 400 {object} util.RespBase +// @Failure 404 {object} util.RespBase +// @Failure 500 {object} util.RespBase +// @Router /accounts [get] +func SetUser(c *gin.Context){ + +} +func DelUser(c *gin.Context){ + +} + +func GetUser(c *gin.Context) { + var resp RespBase + resp.Msg = "操作失败" + resp.Status = 20 + defer func() { + c.JSON(200,resp) + }() + session := sessions.Get(c) + userinfo := session.Get("") + if userinfo == nil{ + logs.Error("error could not find key") + return + } + var users map[string] interface{} + e := json.Unmarshal([]byte(userinfo.(string)),&users) + if nil != e { + logs.Error(e.Error()) + } + delete(users,"socketToken") + resp.Status = 0 + resp.Msg = "操作成功" + resp.Data = users + +} +// GetUsers godoc +// @Summary GetUsers +// @Description Get all user with query +// @Accept json +// @Produce json +// @Param page query int 1 "分页的页数" +// @Param pageSize query int 10 "name search by q" +// @Param displayname query string false "name search by q" +// @Param department_id query string false "name search by q" +// @Param permission_type query string false "name search by q" +// @Success 200 {array} util.RespBase +// @Failure 400 {object} util.RespBase +// @Failure 404 {object} util.RespBase +// @Failure 500 {object} util.RespBase +// @Router /api/users [get] +func GetUsers(c *gin.Context) { + var statuscode int + var resp RespBase + + resp.Msg = "获取失败" + resp.Status = 0 + statuscode = 200 + + defer func() { + c.JSON(statuscode,resp) + }() + //获取用户组信息 + var page int + var pageSize int + var displayname string + var department_id string + var permission_type string + + displayname = c.Query("displayname") + department_id = c.Query("department_id") + permission_type = c.Query("permission_type") + + if c.Query("page") == ""{ + page = 0 + }else { + var err error + page,err = strconv.Atoi(c.Query("page")) + if err != nil{ + logs.Error("error ato i ") + } + } + log.Println(pageSize,page) + if c.Query("pageSize") == ""{ + pageSize = 10 + }else { + var err error + pageSize,err = strconv.Atoi(c.Query("pageSize")) + if err != nil{ + logs.Error("error ato i ") + } + } + session := sessions.Get(c) + userinfo := session.Get("") + var users map[string] interface{} + e := json.Unmarshal([]byte(userinfo.(string)),&users) + if nil != e { + logs.Error(e.Error()) + } + permission,ok := users["PermissionType"] + if !ok{ + logs.Error("error could not find permission_type") + return + } + // 部门组长只允许查看自己所在部门成员列表,且要过滤超级管理员 + if permission == 1 && department_id != ""{ + d ,_ := users["DepartmentId"] + department_id = d.(string) + } + + respdata := make(map[string]interface{},1) + var usersinfo []model.Users + + if permission != 0{ + query := "select * from users " + if displayname != "" || department_id != "" ||permission_type != ""{ + query += " where " + } + if displayname != ""{ + query += fmt.Sprintf("display_name like '%%%s%%'",displayname) + } + if department_id != ""{ + query += fmt.Sprintf("and department_id = %s",department_id) + } + if permission_type != ""{ + query += fmt.Sprintf("and permission_type = %s ",permission_type) + } + query += "order by id ASC " + query += fmt.Sprintf("limit %d ",pageSize) + query += fmt.Sprintf(" offset %d ",pageSize*page) + + e := db.GetMysqlClient().Query2(query,&usersinfo) + if e != nil{ + log.Println(e.Error()) + } + respdata["rows"] = usersinfo + respdata["total"] = len(usersinfo) + respdata["page"] = page + respdata["pageSize"] = pageSize + } + resp.Msg = "OK" + resp.Data = respdata +} + +func CreateVerify(length int32) string{ + strAry := []byte{ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}; + result := string("") + for i := int32(0); i < length; i++ { + x := rand.Intn(len(strAry)) + result += string(strAry[x]) + } + return result +} +/** + * 取随机Token + * @param {Number} length 取Token的长度 + * @return {string} 获取的Token + */ +func createToken(length int32) string{ + strAry := []byte{ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '_' }; + result := string("") + for i := int32(0); i < length; i++ { + x := rand.Intn(len(strAry)) + result += string(strAry[x]) + } + return result +} + +func ByteSliceToString(b []byte) string { + var ret string + + for i := 0;i < len(b) ;i++{ + s := fmt.Sprintf("%02x",b[i]) + ret += string(s) + } + return ret +} + +func DefaultOption(c *gin.Context) { + var resp RespBase + + defer func() { + c.JSON(204, resp) + }() +} +// Login godoc +// @Summary Login +// @Description login +// @Accept json +// @Produce json +// @Param logininfo query {object} LoginReq "登录请求参数" +// @Success 200 {array} util.RespBase +// @Failure 400 {object} util.RespBase +// @Failure 404 {object} util.RespBase +// @Failure 500 {object} util.RespBase +// @Router /api/login [post] +func Login(c *gin.Context) { + type LoginReq struct { + RememberMe int32 `json:"remember_me"` + UserName string `json:"user_name"` + UserPwd string `json:"user_pwd"` + } + var req LoginReq + statusCode := 200 + var resp RespBase + + defer func() { + c.JSON(statusCode, resp) + }() + e := c.Bind(&req) + if e!= nil{ + log.Println(e.Error()) + return + } + + h := md5.New() + h.Write([]byte(req.UserPwd)) // 需要加密的字符串为 123456 + passmd5 := h.Sum(nil) + var result []model.Users + + er := db.GetMysqlClient().Query2("select * from users where user_name = ?", + &result,req.UserName) + if nil != er{ + log.Println(er.Error()) + } + strpassmd5 := ByteSliceToString(passmd5) + if len(result) == 1 { + if result[0].UserPwd == strpassmd5 { + socketToken := md5.New() + socketToken.Write([]byte(createToken(6))) // 需要加密的字符串为 123456 + socketTokenMd5 := socketToken.Sum(nil) + m := structs.Map(result[0]) + m["socketToken"] = ByteSliceToString(socketTokenMd5); + sessionInfo,err := json.Marshal(m) + if err != nil{ + log.Println(err.Error()) + } + if req.RememberMe == 1{ + redis.Set(string(socketTokenMd5),string(sessionInfo),time.Second *2 * 24 * 3600 * 1000 ) + + }else { + redis.Set(string(socketTokenMd5),string(sessionInfo),time.Second *8 * 3600 * 1000 ) + + } + //存储session + session := sessions.Get(c) + session.Set("",string(sessionInfo)) + //session. + var Options *sessions.Options + if req.RememberMe == 1{ + Options = &sessions.Options{ + MaxAge: 2 * 24 * 3600 * 1000, + } + }else { + Options = &sessions.Options{ + MaxAge: 8 * 3600 * 1000, + } + } + session.Options(*Options) + session.Save() + + resp.Msg = "登录成功" + resp.Status = 0 + resp.Data = string(sessionInfo) + }else { + statusCode = 422 + resp.Msg = "用户密码不正确" + } + }else { + statusCode = 422 + resp.Msg = "登录账号不存在,请重新输入" + } +} + +func Register(c *gin.Context) { + type RegisterReq struct { + DisplayName string `json:"display_name"` + EmailAdress string `json:"email_address"` + EmailCode string `json:"email_code"` + UserName string `json:"user_name"` + UserPwd string `json:"user_pwd"` + } + var req RegisterReq + statusCode := 200 + var resp RespBase + var user model.Users + + resp.Msg = "失败" + resp.Status = 1 + + defer func() { + c.JSON(statusCode, resp) + }() + + e := c.Bind(&req) + if e!= nil{ + log.Println(e.Error()) + return + } + + session := sessions.Get(c) + email_code := session.Get(req.EmailAdress) + + if email_code != req.EmailCode{ + resp.Msg = "验证码错误" + resp.Status = 20 + return + } + + user.UserName = req.UserName + user.EmailAddress = req.EmailAdress + user.DisplayName = req.DisplayName + + h := md5.New() + h.Write([]byte(req.UserPwd)) + passwdmd5 := h.Sum(nil) + strpassmd5 := ByteSliceToString(passwdmd5) + user.UserPwd = strpassmd5 + user.UpdatedDate = time.Now().String() + user.CreatedDate = time.Now().String() + + n,er := db.GetMysqlClient().Insert("insert into users(user_name,user_pwd,created_date," + + "updated_date,display_name,email_address) values (?,?,?,?,?,?)", + user.UserName,user.UserPwd,user.CreatedDate,user.UpdatedDate, + user.DisplayName,user.EmailAddress) + if n == 0 || n < 0{ + statusCode = 422 + logs.Error(er.Error()) + resp.Msg = "失败,账号已经存在" + resp.Status = 20 + return + } + if nil != er{ + statusCode = 422 + logs.Error(er.Error()) + resp.Msg = "失败" + resp.Status = 20 + return + } + resp.Msg = "成功" + resp.Status = 0 +} + +func Logout(c *gin.Context) { + var resp RespBase + + resp.Msg = "退出成功" + resp.Status = 0 + defer func() { + c.JSON(200,resp) + }() + session := sessions.Get(c) + session.Delete("") + session.Save() +} + + +func Reader2Json(r io.ReadCloser) string{ + var ret string + for i := 0;;i++{ + s := make([]byte,10) + _,e := r.Read(s) + ret += string(s) + if e != nil{ + break + } + } + return ret +} + +func SendExternalEmail(msg interface{}) error{ + req := make(map[string] interface{},1) + req["type"] = "text" + req["action"] = "smtp-sys" + req["apiType"] = "send" + + content ,err := json.Marshal(msg) + if err != nil{ + log.Println(err.Error()) + return errors.New("Json marshal error") + } + req["content"] = string(content) + + var buffer bytes.Buffer + b,e := json.Marshal(req) + if e != nil{ + log.Println(e.Error()) + } + buffer.Write(b) + resp,err := http.Post("http://47.93.230.163:8091/msg/v1/send","application/json",&buffer) + + if resp.StatusCode != 200{ + return errors.New("error send emain") + } + if err != nil{ + logs.Error("error send email") + return err + } + return nil +} + +func SendToMail(title,user string, password string, host string, to string, content string, + ifgenerate bool) error { + var content_type string + + hp := strings.Split(host, ":") + auth := smtp.PlainAuth("", user, password, hp[0]) + + content_type = "Content-Type: text/plain" + "; charset=UTF-8" + + msg := []byte("To: " + to + "\r\nFrom: " + user + "\r\nSubject: " + title + "\r\n" + + content_type + "\r\n\r\n"+ content + "\r\n" ) + send_to := strings.Split(to, ";") + + //检测是否是邮件地址 + for k,_ := range send_to{ + match, _ := regexp.MatchString("[\\w!#$%&'*+/=?^_`{|}~-]+(?:\\.[\\w!#$%&'*+/=?^_`{|}~-]+)*@(?:[\\w](?:[\\w-]*[\\w])?\\.)+[\\w](?:[\\w-]*[\\w])?", send_to[k]) + if !match{ + return errors.New("Format Error") + } + } + err := smtp.SendMail(host, auth, user,send_to, msg) + if err !=nil{ + return err + } + return err +} + +func SendEmailCode(c *gin.Context) { + var req ReqSendEmailCode + var resp RespBase = RespBase{Msg:"邮件已经存在",Status:0} + statusCode := 200 + + defer func() { + c.JSON(statusCode, resp) + }() + + e := c.Bind(&req) + if nil != e{ + log.Println(e.Error()) + resp.Msg = "请求参数错误" + return + } + //判断邮箱是否存在 + var users []model.Users + db.GetMysqlClient().Query2("select * from users where email_adress = ?",&users,req.EmailAdress) + if len(users) != 0{ + statusCode = 422 + return + } + //产生验证码 + verify := CreateVerify(6) + session := sessions.Get(c) + session.Set(req.EmailAdress,verify) + /* + var Options *sessions.Options + + Options = &sessions.Options{ + MaxAge: 60, + } + session.Options(*Options)*/ + session.Save() + sendcontent := make( map[string] interface{},1) + sendcontent["subject"] = "邮箱验证码,请注意查收" + sendcontent["receivers"] = req.EmailAdress + sendcontent["content"] = string("您本次注册的验证码为:") + verify + string(",工作人员不会向您索取,请勿泄露。请尽快完成操作。") + + e = SendExternalEmail(sendcontent) + if e != nil{ + log.Println(e.Error()) + return + } + //成功 + resp.Msg = "发送成功" +} \ No newline at end of file diff --git a/db/README.md b/db/README.md new file mode 100644 index 0000000..e69de29 diff --git a/db/db.go b/db/db.go new file mode 100644 index 0000000..6311506 --- /dev/null +++ b/db/db.go @@ -0,0 +1,598 @@ +// 数据库工具包 +package db + +import ( + "database/sql" + "document/logs" + "errors" + "fmt" + "reflect" + "strconv" + "sync" + "time" +) + +// 数据容器抽象对象定义 +type Database struct { + Type string // 用来给SqlBuilder进行一些特殊的判断 (空值或mysql 皆表示这是一个MySQL实例) + DB *sql.DB +} + +// SQL异步执行队列定义 +type queueList struct { + list []*QueueItem //队列列表 + sleeping chan bool + loop chan bool + lock sync.RWMutex + quit chan bool + quited bool +} + +// SQL异步执行队列子元素定义 +type QueueItem struct { + DB *Database //数据库对象 + Query string //SQL语句字符串 + Params []interface{} //参数列表 +} + +// 缓存数据对象定义 +type cache struct { + data map[string]map[string]interface{} +} + +func (this *cache) Init() { + this.data["default"] = make(map[string]interface{}) +} + +// 设置缓存 +func (this *cache) Set(key string, value interface{}, args ...string) { + var group string + if len(args) > 0 { + group = args[0] + if _, exist := this.data[group]; !exist { + this.data[group] = make(map[string]interface{}) + } + } else { + group = "default" + } + this.data[group][key] = value +} + +// 获取缓存数据 +func (this *cache) Get(key string, args ...string) interface{} { + var group string + if len(args) > 0 { + group = args[0] + } else { + group = "default" + } + if g, exist := this.data[group]; exist { + if v, ok := g[key]; ok { + return v + } + } + return nil +} + +// 删除缓存数据 +func (this *cache) Del(key string, args ...string) { + var group string + if len(args) > 0 { + group = args[0] + } else { + group = "default" + } + if g, exist := this.data[group]; exist { + if _, ok := g[key]; ok { + delete(this.data[group], key) + } + } +} + +var ( + lastError error + Cache *cache + queue *queueList + Obj *Database +) + +func init() { + Cache = &cache{data: make(map[string]map[string]interface{})} + Cache.Init() + queue = &queueList{} + go queue.Start() +} + +// 关闭数据库连接 +func (this *Database) Close() { + this.DB.Close() +} + +// 获取最后发生的错误字符串 +func LastErr() string { + if lastError != nil { + return lastError.Error() + } + return "" +} + +// 执行语句 +func (this *Database) Exec(query string, args ...interface{}) (sql.Result, error) { + return this.DB.Exec(query, args...) +} + +// 查询单条记录 +func (this *Database) Query(query string, args ...interface{}) (*sql.Rows, error) { + return this.DB.Query(query, args...) +} + +// 查询单条记录 +func (this *Database) QueryRow(query string, args ...interface{}) *sql.Row { + return this.DB.QueryRow(query, args...) +} + +// Query2 查询实体集合 +// obj 为接收数据的实体指针 +func (this *Database) Query2(sql string, obj interface{}, args ...interface{}) error { + var tagMap map[string]int + var tp, tps reflect.Type + var n, i int + var err error + var ret reflect.Value + // 检测val参数是否为我们所想要的参数 + tp = reflect.TypeOf(obj) + if reflect.Ptr != tp.Kind() { + return errors.New("is not pointer") + } + + if reflect.Slice != tp.Elem().Kind() { + return errors.New("is not slice pointer") + } + + tp = tp.Elem() + tps = tp.Elem() + if reflect.Struct != tps.Kind() { + return errors.New("is not struct slice pointer") + } + + tagMap = make(map[string]int) + n = tps.NumField() + for i = 0; i < n; i++ { + tag := tps.Field(i).Tag.Get("sql") + if len(tag) > 0 { + tagMap[tag] = i + 1 + } + } + + // 执行查询 + ret, err = this.queryAndReflect(sql, tagMap, tp, args...) + if nil != err { + return err + } + + // 返回结果 + reflect.ValueOf(obj).Elem().Set(ret) + + return nil +} + +// queryAndReflect 查询并将结果反射成实体集合 +func (this *Database) queryAndReflect(sql string, + tagMap map[string]int, + tpSlice reflect.Type, args ...interface{}) (reflect.Value, error) { + + var ret reflect.Value + + // 执行sql语句 + rows, err := this.DB.Query(sql, args...) + if nil != err { + return reflect.Value{}, err + } + + defer rows.Close() + // 开始枚举结果 + cols, err := rows.Columns() + if nil != err { + return reflect.Value{}, err + } + + ret = reflect.MakeSlice(tpSlice, 0, 50) + // 构建接收队列 + scan := make([]interface{}, len(cols)) + row := make([]interface{}, len(cols)) + for r := range row { + scan[r] = &row[r] + } + + for rows.Next() { + feild := reflect.New(tpSlice.Elem()).Elem() + // 取得结果 + err = rows.Scan(scan...) + // 开始遍历结果 + for i := 0; i < len(cols); i++ { + n := tagMap[cols[i]] - 1 + + if n < 0 { + continue + } + switch feild.Type().Field(n).Type.Kind() { + case reflect.Bool: + if nil != row[i] { + feild.Field(n).SetBool("false" != string(row[i].([]byte))) + } else { + feild.Field(n).SetBool(false) + } + case reflect.String: + if nil != row[i] { + feild.Field(n).SetString(string(row[i].([]byte))) + } else { + feild.Field(n).SetString("") + } + case reflect.Float32: + if nil != row[i] { + //log.Println(row[i].(float32)) + switch reflect.TypeOf(row[i]).Kind(){ + case reflect.Slice: + v, e := strconv.ParseFloat(string(row[i].([]byte)), 0) + if nil == e { + feild.Field(n).SetFloat(float64(v)) + //feild.Field(n).SetFloat(float64(row[i].(float32))) + } + break + case reflect.Float64: + feild.Field(n).SetFloat(float64(row[i].(float32))) + } + + } else { + feild.Field(n).SetFloat(0) + } + case reflect.Float64: + if nil != row[i] { + //log.Println(row[i].(float32)) + //v, e := strconv.ParseFloat(string(row[i].([]byte)), 0) + //if nil == e { + feild.Field(n).SetFloat(row[i].(float64)) + //} + } else { + feild.Field(n).SetFloat(0) + } + case reflect.Int8: + fallthrough + case reflect.Int16: + fallthrough + case reflect.Int32: + fallthrough + case reflect.Int64: + fallthrough + case reflect.Int: + if nil != row[i] { + byRow, ok := row[i].([]byte) + if ok { + v, e := strconv.ParseInt(string(byRow), 10, 64) + if nil == e { + feild.Field(n).SetInt(v) + } + } else { + v, e := strconv.ParseInt(fmt.Sprint(row[i]), 10, 64) + if nil == e { + feild.Field(n).SetInt(v) + } + } + } else { + feild.Field(n).SetInt(0) + } + } + } + + ret = reflect.Append(ret, feild) + } + + return ret, nil +} + +// 执行UPDATE语句并返回受影响的行数 +// 返回0表示没有出错, 但没有被更新的行 +// 返回-1表示出错 +func (this *Database) Update(query string, args ...interface{}) (int64, error) { + ret, err := this.Exec(query, args...) + if err != nil { + return -1, err + } + aff, err := ret.RowsAffected() + if err != nil { + return -1, err + } + return aff, nil +} + +// 执行DELETE语句并返回受影响的行数 +// 返回0表示没有出错, 但没有被删除的行 +// 返回-1表示出错 +func (this *Database) Delete(query string, args ...interface{}) (int64, error) { + return this.Update(query, args...) +} + +func GenSql(obj interface{}) (string,error) { + ret := "" + typ := reflect.TypeOf(obj).Kind() + if typ != reflect.Struct{ + return (""),errors.New("not a struct") + } + value := obj.(reflect.Value) + num := value.NumField() + for i := 0;i < num;i++{ + if i == 0{ + ret += "(" + } + switch (value.Field(i).Type().Kind()){ + case reflect.String: + str := value.Field(i).Interface().(string) + if str[0] != '"'{ + ret += "\"" + str += "\"" + ret += str + + }else{ + ret += value.Field(i).Interface().(string) + } + case reflect.Int: + ret += fmt.Sprintf("%d",value.Field(i).Interface().(int)) + case reflect.Int8: + ret += fmt.Sprintf("%d",value.Field(i).Interface().(int8)) + case reflect.Int32: + ret += fmt.Sprintf("%d",value.Field(i).Interface().(int32)) + case reflect.Int64: + ret += fmt.Sprintf("%d",value.Field(i).Interface().(int64)) + case reflect.Int16: + ret += fmt.Sprintf("%d",value.Field(i).Interface().(int16)) + case reflect.Bool: + if value.Field(i).Interface().(bool) { + ret += fmt.Sprintf("true",) + }else { + ret += fmt.Sprintf("false",) + } + case reflect.Float32: + ret += fmt.Sprintf("%x",value.Field(i).Interface().(float32)) + case reflect.Float64: + ret += fmt.Sprintf("true",value.Field(i).Interface().(float64)) + } + if i == num - 1{ + ret += ")" + }else { + ret += "," + } + } + return ret,nil +} +func (this *Database) InsertObejct(tb_name string,obj interface{}) (int64,error) { + var tagMap map[int]string + var tp, tps reflect.Type + var n, i int + + // 检测val参数是否为我们所想要的参数 + tp = reflect.TypeOf(obj) + if reflect.Ptr != tp.Kind() { + return 0,errors.New("is not pointer") + } + + if reflect.Slice != tp.Elem().Kind() { + return 0,errors.New("is not slice pointer") + } + + tp = tp.Elem() + tps = tp.Elem() + value := reflect.ValueOf(obj).Elem() + + if reflect.Struct != tps.Kind() { + return 0,errors.New("is not struct slice pointer") + } + for z := 0; z < value.Len();z ++ { + tagMap = make(map[int]string) + n = tps.NumField() + var query_struct string + for i = 0; i < n; i++ { + tag := tps.Field(i).Tag.Get("sql") + if len(tag) > 0 { + tagMap[i] = tag + } + if i == 0{ + query_struct += "(" + } + query_struct += tagMap[i] + if i == n -1{ + query_struct += ")" + }else { + query_struct += "," + } + } + vs ,e := GenSql(value.Index(z)) + if nil != e{ + logs.Error(e.Error()) + } + query := "insert into " + tb_name + query_struct + "values " + vs + _, e = this.Insert(query) + if e != nil{ + logs.Error(e.Error()) + } + } + return 0,nil +} +// 执行INSERT语句并返回最后生成的自增ID +// 返回0表示没有出错, 但没生成自增ID +// 返回-1表示出错 +func (this *Database) Insert(query string, args ...interface{}) (int64, error) { + ret, err := this.Exec(query, args...) + if err != nil { + return -1, err + } + last, err := ret.LastInsertId() + if err != nil { + return -1, err + + } + return last, nil +} + +type OneRow map[string]string +type Results []OneRow + +// 判断字段是否存在 +func (row OneRow) Exist(field string) bool { + if _, ok := row[field]; ok { + return true + } + return false +} + +// 获取指定字段的值 +func (row OneRow) Get(field string) string { + if v, ok := row[field]; ok { + return v + } + return "" +} + +// 获取指定字段的整数值, 注意, 如果该字段不存在则会返回0 +func (row OneRow) GetInt(field string) int { + if v, ok := row[field]; ok { + return Atoi(v) + } + return 0 +} + +// 获取指定字段的整数值, 注意, 如果该字段不存在则会返回0 +func (row OneRow) GetInt64(field string) int64 { + if v, ok := row[field]; ok { + return Atoi64(v) + } + return 0 +} + +// 设置值 +func (row OneRow) Set(key, val string) { + row[key] = val +} + +// 查询不定字段的结果集 +func (this *Database) Select(query string, args ...interface{}) (Results, error) { + rows, err := this.DB.Query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + cols, err := rows.Columns() + if err != nil { + return nil, err + } + + colNum := len(cols) + rawValues := make([][]byte, colNum) + scans := make([]interface{}, len(cols)) //query.Scan的参数,因为每次查询出来的列是不定长的,所以传入长度固定当次查询的长度 + + // 将每行数据填充到[][]byte里 + for i := range rawValues { + scans[i] = &rawValues[i] + } + + results := make(Results, 0) + for rows.Next() { + err = rows.Scan(scans...) + if err != nil { + return nil, err + } + + row := make(map[string]string) + + for k, raw := range rawValues { + key := cols[k] + /*if raw == nil { + row[key] = "\\N" + } else {*/ + row[key] = string(raw) + //} + } + results = append(results, row) + } + return results, nil +} + +// 查询一行不定字段的结果 +func (this *Database) SelectOne(query string, args ...interface{}) (OneRow, error) { + ret, err := this.Select(query, args...) + if err != nil { + return nil, err + } + if len(ret) > 0 { + return ret[0], nil + } + return make(OneRow), nil +} + +// 队列入栈 +func (this *queueList) Push(item *QueueItem) { + this.lock.Lock() + this.list = append(this.list, item) + this.lock.Unlock() +} + +// 队列出栈 +func (this *queueList) Pop() chan *QueueItem { + item := make(chan *QueueItem) + go func() { + defer close(item) + for { + switch { + case len(this.list) == 0: + timeout := time.After(time.Second * 2) + select { + case <-this.quit: + this.quited = true + return + case <-timeout: + //log.Println("SQL Queue polling") + } + default: + this.lock.Lock() + i := this.list[0] + this.list = this.list[1:] + this.lock.Unlock() + select { + case item <- i: + return + case <-this.quit: + this.quited = true + return + } + } + } + }() + return item +} + +// 执行开始执行 +func (this *queueList) Start() { + for { + if this.quited { + return + } + c := this.Pop() + item := <-c + item.DB.Exec(item.Query, item.Params...) + } +} + +// 停止队列 +func (this *queueList) Stop() { + this.quit <- true +} + +// 向Sql队列中插入一条执行语句 +func (this *Database) Queue(query string, args ...interface{}) { + item := &QueueItem{ + DB: this, + Query: query, + Params: args, + } + queue.Push(item) +} diff --git a/db/mssql.go b/db/mssql.go new file mode 100644 index 0000000..bacbf4f --- /dev/null +++ b/db/mssql.go @@ -0,0 +1,76 @@ +package db + +import ( + "database/sql" + _ "github.com/denisenkom/go-mssqldb" +) + +// ProcExec 执行存储过程, 返回受影响的行数 +func (this *Database) ExecProc(procname string, params ...interface{}) (int64, error) { + result, err := this.Exec("EXEC " + procname + " " + this.GetProcPlaceholder(len(params)), params...) + if err != nil { + return 0, err + } + affected, err := result.RowsAffected() + if err != nil { + return 0, err + } + lastinsertid, err := result.LastInsertId() + if err != nil { + return affected, nil + } + return lastinsertid, nil +} + +// GetExecProcErr 执行存储过程, 返回是否在执行过程中出现错误 +func (this *Database) GetExecProcErr(procname string, params ...interface{}) error { + _, err := this.ExecProc(procname, params...) + if err != nil { + return err + } + return nil +} + +// ProcQuery 通过存储过程查询记录 +func (this *Database) ProcQuery(procname string, params ...interface{}) (rows *sql.Rows, err error) { + rows, err = this.Query("EXEC " + procname + " " + this.GetProcPlaceholder(len(params)), params...) + return +} + +// ProcQueryRow 通过存储过程查询单条记录 +func (this *Database) ProcQueryRow(procname string, params ...interface{}) *sql.Row { + return this.QueryRow("EXEC " + procname + " " + this.GetProcPlaceholder(len(params)), params...) +} + +// ProcStatus 调用存储过程并获取最终的执行状态码和提示信息 +func (this *Database) ProcStatus(procname string, params ...interface{}) (int, string) { + var status int + var msg string + err := this.QueryRow("EXEC " + procname + " " + this.GetProcPlaceholder(len(params)), params...).Scan(&status, &msg) + if err != nil { + return -99, err.Error() + } + return status, msg +} + +// ProcSelect 通过存储过程查询结果集 +func (this *Database) ProcSelect(procname string, params ...interface{}) (Results, error) { + return this.Select("EXEC " + procname + " " + this.GetProcPlaceholder(len(params)), params...) +} + +// ProcSelectOne 通过存储查询一行不定字段的结果 +func (this *Database) ProcSelectOne(procname string, params ...interface{}) (OneRow, error) { + return this.SelectOne("EXEC " + procname + " " + this.GetProcPlaceholder(len(params)), params...) +} + +// GetProcPlaceholder 按照指定数量生成调用存储过程时所用的参数占位符 +func (this *Database) GetProcPlaceholder(count int) (placeholder string) { + placeholder = "" + for i := 0; i < count; i++ { + if i > 0 { + placeholder += "," + } + placeholder += "?" + } + return +} \ No newline at end of file diff --git a/db/sqlManager.go b/db/sqlManager.go new file mode 100644 index 0000000..0b04394 --- /dev/null +++ b/db/sqlManager.go @@ -0,0 +1,33 @@ +package db + +import ( + "database/sql" + "user/config" + "fmt" + _ "github.com/go-sql-driver/mysql" + "log" +) + +var gDb Database + +func Init() { + mysqlconf := config.GetMysqlConfig() + log.Println(mysqlconf) + cnn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8",mysqlconf.UserName,mysqlconf.Password, + mysqlconf.Addr,mysqlconf.Db) + _db,err := sql.Open("mysql",cnn) + if err != nil{ + fmt.Println("connect sql server ",err.Error()) + } + e := _db.Ping() + if nil != e{ + fmt.Println(e.Error()) + } + gDb = Database{Type:string(""),DB:_db} +} + +func GetMysqlClient() *Database { + return &gDb +} + + diff --git a/db/sql_builder.go b/db/sql_builder.go new file mode 100644 index 0000000..aae95eb --- /dev/null +++ b/db/sql_builder.go @@ -0,0 +1,542 @@ +package db + +import ( + "database/sql" + "errors" + "strconv" + "strings" + "log" + "fmt" + "reflect" + "math/big" + + "git.jiaxianghudong.com/go/utils" +) + +const ( + _ = iota + TYPE_INSERT + TYPE_DELETE + TYPE_UPDATE + TYPE_SELECT + TYPE_INSERTUPDATE +) + +var ( + WrapSymbol = "`" + DBType = "mysql" +) + +// SQL语句构造结构 +type SB struct { + db *Database + t int + field, table, where, group, order, limit string + values SBValues + values2 SBValues + ignore bool + fullsql bool + debug bool + unsafe bool //是否进行安全检查, 专门针对无限定的UPDATE和DELETE进行二次验证 + args []interface{} +} + +// Exec返回结果 +type SBResult struct { + Success bool //语句是否执行成功 + Code int //错误代码 + Msg string //错误提示信息 + LastID int64 //最后产生的ID + Affected int64 //受影响的行数 + Sql string //最后执行的SQL +} + +// 值对象 +type SBValues map[string]interface{} + +// 增量值 +type IncVal struct { + Val int64 + BaseField string // 为空表示对当前字段累加 +} + +// 向值对象中加入值 +func (v SBValues) Add(key string, val interface{}) { + v[key] = val +} + +// 删除值对象中的某个值 +func (v SBValues) Del(key string) { + delete(v, key) +} + +// 判断指定键是否存在 +func (v SBValues) IsExist(key string) bool { + if _, exist := v[key]; exist { + return true + } + return false +} + +// 获取键的整形值 +func (v SBValues) Get(key string) interface{} { + if val, exist := v[key]; exist { + return val + } + return nil +} + +// 获取键的字符串值 +func (v SBValues) GetString(key string) string { + if val, exist := v[key]; exist { + if trueVal, ok := val.(string); ok { + return trueVal + } + } + return "" +} + +// 获取键的整形值 +func (v SBValues) GetInt(key string) int { + if val, exist := v[key]; exist { + if trueVal, ok := val.(int); ok { + return trueVal + } + } + return 0 +} + +// 获取键的无符号整形值 +func (v SBValues) GetUint(key string) uint { + if val, exist := v[key]; exist { + if trueVal, ok := val.(uint); ok { + return trueVal + } + } + return 0 +} + +// 获取键的64位整形值 +func (v SBValues) GetInt64(key string) int64 { + if val, exist := v[key]; exist { + if trueVal, ok := val.(int64); ok { + return trueVal + } + } + return 0 +} + +// 返回绑定完参数的完整的SQL语句 +func FullSql(str string, args ...interface{}) (string, error) { + if !strings.Contains(str, "?") { + return str, nil + } + sons := strings.Split(str, "?") + + var ret string + var argIndex int + var maxArgIndex = len(args) + + for _, son := range sons { + ret += son + + if argIndex < maxArgIndex { + switch v := args[argIndex].(type) { + case int: + ret += strconv.Itoa(v) + case int8: + ret += strconv.Itoa(int(v)) + case int16: + ret += strconv.Itoa(int(v)) + case int32: + ret += utils.I64toA(int64(v)) + case int64: + ret += utils.I64toA(v) + case uint: + ret += utils.UitoA(v) + case uint8: + ret += utils.UitoA(uint(v)) + case uint16: + ret += utils.UitoA(uint(v)) + case uint32: + ret += utils.Ui32toA(v) + case uint64: + ret += utils.Ui64toA(v) + case float32: + ret += utils.F32toA(v) + case float64: + ret += utils.F64toA(v) + case *big.Int: + ret += v.String() + case bool: + if v { + ret += "true" + } else { + ret += "false" + } + case string: + ret += "'" + strings.Replace(strings.Replace(v, "'", "", -1), `\`, `\\`, -1) + "'" + case nil: + ret += "NULL" + default: + return "", errors.New(fmt.Sprintf("invalid sql argument type: %v => %v (sql: %s)", reflect.TypeOf(v).String(), v, str)) + } + + argIndex++ + } + } + + return ret, nil +} + +// 构建SQL语句 +// param: returnFullSql 是否返回完整的sql语句(即:绑定参数之后的语句) +func (q *SB) ToSql(returnFullSql ...bool) (str string, err error) { + q.args = make([]interface{}, 0) + + switch q.t { + case TYPE_INSERT: + if q.table == "" { + err = errors.New("table cannot be empty.") + return + } + if len(q.values) == 0 { + err = errors.New("values cannot be empty.") + return + } + if q.ignore { + str = "INSERT IGNORE INTO " + q.table + } else { + str = "INSERT INTO " + q.table + } + var fields, placeholder string + for k, v := range q.values { + fields += "," + WrapSymbol + k + WrapSymbol + placeholder += ",?" + q.args = append(q.args, v) + } + str += " (" + utils.Substr(fields, 1) + ") VALUES (" + utils.Substr(placeholder, 1) + ")" + case TYPE_DELETE: + if q.table != "" { + if q.where == "" && !q.unsafe { + err = errors.New("deleting all data is not safe.") + return + } + str = "DELETE " + q.table + if q.table != "" { + str += " FROM " + q.table + } + if q.where != "" { + str += " WHERE " + q.where + } + } + case TYPE_UPDATE: + if q.table != "" { + if q.where == "" && !q.unsafe { + err = errors.New("updating all data is not safe.") + return + } + str = "UPDATE " + q.table + str += " SET " + utils.Substr(q.buildUpdateParams(q.values), 1) + if q.where != "" { + str += " WHERE " + q.where + } + } + case TYPE_INSERTUPDATE: + if q.table != "" { + str = "INSERT INTO " + q.table + var fields, placeholder string + for k, v := range q.values { + fields += "," + WrapSymbol + k + WrapSymbol + placeholder += ",?" + q.args = append(q.args, v) + } + str += " (" + utils.Substr(fields, 1) + ") VALUES (" + utils.Substr(placeholder, 1) + ") ON DUPLICATE KEY UPDATE " + placeholder = q.buildUpdateParams(q.values2) + str += utils.Substr(placeholder, 1) + } + case TYPE_SELECT: + str = "SELECT " + q.field + if q.table != "" { + str += " FROM " + q.table + } + if q.where != "" { + str += " WHERE " + q.where + } + if q.group != "" { + str += " GROUP BY " + q.group + } + if q.order != "" { + str += " ORDER BY " + q.order + } + if q.limit != "" && (q.db.Type == "" || q.db.Type == "mysql") { + str += " LIMIT " + q.limit + } + } + + if len(returnFullSql) == 1 && returnFullSql[0] { + str, err = FullSql(str, q.args...) + } + + return +} + +// 构造Update更新参数 +func (q *SB) buildUpdateParams(vals SBValues) string { + var placeholder string + for k, v := range vals { + if iv, ok := v.(IncVal); ok { + placeholder += "," + WrapSymbol + k + WrapSymbol + "=" + utils.Ternary(iv.BaseField == "", k, iv.BaseField).(string) + if iv.Val >= 0 { + placeholder += "+" + utils.I64toA(iv.Val) + } else { + placeholder += utils.I64toA(iv.Val) + } + } else { + placeholder += "," + WrapSymbol + k + WrapSymbol + "=?" + q.args = append(q.args, v) + } + } + return placeholder +} + +// 设置数据库对象 +func (q *SB) DB(db *Database) *SB { + q.db = db + return q +} + +// 设置FROM字句 +func (q *SB) From(str string) *SB { + q.table = str + return q +} + +// 设置表名 +func (q *SB) Table(str string) *SB { + return q.From(str) +} + +// 设置WHERE字句 +func (q *SB) Where(str string) *SB { + q.where = str + return q +} + +// 设置GROUP字句 +func (q *SB) Group(str string) *SB { + q.group = str + return q +} + +// 设置GROUP字句 +func (q *SB) Order(str string) *SB { + q.order = str + return q +} + +// 设置LIMIT字句 +func (q *SB) Limit(count int, offset ...int) *SB { + if len(offset) > 0 { + q.limit = utils.Itoa(offset[0]) + "," + utils.Itoa(count) + } else { + q.limit = "0," + utils.Itoa(count) + } + return q +} + +// 设置安全检查开关 +func (q *SB) Unsafe(unsefe ...bool) *SB { + if len(unsefe) == 1 && !unsefe[0] { + q.unsafe = false + } else { + q.unsafe = true + } + return q +} + +// 是否Debug +func (q *SB) Debug(debug ...bool) *SB { + if len(debug) == 1 && !debug[0] { + q.debug = false + } else { + q.debug = true + } + return q +} + +// 设置值 +func (q *SB) Value(m SBValues) *SB { + q.values = m + return q +} + +// 设置值2 +func (q *SB) Value2(m SBValues) *SB { + q.values2 = m + return q +} + +// 添加值 +func (q *SB) AddValue(key string, val interface{}) *SB { + q.values.Add(key, val) + return q +} + +// 添加值2 +func (q *SB) AddValue2(key string, val interface{}) *SB { + q.values2.Add(key, val) + return q +} + +// 获取一个值对象 +func NewValues() SBValues { + return SBValues{} +} + +// 构建INSERT语句 +func Insert(ignore ...bool) *SB { + var i bool + if len(ignore) == 1 && ignore[0] { + i = true + } + return &SB{t: TYPE_INSERT, db: Obj, ignore: i, values: SBValues{}, args: make([]interface{}, 0)} +} + +// 构建DELETE语句 +func Delete() *SB { + return &SB{t: TYPE_DELETE, db: Obj} +} + +// 构建UPDATE语句 +func Update() *SB { + return &SB{t: TYPE_UPDATE, db: Obj, values: SBValues{}, args: make([]interface{}, 0)} +} + +// 构建InsertUpdate语句, 仅针对MySQL有效, 内部使用ON DUPLICATE KEY UPDATE方式实现 +func InsertUpdate() *SB { + return &SB{t: TYPE_INSERTUPDATE, db: Obj, values: SBValues{}, values2: SBValues{}, args: make([]interface{}, 0)} +} + +// 构建SELECT语句 +func Select(str ...string) *SB { + fields := "*" + if len(str) == 1 { + fields = str[0] + } + return &SB{t: TYPE_SELECT, db: Obj, field: fields} +} + +// 获取构造SQL后的参数 +func (q *SB) GetArgs() []interface{} { + return q.args +} + +// +func (q *SB) FullSql(yes ...bool) *SB { + if len(yes) == 1 { + q.fullsql = yes[0] + } else { + q.fullsql = true + } + return q +} + +// 执行INSERT、DELETE、UPDATE语句 +func (q *SB) Exec(args ...interface{}) *SBResult { + var err error + sbRet := &SBResult{} + sbRet.Sql, err = q.ToSql() + if err != nil { + sbRet.Msg = err.Error() + } else { + if q.debug { + log.Println("\n\tSQL prepare statement:\n\t", sbRet.Sql, "\n\tMap args:\n\t", q.args, "\n\tParams:\n\t", args) + } + + var ret sql.Result + var err error + if q.fullsql { + var sqlStr string + sqlStr, err = FullSql(sbRet.Sql, append(q.args, args...)...) + if err == nil { + ret, err = q.db.Exec(sqlStr) + } + } else { + ret, err = q.db.Exec(sbRet.Sql, append(q.args, args...)...) + } + if err != nil { + sbRet.Msg = err.Error() + } else { + sbRet.Success = true + switch q.t { + case TYPE_INSERT: + if DBType == "mysql" { + last, err := ret.LastInsertId() + if (err == nil) { + sbRet.LastID = last; + } + } + case TYPE_DELETE: + fallthrough + case TYPE_UPDATE: + fallthrough + case TYPE_INSERTUPDATE: + aff, err := ret.RowsAffected() + if (err == nil) { + sbRet.Affected = aff + } + } + } + } + return sbRet +} + +// 查询记录集 +func (q *SB) Query(args ...interface{}) (Results, error) { + s, e := q.ToSql() + if e != nil { + return nil, e + } + if q.debug { + log.Println("\n\tSQL prepare statement:\n\t", s, "\n\tParams:\n\t", args) + } + return q.db.Select(s, args...) +} + +// 查询单行数据 +func (q *SB) QueryOne(args ...interface{}) (OneRow, error) { + q.Limit(1, 0) + s, e := q.ToSql() + if e != nil { + return nil, e + } + if q.debug { + log.Println("\n\tSQL prepare statement:\n\t", s, "\n\tParams:\n\t", args) + } + return q.db.SelectOne(s, args...) +} + +// 查询记录集 +func (q *SB) QueryAllRow(args ...interface{}) (*sql.Rows, error) { + s, e := q.ToSql() + if e != nil { + return nil, e + } + if q.debug { + log.Println("\n\tSQL prepare statement:\n\t", s, "\n\tParams:\n\t", args) + } + return q.db.Query(s, args...) +} + +// 查询单行数据 +func (q *SB) QueryRow(args ...interface{}) *sql.Row { + s, e := q.ToSql() + if e != nil { + return nil + } + if q.debug { + log.Println("\n\tSQL prepare statement:\n\t", s, "\n\tParams:\n\t", args) + } + return q.db.QueryRow(s, args...) +} \ No newline at end of file diff --git a/db/utils.go b/db/utils.go new file mode 100644 index 0000000..996edb1 --- /dev/null +++ b/db/utils.go @@ -0,0 +1,66 @@ +package db + +import ( + "strings" + "strconv" + "database/sql" +) + +// 根据传入的字段列表生成相符数量的占位符 +func GetPlaceholderByFields(fileds string) string { + fileds = strings.Replace(fileds, " ", "", -1) + fileds = strings.Trim(fileds, ",") + count := len(strings.Split(fileds, ",")) + ret := make([]string, count) + for i := 0; i < count; i++ { + ret[i] = "?" + } + return strings.Join(ret, ",") +} + +// Atoi 转换成整型 +func Atoi(s string, d ...int) int { + i, err := strconv.Atoi(s) + if err != nil { + if len(d) > 0 { + return d[0] + } else { + return 0 + } + } + + return i +} + +// Atoi64 转换成整型int64 +func Atoi64(s string, d ...int64) int64 { + i, err := strconv.ParseInt(s, 10, 64) + if err != nil { + if len(d) > 0 { + return d[0] + } else { + return 0 + } + } + + return i +} + +// 返回一个带有Null值的数据库字符串 +func NewNullString(s string) sql.NullString { + if len(s) == 0 { + return sql.NullString{} + } + return sql.NullString{ + String: s, + Valid: true, + } +} + +// 返回一个带有Null值的数据库整形 +func NewNullInt64(s int64, isNull bool) sql.NullInt64 { + return sql.NullInt64{ + Int64: s, + Valid: !isNull, + } +} \ No newline at end of file diff --git a/logs/logs.go b/logs/logs.go new file mode 100644 index 0000000..ee29b8f --- /dev/null +++ b/logs/logs.go @@ -0,0 +1,148 @@ +package logs + +import ( + "fmt" + "os" + "runtime" + "time" +) + +const ( + LOG_ERROR = iota + LOG_WARING + LOG_INFO + LOG_DEBUG +) + +var log *mylog + +/* + * 初始化 + */ +func init() { + log = newMylog() +} + +func Init(dir string, file string, level int, savefile bool) { + log.setDir(dir) + log.setFile(file) + log.setLevel(level) + log.setSavefile(savefile) +} + +func Error(err ...interface{}) { + log.write(LOG_ERROR, fmt.Sprint(err...)) +} + +func Waring(war ...interface{}) { + log.write(LOG_WARING, fmt.Sprint(war...)) +} +func SetLevel(level int) { + log.setLevel(level) +} +func Info(info ...interface{}) { + log.write(LOG_INFO, fmt.Sprint(info...)) +} + +func Debug(deb ...interface{}) { + log.write(LOG_DEBUG, fmt.Sprint(deb...)) +} + +/* + * 日志执行函数 + */ +type mylog struct { + log chan string // 日志chan + dir string // 日志存放目录 + file string // 日志文件名 + savefile bool // 是否保存到文件 + level int // 日志级别 +} + +func newMylog() *mylog { + log := &mylog{} + + log.log = make(chan string, 100) + log.dir = "/opt/logs" + log.file = "out" + log.savefile = false + + go log.run() + return log +} + +func (l *mylog) setDir(dir string) { + l.dir = dir +} + +func (l *mylog) setFile(file string) { + l.file = file +} + +func (l *mylog) setSavefile(b bool) { + l.savefile = b +} + +func (l *mylog) setLevel(level int) { + l.level = level +} + +func (l *mylog) getLevelString(level int) string { + switch level { + case LOG_ERROR: + return "ERROR" + case LOG_WARING: + return "WARING" + case LOG_INFO: + return "INFO" + case LOG_DEBUG: + return "DEBUG" + } + + return "unknown" +} + +func (l *mylog) write(level int, str string) { + // 判断级别 + if level > l.level { + return + } + + // 输出日志 + pc, _, line, _ := runtime.Caller(2) + p := runtime.FuncForPC(pc) + t := time.Now() + str = fmt.Sprintf("[%04d-%02d-%02d %02d:%02d:%02d] [%s] %s(%d): %s\n", + t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), + l.getLevelString(level), p.Name(), line, str) + // 输出到控制台 + if false == l.savefile { + fmt.Print(str) + return + } + + // 输出到文件 + l.log <- str +} + +func (l *mylog) run() { + for { + str := <-l.log + + // 判断文件夹是否存在 + _, err := os.Stat(l.dir) + if nil != err { + os.MkdirAll(l.dir, os.ModePerm) + } + + // 获取时间 + t := time.Now() + path := fmt.Sprintf("%s/%s-%04d-%02d-%02d.log", l.dir, l.file, + t.Year(), t.Month(), t.Day()) + fp, err := os.OpenFile(path, os.O_WRONLY|os.O_APPEND|os.O_CREATE, os.ModePerm) + if nil == err { + fp.WriteString(str) + fp.Close() + } + } +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..f5b1b18 --- /dev/null +++ b/main.go @@ -0,0 +1,46 @@ +package main + +import ( + "github.com/gin-gonic/gin" + "log" + "strconv" + "user/controller" + "user/config" + "user/db" + "user/logs" +) + +func InitMysql() { + c := config.GetMysqlConfig() + if c == nil{ + logs.Error("cannnot connect mysql server") + }else { + db.Init() + } +} + +func main() { + e := config.Init("user.yaml") + if nil != e{ + log.Println(e.Error()) + } + db.Init() + r := gin.Default() + { + /** 添加或修改用户 **/ + r.POST("/api/user", controller.SetUser) + /** 删除用户 **/ + r.DELETE("/api/user", controller.DelUser) + /** 获取单独用户详情信息 methods(id) **/ + r.GET("/api/user", controller.GetUser) + /** 获取所有用户 **/ + r.GET("/api/users", controller.GetUsers) + /** 用户登录 **/ + r.POST("/api/login", controller.Login) + /** 用户注册 **/ + r.POST("/api/register", controller.Register) + /** 用户退出登陆 **/ + r.GET("/api/logout", controller.Logout) + } + r.Run(":" + strconv.Itoa(config.GetPort())) +} diff --git a/model/model.go b/model/model.go new file mode 100644 index 0000000..1f0eb70 --- /dev/null +++ b/model/model.go @@ -0,0 +1,13 @@ +package model + +type Users struct { + ID int64 `sql:"id" json:"id"` + UserName string `sql:"user_name" json:"UserName"` + UserPwd string `sql:"user_pwd" json:"UserPwd"` + CreatedDate string `sql:"created_date" json:"CreatedDate"` + UpdatedDate string `sql:"updated_date" json:"UpdatedDate"` + DisplayName string `sql:"display_name" json:"DisplayName"` + EmailAddress string `sql:"email_address" json:"EmailAddress"` + Tel string `sql:"tel" json:"tel"` + Avatar string `sql:"avatar" json:"Avatar"` +} diff --git a/redis/redis.go b/redis/redis.go new file mode 100644 index 0000000..29ad59d --- /dev/null +++ b/redis/redis.go @@ -0,0 +1,426 @@ +package redis + +import ( + "errors" + "fmt" + // "errors" + "strconv" + "time" + + "gopkg.in/redis.v3" +) + +const maxConn = 100 + +var redisChan chan *redis.Client + +var option *redis.Options + +func Init(addr string, pwd string, db int64) error { + option = &redis.Options{ + Addr: addr, + Password: pwd, + DB: db, + } + redisChan = make(chan *redis.Client, maxConn) + for i := 0; i < maxConn; i++ { + client, err := creatRedisClient(option) + if err != nil { + return err + } + redisChan <- client + } + return nil +} + +// 创建redis对象 +func creatRedisClient(option *redis.Options) (*redis.Client, error) { + + client := redis.NewClient(option) + // 检测client有效性 + if nil != client { + _, err := client.Ping().Result() + if nil != err { + client.Close() + return nil, errors.New(fmt.Sprintf("fail to ping redis-svr,addr :%s , pwd :%s ,DB :%d", option.Addr, option.Password, option.DB)) + } + } else { + return nil, errors.New(fmt.Sprintf("fail to connect redis-svr,,addr :%s , pwd :%s ,DB :%d", option.Addr, option.Password, option.DB)) + } + return client, nil +} + +// 获取redis +func getRedis() (*redis.Client, error) { + var client *redis.Client + + select { + case <-time.After(time.Second * 10): + case client = <-redisChan: + } + + // 检测client有效性 + if nil != client { + _, err := client.Ping().Result() + if nil != err { + client.Close() + // 尝试3次重连 + for i := 0; i < 3; i++ { + client, err = creatRedisClient(option) + if client != nil { + return client, err + } + } + + return nil, err + } + } + return client, nil +} + +// 将redis链接放回连接池 +func relaseRedis(client *redis.Client) { + select { + case <-time.After(time.Second * 10): + client.Close() + case redisChan <- client: + } +} + +func Get(key string) (string, error) { + client, err := getRedis() + if err != nil { + return "", err + } + if client == nil { + return "", errors.New("failed to get rds client") + } + defer relaseRedis(client) + + val, err := client.Get(key).Result() + if nil != err { + if err.Error() == "redis: nil" { + return "", nil + } + return "", err + } + + return val, nil +} + +// redis查询 +func Keys(key string) ([]string, error) { + client, err := getRedis() + if err != nil { + return nil, err + } + if client == nil { + return nil, errors.New("failed to get rds client") + } + defer relaseRedis(client) + + val, err := client.Keys(key).Result() + if nil != err { + var nullResult = []string{""} + return nullResult, err + } + + return val, nil +} + +func Set(key string, val string, expire ...time.Duration) error { + client, err := getRedis() + if err != nil { + return err + } + if client == nil { + return errors.New("failed to get rds client") + } + defer relaseRedis(client) + + var t time.Duration = 0 + + if len(expire) == 1 { + t = expire[0] + } + + _, err = client.Set(key, val, t).Result() + + return err +} +func HSet(key string, filed, val string) error { + client, err := getRedis() + if err != nil { + return err + } + if client == nil { + return errors.New("failed to get rds client") + } + defer relaseRedis(client) + _, err = client.HSet(key, filed, val).Result() + return err +} +func SIsMember(key string, val int32) (bool, error) { + client, err := getRedis() + if err != nil { + return false, err + } + if client == nil { + return false, errors.New("failed to get rds client") + } + defer relaseRedis(client) + isExist, err := client.SIsMember(key, val).Result() + return isExist, err +} +func SAdd(key string, members ...string) error { + client, err := getRedis() + if err != nil { + return err + } + if client == nil { + return errors.New("failed to get rds client") + } + defer relaseRedis(client) + _, err = client.SAdd(key, members...).Result() + return err +} +func SRem(key string, members ...string) error { + client, err := getRedis() + if err != nil { + return err + } + if client == nil { + return errors.New("failed to get rds client") + } + defer relaseRedis(client) + _, err = client.SRem(key, members...).Result() + return err +} +func SMembers(key string) ([]string, error) { + client, err := getRedis() + if err != nil { + return nil, err + } + if client == nil { + return nil, errors.New("failed to get rds client") + } + defer relaseRedis(client) + members, err := client.SMembers(key).Result() + return members, err +} +func HIncrBy(key, filed string, val int64) error { + client, err := getRedis() + if err != nil { + return err + } + if client == nil { + return errors.New("failed to get rds client") + } + defer relaseRedis(client) + _, err = client.HIncrBy(key, filed, val).Result() + return err +} +func LPush(key string, val ...string) error { + client, err := getRedis() + if err != nil { + return err + } + if client == nil { + return errors.New("failed to get rds client") + } + defer relaseRedis(client) + _, err = client.LPush(key, val...).Result() + return err +} +func BRPopLPush(source string, dest string, time time.Duration) (string, error) { + client, err := getRedis() + if err != nil { + return "", err + } + if client == nil { + return "", errors.New("failed to get rds client") + } + defer relaseRedis(client) + val, err := client.BRPopLPush(source, dest, time).Result() + return val, err +} +func RPop(key string) (string, error) { + client, err := getRedis() + if err != nil { + return "", err + } + if client == nil { + return "", errors.New("failed to get rds client") + } + defer relaseRedis(client) + val, err := client.RPop(key).Result() + return val, err +} +func LLen(key string) (int64, error) { + client, err := getRedis() + if err != nil { + return 0, err + } + if client == nil { + return 0, errors.New("failed to get rds client") + } + defer relaseRedis(client) + val, err := client.LLen(key).Result() + return val, err +} +func LTrim(key string, start, stop int64) (string, error) { + client, err := getRedis() + if err != nil { + return "", err + } + if client == nil { + return "", errors.New("failed to get rds client") + } + defer relaseRedis(client) + val, err := client.LTrim(key, start, stop).Result() + return val, err +} +func LRange(key string, start, stop int64) ([]string, error) { + client, err := getRedis() + if err != nil { + return nil, err + } + if client == nil { + return nil, errors.New("failed to get rds client") + } + defer relaseRedis(client) + + val, err := client.LRange(key, start, stop).Result() + if nil != err { + return nil, err + } + + return val, nil +} +func UpdateExpire(key string, expire time.Duration) error { + client, err := getRedis() + if err != nil { + return err + } + if client == nil { + return errors.New("failed to get rds client") + } + defer relaseRedis(client) + + _, err = client.Expire(key, expire).Result() + + return err +} + +func HGet(key string, hash string) (string, error) { + client, err := getRedis() + if err != nil { + return "", err + } + if client == nil { + return "", errors.New("failed to get rds client") + } + defer relaseRedis(client) + val, err := client.HGet(key, hash).Result() + if nil != err { + return "", err + } + + return val, nil +} +func HGetInt64(key string, hash string) (int64, error) { + client, err := getRedis() + if err != nil { + return 0, err + } + if client == nil { + return 0, errors.New("failed to get rds client") + } + defer relaseRedis(client) + val, err := client.HGet(key, hash).Int64() + if nil != err { + return 0, err + } + + return val, nil +} +func SCard(key string) (int64, error) { + client, err := getRedis() + if err != nil { + return 0, err + } + if client == nil { + return 0, errors.New("failed to get rds client") + } + defer relaseRedis(client) + val, err := client.SCard(key).Result() + if nil != err { + return 0, err + } + + return val, nil +} + +// 获取时间 +func Time() (int64, error) { + client, err := getRedis() + if err != nil { + return 0, err + } + if client == nil { + return 0, errors.New("failed to get rds client") + } + defer relaseRedis(client) + // 读取redis时间 + r := client.Time() + if nil == r { + return 0, errors.New("read redis error") + } + + if nil != r.Err() { + return 0, r.Err() + } + + return strconv.ParseInt(r.Val()[0], 10, 0) +} + +// 删除指定key的redis记录 +func Del(key string) (int64, error) { + client, err := getRedis() + if err != nil { + return 0, err + } + if client == nil { + return 0, errors.New("failed to get rds client") + } + defer relaseRedis(client) + + var iResult int64 + cmdResult := client.Del(key) + if nil == cmdResult.Err() { + iResult, _ = cmdResult.Result() + } else { + iResult = 0 + } + + return iResult, nil +} + +// 清空rendis_db +func TruncateDB() (string, error) { + client, err := getRedis() + if err != nil { + return "", err + } + if client == nil { + return "", errors.New("failed to get rds client") + } + defer relaseRedis(client) + + cmdResult := client.FlushDb() + if cmdResult.Err() != nil { + return "", cmdResult.Err() + } + + return cmdResult.Result() +} diff --git a/user.yaml b/user.yaml new file mode 100644 index 0000000..2fe7c41 --- /dev/null +++ b/user.yaml @@ -0,0 +1,19 @@ +listen_api: 4596 +runmode: debug +max_conn: 1500 +logs: + dir: "/var/log/user" + file: "user.log" + level: 1 + savefile: true +redis: + addr: 118.24.238.198 + password: 6379 + db: 1 +mysql: + addr: 118.24.238.198 + user: server + password: 123 + db: background + max_open: 100 + MaxIdle: 99