Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
function H1 = h1Norm(X, params)
%H1NORM - Compute H1 norm on a grid.
% H1 = H1NORM(X) computes the H1 norm of the input array X
% with default parameters.
%
% H1 = H1NORM(X, Name=Value) specifies additional options using
% one or more name-value arguments:
%
% Spacings - 1xD vector of grid spacings [Δ1, Δ2, ..., ΔD].
% The default value is ones(1,D).
%
% IncludeL2 - If true, computes full H1 norm (L2 + gradient).
% If false, computes seminorm only (gradient).
% The default value is true.
%
% Reduction - Method for reducing the norm across batch.
% Options are 'mean', 'sum', or 'none'.
% The default value is 'mean'.
%
% Periodic - 1xD logical array indicating which spatial
% dimensions are periodic. The default value
% is true for all dimensions.
%
% SquareRoot - If false, returns the squared H1 norm.
% If true, returns the H1 norm. The default
% value is false.
%
% Normalize - If true, divides output by C*prod(S1, S2, ...).
% The default value is false.
%
% The H1 norm is defined as:
% ||u||_{H^1} = (||u||_{L^2}^2 + ||∇u||_{L^2}^2)^{1/2}
% where ||∇u||_{L^2}^2 = Σ_i ||∂u/∂x_i||_{L^2}^2.
%
% Input X must be a numeric array of size [B, C, S1, S2, ..., SD]
% where B is batch size, C is number of channels, and S1...SD are
% spatial dimensions.
Comment on lines +35 to +37
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why BC(S..S)? That seems more like PyTorch's layout, whereas dlarray default orders to "SSCB" when using labels.

%
% Gradients are estimated using central differences and one-sided
% differences at boundaries (unless periodic boundary conditions).
%
% Example:
% B=2; C=1; S1=64; S2=64;
% X = randn(B,C,S1,S2);
% H1 = h1Norm(X);

% Copyright 2026 The MathWorks, Inc.

arguments
X dlarray {mustBeNumeric}
params.Spacings (1,:) double = []
params.IncludeL2 (1,1) logical = true
params.Reduction (1,1) string {mustBeMember(params.Reduction, {'mean', 'sum', 'none'})} = "mean"
params.Periodic (1,:) logical = true
params.SquareRoot (1,1) logical = false
params.Normalize (1,1) logical = false
end

sz = size(X);
nd = ndims(X);
if nd < 3
error('Input must be at least [B, C, S1].');
end
B = sz(1);
C = sz(2);
spatialSizes = sz(3:end);
D = numel(spatialSizes);

if isempty(params.Spacings)
params.Spacings = ones(1, D);
else
if numel(params.Spacings) ~= D
error('Spacings must have length equal to the number of spatial dimensions (D).');
end
end

if isscalar(params.Periodic)
params.Periodic = repmat(params.Periodic, 1, D);
elseif numel(params.Periodic) ~= D
error('Periodic must be scalar or 1xD logical.');
end

% Initialize H1 as the L2 error,
if params.IncludeL2
H1 = lossFunctions.l2Norm(X, Reduction="none", SquareRoot=false, Normalize=false);
else
H1 = zeros(B, 1, 'like', X);
end

% Reshape to [B*C, S1, S2, ... Sn] so that all batch, channel
% combinations are handled independently.
X = reshape(X, [B*C spatialSizes]);

% Add the H1 seminorm using forward differences.
for d = 1:D
delta = params.Spacings(d);

dm = 1 + d; % Dimension index of this spatial axis in reshaped X.

% Central difference with wrap.
fd = (circshift(X, -1, dm) - circshift(X, 1, dm)) / (2 * delta);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a warning that circshift isn't a dlarray method, so the way it supports dlarray functionality like dlgradient and dlaccelerate is that we trace the dlarray-s through the circshift implementation - if that implementation happens to use only dlarray compatible methods and patterns, things should work out.

I expect you need dlgradient and dlaccelerate would be beneficial for a loss function. A couple reasons to be cautious with stuff that's not explicitly a dlarray method, but work through this "tracing" approach:

  1. There are many codepaths underlying circshift and other functions - you'd need to verify that all of those are dlarray compatible code, or ensure that you only ever go down codepaths that are.

  2. Since circshift isn't a dlarray method, there's no reason it couldn't be replaced in a future release by a C/C++ built-in in future which would not support dlgradient or dlaccelerate - I wouldn't expect us to have internal tests that would catch this because circshift isn't a dlarray method and we can't reasonably say that every function in MATLAB that supports dlarray through tracing should always support it in future.


