function warpedSurface = warpSurface(varargin)
% function varargout = warpSurface(varargin)
%
% Description:
%
% Warps a surface using AIH implementation of the MCT.
%
% Author: Johann Strasser
% Date: 080114


% Parameters
targetPoints = [];
sourcePoints = [];
sourceSurface = [];
method = 'closed'; % 'closed' for known correspondence, 'general' for unknnown correspondences

% Parameters required for both methods
sigma = 1;
lambda = 1;

% Parameters for method with unknnown correspondences
beta = 1;
anneal = 0.97;
max_it = 100;
tol = 1e-3;

verbose = 0;
axesHandle = -1;
targetColours = ['y'; 'r'];
sourceColours = ['b'; 'g'];
targetLabel = 'Target Point Set';
sourceLabel = 'Source Point Set';

edges = [];


for i=1:2:length(varargin)
    switch lower(varargin{i})
        case 'targetpoints'
            targetPoints = varargin{i+1};
        case 'sourcepoints'
            sourcePoints = varargin{i+1};
        case 'sourcesurface'
            sourceSurface = varargin{i+1};
        case 'method'
            method = varargin{i+1};
        case 'sigma'
            sigma = varargin{i+1};
        case 'lambda'
            lambda = varargin{i+1};
        case 'beta'
            beta = varargin{i+1};
        case 'anneal'
            anneal = varargin{i+1};
        case 'max_it'
            max_it = varargin{i+1};
        case 'tol'
            tol = varargin{i+1};
        case 'verbose'
            verbose = lower(varargin{i+1});
        case 'axeshandle'
            axesHandle = lower(varargin{i+1});
        case 'targetcolours'
            targetColours = lower(varargin{i+1});
        case 'sourcecolours'
            sourceColours = lower(varargin{i+1});
        case 'targetlabel'
            targetLabel = lower(varargin{i+1});
        case 'sourcelabel'
            sourceLabel = lower(varargin{i+1});
        case 'edges'
            edges = varargin{i+1};
        otherwise
            error(['Unknown parameter name passed to ', mfilename, '.  Name was ' varargin{i}])
    end
end

if isempty(targetPoints)
    error([mfilename, ': Parameter ''targetPoints'' not specified or empty.']);
end

if isempty(sourcePoints)
    error([mfilename, ': Parameter ''sourcePoints'' not specified or empty.']);
end

if isempty(sourceSurface)
    error([mfilename, ': Parameter ''sourceSurface'' not specified or empty.']);
end

if ~(isequal(method, 'general') || isequal(method, 'closed'))
    error([mfilename, ': Unknown method: ', method]);
end

warpedSurface = sourceSurface;
sourcePointsBackup = sourcePoints;
targetPointsBackup = targetPoints;
% dim = size(sourcePoints, 2);

if isequal(method, 'closed') && ~isequal(size(targetPoints), size(sourcePoints))
    error([mfilename, ': Size of source point set and target point set not equal, using general MCT.']);
end

