Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,18 @@ celerybeat-schedule
.env

# virtualenv
venv/
ENV/
.venv*
venv*
ENV*
.env*
*.venv

# Spyder project settings
.spyderproject

# Rope project settings
.ropeproject

# ---
*.h5
load_model.py
7 changes: 2 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@ See: https://www.nature.com/articles/s41598-017-11266-1
Make your own decoder with:

```
train_network.py 5 output.model \
--onthefly 10000000 50000 \
--Xstab --Zstab \
--epochs 10 --prob 0.9 \
--learningrate .000001 --normcenterstab --layers 4 4 4 4 4 4 4
train_network.py 5 output.model --onthefly 10000000 50000 --Xstab --Zstab --epochs 10 --prob 0.9 --learningrate .000001 --normcenterstab --layers 4 4 4 4 4 4 4
```

Test a network by adding the `--eval` flag.

See `train_network.py -h` for description of each flag.
266 changes: 149 additions & 117 deletions codes.py

Large diffs are not rendered by default.

13 changes: 7 additions & 6 deletions find_threshold.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from tqdm import tqdm
import numpy as np
from codes import find_threshold, sample
import argparse

parser = argparse.ArgumentParser(description='Find the threshold of a code.',
Expand Down Expand Up @@ -42,17 +45,15 @@
args = parser.parse_args()
print(args)

from codes import find_threshold, sample
import numpy as np
from tqdm import tqdm

if args.dist2:
find_threshold(Lsmall=args.dist, Llarge=args.dist2,
p=(args.plow+args.phigh)/2, high=args.phigh, low=args.plow,
samples=args.samples, logfile=args.out)
p=(args.plow+args.phigh)/2, high=args.phigh, low=args.plow,
samples=args.samples, logfile=args.out)
else:
ps = np.linspace(args.plow, args.phigh, args.steps+1)[:-1]
r = []
for p in tqdm(ps):
r.append(sample(args.dist, p, args.samples))
np.savetxt(args.out, np.vstack([ps[:len(r)], np.array(r).T]), fmt='%.8e')
np.savetxt(args.out, np.vstack(
[ps[:len(r)], np.array(r).T]), fmt='%.8e')
6 changes: 3 additions & 3 deletions generate_training_data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np
from codes import generate_training_data
import argparse

parser = argparse.ArgumentParser(description='Generate single-shot training data.',
Expand All @@ -22,13 +24,11 @@

args = parser.parse_args()

from codes import generate_training_data
import numpy as np

res, _ = generate_training_data(l=args.dist,
p=args.prob,
train_size=args.ntrain,
test_size=args.nval,
)
)

np.savez_compressed(args.out, *res)
123 changes: 81 additions & 42 deletions neural.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,67 +5,93 @@
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.optimizers import Nadam
from keras.objectives import binary_crossentropy
from keras.layers.normalization import BatchNormalization
from keras.losses import binary_crossentropy
from keras.layers import BatchNormalization
import tensorflow as tf

F = lambda _: K.cast(_, 'float32') # TODO XXX there must be a better way to calculate mean than this cast-first approach

# TODO XXX there must be a better way to calculate mean than this cast-first approach
def F(_): return tf.cast(_, 'float32')


class CodeCosts:
def __init__(self, L, code, Z, X, normcentererr_p=None):
if normcentererr_p:
raise NotImplementedError('Throughout the entire codebase, the normalization and centering of the error, might be wrong... Or to be more precise, it might just be plain stupid, given that we are using binary crossentropy as loss.')
raise NotImplementedError(
'Throughout the entire codebase, the normalization and centering of the error, might be wrong... Or to be more precise, it might just be plain stupid, given that we are using binary crossentropy as loss.')
self.L = L
code = code(L)
H = code.H(Z,X)
E = code.E(Z,X)
self.H = K.variable(value=H) # TODO should be sparse
self.E = K.variable(value=E) # TODO should be sparse
H = code.H(Z, X)
E = code.E(Z, X)
self.H = tf.Variable(initial_value=H, trainable=False,
dtype=tf.float32) # TODO should be sparse
self.E = tf.Variable(initial_value=E, trainable=False,
dtype=tf.float32) # TODO should be sparse
self.p = normcentererr_p

