diff --git a/app/chat/export.go b/app/chat/export.go index 591695908c..20d9aa49de 100644 --- a/app/chat/export.go +++ b/app/chat/export.go @@ -8,6 +8,7 @@ import ( "time" "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" "github.com/fatih/color" "github.com/go-faster/jx" "github.com/gotd/td/telegram" @@ -150,6 +151,9 @@ func Export(ctx context.Context, c *telegram.Client, kvd storage.Storage, opts E defer enc.ArrEnd() count := int64(0) + expander := newExportMessageExpander(opts, filter, func(ctx context.Context, msg *tg.Message) ([]*tg.Message, error) { + return tutil.GetGroupedMessages(ctx, c.API(), peer.InputPeer(), msg) + }) loop: for iter.Next(ctx) { @@ -173,52 +177,131 @@ loop: if !ok { continue } - // only get media messages - media, ok := tmedia.GetMedia(m) - if !ok && !opts.All { - continue - } - b, err := texpr.Run(filter, texpr.ConvertEnvMessage(m)) + messages, err := expander.Expand(ctx, m) if err != nil { - return fmt.Errorf("failed to run filter: %w", err) - } - if !b.(bool) { // filtered - continue + return err } + for _, message := range messages { + mb, err := json.Marshal(message) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + enc.Raw(mb) - fileName := "" - if media != nil { // #207 - fileName = media.Name + count++ + tracker.SetValue(count) } - t := &Message{ - ID: m.ID, - Type: "message", - File: fileName, + } + + if err = iter.Err(); err != nil { + return err + } + + tracker.MarkAsDone() + prog.Wait(ctx, pw) + return nil +} + +type groupedMessageResolver func(context.Context, *tg.Message) ([]*tg.Message, error) + +type exportMessageExpander struct { + opts ExportOptions + filter *vm.Program + resolveGrouped groupedMessageResolver + seen map[int]struct{} +} + +func newExportMessageExpander( + opts ExportOptions, + filter *vm.Program, + resolveGrouped groupedMessageResolver, +) *exportMessageExpander { + return &exportMessageExpander{ + opts: opts, + filter: filter, + resolveGrouped: resolveGrouped, + seen: make(map[int]struct{}), + } +} + +func (e *exportMessageExpander) Expand(ctx context.Context, msg *tg.Message) ([]*Message, error) { + if _, ok := e.seen[msg.ID]; ok { + return nil, nil + } + + matched, err := e.matchesFilter(msg) + if err != nil { + return nil, err + } + if !matched { + return nil, nil + } + + messages := []*tg.Message{msg} + if _, ok := msg.GetGroupedID(); ok && e.resolveGrouped != nil { + grouped, err := e.resolveGrouped(ctx, msg) + if err != nil { + return nil, fmt.Errorf("failed to resolve grouped message %d: %w", msg.ID, err) } - if opts.WithContent { - t.Date = m.Date - t.Text = m.Message + if len(grouped) > 0 { + messages = grouped } - if opts.Raw { - t.Raw = m + } + + exported := make([]*Message, 0, len(messages)) + for _, message := range messages { + if _, ok := e.seen[message.ID]; ok { + continue } - mb, err := json.Marshal(t) - if err != nil { - return fmt.Errorf("failed to marshal message: %w", err) + out, ok := e.convert(message) + if !ok { + continue } - enc.Raw(mb) - count++ - tracker.SetValue(count) + e.seen[message.ID] = struct{}{} + exported = append(exported, out) } - if err = iter.Err(); err != nil { - return err + return exported, nil +} + +func (e *exportMessageExpander) matchesFilter(msg *tg.Message) (bool, error) { + b, err := texpr.Run(e.filter, texpr.ConvertEnvMessage(msg)) + if err != nil { + return false, fmt.Errorf("failed to run filter: %w", err) } - tracker.MarkAsDone() - prog.Wait(ctx, pw) - return nil + matched, ok := b.(bool) + if !ok { + return false, fmt.Errorf("filter returned %T, expected bool", b) + } + return matched, nil +} + +func (e *exportMessageExpander) convert(msg *tg.Message) (*Message, bool) { + media, ok := tmedia.GetMedia(msg) + if !ok && !e.opts.All { + return nil, false + } + + fileName := "" + if media != nil { // #207 + fileName = media.Name + } + out := &Message{ + ID: msg.ID, + Type: "message", + File: fileName, + } + if e.opts.WithContent { + out.Date = msg.Date + out.Text = msg.Message + } + if e.opts.Raw { + out.Raw = msg + } + + return out, true } diff --git a/app/chat/export_test.go b/app/chat/export_test.go new file mode 100644 index 0000000000..30e5c2e5cf --- /dev/null +++ b/app/chat/export_test.go @@ -0,0 +1,80 @@ +package chat + +import ( + "context" + "testing" + + "github.com/expr-lang/expr" + "github.com/gotd/td/tg" + "github.com/stretchr/testify/require" +) + +func TestExportExpandsGroupedMediaWhenFilteredMessageMatches(t *testing.T) { + ctx := context.Background() + filter, err := expr.Compile(`Message contains "#sample"`, expr.AsBool()) + require.NoError(t, err) + + grouped := []*tg.Message{ + testExportVideoMessage(10, 100, ""), + testExportVideoMessage(11, 100, `#sample`), + testExportVideoMessage(12, 100, ""), + } + calls := 0 + expander := newExportMessageExpander(ExportOptions{}, filter, func(context.Context, *tg.Message) ([]*tg.Message, error) { + calls++ + return grouped, nil + }) + + messages, err := expander.Expand(ctx, grouped[1]) + require.NoError(t, err) + require.Equal(t, []int{10, 11, 12}, exportMessageIDs(messages)) + + messages, err = expander.Expand(ctx, grouped[1]) + require.NoError(t, err) + require.Empty(t, messages) + require.Equal(t, 1, calls) +} + +func TestExportDoesNotExpandGroupedMediaWhenFilterDoesNotMatch(t *testing.T) { + ctx := context.Background() + filter, err := expr.Compile(`Message contains "#sample"`, expr.AsBool()) + require.NoError(t, err) + + expander := newExportMessageExpander(ExportOptions{}, filter, func(context.Context, *tg.Message) ([]*tg.Message, error) { + t.Fatal("group resolver should not run when the filtered message does not match") + return nil, nil + }) + + messages, err := expander.Expand(ctx, testExportVideoMessage(10, 100, "other")) + require.NoError(t, err) + require.Empty(t, messages) +} + +func exportMessageIDs(messages []*Message) []int { + ids := make([]int, 0, len(messages)) + for _, message := range messages { + ids = append(ids, message.ID) + } + return ids +} + +func testExportVideoMessage(id int, groupedID int64, text string) *tg.Message { + msg := &tg.Message{ + ID: id, + Date: 1, + Message: text, + } + msg.SetGroupedID(groupedID) + msg.SetMedia(&tg.MessageMediaDocument{ + Document: &tg.Document{ + ID: int64(id), + MimeType: "video/mp4", + Size: 1024, + DCID: 1, + Attributes: []tg.DocumentAttributeClass{ + &tg.DocumentAttributeFilename{FileName: "video.mp4"}, + }, + }, + }) + return msg +}