if isequal(method, 'general')

    % METHOD #1: Use general MCT where point correspondences are not known

    % initialize some parameters
    % Calculate the coefficients that define the velocity field
    vectorField = mctlibj.MCTReg.calculateVectorField(targetPoints, sourcePoints, ...
        beta, sigma, lambda, anneal, tol, max_it);

    % Apply those coefficients to the vertices of the source points to get
    % the new target points
    targetPointsNew = mctlibj.MCTReg.applyVectorField(sourcePoints, sourcePoints, ...
        targetPoints, vectorField, beta);
    
    % Apply those coefficients to the vertices of the source surface to get
    % new target surface
    warpedSurface.vertices = mctlibj.MCTReg.applyVectorField(sourceSurface.vertices, sourcePoints, ...
        targetPoints, vectorField, beta);

    if verbose
        if ~ishandle(axesHandle)
            figure;
            axesHandle = gca;
        end

        hold on;

        sourceHandle = plot3(axesHandle, sourcePoints(:, 1), sourcePoints(:, 2), ...
            sourcePoints(:, 3), 'o', 'MarkerFaceColor', sourceColours(1, :), 'MarkerEdgeColor', sourceColours(2, :));
        sourceTextHandle = text(sourcePoints(1, 1), sourcePoints(1, 2), ...
            sourcePoints(1, 3), sourceLabel, 'Color', sourceColours(2, :), ...
            'BackgroundColor', sourceColours(1, :), 'EdgeColor', sourceColours(2, :));

        targetHandle = plot3(axesHandle, targetPoints(:, 1), targetPoints(:, 2), ...
            targetPoints(:, 3), 'o', 'MarkerFaceColor', targetColours(1, :), 'MarkerEdgeColor', targetColours(2, :));
        targetTextHandle = text(targetPoints(1, 1), targetPoints(1, 2), ...
            targetPoints(1, 3), targetLabel, 'Color', targetColours(2, :), ...
            'BackgroundColor', targetColours(1, :), 'EdgeColor', targetColours(2, :));

        targetNewLabel = 'Target new';
        targetNewColours = ['b'; 'g'];

        targetNewHandle = plot3(axesHandle, targetPointsNew(:, 1), targetPointsNew(:, 2), ...
            targetPointsNew(:, 3), 'o', 'MarkerFaceColor', targetNewColours(1, :), 'MarkerEdgeColor', targetNewColours(2, :));
        targetNewTextHandle = text(targetPointsNew(1, 1), targetPointsNew(1, 2), ...
            targetPointsNew(1, 3), targetNewLabel, 'Color', targetNewColours(2, :), ...
            'BackgroundColor', targetNewColours(1, :), 'EdgeColor', targetNewColours(2, :));


        % Render source surface
        sourceSurfaceHandle = patch('Faces', sourceSurface.faces, 'Vertices', sourceSurface.vertices, ...
            'EdgeColor', 'none', 'FaceColor', sourceColours(1, :), 'FaceLighting', 'gouraud',...
            'Parent', axesHandle);

        % Render warped surface
        warpedSurfaceHandle = patch('Faces', warpedSurface.faces, 'Vertices', warpedSurface.vertices, ...
            'EdgeColor', 'none', 'FaceColor', targetColours(1, :), 'FaceLighting', 'gouraud',...
            'Parent', axesHandle);

        %     plot_vector_field(targetPoints, sourcePoints, vectorField, beta);

        cameraPos = get(axesHandle, 'CameraPosition'); %
        lightHandle = light('Position', cameraPos);
        axis tight;
        axis equal;
        view(3);
        axis vis3d;
        box on;
        hold off;
    end
else
    % METHOD #2: Use MCT where point correspondences are known
    
    % Warp from source to target
    
    [targetPointsNorm, datay] = mctlib_normalise('pts', targetPoints);
    [sourcePointsNorm, datax] = mctlib_normalise('pts', sourcePoints);

    %     sigma = 1; % Good, slightly bumpy
    %     lambda = .01; % .01
    tPMuMat = repmat(datay.mu, size(targetPoints, 1), 1);
    sPMuMat = repmat(datax.mu, size(sourcePoints, 1), 1);
    V = (targetPointsNorm - sourcePointsNorm);
    beta = mctlib_CalcParam('x0', sourcePointsNorm, 'f', V, 'sigma', sigma, 'lambda', lambda);
    targetPointsNewNorm = sourcePointsNorm + mctlib_CalcVel('x0', sourcePointsNorm, 'y0', sourcePointsNorm, 'sigma', sigma, 'beta', beta);

    % Note that we have to use the target scale and mean here, since it is the
    % target positions we want to end up at
    targetPointsNew = targetPointsNewNorm .* datay.scale + tPMuMat;

    sourceVertices = sourceSurface.vertices;
    sVMuMat = repmat(datax.mu, size(sourceVertices, 1), 1);
    tVMuMat = repmat(datay.mu, size(sourceVertices, 1), 1);
    sourceVerticesNorm = (sourceVertices - sVMuMat)./ datax.scale;
    targetVerticesNewNorm = sourceVerticesNorm + mctlib_CalcVel('x0', sourcePointsNorm, 'y0', sourceVerticesNorm, 'sigma', sigma, 'beta', beta);

    % Note that we have to use the target scale and mean here, since it is the
    % target positions we want to end up at
    warpedSurface.vertices = targetVerticesNewNorm .* datay.scale + tVMuMat;

    % Warp from target to source and apply negative
    % Experimental, only works for points, but not for the surface, since
    % we have to start from the source due to the source surface being
    % given
