From a1524a7e1766cc36234e6a29ef7c506d3b0035b3 Mon Sep 17 00:00:00 2001 From: JiXieShi Date: Sat, 23 May 2026 21:49:43 +0800 Subject: [PATCH] refactor: extract internal/session and eliminate I/O globals Move serial port, trzsz filter, and pipe lifecycle into internal/session.SerialSession. Replace 8 global I/O vars (serialPort, trzszFilter, stdinPipe, stdoutPipe, clientIn, clientOut, termch, termchOnce) with single sess variable. Delete utils.go. Co-Authored-By: Claude Opus 4.7 --- app.go | 12 ++-- app_test.go | 10 ++- command_test.go | 6 +- internal/session/session.go | 138 ++++++++++++++++++++++++++++++++++++ main.go | 13 ++-- mutual.go | 13 ++-- utils.go | 110 ---------------------------- 7 files changed, 165 insertions(+), 137 deletions(-) create mode 100644 internal/session/session.go delete mode 100644 utils.go diff --git a/app.go b/app.go index 9394f50..ec204f0 100644 --- a/app.go +++ b/app.go @@ -156,8 +156,8 @@ func (a *App) Close() { close(a.done) a.forward.Close() a.plugins.Close() - CloseTrzsz() - CloseSerial() + sess.Close() + if a.logFile != nil { _ = a.logFile.Close() } @@ -216,7 +216,7 @@ func (a *App) writeRawToSession(data []byte) error { a.stdinMu.Lock() defer a.stdinMu.Unlock() - _, err := stdinPipe.Write(data) + _, err := sess.StdinPipe.Write(data) return err } @@ -246,7 +246,7 @@ func (a *App) sendCtrl(letter byte) error { letter = letter + ('a' - 'A') } control := []byte{letter & 0x1f} - _, err := serialPort.Write(control) + _, err := sess.Port.Write(control) return err } @@ -300,7 +300,7 @@ func (a *App) readHexOutput() { buf := make([]byte, frameSize) for { - n, err := stdoutPipe.Read(buf) + n, err := sess.StdoutPipe.Read(buf) if n > 0 { chunk := make([]byte, n) copy(chunk, buf[:n]) @@ -333,7 +333,7 @@ func (a *App) readHexOutput() { func (a *App) readTextOutput() { buf := make([]byte, 4096) for { - n, err := stdoutPipe.Read(buf) + n, err := sess.StdoutPipe.Read(buf) if n > 0 { chunk := make([]byte, n) copy(chunk, buf[:n]) diff --git a/app_test.go b/app_test.go index 6af906c..d630042 100644 --- a/app_test.go +++ b/app_test.go @@ -9,6 +9,7 @@ import ( "go.bug.st/serial" "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/event" + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/session" "github.com/jixishi/SerialTerminalForWindowsTerminal/pkg/forward" "github.com/jixishi/SerialTerminalForWindowsTerminal/pkg/luaplugin" ) @@ -204,11 +205,14 @@ func TestReportForwardIngress(t *testing.T) { } func TestSendCtrl(t *testing.T) { - oldSp := serialPort - defer func() { serialPort = oldSp }() + if sess == nil { + sess = &session.SerialSession{} + } + oldSp := sess.Port + defer func() { sess.Port = oldSp }() // Use a mock serial port - serialPort = &mockSerialPort{} + sess.Port = &mockSerialPort{} a := &App{ cfg: &Config{}, uiEvents: make(chan event.UIEvent, 4), diff --git a/command_test.go b/command_test.go index ab8b285..abeecaa 100644 --- a/command_test.go +++ b/command_test.go @@ -6,13 +6,17 @@ import ( "testing" "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/event" + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/session" "github.com/jixishi/SerialTerminalForWindowsTerminal/pkg/forward" "github.com/jixishi/SerialTerminalForWindowsTerminal/pkg/luaplugin" ) func setupTestPipes() { + if sess == nil { + sess = &session.SerialSession{} + } var cr *io.PipeReader - cr, stdinPipe = io.Pipe() + cr, sess.StdinPipe = io.Pipe() go func() { buf := make([]byte, 4096) for { diff --git a/internal/session/session.go b/internal/session/session.go new file mode 100644 index 0000000..165e76a --- /dev/null +++ b/internal/session/session.go @@ -0,0 +1,138 @@ +// Package session manages the serial port connection and its associated pipes. +package session + +import ( + "fmt" + "io" + "os" + "os/signal" + "runtime" + "sync" + + "github.com/trzsz/trzsz-go/trzsz" + "go.bug.st/serial" + "golang.org/x/term" + + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/config" +) + +// SerialSession owns the serial port, trzsz filter, and pipe pair. +type SerialSession struct { + Port serial.Port + TrzszFilter *trzsz.TrzszFilter + StdinPipe *io.PipeWriter + StdoutPipe *io.PipeReader + ClientIn *io.PipeReader + ClientOut *io.PipeWriter + + termCh chan os.Signal + closeOnce sync.Once +} + +// Open creates a SerialSession by opening the serial port and initializing trzsz. +func Open(cfg *config.Config) (*SerialSession, error) { + mode := &serial.Mode{ + BaudRate: cfg.BaudRate, + StopBits: serial.StopBits(cfg.StopBits), + DataBits: cfg.DataBits, + Parity: serial.Parity(cfg.ParityBit), + } + port, err := serial.Open(cfg.PortName, mode) + if err != nil { + return nil, err + } + + fd := int(os.Stdin.Fd()) + width, _, err := term.GetSize(fd) + if err != nil { + if runtime.GOOS != "windows" { + port.Close() + return nil, fmt.Errorf("term get size failed: %w", err) + } + width = 80 + } + + clientIn, stdinPipe := io.Pipe() + stdoutPipe, clientOut := io.Pipe() + trzszFilter := trzsz.NewTrzszFilter(clientIn, clientOut, port, port, + trzsz.TrzszOptions{TerminalColumns: int32(width), EnableZmodem: true}) + trzsz.SetAffectedByWindows(false) + + s := &SerialSession{ + Port: port, + TrzszFilter: trzszFilter, + StdinPipe: stdinPipe, + StdoutPipe: stdoutPipe, + ClientIn: clientIn, + ClientOut: clientOut, + termCh: make(chan os.Signal, 1), + } + + go func() { + for range s.termCh { + w, _, err := term.GetSize(fd) + if err != nil { + fmt.Printf("term get size failed: %s\n", err) + continue + } + trzszFilter.SetTerminalColumns(int32(w)) + } + }() + + return s, nil +} + +// Write writes data to the stdin pipe (toward serial port, through trzsz). +func (s *SerialSession) Write(data []byte) (int, error) { + return s.StdinPipe.Write(data) +} + +// Read reads data from the stdout pipe (from serial port, through trzsz). +func (s *SerialSession) Read(buf []byte) (int, error) { + return s.StdoutPipe.Read(buf) +} + +// SendCtrl sends a control character directly to the serial port (bypasses trzsz). +func (s *SerialSession) SendCtrl(letter byte) (int, error) { + if letter >= 'A' && letter <= 'Z' { + letter = letter + ('a' - 'A') + } + control := []byte{letter & 0x1f} + return s.Port.Write(control) +} + +// Close tears down the session: stops term signals, closes trzsz, then serial port. +func (s *SerialSession) Close() { + s.closeOnce.Do(func() { + if s.termCh != nil { + signal.Stop(s.termCh) + close(s.termCh) + } + if s.Port != nil { + if err := s.Port.Close(); err != nil { + fmt.Fprint(os.Stderr, err) + fmt.Fprint(os.Stderr, "\n") + } + } + }) +} + +// CheckPortAvailability returns the list of available ports and verifies the named port exists. +func CheckPortAvailability(name string) ([]string, error) { + ports, err := serial.GetPortsList() + if err != nil { + return nil, err + } + if len(ports) == 0 { + return nil, fmt.Errorf("no serial ports found") + } + if name == "" { + return ports, fmt.Errorf("port name not specified") + } + for _, port := range ports { + if port == name { + return ports, nil + } + } + return ports, fmt.Errorf("port " + name + " is not available") +} diff --git a/main.go b/main.go index ec8b0d7..99721bc 100644 --- a/main.go +++ b/main.go @@ -11,6 +11,7 @@ import ( "strconv" "strings" + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/session" "golang.org/x/term" ) @@ -35,20 +36,16 @@ func main() { if cfg.PortName == "" { getCliFlag() } - ports, err := checkPortAvailability(cfg.PortName) + ports, err := session.CheckPortAvailability(cfg.PortName) if err != nil { fmt.Println(err) printUsage(ports) os.Exit(0) } - if err = OpenSerial(); err != nil { - fmt.Fprintf(os.Stderr, "open serial failed: %v\n", err) - os.Exit(1) - } - - if err = OpenTrzsz(); err != nil { - fmt.Fprintf(os.Stderr, "open trzsz failed: %v\n", err) + sess, err = session.Open(cfg) + if err != nil { + fmt.Fprintf(os.Stderr, "open session failed: %v\n", err) os.Exit(1) } diff --git a/mutual.go b/mutual.go index 0d7467d..66d6e5a 100644 --- a/mutual.go +++ b/mutual.go @@ -1,18 +1,13 @@ package main import ( - "github.com/trzsz/trzsz-go/trzsz" - "go.bug.st/serial" "io" "os" + + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/session" ) var ( - serialPort serial.Port - out io.Writer = os.Stdout - trzszFilter *trzsz.TrzszFilter - clientIn *io.PipeReader - stdoutPipe *io.PipeReader - stdinPipe *io.PipeWriter - clientOut *io.PipeWriter + sess *session.SerialSession + out io.Writer = os.Stdout ) diff --git a/utils.go b/utils.go deleted file mode 100644 index 6f6550c..0000000 --- a/utils.go +++ /dev/null @@ -1,110 +0,0 @@ -package main - -import ( - "fmt" - "github.com/trzsz/trzsz-go/trzsz" - "go.bug.st/serial" - "golang.org/x/term" - "io" - "os" - "os/signal" - "runtime" - "strings" - "sync" -) - -func checkPortAvailability(name string) ([]string, error) { - ports, err := serial.GetPortsList() - if err != nil { - return nil, err - } - if len(ports) == 0 { - return nil, fmt.Errorf("无串口") - } - if name == "" { - return ports, fmt.Errorf("串口未指定") - } - for _, port := range ports { - if strings.Compare(port, name) == 0 { - return ports, nil - } - } - return ports, fmt.Errorf("串口 " + name + " 未在线") -} - -func OpenSerial() error { - mode := &serial.Mode{ - BaudRate: cfg.BaudRate, - StopBits: serial.StopBits(cfg.StopBits), - DataBits: cfg.DataBits, - Parity: serial.Parity(cfg.ParityBit), - } - var err error - serialPort, err = serial.Open(cfg.PortName, mode) - return err -} - -func CloseSerial() { - if serialPort == nil { - return - } - - if err := serialPort.Close(); err != nil { - fmt.Fprint(os.Stderr, err) - fmt.Fprint(os.Stderr, "\n") - } -} - -var termch chan os.Signal -var termchOnce sync.Once - -// OpenTrzsz create a TrzszFilter to support trzsz ( trz / tsz ). -// -// ┌────────┐ stdinPipe ┌────────┐ ClientIn ┌─────────────┐ SerialIn ┌────────┐ -// │ ├─────────────►│ ├─────────────►│ ├─────────────►│ │ -// │ mutual │ │ Client │ │ TrzszFilter │ │ Serial │ -// │ │◄─────────────│ │◄─────────────┤ │◄─────────────┤ │ -// └────────┘ stdoutPipe └────────┘ ClientOut └─────────────┘ SerialOut └────────┘ -func OpenTrzsz() error { - fd := int(os.Stdin.Fd()) - width, _, err := term.GetSize(fd) - if err != nil { - if runtime.GOOS != "windows" { - return fmt.Errorf("term get size failed: %w", err) - } - width = 80 - } - - clientIn, stdinPipe = io.Pipe() - stdoutPipe, clientOut = io.Pipe() - trzszFilter = trzsz.NewTrzszFilter(clientIn, clientOut, serialPort, serialPort, - trzsz.TrzszOptions{TerminalColumns: int32(width), EnableZmodem: true}) - trzsz.SetAffectedByWindows(false) - termch = make(chan os.Signal, 1) - termchOnce = sync.Once{} - - go func() { - for range termch { - width, _, err := term.GetSize(fd) - if err != nil { - fmt.Printf("term get size failed: %s\n", err) - continue - } - trzszFilter.SetTerminalColumns(int32(width)) - } - }() - - return nil -} - -func CloseTrzsz() { - if termch == nil { - return - } - - termchOnce.Do(func() { - signal.Stop(termch) - close(termch) - }) -} -