Skip to content

Conversation

@shbiswas834
Copy link

No description provided.


const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms();
const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024;
const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to guard the change by #ifdef USE_ROCM

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

// Compute shared memory size for cta_per_row
constexpr auto kCacheAccBytes = sizeof(at::acc_type<cache_t, true>);
int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize;
int32_t num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to guard the change by #ifdef USE_ROCM

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

|| dev_weights.scalar_type() == at::ScalarType::Float;

if (use_hip_kernel && supported_weights_type && !mixed_D && rocm::is_supported_cdna())
if (use_hip_kernel && !mixed_D && supported_weights_type && rocm::is_supported_cdna())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the order matter?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Order doesnt matter

// (vector is very large).
std::sort(
std::execution::par,
// std::execution::par,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uncomment this line

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed


const int32_t cta_per_row_grid_size = std::min(
div_round_up(total_unique_indices, kMaxThreads),
div_round_up(total_unique_indices, (kMaxThreads/4)),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add #ifdef USE_ROCM above this line

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

@shbiswas834
Copy link
Author

@liligwu changes are tested. Builds successfully. Few backwards and forwards were tested. no regression

Copy link

@avbokovoy avbokovoy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM

list(APPEND HIP_CXX_FLAGS -mf16c)
list(APPEND HIP_CXX_FLAGS -mfma)
list(APPEND HIP_CXX_FLAGS -std=c++20)
list(APPEND HIP_CXX_FLAGS -g)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to revert the whole file changes since it only increases build time and diff in this context

// Now sort the indices by their tags. Use parallel sort for some extra speed
// (vector is very large).
std::sort(
std::execution::par,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be reverted

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants