diff --git a/plugins/reminder/reminder.go b/plugins/reminder/reminder.go index 6fea17d..a5cf7cb 100644 --- a/plugins/reminder/reminder.go +++ b/plugins/reminder/reminder.go @@ -5,6 +5,7 @@ package reminder import ( "errors" "fmt" + "regexp" "strconv" "strings" "sync" @@ -28,12 +29,13 @@ const ( ) type ReminderPlugin struct { - bot bot.Bot - db *sqlx.DB - mutex *sync.Mutex - timer *time.Timer - config *config.Config - when *when.Parser + bot bot.Bot + db *sqlx.DB + mutex *sync.Mutex + timer *time.Timer + config *config.Config + when *when.Parser + lastReminder map[string]*Reminder } type Reminder struct { @@ -66,39 +68,59 @@ func New(b bot.Bot) *ReminderPlugin { w.Add(common.All...) plugin := &ReminderPlugin{ - bot: b, - db: b.DB(), - mutex: &sync.Mutex{}, - timer: timer, - config: b.Config(), - when: w, + bot: b, + db: b.DB(), + mutex: &sync.Mutex{}, + timer: timer, + config: b.Config(), + when: w, + lastReminder: map[string]*Reminder{}, } plugin.queueUpNextReminder() - go reminderer(b.DefaultConnector(), plugin) + go plugin.reminderer(b.DefaultConnector()) + b.RegisterRegexCmd(plugin, bot.Message, regexp.MustCompile(`(?i)^snooze (?P.+)$`), plugin.snooze) b.Register(plugin, bot.Message, plugin.message) b.Register(plugin, bot.Help, plugin.help) return plugin } +func (p *ReminderPlugin) snooze(r bot.Request) bool { + lastReminder := p.lastReminder[r.Msg.Channel] + if lastReminder == nil { + p.bot.Send(r.Conn, bot.Message, r.Msg.Channel, "My memory is too small to contain a snoozed reminder.") + return true + } + durationTxt := replaceDuration(p.when, r.Values["duration"]) + dur, err := time.ParseDuration(durationTxt) + if err != nil { + p.bot.Send(r.Conn, bot.Message, r.Msg.Channel, "Whoa, cowboy. I can't parse that time duration.") + return true + } + lastReminder.when = time.Now().UTC().Add(dur) + p.addReminder(lastReminder) + delete(p.lastReminder, r.Msg.Channel) + p.bot.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf("Okay, I'll let you know in %s", dur)) + p.queueUpNextReminder() + return true +} + +func replaceDuration(when *when.Parser, txt string) string { + t, err := when.Parse(txt, time.Now()) + if t != nil && err == nil { + return txt[0:t.Index] + t.Time.Sub(time.Now()).String() + txt[t.Index+len(t.Text):] + } + return txt +} + func (p *ReminderPlugin) message(c bot.Connector, kind bot.Kind, message msg.Message, args ...any) bool { channel := message.Channel from := message.User.Name - var dur, dur2 time.Duration - t, err := p.when.Parse(message.Body, time.Now()) - // Allowing err to fallthrough for other parsing - if t != nil && err == nil { - t2 := t.Time.Sub(time.Now()).String() - message.Body = string(message.Body[0:t.Index]) + t2 + string(message.Body[t.Index+len(t.Text):]) - log.Debug(). - Str("body", message.Body). - Str("text", t.Text). - Msg("Got time request") - } + message.Body = replaceDuration(p.when, message.Body) parts := strings.Fields(message.Body) if len(parts) >= 5 { @@ -108,7 +130,7 @@ func (p *ReminderPlugin) message(c bot.Connector, kind bot.Kind, message msg.Mes who = from } - dur, err = time.ParseDuration(parts[3]) + dur, err := time.ParseDuration(parts[3]) if err != nil { p.bot.Send(c, bot.Message, channel, "Easy cowboy, not sure I can parse that duration. Try something like '1.5h' or '2h45m'.") return true @@ -135,7 +157,7 @@ func (p *ReminderPlugin) message(c bot.Connector, kind bot.Kind, message msg.Mes } else if operator == "every" && strings.ToLower(parts[4]) == "for" { //batch add, especially for reminding msherms to buy a kit //remind who every dur for dur2 blah - dur2, err = time.ParseDuration(parts[5]) + dur2, err := time.ParseDuration(parts[5]) if err != nil { log.Error().Err(err) p.bot.Send(c, bot.Message, channel, "Easy cowboy, not sure I can parse that duration. Try something like '1.5h' or '2h45m'.") @@ -352,7 +374,8 @@ func (p *ReminderPlugin) queueUpNextReminder() { } } -func reminderer(c bot.Connector, p *ReminderPlugin) { +func (p *ReminderPlugin) reminderer(c bot.Connector) { + for { <-p.timer.C diff --git a/plugins/twitch/twitch_test.go b/plugins/twitch/twitch_test.go index 6351da1..b112acc 100644 --- a/plugins/twitch/twitch_test.go +++ b/plugins/twitch/twitch_test.go @@ -41,7 +41,7 @@ func makeTwitchPlugin(t *testing.T) (*TwitchPlugin, *bot.MockBot) { mb := bot.NewMockBot() c := New(mb) mb.Config().Set("twitch.clientid", "fake") - mb.Config().Set("twitch.authorization", "fake") + mb.Config().Set("twitch.secret", "fake") c.c.SetArray("Twitch.Channels", []string{"test"}) c.c.SetArray("Twitch.test.Users", []string{"drseabass"}) assert.NotNil(t, c)