Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
416f212
Fix requirements install order for bark to prevent pulling CUDA depen…
rampa3 Aug 7, 2025
b50cdd2
Use CPU Torch for CPU build of Chatterbox
rampa3 Aug 7, 2025
a572ddf
Patch libbackend to build Intel builds only if one is requested by bu…
rampa3 Aug 7, 2025
9e65421
Use CPU Torch for CPU build of Coqui
rampa3 Aug 7, 2025
09a32ed
Use CPU Torch for CPU build of diffusers
rampa3 Aug 7, 2025
704753d
Only use XPU in diffusers if available when requested
rampa3 Aug 7, 2025
6ba3b94
Ensure CPU mode usage if running diffusers on CPU
rampa3 Aug 7, 2025
e16f605
Force diffusers to use float32 only if running on CPU no matter what …
rampa3 Aug 7, 2025
64d4b70
Block bfloat16-only diffusers pipelines from running on CPU - deadloc…
rampa3 Aug 7, 2025
a25ff94
Extra CPU optimizations in diffusers
rampa3 Aug 7, 2025
131a590
Use CPU Torch for CPU build of faster-whisper
rampa3 Aug 7, 2025
1d94f2d
Add device type switching logic using Torch into faster-whisper
rampa3 Aug 7, 2025
39f32b0
Use CPU Torch for CPU build of kokoro
rampa3 Aug 7, 2025
2644b31
Use CPU Torch for CPU build of rerankers
rampa3 Aug 7, 2025
6938d6d
Use CPU Torch for CPU build of rfdetr
rampa3 Aug 7, 2025
3c09e79
Use CPU Torch for CPU build of transformers
rampa3 Aug 7, 2025
076fd9c
Create lib import code for CPU mode & only use XPU if available in tr…
rampa3 Aug 7, 2025
c77851d
Force transformers to use float32 only if running on CPU - deadlock p…
rampa3 Aug 7, 2025
5d4aad5
Update vLLM files for building CPU version from source & add CPU vLLM…
rampa3 Aug 7, 2025
fd5656b
Add CPU vLLM build logic into main Makefile
rampa3 Aug 7, 2025
e19f8b7
Revert "Use CPU Torch for CPU build of kokoro"
rampa3 Aug 7, 2025
7986a67
Pin kokoro Torch in CPU requirements to CPU version of the same relea…
rampa3 Aug 7, 2025
515ab68
Resolve conflicts and merge branch 'master' into python_builds_pr
rampa3 Aug 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,11 @@ docker-build-rerankers:
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:rerankers -f backend/Dockerfile.python --build-arg BACKEND=rerankers .

docker-build-vllm:
ifeq ($(BUILD_TYPE),)
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:vllm -f backend/Dockerfile.vllmcpu --build-arg BACKEND=vllm .
else
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:vllm -f backend/Dockerfile.python --build-arg BACKEND=vllm .
endif

docker-build-transformers:
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:transformers -f backend/Dockerfile.python --build-arg BACKEND=transformers .
Expand Down
64 changes: 64 additions & 0 deletions backend/Dockerfile.vllmcpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
ARG BASE_IMAGE=ubuntu:22.04
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why having a separate Dockerfile? when build-type is empty we already treat it as a CPU build

Copy link
Author

@rampa3 rampa3 Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why having a separate Dockerfile? when build-type is empty we already treat it as a CPU build

This Dockerfile is supposed to build vLLM from source, as PyPi similar to Torch only has CUDA release. The aim behind the CPU builds is to, where we don't need any other changes, use CPU specific builds of libraries, as for example just installing torch using PyPi repository on a CPU image adds baggage of more than 4 GBs worth of NVIDIA CUDA dependencies the package pulls in (It is well visible on the master CI build of Kitten TTS right now - the TTS itself is not GPU accelerated, but since one of the libs wants Torch, you get more than 5 GB extra dependencies in Torch + CUDA. I have just for fun built it locally with edited requirements to preinstall CPU Torch, and it fell to 1.16 GB image size.). That is why I went and blanket added extra index pointing to CPU releases of Torch everywhere. With vLLM, it is a bit more complicated - to get CPU release, it has to be built from source. We have a part of install.sh for that, but that part never runs with normal Dockerfile, as at some point, the argument FROM_SOURCE was removed. Since the build also has its specific deps, I made extra Dockerfile that installs build deps according to vLLM docs about building the CPU version. It also successfully builds, but for some reason crashes on init when called by LocalAI, and I have no idea how to properly get the whole stacktrace - GRPC returns only part of it. This is one of the reasons why the PR is a draft, not straight PR - I want to try to get this CPU build working.

