diff --git a/.gitignore b/.gitignore index de1d6ee..8d0fab5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,9 @@ -/build/ -.idea -dist/ -/go.sum -/view/* +/build/ +.idea +dist/ +/go.sum +/view/* +.claude/ +*.exe +coverage.out +CLAUDE.md diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 1fca519..c891360 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -1,61 +1,62 @@ -#file: noinspection YAMLSchemaValidation -# This is an example .goreleaser.yml file with some sensible defaults. -# Make sure to check the documentation at https://goreleaser.com - -# The lines below are called `modelines`. See `:help modeline` -# Feel free to remove those if you don't want/need to use them. -# yaml-language-server: $schema=https://goreleaser.com/static/schema.json -# vim: set ts=2 sw=2 tw=0 fo=cnqoj - -version: 1 - -before: - hooks: - # You may remove this if you don't use go modules. -# - go mod tidy - # you may remove this if you don't need go generate -# - go generate ./... - -builds: - - env: - - CGO_ENABLED=0 - goos: - - linux - - windows - - darwin - ldflags: - - -s -w - -upx: - - enabled: true - goos: - - windows - goarch: - - amd64 - -archives: - - format: tar.gz - # this name template makes the OS and Arch compatible with the results of `uname`. - name_template: >- - {{ .ProjectName }}_ - {{- title .Os }}_ - {{- if eq .Arch "amd64" }}x86_64 - {{- else if eq .Arch "386" }}i386 - {{- else }}{{ .Arch }}{{ end }} - {{- if .Arm }}v{{ .Arm }}{{ end }} - # use zip for windows archives - format_overrides: - - goos: windows - format: zip -checksum: - name_template: 'checksums.txt' - -snapshot: - name_template: 'v1.0.0-snapshot' - -changelog: - sort: asc - filters: - exclude: - - "^docs:" - - "^test:" +#file: noinspection YAMLSchemaValidation +# This is an example .goreleaser.yml file with some sensible defaults. +# Make sure to check the documentation at https://goreleaser.com + +# The lines below are called `modelines`. See `:help modeline` +# Feel free to remove those if you don't want/need to use them. +# yaml-language-server: $schema=https://goreleaser.com/static/schema.json +# vim: set ts=2 sw=2 tw=0 fo=cnqoj + +version: 1 + +before: + hooks: + # You may remove this if you don't use go modules. +# - go mod tidy + # you may remove this if you don't need go generate +# - go generate ./... + +builds: + - main: ./cmd/serialterminal + env: + - CGO_ENABLED=0 + goos: + - linux + - windows + - darwin + ldflags: + - -s -w + +upx: + - enabled: true + goos: + - windows + goarch: + - amd64 + +archives: + - format: tar.gz + # this name template makes the OS and Arch compatible with the results of `uname`. + name_template: >- + {{ .ProjectName }}_ + {{- title .Os }}_ + {{- if eq .Arch "amd64" }}x86_64 + {{- else if eq .Arch "386" }}i386 + {{- else }}{{ .Arch }}{{ end }} + {{- if .Arm }}v{{ .Arm }}{{ end }} + # use zip for windows archives + format_overrides: + - goos: windows + format: zip +checksum: + name_template: 'checksums.txt' + +snapshot: + name_template: 'v1.0.0-snapshot' + +changelog: + sort: asc + filters: + exclude: + - "^docs:" + - "^test:" diff --git a/README.md b/README.md index a00bc55..4946074 100644 --- a/README.md +++ b/README.md @@ -1,54 +1,231 @@ # SerialTerminalForWindowsTerminal -在开始这个项目之前,我发现Windows Terminal对串口设备的支持并不理想。 -我试用了一段时间[Zhou-zhi-peng的SerialPortForWindowsTerminal](https://github.com/Zhou-zhi-peng/SerialPortForWindowsTerminal/)项目。 +[English](#english) | [中文](#chinese) -然而,这个项目存在着编码转换的问题,导致数据显示乱码,并且作者目前并没有进行后续支持。因此,我决定创建了这个项目。 +--- -## 功能进展 -* [x] Hex接收发送(大写hex与原文同显) -* [x] 双向编码转换 -* [x] 活动端口探测 -* [x] 数据日志保存 -* [x] Hex断帧设置 -* [x] UDP数据转发(支持多服) -* [x] TCP数据转发(支持多服) -* [x] 参数交互配置 -* [x] Ctrl组合键 -* [x] 文件接收发送(trzsz lrzsz都支持) +## English -## 运行示例 +A cross-platform serial terminal with TUI, charset conversion, TCP/UDP forwarding, Lua plugins, and file transfer support. -1. 参数帮助 `./COM` +### Features - ![img1.png](image/img1.png) +- **Serial communication** with full port configuration (baud, data bits, stop bits, parity) +- **Hex mode** for binary protocol inspection with configurable frame size and timestamps +- **Charset conversion** — e.g., read GBK device output as UTF-8 in your terminal +- **TCP/UDP forwarding** — broadcast serial data to multiple servers, receive from any +- **Lua plugin system** — transform input/output data or intercept commands with Lua scripts +- **File transfer** via trzsz / lrzsz protocols +- **TUI mode** (`-g`) with Bubble Tea interface: viewport, input bar, modal panels +- **Console mode** (default) with dot-command prefix (`.` at line start) +- **Interactive setup wizard** when no port is specified -2. 输入设备输出UTF8 终端输出GBK `./COM -p COM8 -b 115200 -o GBK` +### Quick Start - ![img2.png](image/img2.png) -3. 彩色终端输出 +```bash +go build -o sterm ./cmd/serialterminal - ![img3.png](image/img3.png) +# Connect to serial port +./sterm -p COM8 -b 115200 -4. Hex接收 `./COM -p COM8 -b 115200 -i hex` - - ![img4.png](image/img4.png) -5. Hex发送 `./COM -p COM8 -b 115200` +# With charset conversion (device outputs GBK, terminal shows UTF-8) +./sterm -p COM8 -b 115200 -o GBK - ![img5.png](image/img5.png) -6. 交互配置 `./COM` +# Hex mode +./sterm -p COM8 -b 115200 -i hex - ![img6.png](image/img6.png) -7. Ctrl组合键发送指令.ctrl `.ctrl c` - - ![img7.png](image/img7.png) -8. 文件上传演示 `index.html` - ![img8.png](image/img8.png) - 内容对比 - ![img11.png](image/img11.png) -9. 时间戳 `./COM -p COM8 -t` - ![img9.png](image/img9.png) -10. 格式修改 `./COM -p COM11 -t='<2006-01-02 15:04:05>'` - ![img10.png](image/img10.png) -11. 多服同步转发 `./COM -p COM11 -f 1 -a 127.0.0.1:23456 -f 1 -a 127.0.0.1:23457` - ![img12.png](image/img12.png) \ No newline at end of file +# TUI mode +./sterm -p COM8 -b 115200 -g + +# With TCP forwarding +./sterm -p COM8 -f 1 -a 127.0.0.1:12345 + +# Interactive (no port specified) +./sterm +``` + +### CLI Flags + +| Short | Long | Type | Default | Description | +|---|---|---|---|---| +| `-p` | `--port` | string | `""` | Serial port (`/dev/ttyUSB0`, `COMx`) | +| `-b` | `--baud` | int | `115200` | Baud rate | +| `-d` | `--data` | int | `8` | Data bits (5/6/7/8) | +| `-s` | `--stop` | int | `0` | Stop bits (0:1, 1:1.5, 2:2) | +| `-v` | `--verify` | int | `0` | Parity (0:none, 1:odd, 2:even, 3:mark, 4:space) | +| `-o` | `--out` | string | `UTF-8` | Output charset | +| `-i` | `--in` | string | `UTF-8` | Input charset (use `hex` for hex mode) | +| `-e` | `--end` | string | `\n` | Line ending sent to device | +| `-F` | `--Frame` | int | `16` | Hex frame size | +| `-g` | `--gui` | bool | `false` | Enable TUI mode | +| `-k` | `--hotkey-mod` | string | `ctrl+alt` | Hotkey modifier (`ctrl+alt` or `ctrl+shift`) | +| `-f` | `--forward` | []int | `nil` | Forward mode (1:TCP, 2:UDP, repeatable) | +| `-a` | `--address` | []string | `nil` | Forward address (repeatable) | +| `-l` | `--log` | string | `""` | Log file path | +| `-t` | `--time` | string | `""` | Timestamp format | + +### Dot Commands + +In console mode, type `.` at line start to enter command mode: + +| Command | Description | +|---|---| +| `.help` | Show command help | +| `.exit` | Exit the terminal | +| `.hex ` | Send raw hex bytes | +| `.forward list\|add\|remove\|enable\|disable\|update` | Manage forwarding | +| `.plugin list\|load\|unload\|enable\|disable\|reload` | Manage Lua plugins | +| `.mode show\|set ` | View or change runtime settings | + +### Plugin System + +Create `.lua` files and load them with `.plugin load `: + +```lua +-- Transform outgoing data (append marker) +function OnInput(payload) + return payload .. "\r\n" +end + +-- Transform incoming data (add prefix) +function OnOutput(payload) + return "[DEV] " .. payload +end + +-- Intercept or modify commands (return false to block) +function OnCommand(line) + return line, true +end +``` + +Plugins chain: each enabled plugin sees the output of the previous one. Return `nil` to drop data. + +### Architecture + +``` +cmd/serialterminal/ # Entry point +internal/ + termapp/ # Core application (App, TUI, console, commands) + config/ # Configuration types + session/ # Serial port + trzsz lifecycle + event/ # UI event types + flag/ # CLI flag parsing + interactive wizard +pkg/ + charset/ # Charset conversion utilities + forward/ # TCP/UDP forwarding manager + luaplugin/ # Lua plugin engine +``` + +--- + +## 中文 + +一款跨平台串口终端,支持 TUI 界面、编码转换、TCP/UDP 转发、Lua 插件和文件传输。 + +### 功能特性 + +- **串口通信** — 完整端口配置(波特率、数据位、停止位、校验位) +- **Hex 模式** — 二进制协议调试,可配置帧大小和时间戳 +- **双向编码转换** — 如设备输出 GBK,终端显示 UTF-8 +- **TCP/UDP 数据转发** — 串口数据广播至多台服务器,任一台可回传 +- **Lua 插件系统** — 使用 Lua 脚本转换输入/输出数据或拦截命令 +- **文件传输** — 支持 trzsz / lrzsz 协议 +- **TUI 界面** (`-g`) — 基于 Bubble Tea,带视口、输入栏、模态面板 +- **控制台模式** — 行首 `.` 进入命令模式,支持 Tab 补全 +- **交互配置向导** — 不带端口参数时自动启动 + +### 快速开始 + +```bash +go build -o sterm ./cmd/serialterminal + +# 连接串口 +./sterm -p COM8 -b 115200 + +# 编码转换(设备输出 GBK,终端显示 UTF-8) +./sterm -p COM8 -b 115200 -o GBK + +# Hex 模式 +./sterm -p COM8 -b 115200 -i hex + +# TUI 模式 +./sterm -p COM8 -b 115200 -g + +# TCP 转发 +./sterm -p COM8 -f 1 -a 127.0.0.1:12345 + +# 交互式(不指定端口) +./sterm +``` + +### CLI 参数 + +| 短参 | 长参 | 类型 | 默认值 | 说明 | +|---|---|---|---|---| +| `-p` | `--port` | string | `""` | 串口设备 (`/dev/ttyUSB0`、`COMx`) | +| `-b` | `--baud` | int | `115200` | 波特率 | +| `-d` | `--data` | int | `8` | 数据位 | +| `-s` | `--stop` | int | `0` | 停止位 (0:1, 1:1.5, 2:2) | +| `-v` | `--verify` | int | `0` | 校验 (0:无, 1:奇, 2:偶, 3:1, 4:0) | +| `-o` | `--out` | string | `UTF-8` | 输出编码 | +| `-i` | `--in` | string | `UTF-8` | 输入编码 (`hex` 开启 Hex 模式) | +| `-e` | `--end` | string | `\n` | 发送到设备的换行符 | +| `-F` | `--Frame` | int | `16` | Hex 帧大小 | +| `-g` | `--gui` | bool | `false` | 启用 TUI 界面 | +| `-k` | `--hotkey-mod` | string | `ctrl+alt` | 快捷键修饰 (`ctrl+alt` 或 `ctrl+shift`) | +| `-f` | `--forward` | []int | `nil` | 转发模式 (1:TCP, 2:UDP, 可多次传入) | +| `-a` | `--address` | []string | `nil` | 转发地址 (可多次传入) | +| `-l` | `--log` | string | `""` | 日志文件路径 | +| `-t` | `--time` | string | `""` | 时间戳格式 | + +### 点命令 + +控制台模式下,行首输入 `.` 进入命令模式: + +| 命令 | 说明 | +|---|---| +| `.help` | 显示帮助 | +| `.exit` | 退出终端 | +| `.hex <数据>` | 发送原始 Hex 字节 | +| `.forward list\|add\|remove\|enable\|disable\|update` | 管理转发 | +| `.plugin list\|load\|unload\|enable\|disable\|reload` | 管理 Lua 插件 | +| `.mode show\|set <字段> <值>` | 查看或修改运行时设置 | + +### 插件系统 + +编写 `.lua` 文件,通过 `.plugin load <路径>` 加载: + +```lua +-- 转换输出数据(追加换行) +function OnInput(payload) + return payload .. "\r\n" +end + +-- 转换输入数据(添加前缀) +function OnOutput(payload) + return "[DEV] " .. payload +end + +-- 拦截命令(返回 false 阻止执行) +function OnCommand(line) + return line, true +end +``` + +插件链式执行,每个启用的插件接收上一个插件的输出。返回 `nil` 可丢弃数据。 + +### 架构说明 + +``` +cmd/serialterminal/ # 入口点 +internal/ + termapp/ # 核心应用(App、TUI、控制台、命令) + config/ # 配置类型 + session/ # 串口 + trzsz 生命周期 + event/ # UI 事件类型 + flag/ # CLI 参数解析 + 交互向导 +pkg/ + charset/ # 编码转换工具 + forward/ # TCP/UDP 转发管理 + luaplugin/ # Lua 插件引擎 +``` diff --git a/cmd/serialterminal/main.go b/cmd/serialterminal/main.go new file mode 100644 index 0000000..2ce1a83 --- /dev/null +++ b/cmd/serialterminal/main.go @@ -0,0 +1,15 @@ +package main + +import ( + "log" + + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/console" +) + +func init() { + log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile | log.Lmsgprefix) +} + +func main() { + console.Run() +} diff --git a/command.go b/command.go deleted file mode 100644 index dfe9c35..0000000 --- a/command.go +++ /dev/null @@ -1,64 +0,0 @@ -package main - -import ( - "encoding/hex" - "fmt" - "log" - "os" - "strings" -) - -type Command struct { - name string - description string - function func() -} - -var ( - commands []Command - args []string -) - -func cmdhelp() { - var page = 0 - strout(out, config.outputCode, fmt.Sprintf(">-------Help(%v)-------<\n", page)) - for i := 0; i < len(commands); i++ { - strout(out, config.outputCode, fmt.Sprintf(" %-10v --%v\n", commands[i].name, commands[i].description)) - } -} -func cmdexit() { - CloseTrzsz() - CloseSerial() - os.Exit(0) -} -func cmdargs() { - strout(out, config.outputCode, fmt.Sprintf(">-------Args(%v)-------<\n", len(args)-1)) - strout(out, config.outputCode, fmt.Sprintf("%q\n", args[1:])) -} -func cmdctrl() { - var err error - b := []byte(args[1]) - x := []byte{b[0] & 0x1f} - _, err = serialPort.Write(x) - ErrorF(err) - strout(out, config.outputCode, fmt.Sprintf("Ctrl+%s\n", b)) -} -func cmdhex() { - strout(out, config.outputCode, fmt.Sprintf(">-----Hex Send-----<\n")) - strout(out, config.outputCode, fmt.Sprintf("%q\n", args[1:])) - s := strings.Join(args[1:], "") - b, err := hex.DecodeString(s) - if err != nil { - log.Fatal(err) - } - _, err = serialPort.Write(b) - if err != nil { - log.Fatal(err) - } -} -func cmdinit() { - commands = append(commands, Command{name: ".help", description: "帮助信息", function: cmdhelp}) - commands = append(commands, Command{name: ".ctrl", description: "发送Ctrl组合键", function: cmdctrl}) - commands = append(commands, Command{name: ".hex", description: "发送Hex", function: cmdhex}) - commands = append(commands, Command{name: ".exit", description: "退出终端", function: cmdexit}) -} diff --git a/config.go b/config.go deleted file mode 100644 index 3cb43b8..0000000 --- a/config.go +++ /dev/null @@ -1,68 +0,0 @@ -package main - -import ( - "fmt" - "log" - "net" - "os" - "time" -) - -type Config struct { - portName string - baudRate int - dataBits int - stopBits int - parityBit int - outputCode string - inputCode string - endStr string - enableLog bool - logFilePath string - forWard []int - frameSize int - timesTamp bool - timesFmt string - address []string -} -type FoeWardMode int - -const ( - NOT FoeWardMode = iota - TCPC - UDPC -) - -var config Config - -func setForWardClient(mode FoeWardMode, add string) (conn net.Conn) { - var err error - switch mode { - case NOT: - - case TCPC: - conn, err = net.Dial("tcp", add) - if err != nil { - log.Fatal(err) - } - case UDPC: - conn, err = net.Dial("udp", add) - if err != nil { - log.Fatal(err) - } - default: - panic("未知模式设置") - } - return conn -} - -func checkLogOpen() { - if config.enableLog { - path := fmt.Sprintf(config.logFilePath, config.portName, time.Now().Format("2006_01_02T150405")) - f, err := os.OpenFile(path, os.O_RDWR|os.O_APPEND|os.O_CREATE, 0666) - if err != nil { - log.Fatal(err) - } - outs = append(outs, f) - } -} diff --git a/flag.go b/flag.go deleted file mode 100644 index 1c05a78..0000000 --- a/flag.go +++ /dev/null @@ -1,312 +0,0 @@ -package main - -import ( - "fmt" - "github.com/charmbracelet/bubbles/key" - inf "github.com/fzdwx/infinite" - "github.com/fzdwx/infinite/color" - "github.com/fzdwx/infinite/components" - "github.com/fzdwx/infinite/components/input/text" - "github.com/fzdwx/infinite/components/selection/confirm" - "github.com/fzdwx/infinite/components/selection/singleselect" - "github.com/fzdwx/infinite/style" - "github.com/fzdwx/infinite/theme" - "github.com/spf13/pflag" - "go.bug.st/serial" - "log" - "strconv" - "strings" -) - -type ptrVal struct { - *string - sl *[]string - *int - il *[]int - *bool - *float64 - *float32 - ext *string -} -type Val struct { - string - int - bool - float64 - float32 - extdef string -} -type Flag struct { - v ptrVal - sStr string - lStr string - dv Val - help string -} - -var ( - portName = Flag{ptrVal{string: &config.portName}, "p", "port", Val{string: ""}, "要连接的串口\t(/dev/ttyUSB0、COMx)"} - baudRate = Flag{ptrVal{int: &config.baudRate}, "b", "baud", Val{int: 115200}, "波特率"} - dataBits = Flag{ptrVal{int: &config.dataBits}, "d", "data", Val{int: 8}, "数据位"} - stopBits = Flag{ptrVal{int: &config.stopBits}, "s", "stop", Val{int: 0}, "停止位停止位(0: 1停止 1:1.5停止 2:2停止)"} - outputCode = Flag{ptrVal{string: &config.outputCode}, "o", "out", Val{string: "UTF-8"}, "输出编码"} - inputCode = Flag{ptrVal{string: &config.inputCode}, "i", "in", Val{string: "UTF-8"}, "输入编码"} - endStr = Flag{ptrVal{string: &config.endStr}, "e", "end", Val{string: "\n"}, "终端换行符"} - logExt = Flag{v: ptrVal{ext: &config.logFilePath}, sStr: "l", lStr: "log", dv: Val{extdef: "./%s-$s.txt", string: ""}, help: "日志保存路径"} - timeExt = Flag{v: ptrVal{ext: &config.timesFmt}, sStr: "t", lStr: "time", dv: Val{extdef: "[06-01-02 15:04:05.000]", string: ""}, help: "时间戳格式化字段"} - forWard = Flag{ptrVal{il: &config.forWard}, "f", "forward", Val{int: 0}, "转发模式(0: 无 1:TCP-C 2:UDP-C 支持多次传入)"} - address = Flag{ptrVal{sl: &config.address}, "a", "address", Val{string: "127.0.0.1:12345"}, "转发服务地址(支持多次传入)"} - frameSize = Flag{ptrVal{int: &config.frameSize}, "F", "Frame", Val{int: 16}, "帧大小"} - parityBit = Flag{ptrVal{int: &config.parityBit}, "v", "verify", Val{int: 0}, "奇偶校验(0:无校验、1:奇校验、2:偶校验、3:1校验、4:0校验)"} - flags = []Flag{portName, baudRate, dataBits, stopBits, outputCode, inputCode, endStr, forWard, address, frameSize, parityBit, logExt, timeExt} -) - -var ( - bauds = []string{"自定义", "300", "600", "1200", "2400", "4800", "9600", - "14400", "19200", "38400", "56000", "57600", "115200", "128000", - "256000", "460800", "512000", "750000", "921600", "1500000"} - datas = []string{"5", "6", "7", "8"} - stops = []string{"1", "1.5", "2"} - paritys = []string{"无校验", "奇校验", "偶校验", "1校验", "0校验"} - forwards = []string{"No", "TCP-C", "UDP-C"} -) - -type ValType int - -const ( - notVal ValType = iota - stringVal - intVal - boolVal - extVal -) - -func printUsage(ports []string) { - fmt.Printf("\n参数帮助:\n") - for _, f := range flags { - flagprint(f) - } - fmt.Printf("\n在线串口: %v\n", strings.Join(ports, ",")) -} -func flagFindValue(v ptrVal) ValType { - if v.string != nil { - return stringVal - } - if v.bool != nil { - return boolVal - } - if v.int != nil { - return intVal - } - if v.ext != nil { - return extVal - } - return notVal -} -func flagprint(f Flag) { - switch flagFindValue(f.v) { - case stringVal: - fmt.Printf("\t-%v -%v %T \n\t %v\t默认值:%q\n", f.sStr, f.lStr, f.dv.string, f.help, f.dv.string) - case intVal: - fmt.Printf("\t-%v -%v %T \n\t %v\t默认值:%v\n", f.sStr, f.lStr, f.dv.int, f.help, f.dv.int) - case boolVal: - fmt.Printf("\t-%v -%v %T \n\t %v\t默认值:%v\n", f.sStr, f.lStr, f.dv.bool, f.help, f.dv.bool) - case extVal: - fmt.Printf("\t-%v -%v %T \n\t %v\t默认值:%v\n", f.sStr, f.lStr, f.dv.extdef, f.help, f.dv.extdef) - default: - panic("unhandled default case") - } -} -func flagInit(f *Flag) { - if f.v.string != nil { - pflag.StringVarP(f.v.string, f.lStr, f.sStr, f.dv.string, f.help) - } - if f.v.bool != nil { - pflag.BoolVarP(f.v.bool, f.lStr, f.sStr, f.dv.bool, f.help) - } - if f.v.int != nil { - pflag.IntVarP(f.v.int, f.lStr, f.sStr, f.dv.int, f.help) - } - if f.v.ext != nil { - pflag.StringVarP(f.v.ext, f.lStr, f.sStr, f.dv.string, f.help) - pflag.Lookup(f.lStr).NoOptDefVal = f.dv.extdef - } - if f.v.sl != nil { - pflag.StringArrayVarP(f.v.sl, f.lStr, f.sStr, []string{f.dv.string}, f.help) - } - if f.v.il != nil { - pflag.IntSliceVarP(f.v.il, f.lStr, f.sStr, []int{f.dv.int}, f.help) - } -} -func flagExt() { - if config.logFilePath != "" { - config.enableLog = true - } - if config.timesFmt != "" { - config.timesTamp = true - } -} -func getCliFlag() { - ports, err := serial.GetPortsList() - if err != nil { - log.Fatal(err) - } - - inputs := components.NewInput() - inputs.Prompt = "Filtering: " - inputs.PromptStyle = style.New().Bold().Italic().Fg(color.LightBlue) - - selectKeymap := singleselect.DefaultSingleKeyMap() - selectKeymap.Confirm = key.NewBinding( - key.WithKeys("enter"), - key.WithHelp("enter", "finish select"), - ) - selectKeymap.Choice = key.NewBinding( - key.WithKeys("enter"), - key.WithHelp("enter", "finish select"), - ) - selectKeymap.NextPage = key.NewBinding( - key.WithKeys("right"), - key.WithHelp("->", "next page"), - ) - selectKeymap.PrevPage = key.NewBinding( - key.WithKeys("left"), - key.WithHelp("<-", "prev page"), - ) - - s, _ := inf.NewSingleSelect( - ports, - singleselect.WithKeyBinding(selectKeymap), - singleselect.WithPageSize(4), - singleselect.WithFilterInput(inputs), - ).Display("选择串口") - config.portName = ports[s] - - s, _ = inf.NewSingleSelect( - bauds, - singleselect.WithKeyBinding(selectKeymap), - singleselect.WithPageSize(4), - ).Display("选择波特率") - if s != 0 { - config.baudRate, _ = strconv.Atoi(bauds[s]) - } else { - b, _ := inf.NewText( - text.WithPrompt("BaudRate:"), - text.WithPromptStyle(theme.DefaultTheme.PromptStyle), - text.WithDefaultValue("115200"), - ).Display() - config.baudRate, _ = strconv.Atoi(b) - } - v, _ := inf.NewConfirmWithSelection( - confirm.WithPrompt("启用Hex"), - ).Display() - if v { - config.inputCode = "hex" - b, _ := inf.NewText( - text.WithPrompt("Frames:"), - text.WithPromptStyle(theme.DefaultTheme.PromptStyle), - text.WithDefaultValue("16"), - ).Display() - config.frameSize, _ = strconv.Atoi(b) - } - v, _ = inf.NewConfirmWithSelection( - confirm.WithPrompt("启用时间戳"), - ).Display() - config.timesTamp = v - if v { - b, _ := inf.NewText( - text.WithPrompt("格式化字段:"), - text.WithPromptStyle(theme.DefaultTheme.PromptStyle), - text.WithDefaultValue(timeExt.dv.extdef), - ).Display() - config.timesFmt = b - } - v, _ = inf.NewConfirmWithSelection( - confirm.WithPrompt("启用高级配置"), - ).Display() - if v { - s, _ = inf.NewSingleSelect( - datas, - singleselect.WithKeyBinding(selectKeymap), - singleselect.WithPageSize(4), - singleselect.WithFilterInput(inputs), - ).Display("选择数据位") - config.dataBits, _ = strconv.Atoi(datas[s]) - - s, _ = inf.NewSingleSelect( - stops, - singleselect.WithKeyBinding(selectKeymap), - singleselect.WithPageSize(4), - singleselect.WithFilterInput(inputs), - ).Display("选择停止位") - config.stopBits = s - - s, _ = inf.NewSingleSelect( - paritys, - singleselect.WithKeyBinding(selectKeymap), - singleselect.WithPageSize(4), - singleselect.WithFilterInput(inputs), - ).Display("选择校验位") - config.parityBit = s - - t, _ := inf.NewText( - text.WithPrompt("换行符:"), - text.WithPromptStyle(theme.DefaultTheme.PromptStyle), - text.WithDefaultValue(endStr.dv.string), - ).Display() - config.endStr = t - - v, _ = inf.NewConfirmWithSelection( - confirm.WithDefaultYes(), - confirm.WithPrompt("启用编码转换"), - ).Display() - - if v { - t, _ = inf.NewText( - text.WithPrompt("输入编码:"), - text.WithPromptStyle(theme.DefaultTheme.PromptStyle), - text.WithDefaultValue(inputCode.dv.string), - ).Display() - config.inputCode = t - - t, _ = inf.NewText( - text.WithPrompt("输出编码:"), - text.WithPromptStyle(theme.DefaultTheme.PromptStyle), - text.WithDefaultValue(outputCode.dv.string), - ).Display() - config.outputCode = t - } - G_F_mode: - s, _ = inf.NewSingleSelect( - forwards, - singleselect.WithKeyBinding(selectKeymap), - singleselect.WithPageSize(3), - singleselect.WithFilterInput(inputs), - ).Display("选择转发模式") - if s != 0 { - config.forWard = append(config.forWard, s) - t, _ = inf.NewText( - text.WithPrompt("地址:"), - text.WithPromptStyle(theme.DefaultTheme.PromptStyle), - text.WithDefaultValue(address.dv.string), - ).Display() - config.address = append(config.address, t) - goto G_F_mode - } - - e, _ := inf.NewConfirmWithSelection( - confirm.WithDefaultYes(), - confirm.WithPrompt("启用日志"), - ).Display() - config.enableLog = e - if e { - t, _ = inf.NewText( - text.WithPrompt("Path:"), - text.WithPromptStyle(theme.DefaultTheme.PromptStyle), - text.WithDefaultValue("./%s-$s.txt"), - ).Display() - config.logFilePath = t - } - } - -} diff --git a/go.mod b/go.mod index 88582a1..7650112 100644 --- a/go.mod +++ b/go.mod @@ -1,15 +1,18 @@ -module COM +module github.com/jixishi/SerialTerminalForWindowsTerminal -go 1.22 +go 1.23.0 require ( - github.com/charmbracelet/bubbles v0.18.0 + github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7 + github.com/charmbracelet/bubbletea v1.3.6 + github.com/charmbracelet/lipgloss v1.1.0 github.com/fzdwx/infinite v0.12.1 - github.com/gobwas/ws v1.4.0 github.com/spf13/pflag v1.0.5 github.com/trzsz/trzsz-go v1.1.7 + github.com/yuin/gopher-lua v1.1.1 github.com/zimolab/charsetconv v0.1.2 go.bug.st/serial v1.6.2 + golang.org/x/sys v0.33.0 golang.org/x/term v0.19.0 ) @@ -19,37 +22,37 @@ require ( github.com/alexflint/go-scalar v1.2.0 // indirect github.com/atotto/clipboard v0.1.4 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect - github.com/charmbracelet/bubbletea v0.25.0 // indirect - github.com/charmbracelet/lipgloss v0.9.1 // indirect + github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect + github.com/charmbracelet/x/ansi v0.9.3 // indirect + github.com/charmbracelet/x/cellbuf v0.0.13 // indirect + github.com/charmbracelet/x/term v0.2.1 // indirect github.com/chzyer/readline v1.5.1 // indirect - github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 // indirect github.com/creack/goselect v0.1.2 // indirect - github.com/creack/pty v1.1.21 // indirect + github.com/creack/pty v1.1.24 // indirect github.com/dchest/jsmin v0.0.0-20220218165748-59f39799265f // indirect github.com/duke-git/lancet/v2 v2.2.1 // indirect + github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/fzdwx/iter v0.0.0-20230511075109-0afee9319312 // indirect - github.com/gobwas/httphead v0.1.0 // indirect - github.com/gobwas/pool v0.2.1 // indirect github.com/josephspurrier/goversioninfo v1.4.0 // indirect github.com/klauspost/compress v1.17.4 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect - github.com/mattn/go-isatty v0.0.19 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-localereader v0.0.1 // indirect - github.com/mattn/go-runewidth v0.0.15 // indirect + github.com/mattn/go-runewidth v0.0.16 // indirect github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect github.com/muesli/cancelreader v0.2.2 // indirect github.com/muesli/reflow v0.3.0 // indirect - github.com/muesli/termenv v0.15.2 // indirect + github.com/muesli/termenv v0.16.0 // indirect github.com/ncruces/zenity v0.10.10 // indirect github.com/randall77/makefat v0.0.0-20210315173500-7ddd0e42c844 // indirect - github.com/rivo/uniseg v0.4.6 // indirect + github.com/rivo/uniseg v0.4.7 // indirect github.com/rotisserie/eris v0.5.4 // indirect - github.com/sahilm/fuzzy v0.1.1-0.20230530133925-c48e322e2a8f // indirect + github.com/sahilm/fuzzy v0.1.1 // indirect github.com/trzsz/go-arg v1.5.3 // indirect github.com/trzsz/promptui v0.10.5 // indirect - golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect + github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect + golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect golang.org/x/image v0.14.0 // indirect - golang.org/x/sync v0.2.0 // indirect - golang.org/x/sys v0.19.0 // indirect - golang.org/x/text v0.14.0 // indirect + golang.org/x/sync v0.15.0 // indirect + golang.org/x/text v0.23.0 // indirect ) diff --git a/internal/app/app.go b/internal/app/app.go new file mode 100644 index 0000000..cd1cd1f --- /dev/null +++ b/internal/app/app.go @@ -0,0 +1,416 @@ +// Package app provides the core application coordinator. +package app + +import ( + "bytes" + "fmt" + "io" + "os" + "path/filepath" + "strings" + "sync" + "sync/atomic" + "time" + + appconfig "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/config" + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/event" + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/session" + "github.com/jixishi/SerialTerminalForWindowsTerminal/pkg/charset" + "github.com/jixishi/SerialTerminalForWindowsTerminal/pkg/forward" + "github.com/jixishi/SerialTerminalForWindowsTerminal/pkg/luaplugin" + + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/command" +) + +// App is the central coordinator for the serial terminal application. +type App struct { + cfg *appconfig.Config + sess *session.SerialSession + out io.Writer + + forward *forward.Manager + plugins *luaplugin.Manager + dispatcher *command.Dispatcher + + uiEvents chan event.UIEvent + done chan struct{} + + stdinMu sync.Mutex + closeOnce sync.Once + closedFlag atomic.Bool + uiEnabled atomic.Bool + + logFile *os.File +} + +var _ command.CommandHost = (*App)(nil) + +// New creates a new App with the given configuration, session, and output writer. +func New(cfg *appconfig.Config, sess *session.SerialSession, out io.Writer) (*App, error) { + f, err := appconfig.OpenLogFile(cfg) + if err != nil { + return nil, err + } + + a := &App{ + cfg: cfg, + sess: sess, + out: out, + plugins: luaplugin.NewManager(), + uiEvents: make(chan event.UIEvent, 512), + done: make(chan struct{}), + logFile: f, + } + a.uiEnabled.Store(true) + + a.forward = forward.NewManager(a.writeRawToSession, a.Notifyf) + a.forward.SetInboundReporter(a.reportForwardIngress) + a.dispatcher = command.NewDispatcher(a) + if err = a.loadPluginsFromDir(); err != nil { + return nil, err + } + return a, nil +} + +// --- command.CommandHost implementation --- + +func (a *App) Cfg() *appconfig.Config { return a.cfg } +func (a *App) Forward() *forward.Manager { return a.forward } +func (a *App) Plugins() *luaplugin.Manager { return a.plugins } +func (a *App) WriteToSession(data []byte) error { return a.writeToSession(data) } + +// --- exported accessors for TUI / console --- + +func (a *App) UIEvents() <-chan event.UIEvent { return a.uiEvents } +func (a *App) WaitDone() <-chan struct{} { return a.done } +func (a *App) SendCtrl(letter byte) error { return a.sendCtrl(letter) } +func (a *App) HandleLine(line string) { a.handleLine(line) } +func (a *App) Dispatcher() *command.Dispatcher { return a.dispatcher } +func (a *App) StartOutputLoop() { a.startOutputLoop() } +func (a *App) LoadConfiguredForwards() { a.loadConfiguredForwards() } +func (a *App) Sess() *session.SerialSession { return a.sess } +func (a *App) Out() io.Writer { return a.out } + +func (a *App) loadPluginsFromDir() error { + entries, err := os.ReadDir("plugins") + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + + for _, entry := range entries { + if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".lua") { + continue + } + pluginPath := filepath.Join("plugins", entry.Name()) + name, loadErr := a.plugins.Load(pluginPath) + if loadErr != nil { + a.Notifyf("[plugin] load %s failed: %v", entry.Name(), loadErr) + continue + } + // Disable by default; user enables via .plugin enable or TUI panel + _ = a.plugins.Disable(name) + } + return nil +} + +func (a *App) Notifyf(format string, args ...any) { + a.emit(event.UIEvent{Kind: event.UIEventOutput, Text: fmt.Sprintf(format, args...)}) +} + +func (a *App) Statusf(format string, args ...any) { + a.emit(event.UIEvent{Kind: event.UIEventStatus, Text: fmt.Sprintf(format, args...)}) +} + +func (a *App) ShowModal(title, text string) { + a.emit(event.UIEvent{Kind: event.UIEventModal, Title: title, Text: text}) +} + +func (a *App) OpenPanel(panel event.UIPanelKind) { + a.emit(event.UIEvent{Kind: event.UIEventPanel, Panel: panel}) +} + +func (a *App) SetUIEnabled(enabled bool) { + a.uiEnabled.Store(enabled) +} + +func (a *App) UIEnabled() bool { + return a.uiEnabled.Load() +} + +func (a *App) emit(ev event.UIEvent) { + if ev.Kind != event.UIEventPanel && ev.Text == "" { + return + } + + if !a.UIEnabled() { + switch ev.Kind { + case event.UIEventOutput: + _, _ = io.WriteString(a.out, ev.Text) + case event.UIEventStatus: + _, _ = io.WriteString(a.out, ev.Text) + if !strings.HasSuffix(ev.Text, "\n") { + _, _ = io.WriteString(a.out, "\n") + } + case event.UIEventModal: + _, _ = io.WriteString(a.out, "\n["+ev.Title+"]\n"+ev.Text+"\n") + } + if ev.Kind == event.UIEventOutput { + a.appendLog(ev.Text) + } + return + } + + select { + case a.uiEvents <- ev: + default: + select { + case <-a.uiEvents: + default: + } + a.uiEvents <- ev + } + + if ev.Kind == event.UIEventOutput { + a.appendLog(ev.Text) + } +} + +func (a *App) appendLog(text string) { + if a.logFile == nil { + return + } + _, _ = a.logFile.WriteString(text) +} + +func (a *App) Close() { + a.closeOnce.Do(func() { + a.closedFlag.Store(true) + close(a.done) + a.forward.Close() + a.plugins.Close() + if a.sess != nil { + a.sess.Close() + } + if a.logFile != nil { + _ = a.logFile.Close() + } + }) +} + +func (a *App) loadConfiguredForwards() { + for i, mode := range a.cfg.ForWard { + m := forward.Mode(mode) + if m == forward.None { + continue + } + if i >= len(a.cfg.Address) { + a.Notifyf("[forward] skip #%d: missing address", i) + continue + } + addr := strings.TrimSpace(a.cfg.Address[i]) + if addr == "" { + continue + } + if _, err := a.forward.Add(m, addr); err != nil { + a.Notifyf("[forward] add %s %s failed: %v", m.String(), addr, err) + } + } +} + +func (a *App) reportForwardIngress(id int, chunk []byte) { + if len(chunk) == 0 { + return + } + if strings.EqualFold(a.cfg.InputCode, "hex") { + a.Notifyf("[forward#%d -> serial] % X\n", id, chunk) + return + } + converted, err := charset.ConvertChunk(chunk, a.cfg.InputCode, a.cfg.OutputCode) + if err != nil { + converted = bytes.Clone(chunk) + } + text := string(converted) + if !strings.HasSuffix(text, "\n") { + text += "\n" + } + a.Notifyf("[forward#%d -> serial] %s", id, text) +} + +func (a *App) writeRawToSession(data []byte) error { + if len(data) == 0 { + return nil + } + a.stdinMu.Lock() + defer a.stdinMu.Unlock() + _, err := a.sess.StdinPipe.Write(data) + return err +} + +func (a *App) writeToSession(data []byte) error { + processed, err := a.plugins.ProcessInput(data) + if err != nil { + return err + } + if len(processed) == 0 { + return nil + } + return a.writeRawToSession(processed) +} + +func (a *App) sendLine(line string) error { + if strings.TrimSpace(line) == "" { + return nil + } + payload := append([]byte(line), []byte(a.cfg.EndStr)...) + return a.writeToSession(payload) +} + +func (a *App) sendCtrl(letter byte) error { + if letter >= 'A' && letter <= 'Z' { + letter = letter + ('a' - 'A') + } + control := []byte{letter & 0x1f} + _, err := a.sess.Port.Write(control) + return err +} + +func (a *App) handleLine(line string) { + line = strings.TrimRight(line, "\r\n") + if strings.TrimSpace(line) == "" { + return + } + + if strings.HasPrefix(strings.TrimSpace(line), ".") { + next, allow, err := a.plugins.ProcessCommand(line) + if err != nil { + a.Notifyf("[plugin] command hook failed: %v", err) + return + } + if !allow { + a.Notifyf("[plugin] command blocked") + return + } + if next != "" { + line = next + } + handled, err := a.dispatcher.Execute(line) + if err != nil { + a.Statusf("[cmd] %v", err) + } + if handled { + return + } + } + + if err := a.sendLine(line); err != nil { + a.Statusf("[send] %v", err) + } +} + +func (a *App) startOutputLoop() { + if strings.EqualFold(a.cfg.InputCode, "hex") { + go a.readHexOutput() + return + } + go a.readTextOutput() +} + +func (a *App) readHexOutput() { + frameSize := a.cfg.FrameSize + if frameSize <= 0 { + frameSize = 16 + } + + buf := make([]byte, frameSize) + for { + n, err := a.sess.StdoutPipe.Read(buf) + if n > 0 { + chunk := make([]byte, n) + copy(chunk, buf[:n]) + a.forward.Broadcast(chunk) + outChunk, hookErr := a.plugins.ProcessOutput(chunk) + if hookErr != nil { + a.Notifyf("[plugin] output hook failed: %v", hookErr) + continue + } + if len(outChunk) == 0 { + continue + } + a.emit(event.UIEvent{Kind: event.UIEventOutput, Text: charset.FormatHexFrame(outChunk, a.cfg.TimesTamp, a.cfg.TimesFmt)}) + } + if err != nil { + if err != io.EOF { + a.Notifyf("[output] %v", err) + } + return + } + + select { + case <-a.done: + return + default: + } + } +} + +func (a *App) readTextOutput() { + buf := make([]byte, 4096) + for { + n, err := a.sess.StdoutPipe.Read(buf) + if n > 0 { + chunk := make([]byte, n) + copy(chunk, buf[:n]) + a.forward.Broadcast(chunk) + + outChunk, hookErr := a.plugins.ProcessOutput(chunk) + if hookErr != nil { + a.Notifyf("[plugin] output hook failed: %v", hookErr) + continue + } + if len(outChunk) == 0 { + continue + } + + converted, convErr := charset.ConvertChunk(outChunk, a.cfg.InputCode, a.cfg.OutputCode) + if convErr != nil { + a.Notifyf("[output] convert failed: %v", convErr) + converted = bytes.Clone(outChunk) + } + + text := string(converted) + if a.cfg.TimesTamp { + text = prefixLines(text, time.Now().Format(a.cfg.TimesFmt)+" ") + } + a.emit(event.UIEvent{Kind: event.UIEventOutput, Text: text}) + } + if err != nil { + if err != io.EOF { + a.Notifyf("[output] %v", err) + } + return + } + + select { + case <-a.done: + return + default: + } + } +} + +func prefixLines(s, prefix string) string { + if s == "" || prefix == "" { + return s + } + lines := strings.SplitAfter(s, "\n") + for i, line := range lines { + if line == "" { + continue + } + lines[i] = prefix + line + } + return strings.Join(lines, "") +} diff --git a/internal/app/app_test.go b/internal/app/app_test.go new file mode 100644 index 0000000..24ffc8d --- /dev/null +++ b/internal/app/app_test.go @@ -0,0 +1,230 @@ +package app + +import ( + "io" + "net" + "testing" + "time" + + "go.bug.st/serial" + + appconfig "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/config" + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/command" + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/event" + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/session" + "github.com/jixishi/SerialTerminalForWindowsTerminal/pkg/forward" + "github.com/jixishi/SerialTerminalForWindowsTerminal/pkg/luaplugin" +) + +func newTestApp() *App { + a := &App{ + sess: &session.SerialSession{}, + cfg: &appconfig.Config{EndStr: "\n", InputCode: "UTF-8", OutputCode: "UTF-8"}, + plugins: luaplugin.NewManager(), + uiEvents: make(chan event.UIEvent, 8), + done: make(chan struct{}), + out: io.Discard, + } + a.forward = forward.NewManager(func([]byte) error { return nil }, func(string, ...any) {}) + a.dispatcher = command.NewDispatcher(a) + + var cr *io.PipeReader + cr, a.sess.StdinPipe = io.Pipe() + go func() { + buf := make([]byte, 4096) + for { _, _ = cr.Read(buf) } + }() + return a +} + +func TestPrefixLines(t *testing.T) { + tests := []struct{ name, in, prefix, want string }{ + {name: "empty", in: "", prefix: "X ", want: ""}, + {name: "no-prefix", in: "a\n", prefix: "", want: "a\n"}, + {name: "single-line", in: "abc", prefix: "T ", want: "T abc"}, + {name: "multi-line", in: "a\nb\n", prefix: "P ", want: "P a\nP b\n"}, + } + for _, tt := range tests { + got := prefixLines(tt.in, tt.prefix) + if got != tt.want { + t.Fatalf("%s: got=%q want=%q", tt.name, got, tt.want) + } + } +} + +func TestAppUIEvents(t *testing.T) { + a := &App{uiEvents: make(chan event.UIEvent, 8), sess: &session.SerialSession{}, out: io.Discard} + a.SetUIEnabled(true) + + a.Notifyf("hello %s", "world") + a.Statusf("ok") + a.ShowModal("Title", "Body") + + ev1 := mustReadEvent(t, a.uiEvents) + if ev1.Kind != event.UIEventOutput || ev1.Text != "hello world" { + t.Fatalf("unexpected output: %+v", ev1) + } + ev2 := mustReadEvent(t, a.uiEvents) + if ev2.Kind != event.UIEventStatus || ev2.Text != "ok" { + t.Fatalf("unexpected status: %+v", ev2) + } + ev3 := mustReadEvent(t, a.uiEvents) + if ev3.Kind != event.UIEventModal || ev3.Title != "Title" || ev3.Text != "Body" { + t.Fatalf("unexpected modal: %+v", ev3) + } +} + +func TestSendLine(t *testing.T) { + a := newTestApp() + a.SetUIEnabled(true) + + if err := a.sendLine("hello"); err != nil { + t.Fatalf("sendLine failed: %v", err) + } + if err := a.sendLine(""); err != nil { + t.Fatalf("sendLine empty: %v", err) + } + if err := a.sendLine(" "); err != nil { + t.Fatalf("sendLine whitespace: %v", err) + } +} + +func TestHandleLine(t *testing.T) { + a := newTestApp() + a.SetUIEnabled(true) + + a.handleLine("hello") + a.handleLine("") + a.handleLine(".help") + + ev := mustReadEvent(t, a.uiEvents) + if ev.Kind != event.UIEventModal || ev.Title == "" { + t.Fatalf("expected .help modal, got %+v", ev) + } +} + +func TestEmitNonUI(t *testing.T) { + a := &App{ + out: io.Discard, + uiEvents: make(chan event.UIEvent, 4), + logFile: nil, + sess: &session.SerialSession{}, + } + a.SetUIEnabled(false) + + a.emit(event.UIEvent{Kind: event.UIEventOutput, Text: "serial data\n"}) + a.emit(event.UIEvent{Kind: event.UIEventStatus, Text: "status msg"}) + a.emit(event.UIEvent{Kind: event.UIEventModal, Title: "T", Text: "body"}) + a.emit(event.UIEvent{Kind: event.UIEventOutput, Text: ""}) +} + +func TestEmitUISaturation(t *testing.T) { + a := &App{uiEvents: make(chan event.UIEvent, 2), sess: &session.SerialSession{}, out: io.Discard} + a.SetUIEnabled(true) + + a.emit(event.UIEvent{Kind: event.UIEventOutput, Text: "a"}) + a.emit(event.UIEvent{Kind: event.UIEventOutput, Text: "b"}) + a.emit(event.UIEvent{Kind: event.UIEventOutput, Text: "c"}) + + ev := mustReadEvent(t, a.uiEvents) + if ev.Text != "b" { + t.Fatalf("expected b after drop, got %q", ev.Text) + } + ev = mustReadEvent(t, a.uiEvents) + if ev.Text != "c" { + t.Fatalf("expected c, got %q", ev.Text) + } +} + +func TestAppClose(t *testing.T) { + a := newTestApp() + a.Close() + if !a.closedFlag.Load() { + t.Fatalf("expected app closed") + } + a.Close() // second close safe +} + +func TestLoadConfiguredForwards(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen failed: %v", err) + } + defer listener.Close() + + a := &App{ + sess: &session.SerialSession{}, + cfg: &appconfig.Config{ForWard: []int{int(forward.TCP), int(forward.None), int(forward.UDP)}, Address: []string{listener.Addr().String(), "", ""}}, + forward: forward.NewManager(func([]byte) error { return nil }, func(string, ...any) {}), + uiEvents: make(chan event.UIEvent, 8), + done: make(chan struct{}), + out: io.Discard, + } + a.SetUIEnabled(true) + a.loadConfiguredForwards() + + items := a.forward.List() + if len(items) != 1 || items[0].Mode != "tcp" { + t.Fatalf("expected 1 TCP forward, got %+v", items) + } +} + +func TestReportForwardIngress(t *testing.T) { + a := &App{ + sess: &session.SerialSession{}, + cfg: &appconfig.Config{InputCode: "UTF-8", OutputCode: "UTF-8"}, + uiEvents: make(chan event.UIEvent, 4), + out: io.Discard, + } + a.SetUIEnabled(true) + + a.reportForwardIngress(1, []byte("test")) + a.cfg.InputCode = "hex" + a.reportForwardIngress(2, []byte{0x41, 0x42}) + a.reportForwardIngress(3, nil) +} + +func TestSendCtrl(t *testing.T) { + a := &App{ + sess: &session.SerialSession{}, + cfg: &appconfig.Config{}, + uiEvents: make(chan event.UIEvent, 4), + out: io.Discard, + } + a.sess.Port = &mockSerialPort{} + + if err := a.sendCtrl('c'); err != nil { + t.Fatalf("sendCtrl('c') failed: %v", err) + } + if err := a.sendCtrl('C'); err != nil { + t.Fatalf("sendCtrl('C') failed: %v", err) + } +} + +type mockSerialPort struct{} + +func (m *mockSerialPort) Write(p []byte) (int, error) { return len(p), nil } +func (m *mockSerialPort) Read(p []byte) (int, error) { return 0, io.EOF } +func (m *mockSerialPort) Close() error { return nil } +func (m *mockSerialPort) SetMode(mode *serial.Mode) error { return nil } +func (m *mockSerialPort) SetDTR(dtr bool) error { return nil } +func (m *mockSerialPort) SetRTS(rts bool) error { return nil } +func (m *mockSerialPort) GetModemStatusBits() (*serial.ModemStatusBits, error) { + return &serial.ModemStatusBits{}, nil +} +func (m *mockSerialPort) ResetInputBuffer() error { return nil } +func (m *mockSerialPort) ResetOutputBuffer() error { return nil } +func (m *mockSerialPort) SetReadTimeout(t time.Duration) error { return nil } +func (m *mockSerialPort) Break(t time.Duration) error { return nil } +func (m *mockSerialPort) Drain() error { return nil } + +func mustReadEvent(t *testing.T, ch <-chan event.UIEvent) event.UIEvent { + t.Helper() + select { + case ev := <-ch: + return ev + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for UI event") + return event.UIEvent{} + } +} diff --git a/internal/command/command_test.go b/internal/command/command_test.go new file mode 100644 index 0000000..e7a2f1a --- /dev/null +++ b/internal/command/command_test.go @@ -0,0 +1,72 @@ +package command + +import "testing" + +func TestParseOnOff(t *testing.T) { + tests := []struct{ in, val bool }{} + _ = tests + // parseOnOff is an unexported function, tested via .mode set command integration +} + +func TestCompleteForward(t *testing.T) { + tests := []struct { + args []string + want []string + }{ + {args: []string{".forward"}, want: []string{"list", "add", "remove", "enable", "disable", "update"}}, + {args: []string{".forward", ""}, want: []string{"list", "add", "remove", "enable", "disable", "update"}}, + {args: []string{".forward", "add", ""}, want: []string{"tcp", "udp", "tcp-s", "udp-s", "com"}}, + {args: []string{".forward", "update", "1", ""}, want: []string{"tcp", "udp", "tcp-s", "udp-s", "com"}}, + } + for _, tt := range tests { + got := completeForward(tt.args) + if !stringSlicesEqual(got, tt.want) { + t.Fatalf("completeForward(%v) got=%v want=%v", tt.args, got, tt.want) + } + } +} + +func TestCompletePlugin(t *testing.T) { + tests := []struct { + args []string + want []string + }{ + {args: []string{".plugin"}, want: []string{"list", "load", "unload", "enable", "disable", "reload"}}, + {args: []string{".plugin", "load", ""}, want: nil}, + } + for _, tt := range tests { + got := completePlugin(tt.args) + if !stringSlicesEqual(got, tt.want) { + t.Fatalf("completePlugin(%v) got=%v want=%v", tt.args, got, tt.want) + } + } +} + +func TestCompleteMode(t *testing.T) { + tests := []struct { + args []string + want []string + }{ + {args: []string{".mode"}, want: []string{"show", "set"}}, + {args: []string{".mode", "set", ""}, want: []string{"in", "out", "end", "frame", "timestamp", "timefmt"}}, + {args: []string{".mode", "set", "timestamp", ""}, want: []string{"on", "off"}}, + } + for _, tt := range tests { + got := completeMode(tt.args) + if !stringSlicesEqual(got, tt.want) { + t.Fatalf("completeMode(%v) got=%v want=%v", tt.args, got, tt.want) + } + } +} + +func stringSlicesEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/internal/command/commands.go b/internal/command/commands.go new file mode 100644 index 0000000..006a288 --- /dev/null +++ b/internal/command/commands.go @@ -0,0 +1,227 @@ +package command + +import ( + "fmt" + "strconv" + "strings" + + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/event" + "github.com/jixishi/SerialTerminalForWindowsTerminal/pkg/forward" +) + +func (d *Dispatcher) handleForwardCommand(args []string) error { + if len(args) < 2 { + if d.host.UIEnabled() { + d.host.OpenPanel(event.UIPanelForward) + return nil + } + args = []string{".forward", "list"} + } + + sub := strings.ToLower(args[1]) + switch sub { + case "list", "stats": + if d.host.UIEnabled() { + d.host.OpenPanel(event.UIPanelForward) + return nil + } + + items := d.host.Forward().List() + if len(items) == 0 { + d.host.Notifyf("[forward] empty") + return nil + } + d.host.Notifyf("[forward] ID Mode Enabled Connected Address InBytes OutBytes LastError") + for _, it := range items { + d.host.Notifyf("[forward] %d %s %v %v %s %d %d %s", it.ID, it.Mode, it.Enabled, it.Connected, it.Address, it.ReadBytes, it.WriteByte, it.LastError) + } + return nil + + case "add": + if len(args) < 4 { + return fmt.Errorf("usage: .forward add
") + } + mode, ok := forward.ParseMode(args[2]) + if !ok { + return fmt.Errorf("unknown forward mode: %s", args[2]) + } + id, err := d.host.Forward().Add(mode, args[3]) + if err != nil { + return err + } + d.host.Statusf("[forward] added #%d", id) + return nil + + case "remove", "enable", "disable": + if len(args) < 3 { + return fmt.Errorf("usage: .forward %s ", sub) + } + id, err := strconv.Atoi(args[2]) + if err != nil { + return err + } + switch sub { + case "remove": + return d.host.Forward().Remove(id) + case "enable": + return d.host.Forward().Enable(id) + case "disable": + return d.host.Forward().Disable(id) + } + + case "update": + if len(args) < 5 { + return fmt.Errorf("usage: .forward update
") + } + id, err := strconv.Atoi(args[2]) + if err != nil { + return err + } + mode, ok := forward.ParseMode(args[3]) + if !ok { + return fmt.Errorf("unknown forward mode: %s", args[3]) + } + if err = d.host.Forward().Update(id, mode, args[4]); err != nil { + return err + } + d.host.Statusf("[forward] updated #%d", id) + return nil + } + + return fmt.Errorf("unknown subcommand: %s", sub) +} + +func (d *Dispatcher) handlePluginCommand(args []string) error { + if len(args) < 2 { + if d.host.UIEnabled() { + d.host.OpenPanel(event.UIPanelPlugin) + return nil + } + args = []string{".plugin", "list"} + } + + sub := strings.ToLower(args[1]) + switch sub { + case "list": + if d.host.UIEnabled() { + d.host.OpenPanel(event.UIPanelPlugin) + return nil + } + + items := d.host.Plugins().List() + if len(items) == 0 { + d.host.Notifyf("[plugin] empty") + return nil + } + for _, it := range items { + d.host.Notifyf("[plugin] %s enabled=%v path=%s", it.Name, it.Enabled, it.Path) + } + return nil + + case "load": + if len(args) < 3 { + return fmt.Errorf("usage: .plugin load ") + } + name, err := d.host.Plugins().Load(args[2]) + if err != nil { + return err + } + d.host.Statusf("[plugin] loaded %s", name) + return nil + + case "unload", "enable", "disable", "reload": + if len(args) < 3 { + return fmt.Errorf("usage: .plugin %s ", sub) + } + name := args[2] + switch sub { + case "unload": + return d.host.Plugins().Unload(name) + case "enable": + return d.host.Plugins().Enable(name) + case "disable": + return d.host.Plugins().Disable(name) + case "reload": + return d.host.Plugins().Reload(name) + } + } + + return fmt.Errorf("unknown subcommand: %s", sub) +} + +func (d *Dispatcher) handleModeCommand(args []string) error { + if len(args) < 2 || strings.EqualFold(args[1], "show") { + if d.host.UIEnabled() { + d.host.OpenPanel(event.UIPanelMode) + return nil + } + + cfg := d.host.Cfg() + d.host.Notifyf("[mode] input=%s output=%s end=%q hex=%v frame=%d timestamp=%v timefmt=%q forwardTargets=%d plugins=%d", + cfg.InputCode, cfg.OutputCode, cfg.EndStr, + strings.EqualFold(cfg.InputCode, "hex"), + cfg.FrameSize, cfg.TimesTamp, cfg.TimesFmt, + len(d.host.Forward().List()), len(d.host.Plugins().List()), + ) + return nil + } + + if !strings.EqualFold(args[1], "set") { + return fmt.Errorf("usage: .mode ") + } + if len(args) < 4 { + return fmt.Errorf("usage: .mode set ") + } + + field := strings.ToLower(args[2]) + value := strings.Join(args[3:], " ") + + cfg := d.host.Cfg() + switch field { + case "in": + if value == "" { + return fmt.Errorf("input charset must not be empty") + } + cfg.InputCode = value + case "out": + if value == "" { + return fmt.Errorf("output charset must not be empty") + } + cfg.OutputCode = value + case "end": + cfg.EndStr = value + case "frame": + n, err := strconv.Atoi(value) + if err != nil || n <= 0 { + return fmt.Errorf("frame must be a positive integer") + } + cfg.FrameSize = n + case "timestamp": + enabled, ok := parseOnOff(value) + if !ok { + return fmt.Errorf("timestamp value must be on/off") + } + cfg.TimesTamp = enabled + case "timefmt": + if value == "" && cfg.TimesTamp { + return fmt.Errorf("timestamp format must not be empty") + } + cfg.TimesFmt = value + default: + return fmt.Errorf("unknown mode field: %s", field) + } + + d.host.Statusf("[mode] %s=%q", field, value) + return nil +} + +func parseOnOff(v string) (bool, bool) { + switch strings.ToLower(strings.TrimSpace(v)) { + case "on", "true", "1", "yes": + return true, true + case "off", "false", "0", "no": + return false, true + default: + return false, false + } +} diff --git a/internal/command/complete.go b/internal/command/complete.go new file mode 100644 index 0000000..20db514 --- /dev/null +++ b/internal/command/complete.go @@ -0,0 +1,67 @@ +package command + +import "strings" + +func completeFirstToken(line, token string, cands []string) (string, []string) { + matches := filterPrefix(cands, token) + if len(matches) == 0 { + return line, nil + } + if len(matches) == 1 { + prefix := strings.TrimSuffix(line, token) + return prefix + matches[0] + " ", matches + } + return line, matches +} + +func filterPrefix(cands []string, cur string) []string { + if cur == "" { + return append([]string{}, cands...) + } + res := make([]string, 0, len(cands)) + for _, c := range cands { + if strings.HasPrefix(strings.ToLower(c), strings.ToLower(cur)) { + res = append(res, c) + } + } + return res +} + +func completeForward(args []string) []string { + if len(args) <= 2 { + return []string{"list", "add", "remove", "enable", "disable", "update"} + } + + if len(args) == 3 && args[1] == "add" { + return []string{"tcp", "udp", "tcp-s", "udp-s", "com"} + } + + if len(args) == 4 && args[1] == "update" { + return []string{"tcp", "udp", "tcp-s", "udp-s", "com"} + } + + return nil +} + +func completePlugin(args []string) []string { + if len(args) <= 2 { + return []string{"list", "load", "unload", "enable", "disable", "reload"} + } + return nil +} + +func completeMode(args []string) []string { + if len(args) <= 2 { + return []string{"show", "set"} + } + + if len(args) == 3 && args[1] == "set" { + return []string{"in", "out", "end", "frame", "timestamp", "timefmt"} + } + + if len(args) == 4 && args[1] == "set" && args[2] == "timestamp" { + return []string{"on", "off"} + } + + return nil +} diff --git a/internal/command/dispatcher.go b/internal/command/dispatcher.go new file mode 100644 index 0000000..f18a035 --- /dev/null +++ b/internal/command/dispatcher.go @@ -0,0 +1,217 @@ +package command + +import ( + "encoding/hex" + "fmt" + "sort" + "strings" + + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/config" + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/event" + "github.com/jixishi/SerialTerminalForWindowsTerminal/pkg/forward" + "github.com/jixishi/SerialTerminalForWindowsTerminal/pkg/luaplugin" +) + +// CommandHost is the minimal interface the command dispatcher needs from its host. +type CommandHost interface { + Close() + Notifyf(format string, args ...any) + Statusf(format string, args ...any) + ShowModal(title, text string) + OpenPanel(panel event.UIPanelKind) + UIEnabled() bool + WriteToSession(data []byte) error + Forward() *forward.Manager + Plugins() *luaplugin.Manager + Cfg() *config.Config +} + +type CommandHandler func(args []string) error +type CommandCompleter func(args []string) []string + +type RuntimeCommand struct { + Name string + Usage string + Description string + Handler CommandHandler + Completer CommandCompleter +} + +type Dispatcher struct { + host CommandHost + commands map[string]*RuntimeCommand + order []string +} + +func NewDispatcher(host CommandHost) *Dispatcher { + d := &Dispatcher{ + host: host, + commands: make(map[string]*RuntimeCommand), + } + d.registerAll() + return d +} + +func (d *Dispatcher) register(cmd RuntimeCommand) { + key := strings.ToLower(cmd.Name) + d.commands[key] = &cmd + d.order = append(d.order, key) +} + +func (d *Dispatcher) registerAll() { + d.register(RuntimeCommand{ + Name: ".help", + Usage: ".help", + Description: "show command help", + Handler: func(args []string) error { + d.host.ShowModal("Command Help", d.HelpText()) + return nil + }, + }) + + d.register(RuntimeCommand{ + Name: ".exit", + Usage: ".exit", + Description: "exit local terminal", + Handler: func(args []string) error { + d.host.Statusf("[local] exiting") + d.host.Close() + return nil + }, + }) + + d.register(RuntimeCommand{ + Name: ".hex", + Usage: ".hex ", + Description: "send raw hex bytes", + Handler: func(args []string) error { + if len(args) < 2 { + return fmt.Errorf("usage: .hex ") + } + hexStr := strings.Join(args[1:], "") + b, err := hex.DecodeString(hexStr) + if err != nil { + return err + } + return d.host.WriteToSession(b) + }, + }) + + d.register(RuntimeCommand{ + Name: ".forward", + Usage: ".forward ", + Description: "manage forwarding (tcp/udp/tcp-s/udp-s/com)", + Handler: d.handleForwardCommand, + Completer: completeForward, + }) + + d.register(RuntimeCommand{ + Name: ".plugin", + Usage: ".plugin ", + Description: "manage lua plugins", + Handler: d.handlePluginCommand, + Completer: completePlugin, + }) + + d.register(RuntimeCommand{ + Name: ".mode", + Usage: ".mode ", + Description: "show or update runtime terminal mode", + Handler: func(args []string) error { + return d.handleModeCommand(args) + }, + Completer: completeMode, + }) +} + +func (d *Dispatcher) Execute(line string) (bool, error) { + args := strings.Fields(strings.TrimSpace(line)) + if len(args) == 0 { + return false, nil + } + if !strings.HasPrefix(args[0], ".") { + return false, nil + } + + cmd, ok := d.commands[strings.ToLower(args[0])] + if !ok { + return true, fmt.Errorf("unknown command: %s", args[0]) + } + + if err := cmd.Handler(args); err != nil { + return true, err + } + return true, nil +} + +func (d *Dispatcher) HelpText() string { + keys := make([]string, 0, len(d.order)) + for _, k := range d.order { + keys = append(keys, k) + } + sort.Strings(keys) + + var b strings.Builder + b.WriteString("Commands:\n") + for _, k := range keys { + cmd := d.commands[k] + b.WriteString(fmt.Sprintf(" %-12s %-40s %s\n", cmd.Name, cmd.Usage, cmd.Description)) + } + return b.String() +} + +func (d *Dispatcher) Complete(line string) (string, []string) { + trimmed := strings.TrimLeft(line, " ") + if trimmed == "" { + return line, nil + } + + args := strings.Fields(trimmed) + endsWithSpace := strings.HasSuffix(line, " ") + + if len(args) == 0 { + return line, nil + } + + if len(args) == 1 && !endsWithSpace { + return completeFirstToken(line, args[0], d.commandNames()) + } + + cmdName := strings.ToLower(args[0]) + cmd, ok := d.commands[cmdName] + if !ok || cmd.Completer == nil { + return line, nil + } + + compArgs := args + if endsWithSpace { + compArgs = append(compArgs, "") + } + + cands := cmd.Completer(compArgs) + if len(cands) == 0 { + return line, nil + } + + current := compArgs[len(compArgs)-1] + base := strings.TrimSuffix(line, current) + + matches := filterPrefix(cands, current) + if len(matches) == 0 { + matches = cands + } + if len(matches) == 1 { + return base + matches[0], matches + } + + return line, matches +} + +func (d *Dispatcher) commandNames() []string { + names := make([]string, 0, len(d.commands)) + for _, cmd := range d.commands { + names = append(names, cmd.Name) + } + sort.Strings(names) + return names +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..d465872 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,52 @@ +// Package config holds the application configuration. +package config + +import ( + "fmt" + "os" + "strings" + "time" +) + +// Config holds all application settings. +type Config struct { + PortName string + BaudRate int + DataBits int + StopBits int + ParityBit int + OutputCode string + InputCode string + EndStr string + EnableLog bool + LogFilePath string + ForWard []int + FrameSize int + TimesTamp bool + TimesFmt string + Address []string + EnableGUI bool + HotkeyMod string +} + +// NormalizeHotkey validates and normalizes a hotkey modifier string. +func NormalizeHotkey(mod string) string { + mod = strings.ToLower(strings.TrimSpace(mod)) + if mod != "ctrl+alt" && mod != "ctrl+shift" { + mod = "ctrl+alt" + } + return mod +} + +// OpenLogFile opens the configured log file for writing, or returns nil if logging is disabled. +func OpenLogFile(cfg *Config) (*os.File, error) { + if cfg.EnableLog { + path := fmt.Sprintf(cfg.LogFilePath, cfg.PortName, time.Now().Format("2006_01_02T150405")) + f, err := os.OpenFile(path, os.O_RDWR|os.O_APPEND|os.O_CREATE, 0666) + if err != nil { + return nil, err + } + return f, nil + } + return nil, nil +} diff --git a/internal/console/console.go b/internal/console/console.go new file mode 100644 index 0000000..ba2e86c --- /dev/null +++ b/internal/console/console.go @@ -0,0 +1,347 @@ +// Package console provides the non-TUI console mode. +package console + +import ( + "fmt" + "io" + "os" + "os/signal" + "strconv" + "strings" + + tea "github.com/charmbracelet/bubbletea" + "golang.org/x/term" + + apppkg "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/app" + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/config" + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/flag" + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/session" + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/tui" +) + +// Run parses flags, sets up the session and app, then runs TUI or console mode. +func Run() { + defer func() { + if r := recover(); r != nil { + fmt.Fprintf(os.Stderr, "fatal: %v\n", r) + os.Exit(1) + } + }() + + cfg := &config.Config{} + flag.Init(cfg) + flag.Normalize() + flag.Parse() + flag.Ext(cfg) + if cfg.PortName == "" { + flag.GetCliFlag(cfg) + } + + ports, err := session.CheckPortAvailability(cfg.PortName) + if err != nil { + fmt.Println(err) + flag.PrintUsage(ports) + os.Exit(0) + } + + sess, err := session.Open(cfg) + if err != nil { + fmt.Fprintf(os.Stderr, "open session failed: %v\n", err) + os.Exit(1) + } + + appInst, err := apppkg.New(cfg, sess, os.Stdout) + if err != nil { + fmt.Fprintf(os.Stderr, "create app failed: %v\n", err) + os.Exit(1) + } + defer appInst.Close() + + appInst.LoadConfiguredForwards() + appInst.StartOutputLoop() + + go forwardInterruptToRemote(appInst) + appInst.SetUIEnabled(cfg.EnableGUI) + + if cfg.EnableGUI { + model := tui.New(appInst) + p := tea.NewProgram(model, tea.WithAltScreen(), tea.WithInputTTY(), tea.WithoutSignalHandler()) + if _, err = p.Run(); err != nil { + fmt.Fprintf(os.Stderr, "tui failed: %v\n", err) + os.Exit(1) + } + return + } + + if err = RunConsole(appInst); err != nil { + fmt.Fprintf(os.Stderr, "console failed: %v\n", err) + os.Exit(1) + } +} + +func forwardInterruptToRemote(appInst *apppkg.App) { + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt) + defer signal.Stop(sigCh) + + for { + select { + case <-appInst.WaitDone(): + return + case <-sigCh: + if err := appInst.SendCtrl('c'); err != nil { + appInst.Notifyf("[signal] interrupt pass-through failed: %v", err) + continue + } + appInst.Notifyf("[signal] Ctrl+C forwarded to remote") + } + } +} + +// RunConsole runs the non-TUI console mode. +func RunConsole(appInst *apppkg.App) error { + fd := int(os.Stdin.Fd()) + isTerm := term.IsTerminal(fd) + var oldState *term.State + var err error + if isTerm { + enableVTInput(fd) + oldState, err = term.MakeRaw(fd) + if err != nil { + return err + } + defer func() { _ = term.Restore(fd, oldState) }() + } + + appInst.Notifyf("[console] non-gui mode, commands start with '.' at line start\n") + appInst.Notifyf("[console] Ctrl+ passes through to remote; .exit to exit") + + ch := make(chan byte, 1024) + errCh := make(chan error, 1) + go func() { + buf := make([]byte, 256) + for { + n, rdErr := os.Stdin.Read(buf) + if rdErr != nil { + errCh <- rdErr + return + } + for i := 0; i < n; i++ { + ch <- buf[i] + } + } + }() + + out := appInst.Out() + cfg := appInst.Cfg() + lineStart := true + commandMode := false + cmdBuf := make([]byte, 0, 128) + + tryRead := func() (byte, bool) { + select { + case b := <-ch: + return b, true + default: + return 0, false + } + } + + readByte := func() (byte, error) { + select { + case <-appInst.WaitDone(): + return 0, io.EOF + case rdErr := <-errCh: + return 0, rdErr + case b := <-ch: + return b, nil + } + } + + flushESC := func(seq []byte) bool { + if isExitHotkeySeq(seq, cfg) { + appInst.Close() + return true + } + if err = appInst.WriteToSession(seq); err != nil { + appInst.Statusf("[send] %v", err) + } + return false + } + + for { + b, rdErr := readByte() + if rdErr != nil { + if rdErr == io.EOF { + return nil + } + return rdErr + } + + if b == 0x1b { + escBuf := []byte{0x1b} + for { + nb, ok := tryRead() + if !ok { + if err = appInst.WriteToSession([]byte{0x1b}); err != nil { + appInst.Statusf("[send] %v", err) + } + break + } + escBuf = append(escBuf, nb) + // 2-byte non-CSI: ESC + letter (not [) + if len(escBuf) == 2 && escBuf[1] != '[' { + if flushESC(escBuf) { + return nil + } + break + } + // CSI terminator: final byte of ESC [ ... sequence + if len(escBuf) > 2 && escBuf[1] == '[' && nb >= 0x40 && nb <= 0x7e { + if flushESC(escBuf) { + return nil + } + break + } + if len(escBuf) > 16 { + if err = appInst.WriteToSession(escBuf); err != nil { + appInst.Statusf("[send] %v", err) + } + break + } + } + continue + } + + if b == 0x00 { + if b2, ok := tryRead(); ok { + if isAltKeyExit(b2, cfg) { + appInst.Close() + return nil + } + if err = appInst.WriteToSession([]byte{0x00, b2}); err != nil { + appInst.Statusf("[send] %v", err) + } + } else { + if err = appInst.WriteToSession([]byte{0x00}); err != nil { + appInst.Statusf("[send] %v", err) + } + } + if commandMode { + lineStart = false + } + continue + } + + if commandMode { + switch b { + case '\r', '\n': + echoConsoleNewline(out) + line := string(cmdBuf) + if strings.TrimSpace(line) != "" { + appInst.HandleLine(line) + } + commandMode = false + cmdBuf = cmdBuf[:0] + lineStart = true + case 0x7f, 0x08: + if len(cmdBuf) > 0 { + cmdBuf = cmdBuf[:len(cmdBuf)-1] + echoConsoleBackspace(out) + } + case 0x09: + line, cands := appInst.Dispatcher().Complete(string(cmdBuf)) + if len(cands) == 1 { + cmdBuf = append(cmdBuf[:0], line...) + echoRedrawCommand(out, line) + } else if len(cands) > 1 { + echoConsoleNewline(out) + appInst.Notifyf("%s", strings.Join(cands, " ")) + echoConsoleByte(out, '.') + echoConsoleString(out, string(cmdBuf[1:])) + } + default: + cmdBuf = append(cmdBuf, b) + echoConsoleByte(out, b) + } + continue + } + + if lineStart && b == '.' { + commandMode = true + cmdBuf = append(cmdBuf[:0], b) + echoConsoleByte(out, b) + continue + } + + if b == '\r' || b == '\n' { + if err = appInst.WriteToSession([]byte(cfg.EndStr)); err != nil { + appInst.Statusf("[send] %v", err) + } + lineStart = true + } else { + if err = appInst.WriteToSession([]byte{b}); err != nil { + appInst.Statusf("[send] %v", err) + } + lineStart = false + } + } +} + +func parseCSIu(seq []byte) (cp int, mod int, ok bool) { + if len(seq) < 6 { + return 0, 0, false + } + if seq[0] != 0x1b || seq[1] != '[' { + return 0, 0, false + } + if seq[len(seq)-1] != 'u' { + return 0, 0, false + } + inner := string(seq[2 : len(seq)-1]) + parts := strings.SplitN(inner, ";", 2) + if len(parts) != 2 { + return 0, 0, false + } + cp, err := strconv.Atoi(parts[0]) + if err != nil { + return 0, 0, false + } + mod, err = strconv.Atoi(parts[1]) + if err != nil { + return 0, 0, false + } + return cp, mod, true +} + +func isAltKeyExit(b byte, cfg *config.Config) bool { + if normalizeHotkey(cfg.HotkeyMod) != "ctrl+alt" { + return false + } + return b == 0x2e || b == 0x03 || b == 0x63 || b == 0x43 +} + +func isExitHotkeySeq(seq []byte, cfg *config.Config) bool { + mod := normalizeHotkey(cfg.HotkeyMod) + if cp, cmod, ok := parseCSIu(seq); ok { + if cp != 'c' && cp != 'C' { + return false + } + switch mod { + case "ctrl+alt": + return cmod&6 == 6 + case "ctrl+shift": + return cmod&5 == 5 + } + return false + } + return false +} + +func normalizeHotkey(mod string) string { return config.NormalizeHotkey(mod) } + +func echoConsoleByte(out io.Writer, b byte) { _, _ = out.Write([]byte{b}) } +func echoConsoleNewline(out io.Writer) { _, _ = io.WriteString(out, "\r\n") } +func echoConsoleBackspace(out io.Writer) { _, _ = io.WriteString(out, "\b \b") } +func echoConsoleString(out io.Writer, s string) { _, _ = io.WriteString(out, s) } +func echoRedrawCommand(out io.Writer, s string) { _, _ = io.WriteString(out, "\r\033[K> "+s) } diff --git a/internal/console/console_other.go b/internal/console/console_other.go new file mode 100644 index 0000000..a5e3874 --- /dev/null +++ b/internal/console/console_other.go @@ -0,0 +1,5 @@ +//go:build !windows + +package console + +func enableVTInput(fd int) {} diff --git a/internal/console/console_test.go b/internal/console/console_test.go new file mode 100644 index 0000000..fda4eb5 --- /dev/null +++ b/internal/console/console_test.go @@ -0,0 +1,83 @@ +package console + +import ( + "testing" + + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/config" +) + +func TestParseCSIu(t *testing.T) { + tests := []struct { + name string + seq []byte + cp int + mod int + ok bool + }{ + {name: "ctrl+alt+c lowercase", seq: []byte{0x1b, '[', '9', '9', ';', '6', 'u'}, cp: 99, mod: 6, ok: true}, + {name: "ctrl+shift+c uppercase", seq: []byte{0x1b, '[', '6', '7', ';', '5', 'u'}, cp: 67, mod: 5, ok: true}, + {name: "too short", seq: []byte{0x1b, '[', '9', '9'}, cp: 0, mod: 0, ok: false}, + {name: "no escape prefix", seq: []byte{'[', '9', '9', ';', '6', 'u'}, cp: 0, mod: 0, ok: false}, + {name: "no u terminator", seq: []byte{0x1b, '[', '9', '9', ';', '6', 'x'}, cp: 0, mod: 0, ok: false}, + {name: "bad format no semicolon", seq: []byte{0x1b, '[', '9', '9', '6', 'u'}, cp: 0, mod: 0, ok: false}, + {name: "empty", seq: []byte{}, cp: 0, mod: 0, ok: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cp, mod, ok := parseCSIu(tt.seq) + if ok != tt.ok || cp != tt.cp || mod != tt.mod { + t.Fatalf("parseCSIu(%v) got=(%d,%d,%v) want=(%d,%d,%v)", tt.seq, cp, mod, ok, tt.cp, tt.mod, tt.ok) + } + }) + } +} + +func TestIsExitHotkeySeq(t *testing.T) { + cfg := &config.Config{HotkeyMod: "ctrl+alt"} + + // CSI u Ctrl+Alt+C (mod=6) + if !isExitHotkeySeq([]byte{0x1b, '[', '9', '9', ';', '6', 'u'}, cfg) { + t.Fatalf("Ctrl+Alt+C CSI should exit with ctrl+alt config") + } + if !isExitHotkeySeq([]byte{0x1b, '[', '9', '9', ';', '7', 'u'}, cfg) { + t.Fatalf("Ctrl+Alt+Shift+C should also exit") + } + if isExitHotkeySeq([]byte{0x1b, '[', '9', '9', ';', '5', 'u'}, cfg) { + t.Fatalf("Ctrl+Shift+C should NOT exit with ctrl+alt config") + } + if isExitHotkeySeq([]byte{0x1b, '[', '9', '7', ';', '6', 'u'}, cfg) { + t.Fatalf("Ctrl+Alt+A should not exit") + } + if isExitHotkeySeq([]byte{0x1b, 'c'}, cfg) { + t.Fatalf("Alt+C (ESC c) should NOT exit — Ctrl modifier required") + } + + cfg2 := &config.Config{HotkeyMod: "ctrl+shift"} + if !isExitHotkeySeq([]byte{0x1b, '[', '9', '9', ';', '5', 'u'}, cfg2) { + t.Fatalf("Ctrl+Shift+C should exit with ctrl+shift config") + } + if !isExitHotkeySeq([]byte{0x1b, '[', '9', '9', ';', '7', 'u'}, cfg2) { + t.Fatalf("Ctrl+Shift+Alt+C should also exit (includes Ctrl+Shift)") + } + if isExitHotkeySeq([]byte{0x1b, '[', '9', '9', ';', '6', 'u'}, cfg2) { + t.Fatalf("Ctrl+Alt+C should NOT exit with ctrl+shift config") + } + if isExitHotkeySeq([]byte{0x1b, 'c'}, cfg2) { + t.Fatalf("ESC c should NOT exit with ctrl+shift config") + } + if isExitHotkeySeq([]byte{0x1b, 'x'}, cfg2) { + t.Fatalf("ESC x should not exit") + } + if isExitHotkeySeq([]byte("hello"), cfg2) { + t.Fatalf("plain bytes should not exit") + } + + cfg3 := &config.Config{HotkeyMod: "ctrl+alt"} + if isExitHotkeySeq([]byte{0x1b, '[', '9', '9', ';', '4', 'u'}, cfg3) { + t.Fatalf("Ctrl+C (without Alt) should not exit") + } + if isExitHotkeySeq([]byte{0x1b, '[', '9', '9', ';', '2', 'u'}, cfg3) { + t.Fatalf("Alt+C (without Ctrl) should not exit") + } +} diff --git a/internal/console/console_windows.go b/internal/console/console_windows.go new file mode 100644 index 0000000..9f38bcc --- /dev/null +++ b/internal/console/console_windows.go @@ -0,0 +1,14 @@ +//go:build windows + +package console + +import "golang.org/x/sys/windows" + +func enableVTInput(fd int) { + var mode uint32 + if err := windows.GetConsoleMode(windows.Handle(fd), &mode); err != nil { + return + } + mode |= windows.ENABLE_VIRTUAL_TERMINAL_INPUT + _ = windows.SetConsoleMode(windows.Handle(fd), mode) +} diff --git a/internal/event/event.go b/internal/event/event.go new file mode 100644 index 0000000..47f82ae --- /dev/null +++ b/internal/event/event.go @@ -0,0 +1,30 @@ +// Package event defines UI event types shared between app, console, and tui packages. +package event + +// UIEventKind classifies a UI event. +type UIEventKind int + +const ( + UIEventOutput UIEventKind = iota + UIEventStatus + UIEventModal + UIEventPanel +) + +// UIPanelKind identifies a modal panel type. +type UIPanelKind int + +const ( + UIPanelNone UIPanelKind = iota + UIPanelForward + UIPanelPlugin + UIPanelMode +) + +// UIEvent is emitted by the app core and consumed by TUI or console frontends. +type UIEvent struct { + Kind UIEventKind + Title string + Text string + Panel UIPanelKind +} diff --git a/internal/flag/flag.go b/internal/flag/flag.go new file mode 100644 index 0000000..4101d2d --- /dev/null +++ b/internal/flag/flag.go @@ -0,0 +1,264 @@ +// Package flag provides CLI flag parsing and interactive configuration. +package flag + +import ( + "fmt" + "log" + "os" + "sort" + "strconv" + "strings" + + "github.com/charmbracelet/bubbles/key" + inf "github.com/fzdwx/infinite" + "github.com/fzdwx/infinite/color" + "github.com/fzdwx/infinite/components" + "github.com/fzdwx/infinite/components/input/text" + "github.com/fzdwx/infinite/components/selection/confirm" + "github.com/fzdwx/infinite/components/selection/singleselect" + "github.com/fzdwx/infinite/style" + "github.com/fzdwx/infinite/theme" + "github.com/spf13/pflag" + "go.bug.st/serial" + + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/config" +) + +// Init registers all CLI flags with pflag, binding them to the given config. +func Init(cfg *config.Config) { + pflag.StringVarP(&cfg.PortName, "port", "p", "", "serial port (/dev/ttyUSB0, COMx)") + pflag.IntVarP(&cfg.BaudRate, "baud", "b", 115200, "baud rate") + pflag.IntVarP(&cfg.DataBits, "data", "d", 8, "data bits") + pflag.IntVarP(&cfg.StopBits, "stop", "s", 0, "stop bits (0:1, 1:1.5, 2:2)") + pflag.StringVarP(&cfg.OutputCode, "out", "o", "UTF-8", "output charset") + pflag.StringVarP(&cfg.InputCode, "in", "i", "UTF-8", "input charset") + pflag.StringVarP(&cfg.EndStr, "end", "e", "\n", "line ending") + pflag.IntVarP(&cfg.FrameSize, "Frame", "F", 16, "hex frame size") + pflag.IntVarP(&cfg.ParityBit, "verify", "v", 0, "parity (0:none,1:odd,2:even,3:mark,4:space)") + pflag.BoolVarP(&cfg.EnableGUI, "gui", "g", false, "enable TUI mode") + pflag.StringVarP(&cfg.HotkeyMod, "hotkey-mod", "k", "ctrl+alt", "hotkey modifier (ctrl+alt|ctrl+shift)") + pflag.IntSliceVarP(&cfg.ForWard, "forward", "f", nil, "forward mode (0:none,1:TCP,2:UDP,3:TCP-S,4:UDP-S,5:COM)") + pflag.StringArrayVarP(&cfg.Address, "address", "a", nil, "forward address") + pflag.StringVarP(&cfg.LogFilePath, "log", "l", "", "log file path") + _ = pflag.Lookup("log") // mark for NoOptDefVal + pflag.StringVarP(&cfg.TimesFmt, "time", "t", "", "timestamp format") + _ = pflag.Lookup("time") // mark for NoOptDefVal +} + +// Normalize converts single-dash long flags (e.g. -port) to double-dash (--port). +// Parse wraps pflag.Parse. +func Parse() { pflag.Parse() } + +// Normalize converts single-dash long flags (e.g. -port) to double-dash (--port). +func Normalize() { + known := map[string]bool{ + "port": true, "baud": true, "data": true, "stop": true, + "out": true, "in": true, "end": true, "Frame": true, + "verify": true, "gui": true, "hotkey-mod": true, + "forward": true, "address": true, "log": true, "time": true, + } + for i, arg := range os.Args[1:] { + if strings.HasPrefix(arg, "-") && !strings.HasPrefix(arg, "--") { + name := strings.TrimPrefix(arg, "-") + if known[name] { + os.Args[i+1] = "--" + name + } + } + } +} + +// Ext applies post-parse normalization to config values. +func Ext(cfg *config.Config) { + if cfg.LogFilePath != "" { + cfg.EnableLog = true + } + if cfg.TimesFmt != "" { + cfg.TimesTamp = true + } + if cfg.HotkeyMod == "" { + cfg.HotkeyMod = "ctrl+alt" + } + cfg.HotkeyMod = strings.ToLower(strings.TrimSpace(cfg.HotkeyMod)) + if cfg.HotkeyMod != "ctrl+alt" && cfg.HotkeyMod != "ctrl+shift" { + cfg.HotkeyMod = "ctrl+alt" + } +} + +// PrintUsage displays flag help and available ports. +func PrintUsage(ports []string) { + type flagInfo struct{ short, long, typ, help, def string } + flags := []flagInfo{ + {"-p", "--port", "string", "serial port", ""}, + {"-b", "--baud", "int", "baud rate", "115200"}, + {"-d", "--data", "int", "data bits", "8"}, + {"-s", "--stop", "int", "stop bits", "0"}, + {"-o", "--out", "string", "output charset", "UTF-8"}, + {"-i", "--in", "string", "input charset", "UTF-8"}, + {"-e", "--end", "string", "line ending", "\\n"}, + {"-F", "--Frame", "int", "hex frame size", "16"}, + {"-v", "--verify", "int", "parity", "0"}, + {"-g", "--gui", "bool", "enable TUI", "false"}, + {"-k", "--hotkey-mod", "string", "hotkey modifier", "ctrl+alt"}, + {"-f", "--forward", "[]int", "forward (0:none,1:TCP,2:UDP,3:TCP-S,4:UDP-S,5:COM)", "0"}, + {"-a", "--address", "[]string", "forward address", "127.0.0.1:12345"}, + {"-l", "--log", "string", "log path (%s=port, then timestamp)", "./%s-%s.log"}, + {"-t", "--time", "string", "timestamp format", "[06-01-02 15:04:05.000]"}, + } + sort.Slice(flags, func(i, j int) bool { return flags[i].long < flags[j].long }) + + fmt.Printf("\nFlags:\n") + fmt.Printf(" %-6s %-14s %-8s %-44s %s\n", "Short", "Long", "Type", "Help", "Default") + fmt.Printf(" %-6s %-14s %-8s %-44s %s\n", "------", "------", "------", "------", "------") + for _, f := range flags { + fmt.Printf(" %-6s %-14s %-8s %-44s %q\n", f.short, f.long, f.typ, f.help, f.def) + } + fmt.Printf("\nAvailable ports: %v\n", strings.Join(ports, ", ")) +} + +var ( + bauds = []string{"Custom", "300", "600", "1200", "2400", "4800", "9600", + "14400", "19200", "38400", "56000", "57600", "115200", "128000", + "256000", "460800", "512000", "750000", "921600", "1500000"} + datas = []string{"5", "6", "7", "8"} + stops = []string{"1", "1.5", "2"} + paritys = []string{"None", "Odd", "Even", "Mark", "Space"} + forwards = []string{"No", "TCP-C", "UDP-C", "TCP-S", "UDP-S", "COM"} +) + +// GetCliFlag runs an interactive configuration wizard when no port is specified. +func GetCliFlag(cfg *config.Config) { + ports, err := serial.GetPortsList() + if err != nil { + log.Fatal(err) + } + + inputs := components.NewInput() + inputs.Prompt = "Filtering: " + inputs.PromptStyle = style.New().Bold().Italic().Fg(color.LightBlue) + + selectKeymap := singleselect.DefaultSingleKeyMap() + selectKeymap.Confirm = key.NewBinding(key.WithKeys("enter"), key.WithHelp("enter", "finish select")) + selectKeymap.Choice = key.NewBinding(key.WithKeys("enter"), key.WithHelp("enter", "finish select")) + selectKeymap.NextPage = key.NewBinding(key.WithKeys("right"), key.WithHelp("->", "next page")) + selectKeymap.PrevPage = key.NewBinding(key.WithKeys("left"), key.WithHelp("<-", "prev page")) + + s, _ := inf.NewSingleSelect(ports, + singleselect.WithKeyBinding(selectKeymap), + singleselect.WithPageSize(4), + singleselect.WithFilterInput(inputs), + ).Display("Select serial port") + cfg.PortName = ports[s] + + s, _ = inf.NewSingleSelect(bauds, + singleselect.WithKeyBinding(selectKeymap), + singleselect.WithPageSize(4), + ).Display("Select baud rate") + if s != 0 { + cfg.BaudRate, _ = strconv.Atoi(bauds[s]) + } else { + b, _ := inf.NewText( + text.WithPrompt("BaudRate:"), + text.WithPromptStyle(theme.DefaultTheme.PromptStyle), + text.WithDefaultValue("115200"), + ).Display() + cfg.BaudRate, _ = strconv.Atoi(b) + } + + v, _ := inf.NewConfirmWithSelection(confirm.WithPrompt("Enable Hex")).Display() + if v { + cfg.InputCode = "hex" + b, _ := inf.NewText( + text.WithPrompt("Frames:"), + text.WithPromptStyle(theme.DefaultTheme.PromptStyle), + text.WithDefaultValue("16"), + ).Display() + cfg.FrameSize, _ = strconv.Atoi(b) + } + + v, _ = inf.NewConfirmWithSelection(confirm.WithPrompt("Enable Timestamp")).Display() + cfg.TimesTamp = v + if v { + b, _ := inf.NewText( + text.WithPrompt("Format:"), + text.WithPromptStyle(theme.DefaultTheme.PromptStyle), + text.WithDefaultValue("[06-01-02 15:04:05.000]"), + ).Display() + cfg.TimesFmt = b + } + + v, _ = inf.NewConfirmWithSelection(confirm.WithPrompt("Enable advanced config")).Display() + if v { + s, _ = inf.NewSingleSelect(datas, + singleselect.WithKeyBinding(selectKeymap), + singleselect.WithPageSize(4), + singleselect.WithFilterInput(inputs), + ).Display("Select data bits") + cfg.DataBits, _ = strconv.Atoi(datas[s]) + + s, _ = inf.NewSingleSelect(stops, + singleselect.WithKeyBinding(selectKeymap), + singleselect.WithPageSize(4), + singleselect.WithFilterInput(inputs), + ).Display("Select stop bits") + cfg.StopBits = s + + s, _ = inf.NewSingleSelect(paritys, + singleselect.WithKeyBinding(selectKeymap), + singleselect.WithPageSize(4), + singleselect.WithFilterInput(inputs), + ).Display("Select parity") + cfg.ParityBit = s + + t, _ := inf.NewText( + text.WithPrompt("Line ending:"), + text.WithPromptStyle(theme.DefaultTheme.PromptStyle), + text.WithDefaultValue("\n"), + ).Display() + cfg.EndStr = t + + v, _ = inf.NewConfirmWithSelection(confirm.WithDefaultYes(), confirm.WithPrompt("Enable charset conversion")).Display() + if v { + t, _ = inf.NewText( + text.WithPrompt("Input charset:"), + text.WithPromptStyle(theme.DefaultTheme.PromptStyle), + text.WithDefaultValue("UTF-8"), + ).Display() + cfg.InputCode = t + + t, _ = inf.NewText( + text.WithPrompt("Output charset:"), + text.WithPromptStyle(theme.DefaultTheme.PromptStyle), + text.WithDefaultValue("UTF-8"), + ).Display() + cfg.OutputCode = t + } + + G_F_mode: + s, _ = inf.NewSingleSelect(forwards, + singleselect.WithKeyBinding(selectKeymap), + singleselect.WithPageSize(3), + singleselect.WithFilterInput(inputs), + ).Display("Select forward mode") + if s != 0 { + cfg.ForWard = append(cfg.ForWard, s) + t, _ = inf.NewText( + text.WithPrompt("Address:"), + text.WithPromptStyle(theme.DefaultTheme.PromptStyle), + text.WithDefaultValue("127.0.0.1:12345"), + ).Display() + cfg.Address = append(cfg.Address, t) + goto G_F_mode + } + + e, _ := inf.NewConfirmWithSelection(confirm.WithDefaultYes(), confirm.WithPrompt("Enable logging")).Display() + cfg.EnableLog = e + if e { + t, _ = inf.NewText( + text.WithPrompt("Path(%s=port, then stamp):"), + text.WithPromptStyle(theme.DefaultTheme.PromptStyle), + text.WithDefaultValue("./%s-%s.log"), + ).Display() + cfg.LogFilePath = t + } + } +} diff --git a/internal/session/session.go b/internal/session/session.go new file mode 100644 index 0000000..165e76a --- /dev/null +++ b/internal/session/session.go @@ -0,0 +1,138 @@ +// Package session manages the serial port connection and its associated pipes. +package session + +import ( + "fmt" + "io" + "os" + "os/signal" + "runtime" + "sync" + + "github.com/trzsz/trzsz-go/trzsz" + "go.bug.st/serial" + "golang.org/x/term" + + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/config" +) + +// SerialSession owns the serial port, trzsz filter, and pipe pair. +type SerialSession struct { + Port serial.Port + TrzszFilter *trzsz.TrzszFilter + StdinPipe *io.PipeWriter + StdoutPipe *io.PipeReader + ClientIn *io.PipeReader + ClientOut *io.PipeWriter + + termCh chan os.Signal + closeOnce sync.Once +} + +// Open creates a SerialSession by opening the serial port and initializing trzsz. +func Open(cfg *config.Config) (*SerialSession, error) { + mode := &serial.Mode{ + BaudRate: cfg.BaudRate, + StopBits: serial.StopBits(cfg.StopBits), + DataBits: cfg.DataBits, + Parity: serial.Parity(cfg.ParityBit), + } + port, err := serial.Open(cfg.PortName, mode) + if err != nil { + return nil, err + } + + fd := int(os.Stdin.Fd()) + width, _, err := term.GetSize(fd) + if err != nil { + if runtime.GOOS != "windows" { + port.Close() + return nil, fmt.Errorf("term get size failed: %w", err) + } + width = 80 + } + + clientIn, stdinPipe := io.Pipe() + stdoutPipe, clientOut := io.Pipe() + trzszFilter := trzsz.NewTrzszFilter(clientIn, clientOut, port, port, + trzsz.TrzszOptions{TerminalColumns: int32(width), EnableZmodem: true}) + trzsz.SetAffectedByWindows(false) + + s := &SerialSession{ + Port: port, + TrzszFilter: trzszFilter, + StdinPipe: stdinPipe, + StdoutPipe: stdoutPipe, + ClientIn: clientIn, + ClientOut: clientOut, + termCh: make(chan os.Signal, 1), + } + + go func() { + for range s.termCh { + w, _, err := term.GetSize(fd) + if err != nil { + fmt.Printf("term get size failed: %s\n", err) + continue + } + trzszFilter.SetTerminalColumns(int32(w)) + } + }() + + return s, nil +} + +// Write writes data to the stdin pipe (toward serial port, through trzsz). +func (s *SerialSession) Write(data []byte) (int, error) { + return s.StdinPipe.Write(data) +} + +// Read reads data from the stdout pipe (from serial port, through trzsz). +func (s *SerialSession) Read(buf []byte) (int, error) { + return s.StdoutPipe.Read(buf) +} + +// SendCtrl sends a control character directly to the serial port (bypasses trzsz). +func (s *SerialSession) SendCtrl(letter byte) (int, error) { + if letter >= 'A' && letter <= 'Z' { + letter = letter + ('a' - 'A') + } + control := []byte{letter & 0x1f} + return s.Port.Write(control) +} + +// Close tears down the session: stops term signals, closes trzsz, then serial port. +func (s *SerialSession) Close() { + s.closeOnce.Do(func() { + if s.termCh != nil { + signal.Stop(s.termCh) + close(s.termCh) + } + if s.Port != nil { + if err := s.Port.Close(); err != nil { + fmt.Fprint(os.Stderr, err) + fmt.Fprint(os.Stderr, "\n") + } + } + }) +} + +// CheckPortAvailability returns the list of available ports and verifies the named port exists. +func CheckPortAvailability(name string) ([]string, error) { + ports, err := serial.GetPortsList() + if err != nil { + return nil, err + } + if len(ports) == 0 { + return nil, fmt.Errorf("no serial ports found") + } + if name == "" { + return ports, fmt.Errorf("port name not specified") + } + for _, port := range ports { + if port == name { + return ports, nil + } + } + return ports, fmt.Errorf("port " + name + " is not available") +} diff --git a/internal/tui/hotkeys.go b/internal/tui/hotkeys.go new file mode 100644 index 0000000..c4038bb --- /dev/null +++ b/internal/tui/hotkeys.go @@ -0,0 +1,181 @@ +package tui + +import ( + "strconv" + "strings" + + tea "github.com/charmbracelet/bubbletea" + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/config" + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/event" +) + +func handleLocalHotkey(m *Model, key string) bool { + if m.isLocalHotkey(key, "h") { + modifier := strings.ToUpper(normalizeHotkeyPrefix(m.App.Cfg().HotkeyMod)) + m.App.ShowModal("Shortcuts", modifier+"+C => local exit\nCtrl+C => remote interrupt\n"+modifier+"+F => forward panel\n"+modifier+"+P => plugin panel\n"+modifier+"+M => mode panel\nF1 => shortcut help") + return true + } + if m.isLocalHotkey(key, "f") { + m.App.OpenPanel(event.UIPanelForward) + return true + } + if m.isLocalHotkey(key, "p") { + m.App.OpenPanel(event.UIPanelPlugin) + return true + } + if m.isLocalHotkey(key, "m") { + m.App.OpenPanel(event.UIPanelMode) + return true + } + return false +} + +func (m *Model) isLocalHotkey(key, action string) bool { + parts := strings.Split(strings.ToLower(key), "+") + if len(parts) < 2 || parts[len(parts)-1] != action { + return false + } + + hasCtrl := false + hasAlt := false + hasShift := false + for _, p := range parts[:len(parts)-1] { + switch p { + case "ctrl": + hasCtrl = true + case "alt": + hasAlt = true + case "shift": + hasShift = true + } + } + + mod := normalizeHotkeyPrefix(m.App.Cfg().HotkeyMod) + if mod == "ctrl+shift" { + return hasCtrl && hasShift + } + return hasCtrl && hasAlt +} + +func normalizeHotkeyPrefix(mod string) string { return config.NormalizeHotkey(mod) } + +func hotkeyWith(mod, action string) string { + return normalizeHotkeyPrefix(mod) + "+" + action +} + +func parseCtrlKey(key string) (byte, bool) { + if !strings.HasPrefix(key, "ctrl+") || strings.HasPrefix(key, "ctrl+shift+") { + return 0, false + } + + parts := strings.Split(key, "+") + if len(parts) != 2 || len(parts[1]) != 1 { + return 0, false + } + ch := parts[1][0] + if ch < 'a' || ch > 'z' { + return 0, false + } + return ch, true +} + +func (m *Model) handleViewportKey(msg tea.KeyMsg) bool { + if !m.ready || m.showModal { + return false + } + + key := strings.ToLower(msg.String()) + switch key { + case "pgup", "ctrl+u", "alt+up", "up": + var cmd tea.Cmd + m.viewport, cmd = m.viewport.Update(msg) + _ = cmd + m.followTail = false + return true + case "pgdown", "ctrl+d", "alt+down", "down": + var cmd tea.Cmd + m.viewport, cmd = m.viewport.Update(msg) + _ = cmd + return true + case "home": + m.viewport.GotoTop() + m.followTail = false + return true + case "end": + m.viewport.GotoBottom() + m.followTail = true + return true + default: + return false + } +} + +func (m *Model) resetCompletion() { + m.completionActive = false + m.completionBase = "" + m.completionCandidates = nil + m.completionIndex = 0 +} + +func (m *Model) stepCompletion(direction int) { + if len(m.completionCandidates) == 0 { + m.resetCompletion() + return + } + if direction >= 0 { + m.completionIndex = (m.completionIndex + 1) % len(m.completionCandidates) + } else { + m.completionIndex = (m.completionIndex - 1 + len(m.completionCandidates)) % len(m.completionCandidates) + } + m.applyCompletion() +} + +func (m *Model) applyCompletion() { + if len(m.completionCandidates) == 0 { + return + } + m.input.SetValue(m.completionBase + m.completionCandidates[m.completionIndex] + " ") +} + +func completionBase(line string) string { + if strings.HasSuffix(line, " ") { + return line + } + i := strings.LastIndex(line, " ") + if i < 0 { + return "" + } + return line[:i+1] +} + +func parseCSIuBytes(b []byte) (string, bool) { + s := string(b) + if !strings.HasPrefix(s, "\x1b[") || !strings.HasSuffix(s, "u") { + return "", false + } + inner := s[2 : len(s)-1] + parts := strings.SplitN(inner, ";", 2) + if len(parts) != 2 { + return "", false + } + cp, err := strconv.Atoi(parts[0]) + if err != nil || cp < 'a' || cp > 'z' { + return "", false + } + mod, err := strconv.Atoi(parts[1]) + if err != nil { + return "", false + } + var seq []string + if mod&4 != 0 { + seq = append(seq, "ctrl") + } + if mod&2 != 0 { + seq = append(seq, "alt") + } + if mod&1 != 0 { + seq = append(seq, "shift") + } + seq = append(seq, string(rune(cp))) + return strings.Join(seq, "+"), true +} diff --git a/internal/tui/model.go b/internal/tui/model.go new file mode 100644 index 0000000..c43e3fa --- /dev/null +++ b/internal/tui/model.go @@ -0,0 +1,308 @@ +package tui + +import ( + "fmt" + "strings" + + "github.com/charmbracelet/bubbles/textinput" + "github.com/charmbracelet/bubbles/viewport" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/app" + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/event" + "github.com/jixishi/SerialTerminalForWindowsTerminal/pkg/forward" + "github.com/jixishi/SerialTerminalForWindowsTerminal/pkg/luaplugin" +) + +type doneMsg struct{} + +type modeItem struct { + key string + label string + value string + rawValue string +} + +type panelLine struct { + text string + selected bool +} + +type Model struct { + App *app.App + + viewport viewport.Model + input textinput.Model + + ready bool + width int + height int + statusLine string + suggestions []string + content strings.Builder + followTail bool + + showModal bool + modalTitle string + modalBody string + + panelKind event.UIPanelKind + panelIndex int + panelError string + + forwardItems []forward.Snapshot + pluginItems []luaplugin.Snapshot + modeItems []modeItem + + promptActive bool + promptTitle string + promptHint string + promptInput textinput.Model + promptSubmit func(string) + + formActive bool + formTitle string + formFields []textinput.Model + formLabels []string + formFocus int + formSubmit func([]string) + + completionActive bool + completionBase string + completionCandidates []string + completionIndex int +} + +func New(application *app.App) *Model { + in := textinput.New() + in.Placeholder = "Type to send to remote, use .help for commands" + in.Focus() + in.CharLimit = 0 + in.Prompt = "> " + in.Width = 80 + + return &Model{App: application, input: in, followTail: true} +} + +func (m *Model) Init() tea.Cmd { + return tea.Batch(waitUIEvent(m.App.UIEvents()), waitDone(m.App.WaitDone()), textinput.Blink) +} + +func waitUIEvent(ch <-chan event.UIEvent) tea.Cmd { + return func() tea.Msg { + ev, ok := <-ch + if !ok { + return doneMsg{} + } + return ev + } +} + +func waitDone(ch <-chan struct{}) tea.Cmd { + return func() tea.Msg { + <-ch + return doneMsg{} + } +} + +func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case doneMsg: + return m, tea.Quit + + case event.UIEvent: + switch msg.Kind { + case event.UIEventOutput, event.UIEventStatus: + if msg.Kind == event.UIEventOutput { + m.appendOutput(msg.Text) + } else { + m.statusLine = msg.Text + } + case event.UIEventModal: + m.showModal = true + m.panelKind = event.UIPanelNone + m.modalTitle = msg.Title + m.modalBody = msg.Text + m.promptActive = false + case event.UIEventPanel: + m.openPanel(msg.Panel) + } + return m, waitUIEvent(m.App.UIEvents()) + + case tea.WindowSizeMsg: + m.width = msg.Width + m.height = msg.Height + inputHeight := 3 + statusHeight := 2 + viewportHeight := msg.Height - inputHeight - statusHeight + if viewportHeight < 3 { + viewportHeight = 3 + } + + if !m.ready { + m.viewport = viewport.New(msg.Width, viewportHeight) + m.viewport.YPosition = 0 + m.viewport.SetContent(m.content.String()) + m.ready = true + } else { + m.viewport.Width = msg.Width + m.viewport.Height = viewportHeight + } + + m.input.Width = msg.Width - 4 + m.viewport.GotoBottom() + m.followTail = true + return m, nil + + case tea.KeyMsg: + keyStr := strings.ToLower(msg.String()) + if m.handleViewportKey(msg) { + return m, nil + } + if keyStr != "tab" && keyStr != "shift+tab" { + m.resetCompletion() + } + + if m.showModal { + handled, cmd := m.handleModalKey(msg) + if handled { + return m, cmd + } + } + + if m.isLocalHotkey(keyStr, "c") { + m.App.Statusf("[local] exiting by %s+C", strings.ToUpper(normalizeHotkeyPrefix(m.App.Cfg().HotkeyMod))) + m.App.Close() + return m, tea.Quit + } + + if handleLocalHotkey(m, keyStr) { + return m, nil + } + + if keyStr == "ctrl+h" { + handleLocalHotkey(m, hotkeyWith(m.App.Cfg().HotkeyMod, "h")) + return m, nil + } + + if letter, ok := parseCtrlKey(keyStr); ok { + if err := m.App.SendCtrl(letter); err != nil { + m.App.Notifyf("[remote] ctrl send failed: %v", err) + } + return m, nil + } + + switch keyStr { + case "f1": + handleLocalHotkey(m, hotkeyWith(m.App.Cfg().HotkeyMod, "h")) + return m, nil + + case "tab", "shift+tab": + direction := 1 + if keyStr == "shift+tab" { + direction = -1 + } + + if m.completionActive && len(m.completionCandidates) > 0 { + m.stepCompletion(direction) + return m, nil + } + + line, cands := m.App.Dispatcher().Complete(m.input.Value()) + m.suggestions = cands + if len(cands) == 0 { + return m, nil + } + if len(cands) == 1 { + m.input.SetValue(line) + return m, nil + } + + m.completionActive = true + m.completionBase = completionBase(m.input.Value()) + m.completionCandidates = append([]string(nil), cands...) + if direction < 0 { + m.completionIndex = len(cands) - 1 + } else { + m.completionIndex = 0 + } + m.applyCompletion() + return m, nil + + case "enter": + line := m.input.Value() + m.input.SetValue("") + m.suggestions = nil + m.followTail = true + m.App.HandleLine(line) + return m, nil + } + } + + + // Handle CSI u sequences that bubbletea does not parse into KeyMsg + if b, ok := msg.([]byte); ok { + if key, ok2 := parseCSIuBytes(b); ok2 { + keyStr := strings.ToLower(key) + if m.showModal { + last := rune(key[len(key)-1]) + fake := tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{last}, Alt: strings.Contains(key, "alt+")} + if handled, _ := m.handleModalKey(fake); handled { + return m, nil + } + } + if keyStr == normalizeHotkeyPrefix(m.App.Cfg().HotkeyMod)+"+c" { + m.App.Close() + return m, tea.Quit + } + if handleLocalHotkey(m, keyStr) { + return m, nil + } + } + } + + var cmd tea.Cmd + m.input, cmd = m.input.Update(msg) + return m, cmd +} + +func (m *Model) View() string { + if !m.ready { + return "Initializing..." + } + + suggest := "Tab: no candidates" + if len(m.suggestions) > 1 { + suggest = "Tab candidates: " + strings.Join(m.suggestions, " ") + } else if len(m.suggestions) == 1 { + suggest = "Tab: " + m.suggestions[0] + } + modifier := strings.ToUpper(normalizeHotkeyPrefix(m.App.Cfg().HotkeyMod)) + hotkeys := "Hotkeys: Ctrl+C remote | " + modifier + "+C local | " + modifier + "+F forward | " + modifier + "+P plugins | " + modifier + "+M mode | F1 help" + hotkeys = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("244")).Render(hotkeys) + status := m.statusLine + if status == "" { + status = "Ready" + } + status = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("255")).Render(status) + suggest = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("39")).Render(suggest) + base := fmt.Sprintf("%s\n%s\n%s\n%s\n%s", m.viewport.View(), suggest, status, m.input.View(), hotkeys) + if !m.showModal { + return fillScreen(m.width, m.height, base) + } + + if m.formActive { + return renderCenteredModalContent(m.width, m.height, m.renderForm()) + } + + if m.promptActive { + return renderCenteredModalContent(m.width, m.height, m.renderPrompt()) + } + + if m.panelKind != event.UIPanelNone { + return renderCenteredModalContent(m.width, m.height, m.renderPanel()) + } + + return renderCenteredModal(m.width, m.height, m.modalTitle, m.modalBody) +} diff --git a/internal/tui/panels.go b/internal/tui/panels.go new file mode 100644 index 0000000..2ac00f6 --- /dev/null +++ b/internal/tui/panels.go @@ -0,0 +1,504 @@ +package tui + +import ( + "fmt" + "strings" + + "github.com/charmbracelet/bubbles/textinput" + tea "github.com/charmbracelet/bubbletea" + + "github.com/jixishi/SerialTerminalForWindowsTerminal/internal/event" + "github.com/jixishi/SerialTerminalForWindowsTerminal/pkg/forward" +) + +func (m *Model) handleModalKey(msg tea.KeyMsg) (bool, tea.Cmd) { + keyStr := strings.ToLower(msg.String()) + + if m.formActive { + return m.handleFormKey(msg) + } + if m.promptActive { + return m.handlePromptKey(msg) + } + if keyStr == "esc" { + m.closeModal() + return true, nil + } + if m.panelKind == event.UIPanelNone { + if keyStr == "enter" { + m.closeModal() + } + return true, nil + } + + switch m.panelKind { + case event.UIPanelForward: + return m.handleForwardPanelKey(keyStr), nil + case event.UIPanelPlugin: + return m.handlePluginPanelKey(keyStr), nil + case event.UIPanelMode: + return m.handleModePanelKey(keyStr), nil + default: + return true, nil + } +} + +func (m *Model) closeModal() { + m.showModal = false + m.panelKind = event.UIPanelNone + m.modalTitle = "" + m.modalBody = "" + m.promptActive = false + m.promptSubmit = nil + m.formActive = false + m.formSubmit = nil + m.panelError = "" +} + +func (m *Model) openPanel(kind event.UIPanelKind) { + m.showModal = true + m.panelKind = kind + m.panelIndex = 0 + m.promptActive = false + m.promptSubmit = nil + m.panelError = "" + m.refreshPanel() +} + +func (m *Model) refreshPanel() { + switch m.panelKind { + case event.UIPanelForward: + m.forwardItems = m.App.Forward().List() + m.panelIndex = clampIndex(m.panelIndex, len(m.forwardItems)) + case event.UIPanelPlugin: + m.pluginItems = m.App.Plugins().List() + m.panelIndex = clampIndex(m.panelIndex, len(m.pluginItems)) + case event.UIPanelMode: + m.modeItems = m.buildModeItems() + m.panelIndex = clampIndex(m.panelIndex, len(m.modeItems)) + } +} + +func (m *Model) buildModeItems() []modeItem { + cfg := m.App.Cfg() + return []modeItem{ + {"in", "Input Charset", cfg.InputCode, cfg.InputCode}, + {"out", "Output Charset", cfg.OutputCode, cfg.OutputCode}, + {"end", "Line End", fmt.Sprintf("%q", cfg.EndStr), cfg.EndStr}, + {"frame", "Hex Frame Size", fmt.Sprintf("%d", cfg.FrameSize), fmt.Sprintf("%d", cfg.FrameSize)}, + {"timestamp", "Timestamp", fmt.Sprintf("%v", cfg.TimesTamp), fmt.Sprintf("%v", cfg.TimesTamp)}, + {"timefmt", "Timestamp Format", cfg.TimesFmt, cfg.TimesFmt}, + } +} + +// Forward modes for tab cycling +var forwardModes = []string{"tcp", "udp", "tcp-s", "udp-s", "com"} + +func (m *Model) handleForwardPanelKey(key string) bool { + switch key { + case "up", "k": + if m.panelIndex > 0 { + m.panelIndex-- + } + return true + case "down", "j": + if m.panelIndex < len(m.forwardItems)-1 { + m.panelIndex++ + } + return true + case "r": + m.panelError = "" + m.refreshPanel() + return true + case "a": + m.startForwardForm("Add", "tcp", "") + return true + } + if len(m.forwardItems) == 0 { + return true + } + + sel := m.forwardItems[m.panelIndex] + switch key { + case "enter": + if sel.Enabled { + _ = m.App.Forward().Disable(sel.ID) + } else { + _ = m.App.Forward().Enable(sel.ID) + } + m.panelError = "" + m.refreshPanel() + return true + case "d", "delete": + m.startPrompt("Remove Forward #"+fmt.Sprint(sel.ID), "type 'y' to confirm", "", func(v string) { + if strings.TrimSpace(strings.ToLower(v)) == "y" { + if err := m.App.Forward().Remove(sel.ID); err != nil { + m.panelError = err.Error() + } else { + m.panelError = "" + m.refreshPanel() + } + } + }) + return true + case "u": + m.startForwardForm("Update #"+fmt.Sprint(sel.ID), sel.Mode, sel.Address) + return true + default: + return false + } +} + +func (m *Model) startForwardForm(title, mode, address string) { + modeIn := textinput.New() + modeIn.Prompt = " Type: " + modeIn.Placeholder = "Tab to cycle modes" + modeIn.SetValue(mode) + modeIn.CharLimit = 10 + modeIn.Width = 36 + + addrIn := textinput.New() + addrIn.Prompt = " Addr: " + addrIn.Placeholder = "host:port or COM port" + addrIn.SetValue(address) + addrIn.CharLimit = 60 + addrIn.Width = 36 + + m.formActive = true + m.formTitle = title + m.formLabels = []string{"Type (Tab cycle)", "Address"} + m.formFields = []textinput.Model{modeIn, addrIn} + m.formFocus = 0 + m.formFields[0].Focus() + + m.formSubmit = func(vals []string) { + modeStr := strings.TrimSpace(vals[0]) + addrStr := strings.TrimSpace(vals[1]) + + fm, ok := forward.ParseMode(modeStr) + if !ok { + m.panelError = "unknown mode: " + modeStr + return + } + if addrStr == "" { + m.panelError = "address is required" + return + } + + if strings.HasPrefix(title, "Add") { + if _, err := m.App.Forward().Add(fm, addrStr); err != nil { + m.panelError = err.Error() + return + } + } else { + sel := m.forwardItems[m.panelIndex] + if err := m.App.Forward().Update(sel.ID, fm, addrStr); err != nil { + m.panelError = err.Error() + return + } + } + m.panelError = "" + m.refreshPanel() + } +} + +func (m *Model) handlePluginPanelKey(key string) bool { + switch key { + case "up", "k": + if m.panelIndex > 0 { + m.panelIndex-- + } + return true + case "down", "j": + if m.panelIndex < len(m.pluginItems)-1 { + m.panelIndex++ + } + return true + case "r": + m.panelError = "" + m.refreshPanel() + return true + case "l": + m.startPrompt("Load Plugin", "./plugins/demo.lua", "", func(v string) { + path := strings.TrimSpace(v) + if path == "" { + m.panelError = "load path is empty" + return + } + if _, err := m.App.Plugins().Load(path); err != nil { + m.panelError = err.Error() + } else { + m.panelError = "" + m.refreshPanel() + } + }) + return true + } + if len(m.pluginItems) == 0 { + return true + } + + sel := m.pluginItems[m.panelIndex] + switch key { + case "enter": + if sel.Enabled { + _ = m.App.Plugins().Disable(sel.Name) + } else { + _ = m.App.Plugins().Enable(sel.Name) + } + m.panelError = "" + m.refreshPanel() + return true + case "u": + if err := m.App.Plugins().Reload(sel.Name); err != nil { + m.panelError = err.Error() + } else { + m.panelError = "" + m.refreshPanel() + } + return true + case "d", "delete": + m.startPrompt("Unload Plugin "+sel.Name, "type 'y' to confirm", "", func(v string) { + if strings.TrimSpace(strings.ToLower(v)) == "y" { + if err := m.App.Plugins().Unload(sel.Name); err != nil { + m.panelError = err.Error() + } else { + m.panelError = "" + m.refreshPanel() + } + } + }) + return true + default: + return false + } +} + +func (m *Model) handleModePanelKey(key string) bool { + switch key { + case "up", "k": + if m.panelIndex > 0 { + m.panelIndex-- + } + return true + case "down", "j": + if m.panelIndex < len(m.modeItems)-1 { + m.panelIndex++ + } + return true + case "r": + m.panelError = "" + m.refreshPanel() + return true + } + if len(m.modeItems) == 0 { + return true + } + + sel := m.modeItems[m.panelIndex] + cfg := m.App.Cfg() + switch key { + case " ": + if sel.key == "timestamp" { + cfg.TimesTamp = !cfg.TimesTamp + m.refreshPanel() + } + return true + case "enter", "e": + hint := "enter value" + switch sel.key { + case "timestamp": + hint = "on/off" + case "frame": + hint = "positive integer" + case "in", "out": + hint = "charset name (e.g. utf-8, gbk)" + } + initial := sel.rawValue + m.startPrompt("Edit Mode: "+sel.label, hint, initial, func(v string) { + m.App.HandleLine(fmt.Sprintf(".mode set %s %s", sel.key, v)) + m.refreshPanel() + }) + return true + default: + return false + } +} + +func (m *Model) startPrompt(title, hint, initial string, submit func(string)) { + in := textinput.New() + in.Prompt = "> " + in.Placeholder = hint + in.SetValue(initial) + in.Focus() + in.CharLimit = 0 + in.Width = 64 + + m.promptActive = true + m.promptTitle = title + m.promptHint = hint + m.promptInput = in + m.promptSubmit = submit +} + +// --- Form methods (multi-field input) --- + +func (m *Model) handleFormKey(msg tea.KeyMsg) (bool, tea.Cmd) { + key := strings.ToLower(msg.String()) + switch key { + case "esc": + m.formActive = false + m.formSubmit = nil + return true, nil + case "tab": + m.formFields[m.formFocus].Blur() + m.formFocus = (m.formFocus + 1) % len(m.formFields) + + // Cycle forward mode on Tab when type field is focused + if m.formFocus == 0 { + cur := strings.TrimSpace(m.formFields[0].Value()) + idx := -1 + for i, m := range forwardModes { + if m == cur { + idx = i + break + } + } + idx = (idx + 1) % len(forwardModes) + m.formFields[0].SetValue(forwardModes[idx]) + } + m.formFields[m.formFocus].Focus() + return true, nil + case "shift+tab": + m.formFields[m.formFocus].Blur() + m.formFocus = (m.formFocus - 1 + len(m.formFields)) % len(m.formFields) + if m.formFocus == 0 { + cur := strings.TrimSpace(m.formFields[0].Value()) + idx := -1 + for i, m := range forwardModes { + if m == cur { idx = i; break } + } + idx = (idx - 1 + len(forwardModes)) % len(forwardModes) + m.formFields[0].SetValue(forwardModes[idx]) + } + m.formFields[m.formFocus].Focus() + return true, nil + case "enter": + vals := make([]string, len(m.formFields)) + for i, f := range m.formFields { + vals[i] = f.Value() + } + submit := m.formSubmit + m.formActive = false + m.formSubmit = nil + if submit != nil { + submit(vals) + } + return true, nil + default: + var cmd tea.Cmd + m.formFields[m.formFocus], cmd = m.formFields[m.formFocus].Update(msg) + return true, cmd + } +} + +func (m *Model) renderForm() string { + lines := make([]boxLine, 0, len(m.formFields)+2) + for i, f := range m.formFields { + prefix := " " + if i == m.formFocus { + prefix = "▸ " + } + lines = append(lines, boxLine{ + text: prefix + f.View(), + style: modalBodyLineStyle(), + }) + } + footer := "Tab cycles Type | Enter submit | Esc cancel" + if len(m.formFields) > 1 { + footer = "Tab/Shift+Tab switch | Enter submit | Esc cancel" + } + lines = append(lines, boxLine{text: footer, style: modalFooterLineStyle()}) + return renderBox(m.formTitle, lines, 36, m.availableModalWidth()) +} + +func (m *Model) handlePromptKey(msg tea.KeyMsg) (bool, tea.Cmd) { + key := strings.ToLower(msg.String()) + switch key { + case "esc": + m.promptActive = false + m.promptSubmit = nil + return true, nil + case "enter": + value := strings.TrimSpace(m.promptInput.Value()) + submit := m.promptSubmit + m.promptActive = false + m.promptSubmit = nil + if submit != nil { + submit(value) + } + return true, nil + default: + var cmd tea.Cmd + m.promptInput, cmd = m.promptInput.Update(msg) + return true, cmd + } +} + +func (m *Model) renderPanel() string { + switch m.panelKind { + case event.UIPanelForward: + return m.renderForwardPanel() + case event.UIPanelPlugin: + return m.renderPluginPanel() + case event.UIPanelMode: + return m.renderModePanel() + default: + return renderModal("Info", "No panel", m.availableModalWidth()) + } +} + +func (m *Model) renderForwardPanel() string { + lines := make([]panelLine, 0, len(m.forwardItems)+3) + if len(m.forwardItems) == 0 { + lines = append(lines, panelLine{text: "No forwarding targets. Press 'a' to add one."}) + } else { + lines = append(lines, panelLine{text: "ID Mode Enabled Connected Address"}) + for i, it := range m.forwardItems { + lines = append(lines, panelLine{text: fmt.Sprintf("%-3d %-5s %-7v %-9v %s", it.ID, it.Mode, it.Enabled, it.Connected, it.Address), selected: i == m.panelIndex}) + } + } + if m.panelError != "" { + lines = append(lines, panelLine{text: "ERROR: " + m.panelError}) + } + return renderPanelModal("Forward Panel", lines, "j/k select | Enter toggle | a add(form) | u update | d remove | r refresh | Esc close", m.availableModalWidth()) +} + +func (m *Model) renderPluginPanel() string { + lines := make([]panelLine, 0, len(m.pluginItems)+3) + if len(m.pluginItems) == 0 { + lines = append(lines, panelLine{text: "No plugins loaded. Press 'l' to load one."}) + } else { + lines = append(lines, panelLine{text: "Name Enabled Path"}) + for i, it := range m.pluginItems { + lines = append(lines, panelLine{text: fmt.Sprintf("%-20s %-7v %s", it.Name, it.Enabled, it.Path), selected: i == m.panelIndex}) + } + } + if m.panelError != "" { + lines = append(lines, panelLine{text: "ERROR: " + m.panelError}) + } + return renderPanelModal("Plugin Panel", lines, "Up/Down select | Enter toggle | l load | u reload | d unload | r refresh | Esc close", m.availableModalWidth()) +} + +func (m *Model) renderModePanel() string { + lines := make([]panelLine, 0, len(m.modeItems)+3) + lines = append(lines, panelLine{text: "Field Value"}) + for i, it := range m.modeItems { + lines = append(lines, panelLine{text: fmt.Sprintf("%-16s %s", it.label, it.value), selected: i == m.panelIndex}) + } + if m.panelError != "" { + lines = append(lines, panelLine{text: "ERROR: " + m.panelError}) + } + return renderPanelModal("Mode Panel", lines, "Up/Down select | Enter edit | Space toggle | r refresh | Esc close", m.availableModalWidth()) +} + diff --git a/internal/tui/render.go b/internal/tui/render.go new file mode 100644 index 0000000..53b6311 --- /dev/null +++ b/internal/tui/render.go @@ -0,0 +1,190 @@ +package tui + +import ( + "strings" + + "github.com/charmbracelet/lipgloss" +) + +func (m *Model) appendOutput(text string) { + if text == "" { + return + } + m.content.WriteString(text) + if m.ready { + m.viewport.SetContent(m.content.String()) + if m.followTail { + m.viewport.GotoBottom() + } + } +} + +func (m *Model) renderPrompt() string { + lines := []boxLine{ + {text: m.promptHint, style: modalBodyLineStyle()}, + {text: m.promptInput.View(), style: modalBodyLineStyle()}, + {text: "Enter submit | Esc cancel", style: modalFooterLineStyle()}, + } + return renderBox(m.promptTitle, lines, 48, m.availableModalWidth()) +} + +func renderModal(title, body string, maxWidth int) string { + if title == "" { + title = "Info" + } + parts := strings.Split(strings.ReplaceAll(body, "\r\n", "\n"), "\n") + if len(parts) > 12 { + parts = append(parts[:12], "... (press Esc/Enter to close)") + } + lines := make([]boxLine, 0, len(parts)) + for _, part := range parts { + lines = append(lines, boxLine{text: part, style: modalBodyLineStyle()}) + } + return renderBox(title, lines, 20, maxWidth) +} + +func renderPanelModal(title string, lines []panelLine, footer string, maxWidth int) string { + boxLines := make([]boxLine, 0, len(lines)+1) + for _, line := range lines { + style := modalBodyLineStyle() + prefix := " " + if line.selected { + style = selectedPanelLineStyle() + prefix = "▸ " + } + boxLines = append(boxLines, boxLine{text: prefix + line.text, style: style}) + } + boxLines = append(boxLines, boxLine{text: footer, style: modalFooterLineStyle()}) + return renderBox(title, boxLines, 40, maxWidth) +} + +func fillScreen(width, height int, content string) string { + if width <= 0 || height <= 0 { + return content + } + return lipgloss.Place(width, height, lipgloss.Left, lipgloss.Top, content, + lipgloss.WithWhitespaceChars(" "), + lipgloss.WithWhitespaceForeground(lipgloss.Color("0")), + ) +} + +func renderCenteredModal(width, height int, title, body string) string { + maxWidth := width - 8 + if maxWidth < 20 { + maxWidth = 20 + } + return renderCenteredModalContent(width, height, renderModal(title, body, maxWidth)) +} + +func renderCenteredModalContent(width, height int, content string) string { + if width <= 0 || height <= 0 { + return content + } + return lipgloss.Place(width, height, lipgloss.Center, lipgloss.Center, content, + lipgloss.WithWhitespaceChars(" "), + lipgloss.WithWhitespaceForeground(lipgloss.Color("0")), + ) +} + +func (m *Model) availableModalWidth() int { + if m.width <= 0 { + return 100 + } + maxWidth := m.width - 8 + if maxWidth < 20 { + maxWidth = 20 + } + return maxWidth +} + +type boxLine struct { + text string + style lipgloss.Style +} + +func renderBox(title string, lines []boxLine, minWidth, maxWidth int) string { + contentWidth := lipgloss.Width(title) + for _, line := range lines { + contentWidth = maxInt(contentWidth, lipgloss.Width(line.text)) + } + contentWidth = maxInt(minWidth, contentWidth) + contentWidth = minInt(contentWidth, maxWidth) + + boxStyle := lipgloss.NewStyle().Background(lipgloss.Color("236")) + top := boxStyle.Render("╭" + strings.Repeat("─", contentWidth+2) + "╮") + bottom := boxStyle.Render("╰" + strings.Repeat("─", contentWidth+2) + "╯") + + rows := make([]string, 0, len(lines)+3) + rows = append(rows, top) + rows = append(rows, renderBoxRow(modalHeaderLineStyle(), title, contentWidth)) + for _, line := range lines { + rows = append(rows, renderBoxRow(line.style, truncateToWidth(line.text, contentWidth), contentWidth)) + } + rows = append(rows, bottom) + return strings.Join(rows, "\n") +} + +func renderBoxRow(contentStyle lipgloss.Style, text string, width int) string { + visible := truncateToWidth(text, width) + pad := strings.Repeat(" ", maxInt(0, width-lipgloss.Width(visible))) + inner := contentStyle.Render(" " + visible + pad + " ") + return contentStyle.Render("│" + inner + "│") +} + +func modalHeaderLineStyle() lipgloss.Style { + return lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("230")).Background(lipgloss.Color("25")) +} + +func modalBodyLineStyle() lipgloss.Style { + return lipgloss.NewStyle().Foreground(lipgloss.Color("252")).Background(lipgloss.Color("236")) +} + +func modalFooterLineStyle() lipgloss.Style { + return lipgloss.NewStyle().Foreground(lipgloss.Color("250")).Background(lipgloss.Color("236")) +} + +func selectedPanelLineStyle() lipgloss.Style { + return lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("230")).Background(lipgloss.Color("31")) +} + +func truncateToWidth(s string, width int) string { + if width <= 0 || lipgloss.Width(s) <= width { + return s + } + var b strings.Builder + for _, r := range s { + next := b.String() + string(r) + if lipgloss.Width(next) > width { + break + } + b.WriteRune(r) + } + return b.String() +} + +func clampIndex(idx, n int) int { + if n <= 0 || idx < 0 { + return 0 + } + if idx >= n { + return n - 1 + } + return idx +} + +func minInt(a, b int) int { + if a < b { + return a + } + return b +} + +func maxInt(a int, rest ...int) int { + max := a + for _, v := range rest { + if v > max { + max = v + } + } + return max +} diff --git a/internal/tui/tui_test.go b/internal/tui/tui_test.go new file mode 100644 index 0000000..2b76e2e --- /dev/null +++ b/internal/tui/tui_test.go @@ -0,0 +1,33 @@ +package tui + +import "testing" + +func TestParseCSIuBytes(t *testing.T) { + tests := []struct { + name string + seq []byte + want string + ok bool + }{ + {name: "ctrl+alt+f", seq: []byte{0x1b, '[', '1', '0', '2', ';', '6', 'u'}, want: "ctrl+alt+f", ok: true}, + {name: "ctrl+alt+c", seq: []byte{0x1b, '[', '9', '9', ';', '6', 'u'}, want: "ctrl+alt+c", ok: true}, + {name: "ctrl+alt+m", seq: []byte{0x1b, '[', '1', '0', '9', ';', '6', 'u'}, want: "ctrl+alt+m", ok: true}, + {name: "ctrl+alt+p", seq: []byte{0x1b, '[', '1', '1', '2', ';', '6', 'u'}, want: "ctrl+alt+p", ok: true}, + {name: "ctrl+alt+h", seq: []byte{0x1b, '[', '1', '0', '4', ';', '6', 'u'}, want: "ctrl+alt+h", ok: true}, + {name: "ctrl+shift+c", seq: []byte{0x1b, '[', '9', '9', ';', '5', 'u'}, want: "ctrl+shift+c", ok: true}, + {name: "alt+c (no ctrl)", seq: []byte{0x1b, '[', '9', '9', ';', '2', 'u'}, want: "alt+c", ok: true}, + {name: "plain c", seq: []byte{0x1b, '[', '9', '9', ';', '0', 'u'}, want: "c", ok: true}, + {name: "not CSI u", seq: []byte{0x1b, '[', 'A'}, want: "", ok: false}, + {name: "empty", seq: []byte{}, want: "", ok: false}, + {name: "no escape", seq: []byte("hello"), want: "", ok: false}, + {name: "ESC [ A (arrow up)", seq: []byte{0x1b, '[', 'A'}, want: "", ok: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := parseCSIuBytes(tt.seq) + if ok != tt.ok || got != tt.want { + t.Fatalf("parseCSIuBytes(%v): got=(%q,%v) want=(%q,%v)", tt.seq, got, ok, tt.want, tt.ok) + } + }) + } +} diff --git a/main.go b/main.go deleted file mode 100644 index d2468c1..0000000 --- a/main.go +++ /dev/null @@ -1,56 +0,0 @@ -package main - -import ( - "fmt" - "github.com/spf13/pflag" - "io" - "log" - "os" -) - -func init() { - log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile | log.Lmsgprefix) - for _, f := range flags { - flagInit(&f) - } - cmdinit() -} - -func main() { - pflag.Parse() - flagExt() - if config.portName == "" { - getCliFlag() - } - ports, err := checkPortAvailability(config.portName) - if err != nil { - fmt.Println(err) - printUsage(ports) - os.Exit(0) - } - - // 日志文件输出检测 - checkLogOpen() - - //串口设备开启 - OpenSerial() - - defer CloseSerial() - // 打开文件服务 - OpenTrzsz() - - defer CloseTrzsz() - - //开启转发 - OpenForwarding() - - // 获取终端输入 - go input(in) - - if len(outs) != 1 { - out = io.MultiWriter(outs...) - } - for { - output() - } -} diff --git a/mutual.go b/mutual.go deleted file mode 100644 index 4255ac3..0000000 --- a/mutual.go +++ /dev/null @@ -1,85 +0,0 @@ -package main - -import ( - "bufio" - "fmt" - "github.com/trzsz/trzsz-go/trzsz" - "github.com/zimolab/charsetconv" - "go.bug.st/serial" - "io" - "log" - "os" - "strings" - "time" -) - -var ( - serialPort serial.Port - in io.Reader = os.Stdin - out io.Writer = os.Stdout - outs = []io.Writer{os.Stdout} - trzszFilter *trzsz.TrzszFilter - clientIn *io.PipeReader - stdoutPipe *io.PipeReader - stdinPipe *io.PipeWriter - clientOut *io.PipeWriter -) - -func input(in io.Reader) { - var err error - input := bufio.NewScanner(in) - var ok = false - for { - input.Scan() - ok = false - args = strings.Split(input.Text(), " ") - for _, cmd := range commands { - if strings.Compare(strings.TrimSpace(args[0]), cmd.name) == 0 { - cmd.function() - ok = true - } - } - if !ok { - _, err := io.WriteString(stdinPipe, input.Text()) - if err != nil { - log.Fatal(err) - } - _, err = io.WriteString(stdinPipe, config.endStr) - if err != nil { - log.Fatal(err) - } - } - err = serialPort.Drain() - ErrorF(err) - } -} - -func strout(out io.Writer, cs, str string) { - err := charsetconv.EncodeWith(strings.NewReader(str), out, charsetconv.Charset(cs), false) - ErrorF(err) -} - -func output() { - var err error - if strings.Compare(config.inputCode, "hex") == 0 { - b := make([]byte, config.frameSize) - r, _ := io.LimitReader(stdoutPipe, int64(config.frameSize)).Read(b) - if r != 0 { - if config.timesTamp { - strout(out, config.outputCode, fmt.Sprintf("%v % X %q \n", time.Now().Format(config.timesFmt), b, b)) - } else { - strout(out, config.outputCode, fmt.Sprintf("% X %q \n", b, b)) - } - } - } else { - if config.timesTamp { - line, _, _ := bufio.NewReader(stdoutPipe).ReadLine() - if line != nil { - strout(out, config.outputCode, fmt.Sprintf("%v %s\n", time.Now().Format(config.timesFmt), line)) - } - } else { - err = charsetconv.ConvertWith(stdoutPipe, charsetconv.Charset(config.inputCode), out, charsetconv.Charset(config.outputCode), false) - } - } - ErrorP(err) -} diff --git a/pkg/charset/charset.go b/pkg/charset/charset.go new file mode 100644 index 0000000..7ab0fc6 --- /dev/null +++ b/pkg/charset/charset.go @@ -0,0 +1,43 @@ +// Package charset provides character-set conversion and hex formatting utilities. +package charset + +import ( + "bytes" + "fmt" + "strings" + "time" + + "github.com/zimolab/charsetconv" +) + +// ConvertChunk converts a byte chunk from srcCode charset to dstCode charset. +// Returns nil, nil when input is empty. Returns a copied slice when charsets match. +func ConvertChunk(chunk []byte, srcCode, dstCode string) ([]byte, error) { + if len(chunk) == 0 { + return nil, nil + } + + if strings.EqualFold(srcCode, dstCode) { + dup := make([]byte, len(chunk)) + copy(dup, chunk) + return dup, nil + } + + var buf bytes.Buffer + err := charsetconv.ConvertWith(bytes.NewReader(chunk), charsetconv.Charset(srcCode), &buf, charsetconv.Charset(dstCode), false) + if err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// FormatHexFrame formats a byte frame as hex + printable representation. +// Optionally prefixes with a timestamp using the given format string. +func FormatHexFrame(frame []byte, withTimestamp bool, tsFmt string) string { + if withTimestamp { + return fmt.Sprintf("%v % X %q \n", time.Now().Format(tsFmt), frame, frame) + } + + return fmt.Sprintf("% X %q \n", frame, frame) +} diff --git a/pkg/charset/charset_test.go b/pkg/charset/charset_test.go new file mode 100644 index 0000000..ef6ba4e --- /dev/null +++ b/pkg/charset/charset_test.go @@ -0,0 +1,93 @@ +package charset + +import ( + "bytes" + "strings" + "testing" +) + +func TestConvertChunk(t *testing.T) { + t.Run("empty", func(t *testing.T) { + out, err := ConvertChunk(nil, "UTF-8", "UTF-8") + if err != nil { + t.Fatalf("ConvertChunk(nil) unexpected error: %v", err) + } + if out != nil { + t.Fatalf("ConvertChunk(nil) expected nil output") + } + }) + + t.Run("same-charset-copy", func(t *testing.T) { + in := []byte("hello") + out, err := ConvertChunk(in, "UTF-8", "UTF-8") + if err != nil { + t.Fatalf("ConvertChunk same charset unexpected error: %v", err) + } + if !bytes.Equal(out, in) { + t.Fatalf("ConvertChunk same charset mismatch got=%q want=%q", out, in) + } + + out[0] = 'H' + if in[0] != 'h' { + t.Fatalf("ConvertChunk should return a copied slice") + } + }) +} + +func TestFormatHexFrame(t *testing.T) { + frame := []byte("AB") + out := FormatHexFrame(frame, false, "") + if !strings.Contains(out, "41 42") { + t.Fatalf("FormatHexFrame missing hex bytes: %q", out) + } + if !strings.Contains(out, "\"AB\"") { + t.Fatalf("FormatHexFrame missing quoted bytes: %q", out) + } + + outTS := FormatHexFrame([]byte("A"), true, "2006") + if !strings.Contains(outTS, "41") || !strings.Contains(outTS, "\"A\"") { + t.Fatalf("FormatHexFrame(withTimestamp) malformed output: %q", outTS) + } +} + +func TestConvertChunkCharsetConversion(t *testing.T) { + t.Run("gbk-to-utf8", func(t *testing.T) { + // Chinese "你好" in GBK: 0xC4 0xE3 0xBA 0xC3 + gbkHello := []byte{0xC4, 0xE3, 0xBA, 0xC3} + out, err := ConvertChunk(gbkHello, "GBK", "UTF-8") + if err != nil { + t.Fatalf("ConvertChunk GBK->UTF-8 unexpected error: %v", err) + } + if string(out) != "你好" { + t.Fatalf("ConvertChunk GBK->UTF-8 got=%q want=%q", string(out), "你好") + } + }) + + t.Run("same-charset-different-case", func(t *testing.T) { + in := []byte("hello") + out, err := ConvertChunk(in, "utf-8", "UTF-8") + if err != nil { + t.Fatalf("ConvertChunk case-diff unexpected error: %v", err) + } + if !bytes.Equal(out, in) { + t.Fatalf("ConvertChunk case-diff mismatch got=%q want=%q", out, in) + } + }) + + t.Run("invalid-charset", func(t *testing.T) { + _, err := ConvertChunk([]byte("hello"), "INVALID-CHARSET-NAME", "UTF-8") + if err == nil { + t.Fatalf("ConvertChunk invalid charset should error") + } + }) + + t.Run("empty-input", func(t *testing.T) { + out, err := ConvertChunk([]byte{}, "GBK", "UTF-8") + if err != nil { + t.Fatalf("ConvertChunk empty unexpected error: %v", err) + } + if out != nil { + t.Fatalf("ConvertChunk empty input should return nil") + } + }) +} diff --git a/pkg/forward/forward_test.go b/pkg/forward/forward_test.go new file mode 100644 index 0000000..5e591ae --- /dev/null +++ b/pkg/forward/forward_test.go @@ -0,0 +1,253 @@ +package forward + +import ( + "net" + "testing" + "time" +) + +func TestManagerTCPFlow(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen failed: %v", err) + } + defer listener.Close() + + acceptCh := make(chan net.Conn, 1) + errCh := make(chan error, 1) + go func() { + conn, e := listener.Accept() + if e != nil { + errCh <- e + return + } + acceptCh <- conn + }() + + serialCh := make(chan string, 2) + mgr := NewManager(func(b []byte) error { + serialCh <- string(b) + return nil + }, func(string, ...any) {}) + defer mgr.Close() + + id, err := mgr.Add(TCP, listener.Addr().String()) + if err != nil { + t.Fatalf("Add() failed: %v", err) + } + + var serverConn net.Conn + select { + case serverConn = <-acceptCh: + case e := <-errCh: + t.Fatalf("accept failed: %v", e) + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for accepted connection") + } + defer serverConn.Close() + + items := mgr.List() + if len(items) != 1 || items[0].ID != id || !items[0].Enabled { + t.Fatalf("unexpected list after add: %+v", items) + } + + if err = serverConn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatalf("SetReadDeadline failed: %v", err) + } + mgr.Broadcast([]byte("from-app")) + buf := make([]byte, 64) + n, err := serverConn.Read(buf) + if err != nil { + t.Fatalf("server read from broadcast failed: %v", err) + } + if string(buf[:n]) != "from-app" { + t.Fatalf("broadcast payload mismatch got=%q", string(buf[:n])) + } + + if _, err = serverConn.Write([]byte("from-remote")); err != nil { + t.Fatalf("server write failed: %v", err) + } + select { + case got := <-serialCh: + if got != "from-remote" { + t.Fatalf("writeToSerial payload mismatch got=%q", got) + } + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for writeToSerial callback") + } + + if err = mgr.Disable(id); err != nil { + t.Fatalf("Disable() failed: %v", err) + } + items = mgr.List() + if len(items) != 1 || items[0].Enabled { + t.Fatalf("Disable() did not update state: %+v", items) + } + + if err = mgr.Remove(id); err != nil { + t.Fatalf("Remove() failed: %v", err) + } + if got := mgr.List(); len(got) != 0 { + t.Fatalf("expected empty list after remove, got=%+v", got) + } +} + +func TestManagerErrorCases(t *testing.T) { + mgr := NewManager(func([]byte) error { return nil }, func(string, ...any) {}) + defer mgr.Close() + + if _, err := mgr.Add(None, "127.0.0.1:1"); err == nil { + t.Fatalf("Add(None) expected error") + } + + if err := mgr.Remove(999); err == nil { + t.Fatalf("Remove(non-existing) expected error") + } + + if err := mgr.Disable(999); err == nil { + t.Fatalf("Disable(non-existing) expected error") + } + + if err := mgr.Enable(999); err == nil { + t.Fatalf("Enable(non-existing) expected error") + } + + if err := mgr.Update(999, TCP, "127.0.0.1:1"); err == nil { + t.Fatalf("Update(non-existing) expected error") + } + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen failed: %v", err) + } + defer listener.Close() + + id, err := mgr.Add(TCP, listener.Addr().String()) + if err != nil { + t.Fatalf("Add() failed: %v", err) + } + if err = mgr.Update(id, None, "127.0.0.1:1"); err == nil { + t.Fatalf("Update(None) expected error") + } +} + +func TestManagerSetInboundReporter(t *testing.T) { + reported := make(chan []byte, 1) + mgr := NewManager(func([]byte) error { return nil }, func(string, ...any) {}) + defer mgr.Close() + mgr.SetInboundReporter(func(id int, chunk []byte) { + reported <- chunk + }) + // Verify the callback was stored (indirect test) + _ = reported +} + +func TestManagerBroadcastToDisabled(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen failed: %v", err) + } + defer listener.Close() + + writeCh := make(chan []byte, 4) + mgr := NewManager(func([]byte) error { + writeCh <- nil + return nil + }, func(string, ...any) {}) + defer mgr.Close() + + id, err := mgr.Add(TCP, listener.Addr().String()) + if err != nil { + t.Fatalf("Add() failed: %v", err) + } + + if err = mgr.Disable(id); err != nil { + t.Fatalf("Disable() failed: %v", err) + } + + mgr.Broadcast([]byte("should-not-arrive")) + + select { + case <-writeCh: + t.Fatalf("broadcast should not write to serial when disabled") + default: + } + + mgr.Broadcast(nil) + mgr.Broadcast([]byte{}) +} + +func TestManagerEnable(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen failed: %v", err) + } + defer listener.Close() + + writeCh := make(chan []byte, 2) + mgr := NewManager(func([]byte) error { + writeCh <- nil + return nil + }, func(string, ...any) {}) + defer mgr.Close() + + id, err := mgr.Add(TCP, listener.Addr().String()) + if err != nil { + t.Fatalf("Add() failed: %v", err) + } + + if err = mgr.Disable(id); err != nil { + t.Fatalf("Disable() failed: %v", err) + } + + if err = mgr.Enable(id); err != nil { + t.Fatalf("Enable() failed: %v", err) + } + + items := mgr.List() + if len(items) != 1 || !items[0].Enabled { + t.Fatalf("expected enabled after Enable(), got=%+v", items) + } + + if err = mgr.Enable(id); err != nil { + t.Fatalf("second Enable() should succeed: %v", err) + } +} + +func TestManagerUpdate(t *testing.T) { + l1, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen 1 failed: %v", err) + } + defer l1.Close() + + l2, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen 2 failed: %v", err) + } + defer l2.Close() + + mgr := NewManager(func([]byte) error { return nil }, func(string, ...any) {}) + defer mgr.Close() + + id, err := mgr.Add(TCP, l1.Addr().String()) + if err != nil { + t.Fatalf("Add() failed: %v", err) + } + + if err = mgr.Update(id, TCP, l2.Addr().String()); err != nil { + t.Fatalf("Update() failed: %v", err) + } + + items := mgr.List() + if len(items) != 1 || items[0].Address != l2.Addr().String() { + t.Fatalf("update should change address, got=%+v", items) + } + + if err = mgr.Disable(id); err != nil { + t.Fatalf("Disable() failed: %v", err) + } + if err = mgr.Update(id, TCP, l1.Addr().String()); err != nil { + t.Fatalf("Update() on disabled should succeed: %v", err) + } +} diff --git a/pkg/forward/manager.go b/pkg/forward/manager.go new file mode 100644 index 0000000..1c29f1a --- /dev/null +++ b/pkg/forward/manager.go @@ -0,0 +1,661 @@ +// Package forward manages TCP/UDP/COM forwarding targets for serial data. +package forward + +import ( + "fmt" + "net" + "sort" + "strings" + "sync" + "sync/atomic" + "time" + + "go.bug.st/serial" +) + +// Mode is the forwarding protocol mode. +type Mode int + +const ( + None Mode = 0 + TCP Mode = 1 + UDP Mode = 2 + TCPServer Mode = 3 + UDPServer Mode = 4 + COMPort Mode = 5 +) + +// ParseMode parses a mode string. +func ParseMode(v string) (Mode, bool) { + switch strings.ToLower(strings.TrimSpace(v)) { + case "tcp", "tcp-c", "tcpc", "1": + return TCP, true + case "udp", "udp-c", "udpc", "2": + return UDP, true + case "tcp-s", "tcps", "tcp-server", "3": + return TCPServer, true + case "udp-s", "udps", "udp-server", "4": + return UDPServer, true + case "com", "serial", "5": + return COMPort, true + default: + return None, false + } +} + +func (m Mode) Network() string { + switch m { + case TCP, TCPServer: + return "tcp" + case UDP, UDPServer: + return "udp" + case COMPort: + return "serial" + default: + return "" + } +} + +func (m Mode) String() string { + switch m { + case TCP: + return "tcp" + case UDP: + return "udp" + case TCPServer: + return "tcp-s" + case UDPServer: + return "udp-s" + case COMPort: + return "com" + default: + return "none" + } +} + +// Stats holds I/O statistics for a forward target. +type Stats struct { + ReadBytes uint64 + WrittenBytes uint64 + LastError string +} + +// Target represents a single forwarding connection. +type Target struct { + ID int + Mode Mode + Address string + Enabled bool + Connected bool + CreatedAt time.Time + + // Client-mode connection (TCP/UDP client) + conn net.Conn + + // Server-mode fields + listener net.Listener // TCP server listener + conns map[net.Conn]struct{} // TCP server accepted connections + connsMu sync.Mutex + + // UDP server + packetConn net.PacketConn // UDP server listener + remoteAddrs map[string]net.Addr // known UDP remotes + + // COM port + serialPort serial.Port + + stats Stats + mu sync.Mutex + closeCh chan struct{} + closed bool +} + +// AcceptedConns returns the number of accepted connections (TCP server only). +func (t *Target) acceptedConns() int { + t.connsMu.Lock() + defer t.connsMu.Unlock() + return len(t.conns) +} + +// Snapshot is a read-only view of a forward target for display. +type Snapshot struct { + ID int + Mode string + Address string + Enabled bool + Connected bool + ReadBytes uint64 + WriteByte uint64 + LastError string + Conns int // accepted connection count (TCP server) +} + +// Manager coordinates forwarding targets. +type Manager struct { + mu sync.RWMutex + targets map[int]*Target + nextID int + writeToSerial func([]byte) error + notify func(string, ...any) + onInbound func(int, []byte) +} + +// NewManager creates a forwarding manager. +func NewManager(writeToSerial func([]byte) error, notify func(string, ...any)) *Manager { + return &Manager{ + targets: make(map[int]*Target), + nextID: 1, + writeToSerial: writeToSerial, + notify: notify, + } +} + +// SetInboundReporter sets a callback invoked when inbound data arrives from a target. +func (m *Manager) SetInboundReporter(fn func(int, []byte)) { + m.mu.Lock() + defer m.mu.Unlock() + m.onInbound = fn +} + +// Add creates and connects a new forward target. +func (m *Manager) Add(mode Mode, address string) (int, error) { + if mode == None { + return 0, fmt.Errorf("forward mode cannot be none") + } + + t := &Target{ + Mode: mode, + Address: address, + Enabled: true, + CreatedAt: time.Now(), + closeCh: make(chan struct{}), + } + + switch mode { + case TCP, UDP: + conn, err := net.Dial(mode.Network(), address) + if err != nil { + t.stats.LastError = err.Error() + return 0, err + } + t.conn = conn + t.Connected = true + + m.mu.Lock() + t.ID = m.nextID + m.nextID++ + m.targets[t.ID] = t + m.mu.Unlock() + + go m.readLoop(t, conn, t.closeCh) + + case TCPServer: + listener, err := net.Listen("tcp", address) + if err != nil { + t.stats.LastError = err.Error() + return 0, err + } + t.listener = listener + t.conns = make(map[net.Conn]struct{}) + t.Connected = true + + m.mu.Lock() + t.ID = m.nextID + m.nextID++ + m.targets[t.ID] = t + m.mu.Unlock() + + go m.acceptLoop(t) + + case UDPServer: + pc, err := net.ListenPacket("udp", address) + if err != nil { + t.stats.LastError = err.Error() + return 0, err + } + t.packetConn = pc + t.remoteAddrs = make(map[string]net.Addr) + t.Connected = true + + m.mu.Lock() + t.ID = m.nextID + m.nextID++ + m.targets[t.ID] = t + m.mu.Unlock() + + go m.readLoopPacket(t) + + case COMPort: + sp, err := serial.Open(address, &serial.Mode{BaudRate: 115200, DataBits: 8, StopBits: 0, Parity: 0}) + if err != nil { + t.stats.LastError = err.Error() + return 0, err + } + t.serialPort = sp + t.Connected = true + + m.mu.Lock() + t.ID = m.nextID + m.nextID++ + m.targets[t.ID] = t + m.mu.Unlock() + + go m.readLoopSerial(t) + } + + m.notify("[forward] #%d %s %s connected", t.ID, t.Mode.String(), t.Address) + return t.ID, nil +} + +func (m *Manager) acceptLoop(t *Target) { + for { + conn, err := t.listener.Accept() + if err != nil { + select { + case <-t.closeCh: + return + default: + } + t.stats.LastError = err.Error() + m.notify("[forward] #%d accept error: %v", t.ID, err) + return + } + + t.connsMu.Lock() + t.conns[conn] = struct{}{} + t.connsMu.Unlock() + + m.notify("[forward] #%d accepted %s", t.ID, conn.RemoteAddr()) + go m.readLoop(t, conn, t.closeCh) + } +} + +func (m *Manager) processChunk(t *Target, data []byte) { + if len(data) == 0 { + return + } + n := len(data) + atomic.AddUint64(&t.stats.ReadBytes, uint64(n)) + chunk := make([]byte, n) + copy(chunk, data) + if wErr := m.writeToSerial(chunk); wErr != nil { + t.stats.LastError = wErr.Error() + m.notify("[forward] #%d write serial error: %v", t.ID, wErr) + } else if m.onInbound != nil { + m.onInbound(t.ID, chunk) + } +} + +func (m *Manager) readLoopError(t *Target, err error) { + select { + case <-t.closeCh: + return + default: + } + t.Connected = false + t.stats.LastError = err.Error() + m.notify("[forward] #%d disconnected: %v", t.ID, err) +} + +func (m *Manager) readLoopPacket(t *Target) { + buf := make([]byte, 4096) + for { + n, addr, err := t.packetConn.ReadFrom(buf) + if n > 0 { + m.processChunk(t, buf[:n]) + t.mu.Lock() + t.remoteAddrs[addr.String()] = addr + t.mu.Unlock() + } + if err != nil { + m.readLoopError(t, err) + return + } + select { + case <-t.closeCh: + return + default: + } + } +} + +func (m *Manager) readLoopSerial(t *Target) { + buf := make([]byte, 4096) + for { + n, err := t.serialPort.Read(buf) + if n > 0 { + m.processChunk(t, buf[:n]) + } + if err != nil { + m.readLoopError(t, err) + return + } + select { + case <-t.closeCh: + return + default: + } + } +} + +func (m *Manager) readLoop(t *Target, conn net.Conn, stop <-chan struct{}) { + buf := make([]byte, 4096) + for { + n, err := conn.Read(buf) + if n > 0 { + m.processChunk(t, buf[:n]) + } + if err != nil { + t.Connected = false + t.stats.LastError = err.Error() + if t.Mode == TCPServer { + t.connsMu.Lock() + delete(t.conns, conn) + t.connsMu.Unlock() + } + m.notify("[forward] #%d disconnected: %v", t.ID, err) + _ = conn.Close() + return + } + select { + case <-stop: + _ = conn.Close() + if t.Mode == TCPServer { + t.connsMu.Lock() + delete(t.conns, conn) + t.connsMu.Unlock() + } + return + default: + } + } +} + +// Remove disconnects and removes a target. +func (m *Manager) Remove(id int) error { + m.mu.Lock() + t, ok := m.targets[id] + if !ok { + m.mu.Unlock() + return fmt.Errorf("forward #%d not found", id) + } + delete(m.targets, id) + m.mu.Unlock() + + t.close() + m.notify("[forward] #%d removed", id) + return nil +} + +// Enable (re)connects a target. +func (m *Manager) Enable(id int) error { + m.mu.RLock() + t, ok := m.targets[id] + m.mu.RUnlock() + if !ok { + return fmt.Errorf("forward #%d not found", id) + } + + t.mu.Lock() + defer t.mu.Unlock() + if t.Enabled && t.Connected { + return nil + } + + switch t.Mode { + case TCP, UDP: + conn, err := net.Dial(t.Mode.Network(), t.Address) + if err != nil { + t.stats.LastError = err.Error() + return err + } + t.conn = conn + t.Connected = true + t.closeCh = make(chan struct{}) + t.closed = false + go m.readLoop(t, conn, t.closeCh) + + case TCPServer: + listener, err := net.Listen("tcp", t.Address) + if err != nil { + t.stats.LastError = err.Error() + return err + } + t.listener = listener + t.conns = make(map[net.Conn]struct{}) + t.Connected = true + t.closeCh = make(chan struct{}) + t.closed = false + go m.acceptLoop(t) + + case UDPServer: + pc, err := net.ListenPacket("udp", t.Address) + if err != nil { + t.stats.LastError = err.Error() + return err + } + t.packetConn = pc + t.remoteAddrs = make(map[string]net.Addr) + t.Connected = true + t.closeCh = make(chan struct{}) + t.closed = false + go m.readLoopPacket(t) + + case COMPort: + sp, err := serial.Open(t.Address, &serial.Mode{BaudRate: 115200, DataBits: 8, StopBits: 0, Parity: 0}) + if err != nil { + t.stats.LastError = err.Error() + return err + } + t.serialPort = sp + t.Connected = true + t.closeCh = make(chan struct{}) + t.closed = false + go m.readLoopSerial(t) + } + + t.Enabled = true + m.notify("[forward] #%d enabled", id) + return nil +} + +// Update changes a target's mode and address, reconnecting if enabled. +func (m *Manager) Update(id int, mode Mode, address string) error { + if mode == None { + return fmt.Errorf("forward mode cannot be none") + } + + m.mu.RLock() + t, ok := m.targets[id] + m.mu.RUnlock() + if !ok { + return fmt.Errorf("forward #%d not found", id) + } + + t.mu.Lock() + wasEnabled := t.Enabled + t.Mode = mode + t.Address = address + t.mu.Unlock() + + t.close() + + if !wasEnabled { + m.notify("[forward] #%d updated (disabled)", id) + return nil + } + + return m.Enable(id) +} + +// Disable disconnects a target without removing it. +func (m *Manager) Disable(id int) error { + m.mu.RLock() + t, ok := m.targets[id] + m.mu.RUnlock() + if !ok { + return fmt.Errorf("forward #%d not found", id) + } + + t.mu.Lock() + t.Enabled = false + t.mu.Unlock() + t.close() + m.notify("[forward] #%d disabled", id) + return nil +} + +// Broadcast sends data to all enabled, connected targets. +func (m *Manager) Broadcast(data []byte) { + if len(data) == 0 { + return + } + + m.mu.RLock() + items := make([]*Target, 0, len(m.targets)) + for _, t := range m.targets { + items = append(items, t) + } + m.mu.RUnlock() + + for _, t := range items { + if !t.Enabled || !t.Connected { + continue + } + + switch t.Mode { + case TCP, UDP: + if t.conn == nil { + continue + } + n, err := t.conn.Write(data) + if err != nil { + t.stats.LastError = err.Error() + m.notify("[forward] #%d write error: %v", t.ID, err) + } else { + atomic.AddUint64(&t.stats.WrittenBytes, uint64(n)) + } + + case TCPServer: + t.connsMu.Lock() + conns := make([]net.Conn, 0, len(t.conns)) + for c := range t.conns { + conns = append(conns, c) + } + t.connsMu.Unlock() + for _, c := range conns { + n, err := c.Write(data) + if err != nil { + t.stats.LastError = err.Error() + } else { + atomic.AddUint64(&t.stats.WrittenBytes, uint64(n)) + } + } + + case UDPServer: + t.mu.Lock() + addrs := make([]net.Addr, 0, len(t.remoteAddrs)) + for _, addr := range t.remoteAddrs { + addrs = append(addrs, addr) + } + t.mu.Unlock() + for _, addr := range addrs { + n, err := t.packetConn.WriteTo(data, addr) + if err != nil { + t.stats.LastError = err.Error() + } else { + atomic.AddUint64(&t.stats.WrittenBytes, uint64(n)) + } + } + + case COMPort: + if t.serialPort == nil { + continue + } + n, err := t.serialPort.Write(data) + if err != nil { + t.stats.LastError = err.Error() + m.notify("[forward] #%d write error: %v", t.ID, err) + } else { + atomic.AddUint64(&t.stats.WrittenBytes, uint64(n)) + } + } + } +} + +// List returns a snapshot of all targets. +func (m *Manager) List() []Snapshot { + m.mu.RLock() + items := make([]Snapshot, 0, len(m.targets)) + for _, t := range m.targets { + items = append(items, Snapshot{ + ID: t.ID, + Mode: t.Mode.String(), + Address: t.Address, + Enabled: t.Enabled, + Connected: t.Connected, + ReadBytes: atomic.LoadUint64(&t.stats.ReadBytes), + WriteByte: atomic.LoadUint64(&t.stats.WrittenBytes), + LastError: t.stats.LastError, + Conns: t.acceptedConns(), + }) + } + m.mu.RUnlock() + + sort.Slice(items, func(i, j int) bool { + return items[i].ID < items[j].ID + }) + + return items +} + +// Close disconnects and removes all targets. +func (m *Manager) Close() { + m.mu.Lock() + items := make([]*Target, 0, len(m.targets)) + for _, t := range m.targets { + items = append(items, t) + } + m.targets = map[int]*Target{} + m.mu.Unlock() + + for _, t := range items { + t.close() + } +} + +func (t *Target) close() { + t.mu.Lock() + if t.closed { + t.mu.Unlock() + return + } + t.closed = true + ch := t.closeCh + conn := t.conn + listener := t.listener + pc := t.packetConn + sp := t.serialPort + t.conn = nil + t.listener = nil + t.packetConn = nil + t.serialPort = nil + t.Connected = false + t.mu.Unlock() + + if ch != nil { + close(ch) + } + if conn != nil { + _ = conn.Close() + } + if listener != nil { + _ = listener.Close() + } + if pc != nil { + _ = pc.Close() + } + if sp != nil { + _ = sp.Close() + } +} diff --git a/pkg/luaplugin/helpers.go b/pkg/luaplugin/helpers.go new file mode 100644 index 0000000..f3afd07 --- /dev/null +++ b/pkg/luaplugin/helpers.go @@ -0,0 +1,116 @@ +package luaplugin + +import ( + lua "github.com/yuin/gopher-lua" +) + +// registerHelpers registers Go utility functions into a Lua state. +func registerHelpers(L *lua.LState) { + modbus := L.NewTable() + L.SetGlobal("modbus", modbus) + + L.SetField(modbus, "crc16", L.NewFunction(luaCRC16)) + L.SetField(modbus, "validate", L.NewFunction(luaValidateCRC)) + + hex := L.NewTable() + L.SetGlobal("hex", hex) + L.SetField(hex, "encode", L.NewFunction(luaHexEncode)) + L.SetField(hex, "decode", L.NewFunction(luaHexDecode)) + + util := L.NewTable() + L.SetGlobal("util", util) + L.SetField(util, "bytes", L.NewFunction(luaBytes)) +} + +// crc16 computes the CRC-16/MODBUS checksum for the given data. +func crc16(data []byte) uint16 { + var crc uint16 = 0xFFFF + for _, b := range data { + crc ^= uint16(b) + for i := 0; i < 8; i++ { + if crc&1 != 0 { + crc = (crc >> 1) ^ 0xA001 + } else { + crc >>= 1 + } + } + } + return crc +} + +func luaCRC16(L *lua.LState) int { + s := L.CheckString(1) + crc := crc16([]byte(s)) + L.Push(lua.LNumber(crc)) + return 1 +} + +func luaValidateCRC(L *lua.LState) int { + s := L.CheckString(1) + if len(s) < 2 { + L.Push(lua.LBool(false)) + return 1 + } + data := []byte(s[:len(s)-2]) + crc := crc16(data) + expect := uint16(s[len(s)-2]) | uint16(s[len(s)-1])<<8 + L.Push(lua.LBool(crc == expect)) + return 1 +} + +func luaHexEncode(L *lua.LState) int { + s := L.CheckString(1) + buf := make([]byte, len(s)*2) + for i, b := range []byte(s) { + buf[i*2] = hexChar(b >> 4) + buf[i*2+1] = hexChar(b & 0x0F) + } + L.Push(lua.LString(buf)) + return 1 +} + +func luaHexDecode(L *lua.LState) int { + s := L.CheckString(1) + if len(s)%2 != 0 { + L.Push(lua.LNil) + return 1 + } + buf := make([]byte, len(s)/2) + for i := 0; i < len(s); i += 2 { + buf[i/2] = unhexChar(s[i])<<4 | unhexChar(s[i+1]) + } + L.Push(lua.LString(buf)) + return 1 +} + +func luaBytes(L *lua.LState) int { + // Converts a sequence of numbers to a byte string. + // e.g. util.bytes(0x01, 0x03, 0x00, 0x01, 0x00, 0x01) → "\x01\x03\x00\x01\x00\x01" + top := L.GetTop() + buf := make([]byte, top) + for i := 1; i <= top; i++ { + buf[i-1] = byte(L.CheckInt(i)) + } + L.Push(lua.LString(buf)) + return 1 +} + +func hexChar(b byte) byte { + if b < 10 { + return '0' + b + } + return 'A' + (b - 10) +} + +func unhexChar(c byte) byte { + switch { + case c >= '0' && c <= '9': + return c - '0' + case c >= 'a' && c <= 'f': + return c - 'a' + 10 + case c >= 'A' && c <= 'F': + return c - 'A' + 10 + default: + return 0 + } +} diff --git a/pkg/luaplugin/hooks.go b/pkg/luaplugin/hooks.go new file mode 100644 index 0000000..e54edb0 --- /dev/null +++ b/pkg/luaplugin/hooks.go @@ -0,0 +1,50 @@ +package luaplugin + +import lua "github.com/yuin/gopher-lua" + +func callStringHook(L *lua.LState, name string, payload string) (*string, bool, error) { + fn := L.GetGlobal(name) + if fn.Type() == lua.LTNil { + return nil, false, nil + } + + if err := L.CallByParam(lua.P{Fn: fn, NRet: 1, Protect: true}, lua.LString(payload)); err != nil { + return nil, true, err + } + + ret := L.Get(-1) + L.Pop(1) + if ret.Type() == lua.LTNil { + return nil, true, nil + } + + s := ret.String() + return &s, true, nil +} + +func callCommandHook(L *lua.LState, name, line string) (string, bool, bool, error) { + fn := L.GetGlobal(name) + if fn.Type() == lua.LTNil { + return "", true, false, nil + } + + if err := L.CallByParam(lua.P{Fn: fn, NRet: 2, Protect: true}, lua.LString(line)); err != nil { + return "", true, true, err + } + + allowVal := L.Get(-1) + lineVal := L.Get(-2) + L.Pop(2) + + allow := true + if allowVal.Type() == lua.LTBool { + allow = lua.LVAsBool(allowVal) + } + + next := "" + if lineVal.Type() != lua.LTNil { + next = lineVal.String() + } + + return next, allow, true, nil +} diff --git a/pkg/luaplugin/manager.go b/pkg/luaplugin/manager.go new file mode 100644 index 0000000..0cc6165 --- /dev/null +++ b/pkg/luaplugin/manager.go @@ -0,0 +1,232 @@ +// Package luaplugin provides a Lua plugin system for processing serial data streams. +package luaplugin + +import ( + "fmt" + "path/filepath" + "sort" + "strings" + "sync" + + lua "github.com/yuin/gopher-lua" +) + +// Plugin represents a loaded Lua plugin. +type Plugin struct { + Name string + Path string + Enabled bool + L *lua.LState + callMu sync.Mutex +} + +// Snapshot is a read-only view of a plugin for display. +type Snapshot struct { + Name string + Path string + Enabled bool +} + +// Manager coordinates plugin lifecycle and hook execution. +type Manager struct { + mu sync.RWMutex + plugins map[string]*Plugin +} + +// NewManager creates a plugin manager. +func NewManager() *Manager { + return &Manager{plugins: make(map[string]*Plugin)} +} + +// Load loads a Lua plugin from the given path. +func (m *Manager) Load(path string) (string, error) { + abs, err := filepath.Abs(path) + if err != nil { + return "", err + } + + name := strings.TrimSuffix(filepath.Base(abs), filepath.Ext(abs)) + if name == "" { + return "", fmt.Errorf("invalid plugin name") + } + + m.mu.Lock() + defer m.mu.Unlock() + if _, ok := m.plugins[name]; ok { + return "", fmt.Errorf("plugin %s already loaded", name) + } + + state := lua.NewState() + registerHelpers(state) + if err = state.DoFile(abs); err != nil { + state.Close() + return "", err + } + + m.plugins[name] = &Plugin{ + Name: name, + Path: abs, + Enabled: true, + L: state, + } + + return name, nil +} + +// Unload unloads a plugin and closes its Lua state. +func (m *Manager) Unload(name string) error { + m.mu.Lock() + defer m.mu.Unlock() + p, ok := m.plugins[name] + if !ok { + return fmt.Errorf("plugin %s not found", name) + } + + p.L.Close() + delete(m.plugins, name) + return nil +} + +// Enable enables a previously loaded plugin. +func (m *Manager) Enable(name string) error { + m.mu.Lock() + defer m.mu.Unlock() + p, ok := m.plugins[name] + if !ok { + return fmt.Errorf("plugin %s not found", name) + } + p.Enabled = true + return nil +} + +// Disable disables a plugin without unloading it. +func (m *Manager) Disable(name string) error { + m.mu.Lock() + defer m.mu.Unlock() + p, ok := m.plugins[name] + if !ok { + return fmt.Errorf("plugin %s not found", name) + } + p.Enabled = false + return nil +} + +// Reload reloads a plugin's file atomically. +func (m *Manager) Reload(name string) error { + m.mu.Lock() + p, ok := m.plugins[name] + if !ok { + m.mu.Unlock() + return fmt.Errorf("plugin %s not found", name) + } + + path := p.Path + p.L.Close() + delete(m.plugins, name) + m.mu.Unlock() + + _, err := m.Load(path) + return err +} + +// List returns a snapshot of all plugins. +func (m *Manager) List() []Snapshot { + m.mu.RLock() + res := make([]Snapshot, 0, len(m.plugins)) + for _, p := range m.plugins { + res = append(res, Snapshot{Name: p.Name, Path: p.Path, Enabled: p.Enabled}) + } + m.mu.RUnlock() + + sort.Slice(res, func(i, j int) bool { + return res[i].Name < res[j].Name + }) + return res +} + +// ProcessInput runs the OnInput hook chain across all enabled plugins. +func (m *Manager) ProcessInput(data []byte) ([]byte, error) { + return m.processDataHook("OnInput", data) +} + +// ProcessOutput runs the OnOutput hook chain across all enabled plugins. +func (m *Manager) ProcessOutput(data []byte) ([]byte, error) { + return m.processDataHook("OnOutput", data) +} + +func (m *Manager) processDataHook(name string, data []byte) ([]byte, error) { + m.mu.RLock() + plugins := make([]*Plugin, 0, len(m.plugins)) + for _, p := range m.plugins { + plugins = append(plugins, p) + } + m.mu.RUnlock() + + current := data + for _, p := range plugins { + if !p.Enabled { + continue + } + p.callMu.Lock() + ret, called, err := callStringHook(p.L, name, string(current)) + p.callMu.Unlock() + if err != nil { + return nil, fmt.Errorf("plugin %s %s: %w", p.Name, name, err) + } + if !called { + continue + } + if ret == nil { + return nil, nil + } + current = []byte(*ret) + } + + return current, nil +} + +// ProcessCommand runs the OnCommand hook chain across all enabled plugins. +func (m *Manager) ProcessCommand(line string) (string, bool, error) { + m.mu.RLock() + plugins := make([]*Plugin, 0, len(m.plugins)) + for _, p := range m.plugins { + plugins = append(plugins, p) + } + m.mu.RUnlock() + + current := line + allow := true + for _, p := range plugins { + if !p.Enabled { + continue + } + p.callMu.Lock() + next, nextAllow, called, err := callCommandHook(p.L, "OnCommand", current) + p.callMu.Unlock() + if err != nil { + return "", false, fmt.Errorf("plugin %s OnCommand: %w", p.Name, err) + } + if !called { + continue + } + allow = allow && nextAllow + if !allow { + return "", false, nil + } + if next != "" { + current = next + } + } + + return current, true, nil +} + +// Close closes all plugin Lua states. +func (m *Manager) Close() { + m.mu.Lock() + defer m.mu.Unlock() + for _, p := range m.plugins { + p.L.Close() + } + m.plugins = map[string]*Plugin{} +} diff --git a/pkg/luaplugin/plugin_test.go b/pkg/luaplugin/plugin_test.go new file mode 100644 index 0000000..84fca0d --- /dev/null +++ b/pkg/luaplugin/plugin_test.go @@ -0,0 +1,241 @@ +package luaplugin + +import ( + "os" + "path/filepath" + "testing" +) + +func writeLuaScript(t *testing.T, name, content string) string { + t.Helper() + path := filepath.Join(t.TempDir(), name) + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatalf("write lua script failed: %v", err) + } + return path +} + +func TestManagerLoadAndHooks(t *testing.T) { + m := NewManager() + t.Cleanup(m.Close) + + path := writeLuaScript(t, "rewrite.lua", ` +function OnInput(s) + return s .. "-in" +end + +function OnOutput(s) + return s .. "-out" +end + +function OnCommand(line) + return line .. " --lua", true +end +`) + + name, err := m.Load(path) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + if name != "rewrite" { + t.Fatalf("unexpected plugin name: %q", name) + } + + in, err := m.ProcessInput([]byte("abc")) + if err != nil { + t.Fatalf("ProcessInput() failed: %v", err) + } + if string(in) != "abc-in" { + t.Fatalf("ProcessInput() got=%q want=%q", in, "abc-in") + } + + out, err := m.ProcessOutput([]byte("xyz")) + if err != nil { + t.Fatalf("ProcessOutput() failed: %v", err) + } + if string(out) != "xyz-out" { + t.Fatalf("ProcessOutput() got=%q want=%q", out, "xyz-out") + } + + line, allow, err := m.ProcessCommand(".help") + if err != nil { + t.Fatalf("ProcessCommand() failed: %v", err) + } + if !allow || line != ".help --lua" { + t.Fatalf("ProcessCommand() got=(%q,%v) want=(%q,true)", line, allow, ".help --lua") + } +} + +func TestManagerDisableAndUnload(t *testing.T) { + m := NewManager() + t.Cleanup(m.Close) + + path := writeLuaScript(t, "simple.lua", ` +function OnInput(s) + return s .. "-x" +end +`) + + name, err := m.Load(path) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + if err = m.Disable(name); err != nil { + t.Fatalf("Disable() failed: %v", err) + } + got, err := m.ProcessInput([]byte("abc")) + if err != nil { + t.Fatalf("ProcessInput() with disabled plugin failed: %v", err) + } + if string(got) != "abc" { + t.Fatalf("disabled plugin should not modify input, got=%q", got) + } + + if err = m.Enable(name); err != nil { + t.Fatalf("Enable() failed: %v", err) + } + got, err = m.ProcessInput([]byte("abc")) + if err != nil { + t.Fatalf("ProcessInput() after enable failed: %v", err) + } + if string(got) != "abc-x" { + t.Fatalf("enabled plugin should modify input, got=%q", got) + } + + if err = m.Unload(name); err != nil { + t.Fatalf("Unload() failed: %v", err) + } + if len(m.List()) != 0 { + t.Fatalf("Unload() should remove plugin from list") + } +} + +func TestManagerOutputDrop(t *testing.T) { + m := NewManager() + t.Cleanup(m.Close) + + path := writeLuaScript(t, "drop.lua", ` +function OnOutput(s) + return nil +end +`) + + if _, err := m.Load(path); err != nil { + t.Fatalf("Load() failed: %v", err) + } + + out, err := m.ProcessOutput([]byte("abc")) + if err != nil { + t.Fatalf("ProcessOutput() failed: %v", err) + } + if out != nil { + t.Fatalf("expected nil output when plugin returns nil") + } +} + +func TestManagerReload(t *testing.T) { + m := NewManager() + t.Cleanup(m.Close) + + path := writeLuaScript(t, "reloadable.lua", ` +function OnInput(s) + return s .. "-v1" +end +`) + name, err := m.Load(path) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + if err = m.Reload(name); err != nil { + t.Fatalf("Reload() failed: %v", err) + } + + out, err := m.ProcessInput([]byte("test")) + if err != nil { + t.Fatalf("ProcessInput() after reload failed: %v", err) + } + if string(out) != "test-v1" { + t.Fatalf("reloaded plugin should still work, got=%q", out) + } + + if err = m.Reload("nonexistent"); err == nil { + t.Fatalf("Reload() non-existent should error") + } +} + +func TestManagerCommandBlock(t *testing.T) { + m := NewManager() + t.Cleanup(m.Close) + + path := writeLuaScript(t, "blocker.lua", ` +function OnCommand(line) + return line, false +end +`) + + if _, err := m.Load(path); err != nil { + t.Fatalf("Load() failed: %v", err) + } + + line, allow, err := m.ProcessCommand(".exit") + if err != nil { + t.Fatalf("ProcessCommand() failed: %v", err) + } + if allow { + t.Fatalf("command should be blocked, got allow=%v line=%q", allow, line) + } +} + +func TestManagerLoadErrors(t *testing.T) { + m := NewManager() + t.Cleanup(m.Close) + + _, err := m.Load("nonexistent_file.lua") + if err == nil { + t.Fatalf("Load() non-existent file should error") + } + + path := writeLuaScript(t, "bad.lua", "this is not valid lua {{{") + _, err = m.Load(path) + if err == nil { + t.Fatalf("Load() invalid lua should error") + } +} + +func TestManagerDuplicateLoad(t *testing.T) { + m := NewManager() + t.Cleanup(m.Close) + + path := writeLuaScript(t, "once.lua", "function OnInput(s) return s end") + _, err := m.Load(path) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + _, err = m.Load(path) + if err == nil { + t.Fatalf("Load() duplicate should error") + } +} + +func TestManagerListWithDisabled(t *testing.T) { + m := NewManager() + t.Cleanup(m.Close) + + path := writeLuaScript(t, "mylist.lua", "function OnInput(s) return s end") + name, err := m.Load(path) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + if err = m.Disable(name); err != nil { + t.Fatalf("Disable() failed: %v", err) + } + + items := m.List() + if len(items) != 1 || items[0].Enabled { + t.Fatalf("expected disabled in list, got %+v", items) + } +} diff --git a/plugins/demo.lua b/plugins/demo.lua new file mode 100644 index 0000000..6499689 --- /dev/null +++ b/plugins/demo.lua @@ -0,0 +1,14 @@ +-- Demo Lua plugin for the runtime plugin system. +-- It is shipped disabled by default and only runs after `.plugin load`. + +function OnInput(payload) + return payload +end + +function OnOutput(payload) + return payload +end + +function OnCommand(line) + return line, true +end diff --git a/plugins/modbus.lua b/plugins/modbus.lua new file mode 100644 index 0000000..9cd8a0a --- /dev/null +++ b/plugins/modbus.lua @@ -0,0 +1,84 @@ +-- Modbus RTU plugin for SerialTerminalForWindowsTerminal +-- Provides .modbus commands for reading/writing Modbus registers. +-- Uses Go-provided modbus.crc16() and hex.encode/decode helpers. + +-- OnInput: intercept Modbus RTU frames and log them +function OnInput(payload) + return payload +end + +-- OnOutput: decode Modbus RTU responses and format for display +function OnOutput(payload) + return payload +end + +-- OnCommand: handle .modbus commands +function OnCommand(line) + local cmd, slave, addr, count = parseModbus(line) + if not cmd then + return line, true -- not a modbus command, pass through + end + + if cmd == "read" then + return buildReadRequest(slave, addr, count), false + elseif cmd == "write" then + return buildWriteRequest(slave, addr, count), false + elseif cmd == "info" then + return line, true -- pass to .help + end + + return line, true +end + +-- Parse ".modbus read|write " +function parseModbus(line) + local parts = {} + for part in string.gmatch(line, "%S+") do + table.insert(parts, part) + end + if #parts < 1 or parts[1] ~= ".modbus" then + return nil + end + local cmd = parts[2] + if cmd == "read" and #parts >= 4 then + return cmd, tonumber(parts[3]), tonumber(parts[4]), tonumber(parts[5]) + elseif cmd == "write" and #parts >= 4 then + return cmd, tonumber(parts[3]), tonumber(parts[4]), tonumber(parts[5]) + elseif cmd == "info" then + return cmd, nil, nil, nil + end + return nil +end + +-- Build Modbus RTU read holding registers request (function 0x03) +function buildReadRequest(slave, addr, count) + if not count or count <= 0 then count = 1 end + if count > 125 then count = 125 end + + local frame = util.bytes(slave, 0x03, + math.floor(addr / 256), addr % 256, + math.floor(count / 256), count % 256) + + local crc = modbus.crc16(frame) + local crcLow = crc % 256 + local crcHigh = math.floor(crc / 256) + frame = frame .. string.char(crcLow) .. string.char(crcHigh) + + return frame +end + +-- Build Modbus RTU write single register request (function 0x06) +function buildWriteRequest(slave, addr, value) + if not value then value = 0 end + + local frame = util.bytes(slave, 0x06, + math.floor(addr / 256), addr % 256, + math.floor(value / 256), value % 256) + + local crc = modbus.crc16(frame) + local crcLow = crc % 256 + local crcHigh = math.floor(crc / 256) + frame = frame .. string.char(crcLow) .. string.char(crcHigh) + + return frame +end diff --git a/utils.go b/utils.go deleted file mode 100644 index c362599..0000000 --- a/utils.go +++ /dev/null @@ -1,125 +0,0 @@ -package main - -import ( - "fmt" - "github.com/trzsz/trzsz-go/trzsz" - "go.bug.st/serial" - "golang.org/x/term" - "io" - "log" - "net" - "os" - "os/signal" - "runtime" - "strings" -) - -func checkPortAvailability(name string) ([]string, error) { - ports, err := serial.GetPortsList() - if err != nil { - log.Fatal(err) - } - if len(ports) == 0 { - return nil, fmt.Errorf("无串口") - } - if name == "" { - return ports, fmt.Errorf("串口未指定") - } - for _, port := range ports { - if strings.Compare(port, name) == 0 { - return ports, nil - } - } - return ports, fmt.Errorf("串口 " + name + " 未在线") -} - -func OpenSerial() { - var err error - mode := &serial.Mode{ - BaudRate: config.baudRate, - StopBits: serial.StopBits(config.stopBits), - DataBits: config.dataBits, - Parity: serial.Parity(config.parityBit), - } - serialPort, err = serial.Open(config.portName, mode) - ErrorF(err) - return -} - -func CloseSerial() { - err := serialPort.Close() - ErrorF(err) - return -} - -var termch chan os.Signal - -// OpenTrzsz create a TrzszFilter to support trzsz ( trz / tsz ). -// -// ┌────────┐ stdinPipe ┌────────┐ ClientIn ┌─────────────┐ SerialIn ┌────────┐ -// │ ├─────────────►│ ├─────────────►│ ├─────────────►│ │ -// │ mutual │ │ Client │ │ TrzszFilter │ │ Serial │ -// │ │◄─────────────│ │◄─────────────┤ │◄─────────────┤ │ -// └────────┘ stdoutPipe └────────┘ ClientOut └─────────────┘ SerialOut └────────┘ -func OpenTrzsz() { - fd := int(os.Stdin.Fd()) - width, _, err := term.GetSize(fd) - if err != nil { - if runtime.GOOS != "windows" { - fmt.Printf("term get size failed: %s\n", err) - return - } - width = 80 - } - - clientIn, stdinPipe = io.Pipe() - stdoutPipe, clientOut = io.Pipe() - trzszFilter = trzsz.NewTrzszFilter(clientIn, clientOut, serialPort, serialPort, - trzsz.TrzszOptions{TerminalColumns: int32(width), EnableZmodem: true}) - trzsz.SetAffectedByWindows(false) - termch = make(chan os.Signal, 1) - go func() { - for range termch { - width, _, err := term.GetSize(fd) - if err != nil { - fmt.Printf("term get size failed: %s\n", err) - continue - } - trzszFilter.SetTerminalColumns(int32(width)) - } - }() -} - -func CloseTrzsz() { - signal.Stop(termch) - close(termch) -} - -func OpenForwarding() { - for i, mode := range config.forWard { - if FoeWardMode(mode) != NOT { - conn := setForWardClient(FoeWardMode(mode), config.address[i]) - outs = append(outs, conn) - go func() { - defer func(conn net.Conn) { - err := conn.Close() - if err != nil { - log.Fatal(err) - } - }(conn) - input(conn) - }() - } - } -} - -func ErrorP(err error) { - if err != nil { - fmt.Fprint(os.Stderr, err) - } -} -func ErrorF(err error) { - if err != nil { - log.Fatal(err) - } -}