Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 37 additions & 13 deletions src/helper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,23 +150,47 @@ macro timing(cond,code)
end

function SparseMatrix.sparse(I,J,V, M, N;keepzeros=false)
@assert length(I) == length(J) == length(V)
if(!keepzeros)
return sparse(I,J,V,M,N)
else
full = sparse(I,J,ones(Float64,length(I)),M,N)
actual = sparse(I,J,V,M,N)
fill!(full.nzval,0.0)
mergednnz = [0]
mergedmap = zeros(Int,length(I))
idxmap = zeros(Int,length(I))
mergedindices = zeros(Int,length(I))
function combine(idx1,idx2)
# @show idx1, idx2
idx1 = round(Int,idx1)
idx2 = round(Int,idx2)
@inbounds @assert mergedmap[idx2] == 0 && (mergedmap[idx1] == idx1 || mergedmap[idx1] == 0)
@inbounds mergednnz[1] += 1
@inbounds mergedmap[idx1] = idx1
@inbounds mergedmap[idx2] = idx1
@inbounds mergedindices[mergednnz[1]] = idx2
return idx1
end

full = sparse(I,J,[float(i) for i in 1:length(I)],M,N,combine)
for col in 1:N
@inbounds for pos in full.colptr[col]:(full.colptr[col+1]-1)
@inbounds row = full.rowval[pos]
@inbounds origidx = round(Int,full.nzval[pos]) # this is the original index (on JJ) of this element
@inbounds idxmap[origidx] = pos
end
end

for c = 1:N
for i=nzrange(actual,c)
r = actual.rowval[i]
v = actual.nzval[i]
if(v!=0)
full[r,c] = v
end
end
# full.nzval[crange] = actual.nzval[crange]
end
@inbounds for k in 1:mergednnz[1]
@inbounds origidx = mergedindices[k]
@inbounds mergedwith = mergedmap[origidx]
@inbounds @assert idxmap[origidx] == 0
@inbounds @assert idxmap[mergedwith] != 0
@inbounds idxmap[origidx] = idxmap[mergedwith]
end

fill!(full.nzval,0.0)
for i in 1:length(I)
@inbounds full.nzval[idxmap[i]] += V[i]
end
return full
end
end
Expand Down