Skip to content

Create specialized versions of BLAS subroutines with fewer arguments and run-time decisions #65

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Beliavsky opened this issue Nov 13, 2024 · 0 comments

Comments

@Beliavsky
Copy link

I asked ChatGPT o1-preview to create specialized versions of stdlib_dgemm for the cases where the matrix multiplication is done with the original matrix a or its transpose. It created stdlib_dgemm_a_orig and stdlib_dgemm_a_trans below. I have not checked them. In general, what is knowable at compile time (whether to use the matrix or its transpose) should not be done at run time.

pure subroutine stdlib_dgemm_a_orig(transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
    ! Specialized DGEMM subroutine where A is not transposed (transa = 'N')
    ! Performs the operation: C = alpha * A * op(B) + beta * C

    ! Scalar Arguments
    character, intent(in) :: transb
    integer, intent(in) :: m, n, k
    real(dp), intent(in) :: alpha, beta
    integer, intent(in) :: lda, ldb, ldc

    ! Array Arguments
    real(dp), intent(in) :: a(lda, *)
    real(dp), intent(in) :: b(ldb, *)
    real(dp), intent(inout) :: c(ldc, *)

    ! Local Scalars
    integer :: i, j, l, info, nrowb
    real(dp) :: temp
    logical :: notb
    real(dp), parameter :: zero = 0.0_dp, one = 1.0_dp

    ! Intrinsic Functions
    intrinsic :: max

    ! Set notb as true if B is not transposed
    notb = (transb == 'N' .or. transb == 'n')
    if (notb) then
        nrowb = k
    else
        nrowb = n
    end if

    ! Test the input parameters.
    info = 0
    if ((.not. notb) .and. (.not. (transb == 'T' .or. transb == 't'))) then
        info = 1
    else if (m < 0) then
        info = 2
    else if (n < 0) then
        info = 3
    else if (k < 0) then
        info = 4
    else if (lda < max(1, m)) then
        info = 7
    else if (ldb < max(1, nrowb)) then
        info = 9
    else if (ldc < max(1, m)) then
        info = 12
    end if
    if (info /= 0) then
        call stdlib_xerbla('DGEMM ', info)
        return
    end if

    ! Quick return if possible.
    if ((m == 0) .or. (n == 0) .or. (((alpha == zero) .or. (k == 0)) .and. (beta == one))) return

    ! If alpha is zero.
    if (alpha == zero) then
        if (beta == zero) then
            do j = 1, n
                do i = 1, m
                    c(i, j) = zero
                end do
            end do
        else
            do j = 1, n
                do i = 1, m
                    c(i, j) = beta * c(i, j)
                end do
            end do
        end if
        return
    end if

    ! Start the operations.
    if (notb) then
        ! Form C := alpha*A*B + beta*C.
        do j = 1, n
            if (beta == zero) then
                do i = 1, m
                    c(i, j) = zero
                end do
            else if (beta /= one) then
                do i = 1, m
                    c(i, j) = beta * c(i, j)
                end do
            end if
            do l = 1, k
                temp = alpha * b(l, j)
                do i = 1, m
                    c(i, j) = c(i, j) + temp * a(i, l)
                end do
            end do
        end do
    else
        ! Form C := alpha*A*B**T + beta*C
        do j = 1, n
            if (beta == zero) then
                do i = 1, m
                    c(i, j) = zero
                end do
            else if (beta /= one) then
                do i = 1, m
                    c(i, j) = beta * c(i, j)
                end do
            end if
            do l = 1, k
                temp = alpha * b(j, l)
                do i = 1, m
                    c(i, j) = c(i, j) + temp * a(i, l)
                end do
            end do
        end do
    end if

    return
end subroutine stdlib_dgemm_a_orig

pure subroutine stdlib_dgemm_a_trans(transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
    ! Specialized DGEMM subroutine where A is transposed (transa = 'T')
    ! Performs the operation: C = alpha * A**T * op(B) + beta * C

    ! Scalar Arguments
    character, intent(in) :: transb
    integer, intent(in) :: m, n, k
    real(dp), intent(in) :: alpha, beta
    integer, intent(in) :: lda, ldb, ldc

    ! Array Arguments
    real(dp), intent(in) :: a(lda, *)
    real(dp), intent(in) :: b(ldb, *)
    real(dp), intent(inout) :: c(ldc, *)

    ! Local Scalars
    integer :: i, j, l, info, nrowb
    real(dp) :: temp
    logical :: notb
    real(dp), parameter :: zero = 0.0_dp, one = 1.0_dp

    ! Intrinsic Functions
    intrinsic :: max

    ! Set notb as true if B is not transposed
    notb = (transb == 'N' .or. transb == 'n')
    if (notb) then
        nrowb = k
    else
        nrowb = n
    end if

    ! Test the input parameters.
    info = 0
    if ((.not. notb) .and. (.not. (transb == 'T' .or. transb == 't'))) then
        info = 1
    else if (m < 0) then
        info = 2
    else if (n < 0) then
        info = 3
    else if (k < 0) then
        info = 4
    else if (lda < max(1, k)) then
        info = 7
    else if (ldb < max(1, nrowb)) then
        info = 9
    else if (ldc < max(1, m)) then
        info = 12
    end if
    if (info /= 0) then
        call stdlib_xerbla('DGEMM ', info)
        return
    end if

    ! Quick return if possible.
    if ((m == 0) .or. (n == 0) .or. (((alpha == zero) .or. (k == 0)) .and. (beta == one))) return

    ! If alpha is zero.
    if (alpha == zero) then
        if (beta == zero) then
            do j = 1, n
                do i = 1, m
                    c(i, j) = zero
                end do
            end do
        else
            do j = 1, n
                do i = 1, m
                    c(i, j) = beta * c(i, j)
                end do
            end do
        end if
        return
    end if

    ! Start the operations.
    if (notb) then
        ! Form C := alpha*A**T*B + beta*C
        do j = 1, n
            do i = 1, m
                temp = zero
                do l = 1, k
                    temp = temp + a(l, i) * b(l, j)
                end do
                if (beta == zero) then
                    c(i, j) = alpha * temp
                else
                    c(i, j) = alpha * temp + beta * c(i, j)
                end if
            end do
        end do
    else
        ! Form C := alpha*A**T*B**T + beta*C
        do j = 1, n
            do i = 1, m
                temp = zero
                do l = 1, k
                    temp = temp + a(l, i) * b(j, l)
                end do
                if (beta == zero) then
                    c(i, j) = alpha * temp
                else
                    c(i, j) = alpha * temp + beta * c(i, j)
                end if
            end do
        end do
    end if

    return
end subroutine stdlib_dgemm_a_trans

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant