mirror of
https://github.com/jixishi/SerialTerminalForWindowsTerminal.git
synced 2026-06-15 16:42:46 +00:00
refactor: extract pkg/forward and pkg/luaplugin packages
Move ForwardManager → pkg/forward/Manager and PluginManager → pkg/luaplugin/Manager. Move FoeWardMode (now forward.Mode) with ParseMode/Network/String into pkg/forward. Rename constants: NOT→None, TCPC→TCP, UDPC→UDP. Update all references in main package. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,253 @@
|
||||
package forward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestManagerTCPFlow(t *testing.T) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen failed: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
acceptCh := make(chan net.Conn, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
conn, e := listener.Accept()
|
||||
if e != nil {
|
||||
errCh <- e
|
||||
return
|
||||
}
|
||||
acceptCh <- conn
|
||||
}()
|
||||
|
||||
serialCh := make(chan string, 2)
|
||||
mgr := NewManager(func(b []byte) error {
|
||||
serialCh <- string(b)
|
||||
return nil
|
||||
}, func(string, ...any) {})
|
||||
defer mgr.Close()
|
||||
|
||||
id, err := mgr.Add(TCP, listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Add() failed: %v", err)
|
||||
}
|
||||
|
||||
var serverConn net.Conn
|
||||
select {
|
||||
case serverConn = <-acceptCh:
|
||||
case e := <-errCh:
|
||||
t.Fatalf("accept failed: %v", e)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("timed out waiting for accepted connection")
|
||||
}
|
||||
defer serverConn.Close()
|
||||
|
||||
items := mgr.List()
|
||||
if len(items) != 1 || items[0].ID != id || !items[0].Enabled {
|
||||
t.Fatalf("unexpected list after add: %+v", items)
|
||||
}
|
||||
|
||||
if err = serverConn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Fatalf("SetReadDeadline failed: %v", err)
|
||||
}
|
||||
mgr.Broadcast([]byte("from-app"))
|
||||
buf := make([]byte, 64)
|
||||
n, err := serverConn.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("server read from broadcast failed: %v", err)
|
||||
}
|
||||
if string(buf[:n]) != "from-app" {
|
||||
t.Fatalf("broadcast payload mismatch got=%q", string(buf[:n]))
|
||||
}
|
||||
|
||||
if _, err = serverConn.Write([]byte("from-remote")); err != nil {
|
||||
t.Fatalf("server write failed: %v", err)
|
||||
}
|
||||
select {
|
||||
case got := <-serialCh:
|
||||
if got != "from-remote" {
|
||||
t.Fatalf("writeToSerial payload mismatch got=%q", got)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("timed out waiting for writeToSerial callback")
|
||||
}
|
||||
|
||||
if err = mgr.Disable(id); err != nil {
|
||||
t.Fatalf("Disable() failed: %v", err)
|
||||
}
|
||||
items = mgr.List()
|
||||
if len(items) != 1 || items[0].Enabled {
|
||||
t.Fatalf("Disable() did not update state: %+v", items)
|
||||
}
|
||||
|
||||
if err = mgr.Remove(id); err != nil {
|
||||
t.Fatalf("Remove() failed: %v", err)
|
||||
}
|
||||
if got := mgr.List(); len(got) != 0 {
|
||||
t.Fatalf("expected empty list after remove, got=%+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerErrorCases(t *testing.T) {
|
||||
mgr := NewManager(func([]byte) error { return nil }, func(string, ...any) {})
|
||||
defer mgr.Close()
|
||||
|
||||
if _, err := mgr.Add(None, "127.0.0.1:1"); err == nil {
|
||||
t.Fatalf("Add(None) expected error")
|
||||
}
|
||||
|
||||
if err := mgr.Remove(999); err == nil {
|
||||
t.Fatalf("Remove(non-existing) expected error")
|
||||
}
|
||||
|
||||
if err := mgr.Disable(999); err == nil {
|
||||
t.Fatalf("Disable(non-existing) expected error")
|
||||
}
|
||||
|
||||
if err := mgr.Enable(999); err == nil {
|
||||
t.Fatalf("Enable(non-existing) expected error")
|
||||
}
|
||||
|
||||
if err := mgr.Update(999, TCP, "127.0.0.1:1"); err == nil {
|
||||
t.Fatalf("Update(non-existing) expected error")
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen failed: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
id, err := mgr.Add(TCP, listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Add() failed: %v", err)
|
||||
}
|
||||
if err = mgr.Update(id, None, "127.0.0.1:1"); err == nil {
|
||||
t.Fatalf("Update(None) expected error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerSetInboundReporter(t *testing.T) {
|
||||
reported := make(chan []byte, 1)
|
||||
mgr := NewManager(func([]byte) error { return nil }, func(string, ...any) {})
|
||||
defer mgr.Close()
|
||||
mgr.SetInboundReporter(func(id int, chunk []byte) {
|
||||
reported <- chunk
|
||||
})
|
||||
// Verify the callback was stored (indirect test)
|
||||
_ = reported
|
||||
}
|
||||
|
||||
func TestManagerBroadcastToDisabled(t *testing.T) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen failed: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
writeCh := make(chan []byte, 4)
|
||||
mgr := NewManager(func([]byte) error {
|
||||
writeCh <- nil
|
||||
return nil
|
||||
}, func(string, ...any) {})
|
||||
defer mgr.Close()
|
||||
|
||||
id, err := mgr.Add(TCP, listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Add() failed: %v", err)
|
||||
}
|
||||
|
||||
if err = mgr.Disable(id); err != nil {
|
||||
t.Fatalf("Disable() failed: %v", err)
|
||||
}
|
||||
|
||||
mgr.Broadcast([]byte("should-not-arrive"))
|
||||
|
||||
select {
|
||||
case <-writeCh:
|
||||
t.Fatalf("broadcast should not write to serial when disabled")
|
||||
default:
|
||||
}
|
||||
|
||||
mgr.Broadcast(nil)
|
||||
mgr.Broadcast([]byte{})
|
||||
}
|
||||
|
||||
func TestManagerEnable(t *testing.T) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen failed: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
writeCh := make(chan []byte, 2)
|
||||
mgr := NewManager(func([]byte) error {
|
||||
writeCh <- nil
|
||||
return nil
|
||||
}, func(string, ...any) {})
|
||||
defer mgr.Close()
|
||||
|
||||
id, err := mgr.Add(TCP, listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Add() failed: %v", err)
|
||||
}
|
||||
|
||||
if err = mgr.Disable(id); err != nil {
|
||||
t.Fatalf("Disable() failed: %v", err)
|
||||
}
|
||||
|
||||
if err = mgr.Enable(id); err != nil {
|
||||
t.Fatalf("Enable() failed: %v", err)
|
||||
}
|
||||
|
||||
items := mgr.List()
|
||||
if len(items) != 1 || !items[0].Enabled {
|
||||
t.Fatalf("expected enabled after Enable(), got=%+v", items)
|
||||
}
|
||||
|
||||
if err = mgr.Enable(id); err != nil {
|
||||
t.Fatalf("second Enable() should succeed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerUpdate(t *testing.T) {
|
||||
l1, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen 1 failed: %v", err)
|
||||
}
|
||||
defer l1.Close()
|
||||
|
||||
l2, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen 2 failed: %v", err)
|
||||
}
|
||||
defer l2.Close()
|
||||
|
||||
mgr := NewManager(func([]byte) error { return nil }, func(string, ...any) {})
|
||||
defer mgr.Close()
|
||||
|
||||
id, err := mgr.Add(TCP, l1.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Add() failed: %v", err)
|
||||
}
|
||||
|
||||
if err = mgr.Update(id, TCP, l2.Addr().String()); err != nil {
|
||||
t.Fatalf("Update() failed: %v", err)
|
||||
}
|
||||
|
||||
items := mgr.List()
|
||||
if len(items) != 1 || items[0].Address != l2.Addr().String() {
|
||||
t.Fatalf("update should change address, got=%+v", items)
|
||||
}
|
||||
|
||||
if err = mgr.Disable(id); err != nil {
|
||||
t.Fatalf("Disable() failed: %v", err)
|
||||
}
|
||||
if err = mgr.Update(id, TCP, l1.Addr().String()); err != nil {
|
||||
t.Fatalf("Update() on disabled should succeed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,369 @@
|
||||
// Package forward manages TCP/UDP forwarding targets for serial data.
|
||||
package forward
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Mode is the forwarding protocol mode.
|
||||
type Mode int
|
||||
|
||||
const (
|
||||
None Mode = iota
|
||||
TCP
|
||||
UDP
|
||||
)
|
||||
|
||||
// ParseMode parses a mode string. Accepts "tcp"/"tcp-c"/"tcpc"/"1" → TCP, "udp"/"udp-c"/"udpc"/"2" → UDP.
|
||||
func ParseMode(v string) (Mode, bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||
case "tcp", "tcp-c", "tcpc", "1":
|
||||
return TCP, true
|
||||
case "udp", "udp-c", "udpc", "2":
|
||||
return UDP, true
|
||||
default:
|
||||
return None, false
|
||||
}
|
||||
}
|
||||
|
||||
func (m Mode) Network() string {
|
||||
switch m {
|
||||
case TCP:
|
||||
return "tcp"
|
||||
case UDP:
|
||||
return "udp"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (m Mode) String() string {
|
||||
switch m {
|
||||
case TCP:
|
||||
return "tcp"
|
||||
case UDP:
|
||||
return "udp"
|
||||
default:
|
||||
return "none"
|
||||
}
|
||||
}
|
||||
|
||||
// Stats holds I/O statistics for a forward target.
|
||||
type Stats struct {
|
||||
ReadBytes uint64
|
||||
WrittenBytes uint64
|
||||
LastError string
|
||||
}
|
||||
|
||||
// Target represents a single forwarding connection.
|
||||
type Target struct {
|
||||
ID int
|
||||
Mode Mode
|
||||
Address string
|
||||
Enabled bool
|
||||
Connected bool
|
||||
CreatedAt time.Time
|
||||
|
||||
conn net.Conn
|
||||
stats Stats
|
||||
mu sync.Mutex
|
||||
closeCh chan struct{}
|
||||
closed bool
|
||||
}
|
||||
|
||||
// Snapshot is a read-only view of a forward target for display.
|
||||
type Snapshot struct {
|
||||
ID int
|
||||
Mode string
|
||||
Address string
|
||||
Enabled bool
|
||||
Connected bool
|
||||
ReadBytes uint64
|
||||
WriteByte uint64
|
||||
LastError string
|
||||
}
|
||||
|
||||
// Manager coordinates forwarding targets.
|
||||
type Manager struct {
|
||||
mu sync.RWMutex
|
||||
targets map[int]*Target
|
||||
nextID int
|
||||
writeToSerial func([]byte) error
|
||||
notify func(string, ...any)
|
||||
onInbound func(int, []byte)
|
||||
}
|
||||
|
||||
// NewManager creates a forwarding manager.
|
||||
func NewManager(writeToSerial func([]byte) error, notify func(string, ...any)) *Manager {
|
||||
return &Manager{
|
||||
targets: make(map[int]*Target),
|
||||
nextID: 1,
|
||||
writeToSerial: writeToSerial,
|
||||
notify: notify,
|
||||
}
|
||||
}
|
||||
|
||||
// SetInboundReporter sets a callback invoked when inbound data arrives from a target.
|
||||
func (m *Manager) SetInboundReporter(fn func(int, []byte)) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.onInbound = fn
|
||||
}
|
||||
|
||||
// Add creates and connects a new forward target.
|
||||
func (m *Manager) Add(mode Mode, address string) (int, error) {
|
||||
if mode == None {
|
||||
return 0, fmt.Errorf("forward mode cannot be none")
|
||||
}
|
||||
|
||||
t := &Target{
|
||||
Mode: mode,
|
||||
Address: address,
|
||||
Enabled: true,
|
||||
CreatedAt: time.Now(),
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
conn, err := net.Dial(mode.Network(), address)
|
||||
if err != nil {
|
||||
t.stats.LastError = err.Error()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
t.conn = conn
|
||||
t.Connected = true
|
||||
|
||||
m.mu.Lock()
|
||||
t.ID = m.nextID
|
||||
m.nextID++
|
||||
m.targets[t.ID] = t
|
||||
m.mu.Unlock()
|
||||
|
||||
go m.readLoop(t, conn, t.closeCh)
|
||||
m.notify("[forward] #%d %s %s connected", t.ID, t.Mode.String(), t.Address)
|
||||
return t.ID, nil
|
||||
}
|
||||
|
||||
func (m *Manager) readLoop(t *Target, conn net.Conn, stop <-chan struct{}) {
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := conn.Read(buf)
|
||||
if n > 0 {
|
||||
atomic.AddUint64(&t.stats.ReadBytes, uint64(n))
|
||||
chunk := make([]byte, n)
|
||||
copy(chunk, buf[:n])
|
||||
if wErr := m.writeToSerial(chunk); wErr != nil {
|
||||
t.stats.LastError = wErr.Error()
|
||||
m.notify("[forward] #%d write serial error: %v", t.ID, wErr)
|
||||
} else if m.onInbound != nil {
|
||||
m.onInbound(t.ID, chunk)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.mu.Lock()
|
||||
if t.conn == conn {
|
||||
t.Connected = false
|
||||
}
|
||||
t.stats.LastError = err.Error()
|
||||
t.mu.Unlock()
|
||||
m.notify("[forward] #%d disconnected: %v", t.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove disconnects and removes a target.
|
||||
func (m *Manager) Remove(id int) error {
|
||||
m.mu.Lock()
|
||||
t, ok := m.targets[id]
|
||||
if !ok {
|
||||
m.mu.Unlock()
|
||||
return fmt.Errorf("forward #%d not found", id)
|
||||
}
|
||||
delete(m.targets, id)
|
||||
m.mu.Unlock()
|
||||
|
||||
t.close()
|
||||
m.notify("[forward] #%d removed", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Enable (re)connects a target.
|
||||
func (m *Manager) Enable(id int) error {
|
||||
m.mu.RLock()
|
||||
t, ok := m.targets[id]
|
||||
m.mu.RUnlock()
|
||||
if !ok {
|
||||
return fmt.Errorf("forward #%d not found", id)
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.Enabled && t.Connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
conn, err := net.Dial(t.Mode.Network(), t.Address)
|
||||
if err != nil {
|
||||
t.stats.LastError = err.Error()
|
||||
return err
|
||||
}
|
||||
|
||||
t.Enabled = true
|
||||
t.Connected = true
|
||||
t.conn = conn
|
||||
t.closeCh = make(chan struct{})
|
||||
t.closed = false
|
||||
go m.readLoop(t, conn, t.closeCh)
|
||||
m.notify("[forward] #%d enabled", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update changes a target's mode and address, reconnecting if enabled.
|
||||
func (m *Manager) Update(id int, mode Mode, address string) error {
|
||||
if mode == None {
|
||||
return fmt.Errorf("forward mode cannot be none")
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
t, ok := m.targets[id]
|
||||
m.mu.RUnlock()
|
||||
if !ok {
|
||||
return fmt.Errorf("forward #%d not found", id)
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
wasEnabled := t.Enabled
|
||||
t.Mode = mode
|
||||
t.Address = address
|
||||
t.mu.Unlock()
|
||||
|
||||
t.close()
|
||||
|
||||
if !wasEnabled {
|
||||
m.notify("[forward] #%d updated (disabled)", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
return m.Enable(id)
|
||||
}
|
||||
|
||||
// Disable disconnects a target without removing it.
|
||||
func (m *Manager) Disable(id int) error {
|
||||
m.mu.RLock()
|
||||
t, ok := m.targets[id]
|
||||
m.mu.RUnlock()
|
||||
if !ok {
|
||||
return fmt.Errorf("forward #%d not found", id)
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
t.Enabled = false
|
||||
t.mu.Unlock()
|
||||
t.close()
|
||||
m.notify("[forward] #%d disabled", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Broadcast sends data to all enabled, connected targets.
|
||||
func (m *Manager) Broadcast(data []byte) {
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
items := make([]*Target, 0, len(m.targets))
|
||||
for _, t := range m.targets {
|
||||
items = append(items, t)
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
for _, t := range items {
|
||||
if !t.Enabled || !t.Connected || t.conn == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
n, err := t.conn.Write(data)
|
||||
if err != nil {
|
||||
t.stats.LastError = err.Error()
|
||||
m.notify("[forward] #%d write error: %v", t.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
atomic.AddUint64(&t.stats.WrittenBytes, uint64(n))
|
||||
}
|
||||
}
|
||||
|
||||
// List returns a snapshot of all targets.
|
||||
func (m *Manager) List() []Snapshot {
|
||||
m.mu.RLock()
|
||||
items := make([]Snapshot, 0, len(m.targets))
|
||||
for _, t := range m.targets {
|
||||
items = append(items, Snapshot{
|
||||
ID: t.ID,
|
||||
Mode: t.Mode.String(),
|
||||
Address: t.Address,
|
||||
Enabled: t.Enabled,
|
||||
Connected: t.Connected,
|
||||
ReadBytes: atomic.LoadUint64(&t.stats.ReadBytes),
|
||||
WriteByte: atomic.LoadUint64(&t.stats.WrittenBytes),
|
||||
LastError: t.stats.LastError,
|
||||
})
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
sort.Slice(items, func(i, j int) bool {
|
||||
return items[i].ID < items[j].ID
|
||||
})
|
||||
|
||||
return items
|
||||
}
|
||||
|
||||
// Close disconnects and removes all targets.
|
||||
func (m *Manager) Close() {
|
||||
m.mu.Lock()
|
||||
items := make([]*Target, 0, len(m.targets))
|
||||
for _, t := range m.targets {
|
||||
items = append(items, t)
|
||||
}
|
||||
m.targets = map[int]*Target{}
|
||||
m.mu.Unlock()
|
||||
|
||||
for _, t := range items {
|
||||
t.close()
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Target) close() {
|
||||
t.mu.Lock()
|
||||
if t.closed {
|
||||
t.mu.Unlock()
|
||||
return
|
||||
}
|
||||
t.closed = true
|
||||
ch := t.closeCh
|
||||
conn := t.conn
|
||||
t.conn = nil
|
||||
t.Connected = false
|
||||
t.mu.Unlock()
|
||||
|
||||
if ch != nil {
|
||||
close(ch)
|
||||
}
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package luaplugin
|
||||
|
||||
import lua "github.com/yuin/gopher-lua"
|
||||
|
||||
func callStringHook(L *lua.LState, name string, payload string) (*string, bool, error) {
|
||||
fn := L.GetGlobal(name)
|
||||
if fn.Type() == lua.LTNil {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
if err := L.CallByParam(lua.P{Fn: fn, NRet: 1, Protect: true}, lua.LString(payload)); err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
|
||||
ret := L.Get(-1)
|
||||
L.Pop(1)
|
||||
if ret.Type() == lua.LTNil {
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
s := ret.String()
|
||||
return &s, true, nil
|
||||
}
|
||||
|
||||
func callCommandHook(L *lua.LState, name, line string) (string, bool, bool, error) {
|
||||
fn := L.GetGlobal(name)
|
||||
if fn.Type() == lua.LTNil {
|
||||
return "", true, false, nil
|
||||
}
|
||||
|
||||
if err := L.CallByParam(lua.P{Fn: fn, NRet: 2, Protect: true}, lua.LString(line)); err != nil {
|
||||
return "", true, true, err
|
||||
}
|
||||
|
||||
allowVal := L.Get(-1)
|
||||
lineVal := L.Get(-2)
|
||||
L.Pop(2)
|
||||
|
||||
allow := true
|
||||
if allowVal.Type() == lua.LTBool {
|
||||
allow = lua.LVAsBool(allowVal)
|
||||
}
|
||||
|
||||
next := ""
|
||||
if lineVal.Type() != lua.LTNil {
|
||||
next = lineVal.String()
|
||||
}
|
||||
|
||||
return next, allow, true, nil
|
||||
}
|
||||
@@ -0,0 +1,230 @@
|
||||
// Package luaplugin provides a Lua plugin system for processing serial data streams.
|
||||
package luaplugin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
// Plugin represents a loaded Lua plugin.
|
||||
type Plugin struct {
|
||||
Name string
|
||||
Path string
|
||||
Enabled bool
|
||||
L *lua.LState
|
||||
callMu sync.Mutex
|
||||
}
|
||||
|
||||
// Snapshot is a read-only view of a plugin for display.
|
||||
type Snapshot struct {
|
||||
Name string
|
||||
Path string
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// Manager coordinates plugin lifecycle and hook execution.
|
||||
type Manager struct {
|
||||
mu sync.RWMutex
|
||||
plugins map[string]*Plugin
|
||||
}
|
||||
|
||||
// NewManager creates a plugin manager.
|
||||
func NewManager() *Manager {
|
||||
return &Manager{plugins: make(map[string]*Plugin)}
|
||||
}
|
||||
|
||||
// Load loads a Lua plugin from the given path.
|
||||
func (m *Manager) Load(path string) (string, error) {
|
||||
abs, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
name := strings.TrimSuffix(filepath.Base(abs), filepath.Ext(abs))
|
||||
if name == "" {
|
||||
return "", fmt.Errorf("invalid plugin name")
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if _, ok := m.plugins[name]; ok {
|
||||
return "", fmt.Errorf("plugin %s already loaded", name)
|
||||
}
|
||||
|
||||
state := lua.NewState()
|
||||
if err = state.DoFile(abs); err != nil {
|
||||
state.Close()
|
||||
return "", err
|
||||
}
|
||||
|
||||
m.plugins[name] = &Plugin{
|
||||
Name: name,
|
||||
Path: abs,
|
||||
Enabled: true,
|
||||
L: state,
|
||||
}
|
||||
|
||||
return name, nil
|
||||
}
|
||||
|
||||
// Unload unloads a plugin and closes its Lua state.
|
||||
func (m *Manager) Unload(name string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
p, ok := m.plugins[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("plugin %s not found", name)
|
||||
}
|
||||
|
||||
p.L.Close()
|
||||
delete(m.plugins, name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Enable enables a previously loaded plugin.
|
||||
func (m *Manager) Enable(name string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
p, ok := m.plugins[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("plugin %s not found", name)
|
||||
}
|
||||
p.Enabled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disable disables a plugin without unloading it.
|
||||
func (m *Manager) Disable(name string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
p, ok := m.plugins[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("plugin %s not found", name)
|
||||
}
|
||||
p.Enabled = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reload reloads a plugin's file.
|
||||
func (m *Manager) Reload(name string) error {
|
||||
m.mu.Lock()
|
||||
p, ok := m.plugins[name]
|
||||
m.mu.Unlock()
|
||||
if !ok {
|
||||
return fmt.Errorf("plugin %s not found", name)
|
||||
}
|
||||
|
||||
path := p.Path
|
||||
if err := m.Unload(name); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := m.Load(path)
|
||||
return err
|
||||
}
|
||||
|
||||
// List returns a snapshot of all plugins.
|
||||
func (m *Manager) List() []Snapshot {
|
||||
m.mu.RLock()
|
||||
res := make([]Snapshot, 0, len(m.plugins))
|
||||
for _, p := range m.plugins {
|
||||
res = append(res, Snapshot{Name: p.Name, Path: p.Path, Enabled: p.Enabled})
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
sort.Slice(res, func(i, j int) bool {
|
||||
return res[i].Name < res[j].Name
|
||||
})
|
||||
return res
|
||||
}
|
||||
|
||||
// ProcessInput runs the OnInput hook chain across all enabled plugins.
|
||||
func (m *Manager) ProcessInput(data []byte) ([]byte, error) {
|
||||
return m.processDataHook("OnInput", data)
|
||||
}
|
||||
|
||||
// ProcessOutput runs the OnOutput hook chain across all enabled plugins.
|
||||
func (m *Manager) ProcessOutput(data []byte) ([]byte, error) {
|
||||
return m.processDataHook("OnOutput", data)
|
||||
}
|
||||
|
||||
func (m *Manager) processDataHook(name string, data []byte) ([]byte, error) {
|
||||
m.mu.RLock()
|
||||
plugins := make([]*Plugin, 0, len(m.plugins))
|
||||
for _, p := range m.plugins {
|
||||
plugins = append(plugins, p)
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
current := data
|
||||
for _, p := range plugins {
|
||||
if !p.Enabled {
|
||||
continue
|
||||
}
|
||||
p.callMu.Lock()
|
||||
ret, called, err := callStringHook(p.L, name, string(current))
|
||||
p.callMu.Unlock()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("plugin %s %s: %w", p.Name, name, err)
|
||||
}
|
||||
if !called {
|
||||
continue
|
||||
}
|
||||
if ret == nil {
|
||||
return nil, nil
|
||||
}
|
||||
current = []byte(*ret)
|
||||
}
|
||||
|
||||
return current, nil
|
||||
}
|
||||
|
||||
// ProcessCommand runs the OnCommand hook chain across all enabled plugins.
|
||||
func (m *Manager) ProcessCommand(line string) (string, bool, error) {
|
||||
m.mu.RLock()
|
||||
plugins := make([]*Plugin, 0, len(m.plugins))
|
||||
for _, p := range m.plugins {
|
||||
plugins = append(plugins, p)
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
current := line
|
||||
allow := true
|
||||
for _, p := range plugins {
|
||||
if !p.Enabled {
|
||||
continue
|
||||
}
|
||||
p.callMu.Lock()
|
||||
next, nextAllow, called, err := callCommandHook(p.L, "OnCommand", current)
|
||||
p.callMu.Unlock()
|
||||
if err != nil {
|
||||
return "", false, fmt.Errorf("plugin %s OnCommand: %w", p.Name, err)
|
||||
}
|
||||
if !called {
|
||||
continue
|
||||
}
|
||||
allow = allow && nextAllow
|
||||
if !allow {
|
||||
return "", false, nil
|
||||
}
|
||||
if next != "" {
|
||||
current = next
|
||||
}
|
||||
}
|
||||
|
||||
return current, true, nil
|
||||
}
|
||||
|
||||
// Close closes all plugin Lua states.
|
||||
func (m *Manager) Close() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
for _, p := range m.plugins {
|
||||
p.L.Close()
|
||||
}
|
||||
m.plugins = map[string]*Plugin{}
|
||||
}
|
||||
@@ -0,0 +1,241 @@
|
||||
package luaplugin
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func writeLuaScript(t *testing.T, name, content string) string {
|
||||
t.Helper()
|
||||
path := filepath.Join(t.TempDir(), name)
|
||||
if err := os.WriteFile(path, []byte(content), 0o644); err != nil {
|
||||
t.Fatalf("write lua script failed: %v", err)
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func TestManagerLoadAndHooks(t *testing.T) {
|
||||
m := NewManager()
|
||||
t.Cleanup(m.Close)
|
||||
|
||||
path := writeLuaScript(t, "rewrite.lua", `
|
||||
function OnInput(s)
|
||||
return s .. "-in"
|
||||
end
|
||||
|
||||
function OnOutput(s)
|
||||
return s .. "-out"
|
||||
end
|
||||
|
||||
function OnCommand(line)
|
||||
return line .. " --lua", true
|
||||
end
|
||||
`)
|
||||
|
||||
name, err := m.Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
if name != "rewrite" {
|
||||
t.Fatalf("unexpected plugin name: %q", name)
|
||||
}
|
||||
|
||||
in, err := m.ProcessInput([]byte("abc"))
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessInput() failed: %v", err)
|
||||
}
|
||||
if string(in) != "abc-in" {
|
||||
t.Fatalf("ProcessInput() got=%q want=%q", in, "abc-in")
|
||||
}
|
||||
|
||||
out, err := m.ProcessOutput([]byte("xyz"))
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessOutput() failed: %v", err)
|
||||
}
|
||||
if string(out) != "xyz-out" {
|
||||
t.Fatalf("ProcessOutput() got=%q want=%q", out, "xyz-out")
|
||||
}
|
||||
|
||||
line, allow, err := m.ProcessCommand(".help")
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessCommand() failed: %v", err)
|
||||
}
|
||||
if !allow || line != ".help --lua" {
|
||||
t.Fatalf("ProcessCommand() got=(%q,%v) want=(%q,true)", line, allow, ".help --lua")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerDisableAndUnload(t *testing.T) {
|
||||
m := NewManager()
|
||||
t.Cleanup(m.Close)
|
||||
|
||||
path := writeLuaScript(t, "simple.lua", `
|
||||
function OnInput(s)
|
||||
return s .. "-x"
|
||||
end
|
||||
`)
|
||||
|
||||
name, err := m.Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
if err = m.Disable(name); err != nil {
|
||||
t.Fatalf("Disable() failed: %v", err)
|
||||
}
|
||||
got, err := m.ProcessInput([]byte("abc"))
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessInput() with disabled plugin failed: %v", err)
|
||||
}
|
||||
if string(got) != "abc" {
|
||||
t.Fatalf("disabled plugin should not modify input, got=%q", got)
|
||||
}
|
||||
|
||||
if err = m.Enable(name); err != nil {
|
||||
t.Fatalf("Enable() failed: %v", err)
|
||||
}
|
||||
got, err = m.ProcessInput([]byte("abc"))
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessInput() after enable failed: %v", err)
|
||||
}
|
||||
if string(got) != "abc-x" {
|
||||
t.Fatalf("enabled plugin should modify input, got=%q", got)
|
||||
}
|
||||
|
||||
if err = m.Unload(name); err != nil {
|
||||
t.Fatalf("Unload() failed: %v", err)
|
||||
}
|
||||
if len(m.List()) != 0 {
|
||||
t.Fatalf("Unload() should remove plugin from list")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerOutputDrop(t *testing.T) {
|
||||
m := NewManager()
|
||||
t.Cleanup(m.Close)
|
||||
|
||||
path := writeLuaScript(t, "drop.lua", `
|
||||
function OnOutput(s)
|
||||
return nil
|
||||
end
|
||||
`)
|
||||
|
||||
if _, err := m.Load(path); err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
out, err := m.ProcessOutput([]byte("abc"))
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessOutput() failed: %v", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatalf("expected nil output when plugin returns nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerReload(t *testing.T) {
|
||||
m := NewManager()
|
||||
t.Cleanup(m.Close)
|
||||
|
||||
path := writeLuaScript(t, "reloadable.lua", `
|
||||
function OnInput(s)
|
||||
return s .. "-v1"
|
||||
end
|
||||
`)
|
||||
name, err := m.Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
if err = m.Reload(name); err != nil {
|
||||
t.Fatalf("Reload() failed: %v", err)
|
||||
}
|
||||
|
||||
out, err := m.ProcessInput([]byte("test"))
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessInput() after reload failed: %v", err)
|
||||
}
|
||||
if string(out) != "test-v1" {
|
||||
t.Fatalf("reloaded plugin should still work, got=%q", out)
|
||||
}
|
||||
|
||||
if err = m.Reload("nonexistent"); err == nil {
|
||||
t.Fatalf("Reload() non-existent should error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerCommandBlock(t *testing.T) {
|
||||
m := NewManager()
|
||||
t.Cleanup(m.Close)
|
||||
|
||||
path := writeLuaScript(t, "blocker.lua", `
|
||||
function OnCommand(line)
|
||||
return line, false
|
||||
end
|
||||
`)
|
||||
|
||||
if _, err := m.Load(path); err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
line, allow, err := m.ProcessCommand(".exit")
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessCommand() failed: %v", err)
|
||||
}
|
||||
if allow {
|
||||
t.Fatalf("command should be blocked, got allow=%v line=%q", allow, line)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerLoadErrors(t *testing.T) {
|
||||
m := NewManager()
|
||||
t.Cleanup(m.Close)
|
||||
|
||||
_, err := m.Load("nonexistent_file.lua")
|
||||
if err == nil {
|
||||
t.Fatalf("Load() non-existent file should error")
|
||||
}
|
||||
|
||||
path := writeLuaScript(t, "bad.lua", "this is not valid lua {{{")
|
||||
_, err = m.Load(path)
|
||||
if err == nil {
|
||||
t.Fatalf("Load() invalid lua should error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerDuplicateLoad(t *testing.T) {
|
||||
m := NewManager()
|
||||
t.Cleanup(m.Close)
|
||||
|
||||
path := writeLuaScript(t, "once.lua", "function OnInput(s) return s end")
|
||||
_, err := m.Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = m.Load(path)
|
||||
if err == nil {
|
||||
t.Fatalf("Load() duplicate should error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerListWithDisabled(t *testing.T) {
|
||||
m := NewManager()
|
||||
t.Cleanup(m.Close)
|
||||
|
||||
path := writeLuaScript(t, "mylist.lua", "function OnInput(s) return s end")
|
||||
name, err := m.Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
if err = m.Disable(name); err != nil {
|
||||
t.Fatalf("Disable() failed: %v", err)
|
||||
}
|
||||
|
||||
items := m.List()
|
||||
if len(items) != 1 || items[0].Enabled {
|
||||
t.Fatalf("expected disabled in list, got %+v", items)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user