1import torch
2from torch import pi, sin
3
4device = torch.device('cuda')
5print(f'Using device: {device}')
6
7
8nh = 3
9nn = 50
10model = torch.nn.Sequential()
11model.add_module('layer0', torch.nn.Linear(2, nn))
12model.add_module('act0', torch.nn.Tanh())
13for i in range(1, nh):
14 model.add_module(f'layer{i}', torch.nn.Linear(nn, nn))
15 model.add_module(f'act{i}', torch.nn.Tanh())
16model.add_module(f'layer{nh}', torch.nn.Linear(nn, 1))
17model.to(device)
18
19
20optim = torch.optim.SGD(model.parameters(),
21 lr = 0.001, momentum=0.9)
22
23
24def f(X):
25 return 2.*pi**2*sin(pi*X[:,0:1])*sin(pi*X[:,1:2])
26
27
28ns_in = 400
29ns_bc = 20
30
31nepochs = 50000
32tol = 1e-3
33
34
35for epoch in range(nepochs):
36
37 optim.zero_grad()
38
39
40 Xin = (2.*torch.rand(ns_in, 2, device=device) - 1.)
41 Xin.requires_grad = True
42
43 U = model(Xin)
44
45
46
47 D1U = torch.autograd.grad(
48 U, Xin,
49 grad_outputs=torch.ones_like(U),
50 retain_graph=True,
51 create_graph=True)[0]
52
53 D2UX1 = torch.autograd.grad(
54 D1U[:,0:1], Xin,
55 grad_outputs=torch.ones_like(D1U[:,0:1]),
56 retain_graph=True,
57 create_graph=True)[0][:,0:1]
58
59 D2UX2 = torch.autograd.grad(
60 D1U[:,1:2], Xin,
61 grad_outputs=torch.ones_like(D1U[:,1:2]),
62 retain_graph=True,
63 create_graph=True)[0][:,1:2]
64
65
66 F = f(Xin)
67
68
69 lin = torch.mean((F + D2UX1 + D2UX2)**2)
70
71
72 x1_bc = 2.*torch.rand(ns_bc, device=device) - 1.
73 x2_bc = 2.*torch.rand(ns_bc, device=device) - 1.
74
75
76 Xcc1 = torch.stack((x1_bc, -torch.ones_like(x1_bc)), dim=1)
77 Ucc1 = model(Xcc1)
78 lcc1 = torch.mean(Ucc1**2)
79
80
81 Xcc3 = torch.stack((x1_bc, torch.ones_like(x1_bc)), dim=1)
82 Ucc3 = model(Xcc3)
83 lcc3 = torch.mean(Ucc3**2)
84
85
86 Xcc4 = torch.stack((-torch.ones_like(x2_bc), x2_bc), dim=1)
87 Ucc4 = model(Xcc4)
88 lcc4 = torch.mean(Ucc4**2)
89
90
91 Xcc2 = torch.stack((torch.ones_like(x2_bc), x2_bc), dim=1)
92 Ucc2 = model(Xcc2)
93 lcc2 = torch.mean(Ucc2**2)
94
95 lcc = 0.25*(lcc1 + lcc2 + lcc3 + lcc4)
96
97
98 loss = lin + lcc
99
100 if ((epoch % 100 == 0) or (loss.item() < tol)):
101 print(f'{epoch}: loss = {loss.item():.3e}, lin = {lin.item():.2e}, lcc = {lcc.item():.2e}')
102
103
104 torch.save(model, 'model.pt')
105 torch.save(Xin, 'Xin.pt')
106 torch.save(x1_bc, 'x1_bc.pt')
107 torch.save(x2_bc, 'x2_bc.pt')
108
109 if (loss.item() < tol):
110 break
111
112 loss.backward()
113 optim.step()