In [1]:
Copied!
import numpy as np
import matplotlib.pyplot as plt
import torch
import pyqtorch.modules as pyq
import numpy as np
import matplotlib.pyplot as plt
import torch
import pyqtorch.modules as pyq
/Users/niklas/Library/Application Support/hatch/env/virtual/qucint/6NOL9orC/qucint/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
Fitting a 1D function¶
Let's define a target function we want to fit.
In [2]:
Copied!
def target_function(x, degree=3):
result = 0
for i in range(degree):
result += torch.cos(i*x) + torch.sin(i*x)
return .05 * result
x = torch.tensor(np.linspace(0, 10, 100))
target_y = target_function(x, 5)
plt.plot(x.numpy(), target_y.numpy(), label="truth")
plt.legend()
plt.show()
def target_function(x, degree=3):
result = 0
for i in range(degree):
result += torch.cos(i*x) + torch.sin(i*x)
return .05 * result
x = torch.tensor(np.linspace(0, 10, 100))
target_y = target_function(x, 5)
plt.plot(x.numpy(), target_y.numpy(), label="truth")
plt.legend()
plt.show()
To fit this function with a QNN we need an entangling ansatz. We will use a layer of U-gates and a layer of CNOTs:
In [3]:
Copied!
def ULayerAnsatz(n_qubits, n_layers):
ops = []
for _ in range(n_layers):
ops.append(pyq.VariationalLayer(n_qubits, pyq.U))
ops.append(pyq.EntanglingLayer(n_qubits))
return pyq.QuantumCircuit(n_qubits, ops)
ULayerAnsatz(3,1)
def ULayerAnsatz(n_qubits, n_layers):
ops = []
for _ in range(n_layers):
ops.append(pyq.VariationalLayer(n_qubits, pyq.U))
ops.append(pyq.EntanglingLayer(n_qubits))
return pyq.QuantumCircuit(n_qubits, ops)
ULayerAnsatz(3,1)
Out[3]:
QuantumCircuit(
(operations): ModuleList(
(0): VariationalLayer(
(operations): ModuleList(
(0): U(qubits=[0], n_qubits=3)
(1): U(qubits=[1], n_qubits=3)
(2): U(qubits=[2], n_qubits=3)
)
)
(1): EntanglingLayer(
(operations): ModuleList(
(0): CNOT(qubits=[0, 1], n_qubits=3)
(1): CNOT(qubits=[1, 2], n_qubits=3)
(2): CNOT(qubits=[2, 0], n_qubits=3)
)
)
)
)
We can define a QNN by implementing a custom torch.nn.Module.
In [4]:
Copied!
class Model(torch.nn.Module):
def __init__(self, n_qubits, n_layers):
super().__init__()
self.n_qubits = n_qubits
self.ansatz1 = ULayerAnsatz(n_qubits, n_layers)
self.embedding = pyq.FeaturemapLayer(n_qubits, pyq.RX)
self.ansatz2 = ULayerAnsatz(n_qubits, n_layers)
self.observable = pyq.Z([0], n_qubits)
def forward(self, x):
batch_size = len(x)
state = self.ansatz1.init_state(batch_size)
state = self.ansatz1(state)
state = self.embedding(state, x)
state = self.ansatz2(state)
new_state = self.observable(state)
state = state.reshape((2**self.n_qubits, batch_size))
new_state = new_state.reshape((2**self.n_qubits, batch_size))
return torch.real(torch.sum(torch.conj(state) * new_state, axis=0))
class Model(torch.nn.Module):
def __init__(self, n_qubits, n_layers):
super().__init__()
self.n_qubits = n_qubits
self.ansatz1 = ULayerAnsatz(n_qubits, n_layers)
self.embedding = pyq.FeaturemapLayer(n_qubits, pyq.RX)
self.ansatz2 = ULayerAnsatz(n_qubits, n_layers)
self.observable = pyq.Z([0], n_qubits)
def forward(self, x):
batch_size = len(x)
state = self.ansatz1.init_state(batch_size)
state = self.ansatz1(state)
state = self.embedding(state, x)
state = self.ansatz2(state)
new_state = self.observable(state)
state = state.reshape((2**self.n_qubits, batch_size))
new_state = new_state.reshape((2**self.n_qubits, batch_size))
return torch.real(torch.sum(torch.conj(state) * new_state, axis=0))
Let's verify that we are getting reasonable outputs from our untrained QNN
In [5]:
Copied!
n_qubits = 5
n_layers = 3
model = Model(n_qubits, n_layers)
with torch.no_grad():
y = model(x)
plt.plot(x.numpy(), target_y.numpy(), label="truth")
plt.plot(x.numpy(), y.numpy(), label="initial")
plt.legend()
plt.show()
n_qubits = 5
n_layers = 3
model = Model(n_qubits, n_layers)
with torch.no_grad():
y = model(x)
plt.plot(x.numpy(), target_y.numpy(), label="truth")
plt.plot(x.numpy(), y.numpy(), label="initial")
plt.legend()
plt.show()
Our QNN is implemented as a torch.nn.Module so we can use the usual torch optimizers to train it.
In [7]:
Copied!
import torch.nn.functional as F
optimizer = torch.optim.Adam(model.parameters(), lr=.01)
epochs = 200
for epoch in range(epochs):
optimizer.zero_grad()
y_pred = model(x)
loss = F.mse_loss(target_y, y_pred)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1:03d} | Loss {loss}")
import torch.nn.functional as F
optimizer = torch.optim.Adam(model.parameters(), lr=.01)
epochs = 200
for epoch in range(epochs):
optimizer.zero_grad()
y_pred = model(x)
loss = F.mse_loss(target_y, y_pred)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1:03d} | Loss {loss}")
Epoch 001 | Loss 1.181170973931877e-05 Epoch 002 | Loss 3.9005150204328295e-05 Epoch 003 | Loss 0.0005210030244242805 Epoch 004 | Loss 6.012409726695054e-05 Epoch 005 | Loss 0.00021009194032554318 Epoch 006 | Loss 0.0002914833038541067 Epoch 007 | Loss 0.0001454984851627904 Epoch 008 | Loss 4.142681408782311e-05 Epoch 009 | Loss 9.054882419142336e-05 Epoch 010 | Loss 0.0001531790494452143 Epoch 011 | Loss 0.00012015653011887463 Epoch 012 | Loss 5.790713974401939e-05 Epoch 013 | Loss 3.846728702473956e-05 Epoch 014 | Loss 5.630471569536954e-05 Epoch 015 | Loss 7.700979061854381e-05 Epoch 016 | Loss 7.391532378356758e-05 Epoch 017 | Loss 4.612078963466576e-05 Epoch 018 | Loss 2.155939995688723e-05 Epoch 019 | Loss 2.3721610576853985e-05 Epoch 020 | Loss 4.3170004830746536e-05 Epoch 021 | Loss 5.069962559255113e-05 Epoch 022 | Loss 3.5152058993235534e-05 Epoch 023 | Loss 1.4055374744605247e-05 Epoch 024 | Loss 1.023692829664193e-05 Epoch 025 | Loss 2.4446698606623847e-05 Epoch 026 | Loss 3.408834997147914e-05 Epoch 027 | Loss 2.4865126936362327e-05 Epoch 028 | Loss 9.659374207355973e-06 Epoch 029 | Loss 6.950788871336306e-06 Epoch 030 | Loss 1.566768989272486e-05 Epoch 031 | Loss 2.099277061607817e-05 Epoch 032 | Loss 1.6212614075325447e-05 Epoch 033 | Loss 8.235768022404368e-06 Epoch 034 | Loss 6.009959075292303e-06 Epoch 035 | Loss 9.772417521725285e-06 Epoch 036 | Loss 1.2892382127904045e-05 Epoch 037 | Loss 1.1361810341757317e-05 Epoch 038 | Loss 7.2561572882857865e-06 Epoch 039 | Loss 4.821231443489856e-06 Epoch 040 | Loss 6.128272875842876e-06 Epoch 041 | Loss 8.843188084005391e-06 Epoch 042 | Loss 8.675418546268711e-06 Epoch 043 | Loss 5.450684118847346e-06 Epoch 044 | Loss 3.5110836783759226e-06 Epoch 045 | Loss 4.989386678750474e-06 Epoch 046 | Loss 6.898259805053778e-06 Epoch 047 | Loss 6.05054456967633e-06 Epoch 048 | Loss 3.7389955465282206e-06 Epoch 049 | Loss 3.273212519167249e-06 Epoch 050 | Loss 4.637559523242552e-06 Epoch 051 | Loss 5.165131301134411e-06 Epoch 052 | Loss 4.142896350302605e-06 Epoch 053 | Loss 3.1492780545134018e-06 Epoch 054 | Loss 3.233757188669417e-06 Epoch 055 | Loss 3.86524696418792e-06 Epoch 056 | Loss 3.953219614950671e-06 Epoch 057 | Loss 3.312011206526524e-06 Epoch 058 | Loss 2.7661753462439286e-06 Epoch 059 | Loss 2.929670133548e-06 Epoch 060 | Loss 3.3791616358340403e-06 Epoch 061 | Loss 3.238319199906219e-06 Epoch 062 | Loss 2.630986577566763e-06 Epoch 063 | Loss 2.4992796964562005e-06 Epoch 064 | Loss 2.8735530062625166e-06 Epoch 065 | Loss 2.938593815700879e-06 Epoch 066 | Loss 2.5476755194736197e-06 Epoch 067 | Loss 2.3241505411727017e-06 Epoch 068 | Loss 2.49713839896336e-06 Epoch 069 | Loss 2.6112011957839276e-06 Epoch 070 | Loss 2.4362527269610276e-06 Epoch 071 | Loss 2.24253269756204e-06 Epoch 072 | Loss 2.2373816058577022e-06 Epoch 073 | Loss 2.3310248995766054e-06 Epoch 074 | Loss 2.3077427496625516e-06 Epoch 075 | Loss 2.1503126672635597e-06 Epoch 076 | Loss 2.070910242097326e-06 Epoch 077 | Loss 2.137149272933625e-06 Epoch 078 | Loss 2.1596597107860127e-06 Epoch 079 | Loss 2.0397936834282285e-06 Epoch 080 | Loss 1.9556617368367096e-06 Epoch 081 | Loss 1.998179918829005e-06 Epoch 082 | Loss 2.0117531938650416e-06 Epoch 083 | Loss 1.93268494654725e-06 Epoch 084 | Loss 1.8697853805020282e-06 Epoch 085 | Loss 1.875957033063329e-06 Epoch 086 | Loss 1.8812430094662359e-06 Epoch 087 | Loss 1.8393870837550013e-06 Epoch 088 | Loss 1.787910412014617e-06 Epoch 089 | Loss 1.7701482382035251e-06 Epoch 090 | Loss 1.7752261255685842e-06 Epoch 091 | Loss 1.7504566547702705e-06 Epoch 092 | Loss 1.7019486718406315e-06 Epoch 093 | Loss 1.6835082115283767e-06 Epoch 094 | Loss 1.685667875198429e-06 Epoch 095 | Loss 1.6616583019769176e-06 Epoch 096 | Loss 1.6231601566709535e-06 Epoch 097 | Loss 1.60813362494993e-06 Epoch 098 | Loss 1.6019701482892067e-06 Epoch 099 | Loss 1.5803956915397127e-06 Epoch 100 | Loss 1.5527873527936183e-06 Epoch 101 | Loss 1.5352883431312178e-06 Epoch 102 | Loss 1.5252310209689772e-06 Epoch 103 | Loss 1.5084573236758253e-06 Epoch 104 | Loss 1.4843679595562227e-06 Epoch 105 | Loss 1.4667658608496229e-06 Epoch 106 | Loss 1.457148855226587e-06 Epoch 107 | Loss 1.4404304018817315e-06 Epoch 108 | Loss 1.4187656091462116e-06 Epoch 109 | Loss 1.4045047593157939e-06 Epoch 110 | Loss 1.3931682693016124e-06 Epoch 111 | Loss 1.3763410548891802e-06 Epoch 112 | Loss 1.3585907789414708e-06 Epoch 113 | Loss 1.3450608769260442e-06 Epoch 114 | Loss 1.332534011814191e-06 Epoch 115 | Loss 1.317529909690219e-06 Epoch 116 | Loss 1.301470982727387e-06 Epoch 117 | Loss 1.2883992098664887e-06 Epoch 118 | Loss 1.2764695970430059e-06 Epoch 119 | Loss 1.2619454042812358e-06 Epoch 120 | Loss 1.247318667750338e-06 Epoch 121 | Loss 1.235492552299088e-06 Epoch 122 | Loss 1.2232337173189462e-06 Epoch 123 | Loss 1.2094041838489128e-06 Epoch 124 | Loss 1.196602229533154e-06 Epoch 125 | Loss 1.185027851699964e-06 Epoch 126 | Loss 1.1729201462357814e-06 Epoch 127 | Loss 1.1602693027487138e-06 Epoch 128 | Loss 1.1482910988692847e-06 Epoch 129 | Loss 1.1371439541573138e-06 Epoch 130 | Loss 1.1255718985465116e-06 Epoch 131 | Loss 1.1136108149941372e-06 Epoch 132 | Loss 1.1025357900284203e-06 Epoch 133 | Loss 1.091830958157926e-06 Epoch 134 | Loss 1.0805734231640713e-06 Epoch 135 | Loss 1.0695479487719928e-06 Epoch 136 | Loss 1.059127192210933e-06 Epoch 137 | Loss 1.0486653310712814e-06 Epoch 138 | Loss 1.0380690084037267e-06 Epoch 139 | Loss 1.0277505705549728e-06 Epoch 140 | Loss 1.0177783820600243e-06 Epoch 141 | Loss 1.0077843803092286e-06 Epoch 142 | Loss 9.977294830491515e-07 Epoch 143 | Loss 9.880084245920512e-07 Epoch 144 | Loss 9.78507612080833e-07 Epoch 145 | Loss 9.688991424059435e-07 Epoch 146 | Loss 9.59433082323173e-07 Epoch 147 | Loss 9.502412901981052e-07 Epoch 148 | Loss 9.410977540121816e-07 Epoch 149 | Loss 9.319727858151028e-07 Epoch 150 | Loss 9.230320009485983e-07 Epoch 151 | Loss 9.142515908872536e-07 Epoch 152 | Loss 9.055176796050045e-07 Epoch 153 | Loss 8.968476493496918e-07 Epoch 154 | Loss 8.883585070062157e-07 Epoch 155 | Loss 8.799798574579335e-07 Epoch 156 | Loss 8.716306904438497e-07 Epoch 157 | Loss 8.634115370027542e-07 Epoch 158 | Loss 8.553286106000076e-07 Epoch 159 | Loss 8.473066878770816e-07 Epoch 160 | Loss 8.393611035486784e-07 Epoch 161 | Loss 8.315383832451852e-07 Epoch 162 | Loss 8.23812730022084e-07 Epoch 163 | Loss 8.161585244572439e-07 Epoch 164 | Loss 8.085853311389707e-07 Epoch 165 | Loss 8.011311389518617e-07 Epoch 166 | Loss 7.937509776200442e-07 Epoch 167 | Loss 7.864414200490334e-07 Epoch 168 | Loss 7.792321202346505e-07 Epoch 169 | Loss 7.721129420068149e-07 Epoch 170 | Loss 7.650566056859965e-07 Epoch 171 | Loss 7.580867392307161e-07 Epoch 172 | Loss 7.512019255765706e-07 Epoch 173 | Loss 7.44392871495942e-07 Epoch 174 | Loss 7.376554924382388e-07 Epoch 175 | Loss 7.310041859797791e-07 Epoch 176 | Loss 7.244223352825202e-07 Epoch 177 | Loss 7.179107583362865e-07 Epoch 178 | Loss 7.114729205692211e-07 Epoch 179 | Loss 7.051138044328317e-07 Epoch 180 | Loss 6.988167086620162e-07 Epoch 181 | Loss 6.925921661054752e-07 Epoch 182 | Loss 6.864380917059668e-07 Epoch 183 | Loss 6.803475742422391e-07 Epoch 184 | Loss 6.743241679721031e-07 Epoch 185 | Loss 6.683699407928609e-07 Epoch 186 | Loss 6.624817735278676e-07 Epoch 187 | Loss 6.566483433856844e-07 Epoch 188 | Loss 6.508865052027779e-07 Epoch 189 | Loss 6.451858149751575e-07 Epoch 190 | Loss 6.395429157077213e-07 Epoch 191 | Loss 6.339613565507968e-07 Epoch 192 | Loss 6.284421739744305e-07 Epoch 193 | Loss 6.229794300541488e-07 Epoch 194 | Loss 6.175748470592854e-07 Epoch 195 | Loss 6.122284465298443e-07 Epoch 196 | Loss 6.069336984847245e-07 Epoch 197 | Loss 6.016991132192018e-07 Epoch 198 | Loss 5.965180591195436e-07 Epoch 199 | Loss 5.913902703863853e-07 Epoch 200 | Loss 5.863220844515445e-07
We can verify that the final prediction looks like the target function:
In [13]:
Copied!
with torch.no_grad():
y_final = model(x)
plt.plot(x.numpy(), target_y.numpy(), label="truth")
plt.plot(x.numpy(), y.numpy(), label="initial")
plt.plot(x.numpy(), y_final.numpy(), "--", label="final", linewidth=3)
plt.legend()
plt.show()
with torch.no_grad():
y_final = model(x)
plt.plot(x.numpy(), target_y.numpy(), label="truth")
plt.plot(x.numpy(), y.numpy(), label="initial")
plt.plot(x.numpy(), y_final.numpy(), "--", label="final", linewidth=3)
plt.legend()
plt.show()