diff --git a/server/config.go b/server/config.go index be5f574..c690af6 100644 --- a/server/config.go +++ b/server/config.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "os" "slices" @@ -89,5 +90,106 @@ func parseConfig(data []byte) (Config, error) { config.AllowedOrigins = allowedOrigins{"*"} } + if err := (&config).Validate(); err != nil { + return Config{}, err + } + return config, nil } + +var validSourceTypes = []string{ + "local:file", + "local:docker", + "local:openclaw", + "ssh:file", + "ssh:docker", + "ssh:openclaw", +} + +func (c *Config) Validate() error { + // Validate servers + serverNames := make(map[string]bool) + for i, server := range c.Servers { + if server.Name == "" { + return fmt.Errorf("servers[%d].name is required", i) + } + if serverNames[server.Name] { + return fmt.Errorf("servers[%d].name %q is duplicated", i, server.Name) + } + serverNames[server.Name] = true + + if server.Host == "" { + return fmt.Errorf("servers[%d].host is required", i) + } + if server.Port <= 0 || server.Port > 65535 { + return fmt.Errorf("servers[%d].port must be between 1 and 65535, got %d", i, server.Port) + } + if server.Username == "" { + return fmt.Errorf("servers[%d].username is required", i) + } + if server.Password == "" && server.PrivateKeyPath == "" { + return fmt.Errorf("servers[%d]: either password or privateKeyPath is required", i) + } + } + + // Validate sources + sourceNames := make(map[string]bool) + for i, source := range c.Sources { + if source.Name == "" { + return fmt.Errorf("sources[%d].name is required", i) + } + if sourceNames[source.Name] { + return fmt.Errorf("sources[%d].name %q is duplicated", i, source.Name) + } + sourceNames[source.Name] = true + + if source.Type == "" { + return fmt.Errorf("sources[%d].type is required", i) + } + if !slices.Contains(validSourceTypes, source.Type) { + return fmt.Errorf("sources[%d].type must be one of %v, got %q", i, validSourceTypes, source.Type) + } + + // Validate type-specific required fields + switch source.Type { + case "local:file", "ssh:file": + if source.Path == "" { + return fmt.Errorf("sources[%d].path is required for type %q", i, source.Type) + } + case "local:docker", "ssh:docker": + if source.ContainerId == "" { + return fmt.Errorf("sources[%d].containerId is required for type %q", i, source.Type) + } + } + + // Validate SSH-specific fields + if slices.Contains([]string{"ssh:file", "ssh:docker", "ssh:openclaw"}, source.Type) { + if source.ServerName != "" { + if !serverNames[source.ServerName] { + return fmt.Errorf("sources[%d].serverName references non-existent server %q", i, source.ServerName) + } + } else { + // Direct SSH config validation + if source.Host == "" { + return fmt.Errorf("sources[%d].host is required for SSH source type %q", i, source.Type) + } + if source.Port <= 0 || source.Port > 65535 { + return fmt.Errorf("sources[%d].port must be between 1 and 65535, got %d", i, source.Port) + } + if source.Username == "" { + return fmt.Errorf("sources[%d].username is required for SSH source type %q", i, source.Type) + } + if source.Password == "" && source.PrivateKeyPath == "" { + return fmt.Errorf("sources[%d]: either password or privateKeyPath is required for SSH source type %q", i, source.Type) + } + } + } + } + + // Validate port + if c.Port < 0 || c.Port > 65535 { + return fmt.Errorf("port must be between 0 and 65535, got %d", c.Port) + } + + return nil +} diff --git a/server/config_test.go b/server/config_test.go new file mode 100644 index 0000000..f421df6 --- /dev/null +++ b/server/config_test.go @@ -0,0 +1,295 @@ +package main + +import ( + "strings" + "testing" +) + +func TestConfigValidation(t *testing.T) { + tests := []struct { + name string + config string + wantError string + }{ + { + name: "valid local file source", + config: ` +[[sources]] +name = "test" +type = "local:file" +path = "/tmp/test.log" +`, + wantError: "", + }, + { + name: "valid local docker source", + config: ` +[[sources]] +name = "test" +type = "local:docker" +containerId = "abc123" +`, + wantError: "", + }, + { + name: "valid SSH source with direct config", + config: ` +[[sources]] +name = "test" +type = "ssh:file" +path = "/var/log/test.log" +host = "example.com" +port = 22 +username = "user" +password = "pass" +`, + wantError: "", + }, + { + name: "valid SSH source with server reference", + config: ` +[[servers]] +name = "prod" +host = "example.com" +port = 22 +username = "user" +password = "pass" + +[[sources]] +name = "test" +type = "ssh:file" +path = "/var/log/test.log" +serverName = "prod" +`, + wantError: "", + }, + { + name: "missing source name", + config: ` +[[sources]] +type = "local:file" +path = "/tmp/test.log" +`, + wantError: "sources[0].name is required", + }, + { + name: "missing source type", + config: ` +[[sources]] +name = "test" +path = "/tmp/test.log" +`, + wantError: "sources[0].type is required", + }, + { + name: "invalid source type", + config: ` +[[sources]] +name = "test" +type = "invalid:type" +path = "/tmp/test.log" +`, + wantError: "sources[0].type must be one of", + }, + { + name: "missing path for file source", + config: ` +[[sources]] +name = "test" +type = "local:file" +`, + wantError: "sources[0].path is required for type \"local:file\"", + }, + { + name: "missing containerId for docker source", + config: ` +[[sources]] +name = "test" +type = "local:docker" +`, + wantError: "sources[0].containerId is required for type \"local:docker\"", + }, + { + name: "SSH source missing host", + config: ` +[[sources]] +name = "test" +type = "ssh:file" +path = "/var/log/test.log" +port = 22 +username = "user" +password = "pass" +`, + wantError: "sources[0].host is required for SSH source type \"ssh:file\"", + }, + { + name: "SSH source missing username", + config: ` +[[sources]] +name = "test" +type = "ssh:file" +path = "/var/log/test.log" +host = "example.com" +port = 22 +password = "pass" +`, + wantError: "sources[0].username is required for SSH source type \"ssh:file\"", + }, + { + name: "SSH source missing authentication", + config: ` +[[sources]] +name = "test" +type = "ssh:file" +path = "/var/log/test.log" +host = "example.com" +port = 22 +username = "user" +`, + wantError: "sources[0]: either password or privateKeyPath is required for SSH source type \"ssh:file\"", + }, + { + name: "SSH source invalid port", + config: ` +[[sources]] +name = "test" +type = "ssh:file" +path = "/var/log/test.log" +host = "example.com" +port = 99999 +username = "user" +password = "pass" +`, + wantError: "sources[0].port must be between 1 and 65535", + }, + { + name: "server missing name", + config: ` +[[servers]] +host = "example.com" +port = 22 +username = "user" +password = "pass" +`, + wantError: "servers[0].name is required", + }, + { + name: "server missing host", + config: ` +[[servers]] +name = "prod" +port = 22 +username = "user" +password = "pass" +`, + wantError: "servers[0].host is required", + }, + { + name: "server invalid port", + config: ` +[[servers]] +name = "prod" +host = "example.com" +port = 0 +username = "user" +password = "pass" +`, + wantError: "servers[0].port must be between 1 and 65535", + }, + { + name: "server missing username", + config: ` +[[servers]] +name = "prod" +host = "example.com" +port = 22 +password = "pass" +`, + wantError: "servers[0].username is required", + }, + { + name: "server missing authentication", + config: ` +[[servers]] +name = "prod" +host = "example.com" +port = 22 +username = "user" +`, + wantError: "servers[0]: either password or privateKeyPath is required", + }, + { + name: "duplicate source names", + config: ` +[[sources]] +name = "test" +type = "local:file" +path = "/tmp/test1.log" + +[[sources]] +name = "test" +type = "local:file" +path = "/tmp/test2.log" +`, + wantError: "sources[1].name \"test\" is duplicated", + }, + { + name: "duplicate server names", + config: ` +[[servers]] +name = "prod" +host = "example.com" +port = 22 +username = "user" +password = "pass" + +[[servers]] +name = "prod" +host = "example2.com" +port = 22 +username = "user" +password = "pass" +`, + wantError: "servers[1].name \"prod\" is duplicated", + }, + { + name: "non-existent server reference", + config: ` +[[sources]] +name = "test" +type = "ssh:file" +path = "/var/log/test.log" +serverName = "nonexistent" +`, + wantError: "sources[0].serverName references non-existent server \"nonexistent\"", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config, err := parseConfig([]byte(tt.config)) + + if tt.wantError == "" { + if err != nil { + t.Errorf("expected no error, got: %v", err) + } + } else { + if err == nil { + t.Errorf("expected error containing %q, got no error", tt.wantError) + } else if !strings.Contains(err.Error(), tt.wantError) { + t.Errorf("expected error containing %q, got: %v", tt.wantError, err) + } + } + + // For valid configs, verify the config is properly set + if tt.wantError == "" && err == nil { + if config.Port == 0 { + t.Error("expected default port to be set") + } + if len(config.AllowedOrigins) == 0 { + t.Error("expected default allowed origins to be set") + } + } + }) + } +}