From e0de872740760655c0dd9d65695882e120e738bf Mon Sep 17 00:00:00 2001 From: JiXieShi Date: Fri, 22 May 2026 02:35:30 +0800 Subject: [PATCH] refactor: extract pkg/charset and internal/event packages Extract ConvertChunk/FormatHexFrame into pkg/charset (zero external deps). Extract UIEvent/UIEventKind/UIPanelKind types into internal/event. Update all references across main package to use qualified imports. Co-Authored-By: Claude Opus 4.7 --- .gitignore | 5 +- .goreleaser.yaml | 122 ++++----- README.md | 106 ++++---- app.go | 388 +++++++++++++++++++++++++++ app_test.go | 256 ++++++++++++++++++ command.go | 515 ++++++++++++++++++++++++++++++++---- command_test.go | 496 ++++++++++++++++++++++++++++++++++ config.go | 56 ++-- config_test.go | 238 +++++++++++++++++ escape_test.go | 123 +++++++++ flag.go | 70 ++++- forwarding.go | 311 ++++++++++++++++++++++ forwarding_test.go | 261 ++++++++++++++++++ internal/event/event.go | 30 +++ main_other.go | 5 + main_windows.go | 14 + mutual.go | 33 --- pkg/charset/charset.go | 43 +++ pkg/charset/charset_test.go | 93 +++++++ plugin.go | 262 ++++++++++++++++++ plugin_test.go | 241 +++++++++++++++++ plugins/demo.lua | 14 + tui_hotkeys.go | 153 +++++++++++ tui_model.go | 268 +++++++++++++++++++ tui_panels.go | 322 ++++++++++++++++++++++ tui_test.go | 309 ++++++++++++++++++++++ 26 files changed, 4504 insertions(+), 230 deletions(-) create mode 100644 app.go create mode 100644 app_test.go create mode 100644 command_test.go create mode 100644 config_test.go create mode 100644 escape_test.go create mode 100644 forwarding.go create mode 100644 forwarding_test.go create mode 100644 internal/event/event.go create mode 100644 main_other.go create mode 100644 main_windows.go create mode 100644 pkg/charset/charset.go create mode 100644 pkg/charset/charset_test.go create mode 100644 plugin.go create mode 100644 plugin_test.go create mode 100644 plugins/demo.lua create mode 100644 tui_hotkeys.go create mode 100644 tui_model.go create mode 100644 tui_panels.go create mode 100644 tui_test.go diff --git a/.gitignore b/.gitignore index 869e5f9..8d0fab5 100644 --- a/.gitignore +++ b/.gitignore @@ -4,5 +4,6 @@ dist/ /go.sum /view/* .claude/ -COM.exe -coverage.out \ No newline at end of file +*.exe +coverage.out +CLAUDE.md diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 1fca519..a39e2fc 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -1,61 +1,61 @@ -#file: noinspection YAMLSchemaValidation -# This is an example .goreleaser.yml file with some sensible defaults. -# Make sure to check the documentation at https://goreleaser.com - -# The lines below are called `modelines`. See `:help modeline` -# Feel free to remove those if you don't want/need to use them. -# yaml-language-server: $schema=https://goreleaser.com/static/schema.json -# vim: set ts=2 sw=2 tw=0 fo=cnqoj - -version: 1 - -before: - hooks: - # You may remove this if you don't use go modules. -# - go mod tidy - # you may remove this if you don't need go generate -# - go generate ./... - -builds: - - env: - - CGO_ENABLED=0 - goos: - - linux - - windows - - darwin - ldflags: - - -s -w - -upx: - - enabled: true - goos: - - windows - goarch: - - amd64 - -archives: - - format: tar.gz - # this name template makes the OS and Arch compatible with the results of `uname`. - name_template: >- - {{ .ProjectName }}_ - {{- title .Os }}_ - {{- if eq .Arch "amd64" }}x86_64 - {{- else if eq .Arch "386" }}i386 - {{- else }}{{ .Arch }}{{ end }} - {{- if .Arm }}v{{ .Arm }}{{ end }} - # use zip for windows archives - format_overrides: - - goos: windows - format: zip -checksum: - name_template: 'checksums.txt' - -snapshot: - name_template: 'v1.0.0-snapshot' - -changelog: - sort: asc - filters: - exclude: - - "^docs:" - - "^test:" +#file: noinspection YAMLSchemaValidation +# This is an example .goreleaser.yml file with some sensible defaults. +# Make sure to check the documentation at https://goreleaser.com + +# The lines below are called `modelines`. See `:help modeline` +# Feel free to remove those if you don't want/need to use them. +# yaml-language-server: $schema=https://goreleaser.com/static/schema.json +# vim: set ts=2 sw=2 tw=0 fo=cnqoj + +version: 1 + +before: + hooks: + # You may remove this if you don't use go modules. +# - go mod tidy + # you may remove this if you don't need go generate +# - go generate ./... + +builds: + - env: + - CGO_ENABLED=0 + goos: + - linux + - windows + - darwin + ldflags: + - -s -w + +upx: + - enabled: true + goos: + - windows + goarch: + - amd64 + +archives: + - format: tar.gz + # this name template makes the OS and Arch compatible with the results of `uname`. + name_template: >- + {{ .ProjectName }}_ + {{- title .Os }}_ + {{- if eq .Arch "amd64" }}x86_64 + {{- else if eq .Arch "386" }}i386 + {{- else }}{{ .Arch }}{{ end }} + {{- if .Arm }}v{{ .Arm }}{{ end }} + # use zip for windows archives + format_overrides: + - goos: windows + format: zip +checksum: + name_template: 'checksums.txt' + +snapshot: + name_template: 'v1.0.0-snapshot' + +changelog: + sort: asc + filters: + exclude: + - "^docs:" + - "^test:" diff --git a/README.md b/README.md index a00bc55..9b280f8 100644 --- a/README.md +++ b/README.md @@ -1,54 +1,54 @@ -# SerialTerminalForWindowsTerminal -在开始这个项目之前,我发现Windows Terminal对串口设备的支持并不理想。 - -我试用了一段时间[Zhou-zhi-peng的SerialPortForWindowsTerminal](https://github.com/Zhou-zhi-peng/SerialPortForWindowsTerminal/)项目。 - -然而,这个项目存在着编码转换的问题,导致数据显示乱码,并且作者目前并没有进行后续支持。因此,我决定创建了这个项目。 - -## 功能进展 -* [x] Hex接收发送(大写hex与原文同显) -* [x] 双向编码转换 -* [x] 活动端口探测 -* [x] 数据日志保存 -* [x] Hex断帧设置 -* [x] UDP数据转发(支持多服) -* [x] TCP数据转发(支持多服) -* [x] 参数交互配置 -* [x] Ctrl组合键 -* [x] 文件接收发送(trzsz lrzsz都支持) - -## 运行示例 - -1. 参数帮助 `./COM` - - ![img1.png](image/img1.png) - -2. 输入设备输出UTF8 终端输出GBK `./COM -p COM8 -b 115200 -o GBK` - - ![img2.png](image/img2.png) -3. 彩色终端输出 - - ![img3.png](image/img3.png) - -4. Hex接收 `./COM -p COM8 -b 115200 -i hex` - - ![img4.png](image/img4.png) -5. Hex发送 `./COM -p COM8 -b 115200` - - ![img5.png](image/img5.png) -6. 交互配置 `./COM` - - ![img6.png](image/img6.png) -7. Ctrl组合键发送指令.ctrl `.ctrl c` - - ![img7.png](image/img7.png) -8. 文件上传演示 `index.html` - ![img8.png](image/img8.png) - 内容对比 - ![img11.png](image/img11.png) -9. 时间戳 `./COM -p COM8 -t` - ![img9.png](image/img9.png) -10. 格式修改 `./COM -p COM11 -t='<2006-01-02 15:04:05>'` - ![img10.png](image/img10.png) -11. 多服同步转发 `./COM -p COM11 -f 1 -a 127.0.0.1:23456 -f 1 -a 127.0.0.1:23457` +# SerialTerminalForWindowsTerminal +在开始这个项目之前,我发现Windows Terminal对串口设备的支持并不理想。 + +我试用了一段时间[Zhou-zhi-peng的SerialPortForWindowsTerminal](https://github.com/Zhou-zhi-peng/SerialPortForWindowsTerminal/)项目。 + +然而,这个项目存在着编码转换的问题,导致数据显示乱码,并且作者目前并没有进行后续支持。因此,我决定创建了这个项目。 + +## 功能进展 +* [x] Hex接收发送(大写hex与原文同显) +* [x] 双向编码转换 +* [x] 活动端口探测 +* [x] 数据日志保存 +* [x] Hex断帧设置 +* [x] UDP数据转发(支持多服) +* [x] TCP数据转发(支持多服) +* [x] 参数交互配置 +* [x] Ctrl组合键 +* [x] 文件接收发送(trzsz lrzsz都支持) + +## 运行示例 + +1. 参数帮助 `./COM` + + ![img1.png](image/img1.png) + +2. 输入设备输出UTF8 终端输出GBK `./COM -p COM8 -b 115200 -o GBK` + + ![img2.png](image/img2.png) +3. 彩色终端输出 + + ![img3.png](image/img3.png) + +4. Hex接收 `./COM -p COM8 -b 115200 -i hex` + + ![img4.png](image/img4.png) +5. Hex发送 `./COM -p COM8 -b 115200` + + ![img5.png](image/img5.png) +6. 交互配置 `./COM` + + ![img6.png](image/img6.png) +7. Ctrl组合键发送指令.ctrl `.ctrl c` + + ![img7.png](image/img7.png) +8. 文件上传演示 `index.html` + ![img8.png](image/img8.png) + 内容对比 + ![img11.png](image/img11.png) +9. 时间戳 `./COM -p COM8 -t` + ![img9.png](image/img9.png) +10. 格式修改 `./COM -p COM11 -t='<2006-01-02 15:04:05>'` + ![img10.png](image/img10.png) +11. 多服同步转发 `./COM -p COM11 -f 1 -a 127.0.0.1:23456 -f 1 -a 127.0.0.1:23457` ![img12.png](image/img12.png) \ No newline at end of file diff --git a/app.go b/app.go new file mode 100644 index 0000000..1c87681 --- /dev/null +++ b/app.go @@ -0,0 +1,388 @@ +package main + +import ( + "bytes" + "fmt" + "io" + "os" + "path/filepath" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/event" + "github.com/jixishi/SerialTerminalForWindowsTerminal/pkg/charset" +) + +type App struct { + cfg *Config + forward *ForwardManager + plugins *PluginManager + dispatcher *CommandDispatcher + + uiEvents chan event.UIEvent + done chan struct{} + + stdinMu sync.Mutex + closeOnce sync.Once + closedFlag atomic.Bool + uiEnabled atomic.Bool + + logFile *os.File +} + +func NewApp(cfg *Config) (*App, error) { + f, err := openLogFile() + if err != nil { + return nil, err + } + + a := &App{ + cfg: cfg, + plugins: NewPluginManager(), + uiEvents: make(chan event.UIEvent, 512), + done: make(chan struct{}), + logFile: f, + } + a.uiEnabled.Store(true) + + a.forward = NewForwardManager(a.writeRawToSession, a.Notifyf) + a.forward.SetInboundReporter(a.reportForwardIngress) + a.dispatcher = NewCommandDispatcher(a) + if err = a.loadDefaultDemoPlugin(); err != nil { + return nil, err + } + return a, nil +} + +func (a *App) loadDefaultDemoPlugin() error { + demoPath := filepath.Join("plugins", "demo.lua") + if _, err := os.Stat(demoPath); err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + + name, err := a.plugins.Load(demoPath) + if err != nil { + return err + } + return a.plugins.Disable(name) +} + +func (a *App) Notifyf(format string, args ...any) { + a.emit(event.UIEvent{Kind: event.UIEventOutput, Text: fmt.Sprintf(format, args...)}) +} + +func (a *App) Statusf(format string, args ...any) { + a.emit(event.UIEvent{Kind: event.UIEventStatus, Text: fmt.Sprintf(format, args...)}) +} + +func (a *App) ShowModal(title, text string) { + a.emit(event.UIEvent{Kind: event.UIEventModal, Title: title, Text: text}) +} + +func (a *App) OpenPanel(panel event.UIPanelKind) { + a.emit(event.UIEvent{Kind: event.UIEventPanel, Panel: panel}) +} + +func (a *App) SetUIEnabled(enabled bool) { + a.uiEnabled.Store(enabled) +} + +func (a *App) UIEnabled() bool { + return a.uiEnabled.Load() +} + +func (a *App) emit(ev event.UIEvent) { + if ev.Kind != event.UIEventPanel && ev.Text == "" { + return + } + + if !a.UIEnabled() { + switch ev.Kind { + case event.UIEventOutput: + _, _ = io.WriteString(out, ev.Text) + case event.UIEventStatus: + _, _ = io.WriteString(out, ev.Text) + if !strings.HasSuffix(ev.Text, "\n") { + _, _ = io.WriteString(out, "\n") + } + case event.UIEventModal: + _, _ = io.WriteString(out, "\n["+ev.Title+"]\n"+ev.Text+"\n") + } + if ev.Kind == event.UIEventOutput { + a.appendLog(ev.Text) + } + return + } + + select { + case a.uiEvents <- ev: + default: + // Keep UI responsive; drop oldest when overloaded. + select { + case <-a.uiEvents: + default: + } + a.uiEvents <- ev + } + + if ev.Kind == event.UIEventOutput { + a.appendLog(ev.Text) + } +} + +func (a *App) appendLog(text string) { + if a.logFile == nil { + return + } + + _, _ = a.logFile.WriteString(text) +} + +func (a *App) isClosed() bool { + return a.closedFlag.Load() +} + +func (a *App) Close() { + a.closeOnce.Do(func() { + a.closedFlag.Store(true) + close(a.done) + a.forward.Close() + a.plugins.Close() + CloseTrzsz() + CloseSerial() + if a.logFile != nil { + _ = a.logFile.Close() + } + }) +} + +func (a *App) waitDone() <-chan struct{} { + return a.done +} + +func (a *App) loadConfiguredForwards() { + for i, mode := range config.forWard { + m := FoeWardMode(mode) + if m == NOT { + continue + } + if i >= len(config.address) { + a.Notifyf("[forward] skip #%d: missing address", i) + continue + } + addr := strings.TrimSpace(config.address[i]) + if addr == "" { + continue + } + if _, err := a.forward.Add(m, addr); err != nil { + a.Notifyf("[forward] add %s %s failed: %v", m.String(), addr, err) + } + } +} + +func (a *App) reportForwardIngress(id int, chunk []byte) { + if len(chunk) == 0 { + return + } + + if strings.EqualFold(a.cfg.inputCode, "hex") { + a.Notifyf("[forward#%d -> serial] % X\n", id, chunk) + return + } + + converted, err := charset.ConvertChunk(chunk, a.cfg.inputCode, a.cfg.outputCode) + if err != nil { + converted = bytes.Clone(chunk) + } + text := string(converted) + if !strings.HasSuffix(text, "\n") { + text += "\n" + } + a.Notifyf("[forward#%d -> serial] %s", id, text) +} + +func (a *App) writeRawToSession(data []byte) error { + if len(data) == 0 { + return nil + } + + a.stdinMu.Lock() + defer a.stdinMu.Unlock() + _, err := stdinPipe.Write(data) + return err +} + +func (a *App) writeToSession(data []byte) error { + processed, err := a.plugins.ProcessInput(data) + if err != nil { + return err + } + if len(processed) == 0 { + return nil + } + + return a.writeRawToSession(processed) +} + +func (a *App) sendLine(line string) error { + if strings.TrimSpace(line) == "" { + return nil + } + + payload := append([]byte(line), []byte(a.cfg.endStr)...) + return a.writeToSession(payload) +} + +func (a *App) sendCtrl(letter byte) error { + if letter >= 'A' && letter <= 'Z' { + letter = letter + ('a' - 'A') + } + control := []byte{letter & 0x1f} + _, err := serialPort.Write(control) + return err +} + +func (a *App) handleLine(line string) { + line = strings.TrimRight(line, "\r\n") + if strings.TrimSpace(line) == "" { + return + } + + if strings.HasPrefix(strings.TrimSpace(line), ".") { + next, allow, err := a.plugins.ProcessCommand(line) + if err != nil { + a.Notifyf("[plugin] command hook failed: %v", err) + return + } + if !allow { + a.Notifyf("[plugin] command blocked") + return + } + if next != "" { + line = next + } + handled, err := a.dispatcher.Execute(line) + if err != nil { + a.Statusf("[cmd] %v", err) + } + if handled { + return + } + } + + if err := a.sendLine(line); err != nil { + a.Statusf("[send] %v", err) + } +} + +func (a *App) startOutputLoop() { + if strings.EqualFold(a.cfg.inputCode, "hex") { + go a.readHexOutput() + return + } + + go a.readTextOutput() +} + +func (a *App) readHexOutput() { + frameSize := a.cfg.frameSize + if frameSize <= 0 { + frameSize = 16 + } + + buf := make([]byte, frameSize) + for { + n, err := stdoutPipe.Read(buf) + if n > 0 { + chunk := make([]byte, n) + copy(chunk, buf[:n]) + a.forward.Broadcast(chunk) + outChunk, hookErr := a.plugins.ProcessOutput(chunk) + if hookErr != nil { + a.Notifyf("[plugin] output hook failed: %v", hookErr) + continue + } + if len(outChunk) == 0 { + continue + } + a.emit(event.UIEvent{Kind: event.UIEventOutput, Text: charset.FormatHexFrame(outChunk, a.cfg.timesTamp, a.cfg.timesFmt)}) + } + if err != nil { + if err != io.EOF { + a.Notifyf("[output] %v", err) + } + return + } + + select { + case <-a.done: + return + default: + } + } +} + +func (a *App) readTextOutput() { + buf := make([]byte, 4096) + for { + n, err := stdoutPipe.Read(buf) + if n > 0 { + chunk := make([]byte, n) + copy(chunk, buf[:n]) + a.forward.Broadcast(chunk) + + outChunk, hookErr := a.plugins.ProcessOutput(chunk) + if hookErr != nil { + a.Notifyf("[plugin] output hook failed: %v", hookErr) + continue + } + if len(outChunk) == 0 { + continue + } + + converted, convErr := charset.ConvertChunk(outChunk, a.cfg.inputCode, a.cfg.outputCode) + if convErr != nil { + a.Notifyf("[output] convert failed: %v", convErr) + converted = bytes.Clone(outChunk) + } + + text := string(converted) + if a.cfg.timesTamp { + text = prefixLines(text, time.Now().Format(a.cfg.timesFmt)+" ") + } + a.emit(event.UIEvent{Kind: event.UIEventOutput, Text: text}) + } + if err != nil { + if err != io.EOF { + a.Notifyf("[output] %v", err) + } + return + } + + select { + case <-a.done: + return + default: + } + } +} + +func prefixLines(s, prefix string) string { + if s == "" || prefix == "" { + return s + } + + lines := strings.SplitAfter(s, "\n") + for i, line := range lines { + if line == "" { + continue + } + lines[i] = prefix + line + } + return strings.Join(lines, "") +} diff --git a/app_test.go b/app_test.go new file mode 100644 index 0000000..40a3a0b --- /dev/null +++ b/app_test.go @@ -0,0 +1,256 @@ +package main + +import ( + "io" + "net" + "testing" + "time" + + "go.bug.st/serial" + + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/event" +) + +func TestPrefixLines(t *testing.T) { + tests := []struct { + name string + in string + prefix string + want string + }{ + {name: "empty", in: "", prefix: "X ", want: ""}, + {name: "no-prefix", in: "a\n", prefix: "", want: "a\n"}, + {name: "single-line", in: "abc", prefix: "T ", want: "T abc"}, + {name: "multi-line", in: "a\nb\n", prefix: "P ", want: "P a\nP b\n"}, + } + + for _, tt := range tests { + got := prefixLines(tt.in, tt.prefix) + if got != tt.want { + t.Fatalf("%s: prefixLines got=%q want=%q", tt.name, got, tt.want) + } + } +} + +func TestAppUIEvents(t *testing.T) { + a := &App{uiEvents: make(chan event.UIEvent, 8)} + a.SetUIEnabled(true) + + a.Notifyf("hello %s", "world") + a.Statusf("ok") + a.ShowModal("Title", "Body") + + ev1 := mustReadEvent(t, a.uiEvents) + if ev1.Kind != event.UIEventOutput || ev1.Text != "hello world" { + t.Fatalf("unexpected output event: %+v", ev1) + } + + ev2 := mustReadEvent(t, a.uiEvents) + if ev2.Kind != event.UIEventStatus || ev2.Text != "ok" { + t.Fatalf("unexpected status event: %+v", ev2) + } + + ev3 := mustReadEvent(t, a.uiEvents) + if ev3.Kind != event.UIEventModal || ev3.Title != "Title" || ev3.Text != "Body" { + t.Fatalf("unexpected modal event: %+v", ev3) + } +} + +func TestSendLine(t *testing.T) { + setupTestPipes() + a := &App{ + cfg: &Config{endStr: "\r\n"}, + plugins: NewPluginManager(), + uiEvents: make(chan event.UIEvent, 8), + done: make(chan struct{}), + } + a.SetUIEnabled(true) + + if err := a.sendLine("hello"); err != nil { + t.Fatalf("sendLine failed: %v", err) + } + + if err := a.sendLine(""); err != nil { + t.Fatalf("sendLine empty string should be no-op: %v", err) + } + if err := a.sendLine(" "); err != nil { + t.Fatalf("sendLine whitespace should be no-op: %v", err) + } +} + +func TestHandleLine(t *testing.T) { + setupTestPipes() + a := &App{ + cfg: &Config{endStr: "\n", inputCode: "UTF-8", outputCode: "UTF-8"}, + plugins: NewPluginManager(), + uiEvents: make(chan event.UIEvent, 8), + done: make(chan struct{}), + } + a.SetUIEnabled(true) + a.forward = NewForwardManager(func([]byte) error { return nil }, func(string, ...any) {}) + a.dispatcher = NewCommandDispatcher(a) + + a.handleLine("hello") + a.handleLine("") + a.handleLine(".help") + + ev := mustReadEvent(t, a.uiEvents) + if ev.Kind != event.UIEventModal || ev.Title == "" { + t.Fatalf("expected .help modal, got %+v", ev) + } +} + +func TestEmitNonUI(t *testing.T) { + oldOut := out + out = io.Discard + defer func() { out = oldOut }() + + a := &App{ + uiEvents: make(chan event.UIEvent, 4), + logFile: nil, + } + a.SetUIEnabled(false) + + a.emit(event.UIEvent{Kind: event.UIEventOutput, Text: "serial data\n"}) + a.emit(event.UIEvent{Kind: event.UIEventStatus, Text: "status msg"}) + a.emit(event.UIEvent{Kind: event.UIEventModal, Title: "T", Text: "body"}) + a.emit(event.UIEvent{Kind: event.UIEventOutput, Text: ""}) +} + +func TestEmitUISaturation(t *testing.T) { + a := &App{ + uiEvents: make(chan event.UIEvent, 2), + } + a.SetUIEnabled(true) + + // Fill channel + a.emit(event.UIEvent{Kind: event.UIEventOutput, Text: "a"}) + a.emit(event.UIEvent{Kind: event.UIEventOutput, Text: "b"}) + // This should drop oldest and insert newest + a.emit(event.UIEvent{Kind: event.UIEventOutput, Text: "c"}) + + ev := mustReadEvent(t, a.uiEvents) + if ev.Text != "b" { + t.Fatalf("expected b after drop, got %q", ev.Text) + } + ev = mustReadEvent(t, a.uiEvents) + if ev.Text != "c" { + t.Fatalf("expected c, got %q", ev.Text) + } +} + +func TestAppClose(t *testing.T) { + a := &App{ + done: make(chan struct{}), + plugins: NewPluginManager(), + forward: NewForwardManager(func([]byte) error { return nil }, func(string, ...any) {}), + uiEvents: make(chan event.UIEvent, 4), + } + a.SetUIEnabled(true) + + a.Close() + if !a.isClosed() { + t.Fatalf("expected app closed") + } + // Second close should be safe + a.Close() +} + +func TestLoadConfiguredForwards(t *testing.T) { + oldCfg := config + defer func() { config = oldCfg }() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen failed: %v", err) + } + defer listener.Close() + + config = Config{ + forWard: []int{int(TCPC), int(NOT), int(UDPC)}, + address: []string{listener.Addr().String(), "", ""}, + } + + a := &App{ + cfg: &config, + forward: NewForwardManager(func([]byte) error { return nil }, func(string, ...any) {}), + uiEvents: make(chan event.UIEvent, 8), + done: make(chan struct{}), + } + a.SetUIEnabled(true) + + a.loadConfiguredForwards() + // TCPC should be added, NOT skipped, UDPC skipped (empty address) + items := a.forward.List() + if len(items) != 1 || items[0].Mode != "tcp" { + t.Fatalf("expected 1 TCP forward, got %+v", items) + } +} + +func TestReportForwardIngress(t *testing.T) { + a := &App{ + cfg: &Config{inputCode: "UTF-8", outputCode: "UTF-8"}, + uiEvents: make(chan event.UIEvent, 4), + } + a.SetUIEnabled(true) + + a.reportForwardIngress(1, []byte("test")) + + // Hex mode + a.cfg.inputCode = "hex" + a.reportForwardIngress(2, []byte{0x41, 0x42}) + + // Empty chunk + a.reportForwardIngress(3, nil) +} + +func TestSendCtrl(t *testing.T) { + oldSp := serialPort + defer func() { serialPort = oldSp }() + + // Use a mock serial port + serialPort = &mockSerialPort{} + a := &App{ + cfg: &Config{}, + uiEvents: make(chan event.UIEvent, 4), + } + a.SetUIEnabled(true) + + if err := a.sendCtrl('c'); err != nil { + t.Fatalf("sendCtrl('c') failed: %v", err) + } + if err := a.sendCtrl('C'); err != nil { + t.Fatalf("sendCtrl('C') failed: %v", err) + } + if err := a.sendCtrl('A'); err != nil { + t.Fatalf("sendCtrl('A') failed: %v", err) + } +} + +type mockSerialPort struct{} + +func (m *mockSerialPort) Write(p []byte) (int, error) { return len(p), nil } +func (m *mockSerialPort) Read(p []byte) (int, error) { return 0, io.EOF } +func (m *mockSerialPort) Close() error { return nil } +func (m *mockSerialPort) SetMode(mode *serial.Mode) error { return nil } +func (m *mockSerialPort) SetDTR(dtr bool) error { return nil } +func (m *mockSerialPort) SetRTS(rts bool) error { return nil } +func (m *mockSerialPort) GetModemStatusBits() (*serial.ModemStatusBits, error) { + return &serial.ModemStatusBits{}, nil +} +func (m *mockSerialPort) ResetInputBuffer() error { return nil } +func (m *mockSerialPort) ResetOutputBuffer() error { return nil } +func (m *mockSerialPort) SetReadTimeout(t time.Duration) error { return nil } +func (m *mockSerialPort) Break(t time.Duration) error { return nil } +func (m *mockSerialPort) Drain() error { return nil } + +func mustReadEvent(t *testing.T, ch <-chan event.UIEvent) event.UIEvent { + t.Helper() + select { + case ev := <-ch: + return ev + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for UI event") + return event.UIEvent{} + } +} diff --git a/command.go b/command.go index dfe9c35..68bd473 100644 --- a/command.go +++ b/command.go @@ -3,62 +3,475 @@ package main import ( "encoding/hex" "fmt" - "log" - "os" + "sort" + "strconv" "strings" + + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/event" ) -type Command struct { - name string - description string - function func() +type CommandHandler func(args []string) error +type CommandCompleter func(args []string) []string + +type RuntimeCommand struct { + Name string + Usage string + Description string + Handler CommandHandler + Completer CommandCompleter } -var ( - commands []Command - args []string -) +type CommandDispatcher struct { + app *App + commands map[string]*RuntimeCommand + order []string +} -func cmdhelp() { - var page = 0 - strout(out, config.outputCode, fmt.Sprintf(">-------Help(%v)-------<\n", page)) - for i := 0; i < len(commands); i++ { - strout(out, config.outputCode, fmt.Sprintf(" %-10v --%v\n", commands[i].name, commands[i].description)) +func NewCommandDispatcher(app *App) *CommandDispatcher { + d := &CommandDispatcher{ + app: app, + commands: make(map[string]*RuntimeCommand), + } + + d.registerAll() + return d +} + +func (d *CommandDispatcher) register(cmd RuntimeCommand) { + key := strings.ToLower(cmd.Name) + d.commands[key] = &cmd + d.order = append(d.order, key) +} + +func (d *CommandDispatcher) registerAll() { + d.register(RuntimeCommand{ + Name: ".help", + Usage: ".help", + Description: "show command help", + Handler: func(args []string) error { + d.app.ShowModal("Command Help", d.HelpText()) + return nil + }, + }) + + d.register(RuntimeCommand{ + Name: ".exit", + Usage: ".exit", + Description: "exit local terminal", + Handler: func(args []string) error { + d.app.Statusf("[local] exiting") + d.app.Close() + return nil + }, + }) + + d.register(RuntimeCommand{ + Name: ".hex", + Usage: ".hex ", + Description: "send raw hex bytes", + Handler: func(args []string) error { + if len(args) < 2 { + return fmt.Errorf("usage: .hex ") + } + hexStr := strings.Join(args[1:], "") + b, err := hex.DecodeString(hexStr) + if err != nil { + return err + } + return d.app.writeToSession(b) + }, + }) + + d.register(RuntimeCommand{ + Name: ".forward", + Usage: ".forward ", + Description: "manage forwarding at runtime", + Handler: d.handleForwardCommand, + Completer: completeForward, + }) + + d.register(RuntimeCommand{ + Name: ".plugin", + Usage: ".plugin ", + Description: "manage lua plugins", + Handler: d.handlePluginCommand, + Completer: completePlugin, + }) + + d.register(RuntimeCommand{ + Name: ".mode", + Usage: ".mode ", + Description: "show or update runtime terminal mode", + Handler: func(args []string) error { + return d.handleModeCommand(args) + }, + Completer: completeMode, + }) +} + +func (d *CommandDispatcher) Execute(line string) (bool, error) { + args := strings.Fields(strings.TrimSpace(line)) + if len(args) == 0 { + return false, nil + } + if !strings.HasPrefix(args[0], ".") { + return false, nil + } + + cmd, ok := d.commands[strings.ToLower(args[0])] + if !ok { + return true, fmt.Errorf("unknown command: %s", args[0]) + } + + if err := cmd.Handler(args); err != nil { + return true, err + } + return true, nil +} + +func (d *CommandDispatcher) HelpText() string { + keys := make([]string, 0, len(d.order)) + for _, k := range d.order { + keys = append(keys, k) + } + sort.Strings(keys) + + var b strings.Builder + b.WriteString("Commands:\n") + for _, k := range keys { + cmd := d.commands[k] + b.WriteString(fmt.Sprintf(" %-12s %-40s %s\n", cmd.Name, cmd.Usage, cmd.Description)) + } + return b.String() +} + +func (d *CommandDispatcher) Complete(line string) (string, []string) { + trimmed := strings.TrimLeft(line, " ") + if trimmed == "" { + return line, nil + } + + args := strings.Fields(trimmed) + endsWithSpace := strings.HasSuffix(line, " ") + + if len(args) == 0 { + return line, nil + } + + if len(args) == 1 && !endsWithSpace { + return completeFirstToken(line, args[0], d.commandNames()) + } + + cmdName := strings.ToLower(args[0]) + cmd, ok := d.commands[cmdName] + if !ok || cmd.Completer == nil { + return line, nil + } + + compArgs := args + if endsWithSpace { + compArgs = append(compArgs, "") + } + + cands := cmd.Completer(compArgs) + if len(cands) == 0 { + return line, nil + } + + current := compArgs[len(compArgs)-1] + base := strings.TrimSuffix(line, current) + + matches := filterPrefix(cands, current) + if len(matches) == 0 { + matches = cands + } + if len(matches) == 1 { + return base + matches[0], matches + } + + return line, matches +} + +func (d *CommandDispatcher) commandNames() []string { + names := make([]string, 0, len(d.commands)) + for _, cmd := range d.commands { + names = append(names, cmd.Name) + } + sort.Strings(names) + return names +} + +func completeFirstToken(line, token string, cands []string) (string, []string) { + matches := filterPrefix(cands, token) + if len(matches) == 0 { + return line, nil + } + if len(matches) == 1 { + prefix := strings.TrimSuffix(line, token) + return prefix + matches[0] + " ", matches + } + return line, matches +} + +func filterPrefix(cands []string, cur string) []string { + if cur == "" { + return append([]string{}, cands...) + } + res := make([]string, 0, len(cands)) + for _, c := range cands { + if strings.HasPrefix(strings.ToLower(c), strings.ToLower(cur)) { + res = append(res, c) + } + } + return res +} + +func completeForward(args []string) []string { + if len(args) <= 2 { + return []string{"list", "add", "remove", "enable", "disable", "update", "stats"} + } + + if len(args) == 3 && args[1] == "add" { + return []string{"tcp", "udp"} + } + + if len(args) == 4 && args[1] == "update" { + return []string{"tcp", "udp"} + } + + return nil +} + +func completePlugin(args []string) []string { + if len(args) <= 2 { + return []string{"list", "load", "unload", "enable", "disable", "reload"} + } + return nil +} + +func completeMode(args []string) []string { + if len(args) <= 2 { + return []string{"show", "set"} + } + + if len(args) == 3 && args[1] == "set" { + return []string{"in", "out", "end", "frame", "timestamp", "timefmt"} + } + + if len(args) == 4 && args[1] == "set" && args[2] == "timestamp" { + return []string{"on", "off"} + } + + return nil +} + +func (d *CommandDispatcher) handleForwardCommand(args []string) error { + if len(args) < 2 { + if d.app.UIEnabled() { + d.app.OpenPanel(event.UIPanelForward) + return nil + } + args = []string{".forward", "list"} + } + + sub := strings.ToLower(args[1]) + switch sub { + case "list", "stats": + if d.app.UIEnabled() { + d.app.OpenPanel(event.UIPanelForward) + return nil + } + + items := d.app.forward.List() + if len(items) == 0 { + d.app.Notifyf("[forward] empty") + return nil + } + d.app.Notifyf("[forward] ID Mode Enabled Connected Address InBytes OutBytes LastError") + for _, it := range items { + d.app.Notifyf("[forward] %d %s %v %v %s %d %d %s", it.ID, it.Mode, it.Enabled, it.Connected, it.Address, it.ReadBytes, it.WriteByte, it.LastError) + } + return nil + + case "add": + if len(args) < 4 { + return fmt.Errorf("usage: .forward add
") + } + mode, ok := parseForwardMode(args[2]) + if !ok { + return fmt.Errorf("unknown forward mode: %s", args[2]) + } + id, err := d.app.forward.Add(mode, args[3]) + if err != nil { + return err + } + d.app.Statusf("[forward] added #%d", id) + return nil + + case "remove", "enable", "disable": + if len(args) < 3 { + return fmt.Errorf("usage: .forward %s ", sub) + } + id, err := strconv.Atoi(args[2]) + if err != nil { + return err + } + switch sub { + case "remove": + return d.app.forward.Remove(id) + case "enable": + return d.app.forward.Enable(id) + case "disable": + return d.app.forward.Disable(id) + } + + case "update": + if len(args) < 5 { + return fmt.Errorf("usage: .forward update
") + } + id, err := strconv.Atoi(args[2]) + if err != nil { + return err + } + mode, ok := parseForwardMode(args[3]) + if !ok { + return fmt.Errorf("unknown forward mode: %s", args[3]) + } + if err = d.app.forward.Update(id, mode, args[4]); err != nil { + return err + } + d.app.Statusf("[forward] updated #%d", id) + return nil + } + + return fmt.Errorf("unknown subcommand: %s", sub) +} + +func (d *CommandDispatcher) handlePluginCommand(args []string) error { + if len(args) < 2 { + if d.app.UIEnabled() { + d.app.OpenPanel(event.UIPanelPlugin) + return nil + } + args = []string{".plugin", "list"} + } + + sub := strings.ToLower(args[1]) + switch sub { + case "list": + if d.app.UIEnabled() { + d.app.OpenPanel(event.UIPanelPlugin) + return nil + } + + items := d.app.plugins.List() + if len(items) == 0 { + d.app.Notifyf("[plugin] empty") + return nil + } + for _, it := range items { + d.app.Notifyf("[plugin] %s enabled=%v path=%s", it.Name, it.Enabled, it.Path) + } + return nil + + case "load": + if len(args) < 3 { + return fmt.Errorf("usage: .plugin load ") + } + name, err := d.app.plugins.Load(args[2]) + if err != nil { + return err + } + d.app.Statusf("[plugin] loaded %s", name) + return nil + + case "unload", "enable", "disable", "reload": + if len(args) < 3 { + return fmt.Errorf("usage: .plugin %s ", sub) + } + name := args[2] + switch sub { + case "unload": + return d.app.plugins.Unload(name) + case "enable": + return d.app.plugins.Enable(name) + case "disable": + return d.app.plugins.Disable(name) + case "reload": + return d.app.plugins.Reload(name) + } + } + + return fmt.Errorf("unknown subcommand: %s", sub) +} + +func (d *CommandDispatcher) handleModeCommand(args []string) error { + if len(args) < 2 || strings.EqualFold(args[1], "show") { + if d.app.UIEnabled() { + d.app.OpenPanel(event.UIPanelMode) + return nil + } + + d.app.Notifyf("[mode] input=%s output=%s end=%q hex=%v frame=%d timestamp=%v timefmt=%q forwardTargets=%d plugins=%d", + d.app.cfg.inputCode, + d.app.cfg.outputCode, + d.app.cfg.endStr, + strings.EqualFold(d.app.cfg.inputCode, "hex"), + d.app.cfg.frameSize, + d.app.cfg.timesTamp, + d.app.cfg.timesFmt, + len(d.app.forward.List()), + len(d.app.plugins.List()), + ) + return nil + } + + if !strings.EqualFold(args[1], "set") { + return fmt.Errorf("usage: .mode ") + } + if len(args) < 4 { + return fmt.Errorf("usage: .mode set ") + } + + field := strings.ToLower(args[2]) + value := strings.Join(args[3:], " ") + + switch field { + case "in": + d.app.cfg.inputCode = value + case "out": + d.app.cfg.outputCode = value + case "end": + d.app.cfg.endStr = value + case "frame": + n, err := strconv.Atoi(value) + if err != nil || n <= 0 { + return fmt.Errorf("frame must be a positive integer") + } + d.app.cfg.frameSize = n + case "timestamp": + enabled, ok := parseOnOff(value) + if !ok { + return fmt.Errorf("timestamp value must be on/off") + } + d.app.cfg.timesTamp = enabled + case "timefmt": + d.app.cfg.timesFmt = value + default: + return fmt.Errorf("unknown mode field: %s", field) + } + + d.app.Statusf("[mode] %s=%q", field, value) + return nil +} + +func parseOnOff(v string) (bool, bool) { + switch strings.ToLower(strings.TrimSpace(v)) { + case "on", "true", "1", "yes": + return true, true + case "off", "false", "0", "no": + return false, true + default: + return false, false } } -func cmdexit() { - CloseTrzsz() - CloseSerial() - os.Exit(0) -} -func cmdargs() { - strout(out, config.outputCode, fmt.Sprintf(">-------Args(%v)-------<\n", len(args)-1)) - strout(out, config.outputCode, fmt.Sprintf("%q\n", args[1:])) -} -func cmdctrl() { - var err error - b := []byte(args[1]) - x := []byte{b[0] & 0x1f} - _, err = serialPort.Write(x) - ErrorF(err) - strout(out, config.outputCode, fmt.Sprintf("Ctrl+%s\n", b)) -} -func cmdhex() { - strout(out, config.outputCode, fmt.Sprintf(">-----Hex Send-----<\n")) - strout(out, config.outputCode, fmt.Sprintf("%q\n", args[1:])) - s := strings.Join(args[1:], "") - b, err := hex.DecodeString(s) - if err != nil { - log.Fatal(err) - } - _, err = serialPort.Write(b) - if err != nil { - log.Fatal(err) - } -} -func cmdinit() { - commands = append(commands, Command{name: ".help", description: "帮助信息", function: cmdhelp}) - commands = append(commands, Command{name: ".ctrl", description: "发送Ctrl组合键", function: cmdctrl}) - commands = append(commands, Command{name: ".hex", description: "发送Hex", function: cmdhex}) - commands = append(commands, Command{name: ".exit", description: "退出终端", function: cmdexit}) -} diff --git a/command_test.go b/command_test.go new file mode 100644 index 0000000..4ff83dd --- /dev/null +++ b/command_test.go @@ -0,0 +1,496 @@ +package main + +import ( + "io" + "strings" + "testing" + + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/event" +) + +func setupTestPipes() { + var cr *io.PipeReader + cr, stdinPipe = io.Pipe() + go func() { + buf := make([]byte, 4096) + for { + _, err := cr.Read(buf) + if err != nil { + return + } + } + }() +} + +func newTestAppForCommand() *App { + a := &App{ + cfg: &Config{inputCode: "UTF-8", outputCode: "UTF-8", endStr: "\n"}, + plugins: NewPluginManager(), + uiEvents: make(chan event.UIEvent, 32), + done: make(chan struct{}), + } + a.SetUIEnabled(true) + a.forward = NewForwardManager(func([]byte) error { return nil }, func(string, ...any) {}) + a.dispatcher = NewCommandDispatcher(a) + return a +} + +func TestCommandCompleteRoot(t *testing.T) { + a := newTestAppForCommand() + line, cands := a.dispatcher.Complete(".") + if line != "." { + t.Fatalf("expected line unchanged for ambiguous completion, got %q", line) + } + if len(cands) == 0 { + t.Fatalf("expected root command candidates") + } + for _, c := range cands { + if c == ".ctrl" { + t.Fatalf(".ctrl should be removed from command set") + } + } +} + +func TestCommandCompleteForwardSubcommands(t *testing.T) { + a := newTestAppForCommand() + _, cands := a.dispatcher.Complete(".forward ") + joined := strings.Join(cands, ",") + for _, name := range []string{"list", "add", "remove", "enable", "disable", "update", "stats"} { + if !strings.Contains(joined, name) { + t.Fatalf("missing forward candidate %q in %v", name, cands) + } + } +} + +func TestCommandExecuteUnknown(t *testing.T) { + a := newTestAppForCommand() + handled, err := a.dispatcher.Execute(".unknown") + if !handled { + t.Fatalf("unknown command should be marked handled") + } + if err == nil { + t.Fatalf("expected unknown command error") + } +} + +func TestCommandExecuteHelpShowsModal(t *testing.T) { + a := newTestAppForCommand() + handled, err := a.dispatcher.Execute(".help") + if err != nil || !handled { + t.Fatalf(".help execute failed handled=%v err=%v", handled, err) + } + + ev := mustReadEvent(t, a.uiEvents) + if ev.Kind != event.UIEventModal || ev.Title == "" { + t.Fatalf("expected help modal event, got %+v", ev) + } +} + +func TestCommandExecuteForwardListShowsPanel(t *testing.T) { + a := newTestAppForCommand() + handled, err := a.dispatcher.Execute(".forward list") + if err != nil || !handled { + t.Fatalf(".forward list execute failed handled=%v err=%v", handled, err) + } + + ev := mustReadEvent(t, a.uiEvents) + if ev.Kind != event.UIEventPanel || ev.Panel != event.UIPanelForward { + t.Fatalf("expected forward panel event, got %+v", ev) + } +} + +func TestCommandExecutePluginListShowsPanel(t *testing.T) { + a := newTestAppForCommand() + if _, err := a.plugins.Load("plugins/demo.lua"); err == nil { + _ = a.plugins.Disable("demo") + } + handled, err := a.dispatcher.Execute(".plugin list") + if err != nil || !handled { + t.Fatalf(".plugin list execute failed handled=%v err=%v", handled, err) + } + + ev := mustReadEvent(t, a.uiEvents) + if ev.Kind != event.UIEventPanel || ev.Panel != event.UIPanelPlugin { + t.Fatalf("expected plugin panel event, got %+v", ev) + } +} + +func TestCommandExecutePluginWithoutSubcommandShowsPanel(t *testing.T) { + a := newTestAppForCommand() + handled, err := a.dispatcher.Execute(".plugin") + if err != nil || !handled { + t.Fatalf(".plugin execute failed handled=%v err=%v", handled, err) + } + + ev := mustReadEvent(t, a.uiEvents) + if ev.Kind != event.UIEventPanel || ev.Panel != event.UIPanelPlugin { + t.Fatalf("expected plugin panel event for bare command, got %+v", ev) + } +} + +func TestCommandExecuteModeShowsPanel(t *testing.T) { + a := newTestAppForCommand() + handled, err := a.dispatcher.Execute(".mode show") + if err != nil || !handled { + t.Fatalf(".mode execute failed handled=%v err=%v", handled, err) + } + + ev := mustReadEvent(t, a.uiEvents) + if ev.Kind != event.UIEventPanel || ev.Panel != event.UIPanelMode { + t.Fatalf("expected mode panel event, got %+v", ev) + } +} + +func TestCommandExecuteModeSet(t *testing.T) { + a := newTestAppForCommand() + handled, err := a.dispatcher.Execute(".mode set end \\r\\n") + if err != nil || !handled { + t.Fatalf(".mode set end failed handled=%v err=%v", handled, err) + } + if a.cfg.endStr != "\\r\\n" { + t.Fatalf("mode set end not applied, got=%q", a.cfg.endStr) + } + + handled, err = a.dispatcher.Execute(".mode set timestamp on") + if err != nil || !handled { + t.Fatalf(".mode set timestamp failed handled=%v err=%v", handled, err) + } + if !a.cfg.timesTamp { + t.Fatalf("mode set timestamp should enable timesTamp") + } +} + +func TestParseOnOff(t *testing.T) { + tests := []struct { + in string + val bool + valid bool + }{ + {in: "on", val: true, valid: true}, + {in: "true", val: true, valid: true}, + {in: "1", val: true, valid: true}, + {in: "yes", val: true, valid: true}, + {in: "off", val: false, valid: true}, + {in: "false", val: false, valid: true}, + {in: "0", val: false, valid: true}, + {in: "no", val: false, valid: true}, + {in: "", val: false, valid: false}, + {in: "maybe", val: false, valid: false}, + } + + for _, tt := range tests { + got, ok := parseOnOff(tt.in) + if ok != tt.valid || got != tt.val { + t.Fatalf("parseOnOff(%q) got=(%v,%v) want=(%v,%v)", tt.in, got, ok, tt.val, tt.valid) + } + } +} + +func TestCompleteForward(t *testing.T) { + tests := []struct { + args []string + want []string + }{ + {args: []string{".forward"}, want: []string{"list", "add", "remove", "enable", "disable", "update", "stats"}}, + {args: []string{".forward", ""}, want: []string{"list", "add", "remove", "enable", "disable", "update", "stats"}}, + {args: []string{".forward", "add", ""}, want: []string{"tcp", "udp"}}, + {args: []string{".forward", "update", "1", ""}, want: []string{"tcp", "udp"}}, + {args: []string{".forward", "list", "1"}, want: nil}, + } + for _, tt := range tests { + got := completeForward(tt.args) + if !stringSlicesEqual(got, tt.want) { + t.Fatalf("completeForward(%v) got=%v want=%v", tt.args, got, tt.want) + } + } +} + +func TestCompletePlugin(t *testing.T) { + tests := []struct { + args []string + want []string + }{ + {args: []string{".plugin"}, want: []string{"list", "load", "unload", "enable", "disable", "reload"}}, + {args: []string{".plugin", "load", ""}, want: nil}, + {args: []string{".plugin", "unload", "demo"}, want: nil}, + } + for _, tt := range tests { + got := completePlugin(tt.args) + if !stringSlicesEqual(got, tt.want) { + t.Fatalf("completePlugin(%v) got=%v want=%v", tt.args, got, tt.want) + } + } +} + +func TestCompleteMode(t *testing.T) { + tests := []struct { + args []string + want []string + }{ + {args: []string{".mode"}, want: []string{"show", "set"}}, + {args: []string{".mode", "set", ""}, want: []string{"in", "out", "end", "frame", "timestamp", "timefmt"}}, + {args: []string{".mode", "set", "timestamp", ""}, want: []string{"on", "off"}}, + {args: []string{".mode", "set", "in", ""}, want: nil}, + } + for _, tt := range tests { + got := completeMode(tt.args) + if !stringSlicesEqual(got, tt.want) { + t.Fatalf("completeMode(%v) got=%v want=%v", tt.args, got, tt.want) + } + } +} + +func stringSlicesEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func TestHelpText(t *testing.T) { + a := newTestAppForCommand() + text := a.dispatcher.HelpText() + for _, cmd := range []string{".help", ".exit", ".hex", ".forward", ".plugin", ".mode"} { + if !strings.Contains(text, cmd) { + t.Fatalf("HelpText missing command %q", cmd) + } + } +} + +func TestCommandExecuteHex(t *testing.T) { + setupTestPipes() + a := newTestAppForCommand() + handled, err := a.dispatcher.Execute(".hex 41 42 43") + if err != nil || !handled { + t.Fatalf(".hex valid failed handled=%v err=%v", handled, err) + } + + handled, err = a.dispatcher.Execute(".hex") + if !handled || err == nil { + t.Fatalf(".hex no args should error, handled=%v err=%v", handled, err) + } + + handled, err = a.dispatcher.Execute(".hex xyz") + if !handled || err == nil { + t.Fatalf(".hex invalid hex should error, handled=%v err=%v", handled, err) + } +} + +func TestCommandExecuteExit(t *testing.T) { + a := newTestAppForCommand() + a.Close() + if !a.isClosed() { + t.Fatalf("expected app closed after Close()") + } +} + +func TestCommandExecuteModeSetAll(t *testing.T) { + a := newTestAppForCommand() + + handled, err := a.dispatcher.Execute(".mode set frame 32") + if err != nil || !handled { + t.Fatalf(".mode set frame failed: handled=%v err=%v", handled, err) + } + if a.cfg.frameSize != 32 { + t.Fatalf("frameSize not set, got=%d", a.cfg.frameSize) + } + + handled, err = a.dispatcher.Execute(".mode set timefmt 2006") + if err != nil || !handled { + t.Fatalf(".mode set timefmt failed: handled=%v err=%v", handled, err) + } + if a.cfg.timesFmt != "2006" { + t.Fatalf("timesFmt not set, got=%q", a.cfg.timesFmt) + } + + handled, err = a.dispatcher.Execute(".mode set out GBK") + if err != nil || !handled { + t.Fatalf(".mode set out failed: handled=%v err=%v", handled, err) + } + if a.cfg.outputCode != "GBK" { + t.Fatalf("outputCode not set, got=%q", a.cfg.outputCode) + } + + handled, err = a.dispatcher.Execute(".mode set in GBK") + if err != nil || !handled { + t.Fatalf(".mode set in failed: handled=%v err=%v", handled, err) + } + if a.cfg.inputCode != "GBK" { + t.Fatalf("inputCode not set, got=%q", a.cfg.inputCode) + } +} + +func TestCommandExecuteModeErrors(t *testing.T) { + a := newTestAppForCommand() + + handled, err := a.dispatcher.Execute(".mode") + if err != nil || !handled { + t.Fatalf(".mode with no subcommand in UI mode shows panel, handled=%v err=%v", handled, err) + } + + _, err = a.dispatcher.Execute(".mode set") + if err == nil { + t.Fatalf(".mode set with no args should error") + } + + _, err = a.dispatcher.Execute(".mode set frame abc") + if err == nil { + t.Fatalf(".mode set frame with non-int should error") + } + + _, err = a.dispatcher.Execute(".mode set timestamp maybe") + if err == nil { + t.Fatalf(".mode set timestamp with invalid value should error") + } + + _, err = a.dispatcher.Execute(".mode set invalid_field value") + if err == nil { + t.Fatalf(".mode set unknown field should error") + } +} + +func TestHandleForwardCommandErrors(t *testing.T) { + a := newTestAppForCommand() + + _, err := a.dispatcher.Execute(".forward add") + if err == nil { + t.Fatalf(".forward add with no args should error") + } + + _, err = a.dispatcher.Execute(".forward add badmode 127.0.0.1:1") + if err == nil { + t.Fatalf(".forward add with invalid mode should error") + } + + _, err = a.dispatcher.Execute(".forward remove abc") + if err == nil { + t.Fatalf(".forward remove with non-int ID should error") + } + + _, err = a.dispatcher.Execute(".forward remove 999") + if err == nil { + t.Fatalf(".forward remove non-existing should error") + } + + _, err = a.dispatcher.Execute(".forward enable abc") + if err == nil { + t.Fatalf(".forward enable with non-int ID should error") + } + + _, err = a.dispatcher.Execute(".forward disable abc") + if err == nil { + t.Fatalf(".forward disable with non-int ID should error") + } + + _, err = a.dispatcher.Execute(".forward update") + if err == nil { + t.Fatalf(".forward update with no args should error") + } + + _, err = a.dispatcher.Execute(".forward update 1") + if err == nil { + t.Fatalf(".forward update with missing addr should error") + } + + _, err = a.dispatcher.Execute(".forward update 1 badmode 127.0.0.1:1") + if err == nil { + t.Fatalf(".forward update with invalid mode should error") + } + + _, err = a.dispatcher.Execute(".forward unknown_sub") + if err == nil { + t.Fatalf(".forward unknown subcommand should error") + } +} + +func TestHandleForwardCommandNoUI(t *testing.T) { + a := newTestAppForCommand() + a.SetUIEnabled(false) + + handled, err := a.dispatcher.Execute(".forward") + if err != nil || !handled { + t.Fatalf(".forward in non-UI should default to list, handled=%v err=%v", handled, err) + } + + handled, err = a.dispatcher.Execute(".forward list") + if err != nil || !handled { + t.Fatalf(".forward list in non-UI failed: %v", err) + } +} + +func TestHandlePluginCommandErrors(t *testing.T) { + a := newTestAppForCommand() + + _, err := a.dispatcher.Execute(".plugin load") + if err == nil { + t.Fatalf(".plugin load with no path should error") + } + + _, err = a.dispatcher.Execute(".plugin unload") + if err == nil { + t.Fatalf(".plugin unload with no name should error") + } + + _, err = a.dispatcher.Execute(".plugin enable") + if err == nil { + t.Fatalf(".plugin enable with no name should error") + } + + _, err = a.dispatcher.Execute(".plugin disable") + if err == nil { + t.Fatalf(".plugin disable with no name should error") + } + + _, err = a.dispatcher.Execute(".plugin reload") + if err == nil { + t.Fatalf(".plugin reload with no name should error") + } + + _, err = a.dispatcher.Execute(".plugin unknown_sub") + if err == nil { + t.Fatalf(".plugin unknown subcommand should error") + } +} + +func TestHandlePluginCommandNoUI(t *testing.T) { + a := newTestAppForCommand() + a.SetUIEnabled(false) + + handled, err := a.dispatcher.Execute(".plugin") + if err != nil || !handled { + t.Fatalf(".plugin in non-UI should default to list, handled=%v err=%v", handled, err) + } +} + +func TestCompleteFirstTokenEdgeCases(t *testing.T) { + a := newTestAppForCommand() + line, cands := a.dispatcher.Complete(".he") + if line != ".he" { + t.Fatalf("ambiguous completion should not change line, got=%q", line) + } + found := false + for _, c := range cands { + if c == ".help" { + found = true + break + } + } + if !found { + t.Fatalf("expected .help in completion candidates, got %v", cands) + } + + line, cands = a.dispatcher.Complete(".exi") + if line != ".exit " || len(cands) != 1 || cands[0] != ".exit" { + t.Fatalf("exact completion of .exi failed: line=%q cands=%v", line, cands) + } + + line, _ = a.dispatcher.Complete("") + if line != "" { + t.Fatalf("empty completion should be noop, got=%q", line) + } +} diff --git a/config.go b/config.go index 3cb43b8..d91f02b 100644 --- a/config.go +++ b/config.go @@ -2,9 +2,8 @@ package main import ( "fmt" - "log" - "net" "os" + "strings" "time" ) @@ -24,7 +23,10 @@ type Config struct { timesTamp bool timesFmt string address []string + enableGUI bool + hotkeyMod string } + type FoeWardMode int const ( @@ -35,34 +37,48 @@ const ( var config Config -func setForWardClient(mode FoeWardMode, add string) (conn net.Conn) { - var err error - switch mode { - case NOT: - +func (m FoeWardMode) Network() string { + switch m { case TCPC: - conn, err = net.Dial("tcp", add) - if err != nil { - log.Fatal(err) - } + return "tcp" case UDPC: - conn, err = net.Dial("udp", add) - if err != nil { - log.Fatal(err) - } + return "udp" default: - panic("未知模式设置") + return "" } - return conn } -func checkLogOpen() { +func (m FoeWardMode) String() string { + switch m { + case TCPC: + return "tcp" + case UDPC: + return "udp" + default: + return "none" + } +} + +func parseForwardMode(v string) (FoeWardMode, bool) { + switch strings.ToLower(strings.TrimSpace(v)) { + case "tcp", "tcp-c", "tcpc", "1": + return TCPC, true + case "udp", "udp-c", "udpc", "2": + return UDPC, true + default: + return NOT, false + } +} + +func openLogFile() (*os.File, error) { if config.enableLog { path := fmt.Sprintf(config.logFilePath, config.portName, time.Now().Format("2006_01_02T150405")) f, err := os.OpenFile(path, os.O_RDWR|os.O_APPEND|os.O_CREATE, 0666) if err != nil { - log.Fatal(err) + return nil, err } - outs = append(outs, f) + return f, nil } + + return nil, nil } diff --git a/config_test.go b/config_test.go new file mode 100644 index 0000000..f634a61 --- /dev/null +++ b/config_test.go @@ -0,0 +1,238 @@ +package main + +import ( + "os" + "path/filepath" + "testing" + + "github.com/spf13/pflag" +) + +func TestForwardModeNetworkAndString(t *testing.T) { + tests := []struct { + mode FoeWardMode + network string + name string + }{ + {mode: NOT, network: "", name: "none"}, + {mode: TCPC, network: "tcp", name: "tcp"}, + {mode: UDPC, network: "udp", name: "udp"}, + } + + for _, tt := range tests { + if got := tt.mode.Network(); got != tt.network { + t.Fatalf("Network() mode=%v got=%q want=%q", tt.mode, got, tt.network) + } + if got := tt.mode.String(); got != tt.name { + t.Fatalf("String() mode=%v got=%q want=%q", tt.mode, got, tt.name) + } + } +} + +func TestParseForwardMode(t *testing.T) { + tests := []struct { + input string + mode FoeWardMode + ok bool + }{ + {input: "tcp", mode: TCPC, ok: true}, + {input: "TCP-C", mode: TCPC, ok: true}, + {input: "1", mode: TCPC, ok: true}, + {input: "udp", mode: UDPC, ok: true}, + {input: " 2 ", mode: UDPC, ok: true}, + {input: "unknown", mode: NOT, ok: false}, + {input: "", mode: NOT, ok: false}, + } + + for _, tt := range tests { + got, ok := parseForwardMode(tt.input) + if ok != tt.ok || got != tt.mode { + t.Fatalf("parseForwardMode(%q) got=(%v,%v) want=(%v,%v)", tt.input, got, ok, tt.mode, tt.ok) + } + } +} + +func TestOpenLogFile(t *testing.T) { + old := config + defer func() { config = old }() + + config = Config{ + enableLog: true, + portName: "COM1", + logFilePath: filepath.Join(t.TempDir(), "%s-%s.log"), + } + + f, err := openLogFile() + if err != nil { + t.Fatalf("openLogFile() unexpected error: %v", err) + } + if f == nil { + t.Fatalf("openLogFile() got nil file when enableLog=true") + } + _ = f.Close() + + config.enableLog = false + f, err = openLogFile() + if err != nil { + t.Fatalf("openLogFile() unexpected error with enableLog=false: %v", err) + } + if f != nil { + t.Fatalf("openLogFile() expected nil file when enableLog=false") + } +} + +func TestFlagFindValue(t *testing.T) { + s := "str" + sl := []string{"a"} + n := 1 + il := []int{1} + b := true + ext := "ext" + + tests := []struct { + name string + v ptrVal + want ValType + }{ + {name: "string", v: ptrVal{string: &s}, want: stringVal}, + {name: "stringSlice", v: ptrVal{sl: &sl}, want: sliceStrVal}, + {name: "bool", v: ptrVal{bool: &b}, want: boolVal}, + {name: "int", v: ptrVal{int: &n}, want: intVal}, + {name: "intSlice", v: ptrVal{il: &il}, want: sliceIntVal}, + {name: "ext", v: ptrVal{ext: &ext}, want: extVal}, + {name: "none", v: ptrVal{}, want: notVal}, + } + + for _, tt := range tests { + got := flagFindValue(tt.v) + if got != tt.want { + t.Fatalf("%s: flagFindValue got=%v want=%v", tt.name, got, tt.want) + } + } +} + +func TestFlagExt(t *testing.T) { + old := config + defer func() { config = old }() + + config = Config{} + flagExt() + if config.enableLog { + t.Fatalf("expected enableLog=false when logFilePath empty") + } + if config.timesTamp { + t.Fatalf("expected timesTamp=false when timesFmt empty") + } + if config.hotkeyMod != "ctrl+alt" { + t.Fatalf("expected default hotkeyMod=ctrl+alt, got=%q", config.hotkeyMod) + } + + config = Config{logFilePath: "/tmp/log.txt"} + flagExt() + if !config.enableLog { + t.Fatalf("expected enableLog=true when logFilePath set") + } + + config = Config{timesFmt: "2006-01-02"} + flagExt() + if !config.timesTamp { + t.Fatalf("expected timesTamp=true when timesFmt set") + } + + config = Config{hotkeyMod: ""} + flagExt() + if config.hotkeyMod != "ctrl+alt" { + t.Fatalf("empty hotkeyMod should default to ctrl+alt") + } + + config = Config{hotkeyMod: "ctrl+shift"} + flagExt() + if config.hotkeyMod != "ctrl+shift" { + t.Fatalf("expected ctrl+shift preserved") + } + + config = Config{hotkeyMod: " CTRL+SHIFT "} + flagExt() + if config.hotkeyMod != "ctrl+shift" { + t.Fatalf("expected whitespace+case normalization, got=%q", config.hotkeyMod) + } + + config = Config{hotkeyMod: "invalid"} + flagExt() + if config.hotkeyMod != "ctrl+alt" { + t.Fatalf("invalid hotkeyMod should default to ctrl+alt, got=%q", config.hotkeyMod) + } +} + +func TestFlagInit(t *testing.T) { + var testStr string + var testBool bool + var testInt int + var testExt string + var testSl []string + var testIl []int + + f := Flag{ + v: ptrVal{string: &testStr}, + sStr: "X", lStr: "test-str", dv: Val{string: "hello"}, help: "test string", + } + flagInit(&f) + if pflag.Lookup("test-str") == nil { + t.Fatalf("string flag not registered") + } + + boolF := Flag{ + v: ptrVal{bool: &testBool}, + sStr: "Y", lStr: "test-bool", dv: Val{bool: true}, help: "test bool", + } + flagInit(&boolF) + + intF := Flag{ + v: ptrVal{int: &testInt}, + sStr: "Z", lStr: "test-int", dv: Val{int: 42}, help: "test int", + } + flagInit(&intF) + + extF := Flag{ + v: ptrVal{ext: &testExt}, + sStr: "E", lStr: "test-ext", dv: Val{extdef: "default-val", string: ""}, help: "test ext", + } + flagInit(&extF) + + slF := Flag{ + v: ptrVal{sl: &testSl}, + sStr: "1", lStr: "test-sl", dv: Val{string: "a"}, help: "test sl", + } + flagInit(&slF) + + ilF := Flag{ + v: ptrVal{il: &testIl}, + sStr: "2", lStr: "test-il", dv: Val{int: 1}, help: "test il", + } + flagInit(&ilF) +} + +func TestNormalizeFlags(t *testing.T) { + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + + os.Args = []string{"COM.exe", "-port", "COM17", "-baud", "9600", "-p", "COM1", "--gui", "COM17"} + normalizeFlags() + + args := os.Args + if args[1] != "--port" { + t.Fatalf("expected -port -> --port, got %q", args[1]) + } + if args[3] != "--baud" { + t.Fatalf("expected -baud -> --baud, got %q", args[3]) + } + if args[5] != "-p" { + t.Fatalf("expected -p unchanged, got %q", args[5]) + } + if args[7] != "--gui" { + t.Fatalf("expected --gui unchanged, got %q", args[7]) + } + if args[8] != "COM17" { + t.Fatalf("expected value unchanged, got %q", args[8]) + } +} diff --git a/escape_test.go b/escape_test.go new file mode 100644 index 0000000..fe13197 --- /dev/null +++ b/escape_test.go @@ -0,0 +1,123 @@ +package main + +import ( + "testing" +) + +func TestParseCSIu(t *testing.T) { + tests := []struct { + name string + seq []byte + cp int + mod int + ok bool + }{ + { + name: "ctrl+alt+c lowercase", + seq: []byte{0x1b, '[', '9', '9', ';', '6', 'u'}, + cp: 99, mod: 6, ok: true, + }, + { + name: "ctrl+shift+c uppercase", + seq: []byte{0x1b, '[', '6', '7', ';', '5', 'u'}, + cp: 67, mod: 5, ok: true, + }, + { + name: "too short", + seq: []byte{0x1b, '[', '9', '9'}, + cp: 0, mod: 0, ok: false, + }, + { + name: "no escape prefix", + seq: []byte{'[', '9', '9', ';', '6', 'u'}, + cp: 0, mod: 0, ok: false, + }, + { + name: "no u terminator", + seq: []byte{0x1b, '[', '9', '9', ';', '6', 'x'}, + cp: 0, mod: 0, ok: false, + }, + { + name: "bad format no semicolon", + seq: []byte{0x1b, '[', '9', '9', '6', 'u'}, + cp: 0, mod: 0, ok: false, + }, + { + name: "empty", + seq: []byte{}, + cp: 0, mod: 0, ok: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cp, mod, ok := parseCSIu(tt.seq) + if ok != tt.ok || cp != tt.cp || mod != tt.mod { + t.Fatalf("parseCSIu(%v) got=(%d,%d,%v) want=(%d,%d,%v)", tt.seq, cp, mod, ok, tt.cp, tt.mod, tt.ok) + } + }) + } +} + +func TestIsExitHotkeySeq(t *testing.T) { + oldCfg := config + defer func() { config = oldCfg }() + + config = Config{hotkeyMod: "ctrl+alt"} + + // CSI u Ctrl+Alt+C (mod=6) + if !isExitHotkeySeq([]byte{0x1b, '[', '9', '9', ';', '6', 'u'}) { + t.Fatalf("Ctrl+Alt+C CSI should exit with ctrl+alt config") + } + // CSI u Ctrl+Alt+Shift+C (mod=7, includes Ctrl+Alt) + if !isExitHotkeySeq([]byte{0x1b, '[', '9', '9', ';', '7', 'u'}) { + t.Fatalf("Ctrl+Alt+Shift+C should also exit") + } + // CSI u Ctrl+Shift+C (mod=5) + if isExitHotkeySeq([]byte{0x1b, '[', '9', '9', ';', '5', 'u'}) { + t.Fatalf("Ctrl+Shift+C should NOT exit with ctrl+alt config") + } + // CSI for other key + if isExitHotkeySeq([]byte{0x1b, '[', '9', '7', ';', '6', 'u'}) { + t.Fatalf("Ctrl+Alt+A should not exit") + } + + // Simple ESC c (Alt+C) should NOT exit — requires Ctrl modifier + if isExitHotkeySeq([]byte{0x1b, 'c'}) { + t.Fatalf("Alt+C (ESC c) should NOT exit — Ctrl modifier required") + } + + // Switch to ctrl+shift + config = Config{hotkeyMod: "ctrl+shift"} + + if !isExitHotkeySeq([]byte{0x1b, '[', '9', '9', ';', '5', 'u'}) { + t.Fatalf("Ctrl+Shift+C should exit with ctrl+shift config") + } + if !isExitHotkeySeq([]byte{0x1b, '[', '9', '9', ';', '7', 'u'}) { + t.Fatalf("Ctrl+Shift+Alt+C should also exit (includes Ctrl+Shift)") + } + if isExitHotkeySeq([]byte{0x1b, '[', '9', '9', ';', '6', 'u'}) { + t.Fatalf("Ctrl+Alt+C should NOT exit with ctrl+shift config") + } + // Simple ESC c should NOT exit with ctrl+shift + if isExitHotkeySeq([]byte{0x1b, 'c'}) { + t.Fatalf("ESC c should NOT exit with ctrl+shift config") + } + // Non-CSI garbage + if isExitHotkeySeq([]byte{0x1b, 'x'}) { + t.Fatalf("ESC x should not exit") + } + if isExitHotkeySeq([]byte("hello")) { + t.Fatalf("plain bytes should not exit") + } + + config = Config{hotkeyMod: "ctrl+alt"} + // Ctrl only (mod=4) should not exit (requires Alt too) + if isExitHotkeySeq([]byte{0x1b, '[', '9', '9', ';', '4', 'u'}) { + t.Fatalf("Ctrl+C (without Alt) should not exit") + } + // Alt only (mod=2) should not exit (requires Ctrl too) + if isExitHotkeySeq([]byte{0x1b, '[', '9', '9', ';', '2', 'u'}) { + t.Fatalf("Alt+C (without Ctrl) should not exit") + } +} diff --git a/flag.go b/flag.go index 1c05a78..6ddf4e5 100644 --- a/flag.go +++ b/flag.go @@ -14,6 +14,8 @@ import ( "github.com/spf13/pflag" "go.bug.st/serial" "log" + "os" + "sort" "strconv" "strings" ) @@ -58,7 +60,9 @@ var ( address = Flag{ptrVal{sl: &config.address}, "a", "address", Val{string: "127.0.0.1:12345"}, "转发服务地址(支持多次传入)"} frameSize = Flag{ptrVal{int: &config.frameSize}, "F", "Frame", Val{int: 16}, "帧大小"} parityBit = Flag{ptrVal{int: &config.parityBit}, "v", "verify", Val{int: 0}, "奇偶校验(0:无校验、1:奇校验、2:偶校验、3:1校验、4:0校验)"} - flags = []Flag{portName, baudRate, dataBits, stopBits, outputCode, inputCode, endStr, forWard, address, frameSize, parityBit, logExt, timeExt} + guiMode = Flag{ptrVal{bool: &config.enableGUI}, "g", "gui", Val{bool: false}, "启用TUI交互界面"} + hotkeyMod = Flag{ptrVal{string: &config.hotkeyMod}, "k", "hotkey-mod", Val{string: "ctrl+alt"}, "本地快捷键修饰(ctrl+alt|ctrl+shift)"} + flags = []Flag{portName, baudRate, dataBits, stopBits, outputCode, inputCode, endStr, forWard, address, frameSize, parityBit, logExt, timeExt, guiMode, hotkeyMod} ) var ( @@ -79,42 +83,81 @@ const ( intVal boolVal extVal + sliceStrVal + sliceIntVal ) -func printUsage(ports []string) { - fmt.Printf("\n参数帮助:\n") +func normalizeFlags() { + known := make(map[string]bool, len(flags)) for _, f := range flags { + known[f.lStr] = true + } + for i, arg := range os.Args[1:] { + if strings.HasPrefix(arg, "-") && !strings.HasPrefix(arg, "--") { + name := strings.TrimPrefix(arg, "-") + if known[name] { + os.Args[i+1] = "--" + name + } + } + } +} + +func printUsage(ports []string) { + sorted := make([]Flag, len(flags)) + copy(sorted, flags) + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].lStr < sorted[j].lStr + }) + + fmt.Printf("\n参数帮助:\n") + fmt.Printf(" %-6s %-14s %-8s %-44s %s\n", "短参", "长参", "类型", "说明", "默认值") + fmt.Printf(" %-6s %-14s %-8s %-44s %s\n", "------", "------", "------", "------", "------") + for _, f := range sorted { flagprint(f) } - fmt.Printf("\n在线串口: %v\n", strings.Join(ports, ",")) + fmt.Printf("\n在线串口: %v\n", strings.Join(ports, ", ")) } + func flagFindValue(v ptrVal) ValType { if v.string != nil { return stringVal } + if v.sl != nil { + return sliceStrVal + } if v.bool != nil { return boolVal } if v.int != nil { return intVal } + if v.il != nil { + return sliceIntVal + } if v.ext != nil { return extVal } return notVal } + func flagprint(f Flag) { + short := "-" + f.sStr + long := "--" + f.lStr + help := f.help + switch flagFindValue(f.v) { case stringVal: - fmt.Printf("\t-%v -%v %T \n\t %v\t默认值:%q\n", f.sStr, f.lStr, f.dv.string, f.help, f.dv.string) + fmt.Printf(" %-6s %-14s %-8s %-44s %q\n", short, long, "string", help, f.dv.string) case intVal: - fmt.Printf("\t-%v -%v %T \n\t %v\t默认值:%v\n", f.sStr, f.lStr, f.dv.int, f.help, f.dv.int) + fmt.Printf(" %-6s %-14s %-8s %-44s %v\n", short, long, "int", help, f.dv.int) case boolVal: - fmt.Printf("\t-%v -%v %T \n\t %v\t默认值:%v\n", f.sStr, f.lStr, f.dv.bool, f.help, f.dv.bool) + fmt.Printf(" %-6s %-14s %-8s %-44s %v\n", short, long, "bool", help, f.dv.bool) case extVal: - fmt.Printf("\t-%v -%v %T \n\t %v\t默认值:%v\n", f.sStr, f.lStr, f.dv.extdef, f.help, f.dv.extdef) - default: - panic("unhandled default case") + fmt.Printf(" %-6s %-14s %-8s %-44s %v\n", short, long, "string", help, f.dv.extdef) + case sliceStrVal: + fmt.Printf(" %-6s %-14s %-8s %-44s %q\n", short, long, "[]string", help, f.dv.string) + case sliceIntVal: + fmt.Printf(" %-6s %-14s %-8s %-44s %v\n", short, long, "[]int", help, f.dv.int) } } func flagInit(f *Flag) { @@ -145,6 +188,13 @@ func flagExt() { if config.timesFmt != "" { config.timesTamp = true } + if config.hotkeyMod == "" { + config.hotkeyMod = "ctrl+alt" + } + config.hotkeyMod = strings.ToLower(strings.TrimSpace(config.hotkeyMod)) + if config.hotkeyMod != "ctrl+alt" && config.hotkeyMod != "ctrl+shift" { + config.hotkeyMod = "ctrl+alt" + } } func getCliFlag() { ports, err := serial.GetPortsList() diff --git a/forwarding.go b/forwarding.go new file mode 100644 index 0000000..7a324e5 --- /dev/null +++ b/forwarding.go @@ -0,0 +1,311 @@ +package main + +import ( + "fmt" + "net" + "sort" + "sync" + "sync/atomic" + "time" +) + +type ForwardStats struct { + ReadBytes uint64 + WrittenBytes uint64 + LastError string +} + +type ForwardTarget struct { + ID int + Mode FoeWardMode + Address string + Enabled bool + Connected bool + CreatedAt time.Time + + conn net.Conn + stats ForwardStats + mu sync.Mutex + closeCh chan struct{} + closed bool +} + +type ForwardSnapshot struct { + ID int + Mode string + Address string + Enabled bool + Connected bool + ReadBytes uint64 + WriteByte uint64 + LastError string +} + +type ForwardManager struct { + mu sync.RWMutex + targets map[int]*ForwardTarget + nextID int + writeToSerial func([]byte) error + notify func(string, ...any) + onInbound func(int, []byte) +} + +func NewForwardManager(writeToSerial func([]byte) error, notify func(string, ...any)) *ForwardManager { + return &ForwardManager{ + targets: make(map[int]*ForwardTarget), + nextID: 1, + writeToSerial: writeToSerial, + notify: notify, + } +} + +func (m *ForwardManager) SetInboundReporter(fn func(int, []byte)) { + m.mu.Lock() + defer m.mu.Unlock() + m.onInbound = fn +} + +func (m *ForwardManager) Add(mode FoeWardMode, address string) (int, error) { + if mode == NOT { + return 0, fmt.Errorf("forward mode cannot be none") + } + + t := &ForwardTarget{ + Mode: mode, + Address: address, + Enabled: true, + CreatedAt: time.Now(), + closeCh: make(chan struct{}), + } + + conn, err := net.Dial(mode.Network(), address) + if err != nil { + t.stats.LastError = err.Error() + return 0, err + } + + t.conn = conn + t.Connected = true + + m.mu.Lock() + t.ID = m.nextID + m.nextID++ + m.targets[t.ID] = t + m.mu.Unlock() + + go m.readLoop(t, conn, t.closeCh) + m.notify("[forward] #%d %s %s connected", t.ID, t.Mode.String(), t.Address) + return t.ID, nil +} + +func (m *ForwardManager) readLoop(t *ForwardTarget, conn net.Conn, stop <-chan struct{}) { + buf := make([]byte, 4096) + for { + n, err := conn.Read(buf) + if n > 0 { + atomic.AddUint64(&t.stats.ReadBytes, uint64(n)) + chunk := make([]byte, n) + copy(chunk, buf[:n]) + if wErr := m.writeToSerial(chunk); wErr != nil { + t.stats.LastError = wErr.Error() + m.notify("[forward] #%d write serial error: %v", t.ID, wErr) + } else if m.onInbound != nil { + m.onInbound(t.ID, chunk) + } + } + + if err != nil { + t.mu.Lock() + if t.conn == conn { + t.Connected = false + } + t.stats.LastError = err.Error() + t.mu.Unlock() + m.notify("[forward] #%d disconnected: %v", t.ID, err) + return + } + + select { + case <-stop: + return + default: + } + } +} + +func (m *ForwardManager) Remove(id int) error { + m.mu.Lock() + t, ok := m.targets[id] + if !ok { + m.mu.Unlock() + return fmt.Errorf("forward #%d not found", id) + } + delete(m.targets, id) + m.mu.Unlock() + + t.close() + m.notify("[forward] #%d removed", id) + return nil +} + +func (m *ForwardManager) Enable(id int) error { + m.mu.RLock() + t, ok := m.targets[id] + m.mu.RUnlock() + if !ok { + return fmt.Errorf("forward #%d not found", id) + } + + t.mu.Lock() + defer t.mu.Unlock() + if t.Enabled && t.Connected { + return nil + } + + conn, err := net.Dial(t.Mode.Network(), t.Address) + if err != nil { + t.stats.LastError = err.Error() + return err + } + + t.Enabled = true + t.Connected = true + t.conn = conn + t.closeCh = make(chan struct{}) + t.closed = false + go m.readLoop(t, conn, t.closeCh) + m.notify("[forward] #%d enabled", id) + return nil +} + +func (m *ForwardManager) Update(id int, mode FoeWardMode, address string) error { + if mode == NOT { + return fmt.Errorf("forward mode cannot be none") + } + + m.mu.RLock() + t, ok := m.targets[id] + m.mu.RUnlock() + if !ok { + return fmt.Errorf("forward #%d not found", id) + } + + t.mu.Lock() + wasEnabled := t.Enabled + t.Mode = mode + t.Address = address + t.mu.Unlock() + + // Restart the target to apply new mode/address when enabled. + t.close() + + if !wasEnabled { + m.notify("[forward] #%d updated (disabled)", id) + return nil + } + + return m.Enable(id) +} + +func (m *ForwardManager) Disable(id int) error { + m.mu.RLock() + t, ok := m.targets[id] + m.mu.RUnlock() + if !ok { + return fmt.Errorf("forward #%d not found", id) + } + + t.mu.Lock() + t.Enabled = false + t.mu.Unlock() + t.close() + m.notify("[forward] #%d disabled", id) + return nil +} + +func (m *ForwardManager) Broadcast(data []byte) { + if len(data) == 0 { + return + } + + m.mu.RLock() + items := make([]*ForwardTarget, 0, len(m.targets)) + for _, t := range m.targets { + items = append(items, t) + } + m.mu.RUnlock() + + for _, t := range items { + if !t.Enabled || !t.Connected || t.conn == nil { + continue + } + + n, err := t.conn.Write(data) + if err != nil { + t.stats.LastError = err.Error() + m.notify("[forward] #%d write error: %v", t.ID, err) + continue + } + + atomic.AddUint64(&t.stats.WrittenBytes, uint64(n)) + } +} + +func (m *ForwardManager) List() []ForwardSnapshot { + m.mu.RLock() + items := make([]ForwardSnapshot, 0, len(m.targets)) + for _, t := range m.targets { + items = append(items, ForwardSnapshot{ + ID: t.ID, + Mode: t.Mode.String(), + Address: t.Address, + Enabled: t.Enabled, + Connected: t.Connected, + ReadBytes: atomic.LoadUint64(&t.stats.ReadBytes), + WriteByte: atomic.LoadUint64(&t.stats.WrittenBytes), + LastError: t.stats.LastError, + }) + } + m.mu.RUnlock() + + sort.Slice(items, func(i, j int) bool { + return items[i].ID < items[j].ID + }) + + return items +} + +func (m *ForwardManager) Close() { + m.mu.Lock() + items := make([]*ForwardTarget, 0, len(m.targets)) + for _, t := range m.targets { + items = append(items, t) + } + m.targets = map[int]*ForwardTarget{} + m.mu.Unlock() + + for _, t := range items { + t.close() + } +} + +func (t *ForwardTarget) close() { + t.mu.Lock() + if t.closed { + t.mu.Unlock() + return + } + t.closed = true + ch := t.closeCh + conn := t.conn + t.conn = nil + t.Connected = false + t.mu.Unlock() + + if ch != nil { + close(ch) + } + if conn != nil { + _ = conn.Close() + } +} diff --git a/forwarding_test.go b/forwarding_test.go new file mode 100644 index 0000000..0ae2059 --- /dev/null +++ b/forwarding_test.go @@ -0,0 +1,261 @@ +package main + +import ( + "net" + "testing" + "time" +) + +func TestForwardManagerTCPFlow(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen failed: %v", err) + } + defer listener.Close() + + acceptCh := make(chan net.Conn, 1) + errCh := make(chan error, 1) + go func() { + conn, e := listener.Accept() + if e != nil { + errCh <- e + return + } + acceptCh <- conn + }() + + serialCh := make(chan string, 2) + mgr := NewForwardManager(func(b []byte) error { + serialCh <- string(b) + return nil + }, func(string, ...any) {}) + defer mgr.Close() + + id, err := mgr.Add(TCPC, listener.Addr().String()) + if err != nil { + t.Fatalf("Add() failed: %v", err) + } + + var serverConn net.Conn + select { + case serverConn = <-acceptCh: + case e := <-errCh: + t.Fatalf("accept failed: %v", e) + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for accepted connection") + } + defer serverConn.Close() + + items := mgr.List() + if len(items) != 1 || items[0].ID != id || !items[0].Enabled { + t.Fatalf("unexpected list after add: %+v", items) + } + + if err = serverConn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatalf("SetReadDeadline failed: %v", err) + } + mgr.Broadcast([]byte("from-app")) + buf := make([]byte, 64) + n, err := serverConn.Read(buf) + if err != nil { + t.Fatalf("server read from broadcast failed: %v", err) + } + if string(buf[:n]) != "from-app" { + t.Fatalf("broadcast payload mismatch got=%q", string(buf[:n])) + } + + if _, err = serverConn.Write([]byte("from-remote")); err != nil { + t.Fatalf("server write failed: %v", err) + } + select { + case got := <-serialCh: + if got != "from-remote" { + t.Fatalf("writeToSerial payload mismatch got=%q", got) + } + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for writeToSerial callback") + } + + if err = mgr.Disable(id); err != nil { + t.Fatalf("Disable() failed: %v", err) + } + items = mgr.List() + if len(items) != 1 || items[0].Enabled { + t.Fatalf("Disable() did not update state: %+v", items) + } + + if err = mgr.Remove(id); err != nil { + t.Fatalf("Remove() failed: %v", err) + } + if got := mgr.List(); len(got) != 0 { + t.Fatalf("expected empty list after remove, got=%+v", got) + } +} + +func TestForwardManagerErrorCases(t *testing.T) { + mgr := NewForwardManager(func([]byte) error { return nil }, func(string, ...any) {}) + defer mgr.Close() + + if _, err := mgr.Add(NOT, "127.0.0.1:1"); err == nil { + t.Fatalf("Add(NOT) expected error") + } + + if err := mgr.Remove(999); err == nil { + t.Fatalf("Remove(non-existing) expected error") + } + + if err := mgr.Disable(999); err == nil { + t.Fatalf("Disable(non-existing) expected error") + } + + if err := mgr.Enable(999); err == nil { + t.Fatalf("Enable(non-existing) expected error") + } + + if err := mgr.Update(999, TCPC, "127.0.0.1:1"); err == nil { + t.Fatalf("Update(non-existing) expected error") + } + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen failed: %v", err) + } + defer listener.Close() + + id, err := mgr.Add(TCPC, listener.Addr().String()) + if err != nil { + t.Fatalf("Add() failed: %v", err) + } + if err = mgr.Update(id, NOT, "127.0.0.1:1"); err == nil { + t.Fatalf("Update(NOT) expected error") + } +} + +func TestForwardManagerSetInboundReporter(t *testing.T) { + reported := make(chan []byte, 1) + mgr := NewForwardManager(func([]byte) error { return nil }, func(string, ...any) {}) + defer mgr.Close() + mgr.SetInboundReporter(func(id int, chunk []byte) { + reported <- chunk + }) + if mgr.onInbound == nil { + t.Fatalf("SetInboundReporter should set onInbound") + } +} + +func TestForwardManagerBroadcastToDisabled(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen failed: %v", err) + } + defer listener.Close() + + writeCh := make(chan []byte, 4) + mgr := NewForwardManager(func([]byte) error { + writeCh <- nil + return nil + }, func(string, ...any) {}) + defer mgr.Close() + + id, err := mgr.Add(TCPC, listener.Addr().String()) + if err != nil { + t.Fatalf("Add() failed: %v", err) + } + + // Disable and verify broadcast skips it + if err = mgr.Disable(id); err != nil { + t.Fatalf("Disable() failed: %v", err) + } + + mgr.Broadcast([]byte("should-not-arrive")) + + // No writeToSerial should be triggered + select { + case <-writeCh: + t.Fatalf("broadcast should not write to serial when disabled") + default: + } + + // Empty data should be no-op + mgr.Broadcast(nil) + mgr.Broadcast([]byte{}) +} + +func TestForwardManagerEnable(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen failed: %v", err) + } + defer listener.Close() + + writeCh := make(chan []byte, 2) + mgr := NewForwardManager(func([]byte) error { + writeCh <- nil + return nil + }, func(string, ...any) {}) + defer mgr.Close() + + id, err := mgr.Add(TCPC, listener.Addr().String()) + if err != nil { + t.Fatalf("Add() failed: %v", err) + } + + if err = mgr.Disable(id); err != nil { + t.Fatalf("Disable() failed: %v", err) + } + + // Re-enable should create a new connection + if err = mgr.Enable(id); err != nil { + t.Fatalf("Enable() failed: %v", err) + } + + items := mgr.List() + if len(items) != 1 || !items[0].Enabled { + t.Fatalf("expected enabled after Enable(), got=%+v", items) + } + + // Enable again (should be no-op since already enabled and connected) + if err = mgr.Enable(id); err != nil { + t.Fatalf("second Enable() should succeed: %v", err) + } +} + +func TestForwardManagerUpdate(t *testing.T) { + l1, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen 1 failed: %v", err) + } + defer l1.Close() + + l2, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen 2 failed: %v", err) + } + defer l2.Close() + + mgr := NewForwardManager(func([]byte) error { return nil }, func(string, ...any) {}) + defer mgr.Close() + + id, err := mgr.Add(TCPC, l1.Addr().String()) + if err != nil { + t.Fatalf("Add() failed: %v", err) + } + + // Update to new address (reconnects) + if err = mgr.Update(id, TCPC, l2.Addr().String()); err != nil { + t.Fatalf("Update() failed: %v", err) + } + + items := mgr.List() + if len(items) != 1 || items[0].Address != l2.Addr().String() { + t.Fatalf("update should change address, got=%+v", items) + } + + // Update disabled target + if err = mgr.Disable(id); err != nil { + t.Fatalf("Disable() failed: %v", err) + } + if err = mgr.Update(id, TCPC, l1.Addr().String()); err != nil { + t.Fatalf("Update() on disabled should succeed: %v", err) + } +} diff --git a/internal/event/event.go b/internal/event/event.go new file mode 100644 index 0000000..47f82ae --- /dev/null +++ b/internal/event/event.go @@ -0,0 +1,30 @@ +// Package event defines UI event types shared between app, console, and tui packages. +package event + +// UIEventKind classifies a UI event. +type UIEventKind int + +const ( + UIEventOutput UIEventKind = iota + UIEventStatus + UIEventModal + UIEventPanel +) + +// UIPanelKind identifies a modal panel type. +type UIPanelKind int + +const ( + UIPanelNone UIPanelKind = iota + UIPanelForward + UIPanelPlugin + UIPanelMode +) + +// UIEvent is emitted by the app core and consumed by TUI or console frontends. +type UIEvent struct { + Kind UIEventKind + Title string + Text string + Panel UIPanelKind +} diff --git a/main_other.go b/main_other.go new file mode 100644 index 0000000..2660e3b --- /dev/null +++ b/main_other.go @@ -0,0 +1,5 @@ +//go:build !windows + +package main + +func enableVTInput(fd int) {} diff --git a/main_windows.go b/main_windows.go new file mode 100644 index 0000000..f9089a1 --- /dev/null +++ b/main_windows.go @@ -0,0 +1,14 @@ +//go:build windows + +package main + +import ( + "golang.org/x/sys/windows" +) + +func enableVTInput(fd int) { + var mode uint32 + if err := windows.GetConsoleMode(windows.Handle(fd), &mode); err == nil { + _ = windows.SetConsoleMode(windows.Handle(fd), mode|windows.ENABLE_VIRTUAL_TERMINAL_INPUT) + } +} diff --git a/mutual.go b/mutual.go index 05686f1..0d7467d 100644 --- a/mutual.go +++ b/mutual.go @@ -1,15 +1,10 @@ package main import ( - "bytes" - "fmt" "github.com/trzsz/trzsz-go/trzsz" - "github.com/zimolab/charsetconv" "go.bug.st/serial" "io" "os" - "strings" - "time" ) var ( @@ -21,31 +16,3 @@ var ( stdinPipe *io.PipeWriter clientOut *io.PipeWriter ) - -func convertChunk(chunk []byte, srcCode, dstCode string) ([]byte, error) { - if len(chunk) == 0 { - return nil, nil - } - - if strings.EqualFold(srcCode, dstCode) { - dup := make([]byte, len(chunk)) - copy(dup, chunk) - return dup, nil - } - - var buf bytes.Buffer - err := charsetconv.ConvertWith(bytes.NewReader(chunk), charsetconv.Charset(srcCode), &buf, charsetconv.Charset(dstCode), false) - if err != nil { - return nil, err - } - - return buf.Bytes(), nil -} - -func formatHexFrame(frame []byte, withTimestamp bool, tsFmt string) string { - if withTimestamp { - return fmt.Sprintf("%v % X %q \n", time.Now().Format(tsFmt), frame, frame) - } - - return fmt.Sprintf("% X %q \n", frame, frame) -} diff --git a/pkg/charset/charset.go b/pkg/charset/charset.go new file mode 100644 index 0000000..7ab0fc6 --- /dev/null +++ b/pkg/charset/charset.go @@ -0,0 +1,43 @@ +// Package charset provides character-set conversion and hex formatting utilities. +package charset + +import ( + "bytes" + "fmt" + "strings" + "time" + + "github.com/zimolab/charsetconv" +) + +// ConvertChunk converts a byte chunk from srcCode charset to dstCode charset. +// Returns nil, nil when input is empty. Returns a copied slice when charsets match. +func ConvertChunk(chunk []byte, srcCode, dstCode string) ([]byte, error) { + if len(chunk) == 0 { + return nil, nil + } + + if strings.EqualFold(srcCode, dstCode) { + dup := make([]byte, len(chunk)) + copy(dup, chunk) + return dup, nil + } + + var buf bytes.Buffer + err := charsetconv.ConvertWith(bytes.NewReader(chunk), charsetconv.Charset(srcCode), &buf, charsetconv.Charset(dstCode), false) + if err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// FormatHexFrame formats a byte frame as hex + printable representation. +// Optionally prefixes with a timestamp using the given format string. +func FormatHexFrame(frame []byte, withTimestamp bool, tsFmt string) string { + if withTimestamp { + return fmt.Sprintf("%v % X %q \n", time.Now().Format(tsFmt), frame, frame) + } + + return fmt.Sprintf("% X %q \n", frame, frame) +} diff --git a/pkg/charset/charset_test.go b/pkg/charset/charset_test.go new file mode 100644 index 0000000..ef6ba4e --- /dev/null +++ b/pkg/charset/charset_test.go @@ -0,0 +1,93 @@ +package charset + +import ( + "bytes" + "strings" + "testing" +) + +func TestConvertChunk(t *testing.T) { + t.Run("empty", func(t *testing.T) { + out, err := ConvertChunk(nil, "UTF-8", "UTF-8") + if err != nil { + t.Fatalf("ConvertChunk(nil) unexpected error: %v", err) + } + if out != nil { + t.Fatalf("ConvertChunk(nil) expected nil output") + } + }) + + t.Run("same-charset-copy", func(t *testing.T) { + in := []byte("hello") + out, err := ConvertChunk(in, "UTF-8", "UTF-8") + if err != nil { + t.Fatalf("ConvertChunk same charset unexpected error: %v", err) + } + if !bytes.Equal(out, in) { + t.Fatalf("ConvertChunk same charset mismatch got=%q want=%q", out, in) + } + + out[0] = 'H' + if in[0] != 'h' { + t.Fatalf("ConvertChunk should return a copied slice") + } + }) +} + +func TestFormatHexFrame(t *testing.T) { + frame := []byte("AB") + out := FormatHexFrame(frame, false, "") + if !strings.Contains(out, "41 42") { + t.Fatalf("FormatHexFrame missing hex bytes: %q", out) + } + if !strings.Contains(out, "\"AB\"") { + t.Fatalf("FormatHexFrame missing quoted bytes: %q", out) + } + + outTS := FormatHexFrame([]byte("A"), true, "2006") + if !strings.Contains(outTS, "41") || !strings.Contains(outTS, "\"A\"") { + t.Fatalf("FormatHexFrame(withTimestamp) malformed output: %q", outTS) + } +} + +func TestConvertChunkCharsetConversion(t *testing.T) { + t.Run("gbk-to-utf8", func(t *testing.T) { + // Chinese "你好" in GBK: 0xC4 0xE3 0xBA 0xC3 + gbkHello := []byte{0xC4, 0xE3, 0xBA, 0xC3} + out, err := ConvertChunk(gbkHello, "GBK", "UTF-8") + if err != nil { + t.Fatalf("ConvertChunk GBK->UTF-8 unexpected error: %v", err) + } + if string(out) != "你好" { + t.Fatalf("ConvertChunk GBK->UTF-8 got=%q want=%q", string(out), "你好") + } + }) + + t.Run("same-charset-different-case", func(t *testing.T) { + in := []byte("hello") + out, err := ConvertChunk(in, "utf-8", "UTF-8") + if err != nil { + t.Fatalf("ConvertChunk case-diff unexpected error: %v", err) + } + if !bytes.Equal(out, in) { + t.Fatalf("ConvertChunk case-diff mismatch got=%q want=%q", out, in) + } + }) + + t.Run("invalid-charset", func(t *testing.T) { + _, err := ConvertChunk([]byte("hello"), "INVALID-CHARSET-NAME", "UTF-8") + if err == nil { + t.Fatalf("ConvertChunk invalid charset should error") + } + }) + + t.Run("empty-input", func(t *testing.T) { + out, err := ConvertChunk([]byte{}, "GBK", "UTF-8") + if err != nil { + t.Fatalf("ConvertChunk empty unexpected error: %v", err) + } + if out != nil { + t.Fatalf("ConvertChunk empty input should return nil") + } + }) +} diff --git a/plugin.go b/plugin.go new file mode 100644 index 0000000..5d1bf66 --- /dev/null +++ b/plugin.go @@ -0,0 +1,262 @@ +package main + +import ( + "fmt" + "path/filepath" + "sort" + "strings" + "sync" + + lua "github.com/yuin/gopher-lua" +) + +type LuaPlugin struct { + Name string + Path string + Enabled bool + L *lua.LState + callMu sync.Mutex +} + +type PluginSnapshot struct { + Name string + Path string + Enabled bool +} + +type PluginManager struct { + mu sync.RWMutex + plugins map[string]*LuaPlugin +} + +func NewPluginManager() *PluginManager { + return &PluginManager{plugins: make(map[string]*LuaPlugin)} +} + +func (m *PluginManager) Load(path string) (string, error) { + abs, err := filepath.Abs(path) + if err != nil { + return "", err + } + + name := strings.TrimSuffix(filepath.Base(abs), filepath.Ext(abs)) + if name == "" { + return "", fmt.Errorf("invalid plugin name") + } + + m.mu.Lock() + defer m.mu.Unlock() + if _, ok := m.plugins[name]; ok { + return "", fmt.Errorf("plugin %s already loaded", name) + } + + state := lua.NewState() + if err = state.DoFile(abs); err != nil { + state.Close() + return "", err + } + + m.plugins[name] = &LuaPlugin{ + Name: name, + Path: abs, + Enabled: true, + L: state, + } + + return name, nil +} + +func (m *PluginManager) Unload(name string) error { + m.mu.Lock() + defer m.mu.Unlock() + p, ok := m.plugins[name] + if !ok { + return fmt.Errorf("plugin %s not found", name) + } + + p.L.Close() + delete(m.plugins, name) + return nil +} + +func (m *PluginManager) Enable(name string) error { + m.mu.Lock() + defer m.mu.Unlock() + p, ok := m.plugins[name] + if !ok { + return fmt.Errorf("plugin %s not found", name) + } + p.Enabled = true + return nil +} + +func (m *PluginManager) Disable(name string) error { + m.mu.Lock() + defer m.mu.Unlock() + p, ok := m.plugins[name] + if !ok { + return fmt.Errorf("plugin %s not found", name) + } + p.Enabled = false + return nil +} + +func (m *PluginManager) Reload(name string) error { + m.mu.Lock() + p, ok := m.plugins[name] + m.mu.Unlock() + if !ok { + return fmt.Errorf("plugin %s not found", name) + } + + path := p.Path + if err := m.Unload(name); err != nil { + return err + } + _, err := m.Load(path) + return err +} + +func (m *PluginManager) List() []PluginSnapshot { + m.mu.RLock() + res := make([]PluginSnapshot, 0, len(m.plugins)) + for _, p := range m.plugins { + res = append(res, PluginSnapshot{Name: p.Name, Path: p.Path, Enabled: p.Enabled}) + } + m.mu.RUnlock() + + sort.Slice(res, func(i, j int) bool { + return res[i].Name < res[j].Name + }) + return res +} + +func (m *PluginManager) ProcessInput(data []byte) ([]byte, error) { + return m.processDataHook("OnInput", data) +} + +func (m *PluginManager) ProcessOutput(data []byte) ([]byte, error) { + return m.processDataHook("OnOutput", data) +} + +func (m *PluginManager) processDataHook(name string, data []byte) ([]byte, error) { + m.mu.RLock() + plugins := make([]*LuaPlugin, 0, len(m.plugins)) + for _, p := range m.plugins { + plugins = append(plugins, p) + } + m.mu.RUnlock() + + current := data + for _, p := range plugins { + if !p.Enabled { + continue + } + p.callMu.Lock() + ret, called, err := callStringHook(p.L, name, string(current)) + p.callMu.Unlock() + if err != nil { + return nil, fmt.Errorf("plugin %s %s: %w", p.Name, name, err) + } + if !called { + continue + } + if ret == nil { + return nil, nil + } + current = []byte(*ret) + } + + return current, nil +} + +func (m *PluginManager) ProcessCommand(line string) (string, bool, error) { + m.mu.RLock() + plugins := make([]*LuaPlugin, 0, len(m.plugins)) + for _, p := range m.plugins { + plugins = append(plugins, p) + } + m.mu.RUnlock() + + current := line + allow := true + for _, p := range plugins { + if !p.Enabled { + continue + } + p.callMu.Lock() + next, nextAllow, called, err := callCommandHook(p.L, "OnCommand", current) + p.callMu.Unlock() + if err != nil { + return "", false, fmt.Errorf("plugin %s OnCommand: %w", p.Name, err) + } + if !called { + continue + } + allow = allow && nextAllow + if !allow { + return "", false, nil + } + if next != "" { + current = next + } + } + + return current, true, nil +} + +func (m *PluginManager) Close() { + m.mu.Lock() + defer m.mu.Unlock() + for _, p := range m.plugins { + p.L.Close() + } + m.plugins = map[string]*LuaPlugin{} +} + +func callStringHook(L *lua.LState, name string, payload string) (*string, bool, error) { + fn := L.GetGlobal(name) + if fn.Type() == lua.LTNil { + return nil, false, nil + } + + if err := L.CallByParam(lua.P{Fn: fn, NRet: 1, Protect: true}, lua.LString(payload)); err != nil { + return nil, true, err + } + + ret := L.Get(-1) + L.Pop(1) + if ret.Type() == lua.LTNil { + return nil, true, nil + } + + s := ret.String() + return &s, true, nil +} + +func callCommandHook(L *lua.LState, name, line string) (string, bool, bool, error) { + fn := L.GetGlobal(name) + if fn.Type() == lua.LTNil { + return "", true, false, nil + } + + if err := L.CallByParam(lua.P{Fn: fn, NRet: 2, Protect: true}, lua.LString(line)); err != nil { + return "", true, true, err + } + + allowVal := L.Get(-1) + lineVal := L.Get(-2) + L.Pop(2) + + allow := true + if allowVal.Type() == lua.LTBool { + allow = lua.LVAsBool(allowVal) + } + + next := "" + if lineVal.Type() != lua.LTNil { + next = lineVal.String() + } + + return next, allow, true, nil +} diff --git a/plugin_test.go b/plugin_test.go new file mode 100644 index 0000000..887f265 --- /dev/null +++ b/plugin_test.go @@ -0,0 +1,241 @@ +package main + +import ( + "os" + "path/filepath" + "testing" +) + +func writeLuaScript(t *testing.T, name, content string) string { + t.Helper() + path := filepath.Join(t.TempDir(), name) + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatalf("write lua script failed: %v", err) + } + return path +} + +func TestPluginManagerLoadAndHooks(t *testing.T) { + m := NewPluginManager() + t.Cleanup(m.Close) + + path := writeLuaScript(t, "rewrite.lua", ` +function OnInput(s) + return s .. "-in" +end + +function OnOutput(s) + return s .. "-out" +end + +function OnCommand(line) + return line .. " --lua", true +end +`) + + name, err := m.Load(path) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + if name != "rewrite" { + t.Fatalf("unexpected plugin name: %q", name) + } + + in, err := m.ProcessInput([]byte("abc")) + if err != nil { + t.Fatalf("ProcessInput() failed: %v", err) + } + if string(in) != "abc-in" { + t.Fatalf("ProcessInput() got=%q want=%q", in, "abc-in") + } + + out, err := m.ProcessOutput([]byte("xyz")) + if err != nil { + t.Fatalf("ProcessOutput() failed: %v", err) + } + if string(out) != "xyz-out" { + t.Fatalf("ProcessOutput() got=%q want=%q", out, "xyz-out") + } + + line, allow, err := m.ProcessCommand(".help") + if err != nil { + t.Fatalf("ProcessCommand() failed: %v", err) + } + if !allow || line != ".help --lua" { + t.Fatalf("ProcessCommand() got=(%q,%v) want=(%q,true)", line, allow, ".help --lua") + } +} + +func TestPluginManagerDisableAndUnload(t *testing.T) { + m := NewPluginManager() + t.Cleanup(m.Close) + + path := writeLuaScript(t, "simple.lua", ` +function OnInput(s) + return s .. "-x" +end +`) + + name, err := m.Load(path) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + if err = m.Disable(name); err != nil { + t.Fatalf("Disable() failed: %v", err) + } + got, err := m.ProcessInput([]byte("abc")) + if err != nil { + t.Fatalf("ProcessInput() with disabled plugin failed: %v", err) + } + if string(got) != "abc" { + t.Fatalf("disabled plugin should not modify input, got=%q", got) + } + + if err = m.Enable(name); err != nil { + t.Fatalf("Enable() failed: %v", err) + } + got, err = m.ProcessInput([]byte("abc")) + if err != nil { + t.Fatalf("ProcessInput() after enable failed: %v", err) + } + if string(got) != "abc-x" { + t.Fatalf("enabled plugin should modify input, got=%q", got) + } + + if err = m.Unload(name); err != nil { + t.Fatalf("Unload() failed: %v", err) + } + if len(m.List()) != 0 { + t.Fatalf("Unload() should remove plugin from list") + } +} + +func TestPluginManagerOutputDrop(t *testing.T) { + m := NewPluginManager() + t.Cleanup(m.Close) + + path := writeLuaScript(t, "drop.lua", ` +function OnOutput(s) + return nil +end +`) + + if _, err := m.Load(path); err != nil { + t.Fatalf("Load() failed: %v", err) + } + + out, err := m.ProcessOutput([]byte("abc")) + if err != nil { + t.Fatalf("ProcessOutput() failed: %v", err) + } + if out != nil { + t.Fatalf("expected nil output when plugin returns nil") + } +} + +func TestPluginManagerReload(t *testing.T) { + m := NewPluginManager() + t.Cleanup(m.Close) + + path := writeLuaScript(t, "reloadable.lua", ` +function OnInput(s) + return s .. "-v1" +end +`) + name, err := m.Load(path) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + if err = m.Reload(name); err != nil { + t.Fatalf("Reload() failed: %v", err) + } + + out, err := m.ProcessInput([]byte("test")) + if err != nil { + t.Fatalf("ProcessInput() after reload failed: %v", err) + } + if string(out) != "test-v1" { + t.Fatalf("reloaded plugin should still work, got=%q", out) + } + + if err = m.Reload("nonexistent"); err == nil { + t.Fatalf("Reload() non-existent should error") + } +} + +func TestPluginManagerCommandBlock(t *testing.T) { + m := NewPluginManager() + t.Cleanup(m.Close) + + path := writeLuaScript(t, "blocker.lua", ` +function OnCommand(line) + return line, false +end +`) + + if _, err := m.Load(path); err != nil { + t.Fatalf("Load() failed: %v", err) + } + + line, allow, err := m.ProcessCommand(".exit") + if err != nil { + t.Fatalf("ProcessCommand() failed: %v", err) + } + if allow { + t.Fatalf("command should be blocked, got allow=%v line=%q", allow, line) + } +} + +func TestPluginManagerLoadErrors(t *testing.T) { + m := NewPluginManager() + t.Cleanup(m.Close) + + _, err := m.Load("nonexistent_file.lua") + if err == nil { + t.Fatalf("Load() non-existent file should error") + } + + path := writeLuaScript(t, "bad.lua", "this is not valid lua {{{") + _, err = m.Load(path) + if err == nil { + t.Fatalf("Load() invalid lua should error") + } +} + +func TestPluginManagerDuplicateLoad(t *testing.T) { + m := NewPluginManager() + t.Cleanup(m.Close) + + path := writeLuaScript(t, "once.lua", "function OnInput(s) return s end") + _, err := m.Load(path) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + _, err = m.Load(path) + if err == nil { + t.Fatalf("Load() duplicate should error") + } +} + +func TestPluginManagerListWithDisabled(t *testing.T) { + m := NewPluginManager() + t.Cleanup(m.Close) + + path := writeLuaScript(t, "mylist.lua", "function OnInput(s) return s end") + name, err := m.Load(path) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + if err = m.Disable(name); err != nil { + t.Fatalf("Disable() failed: %v", err) + } + + items := m.List() + if len(items) != 1 || items[0].Enabled { + t.Fatalf("expected disabled in list, got %+v", items) + } +} diff --git a/plugins/demo.lua b/plugins/demo.lua new file mode 100644 index 0000000..6499689 --- /dev/null +++ b/plugins/demo.lua @@ -0,0 +1,14 @@ +-- Demo Lua plugin for the runtime plugin system. +-- It is shipped disabled by default and only runs after `.plugin load`. + +function OnInput(payload) + return payload +end + +function OnOutput(payload) + return payload +end + +function OnCommand(line) + return line, true +end diff --git a/tui_hotkeys.go b/tui_hotkeys.go new file mode 100644 index 0000000..14af8ca --- /dev/null +++ b/tui_hotkeys.go @@ -0,0 +1,153 @@ +package main + +import ( + "strings" + + tea "github.com/charmbracelet/bubbletea" + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/event" +) + +func handleLocalHotkey(m *uiModel, key string) bool { + if m.isLocalHotkey(key, "h") { + modifier := strings.ToUpper(normalizeHotkeyPrefix(m.app.cfg.hotkeyMod)) + m.app.ShowModal("Shortcuts", modifier+"+C => local exit\nCtrl+C => remote interrupt\n"+modifier+"+F => forward panel\n"+modifier+"+P => plugin panel\n"+modifier+"+M => mode panel\nF1 => shortcut help") + return true + } + if m.isLocalHotkey(key, "f") { + m.app.OpenPanel(event.UIPanelForward) + return true + } + if m.isLocalHotkey(key, "p") { + m.app.OpenPanel(event.UIPanelPlugin) + return true + } + if m.isLocalHotkey(key, "m") { + m.app.OpenPanel(event.UIPanelMode) + return true + } + return false +} + +func (m *uiModel) isLocalHotkey(key, action string) bool { + parts := strings.Split(strings.ToLower(key), "+") + if len(parts) < 2 || parts[len(parts)-1] != action { + return false + } + + hasCtrl := false + hasAlt := false + hasShift := false + for _, p := range parts[:len(parts)-1] { + switch p { + case "ctrl": + hasCtrl = true + case "alt": + hasAlt = true + case "shift": + hasShift = true + } + } + + mod := normalizeHotkeyPrefix(m.app.cfg.hotkeyMod) + if mod == "ctrl+shift" { + return hasCtrl && hasShift + } + return hasCtrl && hasAlt +} + +func normalizeHotkeyPrefix(mod string) string { + mod = strings.ToLower(strings.TrimSpace(mod)) + if mod != "ctrl+alt" && mod != "ctrl+shift" { + mod = "ctrl+alt" + } + return mod +} + +func hotkeyWith(mod, action string) string { + return normalizeHotkeyPrefix(mod) + "+" + action +} + +func parseCtrlKey(key string) (byte, bool) { + if !strings.HasPrefix(key, "ctrl+") || strings.HasPrefix(key, "ctrl+shift+") { + return 0, false + } + + parts := strings.Split(key, "+") + if len(parts) != 2 || len(parts[1]) != 1 { + return 0, false + } + ch := parts[1][0] + if ch < 'a' || ch > 'z' { + return 0, false + } + return ch, true +} + +func (m *uiModel) handleViewportKey(msg tea.KeyMsg) bool { + if !m.ready || m.showModal { + return false + } + + key := strings.ToLower(msg.String()) + switch key { + case "pgup", "ctrl+u", "alt+up", "up": + var cmd tea.Cmd + m.viewport, cmd = m.viewport.Update(msg) + _ = cmd + m.followTail = false + return true + case "pgdown", "ctrl+d", "alt+down", "down": + var cmd tea.Cmd + m.viewport, cmd = m.viewport.Update(msg) + _ = cmd + return true + case "home", "g": + m.viewport.GotoTop() + m.followTail = false + return true + case "end", "shift+g": + m.viewport.GotoBottom() + m.followTail = true + return true + default: + return false + } +} + +func (m *uiModel) resetCompletion() { + m.completionActive = false + m.completionBase = "" + m.completionCandidates = nil + m.completionIndex = 0 +} + +func (m *uiModel) stepCompletion(direction int) { + if len(m.completionCandidates) == 0 { + m.resetCompletion() + return + } + if direction >= 0 { + m.completionIndex = (m.completionIndex + 1) % len(m.completionCandidates) + } else { + m.completionIndex = (m.completionIndex - 1 + len(m.completionCandidates)) % len(m.completionCandidates) + } + m.applyCompletion() +} + +func (m *uiModel) applyCompletion() { + if len(m.completionCandidates) == 0 { + return + } + m.input.SetValue(m.completionBase + m.completionCandidates[m.completionIndex] + " ") +} + +func completionBase(line string) string { + if strings.HasSuffix(line, " ") { + return line + } + i := strings.LastIndex(line, " ") + if i < 0 { + return "" + } + return line[:i+1] +} diff --git a/tui_model.go b/tui_model.go new file mode 100644 index 0000000..7ef6a6e --- /dev/null +++ b/tui_model.go @@ -0,0 +1,268 @@ +package main + +import ( + "fmt" + "strings" + + "github.com/charmbracelet/bubbles/textinput" + "github.com/charmbracelet/bubbles/viewport" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/event" +) + +type doneMsg struct{} + +type modeItem struct { + key string + label string + value string +} + +type panelLine struct { + text string + selected bool +} + +type uiModel struct { + app *App + + viewport viewport.Model + input textinput.Model + + ready bool + width int + height int + statusLine string + suggestions []string + content strings.Builder + followTail bool + + showModal bool + modalTitle string + modalBody string + + panelKind event.UIPanelKind + panelIndex int + + forwardItems []ForwardSnapshot + pluginItems []PluginSnapshot + modeItems []modeItem + + promptActive bool + promptTitle string + promptHint string + promptInput textinput.Model + promptSubmit func(string) + + completionActive bool + completionBase string + completionCandidates []string + completionIndex int +} + +func newUIModel(app *App) *uiModel { + in := textinput.New() + // bubbles v0.18.0 computes placeholder width using display cells, + // which can panic on CJK placeholders. Keep this ASCII-only. + in.Placeholder = "Type to send to remote, use .help for commands" + in.Focus() + in.CharLimit = 0 + in.Prompt = "> " + in.Width = 80 + + return &uiModel{app: app, input: in, followTail: true} +} + +func (m *uiModel) Init() tea.Cmd { + return tea.Batch(waitUIEvent(m.app.uiEvents), waitDone(m.app.waitDone()), textinput.Blink) +} + +func waitUIEvent(ch <-chan event.UIEvent) tea.Cmd { + return func() tea.Msg { + ev, ok := <-ch + if !ok { + return doneMsg{} + } + return ev + } +} + +func waitDone(ch <-chan struct{}) tea.Cmd { + return func() tea.Msg { + <-ch + return doneMsg{} + } +} + +func (m *uiModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case doneMsg: + return m, tea.Quit + + case event.UIEvent: + switch msg.Kind { + case event.UIEventOutput, event.UIEventStatus: + if msg.Kind == event.UIEventOutput { + m.appendOutput(msg.Text) + } else { + m.statusLine = msg.Text + } + case event.UIEventModal: + m.showModal = true + m.panelKind = event.UIPanelNone + m.modalTitle = msg.Title + m.modalBody = msg.Text + m.promptActive = false + case event.UIEventPanel: + m.openPanel(msg.Panel) + } + return m, waitUIEvent(m.app.uiEvents) + + case tea.WindowSizeMsg: + m.width = msg.Width + m.height = msg.Height + inputHeight := 3 + statusHeight := 2 + viewportHeight := msg.Height - inputHeight - statusHeight + if viewportHeight < 3 { + viewportHeight = 3 + } + + if !m.ready { + m.viewport = viewport.New(msg.Width, viewportHeight) + m.viewport.YPosition = 0 + m.viewport.SetContent(m.content.String()) + m.ready = true + } else { + m.viewport.Width = msg.Width + m.viewport.Height = viewportHeight + } + + m.input.Width = msg.Width - 4 + m.viewport.GotoBottom() + m.followTail = true + return m, nil + + case tea.KeyMsg: + keyStr := strings.ToLower(msg.String()) + if m.handleViewportKey(msg) { + return m, nil + } + if keyStr != "tab" && keyStr != "shift+tab" { + m.resetCompletion() + } + + if m.showModal && m.handleModalKey(msg) { + return m, nil + } + + if m.isLocalHotkey(keyStr, "c") { + m.app.Statusf("[local] exiting by %s+C", strings.ToUpper(normalizeHotkeyPrefix(m.app.cfg.hotkeyMod))) + m.app.Close() + return m, tea.Quit + } + + if handleLocalHotkey(m, keyStr) { + return m, nil + } + + // Some terminals can't encode Ctrl+Alt/Shift+H distinctly and report Ctrl+H. + if keyStr == "ctrl+h" { + handleLocalHotkey(m, hotkeyWith(m.app.cfg.hotkeyMod, "h")) + return m, nil + } + + if letter, ok := parseCtrlKey(keyStr); ok { + if err := m.app.sendCtrl(letter); err != nil { + m.app.Notifyf("[remote] ctrl send failed: %v", err) + } + return m, nil + } + + switch keyStr { + case "f1": + handleLocalHotkey(m, hotkeyWith(m.app.cfg.hotkeyMod, "h")) + return m, nil + + case "tab", "shift+tab": + direction := 1 + if keyStr == "shift+tab" { + direction = -1 + } + + if m.completionActive && len(m.completionCandidates) > 0 { + m.stepCompletion(direction) + return m, nil + } + + line, cands := m.app.dispatcher.Complete(m.input.Value()) + m.suggestions = cands + if len(cands) == 0 { + return m, nil + } + if len(cands) == 1 { + m.input.SetValue(line) + return m, nil + } + + m.completionActive = true + m.completionBase = completionBase(m.input.Value()) + m.completionCandidates = append([]string(nil), cands...) + if direction < 0 { + m.completionIndex = len(cands) - 1 + } else { + m.completionIndex = 0 + } + m.applyCompletion() + return m, nil + + case "enter": + line := m.input.Value() + m.input.SetValue("") + m.suggestions = nil + m.followTail = true + m.app.handleLine(line) + return m, nil + } + } + + var cmd tea.Cmd + m.input, cmd = m.input.Update(msg) + return m, cmd +} + +func (m *uiModel) View() string { + if !m.ready { + return "Initializing..." + } + + suggest := "Tab: no candidates" + if len(m.suggestions) > 1 { + suggest = "Tab candidates: " + strings.Join(m.suggestions, " ") + } else if len(m.suggestions) == 1 { + suggest = "Tab: " + m.suggestions[0] + } + modifier := strings.ToUpper(normalizeHotkeyPrefix(m.app.cfg.hotkeyMod)) + hotkeys := "Hotkeys: Ctrl+C remote | " + modifier + "+C local | " + modifier + "+F forward | " + modifier + "+P plugins | " + modifier + "+M mode | F1 help" + hotkeys = lipgloss.NewStyle().Faint(true).Foreground(lipgloss.Color("245")).Render(hotkeys) + status := m.statusLine + if status == "" { + status = "Ready" + } + status = lipgloss.NewStyle().Foreground(lipgloss.Color("250")).Faint(true).Render(status) + base := fmt.Sprintf("%s\n%s\n%s\n%s\n%s", m.viewport.View(), suggest, status, m.input.View(), hotkeys) + if !m.showModal { + return fillScreen(m.width, m.height, base) + } + + if m.promptActive { + return renderCenteredModalContent(m.width, m.height, m.renderPrompt()) + } + + if m.panelKind != event.UIPanelNone { + return renderCenteredModalContent(m.width, m.height, m.renderPanel()) + } + + return renderCenteredModal(m.width, m.height, m.modalTitle, m.modalBody) +} diff --git a/tui_panels.go b/tui_panels.go new file mode 100644 index 0000000..f42aad7 --- /dev/null +++ b/tui_panels.go @@ -0,0 +1,322 @@ +package main + +import ( + "fmt" + "strings" + + "github.com/charmbracelet/bubbles/textinput" + tea "github.com/charmbracelet/bubbletea" + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/event" +) + +func (m *uiModel) handleModalKey(msg tea.KeyMsg) bool { + keyStr := strings.ToLower(msg.String()) + + if m.promptActive { + return m.handlePromptKey(msg) + } + if keyStr == "esc" { + m.closeModal() + return true + } + if m.panelKind == event.UIPanelNone { + if keyStr == "enter" { + m.closeModal() + } + return true + } + + switch m.panelKind { + case event.UIPanelForward: + return m.handleForwardPanelKey(keyStr) + case event.UIPanelPlugin: + return m.handlePluginPanelKey(keyStr) + case event.UIPanelMode: + return m.handleModePanelKey(keyStr) + default: + return true + } +} + +func (m *uiModel) closeModal() { + m.showModal = false + m.panelKind = event.UIPanelNone + m.modalTitle = "" + m.modalBody = "" + m.promptActive = false + m.promptSubmit = nil +} + +func (m *uiModel) openPanel(kind event.UIPanelKind) { + m.showModal = true + m.panelKind = kind + m.panelIndex = 0 + m.promptActive = false + m.promptSubmit = nil + m.refreshPanel() +} + +func (m *uiModel) refreshPanel() { + switch m.panelKind { + case event.UIPanelForward: + m.forwardItems = m.app.forward.List() + m.panelIndex = clampIndex(m.panelIndex, len(m.forwardItems)) + case event.UIPanelPlugin: + m.pluginItems = m.app.plugins.List() + m.panelIndex = clampIndex(m.panelIndex, len(m.pluginItems)) + case event.UIPanelMode: + m.modeItems = m.buildModeItems() + m.panelIndex = clampIndex(m.panelIndex, len(m.modeItems)) + } +} + +func (m *uiModel) buildModeItems() []modeItem { + return []modeItem{{"in", "Input Charset", m.app.cfg.inputCode}, {"out", "Output Charset", m.app.cfg.outputCode}, {"end", "Line End", fmt.Sprintf("%q", m.app.cfg.endStr)}, {"frame", "Hex Frame Size", fmt.Sprintf("%d", m.app.cfg.frameSize)}, {"timestamp", "Timestamp", fmt.Sprintf("%v", m.app.cfg.timesTamp)}, {"timefmt", "Timestamp Format", m.app.cfg.timesFmt}} +} + +func (m *uiModel) handleForwardPanelKey(key string) bool { + switch key { + case "up", "k": + if m.panelIndex > 0 { + m.panelIndex-- + } + return true + case "down", "j": + if m.panelIndex < len(m.forwardItems)-1 { + m.panelIndex++ + } + return true + case "r": + m.refreshPanel() + return true + case "a": + m.startPrompt("Add Forward", "tcp 127.0.0.1:12345", "", func(v string) { + parts := strings.Fields(v) + if len(parts) < 2 { + m.app.Statusf("[forward] usage:
") + return + } + m.app.handleLine(fmt.Sprintf(".forward add %s %s", parts[0], parts[1])) + m.refreshPanel() + }) + return true + } + if len(m.forwardItems) == 0 { + return true + } + + sel := m.forwardItems[m.panelIndex] + switch key { + case "enter": + if sel.Enabled { + m.app.handleLine(fmt.Sprintf(".forward disable %d", sel.ID)) + } else { + m.app.handleLine(fmt.Sprintf(".forward enable %d", sel.ID)) + } + m.refreshPanel() + return true + case "d", "delete", "backspace": + m.app.handleLine(fmt.Sprintf(".forward remove %d", sel.ID)) + m.refreshPanel() + return true + case "u": + m.startPrompt("Update Forward #"+fmt.Sprint(sel.ID), "tcp 127.0.0.1:12345", fmt.Sprintf("%s %s", sel.Mode, sel.Address), func(v string) { + parts := strings.Fields(v) + if len(parts) < 2 { + m.app.Statusf("[forward] usage:
") + return + } + m.app.handleLine(fmt.Sprintf(".forward update %d %s %s", sel.ID, parts[0], parts[1])) + m.refreshPanel() + }) + return true + default: + return true + } +} + +func (m *uiModel) handlePluginPanelKey(key string) bool { + switch key { + case "up", "k": + if m.panelIndex > 0 { + m.panelIndex-- + } + return true + case "down", "j": + if m.panelIndex < len(m.pluginItems)-1 { + m.panelIndex++ + } + return true + case "r": + m.refreshPanel() + return true + case "l": + m.startPrompt("Load Plugin", "./plugins/demo.lua", "", func(v string) { + path := strings.TrimSpace(v) + if path == "" { + m.app.Statusf("[plugin] load path is empty") + return + } + m.app.handleLine(fmt.Sprintf(".plugin load %s", path)) + m.refreshPanel() + }) + return true + } + if len(m.pluginItems) == 0 { + return true + } + + sel := m.pluginItems[m.panelIndex] + switch key { + case "enter": + if sel.Enabled { + m.app.handleLine(fmt.Sprintf(".plugin disable %s", sel.Name)) + } else { + m.app.handleLine(fmt.Sprintf(".plugin enable %s", sel.Name)) + } + m.refreshPanel() + return true + case "u": + m.app.handleLine(fmt.Sprintf(".plugin reload %s", sel.Name)) + m.refreshPanel() + return true + case "d", "delete", "backspace": + m.app.handleLine(fmt.Sprintf(".plugin unload %s", sel.Name)) + m.refreshPanel() + return true + default: + return true + } +} + +func (m *uiModel) handleModePanelKey(key string) bool { + switch key { + case "up", "k": + if m.panelIndex > 0 { + m.panelIndex-- + } + return true + case "down", "j": + if m.panelIndex < len(m.modeItems)-1 { + m.panelIndex++ + } + return true + case "r": + m.refreshPanel() + return true + } + if len(m.modeItems) == 0 { + return true + } + + sel := m.modeItems[m.panelIndex] + switch key { + case " ": + if sel.key == "timestamp" { + if m.app.cfg.timesTamp { + m.app.handleLine(".mode set timestamp off") + } else { + m.app.handleLine(".mode set timestamp on") + } + m.refreshPanel() + } + return true + case "enter", "e": + initial := strings.Trim(sel.value, "\"") + m.startPrompt("Edit Mode: "+sel.label, "new value", initial, func(v string) { + m.app.handleLine(fmt.Sprintf(".mode set %s %s", sel.key, v)) + m.refreshPanel() + }) + return true + default: + return true + } +} + +func (m *uiModel) startPrompt(title, hint, initial string, submit func(string)) { + in := textinput.New() + in.Prompt = "> " + in.Placeholder = hint + in.SetValue(initial) + in.Focus() + in.CharLimit = 0 + in.Width = 64 + + m.promptActive = true + m.promptTitle = title + m.promptHint = hint + m.promptInput = in + m.promptSubmit = submit +} + +func (m *uiModel) handlePromptKey(msg tea.KeyMsg) bool { + key := strings.ToLower(msg.String()) + switch key { + case "esc": + m.promptActive = false + m.promptSubmit = nil + return true + case "enter": + value := strings.TrimSpace(m.promptInput.Value()) + submit := m.promptSubmit + m.promptActive = false + m.promptSubmit = nil + if submit != nil { + submit(value) + } + return true + default: + var cmd tea.Cmd + m.promptInput, cmd = m.promptInput.Update(msg) + _ = cmd + return true + } +} + +func (m *uiModel) renderPanel() string { + switch m.panelKind { + case event.UIPanelForward: + return m.renderForwardPanel() + case event.UIPanelPlugin: + return m.renderPluginPanel() + case event.UIPanelMode: + return m.renderModePanel() + default: + return renderModal("Info", "No panel", m.availableModalWidth()) + } +} + +func (m *uiModel) renderForwardPanel() string { + lines := make([]panelLine, 0, len(m.forwardItems)+2) + if len(m.forwardItems) == 0 { + lines = append(lines, panelLine{text: "No forwarding targets. Press 'a' to add one."}) + } else { + lines = append(lines, panelLine{text: "ID Mode Enabled Connected Address InBytes OutBytes"}) + for i, it := range m.forwardItems { + lines = append(lines, panelLine{text: fmt.Sprintf("%-3d %-5s %-7v %-9v %-22s %-7d %-8d", it.ID, it.Mode, it.Enabled, it.Connected, it.Address, it.ReadBytes, it.WriteByte), selected: i == m.panelIndex}) + } + } + return renderPanelModal("Forward Panel", lines, "Up/Down select | Enter toggle enable | a add | u update | d remove | r refresh | Esc close", m.availableModalWidth()) +} + +func (m *uiModel) renderPluginPanel() string { + lines := make([]panelLine, 0, len(m.pluginItems)+2) + if len(m.pluginItems) == 0 { + lines = append(lines, panelLine{text: "No plugins loaded. Press 'l' to load one."}) + } else { + lines = append(lines, panelLine{text: "Name Enabled Path"}) + for i, it := range m.pluginItems { + lines = append(lines, panelLine{text: fmt.Sprintf("%-20s %-7v %s", it.Name, it.Enabled, it.Path), selected: i == m.panelIndex}) + } + } + return renderPanelModal("Plugin Panel", lines, "Up/Down select | Enter toggle enable | l load | u reload | d unload | r refresh | Esc close", m.availableModalWidth()) +} + +func (m *uiModel) renderModePanel() string { + lines := make([]panelLine, 0, len(m.modeItems)+2) + lines = append(lines, panelLine{text: "Field Value"}) + for i, it := range m.modeItems { + lines = append(lines, panelLine{text: fmt.Sprintf("%-16s %s", it.label, it.value), selected: i == m.panelIndex}) + } + return renderPanelModal("Mode Panel", lines, "Up/Down select | Enter edit value | Space toggle timestamp | r refresh | Esc close", m.availableModalWidth()) +} diff --git a/tui_test.go b/tui_test.go new file mode 100644 index 0000000..2f0a3ee --- /dev/null +++ b/tui_test.go @@ -0,0 +1,309 @@ +package main + +import ( + "strings" + "testing" + + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/event" +) + +func TestParseCtrlKey(t *testing.T) { + tests := []struct { + in string + want byte + ok bool + reason string + }{ + {in: "ctrl+c", want: 'c', ok: true, reason: "plain ctrl"}, + {in: "ctrl+shift+c", ok: false, reason: "ctrl+shift reserved for local"}, + {in: "ctrl+enter", ok: false, reason: "non-letter"}, + {in: "alt+c", ok: false, reason: "wrong modifier"}, + } + + for _, tt := range tests { + got, ok := parseCtrlKey(tt.in) + if ok != tt.ok || got != tt.want { + t.Fatalf("%s parseCtrlKey(%q) got=(%q,%v) want=(%q,%v)", tt.reason, tt.in, got, ok, tt.want, tt.ok) + } + } +} + +func TestRenderModal(t *testing.T) { + modal := renderModal("Title", "line1\nline2", 80) + if !strings.Contains(modal, "Title") { + t.Fatalf("renderModal missing title: %q", modal) + } + if !strings.Contains(modal, "line1") || !strings.Contains(modal, "line2") { + t.Fatalf("renderModal missing lines: %q", modal) + } + if !strings.Contains(modal, "╭") || !strings.Contains(modal, "╮") || !strings.Contains(modal, "╰") || !strings.Contains(modal, "╯") { + t.Fatalf("renderModal missing box borders: %q", modal) + } +} + +func TestHandleCtrlShiftLocalHelp(t *testing.T) { + a := &App{uiEvents: make(chan event.UIEvent, 4), cfg: &Config{hotkeyMod: "ctrl+alt"}} + a.SetUIEnabled(true) + m := uiModel{app: a} + + ok := handleLocalHotkey(&m, "ctrl+alt+h") + if !ok { + t.Fatalf("expected local hotkey to be handled") + } + + ev := mustReadEvent(t, a.uiEvents) + if ev.Kind != event.UIEventModal { + t.Fatalf("expected modal event, got %+v", ev) + } +} + +func TestNormalizeHotkeyPrefix(t *testing.T) { + tests := []struct { + in, want string + }{ + {"", "ctrl+alt"}, + {"ctrl+alt", "ctrl+alt"}, + {"ctrl+shift", "ctrl+shift"}, + {"CTRL+ALT", "ctrl+alt"}, + {" ctrl+SHIFT ", "ctrl+shift"}, + {"invalid", "ctrl+alt"}, + } + + for _, tt := range tests { + got := normalizeHotkeyPrefix(tt.in) + if got != tt.want { + t.Fatalf("normalizeHotkeyPrefix(%q) got=%q want=%q", tt.in, got, tt.want) + } + } +} + +func TestHotkeyWith(t *testing.T) { + got := hotkeyWith("ctrl+alt", "h") + if got != "ctrl+alt+h" { + t.Fatalf("hotkeyWith ctrl+alt+h got=%q", got) + } + got = hotkeyWith("ctrl+shift", "c") + if got != "ctrl+shift+c" { + t.Fatalf("hotkeyWith ctrl+shift+c got=%q", got) + } +} + +func TestIsLocalHotkeyAll(t *testing.T) { + tests := []struct { + key, mod string + action string + want bool + }{ + {"ctrl+alt+c", "ctrl+alt", "c", true}, + {"ctrl+shift+c", "ctrl+shift", "c", true}, + {"ctrl+alt+c", "ctrl+shift", "c", false}, + {"ctrl+shift+c", "ctrl+alt", "c", false}, + {"alt+c", "ctrl+alt", "c", false}, + {"ctrl+c", "ctrl+alt", "c", false}, + } + + for _, tt := range tests { + a := &App{cfg: &Config{hotkeyMod: tt.mod}} + m := uiModel{app: a} + got := m.isLocalHotkey(tt.key, tt.action) + if got != tt.want { + t.Fatalf("isLocalHotkey(%q, %q) hotkeyMod=%q got=%v want=%v", tt.key, tt.action, tt.mod, got, tt.want) + } + } +} + +func TestParseCtrlKeyEdgeCases(t *testing.T) { + tests := []struct { + in string + want byte + ok bool + }{ + {in: "ctrl+z", want: 'z', ok: true}, + {in: "ctrl+a", want: 'a', ok: true}, + {in: "ctrl+shift+c", want: 0, ok: false}, + {in: "ctrl+alt+c", want: 0, ok: false}, + {in: "ctrl+", want: 0, ok: false}, + {in: "ctrl+ab", want: 0, ok: false}, + {in: "ctrl+A", want: 0, ok: false}, + {in: "ctrl+1", want: 0, ok: false}, + } + + for _, tt := range tests { + got, ok := parseCtrlKey(tt.in) + if ok != tt.ok || got != tt.want { + t.Fatalf("parseCtrlKey(%q) got=(%q,%v) want=(%q,%v)", tt.in, got, ok, tt.want, tt.ok) + } + } +} + +func TestRenderModalLongContent(t *testing.T) { + longBody := "line1\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nline9\nline10\nline11\nline12\nline13\nline14" + modal := renderModal("Title", longBody, 80) + if !strings.Contains(modal, "... (press Esc/Enter to close)") { + t.Fatalf("long modal should be truncated: %q", modal) + } + if strings.Contains(modal, "line14") { + t.Fatalf("line14 should not appear in truncated modal") + } +} + +func TestRenderModalEmpty(t *testing.T) { + modal := renderModal("", "", 80) + if !strings.Contains(modal, "Info") { + t.Fatalf("empty title should default to Info: %q", modal) + } +} + +func TestTruncateToWidth(t *testing.T) { + tests := []struct { + in string + width int + want string + }{ + {"hello", 3, "hel"}, + {"hello", 10, "hello"}, + {"", 5, ""}, + {"hello", 0, "hello"}, + } + + for _, tt := range tests { + got := truncateToWidth(tt.in, tt.width) + if got != tt.want { + t.Fatalf("truncateToWidth(%q, %d) got=%q want=%q", tt.in, tt.width, got, tt.want) + } + } +} + +func TestClampIndex(t *testing.T) { + tests := []struct { + idx, n int + want int + }{ + {2, 5, 2}, + {-1, 5, 0}, + {10, 5, 4}, + {0, 0, 0}, + {0, 1, 0}, + } + + for _, tt := range tests { + got := clampIndex(tt.idx, tt.n) + if got != tt.want { + t.Fatalf("clampIndex(%d, %d) got=%d want=%d", tt.idx, tt.n, got, tt.want) + } + } +} + +func TestMinInt(t *testing.T) { + if got := minInt(1, 2); got != 1 { + t.Fatalf("minInt(1,2) got=%d", got) + } + if got := minInt(5, 3); got != 3 { + t.Fatalf("minInt(5,3) got=%d", got) + } + if got := minInt(0, 0); got != 0 { + t.Fatalf("minInt(0,0) got=%d", got) + } +} + +func TestMaxIntFunc(t *testing.T) { + if got := maxInt(1, 2); got != 2 { + t.Fatalf("maxInt(1,2) got=%d", got) + } + if got := maxInt(5, 3, 7); got != 7 { + t.Fatalf("maxInt(5,3,7) got=%d", got) + } +} + +func TestHandleLocalHotkeyForward(t *testing.T) { + a := &App{uiEvents: make(chan event.UIEvent, 4), cfg: &Config{hotkeyMod: "ctrl+alt"}} + a.SetUIEnabled(true) + m := uiModel{app: a} + + if !handleLocalHotkey(&m, "ctrl+alt+f") { + t.Fatalf("expected forward hotkey handled") + } + ev := mustReadEvent(t, a.uiEvents) + if ev.Kind != event.UIEventPanel || ev.Panel != event.UIPanelForward { + t.Fatalf("expected forward panel, got %+v", ev) + } +} + +func TestHandleLocalHotkeyPlugin(t *testing.T) { + a := &App{uiEvents: make(chan event.UIEvent, 4), cfg: &Config{hotkeyMod: "ctrl+alt"}} + a.SetUIEnabled(true) + m := uiModel{app: a} + + if !handleLocalHotkey(&m, "ctrl+alt+p") { + t.Fatalf("expected plugin hotkey handled") + } + ev := mustReadEvent(t, a.uiEvents) + if ev.Kind != event.UIEventPanel || ev.Panel != event.UIPanelPlugin { + t.Fatalf("expected plugin panel, got %+v", ev) + } +} + +func TestHandleLocalHotkeyMode(t *testing.T) { + a := &App{uiEvents: make(chan event.UIEvent, 4), cfg: &Config{hotkeyMod: "ctrl+alt"}} + a.SetUIEnabled(true) + m := uiModel{app: a} + + if !handleLocalHotkey(&m, "ctrl+alt+m") { + t.Fatalf("expected mode hotkey handled") + } + ev := mustReadEvent(t, a.uiEvents) + if ev.Kind != event.UIEventPanel || ev.Panel != event.UIPanelMode { + t.Fatalf("expected mode panel, got %+v", ev) + } +} + +func TestHandleLocalHotkeyUnknown(t *testing.T) { + a := &App{cfg: &Config{hotkeyMod: "ctrl+alt"}} + m := uiModel{app: a} + + if handleLocalHotkey(&m, "ctrl+alt+x") { + t.Fatalf("unknown hotkey should not be handled") + } +} + +func TestHandleLocalHotkeyCtrlShift(t *testing.T) { + a := &App{uiEvents: make(chan event.UIEvent, 4), cfg: &Config{hotkeyMod: "ctrl+shift"}} + a.SetUIEnabled(true) + m := uiModel{app: a} + + if !handleLocalHotkey(&m, "ctrl+shift+h") { + t.Fatalf("expected ctrl+shift+h to be handled") + } + ev := mustReadEvent(t, a.uiEvents) + if ev.Kind != event.UIEventModal { + t.Fatalf("expected help modal with ctrl+shift+h") + } +} + +func TestRenderPanelModal(t *testing.T) { + lines := []panelLine{ + {text: "Header", selected: false}, + {text: "Selected Row", selected: true}, + } + out := renderPanelModal("Test Panel", lines, "Footer text", 80) + if !strings.Contains(out, "Test Panel") { + t.Fatalf("missing title: %q", out) + } + if !strings.Contains(out, "Header") { + t.Fatalf("missing header line: %q", out) + } + if !strings.Contains(out, "Selected Row") { + t.Fatalf("missing selected line: %q", out) + } + if !strings.Contains(out, "Footer text") { + t.Fatalf("missing footer: %q", out) + } +} + +func TestStyleFunctions(t *testing.T) { + _ = modalFooterLineStyle() + rendered := selectedPanelLineStyle().Render("test") + if !strings.Contains(rendered, "test") { + t.Fatalf("selectedPanelLineStyle should render text") + } +}