Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 151 additions & 8 deletions csrc/ops.hip
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ static std::string hipError_to_string(const hipError_t ret)
}
}


template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{
#ifdef NO_HIPBLASLT
Expand All @@ -524,28 +525,101 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(hipblasLtHandl
hipblasLtOrder_t col_ampere = HIPBLASLT_ORDER_COL;

has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&Adesc, HIP_R_8I, m, k, lda));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatrixLayoutCreate for Adesc:"<<m<<" "<<k<<" "<<lda<<std::endl;
return has_error;
}
has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&Bdesc, HIP_R_8I, n, k, ldb));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatrixLayoutCreate for Bdesc:"<<n<<" "<<k<<" "<<ldb<<std::endl;
hipblasLtMatrixLayoutDestroy(Adesc);
return has_error;
}
has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Adesc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatrixLayoutSetAttribute for Adesc"<<std::endl;
hipblasLtMatrixLayoutDestroy(Adesc);
hipblasLtMatrixLayoutDestroy(Bdesc);
return has_error;
}


if(FORMATB == COL_TURING)
//if(FORMATB == COL_TURING)
has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Bdesc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing)));
else
has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Bdesc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere)));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatrixLayoutSetAttribute for Bdesc"<<std::endl;
hipblasLtMatrixLayoutDestroy(Adesc);
hipblasLtMatrixLayoutDestroy(Bdesc);
return has_error;
}
//else
// has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Bdesc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere)));

const int64_t max_workspace_size = 0;//set to 0 to avoid choosing GSU kernel

if(DTYPE_OUT == 32)
{
has_error |= checkHipblasStatus(hipblasLtMatmulDescCreate(&matmulDesc, HIPBLAS_COMPUTE_32I, HIP_R_32I));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatmulDescCreate"<<std::endl;
hipblasLtMatrixLayoutDestroy(Adesc);
hipblasLtMatrixLayoutDestroy(Bdesc);
return has_error;
}
auto opA = HIPBLAS_OP_N;
has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSA, &opA, sizeof(int32_t)));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatmulDescSetAttribute HIPBLASLT_MATMUL_DESC_TRANSA"<<std::endl;
hipblasLtMatrixLayoutDestroy(Adesc);
hipblasLtMatrixLayoutDestroy(Bdesc);
hipblasLtMatmulDescDestroy(matmulDesc);
return has_error;
}
has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(int32_t)));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatmulDescSetAttribute HIPBLASLT_MATMUL_DESC_TRANSB"<<std::endl;
hipblasLtMatrixLayoutDestroy(Adesc);
hipblasLtMatrixLayoutDestroy(Bdesc);
hipblasLtMatmulDescDestroy(matmulDesc);
return has_error;
}
hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT;
checkHipblasStatus(hipblasLtMatmulDescSetAttribute(
matmulDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatmulDescSetAttribute HIPBLASLT_MATMUL_DESC_EPILOGUE"<<std::endl;
hipblasLtMatrixLayoutDestroy(Adesc);
hipblasLtMatrixLayoutDestroy(Bdesc);
hipblasLtMatmulDescDestroy(matmulDesc);
return has_error;
}
has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&Cdesc, HIP_R_32I, m, n, ldc));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatrixLayoutCreate for Cdesc"<<m<<" "<<n<<" "<<ldc<<std::endl;
hipblasLtMatrixLayoutDestroy(Adesc);
hipblasLtMatrixLayoutDestroy(Bdesc);
hipblasLtMatmulDescDestroy(matmulDesc);
return has_error;
}
has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Cdesc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatrixLayoutSetAttribute for Cdesc"<<std::endl;
hipblasLtMatrixLayoutDestroy(Adesc);
hipblasLtMatrixLayoutDestroy(Bdesc);
hipblasLtMatrixLayoutDestroy(Cdesc);
hipblasLtMatmulDescDestroy(matmulDesc);
return has_error;
}
int alpha = 1, beta = 0;