Copy link
Author

@rampa3 rampa3 Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the end, that Dockerfile could become .vllm instead of .vllmcpu - every other GPU than NVIDIA needs to be built from source. But for the start, I focused on CPU, as that is the only platform I can reliably test.

image image

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I rename the file in preparation for potential addition of ROCm and XPU parts into it?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got the point of building vllm from CPU, but I've just run a diff manually locally here against the two Dockerfiles (Dockerfile.python and Dockerfile.vllmcpu) and I don't see notable differences. My point is more that I think we can still use the same Dockerfile, and handle the installation bits directly in the make/install of the backend, unless am I missing something?

--- backend/Dockerfile.vllmcpu	2025-08-08 16:43:25.145194390 +0200
+++ backend/Dockerfile.python	2025-08-08 16:43:15.812600946 +0200
@@ -1,11 +1,9 @@
 ARG BASE_IMAGE=ubuntu:22.04
 
 FROM ${BASE_IMAGE} AS builder
-ARG BACKEND=vllm
+ARG BACKEND=rerankers
 ARG BUILD_TYPE
 ENV BUILD_TYPE=${BUILD_TYPE}
-ARG FROM_SOURCE=true
-ENV FROM_SOURCE=${FROM_SOURCE}
 ARG CUDA_MAJOR_VERSION
 ARG CUDA_MINOR_VERSION
 ARG SKIP_DRIVERS=false
@@ -30,20 +28,81 @@ RUN apt-get update && \
         curl python3-pip \
         python-is-python3 \
         python3-dev llvm \
-        python3-venv make \
-        wget \
-        gcc-12 g++-12 \
-        libtcmalloc-minimal4 \
-        libnuma-dev \
-        ffmpeg \
-        libsm6 libxext6 \
-        libgl1 \
-        jq lsof && \
+        python3-venv make && \
     apt-get clean && \
