%-------------------------------------------------------------------------%
% Background
% Compute the transient displacement responses 
% Created by Lei, TU/e, 2018
%-------------------------------------------------------------------------%

function TR = funTR(beta,gamma,freq,aUPres,uAmp,K,M,cpDofs,totDofs,presDofs,freeDofs,elemXID,numNd)

aN = aUPres/uAmp;                            % normalized prescribed kinematics

% time axis                                 
TR.freq = freq;                               % excitation frequency
omega = 2*pi*freq;

nPeriod = 4;
dt = 1/freq/100;                               % time step
t = 0:dt:nPeriod/freq;    

% store 1 per 2 steps, same as that used in COMSOL
sampleT = 2;                                   
nT = (length(t)-1)/sampleT+1;
TR.T = t(1:sampleT:end);

% prescribed kinematics
ut = (1-cos(omega*t)) * uAmp;                 % prescribed transient displacement;
vt = sin(omega*t) * uAmp * omega;
at = cos(omega*t) * uAmp * omega^2; 

% initialize 
ttr1 = clock;

aUtCurrent = zeros(size(totDofs));        % initial displacement field
                                                                  % no initial displacement on each free control point
aUtPres = aN * ut(1);     
aUtCurrent(presDofs) = aUtPres;

aVtCurrent = zeros(size(totDofs));        % initial velocity field
                                                            % no initial velocity on each free control point                                   
aVtPres = aN * vt(1);     
aVtCurrent(presDofs) = aVtPres;

aAtCurrent = zeros(size(totDofs));        % initial acceleration field
                                                                  % nonzero initial velocity on each free
                                                                  % control point, to be solved
aAtPres = aN * at(1);     
aAtCurrent(presDofs) = aAtPres;

aRFree = zeros(size(freeDofs));    % always no external resultant on the free control points

% partition 
Kff = K(freeDofs,freeDofs);
Mff = M(freeDofs,freeDofs);

Kfp = K(freeDofs,presDofs);
Mfp = M(freeDofs,presDofs);

% solve the initial accelertaion field
aRFreeEff = aRFree - Kfp*aUtPres - Mfp*aAtPres;

aUtFree = aUtCurrent(freeDofs);
aAtFree = Mff\(aRFreeEff - Kff*aUtFree);
aAtCurrent(freeDofs) = aAtFree;

% store the kinematics on the control points
TR.cpaUt = zeros(length(cpDofs),nT);
TR.cpaUt(:,1) = aUtCurrent(cpDofs);

TR.cpaVt = zeros(length(cpDofs),nT);
TR.cpaVt(:,1) = aVtCurrent(cpDofs);

TR.cpaAt = zeros(length(cpDofs),nT);
TR.cpaAt(:,1) = aAtCurrent(cpDofs);

%-------------------------------------------------------------------------%
% time integration (for free control points)
for it = 1:length(t)-1           % start from 1*dt
    
    % prediction on the velocity and displacement
    aVtFreeNextPred = aVtCurrent(freeDofs) + (1-gamma)*dt*aAtCurrent(freeDofs);
    aUtFreeNextPred = aUtCurrent(freeDofs) + dt*aVtCurrent(freeDofs) + (0.5-beta)*dt*dt*aAtCurrent(freeDofs);
    
    % compute the acceleration
    Sff = Mff + dt^2*beta*Kff;
    
    aUtPresNext = aN * ut(it+1);  
    aVtPresNext = aN * vt(it+1);  
    aAtPresNext = aN * at(it+1); 
    
    aRFreeEff = aRFree - Kfp*aUtPresNext - Mfp*aAtPresNext;
    aAtFreeNext = Sff\(aRFreeEff - Kff*aUtFreeNextPred);
    
    % correction
    aVtFreeNext = aVtFreeNextPred + dt*gamma*aAtFreeNext;
    aUtFreeNext = aUtFreeNextPred + dt^2*beta*aAtFreeNext;
    
    % update the current kinematics
    aUtCurrent(presDofs) = aUtPresNext;
    aUtCurrent(freeDofs) = aUtFreeNext;
    
    aVtCurrent(presDofs) = aVtPresNext;
    aVtCurrent(freeDofs) = aVtFreeNext;
    
    aAtCurrent(presDofs) = aAtPresNext;
    aAtCurrent(freeDofs) = aAtFreeNext;
    
    % store kinematics on the control points
    if mod(it,sampleT) == 0
       iT = it/sampleT + 1;
       TR.cpaUt(:,iT) = aUtCurrent(cpDofs);
       
       TR.cpaVt(:,iT) = aVtCurrent(cpDofs);
       
       TR.cpaAt(:,iT) = aAtCurrent(cpDofs);
        
    end
    
end

ttr2 = clock;
TR.dttr = etime(ttr2,ttr1);                                                 % computation time for the transient analysis

%-------------------------------------------------------------------------%

% compute the transient displacement field

% remove repeated values
nXID = numel(elemXID);
XID = reshape(elemXID',nXID,1);

[~, uniID] = unique(XID);    
TR.XID = XID(uniID);

TR.Ut = zeros(2*length(uniID),nT);        % transient displacement field

for iT = 1:nT
    cpaUtip = TR.cpaUt(1:2:end,iT);
    cpaUtop = TR.cpaUt(2:2:end,iT);

    elemUtip = zeros(size(elemXID));
    elemUtop = zeros(size(elemXID));
    
    for jN = 1:size(numNd,2)
        for iN = jN:size(numNd,1)       
            elemUtip(jN,:) = elemUtip(jN,:) + numNd{iN,jN}*cpaUtip(iN);
            elemUtop(jN,:) = elemUtop(jN,:) + numNd{iN,jN}*cpaUtop(iN);
        end

    end
   
    % reassemble as arrays
    Utip = reshape(elemUtip',nXID,1);
    Utop = reshape(elemUtop',nXID,1);
    
    % remove the repeated values
    TR.Ut(1:2:end,iT) = Utip(uniID);
    TR.Ut(2:2:end,iT) = Utop(uniID);
    
end

end