function x = PID_color(x, sigma2, reference)
[height width depth] = size(x);
M = dctmtx(size(x, 3))';
s = size(x);
sc = [height * width, depth];
x = reshape(reshape(x, sc) * M, s);
reference = reshape(reshape(reference, sc) * M, s);
x0 = x;
% PID
N = 30;
r = 15;
sigma_s = 7;
gamma_r = 988.5;
gamma_s = 2/9;
alpha = 1.533;
lambda = log(alpha) * 0.567;
[dx dy] = meshgrid(-r:r);
r2 = dx.^2 + dy.^2;
fprintf('%d/%d: PSNR = %.2f\n', 0, N, psnr(reference, x));
for l=1:N, xp = padarray(x, [r r], 'symmetric');
delta = zeros([height width depth]);
parfor i=1:height
for j=1:width
% Spatial Domain
d = bsxfun(@minus, xp(i:i+2*r, j:j+2*r, :), x(i, j, :));
T = sigma2 * gamma_r * alpha^(-l);
S = sigma_s^2 * gamma_s * alpha^(l/2);
k = exp(- mean(d.^2, 3) / T) .* exp(- r2 / S);
% Fourier Domain
V = sigma2 * sum(k(:).^2);
for c=1:depth
D = fft2(ifftshift(d(:, :, c) .* k));
K = exp(- abs(D).^2 / V);
delta(i, j, c) = sum(sum(real(D) .* K)) / numel(K);
end
end
end
x = x - lambda * delta;
fprintf('%d/%d: PSNR = %.2f\n', l, N, psnr(reference, x));
end
% DDID step
r = 31;
sigma_s = 16;
gamma_r = 0.6;
gamma_f = 2.16;
[dx dy] = meshgrid(-r:r);
r2 = (dx.^2 + dy.^2) / (2 * sigma_s^2);
yp = padarray(x0, [r r], 'symmetric');
xp = padarray(x, [r r], 'symmetric');
parfor i=1:height
for j=1:width
g = xp(i:i+2*r, j:j+2*r, :);
y = yp(i:i+2*r, j:j+2*r, :);
d = bsxfun(@minus, g, g(1+r, 1+r, :));
k = exp(- sum(d.^2, 3) ./ (gamma_r * sigma2)) .* exp(-r2);
gt = sum(sum(bsxfun(@times, g, k))) / sum(k(:));
st = sum(sum(bsxfun(@times, y, k))) / sum(k(:));
V = sigma2 .* sum(k(:).^2);
for c=1:depth
G = fft2(ifftshift((g(:, :, c) - gt(c)) .* k));
S = fft2(ifftshift((y(:, :, c) - st(c)) .* k));
K = 1 - exp(- abs(G).^2 ./ (gamma_f * V));
St = sum(sum(real(S) .* K)) / numel(K);
x(i, j, c) = st(c) + St;
end
end
end
x = reshape(reshape(x, sc) / M, s);
end