diff --git a/.gitignore b/.gitignore index 9e7bf44..4342047 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,4 @@ .idea vendor logs -edge-db-data -hummingbird -mqtt-broker go.sum diff --git a/cmd/mqtt-broker/Dockerfile b/cmd/mqtt-broker/Dockerfile new file mode 100644 index 0000000..3ff27c5 --- /dev/null +++ b/cmd/mqtt-broker/Dockerfile @@ -0,0 +1,42 @@ +# ---------------------------------------------------------------------------------- +# Copyright 2018 Dell Technologies, Inc. +# Copyright 2018 Cavium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ---------------------------------------------------------------------------------- + +ARG BUILDER_BASE=golang:latest +FROM ${BUILDER_BASE} AS builder + +WORKDIR /edge + +# gitlab +COPY . . + +RUN --mount=type=cache,target=/root/.cache/go-build make cmd/mqtt-broker/mqtt-broker + +#Next image - Copy built Go binary into new workspace +FROM alpine:3.16 + +RUN --mount=type=cache,target=/var/cache/apk apk add --update --no-cache dumb-init + +EXPOSE 58090 + +WORKDIR / +COPY --from=builder /edge/cmd/mqtt-broker/mqtt-broker /bin/ +COPY --from=builder /edge/cmd/mqtt-broker/res/configuration.yml.dist /etc/emqtt-broker/res/configuration.yml + +#RUN mkdir -p /logs/mqtt-broker + +CMD ["/bin/sh", "-c", "/bin/mqtt-broker start -c=/etc/emqtt-broker/res/configuration.yml"] \ No newline at end of file diff --git a/cmd/mqtt-broker/initcmd/init.go b/cmd/mqtt-broker/initcmd/init.go new file mode 100644 index 0000000..70e605f --- /dev/null +++ b/cmd/mqtt-broker/initcmd/init.go @@ -0,0 +1,39 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package initcmd + +import ( + "fmt" + "github.com/spf13/cobra" + "github.com/winc-link/hummingbird/cmd/mqtt-broker/mqttd" + "github.com/winc-link/hummingbird/cmd/mqtt-broker/mqttd/command" + "os" + "path" +) + +func must(err error) { + if err != nil { + fmt.Fprint(os.Stderr, err.Error()) + os.Exit(1) + } +} + +func Init(rootCmd *cobra.Command) { + configDir, err := mqttd.GetDefaultConfigDir() + must(err) + command.ConfigFile = path.Join(configDir, "configuration.yml") + rootCmd.PersistentFlags().StringVarP(&command.ConfigFile, "config", "c", command.ConfigFile, "The configuration file path") + rootCmd.AddCommand(command.NewStartCmd()) +} diff --git a/cmd/mqtt-broker/main.go b/cmd/mqtt-broker/main.go new file mode 100644 index 0000000..0a731e3 --- /dev/null +++ b/cmd/mqtt-broker/main.go @@ -0,0 +1,55 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package main + +import ( + "fmt" + "github.com/spf13/cobra" + "github.com/winc-link/hummingbird/cmd/mqtt-broker/initcmd" + + _ "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence" + _ "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/topicalias/fifo" + + _ "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/plugin/admin" + _ "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/plugin/aplugin" + _ "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/plugin/auth" + _ "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/plugin/federation" + + "net/http" + "os" +) + +var ( + rootCmd = &cobra.Command{ + Use: "mqttd", + Long: "This is a MQTT broker that fully implements MQTT V5.0 and V3.1.1 protocol", + Version: "", + } + enablePprof bool + pprofAddr = "127.0.0.1:60600" +) + +func main() { + if enablePprof { + go func() { + http.ListenAndServe(pprofAddr, nil) + }() + } + initcmd.Init(rootCmd) + if err := rootCmd.Execute(); err != nil { + fmt.Fprint(os.Stderr, err.Error()) + os.Exit(1) + } +} diff --git a/cmd/mqtt-broker/mqttd/command/start.go b/cmd/mqtt-broker/mqttd/command/start.go new file mode 100644 index 0000000..0a09a4b --- /dev/null +++ b/cmd/mqtt-broker/mqttd/command/start.go @@ -0,0 +1,155 @@ +package command + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "os" + "os/signal" + "syscall" + + "github.com/spf13/cobra" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/server" + "github.com/winc-link/hummingbird/internal/pkg/pidfile" + "go.uber.org/zap" +) + +var ( + ConfigFile string + logger *zap.Logger +) + +func must(err error) { + if err != nil { + fmt.Fprint(os.Stderr, err) + os.Exit(1) + } +} + +func installSignal(srv server.Server) { + // reload + reloadSignalCh := make(chan os.Signal, 1) + signal.Notify(reloadSignalCh, syscall.SIGHUP) + + // stop + stopSignalCh := make(chan os.Signal, 1) + signal.Notify(stopSignalCh, os.Interrupt, syscall.SIGTERM) + + for { + select { + case <-reloadSignalCh: + var c config.Config + var err error + c, err = config.ParseConfig(ConfigFile) + if err != nil { + logger.Error("reload error", zap.Error(err)) + return + } + srv.ApplyConfig(c) + logger.Info("gmqtt reloaded") + case <-stopSignalCh: + err := srv.Stop(context.Background()) + if err != nil { + logger.Error(err.Error()) + //fmt.Fprint(os.Stderr, err.Error()) + } + } + } + +} + +func GetListeners(c config.Config) (tcpListeners []net.Listener, websockets []*server.WsServer, err error) { + for _, v := range c.Listeners { + var ln net.Listener + if v.Websocket != nil { + ws := &server.WsServer{ + Server: &http.Server{Addr: v.Address}, + Path: v.Websocket.Path, + } + if v.TLSOptions != nil { + ws.KeyFile = v.Key + ws.CertFile = v.Cert + } + websockets = append(websockets, ws) + continue + } + if v.TLSOptions != nil { + var cert tls.Certificate + cert, err = tls.LoadX509KeyPair(v.Cert, v.Key) + if err != nil { + return + } + ln, err = tls.Listen("tcp", v.Address, &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + } else { + ln, err = net.Listen("tcp", v.Address) + } + tcpListeners = append(tcpListeners, ln) + } + return +} + +// NewStartCmd creates a *cobra.Command object for start command. +func NewStartCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "start", + Short: "Start gmqtt broker", + Run: func(cmd *cobra.Command, args []string) { + var err error + must(err) + c, err := config.ParseConfig(ConfigFile) + if os.IsNotExist(err) { + must(err) + } else { + must(err) + } + if c.PidFile != "" { + pid, err := pidfile.New(c.PidFile) + if err != nil { + must(fmt.Errorf("open pid file failed: %s", err)) + } + defer pid.Remove() + } + + level, l, err := c.GetLogger(c.Log) + must(err) + logger = l + + //db := mqttbroker.NewDatabase(c) + //err = db.InitDBClient(l) + //must(err) + + tcpListeners, websockets, err := GetListeners(c) + must(err) + + s := server.New( + server.WithConfig(c), + server.WithTCPListener(tcpListeners...), + server.WithWebsocketServer(websockets...), + server.WithLogger(&server.DefaultLogger{ + Level: level, + Logger: l, + }), + ) + + err = s.Init() + if err != nil { + fmt.Println(err) + os.Exit(1) + return + } + go installSignal(s) + err = s.Run() + if err != nil { + fmt.Fprint(os.Stderr, err.Error()) + os.Exit(1) + return + } + }, + } + return cmd +} diff --git a/cmd/mqtt-broker/mqttd/config_unix.go b/cmd/mqtt-broker/mqttd/config_unix.go new file mode 100644 index 0000000..6f06ebc --- /dev/null +++ b/cmd/mqtt-broker/mqttd/config_unix.go @@ -0,0 +1,11 @@ +// +build !windows + +package mqttd + +var ( + DefaultConfigDir = "./res/" +) + +func GetDefaultConfigDir() (string, error) { + return DefaultConfigDir, nil +} diff --git a/cmd/mqtt-broker/plugin_generate.go b/cmd/mqtt-broker/plugin_generate.go new file mode 100644 index 0000000..faad23c --- /dev/null +++ b/cmd/mqtt-broker/plugin_generate.go @@ -0,0 +1,85 @@ +// +build ignore + +package main + +import ( + "bytes" + "go/format" + "io" + "io/ioutil" + "log" + "strings" + "text/template" + + "gopkg.in/yaml.v2" +) + +var tmpl = `//go:generate sh -c "cd ../../ && go run plugin_generate.go" +// generated by plugin_generate.go; DO NOT EDIT + +package mqttd + +import ( + {{- range $index, $element := .}} + _ "{{$element}}" + {{- end}} +) +` + +const ( + pluginFile = "./mqttd/plugins.go" + pluginCfg = "plugin_imports.yml" + importPath = "gitlab.com/tedge/edgex/internal/thummingbird/mqttbroker/plugin" +) + +type ymlCfg struct { + Packages []string `yaml:"packages"` +} + +func main() { + b, err := ioutil.ReadFile(pluginCfg) + if err != nil { + log.Fatalf("ReadFile error %s", err) + return + } + + var cfg ymlCfg + err = yaml.Unmarshal(b, &cfg) + if err != nil { + log.Fatalf("Unmarshal error: %s", err) + return + } + t, err := template.New("plugin_gen").Parse(tmpl) + if err != nil { + log.Fatalf("Parse template error: %s", err) + return + } + + for k, v := range cfg.Packages { + if !strings.Contains(v, "/") { + cfg.Packages[k] = importPath + "/" + v + } + } + + if err != nil && err != io.EOF { + log.Fatalf("read error: %s", err) + return + } + buf := &bytes.Buffer{} + err = t.Execute(buf, cfg.Packages) + if err != nil { + log.Fatalf("excute template error: %s", err) + return + } + rs, err := format.Source(buf.Bytes()) + if err != nil { + log.Fatalf("format error: %s", err) + return + } + err = ioutil.WriteFile(pluginFile, rs, 0666) + if err != nil { + log.Fatalf("writeFile error: %s", err) + return + } + return +} diff --git a/cmd/mqtt-broker/plugin_imports.yml b/cmd/mqtt-broker/plugin_imports.yml new file mode 100644 index 0000000..a66a647 --- /dev/null +++ b/cmd/mqtt-broker/plugin_imports.yml @@ -0,0 +1,6 @@ +packages: + - admin + # - federation + - aplugin + # for external plugin, use full import path + # - gitlab.com/tedge/edgex/internal/thummingbird/mqttbroker/plugin/prometheus \ No newline at end of file diff --git a/cmd/mqtt-broker/res/ca.crt b/cmd/mqtt-broker/res/ca.crt new file mode 100644 index 0000000..6fb57f7 --- /dev/null +++ b/cmd/mqtt-broker/res/ca.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDszCCApugAwIBAgIJAJXKBu6eNV6YMA0GCSqGSIb3DQEBCwUAMHAxCzAJBgNV +BAYTAkNOMQswCQYDVQQIDAJaSjELMAkGA1UEBwwCSFoxCzAJBgNVBAoMAlRZMQsw +CQYDVQQLDAJUWTEOMAwGA1UEAwwFdGVkZ2UxHTAbBgkqhkiG9w0BCQEWDnRlZGdl +QHR1eWEuY29tMB4XDTIyMDEyNDA2NTQ0M1oXDTMyMDEyMjA2NTQ0M1owcDELMAkG +A1UEBhMCQ04xCzAJBgNVBAgMAlpKMQswCQYDVQQHDAJIWjELMAkGA1UECgwCVFkx +CzAJBgNVBAsMAlRZMQ4wDAYDVQQDDAV0ZWRnZTEdMBsGCSqGSIb3DQEJARYOdGVk +Z2VAdHV5YS5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC/IZfF +++ytZVIDbt5Ypz/55e0HTrq9jrpVOZAKSBbmSUryjpo8NfZoDp5QVZi4kSo1G0xV +Wf9C+5h13TFM2pDm9W9q4v8e3cB3Z+qK8nHn66xyQYnTihg8D9vyJHIQ2nirCVqW +HL2wYdakE0MojbVsQPWufYh84tWXyyUIo2W2ycoXmSfpWhb4LDEf4tcmDBNp2ydG +ef7MNbrS3t/h/iOzqjj7s+styiLyKjxE0oh1VfOOp8e9HPnh2EvaQwwTq91KRf+v +rl4DPZt93oMd9i28HuxBsWsE6eDRfYmF96ZoIXEh4ga9XWR8geuRCsTREQo7tqUX +gXFUXCe2Uo6R0uh9AgMBAAGjUDBOMB0GA1UdDgQWBBQayDeoKN44f/FV+Z6rv1vT +bsITvTAfBgNVHSMEGDAWgBQayDeoKN44f/FV+Z6rv1vTbsITvTAMBgNVHRMEBTAD +AQH/MA0GCSqGSIb3DQEBCwUAA4IBAQB0G/AM7zU2USZj3C32zzzM6LQNx465x1i2 +XgSw9ECM7M3ct6x859L6vKXWUm6+OQO7jm9xRRyDQIrCSQT66MCei5C+nqzyPIZA +zL5cV7bRXA39KBwThyZZqWl4bttp98UZnEbX6yICVEcjsnaIA2D0vh1zZar4Ilyq +mBl4NA13HTIcQ0s4Efuhdf5RdPw3ha2cjf74aNNj2WijAY4rKEVF1Buw/PrvJ7WR +hrQlNrv214e3hbjV99oiII8OLDT0oJApUbSr6ktjF26bAu929b3QDADEK9QpBYE1 +brg3KD51xgD+HGKd3PVLqr60y7OQKKMHd8TrQK/ibQVgFdbE6/AG +-----END CERTIFICATE----- diff --git a/cmd/mqtt-broker/res/configuration.yml b/cmd/mqtt-broker/res/configuration.yml new file mode 100644 index 0000000..ea0808f --- /dev/null +++ b/cmd/mqtt-broker/res/configuration.yml @@ -0,0 +1,159 @@ +# Path to pid file. +# If not set, there will be no pid file. +# pid_file: /var/run/mqttd.pid + +listeners: + # bind address + - address: ":58090" # 58090 + +# - address: ":21883" # 21883 +# tls: +# cacert: "/etc/mqtt-broker/ca.crt" +# cert: "/etc/mqtt-broker/server.pem" +# key: "/etc/mqtt-broker/server.key" +# +# cacert: "cmd/mqtt-broker/res/ca.crt" +# cert: "cmd/mqtt-broker/res/server.pem" +# key: "cmd/mqtt-broker/res/server.key" +# +# - address: ":28883" # 28883 +# # websocket setting +# websocket: +# path: "/" + +api: + grpc: + # The gRPC server listen address. Supports unix socket and tcp socket. + - address: "tcp://127.0.0.1:57090" # 57090 + http: + # The HTTP server listen address. This is a reverse-proxy server in front of gRPC server. + - address: "tcp://127.0.0.1:57091" # 57091 + map: "tcp://127.0.0.1:57090" # The backend gRPC server endpoint, + +mqtt: + # The maximum session expiry interval in seconds. + session_expiry: 2h + # The interval time for session expiry checker to check whether there are expired sessions. + session_expiry_check_timer: 20s + # The maximum lifetime of the message in seconds. + # If a message in the queue is not sent in message_expiry time, it will be dropped, which means it will not be sent to the subscriber. + message_expiry: 2h + # The lifetime of the "inflight" message in seconds. + # If a "inflight" message is not acknowledged by a client in inflight_expiry time, it will be removed when the message queue is full. + inflight_expiry: 30s + # The maximum packet size that the server is willing to accept from the client. + max_packet_size: 268435456 + # The maximum number of QoS 1 and QoS 2 publications that the server is willing to process concurrently for the client. + server_receive_maximum: 100 + # The maximum keep alive time in seconds allows by the server. + # If the client requests a keepalive time bigger than MaxKeepalive,the server will use MaxKeepAlive as the keepalive time. + # In this case, if the client version is v5, the server will set MaxKeepalive into CONNACK to inform the client. + # But if the client version is 3.x, the server has no way to inform the client that the keepalive time has been changed. + max_keepalive: 300 + # The highest value that the server will accept as a Topic Alias sent by the client. + # No-op if the client version is MQTTv3.x . + topic_alias_maximum: 10 + # Whether the server supports Subscription Identifiers. + # No-op if the client version is MQTTv3.x . + subscription_identifier_available: true + # Whether the server supports Wildcard Subscriptions. + wildcard_subscription_available: true + # Whether the server supports Shared Subscriptions. + shared_subscription_available: true + # The highest QOS level permitted for a Publish. + maximum_qos: 2 + # Whether the server supports retained messages. + retain_available: true + # The maximum queue length of the outgoing messages. + # If the queue is full, some message will be dropped. + # The message dropping strategy is described in the document of the persistence/queue.Store interface. + max_queued_messages: 1000 + # The limits of inflight message length of the outgoing messages. + # Inflight message is also stored in the message queue, so it must be less than or equal to max_queued_messages. + # Inflight message is the QoS 1 or QoS 2 message that has been sent out to a client but not been acknowledged yet. + max_inflight: 100 + # Whether to store QoS 0 message for a offline session. + queue_qos0_messages: true + # The delivery mode. The possible value can be "overlap" or "onlyonce". + # It is possible for a client’s subscriptions to overlap so that a published message might match multiple filters. + # When set to "overlap" , the server will deliver one message for each matching subscription and respecting the subscription’s QoS in each case. + # When set to "onlyonce", the server will deliver the message to the client respecting the maximum QoS of all the matching subscriptions. + delivery_mode: onlyonce + # Whether to allow a client to connect with empty client id. + allow_zero_length_clientid: true + +persistence: + type: memory # memory | redis + # The redis configuration only take effect when type == redis. + redis: + # redis server address + addr: "127.0.0.1:56379" + # the maximum number of idle connections in the redis connection pool. + max_idle: 1000 + # the maximum number of connections allocated by the redis connection pool at a given time. + # If zero, there is no limit on the number of connections in the pool. + max_active: 0 + # the connection idle timeout, connection will be closed after remaining idle for this duration. If the value is zero, then idle connections are not closed. + idle_timeout: 240s + password: "qqwihyzjb8l2sx0c" + # the number of the redis database. + database: 0 + +# The topic alias manager setting. The topic alias feature is introduced by MQTT V5. +# This setting is used to control how the broker manage topic alias. +topic_alias_manager: + # Currently, only FIFO strategy is supported. + type: fifo + +plugins: + aplugin: + # Password hash type. (plain | md5 | sha256 | bcrypt) + # Default to MD5. + hash: md5 + # The file to store password. If it is a relative path, it locates in the same directory as the config file. + # (e.g: ./gmqtt_password => /etc/gmqtt/gmqtt_password.yml) + # Defaults to ./gmqtt_password.yml + # password_file: + federation: + # node_name is the unique identifier for the node in the federation. Defaults to hostname. + # node_name: + # fed_addr is the gRPC server listening address for the federation internal communication. Defaults to :8901 + fed_addr: :8901 + # advertise_fed_addr is used to change the federation gRPC server address that we advertise to other nodes in the cluster. + # Defaults to "fed_addr".However, in some cases, there may be a routable address that cannot be bound. + # If the port is missing, the default federation port (8901) will be used. + advertise_fed_addr: :8901 + # gossip_addr is the address that the gossip will listen on, It is used for both UDP and TCP gossip. Defaults to :8902 + gossip_addr: :8902 + # advertise_gossip_addr is used to change the gossip server address that we advertise to other nodes in the cluster. + # Defaults to "GossipAddr" or the private IP address of the node if the IP in "GossipAddr" is 0.0.0.0. + # If the port is missing, the default gossip port (8902) will be used. + advertise_gossip_addr: :8902 + + # retry_join is the address of other nodes to join upon starting up. + # If port is missing, the default gossip port (8902) will be used. + #retry_join: + # - 127.0.0.1:8902 + + # rejoin_after_leave will be pass to "RejoinAfterLeave" in serf configuration. + # It controls our interaction with the snapshot file. + # When set to false (default), a leave causes a Serf to not rejoin the cluster until an explicit join is received. + # If this is set to true, we ignore the leave, and rejoin the cluster on start. + rejoin_after_leave: false + # snapshot_path will be pass to "SnapshotPath" in serf configuration. + # When Serf is started with a snapshot,it will attempt to join all the previously known nodes until one + # succeeds and will also avoid replaying old user events. + snapshot_path: + +# plugin loading orders +plugin_order: + # Uncomment auth to enable authentication. + - aplugin + #- admin + #- federation +log: + level: debug # debug | info | warn | error + file_path: "./mqtt-broker/mqtt-broker.log" + # whether to dump MQTT packet in debug level + dump_packet: false + diff --git a/cmd/mqtt-broker/res/configuration.yml.dist b/cmd/mqtt-broker/res/configuration.yml.dist new file mode 100644 index 0000000..b8a5462 --- /dev/null +++ b/cmd/mqtt-broker/res/configuration.yml.dist @@ -0,0 +1,159 @@ +# Path to pid file. +# If not set, there will be no pid file. +# pid_file: /var/run/mqttd.pid + +listeners: + # bind address + - address: ":58090" # 58090 + +# - address: ":21883" # 21883 +# tls: +# cacert: "/etc/mqtt-broker/ca.crt" +# cert: "/etc/mqtt-broker/server.pem" +# key: "/etc/mqtt-broker/server.key" +# +# cacert: "cmd/mqtt-broker/res/ca.crt" +# cert: "cmd/mqtt-broker/res/server.pem" +# key: "cmd/mqtt-broker/res/server.key" +# +# - address: ":28883" # 28883 +# # websocket setting +# websocket: +# path: "/" + +api: + grpc: + # The gRPC server listen address. Supports unix socket and tcp socket. + - address: "tcp://127.0.0.1:57090" # 57090 + http: + # The HTTP server listen address. This is a reverse-proxy server in front of gRPC server. + - address: "tcp://127.0.0.1:57091" # 57091 + map: "tcp://127.0.0.1:57090" # The backend gRPC server endpoint, + +mqtt: + # The maximum session expiry interval in seconds. + session_expiry: 2h + # The interval time for session expiry checker to check whether there are expired sessions. + session_expiry_check_timer: 20s + # The maximum lifetime of the message in seconds. + # If a message in the queue is not sent in message_expiry time, it will be dropped, which means it will not be sent to the subscriber. + message_expiry: 2h + # The lifetime of the "inflight" message in seconds. + # If a "inflight" message is not acknowledged by a client in inflight_expiry time, it will be removed when the message queue is full. + inflight_expiry: 30s + # The maximum packet size that the server is willing to accept from the client. + max_packet_size: 268435456 + # The maximum number of QoS 1 and QoS 2 publications that the server is willing to process concurrently for the client. + server_receive_maximum: 100 + # The maximum keep alive time in seconds allows by the server. + # If the client requests a keepalive time bigger than MaxKeepalive,the server will use MaxKeepAlive as the keepalive time. + # In this case, if the client version is v5, the server will set MaxKeepalive into CONNACK to inform the client. + # But if the client version is 3.x, the server has no way to inform the client that the keepalive time has been changed. + max_keepalive: 300 + # The highest value that the server will accept as a Topic Alias sent by the client. + # No-op if the client version is MQTTv3.x . + topic_alias_maximum: 10 + # Whether the server supports Subscription Identifiers. + # No-op if the client version is MQTTv3.x . + subscription_identifier_available: true + # Whether the server supports Wildcard Subscriptions. + wildcard_subscription_available: true + # Whether the server supports Shared Subscriptions. + shared_subscription_available: true + # The highest QOS level permitted for a Publish. + maximum_qos: 2 + # Whether the server supports retained messages. + retain_available: true + # The maximum queue length of the outgoing messages. + # If the queue is full, some message will be dropped. + # The message dropping strategy is described in the document of the persistence/queue.Store interface. + max_queued_messages: 1000 + # The limits of inflight message length of the outgoing messages. + # Inflight message is also stored in the message queue, so it must be less than or equal to max_queued_messages. + # Inflight message is the QoS 1 or QoS 2 message that has been sent out to a client but not been acknowledged yet. + max_inflight: 100 + # Whether to store QoS 0 message for a offline session. + queue_qos0_messages: true + # The delivery mode. The possible value can be "overlap" or "onlyonce". + # It is possible for a client’s subscriptions to overlap so that a published message might match multiple filters. + # When set to "overlap" , the server will deliver one message for each matching subscription and respecting the subscription’s QoS in each case. + # When set to "onlyonce", the server will deliver the message to the client respecting the maximum QoS of all the matching subscriptions. + delivery_mode: onlyonce + # Whether to allow a client to connect with empty client id. + allow_zero_length_clientid: true + +persistence: + type: memory # memory | redis + # The redis configuration only take effect when type == redis. + redis: + # redis server address + addr: "127.0.0.1:56379" + # the maximum number of idle connections in the redis connection pool. + max_idle: 1000 + # the maximum number of connections allocated by the redis connection pool at a given time. + # If zero, there is no limit on the number of connections in the pool. + max_active: 0 + # the connection idle timeout, connection will be closed after remaining idle for this duration. If the value is zero, then idle connections are not closed. + idle_timeout: 240s + password: "qqwihyzjb8l2sx0c" + # the number of the redis database. + database: 0 + +# The topic alias manager setting. The topic alias feature is introduced by MQTT V5. +# This setting is used to control how the broker manage topic alias. +topic_alias_manager: + # Currently, only FIFO strategy is supported. + type: fifo + +plugins: + aplugin: + # Password hash type. (plain | md5 | sha256 | bcrypt) + # Default to MD5. + hash: md5 + # The file to store password. If it is a relative path, it locates in the same directory as the config file. + # (e.g: ./gmqtt_password => /etc/gmqtt/gmqtt_password.yml) + # Defaults to ./gmqtt_password.yml + # password_file: + federation: + # node_name is the unique identifier for the node in the federation. Defaults to hostname. + # node_name: + # fed_addr is the gRPC server listening address for the federation internal communication. Defaults to :8901 + fed_addr: :8901 + # advertise_fed_addr is used to change the federation gRPC server address that we advertise to other nodes in the cluster. + # Defaults to "fed_addr".However, in some cases, there may be a routable address that cannot be bound. + # If the port is missing, the default federation port (8901) will be used. + advertise_fed_addr: :8901 + # gossip_addr is the address that the gossip will listen on, It is used for both UDP and TCP gossip. Defaults to :8902 + gossip_addr: :8902 + # advertise_gossip_addr is used to change the gossip server address that we advertise to other nodes in the cluster. + # Defaults to "GossipAddr" or the private IP address of the node if the IP in "GossipAddr" is 0.0.0.0. + # If the port is missing, the default gossip port (8902) will be used. + advertise_gossip_addr: :8902 + + # retry_join is the address of other nodes to join upon starting up. + # If port is missing, the default gossip port (8902) will be used. + #retry_join: + # - 127.0.0.1:8902 + + # rejoin_after_leave will be pass to "RejoinAfterLeave" in serf configuration. + # It controls our interaction with the snapshot file. + # When set to false (default), a leave causes a Serf to not rejoin the cluster until an explicit join is received. + # If this is set to true, we ignore the leave, and rejoin the cluster on start. + rejoin_after_leave: false + # snapshot_path will be pass to "SnapshotPath" in serf configuration. + # When Serf is started with a snapshot,it will attempt to join all the previously known nodes until one + # succeeds and will also avoid replaying old user events. + snapshot_path: + +# plugin loading orders +plugin_order: + # Uncomment auth to enable authentication. + - aplugin + #- admin + #- federation +log: + level: debug # debug | info | warn | error + file_path: "/logs/mqtt-broker/mqtt-broker.log" + # whether to dump MQTT packet in debug level + dump_packet: false + diff --git a/cmd/mqtt-broker/res/gmqtt_password.yml b/cmd/mqtt-broker/res/gmqtt_password.yml new file mode 100644 index 0000000..97b2062 --- /dev/null +++ b/cmd/mqtt-broker/res/gmqtt_password.yml @@ -0,0 +1,3 @@ +# This is a sample plain password file for the auth plugin. +- username: root + password: root \ No newline at end of file diff --git a/cmd/mqtt-broker/res/server.csr b/cmd/mqtt-broker/res/server.csr new file mode 100644 index 0000000..419817b --- /dev/null +++ b/cmd/mqtt-broker/res/server.csr @@ -0,0 +1,14 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIICFTCCAX4CAQAwcDELMAkGA1UEBhMCQ04xCzAJBgNVBAgMAlpKMQswCQYDVQQH +DAJIWjELMAkGA1UECgwCVFkxCzAJBgNVBAsMAlRZMQ4wDAYDVQQDDAV0ZWRnZTEd +MBsGCSqGSIb3DQEJARYOdGVkZ2VAdHV5YS5jb20wgZ8wDQYJKoZIhvcNAQEBBQAD +gY0AMIGJAoGBAMe1bzSZLfvqrBeBOgxAdaDqh8fWudqeb0wqC+ZSZg4uH+WEG4Hu +rMbt4B+b1U98ctpA/aOEgnZiV1z79w8Rm9ENvfCUOKJ8uJyVf2usAdR/HkudDOhU +KlnvXaCd5t99gi8pyBmYkaXf82ya7CN97f/Y35zNIcWTJhJmYmd3N6FRAgMBAAGg +ZTARBgkqhkiG9w0BCQIxBAwCdHkwFQYJKoZIhvcNAQkHMQgMBmJleW9uZDA5Bgkq +hkiG9w0BCQ4xLDAqMAkGA1UdEwQCMAAwCwYDVR0PBAQDAgXgMBAGA1UdEQQJMAeC +BXRlZGdlMA0GCSqGSIb3DQEBCwUAA4GBADSeemIkKCvrfOz+m0AxQN/9L8SEqdHB +l8YBaHSHgdxq6666ENPz5o2uPNnu6qYaBZUMOZ5223Sx2MJPNDAxemFnOw7YbnCV +jIPwI3O9KIFDZ+tmhEIVHSlqRFphYNIWAVVFBsdNkse1gLTLLBLKfbsCZeoD4Dz2 +mc3JPZjeo4Mq +-----END CERTIFICATE REQUEST----- diff --git a/cmd/mqtt-broker/res/server.key b/cmd/mqtt-broker/res/server.key new file mode 100644 index 0000000..2dc27f2 --- /dev/null +++ b/cmd/mqtt-broker/res/server.key @@ -0,0 +1,16 @@ +-----BEGIN PRIVATE KEY----- +MIICdQIBADANBgkqhkiG9w0BAQEFAASCAl8wggJbAgEAAoGBAMe1bzSZLfvqrBeB +OgxAdaDqh8fWudqeb0wqC+ZSZg4uH+WEG4HurMbt4B+b1U98ctpA/aOEgnZiV1z7 +9w8Rm9ENvfCUOKJ8uJyVf2usAdR/HkudDOhUKlnvXaCd5t99gi8pyBmYkaXf82ya +7CN97f/Y35zNIcWTJhJmYmd3N6FRAgMBAAECgYBVCLMGIWcMCdsm0vZlexja4KHZ +/Fr8dFONiaWxd0pPJWKddofD5l2ZAnZY3yCPjLzWo6+b7XMjdzIdvIdw2h2OwGpE +kPQST1lkN5VlPIG67jwmIyJVw1LBAqknmqRFLjJ8NcJRttNjYjEkpetMOq1rM3Di +90mY3lBLT2g5lZa0pQJBAOr0lEPC3WJMq1N04wzqM6h6y8FUKhGZDawl1HGne9bs +4IDzVEhiCT9VvN3eoX+bk6av1/uZOxHY78j81q8KzY8CQQDZmKpPOZFeBK8dIOt7 +L4XB1NMVAkOy4UFZ1I9lpn9OVSQEPrnV0oyMnKIIMCzGy4nnFmrY/u4LKpuYSoVO +lvMfAkAHwhOzORf+SvHNS6rDnmgeRA++Tn0lH5yn9ofRSOp56lBvcZly2mnbwYT+ +/n7uq8BwXJYRJLoimLsyM8cS+JRZAkAp40Glzqc1OiGbseKi7BsLnTSlLrJplQNH +j6urHcoUAj/UsV6E0utLhjuK5/s2qaf6XE5lR237qFAbmPzgjB5xAkAP7H0cfdzs +X9gCe4RqQgIuJzK6Y59GkVeWVT8lScL9FyWm2JmeGh907HgnTXt6t8Hhk23JxD3x +KNcIk5xzRaOO +-----END PRIVATE KEY----- diff --git a/cmd/mqtt-broker/res/server.pem b/cmd/mqtt-broker/res/server.pem new file mode 100644 index 0000000..44f6403 --- /dev/null +++ b/cmd/mqtt-broker/res/server.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDCzCCAfOgAwIBAgIJAMMZRyEj7GMLMA0GCSqGSIb3DQEBCwUAMHAxCzAJBgNV +BAYTAkNOMQswCQYDVQQIDAJaSjELMAkGA1UEBwwCSFoxCzAJBgNVBAoMAlRZMQsw +CQYDVQQLDAJUWTEOMAwGA1UEAwwFdGVkZ2UxHTAbBgkqhkiG9w0BCQEWDnRlZGdl +QHR1eWEuY29tMB4XDTIyMDEyNDA3MDAxNVoXDTMyMDEyMjA3MDAxNVowcDELMAkG +A1UEBhMCQ04xCzAJBgNVBAgMAlpKMQswCQYDVQQHDAJIWjELMAkGA1UECgwCVFkx +CzAJBgNVBAsMAlRZMQ4wDAYDVQQDDAV0ZWRnZTEdMBsGCSqGSIb3DQEJARYOdGVk +Z2VAdHV5YS5jb20wgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAMe1bzSZLfvq +rBeBOgxAdaDqh8fWudqeb0wqC+ZSZg4uH+WEG4HurMbt4B+b1U98ctpA/aOEgnZi +V1z79w8Rm9ENvfCUOKJ8uJyVf2usAdR/HkudDOhUKlnvXaCd5t99gi8pyBmYkaXf +82ya7CN97f/Y35zNIcWTJhJmYmd3N6FRAgMBAAGjLDAqMAkGA1UdEwQCMAAwCwYD +VR0PBAQDAgXgMBAGA1UdEQQJMAeCBXRlZGdlMA0GCSqGSIb3DQEBCwUAA4IBAQA5 +htWsbfo7XedP2DBbVRXWFhEw7RPFfmyFMgzQq3aifnNB93xpDRwauXH5k6TEsiIO +OKjQit9aiSA28sTad6k6S09SwJokeQ9l14T3vVVMdDVJCw1Hq/mEhgoGgpYM+om0 +t/gl7e4FHL0AH6vcAyO70Q4uVRGpnm6Ehp8MxW0f/uip6TLxSj3lTitkCytMSGMK +WhvTLy8gsD9sSkiZUL/jknVkSp5An3roayWZLZucPV0E2rINchRcMcrrY1UkeYu1 +HB94dGg2U7R7Qj0eJBdxJN0uCY5n02pBXabJXRtwvOReHsW6Qoo50MhWEd2sazCA +HwMRWr6g8aAun7QJfySG +-----END CERTIFICATE----- diff --git a/internal/hummingbird/core/application/alertcentreapp/alertapp.go b/internal/hummingbird/core/application/alertcentreapp/alertapp.go new file mode 100644 index 0000000..7ad46e0 --- /dev/null +++ b/internal/hummingbird/core/application/alertcentreapp/alertapp.go @@ -0,0 +1,960 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package alertcentreapp + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + resourceContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "github.com/winc-link/hummingbird/internal/pkg/utils" + dingding "github.com/winc-link/hummingbird/internal/tools/notify/dingding" + feishu "github.com/winc-link/hummingbird/internal/tools/notify/feishu" + yiqiweixin "github.com/winc-link/hummingbird/internal/tools/notify/qiyeweixin" + + "github.com/winc-link/hummingbird/internal/tools/notify/webapi" + "gorm.io/gorm" + "strconv" + "strings" + "time" +) + +type alertApp struct { + dic *di.Container + dbClient interfaces.DBClient + lc logger.LoggingClient +} + +func NewAlertCentreApp(ctx context.Context, dic *di.Container) interfaces.AlertRuleApp { + lc := container.LoggingClientFrom(dic.Get) + dbClient := resourceContainer.DBClientFrom(dic.Get) + + app := &alertApp{ + dic: dic, + dbClient: dbClient, + lc: lc, + } + go app.monitor() + return app +} + +func (p alertApp) AddAlertRule(ctx context.Context, req dtos.RuleAddRequest) (string, error) { + var insertAlertRule models.AlertRule + insertAlertRule.Id = utils.RandomNum() + insertAlertRule.Name = req.Name + insertAlertRule.AlertType = constants.DeviceAlertType + //insertAlertRule.Status = constants.RuleStop + insertAlertRule.AlertLevel = req.AlertLevel + insertAlertRule.Description = req.Description + resp, err := p.dbClient.AddAlertRule(insertAlertRule) + if err != nil { + return "", err + } + return resp.Id, nil +} + +func (p alertApp) UpdateAlertField(ctx context.Context, req dtos.RuleFieldUpdate) error { + if req.Id == "" { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update req id is required", nil) + } + alertRule, err := p.dbClient.AlertRuleById(req.Id) + if err != nil { + return err + } + dtos.ReplaceRuleFields(&alertRule, req) + err = p.dbClient.GetDBInstance().Table(alertRule.TableName()).Select("*").Updates(alertRule).Error + if err != nil { + return err + } + return nil +} + +func (p alertApp) UpdateAlertRule(ctx context.Context, req dtos.RuleUpdateRequest) error { + if req.Id == "" { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update req id is required", nil) + } + if len(req.SubRule) != 1 { + return errors.New("") + } + device, err := p.dbClient.DeviceById(req.SubRule[0].DeviceId) + if err != nil { + return err + } + + product, err := p.dbClient.ProductById(device.ProductId) + if err != nil { + return err + } + + alertRule, err := p.dbClient.AlertRuleById(req.Id) + if err != nil { + return err + } + + if req.SubRule[0].ProductId != device.ProductId { + return errort.NewCommonEdgeX(errort.AlertRuleParamsError, "device product id not equal to req product id", nil) + } + if len(req.Notify) > 0 { + if err = checkNotifyParam(req.Notify); err != nil { + return err + } + } + + var sql string + + switch req.SubRule[0].Trigger { + case constants.DeviceDataTrigger: + var code string + if v, ok := req.SubRule[0].Option["code"]; ok { + code = v + } else { + return errort.NewCommonEdgeX(errort.AlertRuleParamsError, "update rule code is required", nil) + } + + var find bool + var productProperty models.Properties + for _, property := range product.Properties { + if property.Code == code { + find = true + productProperty = property + break + } + } + if !find { + return errort.NewCommonEdgeX(errort.ProductPropertyCodeNotExist, "product property code exist", nil) + } + + switch productProperty.TypeSpec.Type { + case constants.SpecsTypeInt, constants.SpecsTypeFloat: + if err = checkSpecsTypeIntOrFloatParam(req.SubRule[0]); err != nil { + return err + } + case constants.SpecsTypeText: + if err = checkSpecsTypeTextParam(req.SubRule[0]); err != nil { + return err + } + case constants.SpecsTypeBool: + if err = checkSpecsTypeBoolParam(req.SubRule[0]); err != nil { + return err + } + case constants.SpecsTypeEnum: + if err = checkSpecsTypeEnumParam(req.SubRule[0]); err != nil { + return err + } + default: + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update rule code verify failed", nil) + } + + sql = req.BuildEkuiperSql(device.Id, productProperty.TypeSpec.Type) + + case constants.DeviceEventTrigger: + var code string + if v, ok := req.SubRule[0].Option["code"]; ok { + code = v + } else { + return errort.NewCommonEdgeX(errort.AlertRuleParamsError, "update rule code is required", nil) + } + var find bool + //var productProperty models.Properties + for _, event := range product.Events { + if event.Code == code { + find = true + //productProperty = property + break + } + } + if !find { + return errort.NewCommonEdgeX(errort.ProductPropertyCodeNotExist, "product event code exist", nil) + } + sqlTemp := `SELECT rule_id(),json_path_query(data, "$.eventTime") as report_time,deviceId FROM mqtt_stream where deviceId = "%s" and messageType = "EVENT_REPORT" and json_path_exists(data, "$.eventCode") = true and json_path_query(data, "$.eventCode") = "%s"` + sql = fmt.Sprintf(sqlTemp, device.Id, code) + case constants.DeviceStatusTrigger: + //{"code":"","device_id":"2499708","end_at":null,"start_at":null} + + var status string + deviceStatus := req.SubRule[0].Option["status"] + if deviceStatus == "" { + err = errort.NewCommonEdgeX(errort.DefaultReqParamsError, "required status parameter missing", nil) + return err + } + if deviceStatus == "在线" { + status = constants.DeviceOnline + } else if deviceStatus == "离线" { + status = constants.DeviceOffline + } else { + err = errort.NewCommonEdgeX(errort.DefaultReqParamsError, "required status parameter missing", nil) + return err + } + sqlTemp := `SELECT rule_id(),json_path_query(data, "$.time") as report_time,deviceId FROM mqtt_stream where deviceId = "%s" and messageType = "DEVICE_STATUS" and json_path_exists(data, "$.status") = true and json_path_query(data, "$.status") = "%s"` + sql = fmt.Sprintf(sqlTemp, device.Id, status) + default: + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update rule trigger is required", nil) + } + + if sql == "" { + return errort.NewCommonEdgeX(errort.AlertRuleParamsError, "sql is null", nil) + + } + + ekuiperApp := resourceContainer.EkuiperAppFrom(p.dic.Get) + exist, err := ekuiperApp.RuleExist(ctx, alertRule.Id) + if err != nil { + return err + } + configapp := resourceContainer.ConfigurationFrom(p.dic.Get) + if !exist { + if err = ekuiperApp.CreateRule(ctx, dtos.GetRuleAlertEkuiperActions(configapp.Service.Url()), alertRule.Id, sql); err != nil { + return err + } + } else { + if err = ekuiperApp.UpdateRule(ctx, dtos.GetRuleAlertEkuiperActions(configapp.Service.Url()), alertRule.Id, sql); err != nil { + return err + } + } + + dtos.ReplaceRuleModelFields(&alertRule, req) + //alertRule.Status = constants.RuleStop + alertRule.DeviceId = device.Id + err = p.dbClient.GetDBInstance().Table(alertRule.TableName()).Select("*").Updates(alertRule).Error + if err != nil { + return err + } + + return nil +} + +func checkNotifyParam(notify []dtos.Notify) error { + for _, d := range notify { + if !utils.InStringSlice(string(d.Name), constants.GetAlertWays()) { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "notify name not in alertways", nil) + } + if d.StartEffectTime == "" { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "startEffectTime is required", nil) + + } + if d.EndEffectTime == "" { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "endEffectTime is required", nil) + } + if !checkEffectTimeParam(d.StartEffectTime, d.EndEffectTime) { + return errort.NewCommonEdgeX(errort.EffectTimeParamsError, "The format of the effective time is"+ + " incorrect. The end time should be greater than the start time.", nil) + } + + } + return nil +} + +func checkSpecsTypeBoolParam(req dtos.SubRule) error { + var decideCondition string + if v, ok := req.Option["decide_condition"]; ok { + decideCondition = v + } else { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update rule decide_condition is required", nil) + } + + st := strings.Split(decideCondition, " ") + if len(st) != 2 { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update rule decide_condition verify failed", nil) + } + //if st[0] != "==" { + // return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update rule decide_condition verify failed", nil) + //} + if !(st[1] == "true" || st[1] == "false") { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update rule decide_condition verify failed", nil) + } + return nil +} + +func checkSpecsTypeEnumParam(req dtos.SubRule) error { + var decideCondition string + if v, ok := req.Option["decide_condition"]; ok { + decideCondition = v + } else { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update rule decide_condition is required", nil) + } + + st := strings.Split(decideCondition, " ") + if len(st) != 2 { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update rule decide_condition verify failed", nil) + } + //if st[0] != "==" { + // return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update rule decide_condition verify failed", nil) + //} + if st[0] == "" { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update rule decide_condition verify failed", nil) + } + return nil +} + +func checkSpecsTypeTextParam(req dtos.SubRule) error { + var decideCondition string + if v, ok := req.Option["decide_condition"]; ok { + decideCondition = v + } else { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update rule decide_condition is required", nil) + } + + st := strings.Split(decideCondition, " ") + if len(st) != 2 { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update rule decide_condition verify failed", nil) + } + //if st[0] != "==" { + // return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update rule decide_condition verify failed", nil) + //} + if st[1] == "" { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update rule decide_condition verify failed", nil) + } + return nil +} + +func checkSpecsTypeIntOrFloatParam(req dtos.SubRule) error { + var valueType, decideCondition string + + if v, ok := req.Option["value_type"]; ok { + valueType = v + } else { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update rule value_type is required", nil) + } + find := false + for _, s := range constants.ValueTypes { + if s == valueType { + find = true + break + } + } + if !find { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update rule value_type verify failed", nil) + } + + if v, ok := req.Option["decide_condition"]; ok { + decideCondition = v + } else { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update rule decide_condition is required", nil) + } + + st := strings.Split(decideCondition, " ") + if len(st) != 2 { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update rule decide_condition verify failed", nil) + } + find = false + for _, condition := range constants.DecideConditions { + if condition == st[0] { + find = true + break + } + } + if !find { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update rule decide_condition verify failed", nil) + } + return nil +} + +func (p alertApp) AlertRuleById(ctx context.Context, id string) (dtos.RuleResponse, error) { + alertRule, err := p.dbClient.AlertRuleById(id) + var response dtos.RuleResponse + if err != nil { + return response, err + } + var ruleResponse dtos.RuleResponse + ruleResponse.Id = alertRule.Id + ruleResponse.Name = alertRule.Name + ruleResponse.AlertType = alertRule.AlertType + ruleResponse.AlertLevel = alertRule.AlertLevel + ruleResponse.Status = alertRule.Status + ruleResponse.Condition = alertRule.Condition + ruleResponse.SilenceTime = alertRule.SilenceTime + ruleResponse.Description = alertRule.Description + ruleResponse.Created = alertRule.Created + ruleResponse.Modified = alertRule.Modified + ruleResponse.Notify = alertRule.Notify + if len(ruleResponse.Notify) == 0 { + ruleResponse.Notify = make([]models.SubNotify, 0) + } + + var ruleSubRules dtos.RuleSubRules + for _, rule := range alertRule.SubRule { + device, err := p.dbClient.DeviceById(alertRule.DeviceId) + if err != nil { + return response, err + } + product, err := p.dbClient.ProductById(device.ProductId) + if err != nil { + return response, err + } + code := rule.Option["code"] + var ( + eventCodeName string + propertyCodeName string + ) + + for _, event := range product.Events { + if event.Code == code { + eventCodeName = event.Name + } + } + for _, property := range product.Properties { + if property.Code == code { + propertyCodeName = property.Name + } + } + var valueType string + switch rule.Option["value_type"] { + case "original": + valueType = "原始值" + case "avg": + valueType = "平均值" + case "max": + valueType = "最大值" + case "min": + valueType = "最小值" + case "sum": + valueType = "求和值" + } + var condition string + switch rule.Trigger { + case constants.DeviceDataTrigger: + if rule.Option["value_cycle"] == "" { + if valueType == "" { + valueType = "原始值" + } + condition = string(constants.DeviceDataTrigger) + ": 产品: " + product.Name + " | " + + "设备: " + device.Name + " | " + + "功能: " + propertyCodeName + " | " + + "触发条件: " + valueType + " " + rule.Option["decide_condition"] + } else { + condition = string(constants.DeviceDataTrigger) + ": 产品: " + product.Name + " | " + + "设备: " + device.Name + " | " + + "功能: " + propertyCodeName + " | " + + "触发条件: " + valueType + " " + fmt.Sprintf("(%s)", rule.Option["value_cycle"]) + " " + rule.Option["decide_condition"] + } + case constants.DeviceEventTrigger: + condition = string(constants.DeviceEventTrigger) + ": 产品: " + product.Name + " | " + + "设备: " + device.Name + " | " + + fmt.Sprintf("事件 = %s", eventCodeName) + case constants.DeviceStatusTrigger: + condition = string(constants.DeviceStatusTrigger) + ": 产品: " + product.Name + " | " + + "设备: " + device.Name + " | " + + fmt.Sprintf("设备状态 = %s", rule.Option["status"]) + default: + condition = "" + } + ruleSubRules = append(ruleSubRules, dtos.RuleSubRule{ + ProductId: rule.ProductId, + ProductName: product.Name, + DeviceId: rule.DeviceId, + DeviceName: device.Name, + Trigger: rule.Trigger, + Code: code, + Condition: condition, + Option: rule.Option, + }) + } + + ruleResponse.SubRule = ruleSubRules + return ruleResponse, nil +} + +func (p alertApp) AlertRulesSearch(ctx context.Context, req dtos.AlertRuleSearchQueryRequest) ([]dtos.AlertRuleSearchQueryResponse, uint32, error) { + offset, limit := req.BaseSearchConditionQuery.GetPage() + resp, total, err := p.dbClient.AlertRuleSearch(offset, limit, req) + if err != nil { + return []dtos.AlertRuleSearchQueryResponse{}, 0, err + } + alertRules := make([]dtos.AlertRuleSearchQueryResponse, len(resp)) + for i, p := range resp { + alertRules[i] = dtos.RuleSearchQueryResponseFromModel(p) + } + return alertRules, total, nil +} + +func (p alertApp) AlertSearch(ctx context.Context, req dtos.AlertSearchQueryRequest) ([]dtos.AlertSearchQueryResponse, uint32, error) { + offset, limit := req.BaseSearchConditionQuery.GetPage() + resp, total, err := p.dbClient.AlertListSearch(offset, limit, req) + if err != nil { + return []dtos.AlertSearchQueryResponse{}, 0, err + } + return resp, total, nil +} + +func (p alertApp) AlertPlate(ctx context.Context, beforeTime int64) ([]dtos.AlertPlateQueryResponse, error) { + data, err := p.dbClient.AlertPlate(beforeTime) + if err != nil { + return []dtos.AlertPlateQueryResponse{}, err + } + var dealData []dtos.AlertPlateQueryResponse + dealData = append(append(append(append(dealData, dtos.AlertPlateQueryResponse{ + AlertLevel: constants.Urgent, + Count: p.getAlertDataCount(data, constants.Urgent), + }), dtos.AlertPlateQueryResponse{ + AlertLevel: constants.Important, + Count: p.getAlertDataCount(data, constants.Important), + }), dtos.AlertPlateQueryResponse{ + AlertLevel: constants.LessImportant, + Count: p.getAlertDataCount(data, constants.LessImportant), + }), dtos.AlertPlateQueryResponse{ + AlertLevel: constants.Remind, + Count: p.getAlertDataCount(data, constants.Remind), + }) + return dealData, nil +} + +func (p alertApp) getAlertDataCount(data []dtos.AlertPlateQueryResponse, level constants.AlertLevel) int { + for _, datum := range data { + if datum.AlertLevel == level { + return datum.Count + } + } + return 0 +} + +func (p alertApp) AlertRulesDelete(ctx context.Context, id string) error { + _, err := p.dbClient.AlertRuleById(id) + if err != nil { + return err + } + + ekuiperApp := resourceContainer.EkuiperAppFrom(p.dic.Get) + err = ekuiperApp.DeleteRule(ctx, id) + if err != nil { + return err + } + return p.dbClient.DeleteAlertRuleById(id) +} + +func (p alertApp) AlertRulesRestart(ctx context.Context, id string) error { + alertRule, err := p.dbClient.AlertRuleById(id) + if err != nil { + return err + } + if err = p.checkAlertRuleParam(ctx, alertRule, "restart"); err != nil { + return err + } + if alertRule.EkuiperRule() { + ekuiperApp := resourceContainer.EkuiperAppFrom(p.dic.Get) + err = ekuiperApp.RestartRule(ctx, id) + if err != nil { + return err + } + } + return p.dbClient.AlertRuleStart(id) +} + +func (p alertApp) AlertRulesStop(ctx context.Context, id string) error { + _, err := p.dbClient.AlertRuleById(id) + if err != nil { + return err + } + ekuiperApp := resourceContainer.EkuiperAppFrom(p.dic.Get) + err = ekuiperApp.StopRule(ctx, id) + if err != nil { + return err + } + return p.dbClient.AlertRuleStop(id) +} + +func (p alertApp) AlertRulesStart(ctx context.Context, id string) error { + alertRule, err := p.dbClient.AlertRuleById(id) + if err != nil { + return err + } + if err = p.checkAlertRuleParam(ctx, alertRule, "start"); err != nil { + return err + } + ekuiperApp := resourceContainer.EkuiperAppFrom(p.dic.Get) + err = ekuiperApp.StartRule(ctx, id) + if err != nil { + return err + } + return p.dbClient.AlertRuleStart(id) +} + +func (p alertApp) AlertIgnore(ctx context.Context, id string) error { + return p.dbClient.AlertIgnore(id) +} + +func (p alertApp) TreatedIgnore(ctx context.Context, id, message string) error { + return p.dbClient.TreatedIgnore(id, message) +} + +func (p alertApp) AlertRuleStatus(ctx context.Context, id string) (constants.RuleStatus, error) { + alertRule, err := p.dbClient.AlertRuleById(id) + if err != nil { + return "", err + } + return alertRule.Status, nil +} + +func (p alertApp) checkAlertRuleParam(ctx context.Context, rule models.AlertRule, operate string) error { + if operate == "start" { + if rule.Status == constants.RuleStart { + return errort.NewCommonErr(errort.AlertRuleStatusStarting, fmt.Errorf("alertRule id(%s) is runing ,not allow start", rule.Id)) + } + } + + if rule.AlertType == "" || rule.AlertLevel == "" { + return errort.NewCommonErr(errort.AlertRuleParamsError, fmt.Errorf("alertRule id(%s) alertType or alertLevel is null", rule.Id)) + } + + if len(rule.SubRule) == 0 { + return errort.NewCommonErr(errort.AlertRuleParamsError, fmt.Errorf("alertRule id(%s) subrule is null", rule.Id)) + } + + for _, subRule := range rule.SubRule { + if subRule.Trigger == "" { + return errort.NewCommonErr(errort.AlertRuleParamsError, fmt.Errorf("alertRule id(%s) subrule trigger is null", rule.Id)) + } + if subRule.ProductId == "" || subRule.DeviceId == "" { + return errort.NewCommonErr(errort.AlertRuleParamsError, fmt.Errorf("alertRule id(%s) device id or product id is null", rule.Id)) + } + product, err := p.dbClient.ProductById(subRule.ProductId) + if err != nil { + return errort.NewCommonErr(errort.AlertRuleProductOrDeviceUpdate, fmt.Errorf("alertRule id(%s) device id or product id is null", rule.Id)) + } + device, err := p.dbClient.DeviceById(subRule.DeviceId) + if err != nil { + return errort.NewCommonErr(errort.AlertRuleProductOrDeviceUpdate, fmt.Errorf("alertRule id(%s) product or device has been modified. Please edit the rule again", rule.Id)) + } + + if device.ProductId != product.Id { + return errort.NewCommonErr(errort.AlertRuleProductOrDeviceUpdate, fmt.Errorf("alertRule id(%s) product or device has been modified. Please edit the rule again", rule.Id)) + } + code := subRule.Option["code"] + switch subRule.Trigger { + case constants.DeviceDataTrigger: + if code == "" { + return errort.NewCommonErr(errort.AlertRuleParamsError, fmt.Errorf("alertRule id(%s) code is null", rule.Id)) + } + var find bool + var typeSpecType constants.SpecsType + for _, property := range product.Properties { + if property.Code == code { + find = true + typeSpecType = property.TypeSpec.Type + break + } + } + if !find { + return errort.NewCommonErr(errort.AlertRuleProductOrDeviceUpdate, fmt.Errorf("alertRule id(%s) product or device has been modified. Please edit the rule again", rule.Id)) + } + if !typeSpecType.AllowSendInEkuiper() { + return errort.NewCommonErr(errort.AlertRuleParamsError, fmt.Errorf("alertRule id(%s) %s allowSendInEkuiper", rule.Id, typeSpecType)) + } + valueType := subRule.Option["value_type"] + if valueType == "" { + return errort.NewCommonErr(errort.AlertRuleParamsError, fmt.Errorf("alertRule id(%s) valueType is null", rule.Id)) + } + var valueTypeFind bool + for _, s := range constants.ValueTypes { + if s == valueType { + valueTypeFind = true + break + } + } + if !valueTypeFind { + return errort.NewCommonErr(errort.AlertRuleParamsError, fmt.Errorf("alertRule id(%s) valueTypeFind error", rule.Id)) + } + + valueCycle := subRule.Option["value_cycle"] + if valueType != constants.Original { + if valueCycle == "" { + return errort.NewCommonErr(errort.AlertRuleParamsError, fmt.Errorf("alertRule id(%s) valueCycle is null", rule.Id)) + } + } + + decideCondition := subRule.Option["decide_condition"] + if decideCondition == "" { + return errort.NewCommonErr(errort.AlertRuleParamsError, fmt.Errorf("alertRule id(%s) decideCondition is null", rule.Id)) + } + + case constants.DeviceEventTrigger: + if code == "" { + return errort.NewCommonErr(errort.AlertRuleParamsError, fmt.Errorf("alertRule id(%s) code is null", rule.Id)) + } + var find bool + for _, event := range product.Events { + if event.Code == code { + find = true + break + } + } + if !find { + return errort.NewCommonErr(errort.AlertRuleProductOrDeviceUpdate, fmt.Errorf("alertRule id(%s) product or device has been modified. Please edit the rule again", rule.Id)) + } + case constants.DeviceStatusTrigger: + + } + + } + return nil +} + +func (p alertApp) AddAlert(ctx context.Context, req map[string]interface{}) error { + + deviceId, ok := req["deviceId"] + if !ok { + return errort.NewCommonErr(errort.DefaultReqParamsError, errors.New("")) + } + + ruleId, ok := req["rule_id"] + if !ok { + return errort.NewCommonErr(errort.DefaultReqParamsError, errors.New("")) + } + var ( + coverDeviceId string + coverRuleId string + ) + switch deviceId.(type) { + case string: + coverDeviceId = deviceId.(string) + case int: + coverDeviceId = strconv.Itoa(deviceId.(int)) + case int64: + coverDeviceId = strconv.Itoa(int(deviceId.(int64))) + case float64: + coverDeviceId = fmt.Sprintf("%f", deviceId.(float64)) + case float32: + coverDeviceId = fmt.Sprintf("%f", deviceId.(float64)) + } + if coverDeviceId == "" { + return errort.NewCommonErr(errort.DefaultReqParamsError, errors.New("")) + } + + switch ruleId.(type) { + case string: + coverRuleId = ruleId.(string) + case int: + coverRuleId = strconv.Itoa(ruleId.(int)) + case int64: + coverRuleId = strconv.Itoa(int(ruleId.(int64))) + case float64: + coverRuleId = fmt.Sprintf("%f", ruleId.(float64)) + case float32: + coverRuleId = fmt.Sprintf("%f", ruleId.(float64)) + } + + if coverRuleId == "" { + return errort.NewCommonErr(errort.DefaultReqParamsError, errors.New("")) + } + + device, err := p.dbClient.DeviceById(coverDeviceId) + if err != nil { + return err + } + product, err := p.dbClient.ProductById(device.ProductId) + if err != nil { + return err + } + alertRule, err := p.dbClient.AlertRuleById(coverRuleId) + if err != nil { + return err + } + + alertResult := make(map[string]interface{}) + alertResult["device_id"] = device.Id + alertResult["code"] = alertRule.SubRule[0].Option["code"] + if req["window_start"] != nil && req["window_end"] != nil { + p.lc.Info("msg report1:", req["window_start"]) + alertResult["start_at"] = req["window_start"] + alertResult["end_at"] = req["window_end"] + } else if req["report_time"] != nil { + reportTime := utils.InterfaceToString(req["report_time"]) + if len(reportTime) > 3 { + sa, err := strconv.Atoi(reportTime[0:len(reportTime)-3] + "000") + if err == nil { + alertResult["start_at"] = sa + } + ea, err := strconv.Atoi(reportTime[0:len(reportTime)-3] + "999") + if err == nil { + alertResult["end_at"] = ea + } + } + } + + if len(alertRule.SubRule) > 0 { + switch alertRule.SubRule[0].Trigger { + case constants.DeviceEventTrigger: + alertResult["trigger"] = string(constants.DeviceEventTrigger) + case constants.DeviceDataTrigger: + alertResult["trigger"] = string(constants.DeviceDataTrigger) + } + } + + var alertList models.AlertList + alertList.AlertRuleId = alertRule.Id + alertList.AlertResult = alertResult + alertList.TriggerTime = time.Now().UnixMilli() + + send := false + if alertRule.SilenceTime > 0 { + alertSend, err := p.dbClient.AlertListLastSend(alertRule.Id) + if err != nil { + if err == gorm.ErrRecordNotFound { + send = true + goto Jump + } + } else { + if alertSend.Created+alertRule.SilenceTime <= utils.MakeTimestamp() { + send = true + goto Jump + } + } + } + +Jump: + if send == false { + alertList.IsSend = false + _, err = p.dbClient.AddAlertList(alertList) + if err != nil { + return err + } + return nil + } + alertList.IsSend = true + alertList.Status = constants.Untreated + + _, err = p.dbClient.AddAlertList(alertList) + if err != nil { + return err + } + + for _, notify := range alertRule.Notify { + switch notify.Name { + case constants.SMS: + if !checkEffectTime(notify.StartEffectTime, notify.EndEffectTime) { + continue + } + var phoneNumber string + if v, ok := notify.Option["phoneNumber"]; ok { + phoneNumber = v + } + if phoneNumber == "" { + p.lc.Debug("phoneNumber is null") + continue + } + //templateId templateParamSet 内容请用户自行补充。 + var templateId string + var templateParamSet []string + + smsApp := resourceContainer.SmsServiceAppFrom(p.dic.Get) + go smsApp.Send(templateId, templateParamSet, []string{phoneNumber}) + case constants.PHONE: + case constants.QYweixin: + if !checkEffectTime(notify.StartEffectTime, notify.EndEffectTime) { + continue + } + weixinAlertClient := yiqiweixin.NewWeiXinClient(p.lc, p.dic) + //发送内容请用户自行完善 + text := "" + go weixinAlertClient.Send(notify.Option["webhook"], text) + case constants.DingDing: + if !checkEffectTime(notify.StartEffectTime, notify.EndEffectTime) { + continue + } + weixinAlertClient := dingding.NewDingDingClient(p.lc, p.dic) + //发送内容请用户自行完善 + text := "" + go weixinAlertClient.Send(notify.Option["webhook"], text) + case constants.FeiShu: + if !checkEffectTime(notify.StartEffectTime, notify.EndEffectTime) { + continue + } + feishuAlertClient := feishu.NewFeishuClient(p.lc, p.dic) + //发送内容请用户自行完善 + text := "" + go feishuAlertClient.Send(notify.Option["webhook"], text) + case constants.WEBAPI: + if !checkEffectTime(notify.StartEffectTime, notify.EndEffectTime) { + continue + } + webApiClient := webapi.NewWebApiClient(p.lc, p.dic) + headermap := make([]map[string]string, 0) + if header, ok := notify.Option["header"]; ok { + err := json.Unmarshal([]byte(header), &headermap) + if err != nil { + return err + } + } + go webApiClient.Send(notify.Option["webhook"], headermap, alertRule, device, product, req) + } + } + + return nil +} + +func checkEffectTime(startTime, endTime string) bool { + timeTemplate := "2006-01-02 15:04:05" + startstamp, _ := time.ParseInLocation(timeTemplate, fmt.Sprintf("%d-%02d-%02d", time.Now().Year(), time.Now().Month(), time.Now().Day())+" "+startTime, time.Local) + endstamp, _ := time.ParseInLocation(timeTemplate, fmt.Sprintf("%d-%02d-%02d", time.Now().Year(), time.Now().Month(), time.Now().Day())+" "+endTime, time.Local) + if startstamp.Unix() < time.Now().Unix() && endstamp.Unix() > time.Now().Unix() { + //发送 + return true + } else { + return false + } +} +func checkEffectTimeParam(startTime, endTime string) bool { + timeTemplate := "2006-01-02 15:04:05" + startstamp, _ := time.ParseInLocation(timeTemplate, fmt.Sprintf("%d-%02d-%02d", time.Now().Year(), time.Now().Month(), time.Now().Day())+" "+startTime, time.Local) + endstamp, _ := time.ParseInLocation(timeTemplate, fmt.Sprintf("%d-%02d-%02d", time.Now().Year(), time.Now().Month(), time.Now().Day())+" "+endTime, time.Local) + if endstamp.Unix()-startstamp.Unix() <= 0 { + return false + } + return true +} + +func (p alertApp) CheckRuleByProductId(ctx context.Context, productId string) error { + var req dtos.AlertRuleSearchQueryRequest + req.Status = string(constants.RuleStart) + alertRules, _, err := p.AlertRulesSearch(ctx, req) + if err != nil { + return err + } + for _, rule := range alertRules { + for _, subRule := range rule.SubRule { + if subRule.ProductId == productId { + return errort.NewCommonEdgeX(errort.ProductAssociationAlertRule, "This product has been bound"+ + " to alarm rules. Please stop reporting relevant alarm rules before proceeding with the operation.", nil) + } + } + } + return nil +} + +func (p alertApp) CheckRuleByDeviceId(ctx context.Context, deviceId string) error { + var req dtos.AlertRuleSearchQueryRequest + req.Status = string(constants.RuleStart) + alertRules, _, err := p.AlertRulesSearch(ctx, req) + if err != nil { + return err + } + for _, rule := range alertRules { + for _, subRule := range rule.SubRule { + if subRule.DeviceId == deviceId { + return errort.NewCommonEdgeX(errort.DeviceAssociationAlertRule, "This device has been bound to alarm"+ + " rules. Please stop reporting relevant alarm rules before proceeding with the operation", nil) + } + } + } + return nil +} diff --git a/internal/hummingbird/core/application/alertcentreapp/monitor.go b/internal/hummingbird/core/application/alertcentreapp/monitor.go new file mode 100644 index 0000000..7f43cc8 --- /dev/null +++ b/internal/hummingbird/core/application/alertcentreapp/monitor.go @@ -0,0 +1,61 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package alertcentreapp + +import ( + "context" + "github.com/winc-link/hummingbird/internal/dtos" + resourceContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "time" +) + +func (p alertApp) monitor() { + tickTime := time.Second * 5 + timeTickerChan := time.Tick(tickTime) + for { + select { + case <-timeTickerChan: + p.checkRuleStatus() + } + } +} + +func (p alertApp) checkRuleStatus() { + alerts, _, err := p.dbClient.AlertRuleSearch(0, -1, dtos.AlertRuleSearchQueryRequest{}) + if err != nil { + p.lc.Errorf("get alerts err:", err) + } + ekuiperApp := resourceContainer.EkuiperAppFrom(p.dic.Get) + for _, alert := range alerts { + if len(alert.SubRule) == 0 { + continue + } + resp, err := ekuiperApp.GetRuleStats(context.Background(), alert.Id) + if err != nil { + continue + } + status, ok := resp["status"] + if ok { + if status != string(alert.Status) { + if status == string(constants.RuleStop) { + p.dbClient.AlertRuleStop(alert.Id) + } else if status == string(constants.RuleStart) { + p.dbClient.AlertRuleStart(alert.Id) + } + } + } + } +} diff --git a/internal/hummingbird/core/application/categorytemplate/categoryapp.go b/internal/hummingbird/core/application/categorytemplate/categoryapp.go new file mode 100644 index 0000000..31de37a --- /dev/null +++ b/internal/hummingbird/core/application/categorytemplate/categoryapp.go @@ -0,0 +1,122 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package categorytemplate + +import ( + "context" + "encoding/json" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "time" + + resourceContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" +) + +type categoryApp struct { + dic *di.Container + dbClient interfaces.DBClient + lc logger.LoggingClient +} + +func NewCategoryTemplateApp(ctx context.Context, dic *di.Container) interfaces.CategoryApp { + lc := container.LoggingClientFrom(dic.Get) + dbClient := resourceContainer.DBClientFrom(dic.Get) + + return &categoryApp{ + dic: dic, + dbClient: dbClient, + lc: lc, + } +} + +func (m *categoryApp) CategoryTemplateSearch(ctx context.Context, req dtos.CategoryTemplateRequest) ([]dtos.CategoryTemplateResponse, uint32, error) { + offset, limit := req.BaseSearchConditionQuery.GetPage() + + categoryTemplates, total, err := m.dbClient.CategoryTemplateSearch(offset, limit, req) + if err != nil { + m.lc.Errorf("categoryTemplates Search err %v", err) + return []dtos.CategoryTemplateResponse{}, 0, err + } + + libs := make([]dtos.CategoryTemplateResponse, len(categoryTemplates)) + for i, categoryTemplate := range categoryTemplates { + libs[i] = dtos.CategoryTemplateResponseFromModel(categoryTemplate) + } + return libs, total, nil +} + +func (m *categoryApp) Sync(ctx context.Context, versionName string) (int64, error) { + filePath := versionName + "/category_template.json" + cosApp := resourceContainer.CosAppNameFrom(m.dic.Get) + bs, err := cosApp.Get(filePath) + if err != nil { + m.lc.Errorf(err.Error()) + return 0, err + } + var cosCategoryTemplateResp []dtos.CosCategoryTemplateResponse + err = json.Unmarshal(bs, &cosCategoryTemplateResp) + if err != nil { + m.lc.Errorf(err.Error()) + return 0, err + } + + baseQuery := dtos.BaseSearchConditionQuery{ + IsAll: true, + } + dbreq := dtos.CategoryTemplateRequest{BaseSearchConditionQuery: baseQuery} + categoryTemplateResponse, _, err := m.CategoryTemplateSearch(ctx, dbreq) + if err != nil { + return 0, err + } + + upsertCategoryTemplate := make([]models.CategoryTemplate, 0) + for _, cosCategoryTemplate := range cosCategoryTemplateResp { + var find bool + for _, localTemplateResponse := range categoryTemplateResponse { + if cosCategoryTemplate.CategoryKey == localTemplateResponse.CategoryKey { + upsertCategoryTemplate = append(upsertCategoryTemplate, models.CategoryTemplate{ + Id: localTemplateResponse.Id, + CategoryName: cosCategoryTemplate.CategoryName, + CategoryKey: cosCategoryTemplate.CategoryKey, + Scene: cosCategoryTemplate.Scene, + }) + find = true + break + } + } + if !find { + upsertCategoryTemplate = append(upsertCategoryTemplate, models.CategoryTemplate{ + Timestamps: models.Timestamps{ + Created: time.Now().Unix(), + }, + Id: utils.GenUUID(), + CategoryName: cosCategoryTemplate.CategoryName, + CategoryKey: cosCategoryTemplate.CategoryKey, + Scene: cosCategoryTemplate.Scene, + }) + } + } + rows, err := m.dbClient.BatchUpsertCategoryTemplate(upsertCategoryTemplate) + m.lc.Infof("upsert category template rows %+v", rows) + if err != nil { + return 0, err + } + return rows, nil +} diff --git a/internal/hummingbird/core/application/dataresource/dataresource.go b/internal/hummingbird/core/application/dataresource/dataresource.go new file mode 100644 index 0000000..cf5885f --- /dev/null +++ b/internal/hummingbird/core/application/dataresource/dataresource.go @@ -0,0 +1,371 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package dataresource + +import ( + "context" + "database/sql" + "fmt" + mqtt "github.com/eclipse/paho.mqtt.golang" + "github.com/mitchellh/mapstructure" + "github.com/winc-link/hummingbird/internal/dtos" + resourceContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "net/url" + "time" + + _ "github.com/influxdata/influxdb1-client/v2" + client "github.com/influxdata/influxdb1-client/v2" + //_ "github.com/taosdata/driver-go/v2/taosSql" +) + +type dataResourceApp struct { + dic *di.Container + dbClient interfaces.DBClient + lc logger.LoggingClient +} + +func NewDataResourceApp(ctx context.Context, dic *di.Container) interfaces.DataResourceApp { + lc := container.LoggingClientFrom(dic.Get) + dbClient := resourceContainer.DBClientFrom(dic.Get) + + app := &dataResourceApp{ + dic: dic, + dbClient: dbClient, + lc: lc, + } + return app +} + +func (p dataResourceApp) AddDataResource(ctx context.Context, req dtos.AddDataResourceReq) (string, error) { + var insertDataResource models.DataResource + insertDataResource.Name = req.Name + insertDataResource.Type = constants.DataResourceType(req.Type) + insertDataResource.Option = req.Option + insertDataResource.Option["sendSingle"] = true + id, err := p.dbClient.AddDataResource(insertDataResource) + if err != nil { + return "", err + } + return id, nil +} + +func (p dataResourceApp) DataResourceById(ctx context.Context, id string) (models.DataResource, error) { + if id == "" { + return models.DataResource{}, errort.NewCommonEdgeX(errort.DefaultReqParamsError, "req id is required", nil) + + } + dataResource, edgeXErr := p.dbClient.DataResourceById(id) + if edgeXErr != nil { + return models.DataResource{}, edgeXErr + } + return dataResource, nil +} + +func (p dataResourceApp) UpdateDataResource(ctx context.Context, req dtos.UpdateDataResource) error { + if req.Id == "" { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update req id is required", nil) + + } + dataResource, edgeXErr := p.dbClient.DataResourceById(req.Id) + if edgeXErr != nil { + return edgeXErr + } + ruleEngines, _, err := p.dbClient.RuleEngineSearch(0, -1, dtos.RuleEngineSearchQueryRequest{ + Status: string(constants.RuleEngineStart), + }) + if err != nil { + return err + } + + for _, engine := range ruleEngines { + if engine.DataResourceId == req.Id { + return errort.NewCommonErr(errort.RuleEngineIsStartingNotAllowUpdate, fmt.Errorf("please stop this rule engine (%s) before editing it", req.Id)) + } + } + + dtos.ReplaceDataResourceModelFields(&dataResource, req) + edgeXErr = p.dbClient.UpdateDataResource(dataResource) + if edgeXErr != nil { + return edgeXErr + } + return nil +} + +func (p dataResourceApp) DelDataResourceById(ctx context.Context, id string) error { + ruleEngines, _, err := p.dbClient.RuleEngineSearch(0, -1, dtos.RuleEngineSearchQueryRequest{ + Status: string(constants.RuleEngineStart), + }) + if err != nil { + return err + } + + for _, engine := range ruleEngines { + if engine.DataResourceId == id { + return errort.NewCommonErr(errort.RuleEngineIsStartingNotAllowUpdate, fmt.Errorf("please stop this rule engine (%s) before editing it", id)) + } + } + + err = p.dbClient.DelDataResource(id) + if err != nil { + return err + } + return nil +} + +func (p dataResourceApp) DataResourceSearch(ctx context.Context, req dtos.DataResourceSearchQueryRequest) ([]models.DataResource, uint32, error) { + offset, limit := req.BaseSearchConditionQuery.GetPage() + resp, total, err := p.dbClient.SearchDataResource(offset, limit, req) + if err != nil { + return []models.DataResource{}, 0, err + } + return resp, total, nil +} + +func (p dataResourceApp) DataResourceType(ctx context.Context) []constants.DataResourceType { + return constants.DataResources +} + +func (p dataResourceApp) DataResourceHealth(ctx context.Context, resourceId string) error { + dataResource, err := p.dbClient.DataResourceById(resourceId) + if err != nil { + return err + } + //return p.dbClient.UpdateDataResourceHealth(dataResource.Id, true) + switch dataResource.Type { + case constants.HttpResource: + err = p.checkHttpResourceHealth(dataResource) + case constants.MQTTResource: + err = p.checkMQTTResourceHealth(dataResource) + case constants.KafkaResource: + err = p.checkKafkaResourceHealth(dataResource) + case constants.InfluxDBResource: + err = p.checkInfluxDBResourceHealth(dataResource) + case constants.TDengineResource: + //err = p.checkTdengineResourceHealth(dataResource) + default: + return errort.NewCommonErr(errort.DefaultReqParamsError, fmt.Errorf("resource type not much")) + } + if err != nil { + return err + } + return p.dbClient.UpdateDataResourceHealth(dataResource.Id, true) +} + +func (p dataResourceApp) checkHttpResourceHealth(resource models.DataResource) error { + urlAddr := resource.Option["url"].(string) + _, err := url.Parse(urlAddr) + if err != nil { + return err + } + return nil +} + +func (p dataResourceApp) checkMQTTResourceHealth(resource models.DataResource) error { + var ( + server, topic string + clientId, username, password string + //certificationPath, + //privateKeyPath, rootCaPath, insecureSkipVerify, retained, compression, connectionSelector string + ) + server = resource.Option["server"].(string) + topic = resource.Option["topic"].(string) + clientId = resource.Option["clientId"].(string) + //protocolVersion = resource.Option["protocolVersion"] + //qos = resource.Option["qos"].(int) + username = resource.Option["username"].(string) + password = resource.Option["password"].(string) + //certificationPath = resource.Option["certificationPath"] + //privateKeyPath = resource.Option["privateKeyPath"] + //rootCaPath = resource.Option["rootCaPath"] + //insecureSkipVerify = resource.Option["insecureSkipVerify"] + //retained = resource.Option["retained"] + //compression = resource.Option["compression"] + //connectionSelector = resource.Option["connectionSelector"] + + if server == "" || topic == "" || clientId == "" || username == "" || password == "" { + + } + opts := mqtt.NewClientOptions() + opts.AddBroker(server) + opts.SetUsername(username) + opts.SetPassword(password) + opts.SetClientID(clientId) + + client := mqtt.NewClient(opts) + token := client.Connect() + // 如果连接失败,则终止程序 + if token.WaitTimeout(3*time.Second) && token.Error() != nil { + return token.Error() + } + defer client.Disconnect(250) + return nil +} + +func (p dataResourceApp) checkKafkaResourceHealth(resource models.DataResource) error { + return nil +} + +func (p dataResourceApp) checkInfluxDBResourceHealth(resource models.DataResource) error { + type influxSink struct { + addr string + username string + password string + measurement string + databaseName string + tagKey string + tagValue string + fields string + cli client.Client + fieldMap map[string]interface{} + hasTransform bool + } + var m influxSink + if i, ok := resource.Option["addr"]; ok { + if i, ok := i.(string); ok { + m.addr = i + } + } + if i, ok := resource.Option["username"]; ok { + if i, ok := i.(string); ok { + m.username = i + } + } + if i, ok := resource.Option["password"]; ok { + if i, ok := i.(string); ok { + m.password = i + } + } + if i, ok := resource.Option["measurement"]; ok { + if i, ok := i.(string); ok { + m.measurement = i + } + } + if i, ok := resource.Option["databasename"]; ok { + if i, ok := i.(string); ok { + m.databaseName = i + } + } + if i, ok := resource.Option["tagkey"]; ok { + if i, ok := i.(string); ok { + m.tagKey = i + } + } + if i, ok := resource.Option["tagvalue"]; ok { + if i, ok := i.(string); ok { + m.tagValue = i + } + } + if i, ok := resource.Option["fields"]; ok { + if i, ok := i.(string); ok { + m.fields = i + } + } + if i, ok := resource.Option["dataTemplate"]; ok { + if i, ok := i.(string); ok && i != "" { + m.hasTransform = true + } + } + + _, err := client.NewHTTPClient(client.HTTPConfig{ + Addr: m.addr, + Username: m.username, + Password: m.password, + }) + if err != nil { + return err + } + return nil +} + +func (p dataResourceApp) checkTdengineResourceHealth(resource models.DataResource) error { + type taosConfig struct { + ProvideTs bool `json:"provideTs"` + Port int `json:"port"` + Ip string `json:"ip"` // To be deprecated + Host string `json:"host"` + User string `json:"user"` + Password string `json:"password"` + Database string `json:"database"` + Table string `json:"table"` + TsFieldName string `json:"tsFieldName"` + Fields []string `json:"fields"` + STable string `json:"sTable"` + TagFields []string `json:"tagFields"` + DataTemplate string `json:"dataTemplate"` + TableDataField string `json:"tableDataField"` + } + cfg := &taosConfig{ + User: "root", + Password: "taosdata", + } + err := MapToStruct(resource.Option, cfg) + if err != nil { + return fmt.Errorf("read properties %v fail with error: %v", resource.Option, err) + } + if cfg.Ip != "" { + fmt.Errorf("Deprecated: Tdengine sink ip property is deprecated, use host instead.") + if cfg.Host == "" { + cfg.Host = cfg.Ip + } + } + if cfg.Host == "" { + cfg.Host = "localhost" + } + if cfg.User == "" { + return fmt.Errorf("propert user is required.") + } + if cfg.Password == "" { + return fmt.Errorf("propert password is required.") + } + if cfg.Database == "" { + return fmt.Errorf("property database is required") + } + if cfg.Table == "" { + return fmt.Errorf("property table is required") + } + if cfg.TsFieldName == "" { + return fmt.Errorf("property TsFieldName is required") + } + if cfg.STable != "" && len(cfg.TagFields) == 0 { + return fmt.Errorf("property tagFields is required when sTable is set") + } + url := fmt.Sprintf(`%s:%s@tcp(%s:%d)/%s`, cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Database) + //m.conf = cfg + _, err = sql.Open("taosSql", url) + if err != nil { + return err + } + return nil +} + +func MapToStruct(input, output interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + Result: output, + } + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} diff --git a/internal/hummingbird/core/application/deviceapp/cloudability.go b/internal/hummingbird/core/application/deviceapp/cloudability.go new file mode 100644 index 0000000..692184f --- /dev/null +++ b/internal/hummingbird/core/application/deviceapp/cloudability.go @@ -0,0 +1,118 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package deviceapp + +import ( + "context" + "github.com/docker/distribution/uuid" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "strconv" + + "github.com/winc-link/edge-driver-proto/drivercommon" + "github.com/winc-link/edge-driver-proto/driverdevice" +) + +func (p deviceApp) ConnectIotPlatform(ctx context.Context, request *driverdevice.ConnectIotPlatformRequest) *driverdevice.ConnectIotPlatformResponse { + response := new(driverdevice.ConnectIotPlatformResponse) + baseResponse := new(drivercommon.CommonResponse) + baseResponse.Success = false + deviceInfo, err := p.DeviceById(ctx, request.DeviceId) + + if err != nil { + errWrapper := errort.NewCommonEdgeXWrapper(err) + baseResponse.Code = strconv.Itoa(int(errWrapper.Code())) + baseResponse.ErrorMessage = errWrapper.Message() + response.BaseResponse = baseResponse + return response + } + //把消息投体进入消息总线 + messageApp := container.MessageItfFrom(p.dic.Get) + messageApp.DeviceStatusToMessageBus(ctx, deviceInfo.Id, constants.DeviceOnline) + + err = p.dbClient.DeviceOnlineById(request.DeviceId) + if err != nil { + errWrapper := errort.NewCommonEdgeXWrapper(err) + baseResponse.Code = strconv.Itoa(int(errWrapper.Code())) + baseResponse.ErrorMessage = errWrapper.Message() + response.BaseResponse = baseResponse + return response + } else { + baseResponse.Success = true + baseResponse.RequestId = uuid.Generate().String() + response.Data = new(driverdevice.ConnectIotPlatformResponse_Data) + response.Data.Status = driverdevice.ConnectStatus_ONLINE + response.BaseResponse = baseResponse + return response + } +} + +func (p deviceApp) DisConnectIotPlatform(ctx context.Context, request *driverdevice.DisconnectIotPlatformRequest) *driverdevice.DisconnectIotPlatformResponse { + deviceInfo, err := p.DeviceById(ctx, request.DeviceId) + response := new(driverdevice.DisconnectIotPlatformResponse) + baseResponse := new(drivercommon.CommonResponse) + baseResponse.Success = false + if err != nil { + errWrapper := errort.NewCommonEdgeXWrapper(err) + baseResponse.Code = strconv.Itoa(int(errWrapper.Code())) + baseResponse.ErrorMessage = errWrapper.Message() + response.BaseResponse = baseResponse + return response + } + messageApp := container.MessageItfFrom(p.dic.Get) + messageApp.DeviceStatusToMessageBus(ctx, deviceInfo.Id, constants.DeviceOffline) + + err = p.dbClient.DeviceOfflineById(request.DeviceId) + if err != nil { + errWrapper := errort.NewCommonEdgeXWrapper(err) + baseResponse.Code = strconv.Itoa(int(errWrapper.Code())) + baseResponse.ErrorMessage = errWrapper.Message() + response.BaseResponse = baseResponse + return response + } + baseResponse.Success = true + baseResponse.RequestId = uuid.Generate().String() + response.Data = new(driverdevice.DisconnectIotPlatformResponse_Data) + response.Data.Status = driverdevice.ConnectStatus_OFFLINE + response.BaseResponse = baseResponse + return response +} + +func (p deviceApp) GetDeviceConnectStatus(ctx context.Context, request *driverdevice.GetDeviceConnectStatusRequest) *driverdevice.GetDeviceConnectStatusResponse { + deviceInfo, err := p.DeviceById(ctx, request.DeviceId) + response := new(driverdevice.GetDeviceConnectStatusResponse) + baseResponse := new(drivercommon.CommonResponse) + baseResponse.Success = false + if err != nil { + errWrapper := errort.NewCommonEdgeXWrapper(err) + baseResponse.Code = strconv.Itoa(int(errWrapper.Code())) + baseResponse.ErrorMessage = errWrapper.Message() + response.BaseResponse = baseResponse + return response + } + + baseResponse.Success = true + baseResponse.RequestId = uuid.Generate().String() + response.Data = new(driverdevice.GetDeviceConnectStatusResponse_Data) + if deviceInfo.Status == constants.DeviceStatusOnline { + response.Data.Status = driverdevice.ConnectStatus_ONLINE + } else { + response.Data.Status = driverdevice.ConnectStatus_OFFLINE + } + response.BaseResponse = baseResponse + return response + +} diff --git a/internal/hummingbird/core/application/deviceapp/deviceaction.go b/internal/hummingbird/core/application/deviceapp/deviceaction.go new file mode 100644 index 0000000..8a81a9d --- /dev/null +++ b/internal/hummingbird/core/application/deviceapp/deviceaction.go @@ -0,0 +1,270 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package deviceapp + +import ( + "context" + "encoding/json" + "github.com/docker/distribution/uuid" + "github.com/winc-link/edge-driver-proto/thingmodel" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/tools/rpcclient" + "time" +) + +const ( + ResultSuccess = "success" + ResultFail = "fail" +) + +// ui控制台、定时任务、场景联动、云平台api 都可以调用 +func (p *deviceApp) DeviceAction(jobAction dtos.JobAction) dtos.DeviceExecRes { + defer func() { + if err := recover(); err != nil { + p.lc.Error("CreateDeviceCallBack Panic:", err) + } + }() + + device, err := p.dbClient.DeviceById(jobAction.DeviceId) + if err != nil { + return dtos.DeviceExecRes{ + Result: false, + Message: "device not found", + } + } + deviceService, err := p.dbClient.DeviceServiceById(device.DriveInstanceId) + if err != nil { + return dtos.DeviceExecRes{ + Result: false, + Message: "driver not found", + } + } + + driverService := container.DriverServiceAppFrom(di.GContainer.Get) + status := driverService.GetState(deviceService.Id) + if status == constants.RunStatusStarted { + client, errX := rpcclient.NewDriverRpcClient(deviceService.BaseAddress, false, "", deviceService.Id, p.lc) + if errX != nil { + return dtos.DeviceExecRes{ + Result: false, + Message: errX.Error(), + } + } + defer client.Close() + var rpcRequest thingmodel.ThingModelIssueMsg + rpcRequest.DeviceId = jobAction.DeviceId + rpcRequest.OperationType = thingmodel.OperationType_PROPERTY_SET + var data dtos.PropertySet + data.Version = "v1.0" + data.MsgId = uuid.Generate().String() + data.Time = time.Now().UnixMilli() + param := make(map[string]interface{}) + param[jobAction.Code] = jobAction.Value + data.Params = param + rpcRequest.Data = data.ToString() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + _, err = client.ThingModelDownServiceClient.ThingModelMsgIssue(ctx, &rpcRequest) + if err == nil { + return dtos.DeviceExecRes{ + Result: true, + Message: ResultSuccess, + } + } else { + return dtos.DeviceExecRes{ + Result: false, + Message: err.Error(), + } + } + } + return dtos.DeviceExecRes{ + Result: false, + Message: "driver status stop", + } +} + +func (p *deviceApp) DeviceInvokeThingService(invokeDeviceServiceReq dtos.InvokeDeviceServiceReq) dtos.DeviceExecRes { + defer func() { + if err := recover(); err != nil { + p.lc.Error("CreateDeviceCallBack Panic:", err) + } + }() + + device, err := p.dbClient.DeviceById(invokeDeviceServiceReq.DeviceId) + if err != nil { + return dtos.DeviceExecRes{ + Result: false, + Message: "device not found", + } + } + + product, err := p.dbClient.ProductById(device.ProductId) + if err != nil { + return dtos.DeviceExecRes{ + Result: false, + Message: "product not found", + } + } + var find bool + var callType constants.CallType + for _, action := range product.Actions { + if action.Code == invokeDeviceServiceReq.Code { + find = true + callType = action.CallType + } + } + + if !find { + return dtos.DeviceExecRes{ + Result: false, + Message: "code not found", + } + } + + deviceService, err := p.dbClient.DeviceServiceById(device.DriveInstanceId) + if err != nil { + return dtos.DeviceExecRes{ + Result: false, + Message: "driver not found", + } + } + + driverService := container.DriverServiceAppFrom(di.GContainer.Get) + status := driverService.GetState(deviceService.Id) + if status == constants.RunStatusStarted { + client, errX := rpcclient.NewDriverRpcClient(deviceService.BaseAddress, false, "", deviceService.Id, p.lc) + if errX != nil { + return dtos.DeviceExecRes{ + Result: false, + Message: errX.Error(), + } + } + defer client.Close() + var rpcRequest thingmodel.ThingModelIssueMsg + rpcRequest.DeviceId = invokeDeviceServiceReq.DeviceId + rpcRequest.OperationType = thingmodel.OperationType_SERVICE_EXECUTE + var data dtos.InvokeDeviceService + data.Version = "v1.0" + data.MsgId = uuid.Generate().String() + data.Time = time.Now().UnixMilli() + data.Data.Code = invokeDeviceServiceReq.Code + data.Data.InputParams = invokeDeviceServiceReq.Items + rpcRequest.Data = data.ToString() + + if callType == constants.CallTypeAsync { + //saveServiceInfo := genSaveServiceInfo(data.MsgId, data.Time, invokeDeviceServiceReq) + var saveServiceInfo dtos.ThingModelMessage + saveServiceInfo.OpType = int32(thingmodel.OperationType_SERVICE_EXECUTE) + saveServiceInfo.Cid = device.Id + var saveData dtos.SaveServiceIssueData + saveData.MsgId = data.MsgId + saveData.Code = invokeDeviceServiceReq.Code + saveData.Time = data.Time + saveData.InputParams = invokeDeviceServiceReq.Items + saveData.OutputParams = map[string]interface{}{ + "result": true, + "message": "success", + } + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + _, err = client.ThingModelDownServiceClient.ThingModelMsgIssue(ctx, &rpcRequest) + if err == nil { + persistItf := container.PersistItfFrom(p.dic.Get) + _ = persistItf.SaveDeviceThingModelData(saveServiceInfo) + return dtos.DeviceExecRes{ + Result: true, + Message: ResultSuccess, + } + } else { + return dtos.DeviceExecRes{ + Result: false, + Message: err.Error(), + } + } + } else if callType == constants.CallTypeSync { + var saveServiceInfo dtos.ThingModelMessage + saveServiceInfo.OpType = int32(thingmodel.OperationType_SERVICE_EXECUTE) + saveServiceInfo.Cid = device.Id + var saveData dtos.SaveServiceIssueData + saveData.MsgId = data.MsgId + saveData.Code = invokeDeviceServiceReq.Code + saveData.Time = data.Time + saveData.InputParams = invokeDeviceServiceReq.Items + persistItf := container.PersistItfFrom(p.dic.Get) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + _, err = client.ThingModelDownServiceClient.ThingModelMsgIssue(ctx, &rpcRequest) + + if err != nil { + return dtos.DeviceExecRes{ + Result: false, + Message: err.Error(), + } + } + + messageStore := container.MessageStoreItfFrom(p.dic.Get) + ch := messageStore.GenAckChan(data.MsgId) + + select { + case <-time.After(5 * time.Second): + ch.TryCloseChan() + saveData.OutputParams = map[string]interface{}{ + "result": false, + "message": "wait response timeout", + } + s, _ := json.Marshal(saveData) + saveServiceInfo.Data = string(s) + _ = persistItf.SaveDeviceThingModelData(saveServiceInfo) + return dtos.DeviceExecRes{ + Result: false, + Message: "wait response timeout", + } + case <-ctx.Done(): + saveData.OutputParams = map[string]interface{}{ + "result": false, + "message": "wait response timeout", + } + s, _ := json.Marshal(saveData) + saveServiceInfo.Data = string(s) + _ = persistItf.SaveDeviceThingModelData(saveServiceInfo) + return dtos.DeviceExecRes{ + Result: false, + Message: "wait response timeout", + } + case resp := <-ch.DataChan: + if v, ok := resp.(map[string]interface{}); ok { + saveData.OutputParams = v + s, _ := json.Marshal(saveData) + saveServiceInfo.Data = string(s) + _ = persistItf.SaveDeviceThingModelData(saveServiceInfo) + message, _ := json.Marshal(v) + return dtos.DeviceExecRes{ + Result: true, + Message: string(message), + } + } + } + + } + } + return dtos.DeviceExecRes{ + Result: false, + Message: "driver status stop", + } +} diff --git a/internal/hummingbird/core/application/deviceapp/deviceapp.go b/internal/hummingbird/core/application/deviceapp/deviceapp.go new file mode 100644 index 0000000..de9b270 --- /dev/null +++ b/internal/hummingbird/core/application/deviceapp/deviceapp.go @@ -0,0 +1,577 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package deviceapp + +import ( + "context" + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + resourceContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "time" +) + +type deviceApp struct { + //*propertyTyApp + dic *di.Container + dbClient interfaces.DBClient + lc logger.LoggingClient +} + +func NewDeviceApp(ctx context.Context, dic *di.Container) interfaces.DeviceItf { + lc := container.LoggingClientFrom(dic.Get) + dbClient := resourceContainer.DBClientFrom(dic.Get) + + return &deviceApp{ + dic: dic, + dbClient: dbClient, + lc: lc, + } +} + +func (p deviceApp) DeviceById(ctx context.Context, id string) (dtos.DeviceInfoResponse, error) { + device, err := p.dbClient.DeviceById(id) + var response dtos.DeviceInfoResponse + if err != nil { + return response, err + } + var deviceServiceName string + deviceService, err := p.dbClient.DeviceServiceById(device.DriveInstanceId) + if err != nil { + deviceServiceName = deviceService.Name + } + + response = dtos.DeviceInfoResponseFromModel(device, deviceServiceName) + return response, nil +} + +func (p deviceApp) OpenApiDeviceById(ctx context.Context, id string) (dtos.OpenApiDeviceInfoResponse, error) { + device, err := p.dbClient.DeviceById(id) + var response dtos.OpenApiDeviceInfoResponse + if err != nil { + return response, err + } + response = dtos.OpenApiDeviceInfoResponseFromModel(device) + return response, nil +} + +func (p deviceApp) OpenApiDeviceStatusById(ctx context.Context, id string) (dtos.OpenApiDeviceStatus, error) { + device, err := p.dbClient.DeviceById(id) + var response dtos.OpenApiDeviceStatus + if err != nil { + return response, err + } + response.Status = device.Status + return response, nil +} + +func (p deviceApp) DeviceByCloudId(ctx context.Context, id string) (models.Device, error) { + return p.dbClient.DeviceByCloudId(id) +} + +func (p deviceApp) DeviceModelById(ctx context.Context, id string) (models.Device, error) { + return p.dbClient.DeviceById(id) +} + +func (p *deviceApp) DevicesSearch(ctx context.Context, req dtos.DeviceSearchQueryRequest) ([]dtos.DeviceSearchQueryResponse, uint32, error) { + offset, limit := req.BaseSearchConditionQuery.GetPage() + resp, total, err := p.dbClient.DevicesSearch(offset, limit, req) + if err != nil { + return []dtos.DeviceSearchQueryResponse{}, 0, err + } + devices := make([]dtos.DeviceSearchQueryResponse, len(resp)) + for i, dev := range resp { + deviceService, _ := p.dbClient.DeviceServiceById(dev.DriveInstanceId) + devices[i] = dtos.DeviceResponseFromModel(dev, deviceService.Name) + } + return devices, total, nil +} + +func (p *deviceApp) OpenApiDevicesSearch(ctx context.Context, req dtos.DeviceSearchQueryRequest) ([]dtos.OpenApiDeviceInfoResponse, uint32, error) { + offset, limit := req.BaseSearchConditionQuery.GetPage() + resp, total, err := p.dbClient.DevicesSearch(offset, limit, req) + if err != nil { + return []dtos.OpenApiDeviceInfoResponse{}, 0, err + } + devices := make([]dtos.OpenApiDeviceInfoResponse, len(resp)) + for i, device := range resp { + devices[i] = dtos.OpenApiDeviceInfoResponseFromModel(device) + } + return devices, total, nil +} + +func (p *deviceApp) DevicesModelSearch(ctx context.Context, req dtos.DeviceSearchQueryRequest) ([]models.Device, uint32, error) { + offset, limit := req.BaseSearchConditionQuery.GetPage() + return p.dbClient.DevicesSearch(offset, limit, req) +} + +func (p *deviceApp) AddDevice(ctx context.Context, req dtos.DeviceAddRequest) (string, error) { + if req.DriverInstanceId != "" { + driverInstance, err := p.dbClient.DeviceServiceById(req.DriverInstanceId) + if err != nil { + return "", err + } + if driverInstance.Platform != "" && driverInstance.Platform != constants.IotPlatform_LocalIot { + return "", errort.NewCommonErr(errort.DeviceServiceMustLocalPlatform, fmt.Errorf("please sync product data")) + } + } + + productInfo, err := p.dbClient.ProductById(req.ProductId) + + if productInfo.Status == constants.ProductUnRelease { + return "", errort.NewCommonEdgeX(errort.ProductUnRelease, "The product has not been released yet. Please release the product before adding devices", nil) + } + deviceId := utils.RandomNum() + if err != nil { + return "", err + } + + err = resourceContainer.DataDBClientFrom(p.dic.Get).CreateTable(ctx, constants.DB_PREFIX+productInfo.Id, deviceId) + if err != nil { + return "", err + } + var insertDevice models.Device + insertDevice.Id = deviceId + insertDevice.Name = req.Name + insertDevice.ProductId = req.ProductId + insertDevice.Platform = constants.IotPlatform_LocalIot + insertDevice.DriveInstanceId = req.DriverInstanceId + insertDevice.Status = constants.DeviceStatusOffline + insertDevice.Secret = utils.GenerateDeviceSecret(12) + insertDevice.Description = req.Description + id, err := p.dbClient.AddDevice(insertDevice) + if err != nil { + return "", err + } + return id, nil +} + +func (p *deviceApp) BatchDeleteDevice(ctx context.Context, ids []string) error { + var searchReq dtos.DeviceSearchQueryRequest + searchReq.BaseSearchConditionQuery.Ids = dtos.ApiParamsArrayToString(ids) + devices, _, err := p.dbClient.DevicesSearch(0, -1, searchReq) + if err != nil { + return err + } + alertApp := resourceContainer.AlertRuleAppNameFrom(p.dic.Get) + for _, device := range devices { + edgeXErr := alertApp.CheckRuleByDeviceId(ctx, device.Id) + if edgeXErr != nil { + return edgeXErr + } + } + err = p.dbClient.BatchDeleteDevice(ids) + if err != nil { + return err + } + for _, device := range devices { + delDevice := device + go func() { + p.DeleteDeviceCallBack(delDevice) + }() + } + return nil +} + +func (p *deviceApp) DeviceMqttAuthInfo(ctx context.Context, id string) (dtos.DeviceAuthInfoResponse, error) { + mqttAuth, err := p.dbClient.DeviceMqttAuthInfo(id) + var response dtos.DeviceAuthInfoResponse + if err != nil { + return response, err + } + response = dtos.DeviceAuthInfoResponseFromModel(mqttAuth) + return response, nil +} + +func (p *deviceApp) AddMqttAuth(ctx context.Context, req dtos.AddMqttAuthInfoRequest) (string, error) { + var mqttAuth models.MqttAuth + mqttAuth.ClientId = req.ClientId + mqttAuth.UserName = req.UserName + mqttAuth.Password = req.Password + mqttAuth.ResourceId = req.ResourceId + mqttAuth.ResourceType = constants.ResourceType(req.ResourceType) + return p.dbClient.AddMqttAuthInfo(mqttAuth) +} + +func (p *deviceApp) DeleteDeviceById(ctx context.Context, id string) error { + deviceInfo, err := p.dbClient.DeviceById(id) + if err != nil { + return err + } + alertApp := resourceContainer.AlertRuleAppNameFrom(p.dic.Get) + edgeXErr := alertApp.CheckRuleByDeviceId(ctx, id) + if edgeXErr != nil { + return edgeXErr + } + + sceneApp := resourceContainer.SceneAppNameFrom(p.dic.Get) + edgeXErr = sceneApp.CheckSceneByDeviceId(ctx, id) + if edgeXErr != nil { + return edgeXErr + } + + err = p.dbClient.DeleteDeviceById(id) + if err != nil { + return err + } + _ = resourceContainer.DataDBClientFrom(p.dic.Get).DropTable(ctx, id) + + go func() { + p.DeleteDeviceCallBack(models.Device{ + Id: deviceInfo.Id, + DriveInstanceId: deviceInfo.DriveInstanceId, + }) + }() + return nil +} + +func (p *deviceApp) DeviceUpdate(ctx context.Context, req dtos.DeviceUpdateRequest) error { + if req.Id == "" { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update req id is required", nil) + + } + device, edgeXErr := p.dbClient.DeviceById(req.Id) + if edgeXErr != nil { + return edgeXErr + } + + if device.Platform != constants.IotPlatform_LocalIot { + + } + alertApp := resourceContainer.AlertRuleAppNameFrom(p.dic.Get) + edgeXErr = alertApp.CheckRuleByDeviceId(ctx, req.Id) + if edgeXErr != nil { + return edgeXErr + } + + sceneApp := resourceContainer.SceneAppNameFrom(p.dic.Get) + edgeXErr = sceneApp.CheckSceneByDeviceId(ctx, req.Id) + if edgeXErr != nil { + return edgeXErr + } + + dtos.ReplaceDeviceModelFields(&device, req) + edgeXErr = p.dbClient.UpdateDevice(device) + if edgeXErr != nil { + return edgeXErr + } + go func() { + p.UpdateDeviceCallBack(device) + }() + return nil +} + +func (p *deviceApp) DevicesUnBindDriver(ctx context.Context, req dtos.DevicesUnBindDriver) error { + var searchReq dtos.DeviceSearchQueryRequest + searchReq.BaseSearchConditionQuery.Ids = dtos.ApiParamsArrayToString(req.DeviceIds) + var devices []models.Device + var err error + var total uint32 + devices, total, err = p.dbClient.DevicesSearch(0, -1, searchReq) + if err != nil { + return err + } + if total == 0 { + return errort.NewCommonErr(errort.DeviceNotExist, fmt.Errorf("devices not found")) + } + + err = p.dbClient.BatchUnBindDevice(req.DeviceIds) + if err != nil { + return err + } + for _, device := range devices { + callBackDevice := device + go func() { + p.DeleteDeviceCallBack(models.Device{ + Id: callBackDevice.Id, + DriveInstanceId: callBackDevice.DriveInstanceId, + }) + }() + } + return nil +} + +//DevicesBindProductId +func (p *deviceApp) DevicesBindProductId(ctx context.Context, req dtos.DevicesBindProductId) error { + var searchReq dtos.DeviceSearchQueryRequest + searchReq.ProductId = req.ProductId + var devices []models.Device + var err error + var total uint32 + devices, total, err = p.dbClient.DevicesSearch(0, -1, searchReq) + if err != nil { + return err + } + if total == 0 { + return errort.NewCommonErr(errort.DeviceNotExist, fmt.Errorf("devices not found")) + } + var deviceIds []string + for _, device := range devices { + deviceIds = append(deviceIds, device.Id) + } + err = p.dbClient.BatchBindDevice(deviceIds, req.DriverInstanceId) + if err != nil { + return err + } + + return nil +} + +func (p *deviceApp) DevicesBindDriver(ctx context.Context, req dtos.DevicesBindDriver) error { + var searchReq dtos.DeviceSearchQueryRequest + searchReq.BaseSearchConditionQuery.Ids = dtos.ApiParamsArrayToString(req.DeviceIds) + var devices []models.Device + var err error + var total uint32 + devices, total, err = p.dbClient.DevicesSearch(0, -1, searchReq) + if err != nil { + return err + } + if total == 0 { + return errort.NewCommonErr(errort.DeviceNotExist, fmt.Errorf("devices not found")) + } + driverInstance, err := p.dbClient.DeviceServiceById(req.DriverInstanceId) + if err != nil { + return err + } + + for _, device := range devices { + if device.DriveInstanceId != "" { + return errort.NewCommonErr(errort.DeviceNotUnbindDriver, fmt.Errorf("please unbind the device with the driver first")) + } + } + + for _, device := range devices { + if driverInstance.Platform != device.Platform && driverInstance.Platform != "" { + return errort.NewCommonErr(errort.DeviceAndDriverPlatformNotIdentical, fmt.Errorf("the device platform is inconsistent with the drive platform")) + } + } + + err = p.dbClient.BatchBindDevice(req.DeviceIds, req.DriverInstanceId) + if err != nil { + return err + } + + for _, device := range devices { + device.DriveInstanceId = req.DriverInstanceId + callBackDevice := device + go func() { + p.CreateDeviceCallBack(callBackDevice) + }() + } + return nil +} + +func (p *deviceApp) DeviceUpdateConnectStatus(id string, status constants.DeviceStatus) error { + device, edgeXErr := p.dbClient.DeviceById(id) + if edgeXErr != nil { + return edgeXErr + } + device.Status = status + if status == constants.DeviceStatusOnline { + device.LastOnlineTime = utils.MakeTimestamp() + } + edgeXErr = p.dbClient.UpdateDevice(device) + if edgeXErr != nil { + return edgeXErr + } + return nil +} + +func setDeviceInfoSheet(file *dtos.ExportFile, req dtos.DeviceImportTemplateRequest) error { + file.Excel.SetSheetName("Sheet1", dtos.DevicesFilename) + + file.Excel.SetCellStyle(dtos.DevicesFilename, "A1", "A1", file.GetCenterStyle()) + file.Excel.MergeCell(dtos.DevicesFilename, "A1", "B1") + file.Excel.SetCellStr(dtos.DevicesFilename, "A1", "Device Base Info") + + file.Excel.SetCellStr(dtos.DevicesFilename, "A2", "DeviceName") + file.Excel.SetCellStr(dtos.DevicesFilename, "B2", "Description") + + return nil +} + +func (p *deviceApp) DeviceImportTemplateDownload(ctx context.Context, req dtos.DeviceImportTemplateRequest) (*dtos.ExportFile, error) { + file, err := dtos.NewExportFile(dtos.DevicesFilename) + if err != nil { + return nil, err + } + if err := setDeviceInfoSheet(file, req); err != nil { + p.lc.Error(err.Error()) + return nil, err + } + return file, nil +} + +func (p *deviceApp) UploadValidated(ctx context.Context, file *dtos.ImportFile) error { + rows, err := file.Excel.Rows(dtos.DevicesFilename) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultReadExcelErrorCode, "read rows error", err) + } + idx := 0 + for rows.Next() { + idx++ + cols, err := rows.Columns() + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultReadExcelErrorCode, "read cols error", err) + } + if idx == 1 { + continue + } + if idx == 2 { + if len(cols) != 2 { + return errort.NewCommonEdgeX(errort.DefaultReadExcelErrorCode, fmt.Sprintf("read cols error need len %d,but read len %d", 2, len(cols)), err) + } + continue + } + + // 空行过滤 + if len(cols) <= 0 { + continue + } + + if cols[0] == "" { + return errort.NewCommonEdgeX(errort.DefaultReadExcelErrorParamsRequiredCode, fmt.Sprintf("read excel params required %+v", "deviceName"), nil) + } + } + return nil +} + +func (p *deviceApp) DevicesImport(ctx context.Context, file *dtos.ImportFile, productId, driverInstanceId string) (int64, error) { + productService := resourceContainer.ProductAppNameFrom(p.dic.Get) + productInfo, err := productService.ProductById(ctx, productId) + if err != nil { + return 0, err + } + + if productInfo.Status == string(constants.ProductUnRelease) { + return 0, errort.NewCommonEdgeX(errort.ProductUnRelease, "The product has not been released yet. Please release the product before adding devices", nil) + } + + if driverInstanceId != "" { + driverService := resourceContainer.DriverServiceAppFrom(p.dic.Get) + driverInfo, err := driverService.Get(ctx, driverInstanceId) + if err != nil { + return 0, err + } + if driverInfo.Platform != "" && driverInfo.Platform != constants.IotPlatform_LocalIot { + return 0, errort.NewCommonEdgeX(errort.DeviceServiceMustLocalPlatform, "driver service must local platform", err) + } + } + + rows, err := file.Excel.Rows(dtos.DevicesFilename) + if err != nil { + return 0, errort.NewCommonEdgeX(errort.DefaultReadExcelErrorCode, "read rows error", err) + } + devices := make([]models.Device, 0) + idx := 0 + for rows.Next() { + idx++ + deviceAddRequest := models.Device{ + ProductId: productId, + DriveInstanceId: driverInstanceId, + } + cols, err := rows.Columns() + if err != nil { + return 0, errort.NewCommonEdgeX(errort.DefaultReadExcelErrorCode, "read cols error", err) + } + if idx == 1 { + continue + } + if idx == 2 { + if len(cols) != 2 { + return 0, errort.NewCommonEdgeX(errort.DefaultReadExcelErrorCode, fmt.Sprintf("read cols error need len %d,but read len %d", 2, len(cols)), err) + } + continue + } + + // 空行过滤 + if len(cols) <= 0 { + continue + } + + deviceAddRequest.Id = utils.RandomNum() + deviceAddRequest.Name = cols[0] + if len(cols) >= 2 { + deviceAddRequest.Description = cols[1] + } + deviceAddRequest.Status = constants.DeviceStatusOffline + deviceAddRequest.Platform = constants.IotPlatform_LocalIot + deviceAddRequest.Created = utils.MakeTimestamp() + deviceAddRequest.Secret = utils.GenerateDeviceSecret(12) + if deviceAddRequest.Name == "" { + return 0, errort.NewCommonEdgeX(errort.DefaultReadExcelErrorParamsRequiredCode, fmt.Sprintf("read excel params required %+v", deviceAddRequest), nil) + } + devices = append(devices, deviceAddRequest) + } + + for _, device := range devices { + err = resourceContainer.DataDBClientFrom(p.dic.Get).CreateTable(ctx, constants.DB_PREFIX+productInfo.Id, device.Id) + if err != nil { + return 0, err + } + } + + total, err := p.dbClient.BatchUpsertDevice(devices) + if err != nil { + return 0, err + } + + for _, device := range devices { + addDevice := device + go func() { + p.CreateDeviceCallBack(addDevice) + }() + } + return total, nil +} + +func (p *deviceApp) DevicesReportMsgGather(ctx context.Context) error { + var count int + var err error + startTime, endTime := GetYesterdayStartTimeAndEndTime() + persistApp := resourceContainer.PersistItfFrom(p.dic.Get) + count, err = persistApp.SearchDeviceMsgCount(startTime, endTime) + if err != nil { + + } + var msgGather models.MsgGather + msgGather.Count = count + msgGather.Date = time.Now().AddDate(0, 0, -1).Format("2006-01-02") + return p.dbClient.AddMsgGather(msgGather) +} + +func GetYesterdayStartTimeAndEndTime() (int64, int64) { + NowTime := time.Now() + var startTime time.Time + if NowTime.Hour() == 0 && NowTime.Minute() == 0 && NowTime.Second() == 0 { + startTime = time.Unix(NowTime.Unix()-86399, 0) //当天的最后一秒 + } else { + startTime = time.Unix(NowTime.Unix()-86400, 0) + } + currentYear := startTime.Year() + currentMonth := startTime.Month() + currentDay := startTime.Day() + yesterdayStartTime := time.Date(currentYear, currentMonth, currentDay, 0, 0, 0, 0, time.Local).UnixMilli() + yesterdayEndTime := time.Date(currentYear, currentMonth, currentDay, 23, 59, 59, 0, time.Local).UnixMilli() + return yesterdayStartTime, yesterdayEndTime +} diff --git a/internal/hummingbird/core/application/deviceapp/devicecallback.go b/internal/hummingbird/core/application/deviceapp/devicecallback.go new file mode 100644 index 0000000..111747b --- /dev/null +++ b/internal/hummingbird/core/application/deviceapp/devicecallback.go @@ -0,0 +1,113 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package deviceapp + +import ( + "context" + "github.com/winc-link/edge-driver-proto/devicecallback" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/tools/rpcclient" + "time" +) + +func (p *deviceApp) CreateDeviceCallBack(createDevice models.Device) { + defer func() { + if err := recover(); err != nil { + p.lc.Error("CreateDeviceCallBack Panic:", err) + } + }() + deviceService, err := p.dbClient.DeviceServiceById(createDevice.DriveInstanceId) + if err != nil { + return + } + + driverService := container.DriverServiceAppFrom(di.GContainer.Get) + status := driverService.GetState(deviceService.Id) + if status == constants.RunStatusStarted { + client, errX := rpcclient.NewDriverRpcClient(deviceService.BaseAddress, false, "", deviceService.Id, p.lc) + if errX != nil { + return + } + defer client.Close() + var rpcRequest devicecallback.CreateDeviceCallbackRequest + rpcRequest.Data = createDevice.TransformToDriverDevice() + rpcRequest.HappenTime = uint64(time.Now().Unix()) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + _, _ = client.DeviceCallBackServiceClient.CreateDeviceCallback(ctx, &rpcRequest) + } +} + +func (p *deviceApp) UpdateDeviceCallBack(updateDevice models.Device) { + defer func() { + if err := recover(); err != nil { + p.lc.Error("UpdateDeviceCallBack Panic:", err) + } + }() + deviceService, err := p.dbClient.DeviceServiceById(updateDevice.DriveInstanceId) + if err != nil { + return + } + + driverService := container.DriverServiceAppFrom(di.GContainer.Get) + status := driverService.GetState(deviceService.Id) + + if status == constants.RunStatusStarted { + client, errX := rpcclient.NewDriverRpcClient(deviceService.BaseAddress, false, "", deviceService.Id, p.lc) + if errX != nil { + return + } + defer client.Close() + var rpcRequest devicecallback.UpdateDeviceCallbackRequest + rpcRequest.Data = updateDevice.TransformToDriverDevice() + rpcRequest.HappenTime = uint64(time.Now().Second()) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + _, _ = client.DeviceCallBackServiceClient.UpdateDeviceCallback(ctx, &rpcRequest) + } +} + +func (p *deviceApp) DeleteDeviceCallBack(deleteDevice models.Device) { + //查出哪些驱动和这个平台相关联,做推送通知。 + defer func() { + if err := recover(); err != nil { + p.lc.Error("DeleteDeviceCallBack Panic:", err) + } + }() + deviceService, err := p.dbClient.DeviceServiceById(deleteDevice.DriveInstanceId) + if err != nil { + return + } + + driverService := container.DriverServiceAppFrom(di.GContainer.Get) + status := driverService.GetState(deviceService.Id) + + if status == constants.RunStatusStarted { + client, errX := rpcclient.NewDriverRpcClient(deviceService.BaseAddress, false, "", deviceService.Id, p.lc) + if errX != nil { + return + } + defer client.Close() + var rpcRequest devicecallback.DeleteDeviceCallbackRequest + rpcRequest.DeviceId = deleteDevice.Id + rpcRequest.HappenTime = uint64(time.Now().Second()) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + _, _ = client.DeviceCallBackServiceClient.DeleteDeviceCallback(ctx, &rpcRequest) + } +} diff --git a/internal/hummingbird/core/application/dmi/dmi.go b/internal/hummingbird/core/application/dmi/dmi.go new file mode 100644 index 0000000..51da798 --- /dev/null +++ b/internal/hummingbird/core/application/dmi/dmi.go @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package dmi + +import ( + "context" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/dmi/docker" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/di" + "sync" +) + +func New(dic *di.Container, ctx context.Context, wg *sync.WaitGroup, dcm dtos.DriverConfigManage) (interfaces.DMI, error) { + return docker.New(dic, ctx, wg, dcm) +} diff --git a/internal/hummingbird/core/application/dmi/docker/app.go b/internal/hummingbird/core/application/dmi/docker/app.go new file mode 100644 index 0000000..897f13b --- /dev/null +++ b/internal/hummingbird/core/application/dmi/docker/app.go @@ -0,0 +1,171 @@ +package docker + +import ( + "context" + "fmt" + "github.com/docker/docker/errdefs" + "github.com/winc-link/hummingbird/internal/dtos" + resourceContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "io/ioutil" + "path" + "strings" + "sync" +) + +var _ interfaces.DMI = new(dockerImpl) + +type dockerImpl struct { + dic *di.Container + lc logger.LoggingClient + dm *DockerManager + dcm *dtos.DriverConfigManage +} + +func (d *dockerImpl) GetDriverInstanceLogPath(serviceName string) string { + return d.dcm.GetHostLogFilePath(serviceName) +} +func New(dic *di.Container, ctx context.Context, wg *sync.WaitGroup, dcm dtos.DriverConfigManage) (*dockerImpl, error) { + dm, err := NewDockerManager(ctx, dic, &dcm) + if err != nil { + return nil, err + } + dil := &dockerImpl{ + dic: dic, + lc: container.LoggingClientFrom(dic.Get), + dm: dm, + dcm: &dcm, + } + dil.setDcmRootDir() + return dil, nil +} +func (d *dockerImpl) StopAllInstance() { + dbClient := resourceContainer.DBClientFrom(d.dic.Get) + + deviceService, _, err := dbClient.DeviceServicesSearch(0, -1, dtos.DeviceServiceSearchQueryRequest{}) + if err != nil { + d.lc.Errorf("search service error :", err.Error()) + } + for _, service := range deviceService { + d.lc.Info(fmt.Sprintf("stop docker instance[%s]", service.ContainerName)) + err := d.StopInstance(dtos.DeviceService{ContainerName: service.ContainerName}) + if err != nil { + d.lc.Error(fmt.Sprintf("stop docker instance[%s] error:", err.Error())) + } + } +} + +// DownApp 下载应用 +func (d *dockerImpl) DownApp(cfg dtos.DockerConfig, app dtos.DeviceLibrary, toVersion string) (string, error) { + authToken, err := d.getAuthToken(cfg.Address, cfg.Account, cfg.Password, cfg.SaltKey) + if err != nil { + return "", err + } + + // 2. pull images + return d.getApp(authToken, d.getImageUrl(cfg.Address, app.DockerRepoName, toVersion)) +} + +// getAuthToken 获取 docker 认证 token +func (d *dockerImpl) getAuthToken(address, account, pass, salt string) (string, error) { + if account == "" || pass == "" { + return "", nil + } + + // 处理docker密码 + rawPassword, err := utils.DecryptAuthPassword(pass, salt) + if err != nil { + d.lc.Errorf("3.getAuthToken docker id:%s, account:%s, password err:%n", address, account, err) + return "", err + } + return d.dm.GetAuthToken(account, rawPassword, address), nil +} + +// getApp 下载镜像 +func (d *dockerImpl) getApp(token, imageUrl string) (string, error) { + dockerImageId, dockerErr := d.dm.PullDockerImage(imageUrl, token) + if dockerErr != nil { + code := errort.DefaultSystemError + err := fmt.Errorf("driver library download error") + if errdefs.IsUnauthorized(dockerErr) || errdefs.IsForbidden(dockerErr) || + strings.Contains(dockerErr.Error(), "denied") || + strings.Contains(dockerErr.Error(), "unauthorized") { + code = errort.DeviceLibraryDockerAuthInvalid + err = fmt.Errorf("docker auth invalid") + } else if errdefs.IsNotFound(dockerErr) || strings.Contains(dockerErr.Error(), "not found") { + code = errort.DeviceLibraryDockerImagesNotFound + err = fmt.Errorf("docker images not found") + } else if strings.Contains(dockerErr.Error(), "invalid reference format") { + code = errort.DeviceLibraryDockerImagesNotFound + err = fmt.Errorf("docker images not found, url invalid") + } + d.lc.Errorf("4.getApp imageUrl %s, PullDockerImage err:%v", imageUrl, dockerErr) + + return "", errort.NewCommonErr(code, err) + } + + return dockerImageId, nil +} + +func (d *dockerImpl) getImageUrl(address, repoName, version string) string { + return path.Clean(address+"/"+repoName) + ":" + version +} + +// StateApp 驱动软件下载情况 +func (d *dockerImpl) StateApp(dockerImageId string) bool { + return d.dm.ExistImageById(dockerImageId) +} + +// GetAllApp 获取所有镜像信息 +func (d *dockerImpl) GetAllApp() []string { + return d.dm.GetAllImagesIds() +} + +func (d *dockerImpl) setDcmRootDir() { + res, err := d.dm.GetContainerInspect(d.dcm.DockerSelfName) + if err != nil { + d.lc.Errorf("GetContainerInspect:%v", err) + return + } + + for _, v := range res.Mounts { + if v.Destination == constants.DockerHummingbirdRootDir { + d.dcm.SetHostRootDir(v.Source) + break + } + } + networkName := res.HostConfig.NetworkMode.UserDefined() + d.dcm.SetNetworkName(networkName) +} + +func (d *dockerImpl) genRunServiceConfig(name string, cfgContent string, instanceType constants.InstanceType) (string, error) { + var err error + var filePath string + if instanceType == constants.CloudInstance { + + } else if instanceType == constants.DriverInstance { + filePath = d.dcm.GetRunConfigPath(name) + err = utils.CreateDirIfNotExist(constants.DockerHummingbirdRootDir + "/" + constants.DriverRunConfigDir) + if err != nil { + return "", err + } + } + err = ioutil.WriteFile(filePath, []byte(cfgContent), 0644) + + if err != nil { + d.lc.Error("err:", err) + return "", err + } + if instanceType == constants.CloudInstance { + + } else if instanceType == constants.DriverInstance { + return d.dcm.GetHostRunConfigPath(name), nil + } + return "", nil +} diff --git a/internal/hummingbird/core/application/dmi/docker/driver.go b/internal/hummingbird/core/application/dmi/docker/driver.go new file mode 100644 index 0000000..a140468 --- /dev/null +++ b/internal/hummingbird/core/application/dmi/docker/driver.go @@ -0,0 +1,102 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package docker + +import ( + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "net" + "time" +) + +// RemoveApp 删除App文件 +func (d *dockerImpl) RemoveApp(app dtos.DeviceLibrary) error { + return d.dm.ImageRemove(app.DockerImageId) +} + +func (d *dockerImpl) GetSelfIp() string { + ip, err := d.dm.GetContainerIp(d.dcm.DockerSelfName) + if err != nil { + d.lc.Errorf("GetContainerIp err:%v", err) + } + return ip +} + +// instance +func (d *dockerImpl) InstanceState(ins dtos.DeviceService) bool { + // 先判断docker 是否在运行中 + stats, err := d.dm.GetContainerRunStatus(ins.ContainerName) + if err != nil { + d.lc.Errorf("GetContainerRunStatus err:%v", err) + return false + } + if stats != constants.ContainerRunStatusRunning { + return false + } + + // 在通过ping 实例服务存在否 + client, err := net.DialTimeout("tcp", ins.BaseAddress, 2*time.Second) + defer func() { + if client != nil { + _ = client.Close() + } + }() + if err != nil { + return false + } + + return true +} + +func (d *dockerImpl) StartInstance(ins dtos.DeviceService, cfg dtos.RunServiceCfg) (string, error) { + // 关闭自定义开关 + if !ins.DockerParamsSwitch { + cfg.DockerParams = "" + } + filePath, err := d.genRunServiceConfig(ins.ContainerName, cfg.RunConfig, constants.DriverInstance) + if err != nil { + return "", err + } + ip, err := d.dm.ContainerStart(cfg.ImageRepo, ins.ContainerName, filePath, cfg.DockerMountDevices, cfg.DockerParams, constants.DriverInstance) + return ip, err +} + +func (d *dockerImpl) StopInstance(ins dtos.DeviceService) error { + err := d.dm.ContainerStop(ins.ContainerName) + if err != nil { + return err + } + return nil +} + +func (d *dockerImpl) DeleteInstance(ins dtos.DeviceService) error { + // 删除容器 + err := d.dm.ContainerRemove(ins.ContainerName) + if err != nil { + return err + } + paths := []string{ + d.dcm.GetRunConfigPath(ins.ContainerName), + d.dcm.GetMntDir(ins.ContainerName), + } + for _, v := range paths { + err = utils.RemoveFileOrDir(v) + if err != nil { + d.lc.Errorf("RemoveFileOrDir [%s] err %v", v, err) + } + } + return nil +} diff --git a/internal/hummingbird/core/application/dmi/docker/manager.go b/internal/hummingbird/core/application/dmi/docker/manager.go new file mode 100644 index 0000000..400da50 --- /dev/null +++ b/internal/hummingbird/core/application/dmi/docker/manager.go @@ -0,0 +1,560 @@ +package docker + +import ( + "context" + "encoding/base64" + "fmt" + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/network" + "github.com/docker/go-connections/nat" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "io" + "io/ioutil" + "regexp" + "strconv" + "strings" + "sync" + "time" + + "github.com/docker/docker/client" + flag "github.com/spf13/pflag" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/logger" + + dicContainer "github.com/winc-link/hummingbird/internal/pkg/container" +) + +const ( + containerInternalMountPath = "/mnt" + containerInternalConfigPath = "/etc/driver/res/configuration.toml" +) + +type DockerManager struct { + // 镜像repoTags:imageinfo + ImageMap map[string]ImageInfo + // 容器name:ContainerInfo + ContainerMap map[string]ContainerInfo + cli *client.Client + ctx context.Context + timeout time.Duration + dic *di.Container + lc logger.LoggingClient + authToken string + mutex sync.RWMutex + dcm *dtos.DriverConfigManage + defaultRegistries []string +} + +type CustomParams struct { + env []string + runtime string + net string + mnt []string +} + +// 镜像信息 +type ImageInfo struct { + Id string + RepoTags []string +} + +// 容器信息 +type ContainerInfo struct { + Id string + Name string + State string +} + +// NewDockerManager 创建 +func NewDockerManager(ctx context.Context, dic *di.Container, dcm *dtos.DriverConfigManage) (*DockerManager, error) { + cli, err := client.NewClientWithOpts(client.WithVersion(dcm.DockerApiVersion)) + if err != nil { + return nil, err + } + lc := dicContainer.LoggingClientFrom(dic.Get) + dockerManager := &DockerManager{ + cli: cli, + ImageMap: make(map[string]ImageInfo), + ContainerMap: make(map[string]ContainerInfo), + ctx: context.Background(), + timeout: time.Second * 10, + dic: dic, + lc: lc, + dcm: dcm, + } + dockerManager.setDefaultRegistry() + + dockerManager.flushImageMap() + tickTime := time.Second * 10 + timeTickerChanImage := time.Tick(tickTime) + go func() { + for { + select { + case <-ctx.Done(): + lc.Info("close to flushImageMap") + return + case <-timeTickerChanImage: + dockerManager.flushImageMap() + } + } + }() + + dockerManager.flushContainerMap() + timeTickerChanContainer := time.Tick(tickTime) + go func() { + for { + select { + case <-ctx.Done(): + lc.Info("close to flushContainerMap") + return + case <-timeTickerChanContainer: + dockerManager.flushContainerMap() + } + } + }() + + return dockerManager, nil +} + +func (dm *DockerManager) setDefaultRegistry() { + info, err := dm.cli.Info(dm.ctx) + if err != nil { + dm.lc.Errorf("get docker info err:%v", err) + return + } + for _, v := range info.RegistryConfig.IndexConfigs { + dm.defaultRegistries = append(dm.defaultRegistries, v.Name) + } +} + +// 刷新docker镜像数据至内存中 +func (dm *DockerManager) flushImageMap() { + dm.mutex.Lock() + defer dm.mutex.Unlock() + images, err := dm.GetImageList() + if err == nil { + dm.ImageMap = make(map[string]ImageInfo) + for _, image := range images { + if len(image.RepoTags) > 0 { + dm.ImageMap[image.RepoTags[0]] = ImageInfo{ + Id: image.ID, + RepoTags: image.RepoTags, + } + } + } + } +} + +// 刷新docker容器数据至内存中 +func (dm *DockerManager) flushContainerMap() { + dm.mutex.Lock() + defer dm.mutex.Unlock() + containers, err := dm.GetContainerList() + if err == nil { + dm.ContainerMap = make(map[string]ContainerInfo) + for _, c := range containers { + if len(c.Names) > 0 { + dm.ContainerMap[c.Names[0][1:]] = ContainerInfo{ + Id: c.ID, + Name: c.Names[0][1:], + State: c.State, + } + } + } + } +} + +// 启动容器,这里为强制重启 先删除容器在启动新容器 containerId 可为空 +func (dm *DockerManager) ContainerStart(imageRepo string, containerName string, runConfigPath string, mountDevices []string, customParams string, instanceType constants.InstanceType) (ip string, err error) { + dm.flushImageMap() + dm.flushContainerMap() + if !dm.ExistImageById(imageRepo) { + return "", errort.NewCommonErr(errort.DeviceLibraryDockerImagesNotFound, fmt.Errorf("imageRepo is %s not exist", imageRepo)) + } + + // 1.停止容器,相同名字容器直接删除 + _ = dm.ContainerStop(containerName) + _ = dm.ContainerRemove(containerName) + + // 2.创建新容器,配置等 + //exposedPorts, portMap := dm.makeExposedPorts(exposePorts) + //resourceDevices := dm.makeMountDevices(mountDevices) + binds := make([]string, 0) + binds = append(binds, "/etc/localtime:/etc/localtime:ro") // 挂载时区 + var thisRunMode container.NetworkMode + if instanceType == constants.CloudInstance { + + } else if instanceType == constants.DriverInstance { + binds = append(binds, runConfigPath+":"+containerInternalConfigPath) //挂载启动配置文件 + binds = append(binds, dm.dcm.GetHostMntDir(containerName)+":"+containerInternalMountPath) //挂载日志 + thisRunMode = container.NetworkMode(dm.dcm.NetWorkName) + + } + dockerCustomParams, err := dm.ParseCustomParams(customParams) + if err != nil { + dm.lc.Errorf("dockerCustomParams err: %+v", err) + return "", err + } + binds = append(binds, dockerCustomParams.mnt...) + + dm.lc.Infof("dockerCustomParams: %+v,%+v", dockerCustomParams, customParams) + dm.lc.Infof("binds: %+v", binds) + dm.lc.Infof("Image:%+v", dm.ImageMap[imageRepo]) + dm.lc.Infof("thisRunMode:%+v", string(thisRunMode)) + _, cErr := dm.cli.ContainerCreate(dm.ctx, &container.Config{ + Image: imageRepo, + Env: dockerCustomParams.env, + }, &container.HostConfig{ + Binds: binds, + NetworkMode: thisRunMode, + RestartPolicy: container.RestartPolicy{ + MaximumRetryCount: 10, + }, + Runtime: dockerCustomParams.runtime, + }, &network.NetworkingConfig{}, nil, containerName) + if cErr != nil { + return "", cErr + } + + // 3.启动容器 + if err = dm.cli.ContainerStart(dm.ctx, containerName, types.ContainerStartOptions{}); err != nil { + return "", errort.NewCommonEdgeX(errort.ContainerRunFail, "Start Container Fail", err) + } + + // 启动后暂停1秒查看状态 + time.Sleep(time.Second * 1) + dm.flushContainerMap() + + // 4.查看容器信息并返回相应的数据 + status, err := dm.GetContainerRunStatus(containerName) + dm.lc.Infof("status: %+v", status) + + if err != nil { + return "", errort.NewCommonEdgeX(errort.DefaultSystemError, "GetContainerRunStatus Fail", err) + } + if status != constants.ContainerRunStatusRunning { + err = errort.NewCommonEdgeX(errort.ContainerRunFail, fmt.Sprintf("%s container status %s", containerName, status), nil) + } + if thisRunMode.IsHost() { + ip = constants.HostAddress + } else { + ip, err = dm.GetContainerIp(containerName) + } + + return +} + +// 端口导出组装 +func (dm *DockerManager) makeExposedPorts(exposePorts []int) (nat.PortSet, nat.PortMap) { + portMap := make(nat.PortMap) + exposedPorts := make(nat.PortSet, 0) + var empty struct{} + for _, p := range exposePorts { + tmpPort, _ := nat.NewPort("tcp", strconv.Itoa(p)) + portMap[tmpPort] = []nat.PortBinding{ + { + HostIP: "", + HostPort: strconv.Itoa(p), + }, + } + exposedPorts[tmpPort] = empty + } + return exposedPorts, portMap +} + +// 挂载设备组装 +func (dm *DockerManager) makeMountDevices(devices []string) container.Resources { + resourceDevices := make([]container.DeviceMapping, 0) + for _, v := range devices { + resourceDevices = append(resourceDevices, container.DeviceMapping{ + PathOnHost: v, + PathInContainer: v, + CgroupPermissions: "rwm", + }) + } + return container.Resources{Devices: resourceDevices} +} + +func (dm *DockerManager) ContainerStop(containerIdOrName string) error { + defer dm.flushContainerMap() + if err := dm.cli.ContainerStop(dm.ctx, containerIdOrName, &dm.timeout); err != nil { + dm.lc.Infof("ContainerStop fail %v", err.Error()) + killErr := dm.cli.ContainerKill(dm.ctx, containerIdOrName, "SIGKILL") + if killErr != nil { + dm.lc.Infof("ContainerKill fail %v", err.Error()) + } + return nil + } + return nil +} + +// 默认为容器强制删除 +func (dm *DockerManager) ContainerRemove(containerIdOrName string) error { + dm.flushContainerMap() + if err := dm.cli.ContainerRemove(dm.ctx, containerIdOrName, types.ContainerRemoveOptions{Force: true}); err != nil { + dm.lc.Infof("ContainerRemove fail containerId: %s, err: %v", containerIdOrName, err.Error()) + // 先不用抛出错误 + return nil + } + dm.flushContainerMap() + return nil +} + +func (dm *DockerManager) ImageRemove(imageId string) error { + if imageId == "" { + return nil + } + dm.lc.Infof("doing remove imageId %s", imageId) + // 错误只做日志,不做抛出,以免影响后续操作 + dm.flushImageMap() + if _, ok := dm.ImageMap[imageId]; !ok { + dm.lc.Infof("remove imageId %s is not exist", imageId) + return nil + } + if _, err := dm.cli.ImageRemove(dm.ctx, imageId, types.ImageRemoveOptions{}); err != nil { + dm.lc.Infof("ImageRemove imageId %s fail %v", imageId, err.Error()) + return nil + } + dm.flushImageMap() + return nil +} + +func (dm *DockerManager) GetImageList() ([]types.ImageSummary, error) { + return dm.cli.ImageList(dm.ctx, types.ImageListOptions{ + All: true, + }) +} + +func (dm *DockerManager) GetContainerList() (containers []types.Container, err error) { + return dm.cli.ContainerList(dm.ctx, types.ContainerListOptions{ + All: true, + }) +} + +// 获取容器运行状态, 目前不做任何错误处理 +func (dm *DockerManager) GetContainerRunStatus(containerName string) (status string, err error) { + if len(containerName) == 0 { + return constants.ContainerRunStatusExited, nil + } + dm.mutex.Lock() + defer dm.mutex.Unlock() + if _, ok := dm.ContainerMap[containerName]; !ok { + return constants.ContainerRunStatusExited, nil + } + return dm.ContainerMap[containerName].State, nil +} + +func (dm *DockerManager) GetContainerIp(containerId string) (ip string, err error) { + ip = constants.HostAddress + res, err := dm.cli.ContainerInspect(dm.ctx, containerId) + if err != nil { + return + } + dm.lc.Infof("Container Networks %+v", res.NetworkSettings.Networks) + if _, ok := res.NetworkSettings.Networks["bridge"]; ok { + return res.NetworkSettings.Networks["bridge"].IPAddress, nil + } + //dm.lc.Infof("GetContainerIp fail networks %v", res.NetworkSettings.Networks) + return ip, err +} + +// 获取镜像id是否存在 +func (dm *DockerManager) ExistImageById(imageId string) bool { + if len(imageId) == 0 { + return false + } + if len(dm.checkGetImageId(imageId)) >= 0 { + return true + } + return false +} + +// 获取容器id +func (dm *DockerManager) getContainerIdByName(containerName string) string { + if len(containerName) == 0 { + return "" + } + for _, v := range dm.ContainerMap { + if v.Name == "/"+containerName { + return v.Id + } + } + return "" +} + +// 获取容器挂载设备的信息,返回挂载设备路径slice ["/dev/ttyUSB0","/dev/ttyUSB0"] +func (dm *DockerManager) GetContainerMountDevices(containerId string) []string { + resDevices := make([]string, 0) + if _, ok := dm.ContainerMap[containerId]; !ok { + dm.lc.Errorf("containerId is %s not exist", containerId) + return []string{} + } + res, err := dm.cli.ContainerInspect(dm.ctx, containerId) + if err != nil { + dm.lc.Errorf("containerInspect err %v", err) + return []string{} + } + for _, v := range res.HostConfig.Devices { + resDevices = append(resDevices, v.PathOnHost) + } + return resDevices +} + +func (dm *DockerManager) PullDockerImage(imageUrl string, authToken string) (string, error) { + var resp io.ReadCloser + var err error + var dockerImageId string + + // 每次pull都去重新刷新组装token + dm.lc.Debugf("authToken len %d", len(authToken)) + resp, err = dm.cli.ImagePull(dm.ctx, imageUrl, types.ImagePullOptions{ + RegistryAuth: authToken, + }) + if err != nil { + dm.lc.Errorf("url: %s ImagePull err: %+v", imageUrl, err) + err = errort.NewCommonErr(errort.DeviceLibraryImageDownloadFail, err) + return dockerImageId, err + } + + readResp, err := ioutil.ReadAll(resp) + if err != nil { + dm.lc.Errorf("url: %s ImagePull err: %+v", imageUrl, err) + return dockerImageId, err + } + dm.lc.Infof("readResp imageUrl %s, %+v", imageUrl, string(readResp)) + re, err := regexp.Compile(`Digest: (\w+:\w+)`) + if err != nil { + dm.lc.Errorf("regexp Compile err %v", err) + return dockerImageId, err + } + strSubMatch := re.FindStringSubmatch(string(readResp)) + if len(strSubMatch) < 2 { + dm.lc.Errorf("regexp not match imagesId") + return dockerImageId, errort.NewCommonEdgeX(errort.DeviceLibraryImageDownloadFail, "regexp not match imagesId", nil) + } + + dockerImageId = dm.checkGetImageId(imageUrl) + + if dockerImageId == "" { + return "", errort.NewCommonEdgeX(errort.DeviceLibraryImageDownloadFail, "docker images is null", nil) + } + + dm.lc.Infof("images pull success imageId: %s", dockerImageId) + return dockerImageId, nil +} + +func (dm *DockerManager) checkGetImageId(imageUrl string) string { + dm.flushImageMap() + for imageId, v := range dm.ImageMap { + repoTags := v.RepoTags + if utils.InStringSlice(imageUrl, repoTags) { + return imageId + } + // 补充默认docker url前缀 如:默认dockerhub的nginx下载下来是image是 nginx:1.12.0 那么补充默认后变成 docker.io/nginx:1.12.0 + for i, _ := range dm.defaultRegistries { + for j, _ := range repoTags { + repoTags[j] = dm.defaultRegistries[i] + "/" + repoTags[j] + } + if utils.InStringSlice(imageUrl, repoTags) { + return imageId + } + } + } + return "" +} + +func (dm *DockerManager) ContainerIsExist(containerIdOrName string) bool { + _, err := dm.cli.ContainerInspect(dm.ctx, containerIdOrName) + if err != nil { + return false + } + return true +} +func (dm *DockerManager) ContainerRename(originNameOrId string, nowName string) bool { + err := dm.cli.ContainerRename(dm.ctx, originNameOrId, nowName) + if err != nil { + dm.lc.Errorf("ContainerRename %v", err) + return false + } + return true +} + +// Deprecated: docker api 的flush接口返回的值无效 +func (dm *DockerManager) FlushAuthToken(username string, password string, serverAddress string) { + dm.lc.Debugf("docker token from serverAddress %s, username %s", serverAddress, username) + // docker api登陆结果的IdentityToken为空,这里的token为明文组装后base64的值 详见https://github.com/moby/moby/issues/38830 + dm.authToken = base64.StdEncoding.EncodeToString([]byte(`{"username":"` + username + `", "password": "` + password + `", "serveraddress": "` + serverAddress + `"}`)) +} + +func (dm *DockerManager) GetAuthToken(username string, password string, serverAddress string) string { + token := base64.StdEncoding.EncodeToString([]byte(`{"username":"` + username + `", "password": "` + password + `", "serveraddress": "` + serverAddress + `"}`)) + dm.lc.Debugf("docker token from serverAddress %s, username %s token %s", serverAddress, username, token) + return token +} + +// 自定义docker启动参数解析 +func (dm *DockerManager) ParseCustomParams(cmd string) (CustomParams, error) { + runMode := dm.dcm.DockerManageConfig.DockerRunMode + if !utils.InStringSlice(runMode, []string{ + constants.NetworkModeHost, constants.NetworkModeBridge, + }) { + runMode = constants.NetworkModeHost + } + params := CustomParams{ + runtime: "", + env: []string{}, + net: runMode, + } + if cmd == "" { + return params, nil + } + cmd = strings.Replace(cmd, "\n", "", -1) + strArr := strings.Split(cmd, "\\") + args := make([]string, 0) + for _, v := range strArr { + args = append(args, strings.Split(v, " ")...) + } + f := flag.NewFlagSet("edge-flag", flag.ContinueOnError) + f.StringVarP(¶ms.runtime, "runtime", "", "", "") + f.StringVarP(¶ms.net, "net", "", runMode, "") + f.StringArrayVarP(¶ms.env, "env", "e", []string{}, "") + f.StringArrayVarP(¶ms.mnt, "mnt", "v", []string{}, "") + err := f.Parse(args) + if err != nil { + return CustomParams{}, errort.NewCommonErr(errort.DockerParamsParseErr, fmt.Errorf("parse docker params err:%v", err)) + } + + // 内部限制只支持host和bridge两种模式 + if !utils.InStringSlice(params.net, []string{ + constants.NetworkModeHost, constants.NetworkModeBridge, + }) { + params.net = runMode + } + return params, nil +} + +// 自定义docker启动参数解析 +func (dm *DockerManager) ParseCustomParamsIsRunBridge(cmd string) (bool, error) { + params, err := dm.ParseCustomParams(cmd) + if err != nil { + return false, err + } + return container.NetworkMode(params.net).IsBridge(), nil +} + +func (dm *DockerManager) GetAllImagesIds() []string { + ids := make([]string, 0) + tmpMap := dm.ImageMap + for id, _ := range tmpMap { + ids = append(ids, id) + } + return ids +} + +func (dm *DockerManager) GetContainerInspect(containerName string) (types.ContainerJSON, error) { + return dm.cli.ContainerInspect(dm.ctx, containerName) +} diff --git a/internal/hummingbird/core/application/docapp/docsapp.go b/internal/hummingbird/core/application/docapp/docsapp.go new file mode 100644 index 0000000..73d53f2 --- /dev/null +++ b/internal/hummingbird/core/application/docapp/docsapp.go @@ -0,0 +1,100 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package docapp + +import ( + "context" + "encoding/json" + "github.com/winc-link/hummingbird/internal/dtos" + resourceContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "github.com/winc-link/hummingbird/internal/pkg/utils" +) + +type docApp struct { + dic *di.Container + dbClient interfaces.DBClient + lc logger.LoggingClient +} + +func (m docApp) SyncDocs(ctx context.Context, versionName string) (int64, error) { + filePath := versionName + "/doc.json" + cosApp := resourceContainer.CosAppNameFrom(m.dic.Get) + bs, err := cosApp.Get(filePath) + if err != nil { + m.lc.Errorf(err.Error()) + return 0, err + } + var cosDocTemplateResponse []dtos.CosDocTemplateResponse + err = json.Unmarshal(bs, &cosDocTemplateResponse) + if err != nil { + m.lc.Errorf(err.Error()) + return 0, err + } + + baseQuery := dtos.BaseSearchConditionQuery{ + IsAll: true, + } + dbreq := dtos.DocsSearchQueryRequest{BaseSearchConditionQuery: baseQuery} + docs, _, err := m.dbClient.DocsSearch(0, -1, dbreq) + if err != nil { + return 0, err + } + + upsertCosTemplate := make([]models.Doc, 0) + for _, cosDocsTemplate := range cosDocTemplateResponse { + var find bool + for _, localDocsResponse := range docs { + if cosDocsTemplate.Name == localDocsResponse.Name { + upsertCosTemplate = append(upsertCosTemplate, models.Doc{ + Id: localDocsResponse.Id, + Name: cosDocsTemplate.Name, + Sort: cosDocsTemplate.Sort, + JumpLink: cosDocsTemplate.JumpLink, + }) + find = true + break + } + } + if !find { + upsertCosTemplate = append(upsertCosTemplate, models.Doc{ + Id: utils.RandomNum(), + Name: cosDocsTemplate.Name, + Sort: cosDocsTemplate.Sort, + JumpLink: cosDocsTemplate.JumpLink, + }) + } + } + rows, err := m.dbClient.BatchUpsertDocsTemplate(upsertCosTemplate) + if err != nil { + return 0, err + } + return rows, nil +} + +func NewDocsApp(ctx context.Context, dic *di.Container) interfaces.DocsApp { + lc := container.LoggingClientFrom(dic.Get) + dbClient := resourceContainer.DBClientFrom(dic.Get) + + return &docApp{ + dic: dic, + dbClient: dbClient, + lc: lc, + } +} diff --git a/internal/hummingbird/core/application/driverapp/downconfig.go b/internal/hummingbird/core/application/driverapp/downconfig.go new file mode 100644 index 0000000..714d0e3 --- /dev/null +++ b/internal/hummingbird/core/application/driverapp/downconfig.go @@ -0,0 +1,118 @@ +package driverapp + +import ( + "context" + "github.com/google/uuid" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + //"gitlab.com/tedge/edgex/internal/pkg/utils" + // + //"github.com/google/uuid" + //"gitlab.com/tedge/edgex/internal/dtos" + //"gitlab.com/tedge/edgex/internal/models" + //"gitlab.com/tedge/edgex/internal/pkg/errort" + //resourceContainer "gitlab.com/tedge/edgex/internal/tedge/resource/container" +) + +// 配置镜像/驱动的账号密码仓库地址, 账号密码可为空 +func (app *driverLibApp) DownConfigAdd(ctx context.Context, req dtos.DockerConfigAddRequest) error { + dc := models.DockerConfig{ + Id: req.Id, + Address: req.Address, + } + var err error + // 账号密码为空的情况 + if req.Account == "" || req.Password == "" { + dc.Account = "" + dc.Password = "" + } else { + dc.Account = req.Account + dc.SaltKey = generateSaltKey() + dc.Password, err = utils.EncryptAuthPassword(req.Password, dc.SaltKey) + if err != nil { + return err + } + } + return app.DownConfigInternalAdd(dc) +} + +func (app *driverLibApp) DownConfigInternalAdd(dc models.DockerConfig) error { + _, err := app.dbClient.DockerConfigAdd(dc) + if err != nil { + return err + } + return nil +} + +func (app *driverLibApp) DownConfigUpdate(ctx context.Context, req dtos.DockerConfigUpdateRequest) error { + dbClient := container.DBClientFrom(app.dic.Get) + + if req.Id == "" { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update req id is required", nil) + } + + dc, edgeXErr := dbClient.DockerConfigById(req.Id) + if edgeXErr != nil { + return edgeXErr + } + + dtos.ReplaceDockerConfigModelFieldsWithDTO(&dc, req) + + if *req.Password != "" { + var err error + dc.SaltKey = generateSaltKey() + dc.Password, err = utils.EncryptAuthPassword(dc.Password, dc.SaltKey) + if err != nil { + return err + } + } + edgeXErr = dbClient.DockerConfigUpdate(dc) + if edgeXErr != nil { + return edgeXErr + } + return nil +} + +func (app *driverLibApp) DownConfigSearch(ctx context.Context, req dtos.DockerConfigSearchQueryRequest) ([]models.DockerConfig, uint32, error) { + offset, limit := req.BaseSearchConditionQuery.GetPage() + + dcs, total, err := app.dbClient.DockerConfigsSearch(offset, limit, req) + if err != nil { + return dcs, 0, err + } + + return dcs, total, nil +} + +func (app *driverLibApp) DownConfigDel(ctx context.Context, id string) error { + dc, edgeXErr := app.dbClient.DockerConfigById(id) + if edgeXErr != nil { + return edgeXErr + } + + // 判断 此配置是否被 library使用 + _, total, err := app.DeviceLibrariesSearch(ctx, dtos.DeviceLibrarySearchQueryRequest{ + DockerConfigId: id, + }) + if err != nil { + return edgeXErr + } + + if total > 0 { + return errort.NewCommonEdgeX(errort.DockerConfigMustDeleteDeviceLibrary, "请先删除绑定此配置的驱动", nil) + } + + err = app.dbClient.DockerConfigDelete(dc.Id) + if err != nil { + return err + } + return nil +} + +// 生成password的salt key +func generateSaltKey() string { + return uuid.New().String() +} diff --git a/internal/hummingbird/core/application/driverapp/driverapp.go b/internal/hummingbird/core/application/driverapp/driverapp.go new file mode 100644 index 0000000..0b898b6 --- /dev/null +++ b/internal/hummingbird/core/application/driverapp/driverapp.go @@ -0,0 +1,52 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package driverapp + +import ( + "context" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + pkgcontainer "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" +) + +type driverLibApp struct { + dic *di.Container + lc logger.LoggingClient + dbClient interfaces.DBClient + + manager DeviceLibraryManager + market *driverMarket // Refactor: interface +} + +func NewDriverApp(ctx context.Context, dic *di.Container) interfaces.DriverLibApp { + return newDriverLibApp(dic) +} + +func newDriverLibApp(dic *di.Container) *driverLibApp { + app := &driverLibApp{ + dic: dic, + lc: pkgcontainer.LoggingClientFrom(dic.Get), + dbClient: container.DBClientFrom(dic.Get), + } + app.manager = newDriverLibManager(dic, app) + app.market = newDriverMarket(dic, app) + return app +} + +func (app *driverLibApp) getDriverServiceApp() interfaces.DriverServiceApp { + return container.DriverServiceAppFrom(app.dic.Get) +} diff --git a/internal/hummingbird/core/application/driverapp/driverclassify.go b/internal/hummingbird/core/application/driverapp/driverclassify.go new file mode 100644 index 0000000..a5450ed --- /dev/null +++ b/internal/hummingbird/core/application/driverapp/driverclassify.go @@ -0,0 +1,38 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package driverapp + +import ( + "context" + "github.com/winc-link/hummingbird/internal/dtos" +) + +func (app *driverLibApp) GetDriverClassify(ctx context.Context, req dtos.DriverClassifyQueryRequest) ([]dtos.DriverClassifyResponse, uint32, error) { + offset, limit := req.BaseSearchConditionQuery.GetPage() + dcs, total, err := app.dbClient.DriverClassifySearch(offset, limit, req) + + res := make([]dtos.DriverClassifyResponse, len(dcs)) + if err != nil { + return res, 0, err + } + + for i, dc := range dcs { + res[i] = dtos.DriverClassifyResponse{ + Id: dc.Id, + Name: dc.Name, + } + } + return res, total, nil +} diff --git a/internal/hummingbird/core/application/driverapp/driverlib.go b/internal/hummingbird/core/application/driverapp/driverlib.go new file mode 100644 index 0000000..462eee9 --- /dev/null +++ b/internal/hummingbird/core/application/driverapp/driverlib.go @@ -0,0 +1,212 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package driverapp + +import ( + "context" + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/errort" +) + +func (app *driverLibApp) AddDriverLib(ctx context.Context, dl dtos.DeviceLibraryAddRequest) error { + + dlm := app.addOrUpdateSupportVersionConfig(dtos.FromDeviceLibraryRpcToModel(&dl)) + _, err := app.createDriverLib(dlm) + if err != nil { + return err + } + return nil +} + +func (app *driverLibApp) createDriverLib(dl models.DeviceLibrary) (models.DeviceLibrary, error) { + return app.dbClient.AddDeviceLibrary(dl) +} + +func (app *driverLibApp) addOrUpdateSupportVersionConfig(dl models.DeviceLibrary) models.DeviceLibrary { + for _, sv := range dl.SupportVersions { + if sv.Version == dl.Version { + //dl.SupportVersions[i].ConfigJson = dl.Config + return dl + } + } + if dl.Version == "" { + return dl + } + + // add + version := models.SupportVersion{ + Version: dl.Version, + } + dl.SupportVersions = []models.SupportVersion{version} + return dl +} + +func (app *driverLibApp) DeviceLibrariesSearch(ctx context.Context, req dtos.DeviceLibrarySearchQueryRequest) ([]models.DeviceLibrary, uint32, error) { + offset, limit := req.BaseSearchConditionQuery.GetPage() + + installingIds := app.manager.FilterState(constants.OperateStatusInstalling) + if req.DownloadStatus == constants.OperateStatusInstalling && len(installingIds) == 0 { + app.lc.Infof("deviceLibSearch install status is empty, req: %+v", req) + return []models.DeviceLibrary{}, 0, nil + } + + req = app.prepareSearch(req, installingIds) + + deviceLibraries, total, err := app.dbClient.DeviceLibrariesSearch(offset, limit, req) + if err != nil { + return deviceLibraries, 0, err + } + + dlIds := make([]string, 0) + dlMapDsExist := make(map[string]bool) + for _, v := range deviceLibraries { + dlIds = append(dlIds, v.Id) + dlMapDsExist[v.Id] = false + } + dss, _, err := app.dbClient.DeviceServicesSearch(0, -1, dtos.DeviceServiceSearchQueryRequest{DeviceLibraryIds: dtos.ApiParamsArrayToString(dlIds)}) + if err != nil { + return deviceLibraries, 0, err + } + for _, v := range dss { + dlMapDsExist[v.DeviceLibraryId] = true + } + + // 设置驱动库状态 + for i, dl := range deviceLibraries { + stats := app.manager.GetState(dl.Id) + if stats == constants.OperateStatusInstalling { + deviceLibraries[i].OperateStatus = constants.OperateStatusInstalling + } else if app.manager.ExistImage(dl.DockerImageId) && dlMapDsExist[dl.Id] { + deviceLibraries[i].OperateStatus = constants.OperateStatusInstalled + } else { + deviceLibraries[i].OperateStatus = constants.OperateStatusDefault + deviceLibraries[i].DockerImageId = "" + } + } + + return deviceLibraries, total, nil +} + +func (app *driverLibApp) prepareSearch(req dtos.DeviceLibrarySearchQueryRequest, installingIds []string) dtos.DeviceLibrarySearchQueryRequest { + // 处理驱动安装状态 + existImages := app.manager.GetAllImages() + if req.DownloadStatus == constants.OperateStatusInstalling { + req.Ids = dtos.ApiParamsArrayToString(installingIds) + } else if req.DownloadStatus == constants.OperateStatusInstalled { + req.NoInIds = dtos.ApiParamsArrayToString(installingIds) + req.ImageIds = dtos.ApiParamsArrayToString(existImages) + } else if req.DownloadStatus == constants.OperateStatusUninstall || req.DownloadStatus == constants.OperateStatusDefault { + req.NoInIds = dtos.ApiParamsArrayToString(installingIds) + req.NoInImageIds = dtos.ApiParamsArrayToString(existImages) + } + + return req +} + +func (app *driverLibApp) DeleteDeviceLibraryById(ctx context.Context, id string) error { + dl, err := app.dbClient.DeviceLibraryById(id) + if err != nil { + return err + } + + // 内置驱动市场不允许删除 + if dl.IsInternal { + return errort.NewCommonErr(errort.DeviceLibraryNotAllowDelete, fmt.Errorf("internal library not allow delete")) + } + + // 删除驱动前需要查看 驱动所属驱动实例是否存在 + _, total, edgeXErr := app.getDriverServiceApp().Search(ctx, dtos.DeviceServiceSearchQueryRequest{DeviceLibraryId: id}) + if edgeXErr != nil { + return edgeXErr + } + if total > 0 { + return errort.NewCommonErr(errort.DeviceLibraryMustDeleteDeviceService, fmt.Errorf("must delete service")) + } + + // 删除驱动前需要查看 驱动所属的产品是否存在 + //_, total, edgeXErr = app.dbClient.ProductsSearch(0, 1, dtos.ProductSearchQueryRequest{ + // DeviceLibraryId: id, + //}) + //if edgeXErr != nil { + // return edgeXErr + //} + //if total > 0 { + // return errort.NewCommonErr(errort.DeviceLibraryMustDeleteProduct, fmt.Errorf("must delete product")) + //} + + app.manager.Remove(id) + + return nil +} + +func (app *driverLibApp) DriverLibById(dlId string) (models.DeviceLibrary, error) { + dl, err := app.dbClient.DeviceLibraryById(dlId) + if err != nil { + app.lc.Errorf("DriverLibById req DeviceLibraryById(%s) err %v", dlId, err) + return models.DeviceLibrary{}, err + } + return dl, nil +} + +func (app *driverLibApp) DeviceLibraryById(ctx context.Context, id string) (models.DeviceLibrary, error) { + dl, err := app.DriverLibById(id) + if err != nil { + return models.DeviceLibrary{}, err + } + dl.OperateStatus = app.manager.GetState(id) + + return dl, nil +} + +// UpgradeDeviceLibrary 升级驱动库版本 +func (app *driverLibApp) UpgradeDeviceLibrary(ctx context.Context, req dtos.DeviceLibraryUpgradeRequest) error { + if app.manager.GetState(req.Id) == constants.OperateStatusInstalling { + return errort.NewCommonErr(errort.DeviceLibraryUpgradeIng, fmt.Errorf("is upgrading, please wait")) + } + + //检查驱动是否存在 + _, edgeXErr := app.DriverLibById(req.Id) + if edgeXErr != nil { + return edgeXErr + } + + if err := app.manager.Upgrade(req.Id, req.Version); err != nil { + return err + } + + return nil +} + +func (app *driverLibApp) UpdateDeviceLibrary(ctx context.Context, update dtos.UpdateDeviceLibrary) error { + dl, dbErr := app.dbClient.DeviceLibraryById(update.Id) + if dbErr != nil { + return dbErr + } + + dtos.ReplaceDeviceLibraryModelFieldsWithDTO(&dl, update) + + dl = app.addOrUpdateSupportVersionConfig(dl) + + if err := app.dbClient.UpdateDeviceLibrary(dl); err != nil { + return err + } + + app.lc.Infof("deviceLibrary(%v) update succ. ", dl.Id) + + return nil +} diff --git a/internal/hummingbird/core/application/driverapp/manage.go b/internal/hummingbird/core/application/driverapp/manage.go new file mode 100644 index 0000000..771153b --- /dev/null +++ b/internal/hummingbird/core/application/driverapp/manage.go @@ -0,0 +1,230 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package driverapp + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + resourceContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + pkgcontainer "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "sync" +) + +type DeviceLibraryManager interface { + GetState(dlId string) string + SetState(dlId, state string) + FilterState(state string) []string + Upgrade(dlId, version string) error + Remove(dlId string) error + // + GetAllImages() []string // 获取所有安装镜像ID + ExistImage(dockerImageId string) bool +} + +type deviceLibraryManager struct { + libs sync.Map + + dic *di.Container + lc logger.LoggingClient + + driverApp interfaces.DriverLibApp + appModel interfaces.DMI +} + +func newDriverLibManager(dic *di.Container, app interfaces.DriverLibApp) *deviceLibraryManager { + return &deviceLibraryManager{ + libs: sync.Map{}, + dic: dic, + lc: pkgcontainer.LoggingClientFrom(dic.Get), + appModel: interfaces.DMIFrom(dic.Get), + driverApp: app, + } +} + +func (m *deviceLibraryManager) GetState(dlId string) string { + state, ok := m.libs.Load(dlId) + if ok { + return state.(string) + } + m.libs.Store(dlId, constants.OperateStatusDefault) + return constants.OperateStatusDefault +} + +func (m *deviceLibraryManager) SetState(dlId, state string) { + m.libs.Store(dlId, state) +} + +func (m *deviceLibraryManager) FilterState(state string) []string { + var list []string + m.libs.Range(func(key, value interface{}) bool { + if value.(string) == state { + list = append(list, key.(string)) + } + return true + }) + + return list +} + +func (m *deviceLibraryManager) Remove(dlId string) error { + dbClient := container.DBClientFrom(m.dic.Get) + dl, err := dbClient.DeviceLibraryById(dlId) + if err != nil { + return err + } + + // 删除自定义驱动 + if err := dbClient.DeleteDeviceLibraryById(dlId); err != nil { + return err + } + + m.asyncRemoveImage(dl) + return nil +} + +func (m *deviceLibraryManager) GetAllImages() []string { + return m.appModel.GetAllApp() +} + +func (m *deviceLibraryManager) ExistImage(dockerImageId string) bool { + return m.appModel.StateApp(dockerImageId) +} + +// updateDL 下载新版本镜像,并更新驱动库版本信息 +func (m *deviceLibraryManager) updateDL(dlId, updateVersion string) error { + // 获取驱动库信息 Refactor: + dl, dc, err := m.driverApp.GetDeviceLibraryAndMirrorConfig(dlId) + if err != nil { + m.SetState(dlId, constants.OperateStatusDefault) + return err + } + + imageId, err := m.downloadVersion(dl, dc, updateVersion) + if err != nil { + return err + } + + old := dl + // 添加驱动库版本 + dl.DockerImageId = imageId + newVersionInfo, isExistVersion := getNewSupportVersion(dl.SupportVersions, dl.Version, updateVersion) + if !isExistVersion { + dl.SupportVersions = append(dl.SupportVersions, newVersionInfo) + } + dl = m.updateDLDefaultVersion(dl, newVersionInfo) + + dbClient := resourceContainer.DBClientFrom(m.dic.Get) + if err := dbClient.UpdateDeviceLibrary(dl); err != nil { + m.lc.Errorf("updateDeviceLibrary %s fail %+v", dl.Id, err) + return err + } + + m.cleanOldVersion(old, dl) + + return nil +} + +// cleanOldVersion 异步清理旧版本镜像 +func (m *deviceLibraryManager) cleanOldVersion(oldDL, dl models.DeviceLibrary) { + if oldDL.DockerImageId == dl.DockerImageId { + return + } + m.asyncRemoveImage(oldDL) +} + +func (m *deviceLibraryManager) asyncRemoveImage(dl models.DeviceLibrary) { + // 镜像删除 + go m.appModel.RemoveApp(dtos.DeviceLibraryFromModel(dl)) +} + +func getNewSupportVersion(versions models.SupportVersions, curVersion, newVersion string) (models.SupportVersion, bool) { + newVersionInfo := models.SupportVersion{} + for _, v := range versions { + if v.Version == newVersion { + return v, true + } + // 将老版本的配置复制到新版本中 + if v.Version == curVersion { + newVersionInfo = v + } + } + + newVersionInfo.Version = newVersion + return newVersionInfo, false +} + +func (m *deviceLibraryManager) updateDLDefaultVersion(dl models.DeviceLibrary, newVersion models.SupportVersion) models.DeviceLibrary { + // 2.驱动库配置更新为新版本的 + //dl.Config = newVersion.ConfigJson + //if dl.Config == "" { + // dl.Config = dtos.GetLibrarySimpleBaseConfig() + //} + //dl.ConfigFile = newVersion.ConfigFile + dl.Version = newVersion.Version + return dl +} + +func (m *deviceLibraryManager) downloadVersion(dl models.DeviceLibrary, dc models.DockerConfig, version string) (string, error) { + // 3.下载应用 + imageId, err := m.appModel.DownApp(dtos.DockerConfigFromModel(dc), dtos.DeviceLibraryFromModel(dl), version) + if err != nil { + return "", err + } + return imageId, nil +} + +func (m *deviceLibraryManager) Upgrade(dlId, updateVersion string) error { + if m.GetState(dlId) == constants.OperateStatusInstalling { + return errort.NewCommonErr(errort.DeviceLibraryUpgradeIng, fmt.Errorf("device library upgradeing")) + } + + // 1. 设置为升级中 + m.SetState(dlId, constants.OperateStatusInstalling) + m.lc.Infof("1.start updateDeviceLibrary %v to version %v", dlId, updateVersion) + + // 下载新版本镜像,并更新驱动库版本信息 + if err := m.updateDL(dlId, updateVersion); err != nil { + m.SetState(dlId, constants.OperateStatusDefault) + m.lc.Errorf("updateDeviceLibrary version fail", err) + return err + } + + m.SetState(dlId, constants.OperateStatusInstalled) + //m.lc.Infof("2.updateDeviceLibrary %v version %v", dlId, updateVersion) + + // 3. 升级驱动实例 + if err := m.upgradeDeviceService(dlId); err != nil { + m.lc.Errorf("3.upgradeDeviceService error %v", err) + return err + } + + return nil +} + +func (m *deviceLibraryManager) upgradeDeviceService(dlId string) error { + dl, err := m.driverApp.DriverLibById(dlId) + if err != nil { + return err + } + + return resourceContainer.DriverServiceAppFrom(m.dic.Get).Upgrade(dl) +} diff --git a/internal/hummingbird/core/application/driverapp/market.go b/internal/hummingbird/core/application/driverapp/market.go new file mode 100644 index 0000000..0fbcbc9 --- /dev/null +++ b/internal/hummingbird/core/application/driverapp/market.go @@ -0,0 +1,36 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package driverapp + +import ( + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" +) + +type driverMarket struct { + dic *di.Container + lc logger.LoggingClient + driverApp interfaces.DriverLibApp +} + +func newDriverMarket(dic *di.Container, app interfaces.DriverLibApp) *driverMarket { + return &driverMarket{ + dic: dic, + lc: container.LoggingClientFrom(dic.Get), + driverApp: app, + } +} \ No newline at end of file diff --git a/internal/hummingbird/core/application/driverapp/mirror.go b/internal/hummingbird/core/application/driverapp/mirror.go new file mode 100644 index 0000000..749ac61 --- /dev/null +++ b/internal/hummingbird/core/application/driverapp/mirror.go @@ -0,0 +1,21 @@ +package driverapp + +import "github.com/winc-link/hummingbird/internal/models" + +func (app *driverLibApp) GetDeviceLibraryAndMirrorConfig(dlId string) (dl models.DeviceLibrary, dc models.DockerConfig, err error) { + // 1. 获取驱动库 + dl, err = app.DriverLibById(dlId) + if err != nil { + app.lc.Errorf("1.DeviceLibraryOperate device library id:%s, err:%n", dlId, err) + return + } + + // 2.获取docker仓库配置 + dc, err = app.dbClient.DockerConfigById(dl.DockerConfigId) + if err != nil { + app.lc.Errorf("2.DeviceLibraryOperate docker hub, id:%s, DockerConfigId:%s, err:%v", dlId, dl.DockerConfigId, err) + return + } + + return +} diff --git a/internal/hummingbird/core/application/driverserviceapp/driverserviceapp.go b/internal/hummingbird/core/application/driverserviceapp/driverserviceapp.go new file mode 100644 index 0000000..1d119f9 --- /dev/null +++ b/internal/hummingbird/core/application/driverserviceapp/driverserviceapp.go @@ -0,0 +1,25 @@ +package driverserviceapp + +import ( + "context" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" +) + +type driverServiceApp struct { + dic *di.Container + lc logger.LoggingClient + + *driverServiceAppM +} + +func NewDriverServiceApp(ctx context.Context, dic *di.Container) interfaces.DriverServiceApp { + return &driverServiceApp{ + dic: dic, + lc: container.LoggingClientFrom(dic.Get), + + driverServiceAppM: newDriverServiceApp(ctx, dic), + } +} diff --git a/internal/hummingbird/core/application/driverserviceapp/servicecfgbase.go b/internal/hummingbird/core/application/driverserviceapp/servicecfgbase.go new file mode 100644 index 0000000..8bd173e --- /dev/null +++ b/internal/hummingbird/core/application/driverserviceapp/servicecfgbase.go @@ -0,0 +1,40 @@ +package driverserviceapp + +import ( + //"gitlab.com/tedge/edgex/internal/models" + //"gitlab.com/tedge/edgex/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/models" +) + +// 驱动运行的配置模版 +func getDriverConfigTemplate(ds models.DeviceService) string { + + return getDefaultDriverConfig() +} + +func getDefaultDriverConfig() string { + return `[Logger] +FileName = "/mnt/logs/driver.log" +LogLevel = "INFO" # DEBUG INFO WARN ERROR + +[Clients] +[Clients.Core] +Address = "hummingbird-core:57081" +UseTLS = false +CertFilePath = "" + +[Service] +ID = "" +Name = "" +ProductList = [] +GwId = "" +LocalKey = "" +Activated = false +[Service.Server] +Address = "0.0.0.0:49991" +UseTLS = false +CertFile = "" +KeyFile = "" + +[CustomConfig]` +} diff --git a/internal/hummingbird/core/application/driverserviceapp/servicemanager.go b/internal/hummingbird/core/application/driverserviceapp/servicemanager.go new file mode 100644 index 0000000..24f1525 --- /dev/null +++ b/internal/hummingbird/core/application/driverserviceapp/servicemanager.go @@ -0,0 +1,830 @@ +package driverserviceapp + +import ( + "bytes" + "context" + "fmt" + "github.com/BurntSushi/toml" + pkgerr "github.com/pkg/errors" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + bootstrapContainer "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "gorm.io/gorm" + "net" + "reflect" + "strconv" + "strings" + + "sync" +) + +// 驱动实例管理 +func newDriverServiceApp(ctx context.Context, dic *di.Container) *driverServiceAppM { + dsManager := &driverServiceAppM{ + state: sync.Map{}, + dic: dic, + lc: bootstrapContainer.LoggingClientFrom(dic.Get), + dbClient: container.DBClientFrom(dic.Get), + ctx: ctx, + appModel: interfaces.DMIFrom(dic.Get), + dsMonitor: make(map[string]*DeviceServiceMonitor), + } + //dsManager.FlushStatsToAgent() + dsManager.initMonitor() + return dsManager +} + +// +type driverServiceAppM struct { + state sync.Map + dic *di.Container + lc logger.LoggingClient + ctx context.Context // Bootstrap init 启动传入的, 用来处理done数据 + + // interfaces + dbClient interfaces.DBClient + appModel interfaces.DMI + + dsMonitor map[string]*DeviceServiceMonitor +} + +func (m *driverServiceAppM) getDriverApp() interfaces.DriverLibApp { + return container.DriverAppFrom(m.dic.Get) +} + +func (m *driverServiceAppM) GetState(id string) int { + state, ok := m.state.Load(id) + if ok { + return state.(int) + } + m.state.Store(id, constants.RunStatusStopped) + + return constants.RunStatusStopped +} + +func (m *driverServiceAppM) SetState(id string, state int) { + m.state.Store(id, state) +} + +func (m *driverServiceAppM) Start(id string) error { + var err error + defer func() { + if err != nil { + m.SetState(id, constants.RunStatusStopped) + } + }() + + if m.InProgress(id) { + return fmt.Errorf("that id(%s) is staring or stopping, do not to start", id) + } + + ds, err := m.Get(context.Background(), id) + if err != nil { + return err + } + dl, err := m.getDriverApp().DriverLibById(ds.DeviceLibraryId) + if err != nil { + return err + } + + driverRunPort, err := utils.GetAvailablePort(ds.GetPort()) + if err != nil { + return errort.NewCommonErr(errort.CreateConfigFileFail, fmt.Errorf("create cofig file faild %w", err)) + } + + // 获取自身服务运行的ip,并组装运行启动的配置 + runConfig, err := m.buildServiceRunCfg(m.appModel.GetSelfIp(), driverRunPort, ds) + if err != nil { + return errort.NewCommonErr(errort.GetAvailablePortFail, fmt.Errorf("get available port fail")) + + } + dtoDs := dtos.DeviceServiceFromModel(ds) + dtoRunCfg := dtos.RunServiceCfg{ + ImageRepo: dl.DockerImageId, + RunConfig: runConfig, + DockerParams: ds.DockerParams, + DriverName: dl.Name, + } + m.SetState(id, constants.RunStatusStarting) + _, err = m.appModel.StartInstance(dtoDs, dtoRunCfg) + if err != nil { + return err + } + m.SetState(id, constants.RunStatusStarted) + + //重新刷新数据 + ds, err = m.Get(context.Background(), id) + if err != nil { + return err + } + + oldBaseAddress := ds.BaseAddress + // 更新驱动服务数据 + ds.BaseAddress = ds.ContainerName + ":" + strconv.Itoa(driverRunPort) + // 更新监控ds 如果不更新ping 定时检测会失效 + if oldBaseAddress != ds.BaseAddress { + err = m.dbClient.UpdateDeviceService(ds) + if err != nil { + return err + } + if _, ok := m.dsMonitor[ds.Id]; ok { + m.dsMonitor[ds.Id].ds = dtos.DeviceServiceFromModel(ds) + } + } + + return nil +} + +func (m *driverServiceAppM) Stop(id string) error { + ds, err := m.Get(context.Background(), id) + if err != nil { + return err + } + + m.SetState(id, constants.RunStatusStopping) + stopErr := m.appModel.StopInstance(dtos.DeviceServiceFromModel(ds)) + if stopErr != nil { + m.SetState(id, constants.RunStatusStopped) + return errort.NewCommonErr(errort.ContainerStopFail, pkgerr.WithMessage(stopErr, "stop driverService fail")) + } + m.SetState(id, constants.RunStatusStopped) + return nil +} + +func (m *driverServiceAppM) ReStart(id string) error { + err := m.Stop(id) + if err != nil { + return fmt.Errorf("dsId(%v), stop err:%v", id, err) + } + err = m.Start(id) + if err != nil { + return fmt.Errorf("dsId(%v), start err:%v", id, err) + } + return nil +} + +// +func (m *driverServiceAppM) Add(ctx context.Context, ds models.DeviceService) error { + if ds.BaseAddress == "" { + address, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:0", "0.0.0.0")) + if err != nil { + return err + } + port, _ := utils.AvailablePort(address) + ds.BaseAddress = ds.ContainerName + ":" + strconv.Itoa(port) + } + + // 处理专家模式配置 + if ds.ExpertMode && len(ds.ExpertModeContent) > 0 { + tmpKv, err := dtos.FromYamlStrToMap(ds.ExpertModeContent) + if err != nil { + return errort.NewCommonErr(errort.DefaultReqParamsError, fmt.Errorf("parse expertModeContent err:%v", err)) + } + ds.Config[constants.ConfigKeyDriver] = tmpKv + } + + ds, err := m.dbClient.AddDeviceService(ds) + if err != nil { + return err + } + + // 添加后台监控 + if _, ok := m.dsMonitor[ds.Id]; ok { + m.dsMonitor[ds.Id].ds = dtos.DeviceServiceFromModel(ds) + } else { + m.dsMonitor[ds.Id] = NewDeviceServiceMonitor(m.ctx, dtos.DeviceServiceFromModel(ds), m.dic) + } + + //go m.FlushStatsToAgent() + //go m.autoAddDevice(ds) + return nil +} + +// +func (m *driverServiceAppM) Update(ctx context.Context, dto dtos.DeviceServiceUpdateRequest) error { + deviceService, edgeXErr := m.Get(ctx, dto.Id) + if edgeXErr != nil { + return edgeXErr + } + + if m.GetState(dto.Id) == constants.RunStatusStarted { + return errort.NewCommonErr(errort.DeviceServiceMustStopService, fmt.Errorf("service(%v) is running not update", deviceService.Id)) + } + dtos.ReplaceDeviceServiceModelFieldsWithDTO(&deviceService, dto) + edgeXErr = m.dbClient.UpdateDeviceService(deviceService) + if edgeXErr != nil { + return edgeXErr + } + return nil +} + +// +// 升级实例: 如果不存在则创建数据、如果存在,但未运行,不做处理、若运行中则重启 +func (m *driverServiceAppM) Upgrade(dl models.DeviceLibrary) error { + dss, _, err := m.Search(m.ctx, dtos.DeviceServiceSearchQueryRequest{DeviceLibraryId: dl.Id}) + // 不存在则创建 + if len(dss) <= 0 { + version := models.SupportVersion{} + for _, v := range dl.SupportVersions { + if v.Version == dl.Version { + version = v + break + } + } + err = m.Add(m.ctx, models.DeviceService{ + //Id: dsCode, + Name: dl.Name, + DeviceLibraryId: dl.Id, + ExpertMode: version.ExpertMode, + ExpertModeContent: version.ExpertModeContent, + DockerParamsSwitch: version.DockerParamsSwitch, + DockerParams: version.DockerParams, + ContainerName: dl.ContainerName, + Config: make(map[string]interface{}), + }) + if err != nil { + m.lc.Errorf("add device service err:%v", err) + return err + } + return nil + } + + ds := dss[0] + // 存在则 判断是否更新 + if m.GetState(ds.Id) != constants.RunStatusStarted { + return nil + } + + // 重启 + if err = m.Stop(ds.Id); err != nil { + m.lc.Errorf("stop deviceService(%s) err:%v", ds.Id, err) + return err + } + if err = m.Start(ds.Id); err != nil { + m.lc.Errorf("start deviceService(%s) err:%v", ds.Id, err) + return err + } + + return nil +} + +func (m *driverServiceAppM) Search(ctx context.Context, req dtos.DeviceServiceSearchQueryRequest) ([]models.DeviceService, uint32, error) { + offset, limit := req.BaseSearchConditionQuery.GetPage() + + deviceServices, total, err := m.dbClient.DeviceServicesSearch(offset, limit, req) + if err != nil { + return deviceServices, 0, err + } + + dlIds := make([]string, 0) + for i, _ := range deviceServices { + dlIds = append(dlIds, deviceServices[i].DeviceLibraryId) + } + dls, _, err := m.getDriverApp().DeviceLibrariesSearch(m.ctx, dtos.DeviceLibrarySearchQueryRequest{ + BaseSearchConditionQuery: dtos.BaseSearchConditionQuery{Ids: dtos.ApiParamsArrayToString(dlIds)}, + }) + + if err != nil { + return deviceServices, 0, err + } + + dlIdMap := make(map[string]models.DeviceLibrary) + for i, _ := range dls { + dlIdMap[dls[i].Id] = dls[i] + } + + for i, v := range deviceServices { + deviceServices[i].RunStatus = m.GetState(v.Id) + if _, ok := dlIdMap[v.DeviceLibraryId]; ok { + deviceServices[i].ImageExist = dlIdMap[v.DeviceLibraryId].OperateStatus == constants.OperateStatusInstalled + } + } + + return deviceServices, total, nil +} + +func (m *driverServiceAppM) Del(ctx context.Context, id string) error { + ds, edgeXErr := m.dbClient.DeviceServiceById(id) + if edgeXErr != nil { + return edgeXErr + } + + // 删除驱动实例前需要查看 实例所属的设备是否存在 + //_, total, edgeXErr := m.getDeviceApp().DevicesSearch(dtos.DeviceSearchQueryRequest{ServiceId: id}) + //if edgeXErr != nil { + // return edgeXErr + //} + //if total > 0 { + // return errort.NewCommonErr(errort.DeviceServiceMustDeleteDevice, fmt.Errorf("must delete device")) + //} + + // 检查容器是否在运行中 + if m.GetState(id) != constants.RunStatusStopped { + return errort.NewCommonErr(errort.DeviceServiceMustStopService, fmt.Errorf("must stop service")) + } + m.dbClient.GetDBInstance().Transaction(func(tx *gorm.DB) error { + err := tx.Model(&models.DeviceService{}).Where("id =?", id).Delete(&models.DeviceService{}).Error + if err != nil { + return err + } + err = tx.Model(&models.Device{}).Where("drive_instance_id = ?", id).Updates(map[string]interface{}{"drive_instance_id": ""}).Error + if err != nil { + return err + } + return nil + }) + + // 删除容器、监控 + err := m.appModel.DeleteInstance(dtos.DeviceServiceFromModel(ds)) + if err != nil { + m.lc.Errorf("DeleteInstance err:%v", err) + } + + // 刷新agent 服务信息 + //go m.FlushStatsToAgent() + // 删除后台监控 + delete(m.dsMonitor, id) + m.state.Delete(id) + return nil +} + +// +func (m *driverServiceAppM) Get(ctx context.Context, id string) (models.DeviceService, error) { + if id == "" { + return models.DeviceService{}, errort.NewCommonErr(errort.DefaultReqParamsError, fmt.Errorf("id(%s) is empty", id)) + } + deviceService, err := m.dbClient.DeviceServiceById(id) + if err != nil { + return deviceService, err + } + deviceService.RunStatus = m.GetState(id) + + dl, _ := m.getDriverApp().DriverLibById(deviceService.DeviceLibraryId) + deviceService.ImageExist = dl.OperateStatus == constants.OperateStatusInstalled + return deviceService, nil +} + +// +func (m *driverServiceAppM) InProgress(id string) bool { + state, ok := m.state.Load(id) + if !ok { + return false + } + if state.(int) == constants.RunStatusStarting || state.(int) == constants.RunStatusStopping { + return true + } + return false +} + +// +//// 监控驱动运行状态 +func (m *driverServiceAppM) initMonitor() { + dbClient := container.DBClientFrom(m.dic.Get) + lc := bootstrapContainer.LoggingClientFrom(m.dic.Get) + ds, _, err := dbClient.DeviceServicesSearch(0, -1, dtos.DeviceServiceSearchQueryRequest{}) + if err != nil { + lc.Errorf("DeviceServicesSearch err %v", err) + return + } + for _, v := range ds { + m.dsMonitor[v.Id] = NewDeviceServiceMonitor(m.ctx, dtos.DeviceServiceFromModel(v), m.dic) + } +} + +// +//func (m *driverServiceAppM) UpdateAdvanceConfig(ctx context.Context, req dtos.UpdateServiceLogLevelConfigRequest) error { +// ds, err := m.Get(ctx, req.Id) +// if err != nil { +// return err +// } +// +// // 更新配置 +// ds.LogLevel = constants.LogLevel(req.LogLevel) +// +// if err = m.dbClient.UpdateDeviceService(ds); err != nil { +// return err +// } +// +// // 通知驱动 +// if !ds.IsRunning() { +// m.lc.Infof("service %s is stop", ds.Id) +// return nil +// } +// if ds.IsDriver() { +// if err = application.DeviceServiceChangeLogLevelCallback(m.ctx, m.dic, ds, m.lc); err != nil { +// return err +// } +// } else { +// if err = application.AppServiceChangeLogLevelCallback(m.ctx, m.dic, ds, m.lc); err != nil { +// return err +// } +// } +// return nil +//} +// +func (m *driverServiceAppM) UpdateRunStatus(ctx context.Context, req dtos.UpdateDeviceServiceRunStatusRequest) error { + // 1.正在处理中,返回错误 + if m.InProgress(req.Id) { + return errort.NewCommonErr(errort.DeviceServiceMustStopDoingService, fmt.Errorf("device service is processing")) + } + + // 2.请求状态和本地状态一致,无需操作 + if req.RunStatus == m.GetState(req.Id) { + m.lc.Infof("driverService state is %d", req.RunStatus) + return nil + } + + _, err := m.Get(ctx, req.Id) + if err != nil { + return err + } + + if req.RunStatus == constants.RunStatusStopped { + if err = m.Stop(req.Id); err != nil { + return err + } + } else if req.RunStatus == constants.RunStatusStarted { + if err = m.Start(req.Id); err != nil { + return err + } + } + + return nil +} + +// 将deviceService里的配置转换到配置文件中然后启动服务 +func (m *driverServiceAppM) buildServiceRunCfg(serviceIp string, runPort int, ds models.DeviceService) (string, error) { + if ds.DriverType == constants.DriverLibTypeDefault { + return m.buildDriverCfg(serviceIp, runPort, ds) + } else if ds.DriverType == constants.DriverLibTypeAppService { + //return m.buildAppCfg(serviceIp, runPort, ds) + } + return "", nil +} + +func (m *driverServiceAppM) buildDriverCfg(localDefaultIp string, runPort int, ds models.DeviceService) (string, error) { + configuration := &dtos.DriverConfig{} + sysConfig := container.ConfigurationFrom(m.dic.Get) + + // 读取模版配置 + if _, err := toml.Decode(getDriverConfigTemplate(ds), configuration); err != nil { + return "", err + } + + // 修改与核心服务通信的ip + for k, v := range configuration.Clients { + if k == "Core" { + data := v + data.Address = strings.Replace(data.Address, "127.0.0.1", localDefaultIp, -1) + data.Address = strings.Split(data.Address, ":")[0] + ":" + strings.Split(sysConfig.RpcServer.Address, ":")[1] + configuration.Clients[k] = data + } + //else if k == "MQTTBroker" { + // driverMqttInfo, err := m.dbClient.DriverMqttAuthInfo(ds.Id) + // if err != nil { + // return "", err + // } + // data := v + // data.Address = strings.Replace(data.Address, "127.0.0.1", localDefaultIp, -1) + // data.Address = strings.Split(data.Address, ":")[0] + ":" + strings.Split(data.Address, ":")[1] + ":" + "21883" + // data.ClientId = driverMqttInfo.ClientId + // data.Username = driverMqttInfo.UserName + // data.Password = driverMqttInfo.Password + // configuration.Clients[k] = data + //} + } + + configuration.Service.ID = ds.Id + configuration.Service.Name = ds.Name + // 驱动服务只开启rpc服务 + configuration.Service.Server.Address = "0.0.0.0:" + strconv.Itoa(runPort) + + if ds.ExpertMode && ds.ExpertModeContent != "" { + configuration.CustomParam = string(ds.ExpertModeContent) + } + + // set log level + configuration.Logger.LogLevel = constants.LogMap[ds.LogLevel] + configuration.Logger.FileName = "/mnt/logs/driver.log" + + var buff bytes.Buffer + e := toml.NewEncoder(&buff) + err := e.Encode(configuration) + if err != nil { + return "", err + } + + return buff.String(), nil +} + +// +////func (m *driverServiceAppM) buildAppCfg(localDefaultIp string, runPort int, ds models.DeviceService) (string, error) { +//// configuration := &dtos.AppServiceConfig{} +//// sysConfig := container.ConfigurationFrom(m.dic.Get) +//// +//// // 读取模版配置 +//// if _, err := toml.Decode(getDriverConfigTemplate(ds), configuration); err != nil { +//// return "", err +//// } +//// +//// // 修改与核心服务通信的ip +//// p, err := strconv.Atoi(strings.Split(sysConfig.RpcServer.Address, ":")[1]) +//// if err != nil { +//// return "", err +//// } +//// configuration.Tedge.Host = localDefaultIp +//// configuration.Tedge.Port = int32(p) +//// +//// configuration.Server.ID = ds.Id +//// configuration.Server.Name = ds.Name +//// // 驱动服务只开启rpc服务 +//// configuration.Server.Host = "0.0.0.0" +//// configuration.Server.Port = int32(runPort) +//// +//// dl, err := m.getDriverApp().DriverLibById(ds.DeviceLibraryId) +//// if err != nil { +//// return "", err +//// } +//// dlc, err := dl.GetConfig() +//// if err != nil { +//// return "", err +//// } +//// +//// // 如果有专家模式,直接用专家模式的yaml转换为toml,不会有小数点问题 +//// if ds.ExpertMode && ds.ExpertModeContent != "" { +//// err := yaml.Unmarshal([]byte(ds.ExpertModeContent), &configuration.CustomConfig) +//// if err != nil { +//// return "", err +//// } +//// } else { +//// // driver 模块做映射, driver通过配置文件进行强制转换 +//// if ds.Config != nil { +//// if driver, ok := ds.Config[constants.ConfigKeyDriver]; ok { +//// finalDriver := make(map[string]interface{}) +//// if _, ok := driver.(map[string]interface{}); ok { +//// for i, v := range driver.(map[string]interface{}) { +//// v = convertCfgDriverKeyType(dlc, i, v) +//// finalDriver[i] = v +//// } +//// configuration.CustomConfig = finalDriver +//// } +//// } +//// } +//// } +//// +//// // set log level +//// configuration.Log.LogLevel = constants.LogMap[ds.LogLevel] +//// +//// var buff bytes.Buffer +//// e := toml.NewEncoder(&buff) +//// err = e.Encode(configuration) +//// if err != nil { +//// return "", err +//// } +//// +//// return buff.String(), nil +////} +// +//// TODO 只针对docker版本 +////func (m *driverServiceAppM) NotifyAddDevice(d models.Device) { +//// // 目前只支持modbus-rtu协议 +//// protocolKey := constants.DriverModbusRtu +//// if _, ok := d.Protocols[protocolKey]["Address"]; !ok { +//// return +//// } +//// +//// // 如果容器没有处于运行状态,不做任何处理 +//// if m.GetState(d.ServiceId) != constants.RunStatusStarted { +//// return +//// } +//// +//// // 重启docker驱动 +//// err := m.ReStart(d.ServiceId) +//// if err != nil { +//// m.lc.Errorf("NotifyAddDevice restart serviceId(%s) err:%v", d.ServiceId, err) +//// } +////} +// +// 将配置数据强制转换为定义的类型,如果定义错误,则不转换 +func convertCfgDriverKeyType(dlc models.DeviceLibraryConfig, key string, value interface{}) interface{} { + var ok bool + if _, ok = dlc.DeviceServer[constants.ConfigKeyDriver]; !ok { + return value + } + rt := reflect.TypeOf(value) + rv := reflect.ValueOf(value) + items := dlc.DeviceServer[constants.ConfigKeyDriver] + for _, v := range items { + if v.Name == key { + switch v.Type { + case constants.DriverConfigTypeInt: + if rt.Kind() == reflect.String { + tmpV, e := strconv.Atoi(rv.String()) + if e != nil { + return value + } + return tmpV + } + if rt.Kind() == reflect.Float64 { + return int(rv.Float()) + } + return value + case constants.DriverConfigTypeFloat: + if rt.Kind() == reflect.String { + tmpV, e := strconv.ParseInt(rv.String(), 10, 64) + if e != nil { + return value + } + return tmpV + } + if rt.Kind() == reflect.Int { + return float64(rv.Int()) + } + return value + default: + // 其他类型目前先不做处理 + return value + } + } + } + return value +} + +// +//// 清理所有驱动资源:包括驱动镜像、 +//func (m *driverServiceAppM) ClearAllContainer() { +// dss, err := m.AllService() +// if err != nil { +// m.lc.Errorf("get all service err:%v", err) +// return +// } +// for _, v := range dss { +// // 停止驱动 +// _ = m.appModel.StopInstance(dtos.DeviceServiceFromModel(v)) +// } +// +// // 删除驱动路径 /var/tedge/edgex-driver-data +// err = utils.RemoveFileOrDir(constants.DriverBaseDir) +// if err != nil { +// m.lc.Errorf("remove driverBaseDir(%s) err:%v", constants.DriverBaseDir, err) +// } +//} +// +//func (m *driverServiceAppM) AllService() ([]models.DeviceService, error) { +// dss, _, err := m.dbClient.DeviceServicesSearch(0, -1, dtos.DeviceServiceSearchQueryRequest{}) +// return dss, err +//} +// +//func (m *driverServiceAppM) FlushStatsToAgent() { +// // 取出所有驱动,进行批量更新 +// dss, _, edgeXErr := m.Search(context.Background(), dtos.DeviceServiceSearchQueryRequest{}) +// if edgeXErr != nil { +// m.lc.Errorf("deviceServicesSearch err %v", edgeXErr) +// return +// } +// +// // 取出所有缓存 +// client := pkgcontainer.AgentClientNameFrom(m.dic.Get) +// ctx := context.Background() +// stats, err := client.GetAllDriverMonitor(ctx) +// if err != nil { +// m.lc.Errorf("http request get all driver monitor err: %v", err) +// return +// } +// statsIdMap := make(map[string]models.ServiceStats) +// for _, v := range stats { +// // 只处理驱动的服务 +// if v.ServiceType == models.ServiceTypeEnumDriver { +// statsIdMap[v.Id] = dtos.FromDTOServiceStatsToModel(v) +// } +// } +// deleteIds := make([]string, 0) +// +// // 对比获取需要删除的id +// for statsId, _ := range statsIdMap { +// deleteId := statsId +// for _, ds := range dss { +// if ds.Id == statsId { +// deleteId = "" +// } +// } +// if deleteId != "" { +// deleteIds = append(deleteIds, statsId) +// } +// } +// +// // 处理添加/更新 +// for _, v := range dss { +// newStats := models.ServiceStats{} +// if _, ok := statsIdMap[v.Id]; ok { +// newStats = statsIdMap[v.Id] +// } +// newStats.Id = v.Id +// newStats.Name = v.Name +// newStats.LogPath = m.appModel.GetInstanceLogPath(dtos.DeviceServiceFromModel(v)) +// newStats.ServiceType = models.ServiceTypeEnumDriver +// err = client.AddServiceMonitor(ctx, dtos.FromModelsServiceStatsToDTO(newStats)) +// if err != nil { +// m.lc.Errorf("http request add service monitor err: %v", err) +// } +// } +// //处理删除 +// for _, v := range deleteIds { +// err = client.DeleteServiceMonitor(ctx, statsIdMap[v].Id) +// if err != nil { +// m.lc.Errorf("http request delete service monitor err: %v", err) +// } +// } +//} +// +//// 异步调用,自动绑定设备、产品、驱动的关系 +//func (m *driverServiceAppM) autoAddDevice(ds models.DeviceService) { +// // app模型的不作处理 +// if ds.DriverType == constants.DriverLibTypeAppService { +// return +// } +// products, _, err := container.ProductItfFrom(m.dic.Get).ProductsSearch(dtos.ProductSearchQueryRequest{DeviceLibraryId: ds.DeviceLibraryId}) +// if err != nil { +// m.lc.Errorf("search product err: %v", err) +// return +// } +// for _, p := range products { +// m.getProductApp().ProductSyncUpdateDeviceServiceId(p) +// } +//} +// +//// 挂载的设备, 目前只支持modbus-rtu这个配置的设备挂载 +//func buildMountDevices(devices []models.Device) []string { +// mountDevices := make([]string, 0) +// for _, v := range devices { +// if address, ok := v.Protocols[constants.DriverModbusRtu]["Address"]; ok { +// mountDevices = append(mountDevices, address) +// } +// } +// return mountDevices +//} +// +//func AtopReportDriverConfigEdit(dic *di.Container, dl models.DeviceLibrary, ds models.DeviceService, lc logger.LoggingClient) { +// if !application.CanRequestAtop(dic) { +// return +// } +// +// if dl.IsInternal && dl.DriverType == constants.DriverLibTypeDefault { +// runConfig, _ := json.Marshal(ds.Config) +// err := application.AtopDataReport(constants.DataType_ConfigUpdate, dtos.AtopDataReportDriverConfigUpdate{ +// DriverCode: ds.DeviceLibraryId, +// OpenDockerEnv: ds.DockerParamsSwitch, +// DockerEnv: ds.DockerParams, +// OpenExpertMode: ds.ExpertMode, +// ExpertMode: ds.ExpertModeContent, +// RunConfig: string(runConfig), +// }) +// if err != nil { +// lc.Warnf("reportDriverConfigEdit err: %v", err) +// } +// } +// lc.Infof("atopReportDriverConfigEdit success dlId(%v)", dl.Id) +//} +// +//func AtopReportDriverRunOrStop(dic *di.Container, dl models.DeviceLibrary, status int, lc logger.LoggingClient) { +// if !application.CanRequestAtop(dic) { +// return +// } +// +// if dl.IsInternal && dl.DriverType == constants.DriverLibTypeDefault { +// // 不上报 停止中、启动中 这种中间状态,不好控制 +// err := application.AtopDataReport(constants.DataType_DriverRunOrStop, dtos.AtopDataReportDriverRunOrStop{ +// DriverCode: dl.Id, +// RunStatus: status, +// }) +// if err != nil { +// lc.Errorf("reportDriverRunOrStop dlId(%v) err: %v", dl.Id, err) +// return +// } +// } +// lc.Infof("reportDriverRunOrStop success dlId(%v)", dl.Id) +//} +// +//func AtopReportDriverDelete(dic *di.Container, dl models.DeviceLibrary, lc logger.LoggingClient) { +// if !application.CanRequestAtop(dic) { +// return +// } +// +// if dl.IsInternal && dl.DriverType == constants.DriverLibTypeDefault { +// err := application.AtopDataReport(constants.DataType_DriverDelete, dtos.AtopDataReportDriverDelete{ +// DriverCode: dl.Id, +// }) +// if err != nil { +// lc.Errorf("reportDriverDelete dlId(%v) err: %v", dl.Id, err) +// return +// } +// } +// lc.Infof("reportDriverDelete success dlId(%v)", dl.Id) +//} diff --git a/internal/hummingbird/core/application/driverserviceapp/servicemonitor.go b/internal/hummingbird/core/application/driverserviceapp/servicemonitor.go new file mode 100644 index 0000000..eee7c3f --- /dev/null +++ b/internal/hummingbird/core/application/driverserviceapp/servicemonitor.go @@ -0,0 +1,105 @@ +package driverserviceapp + +import ( + "context" + "github.com/winc-link/hummingbird/internal/dtos" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "github.com/winc-link/hummingbird/internal/pkg/middleware" + "sync" + "time" + + //"gitlab.com/tedge/edgex/internal/tedge/resource/interfaces" + // + //"gitlab.com/tedge/edgex/internal/dtos" + //"gitlab.com/tedge/edgex/internal/pkg/constants" + pkgContainer "github.com/winc-link/hummingbird/internal/pkg/container" + //"gitlab.com/tedge/edgex/internal/pkg/di" + //"gitlab.com/tedge/edgex/internal/pkg/logger" + //"gitlab.com/tedge/edgex/internal/pkg/middleware" + resourceContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" +) + +/** +检测驱动实例运行状态,通过rpc +*/ +type DeviceServiceMonitor struct { + ds dtos.DeviceService + isRunning bool + ctx context.Context + dic *di.Container + lc logger.LoggingClient + exitChan chan struct{} + mutex sync.RWMutex +} + +func NewDeviceServiceMonitor(ctx context.Context, ds dtos.DeviceService, dic *di.Container) *DeviceServiceMonitor { + dsm := &DeviceServiceMonitor{ + ds: ds, + isRunning: false, + ctx: ctx, + dic: dic, + lc: pkgContainer.LoggingClientFrom(dic.Get), + exitChan: make(chan struct{}), + } + go dsm.monitor() + return dsm +} + +func (dsm *DeviceServiceMonitor) monitor() { + // 监控间隔 + tickTime := time.Second * 5 + timeTickerChan := time.Tick(tickTime) + for { + select { + case <-dsm.ctx.Done(): + dsm.lc.Infof("close to DeviceServiceMonitor dsId: %s", dsm.ds.Id) + return + case <-dsm.exitChan: + dsm.lc.Infof("close to DeviceServiceMonitor dsId: %s", dsm.ds.Id) + return + case <-timeTickerChan: + dsm.CheckServiceAvailable() + } + } +} + +func (dsm *DeviceServiceMonitor) CheckServiceAvailable() { + dsm.mutex.Lock() + defer dsm.mutex.Unlock() + ctx := middleware.WithCorrelationId(context.Background()) + dsApp := resourceContainer.DriverServiceAppFrom(dsm.dic.Get) + _, err := dsApp.Get(dsm.ctx, dsm.ds.Id) + if err != nil { + dsm.lc.Infof("monitor get driver instance err %+v", err.Error()) + dsm.exitChan <- struct{}{} + return + } + + isRunning := interfaces.DMIFrom(dsm.dic.Get).InstanceState(dsm.ds) + + // 驱动运行状态更改 + if dsm.isRunning != isRunning { + dsm.lc.Debugf("id [%s] before status [%v] current status [%v]: %v", dsm.ds.Id, dsm.isRunning, isRunning, middleware.FromContext(ctx)) + // 状态更改上报 + } + + // 更新管理驱动实例状态 + if !dsApp.InProgress(dsm.ds.Id) { + if isRunning { + dsApp.SetState(dsm.ds.Id, constants.RunStatusStarted) + } else { + dsApp.SetState(dsm.ds.Id, constants.RunStatusStopped) + } + } + if dsm.isRunning && !isRunning { + //OfflineDevicesByServiceId(ctx, dsm.dic, dsm.ds.Id) + } + dsm.isRunning = isRunning +} + +func (dsm *DeviceServiceMonitor) Stop() { + close(dsm.exitChan) +} diff --git a/internal/hummingbird/core/application/homepageapp/homepage.go b/internal/hummingbird/core/application/homepageapp/homepage.go new file mode 100644 index 0000000..d08f246 --- /dev/null +++ b/internal/hummingbird/core/application/homepageapp/homepage.go @@ -0,0 +1,159 @@ +/******************************************************************************* + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package homepageapp + +import ( + "context" + "github.com/winc-link/hummingbird/internal/dtos" + resourceContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "time" +) + +const ( + HummingbridDoc = "https://doc.hummingbird.winc-link.com/" +) + +type homePageApp struct { + dic *di.Container + lc logger.LoggingClient + dbClient interfaces.DBClient +} + +func NewHomePageApp(ctx context.Context, dic *di.Container) interfaces.HomePageItf { + dbClient := resourceContainer.DBClientFrom(dic.Get) + return &homePageApp{ + dic: dic, + lc: container.LoggingClientFrom(dic.Get), + dbClient: dbClient, + } +} + +func (h homePageApp) HomePageInfo(ctx context.Context, req dtos.HomePageRequest) (response dtos.HomePageResponse, err error) { + var responseResponse dtos.HomePageResponse + devices, deviceTotal, err := h.dbClient.DevicesSearch(0, -1, dtos.DeviceSearchQueryRequest{}) + var selfDeviceTotal uint32 + for _, device := range devices { + if device.Platform == constants.IotPlatform_LocalIot { + selfDeviceTotal++ + } + } + responseResponse.PageInfo.Device.Total = deviceTotal + responseResponse.PageInfo.Device.Self = selfDeviceTotal + if deviceTotal-selfDeviceTotal < 0 { + responseResponse.PageInfo.Device.Other = 0 + } else { + responseResponse.PageInfo.Device.Other = deviceTotal - selfDeviceTotal + } + + products, productTotal, err := h.dbClient.ProductsSearch(0, -1, false, dtos.ProductSearchQueryRequest{}) + var selfProductTotal uint32 + for _, product := range products { + if product.Platform == constants.IotPlatform_LocalIot { + selfProductTotal++ + } + } + responseResponse.PageInfo.Product.Total = productTotal + responseResponse.PageInfo.Product.Self = selfProductTotal + if productTotal-selfProductTotal < 0 { + responseResponse.PageInfo.Product.Other = 0 + } else { + responseResponse.PageInfo.Product.Other = productTotal - selfProductTotal + } + + responseResponse.PageInfo.CloudInstance.StopCount = responseResponse.PageInfo.CloudInstance.Count - responseResponse.PageInfo.CloudInstance.RunCount + if responseResponse.PageInfo.CloudInstance.StopCount < 0 { + responseResponse.PageInfo.CloudInstance.StopCount = 0 + } + var searchQuickNavigationReq dtos.QuickNavigationSearchQueryRequest + searchQuickNavigationReq.OrderBy = "sort" + quickNavigations, _, _ := h.dbClient.QuickNavigationSearch(0, -1, searchQuickNavigationReq) + navigations := make([]dtos.QuickNavigation, 0) + for _, navigation := range quickNavigations { + navigations = append(navigations, dtos.QuickNavigation{ + Id: navigation.Id, + Name: navigation.Name, + Icon: navigation.Icon, + //JumpLink: navigation.JumpLink, + }) + } + responseResponse.QuickNavigation = navigations + + var searchDocsReq dtos.DocsSearchQueryRequest + searchDocsReq.OrderBy = "sort" + dbDocs, _, _ := h.dbClient.DocsSearch(0, -1, searchDocsReq) + docs := make([]dtos.Doc, 0) + for _, doc := range dbDocs { + docs = append(docs, dtos.Doc{ + Name: doc.Name, + JumpLink: doc.JumpLink, + }) + } + responseResponse.Docs.More = HummingbridDoc + responseResponse.Docs.Doc = docs + + alertRuleApp := resourceContainer.AlertRuleAppNameFrom(h.dic.Get) + alertResp, _ := alertRuleApp.AlertPlate(ctx, time.Now().AddDate(0, 0, -1).UnixMilli()) + responseResponse.AlertPlate = alertResp + + var alertTotal uint32 + for _, alert := range responseResponse.AlertPlate { + alertTotal += uint32(alert.Count) + } + responseResponse.PageInfo.Alert.Total = alertTotal + + //设备消息总数 + var msgGatherReq dtos.MsgGatherSearchQueryRequest + msgGatherReq.Date = append(append(append(append(append(append(msgGatherReq.Date, + time.Now().AddDate(0, 0, -1).Format("2006-01-02")), + time.Now().AddDate(0, 0, -2).Format("2006-01-02")), + time.Now().AddDate(0, 0, -3).Format("2006-01-02")), + time.Now().AddDate(0, 0, -4).Format("2006-01-02")), + time.Now().AddDate(0, 0, -5).Format("2006-01-02")), + time.Now().AddDate(0, 0, -6).Format("2006-01-02")) + + msgGather, _, err := h.dbClient.MsgGatherSearch(0, -1, msgGatherReq) + responseResponse.MsgGather = append(append(append(append(append(append(responseResponse.MsgGather, dtos.MsgGather{ + Date: time.Now().AddDate(0, 0, -1).Format("2006-01-02"), + Count: getMsgGatherCountByDate(msgGather, time.Now().AddDate(0, 0, -1).Format("2006-01-02")), + }), dtos.MsgGather{ + Date: time.Now().AddDate(0, 0, -2).Format("2006-01-02"), + Count: getMsgGatherCountByDate(msgGather, time.Now().AddDate(0, 0, -2).Format("2006-01-02")), + }), dtos.MsgGather{ + Date: time.Now().AddDate(0, 0, -3).Format("2006-01-02"), + Count: getMsgGatherCountByDate(msgGather, time.Now().AddDate(0, 0, -3).Format("2006-01-02")), + }), dtos.MsgGather{ + Date: time.Now().AddDate(0, 0, -4).Format("2006-01-02"), + Count: getMsgGatherCountByDate(msgGather, time.Now().AddDate(0, 0, -4).Format("2006-01-02")), + }), dtos.MsgGather{ + Date: time.Now().AddDate(0, 0, -5).Format("2006-01-02"), + Count: getMsgGatherCountByDate(msgGather, time.Now().AddDate(0, 0, -5).Format("2006-01-02")), + }), dtos.MsgGather{ + Date: time.Now().AddDate(0, 0, -6).Format("2006-01-02"), + Count: getMsgGatherCountByDate(msgGather, time.Now().AddDate(0, 0, -6).Format("2006-01-02")), + }) + return responseResponse, nil +} + +func getMsgGatherCountByDate(msgGather []models.MsgGather, data string) int { + for _, gather := range msgGather { + if gather.Date == data { + return gather.Count + } + } + return 0 +} diff --git a/internal/hummingbird/core/application/languagesdkapp/languagesdkapp.go b/internal/hummingbird/core/application/languagesdkapp/languagesdkapp.go new file mode 100644 index 0000000..c52c9fb --- /dev/null +++ b/internal/hummingbird/core/application/languagesdkapp/languagesdkapp.go @@ -0,0 +1,107 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package languagesdkapp + +import ( + "context" + "encoding/json" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/models" + pkgcontainer "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" +) + +type languageSDKApp struct { + dic *di.Container + lc logger.LoggingClient + dbClient interfaces.DBClient +} + +func (m languageSDKApp) LanguageSDKSearch(ctx context.Context, req dtos.LanguageSDKSearchQueryRequest) ([]dtos.LanguageSDKSearchResponse, uint32, error) { + offset, limit := req.BaseSearchConditionQuery.GetPage() + req.BaseSearchConditionQuery.OrderBy = "sort:asc" + languages, total, err := m.dbClient.LanguageSearch(offset, limit, req) + + if err != nil { + return nil, 0, err + } + libs := make([]dtos.LanguageSDKSearchResponse, len(languages)) + for i, language := range languages { + libs[i] = dtos.LanguageSDKSearchResponse{ + Name: language.Name, + Icon: language.Icon, + Addr: language.Addr, + Description: language.Description, + } + } + return libs, total, nil + +} + +func (m languageSDKApp) Sync(ctx context.Context, versionName string) error { + filePath := versionName + "/language_sdk.json" + cosApp := container.CosAppNameFrom(m.dic.Get) + bs, err := cosApp.Get(filePath) + if err != nil { + m.lc.Errorf(err.Error()) + } + var cosLanguageSdkResp []dtos.LanguageSDK + err = json.Unmarshal(bs, &cosLanguageSdkResp) + if err != nil { + m.lc.Errorf(err.Error()) + } + + for _, sdk := range cosLanguageSdkResp { + if languageSdk, err := m.dbClient.LanguageSdkByName(sdk.Name); err != nil { + createModel := models.LanguageSdk{ + Name: sdk.Name, + Icon: sdk.Icon, + Sort: sdk.Sort, + Addr: sdk.Addr, + Description: sdk.Description, + } + _, err := m.dbClient.AddLanguageSdk(createModel) + if err != nil { + return err + } + } else { + updateModel := models.LanguageSdk{ + Id: languageSdk.Id, + Name: sdk.Name, + Icon: sdk.Icon, + Sort: sdk.Sort, + Addr: sdk.Addr, + Description: sdk.Description, + } + err = m.dbClient.UpdateLanguageSdk(updateModel) + if err != nil { + return err + } + } + } + return nil +} + +func NewLanguageSDKApp(ctx context.Context, dic *di.Container) interfaces.LanguageSDKApp { + app := &languageSDKApp{ + dic: dic, + lc: pkgcontainer.LoggingClientFrom(dic.Get), + dbClient: container.DBClientFrom(dic.Get), + } + return app +} diff --git a/internal/hummingbird/core/application/limit.go b/internal/hummingbird/core/application/limit.go new file mode 100644 index 0000000..2eef67f --- /dev/null +++ b/internal/hummingbird/core/application/limit.go @@ -0,0 +1,25 @@ +package application + +import ( + //"gitlab.com/tedge/edgex/internal/pkg/limit" + //"gitlab.com/tedge/edgex/internal/tedge/resource/config" + "github.com/winc-link/hummingbird/internal/hummingbird/core/config" + "github.com/winc-link/hummingbird/internal/pkg/limit" +) + +type LimitMethodConf struct { + methods map[string]struct{} +} + +//TODO: 接口限流功能需要重构 +func NewLimitMethodConf(configuration config.ConfigurationStruct) limit.LimitMethodConf { + var conf = &LimitMethodConf{methods: make(map[string]struct{})} + for _, method := range configuration.Writable.LimitMethods { + conf.methods[method] = struct{}{} + } + return conf +} + +func (lmc *LimitMethodConf) GetLimitMethods() map[string]struct{} { + return lmc.methods +} diff --git a/internal/hummingbird/core/application/messageapp/ekuipermessage.go b/internal/hummingbird/core/application/messageapp/ekuipermessage.go new file mode 100644 index 0000000..6edab52 --- /dev/null +++ b/internal/hummingbird/core/application/messageapp/ekuipermessage.go @@ -0,0 +1,72 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package messageapp + +import ( + "context" + mqtt "github.com/eclipse/paho.mqtt.golang" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + "github.com/winc-link/hummingbird/internal/pkg/errort" + pkgMQTT "github.com/winc-link/hummingbird/internal/tools/mqttclient" +) + +func (msp *MessageApp) connectMQTT() (mqttClient pkgMQTT.MQTTClient) { + lc := msp.lc + var req dtos.NewMQTTClient + var consumeCallback mqtt.MessageHandler + var err error + req, consumeCallback, err = msp.prepareMqttConnectParams() + if err != nil { + lc.Errorf("ConnectMQTT failed, err:%v", err) + return + } + connF := func(ctx context.Context) { + msp.lc.Info("ekuiper mqtt connect") + } + + disConnF := func(ctx context.Context, msg dtos.CallbackMessage) { + msp.lc.Info("ekuiper mqtt disconnect") + } + + mqttClient, err = pkgMQTT.NewMQTTClient(req, lc, consumeCallback, connF, disConnF) + if err != nil { + err = errort.NewCommonErr(errort.MqttConnFail, err) + lc.Errorf("ConnectMQTT failed, err:%v", err) + } + return mqttClient +} + +func (tmq *MessageApp) prepareMqttConnectParams() (req dtos.NewMQTTClient, consumeCallback mqtt.MessageHandler, err error) { + config := container.ConfigurationFrom(tmq.dic.Get) + + req = dtos.NewMQTTClient{ + Broker: config.MessageQueue.URL(), + ClientId: config.MessageQueue.Optional["ClientId"], + Username: config.MessageQueue.Optional["Username"], + Password: config.MessageQueue.Optional["Password"], + } + consumeCallback = tmq.ekuiperMsgHandle + return +} + +func (tmq *MessageApp) ekuiperMsgHandle(client mqtt.Client, message mqtt.Message) { + +} + +func (tmq *MessageApp) pushMsgToMessageBus(msg []byte) { + config := container.ConfigurationFrom(tmq.dic.Get) + tmq.ekuiperMqttClient.AsyncPublish(nil, config.MessageQueue.PublishTopicPrefix, msg, false) +} diff --git a/internal/hummingbird/core/application/messageapp/messageapp.go b/internal/hummingbird/core/application/messageapp/messageapp.go new file mode 100644 index 0000000..c5f4585 --- /dev/null +++ b/internal/hummingbird/core/application/messageapp/messageapp.go @@ -0,0 +1,124 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package messageapp + +import ( + "context" + "encoding/json" + "github.com/kirinlabs/HttpRequest" + "github.com/winc-link/edge-driver-proto/drivercommon" + "github.com/winc-link/hummingbird/internal/dtos" + coreContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "strconv" + "strings" + "time" + + pkgMQTT "github.com/winc-link/hummingbird/internal/tools/mqttclient" +) + +type MessageApp struct { + dic *di.Container + lc logger.LoggingClient + dbClient interfaces.DBClient + ekuiperMqttClient pkgMQTT.MQTTClient +} + +func NewMessageApp(dic *di.Container) *MessageApp { + lc := container.LoggingClientFrom(dic.Get) + dbClient := coreContainer.DBClientFrom(dic.Get) + msgApp := &MessageApp{ + dic: dic, + dbClient: dbClient, + lc: lc, + } + mqttClient := msgApp.connectMQTT() + msgApp.ekuiperMqttClient = mqttClient + msgApp.initeKuiperStreams() + return msgApp +} + +func (tmq *MessageApp) initeKuiperStreams() { + req := HttpRequest.NewRequest() + r := make(map[string]string) + r["sql"] = "CREATE STREAM mqtt_stream () WITH (DATASOURCE=\"eventbus/in\", FORMAT=\"JSON\",SHARED = \"true\")" + b, _ := json.Marshal(r) + resp, err := req.Post("http://ekuiper:9081/streams", b) + if err != nil { + tmq.lc.Errorf("init ekuiper stream failed error:%+v", err.Error()) + return + } + + if resp.StatusCode() == 201 { + body, err := resp.Body() + if err != nil { + tmq.lc.Errorf("init ekuiper stream failed error:%+v", err.Error()) + return + } + if strings.Contains(string(body), "created") { + tmq.lc.Infof("init ekuiper stream success") + return + } + } else if resp.StatusCode() == 400 { + body, err := resp.Body() + tmq.lc.Infof("init ekuiper stream body", string(body)) + if err != nil { + tmq.lc.Errorf("init ekuiper stream failed error:%+v", err.Error()) + return + } + + if strings.Contains(string(body), "already exists") { + tmq.lc.Infof("init ekuiper stream plug success") + return + } + } else { + tmq.lc.Errorf("init ekuiper stream failed resp code:%+v", resp.StatusCode()) + } +} + +func (tmq *MessageApp) DeviceStatusToMessageBus(ctx context.Context, deviceId, deviceStatus string) { + var messageBus dtos.MessageBus + messageBus.DeviceId = deviceId + messageBus.MessageType = "DEVICE_STATUS" + messageBus.Data = map[string]interface{}{ + "status": deviceStatus, + "time": time.Now().UnixMilli(), + } + b, _ := json.Marshal(messageBus) + tmq.pushMsgToMessageBus(b) + +} +func (tmq *MessageApp) ThingModelMsgReport(ctx context.Context, msg dtos.ThingModelMessage) (*drivercommon.CommonResponse, error) { + tmq.pushMsgToMessageBus(msg.TransformMessageBus()) + persistItf := coreContainer.PersistItfFrom(tmq.dic.Get) + err := persistItf.SaveDeviceThingModelData(msg) + if err != nil { + tmq.lc.Error("saveDeviceThingModelData error:", err.Error()) + } + response := new(drivercommon.CommonResponse) + if err != nil { + response.Success = false + response.Code = strconv.Itoa(errort.KindDatabaseError) + response.ErrorMessage = err.Error() + } else { + response.Code = "0" + response.Success = true + } + return response, nil +} diff --git a/internal/hummingbird/core/application/messagestore/asyncack.go b/internal/hummingbird/core/application/messagestore/asyncack.go new file mode 100644 index 0000000..76a3eb5 --- /dev/null +++ b/internal/hummingbird/core/application/messagestore/asyncack.go @@ -0,0 +1,45 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package messagestore + +import "sync" + +type MsgAckChan struct { + Mu sync.Mutex + Id string + IsClosed bool + DataChan chan interface{} +} + +func (mac *MsgAckChan) TryCloseChan() { + mac.Mu.Lock() + defer mac.Mu.Unlock() + if !mac.IsClosed { + close(mac.DataChan) + mac.IsClosed = true + } +} + +func (mac *MsgAckChan) TrySendDataAndCloseChan(data interface{}) bool { + mac.Mu.Lock() + defer mac.Mu.Unlock() + if !mac.IsClosed { + mac.DataChan <- data + close(mac.DataChan) + mac.IsClosed = true + return true + } + return false +} diff --git a/internal/hummingbird/core/application/messagestore/messagestore.go b/internal/hummingbird/core/application/messagestore/messagestore.go new file mode 100644 index 0000000..3ec548f --- /dev/null +++ b/internal/hummingbird/core/application/messagestore/messagestore.go @@ -0,0 +1,68 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package messagestore + +import ( + "context" + "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "sync" +) + +type MessageStores interface { + StoreMsgId(id string, ch string) + LoadMsgChan(id string) (interface{}, bool) + DeleteMsgId(id string) + GenAckChan(id string) *MsgAckChan +} + +type ( + MessageStore struct { + logger logger.LoggingClient + ctx context.Context + mutex sync.Mutex + wg *sync.WaitGroup + ackMap sync.Map + } +) + +func NewMessageStore(dic *di.Container) *MessageStore { + lc := container.LoggingClientFrom(dic.Get) + return &MessageStore{ + logger: lc, + } +} + +func (wp *MessageStore) StoreMsgId(id string, ch string) { + wp.ackMap.Store(id, ch) +} + +func (wp *MessageStore) DeleteMsgId(id string) { + wp.ackMap.Delete(id) +} + +func (wp *MessageStore) LoadMsgChan(id string) (interface{}, bool) { + return wp.ackMap.Load(id) +} + +func (wp *MessageStore) GenAckChan(id string) *MsgAckChan { + ack := &MsgAckChan{ + Id: id, + DataChan: make(chan interface{}, 1), + } + wp.ackMap.Store(id, ack) + return ack +} diff --git a/internal/hummingbird/core/application/monitor/monitor.go b/internal/hummingbird/core/application/monitor/monitor.go new file mode 100644 index 0000000..b483de3 --- /dev/null +++ b/internal/hummingbird/core/application/monitor/monitor.go @@ -0,0 +1,118 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package monitor + +import ( + "context" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + "github.com/winc-link/hummingbird/internal/pkg/constants" + pkgcontainer "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "sync" + "time" +) + +type monitor struct { + dic *di.Container + lc logger.LoggingClient + serviceMonitorMap sync.Map + ctx context.Context + exitCh chan struct{} + + systemMonitor *systemMonitor +} + +func NewMonitor(ctx context.Context, dic *di.Container) *monitor { + lc := pkgcontainer.LoggingClientFrom(dic.Get) + + m := monitor{ + dic: dic, + lc: lc, + serviceMonitorMap: sync.Map{}, + ctx: context.Background(), + exitCh: make(chan struct{}), + systemMonitor: NewSystemMonitor(dic, lc), + } + + //go m.run() + + return &m +} + +func systemMetricsTypeToTime(t string) (time.Time, time.Time) { + switch t { + case constants.HourMetricsType: + end := time.Now() + start := time.Now().Add(-1 * time.Hour) + return start, end + case constants.HalfDayMetricsType: + end := time.Now() + start := time.Now().Add(-12 * time.Hour) + return start, end + case constants.DayMetricsType: + end := time.Now() + start := time.Now().Add(-24 * time.Hour) + return start, end + default: + end := time.Now() + start := time.Now().Add(-1 * time.Hour) + return start, end + } +} + +func (m *monitor) GetSystemMetrics(ctx context.Context, query dtos.SystemMetricsQuery) (dtos.SystemMetricsResponse, error) { + dbClient := container.DBClientFrom(m.dic.Get) + + start, end := systemMetricsTypeToTime(query.MetricsType) + metrics, err := dbClient.GetSystemMetrics(start.UnixMilli(), end.UnixMilli()) + if err != nil { + return dtos.SystemMetricsResponse{}, err + } + + resp := dtos.SystemMetricsResponse{ + Metrics: make([]dtos.SystemStatResponse, 0), + } + step := 1 + if query.MetricsType == constants.HalfDayMetricsType || query.MetricsType == constants.DayMetricsType { + step = 5 + } + for i := 0; i < len(metrics); i = i + step { + metric := metrics[i] + + item := dtos.SystemStatResponse{ + Timestamp: metric.Timestamp, + CpuUsedPercent: metric.CpuUsedPercent, + MemoryTotal: metric.Memory.Total, + MemoryUsed: metric.Memory.Used, + MemoryUsedPercent: metric.Memory.UsedPercent, + DiskTotal: metric.Disk.Total, + DiskUsed: metric.Disk.Used, + DiskUsedPercent: metric.Disk.UsedPercent, + Openfiles: metric.Openfiles, + } + + if iface, ok := metric.Network[query.Iface]; ok { + item.NetSentBytes = iface.BytesSentPre + item.NetRecvBytes = iface.BytesRecvPre + } + + resp.Metrics = append(resp.Metrics, item) + } + resp.Total = len(resp.Metrics) + + return resp, nil +} diff --git a/internal/hummingbird/core/application/monitor/systemmonitor.go b/internal/hummingbird/core/application/monitor/systemmonitor.go new file mode 100644 index 0000000..3fd9cd7 --- /dev/null +++ b/internal/hummingbird/core/application/monitor/systemmonitor.go @@ -0,0 +1,218 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package monitor + +import ( + "context" + "fmt" + "github.com/shirou/gopsutil/v3/cpu" + "github.com/shirou/gopsutil/v3/disk" + "github.com/shirou/gopsutil/v3/host" + "github.com/shirou/gopsutil/v3/load" + "github.com/shirou/gopsutil/v3/mem" + "github.com/shirou/gopsutil/v3/net" + "github.com/shirou/gopsutil/v3/process" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "strconv" + "time" +) + +type systemMonitor struct { + ctx context.Context + dic *di.Container + dbClient interfaces.DBClient + lc logger.LoggingClient + exitCh chan struct{} + + ethMap map[string]*dtos.SystemNetwork + //aa *AlertApplication + diskAlert int + cpuAlert int +} + +func NewSystemMonitor(dic *di.Container, lc logger.LoggingClient) *systemMonitor { + dbClient := container.DBClientFrom(dic.Get) + ctx := context.Background() + m := systemMonitor{ + ctx: ctx, + dic: dic, + dbClient: dbClient, + lc: lc, + ethMap: make(map[string]*dtos.SystemNetwork), + exitCh: make(chan struct{}), + //aa: NewAlertApplication(ctx, dic, lc), + diskAlert: 3, // 默认告警 3 次 + cpuAlert: 3, + } + + go m.run() + + return &m +} + +func (m *systemMonitor) run() { + tick := time.Tick(time.Minute) // 1 minute + tickClear := time.Tick(24 * time.Hour) // 每24小时删除一次数据 + go func() { + for { + select { + case <-tick: + metrics := m.collect() + if err := m.dbClient.UpdateSystemMetrics(metrics); err != nil { + m.lc.Errorf("failed to UpdateSystemMetrics %v", err) + } + + //m.reportSystemMetricsAlert(metrics) + case <-tickClear: + m.clearMetrics() + case <-m.exitCh: + return + } + } + }() +} + +func (m *systemMonitor) clearMetrics() { + min := "0" + max := strconv.FormatInt(time.Now().Add(-24*time.Hour).UnixMilli(), 10) + m.lc.Infof("remove system metrics data from %v to %v", min, max) + if err := m.dbClient.RemoveRangeSystemMetrics(min, max); err != nil { + m.lc.Error("failed to clearMetrics", err) + } +} + +func (m *systemMonitor) Close() { + close(m.exitCh) +} + +func (m *systemMonitor) collect() dtos.SystemMetrics { + return dtos.SystemMetrics{ + Timestamp: time.Now().UnixMilli(), + CpuUsedPercent: getCpu(), + CpuAvg: getCpuLoad(), + Memory: getMemory(), + Network: getNetwork(m.ethMap), + Disk: getDisk(), + Openfiles: getOpenfiles(), + } +} + +func getMemory() dtos.SystemMemory { + v, _ := mem.VirtualMemory() + + return dtos.SystemMemory{ + Total: v.Total, + Used: v.Used, + UsedPercent: v.UsedPercent, + } +} + +func getCpu() float64 { + // cpu的使用率 + totalPercent, _ := cpu.Percent(0, false) + if len(totalPercent) <= 0 { + return 0 + } + return totalPercent[0] +} + +func getCpuLoad() float64 { + // cpu的使用率 + avg, _ := load.Avg() + if avg == nil { + return 0 + } + return avg.Load1 +} + +func getDisk() dtos.SystemDisk { + // 目录 / 的磁盘使用率 + usage, _ := disk.Usage("/") + return dtos.SystemDisk{ + Path: "/", + Total: usage.Total, + Used: usage.Used, + UsedPercent: usage.UsedPercent, + } +} + +func getNetwork(ethMap map[string]*dtos.SystemNetwork) map[string]dtos.SystemNetwork { + stats := make(map[string]dtos.SystemNetwork) + + info, _ := net.IOCounters(true) + for _, v := range info { + ethName := v.Name + if !utils.CheckNetIface(ethName) { + continue + } + if v.BytesSent <= 0 && v.BytesRecv <= 0 { + continue + } + + _, ok := ethMap[ethName] + if !ok { + ethMap[ethName] = &dtos.SystemNetwork{} + } + ethItem := ethMap[ethName] + + var ( + byteRecvPre uint64 + byteSentPre uint64 + ) + + now := time.Now().Unix() + if ethItem.Last == 0 { + // 第一次采集,没有初始值,不计算 + } else { + byteRecvPre = v.BytesRecv - ethItem.BytesRecv + byteSentPre = v.BytesSent - ethItem.BytesSent + } + + item := dtos.SystemNetwork{ + Name: ethName, + BytesSent: v.BytesSent, + BytesRecv: v.BytesRecv, + BytesRecvPre: byteRecvPre, + BytesSentPre: byteSentPre, + Last: now, + } + stats[ethName] = item + ethMap[ethName] = &item + } + return stats +} + +func getOpenfiles() int { + // only linux + // https://github.com/shirou/gopsutil#process-class + processes, _ := process.Processes() + var openfiles int + for _, pid := range processes { + files, _ := pid.OpenFiles() + openfiles += len(files) + } + return openfiles +} + +func getPlatform() { + // 查看平台信息 + platform, family, version, _ := host.PlatformInformation() + fmt.Printf("platform = %v ,family = %v , version = %v \n", platform, family, version) +} diff --git a/internal/hummingbird/core/application/persistence/thinkmodel.go b/internal/hummingbird/core/application/persistence/thinkmodel.go new file mode 100644 index 0000000..5cb5faf --- /dev/null +++ b/internal/hummingbird/core/application/persistence/thinkmodel.go @@ -0,0 +1,802 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package persistence + +import ( + "context" + "encoding/json" + "errors" + "github.com/winc-link/edge-driver-proto/thingmodel" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/messagestore" + resourceContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "strconv" +) + +type persistApp struct { + dic *di.Container + lc logger.LoggingClient + dbClient interfaces.DBClient + dataDbClient interfaces.DataDBClient +} + +func NewPersistApp(dic *di.Container) *persistApp { + lc := container.LoggingClientFrom(dic.Get) + dbClient := resourceContainer.DBClientFrom(dic.Get) + dataDbClient := resourceContainer.DataDBClientFrom(dic.Get) + pstApp := &persistApp{ + lc: lc, + dic: dic, + dbClient: dbClient, + dataDbClient: dataDbClient, + } + + return pstApp +} + +func (pst *persistApp) SaveDeviceThingModelData(req dtos.ThingModelMessage) error { + switch pst.dataDbClient.GetDataDBType() { + case constants.LevelDB: + return pst.saveDeviceThingModelToLevelDB(req) + case constants.TDengine: + return pst.saveDeviceThingModelToTdengine(req) + default: + return nil + } +} + +func (pst *persistApp) saveDeviceThingModelToLevelDB(req dtos.ThingModelMessage) error { + switch req.GetOpType() { + case thingmodel.OperationType_PROPERTY_REPORT: + propertyMsg, err := req.TransformMessageDataByProperty() + if err != nil { + return err + } + kvs := make(map[string]interface{}) + for s, data := range propertyMsg.Data { + key := generatePropertyLeveldbKey(req.Cid, s, data.Time) + value, err := data.Marshal() + if err != nil { + continue + } + kvs[key] = value + } + //批量写。 + err = pst.dataDbClient.Insert(context.Background(), "", kvs) + if err != nil { + return err + } + case thingmodel.OperationType_EVENT_REPORT: + eventMsg, err := req.TransformMessageDataByEvent() + if err != nil { + return err + } + kvs := make(map[string]interface{}) + var key string + key = generateEventLeveldbKey(req.Cid, eventMsg.Data.EventCode, eventMsg.Data.EventTime) + value, _ := eventMsg.Data.Marshal() + kvs[key] = value + //批量写。 + err = pst.dataDbClient.Insert(context.Background(), "", kvs) + if err != nil { + return err + } + case thingmodel.OperationType_SERVICE_EXECUTE: + serviceMsg, err := req.TransformMessageDataByService() + kvs := make(map[string]interface{}) + var key string + key = generateActionLeveldbKey(req.Cid, serviceMsg.Code, serviceMsg.Time) + value, _ := serviceMsg.Marshal() + kvs[key] = value + err = pst.dataDbClient.Insert(context.Background(), "", kvs) + + if err != nil { + return err + } + case thingmodel.OperationType_SERVICE_EXECUTE_RESPONSE: + serviceMsg, err := req.TransformMessageDataByServiceExec() + if err != nil { + return err + } + + device, err := pst.dbClient.DeviceById(req.Cid) + if err != nil { + return err + } + + product, err := pst.dbClient.ProductById(device.ProductId) + if err != nil { + return err + } + + var find bool + var callType constants.CallType + + for _, action := range product.Actions { + if action.Code == serviceMsg.Code { + find = true + callType = action.CallType + break + } + } + + if !find { + return errors.New("") + } + + if callType == constants.CallTypeSync { + messageStore := resourceContainer.MessageStoreItfFrom(pst.dic.Get) + ack, ok := messageStore.LoadMsgChan(serviceMsg.MsgId) + if !ok { + //可能是超时了。 + return nil + } + + if v, ok := ack.(*messagestore.MsgAckChan); ok { + v.TrySendDataAndCloseChan(serviceMsg.OutputParams) + messageStore.DeleteMsgId(serviceMsg.MsgId) + } + + } else if callType == constants.CallTypeAsync { + kvs := make(map[string]interface{}) + var key string + key = generateActionLeveldbKey(req.Cid, serviceMsg.Code, serviceMsg.Time) + value, _ := serviceMsg.Marshal() + kvs[key] = value + err = pst.dataDbClient.Insert(context.Background(), "", kvs) + + if err != nil { + return err + } + } + case thingmodel.OperationType_DATA_BATCH_REPORT: + msg, err := req.TransformMessageDataByBatchReport() + if err != nil { + return err + } + t := msg.Time + kvs := make(map[string]interface{}) + + for code, property := range msg.Data.Properties { + var data dtos.ReportData + data.Value = property.Value + data.Time = t + key := generatePropertyLeveldbKey(req.Cid, code, t) + value, err := data.Marshal() + if err != nil { + continue + } + kvs[key] = value + } + + for code, event := range msg.Data.Events { + var data dtos.EventData + data.OutputParams = event.OutputParams + data.EventTime = t + data.EventCode = code + key := generateEventLeveldbKey(req.Cid, code, t) + value, _ := data.Marshal() + kvs[key] = value + + } + //批量写。 + err = pst.dataDbClient.Insert(context.Background(), "", kvs) + + if err != nil { + return err + } + return nil + } + return nil +} + +func (pst *persistApp) saveDeviceThingModelToTdengine(req dtos.ThingModelMessage) error { + switch req.GetOpType() { + case thingmodel.OperationType_PROPERTY_REPORT: + propertyMsg, err := req.TransformMessageDataByProperty() + if err != nil { + return err + } + data := make(map[string]interface{}) + for s, reportData := range propertyMsg.Data { + data[s] = reportData.Value + } + err = pst.dataDbClient.Insert(context.Background(), constants.DB_PREFIX+req.Cid, data) + if err != nil { + return err + } + + case thingmodel.OperationType_EVENT_REPORT: + eventMsg, err := req.TransformMessageDataByEvent() + if err != nil { + return err + } + data := make(map[string]interface{}) + data[eventMsg.Data.EventCode] = eventMsg.Data + err = pst.dataDbClient.Insert(context.Background(), constants.DB_PREFIX+req.Cid, data) + if err != nil { + return err + } + + case thingmodel.OperationType_SERVICE_EXECUTE: + serviceMsg, err := req.TransformMessageDataByService() + if err != nil { + return err + } + v, _ := serviceMsg.Marshal() + data := make(map[string]interface{}) + data[serviceMsg.Code] = string(v) + err = pst.dataDbClient.Insert(context.Background(), constants.DB_PREFIX+req.Cid, data) + if err != nil { + return err + } + case thingmodel.OperationType_DATA_BATCH_REPORT: + + case thingmodel.OperationType_SERVICE_EXECUTE_RESPONSE: + serviceMsg, err := req.TransformMessageDataByServiceExec() + if err != nil { + return err + } + + device, err := pst.dbClient.DeviceById(req.Cid) + if err != nil { + return err + } + + product, err := pst.dbClient.ProductById(device.ProductId) + if err != nil { + return err + } + + var find bool + var callType constants.CallType + + for _, action := range product.Actions { + if action.Code == serviceMsg.Code { + find = true + callType = action.CallType + break + } + } + + if !find { + return errors.New("") + } + + if callType == constants.CallTypeSync { + messageStore := resourceContainer.MessageStoreItfFrom(pst.dic.Get) + ack, ok := messageStore.LoadMsgChan(serviceMsg.MsgId) + if !ok { + //可能是超时了。 + return nil + } + + if v, ok := ack.(*messagestore.MsgAckChan); ok { + v.TrySendDataAndCloseChan(serviceMsg.OutputParams) + messageStore.DeleteMsgId(serviceMsg.MsgId) + } + + } else if callType == constants.CallTypeAsync { + v, _ := serviceMsg.Marshal() + data := make(map[string]interface{}) + data[serviceMsg.Code] = string(v) + err = pst.dataDbClient.Insert(context.Background(), constants.DB_PREFIX+req.Cid, data) + if err != nil { + return err + } + } + + } + return nil +} + +func generatePropertyLeveldbKey(cid, code string, reportTime int64) string { + return cid + "-" + constants.Property + "-" + code + "-" + strconv.Itoa(int(reportTime)) +} + +func generateOncePropertyLeveldbKey(cid, code string) string { + return cid + "-" + constants.Property + "-" + code +} + +func generateEventLeveldbKey(cid, code string, reportTime int64) string { + return cid + "-" + constants.Event + "-" + code + "-" + strconv.Itoa(int(reportTime)) +} + +func generateOnceEventLeveldbKey(cid, code string) string { + return cid + "-" + constants.Event + "-" + code +} + +func generateActionLeveldbKey(cid, code string, reportTime int64) string { + return cid + "-" + constants.Action + "-" + code + "-" + strconv.Itoa(int(reportTime)) +} + +func generateOnceActionLeveldbKey(cid, code string) string { + return cid + "-" + constants.Action + "-" + code +} + +func (pst *persistApp) searchDeviceThingModelPropertyDataFromLevelDB(req dtos.ThingModelPropertyDataRequest) (interface{}, error) { + deviceInfo, err := pst.dbClient.DeviceById(req.DeviceId) + if err != nil { + return nil, err + } + var productInfo models.Product + response := make([]dtos.ThingModelDataResponse, 0) + productInfo, err = pst.dbClient.ProductById(deviceInfo.ProductId) + if err != nil { + return nil, err + } + if req.Code == "" { + for _, property := range productInfo.Properties { + req.Code = property.Code + ksv, _, err := pst.dataDbClient.GetDeviceProperty(req, deviceInfo) + if err != nil { + pst.lc.Errorf("GetDeviceProperty error %+v", err) + continue + } + var reportData dtos.ReportData + if len(ksv) > 0 { + reportData = ksv[0] + } + var unit string + if property.TypeSpec.Type == constants.SpecsTypeInt || property.TypeSpec.Type == constants.SpecsTypeFloat { + var typeSpecIntOrFloat models.TypeSpecIntOrFloat + _ = json.Unmarshal([]byte(property.TypeSpec.Specs), &typeSpecIntOrFloat) + unit = typeSpecIntOrFloat.Unit + } else if property.TypeSpec.Type == constants.SpecsTypeEnum { + //enum 的单位需要特殊处理一下 + enumTypeSpec := make(map[string]string) + _ = json.Unmarshal([]byte(property.TypeSpec.Specs), &enumTypeSpec) + + for key, value := range enumTypeSpec { + s := utils.InterfaceToString(reportData.Value) + if key == s { + unit = value + } + } + } + response = append(response, dtos.ThingModelDataResponse{ + ReportData: reportData, + Code: property.Code, + DataType: string(property.TypeSpec.Type), + Name: property.Name, + Unit: unit, + AccessMode: property.AccessMode, + }) + } + } + return response, nil +} + +func (pst *persistApp) searchDeviceThingModelHistoryPropertyDataFromTDengine(req dtos.ThingModelPropertyDataRequest) (interface{}, int, error) { + var count int + deviceInfo, err := pst.dbClient.DeviceById(req.DeviceId) + if err != nil { + return nil, count, err + } + var productInfo models.Product + productInfo, err = pst.dbClient.ProductById(deviceInfo.ProductId) + if err != nil { + return nil, count, err + } + var response []dtos.ReportData + + for _, property := range productInfo.Properties { + if property.Code == req.Code { + req.Code = property.Code + response, count, err = pst.dataDbClient.GetDeviceProperty(req, deviceInfo) + if err != nil { + pst.lc.Errorf("GetDeviceProperty error %+v", err) + } + var typeSpecIntOrFloat models.TypeSpecIntOrFloat + if property.TypeSpec.Type == constants.SpecsTypeInt || property.TypeSpec.Type == constants.SpecsTypeFloat { + _ = json.Unmarshal([]byte(property.TypeSpec.Specs), &typeSpecIntOrFloat) + } + + if typeSpecIntOrFloat.Unit == "" { + typeSpecIntOrFloat.Unit = "-" + } + break + } + } + return response, count, nil +} + +func (pst *persistApp) searchDeviceThingModelPropertyDataFromTDengine(req dtos.ThingModelPropertyDataRequest) (interface{}, error) { + deviceInfo, err := pst.dbClient.DeviceById(req.DeviceId) + if err != nil { + return nil, err + } + var productInfo models.Product + response := make([]dtos.ThingModelDataResponse, 0) + productInfo, err = pst.dbClient.ProductById(deviceInfo.ProductId) + if err != nil { + return nil, err + } + if req.Code == "" { + for _, property := range productInfo.Properties { + req.Code = property.Code + ksv, _, err := pst.dataDbClient.GetDeviceProperty(req, deviceInfo) + if err != nil { + pst.lc.Errorf("GetDeviceProperty error %+v", err) + continue + } + var reportData dtos.ReportData + if len(ksv) > 0 { + reportData = ksv[0] + } + var unit string + if property.TypeSpec.Type == constants.SpecsTypeInt || property.TypeSpec.Type == constants.SpecsTypeFloat { + var typeSpecIntOrFloat models.TypeSpecIntOrFloat + _ = json.Unmarshal([]byte(property.TypeSpec.Specs), &typeSpecIntOrFloat) + unit = typeSpecIntOrFloat.Unit + } else if property.TypeSpec.Type == constants.SpecsTypeEnum { + //enum 的单位需要特殊处理一下 + enumTypeSpec := make(map[string]string) + _ = json.Unmarshal([]byte(property.TypeSpec.Specs), &enumTypeSpec) + //pst.lc.Info("reportDataType enumTypeSpec", enumTypeSpec) + + for key, value := range enumTypeSpec { + s := utils.InterfaceToString(reportData.Value) + if key == s { + unit = value + } + } + } + + if unit == "" { + unit = "-" + } + response = append(response, dtos.ThingModelDataResponse{ + ReportData: reportData, + Code: property.Code, + DataType: string(property.TypeSpec.Type), + Name: property.Name, + Unit: unit, + AccessMode: property.AccessMode, + }) + } + } + return response, nil +} + +func (pst *persistApp) SearchDeviceThingModelPropertyData(req dtos.ThingModelPropertyDataRequest) (interface{}, error) { + + switch pst.dataDbClient.GetDataDBType() { + case constants.LevelDB: + return pst.searchDeviceThingModelPropertyDataFromLevelDB(req) + + case constants.TDengine: + return pst.searchDeviceThingModelPropertyDataFromTDengine(req) + + default: + return make([]interface{}, 0), nil + } +} + +func (pst *persistApp) searchDeviceThingModelHistoryPropertyDataFromLevelDB(req dtos.ThingModelPropertyDataRequest) (interface{}, int, error) { + deviceInfo, err := pst.dbClient.DeviceById(req.DeviceId) + if err != nil { + return nil, 0, err + } + var count int + var productInfo models.Product + productInfo, err = pst.dbClient.ProductById(deviceInfo.ProductId) + if err != nil { + return nil, 0, err + } + var response []dtos.ReportData + for _, property := range productInfo.Properties { + if property.Code == req.Code { + response, count, err = pst.dataDbClient.GetDeviceProperty(req, deviceInfo) + if err != nil { + pst.lc.Errorf("GetHistoryDeviceProperty error %+v", err) + } + break + } + } + return response, count, nil +} + +func (pst *persistApp) SearchDeviceThingModelHistoryPropertyData(req dtos.ThingModelPropertyDataRequest) (interface{}, int, error) { + switch pst.dataDbClient.GetDataDBType() { + case constants.LevelDB: + return pst.searchDeviceThingModelHistoryPropertyDataFromLevelDB(req) + case constants.TDengine: + return pst.searchDeviceThingModelHistoryPropertyDataFromTDengine(req) + } + response := make([]interface{}, 0) + return response, 0, nil +} + +func (pst *persistApp) searchDeviceThingModelServiceDataFromLevelDB(req dtos.ThingModelServiceDataRequest) ([]dtos.ThingModelServiceDataResponse, int, error) { + var count int + deviceInfo, err := pst.dbClient.DeviceById(req.DeviceId) + if err != nil { + return nil, count, err + } + var productInfo models.Product + productInfo, err = pst.dbClient.ProductById(deviceInfo.ProductId) + if err != nil { + return nil, count, err + } + var response dtos.ThingModelServiceDataResponseArray + if req.Code == "" { + for _, action := range productInfo.Actions { + req.Code = action.Code + var ksv []dtos.SaveServiceIssueData + ksv, count, err = pst.dataDbClient.GetDeviceService(req, deviceInfo, productInfo) + if err != nil { + continue + } + for _, data := range ksv { + response = append(response, dtos.ThingModelServiceDataResponse{ + ServiceName: action.Name, + Code: data.Code, + InputData: data.InputParams, + OutputData: data.OutputParams, + ReportTime: data.Time, + }) + } + } + } else { + var ksv []dtos.SaveServiceIssueData + ksv, count, err = pst.dataDbClient.GetDeviceService(req, deviceInfo, productInfo) + + if err != nil { + return nil, count, err + } + for _, data := range ksv { + name := getServiceName(productInfo.Actions, data.Code) + response = append(response, dtos.ThingModelServiceDataResponse{ + ServiceName: name, + Code: data.Code, + InputData: data.InputParams, + OutputData: data.OutputParams, + ReportTime: data.Time, + }) + } + } + return response, count, nil +} + +func (pst *persistApp) searchDeviceThingModelServiceDataFromTDengine(req dtos.ThingModelServiceDataRequest) ([]dtos.ThingModelServiceDataResponse, int, error) { + + var response dtos.ThingModelServiceDataResponseArray + var count int + deviceInfo, err := pst.dbClient.DeviceById(req.DeviceId) + if err != nil { + return response, count, err + } + var productInfo models.Product + productInfo, err = pst.dbClient.ProductById(deviceInfo.ProductId) + if err != nil { + return response, count, err + + } + serviceData, count, err := pst.dataDbClient.GetDeviceService(req, deviceInfo, productInfo) + + if err != nil { + return response, count, err + } + + for _, data := range serviceData { + name := getServiceName(productInfo.Actions, data.Code) + response = append(response, dtos.ThingModelServiceDataResponse{ + InputData: data.InputParams, + OutputData: data.OutputParams, + Code: data.Code, + ReportTime: data.Time, + ServiceName: name, + }) + } + + return response, count, nil + +} + +func (pst *persistApp) SearchDeviceThingModelServiceData(req dtos.ThingModelServiceDataRequest) ([]dtos.ThingModelServiceDataResponse, int, error) { + var response dtos.ThingModelServiceDataResponseArray + switch pst.dataDbClient.GetDataDBType() { + case constants.LevelDB: + return pst.searchDeviceThingModelServiceDataFromLevelDB(req) + case constants.TDengine: + return pst.searchDeviceThingModelServiceDataFromTDengine(req) + } + return response, 0, nil +} + +func (pst *persistApp) searchDeviceThingModelEventDataFromLevelDB(req dtos.ThingModelEventDataRequest) (dtos.ThingModelEventDataResponseArray, int, error) { + var count int + + deviceInfo, err := pst.dbClient.DeviceById(req.DeviceId) + if err != nil { + return nil, count, err + } + var response dtos.ThingModelEventDataResponseArray + var productInfo models.Product + productInfo, err = pst.dbClient.ProductById(deviceInfo.ProductId) + if err != nil { + return nil, count, err + } + var ksv []dtos.EventData + ksv, count, err = pst.dataDbClient.GetDeviceEvent(req, deviceInfo, productInfo) + if err != nil { + return nil, count, err + } + for _, data := range ksv { + var eventType string + var name string + eventType, name = getEventTypeAndName(productInfo.Events, data.EventCode) + response = append(response, dtos.ThingModelEventDataResponse{ + EventCode: data.EventCode, + EventType: eventType, + OutputData: data.OutputParams, + ReportTime: data.EventTime, + Name: name, + }) + } + return response, count, nil +} + +func (pst *persistApp) searchDeviceThingModelEventDataFromTDengine(req dtos.ThingModelEventDataRequest) (dtos.ThingModelEventDataResponseArray, int, error) { + var response dtos.ThingModelEventDataResponseArray + var count int + deviceInfo, err := pst.dbClient.DeviceById(req.DeviceId) + if err != nil { + return response, count, err + } + var productInfo models.Product + productInfo, err = pst.dbClient.ProductById(deviceInfo.ProductId) + if err != nil { + return response, count, err + + } + eventData, count, err := pst.dataDbClient.GetDeviceEvent(req, deviceInfo, productInfo) + + if err != nil { + return response, count, err + } + + for _, data := range eventData { + var eventType string + var name string + eventType, name = getEventTypeAndName(productInfo.Events, data.EventCode) + response = append(response, dtos.ThingModelEventDataResponse{ + EventCode: data.EventCode, + EventType: eventType, + Name: name, + OutputData: data.OutputParams, + ReportTime: data.EventTime, + }) + + } + return response, count, nil +} + +func (pst *persistApp) SearchDeviceThingModelEventData(req dtos.ThingModelEventDataRequest) ([]dtos.ThingModelEventDataResponse, int, error) { + var response dtos.ThingModelEventDataResponseArray + switch pst.dataDbClient.GetDataDBType() { + case constants.LevelDB: + return pst.searchDeviceThingModelEventDataFromLevelDB(req) + case constants.TDengine: + return pst.searchDeviceThingModelEventDataFromTDengine(req) + } + return response, 0, nil + +} + +func (pst *persistApp) searchDeviceMsgCountFromLevelDB(startTime, endTime int64) (int, error) { + var ( + count int + err error + ) + + devices, _, err := pst.dbClient.DevicesSearch(0, -1, dtos.DeviceSearchQueryRequest{}) + if err != nil { + return 0, err + } + + for _, device := range devices { + product, err := pst.dbClient.ProductById(device.ProductId) + if err != nil { + pst.lc.Errorf("search product:", err) + } + for _, property := range product.Properties { + var req dtos.ThingModelPropertyDataRequest + req.DeviceId = device.Id + req.Code = property.Code + req.Range = append(req.Range, startTime, endTime) + propertyCount, err := pst.dataDbClient.GetDevicePropertyCount(req) + if err != nil { + return 0, err + } + count += propertyCount + } + + for _, event := range product.Events { + var req dtos.ThingModelEventDataRequest + req.DeviceId = device.Id + req.EventCode = event.Code + req.Range = append(req.Range, startTime, endTime) + eventCount, err := pst.dataDbClient.GetDeviceEventCount(req) + if err != nil { + return 0, err + } + count += eventCount + } + } + + return count, err + +} + +func (pst *persistApp) searchDeviceMsgCountFromTDengine(startTime, endTime int64) (int, error) { + var ( + count int + err error + ) + + devices, _, err := pst.dbClient.DevicesSearch(0, -1, dtos.DeviceSearchQueryRequest{}) + if err != nil { + return 0, err + } + + for _, device := range devices { + msgCount, err := pst.dataDbClient.GetDeviceMsgCountByGiveTime(device.Id, startTime, endTime) + if err != nil { + return 0, err + } + count += msgCount + } + return 0, nil +} + +// SearchDeviceMsgCount 统计设备的消息总数(属性、事件都算在内) +func (pst *persistApp) SearchDeviceMsgCount(startTime, endTime int64) (int, error) { + + switch pst.dataDbClient.GetDataDBType() { + case constants.LevelDB: + return pst.searchDeviceMsgCountFromLevelDB(startTime, endTime) + case constants.TDengine: + return pst.searchDeviceMsgCountFromTDengine(startTime, endTime) + } + + return 0, nil +} + +func getEventTypeAndName(events []models.Events, code string) (string, string) { + for _, event := range events { + if event.Code == code { + return event.EventType, event.Name + } + } + return "", "" +} + +func getServiceName(events []models.Actions, code string) string { + for _, event := range events { + if event.Code == code { + return event.Name + } + } + return "" +} diff --git a/internal/hummingbird/core/application/productapp/productapp.go b/internal/hummingbird/core/application/productapp/productapp.go new file mode 100644 index 0000000..22f6f81 --- /dev/null +++ b/internal/hummingbird/core/application/productapp/productapp.go @@ -0,0 +1,309 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package productapp + +import ( + "context" + "errors" + "github.com/winc-link/hummingbird/internal/dtos" + resourceContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "github.com/winc-link/hummingbird/internal/pkg/utils" +) + +type productApp struct { + //*propertyTyApp + dic *di.Container + dbClient interfaces.DBClient + lc logger.LoggingClient +} + +func NewProductApp(ctx context.Context, dic *di.Container) interfaces.ProductItf { + lc := container.LoggingClientFrom(dic.Get) + dbClient := resourceContainer.DBClientFrom(dic.Get) + + return &productApp{ + dic: dic, + dbClient: dbClient, + lc: lc, + } +} + +func (p *productApp) ProductsSearch(ctx context.Context, req dtos.ProductSearchQueryRequest) ([]dtos.ProductSearchQueryResponse, uint32, error) { + offset, limit := req.BaseSearchConditionQuery.GetPage() + + resp, total, err := p.dbClient.ProductsSearch(offset, limit, false, req) + if err != nil { + return []dtos.ProductSearchQueryResponse{}, 0, err + } + products := make([]dtos.ProductSearchQueryResponse, len(resp)) + for i, p := range resp { + products[i] = dtos.ProductResponseFromModel(p) + } + return products, total, nil +} + +func (p *productApp) ProductsModelSearch(ctx context.Context, req dtos.ProductSearchQueryRequest) ([]models.Product, uint32, error) { + offset, limit := req.BaseSearchConditionQuery.GetPage() + + return p.dbClient.ProductsSearch(offset, limit, true, req) + +} + +func (p *productApp) ProductById(ctx context.Context, id string) (dtos.ProductSearchByIdResponse, error) { + resp, err := p.dbClient.ProductById(id) + if err != nil { + return dtos.ProductSearchByIdResponse{}, err + } + return dtos.ProductSearchByIdFromModel(resp), nil +} + +func (p *productApp) OpenApiProductById(ctx context.Context, id string) (dtos.ProductSearchByIdOpenApiResponse, error) { + resp, err := p.dbClient.ProductById(id) + if err != nil { + return dtos.ProductSearchByIdOpenApiResponse{}, err + } + return dtos.ProductSearchByIdOpenApiFromModel(resp), nil +} + +func (p *productApp) OpenApiProductSearch(ctx context.Context, req dtos.ProductSearchQueryRequest) ([]dtos.ProductSearchOpenApiResponse, uint32, error) { + offset, limit := req.BaseSearchConditionQuery.GetPage() + + resp, total, err := p.dbClient.ProductsSearch(offset, limit, false, req) + if err != nil { + return []dtos.ProductSearchOpenApiResponse{}, 0, err + } + products := make([]dtos.ProductSearchOpenApiResponse, len(resp)) + for i, product := range resp { + products[i] = dtos.ProductSearchOpenApiFromModel(product) + } + return products, total, nil +} + +func (p *productApp) ProductModelById(ctx context.Context, id string) (models.Product, error) { + resp, err := p.dbClient.ProductById(id) + if err != nil { + return models.Product{}, err + } + return resp, nil + +} + +func (p *productApp) ProductDelete(ctx context.Context, id string) error { + productInfo, err := p.dbClient.ProductById(id) + if err != nil { + return err + } + _, total, err := p.dbClient.DevicesSearch(0, -1, dtos.DeviceSearchQueryRequest{ProductId: productInfo.Id}) + + if err != nil { + return err + } + if total > 0 { + return errort.NewCommonEdgeX(errort.ProductMustDeleteDevice, "该产品已绑定子设备,请优先删除子设备", err) + } + alertApp := resourceContainer.AlertRuleAppNameFrom(p.dic.Get) + err = alertApp.CheckRuleByProductId(ctx, id) + if err != nil { + return err + } + if err = p.dbClient.AssociationsDeleteProductObject(productInfo); err != nil { + return err + } + _ = resourceContainer.DataDBClientFrom(p.dic.Get).DropStable(ctx, productInfo.Id) + go func() { + p.DeleteProductCallBack(models.Product{ + Id: productInfo.Id, + Platform: productInfo.Platform, + }) + }() + return nil +} + +func (p *productApp) AddProduct(ctx context.Context, req dtos.ProductAddRequest) (productId string, err error) { + // 标准品类 + var properties []models.Properties + var events []models.Events + var actions []models.Actions + + if req.CategoryTemplateId != "1" { + categoryTempInfo, err := p.dbClient.CategoryTemplateById(req.CategoryTemplateId) + if err != nil { + return "", err + } + thingModelTemplateInfo, err := p.dbClient.ThingModelTemplateByCategoryKey(categoryTempInfo.CategoryKey) + if err != nil { + return "", err + } + if thingModelTemplateInfo.ThingModelJSON != "" { + properties, events, actions = dtos.GetModelPropertyEventActionByThingModelTemplate(thingModelTemplateInfo.ThingModelJSON) + } + } + + var insertProduct models.Product + insertProduct.Id = utils.RandomNum() + insertProduct.Name = req.Name + insertProduct.CloudProductId = utils.GenerateDeviceSecret(15) + insertProduct.Platform = constants.IotPlatform_LocalIot + insertProduct.Protocol = req.Protocol + insertProduct.NodeType = constants.ProductNodeType(req.NodeType) + insertProduct.NetType = constants.ProductNetType(req.NetType) + insertProduct.DataFormat = req.DataFormat + insertProduct.Factory = req.Factory + insertProduct.Description = req.Description + insertProduct.Key = req.Key + insertProduct.Status = constants.ProductUnRelease + insertProduct.Properties = properties + insertProduct.Events = events + insertProduct.Actions = actions + + ps, err := p.dbClient.AddProduct(insertProduct) + if err != nil { + return "", err + } + go func() { + p.CreateProductCallBack(insertProduct) + }() + return ps.Id, nil +} + +func (p *productApp) ProductRelease(ctx context.Context, productId string) error { + + var err error + var productInfo models.Product + + productInfo, err = p.dbClient.ProductById(productId) + + if err != nil { + return err + } + if productInfo.Status == constants.ProductRelease { + return errors.New("") + } + + err = resourceContainer.DataDBClientFrom(p.dic.Get).CreateStable(ctx, productInfo) + if err != nil { + return err + } + productInfo.Status = constants.ProductRelease + return p.dbClient.UpdateProduct(productInfo) +} + +func (p *productApp) ProductUnRelease(ctx context.Context, productId string) error { + var err error + var productInfo models.Product + + productInfo, err = p.dbClient.ProductById(productId) + + if err != nil { + return err + } + if productInfo.Status == constants.ProductUnRelease { + return errors.New("") + } + productInfo.Status = constants.ProductUnRelease + return p.dbClient.UpdateProduct(productInfo) +} + +func (p *productApp) OpenApiAddProduct(ctx context.Context, req dtos.OpenApiAddProductRequest) (productId string, err error) { + //var properties []models.Properties + //var events []models.Events + //var actions []models.Actions + //properties, events, actions = dtos.OpenApiGetModelPropertyEventActionByThingModelTemplate(req) + + //if len(properties) == 0 { + // properties = make([]models.Properties, 0) + //} + // + //if len(events) == 0 { + // events = make([]models.Events, 0) + //} + // + //if len(actions) == 0 { + // actions = make([]models.Actions, 0) + //} + var insertProduct models.Product + insertProduct.Name = req.Name + insertProduct.CloudProductId = utils.GenerateDeviceSecret(15) + insertProduct.Platform = constants.IotPlatform_LocalIot + insertProduct.Protocol = req.Protocol + insertProduct.NodeType = constants.ProductNodeType(req.NodeType) + insertProduct.NetType = constants.ProductNetType(req.NetType) + insertProduct.DataFormat = req.DataFormat + insertProduct.Factory = req.Factory + insertProduct.Description = req.Description + //insertProduct.Properties = properties + //insertProduct.Events = events + //insertProduct.Actions = actions + + ps, err := p.dbClient.AddProduct(insertProduct) + if err != nil { + return "", err + } + go func() { + p.CreateProductCallBack(insertProduct) + }() + return ps.Id, nil +} + +func (p *productApp) OpenApiUpdateProduct(ctx context.Context, req dtos.OpenApiUpdateProductRequest) error { + + product, err := p.dbClient.ProductById(req.Id) + if err != nil { + return err + } + if req.Name != nil { + product.Name = *req.Name + } + product.Platform = constants.IotPlatform_LocalIot + if req.Protocol != nil { + product.Protocol = *req.Protocol + } + + if req.NetType != nil { + product.NodeType = constants.ProductNodeType(*req.NodeType) + } + + if req.NetType != nil { + product.NetType = constants.ProductNetType(*req.NetType) + } + + if req.DataFormat != nil { + product.DataFormat = *req.DataFormat + } + + if req.Factory != nil { + product.Factory = *req.Factory + } + + if req.Description != nil { + product.Description = *req.Description + } + + err = p.dbClient.UpdateProduct(product) + if err != nil { + return err + } + go func() { + p.UpdateProductCallBack(product) + }() + return nil +} diff --git a/internal/hummingbird/core/application/productapp/productcallback.go b/internal/hummingbird/core/application/productapp/productcallback.go new file mode 100644 index 0000000..91c05a6 --- /dev/null +++ b/internal/hummingbird/core/application/productapp/productcallback.go @@ -0,0 +1,118 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package productapp + +import ( + "context" + "github.com/winc-link/edge-driver-proto/productcallback" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/di" + "time" + + //"github.com/winc-link/hummingbird/internal/hummingbird/core/application/driverapp" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + "github.com/winc-link/hummingbird/internal/tools/rpcclient" +) + +func (p *productApp) CreateProductCallBack(productInfo models.Product) { + deviceServices, total, err := p.dbClient.DeviceServicesSearch(0, -1, dtos.DeviceServiceSearchQueryRequest{}) + if err != nil { + return + } + if total == 0 { + return + } + for _, service := range deviceServices { + if productInfo.Platform == service.Platform { + driverService := container.DriverServiceAppFrom(di.GContainer.Get) + status := driverService.GetState(service.Id) + if status == constants.RunStatusStarted { + client, errX := rpcclient.NewDriverRpcClient(service.BaseAddress, false, "", service.Id, p.lc) + if errX != nil { + return + } + defer client.Close() + var rpcRequest productcallback.CreateProductCallbackRequest + rpcRequest.Data = productInfo.TransformToDriverProduct() + rpcRequest.HappenTime = uint64(time.Now().Unix()) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + _, _ = client.ProductCallBackServiceClient.CreateProductCallback(ctx, &rpcRequest) + } + } + } +} + +func (p *productApp) UpdateProductCallBack(productInfo models.Product) { + //p.lc.Infof("UpdateProductCallBack Platform :%s name :%s id :%s", productInfo.Platform, productInfo.Name, productInfo.Id) + deviceServices, total, err := p.dbClient.DeviceServicesSearch(0, -1, dtos.DeviceServiceSearchQueryRequest{}) + if err != nil { + return + } + if total == 0 { + return + } + for _, service := range deviceServices { + if productInfo.Platform == service.Platform { + driverService := container.DriverServiceAppFrom(di.GContainer.Get) + status := driverService.GetState(service.Id) + if status == constants.RunStatusStarted { + client, errX := rpcclient.NewDriverRpcClient(service.BaseAddress, false, "", service.Id, p.lc) + if errX != nil { + return + } + defer client.Close() + var rpcRequest productcallback.UpdateProductCallbackRequest + rpcRequest.Data = productInfo.TransformToDriverProduct() + rpcRequest.HappenTime = uint64(time.Now().Second()) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + _, _ = client.ProductCallBackServiceClient.UpdateProductCallback(ctx, &rpcRequest) + } + } + } +} + +func (p *productApp) DeleteProductCallBack(productInfo models.Product) { + //p.lc.Infof("DeleteProductCallBack Platform :%s name :%s id :%s", productInfo.Platform, productInfo.Name, productInfo.Id) + deviceServices, total, err := p.dbClient.DeviceServicesSearch(0, -1, dtos.DeviceServiceSearchQueryRequest{}) + if total == 0 { + return + } + if err != nil { + return + } + + for _, service := range deviceServices { + if productInfo.Platform == service.Platform { + driverService := container.DriverServiceAppFrom(di.GContainer.Get) + status := driverService.GetState(service.Id) + if status == constants.RunStatusStarted { + client, errX := rpcclient.NewDriverRpcClient(service.BaseAddress, false, "", service.Id, p.lc) + if errX != nil { + return + } + defer client.Close() + var rpcRequest productcallback.DeleteProductCallbackRequest + rpcRequest.ProductId = productInfo.Id + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + _, _ = client.ProductCallBackServiceClient.DeleteProductCallback(ctx, &rpcRequest) + } + } + } +} diff --git a/internal/hummingbird/core/application/quicknavigationapp/quciknavigation.go b/internal/hummingbird/core/application/quicknavigationapp/quciknavigation.go new file mode 100644 index 0000000..e70dd73 --- /dev/null +++ b/internal/hummingbird/core/application/quicknavigationapp/quciknavigation.go @@ -0,0 +1,113 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package quicknavigationapp + +import ( + "context" + "encoding/json" + "github.com/winc-link/hummingbird/internal/dtos" + resourceContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "github.com/winc-link/hummingbird/internal/pkg/utils" +) + +type quickNavigationApp struct { + dic *di.Container + dbClient interfaces.DBClient + lc logger.LoggingClient +} + +func (m quickNavigationApp) SyncQuickNavigation(ctx context.Context, versionName string) (int64, error) { + filePath := versionName + "/quick_navigation.json" + cosApp := resourceContainer.CosAppNameFrom(m.dic.Get) + bs, err := cosApp.Get(filePath) + if err != nil { + m.lc.Errorf(err.Error()) + return 0, err + } + var cosQuickNavigationTemplateResponse []dtos.CosQuickNavigationTemplateResponse + err = json.Unmarshal(bs, &cosQuickNavigationTemplateResponse) + if err != nil { + m.lc.Errorf(err.Error()) + return 0, err + } + + baseQuery := dtos.BaseSearchConditionQuery{ + IsAll: true, + } + dbreq := dtos.QuickNavigationSearchQueryRequest{BaseSearchConditionQuery: baseQuery} + quickNavigations, _, err := m.dbClient.QuickNavigationSearch(0, -1, dbreq) + if err != nil { + return 0, err + } + + var cosQuickNavigationName []string + upsertQuickNavigationTemplate := make([]models.QuickNavigation, 0) + for _, cosQuickNavigationTemplate := range cosQuickNavigationTemplateResponse { + cosQuickNavigationName = append(cosQuickNavigationName, cosQuickNavigationTemplate.Name) + var find bool + for _, localQuickNavigation := range quickNavigations { + if cosQuickNavigationTemplate.Name == localQuickNavigation.Name { + upsertQuickNavigationTemplate = append(upsertQuickNavigationTemplate, models.QuickNavigation{ + Id: localQuickNavigation.Id, + Name: cosQuickNavigationTemplate.Name, + Icon: cosQuickNavigationTemplate.Icon, + Sort: cosQuickNavigationTemplate.Sort, + JumpLink: cosQuickNavigationTemplate.JumpLink, + }) + find = true + break + } + } + if !find { + upsertQuickNavigationTemplate = append(upsertQuickNavigationTemplate, models.QuickNavigation{ + Id: utils.RandomNum(), + Name: cosQuickNavigationTemplate.Name, + Icon: cosQuickNavigationTemplate.Icon, + Sort: cosQuickNavigationTemplate.Sort, + JumpLink: cosQuickNavigationTemplate.JumpLink, + }) + } + } + rows, err := m.dbClient.BatchUpsertQuickNavigationTemplate(upsertQuickNavigationTemplate) + if err != nil { + return 0, err + } + + for _, navigation := range quickNavigations { + if !utils.InStringSlice(navigation.Name, cosQuickNavigationName) { + err = m.dbClient.DeleteQuickNavigation(navigation.Id) + if err != nil { + return 0, err + } + } + } + + return rows, nil +} + +func NewQuickNavigationApp(ctx context.Context, dic *di.Container) interfaces.QuickNavigation { + lc := container.LoggingClientFrom(dic.Get) + dbClient := resourceContainer.DBClientFrom(dic.Get) + return &quickNavigationApp{ + dic: dic, + dbClient: dbClient, + lc: lc, + } +} diff --git a/internal/hummingbird/core/application/ruleengine/monitor.go b/internal/hummingbird/core/application/ruleengine/monitor.go new file mode 100644 index 0000000..7526b39 --- /dev/null +++ b/internal/hummingbird/core/application/ruleengine/monitor.go @@ -0,0 +1,59 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package ruleengine + +import ( + "context" + "github.com/winc-link/hummingbird/internal/dtos" + resourceContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "time" +) + +func (p ruleEngineApp) monitor() { + tickTime := time.Second * 5 + timeTickerChan := time.Tick(tickTime) + for { + select { + case <-timeTickerChan: + p.checkRuleStatus() + } + } +} + +func (p ruleEngineApp) checkRuleStatus() { + ruleEngines, _, err := p.dbClient.RuleEngineSearch(0, -1, dtos.RuleEngineSearchQueryRequest{}) + if err != nil { + p.lc.Errorf("get engines err:", err) + } + ekuiperApp := resourceContainer.EkuiperAppFrom(p.dic.Get) + for _, ruleEngine := range ruleEngines { + resp, err := ekuiperApp.GetRuleStats(context.Background(), ruleEngine.Id) + if err != nil { + p.lc.Errorf("error:", err) + continue + } + status, ok := resp["status"] + if ok { + if status != string(ruleEngine.Status) { + if status == string(constants.RuleEngineStop) { + p.dbClient.RuleEngineStop(ruleEngine.Id) + } else if status == string(constants.RuleEngineStart) { + p.dbClient.RuleEngineStart(ruleEngine.Id) + } + } + } + } +} diff --git a/internal/hummingbird/core/application/ruleengine/ruleengineapp.go b/internal/hummingbird/core/application/ruleengine/ruleengineapp.go new file mode 100644 index 0000000..19d4dd0 --- /dev/null +++ b/internal/hummingbird/core/application/ruleengine/ruleengineapp.go @@ -0,0 +1,245 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package ruleengine + +import ( + "context" + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + resourceContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "github.com/winc-link/hummingbird/internal/pkg/utils" +) + +type ruleEngineApp struct { + dic *di.Container + dbClient interfaces.DBClient + lc logger.LoggingClient +} + +func (p ruleEngineApp) AddRuleEngine(ctx context.Context, req dtos.RuleEngineRequest) (string, error) { + dataResource, err := p.dbClient.DataResourceById(req.DataResourceId) + if err != nil { + return "", err + } + randomId := utils.RandomNum() + ekuiperApp := resourceContainer.EkuiperAppFrom(p.dic.Get) + + sql := req.BuildEkuiperSql() + var actions []dtos.Actions + switch dataResource.Type { + case constants.HttpResource: + actions = append(actions, dtos.Actions{ + Rest: dataResource.Option, + }) + case constants.MQTTResource: + actions = append(actions, dtos.Actions{ + MQTT: dataResource.Option, + }) + case constants.KafkaResource: + actions = append(actions, dtos.Actions{ + Kafka: dataResource.Option, + }) + case constants.InfluxDBResource: + actions = append(actions, dtos.Actions{ + Influx: dataResource.Option, + }) + case constants.TDengineResource: + actions = append(actions, dtos.Actions{ + Tdengine: dataResource.Option, + }) + default: + return "", errort.NewCommonErr(errort.DefaultReqParamsError, fmt.Errorf("rule engine action not much")) + } + if err = ekuiperApp.CreateRule(ctx, actions, randomId, sql); err != nil { + return "", err + } + + var insertRuleEngine models.RuleEngine + insertRuleEngine.Name = req.Name + insertRuleEngine.Id = randomId + insertRuleEngine.Description = req.Description + insertRuleEngine.Filter = models.Filter(req.Filter) + insertRuleEngine.DataResourceId = req.DataResourceId + insertRuleEngine.Status = constants.RuleEngineStop + id, err := p.dbClient.AddRuleEngine(insertRuleEngine) + if err != nil { + return "", err + } + return id, nil +} + +func (p ruleEngineApp) UpdateRuleEngine(ctx context.Context, req dtos.RuleEngineUpdateRequest) error { + dataResource, err := p.dbClient.DataResourceById(*req.DataResourceId) + if err != nil { + return err + } + ruleEngine, err := p.dbClient.RuleEngineById(req.Id) + if err != nil { + return err + } + sql := req.BuildEkuiperSql() + var actions []dtos.Actions + switch dataResource.Type { + case constants.HttpResource: + actions = append(actions, dtos.Actions{ + Rest: dataResource.Option, + }) + case constants.MQTTResource: + actions = append(actions, dtos.Actions{ + MQTT: dataResource.Option, + }) + case constants.KafkaResource: + actions = append(actions, dtos.Actions{ + Kafka: dataResource.Option, + }) + case constants.InfluxDBResource: + actions = append(actions, dtos.Actions{ + Influx: dataResource.Option, + }) + case constants.TDengineResource: + actions = append(actions, dtos.Actions{ + Tdengine: dataResource.Option, + }) + default: + return errort.NewCommonErr(errort.DefaultReqParamsError, fmt.Errorf("rule engine action not much")) + } + ekuiperApp := resourceContainer.EkuiperAppFrom(p.dic.Get) + if err = ekuiperApp.UpdateRule(ctx, actions, req.Id, sql); err != nil { + return err + } + dtos.ReplaceRuleEngineModelFields(&ruleEngine, req) + err = p.dbClient.UpdateRuleEngine(ruleEngine) + if err != nil { + return err + } + return nil +} + +func (p ruleEngineApp) UpdateRuleEngineField(ctx context.Context, req dtos.RuleEngineFieldUpdateRequest) error { + //TODO implement me + panic("implement me") +} + +func (p ruleEngineApp) RuleEngineById(ctx context.Context, id string) (dtos.RuleEngineResponse, error) { + ruleEngine, err := p.dbClient.RuleEngineById(id) + var ruleEngineResponse dtos.RuleEngineResponse + if err != nil { + return ruleEngineResponse, err + } + ruleEngineResponse.Id = ruleEngine.Id + ruleEngineResponse.Name = ruleEngine.Name + ruleEngineResponse.Description = ruleEngine.Description + ruleEngineResponse.Created = ruleEngine.Created + ruleEngineResponse.Filter = dtos.Filter(ruleEngine.Filter) + ruleEngineResponse.DataResourceId = ruleEngine.DataResourceId + ruleEngineResponse.DataResource = dtos.DataResourceInfo{ + Name: ruleEngine.DataResource.Name, + Type: string(ruleEngine.DataResource.Type), + Option: ruleEngine.DataResource.Option, + } + return ruleEngineResponse, nil +} + +func (p ruleEngineApp) RuleEngineSearch(ctx context.Context, req dtos.RuleEngineSearchQueryRequest) ([]dtos.RuleEngineSearchQueryResponse, uint32, error) { + offset, limit := req.BaseSearchConditionQuery.GetPage() + resp, total, err := p.dbClient.RuleEngineSearch(offset, limit, req) + if err != nil { + return []dtos.RuleEngineSearchQueryResponse{}, 0, err + } + ruleEngines := make([]dtos.RuleEngineSearchQueryResponse, len(resp)) + for i, p := range resp { + ruleEngines[i] = dtos.RuleEngineSearchQueryResponseFromModel(p) + } + return ruleEngines, total, nil +} + +func (p ruleEngineApp) RuleEngineDelete(ctx context.Context, id string) error { + _, err := p.dbClient.RuleEngineById(id) + if err != nil { + return err + } + ekuiperApp := resourceContainer.EkuiperAppFrom(p.dic.Get) + err = ekuiperApp.DeleteRule(ctx, id) + if err != nil { + return err + } + return p.dbClient.DeleteRuleEngineById(id) +} + +func (p ruleEngineApp) RuleEngineStop(ctx context.Context, id string) error { + _, err := p.dbClient.RuleEngineById(id) + if err != nil { + return err + } + //if alertRule.EkuiperRule() { + ekuiperApp := resourceContainer.EkuiperAppFrom(p.dic.Get) + err = ekuiperApp.StopRule(ctx, id) + if err != nil { + return err + } + return p.dbClient.RuleEngineStop(id) +} + +func (p ruleEngineApp) RuleEngineStart(ctx context.Context, id string) error { + ruleEngine, err := p.dbClient.RuleEngineById(id) + if err != nil { + return err + } + dataResource, err := p.dbClient.DataResourceById(ruleEngine.DataResourceId) + if err != nil { + return err + } + if dataResource.Health != true { + return errort.NewCommonErr(errort.InvalidSource, fmt.Errorf("invalid resource configuration, please check the resource configuration resource id (%s)", dataResource.Id)) + } + + ekuiperApp := resourceContainer.EkuiperAppFrom(p.dic.Get) + err = ekuiperApp.StartRule(ctx, id) + if err != nil { + return err + } + //} + return p.dbClient.RuleEngineStart(id) +} + +func (p ruleEngineApp) RuleEngineStatus(ctx context.Context, id string) (map[string]interface{}, error) { + response := make(map[string]interface{}, 0) + _, err := p.dbClient.RuleEngineById(id) + if err != nil { + return response, err + } + ekuiperApp := resourceContainer.EkuiperAppFrom(p.dic.Get) + return ekuiperApp.GetRuleStats(ctx, id) +} + +func NewRuleEngineApp(ctx context.Context, dic *di.Container) interfaces.RuleEngineApp { + lc := container.LoggingClientFrom(dic.Get) + dbClient := resourceContainer.DBClientFrom(dic.Get) + + app := &ruleEngineApp{ + dic: dic, + dbClient: dbClient, + lc: lc, + } + go app.monitor() + return app +} diff --git a/internal/hummingbird/core/application/scene/monitor.go b/internal/hummingbird/core/application/scene/monitor.go new file mode 100644 index 0000000..be89afa --- /dev/null +++ b/internal/hummingbird/core/application/scene/monitor.go @@ -0,0 +1,65 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package scene + +import ( + "context" + "github.com/winc-link/hummingbird/internal/dtos" + resourceContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "time" +) + +func (p sceneApp) monitor() { + tickTime := time.Second * 9 + timeTickerChan := time.Tick(tickTime) + for { + select { + case <-timeTickerChan: + p.checkSceneRuleStatus() + } + } +} + +func (p sceneApp) checkSceneRuleStatus() { + scenes, _, err := p.dbClient.SceneSearch(0, -1, dtos.SceneSearchQueryRequest{}) + if err != nil { + p.lc.Errorf("get engines err:", err) + } + ekuiperApp := resourceContainer.EkuiperAppFrom(p.dic.Get) + for _, scene := range scenes { + if len(scene.Conditions) != 1 { + continue + } + if scene.Conditions[0].ConditionType != "notify" { + continue + } + resp, err := ekuiperApp.GetRuleStats(context.Background(), scene.Id) + if err != nil { + p.lc.Errorf("error:", err) + continue + } + status, ok := resp["status"] + if ok { + if status != string(scene.Status) { + if status == string(constants.SceneStart) { + p.dbClient.SceneStart(scene.Id) + } else if status == string(constants.SceneStop) { + p.dbClient.SceneStop(scene.Id) + } + } + } + } +} diff --git a/internal/hummingbird/core/application/scene/sceneapp.go b/internal/hummingbird/core/application/scene/sceneapp.go new file mode 100644 index 0000000..13a6fcf --- /dev/null +++ b/internal/hummingbird/core/application/scene/sceneapp.go @@ -0,0 +1,628 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package scene + +import ( + "context" + "errors" + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + resourceContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "strconv" + "strings" +) + +type sceneApp struct { + dic *di.Container + dbClient interfaces.DBClient + lc logger.LoggingClient +} + +func NewSceneApp(ctx context.Context, dic *di.Container) interfaces.SceneApp { + lc := container.LoggingClientFrom(dic.Get) + dbClient := resourceContainer.DBClientFrom(dic.Get) + + app := &sceneApp{ + dic: dic, + dbClient: dbClient, + lc: lc, + } + go app.monitor() + return app +} + +func (p sceneApp) AddScene(ctx context.Context, req dtos.SceneAddRequest) (string, error) { + var scene models.Scene + scene.Name = req.Name + scene.Description = req.Description + resp, err := p.dbClient.AddScene(scene) + if err != nil { + return "", err + } + return resp.Id, nil +} + +func (p sceneApp) UpdateScene(ctx context.Context, req dtos.SceneUpdateRequest) error { + if req.Id == "" { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "update req id is required", nil) + } + scene, edgeXErr := p.dbClient.SceneById(req.Id) + if edgeXErr != nil { + return edgeXErr + } + if len(req.Conditions) != 1 { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "conditions len not eq 1", nil) + } + switch req.Conditions[0].ConditionType { + case "timer": + if scene.Status == constants.SceneStart { + return errort.NewCommonEdgeX(errort.SceneTimerIsStartingNotAllowUpdate, "Please stop this scheduled"+ + " tasks before editing it.", nil) + } + case "notify": + if req.Conditions[0].Option == nil { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "condition option is null", nil) + } + actions, sql, err := p.buildEkuiperSqlAndAction(req) + if err != nil { + return err + } + p.lc.Infof("sql:", sql) + + ekuiperApp := resourceContainer.EkuiperAppFrom(p.dic.Get) + + exist, err := ekuiperApp.RuleExist(ctx, scene.Id) + if err != nil { + return err + } + if exist { + err = ekuiperApp.UpdateRule(ctx, actions, scene.Id, sql) + if err != nil { + return err + } + } else { + err = ekuiperApp.CreateRule(ctx, actions, scene.Id, sql) + if err != nil { + return err + } + } + } + dtos.ReplaceSceneModelFields(&scene, req) + edgeXErr = p.dbClient.UpdateScene(scene) + if edgeXErr != nil { + return edgeXErr + } + return nil +} + +func (p sceneApp) SceneById(ctx context.Context, sceneId string) (models.Scene, error) { + return p.dbClient.SceneById(sceneId) +} + +func (p sceneApp) SceneStartById(ctx context.Context, sceneId string) error { + scene, err := p.dbClient.SceneById(sceneId) + if err != nil { + return err + } + if len(scene.Conditions) == 0 { + return errort.NewCommonErr(errort.SceneRuleParamsError, fmt.Errorf("scene id(%s) conditionType param errror", scene.Id)) + } + + switch scene.Conditions[0].ConditionType { + case "timer": + tmpJob, errJob := scene.ToRuntimeJob() + if errJob != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, errJob.Error(), errJob) + } + p.lc.Infof("tmpJob: %v", tmpJob) + conJobApp := resourceContainer.ConJobAppNameFrom(p.dic.Get) + err = conJobApp.AddJobToRunQueue(tmpJob) + if err != nil { + return err + } + case "notify": + if scene.Conditions[0].Option == nil { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "condition option is null", nil) + } + if err = p.checkAlertRuleParam(ctx, scene, "start"); err != nil { + return err + } + ekuiperApp := resourceContainer.EkuiperAppFrom(p.dic.Get) + err = ekuiperApp.StartRule(ctx, sceneId) + if err != nil { + return err + } + default: + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "condition Type value not much", nil) + } + return p.dbClient.SceneStart(sceneId) +} + +func (p sceneApp) SceneStopById(ctx context.Context, sceneId string) error { + scene, err := p.dbClient.SceneById(sceneId) + if err != nil { + return err + } + switch scene.Conditions[0].ConditionType { + case "timer": + conJobApp := resourceContainer.ConJobAppNameFrom(p.dic.Get) + conJobApp.DeleteJob(scene.Id) + case "notify": + ekuiperApp := resourceContainer.EkuiperAppFrom(p.dic.Get) + err = ekuiperApp.StopRule(ctx, sceneId) + if err != nil { + return err + } + default: + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "condition Type value not much", nil) + } + + return p.dbClient.SceneStop(sceneId) +} + +func (p sceneApp) DelSceneById(ctx context.Context, sceneId string) error { + scene, err := p.dbClient.SceneById(sceneId) + if err != nil { + return err + } + + if len(scene.Conditions) == 0 { + return p.dbClient.DeleteSceneById(sceneId) + //return errort.NewCommonEdgeX(errort.DefaultSystemError, "conditions param error", nil) + } + + switch scene.Conditions[0].ConditionType { + case "timer": + conJobApp := resourceContainer.ConJobAppNameFrom(p.dic.Get) + conJobApp.DeleteJob(scene.Id) + case "notify": + ekuiperApp := resourceContainer.EkuiperAppFrom(p.dic.Get) + err = ekuiperApp.DeleteRule(ctx, sceneId) + if err != nil { + return err + } + default: + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "condition Type value not much", nil) + } + return p.dbClient.DeleteSceneById(sceneId) +} + +func (p sceneApp) SceneSearch(ctx context.Context, req dtos.SceneSearchQueryRequest) ([]models.Scene, uint32, error) { + offset, limit := req.BaseSearchConditionQuery.GetPage() + resp, total, err := p.dbClient.SceneSearch(offset, limit, req) + if err != nil { + return []models.Scene{}, 0, err + } + + return resp, total, nil +} + +func (p sceneApp) SceneLogSearch(ctx context.Context, req dtos.SceneLogSearchQueryRequest) ([]models.SceneLog, uint32, error) { + offset, limit := req.BaseSearchConditionQuery.GetPage() + resp, total, err := p.dbClient.SceneLogSearch(offset, limit, req) + if err != nil { + return []models.SceneLog{}, 0, err + } + return resp, total, nil +} + +func (p sceneApp) buildEkuiperSqlAndAction(req dtos.SceneUpdateRequest) (actions []dtos.Actions, sql string, err error) { + configapp := resourceContainer.ConfigurationFrom(p.dic.Get) + actions = dtos.GetRuleSceneEkuiperActions(configapp.Service.Url()) + option := req.Conditions[0].Option + deviceId := option["device_id"] + deviceName := option["device_name"] + productId := option["product_id"] + productName := option["product_name"] + trigger := option["trigger"] + code := option["code"] + if deviceId == "" || deviceName == "" || productId == "" || productName == "" || code == "" || trigger == "" { + err = errort.NewCommonEdgeX(errort.DefaultReqParamsError, "required parameter missing", nil) + return + } + device, err := p.dbClient.DeviceById(deviceId) + if err != nil { + return + } + product, err := p.dbClient.ProductById(productId) + if err != nil { + return + } + if device.ProductId != product.Id { + err = errort.NewCommonEdgeX(errort.DefaultSystemError, "", nil) + return + } + + switch trigger { + case string(constants.DeviceDataTrigger): + var codeFind bool + for _, property := range product.Properties { + if code == property.Code { + codeFind = true + if !property.TypeSpec.Type.AllowSendInEkuiper() { + err = errort.NewCommonEdgeX(errort.DefaultReqParamsError, "required parameter missing", nil) + } + + var s int + switch option["value_cycle"] { + case "1分钟周期": + s = 60 + case "5分钟周期": + s = 60 * 5 + case "15分钟周期": + s = 60 * 15 + case "30分钟周期": + s = 60 * 30 + case "60分钟周期": + s = 60 * 60 + default: + } + + switch property.TypeSpec.Type { + + case constants.SpecsTypeInt, constants.SpecsTypeFloat: + valueType := option["value_type"] + if valueType == "" { + err = errort.NewCommonEdgeX(errort.DefaultReqParamsError, "required value_type parameter missing", nil) + return + } + switch valueType { + case constants.Original: //原始值 + decideCondition := option["decide_condition"] + if decideCondition == "" { + err = errort.NewCommonEdgeX(errort.DefaultReqParamsError, "required decide_condition parameter missing", nil) + return + } + originalTemp := `SELECT rule_id(),json_path_query(data, "$.%s.time") as report_time ,deviceId FROM mqtt_stream where deviceId = "%s" and messageType = "PROPERTY_REPORT" and json_path_exists(data, "$.%s") = true and json_path_query(data, "$.%s.value") %s` + sql = fmt.Sprintf(originalTemp, code, deviceId, code, code, decideCondition) + return + case constants.Max: + sqlTemp := `SELECT window_start(),window_end(),rule_id(),deviceId,max(json_path_query(data, "$.%s.value")) as max_%s FROM mqtt_stream where deviceId = "%s" and messageType = "PROPERTY_REPORT" and json_path_exists(data, "$.%s") = true GROUP BY %s HAVING max_%s %s` + valueCycle := s + if valueCycle == 0 { + err = errort.NewCommonEdgeX(errort.DefaultReqParamsError, "required value_cycle parameter missing", nil) + return + } + decideCondition := option["decide_condition"] + if decideCondition == "" { + err = errort.NewCommonEdgeX(errort.DefaultReqParamsError, "required decide_condition parameter missing", nil) + return + } + sql = fmt.Sprintf(sqlTemp, code, code, device.Id, code, fmt.Sprintf("TUMBLINGWINDOW(ss, %d)", valueCycle), code, decideCondition) + return + case constants.Min: + sqlTemp := `SELECT window_start(),window_end(),rule_id(),deviceId,min(json_path_query(data, "$.%s.value")) as min_%s FROM mqtt_stream where deviceId = "%s" and messageType = "PROPERTY_REPORT" and json_path_exists(data, "$.%s") = true GROUP BY %s HAVING min_%s %s` + valueCycle := s + if valueCycle == 0 { + err = errort.NewCommonEdgeX(errort.DefaultReqParamsError, "required value_cycle parameter missing", nil) + return + } + decideCondition := option["decide_condition"] + if decideCondition == "" { + err = errort.NewCommonEdgeX(errort.DefaultReqParamsError, "required decide_condition parameter missing", nil) + return + } + sql = fmt.Sprintf(sqlTemp, code, code, device.Id, code, fmt.Sprintf("TUMBLINGWINDOW(ss, %d)", valueCycle), code, decideCondition) + return + case constants.Sum: + sqlTemp := `SELECT window_start(),window_end(),rule_id(),deviceId,sum(json_path_query(data, "$.%s.value")) as sum_%s FROM mqtt_stream where deviceId = "%s" and messageType = "PROPERTY_REPORT" and json_path_exists(data, "$.%s") = true GROUP BY %s HAVING sum_%s %s` + valueCycle := s + if valueCycle == 0 { + err = errort.NewCommonEdgeX(errort.DefaultReqParamsError, "required value_cycle parameter missing", nil) + return + } + decideCondition := option["decide_condition"] + if decideCondition == "" { + err = errort.NewCommonEdgeX(errort.DefaultReqParamsError, "required decide_condition parameter missing", nil) + return + } + sql = fmt.Sprintf(sqlTemp, code, code, device.Id, code, fmt.Sprintf("TUMBLINGWINDOW(ss, %d)", valueCycle), code, decideCondition) + return + case constants.Avg: + sqlTemp := `SELECT window_start(),window_end(),rule_id(),deviceId,avg(json_path_query(data, "$.%s.value")) as avg_%s FROM mqtt_stream where deviceId = "%s" and messageType = "PROPERTY_REPORT" and json_path_exists(data, "$.%s") = true GROUP BY %s HAVING avg_%s %s` + valueCycle := s + if valueCycle == 0 { + err = errort.NewCommonEdgeX(errort.DefaultReqParamsError, "required value_cycle parameter missing", nil) + return + } + decideCondition := option["decide_condition"] + if decideCondition == "" { + err = errort.NewCommonEdgeX(errort.DefaultReqParamsError, "required decide_condition parameter missing", nil) + return + } + sql = fmt.Sprintf(sqlTemp, code, code, device.Id, code, fmt.Sprintf("TUMBLINGWINDOW(ss, %d)", valueCycle), code, decideCondition) + return + } + case constants.SpecsTypeText: + decideCondition := option["decide_condition"] + if decideCondition == "" { + err = errort.NewCommonEdgeX(errort.DefaultReqParamsError, "required decide_condition parameter missing", nil) + return + } + st := strings.Split(decideCondition, " ") + if len(st) != 2 { + return + } + sqlTemp := `SELECT rule_id(),json_path_query(data, "$.%s.time") as report_time,deviceId FROM mqtt_stream where deviceId = "%s" and messageType = "PROPERTY_REPORT" and json_path_exists(data, "$.%s") = true and json_path_query(data, "$.%s.value") = "%s"` + sql = fmt.Sprintf(sqlTemp, code, deviceId, code, code, st[1]) + return + + case constants.SpecsTypeBool: + decideCondition := option["decide_condition"] + if decideCondition == "" { + err = errort.NewCommonEdgeX(errort.DefaultReqParamsError, "required decide_condition parameter missing", nil) + return + } + st := strings.Split(decideCondition, " ") + if len(st) != 2 { + return + } + sqlTemp := `SELECT rule_id(),json_path_query(data, "$.%s.time") as report_time,deviceId FROM mqtt_stream where deviceId = "%s" and messageType = "PROPERTY_REPORT" and json_path_exists(data, "$.%s") = true and json_path_query(data, "$.%s.value") = "%s"` + if st[1] == "true" { + sql = fmt.Sprintf(sqlTemp, code, deviceId, code, code, "1") + } else if st[1] == "false" { + sql = fmt.Sprintf(sqlTemp, code, deviceId, code, code, "0") + } + return + case constants.SpecsTypeEnum: + decideCondition := option["decide_condition"] + if decideCondition == "" { + err = errort.NewCommonEdgeX(errort.DefaultReqParamsError, "required decide_condition parameter missing", nil) + return + } + st := strings.Split(decideCondition, " ") + if len(st) != 2 { + return + } + sqlTemp := `SELECT rule_id(),json_path_query(data, "$.%s.time") as report_time,deviceId FROM mqtt_stream where deviceId = "%s" and messageType = "PROPERTY_REPORT" and json_path_exists(data, "$.%s") = true and json_path_query(data, "$.%s.value") = "%s"` + sql = fmt.Sprintf(sqlTemp, code, deviceId, code, code, st[1]) + return + } + } + } + if !codeFind { + err = errort.NewCommonEdgeX(errort.DefaultReqParamsError, "required code parameter missing", nil) + } + case string(constants.DeviceEventTrigger): + var codeFind bool + for _, event := range product.Events { + if code == event.Code { + codeFind = true + sqlTemp := `SELECT rule_id(),json_path_query(data, "$.eventTime") as report_time,deviceId FROM mqtt_stream where deviceId = "%s" and messageType = "EVENT_REPORT" and json_path_exists(data, "$.eventCode") = true and json_path_query(data, "$.eventCode") = "%s"` + sql = fmt.Sprintf(sqlTemp, device.Id, code) + return + } + } + if !codeFind { + err = errort.NewCommonEdgeX(errort.DefaultReqParamsError, "required code parameter missing", nil) + } + case string(constants.DeviceStatusTrigger): + var status string + deviceStatus := option["status"] + if deviceStatus == "" { + err = errort.NewCommonEdgeX(errort.DefaultReqParamsError, "required status parameter missing", nil) + return + } + if deviceStatus == "在线" { + status = constants.DeviceOnline + } else if deviceStatus == "离线" { + status = constants.DeviceOffline + } else { + err = errort.NewCommonEdgeX(errort.DefaultReqParamsError, "required status parameter missing", nil) + return + } + sqlTemp := `SELECT rule_id(),json_path_query(data, "$.time") as report_time,deviceId FROM mqtt_stream where deviceId = "%s" and messageType = "DEVICE_STATUS" and json_path_exists(data, "$.status") = true and json_path_query(data, "$.status") = "%s"` + sql = fmt.Sprintf(sqlTemp, device.Id, status) + return + default: + err = errort.NewCommonEdgeX(errort.DefaultReqParamsError, "required trigger parameter missing", nil) + return + } + return +} + +func (p sceneApp) checkAlertRuleParam(ctx context.Context, scene models.Scene, operate string) error { + if operate == "start" { + if scene.Status == constants.SceneStart { + return errort.NewCommonErr(errort.AlertRuleStatusStarting, fmt.Errorf("scene id(%s) is runing ,not allow start", scene.Id)) + } + } + + var ( + trigger string + ) + + if len(scene.Conditions) != 1 { + trigger = scene.Conditions[0].Option["trigger"] + + switch scene.Conditions[0].ConditionType { + case "timer": + case "notify": + ekuiperApp := resourceContainer.EkuiperAppFrom(p.dic.Get) + exist, err := ekuiperApp.RuleExist(ctx, scene.Id) + if err != nil { + return err + } + if !exist { + + } + + trigger = scene.Conditions[0].Option["trigger"] + if trigger != string(constants.DeviceDataTrigger) || trigger != string(constants.DeviceEventTrigger) || trigger != string(constants.DeviceStatusTrigger) { + return errort.NewCommonErr(errort.SceneRuleParamsError, fmt.Errorf("scene id(%s) trigger param error", scene.Id)) + } + + option := scene.Conditions[0].Option + deviceId := option["device_id"] + deviceName := option["device_name"] + productId := option["product_id"] + productName := option["product_name"] + //trigger := option["trigger"] + code := option["code"] + if deviceId == "" || deviceName == "" || productId == "" || productName == "" || code == "" || trigger == "" { + return errort.NewCommonEdgeX(errort.SceneRuleParamsError, "required parameter missing", nil) + } + device, err := p.dbClient.DeviceById(deviceId) + if err != nil { + return errort.NewCommonErr(errort.SceneRuleParamsError, fmt.Errorf("scene id(%s) device not found", scene.Id)) + + } + product, err := p.dbClient.ProductById(productId) + if err != nil { + return errort.NewCommonErr(errort.SceneRuleParamsError, fmt.Errorf("scene id(%s) actions is null", scene.Id)) + + } + if device.ProductId != product.Id { + return errort.NewCommonErr(errort.SceneRuleParamsError, fmt.Errorf("scene id(%s) actions is null", scene.Id)) + } + + default: + return errort.NewCommonErr(errort.SceneRuleParamsError, fmt.Errorf("scene id(%s) conditionType param errror", scene.Id)) + + } + } + + //------------------------- + + if len(scene.Actions) == 0 { + return errort.NewCommonErr(errort.SceneRuleParamsError, fmt.Errorf("scene id(%s) actions is null", scene.Id)) + } + + for _, action := range scene.Actions { + //检查产品和设备是否存在 + device, err := p.dbClient.DeviceById(action.DeviceID) + if err != nil { + return errort.NewCommonErr(errort.SceneRuleParamsError, fmt.Errorf("scene id(%s) device not found", scene.Id)) + } + + product, err := p.dbClient.ProductById(action.ProductID) + if err != nil { + return errort.NewCommonErr(errort.SceneRuleParamsError, fmt.Errorf("scene id(%s) product not found", scene.Id)) + } + if device.ProductId != product.Id { + return errort.NewCommonErr(errort.SceneRuleParamsError, fmt.Errorf("scene id(%s) actions is null", scene.Id)) + } + + var find bool + + if trigger == string(constants.DeviceDataTrigger) { + for _, property := range product.Properties { + if property.Code == action.Code { + find = true + break + } + } + } + + if trigger == string(constants.DeviceEventTrigger) { + for _, event := range product.Events { + if event.Code == action.Code { + find = true + break + } + } + } + + if !find { + return errort.NewCommonErr(errort.SceneRuleParamsError, fmt.Errorf("scene id(%s) code not found", scene.Id)) + + } + if action.Value == "" { + return errort.NewCommonErr(errort.SceneRuleParamsError, fmt.Errorf("scene id(%s) value is null", scene.Id)) + } + + } + return nil +} + +func (p sceneApp) EkuiperNotify(ctx context.Context, req map[string]interface{}) error { + sceneId, ok := req["rule_id"] + if !ok { + return errort.NewCommonErr(errort.DefaultReqParamsError, errors.New("")) + } + var ( + coverSceneId string + ) + switch sceneId.(type) { + case string: + coverSceneId = sceneId.(string) + case int: + coverSceneId = strconv.Itoa(sceneId.(int)) + case int64: + coverSceneId = strconv.Itoa(int(sceneId.(int64))) + case float64: + coverSceneId = fmt.Sprintf("%f", sceneId.(float64)) + case float32: + coverSceneId = fmt.Sprintf("%f", sceneId.(float64)) + } + if coverSceneId == "" { + return errort.NewCommonErr(errort.DefaultReqParamsError, errors.New("")) + } + + scene, err := p.dbClient.SceneById(coverSceneId) + if err != nil { + return err + } + + for _, action := range scene.Actions { + deviceApp := resourceContainer.DeviceItfFrom(p.dic.Get) + execRes := deviceApp.DeviceAction(dtos.JobAction{ + ProductId: action.ProductID, + }) + _, err := p.dbClient.AddSceneLog(models.SceneLog{ + SceneId: scene.Id, + Name: scene.Name, + ExecRes: execRes.ToString(), + }) + if err != nil { + p.lc.Errorf("add sceneLog err %v", err.Error()) + + } + + } + return nil +} + +func (p sceneApp) CheckSceneByDeviceId(ctx context.Context, deviceId string) error { + var req dtos.SceneSearchQueryRequest + req.Status = string(constants.SceneStart) + scenes, _, err := p.SceneSearch(ctx, req) + if err != nil { + return err + } + + for _, scene := range scenes { + for _, condition := range scene.Conditions { + + if condition.Option != nil && condition.Option["device_id"] == deviceId { + return errort.NewCommonEdgeX(errort.DeviceAssociationSceneRule, "This device has been bound to scene rules. Please stop reporting scene rules before proceeding with the operation", nil) + } + } + for _, action := range scene.Actions { + if action.DeviceID == deviceId { + return errort.NewCommonEdgeX(errort.DeviceAssociationSceneRule, "This device has been bound to scene rules. Please stop reporting scene rules before proceeding with the operation", nil) + } + } + } + + return nil +} diff --git a/internal/hummingbird/core/application/schedule.go b/internal/hummingbird/core/application/schedule.go new file mode 100644 index 0000000..ae0af66 --- /dev/null +++ b/internal/hummingbird/core/application/schedule.go @@ -0,0 +1,39 @@ +//go:build !community +// +build !community + +package application + +import ( + "context" + resourceContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "time" + + //"gitlab.com/tedge/edgex/internal/dtos" + //"gitlab.com/tedge/edgex/internal/pkg/constants" + //"gitlab.com/tedge/edgex/internal/pkg/container" + //pkgContainer "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/crontab" + //"gitlab.com/tedge/edgex/internal/pkg/di" + //"gitlab.com/tedge/edgex/internal/pkg/errort" + //"gitlab.com/tedge/edgex/internal/pkg/logger" + //resourceContainer "gitlab.com/tedge/edgex/internal/tedge/resource/container" + //"gitlab.com/tedge/edgex/internal/tools/atopclient" +) + +func InitSchedule(dic *di.Container, lc logger.LoggingClient) { + lc.Info("init schedule") + + // 每天 1 点 + crontab.Schedule.AddFunc("0 1 * * *", func() { + lc.Debugf("schedule statistic device msg conut: %v", time.Now().Format("2006-01-02 15:04:05")) + deviceItf := resourceContainer.DeviceItfFrom(dic.Get) + err := deviceItf.DevicesReportMsgGather(context.Background()) + if err != nil { + lc.Error("schedule statistic device err:", err) + } + }) + + crontab.Start() +} diff --git a/internal/hummingbird/core/application/thingmodelapp/thingmodelapp.go b/internal/hummingbird/core/application/thingmodelapp/thingmodelapp.go new file mode 100644 index 0000000..f56267e --- /dev/null +++ b/internal/hummingbird/core/application/thingmodelapp/thingmodelapp.go @@ -0,0 +1,633 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package thingmodelapp + +import ( + "context" + "encoding/json" + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + resourceContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "strings" + "time" +) + +const ( + Property = "property" + Event = "event" + Action = "action" +) + +type thingModelApp struct { + dic *di.Container + dbClient interfaces.DBClient + lc logger.LoggingClient +} + +func (t thingModelApp) AddThingModel(ctx context.Context, req dtos.ThingModelAddOrUpdateReq) (string, error) { + + product, err := t.dbClient.ProductById(req.ProductId) + if err != nil { + return "", err + } + if product.Status == constants.ProductRelease { + //产品已发布,不能修改物模型 + return "", errort.NewCommonEdgeX(errort.ProductRelease, "Please cancel publishing the product first before proceeding with the operation", nil) + } + + if err = validatorReq(req, product); err != nil { + return "", errort.NewCommonEdgeX(errort.DefaultReqParamsError, "param valida error", err) + } + + switch req.ThingModelType { + case Property: + var property models.Properties + property.ProductId = req.ProductId + property.Name = req.Name + property.Code = req.Code + property.Description = req.Description + if req.Property != nil { + typeSpec, _ := json.Marshal(req.Property.TypeSpec) + property.AccessMode = req.Property.AccessModel + property.Require = req.Property.Require + property.TypeSpec.Type = req.Property.DataType + property.TypeSpec.Specs = string(typeSpec) + } + property.Tag = req.Tag + err = resourceContainer.DataDBClientFrom(t.dic.Get).AddDatabaseField(ctx, req.ProductId, req.Property.DataType, req.Code, req.Name) + if err != nil { + return "", err + } + ds, err := t.dbClient.AddThingModelProperty(property) + if err != nil { + return "", err + } + t.ProductUpdateCallback(req.ProductId) + return ds.Id, nil + case Event: + var event models.Events + event.ProductId = req.ProductId + event.Name = req.Name + event.Code = req.Code + event.Description = req.Description + if req.Event != nil { + var inputOutput []models.InputOutput + for _, outPutParam := range req.Event.OutPutParam { + typeSpec, _ := json.Marshal(outPutParam.TypeSpec) + inputOutput = append(inputOutput, models.InputOutput{ + Code: outPutParam.Code, + Name: outPutParam.Name, + TypeSpec: models.TypeSpec{ + Type: outPutParam.DataType, + Specs: string(typeSpec), + }, + }) + } + event.OutputParams = inputOutput + event.EventType = req.Event.EventType + } + event.Tag = req.Tag + err = resourceContainer.DataDBClientFrom(t.dic.Get).AddDatabaseField(ctx, req.ProductId, "", req.Code, req.Name) + if err != nil { + return "", err + } + ds, err := t.dbClient.AddThingModelEvent(event) + if err != nil { + return "", err + } + t.ProductUpdateCallback(req.ProductId) + return ds.Id, nil + case Action: + var action models.Actions + action.ProductId = req.ProductId + action.Name = req.Name + action.Code = req.Code + action.Description = req.Description + if req.Action != nil { + action.CallType = req.Action.CallType + var inputOutput []models.InputOutput + for _, inPutParam := range req.Action.InPutParam { + typeSpec, _ := json.Marshal(inPutParam.TypeSpec) + inputOutput = append(inputOutput, models.InputOutput{ + Code: inPutParam.Code, + Name: inPutParam.Name, + TypeSpec: models.TypeSpec{ + Type: inPutParam.DataType, + Specs: string(typeSpec), + }, + }) + } + action.InputParams = inputOutput + + var outOutput []models.InputOutput + for _, outPutParam := range req.Action.OutPutParam { + typeSpec, _ := json.Marshal(outPutParam.TypeSpec) + outOutput = append(outOutput, models.InputOutput{ + Code: outPutParam.Code, + Name: outPutParam.Name, + TypeSpec: models.TypeSpec{ + Type: outPutParam.DataType, + Specs: string(typeSpec), + }, + }) + } + action.OutputParams = outOutput + } + action.Tag = req.Tag + err = resourceContainer.DataDBClientFrom(t.dic.Get).AddDatabaseField(ctx, req.ProductId, "", req.Code, req.Name) + if err != nil { + return "", err + } + ds, err := t.dbClient.AddThingModelAction(action) + if err != nil { + return "", err + } + t.ProductUpdateCallback(req.ProductId) + return ds.Id, nil + default: + return "", errort.NewCommonEdgeX(errort.DefaultReqParamsError, "param valida error", fmt.Errorf("req params error")) + } +} + +func (t thingModelApp) UpdateThingModel(ctx context.Context, req dtos.ThingModelAddOrUpdateReq) error { + product, err := t.dbClient.ProductById(req.ProductId) + if err != nil { + return err + } + if product.Status == constants.ProductRelease { + //产品已发布,不能修改物模型 + return errort.NewCommonEdgeX(errort.ProductRelease, "Please cancel publishing the product first before proceeding with the operation", nil) + } + if err := validatorReq(req, product); err != nil { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "param valida error", err) + } + switch req.ThingModelType { + case Property: + var property models.Properties + property.Id = req.Id + property.ProductId = req.ProductId + property.Name = req.Name + property.Code = req.Code + property.Description = req.Description + if req.Property != nil { + typeSpec, _ := json.Marshal(req.Property.TypeSpec) + property.AccessMode = req.Property.AccessModel + property.Require = req.Property.Require + property.TypeSpec.Type = req.Property.DataType + property.TypeSpec.Specs = string(typeSpec) + } + property.Tag = req.Tag + + err = resourceContainer.DataDBClientFrom(t.dic.Get).ModifyDatabaseField(ctx, req.ProductId, req.Property.DataType, req.Code, req.Name) + if err != nil { + return err + } + + err = t.dbClient.UpdateThingModelProperty(property) + if err != nil { + return err + } + t.ProductUpdateCallback(req.ProductId) + return nil + case Event: + var event models.Events + event.Id = req.Id + event.ProductId = req.ProductId + event.Name = req.Name + event.Code = req.Code + event.Description = req.Description + if req.Event != nil { + var inputOutput []models.InputOutput + for _, outPutParam := range req.Event.OutPutParam { + typeSpec, _ := json.Marshal(outPutParam.TypeSpec) + inputOutput = append(inputOutput, models.InputOutput{ + Code: outPutParam.Code, + Name: outPutParam.Name, + TypeSpec: models.TypeSpec{ + Type: outPutParam.DataType, + Specs: string(typeSpec), + }, + }) + } + event.OutputParams = inputOutput + event.EventType = req.Event.EventType + } + event.Tag = req.Tag + err = resourceContainer.DataDBClientFrom(t.dic.Get).ModifyDatabaseField(ctx, req.ProductId, "", req.Code, req.Name) + if err != nil { + return err + } + err = t.dbClient.UpdateThingModelEvent(event) + if err != nil { + return err + } + t.ProductUpdateCallback(req.ProductId) + return nil + case Action: + var action models.Actions + action.Id = req.Id + action.ProductId = req.ProductId + action.Name = req.Name + action.Code = req.Code + action.Description = req.Description + if req.Action != nil { + action.CallType = req.Action.CallType + var inputOutput []models.InputOutput + for _, inPutParam := range req.Action.InPutParam { + typeSpec, _ := json.Marshal(inPutParam.TypeSpec) + inputOutput = append(inputOutput, models.InputOutput{ + Code: inPutParam.Code, + Name: inPutParam.Name, + TypeSpec: models.TypeSpec{ + Type: inPutParam.DataType, + Specs: string(typeSpec), + }, + }) + } + action.InputParams = inputOutput + + var outOutput []models.InputOutput + for _, outPutParam := range req.Action.OutPutParam { + typeSpec, _ := json.Marshal(outPutParam.TypeSpec) + outOutput = append(outOutput, models.InputOutput{ + Code: outPutParam.Code, + Name: outPutParam.Name, + TypeSpec: models.TypeSpec{ + Type: outPutParam.DataType, + Specs: string(typeSpec), + }, + }) + } + action.OutputParams = outOutput + } + action.Tag = req.Tag + err = resourceContainer.DataDBClientFrom(t.dic.Get).ModifyDatabaseField(ctx, req.ProductId, "", req.Code, req.Name) + if err != nil { + return err + } + err = t.dbClient.UpdateThingModelAction(action) + if err != nil { + return err + } + t.ProductUpdateCallback(req.ProductId) + return nil + default: + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "param valida error", fmt.Errorf("req params error")) + } +} + +func validatorReq(req dtos.ThingModelAddOrUpdateReq, product models.Product) error { + if req.ProductId == "" || req.Name == "" || req.Code == "" { + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "params error", nil) + } + switch req.ThingModelType { + case "property": + for _, property := range product.Properties { + if strings.ToLower(property.Code) == strings.ToLower(req.Code) { + return errort.NewCommonEdgeX(errort.ThingModelCodeExist, "code identifier already exists", nil) + } + } + case "action": + for _, action := range product.Actions { + if strings.ToLower(action.Code) == strings.ToLower(req.Code) { + return errort.NewCommonEdgeX(errort.ThingModelCodeExist, "code identifier already exists", nil) + } + } + case "event": + for _, event := range product.Events { + if strings.ToLower(event.Code) == strings.ToLower(req.Code) { + return errort.NewCommonEdgeX(errort.ThingModelCodeExist, "code identifier already exists", nil) + } + } + default: + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "params error", nil) + } + return nil +} + +func NewThingModelApp(ctx context.Context, dic *di.Container) interfaces.ThingModelItf { + lc := container.LoggingClientFrom(dic.Get) + dbClient := resourceContainer.DBClientFrom(dic.Get) + + return &thingModelApp{ + dic: dic, + dbClient: dbClient, + lc: lc, + } +} + +func (t thingModelApp) ThingModelDelete(ctx context.Context, id string, thingModelType string) error { + switch thingModelType { + case Property: + propertyInfo, err := t.dbClient.ThingModelPropertyById(id) + if err != nil { + return err + } + err = resourceContainer.DataDBClientFrom(t.dic.Get).DelDatabaseField(ctx, propertyInfo.ProductId, propertyInfo.Code) + if err != nil { + return err + } + err = t.dbClient.ThingModelDeleteProperty(id) + if err != nil { + return err + } + t.ProductUpdateCallback(propertyInfo.ProductId) + case Event: + eventInfo, err := t.dbClient.ThingModelEventById(id) + if err != nil { + return err + } + err = resourceContainer.DataDBClientFrom(t.dic.Get).DelDatabaseField(ctx, eventInfo.ProductId, eventInfo.Code) + if err != nil { + return err + } + err = t.dbClient.ThingModelDeleteEvent(id) + if err != nil { + return err + } + t.ProductUpdateCallback(eventInfo.ProductId) + case Action: + actionInfo, err := t.dbClient.ThingModelActionsById(id) + if err != nil { + return err + } + err = resourceContainer.DataDBClientFrom(t.dic.Get).DelDatabaseField(ctx, actionInfo.ProductId, actionInfo.Code) + if err != nil { + return err + } + err = t.dbClient.ThingModelDeleteAction(id) + if err != nil { + return err + } + t.ProductUpdateCallback(actionInfo.ProductId) + default: + return errort.NewCommonEdgeX(errort.DefaultReqParamsError, "param valida error", fmt.Errorf("req params error")) + } + return nil +} + +func (t thingModelApp) ProductUpdateCallback(productId string) { + go func() { + productService := resourceContainer.ProductAppNameFrom(t.dic.Get) + product, err := productService.ProductModelById(context.Background(), productId) + if err != nil { + return + } + productService.UpdateProductCallBack(product) + }() + +} + +func (t thingModelApp) SystemThingModelSearch(ctx context.Context, req dtos.SystemThingModelSearchReq) (interface{}, error) { + return t.dbClient.SystemThingModelSearch(req.ThingModelType, req.ModelName) +} + +func (t thingModelApp) OpenApiQueryThingModel(ctx context.Context, productId string) (dtos.OpenApiQueryThingModel, error) { + product, err := t.dbClient.ProductById(productId) + if err != nil { + return dtos.OpenApiQueryThingModel{}, err + } + + var response dtos.OpenApiQueryThingModel + + for _, property := range product.Properties { + response.Properties = append(response.Properties, dtos.OpenApiThingModelProperties{ + Id: property.Id, + Name: property.Name, + Code: property.Code, + AccessMode: property.AccessMode, + Require: property.Require, + TypeSpec: property.TypeSpec, + Description: property.Description, + }) + } + + if len(response.Properties) == 0 { + response.Properties = make([]dtos.OpenApiThingModelProperties, 0) + } + + for _, event := range product.Events { + response.Events = append(response.Events, dtos.OpenApiThingModelEvents{ + Id: event.Id, + EventType: event.EventType, + Name: event.Name, + Code: event.Code, + Description: event.Description, + Require: event.Require, + OutputParams: event.OutputParams, + }) + } + + if len(response.Events) == 0 { + response.Events = make([]dtos.OpenApiThingModelEvents, 0) + } + + for _, action := range product.Actions { + response.Services = append(response.Services, dtos.OpenApiThingModelServices{ + Id: action.Id, + Name: action.Name, + Code: action.Code, + Description: action.Description, + Require: action.Require, + CallType: action.CallType, + InputParams: action.InputParams, + OutputParams: action.OutputParams, + }) + } + + if len(response.Services) == 0 { + response.Services = make([]dtos.OpenApiThingModelServices, 0) + } + + return response, nil + +} + +func (t thingModelApp) OpenApiAddThingModel(ctx context.Context, req dtos.OpenApiThingModelAddOrUpdateReq) error { + + _, err := t.dbClient.ProductById(req.ProductId) + if err != nil { + return err + } + var properties []models.Properties + var events []models.Events + var action []models.Actions + for _, property := range req.Properties { + propertyId := property.Id + if propertyId == "" { + propertyId = utils.RandomNum() + } + properties = append(properties, models.Properties{ + Id: propertyId, + ProductId: req.ProductId, + Name: property.Name, + Code: property.Code, + AccessMode: property.AccessMode, + Require: property.Require, + TypeSpec: property.TypeSpec, + Description: property.Description, + Timestamps: models.Timestamps{ + Created: time.Now().UnixMilli(), + }, + }) + } + + for _, event := range req.Events { + eventId := event.Id + if eventId == "" { + eventId = utils.RandomNum() + } + events = append(events, models.Events{ + Id: eventId, + ProductId: req.ProductId, + Name: event.Name, + EventType: event.EventType, + Code: event.Code, + Description: event.Description, + Require: event.Require, + OutputParams: event.OutputParams, + Timestamps: models.Timestamps{ + Created: time.Now().UnixMilli(), + }, + }) + } + + for _, service := range req.Services { + serviceId := service.Id + if serviceId == "" { + serviceId = utils.RandomNum() + } + action = append(action, models.Actions{ + Id: serviceId, + ProductId: req.ProductId, + Name: service.Name, + Code: service.Code, + Description: service.Description, + Require: service.Require, + CallType: service.CallType, + InputParams: service.InputParams, + OutputParams: service.OutputParams, + Timestamps: models.Timestamps{ + Created: time.Now().UnixMilli(), + }, + }) + } + + //t.dbClient.ThingModelActionsById() + + var shouldCallBack bool + updateFunc := func(source interface{}, db *gorm.DB) error { + tx := db.Session(&gorm.Session{FullSaveAssociations: true}).Clauses(clause.OnConflict{UpdateAll: true}).Save(source) + return tx.Error + //tx := db.Save(&source).Error + //return tx + } + + db := t.dbClient.GetDBInstance() + //db.Begin() + if len(properties) > 0 { + shouldCallBack = true + err := updateFunc(properties, db) + if err != nil { + //db.Rollback() + return err + } + } + + if len(events) > 0 { + shouldCallBack = true + err := updateFunc(events, db) + if err != nil { + //db.Rollback() + return err + } + } + + if len(action) > 0 { + shouldCallBack = true + err := updateFunc(action, db) + if err != nil { + //db.Rollback() + return err + } + } + + //db.Commit() + if shouldCallBack { + t.ProductUpdateCallback(req.ProductId) + } + + return nil +} + +func (t thingModelApp) OpenApiDeleteThingModel(ctx context.Context, req dtos.OpenApiThingModelDeleteReq) error { + product, err := t.dbClient.ProductById(req.ProductId) + if err != nil { + return err + } + + var productPropertyIds []string + var productEventIds []string + var productService []string + + for _, property := range product.Properties { + productPropertyIds = append(productPropertyIds, property.Id) + } + + for _, event := range product.Events { + productEventIds = append(productEventIds, event.Id) + } + + for _, action := range product.Actions { + productService = append(productService, action.Id) + } + + for _, id := range req.PropertyIds { + if utils.InStringSlice(id, productPropertyIds) { + if err = t.dbClient.ThingModelDeleteProperty(id); err != nil { + return err + } + } + } + + for _, id := range req.EventIds { + if utils.InStringSlice(id, productEventIds) { + if err = t.dbClient.ThingModelDeleteEvent(id); err != nil { + return err + } + } + } + + for _, id := range req.ServiceIds { + if utils.InStringSlice(id, productService) { + if err = t.dbClient.ThingModelDeleteAction(id); err != nil { + return err + } + } + } + + return nil + +} diff --git a/internal/hummingbird/core/application/thingmodeltemplate/thingmodelapp.go b/internal/hummingbird/core/application/thingmodeltemplate/thingmodelapp.go new file mode 100644 index 0000000..d05a07b --- /dev/null +++ b/internal/hummingbird/core/application/thingmodeltemplate/thingmodelapp.go @@ -0,0 +1,175 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package thingmodeltemplate + +import ( + "context" + "encoding/json" + "github.com/winc-link/hummingbird/internal/dtos" + resourceContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "time" +) + +type thingModelTemplate struct { + dic *di.Container + dbClient interfaces.DBClient + lc logger.LoggingClient +} + +func (m thingModelTemplate) ThingModelTemplateSearch(ctx context.Context, req dtos.ThingModelTemplateRequest) ([]dtos.ThingModelTemplateResponse, uint32, error) { + offset, limit := req.BaseSearchConditionQuery.GetPage() + + thingModelTemplates, total, err := m.dbClient.ThingModelTemplateSearch(offset, limit, req) + if err != nil { + m.lc.Errorf("thingModelTemplate Search err %v", err) + return []dtos.ThingModelTemplateResponse{}, 0, err + } + + libs := make([]dtos.ThingModelTemplateResponse, len(thingModelTemplates)) + for i, thingModelTemplate := range thingModelTemplates { + libs[i] = dtos.ThingModelTemplateResponseFromModel(thingModelTemplate) + } + return libs, total, nil +} + +func (m thingModelTemplate) ThingModelTemplateByCategoryKey(ctx context.Context, categoryKey string) (dtos.ThingModelTemplateResponse, error) { + thingModelTemplate, err := m.dbClient.ThingModelTemplateByCategoryKey(categoryKey) + if err != nil { + m.lc.Errorf("thingModelTemplate Search err %v", err) + return dtos.ThingModelTemplateResponse{}, err + } + var libs dtos.ThingModelTemplateResponse + libs = dtos.ThingModelTemplateResponseFromModel(thingModelTemplate) + return libs, nil +} + +func (m thingModelTemplate) Sync(ctx context.Context, versionName string) (int64, error) { + filePath := versionName + "/thing_model_template.json" + cosApp := resourceContainer.CosAppNameFrom(m.dic.Get) + bs, err := cosApp.Get(filePath) + if err != nil { + m.lc.Errorf(err.Error()) + return 0, err + } + var cosThingModelTemplateResp []dtos.CosThingModelTemplateResponse + err = json.Unmarshal(bs, &cosThingModelTemplateResp) + if err != nil { + m.lc.Errorf(err.Error()) + return 0, err + } + + baseQuery := dtos.BaseSearchConditionQuery{ + IsAll: true, + } + dbreq := dtos.ThingModelTemplateRequest{BaseSearchConditionQuery: baseQuery} + thingModelTemplateResponse, _, err := m.ThingModelTemplateSearch(ctx, dbreq) + if err != nil { + return 0, err + } + + upsertThingModelTemplate := make([]models.ThingModelTemplate, 0) + for _, cosThingModelTemplate := range cosThingModelTemplateResp { + var find bool + for _, localThingModelResponse := range thingModelTemplateResponse { + if cosThingModelTemplate.CategoryKey == localThingModelResponse.CategoryKey { + upsertThingModelTemplate = append(upsertThingModelTemplate, models.ThingModelTemplate{ + Id: localThingModelResponse.Id, + CategoryName: cosThingModelTemplate.CategoryName, + CategoryKey: cosThingModelTemplate.CategoryKey, + ThingModelJSON: cosThingModelTemplate.ThingModelJSON, + }) + find = true + break + } + } + if !find { + upsertThingModelTemplate = append(upsertThingModelTemplate, models.ThingModelTemplate{ + Timestamps: models.Timestamps{ + Created: time.Now().Unix(), + }, + Id: utils.GenUUID(), + CategoryName: cosThingModelTemplate.CategoryName, + CategoryKey: cosThingModelTemplate.CategoryKey, + ThingModelJSON: cosThingModelTemplate.ThingModelJSON, + }) + } + } + rows, err := m.dbClient.BatchUpsertThingModelTemplate(upsertThingModelTemplate) + m.lc.Infof("upsert thingModelTemplate rows %+v", rows) + if err != nil { + return 0, err + } + var modelProperty []models.Properties + var modelEvent []models.Events + var modelAction []models.Actions + + propertyTemp := map[string]struct{}{} + eventTemp := map[string]struct{}{} + actionTemp := map[string]struct{}{} + + for _, cosThingModelTemplate := range cosThingModelTemplateResp { + p, e, a := dtos.GetModelPropertyEventActionByThingModelTemplate(cosThingModelTemplate.ThingModelJSON) + for _, properties := range p { + if _, ok := propertyTemp[properties.Code]; !ok { + properties.Id = utils.GenUUID() + properties.System = true + modelProperty = append(modelProperty, properties) + propertyTemp[properties.Code] = struct{}{} + } + } + for _, event := range e { + if _, ok := eventTemp[event.Code]; !ok { + event.Id = utils.GenUUID() + event.System = true + modelEvent = append(modelEvent, event) + eventTemp[event.Code] = struct{}{} + } + } + for _, action := range a { + if _, ok := actionTemp[action.Code]; !ok { + action.Id = utils.GenUUID() + action.System = true + modelAction = append(modelAction, action) + actionTemp[action.Code] = struct{}{} + } + } + } + m.dbClient.BatchDeleteSystemProperties() + m.dbClient.BatchDeleteSystemActions() + m.dbClient.BatchDeleteSystemEvents() + + m.dbClient.BatchInsertSystemProperties(modelProperty) + m.dbClient.BatchInsertSystemActions(modelAction) + m.dbClient.BatchInsertSystemEvents(modelEvent) + + return rows, nil +} + +func NewThingModelTemplateApp(ctx context.Context, dic *di.Container) interfaces.ThingModelTemplateApp { + lc := container.LoggingClientFrom(dic.Get) + dbClient := resourceContainer.DBClientFrom(dic.Get) + + return &thingModelTemplate{ + dic: dic, + dbClient: dbClient, + lc: lc, + } +} diff --git a/internal/hummingbird/core/application/timerapp/timerapp.go b/internal/hummingbird/core/application/timerapp/timerapp.go new file mode 100644 index 0000000..1198788 --- /dev/null +++ b/internal/hummingbird/core/application/timerapp/timerapp.go @@ -0,0 +1,320 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package timerapp + +import ( + "context" + "github.com/winc-link/hummingbird/internal/dtos" + resourceContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + + //"github.com/winc-link/hummingbird/internal/edge/sharp/module/timer/db" + + //"github.com/winc-link/hummingbird/internal/pkg/timer/db" + "github.com/winc-link/hummingbird/internal/pkg/timer/jobrunner" + "github.com/winc-link/hummingbird/internal/pkg/timer/jobs" + + "github.com/winc-link/hummingbird/internal/pkg/logger" + "sort" + "sync" + "time" +) + +type EdgeTimer struct { + mutex sync.Mutex + logger logger.LoggingClient + db interfaces.DBClient + + //db db.TimerDBClient + // job id map + jobMap map[string]struct{} + // 任务链表 有序 + entries []*entry + // 停止信号 + stop chan struct{} + // 添加任务channel + add chan *entry + // 更新任务 + //update chan *jobs.UpdateJobStu + // 删除任务 uuid + rm chan string + // 启动标志 + running bool + location *time.Location + f jobrunner.JobRunFunc +} + +func NewCronTimer(ctx context.Context, + f jobrunner.JobRunFunc, dic *di.Container) *EdgeTimer { + dbClient := resourceContainer.DBClientFrom(dic.Get) + l := container.LoggingClientFrom(dic.Get) + et := &EdgeTimer{ + logger: l, + db: dbClient, + rm: make(chan string), + add: make(chan *entry), + entries: nil, + jobMap: make(map[string]struct{}), + stop: make(chan struct{}), + running: false, + location: time.Local, + f: f, + } + // restore + et.restoreJobs() + + go et.run() + return et +} + +func (et *EdgeTimer) restoreJobs() { + scenes, _, _ := et.db.SceneSearch(0, -1, dtos.SceneSearchQueryRequest{}) + + if len(scenes) == 0 { + return + } + + for _, scene := range scenes { + if len(scene.Conditions) > 0 && scene.Status == constants.SceneStart { + if scene.Conditions[0].ConditionType == "timer" { + job, err := scene.ToRuntimeJob() + if err != nil { + et.logger.Errorf("restore jobs runtime job err %v", err.Error()) + continue + } + err = et.AddJobToRunQueue(job) + if err != nil { + et.logger.Errorf("restore jobs add job to queue err %v", err.Error()) + } + } + } + } + return +} + +func (et *EdgeTimer) Stop() { + et.mutex.Lock() + defer et.mutex.Unlock() + if et.running { + close(et.stop) + et.running = false + } +} + +type ( + // 任务 + entry struct { + JobID string + Schedule *jobs.JobSchedule + Next time.Time + Prev time.Time + } +) + +func (e entry) Valid() bool { return e.JobID != "" } + +type byTime []*entry + +func (s byTime) Len() int { return len(s) } +func (s byTime) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s byTime) Less(i, j int) bool { + if s[i].Next.IsZero() { + return false + } + if s[j].Next.IsZero() { + return true + } + return s[i].Next.Before(s[j].Next) +} + +func (et *EdgeTimer) now() time.Time { + return time.Now().In(et.location) +} + +func (et *EdgeTimer) run() { + et.mutex.Lock() + if et.running { + et.mutex.Unlock() + return + } + et.running = true + et.mutex.Unlock() + et.logger.Info("edge timer started...") + now := et.now() + for _, entry := range et.entries { + if next, b := entry.Schedule.Next(now); !b { + entry.Next = next + } + } + var timer = time.NewTimer(100000 * time.Hour) + for { + // Determine the next entry to run. + sort.Sort(byTime(et.entries)) + + if len(et.entries) == 0 || et.entries[0].Next.IsZero() { + // If there are no entries yet, just sleep - it still handles new entries + // and stop requests. + timer.Reset(100000 * time.Hour) + } else { + et.logger.Debugf("next wake time: %+v with jobID: %s", et.entries[0].Next, et.entries[0].JobID) + timer.Reset(et.entries[0].Next.Sub(now)) + } + + select { + case now = <-timer.C: + timer.Stop() + now = now.In(et.location) + et.logger.Infof("wake now: %+v with jobID: %s", now, et.entries[0].JobID) + var ( + //finished []int + //eIndex = len(et.entries) - 1 + ) + for i, e := range et.entries { + if e.Next.After(now) || e.Next.IsZero() { + break + } + // async call + go et.f(e.JobID, *et.entries[i].Schedule) + + //times := e.Schedule.ScheduleAdd1() + + if false { + //finished = append(finished, i) + } else { + e.Prev = e.Next + if next, b := e.Schedule.Next(now); !b { + e.Next = next + et.logger.Infof("run now: %+v, entry: jobId: %s, jobName: %s, next: %+v", now, e.JobID, e.Schedule.JobName, e.Next) + // update prev next and runtimes + //if err := et.db.UpdateRuntimeInfo(e.JobID, e.Prev.UnixMilli(), e.Next.UnixMilli(), times); err != nil { + // et.logger.Errorf("update job: %s runtime info error: %s, prev: %d, next: %d", + // e.JobID, err, e.Prev.Unix(), e.Next.Unix()) + //} + } + //} + } + } + //if len(finished) > 0 { + // for i := range finished { + // et.entries[finished[i]], et.entries[eIndex] = et.entries[eIndex], et.entries[finished[i]] + // eIndex-- + // } + // del := et.entries[eIndex+1:] + // ids := make([]string, 0, len(del)) + // for i := range del { + // ids = append(ids, del[i].JobID) + // } + // et.logger.Infof("jobs ended, delete from db: %+v", ids) + // if err := et.db.DeleteJobs(ids); err != nil { + // et.logger.Errorf("jobs ended, delete from db failure: %+v", ids) + // } + // et.entries = et.entries[:eIndex+1] + //} + //et.logger.Infof("entries len: %d", len(et.entries)) + case newEntry := <-et.add: + timer.Stop() + now = et.now() + if next, b := newEntry.Schedule.Next(now); !b { + newEntry.Next = next + et.entries = append(et.entries, newEntry) + et.logger.Infof("added job now: %+v, next: %+v", now, newEntry.Next) + } + et.logger.Infof("added job: %v, now: %+v, next: %+v", newEntry.JobID, now, newEntry.Next) + case entryID := <-et.rm: + timer.Stop() + now = et.now() + et.removeEntry(entryID) + case <-et.stop: + timer.Stop() + et.logger.Info("tedge timer stopped...") + return + } + } +} + +func (et *EdgeTimer) schedule(schedule *jobs.JobSchedule) { + et.mutex.Lock() + defer et.mutex.Unlock() + entry := &entry{ + JobID: schedule.GetJobId(), + Schedule: schedule, + } + if !et.running { + et.entries = append(et.entries, entry) + } else { + et.add <- entry + } +} + +func (et *EdgeTimer) remove(id string) { + if et.running { + et.rm <- id + } else { + et.removeEntry(id) + } +} + +func (et *EdgeTimer) removeEntry(id string) { + var b bool + et.mutex.Lock() + defer et.mutex.Unlock() + for i, e := range et.entries { + if e.JobID == id { + et.entries[i], et.entries[len(et.entries)-1] = et.entries[len(et.entries)-1], et.entries[i] + b = true + break + } + } + if b { + et.entries[len(et.entries)-1] = nil + et.entries = et.entries[:len(et.entries)-1] + delete(et.jobMap, id) + et.logger.Debugf("entry length: %d, deleted job id: %s", len(et.entries), id) + } else { + et.logger.Warnf("unknown jobs,id: %s", id) + } +} + +func (et *EdgeTimer) DeleteJob(id string) { + et.remove(id) +} + +func (et *EdgeTimer) AddJobToRunQueue(j *jobs.JobSchedule) error { + if _, ok := et.jobMap[j.JobID]; ok { + et.logger.Warnf("job is already in map: %s", j.JobID) + return nil + } + // check expire + //if exp, ok := j.TimeData.Expression.(jobs.CronExp); !ok { + // return errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("cron job expression error")) + //} else { + // if _, err := jobs.ParseStandard(exp.CronTab); err != nil { + // return err + // } + //} + + if _, err := jobs.ParseStandard(j.TimeData.Expression); err != nil { + return err + } + + et.schedule(j) + et.mutex.Lock() + defer et.mutex.Unlock() + et.jobMap[j.JobID] = struct{}{} + return nil +} diff --git a/internal/hummingbird/core/application/unittemplate/unitapp.go b/internal/hummingbird/core/application/unittemplate/unitapp.go new file mode 100644 index 0000000..c7ce041 --- /dev/null +++ b/internal/hummingbird/core/application/unittemplate/unitapp.go @@ -0,0 +1,114 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package unittemplate + +import ( + "context" + "encoding/json" + "github.com/winc-link/hummingbird/internal/dtos" + resourceContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "github.com/winc-link/hummingbird/internal/pkg/utils" +) + +type unitApp struct { + dic *di.Container + dbClient interfaces.DBClient + lc logger.LoggingClient +} + +func NewUnitTemplateApp(ctx context.Context, dic *di.Container) interfaces.UnitApp { + lc := container.LoggingClientFrom(dic.Get) + dbClient := resourceContainer.DBClientFrom(dic.Get) + + return &unitApp{ + dic: dic, + dbClient: dbClient, + lc: lc, + } +} + +func (m *unitApp) UnitTemplateSearch(ctx context.Context, req dtos.UnitRequest) ([]dtos.UnitResponse, uint32, error) { + offset, limit := req.BaseSearchConditionQuery.GetPage() + + unitTemplates, total, err := m.dbClient.UnitSearch(offset, limit, req) + if err != nil { + m.lc.Errorf("unit Templates Search err %v", err) + return []dtos.UnitResponse{}, 0, err + } + + libs := make([]dtos.UnitResponse, len(unitTemplates)) + for i, unitTemplate := range unitTemplates { + libs[i] = dtos.UnitTemplateResponseFromModel(unitTemplate) + } + return libs, total, nil +} + +func (m *unitApp) Sync(ctx context.Context, versionName string) (int64, error) { + filePath := versionName + "/unit_template.json" + cosApp := resourceContainer.CosAppNameFrom(m.dic.Get) + bs, err := cosApp.Get(filePath) + if err != nil { + m.lc.Errorf(err.Error()) + return 0, err + } + var cosUnitTemplateResp []dtos.CosUnitTemplateResponse + err = json.Unmarshal(bs, &cosUnitTemplateResp) + if err != nil { + m.lc.Errorf(err.Error()) + return 0, err + } + + baseQuery := dtos.BaseSearchConditionQuery{ + IsAll: true, + } + dbreq := dtos.UnitRequest{BaseSearchConditionQuery: baseQuery} + unitTemplateResponse, _, err := m.UnitTemplateSearch(ctx, dbreq) + if err != nil { + return 0, err + } + + upsertUnitTemplate := make([]models.Unit, 0) + for _, cosUnitTemplate := range cosUnitTemplateResp { + var find bool + for _, localTemplateResponse := range unitTemplateResponse { + if cosUnitTemplate.UnitName == localTemplateResponse.UnitName { + upsertUnitTemplate = append(upsertUnitTemplate, models.Unit{ + Id: localTemplateResponse.Id, + UnitName: cosUnitTemplate.UnitName, + Symbol: cosUnitTemplate.Symbol, + }) + find = true + break + } + } + if !find { + upsertUnitTemplate = append(upsertUnitTemplate, models.Unit{ + Id: utils.GenUUID(), + UnitName: cosUnitTemplate.UnitName, + Symbol: cosUnitTemplate.Symbol, + }) + } + } + rows, err := m.dbClient.BatchUpsertUnitTemplate(upsertUnitTemplate) + if err != nil { + return 0, err + } + return rows, nil +} diff --git a/internal/hummingbird/core/application/userapp/user.go b/internal/hummingbird/core/application/userapp/user.go new file mode 100644 index 0000000..0ae0a6f --- /dev/null +++ b/internal/hummingbird/core/application/userapp/user.go @@ -0,0 +1,257 @@ +package userapp + +import ( + "context" + "github.com/dgrijalva/jwt-go" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/models" + pkgcontainer "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "github.com/winc-link/hummingbird/internal/pkg/middleware" + "time" + + //"gitlab.com/tedge/edgex/internal/pkg/container" + //resourceContainer "gitlab.com/tedge/edgex/internal/tedge/resource/container" + // + //"gitlab.com/tedge/edgex/internal/pkg/di" + //"gitlab.com/tedge/edgex/internal/pkg/errort" + //"gitlab.com/tedge/edgex/internal/pkg/logger" + // + jwt2 "github.com/winc-link/hummingbird/internal/tools/jwt" + // + //"github.com/dgrijalva/jwt-go" + //"gitlab.com/tedge/edgex/internal/dtos" + //"gitlab.com/tedge/edgex/internal/models" + //"gitlab.com/tedge/edgex/internal/pkg/middleware" + //"gitlab.com/tedge/edgex/internal/tedge/resource/interfaces" + "golang.org/x/crypto/bcrypt" +) + +const ( + DefaultUserName = "admin" + DefaultLang = "en" +) + +var _ interfaces.UserItf = new(userApp) + +type userApp struct { + dic *di.Container + dbClient interfaces.DBClient + lc logger.LoggingClient +} + +func New(dic *di.Container) *userApp { + return &userApp{ + dic: dic, + lc: pkgcontainer.LoggingClientFrom(dic.Get), + dbClient: container.DBClientFrom(dic.Get), + } +} + +//UserLogin 用户登录 +func (uapp *userApp) UserLogin(ctx context.Context, req dtos.LoginRequest) (res dtos.LoginResponse, err error) { + // 从数据库用户信息 + user, edgeXErr := uapp.dbClient.GetUserByUserName(req.Username) + if edgeXErr != nil { + return res, errort.NewCommonEdgeX(errort.AppPasswordError, "", edgeXErr) + } + + // 校验密码 + cErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.Password)) + if cErr != nil { + err = errort.NewCommonErr(errort.AppPasswordError, cErr) + return + } + + j := jwt2.NewJWT(jwt2.JwtSignKey) + claims := middleware.CustomClaims{ + ID: 1, + Username: req.Username, + StandardClaims: jwt.StandardClaims{ + NotBefore: time.Now().Unix() - 1000, // 签名生效时间 + ExpiresAt: time.Now().Unix() + 60*60*24*3, // 过期时间 7天 + Issuer: jwt2.JwtIssuer, // 签名的发行者 + }, + } + + token, jwtErr := j.CreateToken(claims) + if jwtErr != nil { + err = jwtErr + return + } + lang := user.Lang + if lang == "" { + lang = DefaultLang + } + res = dtos.LoginResponse{ + User: dtos.UserResponse{ + Username: user.Username, + Lang: lang, + }, + ExpiresAt: claims.StandardClaims.ExpiresAt * 1000, + Token: token, + } + return +} + +//InitInfo 查询用户信息 +func (uapp *userApp) InitInfo() (res dtos.InitInfoResponse, err error) { + // 从数据库用户信息 + _, edgeXErr := uapp.dbClient.GetUserByUserName(DefaultUserName) + if edgeXErr != nil { + //if errort.NewCommonEdgeXWrapper(edgeXErr).Code() == { + if errort.Is(errort.DefaultResourcesNotFound, edgeXErr) { + res.IsInit = false + return + } + return + } + res.IsInit = true + return +} + +// InitPassword 初始化密码 +func (uapp *userApp) InitPassword(ctx context.Context, req dtos.InitPasswordRequest) error { + lc := uapp.lc + + // 从数据库用户信息 + _, edgeXErr := uapp.dbClient.GetUserByUserName(DefaultUserName) + if edgeXErr == nil { + return errort.NewCommonErr(errort.AppSystemInitialized, edgeXErr) + } + + // 生成新密码并存储 + newPasswordHash, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost) + if err != nil { + return err + } + newUser := models.User{ + Username: DefaultUserName, + Password: string(newPasswordHash), + Lang: DefaultLang, + OpenAPIKey: jwt2.GenerateJwtSignKey(), + GatewayKey: jwt2.GenerateJwtSignKey(), + } + jwt2.SetOpenAPIKey(newUser.OpenAPIKey) + jwt2.SetJwtSignKey(newUser.GatewayKey) + //db操作存储 + _, edgeXErr = uapp.dbClient.AddUser(newUser) + if edgeXErr != nil { + lc.Errorf("add user error %v", edgeXErr) + return edgeXErr + } + return nil +} + +// UpdateUserPassword 修改密码 +func (uapp *userApp) UpdateUserPassword(ctx context.Context, username string, req dtos.UpdatePasswordRequest) error { + // 从数据库用户信息 + user, edgeXErr := uapp.dbClient.GetUserByUserName(username) + if edgeXErr != nil { + return edgeXErr + } + err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.CurrentPassword)) + if err != nil { + return err + } + + // 生成新密码并存储 + newPasswordHash, gErr := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost) + if gErr != nil { + return err + } + user.Password = string(newPasswordHash) + //db操作存储 + edgeXErr = uapp.dbClient.UpdateUser(user) + if edgeXErr != nil { + return edgeXErr + } + return nil +} + +//OpenApiUserLogin openapi用户登录 +func (uapp *userApp) OpenApiUserLogin(ctx context.Context, req dtos.LoginRequest) (res *dtos.TokenDetail, err error) { + // 从数据库用户信息 + user, edgeXErr := uapp.dbClient.GetUserByUserName(req.Username) + if edgeXErr != nil { + return res, errort.NewCommonErr(errort.AppPasswordError, edgeXErr) + } + + // 校验密码 + cErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.Password)) + if cErr != nil { + err = errort.NewCommonErr(errort.AppPasswordError, cErr) + return + } + td, err := uapp.CreateTokenDetail(req.Username) + if err != nil { + return + } + return td, nil +} + +// CreateTokenDetail 根据用户名生成 token +func (uapp *userApp) CreateTokenDetail(userName string) (*dtos.TokenDetail, error) { + td := &dtos.TokenDetail{ + AccessId: "accessId", + RefreshId: "refreshId", + AtExpires: time.Now().Add(time.Minute * 120).Unix(), //两小时 + RtExpires: time.Now().Add(time.Hour * 24 * 14).Unix(), //两星期 + } + var ( + userID uint = 1 + err error + ) + td.AccessToken, err = uapp.createToken(userID, userName, td.AtExpires, jwt2.OpenAPIKey) + if err != nil { + return nil, err + } + td.RefreshToken, err = uapp.createToken(userID, userName, td.RtExpires, jwt2.RefreshKey) + if err != nil { + return nil, err + } + return td, nil +} + +// CreateToken 生成 token +func (uapp *userApp) createToken(useId uint, userName string, expire int64, signKey string) (string, error) { + j := jwt2.NewJWT(signKey) + claims := middleware.CustomClaims{ + ID: useId, + Username: userName, + StandardClaims: jwt.StandardClaims{ + NotBefore: time.Now().Unix(), // 签名生效时间 + ExpiresAt: expire, + Issuer: jwt2.JwtIssuer, // 签名的发行者 + }, + } + + token, jwtErr := j.CreateToken(claims) + if jwtErr != nil { + err := jwtErr + return "", err + } + return token, nil +} + +func (uapp *userApp) InitJwtKey() { + user, err := uapp.dbClient.GetUserByUserName(DefaultUserName) + if err != nil { + return + } + if user.GatewayKey == "" { + user.GatewayKey = jwt2.GenerateJwtSignKey() + } + if user.OpenAPIKey == "" { + user.OpenAPIKey = jwt2.GenerateJwtSignKey() + } + if err = uapp.dbClient.UpdateUser(user); err != nil { + panic(err) + } + jwt2.SetOpenAPIKey(user.OpenAPIKey) + jwt2.SetJwtSignKey(user.GatewayKey) +} diff --git a/internal/hummingbird/core/bootstrap/database/database.go b/internal/hummingbird/core/bootstrap/database/database.go new file mode 100644 index 0000000..5a6b467 --- /dev/null +++ b/internal/hummingbird/core/bootstrap/database/database.go @@ -0,0 +1,127 @@ +// +// Copyright (C) 2020 IOTech Ltd +// +// SPDX-License-Identifier: Apache-2.0 + +package database + +import ( + "context" + "errors" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/hummingbird/core/config" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + "github.com/winc-link/hummingbird/internal/hummingbird/core/infrastructure/mysql" + "github.com/winc-link/hummingbird/internal/hummingbird/core/infrastructure/sqlite" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/tools/datadb/tdengine" + + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "github.com/winc-link/hummingbird/internal/pkg/startup" + "github.com/winc-link/hummingbird/internal/tools/datadb/leveldb" + "sync" + + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + pkgContainer "github.com/winc-link/hummingbird/internal/pkg/container" +) + +// Database contains references to dependencies required by the database bootstrap implementation. +type Database struct { + database *config.ConfigurationStruct +} + +// NewDatabase is a factory method that returns an initialized Database receiver struct. +func NewDatabase(database *config.ConfigurationStruct) Database { + return Database{ + database: database, + } +} + +//Return the dbClient interfaces +func (d Database) newDBClient( + lc logger.LoggingClient) (interfaces.DBClient, error) { + + databaseInfo := d.database.GetDatabaseInfo()["Primary"] + switch databaseInfo.Type { + case string(constants.MySQL): + return mysql.NewClient(dtos.Configuration{ + Dsn: databaseInfo.Dsn, + }, lc) + case string(constants.SQLite): + return sqlite.NewClient(dtos.Configuration{ + Username: databaseInfo.Username, + Password: databaseInfo.Password, + DataSource: databaseInfo.DataSource, + }, lc) + default: + panic(errors.New("database configuration error")) + } +} + +func (d Database) newDataDBClient( + lc logger.LoggingClient) (interfaces.DataDBClient, error) { + dataDbInfo := d.database.GetDataDatabaseInfo()["Primary"] + + switch dataDbInfo.Type { + case string(constants.LevelDB): + return leveldb.NewClient(dtos.Configuration{ + DataSource: dataDbInfo.DataSource, + }, lc) + case string(constants.TDengine): + return tdengine.NewClient(dtos.Configuration{ + Dsn: dataDbInfo.Dsn, + }, lc) + default: + panic(errors.New("database configuration error")) + + } +} + +// BootstrapHandler fulfills the BootstrapHandler contract and initializes the database. +func (d Database) BootstrapHandler( + ctx context.Context, + wg *sync.WaitGroup, + startupTimer startup.Timer, + dic *di.Container) bool { + lc := pkgContainer.LoggingClientFrom(dic.Get) + + // initialize Metadata db. + dbClient, err := d.newDBClient(lc) + if err != nil { + panic(err) + } + + dic.Update(di.ServiceConstructorMap{ + container.DBClientInterfaceName: func(get di.Get) interface{} { + return dbClient + }, + }) + + // initialize Data db. + dataDbClient, err := d.newDataDBClient(lc) + if err != nil { + panic(err) + } + + dic.Update(di.ServiceConstructorMap{ + container.DataDBClientInterfaceName: func(get di.Get) interface{} { + return dataDbClient + }, + }) + + lc.Info("DatabaseInfo connected") + + wg.Add(1) + go func() { + defer wg.Done() + select { + case <-ctx.Done(): + interfaces.DMIFrom(di.GContainer.Get).StopAllInstance() //stop all instance + container.DBClientFrom(di.GContainer.Get).CloseSession() + container.DataDBClientFrom(di.GContainer.Get).CloseSession() + lc.Info("DatabaseInfo disconnected") + } + }() + return true +} diff --git a/internal/hummingbird/core/config/config.go b/internal/hummingbird/core/config/config.go new file mode 100644 index 0000000..3004f53 --- /dev/null +++ b/internal/hummingbird/core/config/config.go @@ -0,0 +1,204 @@ +/******************************************************************************* + * Copyright 2023 Winc link Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package config + +import ( + "fmt" + + "go.uber.org/atomic" + + bootstrapConfig "github.com/winc-link/hummingbird/internal/pkg/config" +) + +var ( + DefaultEnv = "pro" +) + +// Struct used to parse the JSON configuration file +type ConfigurationStruct struct { + Writable WritableInfo + MessageQueue MessageQueueInfo + Clients map[string]bootstrapConfig.ClientInfo + Databases map[string]map[string]bootstrapConfig.Database + Registry bootstrapConfig.RegistryInfo + Service bootstrapConfig.ServiceInfo + RpcServer bootstrapConfig.RPCServiceInfo + SecretStore bootstrapConfig.SecretStoreInfo + WebServer bootstrapConfig.ServiceInfo + DockerManage DockerManage + ApplicationSettings ApplicationSettings + Topics struct { + CommandTopic TopicInfo + } +} + +type WritableInfo struct { + PersistData atomic.Bool `toml:"-"` + PersistPeriod atomic.Int32 `toml:"-"` + LogLevel string + LogPath string + InsecureSecrets bootstrapConfig.InsecureSecrets + DebugProfile bool + IsNewModel bool + LimitMethods []string +} + +type TopicInfo struct { + Topic string +} + +type DockerManage struct { + ContainerConfigPath string + HostRootDir string + DockerApiVersion string + Privileged bool +} + +// MessageQueueInfo provides parameters related to connecting to a message bus +type MessageQueueInfo struct { + // Host is the hostname or IP address of the broker, if applicable. + Host string + // Port defines the port on which to access the message queue. + Port int + // Protocol indicates the protocol to use when accessing the message queue. + Protocol string + // Indicates the message queue platform being used. + Type string + // Indicates the topic the data is published/subscribed + SubscribeTopics []string + // Indicates the topic prefix the data is published to. Note that // will be + // added to this Publish Topic prefix as the complete publish topic + PublishTopicPrefix string + // Provides additional configuration properties which do not fit within the existing field. + // Typically the key is the name of the configuration property and the value is a string representation of the + // desired value for the configuration property. + Optional map[string]string +} + +func (m MessageQueueInfo) Enable() bool { + return !(m.Host == "" || m.Port <= 0) +} + +type ApplicationSettings struct { + DeviceIds string + ResDir string + DriverDir string + IsScreenControl bool + OTADir string + CloseAuthToken bool + GatewayEdition string // 网关版本标识 + WebBuildPath string // 前端路径 + TedgeNumber string +} + +// URL constructs a URL from the protocol, host and port and returns that as a string. +func (m MessageQueueInfo) URL() string { + return fmt.Sprintf("%s://%s:%v", m.Protocol, m.Host, m.Port) +} + +// UpdateFromRaw converts configuration received from the registry to a service-specific configuration struct which is +// then used to overwrite the service's existing configuration struct. +func (c *ConfigurationStruct) UpdateFromRaw(rawConfig interface{}) bool { + configuration, ok := rawConfig.(*ConfigurationStruct) + if ok { + // Check that information was successfully read from Registry + if configuration.Service.Port == 0 { + return false + } + *c = *configuration + } + return ok +} + +// EmptyWritablePtr returns a pointer to a service-specific empty WritableInfo struct. It is used by the bootstrap to +// provide the appropriate structure to registry.C's WatchForChanges(). +func (c *ConfigurationStruct) EmptyWritablePtr() interface{} { + return &WritableInfo{} +} + +// UpdateWritableFromRaw converts configuration received from the registry to a service-specific WritableInfo struct +// which is then used to overwrite the service's existing configuration's WritableInfo struct. +func (c *ConfigurationStruct) UpdateWritableFromRaw(rawWritable interface{}) bool { + writable, ok := rawWritable.(*WritableInfo) + if ok { + c.Writable = *writable + } + return ok +} + +// GetBootstrap returns the configuration elements required by the bootstrap. Currently, a copy of the configuration +// data is returned. This is intended to be temporary -- since ConfigurationStruct drives the configuration.toml's +// structure -- until we can make backwards-breaking configuration.toml changes (which would consolidate these fields +// into an bootstrapConfig.BootstrapConfiguration struct contained within ConfigurationStruct). +func (c *ConfigurationStruct) GetBootstrap() bootstrapConfig.BootstrapConfiguration { + // temporary until we can make backwards-breaking configuration.toml change + return bootstrapConfig.BootstrapConfiguration{ + Clients: c.Clients, + Service: c.Service, + RpcServer: c.RpcServer, + Registry: c.Registry, + SecretStore: c.SecretStore, + } +} + +// GetLogLevel returns the current ConfigurationStruct's log level. +func (c *ConfigurationStruct) GetLogLevel() string { + return c.Writable.LogLevel +} + +func (c *ConfigurationStruct) GetLogPath() string { + return c.Writable.LogPath +} + +// GetRegistryInfo returns the RegistryInfo from the ConfigurationStruct. +func (c *ConfigurationStruct) GetRegistryInfo() bootstrapConfig.RegistryInfo { + return c.Registry +} + +// GetDatabaseInfo returns a database information map. +func (c *ConfigurationStruct) GetDatabaseInfo() map[string]bootstrapConfig.Database { + cfg := c.Databases["Metadata"] + return cfg +} + +// GetDataDatabaseInfo returns a database information map for events & readings. +func (c *ConfigurationStruct) GetDataDatabaseInfo() map[string]bootstrapConfig.Database { + cfg := c.Databases["Data"] + return cfg +} + +// GetDataDatabaseInfo returns a database information map for events & readings. +func (c *ConfigurationStruct) GetRedisInfo() map[string]bootstrapConfig.Database { + cfg := c.Databases["Redis"] + return cfg +} + +// GetInsecureSecrets returns the service's InsecureSecrets. +func (c *ConfigurationStruct) GetInsecureSecrets() bootstrapConfig.InsecureSecrets { + return c.Writable.InsecureSecrets +} + +// 判断是否为物模型 +func (c *ConfigurationStruct) IsThingModel() bool { + return c.Writable.IsNewModel +} + +func (c *ConfigurationStruct) GetPersistData() bool { + return c.Writable.PersistData.Load() +} + +func (c *ConfigurationStruct) GetPersisPeriod() int32 { + return c.Writable.PersistPeriod.Load() +} diff --git a/internal/hummingbird/core/container/agentitf.go b/internal/hummingbird/core/container/agentitf.go new file mode 100644 index 0000000..2e75572 --- /dev/null +++ b/internal/hummingbird/core/container/agentitf.go @@ -0,0 +1,12 @@ +package container + +import ( + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/tools/agentclient" +) + +var AgentClientName = di.TypeInstanceToName((*agentclient.AgentClient)(nil)) + +func AgentClientNameFrom(get di.Get) agentclient.AgentClient { + return get(AgentClientName).(agentclient.AgentClient) +} diff --git a/internal/hummingbird/core/container/alertruleitf.go b/internal/hummingbird/core/container/alertruleitf.go new file mode 100644 index 0000000..0cbf256 --- /dev/null +++ b/internal/hummingbird/core/container/alertruleitf.go @@ -0,0 +1,36 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package container + +import ( + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/di" +) + +var ( + AlertRuleAppName = di.TypeInstanceToName((*interfaces.AlertRuleApp)(nil)) +) + +func AlertRuleAppNameFrom(get di.Get) interfaces.AlertRuleApp { + return get(AlertRuleAppName).(interfaces.AlertRuleApp) +} + +var ( + RuleEngineAppName = di.TypeInstanceToName((*interfaces.RuleEngineApp)(nil)) +) + +func RuleEngineAppNameFrom(get di.Get) interfaces.RuleEngineApp { + return get(RuleEngineAppName).(interfaces.RuleEngineApp) +} diff --git a/internal/hummingbird/core/container/categorytemplate.go b/internal/hummingbird/core/container/categorytemplate.go new file mode 100644 index 0000000..48f6a3a --- /dev/null +++ b/internal/hummingbird/core/container/categorytemplate.go @@ -0,0 +1,60 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package container + +import ( + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/di" +) + +var ( + CategoryTemplateAppName = di.TypeInstanceToName((*interfaces.CategoryApp)(nil)) +) + +func CategoryTemplateAppFrom(get di.Get) interfaces.CategoryApp { + return get(CategoryTemplateAppName).(interfaces.CategoryApp) +} + +var ( + ThingModelTemplateAppName = di.TypeInstanceToName((*interfaces.ThingModelTemplateApp)(nil)) +) + +func ThingModelTemplateAppFrom(get di.Get) interfaces.ThingModelTemplateApp { + return get(ThingModelTemplateAppName).(interfaces.ThingModelTemplateApp) +} + +var ( + UnitTemplateAppName = di.TypeInstanceToName((*interfaces.UnitApp)(nil)) +) + +func UnitTemplateAppFrom(get di.Get) interfaces.UnitApp { + return get(UnitTemplateAppName).(interfaces.UnitApp) +} + +var ( + DocsAppName = di.TypeInstanceToName((*interfaces.DocsApp)(nil)) +) + +func DocsTemplateAppFrom(get di.Get) interfaces.DocsApp { + return get(DocsAppName).(interfaces.DocsApp) +} + +var ( + QuickNavigationAppName = di.TypeInstanceToName((*interfaces.QuickNavigation)(nil)) +) + +func QuickNavigationAppTemplateAppFrom(get di.Get) interfaces.QuickNavigation { + return get(QuickNavigationAppName).(interfaces.QuickNavigation) +} diff --git a/internal/hummingbird/core/container/config.go b/internal/hummingbird/core/container/config.go new file mode 100644 index 0000000..e93163c --- /dev/null +++ b/internal/hummingbird/core/container/config.go @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright 2019 Dell Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package container + +import ( + "github.com/winc-link/hummingbird/internal/hummingbird/core/config" + "github.com/winc-link/hummingbird/internal/pkg/di" +) + +// ConfigurationName contains the name of the resource's config.ConfigurationStruct implementation in the DIC. +var ConfigurationName = di.TypeInstanceToName((*config.ConfigurationStruct)(nil)) + +// ConfigurationFrom helper function queries the DIC and returns resource's config.ConfigurationStruct implementation. +func ConfigurationFrom(get di.Get) *config.ConfigurationStruct { + return get(ConfigurationName).(*config.ConfigurationStruct) +} diff --git a/internal/hummingbird/core/container/cos.go b/internal/hummingbird/core/container/cos.go new file mode 100644 index 0000000..0a919ca --- /dev/null +++ b/internal/hummingbird/core/container/cos.go @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package container + +import ( + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/di" +) + +var ( + CosAppName = di.TypeInstanceToName((*interfaces.CosApp)(nil)) +) + +func CosAppNameFrom(get di.Get) interfaces.CosApp { + return get(CosAppName).(interfaces.CosApp) +} diff --git a/internal/hummingbird/core/container/database.go b/internal/hummingbird/core/container/database.go new file mode 100644 index 0000000..451fa64 --- /dev/null +++ b/internal/hummingbird/core/container/database.go @@ -0,0 +1,19 @@ +// +// Copyright (C) 2020 IOTech Ltd +// +// SPDX-License-Identifier: Apache-2.0 + +package container + +import ( + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/di" +) + +// DBClientInterfaceName contains the name of the interfaces.DBClient implementation in the DIC. +var DBClientInterfaceName = di.TypeInstanceToName((*interfaces.DBClient)(nil)) + +// DBClientFrom helper function queries the DIC and returns the interfaces.DBClient implementation. +func DBClientFrom(get di.Get) interfaces.DBClient { + return get(DBClientInterfaceName).(interfaces.DBClient) +} diff --git a/internal/hummingbird/core/container/datadb.go b/internal/hummingbird/core/container/datadb.go new file mode 100644 index 0000000..0562097 --- /dev/null +++ b/internal/hummingbird/core/container/datadb.go @@ -0,0 +1,12 @@ +package container + +import ( + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/di" +) + +var DataDBClientInterfaceName = di.TypeInstanceToName((*interfaces.DataDBClient)(nil)) + +func DataDBClientFrom(get di.Get) interfaces.DataDBClient { + return get(DataDBClientInterfaceName).(interfaces.DataDBClient) +} diff --git a/internal/hummingbird/core/container/dataresource.go b/internal/hummingbird/core/container/dataresource.go new file mode 100644 index 0000000..41a579b --- /dev/null +++ b/internal/hummingbird/core/container/dataresource.go @@ -0,0 +1,26 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package container + +import ( + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/di" +) + +var DataResourceName = di.TypeInstanceToName((*interfaces.DataResourceApp)(nil)) + +func DataResourceFrom(get di.Get) interfaces.DataResourceApp { + return get(DataResourceName).(interfaces.DataResourceApp) +} diff --git a/internal/hummingbird/core/container/deviceitf.go b/internal/hummingbird/core/container/deviceitf.go new file mode 100644 index 0000000..148452c --- /dev/null +++ b/internal/hummingbird/core/container/deviceitf.go @@ -0,0 +1,14 @@ +package container + +import ( + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/di" +) + +// DeviceItfName +var DeviceItfName = di.TypeInstanceToName((*interfaces.DeviceItf)(nil)) + +// DeviceItfFrom +func DeviceItfFrom(get di.Get) interfaces.DeviceItf { + return get(DeviceItfName).(interfaces.DeviceItf) +} diff --git a/internal/hummingbird/core/container/driverift.go b/internal/hummingbird/core/container/driverift.go new file mode 100644 index 0000000..019d437 --- /dev/null +++ b/internal/hummingbird/core/container/driverift.go @@ -0,0 +1,23 @@ +package container + +import ( + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/di" +) + +// DI +var ( + DriverAppName = di.TypeInstanceToName((*interfaces.DriverLibApp)(nil)) +) + +func DriverAppFrom(get di.Get) interfaces.DriverLibApp { + return get(DriverAppName).(interfaces.DriverLibApp) +} + +var ( + DriverServiceAppName = di.TypeInstanceToName((*interfaces.DriverServiceApp)(nil)) +) + +func DriverServiceAppFrom(get di.Get) interfaces.DriverServiceApp { + return get(DriverServiceAppName).(interfaces.DriverServiceApp) +} diff --git a/internal/hummingbird/core/container/ekuiperitf.go b/internal/hummingbird/core/container/ekuiperitf.go new file mode 100644 index 0000000..8932077 --- /dev/null +++ b/internal/hummingbird/core/container/ekuiperitf.go @@ -0,0 +1,29 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package container + +import ( + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/tools/ekuiperclient" + //"github.com/winc-link/hummingbird/internal/tools/ekuiperclient" +) + +var ( + EkuiperAppName = di.TypeInstanceToName((*ekuiperclient.EkuiperClient)(nil)) +) + +func EkuiperAppFrom(get di.Get) ekuiperclient.EkuiperClient { + return get(EkuiperAppName).(ekuiperclient.EkuiperClient) +} diff --git a/internal/hummingbird/core/container/homepage.go b/internal/hummingbird/core/container/homepage.go new file mode 100644 index 0000000..b034f38 --- /dev/null +++ b/internal/hummingbird/core/container/homepage.go @@ -0,0 +1,25 @@ +/******************************************************************************* + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package container + +import ( + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/di" +) + +var ( + HomePageAppName = di.TypeInstanceToName((*interfaces.HomePageItf)(nil)) +) + +func HomePageAppNameFrom(get di.Get) interfaces.HomePageItf { + return get(HomePageAppName).(interfaces.HomePageItf) +} diff --git a/internal/hummingbird/core/container/hpcloud.go b/internal/hummingbird/core/container/hpcloud.go new file mode 100644 index 0000000..ed11ccc --- /dev/null +++ b/internal/hummingbird/core/container/hpcloud.go @@ -0,0 +1,29 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package container + +import ( + //interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/tools/hpcloudclient" +) + +var ( + HpcServiceAppName = di.TypeInstanceToName((*hpcloudclient.Hpcloud)(nil)) +) + +func HpcServiceAppFrom(get di.Get) hpcloudclient.Hpcloud { + return get(HpcServiceAppName).(hpcloudclient.Hpcloud) +} diff --git a/internal/hummingbird/core/container/langeuageitf.go b/internal/hummingbird/core/container/langeuageitf.go new file mode 100644 index 0000000..99adae1 --- /dev/null +++ b/internal/hummingbird/core/container/langeuageitf.go @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package container + +import ( + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/di" +) + +var ( + LanguageSDKAppName = di.TypeInstanceToName((*interfaces.LanguageSDKApp)(nil)) +) + +func LanguageAppNameFrom(get di.Get) interfaces.LanguageSDKApp { + return get(LanguageSDKAppName).(interfaces.LanguageSDKApp) +} diff --git a/internal/hummingbird/core/container/messageitf.go b/internal/hummingbird/core/container/messageitf.go new file mode 100644 index 0000000..a6fd5bb --- /dev/null +++ b/internal/hummingbird/core/container/messageitf.go @@ -0,0 +1,18 @@ +package container + +import ( + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/di" +) + +var MessageItfName = di.TypeInstanceToName((*interfaces.MessageItf)(nil)) + +func MessageItfFrom(get di.Get) interfaces.MessageItf { + return get(MessageItfName).(interfaces.MessageItf) +} + +var MessageStoreItfName = di.TypeInstanceToName((*interfaces.MessageStores)(nil)) + +func MessageStoreItfFrom(get di.Get) interfaces.MessageStores { + return get(MessageStoreItfName).(interfaces.MessageStores) +} diff --git a/internal/hummingbird/core/container/monitoritf.go b/internal/hummingbird/core/container/monitoritf.go new file mode 100644 index 0000000..012b1a9 --- /dev/null +++ b/internal/hummingbird/core/container/monitoritf.go @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package container + +import ( + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/di" +) + +var ( + MonitorAppName = di.TypeInstanceToName((*interfaces.MonitorItf)(nil)) +) + +func MonitorAppNameFrom(get di.Get) interfaces.MonitorItf { + return get(MonitorAppName).(interfaces.MonitorItf) +} diff --git a/internal/hummingbird/core/container/notifyitf.go b/internal/hummingbird/core/container/notifyitf.go new file mode 100644 index 0000000..a8abac0 --- /dev/null +++ b/internal/hummingbird/core/container/notifyitf.go @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package container + +import ( + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/tools/notify/sms" +) + +var ( + SmsServiceAppName = di.TypeInstanceToName((*sms.SMSer)(nil)) +) + +func SmsServiceAppFrom(get di.Get) sms.SMSer { + return get(SmsServiceAppName).(sms.SMSer) +} diff --git a/internal/hummingbird/core/container/persistitf.go b/internal/hummingbird/core/container/persistitf.go new file mode 100644 index 0000000..9f2b242 --- /dev/null +++ b/internal/hummingbird/core/container/persistitf.go @@ -0,0 +1,25 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package container + +import ( + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/di" +) + +var PersistItfName = di.TypeInstanceToName((*interfaces.PersistItf)(nil)) + +func PersistItfFrom(get di.Get) interfaces.PersistItf { + return get(PersistItfName).(interfaces.PersistItf) +} diff --git a/internal/hummingbird/core/container/productitf.go b/internal/hummingbird/core/container/productitf.go new file mode 100644 index 0000000..0a97ef7 --- /dev/null +++ b/internal/hummingbird/core/container/productitf.go @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package container + +import ( + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/di" +) + +var ( + ProductAppName = di.TypeInstanceToName((*interfaces.ProductItf)(nil)) +) + +func ProductAppNameFrom(get di.Get) interfaces.ProductItf { + return get(ProductAppName).(interfaces.ProductItf) +} diff --git a/internal/hummingbird/core/container/scene.go b/internal/hummingbird/core/container/scene.go new file mode 100644 index 0000000..6a87996 --- /dev/null +++ b/internal/hummingbird/core/container/scene.go @@ -0,0 +1,36 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package container + +import ( + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/di" +) + +var ( + SceneAppName = di.TypeInstanceToName((*interfaces.SceneApp)(nil)) +) + +func SceneAppNameFrom(get di.Get) interfaces.SceneApp { + return get(SceneAppName).(interfaces.SceneApp) +} + +var ( + ConJobAppName = di.TypeInstanceToName((*interfaces.ConJob)(nil)) +) + +func ConJobAppNameFrom(get di.Get) interfaces.ConJob { + return get(ConJobAppName).(interfaces.ConJob) +} diff --git a/internal/hummingbird/core/container/systemitf.go b/internal/hummingbird/core/container/systemitf.go new file mode 100644 index 0000000..bec1fc1 --- /dev/null +++ b/internal/hummingbird/core/container/systemitf.go @@ -0,0 +1,16 @@ +package container + +import ( + //"gitlab.com/tedge/edgex/internal/pkg/di" + //"gitlab.com/tedge/edgex/internal/tedge/resource/interfaces" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/di" +) + +// SystemItfName +var SystemItfName = di.TypeInstanceToName((*interfaces.SystemItf)(nil)) + +// SystemItfFrom +func SystemItfFrom(get di.Get) interfaces.SystemItf { + return get(SystemItfName).(interfaces.SystemItf) +} diff --git a/internal/hummingbird/core/container/thingmodel.go b/internal/hummingbird/core/container/thingmodel.go new file mode 100644 index 0000000..069e9a8 --- /dev/null +++ b/internal/hummingbird/core/container/thingmodel.go @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package container + +import ( + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/di" +) + +var ( + ThingModelAppName = di.TypeInstanceToName((*interfaces.ThingModelItf)(nil)) +) + +func ThingModelAppNameFrom(get di.Get) interfaces.ThingModelItf { + return get(ThingModelAppName).(interfaces.ThingModelItf) +} diff --git a/internal/hummingbird/core/container/userift.go b/internal/hummingbird/core/container/userift.go new file mode 100644 index 0000000..80b1b88 --- /dev/null +++ b/internal/hummingbird/core/container/userift.go @@ -0,0 +1,14 @@ +package container + +import ( + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/di" +) + +// UserItfName +var UserItfName = di.TypeInstanceToName((*interfaces.UserItf)(nil)) + +// UserItfFrom +func UserItfFrom(get di.Get) interfaces.UserItf { + return get(UserItfName).(interfaces.UserItf) +} diff --git a/internal/hummingbird/core/controller/http/gateway/agent.go b/internal/hummingbird/core/controller/http/gateway/agent.go new file mode 100644 index 0000000..44070ef --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/agent.go @@ -0,0 +1,57 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package gateway + +import ( + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" +) + +// @Tags 运维管理 +// @Summary 获取系统性能 +// @Produce json +// @Param request query dtos.SystemMetricsQuery true "参数" +// @Success 200 {object} dtos.SystemMetricsResponse +// @Router /api/v1/metrics/system [get] +//func (ctl *controller) GetSystemMetricsHandle(c *gin.Context) { +// ctl.ProxyAgentServer(c) +//} + +func (c *controller) SystemMetricsHandler(ctx *gin.Context) { + var query = dtos.SystemMetricsQuery{} + if err := ctx.BindQuery(&query); err != nil { + httphelper.RenderFail(ctx, errort.NewCommonErr(errort.DefaultReqParamsError, err), ctx.Writer, c.lc) + return + } + metrics, err := c.getSystemMonitorApp().GetSystemMetrics(ctx, query) + if err != nil { + httphelper.RenderFail(ctx, err, ctx.Writer, c.lc) + return + } + + httphelper.ResultSuccess(metrics, ctx.Writer, c.lc) +} + +// @Tags 运维管理 +// @Summary 操作服务重启 +// @Produce json +// @Param request body dtos.Operation true "操作" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/operation [post] +//func (ctl *controller) OperationServiceHandle(c *gin.Context) { +// ctl.ProxyAgentServer(c) +//} diff --git a/internal/hummingbird/core/controller/http/gateway/categroytemplate.go b/internal/hummingbird/core/controller/http/gateway/categroytemplate.go new file mode 100644 index 0000000..e57f1e6 --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/categroytemplate.go @@ -0,0 +1,60 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package gateway + +import ( + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" +) + +// @Tags 标准产品品类 +// @Summary 标准产品品类列表 +// @Produce json +// @Param request query dtos.CategoryTemplateRequest true "参数" +// @Success 200 {array} dtos.CategoryTemplateResponse +// @Router /api/v1/category-template [get] +//@Security ApiKeyAuth +func (ctl *controller) CategoryTemplateSearch(c *gin.Context) { + lc := ctl.lc + var req dtos.CategoryTemplateRequest + urlDecodeParam(&req, c.Request, lc) + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + data, total, edgeXErr := ctl.getCategoryTemplateApp().CategoryTemplateSearch(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + pageResult := httphelper.NewPageResult(data, total, req.Page, req.PageSize) + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} + +// @Tags 标准产品品类 +// @Summary 同步标准产品品类 +// @Produce json +// @Param request query dtos.CategoryTemplateRequest true "参数" +// @Router /api/v1/category-template/sync [post] +//@Security ApiKeyAuth +func (ctl *controller) CategoryTemplateSync(c *gin.Context) { + lc := ctl.lc + var req dtos.CategoryTemplateSyncRequest + urlDecodeParam(&req, c.Request, lc) + _, edgeXErr := ctl.getCategoryTemplateApp().Sync(c, "Ireland") + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} diff --git a/internal/hummingbird/core/controller/http/gateway/cloudinstance.go b/internal/hummingbird/core/controller/http/gateway/cloudinstance.go new file mode 100644 index 0000000..de955f3 --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/cloudinstance.go @@ -0,0 +1,39 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package gateway + +import ( + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" +) + +// @Tags 我的云服务实例 +// @Summary 云服务实例列表 +// @Produce json +// @Param request query dtos.CloudInstanceSearchQueryRequest true "参数" +// @Success 200 {object} dtos.CloudInstanceSearchQueryRequest +// @Router /api/v1/cloud-instance [get] +//@Security ApiKeyAuth +func (ctl *controller) CloudInstanceSearch(c *gin.Context) { + lc := ctl.lc + //var req dtos.CloudInstanceSearchQueryRequest + //urlDecodeParam(&req, c.Request, lc) + //dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + //data, total := 1,0 + data := make([]string, 0) + total := 0 + pageResult := httphelper.NewPageResult(data, uint32(total), 1, 10) + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} diff --git a/internal/hummingbird/core/controller/http/gateway/common.go b/internal/hummingbird/core/controller/http/gateway/common.go new file mode 100644 index 0000000..3aa0adf --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/common.go @@ -0,0 +1,68 @@ +package gateway + +import ( + "fmt" + "github.com/gin-gonic/gin" + "github.com/gorilla/schema" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "net/http" + "net/http/httputil" + "net/url" +) + +const ( + UrlParamSceneId = "sceneId" + UrlParamActionId = "actionId" + UrlParamStrategyId = "strategyId" + UrlParamConditionId = "conditionId" + UrlParamJobId = "jobId" + UrlParamProductId = "productId" + UrlParamCategoryKey = "categoryKey" + UrlParamCloudInstanceId = "cloudInstanceId" + UrlParamDeviceId = "deviceId" + UrlParamFuncPointId = "funcPointId" + UrlParamDeviceLibraryId = "deviceLibraryId" + UrlParamDeviceServiceId = "deviceServiceId" + UrlParamDockerConfigId = "dockerConfigId" + UrlParamRuleId = "ruleId" + UrlDataResourceId = "dataResourceId" + RuleEngineId = "ruleEngineId" +) + +var decoder *schema.Decoder + +func init() { + decoder = schema.NewDecoder() + decoder.IgnoreUnknownKeys(true) +} + +func urlDecodeParam(obj interface{}, r *http.Request, lc logger.LoggingClient) { + err := decoder.Decode(obj, r.URL.Query()) + if err != nil { + lc.Errorf("url decoding err %v", err) + } +} + +func (ctl *controller) ProxyAgentServer(c *gin.Context) { + proxy := ctl.cfg.Clients["Agent"] + ctl.lc.Infof("agentProxy: %v", proxy) + ctl.ServeHTTP(c, fmt.Sprintf("http://%v:%v", proxy.Host, proxy.Port)) +} + +//func (ctl *controller) ProxySharpServer(c *gin.Context) { +// proxy := ctl.cfg.Clients["Sharp"] +// ctl.lc.Infof("sharpProxy: %v", proxy) +// ctl.ServeHTTP(c, fmt.Sprintf("http://%v:%v", proxy.Host, proxy.Port)) +//} + +func (ctl *controller) ServeHTTP(c *gin.Context, URL string) { + parseRootUrl, err := url.Parse(URL) + if err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, ctl.lc) + return + } + proxy := httputil.NewSingleHostReverseProxy(parseRootUrl) + proxy.ServeHTTP(c.Writer, c.Request) +} diff --git a/internal/hummingbird/core/controller/http/gateway/controller.go b/internal/hummingbird/core/controller/http/gateway/controller.go new file mode 100644 index 0000000..705ade5 --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/controller.go @@ -0,0 +1,120 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package gateway + +import ( + "github.com/winc-link/hummingbird/internal/hummingbird/core/config" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + pkgcontainer "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" +) + +type controller struct { + dic *di.Container + lc logger.LoggingClient + cfg *config.ConfigurationStruct +} + +func New(dic *di.Container) *controller { + lc := pkgcontainer.LoggingClientFrom(dic.Get) + cfg := container.ConfigurationFrom(dic.Get) + return &controller{ + dic: dic, + lc: lc, + cfg: cfg, + } +} + +func (ctl *controller) getDriverLibApp() interfaces.DriverLibApp { + return container.DriverAppFrom(ctl.dic.Get) +} + +func (ctl *controller) getUserApp() interfaces.UserItf { + return container.UserItfFrom(ctl.dic.Get) +} + +func (ctl *controller) getDriverServiceApp() interfaces.DriverServiceApp { + return container.DriverServiceAppFrom(ctl.dic.Get) +} + +func (ctl *controller) getSystemMonitorApp() interfaces.MonitorItf { + return container.MonitorAppNameFrom(ctl.dic.Get) +} + +func (ctl *controller) getLanguageApp() interfaces.LanguageSDKApp { + return container.LanguageAppNameFrom(ctl.dic.Get) +} + +func (ctl *controller) getProductApp() interfaces.ProductItf { + return container.ProductAppNameFrom(ctl.dic.Get) +} + +func (ctl *controller) getDeviceApp() interfaces.DeviceItf { + return container.DeviceItfFrom(ctl.dic.Get) +} + +func (ctl *controller) getPersistApp() interfaces.PersistItf { + return container.PersistItfFrom(ctl.dic.Get) +} + +func (ctl *controller) getCategoryTemplateApp() interfaces.CategoryApp { + return container.CategoryTemplateAppFrom(ctl.dic.Get) +} + +func (ctl *controller) getThingModelTemplateApp() interfaces.ThingModelTemplateApp { + return container.ThingModelTemplateAppFrom(ctl.dic.Get) +} + +func (ctl *controller) getThingModelApp() interfaces.ThingModelCtlItf { + return container.ThingModelAppNameFrom(ctl.dic.Get) +} + +func (ctl *controller) getUnitModelApp() interfaces.UnitApp { + return container.UnitTemplateAppFrom(ctl.dic.Get) +} + +func (ctl *controller) getAlertRuleApp() interfaces.AlertRuleApp { + return container.AlertRuleAppNameFrom(ctl.dic.Get) +} + +func (ctl *controller) getRuleEngineApp() interfaces.RuleEngineApp { + return container.RuleEngineAppNameFrom(ctl.dic.Get) +} + +func (ctl *controller) getHomePageApp() interfaces.HomePageItf { + return container.HomePageAppNameFrom(ctl.dic.Get) +} + +func (ctl *controller) getSystemApp() interfaces.SystemItf { + return container.SystemItfFrom(ctl.dic.Get) +} + +func (ctl *controller) getDocsApp() interfaces.DocsApp { + return container.DocsTemplateAppFrom(ctl.dic.Get) +} + +func (ctl *controller) getQuickNavigationApp() interfaces.QuickNavigation { + return container.QuickNavigationAppTemplateAppFrom(ctl.dic.Get) +} + +func (ctl *controller) getDataResourceApp() interfaces.DataResourceApp { + return container.DataResourceFrom(ctl.dic.Get) +} + +func (ctl *controller) getSceneApp() interfaces.SceneApp { + return container.SceneAppNameFrom(ctl.dic.Get) +} diff --git a/internal/hummingbird/core/controller/http/gateway/dataresource.go b/internal/hummingbird/core/controller/http/gateway/dataresource.go new file mode 100644 index 0000000..4e01f68 --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/dataresource.go @@ -0,0 +1,139 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package gateway + +import ( + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" +) + +// @Tags 资源管理 +// @Summary 实例类型 +// @Produce json +// @Param request query dtos.AddDataResourceReq true "参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/typeresource [get] +func (ctl *controller) DataResourceType(c *gin.Context) { + lc := ctl.lc + types := ctl.getDataResourceApp().DataResourceType(c) + httphelper.ResultSuccess(types, c.Writer, lc) +} + +// @Tags 资源管理 +// @Summary 添加资源管理 +// @Produce json +// @Param request query dtos.AddDataResourceReq true "参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/dataresource [post] +func (ctl *controller) DataResourceAdd(c *gin.Context) { + lc := ctl.lc + var req dtos.AddDataResourceReq + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + _, edgeXErr := ctl.getDataResourceApp().AddDataResource(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 资源管理 +// @Summary 添加资源管理 +// @Produce json +// @Param request query dtos.AddDataResourceReq true "参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/dataresource/:resourceId [get] +func (ctl *controller) DataResourceById(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlDataResourceId) + dataSource, edgeXErr := ctl.getDataResourceApp().DataResourceById(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(dataSource, c.Writer, lc) +} + +// @Tags 资源管理 +// @Summary 修改资源管理 +// @Produce json +// @Param request query dtos.AddDataResourceReq true "参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/dataresource [put] +func (ctl *controller) UpdateDataResource(c *gin.Context) { + lc := ctl.lc + var req dtos.UpdateDataResource + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + edgeXErr := ctl.getDataResourceApp().UpdateDataResource(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 资源管理 +// @Summary 删除资源管理 +// @Produce json +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/dataresource/:resourceId [delete] +func (ctl *controller) DataResourceDel(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlDataResourceId) + edgeXErr := ctl.getDataResourceApp().DelDataResourceById(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 资源管理 +// @Summary 资源管理查询 +// @Produce json +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/dataresource [get] +func (ctl *controller) DataResourceSearch(c *gin.Context) { + lc := ctl.lc + var req dtos.DataResourceSearchQueryRequest + urlDecodeParam(&req, c.Request, lc) + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + data, total, edgeXErr := ctl.getDataResourceApp().DataResourceSearch(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + pageResult := httphelper.NewPageResult(data, total, req.Page, req.PageSize) + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} + +func (ctl *controller) DataResourceHealth(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlDataResourceId) + edgeXErr := ctl.getDataResourceApp().DataResourceHealth(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} diff --git a/internal/hummingbird/core/controller/http/gateway/device.go b/internal/hummingbird/core/controller/http/gateway/device.go new file mode 100644 index 0000000..1405e27 --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/device.go @@ -0,0 +1,407 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package gateway + +import ( + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" +) + +// @Tags 设备管理 +// @Summary 查询设备列表 +// @Produce json +// @Param request query dtos.DeviceSearchQueryRequest true "参数" +// @Success 200 {array} []dtos.DeviceSearchQueryResponse +// @Router /api/v1/devices [get] +func (ctl *controller) DevicesSearch(c *gin.Context) { + lc := ctl.lc + var req dtos.DeviceSearchQueryRequest + urlDecodeParam(&req, c.Request, lc) + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + data, total, edgeXErr := ctl.getDeviceApp().DevicesSearch(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + pageResult := httphelper.NewPageResult(data, total, req.Page, req.PageSize) + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} + +// @Tags 设备管理 +// @Summary 查询详情 +// @Produce json +// @Param deviceId path string true "pid" +// @Success 200 {object} dtos.DeviceInfoResponse +// @Router /api/v1/device/:deviceId [get] +func (ctl *controller) DeviceById(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamDeviceId) + data, edgeXErr := ctl.getDeviceApp().DeviceById(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(data, c.Writer, lc) +} + +// @Tags 设备管理 +// @Summary 删除设备 +// @Produce json +// @Param deviceId path string true "pid" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/device/:deviceId [delete] +func (ctl *controller) DeviceDelete(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamDeviceId) + edgeXErr := ctl.getDeviceApp().DeleteDeviceById(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 设备管理 +// @Summary 批量删除设备 +// @Produce json +// @Param deviceId path string true "pid" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/devices [delete] +func (ctl *controller) DevicesDelete(c *gin.Context) { + lc := ctl.lc + var req dtos.DeviceBatchDelete + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + if len(req.DeviceIds) == 0 { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, nil), c.Writer, lc) + return + } + edgeXErr := ctl.getDeviceApp().BatchDeleteDevice(c, req.DeviceIds) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 设备管理 +// @Summary 添加设备 +// @Produce json +// @Param request query dtos.DeviceAddRequest true "参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/device [post] +func (ctl *controller) DeviceByAdd(c *gin.Context) { + lc := ctl.lc + var req dtos.DeviceAddRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + _, edgeXErr := ctl.getDeviceApp().AddDevice(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 设备管理 +// @Summary 查询mqtt连接详情 +// @Produce json +// @Param deviceId path string true "pid" +// @Success 200 {object} dtos.DeviceAuthInfoResponse +// @Router /api/v1/device-mqtt/:deviceId [get] +func (ctl *controller) DeviceMqttInfoById(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamDeviceId) + data, edgeXErr := ctl.getDeviceApp().DeviceMqttAuthInfo(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(data, c.Writer, lc) +} + +func (ctl *controller) AddMqttAuth(c *gin.Context) { + lc := ctl.lc + var req dtos.AddMqttAuthInfoRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + data, edgeXErr := ctl.getDeviceApp().AddMqttAuth(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(data, c.Writer, lc) +} + +// @Tags 设备管理 +// @Summary 设备导入模版下载 +// @Produce json +// @Param req query dtos.DeviceImportTemplateRequest true "参数" +// @Success 200 {object} string +// @Router /api/v1/devices/import-template [get] +func (ctl *controller) DeviceImportTemplateDownload(c *gin.Context) { + lc := ctl.lc + var req dtos.DeviceImportTemplateRequest + + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + file, edgeXErr := ctl.getDeviceApp().DeviceImportTemplateDownload(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + data, _ := file.Excel.WriteToBuffer() + httphelper.ResultExcelData(c, file.FileName, data) +} + +// @Tags 设备管理 +// @Summary 设备导入模版校验 +// @Produce json +// @Success 200 {object} string +// @Router /api/v1/device/upload-validated [post] +func (ctl *controller) UploadValidated(c *gin.Context) { + lc := ctl.lc + files, _ := c.FormFile("file") + f, err := files.Open() + if err != nil { + err = errort.NewCommonErr(errort.DefaultUploadFileErrorCode, err) + httphelper.RenderFail(c, err, c.Writer, lc) + return + } + file, edgeXErr := dtos.NewImportFile(f) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + edgeXErr = ctl.getDeviceApp().UploadValidated(c, file) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 设备管理 +// @Summary 设备导入 +// @Produce json +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/devices/import [post] +func (ctl *controller) DevicesImport(c *gin.Context) { + lc := ctl.lc + //productId := c.Param(UrlParamProductId) + //cloudInstanceId := c.Param(UrlParamCloudInstanceId) + //var req dtos.ProductSearchQueryRequest + var req dtos.DevicesImport + urlDecodeParam(&req, c.Request, lc) + + files, _ := c.FormFile("file") + f, err := files.Open() + if err != nil { + err = errort.NewCommonErr(errort.DefaultUploadFileErrorCode, err) + httphelper.RenderFail(c, err, c.Writer, lc) + return + } + file, edgeXErr := dtos.NewImportFile(f) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + result, edgeXErr := ctl.getDeviceApp().DevicesImport(c, file, req.ProductId, req.DriverInstanceId) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + + httphelper.ResultSuccess(result, c.Writer, lc) +} + +// @Tags 设备管理 +// @Summary 更新设备 +// @Produce json +// @Param deviceId path string true "pid" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/device/:deviceId [put] +func (ctl *controller) DeviceUpdate(c *gin.Context) { + lc := ctl.lc + var req dtos.DeviceUpdateRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + err := ctl.getDeviceApp().DeviceUpdate(c, req) + if err != nil { + httphelper.RenderFail(c, err, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 设备管理 +// @Summary 设备批量绑定驱动 +// @Produce json +// @Param request query dtos.DevicesBindDriver true "参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/devices/bind-driver [put] +func (ctl *controller) DevicesBindDriver(c *gin.Context) { + lc := ctl.lc + var req dtos.DevicesBindDriver + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + err := ctl.getDeviceApp().DevicesBindDriver(c, req) + if err != nil { + httphelper.RenderFail(c, err, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 设备管理 +// @Summary 设备批量与驱动解绑 +// @Produce json +// @Param request query dtos.DevicesBindDriver true "参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/unbind-driver [put] +func (ctl *controller) DevicesUnBindDriver(c *gin.Context) { + lc := ctl.lc + var req dtos.DevicesUnBindDriver + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + err := ctl.getDeviceApp().DevicesUnBindDriver(c, req) + if err != nil { + httphelper.RenderFail(c, err, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +func (ctl *controller) DevicesBindByProductId(c *gin.Context) { + lc := ctl.lc + var req dtos.DevicesBindProductId + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + err := ctl.getDeviceApp().DevicesBindProductId(c, req) + if err != nil { + httphelper.RenderFail(c, err, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 设备管理 +// @Summary 查看设备属性数据 +// @Produce json +// @Param request query dtos.ThingModelPropertyDataRequest true "参数" +// @Success 200 {array} []dtos.ThingModelDataResponse +// @Router /api/v1/device/:deviceId/thing-model/property [get] +func (ctl *controller) DeviceThingModelPropertyDataSearch(c *gin.Context) { + lc := ctl.lc + var req dtos.ThingModelPropertyDataRequest + urlDecodeParam(&req, c.Request, lc) + deviceId := c.Param(UrlParamDeviceId) + req.DeviceId = deviceId + data, edgeXErr := ctl.getPersistApp().SearchDeviceThingModelPropertyData(req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(data, c.Writer, lc) +} + +func (ctl *controller) DeviceThingModelHistoryPropertyDataSearch(c *gin.Context) { + lc := ctl.lc + var req dtos.ThingModelPropertyDataRequest + urlDecodeParam(&req, c.Request, lc) + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + deviceId := c.Param(UrlParamDeviceId) + req.DeviceId = deviceId + data, total, edgeXErr := ctl.getPersistApp().SearchDeviceThingModelHistoryPropertyData(req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + pageResult := httphelper.NewPageResult(data, uint32(total), req.Page, req.PageSize) + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} + +// @Tags 设备管理 +// @Summary 查看设备事件数据 +// @Produce json +// @Param request query dtos.ThingModelPropertyDataRequest true "参数" +// @Success 200 {array} []dtos.ThingModelEventDataResponse +// @Router /api/v1/device/:deviceId/thing-model/event [get] +func (ctl *controller) DeviceThingModelEventDataSearch(c *gin.Context) { + lc := ctl.lc + var req dtos.ThingModelEventDataRequest + urlDecodeParam(&req, c.Request, lc) + deviceId := c.Param(UrlParamDeviceId) + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + req.DeviceId = deviceId + data, total, edgeXErr := ctl.getPersistApp().SearchDeviceThingModelEventData(req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + pageResult := httphelper.NewPageResult(data, uint32(total), req.Page, req.PageSize) + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} + +// @Tags 设备管理 +// @Summary 查看设备服务调用数据 +// @Produce json +// @Param request query dtos.ThingModelServiceDataRequest true "参数" +// @Success 200 {array} []dtos.ThingModelServiceDataResponse +// @Router /api/v1/device/:deviceId/thing-model/service [get] +func (ctl *controller) DeviceThingModelServiceDataSearch(c *gin.Context) { + lc := ctl.lc + var req dtos.ThingModelServiceDataRequest + urlDecodeParam(&req, c.Request, lc) + deviceId := c.Param(UrlParamDeviceId) + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + req.DeviceId = deviceId + data, total, edgeXErr := ctl.getPersistApp().SearchDeviceThingModelServiceData(req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + pageResult := httphelper.NewPageResult(data, uint32(total), req.Page, req.PageSize) + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} + +func (ctl *controller) DeviceStatusTemplate(c *gin.Context) { + lc := ctl.lc + var deviceStatus []constants.DeviceStatus + deviceStatus = append(append(append(deviceStatus), + constants.DeviceStatusOnline), + constants.DeviceStatusOffline) + httphelper.ResultSuccess(deviceStatus, c.Writer, lc) +} diff --git a/internal/hummingbird/core/controller/http/gateway/devicealert.go b/internal/hummingbird/core/controller/http/gateway/devicealert.go new file mode 100644 index 0000000..9324ac5 --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/devicealert.go @@ -0,0 +1,280 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package gateway + +import ( + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" + "time" +) + +// @Tags 告警中心 +// @Summary 添加告警规则 +// @Produce json +// @Param request query dtos.RuleAddRequest true "参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/alert-rule [post] +func (ctl *controller) AlertRuleAdd(c *gin.Context) { + lc := ctl.lc + var req dtos.RuleAddRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + _, edgeXErr := ctl.getAlertRuleApp().AddAlertRule(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 告警中心 +// @Summary 编辑告警规则 +// @Produce json +// @Param request query dtos.RuleUpdateRequest true "参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/alert-rule/:ruleId [put] +func (ctl *controller) AlertRuleUpdate(c *gin.Context) { + lc := ctl.lc + var req dtos.RuleUpdateRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + edgeXErr := ctl.getAlertRuleApp().UpdateAlertRule(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +func (ctl *controller) AlertRuleUpdateField(c *gin.Context) { + lc := ctl.lc + var req dtos.RuleFieldUpdate + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + edgeXErr := ctl.getAlertRuleApp().UpdateAlertField(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 告警中心 +// @Summary 告警规则详情 +// @Produce json +// @Param ruleId path string true "ruleId" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/alert-rule/:ruleId [get] +func (ctl *controller) AlertRuleById(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamRuleId) + data, edgeXErr := ctl.getAlertRuleApp().AlertRuleById(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(data, c.Writer, lc) +} + +// @Tags 告警中心 +// @Summary 告警规则列表 +// @Produce json +// @Param request query dtos.AlertRuleSearchQueryRequest true "参数" +// @Success 200 {array} []dtos.AlertRuleSearchQueryResponse +// @Router /api/v1/alert-rule [get] +func (ctl *controller) AlertRuleSearch(c *gin.Context) { + lc := ctl.lc + var req dtos.AlertRuleSearchQueryRequest + urlDecodeParam(&req, c.Request, lc) + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + data, total, edgeXErr := ctl.getAlertRuleApp().AlertRulesSearch(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + pageResult := httphelper.NewPageResult(data, total, req.Page, req.PageSize) + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} + +// @Tags 告警中心 +// @Summary 告警规则启动 +// @Produce json +// @Param ruleId path string true "ruleId" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/alert-rule/:ruleId/start [post] +func (ctl *controller) AlertRuleStart(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamRuleId) + edgeXErr := ctl.getAlertRuleApp().AlertRulesStart(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 告警中心 +// @Summary 告警规则停止 +// @Produce json +// @Param ruleId path string true "ruleId" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/alert-rule/:ruleId/stop [post] +func (ctl *controller) AlertRuleStop(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamRuleId) + edgeXErr := ctl.getAlertRuleApp().AlertRulesStop(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 告警中心 +// @Summary 告警规则重启 +// @Produce json +// @Param ruleId path string true "ruleId" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/alert-rule/:ruleId/restart [post] +func (ctl *controller) AlertRuleRestart(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamRuleId) + edgeXErr := ctl.getAlertRuleApp().AlertRulesRestart(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 告警中心 +// @Summary 告警规则删除 +// @Produce json +// @Param ruleId path string true "ruleId" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/alert-rule/:ruleId [delete] +func (ctl *controller) AlertRuleDelete(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamRuleId) + edgeXErr := ctl.getAlertRuleApp().AlertRulesDelete(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 告警中心 +// @Summary 告警列表 +// @Produce json +// @Param request query dtos.AlertSearchQueryRequest true "参数" +// @Success 200 {array} []dtos.AlertSearchQueryResponse +// @Router /api/v1/alert-list [get] +func (ctl *controller) AlertSearch(c *gin.Context) { + lc := ctl.lc + var req dtos.AlertSearchQueryRequest + urlDecodeParam(&req, c.Request, lc) + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + data, total, edgeXErr := ctl.getAlertRuleApp().AlertSearch(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + pageResult := httphelper.NewPageResult(data, total, req.Page, req.PageSize) + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} + +// @Tags 告警中心 +// @Summary 告警列表 +// @Produce json +// @Param request query dtos.AlertSearchQueryRequest true "参数" +// @Success 200 {array} []dtos.AlertSearchQueryResponse +// @Router /api/v1/alert-plate [get] +func (ctl *controller) AlertPlate(c *gin.Context) { + lc := ctl.lc + currentTime := time.Now() + beforeTime := currentTime.AddDate(0, 0, -7).UnixMilli() + data, edgeXErr := ctl.getAlertRuleApp().AlertPlate(c, beforeTime) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(data, c.Writer, lc) +} + +// @Tags 告警中心 +// @Summary 忽略告警 +// @Produce json +// @Router /api/v1/alert-ignore/:ruleId [put] +func (ctl *controller) AlertIgnore(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamRuleId) + edgeXErr := ctl.getAlertRuleApp().AlertIgnore(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 告警中心 +// @Summary 处理告警 +// @Produce json +// @Router /api/v1/alert-treated [post] +func (ctl *controller) AlertTreated(c *gin.Context) { + lc := ctl.lc + var req dtos.AlertTreatedRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + edgeXErr := ctl.getAlertRuleApp().TreatedIgnore(c, req.Id, req.Message) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 告警中心 +// @Summary 告警列表 +// @Produce json +// @Param request query dtos.AlertAddRequest true "参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/alert [post] +func (ctl *controller) EkuiperAlert(c *gin.Context) { + lc := ctl.lc + + req := make(map[string]interface{}) + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + lc.Info("req....", req) + edgeXErr := ctl.getAlertRuleApp().AddAlert(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} diff --git a/internal/hummingbird/core/controller/http/gateway/devicelibrary.go b/internal/hummingbird/core/controller/http/gateway/devicelibrary.go new file mode 100644 index 0000000..19ae0ce --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/devicelibrary.go @@ -0,0 +1,270 @@ +package gateway + +import ( + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" +) + +// @Tags 驱动库管理 +// @Summary 新增驱动库 +// @Produce json +// @Param request body dtos.DeviceLibraryAddRequest true "参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/device-libraries [post] +func (ctl *controller) DeviceLibraryAdd(c *gin.Context) { + lc := ctl.lc + var req dtos.DeviceLibraryAddRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + err := ctl.getDriverLibApp().AddDriverLib(c, req) + if err != nil { + httphelper.RenderFail(c, err, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 驱动库管理 +// @Summary 查询驱动 +// @Produce json +// @Param request query dtos.DeviceLibrarySearchQueryRequest true "参数" +// @Success 200 {object} httphelper.ResPageResult +// @Router /api/v1/device-libraries [get] +func (ctl *controller) DeviceLibrariesSearch(c *gin.Context) { + lc := ctl.lc + var req dtos.DeviceLibrarySearchQueryRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + + //req2 := dtos.FromDeviceLibrarySearchQueryRequestToRpc(req) + list, total, edgeXErr := ctl.getDriverLibApp().DeviceLibrariesSearch(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + libs := make([]dtos.DeviceLibraryResponse, len(list)) + for i, p := range list { + libs[i] = dtos.DeviceLibraryResponseFromModel(p) + } + pageResult := httphelper.NewPageResult(libs, total, req.Page, req.PageSize) + + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} + +// @Tags 驱动库管理 +// @Summary 删除驱动 +// @Produce json +// @Param deviceLibraryId path string true "驱动ID" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/device-libraries/:deviceLibraryId [delete] +func (ctl *controller) DeviceLibraryDelete(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamDeviceLibraryId) + err := ctl.getDriverLibApp().DeleteDeviceLibraryById(c, id) + if err != nil { + httphelper.RenderFail(c, err, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 驱动库管理 +// @Summary 获取驱动定义配置信息 +// @Produce json +// @Param request query dtos.DeviceLibraryConfigRequest true "参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/device-libraries/config [get] +//func (ctl *controller) DeviceLibraryConfig(c *gin.Context) { +// lc := ctl.lc +// var req dtos.DeviceLibraryConfigRequest +// if err := c.ShouldBindQuery(&req); err != nil { +// httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) +// return +// } +// if req.DeviceLibraryId == nil && req.CloudProductId == nil && req.DeviceServiceId == nil && req.DeviceId == nil { +// err := fmt.Errorf("deviceLibraryConfig req is null") +// httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) +// return +// } +// +// dl, err := ctl.getDriverLibApp().GetDeviceLibraryConfig(c, req) +// //data, edgeXErr := gatewayapp.DeviceLibraryConfig(c, req) +// if err != nil { +// httphelper.RenderFail(c, err, c.Writer, lc) +// return +// } +// config, err := dl.GetConfigMap() +// if err != nil { +// httphelper.RenderFail(c, err, c.Writer, lc) +// return +// } +// +// httphelper.ResultSuccess(config, c.Writer, lc) +//} + +// @Tags 驱动库管理 +// @Summary 驱动库升级/下载 +// @Produce json +// @Param deviceLibraryId path string true "驱动ID" +// @Param request query dtos.DeviceLibraryUpgradeRequest true "参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/device-libraries/:deviceLibraryId/upgrade-download [put] +// Deprecated +//func (ctl *controller) DeviceLibraryUpgrade(c *gin.Context) { +// lc := ctl.lc +// var req dtos.DeviceLibraryUpgradeRequest +// req.Id = c.Param(UrlParamDeviceLibraryId) +// if err := c.ShouldBind(&req); err != nil { +// httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) +// return +// } +// +// err := ctl.getDriverLibApp().UpgradeDeviceLibrary(c, req, true) +// //edgeXErr := gatewayapp.DeviceLibraryUpgrade(c, req) +// if err != nil { +// httphelper.RenderFail(c, err, c.Writer, lc) +// return +// } +// +// httphelper.ResultSuccess(nil, c.Writer, lc) +//} + +// @Tags 驱动库管理 +// @Summary 上传驱动配置文件 +// @Accept multipart/form-data +// @Produce json +// @Success 200 {object} dtos.DeviceLibraryUploadResponse +// @Router /api/v1/device-libraries/upload [post] +//func (ctl *controller) DeviceLibraryUpload(c *gin.Context) { +// lc := ctl.lc +// file, err := c.FormFile("fileName") +// if err != nil { +// httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) +// return +// } +// var req dtos.DeviceLibraryUploadRequest +// req.FileName = file.Filename +// if !utils.CheckFileValid(req.FileName) { +// err := fmt.Errorf("file name cannot contain special characters: %s", req.FileName) +// httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultFileNotSpecialSymbol, err), c.Writer, lc) +// return +// } +// +// if fileSuffix := path.Ext(req.FileName); fileSuffix != ".json" { +// err := fmt.Errorf("file type not json, filename: %s", req.FileName) +// httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultJsonParseError, err), c.Writer, lc) +// return +// } +// +// uploadType, err := strconv.Atoi(c.Request.PostForm["type"][0]) +// if err != nil { +// httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) +// return +// } +// +// req.Type = uploadType +// if req.Type != constants.DeviceLibraryUploadTypeConfig { +// err := fmt.Errorf("req upload type %d not is %d", req.Type, constants.DeviceLibraryUploadTypeConfig) +// httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) +// return +// } +// f, err := file.Open() +// if err != nil { +// httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) +// return +// } +// +// // 将文件流读入请求中,并校验格式 +// req.FileBytes, err = ioutil.ReadAll(f) +// if err != nil { +// httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) +// return +// } +// if !json.Valid(req.FileBytes) { +// err := fmt.Errorf("config file content must be json") +// httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultJsonParseError, err), c.Writer, lc) +// return +// } +// +// fileName, err := ctl.getDriverLibApp().UploadDeviceLibraryConfig(c, req) +// //resp, edgeXErr := gatewayapp.DeviceLibraryUpload(c, req) +// if err != nil { +// httphelper.RenderFail(c, err, c.Writer, lc) +// return +// } +// resp := dtos.DeviceLibraryUploadResponse{ +// FileName: fileName, +// } +// httphelper.ResultSuccess(resp, c.Writer, lc) +//} + +//@Tags 驱动库管理 +//@Summary 驱动库更新 +//@Produce json +//@Param deviceLibraryId path string true "驱动ID" +//@Param request query dtos.UpdateDeviceLibrary true "参数" +//@Success 200 {object} httphelper.CommonResponse +//@Router /api/v1/device-libraries/:deviceLibraryId [put] +func (ctl *controller) DeviceLibraryUpdate(c *gin.Context) { + lc := ctl.lc + var req dtos.UpdateDeviceLibrary + req.Id = c.Param(UrlParamDeviceLibraryId) + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + err := ctl.getDriverLibApp().UpdateDeviceLibrary(c, req) + if err != nil { + httphelper.RenderFail(c, err, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 驱动库管理 +// @Summary 驱动库配置下载 +// @Produce application/octet-stream +// @Router /api/v1/device-libraries/config/download [get] +//func (ctl *controller) DeviceLibraryConfigDownload(c *gin.Context) { +// //dir, _ := os.Getwd() +// //if dir == "/" { +// // dir = "" +// //} +// //filePath := dir + "/template/driver_config_demo.json" +// // +// //fileName := path.Base(filePath) +// //c.Header("Content-Type", "application/octet-stream") +// //c.Header("Content-Disposition", "attachment; filename="+fileName) +// //c.File(filePath) +// cfg := ctl.getDriverLibApp().ConfigDemo() +// buff := bytes.NewBuffer([]byte(cfg)) +// httphelper.ResultExcelData(c, "driver_config.json", buff) +//} + +//@Tags 驱动库分类 +//@Summary 驱动库分类 +//@Produce json +// @Param request query dtos.DriverClassifyQueryRequest true "参数" +//@Success 200 {object} httphelper.CommonResponse +//@Router /api/v1/device_classify [get] +func (ctl *controller) DeviceClassify(c *gin.Context) { + lc := ctl.lc + var req dtos.DriverClassifyQueryRequest + urlDecodeParam(&req, c.Request, lc) + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + list, total, edgeXErr := ctl.getDriverLibApp().GetDriverClassify(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + pageResult := httphelper.NewPageResult(list, total, req.Page, req.PageSize) + + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} diff --git a/internal/hummingbird/core/controller/http/gateway/deviceservice.go b/internal/hummingbird/core/controller/http/gateway/deviceservice.go new file mode 100644 index 0000000..e7238e3 --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/deviceservice.go @@ -0,0 +1,98 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package gateway + +import ( + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" +) + +// @Tags 驱动实例管理 +// @Summary 查询驱动实例 +// @Produce json +// @Param request query dtos.DeviceServiceSearchQueryRequest true "参数" +// @Success 200 {object} httphelper.ResPageResult +// @Router /api/v1/device-servers [get] +func (ctl *controller) DeviceServicesSearch(c *gin.Context) { + lc := ctl.lc + var req dtos.DeviceServiceSearchQueryRequest + urlDecodeParam(&req, c.Request, lc) + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + + // TODO 驱动实例搜索是否需要查询驱动库 + dss, total, err := ctl.getDriverServiceApp().Search(c, req) + if err != nil { + httphelper.RenderFail(c, err, c.Writer, lc) + return + } + data := make([]dtos.DeviceServiceResponse, len(dss)) + for i, ds := range dss { + dl, err := ctl.getDriverLibApp().DriverLibById(ds.DeviceLibraryId) + if err != nil { + httphelper.RenderFail(c, err, c.Writer, lc) + return + } + data[i] = dtos.DeviceServiceResponseFromModel(ds, dl) + } + + pageResult := httphelper.NewPageResult(data, total, req.Page, req.PageSize) + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} + +// @Tags 驱动实例管理 +// @Summary 删除驱动实例(废弃,已经改为websockert形式) +// @Produce json +// @Param deviceServiceId path string true "驱动实例 ID" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/device_server/:deviceServiceId [delete] +func (ctl *controller) DeviceServiceDelete(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamDeviceServiceId) + + err := ctl.getDriverServiceApp().Del(c, id) + //edgeXErr := gatewayapp.DeviceServiceDelete(c, id) + if err != nil { + httphelper.RenderFail(c, err, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 驱动实例管理 +// @Summary 编辑驱动实例 +// @Produce json +// @Param request body dtos.DeviceServiceUpdateRequest true "参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/device-server [put] +func (ctl *controller) DeviceServiceUpdate(c *gin.Context) { + lc := ctl.lc + var req dtos.DeviceServiceUpdateRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + err := ctl.getDriverServiceApp().Update(c, req) + if err != nil { + httphelper.RenderFail(c, err, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} diff --git a/internal/hummingbird/core/controller/http/gateway/dockerconfig.go b/internal/hummingbird/core/controller/http/gateway/dockerconfig.go new file mode 100644 index 0000000..8675bed --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/dockerconfig.go @@ -0,0 +1,109 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package gateway + +import ( + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" +) + +// @Tags 镜像仓库管理 +// @Summary 新增镜像 +// @Produce json +// @Param request body dtos.DockerConfigAddRequest true "参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/docker-configs [post] +func (ctl *controller) DockerConfigAdd(c *gin.Context) { + lc := ctl.lc + var req dtos.DockerConfigAddRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + edgeXErr := ctl.getDriverLibApp().DownConfigAdd(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 镜像仓库管理 +// @Summary 获取镜像列表 +// @Produce json +// @Param request query dtos.DockerConfigSearchQueryRequest true "参数" +// @Success 200 {object} httphelper.ResPageResult +// @Router /api/v1/docker-configs [get] +func (ctl *controller) DockerConfigsSearch(c *gin.Context) { + lc := ctl.lc + var req dtos.DockerConfigSearchQueryRequest + urlDecodeParam(&req, c.Request, lc) + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + + list, total, edgeXErr := ctl.getDriverLibApp().DownConfigSearch(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + dcs := make([]dtos.DockerConfigResponse, len(list)) + for i, v := range list { + dcs[i] = dtos.DockerConfigResponseFromModel(v) + } + pageResult := httphelper.NewPageResult(dcs, total, req.Page, req.PageSize) + + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} + +// @Tags 镜像仓库管理 +// @Summary 修改仓库信息 +// @Produce json +// @Param request body dtos.DockerConfigUpdateRequest true "参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/docker-configs/:dockerConfigId [put] +func (ctl *controller) DockerConfigUpdate(c *gin.Context) { + lc := ctl.lc + var req dtos.DockerConfigUpdateRequest + req.Id = c.Param(UrlParamDockerConfigId) + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + + edgeXErr := ctl.getDriverLibApp().DownConfigUpdate(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 镜像仓库管理 +// @Summary 删除仓库信息 +// @Produce json +// @Param dockerConfigId path string true "镜像ID" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/docker-configs/:dockerConfigId [delete] +func (ctl *controller) DockerConfigDelete(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamDockerConfigId) + edgeXErr := ctl.getDriverLibApp().DownConfigDel(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} diff --git a/internal/hummingbird/core/controller/http/gateway/docs.go b/internal/hummingbird/core/controller/http/gateway/docs.go new file mode 100644 index 0000000..6c94f84 --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/docs.go @@ -0,0 +1,15 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package gateway diff --git a/internal/hummingbird/core/controller/http/gateway/ekuiper.go b/internal/hummingbird/core/controller/http/gateway/ekuiper.go new file mode 100644 index 0000000..19dc414 --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/ekuiper.go @@ -0,0 +1,38 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package gateway + +import ( + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" +) + +func (ctl *controller) EkuiperScene(c *gin.Context) { + lc := ctl.lc + + req := make(map[string]interface{}) + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + lc.Info("scene req....", req) + edgeXErr := ctl.getSceneApp().EkuiperNotify(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} diff --git a/internal/hummingbird/core/controller/http/gateway/homepage.go b/internal/hummingbird/core/controller/http/gateway/homepage.go new file mode 100644 index 0000000..5213408 --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/homepage.go @@ -0,0 +1,40 @@ +/******************************************************************************* + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package gateway + +import ( + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" +) + +// @Tags 首页 +// @Summary 首页 +// @Produce json +// @Param request query dtos.HomePageRequest true "参数" +// @Success 200 {object} httphelper.ResPageResult +// @Router /api/v1/homepage [get] +// @Security ApiKeyAuth +func (ctl *controller) HomePage(c *gin.Context) { + lc := ctl.lc + var req dtos.HomePageRequest + data, edgeXErr := ctl.getHomePageApp().HomePageInfo(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(data, c.Writer, lc) +} diff --git a/internal/hummingbird/core/controller/http/gateway/jobs.go b/internal/hummingbird/core/controller/http/gateway/jobs.go new file mode 100644 index 0000000..ae5620d --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/jobs.go @@ -0,0 +1,64 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package gateway + +//func (ctl *controller) AddJobHandle(c *gin.Context) { +// ctl.ProxySharpServer(c) +//} +// +//func (ctl *controller) QueryJobHandle(c *gin.Context) { +// ctl.ProxySharpServer(c) +//} +// +//func (ctl *controller) ChangeJobStatusHandle(c *gin.Context) { +// ctl.ProxySharpServer(c) +//} +// +//func (ctl *controller) GetJobHandle(c *gin.Context) { +// ctl.ProxySharpServer(c) +//} +// +//func (ctl *controller) UpdateJobHandle(c *gin.Context) { +// ctl.ProxySharpServer(c) +//} +// +//func (ctl *controller) DeleteJobHandle(c *gin.Context) { +// ctl.ProxySharpServer(c) +//} +// +//func (ctl *controller) CheckJobExistByDeviceIdOrSceneIdHandle(c *gin.Context) { +// ctl.ProxySharpServer(c) +//} +// +//func (ctl *controller) QueryJobLogsHandle(c *gin.Context) { +// ctl.ProxySharpServer(c) +//} +// +//func (ctl *controller) ExecJobHandle(c *gin.Context) { +// lc := ctl.lc +// var req dtos.JobAction +// if err := c.ShouldBind(&req); err != nil { +// httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) +// return +// } +// +// //data, total, edgeXErr := ctl.getDeviceApp().DeviceAction(c, req) +// //if edgeXErr != nil { +// // httphelper.RenderFail(c, edgeXErr, c.Writer, lc) +// // return +// //} +// //pageResult := httphelper.NewPageResult(data, total, req.Page, req.PageSize) +// httphelper.ResultSuccess(nil, c.Writer, lc) +//} diff --git a/internal/hummingbird/core/controller/http/gateway/langeuagesdk.go b/internal/hummingbird/core/controller/http/gateway/langeuagesdk.go new file mode 100644 index 0000000..7370c2f --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/langeuagesdk.go @@ -0,0 +1,69 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package gateway + +import ( + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" +) + +func (ctl *controller) LanguageSdkSearch(c *gin.Context) { + lc := ctl.lc + var req dtos.LanguageSDKSearchQueryRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + + list, _, edgeXErr := ctl.getLanguageApp().LanguageSDKSearch(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + responseData := make(map[string]interface{}) + responseData["doc"] = map[string]interface{}{ + "name": "物联网平台文档", + "icon": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAOEAAADhCAMAAAAJbSJIAAAAflBMVEX////2iSnRTif5hAj5jjH96tv1gAj94s" + + "7PSif4jCnSQQHOPQDQSB3129T129XdfGPSQgr2hBjggGf2hiH5iBbROwDegGnTQgrbe2XORB3OQQ/5fQD5hQv+8uf70rX6w5r6uon7y6j838n+9Ov/+v" + + "X3mk/3kTr4pGL4qm37xqDMD8zoAAAChUlEQVR4nO3ZaVOCUBSA4cykhe5FpNVKW9Tq///BaNHRBC6KzOGced+PIAzPDHdRj466kE+kn6DdfDK8ln6GNnM3" + + "w8vjC+mnaC83yn3HdoUuufj2mRW6ZPjrMyr0o5XPpNCt+wwKXRKv+8wJ18afSeG/99Oc0N3E2z5Dwq3xZ0xYMP5MCX3R+DMkLJxfDAlX+0+jworxZ0Lok3HIp1" + + "pYOb8YEJauf0aEtX1KhT48v6gWluw/zQh9aP1TLnTJ7W4+ZUJXZ/1TLAzsP9UL93g/VQmD+2vlwr3GnyJhzf2nWmGD91OFcIf9p0phre9/ioU//282r7PCvdc/JcLG8" + + "0vHhTV+X1It3HP/qUboRwcafx0VuuT6sL6OCRvtPxUID7T+dVbo7+L7s+rGcWXj4qtiadmyh+w80ONp9R0GjyXX9ev30KKw3wsVhYRR8BbB+ggRIkSIECFChAgRIkSIECFC" + + "hAgRIkSIECFChAgRIkSIECFChAgRIkSIECFChAgRIkSIECFChAgRIkSIECFChAgRIkSIECFChAgRNhM+TSZPpoXTKO/ZsHD6w9kg2hK+/GmiqVHhZPXp6MWk8HXNEk0MCk/Tjc" + + "Ov5oSD82zjeLo8bkX4dpL9O5ENTAlnW8Belr0ZEs7mW8CcuJjZEb6nRafS+cyCMM2Xvo9CYH7uPV8k1Qt7vem8BJgTF5+L5kBxYVowBpdlpXhNwvZDiBChfAgRIpQPIUKE8iFEiFA" + + "+hAgRyocQIUL5ECJEKB9ChAjlQ4gQoXwIESKUDyFChPIhRIhQPoQIEcqHECFC+RAiRCgfQoQI5UOIEKF8CBEilA8hQoTyIUSIUD6ECBHKhxBhVVcnXehqp2f+Ag5ihjFgr47/AAAAAElFTkSuQmCC", + "addr": "https://doc.hummingbird.winc-link.com/", + } + responseData["sdk_language"] = list + httphelper.ResultSuccess(responseData, c.Writer, lc) +} + +func (ctl *controller) LanguageSdkSync(c *gin.Context) { + lc := ctl.lc + var req dtos.LanguageSDKSyncRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + edgeXErr := ctl.getLanguageApp().Sync(c, req.VersionName) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} diff --git a/internal/hummingbird/core/controller/http/gateway/network.go b/internal/hummingbird/core/controller/http/gateway/network.go new file mode 100644 index 0000000..bed5fc3 --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/network.go @@ -0,0 +1,98 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package gateway + +import ( + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" +) + +// @Tags 配网助手 +// @Summary 获取网卡列表 +// @Produce json +// @Success 200 {object} dtos.ConfigNetWorkResponse +// @Router /api/v1/local/config/network [get] +func (ctl *controller) ConfigNetWorkGet(c *gin.Context) { + lc := ctl.lc + res, edgeXErr := ctl.getSystemApp().ConfigNetWork(c, false) + if edgeXErr != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, edgeXErr), c.Writer, lc) + return + } + + httphelper.ResultSuccess(res, c.Writer, lc) +} + +// @Tags 配网助手 +// @Summary 修改网卡 +// @Produce json +// @Param req body dtos.ConfigNetworkUpdateRequest true "参数" +// @Success 200 {object} dtos.ConfigNetWorkResponse +// @Router /api/v1/local/config/network [put] +func (ctl *controller) ConfigNetWorkUpdate(c *gin.Context) { + lc := ctl.lc + var req dtos.ConfigNetworkUpdateRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + edgeXErr := ctl.getSystemApp().ConfigNetWorkUpdate(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 配网助手 +// @Summary 获取dns +// @Produce json +// @Success 200 {object} dtos.ConfigDnsResponse +// @Router /api/v1/local/config/dns [get] +func (ctl *controller) ConfigDnsGet(c *gin.Context) { + lc := ctl.lc + resp, edgeXErr := ctl.getSystemApp().ConfigDns(c) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + + httphelper.ResultSuccess(resp, c.Writer, lc) +} + +// @Tags 配网助手 +// @Summary 修改dns +// @Produce json +// @Param req body dtos.ConfigDnsUpdateRequest true "参数" +// @Success 200 {object} dtos.ConfigDnsResponse +// @Router /api/v1/local/config/dns [put] +func (ctl *controller) ConfigDnsUpdate(c *gin.Context) { + lc := ctl.lc + var req dtos.ConfigDnsUpdateRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + edgeXErr := ctl.getSystemApp().ConfigDnsUpdate(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + + httphelper.ResultSuccess(nil, c.Writer, lc) +} diff --git a/internal/hummingbird/core/controller/http/gateway/product.go b/internal/hummingbird/core/controller/http/gateway/product.go new file mode 100644 index 0000000..52432f3 --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/product.go @@ -0,0 +1,139 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package gateway + +import ( + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" +) + +// @Tags 产品管理 +// @Summary 查询产品列表 +// @Produce json +// @Param request query dtos.ProductSearchQueryRequest true "参数" +// @Success 200 {array} dtos.ProductSearchQueryResponse +// @Router /api/v1/products [get] +// @Security ApiKeyAuth +func (ctl *controller) ProductsSearch(c *gin.Context) { + lc := ctl.lc + var req dtos.ProductSearchQueryRequest + urlDecodeParam(&req, c.Request, lc) + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + + data, total, edgeXErr := ctl.getProductApp().ProductsSearch(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + pageResult := httphelper.NewPageResult(data, total, req.Page, req.PageSize) + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} + +// @Tags 产品管理 +// @Summary 查询产品详情 +// @Produce json +// @Param productId path string true "pid" +// @Success 200 {object} dtos.ProductSearchByIdResponse +// @Router /api/v1/product/:productId [get] +// @Security ApiKeyAuth +func (ctl *controller) ProductById(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamProductId) + data, edgeXErr := ctl.getProductApp().ProductById(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(data, c.Writer, lc) +} + +// @Tags 产品管理 +// @Summary 删除产品 +// @Produce json +// @Param productId path string true "pid" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/product/:productId [delete] +// @Security ApiKeyAuth +func (ctl *controller) ProductDelete(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamProductId) + edgeXErr := ctl.getProductApp().ProductDelete(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +func (ctl *controller) ProductRelease(c *gin.Context) { + lc := ctl.lc + productId := c.Param(UrlParamProductId) + edgeXErr := ctl.getProductApp().ProductRelease(c, productId) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +func (ctl *controller) ProductUnRelease(c *gin.Context) { + lc := ctl.lc + productId := c.Param(UrlParamProductId) + edgeXErr := ctl.getProductApp().ProductUnRelease(c, productId) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 产品管理 +// @Summary 添加产品 +// @Produce json +// @Param request body dtos.ProductAddRequest true "参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/product [post] +// @Security ApiKeyAuth +func (ctl *controller) ProductAdd(c *gin.Context) { + lc := ctl.lc + var req dtos.ProductAddRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + _, edgeXErr := ctl.getProductApp().AddProduct(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 产品管理 +// @Summary 云平台列表 +// @Produce json +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/iot-platform [get] +// @Security ApiKeyAuth +func (ctl *controller) IotPlatform(c *gin.Context) { + lc := ctl.lc + var iotPlatform []constants.IotPlatform + iotPlatform = append(iotPlatform, constants.IotPlatform_LocalIot) + httphelper.ResultSuccess(iotPlatform, c.Writer, lc) + +} diff --git a/internal/hummingbird/core/controller/http/gateway/quicknavigation.go b/internal/hummingbird/core/controller/http/gateway/quicknavigation.go new file mode 100644 index 0000000..6c94f84 --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/quicknavigation.go @@ -0,0 +1,15 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package gateway diff --git a/internal/hummingbird/core/controller/http/gateway/ruleengine.go b/internal/hummingbird/core/controller/http/gateway/ruleengine.go new file mode 100644 index 0000000..b8b6e65 --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/ruleengine.go @@ -0,0 +1,169 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package gateway + +import ( + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" +) + +// @Tags 规则引擎 +// @Summary 添加规则引擎 +// @Produce json +// @Param request query dtos.RuleEngineRequest true "参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/rule-engine [post] +func (ctl *controller) RuleEngineAdd(c *gin.Context) { + lc := ctl.lc + var req dtos.RuleEngineRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + _, edgeXErr := ctl.getRuleEngineApp().AddRuleEngine(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 规则引擎 +// @Summary 编辑规则引擎 +// @Produce json +// @Param request query dtos.RuleEngineUpdateRequest true "参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/rule-engine [put] +func (ctl *controller) RuleEngineUpdate(c *gin.Context) { + lc := ctl.lc + var req dtos.RuleEngineUpdateRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + edgeXErr := ctl.getRuleEngineApp().UpdateRuleEngine(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 规则引擎 +// @Summary 规则引擎详情 +// @Produce json +// @Param ruleEngineId path string true "ruleEngineId" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/rule-engine/:ruleEngineId [get] +func (ctl *controller) RuleEngineById(c *gin.Context) { + lc := ctl.lc + id := c.Param(RuleEngineId) + data, edgeXErr := ctl.getRuleEngineApp().RuleEngineById(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(data, c.Writer, lc) +} + +// @Tags 规则引擎 +// @Summary 规则引擎列表 +// @Produce json +// @Param request query dtos.RuleEngineSearchQueryRequest true "参数" +// @Success 200 {array} []dtos.RuleEngineSearchQueryResponse +// @Router /api/v1/rule-engine [get] +func (ctl *controller) RuleEngineSearch(c *gin.Context) { + lc := ctl.lc + var req dtos.RuleEngineSearchQueryRequest + urlDecodeParam(&req, c.Request, lc) + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + data, total, edgeXErr := ctl.getRuleEngineApp().RuleEngineSearch(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + pageResult := httphelper.NewPageResult(data, total, req.Page, req.PageSize) + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} + +// @Tags 规则引擎 +// @Summary 规则引擎启动 +// @Produce json +// @Param ruleEngineId path string true "ruleEngineId" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/rule-engine/:ruleEngineId/start [post] +func (ctl *controller) RuleEngineStart(c *gin.Context) { + lc := ctl.lc + id := c.Param(RuleEngineId) + edgeXErr := ctl.getRuleEngineApp().RuleEngineStart(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 规则引擎 +// @Summary 规则引擎停止 +// @Produce json +// @Param ruleEngineId path string true "ruleEngineId" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/rule-engine/:ruleEngineId/stop [post] +func (ctl *controller) RuleEngineStop(c *gin.Context) { + lc := ctl.lc + id := c.Param(RuleEngineId) + edgeXErr := ctl.getRuleEngineApp().RuleEngineStop(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 规则引擎 +// @Summary 规则引擎删除 +// @Produce json +// @Param ruleEngineId path string true "ruleEngineId" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/rule-engine/:ruleEngineId/delete [delete] +func (ctl *controller) RuleEngineDelete(c *gin.Context) { + lc := ctl.lc + id := c.Param(RuleEngineId) + edgeXErr := ctl.getRuleEngineApp().RuleEngineDelete(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 规则引擎 +// @Summary 规则引擎状态 +// @Produce json +// @Param ruleEngineId path string true "ruleEngineId" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/rule-engine/:ruleEngineId/status [get] +func (ctl *controller) RuleEngineStatus(c *gin.Context) { + lc := ctl.lc + id := c.Param(RuleEngineId) + r, edgeXErr := ctl.getRuleEngineApp().RuleEngineStatus(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(r, c.Writer, lc) +} diff --git a/internal/hummingbird/core/controller/http/gateway/scene.go b/internal/hummingbird/core/controller/http/gateway/scene.go new file mode 100644 index 0000000..f77e9ca --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/scene.go @@ -0,0 +1,140 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package gateway + +import ( + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" +) + +func (ctl *controller) SceneAdd(c *gin.Context) { + lc := ctl.lc + var req dtos.SceneAddRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + _, edgeXErr := ctl.getSceneApp().AddScene(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +func (ctl *controller) SceneUpdate(c *gin.Context) { + lc := ctl.lc + var req dtos.SceneUpdateRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + edgeXErr := ctl.getSceneApp().UpdateScene(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +func (ctl *controller) SceneById(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamSceneId) + scene, edgeXErr := ctl.getSceneApp().SceneById(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(scene, c.Writer, lc) +} + +func (ctl *controller) SearchScene(c *gin.Context) { + lc := ctl.lc + var req dtos.SceneSearchQueryRequest + urlDecodeParam(&req, c.Request, lc) + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + + list, total, edgeXErr := ctl.getSceneApp().SceneSearch(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + pageResult := httphelper.NewPageResult(list, total, req.Page, req.PageSize) + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} + +func (ctl *controller) SceneLogSearch(c *gin.Context) { + lc := ctl.lc + sceneId := c.Param(UrlParamSceneId) + var req dtos.SceneLogSearchQueryRequest + urlDecodeParam(&req, c.Request, lc) + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + req.SceneId = sceneId + + ctl.lc.Info("sceneLogSearch log:", req) + list, total, edgeXErr := ctl.getSceneApp().SceneLogSearch(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + pageResult := httphelper.NewPageResult(list, total, req.Page, req.PageSize) + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} + +func (ctl *controller) SceneStart(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamSceneId) + edgeXErr := ctl.getSceneApp().SceneStartById(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +func (ctl *controller) SceneStop(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamSceneId) + edgeXErr := ctl.getSceneApp().SceneStopById(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +func (ctl *controller) DeleteScene(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamSceneId) + edgeXErr := ctl.getSceneApp().DelSceneById(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +func (ctl *controller) SceneLog(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamSceneId) + edgeXErr := ctl.getSceneApp().DelSceneById(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} diff --git a/internal/hummingbird/core/controller/http/gateway/system.go b/internal/hummingbird/core/controller/http/gateway/system.go new file mode 100644 index 0000000..c611e2e --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/system.go @@ -0,0 +1,67 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package gateway + +import ( + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "path" +) + +// @Tags 网关管理 +// @Summary 网关备份下载 +// @Produce json +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/system/backup [get] +func (ctl *controller) SystemBackupHandle(c *gin.Context) { + lc := ctl.lc + filePath, edgeXErr := ctl.getSystemApp().SystemBackupFileDownload(c) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + + fileName := path.Base(filePath) + c.Header("Content-Type", "application/octet-stream") + c.Header("Content-Disposition", "attachment; filename="+fileName) + c.File(filePath) + + // 删除zip文件 + utils.RemoveFileOrDir(filePath) +} + +func (ctl *controller) SystemRecoverHandle(c *gin.Context) { + lc := ctl.lc + file, err := c.FormFile("fileName") + if err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + dist := "/tmp/tedge-recover.zip" + err = c.SaveUploadedFile(file, dist) + if err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.SystemErrorCode, err), c.Writer, lc) + return + } + edgeXErr := ctl.getSystemApp().SystemRecover(c, dist) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + + httphelper.ResultSuccess(nil, c.Writer, lc) +} diff --git a/internal/hummingbird/core/controller/http/gateway/thingmodel.go b/internal/hummingbird/core/controller/http/gateway/thingmodel.go new file mode 100644 index 0000000..83db682 --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/thingmodel.go @@ -0,0 +1,176 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package gateway + +import ( + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" +) + +// @Tags 物模型 +// @Summary 查询系统物模型 +// @Produce json +// @Param request query dtos.SystemThingModelSearchReq true "参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/thingmodel/system [get] +// @Security ApiKeyAuth +func (ctl *controller) SystemThingModelSearch(c *gin.Context) { + lc := ctl.lc + var req dtos.SystemThingModelSearchReq + urlDecodeParam(&req, c.Request, lc) + data, edgeXErr := ctl.getThingModelApp().SystemThingModelSearch(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(data, c.Writer, lc) +} + +// @Tags 物模型 +// @Summary 产品添加物模型 +// @Produce json +// @Param request body dtos.ThingModelAddOrUpdateReq true "参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/thingmodel [post] +// @Security ApiKeyAuth +func (ctl *controller) ThingModelAdd(c *gin.Context) { + lc := ctl.lc + var req dtos.ThingModelAddOrUpdateReq + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + _, edgeXErr := ctl.getThingModelApp().AddThingModel(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 物模型 +// @Summary 修改产品物模型 +// @Produce json +// @Param request body dtos.ThingModelAddOrUpdateReq true "参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/thingmodel [put] +// @Security ApiKeyAuth +func (ctl *controller) ThingModelUpdate(c *gin.Context) { + lc := ctl.lc + var req dtos.ThingModelAddOrUpdateReq + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + edgeXErr := ctl.getThingModelApp().UpdateThingModel(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 物模型 +// @Summary 产品删除物模型 +// @Produce json +// @Param request body dtos.ThingModelDeleteReq true "参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/thingmodel [delete] +// @Security ApiKeyAuth +func (ctl *controller) ThingModelDelete(c *gin.Context) { + lc := ctl.lc + var req dtos.ThingModelDeleteReq + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + edgeXErr := ctl.getThingModelApp().ThingModelDelete(c, req.ThingModelId, req.ThingModelType) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 物模型 +// @Summary 物模型单位 +// @Produce json +// @Param request query dtos.UnitRequest true "参数" +// @Success 200 {array} dtos.UnitResponse +// @Router /api/v1/thingmodel/unit [get] +// @Security ApiKeyAuth +func (ctl *controller) ThingModelUnit(c *gin.Context) { + lc := ctl.lc + var req dtos.UnitRequest + urlDecodeParam(&req, c.Request, lc) + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + data, total, edgeXErr := ctl.getUnitModelApp().UnitTemplateSearch(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + pageResult := httphelper.NewPageResult(data, total, req.Page, req.PageSize) + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} + +// @Tags 物模型 +// @Summary 物模型单位同步 +// @Produce json +// @Param request query dtos.UnitTemplateSyncRequest true "参数" +// @Router /api/v1/thingmodel/unit-sync [post] +// @Security ApiKeyAuth +func (ctl *controller) ThingModelUnitSync(c *gin.Context) { + lc := ctl.lc + var req dtos.UnitTemplateSyncRequest + urlDecodeParam(&req, c.Request, lc) + total, edgeXErr := ctl.getUnitModelApp().Sync(c, "Ireland") + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(total, c.Writer, lc) +} + +func (ctl *controller) ThingModelDocsSync(c *gin.Context) { + lc := ctl.lc + total, edgeXErr := ctl.getDocsApp().SyncDocs(c, "Ireland") + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(total, c.Writer, lc) +} + +func (ctl *controller) ThingModelQuickNavigationSync(c *gin.Context) { + lc := ctl.lc + total, edgeXErr := ctl.getQuickNavigationApp().SyncQuickNavigation(c, "Ireland") + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(total, c.Writer, lc) +} + +//func (ctl *controller) MsgGather(c *gin.Context) { +// lc := ctl.lc +// edgeXErr := ctl.getDeviceApp().DevicesReportMsgGather(context.Background()) +// if edgeXErr != nil { +// httphelper.RenderFail(c, edgeXErr, c.Writer, lc) +// return +// } +// httphelper.ResultSuccess(nil, c.Writer, lc) +//} diff --git a/internal/hummingbird/core/controller/http/gateway/thingmodeltemplate.go b/internal/hummingbird/core/controller/http/gateway/thingmodeltemplate.go new file mode 100644 index 0000000..76e01e1 --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/thingmodeltemplate.go @@ -0,0 +1,78 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package gateway + +import ( + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" +) + +// @Tags 物模型模版 +// @Summary 物模型模版列表 +// @Produce json +// @Param request query dtos.CategoryTemplateRequest true "参数" +// @Success 200 {array} dtos.CategoryTemplateResponse +// @Router /api/v1/thingmodel-template [get] +//@Security ApiKeyAuth +func (ctl *controller) ThingModelTemplateSearch(c *gin.Context) { + lc := ctl.lc + var req dtos.ThingModelTemplateRequest + urlDecodeParam(&req, c.Request, lc) + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + data, total, edgeXErr := ctl.getThingModelTemplateApp().ThingModelTemplateSearch(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + pageResult := httphelper.NewPageResult(data, total, req.Page, req.PageSize) + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} + +// @Tags 物模型模版 +// @Summary 物模型模版详情 +// @Produce json +// @Param request query dtos.CategoryTemplateRequest true "参数" +// @Success 200 {array} dtos.ThingModelTemplateResponse +// @Router /api/v1/thingmodel-template [get] +//@Security ApiKeyAuth +func (ctl *controller) ThingModelTemplateByCategoryKey(c *gin.Context) { + lc := ctl.lc + categoryKey := c.Param(UrlParamCategoryKey) + data, edgeXErr := ctl.getThingModelTemplateApp().ThingModelTemplateByCategoryKey(c, categoryKey) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(data, c.Writer, lc) +} + +// @Tags 物模型模版 +// @Summary 同步物模型 +// @Produce json +// @Param request query dtos.CategoryTemplateRequest true "参数" +// @Router /api/v1/thingmodel-template/sync [post] +//@Security ApiKeyAuth +func (ctl *controller) ThingModelTemplateSync(c *gin.Context) { + lc := ctl.lc + var req dtos.ThingModelTemplateSyncRequest + urlDecodeParam(&req, c.Request, lc) + _, edgeXErr := ctl.getThingModelTemplateApp().Sync(c, "Ireland") + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} diff --git a/internal/hummingbird/core/controller/http/gateway/user.go b/internal/hummingbird/core/controller/http/gateway/user.go new file mode 100644 index 0000000..bf12279 --- /dev/null +++ b/internal/hummingbird/core/controller/http/gateway/user.go @@ -0,0 +1,113 @@ +package gateway + +import ( + "fmt" + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" + "github.com/winc-link/hummingbird/internal/pkg/middleware" + //"github.com/winc-link/hummingbird/internal/pkg/middleware" + //_ "github.com/winc-link/hummingbird/cmd/edge-core/docs" // 千万不要忘了导入把你上一步生成的docs + //gs "github.com/swaggo/gin-swagger" + //"github.com/swaggo/gin-swagger/swaggerFiles" +) + +// @Tags 用户系统 +// @Summary 用户登录 +// @Produce json +// @Param login_request body dtos.LoginRequest true "用户登录参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/auth/login [post] +func (ctl *controller) Login(c *gin.Context) { + lc := ctl.lc + var req dtos.LoginRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + res, edgeXErr := ctl.getUserApp().UserLogin(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + //fmt.Println("res:", res) + httphelper.ResultSuccess(res, c.Writer, lc) +} + +// @Tags 用户系统 +// @Summary 获取网关账号是否初始化 +// @Produce json +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/auth/initInfo [get] +func (ctl *controller) InitInfo(c *gin.Context) { + lc := ctl.lc + res, edgeXErr := ctl.getUserApp().InitInfo() + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + + httphelper.ResultSuccess(res, c.Writer, lc) +} + +// @Tags 用户系统 +// @Summary 密码初始化 +// @Produce json +// @Param init_password_request body dtos.InitPasswordRequest true "密码初始化参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/auth/init-password [post] +func (ctl *controller) InitPassword(c *gin.Context) { + lc := ctl.lc + var req dtos.InitPasswordRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + + edgeXErr := ctl.getUserApp().InitPassword(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +// @Tags 用户系统 +// @Summary 密码修改 +// @Produce json +// @Param request body dtos.UpdatePasswordRequest true "密码修改参数" +// @Success 200 {object} httphelper.CommonResponse +// @Router /api/v1/auth/password [put] +func (ctl *controller) UpdatePassword(c *gin.Context) { + lc := ctl.lc + var req dtos.UpdatePasswordRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + + // 获取登录用户信息 + value, ok := c.Get(constants.JwtParsedInfo) + if !ok { + err := fmt.Errorf("token is invalid") + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultTokenPermission, err), c.Writer, lc) + return + } + claim, ok := value.(*middleware.CustomClaims) + if !ok { + err := fmt.Errorf("Request token is invalid.") + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultTokenPermission, err), c.Writer, lc) + return + } + + edgeXErr := ctl.getUserApp().UpdateUserPassword(c, claim.Username, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + + httphelper.ResultSuccess(nil, c.Writer, lc) +} diff --git a/internal/hummingbird/core/controller/http/openapi/common.go b/internal/hummingbird/core/controller/http/openapi/common.go new file mode 100644 index 0000000..ed694b0 --- /dev/null +++ b/internal/hummingbird/core/controller/http/openapi/common.go @@ -0,0 +1,40 @@ +package openapi + +import ( + "github.com/gorilla/schema" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "net/http" +) + +const ( + UrlParamSceneId = "sceneId" + UrlParamActionId = "actionId" + UrlParamStrategyId = "strategyId" + UrlParamConditionId = "conditionId" + UrlParamJobId = "jobId" + UrlParamProductId = "productId" + UrlParamCategoryKey = "categoryKey" + UrlParamCloudInstanceId = "cloudInstanceId" + UrlParamDeviceId = "deviceId" + UrlParamFuncPointId = "funcPointId" + UrlParamDeviceLibraryId = "deviceLibraryId" + UrlParamDeviceServiceId = "deviceServiceId" + UrlParamDockerConfigId = "dockerConfigId" + UrlParamRuleId = "ruleId" + UrlDataResourceId = "dataResourceId" + RuleEngineId = "ruleEngineId" +) + +var decoder *schema.Decoder + +func init() { + decoder = schema.NewDecoder() + decoder.IgnoreUnknownKeys(true) +} + +func urlDecodeParam(obj interface{}, r *http.Request, lc logger.LoggingClient) { + err := decoder.Decode(obj, r.URL.Query()) + if err != nil { + lc.Errorf("url decoding err %v", err) + } +} diff --git a/internal/hummingbird/core/controller/http/openapi/controller.go b/internal/hummingbird/core/controller/http/openapi/controller.go new file mode 100644 index 0000000..c759afd --- /dev/null +++ b/internal/hummingbird/core/controller/http/openapi/controller.go @@ -0,0 +1,59 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package openapi + +import ( + resourceContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" +) + +type controller struct { + lc logger.LoggingClient + dic *di.Container + //cfg *config.ConfigurationStruct + //gwApp *gatewayapp.GatewayApp +} + +func New(dic *di.Container) *controller { + return &controller{ + lc: container.LoggingClientFrom(dic.Get), + dic: dic, + //cfg: resourceContainer.ConfigurationFrom(dic.Get), + //gwApp: gatewayapp.NewGatewayApp(dic), + } +} + +func (ctl *controller) getUserApp() interfaces.UserItf { + return resourceContainer.UserItfFrom(ctl.dic.Get) +} + +func (ctl *controller) getProductApp() interfaces.ProductItf { + return resourceContainer.ProductAppNameFrom(ctl.dic.Get) +} + +func (ctl *controller) getDeviceApp() interfaces.DeviceItf { + return resourceContainer.DeviceItfFrom(ctl.dic.Get) +} + +func (ctl *controller) getThingModelApp() interfaces.ThingModelCtlItf { + return resourceContainer.ThingModelAppNameFrom(ctl.dic.Get) +} + +func (ctl *controller) getPersistApp() interfaces.PersistItf { + return resourceContainer.PersistItfFrom(ctl.dic.Get) +} diff --git a/internal/hummingbird/core/controller/http/openapi/device.go b/internal/hummingbird/core/controller/http/openapi/device.go new file mode 100644 index 0000000..b3c9bc3 --- /dev/null +++ b/internal/hummingbird/core/controller/http/openapi/device.go @@ -0,0 +1,101 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package openapi + +import ( + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" +) + +func (ctl *controller) OpenApiCreateDevice(c *gin.Context) { + lc := ctl.lc + var req dtos.DeviceAddRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + _, edgeXErr := ctl.getDeviceApp().AddDevice(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +func (ctl *controller) OpenApiUpdateDevice(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamDeviceId) + var req dtos.DeviceUpdateRequest + req.Id = id + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + err := ctl.getDeviceApp().DeviceUpdate(c, req) + if err != nil { + httphelper.RenderFail(c, err, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +func (ctl *controller) OpenApiDeviceSearch(c *gin.Context) { + lc := ctl.lc + var req dtos.DeviceSearchQueryRequest + urlDecodeParam(&req, c.Request, lc) + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + data, total, edgeXErr := ctl.getDeviceApp().OpenApiDevicesSearch(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + pageResult := httphelper.NewPageResult(data, total, req.Page, req.PageSize) + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} + +func (ctl *controller) OpenApiDeviceById(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamDeviceId) + data, edgeXErr := ctl.getDeviceApp().OpenApiDeviceById(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(data, c.Writer, lc) +} + +func (ctl *controller) OpenApiDeleteDevice(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamDeviceId) + edgeXErr := ctl.getDeviceApp().DeleteDeviceById(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +func (ctl *controller) OpenApiDeviceStatus(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamDeviceId) + data, edgeXErr := ctl.getDeviceApp().OpenApiDeviceStatusById(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(data, c.Writer, lc) +} diff --git a/internal/hummingbird/core/controller/http/openapi/product.go b/internal/hummingbird/core/controller/http/openapi/product.go new file mode 100644 index 0000000..6562cfe --- /dev/null +++ b/internal/hummingbird/core/controller/http/openapi/product.go @@ -0,0 +1,114 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package openapi + +import ( + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" +) + +func (ctl *controller) OpenApiCreateProduct(c *gin.Context) { + lc := ctl.lc + var req dtos.OpenApiAddProductRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + _, edgeXErr := ctl.getProductApp().OpenApiAddProduct(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +func (ctl *controller) OpenApiProductById(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamProductId) + data, edgeXErr := ctl.getProductApp().OpenApiProductById(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(data, c.Writer, lc) +} + +func (ctl *controller) OpenApiProductReleaseById(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamProductId) + edgeXErr := ctl.getProductApp().ProductRelease(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +func (ctl *controller) OpenApiProductUnReleaseById(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamProductId) + edgeXErr := ctl.getProductApp().ProductUnRelease(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +func (ctl *controller) OpenApiProductSearch(c *gin.Context) { + lc := ctl.lc + var req dtos.ProductSearchQueryRequest + urlDecodeParam(&req, c.Request, lc) + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + + data, total, edgeXErr := ctl.getProductApp().OpenApiProductSearch(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + pageResult := httphelper.NewPageResult(data, total, req.Page, req.PageSize) + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} + +func (ctl *controller) OpenApiDeleteProduct(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamProductId) + edgeXErr := ctl.getProductApp().ProductDelete(c, id) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +func (ctl *controller) OpenApiUpdateProduct(c *gin.Context) { + lc := ctl.lc + id := c.Param(UrlParamProductId) + var req dtos.OpenApiUpdateProductRequest + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + req.Id = id + edgeXErr := ctl.getProductApp().OpenApiUpdateProduct(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) + +} diff --git a/internal/hummingbird/core/controller/http/openapi/thingmodel.go b/internal/hummingbird/core/controller/http/openapi/thingmodel.go new file mode 100644 index 0000000..024c6f9 --- /dev/null +++ b/internal/hummingbird/core/controller/http/openapi/thingmodel.go @@ -0,0 +1,141 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package openapi + +import ( + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" +) + +func (ctl *controller) OpenApiThingModelAddOrUpdate(c *gin.Context) { + lc := ctl.lc + var req dtos.OpenApiThingModelAddOrUpdateReq + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + edgeXErr := ctl.getThingModelApp().OpenApiAddThingModel(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +func (ctl *controller) OpenApiThingModel(c *gin.Context) { + lc := ctl.lc + var req dtos.OpenApiQueryThingModelReq + urlDecodeParam(&req, c.Request, lc) + data, edgeXErr := ctl.getThingModelApp().OpenApiQueryThingModel(c, req.ProductId) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(data, c.Writer, lc) +} + +func (ctl *controller) OpenApiDeleteThingModel(c *gin.Context) { + lc := ctl.lc + var req dtos.OpenApiThingModelDeleteReq + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + edgeXErr := ctl.getThingModelApp().OpenApiDeleteThingModel(c, req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + httphelper.ResultSuccess(nil, c.Writer, lc) +} + +func (ctl *controller) OpenApiSetDeviceProperty(c *gin.Context) { + lc := ctl.lc + var req dtos.OpenApiSetDeviceThingModel + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + + var code string + var value interface{} + for s, i := range req.Item { + code = s + value = i + } + execRes := ctl.getDeviceApp().DeviceAction(dtos.JobAction{ + DeviceId: req.DeviceId, + Code: code, + Value: value, + }) + httphelper.ResultSuccess(execRes, c.Writer, lc) +} + +func (ctl *controller) OpenApiInvokeThingService(c *gin.Context) { + lc := ctl.lc + var req dtos.InvokeDeviceServiceReq + if err := c.ShouldBind(&req); err != nil { + httphelper.RenderFail(c, errort.NewCommonErr(errort.DefaultReqParamsError, err), c.Writer, lc) + return + } + + execRes := ctl.getDeviceApp().DeviceInvokeThingService(req) + httphelper.ResultSuccess(execRes, c.Writer, lc) +} + +func (ctl *controller) OpenApiQueryDevicePropertyData(c *gin.Context) { + lc := ctl.lc + var req dtos.ThingModelPropertyDataRequest + urlDecodeParam(&req, c.Request, lc) + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + data, total, edgeXErr := ctl.getPersistApp().SearchDeviceThingModelHistoryPropertyData(req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + pageResult := httphelper.NewPageResult(data, uint32(total), req.Page, req.PageSize) + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} + +func (ctl *controller) OpenApiQueryDeviceEventData(c *gin.Context) { + lc := ctl.lc + var req dtos.ThingModelEventDataRequest + urlDecodeParam(&req, c.Request, lc) + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + data, total, edgeXErr := ctl.getPersistApp().SearchDeviceThingModelEventData(req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + pageResult := httphelper.NewPageResult(data, uint32(total), req.Page, req.PageSize) + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} + +func (ctl *controller) OpenApiQueryDeviceServiceData(c *gin.Context) { + lc := ctl.lc + var req dtos.ThingModelServiceDataRequest + urlDecodeParam(&req, c.Request, lc) + dtos.CorrectionPageParam(&req.BaseSearchConditionQuery) + + data, total, edgeXErr := ctl.getPersistApp().SearchDeviceThingModelServiceData(req) + if edgeXErr != nil { + httphelper.RenderFail(c, edgeXErr, c.Writer, lc) + return + } + pageResult := httphelper.NewPageResult(data, uint32(total), req.Page, req.PageSize) + httphelper.ResultSuccess(pageResult, c.Writer, lc) +} diff --git a/internal/hummingbird/core/controller/http/openapi/user.go b/internal/hummingbird/core/controller/http/openapi/user.go new file mode 100644 index 0000000..4d2f36a --- /dev/null +++ b/internal/hummingbird/core/controller/http/openapi/user.go @@ -0,0 +1,87 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package openapi + +import ( + buildInErrors "errors" + + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/errort" + jwt2 "github.com/winc-link/hummingbird/internal/tools/jwt" + "github.com/winc-link/hummingbird/internal/tools/openapihelper" +) + +func (ctl *controller) Login(c *gin.Context) { + var lc = ctl.lc + var req dtos.LoginRequest + if err := c.ShouldBind(&req); err != nil { + lc.Error(err.Error()) + openapihelper.ReaderFail(c, errort.ParamsError) + return + } + tokenDetail, edgeXErr := ctl.getUserApp().OpenApiUserLogin(c, req) + if edgeXErr != nil { + lc.Error(edgeXErr.Error()) + openapihelper.ReaderFail(c, errort.OpenApiErrorCode(errort.NewCommonEdgeXWrapper(edgeXErr).Code())) + return + } + result := struct { + AccessToken string `json:"access_token"` + ExpireTime int64 `json:"expire"` + RefreshToken string `json:"refresh_token"` + }{ + AccessToken: tokenDetail.AccessToken, + ExpireTime: tokenDetail.AtExpires, + RefreshToken: tokenDetail.RefreshToken, + } + openapihelper.ReaderSuccess(c, result) +} + +func (ctl *controller) RefreshToken(c *gin.Context) { + refreshToken := c.Param("refreshToken") + if refreshToken == "" { + openapihelper.ReaderFail(c, errort.TokenValid) + return + } + jwt := jwt2.NewJWT(jwt2.RefreshKey) + claim, err := jwt.ParseToken(refreshToken) + if err != nil { + switch { + case buildInErrors.Is(err, jwt2.TokenExpired): + openapihelper.ReaderFail(c, errort.TokenExpired) + case buildInErrors.Is(err, jwt2.TokenInvalid): + openapihelper.ReaderFail(c, errort.TokenValid) + default: + openapihelper.ReaderFail(c, errort.SystemErrorCode) + } + return + } + tokenDetail, err := ctl.getUserApp().CreateTokenDetail(claim.Username) + if err != nil { + openapihelper.ReaderFail(c, errort.SystemErrorCode) + return + } + result := struct { + AccessToken string `json:"access_token"` + ExpireTime int64 `json:"expire"` + RefreshToken string `json:"refresh_token"` + }{ + AccessToken: tokenDetail.AccessToken, + ExpireTime: tokenDetail.AtExpires, + RefreshToken: tokenDetail.RefreshToken, + } + openapihelper.ReaderSuccess(c, result) +} diff --git a/internal/hummingbird/core/controller/http/websocket/common.go b/internal/hummingbird/core/controller/http/websocket/common.go new file mode 100644 index 0000000..d57b8bc --- /dev/null +++ b/internal/hummingbird/core/controller/http/websocket/common.go @@ -0,0 +1,38 @@ +package websocket + +import ( + "encoding/json" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" + "github.com/winc-link/hummingbird/internal/pkg/i18n" + + "github.com/gin-gonic/gin/binding" +) + +type SystemCheckLangReq struct { + Lang string `json:"lang"` +} + +/** +receive: {"code":10003,"data":{"lang":"zh"}} +*/ +func CheckLang(c *wsClient, data interface{}, code dtos.WsCode) { + var req SystemCheckLangReq + bytes, _ := json.Marshal(data) + err := binding.JSON.BindBody(bytes, &req) + if err != nil { + //c.lc.Error(err.Error()) + c.sendData( + code, + httphelper.WsResultFail( + errort.DefaultReqParamsError, + i18n.TransCode(c.ctx, errort.DefaultReqParamsError, nil), + ), + ) + return + } + + c.ChangeLang(req.Lang) + c.sendData(code, httphelper.WsResult(errort.DefaultSuccess, nil, "", "")) +} diff --git a/internal/hummingbird/core/controller/http/websocket/devicelibrary.go b/internal/hummingbird/core/controller/http/websocket/devicelibrary.go new file mode 100644 index 0000000..a6bed35 --- /dev/null +++ b/internal/hummingbird/core/controller/http/websocket/devicelibrary.go @@ -0,0 +1,76 @@ +package websocket + +import ( + "encoding/json" + "github.com/gin-gonic/gin/binding" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" + "github.com/winc-link/hummingbird/internal/pkg/i18n" + "time" +) + +/** +receive: {"code":10001,"data":{"id":"3208327514","version":"2.0.1"}} +*/ +func DeviceLibraryUpgrade(c *wsClient, data interface{}, code dtos.WsCode) { + var req dtos.DeviceLibraryUpgradeRequest + bytes, _ := json.Marshal(data) + err := binding.JSON.BindBody(bytes, &req) + if err != nil { + c.lc.Error(err.Error()) + c.sendData( + code, + httphelper.WsResultFail( + errort.DefaultReqParamsError, + i18n.TransCode(c.ctx, errort.DefaultReqParamsError, nil), + ), + ) + return + } + + driverApp := container.DriverAppFrom(c.dic.Get) + errCode := errort.DefaultSuccess + err = driverApp.UpgradeDeviceLibrary(c.ctx, req) + if err != nil { + c.lc.Errorf("DeviceLibraryUpgrade err: %+v", err) + errCode = errort.NewCommonEdgeXWrapper(err).Code() + } + + // 响应 + dlName := "" + resp := dtos.DeviceLibraryUpgradeResponse{ + Id: req.Id, + } + dl, err := driverApp.DeviceLibraryById(c.ctx, req.Id) + if err != nil { + c.lc.Errorf("get DeviceLibraryById err:%+v", err) + } else { + dlName = dl.Name + resp.Version = dl.Version + resp.OperateStatus = dl.OperateStatus + } + + isSuccess := true + errMsg := "" + successMsg := "" + status := i18n.DefaultSuccess + if errCode != errort.DefaultSuccess { + errMsg = i18n.TransCode(c.ctx, errCode, nil) + isSuccess = false + status = i18n.DefaultFail + } + + msg := i18n.Trans(i18n.GetLang(c.ctx), i18n.LibraryUpgradeDownloadResp, map[string]interface{}{ + "name": dlName, + "status": i18n.Trans(i18n.GetLang(c.ctx), status, nil), + }) + if isSuccess { + successMsg = msg + } else { + errMsg = msg + ": " + errMsg + } + time.Sleep(2 * time.Second) + c.sendData(code, httphelper.WsResult(errCode, resp, errMsg, successMsg)) +} diff --git a/internal/hummingbird/core/controller/http/websocket/deviceservice.go b/internal/hummingbird/core/controller/http/websocket/deviceservice.go new file mode 100644 index 0000000..abd6afe --- /dev/null +++ b/internal/hummingbird/core/controller/http/websocket/deviceservice.go @@ -0,0 +1,78 @@ +package websocket + +import ( + "encoding/json" + "github.com/gin-gonic/gin/binding" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" + "github.com/winc-link/hummingbird/internal/pkg/i18n" + "time" +) + +/** +receive: {"code":10002,"data":{"id":"935769","run_status":1}} +*/ +func DeviceServiceRunStatus(c *wsClient, data interface{}, code dtos.WsCode) { + var req dtos.UpdateDeviceServiceRunStatusRequest + bytes, _ := json.Marshal(data) + err := binding.JSON.BindBody(bytes, &req) + if err != nil { + c.lc.Error(err.Error()) + c.sendData( + code, + httphelper.WsResultFail( + errort.DefaultReqParamsError, + i18n.TransCode(c.ctx, errort.DefaultReqParamsError, nil), + ), + ) + return + } + + driverServerApp := container.DriverServiceAppFrom(c.dic.Get) + err = driverServerApp.UpdateRunStatus(c.ctx, req) + + errCode := errort.DefaultSuccess + errMsg := "" + isSuccess := true + successMsg := "" + status := i18n.DefaultSuccess + + if err != nil { + c.lc.Errorf("DeviceServiceRunStatus err: %+v", err) + edgeX := errort.NewCommonEdgeXWrapper(err) + errCode = edgeX.Code() + errMsg = edgeX.Error() + } + + // 响应 + dsName := "" + resp := dtos.UpdateDeviceServiceRunStatusResponse{ + Id: req.Id, + } + ds, err := driverServerApp.Get(c.ctx, req.Id) + if err != nil { + c.lc.Errorf("get driverServer err:%+v", err) + } else { + resp.RunStatus = ds.RunStatus + } + + if errCode != errort.DefaultSuccess { + if errCode != errort.ContainerRunFail { + errMsg = i18n.TransCode(c.ctx, errCode, nil) + } + isSuccess = false + status = i18n.DefaultFail + } + + msg := i18n.Trans(i18n.GetLang(c.ctx), i18n.ServiceRunStatusResp, map[string]interface{}{ + "name": dsName, + "status": i18n.Trans(i18n.GetLang(c.ctx), status, nil), + }) + if isSuccess { + successMsg = msg + } + time.Sleep(2 * time.Second) + c.sendData(code, httphelper.WsResult(errCode, resp, errMsg, successMsg)) +} diff --git a/internal/hummingbird/core/controller/http/websocket/deviceservicelog.go b/internal/hummingbird/core/controller/http/websocket/deviceservicelog.go new file mode 100644 index 0000000..00df42c --- /dev/null +++ b/internal/hummingbird/core/controller/http/websocket/deviceservicelog.go @@ -0,0 +1,202 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package websocket + +import ( + "encoding/json" + "github.com/gin-gonic/gin/binding" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" + "github.com/winc-link/hummingbird/internal/pkg/i18n" + "time" +) + +/** +receive: {"code":10004,"data":{"id":"123291"}} +*/ +func DeviceServiceLog(c *wsClient, data interface{}, code dtos.WsCode) { + var req dtos.DeviceServiceRunLogRequest + bytes, _ := json.Marshal(data) + err := binding.JSON.BindBody(bytes, &req) + if err != nil { + c.lc.Error(err.Error()) + c.sendData( + code, + httphelper.WsResultFail( + errort.DefaultReqParamsError, + i18n.TransCode(c.ctx, errort.DefaultReqParamsError, nil), + ), + ) + return + } + + driverServerApp := container.DriverServiceAppFrom(c.dic.Get) + + ds, err := driverServerApp.Get(c.ctx, req.Id) + if err != nil { + c.lc.Error(err.Error()) + c.sendData( + code, + httphelper.WsResultFail( + errort.DefaultReqParamsError, + i18n.TransCode(c.ctx, errort.DefaultReqParamsError, nil), + ), + ) + return + } + + driverLibApp := container.DriverAppFrom(c.dic.Get) + dl, err := driverLibApp.DriverLibById(ds.DeviceLibraryId) + if err != nil { + c.lc.Error(err.Error()) + c.sendData( + code, + httphelper.WsResultFail( + errort.DefaultReqParamsError, + i18n.TransCode(c.ctx, errort.DefaultReqParamsError, nil), + ), + ) + return + } + logfilePath := interfaces.DMIFrom(c.dic.Get).GetDriverInstanceLogPath(dl.ContainerName) + + if req.Operate == constants.StatusRead { + //读取日志 + StartReadServiceLog(c, req.Id, logfilePath, code) + + } else if req.Operate == constants.StatusStop { + //停止 + StopReadServiceLog(c, req.Id, code) + errCode := errort.DefaultSuccess + errMsg := "" + successMsg := "" + status := i18n.DefaultSuccess + + msg := i18n.Trans(i18n.GetLang(c.ctx), i18n.CloudInstanceLogResp, map[string]interface{}{ + "name": ds.Name, + "status": i18n.Trans(i18n.GetLang(c.ctx), status, nil), + }) + successMsg = msg + c.sendData(code, httphelper.WsResult(errCode, "", errMsg, successMsg)) + } + +} + +func StartReadServiceLog(c *wsClient, serviceId, logFilePath string, code dtos.WsCode) { + driverServerLogApp := container.HpcServiceAppFrom(c.dic.Get) + hpc := driverServerLogApp.Add(serviceId, logFilePath) + + tails, err := hpc.Read() + if err != nil { + //报错 + } + for { + line, ok := <-tails //遍历chan,读取日志内容 + if !ok { + c.lc.Info("stop") + return + } + c.lc.Infof("msg %+v", line.Text) + c.sendData(code, httphelper.WsResult(errort.DefaultSuccess, line.Text, "", "")) + //return + } +} + +func StopReadServiceLog(c *wsClient, serviceId string, code dtos.WsCode) { + driverServerLogApp := container.HpcServiceAppFrom(c.dic.Get) + hpc := driverServerLogApp.Get(serviceId) + if hpc == nil { + //报错 + c.sendData( + code, + httphelper.WsResultFail( + errort.DefaultReqParamsError, + i18n.TransCode(c.ctx, errort.DefaultReqParamsError, nil), + ), + ) + return + } + hpc.Stop() +} + +/** +receive: {"code":10003,"data":{"id":"3208327514"}} +*/ +func DeviceLibraryDelete(c *wsClient, data interface{}, code dtos.WsCode) { + var req dtos.DeviceServiceDeleteRequest + bytes, _ := json.Marshal(data) + err := binding.JSON.BindBody(bytes, &req) + if err != nil { + c.lc.Error(err.Error()) + c.sendData( + code, + httphelper.WsResultFail( + errort.DefaultReqParamsError, + i18n.TransCode(c.ctx, errort.DefaultReqParamsError, nil), + ), + ) + return + } + driverServiceApp := container.DriverServiceAppFrom(c.dic.Get) + // 响应 + dsName := "" + + ds, err := driverServiceApp.Get(c.ctx, req.Id) + if err != nil { + c.sendData( + code, + httphelper.WsResultFail( + errort.DefaultReqParamsError, + i18n.TransCode(c.ctx, errort.DefaultReqParamsError, nil), + ), + ) + return + } else { + dsName = ds.Name + } + + errCode := errort.DefaultSuccess + err = driverServiceApp.Del(c.ctx, req.Id) + if err != nil { + c.lc.Errorf("del cloud service err: %+v", err) + errCode = errort.NewCommonEdgeXWrapper(err).Code() + } + + isSuccess := true + errMsg := "" + successMsg := "" + status := i18n.DefaultSuccess + + if errCode != errort.DefaultSuccess { + errMsg = i18n.TransCode(c.ctx, errCode, nil) + isSuccess = false + status = i18n.DefaultFail + } + msg := i18n.Trans(i18n.GetLang(c.ctx), i18n.AppServiceDeleteResp, map[string]interface{}{ + "name": dsName, + "status": i18n.Trans(i18n.GetLang(c.ctx), status, nil), + }) + if isSuccess { + successMsg = msg + } else { + errMsg = msg + ": " + errMsg + } + time.Sleep(2 * time.Second) + c.sendData(code, httphelper.WsResult(errCode, "", errMsg, successMsg)) +} diff --git a/internal/hummingbird/core/controller/http/websocket/rule.go b/internal/hummingbird/core/controller/http/websocket/rule.go new file mode 100644 index 0000000..bd59141 --- /dev/null +++ b/internal/hummingbird/core/controller/http/websocket/rule.go @@ -0,0 +1,31 @@ +package websocket + +import ( + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" +) + +type WsData struct { + Code dtos.WsCode `json:"code"` + Data interface{} `json:"data"` +} + +type WsResponse struct { + Code dtos.WsCode `json:"code"` + Data httphelper.CommonResponse `json:"data"` +} + +// 前端websockets +// 前端请求处理 +type wsFunc func(*wsClient, interface{}, dtos.WsCode) + +var wsFuncMap = map[dtos.WsCode]wsFunc{ + //驱动相关 + dtos.WsCodeDeviceLibraryUpgrade: DeviceLibraryUpgrade, + dtos.WsCodeDeviceServiceRunStatus: DeviceServiceRunStatus, + dtos.WsCodeDeviceServiceLog: DeviceServiceLog, + dtos.WsCodeDeviceLibraryDelete: DeviceLibraryDelete, + + //多语言 + dtos.WsCodeCheckLang: CheckLang, +} diff --git a/internal/hummingbird/core/controller/http/websocket/seriousalert.go b/internal/hummingbird/core/controller/http/websocket/seriousalert.go new file mode 100644 index 0000000..c3139e3 --- /dev/null +++ b/internal/hummingbird/core/controller/http/websocket/seriousalert.go @@ -0,0 +1,12 @@ +package websocket + +/** +receive: {"code":10200,"data":{}} +*/ +//func SeriousAlert(c *wsClient, data interface{}, code dtos.WsCode) { +// systemItf := container.SystemItfFrom(c.dic.Get) +// systemItf.NewClientIn(dtos.RpcData{ +// Code: code, +// ReqId: c.id, +// }) +//} diff --git a/internal/hummingbird/core/controller/http/websocket/server.go b/internal/hummingbird/core/controller/http/websocket/server.go new file mode 100644 index 0000000..a18ef71 --- /dev/null +++ b/internal/hummingbird/core/controller/http/websocket/server.go @@ -0,0 +1,275 @@ +package websocket + +import ( + "context" + "encoding/json" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" + "github.com/winc-link/hummingbird/internal/pkg/i18n" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "github.com/winc-link/hummingbird/internal/pkg/middleware" + "net/http" + "strconv" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" +) + +const ( + // Time allowed to write a message to the peer. + writeWait = 10 * time.Second + + // Time allowed to read the next pong message from the peer. + pongWait = 60 * time.Second + + // Send pings to peer with this period. Must be less than pongWait. + pingPeriod = (pongWait * 9) / 10 + + // Maximum message size allowed from peer. + maxMessageSize = 512 +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true + }, +} + +type wsClient struct { + id string + hub *WsServer + ctx *gin.Context + lc logger.LoggingClient + dic *di.Container + conn *websocket.Conn + send chan WsResponse +} + +type WsServer struct { + lc logger.LoggingClient + + // Registered clients. + clients map[*wsClient]bool + + clientIdMap map[string]*wsClient + + // Register requests from the clients. + register chan *wsClient + + // Unregister requests from clients. + unregister chan *wsClient + + broadcast chan WsResponse + ctx context.Context + dic *di.Container +} + +func NewServer(dic *di.Container) *WsServer { + var lc = container.LoggingClientFrom(dic.Get) + c := &WsServer{ + lc: lc, + clients: make(map[*wsClient]bool), + clientIdMap: make(map[string]*wsClient), + register: make(chan *wsClient), + unregister: make(chan *wsClient), + broadcast: make(chan WsResponse), + ctx: context.Background(), + dic: dic, + } + go c.run() + go c.listenBroadcast() + + return c +} + +func (s *WsServer) Handle(c *gin.Context) { + lc := s.lc + w := c.Writer + r := c.Request + s.ctx = r.Context() + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + lc.Error("upgrade webSocket err", err) + httphelper.RenderFail(c, err, c.Writer, lc) + return + } + + // 对于client,以ip:userid 作为唯一请求方,用于断线重连 + clientId := strings.Split(conn.RemoteAddr().String(), ":")[0] + if r.Header.Get("X-Real-Ip") != "" { + clientId = r.Header.Get("X-Real-Ip") + } + lc.Debugf("ws client ip: %s", clientId) + + value, ok := c.Get(constants.JwtParsedInfo) + if ok { + claim, ok := value.(*middleware.CustomClaims) + if ok { + clientId = clientId + ":" + strconv.Itoa(int(claim.ID)) + } + } + + client := &wsClient{ + id: clientId, + hub: s, + conn: conn, + ctx: c, + dic: s.dic, + lc: lc, + send: make(chan WsResponse), + } + s.register <- client + + go client.writePump() + go client.readPump() +} + +func (s *WsServer) run() { + for { + select { + //TODO: 缺少 done 时的退出 + case client := <-s.register: + s.clients[client] = true + s.clientIdMap[client.id] = client + case client := <-s.unregister: + if _, ok := s.clients[client]; ok { + delete(s.clients, client) + } + if _, ok := s.clientIdMap[client.id]; ok { + delete(s.clientIdMap, client.id) + } + case data := <-s.broadcast: + s.lc.Debugf("broadcast forward message to alertclient: %v, data: %+v", len(s.clients), data) + for client := range s.clients { + select { + case client.send <- data: + default: + } + } + } + } +} + +// 监听内部业务推送前端广播消息 +func (s *WsServer) listenBroadcast() { + sc := container.StreamClientFrom(s.dic.Get) + for { + select { + case data := <-sc.Recv(): + var resp httphelper.CommonResponse + if data.ErrCode == errort.DefaultSuccess { + resp = httphelper.NewSuccessCommonResponse(data.Data) + } else { + resp = httphelper.NewFailWithI18nResponse(s.ctx, errort.NewCommonErr(data.ErrCode, nil)) + } + s.broadcast <- WsResponse{ + Code: data.Code, + Data: resp, + } + } + } +} + +func (c *wsClient) writePump() { + ticker := time.NewTicker(pingPeriod) + defer func() { + ticker.Stop() + c.conn.Close() + }() + + for { + select { + case message, ok := <-c.send: + if !message.Data.Success { + if message.Data.ErrorCode != errort.ContainerRunFail { //如果是ContainerRunFail错误,把原错误返回出去方便排查问题。 + message.Data.ErrorMsg = i18n.TransCode(c.ctx, message.Data.ErrorCode, nil) + } + } + + messageBody, _ := json.Marshal(message) + + c.lc.Infof("websocket to resp data: %s", string(messageBody)) + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + if !ok { + // The hub closed the channel. + c.lc.Warn("client send channel closed!") + c.conn.WriteMessage(websocket.CloseMessage, []byte{}) + return + } + + w, err := c.conn.NextWriter(websocket.TextMessage) + if err != nil { + c.lc.Errorf("websocket NextWriter err:", err) + return + } + _, err = w.Write(messageBody) + if err != nil { + c.lc.Errorf("websocket Write err:", err) + } + + if err := w.Close(); err != nil { + return + } + case <-ticker.C: + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } + } + } +} + +func (c *wsClient) readPump() { + defer func() { + c.hub.unregister <- c + c.conn.Close() + }() + c.conn.SetReadLimit(maxMessageSize) + c.conn.SetReadDeadline(time.Now().Add(pongWait)) + c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) + for { + _, msg, err := c.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) { + c.lc.Errorf("ReadMessage close info: %v", err) + } + break + } + if !json.Valid(msg) { + c.lc.Errorf("ReadMessage data not is json format") + continue + } + c.lc.Infof("websocket req %+v", string(msg)) + d := WsData{} + err = json.Unmarshal(msg, &d) + if err != nil { + c.lc.Errorf("ReadMessage data unmarshal err: %v", err) + continue + } + if f, ok := wsFuncMap[d.Code]; ok { + go f(c, d.Data, d.Code) + } + } +} + +func (c *wsClient) sendData(code dtos.WsCode, d httphelper.CommonResponse) { + resData := WsResponse{} + resData.Code = code + resData.Data = d + + c.send <- resData +} + +// 切换语言 +func (c *wsClient) ChangeLang(lang string) { + c.ctx.Set(constants.AcceptLanguage, lang) +} diff --git a/internal/hummingbird/core/controller/rpcserver/driverserver/base.go b/internal/hummingbird/core/controller/rpcserver/driverserver/base.go new file mode 100644 index 0000000..3206b7b --- /dev/null +++ b/internal/hummingbird/core/controller/rpcserver/driverserver/base.go @@ -0,0 +1,30 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package driverserver + +import ( + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "google.golang.org/grpc" +) + +func RegisterRPCService(lc logger.LoggingClient, dic *di.Container, s *grpc.Server) { + NewThingModelServer(lc, dic).RegisterServer(s) + NewDriverDeviceServer(lc, dic).RegisterServer(s) + NewCloudInstanceServer(lc, dic).RegisterServer(s) + NewGatewayServer(lc, dic).RegisterServer(s) + NewDriverStorageServer(lc, dic).RegisterServer(s) + NewProductServer(lc, dic).RegisterServer(s) +} diff --git a/internal/hummingbird/core/controller/rpcserver/driverserver/cloudinstance.go b/internal/hummingbird/core/controller/rpcserver/driverserver/cloudinstance.go new file mode 100644 index 0000000..9507237 --- /dev/null +++ b/internal/hummingbird/core/controller/rpcserver/driverserver/cloudinstance.go @@ -0,0 +1,72 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package driverserver + +import ( + "context" + "github.com/winc-link/edge-driver-proto/cloudinstance" + "github.com/winc-link/edge-driver-proto/drivercommon" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "google.golang.org/grpc" +) + +type CloudInstanceServer struct { + cloudinstance.UnimplementedCloudInstanceServiceServer + + lc logger.LoggingClient + dic *di.Container +} + +func (s *CloudInstanceServer) DriverReportPlatformInfo(ctx context.Context, request *cloudinstance.DriverReportPlatformInfoRequest) (*cloudinstance.DriverReportPlatformInfoResponse, error) { + response := new(cloudinstance.DriverReportPlatformInfoResponse) + response.BaseResponse = new(drivercommon.CommonResponse) + if request.GetDriverInstanceId() == "" { + response.BaseResponse.Success = false + response.BaseResponse.ErrorMessage = "param error" + return response, nil + } + + driverService := container.DriverServiceAppFrom(s.dic.Get) + req := dtos.DeviceServiceUpdateRequest{} + req.Id = request.GetDriverInstanceId() + req.Platform = constants.TransformEdgePlatformToDbPlatform(request.GetIotPlatform()) + err := driverService.Update(ctx, req) + if err != nil { + response.BaseResponse.ErrorMessage = err.Error() + } + response.BaseResponse.Success = true + return response, nil +} + +func (s *CloudInstanceServer) QueryCloudInstanceByPlatform(ctx context.Context, request *cloudinstance.QueryCloudInstanceByPlatformRequest) (*cloudinstance.QueryCloudInstanceByPlatformResponse, error) { + response := new(cloudinstance.QueryCloudInstanceByPlatformResponse) + return response, nil + +} + +func NewCloudInstanceServer(lc logger.LoggingClient, dic *di.Container) *CloudInstanceServer { + return &CloudInstanceServer{ + lc: lc, + dic: dic, + } +} + +func (s *CloudInstanceServer) RegisterServer(server *grpc.Server) { + cloudinstance.RegisterCloudInstanceServiceServer(server, s) +} diff --git a/internal/hummingbird/core/controller/rpcserver/driverserver/device.go b/internal/hummingbird/core/controller/rpcserver/driverserver/device.go new file mode 100644 index 0000000..1234813 --- /dev/null +++ b/internal/hummingbird/core/controller/rpcserver/driverserver/device.go @@ -0,0 +1,162 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package driverserver + +import ( + "context" + "fmt" + "github.com/winc-link/edge-driver-proto/drivercommon" + device "github.com/winc-link/edge-driver-proto/driverdevice" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "google.golang.org/grpc" + "strconv" +) + +type DriverDeviceServer struct { + device.UnimplementedRpcDeviceServer + lc logger.LoggingClient + dic *di.Container +} + +func (s *DriverDeviceServer) ConnectIotPlatform(ctx context.Context, request *device.ConnectIotPlatformRequest) (*device.ConnectIotPlatformResponse, error) { + deviceItf := container.DeviceItfFrom(s.dic.Get) + return deviceItf.ConnectIotPlatform(ctx, request), nil +} + +func (s *DriverDeviceServer) DisconnectIotPlatform(ctx context.Context, request *device.DisconnectIotPlatformRequest) (*device.DisconnectIotPlatformResponse, error) { + deviceItf := container.DeviceItfFrom(s.dic.Get) + return deviceItf.DisConnectIotPlatform(ctx, request), nil +} + +func (s *DriverDeviceServer) GetDeviceConnectStatus(ctx context.Context, request *device.GetDeviceConnectStatusRequest) (*device.GetDeviceConnectStatusResponse, error) { + deviceItf := container.DeviceItfFrom(s.dic.Get) + return deviceItf.GetDeviceConnectStatus(ctx, request), nil +} + +func (s *DriverDeviceServer) QueryDeviceList(ctx context.Context, request *device.QueryDeviceListRequest) (*device.QueryDeviceListResponse, error) { + deviceItf := container.DeviceItfFrom(s.dic.Get) + + var platform string + if request.BaseRequest.UseCloudPlatform { + platform = string(constants.TransformEdgePlatformToDbPlatform(request.BaseRequest.GetCloudInstanceInfo().GetIotPlatform())) + } else { + platform = string(constants.IotPlatform_LocalIot) + } + devices, total, err := deviceItf.DevicesModelSearch(ctx, dtos.DeviceSearchQueryRequest{ + DriveInstanceId: request.BaseRequest.DriverInstanceId, + Platform: platform, + }) + response := new(device.QueryDeviceListResponse) + response.BaseResponse = new(drivercommon.CommonResponse) + if err != nil { + response.BaseResponse.Success = false + response.BaseResponse.ErrorMessage = err.Error() + return response, nil + } + response.BaseResponse.Success = true + response.Data = new(device.QueryDeviceListResponse_Data) + response.Data.Total = total + for _, queryResponse := range devices { + response.Data.Devices = append(response.Data.Devices, queryResponse.TransformToDriverDevice()) + } + return response, nil +} + +func (s *DriverDeviceServer) QueryDeviceById(ctx context.Context, request *device.QueryDeviceByIdRequest) (*device.QueryDeviceByIdResponse, error) { + deviceItf := container.DeviceItfFrom(s.dic.Get) + + deviceInfo, err := deviceItf.DeviceModelById(ctx, request.Id) + response := new(device.QueryDeviceByIdResponse) + response.BaseResponse = new(drivercommon.CommonResponse) + if err != nil { + response.BaseResponse.Success = false + response.BaseResponse.ErrorMessage = err.Error() + return response, nil + } + response.BaseResponse.Success = true + response.Data = new(device.QueryDeviceByIdResponse_Data) + response.Data.Device = deviceInfo.TransformToDriverDevice() + return response, nil +} + +func (s *DriverDeviceServer) CreateDevice(ctx context.Context, request *device.CreateDeviceRequest) (*device.CreateDeviceRequestResponse, error) { + response := new(device.CreateDeviceRequestResponse) + response.BaseResponse = new(drivercommon.CommonResponse) + deviceItf := container.DeviceItfFrom(s.dic.Get) + productItf := container.ProductAppNameFrom(s.dic.Get) + productInfo, err := productItf.ProductById(ctx, request.Device.ProductId) + if err != nil { + err = errort.NewCommonErr(errort.ProductNotExist, fmt.Errorf("")) + errWrapper := errort.NewCommonEdgeXWrapper(err) + response.BaseResponse.Success = false + response.BaseResponse.Code = strconv.Itoa(int(errWrapper.Code())) + response.BaseResponse.ErrorMessage = errWrapper.Message() + return response, nil + } + + var insertDevice dtos.DeviceAddRequest + insertDevice.ProductId = productInfo.Id + insertDevice.Platform = constants.IotPlatform_LocalIot + insertDevice.Name = request.Device.Name + insertDevice.DriverInstanceId = request.BaseRequest.GetDriverInstanceId() + + deviceId, err := deviceItf.AddDevice(ctx, insertDevice) + if err != nil { + errWrapper := errort.NewCommonEdgeXWrapper(err) + response.BaseResponse.Success = false + response.BaseResponse.Code = strconv.Itoa(int(errWrapper.Code())) + response.BaseResponse.ErrorMessage = errWrapper.Message() + return response, nil + } + response.BaseResponse.Success = true + response.Data = new(device.CreateDeviceRequestResponse_Data) + response.Data.Devices = new(device.Device) + response.Data.Devices.Id = deviceId + response.Data.Devices.Name = request.Device.Name + response.Data.Devices.Description = request.Device.Description + response.Data.Devices.ProductId = request.Device.ProductId + response.Data.Devices.Status = device.DeviceStatus_OffLine + response.Data.Devices.Platform = drivercommon.IotPlatform_LocalIot + + return response, nil +} + +func (s *DriverDeviceServer) CreateDeviceAndConnect(ctx context.Context, request *device.CreateDeviceAndConnectRequest) (*device.CreateDeviceAndConnectRequestResponse, error) { + //TODO implement me + panic("implement me") +} + +func (s *DriverDeviceServer) DeleteDevice(ctx context.Context, request *device.DeleteDeviceRequest) (*device.DeleteDeviceResponse, error) { + //TODO implement me + panic("implement me") +} + +var _ device.RpcDeviceServer = (*DriverDeviceServer)(nil) + +func NewDriverDeviceServer(lc logger.LoggingClient, dic *di.Container) *DriverDeviceServer { + return &DriverDeviceServer{ + lc: lc, + dic: dic, + } +} + +func (s *DriverDeviceServer) RegisterServer(server *grpc.Server) { + device.RegisterRpcDeviceServer(server, s) +} diff --git a/internal/hummingbird/core/controller/rpcserver/driverserver/driverstorage.go b/internal/hummingbird/core/controller/rpcserver/driverserver/driverstorage.go new file mode 100644 index 0000000..f6d5162 --- /dev/null +++ b/internal/hummingbird/core/controller/rpcserver/driverserver/driverstorage.go @@ -0,0 +1,198 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package driverserver + +import ( + "context" + driverstorage "github.com/winc-link/edge-driver-proto/driverstorge" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "github.com/winc-link/hummingbird/internal/tools/datadb/leveldb" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/emptypb" + "os" + "strings" + "sync" +) + +type DriverStorageServer struct { + driverstorage.UnimplementedDriverStorageServer + dMap map[string]*leveldb.DriverStorageClient + mu sync.Mutex + dirPath string + lc logger.LoggingClient + dic *di.Container +} + +func (s *DriverStorageServer) All(ctx context.Context, req *driverstorage.AllReq) (*driverstorage.KVs, error) { + id := req.GetDriverServiceId() + if len(id) <= 0 { + return nil, status.Error(codes.InvalidArgument, "driver service not set") + } + s.lc.Infof("get driver storage, driver service: %s", id) + client, err := s.getStorageClient(id) + if err != nil { + s.lc.Errorf("get leveldb client error: %s,driver service: %s", err, id) + return nil, status.Error(codes.Internal, err.Error()) + } + all, err := client.All() + if err != nil { + s.lc.Errorf("get all kvs error: %s,driver service: %s", err, id) + return nil, status.Error(codes.Internal, err.Error()) + } + + var kvs driverstorage.KVs + for k, v := range all { + kvs.Kvs = append(kvs.Kvs, &driverstorage.KV{ + Key: k, + Value: v, + }) + } + return &kvs, nil +} + +func (s *DriverStorageServer) Get(ctx context.Context, req *driverstorage.GetReq) (*driverstorage.KVs, error) { + id := req.GetDriverServiceId() + if len(id) <= 0 { + return nil, status.Error(codes.InvalidArgument, "driver service not set") + } + client, err := s.getStorageClient(id) + if err != nil { + s.lc.Errorf("get leveldb client error: %s,driver service: %s", err, id) + return nil, status.Error(codes.Internal, err.Error()) + } + keys := req.GetKeys() + if len(keys) <= 0 { + return nil, status.Error(codes.InvalidArgument, "keys length is 0") + } + + s.lc.Infof("get driver storage, driver service: %s, keys: %+v", id, keys) + + kvs, _ := client.Get(keys) + // convert + var resp driverstorage.KVs + for k, v := range kvs { + resp.Kvs = append(resp.Kvs, &driverstorage.KV{ + Key: k, + Value: v, + }) + } + return &resp, nil +} + +func (s *DriverStorageServer) Put(ctx context.Context, req *driverstorage.PutReq) (*emptypb.Empty, error) { + id := req.GetDriverServiceId() + if len(id) <= 0 { + return &emptypb.Empty{}, status.Error(codes.InvalidArgument, "driver service not set") + } + client, err := s.getStorageClient(id) + if err != nil { + s.lc.Errorf("get leveldb client error: %s,driver service: %s", err, id) + return &emptypb.Empty{}, status.Error(codes.Internal, err.Error()) + } + data := req.GetData() + keys := make([]string, 0, len(data)) + kvs := make(map[string][]byte, len(data)) + for _, v := range data { + keys = append(keys, v.GetKey()) + kvs[v.GetKey()] = v.GetValue() + } + + s.lc.Infof("put driver storage, driver service: %s, keys: %+v", id, keys) + + if err := client.Put(kvs); err != nil { + return &emptypb.Empty{}, status.Error(codes.Internal, err.Error()) + } + return &emptypb.Empty{}, nil +} + +func (s *DriverStorageServer) Delete(ctx context.Context, req *driverstorage.DeleteReq) (*emptypb.Empty, error) { + id := req.GetDriverServiceId() + if len(id) <= 0 { + return &emptypb.Empty{}, status.Error(codes.InvalidArgument, "driver service not set") + } + client, err := s.getStorageClient(id) + if err != nil { + s.lc.Errorf("get leveldb client error: %s,driver service: %s", err, id) + return &emptypb.Empty{}, status.Error(codes.Internal, err.Error()) + } + + keys := req.GetKeys() + if len(keys) <= 0 { + return nil, status.Error(codes.InvalidArgument, "keys length is 0") + } + + s.lc.Infof("delete driver storage, driver service: %s, keys: %+v", id, keys) + + if err := client.Delete(keys); err != nil { + s.lc.Errorf("driver storage delete keys(%+v) error: %s", keys, err) + return &emptypb.Empty{}, status.Error(codes.Internal, err.Error()) + } + return &emptypb.Empty{}, nil +} + +func NewDriverStorageServer(lc logger.LoggingClient, dic *di.Container) *DriverStorageServer { + var dir string + config := container.ConfigurationFrom(dic.Get) + if ldb, ok := config.Databases["Data"]; !ok { + lc.Errorf("leveldb not config") + os.Exit(-1) + } else { + dir = ldb["Primary"].DataSource + if len(dir) <= 0 { + lc.Errorf("leveldb not config") + os.Exit(-1) + } + } + + strN := strings.SplitN(dir, "/", 2) + if len(strN) < 1 { + lc.Errorf("leveldb config error") + os.Exit(-1) + } + dir = strN[0] + "/" + return &DriverStorageServer{ + lc: lc, + dMap: make(map[string]*leveldb.DriverStorageClient), + dirPath: dir, + dic: dic, + } +} + +func (s *DriverStorageServer) RegisterServer(server *grpc.Server) { + driverstorage.RegisterDriverStorageServer(server, s) +} + +func (s *DriverStorageServer) getStorageClient(id string) (*leveldb.DriverStorageClient, error) { + s.mu.Lock() + defer s.mu.Unlock() + + var ( + err error + ok bool + client *leveldb.DriverStorageClient + ) + + if client, ok = s.dMap[id]; !ok { + if client, err = leveldb.NewDriverStorageClient(s.dirPath, id, s.lc); err != nil { + return nil, err + } + s.dMap[id] = client + } + return client, nil +} diff --git a/internal/hummingbird/core/controller/rpcserver/driverserver/gateway.go b/internal/hummingbird/core/controller/rpcserver/driverserver/gateway.go new file mode 100644 index 0000000..0033dfe --- /dev/null +++ b/internal/hummingbird/core/controller/rpcserver/driverserver/gateway.go @@ -0,0 +1,51 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package driverserver + +import ( + "context" + "github.com/winc-link/edge-driver-proto/gateway" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "google.golang.org/grpc" + "google.golang.org/protobuf/types/known/emptypb" +) + +type GatewayServer struct { + gateway.UnimplementedRpcGatewayServer + + lc logger.LoggingClient + dic *di.Container +} + +func (s *GatewayServer) GetGatewayInfo(ctx context.Context, empty *emptypb.Empty) (*gateway.GateWayInfoResponse, error) { + response := new(gateway.GateWayInfoResponse) + response.Env = "env" + response.GwId = "gatewayId" + response.LocalKey = "localKey" + + return response, nil +} + +func NewGatewayServer(lc logger.LoggingClient, dic *di.Container) *GatewayServer { + return &GatewayServer{ + lc: lc, + dic: dic, + } +} + +func (s *GatewayServer) RegisterServer(server *grpc.Server) { + gateway.RegisterRpcGatewayServer(server, s) +} diff --git a/internal/hummingbird/core/controller/rpcserver/driverserver/product.go b/internal/hummingbird/core/controller/rpcserver/driverserver/product.go new file mode 100644 index 0000000..a233f38 --- /dev/null +++ b/internal/hummingbird/core/controller/rpcserver/driverserver/product.go @@ -0,0 +1,114 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package driverserver + +import ( + "context" + "github.com/winc-link/edge-driver-proto/drivercommon" + "github.com/winc-link/edge-driver-proto/driverproduct" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "google.golang.org/grpc" +) + +type ProductServer struct { + driverproduct.UnimplementedRpcProductServer + lc logger.LoggingClient + dic *di.Container +} + +func (s *ProductServer) QueryProductList(ctx context.Context, request *driverproduct.QueryProductListRequest) (*driverproduct.QueryProductListResponse, error) { + var productPlatform constants.IotPlatform + if request.BaseRequest != nil && request.BaseRequest.UseCloudPlatform { + s.lc.Infof("request.BaseRequest.GetCloudInstanceInfo().IotPlatform:", request.BaseRequest.GetCloudInstanceInfo().IotPlatform) + switch request.BaseRequest.GetCloudInstanceInfo().IotPlatform { + case drivercommon.IotPlatform_WinCLinkIot: + productPlatform = constants.IotPlatform_WinCLinkIot + case drivercommon.IotPlatform_AliIot: + productPlatform = constants.IotPlatform_AliIot + case drivercommon.IotPlatform_HuaweiIot: + productPlatform = constants.IotPlatform_HuaweiIot + case drivercommon.IotPlatform_TencentIot: + productPlatform = constants.IotPlatform_TencentIot + case drivercommon.IotPlatform_TuyaIot: + productPlatform = constants.IotPlatform_TuyaIot + case drivercommon.IotPlatform_OneNetIot: + productPlatform = constants.IotPlatform_OneNetIot + default: + productPlatform = constants.IotPlatform_LocalIot + } + } else { + productPlatform = constants.IotPlatform_LocalIot + } + + productItf := container.ProductAppNameFrom(s.dic.Get) + + response := new(driverproduct.QueryProductListResponse) + response.Data = new(driverproduct.QueryProductListResponse_Data) + + response.BaseResponse = new(drivercommon.CommonResponse) + + productsModel, totol, err := productItf.ProductsModelSearch(ctx, dtos.ProductSearchQueryRequest{ + Platform: string(productPlatform), + }) + if err != nil { + response.BaseResponse.Success = false + response.BaseResponse.ErrorMessage = err.Error() + return response, nil + } + response.Data.Total = totol + var driverProducts []*driverproduct.Product + for _, productModel := range productsModel { + driverProducts = append(driverProducts, productModel.TransformToDriverProduct()) + + } + response.BaseResponse.Success = true + response.Data.Products = driverProducts + return response, nil +} + +func (s *ProductServer) QueryProductById(ctx context.Context, request *driverproduct.QueryProductByIdRequest) (*driverproduct.QueryProductByIdResponse, error) { + productItf := container.ProductAppNameFrom(s.dic.Get) + + response := new(driverproduct.QueryProductByIdResponse) + response.BaseResponse = new(drivercommon.CommonResponse) + + productModel, err := productItf.ProductModelById(ctx, request.Id) + if err != nil { + response.BaseResponse.Success = false + response.BaseResponse.ErrorMessage = err.Error() + return response, nil + } + response.Data = new(driverproduct.QueryProductByIdResponse_Data) + response.Data.Product = productModel.TransformToDriverProduct() + + response.BaseResponse.Success = true + return response, err + +} + +func NewProductServer(lc logger.LoggingClient, dic *di.Container) *ProductServer { + return &ProductServer{ + lc: lc, + dic: dic, + } +} + +func (s *ProductServer) RegisterServer(server *grpc.Server) { + driverproduct.RegisterRpcProductServer(server, s) +} diff --git a/internal/hummingbird/core/controller/rpcserver/driverserver/thingmodel.go b/internal/hummingbird/core/controller/rpcserver/driverserver/thingmodel.go new file mode 100644 index 0000000..6b181b3 --- /dev/null +++ b/internal/hummingbird/core/controller/rpcserver/driverserver/thingmodel.go @@ -0,0 +1,52 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package driverserver + +import ( + "context" + "github.com/winc-link/edge-driver-proto/drivercommon" + "github.com/winc-link/edge-driver-proto/thingmodel" + "github.com/winc-link/hummingbird/internal/dtos" + coreContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "google.golang.org/grpc" +) + +type ThingModelServer struct { + thingmodel.UnimplementedThingModelUpServiceServer + + lc logger.LoggingClient + dic *di.Container +} + +func (s *ThingModelServer) ThingModelMsgReport(ctx context.Context, msg *thingmodel.ThingModelMsg) (response *drivercommon.CommonResponse, err error) { + message := dtos.ThingModelMessageFromThingModelMsg(msg) + messageItf := coreContainer.MessageItfFrom(s.dic.Get) + return messageItf.ThingModelMsgReport(ctx, message) +} + +var _ thingmodel.ThingModelUpServiceServer = (*ThingModelServer)(nil) + +func NewThingModelServer(lc logger.LoggingClient, dic *di.Container) *ThingModelServer { + return &ThingModelServer{ + lc: lc, + dic: dic, + } +} + +func (s *ThingModelServer) RegisterServer(server *grpc.Server) { + thingmodel.RegisterThingModelUpServiceServer(server, s) +} diff --git a/internal/hummingbird/core/infrastructure/mysql/alertrule.go b/internal/hummingbird/core/infrastructure/mysql/alertrule.go new file mode 100644 index 0000000..44c3781 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/mysql/alertrule.go @@ -0,0 +1,217 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package mysql + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" + "time" +) + +func addAlertRule(c *Client, ds models.AlertRule) (alertRule models.AlertRule, edgeXErr error) { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + + err := c.client.CreateObject(&ds) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "alert rule creation failed", err) + } + return ds, edgeXErr +} + +func addAlertList(c *Client, ds models.AlertList) (alertRule models.AlertList, edgeXErr error) { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + + err := c.client.CreateObject(&ds) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "alert rule list creation failed", err) + } + return ds, edgeXErr +} + +func updateAlertRule(c *Client, dl models.AlertRule) error { + dl.Modified = utils.MakeTimestamp() + err := c.client.UpdateObject(&dl) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "alert rule update failed", err) + } + return nil +} + +func alertRuleById(c *Client, id string) (alertRule models.AlertRule, edgeXErr error) { + if id == "" { + return alertRule, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "alert rule id is empty", nil) + } + err := c.Pool.Table(alertRule.TableName()).First(&alertRule, id).Error + if err != nil { + if err == gorm.ErrRecordNotFound { + return alertRule, errort.NewCommonErr(errort.AlertRuleNotExist, fmt.Errorf("alert rule id(%s) not found", id)) + } + return alertRule, errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("query alert rule fail (Id:%s), %s", alertRule.Id, err)) + } + return +} + +func alertRuleSearch(c *Client, offset int, limit int, req dtos.AlertRuleSearchQueryRequest) (alertRules []models.AlertRule, count uint32, edgeXErr error) { + dp := models.AlertRule{} + var total int64 + tx := c.Pool.Table(dp.TableName()) + tx = sqlite.BuildCommonCondition(tx, dp, req.BaseSearchConditionQuery) + + if req.Name != "" { + tx = tx.Where("`name` = ?", req.Name) + } + if req.Status != "" { + tx = tx.Where("`status` = ?", req.Status) + } + + err := tx.Count(&total).Error + if err != nil { + return []models.AlertRule{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "alert rules failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&alertRules).Error + if err != nil { + return []models.AlertRule{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "alert rules failed query from the database", err) + } + + return alertRules, uint32(total), nil +} + +func alertListLastSend(c *Client, alertRuleId string) (alertList models.AlertList, edgeXErr error) { + al := models.AlertList{} + err := c.Pool.Table(al.TableName()).Where("alert_rule_id = ?", alertRuleId).Where("is_send", true).Order("created desc").Last(&alertList).Error + if err != nil { + return + } + return +} + +func alertListSearch(c *Client, offset int, limit int, req dtos.AlertSearchQueryRequest) (alertRules []dtos.AlertSearchQueryResponse, count uint32, edgeXErr error) { + var total int64 + dp := models.AlertList{} + tx := c.Pool.Table(dp.TableName()).Select("alert_list.id,alert_list.status," + + "alert_rule.name,alert_list.alert_result,alert_rule.alert_level,alert_list.trigger_time,alert_list.treated_time,alert_list.message,alert_list.is_send").Joins("left join alert_rule on alert_list.alert_rule_id = alert_rule.id") + //tx = sqlite.BuildCommonCondition(tx, dp, req.BaseSearchConditionQuery) + if req.Name != "" { + tx.Where("alert_rule.name LIKE ?", sqlite.MakeLikeParams(req.Name)) + } + if req.Status != "" { + tx.Where("alert_list.status = ?", req.Status) + + } + if req.AlertLevel != "" { + tx.Where("alert_rule.alert_level = ?", req.AlertLevel) + } + if req.TriggerStartTime > 0 && req.TriggerEndTime > 0 && req.TriggerEndTime-req.TriggerStartTime > 0 { + tx.Where("alert_list.trigger_time >= ?", req.TriggerStartTime) + tx.Where("alert_list.trigger_time <= ?", req.TriggerEndTime) + } + edgeXErr = tx.Count(&total).Error + if edgeXErr != nil { + return []dtos.AlertSearchQueryResponse{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "alert list failed query from the database", edgeXErr) + } + tx.Order("alert_list.created desc") + edgeXErr = tx.Offset(offset).Limit(limit).Scan(&alertRules).Error + if edgeXErr != nil { + return []dtos.AlertSearchQueryResponse{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "alert list failed query from the database", edgeXErr) + } + return alertRules, uint32(total), nil +} + +func deleteAlertRuleById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "alert rule id is empty", nil) + } + err := c.client.DeleteObject(&models.AlertRule{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "alert rule deletion failed", err) + } + return nil +} + +func alertRuleStart(c *Client, id string) error { + d := models.AlertRule{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where("id = ?", id).Updates(map[string]interface{}{"status": constants.RuleStart}).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "start alert rule failed", err) + } + return nil +} + +func alertRuleStop(c *Client, id string) error { + d := models.AlertRule{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where("id = ?", id).Updates(map[string]interface{}{"status": constants.RuleStop}).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "stop alert rule failed", err) + } + return nil +} + +//subQuery := db.Select("AVG(age)").Where("name LIKE ?", "name%").Table("users") +//db.Select("AVG(age) as avgage").Group("name").Having("AVG(age) > (?)", subQuery).Find(&results) +// SELECT AVG(age) as avgage FROM `users` GROUP BY `name` HAVING AVG(age) > (SELECT AVG(age) FROM `users` WHERE name LIKE "name%") + +func alertPlate(c *Client, beforeTime int64) (plate []dtos.AlertPlateQueryResponse, err error) { + d := models.AlertList{} + if beforeTime > 0 { + err = c.Pool.Table(d.TableName()).Raw( + "SELECT count(alert_list.id) AS count,alert_rule.alert_level FROM alert_list "+ + "JOIN alert_rule on alert_list.alert_rule_id = alert_rule.id and alert_list.created > (?) "+ + "GROUP BY alert_rule.alert_level", beforeTime).Scan(&plate).Error + } else { + err = c.Pool.Table(d.TableName()).Raw( + "SELECT count(alert_list.id) AS count,alert_rule.alert_level FROM alert_list " + + "JOIN alert_rule on alert_list.alert_rule_id = alert_rule.id" + + "GROUP BY alert_rule.alert_level").Scan(&plate).Error + } + + return +} + +func alertIgnore(c *Client, id string) (err error) { + d := models.AlertList{} + tx := c.Pool.Table(d.TableName()) + err = tx.Where("id = ?", id).Updates(map[string]interface{}{"status": constants.Ignore, "treated_time": time.Now().UnixMilli()}).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "alert ignore rule failed", err) + } + return nil +} + +func treatedIgnore(c *Client, id string, message string) (err error) { + d := models.AlertList{} + tx := c.Pool.Table(d.TableName()) + err = tx.Where("id = ?", id).Updates(map[string]interface{}{"status": constants.Treated, "message": message, "treated_time": time.Now().UnixMilli()}).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "alert ignore rule failed", err) + } + return nil +} diff --git a/internal/hummingbird/core/infrastructure/mysql/category.go b/internal/hummingbird/core/infrastructure/mysql/category.go new file mode 100644 index 0000000..865a6e3 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/mysql/category.go @@ -0,0 +1,82 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package mysql + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func categoryTemplateSearch(c *Client, offset int, limit int, req dtos.CategoryTemplateRequest) (categoryTemplates []models.CategoryTemplate, count uint32, edgeXErr error) { + cs := models.CategoryTemplate{} + var total int64 + tx := c.Pool.Table(cs.TableName()) + tx = sqlite.BuildCommonCondition(tx, cs, req.BaseSearchConditionQuery) + + if req.CategoryName != "" { + tx = tx.Where("`category_name` LIKE ?", "%"+req.CategoryName+"%") + } + + if req.Scene != "" { + tx = tx.Where("`scene` = ?", req.Scene) + } + + err := tx.Count(&total).Error + if err != nil { + return []models.CategoryTemplate{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "categoryTemplate failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&categoryTemplates).Error + if err != nil { + return []models.CategoryTemplate{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "categoryTemplate failed query from the database", err) + } + + return categoryTemplates, uint32(total), nil +} + +func categoryTemplateById(c *Client, id string) (categoryTemplateInfo models.CategoryTemplate, edgeXErr error) { + if id == "" { + return categoryTemplateInfo, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "categoryTemplate id is empty", nil) + } + err := c.client.GetObject(&models.CategoryTemplate{Id: id}, &categoryTemplateInfo) + if err != nil { + if err == gorm.ErrRecordNotFound { + return categoryTemplateInfo, errort.NewCommonErr(errort.CategoryNotExist, fmt.Errorf("categoryTemplate id(%s) not found", id)) + } + return categoryTemplateInfo, err + } + return +} + +func batchUpsertCategoryTemplate(c *Client, d []models.CategoryTemplate) (int64, error) { + if len(d) <= 0 { + return 0, nil + } + tx := c.Pool.Session(&gorm.Session{FullSaveAssociations: true}).Clauses( + clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(d, sqlite.CreateBatchSize) + num := tx.RowsAffected + err := tx.Error + if err != nil { + return num, err + } + return num, nil +} diff --git a/internal/hummingbird/core/infrastructure/mysql/client.go b/internal/hummingbird/core/infrastructure/mysql/client.go new file mode 100644 index 0000000..40467b9 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/mysql/client.go @@ -0,0 +1,702 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package mysql + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/mysql" + "gorm.io/gorm" + //clientSQLite "github.com/winc-link/hummingbird/internal/tools/sqldb/mysql" +) + +type Client struct { + Pool *gorm.DB + //cache interfaces.Cache + client mysql.ClientSQLite + loggingClient logger.LoggingClient +} + +func NewClient(config dtos.Configuration, lc logger.LoggingClient) (c *Client, errEdgeX error) { + client, err := mysql.NewGormClient(config, lc) + if err != nil { + errEdgeX = errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("database failed to init %w", err)) + return + } + // 自动建表 + if err = client.InitTable( + &models.DeviceLibrary{}, + &models.DeviceService{}, + &models.Device{}, + &models.DockerConfig{}, + &models.AdvanceConfig{}, + &models.SystemMetrics{}, + &models.CategoryTemplate{}, + &models.ThingModelTemplate{}, + &models.DriverClassify{}, + &models.User{}, + &models.LanguageSdk{}, + &models.Metrics{}, + &models.Product{}, + &models.Properties{}, + &models.Actions{}, + &models.Events{}, + &models.Unit{}, + &models.MqttAuth{}, + &models.AlertRule{}, + &models.Scene{}, + &models.SceneLog{}, + &models.AlertList{}, + &models.QuickNavigation{}, + &models.Doc{}, + &models.MsgGather{}, + &models.RuleEngine{}, + &models.DataResource{}, + ); err != nil { + errEdgeX = errort.NewCommonEdgeX(errort.DefaultSystemError, "database failed to init", err) + return + } + c = &Client{ + client: client, + loggingClient: lc, + Pool: client.Pool, + } + return +} + +// CloseSession closes the connections to Redis +func (c *Client) CloseSession() { + c.client.Close() +} + +func (c *Client) GetDBInstance() *gorm.DB { + return c.Pool +} + +func (c *Client) AddDeviceLibrary(dl models.DeviceLibrary) (models.DeviceLibrary, error) { + if len(dl.Id) == 0 { + dl.Id = utils.GenUUID() + } + return addDeviceLibrary(c, dl) +} + +func (c *Client) DockerConfigAdd(dc models.DockerConfig) (models.DockerConfig, error) { + if len(dc.Id) == 0 { + dc.Id = utils.GenUUID() + } + return dockerConfigAdd(c, dc) +} + +func (c *Client) DockerConfigById(id string) (models.DockerConfig, error) { + return dockerConfigById(c, id) +} + +func (c *Client) DockerConfigDelete(id string) error { + return dockerConfigDeleteById(c, id) +} + +func (c *Client) DockerConfigUpdate(dc models.DockerConfig) error { + return dockerConfigUpdate(c, dc) +} + +func (c *Client) DockerConfigsSearch(offset int, limit int, req dtos.DockerConfigSearchQueryRequest) (dcs []models.DockerConfig, total uint32, edgeXErr error) { + return dockerConfigsSearch(c, offset, limit, req) +} + +func (c *Client) DriverClassifySearch(offset int, limit int, req dtos.DriverClassifyQueryRequest) (dcs []models.DriverClassify, total uint32, edgeXErr error) { + return driverClassifySearch(c, offset, limit, req) +} + +func (c *Client) DeviceLibrariesSearch(offset int, limit int, req dtos.DeviceLibrarySearchQueryRequest) (deviceLibraries []models.DeviceLibrary, total uint32, edgeXErr error) { + deviceLibraries, total, edgeXErr = deviceLibrariesSearch(c, offset, limit, req) + if edgeXErr != nil { + return deviceLibraries, total, edgeXErr + } + return deviceLibraries, total, nil +} + +func (c *Client) DeviceServicesSearch(offset int, limit int, req dtos.DeviceServiceSearchQueryRequest) (deviceServices []models.DeviceService, total uint32, edgeXErr error) { + deviceServices, total, edgeXErr = deviceServicesSearch(c, offset, limit, req) + if edgeXErr != nil { + return deviceServices, 0, edgeXErr + } + return deviceServices, total, nil +} + +func (c *Client) DeviceLibraryById(id string) (deviceLibrary models.DeviceLibrary, edgeXErr error) { + return deviceLibraryById(c, id) +} + +func (c *Client) DeleteDeviceLibraryById(id string) error { + return deleteDeviceLibraryById(c, id) +} + +func (c *Client) AddDeviceService(ds models.DeviceService) (models.DeviceService, error) { + // 驱动实例和驱动id一样,为了防止容器实例名冲突导致数据冲突 + if len(ds.Id) == 0 { + ds.Id = utils.RandomNum() + } + ds.Name = ds.Name + "-" + ds.Id + return addDeviceService(c, ds) +} + +func (c *Client) UpdateDeviceService(ds models.DeviceService) error { + return updateDeviceService(c, ds) +} + +func (c *Client) UpdateDeviceLibrary(dl models.DeviceLibrary) error { + return updateDeviceLibrary(c, dl) +} + +func (c *Client) DeviceServiceById(id string) (deviceService models.DeviceService, edgeXErr error) { + deviceService, edgeXErr = deviceServiceById(c, id) + if edgeXErr != nil { + return deviceService, edgeXErr + } + + return +} + +func (c *Client) DeleteDeviceServiceById(id string) error { + return deleteDeviceServiceById(c, id) +} + +func (c *Client) ProductById(id string) (product models.Product, edgeXErr error) { + return productById(c, id) +} + +func (c *Client) AddProduct(ds models.Product) (product models.Product, edgeXErr error) { + if len(ds.Id) == 0 { + ds.Id = utils.RandomNum() + } + return addProduct(c, ds) +} + +func (c *Client) ProductByCloudId(id string) (product models.Product, edgeXErr error) { + return productByCloudId(c, id) +} + +func (c *Client) BatchUpsertProduct(p []models.Product) (int64, error) { + return batchUpsertProduct(c, p) +} + +func (c *Client) BatchSaveProduct(p []models.Product) error { + return batchSaveProduct(c, p) +} + +func (c *Client) BatchDeleteProduct(products []models.Product) error { + return batchDeleteProduct(c, products) +} + +func (c *Client) BatchDeleteProperties(propertiesIds []string) error { + return batchDeleteProperties(c, propertiesIds) +} + +func (c *Client) BatchDeleteSystemProperties() error { + return batchDeleteSystemProperties(c) +} + +func (c *Client) BatchInsertSystemProperties(p []models.Properties) (int64, error) { + return batchInsertSystemProperties(c, p) +} + +func (c *Client) BatchDeleteEvents(eventIds []string) error { + return batchDeleteEvents(c, eventIds) +} + +func (c *Client) BatchDeleteSystemEvents() error { + return batchDeleteSystemEvents(c) +} + +func (c *Client) BatchInsertSystemEvents(p []models.Events) (int64, error) { + return batchInsertSystemEvents(c, p) +} + +func (c *Client) BatchDeleteActions(actionIds []string) error { + return batchDeleteActions(c, actionIds) +} +func (c *Client) BatchDeleteSystemActions() error { + return batchDeleteSystemActions(c) +} + +func (c *Client) BatchInsertSystemActions(p []models.Actions) (int64, error) { + return batchInsertSystemActions(c, p) +} + +func (c *Client) DeleteProductById(id string) error { + return deleteProductById(c, id) +} + +func (c *Client) DeleteProductObject(product models.Product) error { + return deleteProductObject(c, product) +} + +func (c *Client) AssociationsDeleteProductObject(product models.Product) error { + return associationsDeleteProductObject(c, product) +} + +func (c *Client) UpdateProduct(ds models.Product) error { + return updateProduct(c, ds) +} + +func (c *Client) AssociationsUpdateProduct(ds models.Product) error { + return associationsUpdateProduct(c, ds) +} + +func (c *Client) BatchUpsertDevice(p []models.Device) (int64, error) { + return batchUpsertDevice(c, p) +} + +func (c *Client) ProductsSearch(offset int, limit int, preload bool, req dtos.ProductSearchQueryRequest) (products []models.Product, total uint32, edgeXErr error) { + products, total, edgeXErr = productsSearch(c, offset, limit, preload, req) + if edgeXErr != nil { + return products, 0, edgeXErr + } + return products, total, nil +} + +func (c *Client) DevicesSearch(offset int, limit int, req dtos.DeviceSearchQueryRequest) (devices []models.Device, total uint32, edgeXErr error) { + devices, total, edgeXErr = devicesSearch(c, offset, limit, req) + if edgeXErr != nil { + return devices, 0, edgeXErr + } + return devices, total, nil +} + +func (c *Client) DeviceById(id string) (device models.Device, edgeXErr error) { + return deviceById(c, id) +} + +func (c *Client) DeviceOnlineById(id string) (edgeXErr error) { + return deviceOnlineById(c, id) +} + +func (c *Client) DeviceOfflineById(id string) (edgeXErr error) { + return deviceOfflineById(c, id) +} + +func (c *Client) DeviceOfflineByCloudInstanceId(id string) (edgeXErr error) { + return deviceOfflineByCloudInstanceId(c, id) +} + +func (c *Client) MsgReportDeviceById(id string) (device models.Device, edgeXErr error) { + return msgReportDeviceById(c, id) +} + +func (c *Client) DeviceByCloudId(id string) (device models.Device, edgeXErr error) { + return deviceByCloudId(c, id) +} + +func (c *Client) DeviceMqttAuthInfo(id string) (device models.MqttAuth, edgeXErr error) { + return deviceMqttAuthInfo(c, id) +} + +func (c *Client) DriverMqttAuthInfo(id string) (device models.MqttAuth, edgeXErr error) { + return driverMqttAuthInfo(c, id) +} + +func (c *Client) AddDevice(ds models.Device) (deviceId string, edgeXErr error) { + if len(ds.Id) == 0 { + ds.Id = utils.RandomNum() + } + return addDevice(c, ds) +} + +func (c *Client) BatchDeleteDevice(ids []string) error { + return batchDeleteDevice(c, ids) +} + +func (c *Client) BatchUnBindDevice(ids []string) error { + return batchUnBindDevice(c, ids) +} + +func (c *Client) BatchBindDevice(ids []string, driverInstanceId string) error { + return batchBindDevice(c, ids, driverInstanceId) +} + +func (c *Client) DeleteDeviceById(id string) error { + return deleteDeviceById(c, id) +} + +func (c *Client) DeleteDeviceByCloudInstanceId(id string) error { + return deleteDeviceByCloudInstanceId(c, id) +} + +func (c *Client) UpdateDevice(ds models.Device) error { + return updateDevice(c, ds) +} + +func (c *Client) AbilityByCode(model interface{}, code, productId string) (interface{}, error) { + return abilityByCode(c, model, code, productId) +} + +func (c *Client) CategoryTemplateSearch(offset int, limit int, req dtos.CategoryTemplateRequest) ([]models.CategoryTemplate, uint32, error) { + return categoryTemplateSearch(c, offset, limit, req) +} + +func (c *Client) UnitSearch(offset int, limit int, req dtos.UnitRequest) ([]models.Unit, uint32, error) { + return unitSearch(c, offset, limit, req) +} + +func (c *Client) BatchUpsertUnitTemplate(p []models.Unit) (int64, error) { + return batchUpsertUnitTemplate(c, p) +} + +func (c *Client) CategoryTemplateById(id string) (models.CategoryTemplate, error) { + return categoryTemplateById(c, id) +} + +func (c *Client) BatchUpsertCategoryTemplate(p []models.CategoryTemplate) (int64, error) { + return batchUpsertCategoryTemplate(c, p) +} + +func (c *Client) ThingModelTemplateSearch(offset int, limit int, req dtos.ThingModelTemplateRequest) ([]models.ThingModelTemplate, uint32, error) { + return thingModelTemplateSearch(c, offset, limit, req) +} +func (c *Client) ThingModelTemplateByCategoryKey(categoryKey string) (models.ThingModelTemplate, error) { + return thingModelTemplateByCategoryKey(c, categoryKey) +} + +func (c *Client) BatchUpsertThingModelTemplate(p []models.ThingModelTemplate) (int64, error) { + return batchUpsertThingModelTemplate(c, p) +} + +func (c *Client) AddThingModelProperty(ds models.Properties) (models.Properties, error) { + if len(ds.Id) == 0 { + ds.Id = utils.RandomNum() + } + return addThingModelProperty(c, ds) +} + +func (c *Client) BatchUpsertThingModel(ds interface{}) (int64, error) { + return batchUpsertThingModel(c, ds) +} + +func (c *Client) AddThingModelEvent(ds models.Events) (models.Events, error) { + if len(ds.Id) == 0 { + ds.Id = utils.RandomNum() + } + return addThingModelEvent(c, ds) +} + +func (c *Client) AddThingModelAction(ds models.Actions) (models.Actions, error) { + if len(ds.Id) == 0 { + ds.Id = utils.RandomNum() + } + return addThingModelAction(c, ds) +} + +func (c *Client) UpdateThingModelProperty(ds models.Properties) error { + return updateThingModelProperty(c, ds) +} + +func (c *Client) UpdateThingModelEvent(ds models.Events) error { + return updateThingModelEvent(c, ds) +} + +func (c *Client) UpdateThingModelAction(ds models.Actions) error { + return updateThingModelAction(c, ds) +} + +func (c *Client) ThingModelDeleteProperty(id string) error { + return deleteThingModelPropertyById(c, id) +} + +func (c *Client) ThingModelDeleteEvent(id string) error { + return deleteThingModelEventById(c, id) +} + +func (c *Client) ThingModelDeleteAction(id string) error { + return deleteThingModelActionById(c, id) +} + +func (c *Client) ThingModelPropertyById(id string) (models.Properties, error) { + return thingModelPropertyById(c, id) +} + +func (c *Client) ThingModelEventById(id string) (models.Events, error) { + return thingModelEventById(c, id) +} + +func (c *Client) ThingModelActionsById(id string) (models.Actions, error) { + return thingModeActionById(c, id) +} +func (c *Client) SystemThingModelSearch(modelType string, ModelName string) (interface{}, error) { + return systemThingModelSearch(c, modelType, ModelName) +} + +func (c *Client) AddMqttAuthInfo(auth models.MqttAuth) (string, error) { + if len(auth.Id) == 0 { + auth.Id = utils.RandomNum() + } + return addMqttAuth(c, auth) +} + +func (c *Client) AddOrUpdateAuth(auth models.MqttAuth) error { + if len(auth.Id) == 0 { + auth.Id = utils.RandomNum() + } + return addOrUpdateAuth(c, auth) +} + +func (c *Client) AddAlertRule(alertRule models.AlertRule) (models.AlertRule, error) { + if len(alertRule.Id) == 0 { + alertRule.Id = utils.RandomNum() + } + return addAlertRule(c, alertRule) +} + +func (c *Client) AddAlertList(alertRule models.AlertList) (models.AlertList, error) { + if len(alertRule.Id) == 0 { + alertRule.Id = utils.RandomNum() + } + return addAlertList(c, alertRule) +} + +func (c *Client) UpdateAlertRule(rule models.AlertRule) error { + return updateAlertRule(c, rule) +} + +func (c *Client) AlertRuleById(id string) (models.AlertRule, error) { + return alertRuleById(c, id) +} + +func (c *Client) AlertRuleSearch(offset int, limit int, req dtos.AlertRuleSearchQueryRequest) (alertRules []models.AlertRule, total uint32, edgeXErr error) { + alertRules, total, edgeXErr = alertRuleSearch(c, offset, limit, req) + if edgeXErr != nil { + return alertRules, 0, edgeXErr + } + return alertRules, total, nil +} + +func (c *Client) AlertListSearch(offset int, limit int, req dtos.AlertSearchQueryRequest) (alertList []dtos.AlertSearchQueryResponse, total uint32, edgeXErr error) { + alertList, total, edgeXErr = alertListSearch(c, offset, limit, req) + if edgeXErr != nil { + return alertList, 0, edgeXErr + } + return alertList, total, nil +} + +func (c *Client) AlertIgnore(id string) (edgeXErr error) { + return alertIgnore(c, id) +} + +func (c *Client) TreatedIgnore(id, message string) (edgeXErr error) { + return treatedIgnore(c, id, message) +} + +func (c *Client) AlertListLastSend(alertRuleId string) (alertList models.AlertList, edgeXErr error) { + return alertListLastSend(c, alertRuleId) +} + +func (c *Client) DeleteAlertRuleById(id string) error { + return deleteAlertRuleById(c, id) +} + +func (c *Client) AlertRuleStart(id string) error { + return alertRuleStart(c, id) +} + +func (c *Client) AlertRuleStop(id string) error { + return alertRuleStop(c, id) +} + +func (c *Client) AlertPlate(beforeTime int64) (plate []dtos.AlertPlateQueryResponse, err error) { + return alertPlate(c, beforeTime) +} + +func (c *Client) QuickNavigationSearch(offset int, limit int, req dtos.QuickNavigationSearchQueryRequest) (quickNavigations []models.QuickNavigation, total uint32, edgeXErr error) { + quickNavigations, total, edgeXErr = quickNavigationSearch(c, offset, limit, req) + if edgeXErr != nil { + return quickNavigations, 0, edgeXErr + } + return quickNavigations, total, nil +} + +func (c *Client) DocsSearch(offset int, limit int, req dtos.DocsSearchQueryRequest) (docs []models.Doc, total uint32, edgeXErr error) { + docs, total, edgeXErr = docsSearch(c, offset, limit, req) + if edgeXErr != nil { + return docs, 0, edgeXErr + } + return docs, total, nil +} + +func (c *Client) BatchUpsertDocsTemplate(ds []models.Doc) (int64, error) { + return batchUpsertDocsTemplate(c, ds) +} + +func (c *Client) BatchUpsertQuickNavigationTemplate(ds []models.QuickNavigation) (int64, error) { + return batchUpsertQuickNavigationTemplate(c, ds) +} + +func (c *Client) DeleteQuickNavigation(id string) error { + return deleteQuickNavigation(c, id) +} + +func (c *Client) GetAdvanceConfig() (models.AdvanceConfig, error) { + return getAdvanceConfig(c) +} + +func (c *Client) UpdateAdvanceConfig(config models.AdvanceConfig) error { + return updateAdvanceConfig(c, config) +} + +func (c *Client) AddMsgGather(msgGather models.MsgGather) error { + if len(msgGather.Id) == 0 { + msgGather.Id = utils.RandomNum() + } + return addMsgGather(c, msgGather) +} + +func (c *Client) MsgGatherSearch(offset int, limit int, req dtos.MsgGatherSearchQueryRequest) (dcs []models.MsgGather, count uint32, edgeXErr error) { + return msgGatherSearch(c, offset, limit, req) +} + +func (c *Client) AddDataResource(dateResource models.DataResource) (string, error) { + if len(dateResource.Id) == 0 { + dateResource.Id = utils.RandomNum() + } + return addDataResource(c, dateResource) +} + +func (c *Client) UpdateDataResource(dateResource models.DataResource) error { + return updateDataResource(c, dateResource) +} + +func (c *Client) DelDataResource(id string) error { + return deleteDataResourceById(c, id) +} + +func (c *Client) UpdateDataResourceHealth(id string, health bool) error { + return updateDataResourceHealth(c, id, health) +} + +func (c *Client) SearchDataResource(offset int, limit int, req dtos.DataResourceSearchQueryRequest) (dataResource []models.DataResource, count uint32, edgeXErr error) { + return dataResourceSearch(c, offset, limit, req) +} + +func (c *Client) DataResourceById(id string) (models.DataResource, error) { + return dataResourceById(c, id) +} + +func (c *Client) AddRuleEngine(ruleEngine models.RuleEngine) (string, error) { + if len(ruleEngine.Id) == 0 { + ruleEngine.Id = utils.RandomNum() + } + return addRuleEngine(c, ruleEngine) +} + +func (c *Client) UpdateRuleEngine(ruleEngine models.RuleEngine) error { + return updateRuleEngine(c, ruleEngine) +} + +func (c *Client) RuleEngineById(id string) (ruleEngine models.RuleEngine, edgeXErr error) { + return ruleEngineById(c, id) +} + +func (c *Client) RuleEngineSearch(offset int, limit int, req dtos.RuleEngineSearchQueryRequest) (ruleEngine []models.RuleEngine, count uint32, edgeXErr error) { + return ruleEngineSearch(c, offset, limit, req) +} + +func (c *Client) RuleEngineStart(id string) error { + return ruleEngineStart(c, id) +} + +func (c *Client) RuleEngineStop(id string) error { + return ruleEngineStop(c, id) +} + +func (c *Client) DeleteRuleEngineById(id string) error { + return deleteRuleEngineById(c, id) +} + +func (c *Client) AddScene(scene models.Scene) (models.Scene, error) { + if len(scene.Id) == 0 { + scene.Id = utils.RandomNum() + } + return addScene(c, scene) +} + +func (c *Client) UpdateScene(scene models.Scene) error { + if len(scene.Id) == 0 { + scene.Id = utils.RandomNum() + } + return updateScene(c, scene) +} +func (c *Client) SceneById(id string) (models.Scene, error) { + return sceneById(c, id) +} + +func (c *Client) SceneStart(id string) error { + return sceneStart(c, id) +} + +func (c *Client) SceneStop(id string) error { + return sceneStop(c, id) +} + +func (c *Client) DeleteSceneById(id string) error { + return deleteSceneById(c, id) +} + +func (c *Client) SceneSearch(offset int, limit int, req dtos.SceneSearchQueryRequest) (scenes []models.Scene, total uint32, edgeXErr error) { + return sceneSearch(c, offset, limit, req) +} + +func (c *Client) AddSceneLog(sceneLog models.SceneLog) (models.SceneLog, error) { + if len(sceneLog.Id) == 0 { + sceneLog.Id = utils.RandomNum() + } + return addSceneLog(c, sceneLog) +} + +func (c *Client) SceneLogSearch(offset int, limit int, req dtos.SceneLogSearchQueryRequest) (sceneLogs []models.SceneLog, total uint32, edgeXErr error) { + return sceneLogSearch(c, offset, limit, req) +} + +func (c *Client) LanguageSdkByName(name string) (cloudService models.LanguageSdk, edgeXErr error) { + return languageByName(c, name) +} + +func (c *Client) LanguageSearch(offset int, limit int, req dtos.LanguageSDKSearchQueryRequest) (languages []models.LanguageSdk, count uint32, edgeXErr error) { + return languageSearch(c, offset, limit, req) +} + +func (c *Client) AddLanguageSdk(ls models.LanguageSdk) (language models.LanguageSdk, edgeXErr error) { + if len(ls.Id) == 0 { + ls.Id = utils.RandomNum() + } + return addLanguageSdk(c, ls) +} + +func (c *Client) UpdateLanguageSdk(ls models.LanguageSdk) error { + return updateLanguageSdk(c, ls) +} + +func (c *Client) UpdateSystemMetrics(metrics dtos.SystemMetrics) error { + return updateSystemMetrics(c, metrics) +} + +func (c *Client) GetSystemMetrics(start, end int64) ([]dtos.SystemMetrics, error) { + return getSystemMetrics(c, start, end) +} + +func (c *Client) RemoveRangeSystemMetrics(min, max string) error { + return removeRangeSystemMetrics(c, min, max) +} diff --git a/internal/hummingbird/core/infrastructure/mysql/config.go b/internal/hummingbird/core/infrastructure/mysql/config.go new file mode 100644 index 0000000..463f253 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/mysql/config.go @@ -0,0 +1,45 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package mysql + +import ( + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "gorm.io/gorm" +) + +func updateAdvanceConfig(c *Client, config models.AdvanceConfig) error { + if err := c.client.UpdateObject(&config); err != nil { + return errort.NewCommonErr(errort.DefaultSystemError, err) + } + return nil +} + +func getAdvanceConfig(c *Client) (models.AdvanceConfig, error) { + var config models.AdvanceConfig + err := c.client.GetObject(&models.AdvanceConfig{ID: constants.DefaultAdvanceConfigID}, &config) + if err != nil { + if err == gorm.ErrRecordNotFound { + config.ID = constants.DefaultAdvanceConfigID + if err = c.client.CreateObject(&config); err != nil { + return models.AdvanceConfig{}, errort.NewCommonErr(errort.DefaultSystemError, err) + } + return config, nil + } + return models.AdvanceConfig{}, errort.NewCommonErr(errort.DefaultSystemError, err) + } + return config, nil +} diff --git a/internal/hummingbird/core/infrastructure/mysql/dataresource.go b/internal/hummingbird/core/infrastructure/mysql/dataresource.go new file mode 100644 index 0000000..534674b --- /dev/null +++ b/internal/hummingbird/core/infrastructure/mysql/dataresource.go @@ -0,0 +1,117 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package mysql + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" +) + +func dataResourceById(c *Client, id string) (dateResource models.DataResource, edgeXErr error) { + if id == "" { + return dateResource, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "dateResource id is empty", nil) + } + err := c.Pool.Table(dateResource.TableName()).First(&dateResource, id).Error + if err != nil { + if err == gorm.ErrRecordNotFound { + return dateResource, errort.NewCommonErr(errort.DefaultResourcesNotFound, fmt.Errorf("dateResource id(%s) not found", id)) + } + return dateResource, errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("query dateResource fail (Id:%s), %s", dateResource.Id, err)) + } + return +} + +func addDataResource(c *Client, ds models.DataResource) (id string, edgeXErr error) { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + err := c.client.CreateObject(&ds) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "data resourced creation failed", err) + } + + return ds.Id, edgeXErr +} + +func updateDataResource(c *Client, dl models.DataResource) error { + dl.Modified = utils.MakeTimestamp() + err := c.client.UpdateObject(&dl) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "data resource update failed", err) + } + return nil +} + +func deleteDataResourceById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "id is empty", nil) + } + err := c.client.DeleteObject(&models.DataResource{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "data resourced deletion failed", err) + } + return nil +} + +func updateDataResourceHealth(c *Client, id string, health bool) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "id is empty", nil) + } + d := models.DataResource{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where("id = ?", id).Updates(map[string]interface{}{"health": health}).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "update data resource failed", err) + } + return nil +} + +func dataResourceSearch(c *Client, offset int, limit int, req dtos.DataResourceSearchQueryRequest) (dataResource []models.DataResource, count uint32, edgeXErr error) { + dl := models.DataResource{} + var total int64 + tx := c.Pool.Table(dl.TableName()) + tx = sqlite.BuildCommonCondition(tx, dl, req.BaseSearchConditionQuery) + // 特殊条件 + if req.Type != "" { + tx = tx.Where("`type` = ?", req.Type) + } + if req.Health != "" { + isHealth := true + if req.Health == SearchReqBoolTrue { + isHealth = true + } else { + isHealth = false + } + tx = tx.Where("`health` = ?", isHealth) + } + err := tx.Count(&total).Error + if err != nil { + return []models.DataResource{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "data resource failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&dataResource).Error + if err != nil { + return []models.DataResource{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "data resource failed query from the database", err) + } + + return dataResource, uint32(total), nil +} diff --git a/internal/hummingbird/core/infrastructure/mysql/device.go b/internal/hummingbird/core/infrastructure/mysql/device.go new file mode 100644 index 0000000..e1e26da --- /dev/null +++ b/internal/hummingbird/core/infrastructure/mysql/device.go @@ -0,0 +1,307 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package mysql + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func deviceById(c *Client, id string) (device models.Device, edgeXErr error) { + if id == "" { + return device, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "device id is empty", nil) + } + //err := c.client.GetObject(&models.Device{Id: id}, &device) + err := c.Pool.Table(device.TableName()).Preload("Product").First(&device, id).Error + if err != nil { + if err == gorm.ErrRecordNotFound { + return device, errort.NewCommonErr(errort.DeviceNotExist, fmt.Errorf("device id(%s) not found", id)) + } + return device, errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("query device fail (Id:%s), %s", device.Id, err)) + } + return +} + +func deviceOnlineById(c *Client, id string) (edgeXErr error) { + d := models.Device{} + tx := c.Pool.Table(d.TableName()) + edgeXErr = tx.Where("id = ?", id).Updates(map[string]interface{}{"status": constants.DeviceStatusOnline, "last_online_time": utils.MakeTimestamp()}).Error + if edgeXErr != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "deviceOnlineById failed", edgeXErr) + } + return nil +} + +func deviceOfflineById(c *Client, id string) (edgeXErr error) { + d := models.Device{} + tx := c.Pool.Table(d.TableName()) + edgeXErr = tx.Where("id = ?", id).Updates(map[string]interface{}{"status": constants.DeviceStatusOffline}).Error + if edgeXErr != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "deviceOnlineById failed", edgeXErr) + } + return nil +} + +func deviceOfflineByCloudInstanceId(c *Client, id string) (edgeXErr error) { + d := models.Device{} + tx := c.Pool.Table(d.TableName()) + edgeXErr = tx.Where("cloud_instance_id = ?", id).Updates(map[string]interface{}{"status": constants.DeviceStatusOffline}).Error + if edgeXErr != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "deviceOnlineById failed", edgeXErr) + } + return nil +} + +func msgReportDeviceById(c *Client, id string) (device models.Device, edgeXErr error) { + if id == "" { + return device, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "device id is empty", nil) + } + //err := c.client.GetObject(&models.Device{Id: id}, &device) + err := c.Pool.Table(device.TableName()).Preload("Product").First(&device, id).Error + if err != nil { + if err == gorm.ErrRecordNotFound { + return device, errort.NewCommonErr(errort.DeviceNotExist, fmt.Errorf("device id(%s) not found", id)) + } + return device, errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("query device fail (Id:%s), %s", device.Id, err)) + } + return +} + +func deviceByCloudId(c *Client, id string) (device models.Device, edgeXErr error) { + if id == "" { + return device, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "device cloudId is empty", nil) + } + err := c.client.GetObject(&models.Device{CloudDeviceId: id}, &device) + if err != nil { + if err == gorm.ErrRecordNotFound { + return device, errort.NewCommonErr(errort.DeviceNotExist, fmt.Errorf("device cloudId (%s) not found", id)) + } + return device, errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("query device fail (cloudId:%s), %s", device.Id, err)) + } + return +} + +func devicesSearch(c *Client, offset int, limit int, req dtos.DeviceSearchQueryRequest) (devices []models.Device, count uint32, edgeXErr error) { + dp := models.Device{} + var total int64 + tx := c.Pool.Table(dp.TableName()) + tx = sqlite.BuildCommonCondition(tx, dp, req.BaseSearchConditionQuery) + + if req.Name != "" { + tx = tx.Where("`name` LIKE ?", sqlite.MakeLikeParams(req.Name)) + } + + if req.Platform != "" { + tx = tx.Where("`platform` = ?", req.Platform) + } + + if req.ProductId != "" { + tx = tx.Where("`product_id` = ?", req.ProductId) + } + + if req.CloudProductId != "" { + tx = tx.Where("`cloud_product_id` = ?", req.CloudProductId) + + } + if req.CloudInstanceId != "" { + tx = tx.Where("`cloud_instance_id` = ?", req.CloudInstanceId) + } + + if req.DriveInstanceId != "" { + tx = tx.Where("`drive_instance_id` = ?", req.DriveInstanceId) + } + if req.Status != "" { + tx = tx.Where("`status` = ?", req.Status) + } + + err := tx.Count(&total).Error + if err != nil { + return []models.Device{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "devices failed query from the database", err) + } + + err = tx.Offset(offset).Preload("Product").Limit(limit).Find(&devices).Error + if err != nil { + return []models.Device{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "devices failed query from the database", err) + } + + return devices, uint32(total), nil +} + +func batchUpsertDevice(c *Client, d []models.Device) (int64, error) { + if len(d) <= 0 { + return 0, nil + } + tx := c.Pool.Clauses( + clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(&d, 10000) + num := tx.RowsAffected + err := tx.Error + if err != nil { + return num, err + } + return num, nil +} + +func batchDeleteDevice(c *Client, ids []string) error { + d := models.Device{} + tx := c.Pool.Table(d.TableName()) + err := tx.Delete(d, ids).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "batchDeleteDevice failed", err) + } + return nil +} + +func batchUnBindDevice(c *Client, ids []string) error { + d := models.Device{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where("id IN ?", ids).Updates(map[string]interface{}{"drive_instance_id": ""}).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "batchDeleteDevice failed", err) + } + return nil +} + +func batchBindDevice(c *Client, ids []string, driverInstanceId string) error { + d := models.Device{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where("id IN ?", ids).Updates(map[string]interface{}{"drive_instance_id": driverInstanceId}).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "batchDeleteDevice failed", err) + } + return nil +} + +func deleteDeviceById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "device id is empty", nil) + } + err := c.client.DeleteObject(&models.Device{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "device deletion failed", err) + } + return nil +} + +func deleteDeviceByCloudInstanceId(c *Client, cloudInstanceId string) error { + if cloudInstanceId == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "cloudInstanceId id is empty", nil) + } + err := c.client.DeleteObject(&models.Device{CloudInstanceId: cloudInstanceId}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "device deletion failed", err) + } + return nil +} + +func updateDevice(c *Client, dl models.Device) error { + dl.Modified = utils.MakeTimestamp() + err := c.client.UpdateObject(&dl) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "device update failed", err) + } + return nil +} + +func addDevice(c *Client, device models.Device) (string, error) { + exists, edgeXErr := deviceNameExist(c, device.Name) + if edgeXErr != nil { + return "", edgeXErr + } else if exists { + return "", errort.NewCommonEdgeX(errort.DefaultNameRepeat, fmt.Sprintf("device name %s exists", device.Name), edgeXErr) + } + exists, edgeXErr = productIdExist(c, device.ProductId) + if edgeXErr != nil { + return "", edgeXErr + } else if !exists { + return "", errort.NewCommonEdgeX(errort.DeviceProductIdNotFound, fmt.Sprintf("device product %s not exists", device.ProductId), edgeXErr) + } + ts := utils.MakeTimestamp() + if device.Created == 0 { + device.Created = ts + } + device.Modified = ts + + err := c.client.CreateObject(&device) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "device creation failed", err) + } + + return device.Id, edgeXErr +} + +func addOrUpdateAuth(c *Client, auth models.MqttAuth) error { + exists, err := c.client.ExistObject(&models.MqttAuth{ClientId: auth.ClientId}) + if err != nil { + return err + } + if !exists { + ts := utils.MakeTimestamp() + if auth.Created == 0 { + auth.Created = ts + } + auth.Modified = ts + err = c.client.CreateObject(&auth) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "mqtt auch creation failed", err) + } + } + return nil +} + +func addMqttAuth(c *Client, auth models.MqttAuth) (string, error) { + var edgeXErr error + ts := utils.MakeTimestamp() + if auth.Created == 0 { + auth.Created = ts + } + auth.Modified = ts + + err := c.client.CreateObject(&auth) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "mqtt auch creation failed", err) + } + return auth.Id, edgeXErr +} + +func deviceNameExist(c *Client, name string) (bool, error) { + exists, err := c.client.ExistObject(&models.Product{Name: name}) + if err != nil { + return false, err + } + return exists, nil +} + +func deviceMqttAuthInfo(c *Client, id string) (mqttAuth models.MqttAuth, edgeXErr error) { + if id == "" { + return mqttAuth, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "device id is empty", nil) + } + err := c.client.GetObject(&models.MqttAuth{ResourceId: id, ResourceType: constants.DeviceResource}, &mqttAuth) + if err != nil { + if err == gorm.ErrRecordNotFound { + return mqttAuth, errort.NewCommonErr(errort.DefaultResourcesNotFound, fmt.Errorf("mqtt auth resoure id(%s) not found", id)) + } + return mqttAuth, errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("query mqtt auth fail (resoureId:%s), %s", id, err)) + } + return +} diff --git a/internal/hummingbird/core/infrastructure/mysql/devicelibrary.go b/internal/hummingbird/core/infrastructure/mysql/devicelibrary.go new file mode 100644 index 0000000..733b5c4 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/mysql/devicelibrary.go @@ -0,0 +1,157 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package mysql + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" +) + +func deviceLibraryById(c *Client, id string) (deviceLibrary models.DeviceLibrary, edgeXErr error) { + if id == "" { + return deviceLibrary, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "deviceLibrary id is empty", nil) + } + err := c.client.GetObject(&models.DeviceLibrary{Id: id}, &deviceLibrary) + if err != nil { + if err == gorm.ErrRecordNotFound { + return deviceLibrary, errort.NewCommonErr(errort.DeviceLibraryNotExist, fmt.Errorf("device library id(%s) not found", id)) + } + return deviceLibrary, err + } + return +} + +func deleteDeviceLibraryById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "id is empty", nil) + } + err := c.client.DeleteObject(&models.DeviceLibrary{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "device library deletion failed", err) + } + return nil +} + +func deviceLibraryIdExists(c *Client, id string) (bool, error) { + exists, err := c.client.ExistObject(&models.DeviceLibrary{Id: id}) + if err != nil { + return false, err + } + return exists, nil +} + +func addDeviceLibrary(c *Client, dl models.DeviceLibrary) (models.DeviceLibrary, error) { + // query device library name and id to avoid the conflict + exists, edgeXErr := deviceLibraryIdExists(c, dl.Id) + if edgeXErr != nil { + return dl, edgeXErr + } else if exists { + return dl, errort.NewCommonEdgeX(errort.DefaultResourcesRepeat, fmt.Sprintf("device library id %s exists", dl.Id), edgeXErr) + } + + // check docker config id exists + exists, edgeXErr = dockerConfigIdExists(c, dl.DockerConfigId) + if edgeXErr != nil { + return dl, edgeXErr + } else if !exists { + return dl, errort.NewCommonEdgeX(errort.DockerImageRepositoryNotFound, fmt.Sprintf("docker config id %s not exists", dl.Id), edgeXErr) + } + + ts := utils.MakeTimestamp() + if dl.Created == 0 { + dl.Created = ts + } + dl.Modified = ts + + err := c.client.CreateObject(&dl) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "device library creation failed", err) + } + + return dl, edgeXErr +} + +const ( + // url中请求参数有判断真假的 + SearchReqBoolTrue = "true" + SearchReqBoolFalse = "false" +) + +func deviceLibrariesSearch(c *Client, offset int, limit int, req dtos.DeviceLibrarySearchQueryRequest) (deviceLibraries []models.DeviceLibrary, count uint32, edgeXErr error) { + dl := models.DeviceLibrary{} + var total int64 + tx := c.Pool.Table(dl.TableName()) + tx = sqlite.BuildCommonCondition(tx, dl, req.BaseSearchConditionQuery) + // 特殊条件 + if req.DockerConfigId != "" { + tx = tx.Where("`docker_config_id` = ?", req.DockerConfigId) + } + if req.IsInternal != "" { + isInternal := true + if req.IsInternal == SearchReqBoolTrue { + isInternal = true + } else { + isInternal = false + } + tx = tx.Where("`is_internal` = ?", isInternal) + } + if req.DockerRepoName != "" { + tx = tx.Where("`docker_repo_name` = ?", req.DockerRepoName) + } + if req.NameAliasLike != "" { + tx = tx.Where("`name` LIKE ? OR `alias` LIKE ? OR `description` LIKE ?", sqlite.MakeLikeParams(req.NameAliasLike), sqlite.MakeLikeParams(req.NameAliasLike), sqlite.MakeLikeParams(req.NameAliasLike)) + } + if req.NoInIds != "" { + tx = tx.Where("`id` NOT IN ?", dtos.ApiParamsStringToArray(req.NoInIds)) + } + if req.ImageIds != "" { + tx = tx.Where("`docker_image_id` IN ?", dtos.ApiParamsStringToArray(req.ImageIds)) + } + if req.NoInImageIds != "" { + tx = tx.Where("`docker_image_id` NOT IN ?", dtos.ApiParamsStringToArray(req.NoInImageIds)) + } + if req.DriverType != 0 { + tx = tx.Where("`driver_type` = ?", req.DriverType) + } + if req.ClassifyId != 0 { + tx = tx.Where("`classify_id` = ?", req.ClassifyId) + } + + err := tx.Count(&total).Error + if err != nil { + return []models.DeviceLibrary{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "deviceLibraries failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Order("id desc").Find(&deviceLibraries).Error + if err != nil { + return []models.DeviceLibrary{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "deviceLibraries failed query from the database", err) + } + + return deviceLibraries, uint32(total), nil +} + +func updateDeviceLibrary(c *Client, dl models.DeviceLibrary) error { + dl.Modified = utils.MakeTimestamp() + err := c.client.UpdateObject(&dl) + if err != nil { + return err + } + return nil +} diff --git a/internal/hummingbird/core/infrastructure/mysql/deviceservice.go b/internal/hummingbird/core/infrastructure/mysql/deviceservice.go new file mode 100644 index 0000000..5ef1f5a --- /dev/null +++ b/internal/hummingbird/core/infrastructure/mysql/deviceservice.go @@ -0,0 +1,149 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package mysql + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" +) + +func deviceServicesSearch(c *Client, offset int, limit int, req dtos.DeviceServiceSearchQueryRequest) (deviceServices []models.DeviceService, count uint32, edgeXErr error) { + ds := models.DeviceService{} + var total int64 + tx := c.Pool.Table(ds.TableName()) + tx = sqlite.BuildCommonCondition(tx, ds, req.BaseSearchConditionQuery) + + if req.DeviceLibraryId != "" { + tx = tx.Where("`device_library_id` = ?", req.DeviceLibraryId) + } + if req.DeviceLibraryIds != "" { + tx = tx.Where("`device_library_id` IN ?", dtos.ApiParamsStringToArray(req.DeviceLibraryIds)) + } + if req.DriverType != 0 { + tx = tx.Where("`driver_type` = ?", req.DriverType) + } + if req.CloudProductId != "" { + tx = tx.Where("`cloud_product_id` = ?", req.CloudProductId) + } + if req.ProductId != "" { + tx = tx.Where("`product_id` = ?", req.ProductId) + } + if req.Platform != "" { + tx = tx.Where("`platform` = ?", req.Platform) + + } + err := tx.Count(&total).Error + if err != nil { + return []models.DeviceService{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "deviceServices failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&deviceServices).Error + if err != nil { + return []models.DeviceService{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "deviceServices failed query from the database", err) + } + + return deviceServices, uint32(total), nil +} + +func deviceServiceIdExist(c *Client, id string) (bool, error) { + exists, err := c.client.ExistObject(&models.DeviceService{Id: id}) + if err != nil { + return false, err + } + return exists, nil +} + +func addDeviceService(c *Client, ds models.DeviceService) (addedDeviceService models.DeviceService, edgeXErr error) { + exists, edgeXErr := deviceServiceIdExist(c, ds.Id) + if edgeXErr != nil { + return ds, edgeXErr + } else if exists { + return ds, errort.NewCommonEdgeX(errort.DefaultResourcesRepeat, fmt.Sprintf("device service id %s exists", ds.Id), edgeXErr) + } + + exists, edgeXErr = deviceServiceIdExist(c, ds.Name) + if edgeXErr != nil { + return ds, edgeXErr + } else if exists { + return ds, errort.NewCommonEdgeX(errort.DefaultResourcesRepeat, fmt.Sprintf("device service name %s exists", ds.Name), edgeXErr) + } + + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + + err := c.client.CreateObject(&ds) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "device service creation failed", err) + } + + return ds, edgeXErr +} + +func updateDeviceService(c *Client, ds models.DeviceService) error { + ds.Modified = utils.MakeTimestamp() + err := c.client.UpdateObject(&ds) + if err != nil { + return err + } + return nil +} + +func deviceServiceById(c *Client, id string) (deviceService models.DeviceService, edgeXErr error) { + if id == "" { + return deviceService, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "device service id is empty", nil) + } + err := c.client.GetObject(&models.DeviceService{Id: id}, &deviceService) + if err != nil { + if err == gorm.ErrRecordNotFound { + return deviceService, errort.NewCommonErr(errort.DeviceServiceNotExist, fmt.Errorf("device service id(%s) not found", id)) + } + return deviceService, err + } + return +} + +func deleteDeviceServiceById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "device service id is empty", nil) + } + err := c.client.DeleteObject(&models.DeviceService{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "device service deletion failed", err) + } + return nil +} + +func driverMqttAuthInfo(c *Client, id string) (mqttAuth models.MqttAuth, edgeXErr error) { + if id == "" { + return mqttAuth, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "device id is empty", nil) + } + err := c.client.GetObject(&models.MqttAuth{ResourceId: id, ResourceType: constants.DriverResource}, &mqttAuth) + if err != nil { + if err == gorm.ErrRecordNotFound { + return mqttAuth, errort.NewCommonErr(errort.DefaultResourcesNotFound, fmt.Errorf("mqtt auth resoure id(%s) not found", id)) + } + return mqttAuth, errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("query mqtt auth fail (resoureId:%s), %s", id, err)) + } + return +} diff --git a/internal/hummingbird/core/infrastructure/mysql/dockerconfig.go b/internal/hummingbird/core/infrastructure/mysql/dockerconfig.go new file mode 100644 index 0000000..88df6c1 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/mysql/dockerconfig.go @@ -0,0 +1,107 @@ +package mysql + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + + "gorm.io/gorm" + //"gitlab.com/tedge/edgex/internal/dtos" + //"gitlab.com/tedge/edgex/internal/models" + //"gitlab.com/tedge/edgex/internal/pkg/errort" + //"gitlab.com/tedge/edgex/internal/pkg/utils" + //"gitlab.com/tedge/edgex/internal/tools/sqldb/sqlite" +) + +func dockerConfigIdExists(c *Client, id string) (bool, error) { + exists, err := c.client.ExistObject(&models.DockerConfig{ + Id: id, + }) + if err != nil { + return false, err + } + return exists, nil +} + +func dockerConfigAdd(c *Client, dc models.DockerConfig) (models.DockerConfig, error) { + exists, edgeXErr := dockerConfigIdExists(c, dc.Id) + if edgeXErr != nil { + return dc, edgeXErr + } else if exists { + return dc, errort.NewCommonEdgeX(errort.DefaultResourcesRepeat, fmt.Sprintf("docker config %s exists", dc.Id), edgeXErr) + } + ts := utils.MakeTimestamp() + if dc.Created == 0 { + dc.Created = ts + } + dc.Modified = ts + + err := c.client.CreateObject(&dc) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "func point creation failed", err) + return dc, edgeXErr + } + + return dc, edgeXErr +} + +func dockerConfigById(c *Client, id string) (dc models.DockerConfig, edgeXErr error) { + if id == "" { + return dc, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "docker config id is empty", nil) + } + err := c.client.GetObject(&models.DockerConfig{Id: id}, &dc) + if err != nil { + if err == gorm.ErrRecordNotFound { + return dc, errort.NewCommonErr(errort.DockerConfigNotExist, fmt.Errorf("docker config id(%s)not found", id)) + } + return dc, err + } + return +} + +func dockerConfigDeleteById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "docker config id is empty", nil) + } + rawErr := c.client.DeleteObject(&models.DockerConfig{Id: id}) + if rawErr != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "docker config deletion failed", rawErr) + } + return nil +} + +func dockerConfigUpdate(c *Client, dc models.DockerConfig) error { + dc.Modified = utils.MakeTimestamp() + err := c.client.UpdateObject(&dc) + if err != nil { + return err + } + return nil +} + +func dockerConfigsSearch(c *Client, offset int, limit int, req dtos.DockerConfigSearchQueryRequest) (dcs []models.DockerConfig, count uint32, edgeXErr error) { + d := models.DockerConfig{} + var total int64 + tx := c.Pool.Table(d.TableName()) + tx = sqlite.BuildCommonCondition(tx, d, req.BaseSearchConditionQuery) + if req.Address != "" { + tx = tx.Where("`address` = ?", req.Address) + } + if req.Account != "" { + tx = tx.Where("`account` = ?", req.Account) + } + err := tx.Count(&total).Error + if err != nil { + return []models.DockerConfig{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "device failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&dcs).Error + if err != nil { + return []models.DockerConfig{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "device failed query from the database", err) + } + + return dcs, uint32(total), nil +} diff --git a/internal/hummingbird/core/infrastructure/mysql/docs.go b/internal/hummingbird/core/infrastructure/mysql/docs.go new file mode 100644 index 0000000..60ed1af --- /dev/null +++ b/internal/hummingbird/core/infrastructure/mysql/docs.go @@ -0,0 +1,60 @@ +/******************************************************************************* + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package mysql + +import ( + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func docsSearch(c *Client, offset int, limit int, req dtos.DocsSearchQueryRequest) (docs []models.Doc, count uint32, edgeXErr error) { + dp := models.Doc{} + var total int64 + tx := c.Pool.Table(dp.TableName()) + tx = sqlite.BuildCommonCondition(tx, dp, req.BaseSearchConditionQuery) + + if req.Name != "" { + tx = tx.Where("`name` = ?", req.Name) + } + + err := tx.Count(&total).Error + if err != nil { + return []models.Doc{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "docs failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&docs).Error + if err != nil { + return []models.Doc{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "docs failed query from the database", err) + } + + return docs, uint32(total), nil +} + +func batchUpsertDocsTemplate(c *Client, d []models.Doc) (int64, error) { + if len(d) <= 0 { + return 0, nil + } + tx := c.Pool.Session(&gorm.Session{FullSaveAssociations: true}).Clauses( + clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(d, sqlite.CreateBatchSize) + num := tx.RowsAffected + err := tx.Error + if err != nil { + return num, err + } + return num, nil +} diff --git a/internal/hummingbird/core/infrastructure/mysql/driverclassify.go b/internal/hummingbird/core/infrastructure/mysql/driverclassify.go new file mode 100644 index 0000000..d1e64c5 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/mysql/driverclassify.go @@ -0,0 +1,41 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package mysql + +import ( + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" +) + +func driverClassifySearch(c *Client, offset int, limit int, req dtos.DriverClassifyQueryRequest) (dcs []models.DriverClassify, count uint32, edgeXErr error) { + d := models.DriverClassify{} + var total int64 + tx := c.Pool.Table(d.TableName()) + tx = sqlite.BuildCommonCondition(tx, d, req.BaseSearchConditionQuery) + if req.Name != "" { + tx = tx.Where("`name` = ?", req.Name) + } + err := tx.Count(&total).Error + if err != nil { + return []models.DriverClassify{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "device failed query from the database", err) + } + err = tx.Offset(offset).Limit(limit).Find(&dcs).Error + if err != nil { + return []models.DriverClassify{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "device failed query from the database", err) + } + return dcs, uint32(total), nil +} diff --git a/internal/hummingbird/core/infrastructure/mysql/language.go b/internal/hummingbird/core/infrastructure/mysql/language.go new file mode 100644 index 0000000..d3d25be --- /dev/null +++ b/internal/hummingbird/core/infrastructure/mysql/language.go @@ -0,0 +1,82 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package mysql + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" +) + +func languageSearch(c *Client, offset int, limit int, req dtos.LanguageSDKSearchQueryRequest) (languages []models.LanguageSdk, count uint32, edgeXErr error) { + cs := models.LanguageSdk{} + var total int64 + tx := c.Pool.Table(cs.TableName()) + tx = sqlite.BuildCommonCondition(tx, cs, req.BaseSearchConditionQuery) + + err := tx.Count(&total).Error + if err != nil { + return []models.LanguageSdk{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "language sdk failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&languages).Error + if err != nil { + return []models.LanguageSdk{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "language sdk failed query from the database", err) + } + + return languages, uint32(total), nil +} + +func languageByName(c *Client, name string) (language models.LanguageSdk, edgeXErr error) { + if name == "" { + return language, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "language sdk name id is empty", nil) + } + err := c.client.GetObject(&models.LanguageSdk{Name: name}, &language) + if err != nil { + if err == gorm.ErrRecordNotFound { + return language, errort.NewCommonErr(errort.DefaultResourcesNotFound, fmt.Errorf("language sdk (%s) not found", name)) + } + return language, err + } + return +} + +func addLanguageSdk(c *Client, cs models.LanguageSdk) (language models.LanguageSdk, edgeXErr error) { + ts := utils.MakeTimestamp() + if cs.Created == 0 { + cs.Created = ts + } + cs.Modified = ts + + err := c.client.CreateObject(&cs) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "language creation failed", err) + } + + return cs, edgeXErr +} + +func updateLanguageSdk(c *Client, dl models.LanguageSdk) error { + dl.Modified = utils.MakeTimestamp() + err := c.client.UpdateObject(&dl) + if err != nil { + return err + } + return nil +} diff --git a/internal/hummingbird/core/infrastructure/mysql/monitor.go b/internal/hummingbird/core/infrastructure/mysql/monitor.go new file mode 100644 index 0000000..43d26e3 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/mysql/monitor.go @@ -0,0 +1,48 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package mysql + +import ( + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" +) + +func updateSystemMetrics(c *Client, metrics dtos.SystemMetrics) error { + var m = models.SystemMetrics{ + Data: metrics.String(), + Timestamp: metrics.Timestamp, + } + return c.client.CreateObject(&m) +} + +func getSystemMetrics(c *Client, start, end int64) ([]dtos.SystemMetrics, error) { + var list []models.SystemMetrics + if err := c.Pool.Where("timestamp >= ? and timestamp <= ?", start, end).Find(&list).Error; err != nil { + return nil, err + } + var metrics = make([]dtos.SystemMetrics, 0) + for _, item := range list { + m, err := dtos.FromModelsSystemMetricsToDTO(item) + if err != nil { + return nil, err + } + metrics = append(metrics, m) + } + return metrics, nil +} + +func removeRangeSystemMetrics(c *Client, min, max string) error { + return c.Pool.Where("timestamp >= ? and timestamp <= ?", min, max).Delete(&models.SystemMetrics{}).Error +} diff --git a/internal/hummingbird/core/infrastructure/mysql/msggather.go b/internal/hummingbird/core/infrastructure/mysql/msggather.go new file mode 100644 index 0000000..fe3cd32 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/mysql/msggather.go @@ -0,0 +1,61 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package mysql + +import ( + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" +) + +func addMsgGather(c *Client, msgGather models.MsgGather) error { + ts := utils.MakeTimestamp() + if msgGather.Created == 0 { + msgGather.Created = ts + } + msgGather.Modified = ts + + err := c.client.CreateObject(&msgGather) + if err != nil { + edgeXErr := errort.NewCommonEdgeX(errort.DefaultSystemError, "add msg gather failed", err) + return edgeXErr + } + + return nil +} + +func msgGatherSearch(c *Client, offset int, limit int, req dtos.MsgGatherSearchQueryRequest) (dcs []models.MsgGather, count uint32, edgeXErr error) { + d := models.MsgGather{} + var total int64 + tx := c.Pool.Table(d.TableName()) + tx = sqlite.BuildCommonCondition(tx, d, req.BaseSearchConditionQuery) + + if len(req.Date) > 0 { + tx = tx.Where("`date` in (?)", req.Date) + } + err := tx.Count(&total).Error + if err != nil { + return []models.MsgGather{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "msg gather failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&dcs).Error + if err != nil { + return []models.MsgGather{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "msg gather failed query from the database", err) + } + + return dcs, uint32(total), nil +} diff --git a/internal/hummingbird/core/infrastructure/mysql/product.go b/internal/hummingbird/core/infrastructure/mysql/product.go new file mode 100644 index 0000000..d5bb1b6 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/mysql/product.go @@ -0,0 +1,317 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package mysql + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func productNameExist(c *Client, name string) (bool, error) { + exists, err := c.client.ExistObject(&models.Product{Name: name}) + if err != nil { + return false, err + } + return exists, nil +} + +func productIdExist(c *Client, id string) (bool, error) { + exists, err := c.client.ExistObject(&models.Product{Id: id}) + if err != nil { + return false, err + } + return exists, nil +} + +func addProduct(c *Client, ds models.Product) (product models.Product, edgeXErr error) { + exists, edgeXErr := productNameExist(c, ds.Name) + if edgeXErr != nil { + return ds, edgeXErr + } else if exists { + return ds, errort.NewCommonEdgeX(errort.DefaultResourcesRepeat, fmt.Sprintf("product name %s exists", ds.Id), edgeXErr) + } + + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + + err := c.client.CreateObject(&ds) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "product creation failed", err) + } + + return ds, edgeXErr +} + +func productById(c *Client, id string) (product models.Product, edgeXErr error) { + if id == "" { + return product, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "product id is empty", nil) + } + err := c.client.GetPreloadObject(&models.Product{Id: id}, &product) + if err != nil { + if err == gorm.ErrRecordNotFound { + return product, errort.NewCommonErr(errort.ProductNotExist, fmt.Errorf("product id(%s) not found", id)) + } + return product, errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("query product fail (Id:%s), %s", product.Id, err)) + } + return +} + +func productByCloudId(c *Client, id string) (product models.Product, edgeXErr error) { + if id == "" { + return product, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "product id is empty", nil) + } + err := c.client.GetPreloadObject(&models.Product{CloudProductId: id}, &product) + if err != nil { + if err == gorm.ErrRecordNotFound { + return product, errort.NewCommonErr(errort.ProductNotExist, fmt.Errorf("product id(%s) not found", id)) + } + return product, errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("query product fail (Id:%s), %s", product.Id, err)) + } + //_ = c.cache.SetProduct(product) + return +} + +func productsSearch(c *Client, offset int, limit int, preload bool, req dtos.ProductSearchQueryRequest) (products []models.Product, count uint32, edgeXErr error) { + dp := models.Product{} + var total int64 + tx := c.Pool.Table(dp.TableName()) + tx = sqlite.BuildCommonCondition(tx, dp, req.BaseSearchConditionQuery) + + if req.Name != "" { + tx = tx.Where("`name` LIKE ?", sqlite.MakeLikeParams(req.Name)) + } + + if req.Platform != "" { + tx = tx.Where("`platform` = ?", req.Platform) + } + + if req.CloudInstanceId != "" { + tx = tx.Where("`cloud_instance_id` = ?", req.CloudInstanceId) + } + + if req.ProductId != "" { + tx = tx.Where("`product_id` = ?", req.ProductId) + } + + err := tx.Count(&total).Error + if err != nil { + return []models.Product{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "products failed query from the database", err) + } + if preload { + err = tx.Offset(offset).Preload("Properties").Preload("Events").Preload("Actions").Limit(limit).Find(&products).Error + } else { + err = tx.Offset(offset).Limit(limit).Find(&products).Error + } + if err != nil { + return []models.Product{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "products failed query from the database", err) + } + + return products, uint32(total), nil +} + +func batchUpsertProduct(c *Client, d []models.Product) (int64, error) { + if len(d) <= 0 { + return 0, nil + } + tx := c.Pool.Session(&gorm.Session{FullSaveAssociations: true}).Clauses( + clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(d, sqlite.CreateBatchSize) + num := tx.RowsAffected + err := tx.Error + if err != nil { + return num, err + } + return num, nil +} + +func batchSaveProduct(c *Client, d []models.Product) error { + if len(d) <= 0 { + return nil + } + tx := c.Pool.Session(&gorm.Session{FullSaveAssociations: true}).Clauses( + clause.OnConflict{ + UpdateAll: true, + }).Save(d) + //num := tx.RowsAffected + err := tx.Error + if err != nil { + return err + } + return nil +} + +func batchDeleteProduct(c *Client, products []models.Product) error { + d := models.Product{} + tx := c.Pool.Table(d.TableName()) + err := tx.Delete(&products).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "product batchDeleteProduct failed", err) + } + return nil +} + +func deleteProductById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "product id is empty", nil) + } + err := c.client.DeleteObject(&models.Product{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "product deletion failed", err) + } + return nil +} + +func deleteProductObject(c *Client, product models.Product) error { + err := c.client.DeleteObject(&product) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "product deletion failed", err) + } + return nil +} + +func associationsDeleteProductObject(c *Client, product models.Product) error { + err := c.client.AssociationsDeleteObject(&product) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "product deletion failed", err) + } + return nil +} + +func updateProduct(c *Client, dl models.Product) error { + dl.Modified = utils.MakeTimestamp() + err := c.client.UpdateObject(&dl) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "product update failed", err) + } + return nil +} + +func associationsUpdateProduct(c *Client, dl models.Product) error { + dl.Modified = utils.MakeTimestamp() + err := c.client.AssociationsUpdateObject(&dl) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "product update failed", err) + } + return nil +} + +func batchDeleteProperties(c *Client, propertiesIds []string) error { + d := models.Properties{} + tx := c.Pool.Table(d.TableName()) + err := tx.Delete(d, propertiesIds).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "batchDeleteProperties failed", err) + } + return nil +} + +func batchDeleteSystemProperties(c *Client) error { + d := models.Properties{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where(models.Properties{System: true}).Delete(d).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "batchDelete system property failed", err) + } + return nil +} + +func batchInsertSystemProperties(c *Client, p []models.Properties) (int64, error) { + if len(p) <= 0 { + return 0, nil + } + tx := c.Pool.CreateInBatches(p, sqlite.CreateBatchSize) + num := tx.RowsAffected + err := tx.Error + if err != nil { + return num, err + } + return num, nil +} + +func batchDeleteEvents(c *Client, eventIds []string) error { + d := models.Events{} + tx := c.Pool.Table(d.TableName()) + err := tx.Delete(d, eventIds).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "batchDeleteEvents failed", err) + } + return nil +} + +func batchDeleteSystemEvents(c *Client) error { + d := models.Events{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where(models.Events{System: true}).Delete(d).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "batch system Actions failed", err) + } + return nil +} + +func batchInsertSystemEvents(c *Client, p []models.Events) (int64, error) { + if len(p) <= 0 { + return 0, nil + } + tx := c.Pool.CreateInBatches(p, sqlite.CreateBatchSize) + num := tx.RowsAffected + err := tx.Error + if err != nil { + return num, err + } + return num, nil +} + +func batchDeleteActions(c *Client, actionIds []string) error { + d := models.Actions{} + tx := c.Pool.Table(d.TableName()) + err := tx.Delete(d, actionIds).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "batchDeleteActions failed", err) + } + return nil +} + +func batchDeleteSystemActions(c *Client) error { + d := models.Actions{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where(models.Actions{System: true}).Delete(d).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "batch Delete system Actions failed", err) + } + return nil +} + +func batchInsertSystemActions(c *Client, p []models.Actions) (int64, error) { + if len(p) <= 0 { + return 0, nil + } + tx := c.Pool.CreateInBatches(p, sqlite.CreateBatchSize) + num := tx.RowsAffected + err := tx.Error + if err != nil { + return num, err + } + return num, nil +} diff --git a/internal/hummingbird/core/infrastructure/mysql/quicknavagation.go b/internal/hummingbird/core/infrastructure/mysql/quicknavagation.go new file mode 100644 index 0000000..0945a28 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/mysql/quicknavagation.go @@ -0,0 +1,71 @@ +/******************************************************************************* + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package mysql + +import ( + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func quickNavigationSearch(c *Client, offset int, limit int, req dtos.QuickNavigationSearchQueryRequest) (quickNavigation []models.QuickNavigation, count uint32, edgeXErr error) { + dp := models.QuickNavigation{} + var total int64 + tx := c.Pool.Table(dp.TableName()) + tx = sqlite.BuildCommonCondition(tx, dp, req.BaseSearchConditionQuery) + + if req.Name != "" { + tx = tx.Where("`name` = ?", req.Name) + } + + err := tx.Count(&total).Error + if err != nil { + return []models.QuickNavigation{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "quick navigation failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&quickNavigation).Error + if err != nil { + return []models.QuickNavigation{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "quick navigation failed query from the database", err) + } + + return quickNavigation, uint32(total), nil +} + +func batchUpsertQuickNavigationTemplate(c *Client, d []models.QuickNavigation) (int64, error) { + if len(d) <= 0 { + return 0, nil + } + tx := c.Pool.Session(&gorm.Session{FullSaveAssociations: true}).Clauses( + clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(d, sqlite.CreateBatchSize) + num := tx.RowsAffected + err := tx.Error + if err != nil { + return num, err + } + return num, nil +} + +func deleteQuickNavigation(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "quick navigation id is empty", nil) + } + err := c.client.DeleteObject(&models.QuickNavigation{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "quick navigation deletion failed", err) + } + return nil +} diff --git a/internal/hummingbird/core/infrastructure/mysql/ruleengine.go b/internal/hummingbird/core/infrastructure/mysql/ruleengine.go new file mode 100644 index 0000000..612251a --- /dev/null +++ b/internal/hummingbird/core/infrastructure/mysql/ruleengine.go @@ -0,0 +1,120 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package mysql + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" +) + +func addRuleEngine(c *Client, ds models.RuleEngine) (id string, edgeXErr error) { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + err := c.client.CreateObject(&ds) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "rule engine creation failed", err) + } + + return ds.Id, edgeXErr +} + +func ruleEngineById(c *Client, id string) (ruleEngine models.RuleEngine, edgeXErr error) { + if id == "" { + return ruleEngine, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "rule engine id is empty", nil) + } + err := c.Pool.Table(ruleEngine.TableName()).Preload("DataResource").First(&ruleEngine, id).Error + if err != nil { + if err == gorm.ErrRecordNotFound { + return ruleEngine, errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("rule engine id id(%s) not found", id)) + } + return ruleEngine, errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("query rule engine id fail (Id:%s), %s", ruleEngine.Id, err)) + } + return +} + +func ruleEngineSearch(c *Client, offset int, limit int, req dtos.RuleEngineSearchQueryRequest) (ruleEngine []models.RuleEngine, count uint32, edgeXErr error) { + dp := models.RuleEngine{} + var total int64 + tx := c.Pool.Table(dp.TableName()) + tx = sqlite.BuildCommonCondition(tx, dp, req.BaseSearchConditionQuery) + + if req.Name != "" { + tx = tx.Where("`name` LIKE ?", sqlite.MakeLikeParams(req.Name)) + } + if req.Status != "" { + tx = tx.Where("`status` = ?", req.Status) + } + + err := tx.Count(&total).Error + if err != nil { + return ruleEngine, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "rules engine failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Preload("DataResource").Find(&ruleEngine).Error + if err != nil { + return ruleEngine, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "rules engine failed query from the database", err) + } + + return ruleEngine, uint32(total), nil +} + +func ruleEngineStart(c *Client, id string) error { + d := models.RuleEngine{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where("id = ?", id).Updates(map[string]interface{}{"status": constants.RuleStart}).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "start alert rule failed", err) + } + return nil +} + +func ruleEngineStop(c *Client, id string) error { + d := models.RuleEngine{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where("id = ?", id).Updates(map[string]interface{}{"status": constants.RuleStop}).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "stop alert rule failed", err) + } + return nil +} + +func deleteRuleEngineById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "rule engine id is empty", nil) + } + err := c.client.DeleteObject(&models.RuleEngine{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "rule engine deletion failed", err) + } + return nil +} + +func updateRuleEngine(c *Client, ruleEngine models.RuleEngine) error { + ruleEngine.Modified = utils.MakeTimestamp() + err := c.client.UpdateObject(&ruleEngine) + if err != nil { + return err + } + return nil +} diff --git a/internal/hummingbird/core/infrastructure/mysql/scene.go b/internal/hummingbird/core/infrastructure/mysql/scene.go new file mode 100644 index 0000000..761275a --- /dev/null +++ b/internal/hummingbird/core/infrastructure/mysql/scene.go @@ -0,0 +1,159 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package mysql + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" +) + +func addScene(c *Client, ds models.Scene) (scene models.Scene, edgeXErr error) { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + + err := c.client.CreateObject(&ds) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "scene creation failed", err) + } + return ds, edgeXErr +} + +func updateScene(c *Client, dl models.Scene) error { + dl.Modified = utils.MakeTimestamp() + err := c.client.UpdateObject(&dl) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "scene update failed", err) + } + return nil +} + +func sceneById(c *Client, id string) (scene models.Scene, err error) { + if id == "" { + return scene, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "scene id is empty", nil) + } + err = c.client.GetObject(&models.Scene{Id: id}, &scene) + if err != nil { + if err == gorm.ErrRecordNotFound { + return scene, errort.NewCommonErr(errort.DefaultResourcesNotFound, fmt.Errorf("scene id(%s) not found", id)) + } + return scene, err + } + return +} + +func sceneStart(c *Client, id string) error { + d := models.Scene{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where("id = ?", id).Updates(map[string]interface{}{"status": constants.SceneStart}).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "start scene rule failed", err) + } + return nil +} + +func sceneStop(c *Client, id string) error { + d := models.Scene{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where("id = ?", id).Updates(map[string]interface{}{"status": constants.SceneStop}).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "start scene rule failed", err) + } + return nil +} + +func deleteSceneById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "del scene id is empty", nil) + } + err := c.client.DeleteObject(&models.Scene{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "del scene deletion failed", err) + } + return nil +} + +func sceneSearch(c *Client, offset int, limit int, req dtos.SceneSearchQueryRequest) (scene []models.Scene, count uint32, edgeXErr error) { + dp := models.Scene{} + var total int64 + tx := c.Pool.Table(dp.TableName()) + tx = sqlite.BuildCommonCondition(tx, dp, req.BaseSearchConditionQuery) + + if req.Name != "" { + tx = tx.Where("`name` LIKE ?", sqlite.MakeLikeParams(req.Name)) + } + if req.Status != "" { + tx = tx.Where("`status` = ?", req.Status) + } + err := tx.Count(&total).Error + if err != nil { + return scene, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "scene search failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&scene).Error + if err != nil { + return scene, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "scene search failed query from the database", err) + } + + return scene, uint32(total), nil +} + +func addSceneLog(c *Client, ds models.SceneLog) (sceneLog models.SceneLog, edgeXErr error) { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + + err := c.client.CreateObject(&ds) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "scene log creation failed", err) + } + return ds, edgeXErr +} + +func sceneLogSearch(c *Client, offset int, limit int, req dtos.SceneLogSearchQueryRequest) (sceneLogs []models.SceneLog, count uint32, edgeXErr error) { + dp := models.SceneLog{} + var total int64 + tx := c.Pool.Table(dp.TableName()) + tx = sqlite.BuildCommonCondition(tx, dp, req.BaseSearchConditionQuery) + + if req.StartAt > 0 && req.EndAt > 0 && req.EndAt-req.StartAt > 0 { + tx.Where("created > ?", req.StartAt).Where("created < ?", req.EndAt) + } + if req.SceneId != "" { + tx = tx.Where("`scene_id` = ?", req.SceneId) + } + + err := tx.Count(&total).Error + if err != nil { + return sceneLogs, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "scene log search failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&sceneLogs).Error + if err != nil { + return sceneLogs, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "scene log search failed query from the database", err) + } + + return sceneLogs, uint32(total), nil +} diff --git a/internal/hummingbird/core/infrastructure/mysql/thingmodel.go b/internal/hummingbird/core/infrastructure/mysql/thingmodel.go new file mode 100644 index 0000000..7817eae --- /dev/null +++ b/internal/hummingbird/core/infrastructure/mysql/thingmodel.go @@ -0,0 +1,230 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package mysql + +import ( + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func addThingModelProperty(c *Client, ds models.Properties) (models.Properties, error) { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + + var edgeXErr error + err := c.client.CreateObject(&ds) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "thing model property creation failed", err) + } + return ds, edgeXErr + +} +func batchUpsertThingModel(c *Client, d interface{}) (int64, error) { + tx := c.Pool.Session(&gorm.Session{FullSaveAssociations: true}).Clauses( + clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(d, sqlite.CreateBatchSize) + num := tx.RowsAffected + err := tx.Error + if err != nil { + return num, err + } + return num, nil +} + +func addThingModelEvent(c *Client, ds models.Events) (models.Events, error) { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + + var edgeXErr error + err := c.client.CreateObject(&ds) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "thing model event creation failed", err) + } + + return ds, edgeXErr + +} + +func addThingModelAction(c *Client, ds models.Actions) (models.Actions, error) { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + + var edgeXErr error + err := c.client.CreateObject(&ds) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "thing model action creation failed", err) + } + + return ds, edgeXErr +} + +func updateThingModelProperty(c *Client, ds models.Properties) error { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + err := c.client.UpdateObject(&ds) + var edgeXErr error + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "thing model property update failed", err) + } + + return edgeXErr +} + +func updateThingModelEvent(c *Client, ds models.Events) error { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + err := c.client.UpdateObject(&ds) + var edgeXErr error + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "thing model events update failed", err) + } + + return edgeXErr +} + +func updateThingModelAction(c *Client, ds models.Actions) error { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + err := c.client.UpdateObject(&ds) + var edgeXErr error + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "thing model action update failed", err) + } + + return edgeXErr +} + +func deleteThingModelPropertyById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "properties id is empty", nil) + } + err := c.client.DeleteObject(&models.Properties{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "properties deletion failed", err) + } + return nil +} + +func deleteThingModelEventById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "events id is empty", nil) + } + err := c.client.DeleteObject(&models.Events{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "events deletion failed", err) + } + return nil +} + +func deleteThingModelActionById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "actions id is empty", nil) + } + err := c.client.DeleteObject(&models.Actions{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "actions deletion failed", err) + } + return nil +} + +func thingModelPropertyById(c *Client, id string) (models.Properties, error) { + cs := models.Properties{} + var properties models.Properties + tx := c.Pool.Table(cs.TableName()) + tx.Where("id = ?", id) + err := tx.Find(&properties).Error + return properties, err +} + +func thingModelEventById(c *Client, id string) (models.Events, error) { + cs := models.Events{} + var event models.Events + tx := c.Pool.Table(cs.TableName()) + tx.Where("id = ?", id) + err := tx.Find(&event).Error + return event, err +} + +func thingModeActionById(c *Client, id string) (models.Actions, error) { + cs := models.Actions{} + var action models.Actions + tx := c.Pool.Table(cs.TableName()) + tx.Where("id = ?", id) + err := tx.Find(&action).Error + return action, err +} + +func systemThingModelSearch(c *Client, modelType, modelName string) (interface{}, error) { + switch modelType { + case "property": + cs := models.Properties{} + var properties []models.Properties + tx := c.Pool.Table(cs.TableName()) + if modelName != "" { + tx.Where("system =1 and `name` LIKE ?", "%"+modelName+"%") + } else { + tx.Where("system =1") + } + err := tx.Find(&properties).Error + return properties, err + case "event": + cs := models.Events{} + var events []models.Events + tx := c.Pool.Table(cs.TableName()) + if modelName != "" { + tx.Where("system =1 and `name` LIKE ?", "%"+modelName+"%") + } else { + tx.Where("system =1") + } + err := tx.Find(&events).Error + return events, err + case "action": + cs := models.Actions{} + var actions []models.Actions + tx := c.Pool.Table(cs.TableName()) + if modelName != "" { + tx.Where("system =1 and `name` LIKE ?", "%"+modelName+"%") + } else { + tx.Where("system =1") + } + err := tx.Find(&actions).Error + return actions, err + + } + return nil, nil +} diff --git a/internal/hummingbird/core/infrastructure/mysql/thingmodeltemplate.go b/internal/hummingbird/core/infrastructure/mysql/thingmodeltemplate.go new file mode 100644 index 0000000..5862d0b --- /dev/null +++ b/internal/hummingbird/core/infrastructure/mysql/thingmodeltemplate.go @@ -0,0 +1,82 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package mysql + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func thingModelTemplateSearch(c *Client, offset int, limit int, req dtos.ThingModelTemplateRequest) (thingModelTemplate []models.ThingModelTemplate, count uint32, edgeXErr error) { + cs := models.ThingModelTemplate{} + var total int64 + tx := c.Pool.Table(cs.TableName()) + tx = sqlite.BuildCommonCondition(tx, cs, req.BaseSearchConditionQuery) + + if req.CategoryName != "" { + tx = tx.Where("`category_name` = ?", req.CategoryName) + } + + if req.CategoryKey != "" { + tx = tx.Where("`category_key` = ?", req.CategoryKey) + } + + err := tx.Count(&total).Error + if err != nil { + return []models.ThingModelTemplate{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "thing model template failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&thingModelTemplate).Error + if err != nil { + return []models.ThingModelTemplate{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "thing model template failed query from the database", err) + } + + return thingModelTemplate, uint32(total), nil +} + +func thingModelTemplateByCategoryKey(c *Client, categoryKey string) (thingModelInfo models.ThingModelTemplate, edgeXErr error) { + if categoryKey == "" { + return thingModelInfo, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "thing model template category key is empty", nil) + } + err := c.client.GetObject(&models.ThingModelTemplate{CategoryKey: categoryKey}, &thingModelInfo) + if err != nil { + if err == gorm.ErrRecordNotFound { + return thingModelInfo, errort.NewCommonErr(errort.ThingModelNotExist, fmt.Errorf("thing model template category key(%s) not found", categoryKey)) + } + return thingModelInfo, err + } + return +} + +func batchUpsertThingModelTemplate(c *Client, d []models.ThingModelTemplate) (int64, error) { + if len(d) <= 0 { + return 0, nil + } + tx := c.Pool.Session(&gorm.Session{FullSaveAssociations: true}).Clauses( + clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(d, sqlite.CreateBatchSize) + num := tx.RowsAffected + err := tx.Error + if err != nil { + return num, err + } + return num, nil +} diff --git a/internal/hummingbird/core/infrastructure/mysql/thinkmodel.go b/internal/hummingbird/core/infrastructure/mysql/thinkmodel.go new file mode 100644 index 0000000..3dae95c --- /dev/null +++ b/internal/hummingbird/core/infrastructure/mysql/thinkmodel.go @@ -0,0 +1,44 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package mysql + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/models" +) + +func abilityByCode(c *Client, model interface{}, code, productId string) (interface{}, error) { + var err error + switch model.(type) { + case models.Properties: + ability := models.Properties{} + err = c.Pool.Model(&models.Properties{}).Where("code = ? and product_id = ?", code, productId).Find(&ability).Error + if err != nil { + return nil, err + } else { + return ability, nil + } + case models.Events: + ability := models.Events{} + err = c.Pool.Model(&models.Events{}).Where("code = ? and product_id = ?", code, productId).Find(&ability).Error + if err != nil { + return nil, err + } else { + return ability, nil + } + default: + return nil, fmt.Errorf("ability type shoud be propery or event") + } +} diff --git a/internal/hummingbird/core/infrastructure/mysql/unit.go b/internal/hummingbird/core/infrastructure/mysql/unit.go new file mode 100644 index 0000000..d966756 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/mysql/unit.go @@ -0,0 +1,63 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package mysql + +import ( + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func unitSearch(c *Client, offset int, limit int, req dtos.UnitRequest) (units []models.Unit, count uint32, edgeXErr error) { + cs := models.Unit{} + var total int64 + tx := c.Pool.Table(cs.TableName()) + tx = sqlite.BuildCommonCondition(tx, cs, req.BaseSearchConditionQuery) + + if req.UnitName != "" { + tx = tx.Where("`unit_name` LIKE ?", "%"+req.UnitName+"%") + } + + err := tx.Count(&total).Error + if err != nil { + return []models.Unit{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "unit failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&units).Error + if err != nil { + return []models.Unit{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "unit failed query from the database", err) + } + + return units, uint32(total), nil +} + +func batchUpsertUnitTemplate(c *Client, d []models.Unit) (int64, error) { + if len(d) <= 0 { + return 0, nil + } + tx := c.Pool.Session(&gorm.Session{FullSaveAssociations: true}).Clauses( + clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(d, sqlite.CreateBatchSize) + num := tx.RowsAffected + err := tx.Error + if err != nil { + return num, err + } + return num, nil +} diff --git a/internal/hummingbird/core/infrastructure/mysql/user.go b/internal/hummingbird/core/infrastructure/mysql/user.go new file mode 100644 index 0000000..6ca4ef2 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/mysql/user.go @@ -0,0 +1,95 @@ +package mysql + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + + //"gitlab.com/tedge/edgex/internal/pkg/errort" + // + //"gitlab.com/tedge/edgex/internal/models" + "gorm.io/gorm" +) + +func (c *Client) GetUserByUserName(username string) (models.User, error) { + user, edgeXErr := userByUserName(c, username) + if edgeXErr != nil { + return user, edgeXErr + } + return user, nil +} + +func (c *Client) GetAllUser() ([]models.User, error) { + return getAllUser(c) +} + +func (c *Client) AddUsers(users []models.User) error { + return addUsers(c, users) +} + +func (c *Client) AddUser(u models.User) (models.User, error) { + return addUser(c, u) +} + +func (c *Client) UpdateUser(u models.User) error { + return updateUser(c, u) +} + +func userByUserName(c *Client, username string) (models.User, error) { + user := models.User{} + err := c.client.GetObject(&models.User{Username: username}, &user) + if err != nil { + if err == gorm.ErrRecordNotFound { + return user, errort.NewCommonEdgeX(errort.AppPasswordError, fmt.Sprintf("fail to query username %s", username), err) + } else { + return user, err + } + } + return user, nil +} + +func getAllUser(c *Client) ([]models.User, error) { + var users []models.User + if err := c.Pool.Find(&users).Error; err != nil { + return nil, err + } + return users, nil +} + +func addUsers(c *Client, users []models.User) error { + if len(users) <= 0 { + return nil + } + return c.Pool.Create(users).Error +} + +func updateUser(c *Client, u models.User) error { + err := c.Pool.Table(u.TableName()).Where(&models.User{ + Username: u.Username, + }).Save(&u).Error + if err != nil { + return err + } + return nil +} + +func userExist(c *Client, username string) (bool, error) { + exists, err := c.client.ExistObject(&models.User{Username: username}) + if err != nil { + return false, err + } + return exists, nil +} + +func addUser(c *Client, u models.User) (models.User, error) { + exists, edgeXErr := userExist(c, u.Username) + if edgeXErr != nil { + return u, edgeXErr + } else if exists { + return u, errort.NewCommonEdgeX(errort.DefaultNameRepeat, fmt.Sprintf("username %s exists", u.Username), edgeXErr) + } + if err := c.client.CreateObject(&u); err != nil { + return u, errort.NewCommonEdgeX(errort.DefaultSystemError, "user creation failed", err) + } + return u, nil +} diff --git a/internal/hummingbird/core/infrastructure/sqlite/alertrule.go b/internal/hummingbird/core/infrastructure/sqlite/alertrule.go new file mode 100644 index 0000000..cee1226 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/sqlite/alertrule.go @@ -0,0 +1,217 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package sqlite + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" + "time" +) + +func addAlertRule(c *Client, ds models.AlertRule) (alertRule models.AlertRule, edgeXErr error) { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + + err := c.client.CreateObject(&ds) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "alert rule creation failed", err) + } + return ds, edgeXErr +} + +func addAlertList(c *Client, ds models.AlertList) (alertRule models.AlertList, edgeXErr error) { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + + err := c.client.CreateObject(&ds) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "alert rule list creation failed", err) + } + return ds, edgeXErr +} + +func updateAlertRule(c *Client, dl models.AlertRule) error { + dl.Modified = utils.MakeTimestamp() + err := c.client.UpdateObject(&dl) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "alert rule update failed", err) + } + return nil +} + +func alertRuleById(c *Client, id string) (alertRule models.AlertRule, edgeXErr error) { + if id == "" { + return alertRule, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "alert rule id is empty", nil) + } + err := c.Pool.Table(alertRule.TableName()).First(&alertRule, id).Error + if err != nil { + if err == gorm.ErrRecordNotFound { + return alertRule, errort.NewCommonErr(errort.AlertRuleNotExist, fmt.Errorf("alert rule id(%s) not found", id)) + } + return alertRule, errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("query alert rule fail (Id:%s), %s", alertRule.Id, err)) + } + return +} + +func alertRuleSearch(c *Client, offset int, limit int, req dtos.AlertRuleSearchQueryRequest) (alertRules []models.AlertRule, count uint32, edgeXErr error) { + dp := models.AlertRule{} + var total int64 + tx := c.Pool.Table(dp.TableName()) + tx = sqlite.BuildCommonCondition(tx, dp, req.BaseSearchConditionQuery) + + if req.Name != "" { + tx = tx.Where("`name` = ?", req.Name) + } + if req.Status != "" { + tx = tx.Where("`status` = ?", req.Status) + } + + err := tx.Count(&total).Error + if err != nil { + return []models.AlertRule{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "alert rules failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&alertRules).Error + if err != nil { + return []models.AlertRule{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "alert rules failed query from the database", err) + } + + return alertRules, uint32(total), nil +} + +func alertListLastSend(c *Client, alertRuleId string) (alertList models.AlertList, edgeXErr error) { + al := models.AlertList{} + err := c.Pool.Table(al.TableName()).Where("alert_rule_id = ?", alertRuleId).Where("is_send", true).Order("created desc").Last(&alertList).Error + if err != nil { + return + } + return +} + +func alertListSearch(c *Client, offset int, limit int, req dtos.AlertSearchQueryRequest) (alertRules []dtos.AlertSearchQueryResponse, count uint32, edgeXErr error) { + var total int64 + dp := models.AlertList{} + tx := c.Pool.Table(dp.TableName()).Select("alert_list.id,alert_list.status," + + "alert_rule.name,alert_list.alert_result,alert_rule.alert_level,alert_list.trigger_time,alert_list.treated_time,alert_list.message,alert_list.is_send").Joins("left join alert_rule on alert_list.alert_rule_id = alert_rule.id") + //tx = sqlite.BuildCommonCondition(tx, dp, req.BaseSearchConditionQuery) + if req.Name != "" { + tx.Where("alert_rule.name LIKE ?", sqlite.MakeLikeParams(req.Name)) + } + if req.Status != "" { + tx.Where("alert_list.status = ?", req.Status) + + } + if req.AlertLevel != "" { + tx.Where("alert_rule.alert_level = ?", req.AlertLevel) + } + if req.TriggerStartTime > 0 && req.TriggerEndTime > 0 && req.TriggerEndTime-req.TriggerStartTime > 0 { + tx.Where("alert_list.trigger_time >= ?", req.TriggerStartTime) + tx.Where("alert_list.trigger_time <= ?", req.TriggerEndTime) + } + edgeXErr = tx.Count(&total).Error + if edgeXErr != nil { + return []dtos.AlertSearchQueryResponse{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "alert list failed query from the database", edgeXErr) + } + tx.Order("alert_list.created desc") + edgeXErr = tx.Offset(offset).Limit(limit).Scan(&alertRules).Error + if edgeXErr != nil { + return []dtos.AlertSearchQueryResponse{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "alert list failed query from the database", edgeXErr) + } + return alertRules, uint32(total), nil +} + +func deleteAlertRuleById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "alert rule id is empty", nil) + } + err := c.client.DeleteObject(&models.AlertRule{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "alert rule deletion failed", err) + } + return nil +} + +func alertRuleStart(c *Client, id string) error { + d := models.AlertRule{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where("id = ?", id).Updates(map[string]interface{}{"status": constants.RuleStart}).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "start alert rule failed", err) + } + return nil +} + +func alertRuleStop(c *Client, id string) error { + d := models.AlertRule{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where("id = ?", id).Updates(map[string]interface{}{"status": constants.RuleStop}).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "stop alert rule failed", err) + } + return nil +} + +//subQuery := db.Select("AVG(age)").Where("name LIKE ?", "name%").Table("users") +//db.Select("AVG(age) as avgage").Group("name").Having("AVG(age) > (?)", subQuery).Find(&results) +// SELECT AVG(age) as avgage FROM `users` GROUP BY `name` HAVING AVG(age) > (SELECT AVG(age) FROM `users` WHERE name LIKE "name%") + +func alertPlate(c *Client, beforeTime int64) (plate []dtos.AlertPlateQueryResponse, err error) { + d := models.AlertList{} + if beforeTime > 0 { + err = c.Pool.Table(d.TableName()).Raw( + "SELECT count(alert_list.id) AS count,alert_rule.alert_level FROM alert_list "+ + "JOIN alert_rule on alert_list.alert_rule_id = alert_rule.id and alert_list.created > (?) "+ + "GROUP BY alert_rule.alert_level", beforeTime).Scan(&plate).Error + } else { + err = c.Pool.Table(d.TableName()).Raw( + "SELECT count(alert_list.id) AS count,alert_rule.alert_level FROM alert_list " + + "JOIN alert_rule on alert_list.alert_rule_id = alert_rule.id" + + "GROUP BY alert_rule.alert_level").Scan(&plate).Error + } + + return +} + +func alertIgnore(c *Client, id string) (err error) { + d := models.AlertList{} + tx := c.Pool.Table(d.TableName()) + err = tx.Where("id = ?", id).Updates(map[string]interface{}{"status": constants.Ignore, "treated_time": time.Now().UnixMilli()}).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "alert ignore rule failed", err) + } + return nil +} + +func treatedIgnore(c *Client, id string, message string) (err error) { + d := models.AlertList{} + tx := c.Pool.Table(d.TableName()) + err = tx.Where("id = ?", id).Updates(map[string]interface{}{"status": constants.Treated, "message": message, "treated_time": time.Now().UnixMilli()}).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "alert ignore rule failed", err) + } + return nil +} diff --git a/internal/hummingbird/core/infrastructure/sqlite/category.go b/internal/hummingbird/core/infrastructure/sqlite/category.go new file mode 100644 index 0000000..1a7d1f8 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/sqlite/category.go @@ -0,0 +1,82 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package sqlite + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func categoryTemplateSearch(c *Client, offset int, limit int, req dtos.CategoryTemplateRequest) (categoryTemplates []models.CategoryTemplate, count uint32, edgeXErr error) { + cs := models.CategoryTemplate{} + var total int64 + tx := c.Pool.Table(cs.TableName()) + tx = sqlite.BuildCommonCondition(tx, cs, req.BaseSearchConditionQuery) + + if req.CategoryName != "" { + tx = tx.Where("`category_name` LIKE ?", "%"+req.CategoryName+"%") + } + + if req.Scene != "" { + tx = tx.Where("`scene` = ?", req.Scene) + } + + err := tx.Count(&total).Error + if err != nil { + return []models.CategoryTemplate{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "categoryTemplate failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&categoryTemplates).Error + if err != nil { + return []models.CategoryTemplate{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "categoryTemplate failed query from the database", err) + } + + return categoryTemplates, uint32(total), nil +} + +func categoryTemplateById(c *Client, id string) (categoryTemplateInfo models.CategoryTemplate, edgeXErr error) { + if id == "" { + return categoryTemplateInfo, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "categoryTemplate id is empty", nil) + } + err := c.client.GetObject(&models.CategoryTemplate{Id: id}, &categoryTemplateInfo) + if err != nil { + if err == gorm.ErrRecordNotFound { + return categoryTemplateInfo, errort.NewCommonErr(errort.CategoryNotExist, fmt.Errorf("categoryTemplate id(%s) not found", id)) + } + return categoryTemplateInfo, err + } + return +} + +func batchUpsertCategoryTemplate(c *Client, d []models.CategoryTemplate) (int64, error) { + if len(d) <= 0 { + return 0, nil + } + tx := c.Pool.Session(&gorm.Session{FullSaveAssociations: true}).Clauses( + clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(d, sqlite.CreateBatchSize) + num := tx.RowsAffected + err := tx.Error + if err != nil { + return num, err + } + return num, nil +} diff --git a/internal/hummingbird/core/infrastructure/sqlite/client.go b/internal/hummingbird/core/infrastructure/sqlite/client.go new file mode 100644 index 0000000..ec2ae55 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/sqlite/client.go @@ -0,0 +1,715 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package sqlite + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "gorm.io/gorm" + + clientSQLite "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" +) + +type Client struct { + Pool *gorm.DB + //cache interfaces.Cache + client clientSQLite.ClientSQLite + loggingClient logger.LoggingClient +} + +func NewClient(config dtos.Configuration, lc logger.LoggingClient) (c *Client, errEdgeX error) { + client, err := clientSQLite.NewGormClient(config, lc) + if err != nil { + errEdgeX = errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("database failed to init %w", err)) + return + } + // 自动建表 + if err = client.InitTable( + &models.DeviceLibrary{}, + &models.DeviceService{}, + &models.Device{}, + //&models.NetworkConfig{}, + &models.DockerConfig{}, + //&models.GatewayThingModel{}, + &models.AdvanceConfig{}, + //&models.OTAVersion{}, + //&models.CustomMQTT{}, + &models.CategoryTemplate{}, + &models.ThingModelTemplate{}, + &models.DriverClassify{}, + &models.User{}, + &models.LanguageSdk{}, + &models.Product{}, + &models.Properties{}, + &models.Actions{}, + &models.Events{}, + &models.Unit{}, + &models.MqttAuth{}, + &models.SystemMetrics{}, + &models.AlertRule{}, + &models.Scene{}, + &models.SceneLog{}, + &models.AlertList{}, + //&models.DeviceAlertRule{}, + &models.QuickNavigation{}, + &models.Doc{}, + &models.MsgGather{}, + &models.RuleEngine{}, + &models.DataResource{}, + ); err != nil { + errEdgeX = errort.NewCommonEdgeX(errort.DefaultSystemError, "database failed to init", err) + return + } + c = &Client{ + client: client, + loggingClient: lc, + Pool: client.Pool, + } + return +} + +//func (c *Client) WithCache() error { +// err := cache.NewCache() +// if err != nil { +// return err +// } +// c.cache = cache.GetCacheClient() +// return nil +//} + +// CloseSession closes the connections to Redis +func (c *Client) CloseSession() { + //sqlite no need close? +} + +func (c *Client) GetDBInstance() *gorm.DB { + return c.Pool +} + +func (c *Client) AddDeviceLibrary(dl models.DeviceLibrary) (models.DeviceLibrary, error) { + if len(dl.Id) == 0 { + dl.Id = utils.GenUUID() + } + return addDeviceLibrary(c, dl) +} + +func (c *Client) DockerConfigAdd(dc models.DockerConfig) (models.DockerConfig, error) { + if len(dc.Id) == 0 { + dc.Id = utils.GenUUID() + } + return dockerConfigAdd(c, dc) +} + +func (c *Client) DockerConfigById(id string) (models.DockerConfig, error) { + return dockerConfigById(c, id) +} + +func (c *Client) DockerConfigDelete(id string) error { + return dockerConfigDeleteById(c, id) +} + +func (c *Client) DockerConfigUpdate(dc models.DockerConfig) error { + return dockerConfigUpdate(c, dc) +} + +func (c *Client) DockerConfigsSearch(offset int, limit int, req dtos.DockerConfigSearchQueryRequest) (dcs []models.DockerConfig, total uint32, edgeXErr error) { + return dockerConfigsSearch(c, offset, limit, req) +} + +func (c *Client) DriverClassifySearch(offset int, limit int, req dtos.DriverClassifyQueryRequest) (dcs []models.DriverClassify, total uint32, edgeXErr error) { + return driverClassifySearch(c, offset, limit, req) +} + +func (c *Client) DeviceLibrariesSearch(offset int, limit int, req dtos.DeviceLibrarySearchQueryRequest) (deviceLibraries []models.DeviceLibrary, total uint32, edgeXErr error) { + deviceLibraries, total, edgeXErr = deviceLibrariesSearch(c, offset, limit, req) + if edgeXErr != nil { + return deviceLibraries, total, edgeXErr + } + return deviceLibraries, total, nil +} + +func (c *Client) DeviceServicesSearch(offset int, limit int, req dtos.DeviceServiceSearchQueryRequest) (deviceServices []models.DeviceService, total uint32, edgeXErr error) { + deviceServices, total, edgeXErr = deviceServicesSearch(c, offset, limit, req) + if edgeXErr != nil { + return deviceServices, 0, edgeXErr + } + return deviceServices, total, nil +} + +func (c *Client) DeviceLibraryById(id string) (deviceLibrary models.DeviceLibrary, edgeXErr error) { + return deviceLibraryById(c, id) +} + +func (c *Client) DeleteDeviceLibraryById(id string) error { + return deleteDeviceLibraryById(c, id) +} + +func (c *Client) AddDeviceService(ds models.DeviceService) (models.DeviceService, error) { + // 驱动实例和驱动id一样,为了防止容器实例名冲突导致数据冲突 + if len(ds.Id) == 0 { + ds.Id = utils.RandomNum() + } + ds.Name = ds.Name + "-" + ds.Id + return addDeviceService(c, ds) +} + +func (c *Client) UpdateDeviceService(ds models.DeviceService) error { + return updateDeviceService(c, ds) +} + +func (c *Client) UpdateDeviceLibrary(dl models.DeviceLibrary) error { + return updateDeviceLibrary(c, dl) +} + +func (c *Client) DeviceServiceById(id string) (deviceService models.DeviceService, edgeXErr error) { + deviceService, edgeXErr = deviceServiceById(c, id) + if edgeXErr != nil { + return deviceService, edgeXErr + } + + return +} + +func (c *Client) DeleteDeviceServiceById(id string) error { + return deleteDeviceServiceById(c, id) +} + +func (c *Client) ProductById(id string) (product models.Product, edgeXErr error) { + return productById(c, id) +} + +func (c *Client) AddProduct(ds models.Product) (product models.Product, edgeXErr error) { + if len(ds.Id) == 0 { + ds.Id = utils.RandomNum() + } + return addProduct(c, ds) +} + +func (c *Client) ProductByCloudId(id string) (product models.Product, edgeXErr error) { + return productByCloudId(c, id) +} + +func (c *Client) BatchUpsertProduct(p []models.Product) (int64, error) { + return batchUpsertProduct(c, p) +} + +func (c *Client) BatchSaveProduct(p []models.Product) error { + return batchSaveProduct(c, p) +} + +func (c *Client) BatchDeleteProduct(products []models.Product) error { + return batchDeleteProduct(c, products) +} + +func (c *Client) BatchDeleteProperties(propertiesIds []string) error { + return batchDeleteProperties(c, propertiesIds) +} + +func (c *Client) BatchDeleteSystemProperties() error { + return batchDeleteSystemProperties(c) +} + +func (c *Client) BatchInsertSystemProperties(p []models.Properties) (int64, error) { + return batchInsertSystemProperties(c, p) +} + +func (c *Client) BatchDeleteEvents(eventIds []string) error { + return batchDeleteEvents(c, eventIds) +} + +func (c *Client) BatchDeleteSystemEvents() error { + return batchDeleteSystemEvents(c) +} + +func (c *Client) BatchInsertSystemEvents(p []models.Events) (int64, error) { + return batchInsertSystemEvents(c, p) +} + +func (c *Client) BatchDeleteActions(actionIds []string) error { + return batchDeleteActions(c, actionIds) +} +func (c *Client) BatchDeleteSystemActions() error { + return batchDeleteSystemActions(c) +} + +func (c *Client) BatchInsertSystemActions(p []models.Actions) (int64, error) { + return batchInsertSystemActions(c, p) +} + +func (c *Client) DeleteProductById(id string) error { + return deleteProductById(c, id) +} + +func (c *Client) DeleteProductObject(product models.Product) error { + return deleteProductObject(c, product) +} + +func (c *Client) AssociationsDeleteProductObject(product models.Product) error { + return associationsDeleteProductObject(c, product) +} + +func (c *Client) UpdateProduct(ds models.Product) error { + return updateProduct(c, ds) +} + +func (c *Client) AssociationsUpdateProduct(ds models.Product) error { + return associationsUpdateProduct(c, ds) +} + +func (c *Client) BatchUpsertDevice(p []models.Device) (int64, error) { + return batchUpsertDevice(c, p) +} + +func (c *Client) ProductsSearch(offset int, limit int, preload bool, req dtos.ProductSearchQueryRequest) (products []models.Product, total uint32, edgeXErr error) { + products, total, edgeXErr = productsSearch(c, offset, limit, preload, req) + if edgeXErr != nil { + return products, 0, edgeXErr + } + return products, total, nil +} + +func (c *Client) DevicesSearch(offset int, limit int, req dtos.DeviceSearchQueryRequest) (devices []models.Device, total uint32, edgeXErr error) { + devices, total, edgeXErr = devicesSearch(c, offset, limit, req) + if edgeXErr != nil { + return devices, 0, edgeXErr + } + return devices, total, nil +} + +func (c *Client) DeviceById(id string) (device models.Device, edgeXErr error) { + return deviceById(c, id) +} + +func (c *Client) DeviceOnlineById(id string) (edgeXErr error) { + return deviceOnlineById(c, id) +} + +func (c *Client) DeviceOfflineById(id string) (edgeXErr error) { + return deviceOfflineById(c, id) +} + +func (c *Client) DeviceOfflineByCloudInstanceId(id string) (edgeXErr error) { + return deviceOfflineByCloudInstanceId(c, id) +} + +func (c *Client) MsgReportDeviceById(id string) (device models.Device, edgeXErr error) { + return msgReportDeviceById(c, id) +} + +func (c *Client) DeviceByCloudId(id string) (device models.Device, edgeXErr error) { + return deviceByCloudId(c, id) +} + +func (c *Client) DeviceMqttAuthInfo(id string) (device models.MqttAuth, edgeXErr error) { + return deviceMqttAuthInfo(c, id) +} + +func (c *Client) DriverMqttAuthInfo(id string) (device models.MqttAuth, edgeXErr error) { + return driverMqttAuthInfo(c, id) +} + +func (c *Client) AddDevice(ds models.Device) (deviceId string, edgeXErr error) { + if len(ds.Id) == 0 { + ds.Id = utils.RandomNum() + } + return addDevice(c, ds) +} + +func (c *Client) BatchDeleteDevice(ids []string) error { + return batchDeleteDevice(c, ids) +} + +func (c *Client) BatchUnBindDevice(ids []string) error { + return batchUnBindDevice(c, ids) +} + +func (c *Client) BatchBindDevice(ids []string, driverInstanceId string) error { + return batchBindDevice(c, ids, driverInstanceId) +} + +func (c *Client) DeleteDeviceById(id string) error { + return deleteDeviceById(c, id) +} + +func (c *Client) DeleteDeviceByCloudInstanceId(id string) error { + return deleteDeviceByCloudInstanceId(c, id) +} + +func (c *Client) UpdateDevice(ds models.Device) error { + return updateDevice(c, ds) +} + +func (c *Client) AbilityByCode(model interface{}, code, productId string) (interface{}, error) { + return abilityByCode(c, model, code, productId) +} + +func (c *Client) CategoryTemplateSearch(offset int, limit int, req dtos.CategoryTemplateRequest) ([]models.CategoryTemplate, uint32, error) { + return categoryTemplateSearch(c, offset, limit, req) +} + +func (c *Client) UnitSearch(offset int, limit int, req dtos.UnitRequest) ([]models.Unit, uint32, error) { + return unitSearch(c, offset, limit, req) +} + +func (c *Client) BatchUpsertUnitTemplate(p []models.Unit) (int64, error) { + return batchUpsertUnitTemplate(c, p) +} + +func (c *Client) CategoryTemplateById(id string) (models.CategoryTemplate, error) { + return categoryTemplateById(c, id) +} + +func (c *Client) BatchUpsertCategoryTemplate(p []models.CategoryTemplate) (int64, error) { + return batchUpsertCategoryTemplate(c, p) +} + +func (c *Client) ThingModelTemplateSearch(offset int, limit int, req dtos.ThingModelTemplateRequest) ([]models.ThingModelTemplate, uint32, error) { + return thingModelTemplateSearch(c, offset, limit, req) +} +func (c *Client) ThingModelTemplateByCategoryKey(categoryKey string) (models.ThingModelTemplate, error) { + return thingModelTemplateByCategoryKey(c, categoryKey) +} + +func (c *Client) BatchUpsertThingModelTemplate(p []models.ThingModelTemplate) (int64, error) { + return batchUpsertThingModelTemplate(c, p) +} + +func (c *Client) AddThingModelProperty(ds models.Properties) (models.Properties, error) { + if len(ds.Id) == 0 { + ds.Id = utils.RandomNum() + } + return addThingModelProperty(c, ds) +} + +func (c *Client) BatchUpsertThingModel(ds interface{}) (int64, error) { + return batchUpsertThingModel(c, ds) +} + +func (c *Client) AddThingModelEvent(ds models.Events) (models.Events, error) { + if len(ds.Id) == 0 { + ds.Id = utils.RandomNum() + } + return addThingModelEvent(c, ds) +} + +func (c *Client) AddThingModelAction(ds models.Actions) (models.Actions, error) { + if len(ds.Id) == 0 { + ds.Id = utils.RandomNum() + } + return addThingModelAction(c, ds) +} + +func (c *Client) UpdateThingModelProperty(ds models.Properties) error { + return updateThingModelProperty(c, ds) +} + +func (c *Client) UpdateThingModelEvent(ds models.Events) error { + return updateThingModelEvent(c, ds) +} + +func (c *Client) UpdateThingModelAction(ds models.Actions) error { + return updateThingModelAction(c, ds) +} + +func (c *Client) ThingModelDeleteProperty(id string) error { + return deleteThingModelPropertyById(c, id) +} + +func (c *Client) ThingModelDeleteEvent(id string) error { + return deleteThingModelEventById(c, id) +} + +func (c *Client) ThingModelDeleteAction(id string) error { + return deleteThingModelActionById(c, id) +} + +func (c *Client) ThingModelPropertyById(id string) (models.Properties, error) { + return thingModelPropertyById(c, id) +} + +func (c *Client) ThingModelEventById(id string) (models.Events, error) { + return thingModelEventById(c, id) +} + +func (c *Client) ThingModelActionsById(id string) (models.Actions, error) { + return thingModeActionById(c, id) +} +func (c *Client) SystemThingModelSearch(modelType string, ModelName string) (interface{}, error) { + return systemThingModelSearch(c, modelType, ModelName) +} + +func (c *Client) AddMqttAuthInfo(auth models.MqttAuth) (string, error) { + if len(auth.Id) == 0 { + auth.Id = utils.RandomNum() + } + return addMqttAuth(c, auth) +} + +func (c *Client) AddOrUpdateAuth(auth models.MqttAuth) error { + if len(auth.Id) == 0 { + auth.Id = utils.RandomNum() + } + return addOrUpdateAuth(c, auth) +} + +func (c *Client) AddAlertRule(alertRule models.AlertRule) (models.AlertRule, error) { + if len(alertRule.Id) == 0 { + alertRule.Id = utils.RandomNum() + } + return addAlertRule(c, alertRule) +} + +func (c *Client) AddAlertList(alertRule models.AlertList) (models.AlertList, error) { + if len(alertRule.Id) == 0 { + alertRule.Id = utils.RandomNum() + } + return addAlertList(c, alertRule) +} + +func (c *Client) UpdateAlertRule(rule models.AlertRule) error { + return updateAlertRule(c, rule) +} + +func (c *Client) AlertRuleById(id string) (models.AlertRule, error) { + return alertRuleById(c, id) +} + +func (c *Client) AlertRuleSearch(offset int, limit int, req dtos.AlertRuleSearchQueryRequest) (alertRules []models.AlertRule, total uint32, edgeXErr error) { + alertRules, total, edgeXErr = alertRuleSearch(c, offset, limit, req) + if edgeXErr != nil { + return alertRules, 0, edgeXErr + } + return alertRules, total, nil +} + +func (c *Client) AlertListSearch(offset int, limit int, req dtos.AlertSearchQueryRequest) (alertList []dtos.AlertSearchQueryResponse, total uint32, edgeXErr error) { + alertList, total, edgeXErr = alertListSearch(c, offset, limit, req) + if edgeXErr != nil { + return alertList, 0, edgeXErr + } + return alertList, total, nil +} + +func (c *Client) AlertIgnore(id string) (edgeXErr error) { + return alertIgnore(c, id) +} + +func (c *Client) TreatedIgnore(id, message string) (edgeXErr error) { + return treatedIgnore(c, id, message) +} + +func (c *Client) AlertListLastSend(alertRuleId string) (alertList models.AlertList, edgeXErr error) { + return alertListLastSend(c, alertRuleId) +} + +func (c *Client) DeleteAlertRuleById(id string) error { + return deleteAlertRuleById(c, id) +} + +func (c *Client) AlertRuleStart(id string) error { + return alertRuleStart(c, id) +} + +func (c *Client) AlertRuleStop(id string) error { + return alertRuleStop(c, id) +} + +func (c *Client) AlertPlate(beforeTime int64) (plate []dtos.AlertPlateQueryResponse, err error) { + return alertPlate(c, beforeTime) +} + +func (c *Client) QuickNavigationSearch(offset int, limit int, req dtos.QuickNavigationSearchQueryRequest) (quickNavigations []models.QuickNavigation, total uint32, edgeXErr error) { + quickNavigations, total, edgeXErr = quickNavigationSearch(c, offset, limit, req) + if edgeXErr != nil { + return quickNavigations, 0, edgeXErr + } + return quickNavigations, total, nil +} + +func (c *Client) DocsSearch(offset int, limit int, req dtos.DocsSearchQueryRequest) (docs []models.Doc, total uint32, edgeXErr error) { + docs, total, edgeXErr = docsSearch(c, offset, limit, req) + if edgeXErr != nil { + return docs, 0, edgeXErr + } + return docs, total, nil +} + +func (c *Client) BatchUpsertDocsTemplate(ds []models.Doc) (int64, error) { + return batchUpsertDocsTemplate(c, ds) +} + +func (c *Client) BatchUpsertQuickNavigationTemplate(ds []models.QuickNavigation) (int64, error) { + return batchUpsertQuickNavigationTemplate(c, ds) +} + +func (c *Client) DeleteQuickNavigation(id string) error { + return deleteQuickNavigation(c, id) +} + +func (c *Client) GetAdvanceConfig() (models.AdvanceConfig, error) { + return getAdvanceConfig(c) +} + +func (c *Client) UpdateAdvanceConfig(config models.AdvanceConfig) error { + return updateAdvanceConfig(c, config) +} + +func (c *Client) AddMsgGather(msgGather models.MsgGather) error { + if len(msgGather.Id) == 0 { + msgGather.Id = utils.RandomNum() + } + return addMsgGather(c, msgGather) +} + +func (c *Client) MsgGatherSearch(offset int, limit int, req dtos.MsgGatherSearchQueryRequest) (dcs []models.MsgGather, count uint32, edgeXErr error) { + return msgGatherSearch(c, offset, limit, req) +} + +func (c *Client) AddDataResource(dateResource models.DataResource) (string, error) { + if len(dateResource.Id) == 0 { + dateResource.Id = utils.RandomNum() + } + return addDataResource(c, dateResource) +} + +func (c *Client) UpdateDataResource(dateResource models.DataResource) error { + return updateDataResource(c, dateResource) +} + +func (c *Client) DelDataResource(id string) error { + return deleteDataResourceById(c, id) +} + +func (c *Client) UpdateDataResourceHealth(id string, health bool) error { + return updateDataResourceHealth(c, id, health) +} + +func (c *Client) SearchDataResource(offset int, limit int, req dtos.DataResourceSearchQueryRequest) (dataResource []models.DataResource, count uint32, edgeXErr error) { + return dataResourceSearch(c, offset, limit, req) +} + +func (c *Client) DataResourceById(id string) (models.DataResource, error) { + return dataResourceById(c, id) +} + +func (c *Client) AddRuleEngine(ruleEngine models.RuleEngine) (string, error) { + if len(ruleEngine.Id) == 0 { + ruleEngine.Id = utils.RandomNum() + } + return addRuleEngine(c, ruleEngine) +} + +func (c *Client) UpdateRuleEngine(ruleEngine models.RuleEngine) error { + return updateRuleEngine(c, ruleEngine) +} + +func (c *Client) RuleEngineById(id string) (ruleEngine models.RuleEngine, edgeXErr error) { + return ruleEngineById(c, id) +} + +func (c *Client) RuleEngineSearch(offset int, limit int, req dtos.RuleEngineSearchQueryRequest) (ruleEngine []models.RuleEngine, count uint32, edgeXErr error) { + return ruleEngineSearch(c, offset, limit, req) +} + +func (c *Client) RuleEngineStart(id string) error { + return ruleEngineStart(c, id) +} + +func (c *Client) RuleEngineStop(id string) error { + return ruleEngineStop(c, id) +} + +func (c *Client) DeleteRuleEngineById(id string) error { + return deleteRuleEngineById(c, id) +} + +func (c *Client) AddScene(scene models.Scene) (models.Scene, error) { + if len(scene.Id) == 0 { + scene.Id = utils.RandomNum() + } + return addScene(c, scene) +} + +func (c *Client) UpdateScene(scene models.Scene) error { + if len(scene.Id) == 0 { + scene.Id = utils.RandomNum() + } + return updateScene(c, scene) +} +func (c *Client) SceneById(id string) (models.Scene, error) { + return sceneById(c, id) +} + +func (c *Client) SceneStart(id string) error { + return sceneStart(c, id) +} + +func (c *Client) SceneStop(id string) error { + return sceneStop(c, id) +} + +func (c *Client) DeleteSceneById(id string) error { + return deleteSceneById(c, id) +} + +func (c *Client) SceneSearch(offset int, limit int, req dtos.SceneSearchQueryRequest) (scenes []models.Scene, total uint32, edgeXErr error) { + return sceneSearch(c, offset, limit, req) +} + +func (c *Client) AddSceneLog(sceneLog models.SceneLog) (models.SceneLog, error) { + if len(sceneLog.Id) == 0 { + sceneLog.Id = utils.RandomNum() + } + return addSceneLog(c, sceneLog) +} + +func (c *Client) SceneLogSearch(offset int, limit int, req dtos.SceneLogSearchQueryRequest) (sceneLogs []models.SceneLog, total uint32, edgeXErr error) { + return sceneLogSearch(c, offset, limit, req) +} + +func (c *Client) LanguageSdkByName(name string) (cloudService models.LanguageSdk, edgeXErr error) { + return languageByName(c, name) +} + +func (c *Client) LanguageSearch(offset int, limit int, req dtos.LanguageSDKSearchQueryRequest) (languages []models.LanguageSdk, count uint32, edgeXErr error) { + return languageSearch(c, offset, limit, req) +} + +func (c *Client) AddLanguageSdk(ls models.LanguageSdk) (language models.LanguageSdk, edgeXErr error) { + if len(ls.Id) == 0 { + ls.Id = utils.RandomNum() + } + return addLanguageSdk(c, ls) +} + +func (c *Client) UpdateLanguageSdk(ls models.LanguageSdk) error { + return updateLanguageSdk(c, ls) +} + +func (c *Client) UpdateSystemMetrics(metrics dtos.SystemMetrics) error { + return updateSystemMetrics(c, metrics) +} + +func (c *Client) GetSystemMetrics(start, end int64) ([]dtos.SystemMetrics, error) { + return getSystemMetrics(c, start, end) +} + +func (c *Client) RemoveRangeSystemMetrics(min, max string) error { + return removeRangeSystemMetrics(c, min, max) +} diff --git a/internal/hummingbird/core/infrastructure/sqlite/config.go b/internal/hummingbird/core/infrastructure/sqlite/config.go new file mode 100644 index 0000000..85bd549 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/sqlite/config.go @@ -0,0 +1,45 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package sqlite + +import ( + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "gorm.io/gorm" +) + +func updateAdvanceConfig(c *Client, config models.AdvanceConfig) error { + if err := c.client.UpdateObject(&config); err != nil { + return errort.NewCommonErr(errort.DefaultSystemError, err) + } + return nil +} + +func getAdvanceConfig(c *Client) (models.AdvanceConfig, error) { + var config models.AdvanceConfig + err := c.client.GetObject(&models.AdvanceConfig{ID: constants.DefaultAdvanceConfigID}, &config) + if err != nil { + if err == gorm.ErrRecordNotFound { + config.ID = constants.DefaultAdvanceConfigID + if err = c.client.CreateObject(&config); err != nil { + return models.AdvanceConfig{}, errort.NewCommonErr(errort.DefaultSystemError, err) + } + return config, nil + } + return models.AdvanceConfig{}, errort.NewCommonErr(errort.DefaultSystemError, err) + } + return config, nil +} diff --git a/internal/hummingbird/core/infrastructure/sqlite/dataresource.go b/internal/hummingbird/core/infrastructure/sqlite/dataresource.go new file mode 100644 index 0000000..a9cb87b --- /dev/null +++ b/internal/hummingbird/core/infrastructure/sqlite/dataresource.go @@ -0,0 +1,117 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package sqlite + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" +) + +func dataResourceById(c *Client, id string) (dateResource models.DataResource, edgeXErr error) { + if id == "" { + return dateResource, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "dateResource id is empty", nil) + } + err := c.Pool.Table(dateResource.TableName()).First(&dateResource, id).Error + if err != nil { + if err == gorm.ErrRecordNotFound { + return dateResource, errort.NewCommonErr(errort.DefaultResourcesNotFound, fmt.Errorf("dateResource id(%s) not found", id)) + } + return dateResource, errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("query dateResource fail (Id:%s), %s", dateResource.Id, err)) + } + return +} + +func addDataResource(c *Client, ds models.DataResource) (id string, edgeXErr error) { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + err := c.client.CreateObject(&ds) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "data resourced creation failed", err) + } + + return ds.Id, edgeXErr +} + +func updateDataResource(c *Client, dl models.DataResource) error { + dl.Modified = utils.MakeTimestamp() + err := c.client.UpdateObject(&dl) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "data resource update failed", err) + } + return nil +} + +func deleteDataResourceById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "id is empty", nil) + } + err := c.client.DeleteObject(&models.DataResource{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "data resourced deletion failed", err) + } + return nil +} + +func updateDataResourceHealth(c *Client, id string, health bool) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "id is empty", nil) + } + d := models.DataResource{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where("id = ?", id).Updates(map[string]interface{}{"health": health}).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "update data resource failed", err) + } + return nil +} + +func dataResourceSearch(c *Client, offset int, limit int, req dtos.DataResourceSearchQueryRequest) (dataResource []models.DataResource, count uint32, edgeXErr error) { + dl := models.DataResource{} + var total int64 + tx := c.Pool.Table(dl.TableName()) + tx = sqlite.BuildCommonCondition(tx, dl, req.BaseSearchConditionQuery) + // 特殊条件 + if req.Type != "" { + tx = tx.Where("`type` = ?", req.Type) + } + if req.Health != "" { + isHealth := true + if req.Health == SearchReqBoolTrue { + isHealth = true + } else { + isHealth = false + } + tx = tx.Where("`health` = ?", isHealth) + } + err := tx.Count(&total).Error + if err != nil { + return []models.DataResource{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "data resource failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&dataResource).Error + if err != nil { + return []models.DataResource{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "data resource failed query from the database", err) + } + + return dataResource, uint32(total), nil +} diff --git a/internal/hummingbird/core/infrastructure/sqlite/device.go b/internal/hummingbird/core/infrastructure/sqlite/device.go new file mode 100644 index 0000000..ceeed1d --- /dev/null +++ b/internal/hummingbird/core/infrastructure/sqlite/device.go @@ -0,0 +1,307 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package sqlite + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func deviceById(c *Client, id string) (device models.Device, edgeXErr error) { + if id == "" { + return device, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "device id is empty", nil) + } + //err := c.client.GetObject(&models.Device{Id: id}, &device) + err := c.Pool.Table(device.TableName()).Preload("Product").First(&device, id).Error + if err != nil { + if err == gorm.ErrRecordNotFound { + return device, errort.NewCommonErr(errort.DeviceNotExist, fmt.Errorf("device id(%s) not found", id)) + } + return device, errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("query device fail (Id:%s), %s", device.Id, err)) + } + return +} + +func deviceOnlineById(c *Client, id string) (edgeXErr error) { + d := models.Device{} + tx := c.Pool.Table(d.TableName()) + edgeXErr = tx.Where("id = ?", id).Updates(map[string]interface{}{"status": constants.DeviceStatusOnline, "last_online_time": utils.MakeTimestamp()}).Error + if edgeXErr != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "deviceOnlineById failed", edgeXErr) + } + return nil +} + +func deviceOfflineById(c *Client, id string) (edgeXErr error) { + d := models.Device{} + tx := c.Pool.Table(d.TableName()) + edgeXErr = tx.Where("id = ?", id).Updates(map[string]interface{}{"status": constants.DeviceStatusOffline}).Error + if edgeXErr != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "deviceOnlineById failed", edgeXErr) + } + return nil +} + +func deviceOfflineByCloudInstanceId(c *Client, id string) (edgeXErr error) { + d := models.Device{} + tx := c.Pool.Table(d.TableName()) + edgeXErr = tx.Where("cloud_instance_id = ?", id).Updates(map[string]interface{}{"status": constants.DeviceStatusOffline}).Error + if edgeXErr != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "deviceOnlineById failed", edgeXErr) + } + return nil +} + +func msgReportDeviceById(c *Client, id string) (device models.Device, edgeXErr error) { + if id == "" { + return device, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "device id is empty", nil) + } + //err := c.client.GetObject(&models.Device{Id: id}, &device) + err := c.Pool.Table(device.TableName()).Preload("Product").First(&device, id).Error + if err != nil { + if err == gorm.ErrRecordNotFound { + return device, errort.NewCommonErr(errort.DeviceNotExist, fmt.Errorf("device id(%s) not found", id)) + } + return device, errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("query device fail (Id:%s), %s", device.Id, err)) + } + return +} + +func deviceByCloudId(c *Client, id string) (device models.Device, edgeXErr error) { + if id == "" { + return device, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "device cloudId is empty", nil) + } + err := c.client.GetObject(&models.Device{CloudDeviceId: id}, &device) + if err != nil { + if err == gorm.ErrRecordNotFound { + return device, errort.NewCommonErr(errort.DeviceNotExist, fmt.Errorf("device cloudId (%s) not found", id)) + } + return device, errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("query device fail (cloudId:%s), %s", device.Id, err)) + } + return +} + +func devicesSearch(c *Client, offset int, limit int, req dtos.DeviceSearchQueryRequest) (devices []models.Device, count uint32, edgeXErr error) { + dp := models.Device{} + var total int64 + tx := c.Pool.Table(dp.TableName()) + tx = sqlite.BuildCommonCondition(tx, dp, req.BaseSearchConditionQuery) + + if req.Name != "" { + tx = tx.Where("`name` LIKE ?", sqlite.MakeLikeParams(req.Name)) + } + + if req.Platform != "" { + tx = tx.Where("`platform` = ?", req.Platform) + } + + if req.ProductId != "" { + tx = tx.Where("`product_id` = ?", req.ProductId) + } + + if req.CloudProductId != "" { + tx = tx.Where("`cloud_product_id` = ?", req.CloudProductId) + + } + if req.CloudInstanceId != "" { + tx = tx.Where("`cloud_instance_id` = ?", req.CloudInstanceId) + } + + if req.DriveInstanceId != "" { + tx = tx.Where("`drive_instance_id` = ?", req.DriveInstanceId) + } + if req.Status != "" { + tx = tx.Where("`status` = ?", req.Status) + } + + err := tx.Count(&total).Error + if err != nil { + return []models.Device{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "devices failed query from the database", err) + } + + err = tx.Offset(offset).Preload("Product").Limit(limit).Find(&devices).Error + if err != nil { + return []models.Device{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "devices failed query from the database", err) + } + + return devices, uint32(total), nil +} + +func batchUpsertDevice(c *Client, d []models.Device) (int64, error) { + if len(d) <= 0 { + return 0, nil + } + tx := c.Pool.Clauses( + clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(&d, 10000) + num := tx.RowsAffected + err := tx.Error + if err != nil { + return num, err + } + return num, nil +} + +func batchDeleteDevice(c *Client, ids []string) error { + d := models.Device{} + tx := c.Pool.Table(d.TableName()) + err := tx.Delete(d, ids).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "batchDeleteDevice failed", err) + } + return nil +} + +func batchUnBindDevice(c *Client, ids []string) error { + d := models.Device{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where("id IN ?", ids).Updates(map[string]interface{}{"drive_instance_id": ""}).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "batchDeleteDevice failed", err) + } + return nil +} + +func batchBindDevice(c *Client, ids []string, driverInstanceId string) error { + d := models.Device{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where("id IN ?", ids).Updates(map[string]interface{}{"drive_instance_id": driverInstanceId}).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "batchDeleteDevice failed", err) + } + return nil +} + +func deleteDeviceById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "device id is empty", nil) + } + err := c.client.DeleteObject(&models.Device{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "device deletion failed", err) + } + return nil +} + +func deleteDeviceByCloudInstanceId(c *Client, cloudInstanceId string) error { + if cloudInstanceId == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "cloudInstanceId id is empty", nil) + } + err := c.client.DeleteObject(&models.Device{CloudInstanceId: cloudInstanceId}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "device deletion failed", err) + } + return nil +} + +func updateDevice(c *Client, dl models.Device) error { + dl.Modified = utils.MakeTimestamp() + err := c.client.UpdateObject(&dl) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "device update failed", err) + } + return nil +} + +func addDevice(c *Client, device models.Device) (string, error) { + exists, edgeXErr := deviceNameExist(c, device.Name) + if edgeXErr != nil { + return "", edgeXErr + } else if exists { + return "", errort.NewCommonEdgeX(errort.DefaultNameRepeat, fmt.Sprintf("device name %s exists", device.Name), edgeXErr) + } + exists, edgeXErr = productIdExist(c, device.ProductId) + if edgeXErr != nil { + return "", edgeXErr + } else if !exists { + return "", errort.NewCommonEdgeX(errort.DeviceProductIdNotFound, fmt.Sprintf("device product %s not exists", device.ProductId), edgeXErr) + } + ts := utils.MakeTimestamp() + if device.Created == 0 { + device.Created = ts + } + device.Modified = ts + + err := c.client.CreateObject(&device) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "device creation failed", err) + } + + return device.Id, edgeXErr +} + +func addOrUpdateAuth(c *Client, auth models.MqttAuth) error { + exists, err := c.client.ExistObject(&models.MqttAuth{ClientId: auth.ClientId}) + if err != nil { + return err + } + if !exists { + ts := utils.MakeTimestamp() + if auth.Created == 0 { + auth.Created = ts + } + auth.Modified = ts + err = c.client.CreateObject(&auth) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "mqtt auch creation failed", err) + } + } + return nil +} + +func addMqttAuth(c *Client, auth models.MqttAuth) (string, error) { + var edgeXErr error + ts := utils.MakeTimestamp() + if auth.Created == 0 { + auth.Created = ts + } + auth.Modified = ts + + err := c.client.CreateObject(&auth) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "mqtt auch creation failed", err) + } + return auth.Id, edgeXErr +} + +func deviceNameExist(c *Client, name string) (bool, error) { + exists, err := c.client.ExistObject(&models.Product{Name: name}) + if err != nil { + return false, err + } + return exists, nil +} + +func deviceMqttAuthInfo(c *Client, id string) (mqttAuth models.MqttAuth, edgeXErr error) { + if id == "" { + return mqttAuth, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "device id is empty", nil) + } + err := c.client.GetObject(&models.MqttAuth{ResourceId: id, ResourceType: constants.DeviceResource}, &mqttAuth) + if err != nil { + if err == gorm.ErrRecordNotFound { + return mqttAuth, errort.NewCommonErr(errort.DefaultResourcesNotFound, fmt.Errorf("mqtt auth resoure id(%s) not found", id)) + } + return mqttAuth, errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("query mqtt auth fail (resoureId:%s), %s", id, err)) + } + return +} diff --git a/internal/hummingbird/core/infrastructure/sqlite/devicelibrary.go b/internal/hummingbird/core/infrastructure/sqlite/devicelibrary.go new file mode 100644 index 0000000..0dd2b67 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/sqlite/devicelibrary.go @@ -0,0 +1,157 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package sqlite + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" +) + +func deviceLibraryById(c *Client, id string) (deviceLibrary models.DeviceLibrary, edgeXErr error) { + if id == "" { + return deviceLibrary, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "deviceLibrary id is empty", nil) + } + err := c.client.GetObject(&models.DeviceLibrary{Id: id}, &deviceLibrary) + if err != nil { + if err == gorm.ErrRecordNotFound { + return deviceLibrary, errort.NewCommonErr(errort.DeviceLibraryNotExist, fmt.Errorf("device library id(%s) not found", id)) + } + return deviceLibrary, err + } + return +} + +func deleteDeviceLibraryById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "id is empty", nil) + } + err := c.client.DeleteObject(&models.DeviceLibrary{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "device library deletion failed", err) + } + return nil +} + +func deviceLibraryIdExists(c *Client, id string) (bool, error) { + exists, err := c.client.ExistObject(&models.DeviceLibrary{Id: id}) + if err != nil { + return false, err + } + return exists, nil +} + +func addDeviceLibrary(c *Client, dl models.DeviceLibrary) (models.DeviceLibrary, error) { + // query device library name and id to avoid the conflict + exists, edgeXErr := deviceLibraryIdExists(c, dl.Id) + if edgeXErr != nil { + return dl, edgeXErr + } else if exists { + return dl, errort.NewCommonEdgeX(errort.DefaultResourcesRepeat, fmt.Sprintf("device library id %s exists", dl.Id), edgeXErr) + } + + // check docker config id exists + exists, edgeXErr = dockerConfigIdExists(c, dl.DockerConfigId) + if edgeXErr != nil { + return dl, edgeXErr + } else if !exists { + return dl, errort.NewCommonEdgeX(errort.DockerImageRepositoryNotFound, fmt.Sprintf("docker config id %s not exists", dl.Id), edgeXErr) + } + + ts := utils.MakeTimestamp() + if dl.Created == 0 { + dl.Created = ts + } + dl.Modified = ts + + err := c.client.CreateObject(&dl) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "device library creation failed", err) + } + + return dl, edgeXErr +} + +const ( + // url中请求参数有判断真假的 + SearchReqBoolTrue = "true" + SearchReqBoolFalse = "false" +) + +func deviceLibrariesSearch(c *Client, offset int, limit int, req dtos.DeviceLibrarySearchQueryRequest) (deviceLibraries []models.DeviceLibrary, count uint32, edgeXErr error) { + dl := models.DeviceLibrary{} + var total int64 + tx := c.Pool.Table(dl.TableName()) + tx = sqlite.BuildCommonCondition(tx, dl, req.BaseSearchConditionQuery) + // 特殊条件 + if req.DockerConfigId != "" { + tx = tx.Where("`docker_config_id` = ?", req.DockerConfigId) + } + if req.IsInternal != "" { + isInternal := true + if req.IsInternal == SearchReqBoolTrue { + isInternal = true + } else { + isInternal = false + } + tx = tx.Where("`is_internal` = ?", isInternal) + } + if req.DockerRepoName != "" { + tx = tx.Where("`docker_repo_name` = ?", req.DockerRepoName) + } + if req.NameAliasLike != "" { + tx = tx.Where("`name` LIKE ? OR `alias` LIKE ? OR `description` LIKE ?", sqlite.MakeLikeParams(req.NameAliasLike), sqlite.MakeLikeParams(req.NameAliasLike), sqlite.MakeLikeParams(req.NameAliasLike)) + } + if req.NoInIds != "" { + tx = tx.Where("`id` NOT IN ?", dtos.ApiParamsStringToArray(req.NoInIds)) + } + if req.ImageIds != "" { + tx = tx.Where("`docker_image_id` IN ?", dtos.ApiParamsStringToArray(req.ImageIds)) + } + if req.NoInImageIds != "" { + tx = tx.Where("`docker_image_id` NOT IN ?", dtos.ApiParamsStringToArray(req.NoInImageIds)) + } + if req.DriverType != 0 { + tx = tx.Where("`driver_type` = ?", req.DriverType) + } + if req.ClassifyId != 0 { + tx = tx.Where("`classify_id` = ?", req.ClassifyId) + } + + err := tx.Count(&total).Error + if err != nil { + return []models.DeviceLibrary{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "deviceLibraries failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Order("id desc").Find(&deviceLibraries).Error + if err != nil { + return []models.DeviceLibrary{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "deviceLibraries failed query from the database", err) + } + + return deviceLibraries, uint32(total), nil +} + +func updateDeviceLibrary(c *Client, dl models.DeviceLibrary) error { + dl.Modified = utils.MakeTimestamp() + err := c.client.UpdateObject(&dl) + if err != nil { + return err + } + return nil +} diff --git a/internal/hummingbird/core/infrastructure/sqlite/deviceservice.go b/internal/hummingbird/core/infrastructure/sqlite/deviceservice.go new file mode 100644 index 0000000..6d74c12 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/sqlite/deviceservice.go @@ -0,0 +1,149 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package sqlite + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" +) + +func deviceServicesSearch(c *Client, offset int, limit int, req dtos.DeviceServiceSearchQueryRequest) (deviceServices []models.DeviceService, count uint32, edgeXErr error) { + ds := models.DeviceService{} + var total int64 + tx := c.Pool.Table(ds.TableName()) + tx = sqlite.BuildCommonCondition(tx, ds, req.BaseSearchConditionQuery) + + if req.DeviceLibraryId != "" { + tx = tx.Where("`device_library_id` = ?", req.DeviceLibraryId) + } + if req.DeviceLibraryIds != "" { + tx = tx.Where("`device_library_id` IN ?", dtos.ApiParamsStringToArray(req.DeviceLibraryIds)) + } + if req.DriverType != 0 { + tx = tx.Where("`driver_type` = ?", req.DriverType) + } + if req.CloudProductId != "" { + tx = tx.Where("`cloud_product_id` = ?", req.CloudProductId) + } + if req.ProductId != "" { + tx = tx.Where("`product_id` = ?", req.ProductId) + } + if req.Platform != "" { + tx = tx.Where("`platform` = ?", req.Platform) + + } + err := tx.Count(&total).Error + if err != nil { + return []models.DeviceService{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "deviceServices failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&deviceServices).Error + if err != nil { + return []models.DeviceService{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "deviceServices failed query from the database", err) + } + + return deviceServices, uint32(total), nil +} + +func deviceServiceIdExist(c *Client, id string) (bool, error) { + exists, err := c.client.ExistObject(&models.DeviceService{Id: id}) + if err != nil { + return false, err + } + return exists, nil +} + +func addDeviceService(c *Client, ds models.DeviceService) (addedDeviceService models.DeviceService, edgeXErr error) { + exists, edgeXErr := deviceServiceIdExist(c, ds.Id) + if edgeXErr != nil { + return ds, edgeXErr + } else if exists { + return ds, errort.NewCommonEdgeX(errort.DefaultResourcesRepeat, fmt.Sprintf("device service id %s exists", ds.Id), edgeXErr) + } + + exists, edgeXErr = deviceServiceIdExist(c, ds.Name) + if edgeXErr != nil { + return ds, edgeXErr + } else if exists { + return ds, errort.NewCommonEdgeX(errort.DefaultResourcesRepeat, fmt.Sprintf("device service name %s exists", ds.Name), edgeXErr) + } + + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + + err := c.client.CreateObject(&ds) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "device service creation failed", err) + } + + return ds, edgeXErr +} + +func updateDeviceService(c *Client, ds models.DeviceService) error { + ds.Modified = utils.MakeTimestamp() + err := c.client.UpdateObject(&ds) + if err != nil { + return err + } + return nil +} + +func deviceServiceById(c *Client, id string) (deviceService models.DeviceService, edgeXErr error) { + if id == "" { + return deviceService, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "device service id is empty", nil) + } + err := c.client.GetObject(&models.DeviceService{Id: id}, &deviceService) + if err != nil { + if err == gorm.ErrRecordNotFound { + return deviceService, errort.NewCommonErr(errort.DeviceServiceNotExist, fmt.Errorf("device service id(%s) not found", id)) + } + return deviceService, err + } + return +} + +func deleteDeviceServiceById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "device service id is empty", nil) + } + err := c.client.DeleteObject(&models.DeviceService{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "device service deletion failed", err) + } + return nil +} + +func driverMqttAuthInfo(c *Client, id string) (mqttAuth models.MqttAuth, edgeXErr error) { + if id == "" { + return mqttAuth, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "device id is empty", nil) + } + err := c.client.GetObject(&models.MqttAuth{ResourceId: id, ResourceType: constants.DriverResource}, &mqttAuth) + if err != nil { + if err == gorm.ErrRecordNotFound { + return mqttAuth, errort.NewCommonErr(errort.DefaultResourcesNotFound, fmt.Errorf("mqtt auth resoure id(%s) not found", id)) + } + return mqttAuth, errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("query mqtt auth fail (resoureId:%s), %s", id, err)) + } + return +} diff --git a/internal/hummingbird/core/infrastructure/sqlite/dockerconfig.go b/internal/hummingbird/core/infrastructure/sqlite/dockerconfig.go new file mode 100644 index 0000000..a441eca --- /dev/null +++ b/internal/hummingbird/core/infrastructure/sqlite/dockerconfig.go @@ -0,0 +1,107 @@ +package sqlite + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + + "gorm.io/gorm" + //"gitlab.com/tedge/edgex/internal/dtos" + //"gitlab.com/tedge/edgex/internal/models" + //"gitlab.com/tedge/edgex/internal/pkg/errort" + //"gitlab.com/tedge/edgex/internal/pkg/utils" + //"gitlab.com/tedge/edgex/internal/tools/sqldb/sqlite" +) + +func dockerConfigIdExists(c *Client, id string) (bool, error) { + exists, err := c.client.ExistObject(&models.DockerConfig{ + Id: id, + }) + if err != nil { + return false, err + } + return exists, nil +} + +func dockerConfigAdd(c *Client, dc models.DockerConfig) (models.DockerConfig, error) { + exists, edgeXErr := dockerConfigIdExists(c, dc.Id) + if edgeXErr != nil { + return dc, edgeXErr + } else if exists { + return dc, errort.NewCommonEdgeX(errort.DefaultResourcesRepeat, fmt.Sprintf("docker config %s exists", dc.Id), edgeXErr) + } + ts := utils.MakeTimestamp() + if dc.Created == 0 { + dc.Created = ts + } + dc.Modified = ts + + err := c.client.CreateObject(&dc) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "func point creation failed", err) + return dc, edgeXErr + } + + return dc, edgeXErr +} + +func dockerConfigById(c *Client, id string) (dc models.DockerConfig, edgeXErr error) { + if id == "" { + return dc, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "docker config id is empty", nil) + } + err := c.client.GetObject(&models.DockerConfig{Id: id}, &dc) + if err != nil { + if err == gorm.ErrRecordNotFound { + return dc, errort.NewCommonErr(errort.DockerConfigNotExist, fmt.Errorf("docker config id(%s)not found", id)) + } + return dc, err + } + return +} + +func dockerConfigDeleteById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "docker config id is empty", nil) + } + rawErr := c.client.DeleteObject(&models.DockerConfig{Id: id}) + if rawErr != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "docker config deletion failed", rawErr) + } + return nil +} + +func dockerConfigUpdate(c *Client, dc models.DockerConfig) error { + dc.Modified = utils.MakeTimestamp() + err := c.client.UpdateObject(&dc) + if err != nil { + return err + } + return nil +} + +func dockerConfigsSearch(c *Client, offset int, limit int, req dtos.DockerConfigSearchQueryRequest) (dcs []models.DockerConfig, count uint32, edgeXErr error) { + d := models.DockerConfig{} + var total int64 + tx := c.Pool.Table(d.TableName()) + tx = sqlite.BuildCommonCondition(tx, d, req.BaseSearchConditionQuery) + if req.Address != "" { + tx = tx.Where("`address` = ?", req.Address) + } + if req.Account != "" { + tx = tx.Where("`account` = ?", req.Account) + } + err := tx.Count(&total).Error + if err != nil { + return []models.DockerConfig{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "device failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&dcs).Error + if err != nil { + return []models.DockerConfig{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "device failed query from the database", err) + } + + return dcs, uint32(total), nil +} diff --git a/internal/hummingbird/core/infrastructure/sqlite/docs.go b/internal/hummingbird/core/infrastructure/sqlite/docs.go new file mode 100644 index 0000000..133f59d --- /dev/null +++ b/internal/hummingbird/core/infrastructure/sqlite/docs.go @@ -0,0 +1,60 @@ +/******************************************************************************* + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package sqlite + +import ( + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func docsSearch(c *Client, offset int, limit int, req dtos.DocsSearchQueryRequest) (docs []models.Doc, count uint32, edgeXErr error) { + dp := models.Doc{} + var total int64 + tx := c.Pool.Table(dp.TableName()) + tx = sqlite.BuildCommonCondition(tx, dp, req.BaseSearchConditionQuery) + + if req.Name != "" { + tx = tx.Where("`name` = ?", req.Name) + } + + err := tx.Count(&total).Error + if err != nil { + return []models.Doc{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "docs failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&docs).Error + if err != nil { + return []models.Doc{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "docs failed query from the database", err) + } + + return docs, uint32(total), nil +} + +func batchUpsertDocsTemplate(c *Client, d []models.Doc) (int64, error) { + if len(d) <= 0 { + return 0, nil + } + tx := c.Pool.Session(&gorm.Session{FullSaveAssociations: true}).Clauses( + clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(d, sqlite.CreateBatchSize) + num := tx.RowsAffected + err := tx.Error + if err != nil { + return num, err + } + return num, nil +} diff --git a/internal/hummingbird/core/infrastructure/sqlite/driverclassify.go b/internal/hummingbird/core/infrastructure/sqlite/driverclassify.go new file mode 100644 index 0000000..014f673 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/sqlite/driverclassify.go @@ -0,0 +1,41 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package sqlite + +import ( + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" +) + +func driverClassifySearch(c *Client, offset int, limit int, req dtos.DriverClassifyQueryRequest) (dcs []models.DriverClassify, count uint32, edgeXErr error) { + d := models.DriverClassify{} + var total int64 + tx := c.Pool.Table(d.TableName()) + tx = sqlite.BuildCommonCondition(tx, d, req.BaseSearchConditionQuery) + if req.Name != "" { + tx = tx.Where("`name` = ?", req.Name) + } + err := tx.Count(&total).Error + if err != nil { + return []models.DriverClassify{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "device failed query from the database", err) + } + err = tx.Offset(offset).Limit(limit).Find(&dcs).Error + if err != nil { + return []models.DriverClassify{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "device failed query from the database", err) + } + return dcs, uint32(total), nil +} diff --git a/internal/hummingbird/core/infrastructure/sqlite/language.go b/internal/hummingbird/core/infrastructure/sqlite/language.go new file mode 100644 index 0000000..8cefbf6 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/sqlite/language.go @@ -0,0 +1,82 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package sqlite + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" +) + +func languageSearch(c *Client, offset int, limit int, req dtos.LanguageSDKSearchQueryRequest) (languages []models.LanguageSdk, count uint32, edgeXErr error) { + cs := models.LanguageSdk{} + var total int64 + tx := c.Pool.Table(cs.TableName()) + tx = sqlite.BuildCommonCondition(tx, cs, req.BaseSearchConditionQuery) + + err := tx.Count(&total).Error + if err != nil { + return []models.LanguageSdk{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "language sdk failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&languages).Error + if err != nil { + return []models.LanguageSdk{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "language sdk failed query from the database", err) + } + + return languages, uint32(total), nil +} + +func languageByName(c *Client, name string) (language models.LanguageSdk, edgeXErr error) { + if name == "" { + return language, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "language sdk name id is empty", nil) + } + err := c.client.GetObject(&models.LanguageSdk{Name: name}, &language) + if err != nil { + if err == gorm.ErrRecordNotFound { + return language, errort.NewCommonErr(errort.DefaultResourcesNotFound, fmt.Errorf("language sdk (%s) not found", name)) + } + return language, err + } + return +} + +func addLanguageSdk(c *Client, cs models.LanguageSdk) (language models.LanguageSdk, edgeXErr error) { + ts := utils.MakeTimestamp() + if cs.Created == 0 { + cs.Created = ts + } + cs.Modified = ts + + err := c.client.CreateObject(&cs) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "language creation failed", err) + } + + return cs, edgeXErr +} + +func updateLanguageSdk(c *Client, dl models.LanguageSdk) error { + dl.Modified = utils.MakeTimestamp() + err := c.client.UpdateObject(&dl) + if err != nil { + return err + } + return nil +} diff --git a/internal/hummingbird/core/infrastructure/sqlite/monitor.go b/internal/hummingbird/core/infrastructure/sqlite/monitor.go new file mode 100644 index 0000000..ddea243 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/sqlite/monitor.go @@ -0,0 +1,48 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package sqlite + +import ( + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" +) + +func updateSystemMetrics(c *Client, metrics dtos.SystemMetrics) error { + var m = models.SystemMetrics{ + Data: metrics.String(), + Timestamp: metrics.Timestamp, + } + return c.client.CreateObject(&m) +} + +func getSystemMetrics(c *Client, start, end int64) ([]dtos.SystemMetrics, error) { + var list []models.SystemMetrics + if err := c.Pool.Where("timestamp >= ? and timestamp <= ?", start, end).Find(&list).Error; err != nil { + return nil, err + } + var metrics = make([]dtos.SystemMetrics, 0) + for _, item := range list { + m, err := dtos.FromModelsSystemMetricsToDTO(item) + if err != nil { + return nil, err + } + metrics = append(metrics, m) + } + return metrics, nil +} + +func removeRangeSystemMetrics(c *Client, min, max string) error { + return c.Pool.Where("timestamp >= ? and timestamp <= ?", min, max).Delete(&models.SystemMetrics{}).Error +} diff --git a/internal/hummingbird/core/infrastructure/sqlite/msggather.go b/internal/hummingbird/core/infrastructure/sqlite/msggather.go new file mode 100644 index 0000000..26f8c5d --- /dev/null +++ b/internal/hummingbird/core/infrastructure/sqlite/msggather.go @@ -0,0 +1,61 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package sqlite + +import ( + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" +) + +func addMsgGather(c *Client, msgGather models.MsgGather) error { + ts := utils.MakeTimestamp() + if msgGather.Created == 0 { + msgGather.Created = ts + } + msgGather.Modified = ts + + err := c.client.CreateObject(&msgGather) + if err != nil { + edgeXErr := errort.NewCommonEdgeX(errort.DefaultSystemError, "add msg gather failed", err) + return edgeXErr + } + + return nil +} + +func msgGatherSearch(c *Client, offset int, limit int, req dtos.MsgGatherSearchQueryRequest) (dcs []models.MsgGather, count uint32, edgeXErr error) { + d := models.MsgGather{} + var total int64 + tx := c.Pool.Table(d.TableName()) + tx = sqlite.BuildCommonCondition(tx, d, req.BaseSearchConditionQuery) + + if len(req.Date) > 0 { + tx = tx.Where("`date` in (?)", req.Date) + } + err := tx.Count(&total).Error + if err != nil { + return []models.MsgGather{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "msg gather failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&dcs).Error + if err != nil { + return []models.MsgGather{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "msg gather failed query from the database", err) + } + + return dcs, uint32(total), nil +} diff --git a/internal/hummingbird/core/infrastructure/sqlite/product.go b/internal/hummingbird/core/infrastructure/sqlite/product.go new file mode 100644 index 0000000..bb156ca --- /dev/null +++ b/internal/hummingbird/core/infrastructure/sqlite/product.go @@ -0,0 +1,317 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package sqlite + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func productNameExist(c *Client, name string) (bool, error) { + exists, err := c.client.ExistObject(&models.Product{Name: name}) + if err != nil { + return false, err + } + return exists, nil +} + +func productIdExist(c *Client, id string) (bool, error) { + exists, err := c.client.ExistObject(&models.Product{Id: id}) + if err != nil { + return false, err + } + return exists, nil +} + +func addProduct(c *Client, ds models.Product) (product models.Product, edgeXErr error) { + exists, edgeXErr := productNameExist(c, ds.Name) + if edgeXErr != nil { + return ds, edgeXErr + } else if exists { + return ds, errort.NewCommonEdgeX(errort.DefaultResourcesRepeat, fmt.Sprintf("product name %s exists", ds.Id), edgeXErr) + } + + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + + err := c.client.CreateObject(&ds) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "product creation failed", err) + } + + return ds, edgeXErr +} + +func productById(c *Client, id string) (product models.Product, edgeXErr error) { + if id == "" { + return product, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "product id is empty", nil) + } + err := c.client.GetPreloadObject(&models.Product{Id: id}, &product) + if err != nil { + if err == gorm.ErrRecordNotFound { + return product, errort.NewCommonErr(errort.ProductNotExist, fmt.Errorf("product id(%s) not found", id)) + } + return product, errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("query product fail (Id:%s), %s", product.Id, err)) + } + return +} + +func productByCloudId(c *Client, id string) (product models.Product, edgeXErr error) { + if id == "" { + return product, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "product id is empty", nil) + } + err := c.client.GetPreloadObject(&models.Product{CloudProductId: id}, &product) + if err != nil { + if err == gorm.ErrRecordNotFound { + return product, errort.NewCommonErr(errort.ProductNotExist, fmt.Errorf("product id(%s) not found", id)) + } + return product, errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("query product fail (Id:%s), %s", product.Id, err)) + } + //_ = c.cache.SetProduct(product) + return +} + +func productsSearch(c *Client, offset int, limit int, preload bool, req dtos.ProductSearchQueryRequest) (products []models.Product, count uint32, edgeXErr error) { + dp := models.Product{} + var total int64 + tx := c.Pool.Table(dp.TableName()) + tx = sqlite.BuildCommonCondition(tx, dp, req.BaseSearchConditionQuery) + + if req.Name != "" { + tx = tx.Where("`name` LIKE ?", sqlite.MakeLikeParams(req.Name)) + } + + if req.Platform != "" { + tx = tx.Where("`platform` = ?", req.Platform) + } + + if req.CloudInstanceId != "" { + tx = tx.Where("`cloud_instance_id` = ?", req.CloudInstanceId) + } + + if req.ProductId != "" { + tx = tx.Where("`product_id` = ?", req.ProductId) + } + + err := tx.Count(&total).Error + if err != nil { + return []models.Product{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "products failed query from the database", err) + } + if preload { + err = tx.Offset(offset).Preload("Properties").Preload("Events").Preload("Actions").Limit(limit).Find(&products).Error + } else { + err = tx.Offset(offset).Limit(limit).Find(&products).Error + } + if err != nil { + return []models.Product{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "products failed query from the database", err) + } + + return products, uint32(total), nil +} + +func batchUpsertProduct(c *Client, d []models.Product) (int64, error) { + if len(d) <= 0 { + return 0, nil + } + tx := c.Pool.Session(&gorm.Session{FullSaveAssociations: true}).Clauses( + clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(d, sqlite.CreateBatchSize) + num := tx.RowsAffected + err := tx.Error + if err != nil { + return num, err + } + return num, nil +} + +func batchSaveProduct(c *Client, d []models.Product) error { + if len(d) <= 0 { + return nil + } + tx := c.Pool.Session(&gorm.Session{FullSaveAssociations: true}).Clauses( + clause.OnConflict{ + UpdateAll: true, + }).Save(d) + //num := tx.RowsAffected + err := tx.Error + if err != nil { + return err + } + return nil +} + +func batchDeleteProduct(c *Client, products []models.Product) error { + d := models.Product{} + tx := c.Pool.Table(d.TableName()) + err := tx.Delete(&products).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "product batchDeleteProduct failed", err) + } + return nil +} + +func deleteProductById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "product id is empty", nil) + } + err := c.client.DeleteObject(&models.Product{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "product deletion failed", err) + } + return nil +} + +func deleteProductObject(c *Client, product models.Product) error { + err := c.client.DeleteObject(&product) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "product deletion failed", err) + } + return nil +} + +func associationsDeleteProductObject(c *Client, product models.Product) error { + err := c.client.AssociationsDeleteObject(&product) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "product deletion failed", err) + } + return nil +} + +func updateProduct(c *Client, dl models.Product) error { + dl.Modified = utils.MakeTimestamp() + err := c.client.UpdateObject(&dl) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "product update failed", err) + } + return nil +} + +func associationsUpdateProduct(c *Client, dl models.Product) error { + dl.Modified = utils.MakeTimestamp() + err := c.client.AssociationsUpdateObject(&dl) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "product update failed", err) + } + return nil +} + +func batchDeleteProperties(c *Client, propertiesIds []string) error { + d := models.Properties{} + tx := c.Pool.Table(d.TableName()) + err := tx.Delete(d, propertiesIds).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "batchDeleteProperties failed", err) + } + return nil +} + +func batchDeleteSystemProperties(c *Client) error { + d := models.Properties{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where(models.Properties{System: true}).Delete(d).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "batchDelete system property failed", err) + } + return nil +} + +func batchInsertSystemProperties(c *Client, p []models.Properties) (int64, error) { + if len(p) <= 0 { + return 0, nil + } + tx := c.Pool.CreateInBatches(p, sqlite.CreateBatchSize) + num := tx.RowsAffected + err := tx.Error + if err != nil { + return num, err + } + return num, nil +} + +func batchDeleteEvents(c *Client, eventIds []string) error { + d := models.Events{} + tx := c.Pool.Table(d.TableName()) + err := tx.Delete(d, eventIds).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "batchDeleteEvents failed", err) + } + return nil +} + +func batchDeleteSystemEvents(c *Client) error { + d := models.Events{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where(models.Events{System: true}).Delete(d).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "batch system Actions failed", err) + } + return nil +} + +func batchInsertSystemEvents(c *Client, p []models.Events) (int64, error) { + if len(p) <= 0 { + return 0, nil + } + tx := c.Pool.CreateInBatches(p, sqlite.CreateBatchSize) + num := tx.RowsAffected + err := tx.Error + if err != nil { + return num, err + } + return num, nil +} + +func batchDeleteActions(c *Client, actionIds []string) error { + d := models.Actions{} + tx := c.Pool.Table(d.TableName()) + err := tx.Delete(d, actionIds).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "batchDeleteActions failed", err) + } + return nil +} + +func batchDeleteSystemActions(c *Client) error { + d := models.Actions{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where(models.Actions{System: true}).Delete(d).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "batch Delete system Actions failed", err) + } + return nil +} + +func batchInsertSystemActions(c *Client, p []models.Actions) (int64, error) { + if len(p) <= 0 { + return 0, nil + } + tx := c.Pool.CreateInBatches(p, sqlite.CreateBatchSize) + num := tx.RowsAffected + err := tx.Error + if err != nil { + return num, err + } + return num, nil +} diff --git a/internal/hummingbird/core/infrastructure/sqlite/quicknavagation.go b/internal/hummingbird/core/infrastructure/sqlite/quicknavagation.go new file mode 100644 index 0000000..66be95a --- /dev/null +++ b/internal/hummingbird/core/infrastructure/sqlite/quicknavagation.go @@ -0,0 +1,71 @@ +/******************************************************************************* + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package sqlite + +import ( + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func quickNavigationSearch(c *Client, offset int, limit int, req dtos.QuickNavigationSearchQueryRequest) (quickNavigation []models.QuickNavigation, count uint32, edgeXErr error) { + dp := models.QuickNavigation{} + var total int64 + tx := c.Pool.Table(dp.TableName()) + tx = sqlite.BuildCommonCondition(tx, dp, req.BaseSearchConditionQuery) + + if req.Name != "" { + tx = tx.Where("`name` = ?", req.Name) + } + + err := tx.Count(&total).Error + if err != nil { + return []models.QuickNavigation{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "quick navigation failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&quickNavigation).Error + if err != nil { + return []models.QuickNavigation{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "quick navigation failed query from the database", err) + } + + return quickNavigation, uint32(total), nil +} + +func batchUpsertQuickNavigationTemplate(c *Client, d []models.QuickNavigation) (int64, error) { + if len(d) <= 0 { + return 0, nil + } + tx := c.Pool.Session(&gorm.Session{FullSaveAssociations: true}).Clauses( + clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(d, sqlite.CreateBatchSize) + num := tx.RowsAffected + err := tx.Error + if err != nil { + return num, err + } + return num, nil +} + +func deleteQuickNavigation(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "quick navigation id is empty", nil) + } + err := c.client.DeleteObject(&models.QuickNavigation{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "quick navigation deletion failed", err) + } + return nil +} diff --git a/internal/hummingbird/core/infrastructure/sqlite/ruleengine.go b/internal/hummingbird/core/infrastructure/sqlite/ruleengine.go new file mode 100644 index 0000000..9e46a9c --- /dev/null +++ b/internal/hummingbird/core/infrastructure/sqlite/ruleengine.go @@ -0,0 +1,120 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package sqlite + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" +) + +func addRuleEngine(c *Client, ds models.RuleEngine) (id string, edgeXErr error) { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + err := c.client.CreateObject(&ds) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "rule engine creation failed", err) + } + + return ds.Id, edgeXErr +} + +func ruleEngineById(c *Client, id string) (ruleEngine models.RuleEngine, edgeXErr error) { + if id == "" { + return ruleEngine, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "rule engine id is empty", nil) + } + err := c.Pool.Table(ruleEngine.TableName()).Preload("DataResource").First(&ruleEngine, id).Error + if err != nil { + if err == gorm.ErrRecordNotFound { + return ruleEngine, errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("rule engine id id(%s) not found", id)) + } + return ruleEngine, errort.NewCommonErr(errort.DefaultSystemError, fmt.Errorf("query rule engine id fail (Id:%s), %s", ruleEngine.Id, err)) + } + return +} + +func ruleEngineSearch(c *Client, offset int, limit int, req dtos.RuleEngineSearchQueryRequest) (ruleEngine []models.RuleEngine, count uint32, edgeXErr error) { + dp := models.RuleEngine{} + var total int64 + tx := c.Pool.Table(dp.TableName()) + tx = sqlite.BuildCommonCondition(tx, dp, req.BaseSearchConditionQuery) + + if req.Name != "" { + tx = tx.Where("`name` LIKE ?", sqlite.MakeLikeParams(req.Name)) + } + if req.Status != "" { + tx = tx.Where("`status` = ?", req.Status) + } + + err := tx.Count(&total).Error + if err != nil { + return ruleEngine, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "rules engine failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Preload("DataResource").Find(&ruleEngine).Error + if err != nil { + return ruleEngine, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "rules engine failed query from the database", err) + } + + return ruleEngine, uint32(total), nil +} + +func ruleEngineStart(c *Client, id string) error { + d := models.RuleEngine{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where("id = ?", id).Updates(map[string]interface{}{"status": constants.RuleStart}).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "start alert rule failed", err) + } + return nil +} + +func ruleEngineStop(c *Client, id string) error { + d := models.RuleEngine{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where("id = ?", id).Updates(map[string]interface{}{"status": constants.RuleStop}).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "stop alert rule failed", err) + } + return nil +} + +func deleteRuleEngineById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "rule engine id is empty", nil) + } + err := c.client.DeleteObject(&models.RuleEngine{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "rule engine deletion failed", err) + } + return nil +} + +func updateRuleEngine(c *Client, ruleEngine models.RuleEngine) error { + ruleEngine.Modified = utils.MakeTimestamp() + err := c.client.UpdateObject(&ruleEngine) + if err != nil { + return err + } + return nil +} diff --git a/internal/hummingbird/core/infrastructure/sqlite/scene.go b/internal/hummingbird/core/infrastructure/sqlite/scene.go new file mode 100644 index 0000000..52f4ce7 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/sqlite/scene.go @@ -0,0 +1,159 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package sqlite + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" +) + +func addScene(c *Client, ds models.Scene) (scene models.Scene, edgeXErr error) { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + + err := c.client.CreateObject(&ds) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "scene creation failed", err) + } + return ds, edgeXErr +} + +func updateScene(c *Client, dl models.Scene) error { + dl.Modified = utils.MakeTimestamp() + err := c.client.UpdateObject(&dl) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "scene update failed", err) + } + return nil +} + +func sceneById(c *Client, id string) (scene models.Scene, err error) { + if id == "" { + return scene, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "scene id is empty", nil) + } + err = c.client.GetObject(&models.Scene{Id: id}, &scene) + if err != nil { + if err == gorm.ErrRecordNotFound { + return scene, errort.NewCommonErr(errort.DefaultResourcesNotFound, fmt.Errorf("scene id(%s) not found", id)) + } + return scene, err + } + return +} + +func sceneStart(c *Client, id string) error { + d := models.Scene{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where("id = ?", id).Updates(map[string]interface{}{"status": constants.SceneStart}).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "start scene rule failed", err) + } + return nil +} + +func sceneStop(c *Client, id string) error { + d := models.Scene{} + tx := c.Pool.Table(d.TableName()) + err := tx.Where("id = ?", id).Updates(map[string]interface{}{"status": constants.SceneStop}).Error + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "start scene rule failed", err) + } + return nil +} + +func deleteSceneById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "del scene id is empty", nil) + } + err := c.client.DeleteObject(&models.Scene{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "del scene deletion failed", err) + } + return nil +} + +func sceneSearch(c *Client, offset int, limit int, req dtos.SceneSearchQueryRequest) (scene []models.Scene, count uint32, edgeXErr error) { + dp := models.Scene{} + var total int64 + tx := c.Pool.Table(dp.TableName()) + tx = sqlite.BuildCommonCondition(tx, dp, req.BaseSearchConditionQuery) + + if req.Name != "" { + tx = tx.Where("`name` LIKE ?", sqlite.MakeLikeParams(req.Name)) + } + if req.Status != "" { + tx = tx.Where("`status` = ?", req.Status) + } + err := tx.Count(&total).Error + if err != nil { + return scene, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "scene search failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&scene).Error + if err != nil { + return scene, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "scene search failed query from the database", err) + } + + return scene, uint32(total), nil +} + +func addSceneLog(c *Client, ds models.SceneLog) (sceneLog models.SceneLog, edgeXErr error) { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + + err := c.client.CreateObject(&ds) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "scene log creation failed", err) + } + return ds, edgeXErr +} + +func sceneLogSearch(c *Client, offset int, limit int, req dtos.SceneLogSearchQueryRequest) (sceneLogs []models.SceneLog, count uint32, edgeXErr error) { + dp := models.SceneLog{} + var total int64 + tx := c.Pool.Table(dp.TableName()) + tx = sqlite.BuildCommonCondition(tx, dp, req.BaseSearchConditionQuery) + + if req.StartAt > 0 && req.EndAt > 0 && req.EndAt-req.StartAt > 0 { + tx.Where("created > ?", req.StartAt).Where("created < ?", req.EndAt) + } + if req.SceneId != "" { + tx = tx.Where("`scene_id` = ?", req.SceneId) + } + + err := tx.Count(&total).Error + if err != nil { + return sceneLogs, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "scene log search failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&sceneLogs).Error + if err != nil { + return sceneLogs, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "scene log search failed query from the database", err) + } + + return sceneLogs, uint32(total), nil +} diff --git a/internal/hummingbird/core/infrastructure/sqlite/thingmodel.go b/internal/hummingbird/core/infrastructure/sqlite/thingmodel.go new file mode 100644 index 0000000..4a91d93 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/sqlite/thingmodel.go @@ -0,0 +1,230 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package sqlite + +import ( + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/utils" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func addThingModelProperty(c *Client, ds models.Properties) (models.Properties, error) { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + + var edgeXErr error + err := c.client.CreateObject(&ds) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "thing model property creation failed", err) + } + return ds, edgeXErr + +} +func batchUpsertThingModel(c *Client, d interface{}) (int64, error) { + tx := c.Pool.Session(&gorm.Session{FullSaveAssociations: true}).Clauses( + clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(d, sqlite.CreateBatchSize) + num := tx.RowsAffected + err := tx.Error + if err != nil { + return num, err + } + return num, nil +} + +func addThingModelEvent(c *Client, ds models.Events) (models.Events, error) { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + + var edgeXErr error + err := c.client.CreateObject(&ds) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "thing model event creation failed", err) + } + + return ds, edgeXErr + +} + +func addThingModelAction(c *Client, ds models.Actions) (models.Actions, error) { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + + var edgeXErr error + err := c.client.CreateObject(&ds) + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "thing model action creation failed", err) + } + + return ds, edgeXErr +} + +func updateThingModelProperty(c *Client, ds models.Properties) error { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + err := c.client.UpdateObject(&ds) + var edgeXErr error + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "thing model property update failed", err) + } + + return edgeXErr +} + +func updateThingModelEvent(c *Client, ds models.Events) error { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + err := c.client.UpdateObject(&ds) + var edgeXErr error + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "thing model events update failed", err) + } + + return edgeXErr +} + +func updateThingModelAction(c *Client, ds models.Actions) error { + ts := utils.MakeTimestamp() + if ds.Created == 0 { + ds.Created = ts + } + ds.Modified = ts + err := c.client.UpdateObject(&ds) + var edgeXErr error + if err != nil { + edgeXErr = errort.NewCommonEdgeX(errort.DefaultSystemError, "thing model action update failed", err) + } + + return edgeXErr +} + +func deleteThingModelPropertyById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "properties id is empty", nil) + } + err := c.client.DeleteObject(&models.Properties{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "properties deletion failed", err) + } + return nil +} + +func deleteThingModelEventById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "events id is empty", nil) + } + err := c.client.DeleteObject(&models.Events{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "events deletion failed", err) + } + return nil +} + +func deleteThingModelActionById(c *Client, id string) error { + if id == "" { + return errort.NewCommonEdgeX(errort.DefaultIdEmpty, "actions id is empty", nil) + } + err := c.client.DeleteObject(&models.Actions{Id: id}) + if err != nil { + return errort.NewCommonEdgeX(errort.DefaultSystemError, "actions deletion failed", err) + } + return nil +} + +func thingModelPropertyById(c *Client, id string) (models.Properties, error) { + cs := models.Properties{} + var properties models.Properties + tx := c.Pool.Table(cs.TableName()) + tx.Where("id = ?", id) + err := tx.Find(&properties).Error + return properties, err +} + +func thingModelEventById(c *Client, id string) (models.Events, error) { + cs := models.Events{} + var event models.Events + tx := c.Pool.Table(cs.TableName()) + tx.Where("id = ?", id) + err := tx.Find(&event).Error + return event, err +} + +func thingModeActionById(c *Client, id string) (models.Actions, error) { + cs := models.Actions{} + var action models.Actions + tx := c.Pool.Table(cs.TableName()) + tx.Where("id = ?", id) + err := tx.Find(&action).Error + return action, err +} + +func systemThingModelSearch(c *Client, modelType, modelName string) (interface{}, error) { + switch modelType { + case "property": + cs := models.Properties{} + var properties []models.Properties + tx := c.Pool.Table(cs.TableName()) + if modelName != "" { + tx.Where("system =1 and `name` LIKE ?", "%"+modelName+"%") + } else { + tx.Where("system =1") + } + err := tx.Find(&properties).Error + return properties, err + case "event": + cs := models.Events{} + var events []models.Events + tx := c.Pool.Table(cs.TableName()) + if modelName != "" { + tx.Where("system =1 and `name` LIKE ?", "%"+modelName+"%") + } else { + tx.Where("system =1") + } + err := tx.Find(&events).Error + return events, err + case "action": + cs := models.Actions{} + var actions []models.Actions + tx := c.Pool.Table(cs.TableName()) + if modelName != "" { + tx.Where("system =1 and `name` LIKE ?", "%"+modelName+"%") + } else { + tx.Where("system =1") + } + err := tx.Find(&actions).Error + return actions, err + + } + return nil, nil +} diff --git a/internal/hummingbird/core/infrastructure/sqlite/thingmodeltemplate.go b/internal/hummingbird/core/infrastructure/sqlite/thingmodeltemplate.go new file mode 100644 index 0000000..dd55f28 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/sqlite/thingmodeltemplate.go @@ -0,0 +1,82 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package sqlite + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func thingModelTemplateSearch(c *Client, offset int, limit int, req dtos.ThingModelTemplateRequest) (thingModelTemplate []models.ThingModelTemplate, count uint32, edgeXErr error) { + cs := models.ThingModelTemplate{} + var total int64 + tx := c.Pool.Table(cs.TableName()) + tx = sqlite.BuildCommonCondition(tx, cs, req.BaseSearchConditionQuery) + + if req.CategoryName != "" { + tx = tx.Where("`category_name` = ?", req.CategoryName) + } + + if req.CategoryKey != "" { + tx = tx.Where("`category_key` = ?", req.CategoryKey) + } + + err := tx.Count(&total).Error + if err != nil { + return []models.ThingModelTemplate{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "thing model template failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&thingModelTemplate).Error + if err != nil { + return []models.ThingModelTemplate{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "thing model template failed query from the database", err) + } + + return thingModelTemplate, uint32(total), nil +} + +func thingModelTemplateByCategoryKey(c *Client, categoryKey string) (thingModelInfo models.ThingModelTemplate, edgeXErr error) { + if categoryKey == "" { + return thingModelInfo, errort.NewCommonEdgeX(errort.DefaultIdEmpty, "thing model template category key is empty", nil) + } + err := c.client.GetObject(&models.ThingModelTemplate{CategoryKey: categoryKey}, &thingModelInfo) + if err != nil { + if err == gorm.ErrRecordNotFound { + return thingModelInfo, errort.NewCommonErr(errort.ThingModelNotExist, fmt.Errorf("thing model template category key(%s) not found", categoryKey)) + } + return thingModelInfo, err + } + return +} + +func batchUpsertThingModelTemplate(c *Client, d []models.ThingModelTemplate) (int64, error) { + if len(d) <= 0 { + return 0, nil + } + tx := c.Pool.Session(&gorm.Session{FullSaveAssociations: true}).Clauses( + clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(d, sqlite.CreateBatchSize) + num := tx.RowsAffected + err := tx.Error + if err != nil { + return num, err + } + return num, nil +} diff --git a/internal/hummingbird/core/infrastructure/sqlite/thinkmodel.go b/internal/hummingbird/core/infrastructure/sqlite/thinkmodel.go new file mode 100644 index 0000000..87b0a42 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/sqlite/thinkmodel.go @@ -0,0 +1,44 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package sqlite + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/models" +) + +func abilityByCode(c *Client, model interface{}, code, productId string) (interface{}, error) { + var err error + switch model.(type) { + case models.Properties: + ability := models.Properties{} + err = c.Pool.Model(&models.Properties{}).Where("code = ? and product_id = ?", code, productId).Find(&ability).Error + if err != nil { + return nil, err + } else { + return ability, nil + } + case models.Events: + ability := models.Events{} + err = c.Pool.Model(&models.Events{}).Where("code = ? and product_id = ?", code, productId).Find(&ability).Error + if err != nil { + return nil, err + } else { + return ability, nil + } + default: + return nil, fmt.Errorf("ability type shoud be propery or event") + } +} diff --git a/internal/hummingbird/core/infrastructure/sqlite/unit.go b/internal/hummingbird/core/infrastructure/sqlite/unit.go new file mode 100644 index 0000000..9643986 --- /dev/null +++ b/internal/hummingbird/core/infrastructure/sqlite/unit.go @@ -0,0 +1,63 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package sqlite + +import ( + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func unitSearch(c *Client, offset int, limit int, req dtos.UnitRequest) (units []models.Unit, count uint32, edgeXErr error) { + cs := models.Unit{} + var total int64 + tx := c.Pool.Table(cs.TableName()) + tx = sqlite.BuildCommonCondition(tx, cs, req.BaseSearchConditionQuery) + + if req.UnitName != "" { + tx = tx.Where("`unit_name` LIKE ?", "%"+req.UnitName+"%") + } + + err := tx.Count(&total).Error + if err != nil { + return []models.Unit{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "unit failed query from the database", err) + } + + err = tx.Offset(offset).Limit(limit).Find(&units).Error + if err != nil { + return []models.Unit{}, 0, errort.NewCommonEdgeX(errort.DefaultSystemError, "unit failed query from the database", err) + } + + return units, uint32(total), nil +} + +func batchUpsertUnitTemplate(c *Client, d []models.Unit) (int64, error) { + if len(d) <= 0 { + return 0, nil + } + tx := c.Pool.Session(&gorm.Session{FullSaveAssociations: true}).Clauses( + clause.OnConflict{ + UpdateAll: true, + }).CreateInBatches(d, sqlite.CreateBatchSize) + num := tx.RowsAffected + err := tx.Error + if err != nil { + return num, err + } + return num, nil +} diff --git a/internal/hummingbird/core/infrastructure/sqlite/user.go b/internal/hummingbird/core/infrastructure/sqlite/user.go new file mode 100644 index 0000000..3971a6f --- /dev/null +++ b/internal/hummingbird/core/infrastructure/sqlite/user.go @@ -0,0 +1,95 @@ +package sqlite + +import ( + "fmt" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + + //"gitlab.com/tedge/edgex/internal/pkg/errort" + // + //"gitlab.com/tedge/edgex/internal/models" + "gorm.io/gorm" +) + +func (c *Client) GetUserByUserName(username string) (models.User, error) { + user, edgeXErr := userByUserName(c, username) + if edgeXErr != nil { + return user, edgeXErr + } + return user, nil +} + +func (c *Client) GetAllUser() ([]models.User, error) { + return getAllUser(c) +} + +func (c *Client) AddUsers(users []models.User) error { + return addUsers(c, users) +} + +func (c *Client) AddUser(u models.User) (models.User, error) { + return addUser(c, u) +} + +func (c *Client) UpdateUser(u models.User) error { + return updateUser(c, u) +} + +func userByUserName(c *Client, username string) (models.User, error) { + user := models.User{} + err := c.client.GetObject(&models.User{Username: username}, &user) + if err != nil { + if err == gorm.ErrRecordNotFound { + return user, errort.NewCommonEdgeX(errort.AppPasswordError, fmt.Sprintf("fail to query username %s", username), err) + } else { + return user, err + } + } + return user, nil +} + +func getAllUser(c *Client) ([]models.User, error) { + var users []models.User + if err := c.Pool.Find(&users).Error; err != nil { + return nil, err + } + return users, nil +} + +func addUsers(c *Client, users []models.User) error { + if len(users) <= 0 { + return nil + } + return c.Pool.Create(users).Error +} + +func updateUser(c *Client, u models.User) error { + err := c.Pool.Table(u.TableName()).Where(&models.User{ + Username: u.Username, + }).Save(&u).Error + if err != nil { + return err + } + return nil +} + +func userExist(c *Client, username string) (bool, error) { + exists, err := c.client.ExistObject(&models.User{Username: username}) + if err != nil { + return false, err + } + return exists, nil +} + +func addUser(c *Client, u models.User) (models.User, error) { + exists, edgeXErr := userExist(c, u.Username) + if edgeXErr != nil { + return u, edgeXErr + } else if exists { + return u, errort.NewCommonEdgeX(errort.DefaultNameRepeat, fmt.Sprintf("username %s exists", u.Username), edgeXErr) + } + if err := c.client.CreateObject(&u); err != nil { + return u, errort.NewCommonEdgeX(errort.DefaultSystemError, "user creation failed", err) + } + return u, nil +} diff --git a/internal/hummingbird/core/initialize/init.go b/internal/hummingbird/core/initialize/init.go new file mode 100644 index 0000000..4a324a4 --- /dev/null +++ b/internal/hummingbird/core/initialize/init.go @@ -0,0 +1,357 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package initialize + +import ( + "context" + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/alertcentreapp" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/categorytemplate" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/dataresource" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/deviceapp" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/dmi" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/docapp" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/driverapp" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/driverserviceapp" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/homepageapp" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/languagesdkapp" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/messageapp" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/messagestore" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/monitor" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/persistence" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/productapp" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/quicknavigationapp" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/ruleengine" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/scene" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/thingmodelapp" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/thingmodeltemplate" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/timerapp" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/unittemplate" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/userapp" + "github.com/winc-link/hummingbird/internal/hummingbird/core/config" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + "github.com/winc-link/hummingbird/internal/hummingbird/core/controller/rpcserver/driverserver" + interfaces "github.com/winc-link/hummingbird/internal/hummingbird/core/interface" + "github.com/winc-link/hummingbird/internal/hummingbird/core/route" + "github.com/winc-link/hummingbird/internal/pkg/constants" + pkgContainer "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/cos" + "github.com/winc-link/hummingbird/internal/pkg/crontab" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/handlers" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "github.com/winc-link/hummingbird/internal/pkg/startup" + "github.com/winc-link/hummingbird/internal/pkg/timer/jobrunner" + "github.com/winc-link/hummingbird/internal/tools/ekuiperclient" + "github.com/winc-link/hummingbird/internal/tools/hpcloudclient" + "github.com/winc-link/hummingbird/internal/tools/notify/sms" + "github.com/winc-link/hummingbird/internal/tools/streamclient" + "google.golang.org/grpc" + "google.golang.org/grpc/reflection" + "sync" +) + +// Bootstrap contains references to dependencies required by the BootstrapHandler. +type Bootstrap struct { + router *gin.Engine +} + +// NewBootstrap is a factory method that returns an initialized Bootstrap receiver struct. +func NewBootstrap(router *gin.Engine) *Bootstrap { + return &Bootstrap{ + router: router, + } +} + +func (b *Bootstrap) BootstrapHandler(ctx context.Context, wg *sync.WaitGroup, _ startup.Timer, dic *di.Container) bool { + + configuration := container.ConfigurationFrom(dic.Get) + lc := pkgContainer.LoggingClientFrom(dic.Get) + + if !b.initClient(ctx, wg, dic, configuration, lc) { + return false + } + + // rpc 服务 + if ok := initRPCServer(ctx, wg, dic); !ok { + return false + } + lc.Infof("init rpc server") + + // http 路由 + route.LoadRestRoutes(b.router, dic) + + // 业务逻辑 + application.InitSchedule(dic, lc) + + wg.Add(1) + go func() { + defer wg.Done() + + <-ctx.Done() + crontab.Stop() + }() + + return true +} + +func (b *Bootstrap) initClient(ctx context.Context, wg *sync.WaitGroup, dic *di.Container, configuration *config.ConfigurationStruct, lc logger.LoggingClient) bool { + + appMode, err := dmi.New(dic, ctx, wg, dtos.DriverConfigManage{ + DockerManageConfig: dtos.DockerManageConfig{ + ContainerConfigPath: configuration.DockerManage.ContainerConfigPath, + DockerApiVersion: configuration.DockerManage.DockerApiVersion, + DockerRunMode: constants.NetworkModeHost, + DockerSelfName: constants.CoreServiceName, + Privileged: configuration.DockerManage.Privileged, + }, + }) + if err != nil { + lc.Error("create driver model interface error %v", err) + return false + } + + dic.Update(di.ServiceConstructorMap{ + interfaces.DriverModelInterfaceName: func(get di.Get) interface{} { + return appMode + }, + }) + homePageApp := homepageapp.NewHomePageApp(ctx, dic) + dic.Update(di.ServiceConstructorMap{ + container.HomePageAppName: func(get di.Get) interface{} { + return homePageApp + }, + }) + + languageApp := languagesdkapp.NewLanguageSDKApp(ctx, dic) + dic.Update(di.ServiceConstructorMap{ + container.LanguageSDKAppName: func(get di.Get) interface{} { + return languageApp + }, + }) + + monitorApp := monitor.NewMonitor(ctx, dic) + dic.Update(di.ServiceConstructorMap{ + container.MonitorAppName: func(get di.Get) interface{} { + return monitorApp + }, + }) + + streamClient := streamclient.NewStreamClient(lc) + dic.Update(di.ServiceConstructorMap{ + pkgContainer.StreamClientName: func(get di.Get) interface{} { + return streamClient + }, + }) + + driverApp := driverapp.NewDriverApp(ctx, dic) + dic.Update(di.ServiceConstructorMap{ + container.DriverAppName: func(get di.Get) interface{} { + return driverApp + }, + }) + + driverServiceApp := driverserviceapp.NewDriverServiceApp(ctx, dic) + dic.Update(di.ServiceConstructorMap{ + container.DriverServiceAppName: func(get di.Get) interface{} { + return driverServiceApp + }, + }) + + productApp := productapp.NewProductApp(ctx, dic) + dic.Update(di.ServiceConstructorMap{ + container.ProductAppName: func(get di.Get) interface{} { + return productApp + }, + }) + + thingModelApp := thingmodelapp.NewThingModelApp(ctx, dic) + dic.Update(di.ServiceConstructorMap{ + container.ThingModelAppName: func(get di.Get) interface{} { + return thingModelApp + }, + }) + + deviceApp := deviceapp.NewDeviceApp(ctx, dic) + dic.Update(di.ServiceConstructorMap{ + container.DeviceItfName: func(get di.Get) interface{} { + return deviceApp + }, + }) + + alertCentreApp := alertcentreapp.NewAlertCentreApp(ctx, dic) + dic.Update(di.ServiceConstructorMap{ + container.AlertRuleAppName: func(get di.Get) interface{} { + return alertCentreApp + }, + }) + + ruleEngineApp := ruleengine.NewRuleEngineApp(ctx, dic) + dic.Update(di.ServiceConstructorMap{ + container.RuleEngineAppName: func(get di.Get) interface{} { + return ruleEngineApp + }, + }) + + sceneApp := scene.NewSceneApp(ctx, dic) + dic.Update(di.ServiceConstructorMap{ + container.SceneAppName: func(get di.Get) interface{} { + return sceneApp + }, + }) + + conJobApp := timerapp.NewCronTimer(ctx, jobrunner.NewJobRunFunc(dic), dic) + dic.Update(di.ServiceConstructorMap{ + container.ConJobAppName: func(get di.Get) interface{} { + return conJobApp + }, + }) + + dataResourceApp := dataresource.NewDataResourceApp(ctx, dic) + dic.Update(di.ServiceConstructorMap{ + container.DataResourceName: func(get di.Get) interface{} { + return dataResourceApp + }, + }) + + cosApp := cos.NewCos("", "", "") + dic.Update(di.ServiceConstructorMap{ + container.CosAppName: func(get di.Get) interface{} { + return cosApp + }, + }) + + categoryTemplateApp := categorytemplate.NewCategoryTemplateApp(ctx, dic) + dic.Update(di.ServiceConstructorMap{ + container.CategoryTemplateAppName: func(get di.Get) interface{} { + return categoryTemplateApp + }, + }) + + unitTemplateApp := unittemplate.NewUnitTemplateApp(ctx, dic) + dic.Update(di.ServiceConstructorMap{ + container.UnitTemplateAppName: func(get di.Get) interface{} { + return unitTemplateApp + }, + }) + + docsApp := docapp.NewDocsApp(ctx, dic) + dic.Update(di.ServiceConstructorMap{ + container.DocsAppName: func(get di.Get) interface{} { + return docsApp + }, + }) + + quickNavigationApp := quicknavigationapp.NewQuickNavigationApp(ctx, dic) + dic.Update(di.ServiceConstructorMap{ + container.QuickNavigationAppName: func(get di.Get) interface{} { + return quickNavigationApp + }, + }) + + thingModelTemplateApp := thingmodeltemplate.NewThingModelTemplateApp(ctx, dic) + dic.Update(di.ServiceConstructorMap{ + container.ThingModelTemplateAppName: func(get di.Get) interface{} { + return thingModelTemplateApp + }, + }) + + hpcloudServiceApp := hpcloudclient.NewHpcloud(lc) + dic.Update(di.ServiceConstructorMap{ + container.HpcServiceAppName: func(get di.Get) interface{} { + return hpcloudServiceApp + }, + }) + + smsServiceApp := sms.NewSmsClient(lc, "", + "", "") + dic.Update(di.ServiceConstructorMap{ + container.SmsServiceAppName: func(get di.Get) interface{} { + return smsServiceApp + }, + }) + + limitMethodApp := application.NewLimitMethodConf(*configuration) + dic.Update(di.ServiceConstructorMap{ + pkgContainer.LimitMethodConfName: func(get di.Get) interface{} { + return limitMethodApp + }, + }) + + ekuiperApp := ekuiperclient.New(configuration.Clients["Ekuiper"].Address(), lc) + dic.Update(di.ServiceConstructorMap{ + container.EkuiperAppName: func(get di.Get) interface{} { + return ekuiperApp + }, + }) + + //agentApp := agentclient.New(configuration.Clients["Agent"].Address()) + //dic.Update(di.ServiceConstructorMap{ + // container.AgentClientName: func(get di.Get) interface{} { + // return agentApp + // }, + //}) + // + //cacheClient := localcache.NewRamCacheClient() + //dic.Update(di.ServiceConstructorMap{ + // pkgContainer.CacheFuncName: func(get di.Get) interface{} { + // return cacheClient + // }, + //}) + + persistItf := persistence.NewPersistApp(dic) + dic.Update(di.ServiceConstructorMap{ + container.PersistItfName: func(get di.Get) interface{} { + return persistItf + }, + }) + + userItf := userapp.New(dic) + dic.Update(di.ServiceConstructorMap{ + container.UserItfName: func(get di.Get) interface{} { + return userItf + }, + }) + + messageItf := messageapp.NewMessageApp(dic) + dic.Update(di.ServiceConstructorMap{ + container.MessageItfName: func(get di.Get) interface{} { + return messageItf + }, + }) + + messageStoreItf := messagestore.NewMessageStore(dic) + dic.Update(di.ServiceConstructorMap{ + container.MessageStoreItfName: func(get di.Get) interface{} { + return messageStoreItf + }, + }) + return true +} + +func initRPCServer(ctx context.Context, wg *sync.WaitGroup, dic *di.Container) bool { + lc := pkgContainer.LoggingClientFrom(dic.Get) + _, err := handlers.NewRPCServer(ctx, wg, dic, func(serve *grpc.Server) { + driverserver.RegisterRPCService(lc, dic, serve) + reflection.Register(serve) + }) + if err != nil { + lc.Errorf("initRPCServer err:%v", err) + return false + } + return true +} diff --git a/internal/hummingbird/core/initialize/initdata.go b/internal/hummingbird/core/initialize/initdata.go new file mode 100644 index 0000000..5e30bc7 --- /dev/null +++ b/internal/hummingbird/core/initialize/initdata.go @@ -0,0 +1,122 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package initialize + +import ( + "context" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + pkgContainer "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/pkg/logger" + "sync" +) + +func initApp(ctx context.Context, dic *di.Container) bool { + lc := pkgContainer.LoggingClientFrom(dic.Get) + dbClient := container.DBClientFrom(dic.Get) + _, edgeXErr := dbClient.GetUserByUserName("admin") + if edgeXErr != nil { + if errort.Is(errort.AppPasswordError, edgeXErr) { + var wg sync.WaitGroup + wg.Add(6) + go syncQuickNavigation(&wg, dic, lc) + go syncDocTemplate(&wg, dic, lc) + go syncUnitTemplate(&wg, dic, lc) + go syncCategory(&wg, dic, lc) + go syncThingModel(&wg, dic, lc) + go syncDocuments(&wg, dic, lc) + //go initEkuiperStreams(&wg, dic, lc) + wg.Wait() + lc.Infof("initApp end...") + } + + } + return true +} + +func syncCategory(wg *sync.WaitGroup, dic *di.Container, lc logger.LoggingClient) { + defer wg.Done() + categoryApp := container.CategoryTemplateAppFrom(dic.Get) + + _, err := categoryApp.Sync(context.Background(), "Ireland") + lc.Infof("sync category start...") + if err != nil { + lc.Errorf("sync category fail...") + } + lc.Infof("sync category success") + +} + +func syncUnitTemplate(wg *sync.WaitGroup, dic *di.Container, lc logger.LoggingClient) { + defer wg.Done() + unitTempApp := container.UnitTemplateAppFrom(dic.Get) + + _, err := unitTempApp.Sync(context.Background(), "Ireland") + lc.Infof("sync unit start") + + if err != nil { + lc.Errorf("sync unit fail") + } + lc.Infof("sync unit success") + +} + +func syncDocTemplate(wg *sync.WaitGroup, dic *di.Container, lc logger.LoggingClient) { + defer wg.Done() + docApp := container.DocsTemplateAppFrom(dic.Get) + + _, err := docApp.SyncDocs(context.Background(), "Ireland") + lc.Infof("sync doc start") + + if err != nil { + lc.Errorf("sync doc fail") + } + lc.Infof("sync doc success") + +} + +func syncQuickNavigation(wg *sync.WaitGroup, dic *di.Container, lc logger.LoggingClient) { + defer wg.Done() + quickApp := container.QuickNavigationAppTemplateAppFrom(dic.Get) + _, err := quickApp.SyncQuickNavigation(context.Background(), "Ireland") + lc.Infof("sync quickNavigation start") + if err != nil { + lc.Errorf("sync quickNavigation fail...", err.Error()) + } + lc.Infof("sync quickNavigation success") +} + +func syncThingModel(wg *sync.WaitGroup, dic *di.Container, lc logger.LoggingClient) { + defer wg.Done() + thingModelApp := container.ThingModelTemplateAppFrom(dic.Get) + _, err := thingModelApp.Sync(context.Background(), "Ireland") + lc.Infof("sync thingModel start") + if err != nil { + lc.Errorf("sync thingModel fail...", err.Error()) + } + lc.Infof("sync thingModel success") +} + +func syncDocuments(wg *sync.WaitGroup, dic *di.Container, lc logger.LoggingClient) { + defer wg.Done() + languageApp := container.LanguageAppNameFrom(dic.Get) + err := languageApp.Sync(context.Background(), "Ireland") + lc.Infof("sync language start") + if err != nil { + lc.Errorf("sync language fail...", err.Error()) + } + lc.Infof("sync language success") +} diff --git a/internal/hummingbird/core/interface/alertrule.go b/internal/hummingbird/core/interface/alertrule.go new file mode 100644 index 0000000..2cf2eb9 --- /dev/null +++ b/internal/hummingbird/core/interface/alertrule.go @@ -0,0 +1,52 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package interfaces + +import ( + "context" + "github.com/winc-link/hummingbird/internal/dtos" + //"github.com/winc-link/hummingbird/internal/dtos" +) + +type AlertRuleApp interface { + AddAlertRule(ctx context.Context, req dtos.RuleAddRequest) (string, error) + UpdateAlertRule(ctx context.Context, req dtos.RuleUpdateRequest) error + UpdateAlertField(ctx context.Context, req dtos.RuleFieldUpdate) error + AlertRuleById(ctx context.Context, id string) (dtos.RuleResponse, error) + AlertRulesSearch(ctx context.Context, req dtos.AlertRuleSearchQueryRequest) ([]dtos.AlertRuleSearchQueryResponse, uint32, error) + AlertRulesDelete(ctx context.Context, id string) error + AlertRulesStop(ctx context.Context, id string) error + AlertRulesStart(ctx context.Context, id string) error + AlertRulesRestart(ctx context.Context, id string) error + AlertIgnore(ctx context.Context, id string) error + TreatedIgnore(ctx context.Context, id, message string) error + AlertPlate(ctx context.Context, beforeTime int64) ([]dtos.AlertPlateQueryResponse, error) + AlertSearch(ctx context.Context, req dtos.AlertSearchQueryRequest) ([]dtos.AlertSearchQueryResponse, uint32, error) + AddAlert(ctx context.Context, req map[string]interface{}) error + CheckRuleByProductId(ctx context.Context, productId string) error + CheckRuleByDeviceId(ctx context.Context, deviceId string) error +} + +type RuleEngineApp interface { + AddRuleEngine(ctx context.Context, req dtos.RuleEngineRequest) (string, error) + UpdateRuleEngine(ctx context.Context, req dtos.RuleEngineUpdateRequest) error + UpdateRuleEngineField(ctx context.Context, req dtos.RuleEngineFieldUpdateRequest) error + RuleEngineById(ctx context.Context, id string) (dtos.RuleEngineResponse, error) + RuleEngineSearch(ctx context.Context, req dtos.RuleEngineSearchQueryRequest) ([]dtos.RuleEngineSearchQueryResponse, uint32, error) + RuleEngineDelete(ctx context.Context, id string) error + RuleEngineStop(ctx context.Context, id string) error + RuleEngineStart(ctx context.Context, id string) error + RuleEngineStatus(ctx context.Context, id string) (map[string]interface{}, error) +} diff --git a/internal/hummingbird/core/interface/categroy.go b/internal/hummingbird/core/interface/categroy.go new file mode 100644 index 0000000..c33398e --- /dev/null +++ b/internal/hummingbird/core/interface/categroy.go @@ -0,0 +1,44 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package interfaces + +import ( + "context" + "github.com/winc-link/hummingbird/internal/dtos" +) + +type CategoryApp interface { + CategoryTemplateSearch(ctx context.Context, req dtos.CategoryTemplateRequest) ([]dtos.CategoryTemplateResponse, uint32, error) + Sync(ctx context.Context, versionName string) (int64, error) +} + +type UnitApp interface { + UnitTemplateSearch(ctx context.Context, req dtos.UnitRequest) ([]dtos.UnitResponse, uint32, error) + Sync(ctx context.Context, versionName string) (int64, error) +} + +type DocsApp interface { + SyncDocs(ctx context.Context, versionName string) (int64, error) +} + +type QuickNavigation interface { + SyncQuickNavigation(ctx context.Context, versionName string) (int64, error) +} + +type ThingModelTemplateApp interface { + ThingModelTemplateSearch(ctx context.Context, req dtos.ThingModelTemplateRequest) ([]dtos.ThingModelTemplateResponse, uint32, error) + ThingModelTemplateByCategoryKey(ctx context.Context, categoryKey string) (dtos.ThingModelTemplateResponse, error) + Sync(ctx context.Context, versionName string) (int64, error) +} diff --git a/internal/hummingbird/core/interface/cosapp.go b/internal/hummingbird/core/interface/cosapp.go new file mode 100644 index 0000000..5cca412 --- /dev/null +++ b/internal/hummingbird/core/interface/cosapp.go @@ -0,0 +1,20 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package interfaces + +type CosApp interface { + Get(name string) ([]byte, error) + DownloadFiled(name, filepath string) error +} diff --git a/internal/hummingbird/core/interface/datadb.go b/internal/hummingbird/core/interface/datadb.go new file mode 100644 index 0000000..0c8a29b --- /dev/null +++ b/internal/hummingbird/core/interface/datadb.go @@ -0,0 +1,46 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package interfaces + +import ( + "context" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" +) + +type DataDBClient interface { + GetDataDBType() constants.DataType + CloseSession() + + Insert(ctx context.Context, table string, data map[string]interface{}) (err error) + GetDeviceProperty(req dtos.ThingModelPropertyDataRequest, device models.Device) ([]dtos.ReportData, int, error) + GetDeviceService(req dtos.ThingModelServiceDataRequest, device models.Device, product models.Product) ([]dtos.SaveServiceIssueData, int, error) + GetDeviceEvent(req dtos.ThingModelEventDataRequest, device models.Device, product models.Product) ([]dtos.EventData, int, error) + + CreateTable(ctx context.Context, stable, table string) (err error) + DropTable(ctx context.Context, table string) (err error) + + CreateStable(ctx context.Context, product models.Product) (err error) + DropStable(ctx context.Context, table string) (err error) + + AddDatabaseField(ctx context.Context, tableName string, specsType constants.SpecsType, code string, name string) (err error) + DelDatabaseField(ctx context.Context, tableName, code string) (err error) + ModifyDatabaseField(ctx context.Context, tableName string, specsType constants.SpecsType, code string, name string) (err error) + + GetDevicePropertyCount(dtos.ThingModelPropertyDataRequest) (int, error) + GetDeviceEventCount(req dtos.ThingModelEventDataRequest) (int, error) + GetDeviceMsgCountByGiveTime(deviceId string, startTime, endTime int64) (int, error) +} diff --git a/internal/hummingbird/core/interface/dataresource.go b/internal/hummingbird/core/interface/dataresource.go new file mode 100644 index 0000000..5e488a1 --- /dev/null +++ b/internal/hummingbird/core/interface/dataresource.go @@ -0,0 +1,32 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package interfaces + +import ( + "context" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/constants" +) + +type DataResourceApp interface { + AddDataResource(ctx context.Context, req dtos.AddDataResourceReq) (string, error) + DataResourceById(ctx context.Context, id string) (models.DataResource, error) + UpdateDataResource(ctx context.Context, req dtos.UpdateDataResource) error + DelDataResourceById(ctx context.Context, id string) error + DataResourceSearch(ctx context.Context, req dtos.DataResourceSearchQueryRequest) ([]models.DataResource, uint32, error) + DataResourceType(ctx context.Context) []constants.DataResourceType + DataResourceHealth(ctx context.Context, resourceId string) error +} diff --git a/internal/hummingbird/core/interface/db.go b/internal/hummingbird/core/interface/db.go new file mode 100644 index 0000000..8d13430 --- /dev/null +++ b/internal/hummingbird/core/interface/db.go @@ -0,0 +1,187 @@ +// +// Copyright (C) 2020-2021 IOTech Ltd +// +// SPDX-License-Identifier: Apache-2.0 + +package interfaces + +import ( + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "gorm.io/gorm" +) + +type DBClient interface { + CloseSession() + GetDBInstance() *gorm.DB + + QuickNavigationSearch(offset int, limit int, req dtos.QuickNavigationSearchQueryRequest) ([]models.QuickNavigation, uint32, error) + DocsSearch(offset int, limit int, req dtos.DocsSearchQueryRequest) (docs []models.Doc, total uint32, edgeXErr error) + BatchUpsertDocsTemplate(d []models.Doc) (int64, error) + BatchUpsertQuickNavigationTemplate(ds []models.QuickNavigation) (int64, error) + DeleteQuickNavigation(id string) error + + AddDeviceService(ds models.DeviceService) (models.DeviceService, error) + DeviceServiceById(id string) (models.DeviceService, error) + DeleteDeviceServiceById(id string) error + UpdateDeviceService(ds models.DeviceService) error + DeviceServicesSearch(offset int, limit int, req dtos.DeviceServiceSearchQueryRequest) ([]models.DeviceService, uint32, error) + + AddProduct(ds models.Product) (models.Product, error) + ProductsSearch(offset int, limit int, preload bool, req dtos.ProductSearchQueryRequest) ([]models.Product, uint32, error) + ProductById(id string) (models.Product, error) + ProductByCloudId(id string) (models.Product, error) + BatchUpsertProduct(d []models.Product) (int64, error) + BatchSaveProduct(p []models.Product) error + BatchDeleteProduct(products []models.Product) error + BatchDeleteProperties(propertiesId []string) error + BatchDeleteSystemProperties() error + BatchInsertSystemProperties(p []models.Properties) (int64, error) + BatchDeleteSystemEvents() error + BatchInsertSystemEvents(p []models.Events) (int64, error) + BatchDeleteEvents(eventId []string) error + BatchDeleteActions(actionId []string) error + DeleteProductById(id string) error + BatchDeleteSystemActions() error + BatchInsertSystemActions(p []models.Actions) (int64, error) + DeleteProductObject(d models.Product) error + UpdateProduct(ds models.Product) error + AssociationsUpdateProduct(ds models.Product) error + AssociationsDeleteProductObject(ds models.Product) error + + AddThingModelProperty(ds models.Properties) (models.Properties, error) + BatchUpsertThingModel(ds interface{}) (int64, error) + AddThingModelEvent(ds models.Events) (models.Events, error) + AddThingModelAction(ds models.Actions) (models.Actions, error) + UpdateThingModelProperty(ds models.Properties) error + UpdateThingModelEvent(ds models.Events) error + UpdateThingModelAction(ds models.Actions) error + ThingModelDeleteProperty(id string) error + ThingModelDeleteEvent(id string) error + ThingModelDeleteAction(id string) error + ThingModelPropertyById(id string) (models.Properties, error) + ThingModelEventById(id string) (models.Events, error) + ThingModelActionsById(id string) (models.Actions, error) + SystemThingModelSearch(modelType, modelName string) (interface{}, error) + + CategoryTemplateSearch(offset int, limit int, req dtos.CategoryTemplateRequest) ([]models.CategoryTemplate, uint32, error) + CategoryTemplateById(id string) (models.CategoryTemplate, error) + BatchUpsertCategoryTemplate(d []models.CategoryTemplate) (int64, error) + ThingModelTemplateSearch(offset int, limit int, req dtos.ThingModelTemplateRequest) ([]models.ThingModelTemplate, uint32, error) + ThingModelTemplateByCategoryKey(categoryKey string) (models.ThingModelTemplate, error) + BatchUpsertThingModelTemplate(d []models.ThingModelTemplate) (int64, error) + + UnitSearch(offset int, limit int, req dtos.UnitRequest) ([]models.Unit, uint32, error) + BatchUpsertUnitTemplate(d []models.Unit) (int64, error) + + AddDevice(d models.Device) (string, error) + DeviceById(id string) (models.Device, error) + DeviceOnlineById(id string) (edgeXErr error) + DeviceOfflineById(id string) (edgeXxErr error) + DeviceOfflineByCloudInstanceId(id string) (edgeXErr error) + MsgReportDeviceById(id string) (device models.Device, edgeXErr error) + DeviceByCloudId(id string) (models.Device, error) + DevicesSearch(offset int, limit int, req dtos.DeviceSearchQueryRequest) ([]models.Device, uint32, error) + DeviceMqttAuthInfo(id string) (device models.MqttAuth, edgeXErr error) + DriverMqttAuthInfo(id string) (device models.MqttAuth, edgeXErr error) + AddMqttAuthInfo(auth models.MqttAuth) (string, error) + AddOrUpdateAuth(auth models.MqttAuth) error + BatchUpsertDevice(d []models.Device) (int64, error) + BatchDeleteDevice(deviceIds []string) error + BatchUnBindDevice(ids []string) error + BatchBindDevice(ids []string, driverInstanceId string) error + DeleteDeviceById(id string) error + UpdateDevice(ds models.Device) error + DeleteDeviceByCloudInstanceId(cloudInstanceId string) error + AddDeviceLibrary(dl models.DeviceLibrary) (models.DeviceLibrary, error) + DeviceLibraryById(id string) (models.DeviceLibrary, error) + DeleteDeviceLibraryById(id string) error + DeviceLibrariesSearch(offset int, limit int, req dtos.DeviceLibrarySearchQueryRequest) ([]models.DeviceLibrary, uint32, error) + UpdateDeviceLibrary(dl models.DeviceLibrary) error + + DriverClassifySearch(offset int, limit int, req dtos.DriverClassifyQueryRequest) ([]models.DriverClassify, uint32, error) + + DockerConfigAdd(cfg models.DockerConfig) (models.DockerConfig, error) + DockerConfigById(id string) (models.DockerConfig, error) + DockerConfigUpdate(cfg models.DockerConfig) error + DockerConfigDelete(id string) error + DockerConfigsSearch(offset int, limit int, req dtos.DockerConfigSearchQueryRequest) ([]models.DockerConfig, uint32, error) + + AbilityByCode(model interface{}, code, productId string) (interface{}, error) + + //// 获取高级配置信息 + GetAdvanceConfig() (models.AdvanceConfig, error) + // 更新高级配置信息 + UpdateAdvanceConfig(config models.AdvanceConfig) error + + AddMsgGather(msgGather models.MsgGather) error + MsgGatherSearch(offset int, limit int, req dtos.MsgGatherSearchQueryRequest) (msgGather []models.MsgGather, count uint32, edgeXErr error) + + AddDataResource(dateResource models.DataResource) (string, error) + UpdateDataResource(dateResource models.DataResource) error + DelDataResource(id string) error + //DataResourceById(id string) models.DataResource + + UpdateDataResourceHealth(id string, health bool) error + SearchDataResource(offset int, limit int, req dtos.DataResourceSearchQueryRequest) (dataResource []models.DataResource, count uint32, edgeXErr error) + DataResourceById(id string) (models.DataResource, error) + AddRuleEngine(ruleEngine models.RuleEngine) (string, error) + UpdateRuleEngine(ruleEngine models.RuleEngine) error + RuleEngineById(id string) (ruleEngine models.RuleEngine, edgeXErr error) + RuleEngineSearch(offset int, limit int, req dtos.RuleEngineSearchQueryRequest) (ruleEngine []models.RuleEngine, count uint32, edgeXErr error) + RuleEngineStart(id string) error + RuleEngineStop(id string) error + DeleteRuleEngineById(id string) error + + LanguageSdkByName(name string) (cloudService models.LanguageSdk, edgeXErr error) + LanguageSearch(offset int, limit int, req dtos.LanguageSDKSearchQueryRequest) (languages []models.LanguageSdk, count uint32, edgeXErr error) + AddLanguageSdk(cs models.LanguageSdk) (language models.LanguageSdk, edgeXErr error) + UpdateLanguageSdk(ls models.LanguageSdk) error + + DeviceAlert + UserDB + Scene + SystemMonitor +} + +type SystemMonitor interface { + UpdateSystemMetrics(stats dtos.SystemMetrics) error + GetSystemMetrics(start, end int64) ([]dtos.SystemMetrics, error) + RemoveRangeSystemMetrics(min, max string) error +} + +type UserDB interface { + GetUserByUserName(username string) (models.User, error) + UpdateUser(u models.User) error + AddUser(u models.User) (models.User, error) +} + +type DeviceAlert interface { + AddAlertRule(rule models.AlertRule) (models.AlertRule, error) + UpdateAlertRule(rule models.AlertRule) error + AlertRuleById(id string) (models.AlertRule, error) + AlertRuleSearch(offset int, limit int, req dtos.AlertRuleSearchQueryRequest) (alertRules []models.AlertRule, total uint32, edgeXErr error) + DeleteAlertRuleById(id string) error + AlertRuleStart(id string) error + AlertRuleStop(id string) error + + AlertListLastSend(alertRuleId string) (alertList models.AlertList, edgeXErr error) + AddAlertList(alertRule models.AlertList) (models.AlertList, error) + AlertPlate(beforeTime int64) (plate []dtos.AlertPlateQueryResponse, err error) + AlertListSearch(offset int, limit int, req dtos.AlertSearchQueryRequest) (alertList []dtos.AlertSearchQueryResponse, total uint32, edgeXErr error) + AlertIgnore(id string) (edgeXErr error) + TreatedIgnore(id, message string) (edgeXErr error) +} + +type Scene interface { + AddScene(scene models.Scene) (models.Scene, error) + SceneById(id string) (models.Scene, error) + UpdateScene(scene models.Scene) error + SceneStart(id string) error + SceneStop(id string) error + DeleteSceneById(id string) error + SceneSearch(offset int, limit int, req dtos.SceneSearchQueryRequest) (scenes []models.Scene, total uint32, edgeXErr error) + + AddSceneLog(sceneLog models.SceneLog) (models.SceneLog, error) + SceneLogSearch(offset int, limit int, req dtos.SceneLogSearchQueryRequest) (sceneLogs []models.SceneLog, total uint32, edgeXErr error) +} diff --git a/internal/hummingbird/core/interface/device.go b/internal/hummingbird/core/interface/device.go new file mode 100644 index 0000000..34d7ae6 --- /dev/null +++ b/internal/hummingbird/core/interface/device.go @@ -0,0 +1,70 @@ +package interfaces + +import ( + "context" + "github.com/winc-link/edge-driver-proto/driverdevice" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" +) + +type DeviceItf interface { + DeviceCtlItf + DeviceSyncItf + OpenApiDeviceItf +} + +type DeviceCtlItf interface { + AddDevice(ctx context.Context, req dtos.DeviceAddRequest) (string, error) + + DevicesSearch(ctx context.Context, req dtos.DeviceSearchQueryRequest) ([]dtos.DeviceSearchQueryResponse, uint32, error) + + DevicesModelSearch(ctx context.Context, req dtos.DeviceSearchQueryRequest) ([]models.Device, uint32, error) + + DeviceById(ctx context.Context, id string) (dtos.DeviceInfoResponse, error) + + DeviceModelById(ctx context.Context, id string) (models.Device, error) + + DeviceByCloudId(ctx context.Context, id string) (models.Device, error) + + DeviceUpdate(ctx context.Context, req dtos.DeviceUpdateRequest) error + + DevicesBindDriver(ctx context.Context, req dtos.DevicesBindDriver) error + + DevicesUnBindDriver(ctx context.Context, req dtos.DevicesUnBindDriver) error + + DevicesBindProductId(ctx context.Context, req dtos.DevicesBindProductId) error + + ConnectIotPlatform(ctx context.Context, request *driverdevice.ConnectIotPlatformRequest) *driverdevice.ConnectIotPlatformResponse + + DisConnectIotPlatform(ctx context.Context, request *driverdevice.DisconnectIotPlatformRequest) *driverdevice.DisconnectIotPlatformResponse + + GetDeviceConnectStatus(ctx context.Context, request *driverdevice.GetDeviceConnectStatusRequest) *driverdevice.GetDeviceConnectStatusResponse + + DeviceMqttAuthInfo(ctx context.Context, id string) (dtos.DeviceAuthInfoResponse, error) + + AddMqttAuth(ctx context.Context, req dtos.AddMqttAuthInfoRequest) (string, error) + + DeleteDeviceById(ctx context.Context, id string) error + + BatchDeleteDevice(ctx context.Context, ids []string) error + + DeviceImportTemplateDownload(ctx context.Context, req dtos.DeviceImportTemplateRequest) (*dtos.ExportFile, error) + + DevicesImport(ctx context.Context, file *dtos.ImportFile, productId, driverInstanceId string) (int64, error) + + UploadValidated(ctx context.Context, file *dtos.ImportFile) error + + DevicesReportMsgGather(ctx context.Context) error + + DeviceAction(jobAction dtos.JobAction) dtos.DeviceExecRes + + DeviceInvokeThingService(invokeDeviceServiceReq dtos.InvokeDeviceServiceReq) dtos.DeviceExecRes +} + +type OpenApiDeviceItf interface { + OpenApiDeviceById(ctx context.Context, id string) (dtos.OpenApiDeviceInfoResponse, error) + OpenApiDeviceStatusById(ctx context.Context, id string) (dtos.OpenApiDeviceStatus, error) + OpenApiDevicesSearch(ctx context.Context, req dtos.DeviceSearchQueryRequest) ([]dtos.OpenApiDeviceInfoResponse, uint32, error) +} +type DeviceSyncItf interface { +} diff --git a/internal/hummingbird/core/interface/dmi.go b/internal/hummingbird/core/interface/dmi.go new file mode 100644 index 0000000..8017cbc --- /dev/null +++ b/internal/hummingbird/core/interface/dmi.go @@ -0,0 +1,56 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package interfaces + +import ( + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/pkg/di" +) + +var ( + DriverModelInterfaceName = di.TypeInstanceToName((*DMI)(nil)) +) + +func DMIFrom(get di.Get) DMI { + return get(DriverModelInterfaceName).(DMI) +} + +type DMI interface { + DriverInInstanceDMI + StopAllInstance() +} + +//驱动相关接口 +type DriverInInstanceDMI interface { + // DownApp 下载驱动 + DownApp(cfg dtos.DockerConfig, app dtos.DeviceLibrary, toVersion string) (string, error) + + RemoveApp(app dtos.DeviceLibrary) error + GetAllApp() []string + // 检查驱动软件情况 + StateApp(dockerImageId string) bool + + InstanceState(ins dtos.DeviceService) bool + // StartInstance 启动实例 + StartInstance(ins dtos.DeviceService, cfg dtos.RunServiceCfg) (string, error) // 返回服务所在的ip + // StopInstance 停止实例 + StopInstance(ins dtos.DeviceService) error + // DeleteInstance 删除实例 + DeleteInstance(ins dtos.DeviceService) error + + GetDriverInstanceLogPath(serviceName string) string + // GetSelfIp 获取当前服务运行的内网ip + GetSelfIp() string +} diff --git a/internal/hummingbird/core/interface/driverapp.go b/internal/hummingbird/core/interface/driverapp.go new file mode 100644 index 0000000..1117e5b --- /dev/null +++ b/internal/hummingbird/core/interface/driverapp.go @@ -0,0 +1,26 @@ +package interfaces + +import ( + "context" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" +) + +type DriverLibApp interface { + AddDriverLib(ctx context.Context, dl dtos.DeviceLibraryAddRequest) error + DeleteDeviceLibraryById(ctx context.Context, id string) error + DeviceLibraryById(ctx context.Context, id string) (models.DeviceLibrary, error) + DeviceLibrariesSearch(ctx context.Context, req dtos.DeviceLibrarySearchQueryRequest) ([]models.DeviceLibrary, uint32, error) + UpdateDeviceLibrary(ctx context.Context, update dtos.UpdateDeviceLibrary) error + UpgradeDeviceLibrary(ctx context.Context, req dtos.DeviceLibraryUpgradeRequest) error + DriverLibById(dlId string) (models.DeviceLibrary, error) + GetDriverClassify(ctx context.Context, req dtos.DriverClassifyQueryRequest) ([]dtos.DriverClassifyResponse, uint32, error) + GetDeviceLibraryAndMirrorConfig(dlId string) (dl models.DeviceLibrary, dc models.DockerConfig, err error) + DriverDownConfigItf +} +type DriverDownConfigItf interface { + DownConfigAdd(ctx context.Context, req dtos.DockerConfigAddRequest) error + DownConfigUpdate(ctx context.Context, req dtos.DockerConfigUpdateRequest) error + DownConfigSearch(ctx context.Context, req dtos.DockerConfigSearchQueryRequest) ([]models.DockerConfig, uint32, error) + DownConfigDel(ctx context.Context, id string) error +} diff --git a/internal/hummingbird/core/interface/driverserviceapp.go b/internal/hummingbird/core/interface/driverserviceapp.go new file mode 100644 index 0000000..401c040 --- /dev/null +++ b/internal/hummingbird/core/interface/driverserviceapp.go @@ -0,0 +1,27 @@ +package interfaces + +import ( + //"context" + + "context" + //"gitlab.com/tedge/edgex/internal/dtos" + //"gitlab.com/tedge/edgex/internal/models" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" +) + +type DriverServiceApp interface { + GetState(id string) int + SetState(id string, state int) + Start(id string) error // 升级 + Stop(id string) error + ReStart(id string) error + Add(ctx context.Context, ds models.DeviceService) error + Update(ctx context.Context, dto dtos.DeviceServiceUpdateRequest) error + Del(ctx context.Context, id string) error + Get(ctx context.Context, id string) (models.DeviceService, error) + Search(ctx context.Context, req dtos.DeviceServiceSearchQueryRequest) ([]models.DeviceService, uint32, error) + UpdateRunStatus(ctx context.Context, req dtos.UpdateDeviceServiceRunStatusRequest) error + InProgress(id string) bool + Upgrade(dl models.DeviceLibrary) error // 升级驱动实例 +} diff --git a/internal/hummingbird/core/interface/homepage.go b/internal/hummingbird/core/interface/homepage.go new file mode 100644 index 0000000..eeaebc0 --- /dev/null +++ b/internal/hummingbird/core/interface/homepage.go @@ -0,0 +1,21 @@ +/******************************************************************************* + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package interfaces + +import ( + "context" + "github.com/winc-link/hummingbird/internal/dtos" +) + +type HomePageItf interface { + HomePageInfo(ctx context.Context, req dtos.HomePageRequest) (response dtos.HomePageResponse, err error) +} diff --git a/internal/hummingbird/core/interface/language.go b/internal/hummingbird/core/interface/language.go new file mode 100644 index 0000000..4f7955d --- /dev/null +++ b/internal/hummingbird/core/interface/language.go @@ -0,0 +1,25 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package interfaces + +import ( + "context" + "github.com/winc-link/hummingbird/internal/dtos" +) + +type LanguageSDKApp interface { + LanguageSDKSearch(ctx context.Context, req dtos.LanguageSDKSearchQueryRequest) ([]dtos.LanguageSDKSearchResponse, uint32, error) + Sync(ctx context.Context, versionName string) error +} diff --git a/internal/hummingbird/core/interface/message.go b/internal/hummingbird/core/interface/message.go new file mode 100644 index 0000000..aec564e --- /dev/null +++ b/internal/hummingbird/core/interface/message.go @@ -0,0 +1,36 @@ +package interfaces + +import ( + "context" + "github.com/winc-link/edge-driver-proto/drivercommon" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/hummingbird/core/application/messagestore" + //msgTypes "gitlab.com/tedge/edgex/internal/pkg/messaging/types" + // + //"gitlab.com/tedge/edgex/proto/thingmodel" + // + //pkgMQTT "gitlab.com/tedge/edgex/internal/tools/mqttclient" + // + //mqtt "github.com/eclipse/paho.mqtt.golang" + // + //"gitlab.com/tedge/edgex/internal/dtos" +) + +type MessageStores interface { + StoreMsgId(id string, ch string) + LoadMsgChan(id string) (interface{}, bool) + DeleteMsgId(id string) + GenAckChan(id string) *messagestore.MsgAckChan +} + +type MessageItf interface { + TyCloudMqttItf +} + +type PublishCallback func(ctx context.Context, params ...interface{}) (bool, interface{}) + +type TyCloudMqttItf interface { + // ThingModelMsgReport 物模型消息上报到云端 + ThingModelMsgReport(ctx context.Context, msg dtos.ThingModelMessage) (*drivercommon.CommonResponse, error) + DeviceStatusToMessageBus(ctx context.Context, deviceId, deviceStatus string) +} diff --git a/internal/hummingbird/core/interface/monitor.go b/internal/hummingbird/core/interface/monitor.go new file mode 100644 index 0000000..e50a8df --- /dev/null +++ b/internal/hummingbird/core/interface/monitor.go @@ -0,0 +1,24 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package interfaces + +import ( + "context" + "github.com/winc-link/hummingbird/internal/dtos" +) + +type MonitorItf interface { + GetSystemMetrics(ctx context.Context, query dtos.SystemMetricsQuery) (dtos.SystemMetricsResponse, error) +} diff --git a/internal/hummingbird/core/interface/persistence.go b/internal/hummingbird/core/interface/persistence.go new file mode 100644 index 0000000..5215a1a --- /dev/null +++ b/internal/hummingbird/core/interface/persistence.go @@ -0,0 +1,18 @@ +package interfaces + +import ( + "github.com/winc-link/hummingbird/internal/dtos" +) + +type PersistItf interface { + PersistDeviceItf +} + +type PersistDeviceItf interface { + SaveDeviceThingModelData(req dtos.ThingModelMessage) error + SearchDeviceThingModelPropertyData(req dtos.ThingModelPropertyDataRequest) (interface{}, error) + SearchDeviceThingModelHistoryPropertyData(req dtos.ThingModelPropertyDataRequest) (interface{}, int, error) + SearchDeviceThingModelEventData(req dtos.ThingModelEventDataRequest) ([]dtos.ThingModelEventDataResponse, int, error) + SearchDeviceThingModelServiceData(req dtos.ThingModelServiceDataRequest) ([]dtos.ThingModelServiceDataResponse, int, error) + SearchDeviceMsgCount(startTime, endTime int64) (int, error) +} diff --git a/internal/hummingbird/core/interface/product.go b/internal/hummingbird/core/interface/product.go new file mode 100644 index 0000000..3b16397 --- /dev/null +++ b/internal/hummingbird/core/interface/product.go @@ -0,0 +1,48 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package interfaces + +import ( + "context" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" +) + +type ProductItf interface { + ProductCtlItf +} + +type ProductCtlItf interface { + ProductsSearch(ctx context.Context, req dtos.ProductSearchQueryRequest) ([]dtos.ProductSearchQueryResponse, uint32, error) + ProductsModelSearch(ctx context.Context, req dtos.ProductSearchQueryRequest) ([]models.Product, uint32, error) + ProductById(ctx context.Context, id string) (dtos.ProductSearchByIdResponse, error) + ProductModelById(ctx context.Context, id string) (models.Product, error) + ProductDelete(ctx context.Context, id string) error + AddProduct(ctx context.Context, req dtos.ProductAddRequest) (string, error) + ProductRelease(ctx context.Context, productId string) error + ProductUnRelease(ctx context.Context, productId string) error + CreateProductCallBack(productInfo models.Product) + UpdateProductCallBack(productInfo models.Product) + DeleteProductCallBack(productInfo models.Product) + + ProductCtlOpenApiItf +} + +type ProductCtlOpenApiItf interface { + OpenApiAddProduct(ctx context.Context, req dtos.OpenApiAddProductRequest) (string, error) + OpenApiUpdateProduct(ctx context.Context, req dtos.OpenApiUpdateProductRequest) error + OpenApiProductById(ctx context.Context, id string) (dtos.ProductSearchByIdOpenApiResponse, error) + OpenApiProductSearch(ctx context.Context, req dtos.ProductSearchQueryRequest) ([]dtos.ProductSearchOpenApiResponse, uint32, error) +} diff --git a/internal/hummingbird/core/interface/scene.go b/internal/hummingbird/core/interface/scene.go new file mode 100644 index 0000000..4e81a93 --- /dev/null +++ b/internal/hummingbird/core/interface/scene.go @@ -0,0 +1,40 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package interfaces + +import ( + "context" + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/timer/jobs" +) + +type SceneApp interface { + AddScene(ctx context.Context, req dtos.SceneAddRequest) (string, error) + UpdateScene(ctx context.Context, req dtos.SceneUpdateRequest) error + SceneById(ctx context.Context, sceneId string) (models.Scene, error) + SceneStartById(ctx context.Context, sceneId string) error + SceneStopById(ctx context.Context, sceneId string) error + DelSceneById(ctx context.Context, sceneId string) error + SceneSearch(ctx context.Context, req dtos.SceneSearchQueryRequest) ([]models.Scene, uint32, error) + CheckSceneByDeviceId(ctx context.Context, deviceId string) error + SceneLogSearch(ctx context.Context, req dtos.SceneLogSearchQueryRequest) ([]models.SceneLog, uint32, error) + EkuiperNotify(ctx context.Context, req map[string]interface{}) error +} + +type ConJob interface { + AddJobToRunQueue(j *jobs.JobSchedule) error + DeleteJob(id string) +} diff --git a/internal/hummingbird/core/interface/system.go b/internal/hummingbird/core/interface/system.go new file mode 100644 index 0000000..4534a54 --- /dev/null +++ b/internal/hummingbird/core/interface/system.go @@ -0,0 +1,41 @@ +package interfaces + +import ( + "context" + "github.com/winc-link/hummingbird/internal/dtos" +) + +//先把这些临时放在这里 +type SystemItf interface { + GwConfigItf + AdvConfigItf + NetworkItf + GatewayItf +} + +type GwConfigItf interface { + LoadGatewayConfig() error + GetGatewayConfig() dtos.EdgeConfig +} + +type AdvConfigItf interface { + GetAdvanceConfig(ctx context.Context) (dtos.AdvanceConfig, error) + UpdateAdvanceConfig(ctx context.Context, cfg dtos.AdvanceConfig) error +} + +type NetworkItf interface { + GetNetworks(ctx context.Context) (dtos.ConfigNetWorkResponse, dtos.ConfigDnsResponse) + ConfigNetWork(ctx context.Context, isFlush bool) (resp dtos.ConfigNetWorkResponse, err error) + ConfigNetWorkUpdate(ctx context.Context, req dtos.ConfigNetworkUpdateRequest) error + ConfigDns(ctx context.Context) (dtos.ConfigDnsResponse, error) + ConfigDnsUpdate(ctx context.Context, req dtos.ConfigDnsUpdateRequest) error +} + +type GatewayItf interface { + SystemBackupFileDownload(ctx context.Context) (string, error) + SystemRecover(ctx context.Context, filepath string) error +} + +type Starter interface { + Conn() error +} diff --git a/internal/hummingbird/core/interface/thingmodel.go b/internal/hummingbird/core/interface/thingmodel.go new file mode 100644 index 0000000..824a7f3 --- /dev/null +++ b/internal/hummingbird/core/interface/thingmodel.go @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package interfaces + +import ( + "context" + "github.com/winc-link/hummingbird/internal/dtos" +) + +type ThingModelItf interface { + ThingModelCtlItf +} + +type ThingModelCtlItf interface { + ThingModelDelete(ctx context.Context, id string, thingModelType string) error + AddThingModel(ctx context.Context, req dtos.ThingModelAddOrUpdateReq) (string, error) + UpdateThingModel(ctx context.Context, req dtos.ThingModelAddOrUpdateReq) error + SystemThingModelSearch(ctx context.Context, req dtos.SystemThingModelSearchReq) (interface{}, error) + OpenApiAddThingModel(ctx context.Context, req dtos.OpenApiThingModelAddOrUpdateReq) error + OpenApiQueryThingModel(ctx context.Context, productId string) (dtos.OpenApiQueryThingModel, error) + OpenApiDeleteThingModel(ctx context.Context, req dtos.OpenApiThingModelDeleteReq) error +} diff --git a/internal/hummingbird/core/interface/user.go b/internal/hummingbird/core/interface/user.go new file mode 100644 index 0000000..5c3df2d --- /dev/null +++ b/internal/hummingbird/core/interface/user.go @@ -0,0 +1,17 @@ +package interfaces + +import ( + "context" + "github.com/winc-link/hummingbird/internal/dtos" + //"gitlab.com/tedge/edgex/internal/dtos" +) + +type UserItf interface { + UserLogin(ctx context.Context, req dtos.LoginRequest) (res dtos.LoginResponse, err error) + InitInfo() (res dtos.InitInfoResponse, err error) + InitPassword(ctx context.Context, req dtos.InitPasswordRequest) error + UpdateUserPassword(ctx context.Context, username string, req dtos.UpdatePasswordRequest) error + OpenApiUserLogin(ctx context.Context, req dtos.LoginRequest) (res *dtos.TokenDetail, err error) + CreateTokenDetail(userName string) (*dtos.TokenDetail, error) + InitJwtKey() +} diff --git a/internal/hummingbird/core/main.go b/internal/hummingbird/core/main.go new file mode 100644 index 0000000..102b5ba --- /dev/null +++ b/internal/hummingbird/core/main.go @@ -0,0 +1,62 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package core + +import ( + "context" + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/hummingbird/core/bootstrap/database" + "github.com/winc-link/hummingbird/internal/hummingbird/core/config" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + "github.com/winc-link/hummingbird/internal/hummingbird/core/initialize" + "github.com/winc-link/hummingbird/internal/hummingbird/core/route" + "github.com/winc-link/hummingbird/internal/pkg/bootstrap" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/flags" + "github.com/winc-link/hummingbird/internal/pkg/handlers" + pkghandlers "github.com/winc-link/hummingbird/internal/pkg/handlers" + "github.com/winc-link/hummingbird/internal/pkg/startup" + "os" +) + +func Main(ctx context.Context, cancel context.CancelFunc, router *gin.Engine) { + f := flags.New() + f.Parse(os.Args[1:]) + + configuration := &config.ConfigurationStruct{} + di.GContainer = di.NewContainer(di.ServiceConstructorMap{ + container.ConfigurationName: func(get di.Get) interface{} { + return configuration + }, + }) + startupTimer := startup.NewStartUpTimer(constants.CoreServiceKey) + + bootstrap.Run( + ctx, + cancel, + f, + constants.CoreServiceKey, + constants.ConfigStemCore+constants.ConfigMajorVersion, + configuration, + startupTimer, + di.GContainer, + []handlers.BootstrapHandler{ + database.NewDatabase(configuration).BootstrapHandler, + initialize.NewBootstrap(router).BootstrapHandler, + pkghandlers.NewHttpServer(router, true).BootstrapHandler, + route.NewWebBootstrap().BootstrapHandler, + }) +} diff --git a/internal/hummingbird/core/migrates/migrate.go b/internal/hummingbird/core/migrates/migrate.go new file mode 100644 index 0000000..15fddaa --- /dev/null +++ b/internal/hummingbird/core/migrates/migrate.go @@ -0,0 +1,33 @@ +package migrates + +import ( + "github.com/go-gormigrate/gormigrate/v2" + _ "github.com/jinzhu/gorm/dialects/sqlite" + "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + + bootstrapContainer "github.com/winc-link/hummingbird/internal/hummingbird/core/container" +) + +func Migrate(dic *di.Container) { + dbClient := bootstrapContainer.DBClientFrom(dic.Get) + lc := container.LoggingClientFrom(dic.Get) + m := gormigrate.New(dbClient.GetDBInstance(), gormigrate.DefaultOptions, migrations()) + + if err := m.Migrate(); err != nil { + lc.Errorf("Migration run err: %v", err) + } else { + lc.Info("Migration run successfully") + } +} + +func migrations() []*gormigrate.Migration { + return []*gormigrate.Migration{ + //m_1627356992_funcpoint_properties(), + //m_1630660539_device_expand_data(), + //m_1630660539_update_screen_device(), + //m_1637287851_update_library_internal(), + //m_1641866282_upgrade_cloud_market(), + //m_1648619146_upgrade_driver_app(), + } +} diff --git a/internal/hummingbird/core/route/baserouter.go b/internal/hummingbird/core/route/baserouter.go new file mode 100644 index 0000000..18dfe6d --- /dev/null +++ b/internal/hummingbird/core/route/baserouter.go @@ -0,0 +1,81 @@ +package route + +import ( + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + pkgContainer "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/utils" + //"github.com/winc-link/hummingbird/internal/system/monitor/container" + + //"gitlab.com/tedge/edgex/internal/pkg/constants" + "net/http/httputil" + "net/url" + "os" + "regexp" + "strconv" + "strings" + + "github.com/gin-contrib/static" + "github.com/gin-gonic/gin" +) + +func LoadWebProxyRoutes(r *gin.Engine, webBuildPath string, dic *di.Container) { + r.Use(ProxyWeb(r, webBuildPath, dic)).Use(static.ServeRoot("/", webBuildPath)) +} + +//ProxyServer http proxy +func ProxyServer(c *gin.Context, dic *di.Container) { + configuration := container.ConfigurationFrom(dic.Get) + + port := strconv.Itoa(configuration.Service.Port) + addr := configuration.Service.ServerBindAddr + ":" + port + + lc := pkgContainer.LoggingClientFrom(dic.Get) + parseRootUrl, err := url.Parse("http://" + addr) + if err != nil { + lc.Errorf("parse server url err:%v", err) + c.Data(502, "", []byte("proxy server error")) + } + proxy := httputil.NewSingleHostReverseProxy(parseRootUrl) + proxy.ServeHTTP(c.Writer, c.Request) +} + +//ProxyWeb 转发 +func ProxyWeb(g *gin.Engine, webBuildPath string, dic *di.Container) gin.HandlerFunc { + return func(context *gin.Context) { + ReplaceURLPrefix(context, dic) + uri := context.Request.URL.Path + if ok, _ := regexp.MatchString("^(/api/|/v1.0/)", uri); ok { + ProxyServer(context, dic) + context.Abort() + return + } + + absPath := webBuildPath + context.Request.URL.Path + if utils.FilePathIsExist(absPath) { + return + } + context.Request.URL.Path = "/" + // 判断index.html文件是否存在 + indexPath := webBuildPath + "/index.html" + if !utils.FilePathIsExist(indexPath) { + context.Data(404, "", []byte("404 not found")) + context.Abort() + return + } + g.HandleContext(context) + } +} + +func ReplaceURLPrefix(context *gin.Context, dic *di.Container) { + lc := pkgContainer.LoggingClientFrom(dic.Get) + //get prefix from env + prefix := os.Getenv("URLPrefix") + if prefix == "" { + return + } + + prefix = prefix + "/" + context.Request.URL.Path = strings.ReplaceAll(context.Request.URL.Path, prefix, "") + lc.Debugf("after replace url path:%s", context.Request.URL.Path) +} diff --git a/internal/hummingbird/core/route/gateway.go b/internal/hummingbird/core/route/gateway.go new file mode 100644 index 0000000..36c7b44 --- /dev/null +++ b/internal/hummingbird/core/route/gateway.go @@ -0,0 +1,206 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package route + +import ( + "github.com/winc-link/hummingbird/internal/hummingbird/core/controller/http/gateway" + "github.com/winc-link/hummingbird/internal/hummingbird/core/controller/http/websocket" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/tools/jwt" + + "github.com/gin-gonic/gin" + + "github.com/swaggo/files" + gs "github.com/swaggo/gin-swagger" +) + +func RegisterGateway(engine *gin.Engine, dic *di.Container) { + ctl := gateway.New(dic) + v1 := engine.Group("/api/v1") + v1.GET("/swagger/*any", gs.WrapHandler(swaggerFiles.Handler)) + + v1.POST("auth/login", ctl.Login) + v1.GET("auth/initInfo", ctl.InitInfo) + v1.POST("auth/init-password", ctl.InitPassword) + + v1.POST("ekuiper/alert", ctl.EkuiperAlert) + v1.POST("ekuiper/scene", ctl.EkuiperScene) //ekuiper 服务调用 + v1.GET("ws/", websocket.NewServer(dic).Handle) + + v1Auth := v1.Group("", jwt.JWTAuth(false)) + v1Auth.PUT("auth/password", ctl.UpdatePassword) + /*******首页 *******/ + { + v1Auth.GET("home-page", ctl.HomePage) + } + + { + v1Auth.GET("local/config/network", ctl.ConfigNetWorkGet) + v1Auth.PUT("local/config/network", ctl.ConfigNetWorkUpdate) + v1Auth.GET("local/config/dns", ctl.ConfigDnsGet) + v1Auth.PUT("local/config/dns", ctl.ConfigDnsUpdate) + } + { + /******* 运维管理-agentclient *******/ + v1Auth.GET("/metrics/system", ctl.SystemMetricsHandler) + } + + /*******云服实例*******/ + { + v1Auth.GET("cloud-instance", ctl.CloudInstanceSearch) + } + + /******* 镜像仓库管理 *******/ + { + v1Auth.POST("docker-configs", ctl.DockerConfigAdd) + v1Auth.GET("docker-configs", ctl.DockerConfigsSearch) + v1Auth.PUT("docker-configs/:dockerConfigId", ctl.DockerConfigUpdate) + v1Auth.DELETE("docker-configs/:dockerConfigId", ctl.DockerConfigDelete) + } + + /*******驱动管理 *******/ + { + v1Auth.POST("device-libraries", ctl.DeviceLibraryAdd) + v1Auth.GET("device-libraries", ctl.DeviceLibrariesSearch) + v1Auth.DELETE("device-libraries/:deviceLibraryId", ctl.DeviceLibraryDelete) + v1Auth.PUT("device-libraries/:deviceLibraryId", ctl.DeviceLibraryUpdate) + } + /*******驱动实例 *******/ + { + v1Auth.GET("device-servers", ctl.DeviceServicesSearch) + v1Auth.PUT("device-server/:deviceServiceId", ctl.DeviceServiceUpdate) + } + + /*******驱动市场分类 *******/ + { + v1Auth.GET("device-classify", ctl.DeviceClassify) + } + + /*******产品管理 *******/ + { + v1Auth.GET("products", ctl.ProductsSearch) + v1Auth.GET("product/:productId", ctl.ProductById) + v1Auth.POST("product", ctl.ProductAdd) + v1Auth.POST("product-release/:productId", ctl.ProductRelease) + v1Auth.POST("product-unrelease/:productId", ctl.ProductUnRelease) + v1Auth.DELETE("product/:productId", ctl.ProductDelete) + v1Auth.GET("iot-platform", ctl.IotPlatform) + } + /*******产品物模型管理 *******/ + { + v1Auth.GET("thingmodel/system", ctl.SystemThingModelSearch) + v1Auth.POST("thingmodel", ctl.ThingModelAdd) + v1Auth.PUT("thingmodel", ctl.ThingModelUpdate) + v1Auth.DELETE("thingmodel", ctl.ThingModelDelete) + v1Auth.GET("thingmodel/unit", ctl.ThingModelUnit) + v1Auth.POST("thingmodel/unit-sync", ctl.ThingModelUnitSync) + v1Auth.POST("thingmodel/docs-sync", ctl.ThingModelDocsSync) + v1Auth.POST("thingmodel/quicknavigation-sync", ctl.ThingModelQuickNavigationSync) + //v1Auth.POST("thingmodel/msg-gather", ctl.MsgGather) + + } + + /*******设备管理 *******/ + { + v1Auth.POST("device", ctl.DeviceByAdd) + v1Auth.GET("devices", ctl.DevicesSearch) + v1Auth.GET("device/:deviceId", ctl.DeviceById) + v1Auth.DELETE("device/:deviceId", ctl.DeviceDelete) + v1Auth.DELETE("devices", ctl.DevicesDelete) + v1Auth.PUT("device/:deviceId", ctl.DeviceUpdate) + v1Auth.GET("device-mqtt/:deviceId", ctl.DeviceMqttInfoById) + v1Auth.POST("device-mqtt", ctl.AddMqttAuth) + v1Auth.GET("device/:deviceId/thing-model/property", ctl.DeviceThingModelPropertyDataSearch) + v1Auth.GET("device/:deviceId/thing-model/history-property", ctl.DeviceThingModelHistoryPropertyDataSearch) + v1Auth.GET("device/:deviceId/thing-model/event", ctl.DeviceThingModelEventDataSearch) + v1Auth.GET("device/:deviceId/thing-model/service", ctl.DeviceThingModelServiceDataSearch) + v1Auth.GET("device/status-template", ctl.DeviceStatusTemplate) + v1Auth.GET("devices/import-template", ctl.DeviceImportTemplateDownload) + v1Auth.POST("devices/import", ctl.DevicesImport) + v1Auth.POST("device/upload-validated", ctl.UploadValidated) + v1Auth.PUT("devices/bind-driver", ctl.DevicesBindDriver) + v1Auth.PUT("devices/unbind-driver", ctl.DevicesUnBindDriver) + v1Auth.PUT("devices/bind-product", ctl.DevicesBindByProductId) + + } + /*******品类、物模型同步接口 *******/ + { + v1Auth.GET("category-template", ctl.CategoryTemplateSearch) + v1Auth.POST("category-template/sync", ctl.CategoryTemplateSync) + v1Auth.GET("thingmodel-template", ctl.ThingModelTemplateSearch) + v1Auth.GET("thingmodel-template/:categoryKey", ctl.ThingModelTemplateByCategoryKey) + v1Auth.POST("thingmodel-template/sync", ctl.ThingModelTemplateSync) + } + + /*******告警中心接口 *******/ + { + v1Auth.POST("alert-rule", ctl.AlertRuleAdd) + v1Auth.PUT("alert-rule/:ruleId", ctl.AlertRuleUpdate) + v1Auth.PUT("rule-field", ctl.AlertRuleUpdateField) + v1Auth.GET("alert-rule/:ruleId", ctl.AlertRuleById) + v1Auth.GET("alert-rule", ctl.AlertRuleSearch) + v1Auth.DELETE("alert-rule/:ruleId", ctl.AlertRuleDelete) + v1Auth.POST("alert-rule/:ruleId/start", ctl.AlertRuleStart) + v1Auth.POST("alert-rule/:ruleId/stop", ctl.AlertRuleStop) + v1Auth.POST("alert-rule/:ruleId/restart", ctl.AlertRuleRestart) + v1Auth.GET("alert-list", ctl.AlertSearch) + v1Auth.GET("alert-plate", ctl.AlertPlate) + v1Auth.PUT("alert-ignore/:ruleId", ctl.AlertIgnore) + v1Auth.POST("alert-treated", ctl.AlertTreated) + + } + /*******规则引擎 *******/ + { + v1Auth.POST("rule-engine", ctl.RuleEngineAdd) + v1Auth.PUT("rule-engine", ctl.RuleEngineUpdate) + v1Auth.GET("rule-engine/:ruleEngineId", ctl.RuleEngineById) + v1Auth.GET("rule-engine", ctl.RuleEngineSearch) + v1Auth.POST("rule-engine/:ruleEngineId/start", ctl.RuleEngineStart) + v1Auth.POST("rule-engine/:ruleEngineId/stop", ctl.RuleEngineStop) + v1Auth.DELETE("rule-engine/:ruleEngineId/delete", ctl.RuleEngineDelete) + v1Auth.GET("rule-engine/:ruleEngineId/status", ctl.RuleEngineStatus) + + } + /*******资源管理 *******/ + { + v1Auth.GET("typeresource", ctl.DataResourceType) + v1Auth.PUT("dataresource", ctl.UpdateDataResource) + v1Auth.POST("dataresource", ctl.DataResourceAdd) + v1Auth.DELETE("dataresource/:dataResourceId", ctl.DataResourceDel) + v1Auth.GET("dataresource", ctl.DataResourceSearch) + v1Auth.GET("dataresource/:dataResourceId", ctl.DataResourceById) + v1Auth.POST("dataresource/:dataResourceId/health", ctl.DataResourceHealth) + + } + /*******场景联动 *******/ + { + + v1Auth.POST("scene", ctl.SceneAdd) + v1Auth.PUT("scene", ctl.SceneUpdate) + v1Auth.GET("scene/:sceneId", ctl.SceneById) + v1Auth.GET("scene", ctl.SearchScene) + v1Auth.POST("scene/:sceneId/start", ctl.SceneStart) + v1Auth.POST("scene/:sceneId/stop", ctl.SceneStop) + v1Auth.DELETE("scene/:sceneId", ctl.DeleteScene) + v1Auth.GET("scene/:sceneId/log", ctl.SceneLogSearch) + } + /*******文档中心(sdk) *******/ + { + + v1Auth.GET("language-sdk", ctl.LanguageSdkSearch) + v1Auth.POST("language-sdk-sync", ctl.LanguageSdkSync) + } + +} diff --git a/internal/hummingbird/core/route/openapi.go b/internal/hummingbird/core/route/openapi.go new file mode 100644 index 0000000..c0d332d --- /dev/null +++ b/internal/hummingbird/core/route/openapi.go @@ -0,0 +1,96 @@ +/******************************************************************************* + * Copyright 2017. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ + +package route + +import ( + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/hummingbird/core/controller/http/openapi" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/errort" + "github.com/winc-link/hummingbird/internal/tools/jwt" + "github.com/winc-link/hummingbird/internal/tools/openapihelper" +) + +func RegisterOpenApi(engine *gin.Engine, dic *di.Container) { + engine.NoRoute(func(c *gin.Context) { + openapihelper.ReaderFail(c, errort.UrlPathIsInvalid) + }) + + ctl := openapi.New(dic) + + v := engine.Group("/v1.0/openapi") + v.POST("/auth/login", ctl.Login) + v.POST("/token/:refreshToken", ctl.RefreshToken) + v1 := v.Group("", jwt.JWTAuth(false)) + + //产品管理的API + { + //创建产品 + v1.POST("/product", ctl.OpenApiCreateProduct) + //更新产品 + v1.PUT("/product/:productId", ctl.OpenApiUpdateProduct) + //查询产品列表 + v1.GET("/products", ctl.OpenApiProductSearch) + //查询产品详细信息。 + v1.GET("/product/:productId", ctl.OpenApiProductById) + // 发布产品 + v1.GET("/product-release/:productId", ctl.OpenApiProductReleaseById) + // 取消发布产品 + v1.GET("/product-unrelease/:productId", ctl.OpenApiProductUnReleaseById) + //删除指定产品。 + v1.DELETE("product/:productId", ctl.OpenApiDeleteProduct) + } + //设备管理的API + { + //创建设备 + v1.POST("/device", ctl.OpenApiCreateDevice) + //更新设备 + v1.PUT("/device/:deviceId", ctl.OpenApiUpdateDevice) + //查询设备列表 + v1.GET("/devices", ctl.OpenApiDeviceSearch) + //查询设备详细信息。 + v1.GET("/device/:deviceId", ctl.OpenApiDeviceById) + //删除指定设备。 + v1.DELETE("/device/:deviceId", ctl.OpenApiDeleteDevice) + //获取设备的运行状态。 + //v1.GET("/deviceStatus/:deviceId", ctl.OpenApiDeviceStatus) + + } + //物模型管理的API + { + //为指定产品的物模型新增功能 + v1.POST("/thingModel", ctl.OpenApiThingModelAddOrUpdate) + //更新指定产品物模型中的单个功能 + v1.PUT("/thingModel", ctl.OpenApiThingModelAddOrUpdate) + //查看指定产品的物模型中的功能定义详情 + v1.GET("/thingModel", ctl.OpenApiThingModel) + //DeleteThingModel + v1.DELETE("/thingModel", ctl.OpenApiDeleteThingModel) + + } + //物模型使用的API + { + //设置设备的属性。 + v1.POST("/setDeviceProperty", ctl.OpenApiSetDeviceProperty) + //调用设备的服务。 + v1.POST("/invokeThingService", ctl.OpenApiInvokeThingService) + //查询设备的属性历史数据。 + v1.GET("/queryDevicePropertyData", ctl.OpenApiQueryDevicePropertyData) + //查询设备的事件历史数据。 + v1.GET("/queryDeviceEventData", ctl.OpenApiQueryDeviceEventData) + //获取设备的服务记录历史数据。 + v1.GET("/queryDeviceServiceData", ctl.OpenApiQueryDeviceServiceData) + } +} diff --git a/internal/hummingbird/core/route/route.go b/internal/hummingbird/core/route/route.go new file mode 100644 index 0000000..13d8eed --- /dev/null +++ b/internal/hummingbird/core/route/route.go @@ -0,0 +1,36 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package route + +import ( + "github.com/gin-contrib/pprof" + "github.com/gin-gonic/gin" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/i18n" + "github.com/winc-link/hummingbird/internal/pkg/middleware" +) + +func LoadRestRoutes(r *gin.Engine, dic *di.Container) { + // add pprof + pprof.Register(r) + + r.Use(i18n.I18nHandlerGin()) + r.Use(middleware.CorrelationHeader()) + + // load gateway router + RegisterGateway(r, dic) + // load open api + RegisterOpenApi(r, dic) +} diff --git a/internal/hummingbird/core/route/webproxy.go b/internal/hummingbird/core/route/webproxy.go new file mode 100644 index 0000000..e669bc1 --- /dev/null +++ b/internal/hummingbird/core/route/webproxy.go @@ -0,0 +1,90 @@ +package route + +import ( + "context" + "github.com/winc-link/hummingbird/internal/hummingbird/core/container" + pkgContainer "github.com/winc-link/hummingbird/internal/pkg/container" + "github.com/winc-link/hummingbird/internal/pkg/di" + "github.com/winc-link/hummingbird/internal/pkg/startup" + "net/http" + "strconv" + "sync" + "time" + + "github.com/gin-gonic/gin" +) + +///var/bin/cmd/hummingbird-ui/build +const WebBuildPath = "./cmd/hummingbird-ui/build" +const WebBuildPath2 = "/var/build" + +// WebBootstrap contains references to dependencies required by the BootstrapHandler. +type WebBootstrap struct { + router *gin.Engine + AppMode string +} + +// NewWebBootstrap is a factory method that returns an initialized WebBootstrap receiver struct. +func NewWebBootstrap() *WebBootstrap { + // 不做路由日志输出 + g := gin.New() + g.Use(gin.Recovery(), gin.Logger()) + return &WebBootstrap{ + router: g, + } +} + +// BootstrapHandler fulfills the BootstrapHandler contract and performs initialization needed by the resource service. +func (b *WebBootstrap) BootstrapHandler(ctx context.Context, wg *sync.WaitGroup, _ startup.Timer, dic *di.Container) bool { + configuration := container.ConfigurationFrom(dic.Get) + lc := pkgContainer.LoggingClientFrom(dic.Get) + + lc.Infof("start WebBootstrap BootstrapHandler in...") + + //pwd, _ := os.Getwd() + LoadWebProxyRoutes(b.router, WebBuildPath, dic) + + if configuration.WebServer.Host == "" || configuration.WebServer.Port == 0 { + lc.Errorf("WebServer Host is null OR port is 0") + return false + } + port := strconv.Itoa(configuration.WebServer.Port) + addr := configuration.WebServer.Host + ":" + port + timeout := time.Second * time.Duration(configuration.WebServer.Timeout) + server := &http.Server{ + Addr: addr, + Handler: b.router, + WriteTimeout: timeout, + ReadTimeout: timeout, + } + + wg.Add(1) + go func() { + defer wg.Done() + + <-ctx.Done() + lc.Info("WebProxy server shutting down") + _ = server.Shutdown(context.Background()) + lc.Info("WebProxy server shut down") + }() + + lc.Info("WebProxy server starting (" + addr + ")") + + wg.Add(1) + go func() { + defer func() { + wg.Done() + }() + + err := server.ListenAndServe() + if err != nil { + lc.Errorf("WebProxy server failed: %v", err) + cancel := pkgContainer.CancelFuncFrom(dic.Get) + cancel() // this will caused the service to stop + } else { + lc.Info("WebProxy server stopped") + } + }() + + return true +} diff --git a/internal/hummingbird/mqttbroker/LICENSE b/internal/hummingbird/mqttbroker/LICENSE new file mode 100644 index 0000000..ee1c036 --- /dev/null +++ b/internal/hummingbird/mqttbroker/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 DrmagicE + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/internal/hummingbird/mqttbroker/config/api.go b/internal/hummingbird/mqttbroker/config/api.go new file mode 100644 index 0000000..e408d27 --- /dev/null +++ b/internal/hummingbird/mqttbroker/config/api.go @@ -0,0 +1,77 @@ +package config + +import ( + "fmt" + "net" + "strings" +) + +// API is the configuration for API server. +// The API server use gRPC-gateway to provide both gRPC and HTTP endpoints. +type API struct { + // GRPC is the gRPC endpoint configuration. + GRPC []*Endpoint `yaml:"grpc"` + // HTTP is the HTTP endpoint configuration. + HTTP []*Endpoint `yaml:"http"` +} + +// Endpoint represents a gRPC or HTTP server endpoint. +type Endpoint struct { + // Address is the bind address of the endpoint. + // Format: [tcp|unix://][]: + // e.g : + // * unix:///var/run/mqttd.sock + // * tcp://127.0.0.1:8080 + // * :8081 (equal to tcp://:8081) + Address string `yaml:"address"` + // Map maps the HTTP endpoint to gRPC endpoint. + // Must be set if the endpoint is representing a HTTP endpoint. + Map string `yaml:"map"` + // TLS is the tls configuration. + TLS *TLSOptions `yaml:"tls"` +} + +var DefaultAPI API + +func (a API) validateAddress(address string, fieldName string) error { + if address == "" { + return fmt.Errorf("%s cannot be empty", fieldName) + } + epParts := strings.SplitN(address, "://", 2) + if len(epParts) == 1 && epParts[0] != "" { + epParts = []string{"tcp", epParts[0]} + } + if len(epParts) != 0 { + switch epParts[0] { + case "tcp": + _, _, err := net.SplitHostPort(epParts[1]) + if err != nil { + return fmt.Errorf("invalid %s: %s", fieldName, err.Error()) + } + case "unix": + default: + return fmt.Errorf("invalid %s schema: %s", fieldName, epParts[0]) + } + } + return nil +} + +func (a API) Validate() error { + for _, v := range a.GRPC { + err := a.validateAddress(v.Address, "endpoint") + if err != nil { + return err + } + } + for _, v := range a.HTTP { + err := a.validateAddress(v.Address, "endpoint") + if err != nil { + return err + } + err = a.validateAddress(v.Map, "map") + if err != nil { + return err + } + } + return nil +} diff --git a/internal/hummingbird/mqttbroker/config/api_test.go b/internal/hummingbird/mqttbroker/config/api_test.go new file mode 100644 index 0000000..e0265c8 --- /dev/null +++ b/internal/hummingbird/mqttbroker/config/api_test.go @@ -0,0 +1,95 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAPI_Validate(t *testing.T) { + a := assert.New(t) + + tt := []struct { + cfg API + valid bool + }{ + { + cfg: API{ + GRPC: []*Endpoint{ + { + Address: "udp://127.0.0.1", + }, + }, + HTTP: []*Endpoint{ + {}, + }, + }, + valid: false, + }, + { + cfg: API{ + GRPC: []*Endpoint{ + { + Address: "tcp://127.0.0.1:1234", + }, + }, + HTTP: []*Endpoint{ + { + Address: "udp://127.0.0.1", + }, + }, + }, + valid: false, + }, + { + cfg: API{ + GRPC: []*Endpoint{ + { + Address: "tcp://127.0.0.1:1234", + }, + }, + }, + valid: true, + }, + { + cfg: API{ + GRPC: []*Endpoint{ + { + Address: "tcp://127.0.0.1:1234", + }, + }, + HTTP: []*Endpoint{ + { + Address: "tcp://127.0.0.1:1235", + }, + }, + }, + valid: false, + }, + { + cfg: API{ + GRPC: []*Endpoint{ + { + Address: "unix:///var/run/mqttd.sock", + }, + }, + HTTP: []*Endpoint{ + { + Address: "tcp://127.0.0.1:1235", + Map: "unix:///var/run/mqttd.sock", + }, + }, + }, + valid: true, + }, + } + for _, v := range tt { + err := v.cfg.Validate() + if v.valid { + a.NoError(err) + } else { + a.Error(err) + } + } + +} diff --git a/internal/hummingbird/mqttbroker/config/config.go b/internal/hummingbird/mqttbroker/config/config.go new file mode 100644 index 0000000..d46ecc1 --- /dev/null +++ b/internal/hummingbird/mqttbroker/config/config.go @@ -0,0 +1,323 @@ +package config + +import ( + "bytes" + "fmt" + "io/ioutil" + "os" + "path" + "strings" + + "gopkg.in/natefinch/lumberjack.v2" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "gopkg.in/yaml.v2" + + pkgconfig "github.com/winc-link/hummingbird/internal/pkg/config" +) + +const EdgeMqttBroker = "mqtt-broker" + +var ( + defaultPluginConfig = make(map[string]Configuration) + configFileFullPath string + config Config +) + +// Configuration is the interface that enable the implementation to parse config from the global config file. +// Plugin admin and prometheus are two examples. +type Configuration interface { + // Validate validates the configuration. + // If returns error, the broker will not start. + Validate() error + // Unmarshaler defined how to unmarshal YAML into the config structure. + yaml.Unmarshaler +} + +// RegisterDefaultPluginConfig registers the default configuration for the given plugin. +func RegisterDefaultPluginConfig(name string, config Configuration) { + if _, ok := defaultPluginConfig[name]; ok { + panic(fmt.Sprintf("duplicated default config for %s plugin", name)) + } + defaultPluginConfig[name] = config + +} + +// DefaultConfig return the default configuration. +// If config file is not provided, mqttd will start with DefaultConfig. +func DefaultConfig() Config { + c := Config{ + Listeners: DefaultListeners, + MQTT: DefaultMQTTConfig, + API: DefaultAPI, + Log: LogConfig{ + Level: "info", + FilePath: "/var/tedge/logs/mqtt-broker.log", + }, + Plugins: make(pluginConfig), + PluginOrder: []string{"aplugin"}, + Persistence: DefaultPersistenceConfig, + TopicAliasManager: DefaultTopicAliasManager, + } + + for name, v := range defaultPluginConfig { + c.Plugins[name] = v + } + return c +} + +var DefaultListeners = []*ListenerConfig{ + { + Address: "0.0.0.0:58090", + TLSOptions: nil, + Websocket: nil, + }, + { + Address: "0.0.0.0:58091", + Websocket: &WebsocketOptions{ + Path: "/", + }, + }, { + Address: "0.0.0.0:21883", + TLSOptions: &TLSOptions{ + CACert: "/etc/tedge-mqtt-broker/ca.crt", + Cert: "/etc/tedge-mqtt-broker/server.pem", + Key: "/etc/tedge-mqtt-broker/server.key", + }, + }, +} + +// LogConfig is use to configure the log behaviors. +type LogConfig struct { + // Level is the log level. Possible values: debug, info, warn, error + Level string `yaml:"level"` + FilePath string `yaml:"file_path"` + // DumpPacket indicates whether to dump MQTT packet in debug level. + DumpPacket bool `yaml:"dump_packet"` +} + +func (l LogConfig) Validate() error { + level := strings.ToLower(l.Level) + if level != "debug" && level != "info" && level != "warn" && level != "error" { + return fmt.Errorf("invalid log level: %s", l.Level) + } + return nil +} + +// pluginConfig stores the plugin default configuration, key by the plugin name. +// If the plugin has default configuration, it should call RegisterDefaultPluginConfig in it's init function to register. +type pluginConfig map[string]Configuration + +func (p pluginConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + for _, v := range p { + err := unmarshal(v) + if err != nil { + return err + } + } + return nil +} + +// Config is the configration for mqttd. +type Config struct { + Listeners []*ListenerConfig `yaml:"listeners"` + API API `yaml:"api"` + MQTT MQTT `yaml:"mqttclient,omitempty"` + GRPC GRPC `yaml:"gRPC"` + Log LogConfig `yaml:"log"` + PidFile string `yaml:"pid_file"` + ConfigDir string `yaml:"config_dir"` + Plugins pluginConfig `yaml:"plugins"` + // PluginOrder is a slice that contains the name of the plugin which will be loaded. + // Giving a correct order to the slice is significant, + // because it represents the loading order which affect the behavior of the broker. + PluginOrder []string `yaml:"plugin_order"` + Persistence Persistence `yaml:"persistence"` + TopicAliasManager TopicAliasManager `yaml:"topic_alias_manager"` + Database pkgconfig.Database `yaml:"data_base"` +} + +type GRPC struct { + Endpoint string `yaml:"endpoint"` +} + +type TLSOptions struct { + // CACert is the trust CA certificate file. + CACert string `yaml:"cacert"` + // Cert is the path to certificate file. + Cert string `yaml:"cert"` + // Key is the path to key file. + Key string `yaml:"key"` + // Verify indicates whether to verify client cert. + Verify bool `yaml:"verify"` +} + +type ListenerConfig struct { + Address string `yaml:"address"` + *TLSOptions `yaml:"tls"` + Websocket *WebsocketOptions `yaml:"websocket"` +} + +type WebsocketOptions struct { + Path string `yaml:"path"` +} + +func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { + type config Config + raw := config(DefaultConfig()) + if err := unmarshal(&raw); err != nil { + return err + } + emptyMQTT := MQTT{} + if raw.MQTT == emptyMQTT { + raw.MQTT = DefaultMQTTConfig + } + if len(raw.Plugins) == 0 { + raw.Plugins = make(pluginConfig) + for name, v := range defaultPluginConfig { + raw.Plugins[name] = v + } + } else { + for name, v := range raw.Plugins { + if v == nil { + raw.Plugins[name] = defaultPluginConfig[name] + } + } + } + *c = Config(raw) + return nil +} + +func (c Config) Validate() (err error) { + err = c.Log.Validate() + if err != nil { + return err + } + err = c.API.Validate() + if err != nil { + return err + } + err = c.MQTT.Validate() + if err != nil { + return err + } + err = c.Persistence.Validate() + if err != nil { + return err + } + for _, conf := range c.Plugins { + err := conf.Validate() + if err != nil { + return err + } + } + return nil +} + +func ParseConfig(filePath string) (Config, error) { + if filePath == "" { + return DefaultConfig(), nil + } + if _, err := os.Stat(filePath); err != nil { + fmt.Println("unspecificed configuration file, use default config") + return DefaultConfig(), nil + } + b, err := ioutil.ReadFile(filePath) + if err != nil { + return config, err + } + config = DefaultConfig() + err = yaml.Unmarshal(b, &config) + if err != nil { + return config, err + } + config.ConfigDir = path.Dir(filePath) + err = config.Validate() + if err != nil { + return Config{}, err + } + configFileFullPath = filePath + return config, err +} + +func UpdateLogLevel(level string) { + config.Log.Level = level +} + +func GetLogLevel() string { + return config.Log.Level +} + +func WriteToFile() error { + return config.writeToFile() +} + +func (c Config) writeToFile() error { + var ( + err error + buff bytes.Buffer + ) + e := yaml.NewEncoder(&buff) + if err = e.Encode(c); err != nil { + return err + } + if err = ioutil.WriteFile(configFileFullPath+".tmp", buff.Bytes(), 0644); err != nil { + return err + } + os.Remove(configFileFullPath) + return os.Rename(configFileFullPath+".tmp", configFileFullPath) +} + +func (c Config) GetLogger(config LogConfig) (*zap.AtomicLevel, *zap.Logger, error) { + var logLevel zapcore.Level + err := logLevel.UnmarshalText([]byte(config.Level)) + if err != nil { + return nil, nil, err + } + var level = zap.NewAtomicLevelAt(logLevel) + if config.FilePath == "" { + cfg := zap.NewDevelopmentConfig() + cfg.Level = level + cfg.EncoderConfig.ConsoleSeparator = " " + cfg.EncoderConfig.LineEnding = zapcore.DefaultLineEnding + cfg.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder // zapcore.TimeEncoderOfLayout("2006-01-02 15:04:05.000") + cfg.EncoderConfig.EncodeDuration = zapcore.SecondsDurationEncoder + cfg.EncoderConfig.EncodeCaller = zapcore.ShortCallerEncoder + cfg.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder + logger, err := cfg.Build(zap.AddStacktrace(zapcore.PanicLevel)) + if err != nil { + return nil, nil, err + } + return &level, logger.Named(EdgeMqttBroker), nil + } + + writeSyncer := getLogWriter(config) + encoder := getEncoder() + core := zapcore.NewCore(encoder, writeSyncer, level.Level()) + logger := zap.New(core, zap.AddCaller(), zap.AddStacktrace(zapcore.PanicLevel)) + + return &level, logger.Named(EdgeMqttBroker), nil +} + +func getEncoder() zapcore.Encoder { + encoderConfig := zap.NewProductionEncoderConfig() + encoderConfig.LineEnding = zapcore.DefaultLineEnding + encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder // zapcore.TimeEncoderOfLayout("2006-01-02 15:04:05.000") + encoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder + encoderConfig.EncodeDuration = zapcore.SecondsDurationEncoder + encoderConfig.EncodeCaller = zapcore.ShortCallerEncoder + encoderConfig.ConsoleSeparator = " " + return zapcore.NewConsoleEncoder(encoderConfig) +} + +func getLogWriter(cfg LogConfig) zapcore.WriteSyncer { + lumberJackLogger := &lumberjack.Logger{ + Filename: cfg.FilePath, + MaxSize: 10, + MaxBackups: 3, + MaxAge: 7, + Compress: false, + } + return zapcore.AddSync(lumberJackLogger) +} diff --git a/internal/hummingbird/mqttbroker/config/config_mock.go b/internal/hummingbird/mqttbroker/config/config_mock.go new file mode 100644 index 0000000..8304c34 --- /dev/null +++ b/internal/hummingbird/mqttbroker/config/config_mock.go @@ -0,0 +1,62 @@ +// Code generated by config. DO NOT EDIT. +// Source: config/config.go + +// Package config is a generated GoMock package. +package config + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockConfiguration is a mock of Configuration interface +type MockConfiguration struct { + ctrl *gomock.Controller + recorder *MockConfigurationMockRecorder +} + +// MockConfigurationMockRecorder is the mock recorder for MockConfiguration +type MockConfigurationMockRecorder struct { + mock *MockConfiguration +} + +// NewMockConfiguration creates a new mock instance +func NewMockConfiguration(ctrl *gomock.Controller) *MockConfiguration { + mock := &MockConfiguration{ctrl: ctrl} + mock.recorder = &MockConfigurationMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockConfiguration) EXPECT() *MockConfigurationMockRecorder { + return m.recorder +} + +// Validate mocks base method +func (m *MockConfiguration) Validate() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Validate") + ret0, _ := ret[0].(error) + return ret0 +} + +// Validate indicates an expected call of Validate +func (mr *MockConfigurationMockRecorder) Validate() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Validate", reflect.TypeOf((*MockConfiguration)(nil).Validate)) +} + +// UnmarshalYAML mocks base method +func (m *MockConfiguration) UnmarshalYAML(unmarshal func(interface{}) error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UnmarshalYAML", unmarshal) + ret0, _ := ret[0].(error) + return ret0 +} + +// UnmarshalYAML indicates an expected call of UnmarshalYAML +func (mr *MockConfigurationMockRecorder) UnmarshalYAML(unmarshal interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnmarshalYAML", reflect.TypeOf((*MockConfiguration)(nil).UnmarshalYAML), unmarshal) +} diff --git a/internal/hummingbird/mqttbroker/config/config_test.go b/internal/hummingbird/mqttbroker/config/config_test.go new file mode 100644 index 0000000..eef88b0 --- /dev/null +++ b/internal/hummingbird/mqttbroker/config/config_test.go @@ -0,0 +1,36 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseConfig(t *testing.T) { + var tt = []struct { + caseName string + fileName string + hasErr bool + expected Config + }{ + { + caseName: "defaultConfig", + fileName: "", + hasErr: false, + expected: DefaultConfig(), + }, + } + + for _, v := range tt { + t.Run(v.caseName, func(t *testing.T) { + a := assert.New(t) + c, err := ParseConfig(v.fileName) + if v.hasErr { + a.NotNil(err) + } else { + a.Nil(err) + } + a.Equal(v.expected, c) + }) + } +} diff --git a/internal/hummingbird/mqttbroker/config/mqtt.go b/internal/hummingbird/mqttbroker/config/mqtt.go new file mode 100644 index 0000000..627092e --- /dev/null +++ b/internal/hummingbird/mqttbroker/config/mqtt.go @@ -0,0 +1,118 @@ +package config + +import ( + "fmt" + "time" + + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +const ( + Overlap = "overlap" + OnlyOnce = "onlyonce" +) + +var ( + // DefaultMQTTConfig + DefaultMQTTConfig = MQTT{ + SessionExpiry: 2 * time.Hour, + SessionExpiryCheckInterval: 20 * time.Second, + MessageExpiry: 2 * time.Hour, + InflightExpiry: 30 * time.Second, + MaxPacketSize: packets.MaximumSize, + ReceiveMax: 100, + MaxKeepAlive: 300, + TopicAliasMax: 10, + SubscriptionIDAvailable: true, + SharedSubAvailable: true, + WildcardAvailable: true, + RetainAvailable: true, + MaxQueuedMsg: 1000, + MaxInflight: 100, + MaximumQoS: 2, + QueueQos0Msg: true, + DeliveryMode: OnlyOnce, + AllowZeroLenClientID: true, + } +) + +type MQTT struct { + // SessionExpiry is the maximum session expiry interval in seconds. + SessionExpiry time.Duration `yaml:"session_expiry"` + // SessionExpiryCheckInterval is the interval time for session expiry checker to check whether there + // are expired sessions. + SessionExpiryCheckInterval time.Duration `yaml:"session_expiry_check_interval"` + // MessageExpiry is the maximum lifetime of the message in seconds. + // If a message in the queue is not sent in MessageExpiry time, it will be removed, which means it will not be sent to the subscriber. + MessageExpiry time.Duration `yaml:"message_expiry"` + // InflightExpiry is the lifetime of the "inflight" message in seconds. + // If a "inflight" message is not acknowledged by a client in InflightExpiry time, it will be removed when the message queue is full. + InflightExpiry time.Duration `yaml:"inflight_expiry"` + // MaxPacketSize is the maximum packet size that the server is willing to accept from the client + MaxPacketSize uint32 `yaml:"max_packet_size"` + // ReceiveMax limits the number of QoS 1 and QoS 2 publications that the server is willing to process concurrently for the client. + ReceiveMax uint16 `yaml:"server_receive_maximum"` + // MaxKeepAlive is the maximum keep alive time in seconds allows by the server. + // If the client requests a keepalive time bigger than MaxKeepalive, + // the server will use MaxKeepAlive as the keepalive time. + // In this case, if the client version is v5, the server will set MaxKeepalive into CONNACK to inform the client. + // But if the client version is 3.x, the server has no way to inform the client that the keepalive time has been changed. + MaxKeepAlive uint16 `yaml:"max_keepalive"` + // TopicAliasMax indicates the highest value that the server will accept as a Topic Alias sent by the client. + // No-op if the client version is MQTTv3.x + TopicAliasMax uint16 `yaml:"topic_alias_maximum"` + // SubscriptionIDAvailable indicates whether the server supports Subscription Identifiers. + // No-op if the client version is MQTTv3.x . + SubscriptionIDAvailable bool `yaml:"subscription_identifier_available"` + // SharedSubAvailable indicates whether the server supports Shared Subscriptions. + SharedSubAvailable bool `yaml:"shared_subscription_available"` + // WildcardSubAvailable indicates whether the server supports Wildcard Subscriptions. + WildcardAvailable bool `yaml:"wildcard_subscription_available"` + // RetainAvailable indicates whether the server supports retained messages. + RetainAvailable bool `yaml:"retain_available"` + // MaxQueuedMsg is the maximum queue length of the outgoing messages. + // If the queue is full, some message will be dropped. + // The message dropping strategy is described in the document of the persistence/queue.Store interface. + MaxQueuedMsg int `yaml:"max_queued_messages"` + // MaxInflight limits inflight message length of the outgoing messages. + // Inflight message is also stored in the message queue, so it must be less than or equal to MaxQueuedMsg. + // Inflight message is the QoS 1 or QoS 2 message that has been sent out to a client but not been acknowledged yet. + MaxInflight uint16 `yaml:"max_inflight"` + // MaximumQoS is the highest QOS level permitted for a Publish. + MaximumQoS uint8 `yaml:"maximum_qos"` + // QueueQos0Msg indicates whether to store QoS 0 message for a offline session. + QueueQos0Msg bool `yaml:"queue_qos0_messages"` + // DeliveryMode is the delivery mode. The possible value can be "overlap" or "onlyonce". + // It is possible for a client’s subscriptions to overlap so that a published message might match multiple filters. + // When set to "overlap" , the server will deliver one message for each matching subscription and respecting the subscription’s QoS in each case. + // When set to "onlyonce",the server will deliver the message to the client respecting the maximum QoS of all the matching subscriptions. + DeliveryMode string `yaml:"delivery_mode"` + // AllowZeroLenClientID indicates whether to allow a client to connect with empty client id. + AllowZeroLenClientID bool `yaml:"allow_zero_length_clientid"` +} + +func (c MQTT) Validate() error { + if c.MaximumQoS > packets.Qos2 { + return fmt.Errorf("invalid maximum_qos: %d", c.MaximumQoS) + } + if c.MaxQueuedMsg <= 0 { + return fmt.Errorf("invalid max_queued_messages : %d", c.MaxQueuedMsg) + } + if c.ReceiveMax == 0 { + return fmt.Errorf("server_receive_maximum cannot be 0") + } + if c.MaxPacketSize == 0 { + return fmt.Errorf("max_packet_size cannot be 0") + } + if c.MaxInflight == 0 { + return fmt.Errorf("max_inflight cannot be 0") + } + if c.DeliveryMode != Overlap && c.DeliveryMode != OnlyOnce { + return fmt.Errorf("invalid delivery_mode: %s", c.DeliveryMode) + } + + if c.MaxQueuedMsg < int(c.MaxInflight) { + return fmt.Errorf("max_queued_message cannot be less than max_inflight") + } + return nil +} diff --git a/internal/hummingbird/mqttbroker/config/persistence.go b/internal/hummingbird/mqttbroker/config/persistence.go new file mode 100644 index 0000000..c05a6c5 --- /dev/null +++ b/internal/hummingbird/mqttbroker/config/persistence.go @@ -0,0 +1,81 @@ +package config + +import ( + "net" + "time" + + "github.com/pkg/errors" +) + +type PersistenceType = string + +const ( + PersistenceTypeMemory PersistenceType = "memory" + PersistenceTypeRedis PersistenceType = "redis" +) + +var ( + defaultMaxActive = uint(0) + defaultMaxIdle = uint(1000) + // DefaultPersistenceConfig is the default value of Persistence + DefaultPersistenceConfig = Persistence{ + Type: PersistenceTypeMemory, + Redis: RedisPersistence{ + Addr: "127.0.0.1:6379", + Password: "", + Database: 0, + MaxIdle: &defaultMaxIdle, + MaxActive: &defaultMaxActive, + IdleTimeout: 240 * time.Second, + }, + } +) + +// Persistence is the config of backend persistence. +type Persistence struct { + // Type is the persistence type. + // If empty, use "memory" as default. + Type PersistenceType `yaml:"type"` + // Redis is the redis configuration and must be set when Type == "redis". + Redis RedisPersistence `yaml:"redis"` +} + +// RedisPersistence is the configuration of redis persistence. +type RedisPersistence struct { + // Addr is the redis server address. + // If empty, use "127.0.0.1:6379" as default. + Addr string `yaml:"addr"` + // Password is the redis password. + Password string `yaml:"password"` + // Database is the number of the redis database to be connected. + Database uint `yaml:"database"` + // MaxIdle is the maximum number of idle connections in the pool. + // If nil, use 1000 as default. + // This value will pass to redis.Pool.MaxIde. + MaxIdle *uint `yaml:"max_idle"` + // MaxActive is the maximum number of connections allocated by the pool at a given time. + // If nil, use 0 as default. + // If zero, there is no limit on the number of connections in the pool. + // This value will pass to redis.Pool.MaxActive. + MaxActive *uint `yaml:"max_active"` + // Close connections after remaining idle for this duration. If the value + // is zero, then idle connections are not closed. Applications should set + // the timeout to a value less than the server's timeout. + // Ff zero, use 240 * time.Second as default. + // This value will pass to redis.Pool.IdleTimeout. + IdleTimeout time.Duration `yaml:"idle_timeout"` +} + +func (p *Persistence) Validate() error { + if p.Type != PersistenceTypeMemory && p.Type != PersistenceTypeRedis { + return errors.New("invalid persistence type") + } + _, _, err := net.SplitHostPort(p.Redis.Addr) + if err != nil { + return err + } + if p.Redis.Database < 0 { + return errors.New("invalid redis database number") + } + return nil +} diff --git a/internal/hummingbird/mqttbroker/config/testdata/config.yml b/internal/hummingbird/mqttbroker/config/testdata/config.yml new file mode 100644 index 0000000..47828d1 --- /dev/null +++ b/internal/hummingbird/mqttbroker/config/testdata/config.yml @@ -0,0 +1,31 @@ +listeners: + - address: ":58090" + websocket: + path: "/" + - address: ":1234" + +mqtt: + session_expiry: 1m + message_expiry: 1m + max_packet_size: 200 + server_receive_maximum: 65535 + max_keepalive: 0 # unlimited + topic_alias_maximum: 0 # 0 means not Supported + subscription_identifier_available: true + wildcard_subscription_available: true + shared_subscription_available: true + maximum_qos: 2 + retain_available: true + max_queued_messages: 1000 + max_inflight: 32 + max_awaiting_rel: 100 + queue_qos0_messages: true + delivery_mode: overlap # overlap or onlyonce + allow_zero_length_clientid: true + +log: + level: debug # debug | info | warning | error + + + + diff --git a/internal/hummingbird/mqttbroker/config/testdata/default_values.yml b/internal/hummingbird/mqttbroker/config/testdata/default_values.yml new file mode 100644 index 0000000..cd24352 --- /dev/null +++ b/internal/hummingbird/mqttbroker/config/testdata/default_values.yml @@ -0,0 +1,7 @@ +listeners: +mqtt: +log: + + + + diff --git a/internal/hummingbird/mqttbroker/config/testdata/default_values_expected.yml b/internal/hummingbird/mqttbroker/config/testdata/default_values_expected.yml new file mode 100644 index 0000000..e69de29 diff --git a/internal/hummingbird/mqttbroker/config/topic_alias.go b/internal/hummingbird/mqttbroker/config/topic_alias.go new file mode 100644 index 0000000..8ed8299 --- /dev/null +++ b/internal/hummingbird/mqttbroker/config/topic_alias.go @@ -0,0 +1,19 @@ +package config + +type TopicAliasType = string + +const ( + TopicAliasMgrTypeFIFO TopicAliasType = "fifo" +) + +var ( + // DefaultTopicAliasManager is the default value of TopicAliasManager + DefaultTopicAliasManager = TopicAliasManager{ + Type: TopicAliasMgrTypeFIFO, + } +) + +// TopicAliasManager is the config of the topic alias manager. +type TopicAliasManager struct { + Type TopicAliasType +} diff --git a/internal/hummingbird/mqttbroker/database.go b/internal/hummingbird/mqttbroker/database.go new file mode 100644 index 0000000..967df1c --- /dev/null +++ b/internal/hummingbird/mqttbroker/database.go @@ -0,0 +1,59 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package mqttbroker + +import ( + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/infrastructure/sqlite" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/interfaces" + "go.uber.org/zap" +) + +var ( + DbClient interfaces.DBClient +) + +func GetDbClient() interfaces.DBClient { + return DbClient +} + +type Database struct { + conf config.Config +} + +// NewDatabase is a factory method that returns an initialized Database receiver struct. +func NewDatabase(conf config.Config) Database { + return Database{ + conf: conf, + } +} + +// init the dbClient interfaces +func (d Database) InitDBClient( + lc *zap.Logger) error { + dbClient, err := sqlite.NewClient(dtos.Configuration{ + Cluster: d.conf.Database.Cluster, + Username: d.conf.Database.Username, + Password: d.conf.Database.Password, + DataSource: d.conf.Database.DataSource, + DatabaseName: d.conf.Database.Name, + }, lc) + if err != nil { + return err + } + DbClient = dbClient + return nil +} diff --git a/internal/hummingbird/mqttbroker/infrastructure/sqlite/client.go b/internal/hummingbird/mqttbroker/infrastructure/sqlite/client.go new file mode 100644 index 0000000..47f1e0c --- /dev/null +++ b/internal/hummingbird/mqttbroker/infrastructure/sqlite/client.go @@ -0,0 +1,60 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package sqlite + +import ( + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/models" + "github.com/winc-link/hummingbird/internal/pkg/errort" + clientSQLite "github.com/winc-link/hummingbird/internal/tools/sqldb/sqlite" + "go.uber.org/zap" + "gorm.io/gorm" +) + +type Client struct { + Pool *gorm.DB + client clientSQLite.ClientSQLite + loggingClient *zap.Logger +} + +func NewClient(config dtos.Configuration, lc *zap.Logger) (c *Client, errEdgeX error) { + client, err := clientSQLite.NewGormClient(config, nil) + if err != nil { + errEdgeX = errort.NewCommonEdgeX(errort.DefaultSystemError, "database failed to init", err) + return + } + client.Pool = client.Pool.Debug() + // 自动建表 + if err = client.InitTable( + &models.MqttAuth{}, + ); err != nil { + errEdgeX = errort.NewCommonEdgeX(errort.DefaultSystemError, "database failed to init", err) + return + } + c = &Client{ + client: client, + loggingClient: lc, + Pool: client.Pool, + } + return +} + +func (client *Client) CloseSession() { + return +} + +func (client *Client) GetMqttAutInfo(clientId string) (models.MqttAuth, error) { + return getMqttAutInfo(client, clientId) +} diff --git a/internal/hummingbird/mqttbroker/infrastructure/sqlite/mqttauth.go b/internal/hummingbird/mqttbroker/infrastructure/sqlite/mqttauth.go new file mode 100644 index 0000000..45a4358 --- /dev/null +++ b/internal/hummingbird/mqttbroker/infrastructure/sqlite/mqttauth.go @@ -0,0 +1,26 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package sqlite + +import "github.com/winc-link/hummingbird/internal/models" + +func getMqttAutInfo(c *Client, clientId string) (models.MqttAuth, error) { + var mqttAuth models.MqttAuth + + if err := c.Pool.Where("client_id = ?", clientId).First(&mqttAuth).Error; err != nil { + return models.MqttAuth{}, err + } + return mqttAuth, nil +} diff --git a/internal/hummingbird/mqttbroker/interfaces/db.go b/internal/hummingbird/mqttbroker/interfaces/db.go new file mode 100644 index 0000000..fcd8dd2 --- /dev/null +++ b/internal/hummingbird/mqttbroker/interfaces/db.go @@ -0,0 +1,23 @@ +/******************************************************************************* + * Copyright 2017 Dell Inc. + * Copyright (c) 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + *******************************************************************************/ +package interfaces + +import "github.com/winc-link/hummingbird/internal/models" + +type DBClient interface { + CloseSession() + + GetMqttAutInfo(clientId string) (models.MqttAuth, error) +} diff --git a/internal/hummingbird/mqttbroker/message.go b/internal/hummingbird/mqttbroker/message.go new file mode 100644 index 0000000..d437cdc --- /dev/null +++ b/internal/hummingbird/mqttbroker/message.go @@ -0,0 +1,191 @@ +package mqttbroker + +import ( + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +type Message struct { + Dup bool + QoS uint8 + Retained bool + Topic string + Payload []byte + PacketID packets.PacketID + // The following fields are introduced in v5 specification. + // Excepting MessageExpiry, these fields will not take effect when it represents a v3.x publish packet. + ContentType string + CorrelationData []byte + MessageExpiry uint32 + PayloadFormat packets.PayloadFormat + ResponseTopic string + SubscriptionIdentifier []uint32 + UserProperties []packets.UserProperty +} + +// Copy deep copies the Message and return the new one +func (m *Message) Copy() *Message { + newMsg := &Message{ + Dup: m.Dup, + QoS: m.QoS, + Retained: m.Retained, + Topic: m.Topic, + PacketID: m.PacketID, + ContentType: m.ContentType, + MessageExpiry: m.MessageExpiry, + PayloadFormat: m.PayloadFormat, + ResponseTopic: m.ResponseTopic, + } + newMsg.Payload = make([]byte, len(m.Payload)) + copy(newMsg.Payload, m.Payload) + + if len(m.CorrelationData) != 0 { + newMsg.CorrelationData = make([]byte, len(m.CorrelationData)) + copy(newMsg.CorrelationData, m.CorrelationData) + } + + if len(m.SubscriptionIdentifier) != 0 { + newMsg.SubscriptionIdentifier = make([]uint32, len(m.SubscriptionIdentifier)) + copy(newMsg.SubscriptionIdentifier, m.SubscriptionIdentifier) + } + if len(m.UserProperties) != 0 { + newMsg.UserProperties = make([]packets.UserProperty, len(m.UserProperties)) + for k := range newMsg.UserProperties { + newMsg.UserProperties[k].K = make([]byte, len(m.UserProperties[k].K)) + copy(newMsg.UserProperties[k].K, m.UserProperties[k].K) + + newMsg.UserProperties[k].V = make([]byte, len(m.UserProperties[k].V)) + copy(newMsg.UserProperties[k].V, m.UserProperties[k].V) + } + } + return newMsg + +} + +func getVariablelenght(l int) int { + if l <= 127 { + return 1 + } else if l <= 16383 { + return 2 + } else if l <= 2097151 { + return 3 + } else if l <= 268435455 { + return 4 + } + return 0 +} + +// TotalBytes return the publish packets total bytes. +func (m *Message) TotalBytes(version packets.Version) uint32 { + remainLenght := len(m.Payload) + 2 + len(m.Topic) + if m.QoS > packets.Qos0 { + remainLenght += 2 + } + if version == packets.Version5 { + propertyLenght := 0 + if m.PayloadFormat == packets.PayloadFormatString { + propertyLenght += 2 + } + if l := len(m.ContentType); l != 0 { + propertyLenght += 3 + l + } + if l := len(m.CorrelationData); l != 0 { + propertyLenght += 3 + l + } + + for _, v := range m.SubscriptionIdentifier { + propertyLenght++ + propertyLenght += getVariablelenght(int(v)) + } + + if m.MessageExpiry != 0 { + propertyLenght += 5 + } + if l := len(m.ResponseTopic); l != 0 { + propertyLenght += 3 + l + } + for _, v := range m.UserProperties { + propertyLenght += 5 + len(v.K) + len(v.V) + } + remainLenght += propertyLenght + getVariablelenght(propertyLenght) + } + if remainLenght <= 127 { + return 2 + uint32(remainLenght) + } else if remainLenght <= 16383 { + return 3 + uint32(remainLenght) + } else if remainLenght <= 2097151 { + return 4 + uint32(remainLenght) + } + return 5 + uint32(remainLenght) +} + +// MessageFromPublish create the Message instance from publish packets +func MessageFromPublish(p *packets.Publish) *Message { + m := &Message{ + Dup: p.Dup, + QoS: p.Qos, + Retained: p.Retain, + Topic: string(p.TopicName), + Payload: p.Payload, + } + if p.Version == packets.Version5 { + if p.Properties.PayloadFormat != nil { + m.PayloadFormat = *p.Properties.PayloadFormat + } + if l := len(p.Properties.ContentType); l != 0 { + m.ContentType = string(p.Properties.ContentType) + } + if l := len(p.Properties.CorrelationData); l != 0 { + m.CorrelationData = p.Properties.CorrelationData + } + if p.Properties.MessageExpiry != nil { + m.MessageExpiry = *p.Properties.MessageExpiry + } + if l := len(p.Properties.ResponseTopic); l != 0 { + m.ResponseTopic = string(p.Properties.ResponseTopic) + } + m.UserProperties = p.Properties.User + + } + return m +} + +// MessageToPublish create the publish packet instance from *Message +func MessageToPublish(msg *Message, version packets.Version) *packets.Publish { + pub := &packets.Publish{ + Dup: msg.Dup, + Qos: msg.QoS, + PacketID: msg.PacketID, + Retain: msg.Retained, + TopicName: []byte(msg.Topic), + Payload: msg.Payload, + Version: version, + } + if version == packets.Version5 { + var msgExpiry *uint32 + if e := msg.MessageExpiry; e != 0 { + msgExpiry = &e + } + var contentType []byte + if msg.ContentType != "" { + contentType = []byte(msg.ContentType) + } + var responseTopic []byte + if msg.ResponseTopic != "" { + responseTopic = []byte(msg.ResponseTopic) + } + var payloadFormat *byte + if e := msg.PayloadFormat; e == packets.PayloadFormatString { + payloadFormat = &e + } + pub.Properties = &packets.Properties{ + CorrelationData: msg.CorrelationData, + ContentType: contentType, + MessageExpiry: msgExpiry, + ResponseTopic: responseTopic, + PayloadFormat: payloadFormat, + User: msg.UserProperties, + SubscriptionIdentifier: msg.SubscriptionIdentifier, + } + } + return pub +} diff --git a/internal/hummingbird/mqttbroker/mock_gen.sh b/internal/hummingbird/mqttbroker/mock_gen.sh new file mode 100755 index 0000000..4da3f5b --- /dev/null +++ b/internal/hummingbird/mqttbroker/mock_gen.sh @@ -0,0 +1,23 @@ +mockgen -source=config/config.go -destination=./config/config_mock.go -package=config -self_package=gitlab.com/tedge/edgex/internal/thummingbird/mqttbroker/config +mockgen -source=persistence/queue/elem.go -destination=./persistence/queue/elem_mock.go -package=queue -self_package=gitlab.com/tedge/edgex/internal/thummingbird/mqttbroker/queue +mockgen -source=persistence/queue/queue.go -destination=./persistence/queue/queue_mock.go -package=queue -self_package=gitlab.com/tedge/edgex/internal/thummingbird/mqttbroker/queue +mockgen -source=persistence/session/session.go -destination=./persistence/session/session_mock.go -package=session -self_package=gitlab.com/tedge/edgex/internal/thummingbird/mqttbroker/session +mockgen -source=persistence/subscription/subscription.go -destination=./persistence/subscription/subscription_mock.go -package=subscription -self_package=gitlab.com/tedge/edgex/internal/thummingbird/mqttbroker/subscription +mockgen -source=persistence/unack/unack.go -destination=./persistence/unack/unack_mock.go -package=unack -self_package=gitlab.com/tedge/edgex/internal/thummingbird/mqttbroker/unack +mockgen -source=pkg/packets/packets.go -destination=./pkg/packets/packets_mock.go -package=packets -self_package=gitlab.com/tedge/edgex/internal/thummingbird/mqttbroker/packets +mockgen -source=plugin/auth/account_grpc.pb.go -destination=./plugin/auth/account_grpc.pb_mock.go -package=auth -self_package=gitlab.com/tedge/edgex/internal/thummingbird/mqttbroker/auth +mockgen -source=plugin/federation/federation.pb.go -destination=./plugin/federation/federation.pb_mock.go -package=federation -self_package=gitlab.com/tedge/edgex/internal/thummingbird/mqttbroker/federation +mockgen -source=plugin/federation/peer.go -destination=./plugin/federation/peer_mock.go -package=federation -self_package=gitlab.com/tedge/edgex/internal/thummingbird/mqttbroker/federation +mockgen -source=plugin/federation/membership.go -destination=./plugin/federation/membership_mock.go -package=federation -self_package=gitlab.com/tedge/edgex/internal/thummingbird/mqttbroker/federation +mockgen -source=retained/interface.go -destination=./retained/interface_mock.go -package=retained -self_package=gitlab.com/tedge/edgex/internal/thummingbird/mqttbroker/retained +mockgen -source=server/client.go -destination=./server/client_mock.go -package=server -self_package=gitlab.com/tedge/edgex/internal/thummingbird/mqttbroker/server +mockgen -source=server/persistence.go -destination=./server/persistence_mock.go -package=server -self_package=gitlab.com/tedge/edgex/internal/thummingbird/mqttbroker/server +mockgen -source=server/plugin.go -destination=./server/plugin_mock.go -package=server -self_package=gitlab.com/tedge/edgex/internal/thummingbird/mqttbroker/server +mockgen -source=server/server.go -destination=./server/server_mock.go -package=server -self_package=gitlab.com/tedge/edgex/internal/thummingbird/mqttbroker/server +mockgen -source=server/service.go -destination=./server/service_mock.go -package=server -self_package=gitlab.com/tedge/edgex/internal/thummingbird/mqttbroker/server +mockgen -source=server/stats.go -destination=./server/stats_mock.go -package=server -self_package=gitlab.com/tedge/edgex/internal/thummingbird/mqttbroker/server +mockgen -source=server/topic_alias.go -destination=./server/topic_alias_mock.go -package=server -self_package=gitlab.com/tedge/edgex/internal/thummingbird/mqttbroker/server + +# reflection mode. +# gRPC streaming mock issue: https://github.com/golang/mock/pull/163 +mockgen -package=federation -destination=/usr/local/gopath/src/gitlab.com/tedge/edgex/internal/thummingbird/mqttbroker/plugin/federation/federation_grpc.pb_mock.go gitlab.com/tedge/edgex/internal/thummingbird/mqttbroker/plugin/federation FederationClient,Federation_EventStreamClient diff --git a/internal/hummingbird/mqttbroker/persistence/encoding/binary.go b/internal/hummingbird/mqttbroker/persistence/encoding/binary.go new file mode 100644 index 0000000..8e79b07 --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/encoding/binary.go @@ -0,0 +1,73 @@ +package encoding + +import ( + "bytes" + "encoding/binary" + "errors" + "io" +) + +func WriteUint16(w *bytes.Buffer, i uint16) { + w.WriteByte(byte(i >> 8)) + w.WriteByte(byte(i)) +} + +func WriteBool(w *bytes.Buffer, b bool) { + if b { + w.WriteByte(1) + } else { + w.WriteByte(0) + } +} + +func ReadBool(r *bytes.Buffer) (bool, error) { + b, err := r.ReadByte() + if err != nil { + return false, err + } + if b == 0 { + return false, nil + } + return true, nil +} + +func WriteString(w *bytes.Buffer, s []byte) { + WriteUint16(w, uint16(len(s))) + w.Write(s) +} +func ReadString(r *bytes.Buffer) (b []byte, err error) { + l := make([]byte, 2) + _, err = io.ReadFull(r, l) + if err != nil { + return nil, err + } + length := int(binary.BigEndian.Uint16(l)) + paylaod := make([]byte, length) + + _, err = io.ReadFull(r, paylaod) + if err != nil { + return nil, err + } + return paylaod, nil +} + +func WriteUint32(w *bytes.Buffer, i uint32) { + w.WriteByte(byte(i >> 24)) + w.WriteByte(byte(i >> 16)) + w.WriteByte(byte(i >> 8)) + w.WriteByte(byte(i)) +} + +func ReadUint16(r *bytes.Buffer) (uint16, error) { + if r.Len() < 2 { + return 0, errors.New("invalid length") + } + return binary.BigEndian.Uint16(r.Next(2)), nil +} + +func ReadUint32(r *bytes.Buffer) (uint32, error) { + if r.Len() < 4 { + return 0, errors.New("invalid length") + } + return binary.BigEndian.Uint32(r.Next(4)), nil +} diff --git a/internal/hummingbird/mqttbroker/persistence/encoding/redis.go b/internal/hummingbird/mqttbroker/persistence/encoding/redis.go new file mode 100644 index 0000000..294838c --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/encoding/redis.go @@ -0,0 +1,189 @@ +package encoding + +import ( + "bytes" + "encoding/binary" + "io" + "time" + + "github.com/winc-link/hummingbird/internal/pkg/packets" + + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" +) + +// EncodeMessage encodes message into bytes and write it to the buffer +func EncodeMessage(msg *gmqtt.Message, b *bytes.Buffer) { + if msg == nil { + return + } + WriteBool(b, msg.Dup) + b.WriteByte(msg.QoS) + WriteBool(b, msg.Retained) + WriteString(b, []byte(msg.Topic)) + WriteString(b, []byte(msg.Payload)) + WriteUint16(b, msg.PacketID) + + if len(msg.ContentType) != 0 { + b.WriteByte(packets.PropContentType) + WriteString(b, []byte(msg.ContentType)) + } + if len(msg.CorrelationData) != 0 { + b.WriteByte(packets.PropCorrelationData) + WriteString(b, []byte(msg.CorrelationData)) + } + if msg.MessageExpiry != 0 { + b.WriteByte(packets.PropMessageExpiry) + WriteUint32(b, msg.MessageExpiry) + } + b.WriteByte(packets.PropPayloadFormat) + b.WriteByte(msg.PayloadFormat) + + if len(msg.ResponseTopic) != 0 { + b.WriteByte(packets.PropResponseTopic) + WriteString(b, []byte(msg.ResponseTopic)) + } + for _, v := range msg.SubscriptionIdentifier { + b.WriteByte(packets.PropSubscriptionIdentifier) + l, _ := packets.DecodeRemainLength(int(v)) + b.Write(l) + } + for _, v := range msg.UserProperties { + b.WriteByte(packets.PropUser) + WriteString(b, v.K) + WriteString(b, v.V) + } + return +} + +// DecodeMessage decodes message from buffer. +func DecodeMessage(b *bytes.Buffer) (msg *gmqtt.Message, err error) { + msg = &gmqtt.Message{} + msg.Dup, err = ReadBool(b) + if err != nil { + return + } + msg.QoS, err = b.ReadByte() + if err != nil { + return + } + msg.Retained, err = ReadBool(b) + if err != nil { + return + } + topic, err := ReadString(b) + if err != nil { + return + } + msg.Topic = string(topic) + msg.Payload, err = ReadString(b) + if err != nil { + return + } + msg.PacketID, err = ReadUint16(b) + if err != nil { + return + } + for { + pt, err := b.ReadByte() + if err == io.EOF { + return msg, nil + } + if err != nil { + return nil, err + } + switch pt { + case packets.PropContentType: + v, err := ReadString(b) + if err != nil { + return nil, err + } + msg.ContentType = string(v) + case packets.PropCorrelationData: + msg.CorrelationData, err = ReadString(b) + if err != nil { + return nil, err + } + case packets.PropMessageExpiry: + msg.MessageExpiry, err = ReadUint32(b) + if err != nil { + return nil, err + } + case packets.PropPayloadFormat: + msg.PayloadFormat, err = b.ReadByte() + if err != nil { + return nil, err + } + case packets.PropResponseTopic: + v, err := ReadString(b) + if err != nil { + return nil, err + } + msg.ResponseTopic = string(v) + case packets.PropSubscriptionIdentifier: + si, err := packets.EncodeRemainLength(b) + if err != nil { + return nil, err + } + msg.SubscriptionIdentifier = append(msg.SubscriptionIdentifier, uint32(si)) + case packets.PropUser: + k, err := ReadString(b) + if err != nil { + return nil, err + } + v, err := ReadString(b) + if err != nil { + return nil, err + } + msg.UserProperties = append(msg.UserProperties, packets.UserProperty{K: k, V: v}) + } + } +} + +// DecodeMessageFromBytes decodes message from bytes. +func DecodeMessageFromBytes(b []byte) (msg *gmqtt.Message, err error) { + if len(b) == 0 { + return nil, nil + } + return DecodeMessage(bytes.NewBuffer(b)) +} + +func EncodeSession(sess *gmqtt.Session, b *bytes.Buffer) { + WriteString(b, []byte(sess.ClientID)) + if sess.Will != nil { + b.WriteByte(1) + EncodeMessage(sess.Will, b) + WriteUint32(b, sess.WillDelayInterval) + } else { + b.WriteByte(0) + } + time := make([]byte, 8) + binary.BigEndian.PutUint64(time, uint64(sess.ConnectedAt.Unix())) + WriteUint32(b, sess.ExpiryInterval) +} + +func DecodeSession(b *bytes.Buffer) (sess *gmqtt.Session, err error) { + sess = &gmqtt.Session{} + cid, err := ReadString(b) + if err != nil { + return nil, err + } + sess.ClientID = string(cid) + willPresent, err := b.ReadByte() + if err != nil { + return + } + if willPresent == 1 { + sess.Will, err = DecodeMessage(b) + if err != nil { + return + } + sess.WillDelayInterval, err = ReadUint32(b) + if err != nil { + return + } + } + t := binary.BigEndian.Uint64(b.Next(8)) + sess.ConnectedAt = time.Unix(int64(t), 0) + sess.ExpiryInterval, err = ReadUint32(b) + return +} diff --git a/internal/hummingbird/mqttbroker/persistence/memory.go b/internal/hummingbird/mqttbroker/persistence/memory.go new file mode 100644 index 0000000..3ae1211 --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/memory.go @@ -0,0 +1,55 @@ +package persistence + +import ( + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/queue" + mem_queue "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/queue/mem" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/session" + mem_session "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/session/mem" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/subscription" + mem_sub "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/subscription/mem" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/unack" + mem_unack "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/unack/mem" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/server" +) + +func init() { + server.RegisterPersistenceFactory("memory", NewMemory) +} + +func NewMemory(config config.Config) (server.Persistence, error) { + return &memory{}, nil +} + +type memory struct { +} + +func (m *memory) NewUnackStore(config config.Config, clientID string) (unack.Store, error) { + return mem_unack.New(mem_unack.Options{ + ClientID: clientID, + }), nil +} + +func (m *memory) NewSessionStore(config config.Config) (session.Store, error) { + return mem_session.New(), nil +} + +func (m *memory) Open() error { + return nil +} +func (m *memory) NewQueueStore(config config.Config, defaultNotifier queue.Notifier, clientID string) (queue.Store, error) { + return mem_queue.New(mem_queue.Options{ + MaxQueuedMsg: config.MQTT.MaxQueuedMsg, + InflightExpiry: config.MQTT.InflightExpiry, + ClientID: clientID, + DefaultNotifier: defaultNotifier, + }) +} + +func (m *memory) NewSubscriptionStore(config config.Config) (subscription.Store, error) { + return mem_sub.NewStore(), nil +} + +func (m *memory) Close() error { + return nil +} diff --git a/internal/hummingbird/mqttbroker/persistence/queue/elem.go b/internal/hummingbird/mqttbroker/persistence/queue/elem.go new file mode 100644 index 0000000..9e49fd0 --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/queue/elem.go @@ -0,0 +1,116 @@ +package queue + +import ( + "bytes" + "encoding/binary" + "errors" + "time" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/encoding" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +type MessageWithID interface { + ID() packets.PacketID + SetID(id packets.PacketID) +} + +type Publish struct { + *mqttbroker.Message +} + +func (p *Publish) ID() packets.PacketID { + return p.PacketID +} +func (p *Publish) SetID(id packets.PacketID) { + p.PacketID = id +} + +type Pubrel struct { + PacketID packets.PacketID +} + +func (p *Pubrel) ID() packets.PacketID { + return p.PacketID +} +func (p *Pubrel) SetID(id packets.PacketID) { + p.PacketID = id +} + +// Elem represents the element store in the queue. +type Elem struct { + // At represents the entry time. + At time.Time + // Expiry represents the expiry time. + // Empty means never expire. + Expiry time.Time + MessageWithID +} + +// Encode encodes the publish structure into bytes and write it to the buffer +func (p *Publish) Encode(b *bytes.Buffer) { + encoding.EncodeMessage(p.Message, b) +} + +func (p *Publish) Decode(b *bytes.Buffer) (err error) { + msg, err := encoding.DecodeMessage(b) + if err != nil { + return err + } + p.Message = msg + return nil +} + +// Encode encode the pubrel structure into bytes. +func (p *Pubrel) Encode(b *bytes.Buffer) { + encoding.WriteUint16(b, p.PacketID) +} + +func (p *Pubrel) Decode(b *bytes.Buffer) (err error) { + p.PacketID, err = encoding.ReadUint16(b) + return +} + +// Encode encode the elem structure into bytes. +// Format: 8 byte timestamp | 1 byte identifier| data +func (e *Elem) Encode() []byte { + b := bytes.NewBuffer(make([]byte, 0, 100)) + rs := make([]byte, 19) + binary.BigEndian.PutUint64(rs[0:9], uint64(e.At.Unix())) + binary.BigEndian.PutUint64(rs[9:18], uint64(e.Expiry.Unix())) + switch m := e.MessageWithID.(type) { + case *Publish: + rs[18] = 0 + b.Write(rs) + m.Encode(b) + case *Pubrel: + rs[18] = 1 + b.Write(rs) + m.Encode(b) + } + return b.Bytes() +} + +func (e *Elem) Decode(b []byte) (err error) { + if len(b) < 19 { + return errors.New("invalid input length") + } + e.At = time.Unix(int64(binary.BigEndian.Uint64(b[0:9])), 0) + e.Expiry = time.Unix(int64(binary.BigEndian.Uint64(b[9:19])), 0) + switch b[18] { + case 0: // publish + p := &Publish{} + buf := bytes.NewBuffer(b[19:]) + err = p.Decode(buf) + e.MessageWithID = p + case 1: // pubrel + p := &Pubrel{} + buf := bytes.NewBuffer(b[19:]) + err = p.Decode(buf) + e.MessageWithID = p + default: + return errors.New("invalid identifier") + } + return +} diff --git a/internal/hummingbird/mqttbroker/persistence/queue/elem_mock.go b/internal/hummingbird/mqttbroker/persistence/queue/elem_mock.go new file mode 100644 index 0000000..8dca88b --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/queue/elem_mock.go @@ -0,0 +1,61 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: persistence/queue/elem.go + +// Package queue is a generated GoMock package. +package queue + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + packets "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +// MockMessageWithID is a mock of MessageWithID interface +type MockMessageWithID struct { + ctrl *gomock.Controller + recorder *MockMessageWithIDMockRecorder +} + +// MockMessageWithIDMockRecorder is the mock recorder for MockMessageWithID +type MockMessageWithIDMockRecorder struct { + mock *MockMessageWithID +} + +// NewMockMessageWithID creates a new mock instance +func NewMockMessageWithID(ctrl *gomock.Controller) *MockMessageWithID { + mock := &MockMessageWithID{ctrl: ctrl} + mock.recorder = &MockMessageWithIDMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockMessageWithID) EXPECT() *MockMessageWithIDMockRecorder { + return m.recorder +} + +// ID mocks base method +func (m *MockMessageWithID) ID() packets.PacketID { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ID") + ret0, _ := ret[0].(packets.PacketID) + return ret0 +} + +// ID indicates an expected call of ID +func (mr *MockMessageWithIDMockRecorder) ID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockMessageWithID)(nil).ID)) +} + +// SetID mocks base method +func (m *MockMessageWithID) SetID(id packets.PacketID) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetID", id) +} + +// SetID indicates an expected call of SetID +func (mr *MockMessageWithIDMockRecorder) SetID(id interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetID", reflect.TypeOf((*MockMessageWithID)(nil).SetID), id) +} diff --git a/internal/hummingbird/mqttbroker/persistence/queue/elem_test.go b/internal/hummingbird/mqttbroker/persistence/queue/elem_test.go new file mode 100644 index 0000000..d5ba02c --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/queue/elem_test.go @@ -0,0 +1,104 @@ +package queue + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +func assertElemEqual(a *assert.Assertions, expected, actual *Elem) { + expected.At = time.Unix(expected.At.Unix(), 0) + expected.Expiry = time.Unix(expected.Expiry.Unix(), 0) + actual.At = time.Unix(actual.At.Unix(), 0) + actual.Expiry = time.Unix(actual.Expiry.Unix(), 0) + a.Equal(expected, actual) +} + +func TestElem_Encode_Publish(t *testing.T) { + a := assert.New(t) + e := &Elem{ + At: time.Now(), + MessageWithID: &Publish{ + Message: &gmqtt.Message{ + Dup: false, + QoS: 2, + Retained: false, + Topic: "/mytopic", + Payload: []byte("payload"), + PacketID: 2, + ContentType: "type", + CorrelationData: nil, + MessageExpiry: 1, + PayloadFormat: packets.PayloadFormatString, + ResponseTopic: "", + SubscriptionIdentifier: []uint32{1, 2}, + UserProperties: []packets.UserProperty{ + { + K: []byte("1"), + V: []byte("2"), + }, { + K: []byte("3"), + V: []byte("4"), + }, + }, + }, + }, + } + rs := e.Encode() + de := &Elem{} + err := de.Decode(rs) + a.Nil(err) + assertElemEqual(a, e, de) +} +func TestElem_Encode_Pubrel(t *testing.T) { + a := assert.New(t) + e := &Elem{ + At: time.Unix(time.Now().Unix(), 0), + MessageWithID: &Pubrel{ + PacketID: 2, + }, + } + rs := e.Encode() + de := &Elem{} + err := de.Decode(rs) + a.Nil(err) + assertElemEqual(a, e, de) +} + +func Benchmark_Encode_Publish(b *testing.B) { + for i := 0; i < b.N; i++ { + e := &Elem{ + At: time.Unix(time.Now().Unix(), 0), + MessageWithID: &Publish{ + Message: &gmqtt.Message{ + Dup: false, + QoS: 2, + Retained: false, + Topic: "/mytopic", + Payload: []byte("payload"), + PacketID: 2, + ContentType: "type", + CorrelationData: nil, + MessageExpiry: 1, + PayloadFormat: packets.PayloadFormatString, + ResponseTopic: "", + SubscriptionIdentifier: []uint32{1, 2}, + UserProperties: []packets.UserProperty{ + { + K: []byte("1"), + V: []byte("2"), + }, { + K: []byte("3"), + V: []byte("4"), + }, + }, + }, + }, + } + e.Encode() + } +} diff --git a/internal/hummingbird/mqttbroker/persistence/queue/error.go b/internal/hummingbird/mqttbroker/persistence/queue/error.go new file mode 100644 index 0000000..42172fa --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/queue/error.go @@ -0,0 +1,23 @@ +package queue + +import ( + "errors" +) + +var ( + ErrClosed = errors.New("queue has been closed") + ErrDropExceedsMaxPacketSize = errors.New("maximum packet size exceeded") + ErrDropQueueFull = errors.New("the message queue is full") + ErrDropExpired = errors.New("the message is expired") + ErrDropExpiredInflight = errors.New("the inflight message is expired") +) + +// InternalError wraps the error of the backend storage. +type InternalError struct { + // Err is the error return by the backend storage. + Err error +} + +func (i *InternalError) Error() string { + return i.Error() +} diff --git a/internal/hummingbird/mqttbroker/persistence/queue/mem/mem.go b/internal/hummingbird/mqttbroker/persistence/queue/mem/mem.go new file mode 100644 index 0000000..8057769 --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/queue/mem/mem.go @@ -0,0 +1,278 @@ +package mem + +import ( + "container/list" + "sync" + "time" + + "go.uber.org/zap" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/queue" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/server" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +var _ queue.Store = (*Queue)(nil) + +type Options struct { + MaxQueuedMsg int + InflightExpiry time.Duration + ClientID string + DefaultNotifier queue.Notifier +} + +type Queue struct { + cond *sync.Cond + clientID string + version packets.Version + opts *Options + readBytesLimit uint32 + l *list.List + // current is the next element to read. + current *list.Element + inflightDrained bool + closed bool + // max is the maximum queue length + max int + log *zap.Logger + inflightExpiry time.Duration + notifier queue.Notifier +} + +func New(opts Options) (*Queue, error) { + return &Queue{ + clientID: opts.ClientID, + cond: sync.NewCond(&sync.Mutex{}), + l: list.New(), + max: opts.MaxQueuedMsg, + inflightExpiry: opts.InflightExpiry, + notifier: opts.DefaultNotifier, + log: server.LoggerWithField(zap.String("queue", "memory")), + }, nil +} + +func (q *Queue) Close() error { + q.cond.L.Lock() + defer q.cond.L.Unlock() + q.closed = true + q.cond.Signal() + return nil +} + +func (q *Queue) Init(opts *queue.InitOptions) error { + q.cond.L.Lock() + defer q.cond.L.Unlock() + q.closed = false + q.inflightDrained = false + if opts.CleanStart { + q.l = list.New() + } + q.readBytesLimit = opts.ReadBytesLimit + q.version = opts.Version + q.current = q.l.Front() + q.notifier = opts.Notifier + q.cond.Signal() + return nil +} + +func (*Queue) Clean() error { + return nil +} + +func (q *Queue) Add(elem *queue.Elem) (err error) { + now := time.Now() + var dropErr error + var dropElem *list.Element + var drop bool + q.cond.L.Lock() + defer func() { + q.cond.L.Unlock() + q.cond.Signal() + }() + defer func() { + if drop { + if dropErr == queue.ErrDropExpiredInflight { + q.notifier.NotifyInflightAdded(-1) + } + if dropElem == nil { + q.notifier.NotifyDropped(elem, dropErr) + return + } + if dropElem == q.current { + q.current = q.current.Next() + } + q.l.Remove(dropElem) + q.notifier.NotifyDropped(dropElem.Value.(*queue.Elem), dropErr) + } else { + q.notifier.NotifyMsgQueueAdded(1) + } + e := q.l.PushBack(elem) + if q.current == nil { + q.current = e + } + }() + if q.l.Len() >= q.max { + // set default drop error + dropErr = queue.ErrDropQueueFull + drop = true + + // drop expired inflight message + if v := q.l.Front(); v != q.current && + v != nil && + queue.ElemExpiry(now, v.Value.(*queue.Elem)) { + dropElem = v + dropErr = queue.ErrDropExpiredInflight + return + } + + // drop the current elem if there is no more non-inflight messages. + if q.inflightDrained && q.current == nil { + return + } + for e := q.current; e != nil; e = e.Next() { + pub := e.Value.(*queue.Elem).MessageWithID.(*queue.Publish) + // drop expired non-inflight message + if pub.ID() == 0 && + queue.ElemExpiry(now, e.Value.(*queue.Elem)) { + dropElem = e + dropErr = queue.ErrDropExpired + return + } + // drop qos0 message in the queue + if pub.ID() == 0 && pub.QoS == packets.Qos0 && dropElem == nil { + dropElem = e + } + } + if dropElem != nil { + return + } + if elem.MessageWithID.(*queue.Publish).QoS == packets.Qos0 { + return + } + + if q.inflightDrained { + // drop the front message + dropElem = q.current + return + } + // the messages in the queue are all inflight messages, drop the current elem + return + } + return nil +} + +func (q *Queue) Replace(elem *queue.Elem) (replaced bool, err error) { + q.cond.L.Lock() + defer q.cond.L.Unlock() + unread := q.current + for e := q.l.Front(); e != nil && e != unread; e = e.Next() { + if e.Value.(*queue.Elem).ID() == elem.ID() { + e.Value = elem + return true, nil + } + } + return false, nil +} + +func (q *Queue) Read(pids []packets.PacketID) (rs []*queue.Elem, err error) { + now := time.Now() + q.cond.L.Lock() + defer q.cond.L.Unlock() + if !q.inflightDrained { + panic("must call ReadInflight to drain all inflight messages before Read") + } + for (q.l.Len() == 0 || q.current == nil) && !q.closed { + q.cond.Wait() + } + if q.closed { + return nil, queue.ErrClosed + } + length := q.l.Len() + if len(pids) < length { + length = len(pids) + } + var msgQueueDelta, inflightDelta int + var pflag int + for i := 0; i < length && q.current != nil; i++ { + v := q.current + // remove expired message + if queue.ElemExpiry(now, v.Value.(*queue.Elem)) { + q.current = q.current.Next() + q.notifier.NotifyDropped(v.Value.(*queue.Elem), queue.ErrDropExpired) + q.l.Remove(v) + msgQueueDelta-- + continue + } + // remove message which exceeds maximum packet size + pub := v.Value.(*queue.Elem).MessageWithID.(*queue.Publish) + if size := pub.TotalBytes(q.version); size > q.readBytesLimit { + q.current = q.current.Next() + q.notifier.NotifyDropped(v.Value.(*queue.Elem), queue.ErrDropExceedsMaxPacketSize) + q.l.Remove(v) + msgQueueDelta-- + continue + } + + // remove qos 0 message after read + if pub.QoS == 0 { + q.current = q.current.Next() + q.l.Remove(v) + msgQueueDelta-- + } else { + pub.SetID(pids[pflag]) + // When the message becomes inflight message, update the expiry time. + if q.inflightExpiry != 0 { + v.Value.(*queue.Elem).Expiry = now.Add(q.inflightExpiry) + } + pflag++ + inflightDelta++ + q.current = q.current.Next() + } + rs = append(rs, v.Value.(*queue.Elem)) + } + q.notifier.NotifyMsgQueueAdded(msgQueueDelta) + q.notifier.NotifyInflightAdded(inflightDelta) + return rs, nil +} + +func (q *Queue) ReadInflight(maxSize uint) (rs []*queue.Elem, err error) { + q.cond.L.Lock() + defer q.cond.L.Unlock() + length := q.l.Len() + if length == 0 || q.current == nil { + q.inflightDrained = true + return nil, nil + } + if int(maxSize) < length { + length = int(maxSize) + } + for i := 0; i < length && q.current != nil; i++ { + if e := q.current.Value.(*queue.Elem); e.ID() != 0 { + if q.inflightExpiry != 0 { + e.Expiry = time.Now().Add(q.inflightExpiry) + } + rs = append(rs, e) + q.current = q.current.Next() + } else { + q.inflightDrained = true + break + } + } + return rs, nil +} + +func (q *Queue) Remove(pid packets.PacketID) error { + q.cond.L.Lock() + defer q.cond.L.Unlock() + // Must not remove unread messages. + unread := q.current + for e := q.l.Front(); e != nil && e != unread; e = e.Next() { + if e.Value.(*queue.Elem).ID() == pid { + q.l.Remove(e) + q.notifier.NotifyMsgQueueAdded(-1) + q.notifier.NotifyInflightAdded(-1) + return nil + } + } + return nil +} diff --git a/internal/hummingbird/mqttbroker/persistence/queue/queue.go b/internal/hummingbird/mqttbroker/persistence/queue/queue.go new file mode 100644 index 0000000..652ba8d --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/queue/queue.go @@ -0,0 +1,78 @@ +package queue + +import ( + "time" + + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +// InitOptions is used to pass some required client information to the queue.Init() +type InitOptions struct { + // CleanStart is the cleanStart field in the connect packet. + CleanStart bool + // Version is the client MQTT protocol version. + Version packets.Version + // ReadBytesLimit indicates the maximum publish size that is allow to read. + ReadBytesLimit uint32 + Notifier Notifier +} + +// Store represents a queue store for one client. +type Store interface { + // Close will be called when the client disconnect. + // This method must unblock the Read method. + Close() error + // Init will be called when the client connect. + // If opts.CleanStart set to true, the implementation should remove any associated data in backend store. + // If it sets to false, the implementation should be able to retrieve the associated data from backend store. + // The opts.version indicates the protocol version of the connected client, it is mainly used to calculate the publish packet size. + Init(opts *InitOptions) error + Clean() error + // Add inserts a elem to the queue. + // When the len of queue is reaching the maximum setting, the implementation should drop messages according the following priorities: + // 1. Drop the expired inflight message. + // 2. Drop the current elem if there is no more non-inflight messages. + // 3. Drop expired non-inflight message. + // 4. Drop qos0 message. + // 5. Drop the front message. + // See queue.mem for more details. + Add(elem *Elem) error + // Replace replaces the PUBLISH with the PUBREL with the same packet id. + Replace(elem *Elem) (replaced bool, err error) + + // Read reads a batch of new message (non-inflight) from the store. The qos0 messages will be removed after read. + // The size of the batch will be less than or equal to the size of the given packet id list. + // The implementation must remove and do not return any : + // 1. expired messages + // 2. publish message which exceeds the InitOptions.ReadBytesLimit + // while reading. + // The caller must call ReadInflight first to read all inflight message before calling this method. + // Calling this method will be blocked until there are any new messages can be read or the store has been closed. + // If the store has been closed, returns nil, ErrClosed. + Read(pids []packets.PacketID) ([]*Elem, error) + + // ReadInflight reads at most maxSize inflight messages. + // The caller must call this method to read all inflight messages before calling Read method. + // Returning 0 length elems means all inflight messages have been read. + ReadInflight(maxSize uint) (elems []*Elem, err error) + + // Remove removes the elem for a given id. + Remove(pid packets.PacketID) error +} + +type Notifier interface { + // NotifyDropped will be called when the element in the queue is dropped. + // The err indicates the reason of why it is dropped. + // The MessageWithID field in elem param can be queue.Pubrel or queue.Publish. + NotifyDropped(elem *Elem, err error) + NotifyInflightAdded(delta int) + NotifyMsgQueueAdded(delta int) +} + +// ElemExpiry return whether the elem is expired +func ElemExpiry(now time.Time, elem *Elem) bool { + if !elem.Expiry.IsZero() { + return now.After(elem.Expiry) + } + return false +} diff --git a/internal/hummingbird/mqttbroker/persistence/queue/queue_mock.go b/internal/hummingbird/mqttbroker/persistence/queue/queue_mock.go new file mode 100644 index 0000000..18083d9 --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/queue/queue_mock.go @@ -0,0 +1,209 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: persistence/queue/queue.go + +// Package queue is a generated GoMock package. +package queue + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + packets "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +// MockStore is a mock of Store interface +type MockStore struct { + ctrl *gomock.Controller + recorder *MockStoreMockRecorder +} + +// MockStoreMockRecorder is the mock recorder for MockStore +type MockStoreMockRecorder struct { + mock *MockStore +} + +// NewMockStore creates a new mock instance +func NewMockStore(ctrl *gomock.Controller) *MockStore { + mock := &MockStore{ctrl: ctrl} + mock.recorder = &MockStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockStore) EXPECT() *MockStoreMockRecorder { + return m.recorder +} + +// Close mocks base method +func (m *MockStore) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockStoreMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStore)(nil).Close)) +} + +// Init mocks base method +func (m *MockStore) Init(opts *InitOptions) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Init", opts) + ret0, _ := ret[0].(error) + return ret0 +} + +// Init indicates an expected call of Init +func (mr *MockStoreMockRecorder) Init(opts interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockStore)(nil).Init), opts) +} + +// Clean mocks base method +func (m *MockStore) Clean() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Clean") + ret0, _ := ret[0].(error) + return ret0 +} + +// Clean indicates an expected call of Clean +func (mr *MockStoreMockRecorder) Clean() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clean", reflect.TypeOf((*MockStore)(nil).Clean)) +} + +// Add mocks base method +func (m *MockStore) Add(elem *Elem) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Add", elem) + ret0, _ := ret[0].(error) + return ret0 +} + +// Add indicates an expected call of Add +func (mr *MockStoreMockRecorder) Add(elem interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockStore)(nil).Add), elem) +} + +// Replace mocks base method +func (m *MockStore) Replace(elem *Elem) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Replace", elem) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Replace indicates an expected call of Replace +func (mr *MockStoreMockRecorder) Replace(elem interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Replace", reflect.TypeOf((*MockStore)(nil).Replace), elem) +} + +// Read mocks base method +func (m *MockStore) Read(pids []packets.PacketID) ([]*Elem, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read", pids) + ret0, _ := ret[0].([]*Elem) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read +func (mr *MockStoreMockRecorder) Read(pids interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStore)(nil).Read), pids) +} + +// ReadInflight mocks base method +func (m *MockStore) ReadInflight(maxSize uint) ([]*Elem, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadInflight", maxSize) + ret0, _ := ret[0].([]*Elem) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadInflight indicates an expected call of ReadInflight +func (mr *MockStoreMockRecorder) ReadInflight(maxSize interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadInflight", reflect.TypeOf((*MockStore)(nil).ReadInflight), maxSize) +} + +// Remove mocks base method +func (m *MockStore) Remove(pid packets.PacketID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Remove", pid) + ret0, _ := ret[0].(error) + return ret0 +} + +// Remove indicates an expected call of Remove +func (mr *MockStoreMockRecorder) Remove(pid interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockStore)(nil).Remove), pid) +} + +// MockNotifier is a mock of Notifier interface +type MockNotifier struct { + ctrl *gomock.Controller + recorder *MockNotifierMockRecorder +} + +// MockNotifierMockRecorder is the mock recorder for MockNotifier +type MockNotifierMockRecorder struct { + mock *MockNotifier +} + +// NewMockNotifier creates a new mock instance +func NewMockNotifier(ctrl *gomock.Controller) *MockNotifier { + mock := &MockNotifier{ctrl: ctrl} + mock.recorder = &MockNotifierMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockNotifier) EXPECT() *MockNotifierMockRecorder { + return m.recorder +} + +// NotifyDropped mocks base method +func (m *MockNotifier) NotifyDropped(elem *Elem, err error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "NotifyDropped", elem, err) +} + +// NotifyDropped indicates an expected call of NotifyDropped +func (mr *MockNotifierMockRecorder) NotifyDropped(elem, err interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotifyDropped", reflect.TypeOf((*MockNotifier)(nil).NotifyDropped), elem, err) +} + +// NotifyInflightAdded mocks base method +func (m *MockNotifier) NotifyInflightAdded(delta int) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "NotifyInflightAdded", delta) +} + +// NotifyInflightAdded indicates an expected call of NotifyInflightAdded +func (mr *MockNotifierMockRecorder) NotifyInflightAdded(delta interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotifyInflightAdded", reflect.TypeOf((*MockNotifier)(nil).NotifyInflightAdded), delta) +} + +// NotifyMsgQueueAdded mocks base method +func (m *MockNotifier) NotifyMsgQueueAdded(delta int) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "NotifyMsgQueueAdded", delta) +} + +// NotifyMsgQueueAdded indicates an expected call of NotifyMsgQueueAdded +func (mr *MockNotifierMockRecorder) NotifyMsgQueueAdded(delta interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotifyMsgQueueAdded", reflect.TypeOf((*MockNotifier)(nil).NotifyMsgQueueAdded), delta) +} diff --git a/internal/hummingbird/mqttbroker/persistence/queue/redis/redis.go b/internal/hummingbird/mqttbroker/persistence/queue/redis/redis.go new file mode 100644 index 0000000..74dd166 --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/queue/redis/redis.go @@ -0,0 +1,418 @@ +package redis + +import ( + "sync" + "time" + + redigo "github.com/gomodule/redigo/redis" + "go.uber.org/zap" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/server" + "github.com/winc-link/hummingbird/internal/pkg/codes" + "github.com/winc-link/hummingbird/internal/pkg/packets" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/queue" +) + +const ( + queuePrefix = "queue:" +) + +var _ queue.Store = (*Queue)(nil) + +func getKey(clientID string) string { + return queuePrefix + clientID +} + +type Options struct { + MaxQueuedMsg int + ClientID string + InflightExpiry time.Duration + Pool *redigo.Pool + DefaultNotifier queue.Notifier +} + +type Queue struct { + cond *sync.Cond + clientID string + version packets.Version + readBytesLimit uint32 + // max is the maximum queue length + max int + // len is the length of the list + len int + pool *redigo.Pool + closed bool + inflightDrained bool + // current is the current read index of Queue list. + current int + readCache map[packets.PacketID][]byte + err error + log *zap.Logger + inflightExpiry time.Duration + notifier queue.Notifier +} + +func New(opts Options) (*Queue, error) { + return &Queue{ + cond: sync.NewCond(&sync.Mutex{}), + clientID: opts.ClientID, + max: opts.MaxQueuedMsg, + len: 0, + pool: opts.Pool, + closed: false, + inflightDrained: false, + current: 0, + inflightExpiry: opts.InflightExpiry, + notifier: opts.DefaultNotifier, + log: server.LoggerWithField(zap.String("queue", "redis")), + }, nil +} + +func wrapError(err error) *codes.Error { + return &codes.Error{ + Code: codes.UnspecifiedError, + ErrorDetails: codes.ErrorDetails{ + ReasonString: []byte(err.Error()), + UserProperties: nil, + }, + } +} + +func (q *Queue) Close() error { + q.cond.L.Lock() + defer func() { + q.cond.L.Unlock() + q.cond.Signal() + }() + q.closed = true + return nil +} + +func (q *Queue) setLen(conn redigo.Conn) error { + l, err := conn.Do("llen", getKey(q.clientID)) + if err != nil { + return err + } + q.len = int(l.(int64)) + return nil +} + +func (q *Queue) Init(opts *queue.InitOptions) error { + q.cond.L.Lock() + defer q.cond.L.Unlock() + conn := q.pool.Get() + defer conn.Close() + + if opts.CleanStart { + _, err := conn.Do("del", getKey(q.clientID)) + if err != nil { + return wrapError(err) + } + } + err := q.setLen(conn) + if err != nil { + return err + } + q.version = opts.Version + q.readBytesLimit = opts.ReadBytesLimit + q.closed = false + q.inflightDrained = false + q.current = 0 + q.readCache = make(map[packets.PacketID][]byte) + q.notifier = opts.Notifier + q.cond.Signal() + return nil +} + +func (q *Queue) Clean() error { + conn := q.pool.Get() + defer conn.Close() + _, err := conn.Do("del", getKey(q.clientID)) + return err +} + +func (q *Queue) Add(elem *queue.Elem) (err error) { + now := time.Now() + conn := q.pool.Get() + q.cond.L.Lock() + var dropErr error + var dropBytes []byte + var dropElem *queue.Elem + var drop bool + defer func() { + conn.Close() + q.cond.L.Unlock() + q.cond.Signal() + }() + + defer func() { + if drop { + if dropErr == queue.ErrDropExpiredInflight { + q.notifier.NotifyInflightAdded(-1) + q.current-- + } + if dropBytes == nil { + q.notifier.NotifyDropped(elem, dropErr) + return + } else { + err = conn.Send("lrem", getKey(q.clientID), 1, dropBytes) + } + q.notifier.NotifyDropped(dropElem, dropErr) + } else { + q.notifier.NotifyMsgQueueAdded(1) + q.len++ + } + _ = conn.Send("rpush", getKey(q.clientID), elem.Encode()) + err = conn.Flush() + }() + if q.len >= q.max { + // set default drop error + dropErr = queue.ErrDropQueueFull + drop = true + var rs []interface{} + // drop expired inflight message + rs, err = redigo.Values(conn.Do("lrange", getKey(q.clientID), 0, q.len)) + if err != nil { + return + } + var frontBytes []byte + var frontElem *queue.Elem + for i := 0; i < len(rs); i++ { + b := rs[i].([]byte) + e := &queue.Elem{} + err = e.Decode(b) + if err != nil { + return + } + // inflight message + if i < q.current && queue.ElemExpiry(now, e) { + dropBytes = b + dropElem = e + dropErr = queue.ErrDropExpiredInflight + return + } + // non-inflight message + if i >= q.current { + if i == q.current { + frontBytes = b + frontElem = e + } + // drop qos0 message in the queue + pub := e.MessageWithID.(*queue.Publish) + // drop expired non-inflight message + if pub.ID() == 0 && queue.ElemExpiry(now, e) { + dropBytes = b + dropElem = e + dropErr = queue.ErrDropExpired + return + } + if pub.ID() == 0 && pub.QoS == packets.Qos0 && dropElem == nil { + dropBytes = b + dropElem = e + } + } + } + // drop the current elem if there is no more non-inflight messages. + if q.inflightDrained && q.current >= q.len { + return + } + rs, err = redigo.Values(conn.Do("lrange", getKey(q.clientID), q.current, q.len)) + if err != nil { + return err + } + if dropElem != nil { + return + } + if elem.MessageWithID.(*queue.Publish).QoS == packets.Qos0 { + return + } + if frontElem != nil { + // drop the front message + dropBytes = frontBytes + dropElem = frontElem + } + // the the messages in the queue are all inflight messages, drop the current elem + return + } + return nil +} + +func (q *Queue) Replace(elem *queue.Elem) (replaced bool, err error) { + conn := q.pool.Get() + q.cond.L.Lock() + defer func() { + conn.Close() + q.cond.L.Unlock() + }() + id := elem.ID() + eb := elem.Encode() + stop := q.current - 1 + if stop < 0 { + stop = 0 + } + rs, err := redigo.Values(conn.Do("lrange", getKey(q.clientID), 0, stop)) + if err != nil { + return false, err + } + for k, v := range rs { + b := v.([]byte) + e := &queue.Elem{} + err = e.Decode(b) + if err != nil { + return false, err + } + if e.ID() == elem.ID() { + _, err = conn.Do("lset", getKey(q.clientID), k, eb) + if err != nil { + return false, err + } + q.readCache[id] = eb + return true, nil + } + } + + return false, nil +} + +func (q *Queue) Read(pids []packets.PacketID) (elems []*queue.Elem, err error) { + now := time.Now() + q.cond.L.Lock() + defer q.cond.L.Unlock() + conn := q.pool.Get() + defer conn.Close() + if !q.inflightDrained { + panic("must call ReadInflight to drain all inflight messages before Read") + } + for q.current >= q.len && !q.closed { + q.cond.Wait() + } + if q.closed { + return nil, queue.ErrClosed + } + rs, err := redigo.Values(conn.Do("lrange", getKey(q.clientID), q.current, q.current+len(pids)-1)) + if err != nil { + return nil, wrapError(err) + } + var msgQueueDelta, inflightDelta int + var pflag int + for i := 0; i < len(rs); i++ { + b := rs[i].([]byte) + e := &queue.Elem{} + err := e.Decode(b) + if err != nil { + return nil, err + } + // remove expired message + if queue.ElemExpiry(now, e) { + err = conn.Send("lrem", getKey(q.clientID), 1, b) + q.len-- + if err != nil { + return nil, err + } + q.notifier.NotifyDropped(e, queue.ErrDropExpired) + msgQueueDelta-- + continue + } + + // remove message which exceeds maximum packet size + pub := e.MessageWithID.(*queue.Publish) + if size := pub.TotalBytes(q.version); size > q.readBytesLimit { + err = conn.Send("lrem", getKey(q.clientID), 1, b) + q.len-- + if err != nil { + return nil, err + } + q.notifier.NotifyDropped(e, queue.ErrDropExceedsMaxPacketSize) + msgQueueDelta-- + continue + } + + if e.MessageWithID.(*queue.Publish).QoS == 0 { + err = conn.Send("lrem", getKey(q.clientID), 1, b) + q.len-- + msgQueueDelta-- + if err != nil { + return nil, err + } + } else { + e.MessageWithID.SetID(pids[pflag]) + if q.inflightExpiry != 0 { + e.Expiry = now.Add(q.inflightExpiry) + } + pflag++ + nb := e.Encode() + + err = conn.Send("lset", getKey(q.clientID), q.current, nb) + q.current++ + inflightDelta++ + q.readCache[e.MessageWithID.ID()] = nb + } + elems = append(elems, e) + } + err = conn.Flush() + q.notifier.NotifyMsgQueueAdded(msgQueueDelta) + q.notifier.NotifyInflightAdded(inflightDelta) + return +} + +func (q *Queue) ReadInflight(maxSize uint) (elems []*queue.Elem, err error) { + q.cond.L.Lock() + defer q.cond.L.Unlock() + conn := q.pool.Get() + defer conn.Close() + rs, err := redigo.Values(conn.Do("lrange", getKey(q.clientID), q.current, q.current+int(maxSize)-1)) + if len(rs) == 0 { + q.inflightDrained = true + return + } + if err != nil { + return nil, wrapError(err) + } + beginIndex := q.current + for index, v := range rs { + b := v.([]byte) + e := &queue.Elem{} + err := e.Decode(b) + if err != nil { + return nil, err + } + id := e.MessageWithID.ID() + if id != 0 { + if q.inflightExpiry != 0 { + e.Expiry = time.Now().Add(q.inflightExpiry) + b = e.Encode() + _, err = conn.Do("lset", getKey(q.clientID), beginIndex+index, b) + if err != nil { + return nil, err + } + } + elems = append(elems, e) + q.readCache[id] = b + q.current++ + } else { + q.inflightDrained = true + return elems, nil + } + } + return +} + +func (q *Queue) Remove(pid packets.PacketID) error { + q.cond.L.Lock() + defer q.cond.L.Unlock() + conn := q.pool.Get() + defer conn.Close() + if b, ok := q.readCache[pid]; ok { + _, err := conn.Do("lrem", getKey(q.clientID), 1, b) + if err != nil { + return err + } + q.notifier.NotifyMsgQueueAdded(-1) + q.notifier.NotifyInflightAdded(-1) + delete(q.readCache, pid) + q.len-- + q.current-- + } + return nil +} diff --git a/internal/hummingbird/mqttbroker/persistence/queue/test/test_suite.go b/internal/hummingbird/mqttbroker/persistence/queue/test/test_suite.go new file mode 100644 index 0000000..9ab3eaa --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/queue/test/test_suite.go @@ -0,0 +1,672 @@ +package test + +import ( + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/queue" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/server" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +var ( + TestServerConfig = config.Config{ + MQTT: config.MQTT{ + MaxQueuedMsg: 5, + InflightExpiry: 2 * time.Second, + }, + } + cid = "cid" + TestClientID = cid + TestNotifier = &testNotifier{} +) + +type testNotifier struct { + dropElem []*queue.Elem + dropErr error + inflightLen int + msgQueueLen int +} + +func (t *testNotifier) NotifyDropped(elem *queue.Elem, err error) { + t.dropElem = append(t.dropElem, elem) + t.dropErr = err +} + +func (t *testNotifier) NotifyInflightAdded(delta int) { + t.inflightLen += delta + if t.inflightLen < 0 { + t.inflightLen = 0 + } +} + +func (t *testNotifier) NotifyMsgQueueAdded(delta int) { + t.msgQueueLen += delta + if t.msgQueueLen < 0 { + t.msgQueueLen = 0 + } +} + +func initDrop() { + TestNotifier.dropElem = nil + TestNotifier.dropErr = nil +} + +func initNotifierLen() { + TestNotifier.inflightLen = 0 + TestNotifier.msgQueueLen = 0 +} + +func assertMsgEqual(a *assert.Assertions, expected, actual *queue.Elem) { + expMsg := expected.MessageWithID.(*queue.Publish).Message + actMsg := actual.MessageWithID.(*queue.Publish).Message + a.Equal(expMsg.Topic, actMsg.Topic) + a.Equal(expMsg.QoS, actMsg.QoS) + a.Equal(expMsg.Payload, actMsg.Payload) + a.Equal(expMsg.PacketID, actMsg.PacketID) +} + +func assertQueueLen(a *assert.Assertions, inflightLen, msgQueueLen int) { + a.Equal(inflightLen, TestNotifier.inflightLen) + a.Equal(msgQueueLen, TestNotifier.msgQueueLen) +} + +// 2 inflight message + 3 new message +var initElems = []*queue.Elem{ + { + At: time.Now(), + Expiry: time.Time{}, + MessageWithID: &queue.Publish{ + Message: &gmqtt.Message{ + QoS: packets.Qos1, + Retained: false, + Topic: "/topic1_qos1", + Payload: []byte("qos1"), + PacketID: 1, + }, + }, + }, { + At: time.Now(), + Expiry: time.Time{}, + MessageWithID: &queue.Publish{ + Message: &gmqtt.Message{ + QoS: packets.Qos2, + Retained: false, + Topic: "/topic1_qos2", + Payload: []byte("qos2"), + PacketID: 2, + }, + }, + }, { + At: time.Now(), + Expiry: time.Time{}, + MessageWithID: &queue.Publish{ + Message: &gmqtt.Message{ + QoS: packets.Qos1, + Retained: false, + Topic: "/topic1_qos1", + Payload: []byte("qos1"), + PacketID: 0, + }, + }, + }, + { + At: time.Now(), + Expiry: time.Time{}, + MessageWithID: &queue.Publish{ + Message: &gmqtt.Message{ + QoS: packets.Qos0, + Retained: false, + Topic: "/topic1_qos0", + Payload: []byte("qos0"), + PacketID: 0, + }, + }, + }, + { + At: time.Now(), + Expiry: time.Time{}, + MessageWithID: &queue.Publish{ + Message: &gmqtt.Message{ + QoS: packets.Qos2, + Retained: false, + Topic: "/topic1_qos2", + Payload: []byte("qos2"), + PacketID: 0, + }, + }, + }, +} + +func initStore(store queue.Store) error { + return store.Init(&queue.InitOptions{ + CleanStart: true, + Version: packets.Version5, + ReadBytesLimit: 100, + Notifier: TestNotifier, + }) +} + +func add(store queue.Store) error { + for _, v := range initElems { + elem := *v + elem.MessageWithID = &queue.Publish{ + Message: elem.MessageWithID.(*queue.Publish).Message.Copy(), + } + err := store.Add(&elem) + if err != nil { + return err + } + } + TestNotifier.inflightLen = 2 + return nil +} + +func assertDrop(a *assert.Assertions, elem *queue.Elem, err error) { + a.Len(TestNotifier.dropElem, 1) + switch elem.MessageWithID.(type) { + case *queue.Publish: + actual := TestNotifier.dropElem[0].MessageWithID.(*queue.Publish) + pub := elem.MessageWithID.(*queue.Publish) + a.Equal(pub.Message.Topic, actual.Topic) + a.Equal(pub.Message.QoS, actual.QoS) + a.Equal(pub.Payload, actual.Payload) + a.Equal(pub.PacketID, actual.PacketID) + a.Equal(err, TestNotifier.dropErr) + case *queue.Pubrel: + actual := TestNotifier.dropElem[0].MessageWithID.(*queue.Pubrel) + pubrel := elem.MessageWithID.(*queue.Pubrel) + a.Equal(pubrel.PacketID, actual.PacketID) + a.Equal(err, TestNotifier.dropErr) + default: + a.FailNow("unexpected elem type") + + } + initDrop() +} + +func reconnect(a *assert.Assertions, cleanStart bool, store queue.Store) { + a.NoError(store.Close()) + a.NoError(store.Init(&queue.InitOptions{ + CleanStart: cleanStart, + Version: packets.Version5, + ReadBytesLimit: 100, + Notifier: TestNotifier, + })) +} + +type New func(config config.Config, hooks server.Hooks) (server.Persistence, error) + +func TestQueue(t *testing.T, store queue.Store) { + initDrop() + a := assert.New(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + a.NoError(initStore(store)) + a.NoError(add(store)) + assertQueueLen(a, 2, 5) + testRead(a, store) + testDrop(a, store) + testReplace(a, store) + testCleanStart(a, store) + testReadExceedsDrop(a, store) + testClose(a, store) +} + +func testDrop(a *assert.Assertions, store queue.Store) { + // wait inflight messages to expire + time.Sleep(TestServerConfig.MQTT.InflightExpiry) + for i := 0; i < 3; i++ { + err := store.Add(&queue.Elem{ + At: time.Now(), + Expiry: time.Time{}, + MessageWithID: &queue.Publish{ + Message: &gmqtt.Message{ + Dup: false, + QoS: 2, + Retained: false, + Topic: "123", + Payload: []byte("123"), + PacketID: 0, + }, + }, + }) + a.Nil(err) + } + // drop expired inflight message (pid=1) + dropElem := initElems[0] + // queue: 1,2,0(qos2),0(qos2),0(qos2) (1 and 2 are expired inflight messages) + err := store.Add(&queue.Elem{ + At: time.Now(), + Expiry: time.Time{}, + MessageWithID: &queue.Publish{ + Message: &gmqtt.Message{ + Dup: false, + QoS: 1, + Retained: false, + Topic: "123", + Payload: []byte("123"), + PacketID: 0, + }, + }, + }) + a.NoError(err) + assertDrop(a, dropElem, queue.ErrDropExpiredInflight) + assertQueueLen(a, 1, 5) + + e, err := store.Read([]packets.PacketID{5, 6, 7}) + a.NoError(err) + a.Len(e, 3) + a.EqualValues(5, e[0].MessageWithID.ID()) + a.EqualValues(6, e[1].MessageWithID.ID()) + a.EqualValues(7, e[2].MessageWithID.ID()) + // queue: 2,5(qos2),6(qos2),7(qos2), 0(qos1) (2 is expired inflight message) + assertQueueLen(a, 4, 5) + // drop expired inflight message (pid=2) + dropElem = initElems[1] + err = store.Add(&queue.Elem{ + At: time.Now(), + Expiry: time.Time{}, + MessageWithID: &queue.Publish{ + Message: &gmqtt.Message{ + Dup: false, + QoS: 1, + Retained: false, + Topic: "1234", + Payload: []byte("1234"), + PacketID: 0, + }, + }, + }) + a.NoError(err) + // queue: 5(qos2),6(qos2),7(qos2), 0(qos1), 0(qos1) + assertDrop(a, dropElem, queue.ErrDropExpiredInflight) + + assertQueueLen(a, 3, 5) + e, err = store.Read([]packets.PacketID{8, 9}) + + a.NoError(err) + + // queue: 5(qos2),6(qos2),7(qos2),8(qos1),9(qos1) + a.Len(e, 2) + a.EqualValues(8, e[0].MessageWithID.ID()) + a.EqualValues(9, e[1].MessageWithID.ID()) + assertQueueLen(a, 5, 5) + + // drop the elem that is going to enqueue if there is no more non-inflight messages. + dropElem = &queue.Elem{ + At: time.Now(), + Expiry: time.Time{}, + MessageWithID: &queue.Publish{ + Message: &gmqtt.Message{ + Dup: false, + QoS: 1, + Retained: false, + Topic: "123", + Payload: []byte("123"), + PacketID: 0, + }, + }, + } + err = store.Add(dropElem) + a.NoError(err) + assertDrop(a, dropElem, queue.ErrDropQueueFull) + assertQueueLen(a, 5, 5) + + // queue: 5(qos2),6(qos2),7(qos2),8(qos1),9(qos1) + a.NoError(store.Remove(5)) + a.NoError(store.Remove(6)) + // queue: 7(qos2),8(qos2),9(qos2) + assertQueueLen(a, 3, 3) + + dropQoS0 := &queue.Elem{ + At: time.Now(), + Expiry: time.Time{}, + MessageWithID: &queue.Publish{ + Message: &gmqtt.Message{ + Dup: false, + QoS: 0, + Retained: false, + Topic: "/t_qos0", + Payload: []byte("test"), + }, + }, + } + a.NoError(store.Add(dropQoS0)) + // queue: 7(qos2),8(qos2),9(qos2),0 (qos0/t_qos0) + assertQueueLen(a, 3, 4) + + // add expired elem + dropExpired := &queue.Elem{ + At: time.Now(), + Expiry: time.Now().Add(-10 * time.Second), + MessageWithID: &queue.Publish{ + Message: &gmqtt.Message{ + Dup: false, + QoS: 0, + Retained: false, + Topic: "/drop", + Payload: []byte("test"), + }, + }, + } + a.NoError(store.Add(dropExpired)) + // queue: 7(qos2),8(qos2),9(qos2), 0(qos0/t_qos0), 0(qos0/drop) + assertQueueLen(a, 3, 5) + dropFront := &queue.Elem{ + At: time.Now(), + Expiry: time.Time{}, + MessageWithID: &queue.Publish{ + Message: &gmqtt.Message{ + Dup: false, + QoS: 1, + Retained: false, + Topic: "/drop_front", + Payload: []byte("drop_front"), + }, + }, + } + // drop the expired non-inflight message + a.NoError(store.Add(dropFront)) + // queue: 7(qos2),8(qos2),9(qos2), 0(qos0/t_qos0), 0(qos1/drop_front) + assertDrop(a, dropExpired, queue.ErrDropExpired) + assertQueueLen(a, 3, 5) + + // drop qos0 message + a.Nil(store.Add(&queue.Elem{ + At: time.Now(), + Expiry: time.Time{}, + MessageWithID: &queue.Publish{ + Message: &gmqtt.Message{ + Dup: false, + QoS: packets.Qos1, + Retained: false, + Topic: "/t_qos1", + Payload: []byte("/t_qos1"), + }, + }, + })) + // queue: 7(qos2),8(qos2),9(qos2), 0(qos1/drop_front), 0(qos1/t_qos1) + assertDrop(a, dropQoS0, queue.ErrDropQueueFull) + assertQueueLen(a, 3, 5) + + expiredPub := &queue.Elem{ + At: time.Now(), + Expiry: time.Now().Add(TestServerConfig.MQTT.InflightExpiry), + MessageWithID: &queue.Publish{ + Message: &gmqtt.Message{ + Dup: false, + QoS: 1, + Retained: false, + Topic: "/t", + Payload: []byte("/t"), + }, + }, + } + a.NoError(store.Add(expiredPub)) + // drop the front message + assertDrop(a, dropFront, queue.ErrDropQueueFull) + // queue: 7(qos2),8(qos2),9(qos2), 0(qos1/t_qos1), 0(qos1/t) + assertQueueLen(a, 3, 5) + // replace with an expired pubrel + expiredPubrel := &queue.Elem{ + At: time.Now(), + Expiry: time.Now().Add(-1 * time.Second), + MessageWithID: &queue.Pubrel{ + PacketID: 7, + }, + } + r, err := store.Replace(expiredPubrel) + a.True(r) + a.NoError(err) + assertQueueLen(a, 3, 5) + // queue: 7(qos2-pubrel),8(qos2),9(qos2), 0(qos1/t_qos1), 0(qos1/t) + a.NoError(store.Add(&queue.Elem{ + At: time.Now(), + Expiry: time.Time{}, + MessageWithID: &queue.Publish{ + Message: &gmqtt.Message{ + Dup: false, + QoS: 1, + Retained: false, + Topic: "/t1", + Payload: []byte("/t1"), + }, + }, + })) + // queue: 8(qos2),9(qos2), 0(qos1/t_qos1), 0(qos1/t), 0(qos1/t1) + assertDrop(a, expiredPubrel, queue.ErrDropExpiredInflight) + assertQueueLen(a, 2, 5) + drop := &queue.Elem{ + At: time.Now(), + MessageWithID: &queue.Publish{ + Message: &gmqtt.Message{ + Dup: false, + QoS: 0, + Retained: false, + Topic: "/t2", + Payload: []byte("/t2"), + }, + }, + } + a.NoError(store.Add(drop)) + assertDrop(a, drop, queue.ErrDropQueueFull) + assertQueueLen(a, 2, 5) + + a.NoError(store.Remove(8)) + a.NoError(store.Remove(9)) + // queue: 0(qos1/t_qos1), 0(qos1/t), 0(qos1/t1) + assertQueueLen(a, 0, 3) + // wait qos1/t to expire. + time.Sleep(TestServerConfig.MQTT.InflightExpiry) + e, err = store.Read([]packets.PacketID{1, 2, 3}) + a.NoError(err) + a.Len(e, 2) + assertQueueLen(a, 2, 2) + a.NoError(store.Remove(1)) + a.NoError(store.Remove(2)) + assertQueueLen(a, 0, 0) +} + +func testRead(a *assert.Assertions, store queue.Store) { + // 2 inflight + e, err := store.ReadInflight(1) + a.Nil(err) + a.Len(e, 1) + assertMsgEqual(a, initElems[0], e[0]) + + e, err = store.ReadInflight(2) + a.Len(e, 1) + assertMsgEqual(a, initElems[1], e[0]) + pids := []packets.PacketID{3, 4, 5} + e, err = store.Read(pids) + a.Len(e, 3) + + // must consume packet id in order and do not skip packet id if there are qos0 messages. + a.EqualValues(3, e[0].MessageWithID.ID()) + a.EqualValues(0, e[1].MessageWithID.ID()) + a.EqualValues(4, e[2].MessageWithID.ID()) + + assertQueueLen(a, 4, 4) + + err = store.Remove(3) + a.NoError(err) + err = store.Remove(4) + a.NoError(err) + assertQueueLen(a, 2, 2) + +} + +func testReplace(a *assert.Assertions, store queue.Store) { + + var elems []*queue.Elem + elems = append(elems, &queue.Elem{ + At: time.Now(), + Expiry: time.Time{}, + MessageWithID: &queue.Publish{ + Message: &gmqtt.Message{ + QoS: 2, + Topic: "/t_replace", + Payload: []byte("t_replace"), + }, + }, + }, &queue.Elem{ + At: time.Now(), + Expiry: time.Time{}, + MessageWithID: &queue.Publish{ + Message: &gmqtt.Message{ + QoS: 2, + Topic: "/t_replace", + Payload: []byte("t_replace"), + }, + }, + }, &queue.Elem{ + At: time.Now(), + Expiry: time.Time{}, + MessageWithID: &queue.Publish{ + Message: &gmqtt.Message{ + QoS: 2, + Topic: "/t_unread", + Payload: []byte("t_unread"), + PacketID: 3, + }, + }, + }) + for i := 0; i < 2; i++ { + elems = append(elems, &queue.Elem{ + At: time.Now(), + Expiry: time.Time{}, + MessageWithID: &queue.Publish{ + Message: &gmqtt.Message{ + QoS: 2, + Topic: "/t_replace", + Payload: []byte("t_replace"), + }, + }, + }) + a.NoError(store.Add(elems[i])) + } + assertQueueLen(a, 0, 2) + + e, err := store.Read([]packets.PacketID{1, 2}) + // queue: 1(qos2),2(qos2) + a.NoError(err) + a.Len(e, 2) + assertQueueLen(a, 2, 2) + r, err := store.Replace(&queue.Elem{ + At: time.Now(), + Expiry: time.Time{}, + MessageWithID: &queue.Pubrel{ + PacketID: 1, + }, + }) + a.True(r) + a.NoError(err) + + r, err = store.Replace(&queue.Elem{ + At: time.Now(), + Expiry: time.Time{}, + MessageWithID: &queue.Pubrel{ + PacketID: 3, + }, + }) + a.False(r) + a.NoError(err) + a.NoError(store.Add(elems[2])) + TestNotifier.inflightLen++ + // queue: 1(qos2-pubrel),2(qos2), 3(qos2) + + r, err = store.Replace(&queue.Elem{ + At: time.Now(), + Expiry: time.Time{}, + MessageWithID: &queue.Pubrel{ + PacketID: packets.PacketID(3), + }}) + a.False(r, "must not replace unread packet") + a.NoError(err) + assertQueueLen(a, 3, 3) + + reconnect(a, false, store) + + inflight, err := store.ReadInflight(5) + a.NoError(err) + a.Len(inflight, 3) + a.Equal(&queue.Pubrel{ + PacketID: 1, + }, inflight[0].MessageWithID) + + elems[1].MessageWithID.SetID(2) + elems[2].MessageWithID.SetID(3) + assertMsgEqual(a, elems[1], inflight[1]) + assertMsgEqual(a, elems[2], inflight[2]) + assertQueueLen(a, 3, 3) + +} + +func testReadExceedsDrop(a *assert.Assertions, store queue.Store) { + // add exceeded message + exceeded := &queue.Elem{ + At: time.Now(), + MessageWithID: &queue.Publish{ + Message: &gmqtt.Message{ + Dup: false, + QoS: 1, + Retained: false, + Topic: "/drop_exceed", + Payload: make([]byte, 100), + }, + }, + } + a.NoError(store.Add(exceeded)) + assertQueueLen(a, 0, 1) + e, err := store.Read([]packets.PacketID{1}) + a.NoError(err) + a.Len(e, 0) + assertDrop(a, exceeded, queue.ErrDropExceedsMaxPacketSize) + assertQueueLen(a, 0, 0) +} + +func testCleanStart(a *assert.Assertions, store queue.Store) { + reconnect(a, true, store) + rs, err := store.ReadInflight(10) + a.NoError(err) + a.Len(rs, 0) + initDrop() + initNotifierLen() +} + +func testClose(a *assert.Assertions, store queue.Store) { + t := time.After(2 * time.Second) + result := make(chan struct { + len int + err error + }) + go func() { + // should block + rs, err := store.Read([]packets.PacketID{1, 2, 3}) + result <- struct { + len int + err error + }{len: len(rs), err: err} + }() + select { + case <-result: + a.Fail("Read must be blocked before Close") + case <-t: + } + a.NoError(store.Close()) + timeout := time.After(5 * time.Second) + select { + case <-timeout: + a.Fail("Read must be unblocked after Close") + case r := <-result: + a.Zero(r.len) + a.Equal(queue.ErrClosed, r.err) + } +} diff --git a/internal/hummingbird/mqttbroker/persistence/redis.go b/internal/hummingbird/mqttbroker/persistence/redis.go new file mode 100644 index 0000000..a54787c --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/redis.go @@ -0,0 +1,96 @@ +package persistence + +import ( + redigo "github.com/gomodule/redigo/redis" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/queue" + redis_queue "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/queue/redis" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/session" + redis_sess "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/session/redis" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/subscription" + redis_sub "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/subscription/redis" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/unack" + redis_unack "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/unack/redis" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/server" +) + +func init() { + server.RegisterPersistenceFactory("redis", NewRedis) +} + +func NewRedis(config config.Config) (server.Persistence, error) { + return &redis{ + config: config, + }, nil +} + +type redis struct { + pool *redigo.Pool + config config.Config + onMsgDropped server.OnMsgDropped +} + +func (r *redis) NewUnackStore(config config.Config, clientID string) (unack.Store, error) { + return redis_unack.New(redis_unack.Options{ + ClientID: clientID, + Pool: r.pool, + }), nil +} + +func (r *redis) NewSessionStore(config config.Config) (session.Store, error) { + return redis_sess.New(r.pool), nil +} + +func newPool(config config.Config) *redigo.Pool { + return &redigo.Pool{ + // Dial or DialContext must be set. When both are set, DialContext takes precedence over Dial. + Dial: func() (redigo.Conn, error) { + c, err := redigo.Dial("tcp", config.Persistence.Redis.Addr) + if err != nil { + return nil, err + } + if pswd := config.Persistence.Redis.Password; pswd != "" { + if _, err := c.Do("AUTH", pswd); err != nil { + c.Close() + return nil, err + } + } + if _, err := c.Do("SELECT", config.Persistence.Redis.Database); err != nil { + c.Close() + return nil, err + } + return c, nil + }, + } +} +func (r *redis) Open() error { + r.pool = newPool(r.config) + r.pool.MaxIdle = int(*r.config.Persistence.Redis.MaxIdle) + r.pool.MaxActive = int(*r.config.Persistence.Redis.MaxActive) + r.pool.IdleTimeout = r.config.Persistence.Redis.IdleTimeout + conn := r.pool.Get() + defer conn.Close() + // Test the connection + _, err := conn.Do("PING") + + return err +} + +func (r *redis) NewQueueStore(config config.Config, defaultNotifier queue.Notifier, clientID string) (queue.Store, error) { + return redis_queue.New(redis_queue.Options{ + MaxQueuedMsg: config.MQTT.MaxQueuedMsg, + InflightExpiry: config.MQTT.InflightExpiry, + ClientID: clientID, + Pool: r.pool, + DefaultNotifier: defaultNotifier, + }) +} + +func (r *redis) NewSubscriptionStore(config config.Config) (subscription.Store, error) { + return redis_sub.New(r.pool), nil +} + +func (r *redis) Close() error { + return r.pool.Close() +} diff --git a/internal/hummingbird/mqttbroker/persistence/session/mem/store.go b/internal/hummingbird/mqttbroker/persistence/session/mem/store.go new file mode 100644 index 0000000..f4aae7b --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/session/mem/store.go @@ -0,0 +1,68 @@ +package mem + +import ( + "sync" + + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/session" +) + +var _ session.Store = (*Store)(nil) + +func New() *Store { + return &Store{ + mu: sync.Mutex{}, + sess: make(map[string]*gmqtt.Session), + } +} + +type Store struct { + mu sync.Mutex + sess map[string]*gmqtt.Session +} + +func (s *Store) Set(session *gmqtt.Session) error { + s.mu.Lock() + defer s.mu.Unlock() + s.sess[session.ClientID] = session + return nil +} + +func (s *Store) Remove(clientID string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.sess, clientID) + return nil +} + +func (s *Store) Get(clientID string) (*gmqtt.Session, error) { + s.mu.Lock() + defer s.mu.Unlock() + return s.sess[clientID], nil +} + +func (s *Store) GetAll() ([]*gmqtt.Session, error) { + return nil, nil +} + +func (s *Store) SetSessionExpiry(clientID string, expiry uint32) error { + s.mu.Lock() + defer s.mu.Unlock() + if s, ok := s.sess[clientID]; ok { + s.ExpiryInterval = expiry + + } + return nil +} + +func (s *Store) Iterate(fn session.IterateFn) error { + s.mu.Lock() + defer s.mu.Unlock() + for _, v := range s.sess { + cont := fn(v) + if !cont { + break + } + } + return nil +} diff --git a/internal/hummingbird/mqttbroker/persistence/session/redis/store.go b/internal/hummingbird/mqttbroker/persistence/session/redis/store.go new file mode 100644 index 0000000..22afafe --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/session/redis/store.go @@ -0,0 +1,132 @@ +package redis + +import ( + "bytes" + "sync" + "time" + + "github.com/gomodule/redigo/redis" + + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/encoding" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/session" +) + +const ( + sessPrefix = "session:" +) + +var _ session.Store = (*Store)(nil) + +type Store struct { + mu sync.Mutex + pool *redis.Pool +} + +func New(pool *redis.Pool) *Store { + return &Store{ + mu: sync.Mutex{}, + pool: pool, + } +} + +func getKey(clientID string) string { + return sessPrefix + clientID +} +func (s *Store) Set(session *gmqtt.Session) error { + s.mu.Lock() + defer s.mu.Unlock() + c := s.pool.Get() + defer c.Close() + b := &bytes.Buffer{} + encoding.EncodeMessage(session.Will, b) + _, err := c.Do("hset", getKey(session.ClientID), + "client_id", session.ClientID, + "will", b.Bytes(), + "will_delay_interval", session.WillDelayInterval, + "connected_at", session.ConnectedAt.Unix(), + "expiry_interval", session.ExpiryInterval, + ) + return err +} + +func (s *Store) Remove(clientID string) error { + s.mu.Lock() + defer s.mu.Unlock() + c := s.pool.Get() + defer c.Close() + _, err := c.Do("del", getKey(clientID)) + return err +} + +func (s *Store) Get(clientID string) (*gmqtt.Session, error) { + s.mu.Lock() + defer s.mu.Unlock() + c := s.pool.Get() + defer c.Close() + return getSessionLocked(getKey(clientID), c) +} + +func getSessionLocked(key string, c redis.Conn) (*gmqtt.Session, error) { + replay, err := redis.Values(c.Do("hmget", key, "client_id", "will", "will_delay_interval", "connected_at", "expiry_interval")) + if err != nil { + return nil, err + } + sess := &gmqtt.Session{} + var connectedAt uint32 + var will []byte + _, err = redis.Scan(replay, &sess.ClientID, &will, &sess.WillDelayInterval, &connectedAt, &sess.ExpiryInterval) + if err != nil { + return nil, err + } + sess.ConnectedAt = time.Unix(int64(connectedAt), 0) + sess.Will, err = encoding.DecodeMessageFromBytes(will) + if err != nil { + return nil, err + } + return sess, nil +} + +func (s *Store) SetSessionExpiry(clientID string, expiry uint32) error { + s.mu.Lock() + defer s.mu.Unlock() + c := s.pool.Get() + defer c.Close() + _, err := c.Do("hset", getKey(clientID), + "expiry_interval", expiry, + ) + return err +} + +func (s *Store) Iterate(fn session.IterateFn) error { + s.mu.Lock() + defer s.mu.Unlock() + c := s.pool.Get() + defer c.Close() + iter := 0 + for { + arr, err := redis.Values(c.Do("SCAN", iter, "MATCH", sessPrefix+"*")) + if err != nil { + return err + } + if len(arr) >= 1 { + for _, v := range arr[1:] { + for _, vv := range v.([]interface{}) { + sess, err := getSessionLocked(string(vv.([]uint8)), c) + if err != nil { + return err + } + cont := fn(sess) + if !cont { + return nil + } + } + } + } + iter, _ = redis.Int(arr[0], nil) + if iter == 0 { + break + } + } + return nil +} diff --git a/internal/hummingbird/mqttbroker/persistence/session/session.go b/internal/hummingbird/mqttbroker/persistence/session/session.go new file mode 100644 index 0000000..3656c86 --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/session/session.go @@ -0,0 +1,15 @@ +package session + +import gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + +// IterateFn is the callback function used by Iterate() +// Return false means to stop the iteration. +type IterateFn func(session *gmqtt.Session) bool + +type Store interface { + Set(session *gmqtt.Session) error + Remove(clientID string) error + Get(clientID string) (*gmqtt.Session, error) + Iterate(fn IterateFn) error + SetSessionExpiry(clientID string, expiry uint32) error +} diff --git a/internal/hummingbird/mqttbroker/persistence/session/session_mock.go b/internal/hummingbird/mqttbroker/persistence/session/session_mock.go new file mode 100644 index 0000000..a6c5230 --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/session/session_mock.go @@ -0,0 +1,106 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: persistence/session/session.go + +// Package session is a generated GoMock package. +package session + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" +) + +// MockStore is a mock of Store interface +type MockStore struct { + ctrl *gomock.Controller + recorder *MockStoreMockRecorder +} + +// MockStoreMockRecorder is the mock recorder for MockStore +type MockStoreMockRecorder struct { + mock *MockStore +} + +// NewMockStore creates a new mock instance +func NewMockStore(ctrl *gomock.Controller) *MockStore { + mock := &MockStore{ctrl: ctrl} + mock.recorder = &MockStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockStore) EXPECT() *MockStoreMockRecorder { + return m.recorder +} + +// Set mocks base method +func (m *MockStore) Set(session *gmqtt.Session) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Set", session) + ret0, _ := ret[0].(error) + return ret0 +} + +// Set indicates an expected call of Set +func (mr *MockStoreMockRecorder) Set(session interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockStore)(nil).Set), session) +} + +// Remove mocks base method +func (m *MockStore) Remove(clientID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Remove", clientID) + ret0, _ := ret[0].(error) + return ret0 +} + +// Remove indicates an expected call of Remove +func (mr *MockStoreMockRecorder) Remove(clientID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockStore)(nil).Remove), clientID) +} + +// Get mocks base method +func (m *MockStore) Get(clientID string) (*gmqtt.Session, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", clientID) + ret0, _ := ret[0].(*gmqtt.Session) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get +func (mr *MockStoreMockRecorder) Get(clientID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockStore)(nil).Get), clientID) +} + +// Iterate mocks base method +func (m *MockStore) Iterate(fn IterateFn) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Iterate", fn) + ret0, _ := ret[0].(error) + return ret0 +} + +// Iterate indicates an expected call of Iterate +func (mr *MockStoreMockRecorder) Iterate(fn interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Iterate", reflect.TypeOf((*MockStore)(nil).Iterate), fn) +} + +// SetSessionExpiry mocks base method +func (m *MockStore) SetSessionExpiry(clientID string, expiry uint32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetSessionExpiry", clientID, expiry) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetSessionExpiry indicates an expected call of SetSessionExpiry +func (mr *MockStoreMockRecorder) SetSessionExpiry(clientID, expiry interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetSessionExpiry", reflect.TypeOf((*MockStore)(nil).SetSessionExpiry), clientID, expiry) +} diff --git a/internal/hummingbird/mqttbroker/persistence/session/test/test_suite.go b/internal/hummingbird/mqttbroker/persistence/session/test/test_suite.go new file mode 100644 index 0000000..41f48d4 --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/session/test/test_suite.go @@ -0,0 +1,51 @@ +package test + +import ( + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/session" +) + +func TestSuite(t *testing.T, store session.Store) { + a := assert.New(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + var tt = []*gmqtt.Session{ + { + ClientID: "client", + Will: &gmqtt.Message{ + Topic: "topicA", + Payload: []byte("abc"), + }, + WillDelayInterval: 1, + ConnectedAt: time.Unix(1, 0), + ExpiryInterval: 2, + }, { + ClientID: "client2", + Will: nil, + WillDelayInterval: 0, + ConnectedAt: time.Unix(2, 0), + ExpiryInterval: 0, + }, + } + for _, v := range tt { + a.Nil(store.Set(v)) + } + for _, v := range tt { + sess, err := store.Get(v.ClientID) + a.Nil(err) + a.EqualValues(v, sess) + } + var sess []*gmqtt.Session + err := store.Iterate(func(session *gmqtt.Session) bool { + sess = append(sess, session) + return true + }) + a.Nil(err) + a.ElementsMatch(sess, tt) +} diff --git a/internal/hummingbird/mqttbroker/persistence/subscription/mem/topic_trie.go b/internal/hummingbird/mqttbroker/persistence/subscription/mem/topic_trie.go new file mode 100644 index 0000000..3657d7f --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/subscription/mem/topic_trie.go @@ -0,0 +1,201 @@ +package mem + +import ( + "strings" + + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/subscription" +) + +// topicTrie +type topicTrie = topicNode + +// children +type children = map[string]*topicNode + +type clientOpts map[string]*gmqtt.Subscription + +// topicNode +type topicNode struct { + children children + // clients store non-share subscription + clients clientOpts + parent *topicNode // pointer of parent node + topicName string + // shared store shared subscription, key by ShareName + shared map[string]clientOpts +} + +// newTopicTrie create a new trie tree +func newTopicTrie() *topicTrie { + return newNode() +} + +// newNode create a new trie node +func newNode() *topicNode { + return &topicNode{ + children: children{}, + clients: make(clientOpts), + shared: make(map[string]clientOpts), + } +} + +// newChild create a child node of t +func (t *topicNode) newChild() *topicNode { + n := newNode() + n.parent = t + return n +} + +// subscribe add a subscription and return the added node +func (t *topicTrie) subscribe(clientID string, s *gmqtt.Subscription) *topicNode { + topicSlice := strings.Split(s.TopicFilter, "/") + var pNode = t + for _, lv := range topicSlice { + if _, ok := pNode.children[lv]; !ok { + pNode.children[lv] = pNode.newChild() + } + pNode = pNode.children[lv] + } + // shared subscription + if s.ShareName != "" { + if pNode.shared[s.ShareName] == nil { + pNode.shared[s.ShareName] = make(clientOpts) + } + pNode.shared[s.ShareName][clientID] = s + } else { + // non-shared + pNode.clients[clientID] = s + } + pNode.topicName = s.TopicFilter + return pNode +} + +// find walk through the tire and return the node that represent the topicFilter. +// Return nil if not found +func (t *topicTrie) find(topicFilter string) *topicNode { + topicSlice := strings.Split(topicFilter, "/") + var pNode = t + for _, lv := range topicSlice { + if _, ok := pNode.children[lv]; ok { + pNode = pNode.children[lv] + } else { + return nil + } + } + if pNode.topicName == topicFilter { + return pNode + } + return nil +} + +// unsubscribe +func (t *topicTrie) unsubscribe(clientID string, topicName string, shareName string) { + topicSlice := strings.Split(topicName, "/") + l := len(topicSlice) + var pNode = t + for _, lv := range topicSlice { + if _, ok := pNode.children[lv]; ok { + pNode = pNode.children[lv] + } else { + return + } + } + if shareName != "" { + if c := pNode.shared[shareName]; c != nil { + delete(c, clientID) + if len(pNode.shared[shareName]) == 0 { + delete(pNode.shared, shareName) + } + if len(pNode.shared) == 0 && len(pNode.children) == 0 { + delete(pNode.parent.children, topicSlice[l-1]) + } + } + } else { + delete(pNode.clients, clientID) + if len(pNode.clients) == 0 && len(pNode.children) == 0 { + delete(pNode.parent.children, topicSlice[l-1]) + } + } + +} + +// setRs set the node subscription info into rs +func setRs(node *topicNode, rs subscription.ClientSubscriptions) { + for cid, subOpts := range node.clients { + rs[cid] = append(rs[cid], subOpts) + } + + for _, c := range node.shared { + for cid, subOpts := range c { + rs[cid] = append(rs[cid], subOpts) + } + } +} + +// matchTopic get all matched topic for given topicSlice, and set into rs +func (t *topicTrie) matchTopic(topicSlice []string, rs subscription.ClientSubscriptions) { + endFlag := len(topicSlice) == 1 + if cnode := t.children["#"]; cnode != nil { + setRs(cnode, rs) + } + if cnode := t.children["+"]; cnode != nil { + if endFlag { + setRs(cnode, rs) + if n := cnode.children["#"]; n != nil { + setRs(n, rs) + } + } else { + cnode.matchTopic(topicSlice[1:], rs) + } + } + if cnode := t.children[topicSlice[0]]; cnode != nil { + if endFlag { + setRs(cnode, rs) + if n := cnode.children["#"]; n != nil { + setRs(n, rs) + } + } else { + cnode.matchTopic(topicSlice[1:], rs) + } + } +} + +// getMatchedTopicFilter return a map key by clientID that contain all matched topic for the given topicName. +func (t *topicTrie) getMatchedTopicFilter(topicName string) subscription.ClientSubscriptions { + topicLv := strings.Split(topicName, "/") + subs := make(subscription.ClientSubscriptions) + t.matchTopic(topicLv, subs) + return subs +} + +func isSystemTopic(topicName string) bool { + return len(topicName) >= 1 && topicName[0] == '$' +} + +func (t *topicTrie) preOrderTraverse(fn subscription.IterateFn) bool { + if t == nil { + return false + } + if t.topicName != "" { + for clientID, subOpts := range t.clients { + if !fn(clientID, subOpts) { + return false + } + } + + for _, c := range t.shared { + for clientID, subOpts := range c { + if !fn(clientID, subOpts) { + return false + } + } + } + } + for _, c := range t.children { + if !c.preOrderTraverse(fn) { + return false + } + } + return true +} diff --git a/internal/hummingbird/mqttbroker/persistence/subscription/mem/topic_trie_test.go b/internal/hummingbird/mqttbroker/persistence/subscription/mem/topic_trie_test.go new file mode 100644 index 0000000..56b50ac --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/subscription/mem/topic_trie_test.go @@ -0,0 +1,309 @@ +package mem + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +var testTopicMatch = []struct { + subTopic string //subscribe topic + topic string //publish topic + isMatch bool +}{ + {subTopic: "#", topic: "/abc/def", isMatch: true}, + {subTopic: "/a", topic: "a", isMatch: false}, + {subTopic: "a/#", topic: "a", isMatch: true}, + {subTopic: "+", topic: "/a", isMatch: false}, + + {subTopic: "a/", topic: "a", isMatch: false}, + {subTopic: "a/+", topic: "a/123/4", isMatch: false}, + {subTopic: "a/#", topic: "a/123/4", isMatch: true}, + + {subTopic: "/a/+/+/abcd", topic: "/a/dfdf/3434/abcd", isMatch: true}, + {subTopic: "/a/+/+/abcd", topic: "/a/dfdf/3434/abcdd", isMatch: false}, + {subTopic: "/a/+/abc/", topic: "/a/dfdf/abc/", isMatch: true}, + {subTopic: "/a/+/abc/", topic: "/a/dfdf/abc", isMatch: false}, + {subTopic: "/a/+/+/", topic: "/a/dfdf/", isMatch: false}, + {subTopic: "/a/+/+", topic: "/a/dfdf/", isMatch: true}, + {subTopic: "/a/+/+/#", topic: "/a/dfdf/", isMatch: true}, +} + +var topicMatchQosTest = []struct { + topics []packets.Topic + matchTopic struct { + name string // matched topic name + qos uint8 // matched qos + } +}{ + { + topics: []packets.Topic{ + { + SubOptions: packets.SubOptions{ + Qos: packets.Qos1, + }, + Name: "a/b", + }, + { + Name: "a/#", + SubOptions: packets.SubOptions{ + Qos: packets.Qos2, + }, + }, + { + Name: "a/+", + SubOptions: packets.SubOptions{ + Qos: packets.Qos0, + }, + }, + }, + matchTopic: struct { + name string + qos uint8 + }{ + name: "a/b", + qos: packets.Qos2, + }, + }, +} + +var testSubscribeAndFind = struct { + subTopics map[string][]packets.Topic // subscription + findTopics map[string][]struct { //key by clientID + exist bool + topicName string + wantQos uint8 + } +}{ + subTopics: map[string][]packets.Topic{ + "cid1": { + { + SubOptions: packets.SubOptions{ + Qos: packets.Qos1, + }, Name: "t1/t2/+"}, + {SubOptions: packets.SubOptions{ + Qos: packets.Qos2, + }, Name: "t1/t2/"}, + {SubOptions: packets.SubOptions{ + Qos: packets.Qos0, + }, Name: "t1/t2/cid1"}, + }, + "cid2": { + {SubOptions: packets.SubOptions{ + Qos: packets.Qos2, + }, Name: "t1/t2/+"}, + {SubOptions: packets.SubOptions{ + Qos: packets.Qos1, + }, Name: "t1/t2/"}, + {SubOptions: packets.SubOptions{ + Qos: packets.Qos0, + }, Name: "t1/t2/cid2"}, + }, + }, + findTopics: map[string][]struct { //key by clientID + exist bool + topicName string + wantQos uint8 + }{ + "cid1": { + {exist: true, topicName: "t1/t2/+", wantQos: packets.Qos1}, + {exist: true, topicName: "t1/t2/", wantQos: packets.Qos2}, + {exist: false, topicName: "t1/t2/cid2"}, + {exist: false, topicName: "t1/t2/cid3"}, + }, + "cid2": { + {exist: true, topicName: "t1/t2/+", wantQos: packets.Qos2}, + {exist: true, topicName: "t1/t2/", wantQos: packets.Qos1}, + {exist: false, topicName: "t1/t2/cid1"}, + }, + }, +} + +var testUnsubscribe = struct { + subTopics map[string][]packets.Topic //key by clientID + unsubscribe map[string][]string // clientID => topic name + afterUnsub map[string][]struct { // test after unsubscribe, key by clientID + exist bool + topicName string + wantQos uint8 + } +}{ + subTopics: map[string][]packets.Topic{ + "cid1": { + {SubOptions: packets.SubOptions{ + Qos: packets.Qos1, + }, Name: "t1/t2/t3"}, + {SubOptions: packets.SubOptions{ + Qos: packets.Qos2, + }, Name: "t1/t2"}, + }, + "cid2": { + { + SubOptions: packets.SubOptions{ + Qos: packets.Qos2, + }, + Name: "t1/t2/t3"}, + { + SubOptions: packets.SubOptions{ + Qos: packets.Qos1, + }, Name: "t1/t2"}, + }, + }, + unsubscribe: map[string][]string{ + "cid1": {"t1/t2/t3", "t4/t5"}, + "cid2": {"t1/t2/t3"}, + }, + afterUnsub: map[string][]struct { // test after unsubscribe + exist bool + topicName string + wantQos uint8 + }{ + "cid1": { + {exist: false, topicName: "t1/t2/t3"}, + {exist: true, topicName: "t1/t2", wantQos: packets.Qos2}, + }, + "cid2": { + {exist: false, topicName: "t1/t2/+"}, + {exist: true, topicName: "t1/t2", wantQos: packets.Qos1}, + }, + }, +} + +var testPreOrderTraverse = struct { + topics []packets.Topic + clientID string +}{ + topics: []packets.Topic{ + { + SubOptions: packets.SubOptions{ + Qos: packets.Qos0, + }, + Name: "a/b/c", + }, + { + SubOptions: packets.SubOptions{ + Qos: packets.Qos1, + }, + Name: "/a/b/c", + }, + { + SubOptions: packets.SubOptions{ + Qos: packets.Qos2, + }, + Name: "b/c/d", + }, + }, + clientID: "abc", +} + +func TestTopicTrie_matchedClients(t *testing.T) { + a := assert.New(t) + for _, v := range testTopicMatch { + trie := newTopicTrie() + trie.subscribe("cid", &gmqtt.Subscription{ + TopicFilter: v.subTopic, + }) + qos := trie.getMatchedTopicFilter(v.topic) + if v.isMatch { + a.EqualValues(qos["cid"][0].QoS, 0, v.subTopic) + } else { + _, ok := qos["cid"] + a.False(ok, v.subTopic) + } + } +} + +func TestTopicTrie_matchedClients_Qos(t *testing.T) { + a := assert.New(t) + for _, v := range topicMatchQosTest { + trie := newTopicTrie() + for _, tt := range v.topics { + trie.subscribe("cid", &gmqtt.Subscription{ + TopicFilter: tt.Name, + QoS: tt.Qos, + }) + } + rs := trie.getMatchedTopicFilter(v.matchTopic.name) + a.EqualValues(v.matchTopic.qos, rs["cid"][0].QoS) + } +} + +func TestTopicTrie_subscribeAndFind(t *testing.T) { + a := assert.New(t) + trie := newTopicTrie() + for cid, v := range testSubscribeAndFind.subTopics { + for _, topic := range v { + trie.subscribe(cid, &gmqtt.Subscription{ + TopicFilter: topic.Name, + QoS: topic.Qos, + }) + } + } + for cid, v := range testSubscribeAndFind.findTopics { + for _, tt := range v { + node := trie.find(tt.topicName) + if tt.exist { + a.Equal(tt.wantQos, node.clients[cid].QoS) + } else { + if node != nil { + _, ok := node.clients[cid] + a.False(ok) + } + } + } + } +} + +func TestTopicTrie_unsubscribe(t *testing.T) { + a := assert.New(t) + trie := newTopicTrie() + for cid, v := range testUnsubscribe.subTopics { + for _, topic := range v { + trie.subscribe(cid, &gmqtt.Subscription{ + TopicFilter: topic.Name, + QoS: topic.Qos, + }) + } + } + for cid, v := range testUnsubscribe.unsubscribe { + for _, tt := range v { + trie.unsubscribe(cid, tt, "") + } + } + for cid, v := range testUnsubscribe.afterUnsub { + for _, tt := range v { + matched := trie.getMatchedTopicFilter(tt.topicName) + if tt.exist { + a.EqualValues(matched[cid][0].QoS, tt.wantQos) + } else { + a.Equal(0, len(matched)) + } + } + } +} + +func TestTopicTrie_preOrderTraverse(t *testing.T) { + a := assert.New(t) + trie := newTopicTrie() + for _, v := range testPreOrderTraverse.topics { + trie.subscribe(testPreOrderTraverse.clientID, &gmqtt.Subscription{ + TopicFilter: v.Name, + QoS: v.Qos, + }) + } + var rs []packets.Topic + trie.preOrderTraverse(func(clientID string, subscription *gmqtt.Subscription) bool { + a.Equal(testPreOrderTraverse.clientID, clientID) + rs = append(rs, packets.Topic{ + SubOptions: packets.SubOptions{ + Qos: subscription.QoS, + }, + Name: subscription.TopicFilter, + }) + return true + }) + a.ElementsMatch(testPreOrderTraverse.topics, rs) +} diff --git a/internal/hummingbird/mqttbroker/persistence/subscription/mem/trie_db.go b/internal/hummingbird/mqttbroker/persistence/subscription/mem/trie_db.go new file mode 100644 index 0000000..c1b6b45 --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/subscription/mem/trie_db.go @@ -0,0 +1,385 @@ +package mem + +import ( + "strings" + "sync" + + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/subscription" +) + +var _ subscription.Store = (*TrieDB)(nil) + +// TrieDB implement the subscription.Interface, it use trie tree to store topics. +type TrieDB struct { + sync.RWMutex + userIndex map[string]map[string]*topicNode // [clientID][topicFilter] + userTrie *topicTrie + + // system topic which begin with "$" + systemIndex map[string]map[string]*topicNode // [clientID][topicFilter] + systemTrie *topicTrie + + // shared subscription which begin with "$share" + sharedIndex map[string]map[string]*topicNode // [clientID][shareName/topicFilter] + sharedTrie *topicTrie + + // statistics of the server and each client + stats subscription.Stats + clientStats map[string]*subscription.Stats // [clientID] + +} + +func (db *TrieDB) Init(clientIDs []string) error { + return nil +} + +func (db *TrieDB) Close() error { + return nil +} + +func iterateShared(fn subscription.IterateFn, options subscription.IterationOptions, index map[string]map[string]*topicNode, trie *topicTrie) bool { + // 查询指定topicFilter + if options.TopicName != "" && options.MatchType == subscription.MatchName { //寻找指定topicName + var shareName string + var topicFilter string + if strings.HasPrefix(options.TopicName, "$share/") { + shared := strings.SplitN(options.TopicName, "/", 3) + shareName = shared[1] + topicFilter = shared[2] + } else { + return true + } + node := trie.find(topicFilter) + if node == nil { + return true + } + if options.ClientID != "" { // 指定topicName & 指定clientID + if c := node.shared[shareName]; c != nil { + if sub, ok := c[options.ClientID]; ok { + if !fn(options.ClientID, sub) { + return false + } + } + } + } else { + if c := node.shared[shareName]; c != nil { + for clientID, sub := range c { + if !fn(clientID, sub) { + return false + } + } + } + } + return true + } + // 查询Match指定topicFilter + if options.TopicName != "" && options.MatchType == subscription.MatchFilter { // match指定的topicfilter + node := trie.getMatchedTopicFilter(options.TopicName) + if node == nil { + return true + } + if options.ClientID != "" { + for _, v := range node[options.ClientID] { + if !fn(options.ClientID, v) { + return false + } + } + } else { + for clientID, subs := range node { + for _, v := range subs { + if !fn(clientID, v) { + return false + } + } + } + } + return true + } + // 查询指定clientID下的所有topic + if options.ClientID != "" { + for _, v := range index[options.ClientID] { + for _, c := range v.shared { + if sub, ok := c[options.ClientID]; ok { + if !fn(options.ClientID, sub) { + return false + } + } + } + } + return true + } + // 遍历 + return trie.preOrderTraverse(fn) +} + +func iterateNonShared(fn subscription.IterateFn, options subscription.IterationOptions, index map[string]map[string]*topicNode, trie *topicTrie) bool { + // 查询指定topicFilter + if options.TopicName != "" && options.MatchType == subscription.MatchName { //寻找指定topicName + node := trie.find(options.TopicName) + if node == nil { + return true + } + if options.ClientID != "" { // 指定topicName & 指定clientID + if sub, ok := node.clients[options.ClientID]; ok { + if !fn(options.ClientID, sub) { + return false + } + } + + for _, v := range node.shared { + if sub, ok := v[options.ClientID]; ok { + if !fn(options.ClientID, sub) { + return false + } + } + } + + } else { + // 指定topic name 不指定clientid + for clientID, sub := range node.clients { + if !fn(clientID, sub) { + return false + } + } + for _, c := range node.shared { + for clientID, sub := range c { + if !fn(clientID, sub) { + return false + } + } + } + + } + return true + } + // 查询Match指定topicFilter + if options.TopicName != "" && options.MatchType == subscription.MatchFilter { // match指定的topicfilter + node := trie.getMatchedTopicFilter(options.TopicName) + if node == nil { + return true + } + if options.ClientID != "" { + for _, v := range node[options.ClientID] { + if !fn(options.ClientID, v) { + return false + } + } + } else { + for clientID, subs := range node { + for _, v := range subs { + if !fn(clientID, v) { + return false + } + } + } + } + return true + } + // 查询指定clientID下的所有topic + if options.ClientID != "" { + for _, v := range index[options.ClientID] { + sub := v.clients[options.ClientID] + if !fn(options.ClientID, sub) { + return false + } + } + return true + } + // 遍历 + return trie.preOrderTraverse(fn) + +} + +// IterateLocked is the non thread-safe version of Iterate +func (db *TrieDB) IterateLocked(fn subscription.IterateFn, options subscription.IterationOptions) { + if options.Type&subscription.TypeShared == subscription.TypeShared { + if !iterateShared(fn, options, db.sharedIndex, db.sharedTrie) { + return + } + } + if options.Type&subscription.TypeNonShared == subscription.TypeNonShared { + // The Server MUST NOT match Topic Filters starting with a wildcard character (# or +) with Topic Names beginning with a $ character [MQTT-4.7.2-1] + if !(options.TopicName != "" && isSystemTopic(options.TopicName)) { + if !iterateNonShared(fn, options, db.userIndex, db.userTrie) { + return + } + } + } + if options.Type&subscription.TypeSYS == subscription.TypeSYS { + if options.TopicName != "" && !isSystemTopic(options.TopicName) { + return + } + if !iterateNonShared(fn, options, db.systemIndex, db.systemTrie) { + return + } + } +} +func (db *TrieDB) Iterate(fn subscription.IterateFn, options subscription.IterationOptions) { + db.RLock() + defer db.RUnlock() + db.IterateLocked(fn, options) +} + +// GetStats is the non thread-safe version of GetStats +func (db *TrieDB) GetStatusLocked() subscription.Stats { + return db.stats +} + +// GetStats returns the statistic information of the store +func (db *TrieDB) GetStats() subscription.Stats { + db.RLock() + defer db.RUnlock() + return db.GetStatusLocked() +} + +// GetClientStatsLocked the non thread-safe version of GetClientStats +func (db *TrieDB) GetClientStatsLocked(clientID string) (subscription.Stats, error) { + if stats, ok := db.clientStats[clientID]; !ok { + return subscription.Stats{}, subscription.ErrClientNotExists + } else { + return *stats, nil + } +} + +func (db *TrieDB) GetClientStats(clientID string) (subscription.Stats, error) { + db.RLock() + defer db.RUnlock() + return db.GetClientStatsLocked(clientID) +} + +// NewStore create a new TrieDB instance +func NewStore() *TrieDB { + return &TrieDB{ + userIndex: make(map[string]map[string]*topicNode), + userTrie: newTopicTrie(), + + systemIndex: make(map[string]map[string]*topicNode), + systemTrie: newTopicTrie(), + + sharedIndex: make(map[string]map[string]*topicNode), + sharedTrie: newTopicTrie(), + + clientStats: make(map[string]*subscription.Stats), + } +} + +// SubscribeLocked is the non thread-safe version of Subscribe +func (db *TrieDB) SubscribeLocked(clientID string, subscriptions ...*gmqtt.Subscription) subscription.SubscribeResult { + var node *topicNode + var index map[string]map[string]*topicNode + rs := make(subscription.SubscribeResult, len(subscriptions)) + for k, sub := range subscriptions { + topicName := sub.TopicFilter + rs[k].Subscription = sub + if sub.ShareName != "" { + node = db.sharedTrie.subscribe(clientID, sub) + index = db.sharedIndex + } else if isSystemTopic(topicName) { + node = db.systemTrie.subscribe(clientID, sub) + index = db.systemIndex + } else { + node = db.userTrie.subscribe(clientID, sub) + index = db.userIndex + } + if index[clientID] == nil { + index[clientID] = make(map[string]*topicNode) + if db.clientStats[clientID] == nil { + db.clientStats[clientID] = &subscription.Stats{} + } + } + if _, ok := index[clientID][topicName]; !ok { + db.stats.SubscriptionsTotal++ + db.stats.SubscriptionsCurrent++ + db.clientStats[clientID].SubscriptionsTotal++ + db.clientStats[clientID].SubscriptionsCurrent++ + } else { + rs[k].AlreadyExisted = true + } + index[clientID][topicName] = node + } + return rs +} + +// SubscribeLocked add subscriptions for the client +func (db *TrieDB) Subscribe(clientID string, subscriptions ...*gmqtt.Subscription) (subscription.SubscribeResult, error) { + db.Lock() + defer db.Unlock() + return db.SubscribeLocked(clientID, subscriptions...), nil +} + +// UnsubscribeLocked is the non thread-safe version of Unsubscribe +func (db *TrieDB) UnsubscribeLocked(clientID string, topics ...string) { + var index map[string]map[string]*topicNode + var topicTrie *topicTrie + for _, topic := range topics { + var shareName string + shareName, topic := subscription.SplitTopic(topic) + if shareName != "" { + topicTrie = db.sharedTrie + index = db.sharedIndex + } else if isSystemTopic(topic) { + index = db.systemIndex + topicTrie = db.systemTrie + } else { + index = db.userIndex + topicTrie = db.userTrie + } + if _, ok := index[clientID]; ok { + if _, ok := index[clientID][topic]; ok { + db.stats.SubscriptionsCurrent-- + db.clientStats[clientID].SubscriptionsCurrent-- + } + delete(index[clientID], topic) + } + topicTrie.unsubscribe(clientID, topic, shareName) + } +} + +// Unsubscribe remove subscriptions for the client +func (db *TrieDB) Unsubscribe(clientID string, topics ...string) error { + db.Lock() + defer db.Unlock() + db.UnsubscribeLocked(clientID, topics...) + return nil +} + +func (db *TrieDB) unsubscribeAll(index map[string]map[string]*topicNode, clientID string) { + db.stats.SubscriptionsCurrent -= uint64(len(index[clientID])) + if db.clientStats[clientID] != nil { + db.clientStats[clientID].SubscriptionsCurrent -= uint64(len(index[clientID])) + } + for topicName, node := range index[clientID] { + delete(node.clients, clientID) + if len(node.clients) == 0 && len(node.children) == 0 { + ss := strings.Split(topicName, "/") + delete(node.parent.children, ss[len(ss)-1]) + } + } + delete(index, clientID) +} + +// UnsubscribeAllLocked is the non thread-safe version of UnsubscribeAll +func (db *TrieDB) UnsubscribeAllLocked(clientID string) { + db.unsubscribeAll(db.userIndex, clientID) + db.unsubscribeAll(db.systemIndex, clientID) + db.unsubscribeAll(db.sharedIndex, clientID) +} + +// UnsubscribeAll delete all subscriptions of the client +func (db *TrieDB) UnsubscribeAll(clientID string) error { + db.Lock() + defer db.Unlock() + // user topics + db.UnsubscribeAllLocked(clientID) + return nil +} + +// getMatchedTopicFilter return a map key by clientID that contain all matched topic for the given topicName. +func (db *TrieDB) getMatchedTopicFilter(topicName string) subscription.ClientSubscriptions { + // system topic + if isSystemTopic(topicName) { + return db.systemTrie.getMatchedTopicFilter(topicName) + } + return db.userTrie.getMatchedTopicFilter(topicName) +} diff --git a/internal/hummingbird/mqttbroker/persistence/subscription/redis/subscription.go b/internal/hummingbird/mqttbroker/persistence/subscription/redis/subscription.go new file mode 100644 index 0000000..765ba06 --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/subscription/redis/subscription.go @@ -0,0 +1,177 @@ +package redis + +import ( + "bytes" + "strings" + "sync" + + redigo "github.com/gomodule/redigo/redis" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/encoding" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/subscription" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/subscription/mem" +) + +const ( + subPrefix = "sub:" +) + +var _ subscription.Store = (*sub)(nil) + +func EncodeSubscription(sub *mqttbroker.Subscription) []byte { + w := &bytes.Buffer{} + encoding.WriteString(w, []byte(sub.ShareName)) + encoding.WriteString(w, []byte(sub.TopicFilter)) + encoding.WriteUint32(w, sub.ID) + w.WriteByte(sub.QoS) + encoding.WriteBool(w, sub.NoLocal) + encoding.WriteBool(w, sub.RetainAsPublished) + w.WriteByte(sub.RetainHandling) + return w.Bytes() +} + +func DecodeSubscription(b []byte) (*gmqtt.Subscription, error) { + sub := &gmqtt.Subscription{} + r := bytes.NewBuffer(b) + share, err := encoding.ReadString(r) + if err != nil { + return &gmqtt.Subscription{}, err + } + sub.ShareName = string(share) + topic, err := encoding.ReadString(r) + if err != nil { + return &gmqtt.Subscription{}, err + } + sub.TopicFilter = string(topic) + sub.ID, err = encoding.ReadUint32(r) + if err != nil { + return &gmqtt.Subscription{}, err + } + sub.QoS, err = r.ReadByte() + if err != nil { + return &gmqtt.Subscription{}, err + } + sub.NoLocal, err = encoding.ReadBool(r) + if err != nil { + return &gmqtt.Subscription{}, err + } + sub.RetainAsPublished, err = encoding.ReadBool(r) + if err != nil { + return &gmqtt.Subscription{}, err + } + sub.RetainHandling, err = r.ReadByte() + if err != nil { + return nil, err + } + return sub, nil +} + +func New(pool *redigo.Pool) *sub { + return &sub{ + mu: &sync.Mutex{}, + memStore: mem.NewStore(), + pool: pool, + } +} + +type sub struct { + mu *sync.Mutex + memStore *mem.TrieDB + pool *redigo.Pool +} + +// Init loads the subscriptions of given clientIDs from backend into memory. +func (s *sub) Init(clientIDs []string) error { + if len(clientIDs) == 0 { + return nil + } + s.mu.Lock() + defer s.mu.Unlock() + c := s.pool.Get() + defer c.Close() + for _, v := range clientIDs { + rs, err := redigo.Values(c.Do("hgetall", subPrefix+v)) + if err != nil { + return err + } + for i := 1; i < len(rs); i = i + 2 { + sub, err := DecodeSubscription(rs[i].([]byte)) + if err != nil { + return err + } + s.memStore.SubscribeLocked(strings.TrimLeft(v, subPrefix), sub) + } + } + return nil +} + +func (s *sub) Close() error { + _ = s.memStore.Close() + return s.pool.Close() +} + +func (s *sub) Subscribe(clientID string, subscriptions ...*gmqtt.Subscription) (rs subscription.SubscribeResult, err error) { + s.mu.Lock() + defer s.mu.Unlock() + c := s.pool.Get() + defer c.Close() + // hset sub:clientID topicFilter xxx + for _, v := range subscriptions { + err = c.Send("hset", subPrefix+clientID, subscription.GetFullTopicName(v.ShareName, v.TopicFilter), EncodeSubscription(v)) + if err != nil { + return nil, err + } + } + err = c.Flush() + if err != nil { + return nil, err + } + rs = s.memStore.SubscribeLocked(clientID, subscriptions...) + return rs, nil +} + +func (s *sub) Unsubscribe(clientID string, topics ...string) error { + s.mu.Lock() + defer s.mu.Unlock() + c := s.pool.Get() + defer c.Close() + _, err := c.Do("hdel", subPrefix+clientID, topics) + if err != nil { + return err + } + s.memStore.UnsubscribeLocked(clientID, topics...) + return nil +} + +func (s *sub) UnsubscribeAll(clientID string) error { + s.mu.Lock() + defer s.mu.Unlock() + c := s.pool.Get() + defer c.Close() + _, err := c.Do("del", subPrefix+clientID) + if err != nil { + return err + } + s.memStore.UnsubscribeAllLocked(clientID) + return nil +} + +func (s *sub) Iterate(fn subscription.IterateFn, options subscription.IterationOptions) { + s.mu.Lock() + defer s.mu.Unlock() + s.memStore.IterateLocked(fn, options) +} + +func (s *sub) GetStats() subscription.Stats { + s.mu.Lock() + defer s.mu.Unlock() + return s.memStore.GetStatusLocked() +} + +func (s *sub) GetClientStats(clientID string) (subscription.Stats, error) { + s.mu.Lock() + defer s.mu.Unlock() + return s.memStore.GetClientStatsLocked(clientID) +} diff --git a/internal/hummingbird/mqttbroker/persistence/subscription/redis/subscription_test.go b/internal/hummingbird/mqttbroker/persistence/subscription/redis/subscription_test.go new file mode 100644 index 0000000..528f40a --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/subscription/redis/subscription_test.go @@ -0,0 +1,38 @@ +package redis + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" +) + +func TestEncodeDecodeSubscription(t *testing.T) { + a := assert.New(t) + tt := []*mqttbroker.Subscription{ + { + ShareName: "shareName", + TopicFilter: "filter", + ID: 1, + QoS: 1, + NoLocal: false, + RetainAsPublished: false, + RetainHandling: 0, + }, { + ShareName: "", + TopicFilter: "abc", + ID: 0, + QoS: 2, + NoLocal: false, + RetainAsPublished: true, + RetainHandling: 1, + }, + } + + for _, v := range tt { + b := EncodeSubscription(v) + sub, err := DecodeSubscription(b) + a.Nil(err) + a.Equal(v, sub) + } +} diff --git a/internal/hummingbird/mqttbroker/persistence/subscription/subscription.go b/internal/hummingbird/mqttbroker/persistence/subscription/subscription.go new file mode 100644 index 0000000..0811b43 --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/subscription/subscription.go @@ -0,0 +1,193 @@ +package subscription + +import ( + "errors" + "strings" + + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +// IterationType specifies the types of subscription that will be iterated. +type IterationType byte + +const ( + // TypeSYS represents system topic, which start with '$'. + TypeSYS IterationType = 1 << iota + // TypeSYS represents shared topic, which start with '$share/'. + TypeShared + // TypeNonShared represents non-shared topic. + TypeNonShared + TypeAll = TypeSYS | TypeShared | TypeNonShared +) + +var ( + ErrClientNotExists = errors.New("client not exists") +) + +// MatchType specifies what match operation will be performed during the iteration. +type MatchType byte + +const ( + MatchName MatchType = 1 << iota + MatchFilter +) + +// FromTopic returns the subscription instance for given topic and subscription id. +func FromTopic(topic packets.Topic, id uint32) *gmqtt.Subscription { + shareName, topicFilter := SplitTopic(topic.Name) + s := &gmqtt.Subscription{ + ShareName: shareName, + TopicFilter: topicFilter, + ID: id, + QoS: topic.Qos, + NoLocal: topic.NoLocal, + RetainAsPublished: topic.RetainAsPublished, + RetainHandling: topic.RetainHandling, + } + return s +} + +// IterateFn is the callback function used by iterate() +// Return false means to stop the iteration. +type IterateFn func(clientID string, sub *gmqtt.Subscription) bool + +// SubscribeResult is the result of Subscribe() +type SubscribeResult = []struct { + // Topic is the Subscribed topic + Subscription *gmqtt.Subscription + // AlreadyExisted shows whether the topic is already existed. + AlreadyExisted bool +} + +// Stats is the statistics information of the store +type Stats struct { + // SubscriptionsTotal shows how many subscription has been added to the store. + // Duplicated subscription is not counting. + SubscriptionsTotal uint64 + // SubscriptionsCurrent shows the current subscription number in the store. + SubscriptionsCurrent uint64 +} + +// ClientSubscriptions groups the subscriptions by client id. +type ClientSubscriptions map[string][]*gmqtt.Subscription + +// IterationOptions +type IterationOptions struct { + // Type specifies the types of subscription that will be iterated. + // For example, if Type = TypeShared | TypeNonShared , then all shared and non-shared subscriptions will be iterated + Type IterationType + // ClientID specifies the subscriber client id. + ClientID string + // TopicName represents topic filter or topic name. This field works together with MatchType. + TopicName string + // MatchType specifies the matching type of the iteration. + // if MatchName, the IterateFn will be called when the subscription topic filter is equal to TopicName. + // if MatchTopic, the IterateFn will be called when the TopicName match the subscription topic filter. + MatchType MatchType +} + +// Store is the interface used by gmqtt.server to handler the operations of subscriptions. +// This interface provides the ability for extensions to interact with the subscriptions. +// Notice: +// This methods will not trigger any gmqtt hooks. +type Store interface { + // Init will be called only once after the server start, the implementation should load the subscriptions of the given alertclient into memory. + Init(clientIDs []string) error + // Subscribe adds subscriptions to a specific client. + // Notice: + // This method will succeed even if the client is not exists, the subscriptions + // will affect the new client with the client id. + Subscribe(clientID string, subscriptions ...*gmqtt.Subscription) (rs SubscribeResult, err error) + // Unsubscribe removes subscriptions of a specific client. + Unsubscribe(clientID string, topics ...string) error + // UnsubscribeAll removes all subscriptions of a specific client. + UnsubscribeAll(clientID string) error + // Iterate iterates all subscriptions. The callback is called once for each subscription. + // If callback return false, the iteration will be stopped. + // Notice: + // The results are not sorted in any way, no ordering of any kind is guaranteed. + // This method will walk through all subscriptions, + // so it is a very expensive operation. Do not call it frequently. + Iterate(fn IterateFn, options IterationOptions) + + Close() error + StatsReader +} + +// GetTopicMatched returns the subscriptions that match the passed topic. +func GetTopicMatched(store Store, topicFilter string, t IterationType) ClientSubscriptions { + rs := make(ClientSubscriptions) + store.Iterate(func(clientID string, subscription *gmqtt.Subscription) bool { + rs[clientID] = append(rs[clientID], subscription) + return true + }, IterationOptions{ + Type: t, + TopicName: topicFilter, + MatchType: MatchFilter, + }) + if len(rs) == 0 { + return nil + } + return rs +} + +// Get returns the subscriptions that equals the passed topic filter. +func Get(store Store, topicFilter string, t IterationType) ClientSubscriptions { + rs := make(ClientSubscriptions) + store.Iterate(func(clientID string, subscription *gmqtt.Subscription) bool { + rs[clientID] = append(rs[clientID], subscription) + return true + }, IterationOptions{ + Type: t, + TopicName: topicFilter, + MatchType: MatchName, + }) + if len(rs) == 0 { + return nil + } + return rs +} + +// GetClientSubscriptions returns the subscriptions of a specific client. +func GetClientSubscriptions(store Store, clientID string, t IterationType) []*gmqtt.Subscription { + var rs []*gmqtt.Subscription + store.Iterate(func(clientID string, subscription *gmqtt.Subscription) bool { + rs = append(rs, subscription) + return true + }, IterationOptions{ + Type: t, + ClientID: clientID, + }) + return rs +} + +// StatsReader provides the ability to get statistics information. +type StatsReader interface { + // GetStats return the global stats. + GetStats() Stats + // GetClientStats return the stats of a specific client. + // If stats not exists, return an error. + GetClientStats(clientID string) (Stats, error) +} + +// SplitTopic returns the shareName and topicFilter of the given topic. +// If the topic is invalid, returns empty strings. +func SplitTopic(topic string) (shareName, topicFilter string) { + if strings.HasPrefix(topic, "$share/") { + shared := strings.SplitN(topic, "/", 3) + if len(shared) < 3 { + return "", "" + } + return shared[1], shared[2] + } + return "", topic +} + +// GetFullTopicName returns the full topic name of given shareName and topicFilter +func GetFullTopicName(shareName, topicFilter string) string { + if shareName != "" { + return "$share/" + shareName + "/" + topicFilter + } + return topicFilter +} diff --git a/internal/hummingbird/mqttbroker/persistence/subscription/subscription_mock.go b/internal/hummingbird/mqttbroker/persistence/subscription/subscription_mock.go new file mode 100644 index 0000000..5beb0aa --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/subscription/subscription_mock.go @@ -0,0 +1,209 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: persistence/subscription/subscription.go + +// Package subscription is a generated GoMock package. +package subscription + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" +) + +// MockStore is a mock of Store interface +type MockStore struct { + ctrl *gomock.Controller + recorder *MockStoreMockRecorder +} + +// MockStoreMockRecorder is the mock recorder for MockStore +type MockStoreMockRecorder struct { + mock *MockStore +} + +// NewMockStore creates a new mock instance +func NewMockStore(ctrl *gomock.Controller) *MockStore { + mock := &MockStore{ctrl: ctrl} + mock.recorder = &MockStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockStore) EXPECT() *MockStoreMockRecorder { + return m.recorder +} + +// Init mocks base method +func (m *MockStore) Init(clientIDs []string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Init", clientIDs) + ret0, _ := ret[0].(error) + return ret0 +} + +// Init indicates an expected call of Init +func (mr *MockStoreMockRecorder) Init(clientIDs interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockStore)(nil).Init), clientIDs) +} + +// Subscribe mocks base method +func (m *MockStore) Subscribe(clientID string, subscriptions ...*gmqtt.Subscription) (SubscribeResult, error) { + m.ctrl.T.Helper() + varargs := []interface{}{clientID} + for _, a := range subscriptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Subscribe", varargs...) + ret0, _ := ret[0].(SubscribeResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Subscribe indicates an expected call of Subscribe +func (mr *MockStoreMockRecorder) Subscribe(clientID interface{}, subscriptions ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{clientID}, subscriptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Subscribe", reflect.TypeOf((*MockStore)(nil).Subscribe), varargs...) +} + +// Unsubscribe mocks base method +func (m *MockStore) Unsubscribe(clientID string, topics ...string) error { + m.ctrl.T.Helper() + varargs := []interface{}{clientID} + for _, a := range topics { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Unsubscribe", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Unsubscribe indicates an expected call of Unsubscribe +func (mr *MockStoreMockRecorder) Unsubscribe(clientID interface{}, topics ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{clientID}, topics...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unsubscribe", reflect.TypeOf((*MockStore)(nil).Unsubscribe), varargs...) +} + +// UnsubscribeAll mocks base method +func (m *MockStore) UnsubscribeAll(clientID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UnsubscribeAll", clientID) + ret0, _ := ret[0].(error) + return ret0 +} + +// UnsubscribeAll indicates an expected call of UnsubscribeAll +func (mr *MockStoreMockRecorder) UnsubscribeAll(clientID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnsubscribeAll", reflect.TypeOf((*MockStore)(nil).UnsubscribeAll), clientID) +} + +// Iterate mocks base method +func (m *MockStore) Iterate(fn IterateFn, options IterationOptions) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Iterate", fn, options) +} + +// Iterate indicates an expected call of Iterate +func (mr *MockStoreMockRecorder) Iterate(fn, options interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Iterate", reflect.TypeOf((*MockStore)(nil).Iterate), fn, options) +} + +// Close mocks base method +func (m *MockStore) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockStoreMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStore)(nil).Close)) +} + +// GetStats mocks base method +func (m *MockStore) GetStats() Stats { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetStats") + ret0, _ := ret[0].(Stats) + return ret0 +} + +// GetStats indicates an expected call of GetStats +func (mr *MockStoreMockRecorder) GetStats() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStats", reflect.TypeOf((*MockStore)(nil).GetStats)) +} + +// GetClientStats mocks base method +func (m *MockStore) GetClientStats(clientID string) (Stats, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClientStats", clientID) + ret0, _ := ret[0].(Stats) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetClientStats indicates an expected call of GetClientStats +func (mr *MockStoreMockRecorder) GetClientStats(clientID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientStats", reflect.TypeOf((*MockStore)(nil).GetClientStats), clientID) +} + +// MockStatsReader is a mock of StatsReader interface +type MockStatsReader struct { + ctrl *gomock.Controller + recorder *MockStatsReaderMockRecorder +} + +// MockStatsReaderMockRecorder is the mock recorder for MockStatsReader +type MockStatsReaderMockRecorder struct { + mock *MockStatsReader +} + +// NewMockStatsReader creates a new mock instance +func NewMockStatsReader(ctrl *gomock.Controller) *MockStatsReader { + mock := &MockStatsReader{ctrl: ctrl} + mock.recorder = &MockStatsReaderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockStatsReader) EXPECT() *MockStatsReaderMockRecorder { + return m.recorder +} + +// GetStats mocks base method +func (m *MockStatsReader) GetStats() Stats { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetStats") + ret0, _ := ret[0].(Stats) + return ret0 +} + +// GetStats indicates an expected call of GetStats +func (mr *MockStatsReaderMockRecorder) GetStats() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStats", reflect.TypeOf((*MockStatsReader)(nil).GetStats)) +} + +// GetClientStats mocks base method +func (m *MockStatsReader) GetClientStats(clientID string) (Stats, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClientStats", clientID) + ret0, _ := ret[0].(Stats) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetClientStats indicates an expected call of GetClientStats +func (mr *MockStatsReaderMockRecorder) GetClientStats(clientID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientStats", reflect.TypeOf((*MockStatsReader)(nil).GetClientStats), clientID) +} diff --git a/internal/hummingbird/mqttbroker/persistence/subscription/test/test_suite.go b/internal/hummingbird/mqttbroker/persistence/subscription/test/test_suite.go new file mode 100644 index 0000000..0738aea --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/subscription/test/test_suite.go @@ -0,0 +1,601 @@ +package test + +import ( + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/subscription" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +var ( + topicA = &gmqtt.Subscription{ + TopicFilter: "topic/A", + ID: 1, + + QoS: 1, + NoLocal: true, + RetainAsPublished: true, + RetainHandling: 1, + } + + topicB = &gmqtt.Subscription{ + TopicFilter: "topic/B", + + QoS: 1, + NoLocal: false, + RetainAsPublished: true, + RetainHandling: 0, + } + + systemTopicA = &gmqtt.Subscription{ + TopicFilter: "$topic/A", + ID: 1, + + QoS: 1, + NoLocal: true, + RetainAsPublished: true, + RetainHandling: 1, + } + + systemTopicB = &gmqtt.Subscription{ + TopicFilter: "$topic/B", + + QoS: 1, + NoLocal: false, + RetainAsPublished: true, + RetainHandling: 0, + } + + sharedTopicA1 = &gmqtt.Subscription{ + ShareName: "name1", + TopicFilter: "topic/A", + ID: 1, + + QoS: 1, + NoLocal: true, + RetainAsPublished: true, + RetainHandling: 1, + } + + sharedTopicB1 = &gmqtt.Subscription{ + ShareName: "name1", + TopicFilter: "topic/B", + ID: 1, + + QoS: 1, + NoLocal: true, + RetainAsPublished: true, + RetainHandling: 1, + } + + sharedTopicA2 = &gmqtt.Subscription{ + ShareName: "name2", + TopicFilter: "topic/A", + ID: 1, + + QoS: 1, + NoLocal: true, + RetainAsPublished: true, + RetainHandling: 1, + } + + sharedTopicB2 = &gmqtt.Subscription{ + ShareName: "name2", + TopicFilter: "topic/B", + ID: 1, + + QoS: 1, + NoLocal: true, + RetainAsPublished: true, + RetainHandling: 1, + } +) + +var testSubs = []struct { + clientID string + subs []*gmqtt.Subscription +}{ + // non-share and non-system subscription + { + + clientID: "client1", + subs: []*gmqtt.Subscription{ + topicA, topicB, + }, + }, { + clientID: "client2", + subs: []*gmqtt.Subscription{ + topicA, topicB, + }, + }, + // system subscription + { + + clientID: "client1", + subs: []*gmqtt.Subscription{ + systemTopicA, systemTopicB, + }, + }, { + + clientID: "client2", + subs: []*gmqtt.Subscription{ + systemTopicA, systemTopicB, + }, + }, + // share subscription + { + clientID: "client1", + subs: []*gmqtt.Subscription{ + sharedTopicA1, sharedTopicB1, sharedTopicA2, sharedTopicB2, + }, + }, + { + clientID: "client2", + subs: []*gmqtt.Subscription{ + sharedTopicA1, sharedTopicB1, sharedTopicA2, sharedTopicB2, + }, + }, +} + +func testAddSubscribe(t *testing.T, store subscription.Store) { + a := assert.New(t) + for _, v := range testSubs { + _, err := store.Subscribe(v.clientID, v.subs...) + a.Nil(err) + } +} + +func testGetStatus(t *testing.T, store subscription.Store) { + a := assert.New(t) + var err error + tt := []struct { + clientID string + topic packets.Topic + }{ + {clientID: "id0", topic: packets.Topic{Name: "name0", SubOptions: packets.SubOptions{Qos: packets.Qos0}}}, + {clientID: "id1", topic: packets.Topic{Name: "name1", SubOptions: packets.SubOptions{Qos: packets.Qos1}}}, + {clientID: "id2", topic: packets.Topic{Name: "name2", SubOptions: packets.SubOptions{Qos: packets.Qos2}}}, + {clientID: "id3", topic: packets.Topic{Name: "name3", SubOptions: packets.SubOptions{Qos: packets.Qos2}}}, + {clientID: "id4", topic: packets.Topic{Name: "name3", SubOptions: packets.SubOptions{Qos: packets.Qos2}}}, + {clientID: "id4", topic: packets.Topic{Name: "name4", SubOptions: packets.SubOptions{Qos: packets.Qos2}}}, + // test $share and system topic + {clientID: "id4", topic: packets.Topic{Name: "$share/abc/name4", SubOptions: packets.SubOptions{Qos: packets.Qos2}}}, + {clientID: "id4", topic: packets.Topic{Name: "$SYS/abc/def", SubOptions: packets.SubOptions{Qos: packets.Qos2}}}, + } + for _, v := range tt { + _, err = store.Subscribe(v.clientID, subscription.FromTopic(v.topic, 0)) + a.NoError(err) + } + stats := store.GetStats() + expectedTotal, expectedCurrent := len(tt), len(tt) + + a.EqualValues(expectedTotal, stats.SubscriptionsTotal) + a.EqualValues(expectedCurrent, stats.SubscriptionsCurrent) + + // If subscribe duplicated topic, total and current statistics should not increase + _, err = store.Subscribe("id0", subscription.FromTopic(packets.Topic{SubOptions: packets.SubOptions{Qos: packets.Qos0}, Name: "name0"}, 0)) + a.NoError(err) + _, err = store.Subscribe("id4", subscription.FromTopic(packets.Topic{SubOptions: packets.SubOptions{Qos: packets.Qos2}, Name: "$share/abc/name4"}, 0)) + a.NoError(err) + + stats = store.GetStats() + a.EqualValues(expectedTotal, stats.SubscriptionsTotal) + a.EqualValues(expectedCurrent, stats.SubscriptionsCurrent) + + utt := []struct { + clientID string + topic packets.Topic + }{ + {clientID: "id0", topic: packets.Topic{Name: "name0", SubOptions: packets.SubOptions{Qos: packets.Qos0}}}, + {clientID: "id1", topic: packets.Topic{Name: "name1", SubOptions: packets.SubOptions{Qos: packets.Qos1}}}, + } + expectedCurrent -= 2 + for _, v := range utt { + a.NoError(store.Unsubscribe(v.clientID, v.topic.Name)) + } + stats = store.GetStats() + a.EqualValues(expectedTotal, stats.SubscriptionsTotal) + a.EqualValues(expectedCurrent, stats.SubscriptionsCurrent) + + //if unsubscribe not exists topic, current statistics should not decrease + a.NoError(store.Unsubscribe("id0", "name555")) + stats = store.GetStats() + a.EqualValues(len(tt), stats.SubscriptionsTotal) + a.EqualValues(expectedCurrent, stats.SubscriptionsCurrent) + + a.NoError(store.Unsubscribe("id4", "$share/abc/name4")) + + expectedCurrent -= 1 + stats = store.GetStats() + a.EqualValues(expectedTotal, stats.SubscriptionsTotal) + a.EqualValues(expectedCurrent, stats.SubscriptionsCurrent) + + a.NoError(store.UnsubscribeAll("id4")) + expectedCurrent -= 3 + stats = store.GetStats() + a.EqualValues(len(tt), stats.SubscriptionsTotal) + a.EqualValues(expectedCurrent, stats.SubscriptionsCurrent) +} + +func testGetClientStats(t *testing.T, store subscription.Store) { + a := assert.New(t) + var err error + tt := []struct { + clientID string + topic packets.Topic + }{ + {clientID: "id0", topic: packets.Topic{Name: "name0", SubOptions: packets.SubOptions{Qos: packets.Qos0}}}, + {clientID: "id0", topic: packets.Topic{Name: "name1", SubOptions: packets.SubOptions{Qos: packets.Qos1}}}, + // test $share and system topic + {clientID: "id0", topic: packets.Topic{Name: "$share/abc/name5", SubOptions: packets.SubOptions{Qos: packets.Qos2}}}, + {clientID: "id0", topic: packets.Topic{Name: "$SYS/a/b/c", SubOptions: packets.SubOptions{Qos: packets.Qos2}}}, + + {clientID: "id1", topic: packets.Topic{Name: "name0", SubOptions: packets.SubOptions{Qos: packets.Qos2}}}, + {clientID: "id1", topic: packets.Topic{Name: "$share/abc/name5", SubOptions: packets.SubOptions{Qos: packets.Qos2}}}, + {clientID: "id2", topic: packets.Topic{Name: "$SYS/a/b/c", SubOptions: packets.SubOptions{Qos: packets.Qos2}}}, + {clientID: "id2", topic: packets.Topic{Name: "name5", SubOptions: packets.SubOptions{Qos: packets.Qos2}}}, + } + for _, v := range tt { + _, err = store.Subscribe(v.clientID, subscription.FromTopic(v.topic, 0)) + a.NoError(err) + } + stats, _ := store.GetClientStats("id0") + a.EqualValues(4, stats.SubscriptionsTotal) + a.EqualValues(4, stats.SubscriptionsCurrent) + + a.NoError(store.UnsubscribeAll("id0")) + stats, _ = store.GetClientStats("id0") + a.EqualValues(4, stats.SubscriptionsTotal) + a.EqualValues(0, stats.SubscriptionsCurrent) +} + +func TestSuite(t *testing.T, new func() subscription.Store) { + a := assert.New(t) + store := new() + a.Nil(store.Init(nil)) + defer store.Close() + for i := 0; i <= 1; i++ { + testAddSubscribe(t, store) + t.Run("testGetTopic"+strconv.Itoa(i), func(t *testing.T) { + testGetTopic(t, store) + }) + t.Run("testTopicMatch"+strconv.Itoa(i), func(t *testing.T) { + testTopicMatch(t, store) + }) + t.Run("testIterate"+strconv.Itoa(i), func(t *testing.T) { + testIterate(t, store) + }) + t.Run("testUnsubscribe"+strconv.Itoa(i), func(t *testing.T) { + testUnsubscribe(t, store) + }) + } + + store2 := new() + a.Nil(store2.Init(nil)) + defer store2.Close() + t.Run("testGetStatus", func(t *testing.T) { + testGetStatus(t, store2) + }) + + store3 := new() + a.Nil(store3.Init(nil)) + defer store3.Close() + t.Run("testGetStatus", func(t *testing.T) { + testGetClientStats(t, store3) + }) +} +func testGetTopic(t *testing.T, store subscription.Store) { + a := assert.New(t) + + rs := subscription.Get(store, topicA.TopicFilter, subscription.TypeAll) + a.Equal(topicA, rs["client1"][0]) + a.Equal(topicA, rs["client2"][0]) + + rs = subscription.Get(store, topicA.TopicFilter, subscription.TypeNonShared) + a.Equal(topicA, rs["client1"][0]) + a.Equal(topicA, rs["client2"][0]) + + rs = subscription.Get(store, systemTopicA.TopicFilter, subscription.TypeAll) + a.Equal(systemTopicA, rs["client1"][0]) + a.Equal(systemTopicA, rs["client2"][0]) + + rs = subscription.Get(store, systemTopicA.TopicFilter, subscription.TypeSYS) + a.Equal(systemTopicA, rs["client1"][0]) + a.Equal(systemTopicA, rs["client2"][0]) + + rs = subscription.Get(store, "$share/"+sharedTopicA1.ShareName+"/"+sharedTopicA1.TopicFilter, subscription.TypeAll) + a.Equal(sharedTopicA1, rs["client1"][0]) + a.Equal(sharedTopicA1, rs["client2"][0]) + +} +func testTopicMatch(t *testing.T, store subscription.Store) { + a := assert.New(t) + rs := subscription.GetTopicMatched(store, topicA.TopicFilter, subscription.TypeAll) + a.ElementsMatch([]*gmqtt.Subscription{topicA, sharedTopicA1, sharedTopicA2}, rs["client1"]) + a.ElementsMatch([]*gmqtt.Subscription{topicA, sharedTopicA1, sharedTopicA2}, rs["client2"]) + + rs = subscription.GetTopicMatched(store, topicA.TopicFilter, subscription.TypeNonShared) + a.ElementsMatch([]*gmqtt.Subscription{topicA}, rs["client1"]) + a.ElementsMatch([]*gmqtt.Subscription{topicA}, rs["client2"]) + + rs = subscription.GetTopicMatched(store, topicA.TopicFilter, subscription.TypeShared) + a.ElementsMatch([]*gmqtt.Subscription{sharedTopicA1, sharedTopicA2}, rs["client1"]) + a.ElementsMatch([]*gmqtt.Subscription{sharedTopicA1, sharedTopicA2}, rs["client2"]) + + rs = subscription.GetTopicMatched(store, systemTopicA.TopicFilter, subscription.TypeSYS) + a.ElementsMatch([]*gmqtt.Subscription{systemTopicA}, rs["client1"]) + a.ElementsMatch([]*gmqtt.Subscription{systemTopicA}, rs["client2"]) + +} +func testUnsubscribe(t *testing.T, store subscription.Store) { + a := assert.New(t) + a.Nil(store.Unsubscribe("client1", topicA.TopicFilter)) + rs := subscription.Get(store, topicA.TopicFilter, subscription.TypeAll) + a.Nil(rs["client1"]) + a.ElementsMatch([]*gmqtt.Subscription{topicA}, rs["client2"]) + a.Nil(store.UnsubscribeAll("client2")) + a.Nil(store.UnsubscribeAll("client1")) + var iterationCalled bool + store.Iterate(func(clientID string, sub *gmqtt.Subscription) bool { + iterationCalled = true + return true + }, subscription.IterationOptions{Type: subscription.TypeAll}) + a.False(iterationCalled) +} +func testIterate(t *testing.T, store subscription.Store) { + a := assert.New(t) + + var iterationCalled bool + // invalid subscription.IterationOptions + store.Iterate(func(clientID string, sub *gmqtt.Subscription) bool { + iterationCalled = true + return true + }, subscription.IterationOptions{}) + a.False(iterationCalled) + testIterateNonShared(t, store) + testIterateShared(t, store) + testIterateSystem(t, store) +} +func testIterateNonShared(t *testing.T, store subscription.Store) { + a := assert.New(t) + // iterate all non-shared subscriptions. + got := make(subscription.ClientSubscriptions) + store.Iterate(func(clientID string, sub *gmqtt.Subscription) bool { + got[clientID] = append(got[clientID], sub) + return true + }, subscription.IterationOptions{ + Type: subscription.TypeNonShared, + }) + a.ElementsMatch([]*gmqtt.Subscription{topicA, topicB}, got["client1"]) + a.ElementsMatch([]*gmqtt.Subscription{topicA, topicB}, got["client2"]) + + // iterate all non-shared subscriptions with ClientID option. + got = make(subscription.ClientSubscriptions) + store.Iterate(func(clientID string, sub *gmqtt.Subscription) bool { + got[clientID] = append(got[clientID], sub) + return true + }, subscription.IterationOptions{ + Type: subscription.TypeNonShared, + ClientID: "client1", + }) + + a.ElementsMatch([]*gmqtt.Subscription{topicA, topicB}, got["client1"]) + a.Len(got["client2"], 0) + + // iterate all non-shared subscriptions that matched given topic name. + got = make(subscription.ClientSubscriptions) + store.Iterate(func(clientID string, sub *gmqtt.Subscription) bool { + got[clientID] = append(got[clientID], sub) + return true + }, subscription.IterationOptions{ + Type: subscription.TypeNonShared, + MatchType: subscription.MatchName, + TopicName: topicA.TopicFilter, + }) + a.ElementsMatch([]*gmqtt.Subscription{topicA}, got["client1"]) + a.ElementsMatch([]*gmqtt.Subscription{topicA}, got["client2"]) + + // iterate all non-shared subscriptions that matched given topic name and client id + got = make(subscription.ClientSubscriptions) + store.Iterate(func(clientID string, sub *gmqtt.Subscription) bool { + got[clientID] = append(got[clientID], sub) + return true + }, subscription.IterationOptions{ + Type: subscription.TypeNonShared, + MatchType: subscription.MatchName, + TopicName: topicA.TopicFilter, + ClientID: "client1", + }) + a.ElementsMatch([]*gmqtt.Subscription{topicA}, got["client1"]) + a.Len(got["client2"], 0) + + // iterate all non-shared subscriptions that matched given topic filter. + got = make(subscription.ClientSubscriptions) + store.Iterate(func(clientID string, sub *gmqtt.Subscription) bool { + got[clientID] = append(got[clientID], sub) + return true + }, subscription.IterationOptions{ + Type: subscription.TypeNonShared, + MatchType: subscription.MatchFilter, + TopicName: topicA.TopicFilter, + }) + a.ElementsMatch([]*gmqtt.Subscription{topicA}, got["client1"]) + a.ElementsMatch([]*gmqtt.Subscription{topicA}, got["client2"]) + + // iterate all non-shared subscriptions that matched given topic filter and client id + got = make(subscription.ClientSubscriptions) + store.Iterate(func(clientID string, sub *gmqtt.Subscription) bool { + got[clientID] = append(got[clientID], sub) + return true + }, subscription.IterationOptions{ + Type: subscription.TypeNonShared, + MatchType: subscription.MatchFilter, + TopicName: topicA.TopicFilter, + ClientID: "client1", + }) + a.ElementsMatch([]*gmqtt.Subscription{topicA}, got["client1"]) + a.Len(got["client2"], 0) +} +func testIterateShared(t *testing.T, store subscription.Store) { + a := assert.New(t) + // iterate all shared subscriptions. + got := make(subscription.ClientSubscriptions) + store.Iterate(func(clientID string, sub *gmqtt.Subscription) bool { + got[clientID] = append(got[clientID], sub) + return true + }, subscription.IterationOptions{ + Type: subscription.TypeShared, + }) + a.ElementsMatch([]*gmqtt.Subscription{sharedTopicA1, sharedTopicA2, sharedTopicB1, sharedTopicB2}, got["client1"]) + a.ElementsMatch([]*gmqtt.Subscription{sharedTopicA1, sharedTopicA2, sharedTopicB1, sharedTopicB2}, got["client2"]) + + // iterate all shared subscriptions with ClientID option. + got = make(subscription.ClientSubscriptions) + store.Iterate(func(clientID string, sub *gmqtt.Subscription) bool { + got[clientID] = append(got[clientID], sub) + return true + }, subscription.IterationOptions{ + Type: subscription.TypeShared, + ClientID: "client1", + }) + a.ElementsMatch([]*gmqtt.Subscription{sharedTopicA1, sharedTopicA2, sharedTopicB1, sharedTopicB2}, got["client1"]) + a.Len(got["client2"], 0) + + // iterate all shared subscriptions that matched given topic filter. + got = make(subscription.ClientSubscriptions) + store.Iterate(func(clientID string, sub *gmqtt.Subscription) bool { + got[clientID] = append(got[clientID], sub) + return true + }, subscription.IterationOptions{ + Type: subscription.TypeShared, + MatchType: subscription.MatchName, + TopicName: "$share/" + sharedTopicA1.ShareName + "/" + sharedTopicA1.TopicFilter, + }) + a.ElementsMatch([]*gmqtt.Subscription{sharedTopicA1}, got["client1"]) + a.ElementsMatch([]*gmqtt.Subscription{sharedTopicA1}, got["client2"]) + + // iterate all shared subscriptions that matched given topic filter and client id + got = make(subscription.ClientSubscriptions) + store.Iterate(func(clientID string, sub *gmqtt.Subscription) bool { + got[clientID] = append(got[clientID], sub) + return true + }, subscription.IterationOptions{ + Type: subscription.TypeShared, + MatchType: subscription.MatchName, + TopicName: "$share/" + sharedTopicA1.ShareName + "/" + sharedTopicA1.TopicFilter, + ClientID: "client1", + }) + a.ElementsMatch([]*gmqtt.Subscription{sharedTopicA1}, got["client1"]) + a.Len(got["client2"], 0) + + // iterate all shared subscriptions that matched given topic name. + got = make(subscription.ClientSubscriptions) + store.Iterate(func(clientID string, sub *gmqtt.Subscription) bool { + got[clientID] = append(got[clientID], sub) + return true + }, subscription.IterationOptions{ + Type: subscription.TypeShared, + MatchType: subscription.MatchFilter, + TopicName: sharedTopicA1.TopicFilter, + }) + a.ElementsMatch([]*gmqtt.Subscription{sharedTopicA1, sharedTopicA2}, got["client1"]) + a.ElementsMatch([]*gmqtt.Subscription{sharedTopicA1, sharedTopicA2}, got["client2"]) + + // iterate all shared subscriptions that matched given topic name and clientID + got = make(subscription.ClientSubscriptions) + store.Iterate(func(clientID string, sub *gmqtt.Subscription) bool { + got[clientID] = append(got[clientID], sub) + return true + }, subscription.IterationOptions{ + Type: subscription.TypeShared, + MatchType: subscription.MatchFilter, + TopicName: sharedTopicA1.TopicFilter, + ClientID: "client1", + }) + a.ElementsMatch([]*gmqtt.Subscription{sharedTopicA1, sharedTopicA2}, got["client1"]) + a.Len(got["client2"], 0) + +} +func testIterateSystem(t *testing.T, store subscription.Store) { + a := assert.New(t) + // iterate all system subscriptions. + got := make(subscription.ClientSubscriptions) + store.Iterate(func(clientID string, sub *gmqtt.Subscription) bool { + got[clientID] = append(got[clientID], sub) + return true + }, subscription.IterationOptions{ + Type: subscription.TypeSYS, + }) + a.ElementsMatch([]*gmqtt.Subscription{systemTopicA, systemTopicB}, got["client1"]) + a.ElementsMatch([]*gmqtt.Subscription{systemTopicA, systemTopicB}, got["client2"]) + + // iterate all system subscriptions with ClientID option. + got = make(subscription.ClientSubscriptions) + store.Iterate(func(clientID string, sub *gmqtt.Subscription) bool { + got[clientID] = append(got[clientID], sub) + return true + }, subscription.IterationOptions{ + Type: subscription.TypeSYS, + ClientID: "client1", + }) + a.ElementsMatch([]*gmqtt.Subscription{systemTopicA, systemTopicB}, got["client1"]) + a.Len(got["client2"], 0) + + // iterate all system subscriptions that matched given topic filter. + got = make(subscription.ClientSubscriptions) + store.Iterate(func(clientID string, sub *gmqtt.Subscription) bool { + got[clientID] = append(got[clientID], sub) + return true + }, subscription.IterationOptions{ + Type: subscription.TypeSYS, + MatchType: subscription.MatchName, + TopicName: systemTopicA.TopicFilter, + }) + a.ElementsMatch([]*gmqtt.Subscription{systemTopicA}, got["client1"]) + a.ElementsMatch([]*gmqtt.Subscription{systemTopicA}, got["client2"]) + + // iterate all system subscriptions that matched given topic filter and client id + got = make(subscription.ClientSubscriptions) + store.Iterate(func(clientID string, sub *gmqtt.Subscription) bool { + got[clientID] = append(got[clientID], sub) + return true + }, subscription.IterationOptions{ + Type: subscription.TypeSYS, + MatchType: subscription.MatchName, + TopicName: systemTopicA.TopicFilter, + ClientID: "client1", + }) + a.ElementsMatch([]*gmqtt.Subscription{systemTopicA}, got["client1"]) + a.Len(got["client2"], 0) + + // iterate all system subscriptions that matched given topic name. + got = make(subscription.ClientSubscriptions) + store.Iterate(func(clientID string, sub *gmqtt.Subscription) bool { + got[clientID] = append(got[clientID], sub) + return true + }, subscription.IterationOptions{ + Type: subscription.TypeSYS, + MatchType: subscription.MatchFilter, + TopicName: systemTopicA.TopicFilter, + }) + a.ElementsMatch([]*gmqtt.Subscription{systemTopicA}, got["client1"]) + a.ElementsMatch([]*gmqtt.Subscription{systemTopicA}, got["client2"]) + + // iterate all system subscriptions that matched given topic name and clientID + got = make(subscription.ClientSubscriptions) + store.Iterate(func(clientID string, sub *gmqtt.Subscription) bool { + got[clientID] = append(got[clientID], sub) + return true + }, subscription.IterationOptions{ + Type: subscription.TypeSYS, + MatchType: subscription.MatchFilter, + TopicName: systemTopicA.TopicFilter, + ClientID: "client1", + }) + a.ElementsMatch([]*gmqtt.Subscription{systemTopicA}, got["client1"]) + a.Len(got["client2"], 0) +} diff --git a/internal/hummingbird/mqttbroker/persistence/unack/mem/mem.go b/internal/hummingbird/mqttbroker/persistence/unack/mem/mem.go new file mode 100644 index 0000000..3fa0eac --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/unack/mem/mem.go @@ -0,0 +1,44 @@ +package mem + +import ( + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/unack" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +var _ unack.Store = (*Store)(nil) + +type Store struct { + clientID string + unackpublish map[packets.PacketID]struct{} +} + +type Options struct { + ClientID string +} + +func New(opts Options) *Store { + return &Store{ + clientID: opts.ClientID, + unackpublish: make(map[packets.PacketID]struct{}), + } +} + +func (s *Store) Init(cleanStart bool) error { + if cleanStart { + s.unackpublish = make(map[packets.PacketID]struct{}) + } + return nil +} + +func (s *Store) Set(id packets.PacketID) (bool, error) { + if _, ok := s.unackpublish[id]; ok { + return true, nil + } + s.unackpublish[id] = struct{}{} + return false, nil +} + +func (s *Store) Remove(id packets.PacketID) error { + delete(s.unackpublish, id) + return nil +} diff --git a/internal/hummingbird/mqttbroker/persistence/unack/redis/redis.go b/internal/hummingbird/mqttbroker/persistence/unack/redis/redis.go new file mode 100644 index 0000000..08eb728 --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/unack/redis/redis.go @@ -0,0 +1,75 @@ +package redis + +import ( + "github.com/gomodule/redigo/redis" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/unack" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +const ( + unackPrefix = "unack:" +) + +var _ unack.Store = (*Store)(nil) + +type Store struct { + clientID string + pool *redis.Pool + unackpublish map[packets.PacketID]struct{} +} + +type Options struct { + ClientID string + Pool *redis.Pool +} + +func New(opts Options) *Store { + return &Store{ + clientID: opts.ClientID, + pool: opts.Pool, + unackpublish: make(map[packets.PacketID]struct{}), + } +} + +func getKey(clientID string) string { + return unackPrefix + clientID +} +func (s *Store) Init(cleanStart bool) error { + if cleanStart { + c := s.pool.Get() + defer c.Close() + s.unackpublish = make(map[packets.PacketID]struct{}) + _, err := c.Do("del", getKey(s.clientID)) + if err != nil { + return err + } + } + return nil +} + +func (s *Store) Set(id packets.PacketID) (bool, error) { + // from cache + if _, ok := s.unackpublish[id]; ok { + return true, nil + } + c := s.pool.Get() + defer c.Close() + _, err := c.Do("hset", getKey(s.clientID), id, 1) + if err != nil { + return false, err + } + s.unackpublish[id] = struct{}{} + return false, nil +} + +func (s *Store) Remove(id packets.PacketID) error { + c := s.pool.Get() + defer c.Close() + _, err := c.Do("hdel", getKey(s.clientID), id) + if err != nil { + return err + } + delete(s.unackpublish, id) + return nil +} diff --git a/internal/hummingbird/mqttbroker/persistence/unack/test/test_suite.go b/internal/hummingbird/mqttbroker/persistence/unack/test/test_suite.go new file mode 100644 index 0000000..732eff0 --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/unack/test/test_suite.go @@ -0,0 +1,54 @@ +package test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/unack" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +var ( + TestServerConfig = config.Config{} + cid = "cid" + TestClientID = cid +) + +func TestSuite(t *testing.T, store unack.Store) { + a := assert.New(t) + a.Nil(store.Init(false)) + for i := packets.PacketID(1); i < 10; i++ { + rs, err := store.Set(i) + a.Nil(err) + a.False(rs) + rs, err = store.Set(i) + a.Nil(err) + a.True(rs) + err = store.Remove(i) + a.Nil(err) + rs, err = store.Set(i) + a.Nil(err) + a.False(rs) + + } + a.Nil(store.Init(false)) + for i := packets.PacketID(1); i < 10; i++ { + rs, err := store.Set(i) + a.Nil(err) + a.True(rs) + err = store.Remove(i) + a.Nil(err) + rs, err = store.Set(i) + a.Nil(err) + a.False(rs) + } + a.Nil(store.Init(true)) + for i := packets.PacketID(1); i < 10; i++ { + rs, err := store.Set(i) + a.Nil(err) + a.False(rs) + } + +} diff --git a/internal/hummingbird/mqttbroker/persistence/unack/unack.go b/internal/hummingbird/mqttbroker/persistence/unack/unack.go new file mode 100644 index 0000000..0d29a2c --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/unack/unack.go @@ -0,0 +1,19 @@ +package unack + +import ( + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +// Store represents a unack store for one client. +// Unack store is used to persist the unacknowledged qos2 messages. +type Store interface { + // Init will be called when the client connect. + // If cleanStart set to true, the implementation should remove any associated data in backend store. + // If it set to false, the implementation should retrieve the associated data from backend store. + Init(cleanStart bool) error + // Set sets the given id into store. + // The return boolean indicates whether the id exist. + Set(id packets.PacketID) (bool, error) + // Remove removes the given id from store. + Remove(id packets.PacketID) error +} diff --git a/internal/hummingbird/mqttbroker/persistence/unack/unack_mock.go b/internal/hummingbird/mqttbroker/persistence/unack/unack_mock.go new file mode 100644 index 0000000..3653be4 --- /dev/null +++ b/internal/hummingbird/mqttbroker/persistence/unack/unack_mock.go @@ -0,0 +1,78 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: persistence/unack/unack.go + +// Package unack is a generated GoMock package. +package unack + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + packets "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +// MockStore is a mock of Store interface +type MockStore struct { + ctrl *gomock.Controller + recorder *MockStoreMockRecorder +} + +// MockStoreMockRecorder is the mock recorder for MockStore +type MockStoreMockRecorder struct { + mock *MockStore +} + +// NewMockStore creates a new mock instance +func NewMockStore(ctrl *gomock.Controller) *MockStore { + mock := &MockStore{ctrl: ctrl} + mock.recorder = &MockStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockStore) EXPECT() *MockStoreMockRecorder { + return m.recorder +} + +// Init mocks base method +func (m *MockStore) Init(cleanStart bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Init", cleanStart) + ret0, _ := ret[0].(error) + return ret0 +} + +// Init indicates an expected call of Init +func (mr *MockStoreMockRecorder) Init(cleanStart interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockStore)(nil).Init), cleanStart) +} + +// Set mocks base method +func (m *MockStore) Set(id packets.PacketID) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Set", id) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Set indicates an expected call of Set +func (mr *MockStoreMockRecorder) Set(id interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockStore)(nil).Set), id) +} + +// Remove mocks base method +func (m *MockStore) Remove(id packets.PacketID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Remove", id) + ret0, _ := ret[0].(error) + return ret0 +} + +// Remove indicates an expected call of Remove +func (mr *MockStoreMockRecorder) Remove(id interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockStore)(nil).Remove), id) +} diff --git a/internal/hummingbird/mqttbroker/plugin/README.md b/internal/hummingbird/mqttbroker/plugin/README.md new file mode 100644 index 0000000..b792a23 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/README.md @@ -0,0 +1,39 @@ +# Plugin + +[Gmqtt插件机制详解](https://juejin.cn/post/6908305981923409934) + +## How to write plugins + +Gmqtt uses code generator to generate plugin template. + +First, install the CLI tool: + +```bash +# run under gmqtt project root directory. +go install ./cmd/gmqctl +``` + +Enjoy: + +```bash +$ gmqctl gen plugin --help +code generator + +Usage: + gmqctl gen plugin [flags] + +Examples: +The following command will generate a code template for the 'awesome' plugin, which makes use of OnBasicAuth and OnSubscribe hook and enables the configuration in ./plugin directory. + +gmqctl gen plugin -n awesome -H OnBasicAuth,OnSubscribe -c true -o ./plugin + +Flags: + -c, --config Whether the plugin needs a configuration. + -h, --help help for plugin + -H, --hooks string The hooks use by the plugin, multiple hooks are separated by ',' + -n, --name string The plugin name. + -o, --output string The output directory. + +``` + +Details...TODO diff --git a/internal/hummingbird/mqttbroker/plugin/admin/README.md b/internal/hummingbird/mqttbroker/plugin/admin/README.md new file mode 100644 index 0000000..d6fed8e --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/README.md @@ -0,0 +1,83 @@ +# admin + +Admin plugin use [grpc-gateway](https://github.com/grpc-ecosystem/grpc-gateway) to provide both REST HTTP and GRPC APIs +for integration with external systems. + +# API Doc + +See [swagger](https://github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/blob/master/plugin/admin/swagger) + +# Examples + +## List Clients + +```bash +$ curl 127.0.0.1:57091/v1/clients +``` + +Response: + +```json +{ + "clients": [ + { + "client_id": "ab", + "username": "", + "keep_alive": 60, + "version": 4, + "remote_addr": "127.0.0.1:51637", + "local_addr": "127.0.0.1:58090", + "connected_at": "2020-12-12T12:26:36Z", + "disconnected_at": null, + "session_expiry": 7200, + "max_inflight": 100, + "inflight_len": 0, + "max_queue": 100, + "queue_len": 0, + "subscriptions_current": 0, + "subscriptions_total": 0, + "packets_received_bytes": "54", + "packets_received_nums": "3", + "packets_send_bytes": "8", + "packets_send_nums": "2", + "message_dropped": "0" + } + ], + "total_count": 1 + } +``` + +## Filter Subscriptions + +```bash +$ curl 127.0.0.1:57091/v1/filter_subscriptions?filter_type=1,2,3&match_type=1&topic_name=/a +``` + +This curl is able to filter the subscription that the topic name is equal to "/a". + +Response: + +```json +{ + "subscriptions": [ + { + "topic_name": "/a", + "id": 0, + "qos": 1, + "no_local": false, + "retain_as_published": false, + "retain_handling": 0, + "client_id": "ab" + } + ] +} +``` + +## Publish Message + +```bash +$ curl -X POST 127.0.0.1:57091/v1/publish -d '{"topic_name":"a","payload":"test","qos":1}' +``` + +This curl will publish the message to the broker.The broker will check if there are matched topics and send the message +to the subscribers, just like received a message from a MQTT client. \ No newline at end of file diff --git a/internal/hummingbird/mqttbroker/plugin/admin/admin.go b/internal/hummingbird/mqttbroker/plugin/admin/admin.go new file mode 100644 index 0000000..5cbb9d1 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/admin.go @@ -0,0 +1,72 @@ +package admin + +import ( + "go.uber.org/zap" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/server" +) + +var _ server.Plugin = (*Admin)(nil) + +const Name = "admin" + +func init() { + server.RegisterPlugin(Name, New) +} + +func New(config config.Config) (server.Plugin, error) { + return &Admin{}, nil +} + +var log *zap.Logger + +// Admin providers gRPC and HTTP API that enables the external system to interact with the broker. +type Admin struct { + statsReader server.StatsReader + publisher server.Publisher + clientService server.ClientService + store *store +} + +func (a *Admin) registerHTTP(g server.APIRegistrar) (err error) { + err = g.RegisterHTTPHandler(RegisterClientServiceHandlerFromEndpoint) + if err != nil { + return err + } + err = g.RegisterHTTPHandler(RegisterSubscriptionServiceHandlerFromEndpoint) + if err != nil { + return err + } + err = g.RegisterHTTPHandler(RegisterPublishServiceHandlerFromEndpoint) + if err != nil { + return err + } + return nil +} + +func (a *Admin) Load(service server.Server) error { + log = server.LoggerWithField(zap.String("plugin", Name)) + apiRegistrar := service.APIRegistrar() + RegisterClientServiceServer(apiRegistrar, &clientService{a: a}) + RegisterSubscriptionServiceServer(apiRegistrar, &subscriptionService{a: a}) + RegisterPublishServiceServer(apiRegistrar, &publisher{a: a}) + err := a.registerHTTP(apiRegistrar) + if err != nil { + return err + } + a.statsReader = service.StatsManager() + a.store = newStore(a.statsReader, service.GetConfig()) + a.store.subscriptionService = service.SubscriptionService() + a.publisher = service.Publisher() + a.clientService = service.ClientService() + return nil +} + +func (a *Admin) Unload() error { + return nil +} + +func (a *Admin) Name() string { + return Name +} diff --git a/internal/hummingbird/mqttbroker/plugin/admin/client.go b/internal/hummingbird/mqttbroker/plugin/admin/client.go new file mode 100644 index 0000000..daae73a --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/client.go @@ -0,0 +1,58 @@ +package admin + +import ( + "context" + + "github.com/golang/protobuf/ptypes/empty" +) + +type clientService struct { + a *Admin +} + +func (c *clientService) mustEmbedUnimplementedClientServiceServer() { + return +} + +// List lists alertclient information which the session is valid in the broker (both connected and disconnected). +func (c *clientService) List(ctx context.Context, req *ListClientRequest) (*ListClientResponse, error) { + page, pageSize := GetPage(req.Page, req.PageSize) + clients, total, err := c.a.store.GetClients(page, pageSize) + if err != nil { + return &ListClientResponse{}, err + } + return &ListClientResponse{ + Clients: clients, + TotalCount: total, + }, nil +} + +// Get returns the client information for given request client id. +func (c *clientService) Get(ctx context.Context, req *GetClientRequest) (*GetClientResponse, error) { + if req.ClientId == "" { + return nil, ErrInvalidArgument("client_id", "") + } + client := c.a.store.GetClientByID(req.ClientId) + if client == nil { + return nil, ErrNotFound + } + return &GetClientResponse{ + Client: client, + }, nil +} + +// Delete force disconnect. +func (c *clientService) Delete(ctx context.Context, req *DeleteClientRequest) (*empty.Empty, error) { + if req.ClientId == "" { + return nil, ErrInvalidArgument("client_id", "") + } + if req.CleanSession { + c.a.clientService.TerminateSession(req.ClientId) + } else { + client := c.a.clientService.GetClient(req.ClientId) + if client != nil { + client.Close() + } + } + return &empty.Empty{}, nil +} diff --git a/internal/hummingbird/mqttbroker/plugin/admin/client.pb.go b/internal/hummingbird/mqttbroker/plugin/admin/client.pb.go new file mode 100644 index 0000000..197f22c --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/client.pb.go @@ -0,0 +1,739 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.22.0 +// protoc v3.13.0 +// source: client.proto + +package admin + +import ( + reflect "reflect" + sync "sync" + + proto "github.com/golang/protobuf/proto" + empty "github.com/golang/protobuf/ptypes/empty" + timestamp "github.com/golang/protobuf/ptypes/timestamp" + _ "google.golang.org/genproto/googleapis/api/annotations" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// This is a compile-time assertion that a sufficiently up-to-date version +// of the legacy proto package is being used. +const _ = proto.ProtoPackageIsVersion4 + +type ListClientRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + PageSize uint32 `protobuf:"varint,1,opt,name=page_size,json=pageSize,proto3" json:"page_size,omitempty"` + Page uint32 `protobuf:"varint,2,opt,name=page,proto3" json:"page,omitempty"` +} + +func (x *ListClientRequest) Reset() { + *x = ListClientRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_client_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ListClientRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListClientRequest) ProtoMessage() {} + +func (x *ListClientRequest) ProtoReflect() protoreflect.Message { + mi := &file_client_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListClientRequest.ProtoReflect.Descriptor instead. +func (*ListClientRequest) Descriptor() ([]byte, []int) { + return file_client_proto_rawDescGZIP(), []int{0} +} + +func (x *ListClientRequest) GetPageSize() uint32 { + if x != nil { + return x.PageSize + } + return 0 +} + +func (x *ListClientRequest) GetPage() uint32 { + if x != nil { + return x.Page + } + return 0 +} + +type ListClientResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Clients []*Client `protobuf:"bytes,1,rep,name=alertclient,proto3" json:"alertclient,omitempty"` + TotalCount uint32 `protobuf:"varint,2,opt,name=total_count,json=totalCount,proto3" json:"total_count,omitempty"` +} + +func (x *ListClientResponse) Reset() { + *x = ListClientResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_client_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ListClientResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListClientResponse) ProtoMessage() {} + +func (x *ListClientResponse) ProtoReflect() protoreflect.Message { + mi := &file_client_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListClientResponse.ProtoReflect.Descriptor instead. +func (*ListClientResponse) Descriptor() ([]byte, []int) { + return file_client_proto_rawDescGZIP(), []int{1} +} + +func (x *ListClientResponse) GetClients() []*Client { + if x != nil { + return x.Clients + } + return nil +} + +func (x *ListClientResponse) GetTotalCount() uint32 { + if x != nil { + return x.TotalCount + } + return 0 +} + +type GetClientRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ClientId string `protobuf:"bytes,1,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"` +} + +func (x *GetClientRequest) Reset() { + *x = GetClientRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_client_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GetClientRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetClientRequest) ProtoMessage() {} + +func (x *GetClientRequest) ProtoReflect() protoreflect.Message { + mi := &file_client_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetClientRequest.ProtoReflect.Descriptor instead. +func (*GetClientRequest) Descriptor() ([]byte, []int) { + return file_client_proto_rawDescGZIP(), []int{2} +} + +func (x *GetClientRequest) GetClientId() string { + if x != nil { + return x.ClientId + } + return "" +} + +type GetClientResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Client *Client `protobuf:"bytes,1,opt,name=client,proto3" json:"client,omitempty"` +} + +func (x *GetClientResponse) Reset() { + *x = GetClientResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_client_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GetClientResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetClientResponse) ProtoMessage() {} + +func (x *GetClientResponse) ProtoReflect() protoreflect.Message { + mi := &file_client_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetClientResponse.ProtoReflect.Descriptor instead. +func (*GetClientResponse) Descriptor() ([]byte, []int) { + return file_client_proto_rawDescGZIP(), []int{3} +} + +func (x *GetClientResponse) GetClient() *Client { + if x != nil { + return x.Client + } + return nil +} + +type DeleteClientRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ClientId string `protobuf:"bytes,1,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"` + CleanSession bool `protobuf:"varint,2,opt,name=clean_session,json=cleanSession,proto3" json:"clean_session,omitempty"` +} + +func (x *DeleteClientRequest) Reset() { + *x = DeleteClientRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_client_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DeleteClientRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DeleteClientRequest) ProtoMessage() {} + +func (x *DeleteClientRequest) ProtoReflect() protoreflect.Message { + mi := &file_client_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DeleteClientRequest.ProtoReflect.Descriptor instead. +func (*DeleteClientRequest) Descriptor() ([]byte, []int) { + return file_client_proto_rawDescGZIP(), []int{4} +} + +func (x *DeleteClientRequest) GetClientId() string { + if x != nil { + return x.ClientId + } + return "" +} + +func (x *DeleteClientRequest) GetCleanSession() bool { + if x != nil { + return x.CleanSession + } + return false +} + +type Client struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ClientId string `protobuf:"bytes,1,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"` + Username string `protobuf:"bytes,2,opt,name=username,proto3" json:"username,omitempty"` + KeepAlive int32 `protobuf:"varint,3,opt,name=keep_alive,json=keepAlive,proto3" json:"keep_alive,omitempty"` + Version int32 `protobuf:"varint,4,opt,name=version,proto3" json:"version,omitempty"` + RemoteAddr string `protobuf:"bytes,5,opt,name=remote_addr,json=remoteAddr,proto3" json:"remote_addr,omitempty"` + LocalAddr string `protobuf:"bytes,6,opt,name=local_addr,json=localAddr,proto3" json:"local_addr,omitempty"` + ConnectedAt *timestamp.Timestamp `protobuf:"bytes,7,opt,name=connected_at,json=connectedAt,proto3" json:"connected_at,omitempty"` + DisconnectedAt *timestamp.Timestamp `protobuf:"bytes,8,opt,name=disconnected_at,json=disconnectedAt,proto3" json:"disconnected_at,omitempty"` + SessionExpiry uint32 `protobuf:"varint,9,opt,name=session_expiry,json=sessionExpiry,proto3" json:"session_expiry,omitempty"` + MaxInflight uint32 `protobuf:"varint,10,opt,name=max_inflight,json=maxInflight,proto3" json:"max_inflight,omitempty"` + InflightLen uint32 `protobuf:"varint,11,opt,name=inflight_len,json=inflightLen,proto3" json:"inflight_len,omitempty"` + MaxQueue uint32 `protobuf:"varint,12,opt,name=max_queue,json=maxQueue,proto3" json:"max_queue,omitempty"` + QueueLen uint32 `protobuf:"varint,13,opt,name=queue_len,json=queueLen,proto3" json:"queue_len,omitempty"` + SubscriptionsCurrent uint32 `protobuf:"varint,14,opt,name=subscriptions_current,json=subscriptionsCurrent,proto3" json:"subscriptions_current,omitempty"` + SubscriptionsTotal uint32 `protobuf:"varint,15,opt,name=subscriptions_total,json=subscriptionsTotal,proto3" json:"subscriptions_total,omitempty"` + PacketsReceivedBytes uint64 `protobuf:"varint,16,opt,name=packets_received_bytes,json=packetsReceivedBytes,proto3" json:"packets_received_bytes,omitempty"` + PacketsReceivedNums uint64 `protobuf:"varint,17,opt,name=packets_received_nums,json=packetsReceivedNums,proto3" json:"packets_received_nums,omitempty"` + PacketsSendBytes uint64 `protobuf:"varint,18,opt,name=packets_send_bytes,json=packetsSendBytes,proto3" json:"packets_send_bytes,omitempty"` + PacketsSendNums uint64 `protobuf:"varint,19,opt,name=packets_send_nums,json=packetsSendNums,proto3" json:"packets_send_nums,omitempty"` + MessageDropped uint64 `protobuf:"varint,20,opt,name=message_dropped,json=messageDropped,proto3" json:"message_dropped,omitempty"` +} + +func (x *Client) Reset() { + *x = Client{} + if protoimpl.UnsafeEnabled { + mi := &file_client_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Client) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Client) ProtoMessage() {} + +func (x *Client) ProtoReflect() protoreflect.Message { + mi := &file_client_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Client.ProtoReflect.Descriptor instead. +func (*Client) Descriptor() ([]byte, []int) { + return file_client_proto_rawDescGZIP(), []int{5} +} + +func (x *Client) GetClientId() string { + if x != nil { + return x.ClientId + } + return "" +} + +func (x *Client) GetUsername() string { + if x != nil { + return x.Username + } + return "" +} + +func (x *Client) GetKeepAlive() int32 { + if x != nil { + return x.KeepAlive + } + return 0 +} + +func (x *Client) GetVersion() int32 { + if x != nil { + return x.Version + } + return 0 +} + +func (x *Client) GetRemoteAddr() string { + if x != nil { + return x.RemoteAddr + } + return "" +} + +func (x *Client) GetLocalAddr() string { + if x != nil { + return x.LocalAddr + } + return "" +} + +func (x *Client) GetConnectedAt() *timestamp.Timestamp { + if x != nil { + return x.ConnectedAt + } + return nil +} + +func (x *Client) GetDisconnectedAt() *timestamp.Timestamp { + if x != nil { + return x.DisconnectedAt + } + return nil +} + +func (x *Client) GetSessionExpiry() uint32 { + if x != nil { + return x.SessionExpiry + } + return 0 +} + +func (x *Client) GetMaxInflight() uint32 { + if x != nil { + return x.MaxInflight + } + return 0 +} + +func (x *Client) GetInflightLen() uint32 { + if x != nil { + return x.InflightLen + } + return 0 +} + +func (x *Client) GetMaxQueue() uint32 { + if x != nil { + return x.MaxQueue + } + return 0 +} + +func (x *Client) GetQueueLen() uint32 { + if x != nil { + return x.QueueLen + } + return 0 +} + +func (x *Client) GetSubscriptionsCurrent() uint32 { + if x != nil { + return x.SubscriptionsCurrent + } + return 0 +} + +func (x *Client) GetSubscriptionsTotal() uint32 { + if x != nil { + return x.SubscriptionsTotal + } + return 0 +} + +func (x *Client) GetPacketsReceivedBytes() uint64 { + if x != nil { + return x.PacketsReceivedBytes + } + return 0 +} + +func (x *Client) GetPacketsReceivedNums() uint64 { + if x != nil { + return x.PacketsReceivedNums + } + return 0 +} + +func (x *Client) GetPacketsSendBytes() uint64 { + if x != nil { + return x.PacketsSendBytes + } + return 0 +} + +func (x *Client) GetPacketsSendNums() uint64 { + if x != nil { + return x.PacketsSendNums + } + return 0 +} + +func (x *Client) GetMessageDropped() uint64 { + if x != nil { + return x.MessageDropped + } + return 0 +} + +var File_client_proto protoreflect.FileDescriptor + +var file_client_proto_rawDesc = []byte{ + 0x0a, 0x0c, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0f, + 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x1a, + 0x1c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x61, 0x6e, 0x6e, 0x6f, + 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1b, 0x67, + 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x65, + 0x6d, 0x70, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, + 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, + 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x44, 0x0a, 0x11, 0x4c, + 0x69, 0x73, 0x74, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x1b, 0x0a, 0x09, 0x70, 0x61, 0x67, 0x65, 0x5f, 0x73, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0d, 0x52, 0x08, 0x70, 0x61, 0x67, 0x65, 0x53, 0x69, 0x7a, 0x65, 0x12, 0x12, 0x0a, + 0x04, 0x70, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x70, 0x61, 0x67, + 0x65, 0x22, 0x68, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x31, 0x0a, 0x07, 0x63, 0x6c, 0x69, 0x65, 0x6e, + 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, + 0x2e, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x43, 0x6c, 0x69, 0x65, 0x6e, + 0x74, 0x52, 0x07, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x1f, 0x0a, 0x0b, 0x74, 0x6f, + 0x74, 0x61, 0x6c, 0x5f, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, + 0x0a, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x22, 0x2f, 0x0a, 0x10, 0x47, + 0x65, 0x74, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, + 0x1b, 0x0a, 0x09, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x22, 0x44, 0x0a, 0x11, + 0x47, 0x65, 0x74, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x2f, 0x0a, 0x06, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x17, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x2e, + 0x61, 0x70, 0x69, 0x2e, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x52, 0x06, 0x63, 0x6c, 0x69, 0x65, + 0x6e, 0x74, 0x22, 0x57, 0x0a, 0x13, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x43, 0x6c, 0x69, 0x65, + 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x63, 0x6c, 0x69, + 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x63, 0x6c, + 0x69, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x23, 0x0a, 0x0d, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x5f, + 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0c, 0x63, + 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0xb8, 0x06, 0x0a, 0x06, + 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, + 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, + 0x74, 0x49, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, 0x65, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, 0x65, 0x12, + 0x1d, 0x0a, 0x0a, 0x6b, 0x65, 0x65, 0x70, 0x5f, 0x61, 0x6c, 0x69, 0x76, 0x65, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x05, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x41, 0x6c, 0x69, 0x76, 0x65, 0x12, 0x18, + 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, + 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1f, 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, + 0x74, 0x65, 0x5f, 0x61, 0x64, 0x64, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x72, + 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x41, 0x64, 0x64, 0x72, 0x12, 0x1d, 0x0a, 0x0a, 0x6c, 0x6f, 0x63, + 0x61, 0x6c, 0x5f, 0x61, 0x64, 0x64, 0x72, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x6c, + 0x6f, 0x63, 0x61, 0x6c, 0x41, 0x64, 0x64, 0x72, 0x12, 0x3d, 0x0a, 0x0c, 0x63, 0x6f, 0x6e, 0x6e, + 0x65, 0x63, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, + 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, + 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x0b, 0x63, 0x6f, 0x6e, 0x6e, + 0x65, 0x63, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x43, 0x0a, 0x0f, 0x64, 0x69, 0x73, 0x63, 0x6f, + 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x0e, 0x64, 0x69, + 0x73, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x25, 0x0a, 0x0e, + 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x65, 0x78, 0x70, 0x69, 0x72, 0x79, 0x18, 0x09, + 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0d, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x45, 0x78, 0x70, + 0x69, 0x72, 0x79, 0x12, 0x21, 0x0a, 0x0c, 0x6d, 0x61, 0x78, 0x5f, 0x69, 0x6e, 0x66, 0x6c, 0x69, + 0x67, 0x68, 0x74, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x6d, 0x61, 0x78, 0x49, 0x6e, + 0x66, 0x6c, 0x69, 0x67, 0x68, 0x74, 0x12, 0x21, 0x0a, 0x0c, 0x69, 0x6e, 0x66, 0x6c, 0x69, 0x67, + 0x68, 0x74, 0x5f, 0x6c, 0x65, 0x6e, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x69, 0x6e, + 0x66, 0x6c, 0x69, 0x67, 0x68, 0x74, 0x4c, 0x65, 0x6e, 0x12, 0x1b, 0x0a, 0x09, 0x6d, 0x61, 0x78, + 0x5f, 0x71, 0x75, 0x65, 0x75, 0x65, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x6d, 0x61, + 0x78, 0x51, 0x75, 0x65, 0x75, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x71, 0x75, 0x65, 0x75, 0x65, 0x5f, + 0x6c, 0x65, 0x6e, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x71, 0x75, 0x65, 0x75, 0x65, + 0x4c, 0x65, 0x6e, 0x12, 0x33, 0x0a, 0x15, 0x73, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x73, 0x5f, 0x63, 0x75, 0x72, 0x72, 0x65, 0x6e, 0x74, 0x18, 0x0e, 0x20, 0x01, + 0x28, 0x0d, 0x52, 0x14, 0x73, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, + 0x73, 0x43, 0x75, 0x72, 0x72, 0x65, 0x6e, 0x74, 0x12, 0x2f, 0x0a, 0x13, 0x73, 0x75, 0x62, 0x73, + 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x5f, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x18, + 0x0f, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x12, 0x73, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x73, 0x54, 0x6f, 0x74, 0x61, 0x6c, 0x12, 0x34, 0x0a, 0x16, 0x70, 0x61, 0x63, + 0x6b, 0x65, 0x74, 0x73, 0x5f, 0x72, 0x65, 0x63, 0x65, 0x69, 0x76, 0x65, 0x64, 0x5f, 0x62, 0x79, + 0x74, 0x65, 0x73, 0x18, 0x10, 0x20, 0x01, 0x28, 0x04, 0x52, 0x14, 0x70, 0x61, 0x63, 0x6b, 0x65, + 0x74, 0x73, 0x52, 0x65, 0x63, 0x65, 0x69, 0x76, 0x65, 0x64, 0x42, 0x79, 0x74, 0x65, 0x73, 0x12, + 0x32, 0x0a, 0x15, 0x70, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x5f, 0x72, 0x65, 0x63, 0x65, 0x69, + 0x76, 0x65, 0x64, 0x5f, 0x6e, 0x75, 0x6d, 0x73, 0x18, 0x11, 0x20, 0x01, 0x28, 0x04, 0x52, 0x13, + 0x70, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x52, 0x65, 0x63, 0x65, 0x69, 0x76, 0x65, 0x64, 0x4e, + 0x75, 0x6d, 0x73, 0x12, 0x2c, 0x0a, 0x12, 0x70, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x5f, 0x73, + 0x65, 0x6e, 0x64, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x12, 0x20, 0x01, 0x28, 0x04, 0x52, + 0x10, 0x70, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x53, 0x65, 0x6e, 0x64, 0x42, 0x79, 0x74, 0x65, + 0x73, 0x12, 0x2a, 0x0a, 0x11, 0x70, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x5f, 0x73, 0x65, 0x6e, + 0x64, 0x5f, 0x6e, 0x75, 0x6d, 0x73, 0x18, 0x13, 0x20, 0x01, 0x28, 0x04, 0x52, 0x0f, 0x70, 0x61, + 0x63, 0x6b, 0x65, 0x74, 0x73, 0x53, 0x65, 0x6e, 0x64, 0x4e, 0x75, 0x6d, 0x73, 0x12, 0x27, 0x0a, + 0x0f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x5f, 0x64, 0x72, 0x6f, 0x70, 0x70, 0x65, 0x64, + 0x18, 0x14, 0x20, 0x01, 0x28, 0x04, 0x52, 0x0e, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x44, + 0x72, 0x6f, 0x70, 0x70, 0x65, 0x64, 0x32, 0xcd, 0x02, 0x0a, 0x0d, 0x43, 0x6c, 0x69, 0x65, 0x6e, + 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x64, 0x0a, 0x04, 0x4c, 0x69, 0x73, 0x74, + 0x12, 0x22, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x2e, 0x61, + 0x70, 0x69, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x23, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, 0x64, 0x6d, + 0x69, 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x43, 0x6c, 0x69, 0x65, 0x6e, + 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x13, 0x82, 0xd3, 0xe4, 0x93, 0x02, + 0x0d, 0x12, 0x0b, 0x2f, 0x76, 0x31, 0x2f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x6d, + 0x0a, 0x03, 0x47, 0x65, 0x74, 0x12, 0x21, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, 0x64, + 0x6d, 0x69, 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6c, 0x69, 0x65, 0x6e, + 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x22, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, + 0x2e, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6c, + 0x69, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1f, 0x82, 0xd3, + 0xe4, 0x93, 0x02, 0x19, 0x12, 0x17, 0x2f, 0x76, 0x31, 0x2f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, + 0x73, 0x2f, 0x7b, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x7d, 0x12, 0x67, 0x0a, + 0x06, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x12, 0x24, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, + 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, + 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, + 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, + 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x1f, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x19, 0x2a, 0x17, 0x2f, + 0x76, 0x31, 0x2f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x7b, 0x63, 0x6c, 0x69, 0x65, + 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x7d, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x3b, 0x61, 0x64, 0x6d, 0x69, + 0x6e, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_client_proto_rawDescOnce sync.Once + file_client_proto_rawDescData = file_client_proto_rawDesc +) + +func file_client_proto_rawDescGZIP() []byte { + file_client_proto_rawDescOnce.Do(func() { + file_client_proto_rawDescData = protoimpl.X.CompressGZIP(file_client_proto_rawDescData) + }) + return file_client_proto_rawDescData +} + +var file_client_proto_msgTypes = make([]protoimpl.MessageInfo, 6) +var file_client_proto_goTypes = []interface{}{ + (*ListClientRequest)(nil), // 0: gmqtt.admin.api.ListClientRequest + (*ListClientResponse)(nil), // 1: gmqtt.admin.api.ListClientResponse + (*GetClientRequest)(nil), // 2: gmqtt.admin.api.GetClientRequest + (*GetClientResponse)(nil), // 3: gmqtt.admin.api.GetClientResponse + (*DeleteClientRequest)(nil), // 4: gmqtt.admin.api.DeleteClientRequest + (*Client)(nil), // 5: gmqtt.admin.api.Client + (*timestamp.Timestamp)(nil), // 6: google.protobuf.Timestamp + (*empty.Empty)(nil), // 7: google.protobuf.Empty +} +var file_client_proto_depIdxs = []int32{ + 5, // 0: gmqtt.admin.api.ListClientResponse.alertclient:type_name -> gmqtt.admin.api.Client + 5, // 1: gmqtt.admin.api.GetClientResponse.client:type_name -> gmqtt.admin.api.Client + 6, // 2: gmqtt.admin.api.Client.connected_at:type_name -> google.protobuf.Timestamp + 6, // 3: gmqtt.admin.api.Client.disconnected_at:type_name -> google.protobuf.Timestamp + 0, // 4: gmqtt.admin.api.ClientService.List:input_type -> gmqtt.admin.api.ListClientRequest + 2, // 5: gmqtt.admin.api.ClientService.Get:input_type -> gmqtt.admin.api.GetClientRequest + 4, // 6: gmqtt.admin.api.ClientService.Delete:input_type -> gmqtt.admin.api.DeleteClientRequest + 1, // 7: gmqtt.admin.api.ClientService.List:output_type -> gmqtt.admin.api.ListClientResponse + 3, // 8: gmqtt.admin.api.ClientService.Get:output_type -> gmqtt.admin.api.GetClientResponse + 7, // 9: gmqtt.admin.api.ClientService.Delete:output_type -> google.protobuf.Empty + 7, // [7:10] is the sub-list for method output_type + 4, // [4:7] is the sub-list for method input_type + 4, // [4:4] is the sub-list for extension type_name + 4, // [4:4] is the sub-list for extension extendee + 0, // [0:4] is the sub-list for field type_name +} + +func init() { file_client_proto_init() } +func file_client_proto_init() { + if File_client_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_client_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ListClientRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_client_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ListClientResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_client_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetClientRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_client_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetClientResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_client_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DeleteClientRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_client_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Client); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_client_proto_rawDesc, + NumEnums: 0, + NumMessages: 6, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_client_proto_goTypes, + DependencyIndexes: file_client_proto_depIdxs, + MessageInfos: file_client_proto_msgTypes, + }.Build() + File_client_proto = out.File + file_client_proto_rawDesc = nil + file_client_proto_goTypes = nil + file_client_proto_depIdxs = nil +} diff --git a/internal/hummingbird/mqttbroker/plugin/admin/client.pb.gw.go b/internal/hummingbird/mqttbroker/plugin/admin/client.pb.gw.go new file mode 100644 index 0000000..f547a16 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/client.pb.gw.go @@ -0,0 +1,373 @@ +// Code generated by protoc-gen-grpc-gateway. DO NOT EDIT. +// source: client.proto + +/* +Package admin is a reverse proxy. + +It translates gRPC into RESTful JSON APIs. +*/ +package admin + +import ( + "context" + "io" + "net/http" + + "github.com/golang/protobuf/descriptor" + "github.com/golang/protobuf/proto" + "github.com/grpc-ecosystem/grpc-gateway/runtime" + "github.com/grpc-ecosystem/grpc-gateway/utilities" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/status" +) + +// Suppress "imported and not used" errors +var _ codes.Code +var _ io.Reader +var _ status.Status +var _ = runtime.String +var _ = utilities.NewDoubleArray +var _ = descriptor.ForMessage + +var ( + filter_ClientService_List_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)} +) + +func request_ClientService_List_0(ctx context.Context, marshaler runtime.Marshaler, client ClientServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq ListClientRequest + var metadata runtime.ServerMetadata + + if err := req.ParseForm(); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_ClientService_List_0); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := client.List(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err + +} + +func local_request_ClientService_List_0(ctx context.Context, marshaler runtime.Marshaler, server ClientServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq ListClientRequest + var metadata runtime.ServerMetadata + + if err := runtime.PopulateQueryParameters(&protoReq, req.URL.Query(), filter_ClientService_List_0); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := server.List(ctx, &protoReq) + return msg, metadata, err + +} + +func request_ClientService_Get_0(ctx context.Context, marshaler runtime.Marshaler, client ClientServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq GetClientRequest + var metadata runtime.ServerMetadata + + var ( + val string + ok bool + err error + _ = err + ) + + val, ok = pathParams["client_id"] + if !ok { + return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "client_id") + } + + protoReq.ClientId, err = runtime.String(val) + + if err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "client_id", err) + } + + msg, err := client.Get(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err + +} + +func local_request_ClientService_Get_0(ctx context.Context, marshaler runtime.Marshaler, server ClientServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq GetClientRequest + var metadata runtime.ServerMetadata + + var ( + val string + ok bool + err error + _ = err + ) + + val, ok = pathParams["client_id"] + if !ok { + return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "client_id") + } + + protoReq.ClientId, err = runtime.String(val) + + if err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "client_id", err) + } + + msg, err := server.Get(ctx, &protoReq) + return msg, metadata, err + +} + +var ( + filter_ClientService_Delete_0 = &utilities.DoubleArray{Encoding: map[string]int{"client_id": 0}, Base: []int{1, 1, 0}, Check: []int{0, 1, 2}} +) + +func request_ClientService_Delete_0(ctx context.Context, marshaler runtime.Marshaler, client ClientServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq DeleteClientRequest + var metadata runtime.ServerMetadata + + var ( + val string + ok bool + err error + _ = err + ) + + val, ok = pathParams["client_id"] + if !ok { + return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "client_id") + } + + protoReq.ClientId, err = runtime.String(val) + + if err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "client_id", err) + } + + if err := req.ParseForm(); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_ClientService_Delete_0); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := client.Delete(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err + +} + +func local_request_ClientService_Delete_0(ctx context.Context, marshaler runtime.Marshaler, server ClientServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq DeleteClientRequest + var metadata runtime.ServerMetadata + + var ( + val string + ok bool + err error + _ = err + ) + + val, ok = pathParams["client_id"] + if !ok { + return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "client_id") + } + + protoReq.ClientId, err = runtime.String(val) + + if err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "client_id", err) + } + + if err := runtime.PopulateQueryParameters(&protoReq, req.URL.Query(), filter_ClientService_Delete_0); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := server.Delete(ctx, &protoReq) + return msg, metadata, err + +} + +// RegisterClientServiceHandlerServer registers the http handlers for service ClientService to "mux". +// UnaryRPC :call ClientServiceServer directly. +// StreamingRPC :currently unsupported pending https://github.com/grpc/grpc-go/issues/906. +func RegisterClientServiceHandlerServer(ctx context.Context, mux *runtime.ServeMux, server ClientServiceServer) error { + + mux.Handle("GET", pattern_ClientService_List_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_ClientService_List_0(rctx, inboundMarshaler, server, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_ClientService_List_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("GET", pattern_ClientService_Get_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_ClientService_Get_0(rctx, inboundMarshaler, server, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_ClientService_Get_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("DELETE", pattern_ClientService_Delete_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_ClientService_Delete_0(rctx, inboundMarshaler, server, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_ClientService_Delete_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + return nil +} + +// RegisterClientServiceHandlerFromEndpoint is same as RegisterClientServiceHandler but +// automatically dials to "endpoint" and closes the connection when "ctx" gets done. +func RegisterClientServiceHandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { + conn, err := grpc.Dial(endpoint, opts...) + if err != nil { + return err + } + defer func() { + if err != nil { + if cerr := conn.Close(); cerr != nil { + grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) + } + return + } + go func() { + <-ctx.Done() + if cerr := conn.Close(); cerr != nil { + grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) + } + }() + }() + + return RegisterClientServiceHandler(ctx, mux, conn) +} + +// RegisterClientServiceHandler registers the http handlers for service ClientService to "mux". +// The handlers forward requests to the grpc endpoint over "conn". +func RegisterClientServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { + return RegisterClientServiceHandlerClient(ctx, mux, NewClientServiceClient(conn)) +} + +// RegisterClientServiceHandlerClient registers the http handlers for service ClientService +// to "mux". The handlers forward requests to the grpc endpoint over the given implementation of "ClientServiceClient". +// Note: the gRPC framework executes interceptors within the gRPC handler. If the passed in "ClientServiceClient" +// doesn't go through the normal gRPC flow (creating a gRPC client etc.) then it will be up to the passed in +// "ClientServiceClient" to call the correct interceptors. +func RegisterClientServiceHandlerClient(ctx context.Context, mux *runtime.ServeMux, client ClientServiceClient) error { + + mux.Handle("GET", pattern_ClientService_List_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_ClientService_List_0(rctx, inboundMarshaler, client, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_ClientService_List_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("GET", pattern_ClientService_Get_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_ClientService_Get_0(rctx, inboundMarshaler, client, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_ClientService_Get_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("DELETE", pattern_ClientService_Delete_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_ClientService_Delete_0(rctx, inboundMarshaler, client, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_ClientService_Delete_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + return nil +} + +var ( + pattern_ClientService_List_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1}, []string{"v1", "alertclient"}, "", runtime.AssumeColonVerbOpt(true))) + + pattern_ClientService_Get_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 1, 0, 4, 1, 5, 2}, []string{"v1", "alertclient", "client_id"}, "", runtime.AssumeColonVerbOpt(true))) + + pattern_ClientService_Delete_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 1, 0, 4, 1, 5, 2}, []string{"v1", "alertclient", "client_id"}, "", runtime.AssumeColonVerbOpt(true))) +) + +var ( + forward_ClientService_List_0 = runtime.ForwardResponseMessage + + forward_ClientService_Get_0 = runtime.ForwardResponseMessage + + forward_ClientService_Delete_0 = runtime.ForwardResponseMessage +) diff --git a/internal/hummingbird/mqttbroker/plugin/admin/client_grpc.pb.go b/internal/hummingbird/mqttbroker/plugin/admin/client_grpc.pb.go new file mode 100644 index 0000000..134d130 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/client_grpc.pb.go @@ -0,0 +1,179 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. + +package admin + +import ( + context "context" + + empty "github.com/golang/protobuf/ptypes/empty" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion7 + +// ClientServiceClient is the client API for ClientService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type ClientServiceClient interface { + // List alertclient + List(ctx context.Context, in *ListClientRequest, opts ...grpc.CallOption) (*ListClientResponse, error) + // Get the client for given client id. + // Return NotFound error when client not found. + Get(ctx context.Context, in *GetClientRequest, opts ...grpc.CallOption) (*GetClientResponse, error) + // Disconnect the client for given client id. + Delete(ctx context.Context, in *DeleteClientRequest, opts ...grpc.CallOption) (*empty.Empty, error) +} + +type clientServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewClientServiceClient(cc grpc.ClientConnInterface) ClientServiceClient { + return &clientServiceClient{cc} +} + +func (c *clientServiceClient) List(ctx context.Context, in *ListClientRequest, opts ...grpc.CallOption) (*ListClientResponse, error) { + out := new(ListClientResponse) + err := c.cc.Invoke(ctx, "/gmqtt.admin.api.ClientService/List", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *clientServiceClient) Get(ctx context.Context, in *GetClientRequest, opts ...grpc.CallOption) (*GetClientResponse, error) { + out := new(GetClientResponse) + err := c.cc.Invoke(ctx, "/gmqtt.admin.api.ClientService/Get", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *clientServiceClient) Delete(ctx context.Context, in *DeleteClientRequest, opts ...grpc.CallOption) (*empty.Empty, error) { + out := new(empty.Empty) + err := c.cc.Invoke(ctx, "/gmqtt.admin.api.ClientService/Delete", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// ClientServiceServer is the server API for ClientService service. +// All implementations must embed UnimplementedClientServiceServer +// for forward compatibility +type ClientServiceServer interface { + // List alertclient + List(context.Context, *ListClientRequest) (*ListClientResponse, error) + // Get the client for given client id. + // Return NotFound error when client not found. + Get(context.Context, *GetClientRequest) (*GetClientResponse, error) + // Disconnect the client for given client id. + Delete(context.Context, *DeleteClientRequest) (*empty.Empty, error) + mustEmbedUnimplementedClientServiceServer() +} + +// UnimplementedClientServiceServer must be embedded to have forward compatible implementations. +type UnimplementedClientServiceServer struct { +} + +func (UnimplementedClientServiceServer) List(context.Context, *ListClientRequest) (*ListClientResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method List not implemented") +} +func (UnimplementedClientServiceServer) Get(context.Context, *GetClientRequest) (*GetClientResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Get not implemented") +} +func (UnimplementedClientServiceServer) Delete(context.Context, *DeleteClientRequest) (*empty.Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method Delete not implemented") +} +func (UnimplementedClientServiceServer) mustEmbedUnimplementedClientServiceServer() {} + +// UnsafeClientServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to ClientServiceServer will +// result in compilation errors. +type UnsafeClientServiceServer interface { + mustEmbedUnimplementedClientServiceServer() +} + +func RegisterClientServiceServer(s grpc.ServiceRegistrar, srv ClientServiceServer) { + s.RegisterService(&_ClientService_serviceDesc, srv) +} + +func _ClientService_List_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ListClientRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ClientServiceServer).List(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/gmqtt.admin.api.ClientService/List", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ClientServiceServer).List(ctx, req.(*ListClientRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _ClientService_Get_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetClientRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ClientServiceServer).Get(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/gmqtt.admin.api.ClientService/Get", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ClientServiceServer).Get(ctx, req.(*GetClientRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _ClientService_Delete_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(DeleteClientRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ClientServiceServer).Delete(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/gmqtt.admin.api.ClientService/Delete", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ClientServiceServer).Delete(ctx, req.(*DeleteClientRequest)) + } + return interceptor(ctx, in, info, handler) +} + +var _ClientService_serviceDesc = grpc.ServiceDesc{ + ServiceName: "gmqtt.admin.api.ClientService", + HandlerType: (*ClientServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "List", + Handler: _ClientService_List_Handler, + }, + { + MethodName: "Get", + Handler: _ClientService_Get_Handler, + }, + { + MethodName: "Delete", + Handler: _ClientService_Delete_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "client.proto", +} diff --git a/internal/hummingbird/mqttbroker/plugin/admin/client_test.go b/internal/hummingbird/mqttbroker/plugin/admin/client_test.go new file mode 100644 index 0000000..7c77000 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/client_test.go @@ -0,0 +1,179 @@ +package admin + +import ( + "context" + "net" + "strconv" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/server" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +var mockConfig = config.Config{ + MQTT: config.MQTT{ + MaxQueuedMsg: 10, + }, +} + +type dummyConn struct { + net.Conn +} + +// LocalAddr returns the local network address. +func (d *dummyConn) LocalAddr() net.Addr { + return &net.TCPAddr{} +} +func (d *dummyConn) RemoteAddr() net.Addr { + return &net.TCPAddr{} +} + +func TestClientService_List_Get(t *testing.T) { + a := assert.New(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cs := server.NewMockClientService(ctrl) + sr := server.NewMockStatsReader(ctrl) + + admin := &Admin{ + statsReader: sr, + clientService: cs, + store: newStore(sr, mockConfig), + } + c := &clientService{ + a: admin, + } + now := time.Now() + client := server.NewMockClient(ctrl) + client.EXPECT().Version().Return(packets.Version5).AnyTimes() + client.EXPECT().Connection().Return(&dummyConn{}).AnyTimes() + client.EXPECT().ConnectedAt().Return(now).AnyTimes() + created := admin.OnSessionCreatedWrapper(func(ctx context.Context, client server.Client) {}) + for i := 0; i < 10; i++ { + sr.EXPECT().GetClientStats(strconv.Itoa(i)).AnyTimes() + client.EXPECT().ClientOptions().Return(&server.ClientOptions{ + ClientID: strconv.Itoa(i), + Username: strconv.Itoa(i), + KeepAlive: uint16(i), + SessionExpiry: uint32(i), + MaxInflight: uint16(i), + ReceiveMax: uint16(i), + ClientMaxPacketSize: uint32(i), + ServerMaxPacketSize: uint32(i), + ClientTopicAliasMax: uint16(i), + ServerTopicAliasMax: uint16(i), + RequestProblemInfo: true, + UserProperties: []*packets.UserProperty{ + { + K: []byte{1, 2}, + V: []byte{1, 2}, + }, + }, + RetainAvailable: true, + WildcardSubAvailable: true, + SubIDAvailable: true, + SharedSubAvailable: true, + }) + created(context.Background(), client) + } + + resp, err := c.List(context.Background(), &ListClientRequest{ + PageSize: 0, + Page: 0, + }) + a.Nil(err) + a.Len(resp.Clients, 10) + for k, v := range resp.Clients { + addr := net.TCPAddr{} + a.Equal(&Client{ + ClientId: strconv.Itoa(k), + Username: strconv.Itoa(k), + KeepAlive: int32(k), + Version: int32(packets.Version5), + RemoteAddr: addr.String(), + LocalAddr: addr.String(), + ConnectedAt: timestamppb.New(now), + DisconnectedAt: nil, + SessionExpiry: uint32(k), + MaxInflight: uint32(k), + MaxQueue: uint32(mockConfig.MQTT.MaxQueuedMsg), + PacketsReceivedBytes: 0, + PacketsReceivedNums: 0, + PacketsSendBytes: 0, + PacketsSendNums: 0, + MessageDropped: 0, + }, v) + } + + getResp, err := c.Get(context.Background(), &GetClientRequest{ + ClientId: "1", + }) + a.Nil(err) + a.Equal(resp.Clients[1], getResp.Client) + + pagingResp, err := c.List(context.Background(), &ListClientRequest{ + PageSize: 2, + Page: 2, + }) + a.Nil(err) + a.Len(pagingResp.Clients, 2) + a.Equal(resp.Clients[2], pagingResp.Clients[0]) + a.Equal(resp.Clients[3], pagingResp.Clients[1]) +} + +func TestClientService_Delete(t *testing.T) { + a := assert.New(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cs := server.NewMockClientService(ctrl) + sr := server.NewMockStatsReader(ctrl) + + admin := &Admin{ + statsReader: sr, + clientService: cs, + store: newStore(sr, mockConfig), + } + c := &clientService{ + a: admin, + } + client := server.NewMockClient(ctrl) + client.EXPECT().Close() + cs.EXPECT().GetClient("1").Return(client) + _, err := c.Delete(context.Background(), &DeleteClientRequest{ + ClientId: "1", + CleanSession: false, + }) + a.Nil(err) +} + +func TestClientService_Delete_CleanSession(t *testing.T) { + a := assert.New(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cs := server.NewMockClientService(ctrl) + sr := server.NewMockStatsReader(ctrl) + + admin := &Admin{ + statsReader: sr, + clientService: cs, + store: newStore(sr, mockConfig), + } + c := &clientService{ + a: admin, + } + cs.EXPECT().TerminateSession("1") + _, err := c.Delete(context.Background(), &DeleteClientRequest{ + ClientId: "1", + CleanSession: true, + }) + a.Nil(err) +} diff --git a/internal/hummingbird/mqttbroker/plugin/admin/config.go b/internal/hummingbird/mqttbroker/plugin/admin/config.go new file mode 100644 index 0000000..f27d19e --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/config.go @@ -0,0 +1,78 @@ +package admin + +import ( + "errors" + "net" +) + +// Config is the configuration for the admin plugin. +type Config struct { + HTTP HTTPConfig `yaml:"http"` + GRPC GRPCConfig `yaml:"grpc"` +} + +// HTTPConfig is the configuration for http endpoint. +type HTTPConfig struct { + // Enable indicates whether to expose http endpoint. + Enable bool `yaml:"enable"` + // Addr is the address that the http server listen on. + Addr string `yaml:"http_addr"` +} + +// GRPCConfig is the configuration for gRPC endpoint. +type GRPCConfig struct { + // Addr is the address that the gRPC server listen on. + Addr string `yaml:"http_addr"` +} + +// Validate validates the configuration, and return an error if it is invalid. +func (c *Config) Validate() error { + if c.HTTP.Enable { + _, _, err := net.SplitHostPort(c.HTTP.Addr) + if err != nil { + return errors.New("invalid http_addr") + } + } + _, _, err := net.SplitHostPort(c.GRPC.Addr) + if err != nil { + return errors.New("invalid grpc_addr") + } + return nil +} + +// DefaultConfig is the default configuration. +var DefaultConfig = Config{ + HTTP: HTTPConfig{ + Enable: true, + Addr: "127.0.0.1:57091", + }, + GRPC: GRPCConfig{ + Addr: "unix://./mqttd.sock", + }, +} + +func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { + type cfg Config + var v = &struct { + Admin cfg `yaml:"admin"` + }{ + Admin: cfg(DefaultConfig), + } + if err := unmarshal(v); err != nil { + return err + } + emptyGRPC := GRPCConfig{} + if v.Admin.GRPC == emptyGRPC { + v.Admin.GRPC = DefaultConfig.GRPC + } + emptyHTTP := HTTPConfig{} + if v.Admin.HTTP == emptyHTTP { + v.Admin.HTTP = DefaultConfig.HTTP + } + empty := cfg(Config{}) + if v.Admin == empty { + v.Admin = cfg(DefaultConfig) + } + *c = Config(v.Admin) + return nil +} diff --git a/internal/hummingbird/mqttbroker/plugin/admin/hooks.go b/internal/hummingbird/mqttbroker/plugin/admin/hooks.go new file mode 100644 index 0000000..6297c3f --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/hooks.go @@ -0,0 +1,61 @@ +package admin + +import ( + "context" + + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/server" +) + +func (a *Admin) HookWrapper() server.HookWrapper { + return server.HookWrapper{ + OnSessionCreatedWrapper: a.OnSessionCreatedWrapper, + OnSessionResumedWrapper: a.OnSessionResumedWrapper, + OnClosedWrapper: a.OnClosedWrapper, + OnSessionTerminatedWrapper: a.OnSessionTerminatedWrapper, + OnSubscribedWrapper: a.OnSubscribedWrapper, + OnUnsubscribedWrapper: a.OnUnsubscribedWrapper, + } +} + +func (a *Admin) OnSessionCreatedWrapper(pre server.OnSessionCreated) server.OnSessionCreated { + return func(ctx context.Context, client server.Client) { + pre(ctx, client) + a.store.addClient(client) + } +} + +func (a *Admin) OnSessionResumedWrapper(pre server.OnSessionResumed) server.OnSessionResumed { + return func(ctx context.Context, client server.Client) { + pre(ctx, client) + a.store.addClient(client) + } +} + +func (a *Admin) OnClosedWrapper(pre server.OnClosed) server.OnClosed { + return func(ctx context.Context, client server.Client, err error) { + pre(ctx, client, err) + a.store.setClientDisconnected(client.ClientOptions().ClientID) + } +} + +func (a *Admin) OnSessionTerminatedWrapper(pre server.OnSessionTerminated) server.OnSessionTerminated { + return func(ctx context.Context, clientID string, reason server.SessionTerminatedReason) { + pre(ctx, clientID, reason) + a.store.removeClient(clientID) + } +} + +func (a *Admin) OnSubscribedWrapper(pre server.OnSubscribed) server.OnSubscribed { + return func(ctx context.Context, client server.Client, subscription *gmqtt.Subscription) { + pre(ctx, client, subscription) + a.store.addSubscription(client.ClientOptions().ClientID, subscription) + } +} + +func (a *Admin) OnUnsubscribedWrapper(pre server.OnUnsubscribed) server.OnUnsubscribed { + return func(ctx context.Context, client server.Client, topicName string) { + pre(ctx, client, topicName) + a.store.removeSubscription(client.ClientOptions().ClientID, topicName) + } +} diff --git a/internal/hummingbird/mqttbroker/plugin/admin/protos/client.proto b/internal/hummingbird/mqttbroker/plugin/admin/protos/client.proto new file mode 100644 index 0000000..e306ff1 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/protos/client.proto @@ -0,0 +1,78 @@ +syntax = "proto3"; + +package gmqtt.admin.api; +option go_package = ".;admin"; + +import "google/api/annotations.proto"; +import "google/protobuf/empty.proto"; +import "google/protobuf/timestamp.proto"; + +message ListClientRequest { + uint32 page_size = 1; + uint32 page = 2; +} + +message ListClientResponse { + repeated Client clients = 1; + uint32 total_count = 2; +} + +message GetClientRequest { + string client_id = 1; +} + +message GetClientResponse { + Client client = 1; +} + + +message DeleteClientRequest { + string client_id = 1; + bool clean_session = 2; +} + +message Client { + string client_id = 1; + string username = 2; + int32 keep_alive = 3; + int32 version = 4; + string remote_addr = 5; + string local_addr = 6; + google.protobuf.Timestamp connected_at = 7; + google.protobuf.Timestamp disconnected_at = 8; + uint32 session_expiry = 9; + uint32 max_inflight = 10; + uint32 inflight_len = 11; + uint32 max_queue = 12; + uint32 queue_len = 13; + uint32 subscriptions_current = 14; + uint32 subscriptions_total = 15; + uint64 packets_received_bytes = 16; + uint64 packets_received_nums = 17; + uint64 packets_send_bytes = 18; + uint64 packets_send_nums = 19; + uint64 message_dropped = 20; +} + + +service ClientService { + // List clients + rpc List (ListClientRequest) returns (ListClientResponse){ + option (google.api.http) = { + get: "/v1/clients" + }; + } + // Get the client for given client id. + // Return NotFound error when client not found. + rpc Get (GetClientRequest) returns (GetClientResponse){ + option (google.api.http) = { + get: "/v1/clients/{client_id}" + }; + } + // Disconnect the client for given client id. + rpc Delete (DeleteClientRequest) returns (google.protobuf.Empty) { + option (google.api.http) = { + delete: "/v1/clients/{client_id}" + }; + } +} diff --git a/internal/hummingbird/mqttbroker/plugin/admin/protos/proto_gen.sh b/internal/hummingbird/mqttbroker/plugin/admin/protos/proto_gen.sh new file mode 100755 index 0000000..261d433 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/protos/proto_gen.sh @@ -0,0 +1,8 @@ +protoc -I. \ +-I$GOPATH/src/github.com/grpc-ecosystem/grpc-gateway \ +-I$GOPATH/src/github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis \ +--go-grpc_out=../ \ +--go_out=../ \ +--grpc-gateway_out=../ \ +--swagger_out=../swagger \ +*.proto \ No newline at end of file diff --git a/internal/hummingbird/mqttbroker/plugin/admin/protos/publish.proto b/internal/hummingbird/mqttbroker/plugin/admin/protos/publish.proto new file mode 100644 index 0000000..0fd6674 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/protos/publish.proto @@ -0,0 +1,36 @@ +syntax = "proto3"; + +package gmqtt.admin.api; +option go_package = ".;admin"; + +import "google/api/annotations.proto"; +import "google/protobuf/empty.proto"; + +message PublishRequest { + string topic_name = 1; + string payload = 2; + uint32 qos = 3; + bool retained = 4; + // the following fields are using in v5 client. + string content_type = 5; + string correlation_data = 6; + uint32 message_expiry = 7; + uint32 payload_format = 8; + string response_topic = 9; + repeated UserProperties user_properties = 10; +} + +message UserProperties { + bytes K = 1; + bytes V = 2; +} + +service PublishService { + // Publish message to broker + rpc Publish (PublishRequest) returns (google.protobuf.Empty){ + option (google.api.http) = { + post: "/v1/publish" + body:"*" + }; + } +} diff --git a/internal/hummingbird/mqttbroker/plugin/admin/protos/subscription.proto b/internal/hummingbird/mqttbroker/plugin/admin/protos/subscription.proto new file mode 100644 index 0000000..cb4c609 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/protos/subscription.proto @@ -0,0 +1,104 @@ +syntax = "proto3"; + +package gmqtt.admin.api; +option go_package = ".;admin"; + +import "google/api/annotations.proto"; +import "google/protobuf/empty.proto"; + +enum SubFilterType { + SUB_FILTER_TYPE_SYS_UNSPECIFIED = 0; + SUB_FILTER_TYPE_SYS = 1; + SUB_FILTER_TYPE_SHARED = 2; + SUB_FILTER_TYPE_NON_SHARED = 3; +} +enum SubMatchType { + SUB_MATCH_TYPE_MATCH_UNSPECIFIED = 0; + SUB_MATCH_TYPE_MATCH_NAME = 1; + SUB_MATCH_TYPE_MATCH_FILTER = 2; +} + +message ListSubscriptionRequest { + uint32 page_size = 1; + uint32 page = 2; +} + +message ListSubscriptionResponse { + repeated Subscription subscriptions = 1; + uint32 total_count = 2; +} +message FilterSubscriptionRequest { + // If set, only filter the subscriptions that belongs to the client. + string client_id = 1; + // filter_type indicates what kinds of topics are going to filter. + // If there are multiple types, use ',' to separate. e.g : 1,2 + // There are 3 kinds of topic can be filtered, defined by SubFilterType: + // 1 = System Topic(begin with '$') + // 2 = Shared Topic + // 3 = NonShared Topic + string filter_type = 2; + // If 1 (SUB_MATCH_TYPE_MATCH_NAME), the server will return subscriptions which has the same topic name with request topic_name. + // If 2 (SUB_MATCH_TYPE_MATCH_FILTER),the server will return subscriptions which match the request topic_name . + // match_type must be set when filter_type is not empty. + SubMatchType match_type = 3; + // topic_name must be set when match_type is not zero. + string topic_name = 4; + // The maximum subscriptions can be returned. + int32 limit = 5; +} +message FilterSubscriptionResponse { + repeated Subscription subscriptions = 1; +} + +message SubscribeRequest { + string client_id = 1; + repeated Subscription subscriptions = 2; +} + +message SubscribeResponse { + // indicates whether it is a new subscription or the subscription is already existed. + repeated bool new = 1; +} + +message UnsubscribeRequest { + string client_id = 1; + repeated string topics = 2; +} + +message Subscription { + string topic_name = 1; + uint32 id = 2; + uint32 qos = 3; + bool no_local = 4; + bool retain_as_published = 5; + uint32 retain_handling = 6; + string client_id = 7; +} +service SubscriptionService { + // List subscriptions. + rpc List (ListSubscriptionRequest) returns (ListSubscriptionResponse){ + option (google.api.http) = { + get: "/v1/subscriptions" + }; + } + // Filter subscriptions, paging is not supported in this API. + rpc Filter(FilterSubscriptionRequest) returns (FilterSubscriptionResponse) { + option (google.api.http) = { + get: "/v1/filter_subscriptions" + }; + } + // Subscribe topics for the client. + rpc Subscribe (SubscribeRequest) returns (SubscribeResponse) { + option (google.api.http) = { + post: "/v1/subscribe" + body:"*" + }; + } + // Unsubscribe topics for the client. + rpc Unsubscribe (UnsubscribeRequest) returns (google.protobuf.Empty) { + option (google.api.http) = { + post: "/v1/unsubscribe" + body:"*" + }; + } +} diff --git a/internal/hummingbird/mqttbroker/plugin/admin/publish.go b/internal/hummingbird/mqttbroker/plugin/admin/publish.go new file mode 100644 index 0000000..75a742f --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/publish.go @@ -0,0 +1,56 @@ +package admin + +import ( + "context" + + "github.com/golang/protobuf/ptypes/empty" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +type publisher struct { + a *Admin +} + +func (p *publisher) mustEmbedUnimplementedPublishServiceServer() { + return +} + +// Publish publishes a message into broker. +func (p *publisher) Publish(ctx context.Context, req *PublishRequest) (resp *empty.Empty, err error) { + if !packets.ValidV5Topic([]byte(req.TopicName)) { + return nil, ErrInvalidArgument("topic_name", "") + } + if req.Qos > uint32(packets.Qos2) { + return nil, ErrInvalidArgument("qos", "") + } + if req.PayloadFormat != 0 && req.PayloadFormat != 1 { + return nil, ErrInvalidArgument("payload_format", "") + } + if req.ResponseTopic != "" && !packets.ValidV5Topic([]byte(req.ResponseTopic)) { + return nil, ErrInvalidArgument("response_topic", "") + } + var userPpt []packets.UserProperty + for _, v := range req.UserProperties { + userPpt = append(userPpt, packets.UserProperty{ + K: v.K, + V: v.V, + }) + } + + p.a.publisher.Publish(&mqttbroker.Message{ + Dup: false, + QoS: byte(req.Qos), + Retained: req.Retained, + Topic: req.TopicName, + Payload: []byte(req.Payload), + ContentType: req.ContentType, + CorrelationData: []byte(req.CorrelationData), + MessageExpiry: req.MessageExpiry, + PayloadFormat: packets.PayloadFormat(req.PayloadFormat), + ResponseTopic: req.ResponseTopic, + UserProperties: userPpt, + }) + return &empty.Empty{}, nil +} diff --git a/internal/hummingbird/mqttbroker/plugin/admin/publish.pb.go b/internal/hummingbird/mqttbroker/plugin/admin/publish.pb.go new file mode 100644 index 0000000..ec19475 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/publish.pb.go @@ -0,0 +1,331 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.22.0 +// protoc v3.13.0 +// source: publish.proto + +package admin + +import ( + reflect "reflect" + sync "sync" + + proto "github.com/golang/protobuf/proto" + empty "github.com/golang/protobuf/ptypes/empty" + _ "google.golang.org/genproto/googleapis/api/annotations" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// This is a compile-time assertion that a sufficiently up-to-date version +// of the legacy proto package is being used. +const _ = proto.ProtoPackageIsVersion4 + +type PublishRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + TopicName string `protobuf:"bytes,1,opt,name=topic_name,json=topicName,proto3" json:"topic_name,omitempty"` + Payload string `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"` + Qos uint32 `protobuf:"varint,3,opt,name=qos,proto3" json:"qos,omitempty"` + Retained bool `protobuf:"varint,4,opt,name=retained,proto3" json:"retained,omitempty"` + // the following fields are using in v5 client. + ContentType string `protobuf:"bytes,5,opt,name=content_type,json=contentType,proto3" json:"content_type,omitempty"` + CorrelationData string `protobuf:"bytes,6,opt,name=correlation_data,json=correlationData,proto3" json:"correlation_data,omitempty"` + MessageExpiry uint32 `protobuf:"varint,7,opt,name=message_expiry,json=messageExpiry,proto3" json:"message_expiry,omitempty"` + PayloadFormat uint32 `protobuf:"varint,8,opt,name=payload_format,json=payloadFormat,proto3" json:"payload_format,omitempty"` + ResponseTopic string `protobuf:"bytes,9,opt,name=response_topic,json=responseTopic,proto3" json:"response_topic,omitempty"` + UserProperties []*UserProperties `protobuf:"bytes,10,rep,name=user_properties,json=userProperties,proto3" json:"user_properties,omitempty"` +} + +func (x *PublishRequest) Reset() { + *x = PublishRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_publish_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *PublishRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PublishRequest) ProtoMessage() {} + +func (x *PublishRequest) ProtoReflect() protoreflect.Message { + mi := &file_publish_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PublishRequest.ProtoReflect.Descriptor instead. +func (*PublishRequest) Descriptor() ([]byte, []int) { + return file_publish_proto_rawDescGZIP(), []int{0} +} + +func (x *PublishRequest) GetTopicName() string { + if x != nil { + return x.TopicName + } + return "" +} + +func (x *PublishRequest) GetPayload() string { + if x != nil { + return x.Payload + } + return "" +} + +func (x *PublishRequest) GetQos() uint32 { + if x != nil { + return x.Qos + } + return 0 +} + +func (x *PublishRequest) GetRetained() bool { + if x != nil { + return x.Retained + } + return false +} + +func (x *PublishRequest) GetContentType() string { + if x != nil { + return x.ContentType + } + return "" +} + +func (x *PublishRequest) GetCorrelationData() string { + if x != nil { + return x.CorrelationData + } + return "" +} + +func (x *PublishRequest) GetMessageExpiry() uint32 { + if x != nil { + return x.MessageExpiry + } + return 0 +} + +func (x *PublishRequest) GetPayloadFormat() uint32 { + if x != nil { + return x.PayloadFormat + } + return 0 +} + +func (x *PublishRequest) GetResponseTopic() string { + if x != nil { + return x.ResponseTopic + } + return "" +} + +func (x *PublishRequest) GetUserProperties() []*UserProperties { + if x != nil { + return x.UserProperties + } + return nil +} + +type UserProperties struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + K []byte `protobuf:"bytes,1,opt,name=K,proto3" json:"K,omitempty"` + V []byte `protobuf:"bytes,2,opt,name=V,proto3" json:"V,omitempty"` +} + +func (x *UserProperties) Reset() { + *x = UserProperties{} + if protoimpl.UnsafeEnabled { + mi := &file_publish_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *UserProperties) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*UserProperties) ProtoMessage() {} + +func (x *UserProperties) ProtoReflect() protoreflect.Message { + mi := &file_publish_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use UserProperties.ProtoReflect.Descriptor instead. +func (*UserProperties) Descriptor() ([]byte, []int) { + return file_publish_proto_rawDescGZIP(), []int{1} +} + +func (x *UserProperties) GetK() []byte { + if x != nil { + return x.K + } + return nil +} + +func (x *UserProperties) GetV() []byte { + if x != nil { + return x.V + } + return nil +} + +var File_publish_proto protoreflect.FileDescriptor + +var file_publish_proto_rawDesc = []byte{ + 0x0a, 0x0d, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x73, 0x68, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, + 0x0f, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x2e, 0x61, 0x70, 0x69, + 0x1a, 0x1c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x61, 0x6e, 0x6e, + 0x6f, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1b, + 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, + 0x65, 0x6d, 0x70, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x84, 0x03, 0x0a, 0x0e, + 0x50, 0x75, 0x62, 0x6c, 0x69, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, + 0x0a, 0x0a, 0x74, 0x6f, 0x70, 0x69, 0x63, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x09, 0x74, 0x6f, 0x70, 0x69, 0x63, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x18, 0x0a, + 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, + 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x71, 0x6f, 0x73, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x71, 0x6f, 0x73, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x65, 0x74, + 0x61, 0x69, 0x6e, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x72, 0x65, 0x74, + 0x61, 0x69, 0x6e, 0x65, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, + 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x63, 0x6f, 0x6e, + 0x74, 0x65, 0x6e, 0x74, 0x54, 0x79, 0x70, 0x65, 0x12, 0x29, 0x0a, 0x10, 0x63, 0x6f, 0x72, 0x72, + 0x65, 0x6c, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x64, 0x61, 0x74, 0x61, 0x18, 0x06, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x0f, 0x63, 0x6f, 0x72, 0x72, 0x65, 0x6c, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x44, + 0x61, 0x74, 0x61, 0x12, 0x25, 0x0a, 0x0e, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x5f, 0x65, + 0x78, 0x70, 0x69, 0x72, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0d, 0x6d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x45, 0x78, 0x70, 0x69, 0x72, 0x79, 0x12, 0x25, 0x0a, 0x0e, 0x70, 0x61, + 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x5f, 0x66, 0x6f, 0x72, 0x6d, 0x61, 0x74, 0x18, 0x08, 0x20, 0x01, + 0x28, 0x0d, 0x52, 0x0d, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x46, 0x6f, 0x72, 0x6d, 0x61, + 0x74, 0x12, 0x25, 0x0a, 0x0e, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5f, 0x74, 0x6f, + 0x70, 0x69, 0x63, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x72, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x54, 0x6f, 0x70, 0x69, 0x63, 0x12, 0x48, 0x0a, 0x0f, 0x75, 0x73, 0x65, 0x72, + 0x5f, 0x70, 0x72, 0x6f, 0x70, 0x65, 0x72, 0x74, 0x69, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x1f, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x2e, + 0x61, 0x70, 0x69, 0x2e, 0x55, 0x73, 0x65, 0x72, 0x50, 0x72, 0x6f, 0x70, 0x65, 0x72, 0x74, 0x69, + 0x65, 0x73, 0x52, 0x0e, 0x75, 0x73, 0x65, 0x72, 0x50, 0x72, 0x6f, 0x70, 0x65, 0x72, 0x74, 0x69, + 0x65, 0x73, 0x22, 0x2c, 0x0a, 0x0e, 0x55, 0x73, 0x65, 0x72, 0x50, 0x72, 0x6f, 0x70, 0x65, 0x72, + 0x74, 0x69, 0x65, 0x73, 0x12, 0x0c, 0x0a, 0x01, 0x4b, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, + 0x01, 0x4b, 0x12, 0x0c, 0x0a, 0x01, 0x56, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x01, 0x56, + 0x32, 0x6c, 0x0a, 0x0e, 0x50, 0x75, 0x62, 0x6c, 0x69, 0x73, 0x68, 0x53, 0x65, 0x72, 0x76, 0x69, + 0x63, 0x65, 0x12, 0x5a, 0x0a, 0x07, 0x50, 0x75, 0x62, 0x6c, 0x69, 0x73, 0x68, 0x12, 0x1f, 0x2e, + 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x2e, + 0x50, 0x75, 0x62, 0x6c, 0x69, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, + 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, + 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x16, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x10, 0x22, 0x0b, + 0x2f, 0x76, 0x31, 0x2f, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x73, 0x68, 0x3a, 0x01, 0x2a, 0x42, 0x09, + 0x5a, 0x07, 0x2e, 0x3b, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x33, +} + +var ( + file_publish_proto_rawDescOnce sync.Once + file_publish_proto_rawDescData = file_publish_proto_rawDesc +) + +func file_publish_proto_rawDescGZIP() []byte { + file_publish_proto_rawDescOnce.Do(func() { + file_publish_proto_rawDescData = protoimpl.X.CompressGZIP(file_publish_proto_rawDescData) + }) + return file_publish_proto_rawDescData +} + +var file_publish_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_publish_proto_goTypes = []interface{}{ + (*PublishRequest)(nil), // 0: gmqtt.admin.api.PublishRequest + (*UserProperties)(nil), // 1: gmqtt.admin.api.UserProperties + (*empty.Empty)(nil), // 2: google.protobuf.Empty +} +var file_publish_proto_depIdxs = []int32{ + 1, // 0: gmqtt.admin.api.PublishRequest.user_properties:type_name -> gmqtt.admin.api.UserProperties + 0, // 1: gmqtt.admin.api.PublishService.Publish:input_type -> gmqtt.admin.api.PublishRequest + 2, // 2: gmqtt.admin.api.PublishService.Publish:output_type -> google.protobuf.Empty + 2, // [2:3] is the sub-list for method output_type + 1, // [1:2] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_publish_proto_init() } +func file_publish_proto_init() { + if File_publish_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_publish_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PublishRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_publish_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*UserProperties); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_publish_proto_rawDesc, + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_publish_proto_goTypes, + DependencyIndexes: file_publish_proto_depIdxs, + MessageInfos: file_publish_proto_msgTypes, + }.Build() + File_publish_proto = out.File + file_publish_proto_rawDesc = nil + file_publish_proto_goTypes = nil + file_publish_proto_depIdxs = nil +} diff --git a/internal/hummingbird/mqttbroker/plugin/admin/publish.pb.gw.go b/internal/hummingbird/mqttbroker/plugin/admin/publish.pb.gw.go new file mode 100644 index 0000000..f644890 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/publish.pb.gw.go @@ -0,0 +1,163 @@ +// Code generated by protoc-gen-grpc-gateway. DO NOT EDIT. +// source: publish.proto + +/* +Package admin is a reverse proxy. + +It translates gRPC into RESTful JSON APIs. +*/ +package admin + +import ( + "context" + "io" + "net/http" + + "github.com/golang/protobuf/descriptor" + "github.com/golang/protobuf/proto" + "github.com/grpc-ecosystem/grpc-gateway/runtime" + "github.com/grpc-ecosystem/grpc-gateway/utilities" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/status" +) + +// Suppress "imported and not used" errors +var _ codes.Code +var _ io.Reader +var _ status.Status +var _ = runtime.String +var _ = utilities.NewDoubleArray +var _ = descriptor.ForMessage + +func request_PublishService_Publish_0(ctx context.Context, marshaler runtime.Marshaler, client PublishServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq PublishRequest + var metadata runtime.ServerMetadata + + newReader, berr := utilities.IOReaderFactory(req.Body) + if berr != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) + } + if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := client.Publish(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err + +} + +func local_request_PublishService_Publish_0(ctx context.Context, marshaler runtime.Marshaler, server PublishServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq PublishRequest + var metadata runtime.ServerMetadata + + newReader, berr := utilities.IOReaderFactory(req.Body) + if berr != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) + } + if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := server.Publish(ctx, &protoReq) + return msg, metadata, err + +} + +// RegisterPublishServiceHandlerServer registers the http handlers for service PublishService to "mux". +// UnaryRPC :call PublishServiceServer directly. +// StreamingRPC :currently unsupported pending https://github.com/grpc/grpc-go/issues/906. +func RegisterPublishServiceHandlerServer(ctx context.Context, mux *runtime.ServeMux, server PublishServiceServer) error { + + mux.Handle("POST", pattern_PublishService_Publish_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_PublishService_Publish_0(rctx, inboundMarshaler, server, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_PublishService_Publish_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + return nil +} + +// RegisterPublishServiceHandlerFromEndpoint is same as RegisterPublishServiceHandler but +// automatically dials to "endpoint" and closes the connection when "ctx" gets done. +func RegisterPublishServiceHandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { + conn, err := grpc.Dial(endpoint, opts...) + if err != nil { + return err + } + defer func() { + if err != nil { + if cerr := conn.Close(); cerr != nil { + grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) + } + return + } + go func() { + <-ctx.Done() + if cerr := conn.Close(); cerr != nil { + grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) + } + }() + }() + + return RegisterPublishServiceHandler(ctx, mux, conn) +} + +// RegisterPublishServiceHandler registers the http handlers for service PublishService to "mux". +// The handlers forward requests to the grpc endpoint over "conn". +func RegisterPublishServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { + return RegisterPublishServiceHandlerClient(ctx, mux, NewPublishServiceClient(conn)) +} + +// RegisterPublishServiceHandlerClient registers the http handlers for service PublishService +// to "mux". The handlers forward requests to the grpc endpoint over the given implementation of "PublishServiceClient". +// Note: the gRPC framework executes interceptors within the gRPC handler. If the passed in "PublishServiceClient" +// doesn't go through the normal gRPC flow (creating a gRPC client etc.) then it will be up to the passed in +// "PublishServiceClient" to call the correct interceptors. +func RegisterPublishServiceHandlerClient(ctx context.Context, mux *runtime.ServeMux, client PublishServiceClient) error { + + mux.Handle("POST", pattern_PublishService_Publish_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_PublishService_Publish_0(rctx, inboundMarshaler, client, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_PublishService_Publish_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + return nil +} + +var ( + pattern_PublishService_Publish_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1}, []string{"v1", "publish"}, "", runtime.AssumeColonVerbOpt(true))) +) + +var ( + forward_PublishService_Publish_0 = runtime.ForwardResponseMessage +) diff --git a/internal/hummingbird/mqttbroker/plugin/admin/publish_grpc.pb.go b/internal/hummingbird/mqttbroker/plugin/admin/publish_grpc.pb.go new file mode 100644 index 0000000..5c3cc8d --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/publish_grpc.pb.go @@ -0,0 +1,101 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. + +package admin + +import ( + context "context" + + empty "github.com/golang/protobuf/ptypes/empty" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion7 + +// PublishServiceClient is the client API for PublishService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type PublishServiceClient interface { + // Publish message to broker + Publish(ctx context.Context, in *PublishRequest, opts ...grpc.CallOption) (*empty.Empty, error) +} + +type publishServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewPublishServiceClient(cc grpc.ClientConnInterface) PublishServiceClient { + return &publishServiceClient{cc} +} + +func (c *publishServiceClient) Publish(ctx context.Context, in *PublishRequest, opts ...grpc.CallOption) (*empty.Empty, error) { + out := new(empty.Empty) + err := c.cc.Invoke(ctx, "/gmqtt.admin.api.PublishService/Publish", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// PublishServiceServer is the server API for PublishService service. +// All implementations must embed UnimplementedPublishServiceServer +// for forward compatibility +type PublishServiceServer interface { + // Publish message to broker + Publish(context.Context, *PublishRequest) (*empty.Empty, error) + mustEmbedUnimplementedPublishServiceServer() +} + +// UnimplementedPublishServiceServer must be embedded to have forward compatible implementations. +type UnimplementedPublishServiceServer struct { +} + +func (UnimplementedPublishServiceServer) Publish(context.Context, *PublishRequest) (*empty.Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method Publish not implemented") +} +func (UnimplementedPublishServiceServer) mustEmbedUnimplementedPublishServiceServer() {} + +// UnsafePublishServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to PublishServiceServer will +// result in compilation errors. +type UnsafePublishServiceServer interface { + mustEmbedUnimplementedPublishServiceServer() +} + +func RegisterPublishServiceServer(s grpc.ServiceRegistrar, srv PublishServiceServer) { + s.RegisterService(&_PublishService_serviceDesc, srv) +} + +func _PublishService_Publish_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(PublishRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(PublishServiceServer).Publish(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/gmqtt.admin.api.PublishService/Publish", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(PublishServiceServer).Publish(ctx, req.(*PublishRequest)) + } + return interceptor(ctx, in, info, handler) +} + +var _PublishService_serviceDesc = grpc.ServiceDesc{ + ServiceName: "gmqtt.admin.api.PublishService", + HandlerType: (*PublishServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Publish", + Handler: _PublishService_Publish_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "publish.proto", +} diff --git a/internal/hummingbird/mqttbroker/plugin/admin/publish_test.go b/internal/hummingbird/mqttbroker/plugin/admin/publish_test.go new file mode 100644 index 0000000..81fc4db --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/publish_test.go @@ -0,0 +1,125 @@ +package admin + +import ( + "context" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/status" + + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/server" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +func TestPublisher_Publish(t *testing.T) { + a := assert.New(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mp := server.NewMockPublisher(ctrl) + pub := &publisher{ + a: &Admin{ + publisher: mp, + }, + } + msg := &gmqtt.Message{ + QoS: 1, + Retained: true, + Topic: "topic", + Payload: []byte("abc"), + ContentType: "ct", + CorrelationData: []byte("co"), + MessageExpiry: 1, + PayloadFormat: 1, + ResponseTopic: "resp", + UserProperties: []packets.UserProperty{ + { + K: []byte("K"), + V: []byte("V"), + }, + }, + } + mp.EXPECT().Publish(msg) + _, err := pub.Publish(context.Background(), &PublishRequest{ + TopicName: msg.Topic, + Payload: string(msg.Payload), + Qos: uint32(msg.QoS), + Retained: msg.Retained, + ContentType: msg.ContentType, + CorrelationData: string(msg.CorrelationData), + MessageExpiry: msg.MessageExpiry, + PayloadFormat: uint32(msg.PayloadFormat), + ResponseTopic: msg.ResponseTopic, + UserProperties: []*UserProperties{ + { + K: []byte("K"), + V: []byte("V"), + }, + }, + }) + a.Nil(err) +} + +func TestPublisher_Publish_InvalidArgument(t *testing.T) { + var tt = []struct { + name string + field string + req *PublishRequest + }{ + { + name: "invalid_topic_name", + field: "topic_name", + req: &PublishRequest{ + TopicName: "$share/a", + Qos: 2, + }, + }, + { + name: "invalid_qos", + field: "qos", + req: &PublishRequest{ + TopicName: "a", + Qos: 3, + }, + }, + { + name: "invalid_payload_format", + field: "payload_format", + req: &PublishRequest{ + TopicName: "a", + Qos: 2, + PayloadFormat: 3, + }, + }, + { + name: "invalid_response_topic", + field: "response_topic", + req: &PublishRequest{ + TopicName: "a", + Qos: 2, + PayloadFormat: 1, + ResponseTopic: "#/", + }, + }, + } + for _, v := range tt { + t.Run(v.name, func(t *testing.T) { + a := assert.New(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mp := server.NewMockPublisher(ctrl) + pub := &publisher{ + a: &Admin{ + publisher: mp, + }, + } + _, err := pub.Publish(context.Background(), v.req) + s, ok := status.FromError(err) + a.True(ok) + a.Contains(s.Message(), v.field) + }) + } + +} diff --git a/internal/hummingbird/mqttbroker/plugin/admin/store.go b/internal/hummingbird/mqttbroker/plugin/admin/store.go new file mode 100644 index 0000000..050afcb --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/store.go @@ -0,0 +1,159 @@ +package admin + +import ( + "container/list" + "sync" + + "google.golang.org/protobuf/types/known/timestamppb" + + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/server" +) + +type store struct { + clientMu sync.RWMutex + clientIndexer *Indexer + subMu sync.RWMutex + subIndexer *Indexer + config config.Config + statsReader server.StatsReader + subscriptionService server.SubscriptionService +} + +func newStore(statsReader server.StatsReader, config config.Config) *store { + return &store{ + clientIndexer: NewIndexer(), + subIndexer: NewIndexer(), + statsReader: statsReader, + config: config, + } +} + +func (s *store) addSubscription(clientID string, sub *gmqtt.Subscription) { + s.subMu.Lock() + defer s.subMu.Unlock() + + subInfo := &Subscription{ + TopicName: sub.GetFullTopicName(), + Id: sub.ID, + Qos: uint32(sub.QoS), + NoLocal: sub.NoLocal, + RetainAsPublished: sub.RetainAsPublished, + RetainHandling: uint32(sub.RetainHandling), + ClientId: clientID, + } + key := clientID + "_" + sub.GetFullTopicName() + s.subIndexer.Set(key, subInfo) + +} + +func (s *store) removeSubscription(clientID string, topicName string) { + s.subMu.Lock() + defer s.subMu.Unlock() + s.subIndexer.Remove(clientID + "_" + topicName) +} + +func (s *store) addClient(client server.Client) { + c := newClientInfo(client, uint32(s.config.MQTT.MaxQueuedMsg)) + s.clientMu.Lock() + s.clientIndexer.Set(c.ClientId, c) + s.clientMu.Unlock() +} + +func (s *store) setClientDisconnected(clientID string) { + s.clientMu.Lock() + defer s.clientMu.Unlock() + l := s.clientIndexer.GetByID(clientID) + if l == nil { + return + } + l.Value.(*Client).DisconnectedAt = timestamppb.Now() +} + +func (s *store) removeClient(clientID string) { + s.clientMu.Lock() + s.clientIndexer.Remove(clientID) + s.clientMu.Unlock() +} + +// GetClientByID returns the client information for the given client id. +func (s *store) GetClientByID(clientID string) *Client { + s.clientMu.RLock() + defer s.clientMu.RUnlock() + c := s.getClientByIDLocked(clientID) + fillClientInfo(c, s.statsReader) + return c +} + +func newClientInfo(client server.Client, maxQueue uint32) *Client { + clientOptions := client.ClientOptions() + rs := &Client{ + ClientId: clientOptions.ClientID, + Username: clientOptions.Username, + KeepAlive: int32(clientOptions.KeepAlive), + Version: int32(client.Version()), + RemoteAddr: client.Connection().RemoteAddr().String(), + LocalAddr: client.Connection().LocalAddr().String(), + ConnectedAt: timestamppb.New(client.ConnectedAt()), + DisconnectedAt: nil, + SessionExpiry: clientOptions.SessionExpiry, + MaxInflight: uint32(clientOptions.MaxInflight), + MaxQueue: maxQueue, + } + return rs +} + +func (s *store) getClientByIDLocked(clientID string) *Client { + if i := s.clientIndexer.GetByID(clientID); i != nil { + return i.Value.(*Client) + } + return nil +} + +func fillClientInfo(c *Client, stsReader server.StatsReader) { + if c == nil { + return + } + sts, ok := stsReader.GetClientStats(c.ClientId) + if !ok { + return + } + c.SubscriptionsCurrent = uint32(sts.SubscriptionStats.SubscriptionsCurrent) + c.SubscriptionsTotal = uint32(sts.SubscriptionStats.SubscriptionsTotal) + c.PacketsReceivedBytes = sts.PacketStats.BytesReceived.Total + c.PacketsReceivedNums = sts.PacketStats.ReceivedTotal.Total + c.PacketsSendBytes = sts.PacketStats.BytesSent.Total + c.PacketsSendNums = sts.PacketStats.SentTotal.Total + c.MessageDropped = sts.MessageStats.GetDroppedTotal() + c.InflightLen = uint32(sts.MessageStats.InflightCurrent) + c.QueueLen = uint32(sts.MessageStats.QueuedCurrent) +} + +// GetClients +func (s *store) GetClients(page, pageSize uint) (rs []*Client, total uint32, err error) { + rs = make([]*Client, 0) + fn := func(elem *list.Element) { + c := elem.Value.(*Client) + fillClientInfo(c, s.statsReader) + rs = append(rs, elem.Value.(*Client)) + } + s.clientMu.RLock() + defer s.clientMu.RUnlock() + offset, n := GetOffsetN(page, pageSize) + s.clientIndexer.Iterate(fn, offset, n) + return rs, uint32(s.clientIndexer.Len()), nil +} + +// GetSubscriptions +func (s *store) GetSubscriptions(page, pageSize uint) (rs []*Subscription, total uint32, err error) { + rs = make([]*Subscription, 0) + fn := func(elem *list.Element) { + rs = append(rs, elem.Value.(*Subscription)) + } + s.subMu.RLock() + defer s.subMu.RUnlock() + offset, n := GetOffsetN(page, pageSize) + s.subIndexer.Iterate(fn, offset, n) + return rs, uint32(s.subIndexer.Len()), nil +} diff --git a/internal/hummingbird/mqttbroker/plugin/admin/subscription.go b/internal/hummingbird/mqttbroker/plugin/admin/subscription.go new file mode 100644 index 0000000..d4fde11 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/subscription.go @@ -0,0 +1,182 @@ +package admin + +import ( + "context" + "fmt" + "strconv" + "strings" + + "github.com/golang/protobuf/ptypes/empty" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/subscription" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +type subscriptionService struct { + a *Admin +} + +func (s *subscriptionService) mustEmbedUnimplementedSubscriptionServiceServer() { + return +} + +// List lists subscriptions in the broker. +func (s *subscriptionService) List(ctx context.Context, req *ListSubscriptionRequest) (*ListSubscriptionResponse, error) { + page, pageSize := GetPage(req.Page, req.PageSize) + subs, total, err := s.a.store.GetSubscriptions(page, pageSize) + if err != nil { + return &ListSubscriptionResponse{}, err + } + return &ListSubscriptionResponse{ + Subscriptions: subs, + TotalCount: total, + }, nil +} + +// Filter filters subscriptions with the request params. +// Paging is not supported, and the results are not sorted in any way. +// Using huge req.Limit can impact performance. +func (s *subscriptionService) Filter(ctx context.Context, req *FilterSubscriptionRequest) (resp *FilterSubscriptionResponse, err error) { + var iterType subscription.IterationType + iterOpts := subscription.IterationOptions{ + ClientID: req.ClientId, + TopicName: req.TopicName, + } + if req.FilterType == "" { + iterType = subscription.TypeAll + } else { + types := strings.Split(req.FilterType, ",") + for _, v := range types { + if v == "" { + continue + } + i, err := strconv.Atoi(v) + + if err != nil { + return nil, ErrInvalidArgument("filter_type", err.Error()) + } + switch SubFilterType(i) { + + case SubFilterType_SUB_FILTER_TYPE_SYS: + iterType |= subscription.TypeSYS + case SubFilterType_SUB_FILTER_TYPE_SHARED: + iterType |= subscription.TypeShared + case SubFilterType_SUB_FILTER_TYPE_NON_SHARED: + iterType |= subscription.TypeNonShared + default: + return nil, ErrInvalidArgument("filter_type", "") + } + } + } + + iterOpts.Type = iterType + + if req.MatchType == SubMatchType_SUB_MATCH_TYPE_MATCH_NAME { + iterOpts.MatchType = subscription.MatchName + } else if req.MatchType == SubMatchType_SUB_MATCH_TYPE_MATCH_FILTER { + iterOpts.MatchType = subscription.MatchFilter + } + if iterOpts.TopicName == "" && iterOpts.MatchType != 0 { + return nil, ErrInvalidArgument("topic_name", "cannot be empty while match_type Set") + } + if iterOpts.TopicName != "" && iterOpts.MatchType == 0 { + return nil, ErrInvalidArgument("match_type", "cannot be empty while topic_name Set") + } + if iterOpts.TopicName != "" { + if !packets.ValidV5Topic([]byte(iterOpts.TopicName)) { + return nil, ErrInvalidArgument("topic_name", "") + } + } + + if req.Limit > 1000 { + return nil, ErrInvalidArgument("limit", fmt.Sprintf("limit too large, must <= 1000")) + } + if req.Limit == 0 { + req.Limit = 20 + } + resp = &FilterSubscriptionResponse{ + Subscriptions: make([]*Subscription, 0), + } + i := int32(0) + s.a.store.subscriptionService.Iterate(func(clientID string, sub *gmqtt.Subscription) bool { + if i != req.Limit { + resp.Subscriptions = append(resp.Subscriptions, &Subscription{ + TopicName: subscription.GetFullTopicName(sub.ShareName, sub.TopicFilter), + Id: sub.ID, + Qos: uint32(sub.QoS), + NoLocal: sub.NoLocal, + RetainAsPublished: sub.RetainAsPublished, + RetainHandling: uint32(sub.RetainHandling), + ClientId: clientID, + }) + } + i++ + return true + }, iterOpts) + + return resp, nil +} + +// Subscribe makes subscriptions for the given client. +func (s *subscriptionService) Subscribe(ctx context.Context, req *SubscribeRequest) (resp *SubscribeResponse, err error) { + if req.ClientId == "" { + return nil, ErrInvalidArgument("client_id", "cannot be empty") + } + if len(req.Subscriptions) == 0 { + return nil, ErrInvalidArgument("subIndexer", "zero length subIndexer") + } + var subs []*gmqtt.Subscription + for k, v := range req.Subscriptions { + shareName, name := subscription.SplitTopic(v.TopicName) + sub := &gmqtt.Subscription{ + ShareName: shareName, + TopicFilter: name, + ID: v.Id, + QoS: uint8(v.Qos), + NoLocal: v.NoLocal, + RetainAsPublished: v.RetainAsPublished, + RetainHandling: byte(v.RetainHandling), + } + err := sub.Validate() + if err != nil { + return nil, ErrInvalidArgument(fmt.Sprintf("subIndexer[%d]", k), err.Error()) + } + subs = append(subs, sub) + } + rs, err := s.a.store.subscriptionService.Subscribe(req.ClientId, subs...) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to subscribe: %s", err.Error()) + } + resp = &SubscribeResponse{ + New: make([]bool, 0), + } + for _, v := range rs { + resp.New = append(resp.New, !v.AlreadyExisted) + } + return resp, nil + +} + +// Unsubscribe unsubscribe topic for the given client. +func (s *subscriptionService) Unsubscribe(ctx context.Context, req *UnsubscribeRequest) (resp *empty.Empty, err error) { + if req.ClientId == "" { + return nil, ErrInvalidArgument("client_id", "cannot be empty") + } + if len(req.Topics) == 0 { + return nil, ErrInvalidArgument("topics", "zero length topics") + } + + for k, v := range req.Topics { + if !packets.ValidV5Topic([]byte(v)) { + return nil, ErrInvalidArgument(fmt.Sprintf("topics[%d]", k), "") + } + } + err = s.a.store.subscriptionService.Unsubscribe(req.ClientId, req.Topics...) + if err != nil { + return nil, status.Error(codes.Internal, fmt.Sprintf("failed to unsubscribe: %s", err.Error())) + } + return &empty.Empty{}, nil +} diff --git a/internal/hummingbird/mqttbroker/plugin/admin/subscription.pb.go b/internal/hummingbird/mqttbroker/plugin/admin/subscription.pb.go new file mode 100644 index 0000000..b0891e5 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/subscription.pb.go @@ -0,0 +1,922 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.22.0 +// protoc v3.13.0 +// source: subscription.proto + +package admin + +import ( + reflect "reflect" + sync "sync" + + proto "github.com/golang/protobuf/proto" + empty "github.com/golang/protobuf/ptypes/empty" + _ "google.golang.org/genproto/googleapis/api/annotations" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// This is a compile-time assertion that a sufficiently up-to-date version +// of the legacy proto package is being used. +const _ = proto.ProtoPackageIsVersion4 + +type SubFilterType int32 + +const ( + SubFilterType_SUB_FILTER_TYPE_SYS_UNSPECIFIED SubFilterType = 0 + SubFilterType_SUB_FILTER_TYPE_SYS SubFilterType = 1 + SubFilterType_SUB_FILTER_TYPE_SHARED SubFilterType = 2 + SubFilterType_SUB_FILTER_TYPE_NON_SHARED SubFilterType = 3 +) + +// Enum value maps for SubFilterType. +var ( + SubFilterType_name = map[int32]string{ + 0: "SUB_FILTER_TYPE_SYS_UNSPECIFIED", + 1: "SUB_FILTER_TYPE_SYS", + 2: "SUB_FILTER_TYPE_SHARED", + 3: "SUB_FILTER_TYPE_NON_SHARED", + } + SubFilterType_value = map[string]int32{ + "SUB_FILTER_TYPE_SYS_UNSPECIFIED": 0, + "SUB_FILTER_TYPE_SYS": 1, + "SUB_FILTER_TYPE_SHARED": 2, + "SUB_FILTER_TYPE_NON_SHARED": 3, + } +) + +func (x SubFilterType) Enum() *SubFilterType { + p := new(SubFilterType) + *p = x + return p +} + +func (x SubFilterType) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (SubFilterType) Descriptor() protoreflect.EnumDescriptor { + return file_subscription_proto_enumTypes[0].Descriptor() +} + +func (SubFilterType) Type() protoreflect.EnumType { + return &file_subscription_proto_enumTypes[0] +} + +func (x SubFilterType) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use SubFilterType.Descriptor instead. +func (SubFilterType) EnumDescriptor() ([]byte, []int) { + return file_subscription_proto_rawDescGZIP(), []int{0} +} + +type SubMatchType int32 + +const ( + SubMatchType_SUB_MATCH_TYPE_MATCH_UNSPECIFIED SubMatchType = 0 + SubMatchType_SUB_MATCH_TYPE_MATCH_NAME SubMatchType = 1 + SubMatchType_SUB_MATCH_TYPE_MATCH_FILTER SubMatchType = 2 +) + +// Enum value maps for SubMatchType. +var ( + SubMatchType_name = map[int32]string{ + 0: "SUB_MATCH_TYPE_MATCH_UNSPECIFIED", + 1: "SUB_MATCH_TYPE_MATCH_NAME", + 2: "SUB_MATCH_TYPE_MATCH_FILTER", + } + SubMatchType_value = map[string]int32{ + "SUB_MATCH_TYPE_MATCH_UNSPECIFIED": 0, + "SUB_MATCH_TYPE_MATCH_NAME": 1, + "SUB_MATCH_TYPE_MATCH_FILTER": 2, + } +) + +func (x SubMatchType) Enum() *SubMatchType { + p := new(SubMatchType) + *p = x + return p +} + +func (x SubMatchType) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (SubMatchType) Descriptor() protoreflect.EnumDescriptor { + return file_subscription_proto_enumTypes[1].Descriptor() +} + +func (SubMatchType) Type() protoreflect.EnumType { + return &file_subscription_proto_enumTypes[1] +} + +func (x SubMatchType) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use SubMatchType.Descriptor instead. +func (SubMatchType) EnumDescriptor() ([]byte, []int) { + return file_subscription_proto_rawDescGZIP(), []int{1} +} + +type ListSubscriptionRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + PageSize uint32 `protobuf:"varint,1,opt,name=page_size,json=pageSize,proto3" json:"page_size,omitempty"` + Page uint32 `protobuf:"varint,2,opt,name=page,proto3" json:"page,omitempty"` +} + +func (x *ListSubscriptionRequest) Reset() { + *x = ListSubscriptionRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_subscription_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ListSubscriptionRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListSubscriptionRequest) ProtoMessage() {} + +func (x *ListSubscriptionRequest) ProtoReflect() protoreflect.Message { + mi := &file_subscription_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListSubscriptionRequest.ProtoReflect.Descriptor instead. +func (*ListSubscriptionRequest) Descriptor() ([]byte, []int) { + return file_subscription_proto_rawDescGZIP(), []int{0} +} + +func (x *ListSubscriptionRequest) GetPageSize() uint32 { + if x != nil { + return x.PageSize + } + return 0 +} + +func (x *ListSubscriptionRequest) GetPage() uint32 { + if x != nil { + return x.Page + } + return 0 +} + +type ListSubscriptionResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Subscriptions []*Subscription `protobuf:"bytes,1,rep,name=subscriptions,proto3" json:"subscriptions,omitempty"` + TotalCount uint32 `protobuf:"varint,2,opt,name=total_count,json=totalCount,proto3" json:"total_count,omitempty"` +} + +func (x *ListSubscriptionResponse) Reset() { + *x = ListSubscriptionResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_subscription_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ListSubscriptionResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListSubscriptionResponse) ProtoMessage() {} + +func (x *ListSubscriptionResponse) ProtoReflect() protoreflect.Message { + mi := &file_subscription_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListSubscriptionResponse.ProtoReflect.Descriptor instead. +func (*ListSubscriptionResponse) Descriptor() ([]byte, []int) { + return file_subscription_proto_rawDescGZIP(), []int{1} +} + +func (x *ListSubscriptionResponse) GetSubscriptions() []*Subscription { + if x != nil { + return x.Subscriptions + } + return nil +} + +func (x *ListSubscriptionResponse) GetTotalCount() uint32 { + if x != nil { + return x.TotalCount + } + return 0 +} + +type FilterSubscriptionRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // If set, only filter the subscriptions that belongs to the client. + ClientId string `protobuf:"bytes,1,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"` + // filter_type indicates what kinds of topics are going to filter. + // If there are multiple types, use ',' to separate. e.g : 1,2 + // There are 3 kinds of topic can be filtered, defined by SubFilterType: + // 1 = System Topic(begin with '$') + // 2 = Shared Topic + // 3 = NonShared Topic + FilterType string `protobuf:"bytes,2,opt,name=filter_type,json=filterType,proto3" json:"filter_type,omitempty"` + // If 1 (SUB_MATCH_TYPE_MATCH_NAME), the server will return subscriptions which has the same topic name with request topic_name. + // If 2 (SUB_MATCH_TYPE_MATCH_FILTER),the server will return subscriptions which match the request topic_name . + // match_type must be set when filter_type is not empty. + MatchType SubMatchType `protobuf:"varint,3,opt,name=match_type,json=matchType,proto3,enum=gmqtt.admin.api.SubMatchType" json:"match_type,omitempty"` + // topic_name must be set when match_type is not zero. + TopicName string `protobuf:"bytes,4,opt,name=topic_name,json=topicName,proto3" json:"topic_name,omitempty"` + // The maximum subscriptions can be returned. + Limit int32 `protobuf:"varint,5,opt,name=limit,proto3" json:"limit,omitempty"` +} + +func (x *FilterSubscriptionRequest) Reset() { + *x = FilterSubscriptionRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_subscription_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *FilterSubscriptionRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*FilterSubscriptionRequest) ProtoMessage() {} + +func (x *FilterSubscriptionRequest) ProtoReflect() protoreflect.Message { + mi := &file_subscription_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use FilterSubscriptionRequest.ProtoReflect.Descriptor instead. +func (*FilterSubscriptionRequest) Descriptor() ([]byte, []int) { + return file_subscription_proto_rawDescGZIP(), []int{2} +} + +func (x *FilterSubscriptionRequest) GetClientId() string { + if x != nil { + return x.ClientId + } + return "" +} + +func (x *FilterSubscriptionRequest) GetFilterType() string { + if x != nil { + return x.FilterType + } + return "" +} + +func (x *FilterSubscriptionRequest) GetMatchType() SubMatchType { + if x != nil { + return x.MatchType + } + return SubMatchType_SUB_MATCH_TYPE_MATCH_UNSPECIFIED +} + +func (x *FilterSubscriptionRequest) GetTopicName() string { + if x != nil { + return x.TopicName + } + return "" +} + +func (x *FilterSubscriptionRequest) GetLimit() int32 { + if x != nil { + return x.Limit + } + return 0 +} + +type FilterSubscriptionResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Subscriptions []*Subscription `protobuf:"bytes,1,rep,name=subscriptions,proto3" json:"subscriptions,omitempty"` +} + +func (x *FilterSubscriptionResponse) Reset() { + *x = FilterSubscriptionResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_subscription_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *FilterSubscriptionResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*FilterSubscriptionResponse) ProtoMessage() {} + +func (x *FilterSubscriptionResponse) ProtoReflect() protoreflect.Message { + mi := &file_subscription_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use FilterSubscriptionResponse.ProtoReflect.Descriptor instead. +func (*FilterSubscriptionResponse) Descriptor() ([]byte, []int) { + return file_subscription_proto_rawDescGZIP(), []int{3} +} + +func (x *FilterSubscriptionResponse) GetSubscriptions() []*Subscription { + if x != nil { + return x.Subscriptions + } + return nil +} + +type SubscribeRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ClientId string `protobuf:"bytes,1,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"` + Subscriptions []*Subscription `protobuf:"bytes,2,rep,name=subscriptions,proto3" json:"subscriptions,omitempty"` +} + +func (x *SubscribeRequest) Reset() { + *x = SubscribeRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_subscription_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SubscribeRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SubscribeRequest) ProtoMessage() {} + +func (x *SubscribeRequest) ProtoReflect() protoreflect.Message { + mi := &file_subscription_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SubscribeRequest.ProtoReflect.Descriptor instead. +func (*SubscribeRequest) Descriptor() ([]byte, []int) { + return file_subscription_proto_rawDescGZIP(), []int{4} +} + +func (x *SubscribeRequest) GetClientId() string { + if x != nil { + return x.ClientId + } + return "" +} + +func (x *SubscribeRequest) GetSubscriptions() []*Subscription { + if x != nil { + return x.Subscriptions + } + return nil +} + +type SubscribeResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // indicates whether it is a new subscription or the subscription is already existed. + New []bool `protobuf:"varint,1,rep,packed,name=new,proto3" json:"new,omitempty"` +} + +func (x *SubscribeResponse) Reset() { + *x = SubscribeResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_subscription_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SubscribeResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SubscribeResponse) ProtoMessage() {} + +func (x *SubscribeResponse) ProtoReflect() protoreflect.Message { + mi := &file_subscription_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SubscribeResponse.ProtoReflect.Descriptor instead. +func (*SubscribeResponse) Descriptor() ([]byte, []int) { + return file_subscription_proto_rawDescGZIP(), []int{5} +} + +func (x *SubscribeResponse) GetNew() []bool { + if x != nil { + return x.New + } + return nil +} + +type UnsubscribeRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ClientId string `protobuf:"bytes,1,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"` + Topics []string `protobuf:"bytes,2,rep,name=topics,proto3" json:"topics,omitempty"` +} + +func (x *UnsubscribeRequest) Reset() { + *x = UnsubscribeRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_subscription_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *UnsubscribeRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*UnsubscribeRequest) ProtoMessage() {} + +func (x *UnsubscribeRequest) ProtoReflect() protoreflect.Message { + mi := &file_subscription_proto_msgTypes[6] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use UnsubscribeRequest.ProtoReflect.Descriptor instead. +func (*UnsubscribeRequest) Descriptor() ([]byte, []int) { + return file_subscription_proto_rawDescGZIP(), []int{6} +} + +func (x *UnsubscribeRequest) GetClientId() string { + if x != nil { + return x.ClientId + } + return "" +} + +func (x *UnsubscribeRequest) GetTopics() []string { + if x != nil { + return x.Topics + } + return nil +} + +type Subscription struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + TopicName string `protobuf:"bytes,1,opt,name=topic_name,json=topicName,proto3" json:"topic_name,omitempty"` + Id uint32 `protobuf:"varint,2,opt,name=id,proto3" json:"id,omitempty"` + Qos uint32 `protobuf:"varint,3,opt,name=qos,proto3" json:"qos,omitempty"` + NoLocal bool `protobuf:"varint,4,opt,name=no_local,json=noLocal,proto3" json:"no_local,omitempty"` + RetainAsPublished bool `protobuf:"varint,5,opt,name=retain_as_published,json=retainAsPublished,proto3" json:"retain_as_published,omitempty"` + RetainHandling uint32 `protobuf:"varint,6,opt,name=retain_handling,json=retainHandling,proto3" json:"retain_handling,omitempty"` + ClientId string `protobuf:"bytes,7,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"` +} + +func (x *Subscription) Reset() { + *x = Subscription{} + if protoimpl.UnsafeEnabled { + mi := &file_subscription_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Subscription) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Subscription) ProtoMessage() {} + +func (x *Subscription) ProtoReflect() protoreflect.Message { + mi := &file_subscription_proto_msgTypes[7] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Subscription.ProtoReflect.Descriptor instead. +func (*Subscription) Descriptor() ([]byte, []int) { + return file_subscription_proto_rawDescGZIP(), []int{7} +} + +func (x *Subscription) GetTopicName() string { + if x != nil { + return x.TopicName + } + return "" +} + +func (x *Subscription) GetId() uint32 { + if x != nil { + return x.Id + } + return 0 +} + +func (x *Subscription) GetQos() uint32 { + if x != nil { + return x.Qos + } + return 0 +} + +func (x *Subscription) GetNoLocal() bool { + if x != nil { + return x.NoLocal + } + return false +} + +func (x *Subscription) GetRetainAsPublished() bool { + if x != nil { + return x.RetainAsPublished + } + return false +} + +func (x *Subscription) GetRetainHandling() uint32 { + if x != nil { + return x.RetainHandling + } + return 0 +} + +func (x *Subscription) GetClientId() string { + if x != nil { + return x.ClientId + } + return "" +} + +var File_subscription_proto protoreflect.FileDescriptor + +var file_subscription_proto_rawDesc = []byte{ + 0x0a, 0x12, 0x73, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0f, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, 0x64, 0x6d, 0x69, + 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x1a, 0x1c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x61, 0x70, + 0x69, 0x2f, 0x61, 0x6e, 0x6e, 0x6f, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x1a, 0x1b, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x65, 0x6d, 0x70, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x22, 0x4a, 0x0a, 0x17, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x70, + 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x70, + 0x61, 0x67, 0x65, 0x5f, 0x73, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, + 0x70, 0x61, 0x67, 0x65, 0x53, 0x69, 0x7a, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x67, 0x65, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x70, 0x61, 0x67, 0x65, 0x22, 0x80, 0x01, 0x0a, + 0x18, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, + 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x43, 0x0a, 0x0d, 0x73, 0x75, 0x62, + 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, + 0x32, 0x1d, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x2e, 0x61, + 0x70, 0x69, 0x2e, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, + 0x0d, 0x73, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x1f, + 0x0a, 0x0b, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x5f, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x22, + 0xcc, 0x01, 0x0a, 0x19, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, + 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1b, 0x0a, + 0x09, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x1f, 0x0a, 0x0b, 0x66, 0x69, + 0x6c, 0x74, 0x65, 0x72, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0a, 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x54, 0x79, 0x70, 0x65, 0x12, 0x3c, 0x0a, 0x0a, 0x6d, + 0x61, 0x74, 0x63, 0x68, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x1d, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x2e, 0x61, 0x70, + 0x69, 0x2e, 0x53, 0x75, 0x62, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x54, 0x79, 0x70, 0x65, 0x52, 0x09, + 0x6d, 0x61, 0x74, 0x63, 0x68, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x74, 0x6f, 0x70, + 0x69, 0x63, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x74, + 0x6f, 0x70, 0x69, 0x63, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x6c, 0x69, 0x6d, 0x69, + 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x05, 0x52, 0x05, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x22, 0x61, + 0x0a, 0x1a, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x70, + 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x43, 0x0a, 0x0d, + 0x73, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x01, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, 0x64, 0x6d, 0x69, + 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, + 0x6f, 0x6e, 0x52, 0x0d, 0x73, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, + 0x73, 0x22, 0x74, 0x0a, 0x10, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, + 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, + 0x49, 0x64, 0x12, 0x43, 0x0a, 0x0d, 0x73, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, + 0x6f, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x67, 0x6d, 0x71, 0x74, + 0x74, 0x2e, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x53, 0x75, 0x62, 0x73, + 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x0d, 0x73, 0x75, 0x62, 0x73, 0x63, 0x72, + 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x22, 0x25, 0x0a, 0x11, 0x53, 0x75, 0x62, 0x73, 0x63, + 0x72, 0x69, 0x62, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x10, 0x0a, 0x03, + 0x6e, 0x65, 0x77, 0x18, 0x01, 0x20, 0x03, 0x28, 0x08, 0x52, 0x03, 0x6e, 0x65, 0x77, 0x22, 0x49, + 0x0a, 0x12, 0x55, 0x6e, 0x73, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x69, + 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, + 0x64, 0x12, 0x16, 0x0a, 0x06, 0x74, 0x6f, 0x70, 0x69, 0x63, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x06, 0x74, 0x6f, 0x70, 0x69, 0x63, 0x73, 0x22, 0xe0, 0x01, 0x0a, 0x0c, 0x53, 0x75, + 0x62, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1d, 0x0a, 0x0a, 0x74, 0x6f, + 0x70, 0x69, 0x63, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, + 0x74, 0x6f, 0x70, 0x69, 0x63, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x02, 0x69, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x71, 0x6f, 0x73, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x71, 0x6f, 0x73, 0x12, 0x19, 0x0a, 0x08, 0x6e, + 0x6f, 0x5f, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x6e, + 0x6f, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x12, 0x2e, 0x0a, 0x13, 0x72, 0x65, 0x74, 0x61, 0x69, 0x6e, + 0x5f, 0x61, 0x73, 0x5f, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x73, 0x68, 0x65, 0x64, 0x18, 0x05, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x11, 0x72, 0x65, 0x74, 0x61, 0x69, 0x6e, 0x41, 0x73, 0x50, 0x75, 0x62, + 0x6c, 0x69, 0x73, 0x68, 0x65, 0x64, 0x12, 0x27, 0x0a, 0x0f, 0x72, 0x65, 0x74, 0x61, 0x69, 0x6e, + 0x5f, 0x68, 0x61, 0x6e, 0x64, 0x6c, 0x69, 0x6e, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0d, 0x52, + 0x0e, 0x72, 0x65, 0x74, 0x61, 0x69, 0x6e, 0x48, 0x61, 0x6e, 0x64, 0x6c, 0x69, 0x6e, 0x67, 0x12, + 0x1b, 0x0a, 0x09, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x07, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x2a, 0x89, 0x01, 0x0a, + 0x0d, 0x53, 0x75, 0x62, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x54, 0x79, 0x70, 0x65, 0x12, 0x23, + 0x0a, 0x1f, 0x53, 0x55, 0x42, 0x5f, 0x46, 0x49, 0x4c, 0x54, 0x45, 0x52, 0x5f, 0x54, 0x59, 0x50, + 0x45, 0x5f, 0x53, 0x59, 0x53, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, + 0x44, 0x10, 0x00, 0x12, 0x17, 0x0a, 0x13, 0x53, 0x55, 0x42, 0x5f, 0x46, 0x49, 0x4c, 0x54, 0x45, + 0x52, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x53, 0x59, 0x53, 0x10, 0x01, 0x12, 0x1a, 0x0a, 0x16, + 0x53, 0x55, 0x42, 0x5f, 0x46, 0x49, 0x4c, 0x54, 0x45, 0x52, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, + 0x53, 0x48, 0x41, 0x52, 0x45, 0x44, 0x10, 0x02, 0x12, 0x1e, 0x0a, 0x1a, 0x53, 0x55, 0x42, 0x5f, + 0x46, 0x49, 0x4c, 0x54, 0x45, 0x52, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x4e, 0x4f, 0x4e, 0x5f, + 0x53, 0x48, 0x41, 0x52, 0x45, 0x44, 0x10, 0x03, 0x2a, 0x74, 0x0a, 0x0c, 0x53, 0x75, 0x62, 0x4d, + 0x61, 0x74, 0x63, 0x68, 0x54, 0x79, 0x70, 0x65, 0x12, 0x24, 0x0a, 0x20, 0x53, 0x55, 0x42, 0x5f, + 0x4d, 0x41, 0x54, 0x43, 0x48, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x4d, 0x41, 0x54, 0x43, 0x48, + 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x1d, + 0x0a, 0x19, 0x53, 0x55, 0x42, 0x5f, 0x4d, 0x41, 0x54, 0x43, 0x48, 0x5f, 0x54, 0x59, 0x50, 0x45, + 0x5f, 0x4d, 0x41, 0x54, 0x43, 0x48, 0x5f, 0x4e, 0x41, 0x4d, 0x45, 0x10, 0x01, 0x12, 0x1f, 0x0a, + 0x1b, 0x53, 0x55, 0x42, 0x5f, 0x4d, 0x41, 0x54, 0x43, 0x48, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, + 0x4d, 0x41, 0x54, 0x43, 0x48, 0x5f, 0x46, 0x49, 0x4c, 0x54, 0x45, 0x52, 0x10, 0x02, 0x32, 0xe9, + 0x03, 0x0a, 0x13, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x53, + 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x76, 0x0a, 0x04, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x28, + 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x2e, 0x61, 0x70, 0x69, + 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, + 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x29, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, + 0x2e, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, + 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x19, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x13, 0x12, 0x11, 0x2f, 0x76, 0x31, + 0x2f, 0x73, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x83, + 0x01, 0x0a, 0x06, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x12, 0x2a, 0x2e, 0x67, 0x6d, 0x71, 0x74, + 0x74, 0x2e, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x46, 0x69, 0x6c, 0x74, + 0x65, 0x72, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x2b, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, 0x64, + 0x6d, 0x69, 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x53, 0x75, + 0x62, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x22, 0x20, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x1a, 0x12, 0x18, 0x2f, 0x76, 0x31, 0x2f, + 0x66, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x5f, 0x73, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x6c, 0x0a, 0x09, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, + 0x65, 0x12, 0x21, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x2e, + 0x61, 0x70, 0x69, 0x2e, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x22, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, 0x64, 0x6d, + 0x69, 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x18, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x12, + 0x22, 0x0d, 0x2f, 0x76, 0x31, 0x2f, 0x73, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x3a, + 0x01, 0x2a, 0x12, 0x66, 0x0a, 0x0b, 0x55, 0x6e, 0x73, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, + 0x65, 0x12, 0x23, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x2e, + 0x61, 0x70, 0x69, 0x2e, 0x55, 0x6e, 0x73, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x1a, + 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x14, 0x22, 0x0f, 0x2f, 0x76, 0x31, 0x2f, 0x75, 0x6e, 0x73, 0x75, + 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x3a, 0x01, 0x2a, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x3b, + 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_subscription_proto_rawDescOnce sync.Once + file_subscription_proto_rawDescData = file_subscription_proto_rawDesc +) + +func file_subscription_proto_rawDescGZIP() []byte { + file_subscription_proto_rawDescOnce.Do(func() { + file_subscription_proto_rawDescData = protoimpl.X.CompressGZIP(file_subscription_proto_rawDescData) + }) + return file_subscription_proto_rawDescData +} + +var file_subscription_proto_enumTypes = make([]protoimpl.EnumInfo, 2) +var file_subscription_proto_msgTypes = make([]protoimpl.MessageInfo, 8) +var file_subscription_proto_goTypes = []interface{}{ + (SubFilterType)(0), // 0: gmqtt.admin.api.SubFilterType + (SubMatchType)(0), // 1: gmqtt.admin.api.SubMatchType + (*ListSubscriptionRequest)(nil), // 2: gmqtt.admin.api.ListSubscriptionRequest + (*ListSubscriptionResponse)(nil), // 3: gmqtt.admin.api.ListSubscriptionResponse + (*FilterSubscriptionRequest)(nil), // 4: gmqtt.admin.api.FilterSubscriptionRequest + (*FilterSubscriptionResponse)(nil), // 5: gmqtt.admin.api.FilterSubscriptionResponse + (*SubscribeRequest)(nil), // 6: gmqtt.admin.api.SubscribeRequest + (*SubscribeResponse)(nil), // 7: gmqtt.admin.api.SubscribeResponse + (*UnsubscribeRequest)(nil), // 8: gmqtt.admin.api.UnsubscribeRequest + (*Subscription)(nil), // 9: gmqtt.admin.api.Subscription + (*empty.Empty)(nil), // 10: google.protobuf.Empty +} +var file_subscription_proto_depIdxs = []int32{ + 9, // 0: gmqtt.admin.api.ListSubscriptionResponse.subscriptions:type_name -> gmqtt.admin.api.Subscription + 1, // 1: gmqtt.admin.api.FilterSubscriptionRequest.match_type:type_name -> gmqtt.admin.api.SubMatchType + 9, // 2: gmqtt.admin.api.FilterSubscriptionResponse.subscriptions:type_name -> gmqtt.admin.api.Subscription + 9, // 3: gmqtt.admin.api.SubscribeRequest.subscriptions:type_name -> gmqtt.admin.api.Subscription + 2, // 4: gmqtt.admin.api.SubscriptionService.List:input_type -> gmqtt.admin.api.ListSubscriptionRequest + 4, // 5: gmqtt.admin.api.SubscriptionService.Filter:input_type -> gmqtt.admin.api.FilterSubscriptionRequest + 6, // 6: gmqtt.admin.api.SubscriptionService.Subscribe:input_type -> gmqtt.admin.api.SubscribeRequest + 8, // 7: gmqtt.admin.api.SubscriptionService.Unsubscribe:input_type -> gmqtt.admin.api.UnsubscribeRequest + 3, // 8: gmqtt.admin.api.SubscriptionService.List:output_type -> gmqtt.admin.api.ListSubscriptionResponse + 5, // 9: gmqtt.admin.api.SubscriptionService.Filter:output_type -> gmqtt.admin.api.FilterSubscriptionResponse + 7, // 10: gmqtt.admin.api.SubscriptionService.Subscribe:output_type -> gmqtt.admin.api.SubscribeResponse + 10, // 11: gmqtt.admin.api.SubscriptionService.Unsubscribe:output_type -> google.protobuf.Empty + 8, // [8:12] is the sub-list for method output_type + 4, // [4:8] is the sub-list for method input_type + 4, // [4:4] is the sub-list for extension type_name + 4, // [4:4] is the sub-list for extension extendee + 0, // [0:4] is the sub-list for field type_name +} + +func init() { file_subscription_proto_init() } +func file_subscription_proto_init() { + if File_subscription_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_subscription_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ListSubscriptionRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_subscription_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ListSubscriptionResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_subscription_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*FilterSubscriptionRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_subscription_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*FilterSubscriptionResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_subscription_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SubscribeRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_subscription_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SubscribeResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_subscription_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*UnsubscribeRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_subscription_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Subscription); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_subscription_proto_rawDesc, + NumEnums: 2, + NumMessages: 8, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_subscription_proto_goTypes, + DependencyIndexes: file_subscription_proto_depIdxs, + EnumInfos: file_subscription_proto_enumTypes, + MessageInfos: file_subscription_proto_msgTypes, + }.Build() + File_subscription_proto = out.File + file_subscription_proto_rawDesc = nil + file_subscription_proto_goTypes = nil + file_subscription_proto_depIdxs = nil +} diff --git a/internal/hummingbird/mqttbroker/plugin/admin/subscription.pb.gw.go b/internal/hummingbird/mqttbroker/plugin/admin/subscription.pb.gw.go new file mode 100644 index 0000000..3103e9c --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/subscription.pb.gw.go @@ -0,0 +1,395 @@ +// Code generated by protoc-gen-grpc-gateway. DO NOT EDIT. +// source: subscription.proto + +/* +Package admin is a reverse proxy. + +It translates gRPC into RESTful JSON APIs. +*/ +package admin + +import ( + "context" + "io" + "net/http" + + "github.com/golang/protobuf/descriptor" + "github.com/golang/protobuf/proto" + "github.com/grpc-ecosystem/grpc-gateway/runtime" + "github.com/grpc-ecosystem/grpc-gateway/utilities" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/status" +) + +// Suppress "imported and not used" errors +var _ codes.Code +var _ io.Reader +var _ status.Status +var _ = runtime.String +var _ = utilities.NewDoubleArray +var _ = descriptor.ForMessage + +var ( + filter_SubscriptionService_List_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)} +) + +func request_SubscriptionService_List_0(ctx context.Context, marshaler runtime.Marshaler, client SubscriptionServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq ListSubscriptionRequest + var metadata runtime.ServerMetadata + + if err := req.ParseForm(); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_SubscriptionService_List_0); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := client.List(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err + +} + +func local_request_SubscriptionService_List_0(ctx context.Context, marshaler runtime.Marshaler, server SubscriptionServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq ListSubscriptionRequest + var metadata runtime.ServerMetadata + + if err := runtime.PopulateQueryParameters(&protoReq, req.URL.Query(), filter_SubscriptionService_List_0); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := server.List(ctx, &protoReq) + return msg, metadata, err + +} + +var ( + filter_SubscriptionService_Filter_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)} +) + +func request_SubscriptionService_Filter_0(ctx context.Context, marshaler runtime.Marshaler, client SubscriptionServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq FilterSubscriptionRequest + var metadata runtime.ServerMetadata + + if err := req.ParseForm(); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_SubscriptionService_Filter_0); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := client.Filter(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err + +} + +func local_request_SubscriptionService_Filter_0(ctx context.Context, marshaler runtime.Marshaler, server SubscriptionServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq FilterSubscriptionRequest + var metadata runtime.ServerMetadata + + if err := runtime.PopulateQueryParameters(&protoReq, req.URL.Query(), filter_SubscriptionService_Filter_0); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := server.Filter(ctx, &protoReq) + return msg, metadata, err + +} + +func request_SubscriptionService_Subscribe_0(ctx context.Context, marshaler runtime.Marshaler, client SubscriptionServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq SubscribeRequest + var metadata runtime.ServerMetadata + + newReader, berr := utilities.IOReaderFactory(req.Body) + if berr != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) + } + if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := client.Subscribe(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err + +} + +func local_request_SubscriptionService_Subscribe_0(ctx context.Context, marshaler runtime.Marshaler, server SubscriptionServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq SubscribeRequest + var metadata runtime.ServerMetadata + + newReader, berr := utilities.IOReaderFactory(req.Body) + if berr != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) + } + if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := server.Subscribe(ctx, &protoReq) + return msg, metadata, err + +} + +func request_SubscriptionService_Unsubscribe_0(ctx context.Context, marshaler runtime.Marshaler, client SubscriptionServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq UnsubscribeRequest + var metadata runtime.ServerMetadata + + newReader, berr := utilities.IOReaderFactory(req.Body) + if berr != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) + } + if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := client.Unsubscribe(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err + +} + +func local_request_SubscriptionService_Unsubscribe_0(ctx context.Context, marshaler runtime.Marshaler, server SubscriptionServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq UnsubscribeRequest + var metadata runtime.ServerMetadata + + newReader, berr := utilities.IOReaderFactory(req.Body) + if berr != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) + } + if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := server.Unsubscribe(ctx, &protoReq) + return msg, metadata, err + +} + +// RegisterSubscriptionServiceHandlerServer registers the http handlers for service SubscriptionService to "mux". +// UnaryRPC :call SubscriptionServiceServer directly. +// StreamingRPC :currently unsupported pending https://github.com/grpc/grpc-go/issues/906. +func RegisterSubscriptionServiceHandlerServer(ctx context.Context, mux *runtime.ServeMux, server SubscriptionServiceServer) error { + + mux.Handle("GET", pattern_SubscriptionService_List_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_SubscriptionService_List_0(rctx, inboundMarshaler, server, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_SubscriptionService_List_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("GET", pattern_SubscriptionService_Filter_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_SubscriptionService_Filter_0(rctx, inboundMarshaler, server, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_SubscriptionService_Filter_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("POST", pattern_SubscriptionService_Subscribe_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_SubscriptionService_Subscribe_0(rctx, inboundMarshaler, server, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_SubscriptionService_Subscribe_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("POST", pattern_SubscriptionService_Unsubscribe_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_SubscriptionService_Unsubscribe_0(rctx, inboundMarshaler, server, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_SubscriptionService_Unsubscribe_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + return nil +} + +// RegisterSubscriptionServiceHandlerFromEndpoint is same as RegisterSubscriptionServiceHandler but +// automatically dials to "endpoint" and closes the connection when "ctx" gets done. +func RegisterSubscriptionServiceHandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { + conn, err := grpc.Dial(endpoint, opts...) + if err != nil { + return err + } + defer func() { + if err != nil { + if cerr := conn.Close(); cerr != nil { + grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) + } + return + } + go func() { + <-ctx.Done() + if cerr := conn.Close(); cerr != nil { + grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) + } + }() + }() + + return RegisterSubscriptionServiceHandler(ctx, mux, conn) +} + +// RegisterSubscriptionServiceHandler registers the http handlers for service SubscriptionService to "mux". +// The handlers forward requests to the grpc endpoint over "conn". +func RegisterSubscriptionServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { + return RegisterSubscriptionServiceHandlerClient(ctx, mux, NewSubscriptionServiceClient(conn)) +} + +// RegisterSubscriptionServiceHandlerClient registers the http handlers for service SubscriptionService +// to "mux". The handlers forward requests to the grpc endpoint over the given implementation of "SubscriptionServiceClient". +// Note: the gRPC framework executes interceptors within the gRPC handler. If the passed in "SubscriptionServiceClient" +// doesn't go through the normal gRPC flow (creating a gRPC client etc.) then it will be up to the passed in +// "SubscriptionServiceClient" to call the correct interceptors. +func RegisterSubscriptionServiceHandlerClient(ctx context.Context, mux *runtime.ServeMux, client SubscriptionServiceClient) error { + + mux.Handle("GET", pattern_SubscriptionService_List_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_SubscriptionService_List_0(rctx, inboundMarshaler, client, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_SubscriptionService_List_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("GET", pattern_SubscriptionService_Filter_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_SubscriptionService_Filter_0(rctx, inboundMarshaler, client, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_SubscriptionService_Filter_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("POST", pattern_SubscriptionService_Subscribe_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_SubscriptionService_Subscribe_0(rctx, inboundMarshaler, client, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_SubscriptionService_Subscribe_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("POST", pattern_SubscriptionService_Unsubscribe_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_SubscriptionService_Unsubscribe_0(rctx, inboundMarshaler, client, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_SubscriptionService_Unsubscribe_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + return nil +} + +var ( + pattern_SubscriptionService_List_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1}, []string{"v1", "subscriptions"}, "", runtime.AssumeColonVerbOpt(true))) + + pattern_SubscriptionService_Filter_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1}, []string{"v1", "filter_subscriptions"}, "", runtime.AssumeColonVerbOpt(true))) + + pattern_SubscriptionService_Subscribe_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1}, []string{"v1", "subscribe"}, "", runtime.AssumeColonVerbOpt(true))) + + pattern_SubscriptionService_Unsubscribe_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1}, []string{"v1", "unsubscribe"}, "", runtime.AssumeColonVerbOpt(true))) +) + +var ( + forward_SubscriptionService_List_0 = runtime.ForwardResponseMessage + + forward_SubscriptionService_Filter_0 = runtime.ForwardResponseMessage + + forward_SubscriptionService_Subscribe_0 = runtime.ForwardResponseMessage + + forward_SubscriptionService_Unsubscribe_0 = runtime.ForwardResponseMessage +) diff --git a/internal/hummingbird/mqttbroker/plugin/admin/subscription_grpc.pb.go b/internal/hummingbird/mqttbroker/plugin/admin/subscription_grpc.pb.go new file mode 100644 index 0000000..b8c80d7 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/subscription_grpc.pb.go @@ -0,0 +1,215 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. + +package admin + +import ( + context "context" + + empty "github.com/golang/protobuf/ptypes/empty" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion7 + +// SubscriptionServiceClient is the client API for SubscriptionService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type SubscriptionServiceClient interface { + // List subscriptions. + List(ctx context.Context, in *ListSubscriptionRequest, opts ...grpc.CallOption) (*ListSubscriptionResponse, error) + // Filter subscriptions, paging is not supported in this API. + Filter(ctx context.Context, in *FilterSubscriptionRequest, opts ...grpc.CallOption) (*FilterSubscriptionResponse, error) + // Subscribe topics for the client. + Subscribe(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (*SubscribeResponse, error) + // Unsubscribe topics for the client. + Unsubscribe(ctx context.Context, in *UnsubscribeRequest, opts ...grpc.CallOption) (*empty.Empty, error) +} + +type subscriptionServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewSubscriptionServiceClient(cc grpc.ClientConnInterface) SubscriptionServiceClient { + return &subscriptionServiceClient{cc} +} + +func (c *subscriptionServiceClient) List(ctx context.Context, in *ListSubscriptionRequest, opts ...grpc.CallOption) (*ListSubscriptionResponse, error) { + out := new(ListSubscriptionResponse) + err := c.cc.Invoke(ctx, "/gmqtt.admin.api.SubscriptionService/List", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *subscriptionServiceClient) Filter(ctx context.Context, in *FilterSubscriptionRequest, opts ...grpc.CallOption) (*FilterSubscriptionResponse, error) { + out := new(FilterSubscriptionResponse) + err := c.cc.Invoke(ctx, "/gmqtt.admin.api.SubscriptionService/Filter", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *subscriptionServiceClient) Subscribe(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (*SubscribeResponse, error) { + out := new(SubscribeResponse) + err := c.cc.Invoke(ctx, "/gmqtt.admin.api.SubscriptionService/Subscribe", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *subscriptionServiceClient) Unsubscribe(ctx context.Context, in *UnsubscribeRequest, opts ...grpc.CallOption) (*empty.Empty, error) { + out := new(empty.Empty) + err := c.cc.Invoke(ctx, "/gmqtt.admin.api.SubscriptionService/Unsubscribe", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// SubscriptionServiceServer is the server API for SubscriptionService service. +// All implementations must embed UnimplementedSubscriptionServiceServer +// for forward compatibility +type SubscriptionServiceServer interface { + // List subscriptions. + List(context.Context, *ListSubscriptionRequest) (*ListSubscriptionResponse, error) + // Filter subscriptions, paging is not supported in this API. + Filter(context.Context, *FilterSubscriptionRequest) (*FilterSubscriptionResponse, error) + // Subscribe topics for the client. + Subscribe(context.Context, *SubscribeRequest) (*SubscribeResponse, error) + // Unsubscribe topics for the client. + Unsubscribe(context.Context, *UnsubscribeRequest) (*empty.Empty, error) + mustEmbedUnimplementedSubscriptionServiceServer() +} + +// UnimplementedSubscriptionServiceServer must be embedded to have forward compatible implementations. +type UnimplementedSubscriptionServiceServer struct { +} + +func (UnimplementedSubscriptionServiceServer) List(context.Context, *ListSubscriptionRequest) (*ListSubscriptionResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method List not implemented") +} +func (UnimplementedSubscriptionServiceServer) Filter(context.Context, *FilterSubscriptionRequest) (*FilterSubscriptionResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Filter not implemented") +} +func (UnimplementedSubscriptionServiceServer) Subscribe(context.Context, *SubscribeRequest) (*SubscribeResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Subscribe not implemented") +} +func (UnimplementedSubscriptionServiceServer) Unsubscribe(context.Context, *UnsubscribeRequest) (*empty.Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method Unsubscribe not implemented") +} +func (UnimplementedSubscriptionServiceServer) mustEmbedUnimplementedSubscriptionServiceServer() {} + +// UnsafeSubscriptionServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to SubscriptionServiceServer will +// result in compilation errors. +type UnsafeSubscriptionServiceServer interface { + mustEmbedUnimplementedSubscriptionServiceServer() +} + +func RegisterSubscriptionServiceServer(s grpc.ServiceRegistrar, srv SubscriptionServiceServer) { + s.RegisterService(&_SubscriptionService_serviceDesc, srv) +} + +func _SubscriptionService_List_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ListSubscriptionRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(SubscriptionServiceServer).List(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/gmqtt.admin.api.SubscriptionService/List", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(SubscriptionServiceServer).List(ctx, req.(*ListSubscriptionRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _SubscriptionService_Filter_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(FilterSubscriptionRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(SubscriptionServiceServer).Filter(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/gmqtt.admin.api.SubscriptionService/Filter", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(SubscriptionServiceServer).Filter(ctx, req.(*FilterSubscriptionRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _SubscriptionService_Subscribe_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SubscribeRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(SubscriptionServiceServer).Subscribe(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/gmqtt.admin.api.SubscriptionService/Subscribe", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(SubscriptionServiceServer).Subscribe(ctx, req.(*SubscribeRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _SubscriptionService_Unsubscribe_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(UnsubscribeRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(SubscriptionServiceServer).Unsubscribe(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/gmqtt.admin.api.SubscriptionService/Unsubscribe", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(SubscriptionServiceServer).Unsubscribe(ctx, req.(*UnsubscribeRequest)) + } + return interceptor(ctx, in, info, handler) +} + +var _SubscriptionService_serviceDesc = grpc.ServiceDesc{ + ServiceName: "gmqtt.admin.api.SubscriptionService", + HandlerType: (*SubscriptionServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "List", + Handler: _SubscriptionService_List_Handler, + }, + { + MethodName: "Filter", + Handler: _SubscriptionService_Filter_Handler, + }, + { + MethodName: "Subscribe", + Handler: _SubscriptionService_Subscribe_Handler, + }, + { + MethodName: "Unsubscribe", + Handler: _SubscriptionService_Unsubscribe_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "subscription.proto", +} diff --git a/internal/hummingbird/mqttbroker/plugin/admin/subscription_test.go b/internal/hummingbird/mqttbroker/plugin/admin/subscription_test.go new file mode 100644 index 0000000..dd82d65 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/subscription_test.go @@ -0,0 +1,390 @@ +package admin + +import ( + "context" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/subscription" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/server" +) + +func TestSubscriptionService_List(t *testing.T) { + + a := assert.New(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ss := server.NewMockSubscriptionService(ctrl) + client := server.NewMockClient(ctrl) + admin := &Admin{ + store: newStore(nil, mockConfig), + } + sub := &subscriptionService{ + a: admin, + } + client.EXPECT().ClientOptions().Return(&server.ClientOptions{ClientID: "id"}) + subscribe := admin.OnSubscribedWrapper(func(ctx context.Context, client server.Client, subscription *gmqtt.Subscription) {}) + + subsc := &gmqtt.Subscription{ + ShareName: "abc", + TopicFilter: "t", + ID: 1, + QoS: 2, + NoLocal: true, + RetainAsPublished: true, + RetainHandling: 2, + } + subscribe(context.Background(), client, subsc) + sub.a.store.subscriptionService = ss + + resp, err := sub.List(context.Background(), &ListSubscriptionRequest{ + PageSize: 0, + Page: 0, + }) + + a.Nil(err) + a.Len(resp.Subscriptions, 1) + rs := resp.Subscriptions[0] + a.EqualValues(subsc.QoS, rs.Qos) + a.EqualValues(subsc.GetFullTopicName(), rs.TopicName) + a.EqualValues(subsc.ID, rs.Id) + a.EqualValues(subsc.RetainHandling, rs.RetainHandling) + a.EqualValues(subsc.NoLocal, rs.NoLocal) +} + +func TestSubscriptionService_Filter(t *testing.T) { + a := assert.New(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ss := server.NewMockSubscriptionService(ctrl) + admin := &Admin{ + store: newStore(nil, mockConfig), + } + sub := &subscriptionService{ + a: admin, + } + sub.a.store.subscriptionService = ss + + ss.EXPECT().Iterate(gomock.Any(), subscription.IterationOptions{ + Type: subscription.TypeAll, + ClientID: "cid", + TopicName: "abc", + MatchType: subscription.MatchName, + }) + + _, err := sub.Filter(context.Background(), &FilterSubscriptionRequest{ + ClientId: "cid", + FilterType: "1,2,3", + MatchType: SubMatchType_SUB_MATCH_TYPE_MATCH_NAME, + TopicName: "abc", + Limit: 1, + }) + a.Nil(err) + + ss.EXPECT().Iterate(gomock.Any(), subscription.IterationOptions{ + Type: subscription.TypeAll, + ClientID: "cid", + TopicName: "abc", + MatchType: subscription.MatchName, + }) + + // test default filter type + _, err = sub.Filter(context.Background(), &FilterSubscriptionRequest{ + ClientId: "cid", + FilterType: "", + MatchType: SubMatchType_SUB_MATCH_TYPE_MATCH_NAME, + TopicName: "abc", + Limit: 1, + }) + a.Nil(err) + + ss.EXPECT().Iterate(gomock.Any(), subscription.IterationOptions{ + Type: subscription.TypeNonShared | subscription.TypeSYS, + ClientID: "cid", + TopicName: "abc", + MatchType: subscription.MatchName, + }) + + _, err = sub.Filter(context.Background(), &FilterSubscriptionRequest{ + ClientId: "cid", + FilterType: "1,3", + MatchType: SubMatchType_SUB_MATCH_TYPE_MATCH_NAME, + TopicName: "abc", + Limit: 1, + }) + a.Nil(err) + + ss.EXPECT().Iterate(gomock.Any(), subscription.IterationOptions{ + Type: subscription.TypeNonShared | subscription.TypeSYS, + ClientID: "cid", + TopicName: "abc", + MatchType: subscription.MatchFilter, + }) + + _, err = sub.Filter(context.Background(), &FilterSubscriptionRequest{ + ClientId: "cid", + FilterType: "1,3", + MatchType: SubMatchType_SUB_MATCH_TYPE_MATCH_FILTER, + TopicName: "abc", + Limit: 1, + }) + a.Nil(err) + +} + +func TestSubscriptionService_Filter_InvalidArgument(t *testing.T) { + var tt = []struct { + name string + field string + req *FilterSubscriptionRequest + }{ + { + name: "empty_topic_name_with_match_name", + field: "match_type", + req: &FilterSubscriptionRequest{ + ClientId: "", + FilterType: "", + MatchType: SubMatchType_SUB_MATCH_TYPE_MATCH_NAME, + TopicName: "", + Limit: 1, + }, + }, + { + name: "empty_topic_name_with_match_filter", + field: "match_type", + req: &FilterSubscriptionRequest{ + ClientId: "", + FilterType: "", + MatchType: SubMatchType_SUB_MATCH_TYPE_MATCH_FILTER, + TopicName: "", + Limit: 1, + }, + }, + { + name: "invalid_topic_name", + field: "topic_name", + req: &FilterSubscriptionRequest{ + ClientId: "", + FilterType: "", + MatchType: 0, + TopicName: "##", + Limit: 1, + }, + }, + } + for _, v := range tt { + t.Run(v.name, func(t *testing.T) { + a := assert.New(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ss := server.NewMockSubscriptionService(ctrl) + admin := &Admin{ + store: newStore(nil, mockConfig), + } + sub := &subscriptionService{ + a: admin, + } + sub.a.store.subscriptionService = ss + + _, err := sub.Filter(context.Background(), v.req) + a.NotNil(err) + a.Contains(err.Error(), v.field) + + }) + } + +} + +func TestSubscriptionService_Subscribe(t *testing.T) { + a := assert.New(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ss := server.NewMockSubscriptionService(ctrl) + admin := &Admin{ + store: newStore(nil, mockConfig), + } + sub := &subscriptionService{ + a: admin, + } + sub.a.store.subscriptionService = ss + + subs := []*Subscription{ + { + TopicName: "$share/a/b", + Id: 1, + Qos: 2, + NoLocal: true, + RetainAsPublished: true, + RetainHandling: 2, + }, { + TopicName: "abc", + Id: 1, + Qos: 2, + NoLocal: true, + RetainAsPublished: true, + RetainHandling: 2, + }, + } + var expectedSubs []*gmqtt.Subscription + for _, v := range subs { + shareName, filter := subscription.SplitTopic(v.TopicName) + s := &gmqtt.Subscription{ + ShareName: shareName, + TopicFilter: filter, + ID: v.Id, + QoS: byte(v.Qos), + NoLocal: v.NoLocal, + RetainAsPublished: v.RetainAsPublished, + RetainHandling: byte(v.RetainHandling), + } + expectedSubs = append(expectedSubs, s) + } + + ss.EXPECT().Subscribe("cid", expectedSubs).Return(subscription.SubscribeResult{ + { + AlreadyExisted: true, + }, { + AlreadyExisted: false, + }, + }, nil) + + resp, err := sub.Subscribe(context.Background(), &SubscribeRequest{ + ClientId: "cid", + Subscriptions: subs, + }) + a.Nil(err) + resp.New = []bool{false, true} + +} + +func TestSubscriptionService_Subscribe_InvalidArgument(t *testing.T) { + var tt = []struct { + name string + req *SubscribeRequest + }{ + { + name: "empty_client_id", + req: &SubscribeRequest{ + ClientId: "", + Subscriptions: nil, + }, + }, + { + name: "empty_subscriptions", + req: &SubscribeRequest{ + ClientId: "cid", + }, + }, + { + name: "invalid_subscriptions", + req: &SubscribeRequest{ + ClientId: "cid", + Subscriptions: []*Subscription{ + { + TopicName: "##", + Id: 0, + Qos: 0, + NoLocal: false, + RetainAsPublished: false, + RetainHandling: 0, + }, + }, + }, + }, + } + for _, v := range tt { + t.Run(v.name, func(t *testing.T) { + a := assert.New(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ss := server.NewMockSubscriptionService(ctrl) + admin := &Admin{ + store: newStore(nil, mockConfig), + } + sub := &subscriptionService{ + a: admin, + } + sub.a.store.subscriptionService = ss + + _, err := sub.Subscribe(context.Background(), v.req) + a.NotNil(err) + }) + } + +} + +func TestSubscriptionService_Unsubscribe(t *testing.T) { + a := assert.New(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ss := server.NewMockSubscriptionService(ctrl) + admin := &Admin{ + store: newStore(nil, mockConfig), + } + sub := &subscriptionService{ + a: admin, + } + sub.a.store.subscriptionService = ss + + topics := []string{ + "a", "b", + } + ss.EXPECT().Unsubscribe("cid", topics) + _, err := sub.Unsubscribe(context.Background(), &UnsubscribeRequest{ + ClientId: "cid", + Topics: topics, + }) + a.Nil(err) + +} + +func TestSubscriptionService_Unsubscribe_InvalidArgument(t *testing.T) { + var tt = []struct { + name string + req *UnsubscribeRequest + }{ + { + name: "empty_client_id", + req: &UnsubscribeRequest{ + ClientId: "", + Topics: nil, + }, + }, + { + name: "invalid_topic_name", + req: &UnsubscribeRequest{ + ClientId: "cid", + Topics: []string{"+", "##"}, + }, + }, + } + for _, v := range tt { + t.Run(v.name, func(t *testing.T) { + a := assert.New(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ss := server.NewMockSubscriptionService(ctrl) + admin := &Admin{ + store: newStore(nil, mockConfig), + } + sub := &subscriptionService{ + a: admin, + } + sub.a.store.subscriptionService = ss + + _, err := sub.Unsubscribe(context.Background(), v.req) + a.NotNil(err) + }) + } + +} diff --git a/internal/hummingbird/mqttbroker/plugin/admin/swagger/client.swagger.json b/internal/hummingbird/mqttbroker/plugin/admin/swagger/client.swagger.json new file mode 100644 index 0000000..21cc4c0 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/swagger/client.swagger.json @@ -0,0 +1,260 @@ +{ + "swagger": "2.0", + "info": { + "title": "client.proto", + "version": "version not set" + }, + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "paths": { + "/v1/clients": { + "get": { + "summary": "List clients", + "operationId": "List", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "$ref": "#/definitions/apiListClientResponse" + } + }, + "default": { + "description": "An unexpected error response", + "schema": { + "$ref": "#/definitions/runtimeError" + } + } + }, + "parameters": [ + { + "name": "page_size", + "in": "query", + "required": false, + "type": "integer", + "format": "int64" + }, + { + "name": "page", + "in": "query", + "required": false, + "type": "integer", + "format": "int64" + } + ], + "tags": [ + "ClientService" + ] + } + }, + "/v1/clients/{client_id}": { + "get": { + "summary": "Get the client for given client id.\nReturn NotFound error when client not found.", + "operationId": "Get", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "$ref": "#/definitions/apiGetClientResponse" + } + }, + "default": { + "description": "An unexpected error response", + "schema": { + "$ref": "#/definitions/runtimeError" + } + } + }, + "parameters": [ + { + "name": "client_id", + "in": "path", + "required": true, + "type": "string" + } + ], + "tags": [ + "ClientService" + ] + }, + "delete": { + "summary": "Disconnect the client for given client id.", + "operationId": "Delete", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "properties": {} + } + }, + "default": { + "description": "An unexpected error response", + "schema": { + "$ref": "#/definitions/runtimeError" + } + } + }, + "parameters": [ + { + "name": "client_id", + "in": "path", + "required": true, + "type": "string" + }, + { + "name": "clean_session", + "in": "query", + "required": false, + "type": "boolean", + "format": "boolean" + } + ], + "tags": [ + "ClientService" + ] + } + } + }, + "definitions": { + "apiClient": { + "type": "object", + "properties": { + "client_id": { + "type": "string" + }, + "username": { + "type": "string" + }, + "keep_alive": { + "type": "integer", + "format": "int32" + }, + "version": { + "type": "integer", + "format": "int32" + }, + "remote_addr": { + "type": "string" + }, + "local_addr": { + "type": "string" + }, + "connected_at": { + "type": "string", + "format": "date-time" + }, + "disconnected_at": { + "type": "string", + "format": "date-time" + }, + "session_expiry": { + "type": "integer", + "format": "int64" + }, + "max_inflight": { + "type": "integer", + "format": "int64" + }, + "inflight_len": { + "type": "integer", + "format": "int64" + }, + "max_queue": { + "type": "integer", + "format": "int64" + }, + "queue_len": { + "type": "integer", + "format": "int64" + }, + "subscriptions_current": { + "type": "integer", + "format": "int64" + }, + "subscriptions_total": { + "type": "integer", + "format": "int64" + }, + "packets_received_bytes": { + "type": "string", + "format": "uint64" + }, + "packets_received_nums": { + "type": "string", + "format": "uint64" + }, + "packets_send_bytes": { + "type": "string", + "format": "uint64" + }, + "packets_send_nums": { + "type": "string", + "format": "uint64" + }, + "message_dropped": { + "type": "string", + "format": "uint64" + } + } + }, + "apiGetClientResponse": { + "type": "object", + "properties": { + "client": { + "$ref": "#/definitions/apiClient" + } + } + }, + "apiListClientResponse": { + "type": "object", + "properties": { + "clients": { + "type": "array", + "items": { + "$ref": "#/definitions/apiClient" + } + }, + "total_count": { + "type": "integer", + "format": "int64" + } + } + }, + "protobufAny": { + "type": "object", + "properties": { + "type_url": { + "type": "string" + }, + "value": { + "type": "string", + "format": "byte" + } + } + }, + "runtimeError": { + "type": "object", + "properties": { + "error": { + "type": "string" + }, + "code": { + "type": "integer", + "format": "int32" + }, + "message": { + "type": "string" + }, + "details": { + "type": "array", + "items": { + "$ref": "#/definitions/protobufAny" + } + } + } + } + } +} diff --git a/internal/hummingbird/mqttbroker/plugin/admin/swagger/publish.swagger.json b/internal/hummingbird/mqttbroker/plugin/admin/swagger/publish.swagger.json new file mode 100644 index 0000000..2a4f776 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/swagger/publish.swagger.json @@ -0,0 +1,139 @@ +{ + "swagger": "2.0", + "info": { + "title": "publish.proto", + "version": "version not set" + }, + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "paths": { + "/v1/publish": { + "post": { + "summary": "Publish message to broker", + "operationId": "Publish", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "properties": {} + } + }, + "default": { + "description": "An unexpected error response", + "schema": { + "$ref": "#/definitions/runtimeError" + } + } + }, + "parameters": [ + { + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/apiPublishRequest" + } + } + ], + "tags": [ + "PublishService" + ] + } + } + }, + "definitions": { + "apiPublishRequest": { + "type": "object", + "properties": { + "topic_name": { + "type": "string" + }, + "payload": { + "type": "string" + }, + "qos": { + "type": "integer", + "format": "int64" + }, + "retained": { + "type": "boolean", + "format": "boolean" + }, + "content_type": { + "type": "string", + "description": "the following fields are using in v5 client." + }, + "correlation_data": { + "type": "string" + }, + "message_expiry": { + "type": "integer", + "format": "int64" + }, + "payload_format": { + "type": "integer", + "format": "int64" + }, + "response_topic": { + "type": "string" + }, + "user_properties": { + "type": "array", + "items": { + "$ref": "#/definitions/apiUserProperties" + } + } + } + }, + "apiUserProperties": { + "type": "object", + "properties": { + "K": { + "type": "string", + "format": "byte" + }, + "V": { + "type": "string", + "format": "byte" + } + } + }, + "protobufAny": { + "type": "object", + "properties": { + "type_url": { + "type": "string" + }, + "value": { + "type": "string", + "format": "byte" + } + } + }, + "runtimeError": { + "type": "object", + "properties": { + "error": { + "type": "string" + }, + "code": { + "type": "integer", + "format": "int32" + }, + "message": { + "type": "string" + }, + "details": { + "type": "array", + "items": { + "$ref": "#/definitions/protobufAny" + } + } + } + } + } +} diff --git a/internal/hummingbird/mqttbroker/plugin/admin/swagger/subscription.swagger.json b/internal/hummingbird/mqttbroker/plugin/admin/swagger/subscription.swagger.json new file mode 100644 index 0000000..496a5dd --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/swagger/subscription.swagger.json @@ -0,0 +1,329 @@ +{ + "swagger": "2.0", + "info": { + "title": "subscription.proto", + "version": "version not set" + }, + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "paths": { + "/v1/filter_subscriptions": { + "get": { + "summary": "Filter subscriptions, paging is not supported in this API.", + "operationId": "Filter", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "$ref": "#/definitions/apiFilterSubscriptionResponse" + } + }, + "default": { + "description": "An unexpected error response", + "schema": { + "$ref": "#/definitions/runtimeError" + } + } + }, + "parameters": [ + { + "name": "client_id", + "description": "If set, only filter the subscriptions that belongs to the client.", + "in": "query", + "required": false, + "type": "string" + }, + { + "name": "filter_type", + "description": "filter_type indicates what kinds of topics are going to filter.\nIf there are multiple types, use ',' to separate. e.g : 1,2\nThere are 3 kinds of topic can be filtered, defined by SubFilterType:\n1 = System Topic(begin with '$')\n2 = Shared Topic\n3 = NonShared Topic.", + "in": "query", + "required": false, + "type": "string" + }, + { + "name": "match_type", + "description": "If 1 (SUB_MATCH_TYPE_MATCH_NAME), the server will return subscriptions which has the same topic name with request topic_name.\nIf 2 (SUB_MATCH_TYPE_MATCH_FILTER),the server will return subscriptions which match the request topic_name .\nmatch_type must be set when filter_type is not empty.", + "in": "query", + "required": false, + "type": "string", + "enum": [ + "SUB_MATCH_TYPE_MATCH_UNSPECIFIED", + "SUB_MATCH_TYPE_MATCH_NAME", + "SUB_MATCH_TYPE_MATCH_FILTER" + ], + "default": "SUB_MATCH_TYPE_MATCH_UNSPECIFIED" + }, + { + "name": "topic_name", + "description": "topic_name must be set when match_type is not zero.", + "in": "query", + "required": false, + "type": "string" + }, + { + "name": "limit", + "description": "The maximum subscriptions can be returned.", + "in": "query", + "required": false, + "type": "integer", + "format": "int32" + } + ], + "tags": [ + "SubscriptionService" + ] + } + }, + "/v1/subscribe": { + "post": { + "summary": "Subscribe topics for the client.", + "operationId": "Subscribe", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "$ref": "#/definitions/apiSubscribeResponse" + } + }, + "default": { + "description": "An unexpected error response", + "schema": { + "$ref": "#/definitions/runtimeError" + } + } + }, + "parameters": [ + { + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/apiSubscribeRequest" + } + } + ], + "tags": [ + "SubscriptionService" + ] + } + }, + "/v1/subscriptions": { + "get": { + "summary": "List subscriptions.", + "operationId": "List", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "$ref": "#/definitions/apiListSubscriptionResponse" + } + }, + "default": { + "description": "An unexpected error response", + "schema": { + "$ref": "#/definitions/runtimeError" + } + } + }, + "parameters": [ + { + "name": "page_size", + "in": "query", + "required": false, + "type": "integer", + "format": "int64" + }, + { + "name": "page", + "in": "query", + "required": false, + "type": "integer", + "format": "int64" + } + ], + "tags": [ + "SubscriptionService" + ] + } + }, + "/v1/unsubscribe": { + "post": { + "summary": "Unsubscribe topics for the client.", + "operationId": "Unsubscribe", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "properties": {} + } + }, + "default": { + "description": "An unexpected error response", + "schema": { + "$ref": "#/definitions/runtimeError" + } + } + }, + "parameters": [ + { + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/apiUnsubscribeRequest" + } + } + ], + "tags": [ + "SubscriptionService" + ] + } + } + }, + "definitions": { + "apiFilterSubscriptionResponse": { + "type": "object", + "properties": { + "subscriptions": { + "type": "array", + "items": { + "$ref": "#/definitions/apiSubscription" + } + } + } + }, + "apiListSubscriptionResponse": { + "type": "object", + "properties": { + "subscriptions": { + "type": "array", + "items": { + "$ref": "#/definitions/apiSubscription" + } + }, + "total_count": { + "type": "integer", + "format": "int64" + } + } + }, + "apiSubMatchType": { + "type": "string", + "enum": [ + "SUB_MATCH_TYPE_MATCH_UNSPECIFIED", + "SUB_MATCH_TYPE_MATCH_NAME", + "SUB_MATCH_TYPE_MATCH_FILTER" + ], + "default": "SUB_MATCH_TYPE_MATCH_UNSPECIFIED" + }, + "apiSubscribeRequest": { + "type": "object", + "properties": { + "client_id": { + "type": "string" + }, + "subscriptions": { + "type": "array", + "items": { + "$ref": "#/definitions/apiSubscription" + } + } + } + }, + "apiSubscribeResponse": { + "type": "object", + "properties": { + "new": { + "type": "array", + "items": { + "type": "boolean", + "format": "boolean" + }, + "description": "indicates whether it is a new subscription or the subscription is already existed." + } + } + }, + "apiSubscription": { + "type": "object", + "properties": { + "topic_name": { + "type": "string" + }, + "id": { + "type": "integer", + "format": "int64" + }, + "qos": { + "type": "integer", + "format": "int64" + }, + "no_local": { + "type": "boolean", + "format": "boolean" + }, + "retain_as_published": { + "type": "boolean", + "format": "boolean" + }, + "retain_handling": { + "type": "integer", + "format": "int64" + }, + "client_id": { + "type": "string" + } + } + }, + "apiUnsubscribeRequest": { + "type": "object", + "properties": { + "client_id": { + "type": "string" + }, + "topics": { + "type": "array", + "items": { + "type": "string" + } + } + } + }, + "protobufAny": { + "type": "object", + "properties": { + "type_url": { + "type": "string" + }, + "value": { + "type": "string", + "format": "byte" + } + } + }, + "runtimeError": { + "type": "object", + "properties": { + "error": { + "type": "string" + }, + "code": { + "type": "integer", + "format": "int32" + }, + "message": { + "type": "string" + }, + "details": { + "type": "array", + "items": { + "$ref": "#/definitions/protobufAny" + } + } + } + } + } +} diff --git a/internal/hummingbird/mqttbroker/plugin/admin/utils.go b/internal/hummingbird/mqttbroker/plugin/admin/utils.go new file mode 100644 index 0000000..d931311 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/utils.go @@ -0,0 +1,109 @@ +package admin + +import ( + "container/list" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// ErrNotFound represents a not found error. +var ErrNotFound = status.Error(codes.NotFound, "not found") + +// Indexer provides a index for a ordered list that supports queries in O(1). +// All methods are not concurrency-safe. +type Indexer struct { + index map[string]*list.Element + rows *list.List +} + +// NewIndexer is the constructor of Indexer. +func NewIndexer() *Indexer { + return &Indexer{ + index: make(map[string]*list.Element), + rows: list.New(), + } +} + +// Set sets the value for the id. +func (i *Indexer) Set(id string, value interface{}) { + if e, ok := i.index[id]; ok { + e.Value = value + } else { + elem := i.rows.PushBack(value) + i.index[id] = elem + } +} + +// Remove removes and returns the value for the given id. +// Return nil if not found. +func (i *Indexer) Remove(id string) *list.Element { + elem := i.index[id] + if elem != nil { + i.rows.Remove(elem) + } + delete(i.index, id) + return elem +} + +// GetByID returns the value for the given id. +// Return nil if not found. +// Notice: Any access to the return *list.Element also require the mutex, +// because the Set method can modify the Value for *list.Element when updating the Value for the same id. +// If the caller needs the Value in *list.Element, it must get the Value before the next Set is called. +func (i *Indexer) GetByID(id string) *list.Element { + return i.index[id] +} + +// Iterate iterates at most n elements in the list begin from offset. +// Notice: Any access to the *list.Element in fn also require the mutex, +// because the Set method can modify the Value for *list.Element when updating the Value for the same id. +// If the caller needs the Value in *list.Element, it must get the Value before the next Set is called. +func (i *Indexer) Iterate(fn func(elem *list.Element), offset, n uint) { + if i.rows.Len() < int(offset) { + return + } + var j uint + for e := i.rows.Front(); e != nil; e = e.Next() { + if j >= offset && j < offset+n { + fn(e) + } + if j == offset+n { + break + } + j++ + } +} + +// Len returns the length of list. +func (i *Indexer) Len() int { + return i.rows.Len() +} + +// GetPage gets page and pageSize from request params. +func GetPage(reqPage, reqPageSize uint32) (page, pageSize uint) { + page = 1 + pageSize = 20 + if reqPage != 0 { + page = uint(reqPage) + } + if reqPageSize != 0 { + pageSize = uint(reqPageSize) + } + return +} + +func GetOffsetN(page, pageSize uint) (offset, n uint) { + offset = (page - 1) * pageSize + n = pageSize + return +} + +// ErrInvalidArgument is a wrapper function for easier invalid argument error handling. +func ErrInvalidArgument(name string, msg string) error { + errString := "invalid " + name + if msg != "" { + errString = errString + ":" + msg + } + return status.Error(codes.InvalidArgument, errString) +} diff --git a/internal/hummingbird/mqttbroker/plugin/admin/utils_test.go b/internal/hummingbird/mqttbroker/plugin/admin/utils_test.go new file mode 100644 index 0000000..44415c9 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/admin/utils_test.go @@ -0,0 +1,37 @@ +package admin + +import ( + "container/list" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIndexer(t *testing.T) { + a := assert.New(t) + i := NewIndexer() + for j := 0; j < 100; j++ { + i.Set(strconv.Itoa(j), j) + a.EqualValues(j, i.GetByID(strconv.Itoa(j)).Value) + } + a.EqualValues(100, i.Len()) + + var jj int + i.Iterate(func(elem *list.Element) { + v := elem.Value.(int) + a.Equal(jj, v) + jj++ + }, 0, uint(i.Len())) + + e := i.Remove("5") + a.Equal(5, e.Value.(int)) + + var rs []int + i.Iterate(func(elem *list.Element) { + rs = append(rs, elem.Value.(int)) + }, 4, 2) + // 5 is removed + a.Equal([]int{4, 6}, rs) + +} diff --git a/internal/hummingbird/mqttbroker/plugin/aplugin/aplugin.go b/internal/hummingbird/mqttbroker/plugin/aplugin/aplugin.go new file mode 100644 index 0000000..eee1f56 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/aplugin/aplugin.go @@ -0,0 +1,177 @@ +package aplugin + +import ( + "context" + "crypto/md5" + "encoding/hex" + "encoding/json" + "errors" + "sync" + "time" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/plugin/aplugin/snowflake" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/server" + "go.uber.org/zap" +) + +var _ server.Plugin = (*APlugin)(nil) + +const Name = "aplugin" + +func init() { + server.RegisterPlugin(Name, New) + config.RegisterDefaultPluginConfig(Name, &DefaultConfig) +} + +func New(config config.Config) (server.Plugin, error) { + return newAPlugin() +} + +type APlugin struct { + mu sync.Mutex + ctx context.Context + cancel context.CancelFunc + wg *sync.WaitGroup + node *snowflake.Node + log *zap.SugaredLogger + ackMap sync.Map // async ack + driverClients map[string]*DriverClient // driver map, key is username + publisher server.Publisher + publishChan chan PublishInfo +} + +func newAPlugin() (*APlugin, error) { + node, err := snowflake.NewNode(1) + if err != nil { + return nil, err + } + ctx, cancel := context.WithCancel(context.Background()) + return &APlugin{ + node: node, + ctx: ctx, + cancel: cancel, + wg: &sync.WaitGroup{}, + log: zap.NewNop().Sugar(), + driverClients: make(map[string]*DriverClient), + publishChan: make(chan PublishInfo, 32), + }, nil +} + +func (t *APlugin) Load(service server.Server) error { + t.log = server.LoggerWithField(zap.String("plugin", Name)).Sugar() + t.publisher = service.Publisher() + t.wg.Add(1) + go func() { + defer t.wg.Done() + for { + select { + case <-t.ctx.Done(): + t.log.Infof("plugin(%s) exit", t.Name()) + return + case msg := <-t.publishChan: + t.log.Infof("publish msg: topic: %s, payload: %s", msg.Topic, string(msg.Payload)) + t.publisher.Publish(&mqttbroker.Message{ + QoS: 1, + Topic: msg.Topic, + Payload: msg.Payload, + }) + } + } + }() + return nil +} + +func (t *APlugin) Unload() error { + t.cancel() + t.wg.Wait() + return nil +} + +func (t *APlugin) Name() string { + return Name +} + +func (t *APlugin) genAckChan(id int64) *MsgAckChan { + ack := &MsgAckChan{ + Id: id, + DataChan: make(chan interface{}, 1), + } + t.ackMap.Store(id, ack) + return ack +} + +func (t *APlugin) publishWithAckMsg(id int64, topic string, tp int, msg interface{}) (*MsgAckChan, error) { + payload, err := json.Marshal(msg) + if err != nil { + return nil, err + } + buff, err := json.Marshal(AsyncMsg{ + Id: id, + Type: tp, + Data: payload, + }) + if err != nil { + return nil, err + } + ackChan := t.genAckChan(id) + select { + case <-time.After(time.Second): + t.ackMap.Delete(id) + return nil, errors.New("send auth msg to publish chan timeout") + case t.publishChan <- PublishInfo{ + Topic: topic, + Payload: buff, + }: + return ackChan, nil + } +} + +func (t *APlugin) publishNotifyMsg(id int64, topic string, tp int, msg interface{}) error { + payload, err := json.Marshal(msg) + if err != nil { + return err + } + buff, err := json.Marshal(AsyncMsg{ + Id: id, + Type: tp, + Data: payload, + }) + if err != nil { + return err + } + select { + case <-time.After(time.Second): + return errors.New("send auth msg to publish chan timeout") + case t.publishChan <- PublishInfo{ + Topic: topic, + Payload: buff, + }: + return nil + } +} + +func (t *APlugin) validate(username, password string) error { + //t.log.Debugf("got clientId: %s, username: %s, password: %s", clientId, username, password) + passwd, err := md5GenPasswd(username) + if err != nil { + return err + } + if passwd[8:24] != password { + return errors.New("auth failure") + } + return nil +} + +func md5GenPasswd(username string) (string, error) { + h := md5.New() + _, err := h.Write([]byte(username)) + if err != nil { + return "", err + } + rs := h.Sum(nil) + return hex.EncodeToString(rs), nil +} diff --git a/internal/hummingbird/mqttbroker/plugin/aplugin/config.go b/internal/hummingbird/mqttbroker/plugin/aplugin/config.go new file mode 100644 index 0000000..7f1cbd3 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/aplugin/config.go @@ -0,0 +1,18 @@ +package aplugin + +// Config is the configuration for the aplugin plugin. +type Config struct { + // add your config fields +} + +// Validate validates the configuration, and return an error if it is invalid. +func (c *Config) Validate() error { + return nil +} + +// DefaultConfig is the default configuration. +var DefaultConfig = Config{} + +func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { + return nil +} diff --git a/internal/hummingbird/mqttbroker/plugin/aplugin/datatypes.go b/internal/hummingbird/mqttbroker/plugin/aplugin/datatypes.go new file mode 100644 index 0000000..983dc82 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/aplugin/datatypes.go @@ -0,0 +1,192 @@ +package aplugin + +import ( + "encoding/json" + "sync" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/server" +) + +const ( + Auth = iota + 1 // 连接鉴权 + Sub // 设备订阅校验 + Pub // 设备发布校验 + UnSub + Connected + Closed +) + +type PublishInfo struct { + Topic string + Payload []byte +} + +type MsgAckChan struct { + Mu sync.Mutex + Id int64 + IsClosed bool + DataChan chan interface{} // auth ack, device sub ack, device pub ack +} + +func (mac *MsgAckChan) tryCloseChan() { + mac.Mu.Lock() + defer mac.Mu.Unlock() + if !mac.IsClosed { + close(mac.DataChan) + mac.IsClosed = true + } +} + +func (mac *MsgAckChan) trySendDataAndCloseChan(data interface{}) bool { + mac.Mu.Lock() + defer mac.Mu.Unlock() + if !mac.IsClosed { + mac.DataChan <- data + close(mac.DataChan) + mac.IsClosed = true + return true + } + return false +} + +type ( + // AsyncMsg 异步消息统一收发 + AsyncMsg struct { + Id int64 + Type int // 1:连接鉴权,2:设备订阅校验,3:设备发布校验,4:unsub,6:closed + Data json.RawMessage // auth ack sub ack pub ack + } +) + +type ( + AuthCheck struct { + ClientId string + Username string + Password string + Pass bool + Msg string + } +) + +type PubTopic struct { + ClientId string + Username string + Topic string + QoS byte + Retained bool + Pass bool + Msg string +} + +// SubTopic 三方设备或服务订阅topic校验 +type ( + SubTopic struct { + Topic string + QoS byte + Pass bool + Msg string + } + SubTopics struct { + ClientId string + Username string + Topics []SubTopic + } +) + +type ( + // ConnectedNotify 三方设备或服务连接成功后通知对应驱动 + ConnectedNotify struct { + ClientId string + Username string + IP string + Port string + } + + // ClosedNotify 三方设备或服务断开连接后通知对应驱动 + ClosedNotify struct { + ClientId string + Username string + } + + UnSubNotify struct { + ClientId string + Username string + Topics []string + } +) + +type DriverClient struct { + mu sync.RWMutex + ClientId string + Username string + PubTopic string + SubTopic string + ClientMap map[string]*ThirdClient // key is clientId +} + +func (dc *DriverClient) AddThirdClient(client server.Client) { + dc.mu.Lock() + defer dc.mu.Unlock() + dc.ClientMap[client.ClientOptions().ClientID] = newThirdClient(client) +} + +func (dc *DriverClient) DeleteThirdClient(clientId string) { + dc.mu.Lock() + defer dc.mu.Unlock() + delete(dc.ClientMap, clientId) +} + +type ThirdClient struct { + mu sync.RWMutex + client server.Client + subs map[string]struct{} + pubs map[string]struct{} +} + +func (tc *ThirdClient) AddTopics(topics []string, t int) { + tc.mu.Lock() + defer tc.mu.Unlock() + + if t == Sub { + for i := range topics { + tc.subs[topics[i]] = struct{}{} + } + } else if t == Pub { + for i := range topics { + tc.pubs[topics[i]] = struct{}{} + } + } +} + +func (tc *ThirdClient) DeleteTopics(topics []string, t int) { + tc.mu.Lock() + defer tc.mu.Unlock() + + if t == UnSub { + for i := range topics { + delete(tc.subs, topics[i]) + } + } +} + +func (tc *ThirdClient) CheckTopic(topic string, t int) bool { + tc.mu.RLock() + defer tc.mu.RUnlock() + + if t == Sub { + _, ok := tc.subs[topic] + return ok + } else if t == Pub { + _, ok := tc.pubs[topic] + return ok + } + return false +} + +func newThirdClient(c server.Client) *ThirdClient { + return &ThirdClient{ + client: c, + subs: make(map[string]struct{}), + pubs: make(map[string]struct{}), + } +} diff --git a/internal/hummingbird/mqttbroker/plugin/aplugin/hooks.go b/internal/hummingbird/mqttbroker/plugin/aplugin/hooks.go new file mode 100644 index 0000000..c507c61 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/aplugin/hooks.go @@ -0,0 +1,45 @@ +package aplugin + +import ( + "context" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/server" +) + +func (t *APlugin) HookWrapper() server.HookWrapper { + return server.HookWrapper{} +} + +func (t *APlugin) OnBasicAuthWrapper(pre server.OnBasicAuth) server.OnBasicAuth { + return func(ctx context.Context, client server.Client, req *server.ConnectRequest) (err error) { + return nil + } +} + +func (t *APlugin) OnSubscribeWrapper(pre server.OnSubscribe) server.OnSubscribe { + return func(ctx context.Context, client server.Client, req *server.SubscribeRequest) error { + return nil + } +} + +func (t *APlugin) OnUnsubscribeWrapper(pre server.OnUnsubscribe) server.OnUnsubscribe { + return func(ctx context.Context, client server.Client, req *server.UnsubscribeRequest) error { + return nil + } +} + +func (t *APlugin) OnMsgArrivedWrapper(pre server.OnMsgArrived) server.OnMsgArrived { + return func(ctx context.Context, client server.Client, req *server.MsgArrivedRequest) error { + return nil + } +} + +func (t *APlugin) OnConnectedWrapper(pre server.OnConnected) server.OnConnected { + return func(ctx context.Context, client server.Client) { + } +} + +func (t *APlugin) OnClosedWrapper(pre server.OnClosed) server.OnClosed { + return func(ctx context.Context, client server.Client, err error) { + + } +} diff --git a/internal/hummingbird/mqttbroker/plugin/aplugin/snowflake/snowflake.go b/internal/hummingbird/mqttbroker/plugin/aplugin/snowflake/snowflake.go new file mode 100644 index 0000000..f46643f --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/aplugin/snowflake/snowflake.go @@ -0,0 +1,391 @@ +/* +Copyright (c) 2016, Bruce +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +// Package snowflake provides a very simple Twitter snowflake generator and parser. +package snowflake + +import ( + "encoding/base64" + "encoding/binary" + "errors" + "fmt" + "strconv" + "sync" + "time" +) + +var ( + // Epoch is set to the twitter snowflake epoch of Nov 04 2010 01:42:54 UTC in milliseconds + // You may customize this to set a different epoch for your application. + Epoch int64 = 1288834974657 + + // NodeBits holds the number of bits to use for Node + // Remember, you have a total 22 bits to share between Node/Step + NodeBits uint8 = 10 + + // StepBits holds the number of bits to use for Step + // Remember, you have a total 22 bits to share between Node/Step + StepBits uint8 = 12 + + // DEPRECATED: the below four variables will be removed in a future release. + mu sync.Mutex + nodeMax int64 = -1 ^ (-1 << NodeBits) + nodeMask = nodeMax << StepBits + stepMask int64 = -1 ^ (-1 << StepBits) + timeShift = NodeBits + StepBits + nodeShift = StepBits +) + +const encodeBase32Map = "ybndrfg8ejkmcpqxot1uwisza345h769" + +var decodeBase32Map [256]byte + +const encodeBase58Map = "123456789abcdefghijkmnopqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ" + +var decodeBase58Map [256]byte + +// A JSONSyntaxError is returned from UnmarshalJSON if an invalid ID is provided. +type JSONSyntaxError struct{ original []byte } + +func (j JSONSyntaxError) Error() string { + return fmt.Sprintf("invalid snowflake ID %q", string(j.original)) +} + +// ErrInvalidBase58 is returned by ParseBase58 when given an invalid []byte +var ErrInvalidBase58 = errors.New("invalid base58") + +// ErrInvalidBase32 is returned by ParseBase32 when given an invalid []byte +var ErrInvalidBase32 = errors.New("invalid base32") + +// Create maps for decoding Base58/Base32. +// This speeds up the process tremendously. +func init() { + + for i := 0; i < len(decodeBase58Map); i++ { + decodeBase58Map[i] = 0xFF + } + + for i := 0; i < len(encodeBase58Map); i++ { + decodeBase58Map[encodeBase58Map[i]] = byte(i) + } + + for i := 0; i < len(decodeBase32Map); i++ { + decodeBase32Map[i] = 0xFF + } + + for i := 0; i < len(encodeBase32Map); i++ { + decodeBase32Map[encodeBase32Map[i]] = byte(i) + } +} + +// A Node struct holds the basic information needed for a snowflake generator +// node +type Node struct { + mu sync.Mutex + epoch time.Time + time int64 + node int64 + step int64 + + nodeMax int64 + nodeMask int64 + stepMask int64 + timeShift uint8 + nodeShift uint8 +} + +// An ID is a custom type used for a snowflake ID. This is used so we can +// attach methods onto the ID. +type ID int64 + +// NewNode returns a new snowflake node that can be used to generate snowflake +// IDs +func NewNode(node int64) (*Node, error) { + + // re-calc in case custom NodeBits or StepBits were set + // DEPRECATED: the below block will be removed in a future release. + mu.Lock() + nodeMax = -1 ^ (-1 << NodeBits) + nodeMask = nodeMax << StepBits + stepMask = -1 ^ (-1 << StepBits) + timeShift = NodeBits + StepBits + nodeShift = StepBits + mu.Unlock() + + n := Node{} + n.node = node + n.nodeMax = -1 ^ (-1 << NodeBits) + n.nodeMask = n.nodeMax << StepBits + n.stepMask = -1 ^ (-1 << StepBits) + n.timeShift = NodeBits + StepBits + n.nodeShift = StepBits + + if n.node < 0 || n.node > n.nodeMax { + return nil, errors.New("Node number must be between 0 and " + strconv.FormatInt(n.nodeMax, 10)) + } + + var curTime = time.Now() + // add time.Duration to curTime to make sure we use the monotonic clock if available + n.epoch = curTime.Add(time.Unix(Epoch/1000, (Epoch%1000)*1000000).Sub(curTime)) + + return &n, nil +} + +// Generate creates and returns a unique snowflake ID +// To help guarantee uniqueness +// - Make sure your system is keeping accurate system time +// - Make sure you never have multiple nodes running with the same node ID +func (n *Node) Generate() ID { + + n.mu.Lock() + + now := time.Since(n.epoch).Nanoseconds() / 1000000 + + if now == n.time { + n.step = (n.step + 1) & n.stepMask + + if n.step == 0 { + for now <= n.time { + now = time.Since(n.epoch).Nanoseconds() / 1000000 + } + } + } else { + n.step = 0 + } + + n.time = now + + r := ID((now)<= 32 { + b = append(b, encodeBase32Map[f%32]) + f /= 32 + } + b = append(b, encodeBase32Map[f]) + + for x, y := 0, len(b)-1; x < y; x, y = x+1, y-1 { + b[x], b[y] = b[y], b[x] + } + + return string(b) +} + +// ParseBase32 parses a base32 []byte into a snowflake ID +// NOTE: There are many different base32 implementations so becareful when +// doing any interoperation. +func ParseBase32(b []byte) (ID, error) { + + var id int64 + + for i := range b { + if decodeBase32Map[b[i]] == 0xFF { + return -1, ErrInvalidBase32 + } + id = id*32 + int64(decodeBase32Map[b[i]]) + } + + return ID(id), nil +} + +// Base36 returns a base36 string of the snowflake ID +func (f ID) Base36() string { + return strconv.FormatInt(int64(f), 36) +} + +// ParseBase36 converts a Base36 string into a snowflake ID +func ParseBase36(id string) (ID, error) { + i, err := strconv.ParseInt(id, 36, 64) + return ID(i), err +} + +// Base58 returns a base58 string of the snowflake ID +func (f ID) Base58() string { + + if f < 58 { + return string(encodeBase58Map[f]) + } + + b := make([]byte, 0, 11) + for f >= 58 { + b = append(b, encodeBase58Map[f%58]) + f /= 58 + } + b = append(b, encodeBase58Map[f]) + + for x, y := 0, len(b)-1; x < y; x, y = x+1, y-1 { + b[x], b[y] = b[y], b[x] + } + + return string(b) +} + +// ParseBase58 parses a base58 []byte into a snowflake ID +func ParseBase58(b []byte) (ID, error) { + + var id int64 + + for i := range b { + if decodeBase58Map[b[i]] == 0xFF { + return -1, ErrInvalidBase58 + } + id = id*58 + int64(decodeBase58Map[b[i]]) + } + + return ID(id), nil +} + +// Base64 returns a base64 string of the snowflake ID +func (f ID) Base64() string { + return base64.StdEncoding.EncodeToString(f.Bytes()) +} + +// ParseBase64 converts a base64 string into a snowflake ID +func ParseBase64(id string) (ID, error) { + b, err := base64.StdEncoding.DecodeString(id) + if err != nil { + return -1, err + } + return ParseBytes(b) + +} + +// Bytes returns a byte slice of the snowflake ID +func (f ID) Bytes() []byte { + return []byte(f.String()) +} + +// ParseBytes converts a byte slice into a snowflake ID +func ParseBytes(id []byte) (ID, error) { + i, err := strconv.ParseInt(string(id), 10, 64) + return ID(i), err +} + +// IntBytes returns an array of bytes of the snowflake ID, encoded as a +// big endian integer. +func (f ID) IntBytes() [8]byte { + var b [8]byte + binary.BigEndian.PutUint64(b[:], uint64(f)) + return b +} + +// ParseIntBytes converts an array of bytes encoded as big endian integer as +// a snowflake ID +func ParseIntBytes(id [8]byte) ID { + return ID(int64(binary.BigEndian.Uint64(id[:]))) +} + +// Time returns an int64 unix timestamp in milliseconds of the snowflake ID time +// DEPRECATED: the below function will be removed in a future release. +func (f ID) Time() int64 { + return (int64(f) >> timeShift) + Epoch +} + +// Node returns an int64 of the snowflake ID node number +// DEPRECATED: the below function will be removed in a future release. +func (f ID) Node() int64 { + return int64(f) & nodeMask >> nodeShift +} + +// Step returns an int64 of the snowflake step (or sequence) number +// DEPRECATED: the below function will be removed in a future release. +func (f ID) Step() int64 { + return int64(f) & stepMask +} + +// MarshalJSON returns a json byte array string of the snowflake ID. +func (f ID) MarshalJSON() ([]byte, error) { + buff := make([]byte, 0, 22) + buff = append(buff, '"') + buff = strconv.AppendInt(buff, int64(f), 10) + buff = append(buff, '"') + return buff, nil +} + +// UnmarshalJSON converts a json byte array of a snowflake ID into an ID type. +func (f *ID) UnmarshalJSON(b []byte) error { + if len(b) < 3 || b[0] != '"' || b[len(b)-1] != '"' { + return JSONSyntaxError{b} + } + + i, err := strconv.ParseInt(string(b[1:len(b)-1]), 10, 64) + if err != nil { + return err + } + + *f = ID(i) + return nil +} diff --git a/internal/hummingbird/mqttbroker/plugin/auth/README.md b/internal/hummingbird/mqttbroker/plugin/auth/README.md new file mode 100644 index 0000000..41d8c3b --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/auth/README.md @@ -0,0 +1,7 @@ +# Auth + +Auth plugin provides a simple username/password authentication mechanism. + +# API Doc + +See [swagger](https://github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/blob/master/plugin/auth/swagger) diff --git a/internal/hummingbird/mqttbroker/plugin/auth/account.pb.go b/internal/hummingbird/mqttbroker/plugin/auth/account.pb.go new file mode 100644 index 0000000..c50d3eb --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/auth/account.pb.go @@ -0,0 +1,614 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.22.0 +// protoc v3.13.0 +// source: account.proto + +package auth + +import ( + reflect "reflect" + sync "sync" + + proto "github.com/golang/protobuf/proto" + empty "github.com/golang/protobuf/ptypes/empty" + _ "google.golang.org/genproto/googleapis/api/annotations" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// This is a compile-time assertion that a sufficiently up-to-date version +// of the legacy proto package is being used. +const _ = proto.ProtoPackageIsVersion4 + +type ListAccountsRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + PageSize uint32 `protobuf:"varint,1,opt,name=page_size,json=pageSize,proto3" json:"page_size,omitempty"` + Page uint32 `protobuf:"varint,2,opt,name=page,proto3" json:"page,omitempty"` +} + +func (x *ListAccountsRequest) Reset() { + *x = ListAccountsRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_account_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ListAccountsRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListAccountsRequest) ProtoMessage() {} + +func (x *ListAccountsRequest) ProtoReflect() protoreflect.Message { + mi := &file_account_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListAccountsRequest.ProtoReflect.Descriptor instead. +func (*ListAccountsRequest) Descriptor() ([]byte, []int) { + return file_account_proto_rawDescGZIP(), []int{0} +} + +func (x *ListAccountsRequest) GetPageSize() uint32 { + if x != nil { + return x.PageSize + } + return 0 +} + +func (x *ListAccountsRequest) GetPage() uint32 { + if x != nil { + return x.Page + } + return 0 +} + +type ListAccountsResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Accounts []*Account `protobuf:"bytes,1,rep,name=accounts,proto3" json:"accounts,omitempty"` + TotalCount uint32 `protobuf:"varint,2,opt,name=total_count,json=totalCount,proto3" json:"total_count,omitempty"` +} + +func (x *ListAccountsResponse) Reset() { + *x = ListAccountsResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_account_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ListAccountsResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListAccountsResponse) ProtoMessage() {} + +func (x *ListAccountsResponse) ProtoReflect() protoreflect.Message { + mi := &file_account_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListAccountsResponse.ProtoReflect.Descriptor instead. +func (*ListAccountsResponse) Descriptor() ([]byte, []int) { + return file_account_proto_rawDescGZIP(), []int{1} +} + +func (x *ListAccountsResponse) GetAccounts() []*Account { + if x != nil { + return x.Accounts + } + return nil +} + +func (x *ListAccountsResponse) GetTotalCount() uint32 { + if x != nil { + return x.TotalCount + } + return 0 +} + +type GetAccountRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` +} + +func (x *GetAccountRequest) Reset() { + *x = GetAccountRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_account_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GetAccountRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetAccountRequest) ProtoMessage() {} + +func (x *GetAccountRequest) ProtoReflect() protoreflect.Message { + mi := &file_account_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetAccountRequest.ProtoReflect.Descriptor instead. +func (*GetAccountRequest) Descriptor() ([]byte, []int) { + return file_account_proto_rawDescGZIP(), []int{2} +} + +func (x *GetAccountRequest) GetUsername() string { + if x != nil { + return x.Username + } + return "" +} + +type GetAccountResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Account *Account `protobuf:"bytes,1,opt,name=account,proto3" json:"account,omitempty"` +} + +func (x *GetAccountResponse) Reset() { + *x = GetAccountResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_account_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GetAccountResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetAccountResponse) ProtoMessage() {} + +func (x *GetAccountResponse) ProtoReflect() protoreflect.Message { + mi := &file_account_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetAccountResponse.ProtoReflect.Descriptor instead. +func (*GetAccountResponse) Descriptor() ([]byte, []int) { + return file_account_proto_rawDescGZIP(), []int{3} +} + +func (x *GetAccountResponse) GetAccount() *Account { + if x != nil { + return x.Account + } + return nil +} + +type UpdateAccountRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` + Password string `protobuf:"bytes,2,opt,name=password,proto3" json:"password,omitempty"` +} + +func (x *UpdateAccountRequest) Reset() { + *x = UpdateAccountRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_account_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *UpdateAccountRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*UpdateAccountRequest) ProtoMessage() {} + +func (x *UpdateAccountRequest) ProtoReflect() protoreflect.Message { + mi := &file_account_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use UpdateAccountRequest.ProtoReflect.Descriptor instead. +func (*UpdateAccountRequest) Descriptor() ([]byte, []int) { + return file_account_proto_rawDescGZIP(), []int{4} +} + +func (x *UpdateAccountRequest) GetUsername() string { + if x != nil { + return x.Username + } + return "" +} + +func (x *UpdateAccountRequest) GetPassword() string { + if x != nil { + return x.Password + } + return "" +} + +type Account struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` + Password string `protobuf:"bytes,2,opt,name=password,proto3" json:"password,omitempty"` +} + +func (x *Account) Reset() { + *x = Account{} + if protoimpl.UnsafeEnabled { + mi := &file_account_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Account) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Account) ProtoMessage() {} + +func (x *Account) ProtoReflect() protoreflect.Message { + mi := &file_account_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Account.ProtoReflect.Descriptor instead. +func (*Account) Descriptor() ([]byte, []int) { + return file_account_proto_rawDescGZIP(), []int{5} +} + +func (x *Account) GetUsername() string { + if x != nil { + return x.Username + } + return "" +} + +func (x *Account) GetPassword() string { + if x != nil { + return x.Password + } + return "" +} + +type DeleteAccountRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` +} + +func (x *DeleteAccountRequest) Reset() { + *x = DeleteAccountRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_account_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DeleteAccountRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DeleteAccountRequest) ProtoMessage() {} + +func (x *DeleteAccountRequest) ProtoReflect() protoreflect.Message { + mi := &file_account_proto_msgTypes[6] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DeleteAccountRequest.ProtoReflect.Descriptor instead. +func (*DeleteAccountRequest) Descriptor() ([]byte, []int) { + return file_account_proto_rawDescGZIP(), []int{6} +} + +func (x *DeleteAccountRequest) GetUsername() string { + if x != nil { + return x.Username + } + return "" +} + +var File_account_proto protoreflect.FileDescriptor + +var file_account_proto_rawDesc = []byte{ + 0x0a, 0x0d, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, + 0x0e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x61, 0x70, 0x69, 0x1a, + 0x1c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x61, 0x6e, 0x6e, 0x6f, + 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1b, 0x67, + 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x65, + 0x6d, 0x70, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x46, 0x0a, 0x13, 0x4c, 0x69, + 0x73, 0x74, 0x41, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x70, 0x61, 0x67, 0x65, 0x5f, 0x73, 0x69, 0x7a, 0x65, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x70, 0x61, 0x67, 0x65, 0x53, 0x69, 0x7a, 0x65, 0x12, 0x12, + 0x0a, 0x04, 0x70, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x70, 0x61, + 0x67, 0x65, 0x22, 0x6c, 0x0a, 0x14, 0x4c, 0x69, 0x73, 0x74, 0x41, 0x63, 0x63, 0x6f, 0x75, 0x6e, + 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x33, 0x0a, 0x08, 0x61, 0x63, + 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x67, + 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x41, 0x63, + 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x52, 0x08, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x73, 0x12, + 0x1f, 0x0a, 0x0b, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x5f, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x43, 0x6f, 0x75, 0x6e, 0x74, + 0x22, 0x2f, 0x0a, 0x11, 0x47, 0x65, 0x74, 0x41, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, + 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, + 0x65, 0x22, 0x47, 0x0a, 0x12, 0x47, 0x65, 0x74, 0x41, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x31, 0x0a, 0x07, 0x61, 0x63, 0x63, 0x6f, 0x75, + 0x6e, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, + 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x41, 0x63, 0x63, 0x6f, 0x75, 0x6e, + 0x74, 0x52, 0x07, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x22, 0x4e, 0x0a, 0x14, 0x55, 0x70, + 0x64, 0x61, 0x74, 0x65, 0x41, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x1a, + 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x41, 0x0a, 0x07, 0x41, 0x63, + 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, + 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, + 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x32, 0x0a, + 0x14, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x41, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, + 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, + 0x65, 0x32, 0xbd, 0x03, 0x0a, 0x0e, 0x41, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x53, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x12, 0x67, 0x0a, 0x04, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x23, 0x2e, 0x67, + 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x4c, 0x69, + 0x73, 0x74, 0x41, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x24, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x61, + 0x70, 0x69, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x41, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x73, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x14, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x0e, 0x12, + 0x0c, 0x2f, 0x76, 0x31, 0x2f, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x73, 0x12, 0x6d, 0x0a, + 0x03, 0x47, 0x65, 0x74, 0x12, 0x21, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, 0x75, 0x74, + 0x68, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x47, 0x65, 0x74, 0x41, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x22, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, + 0x61, 0x75, 0x74, 0x68, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x47, 0x65, 0x74, 0x41, 0x63, 0x63, 0x6f, + 0x75, 0x6e, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1f, 0x82, 0xd3, 0xe4, + 0x93, 0x02, 0x19, 0x12, 0x17, 0x2f, 0x76, 0x31, 0x2f, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, + 0x73, 0x2f, 0x7b, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, 0x65, 0x7d, 0x12, 0x6a, 0x0a, 0x06, + 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x24, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, + 0x75, 0x74, 0x68, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x41, 0x63, + 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x67, + 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x22, 0x22, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x1c, 0x22, 0x17, 0x2f, 0x76, + 0x31, 0x2f, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x73, 0x2f, 0x7b, 0x75, 0x73, 0x65, 0x72, + 0x6e, 0x61, 0x6d, 0x65, 0x7d, 0x3a, 0x01, 0x2a, 0x12, 0x67, 0x0a, 0x06, 0x44, 0x65, 0x6c, 0x65, + 0x74, 0x65, 0x12, 0x24, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, + 0x61, 0x70, 0x69, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x41, 0x63, 0x63, 0x6f, 0x75, 0x6e, + 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, + 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, + 0x22, 0x1f, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x19, 0x2a, 0x17, 0x2f, 0x76, 0x31, 0x2f, 0x61, 0x63, + 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x73, 0x2f, 0x7b, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, 0x65, + 0x7d, 0x42, 0x08, 0x5a, 0x06, 0x2e, 0x3b, 0x61, 0x75, 0x74, 0x68, 0x62, 0x06, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x33, +} + +var ( + file_account_proto_rawDescOnce sync.Once + file_account_proto_rawDescData = file_account_proto_rawDesc +) + +func file_account_proto_rawDescGZIP() []byte { + file_account_proto_rawDescOnce.Do(func() { + file_account_proto_rawDescData = protoimpl.X.CompressGZIP(file_account_proto_rawDescData) + }) + return file_account_proto_rawDescData +} + +var file_account_proto_msgTypes = make([]protoimpl.MessageInfo, 7) +var file_account_proto_goTypes = []interface{}{ + (*ListAccountsRequest)(nil), // 0: gmqtt.auth.api.ListAccountsRequest + (*ListAccountsResponse)(nil), // 1: gmqtt.auth.api.ListAccountsResponse + (*GetAccountRequest)(nil), // 2: gmqtt.auth.api.GetAccountRequest + (*GetAccountResponse)(nil), // 3: gmqtt.auth.api.GetAccountResponse + (*UpdateAccountRequest)(nil), // 4: gmqtt.auth.api.UpdateAccountRequest + (*Account)(nil), // 5: gmqtt.auth.api.Account + (*DeleteAccountRequest)(nil), // 6: gmqtt.auth.api.DeleteAccountRequest + (*empty.Empty)(nil), // 7: google.protobuf.Empty +} +var file_account_proto_depIdxs = []int32{ + 5, // 0: gmqtt.auth.api.ListAccountsResponse.accounts:type_name -> gmqtt.auth.api.Account + 5, // 1: gmqtt.auth.api.GetAccountResponse.account:type_name -> gmqtt.auth.api.Account + 0, // 2: gmqtt.auth.api.AccountService.List:input_type -> gmqtt.auth.api.ListAccountsRequest + 2, // 3: gmqtt.auth.api.AccountService.Get:input_type -> gmqtt.auth.api.GetAccountRequest + 4, // 4: gmqtt.auth.api.AccountService.Update:input_type -> gmqtt.auth.api.UpdateAccountRequest + 6, // 5: gmqtt.auth.api.AccountService.Delete:input_type -> gmqtt.auth.api.DeleteAccountRequest + 1, // 6: gmqtt.auth.api.AccountService.List:output_type -> gmqtt.auth.api.ListAccountsResponse + 3, // 7: gmqtt.auth.api.AccountService.Get:output_type -> gmqtt.auth.api.GetAccountResponse + 7, // 8: gmqtt.auth.api.AccountService.Update:output_type -> google.protobuf.Empty + 7, // 9: gmqtt.auth.api.AccountService.Delete:output_type -> google.protobuf.Empty + 6, // [6:10] is the sub-list for method output_type + 2, // [2:6] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_account_proto_init() } +func file_account_proto_init() { + if File_account_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_account_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ListAccountsRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_account_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ListAccountsResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_account_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetAccountRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_account_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetAccountResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_account_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*UpdateAccountRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_account_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Account); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_account_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DeleteAccountRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_account_proto_rawDesc, + NumEnums: 0, + NumMessages: 7, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_account_proto_goTypes, + DependencyIndexes: file_account_proto_depIdxs, + MessageInfos: file_account_proto_msgTypes, + }.Build() + File_account_proto = out.File + file_account_proto_rawDesc = nil + file_account_proto_goTypes = nil + file_account_proto_depIdxs = nil +} diff --git a/internal/hummingbird/mqttbroker/plugin/auth/account.pb.gw.go b/internal/hummingbird/mqttbroker/plugin/auth/account.pb.gw.go new file mode 100644 index 0000000..ac69e99 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/auth/account.pb.gw.go @@ -0,0 +1,472 @@ +// Code generated by protoc-gen-grpc-gateway. DO NOT EDIT. +// source: account.proto + +/* +Package auth is a reverse proxy. + +It translates gRPC into RESTful JSON APIs. +*/ +package auth + +import ( + "context" + "io" + "net/http" + + "github.com/golang/protobuf/descriptor" + "github.com/golang/protobuf/proto" + "github.com/grpc-ecosystem/grpc-gateway/runtime" + "github.com/grpc-ecosystem/grpc-gateway/utilities" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/status" +) + +// Suppress "imported and not used" errors +var _ codes.Code +var _ io.Reader +var _ status.Status +var _ = runtime.String +var _ = utilities.NewDoubleArray +var _ = descriptor.ForMessage + +var ( + filter_AccountService_List_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)} +) + +func request_AccountService_List_0(ctx context.Context, marshaler runtime.Marshaler, client AccountServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq ListAccountsRequest + var metadata runtime.ServerMetadata + + if err := req.ParseForm(); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_AccountService_List_0); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := client.List(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err + +} + +func local_request_AccountService_List_0(ctx context.Context, marshaler runtime.Marshaler, server AccountServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq ListAccountsRequest + var metadata runtime.ServerMetadata + + if err := runtime.PopulateQueryParameters(&protoReq, req.URL.Query(), filter_AccountService_List_0); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := server.List(ctx, &protoReq) + return msg, metadata, err + +} + +func request_AccountService_Get_0(ctx context.Context, marshaler runtime.Marshaler, client AccountServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq GetAccountRequest + var metadata runtime.ServerMetadata + + var ( + val string + ok bool + err error + _ = err + ) + + val, ok = pathParams["username"] + if !ok { + return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "username") + } + + protoReq.Username, err = runtime.String(val) + + if err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "username", err) + } + + msg, err := client.Get(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err + +} + +func local_request_AccountService_Get_0(ctx context.Context, marshaler runtime.Marshaler, server AccountServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq GetAccountRequest + var metadata runtime.ServerMetadata + + var ( + val string + ok bool + err error + _ = err + ) + + val, ok = pathParams["username"] + if !ok { + return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "username") + } + + protoReq.Username, err = runtime.String(val) + + if err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "username", err) + } + + msg, err := server.Get(ctx, &protoReq) + return msg, metadata, err + +} + +func request_AccountService_Update_0(ctx context.Context, marshaler runtime.Marshaler, client AccountServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq UpdateAccountRequest + var metadata runtime.ServerMetadata + + newReader, berr := utilities.IOReaderFactory(req.Body) + if berr != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) + } + if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + var ( + val string + ok bool + err error + _ = err + ) + + val, ok = pathParams["username"] + if !ok { + return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "username") + } + + protoReq.Username, err = runtime.String(val) + + if err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "username", err) + } + + msg, err := client.Update(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err + +} + +func local_request_AccountService_Update_0(ctx context.Context, marshaler runtime.Marshaler, server AccountServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq UpdateAccountRequest + var metadata runtime.ServerMetadata + + newReader, berr := utilities.IOReaderFactory(req.Body) + if berr != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) + } + if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + var ( + val string + ok bool + err error + _ = err + ) + + val, ok = pathParams["username"] + if !ok { + return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "username") + } + + protoReq.Username, err = runtime.String(val) + + if err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "username", err) + } + + msg, err := server.Update(ctx, &protoReq) + return msg, metadata, err + +} + +func request_AccountService_Delete_0(ctx context.Context, marshaler runtime.Marshaler, client AccountServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq DeleteAccountRequest + var metadata runtime.ServerMetadata + + var ( + val string + ok bool + err error + _ = err + ) + + val, ok = pathParams["username"] + if !ok { + return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "username") + } + + protoReq.Username, err = runtime.String(val) + + if err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "username", err) + } + + msg, err := client.Delete(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err + +} + +func local_request_AccountService_Delete_0(ctx context.Context, marshaler runtime.Marshaler, server AccountServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq DeleteAccountRequest + var metadata runtime.ServerMetadata + + var ( + val string + ok bool + err error + _ = err + ) + + val, ok = pathParams["username"] + if !ok { + return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "username") + } + + protoReq.Username, err = runtime.String(val) + + if err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "username", err) + } + + msg, err := server.Delete(ctx, &protoReq) + return msg, metadata, err + +} + +// RegisterAccountServiceHandlerServer registers the http handlers for service AccountService to "mux". +// UnaryRPC :call AccountServiceServer directly. +// StreamingRPC :currently unsupported pending https://github.com/grpc/grpc-go/issues/906. +func RegisterAccountServiceHandlerServer(ctx context.Context, mux *runtime.ServeMux, server AccountServiceServer) error { + + mux.Handle("GET", pattern_AccountService_List_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_AccountService_List_0(rctx, inboundMarshaler, server, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_AccountService_List_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("GET", pattern_AccountService_Get_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_AccountService_Get_0(rctx, inboundMarshaler, server, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_AccountService_Get_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("POST", pattern_AccountService_Update_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_AccountService_Update_0(rctx, inboundMarshaler, server, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_AccountService_Update_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("DELETE", pattern_AccountService_Delete_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_AccountService_Delete_0(rctx, inboundMarshaler, server, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_AccountService_Delete_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + return nil +} + +// RegisterAccountServiceHandlerFromEndpoint is same as RegisterAccountServiceHandler but +// automatically dials to "endpoint" and closes the connection when "ctx" gets done. +func RegisterAccountServiceHandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { + conn, err := grpc.Dial(endpoint, opts...) + if err != nil { + return err + } + defer func() { + if err != nil { + if cerr := conn.Close(); cerr != nil { + grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) + } + return + } + go func() { + <-ctx.Done() + if cerr := conn.Close(); cerr != nil { + grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) + } + }() + }() + + return RegisterAccountServiceHandler(ctx, mux, conn) +} + +// RegisterAccountServiceHandler registers the http handlers for service AccountService to "mux". +// The handlers forward requests to the grpc endpoint over "conn". +func RegisterAccountServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { + return RegisterAccountServiceHandlerClient(ctx, mux, NewAccountServiceClient(conn)) +} + +// RegisterAccountServiceHandlerClient registers the http handlers for service AccountService +// to "mux". The handlers forward requests to the grpc endpoint over the given implementation of "AccountServiceClient". +// Note: the gRPC framework executes interceptors within the gRPC handler. If the passed in "AccountServiceClient" +// doesn't go through the normal gRPC flow (creating a gRPC client etc.) then it will be up to the passed in +// "AccountServiceClient" to call the correct interceptors. +func RegisterAccountServiceHandlerClient(ctx context.Context, mux *runtime.ServeMux, client AccountServiceClient) error { + + mux.Handle("GET", pattern_AccountService_List_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_AccountService_List_0(rctx, inboundMarshaler, client, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_AccountService_List_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("GET", pattern_AccountService_Get_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_AccountService_Get_0(rctx, inboundMarshaler, client, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_AccountService_Get_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("POST", pattern_AccountService_Update_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_AccountService_Update_0(rctx, inboundMarshaler, client, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_AccountService_Update_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("DELETE", pattern_AccountService_Delete_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_AccountService_Delete_0(rctx, inboundMarshaler, client, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_AccountService_Delete_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + return nil +} + +var ( + pattern_AccountService_List_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1}, []string{"v1", "accounts"}, "", runtime.AssumeColonVerbOpt(true))) + + pattern_AccountService_Get_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 1, 0, 4, 1, 5, 2}, []string{"v1", "accounts", "username"}, "", runtime.AssumeColonVerbOpt(true))) + + pattern_AccountService_Update_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 1, 0, 4, 1, 5, 2}, []string{"v1", "accounts", "username"}, "", runtime.AssumeColonVerbOpt(true))) + + pattern_AccountService_Delete_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 1, 0, 4, 1, 5, 2}, []string{"v1", "accounts", "username"}, "", runtime.AssumeColonVerbOpt(true))) +) + +var ( + forward_AccountService_List_0 = runtime.ForwardResponseMessage + + forward_AccountService_Get_0 = runtime.ForwardResponseMessage + + forward_AccountService_Update_0 = runtime.ForwardResponseMessage + + forward_AccountService_Delete_0 = runtime.ForwardResponseMessage +) diff --git a/internal/hummingbird/mqttbroker/plugin/auth/account_grpc.pb.go b/internal/hummingbird/mqttbroker/plugin/auth/account_grpc.pb.go new file mode 100644 index 0000000..f733dd1 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/auth/account_grpc.pb.go @@ -0,0 +1,219 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. + +package auth + +import ( + context "context" + + empty "github.com/golang/protobuf/ptypes/empty" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion7 + +// AccountServiceClient is the client API for AccountService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type AccountServiceClient interface { + // List all accounts + List(ctx context.Context, in *ListAccountsRequest, opts ...grpc.CallOption) (*ListAccountsResponse, error) + // Get the account for given username. + // Return NotFound error when account not found. + Get(ctx context.Context, in *GetAccountRequest, opts ...grpc.CallOption) (*GetAccountResponse, error) + // Update the password for the account. + // This API will create the account if not exists. + Update(ctx context.Context, in *UpdateAccountRequest, opts ...grpc.CallOption) (*empty.Empty, error) + // Delete the account for given username + Delete(ctx context.Context, in *DeleteAccountRequest, opts ...grpc.CallOption) (*empty.Empty, error) +} + +type accountServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewAccountServiceClient(cc grpc.ClientConnInterface) AccountServiceClient { + return &accountServiceClient{cc} +} + +func (c *accountServiceClient) List(ctx context.Context, in *ListAccountsRequest, opts ...grpc.CallOption) (*ListAccountsResponse, error) { + out := new(ListAccountsResponse) + err := c.cc.Invoke(ctx, "/gmqtt.auth.api.AccountService/List", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *accountServiceClient) Get(ctx context.Context, in *GetAccountRequest, opts ...grpc.CallOption) (*GetAccountResponse, error) { + out := new(GetAccountResponse) + err := c.cc.Invoke(ctx, "/gmqtt.auth.api.AccountService/Get", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *accountServiceClient) Update(ctx context.Context, in *UpdateAccountRequest, opts ...grpc.CallOption) (*empty.Empty, error) { + out := new(empty.Empty) + err := c.cc.Invoke(ctx, "/gmqtt.auth.api.AccountService/Update", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *accountServiceClient) Delete(ctx context.Context, in *DeleteAccountRequest, opts ...grpc.CallOption) (*empty.Empty, error) { + out := new(empty.Empty) + err := c.cc.Invoke(ctx, "/gmqtt.auth.api.AccountService/Delete", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// AccountServiceServer is the server API for AccountService service. +// All implementations must embed UnimplementedAccountServiceServer +// for forward compatibility +type AccountServiceServer interface { + // List all accounts + List(context.Context, *ListAccountsRequest) (*ListAccountsResponse, error) + // Get the account for given username. + // Return NotFound error when account not found. + Get(context.Context, *GetAccountRequest) (*GetAccountResponse, error) + // Update the password for the account. + // This API will create the account if not exists. + Update(context.Context, *UpdateAccountRequest) (*empty.Empty, error) + // Delete the account for given username + Delete(context.Context, *DeleteAccountRequest) (*empty.Empty, error) + mustEmbedUnimplementedAccountServiceServer() +} + +// UnimplementedAccountServiceServer must be embedded to have forward compatible implementations. +type UnimplementedAccountServiceServer struct { +} + +func (UnimplementedAccountServiceServer) List(context.Context, *ListAccountsRequest) (*ListAccountsResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method List not implemented") +} +func (UnimplementedAccountServiceServer) Get(context.Context, *GetAccountRequest) (*GetAccountResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Get not implemented") +} +func (UnimplementedAccountServiceServer) Update(context.Context, *UpdateAccountRequest) (*empty.Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method Update not implemented") +} +func (UnimplementedAccountServiceServer) Delete(context.Context, *DeleteAccountRequest) (*empty.Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method Delete not implemented") +} +func (UnimplementedAccountServiceServer) mustEmbedUnimplementedAccountServiceServer() {} + +// UnsafeAccountServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to AccountServiceServer will +// result in compilation errors. +type UnsafeAccountServiceServer interface { + mustEmbedUnimplementedAccountServiceServer() +} + +func RegisterAccountServiceServer(s grpc.ServiceRegistrar, srv AccountServiceServer) { + s.RegisterService(&_AccountService_serviceDesc, srv) +} + +func _AccountService_List_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ListAccountsRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(AccountServiceServer).List(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/gmqtt.auth.api.AccountService/List", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(AccountServiceServer).List(ctx, req.(*ListAccountsRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _AccountService_Get_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetAccountRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(AccountServiceServer).Get(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/gmqtt.auth.api.AccountService/Get", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(AccountServiceServer).Get(ctx, req.(*GetAccountRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _AccountService_Update_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(UpdateAccountRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(AccountServiceServer).Update(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/gmqtt.auth.api.AccountService/Update", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(AccountServiceServer).Update(ctx, req.(*UpdateAccountRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _AccountService_Delete_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(DeleteAccountRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(AccountServiceServer).Delete(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/gmqtt.auth.api.AccountService/Delete", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(AccountServiceServer).Delete(ctx, req.(*DeleteAccountRequest)) + } + return interceptor(ctx, in, info, handler) +} + +var _AccountService_serviceDesc = grpc.ServiceDesc{ + ServiceName: "gmqtt.auth.api.AccountService", + HandlerType: (*AccountServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "List", + Handler: _AccountService_List_Handler, + }, + { + MethodName: "Get", + Handler: _AccountService_Get_Handler, + }, + { + MethodName: "Update", + Handler: _AccountService_Update_Handler, + }, + { + MethodName: "Delete", + Handler: _AccountService_Delete_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "account.proto", +} diff --git a/internal/hummingbird/mqttbroker/plugin/auth/account_grpc.pb_mock.go b/internal/hummingbird/mqttbroker/plugin/auth/account_grpc.pb_mock.go new file mode 100644 index 0000000..b655e3d --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/auth/account_grpc.pb_mock.go @@ -0,0 +1,247 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: plugin/auth/account_grpc.pb.go + +// Package auth is a generated GoMock package. +package auth + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + empty "github.com/golang/protobuf/ptypes/empty" + grpc "google.golang.org/grpc" +) + +// MockAccountServiceClient is a mock of AccountServiceClient interface +type MockAccountServiceClient struct { + ctrl *gomock.Controller + recorder *MockAccountServiceClientMockRecorder +} + +// MockAccountServiceClientMockRecorder is the mock recorder for MockAccountServiceClient +type MockAccountServiceClientMockRecorder struct { + mock *MockAccountServiceClient +} + +// NewMockAccountServiceClient creates a new mock instance +func NewMockAccountServiceClient(ctrl *gomock.Controller) *MockAccountServiceClient { + mock := &MockAccountServiceClient{ctrl: ctrl} + mock.recorder = &MockAccountServiceClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockAccountServiceClient) EXPECT() *MockAccountServiceClientMockRecorder { + return m.recorder +} + +// List mocks base method +func (m *MockAccountServiceClient) List(ctx context.Context, in *ListAccountsRequest, opts ...grpc.CallOption) (*ListAccountsResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, in} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "List", varargs...) + ret0, _ := ret[0].(*ListAccountsResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// List indicates an expected call of List +func (mr *MockAccountServiceClientMockRecorder) List(ctx, in interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, in}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockAccountServiceClient)(nil).List), varargs...) +} + +// Get mocks base method +func (m *MockAccountServiceClient) Get(ctx context.Context, in *GetAccountRequest, opts ...grpc.CallOption) (*GetAccountResponse, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, in} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Get", varargs...) + ret0, _ := ret[0].(*GetAccountResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get +func (mr *MockAccountServiceClientMockRecorder) Get(ctx, in interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, in}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockAccountServiceClient)(nil).Get), varargs...) +} + +// Update mocks base method +func (m *MockAccountServiceClient) Update(ctx context.Context, in *UpdateAccountRequest, opts ...grpc.CallOption) (*empty.Empty, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, in} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Update", varargs...) + ret0, _ := ret[0].(*empty.Empty) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Update indicates an expected call of Update +func (mr *MockAccountServiceClientMockRecorder) Update(ctx, in interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, in}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockAccountServiceClient)(nil).Update), varargs...) +} + +// Delete mocks base method +func (m *MockAccountServiceClient) Delete(ctx context.Context, in *DeleteAccountRequest, opts ...grpc.CallOption) (*empty.Empty, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, in} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Delete", varargs...) + ret0, _ := ret[0].(*empty.Empty) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Delete indicates an expected call of Delete +func (mr *MockAccountServiceClientMockRecorder) Delete(ctx, in interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, in}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockAccountServiceClient)(nil).Delete), varargs...) +} + +// MockAccountServiceServer is a mock of AccountServiceServer interface +type MockAccountServiceServer struct { + ctrl *gomock.Controller + recorder *MockAccountServiceServerMockRecorder +} + +// MockAccountServiceServerMockRecorder is the mock recorder for MockAccountServiceServer +type MockAccountServiceServerMockRecorder struct { + mock *MockAccountServiceServer +} + +// NewMockAccountServiceServer creates a new mock instance +func NewMockAccountServiceServer(ctrl *gomock.Controller) *MockAccountServiceServer { + mock := &MockAccountServiceServer{ctrl: ctrl} + mock.recorder = &MockAccountServiceServerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockAccountServiceServer) EXPECT() *MockAccountServiceServerMockRecorder { + return m.recorder +} + +// List mocks base method +func (m *MockAccountServiceServer) List(arg0 context.Context, arg1 *ListAccountsRequest) (*ListAccountsResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List", arg0, arg1) + ret0, _ := ret[0].(*ListAccountsResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// List indicates an expected call of List +func (mr *MockAccountServiceServerMockRecorder) List(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockAccountServiceServer)(nil).List), arg0, arg1) +} + +// Get mocks base method +func (m *MockAccountServiceServer) Get(arg0 context.Context, arg1 *GetAccountRequest) (*GetAccountResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0, arg1) + ret0, _ := ret[0].(*GetAccountResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get +func (mr *MockAccountServiceServerMockRecorder) Get(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockAccountServiceServer)(nil).Get), arg0, arg1) +} + +// Update mocks base method +func (m *MockAccountServiceServer) Update(arg0 context.Context, arg1 *UpdateAccountRequest) (*empty.Empty, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Update", arg0, arg1) + ret0, _ := ret[0].(*empty.Empty) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Update indicates an expected call of Update +func (mr *MockAccountServiceServerMockRecorder) Update(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockAccountServiceServer)(nil).Update), arg0, arg1) +} + +// Delete mocks base method +func (m *MockAccountServiceServer) Delete(arg0 context.Context, arg1 *DeleteAccountRequest) (*empty.Empty, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Delete", arg0, arg1) + ret0, _ := ret[0].(*empty.Empty) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Delete indicates an expected call of Delete +func (mr *MockAccountServiceServerMockRecorder) Delete(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockAccountServiceServer)(nil).Delete), arg0, arg1) +} + +// mustEmbedUnimplementedAccountServiceServer mocks base method +func (m *MockAccountServiceServer) mustEmbedUnimplementedAccountServiceServer() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "mustEmbedUnimplementedAccountServiceServer") +} + +// mustEmbedUnimplementedAccountServiceServer indicates an expected call of mustEmbedUnimplementedAccountServiceServer +func (mr *MockAccountServiceServerMockRecorder) mustEmbedUnimplementedAccountServiceServer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "mustEmbedUnimplementedAccountServiceServer", reflect.TypeOf((*MockAccountServiceServer)(nil).mustEmbedUnimplementedAccountServiceServer)) +} + +// MockUnsafeAccountServiceServer is a mock of UnsafeAccountServiceServer interface +type MockUnsafeAccountServiceServer struct { + ctrl *gomock.Controller + recorder *MockUnsafeAccountServiceServerMockRecorder +} + +// MockUnsafeAccountServiceServerMockRecorder is the mock recorder for MockUnsafeAccountServiceServer +type MockUnsafeAccountServiceServerMockRecorder struct { + mock *MockUnsafeAccountServiceServer +} + +// NewMockUnsafeAccountServiceServer creates a new mock instance +func NewMockUnsafeAccountServiceServer(ctrl *gomock.Controller) *MockUnsafeAccountServiceServer { + mock := &MockUnsafeAccountServiceServer{ctrl: ctrl} + mock.recorder = &MockUnsafeAccountServiceServerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockUnsafeAccountServiceServer) EXPECT() *MockUnsafeAccountServiceServerMockRecorder { + return m.recorder +} + +// mustEmbedUnimplementedAccountServiceServer mocks base method +func (m *MockUnsafeAccountServiceServer) mustEmbedUnimplementedAccountServiceServer() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "mustEmbedUnimplementedAccountServiceServer") +} + +// mustEmbedUnimplementedAccountServiceServer indicates an expected call of mustEmbedUnimplementedAccountServiceServer +func (mr *MockUnsafeAccountServiceServerMockRecorder) mustEmbedUnimplementedAccountServiceServer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "mustEmbedUnimplementedAccountServiceServer", reflect.TypeOf((*MockUnsafeAccountServiceServer)(nil).mustEmbedUnimplementedAccountServiceServer)) +} diff --git a/internal/hummingbird/mqttbroker/plugin/auth/auth.go b/internal/hummingbird/mqttbroker/plugin/auth/auth.go new file mode 100644 index 0000000..c4e5509 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/auth/auth.go @@ -0,0 +1,179 @@ +package auth + +import ( + "crypto/md5" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "hash" + "io/ioutil" + "os" + "path" + "sync" + + "go.uber.org/zap" + "golang.org/x/crypto/bcrypt" + "gopkg.in/yaml.v2" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/plugin/admin" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/server" +) + +var _ server.Plugin = (*Auth)(nil) + +const Name = "auth" + +func init() { + server.RegisterPlugin(Name, New) + config.RegisterDefaultPluginConfig(Name, &DefaultConfig) +} + +func New(config config.Config) (server.Plugin, error) { + a := &Auth{ + config: config.Plugins[Name].(*Config), + indexer: admin.NewIndexer(), + pwdDir: config.ConfigDir, + } + a.saveFile = a.saveFileHandler + return a, nil +} + +var log *zap.Logger + +// Auth provides the username/password authentication for mqttbroker. +// The authentication data is persist in config.PasswordFile. +type Auth struct { + config *Config + pwdDir string + // gard indexer + mu sync.RWMutex + // store username/password + indexer *admin.Indexer + // saveFile persists the account data to password file. + saveFile func() error +} + +// generatePassword generates the hashed password for the plain password. +func (a *Auth) generatePassword(password string) (hashedPassword string, err error) { + var h hash.Hash + switch a.config.Hash { + case Plain: + return password, nil + case MD5: + h = md5.New() + case SHA256: + h = sha256.New() + case Bcrypt: + pwd, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.MinCost) + return string(pwd), err + default: + // just in case. + panic("invalid hash type") + } + _, err = h.Write([]byte(password)) + if err != nil { + return "", err + } + rs := h.Sum(nil) + return hex.EncodeToString(rs), nil +} + +func (a *Auth) mustEmbedUnimplementedAccountServiceServer() { + return +} + +func (a *Auth) validate(username, password string) (permitted bool, err error) { + a.mu.RLock() + elem := a.indexer.GetByID(username) + a.mu.RUnlock() + var hashedPassword string + if elem == nil { + return false, nil + } + ac := elem.Value.(*Account) + hashedPassword = ac.Password + var h hash.Hash + switch a.config.Hash { + case Plain: + return hashedPassword == password, nil + case MD5: + h = md5.New() + case SHA256: + h = sha256.New() + case Bcrypt: + return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password)) == nil, nil + default: + // just in case. + panic("invalid hash type") + } + _, err = h.Write([]byte(password)) + if err != nil { + return false, err + } + rs := h.Sum(nil) + return hashedPassword == hex.EncodeToString(rs), nil +} + +var registerAPI = func(service server.Server, a *Auth) error { + apiRegistrar := service.APIRegistrar() + RegisterAccountServiceServer(apiRegistrar, a) + err := apiRegistrar.RegisterHTTPHandler(RegisterAccountServiceHandlerFromEndpoint) + return err +} + +func (a *Auth) Load(service server.Server) error { + err := registerAPI(service, a) + log = server.LoggerWithField(zap.String("plugin", Name)) + + var pwdFile string + if path.IsAbs(a.config.PasswordFile) { + pwdFile = a.config.PasswordFile + } else { + pwdFile = path.Join(a.pwdDir, a.config.PasswordFile) + } + f, err := os.OpenFile(pwdFile, os.O_CREATE|os.O_RDONLY, 0666) + if err != nil { + return err + } + defer f.Close() + b, err := ioutil.ReadAll(f) + if err != nil { + return err + } + var acts []*Account + err = yaml.Unmarshal(b, &acts) + if err != nil { + return err + } + log.Info("authentication data loaded", + zap.String("hash", a.config.Hash), + zap.Int("account_nums", len(acts)), + zap.String("password_file", pwdFile)) + + dup := make(map[string]struct{}) + for _, v := range acts { + if v.Username == "" { + return errors.New("detect empty username in password file") + } + if _, ok := dup[v.Username]; ok { + return fmt.Errorf("detect duplicated username in password file: %s", v.Username) + } + dup[v.Username] = struct{}{} + } + a.mu.Lock() + defer a.mu.Unlock() + for _, v := range acts { + a.indexer.Set(v.Username, v) + } + return nil +} + +func (a *Auth) Unload() error { + return nil +} + +func (a *Auth) Name() string { + return Name +} diff --git a/internal/hummingbird/mqttbroker/plugin/auth/auth_test.go b/internal/hummingbird/mqttbroker/plugin/auth/auth_test.go new file mode 100644 index 0000000..525dcb2 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/auth/auth_test.go @@ -0,0 +1,156 @@ +package auth + +import ( + "os" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/plugin/admin" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/server" +) + +func init() { + registerAPI = func(service server.Server, a *Auth) error { + return nil + } +} +func TestAuth_validate(t *testing.T) { + var tt = []struct { + name string + username string + password string + }{ + { + name: Plain, + username: "user", + password: "道路千万条,安全第一条,密码不规范,绩效两行泪", + }, { + name: MD5, + username: "user", + password: "道路千万条,安全第一条,密码不规范,绩效两行泪", + }, { + name: SHA256, + username: "user", + password: "道路千万条,安全第一条,密码不规范,绩效两行泪", + }, { + name: Bcrypt, + username: "user", + password: "道路千万条,安全第一条,密码不规范,绩效两行泪", + }, + } + for _, v := range tt { + t.Run(v.name, func(t *testing.T) { + a := assert.New(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + auth := &Auth{ + config: &Config{ + Hash: v.name, + }, + indexer: admin.NewIndexer(), + } + + hashed, err := auth.generatePassword(v.password) + a.Nil(err) + auth.indexer.Set(v.username, &Account{ + Username: v.username, + Password: hashed, + }) + ok, err := auth.validate(v.username, v.password) + a.True(ok) + a.Nil(err) + }) + } + +} + +func TestAuth_EmptyPassword(t *testing.T) { + a := assert.New(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + auth := &Auth{ + config: &Config{ + Hash: Plain, + }, + indexer: admin.NewIndexer(), + } + + hashed, err := auth.generatePassword("abc") + a.Nil(err) + auth.indexer.Set("user", &Account{ + Username: "user", + Password: hashed, + }) + ok, err := auth.validate("user", "") + a.False(ok) + a.Nil(err) +} + +func TestAuth_Load_CreateFile(t *testing.T) { + a := assert.New(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + path := "./testdata/file_not_exists.yml" + defer os.Remove("./testdata/file_not_exists.yml") + cfg := DefaultConfig + cfg.PasswordFile = path + auth, err := New(config.Config{ + Plugins: map[string]config.Configuration{ + "auth": &cfg, + }, + }) + a.Nil(err) + ms := server.NewMockServer(ctrl) + a.Nil(auth.Load(ms)) +} + +func TestAuth_Load_WithDuplicatedUsername(t *testing.T) { + a := assert.New(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + path := "./testdata/gmqtt_password_duplicated.yml" + cfg := DefaultConfig + cfg.PasswordFile = path + cfg.Hash = Plain + auth, err := New(config.Config{ + Plugins: map[string]config.Configuration{ + "auth": &cfg, + }, + }) + a.Nil(err) + ms := server.NewMockServer(ctrl) + a.Error(auth.Load(ms)) +} + +func TestAuth_Load_OK(t *testing.T) { + a := assert.New(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + path := "./testdata/gmqtt_password.yml" + cfg := DefaultConfig + cfg.PasswordFile = path + cfg.Hash = Plain + auth, err := New(config.Config{ + Plugins: map[string]config.Configuration{ + "auth": &cfg, + }, + }) + a.Nil(err) + ms := server.NewMockServer(ctrl) + a.Nil(auth.Load(ms)) + + au := auth.(*Auth) + p, err := au.validate("u1", "p1") + a.True(p) + a.Nil(err) + + p, err = au.validate("u2", "p2") + a.True(p) + a.Nil(err) +} diff --git a/internal/hummingbird/mqttbroker/plugin/auth/config.go b/internal/hummingbird/mqttbroker/plugin/auth/config.go new file mode 100644 index 0000000..d1e1aab --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/auth/config.go @@ -0,0 +1,65 @@ +package auth + +import ( + "errors" + "fmt" +) + +type hashType = string + +const ( + Plain hashType = "plain" + MD5 = "md5" + SHA256 = "sha256" + Bcrypt = "bcrypt" +) + +var ValidateHashType = []string{ + Plain, MD5, SHA256, Bcrypt, +} + +// Config is the configuration for the auth plugin. +type Config struct { + // PasswordFile is the file to store username and password. + PasswordFile string `yaml:"password_file"` + // Hash is the password hash algorithm. + // Possible values: plain | md5 | sha256 | bcrypt + Hash string `yaml:"hash"` +} + +// validate validates the configuration, and return an error if it is invalid. +func (c *Config) Validate() error { + if c.PasswordFile == "" { + return errors.New("password_file must be set") + } + for _, v := range ValidateHashType { + if v == c.Hash { + return nil + } + } + return fmt.Errorf("invalid hash type: %s", c.Hash) +} + +// DefaultConfig is the default configuration. +var DefaultConfig = Config{ + Hash: MD5, + PasswordFile: "./gmqtt_password.yml", +} + +func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { + type cfg Config + var v = &struct { + Auth cfg `yaml:"auth"` + }{ + Auth: cfg(DefaultConfig), + } + if err := unmarshal(v); err != nil { + return err + } + empty := cfg(Config{}) + if v.Auth == empty { + v.Auth = cfg(DefaultConfig) + } + *c = Config(v.Auth) + return nil +} diff --git a/internal/hummingbird/mqttbroker/plugin/auth/grpc_handler.go b/internal/hummingbird/mqttbroker/plugin/auth/grpc_handler.go new file mode 100644 index 0000000..53136f7 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/auth/grpc_handler.go @@ -0,0 +1,150 @@ +package auth + +import ( + "bufio" + "container/list" + "context" + "io/ioutil" + "os" + + "github.com/golang/protobuf/ptypes/empty" + "go.uber.org/zap" + "gopkg.in/yaml.v2" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/plugin/admin" +) + +// List lists all accounts +func (a *Auth) List(ctx context.Context, req *ListAccountsRequest) (resp *ListAccountsResponse, err error) { + page, pageSize := admin.GetPage(req.Page, req.PageSize) + offset, n := admin.GetOffsetN(page, pageSize) + a.mu.RLock() + defer a.mu.RUnlock() + resp = &ListAccountsResponse{ + Accounts: []*Account{}, + TotalCount: 0, + } + a.indexer.Iterate(func(elem *list.Element) { + resp.Accounts = append(resp.Accounts, elem.Value.(*Account)) + }, offset, n) + + resp.TotalCount = uint32(a.indexer.Len()) + return resp, nil +} + +// Get gets the account for given username. +// Return NotFound error when account not found. +func (a *Auth) Get(ctx context.Context, req *GetAccountRequest) (resp *GetAccountResponse, err error) { + if req.Username == "" { + return nil, admin.ErrInvalidArgument("username", "cannot be empty") + } + a.mu.RLock() + defer a.mu.RUnlock() + resp = &GetAccountResponse{} + if e := a.indexer.GetByID(req.Username); e != nil { + resp.Account = e.Value.(*Account) + return resp, nil + } + return nil, admin.ErrNotFound +} + +// saveFileHandler is the default handler for auth.saveFile, must call after auth.mu is locked +func (a *Auth) saveFileHandler() error { + tmpfile, err := ioutil.TempFile("./", "gmqtt_password") + if err != nil { + return err + } + w := bufio.NewWriter(tmpfile) + // get all accounts + var accounts []*Account + a.indexer.Iterate(func(elem *list.Element) { + accounts = append(accounts, elem.Value.(*Account)) + }, 0, uint(a.indexer.Len())) + + b, err := yaml.Marshal(accounts) + if err != nil { + return err + } + + _, err = w.Write(b) + if err != nil { + return err + } + err = w.Flush() + if err != nil { + return err + } + tmpfile.Close() + // replace the old password file. + return os.Rename(tmpfile.Name(), a.config.PasswordFile) +} + +// Update updates the password for the account. +// Create a new account if the account for the username is not exists. +// Update will persist the account data to the password file. +func (a *Auth) Update(ctx context.Context, req *UpdateAccountRequest) (resp *empty.Empty, err error) { + if req.Username == "" { + return nil, admin.ErrInvalidArgument("username", "cannot be empty") + } + hashedPassword, err := a.generatePassword(req.Password) + if err != nil { + return &empty.Empty{}, err + } + a.mu.Lock() + defer a.mu.Unlock() + var oact *Account + elem := a.indexer.GetByID(req.Username) + if elem != nil { + oact = elem.Value.(*Account) + } + a.indexer.Set(req.Username, &Account{ + Username: req.Username, + Password: hashedPassword, + }) + err = a.saveFile() + if err != nil { + // should rollback if failed to persist to file. + if oact == nil { + a.indexer.Remove(req.Username) + return &empty.Empty{}, err + } + a.indexer.Set(req.Username, &Account{ + Username: req.Username, + Password: oact.Password, + }) + } + if oact == nil { + log.Info("new account created", zap.String("username", req.Username)) + } else { + log.Info("password updated", zap.String("username", req.Username)) + } + + return &empty.Empty{}, err +} + +// Delete deletes the account for the username. +func (a *Auth) Delete(ctx context.Context, req *DeleteAccountRequest) (resp *empty.Empty, err error) { + if req.Username == "" { + return nil, admin.ErrInvalidArgument("username", "cannot be empty") + } + a.mu.Lock() + defer a.mu.Unlock() + act := a.indexer.GetByID(req.Username) + if act == nil { + // fast path + return &empty.Empty{}, nil + } + oact := act.Value + a.indexer.Remove(req.Username) + err = a.saveFile() + if err != nil { + // should rollback if failed to persist to file + a.indexer.Set(req.Username, &Account{ + Username: req.Username, + Password: oact.(*Account).Password, + }) + return &empty.Empty{}, err + } + log.Info("account deleted", zap.String("username", req.Username)) + return &empty.Empty{}, nil +} diff --git a/internal/hummingbird/mqttbroker/plugin/auth/grpc_handler_test.go b/internal/hummingbird/mqttbroker/plugin/auth/grpc_handler_test.go new file mode 100644 index 0000000..be3baf8 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/auth/grpc_handler_test.go @@ -0,0 +1,210 @@ +package auth + +import ( + "context" + "errors" + "io/ioutil" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "gopkg.in/yaml.v2" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" +) + +func TestAuth_List_Get_Delete(t *testing.T) { + a := assert.New(t) + path := "./testdata/gmqtt_password.yml" + cfg := DefaultConfig + cfg.PasswordFile = path + cfg.Hash = Plain + auth, err := New(config.Config{ + Plugins: map[string]config.Configuration{ + "auth": &cfg, + }, + }) + a.Nil(err) + err = auth.Load(nil) + a.Nil(err) + au := auth.(*Auth) + au.saveFile = func() error { + return nil + } + resp, err := au.List(context.Background(), &ListAccountsRequest{ + PageSize: 0, + Page: 0, + }) + a.Nil(err) + + a.EqualValues(2, resp.TotalCount) + a.Len(resp.Accounts, 2) + + act := make(map[string]string) + act["u1"] = "p1" + act["u2"] = "p2" + for _, v := range resp.Accounts { + a.Equal(act[v.Username], v.Password) + } + + getResp, err := au.Get(context.Background(), &GetAccountRequest{ + Username: "u1", + }) + a.Nil(err) + a.Equal("u1", getResp.Account.Username) + a.Equal("p1", getResp.Account.Password) + + _, err = au.Delete(context.Background(), &DeleteAccountRequest{ + Username: "u1", + }) + a.Nil(err) + + getResp, err = au.Get(context.Background(), &GetAccountRequest{ + Username: "u1", + }) + s, ok := status.FromError(err) + a.True(ok) + a.Equal(codes.NotFound, s.Code()) + +} + +func TestAuth_Update(t *testing.T) { + a := assert.New(t) + path := "./testdata/gmqtt_password.yml" + cfg := DefaultConfig + cfg.PasswordFile = path + cfg.Hash = Plain + auth, err := New(config.Config{ + Plugins: map[string]config.Configuration{ + "auth": &cfg, + }, + }) + a.Nil(err) + err = auth.Load(nil) + a.Nil(err) + au := auth.(*Auth) + au.saveFile = func() error { + return nil + } + _, err = au.Update(context.Background(), &UpdateAccountRequest{ + Username: "u1", + Password: "p2", + }) + a.Nil(err) + + l := au.indexer.GetByID("u1") + act := l.Value.(*Account) + a.Equal("p2", act.Password) + + // test rollback + au.saveFile = func() error { + return errors.New("some error") + } + _, err = au.Update(context.Background(), &UpdateAccountRequest{ + Username: "u1", + Password: "u3", + }) + a.NotNil(err) + l = au.indexer.GetByID("u1") + act = l.Value.(*Account) + // not change because fails to persist to password file. + a.Equal("p2", act.Password) + + _, err = au.Update(context.Background(), &UpdateAccountRequest{ + Username: "u10", + Password: "p3", + }) + a.NotNil(err) + // not exists because fails to persist to password file. + l = au.indexer.GetByID("u10") + a.Nil(l) + +} + +func TestAuth_Delete(t *testing.T) { + a := assert.New(t) + path := "./testdata/gmqtt_password.yml" + cfg := DefaultConfig + cfg.PasswordFile = path + cfg.Hash = Plain + auth, err := New(config.Config{ + Plugins: map[string]config.Configuration{ + "auth": &cfg, + }, + }) + a.Nil(err) + err = auth.Load(nil) + a.Nil(err) + au := auth.(*Auth) + au.saveFile = func() error { + return errors.New("some error") + } + _, err = au.Delete(context.Background(), &DeleteAccountRequest{ + Username: "u1", + }) + a.NotNil(err) + + resp, err := au.Get(context.Background(), &GetAccountRequest{ + Username: "u1", + }) + a.Nil(err) + a.Equal("u1", resp.Account.Username) + a.Equal("p1", resp.Account.Password) + + au.saveFile = func() error { + return nil + } + + _, err = au.Delete(context.Background(), &DeleteAccountRequest{ + Username: "u1", + }) + a.Nil(err) + + resp, err = au.Get(context.Background(), &GetAccountRequest{ + Username: "u1", + }) + s, ok := status.FromError(err) + a.True(ok) + a.Equal(codes.NotFound, s.Code()) +} + +func TestAuth_saveFileHandler(t *testing.T) { + a := assert.New(t) + path := "./testdata/gmqtt_password_save.yml" + originBytes, err := ioutil.ReadFile(path) + a.Nil(err) + defer func() { + // restore + ioutil.WriteFile(path, originBytes, 0666) + }() + cfg := DefaultConfig + cfg.PasswordFile = path + cfg.Hash = Plain + auth, err := New(config.Config{ + Plugins: map[string]config.Configuration{ + "auth": &cfg, + }, + }) + a.Nil(err) + err = auth.Load(nil) + a.Nil(err) + au := auth.(*Auth) + au.indexer.Set("u1", &Account{ + Username: "u1", + Password: "p11", + }) + au.indexer.Remove("u2") + err = au.saveFileHandler() + a.Nil(err) + b, err := ioutil.ReadFile(path) + a.Nil(err) + + var rs []*Account + err = yaml.Unmarshal(b, &rs) + a.Nil(err) + a.Len(rs, 1) + a.Equal("u1", rs[0].Username) + a.Equal("p11", rs[0].Password) + +} diff --git a/internal/hummingbird/mqttbroker/plugin/auth/hooks.go b/internal/hummingbird/mqttbroker/plugin/auth/hooks.go new file mode 100644 index 0000000..b36e991 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/auth/hooks.go @@ -0,0 +1,45 @@ +package auth + +import ( + "context" + + "go.uber.org/zap" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/server" + "github.com/winc-link/hummingbird/internal/pkg/codes" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +func (a *Auth) HookWrapper() server.HookWrapper { + return server.HookWrapper{ + OnBasicAuthWrapper: a.OnBasicAuthWrapper, + } +} + +func (a *Auth) OnBasicAuthWrapper(pre server.OnBasicAuth) server.OnBasicAuth { + return func(ctx context.Context, client server.Client, req *server.ConnectRequest) (err error) { + err = pre(ctx, client, req) + if err != nil { + return err + } + ok, err := a.validate(string(req.Connect.Username), string(req.Connect.Password)) + if err != nil { + return err + } + if !ok { + log.Debug("authentication failed", zap.String("username", string(req.Connect.Username))) + v := client.Version() + if packets.IsVersion3X(v) { + return &codes.Error{ + Code: codes.V3NotAuthorized, + } + } + if packets.IsVersion5(v) { + return &codes.Error{ + Code: codes.NotAuthorized, + } + } + } + return nil + } +} diff --git a/internal/hummingbird/mqttbroker/plugin/auth/hooks_test.go b/internal/hummingbird/mqttbroker/plugin/auth/hooks_test.go new file mode 100644 index 0000000..fafdd13 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/auth/hooks_test.go @@ -0,0 +1,57 @@ +package auth + +import ( + "context" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/server" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +func TestAuth_OnBasicAuthWrapper(t *testing.T) { + a := assert.New(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + path := "./testdata/gmqtt_password.yml" + cfg := DefaultConfig + cfg.PasswordFile = path + cfg.Hash = Plain + auth, err := New(config.Config{ + Plugins: map[string]config.Configuration{ + "auth": &cfg, + }, + }) + mockClient := server.NewMockClient(ctrl) + mockClient.EXPECT().Version().Return(packets.Version311).AnyTimes() + a.Nil(err) + a.Nil(auth.Load(nil)) + au := auth.(*Auth) + var preCalled bool + fn := au.OnBasicAuthWrapper(func(ctx context.Context, client server.Client, req *server.ConnectRequest) (err error) { + preCalled = true + return nil + }) + // pass + a.Nil(fn(context.Background(), mockClient, &server.ConnectRequest{ + Connect: &packets.Connect{ + Username: []byte("u1"), + Password: []byte("p1"), + }, + })) + a.True(preCalled) + + // fail + a.NotNil(fn(context.Background(), mockClient, &server.ConnectRequest{ + Connect: &packets.Connect{ + Username: []byte("u1"), + Password: []byte("p11"), + }, + })) + + a.Nil(au.Unload()) +} diff --git a/internal/hummingbird/mqttbroker/plugin/auth/protos/account.proto b/internal/hummingbird/mqttbroker/plugin/auth/protos/account.proto new file mode 100644 index 0000000..a31df60 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/auth/protos/account.proto @@ -0,0 +1,70 @@ +syntax = "proto3"; + +package gmqtt.auth.api; +option go_package = ".;auth"; + +import "google/api/annotations.proto"; +import "google/protobuf/empty.proto"; + +message ListAccountsRequest { + uint32 page_size = 1; + uint32 page = 2; +} + +message ListAccountsResponse { + repeated Account accounts = 1; + uint32 total_count = 2; +} + +message GetAccountRequest { + string username = 1; +} + +message GetAccountResponse { + Account account = 1; +} + +message UpdateAccountRequest { + string username = 1; + string password = 2; +} + +message Account { + string username = 1; + string password = 2; +} + +message DeleteAccountRequest { + string username = 1; +} + +service AccountService { + // List all accounts + rpc List (ListAccountsRequest) returns (ListAccountsResponse){ + option (google.api.http) = { + get: "/v1/accounts" + }; + } + + // Get the account for given username. + // Return NotFound error when account not found. + rpc Get (GetAccountRequest) returns (GetAccountResponse){ + option (google.api.http) = { + get: "/v1/accounts/{username}" + }; + } + // Update the password for the account. + // This API will create the account if not exists. + rpc Update(UpdateAccountRequest) returns (google.protobuf.Empty) { + option (google.api.http) = { + post: "/v1/accounts/{username}" + body:"*" + }; + } + // Delete the account for given username + rpc Delete (DeleteAccountRequest) returns (google.protobuf.Empty) { + option (google.api.http) = { + delete: "/v1/accounts/{username}" + }; + } +} diff --git a/internal/hummingbird/mqttbroker/plugin/auth/protos/proto_gen.sh b/internal/hummingbird/mqttbroker/plugin/auth/protos/proto_gen.sh new file mode 100755 index 0000000..261d433 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/auth/protos/proto_gen.sh @@ -0,0 +1,8 @@ +protoc -I. \ +-I$GOPATH/src/github.com/grpc-ecosystem/grpc-gateway \ +-I$GOPATH/src/github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis \ +--go-grpc_out=../ \ +--go_out=../ \ +--grpc-gateway_out=../ \ +--swagger_out=../swagger \ +*.proto \ No newline at end of file diff --git a/internal/hummingbird/mqttbroker/plugin/auth/swagger/account.swagger.json b/internal/hummingbird/mqttbroker/plugin/auth/swagger/account.swagger.json new file mode 100644 index 0000000..aa9f57a --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/auth/swagger/account.swagger.json @@ -0,0 +1,231 @@ +{ + "swagger": "2.0", + "info": { + "title": "account.proto", + "version": "version not set" + }, + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "paths": { + "/v1/accounts": { + "get": { + "summary": "List all accounts", + "operationId": "List", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "$ref": "#/definitions/apiListAccountsResponse" + } + }, + "default": { + "description": "An unexpected error response", + "schema": { + "$ref": "#/definitions/runtimeError" + } + } + }, + "parameters": [ + { + "name": "page_size", + "in": "query", + "required": false, + "type": "integer", + "format": "int64" + }, + { + "name": "page", + "in": "query", + "required": false, + "type": "integer", + "format": "int64" + } + ], + "tags": [ + "AccountService" + ] + } + }, + "/v1/accounts/{username}": { + "get": { + "summary": "Get the account for given username.\nReturn NotFound error when account not found.", + "operationId": "Get", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "$ref": "#/definitions/apiGetAccountResponse" + } + }, + "default": { + "description": "An unexpected error response", + "schema": { + "$ref": "#/definitions/runtimeError" + } + } + }, + "parameters": [ + { + "name": "username", + "in": "path", + "required": true, + "type": "string" + } + ], + "tags": [ + "AccountService" + ] + }, + "delete": { + "summary": "Delete the account for given username", + "operationId": "Delete", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "properties": {} + } + }, + "default": { + "description": "An unexpected error response", + "schema": { + "$ref": "#/definitions/runtimeError" + } + } + }, + "parameters": [ + { + "name": "username", + "in": "path", + "required": true, + "type": "string" + } + ], + "tags": [ + "AccountService" + ] + }, + "post": { + "summary": "Update the password for the account.\nThis API will create the account if not exists.", + "operationId": "Update", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "properties": {} + } + }, + "default": { + "description": "An unexpected error response", + "schema": { + "$ref": "#/definitions/runtimeError" + } + } + }, + "parameters": [ + { + "name": "username", + "in": "path", + "required": true, + "type": "string" + }, + { + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/apiUpdateAccountRequest" + } + } + ], + "tags": [ + "AccountService" + ] + } + } + }, + "definitions": { + "apiAccount": { + "type": "object", + "properties": { + "username": { + "type": "string" + }, + "password": { + "type": "string" + } + } + }, + "apiGetAccountResponse": { + "type": "object", + "properties": { + "account": { + "$ref": "#/definitions/apiAccount" + } + } + }, + "apiListAccountsResponse": { + "type": "object", + "properties": { + "accounts": { + "type": "array", + "items": { + "$ref": "#/definitions/apiAccount" + } + }, + "total_count": { + "type": "integer", + "format": "int64" + } + } + }, + "apiUpdateAccountRequest": { + "type": "object", + "properties": { + "username": { + "type": "string" + }, + "password": { + "type": "string" + } + } + }, + "protobufAny": { + "type": "object", + "properties": { + "type_url": { + "type": "string" + }, + "value": { + "type": "string", + "format": "byte" + } + } + }, + "runtimeError": { + "type": "object", + "properties": { + "error": { + "type": "string" + }, + "code": { + "type": "integer", + "format": "int32" + }, + "message": { + "type": "string" + }, + "details": { + "type": "array", + "items": { + "$ref": "#/definitions/protobufAny" + } + } + } + } + } +} diff --git a/internal/hummingbird/mqttbroker/plugin/auth/testdata/gmqtt_password.yml b/internal/hummingbird/mqttbroker/plugin/auth/testdata/gmqtt_password.yml new file mode 100644 index 0000000..a6017fd --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/auth/testdata/gmqtt_password.yml @@ -0,0 +1,4 @@ +- username: u1 + password: p1 +- username: u2 + password: p2 \ No newline at end of file diff --git a/internal/hummingbird/mqttbroker/plugin/auth/testdata/gmqtt_password_duplicated.yml b/internal/hummingbird/mqttbroker/plugin/auth/testdata/gmqtt_password_duplicated.yml new file mode 100644 index 0000000..e5471d5 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/auth/testdata/gmqtt_password_duplicated.yml @@ -0,0 +1,4 @@ +- username: u1 + password: p1 +- username: u1 + password: p1 \ No newline at end of file diff --git a/internal/hummingbird/mqttbroker/plugin/auth/testdata/gmqtt_password_save.yml b/internal/hummingbird/mqttbroker/plugin/auth/testdata/gmqtt_password_save.yml new file mode 100644 index 0000000..a6017fd --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/auth/testdata/gmqtt_password_save.yml @@ -0,0 +1,4 @@ +- username: u1 + password: p1 +- username: u2 + password: p2 \ No newline at end of file diff --git a/internal/hummingbird/mqttbroker/plugin/federation/README.md b/internal/hummingbird/mqttbroker/plugin/federation/README.md new file mode 100644 index 0000000..475a64a --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/federation/README.md @@ -0,0 +1,255 @@ +# Federation + +Federation is a kind of clustering mechanism which provides high-availability and horizontal scaling. In Federation +mode, multiple gmqtt brokers can be grouped together and "act as one". However, it is impossible to fulfill all +requirements in MQTT specification in a distributed environment. There are some limitations: + +1. Persistent session cannot be resumed from another node. +2. Clients with same client id can connect to different nodes at the same time and will not be kicked out. + +This is because session information only stores in local node and does not share between nodes. + +## Quick Start + +The following commands will start a two nodes federation, the configuration files can be found [here](./examples). +Start node1 in Terminal1: + +```bash +$ mqttd start -c path/to/retry_join/node1_config.yml +``` + +Start node2 in Terminate2: + +```bash +$ mqttd start -c path/to/retry_join/node2_config2.yml +``` + +After node1 and node2 is started, they will join into one federation atomically. + +We can test the federation with `mosquitto_pub/sub`: +Connect to node2 and subscribe topicA: + +```bash +$ mosquitto_sub -t topicA -h 127.0.0.1 -p 1884 +``` + +Connect to node1 and send a message to topicA: + +```bash +$ mosquitto_pub -t topicA -m 123 -h 127.0.0.1 -p 1883 +``` + +The `mosquitto_sub` will receive "123" and print it in the terminal. + +```bash +$ mosquitto_sub -t topicA -h 127.0.0.1 -p 1884 +123 +``` + +## Join Nodes via REST API + +Federation provides gRPC/REST API to join/leave and query members information, +see [swagger](./swagger/federation.swagger.json) for details. In addition to join nodes upon starting up, you can join a +node into federation by using `Join` API. + +Start node3 with the configuration with empty `retry_join` which means that the node will not join any nodes upon +starting up. + +```bash +$ mqttd start -c path/to/retry_join/join_node3_config.yml +``` + +We can send `Join` request to any nodes in the federation to get node3 joined, for example, sends `Join` request to +node1: + +```bash +$ curl -X POST -d '{"hosts":["127.0.0.1:8932"]}' '127.0.0.1:57091/v1/federation/join' +{} +``` + +And check the members in federation: + +```bash +curl http://127.0.0.1:57091/v1/federation/members +{ + "members": [ + { + "name": "node1", + "addr": "192.168.0.105:8902", + "tags": { + "fed_addr": "192.168.0.105:8901" + }, + "status": "STATUS_ALIVE" + }, + { + "name": "node2", + "addr": "192.168.0.105:8912", + "tags": { + "fed_addr": "192.168.0.105:8911" + }, + "status": "STATUS_ALIVE" + }, + { + "name": "node3", + "addr": "192.168.0.105:8932", + "tags": { + "fed_addr": "192.168.0.105:8931" + }, + "status": "STATUS_ALIVE" + } + ] +}% +``` + +You will see there are 3 nodes ara alive in the federation. + +## Configuration + +```go +// Config is the configuration for the federation plugin. +type Config struct { + // NodeName is the unique identifier for the node in the federation. Defaults to hostname. + NodeName string `yaml:"node_name"` + // FedAddr is the gRPC server listening address for the federation internal communication. + // Defaults to :8901. + // If the port is missing, the default federation port (8901) will be used. + FedAddr string `yaml:"fed_addr"` + // AdvertiseFedAddr is used to change the federation gRPC server address that we advertise to other nodes in the cluster. + // Defaults to "FedAddr" or the private IP address of the node if the IP in "FedAddr" is 0.0.0.0. + // However, in some cases, there may be a routable address that cannot be bound. + // If the port is missing, the default federation port (8901) will be used. + AdvertiseFedAddr string `yaml:"advertise_fed_addr"` + // GossipAddr is the address that the gossip will listen on, It is used for both UDP and TCP gossip. Defaults to :8902 + GossipAddr string `yaml:"gossip_addr"` + // AdvertiseGossipAddr is used to change the gossip server address that we advertise to other nodes in the cluster. + // Defaults to "GossipAddr" or the private IP address of the node if the IP in "GossipAddr" is 0.0.0.0. + // If the port is missing, the default gossip port (8902) will be used. + AdvertiseGossipAddr string `yaml:"advertise_gossip_addr"` + // RetryJoin is the address of other nodes to join upon starting up. + // If port is missing, the default gossip port (8902) will be used. + RetryJoin []string `yaml:"retry_join"` + // RetryInterval is the time to wait between join attempts. Defaults to 5s. + RetryInterval time.Duration `yaml:"retry_interval"` + // RetryTimeout is the timeout to wait before joining all nodes in RetryJoin successfully. + // If timeout expires, the server will exit with error. Defaults to 1m. + RetryTimeout time.Duration `yaml:"retry_timeout"` + // SnapshotPath will be pass to "SnapshotPath" in serf configuration. + // When Serf is started with a snapshot, + // it will attempt to join all the previously known nodes until one + // succeeds and will also avoid replaying old user events. + SnapshotPath string `yaml:"snapshot_path"` + // RejoinAfterLeave will be pass to "RejoinAfterLeave" in serf configuration. + // It controls our interaction with the snapshot file. + // When set to false (default), a leave causes a Serf to not rejoin + // the cluster until an explicit join is received. If this is set to + // true, we ignore the leave, and rejoin the cluster on start. + RejoinAfterLeave bool `yaml:"rejoin_after_leave"` +} +``` + +## Implementation Details + +### Inner-node Communication + +Nodes in the same federation communicate with each other through a couple of gRPC streaming apis: + +```proto +message Event { + uint64 id = 1; + oneof Event { + Subscribe Subscribe = 2; + Message message = 3; + Unsubscribe unsubscribe = 4; + } +} +service Federation { + rpc Hello(ClientHello) returns (ServerHello){} + rpc EventStream (stream Event) returns (stream Ack){} +} +``` + +In general, a node is both Client and Server which implements the `Federation` gRPC service. + +* As Client, the node will send subscribe, unsubscribe and message published events to other nodes if necessary. + Each event has a EventID, which is incremental and unique in a session. +* As Server, when receives a event from Client, the node returns an acknowledgement after the event has been handled + successfully. + +### Session State + +The event is designed to be idempotent and will be delivered at least once, just like the QoS 1 message in MQTT +protocol. In order to implement QoS 1 protocol flows, the Client and Server need to associate state with a SessionID, +this is referred to as the Session State. The Server also stores the federation tree and retained messages as part of +the Session State. + +The Session State in the Client consists of: + +* Events which have been sent to the Server, but have not been acknowledged. +* Events pending transmission to the Server. + +The Session State in the Server consists of: + +* The existence of a Session, even if the rest of the Session State is empty. +* The EventID of the next event that the Server is willing to accept. +* Events which have been received from the Client, but have not sent acknowledged yet. + +The Session State stores in memory only. When the Client starts, it generates a random UUID as SessionID. When the +Client detects a new node is joined or reconnects to the Server, it sends the `Hello` request which contains the +SessionID to perform a handshake. During the handshake, the Server will check whether the session for the SessionID +exists. + +* If the session not exists, the Server sends response with `clean_start=true`. +* If the session exists, the Server sends response with `clean_start=false` and sets the next EventID that it is willing + to accept to `next_event_id`. + +After handshake succeed, the Client will start `EventStream`: + +* If the Client receives `clean_start=true`, it sends all local subscriptions and retained messages to the Server in + order to sync the full state. +* If the Client receives `clean_start=false`, it sends events of which the EventID is greater than or equal + to `next_event_id`. + +### Subscription Tree + +Each node in the federation will have two subscription trees, the local tree and the federation tree. The local tree +stores subscriptions for local clients which is managed by gmqtt core and the federation tree stores the subscriptions +for remote nodes which is managed by the federation plugin. The federation tree takes node name as subscriber identifier +for subscriptions. + +* When receives a sub/unsub packet from a local client, the node will update it's local tree first and then broadcasts + the event to other nodes. +* When receives sub/unsub event from a remote node, the node will only update it's federation tree. + +With the federation tree, the node can determine which node the incoming message should be routed to. For example, Node1 +and Node2 are in the same federation. Client1 connects to Node1 and subscribes to topic a/b, the subscription trees of +these two nodes are as follows: + +Node1 local tree: + +| subscriber | topic | +|------------|-------| +| client1 | a/b | + +Node1 federation tree: +empty. + +Node2 local tree: +empty. + +Node2 federation tree: + +| subscriber | topic | +|------------|-------| +| node1 | a/b | + +### Message Distribution Process + +When an MQTT client publishes a message, the node where it is located queries the federation tree and forwards the +message to the relevant node according to the message topic, and then the relevant node retrieves the local subscription +tree and sends the message to the relevant subscriber. + +### Membership Management + +Federation uses [Serf](https://github.com/hashicorp/serf) to manage membership. + + diff --git a/internal/hummingbird/mqttbroker/plugin/federation/config.go b/internal/hummingbird/mqttbroker/plugin/federation/config.go new file mode 100644 index 0000000..f0e6385 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/federation/config.go @@ -0,0 +1,189 @@ +package federation + +import ( + "fmt" + "net" + "os" + "strconv" + "strings" + "time" + + "github.com/hashicorp/go-sockaddr" +) + +// Default config. +const ( + DefaultFedPort = "8901" + DefaultGossipPort = "8902" + DefaultRetryInterval = 5 * time.Second + DefaultRetryTimeout = 1 * time.Minute +) + +// stub function for testing +var getPrivateIP = sockaddr.GetPrivateIP + +// Config is the configuration for the federation plugin. +type Config struct { + // NodeName is the unique identifier for the node in the federation. Defaults to hostname. + NodeName string `yaml:"node_name"` + // FedAddr is the gRPC server listening address for the federation internal communication. + // Defaults to :8901. + // If the port is missing, the default federation port (8901) will be used. + FedAddr string `yaml:"fed_addr"` + // AdvertiseFedAddr is used to change the federation gRPC server address that we advertise to other nodes in the cluster. + // Defaults to "FedAddr" or the private IP address of the node if the IP in "FedAddr" is 0.0.0.0. + // However, in some cases, there may be a routable address that cannot be bound. + // If the port is missing, the default federation port (8901) will be used. + AdvertiseFedAddr string `yaml:"advertise_fed_addr"` + // GossipAddr is the address that the gossip will listen on, It is used for both UDP and TCP gossip. Defaults to :8902 + GossipAddr string `yaml:"gossip_addr"` + // AdvertiseGossipAddr is used to change the gossip server address that we advertise to other nodes in the cluster. + // Defaults to "GossipAddr" or the private IP address of the node if the IP in "GossipAddr" is 0.0.0.0. + // If the port is missing, the default gossip port (8902) will be used. + AdvertiseGossipAddr string `yaml:"advertise_gossip_addr"` + // RetryJoin is the address of other nodes to join upon starting up. + // If port is missing, the default gossip port (8902) will be used. + RetryJoin []string `yaml:"retry_join"` + // RetryInterval is the time to wait between join attempts. Defaults to 5s. + RetryInterval time.Duration `yaml:"retry_interval"` + // RetryTimeout is the timeout to wait before joining all nodes in RetryJoin successfully. + // If timeout expires, the server will exit with error. Defaults to 1m. + RetryTimeout time.Duration `yaml:"retry_timeout"` + // SnapshotPath will be pass to "SnapshotPath" in serf configuration. + // When Serf is started with a snapshot, + // it will attempt to join all the previously known nodes until one + // succeeds and will also avoid replaying old user events. + SnapshotPath string `yaml:"snapshot_path"` + // RejoinAfterLeave will be pass to "RejoinAfterLeave" in serf configuration. + // It controls our interaction with the snapshot file. + // When set to false (default), a leave causes a Serf to not rejoin + // the cluster until an explicit join is received. If this is set to + // true, we ignore the leave, and rejoin the cluster on start. + RejoinAfterLeave bool `yaml:"rejoin_after_leave"` +} + +func isPortNumber(port string) bool { + i, err := strconv.Atoi(port) + if err != nil { + return false + } + if 1 <= i && i <= 65535 { + return true + } + return false +} + +func getAddr(addr string, defaultPort string, fieldName string, usePrivate bool) (string, error) { + if addr == "" { + return "", fmt.Errorf("missing %s", fieldName) + } + host, port, err := net.SplitHostPort(addr) + if port == "" { + port = defaultPort + } + if addr[len(addr)-1] == ':' { + return "", fmt.Errorf("invalid %s", fieldName) + } + if err != nil && strings.Contains(err.Error(), "missing port in address") { + host, port, err = net.SplitHostPort(addr + ":" + defaultPort) + if err != nil { + return "", fmt.Errorf("invalid %s: %s", fieldName, err) + } + } else if err != nil { + return "", fmt.Errorf("invalid %s: %s", fieldName, err) + } + if usePrivate && (host == "0.0.0.0" || host == "") { + host, err = getPrivateIP() + if err != nil { + return "", err + } + } + if !isPortNumber(port) { + return "", fmt.Errorf("invalid port number: %s", port) + } + return net.JoinHostPort(host, port), nil +} + +// Validate validates the configuration, and return an error if it is invalid. +func (c *Config) Validate() (err error) { + if c.NodeName == "" { + hostName, err := os.Hostname() + if err != nil { + return err + } + c.NodeName = hostName + } + c.FedAddr, err = getAddr(c.FedAddr, DefaultFedPort, "fed_addr", false) + if err != nil { + return err + } + c.GossipAddr, err = getAddr(c.GossipAddr, DefaultGossipPort, "gossip_addr", false) + if err != nil { + return err + } + if c.AdvertiseFedAddr == "" { + c.AdvertiseFedAddr = c.FedAddr + } + c.AdvertiseFedAddr, err = getAddr(c.AdvertiseFedAddr, DefaultFedPort, "advertise_fed_addr", true) + if err != nil { + return err + } + if c.AdvertiseGossipAddr == "" { + c.AdvertiseGossipAddr = c.GossipAddr + } + c.AdvertiseGossipAddr, err = getAddr(c.AdvertiseGossipAddr, DefaultGossipPort, "advertise_gossip_addr", true) + if err != nil { + return err + } + + for k, v := range c.RetryJoin { + c.RetryJoin[k], err = getAddr(v, DefaultGossipPort, "retry_join", false) + if err != nil { + return err + } + } + if c.RetryInterval <= 0 { + return fmt.Errorf("invalid retry_join: %d", c.RetryInterval) + } + + if c.RetryTimeout <= 0 { + return fmt.Errorf("invalid retry_timeout: %d", c.RetryTimeout) + } + return nil +} + +// DefaultConfig is the default configuration. +var DefaultConfig = Config{} + +func init() { + hostName, err := os.Hostname() + if err != nil { + panic(err) + } + DefaultConfig = Config{ + NodeName: hostName, + FedAddr: ":" + DefaultFedPort, + GossipAddr: ":" + DefaultFedPort, + RetryJoin: nil, + RetryInterval: DefaultRetryInterval, + RetryTimeout: DefaultRetryTimeout, + } +} + +func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { + type cfg Config + df := cfg(DefaultConfig) + var v = &struct { + Federation *cfg `yaml:"federation"` + }{ + Federation: &df, + } + if err := unmarshal(v); err != nil { + return err + } + if v.Federation == nil { + v.Federation = &df + } + *c = Config(*v.Federation) + return nil +} diff --git a/internal/hummingbird/mqttbroker/plugin/federation/examples/join_node3_config.yml b/internal/hummingbird/mqttbroker/plugin/federation/examples/join_node3_config.yml new file mode 100644 index 0000000..7fb037a --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/federation/examples/join_node3_config.yml @@ -0,0 +1,69 @@ +listeners: + - address: ":1885" +api: + grpc: + - address: "tcp://127.0.0.1:8284" + http: + - address: "tcp://127.0.0.1:8283" + map: "tcp://127.0.0.1:8284" # The backend gRPC server endpoint +mqtt: + session_expiry: 2h + session_expiry_check_timer: 20s + message_expiry: 2h + max_packet_size: 268435456 + server_receive_maximum: 100 + max_keepalive: 60 + topic_alias_maximum: 10 + subscription_identifier_available: true + wildcard_subscription_available: true + shared_subscription_available: true + maximum_qos: 2 + retain_available: true + max_queued_messages: 10000 + max_inflight: 1000 + queue_qos0_messages: true + delivery_mode: onlyonce # overlap or onlyonce + allow_zero_length_clientid: true + +plugins: + federation: + # node_name is the unique identifier for the node in the federation. Defaults to hostname. + node_name: node3 + # fed_addr is the gRPC server listening address for the federation internal communication. Defaults to :8901 + fed_addr: :8931 + # advertise_fed_addr is used to change the federation gRPC server address that we advertise to other nodes in the cluster. + # Defaults to "fed_addr".However, in some cases, there may be a routable address that cannot be bound. + # If the port is missing, the default federation port (8901) will be used. + advertise_fed_addr: :8931 + # gossip_addr is the address that the gossip will listen on, It is used for both UDP and TCP gossip. Defaults to :8902 + gossip_addr: :8932 + # retry_join is the address of other nodes to join upon starting up. + # If port is missing, the default gossip port (8902) will be used. + #retry_join: + # - 127.0.0.1:8912 + # rejoin_after_leave will be pass to "RejoinAfterLeave" in serf configuration. + # It controls our interaction with the snapshot file. + # When set to false (default), a leave causes a Serf to not rejoin the cluster until an explicit join is received. + # If this is set to true, we ignore the leave, and rejoin the cluster on start. + rejoin_after_leave: false + # snapshot_path will be pass to "SnapshotPath" in serf configuration. + # When Serf is started with a snapshot,it will attempt to join all the previously known nodes until one + # succeeds and will also avoid replaying old user events. + snapshot_path: + +# plugin loading orders +plugin_order: + # Uncomment auth to enable authentication. + # - auth + #- prometheus + #- admin + - federation +log: + level: debug # debug | info | warn | error + format: text # json | text + # whether to dump MQTT packet in debug level + dump_packet: false + + + + diff --git a/internal/hummingbird/mqttbroker/plugin/federation/examples/node1_config.yml b/internal/hummingbird/mqttbroker/plugin/federation/examples/node1_config.yml new file mode 100644 index 0000000..2b68b86 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/federation/examples/node1_config.yml @@ -0,0 +1,70 @@ +listeners: + - address: ":58090" +api: + grpc: + - address: "tcp://127.0.0.1:57090" + http: + - address: "tcp://127.0.0.1:57091" + map: "tcp://127.0.0.1:57090" # The backend gRPC server endpoint +mqtt: + session_expiry: 2h + session_expiry_check_timer: 20s + message_expiry: 2h + max_packet_size: 268435456 + server_receive_maximum: 100 + max_keepalive: 60 + topic_alias_maximum: 10 + subscription_identifier_available: true + wildcard_subscription_available: true + shared_subscription_available: true + maximum_qos: 2 + retain_available: true + max_queued_messages: 10000 + max_inflight: 1000 + queue_qos0_messages: true + delivery_mode: onlyonce # overlap or onlyonce + allow_zero_length_clientid: true + +plugins: + federation: + # node_name is the unique identifier for the node in the federation. Defaults to hostname. + node_name: node1 + # fed_addr is the gRPC server listening address for the federation internal communication. Defaults to :8901 + fed_addr: :8901 + # advertise_fed_addr is used to change the federation gRPC server address that we advertise to other nodes in the cluster. + # Defaults to "fed_addr".However, in some cases, there may be a routable address that cannot be bound. + # If the port is missing, the default federation port (8901) will be used. + advertise_fed_addr: :8901 + # gossip_addr is the address that the gossip will listen on, It is used for both UDP and TCP gossip. Defaults to :8902 + gossip_addr: :8902 + # retry_join is the address of other nodes to join upon starting up. + # If port is missing, the default gossip port (8902) will be used. + retry_join: + # Change 127.0.0.1 to real routable ip address if you run gmqtt in multiple nodes. + - 127.0.0.1:8912 + # rejoin_after_leave will be pass to "RejoinAfterLeave" in serf configuration. + # It controls our interaction with the snapshot file. + # When set to false (default), a leave causes a Serf to not rejoin the cluster until an explicit join is received. + # If this is set to true, we ignore the leave, and rejoin the cluster on start. + rejoin_after_leave: false + # snapshot_path will be pass to "SnapshotPath" in serf configuration. + # When Serf is started with a snapshot,it will attempt to join all the previously known nodes until one + # succeeds and will also avoid replaying old user events. + snapshot_path: + +# plugin loading orders +plugin_order: + # Uncomment auth to enable authentication. + # - auth + #- prometheus + #- admin + - federation +log: + level: debug # debug | info | warn | error + format: text # json | text + # whether to dump MQTT packet in debug level + dump_packet: false + + + + diff --git a/internal/hummingbird/mqttbroker/plugin/federation/examples/node2_config.yml b/internal/hummingbird/mqttbroker/plugin/federation/examples/node2_config.yml new file mode 100644 index 0000000..843def8 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/federation/examples/node2_config.yml @@ -0,0 +1,70 @@ +listeners: + - address: ":1884" +api: + grpc: + - address: "tcp://127.0.0.1:8184" + http: + - address: "tcp://127.0.0.1:8183" + map: "tcp://127.0.0.1:8184" # The backend gRPC server endpoint +mqtt: + session_expiry: 2h + session_expiry_check_timer: 20s + message_expiry: 2h + max_packet_size: 268435456 + server_receive_maximum: 100 + max_keepalive: 60 + topic_alias_maximum: 10 + subscription_identifier_available: true + wildcard_subscription_available: true + shared_subscription_available: true + maximum_qos: 2 + retain_available: true + max_queued_messages: 10000 + max_inflight: 1000 + queue_qos0_messages: true + delivery_mode: onlyonce # overlap or onlyonce + allow_zero_length_clientid: true + +plugins: + federation: + # node_name is the unique identifier for the node in the federation. Defaults to hostname. + node_name: node2 + # fed_addr is the gRPC server listening address for the federation internal communication. Defaults to :8901 + fed_addr: :8911 + # advertise_fed_addr is used to change the federation gRPC server address that we advertise to other nodes in the cluster. + # Defaults to "fed_addr".However, in some cases, there may be a routable address that cannot be bound. + # If the port is missing, the default federation port (8901) will be used. + advertise_fed_addr: :8911 + # gossip_addr is the address that the gossip will listen on, It is used for both UDP and TCP gossip. Defaults to :8902 + gossip_addr: :8912 + # retry_join is the address of other nodes to join upon starting up. + # If port is missing, the default gossip port (8902) will be used. + retry_join: + # Change 127.0.0.1 to real routable ip address if you run gmqtt in multiple nodes. + - 127.0.0.1:8902 + # rejoin_after_leave will be pass to "RejoinAfterLeave" in serf configuration. + # It controls our interaction with the snapshot file. + # When set to false (default), a leave causes a Serf to not rejoin the cluster until an explicit join is received. + # If this is set to true, we ignore the leave, and rejoin the cluster on start. + rejoin_after_leave: false + # snapshot_path will be pass to "SnapshotPath" in serf configuration. + # When Serf is started with a snapshot,it will attempt to join all the previously known nodes until one + # succeeds and will also avoid replaying old user events. + snapshot_path: + +# plugin loading orders +plugin_order: + # Uncomment auth to enable authentication. + # - auth + #- prometheus + #- admin + - federation +log: + level: debug # debug | info | warn | error + format: text # json | text + # whether to dump MQTT packet in debug level + dump_packet: false + + + + diff --git a/internal/hummingbird/mqttbroker/plugin/federation/federation.go b/internal/hummingbird/mqttbroker/plugin/federation/federation.go new file mode 100644 index 0000000..9b41549 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/federation/federation.go @@ -0,0 +1,624 @@ +package federation + +import ( + "container/list" + "context" + "errors" + "fmt" + "io" + "net" + "strconv" + "strings" + "sync" + "time" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/hashicorp/logutils" + "github.com/hashicorp/serf/serf" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/subscription" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/subscription/mem" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/retained" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/server" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +var _ server.Plugin = (*Federation)(nil) + +const Name = "federation" + +func init() { + server.RegisterPlugin(Name, New) + config.RegisterDefaultPluginConfig(Name, &DefaultConfig) +} + +func getSerfLogger(level string) (io.Writer, error) { + logLevel := strings.ToUpper(level) + var zapLevel zapcore.Level + err := zapLevel.UnmarshalText([]byte(logLevel)) + if err != nil { + return nil, err + } + zp, err := zap.NewStdLogAt(log, zapLevel) + if err != nil { + return nil, err + } + filter := &logutils.LevelFilter{ + Levels: []logutils.LogLevel{"DEBUG", "INFO", "WARN", "ERROR"}, + MinLevel: logutils.LogLevel(logLevel), + Writer: zp.Writer(), + } + return filter, nil +} + +func getSerfConfig(cfg *Config, eventCh chan serf.Event, logOut io.Writer) *serf.Config { + serfCfg := serf.DefaultConfig() + serfCfg.SnapshotPath = cfg.SnapshotPath + serfCfg.RejoinAfterLeave = cfg.RejoinAfterLeave + serfCfg.NodeName = cfg.NodeName + serfCfg.EventCh = eventCh + host, port, _ := net.SplitHostPort(cfg.GossipAddr) + if host != "" { + serfCfg.MemberlistConfig.BindAddr = host + } + p, _ := strconv.Atoi(port) + serfCfg.MemberlistConfig.BindPort = p + + // set advertise + host, port, _ = net.SplitHostPort(cfg.AdvertiseGossipAddr) + if host != "" { + serfCfg.MemberlistConfig.AdvertiseAddr = host + } + p, _ = strconv.Atoi(port) + serfCfg.MemberlistConfig.AdvertisePort = p + + serfCfg.Tags = map[string]string{"fed_addr": cfg.AdvertiseFedAddr} + serfCfg.LogOutput = logOut + serfCfg.MemberlistConfig.LogOutput = logOut + return serfCfg +} + +func New(config config.Config) (server.Plugin, error) { + log = server.LoggerWithField(zap.String("plugin", Name)) + cfg := config.Plugins[Name].(*Config) + f := &Federation{ + config: cfg, + nodeName: cfg.NodeName, + localSubStore: &localSubStore{}, + fedSubStore: &fedSubStore{ + TrieDB: mem.NewStore(), + sharedSent: map[string]uint64{}, + }, + serfEventCh: make(chan serf.Event, 10000), + sessionMgr: &sessionMgr{ + sessions: map[string]*session{}, + }, + peers: make(map[string]*peer), + exit: make(chan struct{}), + wg: &sync.WaitGroup{}, + } + logOut, err := getSerfLogger(config.Log.Level) + if err != nil { + return nil, err + } + serfCfg := getSerfConfig(cfg, f.serfEventCh, logOut) + s, err := serf.Create(serfCfg) + if err != nil { + return nil, err + } + f.serf = s + return f, nil +} + +var log *zap.Logger + +type Federation struct { + config *Config + nodeName string + serfMu sync.Mutex + serf iSerf + serfEventCh chan serf.Event + sessionMgr *sessionMgr + // localSubStore store the subscriptions for the local node. + // The local node will only broadcast "new subscriptions" to other nodes. + // "New subscription" is the first subscription for a topic name. + // It means that if two client in the local node subscribe the same topic, only the first subscription will be broadcast. + localSubStore *localSubStore + // fedSubStore store federation subscription tree which take nodeName as the subscriber identifier. + // It is used to determine which node the incoming message should be routed to. + fedSubStore *fedSubStore + // retainedStore store is the retained store of the gmqtt core. + // Retained message will be broadcast to other nodes in the federation. + retainedStore retained.Store + publisher server.Publisher + exit chan struct{} + memberMu sync.Mutex + peers map[string]*peer + wg *sync.WaitGroup +} + +type fedSubStore struct { + *mem.TrieDB + sharedMu sync.Mutex + // sharedSent store the number of shared topic sent. + // It is used to select which node the message should be send to with round-robin strategy + sharedSent map[string]uint64 +} + +type sessionMgr struct { + sync.RWMutex + sessions map[string]*session +} + +func (s *sessionMgr) add(nodeName string, id string) (cleanStart bool, nextID uint64) { + s.Lock() + defer s.Unlock() + if v, ok := s.sessions[nodeName]; ok && v.id == id { + nextID = v.nextEventID + } else { + // v.id != id indicates that the client side may recover from crash and need to rebuild the full state. + cleanStart = true + } + if cleanStart { + s.sessions[nodeName] = &session{ + id: id, + nodeName: nodeName, + // TODO config + seenEvents: newLRUCache(100), + nextEventID: 0, + close: make(chan struct{}), + } + } + return +} + +func (s *sessionMgr) del(nodeName string) { + s.Lock() + defer s.Unlock() + if sess := s.sessions[nodeName]; sess != nil { + close(sess.close) + } + delete(s.sessions, nodeName) +} + +func (s *sessionMgr) get(nodeName string) *session { + s.RLock() + defer s.RUnlock() + return s.sessions[nodeName] +} + +// ForceLeave forces a member of a Serf cluster to enter the "left" state. +// Note that if the member is still actually alive, it will eventually rejoin the cluster. +// The true purpose of this method is to force remove "failed" nodes +// See https://www.serf.io/docs/commands/force-leave.html for details. +func (f *Federation) ForceLeave(ctx context.Context, req *ForceLeaveRequest) (*empty.Empty, error) { + if req.NodeName == "" { + return nil, errors.New("host can not be empty") + } + return &empty.Empty{}, f.serf.RemoveFailedNode(req.NodeName) +} + +// ListMembers lists all known members in the Serf cluster. +func (f *Federation) ListMembers(ctx context.Context, req *empty.Empty) (resp *ListMembersResponse, err error) { + resp = &ListMembersResponse{} + for _, v := range f.serf.Members() { + resp.Members = append(resp.Members, &Member{ + Name: v.Name, + Addr: net.JoinHostPort(v.Addr.String(), strconv.Itoa(int(v.Port))), + Tags: v.Tags, + Status: Status(v.Status), + }) + } + return resp, nil +} + +// Leave triggers a graceful leave for the local node. +// This is used to ensure other nodes see the node as "left" instead of "failed". +// Note that a leaved node cannot re-join the cluster unless you restart the leaved node. +func (f *Federation) Leave(ctx context.Context, req *empty.Empty) (resp *empty.Empty, err error) { + return &empty.Empty{}, f.serf.Leave() +} + +func (f *Federation) mustEmbedUnimplementedMembershipServer() { + return +} + +// Join tells the local node to join the an existing cluster. +// See https://www.serf.io/docs/commands/join.html for details. +func (f *Federation) Join(ctx context.Context, req *JoinRequest) (resp *empty.Empty, err error) { + for k, v := range req.Hosts { + req.Hosts[k], err = getAddr(v, DefaultGossipPort, "hosts", false) + if err != nil { + return &empty.Empty{}, status.Error(codes.InvalidArgument, err.Error()) + } + } + _, err = f.serf.Join(req.Hosts, true) + if err != nil { + return nil, err + } + return &empty.Empty{}, nil +} + +type localSubStore struct { + localStore server.SubscriptionService + sync.Mutex + // [clientID][topicName] + index map[string]map[string]struct{} + // topics store the reference counter for each topic. (map[topicName]uint64) + topics map[string]uint64 +} + +// init loads all subscriptions from gmqtt core into federation plugin. +func (l *localSubStore) init(sub server.SubscriptionService) { + l.localStore = sub + l.index = make(map[string]map[string]struct{}) + l.topics = make(map[string]uint64) + l.Lock() + defer l.Unlock() + // copy and convert subscription tree into localSubStore + sub.Iterate(func(clientID string, sub *mqttbroker.Subscription) bool { + l.subscribeLocked(clientID, sub.GetFullTopicName()) + return true + }, subscription.IterationOptions{ + Type: subscription.TypeAll, + }) +} + +// subscribe subscribe the topicName for the client and increase the reference counter of the topicName. +// It returns whether the subscription is new +func (l *localSubStore) subscribe(clientID string, topicName string) (new bool) { + l.Lock() + defer l.Unlock() + return l.subscribeLocked(clientID, topicName) +} + +func (l *localSubStore) subscribeLocked(clientID string, topicName string) (new bool) { + if _, ok := l.index[clientID]; !ok { + l.index[clientID] = make(map[string]struct{}) + } + if _, ok := l.index[clientID][topicName]; !ok { + l.index[clientID][topicName] = struct{}{} + l.topics[topicName]++ + if l.topics[topicName] == 1 { + return true + } + } + return false +} + +func (l *localSubStore) decTopicCounterLocked(topicName string) { + if _, ok := l.topics[topicName]; ok { + l.topics[topicName]-- + if l.topics[topicName] <= 0 { + delete(l.topics, topicName) + } + } +} + +// unsubscribe unsubscribe the topicName for the client and decrease the reference counter of the topicName. +// It returns whether the topicName is removed (reference counter == 0) +func (l *localSubStore) unsubscribe(clientID string, topicName string) (remove bool) { + l.Lock() + defer l.Unlock() + if v, ok := l.index[clientID]; ok { + if _, ok := v[topicName]; ok { + delete(v, topicName) + if len(v) == 0 { + delete(l.index, clientID) + } + l.decTopicCounterLocked(topicName) + return l.topics[topicName] == 0 + } + } + return false + +} + +// unsubscribeAll unsubscribes all topics for the given client. +// Typically, this function is called when the client session has terminated. +// It returns any topic that is removed。 +func (l *localSubStore) unsubscribeAll(clientID string) (remove []string) { + l.Lock() + defer l.Unlock() + for topicName := range l.index[clientID] { + l.decTopicCounterLocked(topicName) + if l.topics[topicName] == 0 { + remove = append(remove, topicName) + } + } + delete(l.index, clientID) + return remove +} + +type session struct { + id string + nodeName string + nextEventID uint64 + // seenEvents cache recently seen events to avoid duplicate events. + seenEvents *lruCache + close chan struct{} +} + +// lruCache is the cache for recently seen events. +type lruCache struct { + l *list.List + items map[uint64]struct{} + size int +} + +func newLRUCache(size int) *lruCache { + return &lruCache{ + l: list.New(), + items: make(map[uint64]struct{}), + size: size, + } +} + +func (l *lruCache) set(id uint64) (exist bool) { + if _, ok := l.items[id]; ok { + return true + } + if l.size == len(l.items) { + elem := l.l.Front() + delete(l.items, elem.Value.(uint64)) + l.l.Remove(elem) + } + l.items[id] = struct{}{} + l.l.PushBack(id) + return false +} + +func getNodeNameFromContext(ctx context.Context) (string, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "", status.Errorf(codes.DataLoss, "EventStream: failed to get metadata") + } + s := md.Get("node_name") + if len(s) == 0 { + return "", status.Errorf(codes.InvalidArgument, "EventStream: missing node_name metadata") + } + nodeName := s[0] + if nodeName == "" { + return "", status.Errorf(codes.InvalidArgument, "EventStream: missing node_name metadata") + } + return nodeName, nil +} + +// Hello is the handler for the handshake process before opening the event stream. +func (f *Federation) Hello(ctx context.Context, req *ClientHello) (resp *ServerHello, err error) { + nodeName, err := getNodeNameFromContext(ctx) + if err != nil { + return nil, err + } + f.memberMu.Lock() + p := f.peers[nodeName] + f.memberMu.Unlock() + if p == nil { + return nil, status.Errorf(codes.Internal, "Hello: the node [%s] has not yet joined", nodeName) + } + + cleanStart, nextID := f.sessionMgr.add(nodeName, req.SessionId) + if cleanStart { + _ = f.fedSubStore.UnsubscribeAll(nodeName) + } + resp = &ServerHello{ + CleanStart: cleanStart, + NextEventId: nextID, + } + return resp, nil +} + +func (f *Federation) eventStreamHandler(sess *session, in *Event) (ack *Ack) { + eventID := in.Id + // duplicated event, ignore it + if sess.seenEvents.set(eventID) { + log.Warn("ignore duplicated event", zap.String("event", in.String())) + return &Ack{ + EventId: eventID, + } + } + if sub := in.GetSubscribe(); sub != nil { + _, _ = f.fedSubStore.Subscribe(sess.nodeName, &mqttbroker.Subscription{ + ShareName: sub.ShareName, + TopicFilter: sub.TopicFilter, + }) + return &Ack{EventId: eventID} + } + if msg := in.GetMessage(); msg != nil { + pubMsg := eventToMessage(msg) + f.publisher.Publish(pubMsg) + if pubMsg.Retained { + f.retainedStore.AddOrReplace(pubMsg) + } + return &Ack{EventId: eventID} + } + if unsub := in.GetUnsubscribe(); unsub != nil { + _ = f.fedSubStore.Unsubscribe(sess.nodeName, unsub.TopicName) + return &Ack{EventId: eventID} + } + return nil +} + +func (f *Federation) EventStream(stream Federation_EventStreamServer) (err error) { + defer func() { + if err != nil && err != io.EOF { + log.Error("EventStream error", zap.Error(err)) + } + }() + md, ok := metadata.FromIncomingContext(stream.Context()) + if !ok { + return status.Errorf(codes.DataLoss, "EventStream: failed to get metadata") + } + s := md.Get("node_name") + if len(s) == 0 { + return status.Errorf(codes.InvalidArgument, "EventStream: missing node_name metadata") + } + nodeName := s[0] + if nodeName == "" { + return status.Errorf(codes.InvalidArgument, "EventStream: missing node_name metadata") + } + sess := f.sessionMgr.get(nodeName) + if sess == nil { + return status.Errorf(codes.Internal, "EventStream: node [%s] does not exist", nodeName) + } + errCh := make(chan error, 1) + done := make(chan struct{}) + // close the session if the client node has been mark as failed. + go func() { + <-sess.close + errCh <- fmt.Errorf("EventStream: the session of node [%s] has been closed", nodeName) + close(done) + }() + go func() { + for { + var in *Event + select { + case <-done: + default: + in, err = stream.Recv() + if err != nil { + errCh <- err + return + } + if ce := log.Check(zapcore.DebugLevel, "event received"); ce != nil { + ce.Write(zap.String("event", in.String())) + } + + ack := f.eventStreamHandler(sess, in) + + err = stream.Send(ack) + if err != nil { + errCh <- err + return + } + if ce := log.Check(zapcore.DebugLevel, "event ack sent"); ce != nil { + ce.Write(zap.Uint64("id", ack.EventId)) + } + sess.nextEventID = ack.EventId + 1 + } + } + }() + err = <-errCh + if err == io.EOF { + return nil + } + return err +} + +func (f *Federation) mustEmbedUnimplementedFederationServer() { + return +} + +var registerAPI = func(service server.Server, f *Federation) error { + apiRegistrar := service.APIRegistrar() + RegisterMembershipServer(apiRegistrar, f) + err := apiRegistrar.RegisterHTTPHandler(RegisterMembershipHandlerFromEndpoint) + return err +} + +func (f *Federation) Load(service server.Server) error { + err := registerAPI(service, f) + if err != nil { + return err + } + f.localSubStore.init(service.SubscriptionService()) + f.retainedStore = service.RetainedService() + f.publisher = service.Publisher() + srv := grpc.NewServer() + RegisterFederationServer(srv, f) + l, err := net.Listen("tcp", f.config.FedAddr) + if err != nil { + return err + } + go func() { + err := srv.Serve(l) + if err != nil { + panic(err) + } + }() + t := time.NewTimer(0) + timeout := time.NewTimer(f.config.RetryTimeout) + for { + select { + case <-timeout.C: + log.Error("retry timeout", zap.Error(err)) + if err != nil { + err = fmt.Errorf("retry timeout: %s", err.Error()) + return err + } + return errors.New("retry timeout") + case <-t.C: + err = f.startSerf(t) + if err == nil { + log.Info("retry join succeed") + return nil + } + log.Info("retry join failed", zap.Error(err)) + } + } +} + +func (f *Federation) Unload() error { + err := f.serf.Leave() + if err != nil { + return err + } + return f.serf.Shutdown() +} + +func (f *Federation) Name() string { + return Name +} + +func messageToEvent(msg *mqttbroker.Message) *Message { + eventMsg := &Message{ + TopicName: msg.Topic, + Payload: string(msg.Payload), + Qos: uint32(msg.QoS), + Retained: msg.Retained, + ContentType: msg.ContentType, + CorrelationData: string(msg.CorrelationData), + MessageExpiry: msg.MessageExpiry, + PayloadFormat: uint32(msg.PayloadFormat), + ResponseTopic: msg.ResponseTopic, + } + for _, v := range msg.UserProperties { + ppt := &UserProperty{ + K: make([]byte, len(v.K)), + V: make([]byte, len(v.V)), + } + copy(ppt.K, v.K) + copy(ppt.V, v.V) + eventMsg.UserProperties = append(eventMsg.UserProperties, ppt) + } + return eventMsg +} + +func eventToMessage(event *Message) *mqttbroker.Message { + pubMsg := &mqttbroker.Message{ + QoS: byte(event.Qos), + Retained: event.Retained, + Topic: event.TopicName, + Payload: []byte(event.Payload), + ContentType: event.ContentType, + CorrelationData: []byte(event.CorrelationData), + MessageExpiry: event.MessageExpiry, + PayloadFormat: packets.PayloadFormat(event.PayloadFormat), + ResponseTopic: event.ResponseTopic, + } + for _, v := range event.UserProperties { + pubMsg.UserProperties = append(pubMsg.UserProperties, packets.UserProperty{ + K: v.K, + V: v.V, + }) + } + return pubMsg +} diff --git a/internal/hummingbird/mqttbroker/plugin/federation/federation.pb.go b/internal/hummingbird/mqttbroker/plugin/federation/federation.pb.go new file mode 100644 index 0000000..bd2f9e0 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/federation/federation.pb.go @@ -0,0 +1,1205 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.22.0 +// protoc v3.13.0 +// source: federation.proto + +package federation + +import ( + reflect "reflect" + sync "sync" + + proto "github.com/golang/protobuf/proto" + empty "github.com/golang/protobuf/ptypes/empty" + _ "google.golang.org/genproto/googleapis/api/annotations" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// This is a compile-time assertion that a sufficiently up-to-date version +// of the legacy proto package is being used. +const _ = proto.ProtoPackageIsVersion4 + +type Status int32 + +const ( + Status_STATUS_UNSPECIFIED Status = 0 + Status_STATUS_ALIVE Status = 1 + Status_STATUS_LEAVING Status = 2 + Status_STATUS_LEFT Status = 3 + Status_STATUS_FAILED Status = 4 +) + +// Enum value maps for Status. +var ( + Status_name = map[int32]string{ + 0: "STATUS_UNSPECIFIED", + 1: "STATUS_ALIVE", + 2: "STATUS_LEAVING", + 3: "STATUS_LEFT", + 4: "STATUS_FAILED", + } + Status_value = map[string]int32{ + "STATUS_UNSPECIFIED": 0, + "STATUS_ALIVE": 1, + "STATUS_LEAVING": 2, + "STATUS_LEFT": 3, + "STATUS_FAILED": 4, + } +) + +func (x Status) Enum() *Status { + p := new(Status) + *p = x + return p +} + +func (x Status) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (Status) Descriptor() protoreflect.EnumDescriptor { + return file_federation_proto_enumTypes[0].Descriptor() +} + +func (Status) Type() protoreflect.EnumType { + return &file_federation_proto_enumTypes[0] +} + +func (x Status) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use Status.Descriptor instead. +func (Status) EnumDescriptor() ([]byte, []int) { + return file_federation_proto_rawDescGZIP(), []int{0} +} + +type Event struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Id uint64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` + // Types that are assignable to Event: + // *Event_Subscribe + // *Event_Message + // *Event_Unsubscribe + Event isEvent_Event `protobuf_oneof:"Event"` +} + +func (x *Event) Reset() { + *x = Event{} + if protoimpl.UnsafeEnabled { + mi := &file_federation_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Event) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Event) ProtoMessage() {} + +func (x *Event) ProtoReflect() protoreflect.Message { + mi := &file_federation_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Event.ProtoReflect.Descriptor instead. +func (*Event) Descriptor() ([]byte, []int) { + return file_federation_proto_rawDescGZIP(), []int{0} +} + +func (x *Event) GetId() uint64 { + if x != nil { + return x.Id + } + return 0 +} + +func (m *Event) GetEvent() isEvent_Event { + if m != nil { + return m.Event + } + return nil +} + +func (x *Event) GetSubscribe() *Subscribe { + if x, ok := x.GetEvent().(*Event_Subscribe); ok { + return x.Subscribe + } + return nil +} + +func (x *Event) GetMessage() *Message { + if x, ok := x.GetEvent().(*Event_Message); ok { + return x.Message + } + return nil +} + +func (x *Event) GetUnsubscribe() *Unsubscribe { + if x, ok := x.GetEvent().(*Event_Unsubscribe); ok { + return x.Unsubscribe + } + return nil +} + +type isEvent_Event interface { + isEvent_Event() +} + +type Event_Subscribe struct { + Subscribe *Subscribe `protobuf:"bytes,2,opt,name=Subscribe,proto3,oneof"` +} + +type Event_Message struct { + Message *Message `protobuf:"bytes,3,opt,name=message,proto3,oneof"` +} + +type Event_Unsubscribe struct { + Unsubscribe *Unsubscribe `protobuf:"bytes,4,opt,name=unsubscribe,proto3,oneof"` +} + +func (*Event_Subscribe) isEvent_Event() {} + +func (*Event_Message) isEvent_Event() {} + +func (*Event_Unsubscribe) isEvent_Event() {} + +// Subscribe represents the subscription for a node, it is used to route message among nodes, +// so only shared_name and topic_filter is required. +type Subscribe struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ShareName string `protobuf:"bytes,1,opt,name=share_name,json=shareName,proto3" json:"share_name,omitempty"` + TopicFilter string `protobuf:"bytes,2,opt,name=topic_filter,json=topicFilter,proto3" json:"topic_filter,omitempty"` +} + +func (x *Subscribe) Reset() { + *x = Subscribe{} + if protoimpl.UnsafeEnabled { + mi := &file_federation_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Subscribe) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Subscribe) ProtoMessage() {} + +func (x *Subscribe) ProtoReflect() protoreflect.Message { + mi := &file_federation_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Subscribe.ProtoReflect.Descriptor instead. +func (*Subscribe) Descriptor() ([]byte, []int) { + return file_federation_proto_rawDescGZIP(), []int{1} +} + +func (x *Subscribe) GetShareName() string { + if x != nil { + return x.ShareName + } + return "" +} + +func (x *Subscribe) GetTopicFilter() string { + if x != nil { + return x.TopicFilter + } + return "" +} + +type Message struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + TopicName string `protobuf:"bytes,1,opt,name=topic_name,json=topicName,proto3" json:"topic_name,omitempty"` + Payload string `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"` + Qos uint32 `protobuf:"varint,3,opt,name=qos,proto3" json:"qos,omitempty"` + Retained bool `protobuf:"varint,4,opt,name=retained,proto3" json:"retained,omitempty"` + // the following fields are using in v5 client. + ContentType string `protobuf:"bytes,5,opt,name=content_type,json=contentType,proto3" json:"content_type,omitempty"` + CorrelationData string `protobuf:"bytes,6,opt,name=correlation_data,json=correlationData,proto3" json:"correlation_data,omitempty"` + MessageExpiry uint32 `protobuf:"varint,7,opt,name=message_expiry,json=messageExpiry,proto3" json:"message_expiry,omitempty"` + PayloadFormat uint32 `protobuf:"varint,8,opt,name=payload_format,json=payloadFormat,proto3" json:"payload_format,omitempty"` + ResponseTopic string `protobuf:"bytes,9,opt,name=response_topic,json=responseTopic,proto3" json:"response_topic,omitempty"` + UserProperties []*UserProperty `protobuf:"bytes,10,rep,name=user_properties,json=userProperties,proto3" json:"user_properties,omitempty"` +} + +func (x *Message) Reset() { + *x = Message{} + if protoimpl.UnsafeEnabled { + mi := &file_federation_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Message) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Message) ProtoMessage() {} + +func (x *Message) ProtoReflect() protoreflect.Message { + mi := &file_federation_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Message.ProtoReflect.Descriptor instead. +func (*Message) Descriptor() ([]byte, []int) { + return file_federation_proto_rawDescGZIP(), []int{2} +} + +func (x *Message) GetTopicName() string { + if x != nil { + return x.TopicName + } + return "" +} + +func (x *Message) GetPayload() string { + if x != nil { + return x.Payload + } + return "" +} + +func (x *Message) GetQos() uint32 { + if x != nil { + return x.Qos + } + return 0 +} + +func (x *Message) GetRetained() bool { + if x != nil { + return x.Retained + } + return false +} + +func (x *Message) GetContentType() string { + if x != nil { + return x.ContentType + } + return "" +} + +func (x *Message) GetCorrelationData() string { + if x != nil { + return x.CorrelationData + } + return "" +} + +func (x *Message) GetMessageExpiry() uint32 { + if x != nil { + return x.MessageExpiry + } + return 0 +} + +func (x *Message) GetPayloadFormat() uint32 { + if x != nil { + return x.PayloadFormat + } + return 0 +} + +func (x *Message) GetResponseTopic() string { + if x != nil { + return x.ResponseTopic + } + return "" +} + +func (x *Message) GetUserProperties() []*UserProperty { + if x != nil { + return x.UserProperties + } + return nil +} + +type UserProperty struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + K []byte `protobuf:"bytes,1,opt,name=K,proto3" json:"K,omitempty"` + V []byte `protobuf:"bytes,2,opt,name=V,proto3" json:"V,omitempty"` +} + +func (x *UserProperty) Reset() { + *x = UserProperty{} + if protoimpl.UnsafeEnabled { + mi := &file_federation_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *UserProperty) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*UserProperty) ProtoMessage() {} + +func (x *UserProperty) ProtoReflect() protoreflect.Message { + mi := &file_federation_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use UserProperty.ProtoReflect.Descriptor instead. +func (*UserProperty) Descriptor() ([]byte, []int) { + return file_federation_proto_rawDescGZIP(), []int{3} +} + +func (x *UserProperty) GetK() []byte { + if x != nil { + return x.K + } + return nil +} + +func (x *UserProperty) GetV() []byte { + if x != nil { + return x.V + } + return nil +} + +type Unsubscribe struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + TopicName string `protobuf:"bytes,1,opt,name=topic_name,json=topicName,proto3" json:"topic_name,omitempty"` +} + +func (x *Unsubscribe) Reset() { + *x = Unsubscribe{} + if protoimpl.UnsafeEnabled { + mi := &file_federation_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Unsubscribe) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Unsubscribe) ProtoMessage() {} + +func (x *Unsubscribe) ProtoReflect() protoreflect.Message { + mi := &file_federation_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Unsubscribe.ProtoReflect.Descriptor instead. +func (*Unsubscribe) Descriptor() ([]byte, []int) { + return file_federation_proto_rawDescGZIP(), []int{4} +} + +func (x *Unsubscribe) GetTopicName() string { + if x != nil { + return x.TopicName + } + return "" +} + +type Ack struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + EventId uint64 `protobuf:"varint,1,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"` +} + +func (x *Ack) Reset() { + *x = Ack{} + if protoimpl.UnsafeEnabled { + mi := &file_federation_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Ack) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Ack) ProtoMessage() {} + +func (x *Ack) ProtoReflect() protoreflect.Message { + mi := &file_federation_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Ack.ProtoReflect.Descriptor instead. +func (*Ack) Descriptor() ([]byte, []int) { + return file_federation_proto_rawDescGZIP(), []int{5} +} + +func (x *Ack) GetEventId() uint64 { + if x != nil { + return x.EventId + } + return 0 +} + +// ClientHello is the request message in handshake process. +type ClientHello struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + SessionId string `protobuf:"bytes,1,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` +} + +func (x *ClientHello) Reset() { + *x = ClientHello{} + if protoimpl.UnsafeEnabled { + mi := &file_federation_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ClientHello) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ClientHello) ProtoMessage() {} + +func (x *ClientHello) ProtoReflect() protoreflect.Message { + mi := &file_federation_proto_msgTypes[6] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ClientHello.ProtoReflect.Descriptor instead. +func (*ClientHello) Descriptor() ([]byte, []int) { + return file_federation_proto_rawDescGZIP(), []int{6} +} + +func (x *ClientHello) GetSessionId() string { + if x != nil { + return x.SessionId + } + return "" +} + +// ServerHello is the response message in handshake process. +type ServerHello struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + CleanStart bool `protobuf:"varint,1,opt,name=clean_start,json=cleanStart,proto3" json:"clean_start,omitempty"` + NextEventId uint64 `protobuf:"varint,2,opt,name=next_event_id,json=nextEventId,proto3" json:"next_event_id,omitempty"` +} + +func (x *ServerHello) Reset() { + *x = ServerHello{} + if protoimpl.UnsafeEnabled { + mi := &file_federation_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ServerHello) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ServerHello) ProtoMessage() {} + +func (x *ServerHello) ProtoReflect() protoreflect.Message { + mi := &file_federation_proto_msgTypes[7] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ServerHello.ProtoReflect.Descriptor instead. +func (*ServerHello) Descriptor() ([]byte, []int) { + return file_federation_proto_rawDescGZIP(), []int{7} +} + +func (x *ServerHello) GetCleanStart() bool { + if x != nil { + return x.CleanStart + } + return false +} + +func (x *ServerHello) GetNextEventId() uint64 { + if x != nil { + return x.NextEventId + } + return 0 +} + +type JoinRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Hosts []string `protobuf:"bytes,1,rep,name=hosts,proto3" json:"hosts,omitempty"` +} + +func (x *JoinRequest) Reset() { + *x = JoinRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_federation_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *JoinRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*JoinRequest) ProtoMessage() {} + +func (x *JoinRequest) ProtoReflect() protoreflect.Message { + mi := &file_federation_proto_msgTypes[8] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use JoinRequest.ProtoReflect.Descriptor instead. +func (*JoinRequest) Descriptor() ([]byte, []int) { + return file_federation_proto_rawDescGZIP(), []int{8} +} + +func (x *JoinRequest) GetHosts() []string { + if x != nil { + return x.Hosts + } + return nil +} + +type Member struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Addr string `protobuf:"bytes,2,opt,name=addr,proto3" json:"addr,omitempty"` + Tags map[string]string `protobuf:"bytes,3,rep,name=tags,proto3" json:"tags,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` + Status Status `protobuf:"varint,4,opt,name=status,proto3,enum=gmqtt.federation.api.Status" json:"status,omitempty"` +} + +func (x *Member) Reset() { + *x = Member{} + if protoimpl.UnsafeEnabled { + mi := &file_federation_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Member) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Member) ProtoMessage() {} + +func (x *Member) ProtoReflect() protoreflect.Message { + mi := &file_federation_proto_msgTypes[9] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Member.ProtoReflect.Descriptor instead. +func (*Member) Descriptor() ([]byte, []int) { + return file_federation_proto_rawDescGZIP(), []int{9} +} + +func (x *Member) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *Member) GetAddr() string { + if x != nil { + return x.Addr + } + return "" +} + +func (x *Member) GetTags() map[string]string { + if x != nil { + return x.Tags + } + return nil +} + +func (x *Member) GetStatus() Status { + if x != nil { + return x.Status + } + return Status_STATUS_UNSPECIFIED +} + +type ListMembersResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Members []*Member `protobuf:"bytes,1,rep,name=members,proto3" json:"members,omitempty"` +} + +func (x *ListMembersResponse) Reset() { + *x = ListMembersResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_federation_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ListMembersResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListMembersResponse) ProtoMessage() {} + +func (x *ListMembersResponse) ProtoReflect() protoreflect.Message { + mi := &file_federation_proto_msgTypes[10] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListMembersResponse.ProtoReflect.Descriptor instead. +func (*ListMembersResponse) Descriptor() ([]byte, []int) { + return file_federation_proto_rawDescGZIP(), []int{10} +} + +func (x *ListMembersResponse) GetMembers() []*Member { + if x != nil { + return x.Members + } + return nil +} + +type ForceLeaveRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + NodeName string `protobuf:"bytes,1,opt,name=node_name,json=nodeName,proto3" json:"node_name,omitempty"` +} + +func (x *ForceLeaveRequest) Reset() { + *x = ForceLeaveRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_federation_proto_msgTypes[11] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ForceLeaveRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ForceLeaveRequest) ProtoMessage() {} + +func (x *ForceLeaveRequest) ProtoReflect() protoreflect.Message { + mi := &file_federation_proto_msgTypes[11] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ForceLeaveRequest.ProtoReflect.Descriptor instead. +func (*ForceLeaveRequest) Descriptor() ([]byte, []int) { + return file_federation_proto_rawDescGZIP(), []int{11} +} + +func (x *ForceLeaveRequest) GetNodeName() string { + if x != nil { + return x.NodeName + } + return "" +} + +var File_federation_proto protoreflect.FileDescriptor + +var file_federation_proto_rawDesc = []byte{ + 0x0a, 0x10, 0x66, 0x65, 0x64, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x12, 0x14, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x66, 0x65, 0x64, 0x65, 0x72, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x1a, 0x1c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, + 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x61, 0x6e, 0x6e, 0x6f, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1b, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x65, 0x6d, 0x70, 0x74, 0x79, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x22, 0xe3, 0x01, 0x0a, 0x05, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x12, 0x0e, 0x0a, + 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, 0x69, 0x64, 0x12, 0x3f, 0x0a, + 0x09, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x1f, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x66, 0x65, 0x64, 0x65, 0x72, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, + 0x65, 0x48, 0x00, 0x52, 0x09, 0x53, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x12, 0x39, + 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x1d, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x66, 0x65, 0x64, 0x65, 0x72, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x48, 0x00, + 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x45, 0x0a, 0x0b, 0x75, 0x6e, 0x73, + 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, + 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x66, 0x65, 0x64, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x55, 0x6e, 0x73, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, + 0x65, 0x48, 0x00, 0x52, 0x0b, 0x75, 0x6e, 0x73, 0x75, 0x62, 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, + 0x42, 0x07, 0x0a, 0x05, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x22, 0x4d, 0x0a, 0x09, 0x53, 0x75, 0x62, + 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x68, 0x61, 0x72, 0x65, 0x5f, + 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x68, 0x61, 0x72, + 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x21, 0x0a, 0x0c, 0x74, 0x6f, 0x70, 0x69, 0x63, 0x5f, 0x66, + 0x69, 0x6c, 0x74, 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x74, 0x6f, 0x70, + 0x69, 0x63, 0x46, 0x69, 0x6c, 0x74, 0x65, 0x72, 0x22, 0x80, 0x03, 0x0a, 0x07, 0x4d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x74, 0x6f, 0x70, 0x69, 0x63, 0x5f, 0x6e, 0x61, + 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x74, 0x6f, 0x70, 0x69, 0x63, 0x4e, + 0x61, 0x6d, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x10, 0x0a, + 0x03, 0x71, 0x6f, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x71, 0x6f, 0x73, 0x12, + 0x1a, 0x0a, 0x08, 0x72, 0x65, 0x74, 0x61, 0x69, 0x6e, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x08, 0x72, 0x65, 0x74, 0x61, 0x69, 0x6e, 0x65, 0x64, 0x12, 0x21, 0x0a, 0x0c, 0x63, + 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x0b, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x54, 0x79, 0x70, 0x65, 0x12, 0x29, + 0x0a, 0x10, 0x63, 0x6f, 0x72, 0x72, 0x65, 0x6c, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x64, 0x61, + 0x74, 0x61, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x63, 0x6f, 0x72, 0x72, 0x65, 0x6c, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x44, 0x61, 0x74, 0x61, 0x12, 0x25, 0x0a, 0x0e, 0x6d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x5f, 0x65, 0x78, 0x70, 0x69, 0x72, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, + 0x0d, 0x52, 0x0d, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x45, 0x78, 0x70, 0x69, 0x72, 0x79, + 0x12, 0x25, 0x0a, 0x0e, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x5f, 0x66, 0x6f, 0x72, 0x6d, + 0x61, 0x74, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0d, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, + 0x64, 0x46, 0x6f, 0x72, 0x6d, 0x61, 0x74, 0x12, 0x25, 0x0a, 0x0e, 0x72, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x5f, 0x74, 0x6f, 0x70, 0x69, 0x63, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0d, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x54, 0x6f, 0x70, 0x69, 0x63, 0x12, 0x4b, + 0x0a, 0x0f, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x70, 0x72, 0x6f, 0x70, 0x65, 0x72, 0x74, 0x69, 0x65, + 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x22, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, + 0x66, 0x65, 0x64, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x55, + 0x73, 0x65, 0x72, 0x50, 0x72, 0x6f, 0x70, 0x65, 0x72, 0x74, 0x79, 0x52, 0x0e, 0x75, 0x73, 0x65, + 0x72, 0x50, 0x72, 0x6f, 0x70, 0x65, 0x72, 0x74, 0x69, 0x65, 0x73, 0x22, 0x2a, 0x0a, 0x0c, 0x55, + 0x73, 0x65, 0x72, 0x50, 0x72, 0x6f, 0x70, 0x65, 0x72, 0x74, 0x79, 0x12, 0x0c, 0x0a, 0x01, 0x4b, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x01, 0x4b, 0x12, 0x0c, 0x0a, 0x01, 0x56, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0c, 0x52, 0x01, 0x56, 0x22, 0x2c, 0x0a, 0x0b, 0x55, 0x6e, 0x73, 0x75, 0x62, + 0x73, 0x63, 0x72, 0x69, 0x62, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x74, 0x6f, 0x70, 0x69, 0x63, 0x5f, + 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x74, 0x6f, 0x70, 0x69, + 0x63, 0x4e, 0x61, 0x6d, 0x65, 0x22, 0x20, 0x0a, 0x03, 0x41, 0x63, 0x6b, 0x12, 0x19, 0x0a, 0x08, + 0x65, 0x76, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x07, + 0x65, 0x76, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x22, 0x2c, 0x0a, 0x0b, 0x43, 0x6c, 0x69, 0x65, 0x6e, + 0x74, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, + 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x65, 0x73, 0x73, + 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x22, 0x52, 0x0a, 0x0b, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x48, + 0x65, 0x6c, 0x6c, 0x6f, 0x12, 0x1f, 0x0a, 0x0b, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x5f, 0x73, 0x74, + 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x63, 0x6c, 0x65, 0x61, 0x6e, + 0x53, 0x74, 0x61, 0x72, 0x74, 0x12, 0x22, 0x0a, 0x0d, 0x6e, 0x65, 0x78, 0x74, 0x5f, 0x65, 0x76, + 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, 0x0b, 0x6e, 0x65, + 0x78, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x22, 0x23, 0x0a, 0x0b, 0x4a, 0x6f, 0x69, + 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x68, 0x6f, 0x73, 0x74, + 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x68, 0x6f, 0x73, 0x74, 0x73, 0x22, 0xdb, + 0x01, 0x0a, 0x06, 0x4d, 0x65, 0x6d, 0x62, 0x65, 0x72, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, + 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, + 0x04, 0x61, 0x64, 0x64, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x61, 0x64, 0x64, + 0x72, 0x12, 0x3a, 0x0a, 0x04, 0x74, 0x61, 0x67, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x26, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x66, 0x65, 0x64, 0x65, 0x72, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x4d, 0x65, 0x6d, 0x62, 0x65, 0x72, 0x2e, 0x54, 0x61, + 0x67, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x04, 0x74, 0x61, 0x67, 0x73, 0x12, 0x34, 0x0a, + 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1c, 0x2e, + 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x66, 0x65, 0x64, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x06, 0x73, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x1a, 0x37, 0x0a, 0x09, 0x54, 0x61, 0x67, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, + 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, + 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x4d, 0x0a, 0x13, + 0x4c, 0x69, 0x73, 0x74, 0x4d, 0x65, 0x6d, 0x62, 0x65, 0x72, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x12, 0x36, 0x0a, 0x07, 0x6d, 0x65, 0x6d, 0x62, 0x65, 0x72, 0x73, 0x18, 0x01, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x66, 0x65, 0x64, + 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x4d, 0x65, 0x6d, 0x62, + 0x65, 0x72, 0x52, 0x07, 0x6d, 0x65, 0x6d, 0x62, 0x65, 0x72, 0x73, 0x22, 0x30, 0x0a, 0x11, 0x46, + 0x6f, 0x72, 0x63, 0x65, 0x4c, 0x65, 0x61, 0x76, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x1b, 0x0a, 0x09, 0x6e, 0x6f, 0x64, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x08, 0x6e, 0x6f, 0x64, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x2a, 0x6a, 0x0a, + 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x16, 0x0a, 0x12, 0x53, 0x54, 0x41, 0x54, 0x55, + 0x53, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, + 0x10, 0x0a, 0x0c, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x41, 0x4c, 0x49, 0x56, 0x45, 0x10, + 0x01, 0x12, 0x12, 0x0a, 0x0e, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, 0x4c, 0x45, 0x41, 0x56, + 0x49, 0x4e, 0x47, 0x10, 0x02, 0x12, 0x0f, 0x0a, 0x0b, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, 0x5f, + 0x4c, 0x45, 0x46, 0x54, 0x10, 0x03, 0x12, 0x11, 0x0a, 0x0d, 0x53, 0x54, 0x41, 0x54, 0x55, 0x53, + 0x5f, 0x46, 0x41, 0x49, 0x4c, 0x45, 0x44, 0x10, 0x04, 0x32, 0xb1, 0x03, 0x0a, 0x0a, 0x4d, 0x65, + 0x6d, 0x62, 0x65, 0x72, 0x73, 0x68, 0x69, 0x70, 0x12, 0x61, 0x0a, 0x04, 0x4a, 0x6f, 0x69, 0x6e, + 0x12, 0x21, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x66, 0x65, 0x64, 0x65, 0x72, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x1e, 0x82, 0xd3, 0xe4, + 0x93, 0x02, 0x18, 0x22, 0x13, 0x2f, 0x76, 0x31, 0x2f, 0x66, 0x65, 0x64, 0x65, 0x72, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x2f, 0x6a, 0x6f, 0x69, 0x6e, 0x3a, 0x01, 0x2a, 0x12, 0x58, 0x0a, 0x05, 0x4c, + 0x65, 0x61, 0x76, 0x65, 0x12, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x16, 0x2e, 0x67, + 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x22, 0x1f, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x19, 0x22, 0x14, 0x2f, 0x76, + 0x31, 0x2f, 0x66, 0x65, 0x64, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x6c, 0x65, 0x61, + 0x76, 0x65, 0x3a, 0x01, 0x2a, 0x12, 0x74, 0x0a, 0x0a, 0x46, 0x6f, 0x72, 0x63, 0x65, 0x4c, 0x65, + 0x61, 0x76, 0x65, 0x12, 0x27, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x66, 0x65, 0x64, 0x65, + 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x46, 0x6f, 0x72, 0x63, 0x65, + 0x4c, 0x65, 0x61, 0x76, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x67, + 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x22, 0x25, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x1f, 0x22, 0x1a, 0x2f, 0x76, + 0x31, 0x2f, 0x66, 0x65, 0x64, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x66, 0x6f, 0x72, + 0x63, 0x65, 0x5f, 0x6c, 0x65, 0x61, 0x76, 0x65, 0x3a, 0x01, 0x2a, 0x12, 0x70, 0x0a, 0x0b, 0x4c, + 0x69, 0x73, 0x74, 0x4d, 0x65, 0x6d, 0x62, 0x65, 0x72, 0x73, 0x12, 0x16, 0x2e, 0x67, 0x6f, 0x6f, + 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, + 0x74, 0x79, 0x1a, 0x29, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x66, 0x65, 0x64, 0x65, 0x72, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4d, 0x65, + 0x6d, 0x62, 0x65, 0x72, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1e, 0x82, + 0xd3, 0xe4, 0x93, 0x02, 0x18, 0x12, 0x16, 0x2f, 0x76, 0x31, 0x2f, 0x66, 0x65, 0x64, 0x65, 0x72, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x6d, 0x65, 0x6d, 0x62, 0x65, 0x72, 0x73, 0x32, 0xaa, 0x01, + 0x0a, 0x0a, 0x46, 0x65, 0x64, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x4f, 0x0a, 0x05, + 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x12, 0x21, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x66, 0x65, + 0x64, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x43, 0x6c, 0x69, + 0x65, 0x6e, 0x74, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x1a, 0x21, 0x2e, 0x67, 0x6d, 0x71, 0x74, 0x74, + 0x2e, 0x66, 0x65, 0x64, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x61, 0x70, 0x69, 0x2e, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x22, 0x00, 0x12, 0x4b, 0x0a, + 0x0b, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x1b, 0x2e, 0x67, + 0x6d, 0x71, 0x74, 0x74, 0x2e, 0x66, 0x65, 0x64, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, + 0x61, 0x70, 0x69, 0x2e, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x1a, 0x19, 0x2e, 0x67, 0x6d, 0x71, 0x74, + 0x74, 0x2e, 0x66, 0x65, 0x64, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x61, 0x70, 0x69, + 0x2e, 0x41, 0x63, 0x6b, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x0e, 0x5a, 0x0c, 0x2e, 0x3b, + 0x66, 0x65, 0x64, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, +} + +var ( + file_federation_proto_rawDescOnce sync.Once + file_federation_proto_rawDescData = file_federation_proto_rawDesc +) + +func file_federation_proto_rawDescGZIP() []byte { + file_federation_proto_rawDescOnce.Do(func() { + file_federation_proto_rawDescData = protoimpl.X.CompressGZIP(file_federation_proto_rawDescData) + }) + return file_federation_proto_rawDescData +} + +var file_federation_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_federation_proto_msgTypes = make([]protoimpl.MessageInfo, 13) +var file_federation_proto_goTypes = []interface{}{ + (Status)(0), // 0: gmqtt.federation.api.Status + (*Event)(nil), // 1: gmqtt.federation.api.Event + (*Subscribe)(nil), // 2: gmqtt.federation.api.Subscribe + (*Message)(nil), // 3: gmqtt.federation.api.Message + (*UserProperty)(nil), // 4: gmqtt.federation.api.UserProperty + (*Unsubscribe)(nil), // 5: gmqtt.federation.api.Unsubscribe + (*Ack)(nil), // 6: gmqtt.federation.api.Ack + (*ClientHello)(nil), // 7: gmqtt.federation.api.ClientHello + (*ServerHello)(nil), // 8: gmqtt.federation.api.ServerHello + (*JoinRequest)(nil), // 9: gmqtt.federation.api.JoinRequest + (*Member)(nil), // 10: gmqtt.federation.api.Member + (*ListMembersResponse)(nil), // 11: gmqtt.federation.api.ListMembersResponse + (*ForceLeaveRequest)(nil), // 12: gmqtt.federation.api.ForceLeaveRequest + nil, // 13: gmqtt.federation.api.Member.TagsEntry + (*empty.Empty)(nil), // 14: google.protobuf.Empty +} +var file_federation_proto_depIdxs = []int32{ + 2, // 0: gmqtt.federation.api.Event.Subscribe:type_name -> gmqtt.federation.api.Subscribe + 3, // 1: gmqtt.federation.api.Event.message:type_name -> gmqtt.federation.api.Message + 5, // 2: gmqtt.federation.api.Event.unsubscribe:type_name -> gmqtt.federation.api.Unsubscribe + 4, // 3: gmqtt.federation.api.Message.user_properties:type_name -> gmqtt.federation.api.UserProperty + 13, // 4: gmqtt.federation.api.Member.tags:type_name -> gmqtt.federation.api.Member.TagsEntry + 0, // 5: gmqtt.federation.api.Member.status:type_name -> gmqtt.federation.api.Status + 10, // 6: gmqtt.federation.api.ListMembersResponse.members:type_name -> gmqtt.federation.api.Member + 9, // 7: gmqtt.federation.api.Membership.Join:input_type -> gmqtt.federation.api.JoinRequest + 14, // 8: gmqtt.federation.api.Membership.Leave:input_type -> google.protobuf.Empty + 12, // 9: gmqtt.federation.api.Membership.ForceLeave:input_type -> gmqtt.federation.api.ForceLeaveRequest + 14, // 10: gmqtt.federation.api.Membership.ListMembers:input_type -> google.protobuf.Empty + 7, // 11: gmqtt.federation.api.Federation.Hello:input_type -> gmqtt.federation.api.ClientHello + 1, // 12: gmqtt.federation.api.Federation.EventStream:input_type -> gmqtt.federation.api.Event + 14, // 13: gmqtt.federation.api.Membership.Join:output_type -> google.protobuf.Empty + 14, // 14: gmqtt.federation.api.Membership.Leave:output_type -> google.protobuf.Empty + 14, // 15: gmqtt.federation.api.Membership.ForceLeave:output_type -> google.protobuf.Empty + 11, // 16: gmqtt.federation.api.Membership.ListMembers:output_type -> gmqtt.federation.api.ListMembersResponse + 8, // 17: gmqtt.federation.api.Federation.Hello:output_type -> gmqtt.federation.api.ServerHello + 6, // 18: gmqtt.federation.api.Federation.EventStream:output_type -> gmqtt.federation.api.Ack + 13, // [13:19] is the sub-list for method output_type + 7, // [7:13] is the sub-list for method input_type + 7, // [7:7] is the sub-list for extension type_name + 7, // [7:7] is the sub-list for extension extendee + 0, // [0:7] is the sub-list for field type_name +} + +func init() { file_federation_proto_init() } +func file_federation_proto_init() { + if File_federation_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_federation_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Event); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_federation_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Subscribe); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_federation_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Message); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_federation_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*UserProperty); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_federation_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Unsubscribe); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_federation_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Ack); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_federation_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ClientHello); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_federation_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ServerHello); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_federation_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*JoinRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_federation_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Member); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_federation_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ListMembersResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_federation_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ForceLeaveRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + file_federation_proto_msgTypes[0].OneofWrappers = []interface{}{ + (*Event_Subscribe)(nil), + (*Event_Message)(nil), + (*Event_Unsubscribe)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_federation_proto_rawDesc, + NumEnums: 1, + NumMessages: 13, + NumExtensions: 0, + NumServices: 2, + }, + GoTypes: file_federation_proto_goTypes, + DependencyIndexes: file_federation_proto_depIdxs, + EnumInfos: file_federation_proto_enumTypes, + MessageInfos: file_federation_proto_msgTypes, + }.Build() + File_federation_proto = out.File + file_federation_proto_rawDesc = nil + file_federation_proto_goTypes = nil + file_federation_proto_depIdxs = nil +} diff --git a/internal/hummingbird/mqttbroker/plugin/federation/federation.pb.gw.go b/internal/hummingbird/mqttbroker/plugin/federation/federation.pb.gw.go new file mode 100644 index 0000000..031a6fa --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/federation/federation.pb.gw.go @@ -0,0 +1,382 @@ +// Code generated by protoc-gen-grpc-gateway. DO NOT EDIT. +// source: federation.proto + +/* +Package federation is a reverse proxy. + +It translates gRPC into RESTful JSON APIs. +*/ +package federation + +import ( + "context" + "io" + "net/http" + + "github.com/golang/protobuf/descriptor" + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes/empty" + "github.com/grpc-ecosystem/grpc-gateway/runtime" + "github.com/grpc-ecosystem/grpc-gateway/utilities" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/status" +) + +// Suppress "imported and not used" errors +var _ codes.Code +var _ io.Reader +var _ status.Status +var _ = runtime.String +var _ = utilities.NewDoubleArray +var _ = descriptor.ForMessage + +func request_Membership_Join_0(ctx context.Context, marshaler runtime.Marshaler, client MembershipClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq JoinRequest + var metadata runtime.ServerMetadata + + newReader, berr := utilities.IOReaderFactory(req.Body) + if berr != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) + } + if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := client.Join(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err + +} + +func local_request_Membership_Join_0(ctx context.Context, marshaler runtime.Marshaler, server MembershipServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq JoinRequest + var metadata runtime.ServerMetadata + + newReader, berr := utilities.IOReaderFactory(req.Body) + if berr != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) + } + if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := server.Join(ctx, &protoReq) + return msg, metadata, err + +} + +func request_Membership_Leave_0(ctx context.Context, marshaler runtime.Marshaler, client MembershipClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq empty.Empty + var metadata runtime.ServerMetadata + + newReader, berr := utilities.IOReaderFactory(req.Body) + if berr != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) + } + if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := client.Leave(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err + +} + +func local_request_Membership_Leave_0(ctx context.Context, marshaler runtime.Marshaler, server MembershipServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq empty.Empty + var metadata runtime.ServerMetadata + + newReader, berr := utilities.IOReaderFactory(req.Body) + if berr != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) + } + if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := server.Leave(ctx, &protoReq) + return msg, metadata, err + +} + +func request_Membership_ForceLeave_0(ctx context.Context, marshaler runtime.Marshaler, client MembershipClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq ForceLeaveRequest + var metadata runtime.ServerMetadata + + newReader, berr := utilities.IOReaderFactory(req.Body) + if berr != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) + } + if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := client.ForceLeave(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err + +} + +func local_request_Membership_ForceLeave_0(ctx context.Context, marshaler runtime.Marshaler, server MembershipServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq ForceLeaveRequest + var metadata runtime.ServerMetadata + + newReader, berr := utilities.IOReaderFactory(req.Body) + if berr != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) + } + if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + + msg, err := server.ForceLeave(ctx, &protoReq) + return msg, metadata, err + +} + +func request_Membership_ListMembers_0(ctx context.Context, marshaler runtime.Marshaler, client MembershipClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq empty.Empty + var metadata runtime.ServerMetadata + + msg, err := client.ListMembers(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err + +} + +func local_request_Membership_ListMembers_0(ctx context.Context, marshaler runtime.Marshaler, server MembershipServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var protoReq empty.Empty + var metadata runtime.ServerMetadata + + msg, err := server.ListMembers(ctx, &protoReq) + return msg, metadata, err + +} + +// RegisterMembershipHandlerServer registers the http handlers for service Membership to "mux". +// UnaryRPC :call MembershipServer directly. +// StreamingRPC :currently unsupported pending https://github.com/grpc/grpc-go/issues/906. +func RegisterMembershipHandlerServer(ctx context.Context, mux *runtime.ServeMux, server MembershipServer) error { + + mux.Handle("POST", pattern_Membership_Join_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_Membership_Join_0(rctx, inboundMarshaler, server, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_Membership_Join_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("POST", pattern_Membership_Leave_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_Membership_Leave_0(rctx, inboundMarshaler, server, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_Membership_Leave_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("POST", pattern_Membership_ForceLeave_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_Membership_ForceLeave_0(rctx, inboundMarshaler, server, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_Membership_ForceLeave_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("GET", pattern_Membership_ListMembers_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateIncomingContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_Membership_ListMembers_0(rctx, inboundMarshaler, server, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_Membership_ListMembers_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + return nil +} + +// RegisterMembershipHandlerFromEndpoint is same as RegisterMembershipHandler but +// automatically dials to "endpoint" and closes the connection when "ctx" gets done. +func RegisterMembershipHandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { + conn, err := grpc.Dial(endpoint, opts...) + if err != nil { + return err + } + defer func() { + if err != nil { + if cerr := conn.Close(); cerr != nil { + grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) + } + return + } + go func() { + <-ctx.Done() + if cerr := conn.Close(); cerr != nil { + grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) + } + }() + }() + + return RegisterMembershipHandler(ctx, mux, conn) +} + +// RegisterMembershipHandler registers the http handlers for service Membership to "mux". +// The handlers forward requests to the grpc endpoint over "conn". +func RegisterMembershipHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { + return RegisterMembershipHandlerClient(ctx, mux, NewMembershipClient(conn)) +} + +// RegisterMembershipHandlerClient registers the http handlers for service Membership +// to "mux". The handlers forward requests to the grpc endpoint over the given implementation of "MembershipClient". +// Note: the gRPC framework executes interceptors within the gRPC handler. If the passed in "MembershipClient" +// doesn't go through the normal gRPC flow (creating a gRPC client etc.) then it will be up to the passed in +// "MembershipClient" to call the correct interceptors. +func RegisterMembershipHandlerClient(ctx context.Context, mux *runtime.ServeMux, client MembershipClient) error { + + mux.Handle("POST", pattern_Membership_Join_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_Membership_Join_0(rctx, inboundMarshaler, client, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_Membership_Join_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("POST", pattern_Membership_Leave_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_Membership_Leave_0(rctx, inboundMarshaler, client, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_Membership_Leave_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("POST", pattern_Membership_ForceLeave_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_Membership_ForceLeave_0(rctx, inboundMarshaler, client, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_Membership_ForceLeave_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + mux.Handle("GET", pattern_Membership_ListMembers_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + rctx, err := runtime.AnnotateContext(ctx, mux, req) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_Membership_ListMembers_0(rctx, inboundMarshaler, client, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, md) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + + forward_Membership_ListMembers_0(ctx, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + + }) + + return nil +} + +var ( + pattern_Membership_Join_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v1", "federation", "join"}, "", runtime.AssumeColonVerbOpt(true))) + + pattern_Membership_Leave_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v1", "federation", "leave"}, "", runtime.AssumeColonVerbOpt(true))) + + pattern_Membership_ForceLeave_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v1", "federation", "force_leave"}, "", runtime.AssumeColonVerbOpt(true))) + + pattern_Membership_ListMembers_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v1", "federation", "members"}, "", runtime.AssumeColonVerbOpt(true))) +) + +var ( + forward_Membership_Join_0 = runtime.ForwardResponseMessage + + forward_Membership_Leave_0 = runtime.ForwardResponseMessage + + forward_Membership_ForceLeave_0 = runtime.ForwardResponseMessage + + forward_Membership_ListMembers_0 = runtime.ForwardResponseMessage +) diff --git a/internal/hummingbird/mqttbroker/plugin/federation/federation.pb_mock.go b/internal/hummingbird/mqttbroker/plugin/federation/federation.pb_mock.go new file mode 100644 index 0000000..71e7af2 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/federation/federation.pb_mock.go @@ -0,0 +1,46 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: plugin/federation/federation.pb.go + +// Package federation is a generated GoMock package. +package federation + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockisEvent_Event is a mock of isEvent_Event interface +type MockisEvent_Event struct { + ctrl *gomock.Controller + recorder *MockisEvent_EventMockRecorder +} + +// MockisEvent_EventMockRecorder is the mock recorder for MockisEvent_Event +type MockisEvent_EventMockRecorder struct { + mock *MockisEvent_Event +} + +// NewMockisEvent_Event creates a new mock instance +func NewMockisEvent_Event(ctrl *gomock.Controller) *MockisEvent_Event { + mock := &MockisEvent_Event{ctrl: ctrl} + mock.recorder = &MockisEvent_EventMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockisEvent_Event) EXPECT() *MockisEvent_EventMockRecorder { + return m.recorder +} + +// isEvent_Event mocks base method +func (m *MockisEvent_Event) isEvent_Event() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "isEvent_Event") +} + +// isEvent_Event indicates an expected call of isEvent_Event +func (mr *MockisEvent_EventMockRecorder) isEvent_Event() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "isEvent_Event", reflect.TypeOf((*MockisEvent_Event)(nil).isEvent_Event)) +} diff --git a/internal/hummingbird/mqttbroker/plugin/federation/federation_grpc.pb.go b/internal/hummingbird/mqttbroker/plugin/federation/federation_grpc.pb.go new file mode 100644 index 0000000..d032640 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/federation/federation_grpc.pb.go @@ -0,0 +1,379 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. + +package federation + +import ( + context "context" + + empty "github.com/golang/protobuf/ptypes/empty" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion7 + +// MembershipClient is the client API for Membership service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type MembershipClient interface { + // Join tells the local node to join the an existing cluster. + // See https://www.serf.io/docs/commands/join.html for details. + Join(ctx context.Context, in *JoinRequest, opts ...grpc.CallOption) (*empty.Empty, error) + // Leave triggers a graceful leave for the local node. + // This is used to ensure other nodes see the node as "left" instead of "failed". + // Note that a leaved node cannot re-join the cluster unless you restart the leaved node. + Leave(ctx context.Context, in *empty.Empty, opts ...grpc.CallOption) (*empty.Empty, error) + // ForceLeave force forces a member of a Serf cluster to enter the "left" state. + // Note that if the member is still actually alive, it will eventually rejoin the cluster. + // The true purpose of this method is to force remove "failed" nodes + // See https://www.serf.io/docs/commands/force-leave.html for details. + ForceLeave(ctx context.Context, in *ForceLeaveRequest, opts ...grpc.CallOption) (*empty.Empty, error) + // ListMembers lists all known members in the Serf cluster. + ListMembers(ctx context.Context, in *empty.Empty, opts ...grpc.CallOption) (*ListMembersResponse, error) +} + +type membershipClient struct { + cc grpc.ClientConnInterface +} + +func NewMembershipClient(cc grpc.ClientConnInterface) MembershipClient { + return &membershipClient{cc} +} + +func (c *membershipClient) Join(ctx context.Context, in *JoinRequest, opts ...grpc.CallOption) (*empty.Empty, error) { + out := new(empty.Empty) + err := c.cc.Invoke(ctx, "/gmqtt.federation.api.Membership/Join", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *membershipClient) Leave(ctx context.Context, in *empty.Empty, opts ...grpc.CallOption) (*empty.Empty, error) { + out := new(empty.Empty) + err := c.cc.Invoke(ctx, "/gmqtt.federation.api.Membership/Leave", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *membershipClient) ForceLeave(ctx context.Context, in *ForceLeaveRequest, opts ...grpc.CallOption) (*empty.Empty, error) { + out := new(empty.Empty) + err := c.cc.Invoke(ctx, "/gmqtt.federation.api.Membership/ForceLeave", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *membershipClient) ListMembers(ctx context.Context, in *empty.Empty, opts ...grpc.CallOption) (*ListMembersResponse, error) { + out := new(ListMembersResponse) + err := c.cc.Invoke(ctx, "/gmqtt.federation.api.Membership/ListMembers", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// MembershipServer is the server API for Membership service. +// All implementations must embed UnimplementedMembershipServer +// for forward compatibility +type MembershipServer interface { + // Join tells the local node to join the an existing cluster. + // See https://www.serf.io/docs/commands/join.html for details. + Join(context.Context, *JoinRequest) (*empty.Empty, error) + // Leave triggers a graceful leave for the local node. + // This is used to ensure other nodes see the node as "left" instead of "failed". + // Note that a leaved node cannot re-join the cluster unless you restart the leaved node. + Leave(context.Context, *empty.Empty) (*empty.Empty, error) + // ForceLeave force forces a member of a Serf cluster to enter the "left" state. + // Note that if the member is still actually alive, it will eventually rejoin the cluster. + // The true purpose of this method is to force remove "failed" nodes + // See https://www.serf.io/docs/commands/force-leave.html for details. + ForceLeave(context.Context, *ForceLeaveRequest) (*empty.Empty, error) + // ListMembers lists all known members in the Serf cluster. + ListMembers(context.Context, *empty.Empty) (*ListMembersResponse, error) + mustEmbedUnimplementedMembershipServer() +} + +// UnimplementedMembershipServer must be embedded to have forward compatible implementations. +type UnimplementedMembershipServer struct { +} + +func (UnimplementedMembershipServer) Join(context.Context, *JoinRequest) (*empty.Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method Join not implemented") +} +func (UnimplementedMembershipServer) Leave(context.Context, *empty.Empty) (*empty.Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method Leave not implemented") +} +func (UnimplementedMembershipServer) ForceLeave(context.Context, *ForceLeaveRequest) (*empty.Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method ForceLeave not implemented") +} +func (UnimplementedMembershipServer) ListMembers(context.Context, *empty.Empty) (*ListMembersResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method ListMembers not implemented") +} +func (UnimplementedMembershipServer) mustEmbedUnimplementedMembershipServer() {} + +// UnsafeMembershipServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to MembershipServer will +// result in compilation errors. +type UnsafeMembershipServer interface { + mustEmbedUnimplementedMembershipServer() +} + +func RegisterMembershipServer(s grpc.ServiceRegistrar, srv MembershipServer) { + s.RegisterService(&_Membership_serviceDesc, srv) +} + +func _Membership_Join_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(JoinRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(MembershipServer).Join(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/gmqtt.federation.api.Membership/Join", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(MembershipServer).Join(ctx, req.(*JoinRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Membership_Leave_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(empty.Empty) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(MembershipServer).Leave(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/gmqtt.federation.api.Membership/Leave", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(MembershipServer).Leave(ctx, req.(*empty.Empty)) + } + return interceptor(ctx, in, info, handler) +} + +func _Membership_ForceLeave_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ForceLeaveRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(MembershipServer).ForceLeave(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/gmqtt.federation.api.Membership/ForceLeave", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(MembershipServer).ForceLeave(ctx, req.(*ForceLeaveRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Membership_ListMembers_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(empty.Empty) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(MembershipServer).ListMembers(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/gmqtt.federation.api.Membership/ListMembers", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(MembershipServer).ListMembers(ctx, req.(*empty.Empty)) + } + return interceptor(ctx, in, info, handler) +} + +var _Membership_serviceDesc = grpc.ServiceDesc{ + ServiceName: "gmqtt.federation.api.Membership", + HandlerType: (*MembershipServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Join", + Handler: _Membership_Join_Handler, + }, + { + MethodName: "Leave", + Handler: _Membership_Leave_Handler, + }, + { + MethodName: "ForceLeave", + Handler: _Membership_ForceLeave_Handler, + }, + { + MethodName: "ListMembers", + Handler: _Membership_ListMembers_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "federation.proto", +} + +// FederationClient is the client API for Federation service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type FederationClient interface { + Hello(ctx context.Context, in *ClientHello, opts ...grpc.CallOption) (*ServerHello, error) + EventStream(ctx context.Context, opts ...grpc.CallOption) (Federation_EventStreamClient, error) +} + +type federationClient struct { + cc grpc.ClientConnInterface +} + +func NewFederationClient(cc grpc.ClientConnInterface) FederationClient { + return &federationClient{cc} +} + +func (c *federationClient) Hello(ctx context.Context, in *ClientHello, opts ...grpc.CallOption) (*ServerHello, error) { + out := new(ServerHello) + err := c.cc.Invoke(ctx, "/gmqtt.federation.api.Federation/Hello", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *federationClient) EventStream(ctx context.Context, opts ...grpc.CallOption) (Federation_EventStreamClient, error) { + stream, err := c.cc.NewStream(ctx, &_Federation_serviceDesc.Streams[0], "/gmqtt.federation.api.Federation/EventStream", opts...) + if err != nil { + return nil, err + } + x := &federationEventStreamClient{stream} + return x, nil +} + +type Federation_EventStreamClient interface { + Send(*Event) error + Recv() (*Ack, error) + grpc.ClientStream +} + +type federationEventStreamClient struct { + grpc.ClientStream +} + +func (x *federationEventStreamClient) Send(m *Event) error { + return x.ClientStream.SendMsg(m) +} + +func (x *federationEventStreamClient) Recv() (*Ack, error) { + m := new(Ack) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// FederationServer is the server API for Federation service. +// All implementations must embed UnimplementedFederationServer +// for forward compatibility +type FederationServer interface { + Hello(context.Context, *ClientHello) (*ServerHello, error) + EventStream(Federation_EventStreamServer) error + mustEmbedUnimplementedFederationServer() +} + +// UnimplementedFederationServer must be embedded to have forward compatible implementations. +type UnimplementedFederationServer struct { +} + +func (UnimplementedFederationServer) Hello(context.Context, *ClientHello) (*ServerHello, error) { + return nil, status.Errorf(codes.Unimplemented, "method Hello not implemented") +} +func (UnimplementedFederationServer) EventStream(Federation_EventStreamServer) error { + return status.Errorf(codes.Unimplemented, "method EventStream not implemented") +} +func (UnimplementedFederationServer) mustEmbedUnimplementedFederationServer() {} + +// UnsafeFederationServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to FederationServer will +// result in compilation errors. +type UnsafeFederationServer interface { + mustEmbedUnimplementedFederationServer() +} + +func RegisterFederationServer(s grpc.ServiceRegistrar, srv FederationServer) { + s.RegisterService(&_Federation_serviceDesc, srv) +} + +func _Federation_Hello_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ClientHello) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(FederationServer).Hello(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/gmqtt.federation.api.Federation/Hello", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(FederationServer).Hello(ctx, req.(*ClientHello)) + } + return interceptor(ctx, in, info, handler) +} + +func _Federation_EventStream_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(FederationServer).EventStream(&federationEventStreamServer{stream}) +} + +type Federation_EventStreamServer interface { + Send(*Ack) error + Recv() (*Event, error) + grpc.ServerStream +} + +type federationEventStreamServer struct { + grpc.ServerStream +} + +func (x *federationEventStreamServer) Send(m *Ack) error { + return x.ServerStream.SendMsg(m) +} + +func (x *federationEventStreamServer) Recv() (*Event, error) { + m := new(Event) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +var _Federation_serviceDesc = grpc.ServiceDesc{ + ServiceName: "gmqtt.federation.api.Federation", + HandlerType: (*FederationServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Hello", + Handler: _Federation_Hello_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "EventStream", + Handler: _Federation_EventStream_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, + Metadata: "federation.proto", +} diff --git a/internal/hummingbird/mqttbroker/plugin/federation/federation_grpc.pb_mock.go b/internal/hummingbird/mqttbroker/plugin/federation/federation_grpc.pb_mock.go new file mode 100644 index 0000000..daccfe8 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/federation/federation_grpc.pb_mock.go @@ -0,0 +1,214 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/plugin/federation (interfaces: FederationClient,Federation_EventStreamClient) + +// Package federation is a generated GoMock package. +package federation + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + grpc "google.golang.org/grpc" + metadata "google.golang.org/grpc/metadata" +) + +// MockFederationClient is a mock of FederationClient interface +type MockFederationClient struct { + ctrl *gomock.Controller + recorder *MockFederationClientMockRecorder +} + +// MockFederationClientMockRecorder is the mock recorder for MockFederationClient +type MockFederationClientMockRecorder struct { + mock *MockFederationClient +} + +// NewMockFederationClient creates a new mock instance +func NewMockFederationClient(ctrl *gomock.Controller) *MockFederationClient { + mock := &MockFederationClient{ctrl: ctrl} + mock.recorder = &MockFederationClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockFederationClient) EXPECT() *MockFederationClientMockRecorder { + return m.recorder +} + +// EventStream mocks base method +func (m *MockFederationClient) EventStream(arg0 context.Context, arg1 ...grpc.CallOption) (Federation_EventStreamClient, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "EventStream", varargs...) + ret0, _ := ret[0].(Federation_EventStreamClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// EventStream indicates an expected call of EventStream +func (mr *MockFederationClientMockRecorder) EventStream(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EventStream", reflect.TypeOf((*MockFederationClient)(nil).EventStream), varargs...) +} + +// Hello mocks base method +func (m *MockFederationClient) Hello(arg0 context.Context, arg1 *ClientHello, arg2 ...grpc.CallOption) (*ServerHello, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Hello", varargs...) + ret0, _ := ret[0].(*ServerHello) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Hello indicates an expected call of Hello +func (mr *MockFederationClientMockRecorder) Hello(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Hello", reflect.TypeOf((*MockFederationClient)(nil).Hello), varargs...) +} + +// MockFederation_EventStreamClient is a mock of Federation_EventStreamClient interface +type MockFederation_EventStreamClient struct { + ctrl *gomock.Controller + recorder *MockFederation_EventStreamClientMockRecorder +} + +// MockFederation_EventStreamClientMockRecorder is the mock recorder for MockFederation_EventStreamClient +type MockFederation_EventStreamClientMockRecorder struct { + mock *MockFederation_EventStreamClient +} + +// NewMockFederation_EventStreamClient creates a new mock instance +func NewMockFederation_EventStreamClient(ctrl *gomock.Controller) *MockFederation_EventStreamClient { + mock := &MockFederation_EventStreamClient{ctrl: ctrl} + mock.recorder = &MockFederation_EventStreamClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockFederation_EventStreamClient) EXPECT() *MockFederation_EventStreamClientMockRecorder { + return m.recorder +} + +// CloseSend mocks base method +func (m *MockFederation_EventStreamClient) CloseSend() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseSend") + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseSend indicates an expected call of CloseSend +func (mr *MockFederation_EventStreamClientMockRecorder) CloseSend() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseSend", reflect.TypeOf((*MockFederation_EventStreamClient)(nil).CloseSend)) +} + +// Context mocks base method +func (m *MockFederation_EventStreamClient) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context +func (mr *MockFederation_EventStreamClientMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockFederation_EventStreamClient)(nil).Context)) +} + +// Header mocks base method +func (m *MockFederation_EventStreamClient) Header() (metadata.MD, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Header") + ret0, _ := ret[0].(metadata.MD) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Header indicates an expected call of Header +func (mr *MockFederation_EventStreamClientMockRecorder) Header() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Header", reflect.TypeOf((*MockFederation_EventStreamClient)(nil).Header)) +} + +// Recv mocks base method +func (m *MockFederation_EventStreamClient) Recv() (*Ack, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Recv") + ret0, _ := ret[0].(*Ack) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Recv indicates an expected call of Recv +func (mr *MockFederation_EventStreamClientMockRecorder) Recv() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Recv", reflect.TypeOf((*MockFederation_EventStreamClient)(nil).Recv)) +} + +// RecvMsg mocks base method +func (m *MockFederation_EventStreamClient) RecvMsg(arg0 interface{}) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RecvMsg", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RecvMsg indicates an expected call of RecvMsg +func (mr *MockFederation_EventStreamClientMockRecorder) RecvMsg(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecvMsg", reflect.TypeOf((*MockFederation_EventStreamClient)(nil).RecvMsg), arg0) +} + +// Send mocks base method +func (m *MockFederation_EventStreamClient) Send(arg0 *Event) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Send", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Send indicates an expected call of Send +func (mr *MockFederation_EventStreamClientMockRecorder) Send(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockFederation_EventStreamClient)(nil).Send), arg0) +} + +// SendMsg mocks base method +func (m *MockFederation_EventStreamClient) SendMsg(arg0 interface{}) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendMsg", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendMsg indicates an expected call of SendMsg +func (mr *MockFederation_EventStreamClientMockRecorder) SendMsg(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMsg", reflect.TypeOf((*MockFederation_EventStreamClient)(nil).SendMsg), arg0) +} + +// Trailer mocks base method +func (m *MockFederation_EventStreamClient) Trailer() metadata.MD { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Trailer") + ret0, _ := ret[0].(metadata.MD) + return ret0 +} + +// Trailer indicates an expected call of Trailer +func (mr *MockFederation_EventStreamClientMockRecorder) Trailer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Trailer", reflect.TypeOf((*MockFederation_EventStreamClient)(nil).Trailer)) +} diff --git a/internal/hummingbird/mqttbroker/plugin/federation/hooks.go b/internal/hummingbird/mqttbroker/plugin/federation/hooks.go new file mode 100644 index 0000000..82a2848 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/federation/hooks.go @@ -0,0 +1,232 @@ +package federation + +import ( + "context" + "sort" + + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/subscription" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/server" +) + +func (f *Federation) HookWrapper() server.HookWrapper { + return server.HookWrapper{ + OnSubscribedWrapper: f.OnSubscribedWrapper, + OnUnsubscribedWrapper: f.OnUnsubscribedWrapper, + OnMsgArrivedWrapper: f.OnMsgArrivedWrapper, + OnSessionTerminatedWrapper: f.OnSessionTerminatedWrapper, + OnWillPublishWrapper: f.OnWillPublishWrapper, + } +} + +func (f *Federation) OnSubscribedWrapper(pre server.OnSubscribed) server.OnSubscribed { + return func(ctx context.Context, client server.Client, subscription *gmqtt.Subscription) { + pre(ctx, client, subscription) + if subscription != nil { + if !f.localSubStore.subscribe(client.ClientOptions().ClientID, subscription.GetFullTopicName()) { + return + } + // only send new subscription + f.memberMu.Lock() + defer f.memberMu.Unlock() + for _, v := range f.peers { + sub := &Subscribe{ + ShareName: subscription.ShareName, + TopicFilter: subscription.TopicFilter, + } + v.queue.add(&Event{ + Event: &Event_Subscribe{ + Subscribe: sub, + }}) + } + } + } +} + +func (f *Federation) OnUnsubscribedWrapper(pre server.OnUnsubscribed) server.OnUnsubscribed { + return func(ctx context.Context, client server.Client, topicName string) { + pre(ctx, client, topicName) + if !f.localSubStore.unsubscribe(client.ClientOptions().ClientID, topicName) { + return + } + // only unsubscribe topic if there is no local subscriber anymore. + f.memberMu.Lock() + defer f.memberMu.Unlock() + for _, v := range f.peers { + unsub := &Unsubscribe{ + TopicName: topicName, + } + v.queue.add(&Event{ + Event: &Event_Unsubscribe{ + Unsubscribe: unsub, + }}) + } + } +} + +func sendSharedMsg(fs *fedSubStore, sharedList map[string][]string, send func(nodeName string, topicName string)) { + // shared subscription + fs.sharedMu.Lock() + defer fs.sharedMu.Unlock() + for topicName, v := range sharedList { + sort.Strings(v) + mod := fs.sharedSent[topicName] % (uint64(len(v))) + fs.sharedSent[topicName]++ + send(v[mod], topicName) + } +} + +// sendMessage sends messages to cluster nodes. +// For retained message, broadcasts the message to all nodes to update their local retained store. +// For none retained message , send it to the nodes which have matched topics. +// For shared subscription, we should either only send the message to local subscriber or only send the message to one node. +// If drop is true, the local node will drop the message. +// If options is not nil, the local node will apply the options to topic matching process. +func (f *Federation) sendMessage(msg *gmqtt.Message) (drop bool, options *subscription.IterationOptions) { + f.memberMu.Lock() + defer f.memberMu.Unlock() + + if msg.Retained { + eventMsg := messageToEvent(msg) + for _, v := range f.peers { + v.queue.add(&Event{ + Event: &Event_Message{ + Message: eventMsg, + }}) + } + return + } + + // shared topic => []nodeName. + sharedList := make(map[string][]string) + // append local shared subscription + f.localSubStore.localStore.Iterate(func(clientID string, sub *gmqtt.Subscription) bool { + fullTopic := sub.GetFullTopicName() + sharedList[fullTopic] = append(sharedList[fullTopic], f.nodeName) + return true + }, subscription.IterationOptions{ + Type: subscription.TypeShared, + TopicName: msg.Topic, + MatchType: subscription.MatchFilter, + }) + + // store non-shared topic, key by nodeName + nonShared := make(map[string]struct{}) + + f.fedSubStore.Iterate(func(nodeName string, sub *gmqtt.Subscription) bool { + if sub.ShareName != "" { + fullTopic := sub.GetFullTopicName() + sharedList[fullTopic] = append(sharedList[fullTopic], nodeName) + return true + } + nonShared[nodeName] = struct{}{} + return true + }, subscription.IterationOptions{ + Type: subscription.TypeAll, + TopicName: msg.Topic, + MatchType: subscription.MatchFilter, + }) + + sent := make(map[string]struct{}) + // shared subscription + sendSharedMsg(f.fedSubStore, sharedList, func(nodeName string, topicName string) { + // Do nothing if it is the local node. + if nodeName == f.nodeName { + return + } + if _, ok := sent[nodeName]; ok { + return + } + sent[nodeName] = struct{}{} + if p, ok := f.peers[nodeName]; ok { + eventMsg := messageToEvent(msg) + p.queue.add(&Event{ + Event: &Event_Message{ + Message: eventMsg, + }}) + drop = true + nonSharedOpts := subscription.IterationOptions{ + Type: subscription.TypeAll ^ subscription.TypeShared, + TopicName: msg.Topic, + MatchType: subscription.MatchFilter, + } + f.localSubStore.localStore.Iterate(func(clientID string, sub *gmqtt.Subscription) bool { + // If the message also matches non-shared subscription in local node, it can not be dropped. + // But the broker must not match any local shared subscriptions for this message, + // so we modify the iterationOptions to ignore shared subscriptions. + drop = false + options = &nonSharedOpts + return false + }, nonSharedOpts) + } + }) + // non-shared subscription + for nodeName := range nonShared { + if _, ok := sent[nodeName]; ok { + continue + } + if p, ok := f.peers[nodeName]; ok { + eventMsg := messageToEvent(msg) + p.queue.add(&Event{ + Event: &Event_Message{ + Message: eventMsg, + }}) + } + } + return +} +func (f *Federation) OnMsgArrivedWrapper(pre server.OnMsgArrived) server.OnMsgArrived { + return func(ctx context.Context, client server.Client, req *server.MsgArrivedRequest) error { + err := pre(ctx, client, req) + if err != nil { + return err + } + if req.Message != nil { + drop, opts := f.sendMessage(req.Message) + if drop { + req.Drop() + } + if opts != nil { + req.IterationOptions = *opts + } + } + return nil + } +} + +func (f *Federation) OnSessionTerminatedWrapper(pre server.OnSessionTerminated) server.OnSessionTerminated { + return func(ctx context.Context, clientID string, reason server.SessionTerminatedReason) { + pre(ctx, clientID, reason) + if unsubs := f.localSubStore.unsubscribeAll(clientID); len(unsubs) != 0 { + f.memberMu.Lock() + defer f.memberMu.Unlock() + for _, v := range f.peers { + for _, topicName := range unsubs { + unsub := &Unsubscribe{ + TopicName: topicName, + } + v.queue.add(&Event{ + Event: &Event_Unsubscribe{ + Unsubscribe: unsub, + }}) + } + } + } + } + +} + +func (f *Federation) OnWillPublishWrapper(pre server.OnWillPublish) server.OnWillPublish { + return func(ctx context.Context, clientID string, req *server.WillMsgRequest) { + pre(ctx, clientID, req) + if req.Message != nil { + drop, opts := f.sendMessage(req.Message) + if drop { + req.Drop() + } + if opts != nil { + req.IterationOptions = *opts + } + } + } +} diff --git a/internal/hummingbird/mqttbroker/plugin/federation/membership.go b/internal/hummingbird/mqttbroker/plugin/federation/membership.go new file mode 100644 index 0000000..421a60f --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/federation/membership.go @@ -0,0 +1,100 @@ +package federation + +import ( + "time" + + "github.com/google/uuid" + "github.com/hashicorp/serf/serf" + "go.uber.org/zap" +) + +// iSerf is the interface for *serf.Serf. +// It is used for test. +type iSerf interface { + Join(existing []string, ignoreOld bool) (int, error) + RemoveFailedNode(node string) error + Leave() error + Members() []serf.Member + Shutdown() error +} + +var servePeerEventStream = func(p *peer) { + p.serveEventStream() +} + +func (f *Federation) startSerf(t *time.Timer) error { + defer func() { + t.Reset(f.config.RetryInterval) + }() + if _, err := f.serf.Join(f.config.RetryJoin, true); err != nil { + return err + } + go f.eventHandler() + return nil +} + +func (f *Federation) eventHandler() { + for { + select { + case evt := <-f.serfEventCh: + switch evt.EventType() { + case serf.EventMemberJoin: + f.nodeJoin(evt.(serf.MemberEvent)) + case serf.EventMemberLeave, serf.EventMemberFailed, serf.EventMemberReap: + f.nodeFail(evt.(serf.MemberEvent)) + case serf.EventUser: + case serf.EventMemberUpdate: + // TODO + case serf.EventQuery: // Ignore + default: + } + case <-f.exit: + f.memberMu.Lock() + for _, v := range f.peers { + v.stop() + } + f.memberMu.Unlock() + return + } + } +} + +func (f *Federation) nodeJoin(member serf.MemberEvent) { + f.memberMu.Lock() + defer f.memberMu.Unlock() + for _, v := range member.Members { + if v.Name == f.nodeName { + continue + } + log.Info("member joined", zap.String("node_name", v.Name)) + if _, ok := f.peers[v.Name]; !ok { + p := &peer{ + fed: f, + member: v, + exit: make(chan struct{}), + sessionID: uuid.New().String(), + queue: newEventQueue(), + localName: f.nodeName, + } + f.peers[v.Name] = p + go servePeerEventStream(p) + } + } +} + +func (f *Federation) nodeFail(member serf.MemberEvent) { + f.memberMu.Lock() + defer f.memberMu.Unlock() + for _, v := range member.Members { + if v.Name == f.nodeName { + continue + } + if p, ok := f.peers[v.Name]; ok { + log.Error("node failed, close stream client", zap.String("node_name", v.Name)) + p.stop() + delete(f.peers, v.Name) + _ = f.fedSubStore.UnsubscribeAll(v.Name) + f.sessionMgr.del(v.Name) + } + } +} diff --git a/internal/hummingbird/mqttbroker/plugin/federation/membership_mock.go b/internal/hummingbird/mqttbroker/plugin/federation/membership_mock.go new file mode 100644 index 0000000..40bfd44 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/federation/membership_mock.go @@ -0,0 +1,106 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: plugin/federation/membership.go + +// Package federation is a generated GoMock package. +package federation + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + serf "github.com/hashicorp/serf/serf" +) + +// MockiSerf is a mock of iSerf interface +type MockiSerf struct { + ctrl *gomock.Controller + recorder *MockiSerfMockRecorder +} + +// MockiSerfMockRecorder is the mock recorder for MockiSerf +type MockiSerfMockRecorder struct { + mock *MockiSerf +} + +// NewMockiSerf creates a new mock instance +func NewMockiSerf(ctrl *gomock.Controller) *MockiSerf { + mock := &MockiSerf{ctrl: ctrl} + mock.recorder = &MockiSerfMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockiSerf) EXPECT() *MockiSerfMockRecorder { + return m.recorder +} + +// Join mocks base method +func (m *MockiSerf) Join(existing []string, ignoreOld bool) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Join", existing, ignoreOld) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Join indicates an expected call of Join +func (mr *MockiSerfMockRecorder) Join(existing, ignoreOld interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Join", reflect.TypeOf((*MockiSerf)(nil).Join), existing, ignoreOld) +} + +// RemoveFailedNode mocks base method +func (m *MockiSerf) RemoveFailedNode(node string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveFailedNode", node) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemoveFailedNode indicates an expected call of RemoveFailedNode +func (mr *MockiSerfMockRecorder) RemoveFailedNode(node interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveFailedNode", reflect.TypeOf((*MockiSerf)(nil).RemoveFailedNode), node) +} + +// Leave mocks base method +func (m *MockiSerf) Leave() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Leave") + ret0, _ := ret[0].(error) + return ret0 +} + +// Leave indicates an expected call of Leave +func (mr *MockiSerfMockRecorder) Leave() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Leave", reflect.TypeOf((*MockiSerf)(nil).Leave)) +} + +// Members mocks base method +func (m *MockiSerf) Members() []serf.Member { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Members") + ret0, _ := ret[0].([]serf.Member) + return ret0 +} + +// Members indicates an expected call of Members +func (mr *MockiSerfMockRecorder) Members() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Members", reflect.TypeOf((*MockiSerf)(nil).Members)) +} + +// Shutdown mocks base method +func (m *MockiSerf) Shutdown() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Shutdown") + ret0, _ := ret[0].(error) + return ret0 +} + +// Shutdown indicates an expected call of Shutdown +func (mr *MockiSerfMockRecorder) Shutdown() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Shutdown", reflect.TypeOf((*MockiSerf)(nil).Shutdown)) +} diff --git a/internal/hummingbird/mqttbroker/plugin/federation/peer.go b/internal/hummingbird/mqttbroker/plugin/federation/peer.go new file mode 100644 index 0000000..6284227 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/federation/peer.go @@ -0,0 +1,381 @@ +package federation + +import ( + "container/list" + "context" + "errors" + "fmt" + "io" + "sync" + "time" + + "github.com/hashicorp/serf/serf" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/subscription" +) + +type peerState byte + +const ( + peerStateStopped peerState = iota + 1 + peerStateStreaming +) + +// peer represents a remote node which act as the event stream server. +type peer struct { + fed *Federation + localName string + member serf.Member + exit chan struct{} + // local session id + sessionID string + queue queue + // stateMu guards the following fields + stateMu sync.Mutex + state peerState + // client-side stream + stream *stream +} + +type stream struct { + queue queue + conn *grpc.ClientConn + client Federation_EventStreamClient + close chan struct{} + errOnce sync.Once + err error + wg sync.WaitGroup +} + +// interface for testing +type queue interface { + clear() + close() + open() + setReadPosition(id uint64) + add(event *Event) + fetchEvents() []*Event + ack(id uint64) +} + +// eventQueue store the events that are ready to send. +// TODO add max buffer size +type eventQueue struct { + cond *sync.Cond + nextID uint64 + l *list.List + nextRead *list.Element + closed bool +} + +func newEventQueue() *eventQueue { + return &eventQueue{ + cond: sync.NewCond(&sync.Mutex{}), + nextID: 0, + l: list.New(), + closed: false, + } +} + +func (e *eventQueue) clear() { + e.cond.L.Lock() + defer e.cond.L.Unlock() + e.nextID = 0 + e.l = list.New() + e.nextRead = nil + e.closed = false +} + +func (e *eventQueue) close() { + e.cond.L.Lock() + defer e.cond.L.Unlock() + e.closed = true + e.cond.Signal() +} + +func (e *eventQueue) open() { + e.cond.L.Lock() + defer e.cond.L.Unlock() + e.closed = false + e.cond.Signal() +} + +func (e *eventQueue) setReadPosition(id uint64) { + e.cond.L.Lock() + defer e.cond.L.Unlock() + for elem := e.l.Front(); elem != nil; elem = elem.Next() { + ev := elem.Value.(*Event) + if ev.Id == id { + e.nextRead = elem + return + } + } +} + +func (e *eventQueue) add(event *Event) { + e.cond.L.Lock() + defer func() { + e.cond.L.Unlock() + e.cond.Signal() + }() + event.Id = e.nextID + e.nextID++ + elem := e.l.PushBack(event) + if e.nextRead == nil { + e.nextRead = elem + } +} + +func (e *eventQueue) fetchEvents() []*Event { + e.cond.L.Lock() + defer e.cond.L.Unlock() + + for (e.l.Len() == 0 || e.nextRead == nil) && !e.closed { + e.cond.Wait() + } + if e.closed { + return nil + } + ev := make([]*Event, 0) + var elem *list.Element + elem = e.nextRead + for i := 0; i < 100; i++ { + ev = append(ev, elem.Value.(*Event)) + elem = elem.Next() + if elem == nil { + break + } + } + e.nextRead = elem + return ev +} + +func (e *eventQueue) ack(id uint64) { + e.cond.L.Lock() + defer func() { + e.cond.L.Unlock() + e.cond.Signal() + }() + var next *list.Element + for elem := e.l.Front(); elem != nil; elem = next { + next = elem.Next() + req := elem.Value.(*Event) + if req.Id <= id { + e.l.Remove(elem) + } + if req.Id == id { + return + } + } +} + +func (p *peer) stop() { + select { + case <-p.exit: + default: + close(p.exit) + } + p.stateMu.Lock() + state := p.state + if state == peerStateStreaming { + _ = p.stream.conn.Close() + } + p.state = peerStateStopped + p.stateMu.Unlock() + if state == peerStateStreaming { + p.stream.wg.Wait() + } +} + +func (p *peer) serveEventStream() { + timer := time.NewTimer(0) + var reconnectCount int + for { + select { + case <-p.exit: + return + case <-timer.C: + err := p.serveStream(reconnectCount, timer) + select { + case <-p.exit: + return + default: + } + if err != nil { + log.Error("stream broken, reconnecting", zap.Error(err), + zap.Int("reconnect_count", reconnectCount)) + reconnectCount++ + continue + } + return + } + } +} + +func (p *peer) initStream(client FederationClient, conn *grpc.ClientConn) (s *stream, err error) { + p.stateMu.Lock() + defer func() { + if err == nil { + p.state = peerStateStreaming + } + p.stateMu.Unlock() + }() + if p.state == peerStateStopped { + return nil, errors.New("peer has been stopped") + } + helloMD := metadata.Pairs("node_name", p.localName) + helloCtx := metadata.NewOutgoingContext(context.Background(), helloMD) + sh, err := client.Hello(helloCtx, &ClientHello{ + SessionId: p.sessionID, + }) + if err != nil { + return nil, fmt.Errorf("handshake error: %s", err.Error()) + } + log.Info("handshake succeed", zap.String("remote_node", p.member.Name), zap.Bool("clean_start", sh.CleanStart)) + if sh.CleanStart { + p.queue.clear() + // sync full state + p.fed.localSubStore.Lock() + for k := range p.fed.localSubStore.topics { + shareName, topicFilter := subscription.SplitTopic(k) + p.queue.add(&Event{ + Event: &Event_Subscribe{Subscribe: &Subscribe{ + ShareName: shareName, + TopicFilter: topicFilter, + }}, + }) + } + p.fed.localSubStore.Unlock() + + p.fed.retainedStore.Iterate(func(message *mqttbroker.Message) bool { + // TODO add timestamp to retained message and use Last Write Wins (LWW) to resolve write conflicts. + p.queue.add(&Event{ + Event: &Event_Message{ + Message: messageToEvent(message.Copy()), + }, + }) + return true + }) + } + p.queue.setReadPosition(sh.NextEventId) + md := metadata.Pairs("node_name", p.localName) + ctx := metadata.NewOutgoingContext(context.Background(), md) + c, err := client.EventStream(ctx) + if err != nil { + return nil, err + } + p.queue.open() + s = &stream{ + queue: p.queue, + conn: conn, + client: c, + close: make(chan struct{}), + } + p.stream = s + return s, nil +} + +func (p *peer) serveStream(reconnectCount int, backoff *time.Timer) (err error) { + defer func() { + if err != nil { + du := time.Duration(0) + if reconnectCount != 0 { + du = time.Duration(reconnectCount) * 500 * time.Millisecond + } + if max := 2 * time.Second; du > max { + du = max + } + backoff.Reset(du) + } + }() + addr := p.member.Tags["fed_addr"] + conn, err := grpc.Dial(addr, grpc.WithInsecure()) + if err != nil { + return err + } + client := NewFederationClient(conn) + s, err := p.initStream(client, conn) + if err != nil { + return err + } + return s.serve() +} + +func (s *stream) serve() error { + s.wg.Add(2) + go s.readLoop() + go s.sendEvents() + s.wg.Wait() + return s.err +} + +func (s *stream) setError(err error) { + s.errOnce.Do(func() { + s.queue.close() + s.conn.Close() + close(s.close) + if err != nil && err != io.EOF { + log.Error("stream error", zap.Error(err)) + s.err = err + } + }) +} + +func (s *stream) readLoop() { + var err error + var resp *Ack + defer func() { + if re := recover(); re != nil { + err = errors.New(fmt.Sprint(re)) + } + s.setError(err) + s.wg.Done() + }() + for { + select { + case <-s.close: + return + default: + resp, err = s.client.Recv() + if err != nil { + return + } + s.queue.ack(resp.EventId) + if ce := log.Check(zapcore.DebugLevel, "event acked"); ce != nil { + ce.Write(zap.Uint64("id", resp.EventId)) + } + } + } +} + +func (s *stream) sendEvents() { + var err error + defer func() { + if re := recover(); re != nil { + err = errors.New(fmt.Sprint(re)) + } + s.setError(err) + s.wg.Done() + }() + for { + events := s.queue.fetchEvents() + // stream has been closed + if events == nil { + return + } + for _, v := range events { + err := s.client.Send(v) + if err != nil { + return + } + if ce := log.Check(zapcore.DebugLevel, "event sent"); ce != nil { + ce.Write(zap.String("event", v.String())) + } + } + } +} diff --git a/internal/hummingbird/mqttbroker/plugin/federation/peer_mock.go b/internal/hummingbird/mqttbroker/plugin/federation/peer_mock.go new file mode 100644 index 0000000..849fb24 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/federation/peer_mock.go @@ -0,0 +1,120 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: plugin/federation/peer.go + +// Package federation is a generated GoMock package. +package federation + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// Mockqueue is a mock of queue interface +type Mockqueue struct { + ctrl *gomock.Controller + recorder *MockqueueMockRecorder +} + +// MockqueueMockRecorder is the mock recorder for Mockqueue +type MockqueueMockRecorder struct { + mock *Mockqueue +} + +// NewMockqueue creates a new mock instance +func NewMockqueue(ctrl *gomock.Controller) *Mockqueue { + mock := &Mockqueue{ctrl: ctrl} + mock.recorder = &MockqueueMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *Mockqueue) EXPECT() *MockqueueMockRecorder { + return m.recorder +} + +// clear mocks base method +func (m *Mockqueue) clear() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "clear") +} + +// clear indicates an expected call of clear +func (mr *MockqueueMockRecorder) clear() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "clear", reflect.TypeOf((*Mockqueue)(nil).clear)) +} + +// close mocks base method +func (m *Mockqueue) close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "close") +} + +// close indicates an expected call of close +func (mr *MockqueueMockRecorder) close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "close", reflect.TypeOf((*Mockqueue)(nil).close)) +} + +// open mocks base method +func (m *Mockqueue) open() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "open") +} + +// open indicates an expected call of open +func (mr *MockqueueMockRecorder) open() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "open", reflect.TypeOf((*Mockqueue)(nil).open)) +} + +// setReadPosition mocks base method +func (m *Mockqueue) setReadPosition(id uint64) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "setReadPosition", id) +} + +// setReadPosition indicates an expected call of setReadPosition +func (mr *MockqueueMockRecorder) setReadPosition(id interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "setReadPosition", reflect.TypeOf((*Mockqueue)(nil).setReadPosition), id) +} + +// add mocks base method +func (m *Mockqueue) add(event *Event) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "add", event) +} + +// add indicates an expected call of add +func (mr *MockqueueMockRecorder) add(event interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "add", reflect.TypeOf((*Mockqueue)(nil).add), event) +} + +// fetchEvents mocks base method +func (m *Mockqueue) fetchEvents() []*Event { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "fetchEvents") + ret0, _ := ret[0].([]*Event) + return ret0 +} + +// fetchEvents indicates an expected call of fetchEvents +func (mr *MockqueueMockRecorder) fetchEvents() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "fetchEvents", reflect.TypeOf((*Mockqueue)(nil).fetchEvents)) +} + +// ack mocks base method +func (m *Mockqueue) ack(id uint64) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ack", id) +} + +// ack indicates an expected call of ack +func (mr *MockqueueMockRecorder) ack(id interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ack", reflect.TypeOf((*Mockqueue)(nil).ack), id) +} diff --git a/internal/hummingbird/mqttbroker/plugin/federation/protos/federation.proto b/internal/hummingbird/mqttbroker/plugin/federation/protos/federation.proto new file mode 100644 index 0000000..7398d65 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/federation/protos/federation.proto @@ -0,0 +1,129 @@ +syntax = "proto3"; + +package gmqtt.federation.api; +option go_package = ".;federation"; + +import "google/api/annotations.proto"; +import "google/protobuf/empty.proto"; + +message Event { + uint64 id = 1; + oneof Event { + Subscribe Subscribe = 2; + Message message = 3; + Unsubscribe unsubscribe = 4; + } +} + +// Subscribe represents the subscription for a node, it is used to route message among nodes, +// so only shared_name and topic_filter is required. +message Subscribe { + string share_name = 1; + string topic_filter = 2; +} + +message Message{ + string topic_name = 1; + string payload = 2; + uint32 qos = 3; + bool retained = 4; + // the following fields are using in v5 client. + string content_type = 5; + string correlation_data = 6; + uint32 message_expiry = 7; + uint32 payload_format = 8; + string response_topic = 9; + repeated UserProperty user_properties = 10; +} + +message UserProperty { + bytes K = 1; + bytes V = 2; +} +message Unsubscribe{ + string topic_name = 1; +} + +message Ack { + uint64 event_id = 1; +} + +// ClientHello is the request message in handshake process. +message ClientHello { + string session_id = 1; +} + +// ServerHello is the response message in handshake process. +message ServerHello{ + bool clean_start = 1; + uint64 next_event_id = 2; +} + +message JoinRequest { + repeated string hosts = 1; +} + + +message Member { + string name = 1; + string addr = 2; + map tags = 3; + Status status = 4; +} + +enum Status { + STATUS_UNSPECIFIED = 0; + STATUS_ALIVE = 1; + STATUS_LEAVING = 2; + STATUS_LEFT = 3; + STATUS_FAILED = 4; +} + +message ListMembersResponse { + repeated Member members = 1; +} + +message ForceLeaveRequest { + string node_name = 1; +} + +service Membership { + // Join tells the local node to join the an existing cluster. + // See https://www.serf.io/docs/commands/join.html for details. + rpc Join(JoinRequest) returns (google.protobuf.Empty){ + option (google.api.http) = { + post: "/v1/federation/join" + body:"*" + }; + } + // Leave triggers a graceful leave for the local node. + // This is used to ensure other nodes see the node as "left" instead of "failed". + // Note that a leaved node cannot re-join the cluster unless you restart the leaved node. + rpc Leave(google.protobuf.Empty) returns (google.protobuf.Empty){ + option (google.api.http) = { + post: "/v1/federation/leave" + body:"*" + }; + } + // ForceLeave force forces a member of a Serf cluster to enter the "left" state. + // Note that if the member is still actually alive, it will eventually rejoin the cluster. + // The true purpose of this method is to force remove "failed" nodes + // See https://www.serf.io/docs/commands/force-leave.html for details. + rpc ForceLeave(ForceLeaveRequest) returns (google.protobuf.Empty){ + option (google.api.http) = { + post: "/v1/federation/force_leave" + body:"*" + }; + } + // ListMembers lists all known members in the Serf cluster. + rpc ListMembers(google.protobuf.Empty) returns (ListMembersResponse){ + option (google.api.http) = { + get: "/v1/federation/members" + }; + } +} + +service Federation { + rpc Hello(ClientHello) returns (ServerHello){} + rpc EventStream (stream Event) returns (stream Ack){} +} diff --git a/internal/hummingbird/mqttbroker/plugin/federation/protos/proto_gen.sh b/internal/hummingbird/mqttbroker/plugin/federation/protos/proto_gen.sh new file mode 100755 index 0000000..261d433 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/federation/protos/proto_gen.sh @@ -0,0 +1,8 @@ +protoc -I. \ +-I$GOPATH/src/github.com/grpc-ecosystem/grpc-gateway \ +-I$GOPATH/src/github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis \ +--go-grpc_out=../ \ +--go_out=../ \ +--grpc-gateway_out=../ \ +--swagger_out=../swagger \ +*.proto \ No newline at end of file diff --git a/internal/hummingbird/mqttbroker/plugin/federation/swagger/federation.swagger.json b/internal/hummingbird/mqttbroker/plugin/federation/swagger/federation.swagger.json new file mode 100644 index 0000000..033dbf0 --- /dev/null +++ b/internal/hummingbird/mqttbroker/plugin/federation/swagger/federation.swagger.json @@ -0,0 +1,357 @@ +{ + "swagger": "2.0", + "info": { + "title": "federation.proto", + "version": "version not set" + }, + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "paths": { + "/v1/federation/force_leave": { + "post": { + "summary": "ForceLeave force forces a member of a Serf cluster to enter the \"left\" state.\nNote that if the member is still actually alive, it will eventually rejoin the cluster.\nThe true purpose of this method is to force remove \"failed\" nodes\nSee https://www.serf.io/docs/commands/force-leave.html for details.", + "operationId": "ForceLeave", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "properties": {} + } + }, + "default": { + "description": "An unexpected error response", + "schema": { + "$ref": "#/definitions/runtimeError" + } + } + }, + "parameters": [ + { + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/apiForceLeaveRequest" + } + } + ], + "tags": [ + "Membership" + ] + } + }, + "/v1/federation/join": { + "post": { + "summary": "Join tells the local node to join the an existing cluster.\nSee https://www.serf.io/docs/commands/join.html for details.", + "operationId": "Join", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "properties": {} + } + }, + "default": { + "description": "An unexpected error response", + "schema": { + "$ref": "#/definitions/runtimeError" + } + } + }, + "parameters": [ + { + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/apiJoinRequest" + } + } + ], + "tags": [ + "Membership" + ] + } + }, + "/v1/federation/leave": { + "post": { + "summary": "Leave triggers a graceful leave for the local node.\nThis is used to ensure other nodes see the node as \"left\" instead of \"failed\".\nNote that a leaved node cannot re-join the cluster unless you restart the leaved node.", + "operationId": "Leave", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "properties": {} + } + }, + "default": { + "description": "An unexpected error response", + "schema": { + "$ref": "#/definitions/runtimeError" + } + } + }, + "parameters": [ + { + "name": "body", + "in": "body", + "required": true, + "schema": { + "properties": {} + } + } + ], + "tags": [ + "Membership" + ] + } + }, + "/v1/federation/members": { + "get": { + "summary": "ListMembers lists all known members in the Serf cluster.", + "operationId": "ListMembers", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "$ref": "#/definitions/apiListMembersResponse" + } + }, + "default": { + "description": "An unexpected error response", + "schema": { + "$ref": "#/definitions/runtimeError" + } + } + }, + "tags": [ + "Membership" + ] + } + } + }, + "definitions": { + "apiAck": { + "type": "object", + "properties": { + "event_id": { + "type": "string", + "format": "uint64" + } + } + }, + "apiForceLeaveRequest": { + "type": "object", + "properties": { + "node_name": { + "type": "string" + } + } + }, + "apiJoinRequest": { + "type": "object", + "properties": { + "hosts": { + "type": "array", + "items": { + "type": "string" + } + } + } + }, + "apiListMembersResponse": { + "type": "object", + "properties": { + "members": { + "type": "array", + "items": { + "$ref": "#/definitions/apiMember" + } + } + } + }, + "apiMember": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "addr": { + "type": "string" + }, + "tags": { + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "status": { + "$ref": "#/definitions/apiStatus" + } + } + }, + "apiMessage": { + "type": "object", + "properties": { + "topic_name": { + "type": "string" + }, + "payload": { + "type": "string" + }, + "qos": { + "type": "integer", + "format": "int64" + }, + "retained": { + "type": "boolean", + "format": "boolean" + }, + "content_type": { + "type": "string", + "description": "the following fields are using in v5 client." + }, + "correlation_data": { + "type": "string" + }, + "message_expiry": { + "type": "integer", + "format": "int64" + }, + "payload_format": { + "type": "integer", + "format": "int64" + }, + "response_topic": { + "type": "string" + }, + "user_properties": { + "type": "array", + "items": { + "$ref": "#/definitions/apiUserProperty" + } + } + } + }, + "apiServerHello": { + "type": "object", + "properties": { + "clean_start": { + "type": "boolean", + "format": "boolean" + }, + "next_event_id": { + "type": "string", + "format": "uint64" + } + }, + "description": "ServerHello is the response message in handshake process." + }, + "apiStatus": { + "type": "string", + "enum": [ + "STATUS_UNSPECIFIED", + "STATUS_ALIVE", + "STATUS_LEAVING", + "STATUS_LEFT", + "STATUS_FAILED" + ], + "default": "STATUS_UNSPECIFIED" + }, + "apiSubscribe": { + "type": "object", + "properties": { + "share_name": { + "type": "string" + }, + "topic_filter": { + "type": "string" + } + }, + "description": "Subscribe represents the subscription for a node, it is used to route message among nodes,\nso only shared_name and topic_filter is required." + }, + "apiUnsubscribe": { + "type": "object", + "properties": { + "topic_name": { + "type": "string" + } + } + }, + "apiUserProperty": { + "type": "object", + "properties": { + "K": { + "type": "string", + "format": "byte" + }, + "V": { + "type": "string", + "format": "byte" + } + } + }, + "protobufAny": { + "type": "object", + "properties": { + "type_url": { + "type": "string" + }, + "value": { + "type": "string", + "format": "byte" + } + } + }, + "runtimeError": { + "type": "object", + "properties": { + "error": { + "type": "string" + }, + "code": { + "type": "integer", + "format": "int32" + }, + "message": { + "type": "string" + }, + "details": { + "type": "array", + "items": { + "$ref": "#/definitions/protobufAny" + } + } + } + }, + "runtimeStreamError": { + "type": "object", + "properties": { + "grpc_code": { + "type": "integer", + "format": "int32" + }, + "http_code": { + "type": "integer", + "format": "int32" + }, + "message": { + "type": "string" + }, + "http_status": { + "type": "string" + }, + "details": { + "type": "array", + "items": { + "$ref": "#/definitions/protobufAny" + } + } + } + } + } +} diff --git a/internal/hummingbird/mqttbroker/retained/interface.go b/internal/hummingbird/mqttbroker/retained/interface.go new file mode 100644 index 0000000..51ba811 --- /dev/null +++ b/internal/hummingbird/mqttbroker/retained/interface.go @@ -0,0 +1,32 @@ +package retained + +import "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + +// IterateFn is the callback function used by iterate() +// Return false means to stop the iteration. +type IterateFn func(message *mqttbroker.Message) bool + +// Store is the interface used by mqttbroker.server and external logic to handler the operations of retained messages. +// User can get the implementation from mqttbroker.Server interface. +// This interface provides the ability for extensions to interact with the retained message store. +// Notice: +// This methods will not trigger any gmqtt hooks. +type Store interface { + // GetRetainedMessage returns the message that equals the passed topic. + GetRetainedMessage(topicName string) *mqttbroker.Message + // ClearAll clears all retained messages. + ClearAll() + // AddOrReplace adds or replaces a retained message. + AddOrReplace(message *mqttbroker.Message) + // remove removes a retained message. + Remove(topicName string) + // GetMatchedMessages returns the retained messages that match the passed topic filter. + GetMatchedMessages(topicFilter string) []*mqttbroker.Message + // Iterate iterate all retained messages. The callback is called once for each message. + // If callback return false, the iteration will be stopped. + // Notice: + // The results are not sorted in any way, no ordering of any kind is guaranteed. + // This method will walk through all retained messages, + // so this will be a expensive operation if there are a large number of retained messages. + Iterate(fn IterateFn) +} diff --git a/internal/hummingbird/mqttbroker/retained/interface_mock.go b/internal/hummingbird/mqttbroker/retained/interface_mock.go new file mode 100644 index 0000000..10b0156 --- /dev/null +++ b/internal/hummingbird/mqttbroker/retained/interface_mock.go @@ -0,0 +1,111 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: retained/interface.go + +// Package retained is a generated GoMock package. +package retained + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" +) + +// MockStore is a mock of Store interface +type MockStore struct { + ctrl *gomock.Controller + recorder *MockStoreMockRecorder +} + +// MockStoreMockRecorder is the mock recorder for MockStore +type MockStoreMockRecorder struct { + mock *MockStore +} + +// NewMockStore creates a new mock instance +func NewMockStore(ctrl *gomock.Controller) *MockStore { + mock := &MockStore{ctrl: ctrl} + mock.recorder = &MockStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockStore) EXPECT() *MockStoreMockRecorder { + return m.recorder +} + +// GetRetainedMessage mocks base method +func (m *MockStore) GetRetainedMessage(topicName string) *gmqtt.Message { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRetainedMessage", topicName) + ret0, _ := ret[0].(*gmqtt.Message) + return ret0 +} + +// GetRetainedMessage indicates an expected call of GetRetainedMessage +func (mr *MockStoreMockRecorder) GetRetainedMessage(topicName interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRetainedMessage", reflect.TypeOf((*MockStore)(nil).GetRetainedMessage), topicName) +} + +// ClearAll mocks base method +func (m *MockStore) ClearAll() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ClearAll") +} + +// ClearAll indicates an expected call of ClearAll +func (mr *MockStoreMockRecorder) ClearAll() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearAll", reflect.TypeOf((*MockStore)(nil).ClearAll)) +} + +// AddOrReplace mocks base method +func (m *MockStore) AddOrReplace(message *gmqtt.Message) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddOrReplace", message) +} + +// AddOrReplace indicates an expected call of AddOrReplace +func (mr *MockStoreMockRecorder) AddOrReplace(message interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddOrReplace", reflect.TypeOf((*MockStore)(nil).AddOrReplace), message) +} + +// Remove mocks base method +func (m *MockStore) Remove(topicName string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Remove", topicName) +} + +// Remove indicates an expected call of Remove +func (mr *MockStoreMockRecorder) Remove(topicName interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockStore)(nil).Remove), topicName) +} + +// GetMatchedMessages mocks base method +func (m *MockStore) GetMatchedMessages(topicFilter string) []*gmqtt.Message { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMatchedMessages", topicFilter) + ret0, _ := ret[0].([]*gmqtt.Message) + return ret0 +} + +// GetMatchedMessages indicates an expected call of GetMatchedMessages +func (mr *MockStoreMockRecorder) GetMatchedMessages(topicFilter interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMatchedMessages", reflect.TypeOf((*MockStore)(nil).GetMatchedMessages), topicFilter) +} + +// Iterate mocks base method +func (m *MockStore) Iterate(fn IterateFn) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Iterate", fn) +} + +// Iterate indicates an expected call of Iterate +func (mr *MockStoreMockRecorder) Iterate(fn interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Iterate", reflect.TypeOf((*MockStore)(nil).Iterate), fn) +} diff --git a/internal/hummingbird/mqttbroker/retained/trie/retain_trie.go b/internal/hummingbird/mqttbroker/retained/trie/retain_trie.go new file mode 100644 index 0000000..fc456ff --- /dev/null +++ b/internal/hummingbird/mqttbroker/retained/trie/retain_trie.go @@ -0,0 +1,152 @@ +package trie + +import ( + "strings" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/retained" +) + +// topicTrie +type topicTrie = topicNode + +// children +type children = map[string]*topicNode + +// topicNode +type topicNode struct { + children children + msg *mqttbroker.Message + parent *topicNode // pointer of parent node + topicName string +} + +// newTopicTrie create a new trie tree +func newTopicTrie() *topicTrie { + return newNode() +} + +// newNode create a new trie node +func newNode() *topicNode { + return &topicNode{ + children: children{}, + } +} + +// newChild create a child node of t +func (t *topicNode) newChild() *topicNode { + return &topicNode{ + children: children{}, + parent: t, + } +} + +// find walk through the tire and return the node that represent the topicName +// return nil if not found +func (t *topicTrie) find(topicName string) *topicNode { + topicSlice := strings.Split(topicName, "/") + var pNode = t + for _, lv := range topicSlice { + if _, ok := pNode.children[lv]; ok { + pNode = pNode.children[lv] + } else { + return nil + } + } + if pNode.msg != nil { + return pNode + } + return nil +} + +// matchTopic walk through the tire and call the fn callback for each message witch match the topic filter. +func (t *topicTrie) matchTopic(topicSlice []string, fn retained.IterateFn) { + endFlag := len(topicSlice) == 1 + switch topicSlice[0] { + case "#": + t.preOrderTraverse(fn) + case "+": + // 当前层的所有 + for _, v := range t.children { + if endFlag { + if v.msg != nil { + fn(v.msg) + } + } else { + v.matchTopic(topicSlice[1:], fn) + } + } + default: + if n := t.children[topicSlice[0]]; n != nil { + if endFlag { + if n.msg != nil { + fn(n.msg) + } + } else { + n.matchTopic(topicSlice[1:], fn) + } + } + } +} + +func (t *topicTrie) getMatchedMessages(topicFilter string) []*mqttbroker.Message { + topicLv := strings.Split(topicFilter, "/") + var rs []*mqttbroker.Message + t.matchTopic(topicLv, func(message *mqttbroker.Message) bool { + rs = append(rs, message.Copy()) + return true + }) + return rs +} + +func isSystemTopic(topicName string) bool { + return len(topicName) >= 1 && topicName[0] == '$' +} + +// addRetainMsg add a retain message +func (t *topicTrie) addRetainMsg(topicName string, message *mqttbroker.Message) { + topicSlice := strings.Split(topicName, "/") + var pNode = t + for _, lv := range topicSlice { + if _, ok := pNode.children[lv]; !ok { + pNode.children[lv] = pNode.newChild() + } + pNode = pNode.children[lv] + } + pNode.msg = message + pNode.topicName = topicName +} + +func (t *topicTrie) remove(topicName string) { + topicSlice := strings.Split(topicName, "/") + l := len(topicSlice) + var pNode = t + for _, lv := range topicSlice { + if _, ok := pNode.children[lv]; ok { + pNode = pNode.children[lv] + } else { + return + } + } + pNode.msg = nil + if len(pNode.children) == 0 { + delete(pNode.parent.children, topicSlice[l-1]) + } +} + +func (t *topicTrie) preOrderTraverse(fn retained.IterateFn) bool { + if t == nil { + return false + } + if t.msg != nil { + if !fn(t.msg) { + return false + } + } + for _, c := range t.children { + if !c.preOrderTraverse(fn) { + return false + } + } + return true +} diff --git a/internal/hummingbird/mqttbroker/retained/trie/trie_db.go b/internal/hummingbird/mqttbroker/retained/trie/trie_db.go new file mode 100644 index 0000000..c843263 --- /dev/null +++ b/internal/hummingbird/mqttbroker/retained/trie/trie_db.go @@ -0,0 +1,79 @@ +package trie + +import ( + "sync" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/retained" +) + +// trieDB implement the retain.Store, it use trie tree to store retain messages . +type trieDB struct { + sync.RWMutex + userTrie *topicTrie + systemTrie *topicTrie +} + +func (t *trieDB) Iterate(fn retained.IterateFn) { + t.RLock() + defer t.RUnlock() + if !t.userTrie.preOrderTraverse(fn) { + return + } + t.systemTrie.preOrderTraverse(fn) +} + +func (t *trieDB) getTrie(topicName string) *topicTrie { + if isSystemTopic(topicName) { + return t.systemTrie + } + return t.userTrie +} + +// GetRetainedMessage return the retain message of the given topic name. +// return nil if the topic name not exists +func (t *trieDB) GetRetainedMessage(topicName string) *mqttbroker.Message { + t.RLock() + defer t.RUnlock() + node := t.getTrie(topicName).find(topicName) + if node != nil { + return node.msg.Copy() + } + return nil +} + +// ClearAll clear all retain messages. +func (t *trieDB) ClearAll() { + t.Lock() + defer t.Unlock() + t.systemTrie = newTopicTrie() + t.userTrie = newTopicTrie() +} + +// AddOrReplace add or replace a retain message. +func (t *trieDB) AddOrReplace(message *mqttbroker.Message) { + t.Lock() + defer t.Unlock() + t.getTrie(message.Topic).addRetainMsg(message.Topic, message) +} + +// remove remove the retain message of the topic name. +func (t *trieDB) Remove(topicName string) { + t.Lock() + defer t.Unlock() + t.getTrie(topicName).remove(topicName) +} + +// GetMatchedMessages returns all messages that match the topic filter. +func (t *trieDB) GetMatchedMessages(topicFilter string) []*mqttbroker.Message { + t.RLock() + defer t.RUnlock() + return t.getTrie(topicFilter).getMatchedMessages(topicFilter) +} + +func NewStore() *trieDB { + return &trieDB{ + userTrie: newTopicTrie(), + systemTrie: newTopicTrie(), + } +} diff --git a/internal/hummingbird/mqttbroker/retained/trie/trie_db_test.go b/internal/hummingbird/mqttbroker/retained/trie/trie_db_test.go new file mode 100644 index 0000000..8727769 --- /dev/null +++ b/internal/hummingbird/mqttbroker/retained/trie/trie_db_test.go @@ -0,0 +1,274 @@ +package trie + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" +) + +func TestTrieDB_ClearAll(t *testing.T) { + a := assert.New(t) + s := NewStore() + s.AddOrReplace(&gmqtt.Message{ + Topic: "a/b/c", + }) + s.AddOrReplace(&gmqtt.Message{ + Topic: "a/b/c/d", + Payload: []byte{1, 2, 3}, + }) + s.ClearAll() + a.Nil(s.GetRetainedMessage("a/b/c")) + a.Nil(s.GetRetainedMessage("a/b/c/d")) +} + +func TestTrieDB_GetRetainedMessage(t *testing.T) { + a := assert.New(t) + s := NewStore() + tt := []*gmqtt.Message{ + { + Topic: "a/b/c/d", + Payload: []byte{1, 2, 3}, + }, + { + Topic: "a/b/c/", + Payload: []byte{1, 2, 3, 4}, + }, + { + Topic: "a/", + Payload: []byte{1, 2, 3}, + }, + } + for _, v := range tt { + s.AddOrReplace(v) + } + for _, v := range tt { + rs := s.GetRetainedMessage(v.Topic) + a.Equal(v.Topic, rs.Topic) + a.Equal(v.Payload, rs.Payload) + } + a.Nil(s.GetRetainedMessage("a/b")) +} + +func TestTrieDB_GetMatchedMessages(t *testing.T) { + a := assert.New(t) + s := NewStore() + msgs := []*gmqtt.Message{ + { + Topic: "a/b/c/d", + Payload: []byte{1, 2, 3}, + }, + { + Topic: "a/b/c/", + Payload: []byte{1, 2, 3, 4}, + }, + { + Topic: "a/", + Payload: []byte{1, 2, 3}, + }, + { + Topic: "a/b", + Payload: []byte{1, 2, 3}, + }, + { + Topic: "b/a", + Payload: []byte{1, 2, 3}, + }, + { + Topic: "a", + Payload: []byte{1, 2, 3}, + }, + } + var tt = []struct { + TopicFilter string + expected map[string]*gmqtt.Message + }{ + { + TopicFilter: "a/+/c/", + expected: map[string]*gmqtt.Message{ + "a/b/c/": { + Payload: []byte{1, 2, 3, 4}, + }, + }, + }, + { + TopicFilter: "a/+", + expected: map[string]*gmqtt.Message{ + "a/": { + Payload: []byte{1, 2, 3}, + }, + "a/b": { + Payload: []byte{1, 2, 3}, + }, + }, + }, + { + TopicFilter: "#", + expected: map[string]*gmqtt.Message{ + "a/b/c/d": { + Payload: []byte{1, 2, 3}, + }, + "a/b/c/": { + Payload: []byte{1, 2, 3, 4}, + }, + "a/": { + Payload: []byte{1, 2, 3}, + }, + "a/b": { + Payload: []byte{1, 2, 3}, + }, + "b/a": { + Payload: []byte{1, 2, 3}, + }, + "a": { + Payload: []byte{1, 2, 3}, + }, + }, + }, + { + TopicFilter: "a/#", + expected: map[string]*gmqtt.Message{ + "a/b/c/d": { + Payload: []byte{1, 2, 3}, + }, + "a/b/c/": { + Payload: []byte{1, 2, 3, 4}, + }, + "a/": { + Payload: []byte{1, 2, 3}, + }, + "a/b": { + Payload: []byte{1, 2, 3}, + }, + "a": { + Payload: []byte{1, 2, 3}, + }, + }, + }, + { + TopicFilter: "a/b/c/d", + expected: map[string]*gmqtt.Message{ + "a/b/c/d": { + Payload: []byte{1, 2, 3}, + }, + }, + }, + } + for _, v := range msgs { + s.AddOrReplace(v) + } + for _, v := range tt { + t.Run(v.TopicFilter, func(t *testing.T) { + rs := s.GetMatchedMessages(v.TopicFilter) + a.Equal(len(v.expected), len(rs)) + got := make(map[string]*gmqtt.Message) + for _, v := range rs { + got[v.Topic] = v + } + for k, v := range v.expected { + a.Equal(v.Payload, got[k].Payload) + } + }) + + } +} + +func TestTrieDB_Remove(t *testing.T) { + a := assert.New(t) + s := NewStore() + s.AddOrReplace(&gmqtt.Message{ + Topic: "a/b/c", + }) + s.AddOrReplace(&gmqtt.Message{ + Topic: "a/b/c/d", + Payload: []byte{1, 2, 3}, + }) + a.NotNil(s.GetRetainedMessage("a/b/c")) + s.Remove("a/b/c") + a.Nil(s.GetRetainedMessage("a/b/c")) +} + +func TestTrieDB_Iterate(t *testing.T) { + a := assert.New(t) + s := NewStore() + msgs := []*gmqtt.Message{ + { + Topic: "a/b/c/d", + Payload: []byte{1, 2, 3}, + }, + { + Topic: "a/b/c/", + Payload: []byte{1, 2, 3, 4}, + }, + { + Topic: "a/", + Payload: []byte{1, 2, 3}, + }, + { + Topic: "a/b", + Payload: []byte{1, 2, 3}, + }, + { + Topic: "a", + Payload: []byte{1, 2, 3}, + }, + { + Topic: "$SYS/a/b", + Payload: []byte{1, 2, 3}, + }, + } + + for _, v := range msgs { + s.AddOrReplace(v) + } + var rs []*gmqtt.Message + s.Iterate(func(message *gmqtt.Message) bool { + rs = append(rs, message) + return true + }) + a.ElementsMatch(msgs, rs) +} + +func TestTrieDB_Iterate_Cancel(t *testing.T) { + a := assert.New(t) + s := NewStore() + msgs := []*gmqtt.Message{ + { + Topic: "a/b/c/d", + Payload: []byte{1, 2, 3}, + }, + { + Topic: "a/b/c/", + Payload: []byte{1, 2, 3, 4}, + }, + { + Topic: "a/", + Payload: []byte{1, 2, 3}, + }, + { + Topic: "a/b", + Payload: []byte{1, 2, 3}, + }, + { + Topic: "a", + Payload: []byte{1, 2, 3}, + }, + } + + for _, v := range msgs { + s.AddOrReplace(v) + } + var i int + var rs []*gmqtt.Message + s.Iterate(func(message *gmqtt.Message) bool { + if i == 2 { + return false + } + rs = append(rs, message) + i++ + return true + }) + a.Len(rs, 2) + +} diff --git a/internal/hummingbird/mqttbroker/server/api_registrar.go b/internal/hummingbird/mqttbroker/server/api_registrar.go new file mode 100644 index 0000000..bc237b3 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/api_registrar.go @@ -0,0 +1,308 @@ +package server + +import ( + "context" + "crypto/tls" + "crypto/x509" + "io/ioutil" + "net" + "net/http" + "strings" + + grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap" + grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" + "github.com/grpc-ecosystem/grpc-gateway/runtime" + "github.com/grpc-ecosystem/grpc-gateway/utilities" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "google.golang.org/grpc" + gcodes "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + + "github.com/winc-link/hummingbird/internal/dtos" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" + "github.com/winc-link/hummingbird/internal/pkg/constants" + "github.com/winc-link/hummingbird/internal/pkg/httphelper" +) + +// APIRegistrar is the registrar for all gRPC servers and HTTP servers. +// It provides the ability for plugins to register gRPC and HTTP handler. +type APIRegistrar interface { + // RegisterHTTPHandler registers the handler to all http servers. + RegisterHTTPHandler(fn HTTPHandler) error + // RegisterService registers a service and its implementation to all gRPC servers. + RegisterService(desc *grpc.ServiceDesc, impl interface{}) +} + +type apiRegistrar struct { + gRPCServers []*gRPCServer + httpServers []*httpServer +} + +// RegisterService implements APIRegistrar interface +func (a *apiRegistrar) RegisterService(desc *grpc.ServiceDesc, impl interface{}) { + for _, v := range a.gRPCServers { + v.server.RegisterService(desc, impl) + } +} + +// RegisterHTTPHandler implements APIRegistrar interface +func (a *apiRegistrar) RegisterHTTPHandler(fn HTTPHandler) error { + var err error + for _, v := range a.httpServers { + schema, addr := splitEndpoint(v.gRPCEndpoint) + if schema == "unix" { + err = fn(context.Background(), v.mux, v.gRPCEndpoint, []grpc.DialOption{grpc.WithInsecure()}) + if err != nil { + return err + } + continue + } + err = fn(context.Background(), v.mux, addr, []grpc.DialOption{grpc.WithInsecure()}) + if err != nil { + return err + } + } + return nil +} + +type gRPCServer struct { + server *grpc.Server + serve func(errChan chan error) error + shutdown func() + endpoint string +} + +type httpServer struct { + gRPCEndpoint string + endpoint string + mux *runtime.ServeMux + tlsCfg *tls.Config + serve func(errChan chan error) error + shutdown func() +} + +// HTTPHandler is the http handler defined by gRPC-gateway. +type HTTPHandler = func(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) + +func splitEndpoint(endpoint string) (schema string, addr string) { + epParts := strings.SplitN(endpoint, "://", 2) + if len(epParts) == 1 && epParts[0] != "" { + epParts = []string{"tcp", epParts[0]} + } + return epParts[0], epParts[1] +} + +func buildTLSConfig(cfg *config.TLSOptions) (*tls.Config, error) { + c, err := tls.LoadX509KeyPair(cfg.Cert, cfg.Key) + if err != nil { + return nil, err + } + certPool := x509.NewCertPool() + if cfg.CACert != "" { + b, err := ioutil.ReadFile(cfg.CACert) + if err != nil { + return nil, err + } + certPool.AppendCertsFromPEM(b) + } + var cliAuthType tls.ClientAuthType + if cfg.Verify { + cliAuthType = tls.RequireAndVerifyClientCert + } + tlsCfg := &tls.Config{ + Certificates: []tls.Certificate{c}, + ClientCAs: certPool, + ClientAuth: cliAuthType, + } + return tlsCfg, nil +} + +func buildGRPCServer(endpoint *config.Endpoint) (*gRPCServer, error) { + var cred credentials.TransportCredentials + if cfg := endpoint.TLS; cfg != nil { + tlsCfg, err := buildTLSConfig(cfg) + if err != nil { + return nil, err + } + cred = credentials.NewTLS(tlsCfg) + } + server := grpc.NewServer( + grpc.Creds(cred), + grpc.ChainUnaryInterceptor( + grpc_zap.UnaryServerInterceptor(zaplog.Logger, grpc_zap.WithLevels(func(code gcodes.Code) zapcore.Level { + if code == gcodes.OK { + return zapcore.DebugLevel + } + return grpc_zap.DefaultClientCodeToLevel(code) + })), + grpc_prometheus.UnaryServerInterceptor), + ) + grpc_prometheus.Register(server) + shutdown := func() { + server.Stop() + } + serve := func(errChan chan error) error { + schema, addr := splitEndpoint(endpoint.Address) + l, err := net.Listen(schema, addr) + if err != nil { + return err + } + go func() { + select { + case errChan <- server.Serve(l): + default: + } + }() + return nil + } + + return &gRPCServer{ + server: server, + serve: serve, + shutdown: shutdown, + endpoint: endpoint.Address, + }, nil +} + +func buildHTTPServer(endpoint *config.Endpoint) (*httpServer, error) { + var tlsCfg *tls.Config + var err error + if cfg := endpoint.TLS; cfg != nil { + tlsCfg, err = buildTLSConfig(cfg) + if err != nil { + return nil, err + } + } + mux := runtime.NewServeMux(runtime.WithMarshalerOption(runtime.MIMEWildcard, &runtime.JSONPb{OrigName: true, EmitDefaults: true})) + server := &http.Server{ + Handler: mux, + } + mux.Handle("POST", runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, + []string{"v1", "advance", "config"}, "", runtime.AssumeColonVerbOpt(true))), + func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + var ( + level string + cfg dtos.AdvanceConfig + ) + // set log level + inboundMarshaler, _ := runtime.MarshalerForRequest(mux, req) + reader, err := utilities.IOReaderFactory(req.Body) + if err != nil { + zaplog.Sugar().Errorf("get req body error: %s", err) + httphelper.RenderFailNoLog(req.Context(), err, w) + return + } + if err = inboundMarshaler.NewDecoder(reader()).Decode(&cfg); err != nil { + zaplog.Sugar().Errorf("unmarshal req body error: %s", err) + httphelper.RenderFailNoLog(req.Context(), err, w) + return + } + zaplog.Sugar().Infof("recv change log level request: %d", cfg.LogLevel) + if level, err = zaplog.SetLogLevel(cfg.LogLevel); err != nil { + zaplog.Sugar().Errorf("set log level error: %s", err) + httphelper.RenderFailNoLog(req.Context(), err, w) + return + } + // write to configuration file + config.UpdateLogLevel(strings.ToLower(level)) + if err = config.WriteToFile(); err != nil { + zaplog.Sugar().Errorf("write to configuration file error: %s", err) + httphelper.RenderFailNoLog(req.Context(), err, w) + return + } + httphelper.ResultSuccessNoLog(cfg, w) + }) + + mux.Handle("GET", runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, + []string{"v1", "advance", "config"}, "", runtime.AssumeColonVerbOpt(true))), + func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + level := config.GetLogLevel() + li := LogLevelMap[strings.ToUpper(level)] + httphelper.ResultSuccessNoLog(dtos.AdvanceConfig{ + LogLevel: constants.LogLevel(li + 1), + }, w) + }) + + shutdown := func() { + server.Shutdown(context.Background()) + } + serve := func(errChan chan error) error { + schema, addr := splitEndpoint(endpoint.Address) + l, err := net.Listen(schema, addr) + if err != nil { + return err + } + if tlsCfg != nil { + l = tls.NewListener(l, tlsCfg) + } + go func() { + select { + case errChan <- server.Serve(l): + default: + } + }() + return nil + } + + return &httpServer{ + gRPCEndpoint: endpoint.Map, + mux: mux, + serve: serve, + shutdown: shutdown, + endpoint: endpoint.Address, + }, nil +} + +func (srv *server) exit() { + select { + case <-srv.exitChan: + default: + close(srv.exitChan) + } +} + +func (srv *server) serveAPIServer() { + var err error + defer func() { + srv.wg.Done() + if err != nil { + zaplog.Error("serveAPIServer error", zap.Error(err)) + srv.setError(err) + } + }() + errChan := make(chan error, 1) + defer func() { + for _, v := range srv.apiRegistrar.gRPCServers { + v.shutdown() + } + for _, v := range srv.apiRegistrar.httpServers { + v.shutdown() + } + }() + for _, v := range srv.apiRegistrar.gRPCServers { + err = v.serve(errChan) + if err != nil { + return + } + zaplog.Info("gRPC server started", zap.String("bind_address", v.endpoint)) + } + + for _, v := range srv.apiRegistrar.httpServers { + err = v.serve(errChan) + if err != nil { + return + } + zaplog.Info("HTTP server started", zap.String("bind_address", v.endpoint), zap.String("gRPC_endpoint", v.gRPCEndpoint)) + } + + for { + select { + case <-srv.exitChan: + return + case err = <-errChan: + return + } + + } +} diff --git a/internal/hummingbird/mqttbroker/server/client.go b/internal/hummingbird/mqttbroker/server/client.go new file mode 100644 index 0000000..e896656 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/client.go @@ -0,0 +1,1486 @@ +package server + +import ( + "bufio" + "bytes" + "context" + "crypto/md5" + "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "io" + "math" + "net" + "os" + "reflect" + "sync" + "sync/atomic" + "time" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/queue" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/subscription" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/unack" + "github.com/winc-link/hummingbird/internal/pkg/bitmap" + "github.com/winc-link/hummingbird/internal/pkg/codes" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +// Error +var ( + ErrConnectTimeOut = errors.New("connect time out") +) + +// Client status +const ( + Connecting = iota + Connected +) + +const ( + readBufferSize = 1024 + writeBufferSize = 1024 +) + +var ( + bufioReaderPool sync.Pool + bufioWriterPool sync.Pool +) + +func kvsToProperties(kvs []struct { + K []byte + V []byte +}) []packets.UserProperty { + u := make([]packets.UserProperty, len(kvs)) + for k, v := range kvs { + u[k].K = v.K + u[k].V = v.V + } + return u +} + +func newBufioReaderSize(r io.Reader, size int) *bufio.Reader { + if v := bufioReaderPool.Get(); v != nil { + br := v.(*bufio.Reader) + br.Reset(r) + return br + } + return bufio.NewReaderSize(r, size) +} + +func putBufioReader(br *bufio.Reader) { + br.Reset(nil) + bufioReaderPool.Put(br) +} + +func newBufioWriterSize(w io.Writer, size int) *bufio.Writer { + if v := bufioWriterPool.Get(); v != nil { + bw := v.(*bufio.Writer) + bw.Reset(w) + return bw + } + return bufio.NewWriterSize(w, size) +} + +func putBufioWriter(bw *bufio.Writer) { + bw.Reset(nil) + bufioWriterPool.Put(bw) +} + +func (client *client) ClientOptions() *ClientOptions { + return client.opts +} + +// ClientOptions is the options which controls how the server interacts with the client. +// It will be set after the client has connected. +type ClientOptions struct { + // ClientID is the client id for the client. + ClientID string + // Username is the username for the client. + Username string + // KeepAlive is the keep alive time in seconds for the client. + // The server will close the client if no there is no packet has been received for 1.5 times the KeepAlive time. + KeepAlive uint16 + // SessionExpiry is the session expiry interval in seconds. + // If the client version is v5, this value will be set into CONNACK Session Expiry Interval property. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901082 + SessionExpiry uint32 + // MaxInflight limits the number of QoS 1 and QoS 2 publications that the client is willing to process concurrently. + // For v3 client, it is default to config.MQTT.MaxInflight. + // For v5 client, it is the minimum of config.MQTT.MaxInflight and Receive Maximum property in CONNECT packet. + MaxInflight uint16 + // ReceiveMax limits the number of QoS 1 and QoS 2 publications that the server is willing to process concurrently for the Client. + // If the client version is v5, this value will be set into Receive Maximum property in CONNACK packet. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901083 + ReceiveMax uint16 + // ClientMaxPacketSize is the maximum packet size that the client is willing to accept. + // The server will drop the packet if it exceeds ClientMaxPacketSize. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901050 + ClientMaxPacketSize uint32 + // ServerMaxPacketSize is the maximum packet size that the server is willing to accept from the client. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901086 + ServerMaxPacketSize uint32 + // ClientTopicAliasMax is highest value that the client will accept as a Topic Alias sent by the server. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901051 + ClientTopicAliasMax uint16 + // ServerTopicAliasMax is highest value that the server will accept as a Topic Alias sent by the client. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901088 + ServerTopicAliasMax uint16 + // RequestProblemInfo is the value to indicate whether the Reason String or User Properties should be sent in the case of failures. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901053 + RequestProblemInfo bool + // UserProperties is the user properties provided by the client. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901090 + UserProperties []*packets.UserProperty + // WildcardSubAvailable indicates whether the client is permitted to send retained messages. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901091 + RetainAvailable bool + // WildcardSubAvailable indicates whether the client is permitted to subscribe Wildcard Subscriptions. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901091 + WildcardSubAvailable bool + // SubIDAvailable indicates whether the client is permitted to set Subscription Identifiers. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901092 + SubIDAvailable bool + // SharedSubAvailable indicates whether the client is permitted to subscribe Shared Subscriptions. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901093 + SharedSubAvailable bool + // AuthMethod is the auth method send by the client. + // Only MQTT v5 client can set this value. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901055 + AuthMethod []byte +} + +// Client represent a mqttclient client. +type Client interface { + // ClientOptions return a reference of ClientOptions. Do not edit. + // This is mainly used in hooks. + ClientOptions() *ClientOptions + // SessionInfo return a reference of session information of the client. Do not edit. + // Session info will be available after the client has passed OnSessionCreated or OnSessionResume. + SessionInfo() *gmqtt.Session + // Version return the protocol version of the used client. + Version() packets.Version + // ConnectedAt returns the connected time + ConnectedAt() time.Time + // Connection returns the raw net.Conn + Connection() net.Conn + // Close closes the client connection. + Close() + // Disconnect sends a disconnect packet to client, it is use to close v5 client. + Disconnect(disconnect *packets.Disconnect) +} + +// client represents a MQTT client and implements the Client interface +type client struct { + connectedAt int64 + server *server + wg sync.WaitGroup + rwc net.Conn //raw tcp connection + bufr *bufio.Reader + bufw *bufio.Writer + packetReader *packets.Reader + packetWriter *packets.Writer + in chan packets.Packet + out chan packets.Packet + close chan struct{} + closed chan struct{} + connected chan struct{} + status int32 + // if 1, when client close, the session expiry interval will be ignored and the session will be removed. + forceRemoveSession int32 + error chan error + errOnce sync.Once + err error + + opts *ClientOptions //set up before OnConnect() + session *gmqtt.Session + + cleanWillFlag bool // whether to remove will Msg + + disconnect *packets.Disconnect + + topicAliasManager TopicAliasManager + version packets.Version + aliasMapper [][]byte + + // gard serverReceiveMaximumQuota + serverQuotaMu sync.Mutex + serverReceiveMaximumQuota uint16 + + config config.Config + + queueStore queue.Store + unackStore unack.Store + pl *packetIDLimiter + queueNotifier *queueNotifier + // register requests the broker to add the client into the "active client list" before sending a positive CONNACK to the client. + register func(connect *packets.Connect, client *client) (sessionResume bool, err error) + // unregister requests the broker to remove the client from the "active client list" when the client is disconnected. + unregister func(client *client) + // deliverMessage + deliverMessage func(srcClientID string, msg *gmqtt.Message, options subscription.IterationOptions) (matched bool) +} + +func (client *client) SessionInfo() *gmqtt.Session { + return client.session +} + +func (client *client) Version() packets.Version { + return client.version +} + +func (client *client) Disconnect(disconnect *packets.Disconnect) { + client.write(disconnect) +} + +// ConnectedAt +func (client *client) ConnectedAt() time.Time { + return time.Unix(atomic.LoadInt64(&client.connectedAt), 0) +} + +// Connection returns the raw net.Conn +func (client *client) Connection() net.Conn { + return client.rwc +} + +func (client *client) setConnecting() { + atomic.StoreInt32(&client.status, Connecting) +} + +func (client *client) setConnected(time time.Time) { + atomic.StoreInt64(&client.connectedAt, time.Unix()) + atomic.StoreInt32(&client.status, Connected) +} + +//Status returns client's status +func (client *client) Status() int32 { + return atomic.LoadInt32(&client.status) +} + +// IsConnected returns whether the client is connected or not. +func (client *client) IsConnected() bool { + return client.Status() == Connected +} + +func (client *client) setError(err error) { + client.errOnce.Do(func() { + if err != nil && err != io.EOF { + zaplog.Error("connection lost", + zap.String("client_id", client.opts.ClientID), + zap.String("remote_addr", client.rwc.RemoteAddr().String()), + zap.Error(err)) + client.err = err + if client.version == packets.Version5 { + if code, ok := err.(*codes.Error); ok { + if client.IsConnected() { + // send Disconnect + client.write(&packets.Disconnect{ + Version: packets.Version5, + Code: code.Code, + Properties: &packets.Properties{ + ReasonString: code.ReasonString, + User: kvsToProperties(code.UserProperties), + }, + }) + } + } + } + } + }) +} + +func (client *client) writeLoop() { + var err error + srv := client.server + defer func() { + if re := recover(); re != nil { + err = errors.New(fmt.Sprint(re)) + } + client.setError(err) + }() + for { + select { + case <-client.close: + return + case packet := <-client.out: + switch p := packet.(type) { + case *packets.Publish: + if client.version == packets.Version5 { + if client.opts.ClientTopicAliasMax > 0 { + // use alias if exist + if alias, ok := client.topicAliasManager.Check(p); ok { + p.TopicName = []byte{} + p.Properties.TopicAlias = &alias + } else { + // alias not exist + if alias != 0 { + p.Properties.TopicAlias = &alias + } + } + } + } + // OnDelivered hook + if srv.hooks.OnDelivered != nil { + srv.hooks.OnDelivered(context.Background(), client, gmqtt.MessageFromPublish(p)) + } + srv.statsManager.messageSent(p.Qos, client.opts.ClientID) + case *packets.Puback, *packets.Pubcomp: + if client.version == packets.Version5 { + client.addServerQuota() + } + case *packets.Pubrec: + if client.version == packets.Version5 && p.Code >= codes.UnspecifiedError { + client.addServerQuota() + } + } + err = client.writePacket(packet) + if err != nil { + return + } + srv.statsManager.packetSent(packet, client.opts.ClientID) + if _, ok := packet.(*packets.Disconnect); ok { + _ = client.rwc.Close() + return + } + } + } +} + +func (client *client) writePacket(packet packets.Packet) error { + if client.server.config.Log.DumpPacket { + if ce := zaplog.Check(zapcore.DebugLevel, "sending packet"); ce != nil { + ce.Write( + zap.String("packet", packet.String()), + zap.String("remote_addr", client.rwc.RemoteAddr().String()), + zap.String("client_id", client.opts.ClientID), + ) + } + } + + return client.packetWriter.WriteAndFlush(packet) +} + +func (client *client) addServerQuota() { + client.serverQuotaMu.Lock() + if client.serverReceiveMaximumQuota < client.opts.ReceiveMax { + client.serverReceiveMaximumQuota++ + } + client.serverQuotaMu.Unlock() +} + +func (client *client) tryDecServerQuota() error { + client.serverQuotaMu.Lock() + defer client.serverQuotaMu.Unlock() + if client.serverReceiveMaximumQuota == 0 { + return codes.NewError(codes.RecvMaxExceeded) + } + client.serverReceiveMaximumQuota-- + return nil +} + +func (client *client) readLoop() { + var err error + srv := client.server + defer func() { + if re := recover(); re != nil { + err = errors.New(fmt.Sprint(re)) + } + client.setError(err) + close(client.in) + }() + for { + var packet packets.Packet + if client.IsConnected() { + if keepAlive := client.opts.KeepAlive; keepAlive != 0 { //KeepAlive + _ = client.rwc.SetReadDeadline(time.Now().Add(time.Duration(keepAlive/2+keepAlive) * time.Second)) + } + } + packet, err = client.packetReader.ReadPacket() + if err != nil { + if err != io.EOF && packet != nil { + zaplog.Error("read error", zap.String("packet_type", reflect.TypeOf(packet).String())) + } + return + } + + if pub, ok := packet.(*packets.Publish); ok { + srv.statsManager.messageReceived(pub.Qos, client.opts.ClientID) + if client.version == packets.Version5 && pub.Qos > packets.Qos0 { + err = client.tryDecServerQuota() + if err != nil { + return + } + } + } + client.in <- packet + <-client.connected + srv.statsManager.packetReceived(packet, client.opts.ClientID) + if client.server.config.Log.DumpPacket { + if ce := zaplog.Check(zapcore.DebugLevel, "received packet"); ce != nil { + ce.Write( + zap.String("packet", packet.String()), + zap.String("remote_addr", client.rwc.RemoteAddr().String()), + zap.String("client_id", client.opts.ClientID), + ) + } + } + } +} + +// Close closes the client connection. The returned channel will be closed after unregisterClient process has been done +func (client *client) Close() { + if client.rwc != nil { + _ = client.rwc.Close() + } +} + +var pid = os.Getpid() +var counter uint32 +var machineID = readMachineID() + +func readMachineID() []byte { + id := make([]byte, 3) + hostname, err1 := os.Hostname() + if err1 != nil { + _, err2 := io.ReadFull(rand.Reader, id) + if err2 != nil { + panic(fmt.Errorf("cannot get hostname: %v; %v", err1, err2)) + } + return id + } + hw := md5.New() + hw.Write([]byte(hostname)) + copy(id, hw.Sum(nil)) + return id +} + +func getRandomUUID() string { + var b [12]byte + // Timestamp, 4 bytes, big endian + binary.BigEndian.PutUint32(b[:], uint32(time.Now().Unix())) + // Machine, first 3 bytes of md5(hostname) + b[4] = machineID[0] + b[5] = machineID[1] + b[6] = machineID[2] + // Pid, 2 bytes, specs don't specify endianness, but we use big endian. + b[7] = byte(pid >> 8) + b[8] = byte(pid) + // Increment, 3 bytes, big endian + i := atomic.AddUint32(&counter, 1) + b[9] = byte(i >> 16) + b[10] = byte(i >> 8) + b[11] = byte(i) + return fmt.Sprintf(`%x`, string(b[:])) +} + +func bool2Byte(bo bool) *byte { + var b byte + if bo { + b = 1 + } else { + b = 0 + } + return &b +} + +func convertUint16(u *uint16, defaultValue uint16) uint16 { + if u == nil { + return defaultValue + } + return *u +} + +func convertUint32(u *uint32, defaultValue uint32) uint32 { + if u == nil { + return defaultValue + } + return *u +} + +func sendErrConnack(cli *client, err error) { + codeErr := converError(err) + // Override the error code if it is invalid for V3 client. + if packets.IsVersion3X(cli.version) && codeErr.Code > codes.V3NotAuthorized { + codeErr.Code = codes.NotAuthorized + } + cli.out <- &packets.Connack{ + Version: cli.version, + Code: codeErr.Code, + Properties: getErrorProperties(cli, &codeErr.ErrorDetails), + } +} + +func (client *client) connectWithTimeOut() (ok bool) { + // if any error occur, this function should set the error to the client and return false + var err error + defer func() { + if err != nil { + client.setError(err) + ok = false + } else { + ok = true + } + close(client.connected) + }() + timeout := time.NewTimer(5 * time.Second) + defer timeout.Stop() + var conn *packets.Connect + var authOpts *AuthOptions + // for enhanced auth + var onAuth OnAuth + + for { + select { + case p := <-client.in: + if p == nil { + return + } + code := codes.Success + var authData []byte + switch p.(type) { + case *packets.Connect: + if conn != nil { + err = codes.ErrProtocol + break + } + conn = p.(*packets.Connect) + var resp *EnhancedAuthResponse + authOpts, resp, err = client.connectHandler(conn) + if err != nil { + break + } + if resp != nil && resp.Continue { + code = codes.ContinueAuthentication + authData = resp.AuthData + onAuth = resp.OnAuth + } else { + code = codes.Success + } + case *packets.Auth: + if conn == nil || packets.IsVersion3X(client.version) { + err = codes.ErrProtocol + break + } + if onAuth == nil { + err = codes.ErrProtocol + break + } + au := p.(*packets.Auth) + if au.Code != codes.ContinueAuthentication { + err = codes.ErrProtocol + break + } + + var authResp *AuthResponse + authResp, err = client.authHandler(au, authOpts, onAuth) + if err != nil { + break + } + if authResp.Continue { + code = codes.ContinueAuthentication + authData = authResp.AuthData + } else { + code = codes.Success + } + default: + err = &codes.Error{ + Code: codes.MalformedPacket, + } + break + } + // authentication fail + if err != nil { + sendErrConnack(client, err) + return + } + // continue authentication (ContinueAuthentication is introduced in V5) + if code == codes.ContinueAuthentication { + client.out <- &packets.Auth{ + Code: code, + Properties: &packets.Properties{ + AuthMethod: conn.Properties.AuthMethod, + AuthData: authData, + }, + } + continue + } + + // authentication success + client.opts.RetainAvailable = authOpts.RetainAvailable + client.opts.WildcardSubAvailable = authOpts.WildcardSubAvailable + client.opts.SubIDAvailable = authOpts.SubIDAvailable + client.opts.SharedSubAvailable = authOpts.SharedSubAvailable + client.opts.SessionExpiry = authOpts.SessionExpiry + client.opts.MaxInflight = authOpts.MaxInflight + client.opts.ReceiveMax = authOpts.ReceiveMax + client.opts.ClientMaxPacketSize = math.MaxUint32 // unlimited + client.opts.ServerMaxPacketSize = authOpts.MaxPacketSize + client.opts.ServerTopicAliasMax = authOpts.TopicAliasMax + client.opts.Username = string(conn.Username) + + if len(conn.ClientID) == 0 { + if len(authOpts.AssignedClientID) != 0 { + client.opts.ClientID = string(authOpts.AssignedClientID) + } else { + client.opts.ClientID = getRandomUUID() + authOpts.AssignedClientID = []byte(client.opts.ClientID) + } + } else { + client.opts.ClientID = string(conn.ClientID) + } + + var connackPpt *packets.Properties + if client.version == packets.Version5 { + client.opts.MaxInflight = convertUint16(conn.Properties.ReceiveMaximum, client.opts.MaxInflight) + client.opts.ClientMaxPacketSize = convertUint32(conn.Properties.MaximumPacketSize, client.opts.ClientMaxPacketSize) + client.opts.ClientTopicAliasMax = convertUint16(conn.Properties.TopicAliasMaximum, client.opts.ClientTopicAliasMax) + client.opts.AuthMethod = conn.Properties.AuthMethod + client.serverReceiveMaximumQuota = client.opts.ReceiveMax + client.aliasMapper = make([][]byte, client.opts.ReceiveMax+1) + client.opts.KeepAlive = authOpts.KeepAlive + + var maxQoS byte + if authOpts.MaximumQoS >= 2 { + maxQoS = byte(1) + } else { + maxQoS = byte(0) + } + + connackPpt = &packets.Properties{ + SessionExpiryInterval: &authOpts.SessionExpiry, + ReceiveMaximum: &authOpts.ReceiveMax, + MaximumQoS: &maxQoS, + RetainAvailable: bool2Byte(authOpts.RetainAvailable), + TopicAliasMaximum: &authOpts.TopicAliasMax, + WildcardSubAvailable: bool2Byte(authOpts.WildcardSubAvailable), + SubIDAvailable: bool2Byte(authOpts.SubIDAvailable), + SharedSubAvailable: bool2Byte(authOpts.SharedSubAvailable), + MaximumPacketSize: &authOpts.MaxPacketSize, + ServerKeepAlive: &authOpts.KeepAlive, + AssignedClientID: authOpts.AssignedClientID, + ResponseInfo: authOpts.ResponseInfo, + } + } else { + client.opts.KeepAlive = conn.KeepAlive + } + + if keepAlive := client.opts.KeepAlive; keepAlive != 0 { //KeepAlive + _ = client.rwc.SetReadDeadline(time.Now().Add(time.Duration(keepAlive/2+keepAlive) * time.Second)) + } + client.newPacketIDLimiter(client.opts.MaxInflight) + + var sessionResume bool + sessionResume, err = client.register(conn, client) + if err != nil { + sendErrConnack(client, err) + return + } + connack := conn.NewConnackPacket(codes.Success, sessionResume) + if conn.Version == packets.Version5 { + connack.Properties = connackPpt + } + client.write(connack) + return + case <-timeout.C: + err = ErrConnectTimeOut + return + } + } +} + +func (client *client) basicAuth(conn *packets.Connect, authOpts *AuthOptions) (err error) { + srv := client.server + if srv.hooks.OnBasicAuth != nil { + err = srv.hooks.OnBasicAuth(context.Background(), client, &ConnectRequest{ + Connect: conn, + Options: authOpts, + }) + + } + return err +} + +func (client *client) enhancedAuth(conn *packets.Connect, authOpts *AuthOptions) (resp *EnhancedAuthResponse, err error) { + srv := client.server + if srv.hooks.OnEnhancedAuth == nil { + return nil, errors.New("OnEnhancedAuth hook is nil") + } + + resp, err = srv.hooks.OnEnhancedAuth(context.Background(), client, &ConnectRequest{ + Connect: conn, + Options: authOpts, + }) + if err == nil && resp == nil { + err = errors.New("return nil response from OnEnhancedAuth hook") + } + return resp, err +} + +func (client *client) connectHandler(conn *packets.Connect) (authOpts *AuthOptions, enhancedResp *EnhancedAuthResponse, err error) { + if !client.config.MQTT.AllowZeroLenClientID && len(conn.ClientID) == 0 { + err = &codes.Error{ + Code: codes.ClientIdentifierNotValid, + } + return + } + client.version = conn.Version + // default auth options + authOpts = client.defaultAuthOptions(conn) + + if packets.IsVersion3X(client.version) || (packets.IsVersion5(client.version) && conn.Properties.AuthMethod == nil) { + err = client.basicAuth(conn, authOpts) + } + if client.version == packets.Version5 && conn.Properties.AuthMethod != nil { + enhancedResp, err = client.enhancedAuth(conn, authOpts) + } + + return +} + +func (client *client) authHandler(auth *packets.Auth, authOpts *AuthOptions, onAuth OnAuth) (resp *AuthResponse, err error) { + authResp, err := onAuth(context.Background(), client, &AuthRequest{ + Auth: auth, + Options: authOpts, + }) + if err == nil && authResp == nil { + return nil, errors.New("return nil response from OnAuth hook") + } + return authResp, err +} + +func getErrorProperties(client *client, errDetails *codes.ErrorDetails) *packets.Properties { + if client.version == packets.Version5 && client.opts.RequestProblemInfo && errDetails != nil { + return &packets.Properties{ + ReasonString: errDetails.ReasonString, + User: kvsToProperties(errDetails.UserProperties), + } + } + return nil +} + +func (client *client) defaultAuthOptions(connect *packets.Connect) *AuthOptions { + opts := &AuthOptions{ + SessionExpiry: uint32(client.config.MQTT.SessionExpiry.Seconds()), + ReceiveMax: client.config.MQTT.ReceiveMax, + MaximumQoS: client.config.MQTT.MaximumQoS, + MaxPacketSize: client.config.MQTT.MaxPacketSize, + TopicAliasMax: client.config.MQTT.TopicAliasMax, + RetainAvailable: client.config.MQTT.RetainAvailable, + WildcardSubAvailable: client.config.MQTT.WildcardAvailable, + SubIDAvailable: client.config.MQTT.SubscriptionIDAvailable, + SharedSubAvailable: client.config.MQTT.SharedSubAvailable, + KeepAlive: client.config.MQTT.MaxKeepAlive, + MaxInflight: client.config.MQTT.MaxInflight, + } + if connect.KeepAlive < opts.KeepAlive { + opts.KeepAlive = connect.KeepAlive + } + if client.version == packets.Version5 { + if i := connect.Properties.SessionExpiryInterval; i == nil { + opts.SessionExpiry = 0 + } else if *i < opts.SessionExpiry { + opts.SessionExpiry = *i + + } + } + return opts +} + +func (client *client) internalClose() { + if client.IsConnected() { + // OnClosed hooks + if client.server.hooks.OnClosed != nil { + client.server.hooks.OnClosed(context.Background(), client, client.err) + } + client.unregister(client) + client.server.statsManager.clientDisconnected(client.opts.ClientID) + } + putBufioReader(client.bufr) + putBufioWriter(client.bufw) + close(client.closed) + +} + +func (client *client) checkMaxPacketSize(msg *gmqtt.Message) (valid bool) { + totalBytes := msg.TotalBytes(packets.Version5) + if client.opts.ClientMaxPacketSize != 0 && totalBytes > client.opts.ClientMaxPacketSize { + return false + } + return true +} + +func (client *client) write(packets packets.Packet) { + select { + case <-client.close: + return + case client.out <- packets: + } +} + +func (client *client) subscribeHandler(sub *packets.Subscribe) *codes.Error { + srv := client.server + suback := &packets.Suback{ + Version: sub.Version, + PacketID: sub.PacketID, + Properties: &packets.Properties{}, + Payload: make([]codes.Code, len(sub.Topics)), + } + var subID uint32 + now := time.Now() + if client.version == packets.Version5 { + if client.opts.SubIDAvailable && len(sub.Properties.SubscriptionIdentifier) != 0 { + subID = sub.Properties.SubscriptionIdentifier[0] + } + if !client.config.MQTT.SubscriptionIDAvailable && subID != 0 { + return &codes.Error{ + Code: codes.SubIDNotSupported, + } + } + } + subReq := &SubscribeRequest{ + Subscribe: sub, + Subscriptions: make(map[string]*struct { + Sub *gmqtt.Subscription + Error error + }), + ID: subID, + } + + for _, v := range sub.Topics { + subReq.Subscriptions[v.Name] = &struct { + Sub *gmqtt.Subscription + Error error + }{Sub: subscription.FromTopic(v, subID), Error: nil} + } + + if srv.hooks.OnSubscribe != nil { + err := srv.hooks.OnSubscribe(context.Background(), client, subReq) + if ce := converError(err); ce != nil { + suback.Properties = getErrorProperties(client, &ce.ErrorDetails) + for k := range suback.Payload { + if packets.IsVersion3X(client.version) { + suback.Payload[k] = packets.SubscribeFailure + } else { + suback.Payload[k] = ce.Code + } + } + client.write(suback) + return nil + } + } + for k, v := range sub.Topics { + sub := subReq.Subscriptions[v.Name].Sub + subErr := converError(subReq.Subscriptions[v.Name].Error) + var isShared bool + code := sub.QoS + if client.version == packets.Version5 { + if sub.ShareName != "" { + isShared = true + if !client.opts.SharedSubAvailable { + code = codes.SharedSubNotSupported + } + } + if !client.opts.SubIDAvailable && subID != 0 { + code = codes.SubIDNotSupported + } + if !client.opts.WildcardSubAvailable { + for _, c := range sub.TopicFilter { + if c == '+' || c == '#' { + code = codes.WildcardSubNotSupported + break + } + } + } + } + + var subRs subscription.SubscribeResult + var err error + if subErr != nil { + code = subErr.Code + if packets.IsVersion3X(client.version) { + code = packets.SubscribeFailure + } + } + if code < packets.SubscribeFailure { + subRs, err = srv.subscriptionsDB.Subscribe(client.opts.ClientID, sub) + if err != nil { + zaplog.Error("failed to subscribe topic", + zap.String("topic", v.Name), + zap.Uint8("qos", v.Qos), + zap.String("client_id", client.opts.ClientID), + zap.String("remote_addr", client.rwc.RemoteAddr().String()), + zap.Error(err)) + code = packets.SubscribeFailure + } + } + suback.Payload[k] = code + if code < packets.SubscribeFailure { + if srv.hooks.OnSubscribed != nil { + srv.hooks.OnSubscribed(context.Background(), client, sub) + } + zaplog.Info("subscribe succeeded", + zap.String("topic", sub.TopicFilter), + zap.Uint8("qos", sub.QoS), + zap.Uint8("retain_handling", sub.RetainHandling), + zap.Bool("retain_as_published", sub.RetainAsPublished), + zap.Bool("no_local", sub.NoLocal), + zap.Uint32("id", sub.ID), + zap.String("client_id", client.opts.ClientID), + zap.String("remote_addr", client.rwc.RemoteAddr().String()), + ) + // The spec does not specify whether the retain message should follow the 'no-local' option rule. + // Gmqtt follows the mosquitto implementation which will send retain messages to no-local subscriptions. + // For details: https://github.com/eclipse/mosquitto/issues/1796 + if !isShared && ((!subRs[0].AlreadyExisted && v.RetainHandling != 2) || v.RetainHandling == 0) { + msgs := srv.retainedDB.GetMatchedMessages(sub.TopicFilter) + for _, v := range msgs { + if v.QoS > subRs[0].Subscription.QoS { + v.QoS = subRs[0].Subscription.QoS + } + v.Dup = false + if !sub.RetainAsPublished { + v.Retained = false + } + var expiry time.Time + if v.MessageExpiry != 0 { + expiry = now.Add(time.Second * time.Duration(v.MessageExpiry)) + } + err := client.queueStore.Add(&queue.Elem{ + At: now, + Expiry: expiry, + MessageWithID: &queue.Publish{ + Message: v, + }, + }) + if err != nil { + client.queueNotifier.notifyDropped(v, &queue.InternalError{Err: err}) + if codesErr, ok := err.(*codes.Error); ok { + return codesErr + } + return &codes.Error{ + Code: codes.UnspecifiedError, + } + } + } + } + } else { + zaplog.Info("subscribe failed", + zap.String("topic", sub.TopicFilter), + zap.Uint8("qos", suback.Payload[k]), + zap.String("client_id", client.opts.ClientID), + zap.String("remote_addr", client.rwc.RemoteAddr().String()), + ) + } + } + client.write(suback) + return nil +} + +func (client *client) publishHandler(pub *packets.Publish) *codes.Error { + srv := client.server + var dup bool + + // check retain available + if !client.opts.RetainAvailable && pub.Retain { + return &codes.Error{ + Code: codes.RetainNotSupported, + } + } + var msg *gmqtt.Message + msg = gmqtt.MessageFromPublish(pub) + + if client.version == packets.Version5 && pub.Properties.TopicAlias != nil { + if *pub.Properties.TopicAlias >= client.opts.ServerTopicAliasMax { + return &codes.Error{ + Code: codes.TopicAliasInvalid, + } + } + topicAlias := *pub.Properties.TopicAlias + name := client.aliasMapper[int(topicAlias)] + if len(pub.TopicName) == 0 { + if len(name) == 0 { + return &codes.Error{ + Code: codes.TopicAliasInvalid, + } + } + msg.Topic = string(name) + } else { + client.aliasMapper[topicAlias] = pub.TopicName + } + + } + + if pub.Qos == packets.Qos2 { + exist, err := client.unackStore.Set(pub.PacketID) + if err != nil { + return converError(err) + } + if exist { + dup = true + } + } + + if pub.Retain { + if len(pub.Payload) == 0 { + srv.retainedDB.Remove(string(pub.TopicName)) + } else { + srv.retainedDB.AddOrReplace(msg.Copy()) + } + } + + var err error + var topicMatched bool + if !dup { + opts := defaultIterateOptions(msg.Topic) + if srv.hooks.OnMsgArrived != nil { + req := &MsgArrivedRequest{ + Publish: pub, + Message: msg, + IterationOptions: opts, + } + err = srv.hooks.OnMsgArrived(context.Background(), client, req) + msg = req.Message + opts = req.IterationOptions + } + if msg != nil && err == nil { + topicMatched = client.deliverMessage(client.opts.ClientID, msg, opts) + } + } + + var ack packets.Packet + // ack properties + var ppt *packets.Properties + code := codes.Success + if client.version == packets.Version5 { + if !topicMatched && err == nil { + code = codes.NotMatchingSubscribers + } + if codeErr := converError(err); codeErr != nil { + ppt = getErrorProperties(client, &codeErr.ErrorDetails) + code = codeErr.Code + } + + } + if pub.Qos == packets.Qos1 { + ack = pub.NewPuback(code, ppt) + } + if pub.Qos == packets.Qos2 { + ack = pub.NewPubrec(code, ppt) + if code >= codes.UnspecifiedError { + err = client.unackStore.Remove(pub.PacketID) + if err != nil { + return converError(err) + } + } + } + if ack != nil { + client.write(ack) + } + return nil + +} + +func converError(err error) *codes.Error { + if err == nil { + return nil + } + if e, ok := err.(*codes.Error); ok { + return e + } + return &codes.Error{ + Code: codes.UnspecifiedError, + ErrorDetails: codes.ErrorDetails{ + ReasonString: []byte(err.Error()), + }, + } +} + +func (client *client) pubackHandler(puback *packets.Puback) *codes.Error { + err := client.queueStore.Remove(puback.PacketID) + if err != nil { + return converError(err) + } + client.pl.release(puback.PacketID) + if ce := zaplog.Check(zapcore.DebugLevel, "unset inflight"); ce != nil { + ce.Write(zap.String("clientID", client.opts.ClientID), + zap.Uint16("pid", puback.PacketID), + ) + } + return nil +} +func (client *client) pubrelHandler(pubrel *packets.Pubrel) *codes.Error { + err := client.unackStore.Remove(pubrel.PacketID) + if err != nil { + return converError(err) + } + pubcomp := pubrel.NewPubcomp() + client.write(pubcomp) + return nil +} +func (client *client) pubrecHandler(pubrec *packets.Pubrec) { + if client.version == packets.Version5 && pubrec.Code >= codes.UnspecifiedError { + err := client.queueStore.Remove(pubrec.PacketID) + client.pl.release(pubrec.PacketID) + if err != nil { + client.setError(err) + } + return + } + pubrel := pubrec.NewPubrel() + _, err := client.queueStore.Replace(&queue.Elem{ + At: time.Now(), + MessageWithID: &queue.Pubrel{ + PacketID: pubrel.PacketID, + }}) + if err != nil { + client.setError(err) + } + client.write(pubrel) +} +func (client *client) pubcompHandler(pubcomp *packets.Pubcomp) { + err := client.queueStore.Remove(pubcomp.PacketID) + client.pl.release(pubcomp.PacketID) + if err != nil { + client.setError(err) + } + +} +func (client *client) pingreqHandler(pingreq *packets.Pingreq) { + resp := pingreq.NewPingresp() + client.write(resp) +} +func (client *client) unsubscribeHandler(unSub *packets.Unsubscribe) { + srv := client.server + unSuback := &packets.Unsuback{ + Version: unSub.Version, + PacketID: unSub.PacketID, + Properties: &packets.Properties{}, + } + cs := make([]codes.Code, len(unSub.Topics)) + defer func() { + if client.version == packets.Version5 { + unSuback.Payload = cs + } + client.write(unSuback) + }() + req := &UnsubscribeRequest{ + Unsubscribe: unSub, + Unsubs: make(map[string]*struct { + TopicName string + Error error + }), + } + + for _, v := range unSub.Topics { + req.Unsubs[v] = &struct { + TopicName string + Error error + }{TopicName: v} + } + if srv.hooks.OnUnsubscribe != nil { + err := srv.hooks.OnUnsubscribe(context.Background(), client, req) + if ce := converError(err); ce != nil { + unSuback.Properties = getErrorProperties(client, &ce.ErrorDetails) + for k := range cs { + cs[k] = ce.Code + } + return + } + } + for k, v := range unSub.Topics { + code := codes.Success + topicName := req.Unsubs[v].TopicName + ce := converError(req.Unsubs[v].Error) + if ce != nil { + code = ce.Code + } + if code == codes.Success { + err := srv.subscriptionsDB.Unsubscribe(client.opts.ClientID, topicName) + if ce := converError(err); ce != nil { + code = ce.Code + } + } + if code == codes.Success { + if srv.hooks.OnUnsubscribed != nil { + srv.hooks.OnUnsubscribed(context.Background(), client, topicName) + } + zaplog.Info("unsubscribed succeed", + zap.String("topic", topicName), + zap.String("client_id", client.opts.ClientID), + zap.String("remote_addr", client.rwc.RemoteAddr().String()), + ) + } else { + zaplog.Info("unsubscribed failed", + zap.String("topic", topicName), + zap.String("client_id", client.opts.ClientID), + zap.String("remote_addr", client.rwc.RemoteAddr().String()), + zap.Uint8("code", code)) + } + cs[k] = code + + } +} + +func (client *client) reAuthHandler(auth *packets.Auth) *codes.Error { + srv := client.server + // default code + code := codes.Success + var resp *AuthResponse + var err error + if srv.hooks.OnReAuth != nil { + resp, err = srv.hooks.OnReAuth(context.Background(), client, auth) + ce := converError(err) + if ce != nil { + return ce + } + } else { + return codes.ErrProtocol + } + if resp.Continue { + code = codes.ContinueAuthentication + } + client.write(&packets.Auth{ + Code: code, + Properties: &packets.Properties{ + AuthMethod: client.opts.AuthMethod, + AuthData: resp.AuthData, + }, + }) + return nil +} + +func (client *client) disconnectHandler(dis *packets.Disconnect) *codes.Error { + if client.version == packets.Version5 { + disExpiry := convertUint32(dis.Properties.SessionExpiryInterval, 0) + sess, err := client.server.sessionStore.Get(client.opts.ClientID) + if err != nil { + return &codes.Error{ + Code: codes.UnspecifiedError, + ErrorDetails: codes.ErrorDetails{ + ReasonString: []byte(err.Error()), + }, + } + } + if sess.ExpiryInterval == 0 && disExpiry != 0 { + return &codes.Error{ + Code: codes.ProtocolError, + } + } + if disExpiry != 0 { + err := client.server.sessionStore.SetSessionExpiry(sess.ClientID, disExpiry) + if err != nil { + zaplog.Error("fail to set session expiry", + zap.String("client_id", client.opts.ClientID), + zap.Error(err)) + } + } + } + client.disconnect = dis + // 不发送will message + client.cleanWillFlag = true + return nil +} + +//读处理 +func (client *client) readHandle() { + var err error + defer func() { + if re := recover(); re != nil { + err = errors.New(fmt.Sprint(re)) + } + client.setError(err) + close(client.close) + }() + for packet := range client.in { + if client.version == packets.Version5 { + if client.opts.ServerMaxPacketSize != 0 && packets.TotalBytes(packet) > client.opts.ServerMaxPacketSize { + err = codes.NewError(codes.PacketTooLarge) + return + } + } + var codeErr *codes.Error + switch packet.(type) { + case *packets.Subscribe: + codeErr = client.subscribeHandler(packet.(*packets.Subscribe)) + case *packets.Publish: + codeErr = client.publishHandler(packet.(*packets.Publish)) + case *packets.Puback: + codeErr = client.pubackHandler(packet.(*packets.Puback)) + case *packets.Pubrel: + codeErr = client.pubrelHandler(packet.(*packets.Pubrel)) + case *packets.Pubrec: + client.pubrecHandler(packet.(*packets.Pubrec)) + case *packets.Pubcomp: + client.pubcompHandler(packet.(*packets.Pubcomp)) + case *packets.Pingreq: + client.pingreqHandler(packet.(*packets.Pingreq)) + case *packets.Unsubscribe: + client.unsubscribeHandler(packet.(*packets.Unsubscribe)) + case *packets.Disconnect: + codeErr = client.disconnectHandler(packet.(*packets.Disconnect)) + return + case *packets.Auth: + auth := packet.(*packets.Auth) + if client.version != packets.Version5 { + err = codes.ErrProtocol + return + } + if !bytes.Equal(client.opts.AuthMethod, auth.Properties.AuthData) { + codeErr = codes.ErrProtocol + return + } + codeErr = client.reAuthHandler(auth) + + default: + err = codes.ErrProtocol + } + if codeErr != nil { + err = codeErr + return + } + } + +} + +func (client *client) newPacketIDLimiter(limit uint16) { + client.pl = &packetIDLimiter{ + cond: sync.NewCond(&sync.Mutex{}), + used: 0, + limit: limit, + exit: false, + freePid: 1, + lockedPid: bitmap.New(packets.MaxPacketID), + } +} + +func (client *client) pollInflights() (cont bool, err error) { + var elems []*queue.Elem + elems, err = client.queueStore.ReadInflight(uint(client.opts.MaxInflight)) + if err != nil || len(elems) == 0 { + return false, err + } + client.pl.lock() + defer client.pl.unlock() + for _, v := range elems { + id := v.MessageWithID.ID() + switch m := v.MessageWithID.(type) { + case *queue.Publish: + m.Dup = true + // https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Subscription_Options + // The Server need not use the same set of Subscription Identifiers in the retransmitted PUBLISH packet. + m.SubscriptionIdentifier = nil + client.pl.markUsedLocked(id) + client.write(gmqtt.MessageToPublish(m.Message, client.version)) + case *queue.Pubrel: + client.write(&packets.Pubrel{PacketID: id}) + } + } + + return true, nil +} + +func (client *client) pollNewMessages(ids []packets.PacketID) (unused []packets.PacketID, err error) { + now := time.Now() + var elems []*queue.Elem + elems, err = client.queueStore.Read(ids) + if err != nil { + return nil, err + } + for _, v := range elems { + switch m := v.MessageWithID.(type) { + case *queue.Publish: + if m.QoS != packets.Qos0 { + ids = ids[1:] + } + if client.version == packets.Version5 && m.Message.MessageExpiry != 0 { + d := uint32(now.Sub(v.At).Seconds()) + m.Message.MessageExpiry = d + } + client.write(gmqtt.MessageToPublish(m.Message, client.version)) + case *queue.Pubrel: + } + } + return ids, err +} +func (client *client) pollMessageHandler() { + var err error + defer func() { + if re := recover(); re != nil { + err = errors.New(fmt.Sprint(re)) + } + client.setError(err) + }() + // drain all inflight messages + cont := true + for cont { + cont, err = client.pollInflights() + if err != nil { + return + } + } + var ids []packets.PacketID + for { + max := uint16(100) + if client.opts.MaxInflight < max { + max = client.opts.MaxInflight + } + ids = client.pl.pollPacketIDs(max) + if ids == nil { + return + } + ids, err = client.pollNewMessages(ids) + if err != nil { + return + } + client.pl.batchRelease(ids) + } +} + +//server goroutine结束的条件:1客户端断开连接 或 2发生错误 +func (client *client) serve() { + defer client.internalClose() + readWg := &sync.WaitGroup{} + + readWg.Add(1) + go func() { //read + client.readLoop() + readWg.Done() + }() + + client.wg.Add(1) + go func() { //write + client.writeLoop() + client.wg.Done() + }() + + if ok := client.connectWithTimeOut(); ok { + client.wg.Add(2) + go func() { + client.pollMessageHandler() + client.wg.Done() + }() + go func() { + client.readHandle() + client.wg.Done() + }() + + } + readWg.Wait() + + if client.queueStore != nil { + qerr := client.queueStore.Close() + if qerr != nil { + zaplog.Error("fail to close message queue", zap.String("client_id", client.opts.ClientID), zap.Error(qerr)) + } + } + if client.pl != nil { + client.pl.close() + } + client.wg.Wait() + _ = client.rwc.Close() +} diff --git a/internal/hummingbird/mqttbroker/server/client_mock.go b/internal/hummingbird/mqttbroker/server/client_mock.go new file mode 100644 index 0000000..dcaad24 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/client_mock.go @@ -0,0 +1,132 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: server/client.go + +// Package server is a generated GoMock package. +package server + +import ( + net "net" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + packets "github.com/winc-link/hummingbird/internal/pkg/packets" + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" +) + +// MockClient is a mock of Client interface +type MockClient struct { + ctrl *gomock.Controller + recorder *MockClientMockRecorder +} + +// MockClientMockRecorder is the mock recorder for MockClient +type MockClientMockRecorder struct { + mock *MockClient +} + +// NewMockClient creates a new mock instance +func NewMockClient(ctrl *gomock.Controller) *MockClient { + mock := &MockClient{ctrl: ctrl} + mock.recorder = &MockClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockClient) EXPECT() *MockClientMockRecorder { + return m.recorder +} + +// ClientOptions mocks base method +func (m *MockClient) ClientOptions() *ClientOptions { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ClientOptions") + ret0, _ := ret[0].(*ClientOptions) + return ret0 +} + +// ClientOptions indicates an expected call of ClientOptions +func (mr *MockClientMockRecorder) ClientOptions() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClientOptions", reflect.TypeOf((*MockClient)(nil).ClientOptions)) +} + +// SessionInfo mocks base method +func (m *MockClient) SessionInfo() *gmqtt.Session { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SessionInfo") + ret0, _ := ret[0].(*gmqtt.Session) + return ret0 +} + +// SessionInfo indicates an expected call of SessionInfo +func (mr *MockClientMockRecorder) SessionInfo() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SessionInfo", reflect.TypeOf((*MockClient)(nil).SessionInfo)) +} + +// Version mocks base method +func (m *MockClient) Version() packets.Version { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Version") + ret0, _ := ret[0].(packets.Version) + return ret0 +} + +// Version indicates an expected call of Version +func (mr *MockClientMockRecorder) Version() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Version", reflect.TypeOf((*MockClient)(nil).Version)) +} + +// ConnectedAt mocks base method +func (m *MockClient) ConnectedAt() time.Time { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ConnectedAt") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// ConnectedAt indicates an expected call of ConnectedAt +func (mr *MockClientMockRecorder) ConnectedAt() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectedAt", reflect.TypeOf((*MockClient)(nil).ConnectedAt)) +} + +// Connection mocks base method +func (m *MockClient) Connection() net.Conn { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Connection") + ret0, _ := ret[0].(net.Conn) + return ret0 +} + +// Connection indicates an expected call of Connection +func (mr *MockClientMockRecorder) Connection() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connection", reflect.TypeOf((*MockClient)(nil).Connection)) +} + +// Close mocks base method +func (m *MockClient) Close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Close") +} + +// Close indicates an expected call of Close +func (mr *MockClientMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockClient)(nil).Close)) +} + +// Disconnect mocks base method +func (m *MockClient) Disconnect(disconnect *packets.Disconnect) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Disconnect", disconnect) +} + +// Disconnect indicates an expected call of Disconnect +func (mr *MockClientMockRecorder) Disconnect(disconnect interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockClient)(nil).Disconnect), disconnect) +} diff --git a/internal/hummingbird/mqttbroker/server/hook.go b/internal/hummingbird/mqttbroker/server/hook.go new file mode 100644 index 0000000..078df37 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/hook.go @@ -0,0 +1,338 @@ +package server + +import ( + "context" + "net" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/subscription" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +type Hooks struct { + OnAccept + OnStop + OnSubscribe + OnSubscribed + OnUnsubscribe + OnUnsubscribed + OnMsgArrived + OnBasicAuth + OnEnhancedAuth + OnReAuth + OnConnected + OnSessionCreated + OnSessionResumed + OnSessionTerminated + OnDelivered + OnClosed + OnMsgDropped + OnWillPublish + OnWillPublished +} + +// WillMsgRequest is the input param for OnWillPublish hook. +type WillMsgRequest struct { + // Message is the message that is going to send. + // The caller can edit this field to modify the will message. + // If nil, the broker will drop the message. + Message *mqttbroker.Message + // IterationOptions is the same as MsgArrivedRequest.IterationOptions, + // see MsgArrivedRequest for details + IterationOptions subscription.IterationOptions +} + +// Drop drops the will message, so the message will not be delivered to any alertclient. +func (w *WillMsgRequest) Drop() { + w.Message = nil +} + +// OnWillPublish will be called before the client with the given clientID sending the will message. +// It provides the ability to modify the message before sending. +type OnWillPublish func(ctx context.Context, clientID string, req *WillMsgRequest) + +type OnWillPublishWrapper func(OnWillPublish) OnWillPublish + +// OnWillPublished will be called after the will message has been sent by the client. +// The msg param is immutable, DO NOT EDIT. +type OnWillPublished func(ctx context.Context, clientID string, msg *mqttbroker.Message) + +type OnWillPublishedWrapper func(OnWillPublished) OnWillPublished + +// OnAccept will be called after a new connection established in TCP server. +// If returns false, the connection will be close directly. +type OnAccept func(ctx context.Context, conn net.Conn) bool + +type OnAcceptWrapper func(OnAccept) OnAccept + +// OnStop will be called on server.Stop() +type OnStop func(ctx context.Context) + +type OnStopWrapper func(OnStop) OnStop + +// SubscribeRequest represents the subscribe request made by a SUBSCRIBE packet. +type SubscribeRequest struct { + // Subscribe is the SUBSCRIBE packet. It is immutable, do not edit. + Subscribe *packets.Subscribe + // Subscriptions wraps all subscriptions by the full topic name. + // You can modify the value of the map to edit the subscription. But must not change the length of the map. + Subscriptions map[string]*struct { + // Sub is the subscription. + Sub *mqttbroker.Subscription + // Error indicates whether to allow the subscription. + // Return nil means it is allow to make the subscription. + // Return an error means it is not allow to make the subscription. + // It is recommended to use *codes.Error if you want to disallow the subscription. e.g:&codes.Error{Code:codes.NotAuthorized} + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901178 + Error error + } + // ID is the subscription id, this value will override the id of subscriptions in Subscriptions.Sub. + // This field take no effect on v3 client. + ID uint32 +} + +// GrantQoS grants the qos to the subscription for the given topic name. +func (s *SubscribeRequest) GrantQoS(topicName string, qos packets.QoS) *SubscribeRequest { + if sub := s.Subscriptions[topicName]; sub != nil { + sub.Sub.QoS = qos + } + return s +} + +// Reject rejects the subscription for the given topic name. +func (s *SubscribeRequest) Reject(topicName string, err error) { + if sub := s.Subscriptions[topicName]; sub != nil { + sub.Error = err + } +} + +// SetID sets the subscription id for the subscriptions +func (s *SubscribeRequest) SetID(id uint32) *SubscribeRequest { + s.ID = id + return s +} + +// OnSubscribe will be called when receive a SUBSCRIBE packet. +// It provides the ability to modify and authorize the subscriptions. +// If return an error, the returned error will override the error set in SubscribeRequest. +type OnSubscribe func(ctx context.Context, client Client, req *SubscribeRequest) error + +type OnSubscribeWrapper func(OnSubscribe) OnSubscribe + +// OnSubscribed will be called after the topic subscribe successfully +type OnSubscribed func(ctx context.Context, client Client, subscription *mqttbroker.Subscription) + +type OnSubscribedWrapper func(OnSubscribed) OnSubscribed + +// OnUnsubscribed will be called after the topic has been unsubscribed +type OnUnsubscribed func(ctx context.Context, client Client, topicName string) + +type OnUnsubscribedWrapper func(OnUnsubscribed) OnUnsubscribed + +// UnsubscribeRequest is the input param for OnSubscribed hook. +type UnsubscribeRequest struct { + // Unsubscribe is the UNSUBSCRIBE packet. It is immutable, do not edit. + Unsubscribe *packets.Unsubscribe + // Unsubs groups all unsubscribe topic by the full topic name. + // You can modify the value of the map to edit the unsubscribe topic. But you cannot change the length of the map. + Unsubs map[string]*struct { + // TopicName is the topic that is going to unsubscribe. + TopicName string + // Error indicates whether to allow the unsubscription. + // Return nil means it is allow to unsubscribe the topic. + // Return an error means it is not allow to unsubscribe the topic. + // It is recommended to use *codes.Error if you want to disallow the unsubscription. e.g:&codes.Error{Code:codes.NotAuthorized} + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901194 + Error error + } +} + +// Reject rejects the subscription for the given topic name. +func (u *UnsubscribeRequest) Reject(topicName string, err error) { + if sub := u.Unsubs[topicName]; sub != nil { + sub.Error = err + } +} + +// OnUnsubscribe will be called when receive a UNSUBSCRIBE packet. +// User can use this function to modify and authorize unsubscription. +// If return an error, the returned error will override the error set in UnsubscribeRequest. +type OnUnsubscribe func(ctx context.Context, client Client, req *UnsubscribeRequest) error + +type OnUnsubscribeWrapper func(OnUnsubscribe) OnUnsubscribe + +// OnMsgArrived will be called when receive a Publish packets.It provides the ability to modify the message before topic match process. +// The return error is for V5 client to provide additional information for diagnostics and will be ignored if the version of used client is V3. +// If the returned error type is *codes.Error, the code, reason string and user property will be set into the ack packet(puback for qos1, and pubrel for qos2); +// otherwise, the code,reason string will be set to 0x80 and error.Error(). +type OnMsgArrived func(ctx context.Context, client Client, req *MsgArrivedRequest) error + +// MsgArrivedRequest is the input param for OnMsgArrived hook. +type MsgArrivedRequest struct { + // Publish is the origin MQTT PUBLISH packet, it is immutable. DO NOT EDIT. + Publish *packets.Publish + // Message is the message that is going to be passed to topic match process. + // The caller can modify it. + Message *mqttbroker.Message + // IterationOptions provides the the ability to change the options of topic matching process. + // In most of cases, you don't need to modify it. + // The default value is: + // subscription.IterationOptions{ + // Type: subscription.TypeAll, + // MatchType: subscription.MatchFilter, + // TopicName: msg.Topic, + // } + // The user of this field is the federation plugin. + // It will change the Type from subscription.TypeAll to subscription.subscription.TypeAll ^ subscription.TypeShared + // that will prevent publishing the shared message to local client. + IterationOptions subscription.IterationOptions +} + +// Drop drops the message, so the message will not be delivered to any alertclient. +func (m *MsgArrivedRequest) Drop() { + m.Message = nil +} + +type OnMsgArrivedWrapper func(OnMsgArrived) OnMsgArrived + +// OnClosed will be called after the tcp connection of the client has been closed +type OnClosed func(ctx context.Context, client Client, err error) + +type OnClosedWrapper func(OnClosed) OnClosed + +// AuthOptions provides several options which controls how the server interacts with the client. +// The default value of these options is defined in the configuration file. +type AuthOptions struct { + // SessionExpiry is session expired time in seconds. + SessionExpiry uint32 + // ReceiveMax limits the number of QoS 1 and QoS 2 publications that the server is willing to process concurrently for the client. + // If the client version is v5, this value will be set into Receive Maximum property in CONNACK packet. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901083 + ReceiveMax uint16 + // MaximumQoS is the highest QOS level permitted for a Publish. + MaximumQoS uint8 + // MaxPacketSize is the maximum packet size that the server is willing to accept from the client. + // If the client version is v5, this value will be set into Receive Maximum property in CONNACK packet. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901086 + MaxPacketSize uint32 + // TopicAliasMax indicates the highest value that the server will accept as a Topic Alias sent by the client. + // The server uses this value to limit the number of Topic Aliases that it is willing to hold on this connection. + // This option only affect v5 client. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901088 + TopicAliasMax uint16 + // RetainAvailable indicates whether the server supports retained messages. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901085 + RetainAvailable bool + // WildcardSubAvailable indicates whether the server supports Wildcard Subscriptions. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901091 + WildcardSubAvailable bool + // SubIDAvailable indicates whether the server supports Subscription Identifiers. + // This option only affect v5 client. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901092 + SubIDAvailable bool + // SharedSubAvailable indicates whether the server supports Shared Subscriptions. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901093 + SharedSubAvailable bool + // KeepAlive is the keep alive time assigned by the server. + // This option only affect v5 client. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901094 + KeepAlive uint16 + // UserProperties is be used to provide additional information to the client. + // This option only affect v5 client. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901090 + UserProperties []*packets.UserProperty + // AssignedClientID allows the server to assign a client id for the client. + // It will override the client id in the connect packet. + AssignedClientID []byte + // ResponseInfo is used as the basis for creating a Response Topic. + // This option only affect v5 client. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901095 + ResponseInfo []byte + // MaxInflight limits the number of QoS 1 and QoS 2 publications that the client is willing to process concurrently. + MaxInflight uint16 +} + +// OnBasicAuth will be called when receive v311 connect packet or v5 connect packet with empty auth method property. +type OnBasicAuth func(ctx context.Context, client Client, req *ConnectRequest) (err error) + +// ConnectRequest represents a connect request made by a CONNECT packet. +type ConnectRequest struct { + // Connect is the CONNECT packet.It is immutable, do not edit. + Connect *packets.Connect + // Options represents the setting which will be applied to the current client if auth success. + // Caller can edit this property to change the setting. + Options *AuthOptions +} + +type OnBasicAuthWrapper func(OnBasicAuth) OnBasicAuth + +// OnEnhancedAuth will be called when receive v5 connect packet with auth method property. +type OnEnhancedAuth func(ctx context.Context, client Client, req *ConnectRequest) (resp *EnhancedAuthResponse, err error) + +type EnhancedAuthResponse struct { + Continue bool + OnAuth OnAuth + AuthData []byte +} +type OnEnhancedAuthWrapper func(OnEnhancedAuth) OnEnhancedAuth + +type AuthRequest struct { + Auth *packets.Auth + Options *AuthOptions +} + +// ReAuthResponse is the response of the OnAuth hook. +type AuthResponse struct { + // Continue indicate that whether more authentication data is needed. + Continue bool + // AuthData is the auth data property of the auth packet. + AuthData []byte +} + +type OnAuth func(ctx context.Context, client Client, req *AuthRequest) (*AuthResponse, error) + +type OnReAuthWrapper func(OnReAuth) OnReAuth + +type OnReAuth func(ctx context.Context, client Client, auth *packets.Auth) (*AuthResponse, error) + +type OnAuthWrapper func(OnAuth) OnAuth + +// OnConnected will be called when a mqttclient client connect successfully. +type OnConnected func(ctx context.Context, client Client) + +type OnConnectedWrapper func(OnConnected) OnConnected + +// OnSessionCreated will be called when new session created. +type OnSessionCreated func(ctx context.Context, client Client) + +type OnSessionCreatedWrapper func(OnSessionCreated) OnSessionCreated + +// OnSessionResumed will be called when session resumed. +type OnSessionResumed func(ctx context.Context, client Client) + +type OnSessionResumedWrapper func(OnSessionResumed) OnSessionResumed + +type SessionTerminatedReason byte + +const ( + NormalTermination SessionTerminatedReason = iota + TakenOverTermination + ExpiredTermination +) + +// OnSessionTerminated will be called when session has been terminated. +type OnSessionTerminated func(ctx context.Context, clientID string, reason SessionTerminatedReason) + +type OnSessionTerminatedWrapper func(OnSessionTerminated) OnSessionTerminated + +// OnDelivered will be called when publishing a message to a client. +type OnDelivered func(ctx context.Context, client Client, msg *mqttbroker.Message) + +type OnDeliveredWrapper func(OnDelivered) OnDelivered + +// OnMsgDropped will be called after the Msg dropped. +// The err indicates the reason of dropping. +// See: persistence/queue/error.go +type OnMsgDropped func(ctx context.Context, clientID string, msg *mqttbroker.Message, err error) + +type OnMsgDroppedWrapper func(OnMsgDropped) OnMsgDropped diff --git a/internal/hummingbird/mqttbroker/server/limiter.go b/internal/hummingbird/mqttbroker/server/limiter.go new file mode 100644 index 0000000..2a848f6 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/limiter.go @@ -0,0 +1,116 @@ +package server + +import ( + "sync" + + "github.com/winc-link/hummingbird/internal/pkg/bitmap" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +func newPacketIDLimiter(limit uint16) *packetIDLimiter { + return &packetIDLimiter{ + cond: sync.NewCond(&sync.Mutex{}), + used: 0, + limit: limit, + exit: false, + freePid: 1, + lockedPid: bitmap.New(packets.MaxPacketID), + } +} + +// packetIDLimiter limit the generation of packet id to keep the number of inflight messages +// always less or equal than receive maximum setting of the client. +type packetIDLimiter struct { + cond *sync.Cond + used uint16 + limit uint16 + exit bool + lockedPid *bitmap.Bitmap // packet id in-use + freePid packets.PacketID // next available id +} + +func (p *packetIDLimiter) close() { + p.cond.L.Lock() + p.exit = true + p.cond.L.Unlock() + p.cond.Signal() +} + +// pollPacketIDs returns at most max number of unused packetID and marks them as used for a client. +// If there is no available id, the call will be blocked until at least one packet id is available or the limiter has been closed. +// return 0 means the limiter is closed. +// the return number = min(max, i.used). +func (p *packetIDLimiter) pollPacketIDs(max uint16) (id []packets.PacketID) { + p.cond.L.Lock() + defer p.cond.L.Unlock() + for p.used >= p.limit && !p.exit { + p.cond.Wait() + } + if p.exit { + return nil + } + n := max + if remain := p.limit - p.used; remain < max { + n = remain + } + for j := uint16(0); j < n; j++ { + for p.lockedPid.Get(p.freePid) == 1 { + if p.freePid == packets.MaxPacketID { + p.freePid = packets.MinPacketID + } else { + p.freePid++ + } + } + id = append(id, p.freePid) + p.used++ + p.lockedPid.Set(p.freePid, 1) + if p.freePid == packets.MaxPacketID { + p.freePid = packets.MinPacketID + } else { + p.freePid++ + } + } + return id +} + +// release marks the given id list as unused +func (p *packetIDLimiter) release(id packets.PacketID) { + p.cond.L.Lock() + p.releaseLocked(id) + p.cond.L.Unlock() + p.cond.Signal() + +} +func (p *packetIDLimiter) releaseLocked(id packets.PacketID) { + if p.lockedPid.Get(id) == 1 { + p.lockedPid.Set(id, 0) + p.used-- + } +} + +func (p *packetIDLimiter) batchRelease(id []packets.PacketID) { + p.cond.L.Lock() + for _, v := range id { + p.releaseLocked(v) + } + p.cond.L.Unlock() + p.cond.Signal() + +} + +// markInUsed marks the given id as used. +func (p *packetIDLimiter) markUsedLocked(id packets.PacketID) { + p.used++ + p.lockedPid.Set(id, 1) +} + +func (p *packetIDLimiter) lock() { + p.cond.L.Lock() +} +func (p *packetIDLimiter) unlock() { + p.cond.L.Unlock() +} +func (p *packetIDLimiter) unlockAndSignal() { + p.cond.L.Unlock() + p.cond.Signal() +} diff --git a/internal/hummingbird/mqttbroker/server/options.go b/internal/hummingbird/mqttbroker/server/options.go new file mode 100644 index 0000000..1b4d318 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/options.go @@ -0,0 +1,50 @@ +package server + +import ( + "net" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" +) + +type Options func(srv *server) + +// WithConfig set the config of the server +func WithConfig(config config.Config) Options { + return func(srv *server) { + srv.config = config + } +} + +// WithTCPListener set tcp listener(s) of the server. Default listen on :58090. +func WithTCPListener(lns ...net.Listener) Options { + return func(srv *server) { + srv.tcpListener = append(srv.tcpListener, lns...) + } +} + +// WithWebsocketServer set websocket server(s) of the server. +func WithWebsocketServer(ws ...*WsServer) Options { + return func(srv *server) { + srv.websocketServer = ws + } +} + +// WithPlugin set plugin(s) of the server. +func WithPlugin(plugin ...Plugin) Options { + return func(srv *server) { + srv.plugins = append(srv.plugins, plugin...) + } +} + +// WithHook set hooks of the server. Notice: WithPlugin() will overwrite hooks. +func WithHook(hooks Hooks) Options { + return func(srv *server) { + srv.hooks = hooks + } +} + +func WithLogger(logger *DefaultLogger) Options { + return func(srv *server) { + zaplog = logger + } +} diff --git a/internal/hummingbird/mqttbroker/server/persistence.go b/internal/hummingbird/mqttbroker/server/persistence.go new file mode 100644 index 0000000..333b370 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/persistence.go @@ -0,0 +1,20 @@ +package server + +import ( + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/queue" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/session" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/subscription" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/unack" +) + +type NewPersistence func(config config.Config) (Persistence, error) + +type Persistence interface { + Open() error + NewQueueStore(config config.Config, defaultNotifier queue.Notifier, clientID string) (queue.Store, error) + NewSubscriptionStore(config config.Config) (subscription.Store, error) + NewSessionStore(config config.Config) (session.Store, error) + NewUnackStore(config config.Config, clientID string) (unack.Store, error) + Close() error +} diff --git a/internal/hummingbird/mqttbroker/server/persistence_mock.go b/internal/hummingbird/mqttbroker/server/persistence_mock.go new file mode 100644 index 0000000..acf1da0 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/persistence_mock.go @@ -0,0 +1,127 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: server/persistence.go + +// Package server is a generated GoMock package. +package server + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + config "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" + queue "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/queue" + session "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/session" + subscription "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/subscription" + unack "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/unack" +) + +// MockPersistence is a mock of Persistence interface +type MockPersistence struct { + ctrl *gomock.Controller + recorder *MockPersistenceMockRecorder +} + +// MockPersistenceMockRecorder is the mock recorder for MockPersistence +type MockPersistenceMockRecorder struct { + mock *MockPersistence +} + +// NewMockPersistence creates a new mock instance +func NewMockPersistence(ctrl *gomock.Controller) *MockPersistence { + mock := &MockPersistence{ctrl: ctrl} + mock.recorder = &MockPersistenceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockPersistence) EXPECT() *MockPersistenceMockRecorder { + return m.recorder +} + +// Open mocks base method +func (m *MockPersistence) Open() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Open") + ret0, _ := ret[0].(error) + return ret0 +} + +// Open indicates an expected call of Open +func (mr *MockPersistenceMockRecorder) Open() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockPersistence)(nil).Open)) +} + +// NewQueueStore mocks base method +func (m *MockPersistence) NewQueueStore(config config.Config, defaultNotifier queue.Notifier, clientID string) (queue.Store, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewQueueStore", config, defaultNotifier, clientID) + ret0, _ := ret[0].(queue.Store) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NewQueueStore indicates an expected call of NewQueueStore +func (mr *MockPersistenceMockRecorder) NewQueueStore(config, defaultNotifier, clientID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewQueueStore", reflect.TypeOf((*MockPersistence)(nil).NewQueueStore), config, defaultNotifier, clientID) +} + +// NewSubscriptionStore mocks base method +func (m *MockPersistence) NewSubscriptionStore(config config.Config) (subscription.Store, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewSubscriptionStore", config) + ret0, _ := ret[0].(subscription.Store) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NewSubscriptionStore indicates an expected call of NewSubscriptionStore +func (mr *MockPersistenceMockRecorder) NewSubscriptionStore(config interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewSubscriptionStore", reflect.TypeOf((*MockPersistence)(nil).NewSubscriptionStore), config) +} + +// NewSessionStore mocks base method +func (m *MockPersistence) NewSessionStore(config config.Config) (session.Store, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewSessionStore", config) + ret0, _ := ret[0].(session.Store) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NewSessionStore indicates an expected call of NewSessionStore +func (mr *MockPersistenceMockRecorder) NewSessionStore(config interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewSessionStore", reflect.TypeOf((*MockPersistence)(nil).NewSessionStore), config) +} + +// NewUnackStore mocks base method +func (m *MockPersistence) NewUnackStore(config config.Config, clientID string) (unack.Store, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewUnackStore", config, clientID) + ret0, _ := ret[0].(unack.Store) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NewUnackStore indicates an expected call of NewUnackStore +func (mr *MockPersistenceMockRecorder) NewUnackStore(config, clientID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewUnackStore", reflect.TypeOf((*MockPersistence)(nil).NewUnackStore), config, clientID) +} + +// Close mocks base method +func (m *MockPersistence) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockPersistenceMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPersistence)(nil).Close)) +} diff --git a/internal/hummingbird/mqttbroker/server/plugin.go b/internal/hummingbird/mqttbroker/server/plugin.go new file mode 100644 index 0000000..1fdd962 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/plugin.go @@ -0,0 +1,44 @@ +package server + +import ( + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" +) + +// HookWrapper groups all hook wrappers function +type HookWrapper struct { + OnBasicAuthWrapper OnBasicAuthWrapper + OnEnhancedAuthWrapper OnEnhancedAuthWrapper + OnConnectedWrapper OnConnectedWrapper + OnReAuthWrapper OnReAuthWrapper + OnSessionCreatedWrapper OnSessionCreatedWrapper + OnSessionResumedWrapper OnSessionResumedWrapper + OnSessionTerminatedWrapper OnSessionTerminatedWrapper + OnSubscribeWrapper OnSubscribeWrapper + OnSubscribedWrapper OnSubscribedWrapper + OnUnsubscribeWrapper OnUnsubscribeWrapper + OnUnsubscribedWrapper OnUnsubscribedWrapper + OnMsgArrivedWrapper OnMsgArrivedWrapper + OnMsgDroppedWrapper OnMsgDroppedWrapper + OnDeliveredWrapper OnDeliveredWrapper + OnClosedWrapper OnClosedWrapper + OnAcceptWrapper OnAcceptWrapper + OnStopWrapper OnStopWrapper + OnWillPublishWrapper OnWillPublishWrapper + OnWillPublishedWrapper OnWillPublishedWrapper +} + +// NewPlugin is the constructor of a plugin. +type NewPlugin func(config config.Config) (Plugin, error) + +// Plugin is the interface need to be implemented for every plugins. +type Plugin interface { + // Load will be called in server.Run(). If return error, the server will panic. + Load(service Server) error + // Unload will be called when the server is shutdown, the return error is only for logging + Unload() error + // HookWrapper returns all hook wrappers that used by the plugin. + // Return a empty wrapper if the plugin does not need any hooks + HookWrapper() HookWrapper + // Name return the plugin name + Name() string +} diff --git a/internal/hummingbird/mqttbroker/server/plugin_mock.go b/internal/hummingbird/mqttbroker/server/plugin_mock.go new file mode 100644 index 0000000..a829d29 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/plugin_mock.go @@ -0,0 +1,90 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: server/plugin.go + +// Package server is a generated GoMock package. +package server + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockPlugin is a mock of Plugin interface +type MockPlugin struct { + ctrl *gomock.Controller + recorder *MockPluginMockRecorder +} + +// MockPluginMockRecorder is the mock recorder for MockPlugin +type MockPluginMockRecorder struct { + mock *MockPlugin +} + +// NewMockPlugin creates a new mock instance +func NewMockPlugin(ctrl *gomock.Controller) *MockPlugin { + mock := &MockPlugin{ctrl: ctrl} + mock.recorder = &MockPluginMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockPlugin) EXPECT() *MockPluginMockRecorder { + return m.recorder +} + +// Load mocks base method +func (m *MockPlugin) Load(service Server) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Load", service) + ret0, _ := ret[0].(error) + return ret0 +} + +// Load indicates an expected call of Load +func (mr *MockPluginMockRecorder) Load(service interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockPlugin)(nil).Load), service) +} + +// Unload mocks base method +func (m *MockPlugin) Unload() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Unload") + ret0, _ := ret[0].(error) + return ret0 +} + +// Unload indicates an expected call of Unload +func (mr *MockPluginMockRecorder) Unload() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unload", reflect.TypeOf((*MockPlugin)(nil).Unload)) +} + +// HookWrapper mocks base method +func (m *MockPlugin) HookWrapper() HookWrapper { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HookWrapper") + ret0, _ := ret[0].(HookWrapper) + return ret0 +} + +// HookWrapper indicates an expected call of HookWrapper +func (mr *MockPluginMockRecorder) HookWrapper() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HookWrapper", reflect.TypeOf((*MockPlugin)(nil).HookWrapper)) +} + +// Name mocks base method +func (m *MockPlugin) Name() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Name") + ret0, _ := ret[0].(string) + return ret0 +} + +// Name indicates an expected call of Name +func (mr *MockPluginMockRecorder) Name() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockPlugin)(nil).Name)) +} diff --git a/internal/hummingbird/mqttbroker/server/publish_service.go b/internal/hummingbird/mqttbroker/server/publish_service.go new file mode 100644 index 0000000..255a990 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/publish_service.go @@ -0,0 +1,13 @@ +package server + +import "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + +type publishService struct { + server *server +} + +func (p *publishService) Publish(message *mqttbroker.Message) { + p.server.mu.Lock() + p.server.deliverMessage("", message, defaultIterateOptions(message.Topic)) + p.server.mu.Unlock() +} diff --git a/internal/hummingbird/mqttbroker/server/queue_notifier.go b/internal/hummingbird/mqttbroker/server/queue_notifier.go new file mode 100644 index 0000000..b82e8d2 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/queue_notifier.go @@ -0,0 +1,69 @@ +package server + +import ( + "context" + + "go.uber.org/zap" + + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/queue" +) + +// queueNotifier implements queue.Notifier interface. +type queueNotifier struct { + dropHook OnMsgDropped + sts *statsManager + cli *client +} + +// defaultNotifier is used to init the notifier when using a persistent session store (e.g redis) which can load session data +// while bootstrapping. +func defaultNotifier(dropHook OnMsgDropped, sts *statsManager, clientID string) *queueNotifier { + return &queueNotifier{ + dropHook: dropHook, + sts: sts, + cli: &client{opts: &ClientOptions{ClientID: clientID}, status: Connected + 1}, + } +} + +func (q *queueNotifier) notifyDropped(msg *gmqtt.Message, err error) { + cid := q.cli.opts.ClientID + zaplog.Warn("message dropped", zap.String("client_id", cid), zap.Error(err)) + q.sts.messageDropped(msg.QoS, q.cli.opts.ClientID, err) + if q.dropHook != nil { + q.dropHook(context.Background(), cid, msg, err) + } +} + +func (q *queueNotifier) NotifyDropped(elem *queue.Elem, err error) { + cid := q.cli.opts.ClientID + if err == queue.ErrDropExpiredInflight && q.cli.IsConnected() { + q.cli.pl.release(elem.ID()) + } + if pub, ok := elem.MessageWithID.(*queue.Publish); ok { + q.notifyDropped(pub.Message, err) + } else { + zaplog.Warn("message dropped", zap.String("client_id", cid), zap.Error(err)) + } +} + +func (q *queueNotifier) NotifyInflightAdded(delta int) { + cid := q.cli.opts.ClientID + if delta > 0 { + q.sts.addInflight(cid, uint64(delta)) + } + if delta < 0 { + q.sts.decInflight(cid, uint64(-delta)) + } + +} + +func (q *queueNotifier) NotifyMsgQueueAdded(delta int) { + cid := q.cli.opts.ClientID + if delta > 0 { + q.sts.addQueueLen(cid, uint64(delta)) + } + if delta < 0 { + q.sts.decQueueLen(cid, uint64(-delta)) + } +} diff --git a/internal/hummingbird/mqttbroker/server/server.go b/internal/hummingbird/mqttbroker/server/server.go new file mode 100644 index 0000000..f77fd96 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/server.go @@ -0,0 +1,1509 @@ +package server + +import ( + "context" + "errors" + "fmt" + "math/rand" + "net" + "net/http" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/queue" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/session" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/unack" + retained_trie "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/retained/trie" + "github.com/winc-link/hummingbird/internal/pkg/codes" + "github.com/winc-link/hummingbird/internal/pkg/constants" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/subscription" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/retained" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +var ( + // ErrInvalWsMsgType [MQTT-6.0.0-1] + ErrInvalWsMsgType = errors.New("invalid websocket message type") + statusPanic = "invalid server status" + plugins = make(map[string]NewPlugin) + topicAliasMgrFactory = make(map[string]NewTopicAliasManager) + persistenceFactories = make(map[string]NewPersistence) +) + +func defaultIterateOptions(topicName string) subscription.IterationOptions { + return subscription.IterationOptions{ + Type: subscription.TypeAll, + TopicName: topicName, + MatchType: subscription.MatchFilter, + } +} + +func RegisterPersistenceFactory(name string, new NewPersistence) { + if _, ok := persistenceFactories[name]; ok { + panic("duplicated persistence factory: " + name) + } + persistenceFactories[name] = new +} + +func RegisterTopicAliasMgrFactory(name string, new NewTopicAliasManager) { + if _, ok := topicAliasMgrFactory[name]; ok { + panic("duplicated topic alias manager factory: " + name) + } + topicAliasMgrFactory[name] = new +} + +func RegisterPlugin(name string, new NewPlugin) { + if _, ok := plugins[name]; ok { + panic("duplicated plugin: " + name) + } + plugins[name] = new +} + +// Server status +const ( + serverStatusInit = iota + serverStatusStarted +) + +var LogLevelMap = map[string]zapcore.Level{ + "DEBUG": zapcore.DebugLevel, + "INFO": zapcore.InfoLevel, + "WARN": zapcore.WarnLevel, + "ERROR": zapcore.ErrorLevel, +} + +type DefaultLogger struct { + Level *zap.AtomicLevel + *zap.Logger +} + +func (df *DefaultLogger) SetLogLevel(level constants.LogLevel) (string, error) { + lStr := constants.LogMap[level] + if len(lStr) <= 0 { + return "", fmt.Errorf("recv invilad log level: %d", level) + } + zl := LogLevelMap[lStr] + df.Level.SetLevel(zl) + return lStr, nil +} + +var zaplog *DefaultLogger + +func init() { + zaplog = &DefaultLogger{ + Logger: zap.NewNop(), + } +} + +// LoggerWithField release fields to a new logger. +// Plugins can use this method to release plugin name field. +func LoggerWithField(fields ...zap.Field) *zap.Logger { + return zaplog.With(fields...) +} + +// Server interface represents a mqttclient server instance. +type Server interface { + // Publisher returns the Publisher + Publisher() Publisher + // GetConfig returns the config of the server + GetConfig() config.Config + // StatsManager returns StatsReader + StatsManager() StatsReader + // Stop stop the server gracefully + Stop(ctx context.Context) error + // ApplyConfig will replace the config of the server + ApplyConfig(config config.Config) + + ClientService() ClientService + + SubscriptionService() SubscriptionService + + RetainedService() RetainedService + // Plugins returns all enabled plugins + Plugins() []Plugin + APIRegistrar() APIRegistrar +} + +type clientService struct { + srv *server + sessionStore session.Store +} + +func (c *clientService) IterateSession(fn session.IterateFn) error { + return c.sessionStore.Iterate(fn) +} + +func (c *clientService) IterateClient(fn ClientIterateFn) { + c.srv.mu.Lock() + defer c.srv.mu.Unlock() + + for _, v := range c.srv.clients { + if !fn(v) { + return + } + } +} + +func (c *clientService) GetClient(clientID string) Client { + c.srv.mu.Lock() + defer c.srv.mu.Unlock() + if c, ok := c.srv.clients[clientID]; ok { + return c + } + return nil +} + +func (c *clientService) GetSession(clientID string) (*mqttbroker.Session, error) { + return c.sessionStore.Get(clientID) +} + +func (c *clientService) TerminateSession(clientID string) { + c.srv.mu.Lock() + defer c.srv.mu.Unlock() + if cli, ok := c.srv.clients[clientID]; ok { + atomic.StoreInt32(&cli.forceRemoveSession, 1) + cli.Close() + return + } + if _, ok := c.srv.offlineClients[clientID]; ok { + err := c.srv.sessionTerminatedLocked(clientID, NormalTermination) + if err != nil { + err = fmt.Errorf("session terminated fail: %s", err.Error()) + zaplog.Error("session terminated fail", zap.Error(err)) + } + } + +} + +// server represents a mqttclient server instance. +// Create a server by using New() +type server struct { + wg sync.WaitGroup + initOnce sync.Once + stopOnce sync.Once + mu sync.RWMutex //gard clients & offlineClients map + status int32 //server status + // clients stores the online clients + clients map[string]*client + // offlineClients store the expired time of all disconnected clients + // with valid session(not expired). Key by clientID + offlineClients map[string]time.Time + willMessage map[string]*willMsg + tcpListener []net.Listener //tcp listeners + websocketServer []*WsServer //websocket serverStop + errOnce sync.Once + err error + exitChan chan struct{} + exitedChan chan struct{} + + retainedDB retained.Store + subscriptionsDB subscription.Store //store subscriptions + + persistence Persistence + queueStore map[string]queue.Store + unackStore map[string]unack.Store + sessionStore session.Store + + // guards config + configMu sync.RWMutex + config config.Config + hooks Hooks + plugins []Plugin + statsManager *statsManager + publishService Publisher + newTopicAliasManager NewTopicAliasManager + + clientService *clientService + apiRegistrar *apiRegistrar +} + +func (srv *server) APIRegistrar() APIRegistrar { + return srv.apiRegistrar +} + +func (srv *server) Plugins() []Plugin { + srv.mu.Lock() + defer srv.mu.Unlock() + p := make([]Plugin, len(srv.plugins)) + copy(p, srv.plugins) + return p + +} + +func (srv *server) RetainedService() RetainedService { + return srv.retainedDB +} + +func (srv *server) ClientService() ClientService { + return srv.clientService +} + +func (srv *server) ApplyConfig(config config.Config) { + srv.configMu.Lock() + defer srv.configMu.Unlock() + srv.config = config + +} + +func (srv *server) SubscriptionService() SubscriptionService { + return srv.subscriptionsDB +} + +func (srv *server) RetainedStore() retained.Store { + return srv.retainedDB +} + +func (srv *server) Publisher() Publisher { + return srv.publishService +} + +func (srv *server) checkStatus() { + if srv.Status() != serverStatusInit { + panic(statusPanic) + } +} + +type DeliveryMode = string + +const ( + Overlap DeliveryMode = "overlap" + OnlyOnce DeliveryMode = "onlyonce" +) + +// GetConfig returns the config of the server +func (srv *server) GetConfig() config.Config { + srv.configMu.Lock() + defer srv.configMu.Unlock() + return srv.config +} + +// StatsManager returns StatsReader +func (srv *server) StatsManager() StatsReader { + return srv.statsManager +} + +// Status returns the server status +func (srv *server) Status() int32 { + return atomic.LoadInt32(&srv.status) +} + +func (srv *server) sessionTerminatedLocked(clientID string, reason SessionTerminatedReason) (err error) { + err = srv.removeSessionLocked(clientID) + if srv.hooks.OnSessionTerminated != nil { + srv.hooks.OnSessionTerminated(context.Background(), clientID, reason) + } + srv.statsManager.sessionTerminated(clientID, reason) + return err +} + +func uint16P(v uint16) *uint16 { + return &v +} + +func setWillProperties(willPpt *packets.Properties, msg *mqttbroker.Message) { + if willPpt != nil { + if willPpt.PayloadFormat != nil { + msg.PayloadFormat = *willPpt.PayloadFormat + } + if willPpt.MessageExpiry != nil { + msg.MessageExpiry = *willPpt.MessageExpiry + } + if willPpt.ContentType != nil { + msg.ContentType = string(willPpt.ContentType) + } + if willPpt.ResponseTopic != nil { + msg.ResponseTopic = string(willPpt.ResponseTopic) + } + if willPpt.CorrelationData != nil { + msg.CorrelationData = willPpt.CorrelationData + } + msg.UserProperties = willPpt.User + } +} + +func (srv *server) lockDuplicatedID(c *client) (oldSession *mqttbroker.Session, err error) { + for { + srv.mu.Lock() + oldSession, err = srv.sessionStore.Get(c.opts.ClientID) + if err != nil { + srv.mu.Unlock() + zaplog.Error("fail to get session", + zap.String("remote_addr", c.rwc.RemoteAddr().String()), + zap.String("client_id", c.opts.ClientID)) + return + } + if oldSession != nil { + var oldClient *client + oldClient = srv.clients[oldSession.ClientID] + srv.mu.Unlock() + if oldClient == nil { + srv.mu.Lock() + break + } + // if there is a duplicated online client, close if first. + zaplog.Info("logging with duplicate ClientID", + zap.String("remote", c.rwc.RemoteAddr().String()), + zap.String("client_id", oldSession.ClientID), + ) + oldClient.setError(codes.NewError(codes.SessionTakenOver)) + oldClient.Close() + <-oldClient.closed + continue + } + break + } + return +} + +// 已经判断是成功了,注册 +func (srv *server) registerClient(connect *packets.Connect, client *client) (sessionResume bool, err error) { + var qs queue.Store + var ua unack.Store + var sess *mqttbroker.Session + var oldSession *mqttbroker.Session + now := time.Now() + oldSession, err = srv.lockDuplicatedID(client) + if err != nil { + return + } + defer func() { + if err == nil { + var willMsg *mqttbroker.Message + var willDelayInterval, expiryInterval uint32 + if connect.WillFlag { + willMsg = &mqttbroker.Message{ + QoS: connect.WillQos, + Topic: string(connect.WillTopic), + Payload: connect.WillMsg, + } + setWillProperties(connect.WillProperties, willMsg) + } + // use default expiry if the client version is version3.1.1 + if packets.IsVersion3X(client.version) && !connect.CleanStart { + expiryInterval = uint32(srv.config.MQTT.SessionExpiry.Seconds()) + } else if connect.Properties != nil { + willDelayInterval = convertUint32(connect.WillProperties.WillDelayInterval, 0) + expiryInterval = client.opts.SessionExpiry + } + sess = &mqttbroker.Session{ + ClientID: client.opts.ClientID, + Will: willMsg, + ConnectedAt: time.Now(), + WillDelayInterval: willDelayInterval, + ExpiryInterval: expiryInterval, + } + err = srv.sessionStore.Set(sess) + } + + if err == nil { + client.session = sess + if sessionResume { + // If a new Network Connection to this Session is made before the Will Delay Interval has passed, + // the Server MUST NOT send the Will Message [MQTT-3.1.3-9]. + if w, ok := srv.willMessage[client.opts.ClientID]; ok { + w.signal(false) + } + if srv.hooks.OnSessionResumed != nil { + srv.hooks.OnSessionResumed(context.Background(), client) + } + srv.statsManager.sessionActive(false) + } else { + if srv.hooks.OnSessionCreated != nil { + srv.hooks.OnSessionCreated(context.Background(), client) + } + srv.statsManager.sessionActive(true) + } + srv.clients[client.opts.ClientID] = client + srv.unackStore[client.opts.ClientID] = ua + srv.queueStore[client.opts.ClientID] = qs + client.queueStore = qs + client.unackStore = ua + if client.version == packets.Version5 { + client.topicAliasManager = srv.newTopicAliasManager(client.config, client.opts.ClientTopicAliasMax, client.opts.ClientID) + } + } + srv.mu.Unlock() + + if err == nil { + if srv.hooks.OnConnected != nil { + srv.hooks.OnConnected(context.Background(), client) + } + } + }() + + client.setConnected(time.Now()) + srv.statsManager.clientConnected(client.opts.ClientID) + + if oldSession != nil { + if !oldSession.IsExpired(now) && !connect.CleanStart { + sessionResume = true + } + // clean old session + if !sessionResume { + err = srv.sessionTerminatedLocked(oldSession.ClientID, TakenOverTermination) + if err != nil { + err = fmt.Errorf("session terminated fail: %w", err) + zaplog.Error("session terminated fail", zap.Error(err)) + } + // Send will message because the previous session is ended. + if w, ok := srv.willMessage[client.opts.ClientID]; ok { + w.signal(true) + } + } else { + qs = srv.queueStore[client.opts.ClientID] + if qs != nil { + err = qs.Init(&queue.InitOptions{ + CleanStart: false, + Version: client.version, + ReadBytesLimit: client.opts.ClientMaxPacketSize, + Notifier: client.queueNotifier, + }) + if err != nil { + return + } + } + ua = srv.unackStore[client.opts.ClientID] + if ua != nil { + err = ua.Init(false) + if err != nil { + return + } + } + if ua == nil || qs == nil { + // This could happen if backend store loss some data which will bring the session into "inconsistent state". + // We should create a new session and prevent the client reuse the inconsistent one. + sessionResume = false + zaplog.Error("detect inconsistent session state", + zap.String("remote_addr", client.rwc.RemoteAddr().String()), + zap.String("client_id", client.opts.ClientID)) + } else { + zaplog.Info("logged in with session reuse", + zap.String("remote_addr", client.rwc.RemoteAddr().String()), + zap.String("client_id", client.opts.ClientID)) + } + + } + } + if !sessionResume { + // create new session + // It is ok to pass nil to defaultNotifier, because we will call Init to override it. + qs, err = srv.persistence.NewQueueStore(srv.config, nil, client.opts.ClientID) + if err != nil { + return + } + err = qs.Init(&queue.InitOptions{ + CleanStart: true, + Version: client.version, + ReadBytesLimit: client.opts.ClientMaxPacketSize, + Notifier: client.queueNotifier, + }) + if err != nil { + return + } + + ua, err = srv.persistence.NewUnackStore(srv.config, client.opts.ClientID) + if err != nil { + return + } + err = ua.Init(true) + if err != nil { + return + } + zaplog.Info("logged in with new session", + zap.String("remote_addr", client.rwc.RemoteAddr().String()), + zap.String("client_id", client.opts.ClientID), + ) + } + delete(srv.offlineClients, client.opts.ClientID) + return +} + +type willMsg struct { + msg *mqttbroker.Message + // If true, send the msg. + // If false, discard the msg. + send chan bool +} + +func (w *willMsg) signal(send bool) { + select { + case w.send <- send: + default: + } +} + +// sendWillLocked sends the will message for the client, this function must be guard by srv.Lock. +func (srv *server) sendWillLocked(msg *mqttbroker.Message, clientID string) { + req := &WillMsgRequest{ + Message: msg, + } + if srv.hooks.OnWillPublish != nil { + srv.hooks.OnWillPublish(context.Background(), clientID, req) + } + // the will message is dropped + if req.Message == nil { + return + } + srv.deliverMessage(clientID, msg, defaultIterateOptions(msg.Topic)) + if srv.hooks.OnWillPublished != nil { + srv.hooks.OnWillPublished(context.Background(), clientID, req.Message) + } +} + +func (srv *server) unregisterClient(client *client) { + srv.mu.Lock() + defer srv.mu.Unlock() + now := time.Now() + var storeSession bool + if sess, err := srv.sessionStore.Get(client.opts.ClientID); sess != nil { + forceRemove := atomic.LoadInt32(&client.forceRemoveSession) + if forceRemove != 1 { + if client.version == packets.Version5 && client.disconnect != nil { + sess.ExpiryInterval = convertUint32(client.disconnect.Properties.SessionExpiryInterval, sess.ExpiryInterval) + } + if sess.ExpiryInterval != 0 { + storeSession = true + } + } + // need to send will message + if !client.cleanWillFlag && sess.Will != nil { + willDelayInterval := sess.WillDelayInterval + if sess.ExpiryInterval <= sess.WillDelayInterval { + willDelayInterval = sess.ExpiryInterval + } + msg := sess.Will.Copy() + if willDelayInterval != 0 && storeSession { + wm := &willMsg{ + msg: msg, + send: make(chan bool, 1), + } + srv.willMessage[client.opts.ClientID] = wm + t := time.NewTimer(time.Duration(willDelayInterval) * time.Second) + go func(clientID string) { + var send bool + select { + case send = <-wm.send: + t.Stop() + case <-t.C: + send = true + } + srv.mu.Lock() + defer srv.mu.Unlock() + delete(srv.willMessage, clientID) + if !send { + return + } + srv.sendWillLocked(msg, clientID) + }(client.opts.ClientID) + } else { + srv.sendWillLocked(msg, client.opts.ClientID) + } + } + if storeSession { + expiredTime := now.Add(time.Duration(sess.ExpiryInterval) * time.Second) + srv.offlineClients[client.opts.ClientID] = expiredTime + delete(srv.clients, client.opts.ClientID) + zaplog.Info("logged out and storing session", + zap.String("remote_addr", client.rwc.RemoteAddr().String()), + zap.String("client_id", client.opts.ClientID), + zap.Time("expired_at", expiredTime), + ) + return + } + } else { + zaplog.Error("fail to get session", + zap.String("remote_addr", client.rwc.RemoteAddr().String()), + zap.String("client_id", client.opts.ClientID), + zap.Error(err)) + } + zaplog.Info("logged out and cleaning session", + zap.String("remote_addr", client.rwc.RemoteAddr().String()), + zap.String("client_id", client.opts.ClientID), + ) + _ = srv.sessionTerminatedLocked(client.opts.ClientID, NormalTermination) +} + +func (srv *server) addMsgToQueueLocked(now time.Time, clientID string, msg *mqttbroker.Message, sub *mqttbroker.Subscription, ids []uint32, q queue.Store) { + mqttCfg := srv.config.MQTT + if !mqttCfg.QueueQos0Msg { + // If the client with the clientID is not connected, skip qos0 messages. + if c := srv.clients[clientID]; c == nil && msg.QoS == packets.Qos0 { + return + } + } + if msg.QoS > sub.QoS { + msg.QoS = sub.QoS + } + for _, id := range ids { + if id != 0 { + msg.SubscriptionIdentifier = append(msg.SubscriptionIdentifier, id) + } + } + msg.Dup = false + if !sub.RetainAsPublished { + msg.Retained = false + } + var expiry time.Time + if mqttCfg.MessageExpiry != 0 { + if msg.MessageExpiry != 0 && int(msg.MessageExpiry) <= int(mqttCfg.MessageExpiry) { + expiry = now.Add(time.Duration(msg.MessageExpiry) * time.Second) + } else { + expiry = now.Add(mqttCfg.MessageExpiry) + } + } else if msg.MessageExpiry != 0 { + expiry = now.Add(time.Duration(msg.MessageExpiry) * time.Second) + } + err := q.Add(&queue.Elem{ + At: now, + Expiry: expiry, + MessageWithID: &queue.Publish{ + Message: msg, + }, + }) + if err != nil { + srv.clients[clientID].queueNotifier.notifyDropped(msg, &queue.InternalError{Err: err}) + return + } +} + +// sharedList is the subscriber (client id) list of shared subscriptions. (key by topic name). +type sharedList map[string][]struct { + clientID string + sub *mqttbroker.Subscription +} + +// maxQos records the maximum qos subscription for the non-shared topic. (key by topic name). +type maxQos map[string]*struct { + sub *mqttbroker.Subscription + subIDs []uint32 +} + +// deliverHandler controllers the delivery behaviors according to the DeliveryMode config. (overlap or onlyonce) +type deliverHandler struct { + fn subscription.IterateFn + sl sharedList + mq maxQos + matched bool + now time.Time + msg *mqttbroker.Message + srv *server +} + +func newDeliverHandler(mode string, srcClientID string, msg *mqttbroker.Message, now time.Time, srv *server) *deliverHandler { + d := &deliverHandler{ + sl: make(sharedList), + mq: make(maxQos), + msg: msg, + srv: srv, + now: now, + } + var iterateFn subscription.IterateFn + d.fn = func(clientID string, sub *mqttbroker.Subscription) bool { + if sub.NoLocal && clientID == srcClientID { + return true + } + d.matched = true + if sub.ShareName != "" { + fullTopic := sub.GetFullTopicName() + d.sl[fullTopic] = append(d.sl[fullTopic], struct { + clientID string + sub *mqttbroker.Subscription + }{clientID: clientID, sub: sub}) + return true + } + return iterateFn(clientID, sub) + } + if mode == Overlap { + iterateFn = func(clientID string, sub *mqttbroker.Subscription) bool { + if qs := srv.queueStore[clientID]; qs != nil { + srv.addMsgToQueueLocked(now, clientID, msg.Copy(), sub, []uint32{sub.ID}, qs) + } + return true + } + } else { + iterateFn = func(clientID string, sub *mqttbroker.Subscription) bool { + // If the delivery mode is onlyOnce, set the message qos to the maximum qos in matched subscriptions. + if d.mq[clientID] == nil { + d.mq[clientID] = &struct { + sub *mqttbroker.Subscription + subIDs []uint32 + }{sub: sub, subIDs: []uint32{sub.ID}} + return true + } + if d.mq[clientID].sub.QoS < sub.QoS { + d.mq[clientID].sub = sub + } + d.mq[clientID].subIDs = append(d.mq[clientID].subIDs, sub.ID) + return true + } + } + return d +} + +func (d *deliverHandler) flush() { + // shared subscription + // TODO enable customize balance strategy of shared subscription + for _, v := range d.sl { + var rs struct { + clientID string + sub *mqttbroker.Subscription + } + // random + rs = v[rand.Intn(len(v))] + if c, ok := d.srv.queueStore[rs.clientID]; ok { + d.srv.addMsgToQueueLocked(d.now, rs.clientID, d.msg.Copy(), rs.sub, []uint32{rs.sub.ID}, c) + } + } + // For onlyonce mode, send the non-shared messages. + for clientID, v := range d.mq { + if qs := d.srv.queueStore[clientID]; qs != nil { + d.srv.addMsgToQueueLocked(d.now, clientID, d.msg.Copy(), v.sub, v.subIDs, qs) + } + } +} + +// deliverMessage send msg to matched client, must call under srv.mu.Lock +func (srv *server) deliverMessage(srcClientID string, msg *mqttbroker.Message, options subscription.IterationOptions) (matched bool) { + now := time.Now() + d := newDeliverHandler(srv.config.MQTT.DeliveryMode, srcClientID, msg, now, srv) + srv.subscriptionsDB.Iterate(d.fn, options) + d.flush() + return d.matched +} + +func (srv *server) removeSessionLocked(clientID string) (err error) { + delete(srv.clients, clientID) + delete(srv.offlineClients, clientID) + + var errs []string + var queueErr, sessionErr, subErr error + if qs := srv.queueStore[clientID]; qs != nil { + queueErr = qs.Clean() + if queueErr != nil { + zaplog.Error("fail to clean message queue", + zap.String("client_id", clientID), + zap.Error(queueErr)) + errs = append(errs, "fail to clean message queue: "+queueErr.Error()) + } + delete(srv.queueStore, clientID) + } + sessionErr = srv.sessionStore.Remove(clientID) + if sessionErr != nil { + zaplog.Error("fail to remove session", + zap.String("client_id", clientID), + zap.Error(sessionErr)) + + errs = append(errs, "fail to remove session: "+sessionErr.Error()) + } + subErr = srv.subscriptionsDB.UnsubscribeAll(clientID) + if subErr != nil { + zaplog.Error("fail to remove subscription", + zap.String("client_id", clientID), + zap.Error(subErr)) + + errs = append(errs, "fail to remove subscription: "+subErr.Error()) + } + + if errs != nil { + return errors.New(strings.Join(errs, ";")) + } + return nil +} + +// sessionExpireCheck 判断是否超时 +// sessionExpireCheck check and terminate expired sessions +func (srv *server) sessionExpireCheck() { + now := time.Now() + srv.mu.Lock() + for cid, expiredTime := range srv.offlineClients { + if now.After(expiredTime) { + zaplog.Info("session expired", zap.String("client_id", cid)) + _ = srv.sessionTerminatedLocked(cid, ExpiredTermination) + + } + } + srv.mu.Unlock() +} + +// server event loop +func (srv *server) eventLoop() { + sessionExpireTimer := time.NewTicker(time.Second * 20) + defer func() { + sessionExpireTimer.Stop() + srv.wg.Done() + }() + for { + select { + case <-srv.exitChan: + return + case <-sessionExpireTimer.C: + srv.sessionExpireCheck() + } + + } +} + +// WsServer is used to build websocket server +type WsServer struct { + Server *http.Server + Path string // Url path + CertFile string //TLS configration + KeyFile string //TLS configration +} + +func defaultServer() *server { + srv := &server{ + status: serverStatusInit, + exitChan: make(chan struct{}), + exitedChan: make(chan struct{}), + clients: make(map[string]*client), + offlineClients: make(map[string]time.Time), + willMessage: make(map[string]*willMsg), + retainedDB: retained_trie.NewStore(), + config: config.DefaultConfig(), + queueStore: make(map[string]queue.Store), + unackStore: make(map[string]unack.Store), + } + srv.publishService = &publishService{server: srv} + return srv +} + +// New returns a gmqtt server instance with the given options +func New(opts ...Options) *server { + srv := defaultServer() + for _, fn := range opts { + fn(srv) + } + return srv +} + +func (srv *server) init(opts ...Options) (err error) { + for _, fn := range opts { + fn(srv) + } + err = srv.initPluginHooks() + if err != nil { + return err + } + var pe Persistence + peType := srv.config.Persistence.Type + if newFn := persistenceFactories[peType]; newFn != nil { + pe, err = newFn(srv.config) + if err != nil { + return err + } + } else { + return fmt.Errorf("persistence factory: %s not found", peType) + } + err = pe.Open() + if err != nil { + return err + } + zaplog.Info("open persistence succeeded", zap.String("type", peType)) + srv.persistence = pe + + srv.subscriptionsDB, err = srv.persistence.NewSubscriptionStore(srv.config) + if err != nil { + return err + } + st, err := srv.persistence.NewSessionStore(srv.config) + if err != nil { + return err + } + srv.sessionStore = st + var sts []*mqttbroker.Session + var cids []string + + err = st.Iterate(func(session *mqttbroker.Session) bool { + sts = append(sts, session) + cids = append(cids, session.ClientID) + return true + }) + if err != nil { + return err + } + zaplog.Info("init session store succeeded", zap.String("type", peType), zap.Int("session_total", len(cids))) + + srv.statsManager = newStatsManager(srv.subscriptionsDB) + srv.clientService = &clientService{ + srv: srv, + sessionStore: srv.sessionStore, + } + + // init queue store & unack store from persistence + for _, v := range sts { + q, err := srv.persistence.NewQueueStore(srv.config, defaultNotifier(srv.hooks.OnMsgDropped, srv.statsManager, v.ClientID), v.ClientID) + if err != nil { + return err + } + srv.queueStore[v.ClientID] = q + srv.offlineClients[v.ClientID] = time.Now().Add(time.Duration(v.ExpiryInterval) * time.Second) + + ua, err := srv.persistence.NewUnackStore(srv.config, v.ClientID) + if err != nil { + return err + } + srv.unackStore[v.ClientID] = ua + } + zaplog.Info("init queue store succeeded", zap.String("type", peType), zap.Int("session_total", len(cids))) + zaplog.Info("init subscription store succeeded", zap.String("type", peType), zap.Int("client_total", len(cids))) + err = srv.subscriptionsDB.Init(cids) + if err != nil { + return err + } + + topicAliasMgrFactory := topicAliasMgrFactory[srv.config.TopicAliasManager.Type] + if topicAliasMgrFactory != nil { + srv.newTopicAliasManager = topicAliasMgrFactory + } else { + return fmt.Errorf("topic alias manager : %s not found", srv.config.TopicAliasManager.Type) + } + err = srv.initAPIRegistrar() + if err != nil { + return err + } + return srv.loadPlugins() +} + +func (srv *server) initAPIRegistrar() error { + registrar := &apiRegistrar{} + for _, v := range srv.config.API.HTTP { + server, err := buildHTTPServer(v) + if err != nil { + return err + } + registrar.httpServers = append(registrar.httpServers, server) + + } + for _, v := range srv.config.API.GRPC { + server, err := buildGRPCServer(v) + if err != nil { + return err + } + registrar.gRPCServers = append(registrar.gRPCServers, server) + } + srv.apiRegistrar = registrar + return nil +} + +// Init initialises the options. +func (srv *server) Init(opts ...Options) (err error) { + srv.initOnce.Do(func() { + err = srv.init(opts...) + }) + return err +} + +// Client returns the client for given clientID +func (srv *server) Client(clientID string) Client { + srv.mu.Lock() + defer srv.mu.Unlock() + return srv.clients[clientID] +} + +func (srv *server) serveTCP(l net.Listener) { + defer func() { + l.Close() + }() + var tempDelay time.Duration + for { + rw, e := l.Accept() + if e != nil { + if ne, ok := e.(net.Error); ok && ne.Temporary() { + if tempDelay == 0 { + tempDelay = 5 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 1 * time.Second; tempDelay > max { + tempDelay = max + } + time.Sleep(tempDelay) + continue + } + return + } + if srv.hooks.OnAccept != nil { + if !srv.hooks.OnAccept(context.Background(), rw) { + rw.Close() + continue + } + } + client, err := srv.newClient(rw) + if err != nil { + zaplog.Error("new client fail", zap.Error(err)) + return + } + go client.serve() + } +} + +var defaultUpgrader = &websocket.Upgrader{ + ReadBufferSize: readBufferSize, + WriteBufferSize: writeBufferSize, + CheckOrigin: func(r *http.Request) bool { + return true + }, + Subprotocols: []string{"mqttclient"}, +} + +//实现io.ReadWriter接口 +// wsConn implements the io.readWriter +type wsConn struct { + net.Conn + c *websocket.Conn +} + +func (ws *wsConn) Close() error { + return ws.Conn.Close() +} + +func (ws *wsConn) Read(p []byte) (n int, err error) { + msgType, r, err := ws.c.NextReader() + if err != nil { + return 0, err + } + if msgType != websocket.BinaryMessage { + return 0, ErrInvalWsMsgType + } + return r.Read(p) +} + +func (ws *wsConn) Write(p []byte) (n int, err error) { + err = ws.c.WriteMessage(websocket.BinaryMessage, p) + if err != nil { + return 0, err + } + return len(p), err +} + +func (srv *server) serveWebSocket(ws *WsServer) { + var err error + if ws.CertFile != "" && ws.KeyFile != "" { + err = ws.Server.ListenAndServeTLS(ws.CertFile, ws.KeyFile) + } else { + err = ws.Server.ListenAndServe() + } + if err != nil && err != http.ErrServerClosed { + srv.setError(fmt.Errorf("serveWebSocket error: %s", err.Error())) + } +} + +func (srv *server) newClient(c net.Conn) (*client, error) { + srv.configMu.Lock() + cfg := srv.config + srv.configMu.Unlock() + client := &client{ + server: srv, + rwc: c, + bufr: newBufioReaderSize(c, readBufferSize), + bufw: newBufioWriterSize(c, writeBufferSize), + close: make(chan struct{}), + closed: make(chan struct{}), + connected: make(chan struct{}), + error: make(chan error, 1), + in: make(chan packets.Packet, 8), + out: make(chan packets.Packet, 8), + status: Connecting, + opts: &ClientOptions{}, + cleanWillFlag: false, + config: cfg, + register: srv.registerClient, + unregister: srv.unregisterClient, + deliverMessage: func(srcClientID string, msg *mqttbroker.Message, options subscription.IterationOptions) (matched bool) { + srv.mu.Lock() + defer srv.mu.Unlock() + return srv.deliverMessage(srcClientID, msg, options) + }, + } + client.packetReader = packets.NewReader(client.bufr) + client.packetWriter = packets.NewWriter(client.bufw) + client.queueNotifier = &queueNotifier{ + dropHook: srv.hooks.OnMsgDropped, + sts: srv.statsManager, + cli: client, + } + client.setConnecting() + + return client, nil +} + +func (srv *server) initPluginHooks() error { + zaplog.Info("init plugin hook wrappers") + var ( + onAcceptWrappers []OnAcceptWrapper + onBasicAuthWrappers []OnBasicAuthWrapper + onEnhancedAuthWrappers []OnEnhancedAuthWrapper + onReAuthWrappers []OnReAuthWrapper + onConnectedWrappers []OnConnectedWrapper + onSessionCreatedWrapper []OnSessionCreatedWrapper + onSessionResumedWrapper []OnSessionResumedWrapper + onSessionTerminatedWrapper []OnSessionTerminatedWrapper + onSubscribeWrappers []OnSubscribeWrapper + onSubscribedWrappers []OnSubscribedWrapper + onUnsubscribeWrappers []OnUnsubscribeWrapper + onUnsubscribedWrappers []OnUnsubscribedWrapper + onMsgArrivedWrappers []OnMsgArrivedWrapper + OnDeliveredWrappers []OnDeliveredWrapper + OnClosedWrappers []OnClosedWrapper + onStopWrappers []OnStopWrapper + onMsgDroppedWrappers []OnMsgDroppedWrapper + onWillPublishWrappers []OnWillPublishWrapper + onWillPublishedWrappers []OnWillPublishedWrapper + ) + for _, v := range srv.config.PluginOrder { + plg, err := plugins[v](srv.config) + if err != nil { + return err + } + srv.plugins = append(srv.plugins, plg) + } + + for _, p := range srv.plugins { + hooks := p.HookWrapper() + // init all hook wrappers + if hooks.OnAcceptWrapper != nil { + onAcceptWrappers = append(onAcceptWrappers, hooks.OnAcceptWrapper) + } + if hooks.OnBasicAuthWrapper != nil { + onBasicAuthWrappers = append(onBasicAuthWrappers, hooks.OnBasicAuthWrapper) + } + if hooks.OnEnhancedAuthWrapper != nil { + onEnhancedAuthWrappers = append(onEnhancedAuthWrappers, hooks.OnEnhancedAuthWrapper) + } + if hooks.OnReAuthWrapper != nil { + onReAuthWrappers = append(onReAuthWrappers, hooks.OnReAuthWrapper) + } + if hooks.OnConnectedWrapper != nil { + onConnectedWrappers = append(onConnectedWrappers, hooks.OnConnectedWrapper) + } + if hooks.OnSessionCreatedWrapper != nil { + onSessionCreatedWrapper = append(onSessionCreatedWrapper, hooks.OnSessionCreatedWrapper) + } + if hooks.OnSessionResumedWrapper != nil { + onSessionResumedWrapper = append(onSessionResumedWrapper, hooks.OnSessionResumedWrapper) + } + if hooks.OnSessionTerminatedWrapper != nil { + onSessionTerminatedWrapper = append(onSessionTerminatedWrapper, hooks.OnSessionTerminatedWrapper) + } + if hooks.OnSubscribeWrapper != nil { + onSubscribeWrappers = append(onSubscribeWrappers, hooks.OnSubscribeWrapper) + } + if hooks.OnSubscribedWrapper != nil { + onSubscribedWrappers = append(onSubscribedWrappers, hooks.OnSubscribedWrapper) + } + if hooks.OnUnsubscribeWrapper != nil { + onUnsubscribeWrappers = append(onUnsubscribeWrappers, hooks.OnUnsubscribeWrapper) + } + if hooks.OnUnsubscribedWrapper != nil { + onUnsubscribedWrappers = append(onUnsubscribedWrappers, hooks.OnUnsubscribedWrapper) + } + if hooks.OnMsgArrivedWrapper != nil { + onMsgArrivedWrappers = append(onMsgArrivedWrappers, hooks.OnMsgArrivedWrapper) + } + if hooks.OnMsgDroppedWrapper != nil { + onMsgDroppedWrappers = append(onMsgDroppedWrappers, hooks.OnMsgDroppedWrapper) + } + if hooks.OnDeliveredWrapper != nil { + OnDeliveredWrappers = append(OnDeliveredWrappers, hooks.OnDeliveredWrapper) + } + if hooks.OnClosedWrapper != nil { + OnClosedWrappers = append(OnClosedWrappers, hooks.OnClosedWrapper) + } + if hooks.OnStopWrapper != nil { + onStopWrappers = append(onStopWrappers, hooks.OnStopWrapper) + } + if hooks.OnWillPublishWrapper != nil { + onWillPublishWrappers = append(onWillPublishWrappers, hooks.OnWillPublishWrapper) + } + if hooks.OnWillPublishedWrapper != nil { + onWillPublishedWrappers = append(onWillPublishedWrappers, hooks.OnWillPublishedWrapper) + } + } + if onAcceptWrappers != nil { + onAccept := func(ctx context.Context, conn net.Conn) bool { + return true + } + for i := len(onAcceptWrappers); i > 0; i-- { + onAccept = onAcceptWrappers[i-1](onAccept) + } + srv.hooks.OnAccept = onAccept + } + if onBasicAuthWrappers != nil { + onBasicAuth := func(ctx context.Context, client Client, req *ConnectRequest) error { + return nil + } + for i := len(onBasicAuthWrappers); i > 0; i-- { + onBasicAuth = onBasicAuthWrappers[i-1](onBasicAuth) + } + srv.hooks.OnBasicAuth = onBasicAuth + } + if onEnhancedAuthWrappers != nil { + onEnhancedAuth := func(ctx context.Context, client Client, req *ConnectRequest) (resp *EnhancedAuthResponse, err error) { + return &EnhancedAuthResponse{ + Continue: false, + }, nil + } + for i := len(onEnhancedAuthWrappers); i > 0; i-- { + onEnhancedAuth = onEnhancedAuthWrappers[i-1](onEnhancedAuth) + } + srv.hooks.OnEnhancedAuth = onEnhancedAuth + } + + if onConnectedWrappers != nil { + onConnected := func(ctx context.Context, client Client) {} + for i := len(onConnectedWrappers); i > 0; i-- { + onConnected = onConnectedWrappers[i-1](onConnected) + } + srv.hooks.OnConnected = onConnected + } + if onSessionCreatedWrapper != nil { + onSessionCreated := func(ctx context.Context, client Client) {} + for i := len(onSessionCreatedWrapper); i > 0; i-- { + onSessionCreated = onSessionCreatedWrapper[i-1](onSessionCreated) + } + srv.hooks.OnSessionCreated = onSessionCreated + } + if onSessionResumedWrapper != nil { + onSessionResumed := func(ctx context.Context, client Client) {} + for i := len(onSessionResumedWrapper); i > 0; i-- { + onSessionResumed = onSessionResumedWrapper[i-1](onSessionResumed) + } + srv.hooks.OnSessionResumed = onSessionResumed + } + if onSessionTerminatedWrapper != nil { + onSessionTerminated := func(ctx context.Context, clientID string, reason SessionTerminatedReason) {} + for i := len(onSessionTerminatedWrapper); i > 0; i-- { + onSessionTerminated = onSessionTerminatedWrapper[i-1](onSessionTerminated) + } + srv.hooks.OnSessionTerminated = onSessionTerminated + } + if onSubscribeWrappers != nil { + onSubscribe := func(ctx context.Context, client Client, req *SubscribeRequest) error { + return nil + } + for i := len(onSubscribeWrappers); i > 0; i-- { + onSubscribe = onSubscribeWrappers[i-1](onSubscribe) + } + srv.hooks.OnSubscribe = onSubscribe + } + if onSubscribedWrappers != nil { + onSubscribed := func(ctx context.Context, client Client, subscription *mqttbroker.Subscription) {} + for i := len(onSubscribedWrappers); i > 0; i-- { + onSubscribed = onSubscribedWrappers[i-1](onSubscribed) + } + srv.hooks.OnSubscribed = onSubscribed + } + if onUnsubscribeWrappers != nil { + onUnsubscribe := func(ctx context.Context, client Client, req *UnsubscribeRequest) error { + return nil + } + for i := len(onUnsubscribeWrappers); i > 0; i-- { + onUnsubscribe = onUnsubscribeWrappers[i-1](onUnsubscribe) + } + srv.hooks.OnUnsubscribe = onUnsubscribe + } + if onUnsubscribedWrappers != nil { + onUnsubscribed := func(ctx context.Context, client Client, topicName string) {} + for i := len(onUnsubscribedWrappers); i > 0; i-- { + onUnsubscribed = onUnsubscribedWrappers[i-1](onUnsubscribed) + } + srv.hooks.OnUnsubscribed = onUnsubscribed + } + if onMsgArrivedWrappers != nil { + onMsgArrived := func(ctx context.Context, client Client, req *MsgArrivedRequest) error { + return nil + } + for i := len(onMsgArrivedWrappers); i > 0; i-- { + onMsgArrived = onMsgArrivedWrappers[i-1](onMsgArrived) + } + srv.hooks.OnMsgArrived = onMsgArrived + } + if OnDeliveredWrappers != nil { + OnDelivered := func(ctx context.Context, client Client, msg *mqttbroker.Message) {} + for i := len(OnDeliveredWrappers); i > 0; i-- { + OnDelivered = OnDeliveredWrappers[i-1](OnDelivered) + } + srv.hooks.OnDelivered = OnDelivered + } + if OnClosedWrappers != nil { + OnClosed := func(ctx context.Context, client Client, err error) {} + for i := len(OnClosedWrappers); i > 0; i-- { + OnClosed = OnClosedWrappers[i-1](OnClosed) + } + srv.hooks.OnClosed = OnClosed + } + if onStopWrappers != nil { + onStop := func(ctx context.Context) {} + for i := len(onStopWrappers); i > 0; i-- { + onStop = onStopWrappers[i-1](onStop) + } + srv.hooks.OnStop = onStop + } + if onMsgDroppedWrappers != nil { + onMsgDropped := func(ctx context.Context, clientID string, msg *mqttbroker.Message, err error) {} + for i := len(onMsgDroppedWrappers); i > 0; i-- { + onMsgDropped = onMsgDroppedWrappers[i-1](onMsgDropped) + } + srv.hooks.OnMsgDropped = onMsgDropped + } + if onWillPublishWrappers != nil { + onWillPublish := func(ctx context.Context, clientID string, req *WillMsgRequest) {} + for i := len(onWillPublishWrappers); i > 0; i-- { + onWillPublish = onWillPublishWrappers[i-1](onWillPublish) + } + srv.hooks.OnWillPublish = onWillPublish + } + if onWillPublishedWrappers != nil { + onWillPublished := func(ctx context.Context, clientID string, msg *mqttbroker.Message) {} + for i := len(onWillPublishedWrappers); i > 0; i-- { + onWillPublished = onWillPublishedWrappers[i-1](onWillPublished) + } + srv.hooks.OnWillPublished = onWillPublished + } + return nil +} + +func (srv *server) loadPlugins() error { + for _, p := range srv.plugins { + zaplog.Info("loading plugin", zap.String("name", p.Name())) + err := p.Load(srv) + if err != nil { + return err + } + } + return nil +} + +func (srv *server) wsHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + c, err := defaultUpgrader.Upgrade(w, r, nil) + if err != nil { + zaplog.Error("websocket upgrade error", zap.String("Msg", err.Error())) + return + } + defer c.Close() + conn := &wsConn{c.UnderlyingConn(), c} + client, err := srv.newClient(conn) + if err != nil { + zaplog.Error("new client fail", zap.Error(err)) + return + } + client.serve() + } +} + +func (srv *server) setError(err error) { + srv.errOnce.Do(func() { + srv.err = err + srv.exit() + }) +} + +// Run starts the mqttclient server. +func (srv *server) Run() (err error) { + err = srv.Init() + if err != nil { + return err + } + var tcps []string + //var ws []string + for _, v := range srv.tcpListener { + tcps = append(tcps, v.Addr().String()) + } + //for _, v := range srv.websocketServer { + // ws = append(ws, v.Server.Addr) + //} + zaplog.Info("mqtt -broker server started", zap.Strings("tcp server listen on", tcps)) + + srv.status = serverStatusStarted + srv.wg.Add(1) + go srv.eventLoop() + //go srv.serveAPIServer() + for _, ln := range srv.tcpListener { + go srv.serveTCP(ln) + } + //for _, server := range srv.websocketServer { + // mux := http.NewServeMux() + // mux.Handle(server.Path, srv.wsHandler()) + // server.Server.Handler = mux + // go srv.serveWebSocket(server) + //} + srv.wg.Wait() + <-srv.exitedChan + return srv.err +} + +// Stop gracefully stops the mqttclient server by the following steps: +// 1. Closing all opening TCP listeners and shutting down all opening websocket servers +// 2. Closing all idle connections +// 3. Waiting for all connections have been closed +// 4. Triggering OnStop() +func (srv *server) Stop(ctx context.Context) error { + var err error + srv.stopOnce.Do(func() { + zaplog.Info("stopping gmqtt server") + defer func() { + defer close(srv.exitedChan) + zaplog.Info("server stopped") + }() + srv.exit() + + for _, l := range srv.tcpListener { + l.Close() + } + for _, ws := range srv.websocketServer { + ws.Server.Shutdown(ctx) + } + // close all idle alertclient + srv.mu.Lock() + chs := make([]chan struct{}, len(srv.clients)) + i := 0 + for _, c := range srv.clients { + chs[i] = c.closed + i++ + c.Close() + } + srv.mu.Unlock() + + done := make(chan struct{}) + if len(chs) != 0 { + go func() { + for _, v := range chs { + <-v + } + close(done) + }() + } else { + close(done) + } + + select { + case <-ctx.Done(): + zaplog.Warn("server stop timeout, force exit", zap.String("error", ctx.Err().Error())) + err = ctx.Err() + return + case <-done: + for _, v := range srv.plugins { + zaplog.Info("unloading plugin", zap.String("name", v.Name())) + err := v.Unload() + if err != nil { + zaplog.Warn("plugin unload error", zap.String("error", err.Error())) + } + } + if srv.hooks.OnStop != nil { + srv.hooks.OnStop(context.Background()) + } + } + }) + return err +} diff --git a/internal/hummingbird/mqttbroker/server/server_mock.go b/internal/hummingbird/mqttbroker/server/server_mock.go new file mode 100644 index 0000000..c7f59e2 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/server_mock.go @@ -0,0 +1,174 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: server/server.go + +// Package server is a generated GoMock package. +package server + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + config "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" +) + +// MockServer is a mock of Server interface +type MockServer struct { + ctrl *gomock.Controller + recorder *MockServerMockRecorder +} + +// MockServerMockRecorder is the mock recorder for MockServer +type MockServerMockRecorder struct { + mock *MockServer +} + +// NewMockServer creates a new mock instance +func NewMockServer(ctrl *gomock.Controller) *MockServer { + mock := &MockServer{ctrl: ctrl} + mock.recorder = &MockServerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockServer) EXPECT() *MockServerMockRecorder { + return m.recorder +} + +// Publisher mocks base method +func (m *MockServer) Publisher() Publisher { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Publisher") + ret0, _ := ret[0].(Publisher) + return ret0 +} + +// Publisher indicates an expected call of Publisher +func (mr *MockServerMockRecorder) Publisher() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Publisher", reflect.TypeOf((*MockServer)(nil).Publisher)) +} + +// GetConfig mocks base method +func (m *MockServer) GetConfig() config.Config { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetConfig") + ret0, _ := ret[0].(config.Config) + return ret0 +} + +// GetConfig indicates an expected call of GetConfig +func (mr *MockServerMockRecorder) GetConfig() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConfig", reflect.TypeOf((*MockServer)(nil).GetConfig)) +} + +// StatsManager mocks base method +func (m *MockServer) StatsManager() StatsReader { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StatsManager") + ret0, _ := ret[0].(StatsReader) + return ret0 +} + +// StatsManager indicates an expected call of StatsManager +func (mr *MockServerMockRecorder) StatsManager() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StatsManager", reflect.TypeOf((*MockServer)(nil).StatsManager)) +} + +// Stop mocks base method +func (m *MockServer) Stop(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stop", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Stop indicates an expected call of Stop +func (mr *MockServerMockRecorder) Stop(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockServer)(nil).Stop), ctx) +} + +// ApplyConfig mocks base method +func (m *MockServer) ApplyConfig(config config.Config) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ApplyConfig", config) +} + +// ApplyConfig indicates an expected call of ApplyConfig +func (mr *MockServerMockRecorder) ApplyConfig(config interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplyConfig", reflect.TypeOf((*MockServer)(nil).ApplyConfig), config) +} + +// ClientService mocks base method +func (m *MockServer) ClientService() ClientService { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ClientService") + ret0, _ := ret[0].(ClientService) + return ret0 +} + +// ClientService indicates an expected call of ClientService +func (mr *MockServerMockRecorder) ClientService() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClientService", reflect.TypeOf((*MockServer)(nil).ClientService)) +} + +// SubscriptionService mocks base method +func (m *MockServer) SubscriptionService() SubscriptionService { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SubscriptionService") + ret0, _ := ret[0].(SubscriptionService) + return ret0 +} + +// SubscriptionService indicates an expected call of SubscriptionService +func (mr *MockServerMockRecorder) SubscriptionService() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscriptionService", reflect.TypeOf((*MockServer)(nil).SubscriptionService)) +} + +// RetainedService mocks base method +func (m *MockServer) RetainedService() RetainedService { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RetainedService") + ret0, _ := ret[0].(RetainedService) + return ret0 +} + +// RetainedService indicates an expected call of RetainedService +func (mr *MockServerMockRecorder) RetainedService() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RetainedService", reflect.TypeOf((*MockServer)(nil).RetainedService)) +} + +// Plugins mocks base method +func (m *MockServer) Plugins() []Plugin { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Plugins") + ret0, _ := ret[0].([]Plugin) + return ret0 +} + +// Plugins indicates an expected call of Plugins +func (mr *MockServerMockRecorder) Plugins() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Plugins", reflect.TypeOf((*MockServer)(nil).Plugins)) +} + +// APIRegistrar mocks base method +func (m *MockServer) APIRegistrar() APIRegistrar { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "APIRegistrar") + ret0, _ := ret[0].(APIRegistrar) + return ret0 +} + +// APIRegistrar indicates an expected call of APIRegistrar +func (mr *MockServerMockRecorder) APIRegistrar() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "APIRegistrar", reflect.TypeOf((*MockServer)(nil).APIRegistrar)) +} diff --git a/internal/hummingbird/mqttbroker/server/service.go b/internal/hummingbird/mqttbroker/server/service.go new file mode 100644 index 0000000..84ba664 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/service.go @@ -0,0 +1,54 @@ +package server + +import ( + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/session" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/subscription" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/retained" +) + +// Publisher provides the ability to Publish messages to the broker. +type Publisher interface { + // Publish Publish a message to broker. + // Calling this method will not trigger OnMsgArrived hook. + Publish(message *mqttbroker.Message) +} + +// ClientIterateFn is the callback function used by ClientService.IterateClient +// Return false means to stop the iteration. +type ClientIterateFn = func(client Client) bool + +// ClientService provides the ability to query and close alertclient. +type ClientService interface { + IterateSession(fn session.IterateFn) error + GetSession(clientID string) (*mqttbroker.Session, error) + GetClient(clientID string) Client + IterateClient(fn ClientIterateFn) + TerminateSession(clientID string) +} + +// SubscriptionService providers the ability to query and add/delete subscriptions. +type SubscriptionService interface { + // Subscribe adds subscriptions to a specific client. + // Notice: + // This method will succeed even if the client is not exists, the subscriptions + // will affect the new client with the client id. + Subscribe(clientID string, subscriptions ...*mqttbroker.Subscription) (rs subscription.SubscribeResult, err error) + // Unsubscribe removes subscriptions of a specific client. + Unsubscribe(clientID string, topics ...string) error + // UnsubscribeAll removes all subscriptions of a specific client. + UnsubscribeAll(clientID string) error + // Iterate iterates all subscriptions. The callback is called once for each subscription. + // If callback return false, the iteration will be stopped. + // Notice: + // The results are not sorted in any way, no ordering of any kind is guaranteed. + // This method will walk through all subscriptions, + // so it is a very expensive operation. Do not call it frequently. + Iterate(fn subscription.IterateFn, options subscription.IterationOptions) + subscription.StatsReader +} + +// RetainedService providers the ability to query and add/delete retained messages. +type RetainedService interface { + retained.Store +} diff --git a/internal/hummingbird/mqttbroker/server/service_mock.go b/internal/hummingbird/mqttbroker/server/service_mock.go new file mode 100644 index 0000000..93d9f08 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/service_mock.go @@ -0,0 +1,356 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: server/service.go + +// Package server is a generated GoMock package. +package server + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + gmqtt "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker" + session "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/session" + subscription "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/subscription" + retained "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/retained" +) + +// MockPublisher is a mock of Publisher interface +type MockPublisher struct { + ctrl *gomock.Controller + recorder *MockPublisherMockRecorder +} + +// MockPublisherMockRecorder is the mock recorder for MockPublisher +type MockPublisherMockRecorder struct { + mock *MockPublisher +} + +// NewMockPublisher creates a new mock instance +func NewMockPublisher(ctrl *gomock.Controller) *MockPublisher { + mock := &MockPublisher{ctrl: ctrl} + mock.recorder = &MockPublisherMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockPublisher) EXPECT() *MockPublisherMockRecorder { + return m.recorder +} + +// Publish mocks base method +func (m *MockPublisher) Publish(message *gmqtt.Message) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Publish", message) +} + +// Publish indicates an expected call of Publish +func (mr *MockPublisherMockRecorder) Publish(message interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Publish", reflect.TypeOf((*MockPublisher)(nil).Publish), message) +} + +// MockClientService is a mock of ClientService interface +type MockClientService struct { + ctrl *gomock.Controller + recorder *MockClientServiceMockRecorder +} + +// MockClientServiceMockRecorder is the mock recorder for MockClientService +type MockClientServiceMockRecorder struct { + mock *MockClientService +} + +// NewMockClientService creates a new mock instance +func NewMockClientService(ctrl *gomock.Controller) *MockClientService { + mock := &MockClientService{ctrl: ctrl} + mock.recorder = &MockClientServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockClientService) EXPECT() *MockClientServiceMockRecorder { + return m.recorder +} + +// IterateSession mocks base method +func (m *MockClientService) IterateSession(fn session.IterateFn) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IterateSession", fn) + ret0, _ := ret[0].(error) + return ret0 +} + +// IterateSession indicates an expected call of IterateSession +func (mr *MockClientServiceMockRecorder) IterateSession(fn interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IterateSession", reflect.TypeOf((*MockClientService)(nil).IterateSession), fn) +} + +// GetSession mocks base method +func (m *MockClientService) GetSession(clientID string) (*gmqtt.Session, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSession", clientID) + ret0, _ := ret[0].(*gmqtt.Session) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSession indicates an expected call of GetSession +func (mr *MockClientServiceMockRecorder) GetSession(clientID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSession", reflect.TypeOf((*MockClientService)(nil).GetSession), clientID) +} + +// GetClient mocks base method +func (m *MockClientService) GetClient(clientID string) Client { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClient", clientID) + ret0, _ := ret[0].(Client) + return ret0 +} + +// GetClient indicates an expected call of GetClient +func (mr *MockClientServiceMockRecorder) GetClient(clientID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockClientService)(nil).GetClient), clientID) +} + +// IterateClient mocks base method +func (m *MockClientService) IterateClient(fn ClientIterateFn) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "IterateClient", fn) +} + +// IterateClient indicates an expected call of IterateClient +func (mr *MockClientServiceMockRecorder) IterateClient(fn interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IterateClient", reflect.TypeOf((*MockClientService)(nil).IterateClient), fn) +} + +// TerminateSession mocks base method +func (m *MockClientService) TerminateSession(clientID string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "TerminateSession", clientID) +} + +// TerminateSession indicates an expected call of TerminateSession +func (mr *MockClientServiceMockRecorder) TerminateSession(clientID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TerminateSession", reflect.TypeOf((*MockClientService)(nil).TerminateSession), clientID) +} + +// MockSubscriptionService is a mock of SubscriptionService interface +type MockSubscriptionService struct { + ctrl *gomock.Controller + recorder *MockSubscriptionServiceMockRecorder +} + +// MockSubscriptionServiceMockRecorder is the mock recorder for MockSubscriptionService +type MockSubscriptionServiceMockRecorder struct { + mock *MockSubscriptionService +} + +// NewMockSubscriptionService creates a new mock instance +func NewMockSubscriptionService(ctrl *gomock.Controller) *MockSubscriptionService { + mock := &MockSubscriptionService{ctrl: ctrl} + mock.recorder = &MockSubscriptionServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockSubscriptionService) EXPECT() *MockSubscriptionServiceMockRecorder { + return m.recorder +} + +// Subscribe mocks base method +func (m *MockSubscriptionService) Subscribe(clientID string, subscriptions ...*gmqtt.Subscription) (subscription.SubscribeResult, error) { + m.ctrl.T.Helper() + varargs := []interface{}{clientID} + for _, a := range subscriptions { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Subscribe", varargs...) + ret0, _ := ret[0].(subscription.SubscribeResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Subscribe indicates an expected call of Subscribe +func (mr *MockSubscriptionServiceMockRecorder) Subscribe(clientID interface{}, subscriptions ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{clientID}, subscriptions...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Subscribe", reflect.TypeOf((*MockSubscriptionService)(nil).Subscribe), varargs...) +} + +// Unsubscribe mocks base method +func (m *MockSubscriptionService) Unsubscribe(clientID string, topics ...string) error { + m.ctrl.T.Helper() + varargs := []interface{}{clientID} + for _, a := range topics { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Unsubscribe", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Unsubscribe indicates an expected call of Unsubscribe +func (mr *MockSubscriptionServiceMockRecorder) Unsubscribe(clientID interface{}, topics ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{clientID}, topics...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unsubscribe", reflect.TypeOf((*MockSubscriptionService)(nil).Unsubscribe), varargs...) +} + +// UnsubscribeAll mocks base method +func (m *MockSubscriptionService) UnsubscribeAll(clientID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UnsubscribeAll", clientID) + ret0, _ := ret[0].(error) + return ret0 +} + +// UnsubscribeAll indicates an expected call of UnsubscribeAll +func (mr *MockSubscriptionServiceMockRecorder) UnsubscribeAll(clientID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnsubscribeAll", reflect.TypeOf((*MockSubscriptionService)(nil).UnsubscribeAll), clientID) +} + +// Iterate mocks base method +func (m *MockSubscriptionService) Iterate(fn subscription.IterateFn, options subscription.IterationOptions) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Iterate", fn, options) +} + +// Iterate indicates an expected call of Iterate +func (mr *MockSubscriptionServiceMockRecorder) Iterate(fn, options interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Iterate", reflect.TypeOf((*MockSubscriptionService)(nil).Iterate), fn, options) +} + +// GetStats mocks base method +func (m *MockSubscriptionService) GetStats() subscription.Stats { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetStats") + ret0, _ := ret[0].(subscription.Stats) + return ret0 +} + +// GetStats indicates an expected call of GetStats +func (mr *MockSubscriptionServiceMockRecorder) GetStats() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStats", reflect.TypeOf((*MockSubscriptionService)(nil).GetStats)) +} + +// GetClientStats mocks base method +func (m *MockSubscriptionService) GetClientStats(clientID string) (subscription.Stats, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClientStats", clientID) + ret0, _ := ret[0].(subscription.Stats) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetClientStats indicates an expected call of GetClientStats +func (mr *MockSubscriptionServiceMockRecorder) GetClientStats(clientID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientStats", reflect.TypeOf((*MockSubscriptionService)(nil).GetClientStats), clientID) +} + +// MockRetainedService is a mock of RetainedService interface +type MockRetainedService struct { + ctrl *gomock.Controller + recorder *MockRetainedServiceMockRecorder +} + +// MockRetainedServiceMockRecorder is the mock recorder for MockRetainedService +type MockRetainedServiceMockRecorder struct { + mock *MockRetainedService +} + +// NewMockRetainedService creates a new mock instance +func NewMockRetainedService(ctrl *gomock.Controller) *MockRetainedService { + mock := &MockRetainedService{ctrl: ctrl} + mock.recorder = &MockRetainedServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockRetainedService) EXPECT() *MockRetainedServiceMockRecorder { + return m.recorder +} + +// GetRetainedMessage mocks base method +func (m *MockRetainedService) GetRetainedMessage(topicName string) *gmqtt.Message { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRetainedMessage", topicName) + ret0, _ := ret[0].(*gmqtt.Message) + return ret0 +} + +// GetRetainedMessage indicates an expected call of GetRetainedMessage +func (mr *MockRetainedServiceMockRecorder) GetRetainedMessage(topicName interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRetainedMessage", reflect.TypeOf((*MockRetainedService)(nil).GetRetainedMessage), topicName) +} + +// ClearAll mocks base method +func (m *MockRetainedService) ClearAll() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ClearAll") +} + +// ClearAll indicates an expected call of ClearAll +func (mr *MockRetainedServiceMockRecorder) ClearAll() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearAll", reflect.TypeOf((*MockRetainedService)(nil).ClearAll)) +} + +// AddOrReplace mocks base method +func (m *MockRetainedService) AddOrReplace(message *gmqtt.Message) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddOrReplace", message) +} + +// AddOrReplace indicates an expected call of AddOrReplace +func (mr *MockRetainedServiceMockRecorder) AddOrReplace(message interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddOrReplace", reflect.TypeOf((*MockRetainedService)(nil).AddOrReplace), message) +} + +// Remove mocks base method +func (m *MockRetainedService) Remove(topicName string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Remove", topicName) +} + +// Remove indicates an expected call of Remove +func (mr *MockRetainedServiceMockRecorder) Remove(topicName interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockRetainedService)(nil).Remove), topicName) +} + +// GetMatchedMessages mocks base method +func (m *MockRetainedService) GetMatchedMessages(topicFilter string) []*gmqtt.Message { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMatchedMessages", topicFilter) + ret0, _ := ret[0].([]*gmqtt.Message) + return ret0 +} + +// GetMatchedMessages indicates an expected call of GetMatchedMessages +func (mr *MockRetainedServiceMockRecorder) GetMatchedMessages(topicFilter interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMatchedMessages", reflect.TypeOf((*MockRetainedService)(nil).GetMatchedMessages), topicFilter) +} + +// Iterate mocks base method +func (m *MockRetainedService) Iterate(fn retained.IterateFn) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Iterate", fn) +} + +// Iterate indicates an expected call of Iterate +func (mr *MockRetainedServiceMockRecorder) Iterate(fn interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Iterate", reflect.TypeOf((*MockRetainedService)(nil).Iterate), fn) +} diff --git a/internal/hummingbird/mqttbroker/server/stats.go b/internal/hummingbird/mqttbroker/server/stats.go new file mode 100644 index 0000000..34aa0e8 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/stats.go @@ -0,0 +1,498 @@ +package server + +import ( + "sync" + "sync/atomic" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/queue" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/persistence/subscription" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +type statsManager struct { + subStatsReader subscription.StatsReader + totalStats *GlobalStats + clientMu sync.Mutex + clientStats map[string]*ClientStats +} + +func (s *statsManager) getClientStats(clientID string) (stats *ClientStats) { + if stats = s.clientStats[clientID]; stats == nil { + subStats, _ := s.subStatsReader.GetClientStats(clientID) + + stats = &ClientStats{ + SubscriptionStats: subStats, + } + s.clientStats[clientID] = stats + } + return stats +} +func (s *statsManager) packetReceived(packet packets.Packet, clientID string) { + s.totalStats.PacketStats.add(packet, true) + s.clientMu.Lock() + defer s.clientMu.Unlock() + s.getClientStats(clientID).PacketStats.add(packet, true) +} +func (s *statsManager) packetSent(packet packets.Packet, clientID string) { + s.totalStats.PacketStats.add(packet, false) + s.clientMu.Lock() + defer s.clientMu.Unlock() + s.getClientStats(clientID).PacketStats.add(packet, false) +} +func (s *statsManager) clientPacketReceived(packet packets.Packet, clientID string) { + s.clientMu.Lock() + defer s.clientMu.Unlock() + s.getClientStats(clientID).PacketStats.add(packet, true) +} +func (s *statsManager) clientPacketSent(packet packets.Packet, clientID string) { + s.clientMu.Lock() + defer s.clientMu.Unlock() + s.getClientStats(clientID).PacketStats.add(packet, false) +} + +func (s *statsManager) clientConnected(clientID string) { + atomic.AddUint64(&s.totalStats.ConnectionStats.ConnectedTotal, 1) +} + +func (s *statsManager) clientDisconnected(clientID string) { + atomic.AddUint64(&s.totalStats.ConnectionStats.DisconnectedTotal, 1) + s.sessionInActive() +} + +func (s *statsManager) sessionActive(create bool) { + if create { + atomic.AddUint64(&s.totalStats.ConnectionStats.SessionCreatedTotal, 1) + } else { + atomic.AddUint64(&s.totalStats.ConnectionStats.InactiveCurrent, ^uint64(0)) + } + atomic.AddUint64(&s.totalStats.ConnectionStats.ActiveCurrent, 1) +} + +func (s *statsManager) sessionInActive() { + atomic.AddUint64(&s.totalStats.ConnectionStats.ActiveCurrent, ^uint64(0)) + atomic.AddUint64(&s.totalStats.ConnectionStats.InactiveCurrent, 1) +} + +func (s *statsManager) sessionTerminated(clientID string, reason SessionTerminatedReason) { + var i *uint64 + switch reason { + case NormalTermination: + i = &s.totalStats.ConnectionStats.SessionTerminated.Normal + case ExpiredTermination: + i = &s.totalStats.ConnectionStats.SessionTerminated.Expired + case TakenOverTermination: + i = &s.totalStats.ConnectionStats.SessionTerminated.TakenOver + } + atomic.AddUint64(i, 1) + atomic.AddUint64(&s.totalStats.ConnectionStats.InactiveCurrent, ^uint64(0)) + s.clientMu.Lock() + defer s.clientMu.Unlock() + delete(s.clientStats, clientID) +} + +func (s *statsManager) messageDropped(qos uint8, clientID string, err error) { + switch qos { + case packets.Qos0: + s.totalStats.MessageStats.Qos0.DroppedTotal.messageDropped(err) + s.clientMu.Lock() + defer s.clientMu.Unlock() + s.getClientStats(clientID).MessageStats.Qos0.DroppedTotal.messageDropped(err) + case packets.Qos1: + s.totalStats.MessageStats.Qos1.DroppedTotal.messageDropped(err) + s.clientMu.Lock() + defer s.clientMu.Unlock() + s.getClientStats(clientID).MessageStats.Qos1.DroppedTotal.messageDropped(err) + case packets.Qos2: + s.totalStats.MessageStats.Qos2.DroppedTotal.messageDropped(err) + s.clientMu.Lock() + defer s.clientMu.Unlock() + s.getClientStats(clientID).MessageStats.Qos2.DroppedTotal.messageDropped(err) + } +} +func (d *DroppedTotal) messageDropped(err error) { + switch err { + case queue.ErrDropExceedsMaxPacketSize: + atomic.AddUint64(&d.ExceedsMaxPacketSize, 1) + case queue.ErrDropQueueFull: + atomic.AddUint64(&d.QueueFull, 1) + case queue.ErrDropExpired: + atomic.AddUint64(&d.Expired, 1) + case queue.ErrDropExpiredInflight: + atomic.AddUint64(&d.InflightExpired, 1) + default: + atomic.AddUint64(&d.Internal, 1) + } +} + +func (s *statsManager) messageReceived(qos uint8, clientID string) { + switch qos { + case packets.Qos0: + atomic.AddUint64(&s.totalStats.MessageStats.Qos0.ReceivedTotal, 1) + s.clientMu.Lock() + defer s.clientMu.Unlock() + atomic.AddUint64(&s.getClientStats(clientID).MessageStats.Qos0.ReceivedTotal, 1) + case packets.Qos1: + atomic.AddUint64(&s.totalStats.MessageStats.Qos1.ReceivedTotal, 1) + s.clientMu.Lock() + defer s.clientMu.Unlock() + atomic.AddUint64(&s.getClientStats(clientID).MessageStats.Qos0.ReceivedTotal, 1) + case packets.Qos2: + atomic.AddUint64(&s.totalStats.MessageStats.Qos2.ReceivedTotal, 1) + s.clientMu.Lock() + defer s.clientMu.Unlock() + atomic.AddUint64(&s.getClientStats(clientID).MessageStats.Qos0.ReceivedTotal, 1) + } +} + +func (s *statsManager) messageSent(qos uint8, clientID string) { + switch qos { + case packets.Qos0: + atomic.AddUint64(&s.totalStats.MessageStats.Qos0.SentTotal, 1) + s.clientMu.Lock() + defer s.clientMu.Unlock() + atomic.AddUint64(&s.getClientStats(clientID).MessageStats.Qos0.SentTotal, 1) + case packets.Qos1: + atomic.AddUint64(&s.totalStats.MessageStats.Qos1.SentTotal, 1) + s.clientMu.Lock() + defer s.clientMu.Unlock() + atomic.AddUint64(&s.getClientStats(clientID).MessageStats.Qos0.SentTotal, 1) + case packets.Qos2: + atomic.AddUint64(&s.totalStats.MessageStats.Qos2.SentTotal, 1) + s.clientMu.Lock() + defer s.clientMu.Unlock() + atomic.AddUint64(&s.getClientStats(clientID).MessageStats.Qos0.SentTotal, 1) + } +} + +// StatsReader interface provides the ability to access the statistics of the server +type StatsReader interface { + // GetGlobalStats returns the server statistics. + GetGlobalStats() GlobalStats + // GetClientStats returns the client statistics for the given client id + GetClientStats(clientID string) (sts ClientStats, exist bool) +} + +// PacketStats represents the statistics of MQTT Packet. +type PacketStats struct { + BytesReceived PacketBytes + ReceivedTotal PacketCount + BytesSent PacketBytes + SentTotal PacketCount +} + +func (p *PacketStats) add(pt packets.Packet, receive bool) { + b := packets.TotalBytes(pt) + var bytes *PacketBytes + var count *PacketCount + if receive { + bytes = &p.BytesReceived + count = &p.ReceivedTotal + } else { + bytes = &p.BytesSent + count = &p.SentTotal + } + switch pt.(type) { + case *packets.Auth: + atomic.AddUint64(&bytes.Auth, uint64(b)) + atomic.AddUint64(&count.Auth, 1) + case *packets.Connect: + atomic.AddUint64(&bytes.Connect, uint64(b)) + atomic.AddUint64(&count.Connect, 1) + case *packets.Connack: + atomic.AddUint64(&bytes.Connack, uint64(b)) + atomic.AddUint64(&count.Connack, 1) + case *packets.Disconnect: + atomic.AddUint64(&bytes.Disconnect, uint64(b)) + atomic.AddUint64(&count.Disconnect, 1) + case *packets.Pingreq: + atomic.AddUint64(&bytes.Pingreq, uint64(b)) + atomic.AddUint64(&count.Pingreq, 1) + case *packets.Pingresp: + atomic.AddUint64(&bytes.Pingresp, uint64(b)) + atomic.AddUint64(&count.Pingresp, 1) + case *packets.Puback: + atomic.AddUint64(&bytes.Puback, uint64(b)) + atomic.AddUint64(&count.Puback, 1) + case *packets.Pubcomp: + atomic.AddUint64(&bytes.Pubcomp, uint64(b)) + atomic.AddUint64(&count.Pubcomp, 1) + case *packets.Publish: + atomic.AddUint64(&bytes.Publish, uint64(b)) + atomic.AddUint64(&count.Publish, 1) + case *packets.Pubrec: + atomic.AddUint64(&bytes.Pubrec, uint64(b)) + atomic.AddUint64(&count.Pubrec, 1) + case *packets.Pubrel: + atomic.AddUint64(&bytes.Pubrel, uint64(b)) + atomic.AddUint64(&count.Pubrel, 1) + case *packets.Suback: + atomic.AddUint64(&bytes.Suback, uint64(b)) + atomic.AddUint64(&count.Suback, 1) + case *packets.Subscribe: + atomic.AddUint64(&bytes.Subscribe, uint64(b)) + atomic.AddUint64(&count.Subscribe, 1) + case *packets.Unsuback: + atomic.AddUint64(&bytes.Unsuback, uint64(b)) + atomic.AddUint64(&count.Unsuback, 1) + case *packets.Unsubscribe: + atomic.AddUint64(&bytes.Unsubscribe, uint64(b)) + atomic.AddUint64(&count.Unsubscribe, 1) + } + atomic.AddUint64(&bytes.Total, uint64(b)) + atomic.AddUint64(&count.Total, 1) + +} +func (p *PacketStats) copy() *PacketStats { + return &PacketStats{ + BytesReceived: p.BytesReceived.copy(), + ReceivedTotal: p.ReceivedTotal.copy(), + BytesSent: p.BytesSent.copy(), + SentTotal: p.SentTotal.copy(), + } +} + +// PacketBytes represents total bytes of each in type have been received or sent. +type PacketBytes struct { + Auth uint64 + Connect uint64 + Connack uint64 + Disconnect uint64 + Pingreq uint64 + Pingresp uint64 + Puback uint64 + Pubcomp uint64 + Publish uint64 + Pubrec uint64 + Pubrel uint64 + Suback uint64 + Subscribe uint64 + Unsuback uint64 + Unsubscribe uint64 + Total uint64 +} + +func (p *PacketBytes) copy() PacketBytes { + return PacketBytes{ + Connect: atomic.LoadUint64(&p.Connect), + Connack: atomic.LoadUint64(&p.Connack), + Disconnect: atomic.LoadUint64(&p.Disconnect), + Pingreq: atomic.LoadUint64(&p.Pingreq), + Pingresp: atomic.LoadUint64(&p.Pingresp), + Puback: atomic.LoadUint64(&p.Puback), + Pubcomp: atomic.LoadUint64(&p.Pubcomp), + Publish: atomic.LoadUint64(&p.Publish), + Pubrec: atomic.LoadUint64(&p.Pubrec), + Pubrel: atomic.LoadUint64(&p.Pubrel), + Suback: atomic.LoadUint64(&p.Suback), + Subscribe: atomic.LoadUint64(&p.Subscribe), + Unsuback: atomic.LoadUint64(&p.Unsuback), + Unsubscribe: atomic.LoadUint64(&p.Unsubscribe), + Total: atomic.LoadUint64(&p.Total), + } +} + +// PacketCount represents total number of each in type have been received or sent. +type PacketCount = PacketBytes + +// ConnectionStats provides the statistics of client connections. +type ConnectionStats struct { + ConnectedTotal uint64 + DisconnectedTotal uint64 + SessionCreatedTotal uint64 + SessionTerminated struct { + TakenOver uint64 + Expired uint64 + Normal uint64 + } + // ActiveCurrent is the number of used active session. + ActiveCurrent uint64 + // InactiveCurrent is the number of used inactive session. + InactiveCurrent uint64 +} + +func (c *ConnectionStats) copy() *ConnectionStats { + return &ConnectionStats{ + ConnectedTotal: atomic.LoadUint64(&c.ConnectedTotal), + DisconnectedTotal: atomic.LoadUint64(&c.DisconnectedTotal), + SessionCreatedTotal: atomic.LoadUint64(&c.SessionCreatedTotal), + SessionTerminated: struct { + TakenOver uint64 + Expired uint64 + Normal uint64 + }{ + TakenOver: atomic.LoadUint64(&c.SessionTerminated.TakenOver), + Expired: atomic.LoadUint64(&c.SessionTerminated.Expired), + Normal: atomic.LoadUint64(&c.SessionTerminated.Normal), + }, + ActiveCurrent: atomic.LoadUint64(&c.ActiveCurrent), + InactiveCurrent: atomic.LoadUint64(&c.InactiveCurrent), + } +} + +type DroppedTotal struct { + Internal uint64 + ExceedsMaxPacketSize uint64 + QueueFull uint64 + Expired uint64 + InflightExpired uint64 +} + +type MessageQosStats struct { + DroppedTotal DroppedTotal + ReceivedTotal uint64 + SentTotal uint64 +} + +func (m *MessageQosStats) GetDroppedTotal() uint64 { + return m.DroppedTotal.Internal + m.DroppedTotal.Expired + m.DroppedTotal.ExceedsMaxPacketSize + m.DroppedTotal.QueueFull + m.DroppedTotal.InflightExpired +} + +// MessageStats represents the statistics of PUBLISH in, separated by QOS. +type MessageStats struct { + Qos0 MessageQosStats + Qos1 MessageQosStats + Qos2 MessageQosStats + InflightCurrent uint64 + QueuedCurrent uint64 +} + +func (m *MessageStats) GetDroppedTotal() uint64 { + return m.Qos0.GetDroppedTotal() + m.Qos1.GetDroppedTotal() + m.Qos2.GetDroppedTotal() +} + +func (s *statsManager) addInflight(clientID string, delta uint64) { + s.clientMu.Lock() + defer s.clientMu.Unlock() + sts := s.getClientStats(clientID) + atomic.AddUint64(&sts.MessageStats.InflightCurrent, delta) + atomic.AddUint64(&s.totalStats.MessageStats.InflightCurrent, 1) +} +func (s *statsManager) decInflight(clientID string, delta uint64) { + s.clientMu.Lock() + defer s.clientMu.Unlock() + sts := s.getClientStats(clientID) + // Avoid the counter to be negative. + // This could happen if the broker is start with persistence data loaded and send messages from the persistent queue. + // Because the statistic data is not persistent, the init value is always 0. + if atomic.LoadUint64(&sts.MessageStats.InflightCurrent) == 0 { + return + } + atomic.AddUint64(&sts.MessageStats.InflightCurrent, ^uint64(delta-1)) + atomic.AddUint64(&s.totalStats.MessageStats.InflightCurrent, ^uint64(delta-1)) +} + +func (s *statsManager) addQueueLen(clientID string, delta uint64) { + s.clientMu.Lock() + defer s.clientMu.Unlock() + sts := s.getClientStats(clientID) + atomic.AddUint64(&sts.MessageStats.QueuedCurrent, delta) + atomic.AddUint64(&s.totalStats.MessageStats.QueuedCurrent, delta) +} +func (s *statsManager) decQueueLen(clientID string, delta uint64) { + s.clientMu.Lock() + defer s.clientMu.Unlock() + sts := s.getClientStats(clientID) + // Avoid the counter to be negative. + // This could happen if the broker is start with persistence data loaded and send messages from the persistent queue. + // Because the statistic data is not persistent, the init value is always 0. + if atomic.LoadUint64(&sts.MessageStats.QueuedCurrent) == 0 { + return + } + atomic.AddUint64(&sts.MessageStats.QueuedCurrent, ^uint64(delta-1)) + atomic.AddUint64(&s.totalStats.MessageStats.QueuedCurrent, ^uint64(delta-1)) +} + +func (m *MessageStats) copy() *MessageStats { + return &MessageStats{ + Qos0: MessageQosStats{ + DroppedTotal: DroppedTotal{ + Internal: atomic.LoadUint64(&m.Qos0.DroppedTotal.Internal), + ExceedsMaxPacketSize: atomic.LoadUint64(&m.Qos0.DroppedTotal.ExceedsMaxPacketSize), + QueueFull: atomic.LoadUint64(&m.Qos0.DroppedTotal.QueueFull), + Expired: atomic.LoadUint64(&m.Qos0.DroppedTotal.Expired), + InflightExpired: atomic.LoadUint64(&m.Qos0.DroppedTotal.InflightExpired), + }, + ReceivedTotal: atomic.LoadUint64(&m.Qos0.ReceivedTotal), + SentTotal: atomic.LoadUint64(&m.Qos0.SentTotal), + }, + Qos1: MessageQosStats{ + DroppedTotal: DroppedTotal{ + Internal: atomic.LoadUint64(&m.Qos1.DroppedTotal.Internal), + ExceedsMaxPacketSize: atomic.LoadUint64(&m.Qos1.DroppedTotal.ExceedsMaxPacketSize), + QueueFull: atomic.LoadUint64(&m.Qos1.DroppedTotal.QueueFull), + Expired: atomic.LoadUint64(&m.Qos1.DroppedTotal.Expired), + InflightExpired: atomic.LoadUint64(&m.Qos1.DroppedTotal.InflightExpired), + }, + ReceivedTotal: atomic.LoadUint64(&m.Qos1.ReceivedTotal), + SentTotal: atomic.LoadUint64(&m.Qos1.SentTotal), + }, + Qos2: MessageQosStats{ + DroppedTotal: DroppedTotal{ + Internal: atomic.LoadUint64(&m.Qos2.DroppedTotal.Internal), + ExceedsMaxPacketSize: atomic.LoadUint64(&m.Qos2.DroppedTotal.ExceedsMaxPacketSize), + QueueFull: atomic.LoadUint64(&m.Qos2.DroppedTotal.QueueFull), + Expired: atomic.LoadUint64(&m.Qos2.DroppedTotal.Expired), + InflightExpired: atomic.LoadUint64(&m.Qos2.DroppedTotal.InflightExpired), + }, + ReceivedTotal: atomic.LoadUint64(&m.Qos2.ReceivedTotal), + SentTotal: atomic.LoadUint64(&m.Qos2.SentTotal), + }, + InflightCurrent: atomic.LoadUint64(&m.InflightCurrent), + QueuedCurrent: atomic.LoadUint64(&m.QueuedCurrent), + } +} + +// GlobalStats is the collection of global statistics. +type GlobalStats struct { + ConnectionStats ConnectionStats + PacketStats PacketStats + MessageStats MessageStats + SubscriptionStats subscription.Stats +} + +// ClientStats is the statistic information of one client. +type ClientStats struct { + PacketStats PacketStats + MessageStats MessageStats + SubscriptionStats subscription.Stats +} + +func (c ClientStats) GetDroppedTotal() uint64 { + return c.MessageStats.Qos0.GetDroppedTotal() + c.MessageStats.Qos1.GetDroppedTotal() + c.MessageStats.Qos2.GetDroppedTotal() +} + +// GetGlobalStats returns the GlobalStats +func (s *statsManager) GetGlobalStats() GlobalStats { + return GlobalStats{ + PacketStats: *s.totalStats.PacketStats.copy(), + ConnectionStats: *s.totalStats.ConnectionStats.copy(), + MessageStats: *s.totalStats.MessageStats.copy(), + SubscriptionStats: s.subStatsReader.GetStats(), + } +} + +// GetClientStats returns the client statistic information for given client id. +func (s *statsManager) GetClientStats(clientID string) (ClientStats, bool) { + s.clientMu.Lock() + defer s.clientMu.Unlock() + if stats := s.clientStats[clientID]; stats == nil { + return ClientStats{}, false + } else { + s, _ := s.subStatsReader.GetClientStats(clientID) + return ClientStats{ + PacketStats: *stats.PacketStats.copy(), + MessageStats: *stats.MessageStats.copy(), + SubscriptionStats: s, + }, true + } + +} + +func newStatsManager(subStatsReader subscription.StatsReader) *statsManager { + return &statsManager{ + subStatsReader: subStatsReader, + totalStats: &GlobalStats{}, + clientMu: sync.Mutex{}, + clientStats: make(map[string]*ClientStats), + } +} diff --git a/internal/hummingbird/mqttbroker/server/stats_mock.go b/internal/hummingbird/mqttbroker/server/stats_mock.go new file mode 100644 index 0000000..95dcaac --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/stats_mock.go @@ -0,0 +1,63 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: server/stats.go + +// Package server is a generated GoMock package. +package server + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockStatsReader is a mock of StatsReader interface +type MockStatsReader struct { + ctrl *gomock.Controller + recorder *MockStatsReaderMockRecorder +} + +// MockStatsReaderMockRecorder is the mock recorder for MockStatsReader +type MockStatsReaderMockRecorder struct { + mock *MockStatsReader +} + +// NewMockStatsReader creates a new mock instance +func NewMockStatsReader(ctrl *gomock.Controller) *MockStatsReader { + mock := &MockStatsReader{ctrl: ctrl} + mock.recorder = &MockStatsReaderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockStatsReader) EXPECT() *MockStatsReaderMockRecorder { + return m.recorder +} + +// GetGlobalStats mocks base method +func (m *MockStatsReader) GetGlobalStats() GlobalStats { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetGlobalStats") + ret0, _ := ret[0].(GlobalStats) + return ret0 +} + +// GetGlobalStats indicates an expected call of GetGlobalStats +func (mr *MockStatsReaderMockRecorder) GetGlobalStats() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGlobalStats", reflect.TypeOf((*MockStatsReader)(nil).GetGlobalStats)) +} + +// GetClientStats mocks base method +func (m *MockStatsReader) GetClientStats(clientID string) (ClientStats, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClientStats", clientID) + ret0, _ := ret[0].(ClientStats) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// GetClientStats indicates an expected call of GetClientStats +func (mr *MockStatsReaderMockRecorder) GetClientStats(clientID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientStats", reflect.TypeOf((*MockStatsReader)(nil).GetClientStats), clientID) +} diff --git a/internal/hummingbird/mqttbroker/server/testdata/ca.pem b/internal/hummingbird/mqttbroker/server/testdata/ca.pem new file mode 100644 index 0000000..9b51f66 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/testdata/ca.pem @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC0zCCAbsCFHNchtacwmLUKBWHbVAe0M+2H+tEMA0GCSqGSIb3DQEBCwUAMCUx +CzAJBgNVBAYTAkNOMRYwFAYDVQQDDA1kcm1hZ2ljLmxvY2FsMCAXDTIxMDEyNDEz +NDUxOVoYDzIxMjAxMjMxMTM0NTE5WjAlMQswCQYDVQQGEwJDTjEWMBQGA1UEAwwN +ZHJtYWdpYy5sb2NhbDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAL3v +sYOylxpCCNWLyLOjL+smnZgFsbt7PL9wxOJOgTFVesVV/mRnlydn9Ism9ERCIBHF +yfsX6lnOKkqGisoTt5DuBphwJeZjSJYOjTIgQdcVbLyvspPN1+no2qAO/jv1Fsg6 +WXmq6lkEc1LPE+fVlQG9pl6ypBdrrCzGKFtfEI+B3nuDIlLhzt2avZ4RmaFZjQJW +WtoWHN56ujoZLzUVv+tjc/wgmMTCA36TYS5jBWXOPfqvg0hYRBysqBZACu2jZ/R0 +qeCvLwJemUlECpiNbEn2w9ApEKlyM58ArXlVLixkYZxEVa+Ai6q7aYRpNvg3gKZw +R/9zWPn/0t8u4Z7GwykCAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAESviZdpHauCg +2ir8kn314rMK9QrK/nt60z+Cd5FkaFKiHQUuD+obXzri5R2qzHNZLJdOmpzaI+1e +tGHJ1jh0J1ShMDDr9qA/CknBM3r/dzDHneNb8B0xFxOABI5vcywG/xM8Dv/dBIuF +PuWvjvw7EJI3i6Vy2tR885ksDB/ucoNSpWevdXDJdoUxA88vNgt1nMzMy4+IGBYf +TwqK++T3V5DCGQj+24eYgiShHAIchbgoB8F+Suvseo8kd9kFsoORToqmoTM1J4JM +K4u8Xvh5sRnbo7ViwcKPAD3fI14z0mqFCObllp6ynib6WAAGztW2F4khyY4WTHH9 +MNE9/UMb6Q== +-----END CERTIFICATE----- diff --git a/internal/hummingbird/mqttbroker/server/testdata/extfile.cnf b/internal/hummingbird/mqttbroker/server/testdata/extfile.cnf new file mode 100644 index 0000000..63a1d15 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/testdata/extfile.cnf @@ -0,0 +1,2 @@ +subjectAltName = DNS:drmagic.local,IP:127.0.0.1 +extendedKeyUsage = serverAuth diff --git a/internal/hummingbird/mqttbroker/server/testdata/openssl.conf b/internal/hummingbird/mqttbroker/server/testdata/openssl.conf new file mode 100644 index 0000000..132b248 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/testdata/openssl.conf @@ -0,0 +1,7 @@ +[req] +distinguished_name = req_distinguished_name +prompt = no + +[req_distinguished_name] +C = CN +CN = drmagic.local \ No newline at end of file diff --git a/internal/hummingbird/mqttbroker/server/testdata/server-cert.pem b/internal/hummingbird/mqttbroker/server/testdata/server-cert.pem new file mode 100644 index 0000000..f3a65c0 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/testdata/server-cert.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDBDCCAeygAwIBAgIUFxnmuzcIANni8M9czNSrB3dPRSIwDQYJKoZIhvcNAQEL +BQAwJTELMAkGA1UEBhMCQ04xFjAUBgNVBAMMDWRybWFnaWMubG9jYWwwIBcNMjEw +MTI0MTM0NTE5WhgPMjEyMDEyMzExMzQ1MTlaMBgxFjAUBgNVBAMMDWRybWFnaWMu +bG9jYWwwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDW37XX6xuVLB4e +KEJfCJ53bb4hJYE4bLidC0a47pUoLQ+eQF0Dp6vF+d79d35Vq9WJJ2AceZO8s1Zf +41aGWlioFk8M6TjFvIxEQMdeDkmsnsCGjNBfWHgESv3OZm2pIF2Hww9aNIz2f24K +QBcRYc7wYKYYyfBscMa+aPg7PkaK9FUU/eV+QJHLMElYEwa/vicJ1qeNjlyBczvY +ckDn2TQdk6eOkwSn6ZlsjGlK0B5cb3dL3UCUBeZjLvEOiB4oXXiBuHCnjhrjoJxL +SbZ9MPxQgN/Np34PBjR7dXxIKnLJzk8E+fD7o+yNcUo9Y4JPiJXom6isiwWSIVSa +fP23DGxhAgMBAAGjNzA1MB4GA1UdEQQXMBWCDWRybWFnaWMubG9jYWyHBH8AAAEw +EwYDVR0lBAwwCgYIKwYBBQUHAwEwDQYJKoZIhvcNAQELBQADggEBALM8lVW82KRr +XVh829urMs6emCjeYhdqHFk8QyX48IprOTBFmTrVFNfD8zcX6NlhsPxPFjsWy5ND +E7T0qROQ0x4/9oe8Hr6+wm1qXfSD02aBop+67WBBUFI2bGm44ZSCeEL/1GACaZry +h+knJAQQp5+mHszQDz2XaqzUOE6tfa9guRUHo9GVO9oIJdP/DjaT9XpsNdHczZdD +1H4Yweit61JaiizA1nMJ1LT0mq8P780InbTgj1r/WgfVhlO1CZ6L3IxGLEUxHg3y +TFRG6Z82rrxi1DA20NZPeB3nDTS7IeIEDpYIn88olTLStiKN6du805nsZ/5+clNs +Fn4N5RMn2vc= +-----END CERTIFICATE----- diff --git a/internal/hummingbird/mqttbroker/server/testdata/server-key.pem b/internal/hummingbird/mqttbroker/server/testdata/server-key.pem new file mode 100644 index 0000000..f019f26 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/testdata/server-key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEA1t+11+sblSweHihCXwied22+ISWBOGy4nQtGuO6VKC0PnkBd +A6erxfne/Xd+VavViSdgHHmTvLNWX+NWhlpYqBZPDOk4xbyMREDHXg5JrJ7AhozQ +X1h4BEr9zmZtqSBdh8MPWjSM9n9uCkAXEWHO8GCmGMnwbHDGvmj4Oz5GivRVFP3l +fkCRyzBJWBMGv74nCdanjY5cgXM72HJA59k0HZOnjpMEp+mZbIxpStAeXG93S91A +lAXmYy7xDogeKF14gbhwp44a46CcS0m2fTD8UIDfzad+DwY0e3V8SCpyyc5PBPnw ++6PsjXFKPWOCT4iV6JuorIsFkiFUmnz9twxsYQIDAQABAoIBAFsPx8LPsorPfZwO +N8KKpo26hn8Jo+/Ds6Fqa/hns/Ko1huc705jOprWQDhu8a1g+0f61fJ7W6722b4d +XEfn9faWLb4tAJBcTZ2HTnZ/2506UiEzgANIPOSk21cjdYndW4XzlogGCU9Vxc62 +RpBpQQgCDaInwqpSSQfc+IYy6DZuekPCbm3hBhEF9grY9j2/QrVHKjbx9rGA6jD3 +31FjL1SqGrkgduef2EW9geoduVSGnEyYU0CoVZb/es51c/5rAzv91eC21egu0OPq +XPFtDM5Gz/4iC9wQ7k2EDF5LiKR49DKmAJM4FSRZKCqDYY7NtFHILXC9S5rMeuMQ +1mVnknkCgYEA/ntEJyIc8h/QzKuLi5ydMZzIXmF6aKW1Gz0ZwMKKDhWMiNLyp+Kc +N2RXdTlmyZNkzcnm0/SmUyaoC5o9NJIVg7HfdyYhgsN/MBHJ9Q+ISHbAogcVugk0 +3CRZ9c3kyVkqiLJeY194/rpI/S7m+/VkeyNqcwBedb2/CDrJP0q0Rk8CgYEA2Cfw +/XVYG3TqziDqWonZtINc275yP5ecw4N9qXIjuoqH6L7N/MzC52QfkpJFoO+bcGxe +umg4mjFA67RCCpFLY2jhh4nS4bAwh3bM+EXgnL1rrgPAz4ZasnsbXEcivP+7a4SZ +pKRT/20CUjimeykhZxzAZcvuENCkGA+WmpWVJk8CgYEA+hIFkfMKwL+U/lsgoMwB +CKzJlT1y/XzA8IhlUy+YXGi+lgG9ZE7iNeiLrO0AXdtSdosOIoDKJPHatrQVqyBW +tfhH4Rz+VzJnPMRuUju2L4dKmq4dopfDcwThxhNS3K2bh4LIEBzUmHRUnz/EyhmF +aSAPTf0x1b/lBmBGPMTbTC8CgYAJmDRFO9kuVtE5VxKv9CB6t73+bwSpN/SYZRTF +2bAmTpHbzeRczUX1eWdBXUbD7v7KTbUitw+UII2OKNEpoOtkvToNhxuaMvTkfmx4 +tLlUm7/U2IvNalxKQdakEPBEzWEnU5pySW0FEHSi66rQGrJF3mvX2OZ3TpuKCd8Y +e31EVwKBgDWP9poeeli+0+EqbsZtP7lGN63yp4XYMYim9zKojmAu6tuAUkUCbIzn +5Jdeoe6I7LgTpnahC+c+a8/4JuPKrYeWV/Tf9R+zatGyYzZ6W4qzlpYJaE6zFAJq +b+mx/By39E+WH3bRQBhcsp1hmxclFhd4KWLdqm5+Zsycu7JvUSjY +-----END RSA PRIVATE KEY----- diff --git a/internal/hummingbird/mqttbroker/server/testdata/test_gen.sh b/internal/hummingbird/mqttbroker/server/testdata/test_gen.sh new file mode 100755 index 0000000..c7d1949 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/testdata/test_gen.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash +# ca key +openssl genrsa -out ca-key.pem 2048 +# ca certificate +openssl req -new -x509 -days 36500 -key ca-key.pem -out ca.pem -config openssl.conf +# server key +openssl genrsa -out server-key.pem 2048 +# server csr +openssl req -subj "/CN=drmagic.local" -new -key server-key.pem -out server.csr +# sign the public key with our CA +openssl x509 -req -days 36500 -in server.csr -CA ca.pem -CAkey ca-key.pem \ + -CAcreateserial -out server-cert.pem -extfile extfile.cnf + +rm ./ca.srl ./ca-key.pem ./server.csr \ No newline at end of file diff --git a/internal/hummingbird/mqttbroker/server/topic_alias.go b/internal/hummingbird/mqttbroker/server/topic_alias.go new file mode 100644 index 0000000..84ae5ba --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/topic_alias.go @@ -0,0 +1,20 @@ +package server + +import ( + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +type NewTopicAliasManager func(config config.Config, maxAlias uint16, clientID string) TopicAliasManager + +// TopicAliasManager manage the topic alias for a V5 client. +// see topicalias/fifo for more details. +type TopicAliasManager interface { + // Check return the alias number and whether the alias exist. + // For examples: + // If the Publish alias exist and the manager decides to use the alias, it return the alias number and true. + // If the Publish alias exist, but the manager decides not to use alias, it return 0 and true. + // If the Publish alias not exist and the manager decides to assign a new alias, it return the new alias and false. + // If the Publish alias not exist, but the manager decides not to assign alias, it return the 0 and false. + Check(publish *packets.Publish) (alias uint16, exist bool) +} diff --git a/internal/hummingbird/mqttbroker/server/topic_alias_mock.go b/internal/hummingbird/mqttbroker/server/topic_alias_mock.go new file mode 100644 index 0000000..e8280b3 --- /dev/null +++ b/internal/hummingbird/mqttbroker/server/topic_alias_mock.go @@ -0,0 +1,50 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: server/topic_alias.go + +// Package server is a generated GoMock package. +package server + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + packets "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +// MockTopicAliasManager is a mock of TopicAliasManager interface +type MockTopicAliasManager struct { + ctrl *gomock.Controller + recorder *MockTopicAliasManagerMockRecorder +} + +// MockTopicAliasManagerMockRecorder is the mock recorder for MockTopicAliasManager +type MockTopicAliasManagerMockRecorder struct { + mock *MockTopicAliasManager +} + +// NewMockTopicAliasManager creates a new mock instance +func NewMockTopicAliasManager(ctrl *gomock.Controller) *MockTopicAliasManager { + mock := &MockTopicAliasManager{ctrl: ctrl} + mock.recorder = &MockTopicAliasManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockTopicAliasManager) EXPECT() *MockTopicAliasManagerMockRecorder { + return m.recorder +} + +// Check mocks base method +func (m *MockTopicAliasManager) Check(publish *packets.Publish) (uint16, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Check", publish) + ret0, _ := ret[0].(uint16) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// Check indicates an expected call of Check +func (mr *MockTopicAliasManagerMockRecorder) Check(publish interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Check", reflect.TypeOf((*MockTopicAliasManager)(nil).Check), publish) +} diff --git a/internal/hummingbird/mqttbroker/session.go b/internal/hummingbird/mqttbroker/session.go new file mode 100644 index 0000000..e55f9d4 --- /dev/null +++ b/internal/hummingbird/mqttbroker/session.go @@ -0,0 +1,24 @@ +package mqttbroker + +import ( + "time" +) + +// Session represents a MQTT session. +type Session struct { + // ClientID represents the client id. + ClientID string + // Will is the will message of the client, can be nil if there is no will message. + Will *Message + // WillDelayInterval represents the Will Delay Interval in seconds + WillDelayInterval uint32 + // ConnectedAt is the session create time. + ConnectedAt time.Time + // ExpiryInterval represents the Session Expiry Interval in seconds + ExpiryInterval uint32 +} + +// IsExpired return whether the session is expired +func (s *Session) IsExpired(now time.Time) bool { + return s.ConnectedAt.Add(time.Duration(s.ExpiryInterval) * time.Second).Before(now) +} diff --git a/internal/hummingbird/mqttbroker/subscription.go b/internal/hummingbird/mqttbroker/subscription.go new file mode 100644 index 0000000..62a50d7 --- /dev/null +++ b/internal/hummingbird/mqttbroker/subscription.go @@ -0,0 +1,65 @@ +package mqttbroker + +import ( + "errors" + + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +// Subscription represents a subscription in mqttbroker. +type Subscription struct { + // ShareName is the share name of a shared subscription. + // set to "" if it is a non-shared subscription. + ShareName string + // TopicFilter is the topic filter which does not include the share name. + TopicFilter string + // ID is the subscription identifier + ID uint32 + // The following fields are Subscription Options. + // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901169 + + // QoS is the qos level of the Subscription. + QoS packets.QoS + // NoLocal is the No Local option. + NoLocal bool + // RetainAsPublished is the Retain As Published option. + RetainAsPublished bool + // RetainHandling the Retain Handling option. + RetainHandling byte +} + +// GetFullTopicName returns the full topic name of the subscription. +func (s *Subscription) GetFullTopicName() string { + if s.ShareName != "" { + return "$share/" + s.ShareName + "/" + s.TopicFilter + } + return s.TopicFilter +} + +// Copy makes a copy of subscription. +func (s *Subscription) Copy() *Subscription { + return &Subscription{ + ShareName: s.ShareName, + TopicFilter: s.TopicFilter, + ID: s.ID, + QoS: s.QoS, + NoLocal: s.NoLocal, + RetainAsPublished: s.RetainAsPublished, + RetainHandling: s.RetainHandling, + } +} + +// Validate returns whether the subscription is valid. +// If you can ensure the subscription is valid then just skip the validation. +func (s *Subscription) Validate() error { + if !packets.ValidV5Topic([]byte(s.GetFullTopicName())) { + return errors.New("invalid topic name") + } + if s.QoS > 2 { + return errors.New("invalid qos") + } + if s.RetainHandling != 0 && s.RetainHandling != 1 && s.RetainHandling != 2 { + return errors.New("invalid retain handling") + } + return nil +} diff --git a/internal/hummingbird/mqttbroker/topicalias/fifo/fifo.go b/internal/hummingbird/mqttbroker/topicalias/fifo/fifo.go new file mode 100644 index 0000000..dda43f5 --- /dev/null +++ b/internal/hummingbird/mqttbroker/topicalias/fifo/fifo.go @@ -0,0 +1,68 @@ +package fifo + +import ( + "container/list" + + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/config" + "github.com/winc-link/hummingbird/internal/hummingbird/mqttbroker/server" + "github.com/winc-link/hummingbird/internal/pkg/packets" +) + +var _ server.TopicAliasManager = (*Queue)(nil) + +func init() { + server.RegisterTopicAliasMgrFactory("fifo", New) +} + +// New is the constructor of Queue. +func New(config config.Config, maxAlias uint16, clientID string) server.TopicAliasManager { + return &Queue{ + clientID: clientID, + topicAlias: &topicAlias{ + max: int(maxAlias), + alias: list.New(), + index: make(map[string]uint16), + }, + } +} + +// Queue is the fifo queue which store all topic alias for one client +type Queue struct { + clientID string + topicAlias *topicAlias +} +type topicAlias struct { + max int + alias *list.List + // topic name => alias + index map[string]uint16 +} +type aliasElem struct { + topic string + alias uint16 +} + +func (q *Queue) Check(publish *packets.Publish) (alias uint16, exist bool) { + topicName := string(publish.TopicName) + // alias exist + if a, ok := q.topicAlias.index[topicName]; ok { + return a, true + } + l := q.topicAlias.alias.Len() + // alias has been exhausted + if l == q.topicAlias.max { + first := q.topicAlias.alias.Front() + elem := first.Value.(*aliasElem) + q.topicAlias.alias.Remove(first) + delete(q.topicAlias.index, elem.topic) + alias = elem.alias + } else { + alias = uint16(l + 1) + } + q.topicAlias.alias.PushBack(&aliasElem{ + topic: topicName, + alias: alias, + }) + q.topicAlias.index[topicName] = alias + return +}