def exact_reversal(self, y_true, y_pred):
"Fraction exactly predicted qubit flips."
if self.p:
y_pred = undo_normcentererr(y_pred, self.p)
y_true = undo_normcentererr(y_true, self.p)
return K.mean(F(K.all(K.equal(y_true, K.round(y_pred)), axis=-1)))
return tf.reduce_mean(F(tf.reduce_all(tf.equal(y_true, tf.round(y_pred)), axis=-1)))

def non_triv_stab_expanded(self, y_true, y_pred):
"Whether the stabilizer after correction is not trivial."
if self.p:
y_pred = undo_normcentererr(y_pred, self.p)
y_true = undo_normcentererr(y_true, self.p)
return K.any(K.dot(self.H, K.transpose((K.round(y_pred)+y_true)%2))%2, axis=0)
# Cast to same dtype to avoid type mismatch
y_pred_rounded = tf.cast(tf.round(y_pred), tf.float32)
y_true_cast = tf.cast(y_true, tf.float32)
correction = tf.cast((y_pred_rounded + y_true_cast) % 2, tf.float32)
return tf.reduce_any(tf.cast(tf.matmul(self.H, tf.transpose(correction)) % 2, tf.bool), axis=0)

def logic_error_expanded(self, y_true, y_pred):
"Whether there is a logical error after correction."
if self.p:
y_pred = undo_normcentererr(y_pred, self.p)
y_true = undo_normcentererr(y_true, self.p)
return K.any(K.dot(self.E, K.transpose((K.round(y_pred)+y_true)%2))%2, axis=0)
# Cast to same dtype to avoid type mismatch
y_pred_rounded = tf.cast(tf.round(y_pred), tf.float32)
y_true_cast = tf.cast(y_true, tf.float32)
correction = tf.cast((y_pred_rounded + y_true_cast) % 2, tf.float32)
return tf.reduce_any(tf.cast(tf.matmul(self.E, tf.transpose(correction)) % 2, tf.bool), axis=0)

def triv_stab(self, y_true, y_pred):
"Fraction trivial stabilizer after corrections."
return 1-K.mean(F(self.non_triv_stab_expanded(y_true, y_pred)))
return 1-tf.reduce_mean(F(self.non_triv_stab_expanded(y_true, y_pred)))

def no_error(self, y_true, y_pred):
"Fraction no logical errors after corrections."
return 1-K.mean(F(self.logic_error_expanded(y_true, y_pred)))
return 1-tf.reduce_mean(F(self.logic_error_expanded(y_true, y_pred)))

def triv_no_error(self, y_true, y_pred):
"Fraction with trivial stabilizer and no error."
# TODO XXX Those casts (the F function) should not be there! This should be logical operations
triv_stab = 1 - F(self.non_triv_stab_expanded(y_true, y_pred))
no_err = 1 - F(self.logic_error_expanded(y_true, y_pred))
return K.mean(no_err*triv_stab)
no_err = 1 - F(self.logic_error_expanded(y_true, y_pred))
return tf.reduce_mean(no_err*triv_stab)

def e_binary_crossentropy(self, y_true, y_pred):
if self.p:
y_pred = undo_normcentererr(y_pred, self.p)
y_true = undo_normcentererr(y_true, self.p)
return K.mean(K.binary_crossentropy(y_pred, y_true), axis=-1)
return tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_true, y_pred), axis=-1)

