Skip to content

Commit 1d80083

Browse files
committed
pkg/sshutil: Update ssh_known_hosts if the SSH server does not accept host keys via cloud-init
Prepare HostKeyCallback that corrects known_hosts when the host key is not found. This is required when the SSH server generates a new host key on the first boot, instead of accepting the provided host keys via cloud-init. e.g. https://github.com/lima-vm/alpine-lima Signed-off-by: Norio Nomura <[email protected]>
1 parent 618e332 commit 1d80083

File tree

4 files changed

+118
-17
lines changed

4 files changed

+118
-17
lines changed

pkg/driver/krunkit/krunkit_driver_darwin_arm64.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
"github.com/lima-vm/lima/v2/pkg/networks/usernet"
2828
"github.com/lima-vm/lima/v2/pkg/osutil"
2929
"github.com/lima-vm/lima/v2/pkg/ptr"
30+
"github.com/lima-vm/lima/v2/pkg/sshutil"
3031
)
3132

3233
type LimaKrunkitDriver struct {
@@ -315,6 +316,17 @@ func (l *LimaKrunkitDriver) GuestAgentConn(_ context.Context) (net.Conn, string,
315316
return nil, "unix", nil
316317
}
317318

318-
func (l *LimaKrunkitDriver) AdditionalSetupForSSH(_ context.Context) error {
319+
func (l *LimaKrunkitDriver) AdditionalSetupForSSH(ctx context.Context) error {
320+
// Wait until the port is available.
321+
addr := net.JoinHostPort("127.0.0.1", fmt.Sprintf("%d", l.SSHLocalPort))
322+
dialContext := func(ctx context.Context) (net.Conn, error) {
323+
dialer := net.Dialer{Timeout: 1 * time.Second}
324+
return dialer.DialContext(ctx, "tcp", addr)
325+
}
326+
user := *l.Instance.Config.User.Name
327+
instanceName := l.Instance.Name
328+
if err := sshutil.WaitSSHReady(ctx, dialContext, addr, user, instanceName, 600); err != nil {
329+
return err
330+
}
319331
return nil
320332
}

pkg/driver/qemu/qemu_driver.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import (
3737
"github.com/lima-vm/lima/v2/pkg/osutil"
3838
"github.com/lima-vm/lima/v2/pkg/ptr"
3939
"github.com/lima-vm/lima/v2/pkg/reflectutil"
40+
"github.com/lima-vm/lima/v2/pkg/sshutil"
4041
"github.com/lima-vm/lima/v2/pkg/version/versionutil"
4142
)
4243

@@ -721,6 +722,17 @@ func (l *LimaQemuDriver) ForwardGuestAgent() bool {
721722
return l.vSockPort == 0 && l.virtioPort == ""
722723
}
723724

724-
func (l *LimaQemuDriver) AdditionalSetupForSSH(_ context.Context) error {
725+
func (l *LimaQemuDriver) AdditionalSetupForSSH(ctx context.Context) error {
726+
// Wait until the port is available.
727+
addr := net.JoinHostPort("127.0.0.1", fmt.Sprintf("%d", l.SSHLocalPort))
728+
dialContext := func(ctx context.Context) (net.Conn, error) {
729+
dialer := net.Dialer{Timeout: 1 * time.Second}
730+
return dialer.DialContext(ctx, "tcp", addr)
731+
}
732+
user := *l.Instance.Config.User.Name
733+
instanceName := l.Instance.Name
734+
if err := sshutil.WaitSSHReady(ctx, dialContext, addr, user, instanceName, 600); err != nil {
735+
return err
736+
}
725737
return nil
726738
}

pkg/driver/wsl2/wsl_driver_windows.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"fmt"
1111
"net"
1212
"regexp"
13+
"time"
1314

1415
"github.com/Microsoft/go-winio"
1516
"github.com/Microsoft/go-winio/pkg/guid"
@@ -21,6 +22,7 @@ import (
2122
"github.com/lima-vm/lima/v2/pkg/limayaml"
2223
"github.com/lima-vm/lima/v2/pkg/ptr"
2324
"github.com/lima-vm/lima/v2/pkg/reflectutil"
25+
"github.com/lima-vm/lima/v2/pkg/sshutil"
2426
"github.com/lima-vm/lima/v2/pkg/windows"
2527
)
2628

@@ -358,6 +360,17 @@ func (l *LimaWslDriver) ForwardGuestAgent() bool {
358360
return l.vSockPort == 0 && l.virtioPort == ""
359361
}
360362

361-
func (l *LimaWslDriver) AdditionalSetupForSSH(_ context.Context) error {
363+
func (l *LimaWslDriver) AdditionalSetupForSSH(ctx context.Context) error {
364+
// Wait until the port is available.
365+
addr := net.JoinHostPort("127.0.0.1", fmt.Sprintf("%d", l.SSHLocalPort))
366+
dialContext := func(ctx context.Context) (net.Conn, error) {
367+
dialer := net.Dialer{Timeout: 1 * time.Second}
368+
return dialer.DialContext(ctx, "tcp", addr)
369+
}
370+
user := *l.Instance.Config.User.Name
371+
instanceName := l.Instance.Name
372+
if err := sshutil.WaitSSHReady(ctx, dialContext, addr, user, instanceName, 600); err != nil {
373+
return err
374+
}
362375
return nil
363376
}

pkg/sshutil/sshutil.go

Lines changed: 78 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -533,8 +533,9 @@ func detectAESAcceleration() bool {
533533

534534
// WaitSSHReady waits until the SSH server is ready to accept connections.
535535
// The dialContext function is used to create a connection to the SSH server.
536-
// The addr, user, instanceName parameter is used for ssh.ClientConn creation.
536+
// The addr, user, and instanceName parameters are used for ssh.ClientConn creation.
537537
// The timeoutSeconds parameter specifies the maximum number of seconds to wait.
538+
// If the SSH server uses a host key not in the instance's ssh_known_hosts, the host key will be updated.
538539
func WaitSSHReady(ctx context.Context, dialContext func(context.Context) (net.Conn, error), addr, user, instanceName string, timeoutSeconds int) error {
539540
ctx, cancel := context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
540541
defer cancel()
@@ -544,16 +545,22 @@ func WaitSSHReady(ctx context.Context, dialContext func(context.Context) (net.Co
544545
if err != nil {
545546
return err
546547
}
547-
// Prepare HostKeyCallback
548+
// Prepare HostKeyCallback that corrects known_hosts when the host key is not found.
549+
// This is required when the SSH server generates a new host key on the first boot,
550+
// instead of accepting the provided host keys via cloud-init.
551+
// e.g. https://github.com/lima-vm/alpine-lima
548552
hostKeyChecker, err := HostKeyCheckerWithKeysInKnownHosts(instanceName)
549553
if err != nil {
550554
return err
551555
}
556+
// Ensure known_hosts is updated at the end
557+
defer hostKeyChecker.updateKnownHosts()
558+
552559
// Prepare ssh client config
553560
sshConfig := &ssh.ClientConfig{
554561
User: user,
555562
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
556-
HostKeyCallback: hostKeyChecker,
563+
HostKeyCallback: hostKeyChecker.check,
557564
Timeout: 10 * time.Second,
558565
}
559566
// Wait until the SSH server is available.
@@ -580,11 +587,16 @@ func WaitSSHReady(ctx context.Context, dialContext func(context.Context) (net.Co
580587
}
581588
}
582589

590+
// errHostKeyMismatch is returned when the SSH host key does not match known hosts.
591+
var errHostKeyMismatch = errors.New("ssh: host key mismatch")
592+
583593
func isRetryableError(err error) bool {
584594
// Port forwarder accepted the connection, but the destination is not ready yet.
585595
return osutil.IsConnectionResetError(err) ||
586596
// SSH server not ready yet (e.g. host key not generated on initial boot).
587-
strings.HasSuffix(err.Error(), "no supported methods remain")
597+
strings.HasSuffix(err.Error(), "no supported methods remain") ||
598+
// Host key is not yet in known_hosts, but will be corrected, so we can retry.
599+
errors.Is(err, errHostKeyMismatch)
588600
}
589601

590602
// userPrivateKeySigner returns the user's private key signer.
@@ -606,19 +618,71 @@ func userPrivateKeySigner() (ssh.Signer, error) {
606618
return signer, nil
607619
}
608620

609-
func HostKeyCheckerWithKeysInKnownHosts(instanceName string) (ssh.HostKeyCallback, error) {
621+
type hostKeyCheckerWithCorrectingUnknownKeys struct {
622+
instanceName string
623+
publicKeys map[string]ssh.PublicKey
624+
unknownKeys map[string]ssh.PublicKey
625+
}
626+
627+
// check checks whether the given host key is in the known hosts.
628+
// If the host key is not found, it is recorded in unknownKeys for later correction.
629+
func (h *hostKeyCheckerWithCorrectingUnknownKeys) check(hostname string, remote net.Addr, key ssh.PublicKey) error {
630+
marshaledKey := string(key.Marshal())
631+
if _, ok := h.publicKeys[marshaledKey]; ok {
632+
return nil
633+
}
634+
if _, ok := h.unknownKeys[marshaledKey]; ok {
635+
return nil
636+
}
637+
logrus.Warnf("SSH host key for instance %q not found in %s; adding it", h.instanceName, filenames.SSHKnownHosts)
638+
h.unknownKeys[marshaledKey] = key
639+
// If always returning nil here, GitHub Advanced Security may report "Use of insecure HostKeyCallback implementation".
640+
// So, we return an error here to make the SSH client report the host key mismatch.
641+
return errHostKeyMismatch
642+
}
643+
644+
// updateKnownHosts updates the known_hosts file with any unknown keys recorded during checks.
645+
// It is required to call this method after using the hostKeyCheckerWithCorrectingUnknownKeys
646+
// to ensure that the known_hosts file is updated appropriately.
647+
func (h *hostKeyCheckerWithCorrectingUnknownKeys) updateKnownHosts() error {
648+
if len(h.unknownKeys) == 0 {
649+
return nil
650+
}
651+
// If there are unknown keys, our provided host keys via cloud-init were not accepted by the Guest SSH server.
652+
// We need to replace known_hosts file with the unknown keys that guest SSH server is actually using.
653+
logrus.Infof("Updating %s file for instance %q with %d new host key(s)", filenames.SSHKnownHosts, h.instanceName, len(h.unknownKeys))
654+
instanceDir, err := dirnames.InstanceDir(h.instanceName)
655+
if err != nil {
656+
return fmt.Errorf("failed to get instance dir for instance %q: %w", h.instanceName, err)
657+
}
658+
knownHostsPath := filepath.Join(instanceDir, filenames.SSHKnownHosts)
659+
hostname := hostname.FromInstName(h.instanceName)
660+
var sshKnownHosts []byte
661+
for _, key := range h.unknownKeys {
662+
sshKnownHosts = append(sshKnownHosts, knownhosts.Line([]string{hostname}, key)...)
663+
sshKnownHosts = append(sshKnownHosts, '\n')
664+
}
665+
if err := os.WriteFile(knownHostsPath, sshKnownHosts, 0o644); err != nil {
666+
return fmt.Errorf("failed to write known_hosts file at %s: %w", knownHostsPath, err)
667+
}
668+
return nil
669+
}
670+
671+
// HostKeyCheckerWithKeysInKnownHosts creates a host key checker using the known hosts file
672+
// located in the instance directory.
673+
func HostKeyCheckerWithKeysInKnownHosts(instanceName string) (hostKeyCheckerWithCorrectingUnknownKeys, error) {
610674
publicKeys, err := PublicKeysFromKnownHosts(instanceName)
611675
if err != nil {
612-
return nil, err
676+
return hostKeyCheckerWithCorrectingUnknownKeys{}, err
613677
}
614-
return func(_ string, _ net.Addr, key ssh.PublicKey) error {
615-
keyBytes := key.Marshal()
616-
for _, pk := range publicKeys {
617-
if bytes.Equal(keyBytes, pk.Marshal()) {
618-
return nil
619-
}
620-
}
621-
return errors.New("ssh: host key mismatch")
678+
publicKeysMap := make(map[string]ssh.PublicKey, len(publicKeys))
679+
for _, pk := range publicKeys {
680+
publicKeysMap[string(pk.Marshal())] = pk
681+
}
682+
return hostKeyCheckerWithCorrectingUnknownKeys{
683+
instanceName: instanceName,
684+
publicKeys: publicKeysMap,
685+
unknownKeys: make(map[string]ssh.PublicKey),
622686
}, nil
623687
}
624688

0 commit comments

Comments
 (0)