Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 120 additions & 25 deletions pkg/tmessage/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package tmessage

import (
"context"
"encoding/json"
"errors"
"io"
"os"
Expand Down Expand Up @@ -56,25 +57,26 @@ func FromFile(ctx context.Context, pool dcpool.Pool, kvd storage.Storage, files
}

func parseFile(ctx context.Context, client *tg.Client, kvd storage.Storage, file string, onlyMedia bool) (*Dialog, error) {
f, err := os.Open(file)
if err != nil {
return nil, err
}
defer func(f *os.File) {
_ = f.Close()
}(f)

peer, err := getChatInfo(ctx, client, kvd, f)
peer, err := getChatInfo(ctx, client, kvd, file)
if err != nil {
return nil, err
}
logctx.From(ctx).Debug("Got peer info",
zap.Int64("id", peer.ID()),
zap.String("name", peer.VisibleName()))

if _, err = f.Seek(0, io.SeekStart); err != nil {
return collectFile(ctx, file, peer, onlyMedia)
}

func collectFile(ctx context.Context, file string, peer peers.Peer, onlyMedia bool) (*Dialog, error) {
// Use a fresh handle so chat ID probing cannot race with jstream's streaming decoder on file offsets.
f, err := os.Open(file)
if err != nil {
return nil, err
}
defer func(f *os.File) {
_ = f.Close()
}(f)

return collect(ctx, f, peer, onlyMedia)
}
Expand Down Expand Up @@ -117,30 +119,123 @@ func collect(ctx context.Context, r io.Reader, peer peers.Peer, onlyMedia bool)
return m, nil
}

func getChatInfo(ctx context.Context, client *tg.Client, kvd storage.Storage, r io.Reader) (peers.Peer, error) {
d := jstream.NewDecoder(r, 1).EmitKV()
func getChatInfo(ctx context.Context, client *tg.Client, kvd storage.Storage, file string) (peers.Peer, error) {
f, err := os.Open(file)
if err != nil {
return nil, err
}
defer func(f *os.File) {
_ = f.Close()
}(f)

chatID := int64(0)
chatID, err := readChatID(f)
if err != nil {
return nil, err
}
if chatID == 0 {
return nil, errors.New("can't get chat type or chat id")
}

for mv := range d.Stream() {
_kv, ok := mv.Value.(jstream.KV)
if !ok {
continue
manager := peers.Options{Storage: storage.NewPeers(kvd)}.Build(client)
return tutil.GetInputPeer(ctx, manager, strconv.FormatInt(chatID, 10))
}

func readChatID(r io.Reader) (int64, error) {
d := json.NewDecoder(r)
d.UseNumber()

tok, err := d.Token()
if err != nil {
return 0, err
}
if delim, ok := tok.(json.Delim); !ok || delim != '{' {
return 0, errors.New("expected telegram export JSON object")
}

for d.More() {
tok, err = d.Token()
if err != nil {
return 0, err
}

if _kv.Key == keyID {
chatID = int64(_kv.Value.(float64))
k, ok := tok.(string)
if !ok {
return 0, errors.New("expected telegram export JSON object key")
}

if chatID != 0 {
break
if k != keyID {
if err = skipJSONValue(d); err != nil {
return 0, err
}
continue
}

return decodeChatID(d)
}

if chatID == 0 {
return nil, errors.New("can't get chat type or chat id")
return 0, errors.New("can't get chat type or chat id")
}

func decodeChatID(d *json.Decoder) (int64, error) {
tok, err := d.Token()
if err != nil {
return 0, err
}

manager := peers.Options{Storage: storage.NewPeers(kvd)}.Build(client)
return tutil.GetInputPeer(ctx, manager, strconv.FormatInt(chatID, 10))
switch v := tok.(type) {
case json.Number:
return v.Int64()
case string:
return strconv.ParseInt(v, 10, 64)
default:
return 0, errors.New("invalid telegram export chat id")
}
}

func skipJSONValue(d *json.Decoder) error {
tok, err := d.Token()
if err != nil {
return err
}

delim, ok := tok.(json.Delim)
if !ok {
return nil
}

switch delim {
case '{':
for d.More() {
if _, err = d.Token(); err != nil {
return err
}
if err = skipJSONValue(d); err != nil {
return err
}
}
end, err := d.Token()
if err != nil {
return err
}
if end != json.Delim('}') {
return errors.New("invalid JSON object")
}
case '[':
for d.More() {
if err = skipJSONValue(d); err != nil {
return err
}
}
end, err := d.Token()
if err != nil {
return err
}
if end != json.Delim(']') {
return errors.New("invalid JSON array")
}
default:
return errors.New("unexpected JSON delimiter")
}

return nil
}
141 changes: 141 additions & 0 deletions pkg/tmessage/files_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package tmessage

import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"testing"

"github.com/gotd/td/constant"
"github.com/gotd/td/telegram/peers"
"github.com/gotd/td/tg"
"github.com/stretchr/testify/require"
)

func TestReadChatIDDoesNotInterfereWithCollect(t *testing.T) {
ctx := context.Background()
path := writeTelegramExport(t, 5000)

f, err := os.Open(path)
require.NoError(t, err)
defer func() {
require.NoError(t, f.Close())
}()

id, err := readChatID(f)
require.NoError(t, err)
require.Equal(t, int64(123456), id)

_, err = f.Seek(0, 0)
require.NoError(t, err)

dialog, err := collect(ctx, f, testPeer{id: id}, true)
require.NoError(t, err)
require.Len(t, dialog.Messages, 5000)
require.Equal(t, 1, dialog.Messages[0])
require.Equal(t, 5000, dialog.Messages[4999])
}

func TestCollectFileIsStableAcrossRepeatedReads(t *testing.T) {
ctx := context.Background()
path := writeTelegramExport(t, 5000)
peer := testPeer{id: 123456}

for range 10 {
dialog, err := collectFile(ctx, path, peer, true)
require.NoError(t, err)
require.Len(t, dialog.Messages, 5000)
}
}

func TestReadChatIDSkipsTopLevelValues(t *testing.T) {
id, err := readChatID(strings.NewReader(`{"messages":[{"id":1,"type":"message","file":"1.jpg"}],"name":"test","id":"789"}`))
require.NoError(t, err)
require.Equal(t, int64(789), id)
}

func writeTelegramExport(t *testing.T, messages int) string {
t.Helper()

path := filepath.Join(t.TempDir(), "export.json")
f, err := os.Create(path)
require.NoError(t, err)
defer func() {
require.NoError(t, f.Close())
}()

_, err = fmt.Fprintf(f, `{"name":"test","type":"private_group","id":123456,"messages":[`)
require.NoError(t, err)

for i := 1; i <= messages; i++ {
if i > 1 {
_, err = fmt.Fprint(f, ",")
require.NoError(t, err)
}
_, err = fmt.Fprintf(f, `{"id":%d,"type":"message","date_unixtime":"1710000000","file":"photos/%d.jpg","text":""}`, i, i)
require.NoError(t, err)
}

_, err = fmt.Fprint(f, `]}`)
require.NoError(t, err)

return path
}

type testPeer struct {
id int64
}

func (p testPeer) ID() int64 {
return p.id
}

func (p testPeer) TDLibPeerID() constant.TDLibPeerID {
return constant.TDLibPeerID(p.id)
}

func (p testPeer) VisibleName() string {
return "test"
}

func (p testPeer) Username() (string, bool) {
return "", false
}

func (p testPeer) Restricted() ([]tg.RestrictionReason, bool) {
return nil, false
}

func (p testPeer) Verified() bool {
return false
}

func (p testPeer) Scam() bool {
return false
}

func (p testPeer) Fake() bool {
return false
}

func (p testPeer) InputPeer() tg.InputPeerClass {
return &tg.InputPeerChat{ChatID: p.id}
}

func (p testPeer) Sync(context.Context) error {
return nil
}

func (p testPeer) Manager() *peers.Manager {
return nil
}

func (p testPeer) Report(context.Context, tg.ReportReasonClass, string) error {
return nil
}

func (p testPeer) Photo(context.Context) (*tg.Photo, bool, error) {
return nil, false, nil
}