forked from cortex-lab/KiloSort
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrezToPhy.m
110 lines (90 loc) · 3.94 KB
/
rezToPhy.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
function [spikeTimes, clusterIDs, amplitudes, templates, templateFeatures, ...
templateFeatureInds, pcFeatures, pcFeatureInds] = rezToPhy(rez, savePath)
% pull out results from kilosort's rez to either return to workspace or to
% save in the appropriate format for the phy GUI to run on. If you provide
% a savePath it should be a folder, and you will need to have npy-matlab
% available (https://github.com/kwikteam/npy-matlab)
%
% spikeTimes will be in samples, not seconds
fs = dir(fullfile(savePath, '*.npy'));
for i = 1:length(fs)
delete(fullfile(savePath, fs(i).name));
end
if exist(fullfile(savePath, '.phy'), 'dir')
rmdir(fullfile(savePath, '.phy'), 's');
end
spikeTimes = uint64(rez.st3(:,1));
% [spikeTimes, ii] = sort(spikeTimes);
spikeTemplates = uint32(rez.st3(:,2));
if size(rez.st3,2)>4
% spikeClusters = uint32(rez.st3(:,5));
end
amplitudes = rez.st3(:,3);
Nchan = rez.ops.Nchan;
try
load(rez.ops.chanMap);
catch
chanMap0ind = [0:Nchan-1]';
connected = ones(Nchan, 1);
xcoords = ones(Nchan, 1);
ycoords = (1:Nchan)';
end
% chanMap0 = chanMap(connected>1e-6);
nt0 = size(rez.W,1);
U = rez.U;
W = rez.W;
% for i = 1:length(chanMap0)
% chanMap0(i) = chanMap0(i) - sum(chanMap0(i) > chanMap(connected<1e-6));
% end
% [~, invchanMap0] = sort(chanMap0);
templates = zeros(Nchan, nt0, rez.ops.Nfilt, 'single');
for iNN = 1:rez.ops.Nfilt
templates(:,:,iNN) = squeeze(U(:,iNN,:)) * squeeze(W(:,iNN,:))';
end
templates = permute(templates, [3 2 1]); % now it's nTemplates x nSamples x nChannels
templatesInds = repmat([0:size(templates,3)-1], size(templates,1), 1); % we include all channels so this is trivial
templateFeatures = rez.cProj;
templateFeatureInds = uint32(rez.iNeigh);
pcFeatures = rez.cProjPC;
pcFeatureInds = uint32(rez.iNeighPC);
if ~isempty(savePath)
writeNPY(spikeTimes, fullfile(savePath, 'spike_times.npy'));
writeNPY(uint32(spikeTemplates-1), fullfile(savePath, 'spike_templates.npy')); % -1 for zero indexing
if size(rez.st3,2)>4
% writeNPY(uint32(spikeClusters-1), fullfile(savePath, 'spike_clusters.npy')); % -1 for zero indexing
end
writeNPY(amplitudes, fullfile(savePath, 'amplitudes.npy'));
writeNPY(templates, fullfile(savePath, 'templates.npy'));
writeNPY(templatesInds, fullfile(savePath, 'templates_ind.npy'));
% Fs = rez.ops.fs;
conn = logical(connected);
chanMap0ind = int32(chanMap0ind);
writeNPY(chanMap0ind(conn), fullfile(savePath, 'channel_map.npy'));
%writeNPY(connected, fullfile(savePath, 'connected.npy'));
% writeNPY(Fs, fullfile(savePath, 'Fs.npy'));
writeNPY([xcoords(conn) ycoords(conn)], fullfile(savePath, 'channel_positions.npy'));
writeNPY(templateFeatures, fullfile(savePath, 'template_features.npy'));
writeNPY(templateFeatureInds'-1, fullfile(savePath, 'template_feature_ind.npy'));% -1 for zero indexing
writeNPY(pcFeatures, fullfile(savePath, 'pc_features.npy'));
writeNPY(pcFeatureInds'-1, fullfile(savePath, 'pc_feature_ind.npy'));% -1 for zero indexing
whiteningMatrix = rez.Wrot/200;
whiteningMatrixInv = whiteningMatrix^-1;
writeNPY(whiteningMatrix, fullfile(savePath, 'whitening_mat.npy'));
writeNPY(whiteningMatrixInv, fullfile(savePath, 'whitening_mat_inv.npy'));
if isfield(rez, 'simScore')
similarTemplates = rez.simScore;
writeNPY(similarTemplates, fullfile(savePath, 'similar_templates.npy'));
end
%make params file
if ~exist(fullfile(savePath,'params.py'),'file')
fid = fopen(fullfile(savePath,'params.py'), 'w');
[~, fname, ext] = fileparts(rez.ops.fbinary);
fprintf(fid,['dat_path = ''',fname ext '''\n']);
fprintf(fid,'n_channels_dat = %i\n',rez.ops.NchanTOT);
fprintf(fid,'dtype = ''int16''\n');
fprintf(fid,'offset = 0\n');
fprintf(fid,'sample_rate = %i\n', rez.ops.fs);
fprintf(fid,'hp_filtered = False');
fclose(fid);
end
end