def s_binary_crossentropy(self, y_true, y_pred):
if self.p:
y_pred = undo_normcentererr(y_pred, self.p)
y_true = undo_normcentererr(y_true, self.p)
s_true = K.dot(y_true, K.transpose(self.H))%2
# Cast to avoid type mismatch
y_true_cast = tf.cast(y_true, tf.float32)
s_true = tf.cast(
tf.matmul(y_true_cast, tf.transpose(self.H)) % 2, tf.float32)
twopminusone = 2*y_pred-1
s_pred = ( 1 - tf.real(K.exp(K.dot(K.log(tf.cast(twopminusone, tf.complex64)), tf.cast(K.transpose(self.H), tf.complex64)))) ) / 2
return K.mean(K.binary_crossentropy(s_pred, s_true), axis=-1)
s_pred = (1 - tf.math.real(tf.exp(tf.matmul(tf.math.log(tf.cast(twopminusone,
tf.complex64)), tf.cast(tf.transpose(self.H), tf.complex64))))) / 2
return tf.reduce_mean(tf.keras.losses.binary_crossentropy(s_true, s_pred), axis=-1)

def se_binary_crossentropy(self, y_true, y_pred):
return 2./3.*self.e_binary_crossentropy(y_true, y_pred) + 1./3.*self.s_binary_crossentropy(y_true, y_pred)

