diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 000000000..ee781590c --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,293 @@ +# This CMake config hopefully makes it easier to compile. +# Ensure the CUDA Toolkit is available on your path. Then run: +# For GCC: `cmake -B build . && cmake --build build` +# For MSVC: `cmake -B build . && cmake --build build --config Release` +# You can also use the following options and variables +# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend +# - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support +# - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version +# is whatever CMake finds on your path. +# - COMPUTE_CAPABILITY: Which GPU Arch/Compute codes to provide to NVCC. +# Separate by semicolons, i.e. `-DCOMPUTE_CAPABILITY=89;90` +# Check your compute capability here: https://developer.nvidia.com/cuda-gpus +# - PTXAS_VERBOSE: Pass the `-v` option to the PTX Assembler +cmake_minimum_required(VERSION 3.22.1) + +project(bitsandbytes LANGUAGES CXX) + +# If run without specifying a build type, default to using the Release configuration: +# optimizing the generated binaries for performance and also adds the `-DNDEBUG` flag, +# which turns off a bunch of asserts which seem to link to new symbols in libstdc++, +# worsening our many_linux compliance.. +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + +# Define included source files +set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.c) +#set(CUDA_FILES csrc/ops.cu csrc/kernels.cu) +set(CUDA_FILES csrc/ops.hip.cpp csrc/kernels.hip.cpp) +#set(CUDA_FILES obsolete/ops.cu obsolete/kernels.cu) +set(MPS_FILES csrc/mps_ops.mm) +set(METAL_FILES csrc/mps_kernels.metal) +# C++ sources are always included +list(APPEND SRC_FILES ${CPP_FILES}) + +set(COMPUTE_BACKEND "hip" CACHE STRING "The compute backend to use (cpu, cuda, mps, hip)") +set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda mps hip) +option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF) + +if(NOT DEFINED HIP_PATH) + if(NOT DEFINED ENV{HIP_PATH}) + set(HIP_PATH "/opt/rocm/hip" CACHE PATH "Path to which HIP has been installed") + else() + set(HIP_PATH $ENV{HIP_PATH} CACHE PATH "Path to which HIP has been installed") + endif() +endif() +message("HIP_PATH: " ${HIP_PATH}) +set(CMAKE_MODULE_PATH "${HIP_PATH}/cmake" ${CMAKE_MODULE_PATH}) +find_package(HIP REQUIRED) +if (HIP_FOUND) + message(STATUS "Found HIP: " ${HIP_VERSION}) +else() + message(FATAL_ERROR "Could not find HIP") +endif() +find_package(rocthrust REQUIRED) +find_package(hipblas REQUIRED) +find_package(hipsparse REQUIRED) +find_package(rocrand REQUIRED) +# Search for rocm in common locations +list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hip /opt/rocm /opt/rocm) +list(APPEND HIP_PATH /opt/rocm/llvm/bin/) +# Find HIP. +# The user may override AMDGPU_TARGETS defined in the HIP config file +# to select the AMDGPU archs to compile for. +# ex. set(AMDGPU_TARGETS "gfx803;gfx900;gfx906") +# Find OpenMP. +#find_package(OpenMP REQUIRED) +# Set compiler and linker. +if(NOT WIN32) + set(CMAKE_CXX_COMPILER ${HIP_HIPCC_EXECUTABLE}) + set(CMAKE_CXX_LINKER ${HIP_HIPCC_EXECUTABLE}) + set(CMAKE_CXXFLAGS -D__HIP_PLATFORM_AMD__) + set(CMAKE_CFLAGS -D__HIP_PLATFORM_AMD__) +endif() +message("Current CMAKE_CXX_COMPILER (should show hipcc): " ${CMAKE_CXX_COMPILER}) +message("Current CMAKE_CXX_LINKER (should show hipcc): " ${CMAKE_CXX_LINKER}) + +if(APPLE) + set(CMAKE_OSX_DEPLOYMENT_TARGET 13.1) +endif() + +set(BNB_OUTPUT_NAME "bitsandbytes") + +message(STATUS "Configuring ${PROJECT_NAME} (Backend: ${COMPUTE_BACKEND})") + +if(${COMPUTE_BACKEND} STREQUAL "cuda") + if(APPLE) + message(FATAL_ERROR "CUDA is not supported on macOS" ) + endif() + option(NO_CUBLASLT "Disable CUBLAS" OFF) + set(BUILD_CUDA ON) + set(BUILD_MPS OFF) + message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}") +elseif(${COMPUTE_BACKEND} STREQUAL "mps") + if(NOT APPLE) + message(FATAL_ERROR "MPS is only supported on macOS" ) + endif() + set(BUILD_CUDA OFF) + set(BUILD_MPS ON) +elseif(${COMPUTE_BACKEND} STREQUAL "hip") + set(BUILD_HIP on) + set(BUILD_CUDA OFF) + set(BUILD_MPS OFF) + set(NO_CUBLASLT ON) +else() + set(BUILD_CUDA OFF) + set(BUILD_MPS OFF) +endif() + + +if(BUILD_CUDA) + enable_language(CUDA) # This will fail if CUDA is not found + find_package(CUDAToolkit REQUIRED) + + # Convert the CUDA version from X.Y.z to XY. There's probably a shorter way of doing this + string(REGEX MATCH "^[0-9]+.[0-9]+" _CUDA_VERSION_FIRST_TWO "${CMAKE_CUDA_COMPILER_VERSION}") + string(REPLACE "." "" CUDA_VERSION_SHORT "${_CUDA_VERSION_FIRST_TWO}") + + # Expose a cache variable that the user can set to ensure the correct version of CUDA is found + set(CUDA_VERSION "${CUDA_VERSION_SHORT}" CACHE STRING "Expected CUDA Version Shortcode") + + message(STATUS "CUDA Version: ${CUDA_VERSION_SHORT} (${CMAKE_CUDA_COMPILER_VERSION})") + message(STATUS "CUDA Compiler: ${CMAKE_CUDA_COMPILER}") + + # It should match the discovered version + if(NOT CUDA_VERSION STREQUAL "${CUDA_VERSION_SHORT}") + message(FATAL_ERROR "You've specified CUDA version ${CUDA_VERSION} however the CUDA compiler found is ${CUDA_VERSION_SHORT}." + " Ensure the desired CUDA compiler is the first one available on your PATH." + ) + endif() + + if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS "11.0") + message(FATAL_ERROR "CUDA Version < 11 is not supported") + elseif(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "13.0") + message(FATAL_ERROR "CUDA Version > 12 is not supported") + endif() + + # CMake < 3.23.0 does not define CMAKE_CUDA_ARCHITECTURES_ALL. + if(CMAKE_VERSION VERSION_LESS "3.23.0") + message(STATUS "CMake < 3.23.0; determining CUDA architectures supported...") + + # 11.x and 12.x both support these at a minimum. + set(CMAKE_CUDA_ARCHITECTURES_ALL 50 52 53 60 61 62 70 72 75 80) + set(CMAKE_CUDA_ARCHITECTURES_ALL_MAJOR 50 60 70 80) + + # CUDA 11.1 adds Ampere support for GA102-GA107. + if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "11.1") + list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL 86) + endif() + + # CUDA 11.4 adds Ampere support for GA10B. + if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "11.4") + list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL 87) + endif() + + # CUDA 11.8 adds support for Ada and Hopper. + if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "11.8") + list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL 89 90) + list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL_MAJOR 90) + endif() + endif() + + string(APPEND CMAKE_CUDA_FLAGS " --use_fast_math") + + if(PTXAS_VERBOSE) + # Verbose? Outputs register usage information, and other things... + string(APPEND CMAKE_CUDA_FLAGS " -Xptxas=-v") + endif() + + foreach(capability ${CMAKE_CUDA_ARCHITECTURES_ALL}) + # Most of the items here are like: `xx-real`, so we just extract the `xx` portion + string(REGEX MATCH "[0-9]+" capability_id "${capability}") + if(capability_id GREATER 0) + list(APPEND POSSIBLE_CAPABILITIES ${capability_id}) + endif() + endforeach() + + # This can be changed via -D argument to CMake + # By default all possible capabilities are compiled + set(COMPUTE_CAPABILITY "${POSSIBLE_CAPABILITIES}" CACHE STRING "Compute Capabilities Targeted") + + message(STATUS "CUDA Capabilities Available: ${POSSIBLE_CAPABILITIES}") + message(STATUS "CUDA Capabilities Selected: ${COMPUTE_CAPABILITY}") + + # Use the "real" option to build native cubin for all selections. + # Ensure we build the PTX for the latest version. + # This behavior of adding a PTX (virtual) target for the highest architecture + # is similar to how the "all" and "all-major" options would behave in CMake >= 3.23. + # TODO: Consider bumping CMake requirement and using CMAKE_CUDA_ARCHITECTURES=[all | native] by default + list(REMOVE_DUPLICATES COMPUTE_CAPABILITY) + list(SORT COMPUTE_CAPABILITY COMPARE NATURAL) + list(POP_BACK COMPUTE_CAPABILITY _LATEST_CAPABILITY) + list(TRANSFORM COMPUTE_CAPABILITY APPEND "-real" OUTPUT_VARIABLE CMAKE_CUDA_ARCHITECTURES) + list(APPEND CMAKE_CUDA_ARCHITECTURES ${_LATEST_CAPABILITY}) + + message(STATUS "CUDA Targets: ${CMAKE_CUDA_ARCHITECTURES}") + message(STATUS "CUDA NVCC Flags: ${CMAKE_CUDA_FLAGS}") + + list(APPEND SRC_FILES ${CUDA_FILES}) + + string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") + if(NO_CUBLASLT) + string(APPEND BNB_OUTPUT_NAME "_cuda110_nocublaslt") + endif() + add_compile_definitions(BUILD_CUDA) +elseif(BUILD_MPS) + if(NOT APPLE) + message(FATAL_ERROR "MPS is only supported on macOS" ) + endif() + + enable_language(OBJCXX) + + list(APPEND SRC_FILES ${MPS_FILES}) + + string(APPEND BNB_OUTPUT_NAME "_mps") + add_compile_definitions(BUILD_MPS) + file(MAKE_DIRECTORY "build") + add_custom_command(OUTPUT "bitsandbytes/bitsandbytes.metallib" + COMMAND xcrun metal -c -o "build/bitsandbytes.air" ${METAL_FILES} + COMMAND xcrun metallib "build/bitsandbytes.air" -o "bitsandbytes/bitsandbytes.metallib" + DEPENDS "${METAL_FILES}" + COMMENT "Compiling Metal kernels" + VERBATIM) + add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib") +elseif(BUILD_HIP) + list(APPEND SRC_FILES ${CUDA_FILES}) + # real name + string(APPEND BNB_OUTPUT_NAME "_hip_nohipblaslt") + add_compile_definitions(BUILD_HIP) +else() + string(APPEND BNB_OUTPUT_NAME "_cpu") + set(GPU_SOURCES) +endif() + + +if(WIN32) + # Export all symbols + set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) +endif() + +# Weird MSVC hacks +if(MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2 /fp:fast") +endif() + +if (BUILD_HIP) + set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX) + message("Working on: " ${CPP_FILES}) + add_library(bitsandbytes SHARED ${SRC_FILES}) + target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include /opt/rocm/include/hipblas /opt/rocm/include/rocblas /opt/rocm/include/hipblaslt /opt/rocm/include/hipsparse /opt/rocm/include/hipcub /opt/rocm/include/rocwmma) + target_compile_features(bitsandbytes PUBLIC cxx_std_14) + target_compile_definitions(bitsandbytes PUBLIC NO_CUBLASLT) + target_include_directories(bitsandbytes PUBLIC csrc include) + target_link_libraries(bitsandbytes PUBLIC hip::device roc::rocthrust roc::hipblas roc::hipsparse roc::rocrand roc::rocprim) +else() + set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX) + add_library(bitsandbytes SHARED ${SRC_FILES}) + target_compile_features(bitsandbytes PUBLIC cxx_std_14) + target_include_directories(bitsandbytes PUBLIC csrc include) + target_link_libraries(bitsandbytes PUBLIC hip::device) +endif() + +if(BUILD_CUDA) + target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cusparse) + if(NO_CUBLASLT) + target_compile_definitions(bitsandbytes PUBLIC NO_CUBLASLT) + else() + target_link_libraries(bitsandbytes PUBLIC CUDA::cublasLt) + endif() + + set_target_properties(bitsandbytes + PROPERTIES + CUDA_SEPARABLE_COMPILATION ON + ) +endif() +if(BUILD_MPS) + add_dependencies(bitsandbytes metallib) + target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph") +endif() + +if(WIN32) + set_target_properties(bitsandbytes PROPERTIES PREFIX "lib") +endif() +set_target_properties(bitsandbytes PROPERTIES OUTPUT_NAME ${BNB_OUTPUT_NAME}) +if(MSVC) + set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY_RELEASE "${PROJECT_SOURCE_DIR}/bitsandbytes") + set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY_DEBUG "${PROJECT_SOURCE_DIR}/bitsandbytes") + set_target_properties(bitsandbytes PROPERTIES RUNTIME_OUTPUT_DIRECTORY_RELEASE "${PROJECT_SOURCE_DIR}/bitsandbytes") + set_target_properties(bitsandbytes PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG "${PROJECT_SOURCE_DIR}/bitsandbytes") +endif() + +set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${PROJECT_SOURCE_DIR}/bitsandbytes") \ No newline at end of file diff --git a/CMakeLists_cpu.txt b/CMakeLists_cpu.txt new file mode 100644 index 000000000..8610ec023 --- /dev/null +++ b/CMakeLists_cpu.txt @@ -0,0 +1,70 @@ +#this method follows "Consuming the HIP API in C++ code" +#https://rocm.docs.amd.com/en/latest/conceptual/cmake-packages.html#consuming-the-hip-api-in-c-code + +project(bitsandbytes) + +cmake_minimum_required(VERSION 3.21) + +message("Is it Window?: " ${CMAKE_HOST_WIN32}) +message("Is it Linux?: " ${CMAKE_HOST_UNIX}) + +if(NOT DEFINED HIP_PATH) + if(NOT DEFINED ENV{HIP_PATH}) + set(HIP_PATH "/opt/rocm/hip" CACHE PATH "Path to which HIP has been installed") + else() + set(HIP_PATH $ENV{HIP_PATH} CACHE PATH "Path to which HIP has been installed") + endif() +endif() +message("HIP_PATH: " ${HIP_PATH}) +set(CMAKE_MODULE_PATH "${HIP_PATH}/cmake" ${CMAKE_MODULE_PATH}) + +find_package(HIP REQUIRED) +if (HIP_FOUND) + message(STATUS "Found HIP: " ${HIP_VERSION}) +else() + message(FATAL_ERROR "Could not find HIP") +endif() + +find_package(rocThrust REQUIRED) +find_package(MIOpen REQUIRED) + +# Search for rocm in common locations +list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hip /opt/rocm-6.0.2 /opt/rocm) +list(APPEND HIP_PATH /opt/rocm-6.0.2/llvm/bin/) +# Find HIP. +# The user may override AMDGPU_TARGETS defined in the HIP config file +# to select the AMDGPU archs to compile for. +# ex. set(AMDGPU_TARGETS "gfx803;gfx900;gfx906") + +# Find OpenMP. +#find_package(OpenMP REQUIRED) + +# Set compiler and linker. +if(NOT WIN32) + set(CMAKE_CXX_COMPILER ${HIP_HIPCC_EXECUTABLE}) + set(CMAKE_CXX_LINKER ${HIP_HIPCC_EXECUTABLE}) +endif() +message("Current CMAKE_CXX_COMPILER (should show hipcc): " ${CMAKE_CXX_COMPILER}) +message("Current CMAKE_CXX_LINKER (should show hipcc): " ${CMAKE_CXX_LINKER}) + +set(CMAKE_BUILD_TYPE Release) +set(BUILD_CUDA ON) +set(HIPCC_VERBOSE 1) + +# Source files. +# all file +#set(CPP_SOURCES ${CMAKE_SOURCE_DIR}/kernels.hip.cpp ${CMAKE_SOURCE_DIR}/kernels.hip.h ${CMAKE_SOURCE_DIR}/common.hip.cpp ${CMAKE_SOURCE_DIR}/common.hip.h ${CMAKE_SOURCE_DIR}/cpu_ops.hip.cpp ${CMAKE_SOURCE_DIR}/cpu_ops.hip.h ${CMAKE_SOURCE_DIR}/ops.hip.cpp ${CMAKE_SOURCE_DIR}/ops.hip.h ${CMAKE_SOURCE_DIR}/pythonInterface.hip.cpp) +set(CPP_SOURCES ${CMAKE_SOURCE_DIR}/csrc/cpu_ops.hip.cpp ${CMAKE_SOURCE_DIR}/csrc/cpu_ops.hip.h ${CMAKE_SOURCE_DIR}/csrc/common.hip.cpp ${CMAKE_SOURCE_DIR}/csrc/common.hip.h) +message("CMAKE_SOURCE_DIR: " ${CMAKE_SOURCE_DIR}) + +# Preparing the executable. +#Add an executable target called "test_openmp_helloworld" to be built from the source files listed in the command invocation. +add_executable(test_bitsandbytes ${CPP_SOURCES}) + +# Link Libraries - HIP Device and OpenMP. +#target_compile_options(test_vector_add PRIVATE ${OpenMP_CXX_FLAGS}) +target_link_libraries(test_bitsandbytes PUBLIC hip::device) +target_include_directories(test_bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/csrc ${CMAKE_SOURCE_DIR}/include /opt/rocm-6.0.2/include/hipblas /opt/rocm-6.0.2/include/hipblaslt /opt/rocm-6.0.2/include/hipsparse /opt/rocm-6.0.2/include/hipcub /opt/rocm-6.0.2/include/rocwmma) +target_compile_features(test_bitsandbytes PUBLIC cxx_std_14) +#set(CMAKE_CXX_FLAGS_INIT -nostartfiles) +#set(CMAKE_SHARED_LINKER_FLAGS "-nostartfiles") \ No newline at end of file diff --git a/README.md b/README.md index 882bd10de..93c1595b8 100644 --- a/README.md +++ b/README.md @@ -1,170 +1,96 @@ -# bitsandbytes-rocm +# `bitsandbytes` -The bitsandbytes is a lightweight wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and quantization functions. -This fork is the ROCm adaptation of bitsandbytes 0.39.1. The repo is inspired by [agrocylo/bitsandbytes-rocm](https://github.com/agrocylo/bitsandbytes-rocm/tree/main/bitsandbytes), which is a ROCm version of bitsandbytes 0.37. While this fork incorporating the majority of features from bitsandbytes 0.39.1, including the crucial 4 bit quantization feature, certain features such as hipblaslt and hip_bfloat16 have been disabled. Enabling these features is listed as a task for the future. +This is a ROCM port of the `bitsandbytes` package. Early version of this port has been attempted but they were poorly maintained and thus, no longer applicable starting from ROCM 6.0. This repo aims at resolving this pain point and bring the ability to fine tune LLM to AMD GPU +ROCM 6.0 official supports the following model in the `gfx1100` family (7900 XTX, 7900 XT, 7900 GRE). +For `gfx1101` (7700XT, 7800XT) and `gfx1102` (7600, 7600XT) as well as other not officially supported card, user can try setting the following flag in pyTorch -Resources: -- [8-bit Optimizer Paper](https://arxiv.org/abs/2110.02861) -- [Video](https://www.youtube.com/watch?v=IxrlHAJtqKE) -- [Docs](https://bitsandbytes.readthedocs.io/en/latest/) +HSA_OVERRIDE_GFX_VERSION=11.0.0. +https://github.com/ROCm/ROCm/issues/2901#issuecomment-1950136950 -- [LLM.int8() Paper](https://arxiv.org/abs/2208.07339) -- [LLM.int8() Software Blog Post](https://huggingface.co/blog/hf-bitsandbytes-integration) -- [LLM.int8() Emergent Features Blog Post](https://timdettmers.com/2022/08/17/llm-int8-and-emergent-features/) -## TL;DR -**Requirements** -Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + ROCm >= 5.4.2 or CUDA > 10.0 +As of ROCM 6.0, the following GPU is shown as supported architecture by the compiler but has not been tested. User feedback is encouraged. +-`gfx1030` (6800, 6800XT, 6900XT, 6950XT) -**Installation**: +-`gfx942` (MI300) +-`gfx908` (MI200) -You need to compile from source for ROCm. +-`gfx906` -Compilation quickstart: -```bash -# Run Docker -docker run -it --network=host --device=/dev/kfd --device=/dev/dri --name=bnb_test --shm-size=8g --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --group-add video rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1 +-`gfx900` +-`gfx90a` -# Install Dependencies -cd -git clone --recurse https://github.com/ROCmSoftwarePlatform/hipBLASLt -cd hipBLASLt -git checkout 4b3b34405e7e25cff404f69bfd0a832644430477 -./install.sh -idc - -cd .. -pip install einops lion_pytorch +Best effort has been made to to ensure the code runs as well as possible. However, feel free to open an issue or make a PR +The current repo has been compiled and test run on the following system config -# Install BitsandBytes -git clone --recurse https://github.com/ROCmSoftwarePlatform/bitsandbytes -cd bitsandbytes -git checkout rocm_enabled -make hip -python setup.py install +Intel i7-13700K +7900 XTX 24 GB -# Run the unit test. If it runs successfully, the library has been installed successfully. -pytest -vvv ./tests/ 2>&1 | tee BitsAndBytes_UT_summary.log -``` +64GB system memory -**Using Int8 inference with HuggingFace Transformers** +ROCM version: 6.0.2 -```python -from transformers import AutoModelForCausalLM -model = AutoModelForCausalLM.from_pretrained( - 'decapoda-research/llama-7b-hf', - device_map='auto', - load_in_8bit=True, - max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB') -``` +Ubuntu 22.04 -A more detailed example, can be found in [examples/int8_inference_huggingface.py](examples/int8_inference_huggingface.py). +## Installation +Clone this repo or download the source code -**Using 8-bit optimizer**: -1. Comment out optimizer: ``#torch.optim.Adam(....)`` -2. Add 8-bit optimizer of your choice ``bnb.optim.Adam8bit(....)`` (arguments stay the same) -3. Replace embedding layer if necessary: ``torch.nn.Embedding(..) -> bnb.nn.Embedding(..)`` +Create virtual environment +conda create -n mypython39-rocm python=3.9 -**Using 8-bit Inference**: -1. Comment out torch.nn.Linear: ``#linear = torch.nn.Linear(...)`` -2. Add bnb 8-bit linear light module: ``linear = bnb.nn.Linear8bitLt(...)`` (base arguments stay the same) -3. There are two modes: - - Mixed 8-bit training with 16-bit main weights. Pass the argument ``has_fp16_weights=True`` (default) - - Int8 inference. Pass the argument ``has_fp16_weights=False`` -4. To use the full LLM.int8() method, use the ``threshold=k`` argument. We recommend ``k=6.0``. -```python -# LLM.int8() -linear = bnb.nn.Linear8bitLt(dim1, dim2, bias=True, has_fp16_weights=False, threshold=6.0) -# inputs need to be fp16 -out = linear(x.to(torch.float16)) -``` +Install torch-rocm 6.0 +pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.0 -## Features -- 8-bit Matrix multiplication with mixed precision decomposition -- LLM.int8() inference -- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB, Lion (saves 75% memory) -- Stable Embedding Layer: Improved stability through better initialization, and normalization -- 8-bit quantization: Quantile, Linear, and Dynamic quantization -- Fast quantile estimation: Up to 100x faster than other algorithms +Install this packge -## Using bitsandbytes +python3 -m pip install ./bitsandbytes.tar.gz -### Using Int8 Matrix Multiplication -For straight Int8 matrix multiplication with mixed precision decomposition you can use ``bnb.matmul(...)``. To enable mixed precision decomposition, use the threshold parameter: -```python -bnb.matmul(..., threshold=6.0) -``` +## ROCM Installation +Only Ubuntu 22.04 was tested on. Windows WSL have not been tested on. +Windows is not supported at this time. -For instructions how to use LLM.int8() inference layers in your own code, see the TL;DR above or for extended instruction see [this blog post](https://huggingface.co/blog/hf-bitsandbytes-integration). +For instruction on installing ROCM, check https://rocm.docs.amd.com/projects/install-on-linux/en/latest/tutorial/install-overview.html -### Using the 8-bit Optimizers +## To check bitsandbytes Installation +>>> import bitsandbytes -With bitsandbytes 8-bit optimizers can be used by changing a single line of code in your codebase. For NLP models we recommend also to use the StableEmbedding layers (see below) which improves results and helps with stable 8-bit optimization. To get started with 8-bit optimizers, it is sufficient to replace your old optimizer with the 8-bit optimizer in the following way: -```python -import bitsandbytes as bnb +>>> import torch -# adam = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995)) # comment out old optimizer -adam = bnb.optim.Adam8bit(model.parameters(), lr=0.001, betas=(0.9, 0.995)) # add bnb optimizer -adam = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995), optim_bits=8) # equivalent +>>> torch.cuda.get_device_name() +'Radeon RX 7900 XTX' -torch.nn.Embedding(...) -> bnb.nn.StableEmbedding(...) # recommended for NLP models -``` +>>> torch.cuda.get_device_properties(torch.device) -Note that by default all parameter tensors with less than 4096 elements are kept at 32-bit even if you initialize those parameters with 8-bit optimizers. This is done since such small tensors do not save much memory and often contain highly variable parameters (biases) or parameters that require high precision (batch norm, layer norm). You can change this behavior like so: -```python -# parameter tensors with less than 16384 values are optimized in 32-bit -# it is recommended to use multiplies of 4096 -adam = bnb.optim.Adam8bit(model.parameters(), min_8bit_size=16384) -``` +_CudaDeviceProperties(name='Radeon RX 7900 XTX', major=11, minor=0, gcnArchName='gfx1100', total_memory=24560MB, multi_processor_count=48) -### Change Bits and other Hyperparameters for Individual Parameters +>>> torch.cuda.get_arch_list() -If you want to optimize some unstable parameters with 32-bit Adam and others with 8-bit Adam, you can use the `GlobalOptimManager`. With this, we can also configure specific hyperparameters for particular layers, such as embedding layers. To do that, we need two things: (1) register the parameter while they are still on the CPU, (2) override the config with the new desired hyperparameters (anytime, anywhere). See our [guide](howto_config_override.md) for more details +['gfx900', 'gfx906', 'gfx908', 'gfx90a', 'gfx1030', 'gfx1100', 'gfx942'] -### Fairseq Users +# Original text from bitsandbytes repo -To use the Stable Embedding Layer, override the respective `build_embedding(...)` function of your model. Make sure to also use the `--no-scale-embedding` flag to disable scaling of the word embedding layer (nor replaced with layer norm). You can use the optimizers by replacing the optimizer in the respective file (`adam.py` etc.). +The `bitsandbytes` library is a lightweight Python wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and 8 & 4-bit quantization functions. -## Release and Feature History +The library includes quantization primitives for 8-bit & 4-bit operations, through `bitsandbytes.nn.Linear8bitLt` and `bitsandbytes.nn.Linear4bit` and 8-bit optimizers through `bitsandbytes.optim` module. -For upcoming features and changes and full history see [Patch Notes](CHANGELOG.md). +There are ongoing efforts to support further hardware backends, i.e. Intel CPU + GPU, AMD GPU, Apple Silicon. Windows support is quite far along and is on its way as well. -## Errors +**Please head to the official documentation page:** -1. RuntimeError: CUDA error: no kernel image is available for execution on the device. [Solution](errors_and_solutions.md#No-kernel-image-available) -2. __fatbinwrap_.. [Solution](errors_and_solutions.md#fatbinwrap_) +**[https://huggingface.co/docs/bitsandbytes/main](https://huggingface.co/docs/bitsandbytes/main)** ## License -The majority of bitsandbytes is licensed under MIT, however portions of the project are available under separate license terms: Pytorch is licensed under the BSD license. +The majority of bitsandbytes is licensed under MIT, however small portions of the project are available under separate license terms, as the parts adapted from Pytorch are licensed under the BSD license. We thank Fabio Cannizzo for his work on [FastBinarySearch](https://github.com/fabiocannizzo/FastBinarySearch) which we use for CPU quantization. - -## How to cite us -If you found this library and found LLM.int8() useful, please consider citing our work: - -```bibtex -@article{dettmers2022llmint8, - title={LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale}, - author={Dettmers, Tim and Lewis, Mike and Belkada, Younes and Zettlemoyer, Luke}, - journal={arXiv preprint arXiv:2208.07339}, - year={2022} -} -``` - -For 8-bit optimizers or quantization routines, please consider citing the following work: - -```bibtex -@article{dettmers2022optimizers, - title={8-bit Optimizers via Block-wise Quantization}, - author={Dettmers, Tim and Lewis, Mike and Shleifer, Sam and Zettlemoyer, Luke}, - journal={9th International Conference on Learning Representations, ICLR}, - year={2022} -} -``` diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 01d5527f5..78c99355b 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -3,20 +3,17 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import cuda_setup, utils, research +from . import research, utils from .autograd._functions import ( MatmulLtState, bmm_cublas, matmul, + matmul_4bit, matmul_cublas, mm_cublas, - matmul_4bit ) -from .cextension import COMPILED_WITH_CUDA from .nn import modules - -if COMPILED_WITH_CUDA: - from .optim import adam +from .optim import adam __pdoc__ = { "libbitsandbytes": False, @@ -24,6 +21,4 @@ "optim.optimizer.MockArgs": False, } -__version__ = "0.42.0" - -PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes" +__version__ = "0.44.0.dev" diff --git a/bitsandbytes/__main__.py b/bitsandbytes/__main__.py index ebbf2653e..e716b6f3f 100644 --- a/bitsandbytes/__main__.py +++ b/bitsandbytes/__main__.py @@ -1,155 +1,4 @@ -import os -import sys -import shlex -import subprocess +if __name__ == "__main__": + from bitsandbytes.diagnostics.main import main -from warnings import warn -from typing import Tuple -from os.path import isdir - -import torch - -HEADER_WIDTH = 60 - -def execute_and_return(command_string: str) -> Tuple[str, str]: - def _decode(subprocess_err_out_tuple): - return tuple( - to_decode.decode("UTF-8").strip() - for to_decode in subprocess_err_out_tuple - ) - - def execute_and_return_decoded_std_streams(command_string): - return _decode( - subprocess.Popen( - shlex.split(command_string), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ).communicate() - ) - - std_out, std_err = execute_and_return_decoded_std_streams(command_string) - return std_out, std_err - -def find_file_recursive(folder, filename): - folder = shlex.quote(folder) - filename = shlex.quote(filename) - cmd = f'find {folder} -name {filename}' - out, err = execute_and_return(cmd) - if len(err) > 0: - raise RuntimeError('Something when wrong when trying to find file. Maybe you do not have a linux system?') - - return out - - -def generate_bug_report_information(): - print_header("") - print_header("BUG REPORT INFORMATION") - print_header("") - print('') - - if 'CONDA_PREFIX' in os.environ: - paths = find_file_recursive(os.environ['CONDA_PREFIX'], '*cuda*so') - print_header("ANACONDA CUDA PATHS") - print(paths) - print('') - if isdir('/usr/local/'): - paths = find_file_recursive('/usr/local', '*cuda*so') - print_header("/usr/local CUDA PATHS") - print(paths) - print('') - - if isdir(os.getcwd()): - paths = find_file_recursive(os.getcwd(), '*cuda*so') - print_header("WORKING DIRECTORY CUDA PATHS") - print(paths) - print('') - - print_header("LD_LIBRARY CUDA PATHS") - if 'LD_LIBRARY_PATH' in os.environ: - lib_path = os.environ['LD_LIBRARY_PATH'].strip() - for path in set(lib_path.split(':')): - try: - if isdir(path): - print_header(f"{path} CUDA PATHS") - paths = find_file_recursive(path, '*cuda*so') - print(paths) - except: - print(f'Could not read LD_LIBRARY_PATH: {path}') - print('') - - - - - -def print_header( - txt: str, width: int = HEADER_WIDTH, filler: str = "+" -) -> None: - txt = f" {txt} " if txt else "" - print(txt.center(width, filler)) - - -def print_debug_info() -> None: - print( - "\nAbove we output some debug information. Please provide this info when " - f"creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose ...\n" - ) - - -generate_bug_report_information() - - -from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL -from .cuda_setup.env_vars import to_be_ignored -from .cuda_setup.main import get_compute_capabilities - - -print_header("OTHER") -print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}") -print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities()}") -print_header("") -print_header("DEBUG INFO END") -print_header("") -print( - """ -Running a quick check that: - + library is importable - + CUDA function is callable -""" -) -print("\nWARNING: Please be sure to sanitize sensible info from any such env vars!\n") - -try: - from bitsandbytes.optim import Adam - - p = torch.nn.Parameter(torch.rand(10, 10).cuda()) - a = torch.rand(10, 10).cuda() - - p1 = p.data.sum().item() - - adam = Adam([p]) - - out = a * p - loss = out.sum() - loss.backward() - adam.step() - - p2 = p.data.sum().item() - - assert p1 != p2 - print("SUCCESS!") - print("Installation was successful!") - sys.exit(0) - -except ImportError: - print() - warn( - f"WARNING: {__package__} is currently running as CPU-only!\n" - "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" - f"If you think that this is so erroneously,\nplease report an issue!" - ) - print_debug_info() - sys.exit(0) -except Exception as e: - print(e) - print_debug_info() - sys.exit(1) + main() diff --git a/bitsandbytes/archive_functional.py b/bitsandbytes/archive_functional.py deleted file mode 100644 index 226c9e51f..000000000 --- a/bitsandbytes/archive_functional.py +++ /dev/null @@ -1,2397 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -import ctypes as ct -import itertools -import operator -import random -import torch -import itertools -import math -from scipy.stats import norm -import numpy as np - -from functools import reduce # Required in Python 3 -from typing import Tuple -from torch import Tensor - -from .cextension import COMPILED_WITH_CUDA, lib - - -# math.prod not compatible with python < 3.8 -def prod(iterable): - return reduce(operator.mul, iterable, 1) - -name2qmap = {} - -if COMPILED_WITH_CUDA: - """C FUNCTIONS FOR OPTIMIZERS""" - str2optimizer32bit = {} - str2optimizer32bit["adam"] = (lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16) #, lib.cadam32bit_grad_bf16) - str2optimizer32bit["momentum"] = ( - lib.cmomentum32bit_grad_32, - lib.cmomentum32bit_grad_16, - ) - str2optimizer32bit["rmsprop"] = ( - lib.crmsprop32bit_grad_32, - lib.crmsprop32bit_grad_16, - ) - str2optimizer32bit["lion"] = (lib.clion32bit_grad_fp32, lib.clion32bit_grad_fp16) #, lib.clion32bit_grad_bf16) - str2optimizer32bit["adagrad"] = ( - lib.cadagrad32bit_grad_32, - lib.cadagrad32bit_grad_16, - ) - - str2optimizer8bit = {} - str2optimizer8bit["adam"] = ( - lib.cadam_static_8bit_grad_32, - lib.cadam_static_8bit_grad_16, - ) - str2optimizer8bit["momentum"] = ( - lib.cmomentum_static_8bit_grad_32, - lib.cmomentum_static_8bit_grad_16, - ) - str2optimizer8bit["rmsprop"] = ( - lib.crmsprop_static_8bit_grad_32, - lib.crmsprop_static_8bit_grad_16, - ) - str2optimizer8bit["lion"] = ( - lib.clion_static_8bit_grad_32, - lib.clion_static_8bit_grad_16, - ) - str2optimizer8bit["lamb"] = ( - lib.cadam_static_8bit_grad_32, - lib.cadam_static_8bit_grad_16, - ) - str2optimizer8bit["lars"] = ( - lib.cmomentum_static_8bit_grad_32, - lib.cmomentum_static_8bit_grad_16, - ) - - str2optimizer8bit_blockwise = {} - str2optimizer8bit_blockwise["adam"] = ( - lib.cadam_8bit_blockwise_grad_fp32, - lib.cadam_8bit_blockwise_grad_fp16, - #lib.cadam_8bit_blockwise_grad_bf16, - ) - str2optimizer8bit_blockwise["momentum"] = ( - lib.cmomentum_8bit_blockwise_grad_fp32, - lib.cmomentum_8bit_blockwise_grad_fp16, - ) - str2optimizer8bit_blockwise["rmsprop"] = ( - lib.crmsprop_8bit_blockwise_grad_fp32, - lib.crmsprop_8bit_blockwise_grad_fp16, - ) - str2optimizer8bit_blockwise["lion"] = ( - lib.clion_8bit_blockwise_grad_fp32, - lib.clion_8bit_blockwise_grad_fp16, - #lib.clion_8bit_blockwise_grad_bf16, - ) - str2optimizer8bit_blockwise["adagrad"] = ( - lib.cadagrad_8bit_blockwise_grad_fp32, - lib.cadagrad_8bit_blockwise_grad_fp16, - ) - -class GlobalPageManager: - _instance = None - - def __init__(self): - raise RuntimeError("Call get_instance() instead") - - def initialize(self): - self.paged_tensors = [] - - @classmethod - def get_instance(cls): - if cls._instance is None: - cls._instance = cls.__new__(cls) - cls._instance.initialize() - return cls._instance - - def prefetch_all(self, to_cpu=False): - # assume the first added, will be hte - # ones that are used first, so swap them in last - # in the case they are evicted again - for t in self.paged_tensors[::-1]: - prefetch_tensor(t, to_cpu) - - - -class CUBLAS_Context: - _instance = None - - def __init__(self): - raise RuntimeError("Call get_instance() instead") - - def initialize(self): - self.context = {} - - @classmethod - def get_instance(cls): - if cls._instance is None: - cls._instance = cls.__new__(cls) - cls._instance.initialize() - return cls._instance - - def get_context(self, device): - if device.index not in self.context: - prev_device = torch.cuda.current_device() - torch.cuda.set_device(device) - self.context[device.index] = ct.c_void_p(lib.get_context()) - torch.cuda.set_device(prev_device) - return self.context[device.index] - - -class Cusparse_Context: - _instance = None - - def __init__(self): - raise RuntimeError("Call get_instance() instead") - - def initialize(self): - #self.context = ct.c_void_p(lib.get_cusparse()) - if torch.version.cuda: - self.context = ct.c_void_p(lib.get_cusparse()) - elif torch.version.hip: - self.context = ct.c_void_p(lib.get_hipsparse()) - - @classmethod - def get_instance(cls): - if cls._instance is None: - cls._instance = cls.__new__(cls) - cls._instance.initialize() - return cls._instance - -dtype2bytes = {} -dtype2bytes[torch.float32] = 4 -dtype2bytes[torch.float16] = 2 -dtype2bytes[torch.bfloat16] = 2 -dtype2bytes[torch.uint8] = 1 -dtype2bytes[torch.int8] = 1 - -def get_paged(*shape, dtype=torch.float32, device=torch.device('cuda', index=0)): - num_bytes = dtype2bytes[dtype]*prod(shape) - cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) - c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) - new_array = np.ctypeslib.as_array(c_ptr, shape=shape) - out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape)).view(shape) - out.is_paged = True - out.page_deviceid = device.index - return out - -def prefetch_tensor(A, to_cpu=False): - assert A.is_paged, 'Only paged tensors can be prefetched!' - if to_cpu: - deviceid = -1 - else: - deviceid = A.page_deviceid - - num_bytes = dtype2bytes[A.dtype]*A.numel() - lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid)) - -def elementwise_func(func_name, A, B, value, prefetch=True): - func = None - if A.dtype == torch.float32: - func = getattr(lib, f'c{func_name}_fp32', None) - cvalue = ct.c_float(value) - elif A.dtype == torch.uint8: - func = getattr(lib, f'c{func_name}_uint8', None) - cvalue = ct.c_uint8(value) - - if func is None: raise NotImplementedError(f'Function not implemented: {func_name}') - - is_managed = getattr(A, 'is_managed', False) - if is_managed and prefetch: - prefetch_tensor(A) - if B is not None: prefetch_tensor(B) - - func(get_ptr(A), get_ptr(B), cvalue, ct.c_int64(A.numel())) - if A.is_paged or B.is_paged: - # paged function are fully asynchronous - # if we return from this function, we want to the tensor - # to be in the correct state, that is the final state after the - # operation occured. So we synchronize. - torch.cuda.synchronize() - -def fill(A, value, device=None, prefetch=True): elementwise_func('fill', A, None, value) -def arange(A, device=None): elementwise_func('arange', A, None, 0) -def _mul(A, B, device=None): elementwise_func('_mul', A, B, 0) - - -def create_linear_map(signed=True, total_bits=8, add_zero=True): - sign = (-1.0 if signed else 0.0) - total_values = 2**total_bits - if add_zero or total_bits < 8: - # add a zero - # since we simulate less bits by having zeros in the data type, we - # we need to center the quantization around zero and as such lose - # a single value - total_values = (2**total_bits if not signed else 2**total_bits-1) - - values = torch.linspace(sign, 1.0, total_values) - gap = 256 - values.numel() - if gap == 0: - return values - else: - l = values.numel()//2 - return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) - -def create_normal_map(offset=0.9677083, use_extra_value=True): - - if use_extra_value: - # one more positive value, this is an asymmetric type - v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist() - v2 = [0]*(256-15) ## we have 15 non-zero values in this data type - v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() - v = v1 + v2 + v3 - else: - v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist() - v2 = [0]*(256-14) ## we have 14 non-zero values in this data type - v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() - v = v1 + v2 + v3 - - values = torch.Tensor(v) - values = values.sort().values - values /= values.max() - assert values.numel() == 256 - return values - -def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8): - e = exponent_bits - p = precision_bits - has_sign = 1 if signed else 0 - assert e+p == total_bits-has_sign - # the exponent is biased to 2^(e-1) -1 == 0 - evalues = [] - pvalues = [] - for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)): - evalues.append(2**val) - - - values = [] - lst = list(itertools.product([0, 1], repeat=precision_bits)) - #for ev in evalues: - bias = 2**(exponent_bits-1) - for evalue in range(2**(exponent_bits)): - for bit_pattern in lst: - value = (1 if evalue != 0 else 0) - for i, pval in enumerate(list(bit_pattern)): - value += pval*(2**-(i+1)) - if evalue == 0: - # subnormals - value = value*2**-(bias) - else: - # normals - value = value*2**-(evalue-bias-1) - values.append(value) - if signed: - values.append(-value) - - - assert len(values) == 2**total_bits - values.sort() - if total_bits < 8: - gap = 256 - len(values) - for i in range(gap): - values.append(0) - values.sort() - code = torch.Tensor(values) - code /= code.max() - - return code - - - -def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): - """ - Creates the dynamic quantiztion map. - - The dynamic data type is made up of a dynamic exponent and - fraction. As the exponent increase from 0 to -7 the number - of bits available for the fraction shrinks. - - This is a generalization of the dynamic type where a certain - number of the bits and be reserved for the linear quantization - region (the fraction). n determines the maximum number of - exponent bits. - - For more details see - (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561] - """ - - data = [] - # these are additional items that come from the case - # where all the exponent bits are zero and no - # indicator bit is present - non_sign_bits = total_bits - (1 if signed else 0) - additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1 - if not signed: - additional_items = 2 * additional_items - for i in range(max_exponent_bits): - fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1)) - boundaries = torch.linspace(0.1, 1, fraction_items) - means = (boundaries[:-1] + boundaries[1:]) / 2.0 - data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() - if signed: - data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() - - if additional_items > 0: - boundaries = torch.linspace(0.1, 1, additional_items + 1) - means = (boundaries[:-1] + boundaries[1:]) / 2.0 - data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() - if signed: - data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() - - data.append(0) - data.append(1.0) - - gap = 256 - len(data) - for i in range(gap): - data.append(0) - - data.sort() - return Tensor(data) - -def create_quantile_map(A, total_bits=8): - q = estimate_quantiles(A, num_quantiles=2**total_bits-1) - q = q.tolist() - q.append(0) - - gap = 256 - len(q) - for i in range(gap): - q.append(0) - - q.sort() - - q = Tensor(q) - q = q/q.abs().max() - return q - -def get_special_format_str(): - if not torch.cuda.is_available(): return 'col_turing' - major, _minor = torch.cuda.get_device_capability() - if major <= 7: - return "col_turing" - if major == 8: - return "col_ampere" - return "col_turing" - - - -def is_on_gpu(tensors): - on_gpu = True - gpu_ids = set() - for t in tensors: - if t is None: continue # NULL pointers are fine - is_paged = getattr(t, 'is_paged', False) - on_gpu &= (t.device.type == 'cuda' or is_paged) - if not is_paged: - gpu_ids.add(t.device.index) - if not on_gpu: - raise TypeError(f'All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}') - if len(gpu_ids) > 1: - raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}') - return on_gpu - -def get_ptr(A: Tensor) -> ct.c_void_p: - """ - Get the ctypes pointer from a PyTorch Tensor. - - Parameters - ---------- - A : torch.tensor - The PyTorch tensor. - - Returns - ------- - ctypes.c_void_p - """ - if A is None: - return None - else: - return ct.c_void_p(A.data.data_ptr()) - - -def pre_call(device): - prev_device = torch.cuda.current_device() - torch.cuda.set_device(device) - return prev_device - - -def post_call(prev_device): - torch.cuda.set_device(prev_device) - - -def get_transform_func(dtype, orderA, orderOut, transpose=False): - name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}' - if not hasattr(lib, name): - print(name) - raise ValueError( - f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}" - ) - else: - return getattr(lib, name) - - -def get_transform_buffer( - shape, dtype, device, to_order, from_order="row", transpose=False -): - # init_func = torch.empty - init_func = torch.zeros - dims = len(shape) - - if dims == 2: - rows = shape[0] - elif dims == 3: - rows = shape[0] * shape[1] - cols = shape[-1] - - state = (shape, to_order) - if transpose: - # swap dims - tmp = rows - rows = cols - cols = tmp - state = (shape[::-1], to_order) - - if to_order == "row" or to_order == "col": - return init_func(shape, dtype=dtype, device=device), state - elif to_order == "col32": - # blocks of 32 columns (padded) - cols = 32 * ((cols + 31) // 32) - return init_func((rows, cols), dtype=dtype, device=device), state - elif to_order == "col_turing": - # blocks of 32 columns and 8 rows - cols = 32 * ((cols + 31) // 32) - rows = 8 * ((rows + 7) // 8) - return init_func((rows, cols), dtype=dtype, device=device), state - elif to_order == "col_ampere": - # blocks of 32 columns and 32 rows - cols = 32 * ((cols + 31) // 32) - rows = 32 * ((rows + 31) // 32) - return init_func((rows, cols), dtype=dtype, device=device), state - else: - raise NotImplementedError(f"To_order not supported: {to_order}") - - -def nvidia_transform( - A, - to_order, - from_order="row", - out=None, - transpose=False, - state=None, - ld=None, -): - if state is None: - state = (A.shape, from_order) - else: - from_order = state[1] - if out is None: - out, new_state = get_transform_buffer( - state[0], A.dtype, A.device, to_order, state[1] - ) - else: - new_state = (state[1], to_order) - func = get_transform_func(A.dtype, from_order, to_order, transpose) - - shape = state[0] - if len(shape) == 2: - dim1 = ct.c_int32(shape[0]) - dim2 = ct.c_int32(shape[1]) - elif ld is not None: - n = prod(shape) - dim1 = prod([shape[i] for i in ld]) - dim2 = ct.c_int32(n // dim1) - dim1 = ct.c_int32(dim1) - else: - dim1 = ct.c_int32(shape[0] * shape[1]) - dim2 = ct.c_int32(shape[2]) - - ptr = CUBLAS_Context.get_instance().get_context(A.device) - func(ptr, get_ptr(A), get_ptr(out), dim1, dim2) - - return out, new_state - - -def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor: - ''' - Estimates 256 equidistant quantiles on the input tensor eCDF. - - Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles - via the eCDF of the input tensor `A`. This is a fast but approximate algorithm - and the extreme quantiles close to 0 and 1 have high variance / large estimation - errors. These large errors can be avoided by using the offset variable which trims - the distribution. The default offset value of 1/512 ensures minimum entropy encoding -- it - trims 1/512 = 0.2% from each side of the distrivution. An offset value of 0.01 to 0.02 - usually has a much lower error but is not a minimum entropy encoding. Given an offset - of 0.02 equidistance points in the range [0.02, 0.98] are used for the quantiles. - - Parameters - ---------- - A : torch.Tensor - The input tensor. Any shape. - out : torch.Tensor - Tensor with the 256 estimated quantiles. - offset : float - The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles) - num_quantiles : int - The number of equally spaced quantiles. - - Returns - ------- - torch.Tensor: - The 256 quantiles in float32 datatype. - ''' - if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.') - if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}") - if num_quantiles < 256 and offset == 1/(512): - # override default arguments - offset = 1/(2*num_quantiles) - - if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) - is_on_gpu([A, out]) - device = pre_call(A.device) - if A.dtype == torch.float32: - lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) - elif A.dtype == torch.float16: - lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) - else: - raise NotImplementedError(f"Not supported data type {A.dtype}") - post_call(device) - - if num_quantiles < 256: - step = round(256/num_quantiles) - idx = torch.linspace(0, 255, num_quantiles).long().to(A.device) - out = out[idx] - - return out - - -def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor: - """ - Quantize tensor A in blocks of size 4096 values. - - Quantizes tensor A by dividing it into blocks of 4096 values. - Then the absolute maximum value within these blocks is calculated - for the non-linear quantization. - - Parameters - ---------- - A : torch.Tensor - The input tensor. - code : torch.Tensor - The quantization map. - absmax : torch.Tensor - The absmax values. - out : torch.Tensor - The output tensor (8-bit). - - Returns - ------- - torch.Tensor: - The 8-bit tensor. - tuple(torch.Tensor, torch.Tensor): - The quantization state to undo the quantization. - """ - - - if code is None: - if "dynamic" not in name2qmap: - name2qmap["dynamic"] = create_dynamic_map().to(A.device) - code = name2qmap["dynamic"] - - if absmax is None: - n = A.numel() - blocks = n // blocksize - blocks += 1 if n % blocksize > 0 else 0 - absmax = torch.zeros((blocks,), device=A.device) - - if out is None: - out = torch.zeros_like(A, dtype=torch.uint8) - - if A.device.type != 'cpu': - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] - cblocksize = ct.c_int32(blocksize) - prev_device = pre_call(A.device) - code = code.to(A.device) - is_on_gpu([code, A, out, absmax]) - if A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) - elif A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - post_call(A.device) - else: - # cpu - code = code.cpu() - lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) - - if nested: - offset = absmax.mean() - absmax -= offset - qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False) - state = [qabsmax, code, blocksize, nested, offset, state2] - else: - state = [absmax, code, blocksize, nested, None, None] - - - - return out, state - - -def dequantize_blockwise( - A: Tensor, - quant_state: Tuple[Tensor, Tensor] = None, - absmax: Tensor = None, - code: Tensor = None, - out: Tensor = None, - blocksize: int = 4096, - nested=False -) -> Tensor: - """ - Dequantizes blockwise quantized values. - - Dequantizes the tensor A with maximum absolute values absmax in - blocks of size 4096. - - Parameters - ---------- - A : torch.Tensor - The input 8-bit tensor. - quant_state : tuple(torch.Tensor, torch.Tensor) - Tuple of code and absmax values. - absmax : torch.Tensor - The absmax values. - code : torch.Tensor - The quantization map. - out : torch.Tensor - Dequantized output tensor (default: float32) - - - Returns - ------- - torch.Tensor: - Dequantized tensor (default: float32) - """ - assert quant_state is not None or absmax is not None - if code is None and quant_state is None: - if "dynamic" not in name2qmap: - name2qmap["dynamic"] = create_dynamic_map().to(A.device) - code = name2qmap["dynamic"] - - if out is None: - out = torch.zeros_like(A, dtype=torch.float32) - - if quant_state is None: - quant_state = (absmax, code, blocksize) - assert absmax is not None and out is not None - else: - absmax, code, blocksize, nested, offset, state2 = quant_state - if nested: - absmax = dequantize_blockwise(absmax, state2) - absmax += offset - - - if A.device.type != 'cpu': - device = pre_call(A.device) - code = code.to(A.device) - if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: - raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") - is_on_gpu([A, absmax, out]) - if out.dtype == torch.float32: - lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) - elif out.dtype == torch.float16: - lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - post_call(A.device) - else: - code = code.cpu() - lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) - - return out - -def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): - return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4') - -def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): - return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4') - -def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor: - """ - Quantize tensor A in blocks of 4-bit values. - - Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. - - Parameters - ---------- - A : torch.Tensor - The input tensor. - absmax : torch.Tensor - The absmax values. - out : torch.Tensor - The output tensor (8-bit). - blocksize : int - The blocksize used in quantization. - quant_type : str - The 4-bit quantization data type {fp4, nf4} - - Returns - ------- - torch.Tensor: - The 8-bit tensor with packed 4-bit values. - tuple(torch.Tensor, torch.Size, torch.dtype, int): - The quantization state to undo the quantization. - """ - if A.device.type != 'cuda': - raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}') - if quant_type not in ['fp4', 'nf4']: - raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') - - n = A.numel() - input_shape = A.shape - - if absmax is None: - blocks = n // blocksize - blocks += 1 if n % blocksize > 0 else 0 - absmax = torch.zeros((blocks,), device=A.device) - - - if out is None: - out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device) - - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] - - prev_device = pre_call(A.device) - is_on_gpu([A, out, absmax]) - - if A.dtype == torch.float32: - if quant_type == 'fp4': - lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) - else: - lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) - elif A.dtype == torch.float16: - if quant_type == 'fp4': - lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) - else: - lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - post_call(A.device) - - if compress_statistics: - offset = absmax.mean() - absmax -= offset - #code = create_custom_map().to(absmax.device) - #qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256) - qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) - del absmax - state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] - else: - state = [absmax, input_shape, A.dtype, blocksize, None, quant_type] - - return out, state - -def dequantize_fp4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: - return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4') - -def dequantize_nf4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: - return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4') - -def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor: - """ - Dequantizes FP4 blockwise quantized values. - - Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize. - - Parameters - ---------- - A : torch.Tensor - The input 8-bit tensor (packed 4-bit values). - quant_state : tuple(torch.Tensor, torch.Size, torch.dtype) - Tuple of absmax values, original tensor shape and original dtype. - absmax : torch.Tensor - The absmax values. - out : torch.Tensor - Dequantized output tensor. - blocksize : int - The blocksize used in quantization. - quant_type : str - The 4-bit quantization data type {fp4, nf4} - - - Returns - ------- - torch.Tensor: - Dequantized tensor. - """ - if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: - raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") - if quant_type not in ['fp4', 'nf4']: - raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') - - if quant_state is None: - assert absmax is not None and out is not None - shape = out.shape - dtype = out.dtype - else: - absmax, shape, dtype, blocksize, compressed_stats, quant_type = quant_state - - - if compressed_stats is not None: - offset, state2 = compressed_stats - absmax = dequantize_blockwise(absmax, state2) - absmax += offset - - if out is None: - out = torch.empty(shape, dtype=dtype, device=A.device) - - n = out.numel() - - - device = pre_call(A.device) - is_on_gpu([A, absmax, out]) - if out.dtype == torch.float32: - if quant_type == 'fp4': - lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) - else: - lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) - elif out.dtype == torch.float16: - if quant_type == 'fp4': - lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) - else: - lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - post_call(A.device) - - is_transposed = (True if A.shape[0] == 1 else False) - if is_transposed: return out.t() - else: return out - - -def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: - if code is None: - if "dynamic" not in name2qmap: - name2qmap["dynamic"] = create_dynamic_map().to(A.device) - code = name2qmap["dynamic"] - code = code.to(A.device) - - absmax = torch.abs(A).max() - inp = A / absmax - out = quantize_no_absmax(inp, code, out) - return out, (absmax, code) - - -def dequantize( - A: Tensor, - quant_state: Tuple[Tensor, Tensor] = None, - absmax: Tensor = None, - code: Tensor = None, - out: Tensor = None, -) -> Tensor: - assert quant_state is not None or absmax is not None - if code is None and quant_state is None: - if "dynamic" not in name2qmap: - name2qmap["dynamic"] = create_dynamic_map().to(A.device) - code = name2qmap["dynamic"] - code = code.to(A.device) - - if quant_state is None: - quant_state = (absmax, code) - out = dequantize_no_absmax(A, quant_state[1], out) - return out * quant_state[0] - - -def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: - ''' - Quantizes input tensor to 8-bit. - - Quantizes the 32-bit input tensor `A` to the 8-bit output tensor - `out` using the quantization map `code`. - - Parameters - ---------- - A : torch.Tensor - The input tensor. - code : torch.Tensor - The quantization map. - out : torch.Tensor, optional - The output tensor. Needs to be of type byte. - - Returns - ------- - torch.Tensor: - Quantized 8-bit tensor. - ''' - prev_device = pre_call(A.device) - if out is None: out = torch.zeros_like(A, dtype=torch.uint8) - is_on_gpu([A, out]) - lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) - post_call(prev_device) - return out - - -def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: - ''' - Dequantizes the 8-bit tensor to 32-bit. - - Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via - the quantization map `code`. - - Parameters - ---------- - A : torch.Tensor - The 8-bit input tensor. - code : torch.Tensor - The quantization map. - out : torch.Tensor - The 32-bit output tensor. - - Returns - ------- - torch.Tensor: - 32-bit output tensor. - ''' - prev_device = pre_call(A.device) - if out is None: out = torch.zeros_like(A, dtype=torch.float32) - is_on_gpu([code, A, out]) - lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) - post_call(prev_device) - return out - - -def optimizer_update_32bit( - optimizer_name: str, - g: Tensor, - p: Tensor, - state1: Tensor, - beta1: float, - eps: float, - step: int, - lr: float, - state2: Tensor = None, - beta2: float = 0.0, - weight_decay: float = 0.0, - gnorm_scale: float = 1.0, - unorm_vec: Tensor = None, - max_unorm: float = 0.0, - skip_zeros=False, -) -> None: - """ - Performs an inplace optimizer update with one or two optimizer states. - - Universal optimizer update for 32-bit state and 32/16-bit gradients/weights. - - Parameters - ---------- - optimizer_name : str - The name of the optimizer: {adam}. - g : torch.Tensor - Gradient tensor. - p : torch.Tensor - Parameter tensor. - state1 : torch.Tensor - Optimizer state 1. - beta1 : float - Optimizer beta1. - eps : float - Optimizer epsilon. - weight_decay : float - Weight decay. - step : int - Current optimizer step. - lr : float - The learning rate. - state2 : torch.Tensor - Optimizer state 2. - beta2 : float - Optimizer beta2. - gnorm_scale : float - The factor to rescale the gradient to the max clip value. - unorm_vec : torch.Tensor - The tensor for the update norm. - max_unorm : float - The maximum update norm relative to the weight norm. - skip_zeros : bool - Whether to skip zero-valued gradients or not (default: False). - """ - - param_norm = 0.0 - if max_unorm > 0.0: - param_norm = torch.norm(p.data.float()) - - - optim_func = None - if g.dtype == torch.float32: - optim_func = str2optimizer32bit[optimizer_name][0] - elif g.dtype == torch.float16: - optim_func = str2optimizer32bit[optimizer_name][1] - elif (g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name])==3): - optim_func = str2optimizer32bit[optimizer_name][2] - else: - raise ValueError(f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}") - - is_on_gpu([g, p, state1, state2, unorm_vec]) - prev_device = pre_call(g.device) - optim_func( - get_ptr(g), - get_ptr(p), - get_ptr(state1), - get_ptr(state2), - get_ptr(unorm_vec), - ct.c_float(max_unorm), - ct.c_float(param_norm), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_float(weight_decay), - ct.c_int32(step), - ct.c_float(lr), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel())) - post_call(prev_device) - - -def optimizer_update_8bit( - optimizer_name: str, - g: Tensor, - p: Tensor, - state1: Tensor, - state2: Tensor, - beta1: float, - beta2: float, - eps: float, - step: int, - lr: float, - qmap1: Tensor, - qmap2: Tensor, - max1: Tensor, - max2: Tensor, - new_max1: Tensor, - new_max2: Tensor, - weight_decay: float = 0.0, - gnorm_scale: float = 1.0, - unorm_vec: Tensor = None, - max_unorm: float = 0.0, -) -> None: - """ - Performs an inplace Adam update. - - Universal Adam update for 32/8-bit state and 32/16-bit gradients/weights. - Uses AdamW formulation if weight decay > 0.0. - - Parameters - ---------- - optimizer_name : str - The name of the optimizer. Choices {adam, momentum} - g : torch.Tensor - Gradient tensor. - p : torch.Tensor - Parameter tensor. - state1 : torch.Tensor - Adam state 1. - state2 : torch.Tensor - Adam state 2. - beta1 : float - Adam beta1. - beta2 : float - Adam beta2. - eps : float - Adam epsilon. - weight_decay : float - Weight decay. - step : int - Current optimizer step. - lr : float - The learning rate. - qmap1 : torch.Tensor - Quantization map for first Adam state. - qmap2 : torch.Tensor - Quantization map for second Adam state. - max1 : torch.Tensor - Max value for first Adam state update. - max2 : torch.Tensor - Max value for second Adam state update. - new_max1 : torch.Tensor - Max value for the next Adam update of the first state. - new_max2 : torch.Tensor - Max value for the next Adam update of the second state. - gnorm_scale : float - The factor to rescale the gradient to the max clip value. - unorm_vec : torch.Tensor - The tensor for the update norm. - max_unorm : float - The maximum update norm relative to the weight norm. - """ - - param_norm = 0.0 - if max_unorm > 0.0: - param_norm = torch.norm(p.data.float()) - - prev_device = pre_call(g.device) - is_on_gpu([g, p, state1, state2, unorm_vec, qmap1, qmap2, max1, max2, new_max1, new_max2]) - if g.dtype == torch.float32 and state1.dtype == torch.uint8: - str2optimizer8bit[optimizer_name][0]( - get_ptr(p), - get_ptr(g), - get_ptr(state1), - get_ptr(state2), - get_ptr(unorm_vec), - ct.c_float(max_unorm), - ct.c_float(param_norm), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_int32(step), - ct.c_float(lr), - get_ptr(qmap1), - get_ptr(qmap2), - get_ptr(max1), - get_ptr(max2), - get_ptr(new_max1), - get_ptr(new_max2), - ct.c_float(weight_decay), - ct.c_float(gnorm_scale), - ct.c_int32(g.numel()), - ) - elif g.dtype == torch.float16 and state1.dtype == torch.uint8: - str2optimizer8bit[optimizer_name][1]( - get_ptr(p), - get_ptr(g), - get_ptr(state1), - get_ptr(state2), - get_ptr(unorm_vec), - ct.c_float(max_unorm), - ct.c_float(param_norm), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_int32(step), - ct.c_float(lr), - get_ptr(qmap1), - get_ptr(qmap2), - get_ptr(max1), - get_ptr(max2), - get_ptr(new_max1), - get_ptr(new_max2), - ct.c_float(weight_decay), - ct.c_float(gnorm_scale), - ct.c_int32(g.numel()), - ) - else: - raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" - ) - post_call(prev_device) - - -def optimizer_update_8bit_blockwise( - optimizer_name: str, - g: Tensor, - p: Tensor, - state1: Tensor, - state2: Tensor, - beta1: float, - beta2: float, - eps: float, - step: int, - lr: float, - qmap1: Tensor, - qmap2: Tensor, - absmax1: Tensor, - absmax2: Tensor, - weight_decay: float = 0.0, - gnorm_scale: float = 1.0, - skip_zeros=False, -) -> None: - - optim_func = None - prev_device = pre_call(g.device) - is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2]) - if g.dtype == torch.float32 and state1.dtype == torch.uint8: - optim_func = str2optimizer8bit_blockwise[optimizer_name][0] - elif g.dtype == torch.float16 and state1.dtype == torch.uint8: - optim_func = str2optimizer8bit_blockwise[optimizer_name][1] - elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and - len(str2optimizer8bit_blockwise[optimizer_name])==3): - optim_func = str2optimizer8bit_blockwise[optimizer_name][2] - else: - raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" - ) - post_call(prev_device) - - is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2]) - - prev_device = pre_call(g.device) - optim_func( - get_ptr(p), - get_ptr(g), - get_ptr(state1), - get_ptr(state2), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_int32(step), - ct.c_float(lr), - get_ptr(qmap1), - get_ptr(qmap2), - get_ptr(absmax1), - get_ptr(absmax2), - ct.c_float(weight_decay), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) - post_call(prev_device) - -def percentile_clipping( - grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5 -): - """Applies percentile clipping - - grad: torch.Tensor - The gradient tensor. - gnorm_vec: torch.Tensor - Vector of gradient norms. 100 elements expected. - step: int - The current optimiation steps (number of past gradient norms). - - """ - prev_device = pre_call(grad.device) - is_on_gpu([grad, gnorm_vec]) - if grad.dtype == torch.float32: - lib.cpercentile_clipping_g32( - get_ptr(grad), - get_ptr(gnorm_vec), - ct.c_int32(step), - ct.c_int32(grad.numel()), - ) - elif grad.dtype == torch.float16: - lib.cpercentile_clipping_g16( - get_ptr(grad), - get_ptr(gnorm_vec), - ct.c_int32(step), - ct.c_int32(grad.numel()), - ) - else: - raise ValueError(f"Gradient type {grad.dtype} not supported!") - post_call(prev_device) - - current_gnorm = torch.sqrt(gnorm_vec[step % 100]) - vals, idx = torch.sort(gnorm_vec) - clip_value = torch.sqrt(vals[percentile]) - gnorm_scale = 1.0 - - if current_gnorm > clip_value: - gnorm_scale = clip_value / current_gnorm - - return current_gnorm, clip_value, gnorm_scale - - -def histogram_scatter_add_2d( - histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor -): - assert len(histogram.shape) == 2 - assert histogram.dtype == torch.float32 - assert source.dtype == torch.float32 - assert index1.dtype == torch.int32 - assert index2.dtype == torch.int32 - - assert histogram.device.type == "cuda" - assert index1.device.type == "cuda" - assert index2.device.type == "cuda" - assert source.device.type == "cuda" - - maxdim1 = ct.c_int32(histogram.shape[0]) - n = ct.c_int32(index1.numel()) - is_on_gpu([histogram, index1, index2, source]) - lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n) - -def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): - if not torch.cuda.is_initialized(): torch.cuda.init() - if A.dtype != expected_type or B.dtype != expected_type: - raise TypeError( - f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}" - ) - - sA = A.shape - sB = B.shape - tA = transposed_A - tB = transposed_B - - correct = True - - if len(sA) == 2 and len(sB) == 2: - if not tA and not tB and A.shape[1] != B.shape[0]: - correct = False - elif tA and not tB and A.shape[0] != B.shape[0]: - correct = False - elif tA and tB and A.shape[0] != B.shape[1]: - correct = False - elif not tA and tB and A.shape[1] != B.shape[1]: - correct = False - elif len(sA) == 3 and len(sB) == 2: - if not tA and not tB and A.shape[2] != B.shape[0]: - correct = False - elif tA and not tB and A.shape[1] != B.shape[0]: - correct = False - elif tA and tB and A.shape[1] != B.shape[1]: - correct = False - elif not tA and tB and A.shape[2] != B.shape[1]: - correct = False - elif len(sA) == 3 and len(sB) == 3: - if not tA and not tB and A.shape[2] != B.shape[1]: - correct = False - elif tA and not tB and A.shape[1] != B.shape[1]: - correct = False - elif tA and tB and A.shape[1] != B.shape[2]: - correct = False - elif not tA and tB and A.shape[2] != B.shape[2]: - correct = False - - if out is not None: - sout = out.shape - # special case common in backprop - if not correct and len(sA) == 3 and len(sB) == 3: - if ( - sout[0] == sA[2] - and sout[1] == sB[2] - and sA[0] == sB[0] - and sA[1] == sB[1] - ): - correct = True - else: - if len(sA) == 2 and len(sB) == 2: - if not tA and not tB: - sout = (sA[0], sB[1]) - elif tA and tB: - sout = (sA[1], sB[0]) - elif tA and not tB: - sout = (sA[1], sB[1]) - elif not tA and tB: - sout = (sA[0], sB[0]) - elif len(sA) == 3 and len(sB) == 2: - if not tA and not tB: - sout = (sA[0], sA[1], sB[1]) - elif tA and tB: - sout = (sA[0], sA[2], sB[0]) - elif tA and not tB: - sout = (sA[0], sA[2], sB[1]) - elif not tA and tB: - sout = (sA[0], sA[1], sB[0]) - elif len(sA) == 3 and len(sB) == 3: - if not tA and not tB: - sout = (sA[0], sA[1], sB[2]) - elif tA and tB: - sout = (sA[0], sA[2], sB[1]) - elif tA and not tB: - sout = (sA[0], sA[2], sB[2]) - elif not tA and tB: - sout = (sA[0], sA[1], sB[1]) - - if not correct: - raise ValueError( - f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}." - ) - - return sout - -def cutlass3_gemm( - A: Tensor, - B: Tensor, - out: Tensor = None, - transposed_A=False, - transposed_B=False, - state=None -): - #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) - if state is None: - Bshape = B.shape - bout = Bshape[1] - else: - Bshape = state[1] - bout = Bshape[0] - if out is None: - out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device) - - sA = A.shape - sB = B.shape - if transposed_A and len(sA) == 2: - sA = (sA[1], sA[0]) - elif transposed_A and len(sA) == 3: - sA = (sA[0], sA[2], sA[0]) - if transposed_B and len(sB) == 2: - sB = (sB[1], sB[0]) - elif transposed_B and len(sB) == 3: - sB = (sB[0], sB[2], sB[0]) - # this is a mess: cuBLAS expect column major, but PyTorch is row major. - # So to perform the matrix multiplication, we have to treat A, B, and C matrices - # (transpose of row major is column major) - # This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these - - # matrices in the input arguments for cuBLAS - # column major: A @ B = C: [m, k] @ [k, n] = [m, n] - # row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n] - # column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m] - if len(sB) == 2: - if B.stride()[0] == B.shape[1]: - transposed_B = False - elif B.stride()[1] == B.shape[0]: - transposed_B = True - if len(A.shape) == 2: - if A.stride()[0] == A.shape[1]: - transposed_A = False - elif A.stride()[1] == A.shape[0]: - transposed_A = True - else: - if A.stride()[1] == A.shape[2]: - transposed_A = False - elif A.stride()[2] == A.shape[1]: - transposed_A = True - - if len(sA) == 2: - n = sA[0] - ldb = A.stride()[1 if transposed_A else 0] - elif len(sA) == 3 and len(sB) == 2: - n = sA[0] * sA[1] - ldb = sA[2] - - m = sB[1] - k = sB[0] - lda = B.stride()[0] - ldc = sB[1] - elif len(sB) == 3: - # special case - assert len(sA) == 3 - if not (sA[0] == sB[0] and sA[1] == sB[1]): - raise ValueError( - f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}" - ) - - transposed_A = True - transposed_B = False - - m = sB[2] - n = sA[2] - k = sB[0] * sB[1] - - lda = n - ldb = sA[2] - ldc = m - - ptr = CUBLAS_Context.get_instance().get_context(A.device) - - # B^T @ A^T = C^T - # [km, nk -> mn] - #lda = ldb = ldc = 1 - #lda = 1 - if state is not None: - m = Bshape[0] - k = Bshape[1] - lda = Bshape[0] - ldc = Bshape[0] - ldb = (ldb+1)//2 - #print(m, n, k, lda, ldb, ldc) - is_on_gpu([B, A, out]) - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - lda = ct.c_int32(lda) - ldb = ct.c_int32(ldb) - ldc = ct.c_int32(ldc) - - if B.dtype == torch.uint8: - lib.cgemm_4bit_inference(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) - elif A.dtype == torch.float32: - lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) - elif A.dtype == torch.float16: - lib.cgemm_host_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) - else: - raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') - - return out - - - - -def igemm( - A: Tensor, - B: Tensor, - out: Tensor = None, - transposed_A=False, - transposed_B=False, -): - sout = check_matmul(A, B, out, transposed_A, transposed_B) - if out is None: - out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) - if len(A.shape) == 3 and len(B.shape) == 3: - if A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1]: - return batched_igemm(A, B, out) - - sA = A.shape - sB = B.shape - if transposed_A and len(sA) == 2: - sA = (sA[1], sA[0]) - elif transposed_A and len(sA) == 3: - sA = (sA[0], sA[2], sA[0]) - if transposed_B and len(sB) == 2: - sB = (sB[1], sB[0]) - elif transposed_B and len(sB) == 3: - sB = (sB[0], sB[2], sB[0]) - # this is a mess: cuBLAS expect column major, but PyTorch is row major. - # So to perform the matrix multiplication, we have to treat A, B, and C matrices - # (transpose of row major is column major) - # This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these - - # matrices in the input arguments for cuBLAS - # column major: A @ B = C: [m, k] @ [k, n] = [m, n] - # row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n] - # column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m] - if len(sB) == 2: - if B.stride()[0] == B.shape[1]: - transposed_B = False - elif B.stride()[1] == B.shape[0]: - transposed_B = True - if len(A.shape) == 2: - if A.stride()[0] == A.shape[1]: - transposed_A = False - elif A.stride()[1] == A.shape[0]: - transposed_A = True - else: - if A.stride()[1] == A.shape[2]: - transposed_A = False - elif A.stride()[2] == A.shape[1]: - transposed_A = True - - if len(sA) == 2: - n = sA[0] - ldb = A.stride()[1 if transposed_A else 0] - elif len(sA) == 3 and len(sB) == 2: - n = sA[0] * sA[1] - ldb = sA[2] - - m = sB[1] - k = sB[0] - lda = B.stride()[(1 if transposed_B else 0)] - ldc = sB[1] - elif len(sB) == 3: - # special case - assert len(sA) == 3 - if not (sA[0] == sB[0] and sA[1] == sB[1]): - raise ValueError( - f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}" - ) - - transposed_A = True - transposed_B = False - - m = sB[2] - n = sA[2] - k = sB[0] * sB[1] - - lda = m - ldb = sA[2] - ldc = m - - ptr = CUBLAS_Context.get_instance().get_context(A.device) - - # B^T @ A^T = C^T - # [km, nk -> mn] - is_on_gpu([B, A, out]) - lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), - get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc)) - return out - - -def batched_igemm( - A: Tensor, - B: Tensor, - out: Tensor = None, - transposed_A=False, - transposed_B=False, -): - if not len(A.shape) == 3 or not len(B.shape) == 3: - raise ValueError( - f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}" - ) - sout = check_matmul(A, B, out, transposed_A, transposed_B) - if out is None: - out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) - - if B.is_contiguous(): - lda = B.stride()[1] - transposed_A = False - else: - s = B.stride() - if s[0] != B.shape[0]: - B = B.contiguous() - lda = B.stride()[1] - elif s[2] == B.shape[1]: - transposed_A = True - lda = B.stride()[2] - else: - if s[2] == 1: - B = B.contiguous() - lda = B.stride()[1] - elif s[1] == 1: - B = B.contiguous() - lda = B.stride()[1] - else: - B = B.contiguous() - lda = B.stride()[1] - - if A.is_contiguous(): - ldb = A.stride()[1] - transposed_B = False - else: - s = A.stride() - if s[0] != A.shape[0]: - A = A.contiguous() - ldb = A.stride()[1] - transposed_B = False - elif s[2] == A.shape[1]: - ldb = A.stride()[2] - transposed_B = True - else: - A = A.contiguous() - ldb = A.stride()[1] - transposed_B = False - - # this is a mess: cuBLAS expect column major, but PyTorch is row major. - # So to perform the matrix multiplication, we have to treat A, B, and C matrices - # (transpose of row major is column major) - # This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these - # matrices in the input arguments for cuBLAS - - # column major: A @ B = C: [batch, m, k] @ [batch, k, n] = [batch, m, n] - # row major: B^T @ A^T = C^T: [batch, m, k] @ [batch, k, n] = [batch, m, n] - # column major with row major layout: B^T @ A^T = C^T: [batch, k, m] @ [batch, n, k] = [batch, n, m] - num_batch = A.shape[0] - n = A.shape[1] - m = B.shape[2] - k = B.shape[1] - - ldc = m - - strideA = B.shape[1] * B.shape[2] - strideB = A.shape[1] * A.shape[2] - strideC = A.shape[1] * B.shape[2] - - ptr = CUBLAS_Context.get_instance().get_context(A.device) - - is_on_gpu([B, A, out]) - lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), - get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc), - ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch)) - return out - - -def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): - shapeA = SA[0] - shapeB = SB[0] - dimsA = len(shapeA) - dimsB = len(shapeB) - assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' - if dimsA == 2: - m = shapeA[0] - elif dimsA == 3: - m = shapeA[0] * shapeA[1] - - rows = n = shapeB[0] - assert prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}' - - # if the tensor is empty, return a transformed empty tensor with the right dimensions - if shapeA[0] == 0 and dimsA == 2: - return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16) - elif shapeA[1] == 0 and dimsA == 3: - return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16) - - if dimsA == 2 and out is None: - out, Sout = get_transform_buffer( - (shapeA[0], shapeB[0]), dtype, A.device, "col32", "row" - ) - elif dimsA == 3 and out is None: - out, Sout = get_transform_buffer( - (shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row" - ) - - assert dimsB != 3, "len(B.shape)==3 not supported" - assert A.device.type == "cuda" - assert B.device.type == "cuda" - assert A.dtype == torch.int8 - assert B.dtype == torch.int8 - assert out.dtype == dtype - assert SA[1] == "col32" - assert SB[1] in ["col_turing", "col_ampere"] - assert Sout[1] == "col32" - assert ( - shapeA[-1] == shapeB[-1] - ), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}" - formatB = SB[1] - prev_device = A.device - torch.cuda.set_device(A.device) - - ptr = CUBLAS_Context.get_instance().get_context(A.device) - ptrA = get_ptr(A) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - - k = shapeA[-1] - lda = ct.c_int32(m * 32) - if formatB == "col_turing": - # turing: tiles with rows filled up to multiple of 8 rows by 32 columns - # n = rows - ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32) - else: - # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns - # n = rows - ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32) - - ldc = ct.c_int32(m * 32) - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - - has_error = 0 - ptrRowScale = get_ptr(None) - is_on_gpu([A, B, out]) - if formatB == 'col_turing': - if dtype == torch.int32: - has_error = lib.cigemmlt_turing_32( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) - else: - has_error = lib.cigemmlt_turing_8( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) - elif formatB == "col_ampere": - if dtype == torch.int32: - has_error = lib.cigemmlt_ampere_32( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) - else: - has_error = lib.cigemmlt_ampere_8( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) - - if has_error == 1: - print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}') - raise Exception('cublasLt ran into an error!') - - torch.cuda.set_device(prev_device) - - return out, Sout - - -def mm_dequant( - A, - quant_state, - row_stats, - col_stats, - out=None, - new_row_stats=None, - new_col_stats=None, - bias=None -): - assert A.dtype == torch.int32 - if bias is not None: assert bias.dtype == torch.float16 - out_shape = quant_state[0] - if len(out_shape) == 3: - out_shape = (out_shape[0] * out_shape[1], out_shape[2]) - - if out is None: - out = torch.empty(out_shape, dtype=torch.float16, device=A.device) - if new_row_stats is None: - new_row_stats = torch.empty( - out_shape[0], dtype=torch.float32, device=A.device - ) - if new_col_stats is None: - new_col_stats = torch.empty( - out_shape[1], dtype=torch.float32, device=A.device - ) - assert ( - new_row_stats.shape[0] == row_stats.shape[0] - ), f"{new_row_stats.shape} vs {row_stats.shape}" - assert ( - new_col_stats.shape[0] == col_stats.shape[0] - ), f"{new_col_stats.shape} vs {col_stats.shape}" - - prev_device = pre_call(A.device) - ptrA = get_ptr(A) - ptrOut = get_ptr(out) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - ptrNewRowStats = get_ptr(new_row_stats) - ptrNewColStats = get_ptr(new_col_stats) - ptrBias = get_ptr(bias) - numRows = ct.c_int32(out_shape[0]) - numCols = ct.c_int32(out_shape[1]) - - is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias]) - lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols) - post_call(prev_device) - - return out - - -def get_colrow_absmax( - A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0 -): - assert A.dtype == torch.float16 - device = A.device - - cols = A.shape[-1] - if len(A.shape) == 3: - rows = A.shape[0] * A.shape[1] - else: - rows = A.shape[0] - - col_tiles = (cols + 255) // 256 - tiled_rows = ((rows + 15) // 16) * 16 - if row_stats is None: - row_stats = torch.empty( - (rows,), dtype=torch.float32, device=device - ).fill_(-50000.0) - if col_stats is None: - col_stats = torch.empty( - (cols,), dtype=torch.float32, device=device - ).fill_(-50000.0) - - if nnz_block_ptr is None and threshold > 0.0: - nnz_block_ptr = torch.zeros( - ((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device - ) - - ptrA = get_ptr(A) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - ptrNnzrows = get_ptr(nnz_block_ptr) - rows = ct.c_int32(rows) - cols = ct.c_int32(cols) - - prev_device = pre_call(A.device) - is_on_gpu([A, row_stats, col_stats, nnz_block_ptr]) - lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols) - post_call(prev_device) - - if threshold > 0.0: - nnz_block_ptr.cumsum_(0) - - return row_stats, col_stats, nnz_block_ptr - - -class COOSparseTensor: - def __init__(self, rows, cols, nnz, rowidx, colidx, values): - assert rowidx.dtype == torch.int32 - assert colidx.dtype == torch.int32 - assert values.dtype == torch.float16 - assert values.numel() == nnz - assert rowidx.numel() == nnz - assert colidx.numel() == nnz - - self.rows = rows - self.cols = cols - self.nnz = nnz - self.rowidx = rowidx - self.colidx = colidx - self.values = values - - -class CSRSparseTensor: - def __init__(self, rows, cols, nnz, rowptr, colidx, values): - assert rowptr.dtype == torch.int32 - assert colidx.dtype == torch.int32 - assert values.dtype == torch.float16 - assert values.numel() == nnz - assert colidx.numel() == nnz - assert rowptr.numel() == rows + 1 - - self.rows = rows - self.cols = cols - self.nnz = nnz - self.rowptr = rowptr - self.colidx = colidx - self.values = values - - -class CSCSparseTensor: - def __init__(self, rows, cols, nnz, colptr, rowidx, values): - assert colptr.dtype == torch.int32 - assert rowidx.dtype == torch.int32 - assert values.dtype == torch.float16 - assert values.numel() == nnz - assert rowidx.numel() == nnz - assert colptr.numel() == cols + 1 - - self.rows = rows - self.cols = cols - self.nnz = nnz - self.colptr = colptr - self.rowidx = rowidx - self.values = values - - -def coo2csr(cooA): - values, counts = torch.unique(cooA.rowidx, return_counts=True) - values.add_(1) - rowptr = torch.zeros( - (cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device - ) - rowptr.scatter_(index=values.long(), src=counts.int(), dim=0) - rowptr.cumsum_(0) - return CSRSparseTensor( - cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values - ) - - -def coo2csc(cooA): - val, col2rowidx = torch.sort(cooA.colidx) - rowidx = cooA.rowidx[col2rowidx] - values = cooA.values[col2rowidx] - colvalues, counts = torch.unique(val, return_counts=True) - colvalues.add_(1) - colptr = torch.zeros( - (cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device - ) - colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0) - colptr.cumsum_(0) - return CSCSparseTensor( - cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values - ) - - -def coo_zeros(rows, cols, nnz, device, dtype=torch.half): - rowidx = torch.zeros((nnz,), dtype=torch.int32, device=device) - colidx = torch.zeros((nnz,), dtype=torch.int32, device=device) - values = torch.zeros((nnz,), dtype=dtype, device=device) - return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) - - -def double_quant( - A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 -): - device = A.device - assert A.dtype == torch.half - assert device.type == "cuda" - prev_device = pre_call(A.device) - - cols = A.shape[-1] - if len(A.shape) == 3: - rows = A.shape[0] * A.shape[1] - else: - rows = A.shape[0] - - if row_stats is None or col_stats is None: - row_stats, col_stats, nnz_row_ptr = get_colrow_absmax( - A, threshold=threshold - ) - - if out_col is None: - out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) - if out_row is None: - out_row = torch.zeros(A.shape, device=device, dtype=torch.int8) - - coo_tensor = None - ptrA = get_ptr(A) - ptrColStats = get_ptr(col_stats) - ptrRowStats = get_ptr(row_stats) - ptrOutCol = get_ptr(out_col) - ptrOutRow = get_ptr(out_row) - - is_on_gpu([A, col_stats, row_stats, out_col, out_row]) - if threshold > 0.0: - nnz = nnz_row_ptr[-1].item() - if nnz > 0: - coo_tensor = coo_zeros( - A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device - ) - ptrRowIdx = get_ptr(coo_tensor.rowidx) - ptrColIdx = get_ptr(coo_tensor.colidx) - ptrVal = get_ptr(coo_tensor.values) - ptrRowPtr = get_ptr(nnz_row_ptr) - - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - ptrRowIdx, - ptrColIdx, - ptrVal, - ptrRowPtr, - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - ) - val, idx = torch.sort(coo_tensor.rowidx) - coo_tensor.rowidx = val - coo_tensor.colidx = coo_tensor.colidx[idx] - coo_tensor.values = coo_tensor.values[idx] - else: - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - None, - None, - None, - None, - ct.c_float(0.0), - ct.c_int32(rows), - ct.c_int32(cols), - ) - else: - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - None, - None, - None, - None, - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - ) - post_call(prev_device) - - return out_row, out_col, row_stats, col_stats, coo_tensor - - -def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): - prev_device = pre_call(A.device) - if state is None: state = (A.shape, from_order) - else: from_order = state[1] - if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) - else: new_state = (state[0], to_order) # (shape, order) - - shape = state[0] - if len(shape) == 2: - dim1 = ct.c_int32(shape[0]) - dim2 = ct.c_int32(shape[1]) - else: - dim1 = ct.c_int32(shape[0] * shape[1]) - dim2 = ct.c_int32(shape[2]) - - is_on_gpu([A, out]) - if to_order == 'col32': - if transpose: - lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) - else: - lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == "col_turing": - if transpose: - lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2) - else: - lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == "col_ampere": - if transpose: - lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2) - else: - lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == "row": - if from_order == "col_turing": - lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2) - elif from_order == "col_ampere": - lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) - else: - raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') - - post_call(prev_device) - - return out, new_state - - -def spmm_coo(cooA, B, out=None): - if out is None: - out = torch.empty( - (cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype - ) - nnz = cooA.nnz - assert cooA.rowidx.numel() == nnz - assert cooA.colidx.numel() == nnz - assert cooA.values.numel() == nnz - assert cooA.cols == B.shape[0] - - transposed_B = False if B.is_contiguous() else True - - ldb = B.stride()[(1 if transposed_B else 0)] - ldc = B.shape[1] - - ptr = Cusparse_Context.get_instance().context - - ptrRowidx = get_ptr(cooA.rowidx) - ptrColidx = get_ptr(cooA.colidx) - ptrValues = get_ptr(cooA.values) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - cnnz = ct.c_int32(cooA.nnz) - crowsA = ct.c_int32(cooA.rows) - ccolsA = ct.c_int32(cooA.cols) - ccolsB = ct.c_int32(B.shape[1]) - cldb = ct.c_int32(ldb) - cldc = ct.c_int32(ldc) - - is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out]) - lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B)) - - return out - - -def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): - if out is None: - out = torch.zeros( - (cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype - ) - nnz = cooA.nnz - prev_device = pre_call(B.device) - assert cooA.rowidx.numel() == nnz - assert cooA.colidx.numel() == nnz - assert cooA.values.numel() == nnz - assert cooA.cols == B.shape[0], f"{cooA.cols} vs {B.shape}" - - transposed_B = False if B.is_contiguous() else True - - ldb = B.stride()[(1 if transposed_B else 0)] - ldc = B.shape[1] - - values, counts = torch.unique(cooA.rowidx, return_counts=True) - offset = counts.cumsum(0).int() - max_count, max_idx = torch.sort(counts, descending=True) - max_idx = max_idx.int() - max_count = max_count.int() - assert ( - max_count[0] <= 32 - ), f"Current max count per row is 8 but found {max_count[0]}." - assert B.dtype in [torch.float16, torch.int8] - ptrOffset = get_ptr(offset) - ptrMaxCount = get_ptr(max_count) - ptrMaxIdx = get_ptr(max_idx) - - ptrRowidx = get_ptr(cooA.rowidx) - ptrColidx = get_ptr(cooA.colidx) - ptrValues = get_ptr(cooA.values) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - ptrDequantStats = get_ptr(dequant_stats) - cnnz_rows = ct.c_int32(counts.numel()) - cnnz = ct.c_int32(cooA.nnz) - crowsA = ct.c_int32(cooA.rows) - ccolsA = ct.c_int32(cooA.cols) - crowsB = ct.c_int32(B.shape[1]) - ccolsB = ct.c_int32(B.shape[1]) - cldb = ct.c_int32(ldb) - cldc = ct.c_int32(ldc) - - is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats]) - if B.dtype == torch.float16: - lib.cspmm_coo_very_sparse_naive_fp16( - ptrMaxCount, - ptrMaxIdx, - ptrOffset, - ptrRowidx, - ptrColidx, - ptrValues, - ptrB, - ptrC, - ptrDequantStats, - cnnz_rows, - cnnz, - crowsA, - crowsB, - ccolsB, - ) - elif B.dtype == torch.int8: - lib.cspmm_coo_very_sparse_naive_int8( - ptrMaxCount, - ptrMaxIdx, - ptrOffset, - ptrRowidx, - ptrColidx, - ptrValues, - ptrB, - ptrC, - ptrDequantStats, - cnnz_rows, - cnnz, - crowsA, - crowsB, - ccolsB, - ) - # else: assertion error - post_call(prev_device) - - return out - - -C = 127.0 - - -def vectorwise_quant(x, dim=1, quant_type="vector"): - if quant_type == "linear": - max1 = torch.abs(x).max().float() - xq = torch.round(x / max1 * 127).to(torch.int8) - return xq, max1 - elif quant_type in ["vector", "row"]: - max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) - xq = torch.round(x * (C / max1)).to(torch.int8) - return xq, max1 - elif quant_type == "zeropoint": - dtype = x.dtype - x = x.float() - dyna = x.max() - x.min() - if dyna == 0: - dyna = 1 - qx = 255.0 / dyna - minx = x.min() - zpx = torch.round(minx * qx) - x = torch.round(qx * x - zpx) + zpx - return x, qx - elif quant_type in ["vector-zeropoint", "row-zeropoint"]: - dtype = x.dtype - x = x.float() - dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin( - x, dim=dim, keepdim=True - ) - dyna[dyna == 0] = 1 - qx = 255.0 / dyna - minx = torch.amin(x, dim=dim, keepdim=True) - zpx = torch.round(minx * qx) - x = torch.round(qx * x - zpx) + zpx - return x, qx - elif quant_type == "truncated-vector": - with torch.no_grad(): - absx = torch.abs(x) - max1 = torch.amax(absx, dim=dim, keepdim=True) - max1 = max1 * 0.7 - idx = absx > max1.expand_as(absx) - sign = torch.sign(x[idx]) - x[idx] = max1.expand_as(absx)[idx] * sign - xq = torch.round(x / max1 * C).to(torch.int8) - return xq, max1 - else: - return None - - -def vectorwise_dequant(xq, max1, quant_type="vector"): - if quant_type == "vector": - x = (xq / C * max1).to(torch.float32) - return x - else: - return None - - -def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"): - if quant_type == "linear": - norm = S1 * S2 / (C * C) - # double cast needed to prevent overflows - return (xq.float() * norm).to(dtype) - elif quant_type == "zeropoint": - norm = 1.0 / (S1 * S2) - return (xq.float() * norm).to(dtype) - elif quant_type == "row-zeropoint": - norm = 1.0 / (S1 * S2) - x = xq.float() - if len(S1.shape) == 3 and len(x.shape) == 2: - S1 = S1.squeeze(0) - if len(S2.shape) == 3 and len(x.shape) == 2: - S2 = S2.squeeze(0) - if len(S1.shape) == 2: - x *= norm - else: - x *= norm - return x.to(dtype) - elif quant_type == "vector-zeropoint": - x = xq.float() - if len(S1.shape) == 3 and len(x.shape) == 2: - S1 = S1.squeeze(0) - if len(S2.shape) == 3 and len(x.shape) == 2: - S2 = S2.squeeze(0) - if len(S1.shape) == 2: - x *= 1.0 / S1 - else: - x *= 1.0 / S1 - x *= 1.0 / S2.t() - return x.to(dtype) - elif quant_type == "row": - x = xq.float() - if len(S1.shape) == 3 and len(x.shape) == 2: - S1 = S1.squeeze(0) - if len(S2.shape) == 3 and len(x.shape) == 2: - S2 = S2.squeeze(0) - if len(S1.shape) == 2: - x *= S1 * S2 / (C * C) - else: - x *= S1 * S2 / (C * C) - return x.to(dtype) - elif quant_type in ["truncated-vector", "vector"]: - x = xq.float() - if len(S1.shape) == 3 and len(x.shape) == 2: - S1 = S1.squeeze(0) - if len(S2.shape) == 3 and len(x.shape) == 2: - S2 = S2.squeeze(0) - if len(S1.shape) == 2: - x *= S1 / C - else: - x *= S1 / C - x *= S2 / C - return x.to(dtype) - else: - return None - - -def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): - offset = B.float().t().sum(0) * (SA[0] + SA[1]) - x = xq.float() - if len(xq.shape) == 2 and len(SB.shape) == 3: - SB = SB.squeeze(0) - if len(SB.shape) == 2: - x *= SB.t() / 127 - else: - x *= SB / 127 - x *= SA[1] / 127 - x += offset - return x.to(dtype) - - -def extract_outliers(A, SA, idx): - shapeA = SA[0] - formatA = SA[1] - assert formatA in ["col_turing", "col_ampere"] - assert A.device.type == "cuda" - - out = torch.zeros( - (shapeA[0], idx.numel()), dtype=torch.int8, device=A.device - ) - - idx_size = ct.c_int32(idx.numel()) - rows = ct.c_int32(shapeA[0]) - cols = ct.c_int32(shapeA[1]) - ptrA = get_ptr(A) - ptrIdx = get_ptr(idx) - ptrOut = get_ptr(out) - - prev_device = pre_call(A.device) - if formatA == 'col_turing': - lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) - elif formatA == "col_ampere": - lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) - post_call(prev_device) - - return out - -def pipeline_test(A, batch_size): - out = torch.zeros_like(A) - lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) - return out diff --git a/bitsandbytes/autograd/__init__.py b/bitsandbytes/autograd/__init__.py index 6b9a7e4d1..f262d89ed 100644 --- a/bitsandbytes/autograd/__init__.py +++ b/bitsandbytes/autograd/__init__.py @@ -1 +1 @@ -from ._functions import undo_layout, get_inverse_transform_indices +from ._functions import get_inverse_transform_indices, undo_layout diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 59b0ac7b2..e9821cd36 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -1,8 +1,8 @@ -import operator -import warnings from dataclasses import dataclass from functools import reduce # Required in Python 3 -from typing import Tuple, Optional, List +import operator +from typing import Callable, Optional, Tuple +import warnings from warnings import warn import torch @@ -14,19 +14,18 @@ def prod(iterable): return reduce(operator.mul, iterable, 1) -tensor = torch.Tensor - # The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov: # https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py - """ This class pools outlier dimensions across layers. This is particularly important for small models where outlier features are less systematic and occur with low frequency. """ + + class GlobalOutlierPooler: _instance = None @@ -56,7 +55,10 @@ def get_current_outlier_idx(self): return torch.Tensor(list(self.outliers)).to(torch.int64) -def get_inverse_transform_indices(transform_tile: callable, tile_size: Tuple[int, int]): +def get_inverse_transform_indices( + transform_tile: Callable[[torch.Tensor], torch.Tensor], + tile_size: Tuple[int, int], +): """ Compute a permutation of indices that invert the specified (tiled) matrix transformation @@ -83,6 +85,7 @@ def get_inverse_transform_indices(transform_tile: callable, tile_size: Tuple[int break # if all indices fit in i bytes, stop early return permuted_tile_indices + def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor: """ Undo a tiled permutation such as turing or ampere layout @@ -159,20 +162,12 @@ def backward(ctx, grad_output): ) if not A.is_contiguous(): A = A.contiguous() - qA, S2 = F.vectorwise_quant( - A.view(-1, A.shape[2]), dim=0, quant_type=quant_type - ) + qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type) igrad_B = F.igemm(qA.t(), qgrad_output) - grad_B = F.vectorwise_mm_dequant( - igrad_B, S2.t(), S1, grad_output.dtype, quant_type - ) + grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type) else: - qgrad_output, S1 = F.vectorwise_quant( - grad_output, dim=dims, quant_type=quant_type - ) - qA, S2 = F.vectorwise_quant( - A, dim=dims, quant_type=quant_type - ) + qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type) + qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type) igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output) grad_B = F.vectorwise_mm_dequant( igrad_B, @@ -201,9 +196,7 @@ def backward(ctx, grad_output): with torch.no_grad(): grad_A = torch.matmul(grad_output, B.permute(permute_dim)) else: - qgrad_output, S1 = F.vectorwise_quant( - grad_output, dim=dims, quant_type=quant_type - ) + qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type) qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type) igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim)) grad_A = F.vectorwise_mm_dequant( @@ -224,12 +217,10 @@ def backward(ctx, grad_output): def supports_igemmlt(device: torch.device) -> bool: """check if this device supports the optimized int8 kernel""" - if torch.version.hip: - return True if torch.cuda.get_device_capability(device=device) < (7, 5): return False device_name = torch.cuda.get_device_name(device=device) - nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series + nvidia16_models = ("GTX 1630", "GTX 1650", "GTX 1660") # https://en.wikipedia.org/wiki/GeForce_16_series if any(model_name in device_name for model_name in nvidia16_models): return False # these devices are technically cuda 7.5-capable, but they lack tensor cores return True @@ -248,6 +239,7 @@ def get_tile_inds(format, device): with torch.no_grad(): return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device) + @dataclass class MatmulLtState: _tile_indices: Optional[torch.Tensor] = None @@ -498,7 +490,7 @@ class MatMul4Bit(torch.autograd.Function): # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") @staticmethod - def forward(ctx, A, B, out=None, bias=None, quant_state: F.QuantState = None): + def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] = None): # default of pytorch behavior if inputs are empty ctx.is_empty = False if prod(A.shape) == 0: @@ -512,7 +504,6 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: F.QuantState = None): else: return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) - # 1. Dequantize # 2. MatmulnN output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) @@ -534,7 +525,7 @@ def backward(ctx, grad_output): bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None - req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad + req_gradA, _, _, req_gradBias, _ = ctx.needs_input_grad A, B = ctx.tensors grad_A, grad_B, grad_bias = None, None, None @@ -544,19 +535,20 @@ def backward(ctx, grad_output): grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) # not supported by PyTorch. TODO: create work-around - #if req_gradB: grad_B = torch.matmul(grad_output.t(), A) - if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t()) + # if req_gradB: grad_B = torch.matmul(grad_output.t(), A) + if req_gradA: + grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t()) return grad_A, grad_B, None, grad_bias, None def matmul( - A: tensor, - B: tensor, - out: tensor = None, - state: MatmulLtState = None, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + state: Optional[MatmulLtState] = None, threshold=0.0, - bias=None + bias=None, ): state = state or MatmulLtState() if threshold > 0.0: @@ -564,11 +556,19 @@ def matmul( return MatMul8bitLt.apply(A, B, out, bias, state) -def matmul_4bit(A: tensor, B: tensor, quant_state: F.QuantState, out: tensor = None, bias=None): +def matmul_4bit( + A: torch.Tensor, + B: torch.Tensor, + quant_state: F.QuantState, + out: Optional[torch.Tensor] = None, + bias=None, +): assert quant_state is not None if A.numel() == A.shape[-1] and A.requires_grad == False: if A.shape[-1] % quant_state.blocksize != 0: - warn(f'Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}') + warn( + f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}", + ) return MatMul4Bit.apply(A, B, out, bias, quant_state) else: out = F.gemv_4bit(A, B.t(), out, state=quant_state) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 03a208995..4bd698950 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -1,50 +1,124 @@ +""" +extract factors the build is dependent on: +[X] compute capability + [ ] TODO: Q - What if we have multiple GPUs of different makes? +- CUDA version +- Software: + - CPU-only: only CPU quantization functions (no optimizer, no matrix multiple) + - CuBLAS-LT: full-build 8-bit optimizer + - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`) + +evaluation: + - if paths faulty, return meaningful error + - else: + - determine CUDA version + - determine capabilities + - based on that set the default path +""" + import ctypes as ct +import logging import os +from pathlib import Path + import torch -from pathlib import Path -from warnings import warn +from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR +from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs -from bitsandbytes.cuda_setup.main import CUDASetup +logger = logging.getLogger(__name__) -setup = CUDASetup.get_instance() -if setup.initialized != True: - setup.run_cuda_setup() +def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: + """ + Get the disk path to the CUDA BNB native library specified by the + given CUDA specs, taking into account the `BNB_CUDA_VERSION` override environment variable. -lib = setup.lib -try: - if lib is None and torch.cuda.is_available() : - CUDASetup.get_instance().generate_instructions() - CUDASetup.get_instance().print_log_stack() - raise RuntimeError(''' - CUDA Setup failed despite GPU being available. Please run the following command to get more information: - - python -m bitsandbytes - - Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them - to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes - and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''') - - lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False - lib.get_context.restype = ct.c_void_p - - HIP_ENVIRONMENT = False - if torch.version.cuda: + The library is not guaranteed to exist at the returned path. + """ + library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}" + if not cuda_specs.has_cublaslt: + # if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt + library_name += "_nocublaslt" + library_name = f"{library_name}{DYNAMIC_LIBRARY_SUFFIX}" + library_name = f"libbitsandbytes_hip_nohipblaslt{DYNAMIC_LIBRARY_SUFFIX}" + override_value = os.environ.get("BNB_CUDA_VERSION") + if override_value: + library_name_stem, _, library_name_ext = library_name.rpartition(".") + # `library_name_stem` will now be e.g. `libbitsandbytes_cuda118`; + # let's remove any trailing numbers: + library_name_stem = library_name_stem.rstrip("0123456789") + # `library_name_stem` will now be e.g. `libbitsandbytes_cuda`; + # let's tack the new version number and the original extension back on. + library_name = f"{library_name_stem}{override_value}.{library_name_ext}" + logger.warning( + f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n" + "This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n" + "If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n" + "If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n" + "For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: BNBNativeLibrary: + binary_path = PACKAGE_DIR / f"libbitsandbytes_cpu{DYNAMIC_LIBRARY_SUFFIX}" + cuda_specs = get_cuda_specs() + if cuda_specs: + cuda_binary_path = get_cuda_bnb_library_path(cuda_specs) + if cuda_binary_path.exists(): + binary_path = cuda_binary_path + else: + logger.warning("Could not find the bitsandbytes CUDA binary at %r", cuda_binary_path) + logger.debug(f"Loading bitsandbytes native library from: {binary_path}") + dll = ct.cdll.LoadLibrary(str(binary_path)) + + if hasattr(dll, "get_context"): # only a CUDA-built library exposes this + return CudaBNBNativeLibrary(dll) + + logger.warning( + "The installed version of bitsandbytes was compiled without GPU support. " + "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.", + ) + return BNBNativeLibrary(dll) + + +try: + lib = get_native_library() +except Exception as e: + lib = None + logger.error(f"Could not load bitsandbytes native library: {e}", exc_info=True) + if torch.cuda.is_available(): + logger.warning( + """ +CUDA Setup failed despite CUDA being available. Please run the following command to get more information: + +python -m bitsandbytes + +Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them +to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes +and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues +""", + ) diff --git a/bitsandbytes/consts.py b/bitsandbytes/consts.py new file mode 100644 index 000000000..8242d104e --- /dev/null +++ b/bitsandbytes/consts.py @@ -0,0 +1,12 @@ +from pathlib import Path +import platform + +DYNAMIC_LIBRARY_SUFFIX = { + "Darwin": ".dylib", + "Linux": ".so", + "Windows": ".dll", +}.get(platform.system(), ".so") + +PACKAGE_DIR = Path(__file__).parent +PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes" +NONPYTORCH_DOC_URL = "https://github.com/TimDettmers/bitsandbytes/blob/main/docs/source/nonpytorchcuda.mdx" diff --git a/bitsandbytes/cuda_setup/env_vars.py b/bitsandbytes/cuda_setup/env_vars.py deleted file mode 100644 index e8268fcaa..000000000 --- a/bitsandbytes/cuda_setup/env_vars.py +++ /dev/null @@ -1,53 +0,0 @@ -import os -from typing import Dict - - -def to_be_ignored(env_var: str, value: str) -> bool: - ignorable = { - "PWD", # PWD: this is how the shell keeps track of the current working dir - "OLDPWD", - "SSH_AUTH_SOCK", # SSH stuff, therefore unrelated - "SSH_TTY", - "GOOGLE_VM_CONFIG_LOCK_FILE", # GCP: requires elevated permissions, causing problems in VMs and Jupyter notebooks - "HOME", # Linux shell default - "TMUX", # Terminal Multiplexer - "XDG_DATA_DIRS", # XDG: Desktop environment stuff - "XDG_GREETER_DATA_DIR", # XDG: Desktop environment stuff - "XDG_RUNTIME_DIR", - "MAIL", # something related to emails - "SHELL", # binary for currently invoked shell - "DBUS_SESSION_BUS_ADDRESS", # hardware related - "PATH", # this is for finding binaries, not libraries - "LESSOPEN", # related to the `less` command - "LESSCLOSE", - "_", # current Python interpreter - } - return env_var in ignorable - - -def might_contain_a_path(candidate: str) -> bool: - return "/" in candidate - - -def is_active_conda_env(env_var: str) -> bool: - return "CONDA_PREFIX" == env_var - - -def is_other_conda_env_var(env_var: str) -> bool: - return "CONDA" in env_var - - -def is_relevant_candidate_env_var(env_var: str, value: str) -> bool: - return is_active_conda_env(env_var) or ( - might_contain_a_path(value) and not - is_other_conda_env_var(env_var) and not - to_be_ignored(env_var, value) - ) - - -def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]: - return { - env_var: value - for env_var, value in os.environ.items() - if is_relevant_candidate_env_var(env_var, value) - } diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py deleted file mode 100644 index b4962c1a0..000000000 --- a/bitsandbytes/cuda_setup/main.py +++ /dev/null @@ -1,377 +0,0 @@ -""" -extract factors the build is dependent on: -[X] compute capability - [ ] TODO: Q - What if we have multiple GPUs of different makes? -- CUDA version -- Software: - - CPU-only: only CPU quantization functions (no optimizer, no matrix multipl) - - CuBLAS-LT: full-build 8-bit optimizer - - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`) - -evaluation: - - if paths faulty, return meaningful error - - else: - - determine CUDA version - - determine capabilities - - based on that set the default path -""" - -import ctypes as ct -import os -import errno -import torch -from warnings import warn -from itertools import product - -from pathlib import Path -from typing import Set, Union -from .env_vars import get_potentially_lib_path_containing_env_vars - -# these are the most common libs names -# libcudart.so is missing by default for a conda install with PyTorch 2.0 and instead -# we have libcudart.so.11.0 which causes a lot of errors before -# not sure if libcudart.so.12.0 exists in pytorch installs, but it does not hurt -CUDA_RUNTIME_LIBS: list = ["libcudart.so", 'libcudart.so.11.0', 'libcudart.so.12.0', 'libcudart.so.12.1', 'libcudart.so.12.2'] - -# this is a order list of backup paths to search CUDA in, if it cannot be found in the main environmental paths -backup_paths = [] -backup_paths.append('$CONDA_PREFIX/lib/libcudart.so.11.0') - -class CUDASetup: - _instance = None - - def __init__(self): - raise RuntimeError("Call get_instance() instead") - - def generate_instructions(self): - if getattr(self, 'error', False): return - print(self.error) - self.error = True - if not self.cuda_available: - self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected or CUDA not installed.') - self.add_log_entry('CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.') - self.add_log_entry('CUDA SETUP: Solution 2): If you do not have sudo rights, you can do the following:') - self.add_log_entry('CUDA SETUP: Solution 2a): Find the cuda library via: find / -name libcuda.so 2>/dev/null') - self.add_log_entry('CUDA SETUP: Solution 2b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_2a') - self.add_log_entry('CUDA SETUP: Solution 2c): For a permanent solution add the export from 2b into your .bashrc file, located at ~/.bashrc') - self.add_log_entry('CUDA SETUP: Solution 3): For a missing CUDA runtime library (libcudart.so), use `find / -name libcudart.so* and follow with step (2b)') - return - - if self.cudart_path is None: - self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA runtime library was not detected.') - self.add_log_entry('CUDA SETUP: Solution 1: To solve the issue the libcudart.so location needs to be added to the LD_LIBRARY_PATH variable') - self.add_log_entry('CUDA SETUP: Solution 1a): Find the cuda runtime library via: find / -name libcudart.so 2>/dev/null') - self.add_log_entry('CUDA SETUP: Solution 1b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_1a') - self.add_log_entry('CUDA SETUP: Solution 1c): For a permanent solution add the export from 1b into your .bashrc file, located at ~/.bashrc') - self.add_log_entry('CUDA SETUP: Solution 2: If no library was found in step 1a) you need to install CUDA.') - self.add_log_entry('CUDA SETUP: Solution 2a): Download CUDA install script: wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh') - self.add_log_entry('CUDA SETUP: Solution 2b): Install desired CUDA version to desired location. The syntax is bash cuda_install.sh CUDA_VERSION PATH_TO_INSTALL_INTO.') - self.add_log_entry('CUDA SETUP: Solution 2b): For example, "bash cuda_install.sh 113 ~/local/" will download CUDA 11.3 and install into the folder ~/local') - - return - - make_cmd = f'CUDA_VERSION={self.cuda_version_string}' - if len(self.cuda_version_string) < 3: - make_cmd += ' make cuda92' - elif self.cuda_version_string == '110': - make_cmd += ' make cuda110' - elif self.cuda_version_string[:2] == '11' and int(self.cuda_version_string[2]) > 0: - make_cmd += ' make cuda11x' - elif self.cuda_version_string[:2] == '12' and 1 >= int(self.cuda_version_string[2]) >= 0: - make_cmd += ' make cuda12x' - elif self.cuda_version_string == '100': - self.add_log_entry('CUDA SETUP: CUDA 10.0 not supported. Please use a different CUDA version.') - self.add_log_entry('CUDA SETUP: Before you try again running bitsandbytes, make sure old CUDA 10.0 versions are uninstalled and removed from $LD_LIBRARY_PATH variables.') - return - - - has_cublaslt = is_cublasLt_compatible(self.cc) - if not has_cublaslt: - make_cmd += '_nomatmul' - - self.add_log_entry('CUDA SETUP: Something unexpected happened. Please compile from source:') - self.add_log_entry('git clone https://github.com/TimDettmers/bitsandbytes.git') - self.add_log_entry('cd bitsandbytes') - self.add_log_entry(make_cmd) - self.add_log_entry('python setup.py install') - - def initialize(self): - if not getattr(self, 'initialized', False): - self.has_printed = False - self.lib = None - self.initialized = False - self.error = False - - def manual_override(self): - if torch.cuda.is_available(): - if 'BNB_CUDA_VERSION' in os.environ: - if len(os.environ['BNB_CUDA_VERSION']) > 0: - warn((f'\n\n{"="*80}\n' - 'WARNING: Manual override via BNB_CUDA_VERSION env variable detected!\n' - 'BNB_CUDA_VERSION=XXX can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n' - 'If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n' - 'If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n' - 'For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: Set[Path]: - return {Path(ld_path) for ld_path in paths_list_candidate.split(":") if ld_path} - - -def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]: - existent_directories: Set[Path] = set() - for path in candidate_paths: - try: - if path.exists(): - existent_directories.add(path) - except PermissionError as pex: - # Handle the PermissionError first as it is a subtype of OSError - # https://docs.python.org/3/library/exceptions.html#exception-hierarchy - pass - except OSError as exc: - if exc.errno != errno.ENAMETOOLONG: - raise exc - - non_existent_directories: Set[Path] = candidate_paths - existent_directories - if non_existent_directories: - CUDASetup.get_instance().add_log_entry("The following directories listed in your path were found to " - f"be non-existent: {non_existent_directories}", is_warning=False) - - return existent_directories - - -def get_cuda_runtime_lib_paths(candidate_paths: Set[Path]) -> Set[Path]: - paths = set() - for libname in CUDA_RUNTIME_LIBS: - for path in candidate_paths: - try: - if (path / libname).is_file(): - paths.add(path / libname) - except PermissionError: - pass - return paths - - -def resolve_paths_list(paths_list_candidate: str) -> Set[Path]: - """ - Searches a given environmental var for the CUDA runtime library, - i.e. `libcudart.so`. - """ - return remove_non_existent_dirs(extract_candidate_paths(paths_list_candidate)) - - -def find_cuda_lib_in(paths_list_candidate: str) -> Set[Path]: - return get_cuda_runtime_lib_paths( - resolve_paths_list(paths_list_candidate) - ) - - -def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None: - if len(results_paths) > 1: - warning_msg = ( - f"Found duplicate {CUDA_RUNTIME_LIBS} files: {results_paths}.. " - "We select the PyTorch default libcudart.so, which is {torch.version.cuda}," - "but this might missmatch with the CUDA version that is needed for bitsandbytes." - "To override this behavior set the BNB_CUDA_VERSION= environmental variable" - "For example, if you want to use the CUDA version 122" - "BNB_CUDA_VERSION=122 python ..." - "OR set the environmental variable in your .bashrc: export BNB_CUDA_VERSION=122" - "In the case of a manual override, make sure you set the LD_LIBRARY_PATH, e.g." - "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2") - CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True) - - -def determine_cuda_runtime_lib_path() -> Union[Path, None]: - """ - Searches for a cuda installations, in the following order of priority: - 1. active conda env - 2. LD_LIBRARY_PATH - 3. any other env vars, while ignoring those that - - are known to be unrelated (see `bnb.cuda_setup.env_vars.to_be_ignored`) - - don't contain the path separator `/` - - If multiple libraries are found in part 3, we optimistically try one, - while giving a warning message. - """ - candidate_env_vars = get_potentially_lib_path_containing_env_vars() - - cuda_runtime_libs = set() - if "CONDA_PREFIX" in candidate_env_vars: - conda_libs_path = Path(candidate_env_vars["CONDA_PREFIX"]) / "lib" - - conda_cuda_libs = find_cuda_lib_in(str(conda_libs_path)) - warn_in_case_of_duplicates(conda_cuda_libs) - - if conda_cuda_libs: - cuda_runtime_libs.update(conda_cuda_libs) - - CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["CONDA_PREFIX"]} did not contain ' - f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True) - - if "LD_LIBRARY_PATH" in candidate_env_vars: - lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"]) - - if lib_ld_cuda_libs: - cuda_runtime_libs.update(lib_ld_cuda_libs) - warn_in_case_of_duplicates(lib_ld_cuda_libs) - - CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain ' - f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True) - - remaining_candidate_env_vars = { - env_var: value for env_var, value in candidate_env_vars.items() - if env_var not in {"CONDA_PREFIX", "LD_LIBRARY_PATH"} - } - - cuda_runtime_libs = set() - for env_var, value in remaining_candidate_env_vars.items(): - cuda_runtime_libs.update(find_cuda_lib_in(value)) - - if len(cuda_runtime_libs) == 0: - CUDASetup.get_instance().add_log_entry('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching in backup paths...') - cuda_runtime_libs.update(find_cuda_lib_in('/usr/local/cuda/lib64')) - - warn_in_case_of_duplicates(cuda_runtime_libs) - - cuda_setup = CUDASetup.get_instance() - cuda_setup.add_log_entry(f'DEBUG: Possible options found for libcudart.so: {cuda_runtime_libs}') - - return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else None - - -# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION -def get_cuda_version(): - major, minor = map(int, torch.version.cuda.split(".")) - - if major < 11: - CUDASetup.get_instance().add_log_entry('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!') - - return f'{major}{minor}' - -def get_compute_capabilities(): - ccs = [] - for i in range(torch.cuda.device_count()): - cc_major, cc_minor = torch.cuda.get_device_capability(torch.cuda.device(i)) - ccs.append(f"{cc_major}.{cc_minor}") - - ccs.sort(key=lambda v: tuple(map(int, str(v).split(".")))) - - return ccs - - -def evaluate_cuda_setup(): - cuda_setup = CUDASetup.get_instance() - if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': - cuda_setup.add_log_entry('') - cuda_setup.add_log_entry('='*35 + 'BUG REPORT' + '='*35) - cuda_setup.add_log_entry(('Welcome to bitsandbytes. For bug reports, please run\n\npython -m bitsandbytes\n\n'), - ('and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')) - cuda_setup.add_log_entry('='*80) - if not torch.cuda.is_available(): return 'libbitsandbytes_cpu.so', None, None, None - if torch.version.hip: return 'libbitsandbytes_hip_nohipblaslt.so', None, None, None - - cudart_path = determine_cuda_runtime_lib_path() - ccs = get_compute_capabilities() - ccs.sort() - cc = ccs[-1] # we take the highest capability - cuda_version_string = get_cuda_version() - - cuda_setup.add_log_entry(f"CUDA SETUP: PyTorch settings found: CUDA_VERSION={cuda_version_string}, Highest Compute Capability: {cc}.") - cuda_setup.add_log_entry(f"CUDA SETUP: To manually override the PyTorch CUDA version please see:" - "https://github.com/TimDettmers/bitsandbytes/blob/main/how_to_use_nonpytorch_cuda.md") - - - # 7.5 is the minimum CC vor cublaslt - has_cublaslt = is_cublasLt_compatible(cc) - - # TODO: - # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) - # (2) Multiple CUDA versions installed - - # we use ls -l instead of nvcc to determine the cuda version - # since most installations will have the libcudart.so installed, but not the compiler - - if has_cublaslt: - binary_name = f"libbitsandbytes_cuda{cuda_version_string}.so" - else: - "if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so" - binary_name = f"libbitsandbytes_cuda{cuda_version_string}_nocublaslt.so" - - return binary_name, cudart_path, cc, cuda_version_string diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py new file mode 100644 index 000000000..936037a51 --- /dev/null +++ b/bitsandbytes/cuda_specs.py @@ -0,0 +1,44 @@ +import dataclasses +from typing import List, Optional, Tuple + +import torch + + +@dataclasses.dataclass(frozen=True) +class CUDASpecs: + highest_compute_capability: Tuple[int, int] + cuda_version_string: str + cuda_version_tuple: Tuple[int, int] + + @property + def has_cublaslt(self) -> bool: + return self.highest_compute_capability >= (7, 5) + + +def get_compute_capabilities() -> List[Tuple[int, int]]: + return sorted(torch.cuda.get_device_capability(torch.cuda.device(i)) for i in range(torch.cuda.device_count())) + + +def get_cuda_version_tuple() -> Tuple[int, int]: + # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION + #major, minor = map(int, torch.version.cuda.split(".")) + #Current ROCM version + ROCM_version = '11.0' + major, minor = map(int, ROCM_version.split(".")) + return major, minor + + +def get_cuda_version_string() -> str: + major, minor = get_cuda_version_tuple() + return f"{major}{minor}" + + +def get_cuda_specs() -> Optional[CUDASpecs]: + if not torch.cuda.is_available(): + return None + + return CUDASpecs( + highest_compute_capability=(get_compute_capabilities()[-1]), + cuda_version_string=(get_cuda_version_string()), + cuda_version_tuple=get_cuda_version_tuple(), + ) diff --git a/bitsandbytes/cuda_setup/__init__.py b/bitsandbytes/diagnostics/__init__.py similarity index 100% rename from bitsandbytes/cuda_setup/__init__.py rename to bitsandbytes/diagnostics/__init__.py diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/cuda.py new file mode 100644 index 000000000..f993dff7e --- /dev/null +++ b/bitsandbytes/diagnostics/cuda.py @@ -0,0 +1,176 @@ +import logging +import os +from pathlib import Path +from typing import Dict, Iterable, Iterator + +import torch + +from bitsandbytes.cextension import get_cuda_bnb_library_path +from bitsandbytes.consts import NONPYTORCH_DOC_URL +from bitsandbytes.cuda_specs import CUDASpecs +from bitsandbytes.diagnostics.utils import print_dedented + +CUDART_PATH_PREFERRED_ENVVARS = ("CONDA_PREFIX", "LD_LIBRARY_PATH") + +CUDART_PATH_IGNORED_ENVVARS = { + "DBUS_SESSION_BUS_ADDRESS", # hardware related + "GOOGLE_VM_CONFIG_LOCK_FILE", # GCP: requires elevated permissions, causing problems in VMs and Jupyter notebooks + "HOME", # Linux shell default + "LESSCLOSE", + "LESSOPEN", # related to the `less` command + "MAIL", # something related to emails + "OLDPWD", + "PATH", # this is for finding binaries, not libraries + "PWD", # PWD: this is how the shell keeps track of the current working dir + "SHELL", # binary for currently invoked shell + "SSH_AUTH_SOCK", # SSH stuff, therefore unrelated + "SSH_TTY", + "TMUX", # Terminal Multiplexer + "XDG_DATA_DIRS", # XDG: Desktop environment stuff + "XDG_GREETER_DATA_DIR", # XDG: Desktop environment stuff + "XDG_RUNTIME_DIR", + "_", # current Python interpreter +} + +CUDA_RUNTIME_LIB_PATTERNS = ( + "cudart64*.dll", # Windows + "libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc. + "nvcuda*.dll", # Windows +) + +logger = logging.getLogger(__name__) + + +def find_cuda_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path]: + for dir_string in paths_list_candidate.split(os.pathsep): + if not dir_string: + continue + if os.sep not in dir_string: + continue + try: + dir = Path(dir_string) + try: + if not dir.exists(): + logger.warning(f"The directory listed in your path is found to be non-existent: {dir}") + continue + except OSError: # Assume an esoteric error trying to poke at the directory + pass + for lib_pattern in CUDA_RUNTIME_LIB_PATTERNS: + for pth in dir.glob(lib_pattern): + if pth.is_file(): + yield pth + except PermissionError: + pass + + +def is_relevant_candidate_env_var(env_var: str, value: str) -> bool: + return ( + env_var in CUDART_PATH_PREFERRED_ENVVARS # is a preferred location + or ( + os.sep in value # might contain a path + and env_var not in CUDART_PATH_IGNORED_ENVVARS # not ignored + and "CONDA" not in env_var # not another conda envvar + and "BASH_FUNC" not in env_var # not a bash function defined via envvar + and "\n" not in value # likely e.g. a script or something? + ) + ) + + +def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]: + return {env_var: value for env_var, value in os.environ.items() if is_relevant_candidate_env_var(env_var, value)} + + +def find_cudart_libraries() -> Iterator[Path]: + """ + Searches for a cuda installations, in the following order of priority: + 1. active conda env + 2. LD_LIBRARY_PATH + 3. any other env vars, while ignoring those that + - are known to be unrelated + - don't contain the path separator `/` + + If multiple libraries are found in part 3, we optimistically try one, + while giving a warning message. + """ + candidate_env_vars = get_potentially_lib_path_containing_env_vars() + + for envvar in CUDART_PATH_PREFERRED_ENVVARS: + if envvar in candidate_env_vars: + directory = candidate_env_vars[envvar] + yield from find_cuda_libraries_in_path_list(directory) + candidate_env_vars.pop(envvar) + + for env_var, value in candidate_env_vars.items(): + yield from find_cuda_libraries_in_path_list(value) + + +def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: + print( + f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, " + f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.", + ) + + binary_path = get_cuda_bnb_library_path(cuda_specs) + if not binary_path.exists(): + print_dedented( + f""" + Library not found: {binary_path}. Maybe you need to compile it from source? + If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION`, + for example, `make CUDA_VERSION=113`. + + The CUDA version for the compile might depend on your conda install, if using conda. + Inspect CUDA version via `conda list | grep cuda`. + """, + ) + + cuda_major, cuda_minor = cuda_specs.cuda_version_tuple + if cuda_major < 11: + print_dedented( + """ + WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8(). + You will be only to use 8-bit optimizers and quantization routines! + """, + ) + + print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}") + + # 7.5 is the minimum CC for cublaslt + if not cuda_specs.has_cublaslt: + print_dedented( + """ + WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU! + If you run into issues with 8-bit matmul, you can try 4-bit quantization: + https://huggingface.co/blog/4bit-transformers-bitsandbytes + """, + ) + + # TODO: + # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) + # (2) Multiple CUDA versions installed + + +def print_cuda_runtime_diagnostics() -> None: + cudart_paths = list(find_cudart_libraries()) + if not cudart_paths: + print("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.") + elif len(cudart_paths) > 1: + print_dedented( + f""" + Found duplicate CUDA runtime files (see below). + + We select the PyTorch default CUDA runtime, which is {torch.version.cuda}, + but this might mismatch with the CUDA version that is needed for bitsandbytes. + To override this behavior set the `BNB_CUDA_VERSION=` environmental variable. + + For example, if you want to use the CUDA version 122, + BNB_CUDA_VERSION=122 python ... + + OR set the environmental variable in your .bashrc: + export BNB_CUDA_VERSION=122 + + In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g. + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2, + """, + ) + for pth in cudart_paths: + print(f"* Found CUDA runtime at: {pth}") diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py new file mode 100644 index 000000000..1ce096f69 --- /dev/null +++ b/bitsandbytes/diagnostics/main.py @@ -0,0 +1,85 @@ +import sys +import traceback + +import torch + +from bitsandbytes.consts import PACKAGE_GITHUB_URL +from bitsandbytes.cuda_specs import get_cuda_specs +from bitsandbytes.diagnostics.cuda import ( + print_cuda_diagnostics, + print_cuda_runtime_diagnostics, +) +from bitsandbytes.diagnostics.utils import print_dedented, print_header + + +def sanity_check(): + from bitsandbytes.cextension import lib + + if lib is None: + print_dedented( + """ + Couldn't load the bitsandbytes library, likely due to missing binaries. + Please ensure bitsandbytes is properly installed. + + For source installations, compile the binaries with `cmake -DCOMPUTE_BACKEND=cuda -S .`. + See the documentation for more details if needed. + + Trying a simple check anyway, but this will likely fail... + """, + ) + + from bitsandbytes.optim import Adam + + p = torch.nn.Parameter(torch.rand(10, 10).cuda()) + a = torch.rand(10, 10).cuda() + p1 = p.data.sum().item() + adam = Adam([p]) + out = a * p + loss = out.sum() + loss.backward() + adam.step() + p2 = p.data.sum().item() + assert p1 != p2 + + +def main(): + print_header("") + print_header("BUG REPORT INFORMATION") + print_header("") + + print_header("OTHER") + cuda_specs = get_cuda_specs() + print("CUDA specs:", cuda_specs) + if not torch.cuda.is_available(): + print("Torch says CUDA is not available. Possible reasons:") + print("1. CUDA driver not installed") + print("2. CUDA not installed") + print("3. You have multiple conflicting CUDA libraries") + if cuda_specs: + print_cuda_diagnostics(cuda_specs) + print_cuda_runtime_diagnostics() + print_header("") + print_header("DEBUG INFO END") + print_header("") + print("Checking that the library is importable and CUDA is callable...") + try: + sanity_check() + print("SUCCESS!") + print("Installation was successful!") + return + except ImportError: + print( + f"WARNING: {__package__} is currently running as CPU-only!\n" + "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" + f"If you think that this is so erroneously,\nplease report an issue!", + ) + except Exception: + traceback.print_exc() + print_dedented( + f""" + Above we output some debug information. + Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose + WARNING: Please be sure to sanitize sensitive info from the output before posting it. + """, + ) + sys.exit(1) diff --git a/bitsandbytes/diagnostics/utils.py b/bitsandbytes/diagnostics/utils.py new file mode 100644 index 000000000..770209b9d --- /dev/null +++ b/bitsandbytes/diagnostics/utils.py @@ -0,0 +1,12 @@ +import textwrap + +HEADER_WIDTH = 60 + + +def print_header(txt: str, width: int = HEADER_WIDTH, filler: str = "+") -> None: + txt = f" {txt} " if txt else "" + print(txt.center(width, filler)) + + +def print_dedented(text): + print("\n".join(textwrap.dedent(text).strip().split("\n"))) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index a4e93bf37..c4f5920e9 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -3,97 +3,106 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import ctypes as ct +from functools import reduce # Required in Python 3 import itertools import operator -import random -import torch -import itertools -import math -import numpy as np +from typing import Any, Dict, Optional, Tuple -from functools import reduce # Required in Python 3 -from typing import Tuple, Any, Dict +import numpy as np +import torch from torch import Tensor -from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import COMPILED_WITH_CUDA, lib, HIP_ENVIRONMENT +from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -# Remark: for AMD GPU we need to disable blocksize == 64 +from .cextension import lib # math.prod not compatible with python < 3.8 def prod(iterable): return reduce(operator.mul, iterable, 1) + name2qmap = {} -if COMPILED_WITH_CUDA: +if lib and lib.compiled_with_cuda: """C FUNCTIONS FOR OPTIMIZERS""" - str2optimizer32bit = {} - str2optimizer32bit["adam"] = (lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16, lib.cadam32bit_grad_bf16) - str2optimizer32bit["momentum"] = ( - lib.cmomentum32bit_grad_32, - lib.cmomentum32bit_grad_16, - ) - str2optimizer32bit["rmsprop"] = ( - lib.crmsprop32bit_grad_32, - lib.crmsprop32bit_grad_16, - ) - str2optimizer32bit["lion"] = (lib.clion32bit_grad_fp32, lib.clion32bit_grad_fp16, lib.clion32bit_grad_bf16) - str2optimizer32bit["adagrad"] = ( - lib.cadagrad32bit_grad_32, - lib.cadagrad32bit_grad_16, - ) + str2optimizer32bit = { + "adam": ( + lib.cadam32bit_grad_fp32, + lib.cadam32bit_grad_fp16, + lib.cadam32bit_grad_bf16, + ), + "momentum": ( + lib.cmomentum32bit_grad_32, + lib.cmomentum32bit_grad_16, + ), + "rmsprop": ( + lib.crmsprop32bit_grad_32, + lib.crmsprop32bit_grad_16, + ), + "lion": ( + lib.clion32bit_grad_fp32, + lib.clion32bit_grad_fp16, + lib.clion32bit_grad_bf16, + ), + "adagrad": ( + lib.cadagrad32bit_grad_32, + lib.cadagrad32bit_grad_16, + ), + } + + str2optimizer8bit = { + "adam": ( + lib.cadam_static_8bit_grad_32, + lib.cadam_static_8bit_grad_16, + ), + "momentum": ( + lib.cmomentum_static_8bit_grad_32, + lib.cmomentum_static_8bit_grad_16, + ), + "rmsprop": ( + lib.crmsprop_static_8bit_grad_32, + lib.crmsprop_static_8bit_grad_16, + ), + "lion": ( + lib.clion_static_8bit_grad_32, + lib.clion_static_8bit_grad_16, + ), + "lamb": ( + lib.cadam_static_8bit_grad_32, + lib.cadam_static_8bit_grad_16, + ), + "lars": ( + lib.cmomentum_static_8bit_grad_32, + lib.cmomentum_static_8bit_grad_16, + ), + } + + str2optimizer8bit_blockwise = { + "adam": ( + lib.cadam_8bit_blockwise_grad_fp32, + lib.cadam_8bit_blockwise_grad_fp16, + lib.cadam_8bit_blockwise_grad_bf16, + ), + "momentum": ( + lib.cmomentum_8bit_blockwise_grad_fp32, + lib.cmomentum_8bit_blockwise_grad_fp16, + ), + "rmsprop": ( + lib.crmsprop_8bit_blockwise_grad_fp32, + lib.crmsprop_8bit_blockwise_grad_fp16, + ), + "lion": ( + lib.clion_8bit_blockwise_grad_fp32, + lib.clion_8bit_blockwise_grad_fp16, + lib.clion_8bit_blockwise_grad_bf16, + ), + "adagrad": ( + lib.cadagrad_8bit_blockwise_grad_fp32, + lib.cadagrad_8bit_blockwise_grad_fp16, + ), + } - str2optimizer8bit = {} - str2optimizer8bit["adam"] = ( - lib.cadam_static_8bit_grad_32, - lib.cadam_static_8bit_grad_16, - ) - str2optimizer8bit["momentum"] = ( - lib.cmomentum_static_8bit_grad_32, - lib.cmomentum_static_8bit_grad_16, - ) - str2optimizer8bit["rmsprop"] = ( - lib.crmsprop_static_8bit_grad_32, - lib.crmsprop_static_8bit_grad_16, - ) - str2optimizer8bit["lion"] = ( - lib.clion_static_8bit_grad_32, - lib.clion_static_8bit_grad_16, - ) - str2optimizer8bit["lamb"] = ( - lib.cadam_static_8bit_grad_32, - lib.cadam_static_8bit_grad_16, - ) - str2optimizer8bit["lars"] = ( - lib.cmomentum_static_8bit_grad_32, - lib.cmomentum_static_8bit_grad_16, - ) - - str2optimizer8bit_blockwise = {} - str2optimizer8bit_blockwise["adam"] = ( - lib.cadam_8bit_blockwise_grad_fp32, - lib.cadam_8bit_blockwise_grad_fp16, - lib.cadam_8bit_blockwise_grad_bf16, - ) - str2optimizer8bit_blockwise["momentum"] = ( - lib.cmomentum_8bit_blockwise_grad_fp32, - lib.cmomentum_8bit_blockwise_grad_fp16, - ) - str2optimizer8bit_blockwise["rmsprop"] = ( - lib.crmsprop_8bit_blockwise_grad_fp32, - lib.crmsprop_8bit_blockwise_grad_fp16, - ) - str2optimizer8bit_blockwise["lion"] = ( - lib.clion_8bit_blockwise_grad_fp32, - lib.clion_8bit_blockwise_grad_fp16, - lib.clion_8bit_blockwise_grad_bf16, - ) - str2optimizer8bit_blockwise["adagrad"] = ( - lib.cadagrad_8bit_blockwise_grad_fp32, - lib.cadagrad_8bit_blockwise_grad_fp16, - ) class GlobalPageManager: _instance = None @@ -112,14 +121,13 @@ def get_instance(cls): return cls._instance def prefetch_all(self, to_cpu=False): - # assume the first added, will be hte + # assume the first added, will be the # ones that are used first, so swap them in last # in the case they are evicted again for t in self.paged_tensors[::-1]: prefetch_tensor(t, to_cpu) - class CUBLAS_Context: _instance = None @@ -152,11 +160,7 @@ def __init__(self): raise RuntimeError("Call get_instance() instead") def initialize(self): - #self.context = ct.c_void_p(lib.get_cusparse()) - if torch.version.cuda: - self.context = ct.c_void_p(lib.get_cusparse()) - elif torch.version.hip: - self.context = ct.c_void_p(lib.get_hipsparse()) + self.context = ct.c_void_p(lib.get_cusparse()) @classmethod def get_instance(cls): @@ -165,6 +169,7 @@ def get_instance(cls): cls._instance.initialize() return cls._instance + dtype2bytes = {} dtype2bytes[torch.float32] = 4 dtype2bytes[torch.float16] = 2 @@ -172,8 +177,11 @@ def get_instance(cls): dtype2bytes[torch.uint8] = 1 dtype2bytes[torch.int8] = 1 -def get_paged(*shape, dtype=torch.float32, device=torch.device('cuda', index=0)): - num_bytes = dtype2bytes[dtype]*prod(shape) +FIRST_CUDA_DEVICE = torch.device("cuda", index=0) + + +def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE): + num_bytes = dtype2bytes[dtype] * prod(shape) cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) new_array = np.ctypeslib.as_array(c_ptr, shape=shape) @@ -182,74 +190,92 @@ def get_paged(*shape, dtype=torch.float32, device=torch.device('cuda', index=0)) out.page_deviceid = device.index return out + def prefetch_tensor(A, to_cpu=False): - assert A.is_paged, 'Only paged tensors can be prefetched!' + assert A.is_paged, "Only paged tensors can be prefetched!" if to_cpu: deviceid = -1 else: deviceid = A.page_deviceid - num_bytes = dtype2bytes[A.dtype]*A.numel() + num_bytes = dtype2bytes[A.dtype] * A.numel() lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid)) + def elementwise_func(func_name, A, B, value, prefetch=True): func = None if A.dtype == torch.float32: - func = getattr(lib, f'c{func_name}_fp32', None) + func = getattr(lib, f"c{func_name}_fp32", None) cvalue = ct.c_float(value) elif A.dtype == torch.uint8: - func = getattr(lib, f'c{func_name}_uint8', None) + func = getattr(lib, f"c{func_name}_uint8", None) cvalue = ct.c_uint8(value) - if func is None: raise NotImplementedError(f'Function not implemented: {func_name}') + if func is None: + raise NotImplementedError(f"Function not implemented: {func_name}") - is_managed = getattr(A, 'is_managed', False) + is_managed = getattr(A, "is_managed", False) if is_managed and prefetch: prefetch_tensor(A) - if B is not None: prefetch_tensor(B) + if B is not None: + prefetch_tensor(B) func(get_ptr(A), get_ptr(B), cvalue, ct.c_int64(A.numel())) if A.is_paged or B.is_paged: # paged function are fully asynchronous # if we return from this function, we want to the tensor # to be in the correct state, that is the final state after the - # operation occured. So we synchronize. + # operation occurred. So we synchronize. torch.cuda.synchronize() -def fill(A, value, device=None, prefetch=True): elementwise_func('fill', A, None, value) -def arange(A, device=None): elementwise_func('arange', A, None, 0) -def _mul(A, B, device=None): elementwise_func('_mul', A, B, 0) + +def fill(A, value, device=None, prefetch=True): + elementwise_func("fill", A, None, value) + + +def arange(A, device=None): + elementwise_func("arange", A, None, 0) + + +def _mul(A, B, device=None): + elementwise_func("_mul", A, B, 0) def create_linear_map(signed=True, total_bits=8, add_zero=True): - sign = (-1.0 if signed else 0.0) + sign = -1.0 if signed else 0.0 total_values = 2**total_bits if add_zero or total_bits < 8: # add a zero # since we simulate less bits by having zeros in the data type, we # we need to center the quantization around zero and as such lose # a single value - total_values = (2**total_bits if not signed else 2**total_bits-1) + total_values = 2**total_bits if not signed else 2**total_bits - 1 values = torch.linspace(sign, 1.0, total_values) gap = 256 - values.numel() if gap == 0: return values else: - l = values.numel()//2 - return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) + l = values.numel() // 2 # noqa: E741 + return torch.Tensor(values[:l].tolist() + [0] * gap + values[l:].tolist()) + def create_normal_map(offset=0.9677083, use_extra_value=True): - from scipy.stats import norm + try: + from scipy.stats import norm + except ImportError as ie: + raise ImportError( + "Scipy is required for `create_normal_map`. Install `bitsandbytes` with the `[test]` extra.", + ) from ie if use_extra_value: # one more positive value, this is an asymmetric type v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist() - v2 = [0]*(256-15) ## we have 15 non-zero values in this data type + v2 = [0] * (256 - 15) ## we have 15 non-zero values in this data type v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() else: v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist() - v2 = [0]*(256-14) ## we have 14 non-zero values in this data type + v2 = [0] * (256 - 14) ## we have 14 non-zero values in this data type v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() v = v1 + v2 + v3 @@ -262,38 +288,37 @@ def create_normal_map(offset=0.9677083, use_extra_value=True): return values + def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8): e = exponent_bits p = precision_bits has_sign = 1 if signed else 0 - assert e+p == total_bits-has_sign + assert e + p == total_bits - has_sign # the exponent is biased to 2^(e-1) -1 == 0 evalues = [] pvalues = [] - for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)): + for i, val in enumerate(range(-(2 ** (exponent_bits - has_sign)), 2 ** (exponent_bits - has_sign), 1)): evalues.append(2**val) - values = [] lst = list(itertools.product([0, 1], repeat=precision_bits)) - #for ev in evalues: - bias = 2**(exponent_bits-1) - for evalue in range(2**(exponent_bits)): + # for ev in evalues: + bias = 2 ** (exponent_bits - 1) + for evalue in range(2 ** (exponent_bits)): for bit_pattern in lst: - value = (1 if evalue != 0 else 0) + value = 1 if evalue != 0 else 0 for i, pval in enumerate(list(bit_pattern)): - value += pval*(2**-(i+1)) + value += pval * (2 ** -(i + 1)) if evalue == 0: # subnormals - value = value*2**-(bias) + value = value * 2**-(bias) else: # normals - value = value*2**-(evalue-bias-1) + value = value * 2 ** -(evalue - bias - 1) values.append(value) if signed: values.append(-value) - assert len(values) == 2**total_bits values.sort() if total_bits < 8: @@ -307,7 +332,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) return code - def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): """ Creates the dynamic quantiztion map. @@ -332,7 +356,11 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): non_sign_bits = total_bits - (1 if signed else 1) additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1 for i in range(max_exponent_bits): - fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1)) + fraction_items = int( + 2 ** (i + non_sign_bits - max_exponent_bits) + 1 + if signed + else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1, + ) boundaries = torch.linspace(0.1, 1, fraction_items) means = (boundaries[:-1] + boundaries[1:]) / 2.0 data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() @@ -358,8 +386,9 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): data.sort() return Tensor(data) + def create_quantile_map(A, total_bits=8): - q = estimate_quantiles(A, num_quantiles=2**total_bits-1) + q = estimate_quantiles(A, num_quantiles=2**total_bits - 1) q = q.tolist() q.append(0) @@ -370,11 +399,13 @@ def create_quantile_map(A, total_bits=8): q.sort() q = Tensor(q) - q = q/q.abs().max() + q = q / q.abs().max() return q + def get_special_format_str(): - if not torch.cuda.is_available(): return 'col_turing' + if not torch.cuda.is_available(): + return "col_turing" major, _minor = torch.cuda.get_device_capability() if major <= 7: return "col_turing" @@ -383,23 +414,28 @@ def get_special_format_str(): return "col_turing" - def is_on_gpu(tensors): on_gpu = True gpu_ids = set() for t in tensors: - if t is None: continue # NULL pointers are fine - is_paged = getattr(t, 'is_paged', False) - on_gpu &= (t.device.type == 'cuda' or is_paged) + if t is None: + continue # NULL pointers are fine + is_paged = getattr(t, "is_paged", False) + on_gpu &= t.device.type == "cuda" or is_paged if not is_paged: gpu_ids.add(t.device.index) if not on_gpu: - raise TypeError(f'All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}') + raise TypeError( + f"All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}", + ) if len(gpu_ids) > 1: - raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}') + raise TypeError( + f"Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}", + ) return on_gpu -def get_ptr(A: Tensor) -> ct.c_void_p: + +def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]: """ Get the ctypes pointer from a PyTorch Tensor. @@ -433,15 +469,13 @@ def get_transform_func(dtype, orderA, orderOut, transpose=False): if not hasattr(lib, name): print(name) raise ValueError( - f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}" + f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}", ) else: return getattr(lib, name) -def get_transform_buffer( - shape, dtype, device, to_order, from_order="row", transpose=False -): +def get_transform_buffer(shape, dtype, device, to_order, from_order="row", transpose=False): # init_func = torch.empty init_func = torch.zeros dims = len(shape) @@ -452,13 +486,13 @@ def get_transform_buffer( rows = shape[0] * shape[1] cols = shape[-1] + state = (shape, to_order) if transpose: # swap dims tmp = rows rows = cols cols = tmp - shape = shape[::-1] - state = (shape, to_order) + state = (shape[::-1], to_order) if to_order == "row" or to_order == "col": return init_func(shape, dtype=dtype, device=device), state @@ -489,18 +523,12 @@ def nvidia_transform( state=None, ld=None, ): - if HIP_ENVIRONMENT: - to_order = "col" if to_order in ["col32","col_turing","col_ampere"] else to_order - from_order = "col" if from_order in ["col32","col_turing","col_ampere"] else from_order - if state is None: state = (A.shape, from_order) else: from_order = state[1] if out is None: - out, new_state = get_transform_buffer( - state[0], A.dtype, A.device, to_order, state[1], transpose - ) + out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1]) else: new_state = (state[1], to_order) func = get_transform_func(A.dtype, from_order, to_order, transpose) @@ -524,8 +552,13 @@ def nvidia_transform( return out, new_state -def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor: - ''' +def estimate_quantiles( + A: Tensor, + out: Optional[torch.Tensor] = None, + offset: float = 1 / 512, + num_quantiles=256, +) -> Tensor: + """ Estimates 256 equidistant quantiles on the input tensor eCDF. Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles @@ -552,14 +585,21 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n ------- torch.Tensor: The 256 quantiles in float32 datatype. - ''' - if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.') - if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}") - if num_quantiles < 256 and offset == 1/(512): + """ + if A.numel() < 256: + raise NotImplementedError( + f"Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.", + ) + if num_quantiles > 256: + raise NotImplementedError( + f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}", + ) + if num_quantiles < 256 and offset == 1 / (512): # override default arguments - offset = 1/(2*num_quantiles) + offset = 1 / (2 * num_quantiles) - if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) + if out is None: + out = torch.zeros((256,), dtype=torch.float32, device=A.device) is_on_gpu([A, out]) device = pre_call(A.device) if A.dtype == torch.float32: @@ -571,7 +611,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n post_call(device) if num_quantiles < 256: - step = round(256/num_quantiles) + step = round(256 / num_quantiles) idx = torch.linspace(0, 255, num_quantiles).long().to(A.device) out = out[idx] @@ -579,13 +619,36 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n class QuantState: - """container for quantization state components to work with Params4bit and similar clases""" - valid_quant_types = ('fp4', 'nf4') - valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types] - valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state', 'quant_type', - 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset'] + """container for quantization state components to work with Params4bit and similar classes""" - def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None): + valid_quant_types = ("fp4", "nf4") + valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types] + valid_qs_keys = [ + "absmax", + "quant_map", + "nested_absmax", + "nested_quant_map", + "quant_state", + "quant_type", + "blocksize", + "dtype", + "shape", + "nested_blocksize", + "nested_dtype", + "nested_offset", + ] + + def __init__( + self, + absmax, + shape=None, + code=None, + blocksize=None, + quant_type=None, + dtype=None, + offset=None, + state2=None, + ): self.absmax = absmax self.shape = shape self.code = code @@ -604,13 +667,20 @@ def __get_item__(self, idx): state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type] """ if self.nested: - list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, [self.offset, self.state2], self.quant_type] + list_repr = [ + self.absmax, + self.shape, + self.dtype, + self.blocksize, + [self.offset, self.state2], + self.quant_type, + ] else: list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type] return list_repr[idx] @classmethod - def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState': + def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> "QuantState": """ unpacks components of state_dict into QuantState where necessary, convert into strings, torch.dtype, ints, etc. @@ -622,37 +692,39 @@ def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState # unpacking tensor with non-tensor components qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)] - if not len(qs_key) and 'quant_type' not in qs_dict: + if not len(qs_key) and "quant_type" not in qs_dict: raise ValueError("Expected packed or unpacked quant_state items, found neither") elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys: - raise ValueError(f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.") + raise ValueError( + f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.", + ) # unpacking minor and non-tensor quant state items if necessary if len(qs_key) == 1: - qs_key = qs_key[0] - qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(qs_key))) + first_qs_key = qs_key[0] + qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key))) - qs_dict = {k.split('.')[-1]: v for k, v in qs_dict.items()} # strip prefixes + qs_dict = {k.split(".")[-1]: v for k, v in qs_dict.items()} # strip prefixes assert set(qs_dict.keys()).issubset(cls.valid_qs_keys) - if 'nested_absmax' in qs_dict: - offset = torch.tensor(float(qs_dict['nested_offset'])).to(device) + if "nested_absmax" in qs_dict: + offset = torch.tensor(float(qs_dict["nested_offset"])).to(device) state2 = cls( - absmax=qs_dict['nested_absmax'].to(device), - blocksize=qs_dict['nested_blocksize'], - code=qs_dict['nested_quant_map'].to(device), - dtype=getattr(torch, qs_dict['nested_dtype']), + absmax=qs_dict["nested_absmax"].to(device), + blocksize=qs_dict["nested_blocksize"], + code=qs_dict["nested_quant_map"].to(device), + dtype=getattr(torch, qs_dict["nested_dtype"]), ) else: offset, state2 = None, None quant_state = cls( - quant_type=qs_dict['quant_type'], - absmax=qs_dict['absmax'].to(device), - blocksize=qs_dict['blocksize'], - code=qs_dict['quant_map'].to(device), - dtype=getattr(torch, qs_dict['dtype']), - shape=torch.Size(qs_dict['shape']) if qs_dict['shape'] is not None else None, + quant_type=qs_dict["quant_type"], + absmax=qs_dict["absmax"].to(device), + blocksize=qs_dict["blocksize"], + code=qs_dict["quant_map"].to(device), + dtype=getattr(torch, qs_dict["dtype"]), + shape=torch.Size(qs_dict["shape"]) if qs_dict["shape"] is not None else None, offset=offset, state2=state2, ) @@ -664,21 +736,23 @@ def as_dict(self, packed=False): param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving """ qs_dict = { - 'quant_type': self.quant_type, - 'absmax': self.absmax, - 'blocksize': self.blocksize, - 'quant_map': self.code, - 'dtype': str(self.dtype).strip('torch.'), - 'shape': tuple(self.shape), + "quant_type": self.quant_type, + "absmax": self.absmax, + "blocksize": self.blocksize, + "quant_map": self.code, + "dtype": str(self.dtype).strip("torch."), + "shape": tuple(self.shape), } if self.nested: - qs_dict.update({ - 'nested_absmax': self.state2.absmax, - 'nested_blocksize': self.state2.blocksize, - 'nested_quant_map': self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors - 'nested_dtype': str(self.state2.dtype).strip('torch.'), - 'nested_offset': self.offset.item(), - }) + qs_dict.update( + { + "nested_absmax": self.state2.absmax, + "nested_blocksize": self.state2.blocksize, + "nested_quant_map": self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors + "nested_dtype": str(self.state2.dtype).strip("torch."), + "nested_offset": self.offset.item(), + }, + ) if not packed: return qs_dict @@ -696,8 +770,38 @@ def to(self, device): self.state2.absmax = self.state2.absmax.to(device) self.state2.code = self.state2.code.to(device) + def __eq__(self, other): + if not isinstance(other, QuantState): + return False + + return ( + torch.allclose(self.absmax, other.absmax, atol=1e-6) + and self.shape == other.shape + and torch.allclose(self.code, other.code, atol=1e-6) + and self.dtype == other.dtype + and self.blocksize == other.blocksize + and self.quant_type == other.quant_type + and ( + self.offset == other.offset + if self.offset is not None and other.offset is not None + else self.offset is other.offset + ) + and ( + self.state2 == other.state2 + if self.state2 is not None and other.state2 is not None + else self.state2 is other.state2 + ) + ) + -def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor: +def quantize_blockwise( + A: Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, +) -> Tuple[Tensor, QuantState]: """ Quantize tensor A in blocks of size 4096 values. @@ -724,7 +828,6 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou The quantization state to undo the quantization. """ - if code is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) @@ -739,34 +842,66 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou if out is None: out = torch.zeros_like(A, dtype=torch.uint8) - if A.device.type != 'cpu': - if not HIP_ENVIRONMENT: - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] - else: - assert blocksize in [4096, 2048, 1024, 512, 256, 128] + if A.device.type != "cpu": + assert blocksize in [4096, 2048, 1024, 512, 256, 128] cblocksize = ct.c_int32(blocksize) prev_device = pre_call(A.device) code = code.to(A.device) is_on_gpu([code, A, out, absmax]) if A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) + lib.cquantize_blockwise_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + cblocksize, + ct.c_int(A.numel()), + ) elif A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) + lib.cquantize_blockwise_fp16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + cblocksize, + ct.c_int(A.numel()), + ) elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) + lib.cquantize_blockwise_bf16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + cblocksize, + ct.c_int(A.numel()), + ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) else: # cpu code = code.cpu() - lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) + lib.cquantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) if nested: offset = absmax.mean() absmax -= offset qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False) - quant_state = QuantState(absmax=qabsmax, code=code, blocksize=blocksize, dtype=A.dtype, offset=offset, state2=state2) + quant_state = QuantState( + absmax=qabsmax, + code=code, + blocksize=blocksize, + dtype=A.dtype, + offset=offset, + state2=state2, + ) else: quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype) @@ -775,12 +910,12 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou def dequantize_blockwise( A: Tensor, - quant_state: QuantState = None, - absmax: Tensor = None, - code: Tensor = None, - out: Tensor = None, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, blocksize: int = 4096, - nested=False + nested=False, ) -> Tensor: """ Dequantizes blockwise quantized values. @@ -814,46 +949,76 @@ def dequantize_blockwise( code = name2qmap["dynamic"] if quant_state is None: - quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) + quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) absmax = quant_state.absmax if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax += quant_state.offset - if absmax.dtype != torch.float32: absmax = absmax.float() + if absmax.dtype != torch.float32: + absmax = absmax.float() if out is None: out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device) - if A.device.type != 'cpu': + if A.device.type != "cpu": device = pre_call(A.device) code = quant_state.code.to(A.device) - supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64] - if HIP_ENVIRONMENT: - supported_blocksizes = supported_blocksizes[:-1] - if quant_state.blocksize not in supported_blocksizes: - raise ValueError(f"The blockwise of {quant_state.blocksize} is not supported. Supported values: {supported_blocksizes}") + if quant_state.blocksize not in [2048, 4096, 1024, 512, 256, 128]: + raise ValueError( + f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]", + ) is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: - lib.cdequantize_blockwise_fp32(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_fp32( + get_ptr(quant_state.code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(A.numel()), + ) elif out.dtype == torch.float16: - lib.cdequantize_blockwise_fp16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_fp16( + get_ptr(quant_state.code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(A.numel()), + ) elif out.dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16(get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_bf16( + get_ptr(quant_state.code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(A.numel()), + ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) else: code = quant_state.code.cpu() - lib.cdequantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(quant_state.absmax), get_ptr(out), ct.c_longlong(quant_state.blocksize), ct.c_longlong(A.numel())) + lib.cdequantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(quant_state.absmax), + get_ptr(out), + ct.c_longlong(quant_state.blocksize), + ct.c_longlong(A.numel()), + ) return out -def get_4bit_type(typename, device=None, blocksize=64): - if device is None: device = 'cuda' + +def get_4bit_type(typename, device=None, blocksize=128): + if device is None: + device = "cuda" data = None - if typename == 'nf4': - ''' Implements the NF4 data type. + if typename == "nf4": + """ Implements the NF4 data type. Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that is normalized into the range [-1, 1]. @@ -862,12 +1027,26 @@ def get_4bit_type(typename, device=None, blocksize=64): Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. - ''' - data = [-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, - -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, - 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, - 0.7229568362236023, 1.0] - elif typename == 'fp4': + """ + data = [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ] + elif typename == "fp4": # 0b000 = 0 # 0b001 = 0.0625 # 0b010 = 8 @@ -878,20 +1057,35 @@ def get_4bit_type(typename, device=None, blocksize=64): # 0b111 = 3 # can also be created with bnb.functional.create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4) data = [0, 0.0625, 8.0, 12.0, 4.0, 6.0, 2.0, 3.0, -0, -0.0625, -8.0, -12.0, -4.0, -6.0, -2.0, -3.0] - elif typename == 'int4': + elif typename == "int4": data = [7, 6, 5, 4, 3, 2, 1, 0, -0, -1, -2, -3, -4, -5, -6, -7] - elif typename == 'af4': + elif typename == "af4": # Taken from: NF4 Isn't Information Theoretically Optimal (and that's Good) # https://arxiv.org/abs/2306.06965 if blocksize == 64: - data = [-1., -0.69441008, -0.51243739, -0.3736951, -0.25607552, -0.14982478, - -0.04934812, 0., 0.04273164, 0.12934483, 0.21961274, 0.31675666, - 0.42563882, 0.55496234, 0.72424863, 1.][::-1] + data = [ + -1.0, + -0.69441008, + -0.51243739, + -0.3736951, + -0.25607552, + -0.14982478, + -0.04934812, + 0.0, + 0.04273164, + 0.12934483, + 0.21961274, + 0.31675666, + 0.42563882, + 0.55496234, + 0.72424863, + 1.0, + ][::-1] else: - raise NotImplementedError(f'4-bit AbnormalFloats currently only support blocksize 64.') + raise NotImplementedError("4-bit AbnormalFloats currently only support blocksize 64.") if data is None: - raise NotImplementedError(f'Typename {typename} not supported') + raise NotImplementedError(f"Typename {typename} not supported") data = Tensor(data) data /= data.abs().max() @@ -900,17 +1094,37 @@ def get_4bit_type(typename, device=None, blocksize=64): return data.to(device) -def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=None, compress_statistics=False, quant_storage=torch.uint8): - if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 - return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4', quant_storage) +def quantize_fp4( + A: Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=128, + compress_statistics=False, + quant_storage=torch.uint8, +): + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage) -def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=None, compress_statistics=False, quant_storage=torch.uint8): - if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 - return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4', quant_storage) -def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=None, compress_statistics=False, quant_type='fp4', quant_storage=torch.uint8) -> Tensor: +def quantize_nf4( + A: Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=128, + compress_statistics=False, + quant_storage=torch.uint8, +): + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage) + + +def quantize_4bit( + A: Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=128, + compress_statistics=False, + quant_type="fp4", + quant_storage=torch.uint8, +) -> Tuple[Tensor, QuantState]: """ Quantize tensor A in blocks of 4-bit values. @@ -936,12 +1150,10 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz tuple(torch.Tensor, torch.Size, torch.dtype, int): The quantization state to undo the quantization. """ - if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 - if A.device.type != 'cuda': - raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}') - if quant_type not in ['fp4', 'nf4']: - raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') + if A.device.type != "cuda": + raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}") + if quant_type not in ["fp4", "nf4"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") n = A.numel() input_shape = A.shape @@ -951,34 +1163,72 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz blocks += 1 if n % blocksize > 0 else 0 absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) - if out is None: mod = dtype2bytes[quant_storage] * 2 - out = torch.zeros(((n+1)//mod, 1), dtype=quant_storage, device=A.device) + out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device) - if not HIP_ENVIRONMENT: - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] - else: - assert blocksize in [4096, 2048, 1024, 512, 256, 128] + assert blocksize in [4096, 2048, 1024, 512, 256, 128] prev_device = pre_call(A.device) is_on_gpu([A, out, absmax]) if A.dtype == torch.float32: - if quant_type == 'fp4': - lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) else: - lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + lib.cquantize_blockwise_fp32_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) elif A.dtype == torch.float16: - if quant_type == 'fp4': - lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) else: - lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + lib.cquantize_blockwise_fp16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) elif A.dtype == torch.bfloat16: - if quant_type == 'fp4': - lib.cquantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) else: - lib.cquantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + lib.cquantize_blockwise_bf16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) @@ -990,25 +1240,57 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz absmax -= offset qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) del absmax - state = QuantState(absmax=qabsmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, offset=offset, state2=state2) + state = QuantState( + absmax=qabsmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + offset=offset, + state2=state2, + ) else: - state = QuantState(absmax=absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, ) + state = QuantState( + absmax=absmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + ) return out, state -def dequantize_fp4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = None) -> Tensor: - if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 - return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4') +def dequantize_fp4( + A: Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 128, +) -> Tensor: + return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") + + +def dequantize_nf4( + A: Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 128, +) -> Tensor: + return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") -def dequantize_nf4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = None) -> Tensor: - if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 - - return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4') -def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = None, quant_type='fp4') -> Tensor: +def dequantize_4bit( + A: Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 128, + quant_type="fp4", +) -> Tensor: """ Dequantizes FP4 blockwise quantized values. @@ -1035,31 +1317,32 @@ def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = torch.Tensor: Dequantized tensor. """ - if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 - - supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64] - if HIP_ENVIRONMENT: - supported_blocksizes = supported_blocksizes[:-1] - - if blocksize not in supported_blocksizes: - raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: {supported_blocksizes}") - if quant_type not in ['fp4', 'nf4']: - raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') + if blocksize not in [2048, 4096, 1024, 512, 256, 128]: + raise ValueError( + f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128]", + ) + if quant_type not in ["fp4", "nf4"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") if quant_state is None: assert absmax is not None and out is not None - quant_state = QuantState(absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type) + quant_state = QuantState( + absmax=absmax, + shape=out.shape, + dtype=out.dtype, + blocksize=blocksize, + quant_type=quant_type, + ) else: absmax = quant_state.absmax - if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax += quant_state.offset - if absmax.dtype != torch.float32: absmax = absmax.float() + if absmax.dtype != torch.float32: + absmax = absmax.float() if out is None: out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) @@ -1069,30 +1352,78 @@ def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = device = pre_call(A.device) is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: - if quant_state.quant_type == 'fp4': - lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) else: - lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + lib.cdequantize_blockwise_fp32_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) elif out.dtype == torch.float16: - if quant_state.quant_type == 'fp4': - lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) else: - lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + lib.cdequantize_blockwise_fp16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) elif out.dtype == torch.bfloat16: - if quant_state.quant_type == 'fp4': - lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) else: - lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + lib.cdequantize_blockwise_bf16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) - is_transposed = (True if A.shape[0] == 1 else False) - if is_transposed: return out.t() - else: return out + is_transposed = True if A.shape[0] == 1 else False + if is_transposed: + return out.t() + else: + return out -def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: +def quantize( + A: Tensor, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, +) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: if code is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) @@ -1100,7 +1431,8 @@ def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: code = code.to(A.device) absmax = torch.abs(A).max() - if absmax.dtype != torch.float32: absmax = absmax.float() + if absmax.dtype != torch.float32: + absmax = absmax.float() inp = A / absmax out = quantize_no_absmax(inp, code, out) return out, (absmax, code) @@ -1108,10 +1440,10 @@ def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: def dequantize( A: Tensor, - state: Tuple[Tensor, Tensor] = None, - absmax: Tensor = None, - code: Tensor = None, - out: Tensor = None, + state: Optional[Tuple[Tensor, Tensor]] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, ) -> Tensor: assert state is not None or absmax is not None if code is None and state is None: @@ -1126,8 +1458,8 @@ def dequantize( return out * state[0] -def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: - ''' +def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: + """ Quantizes input tensor to 8-bit. Quantizes the 32-bit input tensor `A` to the 8-bit output tensor @@ -1146,17 +1478,18 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: ------- torch.Tensor: Quantized 8-bit tensor. - ''' + """ prev_device = pre_call(A.device) - if out is None: out = torch.zeros_like(A, dtype=torch.uint8) + if out is None: + out = torch.zeros_like(A, dtype=torch.uint8) is_on_gpu([A, out]) lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) post_call(prev_device) return out -def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: - ''' +def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: + """ Dequantizes the 8-bit tensor to 32-bit. Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via @@ -1175,9 +1508,10 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: ------- torch.Tensor: 32-bit output tensor. - ''' + """ prev_device = pre_call(A.device) - if out is None: out = torch.zeros_like(A, dtype=torch.float32) + if out is None: + out = torch.zeros_like(A, dtype=torch.float32) is_on_gpu([code, A, out]) lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) post_call(prev_device) @@ -1193,11 +1527,11 @@ def optimizer_update_32bit( eps: float, step: int, lr: float, - state2: Tensor = None, + state2: Optional[torch.Tensor] = None, beta2: float = 0.0, weight_decay: float = 0.0, gnorm_scale: float = 1.0, - unorm_vec: Tensor = None, + unorm_vec: Optional[torch.Tensor] = None, max_unorm: float = 0.0, skip_zeros=False, ) -> None: @@ -1244,16 +1578,17 @@ def optimizer_update_32bit( if max_unorm > 0.0: param_norm = torch.norm(p.data.float()) - optim_func = None if g.dtype == torch.float32: optim_func = str2optimizer32bit[optimizer_name][0] elif g.dtype == torch.float16: optim_func = str2optimizer32bit[optimizer_name][1] - elif (g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name])==3): + elif g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name]) == 3: optim_func = str2optimizer32bit[optimizer_name][2] else: - raise ValueError(f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}") + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", + ) is_on_gpu([g, p, state1, state2, unorm_vec]) prev_device = pre_call(g.device) @@ -1273,7 +1608,8 @@ def optimizer_update_32bit( ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), - ct.c_int32(g.numel())) + ct.c_int32(g.numel()), + ) post_call(prev_device) @@ -1282,21 +1618,21 @@ def optimizer_update_8bit( g: Tensor, p: Tensor, state1: Tensor, - state2: Tensor, + state2: Optional[torch.Tensor], beta1: float, beta2: float, eps: float, step: int, lr: float, qmap1: Tensor, - qmap2: Tensor, + qmap2: Optional[torch.Tensor], max1: Tensor, - max2: Tensor, + max2: Optional[torch.Tensor], new_max1: Tensor, - new_max2: Tensor, + new_max2: Optional[torch.Tensor], weight_decay: float = 0.0, gnorm_scale: float = 1.0, - unorm_vec: Tensor = None, + unorm_vec: Optional[torch.Tensor] = None, max_unorm: float = 0.0, ) -> None: """ @@ -1405,7 +1741,7 @@ def optimizer_update_8bit( ) else: raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", ) post_call(prev_device) @@ -1415,21 +1751,20 @@ def optimizer_update_8bit_blockwise( g: Tensor, p: Tensor, state1: Tensor, - state2: Tensor, + state2: Optional[torch.Tensor], beta1: float, beta2: float, eps: float, step: int, lr: float, qmap1: Tensor, - qmap2: Tensor, + qmap2: Optional[torch.Tensor], absmax1: Tensor, - absmax2: Tensor, + absmax2: Optional[torch.Tensor], weight_decay: float = 0.0, gnorm_scale: float = 1.0, skip_zeros=False, ) -> None: - optim_func = None prev_device = pre_call(g.device) is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2]) @@ -1437,12 +1772,15 @@ def optimizer_update_8bit_blockwise( optim_func = str2optimizer8bit_blockwise[optimizer_name][0] elif g.dtype == torch.float16 and state1.dtype == torch.uint8: optim_func = str2optimizer8bit_blockwise[optimizer_name][1] - elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and - len(str2optimizer8bit_blockwise[optimizer_name])==3): + elif ( + g.dtype == torch.bfloat16 + and state1.dtype == torch.uint8 + and len(str2optimizer8bit_blockwise[optimizer_name]) == 3 + ): optim_func = str2optimizer8bit_blockwise[optimizer_name][2] else: raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", ) post_call(prev_device) @@ -1470,9 +1808,8 @@ def optimizer_update_8bit_blockwise( ) post_call(prev_device) -def percentile_clipping( - grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5 -): + +def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5): """Applies percentile clipping grad: torch.Tensor @@ -1514,9 +1851,7 @@ def percentile_clipping( return current_gnorm, clip_value, gnorm_scale -def histogram_scatter_add_2d( - histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor -): +def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor): assert len(histogram.shape) == 2 assert histogram.dtype == torch.float32 assert source.dtype == torch.float32 @@ -1533,12 +1868,12 @@ def histogram_scatter_add_2d( is_on_gpu([histogram, index1, index2, source]) lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n) + def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): - if not torch.cuda.is_initialized(): torch.cuda.init() + if not torch.cuda.is_initialized(): + torch.cuda.init() if A.dtype != expected_type or B.dtype != expected_type: - raise TypeError( - f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}" - ) + raise TypeError(f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}") sA = A.shape sB = B.shape @@ -1579,12 +1914,7 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 sout = out.shape # special case common in backprop if not correct and len(sA) == 3 and len(sB) == 3: - if ( - sout[0] == sA[2] - and sout[1] == sB[2] - and sA[0] == sB[0] - and sA[1] == sB[1] - ): + if sout[0] == sA[2] and sout[1] == sB[2] and sA[0] == sB[0] and sA[1] == sB[1]: correct = True else: if len(sA) == 2 and len(sB) == 2: @@ -1617,26 +1947,29 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 if not correct: raise ValueError( - f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}." + f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}.", ) return sout + def gemv_4bit( A: Tensor, B: Tensor, - out: Tensor = None, + out: Optional[torch.Tensor] = None, transposed_A=False, transposed_B=False, - state=None + state=None, ): prev_device = pre_call(A.device) - #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) + # sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) if state is None: - raise ValueError(f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )') + raise ValueError("state cannot None. gem_4bit( ) requires the state from quantize_4bit( )") if A.numel() != A.shape[-1]: - raise ValueError(f'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]') + raise ValueError( + 'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]', + ) Bshape = state.shape bout = Bshape[0] @@ -1656,7 +1989,7 @@ def gemv_4bit( k = Bshape[1] lda = Bshape[0] ldc = Bshape[0] - ldb = (A.shape[-1]+1)//2 + ldb = (A.shape[-1] + 1) // 2 is_on_gpu([B, A, out, absmax, state.code]) m = ct.c_int32(m) n = ct.c_int32(n) @@ -1667,25 +2000,65 @@ def gemv_4bit( if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]: if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) + lib.cgemm_4bit_inference_naive_fp16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(state.code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(state.blocksize), + ) elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) + lib.cgemm_4bit_inference_naive_bf16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(state.code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(state.blocksize), + ) elif A.dtype == torch.float32: - lib.cgemm_4bit_inference_naive_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize)) + lib.cgemm_4bit_inference_naive_fp32( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(state.code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(state.blocksize), + ) else: - raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') + raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") else: - raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') + raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") post_call(prev_device) return out + def igemm( A: Tensor, B: Tensor, - out: Tensor = None, + out: Optional[torch.Tensor] = None, transposed_A=False, transposed_B=False, ): @@ -1747,7 +2120,7 @@ def igemm( assert len(sA) == 3 if not (sA[0] == sB[0] and sA[1] == sB[1]): raise ValueError( - f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}" + f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}", ) transposed_A = True @@ -1766,22 +2139,32 @@ def igemm( # B^T @ A^T = C^T # [km, nk -> mn] is_on_gpu([B, A, out]) - lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), - get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc)) + lib.cigemm( + ptr, + ct.c_bool(transposed_B), + ct.c_bool(transposed_A), + ct.c_int32(m), + ct.c_int32(n), + ct.c_int32(k), + get_ptr(B), + get_ptr(A), + get_ptr(out), + ct.c_int32(lda), + ct.c_int32(ldb), + ct.c_int32(ldc), + ) return out def batched_igemm( A: Tensor, B: Tensor, - out: Tensor = None, + out: Optional[torch.Tensor] = None, transposed_A=False, transposed_B=False, ): if not len(A.shape) == 3 or not len(B.shape) == 3: - raise ValueError( - f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}" - ) + raise ValueError(f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}") sout = check_matmul(A, B, out, transposed_A, transposed_B) if out is None: out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) @@ -1848,9 +2231,24 @@ def batched_igemm( ptr = CUBLAS_Context.get_instance().get_context(A.device) is_on_gpu([B, A, out]) - lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), - get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc), - ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch)) + lib.cbatched_igemm( + ptr, + ct.c_bool(transposed_B), + ct.c_bool(transposed_A), + ct.c_int32(m), + ct.c_int32(n), + ct.c_int32(k), + get_ptr(B), + get_ptr(A), + get_ptr(out), + ct.c_int32(lda), + ct.c_int32(ldb), + ct.c_int32(ldc), + ct.c_long(strideA), + ct.c_long(strideB), + ct.c_long(strideC), + ct.c_uint32(num_batch), + ) return out @@ -1859,14 +2257,14 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): shapeB = SB[0] dimsA = len(shapeA) dimsB = len(shapeB) - assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' + assert dimsB == 2, "Only two dimensional matrices are supported for argument B" if dimsA == 2: m = shapeA[0] elif dimsA == 3: m = shapeA[0] * shapeA[1] rows = n = shapeB[0] - assert prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}' + assert prod(list(shapeA)) > 0, f"Input tensor dimensions need to be > 0: {shapeA}" # if the tensor is empty, return a transformed empty tensor with the right dimensions if shapeA[0] == 0 and dimsA == 2: @@ -1875,23 +2273,9 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16) if dimsA == 2 and out is None: - if HIP_ENVIRONMENT: - out, Sout = get_transform_buffer( - (shapeA[0], shapeB[0]), dtype, A.device, "col", "row" - ) - else: - out, Sout = get_transform_buffer( - (shapeA[0], shapeB[0]), dtype, A.device, "col32", "row" - ) + out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, "col32", "row") elif dimsA == 3 and out is None: - if HIP_ENVIRONMENT: - out, Sout = get_transform_buffer( - (shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col", "row" - ) - else: - out, Sout = get_transform_buffer( - (shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row" - ) + out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row") assert dimsB != 3, "len(B.shape)==3 not supported" assert A.device.type == "cuda" @@ -1899,14 +2283,9 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): assert A.dtype == torch.int8 assert B.dtype == torch.int8 assert out.dtype == dtype - if HIP_ENVIRONMENT: - assert SA[1] == "col" - assert SB[1] == "col" - assert Sout[1] == "col" - else: - assert SA[1] == "col32" - assert SB[1] in ["col_turing", "col_ampere"] - assert Sout[1] == "col32" + assert SA[1] == "col32" + assert SB[1] in ["col_turing", "col_ampere"] + assert Sout[1] == "col32" assert ( shapeA[-1] == shapeB[-1] ), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}" @@ -1920,21 +2299,17 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): ptrC = get_ptr(out) k = shapeA[-1] - if HIP_ENVIRONMENT: - lda = ct.c_int32(m) - ldb = ct.c_int32(shapeB[0]) - ldc = ct.c_int32(m) + lda = ct.c_int32(m * 32) + if formatB == "col_turing": + # turing: tiles with rows filled up to multiple of 8 rows by 32 columns + # n = rows + ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32) else: - lda = ct.c_int32(m * 32) - if formatB == "col_turing": - # turing: tiles with rows filled up to multiple of 8 rows by 32 columns - # n = rows - ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32) - else: - # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns - # n = rows - ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32) - ldc = ct.c_int32(m * 32) + # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns + # n = rows + ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32) + + ldc = ct.c_int32(m * 32) m = ct.c_int32(m) n = ct.c_int32(n) k = ct.c_int32(k) @@ -1942,48 +2317,33 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): has_error = 0 ptrRowScale = get_ptr(None) is_on_gpu([A, B, out]) - if formatB == 'col_turing' or HIP_ENVIRONMENT: + if formatB == "col_turing": if dtype == torch.int32: - has_error = lib.cigemmlt_turing_32( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) + has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) else: - has_error = lib.cigemmlt_turing_8( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) + has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) elif formatB == "col_ampere": if dtype == torch.int32: - has_error = lib.cigemmlt_ampere_32( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) + has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) else: - has_error = lib.cigemmlt_ampere_8( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) + has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + + if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` + raise NotImplementedError("igemmlt not available (probably built with NO_CUBLASLT)") - if has_error == 1: - print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}') - raise Exception('cublasLt ran into an error!') + if has_error: + print(f"A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}") + raise Exception("cublasLt ran into an error!") torch.cuda.set_device(prev_device) return out, Sout -def mm_dequant( - A, - quant_state, - row_stats, - col_stats, - out=None, - new_row_stats=None, - new_col_stats=None, - bias=None -): - if HIP_ENVIRONMENT: - A, quant_state = nvidia_transform(A, "row", state = quant_state) +def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None): assert A.dtype == torch.int32 - if bias is not None: assert bias.dtype == torch.float16 + if bias is not None: + assert bias.dtype == torch.float16 out_shape = quant_state[0] if len(out_shape) == 3: out_shape = (out_shape[0] * out_shape[1], out_shape[2]) @@ -1991,19 +2351,11 @@ def mm_dequant( if out is None: out = torch.empty(out_shape, dtype=torch.float16, device=A.device) if new_row_stats is None: - new_row_stats = torch.empty( - out_shape[0], dtype=torch.float32, device=A.device - ) + new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device) if new_col_stats is None: - new_col_stats = torch.empty( - out_shape[1], dtype=torch.float32, device=A.device - ) - assert ( - new_row_stats.shape[0] == row_stats.shape[0] - ), f"{new_row_stats.shape} vs {row_stats.shape}" - assert ( - new_col_stats.shape[0] == col_stats.shape[0] - ), f"{new_col_stats.shape} vs {col_stats.shape}" + new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device) + assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}" + assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}" prev_device = pre_call(A.device) ptrA = get_ptr(A) @@ -2017,15 +2369,23 @@ def mm_dequant( numCols = ct.c_int32(out_shape[1]) is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias]) - lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols) + lib.cdequant_mm_int32_fp16( + ptrA, + ptrRowStats, + ptrColStats, + ptrOut, + ptrNewRowStats, + ptrNewColStats, + ptrBias, + numRows, + numCols, + ) post_call(prev_device) return out -def get_colrow_absmax( - A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0 -): +def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0): assert A.dtype == torch.float16 device = A.device @@ -2038,18 +2398,12 @@ def get_colrow_absmax( col_tiles = (cols + 255) // 256 tiled_rows = ((rows + 15) // 16) * 16 if row_stats is None: - row_stats = torch.empty( - (rows,), dtype=torch.float32, device=device - ).fill_(-50000.0) + row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(-50000.0) if col_stats is None: - col_stats = torch.empty( - (cols,), dtype=torch.float32, device=device - ).fill_(-50000.0) + col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(-50000.0) if nnz_block_ptr is None and threshold > 0.0: - nnz_block_ptr = torch.zeros( - ((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device - ) + nnz_block_ptr = torch.zeros(((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device) ptrA = get_ptr(A) ptrRowStats = get_ptr(row_stats) @@ -2123,14 +2477,10 @@ def __init__(self, rows, cols, nnz, colptr, rowidx, values): def coo2csr(cooA): values, counts = torch.unique(cooA.rowidx, return_counts=True) values.add_(1) - rowptr = torch.zeros( - (cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device - ) + rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device) rowptr.scatter_(index=values.long(), src=counts.int(), dim=0) rowptr.cumsum_(0) - return CSRSparseTensor( - cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values - ) + return CSRSparseTensor(cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values) def coo2csc(cooA): @@ -2139,14 +2489,10 @@ def coo2csc(cooA): values = cooA.values[col2rowidx] colvalues, counts = torch.unique(val, return_counts=True) colvalues.add_(1) - colptr = torch.zeros( - (cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device - ) + colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device) colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0) colptr.cumsum_(0) - return CSCSparseTensor( - cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values - ) + return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values) def coo_zeros(rows, cols, nnz, device, dtype=torch.half): @@ -2156,9 +2502,7 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) -def double_quant( - A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 -): +def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): device = A.device assert A.dtype == torch.half assert device.type == "cuda" @@ -2171,9 +2515,7 @@ def double_quant( rows = A.shape[0] if row_stats is None or col_stats is None: - row_stats, col_stats, nnz_row_ptr = get_colrow_absmax( - A, threshold=threshold - ) + row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) if out_col is None: out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) @@ -2191,9 +2533,7 @@ def double_quant( if threshold > 0.0: nnz = nnz_row_ptr[-1].item() if nnz > 0: - coo_tensor = coo_zeros( - A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device - ) + coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device) ptrRowIdx = get_ptr(coo_tensor.rowidx) ptrColIdx = get_ptr(coo_tensor.colidx) ptrVal = get_ptr(coo_tensor.values) @@ -2252,15 +2592,16 @@ def double_quant( return out_row, out_col, row_stats, col_stats, coo_tensor -def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): - if HIP_ENVIRONMENT: - return nvidia_transform(A,to_order,from_order,out,transpose,state,ld) - +def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None): prev_device = pre_call(A.device) - if state is None: state = (A.shape, from_order) - else: from_order = state[1] - if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) - else: new_state = (state[0], to_order) # (shape, order) + if state is None: + state = (A.shape, from_order) + else: + from_order = state[1] + if out is None: + out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) + else: + new_state = (state[0], to_order) # (shape, order) shape = state[0] if len(shape) == 2: @@ -2271,7 +2612,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No dim2 = ct.c_int32(shape[2]) is_on_gpu([A, out]) - if to_order == 'col32': + if to_order == "col32": if transpose: lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) else: @@ -2292,7 +2633,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No elif from_order == "col_ampere": lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) else: - raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') + raise NotImplementedError(f"Transform function not implemented: From {from_order} to {to_order}") post_call(prev_device) @@ -2301,9 +2642,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No def spmm_coo(cooA, B, out=None): if out is None: - out = torch.empty( - (cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype - ) + out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype) nnz = cooA.nnz assert cooA.rowidx.numel() == nnz assert cooA.colidx.numel() == nnz @@ -2330,16 +2669,28 @@ def spmm_coo(cooA, B, out=None): cldc = ct.c_int32(ldc) is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out]) - lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B)) + lib.cspmm_coo( + ptr, + ptrRowidx, + ptrColidx, + ptrValues, + cnnz, + crowsA, + ccolsA, + ccolsB, + cldb, + ptrB, + cldc, + ptrC, + ct.c_bool(transposed_B), + ) return out def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): if out is None: - out = torch.zeros( - (cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype - ) + out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype) nnz = cooA.nnz prev_device = pre_call(B.device) assert cooA.rowidx.numel() == nnz @@ -2357,9 +2708,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): max_count, max_idx = torch.sort(counts, descending=True) max_idx = max_idx.int() max_count = max_count.int() - assert ( - max_count[0] <= 32 - ), f"Current max count per row is 8 but found {max_count[0]}." + assert max_count[0] <= 32, f"Current max count per row is 8 but found {max_count[0]}." assert B.dtype in [torch.float16, torch.int8] ptrOffset = get_ptr(offset) ptrMaxCount = get_ptr(max_count) @@ -2447,9 +2796,7 @@ def vectorwise_quant(x, dim=1, quant_type="vector"): elif quant_type in ["vector-zeropoint", "row-zeropoint"]: dtype = x.dtype x = x.float() - dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin( - x, dim=dim, keepdim=True - ) + dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True) dyna[dyna == 0] = 1 qx = 255.0 / dyna minx = torch.amin(x, dim=dim, keepdim=True) @@ -2554,15 +2901,10 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): def extract_outliers(A, SA, idx): shapeA = SA[0] formatA = SA[1] - if not HIP_ENVIRONMENT: - assert formatA in ["col_turing", "col_ampere"] - else: - assert formatA in ["col"] + assert formatA in ["col_turing", "col_ampere"] assert A.device.type == "cuda" - out = torch.zeros( - (shapeA[0], idx.numel()), dtype=torch.int8, device=A.device - ) + out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) idx_size = ct.c_int32(idx.numel()) rows = ct.c_int32(shapeA[0]) @@ -2572,7 +2914,7 @@ def extract_outliers(A, SA, idx): ptrOut = get_ptr(out) prev_device = pre_call(A.device) - if formatA == 'col_turing' or HIP_ENVIRONMENT: + if formatA == "col_turing": lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) elif formatA == "col_ampere": lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) @@ -2580,6 +2922,7 @@ def extract_outliers(A, SA, idx): return out + def pipeline_test(A, batch_size): out = torch.zeros_like(A) lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index 6fa6d1183..96f4359bf 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -2,5 +2,21 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .modules import Int8Params, Linear8bitLt, StableEmbedding, Linear4bit, LinearNF4, LinearFP4, Params4bit, OutlierAwareLinear, SwitchBackLinearBnb, Embedding -from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorwise, StandardLinear +from .modules import ( + Embedding, + Int8Params, + Linear4bit, + Linear8bitLt, + LinearFP4, + LinearNF4, + OutlierAwareLinear, + Params4bit, + StableEmbedding, + SwitchBackLinearBnb, +) +from .triton_based_modules import ( + StandardLinear, + SwitchBackLinear, + SwitchBackLinearGlobal, + SwitchBackLinearVectorwise, +) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index e0d94d861..102f7dbd8 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -2,23 +2,49 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import copy from typing import Any, Dict, Optional, TypeVar, Union, overload - import warnings + import torch -import torch.nn.functional as F from torch import Tensor, device, dtype, nn +import torch.nn.functional as F import bitsandbytes as bnb +from bitsandbytes.autograd._functions import get_tile_inds, undo_layout from bitsandbytes.functional import QuantState -from bitsandbytes.autograd._functions import undo_layout, get_tile_inds from bitsandbytes.optim import GlobalOptimManager -from bitsandbytes.utils import OutlierTracer, find_outlier_dims +from bitsandbytes.utils import OutlierTracer T = TypeVar("T", bound="torch.nn.Module") class StableEmbedding(torch.nn.Embedding): + """ + Custom embedding layer designed to improve stability during training for NLP tasks by using 32-bit optimizer states. It is designed to reduce gradient variations that can result from quantization. This embedding layer is initialized with Xavier uniform initialization followed by layer normalization. + + Example: + + ``` + # Initialize StableEmbedding layer with vocabulary size 1000, embedding dimension 300 + embedding_layer = StableEmbedding(num_embeddings=1000, embedding_dim=300) + + # Reset embedding parameters + embedding_layer.reset_parameters() + + # Perform a forward pass with input tensor + input_tensor = torch.tensor([1, 2, 3]) + output_embedding = embedding_layer(input_tensor) + ``` + + Attributes: + norm (`torch.nn.LayerNorm`): Layer normalization applied after the embedding. + + Methods: + reset_parameters(): Reset embedding parameters using Xavier uniform initialization. + forward(input: Tensor) -> Tensor: Forward pass through the stable embedding layer. + """ + def __init__( self, num_embeddings: int, @@ -32,6 +58,25 @@ def __init__( device=None, dtype=None, ) -> None: + """ + Args: + num_embeddings (`int`): + The number of unique embeddings (vocabulary size). + embedding_dim (`int`): + The dimensionality of the embedding. + padding_idx (`Optional[int]`): + Pads the output with zeros at the given index. + max_norm (`Optional[float]`): + Renormalizes embeddings to have a maximum L2 norm. + norm_type (`float`, defaults to `2.0`): + The p-norm to compute for the `max_norm` option. + scale_grad_by_freq (`bool`, defaults to `False`): + Scale gradient by frequency during backpropagation. + sparse (`bool`, defaults to `False`): + Computes dense gradients. Set to `True` to compute sparse gradients instead. + _weight (`Optional[Tensor]`): + Pretrained embeddings. + """ super().__init__( num_embeddings, embedding_dim, @@ -45,9 +90,7 @@ def __init__( dtype, ) self.norm = torch.nn.LayerNorm(embedding_dim, device=device) - GlobalOptimManager.get_instance().register_module_override( - self, "weight", {"optim_bits": 32} - ) + GlobalOptimManager.get_instance().register_module_override(self, "weight", {"optim_bits": 32}) def reset_parameters(self) -> None: torch.nn.init.xavier_uniform_(self.weight) @@ -83,6 +126,10 @@ def forward(self, input: Tensor) -> Tensor: class Embedding(torch.nn.Embedding): + """ + Embedding class to store and retrieve word embeddings from their indices. + """ + def __init__( self, num_embeddings: int, @@ -95,6 +142,25 @@ def __init__( _weight: Optional[Tensor] = None, device: Optional[device] = None, ) -> None: + """ + Args: + num_embeddings (`int`): + The number of unique embeddings (vocabulary size). + embedding_dim (`int`): + The dimensionality of the embedding. + padding_idx (`Optional[int]`): + Pads the output with zeros at the given index. + max_norm (`Optional[float]`): + Renormalizes embeddings to have a maximum L2 norm. + norm_type (`float`, defaults to `2.0`): + The p-norm to compute for the `max_norm` option. + scale_grad_by_freq (`bool`, defaults to `False`): + Scale gradient by frequency during backpropagation. + sparse (`bool`, defaults to `False`): + Computes dense gradients. Set to `True` to compute sparse gradients instead. + _weight (`Optional[Tensor]`): + Pretrained embeddings. + """ super().__init__( num_embeddings, embedding_dim, @@ -104,11 +170,9 @@ def __init__( scale_grad_by_freq, sparse, _weight, - device=device - ) - GlobalOptimManager.get_instance().register_module_override( - self, "weight", {"optim_bits": 32} + device=device, ) + GlobalOptimManager.get_instance().register_module_override(self, "weight", {"optim_bits": 32}) def reset_parameters(self) -> None: torch.nn.init.xavier_uniform_(self.weight) @@ -141,18 +205,17 @@ def forward(self, input: Tensor) -> Tensor: class Params4bit(torch.nn.Parameter): - # Remark: change blocksize to 128 for AMD gpu def __new__( - cls, - data: Optional[torch.Tensor] = None, - requires_grad=True, - quant_state: QuantState = None, - blocksize: int = 128, - compress_statistics: bool = True, - quant_type: str = 'fp4', - quant_storage: torch.dtype = torch.uint8, - module: Optional["Linear4bit"] = None, - bnb_quantized: bool = False + cls, + data: Optional[torch.Tensor] = None, + requires_grad=False, # quantized weights should be frozen by default + quant_state: Optional[QuantState] = None, + blocksize: int = 128, + compress_statistics: bool = True, + quant_type: str = "fp4", + quant_storage: torch.dtype = torch.uint8, + module: Optional["Linear4bit"] = None, + bnb_quantized: bool = False, ) -> "Params4bit": if data is None: data = torch.empty(0) @@ -168,8 +231,46 @@ def __new__( self.module = module return self + def __getstate__(self): + state = self.__dict__ + state["data"] = self.data + state["requires_grad"] = self.requires_grad + return state + + def __setstate__(self, state): + self.requires_grad = state["requires_grad"] + self.blocksize = state["blocksize"] + self.compress_statistics = state["compress_statistics"] + self.quant_type = state["quant_type"] + self.quant_state = state["quant_state"] + self.data = state["data"] + self.quant_storage = state["quant_storage"] + self.bnb_quantized = state["bnb_quantized"] + self.module = state["module"] + + def __deepcopy__(self, memo): + new_instance = type(self).__new__(type(self)) + state = self.__getstate__() + new_instance.__setstate__(state) + new_instance.quant_state = copy.deepcopy(state["quant_state"]) + new_instance.data = copy.deepcopy(state["data"]) + return new_instance + + def __copy__(self): + new_instance = type(self).__new__(type(self)) + state = self.__getstate__() + new_instance.__setstate__(state) + return new_instance + @classmethod - def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit": + def from_prequantized( + cls, + data: torch.Tensor, + quantized_stats: Dict[str, Any], + requires_grad: bool = False, + device="cuda", + **kwargs, + ) -> "Params4bit": self = torch.Tensor._make_subclass(cls, data.to(device)) self.requires_grad = requires_grad self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device) @@ -181,8 +282,13 @@ def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], def _quantize(self, device): w = self.data.contiguous().cuda(device) - w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, - quant_type=self.quant_type, quant_storage=self.quant_storage) + w_4bit, quant_state = bnb.functional.quantize_4bit( + w, + blocksize=self.blocksize, + compress_statistics=self.compress_statistics, + quant_type=self.quant_type, + quant_storage=self.quant_storage, + ) self.data = w_4bit self.quant_state = quant_state if self.module is not None: @@ -191,42 +297,107 @@ def _quantize(self, device): return self def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): - return self.to(device='cuda' if device is None else device, non_blocking=non_blocking) + return self.to(device="cuda" if device is None else device, non_blocking=non_blocking) @overload - def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T: - ... + def to( + self: T, + device: Optional[Union[int, device]] = ..., + dtype: Optional[Union[dtype, str]] = ..., + non_blocking: bool = ..., + ) -> T: ... @overload - def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: - ... + def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ... @overload - def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: - ... + def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if (device is not None and device.type == "cuda" and not self.bnb_quantized): + if device is not None and device.type == "cuda" and not self.bnb_quantized: return self._quantize(device) else: if self.quant_state is not None: self.quant_state.to(device) - new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking), - requires_grad=self.requires_grad, quant_state=self.quant_state, - blocksize=self.blocksize, compress_statistics=self.compress_statistics, - quant_type=self.quant_type) + new_param = Params4bit( + super().to(device=device, dtype=dtype, non_blocking=non_blocking), + requires_grad=self.requires_grad, + quant_state=self.quant_state, + blocksize=self.blocksize, + compress_statistics=self.compress_statistics, + quant_type=self.quant_type, + ) return new_param class Linear4bit(nn.Linear): + """ + This class is the base module for the 4-bit quantization algorithm presented in [QLoRA](https://arxiv.org/abs/2305.14314). + QLoRA 4-bit linear layers uses blockwise k-bit quantization under the hood, with the possibility of selecting various + compute datatypes such as FP4 and NF4. + + In order to quantize a linear layer one should first load the original fp16 / bf16 weights into + the Linear4bit module, then call `quantized_module.to("cuda")` to quantize the fp16 / bf16 weights. + + Example: - def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None): + ```python + import torch + import torch.nn as nn + + import bitsandbytes as bnb + from bnb.nn import Linear4bit + + fp16_model = nn.Sequential( + nn.Linear(64, 64), + nn.Linear(64, 64) + ) + + quantized_model = nn.Sequential( + Linear4bit(64, 64), + Linear4bit(64, 64) + ) + + quantized_model.load_state_dict(fp16_model.state_dict()) + quantized_model = quantized_model.to(0) # Quantization happens here + ``` + """ + + def __init__( + self, + input_features, + output_features, + bias=True, + compute_dtype=None, + compress_statistics=True, + quant_type="fp4", + quant_storage=torch.uint8, + device=None, + ): + """ + Initialize Linear4bit class. + + Args: + input_features (`str`): + Number of input features of the linear layer. + output_features (`str`): + Number of output features of the linear layer. + bias (`bool`, defaults to `True`): + Whether the linear class uses the bias term as well. + """ super().__init__(input_features, output_features, bias, device) - self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self) + self.weight = Params4bit( + self.weight.data, + requires_grad=False, + compress_statistics=compress_statistics, + quant_type=quant_type, + quant_storage=quant_storage, + module=self, + ) # self.persistent_buffers = [] # TODO consider as way to save quant state self.compute_dtype = compute_dtype self.compute_type_is_set = False @@ -243,11 +414,15 @@ def set_compute_type(self, x): if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]): # single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast # warn the user about this - warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.') - warnings.filterwarnings('ignore', message='.*inference.') + warnings.warn( + "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.", + ) + warnings.filterwarnings("ignore", message=".*inference.") if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]): - warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.') - warnings.filterwarnings('ignore', message='.*inference or training') + warnings.warn( + "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.", + ) + warnings.filterwarnings("ignore", message=".*inference or training") def _save_to_state_dict(self, destination, prefix, keep_vars): """ @@ -265,8 +440,8 @@ def forward(self, x: torch.Tensor): if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) - if getattr(self.weight, 'quant_state', None) is None: - if getattr(self, 'quant_state', None) is not None: + if getattr(self.weight, "quant_state", None) is None: + if getattr(self, "quant_state", None) is not None: # the quant state got lost when the parameter got converted. This happens for example for fsdp # since we registered the module, we can recover the state here assert self.weight.shape[1] == 1 @@ -274,7 +449,9 @@ def forward(self, x: torch.Tensor): self.weight = Params4bit(self.weight, quant_storage=self.quant_storage) self.weight.quant_state = self.quant_state else: - print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') + print( + "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.", + ) if not self.compute_type_is_set: self.set_compute_type(x) self.compute_type_is_set = True @@ -292,23 +469,82 @@ def forward(self, x: torch.Tensor): class LinearFP4(Linear4bit): - def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None): - super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', quant_storage, device) + """ + Implements the FP4 data type. + """ + + def __init__( + self, + input_features, + output_features, + bias=True, + compute_dtype=None, + compress_statistics=True, + quant_storage=torch.uint8, + device=None, + ): + """ + Args: + input_features (`str`): + Number of input features of the linear layer. + output_features (`str`): + Number of output features of the linear layer. + bias (`bool`, defaults to `True`): + Whether the linear class uses the bias term as well. + """ + super().__init__( + input_features, + output_features, + bias, + compute_dtype, + compress_statistics, + "fp4", + quant_storage, + device, + ) class LinearNF4(Linear4bit): - ''' Implements the NF4 data type. + """Implements the NF4 data type. - Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that - is normalized into the range [-1, 1]. + Constructs a quantization data type where each bin has equal area under a standard normal distribution N(0, 1) that + is normalized into the range [-1, 1]. - For more information read the paper: QLoRA: Efficient Finetuning of Quantized LLMs (https://arxiv.org/abs/2305.14314) + For more information read the paper: QLoRA: Efficient Finetuning of Quantized LLMs (https://arxiv.org/abs/2305.14314) - Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in - the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. - ''' - def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None): - super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', quant_storage, device) + Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in + the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. + """ + + def __init__( + self, + input_features, + output_features, + bias=True, + compute_dtype=None, + compress_statistics=True, + quant_storage=torch.uint8, + device=None, + ): + """ + Args: + input_features (`str`): + Number of input features of the linear layer. + output_features (`str`): + Number of output features of the linear layer. + bias (`bool`, defaults to `True`): + Whether the linear class uses the bias term as well. + """ + super().__init__( + input_features, + output_features, + bias, + compute_dtype, + compress_statistics, + "nf4", + quant_storage, + device, + ) class Int8Params(torch.nn.Parameter): @@ -325,7 +561,9 @@ def __new__( cls.SCB = None if data is None: data = torch.empty(0) - return torch.Tensor._make_subclass(cls, data, requires_grad) + obj = torch.Tensor._make_subclass(cls, data, requires_grad) + obj.CB, obj.SCB = cls.CB, cls.SCB + return obj def cuda(self, device): if self.has_fp16_weights: @@ -338,8 +576,8 @@ def cuda(self, device): del CBt del SCBt self.data = CB - setattr(self, "CB", CB) - setattr(self, "SCB", SCB) + self.CB = CB + self.SCB = SCB return self @@ -349,33 +587,22 @@ def to( device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ..., - ) -> T: - ... + ) -> T: ... @overload - def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: - ... + def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ... @overload - def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: - ... + def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... def to(self, *args, **kwargs): - device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( - *args, **kwargs - ) + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if ( - device is not None - and device.type == "cuda" - and self.data.device.type == "cpu" - ): + if device is not None and device.type == "cuda" and self.data.device.type == "cpu": return self.cuda(device) else: new_param = Int8Params( - super().to( - device=device, dtype=dtype, non_blocking=non_blocking - ), + super().to(device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, has_fp16_weights=self.has_fp16_weights, ) @@ -398,8 +625,59 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k class Linear8bitLt(nn.Linear): - def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, - memory_efficient_backward=False, threshold=0.0, index=None, device=None): + """ + This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm. + To read more about it, have a look at the paper. + + In order to quantize a linear layer one should first load the original fp16 / bf16 weights into + the Linear8bitLt module, then call `int8_module.to("cuda")` to quantize the fp16 weights. + + Example: + + ```python + import torch + import torch.nn as nn + + import bitsandbytes as bnb + from bnb.nn import Linear8bitLt + + fp16_model = nn.Sequential( + nn.Linear(64, 64), + nn.Linear(64, 64) + ) + + int8_model = nn.Sequential( + Linear8bitLt(64, 64, has_fp16_weights=False), + Linear8bitLt(64, 64, has_fp16_weights=False) + ) + + int8_model.load_state_dict(fp16_model.state_dict()) + int8_model = int8_model.to(0) # Quantization happens here + ``` + """ + + def __init__( + self, + input_features: int, + output_features: int, + bias=True, + has_fp16_weights=True, + memory_efficient_backward=False, + threshold=0.0, + index=None, + device=None, + ): + """ + Initialize Linear8bitLt class. + + Args: + input_features (`int`): + Number of input features of the linear layer. + output_features (`int`): + Number of output features of the linear layer. + bias (`bool`, defaults to `True`): + Whether the linear class uses the bias term as well. + """ super().__init__(input_features, output_features, bias, device) assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" self.state = bnb.MatmulLtState() @@ -441,19 +719,36 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): destination[key_name] = param_from_state if keep_vars else param_from_state.detach() destination[format_name] = self.state.formatB - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs) + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) unexpected_copy = list(unexpected_keys) for key in unexpected_copy: - input_name = key[len(prefix):] + input_name = key[len(prefix) :] if input_name == "SCB": if self.weight.SCB is None: # buffers not yet initialized, can't access them directly without quantizing first - raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is " - "not supported. Please call module.cuda() before module.load_state_dict()") + raise RuntimeError( + "Loading a quantized checkpoint into non-quantized Linear8bitLt is " + "not supported. Please call module.cuda() before module.load_state_dict()", + ) input_param = state_dict[key] self.weight.SCB.copy_(input_param) @@ -496,18 +791,18 @@ def __init__(self, input_features, output_features, bias=True, device=None): self.is_quantized = False def forward_with_outliers(self, x, outlier_idx): - raise NotImplementedError('Please override the `forward_with_outliers(self, x, outlier_idx)` function') + raise NotImplementedError("Please override the `forward_with_outliers(self, x, outlier_idx)` function") def quantize_weight(self, w, outlier_idx): - raise NotImplementedError('Please override the `quantize_weights(self, w, outlier_idx)` function') + raise NotImplementedError("Please override the `quantize_weights(self, w, outlier_idx)` function") def forward(self, x): if self.outlier_dim is None: tracer = OutlierTracer.get_instance() if not tracer.is_initialized(): - print('Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer') + print("Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer") outlier_idx = tracer.get_outliers(self.weight) - #print(outlier_idx, tracer.get_hvalue(self.weight)) + # print(outlier_idx, tracer.get_hvalue(self.weight)) self.outlier_dim = outlier_idx if not self.is_quantized: @@ -515,6 +810,7 @@ def forward(self, x): self.weight.data.copy_(w) self.is_quantized = True + class SwitchBackLinearBnb(nn.Linear): def __init__( self, @@ -525,11 +821,9 @@ def __init__( memory_efficient_backward=False, threshold=0.0, index=None, - device=None + device=None, ): - super().__init__( - input_features, output_features, bias, device - ) + super().__init__(input_features, output_features, bias, device) self.state = bnb.MatmulLtState() self.index = index @@ -539,9 +833,7 @@ def __init__( if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True - self.weight = Int8Params( - self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights - ) + self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights) def init_8bit_state(self): self.state.CB = self.weight.CB diff --git a/bitsandbytes/nn/triton_based_modules.py b/bitsandbytes/nn/triton_based_modules.py index de07ac647..aa8494942 100644 --- a/bitsandbytes/nn/triton_based_modules.py +++ b/bitsandbytes/nn/triton_based_modules.py @@ -1,20 +1,27 @@ -import torch -import torch.nn as nn -import time from functools import partial -from bitsandbytes.triton.triton_utils import is_triton_available +import torch +import torch.nn as nn from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise +from bitsandbytes.triton.int8_matmul_mixed_dequantize import ( + int8_matmul_mixed_dequantize, +) +from bitsandbytes.triton.int8_matmul_rowwise_dequantize import ( + int8_matmul_rowwise_dequantize, +) +from bitsandbytes.triton.quantize_columnwise_and_transpose import ( + quantize_columnwise_and_transpose, +) +from bitsandbytes.triton.quantize_global import ( + quantize_global, + quantize_global_transpose, +) from bitsandbytes.triton.quantize_rowwise import quantize_rowwise -from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose -from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize -from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose -from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize +from bitsandbytes.triton.triton_utils import is_triton_available class _switchback_global(torch.autograd.Function): - @staticmethod def forward(ctx, X_3D, W, bias): # reshape input to [N * L, D] @@ -29,9 +36,7 @@ def forward(ctx, X_3D, W, bias): # matmult, fused dequant and add bias # call "mixed" because we are mixing rowwise quantized and global quantized - return int8_matmul_mixed_dequantize( - X_int8, W_int8.t(), state_X, state_W, bias - ).view(*X_3D.size()[:-1], -1) + return int8_matmul_mixed_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D.size()[:-1], -1) @staticmethod def backward(ctx, G_3D): @@ -48,7 +53,8 @@ def backward(ctx, G_3D): G_int8, state_G = quantize_rowwise(G) W_int8, state_W = quantize_global_transpose(W) grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( - *G_3D.size()[:-1], -1 + *G_3D.size()[:-1], + -1, ) if ctx.needs_input_grad[1]: # backward pass uses standard weight grad @@ -58,8 +64,8 @@ def backward(ctx, G_3D): return grad_X, grad_W, grad_bias -class _switchback_vectorrize(torch.autograd.Function): +class _switchback_vectorrize(torch.autograd.Function): @staticmethod def forward(ctx, X_3D, W, bias): # reshape input to [N * L, D] @@ -73,9 +79,7 @@ def forward(ctx, X_3D, W, bias): # matmult, fused dequant and add bias # call kernel which expects rowwise quantized X and W - return int8_matmul_rowwise_dequantize( - X_int8, W_int8.t(), state_X, state_W, bias - ).view(*X_3D.size()[:-1], -1) + return int8_matmul_rowwise_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D.size()[:-1], -1) @staticmethod def backward(ctx, G_3D): @@ -91,7 +95,8 @@ def backward(ctx, G_3D): G_int8, state_G = quantize_rowwise(G) W_int8, state_W = quantize_columnwise_and_transpose(W) grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( - *G_3D.size()[:-1], -1 + *G_3D.size()[:-1], + -1, ) if ctx.needs_input_grad[1]: # backward pass uses standard weight grad @@ -101,8 +106,8 @@ def backward(ctx, G_3D): return grad_X, grad_W, grad_bias -class _switchback_global_mem_efficient(torch.autograd.Function): +class _switchback_global_mem_efficient(torch.autograd.Function): @staticmethod def forward(ctx, X_3D, W, bias): # reshape input to [N * L, D] @@ -119,9 +124,7 @@ def forward(ctx, X_3D, W, bias): # matmult, fused dequant and add bias # call "mixed" because we are mixing rowwise quantized and global quantized - return int8_matmul_mixed_dequantize( - X_int8, W_int8.t(), state_X, state_W, bias - ).view(*X_3D_sz[:-1], -1) + return int8_matmul_mixed_dequantize(X_int8, W_int8.t(), state_X, state_W, bias).view(*X_3D_sz[:-1], -1) @staticmethod def backward(ctx, G_3D): @@ -143,35 +146,34 @@ def backward(ctx, G_3D): G_int8, state_G = quantize_rowwise(G) del G W_int8 = W_int8.t().contiguous() - grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( - *G_3D_sz[:-1], -1 - ) + grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(*G_3D_sz[:-1], -1) return grad_X, grad_W, grad_bias + class SwitchBackLinear(nn.Linear): def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - vector_wise_quantization: bool = False, - mem_efficient : bool = False, - ): + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + vector_wise_quantization: bool = False, + mem_efficient: bool = False, + ): super().__init__(in_features, out_features, bias, device, dtype) - if not is_triton_available: - raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear. - Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''') + if not is_triton_available(): + raise ImportError("""Could not import triton. Please install triton to use SwitchBackLinear. + Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower""") # By default, we use the global quantization. self.vector_wise_quantization = vector_wise_quantization if self.vector_wise_quantization: self._fn = _switchback_vectorrize if mem_efficient: - print('mem efficient is not supported for vector-wise quantization.') + print("mem efficient is not supported for vector-wise quantization.") exit(1) else: if mem_efficient: @@ -187,7 +189,7 @@ def prepare_for_eval(self): # if hasattr(m, "prepare_for_eval"): # m.prepare_for_eval() # model.apply(cond_prepare) - print('=> preparing for eval.') + print("=> preparing for eval.") if self.vector_wise_quantization: W_int8, state_W = quantize_rowwise(self.weight) else: @@ -211,18 +213,22 @@ def forward(self, x): X_int8, state_X = quantize_rowwise(X) if self.vector_wise_quantization: - return int8_matmul_rowwise_dequantize( - X_int8, self.W_int8.t(), state_X, self.state_W, self.bias - ).view(*x.size()[:-1], -1) + return int8_matmul_rowwise_dequantize(X_int8, self.W_int8.t(), state_X, self.state_W, self.bias).view( + *x.size()[:-1], + -1, + ) else: - return int8_matmul_mixed_dequantize( - X_int8, self.W_int8.t(), state_X, self.state_W, self.bias - ).view(*x.size()[:-1], -1) + return int8_matmul_mixed_dequantize(X_int8, self.W_int8.t(), state_X, self.state_W, self.bias).view( + *x.size()[:-1], + -1, + ) + SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False) SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True) SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True) + # This is just the standard linear function. class StandardLinearFunction(torch.autograd.Function): @staticmethod @@ -252,7 +258,7 @@ def backward(ctx, grad_output_3D): return grad_input, grad_weight, grad_bias -class StandardLinear(nn.Linear): +class StandardLinear(nn.Linear): def forward(self, x): return StandardLinearFunction.apply(x, self.weight, self.bias) diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index 83a57bd9f..b4c95793a 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -3,14 +3,19 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from bitsandbytes.cextension import COMPILED_WITH_CUDA - from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit -from .adamw import AdamW, AdamW8bit, AdamW32bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit +from .adamw import ( + AdamW, + AdamW8bit, + AdamW32bit, + PagedAdamW, + PagedAdamW8bit, + PagedAdamW32bit, +) from .lamb import LAMB, LAMB8bit, LAMB32bit from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS +from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit from .optimizer import GlobalOptimManager from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit -from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit from .sgd import SGD, SGD8bit, SGD32bit diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py index 7d8df58ac..7459dece1 100644 --- a/bitsandbytes/optim/adagrad.py +++ b/bitsandbytes/optim/adagrad.py @@ -20,12 +20,37 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + Base Adagrad optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-2): + The learning rate. + lr_decay (`int`, defaults to 0): + The learning rate decay. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + initial_accumulator_value (`int`, defaults to 0): + The initial momemtum values. + eps (`float`, defaults to 1e-10): + The epsilon value prevents division by zero in the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= weight_decay: - raise ValueError( - f"Invalid weight_decay value: {weight_decay}" - ) + raise ValueError(f"Invalid weight_decay value: {weight_decay}") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") if initial_accumulator_value != 0.0: @@ -62,12 +87,37 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + 8-bit Adagrad optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-2): + The learning rate. + lr_decay (`int`, defaults to 0): + The learning rate decay. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + initial_accumulator_value (`int`, defaults to 0): + The initial momemtum values. + eps (`float`, defaults to 1e-10): + The epsilon value prevents division by zero in the optimizer. + optim_bits (`int`, defaults to 8): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= weight_decay: - raise ValueError( - f"Invalid weight_decay value: {weight_decay}" - ) + raise ValueError(f"Invalid weight_decay value: {weight_decay}") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") if initial_accumulator_value != 0.0: @@ -105,12 +155,37 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + 32-bit Adagrad optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-2): + The learning rate. + lr_decay (`int`, defaults to 0): + The learning rate decay. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + initial_accumulator_value (`int`, defaults to 0): + The initial momemtum values. + eps (`float`, defaults to 1e-10): + The epsilon value prevents division by zero in the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= weight_decay: - raise ValueError( - f"Invalid weight_decay value: {weight_decay}" - ) + raise ValueError(f"Invalid weight_decay value: {weight_decay}") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") if initial_accumulator_value != 0.0: diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index 86981eb86..740db26ac 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -14,34 +14,370 @@ class Adam(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + Base Adam optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + class Adam8bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + 8-bit Adam optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + class Adam32bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + 32-bit Adam optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + class PagedAdam(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + Paged Adam optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) + class PagedAdam8bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + 8-bit paged Adam optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) + class PagedAdam32bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + Paged 32-bit Adam optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) + class AnalysisAdam(torch.optim.Optimizer): """Adam that performs 8-bit vs 32-bit error analysis. @@ -119,9 +455,7 @@ def step(self, closure=None): if grad.dtype in {torch.float16, torch.bfloat16}: grad = grad.float() if grad.is_sparse: - raise RuntimeError( - "Adam does not support sparse gradients, please consider SparseAdam instead" - ) + raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") amsgrad = group.get("amsgrad", False) assert not amsgrad @@ -138,15 +472,9 @@ def step(self, closure=None): state["exp_avg"] = torch.zeros_like(p_data_fp32) # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) - state["abserrors"] = torch.zeros( - (256, 256), device=p_data_fp32.device - ) - state["relerrors"] = torch.zeros( - (256, 256), device=p_data_fp32.device - ) - state["counts"] = torch.zeros( - (256, 256), device=p_data_fp32.device - ) + state["abserrors"] = torch.zeros((256, 256), device=p_data_fp32.device) + state["relerrors"] = torch.zeros((256, 256), device=p_data_fp32.device) + state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32) @@ -154,25 +482,19 @@ def step(self, closure=None): state["exp_avg"] = state["exp_avg"].to(p_data_fp32) state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32) if amsgrad: - state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to( - p_data_fp32 - ) + state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(p_data_fp32) state["step"] += 1 beta1, beta2 = group["betas"] bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] - step_size = ( - group["lr"] * math.sqrt(bias_correction2) / bias_correction1 - ) + step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 e = state["abserrors"] rele = state["relerrors"] counts = state["counts"] if group["weight_decay"] != 0: - p_data_fp32.add_( - p_data_fp32, alpha=-group["weight_decay"] * group["lr"] - ) + p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] if amsgrad: @@ -185,10 +507,7 @@ def step(self, closure=None): denom = exp_avg_sq.sqrt().add_(group["eps"]) update_fp32 = exp_avg / denom - if ( - p_data_fp32.numel() <= 8192 - or p_data_fp32.numel() > 50000 * 1000 - ): + if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000: # embedding layer or too small p_data_fp32 += -step_size * update_fp32 else: @@ -227,9 +546,7 @@ def step(self, closure=None): # 3. dequantize # Error will be calculated automatically! else: - raise ValueError( - f"Invalid analysis value: {self.analysis}!" - ) + raise ValueError(f"Invalid analysis value: {self.analysis}!") denom = state2.sqrt().add_(group["eps"]) update_8bit = state1 / denom @@ -241,9 +558,7 @@ def step(self, closure=None): F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr) F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr) - F.histogram_scatter_add_2d( - counts, C1.int(), C2.int(), torch.ones_like(abserr) - ) + F.histogram_scatter_add_2d(counts, C1.int(), C2.int(), torch.ones_like(abserr)) p_data_fp32 += -step_size * update_fp32 @@ -251,18 +566,10 @@ def step(self, closure=None): if self.savedir != "" and state["step"] % 100 == 0: if not os.path.exists(self.savedir): os.makedirs(self.savedir) - shapestr = "_".join( - [str(dim) for dim in p_data_fp32.shape] - ) - pathe = os.path.join( - self.savedir, f"{p_id}_{shapestr}_abserr.pkl" - ) - pathrele = os.path.join( - self.savedir, f"{p_id}_{shapestr}_relerr.pkl" - ) - pathcounts = os.path.join( - self.savedir, f"{p_id}_{shapestr}_counts.pkl" - ) + shapestr = "_".join([str(dim) for dim in p_data_fp32.shape]) + pathe = os.path.join(self.savedir, f"{p_id}_{shapestr}_abserr.pkl") + pathrele = os.path.join(self.savedir, f"{p_id}_{shapestr}_relerr.pkl") + pathcounts = os.path.join(self.savedir, f"{p_id}_{shapestr}_counts.pkl") torch.save(e, pathe) torch.save(rele, pathrele) torch.save(counts, pathcounts) diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py index 21077f1a0..4bf3f6436 100644 --- a/bitsandbytes/optim/adamw.py +++ b/bitsandbytes/optim/adamw.py @@ -5,35 +5,364 @@ from bitsandbytes.optim.optimizer import Optimizer2State - class AdamW(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + Base AdamW optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + class AdamW8bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + 8-bit AdamW optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + class AdamW32bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + 32-bit AdamW optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) class PagedAdamW(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + """ + Paged AdamW optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) + class PagedAdamW8bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + """ + Paged 8-bit AdamW optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) + class PagedAdamW32bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, - args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + """ + Paged 32-bit AdamW optimizer. + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) diff --git a/bitsandbytes/optim/lamb.py b/bitsandbytes/optim/lamb.py index 1fbb6fadc..8d29cbbfe 100644 --- a/bitsandbytes/optim/lamb.py +++ b/bitsandbytes/optim/lamb.py @@ -23,6 +23,39 @@ def __init__( block_wise=False, max_unorm=1.0, ): + """ + Base LAMB optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + bias_correction (`bool`, defaults to `True`): + Whether to apply bias correction to the first and second-order moments. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + adam_w_mode (`bool`, defaults to `True`): + Whether to use the AdamW variant. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + max_unorm (`float`, defaults to 1.0): + The maximum gradient norm. + """ super().__init__( "lamb", params, @@ -56,6 +89,37 @@ def __init__( block_wise=False, max_unorm=1.0, ): + """ + 8-bit LAMB optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + bias_correction (`bool`, defaults to `True`): + Whether to apply bias correction to the first and second-order moments. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + adam_w_mode (`bool`, defaults to `True`): + Whether to use the AdamW variant. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + max_unorm (`float`, defaults to 1.0): + The maximum gradient norm. + """ super().__init__( "lamb", params, @@ -89,6 +153,37 @@ def __init__( block_wise=False, max_unorm=1.0, ): + """ + 32-bit LAMB optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + bias_correction (`bool`, defaults to `True`): + Whether to apply bias correction to the first and second-order moments. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + adam_w_mode (`bool`, defaults to `True`): + Whether to use the AdamW variant. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + max_unorm (`float`, defaults to 1.0): + The maximum gradient norm. + """ super().__init__( "lamb", params, diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py index 73554e3cc..90c3686fe 100644 --- a/bitsandbytes/optim/lars.py +++ b/bitsandbytes/optim/lars.py @@ -23,10 +23,35 @@ def __init__( percentile_clipping=100, max_unorm=0.02, ): + """ + Base LARS optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`): + The learning rate. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + dampening (`float`, defaults to 0): + The dampening value reduces the momentum of the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + nesterov (`bool`, defaults to `False`): + Whether to use Nesterov momentum. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + max_unorm (`float`, defaults to 0.02): + The maximum gradient norm. + """ if momentum == 0: - raise NotImplementedError( - "LARS without momentum is not supported!" - ) + raise NotImplementedError("LARS without momentum is not supported!") super().__init__( "lars", params, @@ -57,10 +82,33 @@ def __init__( percentile_clipping=100, max_unorm=0.02, ): + """ + 8-bit LARS optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`): + The learning rate. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + dampening (`float`, defaults to 0): + The dampening value reduces the momentum of the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + nesterov (`bool`, defaults to `False`): + Whether to use Nesterov momentum. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + max_unorm (`float`, defaults to 0.02): + The maximum gradient norm. + """ if momentum == 0: - raise NotImplementedError( - "LARS without momentum is not supported!" - ) + raise NotImplementedError("LARS without momentum is not supported!") super().__init__( "lars", params, @@ -91,10 +139,33 @@ def __init__( percentile_clipping=100, max_unorm=0.02, ): + """ + 32-bit LARS optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`): + The learning rate. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + dampening (`float`, defaults to 0): + The dampening value reduces the momentum of the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + nesterov (`bool`, defaults to `False`): + Whether to use Nesterov momentum. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + max_unorm (`float`, defaults to 0.02): + The maximum gradient norm. + """ if momentum == 0: - raise NotImplementedError( - "LARS without momentum is not supported!" - ) + raise NotImplementedError("LARS without momentum is not supported!") super().__init__( "lars", params, @@ -127,9 +198,7 @@ def __init__( if momentum < 0.0: raise ValueError(f"Invalid momentum value: {momentum}") if weight_decay < 0.0: - raise ValueError( - f"Invalid weight_decay value: {weight_decay}" - ) + raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict( lr=lr, @@ -140,9 +209,7 @@ def __init__( max_unorm=max_unorm, ) if nesterov and (momentum <= 0 or dampening != 0): - raise ValueError( - "Nesterov momentum requires a momentum and zero dampening" - ) + raise ValueError("Nesterov momentum requires a momentum and zero dampening") super().__init__(params, defaults) def __setstate__(self, state): diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py index 2bde1a447..2e4163694 100644 --- a/bitsandbytes/optim/lion.py +++ b/bitsandbytes/optim/lion.py @@ -4,27 +4,315 @@ # LICENSE file in the root directory of this source tree. from bitsandbytes.optim.optimizer import Optimizer1State + class Lion(Optimizer1State): - def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + Base Lion optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-4): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + weight_decay (`float`, defaults to 0): + The weight decay value for the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "lion", + params, + lr, + betas, + 0.0, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + class Lion8bit(Optimizer1State): - def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + 8-bit Lion optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-4): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + weight_decay (`float`, defaults to 0): + The weight decay value for the optimizer. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "lion", + params, + lr, + betas, + 0.0, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + class Lion32bit(Optimizer1State): - def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): - super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + 32-bit Lion optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-4): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + weight_decay (`float`, defaults to 0): + The weight decay value for the optimizer. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "lion", + params, + lr, + betas, + 0.0, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) class PagedLion(Optimizer1State): - def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + """ + Paged Lion optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-4): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + weight_decay (`float`, defaults to 0): + The weight decay value for the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ + super().__init__( + "lion", + params, + lr, + betas, + 0.0, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) + class PagedLion8bit(Optimizer1State): - def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + """ + Paged 8-bit Lion optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-4): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + weight_decay (`float`, defaults to 0): + The weight decay value for the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ + super().__init__( + "lion", + params, + lr, + betas, + 0.0, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) + class PagedLion32bit(Optimizer1State): - def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + """ + Paged 32-bit Lion optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-4): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + weight_decay (`float`, defaults to 0): + The weight decay value for the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ + super().__init__( + "lion", + params, + lr, + betas, + 0.0, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index fb83eddf0..f1e60e5e7 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -2,8 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from collections import abc as container_abcs -from collections import defaultdict +from collections import abc as container_abcs, defaultdict from copy import deepcopy from itertools import chain @@ -19,6 +18,10 @@ def __init__(self, initial_data): class GlobalOptimManager: + """ + A global optimizer manager for enabling custom optimizer configs. + """ + _instance = None def __init__(self): @@ -46,30 +49,44 @@ def register_parameters(self, params): for group_index, group in enumerate(param_groups): for p_index, p in enumerate(group["params"]): if id(p) in self.pid2config: - self.index2config[(group_index, p_index)] = self.pid2config[ - id(p) - ] + self.index2config[(group_index, p_index)] = self.pid2config[id(p)] - def override_config( - self, parameters, key=None, value=None, key_value_dict=None - ): + def override_config(self, parameters, key=None, value=None, key_value_dict=None): """ - Overrides initial optimizer config for specific parameters. + Override initial optimizer config with specific hyperparameters. The key-values of the optimizer config for the input parameters are overridden - This can be both, optimizer parameters like "betas", or "lr" or it can be - 8-bit specific parameters like "optim_bits", "percentile_clipping". - - Parameters - ---------- - parameters : torch.Tensor or list(torch.Tensors) - The input parameters. - key : str - The hyperparamter to override. - value : object - The value for the hyperparamters. - key_value_dict : dict - A dictionary with multiple key-values to override. + This can be both, optimizer parameters like `betas` or `lr`, or it can be + 8-bit specific parameters like `optim_bits` or `percentile_clipping`. + + Arguments: + parameters (`torch.Tensor` or `list(torch.Tensors)`): + The input parameters. + key (`str`): + The hyperparamter to override. + value: + The hyperparameter values. + key_value_dict (`dict`): + A dictionary with multiple key-values to override. + + Example: + + ```py + import torch + import bitsandbytes as bnb + + mng = bnb.optim.GlobalOptimManager.get_instance() + + model = MyModel() + mng.register_parameters(model.parameters()) # 1. register parameters while still on CPU + + model = model.cuda() + # use 8-bit optimizer states for all parameters + adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8) + + # 2. override: the parameter model.fc1.weight now uses 32-bit Adam + mng.override_config(model.fc1.weight, 'optim_bits', 32) + ``` """ self.uses_config_override = True if isinstance(parameters, torch.nn.Parameter): @@ -93,6 +110,17 @@ def register_module_override(self, module, param_name, config): class Optimizer8bit(torch.optim.Optimizer): def __init__(self, params, defaults, optim_bits=32, is_paged=False): + """ + Base 8-bit optimizer class. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ super().__init__(params, defaults) self.initialized = False self.name2qmap = {} @@ -101,18 +129,18 @@ def __init__(self, params, defaults, optim_bits=32, is_paged=False): self.mng = GlobalOptimManager.get_instance() self.non_castable_tensor_keys = { - "qmap1", - "qmap2", - "max1", - "max2", - "new_max1", - "new_max2", - "state1", - "state2", - "gnorm_vec", - "absmax1", - "absmax2", - "unorm_vec", + "qmap1", + "qmap2", + "max1", + "max2", + "new_max1", + "new_max2", + "state1", + "state2", + "gnorm_vec", + "absmax1", + "absmax2", + "unorm_vec", } if optim_bits == 8: @@ -126,11 +154,11 @@ def __setstate__(self, state): super().__setstate__(state) def load_state_dict(self, state_dict): - r"""Loads the optimizer state. + """Load an optimizer state. - Args: - state_dict (dict): optimizer state. Should be an object returned - from a call to :meth:`state_dict`. + Arguments: + state_dict (`dict`): + An optimizer state (should be returned from a call to `state_dict`) to load. """ # deepcopy, to be consistent with module API state_dict = deepcopy(state_dict) @@ -139,16 +167,12 @@ def load_state_dict(self, state_dict): saved_groups = state_dict["param_groups"] if len(groups) != len(saved_groups): - raise ValueError( - "loaded state dict has a different number of " - "parameter groups" - ) + raise ValueError("loaded state dict has a different number of parameter groups") param_lens = (len(g["params"]) for g in groups) saved_lens = (len(g["params"]) for g in saved_groups) if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): raise ValueError( - "loaded state dict contains a parameter group " - "that doesn't match the size of optimizer's group" + "loaded state dict contains a parameter group that doesn't match the size of optimizer's group", ) # Update the state @@ -197,9 +221,7 @@ def update_group(group, new_group): new_group["params"] = group["params"] return new_group - param_groups = [ - update_group(g, ng) for g, ng in zip(groups, saved_groups) - ] + param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] self.__setstate__({"state": state, "param_groups": param_groups}) def to_gpu(self): @@ -209,7 +231,7 @@ def to_gpu(self): values = self.state[p] for k, v in values.items(): if isinstance(v, torch.Tensor): - is_paged = getattr(v, 'is_paged', False) + is_paged = getattr(v, "is_paged", False) if not is_paged: self.state[p][k] = v.to(p.device) @@ -217,9 +239,7 @@ def check_overrides(self): for module, attr, config in self.mng.module_weight_config_triple: pmodule = getattr(module, attr) assert pmodule is not None - assert isinstance(pmodule, torch.Tensor) or isinstance( - pmodule, torch.Parameter - ) + assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter) found = False for gindex, group in enumerate(self.param_groups): if found: @@ -231,18 +251,16 @@ def check_overrides(self): # found the matching parameter # init override self.mng.pid2config[id(p)] = config - self.mng.index2config[ - (gindex, pindex) - ] = self.mng.pid2config[id(p)] + self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)] found = True @torch.no_grad() def step(self, closure=None): - """Performs a single optimization step. + """Perform a single optimization step. Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. + closure (`Callable`, *optional*, defaults to `None`): + A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: @@ -256,7 +274,7 @@ def step(self, closure=None): self.to_gpu() # needed for fairseq pure fp16 training self.initialized = True - #if self.is_paged: self.page_mng.prefetch_all() + # if self.is_paged: self.page_mng.prefetch_all() for gindex, group in enumerate(self.param_groups): for pindex, p in enumerate(group["params"]): if p.grad is None: @@ -273,7 +291,6 @@ def step(self, closure=None): # to sync to make sure all tensors are in the right state torch.cuda.synchronize() - return loss def get_config(self, gindex, pindex, group): @@ -297,9 +314,7 @@ def init_state(self, group, p, gindex, pindex): raise NotImplementedError("init_state method needs to be overridden") def update_step(self, group, p, gindex, pindex): - raise NotImplementedError( - "The update_step method needs to be overridden" - ) + raise NotImplementedError("The update_step method needs to be overridden") def get_state_buffer(self, p, dtype=torch.float32): if not self.is_paged or p.numel() < 1e5: @@ -314,12 +329,12 @@ def get_state_buffer(self, p, dtype=torch.float32): def prefetch_state(self, p): if self.is_paged: state = self.state[p] - s1 = state['state1'] - is_paged = getattr(s1, 'is_paged', False) + s1 = state["state1"] + is_paged = getattr(s1, "is_paged", False) if is_paged: - F.prefetch_tensor(state['state1']) - if 'state2' in state: - F.prefetch_tensor(state['state2']) + F.prefetch_tensor(state["state1"]) + if "state2" in state: + F.prefetch_tensor(state["state2"]) class Optimizer2State(Optimizer8bit): @@ -338,8 +353,41 @@ def __init__( block_wise=True, max_unorm=0.0, skip_zeros=False, - is_paged=False + is_paged=False, ): + """ + Base 2-state update optimizer class. + + Arguments: + optimizer_name (`str`): + The name of the optimizer. + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple`, defaults to (0.9, 0.999)): + The beta values for the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value for the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + max_unorm (`float`, defaults to 0.0): + The maximum value to normalize each block with. + skip_zeros (`bool`, defaults to `False`): + Whether to skip zero values for sparse gradients and models to ensure correct updates. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= eps: @@ -350,13 +398,9 @@ def __init__( betas = [float(b) for b in betas] for i in range(len(betas)): if not 0.0 <= betas[i] < 1.0: - raise ValueError( - f"Invalid beta parameter at index {i}: {betas[i]}" - ) + raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") if not 0.0 <= weight_decay: - raise ValueError( - f"Invalid weight_decay value: {weight_decay}" - ) + raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super().__init__(params, defaults, optim_bits, is_paged) @@ -385,9 +429,7 @@ def init_state(self, group, p, gindex, pindex): elif config["optim_bits"] == 8: dtype = torch.uint8 else: - raise NotImplementedError( - f'Amount of optimizer bits not supported: {config["optim_bits"]}' - ) + raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') if p.numel() < config["min_8bit_size"]: dtype = torch.float32 @@ -395,21 +437,15 @@ def init_state(self, group, p, gindex, pindex): state = self.state[p] state["step"] = 0 - if dtype == torch.float32 or ( - dtype == torch.uint8 and p.numel() < 4096 - ): + if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): state["state1"] = self.get_state_buffer(p, dtype=torch.float32) state["state2"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: if state["step"] == 0: if "dynamic" not in self.name2qmap: self.fill_qmap() - self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( - p.device - ) - self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to( - p.device - ) + self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device) + self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device) state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) state["qmap1"] = self.name2qmap["dynamic"] @@ -422,25 +458,13 @@ def init_state(self, group, p, gindex, pindex): blocks = n // 2048 blocks += 1 if n % 2048 > 0 else 0 - state["absmax1"] = torch.zeros( - (blocks,), dtype=torch.float32, device=p.device - ) - state["absmax2"] = torch.zeros( - (blocks,), dtype=torch.float32, device=p.device - ) + state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) + state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) else: - state["max1"] = torch.zeros( - (1,), dtype=torch.float32, device=p.device - ) - state["new_max1"] = torch.zeros( - (1,), dtype=torch.float32, device=p.device - ) - state["max2"] = torch.zeros( - (1,), dtype=torch.float32, device=p.device - ) - state["new_max2"] = torch.zeros( - (1,), dtype=torch.float32, device=p.device - ) + state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["new_max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["new_max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device) if config["percentile_clipping"] < 100: state["gnorm_vec"] = torch.zeros((100,), device=p.device) @@ -460,7 +484,10 @@ def update_step(self, group, p, gindex, pindex): if config["percentile_clipping"] < 100: current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( - grad, state["gnorm_vec"], step, config["percentile_clipping"] + grad, + state["gnorm_vec"], + step, + config["percentile_clipping"], ) else: gnorm_scale = 1.0 @@ -504,9 +531,7 @@ def update_step(self, group, p, gindex, pindex): state["new_max2"], config["weight_decay"], gnorm_scale=gnorm_scale, - unorm_vec=state["unorm_vec"] - if config["max_unorm"] > 0.0 - else None, + unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None, max_unorm=config["max_unorm"], ) @@ -551,21 +576,50 @@ def __init__( block_wise=True, max_unorm=0.0, skip_zeros=False, - is_paged=False + is_paged=False, ): + """ + Base 1-state update optimizer class. + + Arguments: + optimizer_name (`str`): + The name of the optimizer. + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple`, defaults to (0.9, 0.0)): + The beta values for the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value for the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + max_unorm (`float`, defaults to 0.0): + The maximum value to normalize each block with. + skip_zeros (`bool`, defaults to `False`): + Whether to skip zero values for sparse gradients and models to ensure correct updates. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") for i in range(len(betas)): if not 0.0 <= betas[i] < 1.0: - raise ValueError( - f"Invalid beta parameter at index {i}: {betas[i]}" - ) + raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") if not 0.0 <= weight_decay: - raise ValueError( - f"Invalid weight_decay value: {weight_decay}" - ) + raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super().__init__(params, defaults, optim_bits, is_paged) @@ -594,9 +648,7 @@ def init_state(self, group, p, gindex, pindex): elif config["optim_bits"] == 8: dtype = torch.uint8 else: - raise NotImplementedError( - f'Amount of optimizer bits not supported: {config["optim_bits"]}' - ) + raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') if p.numel() < config["min_8bit_size"]: dtype = torch.float32 @@ -604,17 +656,13 @@ def init_state(self, group, p, gindex, pindex): state = self.state[p] state["step"] = 0 - if dtype == torch.float32 or ( - dtype == torch.uint8 and p.numel() < 4096 - ): + if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): state["state1"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: if state["step"] == 0: if "dynamic" not in self.name2qmap: self.fill_qmap() - self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( - p.device - ) + self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device) state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) state["qmap1"] = self.name2qmap["dynamic"] @@ -624,16 +672,10 @@ def init_state(self, group, p, gindex, pindex): blocks = n // 2048 blocks += 1 if n % 2048 > 0 else 0 - state["absmax1"] = torch.zeros( - (blocks,), dtype=torch.float32, device=p.device - ) + state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) else: - state["max1"] = torch.zeros( - (1,), dtype=torch.float32, device=p.device - ) - state["new_max1"] = torch.zeros( - (1,), dtype=torch.float32, device=p.device - ) + state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["new_max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) if config["percentile_clipping"] < 100: state["gnorm_vec"] = torch.zeros((100,), device=p.device) @@ -653,7 +695,10 @@ def update_step(self, group, p, gindex, pindex): if config["percentile_clipping"] < 100: current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( - grad, state["gnorm_vec"], step, config["percentile_clipping"] + grad, + state["gnorm_vec"], + step, + config["percentile_clipping"], ) else: gnorm_scale = 1.0 @@ -669,7 +714,7 @@ def update_step(self, group, p, gindex, pindex): step, config["lr"], None, - config['betas'][1], + config["betas"][1], config["weight_decay"], gnorm_scale, state["unorm_vec"] if config["max_unorm"] > 0.0 else None, diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py index 2853ca723..25611309b 100644 --- a/bitsandbytes/optim/rmsprop.py +++ b/bitsandbytes/optim/rmsprop.py @@ -21,10 +21,37 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + Base RMSprop optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-2): + The learning rate. + alpha (`float`, defaults to 0.99): + The alpha value is the decay rate of the squared gradients of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + centered (`bool`, defaults to `False`): + Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if alpha == 0: - raise NotImplementedError( - "RMSprop with alpha==0.0 is not supported!" - ) + raise NotImplementedError("RMSprop with alpha==0.0 is not supported!") if centered: raise NotImplementedError("Centered RMSprop is not supported!") super().__init__( @@ -57,10 +84,37 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + 8-bit RMSprop optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-2): + The learning rate. + alpha (`float`, defaults to 0.99): + The alpha value is the decay rate of the squared gradients of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + centered (`bool`, defaults to `False`): + Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if alpha == 0: - raise NotImplementedError( - "RMSprop with alpha==0.0 is not supported!" - ) + raise NotImplementedError("RMSprop with alpha==0.0 is not supported!") if centered: raise NotImplementedError("Centered RMSprop is not supported!") super().__init__( @@ -93,11 +147,38 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + 32-bit RMSprop optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-2): + The learning rate. + alpha (`float`, defaults to 0.99): + The alpha value is the decay rate of the squared gradients of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + centered (`bool`, defaults to `False`): + Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if alpha == 0: - raise NotImplementedError( - "RMSprop with alpha==0.0 is not supported!" - ) + raise NotImplementedError("RMSprop with alpha==0.0 is not supported!") if centered: raise NotImplementedError("Centered RMSprop is not supported!") super().__init__( diff --git a/bitsandbytes/optim/sgd.py b/bitsandbytes/optim/sgd.py index 3c0fc2b9f..ec18f036c 100644 --- a/bitsandbytes/optim/sgd.py +++ b/bitsandbytes/optim/sgd.py @@ -20,6 +20,33 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + Base SGD optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`): + The learning rate. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + dampening (`float`, defaults to 0): + The dampening value reduces the momentum of the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + nesterov (`bool`, defaults to `False`): + Whether to use Nesterov momentum. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if momentum == 0: raise NotImplementedError("SGD without momentum is not supported!") super().__init__( @@ -51,6 +78,31 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + 8-bit SGD optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`): + The learning rate. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + dampening (`float`, defaults to 0): + The dampening value reduces the momentum of the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + nesterov (`bool`, defaults to `False`): + Whether to use Nesterov momentum. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if momentum == 0: raise NotImplementedError("SGD without momentum is not supported!") super().__init__( @@ -82,6 +134,31 @@ def __init__( percentile_clipping=100, block_wise=True, ): + """ + 32-bit SGD optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`): + The learning rate. + momentum (`float`, defaults to 0): + The momentum value speeds up the optimizer by taking bigger steps. + dampening (`float`, defaults to 0): + The dampening value reduces the momentum of the optimizer. + weight_decay (`float`, defaults to 0.0): + The weight decay value for the optimizer. + nesterov (`bool`, defaults to `False`): + Whether to use Nesterov momentum. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + """ if momentum == 0: raise NotImplementedError("SGD without momentum is not supported!") super().__init__( diff --git a/bitsandbytes/research/__init__.py b/bitsandbytes/research/__init__.py index 47b720d78..31db4f282 100644 --- a/bitsandbytes/research/__init__.py +++ b/bitsandbytes/research/__init__.py @@ -1,6 +1,6 @@ from . import nn from .autograd._functions import ( - switchback_bnb, matmul_fp8_global, matmul_fp8_mixed, + switchback_bnb, ) diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index 883121759..e2f85c868 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -1,21 +1,18 @@ +from functools import reduce # Required in Python 3 import operator +from typing import Optional import warnings -from dataclasses import dataclass -from functools import reduce # Required in Python 3 import torch +from bitsandbytes.autograd._functions import GlobalOutlierPooler, MatmulLtState import bitsandbytes.functional as F -from bitsandbytes.autograd._functions import MatmulLtState, GlobalOutlierPooler -from bitsandbytes.cextension import HIP_ENVIRONMENT - # math.prod not compatible with python < 3.8 def prod(iterable): return reduce(operator.mul, iterable, 1) -tensor = torch.Tensor class MatMulFP8Mixed(torch.autograd.Function): # forward is the same, but we added the fallback for pre-turing GPUs @@ -86,7 +83,7 @@ def backward(ctx, grad_output): # fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2]) # not supported by PyTorch. TODO: create work-around - if req_gradA: + if req_gradA: grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype) if req_gradB: @@ -170,7 +167,7 @@ def backward(ctx, grad_output): # fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2]) # not supported by PyTorch. TODO: create work-around - if req_gradA: + if req_gradA: grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype) if req_gradB: @@ -187,7 +184,9 @@ def backward(ctx, grad_output): class SwitchBackBnb(torch.autograd.Function): @staticmethod - def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): + # TODO: the B008 on the line below is a likely bug; the current implementation will + # have each SwitchBackBnb instance share a single MatmulLtState instance!!! + def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B008 # default to pytorch behavior if inputs are empty ctx.is_empty = False if prod(A.shape) == 0: @@ -196,9 +195,9 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): ctx.B = B ctx.bias = bias if A.shape[-1] == B.shape[0]: - return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device) + return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device) else: - return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device) + return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device) # 1. Quantize A # 2. Quantize B @@ -217,9 +216,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # 1. Quantize A if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( - A.to(torch.float16), threshold=state.threshold - ) + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) if state.threshold > 0.0 and coo_tensorA is not None: if state.has_fp16_weights: @@ -235,14 +232,14 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # we also need to convert it to the turing/ampere format state.CxB, state.SB = F.transform(state.CB, to_order=formatB) else: - #print('A shape', A.shape) + # print('A shape', A.shape) if not state.has_fp16_weights and state.CxB is None: state.CxB, state.SB = F.transform(state.CB, to_order=formatB) subA = None # 2. Quantize B if state.has_fp16_weights: - #print('B shape', B.shape) + # print('B shape', B.shape) has_grad = True if (getattr(B, "grad", None) is not None) else False is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) if is_transposed: @@ -273,12 +270,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # else: # state.idx = outlier_idx outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) - state.subB = ( - (outliers * state.SCB.view(-1, 1) / 127.0) - .t() - .contiguous() - .to(A.dtype) - ) + state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) CA[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0 subA = A[:, state.idx.long()] @@ -321,14 +313,13 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) - - clone_func = torch.clone if len(output_shape) == 3 else lambda x : x + clone_func = torch.clone if len(output_shape) == 3 else lambda x: x return clone_func(output.view(output_shape)) @staticmethod def backward(ctx, grad_output): if ctx.is_empty: - bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) + bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad CAt, subA, A = ctx.tensors @@ -343,9 +334,7 @@ def backward(ctx, grad_output): # Cast grad_output to fp16 if len(grad_output.shape) == 3: - grad_output = grad_output.reshape( - -1, grad_output.shape[-1] - ).contiguous() + grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) @@ -358,29 +347,25 @@ def backward(ctx, grad_output): if state.CBt is not None: C32grad, Sgrad = F.transform(Cgrad, "col32") if state.CxBt is None: - state.CxBt, state.SBt = F.transform( - state.CBt, to_order=formatB, transpose=True - ) + state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True) # print('back B shape', state.CxBt.shape) # print('back grad shape', C32grad.shape) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) elif state.CB is not None: - CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0)) + CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) else: - raise Exception('State must contain either CBt or CB matrix for backward') + raise Exception("State must contain either CBt or CB matrix for backward") return grad_A, grad_B, None, grad_bias, None + def get_block_sizes(input_matrix, weight_matrix): input_features = input_matrix.shape[-1] - output_features = (weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1]) - if not HIP_ENVIRONMENT: - array = [4096, 2048, 1024, 512, 256, 128, 64, 0] - else: - array = [4096, 2048, 1024, 512, 256, 128, 0] + output_features = weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1] + array = [4096, 2048, 1024, 512, 256, 128, 0] bsz, bsz2 = 1024, 1024 for i, k in enumerate(array): if input_features > array[i + 1]: @@ -393,21 +378,42 @@ def get_block_sizes(input_matrix, weight_matrix): return bsz, bsz2 -def matmul_fp8_global(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1): - if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) + +def matmul_fp8_global( + A: torch.Tensor, + B: torch.Tensor, + fw_code: torch.Tensor, + bw_code: torch.Tensor, + out: Optional[torch.Tensor] = None, + bsz: int = -1, + bsz2: int = -1, +): + if bsz == -1 or bsz2 == -1: + bsz, bsz2 = get_block_sizes(A, B) return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2) -def matmul_fp8_mixed(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1): - if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) + +def matmul_fp8_mixed( + A: torch.Tensor, + B: torch.Tensor, + fw_code: torch.Tensor, + bw_code: torch.Tensor, + out: Optional[torch.Tensor] = None, + bsz: int = -1, + bsz2: int = -1, +): + if bsz == -1 or bsz2 == -1: + bsz, bsz2 = get_block_sizes(A, B) return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2) + def switchback_bnb( - A: tensor, - B: tensor, - out: tensor = None, - state: MatmulLtState = None, + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + state: Optional[MatmulLtState] = None, threshold=0.0, - bias=None + bias=None, ): state = state or MatmulLtState() if threshold > 0.0: diff --git a/bitsandbytes/research/nn/__init__.py b/bitsandbytes/research/nn/__init__.py index 8faec10bb..417011218 100644 --- a/bitsandbytes/research/nn/__init__.py +++ b/bitsandbytes/research/nn/__init__.py @@ -1 +1 @@ -from .modules import LinearFP8Mixed, LinearFP8Global +from .modules import LinearFP8Global, LinearFP8Mixed diff --git a/bitsandbytes/research/nn/modules.py b/bitsandbytes/research/nn/modules.py index 2a46b40c3..57c0f3358 100644 --- a/bitsandbytes/research/nn/modules.py +++ b/bitsandbytes/research/nn/modules.py @@ -1,12 +1,9 @@ -from typing import Optional, TypeVar, Union, overload +from typing import TypeVar import torch -import torch.nn.functional as F -from torch import Tensor, device, dtype, nn +from torch import nn import bitsandbytes as bnb -from bitsandbytes.optim import GlobalOptimManager -from bitsandbytes.utils import OutlierTracer, find_outlier_dims T = TypeVar("T", bound="torch.nn.Module") @@ -31,12 +28,20 @@ def forward(self, x: torch.Tensor): self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) - out = bnb.research.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) + out = bnb.research.matmul_fp8_mixed( + x, + self.weight.t(), + fw_code=self.fw_code, + bw_code=self.bw_code, + bsz=self.bsz, + bsz2=self.bsz2, + ) if self.bias is not None: out += self.bias return out + class LinearFP8Global(nn.Linear): def __init__(self, input_features, output_features, bias=True): super().__init__(input_features, output_features, bias) @@ -57,7 +62,14 @@ def forward(self, x: torch.Tensor): self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) - out = bnb.matmul_fp8_global(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) + out = bnb.matmul_fp8_global( + x, + self.weight.t(), + fw_code=self.fw_code, + bw_code=self.bw_code, + bsz=self.bsz, + bsz2=self.bsz2, + ) if self.bias is not None: out += self.bias diff --git a/bitsandbytes/triton/dequantize_rowwise.py b/bitsandbytes/triton/dequantize_rowwise.py index e092680b8..26eab84f2 100644 --- a/bitsandbytes/triton/dequantize_rowwise.py +++ b/bitsandbytes/triton/dequantize_rowwise.py @@ -1,35 +1,36 @@ import math + import torch -import time + from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None -else: + def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): + return None +else: import triton import triton.language as tl - from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time # rowwise quantize # TODO: autotune this better. @triton.autotune( - configs=[ - triton.Config({}, num_stages=1, num_warps=8), - triton.Config({}, num_stages=2, num_warps=8), - triton.Config({}, num_stages=4, num_warps=8), - triton.Config({}, num_stages=8, num_warps=8), - triton.Config({}, num_stages=1), - triton.Config({}, num_stages=2), - triton.Config({}, num_stages=4), - triton.Config({}, num_stages=8), - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - ], - key=['n_elements'] + configs=[ + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=["n_elements"], ) @triton.jit def _dequantize_rowwise( @@ -50,7 +51,6 @@ def _dequantize_rowwise( max_val = tl.load(state_x + pid) output = max_val * x * inv_127 tl.store(output_ptr + offsets, output, mask=row_mask) - def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): output = torch.empty(*x.shape, device=x.device, dtype=torch.float16) @@ -60,5 +60,5 @@ def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): assert x.is_cuda and output.is_cuda n_elements = output.numel() grid = lambda meta: (x.shape[0],) - _dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) + _dequantize_rowwise[grid](x, state_x, output, 1.0 / 127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) return output diff --git a/bitsandbytes/triton/int8_matmul_mixed_dequantize.py b/bitsandbytes/triton/int8_matmul_mixed_dequantize.py index b0961f558..583371d91 100644 --- a/bitsandbytes/triton/int8_matmul_mixed_dequantize.py +++ b/bitsandbytes/triton/int8_matmul_mixed_dequantize.py @@ -1,15 +1,16 @@ import torch + from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): return None -else: + def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): + return None +else: import triton import triton.language as tl from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time - # This is a matmul kernel based on triton.ops.matmul # It is modified to support rowwise quantized input and global quantized weight # It's purpose is fused matmul then dequantize @@ -26,57 +27,83 @@ def get_configs_io_bound(): for block_n in [32, 64, 128, 256]: num_warps = 2 if block_n <= 64 else 4 configs.append( - triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, - num_stages=num_stages, num_warps=num_warps)) + triton.Config( + {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1}, + num_stages=num_stages, + num_warps=num_warps, + ), + ) # split_k for split_k in [2, 4, 8, 16]: - configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, - num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + configs.append( + triton.Config( + {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k}, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=init_to_zero("C"), + ), + ) return configs - @triton.autotune( configs=[ # basic configs for compute-bound matmuls - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2), # good for int8 - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), - ] + get_configs_io_bound(), - key=['M', 'N', 'K'], - prune_configs_by={ - 'early_config_prune': early_config_prune, - 'perf_model': estimate_matmul_time, - 'top_k': 10 + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2), + *get_configs_io_bound(), + ], + key=["M", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10}, + ) + @triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, }, ) - @triton.heuristics({ - 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, - }) @triton.jit - def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, - ACC_TYPE: tl.constexpr - ): + def _int8_matmul_mixed_dequantize( + A, + B, + C, + bias, + state_x_ptr, + state_w_ptr, + M, + N, + K, + divfactor: tl.constexpr, + has_bias: tl.constexpr, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + ): # matrix multiplication pid = tl.program_id(0) pid_z = tl.program_id(1) @@ -113,13 +140,13 @@ def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, b = tl.load(B) else: k_remaining = K - k * (BLOCK_K * SPLIT_K) - a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) - b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0) acc += tl.dot(a, b) A += BLOCK_K * SPLIT_K * stride_ak B += BLOCK_K * SPLIT_K * stride_bk - - acc = (w_factor * (x_factor * (acc * divfactor))) + + acc = w_factor * (x_factor * (acc * divfactor)) acc = acc.to(C.dtype.element_ty) # conditionally add bias @@ -135,10 +162,9 @@ def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, else: tl.atomic_add(C, acc, mask=mask) - def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): device = a.device - divfactor = 1. / (127. * 127.) + divfactor = 1.0 / (127.0 * 127.0) has_bias = 0 if bias is None else 1 # handle non-contiguous inputs if necessary if a.stride(0) > 1 and a.stride(1) > 1: @@ -152,12 +178,28 @@ def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): # allocates output c = torch.empty((M, N), device=device, dtype=torch.float16) # accumulator types - ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + ACC_TYPE = tl.float32 # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 # launch int8_matmul_mixed_dequantize kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) - _int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - GROUP_M=8, ACC_TYPE=ACC_TYPE) + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"]) + _int8_matmul_mixed_dequantize[grid]( + a, + b, + c, + bias, + state_x, + state_w, + M, + N, + K, + divfactor, + has_bias, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + GROUP_M=8, + ACC_TYPE=ACC_TYPE, + ) return c diff --git a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py index 33f4d13f2..e3d192ded 100644 --- a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py +++ b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py @@ -3,7 +3,9 @@ from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None + + def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): + return None else: import triton import triton.language as tl @@ -17,7 +19,6 @@ def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None def init_to_zero(name): return lambda nargs: nargs[name].zero_() - def get_configs_io_bound(): configs = [] for num_stages in [2, 3, 4, 5, 6]: @@ -26,57 +27,83 @@ def get_configs_io_bound(): for block_n in [32, 64, 128, 256]: num_warps = 2 if block_n <= 64 else 4 configs.append( - triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, - num_stages=num_stages, num_warps=num_warps)) + triton.Config( + {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1}, + num_stages=num_stages, + num_warps=num_warps, + ), + ) # split_k for split_k in [2, 4, 8, 16]: - configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, - num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + configs.append( + triton.Config( + {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k}, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=init_to_zero("C"), + ), + ) return configs - @triton.autotune( configs=[ # basic configs for compute-bound matmuls - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2), # good for int8 - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), - ] + get_configs_io_bound(), - key=['M', 'N', 'K'], - prune_configs_by={ - 'early_config_prune': early_config_prune, - 'perf_model': estimate_matmul_time, - 'top_k': 10 + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2), + *get_configs_io_bound(), + ], + key=["M", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10}, + ) + @triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, }, ) - @triton.heuristics({ - 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, - }) @triton.jit - def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, - ACC_TYPE: tl.constexpr - ): + def _int8_matmul_rowwise_dequantize( + A, + B, + C, + bias, + state_x_ptr, + state_w_ptr, + M, + N, + K, + divfactor, + has_bias: tl.constexpr, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + ): # matrix multiplication pid = tl.program_id(0) pid_z = tl.program_id(1) @@ -113,13 +140,13 @@ def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, b = tl.load(B) else: k_remaining = K - k * (BLOCK_K * SPLIT_K) - a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) - b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0) acc += tl.dot(a, b) A += BLOCK_K * SPLIT_K * stride_ak B += BLOCK_K * SPLIT_K * stride_bk - - acc = (w_factor * (x_factor * (acc * divfactor))) + + acc = w_factor * (x_factor * (acc * divfactor)) acc = acc.to(C.dtype.element_ty) if has_bias: @@ -134,9 +161,8 @@ def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, else: tl.atomic_add(C, acc, mask=mask) - def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): - divfactor = 1. / (127. * 127.) + divfactor = 1.0 / (127.0 * 127.0) has_bias = 0 if bias is None else 1 @@ -153,12 +179,28 @@ def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): # allocates output c = torch.empty((M, N), device=device, dtype=torch.float16) # accumulator types - ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + ACC_TYPE = tl.float32 # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 # launch int8_matmul_rowwise_dequantize kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) - _int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - GROUP_M=8, ACC_TYPE=ACC_TYPE) + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"]) + _int8_matmul_rowwise_dequantize[grid]( + a, + b, + c, + bias, + state_x, + state_w, + M, + N, + K, + divfactor, + has_bias, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + GROUP_M=8, + ACC_TYPE=ACC_TYPE, + ) return c diff --git a/bitsandbytes/triton/quantize_columnwise_and_transpose.py b/bitsandbytes/triton/quantize_columnwise_and_transpose.py index 54220d95a..b8eeffd0c 100644 --- a/bitsandbytes/triton/quantize_columnwise_and_transpose.py +++ b/bitsandbytes/triton/quantize_columnwise_and_transpose.py @@ -1,37 +1,38 @@ import math + import torch -import time + from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - def quantize_columnwise_and_transpose(x: torch.Tensor): return None -else: + def quantize_columnwise_and_transpose(x: torch.Tensor): + return None +else: import triton import triton.language as tl - from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time # This kernel does fused columnwise quantization and transpose. # TODO: autotune this better. @triton.autotune( - configs=[ - triton.Config({}, num_stages=1), - triton.Config({}, num_stages=2), - triton.Config({}, num_stages=4), - triton.Config({}, num_stages=8), - triton.Config({}, num_stages=16), - triton.Config({}, num_stages=1, num_warps=8), - triton.Config({}, num_stages=2, num_warps=8), - triton.Config({}, num_stages=4, num_warps=8), - triton.Config({}, num_stages=8, num_warps=8), - triton.Config({}, num_stages=16, num_warps=8), - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - ], - key=['n_elements'] + configs=[ + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_stages=16), + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=16, num_warps=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=["n_elements"], ) @triton.jit def _quantize_columnwise_and_transpose( @@ -39,7 +40,8 @@ def _quantize_columnwise_and_transpose( output_ptr, output_maxs, n_elements, - M : tl.constexpr, N : tl.constexpr, + M: tl.constexpr, + N: tl.constexpr, BLOCK_SIZE: tl.constexpr, P2: tl.constexpr, ): @@ -47,14 +49,14 @@ def _quantize_columnwise_and_transpose( block_start = pid p2_arange = tl.arange(0, P2) p2_arange_mask = p2_arange < M - arange = p2_arange * N + arange = p2_arange * N offsets = block_start + arange x = tl.load(x_ptr + offsets, mask=p2_arange_mask) abs_x = tl.abs(x) max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0) - output = tl.libdevice.llrint(127. * (x / max_val)) + output = tl.libdevice.llrint(127.0 * (x / max_val)) - new_start = pid * M + new_start = pid * M new_offsets = new_start + p2_arange tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask) tl.store(output_maxs + pid, max_val) @@ -68,7 +70,6 @@ def quantize_columnwise_and_transpose(x: torch.Tensor): assert x.is_cuda and output.is_cuda n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) _quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2) return output, output_maxs - diff --git a/bitsandbytes/triton/quantize_global.py b/bitsandbytes/triton/quantize_global.py index 845db6ecd..f35bdd304 100644 --- a/bitsandbytes/triton/quantize_global.py +++ b/bitsandbytes/triton/quantize_global.py @@ -1,25 +1,25 @@ -import math import torch -import time + from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - def quantize_global_transpose(input): return None - def quantize_global(x: torch.Tensor): return None -else: + def quantize_global_transpose(input): + return None + + def quantize_global(x: torch.Tensor): + return None +else: import triton import triton.language as tl - from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time # global quantize @triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4), - triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1), - - ], - key=['n_elements'] + configs=[ + triton.Config({"BLOCK_SIZE": 1024}, num_warps=4), + triton.Config({"BLOCK_SIZE": 2048}, num_stages=1), + ], + key=["n_elements"], ) @triton.jit def _quantize_global( @@ -35,73 +35,90 @@ def _quantize_global( mask = offsets < n_elements x = tl.load(x_ptr + offsets, mask=mask) absmax_inv = tl.load(absmax_inv_ptr) - output = tl.libdevice.llrint(127. * (x * absmax_inv)) + output = tl.libdevice.llrint(127.0 * (x * absmax_inv)) tl.store(output_ptr + offsets, output, mask=mask) def quantize_global(x: torch.Tensor): absmax = x.abs().max().unsqueeze(0) - absmax_inv = 1./ absmax - output = torch.empty(*x.shape, device='cuda', dtype=torch.int8) + absmax_inv = 1.0 / absmax + output = torch.empty(*x.shape, device="cuda", dtype=torch.int8) assert x.is_cuda and output.is_cuda n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) _quantize_global[grid](x, absmax_inv, output, n_elements) return output, absmax - # global quantize and transpose @triton.autotune( - configs=[ - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), - - # ... - ], - key=['M', 'N'] + configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "GROUP_M": 8}, num_warps=4), + # ... + ], + key=["M", "N"], ) @triton.jit - def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N, - BLOCK_M : tl.constexpr, - BLOCK_N : tl.constexpr, - GROUP_M : tl.constexpr): + def _quantize_global_transpose( + A, + absmax_inv_ptr, + B, + stride_am, + stride_an, + stride_bn, + stride_bm, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + GROUP_M: tl.constexpr, + ): pid = tl.program_id(0) grid_m = (M + BLOCK_M - 1) // BLOCK_M grid_n = (N + BLOCK_N - 1) // BLOCK_N - + width = GROUP_M * grid_n group_id = pid // width group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // group_size - + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an) mask = (rm < M)[:, None] & (rn < N)[None, :] a = tl.load(A, mask=mask) absmax_inv = tl.load(absmax_inv_ptr) - + # rematerialize to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn) mask = (rm < M)[:, None] & (rn < N)[None, :] - output = tl.libdevice.llrint(127. * (a * absmax_inv)) + output = tl.libdevice.llrint(127.0 * (a * absmax_inv)) tl.store(B, output, mask=mask) def quantize_global_transpose(input): absmax = input.abs().max().unsqueeze(0) - absmax_inv = 1./ absmax + absmax_inv = 1.0 / absmax M, N = input.shape - out = torch.empty(N, M, device='cuda', dtype=torch.int8) - + out = torch.empty(N, M, device="cuda", dtype=torch.int8) + assert out.size(0) == N and out.size(1) == M assert input.stride(0) == 1 or input.stride(1) == 1 assert out.stride(0) == 1 or out.stride(1) == 1 - - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) - _quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N) - return out, absmax + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + _quantize_global_transpose[grid]( + input, + absmax_inv, + out, + input.stride(0), + input.stride(1), + out.stride(0), + out.stride(1), + M, + N, + ) + return out, absmax diff --git a/bitsandbytes/triton/quantize_rowwise.py b/bitsandbytes/triton/quantize_rowwise.py index 26d218321..f92ace02c 100644 --- a/bitsandbytes/triton/quantize_rowwise.py +++ b/bitsandbytes/triton/quantize_rowwise.py @@ -1,36 +1,36 @@ import math + import torch -import time from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - def quantize_rowwise(x: torch.Tensor): return None -else: + def quantize_rowwise(x: torch.Tensor): + return None +else: import triton import triton.language as tl - from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time # rowwise quantize # TODO: autotune this better. @triton.autotune( - configs=[ - triton.Config({}, num_stages=1, num_warps=8), - triton.Config({}, num_stages=2, num_warps=8), - triton.Config({}, num_stages=4, num_warps=8), - triton.Config({}, num_stages=8, num_warps=8), - triton.Config({}, num_stages=1), - triton.Config({}, num_stages=2), - triton.Config({}, num_stages=4), - triton.Config({}, num_stages=8), - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - ], - key=['n_elements'] + configs=[ + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=["n_elements"], ) @triton.jit def _quantize_rowwise( @@ -47,10 +47,10 @@ def _quantize_rowwise( offsets = block_start + arange row_mask = arange < BLOCK_SIZE x = tl.load(x_ptr + offsets, mask=row_mask) - + abs_x = tl.abs(x) max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) - output = tl.libdevice.llrint(127. * (x / max_val)) + output = tl.libdevice.llrint(127.0 * (x / max_val)) tl.store(output_ptr + offsets, output, mask=row_mask) tl.store(output_maxs + pid, max_val) @@ -65,4 +65,3 @@ def quantize_rowwise(x: torch.Tensor): grid = lambda meta: (x.shape[0],) _quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) return output, output_maxs - diff --git a/bitsandbytes/triton/triton_utils.py b/bitsandbytes/triton/triton_utils.py index c74c23962..6bbdbf1c1 100644 --- a/bitsandbytes/triton/triton_utils.py +++ b/bitsandbytes/triton/triton_utils.py @@ -1,4 +1,5 @@ import importlib + def is_triton_available(): return importlib.util.find_spec("triton") is not None diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py deleted file mode 100644 index 48373a1fe..000000000 --- a/bitsandbytes/utils.py +++ /dev/null @@ -1,194 +0,0 @@ -import json -import shlex -import subprocess -import torch -from typing import Tuple - -def outlier_hook(module, input): - assert isinstance(module, torch.nn.Linear) - tracer = OutlierTracer.get_instance() - hvalue = tracer.get_hvalue(module.weight) - if hvalue not in tracer.hvalue2outlier_idx: - outlier_idx = find_outlier_dims(module.weight) - tracer.outliers.append(outlier_idx) - tracer.hvalues.append(hvalue) - if len(tracer.outliers) > 1: - # assign the current layer the outlier idx found from the weight - # of the previous linear layer - if tracer.outliers[-1].numel() > 0: - assert tracer.outliers[-1].max() < module.weight.shape[1] - tracer.hvalue2outlier_idx[hvalue] = tracer.outliers[-1] - - else: - # first layer, we cannot use the weight for outlier detection - # we follow a mixed approach: - # (1) zscore test of std of hidden dimension - # (2) magnitude > 6 test - merged = input[0].view(-1, input[0].shape[-1]) - # (1) zscore test of std of hidden dimension - outlier_idx = find_outlier_dims(merged, reduction_dim=1, zscore=3) - # (2) magnitude > 6 test - dims = (torch.abs(input[0])> 6).sum(dim=list(range(len(input[0].shape)-1))) - outlier_idx2 = torch.where(dims > 0)[0] - outlier_idx = torch.cat([outlier_idx, outlier_idx2]).unique() - tracer.hvalue2outlier_idx[hvalue] = outlier_idx - else: - for hook in tracer.hooks: - hook.remove() - - -class OutlierTracer(object): - _instance = None - - def __init__(self): - raise RuntimeError("Call get_instance() instead") - - def initialize(self, model): - self.last_w = None - self.current_outlier_dims = None - self.hvalues = [] - self.outliers = [] - self.hvalue2outlier_idx = {} - self.initialized = True - self.hooks = [] - - for n, m in model.named_modules(): - if isinstance(m, torch.nn.Linear): - self.hooks.append(m.register_forward_pre_hook(outlier_hook)) - - def is_initialized(self): - return getattr(self, 'initialized', False) - - def get_hvalue(self, weight): - return weight.data.storage().data_ptr() - - def get_outliers(self, weight): - if not self.is_initialized(): - print('Outlier tracer is not initialized...') - return None - hvalue = self.get_hvalue(weight) - if hvalue in self.hvalue2outlier_idx: - return self.hvalue2outlier_idx[hvalue] - else: - return None - - @classmethod - def get_instance(cls): - if cls._instance is None: - cls._instance = cls.__new__(cls) - return cls._instance - -def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False): - if rdm: - return torch.randint(0, weight.shape[1], size=(topk,), device=weight.device).long() - - m = weight.mean(reduction_dim) - mm = m.mean() - mstd = m.std() - zm = (m-mm)/mstd - - std = weight.std(reduction_dim) - stdm = std.mean() - stdstd = std.std() - - zstd = (std-stdm)/stdstd - - if topk is not None: - val, idx = torch.topk(std.abs(), k=topk, dim=0) - else: - idx = torch.where(zstd > zscore)[0] - - return idx - - -def execute_and_return(command_string: str) -> Tuple[str, str]: - def _decode(subprocess_err_out_tuple): - return tuple( - to_decode.decode("UTF-8").strip() - for to_decode in subprocess_err_out_tuple - ) - - def execute_and_return_decoded_std_streams(command_string): - return _decode( - subprocess.Popen( - shlex.split(command_string), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ).communicate() - ) - - std_out, std_err = execute_and_return_decoded_std_streams(command_string) - return std_out, std_err - - - -def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None): - """ - Replace linear modules with a new Linear module. - Parameters: - model (`torch.nn.Module`): - Input model or `torch.nn.Module` as the function is run recursively. - linear_replacement (`torch.nn.Module`): - The linear module that replaces the old one. Only expects standard arguments. - If other arguments need to be passed, use a lambda. - skip_modules (`List[str]`, *optional*, defaults to `lm_head`): - List of modules names not to convert. Defaults to `lm_head`. - copy_weights (`bool`): - Copy the weights from the old linear module to the new one - post_processing_fun_name (`str`): - A function name of the replacement linear class that is called - after processing. - """ - for name, module in model.named_children(): - if len(list(module.children())) > 0: - replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function) - - if isinstance(module, torch.nn.Linear) and name not in skip_modules: - old_module = model._modules[name] - model._modules[name] = linear_replacement( - module.in_features, - module.out_features, - module.bias is not None, - ) - if copy_weights: - model._modules[name].weight = old_module.weight - model._modules[name].bias = old_module.bias - - if post_processing_function is not None: - func = getattr(module, post_processing_function, None) - if func is not None: func(module) - return model - - -def pack_dict_to_tensor(source_dict): - """ - Pack a dictionary into a torch tensor for storing quant_state items in state_dict. - - Parameters: - - source_dict: The dictionary to be packed. - - Returns: - A torch tensor containing the packed data. - """ - json_str = json.dumps(source_dict) - json_bytes = json_str.encode('utf-8') - tensor_data = torch.tensor(list(json_bytes), dtype=torch.uint8) - - return tensor_data - - -def unpack_tensor_to_dict(tensor_data): - """ - Unpack a torch tensor into a Python dictionary. - - Parameters: - - tensor_data: The torch tensor containing the packed data. - - Returns: - A Python dictionary containing the unpacked data. - """ - json_bytes = bytes(tensor_data.cpu().numpy()) - json_str = json_bytes.decode('utf-8') - unpacked_dict = json.loads(json_str) - - return unpacked_dict diff --git a/check_bnb_install.py b/check_bnb_install.py index e50afb0a1..5a7f74f89 100644 --- a/check_bnb_install.py +++ b/check_bnb_install.py @@ -1,6 +1,7 @@ -import bitsandbytes as bnb import torch +import bitsandbytes as bnb + p = torch.nn.Parameter(torch.rand(10,10).cuda()) a = torch.rand(10,10).cuda() @@ -8,8 +9,6 @@ adam = bnb.optim.Adam([p]) - - out = a*p loss = out.sum() loss.backward() diff --git a/csrc/common.cpp b/csrc/common.cpp index 52f029917..736af1d6f 100644 --- a/csrc/common.cpp +++ b/csrc/common.cpp @@ -1,39 +1,40 @@ + +#include #include #include -void *quantize_block(void *arguments) { +#define __HIP_PLATFORM_AMD__ + + +void quantize_block(const quantize_block_args& args) { // 1. find absmax in block // 2. divide input value by absmax to normalize into [-1.0, 1.0] // 3. do binary search to find the closest value // 4. check minimal distance // 5. store index - struct quantize_block_args *args = (quantize_block_args *) arguments; - // 1. find absmax in block float absmax_block = -FLT_MAX; - for (long long i = args->block_idx; i < args->block_end; i++) - absmax_block = fmax(absmax_block, fabs(args->A[i])); + for (long long i = args.block_idx; i < args.block_end; i++) + absmax_block = fmax(absmax_block, fabs(args.A[i])); - args->absmax[args->block_idx / args->blocksize] = absmax_block; + args.absmax[args.block_idx / args.blocksize] = absmax_block; - for (long long i = args->block_idx; i < args->block_end; i++) { + for (long long i = args.block_idx; i < args.block_end; i++) { // 2. divide input value by absmax to normalize into [-1.0, 1.0] // 3. do binary search to find the closest value - float normed_value = args->A[i] / absmax_block; - long long idx = args->bin_searcher->scalar(normed_value); + float normed_value = args.A[i] / absmax_block; + long long idx = args.bin_searcher->scalar(normed_value); // 4. check minimal distance // The binary search returns always the value to the left, which might not be the closest value if (idx < 255) { - float dist_left = fabs(normed_value - (args->code[idx])); - float dist_right = fabs(normed_value - (args->code[idx + 1])); + float dist_left = fabs(normed_value - (args.code[idx])); + float dist_right = fabs(normed_value - (args.code[idx + 1])); if (dist_right < dist_left) { idx += 1; } } // 5. store index - args->out[i] = (unsigned char) idx; + args.out[i] = (unsigned char) idx; } - - return NULL; } diff --git a/csrc/common.h b/csrc/common.h index c99034e78..9b1cfe1e0 100644 --- a/csrc/common.h +++ b/csrc/common.h @@ -1,8 +1,14 @@ + +#include #include #ifndef common #define common +#ifdef __HIP_PLATFORM_AMD__ +// Compiled with HIP-Clang +#endif + using namespace BinSearch; #define BLOCK_SIZE 16384 @@ -20,6 +26,6 @@ struct quantize_block_args { }; -void *quantize_block(void *arguments); +void quantize_block(const quantize_block_args& args); #endif diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index e28e7b2c2..6e06e3962 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -1,6 +1,8 @@ + +#include #include -#include #include +#include using namespace BinSearch; @@ -26,17 +28,13 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long BinAlgo bin_searcher(code, elements_code); int thread_wave_size = 256; - // we chunk the thresds into waves of 256 since the max limit is + // we chunk the threads into waves of 256 since the max limit is // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size) for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size) { long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset; - pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * valid_chunks); - - struct quantize_block_args **args = (quantize_block_args **) malloc(valid_chunks * sizeof(quantize_block_args *)); - - for(long long i = 0; i < valid_chunks; i++) - args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args)); + std::vector threads(valid_chunks); + std::vector args(valid_chunks); int chunks_processed = 0; for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize) @@ -44,30 +42,24 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; long long block_end = block_idx + valid_items; - struct quantize_block_args *arg = args[chunks_processed]; - arg->bin_searcher = &bin_searcher; - arg->code = code; - arg->A = A; - arg->absmax = absmax; - arg->out = out; - arg->block_end = block_end; - arg->block_idx = block_idx; - arg->threadidx = block_idx / blocksize; - arg->blocksize = blocksize; - - pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg); + struct quantize_block_args& arg = args[chunks_processed]; + arg.bin_searcher = &bin_searcher; + arg.code = code; + arg.A = A; + arg.absmax = absmax; + arg.out = out; + arg.block_end = block_end; + arg.block_idx = block_idx; + arg.threadidx = block_idx / blocksize; + arg.blocksize = blocksize; + + threads[chunks_processed] = std::thread([arg] { quantize_block(arg); }); chunks_processed += 1; if(chunks_processed == valid_chunks){ break; } } for (int i = 0; i < valid_chunks; i++) - int err = pthread_join(threads[i], NULL); - - free(threads); - for (int i = 0; i < valid_chunks; i++) - free(args[i]); - free(args); - + threads[i].join(); } } diff --git a/csrc/kernels.hip b/csrc/kernels.hip index 54b6afb9d..2bdc368fc 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -1,19 +1,24 @@ -// !!! This is a file automatically generated by hipify!!! #include "hip/hip_runtime.h" // Copyright (c) Facebook, Inc. and its affiliates. // // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -#include -#include +#include "kernels.hip.h" +#include +#include +#include +#include +#include +#include #include #include - #include #include -//#include +#define __syncwarp __syncthreads //TODO: HIP doesn't have this so just sync threads +//#include +#include #define HLF_MAX 65504 #define TH 1024 @@ -21,8 +26,31 @@ #define NUM_BLOCK 4096 -// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda -// Luckily we have atomicmax and atomicmin in ROCm +//source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda +//__device__ float atomicMax(float* address, float val) { +// int* address_as_i = reinterpret_cast(address); +// int old = *address_as_i, assumed; +// do { +// assumed = old; +// old = atomicCAS( +// reinterpret_cast(address), assumed, +// __float_as_int(fmaxf(val, __int_as_float(assumed)))); +// } while (assumed != old); +// return __int_as_float(old); +//} +// +//__device__ float atomicMin(float* address, float val) { +// int* address_as_i = reinterpret_cast(address); +// int old = *address_as_i, assumed; +// do { +// assumed = old; +// old = atomicCAS( +// reinterpret_cast(address), assumed, +// __float_as_int(fminf(val, __int_as_float(assumed)))); +// } while (assumed != old); +// return __int_as_float(old); +//} + __device__ float dDequantizeFP4(unsigned char val, float absmax) { @@ -86,7 +114,7 @@ __device__ float dDequantizeFP4Tree(unsigned char val, float absmax) return 1.00000000f*absmax*sign; // 1011 else return 0.66666667f*absmax*sign; // 1010 - else + else if((val & 0b0001) == 1) // 100 return 5.208333333e-03f*absmax*sign; // 1001 else @@ -110,10 +138,10 @@ __device__ unsigned char dQuantizeFP4(float x) // we do a binary search // the pivots are divided by 12 (the FP4 absmax) - // since we assum input data is in [-1.0, 1.0] + // since we assume input data is in [-1.0, 1.0] // !be careful here, its easy to make a mistake - // that is difficult to noice if you add an extra + // that is difficult to notice if you add an extra // zero somewhere! int sign = x < 0 ? 0b1000 : 0b0000; @@ -150,36 +178,36 @@ __device__ half dhDequantizeNF4(unsigned char val) if((val & 0b0100) == 4) // 1 if((val & 0b0010) == 2) // 11 if((val & 0b0001) == 1) // 111 - return 1.0f; + return 1.0f; else return 0.7229568362236023f; else if((val & 0b0001) == 1) // 110 - return 0.5626170039176941f; + return 0.5626170039176941f; else - return 0.44070982933044434f; + return 0.44070982933044434f; else if((val & 0b0010) == 2) //10 if((val & 0b0001) == 1) // 101 - return 0.33791524171829224f; + return 0.33791524171829224f; else - return 0.24611230194568634f; - else + return 0.24611230194568634f; + else if((val & 0b0001) == 1) // 100 - return 0.16093020141124725f; + return 0.16093020141124725f; else - return 0.07958029955625534f; + return 0.07958029955625534f; else if((val & 0b0100) == 4) // 0 if((val & 0b0010) == 2) //01 if((val & 0b0001) == 1) // 011 - return 0.0f; + return 0.0f; else - return -0.09105003625154495f; + return -0.09105003625154495f; else if((val & 0b0001) == 1) // 010 - return -0.18477343022823334f; + return -0.18477343022823334f; else return -0.28444138169288635f; else @@ -187,12 +215,12 @@ __device__ half dhDequantizeNF4(unsigned char val) if((val & 0b0001) == 1) // 001 return -0.39491748809814453f; else - return -0.5250730514526367f; - else + return -0.5250730514526367f; + else if((val & 0b0001) == 1) // 000 - return -0.6961928009986877f; + return -0.6961928009986877f; else - return -1.0f; + return -1.0f; } @@ -205,36 +233,36 @@ __device__ float dDequantizeNF4(unsigned char val) if((val & 0b0100) == 4) // 1 if((val & 0b0010) == 2) // 11 if((val & 0b0001) == 1) // 111 - return 1.0f; + return 1.0f; else return 0.7229568362236023f; else if((val & 0b0001) == 1) // 110 - return 0.5626170039176941f; + return 0.5626170039176941f; else - return 0.44070982933044434f; + return 0.44070982933044434f; else if((val & 0b0010) == 2) //10 if((val & 0b0001) == 1) // 101 - return 0.33791524171829224f; + return 0.33791524171829224f; else - return 0.24611230194568634f; - else + return 0.24611230194568634f; + else if((val & 0b0001) == 1) // 100 - return 0.16093020141124725f; + return 0.16093020141124725f; else - return 0.07958029955625534f; + return 0.07958029955625534f; else if((val & 0b0100) == 4) // 0 if((val & 0b0010) == 2) //01 if((val & 0b0001) == 1) // 011 - return 0.0f; + return 0.0f; else - return -0.09105003625154495f; + return -0.09105003625154495f; else if((val & 0b0001) == 1) // 010 - return -0.18477343022823334f; + return -0.18477343022823334f; else return -0.28444138169288635f; else @@ -242,12 +270,12 @@ __device__ float dDequantizeNF4(unsigned char val) if((val & 0b0001) == 1) // 001 return -0.39491748809814453f; else - return -0.5250730514526367f; - else + return -0.5250730514526367f; + else if((val & 0b0001) == 1) // 000 - return -0.6961928009986877f; + return -0.6961928009986877f; else - return -1.0f; + return -1.0f; } @@ -500,7 +528,7 @@ __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index template __global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, const int n) { - typedef hipcub::WarpReduce WarpReduce; + typedef hipcub::WarpReduce WarpReduce; __shared__ typename WarpReduce::TempStorage temp_storage; typedef hipcub::BlockLoad LoadT; __shared__ typename LoadT::TempStorage loadt; @@ -565,7 +593,7 @@ __global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* ou max2 = -64000.0f; } - //__syncwarp(); + __syncwarp(); } if(threadIdx.x % 32 < 8) @@ -756,7 +784,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float else local_abs_max = smem_absmax_value[0]; - //__syncwarp(); + __syncwarp(); local_abs_max = 1.0f/local_abs_max; @@ -963,7 +991,7 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p, if(threadIdx.x == 0) atomicAdd(&unorm[0], s1_vals[0]); - //__syncwarp(); + __syncwarp(); } } @@ -1136,7 +1164,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, if(threadIdx.x == 0) atomicAdd(&unorm[0], s1_vals[0]); - //__syncwarp(); + __syncwarp(); } } @@ -1841,7 +1869,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char //float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); //g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; g_val *= gnorm_scale; - + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; @@ -2237,8 +2265,8 @@ template__global__ void kd // data is in 32 column-tile major with tile width 32 columns and numRows rows // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. - // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + // L2. Load data in warp-striped arrangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) // C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register)) // C2. Compute normalization values and store col values in register // S1. Store C1 into 16-bit output @@ -2300,16 +2328,13 @@ template __global__ void kd const int n_out = numRows*numCols; - //int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1); + int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1); // we have tiles of size numRows*32, thus col only increases every numRows // num_row_tiles is the tiles after which the column increases by 32 // blockIdx.x is the index of the current tile - //int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32)); + int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32)); // base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached - //int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS); - - int block_offset = blockIdx.x * THREADS * ITEMS_PER_THREAD; - int thread_offset = threadIdx.x * ITEMS_PER_THREAD; + int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS); // SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS // subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD @@ -2324,33 +2349,20 @@ template __global__ void kd int local_values[ITEMS_PER_THREAD]; half local_output[ITEMS_PER_THREAD]; - //float local_rowStats[ITEMS_PER_THREAD]; - //__shared__ float smem_rowStats[SUBTILE_ROWS]; + float local_rowStats[ITEMS_PER_THREAD]; + __shared__ float smem_rowStats[SUBTILE_ROWS]; typedef hipcub::BlockLoad LoadInt32; - //typedef hipcub::BlockExchange ExchangeInt32; + typedef hipcub::BlockExchange ExchangeInt32; __shared__ typename LoadInt32::TempStorage loadint32; - //__shared__ typename ExchangeInt32::TempStorage exchangeint32; + __shared__ typename ExchangeInt32::TempStorage exchangeint32; // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. - //float colStat = col >= numCols ? 0.0f : colStats[col]; - //float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : __half2float(bias[col]); - int row_idx, col_idx; - float colStat[ITEMS_PER_THREAD]; - float local_biasValue[ITEMS_PER_THREAD]; - float rowStat[ITEMS_PER_THREAD]; - #pragma unroll ITEMS_PER_THREAD - for(int j = 0; j < ITEMS_PER_THREAD; j++) - { - row_idx = (block_offset + thread_offset + j) / numCols; - col_idx = (block_offset + thread_offset + j) % numCols; - colStat[j] = col_idx >= numCols ? 0.0f : colStats[col_idx]; - local_biasValue[j] = ((bias == NULL) || (col_idx >= numCols)) ? 0.0f : __half2float(bias[col_idx]); - rowStat[j] = row_idx >= numRows ? 0.0f : rowStats[row_idx]; - } + float colStat = col >= numCols ? 0.0f : colStats[col]; + float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : __half2float(bias[col]); // no block loads for rows for now -- keep it simple - /*for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x) + for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x) { // todo: is this global mem access slow due to overlaps or does the L1 cache work well here? int row = (base_row+j) % numRows; // wrap around @@ -2358,25 +2370,12 @@ template __global__ void kd // todo: update description about striped shared memory, it is not needed // rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements smem_rowStats[j] = rowStats[row]; - }*/ + } __syncthreads(); - int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out ? THREADS * ITEMS_PER_THREAD : n_out - block_offset; - LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0); - #pragma unroll ITEMS_PER_THREAD - for(int j = 0; j < ITEMS_PER_THREAD; j++) - local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*rowStat[j]*colStat[j]) + local_biasValue[j]); - // each block processes SUBTILE_ROWS*32 elements - #pragma unroll ITEMS_PER_THREAD - for(int j = 0; j < ITEMS_PER_THREAD; j++) - { - int outIdx = block_offset + thread_offset + j; - if(outIdx< n_out) - out[outIdx] = local_output[j]; - } - /*const int items_per_load = THREADS*ITEMS_PER_THREAD; + const int items_per_load = THREADS*ITEMS_PER_THREAD; const int rows_per_load = items_per_load/32; int subtile_base_row = (threadIdx.x / 32)*ITEMS_PER_THREAD; // row within the tile @@ -2390,14 +2389,14 @@ template __global__ void kd if(valid_items <= 0) // the sub-tile might have more elements than the tile itself break; - // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + // L2. Load data in warp-striped arrangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0); ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values); #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j]; - + #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat) + local_biasValue); @@ -2417,7 +2416,7 @@ template __global__ void kd } row_offset += rows_per_load; - }*/ + } } @@ -2657,7 +2656,7 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * { int local_colidx = idx[blockIdx.x]; - /*if(FORMAT==COL_TURING) + if(FORMAT==COL_TURING) { // TURING FORMAT: // 8*32 tiles with 4*4 subtiles @@ -3059,17 +3058,6 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * int out_idx = (row*idx_size) + blockIdx.x; out[out_idx] = val; } - }*/ - - //Only col format is used on ROCm - for(int row = threadIdx.x; row < rowsA; row+= blockDim.x) - { - //col-major offset - int offset = local_colidx * rowsA + row; - - char val = A[offset]; - int out_idx = (row*idx_size) + blockIdx.x; - out[out_idx] = val; } } @@ -3081,17 +3069,17 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * //// 2. Load k x k into registers //// 3. dequantize and store in second pair of k x k //// 4. matmul -//// 5. sum with hipcub +//// 5. sum with cub //// 6. store outputs //// TC kernel //// use k warps per thread block //// 1. threadblock use read-only cache to read in register tile for A into shared memory //// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments -//// 3. each warp reads a segment of values 16x32 from B +//// 3. each warp reads a segment of values 16x32 from B //// 4. do dequantization from register of B into second pair of registers //// 5. store (4) into fragment //// 6. matmul aggregate into fragment C -//// 7. aggreecate files of C into shared memroy block C +//// 7. aggregate files of C into shared memory block C //// 8. sum (7) //// 9. write outputs to matmul output matrix //} @@ -3480,7 +3468,7 @@ template __global__ void kgemm_4bit_inference(int M, i //local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax); local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx); local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx); - } + } //printnonzero(local_B, 128, ""); } @@ -3549,11 +3537,11 @@ template __global__ void kgemm_4bit_inference(int M, i template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) { - // per threadblock: + // per threadblock: // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] // 4 warps -> 4 loads per iter // 1x32 * 32x4 -> 1x4 outputs per thread block - typedef hipcub::WarpReduce WarpReduce; + typedef hipcub::WarpReduce WarpReduce; __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32]; const int warp_idx = threadIdx.x / 32; @@ -3782,7 +3770,7 @@ template __global__ void kfunc(T *A, T *B, T value, long { switch(FUNC) { - case FILL: + case FILL: A[i] = (T)value; break; case ARANGE: @@ -3839,12 +3827,12 @@ template __global__ void kgemm_4bit_inference_naive(int M, int N template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); @@ -3911,7 +3899,6 @@ template __global__ void kOptimizer32bit2State(half* g, half* p, flo template __global__ void kOptimizer32bit2State(hip_bfloat16* g, hip_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); - #define MAKE_PreconditionStatic8bit1State(oname, gtype) \ template __global__ void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ float *unorm, \ @@ -3991,7 +3978,6 @@ MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit) MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit) MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit) //MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) - MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4) MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4) MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4) @@ -3999,7 +3985,6 @@ MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4) MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4) MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4) //MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) - MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4) MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4) MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4) @@ -4007,7 +3992,6 @@ MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4) MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4) //MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) - MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) @@ -4016,7 +4000,6 @@ MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) //MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) - MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) @@ -4024,7 +4007,6 @@ MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) //MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) - MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) @@ -4041,7 +4023,6 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, General8bit) MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, General8bit) MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, General8bit) //MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit) - MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, FP4) MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, FP4) MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, FP4) @@ -4049,7 +4030,6 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, FP4) MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, FP4) MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, FP4) //MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4) - MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, NF4) MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, NF4) MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, NF4) diff --git a/csrc/kernels.hiph b/csrc/kernels.hiph index c842cc754..d49dc0a9d 100644 --- a/csrc/kernels.hiph +++ b/csrc/kernels.hiph @@ -1,12 +1,12 @@ -// !!! This is a file automatically generated by hipify!!! -#include "hip/hip_runtime.h" // Copyright (c) Facebook, Inc. and its affiliates. // // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. + +#include #include -#include +#include #ifndef kernels #define kernels diff --git a/csrc/ops.hip b/csrc/ops.hip index e98e3b817..eb1f72980 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -1,23 +1,20 @@ -// !!! This is a file automatically generated by hipify!!! -#include "hip/hip_runtime.h" // Copyright (c) Facebook, Inc. and its affiliates. // // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. - -#include -#include -#include -#include -#include -#ifndef NO_HIPBLASLT -#include -#endif +#include "ops.hip.h" +#include "kernels.hip.h" +#include +#define HIPBLAS_V2 +#include +#include #include #include #include #include +#define ERR_NOT_IMPLEMENTED 100 + using namespace BinSearch; using std::cout; @@ -28,33 +25,33 @@ void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *sr int threads = 512; int num_blocks = n/threads; num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1; - hipLaunchKernelGGL(( kHistogramScatterAdd2D), dim3(num_blocks), dim3(512), 0, 0, histogram, index1, index2, src, maxidx1, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + kHistogramScatterAdd2D<<>>(histogram, index1, index2, src, maxidx1, n); + HIP_CHECK_RETURN(hipPeekAtLastError()); } template void estimateQuantiles(T *A, float *code, float offset, int n) { int num_blocks = n/4096; num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; - CUDA_CHECK_RETURN(hipMemset(code, 0, 256*sizeof(float))); - hipLaunchKernelGGL(( kEstimateQuantiles), dim3(num_blocks), dim3(512), 0, 0, A, code, offset, std::numeric_limits::max(), n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + HIP_CHECK_RETURN(hipMemset(code, 0, 256*sizeof(float))); + kEstimateQuantiles<<>>(A, code, offset, std::numeric_limits::max(), n); + HIP_CHECK_RETURN(hipPeekAtLastError()); } void quantize(float *code, float *A, unsigned char *out, int n) { int num_blocks = n/1024; num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; - hipLaunchKernelGGL(( kQuantize), dim3(num_blocks), dim3(1024), 0, 0, code, A, out, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + kQuantize<<>>(code, A, out, n); + HIP_CHECK_RETURN(hipPeekAtLastError()); } void dequantize(float *code, unsigned char *A, float *out, int n) { int num_blocks = n/1024; num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; - hipLaunchKernelGGL(( kDequantize), dim3(num_blocks), dim3(1024), 0, 0, code, A, out, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + kDequantize<<>>(code, A, out, n); + HIP_CHECK_RETURN(hipPeekAtLastError()); } template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) @@ -63,22 +60,22 @@ template void quantizeBlockwise(floa num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; if(blocksize == 4096) - hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(1024), 0, 0, code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 2048) - hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(512), 0, 0, code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 1024) - hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(256), 0, 0, code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 512) - hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(256), 0, 0, code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 256) - hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(128), 0, 0, code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 128) - hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(64), 0, 0, code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); //else if(blocksize == 64) - // hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(32), 0, 0, code, A, absmax, out, rand, rand_offset, n); + // kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + HIP_CHECK_RETURN(hipPeekAtLastError()); } template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) @@ -88,11 +85,11 @@ template void dequantizeBlockwise(float *code, unsign int tile_size = (DATA_TYPE > 0) ? 1024 : 512; if(DATA_TYPE > 0) - hipLaunchKernelGGL(( kDequantizeBlockwise), dim3((n+tile_size-1)/tile_size), dim3(64), 0, 0, code, A, absmax, out, blocksize/2, n); + kDequantizeBlockwise<<>>(code, A, absmax, out, blocksize/2, n); else - hipLaunchKernelGGL(( kDequantizeBlockwise), dim3((n+tile_size-1)/tile_size), dim3(64), 0, 0, code, A, absmax, out, blocksize, n); + kDequantizeBlockwise<<>>(code, A, absmax, out, blocksize, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + HIP_CHECK_RETURN(hipPeekAtLastError()); } @@ -100,7 +97,7 @@ template void dequantizeBlockwise(float *code, unsign //{ // int num_blocks = (colsB+32-1)/32; // kMatmul_inference_4bit<<>>(A, B, out, lda, ldb, rowsA, colsA, colsB); -// CUDA_CHECK_RETURN(hipPeekAtLastError()); +// HIP_CHECK_RETURN(hipPeekAtLastError()); //} @@ -116,36 +113,36 @@ template void optimizer32bit(T* g, T* p, case ADAM: if(max_unorm > 0.0f) { - CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); - hipLaunchKernelGGL(( kPreconditionOptimizer32bit2State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + HIP_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); + kPreconditionOptimizer32bit2State<<>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + HIP_CHECK_RETURN(hipPeekAtLastError()); } - hipLaunchKernelGGL(( kOptimizer32bit2State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + kOptimizer32bit2State<<>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + HIP_CHECK_RETURN(hipPeekAtLastError()); break; case MOMENTUM: case RMSPROP: case ADAGRAD: if(max_unorm > 0.0f) { - CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); - hipLaunchKernelGGL(( kPreconditionOptimizer32bit1State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + HIP_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); + kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + HIP_CHECK_RETURN(hipPeekAtLastError()); } - hipLaunchKernelGGL(( kOptimizer32bit1State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + HIP_CHECK_RETURN(hipPeekAtLastError()); break; case LION: // in lion, the momentum update after the parameter update - hipLaunchKernelGGL(( kOptimizer32bit1State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + HIP_CHECK_RETURN(hipPeekAtLastError()); if(max_unorm > 0.0f) { - CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); - hipLaunchKernelGGL(( kPreconditionOptimizer32bit1State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + HIP_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); + kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + HIP_CHECK_RETURN(hipPeekAtLastError()); } break; } @@ -164,38 +161,38 @@ template void optimizerStatic8bit(T* p, T* g, int num_blocks = n/4096; num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; - if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); } + if(max_unorm > 0.0f){ HIP_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); } switch(OPTIMIZER) { case ADAM: - CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); - CUDA_CHECK_RETURN(hipMemset(new_max2, 0, 1*sizeof(float))); - hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit2State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); - hipLaunchKernelGGL(( kOptimizerStatic8bit2State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + HIP_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); + HIP_CHECK_RETURN(hipMemset(new_max2, 0, 1*sizeof(float))); + kPreconditionOptimizerStatic8bit2State<<>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); + HIP_CHECK_RETURN(hipPeekAtLastError()); + kOptimizerStatic8bit2State<<>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + HIP_CHECK_RETURN(hipPeekAtLastError()); break; case MOMENTUM: case RMSPROP: case ADAGRAD: - CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); - hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit1State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); - hipLaunchKernelGGL(( kOptimizerStatic8bit1State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + HIP_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); + kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + HIP_CHECK_RETURN(hipPeekAtLastError()); + kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + HIP_CHECK_RETURN(hipPeekAtLastError()); break; case LION: // in lion, the momentum update happens after the parameter update - hipLaunchKernelGGL(( kOptimizerStatic8bit1State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + HIP_CHECK_RETURN(hipPeekAtLastError()); - CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); - hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit1State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + HIP_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); + kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + HIP_CHECK_RETURN(hipPeekAtLastError()); break; default: break; @@ -218,9 +215,9 @@ template void optimizerStatic8bitBlockwise(T* p, T* g case ADAM: num_blocks = n/BLOCKSIZE_2STATE; num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; - hipLaunchKernelGGL(( kOptimizerStatic8bit2StateBlockwise), dim3(num_blocks), dim3(BLOCKSIZE_2STATE/NUM_2STATE), 0, 0, p, g, state1, state2, beta1, beta2, eps, step, lr, + kOptimizerStatic8bit2StateBlockwise<<>>(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + HIP_CHECK_RETURN(hipPeekAtLastError()); break; case MOMENTUM: case RMSPROP: @@ -228,9 +225,9 @@ template void optimizerStatic8bitBlockwise(T* p, T* g case LION: num_blocks = n/BLOCKSIZE_1STATE; num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; - hipLaunchKernelGGL(( kOptimizerStatic8bit1StateBlockwise), dim3(num_blocks), dim3(BLOCKSIZE_1STATE/NUM_1STATE), 0, 0, p, g, state1, beta1, beta2, eps, step, lr, + kOptimizerStatic8bit1StateBlockwise<<>>(p, g, state1, beta1, beta2, eps, step, lr, quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + HIP_CHECK_RETURN(hipPeekAtLastError()); break; } } @@ -241,9 +238,9 @@ template void percentileClipping(T * g, float *gnorm_vec, int step, { int num_blocks = n/2048; num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1; - CUDA_CHECK_RETURN(hipMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float))); - hipLaunchKernelGGL(( kPercentileClipping), dim3(num_blocks), dim3(512), 0, 0, g, gnorm_vec, step, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + HIP_CHECK_RETURN(hipMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float))); + kPercentileClipping<<>>(g, gnorm_vec, step, n); + HIP_CHECK_RETURN(hipPeekAtLastError()); } void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) @@ -264,7 +261,7 @@ void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, in if (status != HIPBLAS_STATUS_SUCCESS) { - std::cout << "HIPBLAS ERROR: Status " << status << std::endl; + std::cout << "CUBLAS ERROR: Status " << status << std::endl; } } @@ -294,7 +291,7 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i if (status != HIPBLAS_STATUS_SUCCESS) { - std::cout << "HIPBLAS ERROR: Status " << status << std::endl; + std::cout << "CUBLAS ERROR: Status " << status << std::endl; } } @@ -303,8 +300,9 @@ int roundoff(int v, int d) { return (v + d - 1) / d * d; } +#ifdef NO_CUBLASLT +#else -#ifndef NO_HIPBLASLT template hipblasLtOrder_t get_order() { switch(ORDER) @@ -316,34 +314,31 @@ template hipblasLtOrder_t get_order() return HIPBLASLT_ORDER_COL; break; case COL32: - //return HIPBLASLT_ORDER_COL32; return HIPBLASLT_ORDER_COL; break; case COL_TURING: - //return HIPBLASLT_ORDER_COL4_4R2_8C; return HIPBLASLT_ORDER_COL; - break; + break; case COL_AMPERE: - //return HIPBLASLT_ORDER_COL32_2R_4R4; return HIPBLASLT_ORDER_COL; break; default: break; } - return HIPBLASLT_ORDER_ROW; } template hipblasLtOrder_t get_order(); template hipblasLtOrder_t get_order(); + template hipblasLtOrder_t get_order(); -//template hipblasLtOrder_t get_order(); -//template hipblasLtOrder_t get_order(); +template hipblasLtOrder_t get_order(); +template hipblasLtOrder_t get_order(); #endif - template int get_leading_dim(int dim1, int dim2) { + switch(ORDER) { case ROW: @@ -352,10 +347,7 @@ template int get_leading_dim(int dim1, int dim2) case COL: return dim1; break; - default: - return dim1; - break; - /*case COL32: + case COL32: // 32*row tiles return dim1*32; break; @@ -369,266 +361,114 @@ template int get_leading_dim(int dim1, int dim2) default: return 0; break; - */ } } - template int get_leading_dim(int dim1, int dim2); template int get_leading_dim(int dim1, int dim2); -template int get_leading_dim(int dim1, int dim2); -#ifndef NO_HIPBLASLT +template int get_leading_dim(int dim1, int dim2); template void transform(hipblasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2) { + +#ifdef NO_CUBLASLT +#else hipblasLtOrder_t orderA = get_order(); hipblasLtOrder_t orderOut = get_order(); int ldA = get_leading_dim(dim1, dim2); - int ldOut; - if (TARGET==COL && transpose) { - ldOut = dim2; - } else { - ldOut = get_leading_dim(dim1, dim2); - } + int ldOut = get_leading_dim(dim1, dim2); - hipblasLtMatrixLayout_t A_desc = NULL, out_desc = NULL, B_desc = NULL; - T B = T(0); + hipblasLtMatrixLayout_t A_desc = NULL, out_desc = NULL; hipblasLtMatrixTransformDesc_t A2Out_desc = NULL; hipblasOperation_t opTranspose = HIPBLAS_OP_T; float transformAlpha = 1.0f, transformBeta = 0.0f; + if(DTYPE == 8) { checkHipblasStatus(hipblasLtMatrixLayoutCreate(&A_desc, HIP_R_8I, dim1, dim2, ldA)); - checkHipblasStatus(hipblasLtMatrixLayoutCreate(&B_desc, HIP_R_8I, 0, 0, 0)); - if (TARGET==COL && transpose) { - checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_8I, dim2, dim1, ldOut)); - } else { - checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_8I, dim1, dim2, ldOut)); - } + checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_8I, dim1, dim2, ldOut)); } else if(DTYPE == 32) { checkHipblasStatus(hipblasLtMatrixLayoutCreate(&A_desc, HIP_R_32I, dim1, dim2, ldA)); - checkHipblasStatus(hipblasLtMatrixLayoutCreate(&B_desc, HIP_R_32I, 0, 0, 0)); - if (TARGET==COL && transpose) { - checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_32I, dim2, dim1, ldOut)); - } else { - checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_32I, dim1, dim2, ldOut)); - } + checkHipblasStatus(hipblasLtMatrixLayoutCreate(&out_desc, HIP_R_32I, dim1, dim2, ldOut)); } else { printf("ERROR WRONG TYPE FOR TRANSFORM: %i\n", DTYPE); } - checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(A_desc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA))); - checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(out_desc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &orderOut, sizeof(orderOut))); + checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(A_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA))); + checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(out_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderOut, sizeof(orderOut))); checkHipblasStatus(hipblasLtMatrixTransformDescCreate(&A2Out_desc, HIP_R_32F)); if(transpose){ checkHipblasStatus(hipblasLtMatrixTransformDescSetAttribute(A2Out_desc, HIPBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); } - checkHipblasStatus(hipblasLtMatrixTransform(ltHandle, A2Out_desc, &transformAlpha, A, A_desc, &transformBeta, A, B_desc, out, out_desc, 0)); + checkHipblasStatus(hipblasLtMatrixTransform(ltHandle, A2Out_desc, &transformAlpha, A, A_desc, &transformBeta, NULL, NULL, out, out_desc, 0)); if (A_desc) checkHipblasStatus(hipblasLtMatrixLayoutDestroy(A_desc)); - if (B_desc) checkHipblasStatus(hipblasLtMatrixLayoutDestroy(B_desc)); if (out_desc) checkHipblasStatus(hipblasLtMatrixLayoutDestroy(out_desc)); if (A2Out_desc) checkHipblasStatus(hipblasLtMatrixTransformDescDestroy(A2Out_desc)); +#endif } - template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); -template void transform(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); + template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); template void transform(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); template void transform(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); -template void transform(hipblasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(hipblasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); -#endif -static std::string hipError_to_string(const hipError_t ret) -{ - switch(ret) - { - case hipSuccess: - return "hipSuccess"; - case hipErrorInvalidContext: - return "hipErrorInvalidContext"; - case hipErrorInvalidKernelFile: - return "hipErrorInvalidKernelFile"; - case hipErrorMemoryAllocation: - return "hipErrorMemoryAllocation"; - case hipErrorInitializationError: - return "hipErrorInitializationError"; - case hipErrorLaunchFailure: - return "hipErrorLaunchFailure"; - case hipErrorLaunchOutOfResources: - return "hipErrorLaunchOutOfResources"; - case hipErrorInvalidDevice: - return "hipErrorInvalidDevice"; - case hipErrorInvalidValue: - return "hipErrorInvalidValue"; - case hipErrorInvalidDevicePointer: - return "hipErrorInvalidDevicePointer"; - case hipErrorInvalidMemcpyDirection: - return "hipErrorInvalidMemcpyDirection"; - case hipErrorUnknown: - return "hipErrorUnknown"; - case hipErrorInvalidResourceHandle: - return "hipErrorInvalidResourceHandle"; - case hipErrorNotReady: - return "hipErrorNotReady"; - case hipErrorNoDevice: - return "hipErrorNoDevice"; - case hipErrorPeerAccessAlreadyEnabled: - return "hipErrorPeerAccessAlreadyEnabled"; - case hipErrorPeerAccessNotEnabled: - return "hipErrorPeerAccessNotEnabled"; - case hipErrorRuntimeMemory: - return "hipErrorRuntimeMemory"; - case hipErrorRuntimeOther: - return "hipErrorRuntimeOther"; - case hipErrorHostMemoryAlreadyRegistered: - return "hipErrorHostMemoryAlreadyRegistered"; - case hipErrorHostMemoryNotRegistered: - return "hipErrorHostMemoryNotRegistered"; - case hipErrorMapBufferObjectFailed: - return "hipErrorMapBufferObjectFailed"; - case hipErrorTbd: - return "hipErrorTbd"; - default: - throw std::runtime_error("unknown hipError"); - } -} -#ifndef NO_HIPBLASLT template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { + #ifdef NO_CUBLASLT - cout << "" << endl; - cout << "=============================================" << endl; - cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl; - cout << "=============================================" << endl; - cout << "" << endl; - assert(false); - - return 0; + return ERR_NOT_IMPLEMENTED; #else int has_error = 0; hipblasLtMatmulDesc_t matmulDesc = NULL; hipblasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; hipblasOperation_t opT = HIPBLAS_OP_T; - //hipblasLtPointerMode_t alphaVec = hipblasLt_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; + hipblasLtPointerMode_t alphaVec = HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST; hipblasLtOrder_t col32 = HIPBLASLT_ORDER_COL; - hipblasLtOrder_t col_turing = HIPBLASLT_ORDER_COL; + hipblasLtOrder_t col_turing = CUBLASLT_ORDER_COL4_4R2_8C; hipblasLtOrder_t col_ampere = HIPBLASLT_ORDER_COL; has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&Adesc, HIP_R_8I, m, k, lda)); has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&Bdesc, HIP_R_8I, n, k, ldb)); - has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Adesc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); - + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); if(FORMATB == COL_TURING) - has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Bdesc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing))); + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing))); else - has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Bdesc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere))); - - const int64_t max_workspace_size = 0;//set to 0 to avoid choosing GSU kernel + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere))); if(DTYPE_OUT == 32) { has_error |= checkHipblasStatus(hipblasLtMatmulDescCreate(&matmulDesc, HIPBLAS_COMPUTE_32I, HIP_R_32I)); - auto opA = HIPBLAS_OP_N; - has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSA, &opA, sizeof(int32_t))); - has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(int32_t))); - hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT; - checkHipblasStatus(hipblasLtMatmulDescSetAttribute( - matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); + has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&Cdesc, HIP_R_32I, m, n, ldc)); - has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Cdesc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); int alpha = 1, beta = 0; - - - /* Algo and workspace TODO: need to rework to not be duplicated */ - // Set User Preference attributes - hipblasLtMatmulPreference_t pref; - checkHipblasStatus(hipblasLtMatmulPreferenceCreate(&pref)); - checkHipblasStatus( - hipblasLtMatmulPreferenceSetAttribute(pref, - HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &max_workspace_size, - sizeof(max_workspace_size))); - - const int request_solutions = 1; - hipblasLtMatmulHeuristicResult_t heuristicResult[request_solutions]; - int returnedAlgoCount = 0; - checkHipblasStatus(hipblasLtMatmulAlgoGetHeuristic(ltHandle, - matmulDesc, - Adesc, - Bdesc, - Cdesc, - Cdesc, - pref, - request_solutions, - heuristicResult, - &returnedAlgoCount)); - - if (returnedAlgoCount == 0) - { - has_error = 1; - } - else - { - has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0)); - } + has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, NULL, NULL, 0, 0)); } else { - has_error |= checkHipblasStatus(hipblasLtMatmulDescCreate(&matmulDesc, HIPBLAS_COMPUTE_32I, HIP_R_8I)); - hipblasOperation_t opA = HIPBLAS_OP_N; - has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSA, &opA, sizeof(opA))); + has_error |= checkHipblasStatus(hipblasLtMatmulDescCreate(&matmulDesc, HIPBLAS_COMPUTE_32I, HIP_R_32F)); has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&Cdesc, HIP_R_8I, m, n, ldc)); - has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Cdesc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); - /* Algo and workspace TODO: need to rework to not be duplicated */ - // Set User Preference attributes - hipblasLtMatmulPreference_t pref; - checkHipblasStatus(hipblasLtMatmulPreferenceCreate(&pref)); - checkHipblasStatus( - hipblasLtMatmulPreferenceSetAttribute(pref, - HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &max_workspace_size, - sizeof(max_workspace_size))); - - const int request_solutions = 1; - hipblasLtMatmulHeuristicResult_t heuristicResult[request_solutions]; - int returnedAlgoCount = 0; - checkHipblasStatus(hipblasLtMatmulAlgoGetHeuristic(ltHandle, - matmulDesc, - Adesc, - Bdesc, - Cdesc, - Cdesc, - pref, - request_solutions, - heuristicResult, - &returnedAlgoCount)); - + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); if(!SCALE_ROWS) { float alpha = 1.0f, beta = 0.0f; - - has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0)); + has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0)); } else { - //has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, hipblasLt_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec))); - float beta = 0.0f; - - has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0)); + has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec))); + has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, NULL, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0)); } } @@ -641,37 +481,32 @@ template int igemmlt(hipblasLtHandl printf("error detected"); return has_error; -#endif +#endif // NO_CUBLASLT } -#endif - int fill_up_to_nearest_multiple(int value, int multiple) { + return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); } - void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half *bias, int numRows, int numCols) { + int threads = 512; - //int tileCols = fill_up_to_nearest_multiple(numCols, 32); - //int n = numRows*tileCols; - int tileCols = numCols; - int n = numRows*numCols; - //int subtile_rows = 128; - //int tilesize = 32*subtile_rows; - //int num_blocks = numRows/subtile_rows; - //num_blocks += (numRows % subtile_rows == 0) ? 0 : 1; - //num_blocks = num_blocks*(tileCols/32); - //assert(threads <= tilesize); - int num_blocks = numRows * numCols / (threads * 4); - num_blocks += (numRows * numCols) % (threads * 4) == 0 ? 0 : 1; - - hipLaunchKernelGGL(( kdequant_mm_int32_fp16<4, 128, 512>), dim3(num_blocks), dim3(threads), 0, 0, A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); -} + int tileCols = fill_up_to_nearest_multiple(numCols, 32); + int n = numRows*tileCols; + int subtile_rows = 128; + int tilesize = 32*subtile_rows; + int num_blocks = numRows/subtile_rows; + num_blocks += (numRows % subtile_rows == 0) ? 0 : 1; + num_blocks = num_blocks*(tileCols/32); + assert(threads <= tilesize); + kdequant_mm_int32_fp16<4, 128, 512><<>>(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n); + HIP_CHECK_RETURN(hipPeekAtLastError()); +} #define STATS_THREADS 64 #define STATS_ITEMS 4 + #define STATS_ROWS 16 void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols) { @@ -683,17 +518,17 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r row_tiles = row_tiles > 0 ? row_tiles : 1; col_tiles = col_tiles > 0 ? col_tiles : 1; int num_blocks = row_tiles * col_tiles; - if(nnz_threshold == 0.0) - hipLaunchKernelGGL(( kgetColRowStats), dim3(num_blocks), dim3(STATS_THREADS), 0, 0, A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); - else if(nnz_threshold != 0.0) - hipLaunchKernelGGL(( kgetColRowStats), dim3(num_blocks), dim3(STATS_THREADS), 0, 0, A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + kgetColRowStats<<>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); + else if(nnz_threshold != 0.0) + kgetColRowStats<<>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); + HIP_CHECK_RETURN(hipPeekAtLastError()); } void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols) { + int threads = 64; int items_per_thread = 4; int tile_cols = threads*items_per_thread; @@ -706,17 +541,17 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col col_tiles = col_tiles > 0 ? col_tiles : 1; int num_blocks = row_tiles * col_tiles; - if(threshold > 0.0f) - hipLaunchKernelGGL(( kDoubleRowColQuant<64, 4, 16, 64*4, 1>), dim3(num_blocks), dim3(threads), 0, 0, A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); - else - hipLaunchKernelGGL(( kDoubleRowColQuant<64, 4, 16, 64*4, 0>), dim3(num_blocks), dim3(threads), 0, 0, A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); + kDoubleRowColQuant<64, 4, 16, 64*4, 1><<>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + else + kDoubleRowColQuant<64, 4, 16, 64*4, 0><<>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); + HIP_CHECK_RETURN(hipPeekAtLastError()); } template void transformRowToFormat(char * A, char *out, int rows, int cols) { + int threads = 256; int items_per_thread = 8; // we load 128 column values per warp @@ -729,9 +564,9 @@ template void transformRowToFormat(char * A, char *o row_tiles = row_tiles > 0 ? row_tiles : 1; col_tiles = col_tiles > 0 ? col_tiles : 1; int num_blocks = row_tiles * col_tiles; - int outCols = fill_up_to_nearest_multiple(cols, 32); int outRows = fill_up_to_nearest_multiple(rows, 32); + if(FORMAT == COL_TURING) { if(TRANSPOSE) @@ -754,97 +589,98 @@ template void transformRowToFormat(char * A, char *o outRows = cols; } } + kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<>>(A, out, rows, cols, tiledCols, outRows, outCols); + HIP_CHECK_RETURN(hipPeekAtLastError()); - hipLaunchKernelGGL(( kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT>), dim3(num_blocks), dim3(threads), 0, 0, A, out, rows, cols, tiledCols, outRows, outCols); - CUDA_CHECK_RETURN(hipPeekAtLastError()); } - void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) { - hipsparseSpMatDescr_t descA; - hipsparseDnMatDescr_t descB, descC; - - float alpha = 1.0f; - float beta = 0.0f; - void *dBuffer = NULL; - size_t bufferSize = 0; - - CHECK_HIPSPARSE( hipsparseCreateCoo(&descA, A_rows, A_cols, A_nnz, - A_rowidx, A_colidx, A_vals, - HIPSPARSE_INDEX_32I, - HIPSPARSE_INDEX_BASE_ZERO, HIP_R_16F) ); - // Create dense matrix C - CHECK_HIPSPARSE( hipsparseCreateDnMat(&descC, A_rows, B_cols, ldc, C, - HIP_R_16F, HIPSPARSE_ORDER_ROW) ); - // Create dense matrix B - if(transposed_B) - { - int tmp = A_cols; - A_cols = B_cols; - B_cols = tmp; - } - CHECK_HIPSPARSE( hipsparseCreateDnMat(&descB, A_cols, B_cols, ldb, B, - HIP_R_16F, HIPSPARSE_ORDER_ROW) ); - // allocate an external buffer if needed - CHECK_HIPSPARSE( hipsparseSpMM_bufferSize( - handle, - HIPSPARSE_OPERATION_NON_TRANSPOSE, - transposed_B ? HIPSPARSE_OPERATION_TRANSPOSE : HIPSPARSE_OPERATION_NON_TRANSPOSE, - &alpha, descA, descB, &beta, descC, HIP_R_32F, - HIPSPARSE_SPMM_ALG_DEFAULT, &bufferSize) ); - CUDA_CHECK_RETURN( hipMalloc(&dBuffer, bufferSize) ); - - // execute SpMM - CHECK_HIPSPARSE( hipsparseSpMM(handle, - HIPSPARSE_OPERATION_NON_TRANSPOSE, - transposed_B ? HIPSPARSE_OPERATION_TRANSPOSE : HIPSPARSE_OPERATION_NON_TRANSPOSE, - &alpha, descA, descB, &beta, descC, HIP_R_32F, - HIPSPARSE_SPMM_ALG_DEFAULT, dBuffer)); - - // destroy matrix/vector descriptors - CHECK_HIPSPARSE( hipsparseDestroySpMat(descA) ); - CHECK_HIPSPARSE( hipsparseDestroyDnMat(descB) ); - CHECK_HIPSPARSE( hipsparseDestroyDnMat(descC) ); - CUDA_CHECK_RETURN( hipFree(dBuffer) ); -} +#ifdef NO_CUBLASLT +#else + + hipsparseSpMatDescr_t descA; + hipsparseDnMatDescr_t descB, descC; + + float alpha = 1.0f; + float beta = 0.0f; + void *dBuffer = NULL; + size_t bufferSize = 0; + CHECK_HIPSPARSE( hipsparseCreateCoo(&descA, A_rows, A_cols, A_nnz, + A_rowidx, A_colidx, A_vals, + + HIPSPARSE_INDEX_32I, + HIPSPARSE_INDEX_BASE_ZERO, HIP_R_16F) ); + // Create dense matrix C + CHECK_HIPSPARSE( hipsparseCreateDnMat(&descC, A_rows, B_cols, ldc, C, + HIP_R_16F, HIPSPARSE_ORDER_ROW) ); + // Create dense matrix B + if(transposed_B) + { + int tmp = A_cols; + A_cols = B_cols; + B_cols = tmp; + } + CHECK_HIPSPARSE( hipsparseCreateDnMat(&descB, A_cols, B_cols, ldb, B, + HIP_R_16F, HIPSPARSE_ORDER_ROW) ); + + // allocate an external buffer if needed + CHECK_HIPSPARSE( hipsparseSpMM_bufferSize( + handle, + HIPSPARSE_OPERATION_NON_TRANSPOSE, + transposed_B ? HIPSPARSE_OPERATION_TRANSPOSE : HIPSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, descA, descB, &beta, descC, HIP_R_32F, + HIPSPARSE_SPMM_ALG_DEFAULT, &bufferSize) ); + HIP_CHECK_RETURN( hipMalloc(&dBuffer, bufferSize) ); + // execute SpMM + CHECK_HIPSPARSE( hipsparseSpMM(handle, + + HIPSPARSE_OPERATION_NON_TRANSPOSE, + transposed_B ? HIPSPARSE_OPERATION_TRANSPOSE : HIPSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, descA, descB, &beta, descC, HIP_R_32F, + HIPSPARSE_SPMM_ALG_DEFAULT, dBuffer)); + // destroy matrix/vector descriptors + CHECK_HIPSPARSE( hipsparseDestroySpMat(descA) ); + + CHECK_HIPSPARSE( hipsparseDestroyDnMat(descB) ); + CHECK_HIPSPARSE( hipsparseDestroyDnMat(descC) ); + HIP_CHECK_RETURN( hipFree(dBuffer) ); +#endif +} template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) { - hipLaunchKernelGGL(( kspmm_coo_very_sparse_naive), dim3(nnz_rows), dim3(256), 0, 0, max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB); - CUDA_CHECK_RETURN(hipPeekAtLastError()); -} + kspmm_coo_very_sparse_naive<<>>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB); + HIP_CHECK_RETURN(hipPeekAtLastError()); +} template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols) { + int threads = 256; // we load 128 column values per warp int tiledCols = tiledCols = fill_up_to_nearest_multiple(cols, 32); int tiledRows = 0; - int num_blocks = idx_size; - /*if(FORMAT == COL_TURING) + if(FORMAT == COL_TURING) { + tiledRows = fill_up_to_nearest_multiple(rows, 8); } else if(FORMAT == COL_AMPERE) { tiledRows = fill_up_to_nearest_multiple(rows, 32); - }*/ - - //for col format on ROCm - tiledRows = rows; + } + kExtractOutliers<<>>(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols); + HIP_CHECK_RETURN(hipPeekAtLastError()); - hipLaunchKernelGGL(( kExtractOutliers), dim3(num_blocks), dim3(threads), 0, 0, A, idx, out, idx_size, rows, cols, tiledRows, tiledCols); - CUDA_CHECK_RETURN(hipPeekAtLastError()); } - template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) { @@ -852,29 +688,30 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out //cout << num_blocks << endl; //cout << lda << endl; + //cout << ldb << endl; //cout << ldc << endl; - //cout << m << endl; //cout << n << endl; + //cout << k << endl; //if(bits == 32) //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); if(bits == 16) //gemm_device<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - hipLaunchKernelGGL(( gemm_device), dim3(num_blocks), dim3(160), 0, 0 , m, n, k, A, B, out, lda, ldb, ldc); + gemm_device<<>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); //gemm_device<<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); } - template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) { int num_blocks = (m+31)/32; + //cout << num_blocks << endl; //cout << lda << endl; //cout << ldb << endl; @@ -883,7 +720,7 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi //cout << m << endl; //cout << n << endl; //cout << k << endl; - hipLaunchKernelGGL(( kgemm_4bit_inference), dim3(num_blocks), dim3(96), 0, 0 , m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + kgemm_4bit_inference<<< dim3(num_blocks), dim3(96), 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); //kgemm_4bit_inference<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); //kgemm_4bit_inference<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); @@ -894,8 +731,8 @@ template void gemm_4bit_inference_naive(int m, int n, int int num_blocks = (m+3)/4; - hipLaunchKernelGGL(( kgemm_4bit_inference_naive), dim3(num_blocks), dim3(128), 0, 0 , m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + kgemm_4bit_inference_naive<<< dim3(num_blocks), dim3(128), 0, 0 >>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + HIP_CHECK_RETURN(hipPeekAtLastError()); } template void func(T *A, T *B, T value, long n) @@ -904,8 +741,8 @@ template void func(T *A, T *B, T value, long n) int blocks = n/threads; blocks = n % threads == 0 ? blocks : blocks + 1; blocks = blocks > 65535 ? 65535 : blocks; - hipLaunchKernelGGL(( kfunc), dim3(blocks), dim3(512), 0, 0, A, B, value, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + kfunc<<>>(A, B, value, n); + HIP_CHECK_RETURN(hipPeekAtLastError()); } //============================================================== @@ -930,14 +767,12 @@ template void extractOutliers(char * A, int *idx, char *out, int idx template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); -#ifndef NO_HIPBLASLT template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); -#endif template void transformRowToFormat(char * A, char *out, int rows, int cols); template void transformRowToFormat(char * A, char *out, int rows, int cols); diff --git a/csrc/ops.hiph b/csrc/ops.hiph index 8e41f852a..24309c265 100644 --- a/csrc/ops.hiph +++ b/csrc/ops.hiph @@ -1,35 +1,33 @@ -// !!! This is a file automatically generated by hipify!!! // Copyright (c) Facebook, Inc. and its affiliates. // // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. +#ifdef __HIP_PLATFORM_AMD__ +// Compiled with HIP-Clang +#endif #ifndef ops_H #define ops_H #include #include -#include #include #include #include -#include -//#ifndef NO_HIPBLASLT -#include -//#endif -#include +#include +#include +#include #include #include -/* #include #include -*/ -#define CUDA_CHECK_RETURN(value) { \ + +#define HIP_CHECK_RETURN(value) { \ hipError_t _m_cudaStat = value; \ if (_m_cudaStat != hipSuccess) { \ fprintf(stderr, "Error %s at line %d in file %s\n", \ @@ -39,29 +37,28 @@ #define THREADS_PER_BLOCKS (512) -#define CHECK_HIPSPARSE(value) { \ - hipsparseStatus_t _m_hipStat = value; \ - if (_m_hipStat != HIPSPARSE_STATUS_SUCCESS) { \ - fprintf(stderr, "Error at line %d in file %s\n", \ - __LINE__, __FILE__); \ +#define CHECK_CUSPARSE(value) { \ + hipsparseStatus_t _m_cudaStat = value; \ + if (_m_cudaStat != HIPSPARSE_STATUS_SUCCESS) { \ + fprintf(stderr, "Error %s at line %d in file %s\n", \ + hipsparseGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ exit(1); \ } } - #define THREADS_PER_BLOCKS (512) -inline void v(hipError_t status) { +inline void checkCudaStatus(hipError_t status) { if (status != hipSuccess) { - printf("hip API failed with status %d: %s\n", status, hipGetErrorString(status)); - throw std::logic_error("hip API failed"); + printf("cuda API failed with status %d: %s\n", status, hipGetErrorString(status)); + throw std::logic_error("cuda API failed"); } } -inline int checkHipblasStatus(hipblasStatus_t status) { +inline int checkCublasStatus(hipblasStatus_t status) { if (status != HIPBLAS_STATUS_SUCCESS) { - printf("hipBLAS API failed with status %d\n", status); + printf("cuBLAS API failed with status %d\n", status); //throw std::logic_error("cuBLAS API failed"); return 1; } @@ -109,18 +106,17 @@ typedef enum Funcs_t class Context { public: - rocblas_handle m_handle; + hipblasHandle_t m_handle; Context() { - rocblas_handle handle; - rocblas_create_handle(&handle); + hipblasHandle_t handle; + hipblasCreate(&handle); m_handle = handle; } }; -#ifndef NO_HIPBLASLT class ContextLt { public: @@ -132,15 +128,15 @@ class ContextLt hipblasLtCreate(&handle); m_handle = handle; } + }; -#endif -class ContextHipsparse +class ContextCusparse { public: hipsparseHandle_t m_handle; - ContextHipsparse() + ContextCusparse() { hipsparseHandle_t handle; hipsparseCreate(&handle); @@ -185,12 +181,9 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i long long int strideA, long long int strideB, long long int strideC, int batchCount); -#ifndef NO_HIPBLASLT template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); template void transform(hipblasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2); -#endif - void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols); void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index c74357758..c2c06cdfb 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -3,11 +3,13 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -#if BUILD_CUDA -#include -#endif +//#if BUILD_CUDA +//#include #if BUILD_HIP -#include +#include +#endif +#if BUILD_MPS +// #include #endif #include @@ -18,7 +20,8 @@ // UNMANGLED CALLS //=================================================================================== -#if defined(BUILD_CUDA) || defined(BUILD_HIP) +//#if BUILD_CUDA +#if BUILD_HIP void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } @@ -34,13 +37,8 @@ void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, floa void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } -#if defined(BUILD_CUDA) -void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) -{ gemm_4bit_inference_naive<__nv_bfloat16, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } -#elif defined(BUILD_HIP) void gemm_4bit_inference_naive_bf16(int m, int n, int k, hip_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } -#endif void gemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } @@ -65,20 +63,12 @@ MAKE_FUNC32(momentum, MOMENTUM, float, 32) MAKE_FUNC32(momentum, MOMENTUM, half, 16) MAKE_FUNC32(adam, ADAM, float, fp32) MAKE_FUNC32(adam, ADAM, half, fp16) -#if defined(BUILD_CUDA) -MAKE_FUNC32(adam, ADAM, __nv_bfloat16, bf16) -#elif defined(BUILD_HIP) MAKE_FUNC32(adam, ADAM, hip_bfloat16, bf16) -#endif MAKE_FUNC32(rmsprop, RMSPROP, float, 32) MAKE_FUNC32(rmsprop, RMSPROP, half, 16) MAKE_FUNC32(lion, LION, float, fp32) MAKE_FUNC32(lion, LION, half, fp16) -#if defined(BUILD_CUDA) -MAKE_FUNC32(lion, LION, __nv_bfloat16, bf16) -#elif defined(BUILD_HIP) MAKE_FUNC32(lion, LION, hip_bfloat16, bf16) -#endif MAKE_FUNC32(adagrad, ADAGRAD, float, 32) MAKE_FUNC32(adagrad, ADAGRAD, half, 16) @@ -118,18 +108,11 @@ MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, fp16) MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, fp32) MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, fp16) MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32) -#if defined(BUILD_CUDA) -MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) -#elif defined(BUILD_HIP) MAKE_BLOCKWISE8(adam, ADAM, hip_bfloat16, bf16) -#endif MAKE_BLOCKWISE8(lion, LION, half, fp16) MAKE_BLOCKWISE8(lion, LION, float, fp32) -#if defined(BUILD_CUDA) -MAKE_BLOCKWISE8(lion, LION, __nv_bfloat16, bf16) -#elif defined(BUILD_HIP) MAKE_BLOCKWISE8(lion, LION, hip_bfloat16, bf16) -#endif + void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } @@ -138,15 +121,9 @@ void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } -#if defined(BUILD_CUDA) -void quantizeBlockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); } -void quantizeBlockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } -void quantizeBlockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<__nv_bfloat16, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } -#elif defined(BUILD_HIP) void quantizeBlockwise_bf16(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_bf16_fp4(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_bf16_nf4(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } -#endif void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } @@ -160,34 +137,17 @@ void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, floa void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } -#if defined(BUILD_CUDA) -void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, blocksize, n); } -void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n); } -void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n); } -#elif defined(BUILD_HIP) void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } -#endif -#ifndef NO_HIPBLASLT -#if BUILD_CUDA -#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ -void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \ -{ \ - transform(ltHandle, A, out, dim1, dim2); \ -} \ -#endif -#if BUILD_HIP #define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(hipblasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \ { \ transform(ltHandle, A, out, dim1, dim2); \ } \ -#endif - MAKE_FUNC_TRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8); MAKE_FUNC_TRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8); MAKE_FUNC_TRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8); @@ -197,15 +157,6 @@ MAKE_FUNC_TRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8); MAKE_FUNC_TRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8); MAKE_FUNC_TRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32); -#if defined(BUILD_HIP) -MAKE_FUNC_TRANSFORM(8, row, col, t, int8_t, ROW, COL, true, 8); -MAKE_FUNC_TRANSFORM(32, row, col, n, int32_t, ROW, COL, false, 32); -MAKE_FUNC_TRANSFORM(32, row, col, t, int32_t, ROW, COL, true, 32); -MAKE_FUNC_TRANSFORM(8, col, row, n, int8_t, COL, ROW, false, 8); -MAKE_FUNC_TRANSFORM(32, col, row, n, int32_t, COL, ROW, false, 32); -#endif -#endif - void transform_row2col32(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } void transform_row2col32T(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } void transform_row2turing(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } @@ -216,29 +167,6 @@ void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRo void extractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); } void extractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); } -#ifndef NO_HIPBLASLT - -#if defined(BUILD_CUDA) - int igemmlt_turing_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int igemmlt_turing_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int igemmlt_turing_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int igemmlt_ampere_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int igemmlt_ampere_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int igemmlt_ampere_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } -#endif - -#if BUILD_HIP int igemmlt_turing_32(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } @@ -256,21 +184,18 @@ void extractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int row int igemmlt_ampere_8_rowscale(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } -#endif - -#endif void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) { spmm_coo_very_sparse_naive(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) { spmm_coo_very_sparse_naive(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } - #endif extern "C" { -#if defined(BUILD_CUDA) || defined(BUILD_HIP) +//#if BUILD_CUDA +#if BUILD_HIP void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); } void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); } void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } @@ -292,16 +217,6 @@ extern "C" void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); } -#if defined(BUILD_CUDA) - void cquantize_blockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); } - void cquantize_blockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); } - void cquantize_blockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); } - - void cdequantize_blockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); } - void cdequantize_blockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); } - void cdequantize_blockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); } - -#elif defined(BUILD_HIP) void cquantize_blockwise_bf16(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_bf16_fp4(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_bf16_nf4(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); } @@ -309,7 +224,7 @@ extern "C" void cdequantize_blockwise_bf16(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); } -#endif + #define MAKE_CFUNC32(name, gtype, gbits) \ void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ @@ -319,22 +234,14 @@ extern "C" MAKE_CFUNC32(adam, float, fp32) MAKE_CFUNC32(adam, half, fp16) - #if defined(BUILD_CUDA) - MAKE_CFUNC32(adam, __nv_bfloat16, bf16) - #elif defined(BUILD_HIP) MAKE_CFUNC32(adam, hip_bfloat16, bf16) - #endif MAKE_CFUNC32(momentum, float, 32) MAKE_CFUNC32(momentum, half, 16) MAKE_CFUNC32(rmsprop, float, 32) MAKE_CFUNC32(rmsprop, half, 16) MAKE_CFUNC32(lion, float, fp32) MAKE_CFUNC32(lion, half, fp16) - #if defined(BUILD_CUDA) - MAKE_CFUNC32(lion, __nv_bfloat16, bf16) - #elif defined(BUILD_HIP) MAKE_CFUNC32(lion, hip_bfloat16, bf16) - #endif MAKE_CFUNC32(adagrad, float, 32) MAKE_CFUNC32(adagrad, half, 16) @@ -374,18 +281,10 @@ extern "C" MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16) MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32) - #if defined(BUILD_CUDA) - MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) - #elif defined(BUILD_HIP) MAKE_CBLOCKWISE8(adam, ADAM, hip_bfloat16, bf16) - #endif MAKE_CBLOCKWISE8(lion, LION, half, fp16) MAKE_CBLOCKWISE8(lion, LION, float, fp32) - #if defined(BUILD_CUDA) - MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16) - #elif defined(BUILD_HIP) MAKE_CBLOCKWISE8(lion, LION, hip_bfloat16, bf16) - #endif void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); } @@ -398,46 +297,8 @@ extern "C" { strided_gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc, strideA, strideB, strideC, batchCount); } Context *get_context(){ return new Context(); } -#if BUILD_CUDA ContextCusparse *get_cusparse(){ return new ContextCusparse(); } -#endif - -#if BUILD_HIP - ContextHipsparse *get_hipsparse(){ return new ContextHipsparse(); } -#endif - - -#ifndef NO_HIPBLASLT -#if BUILD_CUDA - int cigemmlt_turing_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - //{ (cublasLtHandle_t)context->m_handle; return 0; } - //{ return 0; }//igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int cigemmlt_turing_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt_turing_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int cigemmlt_turing_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt_turing_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int cigemmlt_ampere_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt_ampere_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int cigemmlt_ampere_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - int cigemmlt_ampere_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt_ampere_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } - - #define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ - void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \ - { \ - transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose((cublasLtHandle_t) context->m_handle, A, out, dim1, dim2); \ - } \ - -#endif -#if BUILD_HIP int cigemmlt_turing_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { return igemmlt_turing_32((hipblasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } //{ (hipblasLtHandle_t)context->m_handle; return 0; } @@ -464,10 +325,6 @@ extern "C" transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose((hipblasLtHandle_t) context->m_handle, A, out, dim1, dim2); \ } \ - -#endif - - MAKE_FUNC_CTRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8) MAKE_FUNC_CTRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8) MAKE_FUNC_CTRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8) @@ -477,14 +334,6 @@ extern "C" MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8) MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32) - #if defined(BUILD_HIP) - MAKE_FUNC_CTRANSFORM(8, row, col, t, int8_t, ROW, COL, true, 8) - MAKE_FUNC_CTRANSFORM(32, row, col, n, int32_t, ROW, COL, false, 32) - MAKE_FUNC_CTRANSFORM(32, row, col, t, int32_t, ROW, COL, true, 32) - MAKE_FUNC_CTRANSFORM(8, col, row, n, int8_t, COL, ROW, false, 8) - MAKE_FUNC_CTRANSFORM(32, col, row, n, int32_t, COL, ROW, false, 32) - #endif -#endif void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols) { dequant_mm_int32_fp16(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols); } void cget_col_row_stats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols) @@ -511,15 +360,8 @@ extern "C" void ctransform_row2ampereT(char * A, char *out, int rows, int cols) { transform_row2ampereT(A, out, rows, cols); } -#if BUILD_CUDA void cspmm_coo(ContextCusparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) - { spmm_coo((cusparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); } -#endif - -#if BUILD_HIP - void cspmm_coo(ContextHipsparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) { spmm_coo((hipsparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); } -#endif void cspmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) { spmm_coo_very_sparse_naive_fp16(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } @@ -539,12 +381,11 @@ extern "C" void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } -#if BUILD_CUDA void *cget_managed_ptr(size_t bytes) { void *ptr; - CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost)); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + HIP_CHECK_RETURN(hipMallocManaged(&ptr, bytes, hipMemAttachHost)); + HIP_CHECK_RETURN(hipPeekAtLastError()); return ptr; } @@ -553,35 +394,12 @@ extern "C" { int hasPrefetch = 0; - CUDA_CHECK_RETURN(cudaDeviceGetAttribute(&hasPrefetch, cudaDevAttrConcurrentManagedAccess, device)); // 40ns overhead + HIP_CHECK_RETURN(hipDeviceGetAttribute(&hasPrefetch, hipDeviceAttributeConcurrentManagedAccess, device)); // 40ns overhead if (hasPrefetch == 0) return; - - CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0)); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - } -#endif - -#if BUILD_HIP - void *cget_managed_ptr(size_t bytes) - { - void *ptr; - CUDA_CHECK_RETURN(hipMallocManaged(&ptr, bytes, hipMemAttachHost)); - CUDA_CHECK_RETURN(hipPeekAtLastError()); - - return ptr; - } - - void cprefetch(void *ptr, size_t bytes, int device) - { - int hasPrefetch = 0; - CUDA_CHECK_RETURN(hipDeviceGetAttribute(&hasPrefetch, hipDeviceAttributeConcurrentManagedAccess, device)); // 40ns overhead - if (hasPrefetch == 0) return; - - CUDA_CHECK_RETURN(hipMemPrefetchAsync(ptr, bytes, device, 0)); - CUDA_CHECK_RETURN(hipPeekAtLastError()); + HIP_CHECK_RETURN(hipMemPrefetchAsync(ptr, bytes, device, 0)); + HIP_CHECK_RETURN(hipPeekAtLastError()); } -#endif #define CMAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ void c##fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ fname##_##type_name(A, B, value, n); } \ @@ -594,17 +412,14 @@ extern "C" void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } - #if defined(BUILD_CUDA) - void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) - #elif defined(BUILD_HIP) void cgemm_4bit_inference_naive_bf16(int m, int n, int k, hip_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) - #endif { gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } void cgemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } #endif + void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } } diff --git a/deploy.sh b/deploy.sh index c261ee9a9..e60373627 100644 --- a/deploy.sh +++ b/deploy.sh @@ -5,7 +5,7 @@ echo "MAKE SURE LD_LIBRARY_PATH IS EMPTY!" echo $LD_LIBRARY_PATH if [[ ! -z "${LD_LIBRARY_PATH}" ]]; then - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -24,7 +24,7 @@ make cpuonly CUDA_VERSION="CPU" if [ ! -f "./bitsandbytes/libbitsandbytes_cpu.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -34,7 +34,7 @@ make cuda110 CUDA_VERSION=110 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda110.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -44,7 +44,7 @@ make cuda11x CUDA_VERSION=111 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda111.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -54,7 +54,7 @@ make cuda11x CUDA_VERSION=114 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda114.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -64,7 +64,7 @@ make cuda11x CUDA_VERSION=115 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda115.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -74,7 +74,7 @@ make cuda11x CUDA_VERSION=117 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda117.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -84,7 +84,7 @@ make cuda118 CUDA_VERSION=118 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda118.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -94,7 +94,7 @@ make cuda12x CUDA_VERSION=120 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -104,7 +104,7 @@ make cuda12x CUDA_VERSION=121 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -114,7 +114,7 @@ make cuda12x CUDA_VERSION=122 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda122.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -124,7 +124,7 @@ make cuda12x CUDA_VERSION=123 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda123.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -138,7 +138,7 @@ make cuda110_nomatmul CUDA_VERSION=110 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda110_nocublaslt.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -149,7 +149,7 @@ make cuda11x_nomatmul CUDA_VERSION=111 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda111_nocublaslt.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -159,7 +159,7 @@ make cuda11x_nomatmul CUDA_VERSION=114 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda114_nocublaslt.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -169,7 +169,7 @@ make cuda11x_nomatmul CUDA_VERSION=115 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda115_nocublaslt.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -179,7 +179,7 @@ make cuda11x_nomatmul CUDA_VERSION=117 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda117_nocublaslt.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -189,7 +189,7 @@ make cuda118_nomatmul CUDA_VERSION=118 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda118_nocublaslt.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -199,7 +199,7 @@ make cuda12x_nomatmul CUDA_VERSION=120 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120_nocublaslt.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -209,7 +209,7 @@ make cuda12x_nomatmul CUDA_VERSION=121 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121_nocublaslt.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -219,7 +219,7 @@ make cuda12x_nomatmul CUDA_VERSION=122 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda122_nocublaslt.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi @@ -229,7 +229,7 @@ make cuda12x_nomatmul CUDA_VERSION=123 if [ ! -f "./bitsandbytes/libbitsandbytes_cuda123_nocublaslt.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 + echo "Compilation unsuccessful!" 1>&2 exit 64 fi diff --git a/environment-bnb.yml b/environment-bnb.yml new file mode 100644 index 000000000..1214f7930 --- /dev/null +++ b/environment-bnb.yml @@ -0,0 +1,21 @@ +# for cmake build +name: bnb +channels: + - pytorch + - nvidia + - conda-forge + +dependencies: + - python + #- accelerate + #- einops + - scipy + #- transformers + - pytest + - pytest-cases + - ipython + - debugpy + - yapf + - monkeytype + - rich + - pytest-sugar diff --git a/environment.yml b/environment.yml index c0e07f153..1232c4215 100644 --- a/environment.yml +++ b/environment.yml @@ -1,15 +1,10 @@ name: bnb channels: - - pytorch - - nvidia - conda-forge dependencies: # Base - - conda-forge::python=3.8 - - pytorch::pytorch=>2.1 - - pytorch::pytorch-cuda=11.8 - - nvidia::cuda=11.8 + - conda-forge::python=3.9 # Libraries - conda-forge::accelerate - conda-forge::einops @@ -27,13 +22,15 @@ dependencies: - conda-forge::monkeytype # infer type annotations - conda-forge::rich # better, colored tracebacks, etc - conda-forge::pytest-sugar # better pytest output + # - conda-forge::nodejs # for `doc-builder preview` (optional) ## ENV CREATION - steps to reproduce: -# mamba env remove -n bnb -# mamba create -y -n bnb python=3.8 # creating an empty env bypasses conda -# # and leads to much faster env resolution in the next step https://github.com/mamba-org/mamba/issues/633#issuecomment-812272143 -# mamba env update -n bnb -f environment.yml -# mamba activate bnb +# conda create -n mypython39-rocm python=3.9 +# install torch with ROCM 6.0 support +# pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.0 +# Install this repo +# python3 -m pip install ./bitsandbytes.tar.gz +# install the rest of the dependencies ## PIP dependencies (install *after* ENV CREATION): # pip install --no-cache-dir --no-deps lion_pytorch triton hf-doc-builder watchdog @@ -42,4 +39,4 @@ dependencies: ## ENV UPDATE: # # add new packages to environment.yml, then: -# mamba env update -n bnb -f environment.yml \ No newline at end of file +# mamba env updatpython diff --git a/include/Algo-Direct-Common.h b/include/Algo-Direct-Common.h index c97084904..7b40edea9 100644 --- a/include/Algo-Direct-Common.h +++ b/include/Algo-Direct-Common.h @@ -190,7 +190,7 @@ struct DirectInfo xi = xws; } else { - myassert(Gap==1, "if Gap>1 then X workspace must be provided"); + myassert((Gap==1), "if Gap>1 then X workspace must be provided"); xi = x; } diff --git a/include/Algo-Direct2.h b/include/Algo-Direct2.h index 347ec9c5e..547ca9955 100644 --- a/include/Algo-Direct2.h +++ b/include/Algo-Direct2.h @@ -52,6 +52,7 @@ struct AlgoVecBase::val private: typedef AlgoScalarBase base_t; +#ifdef USE_SSE2 FORCE_INLINE //NO_INLINE void resolve(const FVec& vz, const IVec& bidx, uint32 *pr) const @@ -93,8 +94,8 @@ struct AlgoVecBase::val __m128 vxp = _mm_shuffle_ps(xp01, xp23, (1) + (3 << 2) + (1 << 4) + (3 << 6)); #endif IVec i(u.vec); - IVec vlem = operator< (vz, vxm); - IVec vlep = operator< (vz, vxp); + IVec vlem = vz < vxm; + IVec vlep = vz < vxp; i = i + vlem + vlep; i.store(pr); } @@ -123,8 +124,8 @@ struct AlgoVecBase::val __m128d vxp = _mm_shuffle_pd(vx0, vx1, 3); IVec i(b1, b0); - IVec vlem = operator< (vz, vxm); - IVec vlep = operator< (vz, vxp); + IVec vlem = (vz < vxm); + IVec vlep = (vz < vxp); i = i + vlem + vlep; union { @@ -135,6 +136,7 @@ struct AlgoVecBase::val pr[0] = u.ui32[0]; pr[1] = u.ui32[2]; } +#endif // USE_SSE2 #ifdef USE_AVX @@ -157,7 +159,7 @@ struct AlgoVecBase::val FVec vxp = _mm256_i32gather_ps(xi, idxp, sizeof(float)); IVec ip = idxm; -#else // do not use gather instrucions +#else // do not use gather instructions union U { __m256i vec; @@ -227,8 +229,8 @@ struct AlgoVecBase::val #endif - IVec vlem = operator< (vz, vxm); - IVec vlep = operator< (vz, vxp); + IVec vlem = vz < vxm; + IVec vlep = vz < vxp; ip = ip + vlem + vlep; ip.store(pr); @@ -277,8 +279,8 @@ struct AlgoVecBase::val // FVec vxp = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01p), h23p, 1); IVec i(u.vec); - IVec vlem = operator< (vz, vxm); - IVec vlep = operator< (vz, vxp); + IVec vlem = vz < vxm; + IVec vlep = vz < vxp; i = i + vlem + vlep; i.extractLo32s().store(pr); } diff --git a/include/Portable.h b/include/Portable.h index 1710b0502..987f7726d 100644 --- a/include/Portable.h +++ b/include/Portable.h @@ -4,10 +4,40 @@ #include #include +#if defined(__aarch64__) +#ifdef __HIPCC__ +#undef USE_NEON // Doesn't work with nvcc, undefined symbols +#else +#include +#undef USE_NEON // Not yet implemented +#endif +#undef USE_AVX // x86_64 only +#undef USE_AVX2 // x86_64 only +#undef USE_SSE2 // x86_64 only +#undef USE_SSE41 // x86_64 only +#undef USE_SSE42 // x86_64 only +#undef USE_FMA // x86_64 only +#ifdef USE_NEON +typedef float32x4_t __m128; +typedef int32x4_t __m128i; +typedef float64x2_t __m128d; +#else +typedef struct {float a; float b; float c; float d;} __m128; +typedef struct {int a; int b; int c; int d;} __m128i; +typedef struct {double a; double b;} __m128d; +#endif +#else +#undef USE_NEON // ARM64 only #ifdef __FMA__ #define USE_FMA #endif +#if !defined(__SSE2__) && !defined(_MSC_VER) +#error Compiler must support SSE2 +#endif +#define USE_SSE2 +#if defined(__aarch64__) +#else #ifdef __AVX2__ #define USE_AVX2 #endif @@ -24,7 +54,8 @@ #ifdef __SSE4_2__ #define USE_SSE42 #endif - +#endif +#endif #ifndef _MSC_VER #include @@ -147,5 +178,5 @@ inline T prev(T x) return x; } -} // namepsace Details +} // namespace Details } // namespace BinSearch diff --git a/include/SIMD.h b/include/SIMD.h index a2ac1a9ae..9d1410c73 100644 --- a/include/SIMD.h +++ b/include/SIMD.h @@ -2,6 +2,46 @@ #include "Portable.h" +#ifdef USE_SSE2 +#include +#if defined(USE_AVX) || defined(USE_AVX2) +#include +#else +#ifdef USE_SSE41 +#include +#endif +#endif +#endif + +namespace BinSearch { +namespace Details { + +template +struct FTOITraits{}; + +template +struct FVec; + +template +struct IVec; + +template +struct FVec1; + +template <> struct InstrFloatTraits +{ + typedef __m128 vec_t; +}; + +template <> struct InstrFloatTraits +{ + typedef __m128d vec_t; +}; + +} +} + +#if !defined(__aarch64__) #ifdef USE_SSE42 #ifndef _MSC_VER #include @@ -26,29 +66,11 @@ FORCE_INLINE int popcnt32(int x32) } // namespace #endif -#if defined(USE_AVX) || defined(USE_AVX2) -#include -#else -#include -#ifdef USE_SSE41 -#include -#endif -#endif - #include "Type.h" namespace BinSearch { namespace Details { -template -struct FVec; - -template -struct IVec; - -template -struct FVec1; - template <> struct InstrIntTraits { typedef __m128i vec_t; @@ -64,8 +86,8 @@ template <> struct InstrFloatTraits typedef __m128d vec_t; }; -template -struct FTOITraits +template <> +struct FTOITraits { typedef IVec vec_t; }; @@ -285,9 +307,11 @@ FORCE_INLINE FVec operator- (const FVec& a, const FVec< FORCE_INLINE FVec operator* (const FVec& a, const FVec& b) { return _mm_mul_ps( a, b ); } FORCE_INLINE FVec operator/ (const FVec& a, const FVec& b) { return _mm_div_ps( a, b ); } FORCE_INLINE IVec ftoi (const FVec& a) { return _mm_cvttps_epi32(a); } +#ifndef __clang__ // Conflicts with builtin operator FORCE_INLINE IVec operator<= (const FVec& a, const FVec& b) { return _mm_castps_si128( _mm_cmple_ps( a, b ) ); } FORCE_INLINE IVec operator>= (const FVec& a, const FVec& b) { return _mm_castps_si128( _mm_cmpge_ps( a, b ) ); } FORCE_INLINE IVec operator< (const FVec& a, const FVec& b) { return _mm_castps_si128(_mm_cmplt_ps(a, b)); } +#endif #ifdef USE_FMA FORCE_INLINE FVec mulSub(const FVec& a, const FVec& b, const FVec& c) { return _mm_fmsub_ps(a, b, c); } #endif @@ -339,9 +363,11 @@ FORCE_INLINE FVec operator- (const FVec& a, const FVec FORCE_INLINE FVec operator* (const FVec& a, const FVec& b) { return _mm_mul_pd( a, b ); } FORCE_INLINE FVec operator/ (const FVec& a, const FVec& b) { return _mm_div_pd( a, b ); } FORCE_INLINE IVec ftoi (const FVec& a) { return _mm_cvttpd_epi32(a); } +#ifndef __clang__ // Conflicts with builtin operator FORCE_INLINE IVec operator<= (const FVec& a, const FVec& b) { return _mm_castpd_si128( _mm_cmple_pd( a, b ) ); } FORCE_INLINE IVec operator< (const FVec& a, const FVec& b) { return _mm_castpd_si128(_mm_cmplt_pd(a, b)); } FORCE_INLINE IVec operator>= (const FVec& a, const FVec& b) { return _mm_castpd_si128( _mm_cmpge_pd( a, b ) ); } +#endif #ifdef USE_FMA FORCE_INLINE FVec mulSub(const FVec& a, const FVec& b, const FVec& c ) { return _mm_fmsub_pd(a, b, c); } #endif @@ -558,5 +584,6 @@ FORCE_INLINE FVec mulSub(const FVec& a, const FVec=42", - "wheel" -] +requires = [ "setuptools", "wheel" ] build-backend = "setuptools.build_meta" [tool.ruff] @@ -11,9 +8,7 @@ src = [ "tests", "benchmarking" ] -fix = true select = [ - "A", # prevent using keywords that clobber python builtins "B", # bugbear: security warnings "E", # pycodestyle "F", # pyflakes @@ -24,14 +19,39 @@ select = [ ] target-version = "py38" ignore = [ - "E712", # Allow using if x == False, as it's not always equivalent to if x. + "B007", # Loop control variable not used within the loop body (TODO: enable) + "B028", # Warning without stacklevel (TODO: enable) "E501", # Supress line-too-long warnings: trust yapf's judgement on this one. - "F401", + "E701", # Multiple statements on one line (TODO: enable) + "E712", # Allow using if x == False, as it's not always equivalent to if x. + "E731", # Do not use lambda + "F841", # Local assigned but not used (TODO: enable, these are likely bugs) + "RUF012", # Mutable class attribute annotations ] ignore-init-module-imports = true # allow to expose in __init__.py via imports +[tool.ruff.extend-per-file-ignores] +"**/__init__.py" = ["F401"] # allow unused imports in __init__.py +"{benchmarking,tests}/**/*.py" = [ + "B007", + "B011", + "B023", + "E701", + "E731", + "F841", + "UP030", +] + [tool.ruff.isort] combine-as-imports = true detect-same-package = true force-sort-within-sections = true -known-first-party = ["bitsandbytes"] \ No newline at end of file +known-first-party = ["bitsandbytes"] + +[[tool.mypy.overrides]] +module = "triton.*" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "scipy.stats" +ignore_missing_imports = true diff --git a/pytest.ini b/pytest.ini index 9902b98fa..ac6d72e63 100644 --- a/pytest.ini +++ b/pytest.ini @@ -7,4 +7,7 @@ addopts = -rP log_cli = True log_cli_level = INFO -log_file = logs/pytest.log \ No newline at end of file +log_file = logs/pytest.log +markers = + benchmark: mark test as benchmark + slow: mark test as slow diff --git a/requirements-ci.txt b/requirements-ci.txt new file mode 100644 index 000000000..46bd5b9cd --- /dev/null +++ b/requirements-ci.txt @@ -0,0 +1,7 @@ +# Requirements used for GitHub actions +pytest==7.2.2 +einops==0.6.0 +wheel==0.40.0 +lion-pytorch==0.0.6 +scipy==1.11.4 +pandas==2.2.0 diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 000000000..7ede5b061 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,9 @@ +# Requirements used for local development +setuptools>=63 +pytest~=7.2.2 +einops~=0.6.0 +wheel~=0.40.0 +lion-pytorch~=0.0.6 +scipy~=1.11.4 +pandas~=2.2.0 +matplotlib~=3.8.2 diff --git a/setup.py b/setup.py index c07451d20..13af2a39b 100644 --- a/setup.py +++ b/setup.py @@ -6,9 +6,9 @@ import os from setuptools import find_packages, setup +from setuptools.dist import Distribution - -libs = list(glob.glob("./bitsandbytes/libbitsandbytes*.so")) +libs = list(glob.glob("./bitsandbytes/libbitsandbytes*.*")) libs = [os.path.basename(p) for p in libs] print("libs:", libs) @@ -17,9 +17,15 @@ def read(fname): return open(os.path.join(os.path.dirname(__file__), fname)).read() +# Tested with wheel v0.29.0 +class BinaryDistribution(Distribution): + def has_ext_modules(self): + return True + + setup( - name=f"bitsandbytes", - version="0.42.0", + name="bitsandbytes", + version="0.43.0.dev0", author="Tim Dettmers", author_email="dettmers@cs.washington.edu", description="k-bit optimizers and matrix multiplication routines.", @@ -28,12 +34,16 @@ def read(fname): url="https://github.com/TimDettmers/bitsandbytes", packages=find_packages(), package_data={"": libs}, - install_requires=['torch', 'numpy', 'scipy'], - extras_require={'benchmark': ['pandas', 'matplotlib']}, + install_requires=["torch", "numpy"], + extras_require={ + "benchmark": ["pandas", "matplotlib"], + "test": ["scipy"], + }, long_description=read("README.md"), long_description_content_type="text/markdown", classifiers=[ "Development Status :: 4 - Beta", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], + distclass=BinaryDistribution, ) diff --git a/tests/conftest.py b/tests/conftest.py index 0b4b91225..17ffd281c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,10 +5,19 @@ def pytest_runtest_call(item): try: item.runtest() + except NotImplementedError as nie: + if "NO_CUBLASLT" in str(nie): + pytest.skip("CUBLASLT not available") + raise except AssertionError as ae: if str(ae) == "Torch not compiled with CUDA enabled": pytest.skip("Torch not compiled with CUDA enabled") raise + except RuntimeError as re: + # CUDA-enabled Torch build, but no CUDA-capable device found + if "Found no NVIDIA driver on your system" in str(re): + pytest.skip("No NVIDIA driver found") + raise @pytest.fixture(scope="session") diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 000000000..02cb881a3 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,64 @@ +from io import BytesIO +from itertools import product +import random +from typing import Any, List + +import torch + +test_dims_rng = random.Random(42) + + +TRUE_FALSE = (True, False) +BOOLEAN_TRIPLES = list(product(TRUE_FALSE, repeat=3)) # all combinations of (bool, bool, bool) +BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool) + + +def torch_save_to_buffer(obj): + buffer = BytesIO() + torch.save(obj, buffer) + buffer.seek(0) + return buffer + + +def torch_load_from_buffer(buffer): + buffer.seek(0) + obj = torch.load(buffer) + buffer.seek(0) + return obj + + +def get_test_dims(min: int, max: int, *, n: int) -> List[int]: + return [test_dims_rng.randint(min, max) for _ in range(n)] + + +def format_with_label(label: str, value: Any) -> str: + if isinstance(value, bool): + formatted = "T" if value else "F" + elif isinstance(value, (list, tuple)) and all(isinstance(v, bool) for v in value): + formatted = "".join("T" if b else "F" for b in value) + else: + formatted = str(value) + return f"{label}={formatted}" + + +def id_formatter(label: str): + """ + Return a function that formats the value given to it with the given label. + """ + return lambda value: format_with_label(label, value) + + +DTYPE_NAMES = { + torch.bfloat16: "bf16", + torch.bool: "bool", + torch.float16: "fp16", + torch.float32: "fp32", + torch.float64: "fp64", + torch.int32: "int32", + torch.int64: "int64", + torch.int8: "int8", +} + + +def describe_dtype(dtype: torch.dtype) -> str: + return DTYPE_NAMES.get(dtype) or str(dtype).rpartition(".")[2] diff --git a/tests/test_autograd.py b/tests/test_autograd.py index f045fda4c..6d76e57ce 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -1,61 +1,46 @@ -from itertools import permutations, product +from typing import Tuple import pytest import torch import bitsandbytes as bnb -from bitsandbytes.cextension import HIP_ENVIRONMENT - -n = 1 -k = 25 -dim1 = torch.randint(16, 64, size=(n,)).tolist() -dim2 = torch.randint(32, 96, size=(n,)).tolist() -dim3 = torch.randint(32, 96, size=(n,)).tolist() -dim4 = torch.randint(32, 96, size=(n,)).tolist() -funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)] -str_funcs = ["bmm", "matmul"] -req_grad = [(False, False), (True, False), (True, True), (False, True)] -req_grad_str = ["FF", "TF", "TT", "FT"] -transpose = [(False, False), (False, True), (True, True), (True, False)] -str_transpose = ["FF", "FT", "TT", "TF"] -dtype = [torch.float32, torch.float16] -values = list( - product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose) +from tests.helpers import ( + BOOLEAN_TRIPLES, + BOOLEAN_TUPLES, + TRUE_FALSE, + describe_dtype, + get_test_dims, + id_formatter, ) -str_values = list( - product( - dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose - ) -) -names = [ - "dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format( - *vals - ) - for vals in str_values -] + +TRANSPOSE_VALS = [(False, True), (False, False)] +@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(32, 96, n=1), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) @pytest.mark.parametrize( - "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", - values, - ids=names, + "funcs", + [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)], + ids=["func=bmm", "func=matmul"], ) -def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype) +@pytest.mark.parametrize("req_grad", BOOLEAN_TUPLES, ids=id_formatter("req_grad")) +@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) +def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool], transpose: Tuple[bool, bool]): if dim2 > 0: dim2 = dim2 - (dim2 % 16) dim3 = dim3 - (dim3 % 16) dim4 = dim4 - (dim4 % 16) - for i in range(k): - + for i in range(25): # normal multiply if funcs[0] in [torch.mm, torch.matmul]: dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0]) B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1]) - target = torch.randn( - size=(dim2, dim4), device="cuda", requires_grad=req_grad[1] - ) + target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1]) torch.nn.init.xavier_uniform_(B) if not transpose[0] and not transpose[1]: @@ -87,9 +72,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss( - out_torch, target - ).mean() + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -97,18 +80,14 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_close( - gradA1, gradA2, atol=0.015, rtol=0.1 - ) + torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) assert (idx == 0).sum().item() < n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) assert (idx == 0).sum().item() < n * 0.02 - torch.testing.assert_close( - gradB1, gradB2, atol=0.18, rtol=0.3 - ) + torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3) # batched matrix multiply if funcs[0] in [torch.bmm, torch.matmul]: @@ -135,9 +114,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): n = out_bnb.numel() idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) assert (idx == 0).sum().item() < n * 0.01 - torch.testing.assert_close( - out_bnb, out_torch, atol=0.027, rtol=0.2 - ) + torch.testing.assert_close(out_bnb, out_torch, atol=0.027, rtol=0.2) if any(req_grad): out_bnb.data.copy_(out_torch) @@ -149,9 +126,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss( - out_torch, target - ).mean() + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -159,9 +134,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_close( - gradA1, gradA2, atol=0.015, rtol=0.1 - ) + torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) @@ -208,9 +181,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss( - out_torch, target - ).mean() + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -218,9 +189,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_close( - gradA1, gradA2, atol=0.015, rtol=0.1 - ) + torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) @@ -229,83 +198,22 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): assert (idx == 0).sum().item() < n * 0.02 -n = 1 -k = 3 -dim1 = torch.randint(16, 64, size=(n,)).tolist() -dim2 = torch.randint(32, 96, size=(n,)).tolist() -dim3 = torch.randint(32, 96, size=(n,)).tolist() -dim4 = torch.randint(32, 96, size=(n,)).tolist() - -dim2.append(0) - -decomp = [0.0, 6.0] -funcs = [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)] -str_funcs = ["matmullt", 'switchback_bnb'] -req_grad = [(False, False), (True, False), (True, True), (False, True)] -req_grad = list(product([True, False], repeat=3)) -req_grad_str = [] -for c in req_grad: - strval = '' - for v in c: - if v == True: strval += 'T' - else: strval += 'F' - req_grad_str.append(strval) - -transpose = [(False, True), (False, False)] -str_transpose = ["NT", "NN"] -dtype = [torch.float16, torch.bfloat16, torch.float32] -has_fp16_weights = [True, False] -has_bias = [True, False] -values = list( - product( - dim1, - dim2, - dim3, - dim4, - funcs, - dtype, - req_grad, - transpose, - decomp, - has_fp16_weights, - has_bias - ) -) -str_values = list( - product( - dim1, - dim2, - dim3, - dim4, - str_funcs, - dtype, - req_grad_str, - str_transpose, - decomp, - has_fp16_weights, - has_bias - ) -) -names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_decomp_{}_has_fp16_weights_{}_has_bias_{}".format(*vals) for vals in str_values] - +@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) +@pytest.mark.parametrize("decomp", [0.0, 6.0], ids=id_formatter("decomp")) @pytest.mark.parametrize( - "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias", - values, - ids=names, + "funcs", + [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)], + ids=["func=matmul", "func=switchback_bnb"], ) -def test_matmullt( - dim1, - dim2, - dim3, - dim4, - funcs, - dtype, - req_grad, - transpose, - decomp, - has_fp16_weights, - has_bias -): +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) +@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) +@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) +@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) +@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) +def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias): dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda") @@ -313,19 +221,14 @@ def test_matmullt( req_grad = list(req_grad) req_grad[2] = False - for i in range(k): - + for i in range(3): # normal multiply if funcs[0] in [torch.mm, torch.matmul]: - A = torch.randn( - size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype - ) + A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) if decomp == 6.0: with torch.no_grad(): A[:, outlier_dim] = 6.0 - B = torch.randn( - size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype - ) + B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) target = torch.randn( size=(dim2, dim4), device="cuda", @@ -335,7 +238,7 @@ def test_matmullt( bias = None bias2 = None if has_bias: - bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2]) + bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2]) bias2 = bias.clone() torch.nn.init.xavier_uniform_(B) B2 = B.clone() @@ -380,9 +283,7 @@ def test_matmullt( if any(req_grad): out_bnb.data.copy_(out_torch) torch.cuda.synchronize() - loss_bnb = torch.nn.functional.mse_loss( - out_bnb, target - ).mean() + loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() loss_bnb.backward() gradA1 = A.grad gradB1 = B.grad @@ -392,9 +293,7 @@ def test_matmullt( gradBias1 = bias.grad bias.grad = None - loss_torch = torch.nn.functional.mse_loss( - out_torch, target - ).mean() + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -405,9 +304,7 @@ def test_matmullt( bias.grad = None if req_grad[0]: - torch.testing.assert_close( - gradA1, gradA2, atol=0.015, rtol=0.1 - ) + torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) if req_grad[1]: n = gradB1.numel() if dim2 > 0: @@ -421,66 +318,81 @@ def test_matmullt( assert (idx == 0).sum().item() <= n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) assert (idx == 0).sum().item() <= n * 0.02 - torch.testing.assert_close( - gradB1, gradB2, atol=0.18, rtol=0.3 - ) + torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3) if req_grad[2]: torch.testing.assert_close(gradBias1, gradBias2) -n = 1 -k = 3 -dim1 = torch.randint(16, 64, size=(n,)).tolist() -dim2 = torch.randint(32, 96, size=(n,)).tolist() -dim3 = torch.randint(32, 96, size=(n,)).tolist() -dim4 = torch.randint(32, 96, size=(n,)).tolist() - -dim2.append(0) - -funcs = [(torch.matmul, bnb.matmul_4bit)] -str_funcs = ["matmul"] -req_grad = list(product([True, False], repeat=3)) -req_grad_str = [] -for c in req_grad: - strval = '' - for v in c: - if v == True: strval += 'T' - else: strval += 'F' - req_grad_str.append(strval) - -transpose = [(False, True), (False, False)] -str_transpose = ["NT", "NN"] -dtype = [torch.float16, torch.float32] -compress_statistics = [False, True] -has_fp16_weights = [True, False] -has_bias = [True, False] -quant_type = ['fp4', 'nf4'] -values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type)) -str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias, compress_statistics, quant_type)) -names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics_{}_quant_type_{}".format(*vals) for vals in str_values] -@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type", values, ids=names) -def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type): +@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) +@pytest.mark.parametrize("funcs", [(torch.matmul, bnb.matmul_4bit)], ids=["func=matmul"]) +@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) +@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) +@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype) +@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) +@pytest.mark.parametrize("quant_type", ["fp4", "nf4"], ids=id_formatter("quant_type")) +def test_matmul_4bit( + dim1, + dim2, + dim3, + dim4, + funcs, + dtype, + req_grad, + transpose, + has_bias, + compress_statistics, + quant_type, +): dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) + print("not transpose[0]") + print(not transpose[0]) + print(f"dimA is: {dimA}") dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) + print("not transpose[1]") + print(not transpose[1]) + print(f"dimB is: {dimB}") if has_bias == False: req_grad = list(req_grad) req_grad[2] = False - for i in range(k): + for i in range(3): # normal multiply if funcs[0] in [torch.mm, torch.matmul]: + print(f"req_grad[0] is: {req_grad[0]}") A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) + print(f"tensorA is: {A}") + print(f"tensorA's size is: {A.size()}") + print(f"req_grad[1] is: {req_grad[1]}") B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) + print(f"tensorB is: {B}") + print(f"tensorB's size is: {B.size()}") + print(f"req_grad[2] is: {req_grad[2]}") target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype) bias = None bias2 = None + print(f"target is: {target}") + print(f"target's size is: {target.size()}") + print(f"has_bias is: {has_bias}") + print(f"dim4 is: {dim4}") if has_bias: - bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2]) + bias = torch.randn(dim4, device="cuda", dtype=dtype, requires_grad=req_grad[2]) bias2 = bias.clone() torch.nn.init.xavier_uniform_(B) - B2, quant_state = bnb.functional.quantize_4bit(B, compress_statistics=compress_statistics, quant_type=quant_type) + B2, quant_state = bnb.functional.quantize_4bit( + B, + compress_statistics=compress_statistics, + quant_type=quant_type, + ) + + print(f"B2 is: {B2}") + print(f"B2's size is: {B2.size()}") + print(f"quant_state is: {quant_state}") if not transpose[0] and transpose[1]: out_torch = funcs[0](A, B.t()) @@ -491,15 +403,18 @@ def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, if has_bias: out_torch += bias - + print(f"out_torch is: {out_torch}") + print(f"out_torch size: {out_bnb.size()}") + print(f"out_bnb is: {out_bnb}") + print(f"out_bnb size: {out_bnb.size()}") assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}" - + n = out_bnb.numel() err = torch.abs(out_bnb - out_torch).float().mean().item() if n > 0: - assert err < 0.115 + assert err < 0.12 - #assert err < 0.20 + # assert err < 0.20 if any(req_grad): out_bnb.data.copy_(out_torch) torch.cuda.synchronize() @@ -513,7 +428,7 @@ def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, gradBias1 = bias.grad bias.grad = None - loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean() + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -524,38 +439,31 @@ def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, bias.grad = None if req_grad[0]: - torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) if req_grad[2]: torch.testing.assert_close(gradBias1, gradBias2) -funcs = [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)] -str_funcs = ["matmul_fp8_mixed", 'matmul_fp8_global'] -req_grad = list(product([True, False], repeat=3)) -req_grad_str = [] -for c in req_grad: - strval = '' - for v in c: - if v == True: strval += 'T' - else: strval += 'F' - req_grad_str.append(strval) - -transpose = [(False, True), (False, False)] -str_transpose = ["NT", "NN"] -dtype = [torch.float16, torch.float32] -has_fp16_weights = [True, False] -values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)) -str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose)) -names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(*vals) for vals in str_values] -@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names) -def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): +@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) +@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) +@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype) +@pytest.mark.parametrize( + "funcs", + [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)], + ids=["matmul_fp8_mixed", "matmul_fp8_global"], +) +def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) req_grad = list(req_grad) req_grad[2] = False - for i in range(k): + for i in range(3): # normal multiply if funcs[0] in [torch.mm, torch.matmul]: A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) @@ -580,7 +488,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): err = torch.abs(out_bnb - out_torch).float().mean().item() if n > 0: assert err < 0.115 - #assert err < 0.20 + # assert err < 0.20 if any(req_grad): out_bnb.data.copy_(out_torch) torch.cuda.synchronize() @@ -591,7 +499,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean() + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -599,7 +507,7 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) if req_grad[1]: n = gradB1.numel() @@ -614,9 +522,6 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): assert (idx == 0).sum().item() <= n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) assert (idx == 0).sum().item() <= n * 0.02 - grad_err = (gradB1-gradB2).abs().mean() + grad_err = (gradB1 - gradB2).abs().mean() assert grad_err.item() < 0.003 - torch.testing.assert_close( - gradB1, gradB2, atol=0.18, rtol=0.3 - ) - + torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3) diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 5e9ccf590..b33dfadda 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -1,28 +1,29 @@ -import os import pytest -import torch -from pathlib import Path -from bitsandbytes.cextension import HIP_ENVIRONMENT +from bitsandbytes.cextension import get_cuda_bnb_library_path +from bitsandbytes.cuda_specs import CUDASpecs -# hardcoded test. Not good, but a sanity check for now -# TODO: improve this -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -def test_manual_override(requires_cuda): - manual_cuda_path = str(Path('/mmfs1/home/dettmers/data/local/cuda-12.2')) - pytorch_version = torch.version.cuda.replace('.', '') - - assert pytorch_version != 122 # TODO: this will never be true... - - os.environ['CUDA_HOME']='{manual_cuda_path}' - os.environ['BNB_CUDA_VERSION']='122' - #assert str(manual_cuda_path) in os.environ['LD_LIBRARY_PATH'] - import bitsandbytes as bnb - loaded_lib = bnb.cuda_setup.main.CUDASetup.get_instance().binary_name - #assert loaded_lib == 'libbitsandbytes_cuda122.so' +@pytest.fixture +def hip_spec() -> CUDASpecs: + # All dummy values to pass test. This mainly test if libbitsandbytes_hip_nohipblaslt.so exist in bitsandbytes dir + return CUDASpecs( + cuda_version_string="120", + highest_compute_capability=(8, 6), + cuda_version_tuple=(12, 0), + ) +def test_get_cuda_bnb_library_path(monkeypatch, hip_spec): + monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) + assert get_cuda_bnb_library_path(hip_spec).stem == "libbitsandbytes_hip_nohipblaslt" +#def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): +# monkeypatch.setenv("BNB_CUDA_VERSION", "110") +# assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110" +# assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning? +#def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec): +# monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) +# assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt" diff --git a/tests/test_functional.py b/tests/test_functional.py index 5dba4ef5f..e85ee1859 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,31 +1,35 @@ +from itertools import product import math import random import time -from itertools import product import einops +import numpy as np import pytest +from scipy.stats import norm import torch -import numpy as np import bitsandbytes as bnb from bitsandbytes import functional as F -from bitsandbytes.cextension import HIP_ENVIRONMENT -from scipy.stats import norm - -torch.set_printoptions( - precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 +from tests.helpers import ( + BOOLEAN_TUPLES, + TRUE_FALSE, + describe_dtype, + get_test_dims, + id_formatter, ) + +torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000) k = 20 def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True): - idx = torch.isclose(a, b, rtol, atol) + idx = torch.isclose(a, b, rtol=rtol, atol=atol) sumval = (idx == 0).sum().item() if sumval > count: if throw: print(f"Too many values not close: assert {sumval} < {count}") - torch.testing.assert_close(a, b, rtol, atol) + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) return sumval @@ -91,9 +95,8 @@ def setup(): def teardown(): pass -@pytest.mark.parametrize( - "dtype", [torch.float32, torch.float16], ids=["float", "half"] -) + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"]) def test_estimate_quantiles(dtype): A = torch.rand(1024, 1024, device="cuda") A = A.to(dtype) @@ -110,6 +113,7 @@ def test_estimate_quantiles(dtype): diff = torch.abs(code - quantiles) assert (diff > 5e-02).sum().item() == 0 + def test_quantile_quantization(): for i in range(100): A1 = torch.randn(1024, 1024, device="cuda") @@ -128,7 +132,6 @@ def test_quantile_quantization(): assert diff < 0.001 - def test_dynamic_quantization(): diffs = [] reldiffs = [] @@ -141,8 +144,8 @@ def test_dynamic_quantization(): diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) assert diff.mean().item() < 0.0135 - print(sum(diffs)/len(diffs)) - print(sum(reldiffs)/len(reldiffs)) + print(sum(diffs) / len(diffs)) + print(sum(reldiffs) / len(reldiffs)) for i in range(100): A1 = torch.rand(1024, 1024, device="cuda") @@ -153,17 +156,12 @@ def test_dynamic_quantization(): assert diff < 0.004 -def get_blocksizes(hip_env=False): - if not hip_env: - return [4096, 2048, 1024, 512, 256, 128, 64] - else: - return [4096, 2048, 1024, 512, 256, 128] -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"]) -@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"]) -@pytest.mark.parametrize("blocksize", get_blocksizes(HIP_ENVIRONMENT)) -@pytest.mark.parametrize("signed", [True, False], ids=['signed_True', 'signed_False']) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) +@pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) +@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128]) +@pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): - #print('') + # print('') diffs = [] reldiffs = [] for i in range(100): @@ -174,10 +172,10 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): reldiff = diff / torch.abs(A1.float() + 1e-8) diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) - abserr = sum(diffs)/len(diffs) - relerr = sum(reldiffs)/len(reldiffs) - #print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs)) - #print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs)) + abserr = sum(diffs) / len(diffs) + relerr = sum(reldiffs) / len(reldiffs) + # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs)) + # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs)) assert abserr < 0.011 assert relerr < 0.018 assert A2.dtype == dtype @@ -192,9 +190,9 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): reldiff = diff / torch.abs(A1.float() + 1e-8) diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) - #torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) - abserr = sum(diffs)/len(diffs) - relerr = sum(reldiffs)/len(reldiffs) + # torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) + abserr = sum(diffs) / len(diffs) + relerr = sum(reldiffs) / len(reldiffs) if signed: assert abserr < 0.0035 assert relerr < 0.015 @@ -202,14 +200,11 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): assert abserr < 0.00175 assert relerr < 0.012 assert A2.dtype == dtype - #print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) - #print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) - + # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) + # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) -@pytest.mark.parametrize( - "gtype", [torch.float32, torch.float16], ids=["float", "half"] -) +@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"]) def test_percentile_clipping(gtype): gnorm_vec1 = torch.zeros(100, device="cuda") gnorm_vec2 = torch.zeros(100, device="cuda") @@ -219,9 +214,7 @@ def test_percentile_clipping(gtype): for i in range(k): step += 1 g = torch.randn(n, n, dtype=gtype, device="cuda") - gnorm1, clip2, gnorm_scale = F.percentile_clipping( - g, gnorm_vec2, step, percentile=percentile - ) + gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile) assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1 gnorm2 = torch.norm(g.float()) @@ -284,40 +277,28 @@ def mean(xx): return sum(xx) / float(len(xx)) -# dim1 = torch.randint(1,1024*4, size=(4,)).tolist() -# dim2 = torch.randint(1,1024*4, size=(4,)).tolist() -dim1 = [1024 * 2] -dim2 = [1024 * 16] -methods = [ - ( +methods = { + "linear": ( lambda x, dim: quant(x), lambda x, dim: quant(x), dequant, dequant, mm_dequant, - ) -] -methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant)) -# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant)) -method_names = ["linear", "vectorwise"] -batched = [False, True] -values = list(product(dim1, dim2, methods, batched)) -values_names = list(product(dim1, dim2, method_names, batched)) -names = [ - "dim1_{}_dim2_{}_quant_{}_batched_{}".format(*vals) - for vals in values_names -] + ), + "vectorwise": (quant_multi, quant_multi, dequant, dequant, mm_dequant), +} -@pytest.mark.parametrize( - "dim1, dim2, quant_methods, batched", values, ids=names -) +@pytest.mark.parametrize("dim1", [1024 * 2], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [1024 * 16], ids=id_formatter("dim2")) +@pytest.mark.parametrize("quant_methods", methods.values(), ids=methods.keys()) +@pytest.mark.parametrize("batched", TRUE_FALSE, ids=id_formatter("batched")) def test_approx_igemm(dim1, dim2, quant_methods, batched): dim1 = dim1 - (dim1 % 32) dim2 = dim2 - (dim2 % 32) errors = [] relerrors = [] - #print("") + # print("") for i in range(5): if batched: A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda") @@ -329,9 +310,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda") maxA, Ac = quant_methods[0](A, 1) maxB, Bc = quant_methods[1](B, 0) - torch.testing.assert_close( - quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05 - ) + torch.testing.assert_close(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05) if batched: out2 = torch.bmm(A, B) C = torch.bmm(Ac.float(), Bc.float()) @@ -346,8 +325,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): relerr = err / torch.abs(out2) errors.append(err.mean().item()) relerrors.append(relerr.mean().item()) - #print(mean(errors)) - #print(mean(relerrors)) + # print(mean(errors)) + # print(mean(relerrors)) def test_stable_embedding(): @@ -355,36 +334,17 @@ def test_stable_embedding(): layer.reset_parameters() -n = 2 -hidden_dim = torch.randint(32, 256, size=(n,)).tolist() -batch_dim = torch.randint(16, 256, size=(n,)).tolist() -seq_dim = torch.randint(16, 256, size=(n,)).tolist() -transpose = [(False, False), (False, True), (True, False), (True, True)] -values = list(product(hidden_dim, batch_dim, transpose, seq_dim)) -names = [ - "hidden_dim_{}_batch_dim_{},transpose_{}_seq_dim_{}".format(*vals) - for vals in values -] - - -@pytest.mark.parametrize( - "hidden_dim, batch_dim, transpose, seq_dim", values, ids=names -) +@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 256, n=2), ids=id_formatter("hidden_dim")) +@pytest.mark.parametrize("batch_dim", get_test_dims(16, 256, n=2), ids=id_formatter("batch_dim")) +@pytest.mark.parametrize("seq_dim", get_test_dims(16, 256, n=2), ids=id_formatter("seq_dim")) +@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): hidden_dim = hidden_dim - (hidden_dim % 32) batch_dim = batch_dim - (batch_dim % 16) seq_dim = seq_dim - (seq_dim % 16) for i in range(k): - shapeA = ( - (batch_dim, hidden_dim) - if not transpose[0] - else (hidden_dim, batch_dim) - ) - shapeB = ( - (32 * random.randint(1, 4), hidden_dim) - if transpose[1] - else (hidden_dim, 32 * random.randint(1, 4)) - ) + shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim) + shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4)) A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) if not transpose[0] and not transpose[1]: @@ -404,11 +364,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): for i in range(k): shapeA = (batch_dim, seq_dim, hidden_dim) - shapeB = ( - (32 * random.randint(1, 4), hidden_dim) - if transpose[1] - else (hidden_dim, 32 * random.randint(1, 4)) - ) + shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4)) A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) if not transpose[0] and not transpose[1]: @@ -421,52 +377,27 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): torch.testing.assert_close(out.float(), out2) -n = 3 -seq_dim = torch.randint(32, 512, size=(n,)).tolist() -hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist() -batch_dim = torch.randint(2, 16, size=(n,)).tolist() -values = list(product(seq_dim, hidden_dim, batch_dim)) -names = [ - "seq_dim{}_hidden_dim{}_batch_dim{}".format(*vals) for vals in values -] - - -@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names) +@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=3), ids=id_formatter("seq_dim")) +@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=3), ids=id_formatter("hidden_dim")) +@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=3), ids=id_formatter("batch_dim")) def test_dim3_igemm(seq_dim, hidden_dim, batch_dim): seq_dim = seq_dim - (seq_dim % 32) hidden_dim = hidden_dim - (hidden_dim % 32) batch_dim = batch_dim - (batch_dim % 2) for i in range(25): - A = torch.randint( - -128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda" - ).to(torch.int8) - B = torch.randint( - -128, 127, size=(batch_dim, seq_dim, 1024), device="cuda" - ).to(torch.int8) + A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda").to(torch.int8) + B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to(torch.int8) out2 = torch.einsum("bsi, bso->io", A.float(), B.float()) - iout = torch.empty( - A.shape[2], B.shape[2], dtype=torch.int32, device=A.device - ) + iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device) out = F.igemm(A, B, out=iout) torch.testing.assert_close(out.float(), out2) -n = 2 -seq_dim = torch.randint(32, 512, size=(n,)).tolist() -hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist() -batch_dim = torch.randint(2, 16, size=(n,)).tolist() -transpose = [False, True] -values = list(product(seq_dim, hidden_dim, batch_dim, transpose)) -names = [ - "seq_dim={}_hidden_dim={}_batch_dim={}_transpose{}".format(*vals) - for vals in values -] - - -@pytest.mark.parametrize( - "seq_dim, hidden_dim, batch_dim, transpose", values, ids=names -) +@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=2), ids=id_formatter("seq_dim")) +@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=2), ids=id_formatter("hidden_dim")) +@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=2), ids=id_formatter("batch_dim")) +@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose): def min_max(x): maxA = torch.amax(x, dim=2, keepdim=True) @@ -482,9 +413,7 @@ def min_max(x): errs2 = [] relerrs2 = [] for i in range(k): - A = torch.normal( - 0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda" - ) + A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda") if transpose: B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda") else: @@ -533,23 +462,14 @@ def min_max(x): # print(mean(errs2)) # print(mean(relerrs2)) assert mean(errs) < 0.015 - assert mean(relerrs) < 0.3 + assert mean(relerrs) < 0.31 -n = 2 -dim1 = torch.randint(1, 64, size=(n,)).tolist() -dim2 = torch.randint(32, 128, size=(n,)).tolist() -dim3 = torch.randint(32, 256, size=(n,)).tolist() -dim4 = torch.randint(32, 256, size=(n,)).tolist() -transpose = [(False, False), (True, False), (False, True), (True, True)] -values = list(product(dim1, dim2, dim3, dim4, transpose)) -names = [ - "dim1_{}_dim2_{}_dim3_{}_dim4_{}_transpose_{}".format(*vals) - for vals in values -] - - -@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=2), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=2), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", get_test_dims(32, 256, n=2), ids=id_formatter("dim4")) +@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) def test_ibmm(dim1, dim2, dim3, dim4, transpose): dim2 = dim2 - (dim2 % 16) dim3 = dim3 - (dim3 % 16) @@ -570,22 +490,14 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose): out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float()) out = F.igemm(A.permute([0, 2, 1]), B) elif transpose[0] and transpose[1]: - out2 = torch.bmm( - A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float() - ) + out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()) out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1])) torch.testing.assert_close(out.float(), out2.float()) -n = 1 -dim1 = torch.randint(1, 64, size=(n,)).tolist() -dim2 = torch.randint(32, 128, size=(n,)).tolist() -dim3 = torch.randint(32, 256, size=(n,)).tolist() -values = list(product(dim1, dim2, dim3)) -names = ["dim1_{}_dim2_{}_dim3_{}".format(*vals) for vals in values] - -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=1), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=1), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=1), ids=id_formatter("dim3")) def test_vector_quant(dim1, dim2, dim3): dim2 = dim2 - (dim2 % 16) dim3 = dim3 - (dim3 % 16) @@ -594,31 +506,21 @@ def test_vector_quant(dim1, dim2, dim3): qA, SA = F.vectorwise_quant(A, dim=0) A1 = F.vectorwise_dequant(qA, SA) n = A1.numel() - assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n*0.002)) - - + assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002)) -n = 2 -dim1 = torch.randint(2, 256, size=(n,)).tolist() -dim2 = torch.randint(2, 256, size=(n,)).tolist() -dim3 = torch.randint(2, 256, size=(n,)).tolist() -# dim1, dim2 = (256,), (256,) -dtype = [torch.int8, torch.int32] -a_order = ["row"] -out_order = ["col", "row"] if HIP_ENVIRONMENT else ["col", "row", "col32"] -transpose = [False] -dims = [2, 3] -values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)) - -names = ["dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_transpose_{}".format(*vals)for vals in values] - - -@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(2, 256, n=2), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(2, 256, n=2), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dtype", [torch.int8, torch.int32], ids=describe_dtype) +@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA")) +@pytest.mark.parametrize("orderOut", ["col", "row", "col32"], ids=id_formatter("orderOut")) +@pytest.mark.parametrize("transpose", [False], ids=id_formatter("transpose")) +@pytest.mark.parametrize("dims", [2, 3], ids=id_formatter("dims")) def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): - if dims == 3 and out_order != "col32": + if dims == 3 and orderOut != "col32": return - if dtype == torch.int32 and out_order != "col32": + if dtype == torch.int32 and orderOut != "col32": return try: func = F.get_transform_func(dtype, orderA, orderOut, transpose) @@ -628,9 +530,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans if dims == 2: A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype) elif dims == 3: - A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to( - dtype - ) + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype) out, S = F.nvidia_transform(A, to_order=orderOut) @@ -642,17 +542,11 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans if dims == 2: n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32))) elif dims == 3: - n = ( - A.shape[0] - * A.shape[1] - * (A.shape[2] + (32 - (A.shape[2] % 32))) - ) + n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32))) assert out.numel() == n elif orderOut == "col_turing": # 32 col 8 row tiles - n = (A.shape[0] + (8 - A.shape[0] % 8)) * ( - A.shape[1] + (32 - (A.shape[1] % 32)) - ) + n = (A.shape[0] + (8 - A.shape[0] % 8)) * (A.shape[1] + (32 - (A.shape[1] % 32))) assert out.numel() == n total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0) for row in range(A.shape[0]): @@ -661,9 +555,7 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans j = col coltile = (col // 32) + (1 if col % 32 != 0 else 0) - rowtile = ( - (row // 8) + (1 if row % 8 != 0 else 0) - ) * total_coltile + rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile offset = 32 * 8 * (rowtile + coltile) col2 = col % 32 row2 = (row % 8) * 32 @@ -674,46 +566,23 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans # torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset]) if orderOut == "col32": - out2, S = F.nvidia_transform( - out, from_order=orderOut, to_order="row", state=S - ) + out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S) torch.testing.assert_close(A, out2) -n = 1 -dim1 = torch.randint(1, 256, size=(n,)).tolist() -dim2 = torch.randint(32, 512, size=(n,)).tolist() -dim3 = torch.randint(32, 1024, size=(n,)).tolist() -dim4 = torch.randint(32, 1024, size=(n,)).tolist() - -# dim1 = [2] -# dim2 = [2] -# dim3 = [2] -# dim4 = [2] - -dims = (2, 3) -ldb = [0] -# ldb = list(range(256, 1*1024, 256)) -values = list(product(dim1, dim2, dim3, dim4, dims, ldb)) -names = [ - "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}_ldb_{}".format(*vals) - for vals in values -] - -@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(1, 256, n=1), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(32, 512, n=1), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(32, 1024, n=1), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", get_test_dims(32, 1024, n=1), ids=id_formatter("dim4")) +@pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims")) +@pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb")) def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): for i in range(k): if dims == 2: - A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to( - torch.int8 - ) + A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(torch.int8) elif dims == 3: - A = torch.randint( - -128, 127, size=(dim1, dim2, dim3), device="cuda" - ).to(torch.int8) - B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to( - torch.int8 - ) + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(torch.int8) + B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8) C1 = torch.matmul(A.float(), B.t().float()) A2, SA = F.transform(A, "col32") @@ -722,10 +591,8 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): C3, S = F.nvidia_transform(C2, "row", state=SC) torch.testing.assert_close(C1, C3.float()) - ## transpose - B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to( - torch.int8 - ) + # transpose + B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(torch.int8) C1 = torch.matmul(A.float(), B.float()) B2t, SBt = F.transform(B, "col_turing", transpose=True) @@ -734,29 +601,18 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): torch.testing.assert_close(C1, C3.float()) -dim1 = [32] -dim2 = [32] -dim3 = [32] -dim4 = [32] - -dims = (2,) -# ldb = list(range(256, 1*1024, 256)) -values = list(product(dim1, dim2, dim3, dim4, dims)) -names = [ - "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}".format(*vals) - for vals in values -] - -@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names) +@pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4")) +@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): formatB = F.get_special_format_str() for i in range(k): if dims == 2: A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half() elif dims == 3: - A = torch.normal( - 0, 0.5, size=(dim1, dim2, dim3), device="cuda" - ).half() + A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half() B = torch.randn((dim4, dim3), device="cuda").half() torch.nn.init.xavier_uniform_(B) C1 = torch.matmul(A, B.t()) @@ -788,23 +644,15 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): # torch.testing.assert_close(C1, C3.float()) -batch_size = 2 -seqdim = 512 -# values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)] -values = [ - (batch_size, seqdim, 4 * 1024, 3 * 4 * 1024), - (batch_size, seqdim, 5120, 3 * 5120), - (batch_size, seqdim, 12 * 1024, 4 * 12 * 1024), -] - - -# values = list(product(batch, seq, model, hidden)) -names = [ - "batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values -] - - -@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) +@pytest.mark.parametrize( + ("batch", "seq", "model", "hidden"), + [ + pytest.param(2, 512, 4 * 1024, 3 * 4 * 1024, id="batch=2, seq=512, model=4k, hidden=12k"), + pytest.param(2, 512, 5120, 3 * 5120, id="batch=2, seq=512, model=5k, hidden=15k"), + pytest.param(2, 512, 12 * 1024, 4 * 12 * 1024, id="batch=2, seq=512, model=12k, hidden=48k"), + ], +) +@pytest.mark.benchmark def test_bench_8bit_training(batch, seq, model, hidden): formatB = F.get_special_format_str() A = torch.randn(batch, seq, model, device="cuda").half() @@ -825,7 +673,6 @@ def test_bench_8bit_training(batch, seq, model, hidden): torch.cuda.synchronize() t0 = time.time() for i in range(k): - out1 = torch.matmul(A, w1.t()) # fc1 # out2 = torch.matmul(out1, w2.t())# fc2 @@ -954,33 +801,23 @@ def test_bench_8bit_training(batch, seq, model, hidden): # print(t8) -n = 2 -dim1 = torch.randint(64, 256, size=(n,)).tolist() -dim4 = torch.randint(64, 1024, size=(n,)).tolist() - -#dim1 = [2*1024] -#dim4 = [2*1024] - -#dim1 = [4] -#dim4 = [4] - -dims = (2,) -formatB = ["col_turing", "col_ampere"] -has_bias = [True, False] -values = list(product(dim1, dim4, dims, formatB, has_bias)) -names = ["dim1_{}_dim4_{}_dims_{}_formatB_{}_has_bias_{}".format(*vals) for vals in values] - -@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(64, 256, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim4", get_test_dims(64, 1024, n=2), ids=id_formatter("dim4")) +@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) +@pytest.mark.parametrize("formatB", ["col_turing", "col_ampere"], ids=id_formatter("formatB")) +@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): inner = torch.randint(1, 128, size=(1,)).item() bias = None - if has_bias: bias = torch.randn(dim4, device='cuda', dtype=torch.float16) + if has_bias: + bias = torch.randn(dim4, device="cuda", dtype=torch.float16) formatB = F.get_special_format_str() for i in range(1): A = torch.randn(dim1, inner, device="cuda") B = torch.randn(dim4, inner, device="cuda") C1 = torch.matmul(A.half(), B.t().half()) - if has_bias: C1 += bias + if has_bias: + C1 += bias A1, maxA = F.vectorwise_quant(A, dim=1) B1, maxB = F.vectorwise_quant(B, dim=1) @@ -991,36 +828,27 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): C3, S = F.nvidia_transform(C2, "row", state=SC) C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t()) - if has_bias: C4 += bias + if has_bias: + C4 += bias # TODO: is something wrong here? If so, the problem goes deeper - #n = C1.numel() - #p = 0.06 + # n = C1.numel() + # p = 0.06 std = C1.std(0).view(1, -1) C1 /= std C4 /= std - #assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06)) - #assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}" + # assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06)) + # assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}" C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias) - #torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1) + # torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1) n = C5.numel() - assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01*n)) - + assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n)) -n = 2 -dim1 = [1 * 1024] -dim2 = [1 * 1024] -# dim1 = torch.randint(1,4*1024, size=(n,)).tolist() -# dim2 = torch.randint(1,4*1024, size=(n,)).tolist() -dims = (2,) -# ldb = list(range(256, 1*1024, 256)) -values = list(product(dim1, dim2, dims)) -names = ["dim1_{}_dim2_{}_dims_{}".format(*vals) for vals in values] - - -@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names) +@pytest.mark.parametrize("dim1", [1 * 1024], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) def test_colrow_absmax(dim1, dim2, dims): for i in range(k): threshold = 3.0 @@ -1035,9 +863,7 @@ def test_colrow_absmax(dim1, dim2, dims): else: assert False - row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax( - A, threshold=threshold - ) + row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold) A_blocked = einops.rearrange( torch.abs(A), @@ -1057,26 +883,15 @@ def test_colrow_absmax(dim1, dim2, dims): torch.testing.assert_close(row_stats1_trunc, row_stats2) torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2) - row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax( - A, threshold=0.0 - ) + row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0) torch.testing.assert_close(col_stats1, col_stats2) torch.testing.assert_close(row_stats1, row_stats2) assert nnz_block_ptr2 is None -n = 2 -# dim1 = [8*1024] -# dim2 = [4*1024] -dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() -dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist() - -values = list(product(dim1, dim2)) -names = ["dim1_{}_dim2_{}".format(*vals) for vals in values] - - -@pytest.mark.parametrize("dim1, dim2", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2")) def test_double_quant(dim1, dim2): for i in range(k): A = torch.randn(dim1, dim2, device="cuda").half() @@ -1090,40 +905,33 @@ def test_double_quant(dim1, dim2): torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0) n = CAt.numel() - num_not_close_rows = ( - (torch.isclose(CA, out_row1, atol=1) == 0).sum().item() - ) - num_not_close_cols = ( - (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item() - ) + num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item() + num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item() # allow for 1:500 error due to rounding differences min_error = 1 / 500 if num_not_close_cols > (min_error * n): - print( - f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}" - ) + print(f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}") assert False if num_not_close_rows > (min_error * n): - print( - f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}" - ) + print(f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}") assert False torch.testing.assert_close(Srow.flatten().float(), statsA) torch.testing.assert_close(Scol.flatten().float(), statsAt) -n = 4 -dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() -dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist() -inner = torch.randint(1, 4 * 1024, size=(n,)).tolist() - -values = list(zip(dim1, dim4, inner)) -names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values] - -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) +@pytest.mark.parametrize( + ("dim1", "dim4", "inner"), + ( + pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}") + for (dim1, dim4, inner) in zip( + get_test_dims(1, 4 * 1024, n=4), + get_test_dims(1, 4 * 1024, n=4), + get_test_dims(1, 4 * 1024, n=4), + ) + ), +) def test_integrated_igemmlt(dim1, dim4, inner): for i in range(k): A = torch.randn(dim1, inner, device="cuda").half() @@ -1158,16 +966,17 @@ def test_integrated_igemmlt(dim1, dim4, inner): assert err2 <= err1 * 1.025 -n = 6 -dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() -dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist() -inner = torch.randint(1, 4 * 1024, size=(n,)).tolist() - -values = list(zip(dim1, dim4, inner)) -names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values] - - -@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) +@pytest.mark.parametrize( + ("dim1", "dim4", "inner"), + ( + pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}") + for (dim1, dim4, inner) in zip( + get_test_dims(1, 4 * 1024, n=6), + get_test_dims(1, 4 * 1024, n=6), + get_test_dims(1, 4 * 1024, n=6), + ) + ), +) @pytest.mark.skip("Row scale has some bugs for ampere") def test_igemmlt_row_scale(dim1, dim4, inner): formatB = F.get_special_format_str() @@ -1190,9 +999,7 @@ def test_igemmlt_row_scale(dim1, dim4, inner): c = 10.0 * inner * scale row_scale = torch.ones_like(maxA) / c - outC32, SC = F.igemmlt( - A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale - ) + outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) C3, S = F.nvidia_transform(outC32, "row", state=SC) maxval = torch.abs(C3).max() if maxval == 127: @@ -1234,17 +1041,17 @@ def test_igemmlt_row_scale(dim1, dim4, inner): print(sum(err3) / len(err3)) -dim1 = [1024, 2048] -inner = [12288 * 4, 4096 * 4] -dim4 = [12288, 4096] - -values = list(zip(dim1, dim4, inner)) -names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values] - - -@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) +@pytest.mark.parametrize( + ("dim1", "dim4", "inner"), + [ + pytest.param(1024, 12288 * 4, 12288, id="1024, 12288*4, 12288"), + pytest.param(2048, 4096 * 4, 4096, id="2048, 4096*4, 4096"), + ], +) @pytest.mark.skip("Row scale has some bugs for ampere") +@pytest.mark.benchmark def test_row_scale_bench(dim1, dim4, inner): + formatB = F.get_special_format_str() err1, err2, err3 = [], [], [] relerr1, relerr2 = [], [] scale = 1 @@ -1273,9 +1080,7 @@ def test_row_scale_bench(dim1, dim4, inner): torch.cuda.synchronize() t0 = time.time() for i in range(k): - outC32, SC = F.igemmlt( - A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale - ) + outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) torch.cuda.synchronize() print("row-wise", time.time() - t0) @@ -1289,43 +1094,20 @@ def test_row_scale_bench(dim1, dim4, inner): print("vector-wise", time.time() - t0) -n = 2 -dim1 = torch.randint(2, 1024, size=(n,)).tolist() -dim2 = torch.randint(2, 1024, size=(n,)).tolist() -# dim1 = [8*1024] -# dim2 = [4*1024] - -dim3 = [0] -dtype = [torch.int8] -a_order = ["row"] -out_order = ["col32", "col_turing", "col_ampere"] -transpose = [False, True] -dims = [2] -values = list( - product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose) -) -names = [ - "dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_{}".format( - *vals - ) - for vals in values -] - -@pytest.mark.parametrize( - "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", - values, - ids=names, -) +@pytest.mark.parametrize("dim1", get_test_dims(2, 1024, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(2, 1024, n=2), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", [0], ids=id_formatter("dim3")) +@pytest.mark.parametrize("dims", [2], ids=id_formatter("dims")) +@pytest.mark.parametrize("dtype", [torch.int8], ids=describe_dtype) +@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA")) +@pytest.mark.parametrize("orderOut", ["col32", "col_turing", "col_ampere"], ids=id_formatter("orderOut")) +@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): for i in range(k): if dims == 2: - A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to( - dtype - ) + A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype) elif dims == 3: - A = torch.randint( - 10, 99, size=(dim1, dim2, dim3), device="cuda" - ).to(dtype) + A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype) A.view(-1)[-1] = -1 if transpose: @@ -1337,29 +1119,14 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): assert S1[0][0] == S2[0][0] assert S1[0][1] == S2[0][1] - # print(out1) - # print(out2) + print("out1") + print(out1) + print("out2") + print(out2) torch.testing.assert_close(out1, out2) -n = 2 -# dim1 = torch.randint(2,1024, size=(n,)).tolist() -# dim2 = torch.randint(2,1024, size=(n,)).tolist() -dim1 = [1] -dim2 = [33] - -dtype = [torch.int8] -# a_order = ['col_turing', 'col_ampere'] -a_order = ["col_turing"] -out_order = ["row"] -values = list(product(dim1, dim2, dtype, a_order, out_order)) -names = [ - "dim1_{}_dim2_{}_dtype_{}_orderA_{}_orderOut_{}".format(*vals) - for vals in values -] - -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") def test_overflow(): formatB = F.get_special_format_str() print(formatB) @@ -1374,17 +1141,8 @@ def test_overflow(): c2 = torch.matmul(a.float(), b.float().t()) -n = 2 -dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() -dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist() -# dim1 = [4] -# dim2 = [5] - -values = list(product(dim1, dim2)) -names = ["dim1_{}_dim2_{}".format(*vals) for vals in values] - - -@pytest.mark.parametrize("dim1, dim2", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2")) def test_coo_double_quant(dim1, dim2): threshold = 3.00 for i in range(k): @@ -1392,36 +1150,22 @@ def test_coo_double_quant(dim1, dim2): idx = torch.abs(A) >= threshold CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) - CA, CAt, statsA, statsAt, coo_tensor = F.double_quant( - A, threshold=threshold - ) + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) if coo_tensor is not None: A1 = A * idx A2 = torch.zeros_like(A) - A2[ - coo_tensor.rowidx.long(), coo_tensor.colidx.long() - ] = coo_tensor.values + A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values torch.testing.assert_close(A1, A2) A1 = A * (idx == 0) A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() - torch.testing.assert_close( - A * (idx == 0), A2, rtol=0.05, atol=1.5e-2 - ) - + torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) -n = 2 -dim1 = torch.randint(1, 1 * 1024, size=(n,)).tolist() -dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist() -# dim1 = [7] -# dim2 = [11] -transposed_B = [False, True] -values = list(product(dim1, dim2, transposed_B)) -names = ["dim1_{}_dim2_{}_transposed_B_{}".format(*vals) for vals in values] -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim2")) +@pytest.mark.parametrize("transposed_B", TRUE_FALSE, ids=id_formatter("transposed_B")) def test_spmm_coo(dim1, dim2, transposed_B): threshold = 1.5 dim3 = torch.randint(32, 128, size=(1,)).item() @@ -1437,9 +1181,7 @@ def test_spmm_coo(dim1, dim2, transposed_B): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor( - A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values - ) + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) A2 = A * idx if transposed_B: @@ -1452,7 +1194,7 @@ def test_spmm_coo(dim1, dim2, transposed_B): assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30) -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") +@pytest.mark.benchmark def test_spmm_bench(): batch = 2 model = 1024 * 1 @@ -1479,9 +1221,7 @@ def test_spmm_bench(): print(nnz / idx.numel()) rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor( - A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values - ) + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) for i in range(10): out2 = F.spmm_coo(cooA, B) @@ -1496,14 +1236,8 @@ def test_spmm_bench(): print(tsp / t8) -n = 2 -dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist() -dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist() -values = list(product(dim1, dim2)) -names = ["dim1_{}_dim2_{}".format(*vals) for vals in values] - -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("dim1, dim2", values, ids=names) +@pytest.mark.parametrize("dim1", get_test_dims(256, 1024, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(256, 1024, n=2), ids=id_formatter("dim2")) def test_integrated_sparse_decomp(dim1, dim2): threshold = 3.0 formatB = "col_turing" @@ -1521,9 +1255,7 @@ def test_integrated_sparse_decomp(dim1, dim2): out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) - CA, CAt, statsA, statsAt, coo_tensor = F.double_quant( - A, threshold=threshold - ) + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) C32A, SA = F.transform(CA, "col32") out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) @@ -1553,23 +1285,10 @@ def test_matmuls(): print(err1, err2) -n = 2 -# dim1 = torch.randint(1,1*1024, size=(n,)).tolist() -# dim2 = torch.randint(1,4*1024, size=(n,)).tolist() -dim1 = [1 * 2048] -dim2 = [12288] -# dim1 = [32] -# dim2 = [32] -# dtype = [torch.float16, torch.int8] -dtype = [torch.float16] -out_function = ["zeros", "ones"] -values = list(product(dim1, dim2, dtype, out_function)) -names = [ - "dim1_{}_dim2_{}_dtype_{}_out_func_{}".format(*vals) for vals in values -] - - -@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names) +@pytest.mark.parametrize("dim1", [1 * 2048], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [12288], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype) +@pytest.mark.parametrize("out_func", ["zeros", "ones"], ids=id_formatter("out_func")) def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): out_func = getattr(torch, out_func) @@ -1591,9 +1310,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor( - A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values - ) + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) A2 = A * idx out1 = torch.matmul(A2.half(), B.half()) out = out_func(out1.shape, dtype=torch.float16, device=out1.device) @@ -1608,9 +1325,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): std = out1.std() out1 /= std out2 /= std - assert_all_approx_close( - out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count - ) + assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count) # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count) idx_col = torch.randint(0, A2.shape[-1], size=(15,)) @@ -1638,9 +1353,7 @@ def test_coo2csr(): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor( - A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values - ) + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) A2 = A * idx csrA = F.coo2csr(cooA) counts = csrA.rowptr[1:] - csrA.rowptr[:-1] @@ -1658,9 +1371,7 @@ def test_coo2csc(): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor( - A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values - ) + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) A2 = A * idx cscA = F.coo2csc(cooA) counts = cscA.colptr[1:] - cscA.colptr[:-1] @@ -1672,21 +1383,9 @@ def test_coo2csc(): torch.testing.assert_close(A2.t()[idx], cscA.values) -n = 2 -# dim1 = torch.randint(1,1*1024, size=(n,)).tolist() -# dim2 = torch.randint(1,4*1024, size=(n,)).tolist() -dim1 = [1 * 2048] -# dim2 = [12288] -dim2 = [2048] -# dim1 = [2] -# dim2 = [2] -dtype = [torch.int8] -values = list(product(dim1, dim2, dtype)) -names = ["dim1_{}_dim2_{}_dtype_{}".format(*vals) for vals in values] - - -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names) +@pytest.mark.parametrize("dim1", [1 * 2048]) +@pytest.mark.parametrize("dim2", [2048]) +@pytest.mark.parametrize("dtype", [torch.int8]) def test_spmm_coo_dequant(dim1, dim2, dtype): threshold = 6.0 # threshold = 2.8 @@ -1706,9 +1405,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor( - A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values - ) + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) A2 = A * idx out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) out1 = torch.matmul(A2, B.half()) @@ -1787,22 +1484,11 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): print("partial matmul", time.time() - t0) -batch_size = 1 -seqdim = 1 -values = [] -#values.append((batch_size, seqdim, 768, 4 * 768)) -#values.append((batch_size, seqdim, 1024, 4*1024)) -#values.append((batch_size, seqdim, 1536, 4*1536)) -#values.append((batch_size, seqdim, 2048, 4*2048)) -#values.append((batch_size, seqdim, 2560, 4*2560)) -#values.append((batch_size, seqdim, 4096, 4*4096)) -#values.append((batch_size, seqdim, 5120, 4*5120)) -values.append((batch_size, seqdim, 6656, 4*6656)) -#values.append((batch_size, seqdim, 8192, 4*8192)) -#values.append((batch_size, seqdim, 5140, 4*5140)) -#values.append((batch_size, seqdim, 12288, 4*12288)) -names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values] -@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) +@pytest.mark.parametrize( + ("batch", "seq", "model", "hidden"), + [pytest.param(1, 1, 6656, 4 * 6656, id="batch=1, seq=1, model=6656, hidden=26k")], +) +@pytest.mark.benchmark def test_bench_matmul(batch, seq, model, hidden): iters = 1000 formatB = F.get_special_format_str() @@ -1823,8 +1509,8 @@ def test_bench_matmul(batch, seq, model, hidden): outliers = torch.randint(0, model, size=(5,)).cuda() A[:, :, outliers] = 8.0 - linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half()) - #linearMixedBit.eval() + linearMixedBit = bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half() + # linearMixedBit.eval() linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() @@ -1841,121 +1527,123 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(iters): torch.matmul(A, B.t()) torch.cuda.synchronize() - print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + print( + f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s", + ) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # bnb.matmul_4bit(A, B_fp4.t(), quant_state=state) - #torch.cuda.synchronize() - #print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + # torch.cuda.synchronize() + # print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c) - #torch.cuda.synchronize() - #print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + # torch.cuda.synchronize() + # print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) torch.cuda.synchronize() t0 = time.time() for i in range(iters): bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) torch.cuda.synchronize() - print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + print(f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") torch.cuda.synchronize() t0 = time.time() for i in range(iters): bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c) torch.cuda.synchronize() - print( f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) - + print(f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # bnb.matmul(A, B) - #torch.cuda.synchronize() - #print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + # torch.cuda.synchronize() + # print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # bnb.matmul(A, B, threshold=6.0) - #torch.cuda.synchronize() - #print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - - #CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) - #C32A, SA = F.transform(CA, "col32") - #CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) - #CxB, SB = F.transform(CB, to_order=formatB) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + # CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) + # C32A, SA = F.transform(CA, "col32") + # CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) + # CxB, SB = F.transform(CB, to_order=formatB) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) - #torch.cuda.synchronize() - #print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - - #BA, statsB = F.vectorwise_quant(B, dim=1) - #CxB, SB = F.nvidia_transform(CB, to_order=formatB) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + # BA, statsB = F.vectorwise_quant(B, dim=1) + # CxB, SB = F.nvidia_transform(CB, to_order=formatB) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # A2 = A.view(-1, A.shape[-1]).contiguous() # CA, statsA = F.vectorwise_quant(A2, dim=1) # C32A, SA = F.nvidia_transform(CA, "col32") # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) # F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) - #torch.cuda.synchronize() - #print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - - #BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") - #CxB, SB = F.nvidia_transform(CB, to_order=formatB) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + # BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") + # CxB, SB = F.nvidia_transform(CB, to_order=formatB) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # A2 = A.view(-1, A.shape[-1]).contiguous() # CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear") # C32A, SA = F.nvidia_transform(CA, "col32") # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) # out = Cout * statsB * statsA * (1.0 / (127 * 127)) - #torch.cuda.synchronize() - #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + # torch.cuda.synchronize() + # print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - #linear8bit(A) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # linear8bit(A) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # linear8bit(A) - #torch.cuda.synchronize() - #print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + # torch.cuda.synchronize() + # print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - #linearMixedBit(A) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # linearMixedBit(A) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # linearMixedBit(A) - #torch.cuda.synchronize() - #print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + # torch.cuda.synchronize() + # print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - #linear8bit_train(A) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # linear8bit_train(A) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # linear8bit_train(A) - #torch.cuda.synchronize() - #print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + # torch.cuda.synchronize() + # print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - #linear8bit_train_thresh(A) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # linear8bit_train_thresh(A) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # linear8bit_train(A) - #torch.cuda.synchronize() - #print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + # torch.cuda.synchronize() + # print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + def test_zeropoint(): def quant_zp(x): @@ -1996,8 +1684,8 @@ def quant_zp(x): C2 -= A.sum(1).view(-1, 1) * zp ca, cqa, cza = quant_zp(A) - #print(ca.min(), ca.max()) - #print((ca - cza).min(), (ca - cza).max()) + # print(ca.min(), ca.max()) + # print((ca - cza).min(), (ca - cza).max()) zp = 1 scale = 2.0 @@ -2026,14 +1714,14 @@ def quant_zp(x): C7 -= zpa * zpb * A.shape[1] C7 /= qa * qb - #print("") + # print("") # print(C0.flatten()[:10]) - #print(C1.flatten()[:10]) - #print(C2.flatten()[:10]) - #print(C3.flatten()[:10]) - #print(C5.flatten()[:10]) - #print(C6.flatten()[:10]) - #print(C7.flatten()[:10]) + # print(C1.flatten()[:10]) + # print(C2.flatten()[:10]) + # print(C3.flatten()[:10]) + # print(C5.flatten()[:10]) + # print(C6.flatten()[:10]) + # print(C7.flatten()[:10]) err1 = torch.abs(C1 - C2).mean().item() err2 = torch.abs(C1 - C3).mean().item() err3 = torch.abs(C1 - C4).mean().item() @@ -2070,16 +1758,15 @@ def test_extract_outliers(): torch.testing.assert_close(outliers1, outliers2) - def test_blockwise_cpu_large(): diffs = [] reldiffs = [] batch = 128 seq = 128 - for hidden in [128]:#, 14336]: + for hidden in [128]: # , 14336]: for blocksize in [4096, 16384]: for i in range(2): - A1 = torch.randn(batch, seq, hidden, device='cpu') + A1 = torch.randn(batch, seq, hidden, device="cpu") t0 = time.time() C, S = F.quantize_blockwise(A1, blocksize=blocksize) A2 = F.dequantize_blockwise(C, S, blocksize=blocksize) @@ -2093,10 +1780,9 @@ def test_blockwise_cpu_large(): # print(sum(reldiffs)/len(reldiffs)) - def test_fp8_quant(): for e_bits in range(1, 7): - p_bits = 7-e_bits + p_bits = 7 - e_bits code = F.create_fp8_map(True, e_bits, p_bits).cuda() abserr = [] @@ -2106,12 +1792,12 @@ def test_fp8_quant(): C, SC = F.quantize_blockwise(A1, code=code) A2 = F.dequantize_blockwise(C, SC) diff = torch.abs(A1 - A2) - reldiff = diff/torch.abs(A1+1e-8) + reldiff = diff / torch.abs(A1 + 1e-8) abserr.append(diff.mean().item()) relerr.append(reldiff.mean().item()) - #assert diff < 0.0075 - #print(sum(abserr)/len(abserr)) - #print(sum(relerr)/len(relerr)) + # assert diff < 0.0075 + # print(sum(abserr)/len(abserr)) + # print(sum(relerr)/len(relerr)) abserr = [] relerr = [] @@ -2120,12 +1806,12 @@ def test_fp8_quant(): C, SC = F.quantize_blockwise(A1, code=code) A2 = F.dequantize_blockwise(C, SC) diff = torch.abs(A1 - A2) - reldiff = diff/torch.abs(A1+1e-8) + reldiff = diff / torch.abs(A1 + 1e-8) abserr.append(diff.mean().item()) relerr.append(reldiff.mean().item()) - #assert diff < 0.0075 - #print(sum(abserr)/len(abserr)) - #print(sum(relerr)/len(relerr)) + # assert diff < 0.0075 + # print(sum(abserr)/len(abserr)) + # print(sum(relerr)/len(relerr)) abserr = [] relerr = [] @@ -2134,50 +1820,48 @@ def test_fp8_quant(): C, SC = F.quantize_blockwise(A1) A2 = F.dequantize_blockwise(C, SC) diff = torch.abs(A1 - A2) - reldiff = diff/torch.abs(A1+1e-8) + reldiff = diff / torch.abs(A1 + 1e-8) abserr.append(diff.mean().item()) relerr.append(reldiff.mean().item()) - #assert diff < 0.0075 - #print(3, sum(abserr)/len(abserr)) - #print(3, sum(relerr)/len(relerr)) + # assert diff < 0.0075 + # print(3, sum(abserr)/len(abserr)) + # print(3, sum(relerr)/len(relerr)) def test_few_bit_quant(): - - #print('') + # print('') for bits in range(2, 9): - #print('='*30, bits, '='*30) - for method in ['linear', 'fp8', 'dynamic', 'quantile']: + # print('='*30, bits, '='*30) + for method in ["linear", "fp8", "dynamic", "quantile"]: abserrs = [] relerrs = [] code = None - if method == 'linear': + if method == "linear": code = F.create_linear_map(True, total_bits=bits).cuda() - elif method == 'fp8': - ebits = math.ceil(bits/2) - pbits = bits-ebits-1 + elif method == "fp8": + ebits = math.ceil(bits / 2) + pbits = bits - ebits - 1 code = F.create_fp8_map(True, ebits, pbits, bits).cuda() - elif method == 'dynamic': - code = F.create_dynamic_map(True, bits-0, bits).cuda() - elif method == 'quantile': - values = torch.randn(2048, 2048, device='cuda') + elif method == "dynamic": + code = F.create_dynamic_map(True, bits - 0, bits).cuda() + elif method == "quantile": + values = torch.randn(2048, 2048, device="cuda") code = F.create_quantile_map(values, bits).cuda() # for some data types we have no zero # for some data types we have one zero # for some data types we have two zeros - assert torch.unique(code).numel() in [2**bits, 2**bits-1], f'bits: {bits}, method: {method}' - #print(method, (code==0).sum()) + assert torch.unique(code).numel() in [2**bits, 2**bits - 1], f"bits: {bits}, method: {method}" + # print(method, (code==0).sum()) assert code.numel() == 256 for i in range(10): - - values = torch.randn(1, 32, device='cuda') + values = torch.randn(1, 32, device="cuda") values /= values.abs().max() - #values[values.abs() < 1e-6] += 1e-5 + # values[values.abs() < 1e-6] += 1e-5 q1 = [] v1 = [] for v in values[0]: - idx = torch.abs(v-code).argmin() + idx = torch.abs(v - code).argmin() q1.append(idx.item()) v1.append(code[idx].item()) @@ -2188,65 +1872,64 @@ def test_few_bit_quant(): v2 = F.dequantize_blockwise(q2, S2) idx = torch.isclose(q1.int(), q2.int()) - err2 = torch.abs(v2-values) + err2 = torch.abs(v2 - values) abserrs.append(err2.mean().item()) - relerrs.append((err2/(1e-10+values).abs()).mean().item()) + relerrs.append((err2 / (1e-10 + values).abs()).mean().item()) if idx.sum(): # some weird cases - err1 = torch.abs(v1-values).mean() - #assert err2.mean() <= err1 + err1 = torch.abs(v1 - values).mean() + # assert err2.mean() <= err1 else: torch.testing.assert_close(q1, q2) - #print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs)) - #assert False + # print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs)) + # assert False def test_kbit_quantile_estimation(): for i in range(100): - data = torch.randn(1024, 1024, device='cuda') + data = torch.randn(1024, 1024, device="cuda") for bits in range(2, 9): - p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits) + p = np.linspace(1.3e-4, 1 - 1.3e-4, 2**bits) val1 = torch.Tensor(norm.ppf(p)).cuda() val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits) - err = torch.abs(val1-val2).mean() + err = torch.abs(val1 - val2).mean() assert err < 0.038 for i in range(100): - data = torch.randn(1024, 1024, device='cuda') + data = torch.randn(1024, 1024, device="cuda") for bits in range(2, 4): - total_values = 2**bits-1 - p = np.linspace(0, 1, 2*total_values+1) - idx = np.arange(1, 2*total_values+1, 2) + total_values = 2**bits - 1 + p = np.linspace(0, 1, 2 * total_values + 1) + idx = np.arange(1, 2 * total_values + 1, 2) p = p[idx] - offset = 1/(2*total_values) - p = np.linspace(offset, 1-offset, total_values) + offset = 1 / (2 * total_values) + p = np.linspace(offset, 1 - offset, total_values) val1 = torch.Tensor(norm.ppf(p)).cuda() - val2 = F.estimate_quantiles(data, num_quantiles=2**bits-1) - err = torch.abs(val1-val2).mean() + val2 = F.estimate_quantiles(data, num_quantiles=2**bits - 1) + err = torch.abs(val1 - val2).mean() assert err < 0.035 +@pytest.mark.benchmark def test_bench_dequantization(): - a = torch.rand(1024, 1024, device='cuda').half() - code =F.create_fp8_map(True, 3, 0, 4).cuda() + a = torch.rand(1024, 1024, device="cuda").half() + code = F.create_fp8_map(True, 3, 0, 4).cuda() qa, SA = F.quantize_blockwise(a, code=code) print(qa.max()) - max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000 - #print(max_theoretical_mu) + max_theoretical_mu = 1024 * 1024 * 2 / 1024**3 / 672 * 1000 * 1000 + # print(max_theoretical_mu) torch.cuda.synchronize() t0 = time.time() for i in range(100): qa, SA = F.quantize_blockwise(a) torch.cuda.synchronize() - #print((time.time()-t0)/1e6) - + # print((time.time()-t0)/1e6) -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) def test_fp4_quant(dtype): vals = list(product([0, 1], repeat=4)) @@ -2255,131 +1938,132 @@ def test_fp4_quant(dtype): result = 0 bias = 3 sign, e1, e2, p1 = bits - idx = sign*8 + e1*4 + e2*2 + p1*1 + idx = sign * 8 + e1 * 4 + e2 * 2 + p1 * 1 sign = -1.0 if sign else 1.0 - exp = e1*2 + e2*1 + exp = e1 * 2 + e2 * 1 if exp == 0: # sub-normal - if p1 == 0: result = 0 - else: result = sign*0.0625 + if p1 == 0: + result = 0 + else: + result = sign * 0.0625 else: # normal - exp = 2**(-exp + bias + 1) + exp = 2 ** (-exp + bias + 1) frac = 1.5 if p1 else 1.0 - result = sign*exp*frac + result = sign * exp * frac code[idx] = result - A1 = torch.randn(1024, 1024, device='cuda', dtype=dtype) - qa, SA = F.quantize_fp4(A1, blocksize=64) + A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype) + qa, SA = F.quantize_fp4(A1, blocksize=128) A2 = F.dequantize_fp4(qa, SA) err = (A1 - A2).abs().float() - relerr = (err/(A1.abs().float()+1e-8)).mean() + relerr = (err / (A1.abs().float() + 1e-8)).mean() idx = err > 1.0 err = err.mean() assert A2.dtype == dtype - assert err.item() < 0.1 + assert err.item() < 0.11 assert relerr.item() < 0.28 -@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) +@pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) def test_4bit_compressed_stats(quant_type): - blocksizes = [128, 64] if not HIP_ENVIRONMENT else [128] - for blocksize in blocksizes: + for blocksize in [256, 128]: errs1 = [] errs2 = [] for i in range(10): - A1 = torch.randn(1024, 1024, device='cuda').half() + A1 = torch.randn(1024, 1024, device="cuda").half() q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) - q3, SA3= F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type) + q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type) A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type) A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type) - err = (A1 - A2).abs().float() - relerr = (err/(A1.abs().float()+1e-15)).mean() + relerr = (err / (A1.abs().float() + 1e-15)).mean() err = err.mean() errs1.append(err.item()) - assert err.item() < 0.11 - assert relerr.item() < 0.28 + print(err.item() ) + assert relerr.item() < 0.3 err = (A1 - A3).abs().float() - relerr = (err/(A1.abs().float()+1e-15)).mean() + relerr = (err / (A1.abs().float() + 1e-15)).mean() err = err.mean() errs2.append(err.item()) assert err.item() < 0.11 - assert relerr.item() < 0.28 - - #print(sum(errs1)/len(errs1), blocksize, quant_type) - #print(sum(errs2)/len(errs2), blocksize, quant_type) + assert relerr.item() < 0.3 + # print(sum(errs1)/len(errs1), blocksize, quant_type) + # print(sum(errs2)/len(errs2), blocksize, quant_type) - -#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) -@pytest.mark.parametrize("quant_type", ['nf4']) +# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) +@pytest.mark.parametrize("quant_type", ["nf4"]) +@pytest.mark.benchmark def test_bench_4bit_dequant(quant_type): blocksize = 256 - a = torch.rand(1024*12*4, 1024*12, device='cuda').half() + a = torch.rand(1024 * 12 * 4, 1024 * 12, device="cuda").half() qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type) - input_size = a.numel()/2 - output_size = a.numel()*2 - num_bytes = input_size+output_size - GB = num_bytes/1e9 - max_theoretical_s = GB/768 - #print(max_theoretical_s*1e6) - b = torch.randn(128, 1024*12, device='cuda').half() + input_size = a.numel() / 2 + output_size = a.numel() * 2 + num_bytes = input_size + output_size + GB = num_bytes / 1e9 + max_theoretical_s = GB / 768 + # print(max_theoretical_s*1e6) + b = torch.randn(128, 1024 * 12, device="cuda").half() iters = 100 torch.cuda.synchronize() t0 = time.time() for i in range(iters): F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) - #b.copy_(a) + # b.copy_(a) torch.cuda.synchronize() - #print((time.time()-t0)/iters*1e6) + # print((time.time()-t0)/iters*1e6) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(iters): + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): # torch.matmul(b, a.t()) - #torch.cuda.synchronize() - #print((time.time()-t0)/iters*1e6) - + # torch.cuda.synchronize() + # print((time.time()-t0)/iters*1e6) def test_normal_map_tree(): code = F.create_normal_map() - values =code[:8].tolist() + code[-8:].tolist() + values = code[:8].tolist() + code[-8:].tolist() num_pivots = 1 - #print(values) - while num_pivots <16: - idx = list(range(16//num_pivots//2, 16, 16//num_pivots)) - #print(idx) + # print(values) + while num_pivots < 16: + idx = list(range(16 // num_pivots // 2, 16, 16 // num_pivots)) + # print(idx) num_pivots *= 2 pivots = [] for i in idx: - pivots.append((values[i-1]+values[i])/2) - #print(pivots) + pivots.append((values[i - 1] + values[i]) / 2) + # print(pivots) -@pytest.mark.parametrize("double_quant", [True, False], ids=['DQ_True', 'DQ_False']) -@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) -@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'], ids=['fc1', 'fc2', 'attn', 'attn_packed']) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32']) -@pytest.mark.parametrize("quant_storage", [torch.uint8, torch.float16, torch.bfloat16, torch.float32], ids=['uint8', 'fp16', 'bf16', 'fp32']) -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64") +@pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") +@pytest.mark.parametrize("storage_type", ["nf4", "fp4"]) +@pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) +@pytest.mark.parametrize( + "quant_storage", + [torch.uint8, torch.float16, torch.bfloat16, torch.float32], + ids=describe_dtype, +) def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): for dim in [128, 256, 512, 1024]: - #for dim in [4*1024]: - #for dim in [1*16]: + # for dim in [4*1024]: + # for dim in [1*16]: errs1 = [] errs2 = [] errs3 = [] @@ -2390,38 +2074,42 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): max_errs2 = [] max_errs3 = [] - for i in range(100): - if kind == 'fc1': - A = torch.randn(1, dim, dtype=dtype, device='cuda') - B = torch.randn(dim*4, dim, dtype=dtype, device='cuda')/math.sqrt(dim) - elif kind == 'fc2': - A = torch.randn(1, 4*dim, dtype=dtype, device='cuda') - B = torch.randn(dim, 4*dim, dtype=dtype, device='cuda')/math.sqrt(dim) - elif kind == 'attn': - A = torch.randn(1, dim, dtype=dtype, device='cuda') - B = torch.randn(dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim) - elif kind == 'attn_packed': - A = torch.randn(1, dim, dtype=dtype, device='cuda') - B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim) - - qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant, quant_storage=quant_storage) + if kind == "fc1": + A = torch.randn(1, dim, dtype=dtype, device="cuda") + B = torch.randn(dim * 4, dim, dtype=dtype, device="cuda") / math.sqrt(dim) + elif kind == "fc2": + A = torch.randn(1, 4 * dim, dtype=dtype, device="cuda") + B = torch.randn(dim, 4 * dim, dtype=dtype, device="cuda") / math.sqrt(dim) + elif kind == "attn": + A = torch.randn(1, dim, dtype=dtype, device="cuda") + B = torch.randn(dim, dim, dtype=dtype, device="cuda") / math.sqrt(dim) + elif kind == "attn_packed": + A = torch.randn(1, dim, dtype=dtype, device="cuda") + B = torch.randn(dim * 3, dim, dtype=dtype, device="cuda") / math.sqrt(dim) + + qB, state = F.quantize_4bit( + B, + quant_type=storage_type, + compress_statistics=double_quant, + quant_storage=quant_storage, + ) C3 = torch.matmul(A, B.t()) C2 = F.gemv_4bit(A, qB.t(), state=state) A.requires_grad = True C1 = bnb.matmul_4bit(A, qB.t(), state) - err1 = (C1-C2).abs().float() - err2 = (C3-C2).abs().float() - err3 = (C3-C1).abs().float() + err1 = (C1 - C2).abs().float() + err2 = (C3 - C2).abs().float() + err3 = (C3 - C1).abs().float() - mag1 = torch.abs(C1).float()+1e-5 - mag2 = torch.abs(C3).float()+1e-5 - mag3 = torch.abs(C3).float()+1e-5 + mag1 = torch.abs(C1).float() + 1e-5 + mag2 = torch.abs(C3).float() + 1e-5 + mag3 = torch.abs(C3).float() + 1e-5 - relerr1 = err1/mag1 - relerr2 = err2/mag2 - relerr3 = err3/mag3 + relerr1 = err1 / mag1 + relerr2 = err2 / mag2 + relerr3 = err3 / mag3 max_err1 = err1.max() max_err2 = err2.max() @@ -2439,38 +2127,38 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): max_errs2.append(max_err2.item()) max_errs3.append(max_err3.item()) - c = int(C1.numel()*0.0014*(dim/256))+1 + c = int(C1.numel() * 0.0014 * (dim / 256)) + 1 c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False) - err1 = sum(errs1)/len(errs1)/math.sqrt(dim) - err2 = sum(errs2)/len(errs2)/math.sqrt(dim) - err3 = sum(errs3)/len(errs3)/math.sqrt(dim) - relerr1 = sum(relerrs1)/len(relerrs1)/math.sqrt(dim) - relerr2 = sum(relerrs2)/len(relerrs2)/math.sqrt(dim) - relerr3 = sum(relerrs3)/len(relerrs3)/math.sqrt(dim) - maxerr1 = sum(max_errs1)/len(max_errs1)/math.sqrt(dim) - maxerr2 = sum(max_errs2)/len(max_errs2)/math.sqrt(dim) - maxerr3 = sum(max_errs3)/len(max_errs3)/math.sqrt(dim) - absratio = err2/err3 - relratio = relerr2/relerr3 - maxratio = relerr2/relerr3 + err1 = sum(errs1) / len(errs1) / math.sqrt(dim) + err2 = sum(errs2) / len(errs2) / math.sqrt(dim) + err3 = sum(errs3) / len(errs3) / math.sqrt(dim) + relerr1 = sum(relerrs1) / len(relerrs1) / math.sqrt(dim) + relerr2 = sum(relerrs2) / len(relerrs2) / math.sqrt(dim) + relerr3 = sum(relerrs3) / len(relerrs3) / math.sqrt(dim) + maxerr1 = sum(max_errs1) / len(max_errs1) / math.sqrt(dim) + maxerr2 = sum(max_errs2) / len(max_errs2) / math.sqrt(dim) + maxerr3 = sum(max_errs3) / len(max_errs3) / math.sqrt(dim) + absratio = err2 / err3 + relratio = relerr2 / relerr3 + maxratio = relerr2 / relerr3 # for debugging if the tests fails # - #print('='*80) - #print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:') - #print(C1.flatten()[-20:]) - #print(C2.flatten()[-20:]) - #print(f'inference vs training abs: {err1}') - #print(f'inference vs training rel: {relerr1}') - #print(f'inference vs training max: {maxerr1}') - #print(f'inference vs training vs torch err ratio abs: {absratio}') - #print(f'inference vs training vs torch err ratio rel: {relratio}') - #print(f'inference vs training vs torch err ratio max: {maxratio}') + # print('='*80) + # print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:') + # print(C1.flatten()[-20:]) + # print(C2.flatten()[-20:]) + # print(f'inference vs training abs: {err1}') + # print(f'inference vs training rel: {relerr1}') + # print(f'inference vs training max: {maxerr1}') + # print(f'inference vs training vs torch err ratio abs: {absratio}') + # print(f'inference vs training vs torch err ratio rel: {relratio}') + # print(f'inference vs training vs torch err ratio max: {maxratio}') if dtype == torch.float16: if dim <= 512: assert err1 < 7e-5 - assert relerr1 < 0.0008 + assert relerr1 < 0.0015 else: assert err1 < 6e-5 assert relerr1 < 2e-4 @@ -2479,13 +2167,13 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): assert maxratio < 1.005 and maxratio > 0.995 elif dtype == torch.float32: if dim <= 512: - assert err1 < 5e-8 + assert err1 < 5.5e-8 assert relerr1 < 1e-6 - assert maxerr1 < 1e-7 + assert maxerr1 < 4.5e-7 else: - assert err1 < 5e-8 + assert err1 < 5.5e-8 assert relerr1 < 8e-6 - assert maxerr1 < 1e-7 + assert maxerr1 < 4.5e-7 assert absratio < 1.005 and absratio > 0.995 assert relratio < 1.005 and relratio > 0.995 assert maxratio < 1.005 and maxratio > 0.995 @@ -2499,59 +2187,62 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): assert relerr1 < 0.002 assert maxerr1 < 0.0012 assert absratio < 1.005 and absratio > 0.995 - assert relratio < 1.04 and relratio > 0.96 - assert maxratio < 1.02 and maxratio > 0.98 + assert relratio < 1.04 and relratio > 0.95 + assert maxratio < 1.04 and maxratio > 0.97 + @pytest.mark.skip("Row scale has some bugs for ampere") def test_managed(): - n = 32*10 + n = 32 * 10 A = F.get_paged(n, n, dtype=torch.float32) B = F.get_paged(n, n, dtype=torch.uint8) B2 = F.get_paged(n, n, dtype=torch.float32) assert A.is_paged assert B.is_paged - assert A.page_deviceid==0 - assert B.page_deviceid==0 + assert A.page_deviceid == 0 + assert B.page_deviceid == 0 F.fill(A, 17.0) F.fill(B, 17) F.fill(B2, 2) - assert (A==17).sum().item() == n*n - assert (B==17).sum().item() == n*n - C = A*B.float() - assert (C==289).sum().item() == n*n + assert (A == 17).sum().item() == n * n + assert (B == 17).sum().item() == n * n + C = A * B.float() + assert (C == 289).sum().item() == n * n F._mul(A, B2) F._mul(A, B2) F._mul(A, B2) - assert (A==17*(2**3)).sum().item() == n*n - # F.prefetch_tensor(A) - # F.prefetch_tensor(B) + assert (A == 17 * (2**3)).sum().item() == n * n + +# F.prefetch_tensor(A) +# F.prefetch_tensor(B) - # F.fill(B2, 17.0) - # F._mul(A, B2) - # F.prefetch_tensor(A, to_cpu=True) - # F.prefetch_tensor(B, to_cpu=True) - # F.prefetch_tensor(B2, to_cpu=True) - # torch.cuda.synchronize() +# F.fill(B2, 17.0) +# F._mul(A, B2) - # assert (A==17).sum().item() == n*n +# F.prefetch_tensor(A, to_cpu=True) +# F.prefetch_tensor(B, to_cpu=True) +# F.prefetch_tensor(B2, to_cpu=True) +# torch.cuda.synchronize() - # torch.testing.assert_close(A, torch.ones(A.shape)*289) +# assert (A==17).sum().item() == n*n +# torch.testing.assert_close(A, torch.ones(A.shape)*289) -@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4']) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32']) -@pytest.mark.parametrize("double_quant", [False], ids=['DQ_True']) + +@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) +@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) def test_gemv_eye_4bit(storage_type, dtype, double_quant): dims = 10 torch.random.manual_seed(np.random.randint(0, 412424242)) - dims = torch.randint(0, 8192, size=(dims,)).tolist() - dims = [dim + (64-(dim % 64)) for dim in dims] - #for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]: + dims = get_test_dims(0, 8192, n=dims) + dims = [dim + (64 - (dim % 64)) for dim in dims] + # for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]: for dim in dims: - A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device='cuda') - B = torch.eye(dim, dtype=dtype, device='cuda') + A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device="cuda") + B = torch.eye(dim, dtype=dtype, device="cuda") qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant) C3 = torch.matmul(A, B.t()) @@ -2562,7 +2253,5 @@ def test_gemv_eye_4bit(storage_type, dtype, double_quant): torch.testing.assert_close(A, C3) torch.testing.assert_close(A, C1) torch.testing.assert_close(A, C2) - #torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001) - #torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080) - - + # torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001) + # torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080) diff --git a/tests/test_generation.py b/tests/test_generation.py index 54ec10475..911aa14da 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -1,91 +1,81 @@ -import pytest -import torch -import math - from itertools import product +import math -import transformers -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoTokenizer, - BitsAndBytesConfig, - GenerationConfig, - set_seed, +import pytest +import torch -) +from tests.helpers import TRUE_FALSE, describe_dtype, id_formatter -import bitsandbytes as bnb -from bitsandbytes.cextension import HIP_ENVIRONMENT +transformers = pytest.importorskip("transformers") def get_4bit_config(): - return BitsAndBytesConfig( - load_in_4bit=True, - load_in_8bit=False, - llm_int8_threshold=6.0, - llm_int8_has_fp16_weight=False, - bnb_4bit_compute_dtype=torch.float16, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type='nf4', - ) + return transformers.BitsAndBytesConfig( + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + ) def get_model_and_tokenizer(config): model_name_or_path, quant_type = config bnb_config = get_4bit_config() - if quant_type == '16bit': + if quant_type == "16bit": bnb_config.load_in_4bit = False else: - bnb_config.bnb_4bit_quant_type= quant_type - model = AutoModelForCausalLM.from_pretrained(model_name_or_path, + bnb_config.bnb_4bit_quant_type = quant_type + model = transformers.AutoModelForCausalLM.from_pretrained( + model_name_or_path, quantization_config=bnb_config, - max_memory={0:'48GB'}, - device_map='auto', - torch_dtype=torch.bfloat16 - ).eval() + max_memory={0: "48GB"}, + device_map="auto", + torch_dtype=torch.bfloat16, + ).eval() tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path) return model, tokenizer + def get_prompt_for_generation_eval(text, add_roles=True): description = ( "A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions." ) if add_roles: - prompt = f'{description} ### Human: {text} ### Assistant:' + prompt = f"{description} ### Human: {text} ### Assistant:" else: - prompt = f'{description} {text}' + prompt = f"{description} {text}" return prompt + def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_for_generation_eval): text = prompt_func(text) - inputs = tokenizer(text, return_tensors="pt").to('cuda:0') - outputs = model.generate(inputs=inputs['input_ids'], generation_config=generation_config) + inputs = tokenizer(text, return_tensors="pt").to("cuda:0") + outputs = model.generate(inputs=inputs["input_ids"], generation_config=generation_config) return tokenizer.decode(outputs[0], skip_special_tokens=True) -models = ['huggyllama/llama-7b', 'bigscience/bloom-1b7'] -dtypes = ['nf4', 'fp4'] -load_in_4bit = [True, False] -values = list(product(models, dtypes)) -strfunc = lambda lst: [str(x) for x in lst] -ids = ['_'.join(strfunc(x)) for x in values] -@pytest.fixture(scope='session', params=values, ids=ids) + +models = ["huggyllama/llama-7b", "bigscience/bloom-1b7"] +dtypes = ["nf4", "fp4"] + + +@pytest.fixture(scope="session", params=product(models, dtypes)) def model_and_tokenizer(request): model, tokenizer = get_model_and_tokenizer(request.param) yield request.param, model, tokenizer del model -@pytest.mark.parametrize("DQ", [True, False], ids=['DQ_True', 'DQ_False']) -@pytest.mark.parametrize("inference_kernel", [True, False], ids=['inference_kernel_True', 'inference_kernel_False']) -#@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32']) -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ): - print('') - dtype = torch.float16 +@pytest.mark.parametrize("DQ", TRUE_FALSE, ids=id_formatter("dq")) +@pytest.mark.parametrize("inference_kernel", TRUE_FALSE, ids=id_formatter("inference_kernel")) +@pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype) +@pytest.mark.slow +def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype): fixture_config, model, tokenizer = model_and_tokenizer generation_config = transformers.GenerationConfig( @@ -96,20 +86,19 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ): ) generation_config.max_new_tokens = 20 - - #text = 'Please write down the first 50 digits of pi.' - #text = get_prompt_for_generation_eval(text) - #text += ' Sure, here the first 50 digits of pi: 3.14159' + # text = 'Please write down the first 50 digits of pi.' + # text = get_prompt_for_generation_eval(text) + # text += ' Sure, here the first 50 digits of pi: 3.14159' n_cases = 6 - text = '3.14159' - if hasattr(model.config, 'quantization_config'): + text = "3.14159" + if hasattr(model.config, "quantization_config"): model.config.quantization_config.bnb_4bit_compute_dtype = dtype model.config.quantization_config.bnb_4bit_use_double_quant = DQ if not inference_kernel: - text = [text]*n_cases - inputs = tokenizer(text, return_tensors="pt").to('cuda:0') - x = inputs['input_ids'] + text = [text] * n_cases + inputs = tokenizer(text, return_tensors="pt").to("cuda:0") + x = inputs["input_ids"] outputs = [] if inference_kernel: for i in range(n_cases): @@ -120,18 +109,14 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ): outputs = model.generate(x, generation_config=generation_config) outputs = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs] - assert len(outputs) == n_cases failure_count = 0 for i in range(n_cases): - if not outputs[i][:len(str(math.pi))] == str(math.pi): + if not outputs[i][: len(str(math.pi))] == str(math.pi): failure_count += 1 - failure_max = (2 if fixture_config[0] == 'huggyllama/llama-7b' else 4) + failure_max = 2 if fixture_config[0] == "huggyllama/llama-7b" else 4 if failure_count > failure_max: print(math.pi) for out in outputs: print(out) - raise ValueError(f'Failure count: {failure_count}/{n_cases}') - - - + raise ValueError(f"Failure count: {failure_count}/{n_cases}") diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 478255eee..bbbd05335 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -1,25 +1,28 @@ +import copy import os -from contextlib import nullcontext -from itertools import product +import pickle from tempfile import TemporaryDirectory import pytest import torch import bitsandbytes as bnb +from tests.helpers import TRUE_FALSE, torch_load_from_buffer, torch_save_to_buffer storage = { - 'uint8': torch.uint8, - 'float16': torch.float16, - 'bfloat16': torch.bfloat16, - 'float32': torch.float32 + "uint8": torch.uint8, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, } -@pytest.mark.parametrize( - "quant_type, compress_statistics, bias, quant_storage", - list(product(["nf4", "fp4"], [False, True], [False, True], ['uint8', 'float16', 'bfloat16', 'float32'])), -) -def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage): + +@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) +@pytest.mark.parametrize("bias", TRUE_FALSE) +@pytest.mark.parametrize("compress_statistics", TRUE_FALSE) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) +@pytest.mark.parametrize("save_before_forward", TRUE_FALSE) +def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage, save_before_forward): original_dtype = torch.float16 compute_dtype = None device = "cuda" @@ -81,7 +84,12 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora quant_storage=storage[quant_storage], device="meta", ) - linear_qs.weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False, quant_type=quant_type, quant_storage=storage[quant_storage]) + linear_qs.weight = bnb.nn.Params4bit( + data=linear.weight, + requires_grad=False, + quant_type=quant_type, + quant_storage=storage[quant_storage], + ) if bias: linear_qs.bias = torch.nn.Parameter(linear.bias) linear_qs = linear_qs.to(device) @@ -92,7 +100,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora q0 = a.quant_state q1 = b.quant_state - for attr in ('code', 'dtype', 'blocksize', 'absmax'): + for attr in ("code", "dtype", "blocksize", "absmax"): c, d = getattr(q0, attr), getattr(q1, attr) if isinstance(c, torch.Tensor): assert torch.equal(c, d) @@ -100,7 +108,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora assert c == d, f"{c} != {d}" if q0.state2 is not None: - for attr in ('code', 'dtype', 'blocksize', 'absmax'): + for attr in ("code", "dtype", "blocksize", "absmax"): c, d = getattr(q0.state2, attr), getattr(q1.state2, attr) if isinstance(c, torch.Tensor): assert torch.equal(c, d) @@ -113,6 +121,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora assert a.dtype == b.dtype assert torch.equal(a, b) + if save_before_forward: + bytes_4bit = torch_save_to_buffer(linear_q) + # Forward test x = torch.rand(42, layer_shape[0], device=device) a = linear_q(x) @@ -125,14 +136,23 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora assert torch.equal(a, b) assert torch.equal(a, c) + if not save_before_forward: + bytes_4bit = torch_save_to_buffer(linear_q) + linear_q3 = torch_load_from_buffer(bytes_4bit) + # Test moving to CPU and back to GPU - linear_q2.to('cpu') + linear_q2.to("cpu") linear_q2.to(device) d = linear_qs(x) assert c.dtype == d.dtype assert c.device == d.device assert torch.equal(c, d) + d = linear_q3(x) + assert c.dtype == d.dtype + assert c.device == d.device + assert torch.equal(c, d) + # Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias with TemporaryDirectory() as tmpdir: state_path_4bit = os.path.join(tmpdir, "state_4bit.pth") @@ -140,10 +160,49 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora torch.save(linear.state_dict(), state_path) torch.save(linear_q.state_dict(), state_path_4bit) - size_orig, size_4 = os.path.getsize(state_path), os.path.getsize( - state_path_4bit + size_orig, size_4 = ( + os.path.getsize(state_path), + os.path.getsize(state_path_4bit), ) size_ratio = size_4 / size_orig - target_compression = 0.143 if original_dtype == torch.float32 else 0.29 # these numbers get lower as weight shape increases - ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}" + target_compression = ( + 0.143 if original_dtype == torch.float32 else 0.29 + ) # these numbers get lower as weight shape increases + ratio_error_msg = ( + f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}" + ) assert size_ratio < target_compression, ratio_error_msg + + +def test_copy_param(): + tensor = torch.tensor([1.0, 2.0, 3.0, 4.0]) + param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0) + + shallow_copy_param = copy.copy(param) + assert param.quant_state is shallow_copy_param.quant_state + assert param.data.data_ptr() == shallow_copy_param.data.data_ptr() + + +def test_deepcopy_param(): + tensor = torch.tensor([1.0, 2.0, 3.0, 4.0]) + param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0) + copy_param = copy.deepcopy(param) + assert param.quant_state is not copy_param.quant_state + assert param.data.data_ptr() != copy_param.data.data_ptr() + + +def test_params4bit_real_serialization(): + original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32) + original_param = bnb.nn.Params4bit(data=original_tensor, quant_type="fp4") + + original_param.cuda(0) # move to CUDA to trigger quantization + + serialized_param = pickle.dumps(original_param) + deserialized_param = pickle.loads(serialized_param) + + assert torch.equal(original_param.data, deserialized_param.data) + assert original_param.requires_grad == deserialized_param.requires_grad == False + assert original_param.quant_type == deserialized_param.quant_type + assert original_param.blocksize == deserialized_param.blocksize + assert original_param.compress_statistics == deserialized_param.compress_statistics + assert original_param.quant_state == deserialized_param.quant_state diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 6d5fc6a82..4b62abd6d 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -1,6 +1,5 @@ -import os from contextlib import nullcontext -from itertools import product +import os from tempfile import TemporaryDirectory import pytest @@ -10,12 +9,17 @@ from bitsandbytes import functional as F from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout from bitsandbytes.nn.modules import Linear8bitLt -from bitsandbytes.cextension import HIP_ENVIRONMENT +from tests.helpers import ( + TRUE_FALSE, + id_formatter, + torch_load_from_buffer, + torch_save_to_buffer, +) # contributed by Alex Borzunov, see: # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") + @pytest.mark.skipif( not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5), reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs", @@ -47,7 +51,9 @@ def test_linear_no_igemmlt(): linear_custom.state.force_no_igemmlt = True linear_custom.weight = bnb.nn.Int8Params( - linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False + linear.weight.data.clone(), + requires_grad=False, + has_fp16_weights=False, ).to(linear.weight.dtype) linear_custom.bias = linear.bias linear_custom = linear_custom.cuda() @@ -68,10 +74,20 @@ def test_linear_no_igemmlt(): assert linear_custom.state.CxB is None -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt", - list(product([False, True], [False, True], [False, True], [False, True]))) -def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt): +@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) +@pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward")) +@pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda")) +@pytest.mark.parametrize("force_no_igemmlt", TRUE_FALSE, ids=id_formatter("force_no_igemmlt")) +@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward")) +@pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda")) +def test_linear_serialization( + has_fp16_weights, + serialize_before_forward, + deserialize_before_cuda, + force_no_igemmlt, + save_before_forward, + load_before_cuda, +): linear = torch.nn.Linear(32, 96) x = torch.randn(3, 32, dtype=torch.half) @@ -86,7 +102,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri linear_custom.state.force_no_igemmlt = True linear_custom.weight = bnb.nn.Int8Params( - linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights + linear.weight.data.clone(), + requires_grad=has_fp16_weights, + has_fp16_weights=has_fp16_weights, ) linear_custom.bias = linear.bias linear_custom = linear_custom.cuda() @@ -94,6 +112,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri if serialize_before_forward: state_dict_8bit = linear_custom.state_dict() + if save_before_forward: + bytes_8bit = torch_save_to_buffer(linear_custom) + x_first = x.clone().cuda().requires_grad_(True) fx_first = linear_custom(x_first).float() grad_proj = torch.randn_like(fx_first) @@ -102,6 +123,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri if not serialize_before_forward: state_dict_8bit = linear_custom.state_dict() + if not save_before_forward: + bytes_8bit = torch_save_to_buffer(linear_custom) + with TemporaryDirectory() as tmpdir: state_path_8bit = os.path.join(tmpdir, "state_8bit.pth") state_path = os.path.join(tmpdir, "state.pth") @@ -128,16 +152,28 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError): new_linear_custom.load_state_dict(new_state_dict, strict=True) + if load_before_cuda: + new_linear_custom2 = torch_load_from_buffer(bytes_8bit) + new_linear_custom = new_linear_custom.cuda() if not deserialize_before_cuda: new_linear_custom.load_state_dict(new_state_dict, strict=True) + if not load_before_cuda: + new_linear_custom2 = torch_load_from_buffer(bytes_8bit) + x_second = x.clone().cuda().requires_grad_(True) fx_second = new_linear_custom(x_second).float() (fx_second * grad_proj).mean().backward() + x_third = x.clone().cuda().requires_grad_(True) + fx_third = new_linear_custom2(x_third).float() + (fx_third * grad_proj).mean().backward() + # if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised if has_fp16_weights or not deserialize_before_cuda: assert torch.allclose(fx_first, fx_second, atol=1e-5) assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5) + assert torch.allclose(fx_first, fx_third, atol=1e-5) + assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5) diff --git a/tests/test_modules.py b/tests/test_modules.py index 3e28a0f21..f7e439e3b 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -1,11 +1,13 @@ -from itertools import product +import math +import einops import pytest import torch from torch import nn import bitsandbytes as bnb -from bitsandbytes.cextension import HIP_ENVIRONMENT +from tests.helpers import id_formatter + class MockArgs: def __init__(self, initial_data): @@ -17,12 +19,18 @@ class MLP8bit(torch.nn.Module): def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0): super().__init__() self.fc1 = bnb.nn.Linear8bitLt( - dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, - threshold=threshold + dim1, + dim2, + has_fp16_weights=has_fp16_weights, + memory_efficient_backward=memory_efficient_backward, + threshold=threshold, ) self.fc2 = bnb.nn.Linear8bitLt( - dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, - threshold=threshold + dim2, + dim1, + has_fp16_weights=has_fp16_weights, + memory_efficient_backward=memory_efficient_backward, + threshold=threshold, ) def forward(self, x): @@ -40,19 +48,17 @@ def get_args(): def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): - idx = torch.isclose(a, b, rtol, atol) + idx = torch.isclose(a, b, rtol=rtol, atol=atol) sumval = (idx == 0).sum().item() if sumval > count: print(f"Too many values not close: assert {sumval} < {count}") - torch.testing.assert_close(a, b, rtol, atol) + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) class LinearFunction(torch.autograd.Function): @staticmethod def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0): - round_func = ( - LinearFunction.round_stoachastic if stochastic else torch.round - ) + round_func = LinearFunction.round_stoachastic if stochastic else torch.round norm = math.sqrt(math.pi) / math.sqrt(2.0) # std = torch.abs(x).mean()*norm std = torch.std(x) @@ -120,9 +126,7 @@ def dequant_min_max(xq, A, B, SA, SB, dtype): return x.to(dtype) def get_8bit_linear(x, stochastic=False): - round_func = ( - LinearFunction.round_stoachastic if stochastic else torch.round - ) + round_func = LinearFunction.round_stoachastic if stochastic else torch.round max1 = torch.abs(x).max() x = x / max1 * 127 x = round_func(x) / 127 * max1 @@ -131,9 +135,7 @@ def get_8bit_linear(x, stochastic=False): @staticmethod def get_8bit_vector_wise(x, dim, stochastic=False): - round_func = ( - LinearFunction.round_stoachastic if stochastic else torch.round - ) + round_func = LinearFunction.round_stoachastic if stochastic else torch.round max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) max1[max1 == 0] = 1.0 x = (x * 127) / max1 @@ -217,9 +219,7 @@ def forward(ctx, x, weight, bias=None, args=None): weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2) outputq = bnb.functional.igemm(x8, weight8.t()) - output = LinearFunction.dequant( - outputq, S1, S2, x.dtype, args.quant_type - ) + output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type) # if torch.rand(1) < 0.01: # output32 = torch.matmul(x, weight.t()) # err = torch.abs(output-output32).float() @@ -248,37 +248,25 @@ def backward(ctx, grad_output): # weight and x are already 8bit # -> transform grad_output to 8-bit if args.use_8bit_training == "forward+wgrad": - grad_output8, S1 = LinearFunction.quant( - grad_output, args.quant_type, dim=[0, 1] - ) + grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) grad_weight8 = bnb.functional.igemm(grad_output8, x8) - grad_weight = LinearFunction.dequant( - grad_weight8, S1, S2, grad_output.dtype, args.quant_type - ) + grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x) grad_input = grad_output.matmul(weight) elif args.use_8bit_training == "full": - grad_output8, S1 = LinearFunction.quant( - grad_output, args.quant_type, dim=[0, 1] - ) + grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) grad_weight8 = torch.zeros_like(weight, dtype=torch.int32) bnb.functional.igemm(grad_output8, x8, out=grad_weight8) - grad_weight = LinearFunction.dequant( - grad_weight8, S1, S2, grad_output.dtype, args.quant_type - ) + grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) - grad_output8, S1 = LinearFunction.quant( - grad_output, args.quant_type, dim=2 - ) + grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2) weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0) grad_input8 = bnb.functional.igemm(grad_output8, weight8) - grad_input = LinearFunction.dequant( - grad_input8, S1, S3, grad_output.dtype, args.quant_type - ) + grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type) else: grad_input = grad_output.matmul(weight) @@ -310,13 +298,7 @@ def forward(self, x): return LinearFunction.apply(x, self.weight, self.bias, self.args) -threshold = [0.0, 3.0] -values = threshold -names = [f"threshold_{vals}" for vals in values] - - -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("threshold", values, ids=names) +@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("threshold")) def test_linear8bitlt_inference(threshold): l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half() assert l1.weight.device.type == "cuda" @@ -330,7 +312,6 @@ def test_linear8bitlt_inference(threshold): assert l1.state.CxB is not None -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") def test_linear8bitlt_accumulated_gradient(): l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]) l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]) @@ -361,12 +342,8 @@ def test_linear8bitlt_accumulated_gradient(): opt1.zero_grad(True) opt2.step() opt2.zero_grad(True) - assert_all_approx_close( - l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2 - ) - assert_all_approx_close( - l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2 - ) + assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2) + assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2) # we do this copy because otherwise we have small divergences over time that add up l1[0].weight.data.copy_(l2[0].weight.data) l1[1].weight.data.copy_(l2[1].weight.data) @@ -380,7 +357,17 @@ def test_linear8bitlt_accumulated_gradient(): @pytest.mark.parametrize("threshold", [0.0, 2.0]) @pytest.mark.parametrize("memory_efficient_backward", [False]) def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): - l1 = (bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half()) + l1 = ( + bnb.nn.Linear8bitLt( + 32, + 64, + threshold=threshold, + has_fp16_weights=False, + memory_efficient_backward=memory_efficient_backward, + ) + .cuda() + .half() + ) assert l1.weight.dtype == torch.int8 l1.eval() @@ -402,11 +389,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): if threshold > 0: assert mlp.fc2.state.idx is not None - mlp = ( - MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) - .cuda() - .half() - ) + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half() assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 @@ -419,11 +402,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): if threshold > 0: assert mlp.fc2.state.idx is not None - mlp = ( - MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) - .half() - .cuda() - ) + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda() for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() @@ -436,7 +415,17 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 - mlp = ( MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).half().to("cuda")) + mlp = ( + MLP8bit( + 32, + 64, + threshold=threshold, + has_fp16_weights=False, + memory_efficient_backward=memory_efficient_backward, + ) + .half() + .to("cuda") + ) for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() @@ -452,8 +441,12 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): assert mlp.fc2.weight.device.type == "cuda" mlp = MLP8bit( - 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward - ) + 32, + 64, + threshold=threshold, + has_fp16_weights=False, + memory_efficient_backward=memory_efficient_backward, + ) w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization, mlp = mlp.cuda().half() # and this line triggers quantization @@ -488,7 +481,14 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): assert (idx == 0).sum().item() <= b1.numel() * 0.005 -@pytest.mark.parametrize("module", [lambda nin, nout, bias=True: bnb.nn.Linear8bitLt(nin, nout, bias=bias, has_fp16_weights=False), bnb.nn.LinearFP4], ids=['Int8Lt', 'FP4']) +@pytest.mark.parametrize( + "module", + [ + lambda n_in, n_out, bias=True: bnb.nn.Linear8bitLt(n_in, n_out, bias=bias, has_fp16_weights=False), + bnb.nn.LinearFP4, + ], + ids=["Int8Lt", "FP4"], +) def test_linear_kbit_fp32_bias(module): # casts model to fp16 -> int8 automatically l1 = module(32, 64).cuda() @@ -511,19 +511,21 @@ def test_linear_kbit_fp32_bias(module): o1 = l1(b1) assert l1.bias is None -modules = [] -modules.append(bnb.nn.Linear8bitLt) -modules.append(bnb.nn.Linear4bit) -modules.append(bnb.nn.LinearFP4) -modules.append(bnb.nn.LinearNF4) -modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True)) -modules.append(lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True)) -modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32)) -modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16)) -modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16)) -names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C', 'NF4+fp32', 'NF4+fp16', 'NF4+bf16'] -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.parametrize("module", modules, ids=names) + +module_dict = { + "Int8Lt": bnb.nn.Linear8bitLt, + "4bit": bnb.nn.Linear4bit, + "FP4": bnb.nn.LinearFP4, + "NF4": bnb.nn.LinearNF4, + "FP4+C": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True), + "NF4+C": lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True), + "NF4+fp32": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32), + "NF4+fp16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16), + "NF4+bf16": lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16), +} + + +@pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys()) def test_kbit_backprop(module): b = 17 dim1 = 37 @@ -540,7 +542,7 @@ def test_kbit_backprop(module): kbit[1].bias.detach().copy_(ref[1].bias) ref = ref.half().cuda() kbit = kbit.half().cuda() - kbit = kbit.half().to('cuda') + kbit = kbit.half().to("cuda") errs1 = [] errs2 = [] @@ -558,10 +560,10 @@ def test_kbit_backprop(module): bgrad1 = ref[0].bias.grad bgrad2 = kbit[0].bias.grad - err1 = (out1-out2).abs().float() - err2 = (grad1-grad2).abs().float() - relerr1 = (err1/(out1.abs().float()+1e-9)) - relerr2 = (err2/(grad1.abs().float()+1e-9)) + err1 = (out1 - out2).abs().float() + err2 = (grad1 - grad2).abs().float() + relerr1 = err1 / (out1.abs().float() + 1e-9) + relerr2 = err2 / (grad1.abs().float() + 1e-9) errs1.append(err1.mean().item()) errs2.append(err2.mean().item()) relerrs1.append(relerr1.mean().item()) @@ -571,27 +573,27 @@ def test_kbit_backprop(module): assert_all_approx_close(grad1, grad2, atol=0.008, rtol=0.05, count=1) torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05) else: - assert_all_approx_close(grad1, grad2, atol=0.015, rtol=0.05, count=1) - torch.testing.assert_close(bgrad1, bgrad2, atol=0.02, rtol=0.05) + assert_all_approx_close(grad1, grad2, atol=0.0165, rtol=0.055, count=1) + torch.testing.assert_close(bgrad1, bgrad2, atol=0.025, rtol=0.055) ref.zero_grad() kbit.zero_grad() assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0 assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0 - #print('out', sum(errs1)/len(errs1)) - #print('grad', sum(errs2)/len(errs2)) - #print('rel out', sum(relerrs1)/len(relerrs1)) - #print('rel grad', sum(relerrs2)/len(relerrs2)) + # print('out', sum(errs1)/len(errs1)) + # print('grad', sum(errs2)/len(errs2)) + # print('rel out', sum(relerrs1)/len(relerrs1)) + # print('rel grad', sum(relerrs2)/len(relerrs2)) -def test_fp8linear(): +def test_fp8linear(): b = 10 h = 1024 inp = torch.randn(b, h).cuda() - fp32 = torch.nn.Linear(h, h*2).cuda() - fp8 = bnb.research.nn.LinearFP8Mixed(h, h*2).cuda() - fp32b = torch.nn.Linear(h*2, h).cuda() - fp8b = bnb.research.nn.LinearFP8Mixed(h*2, h).cuda() + fp32 = torch.nn.Linear(h, h * 2).cuda() + fp8 = bnb.research.nn.LinearFP8Mixed(h, h * 2).cuda() + fp32b = torch.nn.Linear(h * 2, h).cuda() + fp8b = bnb.research.nn.LinearFP8Mixed(h * 2, h).cuda() fp8.weight.data.copy_(fp32.weight.data) fp8.bias.data.copy_(fp32.bias.data) @@ -601,34 +603,34 @@ def test_fp8linear(): a = fp32b(torch.nn.functional.gelu(fp32(inp))) b = fp8b(torch.nn.functional.gelu(fp8(inp))) - err = (a-b).abs().mean() + err = (a - b).abs().mean() a.mean().backward() b.mean().backward() - graderr = (fp8.weight.grad-fp32.weight.grad).abs().mean() - bgraderr = (fp8.bias.grad-fp32.bias.grad).abs().mean() + graderr = (fp8.weight.grad - fp32.weight.grad).abs().mean() + bgraderr = (fp8.bias.grad - fp32.bias.grad).abs().mean() assert err < 0.05 assert graderr < 0.00002 assert bgraderr < 0.00002 + def test_4bit_warnings(): dim1 = 64 - with pytest.warns(UserWarning, match=r'inference or training'): + with pytest.warns(UserWarning, match=r"inference or training"): net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)]) net = net.cuda() inp = torch.rand(10, dim1).cuda().half() net(inp) - with pytest.warns(UserWarning, match=r'inference.'): + with pytest.warns(UserWarning, match=r"inference."): net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)]) net = net.cuda() inp = torch.rand(1, dim1).cuda().half() net(inp) with pytest.warns(UserWarning) as record: - net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)]) net = net.cuda() inp = torch.rand(10, dim1).cuda().half() @@ -640,6 +642,3 @@ def test_4bit_warnings(): net(inp) assert len(record) == 2 - - - diff --git a/tests/test_optim.py b/tests/test_optim.py index c373a4f14..b8b4baa2c 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -1,24 +1,22 @@ -import ctypes import os +from os.path import join import shutil import time import uuid -from itertools import product -from os.path import join -import pytest from lion_pytorch import Lion - +import pytest import torch import bitsandbytes as bnb import bitsandbytes.functional as F -from bitsandbytes.cextension import HIP_ENVIRONMENT +from tests.helpers import describe_dtype, id_formatter # import apex k = 20 + def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0): idx = torch.isclose(a, b, rtol=rtol, atol=atol) error_count = (idx == 0).sum().item() @@ -28,7 +26,7 @@ def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0): def get_temp_dir(): - path = f"/tmp/autoswap/{str(uuid.uuid4())}" + path = f"/tmp/autoswap/{uuid.uuid4()}" os.makedirs(path, exist_ok=True) return path @@ -36,6 +34,7 @@ def get_temp_dir(): def rm_path(path): shutil.rmtree(path) + str2optimizers = {} str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam) str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion) @@ -69,8 +68,14 @@ def rm_path(path): ) str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) -str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True)) -str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True)) +str2optimizers["paged_adamw8bit_blockwise"] = ( + torch.optim.AdamW, + lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True), +) +str2optimizers["paged_adam8bit_blockwise"] = ( + torch.optim.Adam, + lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True), +) str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True)) str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True)) str2optimizers["momentum8bit_blockwise"] = ( @@ -93,9 +98,18 @@ def rm_path(path): str2statenames["rmsprop"] = [("square_avg", "state1")] str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")] str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")] -str2statenames["adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] -str2statenames["paged_adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] -str2statenames["paged_adamw8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] +str2statenames["adam8bit_blockwise"] = [ + ("exp_avg", "state1", "qmap1", "absmax1"), + ("exp_avg_sq", "state2", "qmap2", "absmax2"), +] +str2statenames["paged_adam8bit_blockwise"] = [ + ("exp_avg", "state1", "qmap1", "absmax1"), + ("exp_avg_sq", "state2", "qmap2", "absmax2"), +] +str2statenames["paged_adamw8bit_blockwise"] = [ + ("exp_avg", "state1", "qmap1", "absmax1"), + ("exp_avg_sq", "state2", "qmap2", "absmax2"), +] str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")] str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")] @@ -104,15 +118,16 @@ def rm_path(path): str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] -dim1 = [1024] -dim2 = [32, 1024, 4097, 1] -gtype = [torch.float32, torch.float16, torch.bfloat16] -optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion'] -values = list(product(dim1, dim2, gtype, optimizer_names)) -names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values] -@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) +optimizer_names_32bit = ["adam", "momentum", "rmsprop", "paged_adamw", "paged_adam", "lion", "paged_lion"] + + +@pytest.mark.parametrize("optim_name", optimizer_names_32bit, ids=id_formatter("opt")) +@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) +@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2")) def test_optimizer32bit(dim1, dim2, gtype, optim_name): - if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']: pytest.skip() + if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]: + pytest.skip() if dim1 == 1 and dim2 == 1: return p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 @@ -123,7 +138,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): bnb_optimizer = str2optimizers[optim_name][1]([p2]) if gtype == torch.float32: - atol, rtol = 1e-6, 1e-5 + atol, rtol = 1.5e-6, 1.5e-5 elif gtype == torch.bfloat16: atol, rtol = 1e-3, 1e-2 else: @@ -137,7 +152,6 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): bnb_optimizer.step() torch_optimizer.step() - for name1, name2 in str2statenames[optim_name]: torch.testing.assert_close( torch_optimizer.state[p1][name1], @@ -148,7 +162,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 10 errors for Lion - assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10) + assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=10) if i % (k // 5) == 0 and i > 0: path = get_temp_dir() @@ -160,13 +174,17 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): rm_path(path) # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 10 errors for Lion - assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10) + assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=10) for name1, name2 in str2statenames[optim_name]: # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 10 errors for Lion - assert_most_approx_close(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], - atol=atol, rtol=rtol, - max_error_count=10) + assert_most_approx_close( + torch_optimizer.state[p1][name1], + bnb_optimizer.state[p2][name2], + atol=atol, + rtol=rtol, + max_error_count=10, + ) if gtype != torch.float32: # the adam buffers should also be close because they are 32-bit @@ -180,14 +198,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0 -dim1 = [1024] -dim2 = [32, 1024, 4097] -gtype = [torch.float32, torch.float16] -values = list(product(dim1, dim2, gtype)) -names = ["dim1_{}_dim2_{}_gtype_{}".format(*vals) for vals in values] - - -@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names) +@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) +@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype) def test_global_config(dim1, dim2, gtype): if dim1 == 1 and dim2 == 1: return @@ -201,13 +214,9 @@ def test_global_config(dim1, dim2, gtype): eps = 1e-8 bnb.optim.GlobalOptimManager.get_instance().initialize() - bnb.optim.GlobalOptimManager.get_instance().override_config( - p3, "optim_bits", 8 - ) + bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8) - bnb.optim.GlobalOptimManager.get_instance().register_parameters( - [p1, p2, p3] - ) + bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3]) p1 = p1.cuda() p2 = p2.cuda() p3 = p3.cuda() @@ -233,10 +242,7 @@ def test_global_config(dim1, dim2, gtype): assert adam2.state[p3]["state2"].dtype == torch.uint8 -dim1 = [1024] -dim2 = [32, 1024, 4097] -gtype = [torch.float32, torch.float16, torch.bfloat16] -optimizer_names = [ +optimizer_names_8bit = [ "adam8bit", "lion8bit", "momentum8bit", @@ -246,15 +252,15 @@ def test_global_config(dim1, dim2, gtype): "momentum8bit_blockwise", "rmsprop8bit_blockwise", ] -values = list(product(dim1, dim2, gtype, optimizer_names)) -names = [ - "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values -] -@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) +@pytest.mark.parametrize("optim_name", optimizer_names_8bit, ids=id_formatter("opt")) +@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) +@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) def test_optimizer8bit(dim1, dim2, gtype, optim_name): - if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip() + if gtype == torch.bfloat16 and optim_name not in ["adam8bit_blockwise", "lion8bit_blockwise"]: + pytest.skip() if dim1 == 1 and dim2 == 1: return p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 @@ -306,17 +312,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], ) - num_not_close = ( - torch.isclose( - torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol - ) - == 0 - ) - #assert num_not_close.sum().item() < 20 + num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0 + # assert num_not_close.sum().item() < 20 dequant_states.append(s1.clone()) err = torch.abs(p1 - p2) - relerr = err / (torch.abs(p1)+1e-9) + relerr = err / (torch.abs(p1) + 1e-9) if g.dtype == torch.bfloat16: assert err.mean() < 0.00015 assert relerr.mean() < 0.0016 @@ -328,9 +329,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): relerrors.append(relerr.mean().item()) if i % 10 == 0 and i > 0: - for (name1, name2, qmap, max_val), s in zip( - str2statenames[optim_name], dequant_states - ): + for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): s1cpy = s.clone() raws1cpy = bnb_optimizer.state[p2][name2].clone() qmap1 = bnb_optimizer.state[p2][qmap].clone() @@ -360,7 +359,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ) torch.testing.assert_close(s1cpy, s1) - num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0) + num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0 assert num_not_close.sum().item() < 20 # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 5 errors for Lion @@ -378,18 +377,10 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): # print(sum(relerrors)/len(relerrors)) -dim1 = [1024] -dim2 = [32, 1024, 4097] -gtype = [torch.float32] -optim_bits = [32, 8] -values = list(product(dim1, dim2, gtype, optim_bits)) -names = [ - "dim1_{}_dim2_{}_gtype_{}_optim_bits_{}".format(*vals) - for vals in values -] - - -@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names) +@pytest.mark.parametrize("optim_bits", [32, 8], ids=id_formatter("optim_bits")) +@pytest.mark.parametrize("gtype", [torch.float32], ids=describe_dtype) +@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): if dim1 == 1 and dim2 == 1: return @@ -415,15 +406,11 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): for i in range(50): step += 1 - g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + ( - 0.01 * i - ) + g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (0.01 * i) g2 = g1.clone() p2.grad = g2 - current_gnorm, clip_val, gnorm_scale = F.percentile_clipping( - g1, gnorm_vec, step, 5 - ) + current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(g1, gnorm_vec, step, 5) g1 = (g1.float() * gnorm_scale).to(gtype) p1.grad = g1 @@ -477,22 +464,19 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): adam2.load_state_dict(torch.load(join(path, "opt.pt"))) -dim1 = [4096] -dim2 = [4096] -gtype = [torch.float32, torch.float16] -# optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit'] -# optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch'] -# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch'] -# optimizer_names = ['lamb_apex', 'lamb8bit'] -# optimizer_names = ['lars_apex', 'lars8bit'] -optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise', 'paged_lion8bit_blockwise'] -values = list(product(dim1, dim2, gtype, optimizer_names)) -names = [ - "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values +optimizer_names_benchmark = [ + "adam8bit_blockwise", + "paged_adam8bit_blockwise", + "paged_adamw8bit_blockwise", + "paged_lion8bit_blockwise", ] -@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) +@pytest.mark.parametrize("dim1", [4096], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [4096], ids=id_formatter("dim2")) +@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype) +@pytest.mark.parametrize("optim_name", optimizer_names_benchmark, ids=id_formatter("opt")) +@pytest.mark.benchmark def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): if dim1 == 1 and dim2 == 1: return @@ -517,39 +501,36 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): print(optim_name, gtype, s / params) # assert s < 3.9 -dim1 = [2*1024] -gtype = [torch.float16] -#mode = ['torch', 'bnb'] -mode = ['bnb'] -optimizer_names = ['paged_adamw'] -#optimizer_names = ['paged_adamw8bit_blockwise'] -values = list(product(dim1,gtype, optimizer_names, mode)) -names = ['dim1_{0}_gtype_{1}_optim_{2}_mode_{3}'.format(*vals) for vals in values] -@pytest.mark.parametrize("dim1, gtype, optim_name, mode", values, ids=names) + +@pytest.mark.parametrize("dim1", [2 * 1024], ids=id_formatter("dim1")) +@pytest.mark.parametrize("gtype", [torch.float16], ids=describe_dtype) +@pytest.mark.parametrize("optim_name", ["paged_adamw"], ids=id_formatter("optim_name")) +@pytest.mark.parametrize("mode", ["bnb"], ids=id_formatter("mode")) +@pytest.mark.benchmark def test_stream_optimizer_bench(dim1, gtype, optim_name, mode): layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)])) layers1 = layers1.to(gtype) layers1 = layers1.cuda() large_tensor = None - if mode == 'torch': + if mode == "torch": optim = str2optimizers[optim_name][0](layers1.parameters()) else: optim = str2optimizers[optim_name][1](layers1.parameters()) # 12 GB - large_tensor = torch.empty((int(4.5e9),), device='cuda') + large_tensor = torch.empty((int(4.5e9),), device="cuda") torch.cuda.synchronize() time.sleep(5) num_batches = 5 - batches = torch.randn(num_batches, 128, dim1, device='cuda').to(gtype) - lbls = torch.randint(0, 10, size=(num_batches,128)).cuda() + batches = torch.randn(num_batches, 128, dim1, device="cuda").to(gtype) + lbls = torch.randint(0, 10, size=(num_batches, 128)).cuda() for i in range(num_batches): print(i) b = batches[i] - if i ==2: + if i == 2: torch.cuda.synchronize() t0 = time.time() diff --git a/tests/test_triton.py b/tests/test_triton.py index 8890193fc..3624fb5e9 100644 --- a/tests/test_triton.py +++ b/tests/test_triton.py @@ -1,21 +1,24 @@ import pytest import torch -from bitsandbytes.triton.triton_utils import is_triton_available -from bitsandbytes.nn.triton_based_modules import SwitchBackLinear from bitsandbytes.nn import Linear8bitLt -from bitsandbytes.cextension import HIP_ENVIRONMENT +from bitsandbytes.nn.triton_based_modules import SwitchBackLinear +from bitsandbytes.triton.triton_utils import is_triton_available +from tests.helpers import TRUE_FALSE -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") -@pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, - reason="This test requires triton and a GPU with compute capability 8.0 or higher.") -@pytest.mark.parametrize("vector_wise_quantization", [False, True]) + +@pytest.mark.skipif( + not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, + reason="This test requires triton and a GPU with compute capability 8.0 or higher.", +) +@pytest.mark.parametrize("vector_wise_quantization", TRUE_FALSE) def test_switchback(vector_wise_quantization): for dim in [83]: for batch in [13]: - standard = torch.nn.Linear(dim, 4 * dim).cuda().half() - switchback = SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half() + switchback = ( + SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half() + ) baseline = Linear8bitLt(dim, 4 * dim).cuda().half() switchback.weight.data.copy_(standard.weight) switchback.bias.data.copy_(standard.bias) @@ -38,24 +41,23 @@ def test_switchback(vector_wise_quantization): err_sb = (out_standard - out_sb).abs().mean() err_baseline = (out_standard - out_baseline).abs().mean() - print('OUT', err_sb, err_baseline) + print("OUT", err_sb, err_baseline) assert err_sb < 2 * err_baseline err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean() err_baseline = (standard.bias.grad - baseline.bias.grad).abs().mean() - print('GW2', err_sb, err_baseline) + print("GW2", err_sb, err_baseline) assert err_sb < 2 * err_baseline err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean() err_baseline = (standard.weight.grad - baseline.weight.grad).abs().mean() - print('GW1', err_sb, err_baseline) + print("GW1", err_sb, err_baseline) assert err_sb < 2 * err_baseline err_sb = (x1.grad - x2.grad).abs().mean() err_baseline = (x1.grad - x3.grad).abs().mean() - print('GX1', err_sb, err_baseline) + print("GX1", err_sb, err_baseline) assert err_sb < 2 * err_baseline -