diff --git a/docs/cli.md b/docs/cli.md index e851148200..bc45769904 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -403,3 +403,14 @@ cog predict --use-replicate-token -i prompt="Hello" # Multiple environment variables cog run -e CUDA_VISIBLE_DEVICES=0 -e BATCH_SIZE=32 python train.py ``` + +# Selecting Ubuntu version for CUDA base image + +To select a specific Ubuntu version for the CUDA base image, set the environment variable `COG_UBUNTU_VERSION` before building: + +```bash +export COG_UBUNTU_VERSION=22.04 +cog build --use-cog-base-image=false +``` + +If not set, the latest supported Ubuntu version will be used. diff --git a/pkg/config/compatibility.go b/pkg/config/compatibility.go index ae9f5c68ca..b9c3da29eb 100644 --- a/pkg/config/compatibility.go +++ b/pkg/config/compatibility.go @@ -5,6 +5,7 @@ import ( _ "embed" "encoding/json" "fmt" + "os" "sort" "strings" @@ -255,13 +256,20 @@ func versionGreater(a string, b string) (bool, error) { func CUDABaseImageFor(cuda string, cuDNN string) (string, error) { var images []CUDABaseImage + ubuntuEnv := os.Getenv("COG_UBUNTU_VERSION") for _, image := range CUDABaseImages { if version.Matches(cuda, image.CUDA) && image.CuDNN == cuDNN { - images = append(images, image) + if ubuntuEnv == "" || image.Ubuntu == ubuntuEnv { + images = append(images, image) + } } } if len(images) == 0 { - return "", fmt.Errorf("No matching base image for CUDA %s and CuDNN %s", cuda, cuDNN) + ubuntuMsg := ubuntuEnv + if ubuntuEnv == "" { + ubuntuMsg = "any" + } + return "", fmt.Errorf("No matching base image for CUDA %s, CuDNN %s, Ubuntu %s", cuda, cuDNN, ubuntuMsg) } sort.Slice(images, func(i, j int) bool { diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 2aeba66454..0bcfaa9207 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -531,6 +531,46 @@ func TestCUDABaseImageTag(t *testing.T) { require.Equal(t, "nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04", imageTag) } +func TestCUDABaseImageTagWithUbuntuEnv(t *testing.T) { + // By default, CUDA 12.8 + Python 3.12 should select Ubuntu 24.04 + os.Unsetenv("COG_UBUNTU_VERSION") + configDefault := &Config{ + Build: &Build{ + GPU: true, + PythonVersion: "3.12", + CUDA: "12.8.0", + CuDNN: "9", + }, + } + + err := configDefault.ValidateAndComplete("") + require.NoError(t, err) + + imageTag, err := configDefault.CUDABaseImageTag() + require.NoError(t, err) + require.Equal(t, "nvidia/cuda:12.8.0-cudnn-devel-ubuntu24.04", imageTag) + + // If COG_UBUNTU_VERSION is set to 22.04, should select Ubuntu 22.04 image + os.Setenv("COG_UBUNTU_VERSION", "22.04") + configEnv := &Config{ + Build: &Build{ + GPU: true, + PythonVersion: "3.12", + CUDA: "12.8.0", + CuDNN: "9", + }, + } + + err = configEnv.ValidateAndComplete("") + require.NoError(t, err) + + imageTag, err = configEnv.CUDABaseImageTag() + require.NoError(t, err) + require.Equal(t, "nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04", imageTag) + + os.Unsetenv("COG_UBUNTU_VERSION") +} + func TestBuildRunItemStringYAML(t *testing.T) { type BuildWrapper struct { Build *Build `yaml:"build"`