diff --git a/cmd/configure.go b/cmd/configure.go index 0fbf4d75..1f80017f 100644 --- a/cmd/configure.go +++ b/cmd/configure.go @@ -3,7 +3,9 @@ package cmd import ( "crypto/tls" "fmt" + "github.com/pelican-dev/wings/utils" "io" + "log" "net/http" "net/url" "os" @@ -53,13 +55,15 @@ func configureCmdRun(cmd *cobra.Command, args []string) { } if _, err := os.Stat(configureArgs.ConfigPath); err == nil && !configureArgs.Override { - survey.AskOne(&survey.Confirm{Message: "Override existing configuration file"}, &configureArgs.Override) + err := survey.AskOne(&survey.Confirm{Message: "Override existing configuration file"}, &configureArgs.Override) + if err != nil && err != terminal.InterruptErr { + log.Fatal(err) + } if !configureArgs.Override { - fmt.Println("Aborting process; a configuration file already exists for this node.") - os.Exit(1) + log.Fatal("Aborting process; a configuration file already exists for this node.") } } else if err != nil && !os.IsNotExist(err) { - panic(err) + log.Fatal(err) } var questions []*survey.Question @@ -111,8 +115,7 @@ func configureCmdRun(cmd *cobra.Command, args []string) { if err == terminal.InterruptErr { return } - - panic(err) + log.Fatal(err) } c := &http.Client{ @@ -121,7 +124,7 @@ func configureCmdRun(cmd *cobra.Command, args []string) { req, err := getRequest() if err != nil { - panic(err) + log.Fatal(err) } fmt.Printf("%+v", req.Header) @@ -132,7 +135,7 @@ func configureCmdRun(cmd *cobra.Command, args []string) { fmt.Println("Failed to fetch configuration from the panel.\n", err.Error()) os.Exit(1) } - defer res.Body.Close() + defer utils.CloseResponseBodyWithErrorHandling(res.Body) if res.StatusCode == http.StatusForbidden || res.StatusCode == http.StatusUnauthorized { fmt.Println("The authentication credentials provided were not valid.") @@ -148,18 +151,18 @@ func configureCmdRun(cmd *cobra.Command, args []string) { cfg, err := config.NewAtPath(configPath) if err != nil { - panic(err) + log.Fatal(err) } if err := json.Unmarshal(b, cfg); err != nil { - panic(err) + log.Fatal(err) } - + // Manually specify the Panel URL as it won't be decoded from JSON. cfg.PanelLocation = configureArgs.PanelURL if err = config.WriteToDisk(cfg); err != nil { - panic(err) + log.Fatal(err) } fmt.Println("Successfully configured wings.") @@ -168,7 +171,7 @@ func configureCmdRun(cmd *cobra.Command, args []string) { func getRequest() (*http.Request, error) { u, err := url.Parse(configureArgs.PanelURL) if err != nil { - panic(err) + log.Fatal(err) } u.Path = path.Join(u.Path, fmt.Sprintf("api/application/nodes/%s/configuration", configureArgs.Node)) diff --git a/cmd/diagnostics.go b/cmd/diagnostics.go index 462a74fb..777fec42 100644 --- a/cmd/diagnostics.go +++ b/cmd/diagnostics.go @@ -19,11 +19,11 @@ import ( "github.com/AlecAivazis/survey/v2/terminal" "github.com/apex/log" "github.com/docker/docker/api/types" + dockerSystem "github.com/docker/docker/api/types/system" // Alias the correct system package "github.com/docker/docker/pkg/parsers/kernel" "github.com/docker/docker/pkg/parsers/operatingsystem" "github.com/goccy/go-json" "github.com/spf13/cobra" - dockerSystem "github.com/docker/docker/api/types/system" // Alias the correct system package "github.com/pelican-dev/wings/config" "github.com/pelican-dev/wings/environment" @@ -97,89 +97,344 @@ func diagnosticsCmdRun(*cobra.Command, []string) { dockerVersion, dockerInfo, dockerErr := getDockerInfo() output := &strings.Builder{} - fmt.Fprintln(output, "Pelican Wings - Diagnostics Report") - printHeader(output, "Versions") - fmt.Fprintln(output, " Wings:", system.Version) - if dockerErr == nil { - fmt.Fprintln(output, " Docker:", dockerVersion.Version) + type diagnosticField struct { + name string + format string + args []any } - if v, err := kernel.GetKernelVersion(); err == nil { - fmt.Fprintln(output, " Kernel:", v) - } - if os, err := operatingsystem.GetOperatingSystem(); err == nil { - fmt.Fprintln(output, " OS:", os) + diagnosticFields := []diagnosticField{ + { + name: "title", + format: "Pelican Wings - Diagnostics Report", + }, + { + name: "versions header", + format: printHeader("Versions"), + }, + { + name: "wings version", + format: " Wings: %v", + args: []any{system.Version}, + }, + { + name: "docker version", + format: " Docker: %v", + args: []any{ + func() string { + version := "unknown" + if dockerErr != nil { + log.WithError(dockerErr).Warn("failed to get docker version") + } else { + version = dockerVersion.Version + } + return version + }(), + }, + }, + { + name: "kernel version", + format: " Kernel: %v", + args: []any{ + func() string { + version := "unknown" + kernelver, err := kernel.GetKernelVersion() + if err != nil { + log.WithError(err).Warn("failed to get kernel version") + } else { + version = fmt.Sprint(kernelver) + } + return version + }(), + }, + }, + { + name: "operating system", + format: " OS: %v", + args: []any{ + func() string { + os, err := operatingsystem.GetOperatingSystem() + if err != nil { + log.WithError(err).Warn("failed to get operating system") + os = "unknown" + } + return os + }(), + }, + }, } - - printHeader(output, "Wings Configuration") - if err := config.FromFile(config.DefaultLocation); err != nil { + diagnosticFields = append(diagnosticFields, diagnosticField{ + name: "wings configuration header", + format: printHeader("Wings Configuration"), + }) + err := config.FromFile(config.DefaultLocation) + if err != nil { + log.WithError(err).Warn("failed to load configuration so configuration information will not be included in the report") + diagnosticFields = append(diagnosticFields, diagnosticField{ + name: "wings configuration", + format: "Failed to load configuration", + }, + ) } cfg := config.Get() - fmt.Fprintln(output, " Panel Location:", redact(cfg.PanelLocation)) - fmt.Fprintln(output, "") - fmt.Fprintln(output, " Internal Webserver:", redact(cfg.Api.Host), ":", cfg.Api.Port) - fmt.Fprintln(output, " SSL Enabled:", cfg.Api.Ssl.Enabled) - fmt.Fprintln(output, " SSL Certificate:", redact(cfg.Api.Ssl.CertificateFile)) - fmt.Fprintln(output, " SSL Key:", redact(cfg.Api.Ssl.KeyFile)) - fmt.Fprintln(output, "") - fmt.Fprintln(output, " SFTP Server:", redact(cfg.System.Sftp.Address), ":", cfg.System.Sftp.Port) - fmt.Fprintln(output, " SFTP Read-Only:", cfg.System.Sftp.ReadOnly) - fmt.Fprintln(output, "") - fmt.Fprintln(output, " Root Directory:", cfg.System.RootDirectory) - fmt.Fprintln(output, " Logs Directory:", cfg.System.LogDirectory) - fmt.Fprintln(output, " Data Directory:", cfg.System.Data) - fmt.Fprintln(output, " Archive Directory:", cfg.System.ArchiveDirectory) - fmt.Fprintln(output, " Backup Directory:", cfg.System.BackupDirectory) - fmt.Fprintln(output, "") - fmt.Fprintln(output, " Username:", cfg.System.Username) - fmt.Fprintln(output, " Server Time:", time.Now().Format(time.RFC1123Z)) - fmt.Fprintln(output, " Debug Mode:", cfg.Debug) + if err == nil { + diagnosticFields = append(diagnosticFields, []diagnosticField{ + { + name: "wings configuration header", + format: printHeader("Wings Configuration"), + }, + { + name: "panel location", + format: " Panel Location: %v\n", + args: []any{ + redact(cfg.PanelLocation), + }, + }, + { + name: "internal webserver", + format: " Internal Webserver: %v : %v", + args: []any{ + redact(cfg.Api.Host), + cfg.Api.Port, + }, + }, + { + name: "ssl enabled", + format: " SSL Enabled: %v", + args: []any{ + cfg.Api.Ssl.Enabled, + }, + }, + { + name: "ssl certificate", + format: " SSL Certificate: %v", + args: []any{ + redact(cfg.Api.Ssl.CertificateFile), + }, + }, + { + name: "ssl key", + format: " SSL Key: %v\n", + args: []any{ + redact(cfg.Api.Ssl.KeyFile), + }, + }, + { + name: "sftp server", + format: " SFTP Server: %v : %v", + args: []any{ + redact(cfg.System.Sftp.Address), + cfg.System.Sftp.Port, + }, + }, + { + name: "sftp read-only", + format: " SFTP Read-Only: %v\n", + args: []any{ + cfg.System.Sftp.ReadOnly, + }, + }, + { + name: "root directory", + format: " Root Directory: %v", + args: []any{ + cfg.System.RootDirectory, + }, + }, + { + name: "logs directory", + format: " Logs Directory: %v", + args: []any{ + cfg.System.LogDirectory, + }, + }, + { + name: "data directory", + format: " Data Directory: %v", + args: []any{ + cfg.System.Data, + }, + }, + { + name: "archive directory", + format: " Archive Directory: %v", + args: []any{ + cfg.System.ArchiveDirectory, + }, + }, + { + name: "backup directory", + format: " Backup Directory: %v\n", + args: []any{ + cfg.System.BackupDirectory, + }, + }, + { + name: "username", + format: " Username: %v", + args: []any{ + cfg.System.Username, + }, + }, + { + name: "debug mode", + format: " Debug Mode: %v", + args: []any{ + cfg.Debug, + }, + }, + }...) + } + diagnosticFields = append(diagnosticFields, []diagnosticField{ + { + name: "server time", + format: " Server Time: %v", + args: []any{ + time.Now().Format(time.RFC1123Z), + }, + }, + { + name: "docker info header", + format: printHeader("Docker: Info"), + }, + }...) - printHeader(output, "Docker: Info") - if dockerErr == nil { - fmt.Fprintln(output, "Server Version:", dockerInfo.ServerVersion) - fmt.Fprintln(output, "Storage Driver:", dockerInfo.Driver) + if dockerErr != nil { + log.WithError(dockerErr).Warn("failed to get docker info, so docker information will not be included in the report") + diagnosticFields = append(diagnosticFields, diagnosticField{ + name: "docker info", + format: "Failed to get docker info due to error %v", + args: []any{ + dockerErr, + }, + }) + } else { + diagnosticFields = append(diagnosticFields, []diagnosticField{ + { + name: "docker server version", + format: "Server Version: %v", + args: []any{ + dockerInfo.ServerVersion, + }, + }, + { + name: "docker storage driver", + format: "Storage Driver: %v", + args: []any{ + dockerInfo.Driver, + }, + }, + }...) if dockerInfo.DriverStatus != nil { for _, pair := range dockerInfo.DriverStatus { - fmt.Fprintf(output, " %s: %s\n", pair[0], pair[1]) + diagnosticFields = append(diagnosticFields, diagnosticField{ + name: "docker driver status", + format: " %v: %v", + args: []any{ + pair[0], + pair[1], + }, + }) } } if dockerInfo.SystemStatus != nil { for _, pair := range dockerInfo.SystemStatus { - fmt.Fprintf(output, " %s: %s\n", pair[0], pair[1]) + diagnosticFields = append(diagnosticFields, diagnosticField{ + name: "docker driver status", + format: " %v: %v", + args: []any{ + pair[0], + pair[1], + }, + }) } } - fmt.Fprintln(output, "LoggingDriver:", dockerInfo.LoggingDriver) - fmt.Fprintln(output, " CgroupDriver:", dockerInfo.CgroupDriver) + diagnosticFields = append(diagnosticFields, []diagnosticField{ + { + name: "docker LoggingDriver", + format: "LoggingDriver: %v", + args: []any{ + dockerInfo.LoggingDriver, + }, + }, + { + name: "docker CgroupDriver", + format: " CgroupDriver: %v", + args: []any{ + dockerInfo.CgroupDriver, + }, + }, + }...) if len(dockerInfo.Warnings) > 0 { for _, w := range dockerInfo.Warnings { - fmt.Fprintln(output, w) + diagnosticFields = append(diagnosticFields, diagnosticField{ + name: "docker warning", + format: "%v", + args: []any{ + w, + }, + }) } } - } else { - fmt.Fprintln(output, dockerErr.Error()) } - printHeader(output, "Docker: Running Containers") + diagnosticFields = append(diagnosticFields, diagnosticField{ + name: "docker running containers header", + format: printHeader("Docker: Running Containers"), + }) c := exec.Command("docker", "ps") if co, err := c.Output(); err == nil { - output.Write(co) + diagnosticFields = append(diagnosticFields, diagnosticField{ + name: "docker running containers", + format: "%v", + args: []any{ + string(co), + }, + }) } else { - fmt.Fprint(output, "Couldn't list containers: ", err) + diagnosticFields = append(diagnosticFields, diagnosticField{ + name: "docker running containers", + format: "Couldn't list containers: %v", + args: []any{ + err, + }, + }) } - printHeader(output, "Latest Wings Logs") + diagnosticFields = append(diagnosticFields, diagnosticField{ + name: "latest wings logs header", + format: printHeader("Latest Wings Logs"), + }) if diagnosticsArgs.IncludeLogs { p := "/var/log/pelican/wings.log" if cfg != nil { p = path.Join(cfg.System.LogDirectory, "wings.log") } if c, err := exec.Command("tail", "-n", strconv.Itoa(diagnosticsArgs.LogLines), p).Output(); err != nil { - fmt.Fprintln(output, "No logs found or an error occurred.") + diagnosticFields = append(diagnosticFields, diagnosticField{ + name: "no logs", + format: "No logs found or an error occurred.", + }) } else { - fmt.Fprintf(output, "%s\n", string(c)) + diagnosticFields = append(diagnosticFields, diagnosticField{ + name: "logs", + format: "%v", + args: []any{ + string(c), + }, + }) } } else { - fmt.Fprintln(output, "Logs redacted.") + diagnosticFields = append(diagnosticFields, diagnosticField{ + name: "logs redacted", + format: "Logs redacted.", + }) + } + + for _, f := range diagnosticFields { + _, err := fmt.Fprintf(output, f.format+"\n", f.args...) + if err != nil { + log.WithError(err).Warnf("failed to write diagnostic field '%v'", f.name) + } } if !diagnosticsArgs.IncludeEndpoints { @@ -264,7 +519,9 @@ func redact(s string) string { return s } -func printHeader(w io.Writer, title string) { - fmt.Fprintln(w, "\n|\n|", title) - fmt.Fprintln(w, "| ------------------------------") +func printHeader(title string) string { + output := "" + output += fmt.Sprintln("\n|\n|", title) + output += fmt.Sprint("| ------------------------------") + return output } diff --git a/config/config.go b/config/config.go index c80fed35..114ae0ae 100644 --- a/config/config.go +++ b/config/config.go @@ -666,7 +666,12 @@ func EnableLogRotation() error { if err != nil { return err } - defer f.Close() + defer func(f *os.File) { + err := f.Close() + if err != nil { + log.WithError(err).Error("failed to close logrotate file") + } + }(f) t, err := template.New("logrotate").Parse(`{{.LogDirectory}}/wings.log { size 10M diff --git a/environment/docker/container.go b/environment/docker/container.go index 217f1bc9..b28ad5bc 100644 --- a/environment/docker/container.go +++ b/environment/docker/container.go @@ -336,7 +336,12 @@ func (e *Environment) Readlog(lines int) ([]string, error) { if err != nil { return nil, errors.WithStack(err) } - defer r.Close() + defer func(r io.ReadCloser) { + err := r.Close() + if err != nil { + log.WithError(err).Error("failed to close container logs reader") + } + }(r) var out []string scanner := bufio.NewScanner(r) @@ -423,7 +428,12 @@ func (e *Environment) ensureImageExists(image string) error { return errors.Wrapf(err, "environment/docker: failed to pull \"%s\" image for server", image) } - defer out.Close() + defer func(out io.ReadCloser) { + err := out.Close() + if err != nil { + log.WithError(err).Error("failed to close image pull reader") + } + }(out) log.WithField("image", image).Debug("pulling docker image... this could take a bit of time") diff --git a/environment/docker/stats.go b/environment/docker/stats.go index 37a6c648..224f4b77 100644 --- a/environment/docker/stats.go +++ b/environment/docker/stats.go @@ -2,6 +2,7 @@ package docker import ( "context" + "github.com/apex/log" "io" "math" "time" @@ -44,7 +45,12 @@ func (e *Environment) pollResources(ctx context.Context) error { if err != nil { return err } - defer stats.Body.Close() + defer func(Body io.ReadCloser) { + err := Body.Close() + if err != nil { + log.WithError(err).Error("failed to close Docker container stats body") + } + }(stats.Body) uptime, err := e.Uptime(ctx) if err != nil { diff --git a/internal/ufs/removeall_unix.go b/internal/ufs/removeall_unix.go index d756021a..8936a65b 100644 --- a/internal/ufs/removeall_unix.go +++ b/internal/ufs/removeall_unix.go @@ -12,6 +12,7 @@ package ufs import ( "errors" + "github.com/apex/log" "io" "os" @@ -59,7 +60,12 @@ func removeAll(fs unixFS, path string) error { // If parent does not exist, base cannot exist. Fail silently return nil } - defer parent.Close() + defer func(parent File) { + err := parent.Close() + if err != nil { + log.WithError(err).Error("failed to close parent") + } + }(parent) if err := removeAllFrom(fs, parent, base); err != nil { if pathErr, ok := err.(*PathError); ok { @@ -96,7 +102,12 @@ func removeContents(fs unixFS, path string) error { // If parent does not exist, base cannot exist. Fail silently return nil } - defer parent.Close() + defer func(parent File) { + err := parent.Close() + if err != nil { + log.WithError(err).Error("failed to close parent") + } + }(parent) if err := removeContentsFrom(fs, parent, base); err != nil { if pathErr, ok := err.(*PathError); ok { diff --git a/internal/ufs/walk_unix.go b/internal/ufs/walk_unix.go index 065afc22..06fd40b1 100644 --- a/internal/ufs/walk_unix.go +++ b/internal/ufs/walk_unix.go @@ -9,6 +9,7 @@ package ufs import ( "bytes" "fmt" + "github.com/apex/log" iofs "io/fs" "os" "path" @@ -45,7 +46,12 @@ func (fs *UnixFS) walkDir(b []byte, parentfd int, name, relative string, d DirEn dirfd, err := fs.openat(parentfd, name, O_DIRECTORY|O_RDONLY, 0) if dirfd != 0 { - defer unix.Close(dirfd) + defer func(fd int) { + err := unix.Close(fd) + if err != nil { + log.WithError(err).Error("failed to close directory") + } + }(dirfd) } if err != nil { return err @@ -101,7 +107,12 @@ func ReadDirMap[T any](fs *UnixFS, path string, fn func(DirEntry) (T, error)) ([ if err != nil { return nil, err } - defer unix.Close(fd) + defer func(fd int) { + err := unix.Close(fd) + if err != nil { + log.WithError(err).Error("failed to close directory") + } + }(fd) entries, err := fs.readDir(fd, ".", path, nil) if err != nil { diff --git a/remote/http.go b/remote/http.go index 006a7f44..5ccbfcfc 100644 --- a/remote/http.go +++ b/remote/http.go @@ -154,7 +154,12 @@ func (c *client) request(ctx context.Context, method, path string, body *bytes.B res = r if r.HasError() { // Close the request body after returning the error to free up resources. - defer r.Body.Close() + defer func(Body io.ReadCloser) { + err := Body.Close() + if err != nil { + log.WithError(err).Error("http: failed to close response body") + } + }(r.Body) // Don't keep attempting to access this endpoint if the response is a 4XX // level error which indicates a client mistake. Only retry when the error // is due to a server issue (5XX error). diff --git a/remote/servers.go b/remote/servers.go index 832afe07..b486120d 100644 --- a/remote/servers.go +++ b/remote/servers.go @@ -3,6 +3,7 @@ package remote import ( "context" "fmt" + "github.com/pelican-dev/wings/utils" "strconv" "sync" @@ -74,7 +75,7 @@ func (c *client) GetServerConfiguration(ctx context.Context, uuid string) (Serve if err != nil { return config, err } - defer res.Body.Close() + defer utils.CloseResponseBodyWithErrorHandling(res.Body) err = res.BindJSON(&config) return config, err @@ -85,7 +86,7 @@ func (c *client) GetInstallationScript(ctx context.Context, uuid string) (Instal if err != nil { return InstallationScript{}, err } - defer res.Body.Close() + defer utils.CloseResponseBodyWithErrorHandling(res.Body) var config InstallationScript err = res.BindJSON(&config) @@ -138,7 +139,7 @@ func (c *client) ValidateSftpCredentials(ctx context.Context, request SftpAuthRe } return auth, err } - defer res.Body.Close() + defer utils.CloseResponseBodyWithErrorHandling(res.Body) if err := res.BindJSON(&auth); err != nil { return auth, err @@ -152,7 +153,7 @@ func (c *client) GetBackupRemoteUploadURLs(ctx context.Context, backup string, s if err != nil { return data, err } - defer res.Body.Close() + defer utils.CloseResponseBodyWithErrorHandling(res.Body) if err := res.BindJSON(&data); err != nil { return data, err } @@ -205,7 +206,7 @@ func (c *client) getServersPaged(ctx context.Context, page, limit int) ([]RawSer if err != nil { return nil, r.Meta, err } - defer res.Body.Close() + defer utils.CloseResponseBodyWithErrorHandling(res.Body) if err := res.BindJSON(&r); err != nil { return nil, r.Meta, err } diff --git a/router/downloader/downloader.go b/router/downloader/downloader.go index b8fd660a..27aa1b06 100644 --- a/router/downloader/downloader.go +++ b/router/downloader/downloader.go @@ -3,6 +3,7 @@ package downloader import ( "context" "fmt" + "github.com/pelican-dev/wings/utils" "io" "mime" "net" @@ -194,7 +195,7 @@ func (dl *Download) Execute() error { if err != nil { return ErrDownloadFailed } - defer res.Body.Close() + defer utils.CloseResponseBodyWithErrorHandling(res.Body) if res.StatusCode != http.StatusOK { return errors.New("downloader: got bad response status from endpoint: " + res.Status) } diff --git a/server/backup.go b/server/backup.go index d48afd33..2b3ccf4f 100644 --- a/server/backup.go +++ b/server/backup.go @@ -1,6 +1,7 @@ package server import ( + "github.com/pelican-dev/wings/internal/ufs" "io" "io/fs" "os" @@ -42,7 +43,12 @@ func (s *Server) getServerwideIgnoredFiles() (string, error) { } return "", err } - defer f.Close() + defer func(f ufs.File) { + err := f.Close() + if err != nil { + log.WithError(err).Error("failed to close .pelicanignore file") + } + }(f) if st.Mode()&os.ModeSymlink != 0 || st.Size() > 32*1024 { // Don't read a symlinked ignore file, or a file larger than 32KiB in size. return "", nil @@ -152,7 +158,12 @@ func (s *Server) RestoreBackup(b backup.BackupInterface, reader io.ReadCloser) ( // in the file one at a time and writing them to the disk. s.Log().Debug("starting file writing process for backup restoration") err = b.Restore(s.Context(), reader, func(file string, info fs.FileInfo, r io.ReadCloser) error { - defer r.Close() + defer func(r io.ReadCloser) { + err := r.Close() + if err != nil { + log.WithError(err).Error("failed to close restore callback reader") + } + }(r) s.Events().Publish(DaemonMessageEvent, "(restoring): "+file) // TODO: since this will be called a lot, it may be worth adding an optimized // Write with Chtimes method to the UnixFS that is able to re-use the diff --git a/server/backup/backup_local.go b/server/backup/backup_local.go index 9513a7ed..a52de259 100644 --- a/server/backup/backup_local.go +++ b/server/backup/backup_local.go @@ -2,6 +2,7 @@ package backup import ( "context" + "github.com/apex/log" "io" "os" "path/filepath" @@ -107,7 +108,12 @@ func (b *LocalBackup) Restore(ctx context.Context, _ io.Reader, callback Restore if err != nil { return err } - defer f.Close() + defer func(f *os.File) { + err := f.Close() + if err != nil { + log.WithError(err).Error("failed to close local backup file") + } + }(f) var reader io.Reader = f // Steal the logic we use for making backups which will be applied when restoring diff --git a/server/backup/backup_s3.go b/server/backup/backup_s3.go index de4e8727..ab0dfc50 100644 --- a/server/backup/backup_s3.go +++ b/server/backup/backup_s3.go @@ -3,7 +3,9 @@ package backup import ( "context" "fmt" + "github.com/apex/log" "io" + iofs "io/fs" "net/http" "os" "path/filepath" @@ -51,7 +53,12 @@ func (s *S3Backup) WithLogContext(c map[string]interface{}) { // Generate creates a new backup on the disk, moves it into the S3 bucket via // the provided presigned URL, and then deletes the backup from the disk. func (s *S3Backup) Generate(ctx context.Context, fsys *filesystem.Filesystem, ignore string) (*ArchiveDetails, error) { - defer s.Remove() + defer func(s *S3Backup) { + err := s.Remove() + if err != nil { + log.WithError(err).WithField("path", s.Path()).Error("failed to remove backup from disk") + } + }(s) a := &filesystem.Archive{ Filesystem: fsys, @@ -74,7 +81,12 @@ func (s *S3Backup) Generate(ctx context.Context, fsys *filesystem.Filesystem, ig if err != nil { return nil, errors.Wrap(err, "backup: could not read archive from disk") } - defer rc.Close() + defer func(rc *os.File) { + err := rc.Close() + if err != nil { + log.WithError(err).Error("failed to close backup file") + } + }(rc) parts, err := s.generateRemoteRequest(ctx, rc) if err != nil { @@ -106,7 +118,12 @@ func (s *S3Backup) Restore(ctx context.Context, r io.Reader, callback RestoreCal if err != nil { return err } - defer r.Close() + defer func(r iofs.File) { + err := r.Close() + if err != nil { + log.WithError(err).Error("failed to close file during restore") + } + }(r) return callback(f.NameInArchive, f.FileInfo, r) }); err != nil { @@ -117,7 +134,12 @@ func (s *S3Backup) Restore(ctx context.Context, r io.Reader, callback RestoreCal // Generates the remote S3 request and begins the upload. func (s *S3Backup) generateRemoteRequest(ctx context.Context, rc io.ReadCloser) ([]remote.BackupPart, error) { - defer rc.Close() + defer func(rc io.ReadCloser) { + err := rc.Close() + if err != nil { + log.WithError(err).Error("failed to close local backup") + } + }(rc) s.log().Debug("attempting to get size of backup...") size, err := s.Backup.Size() diff --git a/server/filesystem/filesystem_test.go b/server/filesystem/filesystem_test.go index 0366dfc8..1beb8d48 100644 --- a/server/filesystem/filesystem_test.go +++ b/server/filesystem/filesystem_test.go @@ -8,6 +8,7 @@ import ( "os" "path/filepath" "testing" + "time" "unicode/utf8" . "github.com/franela/goblin" @@ -17,7 +18,7 @@ import ( "github.com/pelican-dev/wings/config" ) -func NewFs() (*Filesystem, *rootFs) { +func NewFs() (*Filesystem, *RootFs) { config.Set(&config.Configuration{ AuthenticationToken: "abc", System: config.SystemConfiguration{ @@ -32,7 +33,7 @@ func NewFs() (*Filesystem, *rootFs) { return nil, nil } - rfs := rootFs{root: tmpDir} + rfs := RootFs{root: tmpDir} p := filepath.Join(tmpDir, "server") if err := os.Mkdir(p, 0o755); err != nil { @@ -50,7 +51,7 @@ func NewFs() (*Filesystem, *rootFs) { return fs, &rfs } -type rootFs struct { +type RootFs struct { root string } @@ -62,7 +63,7 @@ func getFileContent(file ufs.File) string { return w.String() } -func (rfs *rootFs) CreateServerFile(p string, c []byte) error { +func (rfs *RootFs) CreateServerFile(p string, c []byte) error { f, err := os.Create(filepath.Join(rfs.root, "server", p)) if err == nil { @@ -73,11 +74,11 @@ func (rfs *rootFs) CreateServerFile(p string, c []byte) error { return err } -func (rfs *rootFs) CreateServerFileFromString(p string, c string) error { +func (rfs *RootFs) CreateServerFileFromString(p string, c string) error { return rfs.CreateServerFile(p, []byte(c)) } -func (rfs *rootFs) StatServerFile(p string) (os.FileInfo, error) { +func (rfs *RootFs) StatServerFile(p string) (os.FileInfo, error) { return os.Stat(filepath.Join(rfs.root, "server", p)) } @@ -114,7 +115,10 @@ func TestFilesystem_Openfile(t *testing.T) { func TestFilesystem_Writefile(t *testing.T) { g := Goblin(t) fs, _ := NewFs() - + closeFileWithErrorChecking := func(f ufs.File, g *G) { + err := f.Close() + g.Assert(err).IsNil() + } g.Describe("Open and WriteFile", func() { buf := &bytes.Buffer{} @@ -130,7 +134,7 @@ func TestFilesystem_Writefile(t *testing.T) { f, _, err := fs.File("test.txt") g.Assert(err).IsNil() - defer f.Close() + defer closeFileWithErrorChecking(f, g) g.Assert(getFileContent(f)).Equal("test file content") g.Assert(fs.CachedUsage()).Equal(r.Size()) }) @@ -143,7 +147,7 @@ func TestFilesystem_Writefile(t *testing.T) { f, _, err := fs.File("/some/nested/test.txt") g.Assert(err).IsNil() - defer f.Close() + defer closeFileWithErrorChecking(f, g) g.Assert(getFileContent(f)).Equal("test file content") }) @@ -155,7 +159,7 @@ func TestFilesystem_Writefile(t *testing.T) { f, _, err := fs.File("foo/bar/test.txt") g.Assert(err).IsNil() - defer f.Close() + defer closeFileWithErrorChecking(f, g) g.Assert(getFileContent(f)).Equal("test file content") }) @@ -171,7 +175,8 @@ func TestFilesystem_Writefile(t *testing.T) { fs.SetDiskLimit(1024) b := make([]byte, 1025) - _, err := rand.Read(b) + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + _, err := rng.Read(b) g.Assert(err).IsNil() g.Assert(len(b)).Equal(1025) @@ -192,7 +197,7 @@ func TestFilesystem_Writefile(t *testing.T) { f, _, err := fs.File("test.txt") g.Assert(err).IsNil() - defer f.Close() + defer closeFileWithErrorChecking(f, g) g.Assert(getFileContent(f)).Equal("new data") }) @@ -556,7 +561,7 @@ func TestFilesystem_Delete(t *testing.T) { err = os.Symlink(filepath.Join(rfs.root, "source.txt"), filepath.Join(rfs.root, "/server/symlink.txt")) g.Assert(err).IsNil() - // Delete the symlink. (This should pass as we will delete the symlink itself, not it's target) + // Delete the symlink. (This should pass as we will delete the symlink itself, not its target) err = fs.Delete("symlink.txt") g.Assert(err).IsNil() diff --git a/server/server.go b/server/server.go index 8e0b9518..48dc810d 100644 --- a/server/server.go +++ b/server/server.go @@ -75,9 +75,6 @@ type Server struct { wsBagLocker sync.Mutex sinks map[system.SinkName]*system.SinkPool - - logSink *system.SinkPool - installSink *system.SinkPool } // New returns a new server instance with a context and all of the default @@ -166,7 +163,6 @@ func DetermineServerTimezone(envvars map[string]interface{}, defaultTimezone str return defaultTimezone } - // parseInvocation parses the start command in the same way we already do in the entrypoint // We can use this to set the container command with all variables replaced. func parseInvocation(invocation string, envvars map[string]interface{}, memory int64, port int, ip string) (parsed string) { diff --git a/utils/utils.go b/utils/utils.go new file mode 100644 index 00000000..2098391a --- /dev/null +++ b/utils/utils.go @@ -0,0 +1,13 @@ +package utils + +import ( + "github.com/apex/log" + "io" +) + +func CloseResponseBodyWithErrorHandling(body io.ReadCloser) { + err := body.Close() + if err != nil { + log.WithError(err).Error("failed to close response body") + } +} diff --git a/wings.go b/wings.go index ff04ba36..f92c79eb 100644 --- a/wings.go +++ b/wings.go @@ -1,18 +1,10 @@ package main import ( - "math/rand" - "time" - "github.com/pelican-dev/wings/cmd" ) func main() { - // Since we make use of the math/rand package in the code, especially for generating - // non-cryptographically secure random strings we need to seed the RNG. Just make use - // of the current time for this. - rand.Seed(time.Now().UnixNano()) - // Execute the main binary code. cmd.Execute() }