mirror of https://github.com/velour/catbase.git
Compare commits
No commits in common. "d69d708245a8d0139cc76e01d5bd06d6d66554f5" and "030d0371f99a0a4e182749ee646c0880be6b705e" have entirely different histories.
d69d708245
...
030d0371f9
|
@ -126,7 +126,6 @@ func (s *SlackApp) Serve() error {
|
||||||
w.Write([]byte(r.Challenge))
|
w.Write([]byte(r.Challenge))
|
||||||
} else if eventsAPIEvent.Type == slackevents.CallbackEvent {
|
} else if eventsAPIEvent.Type == slackevents.CallbackEvent {
|
||||||
innerEvent := eventsAPIEvent.InnerEvent
|
innerEvent := eventsAPIEvent.InnerEvent
|
||||||
typ := innerEvent.Type
|
|
||||||
switch ev := innerEvent.Data.(type) {
|
switch ev := innerEvent.Data.(type) {
|
||||||
case *slackevents.MessageAction:
|
case *slackevents.MessageAction:
|
||||||
log.Debug().Interface("ev", ev).Msg("MessageAction")
|
log.Debug().Interface("ev", ev).Msg("MessageAction")
|
||||||
|
@ -138,15 +137,11 @@ func (s *SlackApp) Serve() error {
|
||||||
//s.msgReceivd(ev)
|
//s.msgReceivd(ev)
|
||||||
case *slackevents.MessageEvent:
|
case *slackevents.MessageEvent:
|
||||||
s.msgReceivd(ev)
|
s.msgReceivd(ev)
|
||||||
case *slackevents.ReactionAddedEvent:
|
case *slack.ReactionAddedEvent:
|
||||||
err := s.reactionReceived(ev)
|
s.reactionReceived(ev)
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("error with reaction recording")
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Interface("ev", ev).
|
Interface("ev", ev).
|
||||||
Interface("type", typ).
|
|
||||||
Msg("Unknown CallbackEvent")
|
Msg("Unknown CallbackEvent")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -612,8 +607,7 @@ func (s *SlackApp) log(msg, channel string) error {
|
||||||
return f.Sync()
|
return f.Sync()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SlackApp) reactionReceived(event *slackevents.ReactionAddedEvent) error {
|
func (s *SlackApp) reactionReceived(event *slack.ReactionAddedEvent) error {
|
||||||
log.Debug().Msgf("reactionReceived(%+v)", event)
|
|
||||||
name, _ := s.getUser(event.User, "unknown")
|
name, _ := s.getUser(event.User, "unknown")
|
||||||
|
|
||||||
ch, err := s.getChannel(event.Item.Channel)
|
ch, err := s.getChannel(event.Item.Channel)
|
||||||
|
@ -638,7 +632,7 @@ func (s *SlackApp) reactionReceived(event *slackevents.ReactionAddedEvent) error
|
||||||
body = fmt.Sprintf("%s: %s", u, m.Text)
|
body = fmt.Sprintf("%s: %s", u, m.Text)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
log.Debug().Interface("msg", m).Msg("Unexpected type in reaction received")
|
log.Debug().Interface("msg", m).Msg("Unexpected type")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -647,8 +641,6 @@ func (s *SlackApp) reactionReceived(event *slackevents.ReactionAddedEvent) error
|
||||||
fixDate(tstamp, "2006-01-02 15:04:05"),
|
fixDate(tstamp, "2006-01-02 15:04:05"),
|
||||||
name, body, event.Reaction)
|
name, body, event.Reaction)
|
||||||
|
|
||||||
log.Debug().Msgf("Made it to reaction received, logging %v: %v", msg, channel)
|
|
||||||
|
|
||||||
return s.log(msg, channel)
|
return s.log(msg, channel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
5
go.mod
5
go.mod
|
@ -14,7 +14,6 @@ require (
|
||||||
github.com/antchfx/xpath v1.1.1 // indirect
|
github.com/antchfx/xpath v1.1.1 // indirect
|
||||||
github.com/armon/go-radix v1.0.0 // indirect
|
github.com/armon/go-radix v1.0.0 // indirect
|
||||||
github.com/azr/backoff v0.0.0-20160115115103-53511d3c7330 // indirect
|
github.com/azr/backoff v0.0.0-20160115115103-53511d3c7330 // indirect
|
||||||
github.com/cdipaolo/goml v0.0.0-20190412180403-e1f51f713598
|
|
||||||
github.com/chrissexton/gofuck v1.0.0
|
github.com/chrissexton/gofuck v1.0.0
|
||||||
github.com/chrissexton/leftpad v0.0.0-20181207133115-1e93189d2fff
|
github.com/chrissexton/leftpad v0.0.0-20181207133115-1e93189d2fff
|
||||||
github.com/chrissexton/sentiment v0.0.0-20190927141846-d69c422ba035
|
github.com/chrissexton/sentiment v0.0.0-20190927141846-d69c422ba035
|
||||||
|
@ -29,6 +28,7 @@ require (
|
||||||
github.com/gonum/floats v0.0.0-20181209220543-c233463c7e82 // indirect
|
github.com/gonum/floats v0.0.0-20181209220543-c233463c7e82 // indirect
|
||||||
github.com/gonum/internal v0.0.0-20181124074243-f884aa714029 // indirect
|
github.com/gonum/internal v0.0.0-20181124074243-f884aa714029 // indirect
|
||||||
github.com/google/uuid v1.1.1
|
github.com/google/uuid v1.1.1
|
||||||
|
github.com/gorilla/websocket v1.4.1 // indirect
|
||||||
github.com/inconshreveable/log15 v0.0.0-20200109203555-b30bc20e4fd1 // indirect
|
github.com/inconshreveable/log15 v0.0.0-20200109203555-b30bc20e4fd1 // indirect
|
||||||
github.com/james-bowman/nlp v0.0.0-20191016091239-d9dbfaff30c6
|
github.com/james-bowman/nlp v0.0.0-20191016091239-d9dbfaff30c6
|
||||||
github.com/james-bowman/sparse v0.0.0-20190423065201-80c6877364c7 // indirect
|
github.com/james-bowman/sparse v0.0.0-20190423065201-80c6877364c7 // indirect
|
||||||
|
@ -44,11 +44,10 @@ require (
|
||||||
github.com/mmcdole/goxpp v0.0.0-20181012175147-0068e33feabf // indirect
|
github.com/mmcdole/goxpp v0.0.0-20181012175147-0068e33feabf // indirect
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
||||||
github.com/olebedev/when v0.0.0-20190311101825-c3b538a97254
|
github.com/olebedev/when v0.0.0-20190311101825-c3b538a97254
|
||||||
github.com/pkg/errors v0.9.1 // indirect
|
|
||||||
github.com/robertkrimen/otto v0.0.0-20180617131154-15f95af6e78d // indirect
|
github.com/robertkrimen/otto v0.0.0-20180617131154-15f95af6e78d // indirect
|
||||||
github.com/rs/zerolog v1.15.0
|
github.com/rs/zerolog v1.15.0
|
||||||
github.com/saintfish/chardet v0.0.0-20120816061221-3af4cd4741ca // indirect
|
github.com/saintfish/chardet v0.0.0-20120816061221-3af4cd4741ca // indirect
|
||||||
github.com/slack-go/slack v0.6.5
|
github.com/slack-go/slack v0.6.4
|
||||||
github.com/spaolacci/murmur3 v1.1.0 // indirect
|
github.com/spaolacci/murmur3 v1.1.0 // indirect
|
||||||
github.com/stretchr/objx v0.2.0 // indirect
|
github.com/stretchr/objx v0.2.0 // indirect
|
||||||
github.com/stretchr/testify v1.4.0
|
github.com/stretchr/testify v1.4.0
|
||||||
|
|
12
go.sum
12
go.sum
|
@ -70,8 +70,10 @@ github.com/gonum/internal v0.0.0-20181124074243-f884aa714029 h1:8jtTdc+Nfj9AR+0s
|
||||||
github.com/gonum/internal v0.0.0-20181124074243-f884aa714029/go.mod h1:Pu4dmpkhSyOzRwuXkOgAvijx4o+4YMUJJo9OvPYMkks=
|
github.com/gonum/internal v0.0.0-20181124074243-f884aa714029/go.mod h1:Pu4dmpkhSyOzRwuXkOgAvijx4o+4YMUJJo9OvPYMkks=
|
||||||
github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
|
github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
|
||||||
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
|
github.com/gorilla/websocket v1.2.0 h1:VJtLvh6VQym50czpZzx07z/kw9EgAxI3x1ZB8taTMQQ=
|
||||||
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.2.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
|
||||||
|
github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM=
|
||||||
|
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/inconshreveable/log15 v0.0.0-20200109203555-b30bc20e4fd1 h1:KUDFlmBg2buRWNzIcwLlKvfcnujcHQRQ1As1LoaCLAM=
|
github.com/inconshreveable/log15 v0.0.0-20200109203555-b30bc20e4fd1 h1:KUDFlmBg2buRWNzIcwLlKvfcnujcHQRQ1As1LoaCLAM=
|
||||||
github.com/inconshreveable/log15 v0.0.0-20200109203555-b30bc20e4fd1/go.mod h1:cOaXtrgN4ScfRrD9Bre7U1thNq5RtJ8ZoP4iXVGRj6o=
|
github.com/inconshreveable/log15 v0.0.0-20200109203555-b30bc20e4fd1/go.mod h1:cOaXtrgN4ScfRrD9Bre7U1thNq5RtJ8ZoP4iXVGRj6o=
|
||||||
github.com/james-bowman/nlp v0.0.0-20191016091239-d9dbfaff30c6 h1:k8+n5sfvxlixRNVkbelPGzEYjbGIKaBnRzRlx2NCtYA=
|
github.com/james-bowman/nlp v0.0.0-20191016091239-d9dbfaff30c6 h1:k8+n5sfvxlixRNVkbelPGzEYjbGIKaBnRzRlx2NCtYA=
|
||||||
|
@ -111,8 +113,6 @@ github.com/olebedev/when v0.0.0-20190311101825-c3b538a97254/go.mod h1:DPucAeQGDP
|
||||||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/robertkrimen/otto v0.0.0-20180617131154-15f95af6e78d h1:1VUlQbCfkoSGv7qP7Y+ro3ap1P1pPZxgdGVqiTVy5C4=
|
github.com/robertkrimen/otto v0.0.0-20180617131154-15f95af6e78d h1:1VUlQbCfkoSGv7qP7Y+ro3ap1P1pPZxgdGVqiTVy5C4=
|
||||||
|
@ -122,8 +122,8 @@ github.com/rs/zerolog v1.15.0 h1:uPRuwkWF4J6fGsJ2R0Gn2jB1EQiav9k3S6CSdygQJXY=
|
||||||
github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc=
|
github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc=
|
||||||
github.com/saintfish/chardet v0.0.0-20120816061221-3af4cd4741ca h1:NugYot0LIVPxTvN8n+Kvkn6TrbMyxQiuvKdEwFdR9vI=
|
github.com/saintfish/chardet v0.0.0-20120816061221-3af4cd4741ca h1:NugYot0LIVPxTvN8n+Kvkn6TrbMyxQiuvKdEwFdR9vI=
|
||||||
github.com/saintfish/chardet v0.0.0-20120816061221-3af4cd4741ca/go.mod h1:uugorj2VCxiV1x+LzaIdVa9b4S4qGAcH6cbhh4qVxOU=
|
github.com/saintfish/chardet v0.0.0-20120816061221-3af4cd4741ca/go.mod h1:uugorj2VCxiV1x+LzaIdVa9b4S4qGAcH6cbhh4qVxOU=
|
||||||
github.com/slack-go/slack v0.6.5 h1:IkDKtJ2IROJNoe3d6mW870/NRKvq2fhLB/Q5XmzWk00=
|
github.com/slack-go/slack v0.6.4 h1:cxOqFgM5RW6mdEyDqAJutFk3qiORK9oHRKi5bPqkY9o=
|
||||||
github.com/slack-go/slack v0.6.5/go.mod h1:FGqNzJBmxIsZURAxh2a8D21AnOVvvXZvGligs4npPUM=
|
github.com/slack-go/slack v0.6.4/go.mod h1:sGRjv3w+ERAUMMMbldHObQPBcNSyVB7KLKYfnwUFBfw=
|
||||||
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
|
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
|
||||||
github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
|
github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
|
||||||
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
|
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
|
||||||
|
|
|
@ -1,59 +0,0 @@
|
||||||
package reaction
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"io/ioutil"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/cdipaolo/goml/base"
|
|
||||||
"github.com/cdipaolo/goml/text"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MetaData struct {
|
|
||||||
NClasses uint8
|
|
||||||
ClassList []string
|
|
||||||
}
|
|
||||||
|
|
||||||
type bayesReactor struct {
|
|
||||||
model *text.NaiveBayes
|
|
||||||
meta MetaData
|
|
||||||
}
|
|
||||||
|
|
||||||
func newBayesReactor(jsonPath string) *bayesReactor {
|
|
||||||
reactor := &bayesReactor{}
|
|
||||||
f, err := ioutil.ReadFile(jsonPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msgf("error reading json")
|
|
||||||
return reactor
|
|
||||||
}
|
|
||||||
var meta MetaData
|
|
||||||
err = json.Unmarshal(f, &meta)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msgf("error reading json")
|
|
||||||
return reactor
|
|
||||||
}
|
|
||||||
reactor.meta = meta
|
|
||||||
|
|
||||||
stream := make(chan base.TextDatapoint, 100)
|
|
||||||
//errors := make(chan error)
|
|
||||||
model := text.NewNaiveBayes(stream, meta.NClasses, base.OnlyWordsAndNumbers)
|
|
||||||
err = model.RestoreFromFile(strings.TrimSuffix(jsonPath, ".json"))
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msgf("error reading json")
|
|
||||||
return reactor
|
|
||||||
}
|
|
||||||
reactor.model = model
|
|
||||||
|
|
||||||
return reactor
|
|
||||||
}
|
|
||||||
|
|
||||||
// React returns an emojy and probability given an input
|
|
||||||
func (b *bayesReactor) React(input string) (string, float64) {
|
|
||||||
if b.model == nil {
|
|
||||||
return "", 0.0
|
|
||||||
}
|
|
||||||
class, prob := b.model.Probability(input)
|
|
||||||
emojy := b.meta.ClassList[class]
|
|
||||||
return emojy, prob
|
|
||||||
}
|
|
|
@ -3,12 +3,12 @@
|
||||||
package reaction
|
package reaction
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"math/rand"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
"github.com/chrissexton/sentiment"
|
"github.com/chrissexton/sentiment"
|
||||||
|
|
||||||
"github.com/velour/catbase/bot"
|
"github.com/velour/catbase/bot"
|
||||||
"github.com/velour/catbase/bot/msg"
|
"github.com/velour/catbase/bot/msg"
|
||||||
"github.com/velour/catbase/config"
|
"github.com/velour/catbase/config"
|
||||||
|
@ -19,7 +19,6 @@ type ReactionPlugin struct {
|
||||||
config *config.Config
|
config *config.Config
|
||||||
|
|
||||||
model sentiment.Models
|
model sentiment.Models
|
||||||
br *bayesReactor
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(b bot.Bot) *ReactionPlugin {
|
func New(b bot.Bot) *ReactionPlugin {
|
||||||
|
@ -27,31 +26,35 @@ func New(b bot.Bot) *ReactionPlugin {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Err(err).Msg("Couldn't restore sentiment model")
|
log.Fatal().Err(err).Msg("Couldn't restore sentiment model")
|
||||||
}
|
}
|
||||||
c := b.Config()
|
|
||||||
path := c.GetString("reaction.modelpath", "emojy.model.json")
|
|
||||||
rp := &ReactionPlugin{
|
rp := &ReactionPlugin{
|
||||||
bot: b,
|
bot: b,
|
||||||
config: c,
|
config: b.Config(),
|
||||||
model: model,
|
model: model,
|
||||||
br: newBayesReactor(path),
|
|
||||||
}
|
}
|
||||||
b.Register(rp, bot.Message, rp.message)
|
b.Register(rp, bot.Message, rp.message)
|
||||||
return rp
|
return rp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ReactionPlugin) message(c bot.Connector, kind bot.Kind, message msg.Message, args ...interface{}) bool {
|
func (p *ReactionPlugin) message(c bot.Connector, kind bot.Kind, message msg.Message, args ...interface{}) bool {
|
||||||
emojy, prob := p.br.React(message.Body)
|
chance := p.config.GetFloat64("Reaction.GeneralChance", 0.01)
|
||||||
target := p.config.GetFloat64("reaction.confidence", 0.5)
|
if rand.Float64() < chance {
|
||||||
|
analysis := p.model.SentimentAnalysis(message.Body, sentiment.English)
|
||||||
|
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Float64("prob", prob).
|
Uint8("score", analysis.Score).
|
||||||
Float64("target", target).
|
Str("body", message.Body).
|
||||||
Bool("accept", prob > target).
|
Msg("sentiment of statement")
|
||||||
Str("emojy", emojy).
|
|
||||||
Msgf("Reaction check")
|
|
||||||
|
|
||||||
if prob > target {
|
var reactions []string
|
||||||
p.bot.Send(c, bot.Reaction, message.Channel, emojy, message)
|
if analysis.Score > 0 {
|
||||||
|
reactions = p.config.GetArray("Reaction.PositiveReactions", []string{})
|
||||||
|
} else {
|
||||||
|
reactions = p.config.GetArray("Reaction.NegativeReactions", []string{})
|
||||||
|
}
|
||||||
|
|
||||||
|
reaction := reactions[rand.Intn(len(reactions))]
|
||||||
|
|
||||||
|
p.bot.Send(c, bot.Reaction, message.Channel, reaction, message)
|
||||||
}
|
}
|
||||||
|
|
||||||
p.checkReactions(c, message)
|
p.checkReactions(c, message)
|
||||||
|
|
|
@ -1,58 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"flag"
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/cdipaolo/goml/base"
|
|
||||||
"github.com/cdipaolo/goml/text"
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MetaData struct {
|
|
||||||
NClasses uint8
|
|
||||||
ClassList []string
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
log.Logger = log.With().Caller().Stack().Logger()
|
|
||||||
log.Logger = log.Logger.Output(zerolog.ConsoleWriter{Out: os.Stderr})
|
|
||||||
|
|
||||||
jsonPath := flag.String("path", "", "path to model JSON")
|
|
||||||
|
|
||||||
flag.Parse()
|
|
||||||
if *jsonPath == "" {
|
|
||||||
fmt.Fprintf(os.Stderr, "You must provide a model path.\n\n")
|
|
||||||
flag.Usage()
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
input := strings.Join(flag.Args(), " ")
|
|
||||||
|
|
||||||
f, err := ioutil.ReadFile(*jsonPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal().Err(err).Msgf("error reading json")
|
|
||||||
}
|
|
||||||
var meta MetaData
|
|
||||||
err = json.Unmarshal(f, &meta)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal().Err(err).Msgf("error reading json")
|
|
||||||
}
|
|
||||||
|
|
||||||
stream := make(chan base.TextDatapoint, 100)
|
|
||||||
//errors := make(chan error)
|
|
||||||
model := text.NewNaiveBayes(stream, meta.NClasses, base.OnlyWordsAndNumbers)
|
|
||||||
err = model.RestoreFromFile(strings.TrimSuffix(*jsonPath, ".json"))
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal().Err(err).Msgf("error reading json")
|
|
||||||
}
|
|
||||||
|
|
||||||
class, prob := model.Probability(input)
|
|
||||||
emojy := meta.ClassList[class]
|
|
||||||
fmt.Printf("%s: %s (%.2f)\n", input, emojy, prob)
|
|
||||||
}
|
|
|
@ -1,185 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"flag"
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/cdipaolo/goml/base"
|
|
||||||
"github.com/cdipaolo/goml/text"
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
type logEntry struct {
|
|
||||||
Who string
|
|
||||||
Author string
|
|
||||||
Body string
|
|
||||||
Emojy string
|
|
||||||
}
|
|
||||||
|
|
||||||
type logs []logEntry
|
|
||||||
|
|
||||||
type emojySet map[string]bool
|
|
||||||
|
|
||||||
type MetaData struct {
|
|
||||||
NClasses uint8
|
|
||||||
ClassList []string
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
log.Logger = log.With().Caller().Stack().Logger()
|
|
||||||
log.Logger = log.Logger.Output(zerolog.ConsoleWriter{Out: os.Stderr})
|
|
||||||
|
|
||||||
logDir := flag.String("path", "", "path to logs")
|
|
||||||
outFile := flag.String("out", "emojy.model", "path to store model")
|
|
||||||
|
|
||||||
flag.Parse()
|
|
||||||
if *logDir == "" {
|
|
||||||
fmt.Fprintf(os.Stderr, "You must provide a log path.\n\n")
|
|
||||||
flag.Usage()
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
logs, classes := prepLogs(*logDir)
|
|
||||||
model, meta := bayes(logs, classes)
|
|
||||||
err := model.PersistToFile(*outFile)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal().Err(err).Msgf("failed to save model")
|
|
||||||
}
|
|
||||||
metaJSON, err := json.Marshal(meta)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal().Err(err).Msgf("failed to save model")
|
|
||||||
}
|
|
||||||
err = ioutil.WriteFile(*outFile+".json", metaJSON, 0666)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal().Err(err).Msgf("failed to save model")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var re = regexp.MustCompile(`(?i)^\[.+\] <(?P<Who>[[:punct:][:alnum:]]+)> reacted to (?P<Author>[[:punct:][:alnum:]]+): (?P<Body>.+) with :(?P<Emojy>[[:punct:][:alnum:]]+):$`)
|
|
||||||
|
|
||||||
func prepLogs(path string) (logs, emojySet) {
|
|
||||||
entries := logs{}
|
|
||||||
emojies := emojySet{}
|
|
||||||
err := filepath.Walk(path, func(path string, info os.FileInfo, err error) error {
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if info.IsDir() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
tmp, err := ioutil.ReadFile(path)
|
|
||||||
content := string(tmp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
for _, line := range strings.Split(content, "\n") {
|
|
||||||
if strings.Contains(line, "unknown event") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !re.MatchString(line) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
entry := parseEntry(line)
|
|
||||||
emojies[entry.Emojy] = true
|
|
||||||
log.Debug().
|
|
||||||
Interface("entry", entry).
|
|
||||||
Str("line", line).
|
|
||||||
Msgf("Found emojy reaction entry")
|
|
||||||
entries = append(entries, entry)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal().Msgf("Error walking: %s", err)
|
|
||||||
}
|
|
||||||
return entries, emojies
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseEntry(content string) logEntry {
|
|
||||||
out := logEntry{}
|
|
||||||
subs := re.FindStringSubmatch(content)
|
|
||||||
if len(subs) == 0 {
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
for i, n := range re.SubexpNames() {
|
|
||||||
switch n {
|
|
||||||
case "Who":
|
|
||||||
out.Who = subs[i]
|
|
||||||
case "Author":
|
|
||||||
out.Author = subs[i]
|
|
||||||
case "Body":
|
|
||||||
out.Body = subs[i]
|
|
||||||
case "Emojy":
|
|
||||||
out.Emojy = subs[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func bayes(logs logs, classes emojySet) (*text.NaiveBayes, MetaData) {
|
|
||||||
// create the channel of data and errors
|
|
||||||
stream := make(chan base.TextDatapoint, 100)
|
|
||||||
errors := make(chan error)
|
|
||||||
|
|
||||||
nClasses := uint8(len(classes))
|
|
||||||
|
|
||||||
classMap := map[string]uint8{}
|
|
||||||
classList := []string{}
|
|
||||||
for k, _ := range classes {
|
|
||||||
classList = append(classList, k)
|
|
||||||
classMap[k] = uint8(len(classList) - 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Strs("classList", classList).Interface("classMap", classMap).Int("nLogs", len(logs)).Msgf("about to train")
|
|
||||||
|
|
||||||
// make a new NaiveBayes model with
|
|
||||||
// 2 classes expected (classes in
|
|
||||||
// datapoints will now expect {0,1}.
|
|
||||||
// in general, given n as the classes
|
|
||||||
// variable, the model will expect
|
|
||||||
// datapoint classes in {0,...,n-1})
|
|
||||||
//
|
|
||||||
// Note that the model is filtering
|
|
||||||
// the text to omit anything except
|
|
||||||
// words and numbers (and spaces
|
|
||||||
// obviously)
|
|
||||||
model := text.NewNaiveBayes(stream, nClasses, base.OnlyWordsAndNumbers)
|
|
||||||
go model.OnlineLearn(errors)
|
|
||||||
|
|
||||||
for _, l := range logs {
|
|
||||||
stream <- base.TextDatapoint{
|
|
||||||
X: l.Body,
|
|
||||||
Y: classMap[l.Emojy],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
close(stream)
|
|
||||||
for {
|
|
||||||
err := <-errors
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("Error passed")
|
|
||||||
} else {
|
|
||||||
// training is done!
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// now you can predict like normal
|
|
||||||
in := "Should work properly once that number of documents increases."
|
|
||||||
class := model.Predict(in) // 0
|
|
||||||
emojy := classList[class]
|
|
||||||
log.Debug().Msgf("Class prediction for %s: %v", in, emojy)
|
|
||||||
|
|
||||||
meta := MetaData{
|
|
||||||
NClasses: nClasses,
|
|
||||||
ClassList: classList,
|
|
||||||
}
|
|
||||||
|
|
||||||
return model, meta
|
|
||||||
}
|
|
Loading…
Reference in New Issue