@@ -533,8 +533,9 @@ func detectAESAcceleration() bool {
533533
534534// WaitSSHReady waits until the SSH server is ready to accept connections.
535535// The dialContext function is used to create a connection to the SSH server.
536- // The addr, user, instanceName parameter is used for ssh.ClientConn creation.
536+ // The addr, user, and instanceName parameters are used for ssh.ClientConn creation.
537537// The timeoutSeconds parameter specifies the maximum number of seconds to wait.
538+ // If the SSH server uses a host key not in the instance's ssh_known_hosts, the host key will be updated.
538539func WaitSSHReady (ctx context.Context , dialContext func (context.Context ) (net.Conn , error ), addr , user , instanceName string , timeoutSeconds int ) error {
539540 ctx , cancel := context .WithTimeout (ctx , time .Duration (timeoutSeconds )* time .Second )
540541 defer cancel ()
@@ -544,16 +545,22 @@ func WaitSSHReady(ctx context.Context, dialContext func(context.Context) (net.Co
544545 if err != nil {
545546 return err
546547 }
547- // Prepare HostKeyCallback
548+ // Prepare HostKeyCallback that corrects known_hosts when the host key is not found.
549+ // This is required when the SSH server generates a new host key on the first boot,
550+ // instead of accepting the provided host keys via cloud-init.
551+ // e.g. https://github.com/lima-vm/alpine-lima
548552 hostKeyChecker , err := HostKeyCheckerWithKeysInKnownHosts (instanceName )
549553 if err != nil {
550554 return err
551555 }
556+ // Ensure known_hosts is updated at the end
557+ defer hostKeyChecker .updateKnownHosts ()
558+
552559 // Prepare ssh client config
553560 sshConfig := & ssh.ClientConfig {
554561 User : user ,
555562 Auth : []ssh.AuthMethod {ssh .PublicKeys (signer )},
556- HostKeyCallback : hostKeyChecker ,
563+ HostKeyCallback : hostKeyChecker . check ,
557564 Timeout : 10 * time .Second ,
558565 }
559566 // Wait until the SSH server is available.
@@ -580,11 +587,16 @@ func WaitSSHReady(ctx context.Context, dialContext func(context.Context) (net.Co
580587 }
581588}
582589
590+ // errHostKeyMismatch is returned when the SSH host key does not match known hosts.
591+ var errHostKeyMismatch = errors .New ("ssh: host key mismatch" )
592+
583593func isRetryableError (err error ) bool {
584594 // Port forwarder accepted the connection, but the destination is not ready yet.
585595 return osutil .IsConnectionResetError (err ) ||
586596 // SSH server not ready yet (e.g. host key not generated on initial boot).
587- strings .HasSuffix (err .Error (), "no supported methods remain" )
597+ strings .HasSuffix (err .Error (), "no supported methods remain" ) ||
598+ // Host key is not yet in known_hosts, but will be corrected, so we can retry.
599+ errors .Is (err , errHostKeyMismatch )
588600}
589601
590602// userPrivateKeySigner returns the user's private key signer.
@@ -606,19 +618,71 @@ func userPrivateKeySigner() (ssh.Signer, error) {
606618 return signer , nil
607619}
608620
609- func HostKeyCheckerWithKeysInKnownHosts (instanceName string ) (ssh.HostKeyCallback , error ) {
621+ type hostKeyCheckerWithCorrectingUnknownKeys struct {
622+ instanceName string
623+ publicKeys map [string ]ssh.PublicKey
624+ unknownKeys map [string ]ssh.PublicKey
625+ }
626+
627+ // check checks whether the given host key is in the known hosts.
628+ // If the host key is not found, it is recorded in unknownKeys for later correction.
629+ func (h * hostKeyCheckerWithCorrectingUnknownKeys ) check (hostname string , remote net.Addr , key ssh.PublicKey ) error {
630+ marshaledKey := string (key .Marshal ())
631+ if _ , ok := h .publicKeys [marshaledKey ]; ok {
632+ return nil
633+ }
634+ if _ , ok := h .unknownKeys [marshaledKey ]; ok {
635+ return nil
636+ }
637+ logrus .Warnf ("SSH host key for instance %q not found in %s; adding it" , h .instanceName , filenames .SSHKnownHosts )
638+ h .unknownKeys [marshaledKey ] = key
639+ // If always returning nil here, GitHub Advanced Security may report "Use of insecure HostKeyCallback implementation".
640+ // So, we return an error here to make the SSH client report the host key mismatch.
641+ return errHostKeyMismatch
642+ }
643+
644+ // updateKnownHosts updates the known_hosts file with any unknown keys recorded during checks.
645+ // It is required to call this method after using the hostKeyCheckerWithCorrectingUnknownKeys
646+ // to ensure that the known_hosts file is updated appropriately.
647+ func (h * hostKeyCheckerWithCorrectingUnknownKeys ) updateKnownHosts () error {
648+ if len (h .unknownKeys ) == 0 {
649+ return nil
650+ }
651+ // If there are unknown keys, our provided host keys via cloud-init were not accepted by the Guest SSH server.
652+ // We need to replace known_hosts file with the unknown keys that guest SSH server is actually using.
653+ logrus .Infof ("Updating %s file for instance %q with %d new host key(s)" , filenames .SSHKnownHosts , h .instanceName , len (h .unknownKeys ))
654+ instanceDir , err := dirnames .InstanceDir (h .instanceName )
655+ if err != nil {
656+ return fmt .Errorf ("failed to get instance dir for instance %q: %w" , h .instanceName , err )
657+ }
658+ knownHostsPath := filepath .Join (instanceDir , filenames .SSHKnownHosts )
659+ hostname := hostname .FromInstName (h .instanceName )
660+ var sshKnownHosts []byte
661+ for _ , key := range h .unknownKeys {
662+ sshKnownHosts = append (sshKnownHosts , knownhosts .Line ([]string {hostname }, key )... )
663+ sshKnownHosts = append (sshKnownHosts , '\n' )
664+ }
665+ if err := os .WriteFile (knownHostsPath , sshKnownHosts , 0o644 ); err != nil {
666+ return fmt .Errorf ("failed to write known_hosts file at %s: %w" , knownHostsPath , err )
667+ }
668+ return nil
669+ }
670+
671+ // HostKeyCheckerWithKeysInKnownHosts creates a host key checker using the known hosts file
672+ // located in the instance directory.
673+ func HostKeyCheckerWithKeysInKnownHosts (instanceName string ) (hostKeyCheckerWithCorrectingUnknownKeys , error ) {
610674 publicKeys , err := PublicKeysFromKnownHosts (instanceName )
611675 if err != nil {
612- return nil , err
676+ return hostKeyCheckerWithCorrectingUnknownKeys {} , err
613677 }
614- return func ( _ string , _ net. Addr , key ssh.PublicKey ) error {
615- keyBytes := key . Marshal ()
616- for _ , pk := range publicKeys {
617- if bytes . Equal ( keyBytes , pk . Marshal ()) {
618- return nil
619- }
620- }
621- return errors . New ( "ssh: host key mismatch" )
678+ publicKeysMap := make ( map [ string ] ssh.PublicKey , len ( publicKeys ))
679+ for _ , pk := range publicKeys {
680+ publicKeysMap [ string ( pk . Marshal ())] = pk
681+ }
682+ return hostKeyCheckerWithCorrectingUnknownKeys {
683+ instanceName : instanceName ,
684+ publicKeys : publicKeysMap ,
685+ unknownKeys : make ( map [ string ]ssh. PublicKey ),
622686 }, nil
623687}
624688
0 commit comments