clear all

path = 'E:\worm\4-13-2009\lin12A594_lag2Cy5_12.30pm\';
channels={'A594', 'Cy5'};

[filename,dummy]=uigetfile('dapi*', 'DialogTitle', path);
worm1=sscanf(filename, 'dapi%d');

[filename,dummy]=uigetfile('dapi*', 'DialogTitle', path);
worm2=sscanf(filename, 'dapi%d');

% useVPCmask=1;
% useACmask=1;
useMask=1;

proj_path=[path 'analyzed\projections\'];
infile=sprintf('dapi%03d_mean.tif',worm1);
im1=mat2gray(imread([proj_path infile]));
infile=sprintf('dapi%03d_mean.tif',worm2);
im2=mat2gray(imread([proj_path infile]));

%% loads VPCs and determine which VPCs are annotated in both worms

infile=sprintf('celldata%03d.mat',worm1);
inpath=[path 'analyzed\matlab_data\'];
data1=load([inpath infile],'VPC','AC', 'DTC');

infile=sprintf('celldata%03d.mat',worm2);
inpath=[path 'analyzed\matlab_data\'];
data2=load([inpath infile],'VPC','AC', 'DTC');


overlapX1=[];
overlapX2=[];
for n=1:length(data1.VPC)

    if data1.VPC(n).N > 0 && data2.VPC(n).N > 0
        if data1.VPC(n).N == data2.VPC(n).N
            for i=1:data1.VPC(n).N
                overlapX1=[overlapX1 data1.VPC(n).coord(1:2,i)];
                overlapX2=[overlapX2 data2.VPC(n).coord(1:2,i)];
            end
        end
    end
end

%% find the offset between the two worm images from overlapping VPCs

DX=round(mean(overlapX1-overlapX2,2));
% plot(overlapX1(1,:), overlapX1(2,:),'o',overlapX2(1,:)+DX(1), overlapX2(2,:)+DX(2),'o')

%% merge max projections

immerge=zeros(1024+abs(DX(2)),1024+abs(DX(1)));

if DX(2)>0
    x1=0; x2=DX(2);
else
    x1=-DX(2); x2=0;
end

if DX(1)>0
    y1=0; y2=DX(1);
else
    y1=-DX(1); y2=0;
end

immerge(x2+(1:1024),y2+(1:1024))=im2;
immerge(x1+(1:1024),y1+(1:1024))=im1;
%% merge spots

w=[worm1 worm2];

for i=1:length(channels)
    
    data(i).spots=[];
    data(i).cellid=[];
    for n=1:2
        matlab_path=[path 'analyzed\matlab_data\'];
        infile=sprintf('%s%03d_spotdata.mat', channels{i}, w(n));
        load([matlab_path infile], 'merged_spot');

        if useMask
            % load VPC mask
            infile=sprintf('VPCmask%03d.mat',w(n));
            load([matlab_path infile], 'border');  
            VPCmask=roipoly(1024,1024,border(1,:), border(2,:));
            
            if n==1
                % load AC mask
                infile=sprintf('ACmask%03d.mat',w(n));
                load([matlab_path infile], 'border');  
                ACmask=roipoly(1024,1024,border(1,:), border(2,:));
                
                % exclude AC mask region from VPC mask
                VPCmask=VPCmask.*~ACmask;
            else
                ACmask=zeros(1024,1024);
            end
            

            % find all spots in AC or VPC mask 
            q=[];
            for s=1:length(merged_spot)
                x=ceil(merged_spot(s,1));
                y=ceil(merged_spot(s,2));
                q(s,1)=ACmask(y,x);
                q(s,2)=VPCmask(y,x);
            end

            % select for AC and VPC spots
            rAC=find(q(:,1)==1);
            AC_spot=merged_spot(rAC,:);
            rVPC=find(q(:,2)==1);
            VPC_spot=merged_spot(rVPC,:);

            Nac=length(rAC); Nvpc=length(rVPC);
            merged_spot=[AC_spot; VPC_spot];
            id_spot=[];
            id_spot(1:Nac)=1;
            id_spot(Nac+(1:Nvpc))=2;
        end
               
        if n==1
            merged_spot(:,1)=merged_spot(:,1)+y1;
            merged_spot(:,2)=merged_spot(:,2)+x1;
            data(i).spots=[data(i).spots merged_spot];
            data(i).cellid=[data(i).cellid id_spot];
        else
            r=find(~(merged_spot(:,1)+y2>abs(DX(1)) & merged_spot(:,1)+y2<1024 & merged_spot(:,2)+x2>abs(DX(2)) & merged_spot(:,2)+x2<1024));
            merged_spot(:,1)=merged_spot(:,1)+y2;
            merged_spot(:,2)=merged_spot(:,2)+x2;
            data(i).spots=[data(i).spots; merged_spot(r,:)];
            data(i).cellid=[data(i).cellid id_spot(1,r)];
        end
    end
end

%% merge VPCs, AC and DTCs

% assign VPCs from worm 1 and adjust coordinates
VPC=data1.VPC;
for n=1:length(VPC)
    for i=1:VPC(n).N
        VPC(n).coord(1,i)=VPC(n).coord(1,i)+y1;
        VPC(n).coord(2,i)=VPC(n).coord(2,i)+x1;
    end
end

% copy VPCs from worm 2, if VPCs do not exist in worm 1
for n=1:length(VPC)
    if VPC(n).N==0 & data2.VPC(n).N ~= 0
        VPC(n).N=data2.VPC(n).N;
        VPC(n).coord=data2.VPC(n).coord;
        for i=1:data2.VPC(n).N
            VPC(n).coord(1,i)=data2.VPC(n).coord(1,i)+y2;
            VPC(n).coord(2,i)=data2.VPC(n).coord(2,i)+x2;
        end
    else
        if VPC(n).N ~= data2.VPC(n).N
            for i=VPC(n).N:data2.VPC(n).N
                VPC(n).coord(1,i)=data2.VPC(n).coord(1,i)+y2;
                VPC(n).coord(2,i)=data2.VPC(n).coord(2,i)+x2;
            end
        end
    end
end

% copy AC
if ~isempty(data1.AC)
    AC=data1.AC;
    AC(1)=AC(1)+y1;
    AC(2)=AC(2)+x1;
else
    AC=data2.AC;
    AC(1)=AC(1)+y2;
    AC(2)=AC(2)+x2;
end

% copy DTCs
DTC=data1.DTC;
for i=1:2
    if ~isempty(DTC(i).coord)
        DTC(i).coord(1)=DTC(i).coord(1)+y1;
        DTC(i).coord(2)=DTC(i).coord(2)+x1;
    end
end
if isempty(DTC(2).coord)
    DTC(2).coord(1)=data2.DTC(1).coord(1)+y2;
    DTC(2).coord(2)=data2.DTC(1).coord(2)+x2;
end

%% get axis for merged worm

ax=[];

cont=1;
while cont==1

    imshow(immerge,[],'In', 50); hold on;
    if ~isempty(ax)
        plot(ax(1,:), ax(2,:), '-sb');
    end

    [x,y,but]=ginput(1);
    
    if but == 1
        n=size(ax,2);
        ax(1,n+1)=x;
        ax(2,n+1)=y;
    end
    
    if but == 3
        n=size(ax,2);
        ax(:,n)=[];
    end
    
    if but == 27
        cont=0;
    end
    
    hold off;

end

%% plot resulting merger
figure(1)
imshow(immerge, 'In', 50); hold on
for z=1:25
    r=find(data(1).spots(:,3)==z);
    plot(data(1).spots(r,1),data(1).spots(r,2),'.', 'color', [z 0 0]/25)
    r=find(data(2).spots(:,3)==z);
    plot(data(2).spots(r,1),data(2).spots(r,2),'.', 'color', [0 z 0]/25)
%     r=find(data(2).spots(:,3)==z & data(2).cellid'==1);
%     plot(data(2).spots(r,1),data(2).spots(r,2),'.', 'color', [z 0 0]/25)
end
for n=1:length(VPC)
    for i=1:VPC(n).N
        plot(VPC(n).coord(1,i),VPC(n).coord(2,i),'ob');
    end
end
plot(AC(1),AC(2),'sm');
for i=1:2
    plot(DTC(i).coord(1),DTC(i).coord(2),'*y');
end
% plot(ax(1,:),ax(2,:),'-o');

hold off;

%% straighten everything in the worm
    
ax_dx=diff(ax(1,:));
ax_dy=diff(ax(2,:));
ax_dr=sqrt(ax_dx.^2+ax_dy.^2);
ax_length=cumsum(ax_dr);
ax_normal=-[ax_dx./ax_dr; ax_dy./ax_dr];

% total length of worm in pixels.
Ltot=ax_length(end);

% for a resolution of ~2 points per pixel, calculate ds
ds=1/(2*Ltot);

% calculate a spline through axis with 0.5 pixel resolution
x=ax(1,:);        
y=ax(2,:);
yy=[x; y];
xx=1:length(x);
xxx=1:(ds*length(x)):length(x);
pp=spline(xx,yy,xxx);

% calculate normal and lengtha of spline segments
pp_dx=diff(pp(1,:));
pp_dy=diff(pp(2,:));
pp_dr=sqrt(pp_dx.^2+pp_dy.^2);
pp_length=cumsum(pp_dr);
pp_normal=-[pp_dx./pp_dr; pp_dy./pp_dr];

% hold on;
% plot(pp(1,:), pp(2,:))

% straighten mRNAs
for i=1:length(channels)
    data(i).mRNA=data(i).spots;
    for t=1:length(data(i).spots)

        x2=data(i).spots(t,1:2)';
        DX = findAxisIntersection (x2, ax, ax_dr, ax_length, ax_normal, pp, pp_dr, pp_length, pp_normal,0);
        data(i).mRNA(t,1)=DX(1);
        data(i).mRNA(t,2)=DX(2);
        data(i).mRNA(t,3)=data(i).spots(t,3);
    end
end

% find coordinates with respect to body axis for VPCs
for t=1:length(VPC)

    if VPC(t).N > 0

        for i=1:VPC(t).N
            x2=VPC(t).coord(1:2,i);
            DX = findAxisIntersection (x2, ax, ax_dr, ax_length, ax_normal, pp, pp_dr, pp_length, pp_normal,0);

            VPC(t).coord(4,i)=DX(1);
            VPC(t).coord(5,i)=DX(2);
        end
    end
end

% find coordinates with respect to body axis for ACs
x2=AC(1:2,1);
DX = findAxisIntersection (x2, ax, ax_dr, ax_length, ax_normal, pp, pp_dr, pp_length, pp_normal,0);

AC(4,1)=DX(1);
AC(5,1)=DX(2);

% find coordinates with respect to body axis for gonad
for n=1:length(DTC)
    x2=DTC(n).coord(1:2)';
    DX = findAxisIntersection (x2, ax, ax_dr, ax_length, ax_normal, pp, pp_dr, pp_length, pp_normal,0);

    DTC(n).coord(4)=DX(1);
    DTC(n).coord(5)=DX(2);
end
%% plot straightened worm data
figure(2)
plot([0 ax_length], zeros(1,length(ax_length)+1),'-ob');
hold on;

for z=1:25
    r=find(data(1).spots(:,3)==z & data(1).cellid'==1);
    plot(data(1).mRNA(r,1),data(1).mRNA(r,2),'o', 'color', [z 0 0]/25)
    r=find(data(1).spots(:,3)==z & data(1).cellid'==2);
    plot(data(1).mRNA(r,1),data(1).mRNA(r,2),'.', 'color', [z 0 0]/25)
    
    r=find(data(2).spots(:,3)==z & data(2).cellid'==1);
    plot(data(2).mRNA(r,1),data(2).mRNA(r,2),'o', 'color', [0 z 0]/25)
    r=find(data(2).spots(:,3)==z & data(2).cellid'==2);
    plot(data(2).mRNA(r,1),data(2).mRNA(r,2),'.', 'color', [0 z 0]/25)
end

col=jet(4);
for t=1:length(VPC)

    if VPC(t).N > 0
        for i=1:VPC(t).N
            if t~=6
                plot(VPC(t).coord(4,i),VPC(t).coord(5,i), 'ok', 'MarkerFaceColor', col(VPC(t).N,:));
            else
                plot(VPC(t).coord(4,i),VPC(t).coord(5,i), 'sk', 'MarkerFaceColor', col(VPC(t).N,:));
            end
        end
    end
end

plot(AC(4,1), AC(5,1), 'sk', 'MarkerFaceColor', [0 1 0]);

for i=1:2
    plot(DTC(i).coord(4),DTC(i).coord(5),'dk', 'MarkerFaceColor', [1 1 0]);
end


xlim([0 ax_length(end)]);
ylim(ax_length(end)/4*[-1 1]);

%% save data

for chn=1:length(channels)
    outfile=sprintf('%s_merge%03d_%03d_mRNA_AC_straight.mat', channels{chn}, worm1, worm2);
    r=find(data(chn).cellid'==1);
    mRNA=data(chn).mRNA(r,:);
    save([matlab_path outfile], 'mRNA');
    
    outfile=sprintf('%s_merge%03d_%03d_mRNA_VPC_straight.mat', channels{chn}, worm1, worm2);
    r=find(data(chn).cellid'==2);
    mRNA=data(chn).mRNA(r,:);
    save([matlab_path outfile], 'mRNA');

end

outfile=sprintf('mergedata%03d_%03d.mat', worm1, worm2);
save([matlab_path outfile], 'VPC', 'AC', 'DTC', 'ax');


%% plot separate data of merged images
% figure(2)
% imshow(immerge, 'In', 33)
% 
% hold on;
% for n=1:length(data2.VPC)
%     for i=1:data2.VPC(n).N
%         plot(data2.VPC(n).coord(1,i)+y2,data2.VPC(n).coord(2,i)+x2,'ob');
%     end
% end
% for n=1:length(data1.VPC)
%     for i=1:data1.VPC(n).N
%         plot(data1.VPC(n).coord(1,i)+y1,data1.VPC(n).coord(2,i)+x1,'or');
%     end
% end
% 
% % plot(data1.ax(1,:)+y1, data1.ax(2,:)+x1, '-sb');
% % plot(data2.ax(1,:)+y2, data2.ax(2,:)+x2, '-sr');
% 
% cm=jet(4);
% matlab_path=[path 'analyzed\matlab_data\'];
% for i=1:2
%     infile=sprintf('%s%03d_spotdata.mat', channels{i}, worm1);
%     load([matlab_path infile], 'merged_spot');
%     plot(merged_spot(:,1)+y1,merged_spot(:,2)+x1, '.', 'color', cm(2*i,:))
%     
%     infile=sprintf('%s%03d_spotdata.mat', channels{i}, worm2);
%     load([matlab_path infile], 'merged_spot');
%     r=find(~(merged_spot(:,1)+y2>abs(DX(1)) & merged_spot(:,1)+y2<1024 & merged_spot(:,2)+x2>abs(DX(2)) & merged_spot(:,2)+x2<1024));
%     plot(merged_spot(r,1)+y2,merged_spot(r,2)+x2, '.', 'color', cm(2*i-1,:));
% 
% end
% 
% hold off;