%     
%     [targetPointsNorm, datay] = mctlib_normalise('pts', targetPoints);
%     [sourcePointsNorm, datax] = mctlib_normalise('pts', sourcePoints);
% 
%     The following achieves the opposite warp for the points
%     tPMuMat = repmat(datay.mu, size(targetPoints, 1), 1);
%     sPMuMat = repmat(datax.mu, size(sourcePoints, 1), 1);
%     V = (sourcePointsNorm - targetPointsNorm);
%     beta = mctlib_CalcParam('x0', targetPointsNorm, 'f', V, 'sigma', sigma, 'lambda', lambda);
%     targetPointsNewNorm = sourcePointsNorm + mctlib_CalcVel('x0', targetPointsNorm, 'y0', targetPointsNorm, 'sigma', sigma, 'beta', -beta);
% 
%     %Or
%     targetPointsNewNorm = sourcePointsNorm - mctlib_CalcVel('x0', targetPointsNorm, 'y0', targetPointsNorm, 'sigma', sigma, 'beta', beta);
%   
%     %     sigma = 1; % Good, slightly bumpy
%     %     lambda = .01; % .01
%     tPMuMat = repmat(datay.mu, size(targetPoints, 1), 1);
%     sPMuMat = repmat(datax.mu, size(sourcePoints, 1), 1);
% %     V = (targetPointsNorm - sourcePointsNorm);
%     V = (sourcePointsNorm - targetPointsNorm);
%     beta = mctlib_CalcParam('x0', targetPointsNorm, 'f', V, 'sigma', sigma, 'lambda', lambda);
%     targetPointsNewNorm = sourcePointsNorm + mctlib_CalcVel('x0', targetPointsNorm, 'y0', sourcePointsNorm, 'sigma', sigma, 'beta', beta);
% 
%     cla;
%     
%     % Note that we have to use the target scale and mean here, since it is the
%     % target positions we want to end up at
%     targetPointsNew = targetPointsNewNorm .* datay.scale + tPMuMat;
% 
%     sourceVertices = sourceSurface.vertices;
%     sVMuMat = repmat(datax.mu, size(sourceVertices, 1), 1);
%     tVMuMat = repmat(datay.mu, size(sourceVertices, 1), 1);
%     sourceVerticesNorm = (sourceVertices - sVMuMat)./ datax.scale;
%     targetVerticesNewNorm = sourceVerticesNorm + mctlib_CalcVel('x0', sourcePointsNorm, 'y0', sourceVerticesNorm, 'sigma', sigma, 'beta', beta);
% 
%     % Note that we have to use the target scale and mean here, since it is the
%     % target positions we want to end up at
%     warpedSurface.vertices = targetVerticesNewNorm .* datay.scale + tVMuMat;

    
    if verbose
        if ~ishandle(axesHandle)
            figure('Position', [50, 50, 660, 660]);        
        end
        
        % Plot the scale-normalised data
