diff --git a/src/PyCall.jl b/src/PyCall.jl index 48c28697..0ca5a849 100644 --- a/src/PyCall.jl +++ b/src/PyCall.jl @@ -22,7 +22,8 @@ import Base: size, ndims, similar, copy, getindex, setindex!, stride, filter!, hash, splice!, pop!, ==, isequal, push!, append!, insert!, prepend!, unsafe_convert, pushfirst!, popfirst!, firstindex, lastindex, - getproperty, setproperty!, propertynames + getproperty, setproperty!, propertynames, + ReinterpretArray, ReshapedArray if isdefined(Base, :hasproperty) # Julia 1.2 import Base: hasproperty diff --git a/src/gc.jl b/src/gc.jl index e000cc37..2057ddf4 100644 --- a/src/gc.jl +++ b/src/gc.jl @@ -29,7 +29,13 @@ const weakref_callback_meth = Ref{PyMethodDef}() function pyembed(po::PyObject, jo::Any) # If there's a need to support immutable embedding, # the API needs to be changed to return the pointer. - isimmutable(jo) && throw(ArgumentError("pyembed: immutable argument not allowed")) + if isimmutable(jo) + if applicable(parent, jo) + return pyembed(po, parent(jo) ) + else + throw(ArgumentError("pyembed: immutable argument not allowed")) + end + end if ispynull(weakref_callback_obj) cf = @cfunction(weakref_callback, PyPtr, (PyPtr,PyPtr)) weakref_callback_meth[] = PyMethodDef("weakref_callback", cf, METH_O) @@ -43,3 +49,10 @@ function pyembed(po::PyObject, jo::Any) pycall_gc[wo] = jo return po end + +# Embed the mutable type underlying the immutable view of the array +# See Base.unsafe_convert(::Type{Ptr{T}}, jo::ArrayType) for specific array types +pyembed(po::PyObject, jo::SubArray ) = pyembed(po, jo.parent) +pyembed(po::PyObject, jo::ReshapedArray ) = pyembed(po, jo.parent) +pyembed(po::PyObject, jo::ReinterpretArray ) = pyembed(po, jo.parent) +pyembed(po::PyObject, jo::PermutedDimsArray) = pyembed(po, jo.parent) \ No newline at end of file diff --git a/src/numpy.jl b/src/numpy.jl index f3d523aa..47dbe6d9 100644 --- a/src/numpy.jl +++ b/src/numpy.jl @@ -172,7 +172,7 @@ const NPY_ARRAY_WRITEABLE = Int32(0x0400) # dimensions. For example, although NumPy works with both row-major and # column-major data, some Python libraries like OpenCV seem to require # row-major data (the default in NumPy). In such cases, use PyReverseDims(array) -function NpyArray(a::StridedArray{T}, revdims::Bool) where T<:PYARR_TYPES +function NpyArray(a::AbstractArray{T}, revdims::Bool) where T<:PYARR_TYPES @npyinitialize size_a = revdims ? reverse(size(a)) : size(a) strides_a = revdims ? reverse(strides(a)) : strides(a) @@ -186,7 +186,7 @@ function NpyArray(a::StridedArray{T}, revdims::Bool) where T<:PYARR_TYPES return PyObject(p, a) end -function PyObject(a::StridedArray{T}) where T<:PYARR_TYPES +function PyObject(a::AbstractArray{T}) where T<:PYARR_TYPES try return NpyArray(a, false) catch @@ -194,7 +194,7 @@ function PyObject(a::StridedArray{T}) where T<:PYARR_TYPES end end -function PyReverseDims(a::StridedArray{T,N}) where {T<:PYARR_TYPES,N} +function PyReverseDims(a::AbstractArray{T,N}) where {T<:PYARR_TYPES,N} try return NpyArray(a, true) catch diff --git a/test/runtests.jl b/test/runtests.jl index dff36dbc..82a6700f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -72,6 +72,8 @@ const PyInt = pyversion < v"3" ? Int : Clonglong @test roundtripeq(Int32) @test roundtripeq(Dict(1 => "hello", 2 => "goodbye")) && roundtripeq(Dict()) @test roundtripeq(UInt8[1,3,4,5]) + @test roundtripeq(1:5) + @test roundtripeq(5:2:10) @test roundtrip(3 => 4) == (3,4) @test roundtrip(Pair{Int,Int}, 3 => 4) == Pair(3,4) @test eltype(roundtrip([Ref(1), Ref(2)])) == typeof(Ref(1)) @@ -107,6 +109,42 @@ const PyInt = pyversion < v"3" ? Int : Clonglong @test GC.@preserve(o, pyincref.(PyArray(o))) == a end end + let A = Float64[1 2; 3 4] + # Normal array + B = copy(A) + C = PyArray( PyObject(B) ) + @test C == B + B[1] = 3 + @test C == B && C[1] == B[1] + + # SubArray + B = view(A, 1:2, 2:2) + C = PyArray( PyObject(B) ) + @test C == B + A[3] = 5 + @test C == B && C[1] == A[3] + + # ReshapedArray + B = Base.ReshapedArray( A, (1,4), () ) + C = PyArray( PyObject(B) ) + @test C == B + A[2] = 6 + @test C == B && C[2] == A[2] + + # PermutedDimsArray + B = PermutedDimsArray(A, (2,1) ) + C = PyArray( PyObject(B) ) + @test C == B + A[1] == 7 + @test C == B && C[1] == A[1] + + # ReinterpretArray + B = reinterpret(UInt64, A) + C = PyArray( PyObject(B) ) + @test C == B + A[1] = 12 + @test C == B && C[1] == reinterpret(UInt64, A[1]) + end end @test PyVector(PyObject([1,3.2,"hello",true])) == [1,3.2,"hello",true] @test PyDict(PyObject(Dict(1 => "hello", 2 => "goodbye"))) == Dict(1 => "hello", 2 => "goodbye")