iot_server/internal/pkg/messaging/mqclient/mqtt/client.go

344 lines
11 KiB
Go
Raw Normal View History

2023-08-28 06:49:44 +00:00
/********************************************************************************
* Copyright 2019 Dell Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
* in compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*******************************************************************************/
package mqtt
import (
"crypto/tls"
"encoding/json"
"fmt"
"sync"
"time"
"github.com/winc-link/hummingbird/internal/pkg/messaging/mqclient"
"github.com/winc-link/hummingbird/internal/pkg/messaging/types"
pahoMqtt "github.com/eclipse/paho.mqtt.golang"
)
// ClientCreator defines the function signature for creating an MQTT client.
type ClientCreator func(config types.MessageBusConfig, handler pahoMqtt.OnConnectHandler) (pahoMqtt.Client, error)
// MessageMarshaller defines the function signature for marshaling structs into []byte.
type MessageMarshaller func(v interface{}) ([]byte, error)
// MessageUnmarshaller defines the function signature for unmarshaling []byte into structs.
type MessageUnmarshaller func(data []byte, v interface{}) error
// Client facilitates communication to an MQTT server and provides functionality needed to send and receive MQTT
// messages.
type Client struct {
sync.Mutex
creator ClientCreator
configuration types.MessageBusConfig
mqttClient pahoMqtt.Client
marshaller MessageMarshaller
unmarshaller MessageUnmarshaller
activeSubscriptions []activeSubscription
}
type activeSubscription struct {
topic string
qos byte
handler pahoMqtt.MessageHandler
errors chan error
}
// NewMQTTClient constructs a new MQTT client based on the options provided.
func NewMQTTClient(config types.MessageBusConfig) (*Client, error) {
client := &Client{
creator: DefaultClientCreator(),
configuration: config,
marshaller: json.Marshal,
unmarshaller: json.Unmarshal,
}
return client, nil
}
// NewMQTTClientWithCreator constructs a new MQTT client based on the options and ClientCreator provided.
func NewMQTTClientWithCreator(
config types.MessageBusConfig,
marshaller MessageMarshaller,
unmarshaller MessageUnmarshaller,
creator ClientCreator) (*Client, error) {
client := &Client{
creator: creator,
configuration: config,
marshaller: marshaller,
unmarshaller: unmarshaller,
}
return client, nil
}
// Connect establishes a connection to a MQTT server.
// This must be called before any other functionality provided by the Client.
func (mc *Client) Connect() error {
if mc.mqttClient == nil {
// Move created MQTT Client here since we need to set the onConnectHandler which needs to have access to
// the Client's activeSubscriptions. This was not possible from the factory method.
mqttClient, err := mc.creator(mc.configuration, mc.onConnectHandler)
if err != nil {
return err
}
mc.mqttClient = mqttClient
}
// Avoid reconnecting if already connected.
if mc.mqttClient.IsConnected() {
return nil
}
optionsReader := mc.mqttClient.OptionsReader()
return getTokenError(
mc.mqttClient.Connect(),
optionsReader.ConnectTimeout(),
ConnectOperation,
"Unable to connect")
}
func (mc *Client) onConnectHandler(_ pahoMqtt.Client) {
optionsReader := mc.mqttClient.OptionsReader()
// activeSubscriptions will be empty on the first connection.
// On a re-connect is when the subscriptions must be re-created.
mc.Lock()
defer mc.Unlock()
for _, subscription := range mc.activeSubscriptions {
token := mc.mqttClient.Subscribe(subscription.topic, subscription.qos, subscription.handler)
message := fmt.Sprintf("Failed to re-create subscription for topic=%s", subscription.topic)
err := getTokenError(token, optionsReader.ConnectTimeout(), SubscribeOperation, message)
if err != nil {
subscription.errors <- err
}
}
}
// Publish sends a message to the connected MQTT server.
func (mc *Client) Publish(message types.MessageEnvelope, topic string) error {
marshaledMessage, err := mc.marshaller(message)
if err != nil {
return NewOperationErr(PublishOperation, err.Error())
}
//TODO:
if !mc.mqttClient.IsConnected() {
return nil
}
optionsReader := mc.mqttClient.OptionsReader()
return getTokenError(
mc.mqttClient.Publish(
topic,
optionsReader.WillQos(),
optionsReader.WillRetained(),
marshaledMessage),
optionsReader.ConnectTimeout(),
PublishOperation,
"Unable to publish message")
}
// Subscribe creates a subscription for the specified topics.
func (mc *Client) Subscribe(topics []types.TopicChannel, messageErrors chan error) error {
optionsReader := mc.mqttClient.OptionsReader()
for _, topic := range topics {
handler := newMessageHandler(mc.unmarshaller, topic.Messages, messageErrors)
qos := optionsReader.WillQos()
token := mc.mqttClient.Subscribe(topic.Topic, qos, handler)
err := getTokenError(token, optionsReader.ConnectTimeout(), SubscribeOperation, "Failed to create subscription")
if err != nil {
return err
}
mc.addActiveSubscriptions(activeSubscription{
topic: topic.Topic,
qos: qos,
handler: handler,
errors: messageErrors,
})
}
return nil
}
func (mc *Client) Unsubscribe(topics ...string) error {
if len(topics) == 0 {
return nil
}
token := mc.mqttClient.Unsubscribe(topics...)
optionsReader := mc.mqttClient.OptionsReader()
err := getTokenError(token, optionsReader.ConnectTimeout(), UnsubscribeOperation, "Failed to cancel subscription")
if err != nil {
return err
}
for _, topic := range topics {
mc.deleteActiveSubscriptions(topic)
}
return nil
}
func (mc *Client) addActiveSubscriptions(a activeSubscription) {
mc.Lock()
defer mc.Unlock()
mc.activeSubscriptions = append(mc.activeSubscriptions, a)
}
func (mc *Client) deleteActiveSubscriptions(topic string) {
mc.Lock()
defer mc.Unlock()
var idx int
var flag bool
for i, v := range mc.activeSubscriptions {
if v.topic != topic {
continue
}
flag = true
idx = i
}
if !flag {
return
}
mc.activeSubscriptions = append(mc.activeSubscriptions[:idx], mc.activeSubscriptions[idx+1:]...)
}
// Disconnect closes the connection to the connected MQTT server.
func (mc *Client) Disconnect() error {
// Specify a wait time equal to the write timeout so that we allow other any queued processing to complete before
// disconnecting.
optionsReader := mc.mqttClient.OptionsReader()
mc.mqttClient.Disconnect(uint(optionsReader.ConnectTimeout() * time.Millisecond))
return nil
}
// DefaultClientCreator returns a default function for creating MQTT alertclient.
func DefaultClientCreator() ClientCreator {
return func(config types.MessageBusConfig, handler pahoMqtt.OnConnectHandler) (pahoMqtt.Client, error) {
clientConfiguration, err := CreateMQTTClientConfiguration(config)
if err != nil {
return nil, err
}
clientOptions, err := createClientOptions(clientConfiguration, tls.X509KeyPair, tls.LoadX509KeyPair)
if err != nil {
return nil, err
}
clientOptions.OnConnect = handler
return pahoMqtt.NewClient(clientOptions), nil
}
}
// ClientCreatorWithCertLoader creates a ClientCreator which leverages the specified cert creator and loader when
// creating an MQTT client.
func ClientCreatorWithCertLoader(certCreator mqclient.X509KeyPairCreator, certLoader mqclient.X509KeyLoader) ClientCreator {
return func(options types.MessageBusConfig, handler pahoMqtt.OnConnectHandler) (pahoMqtt.Client, error) {
clientConfiguration, err := CreateMQTTClientConfiguration(options)
if err != nil {
return nil, err
}
clientOptions, err := createClientOptions(clientConfiguration, certCreator, certLoader)
if err != nil {
return nil, err
}
clientOptions.OnConnect = handler
return pahoMqtt.NewClient(clientOptions), nil
}
}
// newMessageHandler creates a function which meets the criteria for a MessageHandler and propagates the received
// messages to the proper channel.
func newMessageHandler(
unmarshaler MessageUnmarshaller,
messageChannel chan<- types.MessageEnvelope,
errorChannel chan<- error) pahoMqtt.MessageHandler {
return func(client pahoMqtt.Client, message pahoMqtt.Message) {
var messageEnvelope types.MessageEnvelope
payload := message.Payload()
err := unmarshaler(payload, &messageEnvelope)
if err != nil {
errorChannel <- err
}
messageChannel <- messageEnvelope
}
}
// getTokenError determines if a Token is in an errored state and if so returns the proper error message. Otherwise,
// nil.
//
// NOTE the paho.pahoMqtt.golang's recommended way for handling errors do not cover all cases. During manual verification
// with an MQTT server, it was observed that the Token.Error() was sometimes nil even when a token.WaitTimeout(...)
// returned false(indicating the operation has timed-out). Therefore, there are some additional checks that need to
// take place to ensure the error message is returned if it is present. One example scenario, if you attempt to connect
// without providing a ClientID.
func getTokenError(token pahoMqtt.Token, timeout time.Duration, operation string, defaultTimeoutMessage string) error {
hasTimedOut := !token.WaitTimeout(timeout)
if hasTimedOut && token.Error() != nil {
return NewTimeoutError(operation, token.Error().Error())
}
if hasTimedOut && token.Error() == nil {
return NewTimeoutError(operation, defaultTimeoutMessage)
}
if token.Error() != nil {
return NewOperationErr(operation, token.Error().Error())
}
return nil
}
// createClientOptions constructs mqtt.Client options from an MQTTClientConfig.
func createClientOptions(
clientConfiguration MQTTClientConfig,
certCreator mqclient.X509KeyPairCreator,
certLoader mqclient.X509KeyLoader) (*pahoMqtt.ClientOptions, error) {
clientOptions := pahoMqtt.NewClientOptions()
clientOptions.AddBroker(clientConfiguration.BrokerURL)
clientOptions.SetUsername(clientConfiguration.Username)
clientOptions.SetPassword(clientConfiguration.Password)
clientOptions.SetClientID(clientConfiguration.ClientId)
clientOptions.SetKeepAlive(time.Duration(clientConfiguration.KeepAlive) * time.Second)
clientOptions.SetAutoReconnect(clientConfiguration.AutoReconnect)
clientOptions.SetConnectTimeout(time.Duration(clientConfiguration.ConnectTimeout) * time.Second)
tlsConfiguration, err := mqclient.GenerateTLSForClientClientOptions(
clientConfiguration.BrokerURL,
clientConfiguration.TlsConfigurationOptions,
certCreator,
certLoader)
if err != nil {
return clientOptions, err
}
clientOptions.SetTLSConfig(tlsConfiguration)
return clientOptions, nil
}