%         subplot(1, 2, 1);
        hold on;
        axesHandle = gca;
        hold on;
        view(3);
        axis vis3d;
        box on;

        sourceNormHandle = plot3(axesHandle, sourcePointsNorm(:, 1), sourcePointsNorm(:, 2), ...
            sourcePointsNorm(:, 3), 'o', 'MarkerFaceColor', sourceColours(1, :), 'MarkerEdgeColor', sourceColours(2, :));
        %         sourceNormTextHandle = text(sourcePointsNorm(1, 1), sourcePointsNorm(1, 2), ...
        %             sourcePointsNorm(1, 3), sourceLabel, 'Color', sourceColours(2, :), ...
        %             'BackgroundColor', sourceColours(1, :), 'EdgeColor', sourceColours(2, :));
        %

        for i=1:size(edges,1)
            sourceEdgeHandle = plot3(axesHandle, [sourcePointsNorm(edges(i,1),1), sourcePointsNorm(edges(i,2), 1)],...
                [sourcePointsNorm(edges(i,1),2), sourcePointsNorm(edges(i,2),2)], [sourcePointsNorm(edges(i,1),3), sourcePointsNorm(edges(i,2),3)], ...
                '-', 'Color', sourceColours(2, :), 'LineWidth', 2);
        end

        targetNormHandle = plot3(axesHandle, targetPointsNorm(:, 1), targetPointsNorm(:, 2), ...
            targetPointsNorm(:, 3), 'o', 'MarkerFaceColor', targetColours(1, :), 'MarkerEdgeColor', targetColours(2, :));

        for i=1:size(edges,1)
            targetEdgeHandle = plot3(axesHandle, [targetPointsNorm(edges(i,1),1), targetPointsNorm(edges(i,2), 1)],...
                [targetPointsNorm(edges(i,1),2), targetPointsNorm(edges(i,2),2)], [targetPointsNorm(edges(i,1),3), targetPointsNorm(edges(i,2),3)], ...
                '-', 'Color', targetColours(2, :), 'LineWidth', 2);
        end


        %         targetNormTextHandle = text(targetPointsNorm(1, 1), targetPointsNorm(1, 2), ...
        %             targetPointsNorm(1, 3), targetLabel, 'Color', targetColours(2, :), ...
%             'BackgroundColor', targetColours(1, :), 'EdgeColor', targetColours(2, :));
% 
%         targetNewLabel = 'Target new';
%         targetNewColours = ['b'; 'g'];
% 
%         targetNewHandle = plot3(axesHandle, targetPointsNewNorm(:, 1), targetPointsNewNorm(:, 2), ...
%             targetPointsNewNorm(:, 3), 'o', 'MarkerFaceColor', targetNewColours(1, :), 'MarkerEdgeColor', targetNewColours(2, :));
%         targetNewTextHandle = text(targetPointsNewNorm(1, 1), targetPointsNewNorm(1, 2), ...
%             targetPointsNewNorm(1, 3), targetNewLabel, 'Color', targetNewColours(2, :), ...
%             'BackgroundColor', targetNewColours(1, :), 'EdgeColor', targetNewColours(2, :));
% 
% 
%         % Render source surface
        sourceSurfaceNormHandle = patch('Faces', sourceSurface.faces, 'Vertices', sourceVerticesNorm, ...
            'EdgeColor', 'none', 'FaceColor', [0, 0, 1], 'FaceLighting', 'phong',...
            'Parent', axesHandle);

        % Render warped surface
        warpedSurfaceNormHandle = patch('Faces', warpedSurface.faces, 'Vertices', targetVerticesNewNorm, ...
            'EdgeColor', 'none', 'FaceColor', [0, 1, 0], 'FaceLighting', 'phong',...
            'Parent', axesHandle);

%         plot_vector_field(targetPoints, sourcePoints, vectorField, beta);
% 
        
        axis tight;
        axis equal;
        view(3);
        axis vis3d;
        box on;
        
        axis ij;
        
        cameraPos = get(axesHandle, 'CameraPosition'); %
        lightHandle = light('Position', cameraPos);
        
            view(-140, 10);
        
        view(264, 10); % Lateral viewpoint
        view(-6, 10); % Adaxial viewpoint
        
%         legendHandles = [sourceHandle, targetHandle, targetNewHandle];
%         legendNames = {'Aligned template', 'Target model point set', 'Warped template'};
%         legend(legendHandles, legendNames, 'Location', 'Northeast');

        legendHandles = [sourceNormHandle, targetNormHandle];
%         legendNames = {'Aligned template', 'Model point set and warped template surface'};
        legendNames = {'Aligned template', 'Model point set'};
        l = legend(legendHandles, legendNames, 'Location', 'Northeast');

        
        hold off;
        
        % Plot the raw data
%         subplot(1, 2, 2);
        figure('Position', [50, 50, 660, 660]);
        axesHandle = gca;
        hold on;
        axis equal;
        view(3);
        axis vis3d;
        box on;
              
%         lightSourceColours = sourceColours(1, :);
%         lightSourceColours(lightSourceColours == 0) = 0.
%         
%         + [0.5, 0.5, 0.5]
        sourceHandle = plot3(axesHandle, sourcePoints(:, 1), sourcePoints(:, 2), ...
            sourcePoints(:, 3), 'o', 'MarkerFaceColor', sourceColours(1, :), 'MarkerEdgeColor', sourceColours(2, :));