if ~params.Periodic(d)
% Replace first/last elements with forward/reverse differences.

if min(spatialSizes) < 4
error("Non-periodic dimensions require at least 4 grid points for 3rd-order differences.");
end

fd = applyThirdOrderDifferenceAtBoundary(fd, X, dm, delta);
end

fd = fd.^2;

% Reshape back to original size.
fd = reshape(fd, sz);

% Sum over channels and spatial dimensions, giving size of [B, 1].
fd = sum(fd, 2:nd);

% Accumulate per-batch sum.
H1 = H1 + fd;
end

if params.SquareRoot
H1 = sqrt(H1);
end

if params.Normalize
% Normalize by channels and number of spatial points
H1 = H1 / (C * prod(spatialSizes));
end

if strcmp(params.Reduction, "mean")
H1 = mean(H1, 1);
elseif strcmp(params.Reduction, "sum")
H1 = sum(H1, 1);
end
end

function fd = applyThirdOrderDifferenceAtBoundary(fd, X, d, delta)

% Get the indices of components for 3rd-order forward differences.
idx1 = makeIndex(ndims(fd), d, 1);
idx2 = makeIndex(ndims(fd), d, 2);
idx3 = makeIndex(ndims(fd), d, 3);
idx4 = makeIndex(ndims(fd), d, 4);

% Apply 3rd-order forward differences at left boundary.
fd(idx1{:})= (-11*X(idx1{:}) + 18*X(idx2{:}) - 9*X(idx3{:}) + 2*X(idx4{:})) / (6 * delta);

% Get the indices of components for 3rd-order backward differences.
sz = size(fd, d);
idx1 = makeIndex(ndims(fd), d, sz);
idx2 = makeIndex(ndims(fd), d, sz-1);
idx3 = makeIndex(ndims(fd), d, sz-2);
idx4 = makeIndex(ndims(fd), d, sz-3);

% Apply 3rd-order backward differences at right boundary
fd(idx1{:}) = (-2*X(idx4{:}) + 9*X(idx3{:}) - 18*X(idx2{:}) + 11*X(idx1{:})) / (6 * delta);
end

function idx = makeIndex(ndims, toChange, val)
idx = repmat({':'}, 1, ndims);
idx{toChange} = val;
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
function L2 = l2Norm(X, params)
%L2NORM - Compute L2 norm on a grid.
% L2 = L2NORM(X) computes the L2 norm of the input array X
% with default parameters.
%
% L2 = L2NORM(X, Name=Value) specifies additional options using
% one or more name-value arguments:
%
% Reduction - Method for reducing the norm across batch.
% Options are 'mean', 'sum', or 'none'.
% The default value is 'mean'.
%
% SquareRoot - If false, returns the squared L2 norm.
% If true, returns the L2 norm. The default
% value is false.
%
% Normalize - If true, divides output by C*prod(S1, S2, ...).
% The default value is false.
%
% Input X must be a numeric array of size [B, C, S1, S2, ..., SD]
% where B is batch size, C is number of channels, and S1...SD are
% spatial dimensions.
%
% Example:
% B=2; C=1; S1=64; S2=64;
% X = randn(B,C,S1,S2);
% L2 = l2Norm(X);

% Copyright 2026 The MathWorks, Inc.

arguments
X dlarray {mustBeNumeric}
params.Reduction (1,1) string {mustBeMember(params.Reduction, {'mean', 'sum', 'none'})} = "mean"
params.SquareRoot (1,1) logical = false
params.Normalize (1,1) logical = false
end

sz = size(X);

% Convert to BxCS
X = reshape(X, sz(1), []);

L2 = sum(abs(X.^2), 2); % Bx1, abs() needed for complex values

if params.SquareRoot
L2 = sqrt(L2);
end

if params.Reduction == "mean"
L2 = mean(L2);
elseif params.Reduction == "sum"
L2 = sum(L2);
end

if params.Normalize
L2 = L2/(prod(sz(2:end)));
end

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
function X = permuteDimFirst(X, dim)
%PERMUTEDIMFIRST - Permute specified dimension to be the first dimension.
% X = PERMUTEDIMFIRST(X, DIM) moves the dimension specified by DIM
% to the first position while maintaining the relative order of other
% dimensions.

% Copyright 2026 The MathWorks, Inc.

