function [m,U,K,F] = totalK( m, useGrowthTensors )
%[m,U,K,F] = totalK( m, useGrowthTensors )
%    Solve the FEM model for the mesh.
%    The mesh already contains the temperature at each node and the residual
%    displacements of the nodes.  After the computation, the mesh will contain
%    the new positions of all the nodes.
%    The additional results are:
%    U: the displacements applied to all the nodes.  If the computation
%    fails in a manner that renders U non-meaningful, U will be returned as
%    empty.
%    K: The assembled K matrix for the FEM computation.
%    F: The assembled force vector for the FEM computation.
%    The mesh is additionally assumed to have the following components:
%    m.gaussInfo: the Gauss quadrature points for a single cell in isoparametric
%       coordinates, and the values and gradients of the shape functions at those points.
%    m.D: the compliance matrix in global coordinates.  This will
%        eventually be replaced by data per cell.
%    If useGrowthTensors is true (default is false) then growth
%    tensors have been specified for each cell.  Otherwise they are
%    calculated from the morphogens.

% Tasks remaining:
% 1. Implement per-cell anisotropic elastic and thermal moduli.
% 2. Experiment with tolerance to see how large we can set it and get
%    reasonable results.

    global CANUSEGPUARRAY
    if nargin < 2, useGrowthTensors = 0; end

    requireK = nargout >= 3;
    requireF = nargout >= 4;
    
    sb = findStopButton( m );
    numNodes = size(m.prismnodes,1);
    numCells = size(m.tricellvxs,1);
    vxsPerCell = 6;
  % numGaussPoints = 6;
    dfsPerNode = 3;
    dfsPerCell = vxsPerCell*dfsPerNode;
    numDFs = numNodes*dfsPerNode;
    m = makeTRIvalid( m );
    
    setGlobals();
    global gOLD_STRAINRET
    global gNEW_STRAINRET
    if m.versioninfo.mgenversion==0
        STRAINRET_MGEN = gOLD_STRAINRET;
    else
        STRAINRET_MGEN = gNEW_STRAINRET;
    end
    
    locnode = m.globalDynamicProps.locatenode;
    locDFs = m.globalDynamicProps.locateDFs;
    dolocate = (locnode ~= 0) && any(locDFs);
    
    if userinterrupt( sb )
        m.displacements = [];
        U = [];
        fprintf( 1, 'Simulation interrupted by user at step %d.\n', ...
            m.globalDynamicProps.currentIter );
        return;
    end

    useSparse = false;
    % Sparse arrays must be doubles.
    useSingle = strcmp(m.globalProps.solverprecision,'single') && ~useSparse;
    if ~useSparse
        try
            if useSingle
                K = zeros( numDFs, numDFs, 'single' );
            else
                K = zeros( numDFs, numDFs );
            end
