diff --git a/plugins/reminder/reminder.go b/plugins/reminder/reminder.go index 7c87371..1699da2 100644 --- a/plugins/reminder/reminder.go +++ b/plugins/reminder/reminder.go @@ -3,29 +3,34 @@ package reminder import ( + "errors" "fmt" - "sort" + "log" "strconv" "strings" "sync" "time" + "github.com/jmoiron/sqlx" "github.com/velour/catbase/bot" "github.com/velour/catbase/bot/msg" "github.com/velour/catbase/config" ) +const ( + TIMESTAMP = "2006-01-02 15:04:05" +) + type ReminderPlugin struct { Bot bot.Bot - reminders []*Reminder + db *sqlx.DB mutex *sync.Mutex timer *time.Timer config *config.Config - nextReminderId int } type Reminder struct { - id int + id int64 from string who string what string @@ -33,65 +38,40 @@ type Reminder struct { channel string } -type reminderSlice []*Reminder - -func (s reminderSlice) Len() int { - return len(s) -} - -func (s reminderSlice) Less(i, j int) bool { - return s[i].when.Before(s[j].when) -} - -func (s reminderSlice) Swap(i, j int) { - s[i], s[j] = s[j], s[i] -} - func New(bot bot.Bot) *ReminderPlugin { + log.SetFlags(log.LstdFlags | log.Lshortfile) + if bot.DBVersion() == 1 { + if _, err := bot.DB().Exec(`create table if not exists reminders ( + id integer primary key, + fromWho string, + toWho string, + what string, + remindWhen string, + channel string + );`); err != nil { + log.Fatal(err) + } + } + dur, _ := time.ParseDuration("1h") timer := time.NewTimer(dur) timer.Stop() plugin := &ReminderPlugin{ Bot: bot, - reminders: []*Reminder{}, + db: bot.DB(), mutex: &sync.Mutex{}, timer: timer, config: bot.Config(), - nextReminderId: 0, } + + plugin.queueUpNextReminder() + go reminderer(plugin) return plugin } -func reminderer(p *ReminderPlugin) { - //welcome to the reminderererererererererer - for { - <-p.timer.C - - p.mutex.Lock() - - reminder := p.reminders[0] - if len(p.reminders) >= 2 { - p.reminders = p.reminders[1:] - p.timer.Reset(p.reminders[0].when.Sub(time.Now())) - } else { - p.reminders = []*Reminder{} - } - - p.mutex.Unlock() - - if reminder.from == reminder.who { - reminder.from = "you" - } - - message := fmt.Sprintf("Hey %s, %s wanted you to be reminded: %s", reminder.who, reminder.from, reminder.what) - p.Bot.SendMessage(reminder.channel, message) - } - -} - func (p *ReminderPlugin) Message(message msg.Message) bool { channel := message.Channel from := message.User.Name @@ -111,8 +91,6 @@ func (p *ReminderPlugin) Message(message msg.Message) bool { return true } - reminders := []*Reminder{} - operator := strings.ToLower(parts[2]) doConfirm := true @@ -120,20 +98,18 @@ func (p *ReminderPlugin) Message(message msg.Message) bool { if operator == "in" { //one off reminder //remind who in dur blah - when := time.Now().Add(dur) + when := time.Now().UTC().Add(dur) what := strings.Join(parts[4:], " ") - id := p.nextReminderId - p.nextReminderId++ - - reminders = append(reminders, &Reminder{ - id: id, + p.addReminder(&Reminder{ + id: -1, from: from, who: who, what: what, when: when, channel: channel, }) + } else if operator == "every" && strings.ToLower(parts[4]) == "for" { //batch add, especially for reminding msherms to buy a kit //remind who every dur for dur2 blah @@ -143,8 +119,8 @@ func (p *ReminderPlugin) Message(message msg.Message) bool { return true } - when := time.Now().Add(dur) - endTime := time.Now().Add(dur2) + when := time.Now().UTC().Add(dur) + endTime := time.Now().UTC().Add(dur2) what := strings.Join(parts[6:], " ") for i := 0; when.Before(endTime); i++ { @@ -154,11 +130,8 @@ func (p *ReminderPlugin) Message(message msg.Message) bool { break } - id := p.nextReminderId - p.nextReminderId++ - - reminders = append(reminders, &Reminder{ - id: id, + p.addReminder(&Reminder{ + id: int64(-1), from: from, who: who, what: what, @@ -178,59 +151,26 @@ func (p *ReminderPlugin) Message(message msg.Message) bool { p.Bot.SendMessage(channel, response) } - p.mutex.Lock() - - p.timer.Stop() - - p.reminders = append(p.reminders, reminders...) - - sort.Sort(reminderSlice(p.reminders)) - - if len(p.reminders) > 0 { - p.timer.Reset(p.reminders[0].when.Sub(time.Now())) - } - - p.mutex.Unlock() + p.queueUpNextReminder() return true } } else if len(parts) == 2 && strings.ToLower(parts[0]) == "list" && strings.ToLower(parts[1]) == "reminders" { - var response string - p.mutex.Lock() - if len(p.reminders) == 0 { - response = "no pending reminders" + response, err := p.getAllRemindersFormatted(channel) + if err != nil { + p.Bot.SendMessage(channel, "listing failed.") } else { - counter := 1 - for _, reminder := range p.reminders { - if reminder.channel == channel { - response += fmt.Sprintf("%d) %s -> %s :: %s @ %s (id=%d)\n", counter, reminder.from, reminder.who, reminder.what, reminder.when, reminder.id) - counter++ - } - } + p.Bot.SendMessage(channel, response) } - p.mutex.Unlock() - p.Bot.SendMessage(channel, response) return true } else if len(parts) == 3 && strings.ToLower(parts[0]) == "cancel" && strings.ToLower(parts[1]) == "reminder" { - id, err := strconv.Atoi(parts[2]) + id, err := strconv.ParseInt(parts[2], 10, 64) if err != nil { p.Bot.SendMessage(channel, fmt.Sprintf("couldn't parse id: %s", parts[2])) } else { - p.mutex.Lock() - deleted := false - for i, reminder := range p.reminders { - if reminder.id == id { - copy(p.reminders[i:], p.reminders[i+1:]) - p.reminders[len(p.reminders)-1] = nil - p.reminders = p.reminders[:len(p.reminders)-1] - deleted = true - break - } - } - p.mutex.Unlock() - - if deleted { + err := p.deleteReminder(id) + if err == nil { p.Bot.SendMessage(channel, fmt.Sprintf("successfully canceled reminder: %s", parts[2])) } else { p.Bot.SendMessage(channel, fmt.Sprintf("failed to find and cancel reminder: %s", parts[2])) @@ -257,3 +197,128 @@ func (p *ReminderPlugin) BotMessage(message msg.Message) bool { func (p *ReminderPlugin) RegisterWeb() *string { return nil } + +func (p *ReminderPlugin) getNextReminder() *Reminder { + p.mutex.Lock() + defer p.mutex.Unlock() + rows, err := p.db.Query("select id, fromWho, toWho, what, remindWhen, channel from reminders order by remindWhen asc limit 1;") + if err != nil { + log.Print(err) + return nil + } + defer rows.Close() + + once := false + var reminder *Reminder + for rows.Next() { + if once { + log.Print("somehow got multiple rows") + } + reminder = &Reminder{} + + var when string + err := rows.Scan(&reminder.id, &reminder.from, &reminder.who, &reminder.what, &when, &reminder.channel) + if err != nil { + log.Print(err) + return nil + } + reminder.when, err = time.Parse(TIMESTAMP, when) + if err != nil { + log.Print(err) + return nil + } + + once = true + } + + return reminder +} + +func (p *ReminderPlugin) addReminder(reminder *Reminder) error { + p.mutex.Lock() + defer p.mutex.Unlock() + _, err := p.db.Exec(`insert into reminders (fromWho, toWho, what, remindWhen, channel) values (?, ?, ?, ?, ?);`, + reminder.from, reminder.who, reminder.what, reminder.when.Format(TIMESTAMP), reminder.channel) + + if err != nil { + log.Print(err) + } + return err +} + +func (p *ReminderPlugin) deleteReminder(id int64) error { + p.mutex.Lock() + defer p.mutex.Unlock() + res, err := p.db.Exec(`delete from reminders where id = ?;`, id) + if err != nil { + log.Print(err) + } else { + if affected, err := res.RowsAffected(); err != nil { + return err + } else if affected != 1 { + return errors.New("didn't delete any rows") + } + } + return err +} + +func (p *ReminderPlugin) getAllRemindersFormatted(channel string) (string, error) { + p.mutex.Lock() + defer p.mutex.Unlock() + rows, err := p.db.Query("select id, fromWho, toWho, what, remindWhen from reminders order by remindWhen asc;") + if err != nil { + log.Print(err) + return "", nil + } + defer rows.Close() + reminders := "" + counter := 1 + reminder := &Reminder{} + for rows.Next() { + var when string + err := rows.Scan(&reminder.id, &reminder.from, &reminder.who, &reminder.what, &when) + if err != nil { + return "", err + } + reminders += fmt.Sprintf("%d) %s -> %s :: %s @ %s (%d)\n", counter, reminder.from, reminder.who, reminder.what, when, reminder.id) + counter++ + } + if counter == 1 { + return "no pending reminders", nil + } + + return reminders, nil +} + +func (p *ReminderPlugin) queueUpNextReminder() { + nextReminder := p.getNextReminder() + + if nextReminder != nil { + p.timer.Reset(nextReminder.when.Sub(time.Now().UTC())) + } +} + +func reminderer(p *ReminderPlugin) { + for { + <-p.timer.C + + reminder := p.getNextReminder() + + if reminder != nil && time.Now().UTC().After(reminder.when) { + if reminder.from == reminder.who { + reminder.from = "you" + } + + message := fmt.Sprintf("Hey %s, %s wanted you to be reminded: %s", reminder.who, reminder.from, reminder.what) + p.Bot.SendMessage(reminder.channel, message) + + if err:= p.deleteReminder(reminder.id); err != nil { + log.Print(reminder.id) + log.Print(err) + log.Fatal("this will cause problems, we need to stop now.") + } + } + + p.queueUpNextReminder() + } +} diff --git a/plugins/reminder/reminder_test.go b/plugins/reminder/reminder_test.go index 841d1ce..aab65f3 100644 --- a/plugins/reminder/reminder_test.go +++ b/plugins/reminder/reminder_test.go @@ -128,13 +128,13 @@ func TestCancel(t *testing.T) { assert.NotNil(t, c) res := c.Message(makeMessage("!remind testuser in 1m don't fail this test")) assert.True(t, res) - res = c.Message(makeMessage("!cancel reminder 0")) + res = c.Message(makeMessage("!cancel reminder 1")) assert.True(t, res) res = c.Message(makeMessage("!list reminders")) assert.True(t, res) assert.Len(t, mb.Messages, 3) assert.Contains(t, mb.Messages[0], "Sure tester, I'll remind testuser.") - assert.Contains(t, mb.Messages[1], "successfully canceled reminder: 0") + assert.Contains(t, mb.Messages[1], "successfully canceled reminder: 1") assert.Contains(t, mb.Messages[2], "no pending reminders") } @@ -142,10 +142,10 @@ func TestCancelMiss(t *testing.T) { mb := bot.NewMockBot() c := New(mb) assert.NotNil(t, c) - res := c.Message(makeMessage("!cancel reminder 0")) + res := c.Message(makeMessage("!cancel reminder 1")) assert.True(t, res) assert.Len(t, mb.Messages, 1) - assert.Contains(t, mb.Messages[0], "failed to find and cancel reminder: 0") + assert.Contains(t, mb.Messages[0], "failed to find and cancel reminder: 1") } func TestHelp(t *testing.T) {