-    update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 && \
     rm -rf /var/lib/apt/lists/* && \
     pip install --upgrade pip
 
+
+# Cuda
+ENV PATH=/usr/local/cuda/bin:${PATH}
+
+# HipBLAS requirements
+ENV PATH=/opt/rocm/bin:${PATH}
+
+# Vulkan requirements
+RUN <<EOT bash
+    if [ "${BUILD_TYPE}" = "vulkan" ] && [ "${SKIP_DRIVERS}" = "false" ]; then
+        apt-get update && \
+        apt-get install -y  --no-install-recommends \
+            software-properties-common pciutils wget gpg-agent && \
+        wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add - && \
+        wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list && \
+        apt-get update && \
+        apt-get install -y \
+            vulkan-sdk && \
+        apt-get clean && \
+        rm -rf /var/lib/apt/lists/*
+    fi
+EOT
+
+# CuBLAS requirements
+RUN <<EOT bash
+    if [ "${BUILD_TYPE}" = "cublas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then
+        apt-get update && \
+        apt-get install -y  --no-install-recommends \
+            software-properties-common pciutils
+        if [ "amd64" = "$TARGETARCH" ]; then
+            curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
+        fi
+        if [ "arm64" = "$TARGETARCH" ]; then
+            curl -O https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/arm64/cuda-keyring_1.1-1_all.deb
+        fi
+        dpkg -i cuda-keyring_1.1-1_all.deb && \
+        rm -f cuda-keyring_1.1-1_all.deb && \
+        apt-get update && \
+        apt-get install -y --no-install-recommends \
+            cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
+            libcufft-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
+            libcurand-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
+            libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
+            libcusparse-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} \
+            libcusolver-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} && \
+        apt-get clean && \
+        rm -rf /var/lib/apt/lists/*
+    fi
+EOT
+
+# If we are building with clblas support, we need the libraries for the builds
+RUN if [ "${BUILD_TYPE}" = "clblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
+        apt-get update && \
+        apt-get install -y --no-install-recommends \
+            libclblast-dev && \
+        apt-get clean && \
+        rm -rf /var/lib/apt/lists/* \
+    ; fi
+
+RUN if [ "${BUILD_TYPE}" = "hipblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
+        apt-get update && \
+        apt-get install -y --no-install-recommends \
+            hipblas-dev \
+            rocblas-dev && \
+        apt-get clean && \
+        rm -rf /var/lib/apt/lists/* && \
+        # I have no idea why, but the ROCM lib packages don't trigger ldconfig after they install, which results in local-ai and others not being able
+        # to locate the libraries. We run ldconfig ourselves to work around this packaging deficiency
+        ldconfig \
+    ; fi
 # Install uv as a system package
 RUN curl -LsSf https://astral.sh/uv/install.sh | UV_INSTALL_DIR=/usr/bin sh
 ENV PATH="/root/.cargo/bin:${PATH}"
@@ -60,5 +119,5 @@ COPY python/common/ /${BACKEND}/common
 RUN cd /${BACKEND} && make
 
 FROM scratch
-ARG BACKEND=vllm
-COPY --from=builder /${BACKEND}/ /
+ARG BACKEND=rerankers
+COPY --from=builder /${BACKEND}/ /

Copy link
Author

@rampa3 rampa3 Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well. the dependency block I talk about are dependencies from APT, as listed in the vLLM docs. That means that requirements-cpu.txt is not a way. They are GCC, the C++ libraries required to compile vLLM and few tools vLLM uses in its makfiles. Here is the block from vLLM docs that dictates the extra dependencies:

sudo apt-get update  -y
sudo apt-get install -y --no-install-recommends ccache git curl wget ca-certificates gcc-12 g++-12 libtcmalloc-minimal4 libnuma-dev ffmpeg libsm6 libxext6 libgl1 jq lsof
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12

We already have some, but we are missing these:

  • python3-venv
  • make
  • wget
  • gcc-12
  • g++-12
  • libtcmalloc-minimal4
  • libnuma-dev
  • ffmpeg
  • libsm6
  • libxext6
  • libgl1
  • jq
  • lsof

which all have to be installed as APT packages, since vLLM is compiled from C++ code. Normally, we just pull Python dependencies, as even CPU Torch is available already pre-compiled for the C++ parts. vLLM is compiled for CPU fully from scratch, so unless we decide to not ship CPU vLLM, we have to provide these somehow.

I can see if I can make it work with install.sh, since as it is just a shell script, and the builder runs as root, it should work. The only thing is, if I put them here, those building from custom Dockerfiles won't thank me - people build for example on Arch builders, and putting that there limits build platform to Debian & derivative distros only (without manual intervention). Dockerfile was chosen not only as experimentation shortcut (only the fact that there was a separate one was a testing shortcut), but also to keep the backend source directory platform agnostic.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, I think it's OK to put it in the Dockerfile.python builder. Especially because at the end of the day that container is used only for building, so in the worse case we would have to copy the libraries to the final backend during the packaging phase.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my suggestion here probably would be to do this step-by-step for each backend, or at least treat vLLM separately to not make this PR go stale.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my suggestion here probably would be to do this step-by-step for each backend, or at least treat vLLM separately to not make this PR go stale.

I agree. I think splitting it backend by backend will be the best way. I will prepare per backend branches and PRs for the ready ones ASAP. The working ones I will just have to figure out the CI for, the rest will be opened whenever I get a moment to sit down and finish them. Last few weeks were a bit busy, as I am in the middle of autumn terms for bachelor finals. I think with that, I will be closing this one then?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes sounds good to me, we can follow-up on the other PRs. Thanks! (and good luck with your finals! )


FROM ${BASE_IMAGE} AS builder
ARG BACKEND=vllm
ARG BUILD_TYPE
ENV BUILD_TYPE=${BUILD_TYPE}
ARG FROM_SOURCE=true
ENV FROM_SOURCE=${FROM_SOURCE}
ARG CUDA_MAJOR_VERSION
ARG CUDA_MINOR_VERSION
ARG SKIP_DRIVERS=false
ENV CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION}
ENV CUDA_MINOR_VERSION=${CUDA_MINOR_VERSION}
ENV DEBIAN_FRONTEND=noninteractive
ARG TARGETARCH
ARG TARGETVARIANT

RUN apt-get update && \
apt-get install -y --no-install-recommends \
build-essential \
ccache \
ca-certificates \
espeak-ng \
curl \
libssl-dev \
git \
git-lfs \
unzip \
upx-ucl \
curl python3-pip \
python-is-python3 \
python3-dev llvm \
python3-venv make \
wget \
gcc-12 g++-12 \
libtcmalloc-minimal4 \
libnuma-dev \
ffmpeg \
libsm6 libxext6 \
libgl1 \
jq lsof && \
apt-get clean && \
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 && \
rm -rf /var/lib/apt/lists/* && \
pip install --upgrade pip

# Install uv as a system package
RUN curl -LsSf https://astral.sh/uv/install.sh | UV_INSTALL_DIR=/usr/bin sh
ENV PATH="/root/.cargo/bin:${PATH}"

RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y

# Install grpcio-tools (the version in 22.04 is too old)
RUN pip install --user grpcio-tools==1.71.0 grpcio==1.71.0

COPY python/${BACKEND} /${BACKEND}
COPY backend.proto /${BACKEND}/backend.proto
COPY python/common/ /${BACKEND}/common

RUN cd /${BACKEND} && make

FROM scratch
ARG BACKEND=vllm
COPY --from=builder /${BACKEND}/ /
6 changes: 4 additions & 2 deletions backend/python/bark/requirements-cpu.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
bark==0.1.5
transformers
accelerate
torch==2.4.1
torchaudio==2.4.1
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.4.1+cpu
torchaudio==2.4.1+cpu
3 changes: 2 additions & 1 deletion backend/python/bark/requirements-cublas11.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
bark==0.1.5
--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.4.1+cu118
torchaudio==2.4.1+cu118
transformers
accelerate
accelerate
3 changes: 2 additions & 1 deletion backend/python/bark/requirements-cublas12.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
bark==0.1.5
torch==2.4.1
torchaudio==2.4.1
transformers
accelerate
accelerate
3 changes: 2 additions & 1 deletion backend/python/bark/requirements-hipblas.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
bark==0.1.5
--extra-index-url https://download.pytorch.org/whl/rocm6.0
torch==2.4.1+rocm6.0
torchaudio==2.4.1+rocm6.0
transformers
accelerate
accelerate
3 changes: 2 additions & 1 deletion backend/python/bark/requirements-intel.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
bark==0.1.5
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
intel-extension-for-pytorch==2.3.110+xpu
torch==2.3.1+cxx11.abi
Expand All @@ -6,4 +7,4 @@ oneccl_bind_pt==2.3.100+xpu
optimum[openvino]
setuptools
transformers
accelerate
accelerate
3 changes: 1 addition & 2 deletions backend/python/bark/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
bark==0.1.5
grpcio==1.71.0
protobuf
certifi
certifi
5 changes: 3 additions & 2 deletions backend/python/chatterbox/requirements-cpu.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
accelerate
torch==2.6.0
torchaudio==2.6.0
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.6.0+cpu
torchaudio==2.6.0+cpu
transformers==4.46.3
chatterbox-tts
4 changes: 2 additions & 2 deletions backend/python/common/libbackend.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ function getBuildProfile() {
return 0
fi

# If /opt/intel exists, then we are doing an intel/ARC build
if [ -d "/opt/intel" ]; then
# If /opt/intel exists and BUILD_TYPE is one of the Intel ones, then we are doing an intel/ARC build
if [[ -d "/opt/intel" && ( x"${BUILD_TYPE}" == "xintel" || ( x"${BUILD_TYPE}" == "xsycl_f16" || x"${BUILD_TYPE}" == "xsycl_f32" ) ) ]]; then
echo "intel"
return 0
fi
Expand Down
5 changes: 3 additions & 2 deletions backend/python/coqui/requirements-cpu.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
transformers==4.48.3
accelerate
torch==2.4.1
coqui-tts
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.4.1+cpu
coqui-tts
35 changes: 29 additions & 6 deletions backend/python/diffusers/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,20 @@

_ONE_DAY_IN_SECONDS = 60 * 60 * 24
COMPEL = os.environ.get("COMPEL", "0") == "1"
XPU = os.environ.get("XPU", "0") == "1"
# Attempt to use XPU only if Torch says it is available when asking for it
XPU = ((os.environ.get("XPU", "0") == "1") & (torch.xpu.is_available()))
CLIPSKIP = os.environ.get("CLIPSKIP", "1") == "1"
SAFETENSORS = os.environ.get("SAFETENSORS", "1") == "1"
CHUNK_SIZE = os.environ.get("CHUNK_SIZE", "8")
FPS = os.environ.get("FPS", "7")
DISABLE_CPU_OFFLOAD = os.environ.get("DISABLE_CPU_OFFLOAD", "0") == "1"
FRAMES = os.environ.get("FRAMES", "64")

# Set Torch to use all logical CPU cores for CPU mode
num_cores = os.cpu_count()
torch.set_num_threads(max(1, num_cores // 2))
torch.set_num_interop_threads(num_cores)

if XPU:
print(torch.xpu.get_device_name(0))

Expand Down Expand Up @@ -166,7 +172,8 @@ def LoadModel(self, request, context):
torchType = torch.float32
variant = None

if request.F16Memory:
# Only use f16 if not running on CPU - forcing f16 on CPU causes freezes (https://github.com/pytorch/pytorch/issues/75458)
if (request.F16Memory & ((request.CUDA & torch.cuda.is_available()) | XPU)):
torchType = torch.float16
variant = "fp16"

Expand All @@ -189,12 +196,18 @@ def LoadModel(self, request, context):
value = int(value)
self.options[key] = value

# From options, extract if present "torch_dtype" and set it to the appropriate type
# From options, extract if present "torch_dtype" and set it to the appropriate type; if on CPU, always force float32
if "torch_dtype" in self.options:
if self.options["torch_dtype"] == "fp16":
torchType = torch.float16
if not ((request.CUDA & torch.cuda.is_available()) | XPU):
torchType = torch.float32
else:
torchType = torch.float16
elif self.options["torch_dtype"] == "bf16":
torchType = torch.bfloat16
if not ((request.CUDA & torch.cuda.is_available()) | XPU):
torchType = torch.float32
else:
torchType = torch.bfloat16
elif self.options["torch_dtype"] == "fp32":
torchType = torch.float32
# remove it from options
Expand Down Expand Up @@ -290,6 +303,8 @@ def LoadModel(self, request, context):
use_safetensors=True,
variant=variant)
elif request.PipelineType == "FluxPipeline":
if not ((request.CUDA & torch.cuda.is_available()) | XPU):
raise RuntimeError("Flux requires f16. Cannot run diffusers using f16 on CPU - doing so causes deadlocks. Refer to: https://github.com/pytorch/pytorch/issues/75458")
if fromSingleFile:
self.pipe = FluxPipeline.from_single_file(modelFile,
torch_dtype=torchType,
Expand All @@ -301,6 +316,8 @@ def LoadModel(self, request, context):
if request.LowVRAM:
self.pipe.enable_model_cpu_offload()
elif request.PipelineType == "FluxTransformer2DModel":
if not ((request.CUDA & torch.cuda.is_available()) | XPU):
raise RuntimeError("Flux requires f16. Cannot run diffusers using f16 on CPU - doing so causes deadlocks. Refer to: https://github.com/pytorch/pytorch/issues/75458")
dtype = torch.bfloat16
# specify from environment or default to "ChuckMcSneed/FLUX.1-dev"
bfl_repo = os.environ.get("BFL_REPO", "ChuckMcSneed/FLUX.1-dev")
Expand All @@ -319,12 +336,16 @@ def LoadModel(self, request, context):
if request.LowVRAM:
self.pipe.enable_model_cpu_offload()
elif request.PipelineType == "Lumina2Text2ImgPipeline":
if not ((request.CUDA & torch.cuda.is_available()) | XPU):
raise RuntimeError("Lumina requires f16. Cannot run diffusers using f16 on CPU - doing so causes deadlocks. Refer to: https://github.com/pytorch/pytorch/issues/75458")
self.pipe = Lumina2Text2ImgPipeline.from_pretrained(
request.Model,
torch_dtype=torch.bfloat16)
if request.LowVRAM:
self.pipe.enable_model_cpu_offload()
elif request.PipelineType == "SanaPipeline":
if not ((request.CUDA & torch.cuda.is_available()) | XPU):
raise RuntimeError("Sana requires f16. Cannot run diffusers using f16 on CPU - doing so causes deadlocks. Refer to: https://github.com/pytorch/pytorch/issues/75458")
self.pipe = SanaPipeline.from_pretrained(
request.Model,
variant="bf16",
Expand Down Expand Up @@ -362,7 +383,7 @@ def LoadModel(self, request, context):
# modify LoraAdapter to be relative to modelFileBase
request.LoraAdapter = os.path.join(request.ModelPath, request.LoraAdapter)

device = "cpu" if not request.CUDA else "cuda"
device = "cpu" if not (request.CUDA & torch.cuda.is_available()) else "cuda"
if XPU:
device = "xpu"
self.device = device
Expand Down Expand Up @@ -392,6 +413,8 @@ def LoadModel(self, request, context):
self.pipe.to(device)
if self.controlnet:
self.controlnet.to(device)
else:
self.pipe.to("cpu")

except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
Expand Down
5 changes: 3 additions & 2 deletions backend/python/diffusers/requirements-cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ accelerate
compel
peft
sentencepiece
torch==2.7.1
optimum-quanto
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.7.1+cpu
optimum-quanto
4 changes: 4 additions & 0 deletions backend/python/diffusers/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ else
source $backend_dir/../common/libbackend.sh
fi

# Set thread counts for CPU mode
export OMP_NUM_THREADS=$(nproc)
export MKL_NUM_THREADS=$(nproc)

if [ -d "/opt/intel" ]; then
# Assumes we are using the Intel oneAPI container image
# https://github.com/intel/intel-extension-for-pytorch/issues/538
Expand Down
8 changes: 6 additions & 2 deletions backend/python/faster-whisper/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import signal
import sys
import os
import torch
import backend_pb2
import backend_pb2_grpc

Expand All @@ -31,14 +32,17 @@ def Health(self, request, context):
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
def LoadModel(self, request, context):
device = "cpu"
precision = "float32"
# Get device
# device = "cuda" if request.CUDA else "cpu"
if request.CUDA:
# Detecting CUDA availability using Torch.
if (request.CUDA & torch.cuda.is_available()):
device = "cuda"
precision="float16"

try:
print("Preparing models, please wait", file=sys.stderr)
self.model = WhisperModel(request.Model, device=device, compute_type="float16")
self.model = WhisperModel(request.Model, device=device, compute_type=precision)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
# Implement your logic here for the LoadModel service
Expand Down
5 changes: 3 additions & 2 deletions backend/python/faster-whisper/requirements-cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ accelerate
compel
peft
sentencepiece
torch==2.4.1
optimum-quanto
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.4.1+cpu
optimum-quanto
4 changes: 2 additions & 2 deletions backend/python/kokoro/requirements-cpu.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
--extra-index-url https://download.pytorch.org/whl/cpu
transformers
accelerate
torch
torch==2.7.1+cpu
kokoro
soundfile
soundfile
5 changes: 3 additions & 2 deletions backend/python/rerankers/requirements-cpu.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
transformers
accelerate
torch==2.4.1
rerankers[transformers]
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.4.1+cpu
rerankers[transformers]
5 changes: 3 additions & 2 deletions backend/python/rfdetr/requirements-cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ opencv-python
accelerate
peft
inference
torch==2.7.1
optimum-quanto
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.7.1+cpu
optimum-quanto
17 changes: 13 additions & 4 deletions backend/python/transformers/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import torch
import torch.cuda


XPU=os.environ.get("XPU", "0") == "1"
# Attempt to use XPU only if Torch says it is available when asking for it
XPU = ((os.environ.get("XPU", "0") == "1") & (torch.xpu.is_available()))
from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, MambaConfig, MambaForCausalLM
from transformers import AutoProcessor, MusicgenForConditionalGeneration, DiaForConditionalGeneration
from scipy.io import wavfile
Expand Down Expand Up @@ -83,8 +83,14 @@ def LoadModel(self, request, context):
if os.path.exists(request.ModelFile):
model_name = request.ModelFile

compute = torch.float16
if request.F16Memory == True:
# Use float32 for CPU inference
if (torch.cuda.is_available() | XPU):
compute = torch.float16
else:
compute = torch.float32

# Only use f16 if not running on CPU - forcing f16 on CPU causes freezes (https://github.com/pytorch/pytorch/issues/75458)
if (request.F16Memory & (torch.cuda.is_available() | XPU)) == True:
compute=torch.bfloat16

self.CUDA = torch.cuda.is_available()
Expand Down Expand Up @@ -122,6 +128,9 @@ def LoadModel(self, request, context):

print(f"Parsed options: {self.options}", file=sys.stderr)

if not (self.CUDA | XPU):
from transformers import BitsAndBytesConfig, AutoModelForCausalLM

if self.CUDA:
from transformers import BitsAndBytesConfig, AutoModelForCausalLM
if request.MainGPU:
Expand Down
Loading