mirror of
https://github.com/AlchemillaHQ/Sylve.git
synced 2026-06-15 00:56:36 +03:00
cluster: embed ssh instead of sshd
This commit is contained in:
@@ -122,6 +122,10 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
if err := cS.(*cluster.Service).StartEmbeddedSSHServer(qCtx); err != nil {
|
||||
logger.L.Error().Err(err).Msg("Failed to start embedded cluster SSH server")
|
||||
}
|
||||
|
||||
if err := zelta.EnsureZeltaInstalled(); err != nil {
|
||||
logger.L.Error().Err(err).Msg("Failed to install Zelta")
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/alchemillahq/sylve/internal/config"
|
||||
@@ -34,6 +35,8 @@ type Service struct {
|
||||
Transport *raft.NetworkTransport
|
||||
AuthService serviceInterfaces.AuthServiceInterface
|
||||
JailService jailServiceInterfaces.JailServiceInterface
|
||||
|
||||
embeddedSSHOnce sync.Once
|
||||
}
|
||||
|
||||
func NewClusterService(db *gorm.DB, authService serviceInterfaces.AuthServiceInterface, jailService jailServiceInterfaces.JailServiceInterface) clusterServiceInterfaces.ClusterServiceInterface {
|
||||
|
||||
@@ -0,0 +1,225 @@
|
||||
// SPDX-License-Identifier: BSD-2-Clause
|
||||
//
|
||||
// Copyright (c) 2025 The FreeBSD Foundation.
|
||||
//
|
||||
// This software was developed by Hayzam Sherif <hayzam@alchemilla.io>
|
||||
// of Alchemilla Ventures Pvt. Ltd. <hello@alchemilla.io>,
|
||||
// under sponsorship from the FreeBSD Foundation.
|
||||
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/alchemillahq/sylve/internal/logger"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func (s *Service) StartEmbeddedSSHServer(ctx context.Context) error {
|
||||
var startErr error
|
||||
s.embeddedSSHOnce.Do(func() {
|
||||
privatePath, err := s.ClusterSSHPrivateKeyPath()
|
||||
if err != nil {
|
||||
startErr = fmt.Errorf("embedded_ssh_private_key_failed: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
privateRaw, err := os.ReadFile(privatePath)
|
||||
if err != nil {
|
||||
startErr = fmt.Errorf("embedded_ssh_private_key_read_failed: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
hostSigner, err := ssh.ParsePrivateKey(privateRaw)
|
||||
if err != nil {
|
||||
startErr = fmt.Errorf("embedded_ssh_private_key_parse_failed: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
serverConfig := &ssh.ServerConfig{
|
||||
PublicKeyCallback: s.embeddedSSHPublicKeyCallback,
|
||||
}
|
||||
serverConfig.AddHostKey(hostSigner)
|
||||
|
||||
listenAddr := fmt.Sprintf("0.0.0.0:%d", ClusterEmbeddedSSHPort)
|
||||
listener, err := net.Listen("tcp", listenAddr)
|
||||
if err != nil {
|
||||
startErr = fmt.Errorf("embedded_ssh_listen_failed: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.L.Info().
|
||||
Str("addr", listenAddr).
|
||||
Msg("embedded_cluster_ssh_started")
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = listener.Close()
|
||||
}()
|
||||
|
||||
go s.embeddedSSHAcceptLoop(ctx, listener, serverConfig)
|
||||
})
|
||||
|
||||
return startErr
|
||||
}
|
||||
|
||||
func (s *Service) embeddedSSHPublicKeyCallback(conn ssh.ConnMetadata, presentedKey ssh.PublicKey) (*ssh.Permissions, error) {
|
||||
if strings.TrimSpace(conn.User()) != "root" {
|
||||
return nil, fmt.Errorf("invalid_user")
|
||||
}
|
||||
|
||||
identities, err := s.ListClusterSSHIdentities()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list_cluster_identities_failed: %w", err)
|
||||
}
|
||||
|
||||
for _, identity := range identities {
|
||||
pub := strings.TrimSpace(identity.PublicKey)
|
||||
if pub == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parsedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pub + "\n"))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.Equal(parsedKey.Marshal(), presentedKey.Marshal()) {
|
||||
return &ssh.Permissions{
|
||||
Extensions: map[string]string{
|
||||
"node_uuid": identity.NodeUUID,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unauthorized_key")
|
||||
}
|
||||
|
||||
func (s *Service) embeddedSSHAcceptLoop(ctx context.Context, listener net.Listener, serverConfig *ssh.ServerConfig) {
|
||||
for {
|
||||
rawConn, err := listener.Accept()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
logger.L.Warn().Err(err).Msg("embedded_ssh_accept_failed")
|
||||
continue
|
||||
}
|
||||
|
||||
go s.handleEmbeddedSSHConn(ctx, rawConn, serverConfig)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleEmbeddedSSHConn(ctx context.Context, rawConn net.Conn, serverConfig *ssh.ServerConfig) {
|
||||
defer rawConn.Close()
|
||||
|
||||
_, chans, reqs, err := ssh.NewServerConn(rawConn, serverConfig)
|
||||
if err != nil {
|
||||
logger.L.Warn().Err(err).Msg("embedded_ssh_handshake_failed")
|
||||
return
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
for newChannel := range chans {
|
||||
if newChannel.ChannelType() != "session" {
|
||||
_ = newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
|
||||
continue
|
||||
}
|
||||
|
||||
channel, requests, err := newChannel.Accept()
|
||||
if err != nil {
|
||||
logger.L.Warn().Err(err).Msg("embedded_ssh_channel_accept_failed")
|
||||
continue
|
||||
}
|
||||
|
||||
go s.handleEmbeddedSSHSession(ctx, channel, requests)
|
||||
}
|
||||
}
|
||||
|
||||
func parseExecRequestPayload(payload []byte) (string, error) {
|
||||
if len(payload) < 4 {
|
||||
return "", fmt.Errorf("invalid_exec_payload")
|
||||
}
|
||||
|
||||
size := int(binary.BigEndian.Uint32(payload[:4]))
|
||||
if size < 0 || len(payload) < 4+size {
|
||||
return "", fmt.Errorf("invalid_exec_payload_size")
|
||||
}
|
||||
|
||||
return string(payload[4 : 4+size]), nil
|
||||
}
|
||||
|
||||
func exitCodeFromErr(err error) uint32 {
|
||||
if err == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
var exitErr *exec.ExitError
|
||||
if ok := strings.Contains(err.Error(), "signal: killed"); ok {
|
||||
return 137
|
||||
}
|
||||
if ok := strings.Contains(err.Error(), "signal: terminated"); ok {
|
||||
return 143
|
||||
}
|
||||
if ok := strings.Contains(err.Error(), "signal: interrupt"); ok {
|
||||
return 130
|
||||
}
|
||||
|
||||
if errors.As(err, &exitErr) {
|
||||
if code := exitErr.ExitCode(); code >= 0 {
|
||||
return uint32(code)
|
||||
}
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func (s *Service) handleEmbeddedSSHSession(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) {
|
||||
defer channel.Close()
|
||||
|
||||
execReceived := false
|
||||
for req := range requests {
|
||||
switch req.Type {
|
||||
case "exec":
|
||||
if execReceived {
|
||||
_ = req.Reply(false, nil)
|
||||
continue
|
||||
}
|
||||
execReceived = true
|
||||
|
||||
command, err := parseExecRequestPayload(req.Payload)
|
||||
if err != nil {
|
||||
_ = req.Reply(false, nil)
|
||||
return
|
||||
}
|
||||
|
||||
command = strings.TrimSpace(command)
|
||||
if command == "" {
|
||||
_ = req.Reply(false, nil)
|
||||
return
|
||||
}
|
||||
|
||||
_ = req.Reply(true, nil)
|
||||
|
||||
cmd := exec.CommandContext(ctx, "/bin/sh", "-c", command)
|
||||
cmd.Stdin = channel
|
||||
cmd.Stdout = channel
|
||||
cmd.Stderr = channel.Stderr()
|
||||
|
||||
runErr := cmd.Run()
|
||||
exitCode := exitCodeFromErr(runErr)
|
||||
_, _ = channel.SendRequest("exit-status", false, ssh.Marshal(struct{ Status uint32 }{Status: exitCode}))
|
||||
return
|
||||
default:
|
||||
_ = req.Reply(false, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -29,8 +28,7 @@ const (
|
||||
clusterSSHPrivateFileName = "id_ed25519"
|
||||
clusterSSHPublicFileName = "id_ed25519.pub"
|
||||
|
||||
clusterManagedKeyStart = "# --- sylve cluster replication keys start ---"
|
||||
clusterManagedKeyEnd = "# --- sylve cluster replication keys end ---"
|
||||
ClusterEmbeddedSSHPort = 8122
|
||||
)
|
||||
|
||||
func (s *Service) clusterSSHDir() (string, error) {
|
||||
@@ -50,14 +48,6 @@ func (s *Service) clusterSSHDir() (string, error) {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func (s *Service) clusterSSHPrivateKeyPath() (string, error) {
|
||||
dir, err := s.clusterSSHDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(dir, clusterSSHPrivateFileName), nil
|
||||
}
|
||||
|
||||
func (s *Service) ClusterSSHPrivateKeyPath() (string, error) {
|
||||
privatePath, _, _, err := s.ensureLocalClusterSSHKeyPair()
|
||||
if err != nil {
|
||||
@@ -160,7 +150,7 @@ func (s *Service) EnsureAndPublishLocalSSHIdentity() error {
|
||||
NodeUUID: strings.TrimSpace(detail.NodeID),
|
||||
SSHUser: "root",
|
||||
SSHHost: s.localClusterSSHHost(),
|
||||
SSHPort: 22,
|
||||
SSHPort: ClusterEmbeddedSSHPort,
|
||||
PublicKey: pubKey,
|
||||
}
|
||||
|
||||
@@ -240,92 +230,7 @@ func (s *Service) forwardSSHIdentityToLeader(identity clusterModels.ClusterSSHId
|
||||
return fmt.Errorf("forward_ssh_identity_to_leader_failed: %w", lastErr)
|
||||
}
|
||||
|
||||
func replaceManagedSSHBlock(existing string, managed []string) string {
|
||||
managedSet := make(map[string]struct{}, len(managed))
|
||||
normalized := make([]string, 0, len(managed))
|
||||
for _, line := range managed {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := managedSet[line]; ok {
|
||||
continue
|
||||
}
|
||||
managedSet[line] = struct{}{}
|
||||
normalized = append(normalized, line)
|
||||
}
|
||||
sort.Strings(normalized)
|
||||
|
||||
blockLines := []string{clusterManagedKeyStart}
|
||||
blockLines = append(blockLines, normalized...)
|
||||
blockLines = append(blockLines, clusterManagedKeyEnd)
|
||||
block := strings.Join(blockLines, "\n")
|
||||
|
||||
start := strings.Index(existing, clusterManagedKeyStart)
|
||||
end := strings.Index(existing, clusterManagedKeyEnd)
|
||||
if start >= 0 && end > start {
|
||||
end += len(clusterManagedKeyEnd)
|
||||
left := strings.TrimRight(existing[:start], "\n")
|
||||
right := strings.TrimLeft(existing[end:], "\n")
|
||||
parts := []string{}
|
||||
if left != "" {
|
||||
parts = append(parts, left)
|
||||
}
|
||||
parts = append(parts, block)
|
||||
if right != "" {
|
||||
parts = append(parts, right)
|
||||
}
|
||||
return strings.TrimSpace(strings.Join(parts, "\n\n")) + "\n"
|
||||
}
|
||||
|
||||
base := strings.TrimSpace(existing)
|
||||
if base == "" {
|
||||
return block + "\n"
|
||||
}
|
||||
return base + "\n\n" + block + "\n"
|
||||
}
|
||||
|
||||
func (s *Service) ReconcileClusterSSHAuthorizedKeys() error {
|
||||
identities, err := s.ListClusterSSHIdentities()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
managed := make([]string, 0, len(identities))
|
||||
for _, identity := range identities {
|
||||
pub := strings.TrimSpace(identity.PublicKey)
|
||||
if pub == "" {
|
||||
continue
|
||||
}
|
||||
managed = append(managed, pub)
|
||||
}
|
||||
|
||||
sshDir := "/root/.ssh"
|
||||
if err := os.MkdirAll(sshDir, 0700); err != nil {
|
||||
return fmt.Errorf("root_ssh_dir_create_failed: %w", err)
|
||||
}
|
||||
if err := os.Chmod(sshDir, 0700); err != nil {
|
||||
return fmt.Errorf("root_ssh_dir_chmod_failed: %w", err)
|
||||
}
|
||||
|
||||
authKeysPath := filepath.Join(sshDir, "authorized_keys")
|
||||
existing := ""
|
||||
if raw, readErr := os.ReadFile(authKeysPath); readErr == nil {
|
||||
existing = string(raw)
|
||||
}
|
||||
|
||||
next := replaceManagedSSHBlock(existing, managed)
|
||||
if err := os.WriteFile(authKeysPath, []byte(next), 0600); err != nil {
|
||||
return fmt.Errorf("authorized_keys_write_failed: %w", err)
|
||||
}
|
||||
if err := os.Chmod(authKeysPath, 0600); err != nil {
|
||||
return fmt.Errorf("authorized_keys_chmod_failed: %w", err)
|
||||
}
|
||||
|
||||
logger.L.Debug().
|
||||
Int("managed_keys", len(managed)).
|
||||
Str("path", authKeysPath).
|
||||
Msg("cluster_ssh_authorized_keys_reconciled")
|
||||
|
||||
logger.L.Debug().Msg("cluster_ssh_authorized_keys_reconcile_skipped_embedded_ssh_mode")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -203,6 +203,11 @@ func (s *Service) runReplicationPolicy(ctx context.Context, policy *clusterModel
|
||||
s.updateReplicationPolicyResult(policy, err)
|
||||
return err
|
||||
}
|
||||
if len(sourceDatasets) == 0 {
|
||||
runErr := fmt.Errorf("no_source_datasets_found")
|
||||
s.updateReplicationPolicyResult(policy, runErr)
|
||||
return runErr
|
||||
}
|
||||
|
||||
identities, err := s.Cluster.ListClusterSSHIdentities()
|
||||
if err != nil {
|
||||
@@ -252,21 +257,29 @@ func (s *Service) runReplicationPolicy(ctx context.Context, policy *clusterModel
|
||||
})
|
||||
|
||||
var runErr error
|
||||
eligibleTargets := 0
|
||||
skippedOffline := 0
|
||||
skippedNoIdentity := 0
|
||||
attemptedTransfers := 0
|
||||
for _, target := range targets {
|
||||
targetNodeID := strings.TrimSpace(target.NodeID)
|
||||
if targetNodeID == "" || targetNodeID == localNodeID {
|
||||
continue
|
||||
}
|
||||
if status, ok := statusByNode[targetNodeID]; ok && status != "online" {
|
||||
skippedOffline++
|
||||
continue
|
||||
}
|
||||
|
||||
identity, ok := identityByNode[targetNodeID]
|
||||
if !ok {
|
||||
skippedNoIdentity++
|
||||
continue
|
||||
}
|
||||
eligibleTargets++
|
||||
|
||||
for _, sourceDataset := range sourceDatasets {
|
||||
attemptedTransfers++
|
||||
backupRoot, destSuffix := splitDatasetForTarget(sourceDataset)
|
||||
targetSpec := &clusterModels.BackupTarget{
|
||||
SSHHost: fmt.Sprintf("%s@%s", strings.TrimSpace(identity.SSHUser), strings.TrimSpace(identity.SSHHost)),
|
||||
@@ -297,6 +310,14 @@ func (s *Service) runReplicationPolicy(ctx context.Context, policy *clusterModel
|
||||
}
|
||||
}
|
||||
|
||||
if runErr == nil {
|
||||
if eligibleTargets == 0 {
|
||||
runErr = fmt.Errorf("no_eligible_replication_targets (offline=%d missing_identity=%d)", skippedOffline, skippedNoIdentity)
|
||||
} else if attemptedTransfers == 0 {
|
||||
runErr = fmt.Errorf("no_replication_transfers_executed")
|
||||
}
|
||||
}
|
||||
|
||||
s.finalizeReplicationEvent(&event, runErr)
|
||||
s.updateReplicationPolicyResult(policy, runErr)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user