1import torch
2from torch import pi, sin, exp
3from collections import OrderedDict
4import matplotlib.pyplot as plt
5
6
7hidden = [50]*8
8activation = torch.nn.Tanh()
9layerList = [('layer_0', torch.nn.Linear(2, hidden[0])),
10 ('activation_0', activation)]
11for l in range(len(hidden)-1):
12 layerList.append((f'layer_{l+1}',
13 torch.nn.Linear(hidden[l], hidden[l+1])))
14 layerList.append((f'activation_{l+1}', activation))
15layerList.append((f'layer_{len(hidden)}', torch.nn.Linear(hidden[-1], 1)))
16
17layerDict = OrderedDict(layerList)
18model = torch.nn.Sequential(OrderedDict(layerDict))
19
20
21
22
23optim = torch.optim.Adam(model.parameters(),
24 lr = 1e-2)
25scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim,
26 factor=0.1,
27 patience=100)
28
29
30nt = 10
31tt = torch.linspace(0., 1., nt+1)
32nx = 20
33xx = torch.linspace(-1., 1., nx+1)
34T,X = torch.meshgrid(tt, xx, indexing='ij')
35tt = tt.reshape(-1,1)
36xx = xx.reshape(-1,1)
37
38Sic = torch.hstack((torch.zeros_like(xx), xx))
39Uic = sin(pi*xx)
40
41Sbc0 = torch.hstack((tt[1:,:], -1.*torch.ones_like(tt[1:,:])))
42Ubc0 = torch.zeros_like(tt[1:,:])
43
44Sbc1 = torch.hstack((tt[1:,:], 1.*torch.ones_like(tt[1:,:])))
45Ubc1 = torch.zeros_like(tt[1:,:])
46
47tin = tt[1:,:]
48xin = xx[1:-1,:]
49Sin = torch.empty((nt*(nx-1), 2))
50Fin = torch.empty((nt*(nx-1), 1))
51s = 0
52for i,t in enumerate(tin):
53 for j,x in enumerate(xin):
54 Sin[s,0] = t
55 Sin[s,1] = x
56 Fin[s,0] = (pi**2 - 1.)*exp(-t)*sin(pi*x)
57 s += 1
58tin = torch.tensor(Sin[:,0:1], requires_grad=True)
59xin = torch.tensor(Sin[:,1:2], requires_grad=True)
60Sin = torch.hstack((tin,xin))
61
62nepochs = 50001
63tol = 1e-4
64nout = 100
65
66for epoch in range(nepochs):
67
68
69
70
71 Uest = model(Sic)
72 lic = torch.mean((Uest - Uic)**2)
73
74
75 U = model(Sin)
76 U_t = torch.autograd.grad(
77 U, tin,
78 grad_outputs=torch.ones_like(U),
79 retain_graph=True,
80 create_graph=True)[0]
81 U_x = torch.autograd.grad(
82 U, xin,
83 grad_outputs=torch.ones_like(U),
84 retain_graph=True,
85 create_graph=True)[0]
86 U_xx = torch.autograd.grad(
87 U_x, xin,
88 grad_outputs=torch.ones_like(U_x),
89 retain_graph=True,
90 create_graph=True)[0]
91 res = U_t - U_xx - Fin
92 lin = torch.mean(res**2)
93
94
95 Uest = model(Sbc0)
96 lbc0 = torch.mean(Uest**2)
97
98
99 Uest = model(Sbc1)
100 lbc1 = torch.mean(Uest**2)
101
102 loss = lin + lic + lbc0 + lbc1
103
104 lr = optim.param_groups[-1]['lr']
105 print(f'{epoch}: loss = {loss.item():.4e}, lr = {lr:.4e}')
106
107
108 scheduler.step(loss)
109 optim.zero_grad()
110 loss.backward()
111 optim.step()
112
113
114
115 if ((epoch % nout == 0) or (loss.item() < tol)):
116 plt.close()
117 fig = plt.figure(dpi=300)
118 nt = 10
119 tt = torch.linspace(0., 1., nt+1)
120 nx = 20
121 xx = torch.linspace(-1., 1., nx+1)
122 T,X = torch.meshgrid(tt, xx, indexing='ij')
123 Uesp = torch.empty_like(T)
124 M = torch.empty(((nt+1)*(nx+1),2))
125 s = 0
126 for i,t in enumerate(tt):
127 for j,x in enumerate(xx):
128 Uesp[i,j] = exp(-t)*sin(pi*x)
129 M[s,0] = t
130 M[s,1] = x
131 s += 1
132 Uest = model(M)
133 Uest = Uest.detach().reshape(nt+1,nx+1)
134 l2rel = torch.norm(Uest - Uesp)/torch.norm(Uesp)
135
136 ax = fig.add_subplot()
137 cb = ax.contourf(T, X, Uesp,
138 levels=10)
139 fig.colorbar(cb)
140 cl = ax.contour(T, X, Uest,
141 levels=10, colors='white')
142 ax.clabel(cl, fmt='%.1f')
143 ax.set_xlabel('$t$')
144 ax.set_ylabel('$x$')
145 plt.title(f'{epoch}: loss = {loss.item():.4e}, l2rel = {l2rel:.4e}')
146 plt.savefig(f'./results/sol_{(epoch//nout):0>6}.png')
147
148 if ((loss.item() < tol) or (lr < 1e-6)):
149 break