%         catch exc
%             switch exc.identifier
%                 case 'MATLAB:nomem'
%                     reason = ': not enough memory';
%                 case 'MATLAB:pmaxsize'
%                     reason = ': larger than allowed by Matlab';
%                 otherwise
%                     reason = '';
%             end
        catch
            err = lasterror();
            switch err.identifier
                case 'MATLAB:nomem'
                    reason = ': not enough memory';
                case 'MATLAB:pmaxsize'
                    reason = ': larger than allowed by Matlab';
                otherwise
                    reason = '';
            end
            if m.globalProps.allowsparse
                usingsparse = '  Using a sparse array instead.';
            else
                usingsparse = '';
            end
            fprintf( 1, 'Cannot allocate a %d by %d full array%s.%s\n', ...
                numDFs, numDFs, reason, usingsparse );
            if m.globalProps.allowsparse
                useSparse = true;
            else
                return;
            end
        end
    end
    if useSparse
        % In the next line, 4 is just an estimate.  The proper value for the
        % third argument is numDFs*dfsPerNode times the average number of nodes
        % in all the cells that a typical node is a member of.
        estimatedSpace = numDFs*dfsPerCell*4;
        % fprintf( 1, 'totalK.m: allocating %d entries.\n', estimatedSpace );
        K = spalloc( numDFs, numDFs, estimatedSpace );
    end
    if useSingle
        F = zeros( numDFs, 1, 'single' );
    else
        F = zeros( numDFs, 1 );
    end
        
    eliminateRigidMotion = 0;
    exactinv = false;
    if eliminateRigidMotion
        R = zeros( 6, numDFs );
        R( 1, 1:3:numDFs ) = 1;
        R( 2, 2:3:numDFs ) = 1;
        R( 3, 3:3:numDFs ) = 1;
        for i = 1:numNodes
            av = sum(m.prismnodes,1)/numNodes;
            x = m.prismnodes(i,1) - av(1);
            y = m.prismnodes(i,2) - av(2);
            z = m.prismnodes(i,3) - av(3);
            j = i*3;
            R([4 5 6],[j-2, j-1, j]) = [ [ 0 z -y ]; [ -z 0 x ]; [ y -x 0 ] ];
        end
      % R
        [xmin,n1] = min(m.prismnodes(:,1));
        [xmax,n2] = max(m.prismnodes(:,1));
        [ymax,n3] = max(m.prismnodes(:,2));
        n1 = n1*3;
        n2 = n2*3;
        n3 = n3*3;
        selectedDFs = [n1-2, n1-1, n1, n2-1, n2, n3];
    end

    sr = max( min( m.morphogens(:,STRAINRET_MGEN), 1 ), 0 );
    if m.globalProps.timestep==0
        % 0^0 is deemed to be 0, anything_else^0 is 1.
        residStrainPerStep = ones(size(sr));
        residStrainPerStep(sr==0) = 0;
    else
        residStrainPerStep = sr.^m.globalProps.timestep;
    end
    if ~useGrowthTensors
        if m.globalProps.flatten
            m = makeZeroGrowthTensors( m );
        else
            m = makeMeshGrowthTensors( m );
        end
    end
    if userinterrupt( sb )
        m.displacements(:) = [];
        U = [];
        fprintf( 1, 'Simulation interrupted by user at step %d.\n', ...
            m.globalDynamicProps.currentIter );
        return;
    end
    if m.globalProps.freezing <= 0
        retainedStrain = 0;
    elseif m.globalProps.freezing >= 1
        retainedStrain = 1;
    else
        retainedStrain = m.globalProps.freezing^m.globalProps.timestep;
        % exp( -m.globalProps.timestep*(1/m.globalProps.freezing - 1) );
    end
    appliedStrain = 1 - retainedStrain;
  % retainedStrain = retainedStrain/(1+appliedStrain) % Only valid if the applied strain
        % is released and there is no growth.  Properly we should dilute
        % the retained strain by the actual deformation.
    residualScalePerFE = sum( reshape( residStrainPerStep(m.tricellvxs'), 3, [] )', 2 )/3;
    for ci=1:numCells
        if useGrowthTensors
            gt1 = ones(6,1) * m.celldata(ci).cellThermExpGlobalTensor';
        else
            gt1 = m.celldata(ci).Gglobal * m.globalProps.timestep;
        end
        % Every row of gt1 is a growth tensor.
        eps0 = -gt1';
        % Every column of eps0 is a growth tensor.
        if m.globalProps.flatten
            residualScale = 1;
        else
            residualScale = sum(residStrainPerStep(m.tricellvxs(ci,:)))/3;
        end
      % if residualScale ~= 1
      %     m.celldata(ci).residualStrain = m.celldata(ci).residualStrain * residualScale;
      % end
        trivxs = m.tricellvxs(ci,:);
        prismvxs = [ trivxs*2-1, trivxs*2 ];
        cellvxCoords = m.prismnodes( prismvxs, : )';
        [m.celldata(ci),k,f] = ...
            cellFEM( m.celldata(ci), ...
                     cellvxCoords, ...
                     m.globalProps.gaussInfo, ...
                     m.cellstiffness(:,:,ci), ...
                     eps0, ...
                     residualScalePerFE(ci) );
        dfBase = prismvxs*3;
        newIndexes = reshape( [ dfBase-2; dfBase-1; dfBase ], 1, [] );
        K( newIndexes, newIndexes ) = K( newIndexes, newIndexes ) + k;
        F( newIndexes ) = F( newIndexes ) - f;
    end
    if eliminateRigidMotion
      % rR = rank(R)
      % rK = rank(K)
      % rKR = rank([K;R])
        K(selectedDFs,:) = K(selectedDFs,:) + R;
      % F(selectedDFs) = 0;
      % rKR = rank(K)
    end
    if ~m.globalProps.flatten
        if m.globalProps.alwaysFlat
            dfmap = m.fixedDFmap';
            dfmap( 3:3:numel(dfmap) ) = true;
            fixedDFs = find( dfmap );
            stitchPairs = zeros(0,2);
        else
            fixedDFs = find( m.fixedDFmap' );
            stitchPairs = zeros(0,2);
        end
        foo = false;
        if foo
            oppositePairs = zeros(0,2);
            [K,F,renumber] = eliminateEquations( K, F, ...
                fixedDFs, m.globalProps.stitchDFs, oppositePairs, stitchPairs );
        else
            lowerDFs = reshape( ...
                fixedDFs( mod( fixedDFs-1, 6 ) < 3 ), [], 1 );
            upperDFs = lowerDFs + 3;
            oppositePairs = [ lowerDFs, upperDFs ];
          % RATE = 1;
            if true
                % Eventually these will be user-accessible.
                % oppMoves does not currently work.
                % rowsToFix and fixedMoves do.  These specify the degrees
                % of freedom which are to take specified values, and the
                % values they are to take.
                oppMoves = zeros( size( oppositePairs, 1 ), 1 );
                if isempty(m.drivennodes)
                    rowsToFix = [];
                    fixedMoves = [];
                else
                    rowsToFix = repmat( m.drivennodes(:)'*6, 6, 1 ) ...
                                + repmat( [-5;-4;-3;-2;-1;0], 1, length(m.drivennodes) );
                    rowsToFix = rowsToFix(:);
                    nodemoves = m.drivenpositions - m.nodes(m.drivennodes,:);
                    fixedMoves = repmat( nodemoves', 2, 1 );
                    fixedMoves = fixedMoves(:);
                end
            elseif true
                oppMoves = zeros( size( oppositePairs, 1 ), 1 );
                fixedMovesMap = false( size( m.nodes, 1 ), 1 );
                rowsToFix = repmat( 6*[15 21 22 28 29 35], 6, 1 ) + repmat( [-5;-4;-3;-2;-1;0], 1, 6 );
                rowsToFix = rowsToFix(:);
                fixedMoves = -repmat( [[-1;0;0;-1;0;0], [1;0;0;1;0;0]], 1, 3 );
                fixedMoves = m.globalProps.timestep*fixedMoves(:); % zeros( size( rowsToFix, 1 ), 1 );
            else
                fixednodes = 6*[15 21 22 28 29 35];
                oppositePairs = [ fixednodes'-5, fixednodes'-2 ]
                oppMoves = m.globalProps.timestep*[-1 1 -1 1 -1 1]';
                fixedMovesMap = false( size( m.nodes, 1 ), 1 );
                fixednodes = [];
                rowsToFix = find( fixedMovesMap );
                fixedMoves = zeros( size( rowsToFix, 1 ), 1 );
            end
            % oppMoves = RATE*m.globalProps.timestep*ones( size( oppositePairs, 1 ), 1 );
            [K,F,renumber] = eliminateEquations( K, F, ...
                [], m.globalProps.stitchDFs, oppositePairs, stitchPairs, oppMoves, ...
                rowsToFix, fixedMoves );
        end
    end
    if userinterrupt( sb )
        m.displacements(:) = [];
        U = [];
        fprintf( 1, 'Simulation interrupted by user at step %d.\n', ...
            m.globalDynamicProps.currentIter );
        return;
    end
    cgmaxiter = size(K,1)*10; % size(K,1)*40;
    if exactinv
        UC = inv(K)*F;
        cgflag = 0;
        cgrelres = 0;
        m.globalProps.cgiters = 0;
    else
        sparseSolve = true;
        switch m.globalProps.solver
            case 'cgs'
                USERANDOMDISPLACEMENTS = true;
                if USERANDOMDISPLACEMENTS
                    if isempty( m.displacements ) ...
                            || (~m.globalProps.usePrevDispAsEstimate) ...
                            || (m.globalDynamicProps.currentIter <= 0) ...
                            || all(m.displacements(:)==0)
                        fprintf( 1, 'Using random displacement as initial guess.\n' );
                        initestimate = randomiseDisplacements( m );
                    else
                        fprintf( 1, 'Using previous displacement as initial guess.\n' );
                        initestimate = reshape( m.displacements', [], 1 );
                        if m.globalProps.resetRand
                            rand('twister',5489);
                        end
                        initestimate = initestimate .* ...
                            (1 - m.globalProps.perturbRelGrowthEstimate/2 ...
                             + m.globalProps.perturbRelGrowthEstimate * rand( size(initestimate) ));
                          % (0.995 + 0.01*rand( size(initestimate) ));
                      % initestimate = initestimate( randperm( length(initestimate) ) );
                    end
                else
                    initestimate = zeros(numel(m.fixedDFmap),1);
                end
                initestimate(oppositePairs(:,1)) = ...
                    (initestimate(oppositePairs(:,1))-initestimate(oppositePairs(:,2)));
                initestimate = initestimate(renumber);
                fprintf( 1, 'Growth: ' );
                USEJACKET = false;
              % solvertolerancemethod = m.globalProps.solvertolerancemethod
                if CANUSEGPUARRAY
                    tic;
                    [UC,cgflag,cgrelres,m.globalProps.cgiters] = ...
                        mycgs(gpuArray(K),gpuArray(F), ...
                              m.globalProps.solvertolerance, ...
                              m.globalProps.solvertolerancemethod, ...
                              cgmaxiter, ...
                              m.globalProps.maxsolvetime, ...
                              initestimate, ...
                              @testcallback, ...
                              m);
                    fprintf( 1, 'Computation time for growth (cgs,full,GPUArray) is %.6f seconds.\n', toc() );
                elseif USEJACKET
                    tic;
                    [UC,cgflag,cgrelres,m.globalProps.cgiters] = ...
                        mycgs(single(K),F, ...
                              m.globalProps.solvertolerance, ...
                              m.globalProps.solvertolerancemethod, ...
                              cgmaxiter, ...
                              m.globalProps.maxsolvetime, ...
                              initestimate, ...
                              @testcallback, ...
                              m);
                    fprintf( 1, 'Computation time for growth (cgs,full,single,JACKET) is %.6f seconds.\n', toc() );
                elseif sparseSolve || useSparse
                    tic;
                    [UC,cgflag,cgrelres,m.globalProps.cgiters] = ...
                        mycgs(sparse(K),F, ...
                              m.globalProps.solvertolerance, ...
                              m.globalProps.solvertolerancemethod, ...
                              cgmaxiter, ...
                              m.globalProps.maxsolvetime, ...
                              initestimate, ...
                              @testcallback, ...
                              m);
                    fprintf( 1, 'Computation time for growth (cgs,sparse,double) is %.6f seconds.\n', toc() );
                else
                    tic;
                    [UC,cgflag,cgrelres,m.globalProps.cgiters] = ...
                        mycgs(K,F, ...
                              m.globalProps.solvertolerance, ...
                              m.globalProps.solvertolerancemethod, ...
                              cgmaxiter, ...
                              m.globalProps.maxsolvetime, ...
                              initestimate, ...
                              @testcallback, ...
                              m);
                    fprintf( 1, 'Computation time for growth (cgs,full,double) is %.6f seconds.\n', toc() );
                end
              % if false && m.globalProps.usePrevDispAsEstimate
              %     testestimate = (UC-initestimate)./initestimate;
              % end
            case 'lsqr'
                if useSparse
                    tic;
                    [UC,cgflag,cgrelres,m.globalProps.cgiters] = ...
                        mylsqr(sparse(K),F, ...
                               m.globalProps.solvertolerance, ...
                               cgmaxiter, ...
                               m.globalProps.maxsolvetime);
                    fprintf( 1, 'Computation time for growth (lsqr,sparse,double) is %.6f seconds.\n', toc() );
                else
                    tic;
                    [UC,cgflag,cgrelres,m.globalProps.cgiters] = ...
                        mylsqr(K,F, ...
                               m.globalProps.solvertolerance, ...
                               cgmaxiter, ...
                               m.globalProps.maxsolvetime);
                    fprintf( 1, 'Computation time for growth (lsqr,full,double) is %.6f seconds.\n', toc() );
                end
            case 'dgelsy'
                UC = F;
                fprintf( 1, 'Growth by %s, size %d ... ', m.globalProps.solver, size(K,1) );
                tic;
                C = test_gels( K, UC );
                fprintf( 1, 'Computation time for growth (dgelsy,full,double) is %.6f seconds.\n', toc() );
                cgflag = 0;
                cgrelres = 0;
                m.globalProps.cgiters = 0;
            case 'culaSgesv'
                tic;
                [C,UC] = use_culaSgesv( K, F );
                fprintf( 1, 'Computation time for growth (culaSgesv,full,double) is %.6f seconds.\n', toc() );
                if C ~= 0
                    cgflag = -1;
                    switch C
                        case 0
                            culaerr = 'No error';
                        case 1
                            culaerr = 'CULA has not been initialized';
                        case 2
                            culaerr = 'No hardware is available to run';
                        case 3
                            culaerr = 'CUDA runtime or driver is not supported';
                        case 4
                            culaerr = 'Available GPUs do not support the requested operation';
                        case 5
                            culaerr = 'There is insufficient memory to continue';
                        case 6
                            culaerr = 'The requested feature has not been implemented';
                        case 7
                            culaerr = 'An invalid argument was passed to a function';
                        case 8
                            culaerr = 'An operation could not complete because of singular data';
                        case 9
                            culaerr = 'A blas error was encountered';
                        case 10
                            culaerr = 'A runtime error has occurred';
                    end
                    fprintf( 1, 'CULA error %d: %s.\n', C, culaerr );
                    UC = zeros( size(UC), 'double' );
                else
                    UC = double(UC);
                end
                cgflag = 0;
                cgrelres = 0;
                m.globalProps.cgiters = 0;
            otherwise
                complain( 'Warning: unknown solver "%s".  Elasticity equations not solved.', ...
                    m.globalProps.solver );
        end
    end
    if m.globalProps.flatten
        U = UC;
    else
        U = insertFixedDFS( UC, renumber, numDFs, ...
            m.globalProps.stitchDFs, oppositePairs, stitchPairs, oppMoves, rowsToFix, fixedMoves );
        if requireK
            K = insertFixedDFS2( K, renumber, numDFs, m.globalProps.stitchDFs, oppositePairs, oppMoves );
          % KUF = (K*U - F)'
        end
        if requireF
            F = insertFixedDFS( F, renumber, numDFs, ...
                m.globalProps.stitchDFs, oppositePairs, stitchPairs, oppMoves, rowsToFix, fixedMoves );
        end
    end
    U = reshape(U, dfsPerNode, numNodes )';
    if requireF
        F = reshape(F, dfsPerNode, numNodes )';
    end
    if cgflag ~= 0
        fprintf( 1, 'totalK: cgs error: ' );
        if cgflag==20
            fprintf( 1, 'CGS failed to converge to tolerance %g after %d seconds, %d of %d iterations.\n', ...
                cgrelres, round(m.globalProps.maxsolvetime), m.globalProps.cgiters, cgmaxiter );
        elseif cgflag==8
            fprintf( 1, 'CGS interrupted by user after %d steps.\n', ...
                m.globalProps.cgiters );
        elseif cgflag > 0
            cgsmsg( cgflag,cgrelres,m.globalProps.cgiters,cgmaxiter );
        end
    end
    
  % stitchNodes = m.nodes( (m.globalProps.stitchDFs{1}-1)/3+1, : )
    if cgflag==8
        % User interrupt.  The displacements might not be meaningful.
        U = [];
        m.displacements(:) = [];
    else
        if m.globalProps.canceldrift
            anyfixed = any( m.fixedDFmap, 2 );
            [w,U] = cancelMoment( m.prismnodes, U, ~(anyfixed(1:3) | anyfixed(4:6)) );
        end
    
        U = U*appliedStrain;

        if dolocate
            locpnode = locnode+locnode;
            translation = -(U(locpnode-1,:) + U(locpnode,:))/2;
            translation( ~locDFs ) = 0;
            for i=1:size(m.nodes)
                j = i+i;
                U(j-1,:) = U(j-1,:) + translation;
                U(j,:) = U(j,:) + translation;
            end
        end

        m.displacements = U;
        m = computeResiduals( m, retainedStrain );
    end
end
