From f3f1b8d981d818b38713fe84deb206720b0fcb10 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Wed, 26 Feb 2020 16:30:10 +0100 Subject: Room autorejoin --- account.go | 65 +++++++++++++++++++++++++++++++++++++++++++++++++- connector/xmpp/xmpp.go | 3 ++- db.go | 16 +++++++++++++ main.go | 22 +++-------------- server.go | 12 ++-------- 5 files changed, 87 insertions(+), 31 deletions(-) diff --git a/account.go b/account.go index 8da6d44..a76b782 100644 --- a/account.go +++ b/account.go @@ -3,6 +3,7 @@ package main import ( "fmt" "strings" + "sync" log "github.com/sirupsen/logrus" @@ -18,17 +19,23 @@ type Account struct { JoinedRooms map[RoomID]bool } +var accountsLock sync.Mutex var registeredAccounts = map[string]map[string]*Account{} func AddAccount(a *Account) { + accountsLock.Lock() + defer accountsLock.Unlock() + if _, ok := registeredAccounts[a.MatrixUser]; !ok { registeredAccounts[a.MatrixUser] = make(map[string]*Account) } registeredAccounts[a.MatrixUser][a.AccountName] = a - ezbrSystemSendf(a.MatrixUser, "Connecting to account %s (%s)", a.AccountName, a.Protocol) } func FindAccount(mxUser string, name string) *Account { + accountsLock.Lock() + defer accountsLock.Unlock() + if u, ok := registeredAccounts[mxUser]; ok { if a, ok := u[name]; ok { return a @@ -38,6 +45,9 @@ func FindAccount(mxUser string, name string) *Account { } func FindJoinedAccount(mxUser string, protocol string, room RoomID) *Account { + accountsLock.Lock() + defer accountsLock.Unlock() + if u, ok := registeredAccounts[mxUser]; ok { for _, acct := range u { if acct.Protocol == protocol { @@ -51,17 +61,55 @@ func FindJoinedAccount(mxUser string, protocol string, room RoomID) *Account { } func RemoveAccount(mxUser string, name string) { + accountsLock.Lock() + defer accountsLock.Unlock() + if u, ok := registeredAccounts[mxUser]; ok { delete(u, name) } } +// ---- + func (a *Account) ezbrMessagef(format string, args ...interface{}) { msg := fmt.Sprintf(format, args...) msg = fmt.Sprintf("%s: %s", a.Protocol, msg) ezbrSystemSend(a.MatrixUser, msg) } +func (a *Account) connect(config map[string]string, join_rooms []string) { + ezbrSystemSendf(a.MatrixUser, "Connecting to account %s (%s)", a.AccountName, a.Protocol) + + err := a.Conn.Configure(config) + if err != nil { + ezbrSystemSendf(a.MatrixUser, "%s (%s) cannot connect: %s", a.AccountName, a.Protocol, err.Error()) + return + } + + for _, room := range join_rooms { + var entry DbJoinedRoom + db.Where(&DbJoinedRoom{ + MxUserID: a.MatrixUser, + Protocol: a.Protocol, + AccountName: a.AccountName, + RoomID: RoomID(room), + }).FirstOrCreate(&entry) + } + + var autojoin []DbJoinedRoom + db.Where(&DbJoinedRoom{ + MxUserID: a.MatrixUser, + Protocol: a.Protocol, + AccountName: a.AccountName, + }).Find(&autojoin) + for _, aj := range autojoin { + err := a.Conn.Join(aj.RoomID) + if err != nil { + ezbrSystemSendf(a.MatrixUser, "%s (%s) cannot join %s: %s", a.AccountName, a.Protocol, aj.RoomID, err.Error()) + } + } +} + // ---- Begin event handlers ---- func (a *Account) Joined(roomId RoomID) { @@ -69,6 +117,14 @@ func (a *Account) Joined(roomId RoomID) { if err != nil { a.ezbrMessagef("Dropping Account.Joined %s: %s", roomId, err.Error()) } + + var entry DbJoinedRoom + db.Where(&DbJoinedRoom{ + MxUserID: a.MatrixUser, + Protocol: a.Protocol, + AccountName: a.AccountName, + RoomID: roomId, + }).FirstOrCreate(&entry) } func (a *Account) joinedInternal(roomId RoomID) error { @@ -95,6 +151,13 @@ func (a *Account) Left(roomId RoomID) { if err != nil { a.ezbrMessagef("Dropping Account.Left %s: %s", roomId, err.Error()) } + + db.Where(&DbJoinedRoom{ + MxUserID: a.MatrixUser, + Protocol: a.Protocol, + AccountName: a.AccountName, + RoomID: roomId, + }).Delete(&DbJoinedRoom{}) } func (a *Account) leftInternal(roomId RoomID) error { diff --git a/connector/xmpp/xmpp.go b/connector/xmpp/xmpp.go index 7e35135..727c4d8 100644 --- a/connector/xmpp/xmpp.go +++ b/connector/xmpp/xmpp.go @@ -81,7 +81,8 @@ func (xm *XMPP) Configure(c Configuration) error { return fmt.Errorf("JID %s not on server %s", xm.jid, xm.server) } xm.jid_localpart = jid_parts[0] - xm.nickname = xm.jid_localpart + + xm.nickname = c.GetString("nickname", xm.jid_locakpart) xm.password, err = c.GetString("password") if err != nil { diff --git a/db.go b/db.go index fe3d1e3..602bd1f 100644 --- a/db.go +++ b/db.go @@ -36,6 +36,9 @@ func InitDb() error { db.AutoMigrate(&DbPmRoomMap{}) db.Model(&DbPmRoomMap{}).AddIndex("idx_protocol_user_account_user", "protocol", "user_id", "mx_user_id", "account_name") + db.AutoMigrate(&DbJoinedRoom{}) + db.Model(&DbJoinedRoom{}).AddIndex("idx_user_protocol_account", "mx_user_id", "protocol", "account_name") + return nil } @@ -86,6 +89,19 @@ type DbPmRoomMap struct { MxRoomID string `gorm:"index:mxroomoid"` } +// List of joined channels to be re-joined on reconnect +type DbJoinedRoom struct { + gorm.Model + + // User id and account name + MxUserID string + Protocol string + AccountName string + + // Room ID + RoomID connector.RoomID +} + // ---- Simple locking mechanism var dbLocks [256]sync.Mutex diff --git a/main.go b/main.go index d74e3d8..3cec532 100644 --- a/main.go +++ b/main.go @@ -29,7 +29,7 @@ type ConfigAccount struct { } type ConfigFile struct { - HttpBindAddr string `json:"http_bind_addr"` + ASBindAddr string `json:"appservice_bind_addr"` Registration string `json:"registration"` Server string `json:"homeserver_url"` DbType string `json:"db_type"` @@ -45,7 +45,7 @@ var registration *mxlib.Registration func readConfig() ConfigFile { config_file := ConfigFile{ - HttpBindAddr: "0.0.0.0:8321", + ASBindAddr: "0.0.0.0:8321", Registration: "./registration.yaml", Server: "http://localhost:8008", DbType: "sqlite3", @@ -192,7 +192,7 @@ func main() { } conn.SetHandler(account) AddAccount(account) - go connectAndJoin(account, params) + go account.connect(params.Config, params.Rooms) } } @@ -201,19 +201,3 @@ func main() { log.Fatal(err) } } - -func connectAndJoin(account *Account, params ConfigAccount) { - log.Printf("Connecting to %s", params.Protocol) - err := account.Conn.Configure(params.Config) - if err != nil { - log.Printf("Could not connect to %s: %s", params.Protocol, err) - } else { - log.Printf("Connected to %s, now joining %#v", params.Protocol, params.Rooms) - for _, room := range params.Rooms { - err := account.Conn.Join(connector.RoomID(room)) - if err != nil { - log.Printf("Could not join %s: %s", room, err) - } - } - } -} diff --git a/server.go b/server.go index 10721be..84a1e85 100644 --- a/server.go +++ b/server.go @@ -13,14 +13,6 @@ import ( "git.deuxfleurs.fr/Deuxfleurs/easybridge/mxlib" ) -type Config struct { - HttpBindAddr string - Server string - DbType string - DbPath string - MatrixDomain string -} - var mx *mxlib.Client func StartAppService() (chan error, error) { @@ -55,8 +47,8 @@ func StartAppService() (chan error, error) { errch := make(chan error) go func() { - log.Printf("Starting HTTP server on %s", config.HttpBindAddr) - err := http.ListenAndServe(config.HttpBindAddr, checkTokenAndLog(router)) + log.Printf("Starting HTTP server on %s", config.ASBindAddr) + err := http.ListenAndServe(config.ASBindAddr, checkTokenAndLog(router)) if err != nil { errch <- err } -- cgit v1.2.3