Expand All @@ -76,7 +102,8 @@ def create_model(L, hidden_sizes=[4], hidden_act='tanh', act='sigmoid', loss='bi
in_dim = L**2 * (X+Z)
out_dim = 2*L**2 * (X+Z)
model = Sequential()
model.add(Dense(int(hidden_sizes[0]*out_dim), input_dim=in_dim, kernel_initializer='glorot_uniform'))
model.add(Dense(int(hidden_sizes[0]*out_dim),
input_dim=in_dim, kernel_initializer='glorot_uniform'))
if batchnorm:
model.add(BatchNormalization(momentum=batchnorm))
model.add(Activation(hidden_act))
Expand All @@ -90,75 +117,86 @@ def create_model(L, hidden_sizes=[4], hidden_act='tanh', act='sigmoid', loss='bi
model.add(BatchNormalization(momentum=batchnorm))
model.add(Activation(act))
c = CodeCosts(L, ToricCode, Z, X, normcentererr_p)
losses = {'e_binary_crossentropy':c.e_binary_crossentropy,
's_binary_crossentropy':c.s_binary_crossentropy,
'se_binary_crossentropy':c.se_binary_crossentropy}
model.compile(loss=losses.get(loss,loss),
optimizer=Nadam(lr=learning_rate),
metrics=[c.triv_no_error, c.e_binary_crossentropy, c.s_binary_crossentropy]
)
losses = {'e_binary_crossentropy': c.e_binary_crossentropy,
's_binary_crossentropy': c.s_binary_crossentropy,
'se_binary_crossentropy': c.se_binary_crossentropy}
model.compile(loss=losses.get(loss, loss),
optimizer=Nadam(learning_rate=learning_rate),
metrics=[c.triv_no_error, c.e_binary_crossentropy,
c.s_binary_crossentropy]
)
return model


def makeflips(q, out_dimZ, out_dimX):
flips = np.zeros((out_dimZ+out_dimX,), dtype=np.dtype('b'))
rand = np.random.rand(out_dimZ or out_dimX) # if neither is zero they have to necessarily be the same (equal to the number of physical qubits)
both_flips = (2*q<=rand) & (rand<3*q)
if out_dimZ: # non-trivial Z stabilizer is caused by flips in the X basis
x_flips = rand< q
# if neither is zero they have to necessarily be the same (equal to the number of physical qubits)
rand = np.random.rand(out_dimZ or out_dimX)
both_flips = (2*q <= rand) & (rand < 3*q)
if out_dimZ: # non-trivial Z stabilizer is caused by flips in the X basis
x_flips = rand < q
flips[:out_dimZ] ^= x_flips
flips[:out_dimZ] ^= both_flips
if out_dimX: # non-trivial X stabilizer is caused by flips in the Z basis
z_flips = (q<=rand) & (rand<2*q)
if out_dimX: # non-trivial X stabilizer is caused by flips in the Z basis
z_flips = (q <= rand) & (rand < 2*q)
flips[out_dimZ:out_dimZ+out_dimX] ^= z_flips
flips[out_dimZ:out_dimZ+out_dimX] ^= both_flips
return flips


def nonzeroflips(q, out_dimZ, out_dimX):
flips = makeflips(q, out_dimZ, out_dimX)
while not np.any(flips):
flips = makeflips(q, out_dimZ, out_dimX)
return flips


def data_generator(H, out_dimZ, out_dimX, in_dim, p, batch_size=512, size=None,
normcenterstab=False, normcentererr=False):
c = 0
q = (1-p)/3
while True:
flips = np.empty((batch_size, out_dimZ+out_dimX), dtype=int) # TODO dtype? byte?
flips = np.empty((batch_size, out_dimZ+out_dimX),
dtype=int) # TODO dtype? byte?
for i in range(batch_size):
flips[i,:] = nonzeroflips(q, out_dimZ, out_dimX)
stabs = np.dot(flips,H.T)%2
flips[i, :] = nonzeroflips(q, out_dimZ, out_dimX)
stabs = np.dot(flips, H.T) % 2
if normcenterstab:
stabs = do_normcenterstab(stabs, p)
if normcentererr:
flips = do_normcentererr(flips, p)
yield (stabs, flips)
c += 1
if size and c==size:
raise StopIteration
if size and c >= size:
return


def do_normcenterstab(stabs, p):
avg = (1-p)*2/3
avg_stab = 4*avg*(1-avg)**3 + 4*avg**3*(1-avg)
var_stab = avg_stab-avg_stab**2
return (stabs - avg_stab)/var_stab**0.5


def undo_normcenterstab(stabs, p):
avg = (1-p)*2/3
avg_stab = 4*avg*(1-avg)**3 + 4*avg**3*(1-avg)
var_stab = avg_stab-avg_stab**2
return stabs*var_stab**0.5 + avg_stab


def do_normcentererr(flips, p):
avg = (1-p)*2/3
var = avg-avg**2
return (flips-avg)/var**0.5


def undo_normcentererr(flips, p):
avg = (1-p)*2/3
var = avg-avg**2
return flips*var**0.5 + avg


def smart_sample(H, stab, pred, sample, giveup):
'''Sample `pred` until `H@sample==stab`.

Expand All @@ -169,10 +207,11 @@ def smart_sample(H, stab, pred, sample, giveup):
npsum = np.sum
npdot = np.dot
attempts = 1
mismatch = stab!=npdot(H,sample)%2
mismatch = stab != npdot(H, sample) % 2
while npany(mismatch) and attempts < giveup:
propagated = npany(H[mismatch,:], axis=0)
sample[propagated] = pred[propagated]>nprandomuniform(size=npsum(propagated))
mismatch = stab!=npdot(H,sample)%2
propagated = npany(H[mismatch, :], axis=0)
sample[propagated] = pred[propagated] > nprandomuniform(
size=npsum(propagated))
mismatch = stab != npdot(H, sample) % 2
attempts += 1
return attempts
31 changes: 31 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Neural Network Decoders for Quantum Error Correcting Codes
# Compatible with Python 3.10 or 3.11 (TensorFlow compatibility issues with 3.12)

# Core ML/DL frameworks
tensorflow>=2.16.0,<2.21.0
keras>=3.0.0

# Scientific computing
numpy>=1.21.0,<2.0.0
scipy>=1.9.0

# Graph algorithms (for MWPM)
networkx>=2.8

# Progress bars
tqdm>=4.64.0

# Optional: Jupyter notebook support
jupyter>=1.0.0
ipython>=8.0.0

# Optional: Plotting (mentioned in codes.py)
matplotlib>=3.5.0

# Optional: For better performance with large arrays
# numba>=0.56.0

# Development dependencies (optional)
# pytest>=7.0.0
# black>=22.0.0
# flake8>=5.0.0
Loading