Skip to content
This repository was archived by the owner on Sep 27, 2021. It is now read-only.

Commit ec65d5a

Browse files
MaaarcocrSimonDanisch
authored andcommitted
Added support for Adapt.jl (#16)
* Added support for Adapt.jl * avoid conflict * now we can convert an array of Duals with cl() * added tests
1 parent cb54d77 commit ec65d5a

File tree

3 files changed

+17
-6
lines changed

3 files changed

+17
-6
lines changed

REQUIRE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ GPUArrays 0.2.0
55
StaticArrays
66
ColorTypes
77

8+
Adapt
89
Transpiler 0.4.3
910
Sugar 0.4.1
1011
Matcha 0.1.1

src/array.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,14 @@ context(p::CLArray) = context(pointer(p))
2626

2727
# Avoid conflict with OpenCL.cl
2828
module Shorthands
29-
using ..CLArrays
30-
cl(x) = x
31-
cl(x::CLArrays.CLArray) = x
32-
cl(xs::AbstractArray) = isbits(xs) ? xs : CLArrays.CLArray(xs)
29+
using ..CLArrays: CLArray
30+
import Adapt: adapt, adapt_
31+
32+
adapt_(::Type{<:CLArray}, xs::AbstractArray) = isbits(xs) ? xs : convert(CLArray, xs)
33+
34+
cl(x) = adapt(CLArray{Float32}, x)
35+
36+
export cl
3337
end
3438

3539
function (::Type{CLArray{T, N}})(size::NTuple{N, Integer}, ctx::cl.Context = global_context()) where {T, N}

test/runtests.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using CLArrays
1+
using CLArrays, CLArrays.Shorthands
22
using GPUArrays.TestSuite, Base.Test
33

44
for dev in CLArrays.devices()
@@ -30,6 +30,13 @@ for dev in CLArrays.devices()
3030
# in copy to convert to array type, but that actually convert Array{Bool} to BitArray
3131
# against_base((a, b)-> a .& b, CLArray{Bool}, (10,), (10,))
3232
end
33+
34+
@testset "Shorthand Test" begin
35+
GPUArrays.allowslow(true)
36+
@test collect(cl([1,2])) == [1,2]
37+
@test collect(cl([1 2;3 4])) == [1 2;3 4]
38+
@test cl([1,2,3]) == CLArray([1,2,3])
39+
end
3340
end
3441
end
3542

@@ -132,4 +139,3 @@ end
132139
# # out[15] = sizeof(x15)
133140
# return
134141
# end
135-

0 commit comments

Comments
 (0)