Skip to content

Commit 73022b0

Browse files
authored
Add Thermal Neural Network (TNN) SciML example (#18)
* Add Thermal Neural Network (TNN) SciML example * Address review feedback * address reviewer feedback on math text * addressing review comment about required release
1 parent a34ab11 commit 73022b0

13 files changed

Lines changed: 1192 additions & 0 deletions
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
classdef DiffEqLayer < nnet.layer.Layer & nnet.layer.Formattable
2+
% DiffEqLayer Differential equation layer
3+
%
4+
% This layer holds a TNNCell instance, and performs prediction over
5+
% all time steps.
6+
7+
% Copyright 2025 The MathWorks, Inc.
8+
9+
properties (Learnable)
10+
Cell
11+
end
12+
13+
methods
14+
function obj = DiffEqLayer(cell)
15+
% Constructor to store the cell
16+
obj.Cell = cell;
17+
obj.NumInputs = 2;
18+
obj.NumOutputs = 2;
19+
end
20+
21+
function [outputs, state] = predict(this, input, state)
22+
% Initialize cell array for outputs
23+
numSteps = size(input, 3); % Assuming input is [features, batch, time]
24+
25+
% Preallocate the outputs:
26+
outputs = dlarray(zeros(sz, like=input),'CBT');
27+
28+
% Iterate over each time step
29+
thisCell = this.Cell;
30+
for tt = 1:numSteps
31+
outputs(:,:,tt) = predict(thisCell,squeeze(input(:, :, tt)),state);
32+
state = squeeze(outputs(:, :, tt));
33+
end
34+
end
35+
end
36+
end
37+
38+

thermal-neural-network/README.md

Lines changed: 329 additions & 0 deletions
Large diffs are not rendered by default.
64 KB
Loading
46.1 KB
Loading
64.5 KB
Loading
329 KB
Loading

thermal-neural-network/TNNCell.m

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
classdef TNNCell < nnet.layer.Layer
2+
% TNNCell Thermal neural network cell
3+
%
4+
% TNNCell performs the TNN forward pass for a single time step.
5+
6+
% Copyright 2025 The MathWorks, Inc.
7+
8+
properties
9+
SampleTime (1,1) double = 0.5; % in seconds
10+
OutputSize (1,1) double
11+
IncidenceMatrix_x (:,:) double {mustBeInteger}
12+
IncidenceMatrix_u (:,:) double {mustBeInteger}
13+
TemperatureIndices (:,1) double {mustBePositive,mustBeInteger}
14+
NonTemperatureIndices (:,1) double {mustBePositive,mustBeInteger}
15+
InputColumns (:,1) string
16+
TargetColumns (:,1) string
17+
TemperatureColumns (:,1) string
18+
end
19+
20+
properties (Learnable)
21+
ConductanceNet
22+
PowerLoss
23+
Capacitance
24+
end
25+
26+
methods
27+
28+
function this = TNNCell(inputStruct)
29+
% Construct a thermal neural network cell from column metadata.
30+
31+
arguments
32+
inputStruct (1,1) struct
33+
end
34+
35+
requiredFields = ["inputCols", "targetCols", "temperatureCols"];
36+
missingFields = requiredFields(~isfield(inputStruct, requiredFields));
37+
38+
if ~isempty(missingFields)
39+
error("TNNCell:MissingFields", ...
40+
"inputStruct must contain the following fields: %s. Missing: %s.", ...
41+
strjoin(requiredFields, ", "), strjoin(missingFields, ", "));
42+
end
43+
44+
% Construct TNNCell
45+
this.NumInputs = 2;
46+
this.OutputSize = length(inputStruct.targetCols);
47+
nTemps = length(inputStruct.temperatureCols);
48+
49+
% Build incidence matrices for fully connected graph
50+
[this.IncidenceMatrix_x, this.IncidenceMatrix_u] = buildIncidenceMatrices(this.OutputSize, this.NumInputs);
51+
52+
% Store column info
53+
this.InputColumns = strtrim(string(inputStruct.inputCols))';
54+
this.TargetColumns = strtrim(string(inputStruct.targetCols))';
55+
this.TemperatureColumns = strtrim(string(inputStruct.temperatureCols))';
56+
57+
% Indices for temperature and non-temperature columns
58+
this.TemperatureIndices = find(ismember(this.InputColumns, this.TemperatureColumns));
59+
this.NonTemperatureIndices = find(~ismember(this.InputColumns, [this.TemperatureColumns; "profile_id"]));
60+
end
61+
62+
63+
function this = generateNetworks(this)
64+
% Initialize learnable neural networks and parameters for the TNN cell.
65+
66+
nTemps = length(this.TemperatureColumns);
67+
nConds = 0.5 * nTemps * (nTemps - 1) - 1; % fully connected except between the two external nodes
68+
numNeurons = 16;
69+
70+
% By default, just use one dense layer + sigmoid activations
71+
this.ConductanceNet = dlnetwork([featureInputLayer(length(this.InputColumns) + this.OutputSize),...
72+
fullyConnectedLayer(nConds,Name = "conduc_fc1"),sigmoidLayer]);
73+
74+
% By default, use two dense layers + tanh activations
75+
this.PowerLoss = dlnetwork([featureInputLayer(length(this.InputColumns) + this.OutputSize),...
76+
fullyConnectedLayer(numNeurons,Name = "ploss_fc1"),...
77+
tanhLayer,...
78+
fullyConnectedLayer(this.OutputSize,Name="ploss_fc2")]);
79+
80+
this.Capacitance = dlarray(randn(this.OutputSize, 1,'single') * 0.5 - 9.2); % Initialize caps
81+
end
82+
83+
84+
function out = predict(this, input, prevOut)
85+
% Perform a single forward time-step update of the thermal state.
86+
87+
% Extract temperatures
88+
tempsInternal = prevOut; % internal nodes
89+
tempsExternal = input(this.TemperatureIndices,:); % external nodes
90+
subNNInput = [input; prevOut];
91+
92+
E_x = this.IncidenceMatrix_x;
93+
E_u = this.IncidenceMatrix_u;
94+
95+
% Conductance network forward pass
96+
g = abs(predict(this.ConductanceNet, subNNInput'))';
97+
98+
% Power loss network forward pass
99+
q = abs(predict(this.PowerLoss, subNNInput'))';
100+
101+
% Compute temperature differences across edges
102+
dT = E_x' * tempsInternal + E_u' * tempsExternal;
103+
104+
% Heat flow on edges
105+
phi = g .* dT;
106+
107+
% Net outflow from internal nodes
108+
netOutflow = E_x * phi;
109+
110+
% State derivative using incidence-based formulation
111+
dx = exp(this.Capacitance) .* (-netOutflow + q);
112+
113+
% Update temperatures
114+
out = prevOut + this.SampleTime .* dx;
115+
116+
% Clip output
117+
out = max(min(out, 5), -1);
118+
end
119+
120+
end
121+
end
122+
123+
124+
function [E_x, E_u] = buildIncidenceMatrices(numInternal, numExternal)
125+
% Construct incidence matrices for a fully connected thermal network.
126+
%
127+
% numInternal: number of internal nodes
128+
% numExternal: number of external nodes
129+
% Output:
130+
% E_x: [numInternal x L] incidence matrix for internal nodes (fully
131+
% connected graph)
132+
% E_u: [numExternal x L] incidence matrix for external nodes (fully
133+
% connected)
134+
135+
% Calculate number of edges for fully connected internal graph
136+
L_internal = nchoosek(numInternal, 2); % fully connected internal nodes
137+
L_external = numInternal * numExternal; % each external node connected to all internal nodes
138+
L = L_internal + L_external;
139+
140+
% Initialize matrices
141+
E_x = zeros(numInternal, L);
142+
E_u = zeros(numExternal, L);
143+
144+
edgeIdx = 1;
145+
146+
% Internal edges (fully connected)
147+
for i = 1:numInternal
148+
for j = i+1:numInternal
149+
E_x(i, edgeIdx) = 1; % source
150+
E_x(j, edgeIdx) = -1; % target
151+
edgeIdx = edgeIdx + 1;
152+
end
153+
end
154+
155+
% External edges (connect each external node to all internal nodes)
156+
for ext = 1:numExternal
157+
for int = 1:numInternal
158+
E_x(int, edgeIdx) = 1; % internal node as source
159+
E_u(ext, edgeIdx) = -1; % external node as target
160+
edgeIdx = edgeIdx + 1;
161+
end
162+
end
163+
end

0 commit comments

Comments
 (0)