clear all; clc;

addpath('C:\Users\Steve\Desktop\Neuro_Utilities\Matlab Scripts\Hayashi Lab Code');

% Import Data
Ms1A = readtable('D:/ACC Data/3ContextData/25/day9/XYtotalspike1.csv');
Ms1B = readtable('D:/ACC Data/3ContextData/25/day9/XYtotalspike2.csv');
Ms1C = readtable('D:/ACC Data/3ContextData/25/day9/XYtotalspike6.csv');

Ms2A = readtable('D:/ACC Data/3ContextData/532/day9/XYtotalspike1.csv');
Ms2B = readtable('D:/ACC Data/3ContextData/532/day9/XYtotalspike2.csv');
Ms2C = readtable('D:/ACC Data/3ContextData/532/day9/XYtotalspike6.csv');

Ms3A = readtable('D:/ACC Data/3ContextData/614/day9/XYtotalspike1.csv');
Ms3B = readtable('D:/ACC Data/3ContextData/614/day9/XYtotalspike2.csv');
Ms3C = readtable('D:/ACC Data/3ContextData/614/day9/XYtotalspike6.csv');

CellIDsMs1=readtable('D:/ACC Data/3ContextData/25/day9/SC.csv');
CellIDsMs2=readtable('D:/ACC Data/3ContextData/532/day9/SC.csv');
CellIDsMs3=readtable('D:/ACC Data/3ContextData/614/day9/SC.csv');

SCcellsMs1=find(sum(table2array(CellIDsMs1),2)>0);
SCcellsMs2=find(sum(table2array(CellIDsMs2),2)>0);
SCcellsMs3=find(sum(table2array(CellIDsMs3),2)>0);



% Organize Data
SessionAlength = min([size(Ms1A,1),size(Ms2A,1),size(Ms3A,1)]);
SessionBlength = min([size(Ms1B,1),size(Ms2B,1),size(Ms3B,1)]);
SessionClength = min([size(Ms1C,1),size(Ms2C,1),size(Ms3C,1)]);

Ms1A=Ms1A(1:SessionAlength,6:end);Ms2A=Ms2A(1:SessionAlength,6:end);Ms3A=Ms3A(1:SessionAlength,6:end);
Ms1B=Ms1B(1:SessionBlength,6:end);Ms2B=Ms2B(1:SessionBlength,6:end);Ms3B=Ms3B(1:SessionBlength,6:end);
Ms1C=Ms1C(1:SessionClength,6:end);Ms2C=Ms2C(1:SessionClength,6:end);Ms3C=Ms3C(1:SessionClength,6:end);

Ms1numcells=size(Ms1A,2);
Ms2numcells=size(Ms2A,2);
Ms3numcells=size(Ms3A,2);

temp1=[Ms1A;Ms1B;Ms1C];
temp2=[Ms2A;Ms2B;Ms2C];
temp3=[Ms3A;Ms3B;Ms3C];

SCcellsMs2=SCcellsMs2+Ms1numcells;
SCcellsMs3=SCcellsMs3+Ms1numcells+Ms2numcells;
AllSCCells=[SCcellsMs1;SCcellsMs2;SCcellsMs3];

Alldata=[table2array(temp1),table2array(temp2),table2array(temp3)];
%Alldata=sparse(Alldata);
clear Ms1A Ms1B Ms1C Ms2A Ms2B Ms2C Ms3A Ms3B Ms3C Ms4A Ms4B Ms4C Ms5A Ms5B Ms5C temp*

temp1 = cell(1, SessionAlength); temp1(:) = {'Context1'};
temp2 = cell(1, SessionBlength); temp2(:) = {'Context2'};
temp3 = cell(1, SessionClength); temp3(:) = {'Context3'}; % This might need to change to trackB
labels=[temp1';temp2';temp3']; clear temp*
SCcellsAlldata=Alldata(:,AllSCCells);

t = templateSVM('Standardize',true,'KernelFunction','gaussian');

fraction_train = 0.8;
[trainInd,valInd,testInd] = dividerand(size(SCcellsAlldata,1),fraction_train,0,(1-fraction_train));
[Model] = fitcecoc(SCcellsAlldata(trainInd,:),labels(trainInd,:),'Learners','knn');%,'OptimizeHyperparameters','auto');

newly_labelled_data=predict(Model,SCcellsAlldata(testInd,:));

%temp1=newly_labelled_data;
%temp2=labels(testInd);

tf = strcmp(newly_labelled_data,labels(testInd)); 
decodingaccuracy=sum(tf)/length(tf)*100;

x = setdiff(1:size(Alldata,2), AllSCCells);
Surrogate_Decoding_Vals=NaN(1000,1);

for num_shuffs=1:1000;

    rx = randsample(x,length(AllSCCells));
    tempdata=Alldata(:,rx);
    [trainInd,valInd,testInd] = dividerand(size(Alldata,1),fraction_train,0,(1-fraction_train));
    [Model] = fitcecoc(tempdata(trainInd,:),labels(trainInd,:),'Learners','knn');%,'OptimizeHyperparameters','auto');
    PCCnewly_labelled_data=predict(Model,tempdata(testInd,:));
    PCCtf = strcmp(PCCnewly_labelled_data,labels(testInd)); 
    Surrogate_Decoding_Vals(num_shuffs)=sum(PCCtf)/length(PCCtf)*100;
    
end;

figure;histogram(Surrogate_Decoding_Vals,'facecolor',[141/255, 198/255, 63/255]);
hold on
xline(decodingaccuracy,'color',[251/255, 176/255, 64/255]);
xlim([30 70])