Skip to content

Commit f3eb19f

Browse files
committed
refactor: concurrent and chores
1 parent 0c9623c commit f3eb19f

File tree

27 files changed

+357
-189
lines changed

27 files changed

+357
-189
lines changed

cmd/main.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,6 @@ func loadConfig[T config.IConfig](path string) (T, error) {
3333
return *cfg, nil
3434
}
3535

36-
func start(prog program.IProgram) {
37-
program.Program = prog
38-
prog.Run()
39-
}
40-
4136
func init() {
4237
rootCmd.PersistentFlags().BoolVarP(&config.IsVerbose, "verbose", "v", false, "verbose output")
4338

@@ -58,7 +53,8 @@ func init() {
5853
logger.Error("Error creating server: ", err)
5954
return
6055
}
61-
start(s)
56+
program.Program = s
57+
s.Run()
6258
},
6359
})
6460

@@ -74,7 +70,9 @@ func init() {
7470
logger.Error("Error loading config: ", err)
7571
return
7672
}
77-
start(client.NewClient(ctx, cfg))
73+
c := client.NewClient(ctx, cfg)
74+
program.Program = c
75+
c.Run()
7876
},
7977
})
8078

pkg/arch/acceptors/acceptor_udp.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ package acceptors
22

33
import (
44
"context"
5-
"net"
65

6+
"club.asynclab/asrp/pkg/base/network"
77
"club.asynclab/asrp/pkg/comm"
88
)
99