Expand Down Expand Up @@ -580,17 +654,72 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(hipblasLtHandl
}
else
{
has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0));
void* d_workspace=nullptr;
uint64_t workspace_size = 0;
for(int i = 0; i < returnedAlgoCount; i++)
workspace_size = max(workspace_size, heuristicResult[i].workspaceSize);
hipMalloc(&d_workspace, workspace_size);
has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, &heuristicResult[0].algo, d_workspace, workspace_size, 0));
hipFree(d_workspace);
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatmul"<<std::endl;
hipblasLtMatrixLayoutDestroy(Adesc);
hipblasLtMatrixLayoutDestroy(Bdesc);
hipblasLtMatrixLayoutDestroy(Cdesc);
hipblasLtMatmulDescDestroy(matmulDesc);
return has_error;
}
}
}
else
{
has_error |= checkHipblasStatus(hipblasLtMatmulDescCreate(&matmulDesc, HIPBLAS_COMPUTE_32I, HIP_R_8I));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatmulDescCreate for int8"<<std::endl;
hipblasLtMatrixLayoutDestroy(Adesc);
hipblasLtMatrixLayoutDestroy(Bdesc);
return has_error;
}
hipblasOperation_t opA = HIPBLAS_OP_N;
has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSA, &opA, sizeof(opA)));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatmulDescSetAttribute HIPBLASLT_MATMUL_DESC_TRANSA for int8"<<std::endl;
hipblasLtMatrixLayoutDestroy(Adesc);
hipblasLtMatrixLayoutDestroy(Bdesc);
hipblasLtMatmulDescDestroy(matmulDesc);
return has_error;
}
has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT)));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatmulDescSetAttribute HIPBLASLT_MATMUL_DESC_TRANSB for int8"<<std::endl;
hipblasLtMatrixLayoutDestroy(Adesc);
hipblasLtMatrixLayoutDestroy(Bdesc);
hipblasLtMatmulDescDestroy(matmulDesc);
return has_error;
}
has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&Cdesc, HIP_R_8I, m, n, ldc));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatrixLayoutCreate Cdesc for int8"<<m<<" "<<n<<" "<<ldc<<std::endl;
hipblasLtMatrixLayoutDestroy(Adesc);
hipblasLtMatrixLayoutDestroy(Bdesc);
hipblasLtMatmulDescDestroy(matmulDesc);
return has_error;
}
has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Cdesc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatrixLayoutSetAttribute Cdesc for int8"<<std::endl;
hipblasLtMatrixLayoutDestroy(Adesc);
hipblasLtMatrixLayoutDestroy(Bdesc);
hipblasLtMatrixLayoutDestroy(Cdesc);
hipblasLtMatmulDescDestroy(matmulDesc);
return has_error;
}
/* Algo and workspace TODO: need to rework to not be duplicated */
// Set User Preference attributes
hipblasLtMatmulPreference_t pref;
Expand Down Expand Up @@ -622,17 +751,30 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(hipblasLtHandl
}
else
{
uint64_t workspace_size = 0;
for(int i = 0; i < returnedAlgoCount; i++)
workspace_size = max(workspace_size, heuristicResult[i].workspaceSize);
void* d_workspace=nullptr;
hipMalloc(&d_workspace, workspace_size);
if(!SCALE_ROWS)
{
float alpha = 1.0f, beta = 0.0f;

has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0));
has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, workspace_size, 0));
}
else
{
float beta = 0.0f;

has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0));
has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, workspace_size, 0));
}
hipFree(d_workspace);
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatmul with int8"<<std::endl;
hipblasLtMatrixLayoutDestroy(Adesc);
hipblasLtMatrixLayoutDestroy(Bdesc);
hipblasLtMatrixLayoutDestroy(Cdesc);
hipblasLtMatmulDescDestroy(matmulDesc);
return has_error;
}
}
}
Expand All @@ -649,6 +791,7 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(hipblasLtHandl
#endif // NO_HIPBLASLT
}


int fill_up_to_nearest_multiple(int value, int multiple)
{
return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple)));
Expand Down
12 changes: 8 additions & 4 deletions csrc/ops_hip.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,20 @@ typedef enum Funcs_t
class Context
{
public:
rocblas_handle m_handle;
hipblasLtHandle_t m_handle;
//rocblas_handle m_handle;

Context()
{
rocblas_handle handle;
rocblas_create_handle(&handle);
//rocblas_handle handle;
//rocblas_create_handle(&handle);
hipblasLtHandle_t handle;
hipblasLtCreate(&handle);
m_handle = handle;
}

};

/*
class ContextLt
{
public:
Expand All @@ -124,6 +127,7 @@ class ContextLt
m_handle = handle;
}
};
*/

class ContextHipsparse
{
Expand Down