%         sourceTextHandle = text(sourcePoints(1, 1), sourcePoints(1, 2), ...
%             sourcePoints(1, 3), sourceLabel, 'Color', sourceColours(2, :), ...
%             'BackgroundColor', sourceColours(1, :), 'EdgeColor', sourceColours(2, :));


        
        for i=1:size(edges,1)
            sourceEdgeHandle = plot3(axesHandle, [sourcePoints(edges(i,1),1), sourcePoints(edges(i,2), 1)],...
                [sourcePoints(edges(i,1),2), sourcePoints(edges(i,2),2)], [sourcePoints(edges(i,1),3), sourcePoints(edges(i,2),3)], ...
                '-', 'Color', sourceColours(2, :), 'LineWidth', 2);
        end


        targetHandle = plot3(axesHandle, targetPoints(:, 1), targetPoints(:, 2), ...
            targetPoints(:, 3), 'o', 'MarkerFaceColor', targetColours(1, :), 'MarkerEdgeColor', targetColours(2, :));
%         targetTextHandle = text(targetPoints(1, 1), targetPoints(1, 2), ...
%             targetPoints(1, 3), targetLabel, 'Color', targetColours(2, :), ...
%             'BackgroundColor', targetColours(1, :), 'EdgeColor', targetColours(2, :));
% 
        for i=1:size(edges,1)
            targetEdgeHandle = plot3(axesHandle, [targetPoints(edges(i,1),1), targetPoints(edges(i,2), 1)],...
                [targetPoints(edges(i,1),2), targetPoints(edges(i,2),2)], [targetPoints(edges(i,1),3), targetPoints(edges(i,2),3)], ...
                '-', 'Color', targetColours(2, :), 'LineWidth', 2);
        end


        targetNewLabel = 'Target new';
%         targetNewColours = ['b'; 'g'];
        targetNewColours = targetColours;
        targetNewColours(1, :) = targetColours(1, :) + sourceColours(1, :);

%         targetNewHandle = plot3(axesHandle, targetPointsNew(:, 1), targetPointsNew(:, 2), ...
%             targetPointsNew(:, 3), 'o', 'MarkerFaceColor', targetNewColours(1, :), 'MarkerEdgeColor', targetNewColours(2, :));
%         targetNewTextHandle = text(targetPointsNew(1, 1), targetPointsNew(1, 2), ...
%             targetPointsNew(1, 3), targetNewLabel, 'Color', targetNewColours(2, :), ...
%             'BackgroundColor', targetNewColours(1, :), 'EdgeColor', targetNewColours(2, :));

%         for i=1:size(edges,1)
%             targetNewEdgeHandle = plot3(axesHandle, [targetPointsNew(edges(i,1),1), targetPointsNew(edges(i,2), 1)],...
%                 [targetPointsNew(edges(i,1),2), targetPointsNew(edges(i,2),2)], [targetPointsNew(edges(i,1),3), targetPointsNew(edges(i,2),3)], ...
%                 '-', 'Color', targetColours(2, :), 'LineWidth', 2);
%         end


        % Render source surface
        sourceSurfaceHandle = patch('Faces', sourceSurface.faces, 'Vertices', sourceSurface.vertices, ...
            'EdgeColor', 'none', 'FaceColor', [0, 0, 1], 'FaceLighting', 'phong',...
            'Parent', axesHandle);

        % Render warped surface
        warpedSurfaceHandle = patch('Faces', warpedSurface.faces, 'Vertices', warpedSurface.vertices, ...
            'EdgeColor', 'none', 'FaceColor', [0, 1, 0], 'FaceLighting', 'phong',...
            'Parent', axesHandle);

        %     plot_vector_field(targetPoints, sourcePoints, vectorField, beta);
        
       
        axis tight;
        axis equal;
        view(3);
        axis vis3d;
        box on;
        
        axis ij;
        
        cameraPos = get(axesHandle, 'CameraPosition'); %
        lightHandle = light('Position', cameraPos);
        
         
         view(-140, 10);
        
        view(264, 10); % Lateral viewpoint
        view(-6, 10); % Adaxial viewpoint
        
%         legendHandles = [sourceHandle, targetHandle, targetNewHandle];
%         legendNames = {'Aligned template', 'Target model point set', 'Warped template'};
%         legend(legendHandles, legendNames, 'Location', 'Northeast');

        legendHandles = [sourceHandle, targetHandle];
        legendNames = {'Aligned template', 'Model point set and warped template surface'};
        l = legend(legendHandles, legendNames, 'Location', 'Northeast');


        hold off;
    end
end
varargout{1} = warpedSurface;