1010
func NewAcceptorUDP(parentCtx context.Context, addr string) (*Acceptor[comm.UDP], error) {
11-
listener, err := net.Listen("udp", addr) // TODO
11+
listener, err := network.NewUDPListener(addr)
1212
if err != nil {
1313
return nil, err
1414
}

pkg/arch/acceptors/acceptors.go

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -88,23 +88,27 @@ func (acceptor *Acceptor[T]) HandlePacket(pkt packet.IPacket) bool {
8888
return true
8989
}
9090

91-
func (acceptor *Acceptor[T]) init() {
92-
go func() {
93-
pattern.NewConfigSelectContextAndChannel[*comm.Conn]().
94-
WithCtx(acceptor.GetCtx()).
95-
WithGoroutine(func(ch chan *comm.Conn) {
96-
for {
97-
conn, err := acceptor.listener.Accept()
98-
if err != nil {
99-
if acceptor.GetCtx().Err() != nil {
100-
return
101-
}
102-
continue
91+
// ----------------------------------------------
92+
93+
func (acceptor *Acceptor[T]) routineRead() {
94+
pattern.NewConfigSelectContextAndChannel[*comm.Conn]().
95+
WithCtx(acceptor.GetCtx()).
96+
WithGoroutine(func(ch chan *comm.Conn) {
97+
for {
98+
conn, err := acceptor.listener.Accept()
99+
if err != nil {
100+
if acceptor.GetCtx().Err() != nil {
101+
return
103102
}
104-
ch <- conn
103+
continue
105104
}
106-
}).
107-
WithChannelHandler(func(conn *comm.Conn) { go acceptor.handleConnection(conn) }).
108-
Run()
109-
}()
105+
ch <- conn
106+
}
107+
}).
108+
WithChannelHandler(func(conn *comm.Conn) { go acceptor.handleConnection(conn) }).
109+
Run()
110+
}
111+
112+
func (acceptor *Acceptor[T]) init() {
113+
go acceptor.routineRead()
110114
}

pkg/arch/arch.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,5 @@ type IAcceptor interface {
5252

5353
type ForwarderWithValues struct {
5454
IForwarder
55-
Name string
56-
FrontendAddr string
57-
Priority int
58-
Weight int
55+
InitPacket *packet.PacketProxyNegotiationRequest
5956
}

pkg/arch/connectors/connectors.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ func (connector *Connector[T]) initConnection(conn *comm.Conn) error {
7373

7474
r := &packet.PacketProxyNegotiationRequest{
7575
Name: connector.proxyConfig.Name,
76+
Proto: connector.proxyConfig.Proto,
7677
FrontendAddr: connector.proxyConfig.Frontend,
7778
Token: connector.remoteConfig.Token,
7879
Priority: connector.proxyConfig.Priority,

pkg/arch/dialers/dialer_udp.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@ package dialers
33
import (
44
"context"
55
"net"
6-
"time"
76

8-
"club.asynclab/asrp/pkg/base/network"
97
"club.asynclab/asrp/pkg/comm"
108
)
119

@@ -23,5 +21,5 @@ func (impl *DialerImplUDP) Dial(ctx context.Context, addr string) (*comm.Conn, e
2321
if err != nil {
2422
return nil, err
2523
}
26-
return comm.NewConnWithParentCtx(ctx, network.NewConnWithTimeout(conn, 60*time.Second, 0)), nil
24+
return comm.NewConnWithParentCtx(ctx, conn), nil
2725
}

pkg/arch/dialers/dialers.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,14 @@ func (dialer *Dialer[T]) HandlePacket(pkt packet.IPacket) bool {
9292
return false
9393
}
9494

95-
switch pkt := pkt.(type) { // TODO dialer失败就发送给dispatcher重分配请求
95+
switch pkt := pkt.(type) {
9696
case *packet.PacketProxyData:
9797
ok := true
9898

9999
dialer.conns.Compute(func(v *concurrent.ConcurrentIndexMap[*comm.Conn]) {
100100
if _, ok := v.Load(pkt.Uuid); !ok {
101101
conn, err := dialer.impl.Dial(dialer.ctx, dialer.addr)
102-
if err != nil {
102+
if err != nil { // TODO dialer失败就发送给dispatcher重分配请求
103103
ok = false
104104
logger.Error(fmt.Sprintf("Error dialing: %v", err))
105105
return

pkg/arch/dispatchers/dispatchers.go

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@ var logger = logging.GetLogger()
1717

1818
type Dispatcher struct {
1919
concurrent.MetaConcurrentStructure[Dispatcher]
20-
ctx context.Context
21-
ctxCancel context.CancelFunc
22-
forwarders *structure.IndexMap[*arch.ForwarderWithValues]
23-
totalWeights map[int]int // priority -> totalWeight
24-
currentIndex int
25-
senderPacket *channel.SafeSender[packet.IPacket]
26-
connsMap *concurrent.ConcurrentMap[string, string]
20+
ctx context.Context
21+
ctxCancel context.CancelFunc
22+
forwarders *structure.IndexMap[*arch.ForwarderWithValues]
23+
totalWeights map[uint32]uint32 // priority -> totalWeight
24+
currentIndex int
25+
senderPacket *channel.SafeSender[packet.IPacket]
26+
connsMap *concurrent.ConcurrentMap[string, string] // conn -> forwarder
27+
connsMapBackward *concurrent.ConcurrentMap[string, *structure.HashSet[string]] // forwarder -> conns
2728
}
2829

2930
func NewDispatcher(parentCtx context.Context) *Dispatcher {
@@ -32,11 +33,12 @@ func NewDispatcher(parentCtx context.Context) *Dispatcher {
3233
ctx: ctx,
3334
ctxCancel: cancel,
3435
forwarders: structure.NewIndexMap[*arch.ForwarderWithValues](),
35-
totalWeights: make(map[int]int),
36+
totalWeights: make(map[uint32]uint32),
3637
currentIndex: 0,
3738
MetaConcurrentStructure: *concurrent.NewMetaSyncStructure[Dispatcher](),
3839
senderPacket: channel.NewSafeSenderWithParentCtxAndSize[packet.IPacket](ctx, 16),
3940
connsMap: concurrent.NewSyncMap[string, string](),
41+
connsMapBackward: concurrent.NewSyncMap[string, *structure.HashSet[string]](),
4042
}
4143

4244
return dispatcher
@@ -57,23 +59,12 @@ func (dispatcher *Dispatcher) Close() error {
5759

5860
// ---------------------------------------------------------------------
5961

60-
func (dispatcher *Dispatcher) packetHandlerMiddleware(forwarderUuid string, f func(packet.IPacket) bool) func(packet.IPacket) bool {
61-
return func(pkt packet.IPacket) bool {
62-
switch pkt := pkt.(type) {
63-
case *packet.PacketProxyData:
64-
dispatcher.connsMap.LoadOrStore(pkt.Uuid, forwarderUuid)
65-
}
66-
67-
return f(pkt)
68-
}
69-
}
70-
7162
func (dispatcher *Dispatcher) AddForwarder(fwv *arch.ForwarderWithValues) (uuid string) {
7263
dispatcher.Lock.Lock()
7364
defer dispatcher.Lock.Unlock()
7465

7566
uuid = dispatcher.forwarders.Store(fwv)
76-
dispatcher.totalWeights[fwv.Priority] += fwv.Weight
67+
dispatcher.totalWeights[fwv.InitPacket.Priority] += fwv.InitPacket.Weight
7768

7869
go channel.ConsumeWithCtx(dispatcher.GetCtx(), fwv.GetChanSendPacket(), dispatcher.senderPacket.Push)
7970

@@ -88,12 +79,14 @@ func (dispatcher *Dispatcher) RemoveForwarder(uuid string) {
8879
if !ok {
8980
return
9081
}
91-
dispatcher.totalWeights[conn.Priority] -= conn.Weight
92-
if dispatcher.totalWeights[conn.Priority] == 0 {
93-
delete(dispatcher.totalWeights, conn.Priority)
82+
dispatcher.totalWeights[conn.InitPacket.Priority] -= conn.InitPacket.Weight
83+
if dispatcher.totalWeights[conn.InitPacket.Priority] == 0 {
84+
delete(dispatcher.totalWeights, conn.InitPacket.Priority)
9485
}
9586
dispatcher.forwarders.Delete(uuid)
96-
dispatcher.connsMap.Delete(uuid)
87+
if conns, ok := dispatcher.connsMapBackward.LoadAndDelete(uuid); ok {
88+
conns.Stream().ForEach(func(t string) { dispatcher.connsMap.Delete(t) })
89+
}
9790
}
9891

9992
func (dispatcher *Dispatcher) Next() (uuid string, forwarder arch.IForwarder, ok bool) {
@@ -104,21 +97,21 @@ func (dispatcher *Dispatcher) Next() (uuid string, forwarder arch.IForwarder, ok
10497
return
10598
}
10699

107-
totalWeight, _ok := hof.NewStreamWithMap(dispatcher.totalWeights).Max(func(bigger container.Entry[int, int], smaller container.Entry[int, int]) bool {
100+
totalWeight, _ok := hof.NewStreamWithMap(dispatcher.totalWeights).Max(func(bigger container.Entry[uint32, uint32], smaller container.Entry[uint32, uint32]) bool {
108101
return bigger.GetKey() > smaller.GetKey()
109102
})
110103

111104
if !_ok || totalWeight.GetValue() == 0 {
112105
return
113106
}
114107

115-
dispatcher.currentIndex = (dispatcher.currentIndex + 1) % totalWeight.GetValue()
108+
dispatcher.currentIndex = (dispatcher.currentIndex + 1) % int(totalWeight.GetValue())
116109
i := dispatcher.currentIndex
117110

118111
dispatcher.forwarders.Stream().Filter(func(t container.Entry[string, *arch.ForwarderWithValues]) bool {
119-
return t.GetValue().Priority == totalWeight.GetKey()
112+
return t.GetValue().InitPacket.Priority == totalWeight.GetKey()
120113
}).Range(func(t container.Entry[string, *arch.ForwarderWithValues]) bool {
121-
i -= t.GetValue().Weight
114+
i -= int(t.GetValue().InitPacket.Weight)
122115
if i < 0 {
123116
uuid, forwarder, ok = t.GetKey(), t.GetValue(), true
124117
return false
@@ -152,16 +145,18 @@ func (dispatcher *Dispatcher) HandlePacket(pkt packet.IPacket) bool {
152145
switch pkt := pkt.(type) {
153146
case packet.IPacketForConn:
154147
if uuid, ok := dispatcher.connsMap.Load(pkt.GetUuid()); ok {
155-
if v, ok := dispatcher.forwarders.Load(uuid); ok {
156-
v.HandlePacket(pkt)
157-
return true
148+
if forwarder, ok := dispatcher.forwarders.Load(uuid); ok {
149+
forwarder.HandlePacket(pkt)
150+
}
151+
} else {
152+
if uuid, forwarder, ok := dispatcher.Next(); ok {
153+
actual, _ := dispatcher.connsMap.LoadOrStore(pkt.GetUuid(), uuid)
154+
actual2, _ := dispatcher.connsMapBackward.LoadOrStore(actual, structure.NewHashSet[string]())
155+
actual2.Store(pkt.GetUuid())
156+
forwarder.HandlePacket(pkt)
158157
}
159158
}
160159
}
161160

162-
if uuid, forwarder, ok := dispatcher.Next(); ok {
163-
dispatcher.packetHandlerMiddleware(uuid, forwarder.HandlePacket)(pkt)
164-
}
165-
166161
return true
167162
}

pkg/arch/forwarders/forwarders.go

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func NewForwarder(conn *comm.Conn) *Forwarder {
2222
forwarder := &Forwarder{
2323
Tip: "[Common Forwarder]: ",
2424
conn: conn,
25-
senderPacket: channel.NewSafeSenderWithSize[packet.IPacket](16),
25+
senderPacket: channel.NewSafeSenderWithParentCtxAndSize[packet.IPacket](conn.GetCtx(), 16),
2626
}
2727

2828
forwarder.init()
@@ -69,20 +69,22 @@ func (forwarder *Forwarder) GetChanSendPacket() <-chan packet.IPacket {
6969

7070
// ---------------------------------------------------------------------
7171

72-
func (forwarder *Forwarder) init() {
73-
go func() {
74-
pattern.NewConfigSelectContextAndChannel[packet.IPacket]().
75-
WithCtx(forwarder.GetCtx()).
76-
WithGoroutine(func(ch chan packet.IPacket) {
77-
for {
78-
pkt := forwarder.receivePacket()
79-
if pkt == nil {
80-
return
81-
}
82-
ch <- pkt
72+
func (forwarder *Forwarder) routineRead() {
73+
pattern.NewConfigSelectContextAndChannel[packet.IPacket]().
74+
WithCtx(forwarder.GetCtx()).
75+
WithGoroutine(func(ch chan packet.IPacket) {
76+
for {
77+
pkt := forwarder.receivePacket()
78+
if pkt == nil {
79+
return
8380
}
84-
}).
85-
WithChannelHandler(func(pkt packet.IPacket) { forwarder.senderPacket.Push(pkt) }).
86-
Run()
87-
}()
81+
ch <- pkt
82+
}
83+
}).
84+
WithChannelHandler(func(pkt packet.IPacket) { forwarder.senderPacket.Push(pkt) }).
85+
Run()
86+
}
87+
88+
func (forwarder *Forwarder) init() {
89+
go forwarder.routineRead()
8890
}

pkg/arch/ushers/ushers.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,15 @@ func (usher *Usher[T]) handleConnection(conn *comm.Conn) {
7575
}
7676

7777
usher.senderForwarder.Push(&arch.ForwarderWithValues{
78-
IForwarder: f,
79-
Name: p.Name,
80-
FrontendAddr: p.FrontendAddr,
81-
Priority: int(p.Priority),
82-
Weight: int(p.Weight),
78+
IForwarder: f,
79+
InitPacket: p,
8380
})
8481
}
8582

86-
func (usher *Usher[T]) init() {
87-
go pattern.NewConfigSelectContextAndChannel[*comm.Conn]().
83+
// -------------------------------------------
84+
85+
func (usher *Usher[T]) routineRead() {
86+
pattern.NewConfigSelectContextAndChannel[*comm.Conn]().
8887
WithCtx(usher.GetCtx()).
8988
WithGoroutine(func(ch chan *comm.Conn) {
9089
for {
@@ -102,3 +101,7 @@ func (usher *Usher[T]) init() {
102101
WithChannelHandler(func(conn *comm.Conn) { go usher.handleConnection(conn) }).
103102
Run()
104103
}
104+
105+
func (usher *Usher[T]) init() {
106+
go usher.routineRead()
107+
}

0 commit comments

Comments
 (0)