function [u,errAll] = TV_SB_3D(J,f, N,mu, lambda, gamma, nInner, nBreg,varargin)
normFactor = getNormalizationFactor(f,f);
f = normFactor*f;
switch nargin
case 9
uTarget = varargin{1};
case 10
uTarget = varargin{1};
R = varargin{2};
f = f.*R;
end % nargin
errAll = zeros(nBreg,1);
% Normalize Jacobian such that its Hessian diagonal is equal to 1
normFactorJ = 1/sqrt(max(diag(J'*J)));
J = J*normFactorJ;
% Scale the forward and adjoint operations so doent depend on the size
scale = 1/max(abs(J'*f));
% Define forward and adjoint operators
if nargin >= 10
A = @(x)(((J*x(:)).*R)/scale);
else
A = @(x)(((J*x(:)))/scale);
end
AT = @(x)(reshape((J'*x)*scale,N));
tolKrylov = 1e-2; % 1e-4
% Reserve memory for the auxillary variables
rows = N(1);
cols = N(2);
height = N(3);
f0 = f;
u = zeros(N);
x = zeros(N);
y = zeros(N);
z = zeros(N);
bx = zeros(N);
by = zeros(N);
bz = zeros(N);
murf = mu*AT(f);
% Do the reconstruction
for outer = 1:nBreg;
for inner = 1:nInner;
% update u
rhs = murf+lambda*Dxt(x-bx)+lambda*Dyt(y-by)+lambda*Dzt(z-bz);
u = reshape(krylov(rhs(:)),N);
dx = Dx(u);
dy = Dy(u);
dz = Dz(u);
% update x and y and z
[x,y] = shrink2(dx+bx,dy+by,1/lambda);
z = shrink1(dz+bz,1/lambda);
% update bregman parameters
bx = bx+dx-x;
by = by+dy-y;
bz = bz+dz-z;
end % inner loop
fForw = A(u);
f = f + f0-fForw;
murf = mu*AT(f);
if nargin >= 9
% Solution error norm
errAll(outer) = norm(uTarget(:)-abs(u(:)*normFactorJ/(normFactor*scale)))/norm(uTarget(:));
if any([outer ==1, outer == 10, rem(outer, 50) == 0])
close;
h=figure;
subplot(2,2,1);
imagesc(abs(murf(:,:,5))); title(['retroprojection']); colorbar;
subplot(2,2,2);
imagesc(abs(u(:,:,5)*normFactorJ/(normFactor*scale))); title(['u, iter. ' num2str(outer)]); colorbar;
subplot(2,2,3);
plot(errAll(1:outer)); axis tight; title(['Sol. error' ]);
colormap gray;
drawnow;
end % rem
end % nargin
end % outer
% undo the normalization so that results are scaled properly
u = u*normFactorJ/(normFactor*scale);
function normFactor = getNormalizationFactor(R,f)
normFactor = 1/norm(f(:)/size(R==1,1));
end
function d = Dx(u)
[rows,cols,height] = size(u);
d = zeros(rows,cols,height);
d(:,2:cols,:) = u(:,2:cols,:)-u(:,1:cols-1,:);
d(:,1,:) = u(:,1,:)-u(:,cols,:);
end
function d = Dxt(u)
[rows,cols,height] = size(u);
d = zeros(rows,cols,height);
d(:,1:cols-1,:) = u(:,1:cols-1,:)-u(:,2:cols,:);
d(:,cols,:) = u(:,cols,:)-u(:,1,:);
end
function d = Dy(u)
[rows,cols,height] = size(u);
d = zeros(rows,cols,height);
d(2:rows,:,:) = u(2:rows,:,:)-u(1:rows-1,:,:);
d(1,:,:) = u(1,:,:)-u(rows,:,:);
end
function d = Dyt(u)
[rows,cols,height] = size(u);
d = zeros(rows,cols,height);
d(1:rows-1,:,:) = u(1:rows-1,:,:)-u(2:rows,:,:);
d(rows,:,:) = u(rows,:,:)-u(1,:,:);
end
function d = Dz(u) % Time derivative for 3D matrix
[rows,cols,height] = size(u);
d = zeros(rows,cols,height);
d(:,:,2:height) = u(:,:,2:height)-u(:,:,1:height-1);
d(:,:,1) = u(:,:,1)-u(:,:,height);
end
function d = Dzt(u) % Time derivative for 3D matrix, transpose
[rows,cols,height] = size(u);
d = zeros(rows,cols,height);
d(:,:,1:height-1) = u(:,:,1:height-1)-u(:,:,2:height);
d(:,:,height) = u(:,:,height)-u(:,:,1);
end
function [xs,ys] = shrink2(x,y,lambda)
s = sqrt(x.*conj(x)+y.*conj(y));
ss = s-lambda;
ss = ss.*(ss>0);
s = s+(s<lambda);
ss = ss./s;
xs = ss.*x;
ys = ss.*y;
end
function xs = shrink1(x,lambda)
s = abs(x);
xs = sign(x).*max(s-lambda,0);
end
function dx = krylov(r)
%dx = gmres (@jtjx, r, 30, tolKrylov, 100);
[dx,flag,relres,iter] = bicgstab(@jtjx, r, tolKrylov, 100);
end
% =====================================================================
% Callback function for matrix-vector product (called by krylov)
function b = jtjx(sol)
solMat = reshape(sol,N);
% Laplacian part
bTV = lambda*(Dxt(Dx(solMat))+Dyt(Dy(solMat))+Dzt(Dz(solMat)));
% Jacobian part
bJac = mu*AT(A(sol));
% Stability term
bG = gamma*sol;
b = bTV(:) + bJac(:) + bG(:);
end
% =====================================================================
end
%