cluster: embed ssh instead of sshd

This commit is contained in:
hayzam
2026-03-01 15:30:16 +05:30
parent 74fa47ed6a
commit 07c6f7d407
5 changed files with 256 additions and 98 deletions
+4
View File
@@ -122,6 +122,10 @@ func main() {
}
}
if err := cS.(*cluster.Service).StartEmbeddedSSHServer(qCtx); err != nil {
logger.L.Error().Err(err).Msg("Failed to start embedded cluster SSH server")
}
if err := zelta.EnsureZeltaInstalled(); err != nil {
logger.L.Error().Err(err).Msg("Failed to install Zelta")
}
+3
View File
@@ -12,6 +12,7 @@ import (
"encoding/json"
"errors"
"fmt"
"sync"
"time"
"github.com/alchemillahq/sylve/internal/config"
@@ -34,6 +35,8 @@ type Service struct {
Transport *raft.NetworkTransport
AuthService serviceInterfaces.AuthServiceInterface
JailService jailServiceInterfaces.JailServiceInterface
embeddedSSHOnce sync.Once
}
func NewClusterService(db *gorm.DB, authService serviceInterfaces.AuthServiceInterface, jailService jailServiceInterfaces.JailServiceInterface) clusterServiceInterfaces.ClusterServiceInterface {
+225
View File
@@ -0,0 +1,225 @@
// SPDX-License-Identifier: BSD-2-Clause
//
// Copyright (c) 2025 The FreeBSD Foundation.
//
// This software was developed by Hayzam Sherif <hayzam@alchemilla.io>
// of Alchemilla Ventures Pvt. Ltd. <hello@alchemilla.io>,
// under sponsorship from the FreeBSD Foundation.
package cluster
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"os"
"os/exec"
"strings"
"github.com/alchemillahq/sylve/internal/logger"
"golang.org/x/crypto/ssh"
)
func (s *Service) StartEmbeddedSSHServer(ctx context.Context) error {
var startErr error
s.embeddedSSHOnce.Do(func() {
privatePath, err := s.ClusterSSHPrivateKeyPath()
if err != nil {
startErr = fmt.Errorf("embedded_ssh_private_key_failed: %w", err)
return
}
privateRaw, err := os.ReadFile(privatePath)
if err != nil {
startErr = fmt.Errorf("embedded_ssh_private_key_read_failed: %w", err)
return
}
hostSigner, err := ssh.ParsePrivateKey(privateRaw)
if err != nil {
startErr = fmt.Errorf("embedded_ssh_private_key_parse_failed: %w", err)
return
}
serverConfig := &ssh.ServerConfig{
PublicKeyCallback: s.embeddedSSHPublicKeyCallback,
}
serverConfig.AddHostKey(hostSigner)
listenAddr := fmt.Sprintf("0.0.0.0:%d", ClusterEmbeddedSSHPort)
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
startErr = fmt.Errorf("embedded_ssh_listen_failed: %w", err)
return
}
logger.L.Info().
Str("addr", listenAddr).
Msg("embedded_cluster_ssh_started")
go func() {
<-ctx.Done()
_ = listener.Close()
}()
go s.embeddedSSHAcceptLoop(ctx, listener, serverConfig)
})
return startErr
}
func (s *Service) embeddedSSHPublicKeyCallback(conn ssh.ConnMetadata, presentedKey ssh.PublicKey) (*ssh.Permissions, error) {
if strings.TrimSpace(conn.User()) != "root" {
return nil, fmt.Errorf("invalid_user")
}
identities, err := s.ListClusterSSHIdentities()
if err != nil {
return nil, fmt.Errorf("list_cluster_identities_failed: %w", err)
}
for _, identity := range identities {
pub := strings.TrimSpace(identity.PublicKey)
if pub == "" {
continue
}
parsedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pub + "\n"))
if err != nil {
continue
}
if bytes.Equal(parsedKey.Marshal(), presentedKey.Marshal()) {
return &ssh.Permissions{
Extensions: map[string]string{
"node_uuid": identity.NodeUUID,
},
}, nil
}
}
return nil, fmt.Errorf("unauthorized_key")
}
func (s *Service) embeddedSSHAcceptLoop(ctx context.Context, listener net.Listener, serverConfig *ssh.ServerConfig) {
for {
rawConn, err := listener.Accept()
if err != nil {
if ctx.Err() != nil {
return
}
logger.L.Warn().Err(err).Msg("embedded_ssh_accept_failed")
continue
}
go s.handleEmbeddedSSHConn(ctx, rawConn, serverConfig)
}
}
func (s *Service) handleEmbeddedSSHConn(ctx context.Context, rawConn net.Conn, serverConfig *ssh.ServerConfig) {
defer rawConn.Close()
_, chans, reqs, err := ssh.NewServerConn(rawConn, serverConfig)
if err != nil {
logger.L.Warn().Err(err).Msg("embedded_ssh_handshake_failed")
return
}
go ssh.DiscardRequests(reqs)
for newChannel := range chans {
if newChannel.ChannelType() != "session" {
_ = newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
logger.L.Warn().Err(err).Msg("embedded_ssh_channel_accept_failed")
continue
}
go s.handleEmbeddedSSHSession(ctx, channel, requests)
}
}
func parseExecRequestPayload(payload []byte) (string, error) {
if len(payload) < 4 {
return "", fmt.Errorf("invalid_exec_payload")
}
size := int(binary.BigEndian.Uint32(payload[:4]))
if size < 0 || len(payload) < 4+size {
return "", fmt.Errorf("invalid_exec_payload_size")
}
return string(payload[4 : 4+size]), nil
}
func exitCodeFromErr(err error) uint32 {
if err == nil {
return 0
}
var exitErr *exec.ExitError
if ok := strings.Contains(err.Error(), "signal: killed"); ok {
return 137
}
if ok := strings.Contains(err.Error(), "signal: terminated"); ok {
return 143
}
if ok := strings.Contains(err.Error(), "signal: interrupt"); ok {
return 130
}
if errors.As(err, &exitErr) {
if code := exitErr.ExitCode(); code >= 0 {
return uint32(code)
}
}
return 1
}
func (s *Service) handleEmbeddedSSHSession(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) {
defer channel.Close()
execReceived := false
for req := range requests {
switch req.Type {
case "exec":
if execReceived {
_ = req.Reply(false, nil)
continue
}
execReceived = true
command, err := parseExecRequestPayload(req.Payload)
if err != nil {
_ = req.Reply(false, nil)
return
}
command = strings.TrimSpace(command)
if command == "" {
_ = req.Reply(false, nil)
return
}
_ = req.Reply(true, nil)
cmd := exec.CommandContext(ctx, "/bin/sh", "-c", command)
cmd.Stdin = channel
cmd.Stdout = channel
cmd.Stderr = channel.Stderr()
runErr := cmd.Run()
exitCode := exitCodeFromErr(runErr)
_, _ = channel.SendRequest("exit-status", false, ssh.Marshal(struct{ Status uint32 }{Status: exitCode}))
return
default:
_ = req.Reply(false, nil)
}
}
}
+3 -98
View File
@@ -13,7 +13,6 @@ import (
"net"
"os"
"path/filepath"
"sort"
"strings"
"time"
@@ -29,8 +28,7 @@ const (
clusterSSHPrivateFileName = "id_ed25519"
clusterSSHPublicFileName = "id_ed25519.pub"
clusterManagedKeyStart = "# --- sylve cluster replication keys start ---"
clusterManagedKeyEnd = "# --- sylve cluster replication keys end ---"
ClusterEmbeddedSSHPort = 8122
)
func (s *Service) clusterSSHDir() (string, error) {
@@ -50,14 +48,6 @@ func (s *Service) clusterSSHDir() (string, error) {
return path, nil
}
func (s *Service) clusterSSHPrivateKeyPath() (string, error) {
dir, err := s.clusterSSHDir()
if err != nil {
return "", err
}
return filepath.Join(dir, clusterSSHPrivateFileName), nil
}
func (s *Service) ClusterSSHPrivateKeyPath() (string, error) {
privatePath, _, _, err := s.ensureLocalClusterSSHKeyPair()
if err != nil {
@@ -160,7 +150,7 @@ func (s *Service) EnsureAndPublishLocalSSHIdentity() error {
NodeUUID: strings.TrimSpace(detail.NodeID),
SSHUser: "root",
SSHHost: s.localClusterSSHHost(),
SSHPort: 22,
SSHPort: ClusterEmbeddedSSHPort,
PublicKey: pubKey,
}
@@ -240,92 +230,7 @@ func (s *Service) forwardSSHIdentityToLeader(identity clusterModels.ClusterSSHId
return fmt.Errorf("forward_ssh_identity_to_leader_failed: %w", lastErr)
}
func replaceManagedSSHBlock(existing string, managed []string) string {
managedSet := make(map[string]struct{}, len(managed))
normalized := make([]string, 0, len(managed))
for _, line := range managed {
line = strings.TrimSpace(line)
if line == "" {
continue
}
if _, ok := managedSet[line]; ok {
continue
}
managedSet[line] = struct{}{}
normalized = append(normalized, line)
}
sort.Strings(normalized)
blockLines := []string{clusterManagedKeyStart}
blockLines = append(blockLines, normalized...)
blockLines = append(blockLines, clusterManagedKeyEnd)
block := strings.Join(blockLines, "\n")
start := strings.Index(existing, clusterManagedKeyStart)
end := strings.Index(existing, clusterManagedKeyEnd)
if start >= 0 && end > start {
end += len(clusterManagedKeyEnd)
left := strings.TrimRight(existing[:start], "\n")
right := strings.TrimLeft(existing[end:], "\n")
parts := []string{}
if left != "" {
parts = append(parts, left)
}
parts = append(parts, block)
if right != "" {
parts = append(parts, right)
}
return strings.TrimSpace(strings.Join(parts, "\n\n")) + "\n"
}
base := strings.TrimSpace(existing)
if base == "" {
return block + "\n"
}
return base + "\n\n" + block + "\n"
}
func (s *Service) ReconcileClusterSSHAuthorizedKeys() error {
identities, err := s.ListClusterSSHIdentities()
if err != nil {
return err
}
managed := make([]string, 0, len(identities))
for _, identity := range identities {
pub := strings.TrimSpace(identity.PublicKey)
if pub == "" {
continue
}
managed = append(managed, pub)
}
sshDir := "/root/.ssh"
if err := os.MkdirAll(sshDir, 0700); err != nil {
return fmt.Errorf("root_ssh_dir_create_failed: %w", err)
}
if err := os.Chmod(sshDir, 0700); err != nil {
return fmt.Errorf("root_ssh_dir_chmod_failed: %w", err)
}
authKeysPath := filepath.Join(sshDir, "authorized_keys")
existing := ""
if raw, readErr := os.ReadFile(authKeysPath); readErr == nil {
existing = string(raw)
}
next := replaceManagedSSHBlock(existing, managed)
if err := os.WriteFile(authKeysPath, []byte(next), 0600); err != nil {
return fmt.Errorf("authorized_keys_write_failed: %w", err)
}
if err := os.Chmod(authKeysPath, 0600); err != nil {
return fmt.Errorf("authorized_keys_chmod_failed: %w", err)
}
logger.L.Debug().
Int("managed_keys", len(managed)).
Str("path", authKeysPath).
Msg("cluster_ssh_authorized_keys_reconciled")
logger.L.Debug().Msg("cluster_ssh_authorized_keys_reconcile_skipped_embedded_ssh_mode")
return nil
}
+21
View File
@@ -203,6 +203,11 @@ func (s *Service) runReplicationPolicy(ctx context.Context, policy *clusterModel
s.updateReplicationPolicyResult(policy, err)
return err
}
if len(sourceDatasets) == 0 {
runErr := fmt.Errorf("no_source_datasets_found")
s.updateReplicationPolicyResult(policy, runErr)
return runErr
}
identities, err := s.Cluster.ListClusterSSHIdentities()
if err != nil {
@@ -252,21 +257,29 @@ func (s *Service) runReplicationPolicy(ctx context.Context, policy *clusterModel
})
var runErr error
eligibleTargets := 0
skippedOffline := 0
skippedNoIdentity := 0
attemptedTransfers := 0
for _, target := range targets {
targetNodeID := strings.TrimSpace(target.NodeID)
if targetNodeID == "" || targetNodeID == localNodeID {
continue
}
if status, ok := statusByNode[targetNodeID]; ok && status != "online" {
skippedOffline++
continue
}
identity, ok := identityByNode[targetNodeID]
if !ok {
skippedNoIdentity++
continue
}
eligibleTargets++
for _, sourceDataset := range sourceDatasets {
attemptedTransfers++
backupRoot, destSuffix := splitDatasetForTarget(sourceDataset)
targetSpec := &clusterModels.BackupTarget{
SSHHost: fmt.Sprintf("%s@%s", strings.TrimSpace(identity.SSHUser), strings.TrimSpace(identity.SSHHost)),
@@ -297,6 +310,14 @@ func (s *Service) runReplicationPolicy(ctx context.Context, policy *clusterModel
}
}
if runErr == nil {
if eligibleTargets == 0 {
runErr = fmt.Errorf("no_eligible_replication_targets (offline=%d missing_identity=%d)", skippedOffline, skippedNoIdentity)
} else if attemptedTransfers == 0 {
runErr = fmt.Errorf("no_replication_transfers_executed")
}
}
s.finalizeReplicationEvent(&event, runErr)
s.updateReplicationPolicyResult(policy, runErr)