260 lines
6.3 KiB
Go
260 lines
6.3 KiB
Go
|
package packets
|
|||
|
|
|||
|
import (
|
|||
|
"bytes"
|
|||
|
"fmt"
|
|||
|
"io"
|
|||
|
|
|||
|
"github.com/winc-link/hummingbird/internal/pkg/codes"
|
|||
|
)
|
|||
|
|
|||
|
// Connect represents the MQTT Connect packet
|
|||
|
type Connect struct {
|
|||
|
Version Version
|
|||
|
FixHeader *FixHeader
|
|||
|
//Variable header
|
|||
|
ProtocolLevel byte
|
|||
|
//Connect Flags
|
|||
|
UsernameFlag bool
|
|||
|
ProtocolName []byte
|
|||
|
PasswordFlag bool
|
|||
|
WillRetain bool
|
|||
|
WillQos uint8
|
|||
|
WillFlag bool
|
|||
|
WillTopic []byte
|
|||
|
WillMsg []byte
|
|||
|
CleanStart bool
|
|||
|
KeepAlive uint16 //如果非零,1.5倍时间没收到则断开连接[MQTT-3.1.2-24]
|
|||
|
//if set
|
|||
|
ClientID []byte
|
|||
|
Username []byte
|
|||
|
Password []byte
|
|||
|
|
|||
|
Properties *Properties
|
|||
|
WillProperties *Properties
|
|||
|
}
|
|||
|
|
|||
|
func (c *Connect) String() string {
|
|||
|
return fmt.Sprintf("Connect, Version: %v,"+"ProtocolLevel: %v, UsernameFlag: %v, PasswordFlag: %v, ProtocolName: %s, CleanStart: %v, KeepAlive: %v, ClientID: %s, Username: %s, Password: %s"+
|
|||
|
", WillFlag: %v, WillRetain: %v, WillQos: %v, WillMsg: %s, Properties: %s, WillProperties: %s",
|
|||
|
c.Version, c.ProtocolLevel, c.UsernameFlag, c.PasswordFlag, c.ProtocolName, c.CleanStart, c.KeepAlive, c.ClientID, c.Username, c.Password, c.WillFlag, c.WillRetain, c.WillQos, c.WillMsg, c.Properties, c.WillProperties)
|
|||
|
}
|
|||
|
|
|||
|
// Pack encodes the packet struct into bytes and writes it into io.Writer.
|
|||
|
func (c *Connect) Pack(w io.Writer) error {
|
|||
|
var err error
|
|||
|
c.FixHeader = &FixHeader{PacketType: CONNECT, Flags: FlagReserved}
|
|||
|
|
|||
|
bufw := &bytes.Buffer{}
|
|||
|
bufw.Write([]byte{0x00, 0x04})
|
|||
|
bufw.Write(c.ProtocolName)
|
|||
|
bufw.WriteByte(c.ProtocolLevel)
|
|||
|
// write flag
|
|||
|
var (
|
|||
|
usenameFlag = 0
|
|||
|
passwordFlag = 0
|
|||
|
willRetain = 0
|
|||
|
willFlag = 0
|
|||
|
willQos = 0
|
|||
|
CleanStart = 0
|
|||
|
reserved = 0
|
|||
|
)
|
|||
|
if c.UsernameFlag {
|
|||
|
usenameFlag = 128
|
|||
|
}
|
|||
|
if c.PasswordFlag {
|
|||
|
passwordFlag = 64
|
|||
|
}
|
|||
|
if c.WillRetain {
|
|||
|
willRetain = 32
|
|||
|
}
|
|||
|
if c.WillQos == 1 {
|
|||
|
willQos = 8
|
|||
|
} else if c.WillQos == 2 {
|
|||
|
willQos = 16
|
|||
|
}
|
|||
|
if c.WillFlag {
|
|||
|
willFlag = 4
|
|||
|
}
|
|||
|
if c.CleanStart {
|
|||
|
CleanStart = 2
|
|||
|
}
|
|||
|
connFlag := usenameFlag | passwordFlag | willRetain | willFlag | willQos | CleanStart | reserved
|
|||
|
bufw.Write([]byte{uint8(connFlag)})
|
|||
|
writeUint16(bufw, c.KeepAlive)
|
|||
|
|
|||
|
if c.Version == Version5 {
|
|||
|
c.Properties.Pack(bufw, CONNECT)
|
|||
|
}
|
|||
|
clienIDByte, _, err := EncodeUTF8String(c.ClientID)
|
|||
|
|
|||
|
if err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
bufw.Write(clienIDByte)
|
|||
|
|
|||
|
if c.WillFlag {
|
|||
|
if c.Version == Version5 {
|
|||
|
c.WillProperties.PackWillProperties(bufw)
|
|||
|
}
|
|||
|
willTopicByte, _, err := EncodeUTF8String(c.WillTopic)
|
|||
|
if err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
bufw.Write(willTopicByte)
|
|||
|
willMsgByte, _, err := EncodeUTF8String(c.WillMsg)
|
|||
|
if err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
bufw.Write(willMsgByte)
|
|||
|
}
|
|||
|
if c.UsernameFlag {
|
|||
|
usernameByte, _, err := EncodeUTF8String(c.Username)
|
|||
|
if err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
bufw.Write(usernameByte)
|
|||
|
}
|
|||
|
if c.PasswordFlag {
|
|||
|
passwordByte, _, err := EncodeUTF8String(c.Password)
|
|||
|
if err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
bufw.Write(passwordByte)
|
|||
|
}
|
|||
|
|
|||
|
c.FixHeader.RemainLength = bufw.Len()
|
|||
|
err = c.FixHeader.Pack(w)
|
|||
|
if err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
_, err = bufw.WriteTo(w)
|
|||
|
return err
|
|||
|
}
|
|||
|
|
|||
|
// Unpack read the packet bytes from io.Reader and decodes it into the packet struct.
|
|||
|
func (c *Connect) Unpack(r io.Reader) (err error) {
|
|||
|
restBuffer := make([]byte, c.FixHeader.RemainLength)
|
|||
|
_, err = io.ReadFull(r, restBuffer)
|
|||
|
if err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
bufr := bytes.NewBuffer(restBuffer)
|
|||
|
c.ProtocolName, err = readUTF8String(false, bufr)
|
|||
|
if err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
c.ProtocolLevel, err = bufr.ReadByte()
|
|||
|
if err != nil {
|
|||
|
return codes.ErrMalformed
|
|||
|
}
|
|||
|
c.Version = c.ProtocolLevel
|
|||
|
if name, ok := version2protoName[c.ProtocolLevel]; !ok {
|
|||
|
return codes.NewError(codes.V3UnacceptableProtocolVersion)
|
|||
|
} else if !bytes.Equal(c.ProtocolName, name) {
|
|||
|
return codes.NewError(codes.UnsupportedProtocolVersion)
|
|||
|
}
|
|||
|
|
|||
|
connectFlags, err := bufr.ReadByte()
|
|||
|
if err != nil {
|
|||
|
return codes.ErrMalformed
|
|||
|
}
|
|||
|
reserved := 1 & connectFlags
|
|||
|
if reserved != 0 { //[MQTT-3.1.2-3]
|
|||
|
return codes.ErrMalformed
|
|||
|
}
|
|||
|
c.CleanStart = (1 & (connectFlags >> 1)) > 0
|
|||
|
c.WillFlag = (1 & (connectFlags >> 2)) > 0
|
|||
|
c.WillQos = 3 & (connectFlags >> 3)
|
|||
|
if !c.WillFlag && c.WillQos != 0 { //[MQTT-3.1.2-11]
|
|||
|
return codes.ErrMalformed
|
|||
|
}
|
|||
|
c.WillRetain = (1 & (connectFlags >> 5)) > 0
|
|||
|
if !c.WillFlag && c.WillRetain { //[MQTT-3.1.2-11]
|
|||
|
return codes.ErrMalformed
|
|||
|
}
|
|||
|
c.PasswordFlag = (1 & (connectFlags >> 6)) > 0
|
|||
|
c.UsernameFlag = (1 & (connectFlags >> 7)) > 0
|
|||
|
c.KeepAlive, err = readUint16(bufr)
|
|||
|
if err != nil {
|
|||
|
return codes.ErrMalformed
|
|||
|
}
|
|||
|
|
|||
|
if c.Version == Version5 {
|
|||
|
// resolve properties
|
|||
|
c.Properties = &Properties{}
|
|||
|
c.WillProperties = &Properties{}
|
|||
|
if err := c.Properties.Unpack(bufr, CONNECT); err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
}
|
|||
|
return c.unpackPayload(bufr)
|
|||
|
}
|
|||
|
|
|||
|
func (c *Connect) unpackPayload(bufr *bytes.Buffer) error {
|
|||
|
var err error
|
|||
|
c.ClientID, err = readUTF8String(true, bufr)
|
|||
|
if err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
|
|||
|
if IsVersion3X(c.Version) && len(c.ClientID) == 0 && !c.CleanStart { // v311 [MQTT-3.1.3-7]
|
|||
|
return codes.NewError(codes.V3IdentifierRejected) // v311 //[MQTT-3.1.3-8]
|
|||
|
}
|
|||
|
|
|||
|
if c.WillFlag {
|
|||
|
if c.Version == Version5 {
|
|||
|
err := c.WillProperties.UnpackWillProperties(bufr)
|
|||
|
if err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
}
|
|||
|
c.WillTopic, err = readUTF8String(true, bufr)
|
|||
|
if err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
c.WillMsg, err = readUTF8String(true, bufr)
|
|||
|
if err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
if c.UsernameFlag {
|
|||
|
c.Username, err = readUTF8String(true, bufr)
|
|||
|
if err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
}
|
|||
|
if c.PasswordFlag {
|
|||
|
c.Password, err = readUTF8String(true, bufr)
|
|||
|
if err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
}
|
|||
|
return nil
|
|||
|
}
|
|||
|
|
|||
|
// NewConnectPacket returns a Connect instance by the given FixHeader and io.Reader
|
|||
|
func NewConnectPacket(fh *FixHeader, version Version, r io.Reader) (*Connect, error) {
|
|||
|
//b1 := buffer[0] //一定是16
|
|||
|
p := &Connect{FixHeader: fh, Version: version}
|
|||
|
//判断 标志位 flags 是否合法[MQTT-2.2.2-2]
|
|||
|
if fh.Flags != FlagReserved {
|
|||
|
return nil, codes.ErrMalformed
|
|||
|
}
|
|||
|
err := p.Unpack(r)
|
|||
|
if err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
return p, err
|
|||
|
}
|
|||
|
|
|||
|
// NewConnackPacket returns the Connack struct which is the ack packet of the Connect packet.
|
|||
|
func (c *Connect) NewConnackPacket(code codes.Code, sessionReuse bool) *Connack {
|
|||
|
ack := &Connack{Code: code, Version: c.Version}
|
|||
|
if !c.CleanStart && sessionReuse && code == codes.Success {
|
|||
|
ack.SessionPresent = true //[MQTT-3.2.2-2]
|
|||
|
}
|
|||
|
return ack
|
|||
|
}
|