diff --git a/pkg/forward/manager.go b/pkg/forward/manager.go index e788da9..1c29f1a 100644 --- a/pkg/forward/manager.go +++ b/pkg/forward/manager.go @@ -270,37 +270,47 @@ func (m *Manager) acceptLoop(t *Target) { } } +func (m *Manager) processChunk(t *Target, data []byte) { + if len(data) == 0 { + return + } + n := len(data) + atomic.AddUint64(&t.stats.ReadBytes, uint64(n)) + chunk := make([]byte, n) + copy(chunk, data) + if wErr := m.writeToSerial(chunk); wErr != nil { + t.stats.LastError = wErr.Error() + m.notify("[forward] #%d write serial error: %v", t.ID, wErr) + } else if m.onInbound != nil { + m.onInbound(t.ID, chunk) + } +} + +func (m *Manager) readLoopError(t *Target, err error) { + select { + case <-t.closeCh: + return + default: + } + t.Connected = false + t.stats.LastError = err.Error() + m.notify("[forward] #%d disconnected: %v", t.ID, err) +} + func (m *Manager) readLoopPacket(t *Target) { buf := make([]byte, 4096) for { n, addr, err := t.packetConn.ReadFrom(buf) if n > 0 { - atomic.AddUint64(&t.stats.ReadBytes, uint64(n)) - chunk := make([]byte, n) - copy(chunk, buf[:n]) - if wErr := m.writeToSerial(chunk); wErr != nil { - t.stats.LastError = wErr.Error() - m.notify("[forward] #%d write serial error: %v", t.ID, wErr) - } else if m.onInbound != nil { - m.onInbound(t.ID, chunk) - } - // Track remote address for Broadcast + m.processChunk(t, buf[:n]) t.mu.Lock() t.remoteAddrs[addr.String()] = addr t.mu.Unlock() } if err != nil { - select { - case <-t.closeCh: - return - default: - } - t.Connected = false - t.stats.LastError = err.Error() - m.notify("[forward] #%d disconnected: %v", t.ID, err) + m.readLoopError(t, err) return } - select { case <-t.closeCh: return @@ -314,28 +324,12 @@ func (m *Manager) readLoopSerial(t *Target) { for { n, err := t.serialPort.Read(buf) if n > 0 { - atomic.AddUint64(&t.stats.ReadBytes, uint64(n)) - chunk := make([]byte, n) - copy(chunk, buf[:n]) - if wErr := m.writeToSerial(chunk); wErr != nil { - t.stats.LastError = wErr.Error() - m.notify("[forward] #%d write serial error: %v", t.ID, wErr) - } else if m.onInbound != nil { - m.onInbound(t.ID, chunk) - } + m.processChunk(t, buf[:n]) } if err != nil { - select { - case <-t.closeCh: - return - default: - } - t.Connected = false - t.stats.LastError = err.Error() - m.notify("[forward] #%d disconnected: %v", t.ID, err) + m.readLoopError(t, err) return } - select { case <-t.closeCh: return @@ -349,22 +343,11 @@ func (m *Manager) readLoop(t *Target, conn net.Conn, stop <-chan struct{}) { for { n, err := conn.Read(buf) if n > 0 { - atomic.AddUint64(&t.stats.ReadBytes, uint64(n)) - chunk := make([]byte, n) - copy(chunk, buf[:n]) - if wErr := m.writeToSerial(chunk); wErr != nil { - t.stats.LastError = wErr.Error() - m.notify("[forward] #%d write serial error: %v", t.ID, wErr) - } else if m.onInbound != nil { - m.onInbound(t.ID, chunk) - } + m.processChunk(t, buf[:n]) } - if err != nil { t.Connected = false t.stats.LastError = err.Error() - - // Remove from TCP server conns if applicable if t.Mode == TCPServer { t.connsMu.Lock() delete(t.conns, conn) @@ -374,7 +357,6 @@ func (m *Manager) readLoop(t *Target, conn net.Conn, stop <-chan struct{}) { _ = conn.Close() return } - select { case <-stop: _ = conn.Close()