1 import torch
2 from torch import pi, sin
3
4
5 nn = 50
6 model = torch.nn.Sequential()
7 model.add_module('layer_1', torch.nn.Linear(2,nn))
8 model.add_module('fun_1', torch.nn.Tanh())
9 model.add_module('layer_2', torch.nn.Linear(nn,nn))
10 model.add_module('fun_2', torch.nn.Tanh())
11 model.add_module('layer_3', torch.nn.Linear(nn,nn))
12 model.add_module('fun_3', torch.nn.Tanh())
13 model.add_module('layer_4', torch.nn.Linear(nn,1))
14
15
16 optim = torch.optim.SGD(model.parameters(),
17 lr = 1e-3, momentum=0.9)
18
19
20 def f(x1, x2):
21 return 2.*pi**2*sin(pi*x1)*sin(pi*x2)
22
23
24 ns_in = 400
25 ns_cc = 20
26 nepochs = 50000
27 tol = 1e-3
28
29
30 ns_val = 50
31 x1_val = torch.linspace(-1., 1., steps=ns_val)
32 x2_val = torch.linspace(-1., 1., steps=ns_val)
33 X1_val, X2_val = torch.meshgrid(x1_val, x2_val, indexing='ij')
34 X_val = torch.hstack((X1_val.reshape(ns_val**2,1),
35 X2_val.reshape(ns_val**2,1)))
36
37 for epoch in range(nepochs):
38
39
40 X1 = 2.*torch.rand(ns_in, 1) - 1.
41 X2 = 2.*torch.rand(ns_in, 1) - 1.
42 X = torch.hstack((X1, X2))
43 X.requires_grad = True
44
45 U = model(X)
46
47
48 D1U = torch.autograd.grad(
49 U, X,
50 grad_outputs=torch.ones_like(U),
51 retain_graph=True,
52 create_graph=True)[0]
53 D2UX1 = torch.autograd.grad(
54 D1U[:,0:1], X,
55 grad_outputs=torch.ones_like(D1U[:,0:1]),
56 retain_graph=True,
57 create_graph=True)[0]
58 D2UX2 = torch.autograd.grad(
59 D1U[:,1:2], X,
60 grad_outputs=torch.ones_like(D1U[:,1:2]),
61 retain_graph=True,
62 create_graph=True)[0]
63
64
65 F = f(X1, X2)
66
67
68 lin = torch.mean((F + D2UX1[:,0:1] + D2UX2[:,1:2])**2)
69
70
71
72 X1 = 2.*torch.rand(ns_cc, 1) - 1.
73 Xcc1 = torch.hstack((X1, -torch.ones((ns_cc,1))))
74 Ucc1 = model(Xcc1)
75
76
77 Xcc3 = torch.hstack((X1, torch.ones((ns_cc,1))))
78 Ucc3 = model(Xcc3)
79
80
81 X2 = 2.*torch.rand(ns_cc, 1) - 1.
82 Xcc4 = torch.hstack((-torch.ones((ns_cc,1)), X2))
83 Ucc4 = model(Xcc4)
84
85
86 Xcc2 = torch.hstack((torch.ones((ns_cc,1)), X2))
87 Ucc2 = model(Xcc2)
88
89
90 lcc = 1./(4.*ns_cc) * torch.sum(Ucc1**2 + Ucc2**2 + Ucc3**2 + Ucc4**2)
91
92
93 loss = lin + lcc
94
95 if ((epoch % 500 == 0) or (loss.item() < tol)):
96 print(f'{epoch}: loss = {loss.item():.4e}')
97
98 if (loss.item() < tol):
99 break
100
101 optim.zero_grad()
102 loss.backward()
103 optim.step()