switching to high quality piper tts and added label translations
This commit is contained in:
@@ -0,0 +1,370 @@
|
||||
"""Matplotlib based plotting of quantum circuits.
|
||||
|
||||
Todo:
|
||||
|
||||
* Optimize printing of large circuits.
|
||||
* Get this to work with single gates.
|
||||
* Do a better job checking the form of circuits to make sure it is a Mul of
|
||||
Gates.
|
||||
* Get multi-target gates plotting.
|
||||
* Get initial and final states to plot.
|
||||
* Get measurements to plot. Might need to rethink measurement as a gate
|
||||
issue.
|
||||
* Get scale and figsize to be handled in a better way.
|
||||
* Write some tests/examples!
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sympy.core.mul import Mul
|
||||
from sympy.external import import_module
|
||||
from sympy.physics.quantum.gate import Gate, OneQubitGate, CGate, CGateS
|
||||
|
||||
|
||||
__all__ = [
|
||||
'CircuitPlot',
|
||||
'circuit_plot',
|
||||
'labeller',
|
||||
'Mz',
|
||||
'Mx',
|
||||
'CreateOneQubitGate',
|
||||
'CreateCGate',
|
||||
]
|
||||
|
||||
np = import_module('numpy')
|
||||
matplotlib = import_module(
|
||||
'matplotlib', import_kwargs={'fromlist': ['pyplot']},
|
||||
catch=(RuntimeError,)) # This is raised in environments that have no display.
|
||||
|
||||
if np and matplotlib:
|
||||
pyplot = matplotlib.pyplot
|
||||
Line2D = matplotlib.lines.Line2D
|
||||
Circle = matplotlib.patches.Circle
|
||||
|
||||
#from matplotlib import rc
|
||||
#rc('text',usetex=True)
|
||||
|
||||
class CircuitPlot:
|
||||
"""A class for managing a circuit plot."""
|
||||
|
||||
scale = 1.0
|
||||
fontsize = 20.0
|
||||
linewidth = 1.0
|
||||
control_radius = 0.05
|
||||
not_radius = 0.15
|
||||
swap_delta = 0.05
|
||||
labels: list[str] = []
|
||||
inits: dict[str, str] = {}
|
||||
label_buffer = 0.5
|
||||
|
||||
def __init__(self, c, nqubits, **kwargs):
|
||||
if not np or not matplotlib:
|
||||
raise ImportError('numpy or matplotlib not available.')
|
||||
self.circuit = c
|
||||
self.ngates = len(self.circuit.args)
|
||||
self.nqubits = nqubits
|
||||
self.update(kwargs)
|
||||
self._create_grid()
|
||||
self._create_figure()
|
||||
self._plot_wires()
|
||||
self._plot_gates()
|
||||
self._finish()
|
||||
|
||||
def update(self, kwargs):
|
||||
"""Load the kwargs into the instance dict."""
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
def _create_grid(self):
|
||||
"""Create the grid of wires."""
|
||||
scale = self.scale
|
||||
wire_grid = np.arange(0.0, self.nqubits*scale, scale, dtype=float)
|
||||
gate_grid = np.arange(0.0, self.ngates*scale, scale, dtype=float)
|
||||
self._wire_grid = wire_grid
|
||||
self._gate_grid = gate_grid
|
||||
|
||||
def _create_figure(self):
|
||||
"""Create the main matplotlib figure."""
|
||||
self._figure = pyplot.figure(
|
||||
figsize=(self.ngates*self.scale, self.nqubits*self.scale),
|
||||
facecolor='w',
|
||||
edgecolor='w'
|
||||
)
|
||||
ax = self._figure.add_subplot(
|
||||
1, 1, 1,
|
||||
frameon=True
|
||||
)
|
||||
ax.set_axis_off()
|
||||
offset = 0.5*self.scale
|
||||
ax.set_xlim(self._gate_grid[0] - offset, self._gate_grid[-1] + offset)
|
||||
ax.set_ylim(self._wire_grid[0] - offset, self._wire_grid[-1] + offset)
|
||||
ax.set_aspect('equal')
|
||||
self._axes = ax
|
||||
|
||||
def _plot_wires(self):
|
||||
"""Plot the wires of the circuit diagram."""
|
||||
xstart = self._gate_grid[0]
|
||||
xstop = self._gate_grid[-1]
|
||||
xdata = (xstart - self.scale, xstop + self.scale)
|
||||
for i in range(self.nqubits):
|
||||
ydata = (self._wire_grid[i], self._wire_grid[i])
|
||||
line = Line2D(
|
||||
xdata, ydata,
|
||||
color='k',
|
||||
lw=self.linewidth
|
||||
)
|
||||
self._axes.add_line(line)
|
||||
if self.labels:
|
||||
init_label_buffer = 0
|
||||
if self.inits.get(self.labels[i]): init_label_buffer = 0.25
|
||||
self._axes.text(
|
||||
xdata[0]-self.label_buffer-init_label_buffer,ydata[0],
|
||||
render_label(self.labels[i],self.inits),
|
||||
size=self.fontsize,
|
||||
color='k',ha='center',va='center')
|
||||
self._plot_measured_wires()
|
||||
|
||||
def _plot_measured_wires(self):
|
||||
ismeasured = self._measurements()
|
||||
xstop = self._gate_grid[-1]
|
||||
dy = 0.04 # amount to shift wires when doubled
|
||||
# Plot doubled wires after they are measured
|
||||
for im in ismeasured:
|
||||
xdata = (self._gate_grid[ismeasured[im]],xstop+self.scale)
|
||||
ydata = (self._wire_grid[im]+dy,self._wire_grid[im]+dy)
|
||||
line = Line2D(
|
||||
xdata, ydata,
|
||||
color='k',
|
||||
lw=self.linewidth
|
||||
)
|
||||
self._axes.add_line(line)
|
||||
# Also double any controlled lines off these wires
|
||||
for i,g in enumerate(self._gates()):
|
||||
if isinstance(g, (CGate, CGateS)):
|
||||
wires = g.controls + g.targets
|
||||
for wire in wires:
|
||||
if wire in ismeasured and \
|
||||
self._gate_grid[i] > self._gate_grid[ismeasured[wire]]:
|
||||
ydata = min(wires), max(wires)
|
||||
xdata = self._gate_grid[i]-dy, self._gate_grid[i]-dy
|
||||
line = Line2D(
|
||||
xdata, ydata,
|
||||
color='k',
|
||||
lw=self.linewidth
|
||||
)
|
||||
self._axes.add_line(line)
|
||||
def _gates(self):
|
||||
"""Create a list of all gates in the circuit plot."""
|
||||
gates = []
|
||||
if isinstance(self.circuit, Mul):
|
||||
for g in reversed(self.circuit.args):
|
||||
if isinstance(g, Gate):
|
||||
gates.append(g)
|
||||
elif isinstance(self.circuit, Gate):
|
||||
gates.append(self.circuit)
|
||||
return gates
|
||||
|
||||
def _plot_gates(self):
|
||||
"""Iterate through the gates and plot each of them."""
|
||||
for i, gate in enumerate(self._gates()):
|
||||
gate.plot_gate(self, i)
|
||||
|
||||
def _measurements(self):
|
||||
"""Return a dict ``{i:j}`` where i is the index of the wire that has
|
||||
been measured, and j is the gate where the wire is measured.
|
||||
"""
|
||||
ismeasured = {}
|
||||
for i,g in enumerate(self._gates()):
|
||||
if getattr(g,'measurement',False):
|
||||
for target in g.targets:
|
||||
if target in ismeasured:
|
||||
if ismeasured[target] > i:
|
||||
ismeasured[target] = i
|
||||
else:
|
||||
ismeasured[target] = i
|
||||
return ismeasured
|
||||
|
||||
def _finish(self):
|
||||
# Disable clipping to make panning work well for large circuits.
|
||||
for o in self._figure.findobj():
|
||||
o.set_clip_on(False)
|
||||
|
||||
def one_qubit_box(self, t, gate_idx, wire_idx):
|
||||
"""Draw a box for a single qubit gate."""
|
||||
x = self._gate_grid[gate_idx]
|
||||
y = self._wire_grid[wire_idx]
|
||||
self._axes.text(
|
||||
x, y, t,
|
||||
color='k',
|
||||
ha='center',
|
||||
va='center',
|
||||
bbox={"ec": 'k', "fc": 'w', "fill": True, "lw": self.linewidth},
|
||||
size=self.fontsize
|
||||
)
|
||||
|
||||
def two_qubit_box(self, t, gate_idx, wire_idx):
|
||||
"""Draw a box for a two qubit gate. Does not work yet.
|
||||
"""
|
||||
# x = self._gate_grid[gate_idx]
|
||||
# y = self._wire_grid[wire_idx]+0.5
|
||||
print(self._gate_grid)
|
||||
print(self._wire_grid)
|
||||
# unused:
|
||||
# obj = self._axes.text(
|
||||
# x, y, t,
|
||||
# color='k',
|
||||
# ha='center',
|
||||
# va='center',
|
||||
# bbox=dict(ec='k', fc='w', fill=True, lw=self.linewidth),
|
||||
# size=self.fontsize
|
||||
# )
|
||||
|
||||
def control_line(self, gate_idx, min_wire, max_wire):
|
||||
"""Draw a vertical control line."""
|
||||
xdata = (self._gate_grid[gate_idx], self._gate_grid[gate_idx])
|
||||
ydata = (self._wire_grid[min_wire], self._wire_grid[max_wire])
|
||||
line = Line2D(
|
||||
xdata, ydata,
|
||||
color='k',
|
||||
lw=self.linewidth
|
||||
)
|
||||
self._axes.add_line(line)
|
||||
|
||||
def control_point(self, gate_idx, wire_idx):
|
||||
"""Draw a control point."""
|
||||
x = self._gate_grid[gate_idx]
|
||||
y = self._wire_grid[wire_idx]
|
||||
radius = self.control_radius
|
||||
c = Circle(
|
||||
(x, y),
|
||||
radius*self.scale,
|
||||
ec='k',
|
||||
fc='k',
|
||||
fill=True,
|
||||
lw=self.linewidth
|
||||
)
|
||||
self._axes.add_patch(c)
|
||||
|
||||
def not_point(self, gate_idx, wire_idx):
|
||||
"""Draw a NOT gates as the circle with plus in the middle."""
|
||||
x = self._gate_grid[gate_idx]
|
||||
y = self._wire_grid[wire_idx]
|
||||
radius = self.not_radius
|
||||
c = Circle(
|
||||
(x, y),
|
||||
radius,
|
||||
ec='k',
|
||||
fc='w',
|
||||
fill=False,
|
||||
lw=self.linewidth
|
||||
)
|
||||
self._axes.add_patch(c)
|
||||
l = Line2D(
|
||||
(x, x), (y - radius, y + radius),
|
||||
color='k',
|
||||
lw=self.linewidth
|
||||
)
|
||||
self._axes.add_line(l)
|
||||
|
||||
def swap_point(self, gate_idx, wire_idx):
|
||||
"""Draw a swap point as a cross."""
|
||||
x = self._gate_grid[gate_idx]
|
||||
y = self._wire_grid[wire_idx]
|
||||
d = self.swap_delta
|
||||
l1 = Line2D(
|
||||
(x - d, x + d),
|
||||
(y - d, y + d),
|
||||
color='k',
|
||||
lw=self.linewidth
|
||||
)
|
||||
l2 = Line2D(
|
||||
(x - d, x + d),
|
||||
(y + d, y - d),
|
||||
color='k',
|
||||
lw=self.linewidth
|
||||
)
|
||||
self._axes.add_line(l1)
|
||||
self._axes.add_line(l2)
|
||||
|
||||
def circuit_plot(c, nqubits, **kwargs):
|
||||
"""Draw the circuit diagram for the circuit with nqubits.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
c : circuit
|
||||
The circuit to plot. Should be a product of Gate instances.
|
||||
nqubits : int
|
||||
The number of qubits to include in the circuit. Must be at least
|
||||
as big as the largest ``min_qubits`` of the gates.
|
||||
"""
|
||||
return CircuitPlot(c, nqubits, **kwargs)
|
||||
|
||||
def render_label(label, inits={}):
|
||||
"""Slightly more flexible way to render labels.
|
||||
|
||||
>>> from sympy.physics.quantum.circuitplot import render_label
|
||||
>>> render_label('q0')
|
||||
'$\\\\left|q0\\\\right\\\\rangle$'
|
||||
>>> render_label('q0', {'q0':'0'})
|
||||
'$\\\\left|q0\\\\right\\\\rangle=\\\\left|0\\\\right\\\\rangle$'
|
||||
"""
|
||||
init = inits.get(label)
|
||||
if init:
|
||||
return r'$\left|%s\right\rangle=\left|%s\right\rangle$' % (label, init)
|
||||
return r'$\left|%s\right\rangle$' % label
|
||||
|
||||
def labeller(n, symbol='q'):
|
||||
"""Autogenerate labels for wires of quantum circuits.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
n : int
|
||||
number of qubits in the circuit.
|
||||
symbol : string
|
||||
A character string to precede all gate labels. E.g. 'q_0', 'q_1', etc.
|
||||
|
||||
>>> from sympy.physics.quantum.circuitplot import labeller
|
||||
>>> labeller(2)
|
||||
['q_1', 'q_0']
|
||||
>>> labeller(3,'j')
|
||||
['j_2', 'j_1', 'j_0']
|
||||
"""
|
||||
return ['%s_%d' % (symbol,n-i-1) for i in range(n)]
|
||||
|
||||
class Mz(OneQubitGate):
|
||||
"""Mock-up of a z measurement gate.
|
||||
|
||||
This is in circuitplot rather than gate.py because it's not a real
|
||||
gate, it just draws one.
|
||||
"""
|
||||
measurement = True
|
||||
gate_name='Mz'
|
||||
gate_name_latex='M_z'
|
||||
|
||||
class Mx(OneQubitGate):
|
||||
"""Mock-up of an x measurement gate.
|
||||
|
||||
This is in circuitplot rather than gate.py because it's not a real
|
||||
gate, it just draws one.
|
||||
"""
|
||||
measurement = True
|
||||
gate_name='Mx'
|
||||
gate_name_latex='M_x'
|
||||
|
||||
class CreateOneQubitGate(type):
|
||||
def __new__(mcl, name, latexname=None):
|
||||
if not latexname:
|
||||
latexname = name
|
||||
return type(name + "Gate", (OneQubitGate,),
|
||||
{'gate_name': name, 'gate_name_latex': latexname})
|
||||
|
||||
def CreateCGate(name, latexname=None):
|
||||
"""Use a lexical closure to make a controlled gate.
|
||||
"""
|
||||
if not latexname:
|
||||
latexname = name
|
||||
onequbitgate = CreateOneQubitGate(name, latexname)
|
||||
def ControlledGate(ctrls,target):
|
||||
return CGate(tuple(ctrls),onequbitgate(target))
|
||||
return ControlledGate
|
||||
Reference in New Issue
Block a user