完善代码

This commit is contained in:
DESKTOP-4RNDQIC\29019 2019-04-07 12:25:07 +08:00
parent 296ec1e167
commit 93a4a224eb
10 changed files with 235 additions and 40 deletions

View File

@ -4,7 +4,7 @@ import (
"fmt" "fmt"
"github.com/pkg/errors" "github.com/pkg/errors"
"gopkg.in/redis.v4" "gopkg.in/redis.v4"
"user/logs" "background/logs"
) )
var ( var (

View File

@ -14,8 +14,8 @@ import (
"strings" "strings"
"text/template" "text/template"
"time" "time"
"user/config" "background/config"
"user/db" "background/db"
) )
type MailController struct { type MailController struct {

View File

@ -1,7 +1,34 @@
package middle package middle
import "github.com/gin-gonic/gin" import (
"background/config"
"background/controller"
"background/model"
"encoding/json"
"github.com/gin-gonic/gin"
)
func AuthMiddle(c *gin.Context) { func AuthMiddle(c *gin.Context) {
token := c.Query("token")
user := c.Query("userid")
if user == "" || token == ""{
c.JSON(200,controller.RespBase{
"auth err",20,nil,
})
}
if config.RedisOne().Exists(token).Val(){
users := model.Users{}
userInfo := config.RedisOne().Get(token).Val()
e := json.Unmarshal([]byte(userInfo),&users)
if nil != e{
c.JSON(200,controller.RespBase{
"auth err",10,nil,
})
}
}else {
c.JSON(200,controller.RespBase{
"expired",210,nil,
})
}
c.Next()
} }

View File

@ -1,6 +1,7 @@
package controller package controller
import ( import (
"background/utils"
"bytes" "bytes"
"crypto/md5" "crypto/md5"
"encoding/json" "encoding/json"
@ -19,14 +20,15 @@ import (
"strconv" "strconv"
"strings" "strings"
"time" "time"
"user/config" "background/config"
"user/db" "background/db"
"user/logs" "background/logs"
"user/model" "background/model"
"user/redis" "background/redis"
) )
type UserController struct { type UserController struct {
} }
type ReqSendEmailCode struct { type ReqSendEmailCode struct {
@ -80,9 +82,46 @@ func (this *UserController) Auth(c *gin.Context) {
// @Produce json // @Produce json
// @Param q query string false "name search by q" // @Param q query string false "name search by q"
// @Success 200 {array} util.RespBase // @Success 200 {array} util.RespBase
// @Router /accounts [get] // @Router /setUser [get]
func (this *UserController) SetUser(c *gin.Context) { func (this *UserController) SetUser(c *gin.Context) {
}
// SetUser godoc
// @Summary SetUser
// @Description set userinfo
// @Accept json
// @Produce json
// @Param q query string false "name search by q"
// @Success 200 {array} util.RespBase
// @Router /setUser [get]
func (this *UserController) ModifyPasswd(c *gin.Context) {
type ReqModifyPasswd struct{
id int `json:"id"`
UserName string `json:"user_name"`
Password string `json:"password"`
}
var req ReqModifyPasswd
var resp RespBase
resp.Status = -1
resp.Msg = "err"
defer func() {
c.JSON(200,resp)
}()
e := c.BindJSON(&req)
if nil != e{
logs.Error(e.Error())
return
}
e = model.ModyfyPassword(req.UserName,req.Password)
if nil != e{
logs.Error(e.Error())
return
}
resp.Msg = "OK"
resp.Status = 0
} }
func (this *UserController) DelUser(c *gin.Context) { func (this *UserController) DelUser(c *gin.Context) {
@ -222,16 +261,6 @@ func createToken(length int32) string {
return result return result
} }
func ByteSliceToString(b []byte) string {
var ret string
for i := 0; i < len(b); i++ {
s := fmt.Sprintf("%02x", b[i])
ret += string(s)
}
return ret
}
func DefaultOption(c *gin.Context) { func DefaultOption(c *gin.Context) {
var resp RespBase var resp RespBase
@ -280,14 +309,14 @@ func (this *UserController) Login(c *gin.Context) {
if nil != er { if nil != er {
log.Println(er.Error()) log.Println(er.Error())
} }
strpassmd5 := ByteSliceToString(passmd5) strpassmd5 := utils.ByteSliceToString(passmd5)
if len(result) == 1 { if len(result) == 1 {
if result[0].UserPwd == strpassmd5 { if result[0].UserPwd == strpassmd5 {
socketToken := md5.New() socketToken := md5.New()
socketToken.Write([]byte(createToken(6))) // 需要加密的字符串为 123456 socketToken.Write([]byte(createToken(6))) // 需要加密的字符串为 123456
socketTokenMd5 := socketToken.Sum(nil) socketTokenMd5 := socketToken.Sum(nil)
m := structs.Map(result[0]) m := structs.Map(result[0])
m["socketToken"] = ByteSliceToString(socketTokenMd5) m["socketToken"] = utils.ByteSliceToString(socketTokenMd5)
sessionInfo, err := json.Marshal(m) sessionInfo, err := json.Marshal(m)
if err != nil { if err != nil {
log.Println(err.Error()) log.Println(err.Error())
@ -368,7 +397,7 @@ func (this *UserController) Register(c *gin.Context) {
h := md5.New() h := md5.New()
h.Write([]byte(req.UserPwd)) h.Write([]byte(req.UserPwd))
passwdmd5 := h.Sum(nil) passwdmd5 := h.Sum(nil)
strpassmd5 := ByteSliceToString(passwdmd5) strpassmd5 := utils.ByteSliceToString(passwdmd5)
user.UserPwd = strpassmd5 user.UserPwd = strpassmd5
user.UpdatedDate = time.Now().Format("2006-01-02 15:04:05") user.UpdatedDate = time.Now().Format("2006-01-02 15:04:05")
user.CreatedDate = time.Now().Format("2006-01-02 15:04:05") user.CreatedDate = time.Now().Format("2006-01-02 15:04:05")
@ -386,8 +415,8 @@ func (this *UserController) Register(c *gin.Context) {
resp.Status = 20 resp.Status = 20
return return
} }
query := fmt.Sprintf("insert into users(user_name,user_pwd,created_date,"+ query := fmt.Sprintf("insert into users(user_name,user_pwd,created_date,"+
"updated_date,display_name,email_address) values ('%s','%s','%s','%s','%s','%s') ", user.UserName, user.UserPwd, user.CreatedDate, user.UpdatedDate, "updated_date,display_name,email_address) values ('%s','%s','%s','%s','%s','%s')", user.UserName, user.UserPwd, user.CreatedDate, user.UpdatedDate,
user.DisplayName, user.EmailAddress) user.DisplayName, user.EmailAddress)
n, er := db.GetMysqlClient().Insert(query) n, er := db.GetMysqlClient().Insert(query)
if n == 0 || n < 0 { if n == 0 || n < 0 {

View File

@ -9,7 +9,7 @@ import (
"strconv" "strconv"
"sync" "sync"
"time" "time"
"user/logs" "background/logs"
) )
// 数据容器抽象对象定义 // 数据容器抽象对象定义

View File

@ -4,7 +4,7 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"user/config" "background/config"
) )
var gDb Database var gDb Database

20
main.go
View File

@ -1,14 +1,15 @@
package main package main
import ( import (
"background/config"
"background/controller"
"background/controller/middle"
"background/db"
"background/logs"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tommy351/gin-sessions" "github.com/tommy351/gin-sessions"
"log" "log"
"strconv" "strconv"
"user/config"
"user/controller"
"user/db"
"user/logs"
) )
var ( var (
@ -74,18 +75,21 @@ func main() {
/** 删除用户 **/ /** 删除用户 **/
api.DELETE("/user", userController.DelUser) api.DELETE("/user", userController.DelUser)
/** 获取单独用户详情信息 methods(id) **/ /** 获取单独用户详情信息 methods(id) **/
api.GET("/user", userController.GetUser) api.GET("/user", middle.AuthMiddle,userController.GetUser)
/** 获取所有用户 **/ /** 获取所有用户 **/
api.GET("/users", userController.Users) api.GET("/users", middle.AuthMiddle,userController.Users)
api.POST("/search_users", userController.SerarchUsers) api.POST("/search_users",middle.AuthMiddle,userController.SerarchUsers)
/** 用户登录 **/ /** 用户登录 **/
api.POST("/login", userController.Login) api.POST("/login", userController.Login)
/** 用户注册 **/ /** 用户注册 **/
api.POST("/register", userController.Register) api.POST("/register", userController.Register)
/** 用户退出登陆 **/ /** 用户退出登陆 **/
api.GET("/logout", userController.Logout) api.GET("/logout", middle.AuthMiddle,userController.Logout)
api.POST("/verify", mailContoller.OnSendEmailCode) api.POST("/verify", mailContoller.OnSendEmailCode)
/** 修改密码**/
api.POST("modify_pass",middle.AuthMiddle,userController.ModifyPasswd)
} }
e := r.Run(":" + strconv.Itoa(config.GetPort())) e := r.Run(":" + strconv.Itoa(config.GetPort()))
if nil != e { if nil != e {
log.Print(e.Error()) log.Print(e.Error())

View File

@ -1,16 +1,18 @@
package model package model
import ( import (
"background/utils"
"crypto/md5"
"fmt" "fmt"
"log" "log"
"user/db" "background/db"
"user/logs" "background/logs"
) )
type Users struct { type Users struct {
ID int64 `sql:"id" json:"id"` ID int64 `sql:"id" json:"id"`
UserName string `sql:"user_name" json:"user_name"` UserName string `sql:"user_name" json:"user_name"`
UserPwd string `sql:"user_pwd" json:"-"` UserPwd string `sql:"user_pwd" json:"user_pwd"`
CreatedDate string `sql:"created_date" json:"created_date"` CreatedDate string `sql:"created_date" json:"created_date"`
UpdatedDate string `sql:"updated_date" json:"updated_date"` UpdatedDate string `sql:"updated_date" json:"updated_date"`
DisplayName string `sql:"display_name" json:"display_name"` DisplayName string `sql:"display_name" json:"display_name"`
@ -24,7 +26,8 @@ func GetUsers(limit int32, offsetPage int32, name string) ([]Users, int32) {
var query string var query string
if name != "" { if name != "" {
log.Println(name) log.Println(name)
query = fmt.Sprintf("select * from users where user_name like '%s' limit %d offset %d", "%%"+name+"%%", limit, offsetPage*limit) query = fmt.Sprintf("select * from users where user_name like '%s' limit %d offset %d",
"%%"+name+"%%", limit, offsetPage*limit)
log.Printf(query) log.Printf(query)
} else { } else {
query = fmt.Sprintf("select * from users limit %d offset %d", limit, offsetPage*limit) query = fmt.Sprintf("select * from users limit %d offset %d", limit, offsetPage*limit)
@ -44,3 +47,19 @@ func GetUsers(limit int32, offsetPage int32, name string) ([]Users, int32) {
} }
return users, cnts[0].Count return users, cnts[0].Count
} }
func ModyfyPassword(UserName string ,Password string) error {
h := md5.New()
h.Write([]byte(Password))
query := fmt.Sprintf("update users set user_pwd = '%s' where user_name = '%s' ",
utils.ByteSliceToString(h.Sum(nil)),UserName)
n,err := db.GetMysqlClient().Update(query)
if nil != err {
logs.Error(err.Error())
return err
}
if n == 0{
return nil
}
return nil
}

103
utils/JWT.go Normal file
View File

@ -0,0 +1,103 @@
package utils
import (
"fmt"
"github.com/dgrijalva/jwt-go"
"strconv"
"time"
)
//创建token,
/*
uid:用户名
secret:密匙
alg:加密算法类型
exp过期时间单位是秒
*/
func CreateJwt(uid string, secret []byte, alg string, exp int64) (tokenString string, err error) {
//get SigningMethod
signingMethon := jwt.GetSigningMethod(alg)
//time.Sleep(time.Nanosecond * time.Duration(RandomInt(0, 10)))
iat := time.Now().Unix()
// Create a new token object, specifying signing method and the claims
// you would like it to contain.
token := jwt.NewWithClaims(signingMethon, jwt.MapClaims{
"iss": "Authen Center",
"iat": iat,
"exp": iat + exp,
"jti": uid,
})
// Sign and get the complete encoded token as a string using the secret
tokenString, err = token.SignedString(secret)
//fmt.Printf("get jwt:%v,%v,%v\n%s\n", iat, iat+exp, uid, tokenString)
return
}
//获取token的用户名
//tokenString token字符串
//tag 字段名 如jti
func GetUid(tokenString string, tag string) (string, error) {
tokens, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
// hmacSampleSecret is a []byte containing your secret, e.g. []byte("my_secret_key")
return []byte(""), nil
})
if tokens == nil {
return "", err
}
//fmt.Printf("%#v",tokens.Claims.(jwt.MapClaims))
uid := tokens.Claims.(jwt.MapClaims)[tag]
switch t := uid.(type) {
case int:
_ = t
return strconv.Itoa(uid.(int)), nil
case float64:
_ = t
return strconv.FormatFloat(uid.(float64), 'g', 12, 64), nil
//... etc
}
return uid.(string), nil
}
//验证token
//secret 秘钥
//tokenString token的字符串
func VerifyJwt(secret []byte, tokenString string) (state int) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
}
// secret is a []byte containing your secret, e.g. []byte("my_secret_key")
return secret, nil
})
if err != nil {
state := -1 // jwt解析错误
tempError := err.(*jwt.ValidationError)
//fmt.Println("jwt error")
//fmt.Println(tempError)
//jwt过期
if tempError.Errors == jwt.ValidationErrorExpired {
//fmt.Println("jwt expired")
state = -2
}
// jwt IAT 错误
if tempError.Errors == jwt.ValidationErrorIssuedAt {
//fmt.Printf("jwt iat error")
state = -3
}
return state
}
if _, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
// fmt.Println(claims["iat"], claims["exp"])
// 验证通过
return 0
}
return 5
}

13
utils/base.go Normal file
View File

@ -0,0 +1,13 @@
package utils
import "fmt"
func ByteSliceToString(b []byte) string {
var ret string
for i := 0; i < len(b); i++ {
s := fmt.Sprintf("%02x", b[i])
ret += string(s)
}
return ret
}