aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Auvolat <alex@adnab.me>2020-02-26 16:30:10 +0100
committerAlex Auvolat <alex@adnab.me>2020-02-26 16:30:10 +0100
commitf3f1b8d981d818b38713fe84deb206720b0fcb10 (patch)
tree7fd0f5d3e71b3553a4c57db71cb01d2426075967
parent67c7f7361d63a282788f159494a6f43172c8806a (diff)
downloadeasybridge-f3f1b8d981d818b38713fe84deb206720b0fcb10.tar.gz
easybridge-f3f1b8d981d818b38713fe84deb206720b0fcb10.zip
Room autorejoin
-rw-r--r--account.go65
-rw-r--r--connector/xmpp/xmpp.go3
-rw-r--r--db.go16
-rw-r--r--main.go22
-rw-r--r--server.go12
5 files changed, 87 insertions, 31 deletions
diff --git a/account.go b/account.go
index 8da6d44..a76b782 100644
--- a/account.go
+++ b/account.go
@@ -3,6 +3,7 @@ package main
import (
"fmt"
"strings"
+ "sync"
log "github.com/sirupsen/logrus"
@@ -18,17 +19,23 @@ type Account struct {
JoinedRooms map[RoomID]bool
}
+var accountsLock sync.Mutex
var registeredAccounts = map[string]map[string]*Account{}
func AddAccount(a *Account) {
+ accountsLock.Lock()
+ defer accountsLock.Unlock()
+
if _, ok := registeredAccounts[a.MatrixUser]; !ok {
registeredAccounts[a.MatrixUser] = make(map[string]*Account)
}
registeredAccounts[a.MatrixUser][a.AccountName] = a
- ezbrSystemSendf(a.MatrixUser, "Connecting to account %s (%s)", a.AccountName, a.Protocol)
}
func FindAccount(mxUser string, name string) *Account {
+ accountsLock.Lock()
+ defer accountsLock.Unlock()
+
if u, ok := registeredAccounts[mxUser]; ok {
if a, ok := u[name]; ok {
return a
@@ -38,6 +45,9 @@ func FindAccount(mxUser string, name string) *Account {
}
func FindJoinedAccount(mxUser string, protocol string, room RoomID) *Account {
+ accountsLock.Lock()
+ defer accountsLock.Unlock()
+
if u, ok := registeredAccounts[mxUser]; ok {
for _, acct := range u {
if acct.Protocol == protocol {
@@ -51,17 +61,55 @@ func FindJoinedAccount(mxUser string, protocol string, room RoomID) *Account {
}
func RemoveAccount(mxUser string, name string) {
+ accountsLock.Lock()
+ defer accountsLock.Unlock()
+
if u, ok := registeredAccounts[mxUser]; ok {
delete(u, name)
}
}
+// ----
+
func (a *Account) ezbrMessagef(format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
msg = fmt.Sprintf("%s: %s", a.Protocol, msg)
ezbrSystemSend(a.MatrixUser, msg)
}
+func (a *Account) connect(config map[string]string, join_rooms []string) {
+ ezbrSystemSendf(a.MatrixUser, "Connecting to account %s (%s)", a.AccountName, a.Protocol)
+
+ err := a.Conn.Configure(config)
+ if err != nil {
+ ezbrSystemSendf(a.MatrixUser, "%s (%s) cannot connect: %s", a.AccountName, a.Protocol, err.Error())
+ return
+ }
+
+ for _, room := range join_rooms {
+ var entry DbJoinedRoom
+ db.Where(&DbJoinedRoom{
+ MxUserID: a.MatrixUser,
+ Protocol: a.Protocol,
+ AccountName: a.AccountName,
+ RoomID: RoomID(room),
+ }).FirstOrCreate(&entry)
+ }
+
+ var autojoin []DbJoinedRoom
+ db.Where(&DbJoinedRoom{
+ MxUserID: a.MatrixUser,
+ Protocol: a.Protocol,
+ AccountName: a.AccountName,
+ }).Find(&autojoin)
+ for _, aj := range autojoin {
+ err := a.Conn.Join(aj.RoomID)
+ if err != nil {
+ ezbrSystemSendf(a.MatrixUser, "%s (%s) cannot join %s: %s", a.AccountName, a.Protocol, aj.RoomID, err.Error())
+ }
+ }
+}
+
// ---- Begin event handlers ----
func (a *Account) Joined(roomId RoomID) {
@@ -69,6 +117,14 @@ func (a *Account) Joined(roomId RoomID) {
if err != nil {
a.ezbrMessagef("Dropping Account.Joined %s: %s", roomId, err.Error())
}
+
+ var entry DbJoinedRoom
+ db.Where(&DbJoinedRoom{
+ MxUserID: a.MatrixUser,
+ Protocol: a.Protocol,
+ AccountName: a.AccountName,
+ RoomID: roomId,
+ }).FirstOrCreate(&entry)
}
func (a *Account) joinedInternal(roomId RoomID) error {
@@ -95,6 +151,13 @@ func (a *Account) Left(roomId RoomID) {
if err != nil {
a.ezbrMessagef("Dropping Account.Left %s: %s", roomId, err.Error())
}
+
+ db.Where(&DbJoinedRoom{
+ MxUserID: a.MatrixUser,
+ Protocol: a.Protocol,
+ AccountName: a.AccountName,
+ RoomID: roomId,
+ }).Delete(&DbJoinedRoom{})
}
func (a *Account) leftInternal(roomId RoomID) error {
diff --git a/connector/xmpp/xmpp.go b/connector/xmpp/xmpp.go
index 7e35135..727c4d8 100644
--- a/connector/xmpp/xmpp.go
+++ b/connector/xmpp/xmpp.go
@@ -81,7 +81,8 @@ func (xm *XMPP) Configure(c Configuration) error {
return fmt.Errorf("JID %s not on server %s", xm.jid, xm.server)
}
xm.jid_localpart = jid_parts[0]
- xm.nickname = xm.jid_localpart
+
+ xm.nickname = c.GetString("nickname", xm.jid_locakpart)
xm.password, err = c.GetString("password")
if err != nil {
diff --git a/db.go b/db.go
index fe3d1e3..602bd1f 100644
--- a/db.go
+++ b/db.go
@@ -36,6 +36,9 @@ func InitDb() error {
db.AutoMigrate(&DbPmRoomMap{})
db.Model(&DbPmRoomMap{}).AddIndex("idx_protocol_user_account_user", "protocol", "user_id", "mx_user_id", "account_name")
+ db.AutoMigrate(&DbJoinedRoom{})
+ db.Model(&DbJoinedRoom{}).AddIndex("idx_user_protocol_account", "mx_user_id", "protocol", "account_name")
+
return nil
}
@@ -86,6 +89,19 @@ type DbPmRoomMap struct {
MxRoomID string `gorm:"index:mxroomoid"`
}
+// List of joined channels to be re-joined on reconnect
+type DbJoinedRoom struct {
+ gorm.Model
+
+ // User id and account name
+ MxUserID string
+ Protocol string
+ AccountName string
+
+ // Room ID
+ RoomID connector.RoomID
+}
+
// ---- Simple locking mechanism
var dbLocks [256]sync.Mutex
diff --git a/main.go b/main.go
index d74e3d8..3cec532 100644
--- a/main.go
+++ b/main.go
@@ -29,7 +29,7 @@ type ConfigAccount struct {
}
type ConfigFile struct {
- HttpBindAddr string `json:"http_bind_addr"`
+ ASBindAddr string `json:"appservice_bind_addr"`
Registration string `json:"registration"`
Server string `json:"homeserver_url"`
DbType string `json:"db_type"`
@@ -45,7 +45,7 @@ var registration *mxlib.Registration
func readConfig() ConfigFile {
config_file := ConfigFile{
- HttpBindAddr: "0.0.0.0:8321",
+ ASBindAddr: "0.0.0.0:8321",
Registration: "./registration.yaml",
Server: "http://localhost:8008",
DbType: "sqlite3",
@@ -192,7 +192,7 @@ func main() {
}
conn.SetHandler(account)
AddAccount(account)
- go connectAndJoin(account, params)
+ go account.connect(params.Config, params.Rooms)
}
}
@@ -201,19 +201,3 @@ func main() {
log.Fatal(err)
}
}
-
-func connectAndJoin(account *Account, params ConfigAccount) {
- log.Printf("Connecting to %s", params.Protocol)
- err := account.Conn.Configure(params.Config)
- if err != nil {
- log.Printf("Could not connect to %s: %s", params.Protocol, err)
- } else {
- log.Printf("Connected to %s, now joining %#v", params.Protocol, params.Rooms)
- for _, room := range params.Rooms {
- err := account.Conn.Join(connector.RoomID(room))
- if err != nil {
- log.Printf("Could not join %s: %s", room, err)
- }
- }
- }
-}
diff --git a/server.go b/server.go
index 10721be..84a1e85 100644
--- a/server.go
+++ b/server.go
@@ -13,14 +13,6 @@ import (
"git.deuxfleurs.fr/Deuxfleurs/easybridge/mxlib"
)
-type Config struct {
- HttpBindAddr string
- Server string
- DbType string
- DbPath string
- MatrixDomain string
-}
-
var mx *mxlib.Client
func StartAppService() (chan error, error) {
@@ -55,8 +47,8 @@ func StartAppService() (chan error, error) {
errch := make(chan error)
go func() {
- log.Printf("Starting HTTP server on %s", config.HttpBindAddr)
- err := http.ListenAndServe(config.HttpBindAddr, checkTokenAndLog(router))
+ log.Printf("Starting HTTP server on %s", config.ASBindAddr)
+ err := http.ListenAndServe(config.ASBindAddr, checkTokenAndLog(router))
if err != nil {
errch <- err
}