iot_server/internal/pkg/packets/packets.go

681 lines
16 KiB
Go
Raw Normal View History

2023-08-28 06:49:44 +00:00
package packets
import (
"bufio"
"bytes"
"encoding/binary"
"errors"
"io"
"unicode/utf8"
"github.com/winc-link/hummingbird/internal/pkg/codes"
)
// Error type
var (
ErrInvalUTF8String = errors.New("invalid utf-8 string")
)
// MQTT Version
type Version = byte
type QoS = byte
var version2protoName = map[Version][]byte{
Version31: {'M', 'Q', 'I', 's', 'd', 'p'},
Version311: {'M', 'Q', 'T', 'T'},
Version5: {'M', 'Q', 'T', 'T'},
}
const (
Version31 Version = 0x03
Version311 Version = 0x04
Version5 Version = 0x05
// The maximum packet size of a MQTT packet
MaximumSize = 268435456
)
//Packet type
const (
RESERVED = iota
CONNECT
CONNACK
PUBLISH
PUBACK
PUBREC
PUBREL
PUBCOMP
SUBSCRIBE
SUBACK
UNSUBSCRIBE
UNSUBACK
PINGREQ
PINGRESP
DISCONNECT
AUTH
)
// QoS levels & Subscribe failure
const (
Qos0 uint8 = 0x00
Qos1 uint8 = 0x01
Qos2 uint8 = 0x02
SubscribeFailure = 0x80
)
// Flag in the FixHeader
const (
FlagReserved = 0
FlagSubscribe = 2
FlagUnsubscribe = 2
FlagPubrel = 2
)
//PacketID is the type of packet identifier
type PacketID = uint16
//Max & min packet ID
const (
MaxPacketID PacketID = 65535
MinPacketID PacketID = 1
)
type PayloadFormat = byte
const (
PayloadFormatBytes PayloadFormat = 0
PayloadFormatString PayloadFormat = 1
)
func IsVersion3X(v Version) bool {
return v == Version311 || v == Version31
}
func IsVersion5(v Version) bool {
return v == Version5
}
// FixHeader represents the FixHeader of the MQTT packet
type FixHeader struct {
PacketType byte
Flags byte
RemainLength int
}
// Packet defines the interface for structs intended to hold
// decoded MQTT packets, either from being read or before being
// written
type Packet interface {
// Pack encodes the packet struct into bytes and writes it into io.Writer.
Pack(w io.Writer) error
// Unpack read the packet bytes from io.Reader and decodes it into the packet struct
Unpack(r io.Reader) error
// String is mainly used in logging, debugging and testing.
String() string
}
// Topic represents the MQTT Topic
type Topic struct {
SubOptions
Name string
}
// SubOptions is the subscription option of subscriptions.
// For details: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Subscription_Options
type SubOptions struct {
// Qos is the QoS level of the subscription.
// 0 = At most once delivery
// 1 = At least once delivery
// 2 = Exactly once delivery
Qos uint8
// RetainHandling specifies whether retained messages are sent when the subscription is established.
// 0 = Send retained messages at the time of the subscribe
// 1 = Send retained messages at subscribe only if the subscription does not currently exist
// 2 = Do not send retained messages at the time of the subscribe
RetainHandling byte
// NoLocal is the No Local option.
// If the value is 1, Application Messages MUST NOT be forwarded to a connection with a ClientID equal to the ClientID of the publishing connection
NoLocal bool
// RetainAsPublished is the Retain As Published option.
// If 1, Application Messages forwarded using this subscription keep the RETAIN flag they were published with.
// If 0, Application Messages forwarded using this subscription have the RETAIN flag set to 0. Retained messages sent when the subscription is established have the RETAIN flag set to 1.
RetainAsPublished bool
}
// Reader is used to read data from bufio.Reader and create MQTT packet instance.
type Reader struct {
bufr *bufio.Reader
version Version
}
// Writer is used to encode MQTT packet into bytes and write it to bufio.Writer.
type Writer struct {
bufw *bufio.Writer
}
// Flush writes any buffered data to the underlying io.Writer.
func (w *Writer) Flush() error {
return w.bufw.Flush()
}
// ReadWriter warps Reader and Writer.
type ReadWriter struct {
*Reader
*Writer
}
// NewReader returns a new Reader.
func NewReader(r io.Reader) *Reader {
if bufr, ok := r.(*bufio.Reader); ok {
return &Reader{bufr: bufr, version: Version311}
}
return &Reader{bufr: bufio.NewReaderSize(r, 2048), version: Version311}
}
func (r *Reader) SetVersion(version Version) {
r.version = version
}
// NewWriter returns a new Writer.
func NewWriter(w io.Writer) *Writer {
if bufw, ok := w.(*bufio.Writer); ok {
return &Writer{bufw: bufw}
}
return &Writer{bufw: bufio.NewWriterSize(w, 2048)}
}
// ReadPacket reads data from Reader and returns a Packet instance.
// If any errors occurs, returns nil, error
func (r *Reader) ReadPacket() (Packet, error) {
first, err := r.bufr.ReadByte()
if err != nil {
return nil, err
}
fh := &FixHeader{PacketType: first >> 4, Flags: first & 15} //设置FixHeader
length, err := EncodeRemainLength(r.bufr)
if err != nil {
return nil, err
}
fh.RemainLength = length
packet, err := NewPacket(fh, r.version, r.bufr)
if err != nil {
return nil, err
}
if p, ok := packet.(*Connect); ok {
r.version = p.Version
}
return packet, err
}
// WritePacket writes the packet bytes to the Writer.
// Call Flush after WritePacket to flush buffered data to the underlying io.Writer.
func (w *Writer) WritePacket(packet Packet) error {
err := packet.Pack(w.bufw)
if err != nil {
return err
}
return nil
}
// WriteRaw write raw bytes to the Writer.
// Call Flush after WriteRaw to flush buffered data to the underlying io.Writer.
func (w *Writer) WriteRaw(b []byte) error {
_, err := w.bufw.Write(b)
if err != nil {
return err
}
return nil
}
// WriteAndFlush writes and flush the packet bytes to the underlying io.Writer.
func (w *Writer) WriteAndFlush(packet Packet) error {
err := packet.Pack(w.bufw)
if err != nil {
return err
}
return w.Flush()
}
// Pack encodes the FixHeader struct into bytes and writes it into io.Writer.
func (fh *FixHeader) Pack(w io.Writer) error {
var err error
b := make([]byte, 1)
packetType := fh.PacketType << 4
b[0] = packetType | fh.Flags
length, err := DecodeRemainLength(fh.RemainLength)
if err != nil {
return err
}
b = append(b, length...)
_, err = w.Write(b)
return err
}
//DecodeRemainLength 将remain length 转成byte表示
//
//DecodeRemainLength puts the length int into bytes
func DecodeRemainLength(length int) ([]byte, error) {
var result []byte
if length < 128 {
result = make([]byte, 1)
} else if length < 16384 {
result = make([]byte, 2)
} else if length < 2097152 {
result = make([]byte, 3)
} else if length < 268435456 {
result = make([]byte, 4)
} else {
return nil, codes.ErrMalformed
}
var i int
for {
encodedByte := length % 128
length = length / 128
// if there are more data to encode, set the top bit of this byte
if length > 0 {
encodedByte = encodedByte | 128
}
result[i] = byte(encodedByte)
i++
if length <= 0 {
break
}
}
return result, nil
}
// EncodeRemainLength 读remainLength,如果格式错误返回 error
//
// EncodeRemainLength reads the remain length bytes from bufio.Reader and returns length int.
func EncodeRemainLength(r io.ByteReader) (int, error) {
var vbi uint32
var multiplier uint32
for {
digit, err := r.ReadByte()
if err != nil && err != io.EOF {
return 0, err
}
vbi |= uint32(digit&127) << multiplier
if vbi > 268435455 {
return 0, codes.ErrMalformed
}
if (digit & 128) == 0 {
break
}
multiplier += 7
}
return int(vbi), nil
}
// EncodeUTF8String encodes the bytes into UTF-8 encoded strings, returns the encoded bytes, bytes size and error.
func EncodeUTF8String(buf []byte) (b []byte, size int, err error) {
buflen := len(buf)
if buflen > 65535 {
return nil, 0, codes.ErrMalformed
}
//length := int(binary.BigEndian.Uint16(buf[0:2]))
bufw := make([]byte, 2, 2+buflen)
binary.BigEndian.PutUint16(bufw, uint16(buflen))
bufw = append(bufw, buf...)
return bufw, 2 + buflen, nil
}
func readUint16(r *bytes.Buffer) (uint16, error) {
if r.Len() < 2 {
return 0, codes.ErrMalformed
}
return binary.BigEndian.Uint16(r.Next(2)), nil
}
func writeUint16(w *bytes.Buffer, i uint16) {
w.WriteByte(byte(i >> 8))
w.WriteByte(byte(i))
}
func writeUint32(w *bytes.Buffer, i uint32) {
w.WriteByte(byte(i >> 24))
w.WriteByte(byte(i >> 16))
w.WriteByte(byte(i >> 8))
w.WriteByte(byte(i))
}
func readUint32(r *bytes.Buffer) (uint32, error) {
if r.Len() < 4 {
return 0, codes.ErrMalformed
}
return binary.BigEndian.Uint32(r.Next(4)), nil
}
func readBinary(r *bytes.Buffer) (b []byte, err error) {
return readUTF8String(false, r)
}
func readUTF8String(mustUTF8 bool, r *bytes.Buffer) (b []byte, err error) {
if r.Len() < 2 {
return nil, codes.ErrMalformed
}
length := int(binary.BigEndian.Uint16(r.Next(2)))
if r.Len() < length {
return nil, codes.ErrMalformed
}
payload := r.Next(length)
if mustUTF8 {
if !ValidUTF8(payload) {
return nil, codes.ErrMalformed
}
}
return payload, nil
}
// the length of s cannot be greater than 65535
func writeUTF8String(w *bytes.Buffer, s []byte) {
writeUint16(w, uint16(len(s)))
w.Write(s)
}
func writeBinary(w *bytes.Buffer, b []byte) {
writeUint16(w, uint16(len(b)))
w.Write(b)
}
// DecodeUTF8String decodes the UTF-8 encoded strings into bytes, returns the decoded bytes, bytes size and error.
func DecodeUTF8String(buf []byte) (b []byte, size int, err error) {
buflen := len(buf)
if buflen < 2 {
return nil, 0, ErrInvalUTF8String
}
length := int(binary.BigEndian.Uint16(buf[0:2]))
if buflen < length+2 {
return nil, 0, ErrInvalUTF8String
}
payload := buf[2 : length+2]
if !ValidUTF8(payload) {
return nil, 0, ErrInvalUTF8String
}
return payload, length + 2, nil
}
// NewPacket returns a packet representing the decoded MQTT packet and an error.
func NewPacket(fh *FixHeader, version Version, r io.Reader) (Packet, error) {
switch fh.PacketType {
case CONNECT:
return NewConnectPacket(fh, version, r)
case CONNACK:
return NewConnackPacket(fh, version, r)
case PUBLISH:
return NewPublishPacket(fh, version, r)
case PUBACK:
return NewPubackPacket(fh, version, r)
case PUBREC:
return NewPubrecPacket(fh, version, r)
case PUBREL:
return NewPubrelPacket(fh, r)
case PUBCOMP:
return NewPubcompPacket(fh, version, r)
case SUBSCRIBE:
return NewSubscribePacket(fh, version, r)
case SUBACK:
return NewSubackPacket(fh, version, r)
case UNSUBSCRIBE:
return NewUnsubscribePacket(fh, version, r)
case PINGREQ:
return NewPingreqPacket(fh, r)
case DISCONNECT:
return NewDisConnectPackets(fh, version, r)
case UNSUBACK:
return NewUnsubackPacket(fh, version, r)
case PINGRESP:
return NewPingrespPacket(fh, r)
case AUTH:
return NewAuthPacket(fh, r)
default:
return nil, codes.ErrProtocol
}
}
// ValidUTF8 验证是否utf8
//
// ValidUTF8 returns whether the given bytes is in UTF-8 form.
func ValidUTF8(p []byte) bool {
for {
if len(p) == 0 {
return true
}
ru, size := utf8.DecodeRune(p)
if ru >= '\u0000' && ru <= '\u001f' { //[MQTT-1.5.3-2]
return false
}
if ru >= '\u007f' && ru <= '\u009f' {
return false
}
if ru == utf8.RuneError {
return false
}
if !utf8.ValidRune(ru) {
return false
}
if size == 0 {
return true
}
p = p[size:]
}
}
// ValidTopicName returns whether the bytes is a valid non-shared topic filter.[MQTT-4.7.1-1].
func ValidTopicName(mustUTF8 bool, p []byte) bool {
for len(p) > 0 {
ru, size := utf8.DecodeRune(p)
if mustUTF8 && ru == utf8.RuneError {
return false
}
if size == 1 {
//主题名不允许使用通配符
if p[0] == byte('+') || p[0] == byte('#') {
return false
}
}
p = p[size:]
}
return true
}
// ValidV5Topic returns whether the given bytes is a valid MQTT V5 topic
func ValidV5Topic(p []byte) bool {
if len(p) == 0 {
return false
}
if bytes.HasPrefix(p, []byte("$share/")) {
if len(p) < 9 {
return false
}
if p[7] != '/' {
subp := p[7:]
for len(subp) > 0 {
ru, size := utf8.DecodeRune(subp)
if ru == utf8.RuneError {
return false
}
if size == 1 {
if subp[0] == '/' {
return ValidTopicFilter(true, subp[1:])
}
if subp[0] == byte('+') || subp[0] == byte('#') {
return false
}
}
subp = subp[size:]
}
}
return false
}
return ValidTopicFilter(true, p)
}
// ValidTopicFilter 验证主题过滤器是否合法
//
// ValidTopicFilter returns whether the bytes is a valid topic filter. [MQTT-4.7.1-2] [MQTT-4.7.1-3]
// ValidTopicFilter 验证主题过滤器是否合法
//
// ValidTopicFilter returns whether the bytes is a valid topic filter. [MQTT-4.7.1-2] [MQTT-4.7.1-3]
func ValidTopicFilter(mustUTF8 bool, p []byte) bool {
if len(p) == 0 {
return false
}
var prevByte byte //前一个字节
var isSetPrevByte bool
for len(p) > 0 {
ru, size := utf8.DecodeRune(p)
if mustUTF8 && ru == utf8.RuneError {
return false
}
plen := len(p)
if p[0] == byte('#') && plen != 1 { // #一定是最后一个字符
return false
}
if size == 1 && isSetPrevByte {
// + 前(如果有前后字节),一定是'/' [MQTT-4.7.1-2] [MQTT-4.7.1-3]
if (p[0] == byte('+') || p[0] == byte('#')) && prevByte != byte('/') {
return false
}
if plen > 1 { // p[0] 不是最后一个字节
if p[0] == byte('+') && p[1] != byte('/') { // + 后(如果有字节),一定是 '/'
return false
}
}
}
prevByte = p[0]
isSetPrevByte = true
p = p[size:]
}
return true
}
// TopicMatch 返回topic和topic filter是否
//
// TopicMatch returns whether the topic and topic filter is matched.
func TopicMatch(topic []byte, topicFilter []byte) bool {
var spos int
var tpos int
var sublen int
var topiclen int
var multilevelWildcard bool //是否是多层的通配符
sublen = len(topicFilter)
topiclen = len(topic)
if sublen == 0 || topiclen == 0 {
return false
}
if (topicFilter[0] == '$' && topic[0] != '$') || (topic[0] == '$' && topicFilter[0] != '$') {
return false
}
for {
//e.g. ([]byte("foo/bar"),[]byte("foo/+/#")
if spos < sublen && tpos <= topiclen {
if tpos != topiclen && topicFilter[spos] == topic[tpos] { // sublen是订阅 topiclen是发布,首字母匹配
if tpos == topiclen-1 { //遍历到topic的最后一个字节
/* Check for e.g. foo matching foo/# */
if spos == sublen-3 && topicFilter[spos+1] == '/' && topicFilter[spos+2] == '#' {
return true
}
}
spos++
tpos++
if spos == sublen && tpos == topiclen { //长度相等,内容相同,匹配
return true
} else if tpos == topiclen && spos == sublen-1 && topicFilter[spos] == '+' {
//订阅topic比发布topic多一个字节并且多出来的内容是+ ,比如: sub: foo/+ ,topic: foo/
if spos > 0 && topicFilter[spos-1] != '/' {
return false
}
spos++
return true
}
} else {
if topicFilter[spos] == '+' { //sub 和 topic 内容不匹配了
spos++
for { //找到topic的下一个主题分割符
if tpos < topiclen && topic[tpos] != '/' {
tpos++
} else {
break
}
}
if tpos == topiclen && spos == sublen { //都遍历完了,返回true
return true
}
} else if topicFilter[spos] == '#' {
multilevelWildcard = true
return true
} else {
/* Check for e.g. foo/bar matching foo/+/# */
if spos > 0 && spos+2 == sublen && tpos == topiclen && topicFilter[spos-1] == '+' && topicFilter[spos] == '/' && topicFilter[spos+1] == '#' {
multilevelWildcard = true
return true
}
return false
}
}
} else {
break
}
}
if !multilevelWildcard && (tpos < topiclen || spos < sublen) {
return false
}
return false
}
// TotalBytes returns how many bytes of the packet
func TotalBytes(p Packet) uint32 {
var header *FixHeader
switch pt := p.(type) {
case *Auth:
header = pt.FixHeader
case *Connect:
header = pt.FixHeader
case *Connack:
header = pt.FixHeader
case *Disconnect:
header = pt.FixHeader
case *Pingreq:
header = pt.FixHeader
case *Pingresp:
header = pt.FixHeader
case *Puback:
header = pt.FixHeader
case *Pubcomp:
header = pt.FixHeader
case *Publish:
header = pt.FixHeader
case *Pubrec:
header = pt.FixHeader
case *Pubrel:
header = pt.FixHeader
case *Suback:
header = pt.FixHeader
case *Subscribe:
header = pt.FixHeader
case *Unsuback:
header = pt.FixHeader
case *Unsubscribe:
header = pt.FixHeader
}
if header == nil {
return 0
}
var headerLength uint32
if header.RemainLength <= 127 {
headerLength = 2
} else if header.RemainLength <= 16383 {
headerLength = 3
} else if header.RemainLength <= 2097151 {
headerLength = 4
} else if header.RemainLength <= 268435455 {
headerLength = 5
}
return headerLength + uint32(header.RemainLength)
}