diff --git a/internal/command/commands.go b/internal/command/commands.go index 4af7618..006a288 100644 --- a/internal/command/commands.go +++ b/internal/command/commands.go @@ -39,7 +39,7 @@ func (d *Dispatcher) handleForwardCommand(args []string) error { case "add": if len(args) < 4 { - return fmt.Errorf("usage: .forward add
") + return fmt.Errorf("usage: .forward add
") } mode, ok := forward.ParseMode(args[2]) if !ok { @@ -71,7 +71,7 @@ func (d *Dispatcher) handleForwardCommand(args []string) error { case "update": if len(args) < 5 { - return fmt.Errorf("usage: .forward update
") + return fmt.Errorf("usage: .forward update
") } id, err := strconv.Atoi(args[2]) if err != nil { diff --git a/internal/command/complete.go b/internal/command/complete.go index c09b12e..20db514 100644 --- a/internal/command/complete.go +++ b/internal/command/complete.go @@ -29,15 +29,15 @@ func filterPrefix(cands []string, cur string) []string { func completeForward(args []string) []string { if len(args) <= 2 { - return []string{"list", "add", "remove", "enable", "disable", "update", "stats"} + return []string{"list", "add", "remove", "enable", "disable", "update"} } if len(args) == 3 && args[1] == "add" { - return []string{"tcp", "udp"} + return []string{"tcp", "udp", "tcp-s", "udp-s", "com"} } if len(args) == 4 && args[1] == "update" { - return []string{"tcp", "udp"} + return []string{"tcp", "udp", "tcp-s", "udp-s", "com"} } return nil diff --git a/internal/command/dispatcher.go b/internal/command/dispatcher.go index 9f1f7d8..f18a035 100644 --- a/internal/command/dispatcher.go +++ b/internal/command/dispatcher.go @@ -99,8 +99,8 @@ func (d *Dispatcher) registerAll() { d.register(RuntimeCommand{ Name: ".forward", - Usage: ".forward ", - Description: "manage forwarding at runtime", + Usage: ".forward ", + Description: "manage forwarding (tcp/udp/tcp-s/udp-s/com)", Handler: d.handleForwardCommand, Completer: completeForward, }) diff --git a/internal/flag/flag.go b/internal/flag/flag.go index a5356f8..dd5792d 100644 --- a/internal/flag/flag.go +++ b/internal/flag/flag.go @@ -37,7 +37,7 @@ func Init(cfg *config.Config) { pflag.IntVarP(&cfg.ParityBit, "verify", "v", 0, "parity (0:none,1:odd,2:even,3:mark,4:space)") pflag.BoolVarP(&cfg.EnableGUI, "gui", "g", false, "enable TUI mode") pflag.StringVarP(&cfg.HotkeyMod, "hotkey-mod", "k", "ctrl+alt", "hotkey modifier (ctrl+alt|ctrl+shift)") - pflag.IntSliceVarP(&cfg.ForWard, "forward", "f", nil, "forward mode (0:none,1:TCP,2:UDP)") + pflag.IntSliceVarP(&cfg.ForWard, "forward", "f", nil, "forward mode (0:none,1:TCP,2:UDP,3:TCP-S,4:UDP-S,5:COM)") pflag.StringArrayVarP(&cfg.Address, "address", "a", nil, "forward address") pflag.StringVarP(&cfg.LogFilePath, "log", "l", "", "log file path") _ = pflag.Lookup("log") // mark for NoOptDefVal @@ -99,7 +99,7 @@ func PrintUsage(ports []string) { {"-v", "--verify", "int", "parity", "0"}, {"-g", "--gui", "bool", "enable TUI", "false"}, {"-k", "--hotkey-mod", "string", "hotkey modifier", "ctrl+alt"}, - {"-f", "--forward", "[]int", "forward mode", "0"}, + {"-f", "--forward", "[]int", "forward (0:none,1:TCP,2:UDP,3:TCP-S,4:UDP-S,5:COM)", "0"}, {"-a", "--address", "[]string", "forward address", "127.0.0.1:12345"}, {"-l", "--log", "string", "log path", "./%s-$s.txt"}, {"-t", "--time", "string", "timestamp format", "[06-01-02 15:04:05.000]"}, @@ -122,7 +122,7 @@ var ( datas = []string{"5", "6", "7", "8"} stops = []string{"1", "1.5", "2"} paritys = []string{"None", "Odd", "Even", "Mark", "Space"} - forwards = []string{"No", "TCP-C", "UDP-C"} + forwards = []string{"No", "TCP-C", "UDP-C", "TCP-S", "UDP-S", "COM"} ) // GetCliFlag runs an interactive configuration wizard when no port is specified. diff --git a/internal/tui/panels.go b/internal/tui/panels.go index 7655df5..0faab58 100644 --- a/internal/tui/panels.go +++ b/internal/tui/panels.go @@ -103,10 +103,10 @@ func (m *Model) handleForwardPanelKey(key string) bool { m.refreshPanel() return true case "a": - m.startPrompt("Add Forward", "tcp 127.0.0.1:12345", "", func(v string) { + m.startPrompt("Add Forward", "tcp 127.0.0.1:12345 (tcp|udp|tcp-s|udp-s|com)", "", func(v string) { parts := strings.Fields(v) if len(parts) < 2 { - m.panelError = "usage:
" + m.panelError = "usage:
" return } mode, ok := forward.ParseMode(parts[0]) @@ -158,7 +158,7 @@ func (m *Model) handleForwardPanelKey(key string) bool { m.startPrompt("Update Forward #"+fmt.Sprint(sel.ID), "tcp 127.0.0.1:12345", fmt.Sprintf("%s %s", sel.Mode, sel.Address), func(v string) { parts := strings.Fields(v) if len(parts) < 2 { - m.panelError = "usage:
" + m.panelError = "usage:
" return } mode, ok := forward.ParseMode(parts[0]) diff --git a/pkg/forward/manager.go b/pkg/forward/manager.go index 6e578eb..e788da9 100644 --- a/pkg/forward/manager.go +++ b/pkg/forward/manager.go @@ -1,4 +1,4 @@ -// Package forward manages TCP/UDP forwarding targets for serial data. +// Package forward manages TCP/UDP/COM forwarding targets for serial data. package forward import ( @@ -9,24 +9,35 @@ import ( "sync" "sync/atomic" "time" + + "go.bug.st/serial" ) // Mode is the forwarding protocol mode. type Mode int const ( - None Mode = iota - TCP - UDP + None Mode = 0 + TCP Mode = 1 + UDP Mode = 2 + TCPServer Mode = 3 + UDPServer Mode = 4 + COMPort Mode = 5 ) -// ParseMode parses a mode string. Accepts "tcp"/"tcp-c"/"tcpc"/"1" → TCP, "udp"/"udp-c"/"udpc"/"2" → UDP. +// ParseMode parses a mode string. func ParseMode(v string) (Mode, bool) { switch strings.ToLower(strings.TrimSpace(v)) { case "tcp", "tcp-c", "tcpc", "1": return TCP, true case "udp", "udp-c", "udpc", "2": return UDP, true + case "tcp-s", "tcps", "tcp-server", "3": + return TCPServer, true + case "udp-s", "udps", "udp-server", "4": + return UDPServer, true + case "com", "serial", "5": + return COMPort, true default: return None, false } @@ -34,10 +45,12 @@ func ParseMode(v string) (Mode, bool) { func (m Mode) Network() string { switch m { - case TCP: + case TCP, TCPServer: return "tcp" - case UDP: + case UDP, UDPServer: return "udp" + case COMPort: + return "serial" default: return "" } @@ -49,6 +62,12 @@ func (m Mode) String() string { return "tcp" case UDP: return "udp" + case TCPServer: + return "tcp-s" + case UDPServer: + return "udp-s" + case COMPort: + return "com" default: return "none" } @@ -70,13 +89,34 @@ type Target struct { Connected bool CreatedAt time.Time - conn net.Conn + // Client-mode connection (TCP/UDP client) + conn net.Conn + + // Server-mode fields + listener net.Listener // TCP server listener + conns map[net.Conn]struct{} // TCP server accepted connections + connsMu sync.Mutex + + // UDP server + packetConn net.PacketConn // UDP server listener + remoteAddrs map[string]net.Addr // known UDP remotes + + // COM port + serialPort serial.Port + stats Stats mu sync.Mutex closeCh chan struct{} closed bool } +// AcceptedConns returns the number of accepted connections (TCP server only). +func (t *Target) acceptedConns() int { + t.connsMu.Lock() + defer t.connsMu.Unlock() + return len(t.conns) +} + // Snapshot is a read-only view of a forward target for display. type Snapshot struct { ID int @@ -87,6 +127,7 @@ type Snapshot struct { ReadBytes uint64 WriteByte uint64 LastError string + Conns int // accepted connection count (TCP server) } // Manager coordinates forwarding targets. @@ -130,26 +171,179 @@ func (m *Manager) Add(mode Mode, address string) (int, error) { closeCh: make(chan struct{}), } - conn, err := net.Dial(mode.Network(), address) - if err != nil { - t.stats.LastError = err.Error() - return 0, err + switch mode { + case TCP, UDP: + conn, err := net.Dial(mode.Network(), address) + if err != nil { + t.stats.LastError = err.Error() + return 0, err + } + t.conn = conn + t.Connected = true + + m.mu.Lock() + t.ID = m.nextID + m.nextID++ + m.targets[t.ID] = t + m.mu.Unlock() + + go m.readLoop(t, conn, t.closeCh) + + case TCPServer: + listener, err := net.Listen("tcp", address) + if err != nil { + t.stats.LastError = err.Error() + return 0, err + } + t.listener = listener + t.conns = make(map[net.Conn]struct{}) + t.Connected = true + + m.mu.Lock() + t.ID = m.nextID + m.nextID++ + m.targets[t.ID] = t + m.mu.Unlock() + + go m.acceptLoop(t) + + case UDPServer: + pc, err := net.ListenPacket("udp", address) + if err != nil { + t.stats.LastError = err.Error() + return 0, err + } + t.packetConn = pc + t.remoteAddrs = make(map[string]net.Addr) + t.Connected = true + + m.mu.Lock() + t.ID = m.nextID + m.nextID++ + m.targets[t.ID] = t + m.mu.Unlock() + + go m.readLoopPacket(t) + + case COMPort: + sp, err := serial.Open(address, &serial.Mode{BaudRate: 115200, DataBits: 8, StopBits: 0, Parity: 0}) + if err != nil { + t.stats.LastError = err.Error() + return 0, err + } + t.serialPort = sp + t.Connected = true + + m.mu.Lock() + t.ID = m.nextID + m.nextID++ + m.targets[t.ID] = t + m.mu.Unlock() + + go m.readLoopSerial(t) } - t.conn = conn - t.Connected = true - - m.mu.Lock() - t.ID = m.nextID - m.nextID++ - m.targets[t.ID] = t - m.mu.Unlock() - - go m.readLoop(t, conn, t.closeCh) m.notify("[forward] #%d %s %s connected", t.ID, t.Mode.String(), t.Address) return t.ID, nil } +func (m *Manager) acceptLoop(t *Target) { + for { + conn, err := t.listener.Accept() + if err != nil { + select { + case <-t.closeCh: + return + default: + } + t.stats.LastError = err.Error() + m.notify("[forward] #%d accept error: %v", t.ID, err) + return + } + + t.connsMu.Lock() + t.conns[conn] = struct{}{} + t.connsMu.Unlock() + + m.notify("[forward] #%d accepted %s", t.ID, conn.RemoteAddr()) + go m.readLoop(t, conn, t.closeCh) + } +} + +func (m *Manager) readLoopPacket(t *Target) { + buf := make([]byte, 4096) + for { + n, addr, err := t.packetConn.ReadFrom(buf) + if n > 0 { + atomic.AddUint64(&t.stats.ReadBytes, uint64(n)) + chunk := make([]byte, n) + copy(chunk, buf[:n]) + if wErr := m.writeToSerial(chunk); wErr != nil { + t.stats.LastError = wErr.Error() + m.notify("[forward] #%d write serial error: %v", t.ID, wErr) + } else if m.onInbound != nil { + m.onInbound(t.ID, chunk) + } + // Track remote address for Broadcast + t.mu.Lock() + t.remoteAddrs[addr.String()] = addr + t.mu.Unlock() + } + if err != nil { + select { + case <-t.closeCh: + return + default: + } + t.Connected = false + t.stats.LastError = err.Error() + m.notify("[forward] #%d disconnected: %v", t.ID, err) + return + } + + select { + case <-t.closeCh: + return + default: + } + } +} + +func (m *Manager) readLoopSerial(t *Target) { + buf := make([]byte, 4096) + for { + n, err := t.serialPort.Read(buf) + if n > 0 { + atomic.AddUint64(&t.stats.ReadBytes, uint64(n)) + chunk := make([]byte, n) + copy(chunk, buf[:n]) + if wErr := m.writeToSerial(chunk); wErr != nil { + t.stats.LastError = wErr.Error() + m.notify("[forward] #%d write serial error: %v", t.ID, wErr) + } else if m.onInbound != nil { + m.onInbound(t.ID, chunk) + } + } + if err != nil { + select { + case <-t.closeCh: + return + default: + } + t.Connected = false + t.stats.LastError = err.Error() + m.notify("[forward] #%d disconnected: %v", t.ID, err) + return + } + + select { + case <-t.closeCh: + return + default: + } + } +} + func (m *Manager) readLoop(t *Target, conn net.Conn, stop <-chan struct{}) { buf := make([]byte, 4096) for { @@ -167,18 +361,28 @@ func (m *Manager) readLoop(t *Target, conn net.Conn, stop <-chan struct{}) { } if err != nil { - t.mu.Lock() - if t.conn == conn { - t.Connected = false - } + t.Connected = false t.stats.LastError = err.Error() - t.mu.Unlock() + + // Remove from TCP server conns if applicable + if t.Mode == TCPServer { + t.connsMu.Lock() + delete(t.conns, conn) + t.connsMu.Unlock() + } m.notify("[forward] #%d disconnected: %v", t.ID, err) + _ = conn.Close() return } select { case <-stop: + _ = conn.Close() + if t.Mode == TCPServer { + t.connsMu.Lock() + delete(t.conns, conn) + t.connsMu.Unlock() + } return default: } @@ -216,18 +420,59 @@ func (m *Manager) Enable(id int) error { return nil } - conn, err := net.Dial(t.Mode.Network(), t.Address) - if err != nil { - t.stats.LastError = err.Error() - return err + switch t.Mode { + case TCP, UDP: + conn, err := net.Dial(t.Mode.Network(), t.Address) + if err != nil { + t.stats.LastError = err.Error() + return err + } + t.conn = conn + t.Connected = true + t.closeCh = make(chan struct{}) + t.closed = false + go m.readLoop(t, conn, t.closeCh) + + case TCPServer: + listener, err := net.Listen("tcp", t.Address) + if err != nil { + t.stats.LastError = err.Error() + return err + } + t.listener = listener + t.conns = make(map[net.Conn]struct{}) + t.Connected = true + t.closeCh = make(chan struct{}) + t.closed = false + go m.acceptLoop(t) + + case UDPServer: + pc, err := net.ListenPacket("udp", t.Address) + if err != nil { + t.stats.LastError = err.Error() + return err + } + t.packetConn = pc + t.remoteAddrs = make(map[string]net.Addr) + t.Connected = true + t.closeCh = make(chan struct{}) + t.closed = false + go m.readLoopPacket(t) + + case COMPort: + sp, err := serial.Open(t.Address, &serial.Mode{BaudRate: 115200, DataBits: 8, StopBits: 0, Parity: 0}) + if err != nil { + t.stats.LastError = err.Error() + return err + } + t.serialPort = sp + t.Connected = true + t.closeCh = make(chan struct{}) + t.closed = false + go m.readLoopSerial(t) } t.Enabled = true - t.Connected = true - t.conn = conn - t.closeCh = make(chan struct{}) - t.closed = false - go m.readLoop(t, conn, t.closeCh) m.notify("[forward] #%d enabled", id) return nil } @@ -292,18 +537,67 @@ func (m *Manager) Broadcast(data []byte) { m.mu.RUnlock() for _, t := range items { - if !t.Enabled || !t.Connected || t.conn == nil { + if !t.Enabled || !t.Connected { continue } - n, err := t.conn.Write(data) - if err != nil { - t.stats.LastError = err.Error() - m.notify("[forward] #%d write error: %v", t.ID, err) - continue - } + switch t.Mode { + case TCP, UDP: + if t.conn == nil { + continue + } + n, err := t.conn.Write(data) + if err != nil { + t.stats.LastError = err.Error() + m.notify("[forward] #%d write error: %v", t.ID, err) + } else { + atomic.AddUint64(&t.stats.WrittenBytes, uint64(n)) + } - atomic.AddUint64(&t.stats.WrittenBytes, uint64(n)) + case TCPServer: + t.connsMu.Lock() + conns := make([]net.Conn, 0, len(t.conns)) + for c := range t.conns { + conns = append(conns, c) + } + t.connsMu.Unlock() + for _, c := range conns { + n, err := c.Write(data) + if err != nil { + t.stats.LastError = err.Error() + } else { + atomic.AddUint64(&t.stats.WrittenBytes, uint64(n)) + } + } + + case UDPServer: + t.mu.Lock() + addrs := make([]net.Addr, 0, len(t.remoteAddrs)) + for _, addr := range t.remoteAddrs { + addrs = append(addrs, addr) + } + t.mu.Unlock() + for _, addr := range addrs { + n, err := t.packetConn.WriteTo(data, addr) + if err != nil { + t.stats.LastError = err.Error() + } else { + atomic.AddUint64(&t.stats.WrittenBytes, uint64(n)) + } + } + + case COMPort: + if t.serialPort == nil { + continue + } + n, err := t.serialPort.Write(data) + if err != nil { + t.stats.LastError = err.Error() + m.notify("[forward] #%d write error: %v", t.ID, err) + } else { + atomic.AddUint64(&t.stats.WrittenBytes, uint64(n)) + } + } } } @@ -321,6 +615,7 @@ func (m *Manager) List() []Snapshot { ReadBytes: atomic.LoadUint64(&t.stats.ReadBytes), WriteByte: atomic.LoadUint64(&t.stats.WrittenBytes), LastError: t.stats.LastError, + Conns: t.acceptedConns(), }) } m.mu.RUnlock() @@ -356,7 +651,13 @@ func (t *Target) close() { t.closed = true ch := t.closeCh conn := t.conn + listener := t.listener + pc := t.packetConn + sp := t.serialPort t.conn = nil + t.listener = nil + t.packetConn = nil + t.serialPort = nil t.Connected = false t.mu.Unlock() @@ -366,4 +667,13 @@ func (t *Target) close() { if conn != nil { _ = conn.Close() } + if listener != nil { + _ = listener.Close() + } + if pc != nil { + _ = pc.Close() + } + if sp != nil { + _ = sp.Close() + } }