fmt = dims(X);
Dim = finddim(X, dim);
permuteOrder = [Dim setdiff(1:ndims(X), Dim, 'stable')];
X = permute(stripdims(X), permuteOrder);
X = dlarray(X, fmt);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it matter if the format still makes sense here - e.g. x = dlarray(rand(5,4),"CB"); y = permuteDimFirst(x,"B") will re-label x-s batch dim as y-s channel dim.

I think if you need the dimensions in a particular layout, it's probably best to just work without format labels for as long as that's needed, since the dlarray label auto-permutes are always going to fight back against non-default layouts. If you still need dlarray methods when you don't have format labels, most methods that require labelled data should also have something like a DataFormat name-value pair.

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
function loss = relativeH1Loss(pred, gt, params)
%RELATIVEH1LOSS - Compute the relative H1 norm loss between predictions and ground truth.
% LOSS = RELATIVEH1LOSS(PRED, GT) computes the relative H1 norm loss
% between predicted values PRED and ground truth values GT with default
% parameters.
%
% LOSS = RELATIVEH1LOSS(PRED, GT, Name=Value) specifies additional options
% using one or more name-value arguments:
%
% Normalize - If true, normalizes the H1 norm.
% The default value is false.
%
% SpatialSizes - 1xD vector of physical domain sizes for each spatial
% dimension. The default value is ones(1,D).
%
% SquareRoot - If true, returns the square root of the norm.
% If false, returns the squared norm.
% The default value is false.
%
% Reduction - Method for reducing the loss across batch.
% Options are 'mean', 'sum', or 'none'.
% The default value is 'mean'.
%
% Periodic - 1xD logical array indicating which spatial
% dimensions are periodic. The default value
% is true for all dimensions.
%
% Epsilon - Small constant to add to denominator to avoid division
% by zero, in single precision.
% The default value is 2e-16.
%
% The relative H1 loss is defined as:
% loss = ||pred - gt||_{H^1} / ||gt||_{H^1}
% where the H1 norm measures both function values and their gradients.
% This was proposed by
% Czarnecki, Wojciech M., et al. "Sobolev Training for Neural Networks."
% Advances in Neural Information Processing Systems (2017).
%
% Inputs PRED and GT must be dlarrays of size [B, C, S1, S2, ..., SD]
% where B is batch size, C is number of channels, and S1...SD are
% spatial dimensions.
%
% The loss is calculated per sample in the batch and then reduced
% according to the Reduction parameter.
%
% Example:
% B=2; C=1; S1=64; S2=64;
% pred = dlarray(randn(B,C,S1,S2));
% gt = dlarray(randn(B,C,S1,S2));
% loss = relativeH1Loss(pred, gt);

% Copyright 2026 The MathWorks, Inc.

arguments
pred dlarray
gt dlarray
params.Normalize (1,1) logical = false
params.SpatialSizes (1,:) double = []
params.SquareRoot (1,1) logical = false
params.Reduction (1,1) string {mustBeMember(params.Reduction, {'mean', 'sum', 'none'})} = "mean"
params.Periodic (1,:) logical = true
params.Epsilon (1, 1) single = 2e-16
end

if ~isequal(size(pred), size(gt))
error('pred and gt must have identical size.');
end

if isempty(params.SpatialSizes)
params.SpatialSizes = ones(1, ndims(gt) - 2);
elseif isscalar(params.SpatialSizes)
params.SpatialSizes = repmat(params.SpatialSizes, 1, ndims(gt) - 2);
elseif numel(params.SpatialSizes) ~= ndims(gt) - 2
error('SpatialSizes must have length equal to the number of spatial dimensions.');
end

% Ensure that dimension order is [B, C, S1, S2, ... Sn].
pred = lossFunctions.permuteDimFirst(pred, "C");
gt = lossFunctions.permuteDimFirst(gt, "C");
pred = lossFunctions.permuteDimFirst(pred, "B");
gt = lossFunctions.permuteDimFirst(gt, "B");

sz = size(pred);
quadrature = params.SpatialSizes./sz(3:end);

num = lossFunctions.h1Norm(gt - pred, ...
Spacings=quadrature, ...
Reduction='none', ...
Normalize=params.Normalize, ...
SquareRoot=params.SquareRoot, ...
Periodic=params.Periodic);

den = lossFunctions.h1Norm(gt, ...
Spacings=quadrature, ...
Reduction='none', ...
Normalize=params.Normalize, ...
SquareRoot=params.SquareRoot, ...
Periodic=params.Periodic);

loss = num./(den + params.Epsilon);

switch params.Reduction
case "mean"
loss = mean(loss);
case "sum"
loss = sum(loss);
end
end
Loading