From 313e853177c5ca95b201f6c4124a2e3c03ae9cc1 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Fri, 25 Apr 2025 11:28:28 -0300 Subject: [PATCH] FIx findall output type --- src/indexing.jl | 2 +- test/array.jl | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/indexing.jl b/src/indexing.jl index 2597e223a..44d6fc782 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -26,7 +26,7 @@ function Base.findall(bools::WrappedMtlArray{Bool}) indices = cumsum(reshape(bools, prod(size(bools)))) n = @allowscalar indices[end] - ys = MtlArray{I}(undef, n) + ys = similar(bools, I, n) if n > 0 function kernel(ys::MtlDeviceArray, bools, indices) diff --git a/test/array.jl b/test/array.jl index 874b328b9..12b844a44 100644 --- a/test/array.jl +++ b/test/array.jl @@ -526,6 +526,13 @@ end @test testf(x->findall(x), rand(Bool, 1000)) @test testf(x->findall(y->y>Float32(0.5), x), rand(Float32,1000)) + # Set storage mode to a different one than the default + let storage=Metal.DefaultStorageMode == Metal.PrivateStorage ? Metal.SharedStorage : Metal.PrivateStorage + x = mtl(rand(Float32,100); storage) + out = findall(y->y>Float32(0.5), x) + @test Metal.storagemode(x) == Metal.storagemode(out) + end + # ND let x = rand(Bool, 1000, 1000) @test findall(x) == Array(findall(MtlArray(x)))