-
Notifications
You must be signed in to change notification settings - Fork 48
Add Tensorization Example Applied to Battery Thermal Analysis #19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,166 @@ | ||
| function H1 = h1Norm(X, params) | ||
| %H1NORM - Compute H1 norm on a grid. | ||
| % H1 = H1NORM(X) computes the H1 norm of the input array X | ||
| % with default parameters. | ||
| % | ||
| % H1 = H1NORM(X, Name=Value) specifies additional options using | ||
| % one or more name-value arguments: | ||
| % | ||
| % Spacings - 1xD vector of grid spacings [Δ1, Δ2, ..., ΔD]. | ||
| % The default value is ones(1,D). | ||
| % | ||
| % IncludeL2 - If true, computes full H1 norm (L2 + gradient). | ||
| % If false, computes seminorm only (gradient). | ||
| % The default value is true. | ||
| % | ||
| % Reduction - Method for reducing the norm across batch. | ||
| % Options are 'mean', 'sum', or 'none'. | ||
| % The default value is 'mean'. | ||
| % | ||
| % Periodic - 1xD logical array indicating which spatial | ||
| % dimensions are periodic. The default value | ||
| % is true for all dimensions. | ||
| % | ||
| % SquareRoot - If false, returns the squared H1 norm. | ||
| % If true, returns the H1 norm. The default | ||
| % value is false. | ||
| % | ||
| % Normalize - If true, divides output by C*prod(S1, S2, ...). | ||
| % The default value is false. | ||
| % | ||
| % The H1 norm is defined as: | ||
| % ||u||_{H^1} = (||u||_{L^2}^2 + ||∇u||_{L^2}^2)^{1/2} | ||
| % where ||∇u||_{L^2}^2 = Σ_i ||∂u/∂x_i||_{L^2}^2. | ||
| % | ||
| % Input X must be a numeric array of size [B, C, S1, S2, ..., SD] | ||
| % where B is batch size, C is number of channels, and S1...SD are | ||
| % spatial dimensions. | ||
|
Comment on lines
+35
to
+37
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why BC(S..S)? That seems more like PyTorch's layout, whereas |
||
| % | ||
| % Gradients are estimated using central differences and one-sided | ||
| % differences at boundaries (unless periodic boundary conditions). | ||
| % | ||
| % Example: | ||
| % B=2; C=1; S1=64; S2=64; | ||
| % X = randn(B,C,S1,S2); | ||
| % H1 = h1Norm(X); | ||
|
|
||
| % Copyright 2026 The MathWorks, Inc. | ||
jonahweiss marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| arguments | ||
| X dlarray {mustBeNumeric} | ||
| params.Spacings (1,:) double = [] | ||
| params.IncludeL2 (1,1) logical = true | ||
| params.Reduction (1,1) string {mustBeMember(params.Reduction, {'mean', 'sum', 'none'})} = "mean" | ||
| params.Periodic (1,:) logical = true | ||
| params.SquareRoot (1,1) logical = false | ||
| params.Normalize (1,1) logical = false | ||
| end | ||
|
|
||
| sz = size(X); | ||
| nd = ndims(X); | ||
| if nd < 3 | ||
| error('Input must be at least [B, C, S1].'); | ||
| end | ||
| B = sz(1); | ||
| C = sz(2); | ||
| spatialSizes = sz(3:end); | ||
| D = numel(spatialSizes); | ||
|
|
||
| if isempty(params.Spacings) | ||
| params.Spacings = ones(1, D); | ||
| else | ||
| if numel(params.Spacings) ~= D | ||
| error('Spacings must have length equal to the number of spatial dimensions (D).'); | ||
| end | ||
| end | ||
|
|
||
| if isscalar(params.Periodic) | ||
| params.Periodic = repmat(params.Periodic, 1, D); | ||
| elseif numel(params.Periodic) ~= D | ||
| error('Periodic must be scalar or 1xD logical.'); | ||
| end | ||
|
|
||
| % Initialize H1 as the L2 error, | ||
| if params.IncludeL2 | ||
| H1 = lossFunctions.l2Norm(X, Reduction="none", SquareRoot=false, Normalize=false); | ||
| else | ||
| H1 = zeros(B, 1, 'like', X); | ||
| end | ||
|
|
||
| % Reshape to [B*C, S1, S2, ... Sn] so that all batch, channel | ||
| % combinations are handled independently. | ||
| X = reshape(X, [B*C spatialSizes]); | ||
|
|
||
| % Add the H1 seminorm using forward differences. | ||
| for d = 1:D | ||
| delta = params.Spacings(d); | ||
|
|
||
| dm = 1 + d; % Dimension index of this spatial axis in reshaped X. | ||
|
|
||
| % Central difference with wrap. | ||
| fd = (circshift(X, -1, dm) - circshift(X, 1, dm)) / (2 * delta); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a warning that I expect you need
|
||
|
|
||
| if ~params.Periodic(d) | ||
| % Replace first/last elements with forward/reverse differences. | ||
|
|
||
| if min(spatialSizes) < 4 | ||
| error("Non-periodic dimensions require at least 4 grid points for 3rd-order differences."); | ||
| end | ||
|
|
||
| fd = applyThirdOrderDifferenceAtBoundary(fd, X, dm, delta); | ||
| end | ||
|
|
||
| fd = fd.^2; | ||
|
|
||
| % Reshape back to original size. | ||
| fd = reshape(fd, sz); | ||
|
|
||
| % Sum over channels and spatial dimensions, giving size of [B, 1]. | ||
| fd = sum(fd, 2:nd); | ||
|
|
||
| % Accumulate per-batch sum. | ||
| H1 = H1 + fd; | ||
| end | ||
|
|
||
| if params.SquareRoot | ||
| H1 = sqrt(H1); | ||
| end | ||
|
|
||
| if params.Normalize | ||
| % Normalize by channels and number of spatial points | ||
| H1 = H1 / (C * prod(spatialSizes)); | ||
| end | ||
|
|
||
| if strcmp(params.Reduction, "mean") | ||
| H1 = mean(H1, 1); | ||
| elseif strcmp(params.Reduction, "sum") | ||
| H1 = sum(H1, 1); | ||
| end | ||
| end | ||
|
|
||
| function fd = applyThirdOrderDifferenceAtBoundary(fd, X, d, delta) | ||
|
|
||
| % Get the indices of components for 3rd-order forward differences. | ||
| idx1 = makeIndex(ndims(fd), d, 1); | ||
| idx2 = makeIndex(ndims(fd), d, 2); | ||
| idx3 = makeIndex(ndims(fd), d, 3); | ||
| idx4 = makeIndex(ndims(fd), d, 4); | ||
|
|
||
| % Apply 3rd-order forward differences at left boundary. | ||
| fd(idx1{:})= (-11*X(idx1{:}) + 18*X(idx2{:}) - 9*X(idx3{:}) + 2*X(idx4{:})) / (6 * delta); | ||
|
|
||
| % Get the indices of components for 3rd-order backward differences. | ||
| sz = size(fd, d); | ||
| idx1 = makeIndex(ndims(fd), d, sz); | ||
| idx2 = makeIndex(ndims(fd), d, sz-1); | ||
| idx3 = makeIndex(ndims(fd), d, sz-2); | ||
| idx4 = makeIndex(ndims(fd), d, sz-3); | ||
|
|
||
| % Apply 3rd-order backward differences at right boundary | ||
| fd(idx1{:}) = (-2*X(idx4{:}) + 9*X(idx3{:}) - 18*X(idx2{:}) + 11*X(idx1{:})) / (6 * delta); | ||
| end | ||
|
|
||
| function idx = makeIndex(ndims, toChange, val) | ||
| idx = repmat({':'}, 1, ndims); | ||
| idx{toChange} = val; | ||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| function L2 = l2Norm(X, params) | ||
| %L2NORM - Compute L2 norm on a grid. | ||
| % L2 = L2NORM(X) computes the L2 norm of the input array X | ||
| % with default parameters. | ||
| % | ||
| % L2 = L2NORM(X, Name=Value) specifies additional options using | ||
| % one or more name-value arguments: | ||
| % | ||
| % Reduction - Method for reducing the norm across batch. | ||
| % Options are 'mean', 'sum', or 'none'. | ||
| % The default value is 'mean'. | ||
| % | ||
| % SquareRoot - If false, returns the squared L2 norm. | ||
| % If true, returns the L2 norm. The default | ||
| % value is false. | ||
| % | ||
| % Normalize - If true, divides output by C*prod(S1, S2, ...). | ||
| % The default value is false. | ||
| % | ||
| % Input X must be a numeric array of size [B, C, S1, S2, ..., SD] | ||
| % where B is batch size, C is number of channels, and S1...SD are | ||
| % spatial dimensions. | ||
| % | ||
| % Example: | ||
| % B=2; C=1; S1=64; S2=64; | ||
| % X = randn(B,C,S1,S2); | ||
| % L2 = l2Norm(X); | ||
|
|
||
| % Copyright 2026 The MathWorks, Inc. | ||
|
|
||
| arguments | ||
| X dlarray {mustBeNumeric} | ||
| params.Reduction (1,1) string {mustBeMember(params.Reduction, {'mean', 'sum', 'none'})} = "mean" | ||
| params.SquareRoot (1,1) logical = false | ||
| params.Normalize (1,1) logical = false | ||
| end | ||
|
|
||
| sz = size(X); | ||
|
|
||
| % Convert to BxCS | ||
| X = reshape(X, sz(1), []); | ||
|
|
||
| L2 = sum(abs(X.^2), 2); % Bx1, abs() needed for complex values | ||
|
|
||
| if params.SquareRoot | ||
| L2 = sqrt(L2); | ||
| end | ||
|
|
||
| if params.Reduction == "mean" | ||
| L2 = mean(L2); | ||
| elseif params.Reduction == "sum" | ||
| L2 = sum(L2); | ||
| end | ||
|
|
||
| if params.Normalize | ||
| L2 = L2/(prod(sz(2:end))); | ||
| end | ||
|
|
||
| end |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| function X = permuteDimFirst(X, dim) | ||
| %PERMUTEDIMFIRST - Permute specified dimension to be the first dimension. | ||
| % X = PERMUTEDIMFIRST(X, DIM) moves the dimension specified by DIM | ||
| % to the first position while maintaining the relative order of other | ||
| % dimensions. | ||
|
|
||
| % Copyright 2026 The MathWorks, Inc. | ||
|
|
||
| fmt = dims(X); | ||
| Dim = finddim(X, dim); | ||
| permuteOrder = [Dim setdiff(1:ndims(X), Dim, 'stable')]; | ||
| X = permute(stripdims(X), permuteOrder); | ||
| X = dlarray(X, fmt); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it matter if the format still makes sense here - e.g. I think if you need the dimensions in a particular layout, it's probably best to just work without format labels for as long as that's needed, since the |
||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,108 @@ | ||
| function loss = relativeH1Loss(pred, gt, params) | ||
| %RELATIVEH1LOSS - Compute the relative H1 norm loss between predictions and ground truth. | ||
| % LOSS = RELATIVEH1LOSS(PRED, GT) computes the relative H1 norm loss | ||
| % between predicted values PRED and ground truth values GT with default | ||
| % parameters. | ||
| % | ||
| % LOSS = RELATIVEH1LOSS(PRED, GT, Name=Value) specifies additional options | ||
| % using one or more name-value arguments: | ||
| % | ||
| % Normalize - If true, normalizes the H1 norm. | ||
| % The default value is false. | ||
| % | ||
| % SpatialSizes - 1xD vector of physical domain sizes for each spatial | ||
| % dimension. The default value is ones(1,D). | ||
| % | ||
| % SquareRoot - If true, returns the square root of the norm. | ||
| % If false, returns the squared norm. | ||
| % The default value is false. | ||
| % | ||
| % Reduction - Method for reducing the loss across batch. | ||
| % Options are 'mean', 'sum', or 'none'. | ||
| % The default value is 'mean'. | ||
| % | ||
| % Periodic - 1xD logical array indicating which spatial | ||
| % dimensions are periodic. The default value | ||
| % is true for all dimensions. | ||
| % | ||
| % Epsilon - Small constant to add to denominator to avoid division | ||
| % by zero, in single precision. | ||
| % The default value is 2e-16. | ||
| % | ||
| % The relative H1 loss is defined as: | ||
| % loss = ||pred - gt||_{H^1} / ||gt||_{H^1} | ||
| % where the H1 norm measures both function values and their gradients. | ||
| % This was proposed by | ||
| % Czarnecki, Wojciech M., et al. "Sobolev Training for Neural Networks." | ||
| % Advances in Neural Information Processing Systems (2017). | ||
| % | ||
| % Inputs PRED and GT must be dlarrays of size [B, C, S1, S2, ..., SD] | ||
| % where B is batch size, C is number of channels, and S1...SD are | ||
| % spatial dimensions. | ||
| % | ||
| % The loss is calculated per sample in the batch and then reduced | ||
| % according to the Reduction parameter. | ||
| % | ||
| % Example: | ||
| % B=2; C=1; S1=64; S2=64; | ||
| % pred = dlarray(randn(B,C,S1,S2)); | ||
| % gt = dlarray(randn(B,C,S1,S2)); | ||
| % loss = relativeH1Loss(pred, gt); | ||
|
|
||
| % Copyright 2026 The MathWorks, Inc. | ||
|
|
||
| arguments | ||
| pred dlarray | ||
| gt dlarray | ||
| params.Normalize (1,1) logical = false | ||
| params.SpatialSizes (1,:) double = [] | ||
| params.SquareRoot (1,1) logical = false | ||
| params.Reduction (1,1) string {mustBeMember(params.Reduction, {'mean', 'sum', 'none'})} = "mean" | ||
| params.Periodic (1,:) logical = true | ||
| params.Epsilon (1, 1) single = 2e-16 | ||
| end | ||
|
|
||
| if ~isequal(size(pred), size(gt)) | ||
| error('pred and gt must have identical size.'); | ||
| end | ||
|
|
||
| if isempty(params.SpatialSizes) | ||
| params.SpatialSizes = ones(1, ndims(gt) - 2); | ||
| elseif isscalar(params.SpatialSizes) | ||
| params.SpatialSizes = repmat(params.SpatialSizes, 1, ndims(gt) - 2); | ||
| elseif numel(params.SpatialSizes) ~= ndims(gt) - 2 | ||
| error('SpatialSizes must have length equal to the number of spatial dimensions.'); | ||
| end | ||
|
|
||
| % Ensure that dimension order is [B, C, S1, S2, ... Sn]. | ||
| pred = lossFunctions.permuteDimFirst(pred, "C"); | ||
| gt = lossFunctions.permuteDimFirst(gt, "C"); | ||
| pred = lossFunctions.permuteDimFirst(pred, "B"); | ||
| gt = lossFunctions.permuteDimFirst(gt, "B"); | ||
|
|
||
| sz = size(pred); | ||
| quadrature = params.SpatialSizes./sz(3:end); | ||
|
|
||
| num = lossFunctions.h1Norm(gt - pred, ... | ||
| Spacings=quadrature, ... | ||
| Reduction='none', ... | ||
| Normalize=params.Normalize, ... | ||
| SquareRoot=params.SquareRoot, ... | ||
| Periodic=params.Periodic); | ||
|
|
||
| den = lossFunctions.h1Norm(gt, ... | ||
| Spacings=quadrature, ... | ||
| Reduction='none', ... | ||
| Normalize=params.Normalize, ... | ||
| SquareRoot=params.SquareRoot, ... | ||
| Periodic=params.Periodic); | ||
|
|
||
| loss = num./(den + params.Epsilon); | ||
|
|
||
| switch params.Reduction | ||
| case "mean" | ||
| loss = mean(loss); | ||
| case "sum" | ||
| loss = sum(loss); | ||
| end | ||
| end |
Uh oh!
There was an error while loading. Please reload this page.