Skip to content

Add Tensorization Example Applied to Battery Thermal Analysis#19

Open
jonahweiss wants to merge 2 commits intomatlab-deep-learning:mainfrom
jonahweiss:feature/tfno_example
Open

Add Tensorization Example Applied to Battery Thermal Analysis#19
jonahweiss wants to merge 2 commits intomatlab-deep-learning:mainfrom
jonahweiss:feature/tfno_example

Conversation

@jonahweiss
Copy link
Copy Markdown

The example is a live M script: tensorizedFourierNeuralOperatorForBatteryCoolingAnalysis.m demonstrating the application of the paper Multi-Grid Tensorized Fourier Neural Operator for High-Resolution PDEs to the Battery Heat Analysis example.

Once the support files containing pregenerated simulation data are live, the URL variable pregeneratedSimulationDataURL in the example will need to be set, and then the function downloadSimulationData.m may download and unzip the data from the given URL.

The tfno/ folder includes the implementation of the TFNO 3D model.

The lossFunctions/ folder includes the implementation of the relative H1 loss.

The trainingPartitions.m and createBatteryModuleGeometry.m functions are taken from the existing Battery Heat Analysis example.

Comment on lines +35 to +37
% Input X must be a numeric array of size [B, C, S1, S2, ..., SD]
% where B is batch size, C is number of channels, and S1...SD are
% spatial dimensions.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why BC(S..S)? That seems more like PyTorch's layout, whereas dlarray default orders to "SSCB" when using labels.

dm = 1 + d; % Dimension index of this spatial axis in reshaped X.

% Central difference with wrap.
fd = (circshift(X, -1, dm) - circshift(X, 1, dm)) / (2 * delta);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a warning that circshift isn't a dlarray method, so the way it supports dlarray functionality like dlgradient and dlaccelerate is that we trace the dlarray-s through the circshift implementation - if that implementation happens to use only dlarray compatible methods and patterns, things should work out.

I expect you need dlgradient and dlaccelerate would be beneficial for a loss function. A couple reasons to be cautious with stuff that's not explicitly a dlarray method, but work through this "tracing" approach:

  1. There are many codepaths underlying circshift and other functions - you'd need to verify that all of those are dlarray compatible code, or ensure that you only ever go down codepaths that are.

  2. Since circshift isn't a dlarray method, there's no reason it couldn't be replaced in a future release by a C/C++ built-in in future which would not support dlgradient or dlaccelerate - I wouldn't expect us to have internal tests that would catch this because circshift isn't a dlarray method and we can't reasonably say that every function in MATLAB that supports dlarray through tracing should always support it in future.

Dim = finddim(X, dim);
permuteOrder = [Dim setdiff(1:ndims(X), Dim, 'stable')];
X = permute(stripdims(X), permuteOrder);
X = dlarray(X, fmt);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it matter if the format still makes sense here - e.g. x = dlarray(rand(5,4),"CB"); y = permuteDimFirst(x,"B") will re-label x-s batch dim as y-s channel dim.

I think if you need the dimensions in a particular layout, it's probably best to just work without format labels for as long as that's needed, since the dlarray label auto-permutes are always going to fight back against non-default layouts. If you still need dlarray methods when you don't have format labels, most methods that require labelled data should also have something like a DataFormat name-value pair.

@@ -0,0 +1,73 @@
classdef depthwiseConv3dLayer < nnet.layer.Layer & ...
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be convolution3dLayer(1,numChannels) and convolution3dLayer(1,numChannels,BiasLearnRateFactor=0) when UseBias==false?

net = connectLayers(net, "channelSkip", "add2/in2");
else
net = connectLayers(net, "in", "add2/in2");
end
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could LinearFNOSkip and ChannelMLPSkip be merged into a SkipConnectionMode = ["identity","linear"]? That would miss the option of using "linear" for just one of the skips, but I don't expect that's common.

@@ -0,0 +1,63 @@
# Tensorized Fourier Neural Operator for 3D Battery Heat Analysis

This example builds off of the [Fourier Neural Operator for 3D Battery Heat Analysis](https://github.com/matlab-deep-learning/SciML-and-Physics-Informed-Machine-Learning-Examples/tree/main/battery-module-cooling-analysis-with-fourier-neural-operator) example to apply a Tensorized Fourier Neural Operator (TFNO) [1, 2] to heat analysis of a 3D battery module. The TFNO compresses the standard Fourier Neural Operator using tensorization, achieving 14.3x parameter reduction while maintaining accuracy.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be able to link the example with a relative path in the repo.

Comment on lines +5 to +6
![](./images/prediction_vs_gt.png)
![](./images/absolute_error.png)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's worth adding alt-text descriptions for these.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants