1import torch
2
3
4nh = 4
5nn = 50
6fun = torch.nn.Tanh()
7model = torch.nn.Sequential()
8model.add_module('layer_1', torch.nn.Linear(2, nn))
9model.add_module('fun_1', fun)
10for l in range(2, nh+1):
11 model.add_module(f'layer_{l}', torch.nn.Linear(nn, nn))
12 model.add_module(f'fun_{l}', fun)
13model.add_module(f'layer_{nh+1}', torch.nn.Linear(nn, 1))
14
15
16rgn = [5., 7]
17model.lmbda = torch.nn.Parameter(
18 data=(rgn[1]-rgn[0])*torch.rand(1)+rgn[0])
19
20
21optim = torch.optim.Adam(model.parameters(), lr=0.001)
22
23
24tf = 1.
25
26
27lmbda = torch.tensor([6.])
28def ua(t,x, lmbda=lmbda):
29 return 1./(1.+torch.exp(torch.sqrt(lmbda/6.)*x-5./6*lmbda*t))**2
30
31
32def u0(x, lmbda=lmbda):
33 return 1./(1.+torch.exp(torch.sqrt(lmbda/6)*x))**2
34
35
36ts = torch.tensor([0.1, 0.2, 0.3])
37xs = torch.tensor([0.25, 0.5, 0.75])
38T, X = torch.meshgrid(ts, xs, indexing='ij')
39Ss = torch.hstack((T.reshape(-1,1), X.reshape(-1,1)))
40Us_exp = ua(T, X).reshape(-1,1)
41
42
43nepochs = 50000
44tol = 1e-5
45
46eout = 100
47
48sin = 50
49penalty = 1e1
50
51for epoch in range(nepochs):
52
53
54
55
56 tsin = tf*torch.rand(sin, 1)
57 xsin = torch.rand(sin, 1)
58 Sin = torch.hstack((tsin, xsin))
59 Sin.requires_grad = True
60
61 Uin = model(Sin)
62
63
64 DUin = torch.autograd.grad(
65 Uin, Sin,
66 torch.ones_like(Uin),
67 create_graph=True,
68 retain_graph=True)[0]
69 Uin_t = DUin[:,0:1]
70 Uin_x = DUin[:,1:2]
71
72 Uin_xx = torch.autograd.grad(
73 Uin_x, Sin,
74 torch.ones_like(Uin_x),
75 create_graph=True,
76 retain_graph=True)[0][:,1:2]
77
78
79 lin = torch.mean((Uin_t - Uin_xx \
80 - model.lmbda*Uin*(1-Uin))**2)
81
82
83 S0 = torch.hstack((torch.zeros_like(xsin), xsin))
84
85 U0 = model(S0)
86
87
88 l0 = torch.mean((U0 - u0(xsin))**2)
89
90
91 Sbc0 = torch.hstack((tsin, torch.zeros_like(xsin)))
92 Sbc1 = torch.hstack((tsin, torch.ones_like(xsin)))
93 Sbc = torch.vstack((Sbc0, Sbc1))
94
95 Ubc_exp = ua(Sbc[:,0:1],Sbc[:,1:2])
96 Ubc_est = model(Sbc)
97
98
99 lbc = torch.mean((Ubc_est - Ubc_exp)**2)
100
101
102 Us_est = model(Ss)
103
104
105 ls = torch.mean((Us_est - Us_exp)**2)
106
107
108 loss = lin + l0 + lbc + penalty*ls
109
110 if ((epoch % eout == 0) or (loss.item() < tol)):
111 print(f'epoch: {epoch}, '\
112 + f'loss={loss.item():.4e}, '\
113 + f'lmbda={model.lmbda.item():.3f}')
114
115 if (loss.item() < tol):
116 break
117
118 optim.zero_grad()
119 loss.backward()
120 optim.step()